# shared utilities for all the Phase-Diff models

### Imports & global config helpers

In [None]:
from __future__ import annotations

from pathlib import Path
from typing import Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import torchaudio
from einops import rearrange



def set_seed(seed: int = 7):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True  




class STFTCfg:
    def __init__(self, sr=16000, n_fft=512, hop=80, win_length=None):
        self.sr = int(sr)
        self.n_fft = int(n_fft)
        self.hop = int(hop)
        self.win_length = int(win_length or n_fft)

    def to_json(self) -> str:
        return json.dumps({"sr": self.sr, "n_fft": self.n_fft, "hop": self.hop})

    @staticmethod
    def from_json(s: str) -> "STFTCfg":
        obj = json.loads(s)
        return STFTCfg(obj["sr"], obj["n_fft"], obj["hop"], obj["n_fft"])

def assert_same_cfg(a: STFTCfg, b: STFTCfg):
    assert (a.n_fft, a.hop, a.win_length) == (b.n_fft, b.hop, b.win_length), \
        f"STFT params differ: {(a.n_fft,a.hop,a.win_length)} vs {(b.n_fft,b.hop,b.win_length)}"


In [2]:
device = torch.device("cpu")

### Causal STFT / ISTFT (√Hann dual windows)

In [None]:
class CausalSTFT(nn.Module):
    """
    Strictly-causal STFT: center=False and a single *left* pad so the very first
    frame has a full window. √Hann analysis window.
    """
    def __init__(self, n_fft: int, hop_length: int, win_length: int, window_type="sqrt_hann"):
        super().__init__()
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.win_length = win_length

        if window_type == "sqrt_hann":
            w = torch.sqrt(torch.hann_window(win_length))
        elif window_type == "hann":
            w = torch.hann_window(win_length)
        elif window_type == "rectangular":
            w = torch.ones(win_length)
        else:
            raise ValueError(f"Invalid window type: {window_type}")
        self.register_buffer("window", w)
        self.left_pad = max(0, win_length - hop_length)

    def forward(self, wav: torch.Tensor) -> torch.Tensor:
        # wav: (B, L)
        if wav.dim() != 2: raise ValueError("wav must be (B, L)")
        if self.left_pad:
            wav = tF.pad(wav, (self.left_pad, 0))
        return torch.stft(
            wav, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length,
            window=self.window.to(wav.device), center=False, return_complex=True
        )


class CausalISTFT(nn.Module):
    """
    Dual of the CausalSTFT above (√Hann synthesis). Strictly causal OLA.
    """
    def __init__(self, n_fft: int, hop_length: int, win_length: int):
        super().__init__()
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.win_length = win_length
        self.register_buffer("window", torch.sqrt(torch.hann_window(win_length)))
        self.left_pad = max(0, win_length - hop_length)

    def forward(self, spec: torch.Tensor) -> torch.Tensor:
        # spec: (B, F, T) complex, F = n_fft//2+1
        if spec.dim() != 3: raise ValueError("spec must be (B, F, T) complex")
        B, F, T = spec.shape

        time_frames = torch.fft.irfft(spec, n=self.n_fft, dim=1, norm="backward")
        time_frames = time_frames * self.window[None, :, None]

        out_len = (T - 1) * self.hop_length + self.win_length
        wav = tF.fold(
            time_frames, output_size=(1, out_len),
            kernel_size=(1, self.win_length), stride=(1, self.hop_length)
        )[:, 0, 0]  # (B, out_len)

        if self.left_pad:
            wav = wav[..., self.left_pad:]

        # window envelope normalization
        w2 = self.window.square().expand(1, T, -1).transpose(1, 2)
        env = tF.fold(
            w2, output_size=(1, out_len),
            kernel_size=(1, self.win_length), stride=(1, self.hop_length)
        ).squeeze()
        if self.left_pad:
            env = env[self.left_pad:]
        env = env.clamp_min(1e-3)
        return wav / env


### Angle utilities and also phase-diff conversions

In [5]:
def wrap(ang: torch.Tensor, ste: bool = False) -> torch.Tensor:
    out = torch.remainder(ang + math.pi, 2 * math.pi) - math.pi
    if ste:
        out = ang + (out - ang).detach()
    return out

def bpd_to_tpd(bpd: torch.Tensor, n_fft: int, hop_length: int, keep_first_frame: bool = False, squeeze: bool = True) -> torch.Tensor:
    if bpd.ndim == 2:
        bpd = bpd.unsqueeze(0); added_batch = True
    else:
        added_batch = False
    B, F, T = bpd.shape
    m = torch.arange(F, device=bpd.device).view(1, F, 1)
    bias = 2 * math.pi * hop_length * m / n_fft
    phi = bpd[:, :, 1:] + bias
    core = wrap(phi, ste=True)
    tpd = torch.cat([torch.zeros_like(core[:, :, :1]), core], dim=2) if keep_first_frame else core
    if added_batch and squeeze:
        tpd = tpd.squeeze(0)
    return tpd


### Dataset-agnostic targets (mag, FPD, BPD)

In [6]:
@torch.no_grad()
def compute_mag_fpd_bpd(
    wav: torch.Tensor, n_fft: int, hop: int, stft_mod: CausalSTFT
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    wav: (B, L) -> mag (B,F,T), fpd (B,F-1,T), bpd (B,F,T)
    """
    S = stft_mod(wav)                    # (B, F, T) complex
    phase = torch.angle(S)               # (B, F, T)
    mag = S.abs().clamp_min(1e-8)

    fpd = wrap(phase[:, 1:, :] - phase[:, :-1, :])
    tpd = torch.zeros_like(phase)
    tpd[:, :, 1:] = wrap(phase[:, :, 1:] - phase[:, :, :-1])

    B, F, T = phase.shape
    m = torch.arange(F, device=phase.device).view(1, F, 1)
    bpd = wrap(tpd - 2 * math.pi * hop * m / n_fft)
    return mag, fpd, bpd


In [None]:


@no_grad()
def pv_time_stretch(
    wav: Tensor,
    n_fft: int,
    hop_length: int,
    rate: float,
) -> Tensor:
    """
    Phase-vocoder time-stretch (pitch-preserving) using torchaudio.transforms.TimeStretch.

    wav : (B, N) or (N,) mono tensor
    rate: >1.0 -> slower, <1.0 -> faster

    Returns:
        wav_out: (B, N_out) stretched waveform at the SAME sample rate
    """
    # Ensure (B, N)
    if wav.dim() == 1:
        wav = wav.unsqueeze(0)
    assert wav.dim() == 2, f"Expected 2D (B,N) wav, got {wav.shape}"

    B, N = wav.shape
    device = wav.device

    window = torch.hann_window(n_fft, device=device)

    # 1) STFT: (B, F, T) complex
    spec = torch.stft(
        wav,
        n_fft=n_fft,
        hop_length=hop_length,
        win_length=n_fft,
        window=window,
        center=True,
        return_complex=True,
    )  # (B, F, T)

    # 2) View as real for TimeStretch: (B, F, T, 2)
    spec_ri = torch.view_as_real(spec)
    n_freq = spec_ri.shape[1]  # this is the TRUE F

    # 3) TimeStretch handles phase_advance internally, no manual phase_vocoder
    ts = torchaudio.transforms.TimeStretch(
        hop_length=hop_length,
        n_freq=n_freq,
        fixed_rate=rate,
    ).to(device)

    spec_stretch_ri = ts(spec_ri)            # (B, F, T', 2)
    spec_stretch = torch.view_as_complex(spec_stretch_ri)  # (B, F, T')

    # 4) ISTFT back to waveform. Length scales ~ rate
    N_out = int(math.ceil(N * rate))
    wav_out = torch.istft(
        spec_stretch,
        n_fft=n_fft,
        hop_length=hop_length,
        win_length=n_fft,
        window=window,
        center=True,
        length=N_out,
    )  # (B, N_out)

    return wav_out



### Waveform reconstruction from mag + (FPD,TPD) or (FPD,BPD)

In [None]:
@torch.no_grad()
def solve_tridiag_cpu(lower: torch.Tensor, diag: torch.Tensor, upper: torch.Tensor, rhs: torch.Tensor) -> torch.Tensor:
    """
    CPU solve for Hermitian tridiagonal system using SciPy if available,
    otherwise a generic banded solve via torch.linalg for small sizes.
    """
    try:
        import scipy.linalg as la
        n = diag.numel()
        ab = torch.empty((3, n), dtype=diag.dtype)
        ab[0, 1:] = upper; ab[0, 0] = 0
        ab[1, :]  = diag
        ab[2, :-1] = lower; ab[2, -1] = 0
        x_np = la.solve_banded((1, 1), ab.numpy(), rhs.numpy(), overwrite_ab=False, overwrite_b=False)
        return torch.as_tensor(x_np, dtype=rhs.dtype, device=rhs.device)
    except Exception:
        # 
        # 
        F = diag.numel()
        a = lower.clone().cpu().numpy()
        b = diag.clone().cpu().numpy()
        c = upper.clone().cpu().numpy()
        d = rhs .clone().cpu().numpy()
        # forward sweep
        for i in range(1, F):
            w = a[i-1] / b[i-1]
            b[i] = b[i] - w * c[i-1]
            d[i] = d[i] - w * d[i-1]
        x = d
        x[-1] = x[-1] / b[-1]
        for i in range(F-2, -1, -1):
            x[i] = (x[i] - c[i] * x[i+1]) / b[i]
        return torch.as_tensor(x, dtype=rhs.dtype, device=rhs.device)


@torch.no_grad()
def mag_fpd_tpd_to_waveform(mag: torch.Tensor, fpd: torch.Tensor, tpd: torch.Tensor, n_fft: int, hop_length: int, eps: float = 1e-8, squeeze: bool = False) -> torch.Tensor:
    if mag.ndim == 2:
        mag, fpd, tpd = mag.unsqueeze(0), fpd.unsqueeze(0), tpd.unsqueeze(0)
    B, F, T = mag.shape
    device = mag.device
    stft_hat = torch.zeros((B, F, T), dtype=torch.complex64, device=device)

    # t=0 from cumulative FPD
    phase_f0 = torch.zeros((B, F), device=device)
    phase_f0[:, 1:] = torch.cumsum(fpd[:, :, 0], dim=1)
    stft_hat[:, :, 0] = mag[:, :, 0] * torch.exp(1j * phase_f0)

    for t in range(1, T):
        Y_prev = stft_hat[:, :, t - 1]
        ratio_u = (mag[:, 1:, t] / (mag[:, :-1, t] + eps)) * torch.exp(1j * fpd[:, :, t])  # (B, F-1)
        ratio_v = (mag[:, :,  t] / (mag[:, :,  t-1] + eps)) * torch.exp(1j * tpd[:, :, t]) # (B, F)

        abs_u_sq = ratio_u.abs().square()
        diag = torch.ones((B, F), dtype=torch.complex64, device=device)
        diag[:, 0] += abs_u_sq[:, 0]
        diag[:, 1:-1] += abs_u_sq[:, 1:] + 1
        diag[:, -1] += 1

        lower = -ratio_u.clone()         # (B, F-1)
        upper = -ratio_u.conj()          # (B, F-1)
        rhs = Y_prev * ratio_v           # (B, F)

        z = torch.empty_like(rhs)
        for b in range(B):
            z[b] = solve_tridiag_cpu(lower[b].cpu(), diag[b].cpu(), upper[b].cpu(), rhs[b].cpu())
        stft_hat[:, :, t] = mag[:, :, t] * torch.exp(1j * torch.angle(z))

    istft = CausalISTFT(n_fft, hop_length, n_fft).to(device)
    y = istft(stft_hat)
    if squeeze and y.shape[0] == 1:
        y = y.squeeze(0)
    return y


@torch.no_grad()
def mag_fpd_bpd_to_waveform(mag: torch.Tensor, fpd: torch.Tensor, bpd: torch.Tensor, n_fft: int, hop_length: int) -> torch.Tensor:
    tpd = bpd_to_tpd(bpd, n_fft=n_fft, hop_length=hop_length, keep_first_frame=True)
    return mag_fpd_tpd_to_waveform(mag, fpd, tpd, n_fft=n_fft, hop_length=hop_length)


In [9]:
@torch.no_grad()
def inpaint_k_between_pairs_linear(mag: torch.Tensor, k: int) -> torch.Tensor:
    """
    Replace k intermediate frames between anchors with linear blends, keeping T fixed.
    mag: (B, F, T) linear magnitude
    For k=1: replace indices 1,3,5,... with 0.5*(left+right)
    For k=2: replace indices 1 and 2 in each block of 3 with 1/3 and 2/3 blends.
    """
    if k <= 0:
        return mag

    B, F, T = mag.shape
    step = k + 1
    out = mag.clone()

    # Walk in steps of (k+1), blending the k frames in between the endpoints
    for left in range(0, T - step, step):
        right = left + step
        for j in range(1, step):
            t = left + j
            alpha = j / float(step)
            out[:, :, t] = (1.0 - alpha) * mag[:, :, left] + alpha * mag[:, :, right]

    return out


In [10]:
# --- small conv stack w/ causal time padding ---

def _causal_pad(x, k_f, k_t):
    pad_t = (k_t - 1, 0)
    pad_f = (k_f // 2, k_f // 2)
    return tF.pad(x, pad_t + pad_f, mode="constant", value=0.0)

class CausalConv2d(nn.Module):
    def __init__(self, Cin, Cout, k_f, k_t, stride_t=1, groups=1, bias=True):
        super().__init__()
        self.k_f, self.k_t = k_f, k_t
        self.conv = weight_norm(
            nn.Conv2d(Cin, Cout, kernel_size=(k_f, k_t), stride=(1, stride_t), padding=0, groups=groups, bias=bias)
        )

    def forward(self, x):
        x = _causal_pad(x, self.k_f, self.k_t)
        return self.conv(x)

class FreqGatedConv(nn.Module):
    def __init__(self, in_ch: int, out_ch: int, k_f: int, k_t: int):
        super().__init__()
        self.conv1 = CausalConv2d(in_ch, out_ch, k_f=k_f, k_t=k_t)
        self.conv2 = CausalConv2d(in_ch, out_ch, k_f=k_f, k_t=k_t)
    def forward(self, x):
        return self.conv1(x) * torch.sigmoid(self.conv2(x))

class Stem(nn.Module):
    def __init__(self, in_ch: int, bn: bool = False):
        super().__init__()
        self.net1 = nn.Sequential(
            nn.BatchNorm2d(in_ch) if bn else nn.Identity(),
            CausalConv2d(in_ch, 50, k_f=3, k_t=4, stride_t=1),
            nn.LeakyReLU(0.1, inplace=True),
        )
        self.net2 = FreqGatedConv(50, 10, k_f=1, k_t=1)
        self.norm = nn.BatchNorm2d(10) if bn else nn.Identity()

    def forward(self, x):
        return self.norm(self.net2(self.net1(x)))

class BodyBlock(nn.Module):
    def __init__(self, dim, k_f=1, k_t=1, stride_t=1, bn=False):
        super().__init__()
        if k_t > 1:
            conv = CausalConv2d(dim, dim, k_f=k_f, k_t=k_t, stride_t=stride_t)
        else:
            conv = weight_norm(nn.Conv2d(dim, dim, kernel_size=(k_f, 1), stride=(1, stride_t), padding=((k_f - 1)//2, 0)))
        self.block = nn.Sequential(conv, nn.BatchNorm2d(dim) if bn else nn.Identity(), nn.LeakyReLU(0.1, inplace=True))
    def forward(self, x): return self.block(x)

class Body(nn.Module):
    def __init__(self, dim=10, depth=5, k_f=1, k_t=1, stride_t=1, bn=False, use_film=False, film_scale=0.05):
        super().__init__()
        self.layers = nn.ModuleList([BodyBlock(dim, k_f=k_f, k_t=k_t, stride_t=stride_t, bn=bn) for _ in range(depth)])
        self.norm = nn.BatchNorm2d(dim) if bn else nn.Identity()
        self.use_film = use_film
        # self.film = FiLM(dim, hidden=64, scale=film_scale) if use_film else None
        self.film = FiLM(dim, hidden=64, strength=film_scale) if use_film else None

    def forward(self, x, z=None):
        feats = []
        for blk in self.layers:
            y = blk(x)
            if self.use_film and (z is not None):
                y = y + self.film(y, z)
            x = x + y
            feats.append(x)
        return self.norm(x), feats

class Head(nn.Module):
    def __init__(self, out_per_target: int):
        super().__init__()
        self.pre = FreqGatedConv(20, 50, k_f=3, k_t=1)
        self.out_bpd = weight_norm(nn.Conv2d(50, out_per_target, kernel_size=1))
        self.out_fpd = weight_norm(nn.Conv2d(50, out_per_target, kernel_size=1))
    def forward(self, x):
        h = self.pre(x)
        return self.out_bpd(h), self.out_fpd(h)


In [None]:
class PhaseDiffPredictionModel(nn.Module):
    """
    Single class for Baseline.
    Expects input magnitude as log10(|S|)/6 folded to (B,2,F/2,T) internally.
    """
    def __init__(self, n_fft: int, hop_length: int, freq_fold_size: int = 2, bn: bool = False, use_film: bool = False):
        super().__init__()
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.freq_fold_size = freq_fold_size
        self.use_film = use_film

        self.stft = CausalSTFT(n_fft, hop_length, n_fft)
        self.stem_mag = Stem(in_ch=freq_fold_size, bn=bn)
        self.stem_merge = nn.Conv2d(10, 10, kernel_size=1)

        #
        self.body = Body(bn=bn, use_film=False)
        self.head = Head(out_per_target=freq_fold_size)
        



        self.apply(self._init)

    @staticmethod
    def _init(m):
        if isinstance(m, (nn.Conv2d, nn.Conv1d, nn.Linear)):
            nn.init.normal_(m.weight, mean=0.0, std=0.02)
            if m.bias is not None: nn.init.zeros_(m.bias)

    def forward(self, *, wav: Optional[torch.Tensor] = None, mag: Optional[torch.Tensor] = None, cond_z: Optional[torch.Tensor] = None):
        fmaps = []  # we don't use them during training, i can therfore keep emptyy

        if (wav is None) == (mag is None):
            raise ValueError("Provide either wav OR mag")

        if wav is not None:
            mag_log = self.stft(wav).abs().clamp_min(1e-6).log10().unsqueeze(1) / 6.0  # (B,1,F,T)
        else:
            mag_log = mag.clamp_min(1e-6).log10().unsqueeze(1) / 6.0

        # Fold frequency bins into 'channels' = freq_fold_size
        mag_log = mag_log[:, :, :-1, :]                                   # drop Nyquist for even split
        mag_log = rearrange(mag_log, 'b 1 (f n) t -> b n f t', n=self.freq_fold_size)

        z0 = self.stem_mag(mag_log)                                       # (B,10,F,T)

        if self.use_film:
            if cond_z is None:
                cond_z = torch.ones(z0.size(0), device=z0.device)
            z0 = self.film_stem(z0, norm_z(cond_z))

        z1, _ = self.body(z0)                                             # (B,10,F,T)
        z = torch.cat([z0, z1], dim=1)                                    # (B,20,F,T)
        bpd, fpd = self.head(z)                                           # (B,2,F',T)

      
        # Unfold to (B, F_flat, T)
        bpd = rearrange(bpd, 'b n f t -> b (f n) t')  # expect F_flat == (F_in-1)
        fpd = rearrange(fpd, 'b n f t -> b (f n) t')


        F_in  = mag_log.shape[2] * self.freq_fold_size + 1  # add Nyquist we dropped
        F_fpd = F_in - 1


        # Make bpd exactly F_in by duplicating the last bin if needed
        if bpd.size(1) < F_in:
            bpd = torch.cat([bpd, bpd[:, -1:, :]], dim=1)
        elif bpd.size(1) > F_in:
            bpd = bpd[:, :F_in, :]

        # Make fpd exactly F_in-1
        if fpd.size(1) < F_fpd:
            fpd = torch.cat([fpd, fpd[:, -1:, :]], dim=1)
        elif fpd.size(1) > F_fpd:
            fpd = fpd[:, :F_fpd, :]

        return fpd, bpd, fmaps



## loss and simple stretching helper (log-mag bilinear)

In [None]:
def vm_loss(target_rad: torch.Tensor, pred_rad: torch.Tensor) -> torch.Tensor:
    """Von–Mises / negative cosine loss on angles (same shape tensors, radians)."""
    return -(torch.cos(target_rad - pred_rad)).mean()

@torch.no_grad()
def stretch_logmag(mag: torch.Tensor, factor: float) -> torch.Tensor:
    """
    Stretch along time in log-magnitude using bilinear interpolation, then return to linear.
    mag: (B, F, T) linear amplitude
    """
    logmag = mag.clamp_min(1e-8).log10().unsqueeze(1)  # (B,1,F,T)
    logmag_s = tF.interpolate(logmag, scale_factor=(1.0, factor), mode="bilinear", align_corners=False).squeeze(1)
    return (10**logmag_s).clamp_min(1e-8)



@torch.no_grad()
def logmag_stretch_then_match_T(mag: torch.Tensor, factor: float) -> torch.Tensor:
    """
    Stretch magnitude in log-domain along time by `factor`, then resample
    back to the original T so targets still align. Shape stays (B,F,T).
    """
    B, Freq, T = mag.shape
    logmag = mag.clamp_min(1e-8).log10().unsqueeze(1)            # (B,1,F,T)
    stretched = tF.interpolate(logmag, scale_factor=(1.0, factor),
                               mode="bilinear", align_corners=False).squeeze(1)     # (B,F,T')
    matched   = tF.interpolate(stretched.unsqueeze(1), size=(Freq, T),
                               mode="bilinear", align_corners=False).squeeze(1)     # (B,F,T)
    return (10**matched).clamp_min(1e-8)



### Checkpoint helpers for us

In [None]:
def save_ckpt(model: nn.Module, path: str, cfg: STFTCfg):
    os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
    torch.save(model.state_dict(), path)
    with open(path + ".json", "w") as f:
        f.write(cfg.to_json())

def load_ckpt(path: str, device: str = None, use_film: bool = False, fallback: STFTCfg = STFTCfg()) -> Tuple[nn.Module, STFTCfg]:
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    
    cfg_path = path + ".json"
    cfg = fallback
    if os.path.exists(cfg_path):
        with open(cfg_path) as f:
            cfg = STFTCfg.from_json(f.read())

    model = PhaseDiffPredictionModel(n_fft=cfg.n_fft, hop_length=cfg.hop, use_film=use_film).to(device)
    sd = torch.load(path, map_location=device)
    if isinstance(sd, dict) and "state_dict" in sd:
        sd = sd["state_dict"]
    model.load_state_dict(sd, strict=False)
    model.eval()
    return model, cfg


### audio I/O helpers

In [14]:
def load_audio(path: str, target_sr: int) -> Tuple[torch.Tensor, int]:
    """
    Returns (wav, sr) with wav shaped (1, T), float32, mono, resampled to target_sr.
    """
    try:
        wav, sr = torchaudio.load(path, backend="soundfile")
    except Exception:
        import soundfile as sf
        data, sr = sf.read(path, always_2d=True)
        wav = torch.from_numpy(data).T.contiguous()

    wav = wav.float()
    if wav.size(0) > 1:
        wav = wav.mean(dim=0, keepdim=True)
    if sr != target_sr:
        wav = torchaudio.functional.resample(wav, sr, target_sr)
        sr = target_sr
    return wav, sr

@torch.no_grad()
def match_rms(y: torch.Tensor, ref: torch.Tensor, eps=1e-8) -> torch.Tensor:
    if y.ndim == 1: y = y.unsqueeze(0)
    if ref.ndim == 1: ref = ref.unsqueeze(0)
    rms_y   = torch.sqrt((y**2).mean(dim=-1, keepdim=True) + eps)
    rms_ref = torch.sqrt((ref**2).mean(dim=-1, keepdim=True) + eps)
    return torch.clamp(y * (rms_ref / rms_y), -1.0, 1.0)


In [None]:
# --- Data bootstrap utilities  ---
def load_subset_paths(paths_txt="subset_paths.txt", data_dirs=None, exts=(".wav",".flac",".mp3",".ogg")):
    """
    Load persisted file list if available; otherwise scan once and persist.
    Returns: list[str] of absolute or relative paths.
    """
    if os.path.exists(paths_txt):
        with open(paths_txt) as f:
            subset_paths = [ln.strip() for ln in f if ln.strip()]
        print(f"Loaded {len(subset_paths)} paths from {paths_txt}.")
        return subset_paths

    # we just scan once
    if data_dirs is None:
        data_dirs = ["./ears", "./libritts", "./vctk", "./data"]

    subset_paths = []
    for root in data_dirs:
        if not os.path.isdir(root):
            continue
        for ext in exts:
            subset_paths += glob.glob(os.path.join(root, f"**/*{ext}"), recursive=True)

    subset_paths = [p for p in subset_paths if os.path.getsize(p) > 0]
    assert subset_paths, f"No audio found. Edit data_dirs={data_dirs} or place a {paths_txt}."

    with open(paths_txt, "w") as f:
        f.write("\n".join(subset_paths))
    print(f"Scanned and saved {len(subset_paths)} paths to {paths_txt}.")
    return subset_paths


class ListAudioDataset(Dataset):
    """
    Mono + resample + random crop, like your training cell.
    Returns float32 tensors shaped (N,) at target sr.
    """
    def __init__(self, paths, sr=16000, seconds=3):
        self.paths = paths
        self.sr = sr
        self.samples = int(seconds * sr)

    def __len__(self): return len(self.paths)

    def __getitem__(self, idx):
        p = self.paths[idx]
        wav, sr = torchaudio.load(p)               # (C, L)
        wav = wav.mean(0, keepdim=True)            # mono
        if sr != self.sr:
            wav = torchaudio.functional.resample(wav, sr, self.sr)
        if wav.shape[-1] < self.samples:           # pad by repetition if too short
            reps = (self.samples + wav.shape[-1] - 1) // wav.shape[-1]
            wav = wav.repeat(1, reps)
        start = random.randint(0, wav.shape[-1] - self.samples)
        return wav[:, start:start+self.samples].squeeze(0)   # (N,)

def make_dataloader(subset_paths, *, sr=16000, seconds=3, batch_size=6, num_workers=0, shuffle=True, drop_last=True):
    ds = ListAudioDataset(subset_paths, sr=sr, seconds=seconds)
    dl = DataLoader(ds, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers)
    print(f"DataLoader ready: {len(ds)} files → {len(dl)} batches")
    return dl


In [None]:
# ---- Stretch-aware front-end helpers ----------------------------------------

def _interp_time_linear_1d(x_1ft: torch.Tensor, new_T: int) -> torch.Tensor:
    """
    x_1ft: (1, F, T)  (float)
    returns: (1, F, new_T)
    """
    return tF.interpolate(
        x_1ft.unsqueeze(0), size=(x_1ft.size(1), new_T),
        mode="bilinear", align_corners=False
    ).squeeze(0)

@torch.no_grad()
def resample_logmag_per_item(mag: torch.Tensor, z: torch.Tensor) -> list:
    """
    mag: (B, F, T) linear magnitude
    z  : (B,) stretch factors
    returns: list of (mag_z_log, T_b) with time length T_b = round(z_b * T)
    """
    B, F, T = mag.shape
    out = []
    logmag = mag.clamp_min(1e-8).log10()  # (B,F,T)
    for b in range(B):
        zb = float(z[b].item())
        Tb = max(1, int(round(T * zb)))
        m  = _interp_time_linear_1d(logmag[b:b+1], Tb)  # (1,F,Tb)
        out.append((m.squeeze(0), Tb))
    return out  # list of (F, Tb) in log-domain

def sinusoidal_positional_embedding(T: int, dim: int, device) -> torch.Tensor:
    """
    Classic transformer-style sin/cos PE evaluated on indices 0..T-1 (float ok).
    returns: (dim, T)
    """
    assert dim % 2 == 0, "dim must be even"
    pos = torch.arange(T, device=device, dtype=torch.float32)          # (T,)
    i = torch.arange(dim // 2, device=device, dtype=torch.float32)
    inv_freq = 1.0 / (10000 ** (2 * i / dim))
    angles = pos[None, :] * inv_freq[:, None]                           # (dim/2, T)
    pe = torch.cat([torch.sin(angles), torch.cos(angles)], dim=0)       # (dim, T)
    return pe




In [None]:


def build_stretch_features(
    mag: torch.Tensor, z: torch.Tensor, *,
    freq_fold_size: int = 2, pos_dim: int = 16,
    sampler=None,
    pe_module=None,
    pe_concat_fixed: bool = False,
    pe_jitter: float = 0.0,
    pe_extrap_frac: float = 0.0,
):

    """
    Returns:
      feats: (B, C_in, F_fold, T_max) with C_in = freq_fold_size + pos_dim + 1
      mag_z: (B, F, T_max)
    """
    B, F, T = mag.shape
    device = mag.device
    if sampler is None:
        sampler = Sampler1D("linear")

    # 1) per-item resample log-magnitude to Tb approx z*T 
    logmag = mag.clamp_min(1e-8).log10()  # (B,F,T)
    items = []
    T_max = 0
    for b in range(B):
        Tb = max(1, int(round(float(z[b])*T)))
        x  = logmag[b:b+1].unsqueeze(1)           # (1,1,F,T)
        x_ = sampler(x, Tb).squeeze(1)            # (1,F,Tb) -> (1,F,Tb)
        items.append(x_.squeeze(0))               # (F,Tb)
        T_max = max(T_max, Tb)

    feats_list, mags_lin = [], []
    for b in range(B):
        m_log = items[b]                          # (F,Tb)
        Tb = m_log.size(1)

        # here wedrop Nyquist for even fold and fold into channels
        m_log = m_log[:-1, :]                     # (F-1,Tb)
        F_use = m_log.size(0)
        m_in = m_log.unsqueeze(0).unsqueeze(0)    # (1,1,F-1,Tb)

        pad_f = (freq_fold_size - (F_use % freq_fold_size)) % freq_fold_size
        if pad_f: m_in = tF.pad(m_in, (0,0,0,pad_f))
        m_fold = rearrange(m_in, 'b c (f n) t -> b n f t', n=freq_fold_size)  # (1,n,F_fold,Tb)

        # 2) positional embedding channel(s)
        if pe_module is not None:
            pe_learn = pe_module(T_out=Tb, device=device, jitter=pe_jitter, extrap_frac=pe_extrap_frac)  # (dim,Tb)
            pe_list = [pe_learn]
            if pe_concat_fixed:
                pe_fixed = sinusoidal_positional_embedding(Tb, pos_dim//2, device)
                pe_learn = pe_learn[:pos_dim - pe_fixed.size(0), :] if pos_dim > pe_fixed.size(0) else pe_learn
                pe_list = [pe_fixed, pe_learn]
            pe = torch.cat(pe_list, dim=0)[:pos_dim, :]    # (pos_dim, Tb)
        else:
            pe = sinusoidal_positional_embedding(Tb, pos_dim, device)

        pe = pe.unsqueeze(0).unsqueeze(2)                  # (1,pos_dim,1,Tb)

        # 3) z-map channel
        z_map = z[b:b+1].view(1,1,1,1).expand(1,1,1,Tb)

        # time-pad to T_max
        if Tb < T_max:
            pad_t = (0, T_max - Tb)
            m_fold = tF.pad(m_fold, pad_t)
            pe     = tF.pad(pe, pad_t)
            z_map  = tF.pad(z_map, pad_t)

        feats_list.append(torch.cat([m_fold, pe, z_map], dim=1))  # (1, n+pos_dim+1, F_fold, T_max)

        # we keep full-F magnitude (linear) for inversion
        full_log = items[b]              # (F,Tb) before Nyquist drop
        if Tb < T_max:
            full_log = tF.pad(full_log.unsqueeze(0), (0, T_max - Tb)).squeeze(0)
        mags_lin.append((10.0**full_log).clamp_min(1e-8))

    feats = torch.cat(feats_list, dim=0)       # (B, C_in, F_fold, T_max)
    mag_z = torch.stack(mags_lin, dim=0)       # (B, F, T_max)
    return feats, mag_z


In [None]:

class StretchAwareModel(nn.Module):
    """
    Consumes features from build_stretch_features:
      feats: (B, C_in, F_fold, T_max)  where C_in = freq_fold_size + pos_dim + 1
    Predicts (fpd, bpd) at unfolded F sizes to match mag_z for inversion.
    """
    def __init__(self, n_fft: int, hop_length: int,
                 in_channels: int,          # freq_fold_size + pos_dim + 1
                 freq_fold_size: int = 2, bn: bool = False):
        super().__init__()
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.freq_fold_size = freq_fold_size


        self.stem = Stem(in_ch=in_channels, bn=bn)
        self.body = Body(bn=bn, use_film=False)   
        self.head = Head(out_per_target=freq_fold_size)

        self.apply(self._init)

    @staticmethod
    def _init(m):
        if isinstance(m, (nn.Conv2d, nn.Conv1d, nn.Linear)):
            nn.init.normal_(m.weight, mean=0.0, std=0.02)
            if m.bias is not None: nn.init.zeros_(m.bias)

    def forward(self, feats: torch.Tensor, mag_for_unfold: torch.Tensor):
        """
        feats: (B, C_in, F_fold, T)    from build_stretch_features
        mag_for_unfold: (B, F, T)      full-F linear magnitude for inversion
        Returns: fpd, bpd shaped to match (F-1, T) and (F, T)
        """
        z0 = self.stem(feats)                       # (B, 10, F_fold, T)
        z1, _ = self.body(z0)                       # (B, 10, F_fold, T)
        z  = torch.cat([z0, z1], dim=1)             # (B, 20, F_fold, T)
        bpd_fold, fpd_fold = self.head(z)           # (B, n, F_fold, T) each

        # Unfold back in frequency
        bpd = rearrange(bpd_fold, 'b n f t -> b (f n) t')   # (B, F_flat, T)
        fpd = rearrange(fpd_fold, 'b n f t -> b (f n) t')

        F_in  = mag_for_unfold.size(1)              
        F_fpd = F_in - 1

       
        if bpd.size(1) < F_in:  bpd = torch.cat([bpd, bpd[:, -1:, :]], dim=1)
        if bpd.size(1) > F_in:  bpd = bpd[:, :F_in, :]
        if fpd.size(1) < F_fpd: fpd = torch.cat([fpd, fpd[:, -1:, :]], dim=1)
        if fpd.size(1) > F_fpd: fpd = fpd[:, :F_fpd, :]
        return fpd, bpd


In [None]:

def pv_stretch_waveform(wav: torch.Tensor, z: float, n_fft: int, hop: int) -> torch.Tensor:
    """
    Minimal phase-vocoder (batch-safe), implemented in plain PyTorch.
    - wav: (B, N)
    - z: stretch factor (>1 => longer/slower)
    - returns: (B, ~z*N)
    """
    device = wav.device
    stft = CausalSTFT(n_fft, hop, n_fft).to(device)
    istft = CausalISTFT(n_fft, hop, n_fft).to(device)

    # Analysis STFT: (B, F, T) complex
    S = stft(wav)
    mag = S.abs()                # (B, F, T)
    phase = torch.angle(S)       # (B, F, T)
    B, F, T = S.shape

    # Synthesis frame count
    T_out = max(1, int(math.ceil(T * z)))

    # omega_k * hop (phase advance per bin per hop)
    k = torch.arange(F, device=device).view(1, F, 1)
    phase_advance = (2.0 * math.pi * hop * k / n_fft)      # (1, F, 1)

    # Init accumulator with first analysis phase
    phase_acc = phase[..., 0:1].clone()                    # (B, F, 1)

    # Output buffers
    mag_out   = torch.zeros(B, F, T_out, device=device)
    phase_out = torch.zeros(B, F, T_out, device=device)

    # Helper to wrap to [-pi, pi]
    def princarg(x):  
        return torch.remainder(x + math.pi, 2*math.pi) - math.pi

    # Main PV loop (vectorized over batch + freq, iterates over output time)
    for t_out in range(T_out):
        t_src = t_out / z
        t0 = int(math.floor(t_src))
        t1 = min(t0 + 1, T - 1)
        frac = t_src - t0

        # Linear mag interpolation
        m0 = mag[..., t0]            # (B, F)
        m1 = mag[..., t1]            # (B, F)
        m  = (1.0 - frac) * m0 + frac * m1


        # delta_phi = princarg(phi[t1] - phi[t0] - phase_advance)
        dphi = princarg(phase[..., t1] - phase[..., t0] - phase_advance.squeeze(-1))

        # Accumulate phase: advance one hop plus a fraction of delt phi
        phase_acc = phase_acc + phase_advance + frac * dphi.unsqueeze(-1)   # we keep last dim=1

        mag_out[..., t_out]   = m
        phase_out[..., t_out] = phase_acc.squeeze(-1)

    # Recompose complex STFT and invert
    S_out = mag_out * torch.exp(1j * phase_out)   # (B, F, T_out)
    y = istft(S_out)                               # (B, ~z*N)
    return y


In [None]:
# ---- Angle-safe helpers ------------------------------------------------------

def _wrap_pi(x: torch.Tensor) -> torch.Tensor:
    # wrap to (-pi, pi]
    return torch.remainder(x + math.pi, 2 * math.pi) - math.pi

@torch.no_grad()
def angle_time_interpolate(x: torch.Tensor, new_T: int) -> torch.Tensor:
    """
    Angle tensor (B, F, T) or (B, Fm1, T) -> resize along time to new_T using
    complex projection to avoid branch cuts.
    """
    B, F, T = x.shape
    # represent on unit circle and interpolate real/imag separately
    c = torch.cos(x)
    s = torch.sin(x)
    c = tF.interpolate(c, size=new_T, mode="linear", align_corners=False)
    s = tF.interpolate(s, size=new_T, mode="linear", align_corners=False)
    ang = torch.atan2(s, c)
    return _wrap_pi(ang)

@torch.no_grad()
def compute_pv_targets_per_batch(wav: torch.Tensor, z: torch.Tensor,
                                 n_fft: int, hop: int, stft_mod: CausalSTFT):
    """
    wav: (B, N)
    z  : (B,)
    Returns per-batch lists (one per item) of:
      - y_ref (stretched waveform)
      - (mag_ref, fpd_ref, bpd_ref) tensors
    Lengths differ per item; we will align later.
    """
    B = wav.size(0)
    y_refs = []
    for b in range(B):
        y_refs.append(pv_stretch_waveform(wav[b:b+1], float(z[b].item()), n_fft, hop))  # (1, N')
    # batch pad to common length for a single STFT call
    Lmax = max(y.shape[-1] for y in y_refs)
    y_pad = torch.zeros(B, Lmax, device=wav.device)
    for b,y in enumerate(y_refs):
        y_pad[b, :y.shape[-1]] = y.squeeze(0)

    mag_ref, fpd_ref, bpd_ref = compute_mag_fpd_bpd(y_pad, n_fft, hop, stft_mod)  # (B,F,T'),(B,F-1,T'),(B,F,T')
    return y_refs, mag_ref, fpd_ref, bpd_ref


In [None]:


def mrstft_loss(y_hat: torch.Tensor, y_ref: torch.Tensor, n_ffts=(512,1024,2048), hops=(80,160,320)) -> torch.Tensor:
    loss = 0.0
    for n,h in zip(n_ffts, hops):
        stft = CausalSTFT(n, h, n).to(y_hat.device)
        Yh = stft(y_hat).abs()
        Yr = stft(y_ref).abs()
        # log-mag L1 + spectral convergence
        loss += (Yh.log1p() - Yr.log1p()).abs().mean()
        loss += (Yh - Yr).norm() / (Yr.norm() + 1e-8)
    return loss / len(n_ffts)

def si_sdr(ref: torch.Tensor, est: torch.Tensor, eps=1e-8) -> torch.Tensor:
    ref = ref - ref.mean(dim=-1, keepdim=True)
    est = est - est.mean(dim=-1, keepdim=True)
    alpha = (est*ref).sum(dim=-1, keepdim=True) / (ref.pow(2).sum(dim=-1, keepdim=True)+eps)
    target = alpha * ref
    noise = est - target
    return 10.0 * torch.log10((target.pow(2).sum(dim=-1)+eps)/(noise.pow(2).sum(dim=-1)+eps))

# --- Data  ------------------------------------------------------
set_seed(7)
cfg = STFTCfg(sr=16000, n_fft=512, hop=80)
stft = CausalSTFT(cfg.n_fft, cfg.hop, cfg.n_fft).to(device)

subset_paths = load_subset_paths("subset_paths.txt")
dl = make_dataloader(subset_paths, sr=cfg.sr, seconds=3, batch_size=6)



Loaded 1847 paths from subset_paths.txt.
DataLoader ready: 1847 files → 307 batches
