## my Imports, utils, and baseline checkpoint

In [None]:
import math
import os
import random
from pathlib import Path

import torch
import torch.nn as nn
import torchaudio  


%run 00_shared_utils.ipynb

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

# --- Load baseline model as TEACHER ------------------------------------------


BASELINE_CKPT = "checkpoints/baseline.pth"

baseline_model, cfg = load_ckpt(BASELINE_CKPT, device=device, use_film=False)
baseline_model.eval()
for p in baseline_model.parameters():
    p.requires_grad_(False)

print("Loaded baseline teacher from:", BASELINE_CKPT)
print("Baseline cfg:", cfg.__dict__)


## Positional encoding + StretchAware v2 mode

In [None]:
# --- sinusoidal positional encoding (time axis) ------------------------

def sinusoidal_pe(T: int, dim: int, device=None) -> torch.Tensor:
    """
    Standard Transformer-style sinusoidal positional encoding.
    Returns (dim, T).
    """
    pe = torch.zeros(dim, T, device=device)
    position = torch.arange(T, device=device, dtype=torch.float32).unsqueeze(1)  # (T,1)
    div_term = torch.exp(
        torch.arange(0, dim, 2, device=device, dtype=torch.float32)
        * (-math.log(10000.0) / dim)
    )
    pe[0::2, :] = torch.sin(position * div_term).transpose(0, 1)
    pe[1::2, :] = torch.cos(position * div_term).transpose(0, 1)
    return pe  # (dim, T)


# --- StretchAwarePhaseDiff v2 (teacher-student) ------------------------------

class StretchAwarePhaseDiffV2(nn.Module):
    """
    Stretch-aware phase-difference predictor trained to mimic the baseline (teacher)
    on spectrograms warped by a stretch factor z.

    Input:
        mag_z : (B, F, T)   -- warped linear magnitude (synthetic stretch)
        z     : (B,) or (B,1) -- stretch factor per sample

    Extra features:
        - sinusoidal time positional encoding (non-learnable)
        - log(z) as a global channel

    Output:
        - fpd_pred: (B, F-1, T)
        - bpd_pred: (B, F,   T)
    """
    def __init__(
        self,
        n_fft: int,
        hop_length: int,
        pe_dim: int = 16,
        hidden_channels: int = 64,
        num_layers: int = 6,
        k_f: int = 3,
        k_t: int = 5,
    ):
        super().__init__()
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.pe_dim = pe_dim

        # logmag (1) + PE (pe_dim) + z-channel (1)
        in_channels = 1 + pe_dim + 1

        self.in_proj = nn.Conv2d(in_channels, hidden_channels, kernel_size=1)

        # Non-causal temporal conv stack with dilations to increase receptive field.
        blocks = []
        dilations = [1, 2, 4, 8, 16, 1][:num_layers]
        for d in dilations:
            blocks.append(
                nn.Conv2d(
                    hidden_channels,
                    hidden_channels,
                    kernel_size=(k_f, k_t),
                    padding=(k_f // 2, d * (k_t // 2)),   # symmetric padding in time
                    dilation=(1, d),
                )
            )
            blocks.append(nn.GELU())
            blocks.append(nn.BatchNorm2d(hidden_channels))
        self.net = nn.Sequential(*blocks)

        # Separate heads for BPD and FPD (1 channel each)
        self.out_bpd = nn.Conv2d(hidden_channels, 1, kernel_size=1)
        self.out_fpd = nn.Conv2d(hidden_channels, 1, kernel_size=1)

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                nn.init.normal_(m.weight, mean=0.0, std=0.02)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, mag: torch.Tensor, z: torch.Tensor):
        """
        mag : (B, F, T)   linear magnitude (warped by z)
        z   : (B,) or (B,1) stretch factors

        Returns:
            fpd_pred: (B, F-1, T)
            bpd_pred: (B, F,   T)
        """
        B, Freq, T = mag.shape
        device = mag.device

        # 1) log-magnitude base channel
        logmag = mag.clamp_min(1e-6).log10()
        x = logmag.unsqueeze(1)  # (B,1,F,T)

        # 2) sinusoidal positional encoding over time
        pe = sinusoidal_pe(T, self.pe_dim, device=device)          # (pe_dim, T)
        pe = pe.unsqueeze(0).unsqueeze(2).expand(B, self.pe_dim, Freq, T)  # (B,pe_dim,F,T)

        # 3) stretch factor channel: use log(z) broadcast over F,T
        if z.dim() == 1:
            z = z.unsqueeze(1)  # (B,1)
        z_norm = torch.log(z + 1e-6)  # (B,1)
        z_img = z_norm.view(B, 1, 1, 1).expand(B, 1, Freq, T)      # (B,1,F,T)

        feats = torch.cat([x, pe, z_img], dim=1)  # (B,1+pe_dim+1,F,T)

        # 4) Conv trunk
        h = self.in_proj(feats)
        h = self.net(h)

        # 5) Heads
        bpd = self.out_bpd(h).squeeze(1)       # (B,F,T)
        fpd_full = self.out_fpd(h).squeeze(1)  # (B,F,T)
        fpd = fpd_full[:, :-1, :]              # (B,F-1,T) â€“ match teacher target shape

        return fpd, bpd


## Config, data, model, stretch factors

In [None]:
# --- Core setup ---------------------------------------------------------------

set_seed(7)

# I use cfg loaded from baseline ckpt for consistency
sr = cfg.sr
n_fft = cfg.n_fft
hop = cfg.hop

stft = CausalSTFT(n_fft, hop, n_fft).to(device)
print("Training StretchAwarePhaseDiff v2 (teacher) with cfg:", cfg.__dict__)


subset_paths = load_subset_paths("subset_paths.txt")
print("Total files in subset_paths:", len(subset_paths))


# N_FILES = 200   # for testing pruproses before training big model
# subset_paths = subset_paths[:N_FILES]
# print("Using first", len(subset_paths), "files for v2 training")

dl = make_dataloader(
    subset_paths,
    sr=sr,
    seconds=3.0,    # shorter clips to speed up
    batch_size=6,   # we can also adjust batch here for increased speed
    num_workers=0,
    shuffle=True,
    drop_last=True,
)

# Student model
student = StretchAwarePhaseDiffV2(
    n_fft=n_fft,
    hop_length=hop,
    pe_dim=16,
    hidden_channels=64,
    num_layers=6,
    k_f=3,
    k_t=5,
).to(device)

opt = torch.optim.Adam(student.parameters(), lr=1e-3)
epochs = 8   

n_params = sum(p.numel() for p in student.parameters())
print(f"Student model parameters: {n_params/1e6:.2f} M")

# Stretch factors to sample 
stretch_factors = [1.0, 1.5, 2.0]  
print("Stretch factors:", stretch_factors)


## Teacher-student training loop 

In [None]:
# --- Training loop: student mimics baseline on stretched spectrograms -----------

os.makedirs("checkpoints", exist_ok=True)
ckpt_path = "checkpoints/stretchaware_teacher.pth"

def vm_loss(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    return -(torch.cos(a - b)).mean()

for ep in range(1, epochs + 1):
    student.train()
    total_fpd = 0.0
    total_bpd = 0.0
    steps = 0

    for batch_idx, wav_cpu in enumerate(dl):
        wav = wav_cpu.to(device)  # (B, N)
        B, N = wav.shape

        # 1) compute magnitude of original audio
        with torch.no_grad():
            mag_orig, _, _ = compute_mag_fpd_bpd(wav, n_fft, hop, stft)  # (B,F,T)

        # 2) sample a stretch factor z for this batch
        z_val = random.choice(stretch_factors)
        z_batch = torch.full((B,), float(z_val), device=device)  # (B,)

        # 3) warp magnitude along time to simulate stretch
        mag_z = logmag_stretch_then_match_T(mag_orig, z_val)  # (B,F,T)

        # 4) baseline TEACHER: get FPD/BPD targets on warped magnitude
        with torch.no_grad():
            out_teacher = baseline_model(mag=mag_z)
            # baseline_model forward returns (fpd_pred, bpd_pred, aux) in setup
            if isinstance(out_teacher, (tuple, list)):
                fpd_teacher, bpd_teacher = out_teacher[0], out_teacher[1]
            else:
                raise RuntimeError("Unexpected baseline_model output format")

        # 5) STUDENT forward on same mag_z + explicit z
        fpd_student, bpd_student = student(mag_z, z_batch)

        # 6) losses (teacher vs student)
        loss_f = vm_loss(fpd_teacher, fpd_student)
        loss_b = vm_loss(bpd_teacher, bpd_student)
        loss = loss_f + loss_b

        opt.zero_grad(set_to_none=True)
        loss.backward()
        opt.step()

        total_fpd += loss_f.item()
        total_bpd += loss_b.item()
        steps += 1

        if batch_idx % 20 == 0:
            print(f"  [ep {ep}] batch {batch_idx}, loss_f={loss_f.item():.4f}, loss_b={loss_b.item():.4f}, z={z_val}")

    mean_fpd = total_fpd / steps
    mean_bpd = total_bpd / steps
    print(f"[Epoch {ep}/{epochs}] mean loss_f={mean_fpd:.4f}, mean loss_b={mean_bpd:.4f}")

    torch.save({
        "state_dict": student.state_dict(),
        "cfg": cfg.__dict__,
        "stretch_factors": stretch_factors,
        "note": "StretchAwarePhaseDiff v2 (teacher-student on warped mags)",
    }, ckpt_path)

print("Training done. Final checkpoint saved to", ckpt_path)
