In [53]:
import pandas as pd
import lightgbm as lgb
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

In [54]:
from s4_reg.core import s4regressor as regressor
from s4_reg.src_dataloaders_original import StandardScaler
from s4_reg.src_utils_visualize import prediction_result as post

In [55]:
import pickle

with open("tests/prepared_data_all.pkl", "rb") as tf:
    dict_prepared_data = pickle.load(tf)

def make_s4_features(
    data,
    target,
    seq_len_target=180,
    pred_len_target=1,
    d_model_target=10, # 2048
    seq_len_others=10,
    pred_len_others=1,
    d_model_others=10
    ):
    
    data = data.reset_index()
    data['date'] = pd.to_datetime(data['Date'])
    data = data.drop(['Date'],axis=1)

    assert seq_len_target > seq_len_others
       
    model_target = regressor(
        dataset = data,
        target = target,
        size = [seq_len_target, pred_len_target],
        features = 'S',
        d_model = d_model_target,
        device = 'cpu'
    )
    
    feat_df_target = model_target.get_features(data)
    stock_data = feat_df_target[target]
    feat_df_target = feat_df_target.drop([target], axis=1)
    
    model_others = regressor(
        dataset = data,
        target = target,
        size = [seq_len_others, pred_len_others],
        features = 'MS',
        d_model = d_model_others,
        device = 'cpu'
    )
    
    feat_df_others = model_others.get_features(data).iloc[seq_len_target-seq_len_others:,:]
    feat_df_others = feat_df_others.drop([target], axis=1)
    feat_df_others.columns = [f'exog_feat_{i+1}' for i in range(len(feat_df_others.columns))]

    features = pd.concat([
                          feat_df_target,
                          feat_df_others
                          ], axis=1)
    
    return features.iloc[:-1,:], pd.DataFrame(stock_data.iloc[:-1])

with open('./target_symbols_all.txt', 'r') as f:
    targets = [line.strip() for line in f]

features = {}
for i, target in enumerate(targets):
    feat, stock = make_s4_features(dict_prepared_data[target], target)
    if i==0:
        features[target] = feat
        stock_data = stock
    else:
        features[target] = feat
        stock_data = pd.concat([stock_data, stock], axis=1)   
    

In [56]:
pd.to_pickle(features, './data/features_all.pkl')
pd.DataFrame(stock_data).to_csv('./data/stock_data_all.csv')

In [57]:
features = pd.read_pickle('./data/features_all.pkl')
stock_data = pd.read_csv('./data/stock_data_all.csv', index_col='Unnamed: 0')

In [58]:
display(features)
display(stock_data)

{'4584.T':           feat_1      feat_2     feat_3      feat_4      feat_5      feat_6  \
 180   724.858215  121.349426  31.804905 -340.099091  427.807434 -443.279755   
 181   715.975830  118.337715  32.622684 -333.855835  420.758728 -437.826019   
 182   771.152954  126.490562  34.313980 -359.629700  457.323975 -471.962067   
 183   753.990662  126.658371  34.267521 -351.010376  445.800079 -462.006989   
 184   756.594055  126.939079  33.697132 -352.581238  449.329559 -464.032410   
 ...          ...         ...        ...         ...         ...         ...   
 1299  223.451508   39.885864  11.101674 -102.470436  138.206482 -143.450790   
 1300  230.035980   42.742805   9.579368 -105.303726  140.051514 -141.415268   
 1301  226.375488   37.585320  12.369206 -103.315849  140.573196 -144.255249   
 1302  230.569122   39.036400  13.432136 -105.590363  142.691833 -142.796387   
 1303  226.544662   41.507702  13.038601 -102.488571  138.436844 -143.210434   
 
           feat_7      feat_

Unnamed: 0,4584.T,1557.T,8789.T,1893.T,7974.T,4661.T,6232.T,5406.T,5201.T,6758.T,9006.T,1360.T,1579.T,MSFT,GOOG,IBM
180,749.0,28960.0,138.0,540.45746,2619.020,10995.1220,3600.0,737.23270,3107.2370,5196.5960,1779.1500,3125.0,17640.0,11151.431,5880.9480,11136.021
181,808.0,28870.0,138.0,532.77325,2634.572,10945.2590,3540.0,740.71030,3128.7253,5109.1616,1793.6935,3125.0,17600.0,11112.039,5879.1730,11119.768
182,790.0,29090.0,145.0,532.77325,2635.350,10865.4770,3735.0,744.18770,3167.4048,5172.3090,1745.2155,3065.0,17970.0,11224.288,5978.4614,11251.368
183,793.0,28960.0,139.0,530.21185,2559.143,10795.6660,3800.0,736.36340,3141.6190,5219.9126,1733.5807,3115.0,17680.0,10985.611,5853.3380,11256.295
184,770.0,28790.0,140.0,533.62700,2601.912,10925.3125,3860.0,741.57965,3141.6190,5196.5960,1756.8501,3110.0,17720.0,10747.522,5793.8486,11246.297
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1299,243.0,55640.0,80.0,680.00000,5730.000,4925.0000,1244.0,1073.00000,5150.0000,12635.0000,1342.0000,727.6,17305.0,41467.410,14579.7790,16366.795
1300,241.0,55680.0,79.0,681.00000,5725.000,4860.0000,1232.0,1034.00000,5190.0000,12520.0000,1341.0000,733.1,17195.0,42251.793,15190.1360,16507.840
1301,244.0,55600.0,69.0,675.00000,5631.000,4955.0000,1192.0,1020.00000,5140.0000,12555.0000,1328.0000,732.6,17185.0,41581.720,15674.7710,16211.119
1302,241.0,55650.0,68.0,678.00000,5758.000,5097.0000,1180.0,1142.00000,5110.0000,12780.0000,1355.0000,719.3,17505.0,41555.848,15860.0040,16521.734
