In [1]:
import math
import random
from dataclasses import dataclass

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from procgen import ProcgenEnv
from torch.distributions import Categorical
from tqdm.auto import tqdm


@dataclass
class Config:
    env_name: str = "leaper"
    num_envs: int = 32
    num_levels: int = 100
    start_level: int = 0
    total_timesteps: int = 200_000
    rollout_length: int = 256
    update_epochs: int = 3
    minibatch_size: int = 2048
    gamma: float = 0.999
    gae_lambda: float = 0.95
    clip_coef: float = 0.2
    ent_coef: float = 0.01
    vf_coef: float = 0.5
    max_grad_norm: float = 0.5
    learning_rate: float = 5e-4
    seed: int = 1
    device: str = "cuda" if torch.cuda.is_available() else "cpu"


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def make_env(cfg: Config):
    return ProcgenEnv(
        num_envs=cfg.num_envs,
        env_name=cfg.env_name,
        num_levels=cfg.num_levels,
        start_level=cfg.start_level,
        distribution_mode="easy",
        rand_seed=cfg.seed
    )


def get_obs(x):
    if isinstance(x, dict):
        for k in ("obs", "observation", "rgb"):
            if k in x: return x[k]
        raise KeyError("observation key missing")
    return x


class RandomShift(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x


def orthogonal_init(m, gain=1.0):
    if isinstance(m, (nn.Conv2d, nn.Linear)):
        nn.init.orthogonal_(m.weight, gain=gain)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0.)


class SEBlock(nn.Module):
    def __init__(self, channels, reduction=8):
        super().__init__()
        self.fc1 = nn.Conv2d(channels, channels // reduction, 1)
        self.fc2 = nn.Conv2d(channels // reduction, channels, 1)

    def forward(self, x):
        s = F.adaptive_avg_pool2d(x, 1)
        s = F.relu(self.fc1(s), inplace=True)
        s = torch.sigmoid(self.fc2(s))
        return x * s


class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
        self.se = SEBlock(channels)

    def forward(self, x):
        h = F.relu(self.conv1(x), inplace=True)
        h = self.conv2(h)
        h = self.se(h)
        return F.relu(x + h, inplace=True)


class CNNPolicy(nn.Module):
    def __init__(self, in_ch, num_actions, h, w):
        super().__init__()
        base = 64
        self.stem = nn.Sequential(
            nn.Conv2d(in_ch, 32, 3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, base, 3, stride=2, padding=1),
            nn.ReLU(inplace=True),
        )
        self.trunk = nn.Sequential(
            ResidualBlock(base),
            ResidualBlock(base),
            nn.Conv2d(base, base, 3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            ResidualBlock(base),
        )
        with torch.no_grad():
            dummy = torch.zeros(1, in_ch, h, w)
            n = self.trunk(self.stem(dummy)).numel()
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(n, 512),
            nn.ReLU(inplace=True),
        )
        self.pi = nn.Linear(512, num_actions)
        self.v = nn.Linear(512, 1)
        self.apply(lambda m: orthogonal_init(m, gain=math.sqrt(2)))
        orthogonal_init(self.pi, gain=0.01)
        orthogonal_init(self.v, gain=1.0)

    def forward(self, x):
        z = self.fc(self.trunk(self.stem(x)))
        return self.pi(z), self.v(z).squeeze(-1)


def explained_variance(y_pred, y_true):
    vy = torch.var(y_true)
    return (1 - torch.var(y_true - y_pred) / (vy + 1e-8)).item()


class PPO:
    def __init__(self, cfg: Config):
        set_seed(cfg.seed)
        self.cfg = cfg
        self.env = make_env(cfg)
        obs0 = get_obs(self.env.reset())
        self.N = obs0.shape[0]
        self.H, self.W, self.C = obs0.shape[1:4]
        self.num_actions = self.env.action_space.n if hasattr(self.env, "action_space") else 15
        self.net = CNNPolicy(self.C, self.num_actions, self.H, self.W).to(cfg.device)
        self.opt = torch.optim.Adam(self.net.parameters(), lr=cfg.learning_rate, eps=1e-5)
        self.aug = RandomShift()
        self.last_obs = obs0
        self.global_step = 0
        self.metrics = {"update": [], "mean_return": [], "mean_ep_len": [], "policy_loss": [], "value_loss": [],
                        "entropy": [], "approx_kl": [], "explained_var": []}

    def _prep(self, obs_np):
        x = torch.from_numpy(obs_np).to(self.cfg.device).float() / 255.0
        return x.permute(0, 3, 1, 2).contiguous()

    def _policy(self, x):
        x = self.aug(x)
        logits, value = self.net(x)
        return Categorical(logits=logits), value

    def _step_env(self, a_np):
        o, r, d, info = self.env.step(a_np)
        return get_obs(o), r, d, info

    def collect(self):
        T, N = self.cfg.rollout_length, self.cfg.num_envs
        obs = self.last_obs
        obs_buf = torch.zeros((T, N, self.C, self.H, self.W), device=self.cfg.device)
        act_buf = torch.zeros((T, N), device=self.cfg.device, dtype=torch.long)
        logp_buf = torch.zeros((T, N), device=self.cfg.device)
        rew_buf = torch.zeros((T, N), device=self.cfg.device)
        done_buf = torch.zeros((T, N), device=self.cfg.device)
        val_buf = torch.zeros((T, N), device=self.cfg.device)
        ep_returns = np.zeros(N, dtype=np.float32)
        ep_lengths = np.zeros(N, dtype=np.int32)
        ep_return_log = []
        ep_len_log = []
        self.net.eval()
        for t in range(T):
            self.global_step += N
            x = self._prep(obs)
            with torch.no_grad():
                dist, v = self._policy(x)
                a = dist.sample()
                lp = dist.log_prob(a)
            next_obs, r, d, _ = self._step_env(a.cpu().numpy())
            obs_buf[t] = x
            act_buf[t] = a
            logp_buf[t] = lp
            rew_buf[t] = torch.from_numpy(r).to(self.cfg.device)
            done_buf[t] = torch.from_numpy(d.astype(np.float32)).to(self.cfg.device)
            val_buf[t] = v
            ep_returns += r
            ep_lengths += 1
            for i in range(N):
                if d[i]:
                    ep_return_log.append(ep_returns[i])
                    ep_len_log.append(ep_lengths[i])
                    ep_returns[i] = 0.0
                    ep_lengths[i] = 0
            obs = next_obs
        with torch.no_grad():
            x_last = self._prep(obs)
            self.net.eval()
            _, next_v = self._policy(x_last)
        adv = torch.zeros_like(rew_buf)
        lastgaelam = torch.zeros((N,), device=self.cfg.device)
        for t in reversed(range(T)):
            nextnonterm = 1.0 - done_buf[t]
            nextv = val_buf[t + 1] if t < T - 1 else next_v
            delta = rew_buf[t] + self.cfg.gamma * nextv * nextnonterm - val_buf[t]
            lastgaelam = delta + self.cfg.gamma * self.cfg.gae_lambda * nextnonterm * lastgaelam
            adv[t] = lastgaelam
        ret = adv + val_buf
        self.last_obs = obs
        mean_return = float(np.mean(ep_return_log)) if len(ep_return_log) > 0 else float(ep_returns.mean())
        mean_ep_len = float(np.mean(ep_len_log)) if len(ep_len_log) > 0 else float(ep_lengths.mean())
        return obs_buf, act_buf, logp_buf, adv, ret, mean_return, mean_ep_len

    def update(self, obs_buf, act_buf, logp_buf, adv_buf, ret_buf):
        T, N = obs_buf.shape[:2]
        B = T * N
        obs = obs_buf.reshape(B, self.C, self.H, self.W)
        act = act_buf.reshape(B)
        old_logp = logp_buf.reshape(B)
        adv = adv_buf.reshape(B)
        ret = ret_buf.reshape(B)
        adv = (adv - adv.mean()) / (adv.std(unbiased=False) + 1e-8)
        inds = np.arange(B)
        self.net.train()
        ploss_acc = []
        vloss_acc = []
        ent_acc = []
        kl_acc = []
        ev_acc = []
        for _ in range(self.cfg.update_epochs):
            np.random.shuffle(inds)
            for s in range(0, B, self.cfg.minibatch_size):
                mb = inds[s:s + self.cfg.minibatch_size]
                dist, v = self._policy(obs[mb])
                new_logp = dist.log_prob(act[mb])
                entropy = dist.entropy().mean()
                ratio = (new_logp - old_logp[mb]).exp()
                pg_loss = torch.max(
                    -adv[mb] * ratio,
                    -adv[mb] * torch.clamp(ratio, 1.0 - self.cfg.clip_coef, 1.0 + self.cfg.clip_coef)
                ).mean()
                v_pred_clipped = v.detach() + (v - v.detach()).clamp(-0.2, 0.2)
                v_loss_unclipped = F.mse_loss(v, ret[mb], reduction="none")
                v_loss_clipped = F.mse_loss(v_pred_clipped, ret[mb], reduction="none")
                v_loss = torch.max(v_loss_unclipped, v_loss_clipped).mean()
                loss = pg_loss + self.cfg.vf_coef * v_loss - self.cfg.ent_coef * entropy
                self.opt.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(self.net.parameters(), self.cfg.max_grad_norm)
                self.opt.step()
                with torch.no_grad():
                    approx_kl = (old_logp[mb] - new_logp).mean().clamp_min(0).item()
                    ev = explained_variance(v.detach(), ret[mb])
                ploss_acc.append(pg_loss.item())
                vloss_acc.append(v_loss.item())
                ent_acc.append(entropy.item())
                kl_acc.append(approx_kl)
                ev_acc.append(ev)
        return np.mean(ploss_acc), np.mean(vloss_acc), np.mean(ent_acc), np.mean(kl_acc), np.mean(ev_acc)

    def train(self):
        num_updates = self.cfg.total_timesteps // (self.cfg.num_envs * self.cfg.rollout_length)
        print(f"Device: {self.cfg.device}")
        print(
            f"Env: {self.cfg.env_name} | num_envs: {self.cfg.num_envs} | num_levels: {self.cfg.num_levels} | start_level: {self.cfg.start_level}")
        print(f"Obs shape: (N={self.N}, H={self.H}, W={self.W}, C={self.C})")
        print(f"Num actions: {self.num_actions}")
        print(f"Total updates: {num_updates} | Samples per update: {self.cfg.num_envs * self.cfg.rollout_length}")
        pbar = tqdm(range(num_updates), desc="PPO")
        for ui in pbar:
            obs_buf, act_buf, logp_buf, adv, ret, mean_return, mean_ep_len = self.collect()
            p_loss, v_loss, entropy, kl, ev = self.update(obs_buf, act_buf, logp_buf, adv, ret)
            self.metrics["update"].append(ui)
            self.metrics["mean_return"].append(mean_return)
            self.metrics["mean_ep_len"].append(mean_ep_len)
            self.metrics["policy_loss"].append(p_loss)
            self.metrics["value_loss"].append(v_loss)
            self.metrics["entropy"].append(entropy)
            self.metrics["approx_kl"].append(kl)
            self.metrics["explained_var"].append(ev)
            pbar.set_postfix(ret=f"{mean_return:.2f}", len=f"{mean_ep_len:.1f}", p=f"{p_loss:.3f}", v=f"{v_loss:.3f}",
                             ent=f"{entropy:.3f}", kl=f"{kl:.3f}", ev=f"{ev:.3f}")
        self.plot_metrics()

    def plot_metrics(self):
        u = self.metrics["update"]
        plt.figure(figsize=(12, 8))
        plt.subplot(2, 2, 1)
        plt.plot(u, self.metrics["mean_return"])
        plt.title("Mean episodic return")
        plt.subplot(2, 2, 2)
        plt.plot(u, self.metrics["mean_ep_len"])
        plt.title("Mean episodic length")
        plt.subplot(2, 2, 3)
        plt.plot(u, self.metrics["policy_loss"], label="policy")
        plt.plot(u, self.metrics["value_loss"], label="value")
        plt.title("Losses")
        plt.legend()
        plt.subplot(2, 2, 4)
        plt.plot(u, self.metrics["entropy"], label="entropy")
        plt.plot(u, self.metrics["approx_kl"], label="approx KL")
        plt.plot(u, self.metrics["explained_var"], label="explained var")
        plt.title("Diagnostics")
        plt.legend()
        plt.tight_layout()
        plt.show()


if __name__ == "__main__":
    cfg = Config()
    agent = PPO(cfg)
    agent.train()
    torch.save(agent.net.state_dict(), f"ppo_{cfg.env_name}.pt")


Gym has been unmaintained since 2022 and does not support NumPy 2.0 amongst other critical functionality.
Please upgrade to Gymnasium, the maintained drop-in replacement of Gym, or contact the authors of your software and request that they upgrade.
Users of this version of Gym should be able to simply replace 'import gym' with 'import gymnasium as gym' in the vast majority of cases.
See the migration guide at https://gymnasium.farama.org/introduction/migration_guide/ for additional information.
  from .autonotebook import tqdm as notebook_tqdm


Device: cpu
Env: leaper | num_envs: 32 | num_levels: 100 | start_level: 0
Obs shape: (N=32, H=64, W=64, C=3)
Num actions: 15
Total updates: 24 | Samples per update: 8192


PPO:   0%|          | 0/24 [00:09<?, ?it/s]


KeyboardInterrupt: 

In [2]:
import time
import cv2
import torch

cfg = Config()
agent = PPO(cfg)
agent.net.load_state_dict(torch.load(f"ppo_{cfg.env_name}.pt", map_location=cfg.device))
agent.net.eval()

env = make_env(cfg)
obs = get_obs(env.reset())

cv2.namedWindow("procgen", cv2.WINDOW_NORMAL)

for step in range(1000):
    x = agent._prep(obs)
    with torch.no_grad():
        dist, _ = agent._policy(x)
        action = dist.sample().cpu().numpy()
    o, r, d, info = env.step(action)
    obs = get_obs(o)

    frame = obs[0][:, :, ::-1]
    frame = cv2.resize(frame, (800, 600), interpolation=cv2.INTER_NEAREST)

    cv2.imshow("procgen", frame)
    time.sleep(0.03)
    if cv2.waitKey(1) & 0xFF == 27:  # Esc to quit
        break

cv2.destroyAllWindows()


  agent.net.load_state_dict(torch.load(f"ppo_{cfg.env_name}.pt", map_location=cfg.device))
