# Lab 3 Sample Gallery

Generate qualitative samples from any Lab 3 run checkpoint.

This notebook is configured by default for `automatedruns20`.
It exports:
- generated WAV files
- optional real WAV references
- mel comparison plots
- `generation_summary.csv`


In [None]:
from pathlib import Path
import sys
import json
import random
import shutil
import importlib

# ensure `lab 3/src` is importable whether notebook is run from repo root or lab 3 dir
cwd = Path.cwd()
if (cwd / 'src').exists() and (cwd / 'run_lab3.py').exists():
    sys.path.insert(0, str(cwd))
elif (cwd / 'lab 3' / 'src').exists():
    sys.path.insert(0, str(cwd / 'lab 3'))

import numpy as np
import pandas as pd
import torch
import librosa
import matplotlib.pyplot as plt

try:
    import soundfile as sf
    HAS_SF = True
except Exception:
    HAS_SF = False

from IPython.display import Audio, display

from src.lab3_data import load_cache, stratified_split_indices
from src.lab3_models import ReconstructionDecoder
from src.lab3_bridge import FrozenLab1Encoder, denormalize_log_mel
from src.lab3_train import load_target_centroids, build_condition_bank
from src.lab3_sampling import export_posttrain_samples

if not hasattr(torch, '_utils'):
    torch._utils = importlib.import_module('torch._utils')

print('torch:', torch.__version__)
print('soundfile_available:', HAS_SF)


## Config

Change `RUN_DIR` to any run under `saves2/lab3_synthesis`.


In [None]:
# Resolve repo root robustly
cwd = Path.cwd()
if (cwd / 'src').exists() and (cwd / 'run_lab3.py').exists():
    LAB3_DIR = cwd
    REPO_ROOT = cwd.parent
elif (cwd / 'lab 3' / 'src').exists():
    LAB3_DIR = cwd / 'lab 3'
    REPO_ROOT = cwd
else:
    raise RuntimeError('Run this notebook from repo root or from lab 3 directory.')

SAVES_ROOT = REPO_ROOT / 'saves2' / 'lab3_synthesis'
RUN_DIR = None  # set explicit path or keep None to auto-select latest runN
CHECKPOINT_NAME = 'stage2_latest.pt'

OUTPUT_TAG = 'posttrain_samples'
OVERWRITE_OUTPUT = False
WRITE_REAL_AUDIO = True

N_GENERATION_SAMPLES = 100
VAL_RATIO = 0.15
SEED = 328

TARGET_MODE = 'balanced_random'   # 'balanced_random' | 'round_robin' | 'random'
TARGET_GENRE_ORDER = ['hiphop_xtc', 'lofi_hh_lfbb', 'baroque_classical', 'cc0_other']

GL_ITERS = 48
DEVICE = 'auto'

def find_latest_runN(saves_root: Path):
    if not saves_root.exists():
        return None
    cands = []
    for d in saves_root.iterdir():
        if not d.is_dir():
            continue
        name = d.name
        if name.startswith('run') and name[3:].isdigit() and (d / 'run_state.json').exists():
            cands.append((int(name[3:]), d))
    if not cands:
        return None
    cands.sort(key=lambda x: x[0])
    return cands[-1][1]

if RUN_DIR is None:
    latest = find_latest_runN(SAVES_ROOT)
    if latest is None:
        raise FileNotFoundError(f'No runN folders with run_state.json under {SAVES_ROOT}')
    RUN_DIR = latest

print('LAB3_DIR:', LAB3_DIR)
print('RUN_DIR:', RUN_DIR)


In [None]:
run_state_path = RUN_DIR / 'run_state.json'
cache_dir = RUN_DIR / 'cache'
ckpt_path = RUN_DIR / 'checkpoints' / CHECKPOINT_NAME

if not run_state_path.exists():
    raise FileNotFoundError(f'Missing run_state.json: {run_state_path}')
if not cache_dir.exists():
    raise FileNotFoundError(f'Missing cache dir: {cache_dir}')
if not ckpt_path.exists():
    raise FileNotFoundError(f'Missing checkpoint: {ckpt_path}')

run_state = json.loads(run_state_path.read_text(encoding='utf-8'))
run_cfg = run_state.get('config', {}) if isinstance(run_state, dict) else {}

idx_df, arrays, genre_to_idx = load_cache(cache_dir)
idx_to_genre = {v: k for k, v in genre_to_idx.items()}

lab1_ckpt = Path(run_cfg.get('lab1_checkpoint', REPO_ROOT / 'saves' / 'lab1_run_combo_af_gate_exit_v2' / 'latest.pt'))
lab2_centroids = Path(run_cfg.get('lab2_centroids_json', REPO_ROOT / 'saves' / 'lab2_calibration' / 'target_centroids.json'))

g_norm = str(run_cfg.get('generator_norm', 'instance'))
g_upsample = str(run_cfg.get('generator_upsample', 'transpose'))
g_sn = bool(run_cfg.get('generator_spectral_norm', False))
g_mrf = bool(run_cfg.get('generator_mrf', False))
g_mrf_kernels = tuple(int(x.strip()) for x in str(run_cfg.get('generator_mrf_kernels', '3,7,11')).split(',') if x.strip())

if DEVICE == 'auto':
    device_t = 'cuda' if torch.cuda.is_available() else 'cpu'
else:
    device_t = DEVICE

encoder = FrozenLab1Encoder(lab1_ckpt, device=device_t)
sr = int(encoder.cfg.sample_rate)

centroids = load_target_centroids(lab2_centroids)
cond_bank = build_condition_bank(genre_to_idx, centroids).to(device_t)

G = ReconstructionDecoder(
    zc_dim=int(arrays['z_content'].shape[1]),
    cond_dim=int(cond_bank.shape[1]),
    n_mels=int(arrays['mel_norm'].shape[1]),
    n_frames=int(arrays['mel_norm'].shape[2]),
    norm=g_norm,
    upsample=g_upsample,
    spectral_norm=g_sn,
    mrf=g_mrf,
    mrf_kernels=g_mrf_kernels,
).to(device_t)

payload = torch.load(ckpt_path, map_location='cpu')
incoming = payload.get('generator', {})
current = G.state_dict()
filtered = {k: v for k, v in incoming.items() if (k in current and tuple(v.shape) == tuple(current[k].shape))}
missing, unexpected = G.load_state_dict(filtered, strict=False)
dropped = [k for k, v in incoming.items() if (k not in current or tuple(v.shape) != tuple(current.get(k, v).shape))]

G.eval()

print('[load] device:', device_t)
print('[load] generator_norm:', g_norm, 'upsample:', g_upsample, 'mrf:', g_mrf, 'sn:', g_sn)
print('[load] genres:', list(genre_to_idx.keys()))
print('[load] checkpoint keys kept:', len(filtered), 'dropped:', len(dropped))
if dropped:
    print('[load] dropped keys (first 8):', dropped[:8])

train_idx, val_idx = stratified_split_indices(arrays['genre_idx'], val_ratio=VAL_RATIO, seed=SEED)
if len(val_idx) == 0:
    val_idx = np.arange(len(arrays['genre_idx']))

rng = np.random.default_rng(SEED)
n_pick = min(int(N_GENERATION_SAMPLES), len(val_idx))
picked_idx = rng.choice(val_idx, size=n_pick, replace=False)

source_map = run_state.get('genre_to_lab1_source_idx', {}) if isinstance(run_state, dict) else {}
print('[load] selected samples:', len(picked_idx))
print('[load] source-map available for genres:', source_map)


In [None]:
MEL_DB_MIN = -80.0
MEL_DB_MAX = 0.0

def mel_norm_to_db_np(mel_norm_np: np.ndarray) -> np.ndarray:
    t = torch.from_numpy(mel_norm_np).unsqueeze(0)
    db = denormalize_log_mel(t).squeeze(0).cpu().numpy()
    return db.astype(np.float32)

def mel_db_to_audio(mel_db: np.ndarray, sr: int, gl_iters: int) -> np.ndarray:
    mel_power = librosa.db_to_power(mel_db)
    y = librosa.feature.inverse.mel_to_audio(
        mel_power,
        sr=sr,
        n_fft=1024,
        hop_length=256,
        win_length=1024,
        fmin=20,
        fmax=sr // 2,
        n_iter=int(gl_iters),
    )
    if np.max(np.abs(y)) > 0:
        y = y / (np.max(np.abs(y)) + 1e-8)
    return y.astype(np.float32)

def choose_target_genre(sample_i: int, src_genre: str, all_genres: list[str], rng: np.random.Generator) -> str:
    candidates = [g for g in TARGET_GENRE_ORDER if g in all_genres]
    if not candidates:
        candidates = list(all_genres)

    if TARGET_MODE == 'random':
        tgt = str(rng.choice(candidates))
    else:
        tgt = candidates[sample_i % len(candidates)]

    if len(candidates) > 1 and tgt == src_genre:
        if TARGET_MODE == 'random':
            other = [g for g in candidates if g != src_genre]
            tgt = str(rng.choice(other))
        else:
            tgt = candidates[(sample_i + 1) % len(candidates)]
    return tgt

def cosine_np(a: np.ndarray, b: np.ndarray) -> float:
    an = np.linalg.norm(a) + 1e-8
    bn = np.linalg.norm(b) + 1e-8
    return float(np.dot(a, b) / (an * bn))

def style_eval_for_target(fake_db: np.ndarray, target_genre: str) -> tuple[str, float]:
    with torch.no_grad():
        x = torch.from_numpy(fake_db).unsqueeze(0).to(device_t)
        out = encoder.model(x, grl_lambda=0.0)
        logits = out['style_logits'][0]
        probs = torch.softmax(logits, dim=0).detach().cpu().numpy()

    source_idx = int(source_map.get(target_genre, -1)) if isinstance(source_map, dict) else -1
    if source_idx >= 0 and source_idx < len(probs):
        conf = float(probs[source_idx])
    else:
        conf = float(np.max(probs))
    pred_idx = int(np.argmax(probs))
    idx_to_source = {v: k for k, v in encoder.source_to_idx.items()}
    pred_name = idx_to_source.get(pred_idx, str(pred_idx))
    return pred_name, conf


In [None]:
export_dir = RUN_DIR / 'samples' / OUTPUT_TAG
if OVERWRITE_OUTPUT and export_dir.exists():
    shutil.rmtree(export_dir)

sample_info = export_posttrain_samples(
    generator=G,
    frozen_encoder=encoder,
    arrays=arrays,
    index_df=idx_df,
    genre_to_idx=genre_to_idx,
    cond_bank=cond_bank,
    out_dir=export_dir,
    val_idx=val_idx,
    n_samples=int(N_GENERATION_SAMPLES),
    target_mode=str(TARGET_MODE),
    griffin_lim_iters=int(GL_ITERS),
    seed=int(SEED),
    device=device_t,
    genre_to_source_idx={str(g): int(v) for g, v in source_map.items()} if isinstance(source_map, dict) else None,
    write_real_audio=bool(WRITE_REAL_AUDIO),
)

print('[export]', sample_info)
summary_csv = Path(sample_info['summary_csv'])
samples_df = pd.read_csv(summary_csv)
print('[export] rows:', len(samples_df))
samples_df.head(10)


In [None]:
summary_csv = RUN_DIR / 'samples' / OUTPUT_TAG / 'generation_summary.csv'
if not summary_csv.exists():
    raise FileNotFoundError(summary_csv)

df = pd.read_csv(summary_csv)
print('Rows:', len(df))
if len(df):
    print('Mean MPS:', float(df['mps_cosine'].mean()))
    if 'style_conf_target' in df.columns:
        print('Mean style_conf_target:', float(df['style_conf_target'].fillna(0.0).mean()))

display(df.groupby(['source_genre', 'target_genre']).size().rename('count').reset_index())
display(df.groupby('target_genre')['mps_cosine'].mean().rename('mean_mps').reset_index())
if 'style_conf_target' in df.columns:
    display(df.groupby('target_genre')['style_conf_target'].mean().rename('mean_style_conf').reset_index())


In [None]:
summary_csv = RUN_DIR / 'samples' / OUTPUT_TAG / 'generation_summary.csv'
df = pd.read_csv(summary_csv)
preview_n = min(3, len(df))

for i in range(preview_n):
    r = df.iloc[i]
    conf = r['style_conf_target'] if 'style_conf_target' in df.columns else float('nan')
    print(f"[{i}] {r['source_genre']} -> {r['target_genre']} | mps={r['mps_cosine']:.3f} | style_conf={conf:.3f}")
    real_wav = Path(r['real_wav']) if isinstance(r['real_wav'], str) and len(r['real_wav']) else None
    fake_wav = Path(r['fake_wav']) if isinstance(r['fake_wav'], str) and len(r['fake_wav']) else None
    if real_wav is not None and real_wav.exists():
        print('Real:')
        display(Audio(filename=str(real_wav)))
    if fake_wav is not None and fake_wav.exists():
        print('Generated:')
        display(Audio(filename=str(fake_wav)))
