In [31]:
import gym
import ptan
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from typing import Optional

import torch
import torch.nn as nn 
import torch.nn.functional as F 
import torch.optim as optim 

In [32]:
GAMMA = 0.99
LEARNING_RATE = 0.001
ENTROPY_BETA = 0.01
BATCH_SIZE = 8

REWARD_STEPS = 10

In [33]:
class PGN(nn.Module):
    def __init__(self, input_size, n_actions):
        super(PGN, self).__init__()
        
        self.net = nn.Sequential(
            nn.Linear(input_size, 128),
            nn.ReLU(),
            nn.Linear(128, n_actions)
        )
        
    def forward(self, x):
        return self.net(x)

In [34]:
def smooth(old: Optional[float], val: float, alpha: float = 0.95) -> float:
    if old is None:
        return val
    return old * alpha + (1-alpha)*val

In [35]:
env = gym.make("CartPole-v0")
writer = SummaryWriter(comment="-cartpole-pg")

net = PGN(env.observation_space.shape[0], env.action_space.n)
print(net)

PGN(
  (net): Sequential(
    (0): Linear(in_features=4, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=2, bias=True)
  )
)


In [36]:
agent = ptan.agent.PolicyAgent(net, preprocessor=ptan.agent.float32_preprocessor,
                               apply_softmax=True)
exp_source = ptan.experience.ExperienceSourceFirstLast(
    env, agent, gamma=GAMMA, steps_count=REWARD_STEPS
)
optimizer = optim.Adam(net.parameters(), lr=LEARNING_RATE)

In [37]:
total_rewards = []
step_rewards = []
step_idx = 0
done_episodes = 0
reward_sum = 0.0
bs_smoothed = entropy = l_entropy = l_policy = l_total = None

batch_states, batch_actions, batch_scales = [], [], []

In [38]:
for step_idx, exp in enumerate(exp_source):
    reward_sum += exp.reward
    baseline = reward_sum / (step_idx + 1)
    writer.add_scalar("baseline", baseline, step_idx)
    batch_states.append(exp.state)
    batch_actions.append(int(exp.action))
    batch_scales.append(exp.reward - baseline)
    
    # handle new rewards
    new_rewards = exp_source.pop_total_rewards()
    if new_rewards:
        done_episodes += 1
        reward = new_rewards[0]
        total_rewards.append(reward)
        mean_rewards = float(np.mean(total_rewards[-100:]))
        print("%d: reward: %6.2f, mean_100: %6.2f, episodes: %d"
              % (step_idx, reward, mean_rewards, done_episodes))
        writer.add_scalar("reward", reward, step_idx)
        writer.add_scalar("reward_100", mean_rewards, step_idx)
        writer.add_scalar("episodes", done_episodes, step_idx)
        if mean_rewards > 195:
            print("Solved in %d steps and %d episodes!"
                  % (step_idx, done_episodes))
            break
    
    if len(batch_states) < BATCH_SIZE:
        continue

    states_v = torch.FloatTensor(batch_states)
    batch_actions_t = torch.LongTensor(batch_actions)
    batch_scale_v = torch.FloatTensor(batch_scales)
    
    optimizer.zero_grad()
    logits_v = net(states_v)
    log_prob_v = F.log_softmax(logits_v, dim=1)
    log_prob_actions_v = batch_scale_v * log_prob_v[range(BATCH_SIZE), batch_actions_t]
    loss_policy_v = -log_prob_actions_v.mean()
    
    prob_v = F.softmax(logits_v, dim=1)
    entropy_v = -(prob_v * log_prob_v).sum(dim=1).mean()
    entropy_loss_v = -ENTROPY_BETA * entropy_v
    loss_v = loss_policy_v + entropy_loss_v
    
    loss_v.backward()
    optimizer.step()
    
    # cal KL-div
    new_logits_v = net(states_v)
    new_prob_v = F.softmax(new_logits_v, dim=1)
    kl_div_v = -((new_prob_v / prob_v).log() * prob_v).sum(dim=1).mean()
    writer.add_scalar("kl", kl_div_v.item(), step_idx)
    
    grad_max = 0.0
    grad_means = 0.0
    grad_count = 0
    for p in net.parameters():
        grad_max = max(grad_max, p.grad.abs().max().item())
        grad_means += (p.grad ** 2).mean().sqrt().item()
        grad_count += 1
    
    bs_smoothed = smooth(bs_smoothed, np.mean(batch_scales))
    entropy = smooth(entropy, entropy_v.item())
    l_entropy = smooth(l_entropy, entropy_loss_v.item())
    l_policy = smooth(l_policy, loss_policy_v.item())
    l_total = smooth(l_total, loss_v.item())
    
    writer.add_scalar("baseline", baseline, step_idx)
    writer.add_scalar("entropy", entropy, step_idx)
    writer.add_scalar("loss_entropy", l_entropy, step_idx)
    writer.add_scalar("loss_policy", l_policy, step_idx)
    writer.add_scalar("loss_total", l_total, step_idx)
    writer.add_scalar("grad_12", grad_means / grad_count, step_idx)
    writer.add_scalar("grad_max", grad_max, step_idx)
    writer.add_scalar("batch_scales", bs_smoothed, step_idx)
    
    batch_states.clear()
    batch_actions.clear()
    batch_scales.clear()
    
writer.close()

22: reward:  22.00, mean_100:  22.00, episodes: 1
34: reward:  12.00, mean_100:  17.00, episodes: 2
53: reward:  19.00, mean_100:  17.67, episodes: 3
69: reward:  16.00, mean_100:  17.25, episodes: 4
84: reward:  15.00, mean_100:  16.80, episodes: 5
101: reward:  17.00, mean_100:  16.83, episodes: 6
110: reward:   9.00, mean_100:  15.71, episodes: 7
121: reward:  11.00, mean_100:  15.12, episodes: 8
132: reward:  11.00, mean_100:  14.67, episodes: 9
188: reward:  56.00, mean_100:  18.80, episodes: 10
239: reward:  51.00, mean_100:  21.73, episodes: 11
297: reward:  58.00, mean_100:  24.75, episodes: 12
326: reward:  29.00, mean_100:  25.08, episodes: 13
356: reward:  30.00, mean_100:  25.43, episodes: 14
377: reward:  21.00, mean_100:  25.13, episodes: 15
435: reward:  58.00, mean_100:  27.19, episodes: 16
462: reward:  27.00, mean_100:  27.18, episodes: 17
480: reward:  18.00, mean_100:  26.67, episodes: 18
501: reward:  21.00, mean_100:  26.37, episodes: 19
538: reward:  37.00, mean_