In [1]:
import gym
import ptan
import numpy as np
from tensorboardX import SummaryWriter

import torch
import torch.nn as nn
import torch.optim as optim

In [2]:
GAMMA = 0.99
LEARNING_RATE = 0.01
BATCH_SIZE = 8

EPSILON_START = 1.0
EPSILON_STOP = 0.02
EPSILON_STEPS = 5000

REPLAY_BUFFER = 50000

In [3]:
class DQN(nn.Module):
    def __init__(self, input_size, n_actions):
        super(DQN, self).__init__()

        self.net = nn.Sequential(
            nn.Linear(input_size, 128),
            nn.ReLU(),
            nn.Linear(128, n_actions)
        )

    def forward(self, x):
        return self.net(x)

In [4]:
def calc_target(net, local_reward, next_state):
    if next_state is None:
        return local_reward
    state_v = torch.tensor([next_state], dtype=torch.float32)
    next_q_v = net(state_v)
    best_q = next_q_v.max(dim=1)[0].item()
    return local_reward + GAMMA * best_q

In [5]:
if __name__ == "__main__":
    env = gym.make("CartPole-v0")
    writer = SummaryWriter(comment="-cartpole-dqn")

    net = DQN(env.observation_space.shape[0], env.action_space.n)
    print(net)

    selector = ptan.actions.EpsilonGreedyActionSelector(epsilon=EPSILON_START)
    agent = ptan.agent.DQNAgent(net, selector, preprocessor=ptan.agent.float32_preprocessor)
    exp_source = ptan.experience.ExperienceSourceFirstLast(env, agent, gamma=GAMMA)
    replay_buffer = ptan.experience.ExperienceReplayBuffer(exp_source, REPLAY_BUFFER)

    optimizer = optim.Adam(net.parameters(), lr=LEARNING_RATE)
    mse_loss = nn.MSELoss()

    total_rewards = []
    step_idx = 0
    done_episodes = 0

    while True:
        step_idx += 1
        selector.epsilon = max(EPSILON_STOP, EPSILON_START - step_idx / EPSILON_STEPS)
        replay_buffer.populate(1)

        if len(replay_buffer) < BATCH_SIZE:
            continue

        # sample batch
        batch = replay_buffer.sample(BATCH_SIZE)
        batch_states = [exp.state for exp in batch]
        batch_actions = [exp.action for exp in batch]
        batch_targets = [calc_target(net, exp.reward, exp.last_state)
                         for exp in batch]
        # train
        optimizer.zero_grad()
        states_v = torch.FloatTensor(batch_states)
        net_q_v = net(states_v)
        target_q = net_q_v.data.numpy().copy()
        target_q[range(BATCH_SIZE), batch_actions] = batch_targets
        target_q_v = torch.tensor(target_q)
        loss_v = mse_loss(net_q_v, target_q_v)
        loss_v.backward()
        optimizer.step()

        # handle new rewards
        new_rewards = exp_source.pop_total_rewards()
        if new_rewards:
            done_episodes += 1
            reward = new_rewards[0]
            total_rewards.append(reward)
            mean_rewards = float(np.mean(total_rewards[-100:]))
            print("%d: reward: %6.2f, mean_100: %6.2f, epsilon: %.2f, episodes: %d" % (
                step_idx, reward, mean_rewards, selector.epsilon, done_episodes))
            writer.add_scalar("reward", reward, step_idx)
            writer.add_scalar("reward_100", mean_rewards, step_idx)
            writer.add_scalar("epsilon", selector.epsilon, step_idx)
            writer.add_scalar("episodes", done_episodes, step_idx)
            if mean_rewards > 195:
                print("Solved in %d steps and %d episodes!" % (step_idx, done_episodes))
                break
    writer.close()

DQN(
  (net): Sequential(
    (0): Linear(in_features=4, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=2, bias=True)
  )
)
65: reward:  63.00, mean_100:  63.00, epsilon: 0.99, episodes: 1
88: reward:  22.00, mean_100:  42.50, epsilon: 0.98, episodes: 2
121: reward:  32.00, mean_100:  39.00, epsilon: 0.98, episodes: 3
137: reward:  15.00, mean_100:  33.00, epsilon: 0.97, episodes: 4
152: reward:  14.00, mean_100:  29.20, epsilon: 0.97, episodes: 5
188: reward:  35.00, mean_100:  30.17, epsilon: 0.96, episodes: 6
214: reward:  25.00, mean_100:  29.43, epsilon: 0.96, episodes: 7
240: reward:  25.00, mean_100:  28.88, epsilon: 0.95, episodes: 8
261: reward:  20.00, mean_100:  27.89, epsilon: 0.95, episodes: 9
306: reward:  44.00, mean_100:  29.50, epsilon: 0.94, episodes: 10
319: reward:  12.00, mean_100:  27.91, epsilon: 0.94, episodes: 11
332: reward:  12.00, mean_100:  26.58, epsilon: 0.93, episodes: 12
347: reward:  14.00, mean_100:  25.62, 

6962: reward: 183.00, mean_100:  63.57, epsilon: 0.02, episodes: 120
7163: reward: 200.00, mean_100:  65.25, epsilon: 0.02, episodes: 121
7364: reward: 200.00, mean_100:  67.13, epsilon: 0.02, episodes: 122
7565: reward: 200.00, mean_100:  69.00, epsilon: 0.02, episodes: 123
7766: reward: 200.00, mean_100:  70.89, epsilon: 0.02, episodes: 124
7967: reward: 200.00, mean_100:  72.66, epsilon: 0.02, episodes: 125
8168: reward: 200.00, mean_100:  74.51, epsilon: 0.02, episodes: 126
8369: reward: 200.00, mean_100:  76.38, epsilon: 0.02, episodes: 127
8570: reward: 200.00, mean_100:  78.11, epsilon: 0.02, episodes: 128
8771: reward: 200.00, mean_100:  79.54, epsilon: 0.02, episodes: 129
8972: reward: 200.00, mean_100:  81.41, epsilon: 0.02, episodes: 130
9168: reward: 195.00, mean_100:  83.26, epsilon: 0.02, episodes: 131
9369: reward: 200.00, mean_100:  85.06, epsilon: 0.02, episodes: 132
9570: reward: 200.00, mean_100:  86.88, epsilon: 0.02, episodes: 133
9771: reward: 200.00, mean_100:  8

29958: reward: 200.00, mean_100: 192.83, epsilon: 0.02, episodes: 238
30159: reward: 200.00, mean_100: 192.83, epsilon: 0.02, episodes: 239
30360: reward: 200.00, mean_100: 192.83, epsilon: 0.02, episodes: 240
30561: reward: 200.00, mean_100: 192.83, epsilon: 0.02, episodes: 241
30762: reward: 200.00, mean_100: 192.83, epsilon: 0.02, episodes: 242
30963: reward: 200.00, mean_100: 192.83, epsilon: 0.02, episodes: 243
31154: reward: 190.00, mean_100: 192.73, epsilon: 0.02, episodes: 244
31355: reward: 200.00, mean_100: 192.73, epsilon: 0.02, episodes: 245
31556: reward: 200.00, mean_100: 192.73, epsilon: 0.02, episodes: 246
31757: reward: 200.00, mean_100: 192.73, epsilon: 0.02, episodes: 247
31958: reward: 200.00, mean_100: 192.73, epsilon: 0.02, episodes: 248
32139: reward: 180.00, mean_100: 192.53, epsilon: 0.02, episodes: 249
32340: reward: 200.00, mean_100: 192.53, epsilon: 0.02, episodes: 250
32541: reward: 200.00, mean_100: 192.53, epsilon: 0.02, episodes: 251
32742: reward: 200.0

51591: reward: 200.00, mean_100: 179.45, epsilon: 0.02, episodes: 356
51792: reward: 200.00, mean_100: 179.86, epsilon: 0.02, episodes: 357
51993: reward: 200.00, mean_100: 179.86, epsilon: 0.02, episodes: 358
52194: reward: 200.00, mean_100: 179.86, epsilon: 0.02, episodes: 359
52395: reward: 200.00, mean_100: 179.86, epsilon: 0.02, episodes: 360
52596: reward: 200.00, mean_100: 179.91, epsilon: 0.02, episodes: 361
52797: reward: 200.00, mean_100: 179.91, epsilon: 0.02, episodes: 362
52998: reward: 200.00, mean_100: 180.26, epsilon: 0.02, episodes: 363
53199: reward: 200.00, mean_100: 181.08, epsilon: 0.02, episodes: 364
53400: reward: 200.00, mean_100: 181.08, epsilon: 0.02, episodes: 365
53562: reward: 161.00, mean_100: 181.03, epsilon: 0.02, episodes: 366
53763: reward: 200.00, mean_100: 181.03, epsilon: 0.02, episodes: 367
53964: reward: 200.00, mean_100: 181.03, epsilon: 0.02, episodes: 368
54165: reward: 200.00, mean_100: 181.03, epsilon: 0.02, episodes: 369
54344: reward: 178.0

56868: reward:  21.00, mean_100:  19.61, epsilon: 0.02, episodes: 478
56894: reward:  25.00, mean_100:  17.86, epsilon: 0.02, episodes: 479
56917: reward:  22.00, mean_100:  16.68, epsilon: 0.02, episodes: 480
56940: reward:  22.00, mean_100:  15.52, epsilon: 0.02, episodes: 481
56955: reward:  14.00, mean_100:  13.66, epsilon: 0.02, episodes: 482
56969: reward:  13.00, mean_100:  12.88, epsilon: 0.02, episodes: 483
56985: reward:  15.00, mean_100:  11.94, epsilon: 0.02, episodes: 484
57010: reward:  24.00, mean_100:  12.10, epsilon: 0.02, episodes: 485
57037: reward:  26.00, mean_100:  12.26, epsilon: 0.02, episodes: 486
57062: reward:  24.00, mean_100:  12.39, epsilon: 0.02, episodes: 487
57077: reward:  14.00, mean_100:  12.44, epsilon: 0.02, episodes: 488
57094: reward:  16.00, mean_100:  12.51, epsilon: 0.02, episodes: 489
57112: reward:  17.00, mean_100:  12.58, epsilon: 0.02, episodes: 490
57134: reward:  21.00, mean_100:  12.70, epsilon: 0.02, episodes: 491
57148: reward:  13.0

70339: reward: 200.00, mean_100: 130.29, epsilon: 0.02, episodes: 596
70540: reward: 200.00, mean_100: 132.16, epsilon: 0.02, episodes: 597
70700: reward: 159.00, mean_100: 133.62, epsilon: 0.02, episodes: 598
70901: reward: 200.00, mean_100: 135.41, epsilon: 0.02, episodes: 599
71010: reward: 108.00, mean_100: 136.19, epsilon: 0.02, episodes: 600
71165: reward: 154.00, mean_100: 137.44, epsilon: 0.02, episodes: 601
71322: reward: 156.00, mean_100: 138.75, epsilon: 0.02, episodes: 602
71520: reward: 197.00, mean_100: 140.48, epsilon: 0.02, episodes: 603
71646: reward: 125.00, mean_100: 141.53, epsilon: 0.02, episodes: 604
71701: reward:  54.00, mean_100: 141.88, epsilon: 0.02, episodes: 605
71712: reward:  10.00, mean_100: 141.78, epsilon: 0.02, episodes: 606
71726: reward:  13.00, mean_100: 141.66, epsilon: 0.02, episodes: 607
71737: reward:  10.00, mean_100: 141.53, epsilon: 0.02, episodes: 608
71879: reward: 141.00, mean_100: 142.73, epsilon: 0.02, episodes: 609
72080: reward: 200.0

90318: reward: 144.00, mean_100: 173.92, epsilon: 0.02, episodes: 714
90519: reward: 200.00, mean_100: 173.92, epsilon: 0.02, episodes: 715
90677: reward: 157.00, mean_100: 174.14, epsilon: 0.02, episodes: 716
90878: reward: 200.00, mean_100: 174.65, epsilon: 0.02, episodes: 717
91062: reward: 183.00, mean_100: 174.65, epsilon: 0.02, episodes: 718
91263: reward: 200.00, mean_100: 174.65, epsilon: 0.02, episodes: 719
91464: reward: 200.00, mean_100: 175.17, epsilon: 0.02, episodes: 720
91665: reward: 200.00, mean_100: 175.17, epsilon: 0.02, episodes: 721
91799: reward: 133.00, mean_100: 174.50, epsilon: 0.02, episodes: 722
92000: reward: 200.00, mean_100: 174.50, epsilon: 0.02, episodes: 723
92201: reward: 200.00, mean_100: 174.50, epsilon: 0.02, episodes: 724
92389: reward: 187.00, mean_100: 174.37, epsilon: 0.02, episodes: 725
92487: reward:  97.00, mean_100: 173.34, epsilon: 0.02, episodes: 726
92498: reward:  10.00, mean_100: 171.44, epsilon: 0.02, episodes: 727
92509: reward:  10.0

98458: reward: 200.00, mean_100:  58.04, epsilon: 0.02, episodes: 832
98597: reward: 138.00, mean_100:  59.33, epsilon: 0.02, episodes: 833
98764: reward: 166.00, mean_100:  60.88, epsilon: 0.02, episodes: 834
98928: reward: 163.00, mean_100:  62.41, epsilon: 0.02, episodes: 835
99129: reward: 200.00, mean_100:  64.32, epsilon: 0.02, episodes: 836
99330: reward: 200.00, mean_100:  66.23, epsilon: 0.02, episodes: 837
99531: reward: 200.00, mean_100:  68.13, epsilon: 0.02, episodes: 838
99732: reward: 200.00, mean_100:  70.03, epsilon: 0.02, episodes: 839
99857: reward: 124.00, mean_100:  71.18, epsilon: 0.02, episodes: 840
100019: reward: 161.00, mean_100:  72.69, epsilon: 0.02, episodes: 841
100180: reward: 160.00, mean_100:  74.20, epsilon: 0.02, episodes: 842
100381: reward: 200.00, mean_100:  76.12, epsilon: 0.02, episodes: 843
100513: reward: 131.00, mean_100:  77.35, epsilon: 0.02, episodes: 844
100714: reward: 200.00, mean_100:  79.26, epsilon: 0.02, episodes: 845
100778: reward:

117006: reward: 180.00, mean_100: 157.83, epsilon: 0.02, episodes: 948
117207: reward: 200.00, mean_100: 159.19, epsilon: 0.02, episodes: 949
117408: reward: 200.00, mean_100: 159.57, epsilon: 0.02, episodes: 950
117598: reward: 189.00, mean_100: 160.50, epsilon: 0.02, episodes: 951
117761: reward: 162.00, mean_100: 160.39, epsilon: 0.02, episodes: 952
117899: reward: 137.00, mean_100: 160.58, epsilon: 0.02, episodes: 953
118100: reward: 200.00, mean_100: 161.52, epsilon: 0.02, episodes: 954
118301: reward: 200.00, mean_100: 161.52, epsilon: 0.02, episodes: 955
118502: reward: 200.00, mean_100: 161.52, epsilon: 0.02, episodes: 956
118703: reward: 200.00, mean_100: 161.52, epsilon: 0.02, episodes: 957
118884: reward: 180.00, mean_100: 162.32, epsilon: 0.02, episodes: 958
118997: reward: 112.00, mean_100: 161.44, epsilon: 0.02, episodes: 959
119198: reward: 200.00, mean_100: 161.44, epsilon: 0.02, episodes: 960
119380: reward: 181.00, mean_100: 161.57, epsilon: 0.02, episodes: 961
119550

136530: reward: 147.00, mean_100: 166.79, epsilon: 0.02, episodes: 1063
136731: reward: 200.00, mean_100: 166.79, epsilon: 0.02, episodes: 1064
136932: reward: 200.00, mean_100: 166.81, epsilon: 0.02, episodes: 1065
137133: reward: 200.00, mean_100: 166.81, epsilon: 0.02, episodes: 1066
137334: reward: 200.00, mean_100: 166.81, epsilon: 0.02, episodes: 1067
137535: reward: 200.00, mean_100: 167.01, epsilon: 0.02, episodes: 1068
137736: reward: 200.00, mean_100: 167.01, epsilon: 0.02, episodes: 1069
137937: reward: 200.00, mean_100: 167.01, epsilon: 0.02, episodes: 1070
138138: reward: 200.00, mean_100: 167.01, epsilon: 0.02, episodes: 1071
138339: reward: 200.00, mean_100: 167.38, epsilon: 0.02, episodes: 1072
138515: reward: 175.00, mean_100: 167.80, epsilon: 0.02, episodes: 1073
138658: reward: 142.00, mean_100: 167.75, epsilon: 0.02, episodes: 1074
138858: reward: 199.00, mean_100: 167.74, epsilon: 0.02, episodes: 1075
139059: reward: 200.00, mean_100: 167.98, epsilon: 0.02, episode

140321: reward:   8.00, mean_100:   9.43, epsilon: 0.02, episodes: 1178
140333: reward:  11.00, mean_100:   9.44, epsilon: 0.02, episodes: 1179
140342: reward:   8.00, mean_100:   9.41, epsilon: 0.02, episodes: 1180
140354: reward:  11.00, mean_100:   9.43, epsilon: 0.02, episodes: 1181
140365: reward:  10.00, mean_100:   9.44, epsilon: 0.02, episodes: 1182
140376: reward:  10.00, mean_100:   9.43, epsilon: 0.02, episodes: 1183
140387: reward:  10.00, mean_100:   9.45, epsilon: 0.02, episodes: 1184
140399: reward:  11.00, mean_100:   9.46, epsilon: 0.02, episodes: 1185
140410: reward:  10.00, mean_100:   9.47, epsilon: 0.02, episodes: 1186
140421: reward:  10.00, mean_100:   9.49, epsilon: 0.02, episodes: 1187
140431: reward:   9.00, mean_100:   9.48, epsilon: 0.02, episodes: 1188
140441: reward:   9.00, mean_100:   9.48, epsilon: 0.02, episodes: 1189
140452: reward:  10.00, mean_100:   9.49, epsilon: 0.02, episodes: 1190
140464: reward:  11.00, mean_100:   9.40, epsilon: 0.02, episode

146768: reward:  68.00, mean_100:  61.95, epsilon: 0.02, episodes: 1292
146832: reward:  63.00, mean_100:  62.48, epsilon: 0.02, episodes: 1293
146869: reward:  36.00, mean_100:  62.75, epsilon: 0.02, episodes: 1294
146920: reward:  50.00, mean_100:  63.15, epsilon: 0.02, episodes: 1295
146991: reward:  70.00, mean_100:  63.75, epsilon: 0.02, episodes: 1296
147062: reward:  70.00, mean_100:  64.36, epsilon: 0.02, episodes: 1297
147103: reward:  40.00, mean_100:  64.65, epsilon: 0.02, episodes: 1298
147172: reward:  68.00, mean_100:  65.24, epsilon: 0.02, episodes: 1299
147251: reward:  78.00, mean_100:  65.92, epsilon: 0.02, episodes: 1300
147300: reward:  48.00, mean_100:  66.32, epsilon: 0.02, episodes: 1301
147367: reward:  66.00, mean_100:  66.89, epsilon: 0.02, episodes: 1302
147452: reward:  84.00, mean_100:  67.63, epsilon: 0.02, episodes: 1303
147502: reward:  49.00, mean_100:  68.03, epsilon: 0.02, episodes: 1304
147555: reward:  52.00, mean_100:  68.46, epsilon: 0.02, episode

154907: reward: 142.00, mean_100:  71.34, epsilon: 0.02, episodes: 1407
155007: reward:  99.00, mean_100:  72.03, epsilon: 0.02, episodes: 1408
155103: reward:  95.00, mean_100:  72.76, epsilon: 0.02, episodes: 1409
155157: reward:  53.00, mean_100:  72.96, epsilon: 0.02, episodes: 1410
155263: reward: 105.00, mean_100:  73.69, epsilon: 0.02, episodes: 1411
155409: reward: 145.00, mean_100:  74.79, epsilon: 0.02, episodes: 1412
155610: reward: 200.00, mean_100:  76.25, epsilon: 0.02, episodes: 1413
155811: reward: 200.00, mean_100:  77.52, epsilon: 0.02, episodes: 1414
156012: reward: 200.00, mean_100:  78.67, epsilon: 0.02, episodes: 1415
156184: reward: 171.00, mean_100:  79.08, epsilon: 0.02, episodes: 1416
156385: reward: 200.00, mean_100:  80.72, epsilon: 0.02, episodes: 1417
156504: reward: 118.00, mean_100:  81.44, epsilon: 0.02, episodes: 1418
156592: reward:  87.00, mean_100:  81.93, epsilon: 0.02, episodes: 1419
156722: reward: 129.00, mean_100:  82.52, epsilon: 0.02, episode

166631: reward: 144.00, mean_100:  97.83, epsilon: 0.02, episodes: 1521
166832: reward: 200.00, mean_100:  99.04, epsilon: 0.02, episodes: 1522
167010: reward: 177.00, mean_100:  99.72, epsilon: 0.02, episodes: 1523
167158: reward: 147.00, mean_100:  99.19, epsilon: 0.02, episodes: 1524
167263: reward: 104.00, mean_100:  98.23, epsilon: 0.02, episodes: 1525
167361: reward:  97.00, mean_100:  98.18, epsilon: 0.02, episodes: 1526
167473: reward: 111.00, mean_100:  97.29, epsilon: 0.02, episodes: 1527
167585: reward: 111.00, mean_100:  97.07, epsilon: 0.02, episodes: 1528
167605: reward:  19.00, mean_100:  96.24, epsilon: 0.02, episodes: 1529
167621: reward:  15.00, mean_100:  96.15, epsilon: 0.02, episodes: 1530
167634: reward:  12.00, mean_100:  95.15, epsilon: 0.02, episodes: 1531
167650: reward:  15.00, mean_100:  95.02, epsilon: 0.02, episodes: 1532
167667: reward:  16.00, mean_100:  94.29, epsilon: 0.02, episodes: 1533
167685: reward:  17.00, mean_100:  92.46, epsilon: 0.02, episode

184137: reward: 200.00, mean_100: 163.32, epsilon: 0.02, episodes: 1635
184338: reward: 200.00, mean_100: 164.29, epsilon: 0.02, episodes: 1636
184539: reward: 200.00, mean_100: 166.13, epsilon: 0.02, episodes: 1637
184740: reward: 200.00, mean_100: 168.00, epsilon: 0.02, episodes: 1638
184941: reward: 200.00, mean_100: 169.86, epsilon: 0.02, episodes: 1639
185142: reward: 200.00, mean_100: 171.62, epsilon: 0.02, episodes: 1640
185343: reward: 200.00, mean_100: 173.49, epsilon: 0.02, episodes: 1641
185544: reward: 200.00, mean_100: 175.38, epsilon: 0.02, episodes: 1642
185729: reward: 184.00, mean_100: 176.94, epsilon: 0.02, episodes: 1643
185911: reward: 181.00, mean_100: 178.59, epsilon: 0.02, episodes: 1644
186112: reward: 200.00, mean_100: 180.43, epsilon: 0.02, episodes: 1645
186276: reward: 163.00, mean_100: 181.92, epsilon: 0.02, episodes: 1646
186477: reward: 200.00, mean_100: 183.75, epsilon: 0.02, episodes: 1647
186678: reward: 200.00, mean_100: 184.44, epsilon: 0.02, episode