In [1]:
# stwm_compression_template.py
from __future__ import annotations
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Dict

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from einops import rearrange


In [2]:

# -------------------------
#  Utilities
# -------------------------

def exists(x):
    return x is not None

class QKNorm(nn.Module):
    """
    Query-Key normalization (Henry et al., 2020-ish idea).
    Normalize q and k per head to stabilize attention.
    """
    def __init__(self, eps: float = 1e-6):
        super().__init__()
        self.eps = eps

    def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # q,k: (B, heads, N, d)
        q = q / (q.norm(dim=-1, keepdim=True) + self.eps)
        k = k / (k.norm(dim=-1, keepdim=True) + self.eps)
        return q, k


In [3]:

# Optional xFormers memory-efficient attention
_HAS_XFORMERS = False
try:
    from xformers.ops import memory_efficient_attention
    _HAS_XFORMERS = True
except Exception:
    _HAS_XFORMERS = False


In [4]:

class MultiheadSelfAttention(nn.Module):
    """
    Minimal MHA with optional xformers.
    Supports causal mask for temporal attention.
    """
    def __init__(self, dim: int, heads: int, qk_norm: bool = True, dropout: float = 0.0):
        super().__init__()
        assert dim % heads == 0
        self.dim = dim
        self.heads = heads
        self.d = dim // heads

        self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
        self.proj = nn.Linear(dim, dim, bias=False)
        self.drop = nn.Dropout(dropout)
        self.qkn = QKNorm() if qk_norm else None

    def forward(self, x: torch.Tensor, causal: bool = False) -> torch.Tensor:
        """
        x: (B, N, C)
        """
        B, N, C = x.shape
        qkv = self.to_qkv(x)  # (B, N, 3C)
        q, k, v = qkv.chunk(3, dim=-1)

        # (B, heads, N, d)
        q = rearrange(q, "b n (h d) -> b h n d", h=self.heads)
        k = rearrange(k, "b n (h d) -> b h n d", h=self.heads)
        v = rearrange(v, "b n (h d) -> b h n d", h=self.heads)

        if self.qkn is not None:
            q, k = self.qkn(q, k)

        if _HAS_XFORMERS:
            # xFormers expects (B, N, H, d) or (B, N, C) depending on op;
            # memory_efficient_attention uses (B, N, H, d).
            q_ = rearrange(q, "b h n d -> b n h d")
            k_ = rearrange(k, "b h n d -> b n h d")
            v_ = rearrange(v, "b h n d -> b n h d")

            attn_bias = None
            if causal:
                # xFormers causal mask bias helper is not always available;
                # simplest: fall back to torch attention if causal and xformers missing bias.
                # But many xformers builds support a causal flag via an attention bias.
                # To keep template robust, we handle causal in torch path below.
                pass

            if causal:
                # Robust fallback (still fine because temporal N is small: T<=6)
                return self._torch_attention(x, q, k, v, causal=True)

            out = memory_efficient_attention(q_, k_, v_, attn_bias=attn_bias)  # (B, N, H, d)
            out = rearrange(out, "b n h d -> b n (h d)")
        else:
            out = self._torch_attention(x, q, k, v, causal=causal)

        out = self.proj(out)
        out = self.drop(out)
        return out

    def _torch_attention(self, x, q, k, v, causal: bool) -> torch.Tensor:
        # q,k,v: (B, heads, N, d)
        B, H, N, d = q.shape
        scale = 1.0 / math.sqrt(d)

        # (B, H, N, N)
        scores = torch.einsum("b h i d, b h j d -> b h i j", q, k) * scale

        if causal:
            # causal mask: disallow attending to future positions
            mask = torch.triu(torch.ones(N, N, device=x.device, dtype=torch.bool), diagonal=1)
            scores = scores.masked_fill(mask, float("-inf"))

        attn = torch.softmax(scores, dim=-1)
        out = torch.einsum("b h i j, b h j d -> b h i d", attn, v)
        out = rearrange(out, "b h n d -> b n (h d)")
        return out

class FeedForward(nn.Module):
    def __init__(self, dim: int, mult: int = 4, dropout: float = 0.0):
        super().__init__()
        inner = dim * mult
        self.net = nn.Sequential(
            nn.Linear(dim, inner),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(inner, dim),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)


In [5]:

# -------------------------
#  ST-Transformer Block
# -------------------------

class STBlock(nn.Module):
    """
    One block: Spatial Attn -> Temporal (causal) Attn -> FFN
    Pre-LN style.
    """
    def __init__(self, dim: int, heads: int, dropout: float = 0.0, qk_norm: bool = True):
        super().__init__()
        self.ln_s = nn.LayerNorm(dim)
        self.ln_t = nn.LayerNorm(dim)
        self.ln_f = nn.LayerNorm(dim)

        self.attn_s = MultiheadSelfAttention(dim, heads, qk_norm=qk_norm, dropout=dropout)
        self.attn_t = MultiheadSelfAttention(dim, heads, qk_norm=qk_norm, dropout=dropout)
        self.ff = FeedForward(dim, mult=4, dropout=dropout)

    def forward(self, x: torch.Tensor, T: int, HW: int) -> torch.Tensor:
        """
        x: (B, T*HW, C)
        We reshape for factorized attention:
          - Spatial: for each t, attend over HW tokens
          - Temporal: for each spatial index p, attend over T tokens (causal)
        """
        B, N, C = x.shape
        assert N == T * HW

        # Spatial attention: (B*T, HW, C)
        xs = rearrange(x, "b (t p) c -> (b t) p c", t=T, p=HW)
        xs = xs + self.attn_s(self.ln_s(xs), causal=False)
        x = rearrange(xs, "(b t) p c -> b (t p) c", b=B, t=T, p=HW)

        # Temporal attention: (B*HW, T, C), causal
        xt = rearrange(x, "b (t p) c -> (b p) t c", t=T, p=HW)
        xt = xt + self.attn_t(self.ln_t(xt), causal=True)
        x = rearrange(xt, "(b p) t c -> b (t p) c", b=B, p=HW, t=T)

        # FFN
        x = x + self.ff(self.ln_f(x))
        return x

# -------------------------
#  State embedding (robot states)
# -------------------------

class StateEncoder(nn.Module):
    """
    s: (B, 64, Sdim)
    -> returns per-frame conditioning: (B, T, C)
    Revontuli mentions MLP + Conv1d + pos emb + additive embedding.
    Here: MLP -> Conv1d over time -> sample/align to T -> add to tokens.
    """
    def __init__(self, state_dim: int, model_dim: int, conv_channels: int = 256):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(state_dim, conv_channels),
            nn.GELU(),
            nn.Linear(conv_channels, conv_channels),
            nn.GELU(),
        )
        self.conv = nn.Sequential(
            nn.Conv1d(conv_channels, conv_channels, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Conv1d(conv_channels, conv_channels, kernel_size=3, padding=1),
            nn.GELU(),
        )
        self.to_model = nn.Linear(conv_channels, model_dim)

    def forward(self, s: torch.Tensor, T: int) -> torch.Tensor:
        """
        s: (B, 64, Sd)
        returns: (B, T, C)
        """
        B, L, Sd = s.shape  # L=64
        h = self.mlp(s)  # (B, L, conv_channels)
        h = rearrange(h, "b l c -> b c l")
        h = self.conv(h)
        h = rearrange(h, "b c l -> b l c")  # (B, L, conv_channels)

        # Align length 64 -> T using interpolation (simple & robust)
        if L != T:
            # (B, C, L) -> (B, C, T)
            h_ = rearrange(h, "b l c -> b c l")
            h_ = F.interpolate(h_, size=T, mode="linear", align_corners=False)
            h = rearrange(h_, "b c t -> b t c")

        cond = self.to_model(h)  # (B, T, model_dim)
        return cond


In [6]:

# -------------------------
#  Main model
# -------------------------

class RevontuliCompressionModel(nn.Module):
    """
    Token dynamics model:
      Input: past tokens (B, 3, 32, 32) + robot states (B, 64, 25)
      Target: future tokens (B, 3, 32, 32)
    Train with teacher forcing over all 6 frames (past+future),
    but only compute CE on the future frames.
    """
    def __init__(
        self,
        vocab_size: int,
        model_dim: int = 512,
        n_layers: int = 24,
        n_heads: int = 8,
        dropout: float = 0.1,
        state_dim: int = 25,
        H: int = 32,
        W: int = 32,
        n_past: int = 3,
        n_future: int = 3,
    ):
        super().__init__()
        self.vocab_size = vocab_size
        self.model_dim = model_dim
        self.H = H
        self.W = W
        self.HW = H * W
        self.n_past = n_past
        self.n_future = n_future
        self.T_total = n_past + n_future  # 6

        self.tok_emb = nn.Embedding(vocab_size, model_dim)
        # Simple learned spatial position (HW) and temporal position (T)
        self.pos_spatial = nn.Parameter(torch.zeros(1, 1, self.HW, model_dim))
        self.pos_temporal = nn.Parameter(torch.zeros(1, self.T_total, 1, model_dim))

        self.state_enc = StateEncoder(state_dim=state_dim, model_dim=model_dim)

        self.blocks = nn.ModuleList([
            STBlock(model_dim, n_heads, dropout=dropout, qk_norm=True)
            for _ in range(n_layers)
        ])
        self.ln_out = nn.LayerNorm(model_dim)

        # output projection; tying is common (input/output embeddings)
        self.to_logits = nn.Linear(model_dim, vocab_size, bias=False)
        self.to_logits.weight = self.tok_emb.weight  # embedding tying

        nn.init.normal_(self.pos_spatial, std=0.02)
        nn.init.normal_(self.pos_temporal, std=0.02)

    def forward_teacher_forcing(self, z_all: torch.Tensor, s: torch.Tensor) -> torch.Tensor:
        """
        Teacher forcing forward over all 6 frames.
        z_all: (B, 6, 32, 32)  (past + future ground-truth)
        s:     (B, 64, 25)
        returns logits: (B, 6, 32, 32, vocab)
        """
        B, T, H, W = z_all.shape
        assert T == self.T_total and H == self.H and W == self.W

        x = self.tok_emb(z_all)  # (B, T, H, W, C)

        # add (temporal + spatial) positions
        x = x + self.pos_temporal[:, :T] + self.pos_spatial  # broadcast

        # add state conditioning as additive embedding per frame
        cond = self.state_enc(s, T=T)  # (B, T, C)
        x = x + cond[:, :, None, None, :]  # (B,T,1,1,C)

        # flatten to (B, T*HW, C)
        x = rearrange(x, "b t h w c -> b (t h w) c")

        for blk in self.blocks:
            x = blk(x, T=T, HW=self.HW)

        x = self.ln_out(x)
        x = rearrange(x, "b (t hw) c -> b t hw c", t=T, hw=self.HW)
        logits = self.to_logits(x)  # (B, T, HW, vocab)
        logits = rearrange(logits, "b t (h w) v -> b t h w v", h=self.H, w=self.W)
        return logits

    @torch.no_grad()
    def generate_greedy(self, z_past: torch.Tensor, s: torch.Tensor) -> torch.Tensor:
        """
        Greedy autoregressive generation of future frames (3 frames),
        generating a full 32x32 grid each step.

        z_past: (B, 3, 32, 32)
        returns z_hat_future: (B, 3, 32, 32)
        """
        B, Tp, H, W = z_past.shape
        assert Tp == self.n_past and H == self.H and W == self.W

        z_ctx = z_past.clone()

        for t in range(self.n_future):
            # Build a 6-frame tensor with placeholders for unknown future
            # Here we do stepwise frame generation: at step k, we have past + already generated k frames.
            T_current = self.n_past + t + 1
            z_all = torch.zeros(B, self.T_total, H, W, device=z_ctx.device, dtype=z_ctx.dtype)
            z_all[:, :T_current] = z_ctx  # fill known tokens

            # Teacher-forcing forward uses provided z_all; unknown frames are zeros (ignored by masking below)
            logits = self.forward_teacher_forcing(z_all, s)  # (B, 6, H, W, V)

            # take logits for the next frame index (T_current-1) (the last known frame position)
            # Actually we want to predict frame at index T_current-1 (new frame position).
            # We used z_ctx including new frame position as zeros, but logits are computed for all frames.
            next_idx = T_current - 1
            next_logits = logits[:, next_idx]  # (B, H, W, V)
            next_tokens = next_logits.argmax(dim=-1)  # (B, H, W)

            # append generated frame
            z_ctx = torch.cat([z_ctx, next_tokens[:, None]], dim=1)

        z_hat_future = z_ctx[:, self.n_past:self.n_past + self.n_future]
        return z_hat_future


In [11]:

# -------------------------
#  Dataset Template
# -------------------------
import os, sys
sys.path.append(os.path.abspath('..'))
from src.dataset import RawTokenDataset


In [None]:

# -------------------------
#  Train / Eval loops
# -------------------------

@dataclass
class TrainConfig:
    vocab_size: int
    lr: float = 8e-4
    weight_decay: float = 0.05
    epochs: int = 10
    batch_size: int = 8
    num_workers: int = 2
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    log_every: int = 50

def compute_loss_ce_future(logits: torch.Tensor, z_all: torch.Tensor, n_past: int = 3) -> torch.Tensor:
    """
    logits: (B, 6, 32, 32, V)
    z_all:  (B, 6, 32, 32)
    Only compute CE on future frames 3..5
    """
    B, T, H, W, V = logits.shape
    assert T == 6
    logits_f = logits[:, n_past:]           # (B, 3, H, W, V)
    target_f = z_all[:, n_past:]            # (B, 3, H, W)

    loss = F.cross_entropy(
        logits_f.reshape(-1, V),
        target_f.reshape(-1),
        reduction="mean",
    )
    return loss

@torch.no_grad()
def eval_autoreg_ce(model: RevontuliCompressionModel, loader: DataLoader, device: str) -> float:
    """
    Simple autoregressive evaluation:
      - generate future frames greedily
      - compute CE of generated vs GT? (CE usually requires logits; here we approximate with token accuracy)
    For leaderboard-like CE you want logits under AR rollout; that’s more involved.
    This function provides a sanity-check metric (token accuracy).
    """
    model.eval()
    total = 0
    correct = 0

    for z_all, s in loader:
        z_all = z_all.to(device)
        s = s.to(device)
        z_past = z_all[:, :3]
        z_gt_f = z_all[:, 3:]

        z_hat_f = model.generate_greedy(z_past, s)  # (B,3,32,32)
        total += z_gt_f.numel()
        correct += (z_hat_f == z_gt_f).sum().item()

    return correct / total

def train_one_run(model, train_loader, val_loader, cfg: TrainConfig):
    model.to(cfg.device)

    opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay, betas=(0.9, 0.95))
    # simple warmup+linear decay scheduler template
    total_steps = cfg.epochs * len(train_loader)
    warmup_steps = min(2000, max(100, total_steps // 20))

    def lr_lambda(step):
        if step < warmup_steps:
            return float(step) / float(max(1, warmup_steps))
        # linear decay
        return max(0.0, float(total_steps - step) / float(max(1, total_steps - warmup_steps)))

    sched = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda)

    step = 0
    for epoch in range(cfg.epochs):
        model.train()
        for z_all, s in train_loader:
            z_all = z_all.to(cfg.device)
            s = s.to(cfg.device)

            logits = model.forward_teacher_forcing(z_all, s)
            loss = compute_loss_ce_future(logits, z_all, n_past=model.n_past)

            opt.zero_grad(set_to_none=True)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()
            sched.step()

            if step % cfg.log_every == 0:
                print(f"[step {step:07d}] loss={loss.item():.4f} lr={sched.get_last_lr()[0]:.2e}")
            step += 1

        # quick sanity eval
        acc = eval_autoreg_ce(model, val_loader, cfg.device)
        print(f"[epoch {epoch:03d}] AR token-acc={acc:.4f}")


In [None]:

# -------------------------
#  Example usage (dummy)
# -------------------------

def main():
    # Dummy data (replace with your real token/state loading)
    N = 64
    vocab = 4096  # <-- Cosmos tokenizer vocab sizeに合わせて設定
    z = torch.randint(0, vocab, (N, 6, 32, 32), dtype=torch.long)
    s = torch.randn(N, 64, 25)

    ds = TokenStateDataset({"z": z, "s": s})
    train_ds, val_ds = torch.utils.data.random_split(ds, [int(N*0.8), N - int(N*0.8)])

    train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=0, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=4, shuffle=False, num_workers=0, pin_memory=True)

    cfg = TrainConfig(vocab_size=vocab, epochs=2, batch_size=4)

    model = RevontuliCompressionModel(
        vocab_size=vocab,
        model_dim=512,
        n_layers=6,     # まずは小さく（本番は24）
        n_heads=8,
        dropout=0.1,
        state_dim=25,
        H=32, W=32,
        n_past=3,
        n_future=3,
    )

    train_one_run(model, train_loader, val_loader, cfg)

In [None]:
main()