# Day 6 — Joint Inference (video + audio)

This notebook runs the Day-6 inference scaffold:
- Loads **JointDiT** (fp32 params) + cached shapes from a reference meta file
- Samples with simple ancestral steps (log-spaced sigmas)
- Decodes video via **SVD VAE (Temporal Decoder)** and audio via **AudioLDM2 VAE**
- Inverts mel to waveform (torchaudio Griffin-Lim), with louder defaults + peak-norm
- Writes MP4 + WAV; optionally muxes audio into MP4 if `ffmpeg` is present


In [None]:
# Setup & environment knobs
import os
import subprocess
import sys
from pathlib import Path

repo = Path.cwd()
assert (repo / 'scripts' / 'infer' / 'infer_joint.py').exists(), 'Run from repo root (/workspace/jointdit)'

# Memory-savvy defaults (safe values used in Day 5/6)
os.environ.setdefault('JOINTDIT_Q_CHUNK_V', '128')
os.environ.setdefault('JOINTDIT_Q_CHUNK_A', '0')
os.environ.setdefault('JOINTDIT_KV_DOWNSAMPLE', '4')
os.environ.setdefault('PYTORCH_CUDA_ALLOC_CONF', 'max_split_size_mb:128')

# Louder audio defaults (can tweak per run)
os.environ.setdefault('JOINTDIT_MEL_DB_RANGE', '70')  # 60..80 typical
os.environ.setdefault('JOINTDIT_AUDIO_GAIN_DB', '3') # extra boost (dB)

print('env:', {k: os.environ[k] for k in os.environ if k.startswith('JOINTDIT_') or k=='PYTORCH_CUDA_ALLOC_CONF'})


## Inspect Day-6 config
We point inference at `configs/day06_infer.yaml` (created during the Day-6 step).

In [None]:
print(Path('configs/day06_infer.yaml').read_text())


## Run inference
Equivalent to the Make target `make smoke-day06` but invoked inline for visibility.

In [None]:
cmd = [sys.executable, 'scripts/infer/infer_joint.py', '--cfg', 'configs/day06_infer.yaml']
print('Running:', ' '.join(cmd))
subprocess.run(cmd, check=True)


## Outputs
MP4s and WAVs land in `outputs/day06/`. If `ffmpeg` is installed, an `*_av.mp4` with muxed audio may also be written.

In [None]:
outdir = Path('outputs/day06')
outdir.mkdir(parents=True, exist_ok=True)
print('\n'.join(sorted(str(p) for p in outdir.glob('*'))))


## Troubleshooting quickies
- **`expandable_segments` crash**: don’t set it; use `PYTORCH_CUDA_ALLOC_CONF="max_split_size_mb:128"` only.
- **Temporal VAE `decode()`**: requires `num_frames=1` when using `AutoencoderKLTemporalDecoder`.
- **dtype mismatch (Half vs Float)**: keep model & VAEs in **fp32** for inference in this scaffold.
- **Silent WAV**: raise loudness:
  ```bash
  export JOINTDIT_MEL_DB_RANGE=80
  export JOINTDIT_AUDIO_GAIN_DB=6
  ```
