* Youtube video - https://www.youtube.com/watch?v=xHf8oKd7cgU&list=PLgMYKvjKE10UZNku-Qx7-z2PEC-7KLiUn&index=2
* Proximal Policy Optimization Algorithms - https://arxiv.org/abs/1707.06347

In [4]:
import gymnasium as gym
import ale_py
gym.register_envs(ale_py)
import torch
import random
import numpy as np
import torch.nn as nn
from tqdm import tqdm
import torch.nn.functional as F
import matplotlib.pyplot as plt
from stable_baselines3.common.atari_wrappers import MaxAndSkipEnv
from torch.utils.data import DataLoader, Dataset

In [3]:
class ActorCritic(nn.Module):
    def __init__(self, nb_actions):
        super().__init__()

        self.head = nn.Sequential(
            nn.Conv2d(4, 16, kernel_size=8, stride=4),
            nn.Tanh(),
            nn.Conv2d(16, 32, kernel_size=4, stride=2),
            nn.Tanh(),
            nn.Flatten(),
            nn.Linear(32*9*9, 256),
            nn.Tanh()
        )
        self.actor = nn.Sequential(
            nn.Linear(256, nb_actions)
        )
        self.critic = nn.Sequential(
            nn.Linear(256, 1)
        )

    def forward(self, x):
        h = self.head(x)
        return self.actor(h), self.critic(h)

In [6]:
class Environments():
    def __init__(self, nb_actor):
        self.envs = [self.get_env() for _ in range(nb_actor)]
        self.observations = [None for _ in range(nb_actor)]
        self.current_life = [None for _ in range(nb_actor)]
        self.done = [False for _ in range(nb_actor)]
        self.total_reward = [0 for _ in range(nb_actor)]
        self.nb_actor = nb_actor

        for env_id in range(nb_actor):
            self.reset_env(env_id)

    def len(self):
        return self.nb_actor
    
    def reset_env(self, env_id):
        self.total_reward[env_id] = 0
        self.envs[env_id].reset()   

        for _ in range(random.randint(1, 30)):
            self.observations[env_id], reward, _, _, info = self.envs[env_id].step(1)
            self.total_reward[env_id] += reward
            self.current_life[env_id] = info['lives']

    def step(self, env_id, action):
        next_obs, reward, dead, _, info = self.envs[env_id].step(action)
        done = True if (info['lives'] < self.current_life[env_id]) else False
        self.done = done
        self.total_reward[env_id] += reward
        self.current_life[env_id] = info['lives']
        self.observations[env_id] = next_obs
        return next_obs, reward, dead, done, info
    
    def get_env(self):
        env = gym.make("ALE/Breakout-v5", 
               render_mode='human'
               )
        env = gym.wrappers.RecordEpisodeStatistics(env)
        env = gym.wrappers.ResizeObservation(env, (84, 84))
        env = gym.wrappers.GrayscaleObservation(env)
        env = gym.wrappers.FrameStackObservation(env, 4) # It was frame stack in video
        env = MaxAndSkipEnv(env, skip = 4)
        return env

In [None]:
def PPO(
        envs,
        T=128,
        K=3,
        batch_size=32*8,
        gamma=0.99,
        device="cuda",
        gae_parameter=0.95,
        vf_coef_c1=1,
        ent_coef_c2=0.01,
        nb_iterations=40_000
    ):
    optimizer = torch.optim.Adam(actorcritic.parameters(), lr=2.5e-4)
    sheduler = torch.optim.lr_scheduler.LinearLR(
        optimizer, 
        start_factor=1,
        end_factor=0,
        total_iters=nb_iterations
    )
    max_reward = 0
    total_rewards = [[] for _ in range(envs.len())]
    smoothed_rewareds = [[] for _ in range(envs.len())]

    for iteration in tqdm(range(nb_iterations)):
        advantages = torch.zeros((envs.len(), T), dtype=torch.float32).to(device)
        buffer_states = torch.zeros((envs.len(), T, 4, 84, 84), dtype=torch.float32).to(device)
        buffer_actions = torch.zeros((envs.len(), T), dtype=torch.long).to(device)
        buffer_logprobs = torch.zeros((envs.len(), T), dtype=torch.float32).to(device)
        buffer_states_values = torch.zeros((envs.len(), T+1), dtype=torch.float32).to(device)
        buffer_rewards = torch.zeros((envs.len(), T), dtype=torch.float32).to(device)
        buffer_in_terminal = torch.zeros((envs.len(), T), dtype=torch.float32).to(device)

        for env_id in range(envs.len()):
            with torch.no_grad():
                for t in range(T):
                    obs = torch.tensor(envs.observations[env_id] / 255, dtype=torch.float32).to(device).unsqueeze(0)
                    logits, value = actorcritic(obs)
                    logits, value = logits.squeeze(0), value.squeeze(0)
                    m = torch.distributions.Categorical(logits=logits)











                    