In [1]:
import pandas as pd
import lightgbm as lgb
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

In [2]:
from s4_reg.core import s4regressor as regressor
from s4_reg.src_dataloaders_original import StandardScaler
from s4_reg.src_utils_visualize import prediction_result as post

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import pickle

with open("tests/prepared_data.pkl", "rb") as tf:
    dict_prepared_data = pickle.load(tf)

def make_s4_features(
    data,
    target,
    seq_len_target=180,
    pred_len_target=1,
    d_model_target=10, # 2048
    seq_len_others=10,
    pred_len_others=1,
    d_model_others=10
    ):
    
    data = data.reset_index()
    data['date'] = pd.to_datetime(data['Date'])
    data = data.drop(['Date'],axis=1)

    assert seq_len_target > seq_len_others
       
    model_target = regressor(
        dataset = data,
        target = target,
        size = [seq_len_target, pred_len_target],
        features = 'S',
        d_model = d_model_target,
        device = 'cpu'
    )
    
    feat_df_target = model_target.get_features(data)
    stock_data = feat_df_target[target]
    feat_df_target = feat_df_target.drop([target], axis=1)
    
    model_others = regressor(
        dataset = data,
        target = target,
        size = [seq_len_others, pred_len_others],
        features = 'MS',
        d_model = d_model_others,
        device = 'cpu'
    )
    
    feat_df_others = model_others.get_features(data).iloc[seq_len_target-seq_len_others:,:]
    feat_df_others = feat_df_others.drop([target], axis=1)
    feat_df_others.columns = [f'exog_feat_{i+1}' for i in range(len(feat_df_others.columns))]

    features = pd.concat([
                          feat_df_target,
                          feat_df_others
                          ], axis=1)
    
    return features.iloc[:-1,:], pd.DataFrame(stock_data.iloc[:-1])

targets = [
    '4584.T',
    '1557.T',
    '8789.T',
    '1893.T',
    'MSFT'
]

features = {}
for i, target in enumerate(targets):
    feat, stock = make_s4_features(dict_prepared_data[target], target)
    if i==0:
        features[target] = feat
        stock_data = stock
    else:
        features[target] = feat
        stock_data = pd.concat([stock_data, stock], axis=1)   
    

CUDA extension for Cauchy multiplication not found. Install by going to extensions/cauchy/ and running `python setup.py install`. This should speed up end-to-end training by 10-50%
Falling back on slow Cauchy kernel. Install at least one of pykeops or the CUDA extension for memory efficiency.
Falling back on slow Vandermonde kernel. Install pykeops for improved memory efficiency.


In [4]:
display(features)
display(stock_data)

{'4584.T':           feat_1      feat_2     feat_3      feat_4      feat_5      feat_6  \
 180   759.772217  124.378746  35.910980 -355.077606  450.786469 -466.541962   
 181   780.926331  130.576157  35.047318 -363.625000  462.847321 -479.167175   
 182   785.210571  130.537582  36.907253 -366.455231  465.384583 -481.690155   
 183   724.858215  121.349426  31.804905 -340.099091  427.807434 -443.279755   
 184   715.975830  118.337715  32.622684 -333.855835  420.758728 -437.826019   
 ...          ...         ...        ...         ...         ...         ...   
 1248  240.497849   41.462242  13.131885 -109.656715  145.300552 -148.154724   
 1249  237.722244   43.641411  10.955653 -108.320877  143.320831 -149.069626   
 1250  244.596085   42.466618  12.683497 -111.768013  145.887741 -148.545349   
 1251  240.614731   40.725067  10.429011 -111.842094  145.893997 -151.098907   
 1252  238.514633   42.008884  13.450680 -108.465218  145.272293 -147.449829   
 
           feat_7      feat_

Unnamed: 0,4584.T,1557.T,8789.T,1893.T,MSFT
180,816.0,28420.0,149.0,537.042297,101.410782
181,822.0,28830.0,147.0,549.849304,102.930206
182,758.0,29120.0,137.0,547.287903,100.990318
183,749.0,28960.0,138.0,540.457520,101.974594
184,808.0,28870.0,138.0,532.773193,101.487228
...,...,...,...,...,...
1248,249.0,55370.0,65.0,631.000000,288.799988
1249,254.0,55640.0,67.0,635.000000,288.369995
1250,253.0,55630.0,67.0,629.000000,288.450012
1251,251.0,55640.0,67.0,628.000000,286.109985


In [5]:
from ray.tune import register_env

# 環境クラスのインポート
from gym_stock_trading_env import StockTradingEnv

# Define a function to create the environment
def create_stock_trading_env(_):
    env = StockTradingEnv(stock_data, features)
    return env

# Register the environment
env_name = "stock_trading_env"
register_env(env_name, create_stock_trading_env)

In [7]:
# import gymnasium as gym
import ray
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.algorithms.dqn import DQNConfig

# PPOConfig.environment(disable_env_checking=True)
# DQNConfig.environment(disable_env_checking=True)

# Initialize Ray in local mode
ray.init(ignore_reinit_error=True, local_mode=True)

# Create the environment
env = StockTradingEnv(stock_data, features)

# Create PPO agent
ppo_config = PPOConfig().environment(env=env_name)
ppo_agent = ppo_config.build()

# Train the agents
num_iterations = 10

for i in range(num_iterations):
    ppo_result = ppo_agent.train()
    # dqn_result = dqn_agent.train()
    print(f"Iteration {i + 1}:")
    print(f"PPO: episode_reward_mean={ppo_result['episode_reward_mean']}")

# Evaluate the trained agents
def evaluate(agent, env):
    episode_reward = 0
    terminated = False
    obs = env.reset()

    while not terminated:
        action = agent.compute_single_action(obs)
        obs, reward, terminated, _ = env.step(action)
        episode_reward += reward

    return episode_reward

print("\nEvaluating trained agents:")
print(f"PPO: {evaluate(ppo_agent, env)}")

# Shutdown Ray
ray.shutdown()


2023-05-09 01:04:38,879	INFO worker.py:1454 -- Calling ray.init() again after it has already been called.
:actor_name:RolloutWorker
:actor_name:RolloutWorker


:actor_name:RolloutWorker
:actor_name:RolloutWorker
Iteration 1:
PPO: episode_reward_mean=26740534.956261754
Iteration 2:
PPO: episode_reward_mean=106779938.77561693
Iteration 3:
PPO: episode_reward_mean=133604222.52575299
Iteration 4:
PPO: episode_reward_mean=140973899.9382567
Iteration 5:
PPO: episode_reward_mean=165548033.00363925
Iteration 6:
PPO: episode_reward_mean=173857002.68560383
Iteration 7:
PPO: episode_reward_mean=182374462.7843407
Iteration 8:
PPO: episode_reward_mean=186907041.44266376
Iteration 9:
PPO: episode_reward_mean=193746350.60624272
Iteration 10:
PPO: episode_reward_mean=202322647.4261042

Evaluating trained agents:


ValueError: The two structures don't have the same nested structure.

First structure: type=tuple str=({'T_4584_T': array([  816.      ,   759.7722  ,   124.378746,    35.91098 ,
        -355.0776  ,   450.78647 ,  -466.54196 ,  -209.51523 ,
        -668.83704 ,  -765.913   ,   446.53607 , -3296.4219  ,
         -81.178085,  -725.35504 ,   409.0702  ,  3429.948   ,
        1671.3894  ,  1932.4185  , -2117.947   , -1082.8972  ,
        -933.7444  ], dtype=float32), 'T_1557_T': array([ 28420.     ,  27069.906  ,   4526.612  ,   1195.2346 ,
       -12677.845  ,  15991.0625 , -16409.21   ,  -7438.825  ,
       -23656.59   , -27206.695  ,  15819.948  ,  -3414.8308 ,
          743.98334,   -746.6085 ,   -212.78941,   3106.8596 ,
         1120.533  ,   2602.14   ,  -1764.2444 ,    -54.14319,
        -1516.7076 ], dtype=float32), 'T_8789_T': array([  149.      ,   141.81313 ,    20.981348,     8.681377,
         -65.64669 ,    85.784325,   -92.08866 ,   -39.716877,
        -128.91243 ,  -144.8853  ,    85.453156, -3293.6406  ,
        -100.55906 ,  -724.8559  ,   423.67612 ,  3437.5366  ,
        1684.3276  ,  1916.6884  , -2126.2546  , -1107.06    ,
        -920.052   ], dtype=float32), 'T_1893_T': array([  537.0423 ,   515.3221 ,    83.47716,    25.13959,  -240.58551,
         306.40015,  -318.417  ,  -142.34697,  -455.25555,  -520.249  ,
         303.7001 , -3295.3218 ,   -88.84475,  -725.1576 ,   414.84793,
        3432.9497 ,  1676.5074 ,  1926.1959 , -2121.2334 , -1092.4556 ,
        -928.328  ], dtype=float32), 'T_MSFT': array([  101.41078  ,    94.65655  ,    13.091073 ,     6.6034813,
         -43.560165 ,    57.930943 ,   -63.514057 ,   -26.759531 ,
         -87.71068  ,   -97.49457  ,    57.89886  , -3262.2866   ,
       -1525.1359   ,   474.78043  ,  2436.5427   , -1194.3904   ,
         953.6162   ,  4111.0034   ,   794.4325   ,    82.489815 ,
        1999.4463   ], dtype=float32)}, {})

Second structure: type=OrderedDict str=OrderedDict([('T_1557_T', array([ 0.10784248,  0.67715967,  0.57268596, -1.6430274 ,  1.3163116 ,
       -0.5264981 , -1.1280802 , -1.1352262 ,  0.02416827, -0.28174487,
        1.4613514 ,  0.6141178 ,  0.2833027 , -0.7785898 , -0.27564847,
       -1.84351   ,  0.5266593 , -1.1479619 , -0.59976804, -0.68078136,
       -0.21678886], dtype=float32)), ('T_1893_T', array([-1.1713661 , -0.55651677,  0.641187  , -0.37834498, -0.30654532,
        0.14443097, -1.9130996 ,  0.13391708,  1.9052804 ,  0.91851896,
       -0.6931156 , -0.26904592, -0.7914823 , -0.6163682 ,  0.4263788 ,
        1.1251028 ,  0.2916034 , -0.04903826, -0.09718093, -2.325044  ,
        0.83541334], dtype=float32)), ('T_4584_T', array([-0.27869388,  2.5827847 , -0.6521827 ,  0.11909572, -0.75043494,
        0.11965664, -0.18395741, -0.7597558 , -2.2602937 , -0.30668557,
       -0.44791943, -0.20690829,  0.84592086,  0.09911109, -1.215115  ,
        0.32420343,  0.03252498,  0.02946276, -0.18021423,  0.6699253 ,
       -1.3172419 ], dtype=float32)), ('T_8789_T', array([-0.8026259 , -0.4984091 , -1.2810272 ,  0.7019346 ,  0.27599326,
       -0.6144182 ,  0.14961636,  1.7969097 , -2.1039674 ,  0.39290228,
        0.4275565 ,  0.03254066, -0.14300314,  0.3739648 ,  0.8233785 ,
        1.2567726 ,  2.0034873 ,  0.7076085 ,  1.3662488 ,  0.15213673,
       -0.5293134 ], dtype=float32)), ('T_MSFT', array([ 0.30507818, -0.05848873, -1.7709571 ,  0.20654203, -0.12504694,
        0.7090478 ,  0.72867936,  1.19398   ,  1.0777693 ,  0.08289409,
        0.826449  ,  0.8625432 ,  0.47418347,  1.2604042 , -0.54698974,
        1.1916039 , -1.0068083 , -0.09994678,  2.1662927 , -1.3903582 ,
        0.871625  ], dtype=float32))])

More specifically: Substructure "type=dict str={'T_4584_T': array([  816.      ,   759.7722  ,   124.378746,    35.91098 ,
        -355.0776  ,   450.78647 ,  -466.54196 ,  -209.51523 ,
        -668.83704 ,  -765.913   ,   446.53607 , -3296.4219  ,
         -81.178085,  -725.35504 ,   409.0702  ,  3429.948   ,
        1671.3894  ,  1932.4185  , -2117.947   , -1082.8972  ,
        -933.7444  ], dtype=float32), 'T_1557_T': array([ 28420.     ,  27069.906  ,   4526.612  ,   1195.2346 ,
       -12677.845  ,  15991.0625 , -16409.21   ,  -7438.825  ,
       -23656.59   , -27206.695  ,  15819.948  ,  -3414.8308 ,
          743.98334,   -746.6085 ,   -212.78941,   3106.8596 ,
         1120.533  ,   2602.14   ,  -1764.2444 ,    -54.14319,
        -1516.7076 ], dtype=float32), 'T_8789_T': array([  149.      ,   141.81313 ,    20.981348,     8.681377,
         -65.64669 ,    85.784325,   -92.08866 ,   -39.716877,
        -128.91243 ,  -144.8853  ,    85.453156, -3293.6406  ,
        -100.55906 ,  -724.8559  ,   423.67612 ,  3437.5366  ,
        1684.3276  ,  1916.6884  , -2126.2546  , -1107.06    ,
        -920.052   ], dtype=float32), 'T_1893_T': array([  537.0423 ,   515.3221 ,    83.47716,    25.13959,  -240.58551,
         306.40015,  -318.417  ,  -142.34697,  -455.25555,  -520.249  ,
         303.7001 , -3295.3218 ,   -88.84475,  -725.1576 ,   414.84793,
        3432.9497 ,  1676.5074 ,  1926.1959 , -2121.2334 , -1092.4556 ,
        -928.328  ], dtype=float32), 'T_MSFT': array([  101.41078  ,    94.65655  ,    13.091073 ,     6.6034813,
         -43.560165 ,    57.930943 ,   -63.514057 ,   -26.759531 ,
         -87.71068  ,   -97.49457  ,    57.89886  , -3262.2866   ,
       -1525.1359   ,   474.78043  ,  2436.5427   , -1194.3904   ,
         953.6162   ,  4111.0034   ,   794.4325   ,    82.489815 ,
        1999.4463   ], dtype=float32)}" is a sequence, while substructure "type=ndarray str=[ 0.10784248  0.67715967  0.57268596 -1.6430274   1.3163116  -0.5264981
 -1.1280802  -1.1352262   0.02416827 -0.28174487  1.4613514   0.6141178
  0.2833027  -0.7785898  -0.27564847 -1.84351     0.5266593  -1.1479619
 -0.59976804 -0.68078136 -0.21678886]" is not
Entire first structure:
({'T_4584_T': ., 'T_1557_T': ., 'T_8789_T': ., 'T_1893_T': ., 'T_MSFT': .}, {})
Entire second structure:
OrderedDict([('T_1557_T', .), ('T_1893_T', .), ('T_4584_T', .), ('T_8789_T', .), ('T_MSFT', .)])