In [1]:
# %% ------------------------------------ imports
import math
import numpy as np
from glob import glob
from functools import partial
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import librosa

# ==========================
# Hyperparams (same semantics)
# ==========================
min_signal_rate = 0.02
max_signal_rate = 0.95
ema = 0.999




In [2]:
# %% ------------------------------------ losses / metrics (PyTorch)

def spectral_norm_diff(pred: torch.Tensor, real: torch.Tensor) -> torch.Tensor:
    """
    Difference in 'spectral norm' between batches, analogous to your TF helper.
    NOTE: renamed to avoid clashing with torch.nn.utils.spectral_norm.
    real/pred: (B, H, W, C) or (B, C, H, W) – we convert to (B, -1)
    """
    if pred.dim() == 4 and pred.shape[1] not in (1, 2, 3):  # probably NHWC
        # NHWC -> NCHW
        pred = pred.permute(0, 3, 1, 2).contiguous()
        real = real.permute(0, 3, 1, 2).contiguous()
    # flatten spatial + channel
    pr = pred.flatten(1)
    rr = real.flatten(1)
    norm_real = torch.norm(rr, dim=1) + 1e-6
    norm_pred = torch.norm(pr, dim=1) + 1e-6
    return torch.mean(torch.abs(norm_real - norm_pred) / norm_real)


def time_derivative_loss(pred: torch.Tensor, real: torch.Tensor, window: int = 1) -> torch.Tensor:
    """
    Match finite differences along the 'time' dimension (assume H is time).
    Works for NCHW or NHWC (we'll align to NCHW internally).
    """
    if pred.dim() == 4 and pred.shape[1] not in (1, 2, 3):  # NHWC -> NCHW
        pred = pred.permute(0, 3, 1, 2)
        real = real.permute(0, 3, 1, 2)
    # pred, real: (B, C, H, W)
    real_dx = real[:, :, :-window, :] - real[:, :, window:, :]
    pred_dx = pred[:, :, :-window, :] - pred[:, :, window:, :]
    return F.mse_loss(pred_dx, real_dx)




In [3]:
def _load_mdct_from_file(file: str, idx: int, rate: int = 10_000, feats: int = 256, duration: float = 3.3) -> np.ndarray:
    """
    MDCT via windowed DCT-IV with 50% overlap.
    Returns array shape (feats, feats//2) to match your pipeline.
    """
    # load audio robustly
    audio, _ = _safe_audio_load(file, sr=rate, offset=idx, duration=duration)

    # pad / truncate to exact length
    target_len = int(rate * duration)
    audio_fill = np.zeros(target_len, dtype=np.float32)
    audio_fill[:min(len(audio), target_len)] = audio[:target_len]

    # framing
    N = feats            # MDCT bins (half window)
    win_len = 2 * N      # frame length
    hop = N              # 50% overlap

    # frames: expected shape (win_len, n_frames)
    frames = librosa.util.frame(audio_fill, frame_length=win_len, hop_length=hop)
    if frames.shape[0] != win_len and frames.shape[1] == win_len:
        frames = frames.swapaxes(0, 1)  # enforce (win_len, n_frames)

    # sine window along axis=0
    window = np.sin(np.pi / win_len * (np.arange(win_len) + 0.5)).astype(np.float32)  # (win_len,)
    frames = frames * window[:, None]  # (win_len, n_frames)

    # MDCT = DCT-IV along axis=0, keep first N coeffs
    mdct = dct(frames, type=4, norm="ortho", axis=0)[:N, :]  # (N, n_frames)

    # crop/pad to (feats, feats//2)
    H, W = feats, feats // 2
    out = np.zeros((H, W), dtype=np.float32)
    h = min(H, mdct.shape[0])
    w = min(W, mdct.shape[1])
    out[:h, :w] = mdct[:h, :w]
    return out


def _is_decodable(path: str) -> bool:
    """
    Quick pre-check to skip obviously undecodable files.
    """
    try:
        import torchaudio
        torchaudio.info(path)
        return True
    except Exception:
        pass
    try:
        import soundfile as sf
        with sf.SoundFile(path):
            return True
    except Exception:
        pass
    try:
        from scipy.io import wavfile
        wavfile.read(path)
        return True
    except Exception:
        return False


# %% =============================== dataset / dataloader

class FilesMDCTDataset(Dataset):
    def __init__(
        self,
        glob_location: str,
        total_seconds: int = 2,
        out_len: float = 3.3,
        hop_size: int = 1,
        max_feats: int = 2048,
        batch_scale: float = 1.0,
        rate: int = 10_000,
        mdct_feats: int = 256,
    ):
        super().__init__()
        all_files = glob(glob_location, recursive=True)
        self.files = [f for f in all_files if _is_decodable(f)]
        if len(self.files) == 0:
            raise RuntimeError("No decodable audio files found. Check backends or path.")
        pairs = []
        for s in range(total_seconds):
            for f in self.files:
                pairs.append((f, s * hop_size))
        self.pairs = pairs
        self.rate = rate
        self.out_len = out_len
        self.mdct_feats = mdct_feats
        self.max_feats = max_feats
        self.scale = batch_scale

    def __len__(self) -> int:
        return len(self.pairs)

    def __getitem__(self, i: int) -> torch.Tensor:
        f, idx = self.pairs[i]
        spec = _load_mdct_from_file(
            f, idx, rate=self.rate, feats=self.mdct_feats, duration=self.out_len
        )  # (H=feats, W=feats//2)

        # ensure shape (max_feats, mdct_feats//2, 1)
        H = self.max_feats
        W = self.mdct_feats // 2
        out = np.zeros((H, W), dtype=np.float32)
        h = min(H, spec.shape[0])
        w = min(W, spec.shape[1])
        out[:h, :w] = spec[:h, :w]
        out = out[..., None] * self.scale  # (H, W, 1)

        # return as CHW for PyTorch (C, H, W)
        out = np.transpose(out, (2, 0, 1))  # (1, H, W)
        return torch.from_numpy(out)


def get_files_dataloader(
    glob_location: str,
    total_seconds: int = 2,
    out_len: float = 3.3,
    hop_size: int = 1,
    max_feats: int = 256,
    batch_size: int = 16,
    shuffle_size: int = 1000,  # not used; kept for parity
    scale: float = 1.0,
    rate: int = 10_000,
    mdct_feats: int = 256,
    num_workers: int = 0,
    pin_memory: bool = False,
) -> DataLoader:
    ds = FilesMDCTDataset(
        glob_location=glob_location,
        total_seconds=total_seconds,
        out_len=out_len,
        hop_size=hop_size,
        max_feats=max_feats,
        batch_scale=scale,
        rate=rate,
        mdct_feats=mdct_feats,
    )
    loader = DataLoader(
        ds,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,  # 0 is safer in notebooks
        pin_memory=pin_memory,
        drop_last=True,
    )
    return loader



In [4]:

# %% ------------------------------------ normalization module (adaptable)

class AdaptiveNormalizer(nn.Module):
    """
    Store mean/std estimated from data (like Keras Normalization).
    Call .adapt(dataloader, steps=...) once before training.
    """
    def __init__(self, eps: float = 1e-6):
        super().__init__()
        self.register_buffer("mean", torch.zeros(1, 1, 1))
        self.register_buffer("std", torch.ones(1, 1, 1))
        self.eps = eps

    @torch.no_grad()
    def adapt(self, dataloader, max_batches: int = 512, device="cpu", show_progress: bool = True):
        from tqdm.auto import tqdm

        n = 0
        mean = 0.0
        M2 = 0.0
        seen = 0
        iterator = iter(dataloader)
        rng = range(max_batches)
        if show_progress:
            rng = tqdm(rng, desc="adapting normalizer", dynamic_ncols=True, leave=False)

        for _ in rng:
            try:
                x = next(iterator)
            except StopIteration:
                break
            x = x.to(device)  # (B, 1, H, W)
            b = x.numel()
            seen += b
            delta = x - mean
            mean = mean + delta.sum() / seen
            delta2 = x - mean
            M2 = M2 + (delta * delta2).sum()

        var = M2 / max(1, seen - 1)
        std = torch.sqrt(var + self.eps)
        self.mean = torch.tensor([[[mean.item()]]], device=device)
        self.std = torch.tensor([[[std.item()]]], device=device)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, C, H, W)
        return (x - self.mean) / (self.std + self.eps)

    def denormalize(self, x: torch.Tensor) -> torch.Tensor:
        return x * (self.std + self.eps) + self.mean




In [5]:
# %% ------------------------------------ UNet-like noise predictor (with FiLM time conditioning)

class FiLMBlock(nn.Module):
    def __init__(self, in_ch, out_ch, t_dim, use_attn=False):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.gn1 = nn.GroupNorm(8, out_ch)
        self.to_gamma = nn.Linear(t_dim, out_ch)
        self.to_beta = nn.Linear(t_dim, out_ch)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.gn2 = nn.GroupNorm(8, out_ch)
        self.use_attn = use_attn
        if use_attn:
            self.q = nn.Conv2d(out_ch, out_ch, 1)
            self.k = nn.Conv2d(out_ch, out_ch, 1)
            self.v = nn.Conv2d(out_ch, out_ch, 1)

    def forward(self, x, t_emb):
        # conv1 + FiLM (shift + bias in two steps, mirroring your TF idea)
        h = self.conv1(x)
        h = self.gn1(h)
        gamma = self.to_gamma(t_emb).unsqueeze(-1).unsqueeze(-1)  # (B, C, 1, 1)
        beta = self.to_beta(t_emb).unsqueeze(-1).unsqueeze(-1)
        h = h + gamma
        h = F.silu(h)
        h = h + beta

        h = self.conv2(h)
        h = self.gn2(h)
        h = F.silu(h)

        if self.use_attn:
            B, C, H, W = h.shape
            q = self.q(h).flatten(2).transpose(1, 2)  # (B, HW, C)
            k = self.k(h).flatten(2)                  # (B, C, HW)
            v = self.v(h).flatten(2).transpose(1, 2)  # (B, HW, C)
            attn = torch.softmax(q @ k / math.sqrt(C), dim=-1)  # (B, HW, HW)
            h_attn = attn @ v                              # (B, HW, C)
            h_attn = h_attn.transpose(1, 2).reshape(B, C, H, W)
            h = h + h_attn
        return h


class TimeEmbed(nn.Module):
    def __init__(self, dim=128):
        super().__init__()
        self.fc1 = nn.Linear(1, dim)
        self.fc2 = nn.Linear(dim, dim)

    def forward(self, t):  # t: (B, 1, 1, 1) or (B,1)
        if t.dim() == 4:
            t = t.view(t.size(0), 1)
        x = F.silu(self.fc1(t))
        x = F.silu(self.fc2(x))
        return x  # (B, dim)


class UNetNoisePredictor(nn.Module):
    def __init__(self, widths, block_depth, attention=False, dim1=256, dim2=128, t_dim=128):
        super().__init__()
        self.dim1, self.dim2 = dim1, dim2
        self.t_embed = TimeEmbed(t_dim)

        chs = widths
        downs = nn.ModuleList()
        ups = nn.ModuleList()
        self.skips_out = []

        # Down
        in_ch = 1
        self.down_blocks = nn.ModuleList()
        for i, ch in enumerate(chs):
            blocks = nn.ModuleList()
            for _ in range(block_depth):
                blocks.append(FiLMBlock(in_ch, ch, t_dim, use_attn=attention and (i >= len(chs)//2)))
                in_ch = ch
            self.down_blocks.append(blocks)
            if i < len(chs) - 1:
                downs.append(nn.Conv2d(ch, ch, 3, stride=2, padding=1))
        self.downs = downs

        # Mid
        self.mid = FiLMBlock(chs[-1], chs[-1], t_dim, use_attn=attention)

        # Up
        self.up_convs = nn.ModuleList()
        self.up_blocks = nn.ModuleList()
        for i, ch in reversed(list(enumerate(chs[:-1]))):
            self.up_convs.append(nn.Sequential(
                nn.Upsample(scale_factor=2, mode="nearest"),
                nn.Conv2d(chs[i+1], ch, 3, padding=1),
            ))
            blocks = nn.ModuleList()
            for _ in range(block_depth):
                blocks.append(FiLMBlock(ch*2, ch, t_dim, use_attn=attention and (i >= len(chs)//2)))
            self.up_blocks.append(blocks)

        self.out = nn.Conv2d(chs[0], 1, 1)

    def forward(self, x, t_in):
        """
        x: (B,1,H,W); t_in: (B,1,1,1) or (B,1) with values in [0,1] (we pass noise_rate**2)
        """
        t_emb = self.t_embed(t_in)

        # Down
        skips = []
        h = x
        for i, blocks in enumerate(self.down_blocks):
            for blk in blocks:
                h = blk(h, t_emb)
            skips.append(h)
            if i < len(self.downs):
                h = self.downs[i](h)

        # Mid
        h = self.mid(h, t_emb)

        # Up
        for up, blocks, skip in zip(self.up_convs, self.up_blocks, reversed(skips[:-1])):
            h = up(h)
            h = torch.cat([h, skip], dim=1)
            for blk in blocks:
                h = blk(h, t_emb)

        return self.out(h)




In [6]:
# %% ------------------------------------ DDIM model (PyTorch)

class DDIMTorch(nn.Module):
    def __init__(self, widths, block_depth, attention=False, dim1=256, dim2=128, device="cuda"):
        super().__init__()
        self.device = torch.device(device if torch.cuda.is_available() else "cpu")

        self.normalizer = AdaptiveNormalizer()
        self.network = UNetNoisePredictor(widths, block_depth, attention=attention, dim1=dim1, dim2=dim2).to(self.device)
        self.ema_network = UNetNoisePredictor(widths, block_depth, attention=attention, dim1=dim1, dim2=dim2).to(self.device)
        self.ema_network.load_state_dict(self.network.state_dict())
        self.spec_mod = 0.0
        self.dx_mod = 0.0

        self.mse = nn.MSELoss()

    @torch.no_grad()
    def diffusion_schedule(self, diffusion_times):
        # diffusion_times: (B,1,1,1) in [0,1]
        start_angle = math.acos(max_signal_rate)
        end_angle = math.acos(min_signal_rate)
        diffusion_angles = start_angle + diffusion_times * (end_angle - start_angle)
        signal_rates = torch.cos(diffusion_angles)
        noise_rates = torch.sin(diffusion_angles)
        return noise_rates, signal_rates

    def denoise(self, noisy_data, noise_rates, signal_rates, training=True):
        net = self.network if training else self.ema_network
        # pass noise_rates**2 like your TF code
        cond = (noise_rates ** 2).to(noisy_data.dtype)
        pred_noises = net(noisy_data, cond)
        pred_data = (noisy_data - noise_rates * pred_noises) / (signal_rates + 1e-6)
        return pred_noises, pred_data

    @torch.no_grad()
    def reverse_diffusion(self, initial_noise, diffusion_steps: int):
        B = initial_noise.size(0)
        step_size = 1.0 / diffusion_steps
        next_noisy = initial_noise
        for step in tqdm(range(diffusion_steps)):
            noisy = next_noisy
            t = torch.ones((B, 1, 1, 1), device=self.device) - step * step_size
            noise_rates, signal_rates = self.diffusion_schedule(t)
            pred_noises, pred_data = self.denoise(noisy, noise_rates, signal_rates, training=False)
            t_next = t - step_size
            next_noise_rates, next_signal_rates = self.diffusion_schedule(t_next)
            next_noisy = next_signal_rates * pred_data + next_noise_rates * pred_noises
        return pred_data

    @torch.no_grad()
    def generate(self, num_examples, shape, diffusion_steps):
        # shape: (C,H,W) – like your dataset (1, H, W)
        initial_noise = torch.randn((num_examples, *shape), device=self.device)
        generated = self.reverse_diffusion(initial_noise, diffusion_steps)
        denorm = self.normalizer.denormalize(generated)
        return torch.clamp(denorm, -128.0, 128.0)

    def _get_losses(self, y_true, y_pred):
        l = self.mse(y_pred, y_true)
        s = spectral_norm_diff(y_pred, y_true)
        d = time_derivative_loss(y_pred, y_true)
        return l, s, d

    def training_step(self, batch, optimizer):
        batch = batch.to(self.device)  # (B,1,H,W)

        data = self.normalizer(batch)
        noises = torch.randn_like(data)

        B = data.size(0)
        diffusion_times = torch.rand((B, 1, 1, 1), device=self.device)
        noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
        noisy = signal_rates * data + noise_rates * noises

        pred_noises, pred_data = self.denoise(noisy, noise_rates, signal_rates, training=True)
        noise_loss, noise_spec, noise_dx = self._get_losses(noises, pred_noises)
        data_loss, data_spec, data_dx = self._get_losses(data, pred_data)

        total_noise_loss = noise_loss + self.spec_mod * noise_spec + self.dx_mod * noise_dx

        optimizer.zero_grad(set_to_none=True)
        total_noise_loss.backward()
        optimizer.step()

        # EMA
        with torch.no_grad():
            for p, q in zip(self.network.parameters(), self.ema_network.parameters()):
                q.mul_(ema).add_(p, alpha=1.0 - ema)

        metrics = {
            "n_loss": noise_loss.detach().item(),
            "d_loss": data_loss.detach().item(),
            "n_spec": noise_spec.detach().item(),
            "d_spec": data_spec.detach().item(),
            "n_dx": noise_dx.detach().item(),
            "d_dx": data_dx.detach().item(),
            "n_total": total_noise_loss.detach().item(),
            "d_total": (data_loss + self.spec_mod * data_spec + self.dx_mod * data_dx).detach().item(),
        }
        return metrics

    @torch.no_grad()
    def eval_step(self, batch):
        batch = batch.to(self.device)
        data = self.normalizer(batch)
        noises = torch.randn_like(data)

        B = data.size(0)
        diffusion_times = torch.rand((B, 1, 1, 1), device=self.device)
        noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
        noisy = signal_rates * data + noise_rates * noises

        pred_noises, pred_data = self.denoise(noisy, noise_rates, signal_rates, training=False)
        noise_loss = self.mse(noises, pred_noises)
        data_loss = self.mse(data, pred_data)
        return {"n_loss": noise_loss.item(), "d_loss": data_loss.item()}




In [7]:
# %% ------------------------------------ build dataloader (parity with your TF dataset)

def get_files_dataloader(
    glob_location,
    total_seconds=2,
    out_len=3.3,
    hop_size=1,
    max_feats=256,
    batch_size=16,
    shuffle_size=1000,  # not used exactly the same as TF shuffle; DataLoader shuffle=True is typical
    scale=1.0,
    rate=10_000,
    mdct_feats=256,
    num_workers=4,
    pin_memory=True,
):
    ds = FilesMDCTDataset(
        glob_location=glob_location,
        total_seconds=total_seconds,
        out_len=out_len,
        hop_size=hop_size,
        max_feats=max_feats,
        batch_scale=scale,
        rate=rate,
        mdct_feats=mdct_feats,
    )
    loader = DataLoader(
        ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory, drop_last=True
    )
    return loader




In [8]:
# %% ------------------------------------ example usage
from tqdm.auto import tqdm

if __name__ == "__main__":
    # dataset (mirrors your call)
    loader = get_files_dataloader(
        "data/REAL_audio/*.wav",
        out_len=3.3,
        max_feats=256,
        total_seconds=26,
        scale=1.0,
        batch_size=16,
        mdct_feats=256,
        rate=10_000,
        num_workers=0,
        pin_memory=False,
    )

    # peek a batch to get shape
    for batch in loader:
        # batch: (B, 1, H, W)
        shape = batch.shape  # torch.Size([B, 1, H, W])
        break
    print("batch shape:", shape)

    # build model
    model = DDIMTorch(
        widths=[128, 128, 128, 128],
        block_depth=2,
        attention=True,
        dim1=shape[2],
        dim2=shape[3],
        device="cuda",
    )

    # adapt normalizer (like Keras Normalization.adapt)
    model.normalizer.adapt(loader, max_batches=256, device=model.device, show_progress=True)

    # optimizer
    opt = torch.optim.AdamW(model.network.parameters(), lr=2e-4)

    # optional: enable your auxiliary loss terms
    model.spec_mod = 0.0
    model.dx_mod = 0.0

    # tiny training loop sketch
    model.train()
    for epoch in range(1):
        pbar = tqdm(loader, desc=f"epoch {epoch}", dynamic_ncols=True, leave=True)
        for batch in pbar:
            metrics = model.training_step(batch, optimizer=opt)
            pbar.set_postfix({
                "n_total": f"{metrics['n_total']:.4f}",
                "d_loss": f"{metrics['d_loss']:.4f}"
            })

    # sampling example
    with torch.no_grad():
        gen = model.generate(num_examples=4, shape=(1, shape[2], shape[3]), diffusion_steps=50)
        print("generated:", gen.shape)  # (4, 1, H, W)


  from .autonotebook import tqdm as notebook_tqdm


NameError: name '_safe_audio_load' is not defined

In [10]:
# %% =============================== imports
import math
import numpy as np
from glob import glob
from functools import partial
from typing import Tuple

from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import librosa

try:
    from scipy.fft import dct           # modern SciPy
except ImportError:
    from scipy.fftpack import dct       # fallback


# ==========================
# Hyperparams (same semantics as your TF version)
# ==========================
min_signal_rate = 0.02
max_signal_rate = 0.95
ema = 0.999


# %% =============================== losses / metrics (PyTorch)

def spectral_norm_diff(pred: torch.Tensor, real: torch.Tensor) -> torch.Tensor:
    """
    Difference in 'spectral norm' between batches, analogous to your TF helper.
    Supports NCHW (default) or NHWC (will be converted).
    """
    if pred.dim() == 4 and pred.shape[1] not in (1, 2, 3):  # likely NHWC
        pred = pred.permute(0, 3, 1, 2).contiguous()
        real = real.permute(0, 3, 1, 2).contiguous()
    pr = pred.flatten(1)
    rr = real.flatten(1)
    norm_real = torch.norm(rr, dim=1) + 1e-6
    norm_pred = torch.norm(pr, dim=1) + 1e-6
    return torch.mean(torch.abs(norm_real - norm_pred) / norm_real)


def time_derivative_loss(pred: torch.Tensor, real: torch.Tensor, window: int = 1) -> torch.Tensor:
    """
    Match finite differences along the 'time' dimension (assume H is time).
    Works for NCHW or NHWC (we'll align to NCHW internally).
    """
    if pred.dim() == 4 and pred.shape[1] not in (1, 2, 3):  # NHWC -> NCHW
        pred = pred.permute(0, 3, 1, 2)
        real = real.permute(0, 3, 1, 2)
    real_dx = real[:, :, :-window, :] - real[:, :, window:, :]
    pred_dx = pred[:, :, :-window, :] - pred[:, :, window:, :]
    return F.mse_loss(pred_dx, real_dx)


# %% =============================== robust audio load + MDCT via DCT-IV

def _safe_audio_load(path: str, sr: int, offset: float, duration: float) -> Tuple[np.ndarray, int]:
    """
    Try torchaudio -> soundfile -> scipy to load audio robustly.
    Returns mono float32 at target sr.
    """
    # 1) torchaudio
    try:
        import torchaudio
        wav, native_sr = torchaudio.load(path)  # shape (C, T)
        start = int(offset * native_sr)
        end = start + int(duration * native_sr)
        wav = wav[:, start:end]
        if wav.shape[0] > 1:
            wav = wav.mean(0, keepdim=False)    # mono
        else:
            wav = wav.squeeze(0)
        if native_sr != sr:
            resamp = torchaudio.transforms.Resample(native_sr, sr)
            wav = resamp(wav.unsqueeze(0)).squeeze(0)
        return wav.numpy().astype(np.float32), sr
    except Exception:
        pass

    # 2) soundfile (libsndfile)
    try:
        import soundfile as sf
        y, native_sr = sf.read(path, dtype="float32", always_2d=False)
        if y.ndim == 2:
            y = y.mean(axis=1)
        start = int(offset * native_sr)
        end = start + int(duration * native_sr)
        y = y[start:end]
        if native_sr != sr:
            y = librosa.resample(y, orig_sr=native_sr, target_sr=sr)
        return y.astype(np.float32), sr
    except Exception:
        pass

    # 3) scipy (PCM-only)
    try:
        from scipy.io import wavfile
        native_sr, y = wavfile.read(path)
        if y.ndim == 2:
            y = y.mean(axis=1)
        y = y.astype(np.float32)
        # normalize if integer PCM range
        if y.max() > 1.0 or y.min() < -1.0:
            maxv = float(np.max(np.abs(y)) or 1.0)
            y = y / maxv
        start = int(offset * native_sr)
        end = start + int(duration * native_sr)
        y = y[start:end]
        if native_sr != sr:
            y = librosa.resample(y, orig_sr=native_sr, target_sr=sr)
        return y.astype(np.float32), sr
    except Exception as e:
        raise RuntimeError(f"Failed to decode {path}: {e}")


def _load_mdct_from_file(file: str, idx: int, rate: int = 10_000, feats: int = 256, duration: float = 3.3) -> np.ndarray:
    """
    MDCT via windowed DCT-IV with 50% overlap.
    Returns array shape (feats, feats//2) to match your pipeline.
    """
    # load audio robustly
    audio, _ = _safe_audio_load(file, sr=rate, offset=idx, duration=duration)

    # pad / truncate to exact length
    target_len = int(rate * duration)
    audio_fill = np.zeros(target_len, dtype=np.float32)
    audio_fill[:min(len(audio), target_len)] = audio[:target_len]

    # framing
    N = feats            # MDCT bins (half window)
    win_len = 2 * N      # frame length
    hop = N              # 50% overlap

    # frames: expected shape (win_len, n_frames)
    frames = librosa.util.frame(audio_fill, frame_length=win_len, hop_length=hop)
    if frames.shape[0] != win_len and frames.shape[1] == win_len:
        frames = frames.swapaxes(0, 1)  # enforce (win_len, n_frames)

    # sine window along axis=0
    window = np.sin(np.pi / win_len * (np.arange(win_len) + 0.5)).astype(np.float32)  # (win_len,)
    frames = frames * window[:, None]  # (win_len, n_frames)

    # MDCT = DCT-IV along axis=0, keep first N coeffs
    mdct = dct(frames, type=4, norm="ortho", axis=0)[:N, :]  # (N, n_frames)

    # crop/pad to (feats, feats//2)
    H, W = feats, feats // 2
    out = np.zeros((H, W), dtype=np.float32)
    h = min(H, mdct.shape[0])
    w = min(W, mdct.shape[1])
    out[:h, :w] = mdct[:h, :w]
    return out


def _is_decodable(path: str) -> bool:
    """
    Quick pre-check to skip obviously undecodable files.
    """
    try:
        import torchaudio
        torchaudio.info(path)
        return True
    except Exception:
        pass
    try:
        import soundfile as sf
        with sf.SoundFile(path):
            return True
    except Exception:
        pass
    try:
        from scipy.io import wavfile
        wavfile.read(path)
        return True
    except Exception:
        return False


# %% =============================== dataset / dataloader

class FilesMDCTDataset(Dataset):
    def __init__(
        self,
        glob_location: str,
        total_seconds: int = 2,
        out_len: float = 3.3,
        hop_size: int = 1,
        max_feats: int = 2048,
        batch_scale: float = 1.0,
        rate: int = 10_000,
        mdct_feats: int = 256,
    ):
        super().__init__()
        all_files = glob(glob_location, recursive=True)
        self.files = [f for f in all_files if _is_decodable(f)]
        if len(self.files) == 0:
            raise RuntimeError("No decodable audio files found. Check backends or path.")
        pairs = []
        for s in range(total_seconds):
            for f in self.files:
                pairs.append((f, s * hop_size))
        self.pairs = pairs
        self.rate = rate
        self.out_len = out_len
        self.mdct_feats = mdct_feats
        self.max_feats = max_feats
        self.scale = batch_scale

    def __len__(self) -> int:
        return len(self.pairs)

    def __getitem__(self, i: int) -> torch.Tensor:
        f, idx = self.pairs[i]
        spec = _load_mdct_from_file(
            f, idx, rate=self.rate, feats=self.mdct_feats, duration=self.out_len
        )  # (H=feats, W=feats//2)

        # ensure shape (max_feats, mdct_feats//2, 1)
        H = self.max_feats
        W = self.mdct_feats // 2
        out = np.zeros((H, W), dtype=np.float32)
        h = min(H, spec.shape[0])
        w = min(W, spec.shape[1])
        out[:h, :w] = spec[:h, :w]
        out = out[..., None] * self.scale  # (H, W, 1)

        # return as CHW for PyTorch (C, H, W)
        out = np.transpose(out, (2, 0, 1))  # (1, H, W)
        return torch.from_numpy(out)


def get_files_dataloader(
    glob_location: str,
    total_seconds: int = 2,
    out_len: float = 3.3,
    hop_size: int = 1,
    max_feats: int = 256,
    batch_size: int = 16,
    shuffle_size: int = 1000,  # not used; kept for parity
    scale: float = 1.0,
    rate: int = 10_000,
    mdct_feats: int = 256,
    num_workers: int = 0,
    pin_memory: bool = False,
) -> DataLoader:
    ds = FilesMDCTDataset(
        glob_location=glob_location,
        total_seconds=total_seconds,
        out_len=out_len,
        hop_size=hop_size,
        max_feats=max_feats,
        batch_scale=scale,
        rate=rate,
        mdct_feats=mdct_feats,
    )
    loader = DataLoader(
        ds,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,  # 0 is safer in notebooks
        pin_memory=pin_memory,
        drop_last=True,
    )
    return loader


# %% =============================== normalization (adapt with tqdm)

class AdaptiveNormalizer(nn.Module):
    """
    Store mean/std estimated from data (like Keras Normalization).
    Call .adapt(dataloader, steps=...) once before training.
    """
    def __init__(self, eps: float = 1e-6):
        super().__init__()
        self.register_buffer("mean", torch.zeros(1, 1, 1))
        self.register_buffer("std", torch.ones(1, 1, 1))
        self.eps = eps

    @torch.no_grad()
    def adapt(self, dataloader, max_batches: int = 512, device: str = "cpu", show_progress: bool = True):
        iterator = iter(dataloader)
        rng = range(max_batches)
        if show_progress:
            rng = tqdm(rng, desc="adapting normalizer", dynamic_ncols=True, leave=False)

        # streaming Welford
        n = 0
        mean = 0.0
        M2 = 0.0

        for _ in rng:
            try:
                x = next(iterator)  # (B, 1, H, W)
            except StopIteration:
                break
            x = x.to(device).float()
            # compute batch mean/var over all elements to match TF's global stats
            batch_mean = x.mean().item()
            batch_var = x.var(unbiased=False).item()
            batch_count = x.numel()

            n_total = n + batch_count
            delta = batch_mean - mean
            mean = mean + delta * (batch_count / max(1, n_total))
            M2 = M2 + batch_var * batch_count + (delta ** 2) * n * batch_count / max(1, n_total)
            n = n_total

        var = M2 / max(1, n - 1)
        std = math.sqrt(max(var, 0.0) + self.eps)
        self.mean = torch.tensor([[[mean]]], device=device)
        self.std = torch.tensor([[[std]]], device=device)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, C, H, W)
        return (x - self.mean) / (self.std + self.eps)

    def denormalize(self, x: torch.Tensor) -> torch.Tensor:
        return x * (self.std + self.eps) + self.mean


# %% =============================== UNet-like noise predictor (FiLM time conditioning)

class FiLMBlock(nn.Module):
    def __init__(self, in_ch: int, out_ch: int, t_dim: int, use_attn: bool = False):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.gn1 = nn.GroupNorm(8, out_ch)
        self.to_gamma = nn.Linear(t_dim, out_ch)
        self.to_beta = nn.Linear(t_dim, out_ch)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.gn2 = nn.GroupNorm(8, out_ch)
        self.use_attn = use_attn
        if use_attn:
            self.q = nn.Conv2d(out_ch, out_ch, 1)
            self.k = nn.Conv2d(out_ch, out_ch, 1)
            self.v = nn.Conv2d(out_ch, out_ch, 1)

    def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor:
        h = self.conv1(x)
        h = self.gn1(h)
        gamma = self.to_gamma(t_emb).unsqueeze(-1).unsqueeze(-1)  # (B, C, 1, 1)
        beta = self.to_beta(t_emb).unsqueeze(-1).unsqueeze(-1)
        h = h + gamma
        h = F.silu(h)
        h = h + beta

        h = self.conv2(h)
        h = self.gn2(h)
        h = F.silu(h)

        if self.use_attn:
            B, C, H, W = h.shape
            q = self.q(h).flatten(2).transpose(1, 2)  # (B, HW, C)
            k = self.k(h).flatten(2)                  # (B, C, HW)
            v = self.v(h).flatten(2).transpose(1, 2)  # (B, HW, C)
            attn = torch.softmax(q @ k / math.sqrt(C), dim=-1)  # (B, HW, HW)
            h_attn = attn @ v                                     # (B, HW, C)
            h_attn = h_attn.transpose(1, 2).reshape(B, C, H, W)
            h = h + h_attn
        return h


class TimeEmbed(nn.Module):
    def __init__(self, dim: int = 128):
        super().__init__()
        self.fc1 = nn.Linear(1, dim)
        self.fc2 = nn.Linear(dim, dim)

    def forward(self, t: torch.Tensor) -> torch.Tensor:
        # t: (B,1,1,1) or (B,1)
        if t.dim() == 4:
            t = t.view(t.size(0), 1)
        x = F.silu(self.fc1(t))
        x = F.silu(self.fc2(x))
        return x  # (B, dim)


class UNetNoisePredictor(nn.Module):
    def __init__(self, widths, block_depth, attention=False, dim1=256, dim2=128, t_dim=128):
        super().__init__()
        self.dim1, self.dim2 = dim1, dim2
        self.t_embed = TimeEmbed(t_dim)

        chs = widths
        in_ch = 1

        # -------- Down path
        self.down_blocks = nn.ModuleList()
        self.downs = nn.ModuleList()
        for i, ch in enumerate(chs):
            blocks = nn.ModuleList()
            for d in range(block_depth):
                blocks.append(FiLMBlock(in_ch, ch, t_dim, use_attn=attention and (i >= len(chs)//2)))
                in_ch = ch  # after first block, in_ch == ch, so subsequent are ch->ch
            self.down_blocks.append(blocks)
            if i < len(chs) - 1:
                self.downs.append(nn.Conv2d(ch, ch, 3, stride=2, padding=1))

        # -------- Mid
        self.mid = FiLMBlock(chs[-1], chs[-1], t_dim, use_attn=attention)

        # -------- Up path
        self.up_convs = nn.ModuleList()
        self.up_blocks = nn.ModuleList()
        for i, ch in reversed(list(enumerate(chs[:-1]))):
            # upsample + reduce channels from chs[i+1] -> ch
            self.up_convs.append(nn.Sequential(
                nn.Upsample(scale_factor=2, mode="nearest"),
                nn.Conv2d(chs[i+1], ch, 3, padding=1),
            ))

            # First block sees concat([h, skip]) => 2*ch in, ch out
            blocks = nn.ModuleList()
            if block_depth >= 1:
                blocks.append(FiLMBlock(2 * ch, ch, t_dim, use_attn=attention and (i >= len(chs)//2)))
            # Remaining blocks are ch -> ch
            for _ in range(block_depth - 1):
                blocks.append(FiLMBlock(ch, ch, t_dim, use_attn=attention and (i >= len(chs)//2)))
            self.up_blocks.append(blocks)

        self.out = nn.Conv2d(chs[0], 1, 1)

    def forward(self, x: torch.Tensor, t_in: torch.Tensor) -> torch.Tensor:
        """
        x: (B,1,H,W); t_in: (B,1,1,1) or (B,1) with values in [0,1]
        """
        t_emb = self.t_embed(t_in)

        # Down
        skips = []
        h = x
        for i, blocks in enumerate(self.down_blocks):
            for blk in blocks:
                h = blk(h, t_emb)
            skips.append(h)
            if i < len(self.downs):
                h = self.downs[i](h)

        # Mid
        h = self.mid(h, t_emb)

        # Up
        for up, blocks, skip in zip(self.up_convs, self.up_blocks, reversed(skips[:-1])):
            h = up(h)                       # -> (B, ch, H, W)
            h = torch.cat([h, skip], dim=1) # -> (B, 2*ch, H, W)
            for blk in blocks:
                h = blk(h, t_emb)           # first block expects 2*ch in, rest ch in

        return self.out(h)


# %% =============================== DDIM model (PyTorch)

class DDIMTorch(nn.Module):
    def __init__(self, widths, block_depth, attention=False, dim1=256, dim2=128, device="cuda"):
        super().__init__()
        self.device = torch.device(device if torch.cuda.is_available() else "cpu")

        self.normalizer = AdaptiveNormalizer()
        self.network = UNetNoisePredictor(widths, block_depth, attention=attention, dim1=dim1, dim2=dim2).to(self.device)
        self.ema_network = UNetNoisePredictor(widths, block_depth, attention=attention, dim1=dim1, dim2=dim2).to(self.device)
        self.ema_network.load_state_dict(self.network.state_dict())
        self.spec_mod = 0.0
        self.dx_mod = 0.0

        self.mse = nn.MSELoss()

    @torch.no_grad()
    def diffusion_schedule(self, diffusion_times: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # diffusion_times: (B,1,1,1) in [0,1]
        start_angle = math.acos(max_signal_rate)
        end_angle = math.acos(min_signal_rate)
        diffusion_angles = start_angle + diffusion_times * (end_angle - start_angle)
        signal_rates = torch.cos(diffusion_angles)
        noise_rates = torch.sin(diffusion_angles)
        return noise_rates, signal_rates

    def denoise(self, noisy_data: torch.Tensor, noise_rates: torch.Tensor, signal_rates: torch.Tensor, training: bool = True):
        net = self.network if training else self.ema_network
        cond = (noise_rates ** 2).to(noisy_data.dtype)
        pred_noises = net(noisy_data, cond)
        pred_data = (noisy_data - noise_rates * pred_noises) / (signal_rates + 1e-6)
        return pred_noises, pred_data

    @torch.no_grad()
    def reverse_diffusion(self, initial_noise: torch.Tensor, diffusion_steps: int) -> torch.Tensor:
        B = initial_noise.size(0)
        step_size = 1.0 / diffusion_steps
        next_noisy = initial_noise
        for step in tqdm(range(diffusion_steps), desc="sampling", dynamic_ncols=True, leave=False):
            noisy = next_noisy
            t = torch.ones((B, 1, 1, 1), device=self.device) - step * step_size
            noise_rates, signal_rates = self.diffusion_schedule(t)
            pred_noises, pred_data = self.denoise(noisy, noise_rates, signal_rates, training=False)
            t_next = t - step_size
            next_noise_rates, next_signal_rates = self.diffusion_schedule(t_next)
            next_noisy = next_signal_rates * pred_data + next_noise_rates * pred_noises
        return pred_data

    @torch.no_grad()
    def generate(self, num_examples: int, shape: Tuple[int, int, int], diffusion_steps: int) -> torch.Tensor:
        # shape: (C,H,W)
        initial_noise = torch.randn((num_examples, *shape), device=self.device)
        generated = self.reverse_diffusion(initial_noise, diffusion_steps)
        denorm = self.normalizer.denormalize(generated)
        return torch.clamp(denorm, -128.0, 128.0)

    def _get_losses(self, y_true: torch.Tensor, y_pred: torch.Tensor):
        l = self.mse(y_pred, y_true)
        s = spectral_norm_diff(y_pred, y_true)
        d = time_derivative_loss(y_pred, y_true)
        return l, s, d

    def training_step(self, batch: torch.Tensor, optimizer: torch.optim.Optimizer):
        batch = batch.to(self.device).float()  # (B,1,H,W)

        data = self.normalizer(batch)
        noises = torch.randn_like(data)

        B = data.size(0)
        diffusion_times = torch.rand((B, 1, 1, 1), device=self.device)
        noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
        noisy = signal_rates * data + noise_rates * noises

        pred_noises, pred_data = self.denoise(noisy, noise_rates, signal_rates, training=True)
        noise_loss, noise_spec, noise_dx = self._get_losses(noises, pred_noises)
        data_loss, data_spec, data_dx = self._get_losses(data, pred_data)

        total_noise_loss = noise_loss + self.spec_mod * noise_spec + self.dx_mod * noise_dx

        optimizer.zero_grad(set_to_none=True)
        total_noise_loss.backward()
        optimizer.step()

        # EMA
        with torch.no_grad():
            for p, q in zip(self.network.parameters(), self.ema_network.parameters()):
                q.mul_(ema).add_(p, alpha=1.0 - ema)

        metrics = {
            "n_loss": noise_loss.detach().item(),
            "d_loss": data_loss.detach().item(),
            "n_spec": noise_spec.detach().item(),
            "d_spec": data_spec.detach().item(),
            "n_dx": noise_dx.detach().item(),
            "d_dx": data_dx.detach().item(),
            "n_total": total_noise_loss.detach().item(),
            "d_total": (data_loss + self.spec_mod * data_spec + self.dx_mod * data_dx).detach().item(),
        }
        return metrics

    @torch.no_grad()
    def eval_step(self, batch: torch.Tensor):
        batch = batch.to(self.device).float()
        data = self.normalizer(batch)
        noises = torch.randn_like(data)

        B = data.size(0)
        diffusion_times = torch.rand((B, 1, 1, 1), device=self.device)
        noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
        noisy = signal_rates * data + noise_rates * noises

        pred_noises, pred_data = self.denoise(noisy, noise_rates, signal_rates, training=False)
        noise_loss = self.mse(noises, pred_noises)
        data_loss = self.mse(data, pred_data)
        return {"n_loss": noise_loss.item(), "d_loss": data_loss.item()}


# %% =============================== example usage (notebook-friendly)

if __name__ == "__main__":
    # dataset
    loader = get_files_dataloader(
        "data/REAL_audio/*.wav",   # adjust your glob if needed
        out_len=3.3,
        max_feats=256,
        total_seconds=26,
        scale=1.0,
        batch_size=16,
        mdct_feats=256,
        rate=10_000,
        num_workers=0,             # notebook-safe
        pin_memory=False,          # notebook-safe
    )

    # peek a batch to get shape
    for batch in loader:
        shape = batch.shape  # (B, 1, H, W)
        break
    print("batch shape:", shape)

    # model
    model = DDIMTorch(
        widths=[128, 128, 128, 128],
        block_depth=2,
        attention=True,
        dim1=shape[2],
        dim2=shape[3],
        device="cuda",
    )

    # adapt normalizer with progress bar
    model.normalizer.adapt(loader, max_batches=256, device=str(model.device), show_progress=True)

    # optimizer
    opt = torch.optim.AdamW(model.network.parameters(), lr=2e-4)

    # optional aux terms
    model.spec_mod = 0.0
    model.dx_mod = 0.0

    # tiny training loop with tqdm
    model.train()
    for epoch in range(1):
        pbar = tqdm(loader, desc=f"epoch {epoch}", dynamic_ncols=True, leave=True)
        for batch in pbar:
            metrics = model.training_step(batch, optimizer=opt)
            pbar.set_postfix({
                "n_total": f"{metrics['n_total']:.4f}",
                "d_loss": f"{metrics['d_loss']:.4f}",
            })

    # sampling example with progress bar
    with torch.no_grad():
        gen = model.generate(num_examples=4, shape=(1, shape[2], shape[3]), diffusion_steps=50)
        print("generated:", gen.shape)  # (4, 1, H, W)


batch shape: torch.Size([16, 1, 256, 128])


epoch 0:   0%|          | 0/1623 [00:17<?, ?it/s]                     


AcceleratorError: CUDA error: out of memory
Search for `cudaErrorMemoryAllocation' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
