# PPO to play Atari Pong

### The goal of this project work is to implement a PPO algorithm able to learn to play the atari game Pong and reaching a level where it can consistently win, the implementation will try to follow the original implementation from Jhon Schulman.

In [1]:
!pip install gymnasium



In [2]:
!pip install ale-py

Collecting ale-py
  Downloading ale_py-0.10.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.6 kB)
Downloading ale_py-0.10.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m35.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: ale-py
Successfully installed ale-py-0.10.1


In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions.categorical import Categorical
import random
import numpy as np
import gymnasium as gym
from dataclasses import dataclass
import time
import wandb
from torch.utils.tensorboard import SummaryWriter
from PIL import Image

import ale_py

gym.register_envs(ale_py)

  File "/opt/conda/lib/python3.10/site-packages/gymnasium/envs/registration.py", line 594, in load_plugin_envs
    fn()
  File "/opt/conda/lib/python3.10/site-packages/shimmy/registration.py", line 304, in register_gymnasium_envs
    _register_atari_envs()
  File "/opt/conda/lib/python3.10/site-packages/shimmy/registration.py", line 205, in _register_atari_envs
    import ale_py
  File "/opt/conda/lib/python3.10/site-packages/ale_py/__init__.py", line 68, in <module>
    register_v0_v4_envs()
  File "/opt/conda/lib/python3.10/site-packages/ale_py/registration.py", line 178, in register_v0_v4_envs
    _register_rom_configs(legacy_games, obs_types, versions)
  File "/opt/conda/lib/python3.10/site-packages/ale_py/registration.py", line 63, in _register_rom_configs
    gymnasium.register(
AttributeError: partially initialized module 'gymnasium' has no attribute 'register' (most likely due to a circular import)
[0m
  logger.warn(f"plugin: {plugin.value} raised {traceback.format_exc()}")


Define the Actor-Critic architecture, the CNN backbone to process the input images is shared and each component have a MLP head that will generate the actions and the critic-value.

In [4]:
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer


class Agent(nn.Module):
    def __init__(self):
        super(Agent, self).__init__()
        self.network = nn.Sequential(
            layer_init(nn.Conv2d(4, 32, 8, stride=4)),
            nn.ReLU(),
            layer_init(nn.Conv2d(32, 64, 4, stride=2)),
            nn.ReLU(),
            layer_init(nn.Conv2d(64, 64, 3, stride=1)),
            nn.ReLU(),
            nn.Flatten(),
            layer_init(nn.Linear(64 * 6 * 6, 512)),
            nn.ReLU(),
        )
        self.actor = layer_init(nn.Linear(512, 2), std=0.01)
        self.critic = layer_init(nn.Linear(512, 1), std=1)

    def get_value(self, x):
        return self.critic(self.network(x))

    # Here we sample the action randomly from the predicted distribution to maintain exploration
    def get_action_and_value(self, x, action=None):
        hidden = self.network(x)
        logits = self.actor(hidden)
        probs = Categorical(logits=logits) # No need for softmax because we specify logits
        if action is None:
            action = probs.sample()
        return action, probs.log_prob(action), probs.entropy(), self.critic(hidden)

    # Here we get the best action from the predicted distribution
    def act(self, x):
        hidden = self.network(x)
        logits = self.actor(hidden)
        return torch.argmax(logits).item()

Define some basic utils function to process the image given by the env into b&n and crop it and dave gif from a list of states.

In [5]:
def preprocess_pong(state):
    if len(state.shape) == 3:
      state = state[35:195]
      state = state[::2, ::2, 0]  # downsample by factor of 2
      state[state == 144] = 0  # erase background (background type 1)
      state[state == 109] = 0  # erase background (background type 2)
      state[state != 0] = 1  # everything else (paddles, ball) just set to 1
    else:
      state = state[:, 35:195]
      state = state[:, ::2, ::2, 0]  # downsample by factor of 2
      state[state == 144] = 0  # erase background (background type 1)
      state[state == 109] = 0  # erase background (background type 2)
      state[state != 0] = 1  # everything else (paddles, ball) just set to 1

    return state.astype(np.int16)


def save_gif_from_np(images, path, duration=100):
    pil_images = [Image.fromarray((img*255).astype(np.int8)) for img in images]

    # Save as a GIF
    pil_images[0].save(
        path,
        save_all=True,
        append_images=pil_images[1:],  # Add remaining frames
        duration=100,  # Duration between frames in milliseconds
        loop=0  # Loop forever
    )


Define the hyperparameters that will bu used by the PPo algorithm

In [6]:
@dataclass
class Config:
    wandb_project_name = "ppo"
    seed = 42
    num_envs = 8
    gym_id = "ALE/Pong-v5"
    learning_rate = 2.5e-4
    num_steps = 512
    num_minibatches = 4
    total_timesteps = 5000000
    batch_size = int(num_envs * num_steps)
    minibatch_size = int(batch_size // num_minibatches)
    anneal_lr = True
    gamma = 0.99
    gae_lambda = 0.95
    update_epochs = 4
    clip_coef = 0.1
    ent_coef = 0.01
    vf_coef = 0.5
    max_grad_norm = 0.5
    test_every = int(total_timesteps // batch_size // 20)

config = Config()

Define the actual PPO algorithm

In [7]:
def train():
    # Initialize the memory buffers
    obs = torch.zeros((config.num_steps, config.num_envs) + (4, w, h)).to(device)
    actions = torch.zeros((config.num_steps, config.num_envs) + envs.single_action_space.shape).to(device)
    logprobs = torch.zeros((config.num_steps, config.num_envs)).to(device)
    rewards = torch.zeros((config.num_steps, config.num_envs)).to(device)
    dones = torch.zeros((config.num_steps, config.num_envs)).to(device)
    values = torch.zeros((config.num_steps, config.num_envs)).to(device)
    ep_rewards = torch.zeros(config.num_envs).to(device)
    last_4_obs_buffer = torch.zeros(config.num_envs, 4, w, h).to(device)
    terminated_rw = []

    # Game initialization
    global_step = 0
    start_time = time.time()
    next_obs = torch.Tensor(preprocess_pong(envs.reset()[0])).to(device).unsqueeze(1)  # reset return obs, info
    last_4_obs_buffer = torch.cat((last_4_obs_buffer[:, 1:], next_obs), dim=1)
    next_obs = last_4_obs_buffer
    next_done = torch.zeros(config.num_envs).to(device)
    num_updates = int(config.total_timesteps // config.batch_size)

    _ = envs.step(torch.ones(config.num_envs, dtype=torch.long)) # FIRE operation to make the game start

    print(f"Number of policy iteration: {num_updates}")

    for update in range(1, num_updates + 1):
        if config.anneal_lr:
            frac = 1.0 - (update - 1.0) / num_updates
            lrnow = frac * config.learning_rate
            lrnow = max(lrnow, 2e-4)
            optimizer.param_groups[0]["lr"] = lrnow

        # Rollout: we gather experience using the policy
        for step in range(0, config.num_steps):
            global_step += 1 * config.num_envs
            obs[step] = next_obs
            dones[step] = next_done

            with torch.no_grad():
                action, logprob, _, value = agent.get_action_and_value(next_obs)
                values[step] = value.flatten()

            actions[step] = action
            logprobs[step] = logprob

            next_obs, reward, done, truncated, info = envs.step(action.cpu().numpy()+2)
            rewards[step] = torch.tensor(reward).to(device).view(-1)
            ep_rewards += rewards[step]
            next_obs = torch.Tensor(preprocess_pong(next_obs)).to(device).unsqueeze(1)
            last_4_obs_buffer = torch.cat((last_4_obs_buffer[:, 1:], next_obs), dim=1)
            next_obs = last_4_obs_buffer
            next_done = torch.logical_or(torch.Tensor(done).to(device), torch.Tensor(truncated).to(device)).int()
            terminated_rw.extend(ep_rewards[next_done.bool()].tolist())
            ep_rewards *= (1-next_done)

        # Calculate the advantages and bootstrap value if not done
        with torch.no_grad():
            next_value = agent.get_value(next_obs).reshape(1, -1)

            advantages = torch.zeros_like(rewards).to(device)
            lastgaelam = 0
            for t in reversed(range(config.num_steps)):
                if t == config.num_steps - 1:
                    nextnonterminal = 1.0 - next_done
                    nextvalues = next_value
                else:
                    nextnonterminal = 1.0 - dones[t + 1]
                    nextvalues = values[t + 1]
                delta = rewards[t] + config.gamma * nextvalues * nextnonterminal - values[t]
                advantages[t] = lastgaelam = delta + config.gamma * config.gae_lambda * nextnonterminal * lastgaelam
            returns = advantages + values

        # Adjust shapes
        b_obs = obs.reshape((-1,) + (4, w, h))
        b_logprobs = logprobs.reshape(-1)
        b_actions = actions.reshape((-1,) + envs.single_action_space.shape)
        b_advantages = advantages.reshape(-1)
        b_returns = returns.reshape(-1)
        b_values = values.reshape(-1)

        b_inds = np.arange(config.batch_size)
        clipfracs = []

        # Policy update using the experience just played
        for epoch in range(config.update_epochs):
            np.random.shuffle(b_inds)
            for start in range(0, config.batch_size, config.minibatch_size):
                end = start + config.minibatch_size
                mb_inds = b_inds[start:end]

                _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions.long()[mb_inds])
                logratio = newlogprob - b_logprobs[mb_inds]
                ratio = logratio.exp()

                with torch.no_grad():
                    # calculate approx_kl http://joschu.net/blog/kl-approx.html
                    old_approx_kl = (-logratio).mean()
                    approx_kl = ((ratio - 1) - logratio).mean()
                    clipfracs += [((ratio - 1.0).abs() > config.clip_coef).float().mean().item()]

                mb_advantages = b_advantages[mb_inds]
                # Normalize
                mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)

                # Policy loss
                pg_loss1 = -mb_advantages * ratio
                pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - config.clip_coef, 1 + config.clip_coef)
                pg_loss = torch.max(pg_loss1, pg_loss2).mean()

                # Value loss
                newvalue = newvalue.view(-1)
                v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2
                v_clipped = b_values[mb_inds] + torch.clamp(
                    newvalue - b_values[mb_inds],
                    -config.clip_coef,
                    config.clip_coef,
                )
                v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2
                v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
                v_loss = 0.5 * v_loss_max.mean()

                # Final loss
                entropy_loss = entropy.mean()
                loss = pg_loss - config.ent_coef * entropy_loss + v_loss * config.vf_coef

                # Network optimization
                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(agent.parameters(), config.max_grad_norm)
                optimizer.step()

        # Print and log data
        if update % config.test_every == 0:
            print(f"Test reward at update {update}: {test(update)}")

        if len(ep_rewards) > 10:
            mean_rw = torch.mean(torch.tensor(terminated_rw[-10:]))
        else:
            mean_rw = torch.mean(torch.tensor(terminated_rw))

        writer.add_scalar("rewards/mean_reward", mean_rw, global_step)
        writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step)
        writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
        writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step)
        writer.add_scalar("losses/entropy", entropy_loss.item(), global_step)
        writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step)
        writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
        writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step)

    print(f"Final reward at update {update}: {test()}")
    envs.close()
    writer.close()

The test function is uesd to evaluate a given policy, it choses the best moves according to the network instead of sampling from the distribution.

In [8]:
def test(update=None):
    # Create the environment
    env = gym.make(config.gym_id)
    last_4_states = torch.zeros(1, 4, w, h).to(device)
    done = False
    terminated = False
    total_reward = 0
    state_list = []

    # Reset the environment to get the initial state
    state, info = env.reset()

    state = preprocess_pong(state)
    state_list.append(state)
    state = torch.tensor(state).to(device)
    last_4_states = torch.cat((last_4_states[:, 1:], state.unsqueeze(0).unsqueeze(0)), dim=1)

    _ = env.step(1) # We need to do a FIRE operation to make the game starts

    while not (done or terminated):
        # Let the agent decide the action
        with torch.no_grad():
            action = agent.act(last_4_states)

        # Take the action in the environment
        next_state, reward, done, terminated, info = env.step(action+2)
        # Update the total reward
        total_reward += reward
        # Transition to the next state
        next_state = preprocess_pong(next_state)
        state_list.append(next_state)
        next_state = torch.tensor(next_state).to(device)
        last_4_states = torch.cat((last_4_states[:, 1:], next_state.unsqueeze(0).unsqueeze(0)), dim=1)

    game_name = config.gym_id[4:-3]
    save_gif_from_np(state_list, f"{game_name}_{update}.gif")

    return total_reward

Initalize the network and optimizer that will be used in the traninig of the policy

In [9]:
wandb.login(key="a128cb8af0ead22256607ab843b3c7e4e4dd4c48")

run_name = config.gym_id[4:-3]
wandb.init(
    project=config.wandb_project_name,
    sync_tensorboard=True,
    config=vars(config),
    name=run_name,
    monitor_gym=True,
    save_code=True,
)
writer = SummaryWriter(f"runs/{run_name}")
writer.add_text(
    "hyperparameters",
    "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(config).items()])),
)

random.seed(config.seed)
np.random.seed(config.seed)
torch.manual_seed(config.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

envs = gym.make_vec(config.gym_id, num_envs=config.num_envs, vectorization_mode="sync")

w, h = preprocess_pong(envs.reset()[0]).shape[1:]

agent = Agent().to(device)
optimizer = optim.Adam(agent.parameters(), lr=config.learning_rate, eps=1e-5)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mgiovannijgrotto[0m ([33mgiovannijgrotto-universit-di-bologna[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Tracking run with wandb version 0.18.7
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/kaggle/working/wandb/run-20241209_171344-t65gkvt2[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33mPong[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/giovannijgrotto-universit-di-bologna/ppo[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/giovannijgrotto-universit-di-bologna/ppo/runs/t65gkvt2[0m
A.L.E

Let PPO train and visualize the progress

In [10]:
test()
train()

Number of policy iteration: 1220
Test reward at update 61: -17.0
Test reward at update 122: -11.0
Test reward at update 183: 3.0
Test reward at update 244: 15.0
Test reward at update 305: 12.0
Test reward at update 366: 3.0
Test reward at update 427: 12.0
Test reward at update 488: 18.0
Test reward at update 549: 15.0
Test reward at update 610: 20.0
Test reward at update 671: 20.0
Test reward at update 732: 17.0
Test reward at update 793: 15.0
Test reward at update 854: 20.0
Test reward at update 915: 18.0
Test reward at update 976: 20.0
Test reward at update 1037: 17.0
Test reward at update 1098: 21.0
Test reward at update 1159: 16.0
Test reward at update 1220: 19.0
Final reward at update 1220: 21.0
