A 26M-parameter JEPA plays a 2-player turn-based contact-push puzzle by planning in latent space — no physics simulator on the model's side at inference.
3-frame pixel context (224×224)
│
▼
ViT-Tiny encoder (trained from scratch, unfrozen, 192-dim)
│
▼
Projector MLP(192→2048→192) + BatchNorm1d
│
Action encoder Linear(frameskip·2 → 192)
│
▼
AR causal Transformer predictor
(AdaLN, 6 layers, 16 heads, mlp_dim=2048)
│
├─► predicted embedding ─► SIGReg loss
│ (Sketch Isotropic Gaussian Reg.)
│ weight 0.09
│
└─► State head MLP(192→256→256→8) ─► (p1xy, p2xy, T_xy, T_θ)
(DexWM joint, λ=10)
Two agents (amber/cyan), a T-shaped block, and two goals on opposite sides of the board. You push the T toward your goal. The opponent pushes toward its own. 5 turns each, 5 environment steps per turn — whoever has the T closer to their goal at match-end wins.
The opponent is a JEPA trained pixel-to-embedding from scratch (LeWM recipe). Each turn it does CEM in pure embedding space:
- Encode the last 3 frames into latents.
- Sample 96 candidate action sequences from a Gaussian.
- Roll each one through the AR transformer.
- Decode the predicted final state via the joint head.
- Score on
‖agent − T‖ + 2·‖T − goal‖. - Keep top 15% as elites, refit, repeat 6 times.
- Execute the best plan via classical pymunk.
No simulator on the model's side. Pymunk only runs to render the chosen action; the reasoning is pure latent-space CEM.
The whole inference loop runs on CPU at interactive rates (~2 s per turn). Live demo: https://sotoalt.dev/experiments/relay.html.
git clone https://github.com/SotoAlt/relay.git
cd relay
pip install torch torchvision timm einops pymunk pygame opencv-python-headless \
shapely fastapi 'uvicorn[standard]' numpy pillow h5py
# Download the trained checkpoint from HuggingFace
huggingface-cli download SotoAlt/relay relay_stage1_v9_trackE_ep02_uhead.pt \
--local-dir checkpoints/
# Run the server
PYTHONPATH=. python -m world_model.infer_relay \
--port 8800 --device cpu \
--checkpoint-v9 checkpoints/relay_stage1_v9_trackE_ep02_uhead.pt \
--model-execute-jepa
# Open http://localhost:8800/ in your browserThe full recipe — data gen + 100 epochs joint Phase A on one GPU. Runs on RunPod RTX 4090 (~10 hr, ~$6) or RTX 5090 community (~3 hr, ~$2.50).
# 1. Generate training data — 12,500 episodes × 40 steps × 5 regimes (~25 min CPU)
# contact_push 40%, approach_no_push 25%, null_thrust 15%,
# near_t 10%, far_from_t 10%. 500K frames at 224×224 RGB, LZF H5.
python -m scripts.gen_pymunk_v7_episodes \
--stage 1 --episodes 12500 --steps 40 \
--output data/relay_stage1.h5
# 2. Joint Phase A — LeWM losses + DexWM-style state head, λ=10
python -m scripts.train_relay_v9_lewm \
--h5 data/relay_stage1.h5 \
--output checkpoints/relay_v9_joint.pt \
--epochs 100 --batch 128 --frameskip 5 --history 3 \
--lambda-sigreg 0.09 --lambda-state 10 \
--num-workers 4 --device cuda--num-workers 4 is the right setting on Ada and earlier (RTX 4090, A40,
…). Use --num-workers 0 only on Blackwell (RTX 5090 community), where
SDPA flash-attn races multi-process DataLoaders.
In a 2-player turn-based game, the model has to learn which ball belongs to which thrust. That depends on how training data is structured.
The first version of relay alternated which ball was acting on every single frame. But at gameplay time, each player acts for a full 5-frame turn. The temporal structure didn't match — and the model never cleanly attributed actions to balls. Cross-agent leakage on every prediction.
The fix is one line in the data generator:
who = 'p1' if ((t // 5) % 2 == 0) else 'p2' # 5-frame chunks, not alternatingPlus retraining the state-head decoder on uniformly-random scenes (not rollouts) so it sees the full latent geometry the planner queries rather than just the spawn-region prior. Position decode error dropped from ~45 px to ~10 px on agents, and average planning error dropped from 54 px to 47 px.
Lesson: in turn-based games, training data has to match the temporal structure of gameplay.
See docs/JOURNAL.md for the full story — including v0–v9 evolution,
the failed CEM cost-shaping experiments, the SDPA-on-Blackwell ghost
debug, and the joint state-head retrain that closed the calibration
gap.
Phase A joint training, λ_state=10, 5-epoch fine-tune validation:
| metric | pre-joint v9 | post-joint v9 (5 ep ft) |
|---|---|---|
val_pred |
0.0071 | 0.036 |
val_state |
— (Phase B: 0.032) | 0.007 (4.6× better) |
pymunk calibration probe agent_mae |
~113 px | ~54 px (-52%) |
Trading raw embedding accuracy for action-accurate embeddings — the state-head's gradient pulls the predictor away from the distribution-mean optimum that vanilla LeWM converges to.
Match-play (model-execute-jepa, 12 × 4 opponent policies):
| opponent | win % | net progress |
|---|---|---|
| random | 92% | +126 |
| passive | 100% | +149 |
| chase_t | 0% | -67 |
| adversarial | 25% | -10 |
Random and passive opponents are dominated. Chase_t still loses the push war — JEPA's smaller predicted pushes can't keep up with pymunk's real ones. Adversarial draws roughly even.
The model itself follows the LeWorldModel paper byte-for-byte:
- ViT-Tiny encoder (timm
vit_tiny_patch14_224, trained from scratch, unfrozen) - AR causal Transformer predictor with AdaLN (6 layers, 16 heads)
- Two-term loss: next-embedding MSE + SIGReg (the regularizer that prevents embedding collapse without any of the usual stabilization knobs — stop-grad, EMA targets, warmups, prediction heads, masking ratios, auxiliary decoders all collapse into one weight: 0.09)
On top: a DexWM-style joint state head (small MLP 192→256→256→8 GELU) trained alongside the predictor — its gradient flows back through the projector and into the encoder, forcing them to preserve task-relevant info. λ=10, not the paper's 100 — our absolute loss scales differ; λ=100 crushes the predictor.
world_model/
infer_relay.py FastAPI + WebSocket inference server
relay_planner_v9.py v9 CEM planner (latent space)
relay_state_clip.py Iterative agent-T overlap fix
envs/relay.py Pymunk env (execution only, never used by planner)
envs/relay_stages.py Stage geometry + start positions
scripts/
train_relay_v9_lewm.py Joint Phase A (LeWM + DexWM head, the canonical recipe)
client/relay/
index.html Browser UI (canvas + side panel + telemetry)
main.js WebSocket client + CEM dream-field animation
docs/
JOURNAL.md Full research journal — v0 → v9, fixes and failures
- Live demo: https://sotoalt.dev/experiments/relay.html
- Hugging Face model: https://huggingface.co/SotoAlt/relay
- Production deploy artifacts: SotoAlt/relay-deploy
- LeWorldModel paper: arxiv 2603.19312 (Maes, Le Lidec, Scieur, LeCun, Balestriero, 2026)
- DexWM paper: arxiv 2512.13644
MIT.