In [18]:
import gymnasium as gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from stable_baselines3.common.atari_wrappers import (  
    ClipRewardEnv,
    EpisodicLifeEnv,
    FireResetEnv,
    MaxAndSkipEnv,
    NoopResetEnv,
)

In [84]:
#CONSTANTS
total_timesteps = 1000000
num_envs = 4
num_steps = 128
num_epochs = 4
batch_size = num_envs * num_steps
num_updates = total_timesteps // batch_size


In [20]:
def make_env(gym_id, seed, idx, capture_video, run_name):
    def env_create():
        env = gym.make(gym_id)
        env = gym.wrappers.RecordEpisodeStatistics(env)
        if capture_video:
            if idx == 0:
                env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
        env = NoopResetEnv(env, noop_max=30)
        env = MaxAndSkipEnv(env, skip=4)
        env = EpisodicLifeEnv(env)
        if "FIRE" in env.unwrapped.get_action_meanings():
            env = FireResetEnv(env)
        env = ClipRewardEnv(env)
        env = gym.wrappers.ResizeObservation(env, (84, 84))
        env = gym.wrappers.GrayScaleObservation(env)
        env = gym.wrappers.FrameStack(env, 4)
        env.unwrapped.seed(seed)
        env.action_space.seed(seed)
        env.observation_space.seed(seed)
        return env

    return env_create

In [100]:
def layer_initilization(layer, std=np.sqrt(2), bias_const=0.0):
    nn.init.orthogonal_(layer.weight, std)
    nn.init.constant_(layer.bias.data, bias_const)
    return layer

class PPOAgent(nn.Module):
    
    def __init__(self, n_actions, n_frames = 4 ):
        super(PPOAgent, self).__init__()
        
        self.conv = nn.Sequential(
            layer_initilization(nn.Conv2d(n_frames, 32, kernel_size=8, stride=4)),
            nn.ReLU(),
            layer_initilization(nn.Conv2d(32, 64, kernel_size=4, stride=2)),
            nn.ReLU(),
            layer_initilization(nn.Conv2d(64, 64, kernel_size=3, stride=1)),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(3136, 512)
        )
        
        self.actor = nn.Sequential(
            layer_initilization(nn.Linear(512, n_actions))
        )
        
        self.critic = nn.Sequential(
            layer_initilization(nn.Linear(512, 1))
        )
        
        
    def get_value(self, x):
        return self.critic(self.conv(x))
    
    def get_action(self, x, action = None):
            logits = self.actor(self.conv(x))
            dist = torch.distributions.Categorical(logits=logits)
            if action == None:
                action = dist.sample()
            value = self.critic(self.conv(x))
            return action, dist.log_prob(action), dist.entropy(), value
            

In [1]:
def rollout(envs, agent, num_steps, gamma = 0.99, gae_lambda = 0.95):
    
    
    observations = torch.zeros((num_steps , num_envs) + envs.single_observation_space.shape, dtype=torch.float32)
    actions = torch.zeros((num_steps, num_envs) + envs.single_action_space.shape, dtype=torch.int32)
    rewards = torch.zeros((num_steps, num_envs), dtype=torch.float32)
    values = torch.zeros((num_steps, num_envs), dtype=torch.float32)
    logprobs = torch.zeros((num_steps, num_envs), dtype=torch.float32)
    dones = torch.zeros((num_steps, num_envs), dtype=torch.float32)
    #truncs = torch.zeros((num_steps, num_envs), dtype=torch.float32)

    next_obs = torch.Tensor(envs.reset()[0])
    # next_done = torch.zeros(num_envs)
    # next_trunc = torch.zeros(num_envs)
    
    for step in range(num_steps):
        observations[step] = next_obs
        dones[step] = next_done
        #truncs[step] = next_trunc
        
        with torch.no_grad():
            action, logprob, _, value = agent.get_action(next_obs)
            
        actions[step] = action
        values[step] = value.view(-1)
        logprobs[step] = logprob
        
        next_obs, reward, next_done, _, _ = envs.step(action.cpu().numpy())
        next_obs = torch.Tensor(next_obs)
        rewards[step] = torch.tensor(reward)
        
    with torch.no_grad():
            next_value = agent.get_value(next_obs).reshape(1, -1)
            advantages = torch.zeros_like(rewards)
            lastgaelam = 0
            for t in reversed(range(num_steps)):
                if t == num_steps - 1:
                    nextnonterminal = 1.0 - next_done
                    nextvalues = next_value
                else:
                    nextnonterminal = 1.0 - dones[t + 1]
                    nextvalues = values[t + 1]
                delta = rewards[t] + gamma * nextvalues * nextnonterminal - values[t]
                advantages[t] = lastgaelam = delta + gamma * gae_lambda * nextnonterminal * lastgaelam
            returns = advantages + values
    
    return observations, actions, returns, values, advantages, logprobs, dones
        
    

In [2]:
def update_agent(agent, optimizer, observations, actions, returns, values, advantages, logprobs, clip_param = 0.2, vf_coef = 0.5, ent_coef = 0.01):
    
    observations = observations.view(-1, *observations.shape[2:])
    actions = actions.view(-1, *actions.shape[2:])
    returns = returns.view(-1)
    values = values.view(-1)
    advantages = advantages.view(-1)
    logprobs = logprobs.view(-1)
    
    for _ in range(num_epochs):
        for idx in range(0, observations.size(0), batch_size):
            batch_indices = slice(idx, idx + batch_size)
            obs_batch = observations[batch_indices]
            act_batch = actions[batch_indices]
            ret_batch = returns[batch_indices]
            adv_batch = advantages[batch_indices]
            norm_adv_batch = (adv_batch - adv_batch.mean()) / (adv_batch.std() + 1e-8)
            logprob_batch = logprobs[batch_indices]
            
            _, logprob, entropy, value = agent.get_action(obs_batch, act_batch)
            ratio = (logprob - logprob_batch).exp()
            policy_loss = -torch.min(ratio * norm_adv_batch, torch.clamp(ratio, 1.0 - clip_param, 1.0 + clip_param) * norm_adv_batch).mean()
            value_loss = 0.5 * (value - ret_batch).pow(2).mean()
            entropy_loss = entropy.mean()
            loss = policy_loss + value_loss * vf_coef - entropy_loss * ent_coef
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
    return loss.item(), policy_loss.item(), value_loss.item(), entropy_loss.item()
    

In [123]:
#FINAL TRAINING LOOP
envs = gym.vector.SyncVectorEnv([make_env("ALE/Pong-v5", 42, i, False, "Pong") for i in range(num_envs)])
agent = PPOAgent(envs.single_action_space.n)
optimizer = torch.optim.Adam(agent.parameters(), lr=2.5e-4, eps=1e-5)
for update in range(num_updates):
    observations, actions, returns, values, advantages, logprobs, dones = rollout(envs, agent, num_steps)
    loss, policy_loss, value_loss, entropy_loss = update_agent(agent, optimizer, observations, actions, returns, values, advantages, logprobs)
    print(f"Update: {update}, Loss: {loss}, Policy Loss: {policy_loss}, Value Loss: {value_loss}, Entropy Loss: {entropy_loss}")

  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


Update: 0, Loss: 40561.71484375, Policy Loss: -2.3283064365386963e-08, Value Loss: 81123.4296875, Entropy Loss: 3.788868672996958e-32
Update: 1, Loss: 13433.470703125, Policy Loss: -7.171183824539185e-08, Value Loss: 26866.94140625, Entropy Loss: 7.013596585026892e-18
Update: 2, Loss: 6992.53173828125, Policy Loss: 2.337619662284851e-07, Value Loss: 13985.0634765625, Entropy Loss: 1.9777660276642725e-14
Update: 3, Loss: 78.73841094970703, Policy Loss: 1.3969838619232178e-08, Value Loss: 157.47682189941406, Entropy Loss: 8.673495949240717e-11
Update: 4, Loss: 647.7161865234375, Policy Loss: 7.348135113716125e-07, Value Loss: 1295.432373046875, Entropy Loss: 9.606543608242646e-05
Update: 5, Loss: 135.92857360839844, Policy Loss: 0.1246541365981102, Value Loss: 271.6178894042969, Entropy Loss: 0.5021007061004639
Update: 6, Loss: 124.23230743408203, Policy Loss: 0.06882800906896591, Value Loss: 248.32730102539062, Entropy Loss: 0.01705171912908554
Update: 7, Loss: 55.63188171386719, Policy

KeyboardInterrupt: 

In [121]:
update

0