In [1]:
import gym
import tianshou as ts
import torch

In [2]:
envName = "CartPole-v1"
env = gym.make(envName)

In [3]:
train_envs = ts.env.DummyVectorEnv([lambda: gym.make(envName) for _ in range(10)])
test_envs = ts.env.DummyVectorEnv([lambda: gym.make(envName, render_mode="human") for _ in range(100)])

In [4]:
from tianshou.utils.net.common import Net
state_shape = env.observation_space.shape or env.observation_space.n
action_shape = env.action_space.shape or env.action_space.n
net = Net(state_shape, action_shape, hidden_sizes=(128,128,128,128))
optim = torch.optim.Adam(net.parameters(), lr=1e-3)

In [5]:
policy = ts.policy.DQNPolicy(net, optim, discount_factor=0.9, estimation_step=3, target_update_freq=320)

In [6]:
train_collector = ts.data.Collector(policy, train_envs, ts.data.VectorReplayBuffer(20000, 10), exploration_noise=True)
test_collector = ts.data.Collector(policy, train_envs, exploration_noise=True)

In [7]:
env.spec

EnvSpec(id='CartPole-v1', entry_point='gym.envs.classic_control.cartpole:CartPoleEnv', reward_threshold=475.0, nondeterministic=False, max_episode_steps=500, order_enforce=True, autoreset=False, disable_env_checker=False, apply_api_compatibility=False, kwargs={}, namespace=None, name='CartPole', version=1)

In [8]:
env.spec.reward_threshold

475.0

In [9]:
result = ts.trainer.offpolicy_trainer(
    policy, train_collector, test_collector,
    max_epoch=20, step_per_epoch=10000, step_per_collect=10,
    update_per_step=0.1, episode_per_test=100, batch_size=64,
    train_fn=lambda epoch, env_step: policy.set_eps(0.1),
    test_fn=lambda epoch, env_step: policy.set_eps(0.05),
    stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold,
)
print(f"Finished training! Use {result['duration']}")

Epoch #1: 10001it [00:07, 1405.44it/s, env_step=10000, len=137, loss=0.351, n/ep=0, n/st=10, rew=137.00]                           


Epoch #1: test_reward: 206.040000 ± 44.547485, best_reward: 206.040000 ± 44.547485 in #1


Epoch #2: 10001it [00:06, 1568.23it/s, env_step=20000, len=177, loss=0.386, n/ep=0, n/st=10, rew=177.00]                           


Epoch #2: test_reward: 210.830000 ± 24.216133, best_reward: 210.830000 ± 24.216133 in #2


Epoch #3: 10001it [00:06, 1564.43it/s, env_step=30000, len=227, loss=0.057, n/ep=0, n/st=10, rew=227.00]                           


Epoch #3: test_reward: 178.170000 ± 12.687045, best_reward: 210.830000 ± 24.216133 in #2


Epoch #4: 10001it [00:07, 1414.14it/s, env_step=40000, len=149, loss=0.025, n/ep=0, n/st=10, rew=149.00]                           


Epoch #4: test_reward: 176.080000 ± 11.839493, best_reward: 210.830000 ± 24.216133 in #2


Epoch #5: 10001it [00:06, 1463.88it/s, env_step=50000, len=176, loss=0.026, n/ep=0, n/st=10, rew=176.00]                           


Epoch #5: test_reward: 196.890000 ± 14.592392, best_reward: 210.830000 ± 24.216133 in #2


Epoch #6: 10001it [00:07, 1394.38it/s, env_step=60000, len=154, loss=0.025, n/ep=0, n/st=10, rew=154.00]                           


Epoch #6: test_reward: 172.770000 ± 10.669447, best_reward: 210.830000 ± 24.216133 in #2


Epoch #7: 10001it [00:06, 1591.69it/s, env_step=70000, len=158, loss=0.041, n/ep=0, n/st=10, rew=158.00]                           


Epoch #7: test_reward: 203.140000 ± 25.707205, best_reward: 210.830000 ± 24.216133 in #2


Epoch #8: 10001it [00:06, 1522.81it/s, env_step=80000, len=180, loss=0.057, n/ep=0, n/st=10, rew=180.00]                           


Epoch #8: test_reward: 167.100000 ± 7.091544, best_reward: 210.830000 ± 24.216133 in #2


Epoch #9: 10001it [00:06, 1562.25it/s, env_step=90000, len=230, loss=0.013, n/ep=0, n/st=10, rew=230.00]                           


Epoch #9: test_reward: 155.050000 ± 9.343848, best_reward: 210.830000 ± 24.216133 in #2


Epoch #10: 10001it [00:06, 1571.68it/s, env_step=100000, len=142, loss=0.034, n/ep=0, n/st=10, rew=142.00]                           


Epoch #10: test_reward: 103.350000 ± 9.334211, best_reward: 210.830000 ± 24.216133 in #2


Epoch #11: 10001it [00:06, 1562.57it/s, env_step=110000, len=40, loss=0.085, n/ep=0, n/st=10, rew=40.00]                            


Epoch #11: test_reward: 105.590000 ± 3.243748, best_reward: 210.830000 ± 24.216133 in #2


Epoch #12: 10001it [00:06, 1561.73it/s, env_step=120000, len=139, loss=0.185, n/ep=0, n/st=10, rew=139.00]                           


Epoch #12: test_reward: 83.370000 ± 34.078044, best_reward: 210.830000 ± 24.216133 in #2


Epoch #13: 10001it [00:06, 1546.68it/s, env_step=130000, len=129, loss=0.136, n/ep=0, n/st=10, rew=129.00]                           


Epoch #13: test_reward: 128.560000 ± 10.053179, best_reward: 210.830000 ± 24.216133 in #2


Epoch #14: 10001it [00:06, 1571.12it/s, env_step=140000, len=132, loss=0.084, n/ep=0, n/st=10, rew=132.00]                           


Epoch #14: test_reward: 138.270000 ± 6.581573, best_reward: 210.830000 ± 24.216133 in #2


Epoch #15: 10001it [00:06, 1571.50it/s, env_step=150000, len=203, loss=0.060, n/ep=0, n/st=10, rew=203.00]                           


Epoch #15: test_reward: 170.650000 ± 14.485424, best_reward: 210.830000 ± 24.216133 in #2


Epoch #16: 10001it [00:15, 629.76it/s, env_step=160000, len=833, loss=0.131, n/ep=0, n/st=10, rew=833.00]                           


Epoch #16: test_reward: 128.720000 ± 8.747663, best_reward: 210.830000 ± 24.216133 in #2


Epoch #17: 10001it [00:06, 1581.91it/s, env_step=170000, len=207, loss=0.056, n/ep=0, n/st=10, rew=207.00]                           


Epoch #17: test_reward: 326.030000 ± 85.559389, best_reward: 326.030000 ± 85.559389 in #17


Epoch #18: 10001it [00:07, 1395.22it/s, env_step=180000, len=179, loss=0.073, n/ep=0, n/st=10, rew=179.00]                           


Epoch #18: test_reward: 140.660000 ± 5.150184, best_reward: 326.030000 ± 85.559389 in #17


Epoch #19: 10001it [00:11, 863.19it/s, env_step=190000, len=125, loss=0.226, n/ep=0, n/st=10, rew=125.00]                            


Epoch #19: test_reward: 97.230000 ± 10.822065, best_reward: 326.030000 ± 85.559389 in #17


Epoch #20: 10001it [00:10, 991.19it/s, env_step=200000, len=552, loss=0.255, n/ep=0, n/st=10, rew=552.00]                           


Epoch #20: test_reward: 115.900000 ± 10.931148, best_reward: 326.030000 ± 85.559389 in #17
Finished training! Use 174.45s


In [10]:
result

{'duration': '174.45s',
 'train_time/model': '105.86s',
 'test_step': 559467,
 'test_episode': 3500,
 'test_time': '43.05s',
 'test_speed': '12995.26 step/s',
 'best_reward': 326.03,
 'best_result': '326.03 ± 85.56',
 'train_step': 200000,
 'train_episode': 1441,
 'train_time/collector': '25.54s',
 'train_speed': '1522.13 step/s'}

In [11]:
policy.eval()
policy.set_eps(0.05)
collector = ts.data.Collector(policy, gym.make(envName, render_mode="human"), exploration_noise=False)
collector.collect(n_episode=10, render=1 / 60)

