In [12]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

import gym

from ignite.engine import Engine, Events
from torch.utils.tensorboard import SummaryWriter

In [2]:
class Policy(nn.Module):
    def __init__(self):
        super(Policy, self).__init__()
        self.affine1 = nn.Linear(4, 128)
        self.affine2 = nn.Linear(128, 2)

        self.saved_log_probs = []
        self.rewards = []

    def forward(self, x):
        x = F.relu(self.affine1(x))
        action_scores = self.affine2(x)
        return F.softmax(action_scores, dim=1)

In [5]:
seed = 543
gamma = 0.99
log_interval = 100

env = gym.make("CartPole-v0")
env.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x283d84540b0>

In [None]:
model = Policy()
optimizer = optim.Adam(model.parameters(), lr=1e-2)
eps = np.finfo(np.float32).eps.item()
timesteps = list(range(10000))
writer = SummaryWriter('logs/reinforce')


def select_action(model, observation):
    state = torch.from_numpy(observation).float().unsqueeze(0)
    probs = model(state)
    m = Categorical(probs)
    action = m.sample()
    model.saved_log_probs.append(m.log_prob(action))
    return action.item()


def finish_episode(model, optimizer, gamma, eps):
    R = 0
    policy_loss = []
    rewards = []
    for r in model.rewards[::-1]:
        R = r + gamma * R
        rewards.insert(0, R)
    rewards = torch.tensor(rewards)
    rewards = (rewards - rewards.mean()) / (rewards.std() + eps)
    for log_prob, reward in zip(model.saved_log_probs, rewards):
        policy_loss.append(-log_prob * reward)
    optimizer.zero_grad()
    policy_loss = torch.cat(policy_loss).sum()
    policy_loss.backward()
    optimizer.step()
    del model.rewards[:]
    del model.saved_log_probs[:]

def run_single_timestep(engine, timestep):
    observation = engine.state.observation
    action = select_action(model, observation)
    engine.state.observation, reward, done, _ = env.step(action)
#     if args.render:
#         env.render()
    model.rewards.append(reward)

    if done:
        engine.terminate_epoch()
        engine.state.timestep = timestep

trainer = Engine(run_single_timestep)

In [14]:
EPISODE_STARTED = Events.EPOCH_STARTED
EPISODE_COMPLETED = Events.EPOCH_COMPLETED

@trainer.on(Events.STARTED)
def initialize(engine):
    engine.state.running_reward = 10

@trainer.on(EPISODE_STARTED)
def reset_environment_state(engine):
    engine.state.observation = env.reset()

@trainer.on(EPISODE_COMPLETED)
def update_model(engine):
    t = engine.state.timestep
    engine.state.running_reward = engine.state.running_reward * 0.99 + t * 0.01
    finish_episode(model, optimizer, gamma, eps)

@trainer.on(EPISODE_COMPLETED(every=log_interval))
def log_episode(engine):
    i_episode = engine.state.epoch
    print(
        "Episode {}\tLast length: {:5d}\tAverage length: {:.2f}".format(
            i_episode, engine.state.timestep, engine.state.running_reward
        )
    )
    
@trainer.on(EPISODE_COMPLETED(every=10))
def log_episode_to_tensorboard(engine):
    i_episode = engine.state.epoch
    writer.add_scalar('running reward', engine.state.running_reward, i_episode)

@trainer.on(EPISODE_COMPLETED)
def should_finish_training(engine):
    running_reward = engine.state.running_reward
    if running_reward > env.spec.reward_threshold:
        print(
            "Solved! Running reward is now {} and "
            "the last episode runs to {} time steps!".format(running_reward, engine.state.timestep)
        )
        engine.should_terminate = True

In [15]:
trainer.run(timesteps, max_epochs=10000)

Episode 100	Last length:    75	Average length: 47.54
Episode 200	Last length:    31	Average length: 95.45
Episode 300	Last length:   199	Average length: 144.25
Episode 400	Last length:   199	Average length: 178.96
Episode 500	Last length:   199	Average length: 182.50
Episode 600	Last length:   199	Average length: 191.39
Solved! Running reward is now 195.0011948966364 and the last episode runs to 199 time steps!


State:
	iteration: 106356
	epoch: 664
	epoch_length: 10000
	max_epochs: 10000
	output: <class 'NoneType'>
	batch: 199
	metrics: <class 'dict'>
	dataloader: <class 'list'>
	seed: <class 'NoneType'>
	times: <class 'dict'>
	running_reward: 195.0011948966364
	observation: <class 'numpy.ndarray'>
	timestep: 199