# Lab 3 - Reconstruction Decoder (Notebook Runner)

This notebook is configured to run Lab 3 end-to-end from top to bottom:

1. Configure run variables (single config cell).
2. Launch staged training (`stage1` then `stage2`) using `run_lab3.py`.
3. Load audit outputs.
4. Generate qualitative samples (mel plots + WAV files) for direct inspection.

Use `RUN_MODE='fresh'` for a new run or `RUN_MODE='resume'` to continue an existing run.


In [None]:
from pathlib import Path
import json
import shutil
import subprocess
import os

import numpy as np
import pandas as pd
import torch
import librosa
import matplotlib.pyplot as plt

try:
    import soundfile as sf
    HAS_SF = True
except Exception:
    HAS_SF = False

from IPython.display import Audio, display

from src.lab3_data import load_cache, stratified_split_indices
from src.lab3_models import ReconstructionDecoder
from src.lab3_bridge import FrozenLab1Encoder, denormalize_log_mel
from src.lab3_train import load_target_centroids, build_condition_bank
from src.lab3_sampling import export_posttrain_samples
import importlib

# torch sanity guard for notebook kernels with partially initialized modules
if not hasattr(torch, '_utils'):
    torch._utils = importlib.import_module('torch._utils')
print('torch:', torch.__version__, 'file:', torch.__file__)



## Run Config

Set everything here, then run all cells.


In [None]:
# -----------------------------
# Core run controls
# -----------------------------
REPO_ROOT = Path.cwd().parent
LAB3_DIR = Path.cwd()
SAVES_ROOT = REPO_ROOT / 'saves2' / 'lab3_synthesis'

def find_latest_run_dir(saves_root: Path):
    candidates = []
    for d in saves_root.iterdir() if saves_root.exists() else []:
        if not d.is_dir():
            continue
        rs = d / 'run_state.json'
        if rs.exists():
            candidates.append(d)
    if not candidates:
        return None
    return max(candidates, key=lambda p: (p / 'run_state.json').stat().st_mtime)

RUN_MODE = 'fresh'  # 'fresh' or 'resume'
RUN_NAME = ''               # empty => strict auto numbered folders: run1, run2, ...
RESUME_DIR = SAVES_ROOT / 'run1'  # used only when RUN_MODE='resume'
CLEAN_START = False          # if fresh and run folder exists, delete it first

# Optional cache reuse (speeds up iteration)
REUSE_CACHE_DIR = None  # resume from run cache/checkpoints

# -----------------------------
# Data/model paths
# -----------------------------
MANIFESTS_ROOT = Path('Z:/DataSets/_lab1_manifests')
LAB1_CHECKPOINT = REPO_ROOT / 'saves' / 'lab1_run_combo_af_gate_exit_v2' / 'latest.pt'
LAB2_TARGET_CENTROIDS = REPO_ROOT / 'saves' / 'lab2_calibration' / 'lab2_20260211_015118_lda_cleanup_v2' / 'target_centroids.json'

# -----------------------------
# Training scale
# -----------------------------
SMOKE = False
PER_GENRE_SAMPLES = 800
CHUNKS_PER_TRACK = 4
CHUNK_SAMPLING = 'uniform'
MIN_START_SEC = 0.0
MAX_START_SEC = None
SPLIT_BY_TRACK = True
STAGE2_COND_MODE = 'mix'
STAGE2_COND_ALPHA_START = 0.8
STAGE2_COND_ALPHA_END = 0.4
STAGE2_COND_EXEMPLAR_NOISE_STD = 0.03
STAGE2_TARGET_BALANCE = True
VAL_RATIO = 0.15
N_FRAMES = 256
BATCH_SIZE = 32
NUM_WORKERS = 0
SEED = 328
DEVICE = 'auto'              # 'auto' | 'cuda' | 'cpu'
GENERATOR_NORM = 'instance'  # 'instance' | 'batch'
GENERATOR_SPECTRAL_NORM = False
GENERATOR_UPSAMPLE = 'transpose'  # 'transpose' | 'pixelshuffle'
DISCRIMINATOR_ARCH = 'single'  # 'single' | 'multiscale'
DISCRIMINATOR_SCALES = 3

STAGE1_EPOCHS = 20
STAGE2_EPOCHS = 60
MAX_BATCHES_PER_EPOCH = None # int or None

# -----------------------------
# Loss weights
# -----------------------------
LR_G = 2e-4
LR_D = 2e-4
ADV_WEIGHT = 0.8
RECON_WEIGHT = 8.0
CONTENT_WEIGHT = 3.0
STYLE_WEIGHT = 10.0
CONTINUITY_WEIGHT = 1.0
MRSTFT_WEIGHT = 0.0
STAGE1_MRSTFT_WEIGHT = 0.0
STAGE2_MRSTFT_WEIGHT = 0.0
MRSTFT_RESOLUTIONS = '64,16,64;128,32,128;256,64,256'
FLATNESS_WEIGHT = 0.2
FEATURE_MATCH_WEIGHT = 0.0
PERCEPTUAL_WEIGHT = 0.0
STYLE_HINGE_WEIGHT = 0.0
CONTRASTIVE_WEIGHT = 0.0
BATCH_INFONCE_DIV_WEIGHT = 0.0
DIVERSITY_WEIGHT = 0.0
TIMBRE_BALANCE_WEIGHT = 0.0
LOWMID_RECON_WEIGHT = 0.0
SPECTRAL_TILT_WEIGHT = 0.0
ZCR_PROXY_WEIGHT = 0.0
STYLE_MID_WEIGHT = 0.0
HF_MUZZLE_WEIGHT = 0.0
HIGHPASS_ANCHOR_WEIGHT = 0.0
MEL_DIVERSITY_WEIGHT = 0.0
TARGET_PROFILE_WEIGHT = 0.0
STAGE2_D_LR_MULT = 0.5
STAGE2_CONTENT_START = 0.5
STAGE2_CONTENT_END = 0.2
STAGE2_STYLE_LABEL_SMOOTHING = 0.1
STAGE2_STYLE_ONLY_WARMUP_EPOCHS = 4
STAGE2_G_LR_WARMUP_EPOCHS = 5
STAGE2_G_LR_START_MULT = 0.3
STAGE2_COND_NOISE_STD = 0.0
STAGE2_STYLE_JITTER_STD = 0.0
STAGE2_STYLE_HINGE_TARGET_CONF = 0.85
STAGE2_ADAPTIVE_CONTENT = False
STAGE2_ADAPTIVE_CONTENT_LOW = 0.0
STAGE2_ADAPTIVE_CONTENT_HIGH = 0.4
STAGE2_ADAPTIVE_CONF_LOW = 0.30
STAGE2_ADAPTIVE_CONF_HIGH = 0.45
STAGE2_STYLE_CRITIC_LR = 2e-4
STAGE2_CONTRASTIVE_TEMP = 0.10
STAGE2_BATCH_INFONCE_TEMP = 0.15
STAGE2_DIVERSITY_MARGIN = 0.90
STAGE2_DIVERSITY_MAX_PAIRS = 128
STAGE2_MEL_DIVERSITY_MARGIN = 0.60
STAGE2_MEL_DIVERSITY_MAX_PAIRS = 192
STAGE2_STYLE_LOWPASS_KEEP_BINS = 80
STAGE2_STYLE_LOWPASS_CUTOFF_HZ = None
STAGE2_STYLE_MID_LOW_BIN = 8
STAGE2_STYLE_MID_HIGH_BIN = 56
STAGE2_LOWMID_SPLIT_BIN = 80
STAGE2_LOWMID_GAIN = 7.0
STAGE2_HIGH_GAIN = 0.45
STAGE2_SPECTRAL_TILT_MAX_RATIO = 0.78
STAGE2_ZCR_PROXY_TARGET_MAX = 0.20
STAGE2_STYLE_THAW_LAST_EPOCHS = 0
STAGE2_STYLE_THAW_LR = 1e-6
STAGE2_STYLE_THAW_SCOPE = 'style_head'
RESET_STAGE2_OUT_LAYER = False
D_REAL_LABEL = 1.0
D_FAKE_LABEL = 0.0
G_REAL_LABEL = 1.0

# -----------------------------
# Exit thresholds
# -----------------------------
MPS_THRESHOLD = 0.90
SF_THRESHOLD = 0.85
EVAL_MAX_BATCHES = 30

# -----------------------------
# Generation preview settings
# -----------------------------
N_GENERATION_SAMPLES = 6
TARGET_GENRE_ORDER = ['baroque_classical', 'hiphop_xtc', 'lofi_hh_lfbb', 'cc0_other']
POSTTRAIN_SAMPLE_EXPORT_TAG = 'posttrain_samples'
POSTTRAIN_SAMPLE_COUNT = 100
POSTTRAIN_SAMPLE_TARGET_MODE = 'balanced_random'
POSTTRAIN_SAMPLE_WRITE_REAL = True
GL_ITERS = 64

if RUN_MODE == 'fresh':
    OUT_DIR = SAVES_ROOT / RUN_NAME if RUN_NAME else SAVES_ROOT
elif RUN_MODE == 'resume':
    if RESUME_DIR is None:
        raise ValueError("RUN_MODE='resume' requires RESUME_DIR")
    OUT_DIR = Path(RESUME_DIR)
else:
    raise ValueError("RUN_MODE must be 'fresh' or 'resume'")

OUT_DIR










In [None]:
# Prepare run directory (fresh mode cleanup)
if RUN_MODE == 'fresh' and RUN_NAME == '':
    print('[setup] RUN_NAME empty: training will auto-create numbered run directories (run1, run2, ...).')

if RUN_MODE == 'fresh' and CLEAN_START and RUN_NAME and OUT_DIR.exists():
    print(f'[setup] removing existing run dir: {OUT_DIR}')
    shutil.rmtree(OUT_DIR)

if RUN_MODE == 'resume':
    OUT_DIR.mkdir(parents=True, exist_ok=True)

print('[setup] target run_dir =', OUT_DIR)



## Launch Training Pipeline

This calls `run_lab3.py` with the config above.


In [None]:
cmd = [
    'python', 'run_lab3.py',
    '--mode', RUN_MODE,
    '--out-root', str(SAVES_ROOT),
    '--strict-run-naming',
    '--manifests-root', str(MANIFESTS_ROOT),
    '--lab1-checkpoint', str(LAB1_CHECKPOINT),
    '--lab2-centroids-json', str(LAB2_TARGET_CENTROIDS),
    '--per-genre-samples', str(PER_GENRE_SAMPLES),
    '--chunks-per-track', str(CHUNKS_PER_TRACK),
    '--chunk-sampling', str(CHUNK_SAMPLING),
    '--min-start-sec', str(MIN_START_SEC),
    '--max-start-sec', str(MAX_START_SEC) if MAX_START_SEC is not None else '',
    '--split-by-track' if SPLIT_BY_TRACK else '--no-split-by-track',
    '--seed', str(SEED),
    '--val-ratio', str(VAL_RATIO),
    '--n-frames', str(N_FRAMES),
    '--batch-size', str(BATCH_SIZE),
    '--num-workers', str(NUM_WORKERS),
    '--generator-norm', str(GENERATOR_NORM),
    '--generator-spectral-norm' if GENERATOR_SPECTRAL_NORM else '',
    '--generator-upsample', str(GENERATOR_UPSAMPLE),
    '--discriminator-arch', str(DISCRIMINATOR_ARCH),
    '--discriminator-scales', str(DISCRIMINATOR_SCALES),
    '--stage1-epochs', str(STAGE1_EPOCHS),
    '--stage2-epochs', str(STAGE2_EPOCHS),
    '--lr-g', str(LR_G),
    '--lr-d', str(LR_D),
    '--adv-weight', str(ADV_WEIGHT),
    '--recon-weight', str(RECON_WEIGHT),
    '--content-weight', str(CONTENT_WEIGHT),
    '--style-weight', str(STYLE_WEIGHT),
    '--continuity-weight', str(CONTINUITY_WEIGHT),
    '--mrstft-weight', str(MRSTFT_WEIGHT),
    '--stage1-mrstft-weight', str(STAGE1_MRSTFT_WEIGHT),
    '--stage2-mrstft-weight', str(STAGE2_MRSTFT_WEIGHT),
    '--mrstft-resolutions', str(MRSTFT_RESOLUTIONS),
    '--flatness-weight', str(FLATNESS_WEIGHT),
    '--feature-match-weight', str(FEATURE_MATCH_WEIGHT),
    '--perceptual-weight', str(PERCEPTUAL_WEIGHT),
    '--style-hinge-weight', str(STYLE_HINGE_WEIGHT),
    '--contrastive-weight', str(CONTRASTIVE_WEIGHT),
    '--batch-infonce-div-weight', str(BATCH_INFONCE_DIV_WEIGHT),
    '--diversity-weight', str(DIVERSITY_WEIGHT),
    '--timbre-balance-weight', str(TIMBRE_BALANCE_WEIGHT),
    '--lowmid-recon-weight', str(LOWMID_RECON_WEIGHT),
    '--spectral-tilt-weight', str(SPECTRAL_TILT_WEIGHT),
    '--zcr-proxy-weight', str(ZCR_PROXY_WEIGHT),
    '--style-mid-weight', str(STYLE_MID_WEIGHT),
    '--hf-muzzle-weight', str(HF_MUZZLE_WEIGHT),
    '--highpass-anchor-weight', str(HIGHPASS_ANCHOR_WEIGHT),
    '--mel-diversity-weight', str(MEL_DIVERSITY_WEIGHT),
    '--target-profile-weight', str(TARGET_PROFILE_WEIGHT),
    '--stage2-d-lr-mult', str(STAGE2_D_LR_MULT),
    '--stage2-content-start', str(STAGE2_CONTENT_START),
    '--stage2-content-end', str(STAGE2_CONTENT_END),
    '--stage2-style-label-smoothing', str(STAGE2_STYLE_LABEL_SMOOTHING),
    '--stage2-style-only-warmup-epochs', str(STAGE2_STYLE_ONLY_WARMUP_EPOCHS),
    '--stage2-g-lr-warmup-epochs', str(STAGE2_G_LR_WARMUP_EPOCHS),
    '--stage2-g-lr-start-mult', str(STAGE2_G_LR_START_MULT),
    '--stage2-cond-noise-std', str(STAGE2_COND_NOISE_STD),
    '--stage2-cond-mode', str(STAGE2_COND_MODE),
    '--stage2-cond-alpha-start', str(STAGE2_COND_ALPHA_START),
    '--stage2-cond-alpha-end', str(STAGE2_COND_ALPHA_END),
    '--stage2-cond-exemplar-noise-std', str(STAGE2_COND_EXEMPLAR_NOISE_STD),
    '--stage2-target-balance' if STAGE2_TARGET_BALANCE else '--no-stage2-target-balance',
    '--stage2-style-jitter-std', str(STAGE2_STYLE_JITTER_STD),
    '--stage2-style-hinge-target-conf', str(STAGE2_STYLE_HINGE_TARGET_CONF),
    '--stage2-adaptive-content-low', str(STAGE2_ADAPTIVE_CONTENT_LOW),
    '--stage2-adaptive-content-high', str(STAGE2_ADAPTIVE_CONTENT_HIGH),
    '--stage2-adaptive-conf-low', str(STAGE2_ADAPTIVE_CONF_LOW),
    '--stage2-adaptive-conf-high', str(STAGE2_ADAPTIVE_CONF_HIGH),
    '--stage2-style-critic-lr', str(STAGE2_STYLE_CRITIC_LR),
    '--stage2-contrastive-temp', str(STAGE2_CONTRASTIVE_TEMP),
    '--stage2-batch-infonce-temp', str(STAGE2_BATCH_INFONCE_TEMP),
    '--stage2-diversity-margin', str(STAGE2_DIVERSITY_MARGIN),
    '--stage2-diversity-max-pairs', str(STAGE2_DIVERSITY_MAX_PAIRS),
    '--stage2-mel-diversity-margin', str(STAGE2_MEL_DIVERSITY_MARGIN),
    '--stage2-mel-diversity-max-pairs', str(STAGE2_MEL_DIVERSITY_MAX_PAIRS),
    '--stage2-style-lowpass-keep-bins', str(STAGE2_STYLE_LOWPASS_KEEP_BINS),
    '--stage2-style-lowpass-cutoff-hz', str(STAGE2_STYLE_LOWPASS_CUTOFF_HZ),
    '--stage2-style-mid-low-bin', str(STAGE2_STYLE_MID_LOW_BIN),
    '--stage2-style-mid-high-bin', str(STAGE2_STYLE_MID_HIGH_BIN),
    '--stage2-lowmid-split-bin', str(STAGE2_LOWMID_SPLIT_BIN),
    '--stage2-lowmid-gain', str(STAGE2_LOWMID_GAIN),
    '--stage2-high-gain', str(STAGE2_HIGH_GAIN),
    '--stage2-spectral-tilt-max-ratio', str(STAGE2_SPECTRAL_TILT_MAX_RATIO),
    '--stage2-zcr-proxy-target-max', str(STAGE2_ZCR_PROXY_TARGET_MAX),
    '--stage2-style-thaw-last-epochs', str(STAGE2_STYLE_THAW_LAST_EPOCHS),
    '--stage2-style-thaw-lr', str(STAGE2_STYLE_THAW_LR),
    '--stage2-style-thaw-scope', str(STAGE2_STYLE_THAW_SCOPE),
    '--mps-threshold', str(MPS_THRESHOLD),
    '--sf-threshold', str(SF_THRESHOLD),
    '--eval-max-batches', str(EVAL_MAX_BATCHES),
    '--auto-sample-export',
    '--sample-count', str(POSTTRAIN_SAMPLE_COUNT),
    '--sample-target-mode', str(POSTTRAIN_SAMPLE_TARGET_MODE),
    '--sample-griffin-lim-iters', str(GL_ITERS),
    '--sample-export-tag', str(POSTTRAIN_SAMPLE_EXPORT_TAG),
    '--sample-write-real-audio' if POSTTRAIN_SAMPLE_WRITE_REAL else '--no-sample-write-real-audio',
    '--device', str(DEVICE),
    '--d-real-label', str(D_REAL_LABEL),
    '--d-fake-label', str(D_FAKE_LABEL),
    '--g-real-label', str(G_REAL_LABEL),
]
if RUN_MODE == 'fresh':
    cmd.extend(['--run-name', RUN_NAME])
if RUN_MODE == 'resume':
    cmd.extend(['--resume-dir', str(OUT_DIR)])
if REUSE_CACHE_DIR is not None:
    cmd.extend(['--reuse-cache-dir', str(REUSE_CACHE_DIR)])
if MAX_BATCHES_PER_EPOCH is not None:
    cmd.extend(['--max-batches-per-epoch', str(MAX_BATCHES_PER_EPOCH)])
if SMOKE:
    cmd.append('--smoke')
if STAGE2_ADAPTIVE_CONTENT:
    cmd.append('--stage2-adaptive-content')
if RESET_STAGE2_OUT_LAYER:
    cmd.append('--reset-stage2-out-layer')
cmd = [x for x in cmd if x != '' ]
print(' '.join(cmd))
# Stream training logs directly to notebook output
p = subprocess.Popen(
    cmd,
    cwd=str(LAB3_DIR),
    stdout=subprocess.PIPE,
    stderr=subprocess.STDOUT,
    text=True,
    bufsize=1,
)
for line in p.stdout:
    print(line, end='')
ret = p.wait()
if ret != 0:
    raise subprocess.CalledProcessError(ret, cmd)
# In fresh auto-run mode, resolve OUT_DIR to the created runN folder.
if RUN_MODE == 'fresh' and RUN_NAME == '':
    latest = find_latest_run_dir(SAVES_ROOT)
    if latest is None:
        raise FileNotFoundError('No run directory with run_state.json found under SAVES_ROOT after training.')
    OUT_DIR = latest
    print(f'[post-run] resolved run_dir = {OUT_DIR}')




## Load Run Outputs


In [None]:
def resolve_run_dir_for_metrics(out_dir: Path, saves_root: Path, resume_dir: Path | None = None) -> Path:
    if (out_dir / 'run_state.json').exists():
        return out_dir

    if resume_dir is not None:
        resume_path = Path(resume_dir)
        if (resume_path / 'run_state.json').exists():
            return resume_path

    latest = find_latest_run_dir(saves_root)
    if latest is not None:
        print(f'[metrics] OUT_DIR has no run_state.json, using latest run dir: {latest}')
        return latest

    raise FileNotFoundError(
        f"No run_state.json found. Checked OUT_DIR={out_dir} and no run dirs under {saves_root}."
    )

OUT_DIR = resolve_run_dir_for_metrics(OUT_DIR, SAVES_ROOT, RESUME_DIR if RUN_MODE == 'resume' else None)
run_state_path = OUT_DIR / 'run_state.json'
audit_path = OUT_DIR / 'lab3_exit_audit.json'
history_path = OUT_DIR / 'history.csv'

run_state = json.loads(run_state_path.read_text(encoding='utf-8'))
audit = json.loads(audit_path.read_text(encoding='utf-8')) if audit_path.exists() else {}
history = pd.read_csv(history_path) if history_path.exists() else pd.DataFrame()

print('[resolved_out_dir]', OUT_DIR)
print('[run_state]')
print(json.dumps({
    'stage_cache_done': run_state.get('stage_cache_done'),
    'stage1_done': run_state.get('stage1_done'),
    'stage2_done': run_state.get('stage2_done'),
    'eval_done': run_state.get('eval_done'),
    'lab3_done': run_state.get('lab3_done'),
}, indent=2))

print('[audit]')
print(json.dumps(audit, indent=2))

history.tail(10)



SyntaxError: unterminated string literal (detected at line 38) (50314973.py, line 38)

In [None]:
# Optional quick training curves
if len(history):
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    for stg, g in history.groupby('stage'):
        axes[0].plot(g['epoch'], g['loss_g'], marker='o', label=stg)
        axes[1].plot(g['epoch'], g['loss_d'], marker='o', label=stg)
    axes[0].set_title('Generator Loss')
    axes[1].set_title('Discriminator Loss')
    for ax in axes:
        ax.set_xlabel('Epoch')
        ax.grid(True, alpha=0.3)
        ax.legend()
    plt.tight_layout()
    plt.show()


## Generation Samples (Direct Results)

This cell loads the trained Stage 2 checkpoint and creates side-by-side outputs:
- real mel / generated mel plots
- reconstructed WAVs from mel via Griffin-Lim
- a summary CSV with source/target genres and quick content similarity


In [None]:
# Post-train sample export using the shared CLI helper
cache_dir = OUT_DIR / 'cache'
idx_df, arrays, genre_to_idx = load_cache(cache_dir)
run_state_path = OUT_DIR / 'run_state.json'
run_cfg = json.loads(run_state_path.read_text(encoding='utf-8')) if run_state_path.exists() else {}
run_cfg = run_cfg.get('config', {}) if isinstance(run_cfg, dict) else {}

lab1_ckpt_for_samples = Path(run_cfg.get('lab1_checkpoint', str(LAB1_CHECKPOINT)))
lab2_centroids_for_samples = Path(run_cfg.get('lab2_centroids_json', str(LAB2_TARGET_CENTROIDS)))
g_norm = str(run_cfg.get('generator_norm', globals().get('GENERATOR_NORM', 'instance')))
g_upsample = str(run_cfg.get('generator_upsample', globals().get('GENERATOR_UPSAMPLE', 'transpose')))
g_sn = bool(run_cfg.get('generator_spectral_norm', globals().get('GENERATOR_SPECTRAL_NORM', False)))
g_mrf = bool(run_cfg.get('generator_mrf', False))
g_mrf_kernels = tuple(int(x.strip()) for x in str(run_cfg.get('generator_mrf_kernels', '3,7,11')).split(',') if x.strip())

device_t = 'cuda' if (DEVICE == 'auto' and torch.cuda.is_available()) else DEVICE
if device_t == 'auto':
    device_t = 'cpu'

encoder = FrozenLab1Encoder(lab1_ckpt_for_samples, device=device_t)
centroids = load_target_centroids(lab2_centroids_for_samples)
cond_bank = build_condition_bank(genre_to_idx, centroids).to(device_t)

G = ReconstructionDecoder(
    zc_dim=arrays['z_content'].shape[1],
    cond_dim=cond_bank.shape[1],
    n_mels=arrays['mel_norm'].shape[1],
    n_frames=arrays['mel_norm'].shape[2],
    norm=g_norm,
    upsample=g_upsample,
    spectral_norm=g_sn,
    mrf=g_mrf,
    mrf_kernels=g_mrf_kernels,
).to(device_t)
ckpt = torch.load(OUT_DIR / 'checkpoints' / 'stage2_latest.pt', map_location='cpu')
incoming = ckpt.get('generator', {})
current = G.state_dict()
filtered = {k: v for k, v in incoming.items() if (k in current and tuple(v.shape) == tuple(current[k].shape))}
G.load_state_dict(filtered, strict=False)
G.eval()

train_idx, val_idx = stratified_split_indices(arrays['genre_idx'], val_ratio=VAL_RATIO, seed=SEED)
if len(val_idx) == 0:
    val_idx = np.arange(min(POSTTRAIN_SAMPLE_COUNT, len(arrays['genre_idx'])))

source_map = {}
rs = json.loads(run_state_path.read_text(encoding='utf-8')) if run_state_path.exists() else {}
if isinstance(rs, dict):
    source_map = rs.get('genre_to_lab1_source_idx', {}) or {}

sample_out = OUT_DIR / 'samples' / POSTTRAIN_SAMPLE_EXPORT_TAG
sample_info = export_posttrain_samples(
    generator=G,
    frozen_encoder=encoder,
    arrays=arrays,
    index_df=idx_df,
    genre_to_idx=genre_to_idx,
    cond_bank=cond_bank,
    out_dir=sample_out,
    val_idx=val_idx,
    n_samples=int(POSTTRAIN_SAMPLE_COUNT),
    target_mode=str(POSTTRAIN_SAMPLE_TARGET_MODE),
    griffin_lim_iters=int(GL_ITERS),
    seed=int(SEED),
    device=device_t,
    genre_to_source_idx={str(g): int(v) for g, v in source_map.items()} if isinstance(source_map, dict) else None,
    write_real_audio=bool(POSTTRAIN_SAMPLE_WRITE_REAL),
)
print('[sample-export]', sample_info)
samples_df = pd.read_csv(sample_out / 'generation_summary.csv')
samples_df.head(10)


In [None]:
# Preview first generated sample inline (if WAV exists)
summary_csv = OUT_DIR / 'samples' / POSTTRAIN_SAMPLE_EXPORT_TAG / 'generation_summary.csv'
if summary_csv.exists():
    gen = pd.read_csv(summary_csv)
    if len(gen):
        cols = [c for c in ['source_genre', 'target_genre', 'mps_cosine'] if c in gen.columns]
        print(gen.iloc[0][cols])
        fake_wav = gen.iloc[0]['fake_wav']
        real_wav = gen.iloc[0]['real_wav']
        if isinstance(real_wav, str) and len(real_wav) and Path(real_wav).exists():
            print('Real preview:')
            display(Audio(filename=real_wav))
        if isinstance(fake_wav, str) and len(fake_wav) and Path(fake_wav).exists():
            print('Generated preview:')
            display(Audio(filename=fake_wav))


## Notes

- If style fidelity remains low in early runs, that is expected for short training.
- Use resume mode to continue training from the same run folder.
- Generation audio uses Griffin-Lim from mel and is only for qualitative sanity checks.
