In [None]:
# %% ------------------------- Cell 1: Imports & Setup ------------------------
import sys
from pathlib import Path
import torch, torch.nn as nn, torch.nn.functional as F
import torchaudio, numpy as np, random
from tqdm.notebook import tqdm

SEED = 0
torch.manual_seed(SEED); np.random.seed(SEED); random.seed(SEED)

current_dir = Path.cwd()
repo_root = current_dir.parent.parent
sys.path.insert(0, str(repo_root / "src"))

print("Repo root:", repo_root)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# helper utils
from utils.audio_dataset_loader import (
    load_ears_dataset, load_wham_dataset, load_noizeus_dataset,
    create_audio_pairs, preprocess_audio
)

In [None]:
class TinyGRUVAD(nn.Module):
    """Light GRU-based VAD, causal, hearing-aid friendly (~2 k params)."""
    def __init__(self, input_dim=32, hidden_dim=16, dropout=0.1):
        super().__init__()
        self.pre = nn.Conv1d(input_dim, input_dim, kernel_size=3, padding=1, groups=input_dim)
        self.norm = nn.LayerNorm(input_dim)
        self.gru = nn.GRU(input_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, 1)
        self.drop = nn.Dropout(dropout)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x, h=None):
        # x: (B,T,F)
        x = x.transpose(1,2)              # (B,F,T)
        x = self.pre(x).transpose(1,2)    # local TDNN-like conv
        x = self.norm(x)
        out, h = self.gru(x, h)
        out = self.drop(out)
        p = self.sigmoid(self.fc(out))
        return p, h


In [None]:
# %% ------------------------- Cell 3: Load Datasets --------------------------
max_pairs = 3000
noise_files = load_wham_dataset(repo_root, mode="train", max_files=max_pairs)
clean_files = load_ears_dataset(repo_root, mode="train")
train_pairs = create_audio_pairs(noise_files, clean_files)
print(f"Train pairs: {len(train_pairs)}")

noise_val = load_wham_dataset(repo_root, mode="validation", max_files=900)
clean_val = load_ears_dataset(repo_root, mode="validation")
val_pairs = create_audio_pairs(noise_val, clean_val)
print(f"Val pairs: {len(val_pairs)}")


In [None]:

# %% ------------------------- Cell 4: Feature Extraction ---------------------
def mix_and_extract(noisy_wave, clean_wave, fs, n_bands=16, frame_len=0.008, hop_len=0.004, device=device):
    """log-mel + delta features & frame-wise labels."""
    if noisy_wave.dim() > 1: noisy_wave=noisy_wave[0]
    if clean_wave.dim() > 1: clean_wave=clean_wave[0]
    n_fft, hop = int(fs*frame_len), int(fs*hop_len)
    win = torch.hann_window(n_fft).to(device)

    def pspec(w): 
        spec=torch.stft(w.to(device),n_fft,hop,window=win,return_complex=True)
        return spec.abs()**2

    mel = torchaudio.transforms.MelScale(n_mels=n_bands, sample_rate=fs, n_stft=n_fft//2+1).to(device)
    melN = mel(pspec(noisy_wave)).clamp_min(1e-8)
    melC = mel(pspec(clean_wave)).clamp_min(1e-8)
    logN, logC = torch.log(melN.T+1e-8), torch.log(melC.T+1e-8)
    delta=torch.zeros_like(logN); delta[1:]=logN[1:]-logN[:-1]
    feats=torch.cat([logN,delta],1).unsqueeze(0)
    ratio=melC.T.sum(1)/(melN.T.sum(1)+1e-8)
    labels=(ratio>0.2).float().unsqueeze(1).unsqueeze(0)
    return feats.to(device), labels.to(device)


In [None]:

# %% ------------------------- Cell 5: Dataset & Loader -----------------------
class LiveMixDataset(torch.utils.data.Dataset):
    def __init__(self,pairs,target_sr=16_000,snr_range=(-5,10),device="cpu"):
        self.pairs=pairs; self.sr=target_sr; self.range=snr_range; self.device=device
    def __len__(self): return len(self.pairs)
    def __getitem__(self,idx):
        n,c=self.pairs[idx]; snr=random.uniform(*self.range)
        cw,nw,noisy,fs=preprocess_audio(Path(c),Path(n),self.sr,snr,None)
        feats,labs=mix_and_extract(noisy,cw,fs,device=self.device)
        return feats.squeeze(0),labs.squeeze(0)

def collate_pad(batch):
    feats,labs=zip(*batch); L=[f.size(0) for f in feats]; Tmax=max(L); Fdim=feats[0].size(1)
    X=torch.zeros(len(batch),Tmax,Fdim); Y=torch.zeros(len(batch),Tmax,1)
    for i,(f,l) in enumerate(zip(feats,labs)): X[i,:f.size(0)]=f; Y[i,:l.size(0)]=l
    return X,Y,torch.tensor(L)

train_ds=LiveMixDataset(train_pairs,device=device)
val_ds=LiveMixDataset(val_pairs,device=device)
train_dl=torch.utils.data.DataLoader(train_ds,batch_size=1,shuffle=True)
val_dl=torch.utils.data.DataLoader(val_ds,batch_size=1)


In [None]:
# %% ------------------------- Cell 6: Training (clean live view) -------------------------------
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score
from IPython.display import clear_output, display
import matplotlib.pyplot as plt

vad = TinyGRUVAD(32, 16).to(device)
opt = torch.optim.AdamW(vad.parameters(), lr=1e-3, weight_decay=1e-6)
crit = nn.BCELoss()
sched = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode="min", factor=0.5, patience=3)

bestF1, bestPath = -1.0, repo_root / "models" / "tiny_vad_best.pth"
bestPath.parent.mkdir(parents=True, exist_ok=True)

# prepare live plot
plt.ion()
fig, (ax_loss, ax_f1) = plt.subplots(1, 2, figsize=(12, 5))
ax_loss.set_title("Loss"); ax_f1.set_title("Validation F1")
ax_loss.set_xlabel("Epoch"); ax_f1.set_xlabel("Epoch")
ax_loss.set_ylabel("BCE Loss"); ax_f1.set_ylabel("F1 Score")
ax_loss.grid(True, alpha=0.3); ax_f1.grid(True, alpha=0.3)
train_losses, val_losses, f1_scores = [], [], []

for epoch in range(1, 51):
    # ---------------- Train ----------------
    vad.train(); Ltr = 0.0
    for x, y in tqdm(train_dl, desc=f"Epoch {epoch}/50", leave=False):
        x, y = x.to(device), y.to(device)
        p, _ = vad(x)
        loss = crit(p, y)
        opt.zero_grad(); loss.backward(); opt.step()
        Ltr += loss.item()
    Ltr /= len(train_dl)
    train_losses.append(Ltr)

    # ---------------- Validate ----------------
    vad.eval(); Lval = 0.0; P, L = [], []
    with torch.no_grad():
        for x, y in val_dl:
            x, y = x.to(device), y.to(device)
            p, _ = vad(x)
            Lval += crit(p, y).item()
            P.append(p.cpu().numpy().ravel()); L.append(y.cpu().numpy().ravel())
    Lval /= len(val_dl)
    val_losses.append(Lval)

    probs, labels = np.concatenate(P), np.concatenate(L).astype(int)
    pred = (probs >= 0.5).astype(int)
    prec, rec, f1, _ = precision_recall_fscore_support(labels, pred, average="binary", zero_division=0)
    auc = roc_auc_score(labels, probs)
    f1_scores.append(f1)
    sched.step(Lval)

    # ---------------- Save & Display ----------------
    if f1 > bestF1:
        bestF1 = f1
        torch.save(vad.state_dict(), bestPath)
        status = f"[INFO] new best (F1={f1:.3f})"
    else:
        status = ""

    clear_output(wait=True)           # <-- clear console output
    print(f"Epoch {epoch}/50  TL={Ltr:.3f}  VL={Lval:.3f}  F1={f1:.3f}  AUC={auc:.3f}  {status}")

    # update live plots
    ax_loss.cla(); ax_f1.cla()
    ax_loss.plot(train_losses, "b-", label="Train Loss")
    ax_loss.plot(val_losses, "r-", label="Val Loss")
    ax_loss.legend(); ax_loss.set_xlabel("Epoch"); ax_loss.set_ylabel("BCE Loss"); ax_loss.grid(True)
    ax_f1.plot(f1_scores, "g-", label="Val F1")
    ax_f1.legend(); ax_f1.set_xlabel("Epoch"); ax_f1.set_ylabel("F1 Score"); ax_f1.grid(True)
    display(fig)
    plt.pause(0.001)

    if sched._last_lr[0] < 1e-5: 
        print("[INFO] Early stop (triggered by LR floor)")
        break

plt.ioff()
vad.load_state_dict(torch.load(bestPath, map_location=device))
print(f"Loaded best model (F1={bestF1:.3f})")


In [None]:

# %% ------------------------- Cell 8: Quick Validation Plot ------------------
import matplotlib.pyplot as plt
exN,exC=val_pairs[0]
cw,nw,noisy,fs=preprocess_audio(Path(exC),Path(exN),16_000,0)
# use 8 ms frames with 4 ms hop for the digital filter timing
f,l=mix_and_extract(noisy,cw,fs,frame_len=0.008, hop_len=0.004)
vad.eval(); 
with torch.no_grad(): p,_=vad(f.to(device))
p=p.squeeze().cpu().numpy(); l=l.squeeze().cpu().numpy()
plt.figure(figsize=(10,4))
plt.plot(p,label="P(speech)"); plt.plot(l,label="Label",alpha=.6); plt.legend(); plt.grid(); plt.show()
