# **PER (Prioritized Experience Replay) QR-DQN**

Prioritized Experience Replay (PER) improves the sample efficiency of reinforcement learning by prioritizing experiences that are expected to provide more informative gradients. In standard experience replay, all transitions are stored and sampled uniformly, but with PER, we prioritize transitions based on their TD error or temporal difference (TD) residual. The higher the TD error, the more important that experience is for training, since it indicates that the model's prediction for that state-action pair was significantly off. By sampling transitions with a higher probability, PER ensures that the agent learns more effectively from the most surprising or informative experiences.

In [None]:
# 1) Install compatible versions (clean), then *force restart* the Python process.
#    - NumPy 2.0.2 plays nicely with Gymnasium 0.29.1 and MinAtar 1.0.15.
#    - We *don't* touch TensorFlow/numba; we just won't import them.
%pip -q install --upgrade --force-reinstall "numpy==2.0.2" gymnasium==0.29.1 minatar==1.0.15

import os, sys
print("Restarting runtime now to finalize NumPy ABI ...")
os.kill(os.getpid(), 9)  # <-- forces a clean interpreter (same as Runtime->Restart)


[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m60.9/60.9 kB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m19.2/19.2 MB[0m [31m93.4 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
google-colab 1.0.0 requires pandas==2.2.2, but you have pandas 2.3.3 which is incompatible.
dask-cudf-cu12 25.6.0 requires pandas<2.2.4dev0,>=2.0, but you have pandas 2.3.3 which is incompatible.
cudf-cu12 25.6.0 requires pandas<2.2.4dev0,>=2.0, but you have pandas 2.3.3 which is incompatible.
datasets 4.0.0 requires fsspec[http]<=2025.3.0,>=2023.1.0, but you have fsspec 2025.9.0 which is incompatible.
dopamine-rl 4.1.2 requires gymnasium>=1.0.0, but you have gymnasium 0.29.1 which is incompatible.
gcsfs 2025.3.0 requires fsspec==2025.3.0, but you have fsspec 2025.9.0 

In [None]:
import numpy as np, gymnasium as gym
from minatar import Environment as MinAtarBaseEnv
print("NumPy:", np.__version__, "| Gymnasium:", gym.__version__)

# quick smoke test (no training yet)
env = MinAtarBaseEnv("breakout", sticky_action_prob=0.1)
env.reset(); env.seed(123)
s = env.state()   # (H,W,C)
print("MinAtar Breakout state shape (H,W,C):", s.shape)


NumPy: 2.0.2 | Gymnasium: 0.29.1
MinAtar Breakout state shape (H,W,C): (10, 10, 4)


Imports, config, and helpers

In [None]:
import os, math, copy, random, collections
from dataclasses import dataclass

import numpy as np
import torch, torch.nn as nn
import torch.nn.functional as F

# --------------------------
# Repro & device
# --------------------------
def set_seed(seed: int):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --------------------------
# Epsilon schedule
# --------------------------
def epsilon_by_step(step, cfg):
    # simple linear decay
    t = min(1.0, step / max(1, cfg.eps_decay))
    return cfg.eps_start + t * (cfg.eps_final - cfg.eps_start)

# --------------------------
# Config
# --------------------------
@dataclass
class Config:
    # RL
    gamma: float = 0.99
    n_quantiles: int = 51
    hidden: int = 128
    lr: float = 1e-3
    adam_eps: float = 1e-8
    grad_clip: float = 10.0

    # training
    total_steps: int = 200_000
    buffer_size: int = 100_000
    batch_size: int = 64
    learn_start: int = 1_000
    target_tau: int = 1_000

    # exploration
    eps_start: float = 1.0
    eps_final: float = 0.01
    eps_decay: int = 100_000

    # seeds
    base_seed: int = 0

    # env selection
    env_kind: str = "cartpole"   # "cartpole" or "minatar"
    game: str = "breakout"       # MinAtar game name
    sticky: float = 0.1          # MinAtar sticky-action prob

# Handy printer
def summarize(name, scores: np.ndarray):
    s = np.array(scores, dtype=np.float32)
    print(f"{name}: mean={s.mean():.2f}  median={np.median(s):.2f}  std={s.std(ddof=1):.2f}  n={len(s)}  scores={s.round(1)}")


Environments (Gymnasium CartPole + MinAtar wrapper)

In [None]:
# import gymnasium as gym
# from minatar import Environment as MinAtarBaseEnv

# # ------------- Gym CartPole -------------
# def make_env_cartpole(seed: int):
#     env = gym.make("CartPole-v1")
#     env.reset(seed=seed)
#     return env

# # ------------- MinAtar (Gym-like) -------------
# class ActionSpace:
#     def __init__(self, n): self.n = n

# class MinAtarGymLike:
#     """
#     Wrap MinAtar to feel like Gymnasium:
#       - reset(seed) -> (obs, info)
#       - step(a)     -> (obs, reward, terminated, truncated, info)
#       - action_space.n
#       - observation: float32 (C,H,W) in {0,1}
#     """
#     def __init__(self, game="breakout", sticky=0.1, seed=0):
#         self.env = MinAtarBaseEnv(game, sticky_action_prob=sticky)
#         self._seed = seed
#         self.env.reset()      # MinAtar does not take seed in ctor
#         self.env.seed(seed)   # separate seeding call
#         self.action_space = ActionSpace(self.env.num_actions())
#         s = self.env.state()  # (H,W,C)
#         H, W, C = s.shape
#         self.shape = (C, H, W)

#     def reset(self, seed=None):
#         if seed is not None:
#             self._seed = seed
#             self.env.seed(seed)
#         self.env.reset()
#         s = self.env.state()            # (H,W,C)
#         s = np.transpose(s, (2,0,1)).astype(np.float32)
#         return s, {}

#     def step(self, a):
#         r, done = self.env.act(a)       # MinAtar returns (reward, terminal)
#         s = self.env.state()
#         s = np.transpose(s, (2,0,1)).astype(np.float32)
#         return s, float(r), bool(done), False, {}

#     def close(self): pass

# def make_env_minatar(seed: int, game="breakout", sticky=0.1):
#     env = MinAtarGymLike(game=game, sticky=sticky, seed=seed)
#     return env

# # ------------- Unified factory -------------
# def make_env(cfg: Config, seed: int):
#     if cfg.env_kind == "cartpole":
#         return make_env_cartpole(seed)
#     elif cfg.env_kind == "minatar":
#         return make_env_minatar(seed, game=cfg.game, sticky=cfg.sticky)
#     else:
#         raise ValueError(f"Unknown env_kind: {cfg.env_kind}")


import numpy as np
import gymnasium as gym
from minatar import Environment as MinAtarBaseEnv

def make_env_cartpole(seed: int):
    env = gym.make("CartPole-v1")
    env.reset(seed=seed)
    return env

class ActionSpace:
    def __init__(self, n): self.n = n

class MinAtarGymLike:
    """
    Wrap MinAtar to feel like Gymnasium:
      - reset(seed) -> (obs, info)
      - step(a)     -> (obs, reward, terminated, truncated, info)
      - action_space.n
      - observation: float32 (C,H,W) in {0,1}
    """
    def __init__(self, game="breakout", sticky=0.1, seed=0):
        self.env = MinAtarBaseEnv(game, sticky_action_prob=sticky)
        self._seed = seed
        self.env.reset()
        self.env.seed(seed)                    # MinAtar uses a separate seed() call
        self.action_space = ActionSpace(self.env.num_actions())
        s = self.env.state()                   # (H,W,C)
        H, W, C = s.shape
        self.shape = (C, H, W)

    def reset(self, seed=None):
        if seed is not None:
            self._seed = seed
            self.env.seed(seed)
        self.env.reset()
        s = self.env.state()                   # (H,W,C)
        s = np.transpose(s, (2,0,1)).astype(np.float32)   # -> (C,H,W)
        return s, {}

    def step(self, a):
        r, done = self.env.act(a)              # (reward, terminal)
        s = self.env.state()
        s = np.transpose(s, (2,0,1)).astype(np.float32)
        return s, float(r), bool(done), False, {}

    def close(self): pass

def make_env_minatar(seed: int, game="breakout", sticky=0.1):
    return MinAtarGymLike(game=game, sticky=sticky, seed=seed)

class Config:
    def __init__(self, env_kind="minatar", game="breakout", sticky=0.1):
        self.env_kind = env_kind
        self.game = game
        self.sticky = sticky

def make_env(cfg: Config, seed: int):
    if cfg.env_kind == "cartpole":
        return make_env_cartpole(seed)
    elif cfg.env_kind == "minatar":
        return make_env_minatar(seed, game=cfg.game, sticky=cfg.sticky)
    else:
        raise ValueError(f"Unknown env_kind: {cfg.env_kind}")


QR-DQN networks (MLP for vector obs, Conv for MinAtar), action helper, quantile grid

In [None]:
# --------- Quantile support points (fixed midpoints) ----------
def quantile_midpoints(N: int, device=None):
    # taus_n = (2i + 1) / (2N), i=0..N-1
    i = torch.arange(N, dtype=torch.float32, device=device)
    return (2.0 * i + 1.0) / (2.0 * N)

# --------- Models ----------
class QRDQN_MLP(nn.Module):
    def __init__(self, obs_dim, n_actions, n_quantiles=51, hidden=128):
        super().__init__()
        self.n_actions = n_actions
        self.n_quantiles = n_quantiles
        self.net = nn.Sequential(
            nn.Linear(obs_dim, hidden), nn.ReLU(),
            nn.Linear(hidden, hidden),  nn.ReLU(),
            nn.Linear(hidden, n_actions * n_quantiles)
        )

    def forward(self, x):
        if x.dim() == 1: x = x.unsqueeze(0)
        y = self.net(x)
        B = y.shape[0]
        return y.view(B, self.n_actions, self.n_quantiles)

class QRDQN_Conv_MinAtar(nn.Module):
    def __init__(self, C, n_actions, n_quantiles=51, hidden=128):
        super().__init__()
        self.n_actions = n_actions
        self.n_quantiles = n_quantiles
        self.torso = nn.Sequential(
            nn.Conv2d(C, 16, kernel_size=3, stride=1, padding=1), nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1), nn.ReLU(),
        )
        self.head = nn.Sequential(
            nn.Linear(32*10*10, hidden), nn.ReLU(),
            nn.Linear(hidden, n_actions * n_quantiles)
        )

    def forward(self, x):
        if x.dim() == 3: x = x.unsqueeze(0)
        z = self.torso(x)
        z = torch.flatten(z, 1)
        y = self.head(z)
        B = y.shape[0]
        return y.view(B, self.n_actions, self.n_quantiles)

def build_model_for_env(env, cfg: Config):
    if hasattr(env.observation_space, "shape") and env.observation_space.shape is not None:
        shape = env.observation_space.shape
    else:
        s0,_ = env.reset()
        shape = np.array(s0).shape

    if len(shape) == 1:
        obs_dim = int(shape[0])
        nA = env.action_space.n
        net = QRDQN_MLP(obs_dim, nA, cfg.n_quantiles, cfg.hidden)
    elif len(shape) == 3:
        C,H,W = shape
        nA = env.action_space.n
        net = QRDQN_Conv_MinAtar(C, nA, cfg.n_quantiles, cfg.hidden)
    else:
        raise ValueError(f"Unsupported observation shape: {shape}")
    return net.to(device)

# --------- Action helper (QR mean-greedy with ε) ----------
@torch.no_grad()
def act_epsilon_greedy(state, online, epsilon: float):
    if np.random.rand() < epsilon:
        return np.random.randint(0, online.n_actions)
    s = torch.tensor(state, dtype=torch.float32, device=device)
    if s.dim() == 1: s = s.unsqueeze(0)
    a = online(s).mean(2).argmax(1).item()
    return int(a)


Losses (pinball & Huber-quantile) and per-sample loss for PER

In [None]:
def pinball_loss(q_pred, q_tgt, taus):
    u = q_tgt - q_pred
    tau = taus.view(1, -1)
    ind = (u.detach() < 0).float()
    return (torch.abs(tau - ind) * torch.abs(u)).mean()

def huber_quantile_loss(q_pred, q_tgt, taus, kappa=1.0):
    u = q_tgt - q_pred
    tau = taus.view(1, -1)
    ind = (u.detach() < 0).float()
    abs_u = u.abs()
    hub = torch.where(abs_u <= kappa, 0.5 * u**2, kappa * (abs_u - 0.5 * kappa))
    return (torch.abs(tau - ind) * hub / kappa).mean()

# Per-sample version for PER weighting
def qr_loss_per_sample(q_pred, q_tgt, taus, kappa=1.0, use_huber=True):
    u = q_tgt - q_pred
    tau = taus.view(1, -1)
    ind = (u.detach() < 0).float()
    if use_huber and kappa > 0.0:
        abs_u = u.abs()
        hub = torch.where(abs_u <= kappa, 0.5 * u**2, kappa * (abs_u - 0.5 * kappa))
        base = hub / kappa
    else:
        base = u.abs()
    loss_q = torch.abs(tau - ind) * base
    return loss_q.mean(dim=1)


Replay buffers (Uniform + PER)

In [None]:
# --------- Uniform replay (for A/B, optional) ----------
class ReplayUniform:
    def __init__(self, capacity: int):
        self.data = collections.deque(maxlen=capacity)

    def push(self, s,a,r,ns,d):
        self.data.append((s.copy(), int(a), float(r), ns.copy(), float(d)))

    def sample(self, B: int):
        idx = np.random.randint(0, len(self.data), size=B)
        batch = [self.data[i] for i in idx]
        return idx, batch, np.ones(B, dtype=np.float32)

    def update_priorities(self, idxs, new_p): pass
    def __len__(self): return len(self.data)

# --------- PER (SumTree) ----------
class SumTree:
    def __init__(self, capacity: int):
        self.n = 1
        while self.n < capacity: self.n <<= 1
        self.size = 0
        self.write = 0
        self.data = [None] * self.n
        self.tree = np.zeros(2*self.n, dtype=np.float32)

    def total(self) -> float:
        return float(self.tree[1])

    def _update(self, idx: int, p: float):
        i = idx + self.n
        self.tree[i] = p
        i //= 2
        while i >= 1:
            self.tree[i] = self.tree[2*i] + self.tree[2*i+1]
            i //= 2

    def add(self, p: float, data):
        self.data[self.write] = data
        self._update(self.write, p)
        self.write = (self.write + 1) % self.n
        self.size = min(self.size + 1, self.n)

    def sample(self, mass: float):
        i = 1
        while i < self.n:
            left = 2*i
            if self.tree[left] >= mass:
                i = left
            else:
                mass -= self.tree[left]
                i = left+1
        idx = i - self.n
        return idx, self.data[idx], float(self.tree[i])

    def update(self, idx: int, p: float):
        self._update(idx, p)

class PERReplay:
    def __init__(self, capacity, alpha=0.6, beta0=0.4, beta1=1.0, steps=1_000_000, eps=1e-6):
        self.tree = SumTree(capacity)
        self.alpha = float(alpha)
        self.beta0, self.beta1 = float(beta0), float(beta1)
        self.steps = int(steps)
        self.step = 0
        self.eps = float(eps)
        self.max_p = 1.0

    def push(self, s,a,r,ns,d):
        self.tree.add(self.max_p, (s.copy(), int(a), float(r), ns.copy(), float(d)))

    def sample(self, B):
        self.step += 1
        beta = self.beta0 + (self.beta1 - self.beta0) * min(1.0, self.step / self.steps)
        total = self.tree.total() + 1e-8
        seg = total / B
        idxs, batch, probs = [], [], []
        for i in range(B):
            a, b = seg*i, seg*(i+1)
            mass = np.random.uniform(a, b)
            idx, data, p = self.tree.sample(mass)
            idxs.append(idx); batch.append(data); probs.append(p)
        probs = np.asarray(probs, dtype=np.float32) / total
        N = max(1, self.tree.size)
        w = (N * probs) ** (-beta)
        w /= (w.max() + 1e-8)
        return idxs, batch, w.astype(np.float32)

    def update_priorities(self, idxs, new_p):
        for i, p in zip(idxs, new_p):
            p_alpha = (abs(float(p)) + self.eps) ** self.alpha
            self.tree.update(i, p_alpha)
            if p_alpha > self.max_p: self.max_p = p_alpha

    def __len__(self): return self.tree.size

class PERBuffer:
    def __init__(self, capacity, obs_shape, alpha=0.6, eps_prio=1e-6, p_min=1e-3, p_max=10.0):
        self.capacity  = capacity
        self.alpha     = alpha
        self.eps_prio  = eps_prio
        self.p_min     = p_min
        self.p_max     = p_max
        self.ptr       = 0
        self.full      = False

        self.S  = np.zeros((capacity,)+obs_shape, dtype=np.float32)
        self.A  = np.zeros((capacity,), dtype=np.int64)
        self.R  = np.zeros((capacity,), dtype=np.float32)
        self.NS = np.zeros((capacity,)+obs_shape, dtype=np.float32)
        self.D  = np.zeros((capacity,), dtype=np.float32)
        self.P  = np.ones((capacity,), dtype=np.float32)

    def __len__(self):
        return self.capacity if self.full else self.ptr

    def push(self, s, a, r, ns, d, prio=None):
        i = self.ptr
        self.S[i]  = s
        self.A[i]  = a
        self.R[i]  = r
        self.NS[i] = ns
        self.D[i]  = d
        if prio is None:
            prio = float(self.P[:len(self)].max() if len(self) > 0 else 1.0)
        prio = np.clip(prio, self.p_min, self.p_max)
        self.P[i] = prio

        self.ptr += 1
        if self.ptr >= self.capacity:
            self.ptr  = 0
            self.full = True

    def sample(self, B, beta=0.4):
        n = len(self)
        scaled = self.P[:n] ** self.alpha
        probs  = scaled / scaled.sum()
        idx    = np.random.choice(n, size=B, p=probs, replace=False)

        w = (n * probs[idx]) ** (-beta)
        w = w / w.max()
        w = w.astype(np.float32)

        s  = torch.tensor(self.S[idx],  dtype=torch.float32, device=device)
        ns = torch.tensor(self.NS[idx], dtype=torch.float32, device=device)
        a  = torch.tensor(self.A[idx],  dtype=torch.int64,   device=device)
        r  = torch.tensor(self.R[idx],  dtype=torch.float32, device=device)
        d  = torch.tensor(self.D[idx],  dtype=torch.float32, device=device)
        w_t= torch.tensor(w,            dtype=torch.float32, device=device)
        return (idx, s, a, r, ns, d, w_t)

    def update_priorities(self, idx, new_prio):
        new_prio = np.asarray(new_prio, dtype=np.float32)
        new_prio = np.abs(new_prio) + self.eps_prio
        new_prio = np.clip(new_prio, self.p_min, self.p_max)
        self.P[idx] = new_prio


Training (n=1), evaluation, and multi-seed runner

In [None]:
@torch.no_grad()
def greedy_eval(env_maker, cfg: Config, online_net, episodes=10, seed_offset=2025):
    def run_one(ep):
        env = env_maker(cfg, cfg.base_seed + seed_offset + ep)
        s,_ = env.reset()
        total = 0.0
        for _ in range(5000):
            a = act_epsilon_greedy(s, online_net, 0.0)
            s, r, term, trunc, _ = env.step(a)
            total += r
            if term or trunc: break
        try: env.close()
        except: pass
        return total
    scores = [run_one(i) for i in range(episodes)]
    arr = np.array(scores, dtype=np.float32)
    return float(arr.mean()), float(arr.std(ddof=1)), scores

def train_qrdqn(run_name: str, cfg: Config, use_per=True, loss_mode="pinball", seed=None):
    seed = cfg.base_seed if seed is None else seed
    set_seed(seed)
    env = make_env(cfg, seed)
    online = build_model_for_env(env, cfg)
    target = build_model_for_env(env, cfg)
    target.load_state_dict(online.state_dict())
    opt = torch.optim.Adam(online.parameters(), lr=cfg.lr, eps=cfg.adam_eps)
    taus = quantile_midpoints(cfg.n_quantiles, device=device)

    # choose replay
    if use_per:
        rb = PERReplay(cfg.buffer_size, alpha=0.6, beta0=0.4, beta1=1.0, steps=cfg.total_steps, eps=1e-6)
    else:
        rb = ReplayUniform(cfg.buffer_size)

    # prime
    s,_ = env.reset(seed=seed)
    ep_return = 0.0
    recent_returns = []

    print(f"=== Training {run_name} | {'PER' if use_per else 'Uniform'} | loss={loss_mode} | n=1 ===")
    for step in range(1, cfg.total_steps+1):
        eps = epsilon_by_step(step, cfg)
        a = act_epsilon_greedy(s, online, eps)
        ns, r, term, trunc, _ = env.step(a)
        d = float(term or trunc)
        rb.push(s,a,r,ns,d)
        s = ns
        ep_return += r

        if len(rb) >= cfg.learn_start:
            idxs, batch, w = rb.sample(cfg.batch_size)
            bs, ba, br, bns, bd = zip(*batch)
            s_b  = torch.tensor(np.stack(bs),  dtype=torch.float32, device=device)
            ns_b = torch.tensor(np.stack(bns), dtype=torch.float32, device=device)
            a_b  = torch.tensor(ba, dtype=torch.int64, device=device)
            r_b  = torch.tensor(br, dtype=torch.float32, device=device)
            d_b  = torch.tensor(bd, dtype=torch.float32, device=device)
            w_b  = torch.tensor(w,  dtype=torch.float32, device=device)

            # forward
            q_all = online(s_b)
            B,A,N = q_all.shape
            q_sel = q_all.gather(1, a_b.view(B,1,1).expand(B,1,N)).squeeze(1)
            with torch.no_grad():
                next_q_online = online(ns_b)
                next_a = next_q_online.mean(2).argmax(1, keepdim=True)
                next_q_target = target(ns_b).gather(1, next_a.unsqueeze(-1).expand(-1,-1,N)).squeeze(1)
                tgt = r_b.unsqueeze(1) + (1.0 - d_b.unsqueeze(1)) * cfg.gamma * next_q_target

            if use_per:
                per_elem = qr_loss_per_sample(q_sel, tgt, taus, kappa=1.0, use_huber=(loss_mode=="huber1"))
                loss = (w_b * per_elem).mean()
            else:
                if loss_mode == "pinball":
                    loss = pinball_loss(q_sel, tgt, taus)
                else:
                    loss = huber_quantile_loss(q_sel, tgt, taus, kappa=1.0)

            opt.zero_grad(set_to_none=True)
            loss.backward()
            nn.utils.clip_grad_norm_(online.parameters(), cfg.grad_clip)
            opt.step()

            if use_per:
                with torch.no_grad():
                    rb.update_priorities(idxs, per_elem.detach().cpu().numpy())

        if step % cfg.target_tau == 0:
            target.load_state_dict(online.state_dict())

        if term or trunc:
            recent_returns.append(ep_return)
            s,_ = env.reset()
            ep_return = 0.0

        if step % 5000 == 0:
            last10 = np.mean(recent_returns[-10:]) if recent_returns else 0.0
            print(f"[{run_name}] step {step:7d} | buffer={len(rb):6d} | eps={eps:.3f} | recent(10)={last10:.1f}")

    # save & quick eval (10 eps)
    ckpt = f"/content/{run_name}.pt"
    torch.save({"model": online.state_dict()}, ckpt)
    mean_eval, std_eval, _ = greedy_eval(make_env, cfg, online, episodes=10, seed_offset=2025)
    print(f"Saved: {ckpt} | Mean greedy eval (10 eps): {mean_eval:.2f}")

    try: env.close()
    except: pass
    return {"name": run_name, "ckpt": ckpt, "mean_eval": mean_eval}

def run_multi_seed(tag, cfg: Config, seeds, use_per=True, loss_mode="pinball"):
    scores, ckpts = [], []
    for s in seeds:
        out = train_qrdqn(f"{tag}_seed{s}", cfg, use_per=use_per, loss_mode=loss_mode, seed=s)
        scores.append(out["mean_eval"]); ckpts.append(out["ckpt"])
    return np.array(scores, dtype=np.float32), ckpts


CartPole with PER vs Uniform (quick sanity)

In [None]:
# CartPole sanity check
cfg_cp = Config(
    env_kind="cartpole",
    total_steps=50_000,
    buffer_size=20_000,
    batch_size=64,
    learn_start=1_000,
    target_tau=1_000,
    eps_start=1.0, eps_final=0.01, eps_decay=25_000,
    n_quantiles=51, hidden=128, lr=1e-3, gamma=0.99,
    base_seed=0
)

seeds = [7,8,9]

print("=== CartPole-v1 — Multi-seed (PER n=1 vs Uniform n=1) ===")
per_scores, per_ckpts = run_multi_seed("cp_per_pinball", cfg_cp, seeds, use_per=True,  loss_mode="pinball")
uni_scores, uni_ckpts = run_multi_seed("cp_uni_pinball", cfg_cp, seeds, use_per=False, loss_mode="pinball")

summarize("PER (pinball)", per_scores)
summarize("Uniform (pinball)", uni_scores)


=== CartPole-v1 — Multi-seed (PER n=1 vs Uniform n=1) ===
=== Training cp_per_pinball_seed7 | PER | loss=pinball | n=1 ===
[cp_per_pinball_seed7] step    5000 | buffer=  5000 | eps=0.802 | recent(10)=25.2
[cp_per_pinball_seed7] step   10000 | buffer= 10000 | eps=0.604 | recent(10)=36.4
[cp_per_pinball_seed7] step   15000 | buffer= 15000 | eps=0.406 | recent(10)=121.5
[cp_per_pinball_seed7] step   20000 | buffer= 20000 | eps=0.208 | recent(10)=200.5
[cp_per_pinball_seed7] step   25000 | buffer= 25000 | eps=0.010 | recent(10)=438.0
[cp_per_pinball_seed7] step   30000 | buffer= 30000 | eps=0.010 | recent(10)=457.7
[cp_per_pinball_seed7] step   35000 | buffer= 32768 | eps=0.010 | recent(10)=200.2
[cp_per_pinball_seed7] step   40000 | buffer= 32768 | eps=0.010 | recent(10)=181.7
[cp_per_pinball_seed7] step   45000 | buffer= 32768 | eps=0.010 | recent(10)=18.1
[cp_per_pinball_seed7] step   50000 | buffer= 32768 | eps=0.010 | recent(10)=9.6
Saved: /content/cp_per_pinball_seed7.pt | Mean greed

MinAtar Breakout with PER (you can add Uniform for A/B)

In [None]:
# MinAtar Breakout
cfg_b = Config(
    env_kind="minatar",
    game="breakout",
    sticky=0.1,
    total_steps=200_000,
    buffer_size=100_000,
    batch_size=64,
    learn_start=5_000,
    target_tau=2_000,
    eps_start=1.0, eps_final=0.01, eps_decay=100_000,
    n_quantiles=51, hidden=128, lr=1e-3, gamma=0.99,
    base_seed=0
)

seeds = [7,8,9]

print("=== MinAtar Breakout — Multi-seed (PER n=1) ===")
b_per_scores, b_per_ckpts = run_multi_seed("minatar_breakout_per", cfg_b, seeds, use_per=True, loss_mode="pinball")
summarize("Breakout PER (pinball)", b_per_scores)

# Optional uniform A/B:
print("=== MinAtar Breakout — Multi-seed (Uniform n=1) ===")
b_uni_scores, b_uni_ckpts = run_multi_seed("minatar_breakout_uni", cfg_b, seeds, use_per=False, loss_mode="pinball")
summarize("Breakout Uniform (pinball)", b_uni_scores)


=== MinAtar Breakout — Multi-seed (PER n=1) ===


AttributeError: 'MinAtarGymLike' object has no attribute 'observation_space'

MinAtar Asterix with PER

In [None]:
cfg_a = copy.deepcopy(cfg_b)
cfg_a.game = "asterix"

print("=== MinAtar Asterix — Multi-seed (PER n=1) ===")
a_per_scores, a_per_ckpts = run_multi_seed("minatar_asterix_per", cfg_a, seeds, use_per=True, loss_mode="pinball")
summarize("Asterix PER (pinball)", a_per_scores)


Reload & evaluate any checkpoint (more episodes)

In [None]:
@torch.no_grad()
def load_and_eval(ckpt_path: str, cfg: Config, episodes=100):
    env = make_env(cfg, cfg.base_seed + 13579)
    net = build_model_for_env(env, cfg)
    payload = torch.load(ckpt_path, map_location=device)
    net.load_state_dict(payload["model"])

    m, s, scores = greedy_eval(make_env, cfg, net, episodes=episodes, seed_offset=24680)
    print(f"[Reload check] {os.path.basename(ckpt_path)}: mean={m:.2f} ± {s:.2f} over {episodes} eps")
    try: env.close()
    except: pass
    return m, s, scores

# Example:
load_and_eval(b_per_ckpts[0], cfg_b, episodes=300)


# **full implementation of QR-DQN (n=1) with fixed PER (priorities from scalar TD-error on mean-Q, IS weights applied per-sample before reduction, clipped priorities), plus the uniform baseline.**

In [None]:
import math, os, random, time
from dataclasses import dataclass
from typing import Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

@dataclass
class Config:
    # env
    env_kind: str = "cartpole"   # "cartpole" | "minatar"
    game: str    = "breakout"    # for MinAtar
    sticky: float = 0.1          # for MinAtar
    base_seed: int = 7

    # algo
    total_steps: int = 50_000
    buffer_size: int = 32_768
    batch_size: int  = 128
    gamma: float     = 0.99
    n_quantiles: int = 51
    hidden: int      = 128
    lr: float        = 3e-4
    adam_eps: float  = 1e-8
    target_tau: int  = 2_000
    grad_clip: float = 10.0

    # exploration
    eps_start: float = 1.0
    eps_final: float = 0.01
    eps_decay: int   = 25_000

    # learning start
    learn_start: int = 2_000

def set_global_seeds(seed: int):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


Device: cpu


Quantile utils + loss (per-sample)

In [None]:
def quantile_midpoints(n: int, device=None):
    i = torch.arange(n, dtype=torch.float32, device=device)
    return (i + 0.5) / n

def epsilon_by_step(step: int, cfg: Config):
    t = min(1.0, step / max(1, cfg.eps_decay))
    return cfg.eps_start + t * (cfg.eps_final - cfg.eps_start)

def pinball_loss(q_pred: torch.Tensor,
                 q_targ: torch.Tensor,
                 taus: torch.Tensor,
                 reduce: str = "mean") -> torch.Tensor:
    """
    Standard quantile (pinball) loss.
    q_pred, q_targ: [B, N]
    taus: [N]
    If reduce='none' -> returns per-sample loss [B].
    """
    B, N = q_pred.shape
    diff = q_targ.unsqueeze(2) - q_pred.unsqueeze(1)
    # Huber-0: plain pinball via absolute value
    abs_u = diff.abs()
    tau = taus.view(1, 1, N)
    weight = torch.where(diff < 0.0, 1.0 - tau, tau)
    loss_all = weight * abs_u  # [B, N, N]
    # Average over target-quantiles and predicted-quantiles
    per_sample = loss_all.mean(dim=(1, 2))  # [B]
    if reduce == "none":
        return per_sample
    elif reduce == "mean":
        return per_sample.mean()
    else:
        raise ValueError("reduce must be 'none' or 'mean'")


Models (MLP for CartPole, Conv for MinAtar)

In [None]:
class QRDQN(nn.Module):
    """MLP head (CartPole). Outputs [B, A, N] quantiles."""
    def __init__(self, obs_dim: int, n_actions: int, n_quantiles: int, hidden: int):
        super().__init__()
        self.n_actions  = n_actions
        self.n_quantiles= n_quantiles
        self.net = nn.Sequential(
            nn.Linear(obs_dim, hidden), nn.ReLU(),
            nn.Linear(hidden, hidden),  nn.ReLU(),
            nn.Linear(hidden, n_actions * n_quantiles)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        z = self.net(x)
        return z.view(-1, self.n_actions, self.n_quantiles)

class QRDQNConv(nn.Module):
    """Simple Conv torso for MinAtar (C,H,W in {0,1}). Outputs [B, A, N]."""
    def __init__(self, C: int, n_actions: int, n_quantiles: int, hidden: int=128):
        super().__init__()
        self.n_actions   = n_actions
        self.n_quantiles = n_quantiles
        self.torso = nn.Sequential(
            nn.Conv2d(C, 32, kernel_size=3, stride=1, padding=1), nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1), nn.ReLU()
        )
        self.head = nn.Sequential(
            nn.Linear(32*10*10, hidden), nn.ReLU(),
            nn.Linear(hidden, n_actions * n_quantiles)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        h = self.torso(x)
        h = h.view(h.size(0), -1)
        z = self.head(h)
        return z.view(-1, self.n_actions, self.n_quantiles)


Replay buffers (Uniform + PER with fixes)

In [None]:
# --- Uniform replay ---

class Replay:
    def __init__(self, capacity: int, obs_shape: Tuple[int, ...]):
        self.capacity = capacity
        self.ptr = 0
        self.full = False
        self.S  = np.zeros((capacity,)+obs_shape, dtype=np.float32)
        self.A  = np.zeros(capacity, dtype=np.int64)
        self.R  = np.zeros(capacity, dtype=np.float32)
        self.NS = np.zeros((capacity,)+obs_shape, dtype=np.float32)
        self.D  = np.zeros(capacity, dtype=np.float32)

    def __len__(self):
        return self.capacity if self.full else self.ptr

    def push(self, s,a,r,ns,d):
        i = self.ptr
        self.S[i], self.A[i], self.R[i], self.NS[i], self.D[i] = s, a, r, ns, d
        self.ptr += 1
        if self.ptr >= self.capacity:
            self.ptr = 0
            self.full = True

    def sample(self, B: int):
        n = len(self)
        idx = np.random.randint(0, n, size=B)
        s  = torch.tensor(self.S[idx],  dtype=torch.float32, device=device)
        ns = torch.tensor(self.NS[idx], dtype=torch.float32, device=device)
        a  = torch.tensor(self.A[idx],  dtype=torch.int64,   device=device)
        r  = torch.tensor(self.R[idx],  dtype=torch.float32, device=device)
        d  = torch.tensor(self.D[idx],  dtype=torch.float32, device=device)
        return idx, s, a, r, ns, d

# --- Prioritized replay (fixed) ---

class PERBuffer:
    """
    Priorities from scalar TD-error on mean-Q.
    IS weights applied per-sample before reduction.
    Priorities clipped to [p_min, p_max].
    """
    def __init__(self, capacity, obs_shape, alpha=0.6, eps_prio=1e-6, p_min=1e-3, p_max=10.0):
        self.capacity = capacity
        self.alpha    = alpha
        self.eps_prio = eps_prio
        self.p_min    = p_min
        self.p_max    = p_max
        self.ptr = 0
        self.full = False

        self.S  = np.zeros((capacity,)+obs_shape, dtype=np.float32)
        self.A  = np.zeros(capacity, dtype=np.int64)
        self.R  = np.zeros(capacity, dtype=np.float32)
        self.NS = np.zeros((capacity,)+obs_shape, dtype=np.float32)
        self.D  = np.zeros(capacity, dtype=np.float32)
        self.P  = np.ones(capacity, dtype=np.float32)

    def __len__(self):
        return self.capacity if self.full else self.ptr

    def push(self, s,a,r,ns,d, prio=None):
        i = self.ptr
        self.S[i], self.A[i], self.R[i], self.NS[i], self.D[i] = s, a, r, ns, d
        if prio is None:
            prio = float(self.P[:len(self)].max() if len(self) > 0 else 1.0)
        self.P[i] = np.clip(prio, self.p_min, self.p_max)
        self.ptr += 1
        if self.ptr >= self.capacity:
            self.ptr = 0
            self.full = True

    def sample(self, B, beta=0.4):
        n = len(self)
        scaled = self.P[:n] ** self.alpha
        probs  = scaled / scaled.sum()
        idx    = np.random.choice(n, size=B, p=probs, replace=False)
        w = (n * probs[idx]) ** (-beta)
        w = w / w.max()
        s  = torch.tensor(self.S[idx],  dtype=torch.float32, device=device)
        ns = torch.tensor(self.NS[idx], dtype=torch.float32, device=device)
        a  = torch.tensor(self.A[idx],  dtype=torch.int64,   device=device)
        r  = torch.tensor(self.R[idx],  dtype=torch.float32, device=device)
        d  = torch.tensor(self.D[idx],  dtype=torch.float32, device=device)
        w_t= torch.tensor(w,            dtype=torch.float32, device=device)
        return idx, s, a, r, ns, d, w_t

    def update_priorities(self, idx, new_prio):
        new_prio = np.asarray(new_prio, dtype=np.float32)
        new_prio = np.abs(new_prio) + self.eps_prio
        new_prio = np.clip(new_prio, self.p_min, self.p_max)
        self.P[idx] = new_prio


Environments (CartPole + MinAtar wrapper) + factory

In [None]:
import gymnasium as gym
from minatar import Environment as MinAtarEnv

def make_env_cartpole(seed: int):
    env = gym.make("CartPole-v1")
    env.reset(seed=seed)
    return env

class ActionSpace:
    def __init__(self, n): self.n = n

class MinAtarGymLike:
    """
    Wrap MinAtar to emulate Gymnasium API:
    - reset(seed)->(obs, info)
    - step(a)->(obs, reward, terminated, truncated, info)
    - action_space.n
    - obs is float32 (C,H,W) in {0,1}
    """
    def __init__(self, game="breakout", sticky=0.1, seed=0):
        self.env = MinAtarEnv(game, sticky_action_prob=sticky)
        self.env.reset()
        self.env.seed(seed)
        self.action_space = ActionSpace(self.env.num_actions())
        s = self.env.state()
        H, W, C = s.shape
        self.shape = (C, H, W)

    def reset(self, seed=None):
        if seed is not None:
            self.env.seed(seed)
        self.env.reset()
        s = self.env.state()
        s = np.transpose(s, (2,0,1)).astype(np.float32)
        return s, {}

    def step(self, a):
        r, done = self.env.act(a)
        s = self.env.state()
        s = np.transpose(s, (2,0,1)).astype(np.float32)
        return s, float(r), bool(done), False, {}

    def close(self): pass

def make_env(cfg: Config, seed: int):
    if cfg.env_kind == "cartpole":
        return make_env_cartpole(seed)
    elif cfg.env_kind == "minatar":
        return MinAtarGymLike(cfg.game, cfg.sticky, seed)
    else:
        raise ValueError(f"Unknown env_kind: {cfg.env_kind}")


Policy, action, and eval helpers

In [None]:
@torch.no_grad()
def act_epsilon_greedy(state, net, epsilon: float):
    if np.random.rand() < epsilon:
        return np.random.randint(0, net.n_actions)
    if state.ndim == 1:
        s = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
    else:
        s = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
    q = net(s)
    a = q.mean(2).argmax(1)    # mean-over-quantiles
    return int(a.item())

@torch.no_grad()
def greedy_rollout(env, net, max_steps=5000):
    s,_ = env.reset()
    total = 0.0
    for _ in range(max_steps):
        a = act_epsilon_greedy(s, net, 0.0)
        s, r, term, trunc, _ = env.step(a)
        total += r
        if term or trunc: break
    return float(total)

@torch.no_grad()
def eval_model_greedy(net, cfg: Config, episodes=10, seed_base=12_345, max_steps=5000):
    env = make_env(cfg, seed_base)
    scores = []
    for i in range(episodes):
        try:
            s,_ = env.reset(seed=seed_base + i)
        except TypeError:
            s,_ = env.reset()
        total = 0.0
        for _ in range(max_steps):
            a = act_epsilon_greedy(s, net, 0.0)
            s, r, term, trunc, _ = env.step(a)
            total += r
            if term or trunc: break
        scores.append(float(total))
    env.close()
    arr = np.array(scores, dtype=np.float32)
    return float(arr.mean()), float(arr.std(ddof=1)), scores

Training — Uniform vs PER (fixed) (n=1)

In [None]:
def build_net_for_env(cfg: Config, env):
    if cfg.env_kind == "cartpole":
        obs_dim = env.observation_space.shape[0]
        nA = env.action_space.n
        net = QRDQN(obs_dim, nA, cfg.n_quantiles, cfg.hidden).to(device)
        obs_shape = (obs_dim,)
    else:
        s0,_ = env.reset()
        C,H,W = s0.shape
        nA = env.action_space.n
        net = QRDQNConv(C, nA, cfg.n_quantiles, cfg.hidden).to(device)
        obs_shape = (C,H,W)
    return net, obs_shape, nA

def train_uniform(run_name: str, cfg: Config, seed=0, log_eval_every=5_000):
    set_global_seeds(cfg.base_seed + seed)
    env = make_env(cfg, cfg.base_seed + seed)
    online, obs_shape, nA = build_net_for_env(cfg, env)
    target = build_net_for_env(cfg, env)[0]
    target.load_state_dict(online.state_dict())

    opt  = torch.optim.Adam(online.parameters(), lr=cfg.lr, eps=cfg.adam_eps)
    taus = quantile_midpoints(cfg.n_quantiles, device=device)

    rb = Replay(cfg.buffer_size, obs_shape)

    s,_ = env.reset()
    for step in range(1, cfg.total_steps+1):
        eps = epsilon_by_step(step, cfg)
        a = act_epsilon_greedy(s, online, eps)
        ns, r, term, trunc, _ = env.step(a)
        d = float(term or trunc)
        rb.push(s,a,r,ns,d)
        s = ns

        if len(rb) >= cfg.learn_start:
            _, bs, ba, br, bns, bd = rb.sample(cfg.batch_size)
            with torch.no_grad():
                next_q_online = online(bns)
                next_a = next_q_online.mean(2).argmax(1, keepdim=True)
                next_q_target = target(bns).gather(
                    1, next_a.unsqueeze(-1).expand(-1,-1,cfg.n_quantiles)
                ).squeeze(1)
                target_q = br.unsqueeze(1) + (1.0 - bd.unsqueeze(1)) * cfg.gamma * next_q_target

            q_all = online(bs)
            q_chosen = q_all.gather(1, ba.view(-1,1,1).expand(-1,1,cfg.n_quantiles)).squeeze(1)

            loss = pinball_loss(q_chosen, target_q, taus, reduce="mean")
            opt.zero_grad(set_to_none=True)
            loss.backward()
            nn.utils.clip_grad_norm_(online.parameters(), cfg.grad_clip)
            opt.step()

        if step % cfg.target_tau == 0:
            target.load_state_dict(online.state_dict())

        if term or trunc:
            s,_ = env.reset()

        if step % log_eval_every == 0:
            print(f"[{run_name}] step {step:7d} | buffer={len(rb):6d} | eps={eps:.3f}")

    mean10, _, _ = eval_model_greedy(online, cfg, episodes=10)
    path = f"/content/{run_name}.pt"
    torch.save(online.state_dict(), path)
    env.close()
    print(f"Saved: {path} | Mean greedy eval (10 eps): {mean10:.2f}")
    return {"ckpt": path, "mean10": mean10}

def train_per_fixed(run_name: str, cfg: Config, seed=0,
                    alpha=0.6, beta_start=0.4, beta_end=1.0,
                    log_eval_every=5_000):
    set_global_seeds(cfg.base_seed + seed)
    env = make_env(cfg, cfg.base_seed + seed)
    online, obs_shape, nA = build_net_for_env(cfg, env)
    target = build_net_for_env(cfg, env)[0]
    target.load_state_dict(online.state_dict())

    opt  = torch.optim.Adam(online.parameters(), lr=cfg.lr, eps=cfg.adam_eps)
    taus = quantile_midpoints(cfg.n_quantiles, device=device)

    rb = PERBuffer(cfg.buffer_size, obs_shape, alpha=alpha, eps_prio=1e-6, p_min=1e-3, p_max=10.0)

    s,_ = env.reset()

    def beta_at(step):
        t = min(1.0, step / max(1, cfg.total_steps))
        return beta_start + t * (beta_end - beta_start)

    for step in range(1, cfg.total_steps+1):
        eps = epsilon_by_step(step, cfg)
        a = act_epsilon_greedy(s, online, eps)
        ns, r, term, trunc, _ = env.step(a)
        d = float(term or trunc)
        rb.push(s,a,r,ns,d)
        s = ns

        if len(rb) >= cfg.learn_start:
            beta = beta_at(step)
            idx, bs, ba, br, bns, bd, is_w = rb.sample(cfg.batch_size, beta=beta)

            with torch.no_grad():
                next_q_online = online(bns)
                next_a = next_q_online.mean(2).argmax(1, keepdim=True)
                next_q_target = target(bns).gather(
                    1, next_a.unsqueeze(-1).expand(-1,-1,cfg.n_quantiles)
                ).squeeze(1)
                target_q = br.unsqueeze(1) + (1.0 - bd.unsqueeze(1)) * cfg.gamma * next_q_target

            q_all = online(bs)
            q_chosen = q_all.gather(1, ba.view(-1,1,1).expand(-1,1,cfg.n_quantiles)).squeeze(1)

            # Priorities from scalar mean-Q TD-error
            with torch.no_grad():
                q_mean      = q_chosen.mean(dim=1)
                target_mean = target_q.mean(dim=1)
                td_err_mean = (target_mean - q_mean).abs().detach().cpu().numpy()

            # Per-sample quantile loss, IS-weighted mean
            per_sample = pinball_loss(q_chosen, target_q, taus, reduce='none')
            loss = (is_w * per_sample).mean()

            opt.zero_grad(set_to_none=True)
            loss.backward()
            nn.utils.clip_grad_norm_(online.parameters(), cfg.grad_clip)
            opt.step()

            if step % cfg.target_tau == 0:
                target.load_state_dict(online.state_dict())

            rb.update_priorities(idx, td_err_mean)

        if term or trunc:
            s,_ = env.reset()

        if step % log_eval_every == 0:
            print(f"[{run_name}] step {step:7d} | buffer={len(rb):6d} | eps={eps:.3f}")

    mean10, _, _ = eval_model_greedy(online, cfg, episodes=10)
    path = f"/content/{run_name}.pt"
    torch.save(online.state_dict(), path)
    env.close()
    print(f"Saved: {path} | Mean greedy eval (10 eps): {mean10:.2f}")
    return {"ckpt": path, "mean10": mean10}

Simple runner (CartPole + MinAtar examples)

In [None]:
# --- runs ---

# CartPole settings
cfg_cp = Config(env_kind="cartpole",
                total_steps=50_000,
                buffer_size=32_768,
                batch_size=128,
                eps_start=1.0, eps_final=0.01, eps_decay=25_000,
                target_tau=2_000)

print("=== CartPole-v1 — PER (fixed) vs Uniform (n=1) ===")
per_out = train_per_fixed("cp_per_fixed_pinball_seed7", cfg_cp, seed=7)
uni_out = train_uniform  ("cp_uniform_pinball_seed7",  cfg_cp, seed=7)

# MinAtar settings (Breakout)
cfg_ma = Config(env_kind="minatar",
                game="breakout", sticky=0.1,
                total_steps=200_000,
                buffer_size=100_000,
                batch_size=128,
                eps_start=1.0, eps_final=0.01, eps_decay=100_000,
                target_tau=2_000)

print("\n=== MinAtar Breakout — PER (fixed) vs Uniform (n=1) ===")
per_b = train_per_fixed("minatar_breakout_per_seed7", cfg_ma, seed=7)
uni_b = train_uniform  ("minatar_breakout_uni_seed7", cfg_ma, seed=7)


=== CartPole-v1 — PER (fixed) vs Uniform (n=1) ===
[cp_per_fixed_pinball_seed7] step    5000 | buffer=  5000 | eps=0.802
[cp_per_fixed_pinball_seed7] step   10000 | buffer= 10000 | eps=0.604
[cp_per_fixed_pinball_seed7] step   15000 | buffer= 15000 | eps=0.406
[cp_per_fixed_pinball_seed7] step   20000 | buffer= 20000 | eps=0.208
[cp_per_fixed_pinball_seed7] step   25000 | buffer= 25000 | eps=0.010
[cp_per_fixed_pinball_seed7] step   30000 | buffer= 30000 | eps=0.010
[cp_per_fixed_pinball_seed7] step   35000 | buffer= 32768 | eps=0.010
[cp_per_fixed_pinball_seed7] step   40000 | buffer= 32768 | eps=0.010
[cp_per_fixed_pinball_seed7] step   45000 | buffer= 32768 | eps=0.010
[cp_per_fixed_pinball_seed7] step   50000 | buffer= 32768 | eps=0.010
Saved: /content/cp_per_fixed_pinball_seed7.pt | Mean greedy eval (10 eps): 294.30
[cp_uniform_pinball_seed7] step    5000 | buffer=  5000 | eps=0.802
[cp_uniform_pinball_seed7] step   10000 | buffer= 10000 | eps=0.604
[cp_uniform_pinball_seed7] step

Evaluate any saved checkpoint (greedy)

In [None]:
# --- Reload & evaluate helper (greedy ε=0) ---

def load_net_and_eval(ckpt_path: str, cfg: Config, episodes=100, seed_base=54_321):
    env = make_env(cfg, seed_base)
    net, _, _ = build_net_for_env(cfg, env)
    payload = torch.load(ckpt_path, map_location=device)
    net.load_state_dict(payload)
    env.close()
    m, s, scores = eval_model_greedy(net, cfg, episodes=episodes, seed_base=seed_base)
    print(f"[EVAL] {os.path.basename(ckpt_path)}: mean={m:.2f} ± {s:.2f}  (n={episodes})")
    return m, s, scores

load_net_and_eval("/content/cp_per_fixed_pinball_seed7.pt", cfg_cp, episodes=1000)
load_net_and_eval("/content/cp_uniform_pinball_seed7.pt", cfg_cp, episodes=1000)

load_net_and_eval("/content/minatar_breakout_per_seed7.pt", cfg_ma, episodes=1000, seed_base=65_432)
load_net_and_eval("/content/minatar_breakout_uni_seed7.pt", cfg_ma, episodes=1000, seed_base=65_432)


[EVAL] cp_per_fixed_pinball_seed7.pt: mean=284.27 ± 114.73  (n=1000)
[EVAL] cp_uniform_pinball_seed7.pt: mean=162.24 ± 5.88  (n=1000)
[EVAL] minatar_breakout_per_seed7.pt: mean=10.35 ± 5.70  (n=1000)
[EVAL] minatar_breakout_uni_seed7.pt: mean=13.05 ± 8.21  (n=1000)


(13.053999900817871,
 8.2127046585083,
 [17.0,
  23.0,
  24.0,
  7.0,
  24.0,
  3.0,
  1.0,
  17.0,
  1.0,
  5.0,
  3.0,
  24.0,
  12.0,
  26.0,
  23.0,
  1.0,
  36.0,
  14.0,
  11.0,
  6.0,
  17.0,
  17.0,
  22.0,
  6.0,
  0.0,
  4.0,
  13.0,
  26.0,
  11.0,
  11.0,
  5.0,
  3.0,
  5.0,
  14.0,
  11.0,
  11.0,
  7.0,
  21.0,
  1.0,
  5.0,
  25.0,
  2.0,
  7.0,
  3.0,
  25.0,
  21.0,
  7.0,
  27.0,
  22.0,
  12.0,
  24.0,
  3.0,
  1.0,
  25.0,
  28.0,
  0.0,
  13.0,
  9.0,
  8.0,
  10.0,
  3.0,
  8.0,
  11.0,
  18.0,
  14.0,
  12.0,
  12.0,
  13.0,
  6.0,
  7.0,
  24.0,
  14.0,
  12.0,
  12.0,
  6.0,
  18.0,
  21.0,
  26.0,
  12.0,
  18.0,
  5.0,
  13.0,
  14.0,
  22.0,
  14.0,
  7.0,
  2.0,
  17.0,
  8.0,
  4.0,
  26.0,
  2.0,
  12.0,
  7.0,
  23.0,
  0.0,
  0.0,
  1.0,
  16.0,
  14.0,
  9.0,
  7.0,
  8.0,
  8.0,
  5.0,
  22.0,
  0.0,
  10.0,
  26.0,
  7.0,
  3.0,
  26.0,
  7.0,
  14.0,
  5.0,
  7.0,
  22.0,
  24.0,
  5.0,
  12.0,
  11.0,
  17.0,
  12.0,
  3.0,
  3.0,
  4.0,
  3.0,
  