In [1]:
import pandas as pd
import pickle 
import sys
import finrl
from finenv.env_stocktrading import StockTradingEnv
from finenv.preprocessors import data_split
from finenv.save_model import upload_files
import psutil
import ray
from datetime import datetime
ray._private.utils.get_system_memory = lambda: psutil.virtual_memory().total
from ray.tune.registry import register_env
from gymnasium.wrappers import EnvCompatibility

train = pd.read_csv('dataset/train_data.csv')
train = train.set_index(train.columns[0])
train.index.names = ['']
INDICATORS = ['macd','boll_ub','boll_lb','rsi_30','cci_30','dx_30','close_30_sma','close_60_sma']
stock_dimension = len(train.tic.unique())
state_space = 1 + 2*stock_dimension + len(INDICATORS)*stock_dimension
buy_cost_list = sell_cost_list = [0.001] * stock_dimension
num_stock_shares = [0] * stock_dimension

In [2]:
def env_creator(env_config):
    # env_config is passed as {} and defaults are set here
    df = env_config.get('df', train)
    hmax = env_config.get('hmax', 200)
    initial_amount = env_config.get('initial_amount', 1000000)
    num_stock_shares = env_config.get('num_stock_shares', [0] * stock_dimension)
    buy_cost_pct = env_config.get('buy_cost_pct', buy_cost_list)
    sell_cost_pct = env_config.get('sell_cost_pct', sell_cost_list)
    state_space = env_config.get('state_space', 1 + 2*stock_dimension + len(INDICATORS)*stock_dimension)
    stock_dim = env_config.get('stock_dim', stock_dimension)
    tech_indicator_list = env_config.get('tech_indicator_list', INDICATORS)
    action_space = env_config.get('action_space', stock_dimension)
    reward_scaling = env_config.get('reward_scaling', 1e-2)

    return EnvCompatibility(StockTradingEnv(
        df=df,
        hmax=hmax,
        initial_amount=initial_amount,
        num_stock_shares=num_stock_shares,
        buy_cost_pct=buy_cost_pct,
        sell_cost_pct=sell_cost_pct,
        state_space=state_space,
        stock_dim=stock_dim,
        tech_indicator_list=tech_indicator_list,
        action_space=action_space,
        reward_scaling=reward_scaling
    ))

register_env("finrl", env_creator)
from ray.rllib.agents import ppo

In [3]:
#Test weight transfering to new conifg 
from ray.rllib.algorithms.algorithm import Algorithm
ray.shutdown()
ray.init(num_cpus=12)
cwd_checkpoint = 'model/ppo_230405/checkpoint_000025'
trainer = Algorithm.from_checkpoint(cwd_checkpoint)
model_weights = trainer.get_policy().get_weights()
print('passed model weights')
config = ppo.PPOConfig()
print('config created')
config = config.environment(env_config={'hmax':500,'initial_amount':1000000})
config = config.training(gamma=0.9, lr=0.001, kl_coeff=0.3)  
config = config.rollouts(num_rollout_workers=0) 
config = config.framework(framework="torch")
config['seed'] = 42
config["model"]["fcnet_hiddens"] = [256, 256, 128,16]
config['sgd_minibatch_size'] = 128
config['num_sgd_iter'] = 30
config['rollout_fragment_length'] = 1000
config['train_batch_size'] = 10000
trainer2 = ppo.PPOTrainer(env='finrl', config=config)
trainer2.get_policy().set_weights(model_weights)
print('New Weights loaded. ')
cwd_checkpoint = f"model/ppo_new_230405"
trainer2.save(cwd_checkpoint)

2023-04-05 10:58:39,116	INFO worker.py:1544 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8265 [39m[22m
2023-04-05 10:58:40,851	INFO algorithm.py:506 -- Current log_level is WARN. For more information, set 'log_level': 'INFO' / 'DEBUG' or use the -v and -vv flags.
[2m[36m(RolloutWorker pid=28457)[0m   logger.warn("Casting input x to numpy array.")
[2m[36m(RolloutWorker pid=28402)[0m   logger.warn("Casting input x to numpy array.")
[2m[36m(RolloutWorker pid=28429)[0m   logger.warn("Casting input x to numpy array.")
[2m[36m(RolloutWorker pid=28415)[0m   logger.warn("Casting input x to numpy array.")
[2m[36m(RolloutWorker pid=28471)[0m   logger.warn("Casting input x to numpy array.")
[2m[36m(RolloutWorker pid=28443)[0m   logger.warn("Casting input x to numpy array.")
[2m[36m(RolloutWorker pid=28485)[0m   logger.warn("Casting input x to numpy array.")
[2m[36m(RolloutWorker pid=28616)[0m   logger.warn("Casting input x to numpy arr

passed model weights
config created


  logger.warn("Casting input x to numpy array.")


New Weights loaded. 


'model/ppo_new_230405/checkpoint_000000'