In [7]:
import pandas as pd
import lightgbm as lgb
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

In [8]:
from s4_reg.core import s4regressor as regressor
from s4_reg.src_dataloaders_original import StandardScaler
from s4_reg.src_utils_visualize import prediction_result as post

In [9]:
import pickle

with open("tests/prepared_data.pkl", "rb") as tf:
    dict_prepared_data = pickle.load(tf)

def make_s4_features(
    data,
    target,
    seq_len_target=180,
    pred_len_target=1,
    d_model_target=10, # 2048
    seq_len_others=10,
    pred_len_others=1,
    d_model_others=10
    ):
    
    data = data.reset_index()
    data['date'] = pd.to_datetime(data['Date'])
    data = data.drop(['Date'],axis=1)

    assert seq_len_target > seq_len_others
       
    model_target = regressor(
        dataset = data,
        target = target,
        size = [seq_len_target, pred_len_target],
        features = 'S',
        d_model = d_model_target,
        device = 'cpu'
    )
    
    feat_df_target = model_target.get_features(data)
    stock_data = feat_df_target[target]
    feat_df_target = feat_df_target.drop([target], axis=1)
    
    model_others = regressor(
        dataset = data,
        target = target,
        size = [seq_len_others, pred_len_others],
        features = 'MS',
        d_model = d_model_others,
        device = 'cpu'
    )
    
    feat_df_others = model_others.get_features(data).iloc[seq_len_target-seq_len_others:,:]
    feat_df_others = feat_df_others.drop([target], axis=1)
    feat_df_others.columns = [f'exog_feat_{i+1}' for i in range(len(feat_df_others.columns))]

    features = pd.concat([
                          feat_df_target,
                          feat_df_others
                          ], axis=1)
    
    return features.iloc[:-1,:], pd.DataFrame(stock_data.iloc[:-1])

targets = [
    '4584.T',
    '1557.T',
    '8789.T',
    '1893.T',
    'MSFT'
]

features = {}
for i, target in enumerate(targets):
    feat, stock = make_s4_features(dict_prepared_data[target], target)
    if i==0:
        features[target] = feat
        stock_data = stock
    else:
        features[target] = feat
        stock_data = pd.concat([stock_data, stock], axis=1)   
    

In [10]:
pd.to_pickle(features, './data/features_test.pkl')
pd.DataFrame(stock_data).to_csv('./data/stock_data_test.csv')

In [13]:
features = pd.read_pickle('./data/features_test.pkl')
stock_data = pd.read_csv('./data/stock_data_test.csv', index_col='Unnamed: 0')

In [14]:
display(features)
display(stock_data)

{'4584.T':           feat_1     feat_2      feat_3      feat_4     feat_5      feat_6  \
 180   423.826660 -50.535824  620.076904  647.075989  29.847284 -176.509766   
 181   365.338318 -43.820969  530.817871  554.934692  25.334932 -150.839783   
 182   393.040863 -48.985344  570.585327  597.884583  26.800694 -159.956039   
 183   404.464508 -48.618393  592.458801  610.869812  25.291115 -167.483673   
 184   397.536987 -48.184841  581.841614  599.444336  23.852419 -166.875946   
 ...          ...        ...         ...         ...        ...         ...   
 1294  132.887314 -13.847143  194.207443  201.568588   5.951747  -55.492905   
 1295  133.602676 -14.692342  194.495743  204.649796  10.407352  -55.476749   
 1296  132.842178 -12.357491  193.529541  202.219360   8.310191  -54.819019   
 1297  133.726883 -13.716507  195.077179  205.540329   7.835347  -55.761353   
 1298  132.368713 -10.806239  196.750214  202.363480   8.234961  -56.342850   
 
           feat_7      feat_8      feat_

Unnamed: 0,4584.T,1557.T,8789.T,1893.T,MSFT
180,661.0,28860.0,137.0,533.62700,11123.248
181,708.0,29190.0,140.0,543.87270,10875.842
182,730.0,29370.0,145.0,529.35803,10690.548
183,718.0,29630.0,144.0,547.28790,11058.452
184,732.0,29800.0,141.0,552.41077,11266.821
...,...,...,...,...,...
1294,243.0,55640.0,80.0,680.00000,41467.410
1295,241.0,55680.0,79.0,681.00000,42251.793
1296,244.0,55600.0,69.0,675.00000,41581.720
1297,241.0,55650.0,68.0,678.00000,41555.848
