
# Music Realism Scoring with WGAN-GP (Log-Mel, Codec Parity)

This notebook trains a **WGAN-GP on real music only** (using **log-mel spectrograms**) and then uses the **discriminator** as a **realism score** for new tracks (AI vs. real).  
It handles **codec parity** (WAV ↔ MP3 round-trip) to prevent shortcut learning, and fixes input shapes to **[1, 128, 256]** (channels, mels, frames).


In [1]:
import torch
print("PyTorch:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("Device count:", torch.cuda.device_count())
    print("GPU 0:", torch.cuda.get_device_name(0))


PyTorch: 2.9.0+cu128
CUDA available: True
Device count: 1
GPU 0: NVIDIA GeForce RTX 4050 Laptop GPU


In [32]:
import os
import glob  # instead of: from glob import glob


SEED = 17
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
REAL_AUDIO_DIR = "data/REAL_audio" # <-- only this is required now


SR = 22050
N_FFT = 1024
HOP = 256
WIN = 1024
N_MELS = 128
FMIN = 20
FMAX = 8000


# frames per training window (mel time steps)
FRAMES = 256 # ≈ FRAMES*HOP/SR seconds (here ~2.97 s); raise if you have VRAM
WINDOW_SEC = FRAMES * HOP / SR


BATCH = 32
EPOCHS = 50
LR_G = 2e-4
LR_D = 2e-4
BETAS = (0.5, 0.9)
LAMBDA_GP = 10.0
N_CRITIC = 5 # D steps per G step


SAVE_DIR = "runs/real_only_gan"
SAMPLES_DIR = f"{SAVE_DIR}/samples"
CKPT_DIR = f"{SAVE_DIR}/ckpts"
os.makedirs(SAMPLES_DIR, exist_ok=True)
os.makedirs(CKPT_DIR, exist_ok=True)

In [33]:

import os, sys, math, random, shutil, tempfile, subprocess, warnings, glob
from pathlib import Path
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import tqdm
import torchaudio
from torchaudio.transforms import MelSpectrogram, AmplitudeToDB

# Optional deps
try:
    import pyloudnorm as pyln
    HAVE_PYL = True
except Exception:
    HAVE_PYL = False

try:
    import soundfile as sf
    HAVE_SF = True
except Exception:
    HAVE_SF = False


# sklearn is optional for metrics
try:
    from sklearn.metrics import roc_auc_score, average_precision_score
    HAVE_SK = True
except Exception:
    HAVE_SK = False

# Check FFmpeg availability (for MP3 round-trip)
def _have_ffmpeg():
    try:
        subprocess.run(["ffmpeg", "-version"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=False)
        return True
    except Exception:
        return False

HAVE_FFMPEG = _have_ffmpeg()

# Device selection
if DEVICE == "cuda" and not torch.cuda.is_available():
    DEVICE = "cpu"
print(f"Device: {DEVICE} | pyloudnorm: {HAVE_PYL} | ffmpeg: {HAVE_FFMPEG} | sklearn: {HAVE_SK}")
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED); 
if DEVICE == "cuda":
    torch.cuda.manual_seed_all(SEED)


Device: cuda | pyloudnorm: True | ffmpeg: True | sklearn: True


In [34]:
import librosa

try:
    import soundfile as sf
    HAVE_SF = True
except Exception:
    HAVE_SF = False


In [26]:
# from pathlib import Path

# from pathlib import Path
# import numpy as np
# import torch

# def load_audio(path, sr=SR):
#     """
#     Robust loader that avoids TorchCodec crashes:
#     1) For WAV/AIFF/FLAC: try soundfile (libsndfile).
#     2) For everything (incl. MP3/M4A): try librosa (audioread/ffmpeg).
#     3) Last resort: torchaudio.load.
#     Returns mono float tensor at target SR.
#     """
#     ext = Path(path).suffix.lower()

#     # 1) Prefer soundfile for lossless containers
#     if HAVE_SF and ext in [".wav", ".aiff", ".aif", ".flac", ".ogg"]:
#         try:
#             y, file_sr = sf.read(str(path), dtype="float32", always_2d=False)
#             if y.ndim == 2:
#                 y = y.mean(axis=1)
#             wav_t = torch.from_numpy(y)
#             if file_sr != SR:
#                 wav_t = torchaudio.functional.resample(wav_t.unsqueeze(0), file_sr, SR).squeeze(0)
#             return wav_t.contiguous()
#         except Exception:
#             pass

#     # 2) Librosa for anything (mp3, m4a, wav, etc.)
#     try:
#         y, file_sr = librosa.load(str(path), sr=None, mono=True)
#         y = y.astype(np.float32, copy=False)
#         wav_t = torch.from_numpy(y)
#         if file_sr != SR:
#             wav_t = torchaudio.functional.resample(wav_t.unsqueeze(0), file_sr, SR).squeeze(0)
#         return wav_t.contiguous()
#     except Exception:
#         pass

#     # 3) Fallback to torchaudio (may require TorchCodec)
#     wav, file_sr = torchaudio.load(str(path))  # [C, T]
#     wav_t = wav.mean(dim=0)
#     if file_sr != SR:
#         wav_t = torchaudio.functional.resample(wav_t.unsqueeze(0), file_sr, SR).squeeze(0)
#     return wav_t.contiguous()

# def lufs_normalize(wav_t, sr=SR, target_lufs=TARGET_LUFS):
#     if HAVE_PYL:
#         y = wav_t.detach().cpu().numpy().astype(np.float32)
#         meter = pyln.Meter(sr)
#         try:
#             loud = meter.integrated_loudness(y)
#             gain_db = target_lufs - loud
#             gain = 10 ** (gain_db / 20.0)
#             y = np.clip(y * gain, -1.0, 1.0)
#             return torch.from_numpy(y)
#         except Exception:
#             pass
#     # Fallback: simple peak normalization
#     peak = wav_t.abs().max().item()
#     if peak > 0:
#         wav_t = wav_t / peak
#     return wav_t

# def mp3_roundtrip(wav_t, sr=SR, bitrate=MP3_BITRATE):
#     if not HAVE_FFMPEG:
#         return wav_t
#     try:
#         with tempfile.TemporaryDirectory() as td:
#             wav_path = os.path.join(td, "tmp.wav")
#             mp3_path = os.path.join(td, "tmp.mp3")
#             torchaudio.save(wav_path, wav_t.unsqueeze(0), sr)
#             cmd = ["ffmpeg", "-y", "-hide_banner", "-loglevel", "error",
#                    "-i", wav_path, "-b:a", bitrate, mp3_path]
#             res = subprocess.run(cmd, check=False)
#             if res.returncode != 0 or not os.path.exists(mp3_path):
#                 # fall back gracefully
#                 return wav_t
#             wav2, sr2 = torchaudio.load(mp3_path)
#             if sr2 != sr:
#                 wav2 = torchaudio.functional.resample(wav2, sr2, sr)
#             return torch.mean(wav2, dim=0)
#     except Exception:
#         return wav_t

# # Log-mel transforms
# _mel = MelSpectrogram(
#     sample_rate=SR, n_fft=N_FFT, hop_length=HOP,
#     n_mels=N_MELS, f_min=FMIN, f_max=FMAX, center=True, power=2.0
# )
# _to_db = AmplitudeToDB(stype="power")

# def to_logmel(wav_t):
#     # wav_t: [T] float32/float64
#     x = wav_t.unsqueeze(0)  # [1, T]
#     mel = _mel(x)           # [1, mels, frames]
#     mel_db = _to_db(mel)    # [1, mels, frames]
#     # min-max scale to [-1, 1] per-sample
#     m = mel_db.amin(dim=(1,2), keepdim=True)
#     M = mel_db.amax(dim=(1,2), keepdim=True)
#     mel_n = (mel_db - m) / (M - m + 1e-9)
#     mel_n = mel_n * 2.0 - 1.0
#     return mel_n   # [1, mels, frames] range ~ [-1,1]

# def fix_frames(spec_1mT, frames=FRAMES):
#     # spec_1mT: [1, mels, T]
#     T = spec_1mT.shape[-1]
#     if T == frames:
#         return spec_1mT
#     spec = F.interpolate(spec_1mT.unsqueeze(0), size=(N_MELS, frames), mode="bilinear", align_corners=False)
#     return spec.squeeze(0)  # [1, mels, frames]

# def preprocess_file(path, codec_parity=False):
#     wav = load_audio(path, sr=SR)
#     if codec_parity:
#         wav = mp3_roundtrip(wav, sr=SR)
#     wav = lufs_normalize(wav, sr=SR, target_lufs=TARGET_LUFS)
#     # Extract a center window of ~WIN_SECS before mel (or pad if short)
#     N_SAMPLES = int(SR * WIN_SECS)
#     if wav.numel() < N_SAMPLES:
#         wav = F.pad(wav, (0, N_SAMPLES - wav.numel()))
#     else:
#         start = (wav.numel() - N_SAMPLES) // 2
#         wav = wav[start:start+N_SAMPLES]
#     mel = to_logmel(wav)          # [1, 128, T']
#     mel = fix_frames(mel, FRAMES) # [1, 128, 256]
#     return mel


In [35]:
import librosa

EPS = 1e-7


def wav_to_logmel(y: np.ndarray) -> np.ndarray:
    S = librosa.feature.melspectrogram(
        y=y, sr=SR, n_fft=N_FFT, hop_length=HOP, win_length=WIN,
        n_mels=N_MELS, power=2.0, fmin=FMIN, fmax=FMAX,
        )
    S = np.maximum(S, EPS)
    logS = np.log(S)
    return logS # natural log of power mel


def logmel_to_wav(logS: np.ndarray, length: int | None = None) -> np.ndarray:
    S = np.exp(logS) # back to power mel
    y = librosa.feature.inverse.mel_to_audio(
        M=S, sr=SR, n_fft=N_FFT, hop_length=HOP, win_length=WIN,
        fmin=FMIN, fmax=FMAX, power=2.0, n_iter=64
        )
    if length is not None and len(y) > length:
        y = y[:length]
    return y

In [36]:
class RealMelDataset(Dataset):
    def __init__(self, root: str, frames: int = FRAMES):
        exts = (".wav", ".flac", ".mp3", ".ogg", ".m4a", ".aiff", ".aif")
        self.self.files = [p for p in glob.glob(os.path.join(root, "**", "*"), recursive=True)
              if p.lower().endswith(exts)]
        self.frames = frames
        self._index = []
        for p in self.files:
            try:
                info = sf.info(p)
                if info.samplerate <= 0 or info.frames < SR * 2:
                    continue
                self._index.append(p)
            except Exception:
                continue
        assert len(self._index) > 0, f"No valid audio found under {root}"

    def __len__(self):
        return len(self._index)

    def __getitem__(self, idx):
        path = self._index[idx]
        y, sr = librosa.load(path, sr=SR, mono=True)
        if len(y) < (self.frames+2)*HOP:
            # pad if short
            pad = (self.frames+2)*HOP - len(y)
            y = np.pad(y, (0, pad))
        # random crop (around frames)
        max_start = max(0, len(y) - (self.frames+2)*HOP)
        start = 0 if max_start == 0 else np.random.randint(0, max_start+1)
        y = y[start:start + (self.frames*HOP + WIN)]
        logmel = wav_to_logmel(y)
        # ensure exact frame length
        if logmel.shape[1] < self.frames:
            pad_w = self.frames - logmel.shape[1]
            logmel = np.pad(logmel, ((0,0),(0,pad_w)), mode='edge')
        logmel = logmel[:, :self.frames]
        x = torch.from_numpy(logmel).float().unsqueeze(0)  # [1, n_mels, frames]
        return x



In [37]:
# ----------------------------
# Models: simple conv discriminator & generator on mel space
# ----------------------------
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        ch = [1, 32, 64, 128, 256]
        self.net = nn.Sequential(
            nn.Conv2d(ch[0], ch[1], 4, 2, 1), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ch[1], ch[2], 4, 2, 1), nn.InstanceNorm2d(ch[2], affine=True), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ch[2], ch[3], 4, 2, 1), nn.InstanceNorm2d(ch[3], affine=True), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ch[3], ch[4], 4, 2, 1), nn.InstanceNorm2d(ch[4], affine=True), nn.LeakyReLU(0.2, inplace=True),
        )
        # project to scalar
        self.out = nn.Sequential(
            nn.Conv2d(ch[4], 1, kernel_size=4, stride=1, padding=0),
        )

    def forward(self, x):
        h = self.net(x)
        # adaptive pooling to handle variable dims (if frames change)
        h = F.adaptive_avg_pool2d(h, (1,1))
        score = self.out(h).view(x.size(0))
        return score

class Generator(nn.Module):
    def __init__(self, z_dim=128, out_h=N_MELS, out_w=FRAMES):
        super().__init__()
        self.out_h = out_h
        self.out_w = out_w
        base = 256
        self.fc = nn.Sequential(
            nn.Linear(z_dim, base*8), nn.ReLU(True),
            nn.Linear(base*8, base*16), nn.ReLU(True),
        )
        # reshape to a small 2D map and upsample with convTranspose
        self.start_h, self.start_w = out_h//16, out_w//16
        self.proj = nn.Linear(base*16, 256 * self.start_h * self.start_w)
        self.up = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, 2, 1), nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1), nn.ReLU(True),
            nn.ConvTranspose2d(64, 32, 4, 2, 1), nn.ReLU(True),
            nn.ConvTranspose2d(32, 16, 4, 2, 1), nn.ReLU(True),
            nn.Conv2d(16, 1, 3, 1, 1),
        )

    def forward(self, z):
        h = self.fc(z)
        h = self.proj(h)
        h = h.view(z.size(0), 256, self.start_h, self.start_w)
        x = self.up(h)
        # output unconstrained log-mel; optionally clamp for stability
        return x


In [38]:
# ----------------------------
# WGAN-GP training (REAL only)
# ----------------------------

def gradient_penalty(D, real, fake):
    bsz = real.size(0)
    eps = torch.rand(bsz, 1, 1, 1, device=real.device)
    inter = eps * real + (1 - eps) * fake
    inter.requires_grad_(True)
    d_inter = D(inter)
    grads = torch.autograd.grad(
        outputs=d_inter, inputs=inter,
        grad_outputs=torch.ones_like(d_inter),
        create_graph=True, retain_graph=True, only_inputs=True
    )[0]
    grads = grads.view(bsz, -1)
    gp = ((grads.norm(2, dim=1) - 1.0) ** 2).mean()
    return gp

@torch.no_grad()
def sample_mels(G, n=8, z_dim=128):
    G.eval()
    z = torch.randn(n, z_dim, device=DEVICE)
    mels = G(z).cpu().numpy()  # [n,1,N_MELS,FRAMES]
    return mels

@torch.no_grad()
def save_audio_samples(G, step_tag: str, n=4, z_dim=128):
    mels = sample_mels(G, n=n, z_dim=z_dim)
    for i, mel in enumerate(mels):
        logmel = mel[0]
        y = logmel_to_wav(logmel)
        y = np.clip(y, -1.0, 1.0)
        sf.write(f"{SAMPLES_DIR}/sample_{step_tag}_{i:02d}.wav", y, SR)


def train_real_only_wgan_gp():
    torch.manual_seed(SEED); np.random.seed(SEED); random.seed(SEED)
    ds = RealMelDataset(REAL_AUDIO_DIR, frames=FRAMES)
    dl = DataLoader(ds, batch_size=BATCH, shuffle=True, num_workers=2, drop_last=True)

    D = Discriminator().to(DEVICE)
    G = Generator(z_dim=128, out_h=N_MELS, out_w=FRAMES).to(DEVICE)
    optD = torch.optim.Adam(D.parameters(), lr=LR_D, betas=BETAS)
    optG = torch.optim.Adam(G.parameters(), lr=LR_G, betas=BETAS)

    global_step = 0
    for epoch in range(1, EPOCHS+1):
        pbar = tqdm(dl, desc=f"Epoch {epoch}/{EPOCHS}")
        for real in pbar:
            real = real.to(DEVICE)
            bsz = real.size(0)
            # --------------------
            # (1) Update D
            # --------------------
            for _ in range(N_CRITIC):
                z = torch.randn(bsz, 128, device=DEVICE)
                fake = G(z).detach()
                d_real = D(real)
                d_fake = D(fake)
                gp = gradient_penalty(D, real, fake)
                lossD = (d_fake - d_real).mean() + LAMBDA_GP * gp
                optD.zero_grad(set_to_none=True)
                lossD.backward()
                optD.step()

            # --------------------
            # (2) Update G
            # --------------------
            z = torch.randn(bsz, 128, device=DEVICE)
            fake = G(z)
            lossG = -D(fake).mean()
            optG.zero_grad(set_to_none=True)
            lossG.backward()
            optG.step()

            global_step += 1
            if global_step % 100 == 0:
                pbar.set_postfix({"lossD": float(lossD.item()), "lossG": float(lossG.item())})
            if global_step % 1000 == 0:
                save_audio_samples(G, step_tag=f"e{epoch}_s{global_step}", n=4)
        # end epoch
        torch.save(G.state_dict(), f"{CKPT_DIR}/G_epoch{epoch:03d}.pt")
        torch.save(D.state_dict(), f"{CKPT_DIR}/D_epoch{epoch:03d}.pt")
        save_audio_samples(G, step_tag=f"e{epoch}", n=8)
    return G, D

# %% [markdown]
# ---- Quickstart ----
# 1) Put your real audio under REAL_AUDIO_DIR (any extension supported by soundfile/librosa)
# 2) Run: G, D = train_real_only_wgan_gp()
# 3) Generate new music snippets at any time: save_audio_samples(G, step_tag="manual", n=8)
#    Files will appear under SAMPLES_DIR.

# %%
if __name__ == "__main__":
    print(f"Device: {DEVICE}")
    print(f"Training windows: ~{WINDOW_SEC:.2f}s | Mel: {N_MELS}x{FRAMES}")
    # Uncomment to train from script:
    # G, D = train_real_only_wgan_gp()
    pass


Device: cuda
Training windows: ~2.97s | Mel: 128x256


In [39]:
G, D = train_real_only_wgan_gp()
save_audio_samples(G, step_tag="manual", n=8)

AttributeError: 'RealMelDataset' object has no attribute 'self'

In [None]:

# def spectral_conv2d(in_ch, out_ch, k, s, p):
#     return nn.utils.spectral_norm(nn.Conv2d(in_ch, out_ch, k, s, p))

# class Discriminator(nn.Module):
#     """PatchGAN-style: accepts [B,1,128,256] and outputs [B,1] scalar scores via global pooling."""
#     def __init__(self, use_spectral_norm=True):
#         super().__init__()
#         Conv = spectral_conv2d if use_spectral_norm else nn.Conv2d
#         chs = [1, 32, 64, 128, 256, 256]
#         self.net = nn.Sequential(
#             Conv(chs[0], chs[1],  (3,3), (2,2), (1,1)), nn.LeakyReLU(0.2, inplace=True),
#             Conv(chs[1], chs[2],  (3,3), (2,2), (1,1)), nn.LeakyReLU(0.2, inplace=True),
#             Conv(chs[2], chs[3],  (3,3), (2,2), (1,1)), nn.LeakyReLU(0.2, inplace=True),
#             Conv(chs[3], chs[4],  (3,3), (2,2), (1,1)), nn.LeakyReLU(0.2, inplace=True),
#             Conv(chs[4], chs[5],  (3,3), (2,2), (1,1)), nn.LeakyReLU(0.2, inplace=True),
#         )
#         self.head = nn.Linear(chs[5], 1)

#     def forward(self, x):
#         # x: [B,1,128,256]
#         feat = self.net(x)                     # [B,C,h,w]
#         feat = feat.mean(dim=(2,3))            # global average pool -> [B,C]
#         out = self.head(feat)                  # [B,1]
#         return out

# class MelEncoder(nn.Module):
#     """Encode [B,1,128,256] -> [B,256,8,16] feature map (to match G's bottleneck)."""
#     def __init__(self):
#         super().__init__()
#         chs = [1, 32, 64, 128, 256]
#         self.net = nn.Sequential(
#             nn.Conv2d(chs[0], chs[1], 3, 2, 1), nn.LeakyReLU(0.2, inplace=True),
#             nn.Conv2d(chs[1], chs[2], 3, 2, 1), nn.LeakyReLU(0.2, inplace=True),
#             nn.Conv2d(chs[2], chs[3], 3, 2, 1), nn.LeakyReLU(0.2, inplace=True),
#             nn.Conv2d(chs[3], chs[4], 3, 2, 1), nn.LeakyReLU(0.2, inplace=True),
#         )
#     def forward(self, x):  # x: [B,1,128,256]
#         return self.net(x)  # [B,256,8,16]

# class CondGenerator(nn.Module):
#     """
#     Conditional G: takes encoder(feature_of_real) + noise; outputs [B,1,128,256] mel.
#     """
#     def __init__(self, z_dim=128):
#         super().__init__()
#         self.z_dim = z_dim
#         self.enc_proj = nn.Conv2d(256, 256, 1)
#         self.z_fc = nn.Linear(z_dim, 256*8*16)

#         def block(in_ch, out_ch):
#             return nn.Sequential(
#                 nn.ConvTranspose2d(in_ch, out_ch, 4, 2, 1),
#                 nn.BatchNorm2d(out_ch),
#                 nn.ReLU(inplace=True)
#             )

#         self.up = nn.Sequential(
#             block(512, 256),   # (enc 256 + noise 256) -> 256
#             block(256, 128),
#             block(128, 64),
#             block(64, 32),
#             nn.ConvTranspose2d(32, 1, 3, 1, 1),
#             nn.Tanh()
#         )

#     def forward(self, enc_feat, z):
#         # enc_feat: [B,256,8,16], z: [B, z_dim]
#         enc = self.enc_proj(enc_feat)
#         z = self.z_fc(z).view(z.size(0), 256, 8, 16)
#         x = torch.cat([enc, z], dim=1)  # [B,512,8,16]
#         x = self.up(x)                  # [B,1,~128,~256]
#         x = F.interpolate(x, size=(N_MELS, FRAMES), mode="bilinear", align_corners=False)
#         return x


In [None]:

# def grad_penalty(D, real, fake, device):
#     B = real.size(0)
#     eps = torch.rand(B, 1, 1, 1, device=device)
#     x_hat = eps * real + (1 - eps) * fake
#     x_hat.requires_grad_(True)
#     d_hat = D(x_hat)
#     ones = torch.ones_like(d_hat, device=device)
#     grads = torch.autograd.grad(
#         outputs=d_hat, inputs=x_hat, grad_outputs=ones,
#         create_graph=True, retain_graph=True, only_inputs=True
#     )[0]
#     gp = ((grads.view(B, -1).norm(2, dim=1) - 1.0) ** 2).mean()
#     return gp

# from tqdm.auto import tqdm

# def train_wgan_gp(train_loader, epochs=EPOCHS, z_dim=Z_DIM, device=DEVICE):
#     if train_loader is None:
#         print("No training data. Populate REAL_WAV_DIR with .wav files and rerun.")
#         return None, None

#     E = MelEncoder().to(device)
#     G = CondGenerator(z_dim=Z_DIM).to(device)
#     D = Discriminator().to(device)

#     optD = torch.optim.Adam(D.parameters(), lr=LR_D, betas=(BETA1, BETA2))
#     optG = torch.optim.Adam(G.parameters(), lr=LR_G, betas=(BETA1, BETA2))

#     global_step = 0
#     for epoch in range(1, epochs + 1):
#         pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}", leave=True)
#         for mel in pbar:
#             mel = mel.to(device)
#             B = mel.size(0)
#             z = torch.randn(B, Z_DIM, device=device)

#             with torch.no_grad():
#                 enc = E(mel)

#             fake = G(enc, z)

#             # ===== Train D (N_CRITIC steps) =====
#             loss_D_val = 0.0
#             for _ in range(N_CRITIC):
#                 z = torch.randn(B, z_dim, device=device)
#                 fake = G(z).detach()
#                 d_real = D(mel).mean()
#                 d_fake = D(fake.detach()).mean()
#                 gp = grad_penalty(D, mel, fake.detach(), device)
#                 loss_D = (d_fake - d_real) + LAMBDA_GP * gp

#                 optD.zero_grad(set_to_none=True)
#                 loss_D.backward()
#                 optD.step()
#                 loss_D_val = loss_D.item()

#             # ===== Train G (1 step) =====
#             z = torch.randn(B, z_dim, device=device)
#             fake = G(z)
#             loss_G = -D(fake).mean()
#             optG.zero_grad(set_to_none=True)
#             loss_G.backward()
#             optG.step()
#             d_fake_for_G = D(fake).mean()
#             lambda_rec = 10.0  # tweak 1–20; higher = closer to input, lower = freer stylization
#             rec_loss = F.l1_loss(fake, mel)
#             loss_G = -d_fake_for_G + lambda_rec * rec_loss

#             optG.zero_grad(set_to_none=True)
#             loss_G.backward()
#             optG.step()


#             # tqdm stats
#             pbar.set_postfix({
#                 "D": f"{loss_D_val:.3f}",
#                 "G": f"{loss_G.item():.3f}",
#                 "d_real": f"{d_real.item():.3f}",
#                 "d_fake": f"{d_fake.item():.3f}"
#             })
#             global_step += 1

#     return G, D
# # Quick dry-run (skips if no data)
# # G, D = train_wgan_gp(train_loader, epochs=1)
# # torch.save(D.state_dict(), "D.pth"); torch.save(G.state_dict(), "G.pth")


  from .autonotebook import tqdm as notebook_tqdm


In [None]:

# @torch.no_grad()
# def score_windows_with_D(D, mels, device=DEVICE):
#     D.eval()
#     scores = []
#     for i in range(0, len(mels), 16):
#         batch = torch.stack(mels[i:i+16], dim=0).to(device)  # [B,1,128,256]
#         s = D(batch).squeeze(1).detach().cpu().numpy()
#         scores.extend(s.tolist())
#     return float(np.mean(scores)), scores

# def slice_track_to_mels(path, codec_parity_for_real=False, step_secs=10.0):
#     """Return list of [1,128,256] mel windows for an entire track.
#     - For real WAVs, pass codec_parity_for_real=True to round-trip MP3.
#     - For AI MP3s, leave False (already MP3).
#     """
#     wav = load_audio(path, sr=SR)
#     if codec_parity_for_real:
#         wav = mp3_roundtrip(wav, sr=SR)
#     wav = lufs_normalize(wav, sr=SR, target_lufs=TARGET_LUFS)

#     N = int(SR * WIN_SECS)
#     step = int(SR * step_secs)
#     if wav.numel() < N:
#         wav = F.pad(wav, (0, N - wav.numel()))

#     mels = []
#     for start in range(0, max(1, wav.numel() - N + 1), step):
#         seg = wav[start:start+N]
#         mel = to_logmel(seg)
#         mel = fix_frames(mel, FRAMES)
#         mels.append(mel)
#     return mels  # list of [1,128,256]

# def score_folder(D, folder, exts, real_folder=False, max_files=None):
#     files = sorted([p for ext in exts for p in Path(folder).rglob(f"*{ext}")])
#     if max_files is not None:
#         files = files[:max_files]

#     kept_files, kept_scores = [], []
#     for p in files:
#         try:
#             codec_parity = real_folder  # True for real WAVs, False for AI MP3s
#             mels = slice_track_to_mels(p, codec_parity_for_real=codec_parity, step_secs=WIN_SECS)
#             if not mels:
#                 continue
#             mean_s, _ = score_windows_with_D(D, mels)
#             kept_files.append(p)
#             kept_scores.append(mean_s)
#         except Exception as e:
#             # Skip problematic file; keep evaluation running
#             print(f"[skip] {p} -> {e}")
#             continue

#     return kept_files, np.array(kept_scores, dtype=np.float32)


In [None]:

# def evaluate_discriminator(D, real_dir=REAL_WAV_DIR, ai_dir=AI_MP3_DIR, max_files=MAX_EVAL_FILES_PER_CLASS):
#     if not HAVE_SK:
#         print("scikit-learn not found — skipping ROC-AUC/PR-AUC. You can 'pip install scikit-learn' and rerun.")
#         return None

#     real_files, real_scores = score_folder(D, real_dir, [".wav", ".WAV"], real_folder=True, max_files=max_files)
#     ai_files,   ai_scores   = score_folder(D, ai_dir,   [".mp3", ".MP3"], real_folder=False, max_files=max_files)

#     y_true = np.array([1]*len(real_scores) + [0]*len(ai_scores), dtype=np.int32)
#     y_pred = np.concatenate([real_scores, ai_scores], axis=0)

#     roc = roc_auc_score(y_true, y_pred)
#     pr  = average_precision_score(y_true, y_pred)
#     print(f"ROC-AUC: {roc:.4f} | PR-AUC: {pr:.4f}")
#     return {
#         "roc_auc": float(roc),
#         "pr_auc": float(pr),
#         "real_scores": real_scores,
#         "ai_scores": ai_scores,
#         "real_files": [str(p) for p in real_files],
#         "ai_files": [str(p) for p in ai_files],
#     }



## Quickstart

1. Put your data here (or change the config):
   - `data/real_wav/**.wav`
   - `data/ai_mp3/**.mp3`

2. Run training:


In [None]:

# # Train the GAN on REAL only (WGAN-GP). Increase EPOCHS later.
# G, D = train_wgan_gp(train_loader, epochs=EPOCHS)

# # Save checkpoints
# if D is not None and G is not None:
#     torch.save(D.state_dict(), "D.pth")
#     torch.save(G.state_dict(), "G.pth")
#     print("Saved D.pth and G.pth")


Epoch 1/5:   0%|          | 0/124 [00:02<?, ?it/s]


TypeError: CondGenerator.forward() missing 1 required positional argument: 'z'


3. Score folders and compute metrics:


In [None]:

# if 'D' in globals() and D is not None:
#     _ = evaluate_discriminator(D, real_dir=REAL_WAV_DIR, ai_dir=AI_MP3_DIR, max_files=MAX_EVAL_FILES_PER_CLASS)
# else:
#     print("Train (or load) a Discriminator first.")


  from pkg_resources import resource_filename


ROC-AUC: 0.1468 | PR-AUC: 0.3437



### Load Discriminator later & score single files


In [None]:

# def load_discriminator(path="D.pth", device=DEVICE):
#     D = Discriminator(use_spectral_norm=True).to(device)
#     sd = torch.load(path, map_location=device)
#     D.load_state_dict(sd)
#     D.eval()
#     return D

# # Example single-file scoring (edit the paths):
# D = load_discriminator("D.pth")
# mels = slice_track_to_mels("data/AI_audio/-0Gj8-vB1q4_1.mp3", codec_parity_for_real=False, step_secs=WIN_SECS)
# mean_score, window_scores = score_windows_with_D(D, mels)
# print("Mean realism score:", mean_score)


Mean realism score: -14.58976697921753



## Notes & Tips

- **Codec parity matters**: we round-trip real WAVs through MP3 (192 kbps) during training and when *scoring* real files.  
  If `ffmpeg` isn't available, the notebook will silently skip parity (you can install ffmpeg to enable it).

- **Loudness normalization**: LUFS if `pyloudnorm` is installed, else peak-normalization fallback.

- **Input shape fixed**: mel specs resized to `[1, 128, 256]`. You can increase `FRAMES` (e.g., 512) if your GPU has room.

- **Stability**: We use **WGAN-GP** and **spectral norm** on the discriminator.

- **Sanity check**: Before full training, try a few batches to ensure losses move and `d_real > d_fake` early on.

- **Evaluation**: The **mean discriminator score** per track is your realism score. With scikit-learn installed, we report **ROC-AUC** and **PR-AUC**.

- **Next steps**: try **PCEN** instead of dB, add light augmentations, or move to longer windows, then compare metrics.
