
# Training – Baseline PhaseDiff Model 


In [None]:
%run 00_shared_utils.ipynb


In [None]:
from __future__ import annotations

import math
import os
import random

import torch
from torch.utils.data import Dataset, DataLoader
import torchaudio

# --- Core setup ---
device = "cuda" if torch.cuda.is_available() else "cpu"
set_seed(7)

n_fft = 512
hop = 80
win = n_fft

MF_WIN = 3
TB_WIN = 2
mag_thresh_db = -60  # silence filtering threshold

cfg = STFTCfg(sr=16000, n_fft=n_fft, hop=hop)
stft_small = CausalSTFT(n_fft=n_fft, hop_length=hop, win_length=win).to(device)

stft = CausalSTFT(cfg.n_fft, cfg.hop, cfg.n_fft).to(device)
print("Training baseline model with", cfg.__dict__)

subset_paths = load_subset_paths("subset_paths.txt")
dl = make_dataloader(subset_paths, sr=cfg.sr, seconds=3, batch_size=6)


In [None]:
# --- Dataset ---
# Reuse existing DataLoader if already in memory, otherwise we have to create one quickly
if "dl" in globals():
    print("Using existing dataloader (len:", len(dl), ")")
else:
    assert "subset_paths" in globals(), "Need subset_paths or dl from previous session."
    class ListAudioDataset(Dataset):
        def __init__(self, paths, sr=16000, seconds=3):
            self.paths = paths; self.sr = sr; self.samples = int(seconds * sr)
        def __len__(self): return len(self.paths)
        def __getitem__(self, idx):
            p = self.paths[idx]
            wav, sr = torchaudio.load(p)
            wav = wav.mean(0, keepdim=True)
            if sr != self.sr:
                wav = torchaudio.functional.resample(wav, sr, self.sr)
            if wav.shape[-1] < self.samples:
                reps = (self.samples + wav.shape[-1] - 1)//wav.shape[-1]
                wav = wav.repeat(1, reps)
            start = random.randint(0, wav.shape[-1]-self.samples)
            return wav[:, start:start+self.samples].squeeze(0)
    ds = ListAudioDataset(subset_paths, sr=cfg.sr, seconds=3)
    dl = DataLoader(ds, batch_size=6, shuffle=True, drop_last=True, num_workers=0)
    print("Created new dataloader with", len(dl), "batches.")

In [None]:

# --- Model & Optimizer ---
model = PhaseDiffPredictionModel(n_fft=cfg.n_fft, hop_length=cfg.hop, use_film=False).to(device)
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
epochs = 8  

In [None]:


def vm(a,b): return -(torch.cos(a-b)).mean()


In [None]:
for ep in range(1, epochs+1):
    model.train()
    total_fpd = total_bpd = 0.0
    steps = 0
    for wav_cpu in dl:
        wav = wav_cpu.to(device)
        with torch.no_grad():
            mag, fpd_tgt, bpd_tgt = compute_mag_fpd_bpd(
                wav, cfg.n_fft, cfg.hop, stft
            )  # mag: (B, F, T)

        # local inpainting 50% of the time
        mag_in = mag
        if random.random() < 0.5 and mag.shape[-1] > 3:
            k = 2 if random.random() < 0.5 else 1  # k = 1 or 2 with equal prob
            mag_in = inpaint_k_between_pairs_linear(mag, k=k)

        opt.zero_grad(set_to_none=True)
        fpd_pred, bpd_pred, _ = model(mag=mag_in)

        
        loss_f = vm(fpd_tgt, fpd_pred)
        loss_b = vm(bpd_tgt, bpd_pred)
        loss = loss_f + loss_b
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step()

        total_fpd += float(loss_f.detach())
        total_bpd += float(loss_b.detach())
        steps += 1

    print(f"[Epoch {ep:02d}] FPD {total_fpd/steps:+.4f} | BPD {total_bpd/steps:+.4f}")

In [None]:
# --- Save checkpoint ---
save_ckpt(model, "checkpoints/baseline_og.pth", cfg)
print("✅ Baseline model saved to checkpoints/baseline (+.json)")
