# Day 7 — Stage-B fine-tune (experts) + Joint Inference

**Goal**: Freeze most of JointDiT, unfreeze a few expert blocks + I/O projections, train briefly (Stage-B), then run inference with the new checkpoint.

**What this notebook does**
1) Sanity: set memory-friendly env vars (same knobs as earlier days).
2) Train **Stage-B** for a small number of steps (or use your pre-trained paths).
3) Generate a Day-7 inference config pointing to your Stage-B checkpoint and decode MP4/WAV.
4) Troubleshooting notes for common hiccups.

**Assumptions**
- Project root = `/workspace/jointdit`
- You’ve already cached Day-2 latents and completed Day-5 training once.
- Day-6 inference worked (decoders in place).


In [None]:
# Paths (edit if you changed layout)
REPO = "/workspace/jointdit"
CKPT_STAGEB_DIR = f"{REPO}/checkpoints/day07_stage_b_d7"                # where we'll save a Stage-B run
CKPT_STAGEB_FILE = f"{CKPT_STAGEB_DIR}/ckpt_step_001000.pt"             # or pick the step you want
CKPT_NOJOINT_DIR = f"{REPO}/checkpoints/day07_ablate_nojoint_d7nojoint" # optional ablation ckpt
CKPT_NOJOINT_FILE = f"{CKPT_NOJOINT_DIR}/ckpt_step_001000.pt"
OUT_DIR = f"{REPO}/outputs/day07"

print(REPO, CKPT_STAGEB_DIR, OUT_DIR)

## 1) Memory knobs (kept from earlier days)
Feel free to adjust if you have more/less VRAM.

In [None]:
%env PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128
%env JOINTDIT_Q_CHUNK_V=128
%env JOINTDIT_Q_CHUNK_A=0
%env JOINTDIT_KV_DOWNSAMPLE=4
%env JOINTDIT_MAX_T=6

## 2) Stage-B training (experts)
Uses `configs/day07_trainB.yaml`. You can bump `--max-steps` if you want a denser update.

In [None]:
import pathlib

pathlib.Path(CKPT_STAGEB_DIR).mkdir(parents=True, exist_ok=True)

# Quick smoke (set to 100); change to 1000+ if you want more adaptation
MAX_STEPS = 100

# This trains with AMP only if dtype=fp16 in the config; otherwise pure fp32
! . {REPO}/.venv/bin/activate && \
  PYTHONPATH={REPO} python {REPO}/scripts/train/train_stage_b.py \
  --cfg {REPO}/configs/day07_trainB.yaml \
  --max-steps {MAX_STEPS} \
  --log-suffix d7 \
  --ckpt-suffix d7

## 3) Create Day-7 inference config pointing to your Stage-B checkpoint
We clone Day-6’s infer config and just swap in the new ckpt + output dir.

In [None]:
import pathlib
import shutil

import yaml

src = f"{REPO}/configs/day06_infer.yaml"
dst = f"{REPO}/configs/day07_infer.yaml"
shutil.copy(src, dst)

with open(dst, "r") as f:
    y = yaml.safe_load(f)

# pick which checkpoint to use (Stage-B real or ablation)
USE_NO_JOINT = False    # set True to test no-joint ablation
y["ckpt"] = CKPT_NOJOINT_FILE if USE_NO_JOINT else CKPT_STAGEB_FILE
y["out_dir"] = OUT_DIR

with open(dst, "w") as f:
    yaml.safe_dump(y, f, sort_keys=False)

print("Wrote:", dst)
print("ckpt:", y["ckpt"])

## 4) Inference (Joint)
This decodes MP4/WAV into `outputs/day07/…`.

In [None]:
! . {REPO}/.venv/bin/activate && \
  PYTHONPATH={REPO} python {REPO}/scripts/infer/infer_joint.py \
  --cfg {REPO}/configs/day07_infer.yaml

## 5) List outputs

In [None]:
! ls -lh {OUT_DIR} || true

## Troubleshooting
- **“No Stage-B ckpt found” in `make day7-infer`**: either update `configs/day07_infer.yaml: ckpt:` to your file
  (e.g. `/workspace/jointdit/checkpoints/day07_stage_b_d7/ckpt_step_001000.pt`) or symlink it into the path the Makefile expects.
- **GradScaler: “Attempting to unscale FP16 gradients.”**: ensure your config’s `runtime.dtype` is `fp16` *if* you want AMP; otherwise we train in fp32 and skip `scaler.unscale_()`.
- **OOM**: lower `JOINTDIT_MAX_T` to 4, or raise `JOINTDIT_Q_CHUNK_V`, or increase `JOINTDIT_KV_DOWNSAMPLE`.
- **Silent WAV**: at low steps, the Griffin-Lim inversion from mel can be weak. Bump `steps` (40–60), lower guidance (e.g. 1.0), or adopt a proper vocoder later.
