# Space Invaders: PPO + Rainbow 

This notebook contains two independent agents for Atari Space Invaders:

- **PPO (on-policy)**: vectorized rollouts + GAE + clipped surrogate objective  
- **Rainbow DQN (off-policy)**: NoisyNet + Dueling + C51 + PER + n-step returns

They share environment construction utilities and evaluation helpers, but training loops remain separate.


In [None]:
# ---- Runtime / Device ----
import os
import time
import csv
import random
from collections import deque

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import gymnasium as gym
import ale_py

from gymnasium.wrappers import AtariPreprocessing, FrameStackObservation, ResizeObservation, GrayscaleObservation

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
gym.register_envs(ale_py)

print("torch.cuda.is_available():", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))
print("DEVICE:", DEVICE)


## Shared Atari environment utilities
 We standardise environment construction for fair comparison

In [None]:
def make_atari_env(
    seed: int = 0,
    *,
    episodic_life: bool = False,
    sticky_actions: bool = True,
    frame_skip: int = 4,
    frame_stack: int = 4,
    render_mode=None,
):
    """Create ALE/SpaceInvaders-v5 with standard preprocessing.

    - frameskip at ALE level set to 1 (we manage skipping in AtariPreprocessing)
    - sticky actions optional via repeat_action_probability
    - AtariPreprocessing handles resize, grayscale, (optional) episodic life, noop reset
    - FrameStackObservation stacks 4 frames -> (4,84,84)
    """
    env = gym.make(
        "ALE/SpaceInvaders-v5",
        frameskip=1,
        repeat_action_probability=0.25 if sticky_actions else 0.0,
        full_action_space=False,
        render_mode=render_mode,
    )

    env = AtariPreprocessing(
        env,
        noop_max=30,
        frame_skip=frame_skip,
        screen_size=84,
        terminal_on_life_loss=episodic_life,
        grayscale_obs=True,
        grayscale_newaxis=False,
        scale_obs=False,
    )

    env = FrameStackObservation(env, frame_stack)
    env.reset(seed=seed)

    return env


def fix_obs(obs: np.ndarray) -> np.ndarray:
    """Ensure observation is in (4,84,84) uint8 format.

    Handles possible formats returned by wrappers / vector envs.
    """
    # FrameStackObservation commonly returns (4,84,84) already.
    if isinstance(obs, np.ndarray):
        if obs.shape == (84, 84, 4):
            return np.transpose(obs, (2, 0, 1))
        
        if obs.shape == (4, 84, 84):
            return obs
        
        if obs.shape == (84, 84):
            return np.repeat(obs[None, ...], 4, axis=0)
        
    return np.array(obs)


## PPO Agent

In [None]:
# PPO model
class PPOCNNPolicy(nn.Module):
    def __init__(self, action_dim: int):
        super().__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(4, 32, 8, stride=4), nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2), nn.ReLU(),
            nn.Conv2d(64, 64, 3, stride=1), nn.ReLU(),
        )
        self.fc = nn.Sequential(
            nn.Linear(64 * 7 * 7, 512),
            nn.ReLU(),
        )

        self.policy = nn.Linear(512, action_dim)
        self.value = nn.Linear(512, 1)

    def forward(self, x: torch.Tensor):
        # x expected float32 or uint8-like; normalise in-network.
        x = x / 255.0
        x = self.conv(x)
        x = x.reshape(x.size(0), -1)
        x = self.fc(x)

        return self.policy(x), self.value(x)


In [None]:
# ---- PPO: Training / Evaluation ----
def train_ppo(config: dict):
    """Train PPO on Space Invaders.

    Notes:
    - PPO is on-policy: do not reuse experience across updates.
    - Clip rewards to {-1,0,1} for stability but track raw episode score separately.
    """
    num_envs = int(config.get("num_envs", 8))
    num_steps = int(config.get("num_steps", 512))
    total_steps = int(config.get("total_steps", 10_000_000))

    learning_rate = float(config.get("learning_rate", 2.5e-4))
    gamma = float(config.get("gamma", 0.99))

    gae_lambda = float(config.get("gae_lambda", 0.95))
    clip_epsilon = float(config.get("clip_epsilon", 0.1))

    epochs = int(config.get("epochs", 4))
    minibatch_size = int(config.get("minibatch_size", 512))

    entropy_coef = float(config.get("entropy_coef", 0.01))
    value_coef = float(config.get("value_coef", 0.5))
    linear_lr_decay = bool(config.get("linear_lr_decay", True))

    checkpoint_path = str(config.get("checkpoint_path", "space_invaders_final_ppo.pth"))
    save_every = int(config.get("save_every", 1_000_000))
    save_template = str(config.get("save_template", "space_invaders_ppo_{}M.pth"))
    load_path = config.get("load_path", None)

    # Vector env - use SyncVectorEnv for Windows compatibility
    envs = gym.vector.SyncVectorEnv(
        [lambda i=i: make_atari_env(seed=i, episodic_life=False, sticky_actions=False) for i in range(num_envs)]
    )
    obs, _ = envs.reset()
    action_dim = envs.single_action_space.n

    model = PPOCNNPolicy(action_dim).to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, eps=1e-5)

    # Load
    if load_path and os.path.exists(load_path):
        print(f"[PPO] Loading weights from {load_path}")
        model.load_state_dict(torch.load(load_path, map_location=DEVICE))

    elif os.path.exists(checkpoint_path):
        print(f"[PPO] Loading weights from {checkpoint_path}")
        model.load_state_dict(torch.load(checkpoint_path, map_location=DEVICE))

    else:
        print("[PPO] No checkpoint found. Starting from scratch.")

    episode_scores = np.zeros(num_envs, dtype=np.float32)
    recent_scores = deque(maxlen=100)

    global_step = 0
    start_time = time.time()

    while global_step < total_steps:
        if linear_lr_decay:
            frac = 1.0 - (global_step / max(1, total_steps))
            optimizer.param_groups[0]["lr"] = learning_rate * frac

        obs_buf, act_buf, logp_buf, rew_buf, done_buf, val_buf = [], [], [], [], [], []

        for _ in range(num_steps):
            global_step += num_envs

            with torch.no_grad():
                obs_t = torch.tensor(obs, dtype=torch.float32, device=DEVICE)
                logits, values = model(obs_t)
                probs = torch.softmax(logits, dim=-1)
                dist = torch.distributions.Categorical(probs)
                actions = dist.sample()

            next_obs, rewards, terminations, truncations, _ = envs.step(actions.cpu().numpy())
            dones = terminations | truncations

            # Raw score tracking
            episode_scores += rewards
            for i, done in enumerate(dones):
                if done:
                    recent_scores.append(float(episode_scores[i]))
                    episode_scores[i] = 0.0

            clipped_rewards = np.sign(rewards)

            obs_buf.append(obs)
            act_buf.append(actions.cpu().numpy())
            logp_buf.append(dist.log_prob(actions).detach().cpu().numpy())
            rew_buf.append(clipped_rewards)
            done_buf.append(dones)
            val_buf.append(values.squeeze().detach().cpu().numpy())

            obs = next_obs

            if global_step % save_every < num_envs:
                m = int(global_step / 1_000_000)
                save_name = save_template.format(m)
                torch.save(model.state_dict(), save_name)
                torch.save(model.state_dict(), checkpoint_path)
                print(f"[PPO] Saved: {save_name}")

        # GAE
        obs_arr = torch.tensor(np.array(obs_buf), dtype=torch.float32, device=DEVICE)
        act_arr = torch.tensor(np.array(act_buf), device=DEVICE)
        logp_arr = torch.tensor(np.array(logp_buf), dtype=torch.float32, device=DEVICE)
        rew_arr = np.array(rew_buf)
        done_arr = np.array(done_buf)
        val_arr = np.array(val_buf)

        with torch.no_grad():
            next_val = model(torch.tensor(obs, dtype=torch.float32, device=DEVICE))[1].squeeze().cpu().numpy()

        adv = np.zeros_like(rew_arr)
        lastgaelam = 0.0

        for t in reversed(range(num_steps)):
            next_nonterminal = 1.0 - done_arr[t].astype(np.float32)
            next_values = next_val if t == num_steps - 1 else val_arr[t + 1]
            delta = rew_arr[t] + gamma * next_values * next_nonterminal - val_arr[t]
            lastgaelam = delta + gamma * gae_lambda * next_nonterminal * lastgaelam
            adv[t] = lastgaelam

        returns = adv + val_arr
        adv = (adv - adv.mean()) / (adv.std() + 1e-8)

        b_obs = obs_arr.reshape(-1, 4, 84, 84)
        b_act = act_arr.reshape(-1)
        b_logp = logp_arr.reshape(-1)
        b_adv = torch.tensor(adv.reshape(-1), dtype=torch.float32, device=DEVICE)
        b_ret = torch.tensor(returns.reshape(-1), dtype=torch.float32, device=DEVICE)

        batch_size = num_envs * num_steps
        idxs = np.arange(batch_size)

        losses = []
        for _ in range(epochs):
            np.random.shuffle(idxs)

            for start in range(0, batch_size, minibatch_size):
                mb_idx = idxs[start:start + minibatch_size]

                logits, values = model(b_obs[mb_idx])
                probs = torch.softmax(logits, dim=-1)
                dist = torch.distributions.Categorical(probs)
                new_logp = dist.log_prob(b_act[mb_idx])
                entropy = dist.entropy().mean()

                ratio = (new_logp - b_logp[mb_idx]).exp()
                pg_loss = -torch.min(
                    ratio * b_adv[mb_idx],
                    torch.clamp(ratio, 1 - clip_epsilon, 1 + clip_epsilon) * b_adv[mb_idx],
                ).mean()
                v_loss = ((values.squeeze() - b_ret[mb_idx]) ** 2).mean()
                loss = pg_loss + value_coef * v_loss - entropy_coef * entropy

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                losses.append(float(loss.item()))

        elapsed = time.time() - start_time
        fps = int(global_step / max(1e-8, elapsed))
        avg_score = float(np.mean(recent_scores)) if recent_scores else 0.0
        lr_now = optimizer.param_groups[0]["lr"]

        print(f"[PPO] Step {global_step}/{total_steps} | AvgScore(100): {avg_score:.0f} | FPS: {fps} | LR: {lr_now:.2e}")

    torch.save(model.state_dict(), checkpoint_path)
    envs.close()
    
    print(f"[PPO] Training complete. Saved to {checkpoint_path}")
    return checkpoint_path


def get_noop_action(env) -> int:
    """Get the NOOP action index for the environment."""
    try:
        meanings = env.unwrapped.get_action_meanings()
        if "NOOP" in meanings:
            return meanings.index("NOOP")
    except Exception:
        pass
    return 0


def print_eval_summary(agent_name: str, returns: list):
    """Print comprehensive evaluation summary metrics."""
    scores = np.array(returns)
    n = len(scores)
    
    mean_score = float(np.mean(scores))
    median_score = float(np.median(scores))
    std_score = float(np.std(scores))
    min_score = float(np.min(scores))
    max_score = float(np.max(scores))
    
    q25 = float(np.percentile(scores, 25))
    q75 = float(np.percentile(scores, 75))
    iqr = q75 - q25
    
    coef_var = std_score / mean_score if mean_score != 0 else 0.0
    std_err = std_score / np.sqrt(n) if n > 0 else 0.0
    
    # 95% CI for mean
    ci_low = mean_score - 1.96 * std_err
    ci_high = mean_score + 1.96 * std_err
    
    best_run = int(np.argmax(scores)) + 1
    worst_run = int(np.argmin(scores)) + 1
    
    # Success rates
    rate_200 = 100.0 * np.sum(scores >= 200) / n
    rate_500 = 100.0 * np.sum(scores >= 500) / n
    rate_1000 = 100.0 * np.sum(scores >= 1000) / n
    rate_1500 = 100.0 * np.sum(scores >= 1500) / n
    
    print(f"========== {agent_name} SUMMARY METRICS ==========")
    print(f"{'Runs:':<22}{n}")
    print(f"{'Mean score:':<22}{mean_score:.2f}")
    print(f"{'Median score:':<22}{median_score:.2f}")
    print(f"{'Std deviation:':<22}{std_score:.2f}")
    print(f"{'Min score:':<22}{min_score:.1f}")
    print(f"{'Max score:':<22}{max_score:.1f}")
    print(f"{'25th percentile:':<22}{q25:.2f}")
    print(f"{'75th percentile:':<22}{q75:.2f}")
    print(f"{'IQR:':<22}{iqr:.2f}")
    print(f"{'Coeff. of variation:':<22}{coef_var:.3f}")
    print(f"{'Standard error:':<22}{std_err:.2f}")
    print(f"{'95% CI (mean):':<22}[{ci_low:.2f}, {ci_high:.2f}]")
    print(f"{'Best run #:':<22}{best_run}")
    print(f"{'Worst run #:':<22}{worst_run}")
    print()
    print("--- Success Rates ---")
    print(f"Score >= 200: {rate_200:.2f}%")
    print(f"Score >= 500: {rate_500:.2f}%")
    print(f"Score >= 1000: {rate_1000:.2f}%")
    print(f"Score >= 1500: {rate_1500:.2f}%")
    print("=" * (len(agent_name) + 36))


def make_ppo_eval_env(sticky_action_prob: float = 0.25):
    """Create evaluation environment for PPO with sticky keys.
    
    Uses simpler wrappers matching the final_eval.py approach:
    - frameskip=4 at ALE level
    - ResizeObservation -> GrayscaleObservation -> FrameStackObservation
    """
    env = gym.make(
        "ALE/SpaceInvaders-v5",
        frameskip=4,
        repeat_action_probability=sticky_action_prob,
        render_mode=None,
    )
    env = ResizeObservation(env, (84, 84))
    env = GrayscaleObservation(env)
    env = FrameStackObservation(env, stack_size=4)
    return env


def obs_to_tensor(obs, device) -> torch.Tensor:
    """Convert observation to tensor for model input."""
    obs_np = np.asarray(obs)
    # Handle possible (4,84,84,1) shape from wrappers
    if obs_np.ndim == 4 and obs_np.shape[-1] == 1:
        obs_np = obs_np.squeeze(-1)
    return torch.tensor(obs_np, dtype=torch.float32, device=device).unsqueeze(0)


def evaluate_ppo(
    checkpoint_path: str,
    episodes: int = 100,
    seed: int = 123,
    max_random_noops: int = 30,
    sticky_action_prob: float = 0.25,
    verbose: bool = False,
):
    """Evaluate PPO agent with sticky keys and random no-op starts.
    
    Args:
        checkpoint_path: Path to the model checkpoint.
        episodes: Number of evaluation episodes.
        seed: Random seed for reproducibility.
        max_random_noops: Maximum number of random no-op actions at episode start.
        sticky_action_prob: Probability of repeating previous action (sticky keys).
        verbose: If True, print per-episode scores and summary statistics.
    
    Returns:
        Tuple of (mean_score, std_score).
    """
    env = make_ppo_eval_env(sticky_action_prob=sticky_action_prob)
    action_dim = env.action_space.n
    noop_action = get_noop_action(env)
    
    model = PPOCNNPolicy(action_dim).to(DEVICE)
    model.load_state_dict(torch.load(checkpoint_path, map_location=DEVICE))
    model.eval()

    rng = np.random.default_rng(seed)
    returns = []
    
    if verbose:
        print(f"[PPO EVAL] Starting evaluation: {episodes} episodes, max_noops={max_random_noops}")
        print(f"[PPO EVAL] Sticky Keys: {sticky_action_prob > 0} (Prob: {sticky_action_prob})")
        print(f"[PPO EVAL] Checkpoint: {checkpoint_path}")
        print("-" * 50)

    for ep in range(episodes):
        obs, _ = env.reset()
        
        # Random no-op actions at episode start
        n_noops = int(rng.integers(0, max_random_noops + 1))
        terminated = truncated = False
        for _ in range(n_noops):
            obs, _, terminated, truncated, _ = env.step(noop_action)
            if terminated or truncated:
                break
        
        done = terminated or truncated
        total = 0.0
        
        while not done:
            obs_t = obs_to_tensor(obs, DEVICE)
            with torch.no_grad():
                logits, _ = model(obs_t)
                action = int(torch.argmax(logits, dim=1).item())
            obs, r, terminated, truncated, _ = env.step(action)
            done = bool(terminated or truncated)
            total += float(r)
        
        returns.append(total)
        
        if verbose:
            print(f"[PPO EVAL] Episode {ep + 1}/{episodes} | Score: {total:.0f} | NoOps: {n_noops}")

    env.close()
    
    mean_score = float(np.mean(returns))
    std_score = float(np.std(returns))
    
    if verbose:
        print()
        print_eval_summary("PPO", returns)
    
    return mean_score, std_score

## Rainbow DQN Agent

In [None]:
# ---- Rainbow Hyperparams ----
SAVE_DIR = "./rainbow_space_invaders_log"
os.makedirs(SAVE_DIR, exist_ok=True)

NUM_ENVS_RB = 4
FRAME_SKIP = 4
FRAMES_PER_STEP = NUM_ENVS_RB * FRAME_SKIP

TOTAL_FRAMES = 32_000_000
TOTAL_STEPS = TOTAL_FRAMES // FRAMES_PER_STEP

NOISY_SIGMA_INIT = 0.5
N_STEP = 3
BATCH_SIZE = 64
GAMMA = 0.99
REPLAY_SIZE = 1_000_000

PER_ALPHA = 0.5
PER_BETA_START = 0.4
PER_BETA_END = 1.0
PER_BETA_ANNEAL_STEPS = max(1, TOTAL_STEPS)
TARGET_UPDATE = 32_000 // FRAMES_PER_STEP

LEARNING_RATE_RB = 6.25e-5
ADAM_EPS = 1.5e-4

NUM_ATOMS = 51
V_MIN = -10
V_MAX = 10

LEARNING_START = 80_000 // FRAMES_PER_STEP
PER_MAX_PRIORITY = 1.0

csv_path = os.path.join(SAVE_DIR, "training_log.csv")
if not os.path.exists(csv_path):
    with open(csv_path, "w", newline="") as f:
        csv.writer(f).writerow(["step", "episode_reward", "avg_reward_100",
                                "loss", "q_mean", "q_max", "eval_mean", "eval_std"])


In [None]:
# ---- Noisy Linear ----
class NoisyLinear(nn.Module):
    def __init__(self, inp, outp):
        super().__init__()
        self.in_features = inp
        self.out_features = outp
        self.noisy = True

        self.weight_mu = nn.Parameter(torch.empty(outp, inp))
        self.weight_sigma = nn.Parameter(torch.empty(outp, inp))
        self.register_buffer("weight_epsilon_i", torch.empty(inp))
        self.register_buffer("weight_epsilon_j", torch.empty(outp))

        self.bias_mu = nn.Parameter(torch.empty(outp))
        self.bias_sigma = nn.Parameter(torch.empty(outp))
        self.register_buffer("bias_epsilon", torch.empty(outp))

        self.reset_parameters()
        self.reset_noise()

    def reset_parameters(self):
        mu_range = 1 / np.sqrt(self.in_features)
        self.weight_mu.data.uniform_(-mu_range, mu_range)
        sigma_init = NOISY_SIGMA_INIT / np.sqrt(self.in_features)
        self.weight_sigma.data.fill_(sigma_init)
        self.bias_mu.data.uniform_(-mu_range, mu_range)
        self.bias_sigma.data.fill_(sigma_init)

    @staticmethod
    def _f(x):
        return x.sign() * x.abs().sqrt()

    def reset_noise(self):
        self.weight_epsilon_i.normal_()
        self.weight_epsilon_j.normal_()
        self.bias_epsilon.normal_()

    def forward(self, x):
        if self.noisy:
            eps_i = self._f(self.weight_epsilon_i)
            eps_j = self._f(self.weight_epsilon_j)
            w_eps = eps_j.unsqueeze(1) * eps_i.unsqueeze(0)
            w = self.weight_mu + self.weight_sigma * w_eps
            b = self.bias_mu + self.bias_sigma * self.bias_epsilon
        else:
            w = self.weight_mu
            b = self.bias_mu
        return nn.functional.linear(x, w, b)


In [None]:
# ---- Rainbow Network (Dueling + Noisy + C51) ----
class RainbowDQN(nn.Module):
    def __init__(self, action_dim, num_atoms=NUM_ATOMS, v_min=V_MIN, v_max=V_MAX):
        super().__init__()
        self.num_atoms = num_atoms
        self.action_dim = action_dim
        self.v_min = v_min
        self.v_max = v_max
        self.register_buffer("support", torch.linspace(v_min, v_max, num_atoms))

        self.feature = nn.Sequential(
            nn.Conv2d(4, 32, 8, 4), nn.ReLU(),
            nn.Conv2d(32, 64, 4, 2), nn.ReLU(),
            nn.Conv2d(64, 64, 3, 1), nn.ReLU(),
            nn.Flatten()
        )
        self.fc = nn.Sequential(NoisyLinear(64 * 7 * 7, 512), nn.ReLU())
        self.value = NoisyLinear(512, num_atoms)
        self.advantage = NoisyLinear(512, action_dim * num_atoms)

    def set_noise(self, noisy: bool):
        for m in self.modules():
            if isinstance(m, NoisyLinear):
                m.noisy = noisy

    def reset_noise(self):
        for m in self.modules():
            if isinstance(m, NoisyLinear):
                m.reset_noise()

    def forward(self, x):
        x = x / 255.0
        x = self.fc(self.feature(x))
        v = self.value(x).view(-1, 1, self.num_atoms)
        a = self.advantage(x).view(-1, self.action_dim, self.num_atoms)
        q_atoms = v + a - a.mean(1, keepdim=True)
        return torch.softmax(q_atoms, dim=2)


In [None]:
# ---- PER (SumTree + Prioritized Replay Buffer) ----
class SumTree:
    def __init__(self, capacity):
        self.capacity = int(capacity)
        self.tree = np.zeros(2 * self.capacity - 1, dtype=np.float32)
        self.data = [None] * self.capacity
        self.write = 0
        self.n_entries = 0

    def _propagate(self, idx, change):
        parent = (idx - 1) // 2
        self.tree[parent] += change
        if parent != 0:
            self._propagate(parent, change)

    def _retrieve(self, idx, s):
        left = 2 * idx + 1
        right = left + 1
        if left >= len(self.tree):
            return idx
        if s <= self.tree[left]:
            return self._retrieve(left, s)
        return self._retrieve(right, s - self.tree[left])

    def total(self):
        return float(self.tree[0])

    def add(self, p, data):
        idx = self.write + self.capacity - 1
        self.data[self.write] = data
        self.update(idx, p)
        self.write = (self.write + 1) % self.capacity
        self.n_entries = min(self.n_entries + 1, self.capacity)

    def update(self, idx, p):
        change = p - self.tree[idx]
        self.tree[idx] = p
        self._propagate(idx, change)

    def get(self, s):
        idx = self._retrieve(0, s)
        data_idx = idx - self.capacity + 1
        return idx, float(self.tree[idx]), self.data[data_idx]


class PrioritizedReplayBuffer:
    def __init__(self, size=REPLAY_SIZE, alpha=PER_ALPHA):
        self.alpha = alpha
        self.capacity = int(size)
        self.tree = SumTree(self.capacity)
        self.max_priority = 1.0

    def add(self, transition):
        p = self.max_priority
        self.tree.add((p ** self.alpha), transition)

    def sample(self, batch_size, beta):
        batch, tree_idxs, priorities = [], [], []
        total = self.tree.total()
        if total <= 0 or self.tree.n_entries == 0:
            return None, None, None

        segment = total / batch_size
        for i in range(batch_size):
            a = segment * i
            b = segment * (i + 1)
            s = random.uniform(a, b)
            tree_idx, p, data = self.tree.get(s)
            if data is None:
                return None, None, None
            batch.append(data)
            tree_idxs.append(tree_idx)
            priorities.append(p)

        probs = np.array(priorities, dtype=np.float32) / (total + 1e-8)
        weights = (self.tree.n_entries * probs + 1e-8) ** (-beta)
        weights = weights / (weights.max() + 1e-8)
        weights = torch.tensor(weights, device=DEVICE, dtype=torch.float32)
        return batch, tree_idxs, weights

    def update(self, tree_idxs, priorities):
        for tree_idx, p in zip(tree_idxs, priorities):
            p = float(p)
            p = min(max(p, 1e-6), PER_MAX_PRIORITY)
            self.tree.update(int(tree_idx), (p ** self.alpha))
            self.max_priority = max(self.max_priority, p)


In [None]:
# ---- N-step buffer ----
class NStepBuffer:
    def __init__(self, n=N_STEP, gamma=GAMMA):
        self.n = n
        self.gamma = gamma
        self.buf = deque()

    def reset(self):
        self.buf.clear()

    def add(self, s, a, r, ns, done):
        self.buf.append((s, a, r, ns, done))

    def _compute_nstep(self, k):
        R = 0.0
        ns_k = None
        done_k = False
        for i in range(k):
            _, _, r, ns, done = self.buf[i]
            R += (self.gamma ** i) * float(r)
            ns_k = ns
            if done:
                done_k = True
                break
        s0, a0 = self.buf[0][0], self.buf[0][1]
        return s0, a0, R, ns_k, done_k

    def pop_nstep_if_ready(self):
        if len(self.buf) < self.n:
            return None
        out = self._compute_nstep(self.n)
        self.buf.popleft()
        return out

    def flush_all(self):
        outs = []
        while len(self.buf) > 0:
            k = min(self.n, len(self.buf))
            outs.append(self._compute_nstep(k))
            self.buf.popleft()
        return outs


In [None]:
# ---- C51 Projection + Loss ----
def c51_projection(next_dist, rewards, dones, support, v_min, v_max, num_atoms, gamma):
    delta_z = (v_max - v_min) / (num_atoms - 1)
    Tz = rewards[:, None] + (1.0 - dones[:, None]) * (gamma * support[None, :])
    Tz = Tz.clamp(v_min, v_max)
    b = (Tz - v_min) / delta_z
    l = b.floor().long()
    u = b.ceil().long()

    proj = torch.zeros_like(next_dist)
    eq = (u == l)
    proj.scatter_add_(1, l.clamp(0, num_atoms - 1), next_dist * eq.float())

    ne = ~eq
    if ne.any():
        idx = ne.nonzero(as_tuple=True)[0]
        l_ne = l[ne].clamp(0, num_atoms - 1)
        u_ne = u[ne].clamp(0, num_atoms - 1)
        b_ne = b[ne]
        d_ne = next_dist[ne]

        proj.index_put_((idx, l_ne), d_ne * (u_ne.float() - b_ne), accumulate=True)
        proj.index_put_((idx, u_ne), d_ne * (b_ne - l_ne.float()), accumulate=True)

    proj = proj / (proj.sum(dim=1, keepdim=True) + 1e-8)
    return proj


def rainbow_loss(model, target, batch, weights, n_step=N_STEP, gamma=GAMMA):
    states, actions, rewards, next_states, dones = zip(*batch)
    states = torch.from_numpy(np.stack(states)).to(DEVICE).float()
    next_states = torch.from_numpy(np.stack(next_states)).to(DEVICE).float()
    actions = torch.tensor(actions, device=DEVICE, dtype=torch.long)
    rewards = torch.tensor(rewards, device=DEVICE, dtype=torch.float32)
    dones = torch.tensor(dones, device=DEVICE, dtype=torch.float32)

    gamma_n = gamma ** n_step

    model.reset_noise()
    target.reset_noise()

    with torch.no_grad():
        next_dist_all = model(next_states)
        q_next = (next_dist_all * model.support).sum(2)
        next_actions = q_next.argmax(1)

        target_dist_all = target(next_states)
        next_dist = target_dist_all[torch.arange(len(batch), device=DEVICE), next_actions]

        proj = c51_projection(
            next_dist=next_dist,
            rewards=rewards,
            dones=dones,
            support=model.support,
            v_min=model.v_min,
            v_max=model.v_max,
            num_atoms=model.num_atoms,
            gamma=gamma_n,
        )

    dist_all = model(states)
    dist = dist_all[torch.arange(len(batch), device=DEVICE), actions]
    loss_per = -(proj * torch.log(dist + 1e-8)).sum(1)
    loss = (loss_per * weights).mean()
    return loss, loss_per.detach().cpu().numpy(), dist_all


In [None]:
# ---- Rainbow: Evaluation ----
def evaluate_rainbow(
    model: RainbowDQN,
    episodes: int = 30,
    seed: int = 123,
    max_random_noops: int = 30,
    verbose: bool = False,
):
    """Evaluate Rainbow DQN agent with random no-op starts.
    
    Args:
        model: The RainbowDQN model to evaluate.
        episodes: Number of evaluation episodes.
        seed: Random seed for reproducibility.
        max_random_noops: Maximum number of random no-op actions at episode start.
        verbose: If True, print per-episode scores and summary statistics.
    
    Returns:
        Tuple of (mean_score, std_score).
    """
    env = make_atari_env(seed=seed, episodic_life=False, sticky_actions=True)
    noop_action = get_noop_action(env)
    
    model.eval()
    model.set_noise(False)

    rng = np.random.default_rng(seed)
    returns = []
    
    if verbose:
        print(f"[RAINBOW EVAL] Starting evaluation: {episodes} episodes, max_noops={max_random_noops}")
        print("-" * 50)

    for ep in range(episodes):
        obs, _ = env.reset(seed=seed + ep)
        obs = fix_obs(obs).astype(np.uint8)
        
        # Random no-op actions at episode start
        n_noops = int(rng.integers(0, max_random_noops + 1))
        terminated = truncated = False
        for _ in range(n_noops):
            obs, _, terminated, truncated, _ = env.step(noop_action)
            obs = fix_obs(obs).astype(np.uint8)
            if terminated or truncated:
                break
        
        done = terminated or truncated
        total_reward = 0.0

        while not done:
            obs_t = torch.from_numpy(obs[None]).to(DEVICE).float()
            with torch.no_grad():
                dist = model(obs_t)
                q = (dist * model.support).sum(2)
                action = int(q.argmax(1).item())

            obs, reward, terminated, truncated, _ = env.step(action)
            obs = fix_obs(obs).astype(np.uint8)
            done = bool(terminated or truncated)
            total_reward += float(reward)

        returns.append(total_reward)
        
        if verbose:
            print(f"[RAINBOW EVAL] Episode {ep + 1}/{episodes} | Score: {total_reward:.0f} | NoOps: {n_noops}")

    env.close()
    model.set_noise(True)
    model.train()
    
    mean_score = float(np.mean(returns))
    std_score = float(np.std(returns))
    
    if verbose:
        print()
        print_eval_summary("RAINBOW", returns)
    
    return mean_score, std_score


In [None]:
# ---- Rainbow: Training ----
def train_rainbow(
    *,
    total_steps: int = TOTAL_STEPS,
    num_envs: int = NUM_ENVS_RB,
    eval_interval_steps: int = 1_000_000 // FRAMES_PER_STEP,
    log_interval: int = 10_000,
):
    env = gym.vector.SyncVectorEnv(
        [lambda i=i: make_atari_env(seed=i, episodic_life=True, sticky_actions=True) for i in range(num_envs)]
    )

    obs, _ = env.reset()
    obs = np.stack([fix_obs(o) for o in obs]).astype(np.uint8)
    action_dim = env.single_action_space.n

    model = RainbowDQN(action_dim).to(DEVICE)
    target = RainbowDQN(action_dim).to(DEVICE)
    target.load_state_dict(model.state_dict())

    opt = optim.Adam(model.parameters(), lr=LEARNING_RATE_RB, eps=ADAM_EPS)
    buffer = PrioritizedReplayBuffer(size=REPLAY_SIZE, alpha=PER_ALPHA)
    nbufs = [NStepBuffer(n=N_STEP, gamma=GAMMA) for _ in range(num_envs)]

    rewards100 = deque(maxlen=100)
    episode_rewards = np.zeros(num_envs, dtype=np.float32)
    episode_count = 0

    beta = PER_BETA_START
    beta_increment = (PER_BETA_END - PER_BETA_START) / PER_BETA_ANNEAL_STEPS

    last_loss = 0.0
    avg100 = 0.0
    q_mean = 0.0
    q_max = 0.0
    last_eval_mean = 0.0
    last_eval_std = 0.0
    best_avg = -float("inf")

    csv_file = open(csv_path, "a", newline="")
    csv_writer = csv.writer(csv_file)

    for step in range(1, total_steps + 1):
        # Action selection
        model.reset_noise()
        with torch.no_grad():
            obs_t = torch.from_numpy(obs).to(DEVICE).float()
            dist = model(obs_t)
            q = (dist * model.support).sum(2)
            actions = q.argmax(1).cpu().numpy()

        # Env step
        next_obs, rewards, terminated, truncated, _ = env.step(actions)
        dones = np.logical_or(terminated, truncated)
        next_obs = np.stack([fix_obs(o) for o in next_obs]).astype(np.uint8)

        # Store transitions / episode tracking
        for i in range(num_envs):
            r_clip = float(np.clip(rewards[i], -1.0, 1.0))
            done = bool(dones[i])

            nbufs[i].add(obs[i], int(actions[i]), r_clip, next_obs[i], done)
            episode_rewards[i] += float(rewards[i])

            out = nbufs[i].pop_nstep_if_ready()
            if out is not None:
                buffer.add(out)

            if done:
                for t in nbufs[i].flush_all():
                    buffer.add(t)

                rewards100.append(float(episode_rewards[i]))
                episode_rewards[i] = 0.0
                episode_count += 1

        obs = next_obs

        # Train update
        if step > LEARNING_START and buffer.tree.n_entries >= BATCH_SIZE:
            beta = min(PER_BETA_END, beta + beta_increment)
            batch, tree_idxs, weights = buffer.sample(BATCH_SIZE, beta)

            if batch is not None:
                loss, per_sample_loss, dist_all = rainbow_loss(model, target, batch, weights)
                last_loss = float(loss.item())

                opt.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 10.0)
                opt.step()

                with torch.no_grad():
                    q_vals = (dist_all * model.support).sum(2)
                    q_mean = float(q_vals.mean().item())
                    q_max = float(q_vals.max().item())

                buffer.update(tree_idxs, per_sample_loss + 1e-6)

        # Target update
        if step % TARGET_UPDATE == 0:
            target.load_state_dict(model.state_dict())

        # Eval
        if step % eval_interval_steps == 0:
            last_eval_mean, last_eval_std = evaluate_rainbow(model, episodes=30, seed=123)
            print(f"[RNBW][EVAL] Step {step} | Mean {last_eval_mean:.1f} | Std {last_eval_std:.1f}")

        # Logging
        if step % log_interval == 0:
            avg100 = float(np.mean(rewards100)) if rewards100 else 0.0
            if episode_count > 0:
                print(f"[RNBW] Step {step}/{total_steps} | Ep {episode_count} | Avg100 {avg100:.1f} | Loss {last_loss:.4f}")

            csv_writer.writerow([
                step,
                rewards100[-1] if rewards100 else 0.0,
                avg100,
                last_loss,
                q_mean,
                q_max,
                last_eval_mean,
                last_eval_std,
            ])
            csv_file.flush()

            # Best model checkpointing
            if avg100 > best_avg and len(rewards100) >= 10:
                best_avg = avg100
                torch.save(
                    {
                        "step": step,
                        "model_state_dict": model.state_dict(),
                        "target_state_dict": target.state_dict(),
                        "avg_reward_100": avg100,
                    },
                    os.path.join(SAVE_DIR, "rainbow_best.pth"),
                )

        # Periodic checkpoint
        if step % 1_000_000 == 0:
            torch.save(
                {
                    "step": step,
                    "model_state_dict": model.state_dict(),
                    "target_state_dict": target.state_dict(),
                    "optimizer_state_dict": opt.state_dict(),
                },
                os.path.join(SAVE_DIR, f"rainbow_step_{step}.pth"),
            )
            print(f"[RNBW] Saved checkpoint: rainbow_step_{step}.pth")

    csv_file.close()
    env.close()
    print("[RNBW] Training done")
    return os.path.join(SAVE_DIR, "rainbow_best.pth")


## Running code

In [None]:
PPO_DEFAULT_CONFIG = {
    "num_envs": 8,
    "num_steps": 512,
    "total_steps": 100_000_000,
    "learning_rate": 2.5e-4,
    "gamma": 0.99,
    "gae_lambda": 0.95,
    "clip_epsilon": 0.1,
    "epochs": 4,
    "minibatch_size": 512,
    "entropy_coef": 0.01,
    "value_coef": 0.5,
    "linear_lr_decay": True,
    "checkpoint_path": "space_invaders_final_ppo.pth",
    "save_every": 1_000_000,
    "save_template": "space_invaders_ppo_{}M.pth",
    "load_path": None,
}

# Space Invaders action space size (NOOP, FIRE, RIGHT, LEFT, RIGHTFIRE, LEFTFIRE)
SPACE_INVADERS_ACTION_DIM = 6

def compare_agents(
    *,
    ppo_ckpt: str = "./pretrained_models/space_invaders_ppo_124M.pth",
    rainbow_ckpt: str = "./pretrained_models/space_invaders_rainbow_best.pth",
    episodes: int = 100,
    max_random_noops: int = 30,
    verbose: bool = False,
):
    """Compare PPO and Rainbow agents.
    
    Args:
        ppo_ckpt: Path to PPO checkpoint.
        rainbow_ckpt: Path to Rainbow checkpoint.
        episodes: Number of evaluation episodes per agent.
        max_random_noops: Maximum random no-op actions at episode start.
        verbose: If True, print per-episode scores and summary statistics.
    
    Returns:
        Dictionary with results for each agent: {"PPO": (mean, std), "Rainbow": (mean, std)}.
    """
    results = {}

    if os.path.exists(ppo_ckpt):
        m, s = evaluate_ppo(
            ppo_ckpt,
            episodes=episodes,
            max_random_noops=max_random_noops,
            verbose=verbose,
        )
        results["PPO"] = (m, s)

    else:
        results["PPO"] = None

    if rainbow_ckpt is None:
        rainbow_ckpt = os.path.join(SAVE_DIR, "rainbow_best.pth")

    if os.path.exists(rainbow_ckpt):
        chk = torch.load(rainbow_ckpt, map_location=DEVICE)
        model = RainbowDQN(SPACE_INVADERS_ACTION_DIM).to(DEVICE)
        model.load_state_dict(chk["model_state_dict"])
        m, s = evaluate_rainbow(
            model,
            episodes=episodes,
            max_random_noops=max_random_noops,
            verbose=verbose,
        )
        results["Rainbow"] = (m, s)
        
    else:
        results["Rainbow"] = None

    return results

## Train PPO

In [None]:
# Pretrained models can be found in the pretrained_models folder

#ppo_trained = train_ppo(PPO_DEFAULT_CONFIG)

## Train Rainbow

In [None]:
# This will take a long time to train fully, limited to 200k as PoC
# Pretrained models can be found in the pretrained_models folder

#rainbow_trained = train_rainbow(total_steps=200_000)

## Compare the two models

Returns mean score and stddev for each agent over specified episodes

In [None]:
# Evaluate PPO
ppo_mean, ppo_std = evaluate_ppo(
    checkpoint_path="./pretrained_models/space_invaders_ppo_124M.pth",
    episodes=1000,
    verbose=True,   # toggle logging here
)

print("PPO:", ppo_mean, "+/-", ppo_std)


In [None]:
# Evaluate Rainbow
# load pretrained checkpoint
rainbow_ckpt = "./pretrained_models/space_invaders_rainbow_best.pth"
chk = torch.load(rainbow_ckpt, map_location=DEVICE)

rainbow_model = RainbowDQN(action_dim=6).to(DEVICE)
rainbow_model.load_state_dict(chk["model_state_dict"])

rainbow_mean, rainbow_std = evaluate_rainbow(
    model=rainbow_model,
    episodes=1000,
    verbose=True,   # toggle logging here
)

print("Rainbow:", rainbow_mean, "+/-", rainbow_std)

In [None]:
results = compare_agents(episodes=100)
print(results)