In [None]:
pip install jupyterlab ipykernel numpy pandas matplotlib tensorboard soundfile torch torchvision torchaudio

In [None]:
# ==============================
# 0. SETUP: IMPORTS, CONFIG, DEVICE, DIRECTORIES
# ==============================

# If you're in COLAB, uncomment and run this once:
# !pip install soundfile torchaudio --quiet

import os
import random
import time
from pathlib import Path

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import torchaudio
import soundfile as sf

import numpy as np
from sklearn.metrics import confusion_matrix, classification_report
import matplotlib.pyplot as plt

# ----- Device selection -----
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")  # Apple Silicon GPU
else:
    device = torch.device("cpu")
print("Using device:", device)

# ----- Global audio / feature config -----
SAMPLE_RATE = 16_000          # 16 kHz standard for speech
CLIP_DURATION_SEC = 1.0       # 1-second clips
CLIP_NUM_SAMPLES = int(SAMPLE_RATE * CLIP_DURATION_SEC)  # 16000 samples

N_MELS = 64       # number of mel filterbanks / MFCC coefficients
N_FFT = 512       # FFT size
HOP_LENGTH = 160  # ~10ms hop
WIN_LENGTH = 400  # ~25ms window

# Commands for 10-class KWS + "unknown"
TARGET_COMMANDS = [
    "yes", "no", "up", "down", "left", "right", "on", "off", "stop", "go"
]
ALL_CLASSES = TARGET_COMMANDS + ["unknown"]     # total 11 classes
NUM_CLASSES = len(ALL_CLASSES)
CLASS_TO_INDEX = {c: i for i, c in enumerate(ALL_CLASSES)}

# ----- Directory layout -----
ROOT_DATA = "./data"
SPEECH_COMMANDS_ROOT = os.path.join(ROOT_DATA, "speech_commands")
NOISE_DIR = os.path.join(ROOT_DATA, "noise")
RAVDESS_ROOT = "./data/ravdess/audio_speech_actors_01-24"  # adjust if your folder is different

os.makedirs(ROOT_DATA, exist_ok=True)
os.makedirs(SPEECH_COMMANDS_ROOT, exist_ok=True)
os.makedirs(NOISE_DIR, exist_ok=True)

print("CWD:", os.getcwd())
print("Data root:", ROOT_DATA)
print("SpeechCommands root:", SPEECH_COMMANDS_ROOT)
print("Noise dir:", NOISE_DIR)
print("RAVDESS dir:", RAVDESS_ROOT)


In [None]:
# ==============================
# 1. BASIC AUDIO UTILITIES + SAFE torchaudio.load
# ==============================

def pad_or_trim(waveform: torch.Tensor, target_num_samples: int) -> torch.Tensor:
    """
    Ensure waveform has exactly target_num_samples samples.
    If longer -> truncate. If shorter -> zero-pad at the end.

    waveform: [1, T] or [C, T]
    returns:  [1, target_num_samples] (mono)
    """
    # Force mono [1, T]
    if waveform.ndim == 2 and waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0, keepdim=True)
    elif waveform.ndim == 1:
        waveform = waveform.unsqueeze(0)

    num_samples = waveform.shape[-1]
    if num_samples == target_num_samples:
        return waveform
    elif num_samples > target_num_samples:
        return waveform[..., :target_num_samples]
    else:
        pad_amount = target_num_samples - num_samples
        return torch.nn.functional.pad(waveform, (0, pad_amount))


def compute_rms(x: torch.Tensor, eps: float = 1e-8) -> float:
    """
    Root Mean Square energy of waveform.
    x: [1, T]
    returns: scalar rms float
    """
    return torch.sqrt(torch.mean(x ** 2) + eps).item()


def mix_with_noise(clean: torch.Tensor, noise: torch.Tensor, snr_db: float) -> torch.Tensor:
    """
    Mix clean speech with noise at a desired SNR (in dB).

    clean: [1, T]
    noise: [C, T_noise] or [1, T_noise]
    returns: [1, T] noisy waveform
    """
    # Ensure mono
    if noise.ndim == 2 and noise.shape[0] > 1:
        noise = noise.mean(dim=0, keepdim=True)
    elif noise.ndim == 1:
        noise = noise.unsqueeze(0)

    T = clean.shape[-1]

    # Ensure noise is long enough. If too short, repeat it.
    if noise.shape[-1] < T:
        repeats = (T // noise.shape[-1]) + 1
        noise = noise.repeat(1, repeats)

    # Randomly crop a T-length segment of the noise
    start = random.randint(0, noise.shape[-1] - T)
    noise_segment = noise[..., start:start + T]

    # Compute RMS levels of clean and noise
    rms_clean = compute_rms(clean)
    rms_noise = compute_rms(noise_segment)
    if rms_noise == 0:
        # Degenerate noise (silent) -> return clean
        return clean

    # SNR in dB: SNR = 20 * log10(rms_clean / rms_noise_scaled)
    snr_linear = 10 ** (snr_db / 20.0)
    k = rms_clean / (rms_noise * snr_linear)
    noise_scaled = k * noise_segment

    mixed = clean + noise_scaled
    mixed = torch.clamp(mixed, -1.0, 1.0)
    return mixed


# ----- Monkey-patch torchaudio.load to use soundfile backend -----
def safe_torchaudio_load(path,
                         frame_offset=0,
                         num_frames=-1,
                         normalize=True,
                         channels_first=True,
                         format=None,
                         buffer_size=4096,
                         backend=None):
    """
    Replacement for torchaudio.load that uses soundfile, to avoid backend issues.
    """
    data, sr = sf.read(path, dtype="float32", always_2d=True)  # [T, C]
    wav = torch.from_numpy(data)  # [T, C]

    # Handle frame offset / num_frames slicing
    if frame_offset > 0 or num_frames > -1:
        end = None if num_frames < 0 else frame_offset + num_frames
        wav = wav[frame_offset:end]

    # Convert to [C, T] if channels_first
    if channels_first:
        wav = wav.transpose(0, 1)  # [C, T]

    # Normalization: soundfile already returns float32 in [-1, 1] for PCM
    return wav, sr


torchaudio.load = safe_torchaudio_load
print("Patched torchaudio.load to use soundfile backend.")


In [None]:
# ==============================
# 2. FEATURE EXTRACTION: LOG-MEL + MFCC
# ==============================

# MelSpectrogram transform (for log-Mel features)
mel_transform = torchaudio.transforms.MelSpectrogram(
    sample_rate=SAMPLE_RATE,
    n_fft=N_FFT,
    hop_length=HOP_LENGTH,
    win_length=WIN_LENGTH,
    n_mels=N_MELS,
)

# Convert amplitude to dB (log-like scale)
amp_to_db = torchaudio.transforms.AmplitudeToDB()

# MFCC transform (for MFCC features)
mfcc_transform = torchaudio.transforms.MFCC(
    sample_rate=SAMPLE_RATE,
    n_mfcc=N_MELS,   # keep same "freq" dimension size so CNN architecture is unchanged
    melkwargs={
        "n_fft": N_FFT,
        "n_mels": N_MELS,
        "hop_length": HOP_LENGTH,
        "win_length": WIN_LENGTH,
        "mel_scale": "htk",
    },
)


def waveform_to_features(waveform: torch.Tensor, mode: str = "logmel") -> torch.Tensor:
    """
    Convert waveform to 2D time-frequency features.

    waveform: [1, T]
    mode: "logmel" or "mfcc"
    returns: [1, N_MELS, T_frames]
    """
    waveform = waveform.to(torch.float32)

    with torch.no_grad():
        if mode == "logmel":
            mel = mel_transform(waveform)      # [1, n_mels, time]
            feat = amp_to_db(mel)             # [1, n_mels, time]
        elif mode == "mfcc":
            mfcc = mfcc_transform(waveform)   # [1, n_mfcc, time]
            feat = mfcc
        else:
            raise ValueError(f"Unknown feature mode: {mode}")
    return feat


In [None]:
# ==============================
# 3. SPECAUGMENT (TIME + FREQUENCY MASKING)
# ==============================

def spec_augment(
    feat: torch.Tensor,
    max_time_mask: int = 20,
    max_freq_mask: int = 8,
    num_time_masks: int = 2,
    num_freq_masks: int = 2,
) -> torch.Tensor:
    """
    Simple SpecAugment implementation.

    feat: [1, N_MELS, T]
    returns: augmented feature of same shape
    """
    out = feat.clone()
    _, n_mels, T = out.shape

    # Time masks
    for _ in range(num_time_masks):
        t = random.randint(0, max_time_mask)
        if t == 0 or T - t <= 0:
            continue
        t0 = random.randint(0, T - t)
        out[:, :, t0:t0 + t] = 0.0

    # Frequency masks
    for _ in range(num_freq_masks):
        f = random.randint(0, max_freq_mask)
        if f == 0 or n_mels - f <= 0:
            continue
        f0 = random.randint(0, n_mels - f)
        out[:, f0:f0 + f, :] = 0.0

    return out


In [None]:
# ==============================
# 4. LOAD NOISE FILES
# ==============================

def load_noise_waveforms(noise_dir: str, sample_rate: int = SAMPLE_RATE):
    """
    Load all .wav files in noise_dir, convert to mono and resample.

    returns: list of waveforms, each [1, T_noise]
    """
    noise_waveforms = []
    noise_dir_path = Path(noise_dir)
    if not noise_dir_path.exists():
        print(f"[WARN] Noise directory {noise_dir} does not exist.")
        return noise_waveforms

    for wav_path in noise_dir_path.glob("*.wav"):
        try:
            wav, sr = torchaudio.load(str(wav_path))

            # Convert to mono if multi-channel
            wav = pad_or_trim(wav, wav.shape[-1])  # ensure [1, T]

            # Resample if needed
            if sr != sample_rate:
                wav = torchaudio.functional.resample(wav, sr, sample_rate)

            # Remove DC offset (center around 0)
            wav = wav - wav.mean()
            noise_waveforms.append(wav)
        except Exception as e:
            print(f"[WARN] Failed to load noise file {wav_path}: {e}")

    print(f"Loaded {len(noise_waveforms)} noise files from {noise_dir}")
    return noise_waveforms


NOISE_WAVEFORMS = load_noise_waveforms(NOISE_DIR, SAMPLE_RATE)
if len(NOISE_WAVEFORMS) == 0:
    print("NOISE_WAVEFORMS is empty – KWS noise training will still run but effectively be clean.")


In [None]:
# ==============================
# 5. SPEECH COMMANDS DATASET + COLLATE + LOADERS
# ==============================

class GuardianSpeechCommands(Dataset):
    """
    Wrapper around torchaudio.datasets.SPEECHCOMMANDS with:
      - 11-class label mapping (10 commands + 'unknown')
      - optional noise mixing (for KWS multi-condition training)
      - log-Mel or MFCC feature extraction
      - optional SpecAugment (for log-Mel)
    """

    def __init__(
        self,
        root: str,
        subset: str,                 # "training", "validation", "testing"
        target_commands,
        all_classes,
        noise_waveforms=None,        # list of noise waveforms or None
        p_noise: float = 0.7,        # probability of mixing noise during training
        snr_db_range=(0, 20),        # SNR range for noise mixing
        sample_rate: int = SAMPLE_RATE,
        feature_mode: str = "logmel",    # "logmel" or "mfcc"
        use_specaugment: bool = False,   # apply SpecAugment (for logmel only)
    ):
        super().__init__()
        assert subset in ["training", "validation", "testing"]

        self.base = torchaudio.datasets.SPEECHCOMMANDS(
            root=root,
            download=True,
            url="speech_commands_v0.02",
            subset=subset,
        )
        self.target_commands = set(target_commands)
        self.all_classes = all_classes
        self.class_to_index = {c: i for i, c in enumerate(all_classes)}

        self.noise_waveforms = noise_waveforms if noise_waveforms is not None else []
        self.p_noise = p_noise
        self.snr_db_range = snr_db_range
        self.sample_rate = sample_rate

        self.feature_mode = feature_mode
        self.use_specaugment = use_specaugment

    def __len__(self):
        return len(self.base)

    def _map_label(self, label: str) -> int:
        """
        Map raw label string to:
        - its own index if in TARGET_COMMANDS
        - "unknown" class index otherwise
        """
        if label in self.target_commands:
            return self.class_to_index[label]
        else:
            return self.class_to_index["unknown"]

    def __getitem__(self, idx):
        # base[idx] returns: (waveform, sample_rate, label, speaker_id, utterance_number)
        waveform, sr, label, speaker_id, utt_number = self.base[idx]

        waveform = pad_or_trim(waveform, waveform.shape[-1])  # ensure [1, T]

        # Resample to desired SAMPLE_RATE if needed
        if sr != self.sample_rate:
            waveform = torchaudio.functional.resample(waveform, sr, self.sample_rate)

        # Pad/trim to fixed length (1 second)
        waveform = pad_or_trim(waveform, CLIP_NUM_SAMPLES)

        # Optional noise mixing (mainly for training)
        if self.noise_waveforms and random.random() < self.p_noise:
            snr_db = random.uniform(*self.snr_db_range)
            noise = random.choice(self.noise_waveforms)
            waveform = mix_with_noise(waveform, noise, snr_db)

        # Convert to time-frequency features (log-Mel or MFCC)
        feat = waveform_to_features(waveform, mode=self.feature_mode)  # [1, F, T]

        # Optional SpecAugment (for training log-Mel)
        if self.use_specaugment and self.feature_mode == "logmel":
            if random.random() < 0.5:
                feat = spec_augment(feat)

        # Map label string to integer class index
        y = self._map_label(label)

        return feat, y


def guardian_collate_fn(batch):
    """
    Custom collate function to handle variable time dimension.

    batch: list of tuples (feat [1,F,T_i], label)
    returns:
      X: [B, 1, F, T_max]
      y: [B]
    """
    xs, ys = zip(*batch)
    max_T = max(x.shape[-1] for x in xs)

    padded = []
    for x in xs:
        T = x.shape[-1]
        if T < max_T:
            pad_amount = max_T - T
            x = torch.nn.functional.pad(x, (0, pad_amount))
        padded.append(x)

    X = torch.stack(padded, dim=0)          # [B, 1, F, T_max]
    y = torch.tensor(ys, dtype=torch.long)  # [B]
    return X, y


def make_loaders(
    feature_mode="logmel",
    train_with_noise=False,
    use_specaugment=False,
    batch_size=64,
    num_workers=0,
):
    """
    Build train/val/test DataLoaders for SpeechCommands.

    - train_with_noise: if True, training set uses noise mixing.
    - use_specaugment: if True, SpecAugment applied on training features (log-Mel only).
    """
    if train_with_noise and len(NOISE_WAVEFORMS) > 0:
        train_noise_waveforms = NOISE_WAVEFORMS
        p_noise = 0.7
        snr_range = (0, 20)
    else:
        train_noise_waveforms = None
        p_noise = 0.0
        snr_range = (0, 0)

    train_dataset = GuardianSpeechCommands(
        root=SPEECH_COMMANDS_ROOT,
        subset="training",
        target_commands=TARGET_COMMANDS,
        all_classes=ALL_CLASSES,
        noise_waveforms=train_noise_waveforms,
        p_noise=p_noise,
        snr_db_range=snr_range,
        feature_mode=feature_mode,
        use_specaugment=use_specaugment,
    )

    val_dataset = GuardianSpeechCommands(
        root=SPEECH_COMMANDS_ROOT,
        subset="validation",
        target_commands=TARGET_COMMANDS,
        all_classes=ALL_CLASSES,
        noise_waveforms=None,
        p_noise=0.0,
        snr_db_range=(0, 0),
        feature_mode=feature_mode,
        use_specaugment=False,
    )

    test_dataset = GuardianSpeechCommands(
        root=SPEECH_COMMANDS_ROOT,
        subset="testing",
        target_commands=TARGET_COMMANDS,
        all_classes=ALL_CLASSES,
        noise_waveforms=None,
        p_noise=0.0,
        snr_db_range=(0, 0),
        feature_mode=feature_mode,
        use_specaugment=False,
    )

    pin_mem = torch.cuda.is_available()  # only relevant for CUDA

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=pin_mem,
        collate_fn=guardian_collate_fn,
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_mem,
        collate_fn=guardian_collate_fn,
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_mem,
        collate_fn=guardian_collate_fn,
    )

    print(f"[make_loaders] feature_mode={feature_mode}, "
          f"train_with_noise={train_with_noise}, use_specaugment={use_specaugment}")
    print("  Train examples:", len(train_dataset))
    print("  Val   examples:", len(val_dataset))
    print("  Test  examples:", len(test_dataset))

    return train_loader, val_loader, test_loader


In [None]:
# ==============================
# 7. MODEL: SHARED ENCODER + COMMAND HEAD
# ==============================

class SharedAudioEncoder(nn.Module):
    def __init__(self, n_mels: int = N_MELS):
        super().__init__()
        # Conv block 1: 1 -> 32 channels
        self.conv_block1 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=(3, 3), padding=(1, 1)),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=(2, 2)),
        )
        # Conv block 2: 32 -> 64 channels
        self.conv_block2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=(3, 3), padding=(1, 1)),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=(2, 2)),
        )
        # Conv block 3: 64 -> 128 channels
        self.conv_block3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=(3, 3), padding=(1, 1)),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=(2, 2)),
        )

        self.embedding_dim = 128  # after global average pooling

    def forward(self, x):
        # x: [B, 1, F, T]
        x = self.conv_block1(x)          # [B,32, F/2, T/2]
        x = self.conv_block2(x)          # [B,64, F/4, T/4]
        x = self.conv_block3(x)          # [B,128,F/8, T/8]
        # Global average pooling over freq and time -> [B,128]
        x = x.mean(dim=[2, 3])
        return x


class CommandHead(nn.Module):
    def __init__(self, embedding_dim: int, num_classes: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(embedding_dim, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.3),
            nn.Linear(128, num_classes),
        )

    def forward(self, z):
        return self.net(z)


class GuardianVoiceModel(nn.Module):
    def __init__(self, num_command_classes: int):
        super().__init__()
        self.encoder = SharedAudioEncoder(n_mels=N_MELS)
        self.command_head = CommandHead(self.encoder.embedding_dim, num_command_classes)

    def forward(self, x):
        # x: [B,1,F,T]
        z = self.encoder(x)             # [B,128]
        cmd_logits = self.command_head(z)   # [B,num_classes]
        return cmd_logits


In [None]:
# ==============================
# 8. TRAINING & EVALUATION HELPERS
# ==============================

def train_one_epoch(model, loader, optimizer, criterion, device):
    """
    One training epoch over loader.
    Returns (avg_loss, accuracy)
    """
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for X, y in loader:
        X = X.to(device)
        y = y.to(device)

        optimizer.zero_grad()
        logits = model(X)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * X.size(0)
        preds = logits.argmax(dim=1)
        correct += (preds == y).sum().item()
        total += y.size(0)

    return running_loss / total, correct / total


@torch.no_grad()
def evaluate(model, loader, criterion, device):
    """
    Evaluate model on loader.
    Returns (avg_loss, accuracy)
    """
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    for X, y in loader:
        X = X.to(device)
        y = y.to(device)

        logits = model(X)
        loss = criterion(logits, y)

        running_loss += loss.item() * X.size(0)
        preds = logits.argmax(dim=1)
        correct += (preds == y).sum().item()
        total += y.size(0)

    return running_loss / total, correct / total


In [None]:
# ==============================
# 9. BUILD NOISY TEST LOADERS (FOR SNR EVAL)
# ==============================

def build_noisy_test_loader(
    feature_mode: str,
    snr_db: float,
    batch_size: int = 64,
):
    """
    Create a test loader where all examples are mixed with noise
    at a fixed SNR.
    """
    noisy_test_ds = GuardianSpeechCommands(
        root=SPEECH_COMMANDS_ROOT,
        subset="testing",
        target_commands=TARGET_COMMANDS,
        all_classes=ALL_CLASSES,
        noise_waveforms=NOISE_WAVEFORMS,
        p_noise=1.0,                      # always mix noise
        snr_db_range=(snr_db, snr_db),    # fixed SNR
        sample_rate=SAMPLE_RATE,
        feature_mode=feature_mode,
        use_specaugment=False,
    )
    noisy_test_loader = DataLoader(
        noisy_test_ds,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=guardian_collate_fn,
        num_workers=0,
        pin_memory=torch.cuda.is_available(),
    )
    print(f"Built noisy test loader for feature_mode={feature_mode}, "
          f"SNR={snr_db} dB, size={len(noisy_test_ds)}")
    return noisy_test_loader


In [None]:
# ==============================
# 10.1 EXP 1: LOG-MEL, CLEAN-ONLY TRAINING
# ==============================

criterion_kws = nn.CrossEntropyLoss()

train_loader_clean, val_loader_logmel_clean, test_loader_logmel_clean = make_loaders(
    feature_mode="logmel",
    train_with_noise=False,
    use_specaugment=False,
    batch_size=64,
    num_workers=0,
)

model_clean = GuardianVoiceModel(num_command_classes=NUM_CLASSES).to(device)
optimizer_clean = torch.optim.Adam(model_clean.parameters(), lr=1e-3)

num_epochs_clean = 5

for epoch in range(1, num_epochs_clean + 1):
    train_loss_c, train_acc_c = train_one_epoch(
        model_clean, train_loader_clean, optimizer_clean, criterion_kws, device
    )
    val_loss_c, val_acc_c = evaluate(model_clean, val_loader_logmel_clean, criterion_kws, device)
    print(
        f"[CLEAN-LOGMEL] Epoch {epoch:02d}: "
        f"Train loss {train_loss_c:.4f}, acc {train_acc_c*100:.2f}% | "
        f"Val loss {val_loss_c:.4f}, acc {val_acc_c*100:.2f}%"
    )

clean_test_loss_c, clean_test_acc_c = evaluate(
    model_clean, test_loader_logmel_clean, criterion_kws, device
)
print(
    f"[CLEAN-LOGMEL] Final CLEAN test: loss {clean_test_loss_c:.4f}, "
    f"acc {clean_test_acc_c*100:.2f}%"
)


In [None]:
# ==============================
# 10.2 EXP 2: LOG-MEL, NOISE-TRAINED (MULTI-CONDITION)
# ==============================

train_loader_noise, val_loader_logmel_noise, test_loader_logmel_noise = make_loaders(
    feature_mode="logmel",
    train_with_noise=True,
    use_specaugment=False,
    batch_size=64,
    num_workers=0,
)

model_noise = GuardianVoiceModel(num_command_classes=NUM_CLASSES).to(device)
optimizer_noise = torch.optim.Adam(model_noise.parameters(), lr=1e-3)

num_epochs_noise = 5

for epoch in range(1, num_epochs_noise + 1):
    train_loss_n, train_acc_n = train_one_epoch(
        model_noise, train_loader_noise, optimizer_noise, criterion_kws, device
    )
    val_loss_n, val_acc_n = evaluate(
        model_noise, val_loader_logmel_noise, criterion_kws, device
    )
    print(f"[NOISE-LOGMEL] Epoch {epoch:02d}: "
          f"Train loss {train_loss_n:.4f}, acc {train_acc_n*100:.2f}% | "
          f"Val loss {val_loss_n:.4f}, acc {val_acc_n*100:.2f}%")

clean_test_loss_n, clean_test_acc_n = evaluate(
    model_noise, test_loader_logmel_noise, criterion_kws, device
)
print(f"[NOISE-LOGMEL] Final CLEAN test: loss {clean_test_loss_n:.4f}, "
      f"acc {clean_test_acc_n*100:.2f}%")


In [None]:
# ==============================
# 10.3 EXP 3: LOG-MEL, NOISE + SPECAUGMENT
# ==============================

train_loader_spec, val_loader_logmel_spec, test_loader_logmel_spec = make_loaders(
    feature_mode="logmel",
    train_with_noise=True,
    use_specaugment=True,
    batch_size=64,
    num_workers=0,
)

model_spec = GuardianVoiceModel(num_command_classes=NUM_CLASSES).to(device)
optimizer_spec = torch.optim.Adam(model_spec.parameters(), lr=1e-3)

num_epochs_spec = 5

for epoch in range(1, num_epochs_spec + 1):
    train_loss_s, train_acc_s = train_one_epoch(
        model_spec, train_loader_spec, optimizer_spec, criterion_kws, device
    )
    val_loss_s, val_acc_s = evaluate(
        model_spec, val_loader_logmel_spec, criterion_kws, device
    )
    print(f"[SPECAUG-LOGMEL] Epoch {epoch:02d}: "
          f"Train loss {train_loss_s:.4f}, acc {train_acc_s*100:.2f}% | "
          f"Val loss {val_loss_s:.4f}, acc {val_acc_s*100:.2f}%")

clean_test_loss_s, clean_test_acc_s = evaluate(
    model_spec, test_loader_logmel_spec, criterion_kws, device
)
print(f"[SPECAUG-LOGMEL] Final CLEAN test: loss {clean_test_loss_s:.4f}, "
      f"acc {clean_test_acc_s*100:.2f}%")


In [None]:
# ==============================
# 10.4 BUILD NOISY TEST LOADERS AND EVAL (LOG-MEL)
# ==============================

snrs = [20.0, 10.0, 0.0]
noisy_loaders_logmel = {snr: build_noisy_test_loader("logmel", snr) for snr in snrs}

# Evaluate all three log-Mel models (clean, noise-trained, specaug) on noisy test sets

clean_accuracies = {"Clean": clean_test_acc_c}
noise_accuracies = {"Clean": clean_test_acc_n}
specaug_accuracies = {"Clean": clean_test_acc_s}

for snr in snrs:
    _, acc_c = evaluate(model_clean, noisy_loaders_logmel[snr], criterion_kws, device)
    _, acc_n = evaluate(model_noise, noisy_loaders_logmel[snr], criterion_kws, device)
    _, acc_s = evaluate(model_spec, noisy_loaders_logmel[snr], criterion_kws, device)
    clean_accuracies[snr] = acc_c
    noise_accuracies[snr] = acc_n
    specaug_accuracies[snr] = acc_s

print("\n[CLEAN-LOGMEL] Acc vs SNR:", clean_accuracies)
print("[NOISE-LOGMEL] Acc vs SNR:", noise_accuracies)
print("[SPECAUG-LOGMEL] Acc vs SNR:", specaug_accuracies)


In [None]:
# ==============================
# 10.5  MFCC, NOISE-TRAINED (FEATURE COMPARISON)
# ==============================

train_loader_mfcc, val_loader_mfcc, test_loader_mfcc = make_loaders(
    feature_mode="mfcc",
    train_with_noise=True,
    use_specaugment=False,
    batch_size=64,
    num_workers=0,
)

model_mfcc = GuardianVoiceModel(num_command_classes=NUM_CLASSES).to(device)
optimizer_mfcc = torch.optim.Adam(model_mfcc.parameters(), lr=1e-3)

num_epochs_mfcc = 5

for epoch in range(1, num_epochs_mfcc + 1):
    train_loss_m, train_acc_m = train_one_epoch(
        model_mfcc, train_loader_mfcc, optimizer_mfcc, criterion_kws, device
    )
    val_loss_m, val_acc_m = evaluate(
        model_mfcc, val_loader_mfcc, criterion_kws, device
    )
    print(f"[MFCC-NOISE] Epoch {epoch:02d}: "
          f"Train loss {train_loss_m:.4f}, acc {train_acc_m*100:.2f}% | "
          f"Val loss {val_loss_m:.4f}, acc {val_acc_m*100:.2f}%")

clean_test_loss_m, clean_test_acc_m = evaluate(
    model_mfcc, test_loader_mfcc, criterion_kws, device
)
print(f"[MFCC-NOISE] Final CLEAN test: loss {clean_test_loss_m:.4f}, "
      f"acc {clean_test_acc_m*100:.2f}%")

# Noisy MFCC loaders for SNR eval
noisy_loaders_mfcc = {snr: build_noisy_test_loader("mfcc", snr) for snr in snrs}

mfcc_accuracies = {"Clean": clean_test_acc_m}
for snr in snrs:
    _, acc_m = evaluate(model_mfcc, noisy_loaders_mfcc[snr], criterion_kws, device)
    mfcc_accuracies[snr] = acc_m

print("[MFCC-NOISE] Acc vs SNR:", mfcc_accuracies)


In [None]:
# ==============================
# 11. CONFUSION MATRIX / CLASSIFICATION REPORT (BEST MODEL)
# ==============================

# Use best KWS model: log-Mel + noise + SpecAugment (model_spec)
model_spec.eval()
all_y = []
all_pred = []

with torch.no_grad():
    for X, y in test_loader_logmel_spec:
        X = X.to(device)
        y = y.to(device)

        logits = model_spec(X)
        preds = logits.argmax(dim=1)

        all_y.append(y.cpu().numpy())
        all_pred.append(preds.cpu().numpy())

all_y = np.concatenate(all_y)
all_pred = np.concatenate(all_pred)

print("Confusion matrix (SpecAug log-Mel model on clean test):")
print(confusion_matrix(all_y, all_pred))

print("\nClassification report (SpecAug log-Mel model on clean test):")
print(classification_report(all_y, all_pred, target_names=ALL_CLASSES))


In [None]:
# ==============================
# 12. PLOT: ACCURACY VS SNR (LOG-MEL MODELS)
# ==============================

snr_labels = ["Clean", "20 dB", "10 dB", "0 dB"]

acc_clean_only = [
    clean_accuracies["Clean"] * 100,
    clean_accuracies[20.0] * 100,
    clean_accuracies[10.0] * 100,
    clean_accuracies[0.0] * 100,
]
acc_noise_trained = [
    noise_accuracies["Clean"] * 100,
    noise_accuracies[20.0] * 100,
    noise_accuracies[10.0] * 100,
    noise_accuracies[0.0] * 100,
]
acc_specaug = [
    specaug_accuracies["Clean"] * 100,
    specaug_accuracies[20.0] * 100,
    specaug_accuracies[10.0] * 100,
    specaug_accuracies[0.0] * 100,
]

plt.figure()
plt.plot(snr_labels, acc_clean_only, marker="o", label="Clean-only (log-Mel)")
plt.plot(snr_labels, acc_noise_trained, marker="o", label="Noise-trained (log-Mel)")
plt.plot(snr_labels, acc_specaug, marker="o", label="Noise+SpecAug (log-Mel)")
plt.xlabel("Condition")
plt.ylabel("Accuracy (%)")
plt.title("GuardianDrive-Voice: Accuracy vs SNR")
plt.grid(True)
plt.legend()
plt.show()


In [None]:
# ==============================
# 13. GUARDIANDRIVE DEMO ON YOUR OWN WAVs
# ==============================

def load_wav_and_features(path, mode="logmel", sample_rate=SAMPLE_RATE):
    """
    Load a wav file, convert to mono, resample, pad/trim,
    and compute features (log-Mel or MFCC).
    """
    wav, sr = torchaudio.load(path)

    wav = pad_or_trim(wav, wav.shape[-1])  # [1, T]

    # Resample if needed
    if sr != sample_rate:
        wav = torchaudio.functional.resample(wav, sr, sample_rate)

    # Pad/trim to 1 second
    wav = pad_or_trim(wav, CLIP_NUM_SAMPLES)

    # Convert to features
    feat = waveform_to_features(wav, mode=mode)  # [1,F,T]
    return wav, feat


def guardian_predict(model, feat_tensor):
    """
    Run model on a single feature tensor.
    feat_tensor: [1,1,F,T] on device
    returns: (pred_class, confidence, full_probs)
    """
    model.eval()
    with torch.no_grad():
        logits = model(feat_tensor)
        probs = torch.softmax(logits, dim=1)[0]
        conf, idx = torch.max(probs, dim=0)
        pred_class = ALL_CLASSES[idx.item()]
        return pred_class, conf.item(), probs.cpu().numpy()


def demo_on_file(path, model, noise_waveforms=None, snr_db=None):
    """
    Run GuardianDrive-Voice model on:
      - clean version
      - optional noisy version (if noise_waveforms and snr_db provided)
    """
    print(f"\n=== GuardianDrive Demo on: {path} ===")

    # Clean
    clean_wav, clean_feat = load_wav_and_features(path, mode="logmel")
    clean_X = clean_feat.unsqueeze(0).to(device)  # [1,1,F,T]
    pred_clean, conf_clean, _ = guardian_predict(model, clean_X)
    print(f"Clean           -> predicted: {pred_clean:>7s}, confidence: {conf_clean*100:5.2f}%")

    # Noisy version (if noise + SNR specified)
    if noise_waveforms and snr_db is not None:
        noise = random.choice(noise_waveforms)
        noisy_wav = mix_with_noise(clean_wav, noise, snr_db)
        noisy_feat = waveform_to_features(noisy_wav, mode="logmel")
        noisy_X = noisy_feat.unsqueeze(0).to(device)
        pred_noisy, conf_noisy, _ = guardian_predict(model, noisy_X)
        print(f"Noise @ {snr_db:>4.1f} dB -> predicted: {pred_noisy:>7s}, confidence: {conf_noisy*100:5.2f}%")


# Use BEST model: log-Mel + noise + SpecAugment
guardian_model = model_spec

# Adjust these paths to wherever your demo wavs are
DEMO_FILES = [
    # "/path/to/guardian_help.wav",
    # "/path/to/call_ambulance.wav",
    # "/path/to/im_not_okay.wav",
]

for fname in DEMO_FILES:
    demo_on_file(fname, guardian_model, noise_waveforms=NOISE_WAVEFORMS, snr_db=10.0)


In [None]:
# ==============================
# 14.a RAVDESS DISTRESS DATASET (SPEAKER-INDEPENDENT)
# ==============================

class RAVDESSDistressDataset(Dataset):
    """
    Binary classification on RAVDESS:
      y = 0 -> non-distress (neutral, calm, happy, sad, disgust, surprised)
      y = 1 -> distress (angry, fearful)
    Speaker-independent split:
      Actors  1–18 -> train
      Actors 19–21 -> val
      Actors 22–24 -> test
    """

    def __init__(
        self,
        root: str,              # path like "./data/ravdess/audio_speech_actors_01-24"
        split: str = "train",   # "train" | "val" | "test"
        feature_mode: str = "logmel",
        sample_rate: int = SAMPLE_RATE,
    ):
        super().__init__()
        assert split in ["train", "val", "test"]
        self.root = Path(root)
        self.split = split
        self.feature_mode = feature_mode
        self.sample_rate = sample_rate

        self.train_actors = set(range(1, 19))
        self.val_actors   = set(range(19, 22))
        self.test_actors  = set(range(22, 25))

        self.files = []
        self.labels = []

        all_wavs = sorted(self.root.rglob("*.wav"))

        if len(all_wavs) == 0:
            print(f"[WARN] No .wav files found under {root}. Check the path.")
        else:
            print(f"[RAVDESS] Found total wav files: {len(all_wavs)}")

        for wav_path in all_wavs:
            actor_folder = wav_path.parent.name      # "Actor_01"
            try:
                actor_id = int(actor_folder.split("_")[1])
            except Exception:
                continue

            if actor_id in self.train_actors:
                actor_split = "train"
            elif actor_id in self.val_actors:
                actor_split = "val"
            elif actor_id in self.test_actors:
                actor_split = "test"
            else:
                continue

            if actor_split != split:
                continue

            label = self._emotion_to_distress(wav_path.name)
            self.files.append(wav_path)
            self.labels.append(label)

        print(f"[RAVDESSDistressDataset] split={split}, size={len(self.files)}")

    def _emotion_to_distress(self, filename: str):
        """
        RAVDESS file names like "03-01-08-01-02-01-01.wav"
        3rd field = emotion_id.

        Mapping:
          distress (1)      -> {angry (05), fearful (06)}
          non-distress (0)  -> everything else
        """
        stem = filename.split(".")[0]
        parts = stem.split("-")
        if len(parts) < 3:
            return 0
        emotion_id = int(parts[2])

        if emotion_id in {5, 6}:  # angry, fearful
            return 1
        else:
            return 0

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        wav_path = self.files[idx]
        y = self.labels[idx]

        waveform, sr = torchaudio.load(str(wav_path))
        waveform = pad_or_trim(waveform, waveform.shape[-1])

        if sr != self.sample_rate:
            waveform = torchaudio.functional.resample(waveform, sr, self.sample_rate)

        waveform = pad_or_trim(waveform, CLIP_NUM_SAMPLES)  # [1, 16000]

        feat = waveform_to_features(waveform, mode=self.feature_mode)  # [1, F, T]

        return feat, y


In [None]:
# ==============================
# 14b. DISTRESS DATALOADERS
# ==============================

# Quick check that RAVDESS root exists
print("RAVDESS root exists:", os.path.isdir(RAVDESS_ROOT))

dist_train_ds = RAVDESSDistressDataset(RAVDESS_ROOT, split="train", feature_mode="logmel")
dist_val_ds   = RAVDESSDistressDataset(RAVDESS_ROOT, split="val",   feature_mode="logmel")
dist_test_ds  = RAVDESSDistressDataset(RAVDESS_ROOT, split="test",  feature_mode="logmel")

dist_train_loader = DataLoader(
    dist_train_ds, batch_size=64, shuffle=True,
    collate_fn=guardian_collate_fn, num_workers=0,
    pin_memory=torch.cuda.is_available(),
)
dist_val_loader = DataLoader(
    dist_val_ds, batch_size=64, shuffle=False,
    collate_fn=guardian_collate_fn, num_workers=0,
    pin_memory=torch.cuda.is_available(),
)
dist_test_loader = DataLoader(
    dist_test_ds, batch_size=64, shuffle=False,
    collate_fn=guardian_collate_fn, num_workers=0,
    pin_memory=torch.cuda.is_available(),
)

# Sanity check: one batch
batch_X_dist, batch_y_dist = next(iter(dist_train_loader))
print("Distress batch X shape:", batch_X_dist.shape)  # [B,1,F,T]
print("Distress batch y shape:", batch_y_dist.shape)  # [B]


In [None]:
# ==============================
# 15. DISTRESS MODEL (SINGLE-TASK)
# ==============================

class DistressHead(nn.Module):
    """
    Simple MLP head:
      128-dim embedding -> 64-dim -> 2-class logits
    """
    def __init__(self, embedding_dim: int, num_classes: int = 2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(embedding_dim, 64),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.3),
            nn.Linear(64, num_classes),
        )

    def forward(self, z):
        return self.net(z)


class DistressVoiceModel(nn.Module):
    """
    Distress classifier using the same SharedAudioEncoder as KWS.
    """
    def __init__(self):
        super().__init__()
        self.encoder = SharedAudioEncoder(n_mels=N_MELS)
        self.distress_head = DistressHead(self.encoder.embedding_dim, num_classes=2)

    def forward(self, x):
        z = self.encoder(x)              # [B, 128]
        logits = self.distress_head(z)   # [B, 2]
        return logits


In [None]:
# ==============================
# 16. TRAIN DISTRESS MODEL (SINGLE-TASK) + CONFUSION MATRIX
# ==============================

dist_model = DistressVoiceModel().to(device)

criterion_dist = nn.CrossEntropyLoss()
optimizer_dist = torch.optim.Adam(dist_model.parameters(), lr=1e-3)

num_epochs_dist = 5   # can increase to 10–15 later

for epoch in range(1, num_epochs_dist + 1):
    train_loss_d, train_acc_d = train_one_epoch(
        dist_model, dist_train_loader, optimizer_dist, criterion_dist, device
    )
    val_loss_d, val_acc_d = evaluate(
        dist_model, dist_val_loader, criterion_dist, device
    )

    print(f"[DISTRESS] Epoch {epoch:02d}: "
          f"Train loss {train_loss_d:.4f}, acc {train_acc_d*100:.2f}% | "
          f"Val loss {val_loss_d:.4f}, acc {val_acc_d*100:.2f}%")

test_loss_d, test_acc_d = evaluate(
    dist_model, dist_test_loader, criterion_dist, device
)
print(f"[DISTRESS] Final Test loss {test_loss_d:.4f}, acc {test_acc_d*100:.2f}%")

# Confusion matrix / F1
dist_model.eval()

all_y = []
all_pred = []

with torch.no_grad():
    for X, y in dist_test_loader:
        X = X.to(device)
        y = y.to(device)

        logits = dist_model(X)
        preds = logits.argmax(dim=1)

        all_y.append(y.cpu().numpy())
        all_pred.append(preds.cpu().numpy())

all_y = np.concatenate(all_y)
all_pred = np.concatenate(all_pred)

print("Confusion matrix (distress model):")
print(confusion_matrix(all_y, all_pred))

print("\nClassification report (distress model):")
print(classification_report(all_y, all_pred,
                            target_names=["non_distress", "distress"]))


In [None]:
# ==============================
# 17. MULTI-TASK MODEL (KWS + DISTRESS)
# ==============================

class MultiTaskGuardianModel(nn.Module):
    """
    Shared encoder with two heads:
      - command_head  : 11-way KWS (yes, no, up, ..., unknown)
      - distress_head : 2-way (non_distress, distress)
    """
    def __init__(self, num_command_classes: int, num_distress_classes: int = 2):
        super().__init__()
        self.encoder = SharedAudioEncoder(n_mels=N_MELS)
        self.command_head = CommandHead(self.encoder.embedding_dim,
                                        num_command_classes)
        self.distress_head = DistressHead(self.encoder.embedding_dim,
                                          num_distress_classes)

    def forward(self, x):
        z = self.encoder(x)                   # [B, 128]
        cmd_logits = self.command_head(z)     # [B, 11]
        dist_logits = self.distress_head(z)   # [B, 2]
        return cmd_logits, dist_logits


In [None]:
# ==============================
# 17a. MULTI-TASK TRAIN / EVAL HELPERS
# ==============================

def train_multitask_epoch(
    model,
    kws_loader,            # DataLoader for KWS (SpeechCommands)
    dist_loader,           # DataLoader for Distress (RAVDESS)
    optimizer,
    cmd_criterion,
    dist_criterion,
    device,
    lambda_dist: float = 1.0,
):
    """
    One epoch of multi-task training.
    Each step sees one KWS batch and one Distress batch.
    """
    model.train()

    total_cmd_loss = 0.0
    total_dist_loss = 0.0
    total_cmd_correct = 0
    total_dist_correct = 0
    total_cmd_samples = 0
    total_dist_samples = 0

    for (X_cmd, y_cmd), (X_dist, y_dist) in zip(kws_loader, dist_loader):
        X_cmd = X_cmd.to(device)
        y_cmd = y_cmd.to(device)
        X_dist = X_dist.to(device)
        y_dist = y_dist.to(device)

        optimizer.zero_grad()

        # KWS batch
        cmd_logits, _ = model(X_cmd)
        loss_cmd = cmd_criterion(cmd_logits, y_cmd)
        cmd_preds = cmd_logits.argmax(dim=1)
        cmd_correct = (cmd_preds == y_cmd).sum().item()

        # Distress batch
        _, dist_logits = model(X_dist)
        loss_dist = dist_criterion(dist_logits, y_dist)
        dist_preds = dist_logits.argmax(dim=1)
        dist_correct = (dist_preds == y_dist).sum().item()

        # Combined loss
        loss = loss_cmd + lambda_dist * loss_dist
        loss.backward()
        optimizer.step()

        batch_cmd_size = y_cmd.size(0)
        batch_dist_size = y_dist.size(0)

        total_cmd_loss += loss_cmd.item() * batch_cmd_size
        total_dist_loss += loss_dist.item() * batch_dist_size

        total_cmd_correct += cmd_correct
        total_dist_correct += dist_correct

        total_cmd_samples += batch_cmd_size
        total_dist_samples += batch_dist_size

    avg_cmd_loss = total_cmd_loss / max(1, total_cmd_samples)
    avg_dist_loss = total_dist_loss / max(1, total_dist_samples)

    cmd_acc = total_cmd_correct / max(1, total_cmd_samples)
    dist_acc = total_dist_correct / max(1, total_dist_samples)

    return avg_cmd_loss, cmd_acc, avg_dist_loss, dist_acc


@torch.no_grad()
def evaluate_kws_multitask(model, loader, device):
    model.eval()
    correct = 0
    total = 0

    for X, y in loader:
        X = X.to(device)
        y = y.to(device)
        cmd_logits, _ = model(X)
        preds = cmd_logits.argmax(dim=1)
        correct += (preds == y).sum().item()
        total += y.size(0)

    return correct / max(1, total)


@torch.no_grad()
def evaluate_distress_multitask(model, loader, device):
    model.eval()
    correct = 0
    total = 0

    for X, y in loader:
        X = X.to(device)
        y = y.to(device)
        _, dist_logits = model(X)
        preds = dist_logits.argmax(dim=1)
        correct += (preds == y).sum().item()
        total += y.size(0)

    return correct / max(1, total)


In [None]:
# ==============================
# 17b. RUN MULTI-TASK TRAINING
# ==============================

# Reuse KWS loaders from EXP 2 (noise-trained log-Mel)
kws_train_loader_mt = train_loader_noise
kws_val_loader_mt   = val_loader_logmel_noise
kws_test_loader_mt  = test_loader_logmel_noise

multitask_model = MultiTaskGuardianModel(
    num_command_classes=NUM_CLASSES,
    num_distress_classes=2,
).to(device)

cmd_criterion = nn.CrossEntropyLoss()
dist_criterion = nn.CrossEntropyLoss()
optimizer_mt = torch.optim.Adam(multitask_model.parameters(), lr=1e-3)

num_epochs_mt = 5
lambda_dist = 1.0

for epoch in range(1, num_epochs_mt + 1):
    cmd_loss, cmd_acc, dist_loss, dist_acc = train_multitask_epoch(
        multitask_model,
        kws_train_loader_mt,
        dist_train_loader,
        optimizer_mt,
        cmd_criterion,
        dist_criterion,
        device,
        lambda_dist=lambda_dist,
    )

    cmd_val_acc = evaluate_kws_multitask(multitask_model, kws_val_loader_mt, device)
    dist_val_acc = evaluate_distress_multitask(multitask_model, dist_val_loader, device)

    print(f"[MULTI-TASK] Epoch {epoch:02d}: "
          f"KWS train loss {cmd_loss:.4f}, train acc {cmd_acc*100:.2f}% | "
          f"Distress train loss {dist_loss:.4f}, train acc {dist_acc*100:.2f}% || "
          f"KWS val acc {cmd_val_acc*100:.2f}% | Distress val acc {dist_val_acc*100:.2f}%")

kws_test_acc_mt = evaluate_kws_multitask(multitask_model, kws_test_loader_mt, device)
dist_test_acc_mt = evaluate_distress_multitask(multitask_model, dist_test_loader, device)

print(f"[MULTI-TASK] Final KWS test acc: {kws_test_acc_mt*100:.2f}%")
print(f"[MULTI-TASK] Final Distress test acc: {dist_test_acc_mt*100:.2f}%")


In [None]:
# ==============================
# 18. REAL-TIME AUDIO LAYER + DEPLOYMENT + VISUALIZATION (UPDATED)
#     - Robust noise loading (recursive) 
#     - Front-end conditioning (bandpass + AGC)
#     - VAD + endpointing (energy-based)
#     - Streaming inference simulation + wake-gate
#     - Export (TorchScript + ONNX) + quantization demo
#     - Latency profiling (CUDA + MPS sync) + model size
#     - Better visualizations (raw + normalized confusion matrices, VAD energy plot)
# ==============================

import time

# ---------- 18.1 Robust noise loader (override) + reload ----------
def load_noise_waveforms(noise_dir: str, sample_rate: int = SAMPLE_RATE):
    """
    Robust noise loader:
      - recursive search
      - case-insensitive .wav
      - converts to mono, resamples, DC-removes
    Returns: list of waveforms, each [1, T_noise]
    """
    noise_waveforms = []
    noise_dir_path = Path(noise_dir)

    if not noise_dir_path.exists():
        print(f"[18][WARN] Noise directory does not exist: {noise_dir}")
        return noise_waveforms

    candidates = [p for p in noise_dir_path.rglob("*")
                  if p.is_file() and p.suffix.lower() == ".wav"]

    for wav_path in candidates:
        try:
            wav, sr = torchaudio.load(str(wav_path))

            # force mono [1, T]
            if wav.ndim == 2 and wav.shape[0] > 1:
                wav = wav.mean(dim=0, keepdim=True)
            elif wav.ndim == 1:
                wav = wav.unsqueeze(0)

            if sr != sample_rate:
                wav = torchaudio.functional.resample(wav, sr, sample_rate)

            wav = wav - wav.mean()
            noise_waveforms.append(wav)
        except Exception as e:
            print(f"[18][WARN] Failed noise file {wav_path}: {e}")

    print(f"[18] Loaded {len(noise_waveforms)} noise files from {noise_dir} (recursive)")
    return noise_waveforms

# Reload noise waveforms for streaming / noisy test loaders created AFTER this point
NOISE_WAVEFORMS = load_noise_waveforms(NOISE_DIR, SAMPLE_RATE)
if len(NOISE_WAVEFORMS) == 0:
    print("[18][WARN] NOISE_WAVEFORMS is empty. Noise mixing will be disabled in demos/loaders built after this.")


# ---------- 18.2 Front-end conditioning: bandpass + AGC ----------
def _to_mono_1ch(wav: torch.Tensor) -> torch.Tensor:
    """Ensure mono [1, T]."""
    if wav.ndim == 2 and wav.shape[0] > 1:
        wav = wav.mean(dim=0, keepdim=True)
    elif wav.ndim == 1:
        wav = wav.unsqueeze(0)
    return wav

def agc_rms_normalize(wav: torch.Tensor, target_dbfs: float = -20.0, eps: float = 1e-8):
    """
    Simple AGC: normalize RMS to target dBFS (wav assumed in [-1,1]).
    """
    wav = _to_mono_1ch(wav)
    rms = torch.sqrt(torch.mean(wav ** 2) + eps)
    target_rms = 10 ** (target_dbfs / 20.0)
    gain = target_rms / rms
    out = torch.clamp(wav * gain, -1.0, 1.0)
    return out

def simple_bandpass_condition(wav: torch.Tensor, sr: int = SAMPLE_RATE):
    """
    Cheap speech-band conditioning (NOT a real denoiser):
      - highpass ~80 Hz
      - lowpass  ~7.6 kHz
    """
    wav = _to_mono_1ch(wav)
    wav = torchaudio.functional.highpass_biquad(wav, sr, cutoff_freq=80.0)
    wav = torchaudio.functional.lowpass_biquad(wav, sr, cutoff_freq=7600.0)
    return wav

def frontend_process(wav: torch.Tensor, sr: int = SAMPLE_RATE,
                     do_bandpass: bool = True, do_agc: bool = True):
    """
    Front-end conditioning chain.
    """
    wav = _to_mono_1ch(wav)
    if do_bandpass:
        wav = simple_bandpass_condition(wav, sr)
    if do_agc:
        wav = agc_rms_normalize(wav, target_dbfs=-20.0)
    return wav


# ---------- 18.3 VAD + endpointing (energy-based) ----------
def vad_energy_segments(
    wav: torch.Tensor,
    sr: int = SAMPLE_RATE,
    frame_ms: float = 30.0,
    hop_ms: float = 10.0,
    energy_threshold_ratio: float = 2.5,
    min_speech_ms: float = 200.0,
    hangover_ms: float = 300.0,
):
    """
    Returns list of (start_sample, end_sample) speech segments.
    Energy threshold is median_energy * ratio (robust-ish).
    """
    wav = _to_mono_1ch(wav)
    x = wav[0]

    frame = int(sr * frame_ms / 1000.0)
    hop = int(sr * hop_ms / 1000.0)
    if frame <= 0 or hop <= 0 or x.numel() < frame:
        return [], None

    energies, idxs = [], []
    for i in range(0, x.numel() - frame + 1, hop):
        chunk = x[i:i+frame]
        energies.append((chunk ** 2).mean().item())
        idxs.append(i)

    med = float(np.median(energies)) + 1e-12
    thr = med * energy_threshold_ratio

    min_speech_frames = max(1, int(min_speech_ms / hop_ms))
    hangover_frames = max(1, int(hangover_ms / hop_ms))

    segments = []
    in_speech = False
    speech_start = None
    speech_frames = 0
    silence_frames = 0

    for k, e in enumerate(energies):
        is_speech = e > thr

        if not in_speech:
            if is_speech:
                speech_frames += 1
                if speech_frames >= min_speech_frames:
                    in_speech = True
                    speech_start = idxs[max(0, k - min_speech_frames + 1)]
                    silence_frames = 0
            else:
                speech_frames = 0
        else:
            if is_speech:
                silence_frames = 0
            else:
                silence_frames += 1
                if silence_frames >= hangover_frames:
                    speech_end = idxs[k] + frame
                    segments.append((speech_start, min(speech_end, x.numel())))
                    in_speech = False
                    speech_start = None
                    speech_frames = 0
                    silence_frames = 0

    if in_speech and speech_start is not None:
        segments.append((speech_start, x.numel()))

    meta = {
        "energies": np.array(energies, dtype=np.float32),
        "idxs": np.array(idxs, dtype=np.int64),
        "thr": float(thr),
        "frame": frame,
        "hop": hop,
        "frame_ms": frame_ms,
        "hop_ms": hop_ms,
    }
    return segments, meta


# ---------- 18.4 Visualization utilities ----------
def plot_confusion_matrix(cm: np.ndarray, class_names, title: str, normalize: bool = False):
    """
    If normalize=True: row-normalized (percent per true class).
    """
    cm_plot = cm.astype(np.float32)

    if normalize:
        row_sums = cm_plot.sum(axis=1, keepdims=True) + 1e-12
        cm_plot = (cm_plot / row_sums) * 100.0  # percent

    plt.figure(figsize=(9, 7))
    plt.imshow(cm_plot, interpolation="nearest")
    plt.title(title + (" (Normalized %)" if normalize else ""))
    plt.colorbar()

    ticks = np.arange(len(class_names))
    plt.xticks(ticks, class_names, rotation=45, ha="right")
    plt.yticks(ticks, class_names)
    plt.xlabel("Predicted")
    plt.ylabel("True")

    plt.tight_layout()
    plt.show()

def plot_vad_segments(wav: torch.Tensor, sr: int, segments, title="VAD Segments"):
    wav = _to_mono_1ch(wav)
    x = wav[0].cpu().numpy()
    t = np.arange(len(x)) / sr

    plt.figure(figsize=(12, 3))
    plt.plot(t, x)
    for (s0, s1) in segments:
        plt.axvspan(s0/sr, s1/sr, alpha=0.25)
    plt.title(title)
    plt.xlabel("Time (s)")
    plt.ylabel("Amplitude")
    plt.tight_layout()
    plt.show()

def plot_vad_energy(meta, title="VAD Frame Energy"):
    if meta is None:
        print(" No VAD meta to plot.")
        return
    energies = meta["energies"]
    idxs = meta["idxs"]
    thr = meta["thr"]

    times = idxs / SAMPLE_RATE
    plt.figure(figsize=(12, 3))
    plt.plot(times, energies)
    plt.axhline(thr, linestyle="--", linewidth=1)
    plt.title(title)
    plt.xlabel("Time (s)")
    plt.ylabel("Frame Energy")
    plt.tight_layout()
    plt.show()

def plot_logmel(feat: torch.Tensor, title="Log-Mel"):
    # feat: [1, F, T]
    m = feat[0].detach().cpu().numpy()
    plt.figure(figsize=(10, 3))
    plt.imshow(m, aspect="auto", origin="lower")
    plt.title(title)
    plt.xlabel("Frames")
    plt.ylabel("Mel bins")
    plt.colorbar()
    plt.tight_layout()
    plt.show()


# ---------- 18.5 Streaming inference simulation + wake-gate ----------
def streaming_kws_demo(
    wav_path: str,
    model: nn.Module,
    wake_gate_class: str = "go",
    wake_prob_thresh: float = 0.85,
    step_ms: float = 100.0,
    window_sec: float = 1.0,
    do_frontend: bool = True,
    do_viz: bool = True,
):
    """
    Offline streaming demo:
      - front-end conditioning
      - VAD segmentation
      - slide 1s windows through VAD speech segments
      - trigger if predicted class == wake_gate_class with confidence >= threshold

    NOTE: This is NOT a trained wake-word detector. It's a demo-friendly gate.
    """
    if not os.path.isfile(wav_path):
        print(f" File not found: {wav_path}")
        return []

    wav, sr = torchaudio.load(wav_path)
    wav = _to_mono_1ch(wav)

    if sr != SAMPLE_RATE:
        wav = torchaudio.functional.resample(wav, sr, SAMPLE_RATE)
        sr = SAMPLE_RATE

    if do_frontend:
        wav = frontend_process(wav, sr)

    segments, meta = vad_energy_segments(wav, sr)
    print(f" VAD segments found: {len(segments)}")

    if do_viz:
        plot_vad_segments(wav, sr, segments, title=f"VAD Segments: {os.path.basename(wav_path)}")
        plot_vad_energy(meta, title=f"VAD Energy: {os.path.basename(wav_path)}")

    wake_idx = CLASS_TO_INDEX.get(wake_gate_class, None)
    if wake_idx is None:
        print(f" wake_gate_class '{wake_gate_class}' not in ALL_CLASSES.")
        return []

    step = int(sr * step_ms / 1000.0)
    win = int(sr * window_sec)

    model.eval()
    triggers = []  # dicts: {"time": float, "pred": str, "conf": float}

    for (s0, s1) in segments:
        pos = s0
        while pos + win <= s1:
            chunk = wav[:, pos:pos+win]
            feat = waveform_to_features(chunk, mode="logmel")  # [1,F,T]
            X = feat.unsqueeze(0).to(device)                   # [1,1,F,T]

            with torch.no_grad():
                logits = model(X)
                probs = torch.softmax(logits, dim=1)[0]
                conf, pred = torch.max(probs, dim=0)

            pred_name = ALL_CLASSES[pred.item()]
            if pred.item() == wake_idx and conf.item() >= wake_prob_thresh:
                tsec = pos / sr
                triggers.append({"time": tsec, "pred": pred_name, "conf": conf.item()})
                print(f"[WAKE-GATE] {tsec:6.2f}s -> {pred_name} ({conf.item()*100:5.1f}%)")

            pos += step

    if do_viz:
        if len(triggers) > 0:
            times = [d["time"] for d in triggers]
            confs = [d["conf"] * 100 for d in triggers]
            plt.figure(figsize=(10, 2.5))
            plt.scatter(times, confs)
            plt.title("Wake-Gate Triggers Over Time")
            plt.xlabel("Time (s)")
            plt.ylabel("Confidence (%)")
            plt.grid(True)
            plt.tight_layout()
            plt.show()
        else:
            print(" No wake-gate triggers at current threshold.")

    return triggers


# ---------- 18.6 Deployment: quantization + export ----------
def quantize_for_cpu(model: nn.Module):
    """
    Dynamic quantization demo: affects Linear layers only (limited gains here).
    """
    model_cpu = model.to("cpu").eval()
    qmodel = torch.quantization.quantize_dynamic(
        model_cpu,
        {nn.Linear},
        dtype=torch.qint8
    )
    return qmodel

def export_torchscript(model: nn.Module, out_path="guardian_voice_kws.pt"):
    model_cpu = model.to("cpu").eval()
    example = torch.randn(1, 1, N_MELS, 101)
    traced = torch.jit.trace(model_cpu, example)
    traced.save(out_path)
    print(f" Saved TorchScript: {out_path}")

def export_onnx(model: nn.Module, out_path="guardian_voice_kws.onnx"):
    model_cpu = model.to("cpu").eval()
    example = torch.randn(1, 1, N_MELS, 101)
    try:
        torch.onnx.export(
            model_cpu,
            example,
            out_path,
            input_names=["x"],
            output_names=["logits"],
            dynamic_axes={"x": {0: "batch", 3: "time"}, "logits": {0: "batch"}},
            opset_version=17
        )
        print(f" Saved ONNX: {out_path}")
    except Exception as e:
        print(f"[WARN] ONNX export failed: {e}")


# ---------- 18.7 Profiling: latency + model size ----------
def _count_params(model: nn.Module) -> int:
    return sum(p.numel() for p in model.parameters())

def _model_size_mb_state_dict(model: nn.Module) -> float:
    """
    Rough size: sum of state_dict tensor bytes / MB.
    """
    total_bytes = 0
    sd = model.state_dict()
    for k, v in sd.items():
        if torch.is_tensor(v):
            total_bytes += v.numel() * v.element_size()
    return total_bytes / (1024 ** 2)

def benchmark_forward(model: nn.Module, device: torch.device, reps: int = 300, warmup: int = 50):
    """
    Forward-only benchmark. Includes correct synchronization for CUDA and MPS.
    """
    model = model.to(device).eval()
    x = torch.randn(1, 1, N_MELS, 101, device=device)

    # warmup
    for _ in range(warmup):
        _ = model(x)

    if device.type == "cuda":
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.synchronize()
        t0 = time.perf_counter()
        for _ in range(reps):
            _ = model(x)
        torch.cuda.synchronize()
        t1 = time.perf_counter()
        ms = (t1 - t0) * 1000 / reps
        mem = torch.cuda.max_memory_allocated() / (1024**2)
        print(f"[CUDA] forward avg: {ms:.3f} ms | peak mem: {mem:.1f} MB")

    elif device.type == "mps":
        # MPS is async: must sync
        torch.mps.synchronize()
        t0 = time.perf_counter()
        for _ in range(reps):
            _ = model(x)
        torch.mps.synchronize()
        t1 = time.perf_counter()
        ms = (t1 - t0) * 1000 / reps
        print(f"[MPS] forward avg: {ms:.3f} ms")

    else:
        t0 = time.perf_counter()
        for _ in range(reps):
            _ = model(x)
        t1 = time.perf_counter()
        ms = (t1 - t0) * 1000 / reps
        print(f"[CPU] forward avg: {ms:.3f} ms")

    print(f" Params: {_count_params(model):,} | StateDict size: {_model_size_mb_state_dict(model):.2f} MB")

def benchmark_end_to_end_kws(model: nn.Module, device: torch.device, reps: int = 50):
    """
    End-to-end: waveform -> features -> forward.
    This is closer to real streaming cost than forward-only.
    """
    model = model.to(device).eval()
    wav = torch.randn(1, CLIP_NUM_SAMPLES) * 0.05  # 1 sec synthetic "audio"
    wav = frontend_process(wav, SAMPLE_RATE)       # front-end on CPU
    wav = wav.to("cpu")

    # time features on CPU + forward on device
    times = []
    for _ in range(reps):
        t0 = time.perf_counter()
        feat = waveform_to_features(wav, mode="logmel")        # [1,F,T] CPU
        X = feat.unsqueeze(0).to(device)                      # [1,1,F,T]
        _ = model(X)                                          # forward
        # sync if needed
        if device.type == "cuda":
            torch.cuda.synchronize()
        elif device.type == "mps":
            torch.mps.synchronize()
        t1 = time.perf_counter()
        times.append((t1 - t0) * 1000)

    print(f" End-to-end (feat+forward) avg: {np.mean(times):.2f} ms | p95: {np.percentile(times, 95):.2f} ms")


# ---------- 18.8 Safety checks (non-production heuristics) ----------
def simple_replay_risk_checks(wav: torch.Tensor, sr: int = SAMPLE_RATE):
    """
    NON-PRODUCTION heuristics:
      - clipping rate
      - energy coefficient-of-variation (low variability can indicate replay/processed audio)
    """
    wav = _to_mono_1ch(wav)
    x = wav[0]

    clip_rate = (x.abs() > 0.98).float().mean().item()

    frame = int(sr * 0.03)
    hop = int(sr * 0.01)
    energies = []
    for i in range(0, max(1, x.numel() - frame), hop):
        chunk = x[i:i+frame]
        energies.append((chunk**2).mean().item())
    energies = np.array(energies) if len(energies) else np.array([0.0])

    e_std = float(np.std(energies))
    e_mean = float(np.mean(energies) + 1e-12)
    cv = e_std / e_mean

    warnings = []
    if clip_rate > 0.01:
        warnings.append(f"High clipping rate: {clip_rate*100:.2f}%")
    if cv < 0.15:
        warnings.append(f"Low energy variability (CV={cv:.2f}) — possible replay/processed audio")
    return warnings


# ---------- 18.9 Better confusion-matrix visualizations (raw + normalized) ----------
# KWS CM plots (SpecAug best model)
try:
    model_spec.eval()
    all_y_kws, all_pred_kws = [], []

    with torch.no_grad():
        for X, y in test_loader_logmel_spec:
            X = X.to(device)
            y = y.to(device)
            logits = model_spec(X)
            preds = logits.argmax(dim=1)
            all_y_kws.append(y.cpu().numpy())
            all_pred_kws.append(preds.cpu().numpy())

    all_y_kws = np.concatenate(all_y_kws)
    all_pred_kws = np.concatenate(all_pred_kws)
    cm_kws = confusion_matrix(all_y_kws, all_pred_kws)

    plot_confusion_matrix(cm_kws, ALL_CLASSES, "KWS Confusion Matrix (SpecAug Log-Mel)", normalize=False)
    plot_confusion_matrix(cm_kws, ALL_CLASSES, "KWS Confusion Matrix (SpecAug Log-Mel)", normalize=True)

except Exception as e:
    print(" Skipping KWS CM plots (model_spec/test_loader not available):", e)

# Distress CM plots (dist_model)
try:
    dist_model.eval()
    all_y_d, all_pred_d = [], []

    with torch.no_grad():
        for X, y in dist_test_loader:
            X = X.to(device)
            y = y.to(device)
            logits = dist_model(X)
            preds = logits.argmax(dim=1)
            all_y_d.append(y.cpu().numpy())
            all_pred_d.append(preds.cpu().numpy())

    all_y_d = np.concatenate(all_y_d)
    all_pred_d = np.concatenate(all_pred_d)
    cm_d = confusion_matrix(all_y_d, all_pred_d)

    plot_confusion_matrix(cm_d, ["non_distress", "distress"], "Distress Confusion Matrix", normalize=False)
    plot_confusion_matrix(cm_d, ["non_distress", "distress"], "Distress Confusion Matrix", normalize=True)

except Exception as e:
    print(" Skipping Distress CM plots (dist_model/dist_test_loader not available):", e)


# ---------- 18.10 Quick run: benchmark + optional exports ----------
try:
    print("\n Forward benchmark (guardian_model if defined, else model_spec)...")
    _m = guardian_model if "guardian_model" in globals() else model_spec
    benchmark_forward(_m, device)
    benchmark_end_to_end_kws(_m, device)

    # Optional exports
    # export_torchscript(_m, "guardian_voice_kws.pt")
    # export_onnx(_m, "guardian_voice_kws.onnx")

    # Optional quantization demo (CPU only)
    # q_m = quantize_for_cpu(_m)
    # print("[18] Quantized model (CPU) ready.")

except Exception as e:
    print(" Benchmark/export section skipped:", e)