In [1]:
"""
Stable end-to-end pipeline:  WAV → log-Mel → global-norm → CRNN
----------------------------------------------------------------
• Works with librosa ≥ 0.10
• No warnings, no –inf / +inf, no OverflowError (silent clips OK)
• Folder layout:
      cleaned_dataset/
          train/{questions,others}/*.wav
          test/{questions,others}/*.wav
----------------------------------------------------------------
"""

import os, random, warnings, math, numpy as np, torch, librosa
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, f1_score
import torch.nn as nn

# ─── hyper-params & reproducibility ──────────────────────────────────────
ROOT = Path("cleaned_dataset")
SR = 16_000  # resample rate
N_FFT = 1024
HOP = 256
N_MELS = 128
TOP_DB = 80.0  # dynamic-range clamp (dB)
BATCH_SIZE = 32
EPOCHS = 20
LR = 1e-3
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)


# ─── robust helpers ──────────────────────────────────────────────────────
def safe_load(path: Path) -> np.ndarray:
    """Return mono float32 signal (guaranteed non-empty)."""
    y, _ = librosa.load(str(path), sr=SR, mono=True)
    if y.size == 0:
        y = np.zeros(int(0.1 * SR), dtype=np.float32)  # 100 ms silence
    return y.astype(np.float32)


def log_mel(y: np.ndarray) -> np.ndarray:
    """Waveform → log-Mel in dB (shape = [n_mels, T]) – finite values only."""
    if y.size < N_FFT:
        y = np.pad(y, (0, N_FFT - y.size), mode="constant")
    S = librosa.feature.melspectrogram(
        y=y, sr=SR, n_fft=N_FFT, hop_length=HOP, n_mels=N_MELS, power=2.0, center=True
    )
    # Avoid log(0) and clamp dynamic range
    S_db = librosa.power_to_db(S, ref=np.max, amin=1e-10, top_db=TOP_DB)
    S_db = np.nan_to_num(S_db, neginf=-TOP_DB, posinf=0.0)
    return S_db.astype(np.float32)


# ─── dataset & dataloader ────────────────────────────────────────────────
def list_wavs(split: str):
    base = ROOT / split
    pairs = [(p, 0) for p in (base / "others").rglob("*.wav")]
    pairs += [(p, 1) for p in (base / "questions").rglob("*.wav")]
    random.shuffle(pairs)
    return pairs


def compute_global_cmvn(paths):
    """Return mean, std over *all* frames in training set."""
    sum_, sq_sum, count = 0.0, 0.0, 0
    for p, _ in paths:
        m = log_mel(safe_load(p))
        sum_ += m.sum()
        sq_sum += (m**2).sum()
        count += m.size
    mean = sum_ / count
    std = math.sqrt(sq_sum / count - mean**2 + 1e-12)
    return mean, std


class MelDataset(Dataset):
    def __init__(self, file_label_pairs, mean, std):
        self.items, self.mean, self.std = file_label_pairs, mean, std

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx):
        path, label = self.items[idx]
        m = (log_mel(safe_load(path)) - self.mean) / (self.std + 1e-6)
        return torch.from_numpy(m).unsqueeze(0), label  # (1, 128, T)


def pad_collate(batch):
    xs, ys = zip(*batch)
    T_max = max(x.shape[-1] for x in xs)
    padded = torch.zeros(len(xs), 1, N_MELS, T_max, dtype=torch.float32)
    for i, x in enumerate(xs):
        padded[i, :, :, : x.shape[-1]] = x
    return padded, torch.tensor(ys)


def get_loaders():
    train_pairs, test_pairs = list_wavs("train"), list_wavs("test")
    mean, std = compute_global_cmvn(train_pairs)
    tr_ds = MelDataset(train_pairs, mean, std)
    te_ds = MelDataset(test_pairs, mean, std)
    tr_ld = DataLoader(
        tr_ds, BATCH_SIZE, shuffle=True, collate_fn=pad_collate
    )
    te_ld = DataLoader(
        te_ds, BATCH_SIZE, shuffle=False, collate_fn=pad_collate
    )
    return tr_ld, te_ld


# ─── simple CRNN ─────────────────────────────────────────────────────────
class CRNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 16, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.lstm = nn.LSTM(32 * (N_MELS // 4), 64, batch_first=True)
        self.fc = nn.Linear(64, 2)

    def forward(self, x):  # x: (B,1,128,T)
        x = self.cnn(x)  # (B,32,32,T/4)   (128/4=32)
        B, C, M, T = x.shape
        x = x.permute(0, 3, 1, 2).reshape(B, T, C * M)  # (B,T,features)
        h, _ = self.lstm(x)
        return self.fc(h[:, -1])  # final time-step → logits


# ─── training / evaluation ───────────────────────────────────────────────
def run_epoch(model, loader, optim=None):
    training = optim is not None
    model.train() if training else model.eval()
    y_true, y_pred, total_loss = [], [], 0.0
    crit = nn.CrossEntropyLoss()
    for X, y in loader:
        X, y = X.to(DEVICE), y.to(DEVICE)
        out = model(X)
        loss = crit(out, y)
        if training:
            optim.zero_grad()
            loss.backward()
            optim.step()
        total_loss += loss.item() * y.size(0)
        y_true.extend(y.cpu())
        y_pred.extend(out.argmax(1).cpu())
    acc = accuracy_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred, average="macro")
    return total_loss / len(loader.dataset), acc, f1


def main():
    train_loader, test_loader = get_loaders()
    net = CRNN().to(DEVICE)
    opt = torch.optim.Adam(net.parameters(), lr=LR)
    best = 0.0
    for epoch in range(1, EPOCHS + 1):
        _, tr_acc, tr_f1 = run_epoch(net, train_loader, opt)
        _, te_acc, te_f1 = run_epoch(net, test_loader)
        print(
            f"ep{epoch:02d}  train {tr_acc:.3f}/{tr_f1:.3f}  "
            f"test {te_acc:.3f}/{te_f1:.3f}"
        )
        if te_f1 > best:
            best = te_f1
            torch.save(net.state_dict(), "best_crnn.pt")


if __name__ == "__main__":
    main()

ep01  train 0.599/0.574  test 0.746/0.583
ep02  train 0.733/0.725  test 0.822/0.713
ep03  train 0.795/0.786  test 0.875/0.765
ep04  train 0.818/0.810  test 0.865/0.760
ep05  train 0.828/0.821  test 0.858/0.760
ep06  train 0.851/0.846  test 0.787/0.702
ep07  train 0.864/0.859  test 0.857/0.763
ep08  train 0.871/0.867  test 0.814/0.720
ep09  train 0.893/0.890  test 0.718/0.643
ep10  train 0.901/0.898  test 0.887/0.789
ep11  train 0.917/0.915  test 0.831/0.738
ep12  train 0.927/0.926  test 0.795/0.704
ep13  train 0.937/0.936  test 0.809/0.716
ep14  train 0.946/0.945  test 0.741/0.663
ep15  train 0.944/0.943  test 0.842/0.742
ep16  train 0.963/0.962  test 0.798/0.703
ep17  train 0.963/0.963  test 0.820/0.726
ep18  train 0.964/0.964  test 0.834/0.737
ep19  train 0.967/0.967  test 0.824/0.727
ep20  train 0.972/0.971  test 0.824/0.729
