In [1]:
import gym
import tianshou as ts
import torch, numpy as np
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import BasicLogger


In [5]:
test = [(1,2,3), (2,3,4), (3,4,5), (4,5,6)]
for item in zip(*test):
    print(item)

(1, 2, 3, 4)
(2, 3, 4, 5)
(3, 4, 5, 6)


In [None]:
print(*test)

# Self-defined experiment

In [8]:
from module import *
from module.agent_network import *
from module.environment import market_envrionment
import tensorflow as tf
from tqdm import tqdm
import matplotlib.pyplot as plt
tf.config.run_functions_eagerly(True)

In [9]:
# initialization for agent and environment
env = market_envrionment()
state_size = env.observation_space[0] # given from environment
action_size = env.action_space.shape[0]
agent = DQNAgent(state_size, action_size)

# hyper-parameter
done = False
batch_size = 5
history_1 = []
EPISODES = 100

for e in tqdm(range(EPISODES)): # one episode is M trading years in a period
    
    # initialize state
    state = env.reset()
    state = np.reshape(state, [1, state_size]) # an array containing only one array [[a,b,c,d]]
    rewards = 0
    
    for time in range(17): # how many years in a training period

        # take an action
        action = agent.act(state)
        
        # environment responds to the action and return new state and reward
        next_state, reward, done = env.step(action)
        
        # record reward
        rewards += reward
        
        # reshape state(can be reshaped within env)
        next_state = np.reshape(next_state, [1, state_size]) # an 1 x n 2d array
        
        # record the experience for replay
        agent.memorize(state, action, reward, next_state, done) # record every trading 
        
        # transit to next state
        state = next_state
        
        # determine if the training is over or not
        if done:
            break
            
        # replay to train the network    
        if len(agent.memory) > batch_size: # batch_size = 2 to make agent learn for every 3 trading events
            agent.replay(batch_size)
            
    print("episode: {}/{}, e: {:.2}, rewards: {}"
                    .format(e, EPISODES, agent.epsilon, rewards))
    history_1.append([e, time, agent.epsilon, rewards])

  "The `lr` argument is deprecated, use `learning_rate` instead.")
  "Even though the `tf.config.experimental_run_functions_eagerly` "
  1%|          | 1/100 [00:05<09:28,  5.74s/it]

episode: 0/100, e: 0.94, rewards: 1.5114835969498384


  2%|▏         | 2/100 [00:13<11:37,  7.11s/it]

episode: 1/100, e: 0.86, rewards: 1.4361682103095597


  3%|▎         | 3/100 [00:21<11:49,  7.32s/it]

episode: 2/100, e: 0.79, rewards: 1.5498683842038905


  4%|▍         | 4/100 [00:30<12:32,  7.83s/it]

episode: 3/100, e: 0.73, rewards: 1.4524489991580327


  5%|▌         | 5/100 [00:37<12:18,  7.77s/it]

episode: 4/100, e: 0.67, rewards: 1.6443779736264352


  6%|▌         | 6/100 [00:45<12:16,  7.84s/it]

episode: 5/100, e: 0.61, rewards: 1.5249821208487324


  7%|▋         | 7/100 [00:53<12:12,  7.87s/it]

episode: 6/100, e: 0.56, rewards: 1.5882407896299229


  8%|▊         | 8/100 [01:03<13:15,  8.65s/it]

episode: 7/100, e: 0.52, rewards: 1.5875649062539117


  9%|▉         | 9/100 [01:12<13:19,  8.78s/it]

episode: 8/100, e: 0.48, rewards: 1.4945008380948677


 10%|█         | 10/100 [01:20<12:46,  8.52s/it]

episode: 9/100, e: 0.44, rewards: 1.5458187525020994


 11%|█         | 11/100 [01:28<12:19,  8.31s/it]

episode: 10/100, e: 0.4, rewards: 1.296511047095364


 12%|█▏        | 12/100 [01:37<12:13,  8.33s/it]

episode: 11/100, e: 0.37, rewards: 1.4182194540000683


 13%|█▎        | 13/100 [01:45<12:01,  8.29s/it]

episode: 12/100, e: 0.34, rewards: 1.4909754365581167


 14%|█▍        | 14/100 [01:52<11:28,  8.01s/it]

episode: 13/100, e: 0.31, rewards: 1.5933883288523436


 15%|█▌        | 15/100 [01:59<11:02,  7.80s/it]

episode: 14/100, e: 0.29, rewards: 1.5512512683216255


 16%|█▌        | 16/100 [02:07<10:41,  7.64s/it]

episode: 15/100, e: 0.26, rewards: 1.5400461083787267


 17%|█▋        | 17/100 [02:14<10:22,  7.50s/it]

episode: 16/100, e: 0.24, rewards: 1.5174830936307417


 18%|█▊        | 18/100 [02:22<10:17,  7.53s/it]

episode: 17/100, e: 0.22, rewards: 1.474712581353935


 19%|█▉        | 19/100 [02:29<10:03,  7.45s/it]

episode: 18/100, e: 0.2, rewards: 1.4998871605883939


 20%|██        | 20/100 [02:36<09:54,  7.43s/it]

episode: 19/100, e: 0.19, rewards: 1.475099133237893


 21%|██        | 21/100 [02:43<09:41,  7.36s/it]

episode: 20/100, e: 0.17, rewards: 1.4832341811402752


 22%|██▏       | 22/100 [02:51<09:34,  7.36s/it]

episode: 21/100, e: 0.16, rewards: 1.6185206851445328


 23%|██▎       | 23/100 [02:58<09:23,  7.31s/it]

episode: 22/100, e: 0.14, rewards: 1.4706521009435336


 24%|██▍       | 24/100 [03:05<09:13,  7.28s/it]

episode: 23/100, e: 0.13, rewards: 1.4542694546694799


 25%|██▌       | 25/100 [03:13<09:09,  7.32s/it]

episode: 24/100, e: 0.12, rewards: 1.5085139834857035


 26%|██▌       | 26/100 [03:20<09:02,  7.33s/it]

episode: 25/100, e: 0.11, rewards: 1.5702502231977116


 27%|██▋       | 27/100 [03:27<09:00,  7.41s/it]

episode: 26/100, e: 0.1, rewards: 1.5331123224783316


 28%|██▊       | 28/100 [03:35<08:51,  7.39s/it]

episode: 27/100, e: 0.094, rewards: 1.5410023367519683


 29%|██▉       | 29/100 [03:42<08:48,  7.44s/it]

episode: 28/100, e: 0.087, rewards: 1.5244423214447334


 30%|███       | 30/100 [03:50<08:40,  7.43s/it]

episode: 29/100, e: 0.08, rewards: 1.5221819066937623


 31%|███       | 31/100 [03:57<08:36,  7.49s/it]

episode: 30/100, e: 0.073, rewards: 1.4532938388341679


 32%|███▏      | 32/100 [04:05<08:26,  7.44s/it]

episode: 31/100, e: 0.067, rewards: 1.65865303634445


 33%|███▎      | 33/100 [04:12<08:16,  7.42s/it]

episode: 32/100, e: 0.062, rewards: 1.384366864393528


 34%|███▍      | 34/100 [04:20<08:11,  7.45s/it]

episode: 33/100, e: 0.057, rewards: 1.4618548438690482


 35%|███▌      | 35/100 [04:27<08:03,  7.43s/it]

episode: 34/100, e: 0.052, rewards: 1.51251801844733


 36%|███▌      | 36/100 [04:35<07:58,  7.48s/it]

episode: 35/100, e: 0.048, rewards: 1.5818395168318526


 37%|███▋      | 37/100 [04:42<07:49,  7.45s/it]

episode: 36/100, e: 0.044, rewards: 1.4798275968512886


 38%|███▊      | 38/100 [04:50<07:44,  7.50s/it]

episode: 37/100, e: 0.04, rewards: 1.562797508130004


 39%|███▉      | 39/100 [04:57<07:35,  7.47s/it]

episode: 38/100, e: 0.037, rewards: 1.5344039115704373


 40%|████      | 40/100 [05:05<07:29,  7.48s/it]

episode: 39/100, e: 0.034, rewards: 1.5073190087230788


 41%|████      | 41/100 [05:12<07:22,  7.49s/it]

episode: 40/100, e: 0.031, rewards: 1.57851125412942


 42%|████▏     | 42/100 [05:19<07:12,  7.46s/it]

episode: 41/100, e: 0.029, rewards: 1.625786560779943


 43%|████▎     | 43/100 [05:27<07:08,  7.52s/it]

episode: 42/100, e: 0.026, rewards: 1.4231748073039057


 44%|████▍     | 44/100 [05:35<07:02,  7.55s/it]

episode: 43/100, e: 0.024, rewards: 1.4672352323197142


 45%|████▌     | 45/100 [05:42<06:56,  7.57s/it]

episode: 44/100, e: 0.022, rewards: 1.5328954909003927


 46%|████▌     | 46/100 [05:50<06:46,  7.53s/it]

episode: 45/100, e: 0.02, rewards: 1.4974446748123151


 47%|████▋     | 47/100 [05:57<06:40,  7.55s/it]

episode: 46/100, e: 0.019, rewards: 1.5656956366002677


 48%|████▊     | 48/100 [06:05<06:30,  7.51s/it]

episode: 47/100, e: 0.017, rewards: 1.5777990307683811


 49%|████▉     | 49/100 [06:12<06:24,  7.54s/it]

episode: 48/100, e: 0.016, rewards: 1.5550855724681014


 50%|█████     | 50/100 [06:20<06:14,  7.50s/it]

episode: 49/100, e: 0.014, rewards: 1.4624250920640527


 51%|█████     | 51/100 [06:27<06:05,  7.46s/it]

episode: 50/100, e: 0.013, rewards: 1.492149604431204


 52%|█████▏    | 52/100 [06:35<06:00,  7.50s/it]

episode: 51/100, e: 0.012, rewards: 1.6173147154731535


 53%|█████▎    | 53/100 [06:42<05:51,  7.48s/it]

episode: 52/100, e: 0.011, rewards: 1.4213452500819126


 54%|█████▍    | 54/100 [06:50<05:45,  7.51s/it]

episode: 53/100, e: 0.01, rewards: 1.4632892785014369


 55%|█████▌    | 55/100 [06:57<05:36,  7.48s/it]

episode: 54/100, e: 0.01, rewards: 1.3978625656305346


 56%|█████▌    | 56/100 [07:05<05:33,  7.57s/it]

episode: 55/100, e: 0.01, rewards: 1.5411885460791612


 57%|█████▋    | 57/100 [07:12<05:24,  7.55s/it]

episode: 56/100, e: 0.01, rewards: 1.4133025707670779


 58%|█████▊    | 58/100 [07:20<05:18,  7.57s/it]

episode: 57/100, e: 0.01, rewards: 1.420090209620016


 59%|█████▉    | 59/100 [07:27<05:08,  7.52s/it]

episode: 58/100, e: 0.01, rewards: 1.5397617446133027


 60%|██████    | 60/100 [07:35<04:59,  7.49s/it]

episode: 59/100, e: 0.01, rewards: 1.5285201853366415


 61%|██████    | 61/100 [07:43<04:53,  7.53s/it]

episode: 60/100, e: 0.01, rewards: 1.3605097961667805


 62%|██████▏   | 62/100 [07:50<04:44,  7.49s/it]

episode: 61/100, e: 0.01, rewards: 1.48923272962952


 63%|██████▎   | 63/100 [07:58<04:38,  7.52s/it]

episode: 62/100, e: 0.01, rewards: 1.6627059238702693


 64%|██████▍   | 64/100 [08:05<04:29,  7.48s/it]

episode: 63/100, e: 0.01, rewards: 1.4871391874786077


 65%|██████▌   | 65/100 [08:13<04:22,  7.51s/it]

episode: 64/100, e: 0.01, rewards: 1.4344345764356148


 66%|██████▌   | 66/100 [08:20<04:14,  7.47s/it]

episode: 65/100, e: 0.01, rewards: 1.4709511321425048


 67%|██████▋   | 67/100 [08:28<04:08,  7.53s/it]

episode: 66/100, e: 0.01, rewards: 1.5737260780648101


 68%|██████▊   | 68/100 [08:35<04:01,  7.54s/it]

episode: 67/100, e: 0.01, rewards: 1.4344754260031916


 69%|██████▉   | 69/100 [08:43<03:52,  7.50s/it]

episode: 68/100, e: 0.01, rewards: 1.474772889928426


 70%|███████   | 70/100 [08:50<03:46,  7.54s/it]

episode: 69/100, e: 0.01, rewards: 1.314930722419385


 71%|███████   | 71/100 [08:58<03:37,  7.50s/it]

episode: 70/100, e: 0.01, rewards: 1.4933500991472526


 72%|███████▏  | 72/100 [09:05<03:31,  7.54s/it]

episode: 71/100, e: 0.01, rewards: 1.4105258899531326


 73%|███████▎  | 73/100 [09:13<03:22,  7.49s/it]

episode: 72/100, e: 0.01, rewards: 1.5591342224156846


 74%|███████▍  | 74/100 [09:20<03:15,  7.53s/it]

episode: 73/100, e: 0.01, rewards: 1.481283253965126


 75%|███████▌  | 75/100 [09:28<03:07,  7.49s/it]

episode: 74/100, e: 0.01, rewards: 1.566481660055593


 76%|███████▌  | 76/100 [09:35<03:00,  7.50s/it]

episode: 75/100, e: 0.01, rewards: 1.419343498393765


 77%|███████▋  | 77/100 [09:43<02:51,  7.47s/it]

episode: 76/100, e: 0.01, rewards: 1.7480691428962654


 78%|███████▊  | 78/100 [09:50<02:43,  7.45s/it]

episode: 77/100, e: 0.01, rewards: 1.5139052061475928


 79%|███████▉  | 79/100 [09:58<02:37,  7.50s/it]

episode: 78/100, e: 0.01, rewards: 1.3807509407448841


 80%|████████  | 80/100 [10:05<02:29,  7.46s/it]

episode: 79/100, e: 0.01, rewards: 1.6039580208059487


 81%|████████  | 81/100 [10:13<02:22,  7.52s/it]

episode: 80/100, e: 0.01, rewards: 1.5516321367817958


 82%|████████▏ | 82/100 [10:20<02:14,  7.49s/it]

episode: 81/100, e: 0.01, rewards: 1.464408623274913


 83%|████████▎ | 83/100 [10:28<02:08,  7.53s/it]

episode: 82/100, e: 0.01, rewards: 1.553570359095182


 84%|████████▍ | 84/100 [10:35<01:59,  7.49s/it]

episode: 83/100, e: 0.01, rewards: 1.5040864675948307


 85%|████████▌ | 85/100 [10:43<01:52,  7.53s/it]

episode: 84/100, e: 0.01, rewards: 1.4794321386026554


 86%|████████▌ | 86/100 [10:50<01:44,  7.49s/it]

episode: 85/100, e: 0.01, rewards: 1.4576506928324389


 87%|████████▋ | 87/100 [10:58<01:37,  7.52s/it]

episode: 86/100, e: 0.01, rewards: 1.5410329291491454


 88%|████████▊ | 88/100 [11:05<01:29,  7.48s/it]

episode: 87/100, e: 0.01, rewards: 1.52687826880751


 89%|████████▉ | 89/100 [11:12<01:22,  7.46s/it]

episode: 88/100, e: 0.01, rewards: 1.6196696518257299


 90%|█████████ | 90/100 [11:20<01:15,  7.51s/it]

episode: 89/100, e: 0.01, rewards: 1.5286589263143109


 91%|█████████ | 91/100 [11:27<01:07,  7.48s/it]

episode: 90/100, e: 0.01, rewards: 1.5837018938117822


 92%|█████████▏| 92/100 [11:35<01:00,  7.56s/it]

episode: 91/100, e: 0.01, rewards: 1.4825246536886256


 93%|█████████▎| 93/100 [11:43<00:52,  7.51s/it]

episode: 92/100, e: 0.01, rewards: 1.544007871324209


 94%|█████████▍| 94/100 [11:50<00:45,  7.54s/it]

episode: 93/100, e: 0.01, rewards: 1.5535341898794328


 95%|█████████▌| 95/100 [11:58<00:37,  7.50s/it]

episode: 94/100, e: 0.01, rewards: 1.521378554870091


 96%|█████████▌| 96/100 [12:05<00:30,  7.52s/it]

episode: 95/100, e: 0.01, rewards: 1.6919492527180708


 97%|█████████▋| 97/100 [12:13<00:22,  7.49s/it]

episode: 96/100, e: 0.01, rewards: 1.6434065315890103


 98%|█████████▊| 98/100 [12:20<00:14,  7.47s/it]

episode: 97/100, e: 0.01, rewards: 1.4426326860519474


 99%|█████████▉| 99/100 [12:28<00:07,  7.55s/it]

episode: 98/100, e: 0.01, rewards: 1.5262979089853963


100%|██████████| 100/100 [12:35<00:00,  7.56s/it]

episode: 99/100, e: 0.01, rewards: 1.5460787303204409





In [10]:
def training_plot(history):
    df = pd.DataFrame(history, columns =["episode", "total_time","epsilon",'reward'])
    df.set_index("episode")
    df.to_csv('../output/RL_training_data_1.jpg')
    figs, axs = plt.subplots(nrows=2, ncols=1, figsize=(12, 6))

    axs[0].set_title('Total Rewards for each episode')
    axs[0].plot(df.index, df['reward'].rolling(5).mean())
    axs[0].set_ylabel('total rewards')

    axs[1].set_title('Epsilon(exploring rate) for each episode')
    axs[1].plot(df.index, df['epsilon'])
    axs[1].set_ylabel('epsilon')

    plt.tight_layout()
    # figs.savefig('../picture/RL_training_plots_1.jpg')
    plt.show()

In [11]:
training_plot(history_1)

NameError: name 'plt' is not defined

# Switch to other platform

In [2]:
train_envs = ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(8)])
test_envs = ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(100)])

env = gym.make('CartPole-v0')

In [5]:
np.prod(state_shape)

4

In [6]:
state_shape

(4,)

In [7]:
class Net(nn.Module):
    def __init__(self, state_shape, action_shape):
        super().__init__()
        self.model = nn.Sequential(*[
            nn.Linear(np.prod(state_shape), 128), nn.ReLU(inplace=True),
            nn.Linear(128, 128), nn.ReLU(inplace=True),
            nn.Linear(128, 128), nn.ReLU(inplace=True),
            nn.Linear(128, np.prod(action_shape))
        ])
    def forward(self, obs, state=None, info={}):
        if not isinstance(obs, torch.Tensor):
            obs = torch.tensor(obs, dtype=torch.float)
        batch = obs.shape[0]
        logits = self.model(obs.view(batch, -1))
        return logits, state

state_shape = env.observation_space.shape or env.observation_space.n
action_shape = env.action_space.shape or env.action_space.n
net = Net(state_shape, action_shape)
optim = torch.optim.Adam(net.parameters(), lr=1e-3)

In [9]:
net.parameters

<bound method Module.parameters of Net(
  (model): Sequential(
    (0): Linear(in_features=4, out_features=128, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=128, out_features=128, bias=True)
    (3): ReLU(inplace=True)
    (4): Linear(in_features=128, out_features=128, bias=True)
    (5): ReLU(inplace=True)
    (6): Linear(in_features=128, out_features=2, bias=True)
  )
)>

In [4]:
policy = ts.policy.DQNPolicy(net, optim, discount_factor=0.9, estimation_step=3, target_update_freq=320)


In [5]:
train_collector = ts.data.Collector(policy, train_envs, ts.data.VectorReplayBuffer(20000, 10), exploration_noise=True)
test_collector = ts.data.Collector(policy, test_envs, exploration_noise=True)


In [6]:
result = ts.trainer.offpolicy_trainer(
    policy, 
    train_collector, 
    test_collector,
    max_epoch=10, 
    step_per_epoch=10000, 
    step_per_collect=10,
    update_per_step=0.1, 
    episode_per_test=100,
    batch_size=64,
    train_fn=lambda epoch, env_step: policy.set_eps(0.1),
    test_fn=lambda epoch, env_step: policy.set_eps(0.05),
    stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold)
print(f'Finished training! Use {result["duration"]}')

  f"n_step={n_step} is not a multiple of #env ({self.env_num}), "
Epoch #1:  47%|####7     | 4736/10000 [00:03<00:04, 1195.54it/s, env_step=4736, len=200, n/ep=1, n/st=16, rew=200.00]

Finished training! Use 4.10s





In [7]:
policy.eval()
policy.set_eps(0.05)
collector = ts.data.Collector(policy, env, exploration_noise=True)
collector.collect(n_episode=1, render=1 / 35)



{'n/ep': 1,
 'n/st': 200,
 'rews': array([200.]),
 'lens': array([200]),
 'idxs': array([0])}

In [8]:
writer = SummaryWriter('log/dqn')
logger = BasicLogger(writer)

In [9]:
result

{'test_step': 22037,
 'test_episode': 200,
 'test_time': '0.91s',
 'test_speed': '24227.64 step/s',
 'best_reward': 198.07,
 'best_result': '198.07 ± 4.76',
 'duration': '4.10s',
 'train_time/model': '2.52s',
 'train_step': 4736,
 'train_episode': 233,
 'train_time/collector': '0.67s',
 'train_speed': '1484.76 step/s'}