In [None]:
from typing import Optional
from collections import namedtuple, deque
import math
import random

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.categorical import Categorical
from torch.utils.data import DataLoader
from torch.utils.data.dataset import IterableDataset
from torchmetrics import MeanMetric

from lightning.pytorch import LightningModule
from lightning.pytorch.trainer import Trainer

In [36]:

Experience = namedtuple('Experience', ['state', 'action', 'reward', 'next_state', 'done', 'log_prob'])


class ReplayBuffer(object):

    def __init__(self, capacity: int = 1000):
        self.buffer = deque(maxlen=capacity)

    def __len__(self):
        return len(self.buffer)

    def append(self, experience):
        self.buffer.append(experience)

    def sample(self, sample_size: int = 1):
        replace = False if sample_size < len(self) else True
        indices = np.random.choice(len(self), sample_size, replace=replace)
        # collate experiences
        for idx in indices:
            yield {key: getattr(self.buffer[idx], key) for key in Experience._fields}

    def clear(self):
        self.buffer.clear()


class RLDataset(IterableDataset):

    def __init__(self, buffer: ReplayBuffer, sample_step_num: int = 1):
        self.buffer = buffer
        self.sample_step_num = sample_step_num

    def __iter__(self):
        batched_data_iter = self.buffer.sample(self.sample_step_num)
        for _ in batched_data_iter:
            yield _


In [None]:
def policy_loss(advantages: torch.Tensor, ratio: torch.Tensor, cliip_coef: float):
    p_loss1 = -advantages * ratio
    p_loss2 = -advantages * torch.clamp(ratio, 1 - cliip_coef, 1 + cliip_coef)
    return torch.max(p_loss1, p_loss2).mean()


def value_loss(new_values: torch.Tensor, old_values: torch.Tensor, returns: torch.Tensor, clip_coef: float,
               clip_vloss: bool, vf_coef: float):
    new_values = new_values.view(-1)
    if not clip_vloss:
        values_pred = new_values
    else:
        values_pred = old_values + torch.clamp(new_values - old_values, -clip_coef, clip_coef)
    return vf_coef * F.mse_loss(values_pred, returns)


def entropy_loss(entropy: torch.Tensor, ent_coef: float):
    return -entropy.mean() * ent_coef


class Preprocessor(nn.Module):

    def __init__(self, *args, **kwargs):
        super().__init__()

    def forward(self, batch: dict[str, torch.Tensor]):
        pass


class PPOModuleBase(nn.Module):

    def __init__(self, mlp_cfg: dict, inp_channels: int, out_channels: int, *args, **kwargs):
        super().__init__()

        self.net = self._create_mlp(mlp_cfg, inp_channels, out_channels)

    def _create_mlp(self, mlp_cfg, inp_channels, out_channels):
        channels = [inp_channels] + mlp_cfg['channels'] + [out_channels]
        use_layer_norm = mlp_cfg.get('use_layer_norm', False)
        act_func = mlp_cfg.get('act_func', 'ReLU')

        _mlp = nn.Sequential(*[
            nn.Sequential(
                nn.Linear(in_chn, out_chn, bias=False),
                nn.LayerNorm(out_chn) if use_layer_norm else nn.Identity(),
                getattr(nn, act_func)())
            for in_chn, out_chn in zip(channels[:-1], channels[1:])
        ])
        return _mlp

    def forward(self, x: torch.Tensor):
        return self.net(x)


class PPOActor(PPOModuleBase):

    def __init__(self, mlp_cfg: dict, envs: gym.vector.SyncVectorEnv):
        inp_channels = math.prod(envs.single_observation_space.shape)
        out_channels = envs.single_action_space.n
        super().__init__(mlp_cfg, inp_channels, out_channels)


class PPOCritic(PPOModuleBase):

    def __init__(self, mlp_cfg: dict, envs: gym.vector.SyncVectorEnv):
        inp_channels = math.prod(envs.single_observation_space.shape)
        out_channels = 1
        super().__init__(mlp_cfg, inp_channels, out_channels)


class PPOModule(nn.Module):

    def __init__(self, env_cfg: dict, actor_cfg: dict, critic_cfg: dict):
        super().__init__()

        self.env_cfg = env_cfg
        self.actor_cfg = actor_cfg
        self.critic_cfg = critic_cfg

    def _create_env(self, env_cfg: dict) -> gym.vector.SyncVectorEnv:
        batch_size = self.data_cfg['batch_size']

        type = env_cfg.pop('type')
        envs = gym.vector.SyncVectorEnv([gym.make(type, **env_cfg) for _ in range(batch_size)])

        return envs



class PPO(LightningModule):

    def __init__(
            self,
            env_cfg: dict,
            preprocessor_cfg: dict,
            actor_cfg: dict,
            critic_cfg: dict,
            data_cfg: dict,
            optim_cfg: dict,
            torchmetric_cfg: dict,
            normalize_advantages: bool = True,
            clip_coef: float = 0.2,
            clip_vloss: bool = False,
            vf_coef: float = 1.0,
            ent_coef: float = 0.0,
            
            *args, **kwargs):
        super().__init__()
        self.save_hyperparameters()

        self.env_cfg = env_cfg
        self.actor_cfg = actor_cfg
        self.critic_cfg = critic_cfg
        self.data_cfg = data_cfg
        self.optim_cfg = optim_cfg

        self.envs: gym.vector.SyncVectorEnv = self._create_env(env_cfg)

        self.preprocessor = Preprocessor(**preprocessor_cfg)
        self.actor = PPOActor(**actor_cfg, envs=self.envs)
        self.critic = PPOCritic(**critic_cfg, envs=self.envs)

        self.avg_p_loss = MeanMetric(**torchmetric_cfg)
        self.avg_v_loss = MeanMetric(**torchmetric_cfg)
        self.avg_e_loss = MeanMetric(**torchmetric_cfg)

        self.normalize_advantages = normalize_advantages
        self.clip_coef = clip_coef
        self.clip_vloss = clip_vloss
        self.vf_coef = vf_coef
        self.ent_coef = ent_coef

        self.buffer = ReplayBuffer(data_cfg['buffer_size'])

    def get_action(self, obs: torch.Tensor, action: Optional[torch.Tensor] = None):
        logits = self.actor(obs)
        distribution = Categorical(logits=logits)
        if action is None:
            action = distribution.sample()
        return action, distribution.log_prob(action), distribution.entropy()

    def get_greedy_action(self, obs: torch.Tensor):
        logits = self.actor(obs)
        probs = F.softmax(logits, dim=-1)
        return torch.argmax(probs, dim=-1)

    def get_value(self, obs: torch.Tensor):
        return self.critic(obs)

    def get_action_and_value(self, obs: torch.Tensor, action: torch.Tensor = None):
        action, log_prob, entropy = self.get_action(obs, action)
        value = self.get_value(obs)
        return action, log_prob, entropy, value

    def forward(self, obs: torch.Tensor, action: Optional[torch.Tensor] = None):
        return self.get_action_and_value(obs, action)

    @torch.no_grad()
    def estimate_returns_and_advantages(self,
                                        rewards: torch.Tensor,
                                        values: torch.Tensor,
                                        dones: torch.Tensor,
                                        next_obs: torch.Tensor,
                                        next_done: torch.Tensor,
                                        num_steps: int,
                                        gamma: float,
                                        gae_lambda: float):
        next_value = self.get_value(next_obs).reshape(1, -1)
        advantages = torch.zeros_like(rewards)
        last_gae_lam = 0
        for t in reversed(range(num_steps)):
            if t == num_steps - 1:
                next_non_terminal = torch.logical_not(next_done)
                next_values = next_value
            else:
                next_non_terminal = torch.logical_not(dones[t + 1])
                next_values = values[t + 1]
            delta = rewards[t] + gamma * next_values * next_non_terminal - values[t]
            advantages[t] = last_gae_lam = delta + gamma * gae_lambda * next_non_terminal * last_gae_lam
        returns = advantages + values
        return returns, advantages

    def on_train_epoch_start(self):
        num_steps = self.data_cfg['num_steps']
        self.eval()

        nxt_obs = torch.tensor(self.envs.reset()[0])
        nxt_done = torch.zeros(self.envs.num_envs)
        for _ in range(0, num_steps):
            obs = nxt_obs
            done = nxt_done

            with torch.no_grad():
                action, log_prob, _, value = self.get_action_and_value(nxt_obs)
                value = value.flatten()

            nxt_obs, reward, done, truncated, info = self.envs.step(action.cpu().numpy())
            done = torch.logical_or(torch.tensor(done), torch.tensor(truncated))
            reward = torch.tensor(reward, dtype=torch.float32).view(-1)

            curr_batch = Experience(obs, action, reward, nxt_obs, done)
            self.buffer.append(curr_batch)

    def training_step(self, batch: dict[str, torch.Tensor]):
        batch = self.preprocessor(batch)

        _, new_log_prob, entropy, new_value = self(batch['obs'], batch['actions'].long())
        log_ratio = new_log_prob - batch['log_probs']
        ratio = log_ratio.exp()

        advantages = batch['advantages']
        if self.normalize_advantages:
            advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        # policy loss
        p_loss = policy_loss(batch['advantages'], ratio, self.clip_coef)

        # value loss
        v_loss = value_loss(new_value, batch['values'], batch['returns'], self.clip_coef, self.clip_vloss, self.vf_coef)

        # entropy loss
        e_loss = entropy_loss(entropy, self.ent_coef)

        # update metrics
        self.avg_p_loss.update(p_loss)
        self.avg_v_loss.update(v_loss)
        self.avg_e_loss.update(e_loss)

        return p_loss + v_loss + e_loss

    def on_train_epoch_end(self, global_step: int):
        self.logger.log_metrics(
            {
                "Loss/policy_loss": self.avg_p_loss.compute(),
                "Loss/value_loss": self.avg_v_loss.compute(),
                "Loss/entropy_loss": self.avg_e_loss.compute(),
            },
            global_step)

        self.avg_p_loss.reset()
        self.avg_v_loss.reset()
        self.avg_e_loss.reset()

        self.buffer.clear()

    def configure_optimizers(self):
        optim_type = self.optim_cfg.pop('type')
        return getattr(torch.optim, optim_type)(self.parameters(), **self.optim_cfg)

    def _create_env(self, env_cfg: dict) -> gym.vector.SyncVectorEnv:
        batch_size = self.data_cfg['batch_size']

        type = env_cfg.pop('type')
        envs = gym.vector.SyncVectorEnv([gym.make(type, **env_cfg) for _ in range(batch_size)])

        return envs

    def train_dataloader(self):
        dataset = RLDataset(self.buffer, self.data_cfg['batch_size'])
        dataloader = DataLoader(dataset, batch_size=self.data_cfg['batch_size'])
        return dataloader

    def get_device(self, batch):
        return batch[0].device
