# Generalized Advantage Estimation (GAE)

In [1]:
!apt-get install -y xvfb

!pip install gym==0.23.1 \
    pytorch-lightning==1.6 \
    pyvirtualdisplay

!pip install -U brax==0.0.12 jax==0.3.14 jaxlib==0.3.14+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

import warnings 
warnings.filterwarnings('ignore')

E: Could not open lock file /var/lib/dpkg/lock-frontend - open (13: Permission denied)
E: Unable to acquire the dpkg frontend lock (/var/lib/dpkg/lock-frontend), are you root?
  from pkg_resources import load_entry_point
Collecting gym==0.23.1
  Downloading gym-0.23.1.tar.gz (626 kB)
[K     |████████████████████████████████| 626 kB 1.2 MB/s eta 0:00:01
[?25h  Installing build dependencies ... [?25l^C
[?25hcanceled
[31mERROR: Operation cancelled by user[0m
  from pkg_resources import load_entry_point
Looking in links: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Collecting brax==0.0.12
  Downloading brax-0.0.12-py3-none-any.whl (186 kB)
[K     |████████████████████████████████| 186 kB 1.2 MB/s eta 0:00:01
[?25hCollecting jax==0.3.14
  Downloading jax-0.3.14.tar.gz (990 kB)
[K     |████████████████████████████████| 990 kB 32.0 MB/s eta 0:00:01
[?25hCollecting jaxlib==0.3.14+cuda11.cudnn82
  Downloading https://storage.googleapis.com/jax-releases/cuda11/jax

#### Setup virtual display

In [1]:
from pyvirtualdisplay import Display
Display(visible=False, size=(1400, 900)).start()

<pyvirtualdisplay.display.Display at 0x7f2e080e4d60>

#### Import the necessary code libraries

In [1]:
import copy
import torch
import random
import gym
import matplotlib
import functools
import itertools
import math

import numpy as np
import matplotlib.pyplot as plt

import torch.nn.functional as F

from collections import deque, namedtuple
from IPython.display import HTML
from base64 import b64encode

from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data.dataset import IterableDataset
from torch.optim import AdamW

from torch.distributions import Normal

from pytorch_lightning import LightningModule, Trainer

import brax.v1
from brax.v1 import envs
from brax.v1.envs import to_torch
from brax.v1.io import html

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
num_gpus = torch.cuda.device_count()

v = torch.ones(1, device='cuda')

In [2]:
device

'cuda:0'

In [3]:
@torch.no_grad()
def create_video(env, episode_length, policy=None):
  qp_array = []
  state = env.reset()
  for i in range(episode_length):
    if policy:
      loc, scale = policy(state)
      sample = torch.normal(loc, scale)
      action = torch.tanh(sample)
    else:
      action = env.action_space.sample()
    state, _, _, _ = env.step(action)
    qp_array.append(env.unwrapped._state.qp)
  return HTML(html.render(env.unwrapped._env.sys, qp_array))


@torch.no_grad()
def test_agent(env, episode_length, policy, episodes=10):

  ep_returns = []
  for ep in range(episodes):
    state = env.reset()
    done = False
    ep_ret = 0.0

    while not done:
      loc, scale = policy(state)
      sample = torch.normal(loc, scale)
      action = torch.tanh(sample)   
      state, reward, done, info = env.step(action)
      ep_ret += reward.item()

    ep_returns.append(ep_ret)

  return sum(ep_returns) / episodes  

#### Create the policy

In [4]:
class GradientPolicy(nn.Module):

  def __init__(self, in_features, out_dims, hidden_size=128):
    super().__init__()
    self.fc1 = nn.Linear(in_features, hidden_size)
    self.fc2 = nn.Linear(hidden_size, hidden_size)
    self.fc_mu = nn.Linear(hidden_size, out_dims)
    self.fc_std = nn.Linear(hidden_size, out_dims)

  def forward(self, x):
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    loc = self.fc_mu(x)
    loc = torch.tanh(loc)
    scale = self.fc_std(x)
    scale = F.softplus(scale) + 0.001
    return loc, scale

#### Create the value network

In [5]:
class ValueNet(nn.Module):

  def __init__(self, in_features, hidden_size=128):
    super().__init__()
    self.fc1 = nn.Linear(in_features, hidden_size)
    self.fc2 = nn.Linear(hidden_size, hidden_size)
    self.fc3 = nn.Linear(hidden_size, 1)

  def forward(self, x):
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x

#### Create the environment

In [6]:
class RunningMeanStd:
    # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
    def __init__(self, epsilon=1e-4, shape=()):
        self.mean = torch.zeros(shape, dtype=torch.float32).to(device)
        self.var = torch.ones(shape, dtype=torch.float32).to(device)
        self.count = epsilon

    def update(self, x):
        batch_mean = torch.mean(x, dim=0)
        batch_var = torch.var(x, dim=0)
        batch_count = x.shape[0]
        self.update_from_moments(batch_mean, batch_var, batch_count)

    def update_from_moments(self, batch_mean, batch_var, batch_count):
        self.mean, self.var, self.count = update_mean_var_count_from_moments(
            self.mean, self.var, self.count, batch_mean, batch_var, batch_count
        )


def update_mean_var_count_from_moments(
    mean, var, count, batch_mean, batch_var, batch_count
):
    delta = batch_mean - mean
    tot_count = count + batch_count

    new_mean = mean + delta * batch_count / tot_count
    m_a = var * count
    m_b = batch_var * batch_count
    M2 = m_a + m_b + torch.square(delta) * count * batch_count / tot_count
    new_var = M2 / tot_count
    new_count = tot_count

    return new_mean, new_var, new_count


class NormalizeObservation(gym.core.Wrapper):

    def __init__(self, env, epsilon=1e-8):
        super().__init__(env)
        self.num_envs = getattr(env, "num_envs", 1)
        self.obs_rms = RunningMeanStd(shape=self.observation_space.shape[-1])
        self.epsilon = epsilon

    def step(self, action):
        obs, rews, dones, infos = self.env.step(action)
        obs = self.normalize(obs)
        return obs, rews, dones, infos

    def reset(self, **kwargs):
        return_info = kwargs.get("return_info", False)
        if return_info:
            obs, info = self.env.reset(**kwargs)
        else:
            obs = self.env.reset(**kwargs)
        obs = self.normalize(obs)
        if not return_info:
            return obs
        else:
            return obs, info

    def normalize(self, obs):
        self.obs_rms.update(obs)
        return (obs - self.obs_rms.mean) / torch.sqrt(self.obs_rms.var + self.epsilon)

In [7]:
entry_point = functools.partial(envs.create_gym_env, env_name='halfcheetah')
gym.register('brax-halfcheetah-v0', entry_point=entry_point)

In [9]:
env = gym.make("brax-halfcheetah-v0", episode_length=1000)
env = to_torch.JaxToTorchWrapper(env, device=device)
create_video(env, 1000)

  logger.warn(
  logger.warn(
  logger.warn(
  logger.deprecation(
  if not isinstance(done, (bool, np.bool8)):
  logger.warn(
  logger.warn(
  logger.warn("Casting input x to numpy array.")
  logger.warn(


In [8]:
def create_env(env_name, num_envs=256, episode_length=1000):
  env = gym.make(env_name, batch_size=num_envs, episode_length=episode_length)
  env = to_torch.JaxToTorchWrapper(env, device=device)
  env = NormalizeObservation(env)
  return env

In [None]:
env = create_env('brax-halfcheetah-v0', num_envs=10)
obs = env.reset()
print("Num envs: ", obs.shape[0], "Obs dimentions: ", obs.shape[1])

In [None]:
env.action_space

In [None]:
obs, reward, done, info = env.step(env.action_space.sample())

In [None]:
info.keys()

#### Create the dataset

In [18]:
class RLDataset(IterableDataset):

  def __init__(self, env, policy, value_net, 
               samples_per_epoch, gamma, lamb, repeats):#lamb is the new parameter
    
    self.samples_per_epoch = samples_per_epoch
    self.gamma = gamma
    self.lamb = lamb
    self.repeats = repeats
    self.env = env
    self.policy = policy
    self.value_net = value_net
    self.obs = self.env.reset()
  
  @torch.no_grad()
  def __iter__(self):
    transitions = []
    for step in range(self.samples_per_epoch):
      loc, scale = self.policy(self.obs)
      action = torch.normal(loc, scale)
      next_obs, reward, done, info = self.env.step(action)
      transitions.append((self.obs, loc, scale, action, reward, done, next_obs))
      self.obs = next_obs
    
    transitions  = map(torch.stack, zip(*transitions))
    obs_b, loc_b, scale_b, action_b, reward_b, done_b, next_obs_b = transitions
    
    # print(f"This is obs_b {obs_b}")
    # print(f"This is loc_b {loc_b}")
    # print(f"This is scale_b {scale_b}")
    # print(f"This is action_b {action_b}")
    # print(f"This is reward_b {reward_b}")
    # print(f"This is done_b {done_b}")
    # print(f"This is next_obs_b {next_obs_b}")
    #right now, the shape of reward_b and done_b is (samples_per_epoch, num_envs)
    #and want to add a dimension with a single element to make them compatible with the rest (samples_per_epoch, num_envs, )
    
    reward_b = reward_b.unsqueeze(dim=-1)
    done_b = done_b.unsqueeze(dim=-1)
    # print(f"This is reward_b after modification {reward_b}")
    # print(f"This is done_b after modification {done_b}")
# This extra dimension is likely needed because the calculations for TD error, GAE, and targets involve broadcasting operations that require
# reward_b and done_b to have the same shape as other tensors like values_b and next_values_b. Without this extra dimension, the broadcasting would not work correctly, leading to shape mismatches.
# Summary
# First Code Snippet: Directly stacks and reshapes the tensors without adding an extra dimension, assuming that the shapes are already compatible.
# Second Code Snippet: Adds an extra dimension to reward_b and done_b to ensure they are compatible with other tensors for broadcasting in calculations of TD error, GAE, and targets.
# The decision to add the extra dimension depends on the specific requirements of the computations being performed. In the second snippet, it ensures that the tensors can be 
# broadcast correctly during the calculation of TD error and GAE.
# I still don't get why are we doing this reshaping of these tensors one more time extra!!

    values_b = self.value_net(obs_b)
    next_values_b = self.value_net(next_obs_b)
    # print(f"This is values_b {values_b}")
    # print(f"This is next_values_b {next_values_b}")
    #The following is the temproal difference residual td_error
    td_error_b = reward_b + (1 - done_b) * self.gamma * next_values_b - values_b
    #the following is similar to the way we calculated the return in REINFORCE algorihtm
    running_gae = torch.zeros((self.env.num_envs, 1), dtype=torch.float32, device=device)
    gae_b = torch.zeros_like(td_error_b)
    # print(f"This is td_error_b {td_error_b}")
    # print(f"This is running_gae {running_gae}")
    # print(f"this is gae_b {gae_b}")

    for row in range(self.samples_per_epoch - 1, -1, -1):
      # print(f'This is td_error_b[row] {td_error_b[row]}')
      # print(f'This is done_b[row] {done_b[row]}')
      running_gae = td_error_b[row] + (1 - done_b[row]) * self.gamma * self.lamb * running_gae
      # print(f"This is running_gae at end of of each iteration of the loop {running_gae}")
      gae_b[row] = running_gae
      # print(f"This is gae_b[row] at the end of each iteration of the loop {gae_b[row]}")
    #unlike other algorithms, here the value_net that will be passed will be the target_network itself
    #it's just like we are doing the computation for the target in the dataset itself, without having the need to call 
    #the target_network individually.
    #because, at the start of the first epoch, the traning value_net and target value_net are same
    #and after each epoch, the parameters of the trained value_net is copied to the parameters of the target value_net
    #so, basically, it means, to sample from the envrionment, it doens't matter which network we use
    #because our algorithm works in a way that at the time of sampling, traning value_net and target value_net are always same
    target_b = gae_b + values_b #but why are we estimating the target value like this?
    #remember that the estimated value of the states is the sum of rewards that we expect to obtain starting from that state
    # another way to estimate the value of the state is adding the reward obtained after following our policy plus a discounted estimate of the value of the nexd state
    #r_t + gamma * next_state_values, that is the target we have been using so far
    #but by adding the advantages of the actions taken in the trajectory to the value of the states, we are also computing
    #a more reliable estimate of the expected sum of rewards that we can use as a target 
    # print(f"This is target_b {target_b}")

    num_samples = self.samples_per_epoch * self.env.num_envs
    reshape_fn = lambda x: x.view(num_samples, -1)
    batch = [obs_b, loc_b, scale_b, action_b, reward_b, gae_b, target_b]

    obs_b, loc_b, scale_b, action_b, reward_b, gae_b, target_b = map(reshape_fn, batch)
    # print(f"This is obs_b after map modification {obs_b}")
    # print(f"This is loc_b after map modification {loc_b}")
    # print(f"This is scale_b after map modification {scale_b}")
    # print(f"This is action_b after map modification {action_b}")
    # print(f"This is reward_b after map modification {reward_b}")
    # print(f"This is gae_b after map modification {gae_b}")
    # print(f"This is target_b after map modification {target_b}")
    for repeat in range(self.repeats):
      idx = list(range(num_samples))
      random.shuffle(idx)

      for i in idx:
        yield obs_b[i], loc_b[i], scale_b[i], action_b[i], reward_b[i], gae_b[i], target_b[i]

### PLEASE DO NOT EXECUTE THE FOLLOWING CELL, IT"S FOR UNDERSTANDING THE DATA TYPE OF DIFFERENT ELEMENTS

In [17]:
# class MockPolicy:
#     def __call__(self, obs):
#         return torch.tensor([[0.5, 0.5]] * obs.shape[0])
      
# Define the RLDataset class
class RLDataset(IterableDataset):
    def __init__(self, env, policy, value_net, samples_per_epoch, gamma, lamb, repeats, device):
        self.samples_per_epoch = samples_per_epoch
        self.gamma = gamma
        self.lamb = lamb
        self.repeats = repeats
        self.env = env
        self.policy = policy
        self.value_net = value_net
        self.device = device
        # self.obs = torch.tensor(self.env.reset()[0], dtype=torch.float32).to(self.device)
        self.obs = self.env.reset().to(self.device) 
    
    @torch.no_grad()
    def __iter__(self):
        transitions = []
        for step in range(self.samples_per_epoch):
            loc, scale = self.policy(self.obs)
            action = torch.normal(loc, scale)
            next_obs, reward, done, info = self.env.step(action.cpu().numpy())
            next_obs = torch.tensor(next_obs, dtype=torch.float32).to(self.device)
            reward = torch.tensor(reward, dtype=torch.float32).to(self.device)
            done = torch.tensor(done, dtype=torch.float32).to(self.device)
            transitions.append((self.obs, loc, scale, action, reward, done, next_obs))
            self.obs = next_obs
        
        transitions = map(torch.stack, zip(*transitions))
        obs_b, loc_b, scale_b, action_b, reward_b, done_b, next_obs_b = transitions
        print(f"This is obs_b {obs_b}")
        print(f"This is loc_b {loc_b}")
        print(f"This is scale_b {scale_b}")
        print(f"This is action_b {action_b}")
        print(f"This is reward_b {reward_b}")
        print(f"This is done_b {done_b}")
        print(f"This is next_obs_b {next_obs_b}")
        
        reward_b = reward_b.unsqueeze(dim=-1)
        done_b = done_b.unsqueeze(dim=-1)
        print(f"This is reward_b after modification {reward_b}")
        print(f"This is done_b after modification {done_b}")

        values_b = self.value_net(obs_b)
        next_values_b = self.value_net(next_obs_b)
        print(f"This is values_b {values_b}")
        print(f"This is next_values_b {next_values_b}")

        td_error_b = reward_b + (1 - done_b) * self.gamma * next_values_b - values_b
        running_gae = torch.zeros((self.env.num_envs, 1), dtype=torch.float32, device=self.device)
        gae_b = torch.zeros_like(td_error_b)
        print(f"This is td_error_b {td_error_b}")
        print(f"This is running_gae {running_gae}")
        print(f"This is gae_b {gae_b}")

        for row in range(self.samples_per_epoch - 1, -1, -1):
            print(f'This is td_error_b[row] {td_error_b[row]}')
            print(f'This is done_b[row] {done_b[row]}')
            running_gae = td_error_b[row] + (1 - done_b[row]) * self.gamma * self.lamb * running_gae
            print(f"This is running_gae at end of each iteration of the loop {running_gae}")
            gae_b[row] = running_gae
            print(f"This is gae_b[row] at the end of each iteration of the loop {gae_b[row]}")

        target_b = gae_b + values_b
        print(f"This is target_b {target_b}")

        num_samples = self.samples_per_epoch * self.env.num_envs
        reshape_fn = lambda x: x.view(num_samples, -1)
        batch = [obs_b, loc_b, scale_b, action_b, reward_b, gae_b, target_b]

        obs_b, loc_b, scale_b, action_b, reward_b, gae_b, target_b = map(reshape_fn, batch)
        print(f"This is obs_b after map modification {obs_b}")
        print(f"This is loc_b after map modification {loc_b}")
        print(f"This is scale_b after map modification {scale_b}")
        print(f"This is action_b after map modification {action_b}")
        print(f"This is reward_b after map modification {reward_b}")
        print(f"This is gae_b after map modification {gae_b}")
        print(f"This is target_b after map modification {target_b}")

        for repeat in range(self.repeats):
            idx = list(range(num_samples))
            random.shuffle(idx)

            for i in idx:
                yield obs_b[i], loc_b[i], scale_b[i], action_b[i], reward_b[i], gae_b[i], target_b[i]


# class SimpleValueNet(nn.Module):
#     def __init__(self, obs_dim):
#         super(SimpleValueNet, self).__init__()
#         self.fc = nn.Linear(obs_dim, 1)
    
#     def forward(self, x):
#         return self.fc(x)
# # Define a simple policy network
# class SimplePolicy(nn.Module):
#     def __init__(self, obs_dim, action_dim):
#         super(SimplePolicy, self).__init__()
#         self.fc = nn.Linear(obs_dim, action_dim * 2)
        
#     def forward(self, x):
#         params = self.fc(x)
#         loc = params[:, :action_dim]
#         scale = F.softplus(params[:, action_dim:])
#         return loc, scale

      
# Create mock environment and policy
env = create_env('brax-halfcheetah-v0', num_envs=3)
# policy = SimplePolicy(env.observation_space.shape[1], env.action_space.shape[1])
# value_net = SimpleValueNet(env.observation_space.shape[1])
obs_size = env.observation_space.shape[1]
action_dims = env.action_space.shape[1]

policy = GradientPolicy(obs_size, action_dims, 256).to(device)
value_net = ValueNet(obs_size, 256).to(device)
# Set hyperparameters
samples_per_epoch = 5
gamma = 0.99
lamb = 0.95
repeats = 2
# Create the dataset
# dataset = RLDataset(env, policy, samples_per_epoch=5, gamma=0.99)
dataset = RLDataset(env, policy, value_net, samples_per_epoch, gamma, lamb, repeats, device=device)


# Run the __iter__ method
for data in dataset:
    print(data)

  logger.deprecation(
  if not isinstance(done, (bool, np.bool8)):
  logger.warn(
  logger.warn(
  logger.warn("Casting input x to numpy array.")
  logger.warn(
  next_obs = torch.tensor(next_obs, dtype=torch.float32).to(self.device)
  reward = torch.tensor(reward, dtype=torch.float32).to(self.device)
  done = torch.tensor(done, dtype=torch.float32).to(self.device)


This is obs_b tensor([[[ 6.5935e-01,  4.0805e-03,  0.0000e+00, -3.3213e-01, -9.4903e-01,
          -1.0375e+00, -2.6760e-01,  9.6362e-01,  1.0139e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  5.7962e-01, -6.4751e-01,  1.1491e+00,
          -1.1388e+00,  1.1314e+00,  4.6301e-02],
         [ 2.6026e-01,  4.0805e-03,  0.0000e+00, -7.1501e-01,  8.9594e-02,
           8.9902e-02, -8.3471e-01,  4.7483e-02, -9.7772e-01,  0.0000e+00,
           0.0000e+00,  0.0000e+00, -1.1496e+00,  1.1498e+00, -6.1297e-01,
           4.5863e-01, -4.1879e-01,  9.6338e-01],
         [-9.1459e-01,  4.0805e-03,  0.0000e+00,  1.0491e+00,  8.5908e-01,
           9.4782e-01,  1.1021e+00, -1.0117e+00, -3.6115e-02,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  5.7003e-01, -5.0231e-01, -5.3612e-01,
           6.8026e-01, -7.1280e-01, -1.0097e+00]],

        [[-3.1476e-01,  4.6030e-02, -4.7280e-01,  5.2663e-01,  9.2498e-01,
          -1.0981e+00,  2.5246e-01, -7.2102e-01, -8.3667e-01,  2.5838e-01,
         

#### Create PPO with generalized advantage estimation (GAE)

In [23]:
class PPO(LightningModule):

  def __init__(self, env_name, num_envs=2048, episode_length=1_000, 
               batch_size=1024, hidden_size=256, samples_per_epoch=5, 
               policy_lr=1e-4, value_lr=1e-3, epoch_repeat=8, epsilon=0.3, 
               gamma=0.99, lamb=0.95, entropy_coef=0.2, optim=AdamW):
    self.automatic_optimization = False  # Disable automatic optimization
    #for using more than two optimisers
    super().__init__()
    self.automatic_optimization = False  # Disable automatic optimization
    #for using more than two optimisers
    self.env = create_env(env_name, num_envs=num_envs, episode_length=episode_length)
    test_env = gym.make(env_name, episode_length=episode_length)
    test_env = to_torch.JaxToTorchWrapper(test_env, device=device)
    self.test_env = NormalizeObservation(test_env)
    self.test_env.obs_rms = self.env.obs_rms

    obs_size = self.env.observation_space.shape[1]
    action_dims = self.env.action_space.shape[1]

    self.policy = GradientPolicy(obs_size, action_dims, hidden_size)
    self.value_net = ValueNet(obs_size, hidden_size)
    self.target_value_net = copy.deepcopy(self.value_net)

    self.dataset = RLDataset(self.env, self.policy, self.target_value_net, 
                             samples_per_epoch, gamma, lamb, epoch_repeat)

    self.save_hyperparameters()
    self.videos = []
  
  def configure_optimizers(self):
    value_opt = self.hparams.optim(self.value_net.parameters(), lr=self.hparams.value_lr)
    policy_opt = self.hparams.optim(self.policy.parameters(), lr=self.hparams.policy_lr)
    return value_opt, policy_opt

  def train_dataloader(self):
    return DataLoader(dataset=self.dataset, batch_size=self.hparams.batch_size)

  # Training step.
  def training_step(self, batch, batch_idx):
    obs_b, loc_b, scale_b, action_b, reward_b, gae_b, target_b = batch
    value_opt, policy_opt = self.optimizers()
    state_values = self.value_net(obs_b)

    # if optimizer_idx == 0:
    value_loss = F.smooth_l1_loss(state_values, target_b)
    self.log("episode/Value Loss", value_loss)
    value_opt.zero_grad()
    value_loss.backward()
    value_opt.step()
      # return loss
    
    # elif optimizer_idx == 1:

    new_loc, new_scale = self.policy(obs_b)
    dist = Normal(new_loc, new_scale)
    log_prob = dist.log_prob(action_b).sum(dim=-1, keepdim=True)

    prev_dist = Normal(loc_b, scale_b)
    prev_log_prob = prev_dist.log_prob(action_b).sum(dim=-1, keepdim=True)

    rho_s = torch.exp(log_prob - prev_log_prob)

    surrogate_1 = rho_s * gae_b
    surrogate_2 = rho_s.clip(1 - self.hparams.epsilon, 1 + self.hparams.epsilon) * gae_b
    policy_loss = - torch.minimum(surrogate_1, surrogate_2)
    
    entropy = dist.entropy().sum(dim=-1, keepdim=True)
    true_loss = policy_loss - self.hparams.entropy_coef * entropy

    self.log("episode/Policy Loss", policy_loss.mean())
    self.log("episode/Entropy", entropy.mean())
    self.log("episode/Reward", reward_b.mean())
    policy_opt.zero_grad()
    true_loss.mean().backward()
    policy_opt.step()
      # return loss.mean()

  def on_train_epoch_end(self):
    self.target_value_net.load_state_dict(self.value_net.state_dict())

    if self.current_epoch % 10 == 0:
      average_return = test_agent(self.test_env, self.hparams.episode_length, self.policy, episodes=1)
      self.log("episode/Average Return", average_return)

    if self.current_epoch % 50 == 0:
      video = create_video(self.test_env, self.hparams.episode_length, policy=self.policy)
      self.videos.append(video)

#### Purge logs and run the visualization tool (Tensorboard)

In [None]:
# Start tensorboard.
!rm -r /content/lightning_logs/
!rm -r /content/videos/
%load_ext tensorboard
%tensorboard --logdir /content/lightning_logs/

#### Train the policy

In [24]:
algo = PPO('brax-halfcheetah-v0')

trainer = Trainer(
                # gpus=num_gpus,
                max_epochs=5000)

trainer.fit(algo)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name             | Type           | Params | Mode 
------------------------------------------------------------
0 | policy           | GradientPolicy | 73.7 K | train
1 | value_net        | ValueNet       | 70.9 K | train
2 | target_value_net | ValueNet       | 70.9 K | train
------------------------------------------------------------
215 K     Trainable params
0         Non-trainable params
215 K     Total params
0.862     Total estimated model params size (MB)
/home/akhters/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
  logger.deprecation(
  if not isinstance(done, (bool, np.

Training: |          | 0/? [00:00<?, ?it/s]

/home/akhters/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


In [None]:
algo.videos[-1]