In [2]:
# %% [markdown]
# # PPO from scratch on CarRacing-v3 (Gymnasium, Py3.12/Colab)
# - Installs minimal deps (Box2D), works headless via rgb_array (no display needed)
# - Simple CNN Actor-Critic + tanh-Gaussian PPO
# - Frame stacking for better dynamics (default K=4)
# - Saves a short evaluation video at the end

# %% [code]
!apt -yq install swig >/dev/null
!pip -q install "gymnasium[box2d]" "torch" "torchvision" "opencv-python-headless" "imageio" "imageio-ffmpeg" "tqdm" "einops" numpy

import os, math, random, time, imageio, numpy as np
import gymnasium as gym
import torch, torch.nn as nn, torch.nn.functional as F
from torch.distributions import Normal
from collections import deque
from tqdm.auto import trange
import cv2




[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m374.4/374.4 kB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.9/1.9 MB[0m [31m13.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for box2d-py (setup.py) ... [?25l[?25hdone


In [4]:

print("Gymnasium:", gym.__version__)
print("Torch:", torch.__version__)

# ---------- Helpers (Gymnasium API) ----------
def reset_env(env, seed=None):
    obs, info = env.reset(seed=seed)
    return obs, info

def step_env(env, action):
    obs, r, terminated, truncated, info = env.step(action)
    return obs, r, (terminated or truncated), info

# ---------- Env ----------
# CarRacing-v3: obs = (96,96,3) uint8, action = [steer∈[-1,1], gas∈[0,1], brake∈[0,1]]
env = gym.make("CarRacing-v3", render_mode="rgb_array", domain_randomize=False, continuous=True)
print("Obs space:", env.observation_space, "| Act space:", env.action_space)

# ---------- Preprocess & Frame Stack ----------
def preprocess(obs, out=84):
    img = cv2.resize(obs, (out, out), interpolation=cv2.INTER_AREA)
    # keep RGB; normalize to [0,1]
    return torch.from_numpy(img).float().permute(2,0,1) / 255.0  # [3,H,W]

class FrameStacker:
    def __init__(self, k=4):
        self.k = k
        self.frames = deque(maxlen=k)
    def reset(self, obs_t):
        self.frames.clear()
        for _ in range(self.k):
            self.frames.append(obs_t.clone())
        return torch.cat(list(self.frames), dim=0)  # [3k,H,W]
    def step(self, obs_t):
        self.frames.append(obs_t)
        return torch.cat(list(self.frames), dim=0)

# ---------- CNN Actor-Critic ----------
class CNNActorCritic(nn.Module):
    def __init__(self, in_ch, action_space):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, 32, 8, 4), nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, 4, 2),   nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 3, 1),   nn.ReLU(inplace=True),
        )
        with torch.no_grad():
            n_flat = self.conv(torch.zeros(1, in_ch, 84, 84)).view(1,-1).shape[1]
        self.fc = nn.Sequential(nn.Linear(n_flat, 256), nn.ReLU(inplace=True))

        self.act_dim = int(np.prod(action_space.shape))  # should be 3
        self.mu = nn.Linear(256, self.act_dim)
        # higher steer exploration, tiny brake exploration
        self.log_std = nn.Parameter(torch.tensor([-0.3, -0.7, -2.0]))  # [steer, gas, brake]
        self.v = nn.Linear(256, 1)

        # init so: steer ~0, gas ~0.38, brake ~0.12 at start
        nn.init.zeros_(self.mu.weight)
        self.mu.bias.data = torch.tensor([0.0, -0.5, -2.0])

    def encode(self, x):
        z = self.conv(x); z = z.view(z.size(0), -1); return self.fc(z)

    def forward(self, x):
        z = self.encode(x)
        mu = self.mu(z)
        std = self.log_std.exp().expand_as(mu)
        dist = Normal(mu, std)
        v = self.v(z).squeeze(-1)
        return dist, v

    def _squash(self, a_pre):
        steer_pre, gas_pre, brake_pre = a_pre[...,0:1], a_pre[...,1:2], a_pre[...,2:3]
        steer = torch.tanh(steer_pre)        # [-1,1]
        gas   = torch.sigmoid(gas_pre)       # [0,1]
        brake = torch.sigmoid(brake_pre)     # [0,1]
        return torch.cat([steer, gas, brake], dim=-1)

    def act(self, x):
        dist, v = self.forward(x)
        a_pre = dist.rsample()
        a_env = self._squash(a_pre)

        # change-of-variables corrections
        sp, gp, bp = a_pre[...,0], a_pre[...,1], a_pre[...,2]
        steer_corr = torch.log(1 - torch.tanh(sp)**2 + 1e-8)
        s = torch.sigmoid(gp); gas_corr = torch.log(s*(1 - s) + 1e-8)
        b = torch.sigmoid(bp); brake_corr = torch.log(b*(1 - b) + 1e-8)
        logp = dist.log_prob(a_pre).sum(-1) + steer_corr + gas_corr + brake_corr
        return a_env, logp, v

    def eval_action_logp_v(self, x, a_env):
        # invert env actions -> pre-activation
        steer = torch.clamp(a_env[...,0], -0.999, 0.999)
        gas   = torch.clamp(a_env[...,1], 1e-6, 1-1e-6)
        brake = torch.clamp(a_env[...,2], 1e-6, 1-1e-6)

        sp = torch.atanh(steer)
        gp = torch.log(gas) - torch.log(1 - gas)      # logit
        bp = torch.log(brake) - torch.log(1 - brake)  # logit
        a_pre = torch.stack([sp, gp, bp], dim=-1)

        dist, v = self.forward(x)
        base = dist.log_prob(a_pre).sum(-1)
        corr = torch.log(1 - torch.tanh(sp)**2 + 1e-8) \
             + torch.log(gas*(1 - gas) + 1e-8) \
             + torch.log(brake*(1 - brake) + 1e-8)
        logp = base + corr
        return logp, v

# ---------- Config ----------
cfg = dict(
    seed=42,
    frame_stack=4,
    total_steps=150_000,     # bump for better results (e.g., 1e6)
    rollout_len=2048,
    minibatch_size=256,
    ppo_epochs=8,
    gamma=0.99,
    lam=0.95,
    clip_ratio=0.2,
    lr=3e-4,
    vf_coef=0.5,
    ent_coef=0.01,
    max_grad_norm=0.5,
    log_interval=5_000,
)

random.seed(cfg["seed"]); np.random.seed(cfg["seed"]); torch.manual_seed(cfg["seed"])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# ---------- Policy / Optim ----------
in_ch = 3 * cfg["frame_stack"]
policy = CNNActorCritic(in_ch, env.action_space).to(device)
optimizer = torch.optim.Adam(policy.parameters(), lr=cfg["lr"])


# ---------- Buffer ----------
class RolloutBuffer:
    def __init__(self, n, obs_shape, act_dim):
        self.obs = torch.zeros((n, *obs_shape), dtype=torch.float32)
        self.act = torch.zeros((n, act_dim), dtype=torch.float32)
        self.logp = torch.zeros(n, dtype=torch.float32)
        self.rew  = torch.zeros(n, dtype=torch.float32)
        self.done = torch.zeros(n, dtype=torch.float32)
        self.val  = torch.zeros(n, dtype=torch.float32)
        self.ptr, self.n = 0, n
    def add(self, o, a, lp, r, d, v):
        self.obs[self.ptr].copy_(o.cpu())
        self.act[self.ptr].copy_(a.detach().cpu())
        self.logp[self.ptr] = lp.detach().cpu()
        self.rew[self.ptr]  = r
        self.done[self.ptr] = float(d)
        self.val[self.ptr]  = v.detach().cpu()
        self.ptr += 1
    def gae(self, gamma, lam, last_v):
        adv = torch.zeros_like(self.rew)
        last = 0.0
        for t in reversed(range(self.ptr)):
            if t == self.ptr - 1:
                next_nonterm, next_v = 1.0 - self.done[t], last_v
            else:
                next_nonterm, next_v = 1.0 - self.done[t+1], self.val[t+1]
            delta = self.rew[t] + gamma * next_v * next_nonterm - self.val[t]
            last = delta + gamma * lam * next_nonterm * last
            adv[t] = last
        ret = adv + self.val[:self.ptr]
        adv = (adv - adv.mean()) / (adv.std() + 1e-8)
        return adv[:self.ptr], ret[:self.ptr]

# ---------- Training ----------
fs = FrameStacker(cfg["frame_stack"])
obs_raw, _ = reset_env(env, seed=cfg["seed"])
obs_t = preprocess(obs_raw)
state = fs.reset(obs_t)
state = state.to(device)

episode_returns, ep_ret = [], 0.0
global_step = 0
pbar = trange(cfg["total_steps"], desc="Training PPO (CarRacing)", leave=True, miniters=cfg["log_interval"])

while global_step < cfg["total_steps"]:
    buf = RolloutBuffer(cfg["rollout_len"], (in_ch, 84, 84), int(np.prod(env.action_space.shape)))
    for _ in range(cfg["rollout_len"]):
        with torch.no_grad():
            a_env, logp, v = policy.act(state.unsqueeze(0))
        a_np = a_env.squeeze(0).cpu().numpy().astype(np.float32)
        next_obs_raw, r, done, info = step_env(env, a_np)
        ep_ret += r

        # store
        buf.add(state, torch.from_numpy(a_np), logp.squeeze(0), r, done, v.squeeze(0))

        # next state
        next_obs_t = preprocess(next_obs_raw)
        state = fs.step(next_obs_t).to(device)

        if done:
            episode_returns.append(ep_ret); ep_ret = 0.0
            next_obs_raw, _ = reset_env(env)
            next_obs_t = preprocess(next_obs_raw)
            state = fs.reset(next_obs_t).to(device)

        global_step += 1
        if global_step % cfg["log_interval"] == 0:
            last10 = np.mean(episode_returns[-10:]) if episode_returns else 0.0
            pbar.set_postfix(steps=global_step, avg_return=f"{last10:.1f}")
            pbar.update(cfg["log_interval"])

    with torch.no_grad():
        _, last_v = policy.forward(state.unsqueeze(0))
        last_v = last_v.squeeze(0).detach().cpu()
    adv, ret = buf.gae(cfg["gamma"], cfg["lam"], last_v)

    obs_mb = buf.obs[:buf.ptr].to(device)
    act_mb = buf.act[:buf.ptr].to(device)
    old_logp = buf.logp[:buf.ptr].to(device)
    adv_mb = adv.to(device); ret_mb = ret.to(device)

    n = obs_mb.shape[0]
    idxs = np.arange(n)
    for _ in range(cfg["ppo_epochs"]):
        np.random.shuffle(idxs)
        for s in range(0, n, cfg["minibatch_size"]):
            mb = idxs[s:s+cfg["minibatch_size"]]
            new_logp, v_pred = policy.eval_action_logp_v(obs_mb[mb], act_mb[mb])
            ratio = torch.exp(new_logp - old_logp[mb])
            surr1 = ratio * adv_mb[mb]
            surr2 = torch.clamp(ratio, 1-cfg["clip_ratio"], 1+cfg["clip_ratio"]) * adv_mb[mb]
            policy_loss = -torch.min(surr1, surr2).mean()
            v_loss = F.mse_loss(v_pred, ret_mb[mb])
            dist, _ = policy.forward(obs_mb[mb])
            entropy = dist.entropy().sum(-1).mean()

            loss = policy_loss + cfg["vf_coef"]*v_loss - cfg["ent_coef"]*entropy
            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            nn.utils.clip_grad_norm_(policy.parameters(), cfg["max_grad_norm"])
            optimizer.step()

print("Episodes:", len(episode_returns))
print("Recent returns:", [round(x,1) for x in episode_returns[-10:]])

# ---------- Evaluation + Video ----------
def eval_and_record(env, policy, steps=1000, fname="ppo_carracing_demo.mp4", fps=30):
    frames, ep_ret = [], 0.0
    obs_raw, _ = reset_env(env, seed=123)
    fs = FrameStacker(cfg["frame_stack"])
    state = fs.reset(preprocess(obs_raw)).to(device)
    policy.eval()
    with torch.no_grad():
        for t in range(steps):
            a_env, _, _ = policy.act(state.unsqueeze(0))
            obs_raw, r, done, _ = step_env(env, a_env.squeeze(0).cpu().numpy().astype(np.float32))
            ep_ret += r
            frame = env.render()  # rgb_array
            frames.append(frame)
            state = fs.step(preprocess(obs_raw)).to(device)
            if done:
                break
    imageio.mimsave(fname, frames, fps=fps)
    print(f"Eval return: {ep_ret:.1f} | saved {fname}")

eval_and_record(env, policy, steps=800)


Gymnasium: 1.2.0
Torch: 2.8.0+cu126
Obs space: Box(0, 255, (96, 96, 3), uint8) | Act space: Box([-1.  0.  0.], 1.0, (3,), float32)
Device: cuda


Training PPO (CarRacing):   0%|          | 0/150000 [00:00<?, ?it/s]

Episodes: 157
Recent returns: [157.4, 151.8, 38.5, -35.1, 161.2, 158.8, 94.4, 140.0, 95.9, 174.9]




Eval return: 100.1 | saved ppo_carracing_demo.mp4


In [5]:
with torch.no_grad():
    a, _, _ = policy.act(state.unsqueeze(0))
    print("sample action:", a.squeeze(0).cpu().numpy())


sample action: [-0.31121773  0.41215616  0.1455702 ]
