In [9]:
import gym
import ptan
import numpy as np
from torch.utils.tensorboard import SummaryWriter

import torch
import torch.nn as nn 
import torch.nn.functional as F 
import torch.optim as optim

In [10]:
GAMMA = 0.99
LEARNING_RATE = 0.01
EPISODES_TO_TRAIN = 4

In [11]:
class PGN(nn.Module):
    def __init__(self, input_size, n_actions):
        super(PGN, self).__init__()
        
        self.net = nn.Sequential(
            nn.Linear(input_size, 128),
            nn.ReLU(),
            nn.Linear(128, n_actions)
        )
        
    def forward(self, x):
        return self.net(x)

In [12]:
def calc_qvals(rewards):
    res = []
    sum_r = 0.0
    for r in reversed(rewards):
        sum_r *= GAMMA
        sum_r += r
        res.append(sum_r)
    res = list(reversed(res))
    mean_q = np.mean(res)
    return [q - mean_q for q in res]

In [13]:
env = gym.make("CartPole-v0")
writer = SummaryWriter(comment="-cartpole-reinforce-baseline")

net = PGN(env.observation_space.shape[0], env.action_space.n)
print(net)

PGN(
  (net): Sequential(
    (0): Linear(in_features=4, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=2, bias=True)
  )
)


In [14]:
agent = ptan.agent.PolicyAgent(net, preprocessor=ptan.agent.float32_preprocessor,
                               apply_softmax=True)
exp_source = ptan.experience.ExperienceSourceFirstLast(env, agent, gamma=GAMMA)
optimizer = optim.Adam(net.parameters(), lr=LEARNING_RATE)


In [15]:
total_rewards = []
step_idx = 0
done_episodes = 0

batch_episodes = 0
batch_states, batch_actions, batch_qvals = [], [], []
cur_states, cur_actions, cur_rewards = [], [], []

In [16]:
for step_idx, exp in enumerate(exp_source):
    cur_states.append(exp.state)
    cur_actions.append(int(exp.action))
    cur_rewards.append(exp.reward)
    
    if exp.last_state is None:
        batch_states.extend(cur_states)
        batch_actions.extend(cur_actions)
        batch_qvals.extend(calc_qvals(cur_rewards))
        cur_states.clear()
        cur_actions.clear()
        cur_rewards.clear()
        batch_episodes += 1
        
    # handle new rewards
    new_rewards = exp_source.pop_total_rewards()
    if new_rewards:
        done_episodes += 1
        reward = new_rewards[0]
        total_rewards.append(reward)
        mean_reward = float(np.mean(total_rewards[-100:]))
        print("%d: reward: %6.2f, mean_100: %6.2f, episodes: %d"
              % (step_idx, reward, mean_reward, done_episodes))
        writer.add_scalar("reward", reward, step_idx)
        writer.add_scalar("reward_100", mean_reward, step_idx)
        writer.add_scalar("episodes", done_episodes, step_idx)
        if mean_reward > 195:
            print("Solved in %d steps and %d episodes!"
                  % (step_idx, done_episodes))
            break
    
    if batch_episodes < EPISODES_TO_TRAIN:
        continue
    
    states_v = torch.FloatTensor(batch_states)
    batch_actions_t = torch.LongTensor(batch_actions)
    batch_qvals_v = torch.FloatTensor(batch_qvals)
    
    optimizer.zero_grad()  
    logits_v = net(states_v)
    log_prob_v = F.log_softmax(logits_v, dim=1)
    log_prob_actions_v = batch_qvals_v * log_prob_v[range(len(batch_states)), batch_actions_t]
    loss_v = -log_prob_actions_v.mean()
    
    loss_v.backward()
    optimizer.step()
    
    batch_episodes = 0
    batch_states.clear()
    batch_actions.clear()
    batch_qvals.clear()

writer.close()

28: reward:  28.00, mean_100:  28.00, episodes: 1
37: reward:   9.00, mean_100:  18.50, episodes: 2
46: reward:   9.00, mean_100:  15.33, episodes: 3
58: reward:  12.00, mean_100:  14.50, episodes: 4
70: reward:  12.00, mean_100:  14.00, episodes: 5
111: reward:  41.00, mean_100:  18.50, episodes: 6
128: reward:  17.00, mean_100:  18.29, episodes: 7
147: reward:  19.00, mean_100:  18.38, episodes: 8
160: reward:  13.00, mean_100:  17.78, episodes: 9
177: reward:  17.00, mean_100:  17.70, episodes: 10
221: reward:  44.00, mean_100:  20.09, episodes: 11
256: reward:  35.00, mean_100:  21.33, episodes: 12
288: reward:  32.00, mean_100:  22.15, episodes: 13
356: reward:  68.00, mean_100:  25.43, episodes: 14
368: reward:  12.00, mean_100:  24.53, episodes: 15
450: reward:  82.00, mean_100:  28.12, episodes: 16
483: reward:  33.00, mean_100:  28.41, episodes: 17
502: reward:  19.00, mean_100:  27.89, episodes: 18
594: reward:  92.00, mean_100:  31.26, episodes: 19
622: reward:  28.00, mean_