In [1]:
import os
from typing import Optional
from collections import namedtuple, deque
import math
import random
import logging

import tqdm
import tqdm.notebook as tqdm_notebook
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.categorical import Categorical
from torch.utils.data import DataLoader
from torch.utils.data.dataset import IterableDataset
from torchmetrics import MeanMetric

from lightning.pytorch import LightningModule, loggers as pl_loggers
from lightning.pytorch.trainer import Trainer
from lightning.pytorch import seed_everything

In [2]:
logger = logging.getLogger('lightning.pytorch')
logger.setLevel(logging.INFO)

# Data

In [3]:


Experience = namedtuple('Experience', ['observations', 'actions', 'values', 'returns', 'advantages', 'log_probs'])


class ReplayBuffer(object):

    def __init__(self, capacity: int = 1000):
        self.buffer = deque(maxlen=capacity)

    def __len__(self):
        return len(self.buffer)

    def append(self, experience):
        self.buffer.append(experience)

    def sample(self, sample_size: int = 1):
        replace = False if sample_size < len(self) else True
        indices = np.random.choice(len(self), sample_size, replace=replace)
        # collate experiences
        for idx in indices:
            yield {key: getattr(self.buffer[idx], key) for key in Experience._fields}

    def clear(self):
        self.buffer.clear()


class RLDataset(IterableDataset):

    def __init__(self, buffer: ReplayBuffer, sample_step_num: int = 1):
        self.buffer = buffer
        self.sample_step_num = sample_step_num

    def __len__(self):
        return self.sample_step_num

    def __iter__(self):
        for data in self.buffer.sample(self.sample_step_num):
            yield data


# Loss

In [4]:
def policy_loss(advantages: torch.Tensor, ratio: torch.Tensor, cliip_coef: float):
    p_loss1 = -advantages * ratio
    p_loss2 = -advantages * torch.clamp(ratio, 1 - cliip_coef, 1 + cliip_coef)
    return torch.max(p_loss1, p_loss2).mean()


def value_loss(new_values: torch.Tensor, old_values: torch.Tensor, returns: torch.Tensor, clip_coef: float,
               clip_vloss: bool, vf_coef: float):
    new_values = new_values.view(-1)
    if not clip_vloss:
        values_pred = new_values
    else:
        values_pred = old_values + torch.clamp(new_values - old_values, -clip_coef, clip_coef)
    return vf_coef * F.mse_loss(values_pred, returns)


def entropy_loss(entropy: torch.Tensor, ent_coef: float):
    return -entropy.mean() * ent_coef

# Model

In [5]:
class PPOModuleBase(nn.Module):

    def __init__(self, mlp_cfg: dict, inp_channels: int, out_channels: int, *args, **kwargs):
        super().__init__()

        self.net = self._create_mlp(mlp_cfg, inp_channels, out_channels)

    def _create_mlp(self, mlp_cfg, inp_channels, out_channels):
        channels = [inp_channels] + mlp_cfg['channels'] + [out_channels]
        use_layer_norm = mlp_cfg.get('use_layer_norm', False)
        act_func = mlp_cfg.get('act_func', 'ReLU')

        _mlp = nn.Sequential(*[
            nn.Sequential(
                nn.Linear(in_chn, out_chn, bias=True),
                nn.LayerNorm(out_chn) if use_layer_norm else nn.Identity(),
                getattr(nn, act_func)())
            for in_chn, out_chn in zip(channels[:-2], channels[1:-1])
        ])
        _mlp.append(nn.Linear(channels[-2], channels[-1], bias=True))
        return _mlp

    def forward(self, x: torch.Tensor):
        return self.net(x)


class PPOActor(PPOModuleBase):

    def __init__(self, mlp_cfg: dict, envs: gym.vector.SyncVectorEnv):
        inp_channels = math.prod(envs.single_observation_space.shape)
        out_channels = envs.single_action_space.n
        super().__init__(mlp_cfg, inp_channels, out_channels)


class PPOCritic(PPOModuleBase):

    def __init__(self, mlp_cfg: dict, envs: gym.vector.SyncVectorEnv):
        inp_channels = math.prod(envs.single_observation_space.shape)
        out_channels = 1
        super().__init__(mlp_cfg, inp_channels, out_channels)


class PPOModel(nn.Module):

    def __init__(self, envs: gym.vector.SyncVectorEnv, actor_cfg: dict, critic_cfg: dict, *args, **kwargs):
        super().__init__()

        self.actor: PPOActor = PPOActor(**actor_cfg, envs=envs)
        self.critic: PPOCritic = PPOCritic(**critic_cfg, envs=envs)

    def get_action(self, obs: torch.Tensor, act: Optional[torch.Tensor] = None, greedy: bool = False):
        act_logits = self.actor(obs)
        if greedy:
            probs = F.softmax(act_logits, dim=-1)
            return torch.argmax(probs, dim=-1)
        else:
            dist = Categorical(logits=act_logits)
            if act is None:
                act = dist.sample()
            return act, dist.log_prob(act), dist.entropy()

    def get_value(self, obs: torch.Tensor):
        return self.critic(obs)

    def forward(self, obs: torch.Tensor, act: torch.Tensor = None):
        act, log_prob, entropy = self.get_action(obs, act)
        val = self.get_value(obs)
        return act, log_prob, entropy, val

    @torch.no_grad()
    def estimate_returns_and_advantages(
            self,
            rewards: torch.Tensor,
            values: torch.Tensor,
            dones: torch.Tensor,
            next_obs: torch.Tensor,
            next_done: torch.Tensor,
            num_steps: int,
            gamma: float,
            gae_lambda: float):
        next_values = self.get_value(next_obs).reshape(1, -1)
        advantages = torch.zeros_like(rewards)
        last_gae_lam = 0
        for t in reversed(range(num_steps)):
            if t == num_steps - 1:
                next_non_terminal = torch.logical_not(next_done)
            else:
                next_non_terminal = torch.logical_not(dones[t + 1])
                next_values = values[t + 1]
            delta = rewards[t] + gamma * next_values * next_non_terminal - values[t]
            advantages[t] = last_gae_lam = delta + gamma * gae_lambda * next_non_terminal * last_gae_lam
        returns = advantages + values
        return returns, advantages


class PPOLightning(LightningModule):

    def __init__(
            self,
            env_cfg: dict,
            data_cfg: dict,
            ppo_cfg: dict,
            vf_coef: float = 1.0,
            ent_coef: float = 0.0,
            clip_coef: float = 0.2,
            clip_vloss: bool = False,
            learning_rate: float = 1e-3,
            normalize_advantages: bool = False,
            gamma: float = 0.99,
            gae_lambda: float = 0.95,
            seed=42,
            log_root='',
            **torchmetrics_kwargs):
        super().__init__()
        self.save_hyperparameters()

        self.env_cfg = env_cfg
        self.data_cfg = data_cfg
        self.ppo_cfg = ppo_cfg

        self.envs = self._load_env(env_cfg, data_cfg, self.hparams.log_root, self.hparams.seed)
        self.replay_buffer = self._load_replay_buffer(data_cfg)
        self.ppo_model: PPOModel = PPOModel(envs=self.envs, **ppo_cfg)

        self.avg_pg_loss = MeanMetric(**torchmetrics_kwargs)
        self.avg_value_loss = MeanMetric(**torchmetrics_kwargs)
        self.avg_ent_loss = MeanMetric(**torchmetrics_kwargs)

        self.vf_coef = vf_coef
        self.ent_coef = ent_coef
        self.clip_coef = clip_coef
        self.clip_vloss = clip_vloss
        self.normalize_advantages = normalize_advantages

        self.learning_rate = learning_rate

        self.gamma = gamma
        self.gae_lambda = gae_lambda

    def forward(self, obs: torch.Tensor, act: Optional[torch.Tensor] = None):
        return self.ppo_model(obs, act)

    def on_fit_start(self):
        num_steps = self.data_cfg.get('total_timestep_n', 1000)

        self.observation_shape = self.envs.single_observation_space.shape
        self.action_shape = self.envs.single_action_space.shape

        self.observations_buffer = torch.zeros((num_steps, self.envs.num_envs) + self.observation_shape,
                                               device=self.device)
        self.actions_buffer = torch.zeros((num_steps, self.envs.num_envs) + self.action_shape, device=self.device)
        self.log_probs_buffer = torch.zeros((num_steps, self.envs.num_envs), device=self.device)
        self.rewards_buffer = torch.zeros((num_steps, self.envs.num_envs), device=self.device)
        self.dones_buffer = torch.zeros((num_steps, self.envs.num_envs), device=self.device)
        self.values_buffer = torch.zeros((num_steps, self.envs.num_envs), device=self.device)

        self.on_train_epoch_start()

    @torch.no_grad()
    def on_train_epoch_start(self):
        self.eval()

        env_eps = 0
        env_rew = 0
        env_eps_len = 0

        num_steps = self.data_cfg.get('total_timestep_n', 1000)

        next_observations = torch.tensor(self.envs.reset()[0], device=self.device)
        next_dones = torch.zeros(self.envs.num_envs, device=self.device)
        for step in range(0, num_steps):
            self.observations_buffer[step] = next_observations
            self.dones_buffer[step] = next_dones

            actions, log_probs, _, values = self(next_observations)
            self.values_buffer[step] = values.flatten()
            self.actions_buffer[step] = actions
            self.log_probs_buffer[step] = log_probs

            next_observations, rewards, dones, truncateds, info = self.envs.step(actions.cpu().numpy())
            dones = torch.logical_or(torch.tensor(dones), torch.tensor(truncateds))
            self.rewards_buffer[step] = torch.tensor(rewards.astype(np.float32), device=self.device).view(-1)

            next_observations = torch.tensor(next_observations, device=self.device)
            next_dones = dones.to(self.device)

            if 'final_info' in info:
                for i, agent_final_info in enumerate(info['final_info']):
                    if agent_final_info is not None and 'episode' in agent_final_info:
                        env_eps += 1
                        env_rew += agent_final_info['episode']['r'][0]
                        env_eps_len += agent_final_info['episode']['l'][0]

        self.logger.log_metrics({"env/mean_episodes_reward": env_rew / (env_eps + 1e-8),
                                 "env/mean_episodes_length": env_eps_len / (env_eps + 1e-8)},
                                self.current_epoch)

        returns, advantages = self.ppo_model.estimate_returns_and_advantages(self.rewards_buffer, self.values_buffer,
                                                                             self.dones_buffer, next_observations,
                                                                             next_dones, num_steps, self.gamma,
                                                                             self.gae_lambda)

        obs_data = self.observations_buffer.reshape((-1,) + self.observation_shape)
        log_prob_data = self.log_probs_buffer.reshape(-1)
        act_data = self.actions_buffer.reshape((-1,) + self.action_shape)
        adv_data = advantages.reshape(-1)
        ret_data = returns.reshape(-1)
        val_data = self.values_buffer.reshape(-1)
        for obs, log_prob, act, adv, ret, val in zip(obs_data, log_prob_data, act_data, adv_data, ret_data, val_data):
            self.replay_buffer.append(Experience(
                observations=obs,
                actions=act,
                values=val,
                returns=ret,
                advantages=adv,
                log_probs=log_prob))

        self.train()

    def training_step(self, batch: dict[str, torch.Tensor]):
        _, new_log_prob, entropy, new_value = self(batch['observations'], batch['actions'])
        log_ratio = new_log_prob - batch['log_probs']
        ratio = torch.exp(log_ratio)

        # policy loss
        advantages = batch['advantages']
        if self.normalize_advantages:
            advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
        p_loss = policy_loss(batch['advantages'], ratio, self.clip_coef)

        # value loss
        v_loss = value_loss(new_value, batch['values'], batch['returns'], self.clip_coef, self.clip_vloss, self.vf_coef)

        # entropy loss
        e_loss = entropy_loss(entropy, self.ent_coef)

        # update metrics
        self.avg_pg_loss.update(p_loss)
        self.avg_value_loss.update(v_loss)
        self.avg_ent_loss.update(e_loss)

        self.log_dict({'p_loss': p_loss.item(), 'v_loss': v_loss.item(), 'e_loss': e_loss.item()},
                      sync_dist=True, prog_bar=True, on_epoch=True)

        return p_loss + e_loss + v_loss

    def on_train_epoch_end(self):
        self.replay_buffer.clear()

        self.logger.log_metrics(
            {
                "Loss/policy_loss": self.avg_pg_loss.compute(),
                "Loss/value_loss": self.avg_value_loss.compute(),
                "Loss/entropy_loss": self.avg_ent_loss.compute(),
            },
            self.global_step)
        self.reset_metrics()

    def on_fit_end(self):
        self.envs.close()

    def reset_metrics(self):
        self.avg_pg_loss.reset()
        self.avg_value_loss.reset()
        self.avg_ent_loss.reset()

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate, eps=1e-4)

    @classmethod
    def _load_env(cls, env_cfg: dict, data_cfg: dict, log_root: str = None, seed: int = 42) -> gym.vector.SyncVectorEnv:
        def make_env(env_type, env_cfg: dict, idx: int):
            def thunk():
                env = gym.make(env_type, **env_cfg)
                env = gym.wrappers.RecordEpisodeStatistics(env)
                if idx == 0 and log_root is not None:
                    env = gym.wrappers.RecordVideo(
                        env, os.path.join(log_root, 'videos'), disable_logger=True)
                env.action_space.seed(seed)
                env.observation_space.seed(seed)
                return env
            return thunk

        batch_size = data_cfg.get('batch_size', 1)
        env_type = env_cfg.pop('type')
        envs = gym.vector.SyncVectorEnv([make_env(env_type, env_cfg, _) for _ in range(batch_size)])
        return envs

    @classmethod
    def _load_replay_buffer(cls, data_cfg: dict) -> ReplayBuffer:
        total_timestep_n = data_cfg.get('total_timestep_n', 1000)
        buffer = ReplayBuffer(total_timestep_n)
        return buffer

    def train_dataloader(self):
        dataset = RLDataset(self.replay_buffer, self.data_cfg['sample_timestep_n'])
        return DataLoader(dataset, batch_size=self.data_cfg['batch_size'])

# Train & Test

In [6]:
# seed = 42
# env_type = 'CartPole-v1'
# num_envs = 4
# num_steps = 2048
# total_timesteps = 500000
# num_updates = total_timesteps // (num_envs * num_steps)

# seed_everything(seed)

# tb_logger = pl_loggers.TensorBoardLogger(save_dir=os.path.join('logs'), name=env_type)

# envs = PPOLightning._load_env({'type': env_type, "render_mode": "rgb_array"},
#                                {'batch_size': num_envs},
#                                log_root=tb_logger.log_dir,
#                                seed=seed)

# ppo_model = PPOModel(
#     envs=envs,
#     actor_cfg=dict(
#         mlp_cfg=dict(
#             channels=[64, 64],
#             use_layer_norm=False,
#             act_func='ReLU')),
#     critic_cfg=dict(
#         mlp_cfg=dict(
#             channels=[64, 64],
#             use_layer_norm=False,
#             act_func='ReLU')))

# optimizer = torch.optim.Adam(ppo_model.parameters(), lr=1e-4)

# for name, param in ppo_model.named_parameters():
#     print(name, param.requires_grad)

In [7]:
# from tensorboard import notebook


# local_rew = 0.0
# local_ep_len = 0.0
# local_num_episodes = 0.0
# global_steps = 0.0
# update_epochs = 10
# vf_coef: float = 1.0
# ent_coef: float = 0.0
# clip_coef: float = 0.2
# clip_vloss: bool = False
# gamma = 0.99
# gae_lambda = 0.95

# obs = torch.zeros((num_steps, num_envs) + envs.single_observation_space.shape)
# actions = torch.zeros((num_steps, num_envs) + envs.single_action_space.shape)
# logprobs = torch.zeros((num_steps, num_envs))
# rewards = torch.zeros((num_steps, num_envs))
# dones = torch.zeros((num_steps, num_envs))
# values = torch.zeros((num_steps, num_envs))

# next_obs = torch.tensor(envs.reset(seed=seed)[0])
# next_done = torch.zeros(num_envs)
# for update in range(1, num_updates + 1):
#     for step in tqdm_notebook.tqdm_notebook(
#             range(num_steps), 
#             desc=f'Update num: {update}/{num_steps}, generating trajections'):
#         global_steps += num_envs
#         obs[step] = next_obs
#         dones[step] = next_done

#         with torch.no_grad():
#             ppo_model.eval()
#             action, logprob, _, value = ppo_model(next_obs)
#             values[step] = value.flatten()
#         actions[step] = action
#         logprobs[step] = logprob

#         next_obs, reward, done, truncated, info = envs.step(action.cpu().numpy())
#         done = torch.logical_or(torch.tensor(done), torch.tensor(truncated))
#         rewards[step] = torch.tensor(reward).view(-1)
#         next_obs = torch.tensor(next_obs)
#         next_done = done

#         if 'final_info' in info:
#             for i, agent_final_info in enumerate(info['final_info']):
#                 if agent_final_info is not None and "episode" in agent_final_info:
#                     local_num_episodes += 1
#                     local_rew += agent_final_info['episode']['r'][0]
#                     local_ep_len += agent_final_info['episode']['l'][0]

#     if local_num_episodes != 0:
#         tb_logger.log_metrics({
#             'rewards': local_rew / local_num_episodes, 
#             'episode_lengths': local_ep_len / local_num_episodes,
#         }, step=global_steps)

#     local_rew = 0.0
#     local_ep_len = 0.0
#     local_num_episodes = 0.0
#     returns, advantages = ppo_model.estimate_returns_and_advantages(
#         rewards, values, dones, next_obs, next_done, num_steps, gamma, gae_lambda
#     )

#     curr_data = {
#         'obs': obs.reshape((-1,) + envs.single_observation_space.shape),
#         'logprobs': logprobs.reshape(-1),
#         'actions': actions.reshape((-1,) + envs.single_action_space.shape),
#         'advantages': advantages.reshape(-1),
#         'returns': returns.reshape(-1),
#         'values': values.reshape(-1),
#     }

#     if update == 1:
#         with open('train_ppo_ipynb.pkl', 'wb') as fp:
#             import pickle as pkl
#             pkl.dump(curr_data, fp)

#     random_sampler = torch.utils.data.RandomSampler(list(range(curr_data['obs'].shape[0])))
#     sampler = torch.utils.data.BatchSampler(random_sampler, batch_size=num_envs, drop_last=False)

#     ppo_model.train()
#     for epoch in range(update_epochs):
#         per_epoch_p_loss = 0.0
#         per_epoch_v_loss = 0.0
#         per_epoch_e_loss = 0.0
#         for batch_idxs in tqdm_notebook.tqdm_notebook(
#                 sampler,
#                 desc=f'Epoch num: {epoch + 1}/{update_epochs}, updating model',
#                 leave=False):
#             _, newlogprob, entropy, newvalue = ppo_model(curr_data['obs'][batch_idxs],
#                                                          curr_data['actions'][batch_idxs].long())
#             logratio = newlogprob - curr_data['logprobs'][batch_idxs]
#             ratio = logratio.exp()

#             advantages = curr_data['advantages'][batch_idxs]

#             p_loss = policy_loss(advantages, ratio, clip_coef)
#             v_loss = value_loss(newvalue, curr_data['values'][batch_idxs], curr_data['returns'][batch_idxs],
#                                 clip_coef, clip_vloss, vf_coef)
#             e_loss = entropy_loss(entropy, ent_coef)

#             loss = p_loss + v_loss + e_loss
#             per_epoch_p_loss += p_loss.item()
#             per_epoch_v_loss += v_loss.item()
#             per_epoch_e_loss += e_loss.item()

#             optimizer.zero_grad(set_to_none=True)
#             loss.backward()
#             nn.utils.clip_grad_norm_(ppo_model.parameters(), 0.5)
#             optimizer.step()

#         tb_logger.log_metrics({
#             'losses/p_loss': per_epoch_p_loss / len(sampler),
#             'losses/v_loss': per_epoch_v_loss / len(sampler),
#             'losses/e_loss': per_epoch_e_loss / len(sampler),
#         }, step=global_steps)

In [8]:
env_type = 'CartPole-v1'

model_cfg = dict(
    env_cfg=dict(
        type=env_type,
        render_mode="rgb_array"),
    data_cfg=dict(
        total_timestep_n=2048,
        sample_timestep_n=1024,
        batch_size=8),
    ppo_cfg=dict(
        actor_cfg=dict(
            mlp_cfg=dict(
                channels=[64, 64],
                use_layer_norm=True,
                act_func='ReLU')),
        critic_cfg=dict(
            mlp_cfg=dict(
                channels=[64, 64],
                use_layer_norm=True,
                act_func='ReLU'))),
    vf_coef=1.0,
    ent_coef=0.0,
    clip_coef=0.2,
    clip_vloss=True,
    normalize_advantages=False,
    gamma=0.99,
    gae_lambda=0.95,
    learning_rate=1e-4,
)

tb_logger = pl_loggers.TensorBoardLogger(save_dir=os.path.join('logs'), name=env_type)
model = PPOLightning(**model_cfg, log_root=tb_logger.log_dir)
print(model)

trainer = Trainer(max_epochs=2 ** 16, logger=tb_logger)
trainer.fit(model)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


PPOLightning(
  (ppo_model): PPOModel(
    (actor): PPOActor(
      (net): Sequential(
        (0): Sequential(
          (0): Linear(in_features=4, out_features=64, bias=True)
          (1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (2): ReLU()
        )
        (1): Sequential(
          (0): Linear(in_features=64, out_features=64, bias=True)
          (1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (2): ReLU()
        )
        (2): Linear(in_features=64, out_features=2, bias=True)
      )
    )
    (critic): PPOCritic(
      (net): Sequential(
        (0): Sequential(
          (0): Linear(in_features=4, out_features=64, bias=True)
          (1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (2): ReLU()
        )
        (1): Sequential(
          (0): Linear(in_features=64, out_features=64, bias=True)
          (1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (2): ReLU()
        )
        (2): Linear(in_f


  | Name           | Type       | Params | Mode 
------------------------------------------------------
0 | ppo_model      | PPOModel   | 9.7 K  | train
1 | avg_pg_loss    | MeanMetric | 0      | train
2 | avg_value_loss | MeanMetric | 0      | train
3 | avg_ent_loss   | MeanMetric | 0      | train
------------------------------------------------------
9.7 K     Trainable params
0         Non-trainable params
9.7 K     Total params
0.039     Total estimated model params size (MB)
26        Modules in train mode
0         Modules in eval mode
/Users/yuhang09/miniforge3/envs/ppo/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
/Users/yuhang09/miniforge3/envs/ppo/lib/python3.9/site-packages/lightning/pytorch/utilities/data.py:123: Your `IterableDataset

Training: |          | 0/? [00:00<?, ?it/s]


Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined