In [1]:
import gym
import tianshou as ts
import torch

In [2]:
envName = "CartPole-v1"
env = gym.make(envName)

In [3]:
train_envs = ts.env.DummyVectorEnv([lambda: gym.make(envName) for _ in range(10)])
test_envs = ts.env.DummyVectorEnv([lambda: gym.make(envName, render_mode="human") for _ in range(100)])

In [4]:
from tianshou.utils.net.common import Net
state_shape = env.observation_space.shape or env.observation_space.n
action_shape = env.action_space.shape or env.action_space.n
net = Net(state_shape, action_shape, hidden_sizes=(128,128,128,128))
optim = torch.optim.Adam(net.parameters(), lr=1e-3)

In [5]:
policy = ts.policy.DQNPolicy(net, optim, discount_factor=0.9, estimation_step=3, target_update_freq=320)

In [6]:
train_collector = ts.data.Collector(policy, train_envs, ts.data.VectorReplayBuffer(20000, 10), exploration_noise=True)
test_collector = ts.data.Collector(policy, train_envs, exploration_noise=True)

In [7]:
env.spec

EnvSpec(id='CartPole-v1', entry_point='gym.envs.classic_control.cartpole:CartPoleEnv', reward_threshold=475.0, nondeterministic=False, max_episode_steps=500, order_enforce=True, autoreset=False, disable_env_checker=False, apply_api_compatibility=False, kwargs={}, namespace=None, name='CartPole', version=1)

In [8]:
env.spec.reward_threshold

475.0

In [9]:
result = ts.trainer.offpolicy_trainer(
    policy, train_collector, test_collector,
    max_epoch=20, step_per_epoch=10000, step_per_collect=10,
    update_per_step=0.1, episode_per_test=100, batch_size=64,
    train_fn=lambda epoch, env_step: policy.set_eps(0.1),
    test_fn=lambda epoch, env_step: policy.set_eps(0.05),
    stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold,
)
print(f"Finished training! Use {result['duration']}")

Epoch #1: 10001it [00:08, 1147.10it/s, env_step=10000, len=186, loss=0.310, n/ep=0, n/st=10, rew=186.00]                           


Epoch #1: test_reward: 237.130000 ± 51.760923, best_reward: 237.130000 ± 51.760923 in #1


Epoch #2: 10001it [00:07, 1252.16it/s, env_step=20000, len=146, loss=0.366, n/ep=0, n/st=10, rew=146.00]                           


Epoch #2: test_reward: 267.290000 ± 61.752457, best_reward: 267.290000 ± 61.752457 in #2


Epoch #3: 10001it [00:08, 1129.22it/s, env_step=30000, len=188, loss=0.069, n/ep=0, n/st=10, rew=188.00]                          


Epoch #3: test_reward: 189.390000 ± 21.418634, best_reward: 267.290000 ± 61.752457 in #2


Epoch #4: 10001it [00:11, 876.85it/s, env_step=40000, len=165, loss=0.029, n/ep=0, n/st=10, rew=165.00]                            


Epoch #4: test_reward: 207.150000 ± 36.487087, best_reward: 267.290000 ± 61.752457 in #2


Epoch #5: 10001it [00:07, 1326.48it/s, env_step=50000, len=182, loss=0.077, n/ep=0, n/st=10, rew=182.00]                           


Epoch #5: test_reward: 142.660000 ± 7.922399, best_reward: 267.290000 ± 61.752457 in #2


Epoch #6: 10001it [00:07, 1305.26it/s, env_step=60000, len=201, loss=0.039, n/ep=0, n/st=10, rew=201.00]                           


Epoch #6: test_reward: 175.550000 ± 16.113581, best_reward: 267.290000 ± 61.752457 in #2


Epoch #7: 10001it [00:07, 1279.39it/s, env_step=70000, len=207, loss=0.026, n/ep=0, n/st=10, rew=207.00]                           


Epoch #7: test_reward: 175.860000 ± 16.971753, best_reward: 267.290000 ± 61.752457 in #2


Epoch #8: 10001it [00:07, 1286.96it/s, env_step=80000, len=199, loss=0.028, n/ep=0, n/st=10, rew=199.00]                           


Epoch #8: test_reward: 199.850000 ± 14.148056, best_reward: 267.290000 ± 61.752457 in #2


Epoch #9: 10001it [00:13, 763.81it/s, env_step=90000, len=476, loss=0.031, n/ep=0, n/st=10, rew=476.00]                           


Epoch #9: test_reward: 241.460000 ± 25.836184, best_reward: 267.290000 ± 61.752457 in #2


Epoch #10: 10001it [00:16, 619.61it/s, env_step=100000, len=1229, loss=0.097, n/ep=0, n/st=10, rew=1229.00]                           


Epoch #10: test_reward: 305.130000 ± 67.663233, best_reward: 305.130000 ± 67.663233 in #10


Epoch #11: 10001it [00:17, 583.01it/s, env_step=110000, len=566, loss=0.111, n/ep=0, n/st=10, rew=566.00]                             


Epoch #11: test_reward: 158.230000 ± 8.707302, best_reward: 305.130000 ± 67.663233 in #10


Epoch #12: 10001it [00:15, 639.12it/s, env_step=120000, len=868, loss=0.118, n/ep=0, n/st=10, rew=868.50]                             


Epoch #12: test_reward: 155.290000 ± 7.511718, best_reward: 305.130000 ± 67.663233 in #10


Epoch #13: 10001it [00:15, 635.66it/s, env_step=130000, len=1112, loss=0.101, n/ep=0, n/st=10, rew=1112.00]                           


Epoch #13: test_reward: 131.870000 ± 4.564329, best_reward: 305.130000 ± 67.663233 in #10


Epoch #14:  72%|#######1  | 7170/10000 [00:17<00:06, 417.27it/s, env_step=137170, len=1278, n/ep=1, n/st=10, rew=1278.00]             

Finished training! Use 186.48s





In [10]:
result

{'duration': '186.48s',
 'train_time/model': '86.97s',
 'test_step': 857772,
 'test_episode': 5400,
 'test_time': '77.86s',
 'test_speed': '11017.17 step/s',
 'best_reward': 500.0,
 'best_result': '500.00 ± 0.00',
 'train_step': 137170,
 'train_episode': 932,
 'train_time/collector': '21.65s',
 'train_speed': '1262.85 step/s'}

In [11]:
policy.eval()
policy.set_eps(0.05)
collector = ts.data.Collector(policy, gym.make(envName, render_mode="human"), exploration_noise=False)
collector.collect(n_episode=10, render=1 / 60)



KeyboardInterrupt: 