In [None]:
# %pip install procgen torch numpy tqdm matplotlib

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.categorical import Categorical
from dataclasses import dataclass
from typing import List, Dict, Tuple
from procgen import ProcgenEnv
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import random

@dataclass
class Config:
    env_name: str = "chaser"
    num_envs_train: int = 32
    num_envs_eval: int = 32
    num_levels_train: int = 200
    num_levels_eval: int = 200
    start_level_train: int = 0
    start_level_eval: int = 10000
    distribution_mode: str = "easy"
    total_timesteps: int = 1_000_000
    rollout_length: int = 256
    update_epochs: int = 3
    minibatch_size: int = 4096
    gamma: float = 0.999
    gae_lambda: float = 0.95
    clip_coef: float = 0.2
    ent_coef: float = 0.01
    vf_coef: float = 0.5
    max_grad_norm: float = 0.5
    learning_rate: float = 5e-4
    eval_interval_updates: int = 20
    eval_steps_per_env: int = 512
    seed: int = 1
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    use_random_shift: bool = True
    shift_pad: int = 4

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def make_env(num_envs: int, env_name: str, num_levels: int, start_level: int, distribution_mode: str, rand_seed: int) -> ProcgenEnv:
    return ProcgenEnv(
        num_envs=num_envs,
        env_name=env_name,
        num_levels=num_levels,
        start_level=start_level,
        distribution_mode=distribution_mode,
        rand_seed=rand_seed
    )

def get_obs(x):
    if isinstance(x, dict):
        for k in ("obs", "observation", "rgb"):
            if k in x: return x[k]
        raise KeyError(f"reset/step returned dict without observation key: {list(x.keys())}")
    return x

class RandomShift(nn.Module):
    def __init__(self, pad: int = 4):
        super().__init__()
        self.pad = pad
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.pad <= 0: return x
        b, c, h, w = x.shape
        x = F.pad(x, (self.pad, self.pad, self.pad, self.pad), mode="replicate")
        top = torch.randint(0, 2*self.pad + 1, (b,), device=x.device)
        left = torch.randint(0, 2*self.pad + 1, (b,), device=x.device)
        out = torch.empty((b, c, h, w), device=x.device, dtype=x.dtype)
        for i in range(b):
            out[i] = x[i, :, top[i]:top[i]+h, left[i]:left[i]+w]
        return out

class CNNPolicy(nn.Module):
    def __init__(self, in_ch: int, num_actions: int):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, 32, 3, stride=2, padding=1), nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, 3, stride=2, padding=1), nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 3, stride=2, padding=1), nn.ReLU(inplace=True),
        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64*8*8, 256),
            nn.ReLU(inplace=True),
        )
        self.pi = nn.Linear(256, num_actions)
        self.v = nn.Linear(256, 1)
    def forward(self, x):
        z = self.conv(x)
        z = self.fc(z)
        return self.pi(z), self.v(z)

class PPOAgent:
    def __init__(self, cfg: Config):
        set_seed(cfg.seed)
        self.cfg = cfg
        self.env = make_env(cfg.num_envs_train, cfg.env_name, cfg.num_levels_train, cfg.start_level_train, cfg.distribution_mode, rand_seed=cfg.seed)
        self.eval_env = make_env(cfg.num_envs_eval, cfg.env_name, cfg.num_levels_eval, cfg.start_level_eval, cfg.distribution_mode, rand_seed=cfg.seed+123)
        obs0_raw = self.env.reset()
        obs0 = get_obs(obs0_raw)
        self.N = obs0.shape[0]
        self.H, self.W, self.C = obs0.shape[1], obs0.shape[2], obs0.shape[3]
        self.num_actions = self.env.action_space.n if hasattr(self.env, "action_space") else 15
        self.model = CNNPolicy(in_ch=self.C, num_actions=self.num_actions).to(cfg.device)
        self.opt = torch.optim.Adam(self.model.parameters(), lr=cfg.learning_rate, eps=1e-5)
        self.augment = RandomShift(cfg.shift_pad) if cfg.use_random_shift else None
        self.global_step = 0
        self.metrics: Dict[str, List[float]] = {
            "updates": [], "train_return_mean": [], "train_return_std": [],
            "eval_return_mean": [], "eval_return_std": [],
            "policy_loss": [], "value_loss": [], "entropy": []
        }
        self.last_obs = obs0

    def _prep(self, obs_np: np.ndarray) -> torch.Tensor:
        x = torch.from_numpy(obs_np).to(self.cfg.device).float() / 255.0
        x = x.permute(0, 3, 1, 2).contiguous()
        return x

    def _policy(self, x: torch.Tensor):
        if self.augment is not None and self.model.training:
            x = self.augment(x)
        logits, value = self.model(x)
        dist = Categorical(logits=logits)
        return dist, value.squeeze(-1)

    def _env_step(self, env: ProcgenEnv, action_np: np.ndarray):
        out = env.step(action_np)
        if len(out) == 4:
            obs, rew, done, info = out
        elif len(out) == 5:
            obs, rew, done, info, _ = out
        else:
            raise RuntimeError("Unexpected ProcgenEnv.step output format")
        obs = get_obs(obs)
        return obs, rew, done, info

    def collect(self):
        T = self.cfg.rollout_length
        N = self.cfg.num_envs_train
        obs = self.last_obs if self.last_obs is not None else get_obs(self.env.reset())

        obs_buf = torch.zeros((T, N, self.C, self.H, self.W), device=self.cfg.device)
        act_buf = torch.zeros((T, N), device=self.cfg.device, dtype=torch.long)
        logp_buf = torch.zeros((T, N), device=self.cfg.device)
        rew_buf = torch.zeros((T, N), device=self.cfg.device)
        done_buf = torch.zeros((T, N), device=self.cfg.device)
        val_buf = torch.zeros((T, N), device=self.cfg.device)

        ep_returns = np.zeros(N, dtype=np.float32)
        finished_returns = []

        for t in range(T):
            self.global_step += N
            x = self._prep(obs)
            with torch.no_grad():
                self.model.train()
                dist, value = self._policy(x)
                action = dist.sample()
                logp = dist.log_prob(action)
            next_obs, reward, done, info = self._env_step(self.env, action.detach().cpu().numpy())
            obs_buf[t] = x
            act_buf[t] = action
            logp_buf[t] = logp
            rew_buf[t] = torch.from_numpy(reward).to(self.cfg.device)
            done_buf[t] = torch.from_numpy(done.astype(np.float32)).to(self.cfg.device)
            val_buf[t] = value
            ep_returns += reward
            for i in range(N):
                if done[i]:
                    finished_returns.append(ep_returns[i])
                    ep_returns[i] = 0.0
            obs = next_obs

        with torch.no_grad():
            x_last = self._prep(obs)
            self.model.train()
            _, next_value = self._policy(x_last)

        adv_buf = torch.zeros_like(rew_buf)
        lastgaelam = torch.zeros((N,), device=self.cfg.device)
        for t in reversed(range(T)):
            nextnonterm = 1.0 - done_buf[t]
            nextv = val_buf[t+1] if t < T-1 else next_value
            delta = rew_buf[t] + self.cfg.gamma * nextv * nextnonterm - val_buf[t]
            lastgaelam = delta + self.cfg.gamma * self.cfg.gae_lambda * nextnonterm * lastgaelam
            adv_buf[t] = lastgaelam
        ret_buf = adv_buf + val_buf

        self.last_obs = obs

        tr_mean = float(np.mean(finished_returns)) if len(finished_returns) > 0 else 0.0
        tr_std = float(np.std(finished_returns)) if len(finished_returns) > 0 else 0.0
        self.metrics["train_return_mean"].append(tr_mean)
        self.metrics["train_return_std"].append(tr_std)

        return obs_buf, act_buf, logp_buf, adv_buf, ret_buf

    def update(self, obs_buf, act_buf, logp_buf, adv_buf, ret_buf):
        T, N = obs_buf.shape[0], obs_buf.shape[1]
        B = T * N
        obs = obs_buf.reshape(B, self.C, self.H, self.W)
        act = act_buf.reshape(B)
        old_logp = logp_buf.reshape(B)
        adv = adv_buf.reshape(B)
        ret = ret_buf.reshape(B)
        adv = (adv - adv.mean()) / (adv.std(unbiased=False) + 1e-8)

        inds = np.arange(B)
        pl_losses, v_losses, ents = [], [], []

        for _ in range(self.cfg.update_epochs):
            np.random.shuffle(inds)
            for start in range(0, B, self.cfg.minibatch_size):
                end = start + self.cfg.minibatch_size
                mb = inds[start:end]
                self.model.train()
                dist, value = self._policy(obs[mb])
                new_logp = dist.log_prob(act[mb])
                entropy = dist.entropy().mean()
                ratio = (new_logp - old_logp[mb]).exp()
                pg_loss = torch.max(-adv[mb] * ratio,
                                    -adv[mb] * torch.clamp(ratio, 1.0 - self.cfg.clip_coef, 1.0 + self.cfg.clip_coef)).mean()
                v_loss = F.mse_loss(value, ret[mb])
                loss = pg_loss + self.cfg.vf_coef * v_loss - self.cfg.ent_coef * entropy
                self.opt.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.max_grad_norm)
                self.opt.step()
                pl_losses.append(pg_loss.item())
                v_losses.append(v_loss.item())
                ents.append(entropy.item())

        self.metrics["policy_loss"].append(float(np.mean(pl_losses)))
        self.metrics["value_loss"].append(float(np.mean(v_losses)))
        self.metrics["entropy"].append(float(np.mean(ents)))

    def evaluate(self) -> Tuple[float, float]:
        self.model.eval()
        obs = get_obs(self.eval_env.reset())
        N = self.cfg.num_envs_eval
        steps = self.cfg.eval_steps_per_env
        ep_returns = np.zeros(N, dtype=np.float32)
        finished = []
        for _ in range(steps):
            with torch.no_grad():
                x = self._prep(obs)
                logits, _ = self.model(x)
                action = torch.argmax(logits, dim=-1)
            obs, reward, done, info = self._env_step(self.eval_env, action.cpu().numpy())
            ep_returns += reward
            for i in range(N):
                if done[i]:
                    finished.append(ep_returns[i])
                    ep_returns[i] = 0.0
        m = float(np.mean(finished)) if len(finished) > 0 else float(np.mean(ep_returns))
        s = float(np.std(finished)) if len(finished) > 0 else float(np.std(ep_returns))
        self.metrics["eval_return_mean"].append(m)
        self.metrics["eval_return_std"].append(s)
        return m, s

    def train(self):
        num_updates = self.cfg.total_timesteps // (self.cfg.num_envs_train * self.cfg.rollout_length)
        pbar = tqdm(range(num_updates), desc="PPO")
        for u in pbar:
            obs_buf, act_buf, logp_buf, adv_buf, ret_buf = self.collect()
            self.update(obs_buf, act_buf, logp_buf, adv_buf, ret_buf)
            if (u + 1) % self.cfg.eval_interval_updates == 0:
                ev_m, _ = self.evaluate()
            else:
                ev_m = self.metrics["eval_return_mean"][-1] if self.metrics["eval_return_mean"] else 0.0
            self.metrics["updates"].append(u + 1)
            pbar.set_postfix(train=f"{self.metrics['train_return_mean'][-1]:.1f}", eval=f"{ev_m:.1f}")

def plot_metrics(metrics: Dict[str, List[float]], title="Procgen Chaser - Generalization"):
    up = metrics["updates"]
    fig, axs = plt.subplots(1, 3, figsize=(14, 4))
    axs[0].plot(up, metrics["train_return_mean"], label="train")
    axs[0].fill_between(up,
                        np.array(metrics["train_return_mean"]) - np.array(metrics["train_return_std"]),
                        np.array(metrics["train_return_mean"]) + np.array(metrics["train_return_std"]),
                        alpha=0.2)
    axs[0].set_title("Train return")
    axs[0].set_xlabel("Update")
    axs[0].set_ylabel("Return")
    axs[1].plot(up, metrics["eval_return_mean"], label="eval", color="green")
    if len(metrics["eval_return_std"]) == len(up):
        axs[1].fill_between(up,
                            np.array(metrics["eval_return_mean"]) - np.array(metrics["eval_return_std"]),
                            np.array(metrics["eval_return_mean"]) + np.array(metrics["eval_return_std"]),
                            color="green", alpha=0.2)
    axs[1].set_title("Eval return (unseen)")
    axs[1].set_xlabel("Update")
    axs[2].plot(up, metrics["policy_loss"], label="pi loss", color="red")
    axs[2].plot(up, metrics["value_loss"], label="v loss", color="orange")
    axs[2].set_title("Losses")
    axs[2].set_xlabel("Update")
    axs[2].legend()
    fig.suptitle(title)
    plt.tight_layout()
    plt.show()

if __name__ == "__main__":
    cfg = Config()
    agent = PPOAgent(cfg)
    agent.train()
    plot_metrics(agent.metrics, title="CNN + random shift (minimal)")


PPO:   0%|          | 0/122 [00:00<?, ?it/s]