```
Author: Ehsan Kamalinejad (EK)
Created: 2023-02-25
```

# PPO Training

This notebook is a basic implementation of reinforcement learning (RL) training through proximal policy optimization (PPO). PPO is the default RL training for many problems at OpenAI and has great performance while very flexible. Here, the focus is education and simplicity.

Here, we will learn how to train a RL agent from scratch through PPO. This will be useful when we do (reinforcement learning with human feedback) RLHF to fine-tune language models in future lectures.

Pre-requisites:
- Intermediate level familiarity with Python and PyTorch.
- Intermediate level familiarity with general concepts in machine learning (ML) and gradient based optimization.
- Basic familiarity with concepts in RL such as environments, rewards, agents, etc.

Dependencies:
```
pip install torch --extra-index-url https://download.pytorch.org/whl/cu116
pip install moviepy omegaconf matplotlib
pip install gym==0.26.2
pip install git+https://github.com/carlosluis/stable-baselines3@fix_tests
pip install gym[classic_control] gym[atari] gym[accept-rom-license] gym[other]
```

References:
- [EK's Video Lecture](https://www.youtube.com/watch?v=3uvnoVjM8nY) This is the lecture where we did a deep dive into the theory of PPO.
- [OpenAI PPO Repo](https://github.com/openai/baselines/blob/master/baselines/ppo2/runner.py) This is helpful as a reference for further implementations.
- [PPO Paper](https://arxiv.org/abs/1707.06347) This is the original paper that introduced PPO.
- [Sergey Levine UC Berkley CS285](http://rail.eecs.berkeley.edu/deeprlcourse/) This is a complete course in RL.
- [Pieter Abbeel mini-course](https://www.youtube.com/watch?v=2GwBez0D20A&list=PLwRJQ4m4UJjNymuBM9RdmB3Z9N5-0IlY0) This is a mini-course focusing on TRPO, PPO, DDPG and model free RL.
- [OpenAI Documentation on RL](https://spinningup.openai.com/en/latest/index.html) THis is OpenAI documentation on RL and parts of our code was borrowed from here.
- [labml.ai](https://nn.labml.ai/) This repo contains popular papers with their annotated PyTorch implementations.
- [cleanrl](https://github.com/vwxyzjn/cleanrl) This repo has clean implementations of RL algorithms and parts of our code was borrowed from here.

In [1]:
# !pip install torch --extra-index-url https://download.pytorch.org/whl/cu116
!pip install moviepy omegaconf matplotlib
!pip install gym==0.26.2
!pip install git+https://github.com/carlosluis/stable-baselines3@fix_tests
!pip install gym[classic_control] gym[atari] gym[accept-rom-license] gym[other]

Collecting omegaconf
  Obtaining dependency information for omegaconf from https://files.pythonhosted.org/packages/e3/94/1843518e420fa3ed6919835845df698c7e27e183cb997394e4a670973a65/omegaconf-2.3.0-py3-none-any.whl.metadata
  Downloading omegaconf-2.3.0-py3-none-any.whl.metadata (3.9 kB)
Collecting antlr4-python3-runtime==4.9.* (from omegaconf)
  Using cached antlr4-python3-runtime-4.9.3.tar.gz (117 kB)
  Preparing metadata (setup.py) ... [?25ldone
Using cached omegaconf-2.3.0-py3-none-any.whl (79 kB)
Building wheels for collected packages: antlr4-python3-runtime
  Building wheel for antlr4-python3-runtime (setup.py) ... [?25ldone
[?25h  Created wheel for antlr4-python3-runtime: filename=antlr4_python3_runtime-4.9.3-py3-none-any.whl size=144554 sha256=dee7a7ea43993b6873a76982ff4e38d940859e768a2f5dc739d10be4b80323a4
  Stored in directory: /Users/cohlem/Library/Caches/pip/wheels/1a/97/32/461f837398029ad76911109f07047fde1d7b661a147c7c56d1
Successfully built antlr4-python3-runtime
Insta

Collecting importlib-metadata~=4.13 (from stable_baselines3==2.0.0a0)
  Obtaining dependency information for importlib-metadata~=4.13 from https://files.pythonhosted.org/packages/d0/98/c277899f5aa21f6e6946e1c83f2af650cbfee982763ffb91db07ff7d3a13/importlib_metadata-4.13.0-py3-none-any.whl.metadata
  Downloading importlib_metadata-4.13.0-py3-none-any.whl.metadata (4.9 kB)
Downloading importlib_metadata-4.13.0-py3-none-any.whl (23 kB)
Building wheels for collected packages: stable_baselines3
  Building wheel for stable_baselines3 (pyproject.toml) ... [?25ldone
[?25h  Created wheel for stable_baselines3: filename=stable_baselines3-2.0.0a0-py3-none-any.whl size=174675 sha256=67d7e9f421066cf3a81cbb1e571345da13ae11bfb581cf1a862984484c478aa2
  Stored in directory: /private/var/folders/pv/2c84c3jx48v8485vx92jq6zc0000gn/T/pip-ephem-wheel-cache-8icna2yc/wheels/cc/3f/36/c2107a7756801b30e445eb0eea4eccce687ca61229763e2a3f
Successfully built stable_baselines3
Installing collected packages: importli

Downloading ale_py-0.8.1-cp311-cp311-macosx_11_0_arm64.whl (1.0 MB)
[2K   [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0mm eta [36m0:00:01[0m0:01[0mm
[?25hDownloading opencv_python-4.11.0.86-cp37-abi3-macosx_13_0_arm64.whl (37.3 MB)
[2K   [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m37.3/37.3 MB[0m [31m8.4 MB/s[0m eta [36m0:00:00[0mm eta [36m0:00:01[0m[36m0:00:01[0m
[?25hDownloading AutoROM-0.4.2-py3-none-any.whl (16 kB)
Downloading importlib_resources-6.5.2-py3-none-any.whl (37 kB)
Building wheels for collected packages: AutoROM.accept-rom-license
  Building wheel for AutoROM.accept-rom-license (pyproject.toml) ... [?25ldone
[?25h  Created wheel for AutoROM.accept-rom-license: filename=autorom_accept_rom_license-0.6.1-py3-none-any.whl size=446709 sha256=81134d21c8b4ff11e2975b101c4da4e52086564918ba89b3bda3ffa635dd1d42
  Stored in directory: /Users/cohlem/Library/Caches/pip/

In [2]:
import time
import random
import numpy as np
import matplotlib.pylab as plt
plt.style.use('dark_background')
from tqdm.notebook import tqdm
from omegaconf import DictConfig

import gymnasium as gym

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader


from IPython.display import Video

## Setup

In [3]:
seed = 2023
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [5]:
configs = {
    # experiment arguments
    "exp_name": "cartpole",
    "gym_id": "CartPole-v1", # the id of from OpenAI gym
    # training arguments
    "learning_rate": 1e-3, # the learning rate of the optimizer
    "total_timesteps": 1000000, # total timesteps of the training
    "max_grad_norm": 0.5, # the maximum norm allowed for the gradient
    # PPO parameters
    "num_trajcts": 32, # N
    "max_trajects_length": 64, # T
    "gamma": 0.99, # gamma
    "gae_lambda":0.95, # lambda for the generalized advantage estimation
    "num_minibatches": 2, # number of mibibatches used in each gradient
    "update_epochs": 2, # number of full rollout storage creations
    "clip_epsilon": 0.2, # the surrogate clipping coefficient
    "ent_coef": 0.01, # entroy coefficient controlling the exploration factor C2
    "vf_coef": 0.5, # value function controlling value estimation importance C1
    # visualization and print parameters
    "num_returns_to_average": 3, # how many episodes to use for printing average return
    "num_episodes_to_average": 23, # how many episodes to use for smoothing of the return diagram
    }

# batch_size is the size of the flatten sequences when trajcts are flatten
configs['batch_size'] = int(configs['num_trajcts'] * configs['max_trajects_length'])
# number of samples used in each gradient
configs['minibatch_size'] = int(configs['batch_size'] // configs['num_minibatches'])

configs = DictConfig(configs)

run_name = f"{configs.gym_id}__{configs.exp_name}__{seed}__{int(time.time())}"

In [53]:
12

12

In [55]:
sample = torch.zeros((2,4))

In [60]:
sample[0, 1:4] = 1

In [61]:
sample

tensor([[0., 1., 1., 1.],
        [0., 0., 0., 0.]])

## Env

`envs` is a set of parallel environments each holding a random initiali `state` and accepts an `action` to change and return its new state.

In [5]:
# create an env with random state
def make_env_func(gym_id, seed, idx, run_name, capture_video=False):
    def env_fun():
        env = gym.make(gym_id, render_mode="rgb_array")
        env = gym.wrappers.RecordEpisodeStatistics(env)
        if capture_video:
            # initiate the video capture if not already initiated
            if idx == 0:
                # wrapper to create the video of the performance
                env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
        env.action_space.seed(seed)
        env.observation_space.seed(seed)
        return env

    return env_fun

In [6]:
# create N envs
envs = []
for i in range(configs.num_trajcts):
    envs.append( make_env_func(configs.gym_id, seed + i, i, run_name) )
envs = gym.vector.SyncVectorEnv(envs)
envs

SyncVectorEnv(num_envs=32)

## Model

A simple fully connected model that gets a state and has two methods:
- `agent.value_func(state)` gets a state and returns the estimated expected total future rewards from that state $V_{\theta}(s)$.
- `agent.policy(state)` gets a state and returns next `action`, `log_prob` of actions, the `entropy` and `value`.

In [7]:
class FCBlock(nn.Module):
    """A generic fully connected residual block with good setup"""
    def __init__(self, embd_dim, dropout=0.2):
        super().__init__()
        self.block = nn.Sequential(
            nn.LayerNorm(embd_dim),
            nn.GELU(),
            nn.Linear(embd_dim, 4*embd_dim),
            nn.GELU(),
            nn.Linear(4*embd_dim, embd_dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return x + self.block(x)


class Agent(nn.Module):
    """an agent that creates actions and estimates values"""
    def __init__(self, env_observation_dim, action_space_dim, embd_dim=64, num_blocks=2):
        super().__init__()
        self.embedding_layer = nn.Linear(env_observation_dim, embd_dim)
        self.shared_layers = nn.Sequential(*[FCBlock(embd_dim=embd_dim) for _ in range(num_blocks)])
        self.value_head = nn.Linear(embd_dim, 1)
        self.policy_head = nn.Linear(embd_dim, action_space_dim)
        # orthogonal initialization with a hi entropy for more exploration at the start
        torch.nn.init.orthogonal_(self.policy_head.weight, 0.01)

    def value_func(self, state):
        hidden = self.shared_layers(self.embedding_layer(state))
        value = self.value_head(hidden)
        return value

    def policy(self, state, action=None):
        hidden = self.shared_layers(self.embedding_layer(state))
        logits = self.policy_head(hidden)
        # PyTorch categorical class helpful for sampling and probability calculations
        probs = Categorical(logits=logits)
        if action is None:
            action = probs.sample()
        return action, probs.log_prob(action), probs.entropy(), self.value_head(hidden)


In [8]:
import torch
from torch.distributions.categorical import Categorical

# Define a categorical distribution with specified probabilities
probs = torch.tensor([0.1, 0.2, 0.3, 0.4])  # Probabilities for 4 outcomes
dist = Categorical(probs)

# Sample from the distribution
sample = dist.sample()
print("Sampled outcome:", sample.item())

# Calculate the log probability of an outcome
log_prob = dist.log_prob(torch.tensor(2))  # Log probability of outcome 2
print("Log probability of outcome 2:", log_prob.item())

# Calculate the entropy of the distribution
entropy = dist.entropy()
print("Entropy of the distribution:", entropy.item())


Sampled outcome: 3
Log probability of outcome 2: -1.2039728164672852
Entropy of the distribution: 1.2798542976379395


### Generalized Advantage Estimation

In [9]:
def gae(
    cur_observation,  # the current state when advantages will be calculated
    rewards,          # rewards collected from trajectories of shape [num_trajcts, max_trajects_length]
    dones,            # binary marker of end of trajectories of shape [num_trajcts, max_trajects_length]
    values            # value estimates collected over trajectories of shape [num_trajcts, max_trajects_length]
):
    """
    Generalized Advantage Estimation (gae) estimating advantage of a particular trajecotry
    vs the expected return starting from a state
    """
    advantages = torch.zeros((configs.num_trajcts, configs.max_trajects_length))
    last_advantage = 0

    # the value after the last step
    with torch.no_grad():
        last_value = agent.value_func(cur_observation).reshape(1, -1)
#         last_value = agent.value_func(cur_observation)


    # reverse recursive to calculate advantages based on the delta formula
    for t in reversed(range(configs.max_trajects_length)):
        # mask if episode completed after step t
        

        mask = 1.0 - dones[:, t]
        last_value = last_value * mask
        
        last_advantage = last_advantage * mask
        delta = rewards[:, t] + configs.gamma * last_value - values[:, t]
        last_advantage = delta + configs.gamma * configs.gae_lambda * last_advantage

        
        advantages[:, t] = last_advantage
        last_value = values[:, t]

    advantages = advantages.to(device)
    returns = advantages + values

    return advantages, returns

### Creating Rollout Storage

In [10]:
def create_rollout(
    envs,            # parallel envs creating trajectories
    cur_observation, # starting observation of shape [num_trajcts, observation_dim]
    cur_done,        # current termination status of shape [num_trajcts,]
    all_returns      # a list to track returns
):
    """
    rollout phase: create parallel trajectories and store them in the rollout storage
    """

    # cache empty tensors to store the rollouts
    observations = torch.zeros((configs.num_trajcts, configs.max_trajects_length) +
                               envs.single_observation_space.shape).to(device)
    actions = torch.zeros((configs.num_trajcts, configs.max_trajects_length) +
                          envs.single_action_space.shape).to(device)
    logprobs = torch.zeros((configs.num_trajcts, configs.max_trajects_length)).to(device)
    rewards = torch.zeros((configs.num_trajcts, configs.max_trajects_length)).to(device)
    dones = torch.zeros((configs.num_trajcts, configs.max_trajects_length)).to(device)
    values = torch.zeros((configs.num_trajcts, configs.max_trajects_length)).to(device)

    for t in range(configs.max_trajects_length):
        observations[:,t] = cur_observation
        dones[:,t] = cur_done

        # give observation to the model and collect action, logprobs of actions, entropy and value
        with torch.no_grad():
            action, logprob, entropy, value = agent.policy(cur_observation)
        values[:,t] = value.flatten()
        actions[:,t] = action
        logprobs[:,t] = logprob

        # apply the action to the env and collect observation and reward
        cur_observation, reward, cur_done, _, info = envs.step(action.cpu().numpy())
        rewards[:,t] = torch.tensor(reward).to(device).view(-1)
        cur_observation = torch.Tensor(cur_observation).to(device)
        cur_done = torch.Tensor(cur_done).to(device)

        # if an episode ended store its total reward for progress report
#         if info:
#             for item in info['final_info']:
#                 if item and "episode" in item.keys():
#                     all_returns.append(item['episode']['r'])
#                     break

    # create the rollout storage
    rollout = {
        'cur_observation': cur_observation,
        'cur_done': cur_done,
        'observations': observations,
        'actions': actions,
        'logprobs': logprobs,
        'values': values,
        'dones': dones,
        'rewards': rewards
    }

    return rollout

In [11]:
class Storage(Dataset):
    def __init__(self, rollout, advantages, returns, envs):
        # fill in the storage and flatten the parallel trajectories
        self.observations = rollout['observations'].reshape((-1,) + envs.single_observation_space.shape)
        self.logprobs = rollout['logprobs'].reshape(-1)
        self.actions = rollout['actions'].reshape((-1,) + envs.single_action_space.shape).long()
        self.advantages = advantages.reshape(-1)
        self.returns = returns.reshape(-1)

    def __getitem__(self, ix: int):
        item = [
            self.observations[ix],
            self.logprobs[ix],
            self.actions[ix],
            self.advantages[ix],
            self.returns[ix]
        ]
        return item

    def __len__(self) -> int:
        return len(self.observations)

### Loss Functions

In [12]:
def loss_clip(
    mb_oldlogporb,     # old logprob of mini batch actions collected during the rollout
    mb_newlogprob,     # new logprob of mini batch actions created by the new policy
    mb_advantages      # mini batch of advantages collected during the the rollout
):
    """
    policy loss with clipping to control gradients
    """
    ratio = torch.exp(mb_newlogprob - mb_oldlogporb)
    policy_loss = -mb_advantages * ratio
    # clipped policy gradient loss enforces closeness
    clipped_loss = -mb_advantages * torch.clamp(ratio, 1 - configs.clip_epsilon, 1 + configs.clip_epsilon)
    pessimistic_loss = torch.max(policy_loss, clipped_loss).mean()
    return pessimistic_loss


def loss_vf(
    mb_oldreturns,  # mini batch of old returns collected during the rollout
    mb_newvalues    # minibach of values calculated by the new value function
):
    """
    enforcing the value function to give more accurate estimates of returns
    """
    mb_newvalues = mb_newvalues.view(-1)
    loss = 0.5 * ((mb_newvalues - mb_oldreturns) ** 2).mean()
    return loss

## Training

In [13]:
agent = Agent(
    env_observation_dim=envs.single_observation_space.shape[0],
    action_space_dim=envs.single_action_space.n
).to(device)

optimizer = optim.Adam(agent.parameters(), lr=configs.learning_rate)

In [98]:
all_returns = []

# initialize the game
cur_observation = torch.Tensor(envs.reset()[0]).to(device)
cur_done = torch.zeros(configs.num_trajcts).to(device)

In [99]:
configs.minibatch_size

1024

In [100]:
# track returns
all_returns = []

# initialize the game
cur_observation = torch.Tensor(envs.reset()[0]).to(device)
cur_done = torch.zeros(configs.num_trajcts).to(device)

# progress bar
num_updates = configs.total_timesteps // configs.batch_size
progress_bar = tqdm(total=num_updates)

for update in range(1, num_updates + 1):


    ##############################################
    # Phase 1: rollout creation

    # parallel envs creating trajectories
    rollout = create_rollout(envs, cur_observation, cur_done, all_returns)

    cur_done = rollout['cur_done']
    cur_observation = rollout['cur_observation']
    rewards = rollout['rewards']
    dones = rollout['dones']
    values = rollout['values']

    # calculating advantages
    advantages, returns = gae(cur_observation, rewards, dones, values)

    # a dataset containing the rollouts
    dataset = Storage(rollout, advantages, returns, envs)

    print(len(dataset))
    # a standard dataloader made out of current storage
    trainloader = DataLoader(dataset, batch_size=configs.minibatch_size, shuffle=True)


    ##############################################
    # Phase 2: model update

    # linearly shrink the lr from the initial lr to zero
    frac = 1.0 - (update - 1.0) / num_updates
    optimizer.param_groups[0]["lr"] = frac * configs.learning_rate

    # training loop
    for epoch in range(configs.update_epochs):
        for batch in trainloader:
            mb_observations, mb_logprobs, mb_actions, mb_advantages, mb_returns = batch

            # we calculate the distribution of actions through the updated model revisiting the old trajectories
            _, mb_newlogprob, mb_entropy, mb_newvalues = agent.policy(mb_observations, mb_actions)

            policy_loss = loss_clip(mb_logprobs, mb_newlogprob, mb_advantages)

            value_loss = loss_vf(mb_returns, mb_newvalues)

            # average entory of the action space
            entropy_loss = mb_entropy.mean()

            # full weighted loss
            loss = policy_loss - configs.ent_coef * entropy_loss + configs.vf_coef * value_loss

            print(loss)
            optimizer.zero_grad()
            loss.backward()

            # extra clipping of the gradients to avoid overshoots
            nn.utils.clip_grad_norm_(agent.parameters(), configs.max_grad_norm)
            optimizer.step()

    # progress bar
    if len(all_returns) > configs.num_returns_to_average:
        progress_bar.set_description(f"episode return: {np.mean(all_returns[-configs.num_returns_to_average:]):.2f}")
        progress_bar.refresh()
        progress_bar.update()
envs.close()

  0%|          | 0/488 [00:00<?, ?it/s]

2048
tensor(12.5200, grad_fn=<AddBackward0>)
tensor(11.2932, grad_fn=<AddBackward0>)
tensor(9.5838, grad_fn=<AddBackward0>)
tensor(8.5303, grad_fn=<AddBackward0>)
2048
tensor(10.4960, grad_fn=<AddBackward0>)
tensor(8.8961, grad_fn=<AddBackward0>)
tensor(7.8095, grad_fn=<AddBackward0>)
tensor(5.8538, grad_fn=<AddBackward0>)
2048
tensor(7.6547, grad_fn=<AddBackward0>)
tensor(6.2848, grad_fn=<AddBackward0>)
tensor(4.8548, grad_fn=<AddBackward0>)
tensor(3.2112, grad_fn=<AddBackward0>)
2048
tensor(8.4795, grad_fn=<AddBackward0>)
tensor(6.2684, grad_fn=<AddBackward0>)
tensor(4.6960, grad_fn=<AddBackward0>)
tensor(3.9718, grad_fn=<AddBackward0>)
2048
tensor(10.1310, grad_fn=<AddBackward0>)
tensor(8.5069, grad_fn=<AddBackward0>)
tensor(7.3593, grad_fn=<AddBackward0>)
tensor(6.1121, grad_fn=<AddBackward0>)
2048
tensor(12.7728, grad_fn=<AddBackward0>)
tensor(12.6545, grad_fn=<AddBackward0>)
tensor(13.1728, grad_fn=<AddBackward0>)
tensor(13.1289, grad_fn=<AddBackward0>)
2048
tensor(16.9630, grad_

tensor(21.3628, grad_fn=<AddBackward0>)
tensor(22.6316, grad_fn=<AddBackward0>)
tensor(19.8418, grad_fn=<AddBackward0>)
2048
tensor(22.0911, grad_fn=<AddBackward0>)
tensor(24.0254, grad_fn=<AddBackward0>)
tensor(22.8915, grad_fn=<AddBackward0>)
tensor(23.1287, grad_fn=<AddBackward0>)
2048
tensor(58.3377, grad_fn=<AddBackward0>)
tensor(71.5483, grad_fn=<AddBackward0>)
tensor(66.2170, grad_fn=<AddBackward0>)
tensor(60.9500, grad_fn=<AddBackward0>)
2048
tensor(23.4455, grad_fn=<AddBackward0>)
tensor(24.4350, grad_fn=<AddBackward0>)
tensor(21.1959, grad_fn=<AddBackward0>)
tensor(25.1333, grad_fn=<AddBackward0>)
2048
tensor(26.7439, grad_fn=<AddBackward0>)
tensor(26.3652, grad_fn=<AddBackward0>)
tensor(25.0380, grad_fn=<AddBackward0>)
tensor(27.3200, grad_fn=<AddBackward0>)
2048
tensor(20.2687, grad_fn=<AddBackward0>)
tensor(19.8037, grad_fn=<AddBackward0>)
tensor(18.2750, grad_fn=<AddBackward0>)
tensor(22.1706, grad_fn=<AddBackward0>)
2048
tensor(16.5483, grad_fn=<AddBackward0>)
tensor(16.

tensor(20.2355, grad_fn=<AddBackward0>)
tensor(24.4227, grad_fn=<AddBackward0>)
tensor(20.6883, grad_fn=<AddBackward0>)
2048
tensor(22.6260, grad_fn=<AddBackward0>)
tensor(19.9253, grad_fn=<AddBackward0>)
tensor(22.2208, grad_fn=<AddBackward0>)
tensor(19.7763, grad_fn=<AddBackward0>)
2048
tensor(8.2520, grad_fn=<AddBackward0>)
tensor(7.2097, grad_fn=<AddBackward0>)
tensor(7.1564, grad_fn=<AddBackward0>)
tensor(9.2074, grad_fn=<AddBackward0>)
2048
tensor(21.7018, grad_fn=<AddBackward0>)
tensor(22.4772, grad_fn=<AddBackward0>)
tensor(19.5793, grad_fn=<AddBackward0>)
tensor(25.5296, grad_fn=<AddBackward0>)
2048
tensor(21.5828, grad_fn=<AddBackward0>)
tensor(21.0041, grad_fn=<AddBackward0>)
tensor(19.0576, grad_fn=<AddBackward0>)
tensor(23.6924, grad_fn=<AddBackward0>)
2048
tensor(9.5945, grad_fn=<AddBackward0>)
tensor(9.6387, grad_fn=<AddBackward0>)
tensor(9.3784, grad_fn=<AddBackward0>)
tensor(8.7270, grad_fn=<AddBackward0>)
2048
tensor(24.3616, grad_fn=<AddBackward0>)
tensor(21.1731, gr

tensor(14.0100, grad_fn=<AddBackward0>)
tensor(15.7087, grad_fn=<AddBackward0>)
tensor(15.9030, grad_fn=<AddBackward0>)
tensor(14.0642, grad_fn=<AddBackward0>)
2048
tensor(14.7521, grad_fn=<AddBackward0>)
tensor(13.8309, grad_fn=<AddBackward0>)
tensor(15.8963, grad_fn=<AddBackward0>)
tensor(10.9333, grad_fn=<AddBackward0>)
2048
tensor(13.7706, grad_fn=<AddBackward0>)
tensor(10.6997, grad_fn=<AddBackward0>)
tensor(12.5791, grad_fn=<AddBackward0>)
tensor(12.1680, grad_fn=<AddBackward0>)
2048
tensor(13.8878, grad_fn=<AddBackward0>)
tensor(26.6002, grad_fn=<AddBackward0>)
tensor(17.0177, grad_fn=<AddBackward0>)
tensor(20.6987, grad_fn=<AddBackward0>)
2048
tensor(25.7322, grad_fn=<AddBackward0>)
tensor(25.9837, grad_fn=<AddBackward0>)
tensor(27.2851, grad_fn=<AddBackward0>)
tensor(23.1378, grad_fn=<AddBackward0>)
2048
tensor(13.1275, grad_fn=<AddBackward0>)
tensor(15.5818, grad_fn=<AddBackward0>)
tensor(15.1690, grad_fn=<AddBackward0>)
tensor(13.5136, grad_fn=<AddBackward0>)
2048
tensor(4.5

tensor(1.2496, grad_fn=<AddBackward0>)
tensor(1.2400, grad_fn=<AddBackward0>)
tensor(1.2760, grad_fn=<AddBackward0>)
tensor(1.0972, grad_fn=<AddBackward0>)
2048
tensor(15.2377, grad_fn=<AddBackward0>)
tensor(15.3091, grad_fn=<AddBackward0>)
tensor(16.0453, grad_fn=<AddBackward0>)
tensor(14.0682, grad_fn=<AddBackward0>)
2048
tensor(0.8019, grad_fn=<AddBackward0>)
tensor(1.0387, grad_fn=<AddBackward0>)
tensor(0.7754, grad_fn=<AddBackward0>)
tensor(0.9002, grad_fn=<AddBackward0>)
2048
tensor(0.6668, grad_fn=<AddBackward0>)
tensor(0.5950, grad_fn=<AddBackward0>)
tensor(0.6236, grad_fn=<AddBackward0>)
tensor(0.6637, grad_fn=<AddBackward0>)
2048
tensor(0.7542, grad_fn=<AddBackward0>)
tensor(0.7417, grad_fn=<AddBackward0>)
tensor(0.6082, grad_fn=<AddBackward0>)
tensor(0.8413, grad_fn=<AddBackward0>)
2048
tensor(1.0393, grad_fn=<AddBackward0>)
tensor(1.2188, grad_fn=<AddBackward0>)
tensor(0.8979, grad_fn=<AddBackward0>)
tensor(0.8894, grad_fn=<AddBackward0>)
2048
tensor(0.7934, grad_fn=<AddBac

tensor(-0.0195, grad_fn=<AddBackward0>)
tensor(-0.0386, grad_fn=<AddBackward0>)
tensor(-0.0094, grad_fn=<AddBackward0>)
2048
tensor(0.2359, grad_fn=<AddBackward0>)
tensor(0.3033, grad_fn=<AddBackward0>)
tensor(0.2817, grad_fn=<AddBackward0>)
tensor(0.2603, grad_fn=<AddBackward0>)
2048
tensor(0.0829, grad_fn=<AddBackward0>)
tensor(0.1515, grad_fn=<AddBackward0>)
tensor(0.2263, grad_fn=<AddBackward0>)
tensor(0.0727, grad_fn=<AddBackward0>)
2048
tensor(7.8275, grad_fn=<AddBackward0>)
tensor(12.3001, grad_fn=<AddBackward0>)
tensor(10.9493, grad_fn=<AddBackward0>)
tensor(9.0408, grad_fn=<AddBackward0>)
2048
tensor(0.3022, grad_fn=<AddBackward0>)
tensor(0.3086, grad_fn=<AddBackward0>)
tensor(0.3618, grad_fn=<AddBackward0>)
tensor(0.2499, grad_fn=<AddBackward0>)
2048
tensor(0.2615, grad_fn=<AddBackward0>)
tensor(0.3377, grad_fn=<AddBackward0>)
tensor(0.2497, grad_fn=<AddBackward0>)
tensor(0.3298, grad_fn=<AddBackward0>)
2048
tensor(0.1834, grad_fn=<AddBackward0>)
tensor(0.2196, grad_fn=<AddBa

2048
tensor(10.2622, grad_fn=<AddBackward0>)
tensor(10.7355, grad_fn=<AddBackward0>)
tensor(10.8788, grad_fn=<AddBackward0>)
tensor(9.8857, grad_fn=<AddBackward0>)
2048
tensor(0.0706, grad_fn=<AddBackward0>)
tensor(0.0869, grad_fn=<AddBackward0>)
tensor(0.0718, grad_fn=<AddBackward0>)
tensor(0.0376, grad_fn=<AddBackward0>)
2048
tensor(0.0730, grad_fn=<AddBackward0>)
tensor(0.0502, grad_fn=<AddBackward0>)
tensor(0.0647, grad_fn=<AddBackward0>)
tensor(0.0620, grad_fn=<AddBackward0>)
2048
tensor(2.0295, grad_fn=<AddBackward0>)
tensor(4.0413, grad_fn=<AddBackward0>)
tensor(1.7708, grad_fn=<AddBackward0>)
tensor(4.4228, grad_fn=<AddBackward0>)
2048
tensor(0.0758, grad_fn=<AddBackward0>)
tensor(0.0166, grad_fn=<AddBackward0>)
tensor(0.0296, grad_fn=<AddBackward0>)
tensor(0.1117, grad_fn=<AddBackward0>)
2048
tensor(0.1735, grad_fn=<AddBackward0>)
tensor(0.1007, grad_fn=<AddBackward0>)
tensor(0.0972, grad_fn=<AddBackward0>)
tensor(0.1117, grad_fn=<AddBackward0>)
2048
tensor(0.0912, grad_fn=<Ad

tensor(11.9302, grad_fn=<AddBackward0>)
tensor(7.7533, grad_fn=<AddBackward0>)
tensor(11.3252, grad_fn=<AddBackward0>)
2048
tensor(32.6305, grad_fn=<AddBackward0>)
tensor(37.4447, grad_fn=<AddBackward0>)
tensor(32.6611, grad_fn=<AddBackward0>)
tensor(35.5773, grad_fn=<AddBackward0>)
2048
tensor(33.2096, grad_fn=<AddBackward0>)
tensor(32.7352, grad_fn=<AddBackward0>)
tensor(27.5702, grad_fn=<AddBackward0>)
tensor(36.5906, grad_fn=<AddBackward0>)
2048
tensor(27.3151, grad_fn=<AddBackward0>)
tensor(37.0020, grad_fn=<AddBackward0>)
tensor(34.4508, grad_fn=<AddBackward0>)
tensor(27.0527, grad_fn=<AddBackward0>)
2048
tensor(23.8567, grad_fn=<AddBackward0>)
tensor(22.7242, grad_fn=<AddBackward0>)
tensor(23.5324, grad_fn=<AddBackward0>)
tensor(21.8429, grad_fn=<AddBackward0>)
2048
tensor(36.1844, grad_fn=<AddBackward0>)
tensor(29.9729, grad_fn=<AddBackward0>)
tensor(31.4542, grad_fn=<AddBackward0>)
tensor(31.7128, grad_fn=<AddBackward0>)
2048
tensor(16.8462, grad_fn=<AddBackward0>)
tensor(19.8

tensor(0.0839, grad_fn=<AddBackward0>)
tensor(0.0498, grad_fn=<AddBackward0>)
tensor(0.1125, grad_fn=<AddBackward0>)
2048
tensor(6.7691, grad_fn=<AddBackward0>)
tensor(8.5676, grad_fn=<AddBackward0>)
tensor(9.6110, grad_fn=<AddBackward0>)
tensor(5.7045, grad_fn=<AddBackward0>)
2048
tensor(0.1525, grad_fn=<AddBackward0>)
tensor(0.0555, grad_fn=<AddBackward0>)
tensor(0.1061, grad_fn=<AddBackward0>)
tensor(0.0801, grad_fn=<AddBackward0>)
2048
tensor(0.0848, grad_fn=<AddBackward0>)
tensor(0.1274, grad_fn=<AddBackward0>)
tensor(0.0939, grad_fn=<AddBackward0>)
tensor(0.0888, grad_fn=<AddBackward0>)
2048
tensor(0.1853, grad_fn=<AddBackward0>)
tensor(0.1890, grad_fn=<AddBackward0>)
tensor(0.1968, grad_fn=<AddBackward0>)
tensor(0.1624, grad_fn=<AddBackward0>)
2048
tensor(0.0500, grad_fn=<AddBackward0>)
tensor(0.0106, grad_fn=<AddBackward0>)
tensor(0.0603, grad_fn=<AddBackward0>)
tensor(-0.0049, grad_fn=<AddBackward0>)
2048
tensor(0.0537, grad_fn=<AddBackward0>)
tensor(0.0650, grad_fn=<AddBackwa

tensor(0.1110, grad_fn=<AddBackward0>)
tensor(0.1726, grad_fn=<AddBackward0>)
tensor(0.0779, grad_fn=<AddBackward0>)
2048
tensor(0.2092, grad_fn=<AddBackward0>)
tensor(0.1715, grad_fn=<AddBackward0>)
tensor(0.1985, grad_fn=<AddBackward0>)
tensor(0.1749, grad_fn=<AddBackward0>)
2048
tensor(0.1410, grad_fn=<AddBackward0>)
tensor(0.2559, grad_fn=<AddBackward0>)
tensor(0.1898, grad_fn=<AddBackward0>)
tensor(0.2119, grad_fn=<AddBackward0>)
2048
tensor(0.0666, grad_fn=<AddBackward0>)
tensor(0.1149, grad_fn=<AddBackward0>)
tensor(0.1048, grad_fn=<AddBackward0>)
tensor(0.0857, grad_fn=<AddBackward0>)
2048
tensor(0.1386, grad_fn=<AddBackward0>)
tensor(0.1368, grad_fn=<AddBackward0>)
tensor(0.1661, grad_fn=<AddBackward0>)
tensor(0.1296, grad_fn=<AddBackward0>)
2048
tensor(0.1235, grad_fn=<AddBackward0>)
tensor(0.1350, grad_fn=<AddBackward0>)
tensor(0.1128, grad_fn=<AddBackward0>)
tensor(0.1517, grad_fn=<AddBackward0>)
2048
tensor(0.0967, grad_fn=<AddBackward0>)
tensor(0.0930, grad_fn=<AddBackwar

# Analysis

In [26]:
if not len(all_returns)%configs.num_episodes_to_average==0:
    all_returns_truncated = np.array(all_returns[:-(len(all_returns)%configs.num_episodes_to_average)])
else:
    all_returns_truncated = all_returns
all_returns_smoothed = np.average(all_returns_truncated.reshape(-1, configs.num_episodes_to_average), axis=1)
print('mean reward:', np.mean(all_returns_smoothed))
print('std reward:', np.std(all_returns_smoothed))
print('max reward:', np.max(all_returns_smoothed))
print('converge mean reward:', np.mean(all_returns_smoothed[-1]))
plt.plot(all_returns_smoothed);


AttributeError: 'list' object has no attribute 'reshape'

## Inference

In [101]:
# create a test env
test_env = make_env_func(configs.gym_id, seed, 0, 'inference', True)()

# use the trained agent to run through the env till it terminates this is an eposide
observation, _ = test_env.reset()
observation = torch.unsqueeze(torch.tensor(observation),dim=0).to(device)
for _ in range(500):
    action, _, _, _ = agent.policy(observation)
    action = action.cpu().item()
    observation, reward, done, _, info = test_env.step(action)
    observation = torch.unsqueeze(torch.tensor(observation),dim=0).to(device)
    if done:
        break
test_env.close()

Video('/content/videos/inference/rl-video-episode-0.mp4', embed=True)

  logger.warn(


In [30]:
Video('/content/videos/inference/rl-video-episode-0.mp4', embed=True)

### BreakOut

You can use similar PPO setup to solve much more complex problems. The problem `"gym_id": "BreakoutNoFrameskip-v4"` is left as an exercise for you. The `env` and `agent` definitions are provided here but the rest are left to you. With the correct setup (very similar to what we did in CartPole), you should be able to get to respectable scores above 700. Even perfect score is possible with a tuned hyper-parameters with the same setup. Note that this is a much more complex problem and you will need to increase the `total_timesteps` by at least a factor of 30 to get to good results. This will take several hours.

```python
# extra imports
from stable_baselines3.common.atari_wrappers import (
    ClipRewardEnv,
    EpisodicLifeEnv,
    FireResetEnv,
    MaxAndSkipEnv,
    NoopResetEnv,
)


# env definition
def make_env_func(gym_id, seed, idx, run_name, capture_video=False):
    def env_fun():
        env = gym.make(gym_id, render_mode='rgb_array')
        env = gym.wrappers.RecordEpisodeStatistics(env)
        if capture_video:
            # initiate the video capture if not already initiated
            if idx == 0:
                # wrapper to create the video of the performance
                env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
        env = NoopResetEnv(env, noop_max=30)
        env = MaxAndSkipEnv(env, skip=4)
        env = EpisodicLifeEnv(env)
        if "FIRE" in env.unwrapped.get_action_meanings():
            env = FireResetEnv(env)
        env = ClipRewardEnv(env)
        env = gym.wrappers.ResizeObservation(env, (84, 84))
        env = gym.wrappers.GrayScaleObservation(env)
        env = gym.wrappers.FrameStack(env, 4)
        env.seed(seed)
        env.action_space.seed(seed)
        env.observation_space.seed(seed)
        return env
    return env_fun


# Model defnition
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer

class Agent(nn.Module):
    def __init__(self, envs):
        super(Agent, self).__init__()
        self.network = nn.Sequential(
            layer_init(nn.Conv2d(4, 32, 8, stride=4)),
            nn.ReLU(),
            layer_init(nn.Conv2d(32, 64, 4, stride=2)),
            nn.ReLU(),
            layer_init(nn.Conv2d(64, 64, 3, stride=1)),
            nn.AdaptiveMaxPool2d(output_size=(1, 1)),
            nn.Flatten(),
            layer_init(nn.Linear(64, 512)),
            nn.ReLU(),
        )
        self.value_head = layer_init(nn.Linear(512, 1), std=1)
        self.policy_head = layer_init(nn.Linear(512, envs.single_action_space.n), std=0.01)

    def value_func(self, x):
        return self.value_head(self.network(x / 255.0))

    def policy(self, x, action=None):
        hidden = self.network(x / 255.0)
        logits = self.policy_head(hidden)
        probs = Categorical(logits=logits)
        if action is None:
            action = probs.sample()
        return action, probs.log_prob(action), probs.entropy(), self.value_head(hidden)
```

## Complete

In [9]:
# create an env with random state
def make_env_func(gym_id, seed, idx, run_name, capture_video=False):
    def env_fun():
        env = gym.make(gym_id, render_mode="rgb_array")
        env = gym.wrappers.RecordEpisodeStatistics(env)
        if capture_video:
            # initiate the video capture if not already initiated
            if idx == 0:
                # wrapper to create the video of the performance
                env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
        env.action_space.seed(seed)
        env.observation_space.seed(seed)
        return env

    return env_fun

In [20]:


# create N envs
envs = []
for i in range(configs.num_trajcts):
    envs.append( make_env_func(configs.gym_id, seed + i, i, run_name) )
envs = gym.vector.SyncVectorEnv(envs)

# start the environment
cur_observation = envs.reset()[0]

class FCBlock(nn.Module):
    """A generic fully connected residual block with good setup"""
    def __init__(self, embd_dim, dropout=0.2):
        super().__init__()
        self.block = nn.Sequential(
            nn.LayerNorm(embd_dim),
            nn.GELU(),
            nn.Linear(embd_dim, 4*embd_dim),
            nn.GELU(),
            nn.Linear(4*embd_dim, embd_dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return x + self.block(x)



class Agent(nn.Module):
    """an agent that creates actions and estimates values"""
    def __init__(self, env_observation_dim, action_space_dim, embd_dim=64, num_blocks=2):
        super().__init__()
        self.embedding_layer = nn.Linear(env_observation_dim, embd_dim)
        self.shared_layers = nn.Sequential(*[FCBlock(embd_dim=embd_dim) for _ in range(num_blocks)])
        self.value_head = nn.Linear(embd_dim, 1)
        self.policy_head = nn.Linear(embd_dim, action_space_dim)
        # orthogonal initialization with a hi entropy for more exploration at the start
        torch.nn.init.orthogonal_(self.policy_head.weight, 0.01)

    def value_func(self, state):
        hidden = self.shared_layers(self.embedding_layer(state))
        value = self.value_head(hidden)
        return value

    def policy(self, state, action=None):
        hidden = self.shared_layers(self.embedding_layer(state))
        logits = self.policy_head(hidden)
        # PyTorch categorical class helpful for sampling and probability calculations
        probs = Categorical(logits=logits)
        if action is None:
            action = probs.sample()
        return action, probs.log_prob(action), probs.entropy(), self.value_head(hidden)

#     "num_trajcts": 32, # N
#     "max_trajects_length": 64, # T

def create_rollout(envs, cur_observation, cur_done, agent):

    observations = torch.zeros((cur_observation.shape[0], configs['max_trajects_length'],envs.single_observation_space.shape[0] ))
    actions = torch.zeros((cur_observation.shape[0], configs['max_trajects_length']) + envs.single_action_space.shape)
    dones =  torch.zeros((cur_observation.shape[0], configs['max_trajects_length']))
    rewards =  torch.zeros((cur_observation.shape[0], configs['max_trajects_length']))
    values = torch.zeros((cur_observation.shape[0], configs['max_trajects_length']))
    advantages = torch.zeros((cur_observation.shape[0], configs['max_trajects_length']))
    logprobs = torch.zeros((cur_observation.shape[0], configs['max_trajects_length']))
     
    
    for t in range(configs['max_trajects_length']):
        # get the policy
        with torch.no_grad():
            action, logprob,_,value = agent.policy(cur_observation)
            
        observations[:,t,:] = cur_observation
        actions[:,t] = action
        dones[:,t] = cur_done
        logprobs[:,t] = logprob
        
        cur_observation, cur_reward, cur_done,_,_ = envs.step(action.cpu().numpy())
        cur_observation = torch.tensor(cur_observation)
        cur_reward = torch.tensor(cur_reward)
        cur_done = torch.tensor(cur_done)
        
        rewards[:,t] = torch.tensor(cur_reward)

        values[:,t] = value.squeeze()
       
    # Advantage is approximated reverse recursively
#     advantage = gae(observations,dones, rewards,values, advantages)
    
    return {
        "cur_observation": cur_observation,
        "observations": observations,
        "actions" : actions,
        "dones" : dones,
        "rewards" : rewards,
        "values" : values,
        "advantages" : advantages,
        "logprobs" : logprobs
    }

agent = Agent(
    env_observation_dim=envs.single_observation_space.shape[0],
    action_space_dim=envs.single_action_space.n
).to(device)

optimizer = optim.Adam(agent.parameters(), lr=configs.learning_rate)


def gae(rewards, values, dones, cur_observation, agent):
    last_advantage = 0
    with torch.no_grad():
        last_value = agent.value_func(cur_observation).reshape(1,-1)
    advantages = torch.zeros_like(values)
    
    for t in reversed(range(configs['max_trajects_length'])):
        mask = 1.0 - dones[:,t]
        last_advantage = mask*last_advantage
        last_value = mask*last_value
        delta = rewards[:,t] + configs['gamma']*last_value - values[:,t]
        advantages[:,t] = delta + configs['gae_lambda']*configs['gamma']*last_advantage
        last_value = values[:,t]
        last_advantage = advantages[:,t]
    
    returns = advantages + values
        
    return advantages, returns


def loss_clip(
    mb_oldlogporb,     # old logprob of mini batch actions collected during the rollout
    mb_newlogprob,     # new logprob of mini batch actions created by the new policy
    mb_advantages      # mini batch of advantages collected during the the rollout
):
    """
    policy loss with clipping to control gradients
    """
    ratio = torch.exp(mb_newlogprob - mb_oldlogporb)
    policy_loss = -mb_advantages * ratio
    # clipped policy gradient loss enforces closeness
    clipped_loss = -mb_advantages * torch.clamp(ratio, 1 - configs.clip_epsilon, 1 + configs.clip_epsilon)
    pessimistic_loss = torch.max(policy_loss, clipped_loss).mean()
    return pessimistic_loss


def loss_vf(
    mb_oldreturns,  # mini batch of old returns collected during the rollout
    mb_newvalues    # minibach of values calculated by the new value function
):
    """
    enforcing the value function to give more accurate estimates of returns
    """
    mb_newvalues = mb_newvalues.view(-1)
    loss = 0.5 * ((mb_newvalues - mb_oldreturns) ** 2).mean()
    return loss


class Storage(Dataset):
    def __init__(self, rollout, advantages, returns, envs):
        # fill in the storage and flatten the parallel trajectories
        self.observations = rollout['observations'].reshape((-1,) + envs.single_observation_space.shape)
        self.logprobs = rollout['logprobs'].reshape(-1)
        self.actions = rollout['actions'].reshape((-1,) + envs.single_action_space.shape).long()
        self.advantages = advantages.reshape(-1)
        self.returns = returns.reshape(-1)

    def __getitem__(self, ix: int):
        item = [
            self.observations[ix],
            self.logprobs[ix],
            self.actions[ix],
            self.advantages[ix],
            self.returns[ix]
        ]
        return item

    def __len__(self) -> int:
        return len(self.observations)
    
import torch
from torch.distributions.categorical import Categorical
# Create the environment
envs = []
for i in range(configs.num_trajcts):
    envs.append( make_env_func(configs.gym_id, seed + i, i, run_name) )
envs = gym.vector.SyncVectorEnv(envs)
cur_observation = torch.tensor(envs.reset()[0])
cur_done = torch.tensor(torch.zeros(configs.num_trajcts))


for i in range(int((configs.total_timesteps/configs.batch_size))):

    frac = 1.0 - (i - 1.0) / (configs.total_timesteps/configs.batch_size)
    optimizer.param_groups[0]["lr"] = frac * configs.learning_rate
    # Phase 1: Create rollout
    rollouts = create_rollout(envs,cur_observation,cur_done, agent)
    advantages,returns = gae(rollouts['rewards'], rollouts['values'], rollouts['dones'], rollouts['cur_observation'], agent)
#     rollout, advantages, returns, envs
    dataset = Storage(rollouts, advantages, returns, envs)
    dataloader = DataLoader(dataset, batch_size=configs.minibatch_size, shuffle=True)
    
    # Phase 2: Update
    for j in range(configs.update_epochs):
        for data in dataloader: # mini_batch
            mb_ob,mb_logprobs, mb_actions, mb_advantages, mb_returns = data

            new_actions, new_logprobs, new_entropy, new_values = agent.policy(mb_ob, mb_actions)
            
#             print(f'old logprobs: {mb_logprobs[0]}, new_logprobs: {new_logprobs[0]}, old advantage: {mb_advantages[0]}, old return: {mb_returns[0]}, new values: {new_values[0]} ')
            
#             print(new_logprobs[0] / mb_logprobs[0]*mb_advantages[0])
            c_loss = loss_clip(mb_logprobs, new_logprobs, mb_advantages)
            vf_loss = loss_vf(mb_returns, new_values)
            entropy = new_entropy.mean()
            
            # maximize the PPO clip loss, minimize the value loss and maximize the entropy(exploration)
            # since pytorch's optimizer are configured to do gradient descent under the hood, i.e W - lr* deltaW
            # we need to multiply the loss the with negative(-) if we need to do gradient ascent
            # we already multiplied it in the clip loss function and we multiply it in entropy.
            
            loss = c_loss + configs.vf_coef*vf_loss - configs.ent_coef*entropy
            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(agent.parameters(), configs.max_grad_norm)
            optimizer.step()
            
            print('----------------------------------------------------------')
            print(f'TOTAL LOSS, {loss}, CLIP loss: {c_loss}, value function loss: {vf_loss}, entropy loss: {entropy}')


  cur_done = torch.tensor(torch.zeros(configs.num_trajcts))
  rewards[:,t] = torch.tensor(cur_reward)


----------------------------------------------------------
TOTAL LOSS, 10.348729133605957, CLIP loss: -7.379720687866211, value function loss: 35.470760345458984, entropy loss: 0.6930763721466064
----------------------------------------------------------
TOTAL LOSS, 8.243270874023438, CLIP loss: -7.169240951538086, value function loss: 30.838886260986328, entropy loss: 0.6930961012840271
----------------------------------------------------------
TOTAL LOSS, 7.106352806091309, CLIP loss: -7.216591835021973, value function loss: 28.659748077392578, entropy loss: 0.6929196715354919
----------------------------------------------------------
TOTAL LOSS, 6.165403842926025, CLIP loss: -7.365848541259766, value function loss: 27.076356887817383, entropy loss: 0.6925910711288452
----------------------------------------------------------
TOTAL LOSS, 10.36641788482666, CLIP loss: -7.085920810699463, value function loss: 34.91851806640625, entropy loss: 0.6920415163040161
-------------------------

----------------------------------------------------------
TOTAL LOSS, 15.944039344787598, CLIP loss: -0.7599526047706604, value function loss: 33.42106628417969, entropy loss: 0.6541071534156799
----------------------------------------------------------
TOTAL LOSS, 15.573347091674805, CLIP loss: -1.3486592769622803, value function loss: 33.85710144042969, entropy loss: 0.6543787717819214
----------------------------------------------------------
TOTAL LOSS, 15.704325675964355, CLIP loss: -1.150803804397583, value function loss: 33.72331237792969, entropy loss: 0.6526653170585632
----------------------------------------------------------
TOTAL LOSS, 14.707890510559082, CLIP loss: -0.9192712306976318, value function loss: 31.267274856567383, entropy loss: 0.6475383639335632
----------------------------------------------------------
TOTAL LOSS, 17.696914672851562, CLIP loss: -1.1446342468261719, value function loss: 37.696044921875, entropy loss: 0.6473550796508789
----------------------

----------------------------------------------------------
TOTAL LOSS, 17.09221839904785, CLIP loss: -2.036982297897339, value function loss: 38.2701301574707, entropy loss: 0.586458683013916
----------------------------------------------------------
TOTAL LOSS, 20.072446823120117, CLIP loss: -1.3047281503677368, value function loss: 42.76608657836914, entropy loss: 0.5868306159973145
----------------------------------------------------------
TOTAL LOSS, 17.104907989501953, CLIP loss: -1.9703298807144165, value function loss: 38.1622314453125, entropy loss: 0.5877543091773987
----------------------------------------------------------
TOTAL LOSS, 14.065033912658691, CLIP loss: -2.565640449523926, value function loss: 33.27303695678711, entropy loss: 0.5844131708145142
----------------------------------------------------------
TOTAL LOSS, 14.577067375183105, CLIP loss: -2.1921632289886475, value function loss: 33.55002975463867, entropy loss: 0.5784040689468384
--------------------------

----------------------------------------------------------
TOTAL LOSS, 14.473060607910156, CLIP loss: -1.9907119274139404, value function loss: 32.93883514404297, entropy loss: 0.5644832849502563
----------------------------------------------------------
TOTAL LOSS, 13.008715629577637, CLIP loss: -2.196770668029785, value function loss: 30.422138214111328, entropy loss: 0.5583270788192749
----------------------------------------------------------
TOTAL LOSS, 13.38146686553955, CLIP loss: -2.1364383697509766, value function loss: 31.0470027923584, entropy loss: 0.5595700740814209
----------------------------------------------------------
TOTAL LOSS, 15.919726371765137, CLIP loss: -1.0064388513565063, value function loss: 33.86322021484375, entropy loss: 0.5444563031196594
----------------------------------------------------------
TOTAL LOSS, 17.620643615722656, CLIP loss: -0.7727962136268616, value function loss: 36.79756164550781, entropy loss: 0.5340909957885742
----------------------

----------------------------------------------------------
TOTAL LOSS, 20.304222106933594, CLIP loss: -2.2673513889312744, value function loss: 45.1539421081543, entropy loss: 0.5397313237190247
----------------------------------------------------------
TOTAL LOSS, 18.790607452392578, CLIP loss: -1.234362244606018, value function loss: 40.06085205078125, entropy loss: 0.5456689596176147
----------------------------------------------------------
TOTAL LOSS, 19.702848434448242, CLIP loss: -1.6493875980377197, value function loss: 42.71504592895508, entropy loss: 0.5286309719085693
----------------------------------------------------------
TOTAL LOSS, 19.168212890625, CLIP loss: -1.7113147974014282, value function loss: 41.769813537597656, entropy loss: 0.537956178188324
----------------------------------------------------------
TOTAL LOSS, 26.68878746032715, CLIP loss: 1.075148582458496, value function loss: 51.23793411254883, entropy loss: 0.5326797962188721
----------------------------

----------------------------------------------------------
TOTAL LOSS, 13.547203063964844, CLIP loss: -1.7149505615234375, value function loss: 30.53507423400879, entropy loss: 0.5383638739585876
----------------------------------------------------------
TOTAL LOSS, 10.808194160461426, CLIP loss: -1.379618525505066, value function loss: 24.386444091796875, entropy loss: 0.5408948063850403
----------------------------------------------------------
TOTAL LOSS, 11.190704345703125, CLIP loss: -1.5530836582183838, value function loss: 25.49832534790039, entropy loss: 0.5375353097915649
----------------------------------------------------------
TOTAL LOSS, 11.545846939086914, CLIP loss: -1.7130708694458008, value function loss: 26.52831268310547, entropy loss: 0.5238770246505737
----------------------------------------------------------
TOTAL LOSS, 35.4311637878418, CLIP loss: 1.3872828483581543, value function loss: 68.09848022460938, entropy loss: 0.5360649228096008
-----------------------

----------------------------------------------------------
TOTAL LOSS, 31.97437858581543, CLIP loss: 1.8271782398223877, value function loss: 60.30443572998047, entropy loss: 0.501842200756073
----------------------------------------------------------
TOTAL LOSS, 15.374748229980469, CLIP loss: -0.39568787813186646, value function loss: 31.551197052001953, entropy loss: 0.5162529945373535
----------------------------------------------------------
TOTAL LOSS, 13.539871215820312, CLIP loss: -0.729843020439148, value function loss: 28.549800872802734, entropy loss: 0.5185727477073669
----------------------------------------------------------
TOTAL LOSS, 15.606609344482422, CLIP loss: -0.6577704548835754, value function loss: 32.53902053833008, entropy loss: 0.5130565762519836
----------------------------------------------------------
TOTAL LOSS, 13.480276107788086, CLIP loss: -0.5340601205825806, value function loss: 28.03879737854004, entropy loss: 0.5062121748924255
---------------------

----------------------------------------------------------
TOTAL LOSS, 23.271333694458008, CLIP loss: 1.9538993835449219, value function loss: 42.645408630371094, entropy loss: 0.5269914865493774
----------------------------------------------------------
TOTAL LOSS, 23.85462188720703, CLIP loss: 1.4169095754623413, value function loss: 44.88593673706055, entropy loss: 0.5257580876350403
----------------------------------------------------------
TOTAL LOSS, 24.834487915039062, CLIP loss: 1.752379298210144, value function loss: 46.17464828491211, entropy loss: 0.5214694142341614
----------------------------------------------------------
TOTAL LOSS, 22.080320358276367, CLIP loss: 1.7696585655212402, value function loss: 40.63182830810547, entropy loss: 0.5253265500068665
----------------------------------------------------------
TOTAL LOSS, 21.931297302246094, CLIP loss: -2.3502345085144043, value function loss: 48.573875427246094, entropy loss: 0.5404579639434814
------------------------

----------------------------------------------------------
TOTAL LOSS, 25.260541915893555, CLIP loss: 0.8755590319633484, value function loss: 48.78072738647461, entropy loss: 0.5379893779754639
----------------------------------------------------------
TOTAL LOSS, 18.70567512512207, CLIP loss: 0.22237159311771393, value function loss: 36.97721481323242, entropy loss: 0.5303789973258972
----------------------------------------------------------
TOTAL LOSS, 26.319747924804688, CLIP loss: 1.2545777559280396, value function loss: 50.140594482421875, entropy loss: 0.5126360058784485
----------------------------------------------------------
TOTAL LOSS, 39.015480041503906, CLIP loss: 1.7176216840744019, value function loss: 74.60643005371094, entropy loss: 0.5357197523117065
----------------------------------------------------------
TOTAL LOSS, 28.932193756103516, CLIP loss: 1.8005945682525635, value function loss: 54.273963928222656, entropy loss: 0.5382097959518433
-----------------------

----------------------------------------------------------
TOTAL LOSS, 12.06381893157959, CLIP loss: -1.3303947448730469, value function loss: 26.79884910583496, entropy loss: 0.5210886001586914
----------------------------------------------------------
TOTAL LOSS, 18.845308303833008, CLIP loss: -0.7274315357208252, value function loss: 39.15598678588867, entropy loss: 0.5252132415771484
----------------------------------------------------------
TOTAL LOSS, 11.33300495147705, CLIP loss: -1.1372562646865845, value function loss: 24.950929641723633, entropy loss: 0.520368218421936
----------------------------------------------------------
TOTAL LOSS, 6.707208156585693, CLIP loss: -0.6645680665969849, value function loss: 14.754429817199707, entropy loss: 0.5438849925994873
----------------------------------------------------------
TOTAL LOSS, 7.23068904876709, CLIP loss: -0.7044165730476379, value function loss: 15.880660057067871, entropy loss: 0.522445023059845
------------------------

----------------------------------------------------------
TOTAL LOSS, 4.817870140075684, CLIP loss: -0.8062571287155151, value function loss: 11.25914192199707, entropy loss: 0.5443556904792786
----------------------------------------------------------
TOTAL LOSS, 4.377161979675293, CLIP loss: -0.9243490099906921, value function loss: 10.613869667053223, entropy loss: 0.5424070358276367
----------------------------------------------------------
TOTAL LOSS, 5.119307041168213, CLIP loss: -0.9254929423332214, value function loss: 12.100252151489258, entropy loss: 0.5326243042945862
----------------------------------------------------------
TOTAL LOSS, 5.135594844818115, CLIP loss: -0.8511530160903931, value function loss: 11.984164237976074, entropy loss: 0.5334216952323914
----------------------------------------------------------
TOTAL LOSS, 15.570378303527832, CLIP loss: 0.22129572927951813, value function loss: 30.709110260009766, entropy loss: 0.5472034215927124
--------------------

----------------------------------------------------------
TOTAL LOSS, 13.31335163116455, CLIP loss: 0.11068557947874069, value function loss: 26.416194915771484, entropy loss: 0.5431405901908875
----------------------------------------------------------
TOTAL LOSS, 12.677388191223145, CLIP loss: -0.26840612292289734, value function loss: 25.90256690979004, entropy loss: 0.5489805340766907
----------------------------------------------------------
TOTAL LOSS, 11.487421035766602, CLIP loss: 0.33376166224479675, value function loss: 22.3183536529541, entropy loss: 0.5516571998596191
----------------------------------------------------------
TOTAL LOSS, 6.2927141189575195, CLIP loss: -0.15124888718128204, value function loss: 12.898979187011719, entropy loss: 0.5526514053344727
----------------------------------------------------------
TOTAL LOSS, 7.885219573974609, CLIP loss: 0.10177735984325409, value function loss: 15.577875137329102, entropy loss: 0.549545407295227
-------------------

----------------------------------------------------------
TOTAL LOSS, 3.213731288909912, CLIP loss: -0.8245779871940613, value function loss: 8.087639808654785, entropy loss: 0.5510681867599487
----------------------------------------------------------
TOTAL LOSS, 3.379998207092285, CLIP loss: -0.8284464478492737, value function loss: 8.428186416625977, entropy loss: 0.56485915184021
----------------------------------------------------------
TOTAL LOSS, 3.3677775859832764, CLIP loss: -0.8444634079933167, value function loss: 8.435688018798828, entropy loss: 0.5602984428405762
----------------------------------------------------------
TOTAL LOSS, 2.7888360023498535, CLIP loss: -0.7644861340522766, value function loss: 7.117700576782227, entropy loss: 0.5528097748756409
----------------------------------------------------------
TOTAL LOSS, 4.300012111663818, CLIP loss: -0.15707087516784668, value function loss: 8.925188064575195, entropy loss: 0.551080584526062
-------------------------

----------------------------------------------------------
TOTAL LOSS, 2.0045926570892334, CLIP loss: -0.6226097345352173, value function loss: 5.266026496887207, entropy loss: 0.581096351146698
----------------------------------------------------------
TOTAL LOSS, 1.5650800466537476, CLIP loss: -0.6308144330978394, value function loss: 4.403470516204834, entropy loss: 0.5840815901756287
----------------------------------------------------------
TOTAL LOSS, 1.7571097612380981, CLIP loss: -0.6069133877754211, value function loss: 4.7397260665893555, entropy loss: 0.5839906930923462
----------------------------------------------------------
TOTAL LOSS, 1.7577015161514282, CLIP loss: -0.658875048160553, value function loss: 4.844621181488037, entropy loss: 0.5734144449234009
----------------------------------------------------------
TOTAL LOSS, 2.518303871154785, CLIP loss: -0.3098110556602478, value function loss: 5.667874813079834, entropy loss: 0.5822440385818481
----------------------

----------------------------------------------------------
TOTAL LOSS, 0.9210206866264343, CLIP loss: -1.0293656587600708, value function loss: 3.9121837615966797, entropy loss: 0.5705526471138
----------------------------------------------------------
TOTAL LOSS, 1.0864850282669067, CLIP loss: -0.9783027172088623, value function loss: 4.1410417556762695, entropy loss: 0.5733184814453125
----------------------------------------------------------
TOTAL LOSS, 0.7521970868110657, CLIP loss: -1.03639817237854, value function loss: 3.5885915756225586, entropy loss: 0.5700546503067017
----------------------------------------------------------
TOTAL LOSS, 0.9090870022773743, CLIP loss: -0.9570214152336121, value function loss: 3.7436187267303467, entropy loss: 0.5700938105583191
----------------------------------------------------------
TOTAL LOSS, 1.4798142910003662, CLIP loss: -0.4636365473270416, value function loss: 3.8980820178985596, entropy loss: 0.5590185523033142
--------------------

----------------------------------------------------------
TOTAL LOSS, 0.8922325968742371, CLIP loss: -0.6789537668228149, value function loss: 3.153604030609131, entropy loss: 0.5615654587745667
----------------------------------------------------------
TOTAL LOSS, 0.8878152966499329, CLIP loss: -0.7177057862281799, value function loss: 3.2222156524658203, entropy loss: 0.5586742758750916
----------------------------------------------------------
TOTAL LOSS, 0.7335907816886902, CLIP loss: -0.7329829931259155, value function loss: 2.944417715072632, entropy loss: 0.5635064840316772
----------------------------------------------------------
TOTAL LOSS, 1.3089015483856201, CLIP loss: -0.23340529203414917, value function loss: 3.096039295196533, entropy loss: 0.5712694525718689
----------------------------------------------------------
TOTAL LOSS, 1.1974060535430908, CLIP loss: -0.3524321913719177, value function loss: 3.1111273765563965, entropy loss: 0.5725374817848206
-----------------

----------------------------------------------------------
TOTAL LOSS, 8.388026237487793, CLIP loss: 0.22258919477462769, value function loss: 16.34204864501953, entropy loss: 0.5587421655654907
----------------------------------------------------------
TOTAL LOSS, 0.5813453197479248, CLIP loss: -0.37937766313552856, value function loss: 1.9328736066818237, entropy loss: 0.5713826417922974
----------------------------------------------------------
TOTAL LOSS, 0.41400063037872314, CLIP loss: -0.5275816917419434, value function loss: 1.8944288492202759, entropy loss: 0.5632100105285645
----------------------------------------------------------
TOTAL LOSS, 0.3997592628002167, CLIP loss: -0.48419561982154846, value function loss: 1.77925443649292, entropy loss: 0.5672340989112854
----------------------------------------------------------
TOTAL LOSS, 0.5773085355758667, CLIP loss: -0.40214380621910095, value function loss: 1.9700920581817627, entropy loss: 0.5593629479408264
---------------

----------------------------------------------------------
TOTAL LOSS, 0.6823768615722656, CLIP loss: -0.242786705493927, value function loss: 1.8615598678588867, entropy loss: 0.5616395473480225
----------------------------------------------------------
TOTAL LOSS, 0.45126959681510925, CLIP loss: -0.3117990493774414, value function loss: 1.5373446941375732, entropy loss: 0.5603702664375305
----------------------------------------------------------
TOTAL LOSS, 0.6599277853965759, CLIP loss: -0.1217915266752243, value function loss: 1.5745291709899902, entropy loss: 0.5545251965522766
----------------------------------------------------------
TOTAL LOSS, 0.8645492196083069, CLIP loss: 0.05730564147233963, value function loss: 1.625673770904541, entropy loss: 0.5593277215957642
----------------------------------------------------------
TOTAL LOSS, 0.7371755838394165, CLIP loss: 0.027869824320077896, value function loss: 1.4293925762176514, entropy loss: 0.5390530824661255
---------------

----------------------------------------------------------
TOTAL LOSS, 1.143189549446106, CLIP loss: -0.39681369066238403, value function loss: 3.091308832168579, entropy loss: 0.5651097893714905
----------------------------------------------------------
TOTAL LOSS, 1.3181462287902832, CLIP loss: -0.5226204991340637, value function loss: 3.692849636077881, entropy loss: 0.5658109784126282
----------------------------------------------------------
TOTAL LOSS, 1.1701232194900513, CLIP loss: -0.5098294019699097, value function loss: 3.371147632598877, entropy loss: 0.5621245503425598
----------------------------------------------------------
TOTAL LOSS, 1.2310956716537476, CLIP loss: -0.38758864998817444, value function loss: 3.24877667427063, entropy loss: 0.570409893989563
----------------------------------------------------------
TOTAL LOSS, 1.0207579135894775, CLIP loss: -0.15625013411045074, value function loss: 2.365349054336548, entropy loss: 0.5666534900665283
--------------------

----------------------------------------------------------
TOTAL LOSS, 0.23994752764701843, CLIP loss: -0.38179290294647217, value function loss: 1.2546186447143555, entropy loss: 0.5568892359733582
----------------------------------------------------------
TOTAL LOSS, 0.35764631628990173, CLIP loss: -0.3298061788082123, value function loss: 1.386011004447937, entropy loss: 0.5553008317947388
----------------------------------------------------------
TOTAL LOSS, 0.2920587658882141, CLIP loss: -0.39160025119781494, value function loss: 1.3783751726150513, entropy loss: 0.552856981754303
----------------------------------------------------------
TOTAL LOSS, 0.35904741287231445, CLIP loss: -0.3337554335594177, value function loss: 1.3966984748840332, entropy loss: 0.5546377301216125
----------------------------------------------------------
TOTAL LOSS, 0.14790624380111694, CLIP loss: -0.5208295583724976, value function loss: 1.3486250638961792, entropy loss: 0.5576723217964172
-----------

----------------------------------------------------------
TOTAL LOSS, 0.3857693076133728, CLIP loss: -0.1514054834842682, value function loss: 1.0853055715560913, entropy loss: 0.5477980375289917
----------------------------------------------------------
TOTAL LOSS, 0.330841988325119, CLIP loss: -0.19502517580986023, value function loss: 1.0626311302185059, entropy loss: 0.544841468334198
----------------------------------------------------------
TOTAL LOSS, 0.3077254593372345, CLIP loss: -0.14352962374687195, value function loss: 0.9134889841079712, entropy loss: 0.5489410758018494
----------------------------------------------------------
TOTAL LOSS, 0.11502068489789963, CLIP loss: -0.4174817204475403, value function loss: 1.0761388540267944, entropy loss: 0.5567024946212769
----------------------------------------------------------
TOTAL LOSS, 0.24853961169719696, CLIP loss: -0.2877844572067261, value function loss: 1.0837301015853882, entropy loss: 0.5540982484817505
-------------

----------------------------------------------------------
TOTAL LOSS, 0.21973274648189545, CLIP loss: -0.19552457332611084, value function loss: 0.8415960073471069, entropy loss: 0.554067850112915
----------------------------------------------------------
TOTAL LOSS, 0.23514294624328613, CLIP loss: -0.22886061668395996, value function loss: 0.9391130208969116, entropy loss: 0.5552940964698792
----------------------------------------------------------
TOTAL LOSS, 0.25324904918670654, CLIP loss: -0.18809257447719574, value function loss: 0.8937844038009644, entropy loss: 0.5550556778907776
----------------------------------------------------------
TOTAL LOSS, 0.22143465280532837, CLIP loss: -0.23693612217903137, value function loss: 0.9277254343032837, entropy loss: 0.5491940379142761
----------------------------------------------------------
TOTAL LOSS, 0.5340598821640015, CLIP loss: -0.1606820523738861, value function loss: 1.4004621505737305, entropy loss: 0.5489085912704468
--------

----------------------------------------------------------
TOTAL LOSS, 0.3550002872943878, CLIP loss: -0.12973271310329437, value function loss: 0.9806684851646423, entropy loss: 0.5601251721382141
----------------------------------------------------------
TOTAL LOSS, 0.331819623708725, CLIP loss: -0.13343849778175354, value function loss: 0.9417628645896912, entropy loss: 0.5623316168785095
----------------------------------------------------------
TOTAL LOSS, 0.29054945707321167, CLIP loss: -0.05605361983180046, value function loss: 0.70436692237854, entropy loss: 0.5580384731292725
----------------------------------------------------------
TOTAL LOSS, 0.32999926805496216, CLIP loss: -0.07937155663967133, value function loss: 0.8298364877700806, entropy loss: 0.5547419190406799
----------------------------------------------------------
TOTAL LOSS, 0.3269295394420624, CLIP loss: -0.05951415374875069, value function loss: 0.7841062545776367, entropy loss: 0.5609410405158997
-----------

----------------------------------------------------------
TOTAL LOSS, 0.1161443442106247, CLIP loss: -0.3156648874282837, value function loss: 0.8746157288551331, entropy loss: 0.5498631596565247
----------------------------------------------------------
TOTAL LOSS, 0.14344768226146698, CLIP loss: -0.34528428316116333, value function loss: 0.9885331392288208, entropy loss: 0.553460419178009
----------------------------------------------------------
TOTAL LOSS, 0.13215503096580505, CLIP loss: -0.3127520978450775, value function loss: 0.9008748531341553, entropy loss: 0.5530300736427307
----------------------------------------------------------
TOTAL LOSS, 0.08719012886285782, CLIP loss: -0.35609710216522217, value function loss: 0.8975728750228882, entropy loss: 0.5499204993247986
----------------------------------------------------------
TOTAL LOSS, 3.9470200538635254, CLIP loss: 0.0012893006205558777, value function loss: 7.902414321899414, entropy loss: 0.547649085521698
-----------

----------------------------------------------------------
TOTAL LOSS, 0.20772093534469604, CLIP loss: -0.20103034377098083, value function loss: 0.8287639617919922, entropy loss: 0.5630707144737244
----------------------------------------------------------
TOTAL LOSS, 0.16686859726905823, CLIP loss: -0.21643410623073578, value function loss: 0.7778431177139282, entropy loss: 0.5618863105773926
----------------------------------------------------------
TOTAL LOSS, 0.15334858000278473, CLIP loss: -0.18456904590129852, value function loss: 0.687069296836853, entropy loss: 0.5617027878761292
----------------------------------------------------------
TOTAL LOSS, 0.22892743349075317, CLIP loss: -0.23988650739192963, value function loss: 0.9487720727920532, entropy loss: 0.5572097897529602
----------------------------------------------------------
TOTAL LOSS, 0.20315828919410706, CLIP loss: -0.2018665373325348, value function loss: 0.8212928175926208, entropy loss: 0.5621580481529236
-------

----------------------------------------------------------
TOTAL LOSS, 0.316604346036911, CLIP loss: -0.043676719069480896, value function loss: 0.7316178679466248, entropy loss: 0.5527870655059814
----------------------------------------------------------
TOTAL LOSS, 0.40714529156684875, CLIP loss: 0.062092483043670654, value function loss: 0.7010537385940552, entropy loss: 0.5474061965942383
----------------------------------------------------------
TOTAL LOSS, 0.4047847390174866, CLIP loss: 0.06847520172595978, value function loss: 0.6837100982666016, entropy loss: 0.5545501112937927
----------------------------------------------------------
TOTAL LOSS, 0.37996163964271545, CLIP loss: 0.03965182602405548, value function loss: 0.6918618083000183, entropy loss: 0.5621103644371033
----------------------------------------------------------
TOTAL LOSS, 0.41897428035736084, CLIP loss: 0.0726928859949112, value function loss: 0.703630805015564, entropy loss: 0.5533980131149292
------------

----------------------------------------------------------
TOTAL LOSS, 0.12799686193466187, CLIP loss: -0.1979384571313858, value function loss: 0.6630125045776367, entropy loss: 0.5570931434631348
----------------------------------------------------------
TOTAL LOSS, 0.15809233486652374, CLIP loss: -0.18347126245498657, value function loss: 0.6944301128387451, entropy loss: 0.5651459693908691
----------------------------------------------------------
TOTAL LOSS, 0.1814979910850525, CLIP loss: -0.14523479342460632, value function loss: 0.6645259857177734, entropy loss: 0.5530203580856323
----------------------------------------------------------
TOTAL LOSS, -0.004994330462068319, CLIP loss: -0.3529685139656067, value function loss: 0.7069524526596069, entropy loss: 0.5502042770385742
----------------------------------------------------------
TOTAL LOSS, 0.12621814012527466, CLIP loss: -0.2156035453081131, value function loss: 0.6950333714485168, entropy loss: 0.5694998502731323
-------

----------------------------------------------------------
TOTAL LOSS, 0.37603452801704407, CLIP loss: 0.003348606638610363, value function loss: 0.7566673755645752, entropy loss: 0.5647791028022766
----------------------------------------------------------
TOTAL LOSS, 0.23910976946353912, CLIP loss: -0.024910928681492805, value function loss: 0.5392034649848938, entropy loss: 0.5581031441688538
----------------------------------------------------------
TOTAL LOSS, 0.2527894675731659, CLIP loss: -0.01479150727391243, value function loss: 0.5462074279785156, entropy loss: 0.5522732734680176
----------------------------------------------------------
TOTAL LOSS, 0.26551511883735657, CLIP loss: 0.0021994542330503464, value function loss: 0.5377550721168518, entropy loss: 0.5561872124671936
----------------------------------------------------------
TOTAL LOSS, 0.3363780975341797, CLIP loss: 0.019544919952750206, value function loss: 0.6447423696517944, entropy loss: 0.5538027286529541
-----

----------------------------------------------------------
TOTAL LOSS, 0.2217530757188797, CLIP loss: -0.06331127136945724, value function loss: 0.5809364914894104, entropy loss: 0.5403898358345032
----------------------------------------------------------
TOTAL LOSS, 0.27028873562812805, CLIP loss: -0.01642189361155033, value function loss: 0.5844573378562927, entropy loss: 0.5518062114715576
----------------------------------------------------------
TOTAL LOSS, 0.2707225978374481, CLIP loss: -0.01633983850479126, value function loss: 0.5848239660263062, entropy loss: 0.5349539518356323
----------------------------------------------------------
TOTAL LOSS, 0.28249233961105347, CLIP loss: -0.038648948073387146, value function loss: 0.653236985206604, entropy loss: 0.5477190613746643
----------------------------------------------------------
TOTAL LOSS, 0.3270037770271301, CLIP loss: 0.008666357025504112, value function loss: 0.6474550366401672, entropy loss: 0.5390120148658752
--------

----------------------------------------------------------
TOTAL LOSS, 0.40628737211227417, CLIP loss: 0.09103633463382721, value function loss: 0.6414951682090759, entropy loss: 0.5496566891670227
----------------------------------------------------------
TOTAL LOSS, 0.47398775815963745, CLIP loss: 0.15115073323249817, value function loss: 0.6568976044654846, entropy loss: 0.5611789226531982
----------------------------------------------------------
TOTAL LOSS, 0.3859315514564514, CLIP loss: 0.10805582255125046, value function loss: 0.5667450428009033, entropy loss: 0.5496793985366821
----------------------------------------------------------
TOTAL LOSS, 0.42531347274780273, CLIP loss: 0.13655324280261993, value function loss: 0.5885863900184631, entropy loss: 0.5532981157302856
----------------------------------------------------------
TOTAL LOSS, 0.432545006275177, CLIP loss: 0.13003632426261902, value function loss: 0.6161856651306152, entropy loss: 0.5584152936935425
-------------

----------------------------------------------------------
TOTAL LOSS, 10.573843002319336, CLIP loss: 0.8245144486427307, value function loss: 19.50962257385254, entropy loss: 0.5482480525970459
----------------------------------------------------------
TOTAL LOSS, 9.550298690795898, CLIP loss: 0.7393528819084167, value function loss: 17.6328182220459, entropy loss: 0.546333909034729
----------------------------------------------------------
TOTAL LOSS, 12.064103126525879, CLIP loss: 0.764017641544342, value function loss: 22.611167907714844, entropy loss: 0.5499112010002136
----------------------------------------------------------
TOTAL LOSS, 6.179846286773682, CLIP loss: 0.0531463623046875, value function loss: 12.264387130737305, entropy loss: 0.549344539642334
----------------------------------------------------------
TOTAL LOSS, 1.9148509502410889, CLIP loss: -0.0956399068236351, value function loss: 4.032140731811523, entropy loss: 0.5579445362091064
----------------------------

----------------------------------------------------------
TOTAL LOSS, 0.2381795346736908, CLIP loss: -0.02337023615837097, value function loss: 0.5338900685310364, entropy loss: 0.5395259261131287
----------------------------------------------------------
TOTAL LOSS, 0.2043628692626953, CLIP loss: -0.047317493706941605, value function loss: 0.5140931606292725, entropy loss: 0.5366215109825134
----------------------------------------------------------
TOTAL LOSS, 0.2809968888759613, CLIP loss: 0.009188580326735973, value function loss: 0.5542681813240051, entropy loss: 0.5325789451599121
----------------------------------------------------------
TOTAL LOSS, 0.19381439685821533, CLIP loss: -0.06271721422672272, value function loss: 0.5238732695579529, entropy loss: 0.5405026078224182
----------------------------------------------------------
TOTAL LOSS, 0.16449427604675293, CLIP loss: -0.07077673077583313, value function loss: 0.48155349493026733, entropy loss: 0.5505746006965637
------

----------------------------------------------------------
TOTAL LOSS, 9.222637176513672, CLIP loss: 0.7374346256256104, value function loss: 16.981319427490234, entropy loss: 0.5457155704498291
----------------------------------------------------------
TOTAL LOSS, 13.076682090759277, CLIP loss: 0.8621704578399658, value function loss: 24.43996810913086, entropy loss: 0.5472187399864197
----------------------------------------------------------
TOTAL LOSS, 8.668578147888184, CLIP loss: 0.5781742334365845, value function loss: 16.191646575927734, entropy loss: 0.5419491529464722
----------------------------------------------------------
TOTAL LOSS, 13.122359275817871, CLIP loss: 0.9330447316169739, value function loss: 24.389389038085938, entropy loss: 0.5379262566566467
----------------------------------------------------------
TOTAL LOSS, 0.08812152594327927, CLIP loss: -0.18668211996555328, value function loss: 0.5603882670402527, entropy loss: 0.5390490889549255
--------------------

----------------------------------------------------------
TOTAL LOSS, 0.10761310905218124, CLIP loss: -0.14434516429901123, value function loss: 0.514574408531189, entropy loss: 0.5328933596611023
----------------------------------------------------------
TOTAL LOSS, 0.15387727320194244, CLIP loss: -0.1165228933095932, value function loss: 0.5514669418334961, entropy loss: 0.533329963684082
----------------------------------------------------------
TOTAL LOSS, 0.1512361466884613, CLIP loss: -0.09889239817857742, value function loss: 0.5112665295600891, entropy loss: 0.5504730939865112
----------------------------------------------------------
TOTAL LOSS, 0.0935131162405014, CLIP loss: -0.10504322499036789, value function loss: 0.4080440402030945, entropy loss: 0.5465676784515381
----------------------------------------------------------
TOTAL LOSS, 0.15184791386127472, CLIP loss: -0.08077477663755417, value function loss: 0.4762708842754364, entropy loss: 0.5512760877609253
----------

----------------------------------------------------------
TOTAL LOSS, 0.2832171618938446, CLIP loss: 0.05057377368211746, value function loss: 0.4761630892753601, entropy loss: 0.5438150763511658
----------------------------------------------------------
TOTAL LOSS, 0.35388198494911194, CLIP loss: 0.10177113860845566, value function loss: 0.5151351690292358, entropy loss: 0.5456751585006714
----------------------------------------------------------
TOTAL LOSS, 0.30772286653518677, CLIP loss: 0.0695905014872551, value function loss: 0.48714667558670044, entropy loss: 0.5440967679023743
----------------------------------------------------------
TOTAL LOSS, 0.285528302192688, CLIP loss: 0.011191491037607193, value function loss: 0.5596216320991516, entropy loss: 0.5473993420600891
----------------------------------------------------------
TOTAL LOSS, 0.23483379185199738, CLIP loss: 0.007908366620540619, value function loss: 0.4649111330509186, entropy loss: 0.553013026714325
------------

----------------------------------------------------------
TOTAL LOSS, 0.12078201770782471, CLIP loss: -0.14324037730693817, value function loss: 0.5392515659332275, entropy loss: 0.5603391528129578
----------------------------------------------------------
TOTAL LOSS, -0.01691386103630066, CLIP loss: -0.22638101875782013, value function loss: 0.43011474609375, entropy loss: 0.5590215921401978
----------------------------------------------------------
TOTAL LOSS, 0.04324973374605179, CLIP loss: -0.17100213468074799, value function loss: 0.4396243691444397, entropy loss: 0.5560315847396851
----------------------------------------------------------
TOTAL LOSS, 0.019325731322169304, CLIP loss: -0.20237240195274353, value function loss: 0.4545637369155884, entropy loss: 0.5583735108375549
----------------------------------------------------------
TOTAL LOSS, 12.439684867858887, CLIP loss: 0.6060082912445068, value function loss: 23.67843246459961, entropy loss: 0.5539476871490479
---------

----------------------------------------------------------
TOTAL LOSS, 0.11628977209329605, CLIP loss: -0.0850125178694725, value function loss: 0.4137776494026184, entropy loss: 0.5586536526679993
----------------------------------------------------------
TOTAL LOSS, 0.11463254690170288, CLIP loss: -0.12356901168823242, value function loss: 0.4876473546028137, entropy loss: 0.5622116327285767
----------------------------------------------------------
TOTAL LOSS, 0.1385117471218109, CLIP loss: -0.11448600143194199, value function loss: 0.5171678066253662, entropy loss: 0.5586148500442505
----------------------------------------------------------
TOTAL LOSS, 0.11341542750597, CLIP loss: -0.09472344070672989, value function loss: 0.42744266986846924, entropy loss: 0.5582464337348938
----------------------------------------------------------
TOTAL LOSS, 0.18763767182826996, CLIP loss: -0.06735549122095108, value function loss: 0.5209812521934509, entropy loss: 0.5497464537620544
---------

----------------------------------------------------------
TOTAL LOSS, 0.08588914573192596, CLIP loss: -0.1574002355337143, value function loss: 0.4975477159023285, entropy loss: 0.5484473705291748
----------------------------------------------------------
TOTAL LOSS, 0.13995210826396942, CLIP loss: -0.1072314903140068, value function loss: 0.5055055022239685, entropy loss: 0.5569148063659668
----------------------------------------------------------
TOTAL LOSS, 0.21204492449760437, CLIP loss: -0.0022776350378990173, value function loss: 0.43974897265434265, entropy loss: 0.5551937222480774
----------------------------------------------------------
TOTAL LOSS, 0.16728003323078156, CLIP loss: -0.052608393132686615, value function loss: 0.4507475793361664, entropy loss: 0.5485376119613647
----------------------------------------------------------
TOTAL LOSS, 0.19814634323120117, CLIP loss: -0.05034065991640091, value function loss: 0.5079765319824219, entropy loss: 0.5501270294189453
---

----------------------------------------------------------
TOTAL LOSS, 0.3122633099555969, CLIP loss: 0.10640712082386017, value function loss: 0.42261776328086853, entropy loss: 0.5452685356140137
----------------------------------------------------------
TOTAL LOSS, 0.2851129472255707, CLIP loss: 0.054041437804698944, value function loss: 0.4727911353111267, entropy loss: 0.5324057340621948
----------------------------------------------------------
TOTAL LOSS, 0.2516605854034424, CLIP loss: 0.07880187779664993, value function loss: 0.356458455324173, entropy loss: 0.5370516180992126
----------------------------------------------------------
TOTAL LOSS, 0.32060039043426514, CLIP loss: 0.08307275921106339, value function loss: 0.4859403371810913, entropy loss: 0.5442535877227783
----------------------------------------------------------
TOTAL LOSS, 0.2962338328361511, CLIP loss: 0.0864626094698906, value function loss: 0.43042120337486267, entropy loss: 0.5439359545707703
-------------

----------------------------------------------------------
TOTAL LOSS, 13.011927604675293, CLIP loss: 1.9577810764312744, value function loss: 22.118919372558594, entropy loss: 0.5313268899917603
----------------------------------------------------------
TOTAL LOSS, 0.03403346240520477, CLIP loss: -0.21150213479995728, value function loss: 0.5018171072006226, entropy loss: 0.5372956395149231
----------------------------------------------------------
TOTAL LOSS, -0.026439659297466278, CLIP loss: -0.27586957812309265, value function loss: 0.5097503066062927, entropy loss: 0.5445233583450317
----------------------------------------------------------
TOTAL LOSS, 0.020777767524123192, CLIP loss: -0.21532823145389557, value function loss: 0.4830196797847748, entropy loss: 0.54038405418396
----------------------------------------------------------
TOTAL LOSS, -0.024001626297831535, CLIP loss: -0.2564989924430847, value function loss: 0.47591349482536316, entropy loss: 0.5459381341934204
-----

----------------------------------------------------------
TOTAL LOSS, 9.945066452026367, CLIP loss: 0.650600016117096, value function loss: 18.599903106689453, entropy loss: 0.5485504865646362
----------------------------------------------------------
TOTAL LOSS, 0.10063993185758591, CLIP loss: -0.09575165063142776, value function loss: 0.4035267233848572, entropy loss: 0.5371780395507812
----------------------------------------------------------
TOTAL LOSS, 0.15937146544456482, CLIP loss: -0.057165637612342834, value function loss: 0.4440588057041168, entropy loss: 0.549229621887207
----------------------------------------------------------
TOTAL LOSS, 0.13728155195713043, CLIP loss: -0.0728650614619255, value function loss: 0.43115997314453125, entropy loss: 0.5433366298675537
----------------------------------------------------------
TOTAL LOSS, 0.16389863193035126, CLIP loss: -0.08908496797084808, value function loss: 0.5166406631469727, entropy loss: 0.5336726903915405
----------

----------------------------------------------------------
TOTAL LOSS, 0.1761694699525833, CLIP loss: -0.03005749173462391, value function loss: 0.42310455441474915, entropy loss: 0.5325312614440918
----------------------------------------------------------
TOTAL LOSS, 0.20497910678386688, CLIP loss: 0.0046586692333221436, value function loss: 0.41146421432495117, entropy loss: 0.5411669015884399
----------------------------------------------------------
TOTAL LOSS, 0.15268145501613617, CLIP loss: -0.037881240248680115, value function loss: 0.3919435143470764, entropy loss: 0.5409059524536133
----------------------------------------------------------
TOTAL LOSS, 0.18079273402690887, CLIP loss: 0.0101285669952631, value function loss: 0.35199233889579773, entropy loss: 0.5332001447677612
----------------------------------------------------------
TOTAL LOSS, 0.17114481329917908, CLIP loss: -0.037939008325338364, value function loss: 0.4291887581348419, entropy loss: 0.5510554909706116
--

----------------------------------------------------------
TOTAL LOSS, 0.16579407453536987, CLIP loss: -0.05761966109275818, value function loss: 0.45766568183898926, entropy loss: 0.54190993309021
----------------------------------------------------------
TOTAL LOSS, 0.23363091051578522, CLIP loss: 0.0608680434525013, value function loss: 0.3563547134399414, entropy loss: 0.5414479970932007
----------------------------------------------------------
TOTAL LOSS, 0.2666631042957306, CLIP loss: 0.05133798345923424, value function loss: 0.441510945558548, entropy loss: 0.5430347919464111
----------------------------------------------------------
TOTAL LOSS, 0.2894027531147003, CLIP loss: 0.08983919024467468, value function loss: 0.4098123610019684, entropy loss: 0.5342628955841064
----------------------------------------------------------
TOTAL LOSS, 0.18460537493228912, CLIP loss: 0.021160468459129333, value function loss: 0.33772972226142883, entropy loss: 0.5419955253601074
------------

----------------------------------------------------------
TOTAL LOSS, 0.22180813550949097, CLIP loss: 0.019760170951485634, value function loss: 0.41468098759651184, entropy loss: 0.5292534232139587
----------------------------------------------------------
TOTAL LOSS, 0.18204975128173828, CLIP loss: -0.03261308744549751, value function loss: 0.44014379382133484, entropy loss: 0.5409061312675476
----------------------------------------------------------
TOTAL LOSS, 0.16234055161476135, CLIP loss: -0.024120137095451355, value function loss: 0.38379746675491333, entropy loss: 0.543804407119751
----------------------------------------------------------
TOTAL LOSS, 0.1867704689502716, CLIP loss: -0.0028853067196905613, value function loss: 0.3900943100452423, entropy loss: 0.5391373634338379
----------------------------------------------------------
TOTAL LOSS, 0.14214996993541718, CLIP loss: -0.04056426137685776, value function loss: 0.3763887882232666, entropy loss: 0.5480149984359741
-

----------------------------------------------------------
TOTAL LOSS, 0.10775760561227798, CLIP loss: -0.08655533939599991, value function loss: 0.3991553783416748, entropy loss: 0.52647465467453
----------------------------------------------------------
TOTAL LOSS, 0.1484443098306656, CLIP loss: -0.049621663987636566, value function loss: 0.4069461226463318, entropy loss: 0.5407078266143799
----------------------------------------------------------
TOTAL LOSS, 0.21052229404449463, CLIP loss: -0.002876847982406616, value function loss: 0.4375477731227875, entropy loss: 0.5374749898910522
----------------------------------------------------------
TOTAL LOSS, 0.06140885502099991, CLIP loss: -0.1331312656402588, value function loss: 0.3997878134250641, entropy loss: 0.5353787541389465
----------------------------------------------------------
TOTAL LOSS, 0.21687918901443481, CLIP loss: 0.015298177488148212, value function loss: 0.41405168175697327, entropy loss: 0.5444822311401367
------

In [8]:
# create an env with random state
def make_env_func(gym_id, seed, idx, run_name, capture_video=False):
    def env_fun():
        env = gym.make(gym_id, render_mode="rgb_array")
        env = gym.wrappers.RecordEpisodeStatistics(env)
        if capture_video:
            # initiate the video capture if not already initiated
            if idx == 0:
                # wrapper to create the video of the performance
                env = gym.wrapper
                s.RecordVideo(env, f"videos/{run_name}")
        env.action_space.seed(seed)
        env.observation_space.seed(seed)
        return env

    return env_fun

In [9]:
# create N envs
envs = []
for i in range(configs.num_trajcts):
    envs.append( make_env_func(configs.gym_id, seed + i, i, run_name) )
envs = gym.vector.SyncVectorEnv(envs)
envs

SyncVectorEnv(num_envs=32)

In [10]:
# start the environment
cur_observation = envs.reset()[0]

In [11]:
class FCBlock(nn.Module):
    """A generic fully connected residual block with good setup"""
    def __init__(self, embd_dim, dropout=0.2):
        super().__init__()
        self.block = nn.Sequential(
            nn.LayerNorm(embd_dim),
            nn.GELU(),
            nn.Linear(embd_dim, 4*embd_dim),
            nn.GELU(),
            nn.Linear(4*embd_dim, embd_dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return x + self.block(x)



class Agent(nn.Module):
    """an agent that creates actions and estimates values"""
    def __init__(self, env_observation_dim, action_space_dim, embd_dim=64, num_blocks=2):
        super().__init__()
        self.embedding_layer = nn.Linear(env_observation_dim, embd_dim)
        self.shared_layers = nn.Sequential(*[FCBlock(embd_dim=embd_dim) for _ in range(num_blocks)])
        self.value_head = nn.Linear(embd_dim, 1)
        self.policy_head = nn.Linear(embd_dim, action_space_dim)
        # orthogonal initialization with a hi entropy for more exploration at the start
        torch.nn.init.orthogonal_(self.policy_head.weight, 0.01)

    def value_func(self, state):
        hidden = self.shared_layers(self.embedding_layer(state))
        value = self.value_head(hidden)
        return value

    def policy(self, state, action=None):
        hidden = self.shared_layers(self.embedding_layer(state))
        logits = self.policy_head(hidden)
        # PyTorch categorical class helpful for sampling and probability calculations
        probs = Categorical(logits=logits)
        if action is None:
            action = probs.sample()
        return action, probs.log_prob(action), probs.entropy(), self.value_head(hidden)


In [12]:
#     "num_trajcts": 32, # N
#     "max_trajects_length": 64, # T

def create_rollout(envs, cur_observation, cur_done, agent):

    observations = torch.zeros((cur_observation.shape[0], configs['max_trajects_length'],envs.single_observation_space.shape[0] ))
    actions = torch.zeros((cur_observation.shape[0], configs['max_trajects_length']) + envs.single_action_space.shape)
    dones =  torch.zeros((cur_observation.shape[0], configs['max_trajects_length']))
    rewards =  torch.zeros((cur_observation.shape[0], configs['max_trajects_length']))
    values = torch.zeros((cur_observation.shape[0], configs['max_trajects_length']))
    advantages = torch.zeros((cur_observation.shape[0], configs['max_trajects_length']))
    logprobs = torch.zeros((cur_observation.shape[0], configs['max_trajects_length']))
     
    
    for t in range(configs['max_trajects_length']):
        # get the policy
        with torch.no_grad():
            action, logprobb,_,value = agent.policy(cur_observation)
            
        observations[:,t,:] = cur_observation
        actions[:,t] = action
        dones[:,t] = cur_done
        
        cur_observation, cur_reward, cur_done,_,_ = envs.step(action.cpu().numpy())
        cur_observation = torch.tensor(cur_observation)
        cur_reward = torch.tensor(cur_reward)
        cur_done = torch.tensor(cur_done)
        
        rewards[:,t] = torch.tensor(cur_reward)

        values[:,t] = value.squeeze()
       
    # Advantage is approximated reverse recursively
#     advantage = gae(observations,dones, rewards,values, advantages)
    
    return {
        "cur_observation": cur_observation,
        "observations": observations,
        "actions" : actions,
        "dones" : dones,
        "rewards" : rewards,
        "values" : values,
        "advantages" : advantages,
        "logprobs" : logprobs
    }

In [13]:
agent = Agent(
    env_observation_dim=envs.single_observation_space.shape[0],
    action_space_dim=envs.single_action_space.n
).to(device)

optimizer = optim.Adam(agent.parameters(), lr=configs.learning_rate)

In [14]:
rollouts = create_rollout(envs, torch.tensor(cur_observation), torch.tensor(cur_done), agent)


NameError: name 'cur_done' is not defined

In [15]:
cur_observation = rollouts['cur_observation']
values = rollouts['values']
dones = rollouts['dones']
rewards = rollouts['rewards']


NameError: name 'rollouts' is not defined

In [16]:
def gae(rewards, values, dones, cur_observation, agent):
    last_advantage = 0
    with torch.no_grad():
        last_value = agent.value_func(cur_observation).reshape(1,-1)
    advantages = torch.zeros_like(values)
    
    for t in reversed(range(configs['max_trajects_length'])):
        mask = 1.0 - dones[:,t]
        last_advantage = mask*last_advantage
        last_value = mask*last_value
        delta = rewards[:,t] + configs['gamma']*last_value - values[:,t]
        advantages[:,t] = delta + configs['gae_lambda']*configs['gamma']*last_advantage
        last_value = values[:,t]
        last_advantage = advantages[:,t]
    
    returns = advantages + values
        
    return advantages, returns

In [17]:
advantages,returns = gae(rewards, values, dones, cur_observation, agent)

NameError: name 'rewards' is not defined

In [18]:
def gae(rewards, values, dones, cur_observation, agent):
    last_advantage = 0
    with torch.no_grad():
        last_value = agent.value_func(cur_observation).reshape(1,-1)
    advantages = torch.zeros_like(values)
    
    for t in reversed(range(configs['max_trajects_length'])):
        mask = 1.0 - dones[:,t]
        last_advantage = mask*last_advantage
        last_value = mask*last_value
        delta = rewards[:,t] + configs['gamma']*last_value - values[:,t]
        advantages[:,t] = delta + configs['gae_lambda']*configs['gamma']*last_advantage
        last_value = values[:,t]
        last_advantage = advantages[:,t]
    
    returns = advantages + values
        
    return advantages, returns


def loss_clip(
    mb_oldlogporb,     # old logprob of mini batch actions collected during the rollout
    mb_newlogprob,     # new logprob of mini batch actions created by the new policy
    mb_advantages      # mini batch of advantages collected during the the rollout
):
    """
    policy loss with clipping to control gradients
    """
    ratio = torch.exp(mb_newlogprob - mb_oldlogporb)
    policy_loss = -mb_advantages * ratio
    # clipped policy gradient loss enforces closeness
    clipped_loss = -mb_advantages * torch.clamp(ratio, 1 - configs.clip_epsilon, 1 + configs.clip_epsilon)
    pessimistic_loss = torch.max(policy_loss, clipped_loss).mean()
    return pessimistic_loss


def loss_vf(
    mb_oldreturns,  # mini batch of old returns collected during the rollout
    mb_newvalues    # minibach of values calculated by the new value function
):
    """
    enforcing the value function to give more accurate estimates of returns
    """
    mb_newvalues = mb_newvalues.view(-1)
    loss = 0.5 * ((mb_newvalues - mb_oldreturns) ** 2).mean()
    return loss


class Storage(Dataset):
    def __init__(self, rollout, advantages, returns, envs):
        # fill in the storage and flatten the parallel trajectories
        self.observations = rollout['observations'].reshape((-1,) + envs.single_observation_space.shape)
        self.logprobs = rollout['logprobs'].reshape(-1)
        self.actions = rollout['actions'].reshape((-1,) + envs.single_action_space.shape).long()
        self.advantages = advantages.reshape(-1)
        self.returns = returns.reshape(-1)

    def __getitem__(self, ix: int):
        item = [
            self.observations[ix],
            self.logprobs[ix],
            self.actions[ix],
            self.advantages[ix],
            self.returns[ix]
        ]
        return item

    def __len__(self) -> int:
        return len(self.observations)
    
import torch
from torch.distributions.categorical import Categorical
# Create the environment
envs = []
for i in range(configs.num_trajcts):
    envs.append( make_env_func(configs.gym_id, seed + i, i, run_name) )
envs = gym.vector.SyncVectorEnv(envs)
cur_observation = torch.tensor(envs.reset()[0])
cur_done = torch.tensor(torch.zeros(configs.num_trajcts))


for i in range(200):

    frac = 1.0 - (i - 1.0) / (configs.total_timesteps/configs.batch_size)
    optimizer.param_groups[0]["lr"] = frac * configs.learning_rate
    # Phase 1: Create rollout
    rollouts = create_rollout(envs,cur_observation,cur_done, agent)
    advantages,returns = gae(rollouts['rewards'], rollouts['values'], rollouts['dones'], rollouts['cur_observation'], agent)
#     rollout, advantages, returns, envs
    dataset = Storage(rollouts, advantages, returns, envs)
    dataloader = DataLoader(dataset, batch_size=configs.minibatch_size, shuffle=True)
    
    # Phase 2: Update
    for j in range(configs.update_epochs):
        for data in dataloader: # mini_batch
            mb_ob,mb_logprobs, mb_actions, mb_advantages, mb_returns = data
            
            new_actions, new_logprobs, new_entropy, new_values = agent.policy(mb_ob, mb_logprobs)
            
            c_loss = loss_clip(mb_logprobs, new_logprobs, mb_advantages)
            vf_loss = loss_vf(mb_returns, new_values)
            entropy = new_entropy.mean()
            
            # maximize the PPO clip loss, minimize the value loss and maximize the entropy(exploration)
            # since pytorch's optimizer are configured to do gradient descent under the hood, i.e W - lr* deltaW
            # we need to multiply the loss the with negative(-) if we need to do gradient ascent
            # we already multiplied it in the clip loss function and we multiply it in entropy.
            
            loss = c_loss + configs.vf_coef*vf_loss - configs.ent_coef*entropy
            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(agent.parameters(), configs.max_grad_norm)
            optimizer.step()
            
            print(loss)
        

In [19]:
configs

{'exp_name': 'cartpole', 'gym_id': 'CartPole-v1', 'learning_rate': 0.001, 'total_timesteps': 1000000, 'max_grad_norm': 0.5, 'num_trajcts': 32, 'max_trajects_length': 64, 'gamma': 0.99, 'gae_lambda': 0.95, 'num_minibatches': 2, 'update_epochs': 2, 'clip_epsilon': 0.2, 'ent_coef': 0.01, 'vf_coef': 0.5, 'num_returns_to_average': 3, 'num_episodes_to_average': 23, 'batch_size': 2048, 'minibatch_size': 1024}

In [20]:
def loss_clip(
    mb_oldlogporb,     # old logprob of mini batch actions collected during the rollout
    mb_newlogprob,     # new logprob of mini batch actions created by the new policy
    mb_advantages      # mini batch of advantages collected during the the rollout
):
    """
    policy loss with clipping to control gradients
    """
    ratio = torch.exp(mb_newlogprob - mb_oldlogporb)
    policy_loss = -mb_advantages * ratio
    # clipped policy gradient loss enforces closeness
    clipped_loss = -mb_advantages * torch.clamp(ratio, 1 - configs.clip_epsilon, 1 + configs.clip_epsilon)
    pessimistic_loss = torch.max(policy_loss, clipped_loss).mean()
    return pessimistic_loss


def loss_vf(
    mb_oldreturns,  # mini batch of old returns collected during the rollout
    mb_newvalues    # minibach of values calculated by the new value function
):
    """
    enforcing the value function to give more accurate estimates of returns
    """
    mb_newvalues = mb_newvalues.view(-1)
    loss = 0.5 * ((mb_newvalues - mb_oldreturns) ** 2).mean()
    return loss

In [21]:
configs

{'exp_name': 'cartpole', 'gym_id': 'CartPole-v1', 'learning_rate': 0.001, 'total_timesteps': 1000000, 'max_grad_norm': 0.5, 'num_trajcts': 32, 'max_trajects_length': 64, 'gamma': 0.99, 'gae_lambda': 0.95, 'num_minibatches': 2, 'update_epochs': 2, 'clip_epsilon': 0.2, 'ent_coef': 0.01, 'vf_coef': 0.5, 'num_returns_to_average': 3, 'num_episodes_to_average': 23, 'batch_size': 2048, 'minibatch_size': 1024}

In [1]:
import torch
from torch.distributions.categorical import Categorical
# Create the environment
envs = []
for i in range(configs.num_trajcts):
    envs.append( make_env_func(configs.gym_id, seed + i, i, run_name) )
envs = gym.vector.SyncVectorEnv(envs)
cur_observation = torch.tensor(envs.reset()[0])
cur_done = torch.tensor(torch.zeros(configs.num_trajcts))


for i in range(200):

    frac = 1.0 - (i - 1.0) / (configs.total_timesteps/configs.batch_size)
    optimizer.param_groups[0]["lr"] = frac * configs.learning_rate
    # Phase 1: Create rollout
    rollouts = create_rollout(envs,cur_observation,cur_done, agent)
    advantages,returns = gae(rollouts['rewards'], rollouts['values'], rollouts['dones'], rollouts['cur_observation'], agent)
#     rollout, advantages, returns, envs
    dataset = Storage(rollouts, advantages, returns, envs)
    dataloader = DataLoader(dataset, batch_size=configs.minibatch_size, shuffle=True)
    
    # Phase 2: Update
    for j in range(configs.update_epochs):
        for data in dataloader: # mini_batch
            mb_ob,mb_logprobs, mb_actions, mb_advantages, mb_returns = data
            
            new_actions, new_logprobs, new_entropy, new_values = agent.policy(mb_ob, mb_logprobs)
            
            print(mb_logprobs[0], new_logprobs[0], mb_advantage[0])
            
            print(new_logprobs[0] / mb_logprobs[0]*mb_advanrage[0])
            c_loss = loss_clip(mb_logprobs, new_logprobs, mb_advantages)
            vf_loss = loss_vf(mb_returns, new_values)
            entropy = new_entropy.mean()
            
            # maximize the PPO clip loss, minimize the value loss and maximize the entropy(exploration)
            # since pytorch's optimizer are configured to do gradient descent under the hood, i.e W - lr* deltaW
            # we need to multiply the loss the with negative(-) if we need to do gradient ascent
            # we already multiplied it in the clip loss function and we multiply it in entropy.
            
            loss = c_loss + configs.vf_coef*vf_loss - configs.ent_coef*entropy
            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(agent.parameters(), configs.max_grad_norm)
            optimizer.step()
            
#             print(loss)
            
            break
        break
    break
        

NameError: name 'configs' is not defined

In [32]:
# create a test env
test_env = make_env_func(configs.gym_id, seed, 0, 'inference', True)()

# use the trained agent to run through the env till it terminates this is an eposide
observation, _ = test_env.reset()
observation = torch.unsqueeze(torch.tensor(observation),dim=0).to(device)
for _ in range(500):
    action, _, _, _ = agent.policy(observation)
    action = action.cpu().item()
    observation, reward, done, _, info = test_env.step(action)
    observation = torch.unsqueeze(torch.tensor(observation),dim=0).to(device)
    if done:
        break
test_env.close()

Video('/content/videos/inference/rl-video-episode-0.mp4', embed=True)

  logger.warn(


In [393]:
configs

{'exp_name': 'cartpole', 'gym_id': 'CartPole-v1', 'learning_rate': 0.001, 'total_timesteps': 1000000, 'max_grad_norm': 0.5, 'num_trajcts': 32, 'max_trajects_length': 64, 'gamma': 0.99, 'gae_lambda': 0.95, 'num_minibatches': 2, 'update_epochs': 2, 'clip_epsilon': 0.2, 'ent_coef': 0.01, 'vf_coef': 0.5, 'num_returns_to_average': 3, 'num_episodes_to_average': 23, 'batch_size': 2048, 'minibatch_size': 1024}

In [22]:
sample = torch.randn((2,3))

In [23]:
sample

tensor([[-0.6010, -1.7952, -0.5128],
        [-0.2789,  0.2727, -1.3774]])

In [25]:
sample.mean(dim=1)

tensor([-0.9696, -0.4612])

In [26]:
import torch

# Example: One embedding (1D tensor) and a list of other embeddings (2D tensor)
one_embedding = torch.tensor([1.0, 2.0, 3.0])
list_of_embeddings = torch.tensor([[4.0, 5.0, 6.0],
                                    [7.0, 8.0, 9.0],
                                    [1.0, 0.0, 0.0]])

# Normalize the one embedding
one_embedding_norm = one_embedding / one_embedding.norm()

# Normalize the list of embeddings
list_of_embeddings_norm = list_of_embeddings / list_of_embeddings.norm(dim=1, keepdim=True)

# Compute cosine similarities (dot product of normalized vectors)
cos_similarities = torch.mm(list_of_embeddings_norm, one_embedding_norm.unsqueeze(1)).squeeze()

# Print cosine similarities
print(cos_similarities)


tensor([0.9746, 0.9594, 0.2673])


In [27]:
list_of_embeddings.norm(dim=1, keepdim=True)

tensor([[ 8.7750],
        [13.9284],
        [ 1.0000]])

In [29]:
cos_similarities.argmax()

tensor(0)

In [33]:
# Number of top scores to retrieve (k)
k = 5

# Use torch.topk to get the top-k values and their indices
topk_values, topk_indices = torch.topk(cos_similarities, k)

RuntimeError: selected index k out of range

tensor([0.9746, 0.9594])

In [32]:
topk_indices

tensor([0, 1])

In [42]:
ls = [*list_of_embeddings[0], torch.tensor([0])]

In [44]:
torch.tensor(ls)

tensor([4., 5., 6., 0.])

In [49]:
one = torch.randn(2,3)
print(one)
new_one = one.view(-1)
new_one


tensor([[-0.0150, -0.1535, -0.1722],
        [ 0.9797, -1.4121, -1.5100]])


tensor([-0.0150, -0.1535, -0.1722,  0.9797, -1.4121, -1.5100])

In [50]:
two = torch.randn(2,3)
print(two)
new_two = two.view(-1)

new_two

tensor([[-0.9105, -1.8257,  1.2258],
        [-0.3211, -0.8037, -1.5332]])


tensor([-0.9105, -1.8257,  1.2258, -0.3211, -0.8037, -1.5332])

In [51]:
torch.cat((new_one, new_two), dim=-1)

tensor([-0.0150, -0.1535, -0.1722,  0.9797, -1.4121, -1.5100, -0.9105, -1.8257,
         1.2258, -0.3211, -0.8037, -1.5332])

In [52]:
import torch

# Original tensor
tensor = torch.tensor([-0.0150, -0.1535, -0.1722,  0.9797, -1.4121, -1.5100, 
                       -0.9105, -1.8257, 1.2258, -0.3211, -0.8037, -1.5332])

# Duplicate across 5 rows
duplicated_tensor = tensor.repeat(5, 1)

# Print the result
print(duplicated_tensor)


tensor([[-0.0150, -0.1535, -0.1722,  0.9797, -1.4121, -1.5100, -0.9105, -1.8257,
          1.2258, -0.3211, -0.8037, -1.5332],
        [-0.0150, -0.1535, -0.1722,  0.9797, -1.4121, -1.5100, -0.9105, -1.8257,
          1.2258, -0.3211, -0.8037, -1.5332],
        [-0.0150, -0.1535, -0.1722,  0.9797, -1.4121, -1.5100, -0.9105, -1.8257,
          1.2258, -0.3211, -0.8037, -1.5332],
        [-0.0150, -0.1535, -0.1722,  0.9797, -1.4121, -1.5100, -0.9105, -1.8257,
          1.2258, -0.3211, -0.8037, -1.5332],
        [-0.0150, -0.1535, -0.1722,  0.9797, -1.4121, -1.5100, -0.9105, -1.8257,
          1.2258, -0.3211, -0.8037, -1.5332]])
