In [None]:
# Install environment and agent
!pip install highway-env
!pip install --upgrade sympy torch

### Papers:
[Proximal Policy Optimization Algorithms](https://arxiv.org/pdf/1707.06347)

[Towards Delivering a Coherent Self-Contained Explanation of Proximal Policy Optimization](https://fse.studenttheses.ub.rug.nl/25709/1/mAI_2021_BickD.pdf)

### Tutorial
[Hugging Face Deep RL Course: PROXIMAL POLICY OPTIMIZATION (PPO)](https://huggingface.co/learn/deep-rl-course/unit8/introduction)

## Policy Network

In [None]:
import gymnasium as gym
import highway_env
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

def init_layer(layer, gaiN = np.sqrt(2)):
  nn.init.orthogonal_(layer.weight, gain)
  nn.init.constant_(layer.bias, 0)
  return layer


# Policy network (MLP)
class MLPPolicyNetwork(nn.Module):
    def __init__(self, in_states, h1_nodes, out_actions):
        super(MLPPolicyNetwork, self).__init__()

        # Actor network
        self.actor = nn.Sequential(
            init_layer(nn.Linear(in_states, h1_nodes)),
            nn.Tanh(),
            init_layer(nn.Linear(h1_nodes, h1_nodes)),
            nn.Tanh(),
            init_layer(nn.Linear(h1_nodes, out_actions), std = 0.01)
        )
        # Critic network
        self.critic = nn.Sequential(
            init_layer(nn.Linear(in_states, h1_nodes)),
            nn.Tanh(),
            init_layer(nn.Linear(h1_nodes, h1_nodes)),
            nn.Tanh(),
            init_layer(nn.Linear(h1_nodes, 1), std = 1.)
        )

    def forward(self, x):
        logits = self.actor(x)
        value = self.critic(x)
        return logits, value


class CNNPolicyNetwork(nn.Module):
    def __init__(self, input_shape, num_actions):
        super(CNNPolicyNetwork, self).__init__()
        stack, height, width = input_shape

        self.shared_conv = nn.Sequential(
            nn.Conv2d(stack, 16, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=2),
            nn.ReLU(),
        )

        with torch.no_grad():
            test = torch.zeros(1, stack, height, width)
            find_conv_size = self.shared_conv(test)
            conv_size = find_conv_size.numel()

        # self.actor_fc = nn.Linear(conv_size, num_actions)
        # self.critic_fc = nn.Linear(conv_size, 1)

        self.actor_fc = init_layer(nn.Linear(conv_size, num_actions), gain = 0.01)
        self.critic_fc = init_layer(nn.Linear(conv_size, 1), gain = 1.)

    def forward(self, x):
        feats = self.shared_conv(x)
        feats = torch.flatten(feats, start_dim=1)

        logits = self.actor_fc(feats)
        value = self.critic_fc(feats)

        return logits, value