In [None]:
# Install environment and agent
!pip install highway-env
!pip install --upgrade sympy torch

### Papers:
[Proximal Policy Optimization Algorithms](https://arxiv.org/pdf/1707.06347)

[Towards Delivering a Coherent Self-Contained Explanation of Proximal Policy Optimization](https://fse.studenttheses.ub.rug.nl/25709/1/mAI_2021_BickD.pdf)

### Tutorial
[Hugging Face Deep RL Course: PROXIMAL POLICY OPTIMIZATION (PPO)](https://huggingface.co/learn/deep-rl-course/unit8/introduction)

## Policy Network

In [None]:
import gymnasium as gym
import highway_env
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

# Policy network (MLP)
class MLPPolicyNetwork(nn.Module):
    def __init__(self, in_states, h1_nodes, out_actions):
        super(MLPPolicyNetwork, self).__init__()

        # Actor network
        self.actor = nn.Sequential(
            nn.Linear(in_states, h1_nodes),
            nn.ReLU(),
            nn.Linear(h1_nodes, out_actions)
        )
        # Critic network
        self.critic = nn.Sequential(
            nn.Linear(in_states, h1_nodes),
            nn.ReLU(),
            nn.Linear(h1_nodes, 1)
        )

    def forward(self, x):
        policy = self.actor(x)
        value = self.critic(x)
        return policy, value


class CNNPolicyNetwork(nn.Module):
    def __init__(self, input_shape, num_actions):
        super(CNNPolicyNetwork, self).__init__()
        # greyscale Image is(stack,height,width)
        stack, height, width = input_shape

        # Feature
        self.actor_cn = nn.Sequential(
            nn.Conv2d(stack, 16, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=2),
            nn.ReLU(),
        )

        self.critic_cn = nn.Sequential(
            nn.Conv2d(stack, 16, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=2),
            nn.ReLU(),
        )

        # This is for finding the size to dense more robust compared to decision manually
        with torch.no_grad():
                # Torch uses(1,channels,height,width)
                test = torch.zeros(1, stack, height, width)
                find_conv_size = self.conv(test)
                conv_size = find_conv_size.numel()


        self.actor = nn.Linear(conv_size, action_dim) # Actor network
        self.critic = nn.Linear(conv_size, 1) # Critic network

    def forward(self, x):
        # x shape: [batch, stack, height, width]
        conv_feats = self.actor_conv(x)
        conv_feats = torch.flatten(conv_feats, start_dim=1)
        policy = self.actor_fc(conv_feats)

        conv_feats_critic = self.critic_conv(x)
        conv_feats_critic = torch.flatten(conv_feats_critic, start_dim=1)
        value = self.critic(critic_feats)

        return policy, value