## Setup

In [None]:
!nvidia-smi -L || true
import gymnasium as gym
import ale_py
import torch
print("torch", torch.__version__, "cuda available?", torch.cuda.is_available())


GPU 0: Tesla T4 (UUID: GPU-d1f6db03-7639-26ea-209a-ca5e45df7754)
torch 2.9.0+cu126 cuda available? True


In [None]:
from google.colab import drive
drive.mount("/content/drive")


Mounted at /content/drive


### Clone git repo into Colab
I like this move

In [None]:
import os, pathlib

ROOT = "/content/drive/MyDrive/Colab_Notebooks"
REPO_DIR = f"{ROOT}/tg_smn"

pathlib.Path(ROOT).mkdir(parents=True, exist_ok=True)

%cd $ROOT
if not os.path.exists(REPO_DIR):
    !git clone https://github.com/RespectableGlioma/tg_smn.git
else:
    %cd $REPO_DIR
    !git pull


/content/drive/MyDrive/Colab_Notebooks
/content/drive/MyDrive/Colab_Notebooks/tg_smn
Already up to date.


In [None]:
!pip -q install -U pip setuptools wheel
!pip -q install "gymnasium[atari]" ale-py opencv-python pillow tqdm

# Optional: if you want to use the repo's packaging as well (not required for the world model script):
%cd $REPO_DIR
!pip -q install -e .


[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.8 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m86.3 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.1 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m75.6 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
ipython 7.34.0 requires jedi>=0.16, which is not installed.[0m[31m
[0m/content/drive/MyDrive/Colab_Notebooks/tg_smn
  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build editable ... [?25l[?25hdone
  Preparing editable metadata (pyproject.toml

### Quick Test

In [None]:
# Per ale-py docs: not strictly required, but can help IDEs / env registration :contentReference[oaicite:5]{index=5}
gym.register_envs(ale_py)

env_id = "ALE/Pong-v5"
env = gym.make(env_id, frameskip=1, repeat_action_probability=0.0)
obs, info = env.reset()
print("Loaded", env_id, "| obs type:", type(obs), "| obs shape:", getattr(obs, "shape", None))
for _ in range(5):
    obs, r, term, trunc, info = env.step(env.action_space.sample())
print("Step ok.")
env.close()


Loaded ALE/Pong-v5 | obs type: <class 'numpy.ndarray'> | obs shape: (210, 160, 3)
Step ok.


### Setup rest of codebase

In [None]:
import pathlib
world_dir = pathlib.Path(REPO_DIR) / "world_models"
world_dir.mkdir(parents=True, exist_ok=True)
print("world_models dir:", world_dir)


world_models dir: /content/drive/MyDrive/Colab_Notebooks/tg_smn/world_models


In [None]:
%%writefile /content/drive/MyDrive/Colab_Notebooks/tg_smn/world_models/ale_rssm_causal_stochastic_v2.py
"""
Dreamer-style RSSM with explicit split:
  - (h,z): causal core (temporal, used for transitions)
  - u    : stochastic envelope (nuisance, i.i.d., not used in transitions)

Improvements vs v1:
  1) Foreground/motion-weighted reconstruction so the model must track small moving objects (ball).
  2) Overshooting (multi-step prior rollout loss) so the transition learns "rules".
  3) Log RAW KL (before free-nats clamp) to detect posterior collapse.
  4) Optional detach of z into noisy decoder to force u to capture nuisance.

Usage:
  python world_models/ale_rssm_causal_stochastic_v2.py --env_id ALE/Pong-v5 --repeat_action_probability 0.0
"""
import argparse
from dataclasses import dataclass
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm


# -------------------------
# Utils
# -------------------------
def set_seed(seed: int):
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

def onehot(a: torch.Tensor, n: int) -> torch.Tensor:
    return F.one_hot(a, num_classes=n).float()

def augment_obs(clean01: torch.Tensor, noise_std: float = 0.05) -> torch.Tensor:
    b = torch.empty((clean01.shape[0], 1, 1, 1), device=clean01.device).uniform_(0.6, 1.4)
    y = clean01 * b
    if noise_std > 0:
        y = y + noise_std * torch.randn_like(y)
    return torch.clamp(y, 0.0, 1.0)

def kl_diag_gaussian(mu_q, logstd_q, mu_p, logstd_p):
    std_q = torch.exp(logstd_q)
    std_p = torch.exp(logstd_p)
    var_q = std_q * std_q
    var_p = std_p * std_p
    kl = (logstd_p - logstd_q) + (var_q + (mu_q - mu_p) ** 2) / (2.0 * var_p) - 0.5
    return torch.sum(kl, dim=-1)  # [B]

def kl_std_normal(mu, logstd):
    var = torch.exp(2.0 * logstd)
    kl = 0.5 * (var + mu**2 - 1.0 - torch.log(var + 1e-8))
    return torch.sum(kl, dim=-1)  # [B]

def reparam(mu, logstd):
    std = torch.exp(logstd)
    eps = torch.randn_like(std)
    return mu + eps * std

def weighted_mse(pred: torch.Tensor, tgt: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
    # pred/tgt/w: [B,1,H,W]
    err = (pred - tgt) ** 2
    num = torch.mean(err * w)
    den = torch.mean(w) + 1e-8
    return num / den

def save_grid_png(path: Path, top: np.ndarray, bottom: np.ndarray, pad: int = 2):
    """
    Save 2-row strip image (grayscale) comparing top vs bottom.
    top/bottom: [K,H,W] float in [0,1]
    """
    from PIL import Image
    assert top.shape == bottom.shape
    K, H, W = top.shape
    canvas = np.zeros((2*H + pad, K*W + (K-1)*pad), dtype=np.uint8)
    def u8(x): return (np.clip(x,0,1)*255).astype(np.uint8)
    for i in range(K):
        x0 = i*(W+pad)
        canvas[0:H, x0:x0+W] = u8(top[i])
        canvas[H+pad:H+pad+H, x0:x0+W] = u8(bottom[i])
    Image.fromarray(canvas).save(str(path))


# -------------------------
# ALE env + dataset
# -------------------------
def make_atari_env(env_id: str, seed: int, frame_skip: int, repeat_action_prob: float,
                   screen_size: int, noop_max: int):
    import gymnasium as gym
    import ale_py
    gym.register_envs(ale_py)

    env = gym.make(env_id, frameskip=1, repeat_action_probability=repeat_action_prob)
    env = gym.wrappers.AtariPreprocessing(
        env,
        noop_max=noop_max,
        frame_skip=frame_skip,
        screen_size=screen_size,
        grayscale_obs=True,
        grayscale_newaxis=False,
        scale_obs=False,  # uint8
        terminal_on_life_loss=False,
    )
    env.reset(seed=seed)
    return env

def collect_dataset(env, collect_steps: int):
    obs0, _ = env.reset()
    obs0 = np.asarray(obs0, dtype=np.uint8)
    H, W = obs0.shape

    obs = np.zeros((collect_steps + 1, H, W), dtype=np.uint8)
    act = np.zeros((collect_steps,), dtype=np.int64)
    done = np.zeros((collect_steps,), dtype=np.bool_)

    obs[0] = obs0
    o = obs0
    for t in tqdm(range(collect_steps), desc="Collect"):
        a = env.action_space.sample()
        o2, _r, terminated, truncated, _info = env.step(a)
        d = bool(terminated or truncated)
        o2 = np.asarray(o2, dtype=np.uint8)

        act[t] = a
        done[t] = d
        obs[t+1] = o2

        if d:
            o, _ = env.reset()
            o = np.asarray(o, dtype=np.uint8)
            obs[t+1] = o
        else:
            o = o2

    return obs, act, done

def valid_starts_from_dones(done: np.ndarray, seq_len: int):
    N = done.shape[0]
    if N < seq_len:
        return np.array([], dtype=np.int64)
    d = done.astype(np.int32)
    cs = np.concatenate([[0], np.cumsum(d)])  # length N+1
    win = cs[seq_len:] - cs[:-seq_len]
    return np.where(win == 0)[0].astype(np.int64)


# -------------------------
# RSSM modules
# -------------------------
class ObsEncoder(nn.Module):
    def __init__(self, in_ch: int = 1, embed_dim: int = 1024):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, 32, 4, stride=2, padding=1), nn.ReLU(),   # 64->32
            nn.Conv2d(32, 64, 4, stride=2, padding=1), nn.ReLU(),      # 32->16
            nn.Conv2d(64, 128, 4, stride=2, padding=1), nn.ReLU(),     # 16->8
            nn.Conv2d(128, 256, 4, stride=2, padding=1), nn.ReLU(),    # 8->4
        )
        self.fc = nn.Linear(256 * 4 * 4, embed_dim)

    def forward(self, o: torch.Tensor) -> torch.Tensor:
        h = self.conv(o).reshape(o.shape[0], -1)
        return F.relu(self.fc(h))

class GaussianHead(nn.Module):
    def __init__(self, in_dim: int, out_dim: int, min_logstd: float = -5.0, max_logstd: float = 2.0):
        super().__init__()
        self.mu = nn.Linear(in_dim, out_dim)
        self.logstd = nn.Linear(in_dim, out_dim)
        self.min_logstd = min_logstd
        self.max_logstd = max_logstd

    def forward(self, x):
        mu = self.mu(x)
        logstd = torch.clamp(self.logstd(x), self.min_logstd, self.max_logstd)
        return mu, logstd

class RSSM(nn.Module):
    def __init__(self, action_dim: int, h_dim: int, z_dim: int, embed_dim: int):
        super().__init__()
        self.action_dim = action_dim
        self.h_dim = h_dim
        self.z_dim = z_dim
        self.gru = nn.GRUCell(z_dim + action_dim, h_dim)
        self.prior_head = GaussianHead(h_dim, z_dim)
        self.post_mlp = nn.Sequential(
            nn.Linear(h_dim + embed_dim, 512), nn.ReLU(),
            nn.Linear(512, 512), nn.ReLU(),
        )
        self.post_head = GaussianHead(512, z_dim)

    def init_state(self, batch: int, device: torch.device):
        h = torch.zeros((batch, self.h_dim), device=device)
        z = torch.zeros((batch, self.z_dim), device=device)
        return h, z

    def prior(self, h: torch.Tensor):
        return self.prior_head(h)

    def posterior(self, h: torch.Tensor, e: torch.Tensor):
        x = self.post_mlp(torch.cat([h, e], dim=-1))
        return self.post_head(x)

    def step(self, h: torch.Tensor, z: torch.Tensor, a_oh: torch.Tensor):
        return self.gru(torch.cat([z, a_oh], dim=-1), h)

class DecoderClean(nn.Module):
    def __init__(self, feat_dim: int):
        super().__init__()
        self.fc = nn.Linear(feat_dim, 256*4*4)
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1), nn.ReLU(),
            nn.ConvTranspose2d(32, 16, 4, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(16, 1, 3, padding=1),
            nn.Sigmoid(),
        )
    def forward(self, feat: torch.Tensor) -> torch.Tensor:
        h = F.relu(self.fc(feat)).reshape(feat.shape[0], 256, 4, 4)
        return self.deconv(h)

class DecoderNoisy(nn.Module):
    def __init__(self, feat_dim: int):
        super().__init__()
        self.fc = nn.Linear(feat_dim, 256*4*4)
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1), nn.ReLU(),
            nn.ConvTranspose2d(32, 16, 4, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(16, 1, 3, padding=1),
            nn.Sigmoid(),
        )
    def forward(self, feat: torch.Tensor) -> torch.Tensor:
        h = F.relu(self.fc(feat)).reshape(feat.shape[0], 256, 4, 4)
        return self.deconv(h)

class UPredictor(nn.Module):
    def __init__(self, embed_dim: int, u_dim: int):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, 512), nn.ReLU(),
            nn.Linear(512, 512), nn.ReLU(),
        )
        self.head = GaussianHead(512, u_dim)
    def forward(self, e: torch.Tensor):
        return self.head(self.mlp(e))


# -------------------------
# Hparams
# -------------------------
@dataclass
class HParams:
    screen_size: int = 64
    h_dim: int = 200
    z_dim: int = 32
    u_dim: int = 16
    embed_dim: int = 1024
    seq_len: int = 50
    batch: int = 32

    lr: float = 2e-4
    grad_clip: float = 100.0

    # weights
    beta_clean: float = 1.0
    beta_noisy: float = 0.5
    beta_kl_z: float = 1.0
    beta_kl_u: float = 0.2
    beta_inv: float = 0.5
    beta_overshoot: float = 1.0

    free_nats_z: float = 3.0
    free_nats_u: float = 1.0

    noise_std: float = 0.05

    # foreground weighting
    alpha_motion: float = 20.0
    motion_thr: float = 0.03
    alpha_bright: float = 5.0
    bright_thr: float = 0.7

    # overshoot
    overshoot_k: int = 10


@torch.no_grad()
def eval_rollout(enc_obs, rssm, dec_clean, obs_seq_u8, act_seq, device, action_dim: int, K: int = 16):
    obs01 = torch.from_numpy(obs_seq_u8[:K+1]).to(device).float() / 255.0
    obs01 = obs01.unsqueeze(1)  # [K+1,1,H,W]
    act = torch.from_numpy(act_seq[:K]).to(device).long()

    h, _z = rssm.init_state(batch=1, device=device)
    e0 = enc_obs(obs01[0:1])
    mu_q, _ = rssm.posterior(h, e0)
    z = mu_q

    preds = [dec_clean(torch.cat([h, z], dim=-1))[0,0].detach().cpu().numpy()]
    for t in range(K):
        h = rssm.step(h, z, onehot(act[t:t+1], action_dim))
        mu_p, _ = rssm.prior(h)
        z = mu_p
        preds.append(dec_clean(torch.cat([h, z], dim=-1))[0,0].detach().cpu().numpy())

    gt = obs01[:,0].detach().cpu().numpy()
    pr = np.stack(preds, axis=0)
    return gt, pr


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--env_id", type=str, default="ALE/Pong-v5")
    p.add_argument("--seed", type=int, default=0)
    p.add_argument("--collect_steps", type=int, default=100_000)
    p.add_argument("--frame_skip", type=int, default=4)
    p.add_argument("--repeat_action_probability", type=float, default=0.0)
    p.add_argument("--screen_size", type=int, default=64)
    p.add_argument("--noop_max", type=int, default=0)

    p.add_argument("--seq_len", type=int, default=50)
    p.add_argument("--batch", type=int, default=32)
    p.add_argument("--train_steps", type=int, default=20_000)
    p.add_argument("--lr", type=float, default=2e-4)
    p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
    p.add_argument("--outdir", type=str, default="outputs_rssm_causal_stochastic_v2")
    p.add_argument("--eval_every", type=int, default=2000)

    # v2 knobs
    p.add_argument("--overshoot_k", type=int, default=10)
    p.add_argument("--beta_overshoot", type=float, default=1.0)
    p.add_argument("--alpha_motion", type=float, default=20.0)
    p.add_argument("--motion_thr", type=float, default=0.03)
    p.add_argument("--alpha_bright", type=float, default=5.0)
    p.add_argument("--bright_thr", type=float, default=0.7)
    p.add_argument("--detach_z_in_noisy", type=int, default=1, help="1 detaches z from noisy decoder to force u to carry nuisance")
    args = p.parse_args()

    set_seed(args.seed)
    device = torch.device(args.device)
    outdir = Path(args.outdir)
    outdir.mkdir(parents=True, exist_ok=True)

    env = make_atari_env(
        env_id=args.env_id,
        seed=args.seed,
        frame_skip=args.frame_skip,
        repeat_action_prob=args.repeat_action_probability,
        screen_size=args.screen_size,
        noop_max=args.noop_max,
    )
    action_dim = env.action_space.n
    print("Env:", args.env_id, "| action_dim:", action_dim, "| obs:", env.observation_space)

    obs_u8, act, done = collect_dataset(env, args.collect_steps)
    starts = valid_starts_from_dones(done, args.seq_len)
    if len(starts) == 0:
        raise RuntimeError("No valid sequences found (seq_len too large for collected episodes).")

    hp = HParams(
        screen_size=args.screen_size,
        seq_len=args.seq_len,
        batch=args.batch,
        lr=args.lr,
        overshoot_k=args.overshoot_k,
        beta_overshoot=args.beta_overshoot,
        alpha_motion=args.alpha_motion,
        motion_thr=args.motion_thr,
        alpha_bright=args.alpha_bright,
        bright_thr=args.bright_thr,
    )

    enc_obs = ObsEncoder(in_ch=1, embed_dim=hp.embed_dim).to(device)
    rssm = RSSM(action_dim=action_dim, h_dim=hp.h_dim, z_dim=hp.z_dim, embed_dim=hp.embed_dim).to(device)
    u_pred = UPredictor(embed_dim=hp.embed_dim, u_dim=hp.u_dim).to(device)
    dec_clean = DecoderClean(feat_dim=hp.h_dim + hp.z_dim).to(device)
    dec_noisy = DecoderNoisy(feat_dim=hp.h_dim + hp.z_dim + hp.u_dim).to(device)

    opt = torch.optim.Adam(
        list(enc_obs.parameters()) + list(rssm.parameters()) + list(u_pred.parameters()) +
        list(dec_clean.parameters()) + list(dec_noisy.parameters()),
        lr=hp.lr
    )

    H, W = obs_u8.shape[1], obs_u8.shape[2]

    def sample_batch():
        idx = np.random.choice(starts, size=hp.batch, replace=True)
        obs_seq = np.stack([obs_u8[i:i+hp.seq_len+1] for i in idx], axis=0)   # [B,L+1,H,W]
        act_seq = np.stack([act[i:i+hp.seq_len] for i in idx], axis=0)        # [B,L]
        return obs_seq, act_seq

    # fixed eval slice
    eval_i = int(starts[len(starts)//2])
    eval_obs_seq = obs_u8[eval_i:eval_i+hp.seq_len+1]
    eval_act_seq = act[eval_i:eval_i+hp.seq_len]

    for step in range(1, args.train_steps + 1):
        obs_seq_u8, act_seq_np = sample_batch()

        obs_seq = torch.from_numpy(obs_seq_u8).to(device).float() / 255.0     # [B,L+1,H,W]
        obs_seq = obs_seq.unsqueeze(2)                                        # [B,L+1,1,H,W]
        act_seq = torch.from_numpy(act_seq_np).to(device).long()              # [B,L]

        B, Lp1 = obs_seq.shape[0], obs_seq.shape[1]
        L = Lp1 - 1

        # Foreground/motion weights computed from CLEAN sequence
        # delta_t = |o_t - o_{t-1}| (t>=1)
        delta = torch.zeros_like(obs_seq)
        delta[:, 1:] = torch.abs(obs_seq[:, 1:] - obs_seq[:, :-1])
        w = 1.0 \
            + hp.alpha_motion * (delta > hp.motion_thr).float() \
            + hp.alpha_bright * (obs_seq > hp.bright_thr).float()
        # w shape: [B,L+1,1,H,W]

        # Noisy augmented obs for posterior + noisy decoder target
        obs_flat = obs_seq.reshape(B*(L+1), 1, H, W)
        obs_noisy_flat = augment_obs(obs_flat, noise_std=hp.noise_std)
        obs_noisy = obs_noisy_flat.reshape(B, L+1, 1, H, W)

        # Embed (posterior uses noisy)
        e = enc_obs(obs_noisy_flat).reshape(B, L+1, -1)

        h, _z0 = rssm.init_state(B, device=device)

        recon_clean = 0.0
        recon_noisy = 0.0
        kl_z = 0.0
        kl_u = 0.0
        raw_klz = 0.0
        raw_klu = 0.0

        # invariance on posterior mean z at t=0 under two augmentations
        obs0 = obs_seq[:, 0]  # [B,1,H,W]
        e1 = enc_obs(augment_obs(obs0, noise_std=hp.noise_std))
        e2 = enc_obs(augment_obs(obs0, noise_std=hp.noise_std))
        mu1, _ = rssm.posterior(h, e1)
        mu2, _ = rssm.posterior(h, e2)
        inv_loss = F.mse_loss(mu1, mu2)

        # --- main unroll (posterior decoding) ---
        z = None
        for t in range(L+1):
            mu_p, logstd_p = rssm.prior(h)
            mu_q, logstd_q = rssm.posterior(h, e[:, t])
            z = reparam(mu_q, logstd_q)

            mu_u, logstd_u = u_pred(e[:, t])
            u = reparam(mu_u, logstd_u)

            # clean recon (weighted)
            pred_clean = dec_clean(torch.cat([h, z], dim=-1))
            tgt_clean = obs_seq[:, t]
            recon_clean = recon_clean + weighted_mse(pred_clean, tgt_clean, w[:, t])

            # noisy recon (optionally detach z so u must explain nuisance)
            z_for_noisy = z.detach() if args.detach_z_in_noisy == 1 else z
            pred_noisy = dec_noisy(torch.cat([h, z_for_noisy, u], dim=-1))
            tgt_noisy = obs_noisy[:, t]
            recon_noisy = recon_noisy + F.mse_loss(pred_noisy, tgt_noisy)

            # KLs
            klz_t_raw = kl_diag_gaussian(mu_q, logstd_q, mu_p, logstd_p)  # [B]
            klu_t_raw = kl_std_normal(mu_u, logstd_u)                     # [B]
            raw_klz = raw_klz + torch.mean(klz_t_raw)
            raw_klu = raw_klu + torch.mean(klu_t_raw)

            klz_t = torch.clamp(klz_t_raw, min=hp.free_nats_z)
            klu_t = torch.clamp(klu_t_raw, min=hp.free_nats_u)
            kl_z = kl_z + torch.mean(klz_t)
            kl_u = kl_u + torch.mean(klu_t)

            if t < L:
                h = rssm.step(h, z, onehot(act_seq[:, t], action_dim))

        recon_clean = recon_clean / (L+1)
        recon_noisy = recon_noisy / (L+1)
        kl_z = kl_z / (L+1)
        kl_u = kl_u / (L+1)
        raw_klz = raw_klz / (L+1)
        raw_klu = raw_klu / (L+1)

        # --- overshooting loss (prior rollout from t=0) ---
        overshoot = 0.0
        if hp.overshoot_k > 0:
            h_os, _ = rssm.init_state(B, device=device)
            mu_q0, logstd_q0 = rssm.posterior(h_os, e[:, 0])
            z_os = reparam(mu_q0, logstd_q0)
            # predict 1..K
            K = min(hp.overshoot_k, L)
            for k in range(1, K+1):
                h_os = rssm.step(h_os, z_os, onehot(act_seq[:, k-1], action_dim))
                mu_p, logstd_p = rssm.prior(h_os)
                z_os = reparam(mu_p, logstd_p)
                pred = dec_clean(torch.cat([h_os, z_os], dim=-1))
                tgt = obs_seq[:, k]
                overshoot = overshoot + weighted_mse(pred, tgt, w[:, k])
            overshoot = overshoot / max(1, K)

        loss = (
            hp.beta_clean * recon_clean +
            hp.beta_noisy * recon_noisy +
            hp.beta_kl_z * kl_z +
            hp.beta_kl_u * kl_u +
            hp.beta_inv * inv_loss +
            hp.beta_overshoot * overshoot
        )

        opt.zero_grad(set_to_none=True)
        loss.backward()
        nn.utils.clip_grad_norm_(
            list(enc_obs.parameters()) + list(rssm.parameters()) + list(u_pred.parameters()) +
            list(dec_clean.parameters()) + list(dec_noisy.parameters()),
            hp.grad_clip
        )
        opt.step()

        if step % 200 == 0:
            print(
                f"step {step:06d} | loss {loss.item():.4f} "
                f"| clean {recon_clean.item():.4f} | noisy {recon_noisy.item():.4f} "
                f"| overshoot {float(overshoot):.4f} "
                f"| klz {kl_z.item():.3f} (raw {raw_klz.item():.3f}) "
                f"| klu {kl_u.item():.3f} (raw {raw_klu.item():.3f}) "
                f"| inv {inv_loss.item():.4f}"
            )

        if step % args.eval_every == 0:
            enc_obs.eval(); rssm.eval(); dec_clean.eval()
            gt, pr = eval_rollout(enc_obs, rssm, dec_clean, eval_obs_seq, eval_act_seq, device, action_dim, K=16)
            save_grid_png(outdir / f"rollout_clean_gt_vs_pred_step{step}.png", gt, pr)
            print(f"[EVAL step {step}] saved rollout png -> {outdir}")
            enc_obs.train(); rssm.train(); dec_clean.train()

    print("Done. Outputs in:", outdir)


if __name__ == "__main__":
    main()


Writing /content/drive/MyDrive/Colab_Notebooks/tg_smn/world_models/ale_rssm_causal_stochastic_v2.py


### Test dreamer style model

In [None]:
# %cd $REPO_DIR

# OUTDIR = "/content/drive/MyDrive/Colab_Notebooks/tg_smn/outputs_rssm_causal_stochastic"

# !python world_models/ale_rssm_causal_stochastic.py \
#   --env_id "ALE/Pong-v5" \
#   --repeat_action_probability 0.0 \
#   --screen_size 64 \
#   --frame_skip 4 \
#   --collect_steps 20000 \
#   --train_steps 3000 \
#   --eval_every 1000 \
#   --outdir "$OUTDIR"


## Actually Run Experiments

Optional variants:

Breakout:

--env_id "ALE/Breakout-v5"


True transition stochasticity (sticky actions):

--repeat_action_probability 0.25


Sticky actions are the standard way ALE injects stochasticity.

In [None]:
%cd /content/drive/MyDrive/Colab_Notebooks/tg_smn

OUTDIR = "/content/drive/MyDrive/Colab_Notebooks/tg_smn/outputs_rssm_causal_stochastic_v2"

!python world_models/ale_rssm_causal_stochastic_v2.py \
  --env_id "ALE/Pong-v5" \
  --repeat_action_probability 0.0 \
  --screen_size 64 \
  --frame_skip 4 \
  --collect_steps 100000 \
  --train_steps 20000 \
  --eval_every 2000 \
  --overshoot_k 10 \
  --beta_overshoot 1.0 \
  --alpha_motion 25.0 \
  --motion_thr 0.03 \
  --alpha_bright 6.0 \
  --bright_thr 0.70 \
  --detach_z_in_noisy 1 \
  --outdir "$OUTDIR"


/content/drive/MyDrive/Colab_Notebooks/tg_smn
A.L.E: Arcade Learning Environment (version 0.11.2+ecc1138)
[Powered by Stella]
Env: ALE/Pong-v5 | action_dim: 6 | obs: Box(0, 255, (64, 64), uint8)
Collect: 100% 100000/100000 [01:49<00:00, 917.41it/s]
Consider using tensor.detach() first. (Triggered internally at /pytorch/torch/csrc/autograd/generated/python_variable_methods.cpp:836.)
  f"| overshoot {float(overshoot):.4f} "
step 000200 | loss 3.2112 | clean 0.0026 | noisy 0.0119 | overshoot 0.0026 | klz 3.000 (raw 0.888) | klu 1.000 (raw 0.226) | inv 0.0000
step 000400 | loss 3.2093 | clean 0.0017 | noisy 0.0117 | overshoot 0.0018 | klz 3.000 (raw 1.154) | klu 1.000 (raw 0.075) | inv 0.0000
step 000600 | loss 3.2089 | clean 0.0015 | noisy 0.0116 | overshoot 0.0016 | klz 3.000 (raw 0.627) | klu 1.000 (raw 0.013) | inv 0.0000
step 000800 | loss 3.2086 | clean 0.0015 | noisy 0.0112 | overshoot 0.0015 | klz 3.000 (raw 0.479) | klu 1.000 (raw 0.023) | inv 0.0000
step 001000 | loss 3.2084 | cl

## Visualize Results

In [1]:
from glob import glob
from PIL import Image

OUTDIR = "/content/drive/MyDrive/Colab_Notebooks/tg_smn/outputs_rssm_causal_stochastic"

rollouts = sorted(glob(f"{OUTDIR}/rollout_clean_gt_vs_pred_step*.png"))
samples  = sorted(glob(f"{OUTDIR}/nuisance_samples_step*.png"))

print("Latest rollout:", rollouts[-1] if rollouts else None)
print("Latest nuisance:", samples[-1] if samples else None)

if rollouts:
    display(Image.open(rollouts[-1]))
if samples:
    display(Image.open(samples[-1]))


Latest rollout: None
Latest nuisance: None
