# Lab 3 Sample Gallery (Codec + Mel)

This notebook is now codec-aware and defaults to the codec run path.

Use it to:
- load an existing run (`lab3_codec_transfer` or `lab3_synthesis`)
- optionally generate fresh samples from checkpoint
- listen to **source chunk vs generated** side-by-side
- compare quick similarity metrics (wave cosine / MFCC cosine).

In [None]:
from pathlib import Path
import sys
import json
import random
import importlib

import numpy as np
import pandas as pd
import torch
import librosa
import matplotlib.pyplot as plt

try:
    import soundfile as sf
    HAS_SF = True
except Exception:
    HAS_SF = False

from IPython.display import Audio, display

cwd = Path.cwd()
if (cwd / 'src').exists() and ((cwd / 'run_lab3.py').exists() or (cwd / 'run_lab3_codec.py').exists()):
    LAB3_DIR = cwd
    REPO_ROOT = cwd.parent
elif (cwd / 'lab 3' / 'src').exists():
    LAB3_DIR = cwd / 'lab 3'
    REPO_ROOT = cwd
else:
    raise RuntimeError('Run from repo root or lab 3 directory.')

if str(LAB3_DIR) not in sys.path:
    sys.path.insert(0, str(LAB3_DIR))

if not hasattr(torch, '_utils'):
    torch._utils = importlib.import_module('torch._utils')

import subprocess
import importlib.util

def _ensure_pkg(mod_name: str, pip_name: str | None = None):
    if importlib.util.find_spec(mod_name) is None:
        subprocess.check_call([sys.executable, '-m', 'pip', 'install', pip_name or mod_name])

_ensure_pkg('transformers')
_ensure_pkg('sentencepiece')

from src.lab3_data import load_cache, stratified_split_indices, stratified_group_split_indices
from src.lab3_codec_data import load_codec_cache
from src.lab3_codec_models import CodecLatentTranslator
from src.lab3_codec_bridge import FrozenEncodec
from src.lab3_codec_train import build_style_centroid_bank, build_style_exemplar_bank
from src.lab3_sampling import resolve_next_run_name

print('torch:', torch.__version__)
print('soundfile_available:', HAS_SF)

## Config

In [None]:
# Defaults now target codec pipeline run1005
RUN_KIND = 'codec'   # 'codec' | 'mel' | 'auto'

CODEC_RUNS_ROOT = REPO_ROOT / 'saves2' / 'lab3_codec_transfer'
MEL_RUNS_ROOT = REPO_ROOT / 'saves2' / 'lab3_synthesis'

RUN_DIR = CODEC_RUNS_ROOT / 'run1005'  # set None to auto-pick latest runN in selected root
OUTPUT_TAG = 'posttrain_samples'

GENERATE_NEW_SAMPLES = False
FORCE_REGENERATE = False
N_GENERATION_SAMPLES = 50
VAL_RATIO = 0.15
SEED = 328

CODEC_COND_MODE = 'mix'   # centroid | exemplar | mix
CODEC_COND_ALPHA = 0.35
CODEC_STYLE_JITTER_STD = 0.03

PREVIEW_RANDOM = True
PREVIEW_N = 8
EXPORT_SOURCE_PAIRS = True

DEVICE = 'auto'

def _latest_run(root: Path):
    if not root.exists():
        return None
    cands = []
    for d in root.iterdir():
        if d.is_dir() and d.name.startswith('run') and d.name[3:].isdigit() and (d / 'run_state.json').exists():
            cands.append((int(d.name[3:]), d))
    if not cands:
        return None
    cands.sort(key=lambda x: x[0])
    return cands[-1][1]

In [None]:
if RUN_DIR is None:
    root = CODEC_RUNS_ROOT if RUN_KIND in ('codec', 'auto') else MEL_RUNS_ROOT
    RUN_DIR = _latest_run(root)

if RUN_DIR is None or not RUN_DIR.exists():
    raise FileNotFoundError(f'RUN_DIR not found: {RUN_DIR}')

run_state_path = RUN_DIR / 'run_state.json'
if not run_state_path.exists():
    raise FileNotFoundError(run_state_path)

run_state = json.loads(run_state_path.read_text(encoding='utf-8'))
cfg = run_state.get('config', {}) if isinstance(run_state, dict) else {}

if RUN_KIND == 'auto':
    if (RUN_DIR / 'cache' / 'codec_cache_arrays.npz').exists() and (RUN_DIR / 'checkpoints' / 'stage3_latest.pt').exists():
        run_kind = 'codec'
    else:
        run_kind = 'mel'
else:
    run_kind = RUN_KIND

print('RUN_DIR:', RUN_DIR)
print('run_kind:', run_kind)
print('run_name:', run_state.get('run_name', RUN_DIR.name))

In [None]:
rng = np.random.default_rng(int(SEED))
if DEVICE == 'auto':
    device_t = 'cuda' if torch.cuda.is_available() else 'cpu'
else:
    device_t = DEVICE

samples_dir = RUN_DIR / 'samples' / OUTPUT_TAG
samples_dir.mkdir(parents=True, exist_ok=True)
summary_csv = samples_dir / 'generation_summary.csv'

if run_kind != 'codec':
    raise RuntimeError('This notebook is now intended for codec runs. Set RUN_KIND=codec or auto with codec run dir.')

# Load codec cache
cache_dir = RUN_DIR / 'cache'
idx_df, arrays, genre_to_idx, cache_meta = load_codec_cache(cache_dir)
n_genres = len(genre_to_idx)
idx_to_genre = {int(v): str(k) for k, v in genre_to_idx.items()}

# Build train/val split for sampling
if 'track_id' in idx_df.columns and bool(cfg.get('split_by_track', True)):
    train_idx, val_idx = stratified_group_split_indices(
        arrays['genre_idx'],
        idx_df['track_id'].astype(str).to_numpy(),
        val_ratio=float(cfg.get('val_ratio', VAL_RATIO)),
        seed=int(cfg.get('seed', SEED)),
    )
else:
    train_idx, val_idx = stratified_split_indices(
        arrays['genre_idx'],
        val_ratio=float(cfg.get('val_ratio', VAL_RATIO)),
        seed=int(cfg.get('seed', SEED)),
    )

style_centroid_bank = build_style_centroid_bank(arrays['z_style'], arrays['genre_idx'], n_genres=n_genres).to(device_t)
style_exemplar_bank = build_style_exemplar_bank(arrays['z_style'][train_idx], arrays['genre_idx'][train_idx], n_genres=n_genres)

# Load codec model + checkpoint
codec = FrozenEncodec(
    model_id=str(cfg.get('codec_model_id', 'facebook/encodec_24khz')),
    bandwidth=float(cfg.get('codec_bandwidth', 6.0)),
    chunk_seconds=float(cfg.get('codec_chunk_seconds', 5.0)),
    device=device_t,
)

G = CodecLatentTranslator(
    in_channels=int(arrays['q_emb'].shape[1]),
    z_content_dim=int(arrays['z_content'].shape[1]),
    z_style_dim=int(arrays['z_style'].shape[1]),
    hidden_channels=int(cfg.get('translator_hidden_channels', 256)),
    n_blocks=int(cfg.get('translator_blocks', 10)),
    noise_dim=int(cfg.get('translator_noise_dim', 32)),
    residual_scale=float(cfg.get('translator_residual_scale', 0.5)),
).to(device_t)

ckpt_candidates = [RUN_DIR / 'checkpoints' / 'stage3_latest.pt', RUN_DIR / 'checkpoints' / 'stage2_latest.pt', RUN_DIR / 'checkpoints' / 'stage1_latest.pt']
ckpt_path = next((p for p in ckpt_candidates if p.exists()), None)
if ckpt_path is None:
    raise FileNotFoundError('No stage checkpoint found in run checkpoints/')

try:
    payload = torch.load(str(ckpt_path), map_location='cpu', weights_only=False)
except TypeError:
    payload = torch.load(str(ckpt_path), map_location='cpu')
G.load_state_dict(payload['generator'], strict=True)
G.eval()

print('Loaded checkpoint:', ckpt_path.name)
print('cache rows:', len(idx_df), 'val rows:', len(val_idx), 'genres:', genre_to_idx)

In [None]:
def _pick_target(src_g: int, n_genres: int, i: int) -> int:
    # balanced random-like by cycling + anti-clash
    tgt = int((i + src_g + 1) % n_genres)
    if tgt == int(src_g):
        tgt = int((tgt + 1) % n_genres)
    return tgt

def _z_tgt_for_genre(tgt_g: int):
    z_cent = style_centroid_bank[tgt_g:tgt_g+1].to(device_t).float()
    ex_bank = style_exemplar_bank.get(int(tgt_g))
    if ex_bank is None or int(ex_bank.shape[0]) == 0:
        z_ex = z_cent
    else:
        j = int(rng.integers(0, int(ex_bank.shape[0])))
        z_ex = ex_bank[j:j+1].to(device_t).float()
    if CODEC_COND_MODE == 'centroid':
        z = z_cent
    elif CODEC_COND_MODE == 'exemplar':
        z = z_ex
    else:
        z = float(CODEC_COND_ALPHA) * z_cent + (1.0 - float(CODEC_COND_ALPHA)) * z_ex
    if float(CODEC_STYLE_JITTER_STD) > 0.0:
        z = z + torch.randn_like(z) * float(CODEC_STYLE_JITTER_STD)
    return torch.nn.functional.normalize(z, dim=-1)

if (GENERATE_NEW_SAMPLES and (FORCE_REGENERATE or (not summary_csv.exists()))):
    rows = []
    for i in range(int(N_GENERATION_SAMPLES)):
        ridx = int(val_idx[int(rng.integers(0, len(val_idx)))])
        src_g = int(arrays['genre_idx'][ridx])
        tgt_g = _pick_target(src_g, n_genres=n_genres, i=i)

        q_src = torch.from_numpy(arrays['q_emb'][ridx:ridx+1]).to(device_t).float()
        zc = torch.from_numpy(arrays['z_content'][ridx:ridx+1]).to(device_t).float()
        z_tgt = _z_tgt_for_genre(tgt_g)

        with torch.no_grad():
            q_hat = G(q_src=q_src, z_content=zc, z_style_tgt=z_tgt)
            wav = codec.decode_embeddings(q_hat)[0, 0].detach().cpu().numpy().astype(np.float32)
        wav = wav / (np.max(np.abs(wav)) + 1e-8)

        out_wav = samples_dir / f'sample_{i:04d}_src{src_g}_tgt{tgt_g}.wav'
        sf.write(str(out_wav), wav, int(codec.cfg.sample_rate))

        rows.append({
            'sample_id': int(i),
            'cache_row': int(ridx),
            'source_genre_idx': int(src_g),
            'target_genre_idx': int(tgt_g),
            'source_genre': idx_to_genre.get(int(src_g), str(src_g)),
            'target_genre': idx_to_genre.get(int(tgt_g), str(tgt_g)),
            'fake_wav': str(out_wav),
        })

    pd.DataFrame(rows).to_csv(summary_csv, index=False)
    print('Generated samples ->', summary_csv)
else:
    print('Using existing sample summary:', summary_csv)

if not summary_csv.exists():
    raise FileNotFoundError(f'Missing generation summary: {summary_csv}')

In [None]:
df = pd.read_csv(summary_csv)
if 'source_genre' not in df.columns and 'source_genre_idx' in df.columns:
    df['source_genre'] = df['source_genre_idx'].map(lambda x: idx_to_genre.get(int(x), str(x)))
if 'target_genre' not in df.columns and 'target_genre_idx' in df.columns:
    df['target_genre'] = df['target_genre_idx'].map(lambda x: idx_to_genre.get(int(x), str(x)))

pair_dir = samples_dir / 'source_pairs'
pair_dir.mkdir(parents=True, exist_ok=True)

chunk_seconds = float(cfg.get('codec_chunk_seconds', 5.0))
sr = int(codec.cfg.sample_rate)

src_wavs = []
wave_cos = []
mfcc_cos = []

for _, r in df.iterrows():
    ridx = int(r['cache_row'])
    fake_path = Path(r['fake_wav'])
    meta = idx_df.iloc[ridx]
    src_path = Path(str(meta['path']))
    start_sec = float(meta.get('start_sec', 0.0))

    y_src, _ = librosa.load(str(src_path), sr=sr, mono=True, offset=max(0.0, start_sec), duration=max(0.2, chunk_seconds))
    y_fake, _ = librosa.load(str(fake_path), sr=sr, mono=True)
    n = min(len(y_src), len(y_fake))
    if n <= 1:
        src_wavs.append('')
        wave_cos.append(np.nan)
        mfcc_cos.append(np.nan)
        continue

    y_src = y_src[:n].astype(np.float32)
    y_fake = y_fake[:n].astype(np.float32)

    src_out = pair_dir / f"sample_{int(r['sample_id']):04d}_source.wav"
    if EXPORT_SOURCE_PAIRS:
        sf.write(str(src_out), y_src, sr)
        src_wavs.append(str(src_out))
    else:
        src_wavs.append('')

    c = float(np.dot(y_src, y_fake) / ((np.linalg.norm(y_src) * np.linalg.norm(y_fake)) + 1e-12))
    wave_cos.append(c)

    m1 = librosa.feature.mfcc(y=y_src, sr=sr, n_mfcc=13).mean(axis=1)
    m2 = librosa.feature.mfcc(y=y_fake, sr=sr, n_mfcc=13).mean(axis=1)
    mcos = float(np.dot(m1, m2) / ((np.linalg.norm(m1) * np.linalg.norm(m2)) + 1e-12))
    mfcc_cos.append(mcos)

df['source_wav'] = src_wavs
df['wave_cosine_src_fake'] = wave_cos
df['mfcc_cosine_src_fake'] = mfcc_cos

compare_csv = samples_dir / 'generation_compare_with_source.csv'
df.to_csv(compare_csv, index=False)
print('compare csv:', compare_csv)
print('rows:', len(df))
display(df[['sample_id','source_genre','target_genre','wave_cosine_src_fake','mfcc_cosine_src_fake']].head(12))

In [None]:
print('Mean wave cosine (src vs fake):', float(df['wave_cosine_src_fake'].mean()))
print('Mean MFCC cosine (src vs fake):', float(df['mfcc_cosine_src_fake'].mean()))

display(df.groupby(['source_genre','target_genre']).size().rename('count').reset_index())
display(df.groupby('target_genre')[['wave_cosine_src_fake','mfcc_cosine_src_fake']].mean().reset_index())

In [None]:
view_df = df.copy()
if PREVIEW_RANDOM:
    view_df = view_df.sample(min(int(PREVIEW_N), len(view_df)), random_state=int(SEED)).reset_index(drop=True)
else:
    view_df = view_df.head(min(int(PREVIEW_N), len(view_df))).reset_index(drop=True)

for i, r in view_df.iterrows():
    sid = int(r['sample_id'])
    print(f"[{i}] sample={sid} | {r['source_genre']} -> {r['target_genre']} | wave_cos={r['wave_cosine_src_fake']:.3f} | mfcc_cos={r['mfcc_cosine_src_fake']:.3f}")
    src = Path(r['source_wav']) if isinstance(r.get('source_wav',''), str) and len(str(r.get('source_wav',''))) else None
    fake = Path(r['fake_wav']) if isinstance(r.get('fake_wav',''), str) and len(str(r.get('fake_wav',''))) else None

    if src is not None and src.exists():
        print('Source chunk:')
        display(Audio(filename=str(src)))
    if fake is not None and fake.exists():
        print('Generated:')
        display(Audio(filename=str(fake)))
    print('-' * 70)