# PPO (Proximal Policy Implementation)

- Original paper : Proximal Policy Optimization Algorithms
- https://arxiv.org/pdf/1707.06347.pdf

## Install packages

In [None]:
!python3 -m pip install pixyz
!python3 -m pip install gym-notebook-wrapper
!apt update && apt install xvfb

## Install modules

In [None]:
import gnwrapper
import gym

import numpy as np

import torch
from torch import nn

from pixyz import distributions as dists
from pixyz.models import Model
from pixyz.losses.losses import Loss, LossSelfOperator
from pixyz.losses import Entropy, MinLoss, ValueLoss, Parameter
from pixyz.utils import print_latex

import sympy

In [None]:
class RolloutBuffer:
    def __init__(self):
        self.actions = []
        self.states = []
        self.logprobs = []
        self.rewards = []
        self.state_values = []
        self.is_terminals = []

    def clear(self):
        del self.actions[:]
        del self.states[:]
        del self.logprobs[:]
        del self.rewards[:]
        del self.state_values[:]
        del self.is_terminals[:]

# Define loss

In [None]:
class Ratio(Loss):
    """
    """

    def __init__(self, p, q, sum_features=False, feature_dims=None):
        super().__init__(p.var + p.input_var + q.var + q.input_var)

        self.sum_features = sum_features
        self.feature_dims = feature_dims

        self.p = p
        self.q = q

    @property
    def _symbol(self):
        return sympy.Symbol(f"\\frac{{{self.p.prob_text}}}{{{self.q.prob_text}}}")

    def forward(self, x_dict={}, **kwargs):
        p_log_prob = self.p.log_prob(sum_features=self.sum_features, feature_dims=self.feature_dims, **kwargs).eval(x_dict)
        q_log_prob = self.q.log_prob(sum_features=self.sum_features, feature_dims=self.feature_dims, **kwargs).eval(x_dict)

        ratio = torch.exp(p_log_prob - q_log_prob.detach())

        return ratio, {}

In [None]:
class ClipLoss(LossSelfOperator):
    """
    """

    def __init__(self, loss1, min, max):
        super().__init__(loss1)

        self.min = min
        self.max = max

    @property
    def _symbol(self):
        return sympy.Symbol(f"clip({self.loss1.loss_text}, {self.min}, {self.max})")

    def forward(self, x_dict={}, **kwargs):
        loss, x_dict = self.loss1(x_dict, **kwargs)
        loss = torch.clamp(loss, self.min, self.max)

        return loss, x_dict

In [None]:
class MSELoss(Loss):
    """
    """

    def __init__(self, var1, var2):
        super().__init__([var1, var2])

        self.var1 = var1
        self.var2 = var2

        self.MSELoss = nn.MSELoss()

    @property
    def _symbol(self):
        return sympy.Symbol(f"MSE({self.var1},{self.var2})")

    def forward(self, x_dict={}, **kwargs):

        loss = self.MSELoss(x_dict[self.var1], x_dict[self.var2])

        return loss, {}

### Define models

In [None]:
class Actor(dists.Categorical):
    def __init__(self, state_dim, action_dim, name):
        super().__init__(var=["a"], cond_var=["s"], name=name)

        self.output_probs = nn.Sequential(
            nn.Linear(state_dim, action_dim),
            nn.Sigmoid()
        )

    def forward(self, s):
        probs = self.output_probs(s)
        return {"probs": probs}

In [None]:
class Critic(dists.Normal):
    def __init__(self, state_dim):
        super().__init__(var=["v"], cond_var=["s"])

        self.backborn = nn.Sequential(
            nn.Linear(state_dim, 4),
            nn.Sigmoid(),
            nn.Linear(4, 4),
            nn.SiLU(),
        )

        self.output_loc = nn.Sequential(
            nn.Linear(4, 1),
        )

        self.output_scale = nn.Sequential(
            nn.Linear(4, 1),
            nn.Softplus()
        )

    def forward(self, s):
        h = self.backborn(s)
        loc = self.output_loc(h)
        scale = self.output_scale(h)

        return {"loc": loc, "scale": 1.}

In [None]:
class PPO(Model):
    def __init__(self, actor, actor_old, critic, gamma, eps_clip, K_epochs, device):

        ##############################
        #      Hyper parameters      #
        ##############################
        self.gamma = gamma
        self.eps_clip = eps_clip
        self.K_epochs = K_epochs
        self.device = device

        #################################
        #      Actor-Critic models      #
        #################################
        self.actor = actor.to(self.device)
        self.actor_old = actor_old.to(self.device)
        self.actor_old.load_state_dict(self.actor.state_dict())
        self.critic = critic.to(self.device)

        ###########################
        #      Loss function      #
        ###########################
        advantage = Parameter("\\hat{A}")
        ratio = Ratio(self.actor, self.actor_old)
        clip = ClipLoss(ratio, 1-eps_clip, 1+eps_clip)

        ppo_loss = MinLoss(ratio*advantage, clip*advantage)
        value_loss = ValueLoss(0.5)*MSELoss("v", "r")
        entropy = ValueLoss(0.01)*Entropy(self.actor)

        loss_func = (value_loss - ppo_loss - entropy).mean()

        #########################
        #      Setup model      #
        #########################
        super().__init__(loss=loss_func, distributions=[self.actor, self.critic], retain_graph=True)

        self.buffer = RolloutBuffer()

    def select_action(self, state):
        """
        """

        with torch.no_grad():
            state = state.to(self.device)
            action = self.actor_old.sample({"s": state})["a"].detach()
            state_val = self.critic.sample({"s": state})["v"].detach()

            self.buffer.states.append(state)
            self.buffer.actions.append(action)
            self.buffer.state_values.append(state_val)

            return action.detach().cpu().numpy()

    def get_discount_reward(self):
        """
        """

        rewards = []
        discounted_reward = 0.

        for reward, is_terminal in zip(reversed(self.buffer.rewards), reversed(self.buffer.is_terminals)):
            if is_terminal:
                discounted_reward = 0.

            discounted_reward = reward + self.gamma * discounted_reward
            rewards.insert(0, discounted_reward)

        rewards = torch.tensor(rewards).to(dtype=torch.float32, device=self.device)
        rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-7)

        return rewards.detach()

    def update(self):
        """
        """

        # Calculate discount rewards
        rewards = self.get_discount_reward()

        old_states = torch.squeeze(
            torch.stack(self.buffer.states, dim=0)).detach().to(self.device)
        old_actions = torch.squeeze(
            torch.stack(self.buffer.actions, dim=0)).detach().to(self.device)
        old_state_values = torch.squeeze(
            torch.stack(self.buffer.state_values, dim=0)).detach().to(self.device)

        # calculate advantages
        advantages = rewards.detach() - old_state_values.detach()

        total_loss = 0.

        for _ in range(self.K_epochs):
            # Evaluating old values
            state_values = self.critic.sample({"s": old_states})["v"]

            # match state_values tensor dimensions with rewards tensor
            state_values = torch.squeeze(state_values)

            loss = self.train({
                "s": old_states,
                "a": old_actions,
                "\\hat{A}": advantages,
                "v": state_values,
                "r": rewards})

            total_loss += loss

        print("Train Loss", (total_loss/self.K_epochs).cpu().detach().numpy())

        # Copy new weights into old policy
        self.actor_old.load_state_dict(self.actor.state_dict())

        # clear buffer
        self.buffer.clear()

In [None]:
actor = Actor(4, 2, "\\pi")
actor_old = Actor(4, 2, "\\pi_{old}")
critic = Critic(4)

gamma = 0.98
eps_clip = 0.1
K_epochs = 500

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [None]:
ppo = PPO(actor, actor_old, critic, gamma, eps_clip, K_epochs, device)

print(ppo)

Distributions (for training):
  \pi(a|s), p(v|s)
Loss function:
  mean \left(- 0.01 H \left[ {\pi(a|s)} \right] + 0.5 MSE(v,r) - min \left(\frac{\pi(a|s)}{\pi_{old}(a|s)} \hat{A}, \hat{A} clip(\frac{\pi(a|s)}{\pi_{old}(a|s)}, 0.9, 1.1)\right) \right)
Optimizer:
  Adam (
  Parameter Group 0
      amsgrad: False
      betas: (0.9, 0.999)
      capturable: False
      differentiable: False
      eps: 1e-08
      foreach: None
      fused: None
      lr: 0.001
      maximize: False
      weight_decay: 0
  )


In [None]:
print_latex(ppo)

<IPython.core.display.Math object>

In [None]:
env_name = "CartPole-v1"
env = gnwrapper.LoopAnimation(gym.make(env_name))


for i in range(1, 50000):
    state = env.reset()
    steps = 0
    rewards = 0.
    while True:
        steps += 1

        action = ppo.select_action(torch.from_numpy(state))
        _action = np.argmax(action)
        state, reward, done, _ = env.step(_action)
        rewards += reward

        # saving reward and is_terminals
        ppo.buffer.rewards.append(reward)
        ppo.buffer.is_terminals.append(done)

        if done:
            break

    print("Reward", rewards)

    if i % 10 == 0:
        ppo.update()
        print("###########################################")

Reward 11.0
Reward 14.0
Reward 25.0
Reward 19.0
Reward 26.0


  deprecation(
  deprecation(


Reward 13.0
Reward 15.0
Reward 10.0
Reward 16.0
Reward 14.0
Train Loss 0.71923715
###########################################
Reward 28.0
Reward 11.0
Reward 15.0
Reward 24.0
Reward 20.0
Reward 34.0
Reward 17.0
Reward 17.0
Reward 45.0
Reward 15.0
Train Loss 0.6763395
###########################################
Reward 16.0
Reward 15.0
Reward 23.0
Reward 28.0
Reward 23.0
Reward 17.0
Reward 14.0
Reward 20.0
Reward 15.0
Reward 36.0
Train Loss 0.6565022
###########################################
Reward 38.0
Reward 25.0
Reward 92.0
Reward 26.0
Reward 36.0
Reward 59.0
Reward 32.0
Reward 21.0
Reward 20.0
Reward 38.0
Train Loss 0.7915196
###########################################
Reward 44.0
Reward 52.0
Reward 94.0
Reward 38.0
Reward 38.0
Reward 50.0
Reward 79.0
Reward 24.0
Reward 37.0
Reward 36.0
Train Loss 0.7029697
###########################################
Reward 15.0
Reward 26.0
Reward 26.0
Reward 13.0
Reward 46.0
Reward 19.0
Reward 21.0
Reward 12.0
Reward 22.0
Reward 25.0
Train Loss 0.5

KeyboardInterrupt: ignored

In [None]:
env = gnwrapper.LoopAnimation(gym.make('CartPole-v1'))

state = env.reset()

while True:

    action = ppo.select_action(torch.from_numpy(state))
    _action = np.argmax(action)
    state, reward, done, _ = env.step(_action)

    env.render()

    if done:
        break

env.close()
env.display()

If you want to render in human mode, initialize the environment in this way: gym.make('EnvName', render_mode='human') and don't call the render method.
See here for more information: https://www.gymlibrary.ml/content/api/[0m
  deprecation(
