# Day 4 — JointDiT skeleton (assembly + smoke)

**Goal:** wire up the minimal JointDiT (Input → N×JointBlocks → Output), confirm end-to-end shapes, and validate the three modes (`full`, `iso_v`, `iso_a`).

**What this notebook does**
1. Activates the project root and environment assumptions
2. Runs the Day-4 smoke (`scripts/smoke/day04_joint_smoke.py`)
3. (Optional) Direct import smoke: instantiates `JointDiT` and forwards dummy latents
4. Records a tiny smoke snapshot (`notebooks/day_04/smoke_log.json`) for traceability

> If you hit a dtype mismatch (`Half vs Float`), see the *Troubleshooting* cell below — we pin the model to fp16 for the smoke to match fp16 inputs.

In [None]:
import json
import os
import platform
import subprocess
import sys
import time
from pathlib import Path

ROOT = Path('/workspace/jointdit')
assert ROOT.exists(), f"Project root not found: {ROOT}"
os.chdir(ROOT)
print('cwd =', Path.cwd())
print('python =', sys.version)
print('platform =', platform.platform())

try:
    import torch
    print('torch =', torch.__version__, '| cuda =', torch.cuda.is_available())
    if torch.cuda.is_available():
        print('gpu =', torch.cuda.get_device_name(0))
except Exception as e:
    print('[warn] torch not available:', e)

## Run the Day-4 smoke script
Runs the exact command we use via `make smoke-day04`.

In [None]:
env = os.environ.copy()
env['PYTHONPATH'] = str(ROOT)
cmd = [sys.executable, 'scripts/smoke/day04_joint_smoke.py']
print('> ', ' '.join(cmd))
res = subprocess.run(cmd, env=env, capture_output=True, text=True)
print(res.stdout)
if res.returncode != 0:
    print(res.stderr)
    raise SystemExit(res.returncode)

## Optional: direct import smoke
Instantiate `JointDiT` and feed dummy latents to verify shapes without the standalone script.

In [None]:
import torch

from models.jointdit import JointDiT

device = 'cuda' if torch.cuda.is_available() else 'cpu'
use_fp16 = True if device=='cuda' else False
dtype = torch.float16 if use_fp16 else torch.float32

# dummy latents (match smoke):
B = 1; Tv = 12; Cv = 4; Hv = 40; Wv = 53
Ta = 8; Ca = 8; Ha = 20; Wa = 15
v = torch.randn(B, Tv, Cv, Hv, Wv, device=device, dtype=dtype)
a = torch.randn(B, Ca, Ha, Wa, device=device, dtype=dtype)

joint = JointDiT(
    d_model=256, heads=8, ff_mult=4, dropout=0.0,
    rope_cfg={
        'video': {'type': 'rope_3d', 'theta': 10000.0},
        'audio': {'type': 'rope_2d', 'theta': 10000.0}
    },
    video_in_ch=Cv,
    audio_in_ch=Ca,
    joint_blocks=2,
    svd_slicer=None,
    aldm_slicer=None,
).to(device)
if use_fp16:
    joint = joint.half()
joint.eval()

for mode in ['full','iso_v','iso_a']:
    with torch.inference_mode():
        v_out, a_out = joint(v, a, mode=mode)
    print(f"mode={mode:>5}  v_out={tuple(v_out.shape)}  a_out={tuple(a_out.shape)}  nan/inf={bool(torch.isnan(v_out).any() or torch.isinf(v_out).any())}")

## Troubleshooting: dtype mismatch (`Half` vs `Float`)
If you see `RuntimeError: mat1 and mat2 must have the same dtype, but got Half and Float`, it means inputs are fp16 but model weights are fp32.

Two fixes:
1) Cast the model to half (used here): `joint = joint.half()`
2) Or keep weights fp32 and enable autocast around the forward call.

In [None]:
# Example: autocast path instead of model.half()
import contextlib

from torch.amp import autocast

if device=='cuda':
    with contextlib.ExitStack() as stack:
        stack.enter_context(autocast('cuda'))
        with torch.inference_mode():
            _v, _a = joint(v, a, mode='full')
    print('autocast smoke ok →', tuple(_v.shape), tuple(_a.shape))
else:
    print('CPU run — autocast not required')

## Save a tiny smoke log
Writes a JSON with shapes + timestamp so we can trace re-runs later.

In [None]:
out_dir = Path('notebooks/day_04')
out_dir.mkdir(parents=True, exist_ok=True)
log = {
    'ts': time.strftime('%Y-%m-%d %H:%M:%S'),
    'shapes': {
        'v_in': [B, Tv, Cv, Hv, Wv],
        'a_in': [B, Ca, Ha, Wa]
    },
    'config': {
        'd_model': 256, 'heads': 8, 'ff_mult': 4, 'blocks': 2,
        'rope': {'video': 'rope_3d', 'audio': 'rope_2d'}
    }
}
Path(out_dir/'smoke_log.json').write_text(json.dumps(log, indent=2))
print('wrote', out_dir/'smoke_log.json')