## Check if installed

In [1]:
import tianshou as ts
print(ts.__version__)

0.4.10


## Make an environment

In [2]:
import gym

env = gym.make("CartPole-v0")  # cart carrying a pole moving on a track

# Create 10 environments in train_evns and 100 in test_evns
train_envs = ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(10)])
test_envs = ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(100)])

In [3]:
# Can instead import EnvPool to speed up
#import envpool - OBS CANNOT INSTALL ENVPOOL ON ANACONDA??!??!
#train_envs = envpool.make_gym("CartPole-v0", num_evns=10)
#test_envs = envpool.make_gym("CartPole-v0", num_evns=100)

## Build the network

**Old code where the result did not work. Copy pasted from website instead**
import torch, numpy as np
from torch import nn

class Net(nn.Module):
    def __init__(self, state_shape, action_shape):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(np.prod(state_shape), 128), nn.ReLU(inplace=True),
            nn.Linear(128, 128), nn.ReLU(inplace=True),
            nn.Linear(128, 128), nn.ReLU(inplace=True),
            nn.Linear(128, np.prod(action_shape)),
            )
            
    def forward(self, obs, state=None, info={}):
        if not isinstance(obs, torch.Tensor):
            obs = torch.tensor(obs, dtype=torch.float)  
        batch = obs.shape[0]
        logits = self.model(obs.view(batch, -1))
        return logits, state  
    
state_shape = env.observation_space.shape or env.observation_space.n
action_shape = env.observation_space.shape or env.observation_space.n
net = Net(state_shape, action_shape)
optim = torch.optim.Adam(net.parameters(), lr=1e-3)

In [3]:
import torch, numpy as np
from torch import nn

class Net(nn.Module):
    def __init__(self, state_shape, action_shape):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(np.prod(state_shape), 128), nn.ReLU(inplace=True),
            nn.Linear(128, 128), nn.ReLU(inplace=True),
            nn.Linear(128, 128), nn.ReLU(inplace=True),
            nn.Linear(128, np.prod(action_shape)),
        )

    def forward(self, obs, state=None, info={}):
        if not isinstance(obs, torch.Tensor): 
            obs = torch.tensor(obs, dtype=torch.float)  # If observation is not a tensor, make it a tensor
        batch = obs.shape[0]
        logits = self.model(obs.view(batch, -1))  # Logits are "raw output of the neural network"
        return logits, state  

state_shape = env.observation_space.shape or env.observation_space.n
action_shape = env.action_space.shape or env.action_space.n
net = Net(state_shape, action_shape)
optim = torch.optim.Adam(net.parameters(), lr=1e-3)

## Setup policy

In [4]:
#use the neural network "net" and the optimizer "optim" from above 
policy = ts.policy.DQNPolicy(net, optim, discount_factor=0.9, estimation_step=3, target_update_freq=320)

## Setup collector
Collector is a Tianshou concept.

"Allows policty to interact with different types of environments conveniently".

Number of buffers should be the number of environments.

In [5]:
train_collector = ts.data.Collector(policy, train_envs, ts.data.VectorReplayBuffer(20000, 10), exploration_noise=True)
test_collector = ts.data.Collector(policy, test_envs, exploration_noise=True)

## Train policy with a Trainer

Several options for trainer. DQN is an off-policy algorithm, so `offpolicy_trainer()` is used. It stops training when `stop_fn` condition is reached.

In [6]:
from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import TensorboardLogger
#writer = SummaryWriter("log/dqn")
#logger = TensorboardLogger(writer)  # Gets massive warnings if passed into result

load_model = True  # Train new model or use existing

if load_model == False:
    result = ts.trainer.offpolicy_trainer(
        policy, train_collector, test_collector,
        max_epoch=10, step_per_epoch=10000, step_per_collect=10,
        update_per_step=0.1, episode_per_test=100, batch_size=64,
        train_fn=lambda epoch, env_step: policy.set_eps(0.1),
        test_fn=lambda epoch, env_step: policy.set_eps(0.05),
        stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold)
    print(f'Finished training! Use {result["duration"]}')

## Save and load policy

In [8]:
if load_model == False:
    torch.save(policy.state_dict(), "dqn.pth")

In [9]:
if load_model:
    policy.load_state_dict(torch.load("dqn.pth"))

## Watch performance

In [10]:
policy.eval()
policy.set_eps(0.05)
collector = ts.data.Collector(policy, env, exploration_noise=True)
collector.collect(n_episode=1, render=1 / 35)  # 35 fps
# Warning - The kernel crashes once the animation finishes.



{'n/ep': 1,
 'n/st': 200,
 'rews': array([200.]),
 'lens': array([200]),
 'idxs': array([0]),
 'rew': 200.0,
 'len': 200.0,
 'rew_std': 0.0,
 'len_std': 0.0}