<a href="https://colab.research.google.com/github/asceznyk/notebooks/blob/main/ppo_cartpole.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import time
import numpy as np
import scipy.signal as signal

import gym

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.distributions.categorical import Categorical


In [None]:
env = gym.make("CartPole-v0")
dim_obs = env.observation_space.shape[0]
n_actions = env.action_space.n

In [None]:
## networks

def mlp(layers, out_dim):
    args = []
    for u_in, u_out in layers:
        args.append(nn.Linear(u_in, u_out))
        args.append(nn.Tanh())
    args.append((nn.Linear(u_out, out_dim)))
    return nn.Sequential(*args)

class Actor(nn.Module):
    def __init__(self, n_actions, hid_layers):
        super(Actor, self).__init__()
        self.net = mlp(hid_layers, n_actions) 
    def forward(self, x):
        return self.net(x)

class Critic(nn.Module):
    def __init__(self, hid_layers):
        super(Critic, self).__init__()
        self.net = mlp(hid_layers, 1) 
    def forward(self, x):
        return self.net(x)


In [None]:
## the buffer

def discount_cum(x, d): 
    return signal.lfilter([1], [1, float(-d)], x[::-1], axis=0)[::-1]

class Buffer:
    def __init__(self, dim_obs, size, gamma=0.99, lam=0.95):
        self.obs_buf = np.zeros((size, dim_obs), dtype=np.float32)
        self.act_buf = np.zeros(size, dtype=np.int64)
        self.adv_buf = np.zeros(size, dtype=np.float32) 
        self.return_buf = np.zeros(size, dtype=np.float32)
        self.reward_buf = np.zeros(size, dtype=np.float32)
        self.value_buf = np.zeros(size, dtype=np.float32)
        self.logp_buf = np.zeros(size, dtype=np.float32)
        self.gamma, self.lam = gamma, lam
        self.pointer, self.t_start = 0, 0

    def terminate(self, last_val=0):
        t_slice = slice(self.t_start, self.pointer)
        rewards = np.append(self.reward_buf[t_slice], last_val)
        values = np.append(self.value_buf[t_slice], last_val)
        delta = rewards[:-1] + self.gamma * values[1:] - values[:-1]
        self.adv_buf[t_slice] = discount_cum(delta, self.gamma * self.lam)
        self.return_buf[t_slice] = discount_cum(rewards, self.gamma)[:-1]
        self.t_start = self.pointer

    def store(self, obs, act, reward, logp, value):
        self.obs_buf[self.pointer] = obs
        self.act_buf[self.pointer] = act
        self.reward_buf[self.pointer] = reward
        self.logp_buf[self.pointer] = logp
        self.value_buf[self.pointer] = value
        self.pointer += 1

    def get(self):
        self.pointer, self.t_start = 0, 0 
        self.adv_buf = (self.adv_buf - np.mean(self.adv_buf)) / \
                        np.std(self.adv_buf) 
        return (
            self.obs_buf,
            self.act_buf,
            self.adv_buf,
            self.return_buf,
            self.logp_buf
        )


In [None]:
## helper

def to_torch(*args):
    for t in args: 
        yield torch.from_numpy(t)

def to_numpy(*args):
    for t in args:
        yield t.detach().numpy()

def sample_action(actor, obs): 
    logits = actor(obs)
    dist = Categorical(F.softmax(logits, dim=-1))
    return logits, dist.sample()

def log_p(logits, act_buf):
    return torch.sum(
        F.one_hot(act_buf, n_actions) * F.log_softmax(logits, dim=-1), dim=-1
    )

In [None]:
## hyperparameters of the PPO algorithm

steps_per_epoch = 4000
epochs = 30
gamma = 0.99
eps = 0.2
lr_policy = 3e-4
lr_value = 1e-3
policy_iterations = 80
value_iterations = 80
lam = 0.97
target_kl = 0.01
hidden_sizes = [(dim_obs, 64), (64, 64)]

render = False


In [None]:
## inits

actor = Actor(n_actions, hidden_sizes) 
critic = Critic(hidden_sizes)

buffer = Buffer(dim_obs, steps_per_epoch)
policy_opt = torch.optim.Adam(actor.parameters(), lr=lr_policy)
value_opt = torch.optim.Adam(critic.parameters(), lr=lr_value)

obs, ep_ret, ep_len = env.reset(), 0, 0

In [None]:
## typical training
for e in range(epochs):
    sum_return = 0
    sum_len = 0
    num_ep = 0

    actor.eval()
    critic.eval()
    for t in range(steps_per_epoch):  
        logit, action = sample_action(actor, torch.from_numpy(obs)) 
        obs_, reward, done, _ = env.step(action.detach().numpy())
        ep_ret += reward
        ep_len += 1

        buffer.store(
            obs, 
            action.detach().numpy(), 
            reward,
            log_p(logit, action).detach().numpy(), 
            critic(torch.from_numpy(obs)).detach().numpy()
        )
        obs = obs_

        if done or (t == steps_per_epoch-1): 
            buffer.terminate(0 if done else critic(torch.from_numpy(obs)).detach().numpy())
            sum_return += ep_ret
            sum_len += ep_len
            num_ep += 1
            obs, ep_ret, ep_len = env.reset(), 0, 0
 
    obs_buf, act_buf, adv_buf, return_buf, logp_buf = to_torch(*buffer.get()) 

    actor.train()
    critic.train()

    for i in range(policy_iterations): 
        policy_loss = -torch.minimum(
            (log_p(actor(obs_buf), act_buf) - logp_buf).exp() * adv_buf, 
            torch.where(adv_buf > 0, (1+eps)*adv_buf, (1-eps)*adv_buf)
        ).mean() 

        policy_opt.zero_grad()
        policy_loss.backward()
        policy_opt.step()
        
        kl = (logp_buf - log_p(actor(obs_buf), act_buf)).mean().sum() 
        if kl > 1.5 * target_kl: break

    for j in range(value_iterations):
        value_loss = ((critic(obs_buf)-return_buf)**2).mean()

        value_opt.zero_grad()
        value_loss.backward()
        value_opt.step()

    print(f"Total episodes: {num_ep}")
    print(
        f"Epoch {e+1}: Mean return: {sum_return / num_ep}, Mean length: {sum_len / num_ep}"
    )