# Proximal Policy Optimization (PPO)

In [None]:
!apt-get install -y xvfb

!pip install gym==0.23.1 \
    pytorch-lightning==1.6 \
    pyvirtualdisplay

!pip install -U brax==0.0.12 jax==0.3.14 jaxlib==0.3.14+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

import warnings 
warnings.filterwarnings('ignore')

#### Setup virtual display

In [2]:
from pyvirtualdisplay import Display
Display(visible=False, size=(1400, 900)).start()

<pyvirtualdisplay.display.Display at 0x7fdbc2ef4c70>

#### Import the necessary code libraries

In [3]:
import copy
import torch
import random
import gym
import matplotlib
import functools
import itertools
import math

import numpy as np
import matplotlib.pyplot as plt

import torch.nn.functional as F

from collections import deque, namedtuple
from IPython.display import HTML
from base64 import b64encode

from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data.dataset import IterableDataset
from torch.optim import AdamW

from torch.distributions import Normal

from pytorch_lightning import LightningModule, Trainer

import brax.v1
from brax.v1 import envs
from brax.v1.envs import to_torch
from brax.v1.io import html

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
num_gpus = torch.cuda.device_count()

v = torch.ones(1, device='cuda')

In [4]:
@torch.no_grad()
def create_video(env, episode_length, policy=None):
  qp_array = []
  state = env.reset()
  for i in range(episode_length):
    if policy:
      loc, scale = policy(state)
      sample = torch.normal(loc, scale)
      action = torch.tanh(sample)
    else:
      action = env.action_space.sample()
    state, _, _, _ = env.step(action)
    qp_array.append(env.unwrapped._state.qp)
  return HTML(html.render(env.unwrapped._env.sys, qp_array))


@torch.no_grad()
def test_agent(env, episode_length, policy, episodes=10):

  ep_returns = []
  for ep in range(episodes):
    state = env.reset()
    done = False
    ep_ret = 0.0

    while not done:
      loc, scale = policy(state)
      sample = torch.normal(loc, scale)
      action = torch.tanh(sample)   
      state, reward, done, info = env.step(action)
      ep_ret += reward.item()

    ep_returns.append(ep_ret)

  return sum(ep_returns) / episodes  

#### Create the policy

In [5]:
class GradientPolicy(nn.Module):

  def __init__(self, in_features, out_dims, hidden_size=128):
    super().__init__()
    self.fc1 = nn.Linear(in_features, hidden_size)
    self.fc2 = nn.Linear(hidden_size, hidden_size)
    self.fc_mu = nn.Linear(hidden_size, out_dims)
    self.fc_std = nn.Linear(hidden_size, out_dims)

  def forward(self, x):
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    loc = self.fc_mu(x)
    loc = torch.tanh(loc)
    scale = self.fc_std(x)
    scale = F.softplus(scale) + 0.001
    return loc, scale

#### Create the value network

In [6]:
class ValueNet(nn.Module):

  def __init__(self, in_features, hidden_size=128):
    super().__init__()
    self.fc1 = nn.Linear(in_features, hidden_size)
    self.fc2 = nn.Linear(hidden_size, hidden_size)
    self.fc3 = nn.Linear(hidden_size, 1)

  def forward(self, x):
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x

#### Create the environment

In [7]:
class RunningMeanStd:
    # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
    def __init__(self, epsilon=1e-4, shape=()):
        self.mean = torch.zeros(shape, dtype=torch.float32).to(device)
        self.var = torch.ones(shape, dtype=torch.float32).to(device)
        self.count = epsilon

    def update(self, x):
        batch_mean = torch.mean(x, dim=0)
        batch_var = torch.var(x, dim=0)
        batch_count = x.shape[0]
        self.update_from_moments(batch_mean, batch_var, batch_count)

    def update_from_moments(self, batch_mean, batch_var, batch_count):
        self.mean, self.var, self.count = update_mean_var_count_from_moments(
            self.mean, self.var, self.count, batch_mean, batch_var, batch_count
        )


def update_mean_var_count_from_moments(
    mean, var, count, batch_mean, batch_var, batch_count
):
    delta = batch_mean - mean
    tot_count = count + batch_count

    new_mean = mean + delta * batch_count / tot_count
    m_a = var * count
    m_b = batch_var * batch_count
    M2 = m_a + m_b + torch.square(delta) * count * batch_count / tot_count
    new_var = M2 / tot_count
    new_count = tot_count

    return new_mean, new_var, new_count


class NormalizeObservation(gym.core.Wrapper):

    def __init__(self, env, epsilon=1e-8):
        super().__init__(env)
        self.num_envs = getattr(env, "num_envs", 1)
        self.obs_rms = RunningMeanStd(shape=self.observation_space.shape[-1])
        self.epsilon = epsilon

    def step(self, action):
        obs, rews, dones, infos = self.env.step(action)
        obs = self.normalize(obs)
        return obs, rews, dones, infos

    def reset(self, **kwargs):
        return_info = kwargs.get("return_info", False)
        if return_info:
            obs, info = self.env.reset(**kwargs)
        else:
            obs = self.env.reset(**kwargs)
        obs = self.normalize(obs)
        if not return_info:
            return obs
        else:
            return obs, info

    def normalize(self, obs):
        self.obs_rms.update(obs)
        return (obs - self.obs_rms.mean) / torch.sqrt(self.obs_rms.var + self.epsilon)

In [8]:
entry_point = functools.partial(envs.create_gym_env, env_name='ant')
gym.register('brax-ant-v0', entry_point=entry_point)

In [9]:
env = gym.make("brax-ant-v0", episode_length=1000)
env = to_torch.JaxToTorchWrapper(env, device=device)
create_video(env, 1000)

  logger.warn(
  logger.warn(
  logger.warn(
  logger.deprecation(
  if not isinstance(done, (bool, np.bool8)):
  logger.warn(
  logger.warn(
  logger.warn("Casting input x to numpy array.")
  logger.warn(


In [10]:
def create_env(env_name, num_envs=256, episode_length=1000):
  env = gym.make(env_name, batch_size=num_envs, episode_length=episode_length)
  env = to_torch.JaxToTorchWrapper(env, device=device)
  env = NormalizeObservation(env)
  return env

In [11]:
env = create_env('brax-ant-v0', num_envs=10)
obs = env.reset()
print(obs)
print("Num envs: ", obs.shape[0], "Obs dimentions: ", obs.shape[1])
print(obs.shape)

  logger.warn(
  logger.warn(
  logger.warn(


tensor([[ 2.6840e-01,  2.2386e-03,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         -3.8985e-02, -6.6084e-01, -3.0037e-01, -1.0137e+00, -6.8261e-01,
          1.1341e+00, -5.5921e-01,  5.2999e-01,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  6.8554e-01,
          1.0177e+00,  7.0795e-01, -3.6630e-01,  5.3200e-01, -1.4538e+00,
          1.1369e+00, -3.5128e-02,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         -7.6877e-02,  0.0000e+00,  0.0000e+00,  0.0000e+00,  1.7143e-01,
         -1.8722e-01, -3.1063e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00, -1.8475e-01,  1.7777e-01, -2.2249e-01,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00, -9.4708e-03,  9.3989e-03,
          0.0000e+00,  0.0000e+00,  0.

In [13]:
env.action_space

Box(-1.0, 1.0, (10, 8), float32)

In [14]:
obs, reward, done, info = env.step(env.action_space.sample())

  logger.deprecation(
  if not isinstance(done, (bool, np.bool8)):
  logger.warn(
  logger.warn(
  logger.warn("Casting input x to numpy array.")
  logger.warn(


In [15]:
print(obs)
print(reward)
print(done)
print(info)

tensor([[-4.2192e-01,  2.7125e-01,  7.4431e-02,  7.2217e-01,  4.3458e-01,
          2.1812e-01, -1.2240e+00, -5.1481e-03, -7.1066e-01, -1.2513e+00,
          8.2206e-01, -9.4428e-01, -1.4762e-01,  1.4865e+00, -9.9853e-02,
         -1.0603e+00,  1.4437e-01,  5.7022e-01,  4.3548e-01, -3.9661e-01,
          9.7564e-01, -2.9750e-01,  4.8157e-01,  1.0842e+00, -1.9722e-01,
          9.2837e-01,  5.0500e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00, -3.1799e-01, -3.1799e-01,
          3.1755e-01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0181e+00,
          2.1233e+00, -2.0021e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          4.2205e-02,  2.9145e-01,  2.4061e-01,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  1.5433e-02,  4.0989e-01,  5.0200e-01,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00, -2.3342e-01,  3.1088e-01,
          1.6654e-01,  0.0000e+00,  0.

In [16]:
info.keys()

dict_keys(['distance_from_origin', 'first_obs', 'first_qp', 'forward_reward', 'reward_contact', 'reward_ctrl', 'reward_forward', 'reward_survive', 'steps', 'truncation', 'x_position', 'x_velocity', 'y_position', 'y_velocity'])

#### Create the dataset

In [12]:
class RLDataset(IterableDataset):

  def __init__(self, env, policy, samples_per_epoch, epoch_repeat):#how many times we want to repeat the samples of a given epoch
    self.env = env
    self.policy = policy
    self.samples_per_epoch = samples_per_epoch
    self.epoch_repeat = epoch_repeat
    self.obs = self.env.reset()

  @torch.no_grad()
  def __iter__(self):
    transitions = []
    for step in range(self.samples_per_epoch):
      loc, scale = self.policy(self.obs)
      action = torch.normal(loc, scale)
      next_obs, reward, done, info = self.env.step(action)
      transitions.append((self.obs, loc, scale, action, reward, done, next_obs))
      self.obs = next_obs
    
    #num_samples is the number of rows that we want our tensors of obs_b, loc_b, scale_b, action_b, reward_b, done_b, next_obs_b to have
    num_samples = self.env.num_envs * self.samples_per_epoch
    reshape_fn = lambda x: x.view(num_samples, -1)
    batch = map(torch.stack, zip(*transitions))#map function takes a function(torch.stack) and a collection(*transition)
    #what are we doing here?We are taking a list of transitions and we are using the * operator to decompose the contents of the transitions list
    #then out of each of the elements from this list of transitions we are going to take using the zip function, the first one that is the observation and we'll group them together
    #then the second one, then the third and so on and so forth. And by applying to stack function to a group of observations, we will get a tensor of observations when we apply 
    #and when we apply the stack function to the group of loc, we will get a tensor of the mean action values and so on and so forth
    obs_b, loc_b, scale_b, action_b, reward_b, done_b, next_obs_b = map(reshape_fn, batch)#(num_samples, num_envs, feature_dims) -> (num_samples*num_envs, feature_dims)

    #the following is different
    #so for a given number of times, we are going to pass, the same observations to our traning step to extract all the learning that we can from a small sample
    for repeat in range(self.epoch_repeat):
      idx = list(range(num_samples))
      random.shuffle(idx)

      for i in idx:
        yield obs_b[i], loc_b[i], scale_b[i], action_b[i], reward_b[i], done_b[i], next_obs_b[i]

#### Create the PPO algorithm

In [18]:
class PPO(LightningModule):

  def __init__(self, env_name, num_envs=2048, episode_length=1000,
               batch_size=1024, hidden_size=256, samples_per_epoch=5,
               epoch_repeat=8, policy_lr=1e-4, value_lr=1e-3, gamma=0.97,
               epsilon=0.3, entropy_coef=0.1, optim=AdamW):
    
    super().__init__()
    self.automatic_optimization = False  # Disable automatic optimization
    #for using more than two optimisers
    self.env = create_env(env_name, num_envs=num_envs, episode_length=episode_length)
    test_env = gym.make(env_name, episode_length=episode_length)
    test_env = to_torch.JaxToTorchWrapper(test_env, device=device)#thndis is very important
    self.test_env = NormalizeObservation(test_env)
    self.test_env.obs_rms = self.env.obs_rms#share the normalization metrics

    obs_size = self.env.observation_space.shape[1]
    action_dims = self.env.action_space.shape[1]

    self.policy = GradientPolicy(obs_size, action_dims, hidden_size)
    self.value_net = ValueNet(obs_size, hidden_size)
    self.target_value_net = copy.deepcopy(self.value_net)

    self.dataset = RLDataset(self.env, self.policy, samples_per_epoch, epoch_repeat)#this
    #tiny little parameterm epoch_repeat

    self.save_hyperparameters()
    self.videos = []
  
  def configure_optimizers(self):
    value_opt = self.hparams.optim(self.value_net.parameters(), lr=self.hparams.value_lr)
    policy_opt = self.hparams.optim(self.policy.parameters(), lr=self.hparams.policy_lr)
    return value_opt, policy_opt

  def train_dataloader(self):
    return DataLoader(dataset=self.dataset, batch_size=self.hparams.batch_size)

  def training_step(self, batch, batch_idx):
    obs_b, loc_b, scale_b, action_b, reward_b, done_b, next_obs_b = batch
    value_opt, policy_opt = self.optimizers()
    state_values = self.value_net(obs_b)

    with torch.no_grad():
      next_state_values = self.target_value_net(next_obs_b)
      next_state_values[done_b.bool()] = 0.0
      target = reward_b + self.hparams.gamma * next_state_values
    
    # if optimizer_idx == 0:
    value_loss = F.smooth_l1_loss(state_values, target)
    self.log("episode/Value Loss", value_loss)
    value_opt.zero_grad()
    value_loss.backward()
    value_opt.step()
      # return loss

    # elif optimizer_idx == 1:
    #.detach() because although the target is not part of a computation graph, but state_values are part of a computation graph and we don't wish advantages 
    #to become part of that computation graph, so that is why, we are using the detach() method on advantages
    advantages = (target - state_values).detach()
    #initially new_loc, new_scale will be same as loc_b, scale_b , but as we train multiple epochs on the same batch, new_loc & new_scale start to diverege from collected 
    new_loc, new_scale = self.policy(obs_b)
    #since we are going to see the same samples a given number of times, and we are going to update our policy, after every update, the action distributions produced by our policy will change
    #and that is why we are recomputing them and calling them new_loc, new_scale
    dist = Normal(new_loc, new_scale)
    #creating new_distribution
    #focus, we are calculating the log_prob of the same action, under different probability distribution
    #first, below
    log_prob = dist.log_prob(action_b).sum(dim=-1, keepdim=True)

    prev_dist = Normal(loc_b, scale_b)
    #second, here
    prev_log_prob = prev_dist.log_prob(action_b).sum(dim=-1, keepdim=True)
    #The following three variables are just code implementation of them as mentioned in the pseudocode, not much to understand
    rho = torch.exp(log_prob - prev_log_prob)
    #one of the two possible loss(revise the algorithm if you can't understand)
    surrogate_1 = rho * advantages
    surrogate_2 = rho.clip(1 - self.hparams.epsilon, 1 + self.hparams.epsilon) * advantages

    policy_loss = - torch.minimum(surrogate_1, surrogate_2)
    entropy = dist.entropy().sum(dim=-1, keepdim=True)
    true_policy_loss = policy_loss - self.hparams.entropy_coef * entropy
    policy_opt.zero_grad()
# The .mean() operation is applied to true_policy_loss before calling .backward() because true_policy_loss is typically a tensor with more than one element. PyTorch requires that gradients be computed on scalar values (tensors with a single element), not on multi-element tensors. 
# Here's a more detailed explanation:
# Understanding the Need for .mean()
# Multi-element Tensors:
# The true_policy_loss is computed for each element in the batch, resulting in a tensor of losses. This tensor will have one loss value per sample in the batch.
# Backpropagating through a multi-element tensor without reducing it to a scalar can lead to ambiguity, as it is not clear how to aggregate the gradients across the different elements.
# Gradient Calculation:
# In order to perform backpropagation, PyTorch needs a scalar loss value. Taking the mean (or sum) of the true_policy_loss tensor reduces it to a single scalar value.
# This scalar loss value is then used to compute gradients via the .backward() method, ensuring that the gradients are correctly aggregated across the batch.
# Gradient Consistency:
# Taking the mean of the losses before backpropagation ensures consistent gradient magnitudes, regardless of batch size. This helps maintain stable learning rates and training dynamics.
# Example Code
# In your training_step method, the true_policy_loss tensor is created by combining the policy loss and entropy penalty for each sample in the batch. To perform backpropagation correctly, you take the mean of this tensor before calling .backward():
    true_policy_loss.mean().backward()
    policy_opt.step()

    self.log("episode/Policy Loss", policy_loss.mean())
    self.log("episode/Entropy", entropy.mean())
    self.log("episode/Reward", reward_b.mean())
      # return loss.mean()

  def on_train_epoch_end(self):
    #updating the target network
    self.target_value_net.load_state_dict(self.value_net.state_dict())

    if self.current_epoch % 10 == 0:
      average_return = test_agent(self.test_env, self.hparams.episode_length, self.policy, episodes=1)
      self.log("episode/Average Return", average_return)
    
    if self.current_epoch % 50 == 0:
      video = create_video(self.test_env, self.hparams.episode_length, policy=self.policy)
      self.videos.append(video)

#### Purge logs and run the visualization tool (Tensorboard)

In [None]:
# Start tensorboard.
!rm -r /content/lightning_logs/
!rm -r /content/videos/
%load_ext tensorboard
%tensorboard --logdir /content/lightning_logs/

#### Train the policy

In [19]:
algo = PPO("brax-ant-v0")

trainer = Trainer(
                #   gpus=num_gpus,
                  max_epochs=500)

trainer.fit(algo)

  logger.warn(
  logger.warn(
  logger.warn(
  logger.warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name             | Type           | Params | Mode 
------------------------------------------------------------
0 | policy           | GradientPolicy | 92.4 K | train
1 | value_net        | ValueNet       | 88.6 K | train
2 | target_value_net | ValueNet       | 88.6 K | train
------------------------------------------------------------
269 K     Trainable params
0         Non-trainable params
269 K     Total params
1.078     Total estimated model params size (MB)
/home/akhters/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance

Training: |          | 0/? [00:00<?, ?it/s]

/home/akhters/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


In [None]:
len(algo.videos)

In [None]:
algo.videos[9]