<a href="https://colab.research.google.com/github/asceznyk/notebooks/blob/main/ppo_cartpole.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import time
import numpy as np
import scipy.signal as signal

import gym

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.distributions.categorical import Categorical


In [None]:
env = gym.make("CartPole-v1")
dim_obs = env.observation_space.shape[0]
n_actions = env.action_space.n

In [None]:
## networks

def mlp(layers, out_dim):
    args = []
    for u_in, u_out in layers:
        args.append(nn.Linear(u_in, u_out))
        args.append(nn.LeakyReLU())
    args.append((nn.Linear(u_out, out_dim)))
    return nn.Sequential(*args)

class Actor(nn.Module):
    def __init__(self, n_actions, hid_layers):
        super(Actor, self).__init__()
        self.net = mlp(hid_layers, n_actions) 
    def forward(self, x):
        return self.net(x)

class Critic(nn.Module):
    def __init__(self, hid_layers):
        super(Critic, self).__init__()
        self.net = mlp(hid_layers, 1) 
    def forward(self, x):
        return self.net(x)


In [None]:
## the buffer

def discount_cum(x, d): 
    return signal.lfilter([1], [1, float(-d)], x[::-1], axis=0)[::-1]

class Buffer:
    def __init__(self, dim_obs, gamma=0.99, lam=0.95):
        self.obs_buf = [] 
        self.act_buf = []

        self.reward_buf = []
        self.value_buf = [] 
        self.logp_buf = []

        self.adv_buf = []  
        self.return_buf = []

        self.gamma, self.lam = gamma, lam
        self.pointer, self.t_start = 0, 0

    def terminate(self, last_val=0):
        rewards = np.array(self.reward_buf)
        values = np.append(self.value_buf, last_val)
        delta = rewards + self.gamma * values[1:] - values[:-1]
        self.adv_buf = discount_cum(delta, self.gamma * self.lam)
        self.adv_buf = (self.adv_buf - np.mean(self.adv_buf)) / np.std(self.adv_buf)
        self.return_buf = discount_cum(rewards, self.gamma) 

    def store(self, obs, act, reward, logp, value):
        self.obs_buf.append(obs)
        self.act_buf.append(act)
        self.reward_buf.append(reward)
        self.logp_buf.append(logp)
        self.value_buf.append(value) 

    def get(self):
        obs_buf = np.array(self.obs_buf) 
        act_buf = np.array(self.act_buf)
        adv_buf = np.array(self.adv_buf)
        return_buf = np.array(self.return_buf)
        logp_buf = np.array(self.logp_buf)

        self.obs_buf = [] 
        self.act_buf = []

        self.reward_buf = []
        self.value_buf = [] 
        self.logp_buf = []
        
        self.adv_buf = []  
        self.return_buf = []

        return (
            obs_buf,
            act_buf,
            adv_buf,
            return_buf,
            logp_buf
        )


In [None]:
## helper

def to_torch(*args):
    for t in args: 
        yield torch.from_numpy(t)

def to_numpy(*args):
    for t in args:
        yield t.detach().numpy()

def sample_action(actor, obs): 
    logits = actor(obs)
    dist = Categorical(F.softmax(logits, dim=-1))
    return logits, dist.sample()

def log_p(logits, act_buf):
    return torch.sum(
        F.one_hot(act_buf, n_actions) * F.log_softmax(logits, dim=-1), dim=-1
    )

In [None]:
## hyperparameters of the PPO algorithm

max_ep = 500
max_steps = 1000 
gamma = 0.99
eps = 0.2
lr_policy = 3e-4
lr_value = 1e-3
policy_iterations = 10
value_iterations = 10
target_kl = 0.01
hidden_sizes = [(dim_obs, 32), (32, 32)]

render = False


In [None]:
## inits

def init_agent():    
    actor = Actor(n_actions, hidden_sizes) 
    critic = Critic(hidden_sizes)
    buffer = Buffer(dim_obs)
    policy_opt = torch.optim.Adam(actor.parameters(), lr=lr_policy)
    value_opt = torch.optim.Adam(critic.parameters(), lr=lr_value)
    return actor, critic, buffer, policy_opt, value_opt

In [None]:
## comparing different policy loss functions

def step(opt, loss): 
    opt.zero_grad()
    loss.backward()
    opt.step()

def ppo_policy(actor, opt, obs_buf, act_buf, logp_buf, adv_buf, policy_iterations=policy_iterations):
    for _ in range(policy_iterations): 
        logits = actor(obs_buf)  
        policy_loss = -torch.min(
            (log_p(logits, act_buf) - logp_buf).exp() * adv_buf, 
            torch.where(adv_buf > 0, (1+eps)*adv_buf, (1-eps)*adv_buf)
        ).mean()
        step(opt, policy_loss) 
                
        kl = (logp_buf - log_p(logits, act_buf)).mean() 
        if kl > 1.5 * target_kl: break 

def custom_policy(actor, opt, obs_buf, act_buf, logp_buf, adv_buf, policy_iterations=policy_iterations):
    for _ in range(policy_iterations):
        ratios = (log_p(actor(obs_buf), act_buf) - logp_buf).exp() 
        policy_loss = -torch.min(ratios * adv_buf, torch.clamp(ratios, 1+eps, 1-eps) * adv_buf).mean()
        step(opt, policy_loss)


In [None]:
## typical training

def train_agent(policy_loss_fn):
    ##init new agent
    actor, critic, buffer, policy_opt, value_opt = init_agent()

    actor.train()
    critic.train()
    
    final_rewards = []
    obs, ep_ret, ep_len = env.reset(), 0, 0
    for ep in range(max_ep):  
        for t in range(max_steps):  
            logit, action = sample_action(actor, torch.from_numpy(obs)) 
            obs_, reward, done, _ = env.step(action.detach().numpy())
            ep_ret += reward
    
            buffer.store(
                obs, 
                action.detach().numpy(), 
                reward / 10,
                log_p(logit, action).detach().numpy(), 
                critic(torch.from_numpy(obs)).detach().numpy()
            )
            obs = obs_
    
            if done or (t == max_steps-1): 
                buffer.terminate(0 if done else critic(torch.from_numpy(obs)).detach().numpy())  
    
                print(f"episode {ep}: return = {ep_ret}")
    
                final_rewards.append(ep_ret)
                obs, ep_ret = env.reset(), 0
                break 
    
        obs_buf, act_buf, adv_buf, return_buf, logp_buf = to_torch(*buffer.get())

        policy_loss_fn(actor, policy_opt, obs_buf, act_buf, logp_buf, adv_buf)
    
        for j in range(value_iterations):
            value_loss = ((critic(obs_buf)-return_buf)**2).mean()
    
            value_opt.zero_grad()
            value_loss.backward()
            value_opt.step()
    
    return np.mean(final_rewards)

In [None]:
emp_custom, emp_ppo = [], []
for k in range(10):
    emp_custom.append(train_agent(custom_policy))
    emp_ppo.append(train_agent(ppo_policy))

In [None]:
np.mean(emp_custom), np.std(emp_custom)

In [None]:
np.mean(emp_ppo), np.std(emp_ppo)