In [2]:
import argparse
import gym
import numpy as np
from itertools import count

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

args = {
    'gamma': 0.99,
    'seed': 543,
    'render': True,
    'log_interval': 200
}

env = gym.make('CartPole-v1')
env.seed(args['seed'])
torch.manual_seed(args['seed'])


class Policy(nn.Module):
    def __init__(self):
        super(Policy, self).__init__()
        self.affine1 = nn.Linear(4, 128)
        self.affine2 = nn.Linear(128, 2)

        self.saved_log_probs = []
        self.rewards = []

    def forward(self, x):
        x = F.relu(self.affine1(x))
        action_scores = self.affine2(x)
        return F.softmax(action_scores, dim=1)


policy = Policy()
optimizer = optim.Adam(policy.parameters(), lr=1e-2)
eps = np.finfo(np.float32).eps.item()


def select_action(state):
    state = torch.from_numpy(state).float().unsqueeze(0)
    probs = policy(state)
    m = Categorical(probs)
    action = m.sample()
    policy.saved_log_probs.append(m.log_prob(action))
    return action.item()


def finish_episode():
    R = 0
    policy_loss = []
    rewards = []
    for r in policy.rewards[::-1]:
        R = r + args['gamma'] * R
        rewards.insert(0, R)
    rewards = torch.tensor(rewards)
    rewards = (rewards - rewards.mean()) / (rewards.std() + eps)
    for log_prob, reward in zip(policy.saved_log_probs, rewards):
        policy_loss.append(-log_prob * reward)
    optimizer.zero_grad()
    policy_loss = torch.cat(policy_loss).sum()
    policy_loss.backward()
    optimizer.step()
    del policy.rewards[:]
    del policy.saved_log_probs[:]


def main():
    running_reward = 10
    for i_episode in count(1):
        state = env.reset()
        for t in range(10000):  # Don't infinite loop while learning
            action = select_action(state)
            state, reward, done, _ = env.step(action)
            if args['render']:
                env.render()
            policy.rewards.append(reward)
            if done:
                break

        running_reward = running_reward * 0.99 + t * 0.01
        finish_episode()
        if i_episode % args['log_interval'] == 0:
            print('Episode {}\tLast length: {:5d}\tAverage length: {:.2f}'.format(
                i_episode, t, running_reward))
        if running_reward > env.spec.reward_threshold:
            print("Solved! Running reward is now {} and "
                  "the last episode runs to {} time steps!".format(running_reward, t))
            break


main()


Episode 200	Last length:     9	Average length: 10.18
Episode 400	Last length:     9	Average length: 8.58
Episode 600	Last length:     9	Average length: 8.32
Episode 800	Last length:     8	Average length: 8.49
Episode 1000	Last length:     8	Average length: 8.36
Episode 1200	Last length:     9	Average length: 8.37
Episode 1400	Last length:     9	Average length: 8.48
Episode 1600	Last length:     8	Average length: 8.41
Episode 1800	Last length:     8	Average length: 8.39
Episode 2000	Last length:     8	Average length: 8.41
Episode 2200	Last length:     7	Average length: 8.40
Episode 2400	Last length:     9	Average length: 8.33
Episode 2600	Last length:     8	Average length: 8.53
Episode 2800	Last length:     8	Average length: 13.64
Episode 3000	Last length:     8	Average length: 9.11
Episode 3200	Last length:     9	Average length: 8.50
Episode 3400	Last length:     7	Average length: 8.40
Episode 3600	Last length:     9	Average length: 8.41
Episode 3800	Last length:     9	Average length: 

Episode 30800	Last length:     9	Average length: 8.38
Episode 31000	Last length:     8	Average length: 8.42
Episode 31200	Last length:     8	Average length: 8.27
Episode 31400	Last length:     9	Average length: 8.30
Episode 31600	Last length:     9	Average length: 8.32
Episode 31800	Last length:    10	Average length: 8.43
Episode 32000	Last length:     8	Average length: 8.46
Episode 32200	Last length:     7	Average length: 8.29
Episode 32400	Last length:     8	Average length: 8.36
Episode 32600	Last length:     7	Average length: 8.37
Episode 32800	Last length:     9	Average length: 8.35
Episode 33000	Last length:     9	Average length: 8.39
Episode 33200	Last length:     8	Average length: 8.34
Episode 33400	Last length:     7	Average length: 8.30
Episode 33600	Last length:     9	Average length: 8.33
Episode 33800	Last length:     9	Average length: 8.36
Episode 34000	Last length:     9	Average length: 8.36
Episode 34200	Last length:     9	Average length: 8.38
Episode 34400	Last length:  

Episode 61200	Last length:     8	Average length: 8.35
Episode 61400	Last length:     8	Average length: 8.30
Episode 61600	Last length:     8	Average length: 8.22
Episode 61800	Last length:     9	Average length: 8.37
Episode 62000	Last length:     9	Average length: 8.37
Episode 62200	Last length:     9	Average length: 8.34
Episode 62400	Last length:     7	Average length: 8.27
Episode 62600	Last length:     7	Average length: 8.36
Episode 62800	Last length:     9	Average length: 8.38
Episode 63000	Last length:     9	Average length: 8.46
Episode 63200	Last length:     9	Average length: 8.42
Episode 63400	Last length:     8	Average length: 8.44
Episode 63600	Last length:     9	Average length: 8.33
Episode 63800	Last length:     9	Average length: 8.40
Episode 64000	Last length:     9	Average length: 8.31
Episode 64200	Last length:     8	Average length: 8.40
Episode 64400	Last length:     9	Average length: 8.31
Episode 64600	Last length:     7	Average length: 8.32
Episode 64800	Last length:  

Episode 91600	Last length:     8	Average length: 8.35
Episode 91800	Last length:     7	Average length: 8.36
Episode 92000	Last length:     7	Average length: 8.34
Episode 92200	Last length:     9	Average length: 8.36
Episode 92400	Last length:     8	Average length: 8.37
Episode 92600	Last length:     9	Average length: 8.31
Episode 92800	Last length:     8	Average length: 8.28
Episode 93000	Last length:     9	Average length: 8.36
Episode 93200	Last length:     8	Average length: 8.37
Episode 93400	Last length:     8	Average length: 8.24
Episode 93600	Last length:     7	Average length: 8.22
Episode 93800	Last length:     9	Average length: 8.40
Episode 94000	Last length:     8	Average length: 8.37
Episode 94200	Last length:     8	Average length: 8.47
Episode 94400	Last length:     8	Average length: 8.40
Episode 94600	Last length:     8	Average length: 8.42
Episode 94800	Last length:     8	Average length: 8.35
Episode 95000	Last length:     9	Average length: 8.32
Episode 95200	Last length:  

Episode 121600	Last length:     8	Average length: 8.38
Episode 121800	Last length:     9	Average length: 8.35
Episode 122000	Last length:     8	Average length: 8.40
Episode 122200	Last length:     8	Average length: 8.30
Episode 122400	Last length:     9	Average length: 8.31
Episode 122600	Last length:     8	Average length: 8.31
Episode 122800	Last length:     8	Average length: 8.36
Episode 123000	Last length:     7	Average length: 8.26
Episode 123200	Last length:     8	Average length: 8.33
Episode 123400	Last length:     8	Average length: 8.40
Episode 123600	Last length:     9	Average length: 8.42
Episode 123800	Last length:     9	Average length: 8.39
Episode 124000	Last length:     9	Average length: 8.39
Episode 124200	Last length:     9	Average length: 8.37
Episode 124400	Last length:     9	Average length: 8.30
Episode 124600	Last length:     9	Average length: 8.26
Episode 124800	Last length:     9	Average length: 8.24
Episode 125000	Last length:     8	Average length: 8.32
Episode 12

Episode 151400	Last length:     8	Average length: 8.47
Episode 151600	Last length:     9	Average length: 8.34
Episode 151800	Last length:     9	Average length: 8.32
Episode 152000	Last length:     7	Average length: 8.35
Episode 152200	Last length:     9	Average length: 8.38
Episode 152400	Last length:     9	Average length: 8.28
Episode 152600	Last length:     9	Average length: 8.42
Episode 152800	Last length:     8	Average length: 8.28
Episode 153000	Last length:     8	Average length: 8.36
Episode 153200	Last length:     8	Average length: 8.40
Episode 153400	Last length:     8	Average length: 8.34
Episode 153600	Last length:     8	Average length: 8.25
Episode 153800	Last length:     8	Average length: 8.37
Episode 154000	Last length:     9	Average length: 8.30
Episode 154200	Last length:     7	Average length: 8.19
Episode 154400	Last length:     8	Average length: 8.36
Episode 154600	Last length:     8	Average length: 8.29
Episode 154800	Last length:     9	Average length: 8.43
Episode 15

Episode 181200	Last length:     8	Average length: 8.36
Episode 181400	Last length:     9	Average length: 8.28
Episode 181600	Last length:     8	Average length: 8.28
Episode 181800	Last length:     9	Average length: 8.30
Episode 182000	Last length:     8	Average length: 8.32
Episode 182200	Last length:     8	Average length: 8.34
Episode 182400	Last length:     8	Average length: 8.32
Episode 182600	Last length:     9	Average length: 8.46
Episode 182800	Last length:     9	Average length: 8.33
Episode 183000	Last length:     9	Average length: 8.34
Episode 183200	Last length:     9	Average length: 8.41
Episode 183400	Last length:     8	Average length: 8.34
Episode 183600	Last length:     9	Average length: 8.38
Episode 183800	Last length:     7	Average length: 8.32
Episode 184000	Last length:     8	Average length: 8.34
Episode 184200	Last length:     9	Average length: 8.29
Episode 184400	Last length:     8	Average length: 8.31
Episode 184600	Last length:     9	Average length: 8.26
Episode 18

Episode 211000	Last length:    10	Average length: 8.41
Episode 211200	Last length:     7	Average length: 8.39
Episode 211400	Last length:     7	Average length: 8.29
Episode 211600	Last length:     9	Average length: 8.43
Episode 211800	Last length:     8	Average length: 8.43
Episode 212000	Last length:     7	Average length: 8.35
Episode 212200	Last length:     8	Average length: 8.29
Episode 212400	Last length:     8	Average length: 8.26
Episode 212600	Last length:     8	Average length: 8.33
Episode 212800	Last length:    10	Average length: 8.36
Episode 213000	Last length:     8	Average length: 8.39
Episode 213200	Last length:     9	Average length: 8.37
Episode 213400	Last length:     7	Average length: 8.20
Episode 213600	Last length:     7	Average length: 8.34
Episode 213800	Last length:     8	Average length: 8.36
Episode 214000	Last length:     8	Average length: 8.42
Episode 214200	Last length:     9	Average length: 8.32
Episode 214400	Last length:     8	Average length: 8.28
Episode 21

Episode 240800	Last length:     8	Average length: 8.43
Episode 241000	Last length:     8	Average length: 8.39
Episode 241200	Last length:     8	Average length: 8.37
Episode 241400	Last length:     9	Average length: 8.38
Episode 241600	Last length:     9	Average length: 8.42
Episode 241800	Last length:     9	Average length: 8.41
Episode 242000	Last length:     8	Average length: 8.39
Episode 242200	Last length:     8	Average length: 8.32
Episode 242400	Last length:     8	Average length: 8.46
Episode 242600	Last length:     9	Average length: 8.36
Episode 242800	Last length:     9	Average length: 8.35
Episode 243000	Last length:    10	Average length: 8.46
Episode 243200	Last length:     8	Average length: 8.43
Episode 243400	Last length:     8	Average length: 8.32
Episode 243600	Last length:     8	Average length: 8.36
Episode 243800	Last length:     8	Average length: 8.45
Episode 244000	Last length:     7	Average length: 8.46
Episode 244200	Last length:     8	Average length: 8.34
Episode 24

Episode 270600	Last length:     8	Average length: 8.43
Episode 270800	Last length:     9	Average length: 8.37
Episode 271000	Last length:     8	Average length: 8.29
Episode 271200	Last length:     9	Average length: 8.34
Episode 271400	Last length:     8	Average length: 8.39
Episode 271600	Last length:     8	Average length: 8.34
Episode 271800	Last length:     9	Average length: 8.35
Episode 272000	Last length:     8	Average length: 8.24
Episode 272200	Last length:     8	Average length: 8.38
Episode 272400	Last length:     8	Average length: 8.31
Episode 272600	Last length:     9	Average length: 8.35
Episode 272800	Last length:     8	Average length: 8.36
Episode 273000	Last length:     8	Average length: 8.33
Episode 273200	Last length:     8	Average length: 8.37
Episode 273400	Last length:     8	Average length: 8.34
Episode 273600	Last length:     9	Average length: 8.38
Episode 273800	Last length:     8	Average length: 8.32
Episode 274000	Last length:     9	Average length: 8.37
Episode 27

Episode 300400	Last length:     9	Average length: 8.38
Episode 300600	Last length:     8	Average length: 8.33
Episode 300800	Last length:     8	Average length: 8.33
Episode 301000	Last length:     9	Average length: 8.41
Episode 301200	Last length:     8	Average length: 8.35
Episode 301400	Last length:     8	Average length: 8.30
Episode 301600	Last length:     9	Average length: 8.31
Episode 301800	Last length:     8	Average length: 8.33
Episode 302000	Last length:     9	Average length: 8.43
Episode 302200	Last length:     9	Average length: 8.35
Episode 302400	Last length:    10	Average length: 8.39
Episode 302600	Last length:     9	Average length: 8.37
Episode 302800	Last length:     8	Average length: 8.33
Episode 303000	Last length:     8	Average length: 8.32
Episode 303200	Last length:     9	Average length: 8.29
Episode 303400	Last length:     7	Average length: 8.33
Episode 303600	Last length:     9	Average length: 8.49
Episode 303800	Last length:     9	Average length: 8.38
Episode 30

Episode 330200	Last length:     9	Average length: 8.24
Episode 330400	Last length:     9	Average length: 8.28
Episode 330600	Last length:     8	Average length: 8.31
Episode 330800	Last length:     7	Average length: 8.46
Episode 331000	Last length:     8	Average length: 8.33
Episode 331200	Last length:     8	Average length: 8.35
Episode 331400	Last length:    10	Average length: 8.32
Episode 331600	Last length:     7	Average length: 8.31
Episode 331800	Last length:     9	Average length: 8.44
Episode 332000	Last length:     9	Average length: 8.37
Episode 332200	Last length:     9	Average length: 8.41
Episode 332400	Last length:     9	Average length: 8.32
Episode 332600	Last length:     9	Average length: 8.38
Episode 332800	Last length:     8	Average length: 8.33
Episode 333000	Last length:     8	Average length: 8.32
Episode 333200	Last length:     8	Average length: 8.36
Episode 333400	Last length:     9	Average length: 8.44
Episode 333600	Last length:     8	Average length: 8.35
Episode 33

KeyboardInterrupt: 