# Coudwalk Technical Challenge

## Objective:
Your task is to build a lightweight prototype that listens to spoken digits (0–9) and predicts the correct number. The goal is to find the lightest effective solution. Live microphone input to test your model in real time. This helps explore real-world performance, including latency, noise handling, and usability under less controlled conditions.

# Data loading & Preprocessing

In [19]:
# Data loading & Preprocessing for DS-CNN on Free Spoken Digit Dataset (HF: mteb/free-spoken-digit-dataset)
# - Loads dataset and derives label names directly from HF features
# - Converts audio to Log-Mel spectrograms suitable for DS-CNN: shape [1, n_mels, T]
# - Provides modular utilities and a robust collate function
# - Exposes: raw (DatasetDict), label_names, id2label, label2id, mapped_splits (DatasetDict), collate

# Ensure required packages
try:
    from datasets import load_dataset, Audio, DatasetDict
except ModuleNotFoundError:
    !pip install datasets[audio]
    from datasets import load_dataset, Audio, DatasetDict

import torch
import numpy as np

# Torchaudio for spectrograms
try:
    import torchaudio
    from torchaudio.transforms import MelSpectrogram
    from torchaudio.functional import amplitude_to_DB
except ModuleNotFoundError:
    !pip install torchaudio
    import torchaudio
    from torchaudio.transforms import MelSpectrogram
    from torchaudio.functional import amplitude_to_DB


def get_or_default(name: str, default):
    """Fetch a global variable by name if it exists; otherwise return default."""
    return globals()[name] if name in globals() else default


# Configuration (use existing globals if already defined elsewhere in the notebook)
target_sample_rate: int = get_or_default("target_sample_rate", 16000)
n_mels: int = get_or_default("n_mels", 40)
n_fft_val: int = get_or_default("n_fft_val", 512)
win_length: int = get_or_default("win_length", int(0.025 * target_sample_rate))  # ~25ms
hop_length: int = get_or_default("hop_length", int(0.010 * target_sample_rate))  # ~10ms
batch_size: int = get_or_default("batch_size", 128)
num_workers: int = get_or_default("num_workers", 2)

# 1) Load dataset and derive label names (no hardcoding)
raw: DatasetDict = load_dataset("mteb/free-spoken-digit-dataset")
label_names = raw["train"].features["label"].names
id2label = {i: name for i, name in enumerate(label_names)}
label2id = {name: i for i, name in id2label.items()}

# Cast/Decode audio at target sampling rate
raw = raw.cast_column("audio", Audio(sampling_rate=target_sample_rate))

# 2) Preprocess: audio -> Log-Mel spectrogram
# Reuse existing mel transform if available, else create a new one
if "melspec" in globals() and isinstance(globals()["melspec"], MelSpectrogram):
    melspec: MelSpectrogram = globals()["melspec"]
else:
    melspec = MelSpectrogram(
        sample_rate=target_sample_rate,
        n_fft=n_fft_val,
        win_length=win_length,
        hop_length=hop_length,
        f_min=0.0,
        f_max=target_sample_rate // 2,
        n_mels=n_mels,
        center=True,
        power=2.0,  # power spectrogram
        norm="slaney",
        mel_scale="htk",
    )


def waveform_to_logmel(waveform: torch.Tensor, sample_rate: int) -> torch.Tensor:
    """
    Convert mono waveform [T] or [1, T] to Log-Mel spectrogram [1, n_mels, time].
    """
    if waveform.dim() == 1:
        waveform = waveform.unsqueeze(0)  # [1, T]
    # Ensure target sample rate; dataset cast handles this, but keep signature consistent
    with torch.no_grad():
        mel = melspec(waveform)  # [1, n_mels, time]
        # Convert to log scale (dB). Add small offset to avoid log(0).
        log_mel = amplitude_to_DB(mel.clamp_min(1e-10), multiplier=10.0, amin=1e-10, db_multiplier=0.0)
        # Optional per-utterance normalization (stabilizes training for small models)
        mean = log_mel.mean(dim=(-1, -2), keepdim=True)
        std = log_mel.std(dim=(-1, -2), keepdim=True).clamp_min(1e-5)
        log_mel = (log_mel - mean) / std
    return log_mel  # [1, n_mels, time]


def _map_example_to_features(batch):
    """
    HF map function: takes an example with keys {"audio": {"array", "sampling_rate"}, "label"}
    Returns dict with "input_values": float32 tensor-like [1, n_mels, time] and "label": int
    """
    arr = batch["audio"]["array"]
    sr = batch["audio"]["sampling_rate"]
    # Convert to torch waveform
    wf = torch.tensor(arr, dtype=torch.float32)
    # Compute log-mel
    features = waveform_to_logmel(wf, sr)
    # Store as list to be HF-serializable; collate will convert back to tensor
    return {
        "input_values": features.squeeze(0).numpy()[None, ...].astype(np.float32),  # [1, n_mels, T] as numpy
        "label": int(batch["label"]),
    }


# Apply mapping to all splits
mapped_splits: DatasetDict = raw.map(
    _map_example_to_features,
    remove_columns=[c for c in raw["train"].column_names if c not in ("label",)],
)


# 3) Collate for DS-CNN: pads along time dimension to max(T) in batch; output [B, 1, n_mels, T]
def collate(batch):
    """
    Collate function:
    - Accepts items with {"input_values": [1, n_mels, T], "label": int}
    - Right-pads time dimension to the max length in batch
    - Returns (features [B, 1, n_mels, T], labels [B])
    """
    feats, labels = [], []
    time_dim = -1
    for item in batch:
        x = item["input_values"]
        # Ensure torch tensor [1, n_mels, T]
        if isinstance(x, list):
            x = np.array(x, dtype=np.float32)
        if isinstance(x, np.ndarray):
            x = torch.from_numpy(x)
        if x.dim() == 2:
            x = x.unsqueeze(0)  # [1, n_mels, T]
        feats.append(x)
        labels.append(int(item["label"]))
    # Pad along time
    max_T = max(x.shape[time_dim] for x in feats)
    padded = []
    for x in feats:
        t = x.shape[time_dim]
        if t < max_T:
            pad_shape = list(x.shape)
            pad_shape[time_dim] = max_T - t
            pad = torch.zeros(pad_shape, dtype=x.dtype)
            x = torch.cat([x, pad], dim=time_dim)
        elif t > max_T:
            idx = [slice(None)] * x.dim()
            idx[time_dim] = slice(0, max_T)
            x = x[tuple(idx)]
        padded.append(x)
    features = torch.stack(padded, dim=0)  # [B, 1, n_mels, T]
    labels = torch.tensor(labels, dtype=torch.long)
    return features, labels


# Expose commonly used globals for downstream cells
globals().update({
    "raw": raw,
    "label_names": label_names,
    "id2label": id2label,
    "label2id": label2id,
    "mapped_splits": mapped_splits,
    "collate": collate,
    "target_sample_rate": target_sample_rate,
    "n_mels": n_mels,
    "n_fft_val": n_fft_val,
    "win_length": win_length,
    "hop_length": hop_length,
    "melspec": melspec,
})


# DS-CNN Model & Training

In [20]:
import os
from pathlib import Path
import torch
import torch.nn as nn
from torch.utils.data import DataLoader


class DSConvBlock(nn.Module):
    """
    Depthwise Separable Convolution Block:
    - Depthwise 2D convolution (groups=in_channels)
    - BatchNorm + ReLU
    - Pointwise 1x1 convolution
    - BatchNorm + ReLU
    """

    def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: tuple[int, int] = (1, 1),
                 dilation: int = 1):
        super().__init__()
        padding = ((kernel_size - 1) // 2) * dilation
        self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride, padding=padding,
                                   dilation=dilation, groups=in_channels, bias=False)
        self.dw_bn = nn.BatchNorm2d(in_channels)
        self.dw_relu = nn.ReLU(inplace=True)
        self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
        self.pw_bn = nn.BatchNorm2d(out_channels)
        self.pw_relu = nn.ReLU(inplace=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.depthwise(x)
        x = self.dw_bn(x)
        x = self.dw_relu(x)
        x = self.pointwise(x)
        x = self.pw_bn(x)
        x = self.pw_relu(x)
        return x


class DSCNN(nn.Module):
    """
    Depthwise Separable CNN for keyword spotting / digit recognition.
    Expects input of shape [B, 1, n_mels(=40), T].
    Architecture:
      - Stem conv
      - Stack of DSConv blocks with occasional stride for downsampling
      - Global average pooling
      - Linear classifier
    """

    def __init__(self, n_mels: int = 40, n_classes: int = 10, channels: tuple[int, ...] = (64, 64, 128, 128, 256)):
        super().__init__()
        c1, c2, c3, c4, c5 = channels
        self.stem = nn.Sequential(
            nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(c1),
            nn.ReLU(inplace=True),
        )
        self.features = nn.Sequential(
            DSConvBlock(c1, c2, kernel_size=3, stride=(2, 2)),  # downsample both axes
            DSConvBlock(c2, c3, kernel_size=3, stride=(1, 2)),  # further reduce time
            DSConvBlock(c3, c4, kernel_size=3, stride=(2, 1)),  # reduce mel
            DSConvBlock(c4, c5, kernel_size=3, stride=(1, 1)),
        )
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Linear(c5, n_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.stem(x)
        x = self.features(x)
        x = self.pool(x).flatten(1)
        x = self.classifier(x)
        return x


def count_parameters(model: nn.Module) -> int:
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [21]:
from datasets import DatasetDict


def _build_splits_from_prepared(prepared: dict, val_size: float = 0.1) -> dict:
    """
    Create train/validation/test splits from prepared_splits.
    If validation not present, split from train with stratification.
    """
    splits = dict(prepared)
    if "validation" not in splits:
        tv = splits["train"].train_test_split(test_size=val_size, stratify_by_column="label")
        splits["train"], splits["validation"] = tv["train"], tv["test"]
    return splits


def _num_classes_from_dataset(ds) -> int:
    labs = list(set(ds["label"]))
    return int(max(labs) + 1)


def build_dataloaders(splits: dict, collate_fn, batch_size: int = 128, num_workers: int = 2) -> tuple[
    DataLoader, DataLoader, DataLoader]:
    train_loader = DataLoader(splits["train"], batch_size=batch_size, shuffle=True, num_workers=num_workers,
                              collate_fn=collate_fn)
    val_loader = DataLoader(splits["validation"], batch_size=batch_size, shuffle=False, num_workers=num_workers,
                            collate_fn=collate_fn)
    test_loader = DataLoader(splits["test"], batch_size=batch_size, shuffle=False, num_workers=num_workers,
                             collate_fn=collate_fn)
    return train_loader, val_loader, test_loader


prepared_with_val = _build_splits_from_prepared(mapped_splits, val_size=0.1)
n_classes = _num_classes_from_dataset(prepared_with_val["train"])
train_loader, val_loader, test_loader = build_dataloaders(prepared_with_val, collate, batch_size=128,
                                                          num_workers=2)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DSCNN(n_mels=40, n_classes=n_classes).to(device)
param_count = count_parameters(model)

models_dir = Path("models")
models_dir.mkdir(parents=True, exist_ok=True)
best_ckpt_path = models_dir / "dscnn_best.pt"


# Training Loop

In [22]:
import os
import time
import math
from typing import Dict, List, Tuple, Callable, Optional

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt


def _macro_micro_from_confusion(conf_mat: torch.Tensor) -> Dict[str, float]:
    # conf_mat shape: [C, C] where rows = true, cols = pred
    cm = conf_mat.float()
    tp = torch.diag(cm)
    fp = cm.sum(0) - tp
    fn = cm.sum(1) - tp
    tn = cm.sum() - (tp + fp + fn)

    # Avoid division by zero
    eps = 1e-12

    precision_per_class = tp / torch.clamp(tp + fp, min=eps)
    recall_per_class = tp / torch.clamp(tp + fn, min=eps)
    f1_per_class = 2 * precision_per_class * recall_per_class / torch.clamp(precision_per_class + recall_per_class,
                                                                            min=eps)

    macro_precision = precision_per_class.mean().item()
    macro_recall = recall_per_class.mean().item()
    macro_f1 = f1_per_class.mean().item()

    # Micro = compute from totals
    tp_sum = tp.sum()
    fp_sum = fp.sum()
    fn_sum = fn.sum()
    micro_precision = (tp_sum / torch.clamp(tp_sum + fp_sum, min=eps)).item()
    micro_recall = (tp_sum / torch.clamp(tp_sum + fn_sum, min=eps)).item()
    micro_f1 = (2 * micro_precision * micro_recall / max(1e-12, (micro_precision + micro_recall))) if (
                                                                                                              micro_precision + micro_recall) > 0 else 0.0

    return {
        "precision_macro": macro_precision,
        "recall_macro": macro_recall,
        "f1_macro": macro_f1,
        "precision_micro": micro_precision,
        "recall_micro": micro_recall,
        "f1_micro": micro_f1,
    }


@torch.no_grad()
def evaluate_with_confusion(model: nn.Module, loader: DataLoader, device: torch.device, n_classes: int) -> Dict[
    str, float]:
    model.eval()
    total = 0
    correct = 0
    loss_sum = 0.0
    criterion = nn.CrossEntropyLoss()
    conf_mat = torch.zeros((n_classes, n_classes), dtype=torch.long, device=device)

    for feats, labels in loader:
        feats = feats.to(device)
        labels = labels.to(device)
        logits = model(feats)
        loss = criterion(logits, labels)

        preds = logits.argmax(1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
        loss_sum += loss.item() * labels.size(0)

        # Update confusion matrix
        for t, p in zip(labels.view(-1), preds.view(-1)):
            conf_mat[t.long(), p.long()] += 1

    metrics = _macro_micro_from_confusion(conf_mat)
    metrics.update({
        "loss": loss_sum / max(1, total),
        "acc": correct / max(1, total),
    })
    return metrics


def train_one_epoch_with_metrics(model: nn.Module, loader: DataLoader, optimizer: torch.optim.Optimizer,
                                 device: torch.device, n_classes: int,
                                 scaler: Optional[torch.cuda.amp.GradScaler] = None) -> Dict[str, float]:
    model.train()
    criterion = nn.CrossEntropyLoss()
    total = 0
    loss_sum = 0.0
    correct = 0
    conf_mat = torch.zeros((n_classes, n_classes), dtype=torch.long, device=device)

    # Gradient clipping config
    max_norm = 1.0  # L2 norm clipping

    for b_idx, (feats, labels) in enumerate(loader):
        feats = feats.to(device)
        labels = labels.to(device)

        optimizer.zero_grad(set_to_none=True)
        if scaler is not None:
            with torch.cuda.amp.autocast():
                logits = model(feats)
                loss = criterion(logits, labels)
            scaler.scale(loss).backward()

            # Unscale then clip before stepping
            scaler.unscale_(optimizer)
            global_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
            scaler.step(optimizer)
            scaler.update()
        else:
            logits = model(feats)
            loss = criterion(logits, labels)
            loss.backward()

            # Clip before optimizer.step()
            global_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)

            optimizer.step()

        loss_sum += loss.item() * labels.size(0)
        preds = logits.argmax(1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

        for t, p in zip(labels.view(-1), preds.view(-1)):
            conf_mat[t.long(), p.long()] += 1

    base = {
        "loss": loss_sum / max(1, total),
        "acc": correct / max(1, total),
    }
    base.update(_macro_micro_from_confusion(conf_mat))
    return base


@torch.no_grad()
def measure_model_only_latency(model: nn.Module, loader: DataLoader, device: torch.device,
                               warmup_batches: int = 5, max_batches: Optional[int] = None) -> Dict[str, float]:
    # Measures forward-pass latency from already-precomputed features to logits.
    model.eval()
    times: List[float] = []
    seen = 0

    # Warmup
    it = iter(loader)
    for _ in range(warmup_batches):
        try:
            feats, _ = next(it)
        except StopIteration:
            break
        _ = model(feats.to(device))

    # Timed runs
    it = iter(loader)
    with torch.inference_mode():
        for b_idx, (feats, _) in enumerate(it):
            if max_batches is not None and b_idx >= max_batches:
                break
            feats = feats.to(device)
            torch.cuda.synchronize() if torch.cuda.is_available() else None
            t0 = time.perf_counter()
            _ = model(feats)
            torch.cuda.synchronize() if torch.cuda.is_available() else None
            t1 = time.perf_counter()
            batch_time = t1 - t0
            per_sample = batch_time / max(1, feats.size(0))
            times.append(per_sample)
            seen += feats.size(0)

    if len(times) == 0:
        return {"latency_mean_s": float("nan"), "latency_p50_s": float("nan"), "latency_p95_s": float("nan")}
    arr = np.array(times)
    return {
        "latency_mean_s": float(arr.mean()),
        "latency_p50_s": float(np.percentile(arr, 50)),
        "latency_p95_s": float(np.percentile(arr, 95)),
    }


@torch.no_grad()
def measure_end_to_end_latency(
        model: nn.Module,
        dataset,
        device: torch.device,
        featurize_fn: Callable[[torch.Tensor, int], torch.Tensor],
        sample_rate_getter: Optional[Callable[[int], int]] = None,
        num_samples: int = 100,
        warmup: int = 10,
) -> Dict[str, float]:
    """
    Measures end-to-end latency from raw audio -> features -> classification.

    Arguments:
      - dataset: indexable dataset where dataset[i] returns either (waveform, label) or (waveform, sample_rate, label)
      - featurize_fn: function that maps (waveform, sample_rate) -> features tensor shaped as model expects (C, T, ...) or (F, T)
      - sample_rate_getter: optional function to get sample rate when dataset[i] does not return it; called as sample_rate_getter(i)
      - num_samples: number of random samples from dataset to time
      - warmup: number of warmup runs to exclude from timing

    Returns dict with mean/p50/p95 in seconds per sample.
    """
    model.eval()
    rng = np.random.default_rng(0)
    indices = rng.choice(len(dataset), size=min(num_samples + warmup, len(dataset)), replace=False)
    times: List[float] = []

    # Helper to extract (waveform, sr) regardless of dataset format
    def _get_waveform_sr(idx: int) -> Tuple[torch.Tensor, int]:
        item = dataset[idx]
        if isinstance(item, (tuple, list)):
            if len(item) == 3:
                waveform, sample_rate, _ = item
            elif len(item) == 2:
                waveform, _ = item
                sample_rate = sample_rate_getter(idx) if sample_rate_getter is not None else 16000
            else:
                raise ValueError("Unsupported dataset item format.")
        else:
            raise ValueError("Dataset item must be tuple/list.")
        return waveform, int(sample_rate)

    # Warmup
    for i in range(min(warmup, len(indices))):
        wf, sr = _get_waveform_sr(indices[i])
        feats = featurize_fn(wf, sr)
        feats = feats.unsqueeze(0).to(device) if feats.dim() == 3 else feats.to(device).unsqueeze(0)
        _ = model(feats)

    # Timed runs
    with torch.inference_mode():
        for i in range(warmup, len(indices)):
            wf, sr = _get_waveform_sr(indices[i])
            torch.cuda.synchronize() if torch.cuda.is_available() else None
            t0 = time.perf_counter()
            feats = featurize_fn(wf, sr)
            feats = feats.unsqueeze(0).to(device) if feats.dim() == 3 else feats.to(device).unsqueeze(0)
            _ = model(feats)
            torch.cuda.synchronize() if torch.cuda.is_available() else None
            t1 = time.perf_counter()
            times.append(t1 - t0)

    if len(times) == 0:
        return {"latency_mean_s": float("nan"), "latency_p50_s": float("nan"), "latency_p95_s": float("nan")}
    arr = np.array(times)
    return {
        "latency_mean_s": float(arr.mean()),
        "latency_p50_s": float(np.percentile(arr, 50)),
        "latency_p95_s": float(np.percentile(arr, 95)),
        "samples_timed": int(len(arr)),
    }


In [23]:
# Assumes you already have:
# - model, train_loader, val_loader, test_loader, device, n_classes, best_ckpt_path
# - epochs, optimizer, scheduler, scaler (optional)
# If you previously defined train_one_epoch/evaluate, we now use the enhanced versions below.

# Robust collate to handle cases where "input_values" may be a list instead of a Tensor
def _safe_collate(batch):
    import torch
    feats = []
    labels = []
    time_dim = -1
    for b in batch:
        x = b["input_values"]
        if isinstance(x, list):
            # unwrap single-element lists
            x = x[0] if len(x) > 0 else x
        if not isinstance(x, torch.Tensor):
            x = torch.tensor(x)
        if x.dim() == 2:
            x = x.unsqueeze(0)  # [1, n_mels, T]
        feats.append(x)
        labels.append(int(b["label"]))
    # pad/truncate to max time in batch
    max_time = max(x.shape[time_dim] for x in feats)
    padded = []
    for x in feats:
        t = x.shape[time_dim]
        if t > max_time:
            idx = [slice(None)] * x.dim()
            idx[time_dim] = slice(0, max_time)
            x = x[tuple(idx)]
        elif t < max_time:
            pad_shape = list(x.shape)
            pad_shape[time_dim] = max_time - t
            pad_tensor = torch.zeros(pad_shape, dtype=x.dtype)
            x = torch.cat([x, pad_tensor], dim=time_dim)
        padded.append(x)
    batch_features = torch.stack(padded, dim=0)  # [B, 1, n_mels, time]
    batch_labels = torch.tensor(labels, dtype=torch.long)
    return batch_features, batch_labels


# Rebuild DataLoaders with the robust collate to avoid shape/list issues
from torch.utils.data import DataLoader

train_loader = DataLoader(
    prepared_with_val["train"],
    batch_size=128,
    shuffle=True,
    num_workers=2,
    collate_fn=_safe_collate,
    pin_memory=torch.cuda.is_available(),
)
val_loader = DataLoader(
    prepared_with_val["validation"],
    batch_size=128,
    shuffle=False,
    num_workers=2,
    collate_fn=_safe_collate,
    pin_memory=torch.cuda.is_available(),
)
test_loader = DataLoader(
    prepared_with_val["test"],
    batch_size=128,
    shuffle=False,
    num_workers=2,
    collate_fn=_safe_collate,
    pin_memory=torch.cuda.is_available(),
)


# Clear session and CUDA cache between runs, then re-instantiate everything fresh
def clear_session():
    import gc
    import torch
    # Drop references if they exist
    for name in ("model", "optimizer", "scheduler", "scaler"):
        if name in globals():
            try:
                del globals()[name]
            except Exception:
                pass
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()


clear_session()

torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

# Fresh instantiation for each run
epochs = 25
lr = 0.002

model = DSCNN(n_mels=40, n_classes=n_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
scaler = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None

best_val_acc = 0.0
history = {
    "train_loss": [], "train_acc": [],
    "train_precision_macro": [], "train_recall_macro": [], "train_f1_macro": [],
    "val_loss": [], "val_acc": [],
    "val_precision_macro": [], "val_recall_macro": [], "val_f1_macro": [],
}


for epoch in range(1, epochs + 1):
    model.train()  # ensure training mode each epoch
    train_metrics = train_one_epoch_with_metrics(model, train_loader, optimizer, device, n_classes, scaler)
    val_metrics = evaluate_with_confusion(model, val_loader, device, n_classes)
    scheduler.step()

    history["train_loss"].append(train_metrics["loss"])
    history["train_acc"].append(train_metrics["acc"])
    history["train_precision_macro"].append(train_metrics["precision_macro"])
    history["train_recall_macro"].append(train_metrics["recall_macro"])
    history["train_f1_macro"].append(train_metrics["f1_macro"])

    history["val_loss"].append(val_metrics["loss"])
    history["val_acc"].append(val_metrics["acc"])
    history["val_precision_macro"].append(val_metrics["precision_macro"])
    history["val_recall_macro"].append(val_metrics["recall_macro"])
    history["val_f1_macro"].append(val_metrics["f1_macro"])

    if val_metrics["acc"] > best_val_acc:
        best_val_acc = val_metrics["acc"]
        torch.save(
            {"model_state": model.state_dict(), "config": {"n_classes": n_classes}},
            best_ckpt_path
        )

    print(f"Epoch {epoch:02d} | "
          f"Train: loss={train_metrics['loss']:.4f}, acc={train_metrics['acc']:.4f}, f1={train_metrics['f1_macro']:.4f} | "
          f"Val: loss={val_metrics['loss']:.4f}, acc={val_metrics['acc']:.4f}, f1={val_metrics['f1_macro']:.4f}")

best_val_acc


[grad clip] step 0: global_norm=0.6533, max_norm=1.0
Epoch 01 | Train: loss=2.1720, acc=0.2070, f1=0.1750 | Val: loss=2.3704, acc=0.1000, f1=0.0182
[grad clip] step 0: global_norm=0.8746, max_norm=1.0
Epoch 02 | Train: loss=1.7603, acc=0.4457, f1=0.4239 | Val: loss=3.7764, acc=0.1074, f1=0.0297
[grad clip] step 0: global_norm=1.1774, max_norm=1.0
Epoch 03 | Train: loss=1.2681, acc=0.7342, f1=0.7246 | Val: loss=1.3514, acc=0.5074, f1=0.5013
[grad clip] step 0: global_norm=0.9693, max_norm=1.0


KeyboardInterrupt: 

# Create Graphs


In [61]:
os.makedirs("Plot results", exist_ok=True)

epochs_range = range(1, len(history["train_acc"]) + 1)

# 1) Final test metrics
with torch.no_grad():
    test_metrics = evaluate_with_confusion(model, test_loader, device, n_classes)

# 2) Model-only latency (forward pass on precomputed features)
#    Limit batches to keep timing quick; adjust as needed.
model_latency = measure_model_only_latency(
    model, val_loader, device, warmup_batches=5, max_batches=20
)

# 3) End-to-end latency (raw audio -> features -> model).
#    If you have a raw-audio dataset and a featurize_fn available, compute it here.
#    For now, set to None so plotting code can skip it gracefully.
end_to_end_latency = None

print("Test metrics:", test_metrics)
print("Model-only latency:", model_latency)

# 1) Accuracy curves (own graph)
plt.figure(figsize=(8, 5))
plt.plot(epochs_range, history["train_acc"], label="Train Acc", marker="o", linewidth=2)
plt.plot(epochs_range, history["val_acc"], label="Val Acc", marker="s", linewidth=2)
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.ylim(0, 1)
plt.title("Accuracy over epochs")
plt.grid(True, alpha=0.3)
plt.legend()
plt.tight_layout()
plt.savefig(os.path.join("Plot results", "accuracy_over_epochs.png"), dpi=150)
plt.close()

# 2) Precision (macro) curves (own graph)
plt.figure(figsize=(8, 5))
plt.plot(epochs_range, history["train_precision_macro"], label="Train Precision (macro)", marker="o", linewidth=2)
plt.plot(epochs_range, history["val_precision_macro"], label="Val Precision (macro)", marker="s", linewidth=2)
plt.xlabel("Epoch")
plt.ylabel("Precision (macro)")
plt.ylim(0, 1)
plt.title("Precision (macro) over epochs")
plt.grid(True, alpha=0.3)
plt.legend()
plt.tight_layout()
plt.savefig(os.path.join("Plot results", "precision_macro_over_epochs.png"), dpi=150)
plt.close()

# 3) F1 (macro) curves (own graph)
plt.figure(figsize=(8, 5))
plt.plot(epochs_range, history["train_f1_macro"], label="Train F1 (macro)", marker="o", linewidth=2)
plt.plot(epochs_range, history["val_f1_macro"], label="Val F1 (macro)", marker="s", linewidth=2)
plt.xlabel("Epoch")
plt.ylabel("F1 (macro)")
plt.ylim(0, 1)
plt.title("F1 (macro) over epochs")
plt.grid(True, alpha=0.3)
plt.legend()
plt.tight_layout()
plt.savefig(os.path.join("Plot results", "f1_macro_over_epochs.png"), dpi=150)
plt.close()

# 4) Final test set metrics bar chart
final_names = ["Accuracy", "Precision (macro)", "Recall (macro)", "F1 (macro)"]
final_vals = [
    test_metrics.get("acc", float("nan")),
    test_metrics.get("precision_macro", float("nan")),
    test_metrics.get("recall_macro", float("nan")),
    test_metrics.get("f1_macro", float("nan")),
]
plt.figure(figsize=(8, 5))
bars = plt.bar(final_names, final_vals, color=["#4caf50", "#2196f3", "#ff9800", "#9c27b0"])
plt.ylim(0, 1.0)
plt.ylabel("Score")
plt.title("Final Test Metrics")
for b, v in zip(bars, final_vals):
    plt.text(b.get_x() + b.get_width() / 2, b.get_height() + 0.01, f"{v:.3f}", ha="center", va="bottom")
plt.tight_layout()
plt.savefig(os.path.join("Plot results", "final_test_metrics.png"), dpi=150)
plt.close()


# 5) Latency plot (own graph; model-only; end-to-end if available)
# We plot horizontal bars for summary stats in milliseconds for readability
def _plot_latency_summary(lat_dict: Dict[str, float], title: str, filename: str):
    if lat_dict is None:
        return
    keys = [k for k in ["latency_mean_s", "latency_p50_s", "latency_p95_s"] if k in lat_dict]
    vals_ms = [lat_dict[k] * 1000.0 for k in keys]
    labels = ["Mean (ms)", "P50 (ms)", "P95 (ms)"]
    plt.figure(figsize=(8, 4))
    bars = plt.barh(labels, vals_ms, color="#607d8b")
    plt.xlabel("Milliseconds per sample")
    plt.title(title)
    for b, v in zip(bars, vals_ms):
        plt.text(v, b.get_y() + b.get_height() / 2, f" {v:.2f} ms", va="center")
    plt.tight_layout()
    plt.savefig(os.path.join("Plot results", filename), dpi=150)
    plt.close()


_plot_latency_summary(model_latency, "Model-only Latency Summary", "latency_model_only.png")
if end_to_end_latency is not None:
    _plot_latency_summary(end_to_end_latency, "End-to-End Latency Summary", "latency_end_to_end.png")

print('Saved plots to "Plot results" folder.')


Test metrics: {'precision_macro': 0.012240314856171608, 'recall_macro': 0.07999999821186066, 'f1_macro': 0.018703008070588112, 'precision_micro': 0.07999999821186066, 'recall_micro': 0.07999999821186066, 'f1_micro': 0.07999999821186066, 'loss': 2.3238205464680988, 'acc': 0.08}
Model-only latency: {'latency_mean_s': 0.0016528858344539035, 'latency_p50_s': 0.0018740283984328698, 'latency_p95_s': 0.001938971782718722}
Saved plots to "Plot results" folder.
