In [None]:
import os
from typing import Optional
from collections import namedtuple, deque
import math
import random
import logging

import tqdm
import tqdm.notebook as tqdm_notebook
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.categorical import Categorical
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataset import IterableDataset
from torchmetrics import MeanMetric

from lightning.pytorch import LightningModule, loggers as pl_loggers
from lightning.pytorch.trainer import Trainer
from lightning.pytorch import seed_everything

In [None]:
logger = logging.getLogger('lightning.pytorch')
logger.setLevel(logging.INFO)

# Data

In [None]:
Experience = namedtuple('Experience', ['observations', 'actions', 'values', 'returns', 'advantages', 'log_probs'])


class ReplayBuffer(object):

    def __init__(self, capacity: int = 1000):
        self.buffer = deque(maxlen=capacity)

    def __len__(self):
        return len(self.buffer)

    def append(self, experience):
        self.buffer.append(experience)

    def sample(self, sample_size: int = 1):
        replace = False if sample_size < len(self) else True
        indices = np.random.choice(len(self), sample_size, replace=replace)
        # collate experiences
        for idx in indices:
            yield {key: getattr(self.buffer[idx], key) for key in Experience._fields}

    def clear(self):
        self.buffer.clear()


class RLDataset(IterableDataset):

    def __init__(self, buffer: ReplayBuffer, sample_step_num: int = 1):
        self.buffer = buffer
        self.sample_step_num = sample_step_num

    def __len__(self):
        return self.sample_step_num

    def __iter__(self):
        for data in self.buffer.sample(self.sample_step_num):
            yield data


class FakeDataset(Dataset):

    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]


# Loss

In [None]:
def policy_loss_factory(clip_coef: float):
    def policy_loss(advantages: torch.Tensor, ratio: torch.Tensor):
        p_loss1 = -advantages * ratio
        p_loss2 = -advantages * torch.clamp(ratio, 1 - clip_coef, 1 + clip_coef)
        return torch.max(p_loss1, p_loss2).mean()
    return policy_loss


def value_loss_factory(clip_coef: float, clip_vloss: bool, vf_coef: float):
    def value_loss(new_values: torch.Tensor, old_values: torch.Tensor, returns: torch.Tensor):
        new_values = new_values.view(-1)
        if not clip_vloss:
            values_pred = new_values
        else:
            values_pred = old_values + torch.clamp(new_values - old_values, -clip_coef, clip_coef)
        return vf_coef * F.mse_loss(values_pred, returns)
    return value_loss


def entropy_loss_factory(ent_coef: float):
    def entropy_loss(entropy: torch.Tensor):
        return -entropy.mean() * ent_coef
    return entropy_loss

# Model

In [None]:
from copy import deepcopy
from os import truncate


class PPOModuleBase(nn.Module):

    def __init__(self, mlp_cfg: dict, inp_channels: int, out_channels: int, *args, **kwargs):
        super().__init__()

        self.net = self._create_mlp(mlp_cfg, inp_channels, out_channels)

    def _create_mlp(self, mlp_cfg, inp_channels, out_channels):
        channels = [inp_channels] + mlp_cfg['channels'] + [out_channels]
        use_layer_norm = mlp_cfg.get('use_layer_norm', False)
        act_func = mlp_cfg.get('act_func', 'ReLU')

        _mlp = nn.Sequential(*[
            nn.Sequential(
                nn.Linear(in_chn, out_chn, bias=True),
                nn.LayerNorm(out_chn) if use_layer_norm else nn.Identity(),
                getattr(nn, act_func)())
            for in_chn, out_chn in zip(channels[:-2], channels[1:-1])
        ])
        _mlp.append(nn.Linear(channels[-2], channels[-1], bias=True))
        return _mlp

    def forward(self, x: torch.Tensor):
        return self.net(x)


class PPOActor(PPOModuleBase):

    def __init__(self, mlp_cfg: dict, envs: gym.vector.SyncVectorEnv):
        inp_channels = math.prod(envs.single_observation_space.shape)
        out_channels = envs.single_action_space.n
        super().__init__(mlp_cfg, inp_channels, out_channels)


class PPOCritic(PPOModuleBase):

    def __init__(self, mlp_cfg: dict, envs: gym.vector.SyncVectorEnv):
        inp_channels = math.prod(envs.single_observation_space.shape)
        out_channels = 1
        super().__init__(mlp_cfg, inp_channels, out_channels)


class PPOModel(nn.Module):

    def __init__(self, envs: gym.vector.SyncVectorEnv, actor_cfg: dict, critic_cfg: dict, *args, **kwargs):
        super().__init__()

        self.actor: PPOActor = PPOActor(**actor_cfg, envs=envs)
        self.critic: PPOCritic = PPOCritic(**critic_cfg, envs=envs)

    def get_action(self, obs: torch.Tensor, act: Optional[torch.Tensor] = None, greedy: bool = False):
        act_logits = self.actor(obs)
        if greedy:
            probs = F.softmax(act_logits, dim=-1)
            return torch.argmax(probs, dim=-1)
        else:
            dist = Categorical(logits=act_logits)
            if act is None:
                act = dist.sample()
            return act, dist.log_prob(act), dist.entropy()

    def get_value(self, obs: torch.Tensor):
        return self.critic(obs)

    def forward(self, obs: torch.Tensor, act: torch.Tensor = None, greedy: bool = False):
        if greedy:
            return self.get_action(obs, act, greedy)
        else:
            act, log_prob, entropy = self.get_action(obs, act)
            val = self.get_value(obs)
            return act, log_prob, entropy, val

    @torch.no_grad()
    def estimate_returns_and_advantages(
            self,
            rewards: torch.Tensor,
            values: torch.Tensor,
            dones: torch.Tensor,
            next_obs: torch.Tensor,
            next_done: torch.Tensor,
            num_steps: int,
            gamma: float,
            gae_lambda: float):
        next_values = self.get_value(next_obs).reshape(1, -1)
        advantages = torch.zeros_like(rewards)
        last_gae_lam = 0
        for t in reversed(range(num_steps)):
            if t == num_steps - 1:
                next_non_terminal = torch.logical_not(next_done)
            else:
                next_non_terminal = torch.logical_not(dones[t + 1])
                next_values = values[t + 1]
            delta = rewards[t] + gamma * next_values * next_non_terminal - values[t]
            advantages[t] = last_gae_lam = delta + gamma * gae_lambda * next_non_terminal * last_gae_lam
        returns = advantages + values
        return returns, advantages


class PPOLightning(LightningModule):

    def __init__(
            self,
            env_cfg: dict,
            data_cfg: dict,
            ppo_cfg: dict,
            loss_cfg: dict = {
                'policy_loss': {'clip_coef': 0.2},
                'value_loss': {'clip_coef': 0.2, 'clip_vloss': False, 'vf_coef': 1.0},
                'entropy_loss': {'ent_coef': 0.0}},
            optim_cfg: dict = {'lr': 1e-4},
            running_cfg: dict = {
                'seed': 42,
                'log_root': '',
                'gamma': 0.99,
                'gae_lambda': 0.95,
                'normalize_advantages': False,
                'update_interval': 1,
                'update_steps': 10},
            **torchmetrics_kwargs):
        super().__init__()
        self.save_hyperparameters()

        self._load_loss(loss_cfg)
        self.train_envs = self._load_env(env_cfg, data_cfg, 
                                         os.path.join(self.hparams.running_cfg.get('log_root', ''), 'video', 'train'),
                                         self.hparams.running_cfg.get('seed', 42))
        self.val_env = self._load_env(self.hparams.env_cfg, {'batch_size': 1},
                                      os.path.join(self.hparams.running_cfg.get('log_root', ''), 'video', 'val'),
                                      self.hparams.running_cfg.get('seed', 42))

        self.replay_buffer = self._load_replay_buffer(data_cfg)
        self.ppo_model: PPOModel = PPOModel(envs=self.train_envs, **ppo_cfg)

        self.avg_pg_loss = MeanMetric(**torchmetrics_kwargs)
        self.avg_value_loss = MeanMetric(**torchmetrics_kwargs)
        self.avg_ent_loss = MeanMetric(**torchmetrics_kwargs)

        self._reset_data = False

    def forward(self, obs: torch.Tensor, act: Optional[torch.Tensor] = None, *args, **kwargs):
        return self.ppo_model(obs, act, *args, **kwargs)

    def on_fit_start(self):
        num_steps = self.hparams.data_cfg.get('total_timestep_n', 1000)

        self.observation_shape = self.train_envs.single_observation_space.shape
        self.action_shape = self.train_envs.single_action_space.shape

        self.observations_buffer = torch.zeros((num_steps, self.train_envs.num_envs) + self.observation_shape,
                                               device=self.device)
        self.actions_buffer = torch.zeros((num_steps, self.train_envs.num_envs) + self.action_shape, device=self.device)
        self.log_probs_buffer = torch.zeros((num_steps, self.train_envs.num_envs), device=self.device)
        self.rewards_buffer = torch.zeros((num_steps, self.train_envs.num_envs), device=self.device)
        self.dones_buffer = torch.zeros((num_steps, self.train_envs.num_envs), device=self.device)
        self.values_buffer = torch.zeros((num_steps, self.train_envs.num_envs), device=self.device)

        self._load_data()

    def on_train_epoch_start(self):
        if self._reset_data:
            self._load_data()
            self._reset_data = False

    def training_step(self, batch: dict[str, torch.Tensor]):
        _, new_log_prob, entropy, new_value = self(batch['observations'], batch['actions'])
        log_ratio = new_log_prob - batch['log_probs']
        ratio = torch.exp(log_ratio)

        # policy loss
        advantages = batch['advantages']
        if self.hparams.running_cfg.get('normalize_advantages', False):
            advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
        p_loss = self.policy_loss(batch['advantages'], ratio)

        # value loss
        v_loss = self.value_loss(new_value, batch['values'], batch['returns'])

        # entropy loss
        e_loss = self.entropy_loss(entropy)

        # update metrics
        self.avg_pg_loss.update(p_loss)
        self.avg_value_loss.update(v_loss)
        self.avg_ent_loss.update(e_loss)

        return p_loss + e_loss + v_loss

    def validation_step(self, *args, **kwargs):
        
        step = 0
        done = False
        cumulative_rew = 0
        next_obs = torch.tensor(self.val_env.reset(seed=self.hparams.running_cfg.get('seed', 42))[0], device=self.device)
        while not done:
            action = self(next_obs, greedy=True)
            next_obs, reward, done, truncate, _ = self.val_env.step(action.cpu().numpy())
            done = done or truncate
            cumulative_rew += reward
            next_obs = torch.tensor(next_obs, device=self.device)
            step += 1

        return {'loss': 0, 'val/reward': cumulative_rew, 'val/step': step}

    def on_train_epoch_end(self):
        update_interval = self.hparams.running_cfg.get('update_interval', 1)
        if (self.current_epoch + 1) % update_interval == 0:
            self.replay_buffer.clear()
            self._reset_data = True

        self.logger.log_metrics(
            {
                "Loss/policy_loss": self.avg_pg_loss.compute(),
                "Loss/value_loss": self.avg_value_loss.compute(),
                "Loss/entropy_loss": self.avg_ent_loss.compute(),
            },
            self.global_step)
        self.reset_metrics()

    def on_fit_end(self):
        self.train_envs.close()
        self.val_env.close()

    def reset_metrics(self):
        self.avg_pg_loss.reset()
        self.avg_value_loss.reset()
        self.avg_ent_loss.reset()

    def configure_optimizers(self):
        lr = self.hparams.optim_cfg.get('lr', 1e-4)
        return torch.optim.Adam(self.parameters(), lr=lr, eps=1e-4)

    def _load_loss(self, loss_cfg: dict):
        self.policy_loss = policy_loss_factory(**loss_cfg['policy_loss'])
        self.value_loss = value_loss_factory(**loss_cfg['value_loss'])
        self.entropy_loss = entropy_loss_factory(**loss_cfg['entropy_loss'])

    @torch.no_grad()
    def _load_data(self):
        self.eval()

        env_eps = 0
        env_rew = 0
        env_eps_len = 0

        num_steps = self.hparams.data_cfg.get('total_timestep_n', 1000)

        gamma = self.hparams.running_cfg.get('gamma', 0.99)
        gae_lambda = self.hparams.running_cfg.get('gae_lambda', 0.95)

        next_observations = torch.tensor(self.train_envs.reset()[0], device=self.device)
        next_dones = torch.zeros(self.train_envs.num_envs, device=self.device)
        for step in range(0, num_steps):
            self.observations_buffer[step] = next_observations
            self.dones_buffer[step] = next_dones

            actions, log_probs, _, values = self(next_observations)
            self.values_buffer[step] = values.flatten()
            self.actions_buffer[step] = actions
            self.log_probs_buffer[step] = log_probs

            next_observations, rewards, dones, truncateds, info = self.train_envs.step(actions.cpu().numpy())
            dones = torch.logical_or(torch.tensor(dones), torch.tensor(truncateds))
            self.rewards_buffer[step] = torch.tensor(rewards.astype(np.float32), device=self.device).view(-1)

            next_observations = torch.tensor(next_observations, device=self.device)
            next_dones = dones.to(self.device)

            episode = info.get('episode', None)
            if episode:
                for r, l in zip(episode['r'], episode['l']):
                    env_eps += 1
                    env_rew += r
                    env_eps_len += l

        self.logger.log_metrics({"env/mean_episodes_reward": env_rew / (env_eps + 1e-8),
                                 "env/mean_episodes_length": env_eps_len / (env_eps + 1e-8)},
                                self.current_epoch)

        returns, advantages = self.ppo_model.estimate_returns_and_advantages(self.rewards_buffer, self.values_buffer,
                                                                             self.dones_buffer, next_observations,
                                                                             next_dones, num_steps, gamma, gae_lambda)

        obs_data = self.observations_buffer.reshape((-1,) + self.observation_shape)
        log_prob_data = self.log_probs_buffer.reshape(-1)
        act_data = self.actions_buffer.reshape((-1,) + self.action_shape)
        adv_data = advantages.reshape(-1)
        ret_data = returns.reshape(-1)
        val_data = self.values_buffer.reshape(-1)
        for obs, log_prob, act, adv, ret, val in zip(obs_data, log_prob_data, act_data, adv_data, ret_data, val_data):
            self.replay_buffer.append(Experience(
                observations=obs,
                actions=act,
                values=val,
                returns=ret,
                advantages=adv,
                log_probs=log_prob))

        self.train()

    @classmethod
    def _load_env(cls, env_cfg: dict, data_cfg: dict, log_root: str = None, seed: int = 42) -> gym.vector.SyncVectorEnv:
        def make_env(env_type, env_cfg: dict, idx: int):
            def thunk():
                env = gym.make(env_type, **env_cfg)
                env = gym.wrappers.RecordEpisodeStatistics(env)
                if idx == 0 and log_root:
                    env = gym.wrappers.RecordVideo(env, log_root, disable_logger=True, )
                env.action_space.seed(seed)
                env.observation_space.seed(seed)
                return env
            return thunk

        curr_env_cfg = deepcopy(env_cfg)
        batch_size = data_cfg.get('batch_size', 1)
        env_type = curr_env_cfg.pop('type')
        envs = gym.vector.SyncVectorEnv([make_env(env_type, curr_env_cfg, _) for _ in range(batch_size)])
        return envs

    @classmethod
    def _load_replay_buffer(cls, data_cfg: dict) -> ReplayBuffer:
        total_timestep_n = data_cfg.get('total_timestep_n', 1000)
        buffer = ReplayBuffer(total_timestep_n)
        return buffer

    def train_dataloader(self):
        batch_size = self.hparams.data_cfg.get('batch_size', 1)
        sample_timestep_n = max(self.hparams.data_cfg['sample_timestep_n'],
                                self.hparams.running_cfg['update_steps'] * batch_size)
        dataset = RLDataset(self.replay_buffer, sample_timestep_n)
        return DataLoader(dataset, batch_size=batch_size)

    def val_dataloader(self):
        dataset = FakeDataset(list([i for i in range(1)]))
        return DataLoader(dataset, batch_size=1)

# Train & Test

In [None]:
env_type = 'CartPole-v1'
log_root = 'logs'

tb_logger = pl_loggers.TensorBoardLogger(save_dir=log_root, name=env_type)

model_cfg = dict(
    env_cfg=dict(
        type=env_type,
        render_mode="rgb_array"),
    data_cfg=dict(
        total_timestep_n=2048,
        sample_timestep_n=1024,
        batch_size=16),
    ppo_cfg=dict(
        actor_cfg=dict(
            mlp_cfg=dict(
                channels=[64, 64],
                use_layer_norm=True,
                act_func='ReLU')),
        critic_cfg=dict(
            mlp_cfg=dict(
                channels=[64, 64],
                use_layer_norm=True,
                act_func='ReLU'))),
    loss_cfg=dict(
        policy_loss=dict(clip_coef=0.2),
        value_loss=dict(
            clip_coef=0.2,
            clip_vloss=False,
            vf_coef=1.0),
        entropy_loss=dict(ent_coef=0.0)),
    optim_cfg=dict(lr=1e-4),
    running_cfg=dict(
        seed=42,
        log_root=tb_logger.log_dir,
        gamma=0.99,
        gae_lambda=0.95,
        normalize_advantages=False,
        update_interval=10,
        update_steps=10),
)

model = PPOLightning(**model_cfg)
trainer = Trainer(max_steps=5000, logger=tb_logger, val_check_interval=100, check_val_every_n_epoch=None)
trainer.fit(model)