In [1]:
import einops
import numpy as np
import pufferlib, pufferlib.vector
from pufferlib.environments import atari

In [2]:
game = "pong"
num_envs = 12

In [3]:
vecenv = pufferlib.vector.make(
    atari.env_creator(game),
    num_envs=num_envs,
    backend=pufferlib.vector.Multiprocessing,
    env_kwargs={"framestack": 4},
)

A.L.E: Arcade Learning Environment (version 0.9.0+750d7f9)
[Powered by Stella]
Process Process-12:
Process Process-6:
Process Process-4:
Process Process-10:
Process Process-11:
Process Process-2:
Process Process-3:
Process Process-1:
Process Process-9:
Process Process-5:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/fitti/.conda/envs/puffer/lib/python3.11/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
Traceback (most recent call last):
  File "/home/fitti/.conda/envs/puffer/lib/python3.11/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/home/fitti/.conda/envs/puffer/lib/python3.11/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/hom

In [4]:
import torch
from torch import nn

In [5]:
class PPONetwork(nn.Module):
    def __init__(self, obs_shape, n_actions):
        super().__init__()
        self.convs = nn.Sequential(
            nn.Conv2d(in_channels=obs_shape[0], out_channels=16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1),
            nn.MaxPool2d(2, 2),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.MaxPool2d(2, 2),
            nn.ReLU(),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.MaxPool2d(2, 2),
            nn.ReLU(),
            nn.Flatten(),
        )

        dummy_in = torch.zeros(obs_shape).unsqueeze(0)
        n_flattened = self.convs(dummy_in).shape[1]

        self.policy_head = nn.Linear(n_flattened, n_actions)
        self.value_head = nn.Linear(n_flattened, 1)

    def forward(self, x):
        device = next(self.convs.parameters()).device
        if not isinstance(x, torch.Tensor):
            x = torch.tensor(np.array(x), dtype=torch.float32, device=device)
        x = x.to(device=device)
        x = einops.rearrange(x, "b w c h -> b c h w")

        x = self.convs(x)
        logits = self.policy_head(x)
        value = self.value_head(x)
        
        return logits, value

In [6]:
import random

In [7]:
def select_actions(logits):
    dist = torch.distributions.Categorical(logits=logits) # shape [batch_size, n_actions]
    
    actions = dist.sample() # Shape: [n_actions]
    log_probs = dist.log_prob(actions)

    return actions, log_probs

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_float32_matmul_precision('high')

In [9]:
from torch import optim

from itertools import count
from tqdm.auto import tqdm

import wandb

In [10]:
def train(
    max_batches = 8_000,
    batch_size = 128,
    minibatch_size = 32,
    num_epochs = 10,
    learning_rate = 1e-4,
    gamma = 0.99,
    lmbd = 0.95,
    value_loss_coef = 0.5,
    clip_param = 0.2,
    leave_bar = True,
):  
    assert batch_size % minibatch_size == 0
    num_minibatches = int(batch_size / minibatch_size)
    net = PPONetwork(obs_shape, n_actions).to(device)
    print("Compiling model...", end="\r")
    net.compile()
    print("Model compiled!   ")
    optimizer = optim.Adam(net.parameters(), lr=learning_rate)
    
    #if wandb.run is not None:
        #wandb.watch(net, log="all")
        
    obs, _ = vecenv.reset()
    try:
        for batch in tqdm(range(max_batches), leave=leave_bar, desc="Batches"):
            batch_states = torch.zeros(
                (batch_size + 1, num_envs) + vecenv.single_observation_space.shape,
                device=device,
            )
            batch_values = torch.zeros((batch_size + 1, num_envs), device=device)
            batch_actions = torch.zeros((batch_size, num_envs), device=device)
            batch_log_probs = torch.zeros((batch_size, num_envs), device=device)
            batch_rewards = torch.zeros((batch_size, num_envs), device=device)
            batch_dones = torch.zeros((batch_size, num_envs), dtype=torch.bool, device=device)
            batch_advantages = torch.zeros((batch_size, num_envs), device=device)
            
            # Collect one rollout for each env
            for t in range(batch_size):
                with torch.no_grad():
                    logits, values = net(obs)
                    actions, log_probs = select_actions(logits)
        
                next_obs, rewards, dones, truncateds, infos = vecenv.step(actions.cpu())

                for i, info in enumerate(infos):
                    wandb.log(data={"return": info["episode_return"]})
                    
                batch_states[t] = torch.from_numpy(obs.copy())
                batch_values[t] = values.squeeze()
                batch_actions[t] = actions.squeeze()
                batch_log_probs[t] = log_probs
                batch_rewards[t] = torch.from_numpy(rewards.copy())
                batch_dones[t] = torch.from_numpy(dones.copy())

                obs = next_obs.copy()

            # Append the final states and values for computing the final losses
            with torch.no_grad():
                _, final_values = net(obs)
                batch_states[batch_size] = torch.from_numpy(obs.copy())
                batch_values[batch_size] = final_values.squeeze()
            
            # Calculate advantages
            advantages = torch.zeros(final_values.shape[0], device=device)
            for t in reversed(range(batch_size)):
                next_values = batch_values[t+1] * ~batch_dones[t]
                td_error = batch_rewards[t] + gamma * next_values - batch_values[t]
                advantages = td_error + gamma * lmbd * (advantages * ~batch_dones[t])
                adv_mean, adv_std = advantages.mean(), advantages.std()
                batch_advantages[t] = (advantages - adv_mean) / (1.0 if adv_std == 0.0 else adv_std)

            # Perform updates
            for epoch in range(num_epochs):
                indices = torch.randperm(batch_size)
                for minibatch in range(num_minibatches):
                    start = minibatch * minibatch_size
                    end = start + minibatch_size
                    
                    for t in indices[start:end]:
                        logits, values = net(batch_states[t])
                        dist = torch.distributions.Categorical(logits=logits)
                        log_probs = dist.log_prob(batch_actions[t])
                        ratio = (log_probs - batch_log_probs[t]).exp()
                        
                        next_values = batch_values[t+1] * ~batch_dones[t]
                        td_target = batch_rewards[t] + gamma * next_values
                        value_loss = (value_loss_coef * (td_target - values)**2).mean()
                        
                        surr1 = ratio * batch_advantages[t]
                        surr2 = torch.clamp(ratio, 1.0 - clip_param, 1.0 + clip_param) * batch_advantages[t]
                        policy_loss = -torch.min(surr1, surr2).mean()
    
                        loss = policy_loss + value_loss
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()

    except KeyboardInterrupt:
        print("Training halted manually.")
        
    # Finalize the run
    model_path = f"./models/PPO_scratch_{game}_latest.pt"
    torch.save(net.state_dict(), model_path)
    wandb.log_model(model_path, "latest")
    wandb.unwatch()
    net.eval()
    
    return net

In [11]:
obs_shape = vecenv.single_observation_space.shape
obs_shape = (obs_shape[1], obs_shape[2], obs_shape[0])
n_actions = vecenv.single_action_space.n

obs_shape, n_actions

((4, 105, 80), 6)

In [12]:
max_batches = 2_400
batch_size = 128
minibatch_size = 32
num_epochs = 3
learning_rate = 2.5e-4
gamma = 0.99
lmbd = 0.95
value_loss_coef = 0.25
clip_param = 0.2
leave_bar = True

project = f"PPO-scratch-{game}"
config = {
    "num_envs": num_envs,
    "batch_size": batch_size,
    "minibatch_size": minibatch_size,
    "num_epochs": num_epochs,
    "learning_rate": learning_rate,
    "gamma": gamma,
    "lambda": lmbd,
    "value_loss_coef": value_loss_coef,
    "clip_param": clip_param,
}

wandb.init(
    project=project,
    config=config,
)
    
net = train(
    max_batches=max_batches,
    batch_size=batch_size,
    minibatch_size=minibatch_size,
    num_epochs=num_epochs,
    learning_rate=learning_rate,
    gamma=gamma,
    lmbd=lmbd,
    value_loss_coef=value_loss_coef,
    clip_param=clip_param,
    leave_bar=leave_bar,
)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mfitti[0m. Use [1m`wandb login --relogin`[0m to force relogin


Model compiled!   


Batches:   0%|          | 0/2400 [00:00<?, ?it/s]

  return node.target(*args, **kwargs)
  return target(*args, **kwargs)
  return func(*args, **kwargs)
W0222 07:40:07.010000 12336 site-packages/torch/_inductor/utils.py:1137] [0/0] Not enough SMs to use max_autotune_gemm mode


Training halted manually.


In [13]:
import gymnasium as gym
import ale_py
env = gym.make("ALE/Pong-v5", render_mode="rgb_array")

from datetime import datetime

timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
video_folder = f"./videos/{project}_{timestamp}"
env = gym.wrappers.RecordVideo(env, video_folder)

env = gym.wrappers.GrayScaleObservation(env=env)
env = gym.wrappers.ResizeObservation(env=env, shape=(105, 80))
env = gym.wrappers.FrameStack(env=env, num_stack=4)

ob, _ = env.reset()
ob = np.expand_dims(np.array(ob).transpose(2, 0, 1), 0)
print(ob.shape)

(1, 80, 4, 105)


In [14]:
import gymnasium as gym
import numpy as np
from datetime import datetime
from pathlib import Path

#timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
#video_folder = f"./videos/PPO_scratch_Breakout_{timestamp}"

#env = gym.make("", render_mode="rgb_array")
#env = gym.wrappers.RecordVideo(env, video_folder)

for _ in range(9):
    ob, _ = env.reset()
    ob = np.expand_dims(np.array(ob).transpose(2, 0, 1), 0)
    ret = 0
    done, truncated = False, False
    while not (done or truncated):
        logits, _ = net(ob)
        actions, *_ = select_actions(logits)
        ob, reward, done, truncated, _ = env.step(actions[0].cpu().numpy())
        ob = np.expand_dims(np.array(ob).transpose(2, 0, 1), 0)
        ret += reward
    
    print(ret)

if wandb.run is not None:
    latest_video = max(
        Path(video_folder).glob("*.mp4"),
        key=lambda x: x.stat().st_mtime
    )
    wandb.log({
        "video": wandb.Video(str(latest_video))
    })
    wandb.finish()
    
env.close()

  return node.target(*args, **kwargs)
  return target(*args, **kwargs)
  return func(*args, **kwargs)


MoviePy - Building video /home/fitti/projects/puffer/videos/PPO-scratch-pong_20250222_082927/rl-video-episode-0.mp4.
MoviePy - Writing video /home/fitti/projects/puffer/videos/PPO-scratch-pong_20250222_082927/rl-video-episode-0.mp4



                                                              

MoviePy - Done !
MoviePy - video ready /home/fitti/projects/puffer/videos/PPO-scratch-pong_20250222_082927/rl-video-episode-0.mp4
-21.0




MoviePy - Building video /home/fitti/projects/puffer/videos/PPO-scratch-pong_20250222_082927/rl-video-episode-1.mp4.
MoviePy - Writing video /home/fitti/projects/puffer/videos/PPO-scratch-pong_20250222_082927/rl-video-episode-1.mp4



                                                              

MoviePy - Done !
MoviePy - video ready /home/fitti/projects/puffer/videos/PPO-scratch-pong_20250222_082927/rl-video-episode-1.mp4
-21.0




-21.0
-21.0
-21.0
-21.0
-21.0
-21.0
MoviePy - Building video /home/fitti/projects/puffer/videos/PPO-scratch-pong_20250222_082927/rl-video-episode-8.mp4.
MoviePy - Writing video /home/fitti/projects/puffer/videos/PPO-scratch-pong_20250222_082927/rl-video-episode-8.mp4



                                                              

MoviePy - Done !
MoviePy - video ready /home/fitti/projects/puffer/videos/PPO-scratch-pong_20250222_082927/rl-video-episode-8.mp4
-21.0




0,1
return,█▁▁▁▁▁▁█▁▁█▁▁▁▁▁▁▁▁▁█▁▁██▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
return,-21
