In [1]:
# Gym stuff
import gym
import gym_anytrading
from stable_baselines3 import PPO

# Stable baselines - rl stuff
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3 import A2C, PPO

# Processing libraries
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt

In [2]:
df = pd.read_csv('Reliance.csv')

In [3]:
df.columns

Index(['Date', 'Open', 'High', 'Low', 'Close', 'Volume', 'Dividends',
       'Stock Splits', 'SUPERT_7_1.0', 'SUPERTd_7_1.0', 'SUPERTl_7_1.0',
       'SUPERTs_7_1.0', 'SMA_5', 'WMA_10', 'SMA_14', 'WMA_28', 'SMA_44',
       'PSARl_0.02_0.2', 'PSARs_0.02_0.2', 'PSARaf_0.02_0.2',
       'PSARr_0.02_0.2'],
      dtype='object')

In [4]:
df

                           Date         Open         High          Low  \
0     1996-01-01 00:00:00+05:30    10.439797    10.495815    10.371046   
1     1996-01-02 00:00:00+05:30    10.452527    10.503452    10.320119   
2     1996-01-03 00:00:00+05:30    10.567110    11.048360    10.452527   
3     1996-01-04 00:00:00+05:30    10.376140    10.409242    10.238640   
4     1996-01-05 00:00:00+05:30    10.337946    10.337946    10.218270   
...                         ...          ...          ...          ...   
6790  2023-01-02 00:00:00+05:30  2550.000000  2579.000000  2548.199951   
6791  2023-01-03 00:00:00+05:30  2565.050049  2573.000000  2547.800049   
6792  2023-01-04 00:00:00+05:30  2557.000000  2561.050049  2514.000000   
6793  2023-01-05 00:00:00+05:30  2523.500000  2536.399902  2504.000000   
6794  2023-01-06 00:00:00+05:30  2526.649902  2547.949951  2518.300049   

            Close     Volume  Dividends  Stock Splits  SUPERT_7_1.0  \
0       10.477991   48051995        0.0 

In [5]:
df['Date']

0       1996-01-01 00:00:00+05:30
1       1996-01-02 00:00:00+05:30
2       1996-01-03 00:00:00+05:30
3       1996-01-04 00:00:00+05:30
4       1996-01-05 00:00:00+05:30
                  ...            
6790    2023-01-02 00:00:00+05:30
6791    2023-01-03 00:00:00+05:30
6792    2023-01-04 00:00:00+05:30
6793    2023-01-05 00:00:00+05:30
6794    2023-01-06 00:00:00+05:30
Name: Date, Length: 6795, dtype: object

In [6]:
df.set_index('Date', inplace=True)
df.head()

                                Open       High        Low      Close  \
Date                                                                    
1996-01-01 00:00:00+05:30  10.439797  10.495815  10.371046  10.477991   
1996-01-02 00:00:00+05:30  10.452527  10.503452  10.320119  10.396508   
1996-01-03 00:00:00+05:30  10.567110  11.048360  10.452527  10.475444   
1996-01-04 00:00:00+05:30  10.376140  10.409242  10.238640  10.378686   
1996-01-05 00:00:00+05:30  10.337946  10.337946  10.218270  10.307390   

                              Volume  Dividends  Stock Splits  SUPERT_7_1.0  \
Date                                                                          
1996-01-01 00:00:00+05:30   48051995        0.0           0.0           0.0   
1996-01-02 00:00:00+05:30   77875009        0.0           0.0           NaN   
1996-01-03 00:00:00+05:30   96602936        0.0           0.0           NaN   
1996-01-04 00:00:00+05:30  100099436        0.0           0.0           NaN   
1996-01-05 00:

In [7]:
df

                                  Open         High          Low        Close  \
Date                                                                            
1996-01-01 00:00:00+05:30    10.439797    10.495815    10.371046    10.477991   
1996-01-02 00:00:00+05:30    10.452527    10.503452    10.320119    10.396508   
1996-01-03 00:00:00+05:30    10.567110    11.048360    10.452527    10.475444   
1996-01-04 00:00:00+05:30    10.376140    10.409242    10.238640    10.378686   
1996-01-05 00:00:00+05:30    10.337946    10.337946    10.218270    10.307390   
...                                ...          ...          ...          ...   
2023-01-02 00:00:00+05:30  2550.000000  2579.000000  2548.199951  2575.899902   
2023-01-03 00:00:00+05:30  2565.050049  2573.000000  2547.800049  2557.050049   
2023-01-04 00:00:00+05:30  2557.000000  2561.050049  2514.000000  2518.550049   
2023-01-05 00:00:00+05:30  2523.500000  2536.399902  2504.000000  2514.050049   
2023-01-06 00:00:00+05:30  2

In [8]:
df.rename(columns={'close': 'Close', 'open':'Open', 'low':'Low', 'volume':'Volume'}, inplace=True)

In [9]:
df.shape

(6795, 20)

In [10]:
env = gym.make('stocks-v0', df=df, frame_bound=(5,4000), window_size=5)

In [11]:
env.signal_features

array([[ 1.04779911e+01,  0.00000000e+00],
       [ 1.03965082e+01, -8.14828873e-02],
       [ 1.04754438e+01,  7.89356232e-02],
       ...,
       [ 3.67148712e+02,  1.04157104e+01],
       [ 3.59581970e+02, -7.56674194e+00],
       [ 3.74464844e+02,  1.48828735e+01]])

In [12]:
env.signal_features.shape

(4000, 2)

In [13]:
env = gym.make('stocks-v0', df=df, frame_bound=(5,5000), window_size=5)

In [14]:
env.signal_features.shape

(5000, 2)

In [15]:
state = env.reset()
while True: 
    action = env.action_space.sample()
    n_state, reward, done, info = env.step(action)
    if done: 
        print("info", info)
        break
        
plt.figure(figsize=(15,6))
plt.cla()
env.render_all()
plt.show()

In [16]:
env.compute_reward

<bound method Wrapper.compute_reward of <OrderEnforcing<StocksEnv<stocks-v0>>>>

In [17]:
env.compute_reward()

In [18]:
env.df

                                  Open         High          Low        Close  \
Date                                                                            
1996-01-01 00:00:00+05:30    10.439797    10.495815    10.371046    10.477991   
1996-01-02 00:00:00+05:30    10.452527    10.503452    10.320119    10.396508   
1996-01-03 00:00:00+05:30    10.567110    11.048360    10.452527    10.475444   
1996-01-04 00:00:00+05:30    10.376140    10.409242    10.238640    10.378686   
1996-01-05 00:00:00+05:30    10.337946    10.337946    10.218270    10.307390   
...                                ...          ...          ...          ...   
2023-01-02 00:00:00+05:30  2550.000000  2579.000000  2548.199951  2575.899902   
2023-01-03 00:00:00+05:30  2565.050049  2573.000000  2547.800049  2557.050049   
2023-01-04 00:00:00+05:30  2557.000000  2561.050049  2514.000000  2518.550049   
2023-01-05 00:00:00+05:30  2523.500000  2536.399902  2504.000000  2514.050049   
2023-01-06 00:00:00+05:30  2

In [19]:
env.df.shape

(6795, 20)

In [20]:
env = gym.make('stocks-v0', df=df, frame_bound=(5,5000), window_size=5)

In [21]:
env.signal_features.shape

(5000, 2)

In [22]:
env.df.shape

(6795, 20)

In [23]:
env.observation_space

Box([[-inf -inf]
 [-inf -inf]
 [-inf -inf]
 [-inf -inf]
 [-inf -inf]], [[inf inf]
 [inf inf]
 [inf inf]
 [inf inf]
 [inf inf]], (5, 2), float64)

In [24]:
env.class_name

<bound method Wrapper.class_name of <class 'gym.wrappers.order_enforcing.OrderEnforcing'>>

In [25]:
env.env

<gym_anytrading.envs.stocks_env.StocksEnv at 0x15ce62e20>

In [26]:
env.env.frame_bound

(5, 5000)

In [27]:
env.env.frame_bound.index

<function tuple.index(value, start=0, stop=9223372036854775807, /)>

In [28]:
env.env.frame_bound.count

<function tuple.count(value, /)>

In [29]:
env.env.frame_bound.count()

In [30]:
env.env.frame_bound.count(3)

0

In [31]:
env.env.history

In [32]:
env.env.history()

In [33]:
env.env.render_all

<bound method TradingEnv.render_all of <gym_anytrading.envs.stocks_env.StocksEnv object at 0x15ce62e20>>

In [34]:
env.env.render_all()

In [35]:
env.env.max_possible_profit

<bound method StocksEnv.max_possible_profit of <gym_anytrading.envs.stocks_env.StocksEnv object at 0x15ce62e20>>

In [36]:
env.env.max_possible_profit()

3.4226217697017766e+19

In [37]:
state = env.reset()
while True: 
    action = env.action_space.sample()
    n_state, reward, done, info = env.step(action)
    if done: 
        print("info", info)
        break
        
plt.figure(figsize=(15,6))
plt.cla()
env.render_all()
plt.show()

In [38]:
state = env.reset()
while True: 
    action = env.action_space.sample()
    n_state, reward, done, info = env.step(action)
    if done: 
        print("info", info)
        break
        
plt.figure(figsize=(15,6))
plt.cla()
env.render_all()
plt.show()

In [39]:
env.env.metadata

{'render.modes': ['human']}

In [40]:
env.env.prices

array([ 10.4779911 ,  10.39650822,  10.47544384, ..., 397.29522705,
       400.39071655, 410.6534729 ])

In [41]:
env.env.reward_range

(-inf, inf)

In [42]:
env_maker = lambda: gym.make('stocks-v0', df=df, frame_bound=(10,5000), window_size=10)
env = DummyVecEnv([env_maker])

In [43]:
model = A2C('MlpPolicy', env, verbose=1) 
model.learn(total_timesteps=1000000)

In [44]:
df['SUPERTd_7_1.0'].value_counts()

 1    3577
-1    3218
Name: SUPERTd_7_1.0, dtype: int64

In [45]:
import pandas_ta as ta

In [46]:
df['RSI'] = df.ta.rsi(14)

In [47]:
df['PSARr_0.02_0.2'].value_counts()

0    6195
1     600
Name: PSARr_0.02_0.2, dtype: int64

In [48]:
df.columns

Index(['Open', 'High', 'Low', 'Close', 'Volume', 'Dividends', 'Stock Splits',
       'SUPERT_7_1.0', 'SUPERTd_7_1.0', 'SUPERTl_7_1.0', 'SUPERTs_7_1.0',
       'SMA_5', 'WMA_10', 'SMA_14', 'WMA_28', 'SMA_44', 'PSARl_0.02_0.2',
       'PSARs_0.02_0.2', 'PSARaf_0.02_0.2', 'PSARr_0.02_0.2', 'RSI'],
      dtype='object')

In [49]:
df.fillna(0)

                                  Open         High          Low        Close  \
Date                                                                            
1996-01-01 00:00:00+05:30    10.439797    10.495815    10.371046    10.477991   
1996-01-02 00:00:00+05:30    10.452527    10.503452    10.320119    10.396508   
1996-01-03 00:00:00+05:30    10.567110    11.048360    10.452527    10.475444   
1996-01-04 00:00:00+05:30    10.376140    10.409242    10.238640    10.378686   
1996-01-05 00:00:00+05:30    10.337946    10.337946    10.218270    10.307390   
...                                ...          ...          ...          ...   
2023-01-02 00:00:00+05:30  2550.000000  2579.000000  2548.199951  2575.899902   
2023-01-03 00:00:00+05:30  2565.050049  2573.000000  2547.800049  2557.050049   
2023-01-04 00:00:00+05:30  2557.000000  2561.050049  2514.000000  2518.550049   
2023-01-05 00:00:00+05:30  2523.500000  2536.399902  2504.000000  2514.050049   
2023-01-06 00:00:00+05:30  2

In [50]:
env.df.loc[:, ['Low', 'Volume','SMA_5', 'RSI', 'SUPERTd_7_1.0', 'PSARr_0.02_0.2']].fillna(0).to_numpy()

In [51]:
env_maker = lambda: gym.make('stocks-v0', df=df, frame_bound=(10,5000), window_size=10)
env = DummyVecEnv([env_maker])

In [52]:
df['SUPERTd_7_1.0'].value_counts()

 1    3577
-1    3218
Name: SUPERTd_7_1.0, dtype: int64

In [53]:
import pandas_ta as ta

In [54]:
df['RSI'] = df.ta.rsi(14)

In [55]:
df['PSARr_0.02_0.2'].value_counts()

0    6195
1     600
Name: PSARr_0.02_0.2, dtype: int64

In [56]:
df.columns

Index(['Open', 'High', 'Low', 'Close', 'Volume', 'Dividends', 'Stock Splits',
       'SUPERT_7_1.0', 'SUPERTd_7_1.0', 'SUPERTl_7_1.0', 'SUPERTs_7_1.0',
       'SMA_5', 'WMA_10', 'SMA_14', 'WMA_28', 'SMA_44', 'PSARl_0.02_0.2',
       'PSARs_0.02_0.2', 'PSARaf_0.02_0.2', 'PSARr_0.02_0.2', 'RSI'],
      dtype='object')

In [57]:
df.fillna(0)

                                  Open         High          Low        Close  \
Date                                                                            
1996-01-01 00:00:00+05:30    10.439797    10.495815    10.371046    10.477991   
1996-01-02 00:00:00+05:30    10.452527    10.503452    10.320119    10.396508   
1996-01-03 00:00:00+05:30    10.567110    11.048360    10.452527    10.475444   
1996-01-04 00:00:00+05:30    10.376140    10.409242    10.238640    10.378686   
1996-01-05 00:00:00+05:30    10.337946    10.337946    10.218270    10.307390   
...                                ...          ...          ...          ...   
2023-01-02 00:00:00+05:30  2550.000000  2579.000000  2548.199951  2575.899902   
2023-01-03 00:00:00+05:30  2565.050049  2573.000000  2547.800049  2557.050049   
2023-01-04 00:00:00+05:30  2557.000000  2561.050049  2514.000000  2518.550049   
2023-01-05 00:00:00+05:30  2523.500000  2536.399902  2504.000000  2514.050049   
2023-01-06 00:00:00+05:30  2

In [58]:
env.df.loc[:, ['Low', 'Volume','SMA_5', 'RSI', 'SUPERTd_7_1.0', 'PSARr_0.02_0.2']].fillna(0).to_numpy()

In [59]:
def add_signals(env):
    start = env.frame_bound[0] - env.window_size
    end = env.frame_bound[1]
    prices = env.df.loc[:, 'Low'].to_numpy()[start:end]
    signal_features = env.df.loc[:, ['Low', 'Volume','SMA_5', 'RSI', 'SUPERTd_7_1.0', 'PSARr_0.02_0.2']].fillna(0).to_numpy()[start:end]
    return prices, signal_features

In [60]:
from gym_anytrading.envs import StocksEnv

class MyCustomEnv(StocksEnv):
    _process_data = add_signals
    
env2 = MyCustomEnv(df=df, window_size=12, frame_bound=(12,50))

In [61]:
env2.signal_features

array([[ 1.03710464e+01,  4.80519950e+07,  0.00000000e+00,
         0.00000000e+00,  1.00000000e+00,  0.00000000e+00],
       [ 1.03201192e+01,  7.78750090e+07,  0.00000000e+00,
         0.00000000e+00,  1.00000000e+00,  1.00000000e+00],
       [ 1.04525269e+01,  9.66029360e+07,  0.00000000e+00,
         0.00000000e+00,  1.00000000e+00,  0.00000000e+00],
       [ 1.02386398e+01,  1.00099436e+08,  0.00000000e+00,
         0.00000000e+00,  1.00000000e+00,  1.00000000e+00],
       [ 1.02182699e+01,  7.69359300e+07,  1.04072039e+01,
         0.00000000e+00,  1.00000000e+00,  0.00000000e+00],
       [ 9.79303950e+00,  8.62885840e+07,  1.02905840e+01,
         0.00000000e+00,  1.00000000e+00,  0.00000000e+00],
       [ 9.28123321e+00,  1.79415702e+08,  1.01332233e+01,
         0.00000000e+00,  1.00000000e+00,  0.00000000e+00],
       [ 9.31942837e+00,  1.27653926e+08,  9.91067753e+00,
         0.00000000e+00,  1.00000000e+00,  0.00000000e+00],
       [ 9.25576849e+00,  1.89051436e+08,  9.766

In [62]:
from gym_anytrading.envs import StocksEnv

class MyCustomEnv(StocksEnv):
    _process_data = add_signals
    
env2 = MyCustomEnv(df=df, window_size=12, frame_bound=(5001,6700))

In [63]:
env2.signal_features

array([[ 4.05700720e+02,  3.67541800e+06,  4.10934485e+02,
         4.25042364e+01,  1.00000000e+00,  0.00000000e+00],
       [ 4.07462751e+02,  5.31176500e+06,  4.12687012e+02,
         4.35500728e+01,  1.00000000e+00,  0.00000000e+00],
       [ 4.12034576e+02,  6.19554800e+06,  4.12544147e+02,
         4.54334169e+01,  1.00000000e+00,  0.00000000e+00],
       ...,
       [ 2.64019995e+03,  3.41973000e+06,  2.63404209e+03,
         6.64535491e+01,  1.00000000e+00,  0.00000000e+00],
       [ 2.60494995e+03,  4.36612300e+06,  2.64014761e+03,
         5.85132709e+01,  1.00000000e+00,  0.00000000e+00],
       [ 2.58600000e+03,  4.37117900e+06,  2.63644834e+03,
         5.73850269e+01, -1.00000000e+00,  0.00000000e+00]])

In [64]:
import wandb
from wandb.integration.sb3 import WandbCallback

In [65]:
env_maker = lambda: env2
env = DummyVecEnv([env_maker])

In [66]:
config = {
    "policy_type": "MlpPolicy",
    "total_timesteps": 1000000,
    "env_id": "Stocks-v0",
}
run = wandb.init(
    project="sb3-New",
    config=config,
    sync_tensorboard=True,  # auto-upload sb3's tensorboard metrics
    monitor_gym=True,  # auto-upload the videos of agents playing the game
    save_code=True,  # optional
)

model_with_signals_A2C = A2C(config["policy_type"], env, verbose=1, tensorboard_log=f"runs/{run.id}")
model_with_signals_A2C.learn(
    total_timesteps=config["total_timesteps"],
    callback=WandbCallback(
        model_save_path=f"models/{run.id}",
        verbose=2,
    ),
)
run.finish()

In [67]:
from gym_anytrading.envs import StocksEnv

class MyCustomEnv(StocksEnv):
    _process_data = add_signals
    
env2 = MyCustomEnv(df=df, window_size=12, frame_bound=(12,5000))

In [68]:
env2.signal_features

array([[ 1.03710464e+01,  4.80519950e+07,  0.00000000e+00,
         0.00000000e+00,  1.00000000e+00,  0.00000000e+00],
       [ 1.03201192e+01,  7.78750090e+07,  0.00000000e+00,
         0.00000000e+00,  1.00000000e+00,  1.00000000e+00],
       [ 1.04525269e+01,  9.66029360e+07,  0.00000000e+00,
         0.00000000e+00,  1.00000000e+00,  0.00000000e+00],
       ...,
       [ 3.94652159e+02,  5.36476900e+06,  4.04600623e+02,
         3.75337633e+01, -1.00000000e+00,  0.00000000e+00],
       [ 3.93413934e+02,  7.31637400e+06,  4.01043170e+02,
         3.99135617e+01, -1.00000000e+00,  1.00000000e+00],
       [ 3.99319180e+02,  7.98026400e+06,  4.02414697e+02,
         4.71081063e+01,  1.00000000e+00,  0.00000000e+00]])

In [69]:
import wandb
from wandb.integration.sb3 import WandbCallback

In [70]:
env_maker = lambda: env2
env = DummyVecEnv([env_maker])

In [71]:
config = {
    "policy_type": "MlpPolicy",
    "total_timesteps": 1000000,
    "env_id": "Stocks-v0",
}
run = wandb.init(
    project="sb3-Latest",
    config=config,
    sync_tensorboard=True,  # auto-upload sb3's tensorboard metrics
    monitor_gym=True,  # auto-upload the videos of agents playing the game
    save_code=True,  # optional
)

model_with_signals_A2C = A2C(config["policy_type"], env, verbose=1, tensorboard_log=f"runs/{run.id}")
model_with_signals_A2C.learn(
    total_timesteps=config["total_timesteps"],
    callback=WandbCallback(
        model_save_path=f"models/{run.id}",
        verbose=2,
    ),
)
run.finish()

In [72]:
env = MyCustomEnv(df=df, window_size=12, frame_bound=(5012,6700))
obs = env.reset()
while True: 
    obs = obs[np.newaxis, ...]
    action, _states = model_with_signals_A2C.predict(obs)
    obs, rewards, done, info = env.step(action)
    if done:
        print("info", info)
        break

In [73]:
plt.figure(figsize=(15,6))
plt.cla()
env.render_all()
plt.legend('PSB')
plt.title('A2C Performance')
plt.show()

In [74]:
plt.figure(figsize=(25,6))
plt.cla()
env.render_all()
plt.legend('PSB')
plt.title('A2C Performance')
plt.show()

In [75]:
env = MyCustomEnv(df=df, window_size=12, frame_bound=(5012,6700))
obs = env.reset()
while True: 
    obs = obs[np.newaxis, ...]
    action, _states = model_with_signals_A2C.predict(obs)
    obs, rewards, done, info = env.step(action)
    if done:
        print("info", info)
        break

In [76]:
env = MyCustomEnv(df=df, window_size=12, frame_bound=(5012,6700))
obs = env.reset()
while True: 
    obs = obs[np.newaxis, ...]
    action, _states = model_with_signals_A2C.predict(obs)
    obs, rewards, done, info = env.step(action)
    if done:
        print("info", info)
        break

In [77]:
plt.figure(figsize=(25,6))
plt.cla()
env.render_all()
plt.legend('PSB')
plt.title('A2C Performance')
plt.show()

In [78]:
import tensorboard

In [79]:
env = MyCustomEnv(df=df, window_size=12, frame_bound=(5012,6700))
obs = env.reset()
while True: 
    obs = obs[np.newaxis, ...]
    action, _states = model_with_signals_A2C.predict(obs)
    obs, rewards, done, info = env.step(action)
    if done:
        print("info", info)
        break

In [80]:
env = MyCustomEnv(df=df, window_size=12, frame_bound=(5012,6700))
obs = env.reset()
while True: 
    obs = obs[np.newaxis, ...]
    action, _states = model_with_signals_A2C.predict(obs)
    obs, rewards, done, info = env.step(action)
    if done:
        print("info", info)
        break

In [81]:
env = MyCustomEnv(df=df, window_size=12, frame_bound=(5012,6700))
obs = env.reset()
while True: 
    obs = obs[np.newaxis, ...]
    action, _states = model_with_signals_A2C.predict(obs)
    obs, rewards, done, info = env.step(action)
    if done:
        print("info", info)
        break

In [82]:
env = MyCustomEnv(df=df, window_size=12, frame_bound=(5012,6700))
obs = env.reset()
while True: 
    obs = obs[np.newaxis, ...]
    action, _states = model_with_signals_A2C.predict(obs)
    obs, rewards, done, info = env.step(action)
    if done:
        print("info", info)
        break

In [83]:
env = MyCustomEnv(df=df, window_size=12, frame_bound=(5012,6700))
obs = env.reset()
while True: 
    obs = obs[np.newaxis, ...]
    action, _states = model_with_signals_A2C.predict(obs)
    obs, rewards, done, info = env.step(action)
    if done:
        print("info", info)
        break

In [84]:
env = MyCustomEnv(df=df, window_size=12, frame_bound=(5012,6700))
obs = env.reset()
while True: 
    obs = obs[np.newaxis, ...]
    action, _states = model_with_signals_A2C.predict(obs)
    obs, rewards, done, info = env.step(action)
    if done:
        print("info", info)
        break

In [85]:
plt.figure(figsize=(25,6))
plt.cla()
env.render_all()
plt.legend('PSB')
plt.title('A2C Performance')
plt.show()

In [86]:
env_maker = lambda: env2
env = DummyVecEnv([env_maker])

In [87]:
config = {
    "policy_type": "MlpPolicy",
    "total_timesteps": 1000000,
    "env_id": "Stocks-v0",
}
run = wandb.init(
    project="sb3-Latest",
    config=config,
    sync_tensorboard=True,  # auto-upload sb3's tensorboard metrics
    monitor_gym=True,  # auto-upload the videos of agents playing the game
    save_code=True,  # optional
)

model_with_signals_PPO = PPO(config["policy_type"], env, verbose=1, tensorboard_log=f"runs/{run.id}")
model_with_signals_PPO.learn(
    total_timesteps=config["total_timesteps"],
    callback=WandbCallback(
        model_save_path=f"models/{run.id}",
        verbose=2,
    ),
)
run.finish()

In [88]:
env = MyCustomEnv(df=df, window_size=12, frame_bound=(5012,6700))
obs = env.reset()
while True: 
    obs = obs[np.newaxis, ...]
    action, _states = model_with_signals_PPO.predict(obs)
    obs, rewards, done, info = env.step(action)
    if done:
        print("info", info)
        break
        

In [89]:
plt.figure(figsize=(15,6))
plt.cla()
env.render_all()
plt.legend('PSB')
plt.title('PPO Performance')
plt.show()

In [90]:
env_maker = lambda: env2
env = DummyVecEnv([env_maker])

In [91]:
from stable_baselines3 import DQN

In [92]:
config = {
    "policy_type": "MlpPolicy",
    "total_timesteps": 1000000,
    "env_id": "Stocks-v0",
}
run = wandb.init(
    project="sb3",
    config=config,
    sync_tensorboard=True,  # auto-upload sb3's tensorboard metrics
    monitor_gym=True,  # auto-upload the videos of agents playing the game
    save_code=True,  # optional
)

model_with_signals_DQN = DQN(config["policy_type"], env, verbose=1, tensorboard_log=f"runs/{run.id}")
model_with_signals_DQN.learn(
    total_timesteps=config["total_timesteps"],
    callback=WandbCallback(
        model_save_path=f"models/{run.id}",
        verbose=2,
    ),
)
run.finish()

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016834687499999744, max=1.0…