# JointDiT — End-to-End Guide

This notebook walks you through **latent caching → training (Stage-A/B) → inference** using the repo scripts.
You can run cells sequentially or copy/paste commands into your terminal.

In [None]:
import os
import pathlib
import platform

import torch

print('Python:', platform.python_version())
print('CUDA available:', torch.cuda.is_available())
if torch.cuda.is_available():
    print('GPU:', torch.cuda.get_device_name(0))
print('Repo root:', pathlib.Path('.').resolve())

## 1) Cache latents (Day 2)
Make sure your raw videos exist under the paths referenced by `configs/day02_cache.yaml`.

In [None]:
%%bash
set -euo pipefail
source .venv/bin/activate
PYTHONPATH=. python scripts/data/cache_latents.py --cfg configs/day02_cache.yaml --split train
PYTHONPATH=. python scripts/data/cache_latents.py --cfg configs/day02_cache.yaml --split val

## 2) Choose VRAM profile (exports)
Pick a row below (you can tweak later):

In [None]:
# ~48GB profile (safe) — adjust as needed
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
os.environ['JOINTDIT_MAX_T'] = '6'
os.environ['JOINTDIT_Q_CHUNK_V'] = '64'
os.environ['JOINTDIT_Q_CHUNK_A'] = '0'
os.environ['JOINTDIT_KV_DOWNSAMPLE'] = '8'
os.environ['FORCE_KEEP_USER_ENVS'] = '1'
print({k: os.environ[k] for k in ['PYTORCH_CUDA_ALLOC_CONF','JOINTDIT_MAX_T','JOINTDIT_Q_CHUNK_V','JOINTDIT_Q_CHUNK_A','JOINTDIT_KV_DOWNSAMPLE','FORCE_KEEP_USER_ENVS']})

## 3) Train — Stage A (Day 5)
Small smoke run — bump `--max-steps` for real training.

In [None]:
%%bash
set -euo pipefail
source .venv/bin/activate
PYTHONPATH=. python scripts/train/train_stage_a.py \
  --cfg configs/day05_train.yaml \
  --max-steps 25 \
  --ckpt-suffix nbA \
  --log-suffix nbA

## 4) Train — Stage B (Day 7)
Fine-tune experts + in/out. Adjust `--max-steps`.

In [None]:
%%bash
set -euo pipefail
source .venv/bin/activate
PYTHONPATH=. python scripts/train/train_stage_b.py \
  --cfg configs/day07_trainB.yaml \
  --max-steps 100 \
  --ckpt-suffix nbB \
  --log-suffix nbB

## 5) Inference (Day 6 sampler)
Edit `configs/day06_infer.yaml` if you want to point at your Stage-B ckpt and set seeds/steps.
This will write MP4 + WAV under `outputs/day06/`.

In [None]:
%%bash
set -euo pipefail
source .venv/bin/activate
PYTHONPATH=. python scripts/infer/infer_joint.py --cfg configs/day06_infer.yaml

## 6) Inspect outputs

In [None]:
import pathlib

out = pathlib.Path('outputs/day06')
list(sorted(str(p) for p in out.glob('*')))[:20]