In [5]:
# %% [markdown]
# Real-only Music GAN (WGAN-GP) — Train on REAL_AUDIO only and generate new music
# Paste these cells into your `music_gan_realism.ipynb` (or run as-is in a new notebook).
# This version removes any need for AI/MP3 folders and fixes glob + typos.

# %%
# !pip install torch torchaudio librosa soundfile numpy scikit-learn tqdm --quiet

# %%
import os, random, math, io, warnings
import glob  # <-- use module form; call as glob.glob(...)
from pathlib import Path
from typing import List, Tuple

import numpy as np
import soundfile as sf
import librosa
import librosa.feature
import librosa.display
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm

warnings.filterwarnings("ignore", category=UserWarning)

# ----------------------------
# Config (edit as needed)
# ----------------------------
SEED = 17
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
REAL_AUDIO_DIR = "data/REAL_audio"  # <-- only this is required now

SR = 22050
N_FFT = 1024
HOP = 256
WIN = 1024
N_MELS = 128
FMIN = 20
FMAX = 8000

# frames per training window (mel time steps)
FRAMES = 256  # ≈ FRAMES*HOP/SR seconds (here ~2.97 s); raise if you have VRAM
WINDOW_SEC = FRAMES * HOP / SR

BATCH = 32
EPOCHS = 50
LR_G = 2e-4
LR_D = 2e-4
BETAS = (0.5, 0.9)
LAMBDA_GP = 10.0
N_CRITIC = 5  # D steps per G step

# Snapshot & sampling controls (to reduce between-epoch stalls)
SAVE_EVERY_EPOCHS = 5      # set to 1 to save every epoch
SAMPLES_PER_SNAPSHOT = 4   # number of WAVs to render per snapshot
SAMPLE_EVERY_EPOCHS = 5    # set to 1 to sample every epoch; None to disable
GRIFFIN_LIM_ITERS = 24     # 64 sounds a bit nicer but is much slower

SAVE_DIR = "runs/real_only_gan"
SAMPLES_DIR = f"{SAVE_DIR}/samples"
CKPT_DIR = f"{SAVE_DIR}/ckpts"
os.makedirs(SAMPLES_DIR, exist_ok=True)
os.makedirs(CKPT_DIR, exist_ok=True)

# ----------------------------
# Utils: audio <-> mel
# ----------------------------
EPS = 1e-7

def wav_to_logmel(y: np.ndarray) -> np.ndarray:
    S = librosa.feature.melspectrogram(
        y=y, sr=SR, n_fft=N_FFT, hop_length=HOP, win_length=WIN,
        n_mels=N_MELS, power=2.0, fmin=FMIN, fmax=FMAX,
    )
    S = np.maximum(S, EPS)
    logS = np.log(S)
    return logS  # natural log of power mel

def logmel_to_wav(logS: np.ndarray, length: int | None = None) -> np.ndarray:
    S = np.exp(logS)  # back to power mel
    y = librosa.feature.inverse.mel_to_audio(
        M=S, sr=SR, n_fft=N_FFT, hop_length=HOP, win_length=WIN,
        fmin=FMIN, fmax=FMAX, power=2.0, n_iter=GRIFFIN_LIM_ITERS
    )
    if length is not None and len(y) > length:
        y = y[:length]
    return y

# ----------------------------
# Dataset — REAL audio only
# ----------------------------
class RealMelDataset(Dataset):
    def __init__(self, root: str, frames: int = FRAMES):
        exts = (".wav", ".flac", ".mp3", ".ogg", ".m4a", ".aiff", ".aif")
        self.files = [p for p in glob.glob(os.path.join(root, "**", "*"), recursive=True)
                      if p.lower().endswith(exts)]
        self.frames = frames
        self.valid = []
        for p in self.files:
            try:
                info = sf.info(p)
                if info.samplerate <= 0 or info.frames < SR * 2:
                    continue
                # quick probe: try reading a tiny slice
                y, _ = librosa.load(p, sr=SR, mono=True, duration=1.0)
                if y is None or len(y) == 0:
                    continue
                self.valid.append(p)
            except Exception:
                continue
        if not self.valid:
            raise RuntimeError(f"No valid audio found under {root}. Checked {len(self.files)} files.")

    def __len__(self):
        return len(self.valid)

    def _load_segment(self, path: str):
        y, sr = librosa.load(path, sr=SR, mono=True)
        target_len = self.frames * HOP + WIN
        if len(y) < target_len:
            y = np.pad(y, (0, target_len - len(y)))
        max_start = max(0, len(y) - target_len)
        start = 0 if max_start == 0 else np.random.randint(0, max_start + 1)
        segment = y[start:start + target_len]
        return segment

    def __getitem__(self, idx):
        # robust loop: if a file breaks in a worker, try a different one
        for _ in range(5):
            path = self.valid[idx % len(self.valid)]
            try:
                segment = self._load_segment(path)
                logmel = wav_to_logmel(segment)
                if logmel.shape[1] < self.frames:
                    pad_w = self.frames - logmel.shape[1]
                    logmel = np.pad(logmel, ((0, 0), (0, pad_w)), mode='edge')
                logmel = logmel[:, :self.frames]
                x = torch.from_numpy(logmel).float().unsqueeze(0)
                return x
            except Exception:
                # pick a random fallback index
                idx = np.random.randint(0, len(self.valid))
                continue
        # last resort: return zeros to avoid crashing the worker
        x = torch.zeros(1, N_MELS, self.frames, dtype=torch.float32)
        return x

# ----------------------------
# Models: discriminator & generator on mel space
# ----------------------------
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        ch = [1, 32, 64, 128, 256]
        self.net = nn.Sequential(
            nn.Conv2d(ch[0], ch[1], 4, 2, 1), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ch[1], ch[2], 4, 2, 1), nn.InstanceNorm2d(ch[2], affine=True), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ch[2], ch[3], 4, 2, 1), nn.InstanceNorm2d(ch[3], affine=True), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ch[3], ch[4], 4, 2, 1), nn.InstanceNorm2d(ch[4], affine=True), nn.LeakyReLU(0.2, inplace=True),
        )
        self.head = nn.Conv2d(ch[4], 1, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        h = self.net(x)
        h = F.adaptive_avg_pool2d(h, (1, 1))
        score = self.head(h).view(x.size(0))
        return score

class Generator(nn.Module):
    def __init__(self, z_dim=128, out_h=N_MELS, out_w=FRAMES):
        super().__init__()
        self.out_h = out_h
        self.out_w = out_w
        base = 256
        self.fc = nn.Sequential(
            nn.Linear(z_dim, base*8), nn.ReLU(True),
            nn.Linear(base*8, base*16), nn.ReLU(True),
        )
        self.start_h, self.start_w = max(1, out_h//16), max(1, out_w//16)
        self.proj = nn.Linear(base*16, 256 * self.start_h * self.start_w)
        self.up = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, 2, 1), nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1), nn.ReLU(True),
            nn.ConvTranspose2d(64, 32, 4, 2, 1), nn.ReLU(True),
            nn.ConvTranspose2d(32, 16, 4, 2, 1), nn.ReLU(True),
            nn.Conv2d(16, 1, 3, 1, 1),
        )

    def forward(self, z):
        h = self.fc(z)
        h = self.proj(h)
        h = h.view(z.size(0), 256, self.start_h, self.start_w)
        x = self.up(h)
        # crop/resize to exact target if off-by-one due to strides
        if x.size(2) != self.out_h or x.size(3) != self.out_w:
            x = F.interpolate(x, size=(self.out_h, self.out_w), mode="bilinear", align_corners=False)
        return x

# ----------------------------
# WGAN-GP training (REAL only)
# ----------------------------

def gradient_penalty(D, real, fake):
    bsz = real.size(0)
    eps = torch.rand(bsz, 1, 1, 1, device=real.device)
    inter = eps * real + (1 - eps) * fake
    inter.requires_grad_(True)
    d_inter = D(inter)
    grads = torch.autograd.grad(
        outputs=d_inter, inputs=inter,
        grad_outputs=torch.ones_like(d_inter),
        create_graph=True, retain_graph=True, only_inputs=True
    )[0]
    grads = grads.view(bsz, -1)
    gp = ((grads.norm(2, dim=1) - 1.0) ** 2).mean()
    return gp

@torch.no_grad()
def sample_mels(G, n=8, z_dim=128):
    G.eval()
    z = torch.randn(n, z_dim, device=DEVICE)
    mels = G(z).cpu().numpy()  # [n,1,N_MELS,FRAMES]
    return mels

@torch.no_grad()
def save_audio_samples(G, step_tag: str, n=SAMPLES_PER_SNAPSHOT, z_dim=128):
    mels = sample_mels(G, n=n, z_dim=z_dim)
    for i, mel in enumerate(mels):
        logmel = mel[0]
        y = logmel_to_wav(logmel)
        y = np.clip(y, -1.0, 1.0)
        sf.write(f"{SAMPLES_DIR}/sample_{step_tag}_{i:02d}.wav", y, SR)


def train_real_only_wgan_gp(resume_path: str | None = None):
    torch.manual_seed(SEED); np.random.seed(SEED); random.seed(SEED)
    ds = RealMelDataset(REAL_AUDIO_DIR, frames=FRAMES)
    dl = DataLoader(ds, batch_size=BATCH, shuffle=True, num_workers=0, drop_last=True, pin_memory=False, persistent_workers=False)

    D = Discriminator().to(DEVICE)
    G = Generator(z_dim=128, out_h=N_MELS, out_w=FRAMES).to(DEVICE)

    if resume_path and os.path.exists(resume_path):
        G.load_state_dict(torch.load(resume_path, map_location=DEVICE))
        print(f"[Resume] Loaded generator weights from {resume_path}")

    optD = torch.optim.Adam(D.parameters(), lr=LR_D, betas=BETAS)
    optG = torch.optim.Adam(G.parameters(), lr=LR_G, betas=BETAS)

    global_step = 0
    for epoch in range(1, EPOCHS + 1):
        pbar = tqdm(dl, desc=f"Epoch {epoch}/{EPOCHS}")
        for real in pbar:
            real = real.to(DEVICE)
            bsz = real.size(0)
            # --------------------
            # (1) Update D N_CRITIC times
            # --------------------
            for _ in range(N_CRITIC):
                z = torch.randn(bsz, 128, device=DEVICE)
                fake = G(z).detach()
                d_real = D(real)
                d_fake = D(fake)
                gp = gradient_penalty(D, real, fake)
                lossD = (d_fake - d_real).mean() + LAMBDA_GP * gp
                optD.zero_grad(set_to_none=True)
                lossD.backward()
                optD.step()

            # --------------------
            # (2) Update G once
            # --------------------
            z = torch.randn(bsz, 128, device=DEVICE)
            fake = G(z)
            lossG = -D(fake).mean()
            optG.zero_grad(set_to_none=True)
            lossG.backward()
            optG.step()

            global_step += 1
            if global_step % 100 == 0:
                pbar.set_postfix({"lossD": float(lossD.item()), "lossG": float(lossG.item())})
        # end epoch — heavy I/O moved behind frequency gates
        if (epoch % SAVE_EVERY_EPOCHS) == 0 or epoch == EPOCHS:
            torch.save(G.state_dict(), f"{CKPT_DIR}/G_epoch{epoch:03d}.pt")
            torch.save(D.state_dict(), f"{CKPT_DIR}/D_epoch{epoch:03d}.pt")
        if SAMPLE_EVERY_EPOCHS is not None and ((epoch % SAMPLE_EVERY_EPOCHS) == 0 or epoch == EPOCHS):
            save_audio_samples(G, step_tag=f"e{epoch}", n=SAMPLES_PER_SNAPSHOT)
    return G, D

# %% [markdown]
# ---- Quickstart ----
# 1) Put your real audio under REAL_AUDIO_DIR (any extension supported by soundfile/librosa)
# 2) Run: G, D = train_real_only_wgan_gp()
# 3) Generate new music snippets at any time: save_audio_samples(G, step_tag="manual", n=8)
#    Files will appear under SAMPLES_DIR.
# 4) (Optional) Resume training: G, D = train_real_only_wgan_gp(resume_path="runs/real_only_gan/ckpts/G_epoch010.pt")

# %%
if __name__ == "__main__":
    print(f"Device: {DEVICE}")
    print(f"Training windows: ~{WINDOW_SEC:.2f}s | Mel: {N_MELS}x{FRAMES}")
    # Uncomment to train from script:
    # G, D = train_real_only_wgan_gp()
    pass


Device: cuda
Training windows: ~2.97s | Mel: 128x256


In [6]:
G, D = train_real_only_wgan_gp()
save_audio_samples(G, step_tag="manual", n=8)


Epoch 1/50: 100%|██████████| 31/31 [00:24<00:00,  1.25it/s]
Epoch 2/50: 100%|██████████| 31/31 [00:24<00:00,  1.24it/s]
Epoch 3/50: 100%|██████████| 31/31 [00:24<00:00,  1.24it/s]
Epoch 4/50: 100%|██████████| 31/31 [00:25<00:00,  1.23it/s, lossD=1.46, lossG=-0.00453]
Epoch 5/50: 100%|██████████| 31/31 [00:24<00:00,  1.24it/s]
Epoch 6/50: 100%|██████████| 31/31 [00:24<00:00,  1.24it/s]
Epoch 7/50: 100%|██████████| 31/31 [00:25<00:00,  1.23it/s, lossD=-8.99, lossG=3.98]
Epoch 8/50: 100%|██████████| 31/31 [00:24<00:00,  1.24it/s]
Epoch 9/50: 100%|██████████| 31/31 [00:24<00:00,  1.24it/s]
Epoch 10/50: 100%|██████████| 31/31 [00:25<00:00,  1.24it/s, lossD=-23.4, lossG=12.3]
Epoch 11/50: 100%|██████████| 31/31 [00:25<00:00,  1.23it/s]
Epoch 12/50: 100%|██████████| 31/31 [00:25<00:00,  1.24it/s]
Epoch 13/50: 100%|██████████| 31/31 [00:25<00:00,  1.23it/s, lossD=-35.7, lossG=16.9]
Epoch 14/50: 100%|██████████| 31/31 [00:24<00:00,  1.24it/s]
Epoch 15/50: 100%|██████████| 31/31 [00:24<00:00,  1

# Model -> DDIM

In [7]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

import os

import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow import keras
from tensorflow.keras import layers
import math
import matplotlib.pyplot as plt
from tqdm import tqdm
import librosa 
from glob import glob

import random
from functools import partial
import warnings
import IPython.display as ipd
from tensorflow.keras import mixed_precision
warnings.filterwarnings('ignore')


In [10]:
import torch
print(torch.cuda.is_available())   # should be True
print(torch.cuda.get_device_name(0))  # GPU name


True
NVIDIA GeForce RTX 4050 Laptop GPU


In [11]:
music_files = glob("data/REAL_audio/**/*.wav", recursive=True)
music_files[:10]


['data/REAL_audio\\blues.00000.wav',
 'data/REAL_audio\\blues.00001.wav',
 'data/REAL_audio\\blues.00002.wav',
 'data/REAL_audio\\blues.00003.wav',
 'data/REAL_audio\\blues.00004.wav',
 'data/REAL_audio\\blues.00005.wav',
 'data/REAL_audio\\blues.00006.wav',
 'data/REAL_audio\\blues.00007.wav',
 'data/REAL_audio\\blues.00008.wav',
 'data/REAL_audio\\blues.00009.wav']

In [13]:
min_signal_rate = 0.02
max_signal_rate = 0.95
ema = 0.999

def spectral_norm(pred, real):
    """Calculate difference in spectral norm between two batches of spectrograms."""
    norm_real = tf.norm(real, axis=(1,2)) + 1e-6
    norm_pred = tf.norm(pred, axis=(1,2)) + 1e-6
    return tf.reduce_mean(tf.abs(norm_real - norm_pred) / norm_real)

def time_derivative(pred, real, window=1):
    real_derivative = real[:, :-window, :, :] - real[:, window:, :, :]
    pred_derivative = pred[:, :-window, :, :] - pred[:, window:, :, :]
    return tf.reduce_mean(tf.keras.losses.MSE(real_derivative, pred_derivative))



class DDIM(keras.Model):
    """DDIM model modified from this tutorial: https://keras.io/examples/generative/ddim/"""
    
    def __init__(self, widths, block_depth, attention=False, dim1=256, dim2=128):
        super().__init__()

        self.normalizer = layers.Normalization(axis=(2,3))
        self.network = get_network(widths, block_depth, attention=attention, dim1=dim1, dim2=dim2)
        self.ema_network = keras.models.clone_model(self.network)
        self.spec_mod = 0
        self.dx_mod = 0

    def compile(self, **kwargs):
        super().compile(**kwargs)

        self.noise_loss_tracker = keras.metrics.Mean(name="n_loss")
        self.data_loss_tracker = keras.metrics.Mean(name="d_loss")
        
        self.noise_spec_tracker = keras.metrics.Mean(name="n_spec")
        self.data_spec_tracker = keras.metrics.Mean(name="d_spec")
        
        self.noise_dx_tracker = keras.metrics.Mean(name="n_dx")
        self.data_dx_tracker = keras.metrics.Mean(name="d_dx")
        
        self.noise_total_tracker = keras.metrics.Mean(name="n_total")
        self.data_total_tracker = keras.metrics.Mean(name="d_total")
        self.spec_mod = 0
        self.dx_mod = 0
    @property
    def metrics(self):
        return [
            self.noise_loss_tracker, 
            self.data_loss_tracker,
            
            self.noise_spec_tracker,
            self.data_spec_tracker,
            
            self.noise_dx_tracker,
            self.data_dx_tracker,
            
            self.noise_total_tracker,
            self.data_total_tracker
        ]
    
    def update_trackers(self, n_l, n_s, n_d, d_l, d_s, d_d):
        """Update all loss trackers."""
        n_t = n_l + n_s + n_d
        d_t = d_l + d_s + d_d
        
        for loss, tracker in zip(
            [n_l, n_s, n_d, n_t, d_l, d_s, d_d, d_t], 
            [
                self.noise_loss_tracker, self.noise_spec_tracker, self.noise_dx_tracker, self.noise_total_tracker,
                self.data_loss_tracker, self.data_spec_tracker, self.data_dx_tracker, self.data_total_tracker
            ]
        ):
            tracker.update_state(loss)
            
    def get_losses(self, y_true, y_pred):
        """Get losses for model."""
        return (
            tf.reduce_mean(
                self.loss(y_pred, y_true)
            ), spectral_norm(
                y_pred, y_true
            ), time_derivative(
                y_pred, y_true
            )
        )

    def denormalize(self, data):
        data = self.normalizer.mean + data * self.normalizer.variance**0.5
        return tf.clip_by_value(data, -128.0, 128.0)

    def diffusion_schedule(self, diffusion_times):
        start_angle = tf.acos(max_signal_rate)
        end_angle = tf.acos(min_signal_rate)
        diffusion_angles = start_angle + diffusion_times * (end_angle - start_angle)
        signal_rates = tf.cos(diffusion_angles)
        noise_rates = tf.sin(diffusion_angles)
        return noise_rates, signal_rates

    def denoise(self, noisy_data, noise_rates, signal_rates, training):
        if training:
            network = self.network
        else:
            network = self.ema_network
        pred_noises = network([noisy_data, noise_rates**2], training=training)
        pred_data = (noisy_data - noise_rates * pred_noises) / signal_rates

        return pred_noises, pred_data

    def reverse_diffusion(self, initial_noise, diffusion_steps):
        num_examples = tf.shape(initial_noise)[0]
        step_size = 1.0 / diffusion_steps

        # important line:
        # at the first sampling step, the "noisy data" is pure noise
        # but its signal rate is assumed to be nonzero (min_signal_rate)
        next_noisy_data = initial_noise
        for step in tqdm(range(diffusion_steps)):
            noisy_data = next_noisy_data

            # separate the current noisy data to its components
            diffusion_times = tf.ones((num_examples, 1, 1, 1)) - step * step_size
            noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
            pred_noises, pred_data = self.denoise(
                noisy_data, noise_rates, signal_rates, training=False
            )
            # network used in eval mode

            # remix the predicted components using the next signal and noise rates
            next_diffusion_times = diffusion_times - step_size
            next_noise_rates, next_signal_rates = self.diffusion_schedule(
                next_diffusion_times
            )
            next_noisy_data = (
                next_signal_rates * pred_data + next_noise_rates * pred_noises
            )
            # this new noisy data will be used in the next step

        return pred_data

    def generate(self, num_examples, shape, diffusion_steps):
        # noise -> data -> denormalized data
        initial_noise = tf.random.normal(shape=(num_examples, shape[0], shape[1], shape[2]))
        generated_data = self.reverse_diffusion(initial_noise, diffusion_steps)
        generated_data = self.denormalize(generated_data)
        return generated_data

    def train_step(self, data):
        batch_size = tf.shape(data)[0]
        # normalize data to have standard deviation of 1, like the noises
        data = self.normalizer(data, training=True)
        noises = tf.random.normal(shape=tf.shape(data))

        # sample uniform random diffusion times
        diffusion_times = tf.random.uniform(
            shape=(batch_size, 1, 1, 1), minval=0.0, maxval=1.0
        )
        noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
        noise_rates = noise_rates
        signal_rates = signal_rates
        # mix the data with noises accordingly
        noisy_data = signal_rates * data + noise_rates * noises

        with tf.GradientTape() as tape:
            # train the network to separate noisy data to their components
            pred_noises, pred_data = self.denoise(
                noisy_data, noise_rates, signal_rates, training=True
            )

            noise_loss, noise_spec, noise_dx = self.get_losses(noises, pred_noises) #safe_reduce_mean(self.loss(noises, pred_noises))  # used for training
            total_noise_loss = tf.reduce_sum([
                noise_loss, 
                self.spec_mod*noise_spec, 
                self.dx_mod*noise_dx
            ])
            data_loss, data_spec, data_dx = self.get_losses(data, pred_data) #safe_reduce_mean(self.loss(data, pred_data))  # only used as metric
        gradients = tape.gradient(noise_loss, self.network.trainable_weights)
        self.optimizer.apply_gradients(zip(gradients, self.network.trainable_weights))

        self.update_trackers(
            noise_loss, noise_spec, noise_dx,
            data_loss, data_spec, data_dx
        )

        # track the exponential moving averages of weights
        for weight, ema_weight in zip(self.network.weights, self.ema_network.weights):
            ema_weight.assign(ema * ema_weight + (1 - ema) * weight)

        # KID is not measured during the training phase for computational efficiency
        return {m.name: m.result() for m in self.metrics}

    def test_step(self, data):
        # normalize data to have standard deviation of 1, like the noises
        batch_size = tf.shape(data)[0]
        
        data = self.normalizer(data, training=False)
        noises = tf.random.normal(shape=tf.shape(data))

        # sample uniform random diffusion times
        diffusion_times = tf.random.uniform(
            shape=(batch_size, 1, 1, 1), minval=0.0, maxval=1.0
        )
        noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
        # mix the data with noises accordingly
        noisy_data = signal_rates * data + noise_rates * noises

        # use the network to separate noisy data to their components
        pred_noises, pred_data = self.denoise(
            noisy_data, noise_rates, signal_rates, training=False
        )

        noise_loss = self.loss(noises, pred_noises)
        data_loss = self.loss(data, pred_data)

        self.data_loss_tracker.update_state(data_loss)
        self.noise_loss_tracker.update_state(noise_loss)

        return {m.name: m.result() for m in self.metrics}




In [14]:
def load_at_interval(x, rate=10_000, feats=256, duration=3.3):
    """Load music from file at some offset. Return MDCT spectrogram of that data"""
    file = x[0].numpy().decode()
    idx = x[1].numpy()
    audio, sr = librosa.load(file, duration=duration, sr=rate, offset=idx)
    audio_fill = np.zeros(int(rate*duration), dtype=np.float32)
    audio_fill[:len(audio)] = audio
    spec = tf.signal.mdct(audio_fill, feats)
    return spec

def load_audio(x,y, rate=10_000, mdct_feats=256, duration=3.3):
    """TF function for loading MDCT spectrogram from file."""
    out = tf.py_function(lambda x,y: load_at_interval( 
        (x,y), rate=rate, feats=mdct_feats, duration=duration
    ), inp=[x,y], Tout=tf.float32)
    return out

def get_files_dataset(
        glob_location,
        total_seconds=2,
        out_len = 3.3,
        hop_size=1,
        max_feats = 2048,
        batch_size=4,
        shuffer_size=1000,
        scale=1,
        rate=10_000,
        mdct_feats=256
    ):
    """Get file dataset loader for a glob of audio files."""
    
    files = glob(
        glob_location,
        recursive=True
    )
    
#     files = [file for file in files if file not in exclude]
    
    def file_list_generator():
        for _ in range(total_seconds):
            for file in files:
                yield file, _*hop_size
                
    load_fn = partial(load_audio, duration=out_len, rate=rate, mdct_feats=mdct_feats)
                
    dg =tf.data.Dataset.from_generator(file_list_generator, output_signature = (
        tf.TensorSpec(shape=(), dtype=tf.string), 
        tf.TensorSpec(shape=(), dtype=tf.int32))).shuffle(shuffer_size).map(
            load_fn, num_parallel_calls=tf.data.AUTOTUNE
        ).map(
            lambda x: tf.expand_dims(x, -1)[:max_feats, :, :]*scale
        ).map(
            lambda x: tf.ensure_shape(x, (max_feats, mdct_feats//2, 1))
        ).batch(batch_size).prefetch(tf.data.AUTOTUNE)
    
    return dg

In [15]:
dataset = get_files_dataset(
    "data/REAL_audio/**/*.wav", 
    out_len=3.3, 
    max_feats=256, 
    total_seconds=26, 
    scale=1,
    batch_size=16
)


In [16]:
for test_batch in dataset.take(1):
    shape = test_batch.shape

print(shape)


(16, 256, 128, 1)


In [19]:
num_total_examples = (len(music_files) * 26) // shape[0]
num_total_examples

1625

In [20]:
def block(x, t_emb, ch, use_attn=False):
    # First conv
    h = layers.Conv2D(ch, 3, padding="same")(x)
    h = layers.GroupNormalization(groups=8)(h)
    # FiLM-like conditioning with time embedding
    gamma = layers.Dense(ch)(t_emb)
    beta  = layers.Dense(ch)(t_emb)
    h = layers.Add()([h, gamma])  # shift
    h = layers.Activation("swish")(h)
    h = layers.Add()([h, beta])   # bias

    # Second conv
    h = layers.Conv2D(ch, 3, padding="same")(h)
    h = layers.GroupNormalization(groups=8)(h)
    h = layers.Activation("swish")(h)

    # Optional lightweight self-attention
    if use_attn:
        q = layers.Conv2D(ch, 1, padding="same")(h)
        k = layers.Conv2D(ch, 1, padding="same")(h)
        v = layers.Conv2D(ch, 1, padding="same")(h)
        attn = layers.Softmax(axis=-1)(
            layers.Lambda(lambda x: tf.matmul(
                tf.reshape(x[0], [tf.shape(x[0])[0], -1, tf.shape(x[0])[-1]]),
                tf.transpose(tf.reshape(x[1], [tf.shape(x[1])[0], -1, tf.shape(x[1])[-1]]), [0, 2, 1])
            ))([q, k])
        )
        v_flat = layers.Lambda(lambda x: tf.reshape(x, [tf.shape(x)[0], -1, tf.shape(x)[-1]]))(v)
        h_attn = layers.Lambda(lambda x: tf.reshape(
            tf.matmul(x[0], x[1]),
            [tf.shape(x[0])[0], tf.shape(h)[1], tf.shape(h)[2], tf.shape(h)[-1]]
        ))([attn, v_flat])
        h = layers.Add()([h, h_attn])

    return h

def time_embedding(t, emb_dim=128):
    # t is shape (B,1,1,1). Flatten -> MLP -> (B, emb_dim)
    x = layers.Flatten()(t)
    x = layers.Dense(emb_dim, activation="swish")(x)
    x = layers.Dense(emb_dim, activation="swish")(x)
    # reshape to (B,1,1,emb_dim) so Dense layers above can broadcast
    return layers.Reshape((1, 1, emb_dim))(x)

def get_network(widths, block_depth, attention=False, dim1=256, dim2=128):
    """U-Net–ish noise predictor conditioned on diffusion time."""
    x_in = layers.Input(shape=(dim1, dim2, 1))
    t_in = layers.Input(shape=(1, 1, 1))
    t_emb = time_embedding(t_in, emb_dim=128)

    # Down path
    skips = []
    x = x_in
    for i, ch in enumerate(widths):
        for _ in range(block_depth):
            x = block(x, t_emb, ch, use_attn=attention and (i >= len(widths)//2))
        skips.append(x)
        if i < len(widths) - 1:
            x = layers.Conv2D(ch, 3, strides=2, padding="same")(x)

    # Mid
    x = block(x, t_emb, widths[-1], use_attn=attention)

    # Up path
    for i, ch in reversed(list(enumerate(widths[:-1]))):
        x = layers.UpSampling2D()(x)
        x = layers.Conv2D(ch, 3, padding="same")(x)
        x = layers.Concatenate()([x, skips[i]])
        for _ in range(block_depth):
            x = block(x, t_emb, ch, use_attn=attention and (i >= len(widths)//2))

    # Output noise prediction
    out = layers.Conv2D(1, 1, padding="same")(x)
    return keras.Model([x_in, t_in], out, name="ddim_unet")


In [21]:
model = DDIM(widths = [128, 128, 128, 128], block_depth = 2, 
             attention=True, dim1=shape[1], dim2=shape[2])

TypeError: Could not build a TypeSpec for KerasTensor(type_spec=TensorSpec(shape=(None, 64, 32, 128), dtype=tf.float32, name=None), name='tf.reshape/Reshape:0', description="created by layer 'tf.reshape'") of unsupported type <class 'keras.src.engine.keras_tensor.KerasTensor'>.

In [None]:
model.normalizer.adapt(dataset, steps=10)


In [None]:
model.compile(
    loss=tf.keras.losses.MSE,
    optimizer= tfa.optimizers.AdamW(
        learning_rate = 3e-4,
        weight_decay = 1e-4
    )
)


In [None]:
dataset = dataset.cache()


In [None]:
history = model.fit(dataset.repeat(), steps_per_epoch=num_total_examples, epochs=1)


In [None]:
model.spec_mod = 1
model.dx_mod = 1


In [None]:
history = model.fit(dataset.repeat(), steps_per_epoch=num_total_examples, epochs=100)


In [None]:
specs = model.generate(8, shape[1:], 1000)


In [None]:
for i in range(4):
    plt.pcolormesh(np.log(np.abs(test_batch[i, :, :, 0].numpy().T)))
    plt.colorbar()
    plt.title(f"Real example {i+1}")
    plt.show()
    ipd.display(ipd.Audio(tf.signal.inverse_mdct(test_batch[i, :, :, 0]), rate=10_000))


In [None]:
for i in range(len(specs)):
    plt.pcolormesh(np.log(np.abs(specs[i, :, :, 0].numpy().T)))
    plt.colorbar()
    plt.title(f"Generated example {i+1}")
    plt.show()
    ipd.display(ipd.Audio(tf.signal.inverse_mdct(tf.cast(specs[i, :, :, 0], tf.float32)), rate=10_000))
