![title](https://proper-parallel6-dull-outside-car.github.io/libribrain-keyword/tutorial/img/task-graphic.png)

# 🕵️ Tutorial: Neural Keyword Detection on Sherlock Holmes Stories (LibriBrain)

> ⚠️ In Google Colab, make sure to switch to a `GPU` runtime via `Runtime -> Change runtime type`.

This tutorial covers keyword detection from MEG data using the LibriBrain dataset, which covers **over 50 hours** of MEG data of a single participant listening to Sherlock Holmes audiobooks. It is inspired by (and, at times, re-uses small code snippets from) the existing LibriBrain tutorials found on the [official website](https://neural-processing-lab.github.io/2025-libribrain-competition/participate/).

## Contents
Here's what we'll do in the next hour:
1. **Setup** — install dependencies and set up configuration
2. **Dataset Structure** — have a look at the pre-processed MEG recordings and event files
3. **Problem Formulation** — understand the problem and associated challenges
4. **Model Architecture** — look at one possible solution approach,...
5. **Training Strategy** — ... and understand how to get the most out of it
6. **Evaluation** — then, evaluate on standard metrics - AUPRC and false alarms/hour at fixed recall

Let's get started! 🔎

## I. Setup
Run the cell below as-is. It will install all required dependencies and prepare the environment. Notice the lines at the bottom of the cell? Those install a modified variant of the `pnpl` library we have customized for neural keyword detection. The `pnpl` library makes our life much easier here - it automatically downloads the dataset into the correct folder structure, provides dataset classes and more.

In [None]:
# Install dependencies
%pip install -q mne_bids lightning torchmetrics scikit-learn plotly ipywidgets neptune

BASE_PATH = "./kws"
try:
    import google.colab  # Colab runtime
    IN_COLAB = True
    BASE_PATH = "/content"
    print("Running in Colab")
except ImportError:
    IN_COLAB = False
    print("Not running in Colab")

# Anonymous GitHub repo will be replaced after review process has concluded.
!git clone https://github.com/Proper-Parallel6-Dull-Outside-Car/libribrain-keyword.git
!cd libribrain-keyword/modified-pnpl/pnpl && pip install -q .

In [None]:
"""
Configuration Parameters.

Each parameter below affects the keyword detection pipeline we'll build.
For now, just run this cell as-is - you can come back in the future and play around with it!
"""

CONFIG = {
    # === DATASET CONFIGURATION ===
    "data_path": f"{BASE_PATH}/data/",

    # Target Keyword Selection
    "keyword": "watson",              # Choose your target word (case-insensitive)
                                      # Popular options: "watson", "holmes", "sherlock", "the"
                                      # Tip: Common words have more examples but lower precision

    # Word Filtering
    "min_word_len": 3,                # Skip very short words (reduce noise)
    "max_word_len": None,             # Optional: limit to shorter words for faster processing

    # Temporal Window Configuration
    "tmin": None,                     # Auto-computed: 0 - negative_buffer
    "tmax": None,                     # Auto-computed: keyword_duration + positive_buffer
    "negative_buffer": 0.10,          # Pre-onset context (100ms captures anticipatory activity)
    "positive_buffer": 0.30,          # Post-offset context (300ms captures completion responses)
                                      # Trade-off: More context vs. computational cost

    # Signal Preprocessing
    "standardize": True,              # Z-score normalization per channel (recommended)
    "clipping_boundary": 10.0,        # Outlier clipping (prevents extreme values)

    # === DATA SPLITS ===
    # Strategy: Use minimal subset for fast experimentation
    # For full performance, expand to more sessions/books
    "train_run_keys": [("0","1","Sherlock1","1"), ("0","3","Sherlock1","1"), ("0","5","Sherlock1","1"), ("0","12","Sherlock4","1"), ("0","12","Sherlock6","1")],
    "val_partition": "validation",     # Automatic keyword-aware validation selection
    "test_partition": "test",         # Automatic keyword-aware test selection

    # === DATALOADER SETTINGS ===
    "batch_size": 64,                 # Balance: memory usage vs. gradient stability
    "num_workers": 2 if IN_COLAB else 0,  # Parallel data loading (adjust for your system)

    # === MODEL ARCHITECTURE ===
    "model_dim": 128,                 # Hidden dimension size (trade-off: capacity vs. speed)
    "dropout": 0.4,                   # Regularization strength (combat overfitting)
    "lstm_layers": 2,                 # Temporal processing depth
    "bi_directional": False,          # Bidirectional LSTM (doubles parameters)

    # === OPTIMIZATION ===
    "learning_rate": 1e-3,            # Step size (too high: instability, too low: slow convergence)
    "weight_decay": 0.01,             # L2 regularization (prevent overfitting)
    "smoothing": 0.1,                 # Label smoothing (makes model less confident)

    # === TRAINING CONTROL ===
    "max_epochs": 8,                  # Maximum training epochs
    "early_patience": 6,              # Stop if validation doesn't improve
}

print("Configuration loaded! Key settings:")
print(f"🎯 Target keyword: '{CONFIG['keyword']}'")
print(f"⏱️  Window: {CONFIG['negative_buffer']}s before → {CONFIG['positive_buffer']}s after onset")
print(f"🧠 Model dimension: {CONFIG['model_dim']}")
print(f"📊 Batch size: {CONFIG['batch_size']}")

## II. Dataset Structure

Next, let's explore the dataset! LibriBrain uses two complementary file types that work together to provide both neural signals and their annotations, which we will use as labels:

**🧠 1. HDF5 Files (.h5)**: The Neural Signal Container
- **data**: MEG sensor readings (306 channels × timesteps)
- **times**: Precise timestamps for each sample (250Hz sampling rate)
- **Preprocessing**: Signals are cleaned, filtered, and standardised
- **Size**: Large files (~100MB+) containing continuous recordings

**📝 2. Event Files (.tsv)**: The Annotation Layer  
- **Timing**: Word and phoneme onset times aligned to MEG
- **Labels**: Text content ('segment' column) for each time window
- **Categories**: 'word', 'phoneme', and 'silence' event types
- **Precision**: Millisecond-accurate alignment with neural data

Naming and folder structure is based on the [BIDS](https://bids-specification.readthedocs.io/en/stable/) specification:
- **Subject**: Single participant (always 'sub-0')
- **Sessions**: Book chapters (1-12+ per story)
- **Tasks**: Different Books (Sherlock1, Sherlock2,... up to Sherlock7)
- **Runs**: Recording sessions per task
- **Preprocessing**: The HDF5 file will also contain a "preprocessing string" at the end, indicating what processing has been done on the raw data.

That means that you may find the two files
- sub-0_ses-10_task-Sherlock1_run-1_proc-bads+headpos+sss+notch+bp+ds_meg.h5
- sub-0_ses-10_task-Sherlock1_run-1_events.tsv

in folders 'Sherlock1/derivatives/serialised/' and 'Sherlock1/derivatives/events/' respectively. Those two together represent a single session.

#### Let's explore
Below, we will examine these two example files. Here are some things to pay attention to:

For the `pnpl` package:
- How it automatically downloads both data and label given a single "run key"

For the H5 file:
- Matrix shape: 306 sensors × time samples
- Sampling rate: 250Hz (4ms resolution)
- Time-locked: Besides the raw data (as a `data` key), the file contains a `times` key that functions as a timestamp. It increases with 4ms per sample.

For the events file:
- Word-level timing in the 'segment' column
- Different event types: word vs. phoneme vs. silence

In [None]:
import os, h5py, pandas as pd
from pathlib import Path

from pnpl.datasets import LibriBrainWord
from torch.utils.data import DataLoader

# Loading a single session to force download
example_data = LibriBrainWord(
  data_path=CONFIG['data_path'],
  include_run_keys = [("0","1","Sherlock1","1")]
)
# Conditionally set num_workers to avoid multiprocessing issues (try increasing if performance is problematic)
example_loader = DataLoader(example_data, batch_size=CONFIG['batch_size'], shuffle=True, num_workers=CONFIG['num_workers'])


hdf5_file_path = f"{CONFIG['data_path']}/Sherlock1/derivatives/serialised/sub-0_ses-1_task-Sherlock1_run-1_proc-bads+headpos+sss+notch+bp+ds_meg.h5"
events_path = f"{CONFIG['data_path']}/Sherlock1/derivatives/events/sub-0_ses-1_task-Sherlock1_run-1_events.tsv"

# Quick look at H5 files (MEG data)
with h5py.File(hdf5_file_path, 'r') as f:
    data = f["data"]
    times = f["times"]
    sfreq = f.attrs["sample_frequency"]
    print("HDF5 datasets:", list(f.keys()))
    print("data shape (channels x time):", data.shape)
    print("sfreq:", sfreq)
    print("times[0:5] (s):", times[:5])

# Quick look at events (labels)
df = pd.read_csv(events_path, sep='\t')
print("\nEvents columns:", list(df.columns))
print("First rows:\n", df.head())


## III. Task Overview and Configuration

Now that we understand our data, let's talk about the task. We'll try to detect a target keyword in continuous MEG using short windows around word onsets. Fundamentally, that is an _event-referenced binary classification problem_ - we're given a sample (`T` timepoints of MEG data across `306` channels) and are trying to predict a value between 0 and 1 representing the probability that our keyword occurs during that timeframe.

> **How long is `T`?**
>
> The length of each sample equals the longest instance of the chosen keyword (meaning for a 1s word, each sample will be 250Hz * 1s = 250 samples). Optionally, you can provide a `negative_buffer` or `positive_buffer` (so the sample starts before/extends beyond the keyword duration.

**There are two primary challenges making our lives harder:**

### 1. Class Imbalance
Unfortunately, there are _a lot_ of different words in the dataset. Even the most common word, "the" makes up just 5.5% of all words. Given that the distribution within the dataset is [Zipfian](https://en.wikipedia.org/wiki/Zipf%27s_law), most words are a lot rarer than that:

![Word frequency chart](https://proper-parallel6-dull-outside-car.github.io/libribrain-keyword/tutorial/img/word_frequency_chart.png)
The keyword we'll be looking at today is **"Watson"**, which appears just **608 times** in the entire dataset (**0.1189%**).

### 2. Signal-to-Noise Ratio
Compared to traditional audio, or even to invasive brain decoding, non-invasive brain data is _a lot_ more noisy. While we do get nice data from 306 [MEG](https://en.wikipedia.org/wiki/Magnetoencephalography) channels, all of these sensors were placed _outside_ the participant's head. Have a look at the visualisation below (which is adapted from the [2025 LibriBrain Competition Speech Detection Tutorial](https://neural-processing-lab.github.io/2025-libribrain-competition/links/speech-colab)). It's impossible to tell with the naked eye when _speech_ starts - let alone tell apart specific words.


In [None]:
import pandas as pd
import numpy as np
import h5py
import matplotlib.pyplot as plt


def plot_meg_and_labels(hdf5_file_path, tsv_file_path, start_time, end_time, title=None):
    # ---- Load MEG (data: channels x samples; times: seconds) ----
    with h5py.File(hdf5_file_path, "r") as f:
        meg = f["data"][:]    # (C, T)
        times = f["times"][:] # (T,)

    if times.size < 2:
        raise ValueError("Not enough time points in 'times' to determine a window.")

    # clamp window to available time range
    t0 = max(float(times[0]), float(start_time))
    t1 = min(float(times[-1]), float(end_time))
    if not (t0 < t1):
        raise ValueError(f"Empty window after clamping: [{t0}, {t1}]")

    # slice by time using searchsorted (avoids sampling-frequency math)
    i0 = int(np.searchsorted(times, t0, side="left"))
    i1 = int(np.searchsorted(times, t1, side="right"))
    seg = meg[:, i0:i1]
    tseg = times[i0:i1]

    # ---- Load and prep TSV ----
    tsv = pd.read_csv(tsv_file_path, sep="\t")
    # keep only rows with usable timing + kind
    tsv = tsv.dropna(subset=["timemeg", "kind"]).copy()
    tsv["timemeg"] = tsv["timemeg"].astype(float)

    # rows inside the window
    win = tsv[(tsv["timemeg"] >= t0) & (tsv["timemeg"] <= t1)].copy()

    # determine label at t0 from the last event before the window
    prev = tsv[tsv["timemeg"] < t0].tail(1)
    start_label = int(prev["kind"].isin(["word", "phoneme"]).iloc[0]) if not prev.empty else 0
    end_label = start_label if win.empty else int(win["kind"].iloc[-1] in ("word", "phoneme"))

    # build a simple step series (0 = silence, 1 = speech event)
    step_times = np.concatenate(([t0], win["timemeg"].to_numpy(), [t1]))
    step_vals = np.concatenate(([start_label],
                                win["kind"].isin(["word", "phoneme"]).astype(int).to_numpy(),
                                [end_label]))

    # ---- Plot ----
    fig, (ax0, ax1) = plt.subplots(2, 1, figsize=(14, 8),
                                   gridspec_kw={"height_ratios": [2, 1]},
                                   sharex=True)

    ax0.plot(tseg, seg.T, alpha=0.5)
    ax0.set_title(title or f"MEG {t0:.2f}s–{t1:.2f}s")
    ax0.set_ylabel("Amplitude")
    ax0.grid(True, alpha=0.3)

    ax1.plot(step_times, step_vals, drawstyle="steps-post", linewidth=2)
    # annotate word onsets only (kept simple)
    for _, r in win.iterrows():
        if r["kind"] == "word":
            ax1.text(float(r["timemeg"]), 1.22, str(r.get("segment", "")),
                     fontsize=9, ha="center", va="bottom", rotation=0)

    ax1.set_ylim(-0.2, 1.5)
    ax1.set_xlim(t0, t1)
    ax1.set_ylabel("Speech / Silence")
    ax1.set_xlabel("Time (s)")
    ax1.set_title("Labels")
    ax1.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()


tsv_file_path  = f"{CONFIG['data_path']}/Sherlock1/derivatives/events/sub-0_ses-1_task-Sherlock1_run-1_events.tsv"
hdf5_file_path = f"{CONFIG['data_path']}/Sherlock1/derivatives/serialised/sub-0_ses-1_task-Sherlock1_run-1_proc-bads+headpos+sss+notch+bp+ds_meg.h5"
plot_meg_and_labels(hdf5_file_path, tsv_file_path, start_time=55, end_time=56,
                    title="Transition silence → speech")


## IV. Problem Formulation and Challenges

Now that we understand the challenges we're facing, let's start building a solution!

### Step 1: Training data
So far, we've only played around with a single session. The first step, then, is to acquire more data. As we may be somewhat compute-limited by the constraints of Google Colab, we might not want to work with _all 52 hours_ of the dataset. In the default configuration, we have pre-defined some sessions as a `train` set (see `train_run_keys` in your config above). The `LibriBrainWords` dataset class chooses the `validation` and `test` sets automatically using either the LibriBrain default split or, if the chosen keyword does not appear in them, the highest-prevalence session in the dataset. Our default keyword `Watson`, does not appear in the default split - that is why you can see the cell download all available label files. Due to this dynamic logic, we need the additional code below to check that we are not accidentally training on a session that was dynamically picked as validation/test set. Execute the cell below to download the data and set up the datasets.

In [None]:
# Build datasets and loaders (uses pnpl LibriBrainWord)
from pnpl.datasets.libribrain2025.word_dataset import LibriBrainWord
from pnpl.datasets.libribrain2025.base import LibriBrainBase
from pnpl.datasets.libribrain2025.constants import RUN_KEYS, VALIDATION_RUN_KEYS, TEST_RUN_KEYS
from torch.utils.data import DataLoader
import os, pandas as pd

# Train subset: user-chosen specific sessions to keep runtime light
train_ds = LibriBrainWord(
    data_path=CONFIG["data_path"],
    include_run_keys=CONFIG["train_run_keys"],
    keyword_detection=CONFIG["keyword"],
    min_word_length=CONFIG["min_word_len"],
    max_word_length=CONFIG["max_word_len"],
    tmin=CONFIG["tmin"],
    tmax=CONFIG["tmax"],
    negative_buffer=CONFIG["negative_buffer"],
    positive_buffer=CONFIG["positive_buffer"],
    standardize=CONFIG["standardize"],
    clipping_boundary=CONFIG["clipping_boundary"],
    preload_files=True,
)

# Explicitly choose validation/test runs by scanning events.tsv for the keyword

def _events_path(data_path, rk):
    s, se, t, r = rk
    return os.path.join(data_path, t, "derivatives", "events", f"sub-{s}_ses-{se}_task-{t}_run-{r}_events.tsv")

def _count_keyword(data_path, rk, keyword, min_len=None, max_len=None):
    f = _events_path(data_path, rk)
    try:
        LibriBrainBase.ensure_file_download(f, data_path=data_path)
    except Exception:
        if not os.path.exists(f):
            return 0, 0
    try:
        df = pd.read_csv(f, sep="\t")
    except Exception:
        return 0, 0
    if "kind" in df.columns:
        df = df[df["kind"] == "word"].copy()
    if "segment" not in df.columns:
        return 0, 0
    seg = df["segment"].astype(str).str.strip()
    if min_len and min_len > 1:
        seg = seg[seg.str.len() >= min_len]
    if max_len:
        seg = seg[seg.str.len() <= max_len]
    total = int(seg.shape[0])
    pos = int(seg.str.lower().eq(str(keyword).lower()).sum()) if total > 0 else 0
    return pos, total

def _choose_val_test(data_path, keyword, min_len=None, max_len=None):
    cands = [tuple(rk) for rk in RUN_KEYS]
    scored = []  # (prev, pos, total, rk)
    for rk in cands:
        pos, total = _count_keyword(data_path, rk, keyword, min_len, max_len)
        if total > 0 and pos > 0:
            prev = pos / total
            scored.append((prev, pos, total, rk))
    if not scored:
        return VALIDATION_RUN_KEYS[0], TEST_RUN_KEYS[0]
    scored.sort(key=lambda x: (-x[0], -x[1]))
    val_rk = scored[0][3]
    pool = [s for s in scored if s[3] != val_rk]
    if not pool:
        return val_rk, scored[0][3]
    pool.sort(key=lambda x: (0 if x[3] in set(TEST_RUN_KEYS) else 1, -x[0], -x[1]))
    test_rk = pool[0][3]
    return val_rk, test_rk

val_rk, test_rk = _choose_val_test(
    data_path=CONFIG["data_path"],
    keyword=CONFIG["keyword"],
    min_len=CONFIG["min_word_len"],
    max_len=CONFIG["max_word_len"],
)
print("Chosen VAL:", val_rk, "| TEST:", test_rk)

val_ds = LibriBrainWord(
    data_path=CONFIG["data_path"],
    include_run_keys=[val_rk],
    keyword_detection=CONFIG["keyword"],
    min_word_length=CONFIG["min_word_len"],
    max_word_length=CONFIG["max_word_len"],
    tmin=CONFIG["tmin"],
    tmax=CONFIG["tmax"],
    negative_buffer=CONFIG["negative_buffer"],
    positive_buffer=CONFIG["positive_buffer"],
    standardize=CONFIG["standardize"],
    clipping_boundary=CONFIG["clipping_boundary"],
    preload_files=True,
)

test_ds = LibriBrainWord(
    data_path=CONFIG["data_path"],
    include_run_keys=[test_rk],
    keyword_detection=CONFIG["keyword"],
    min_word_length=CONFIG["min_word_len"],
    max_word_length=CONFIG["max_word_len"],
    tmin=CONFIG["tmin"],
    tmax=CONFIG["tmax"],
    negative_buffer=CONFIG["negative_buffer"],
    positive_buffer=CONFIG["positive_buffer"],
    standardize=CONFIG["standardize"],
    clipping_boundary=CONFIG["clipping_boundary"],
    preload_files=True,
)

train_loader = DataLoader(train_ds, batch_size=CONFIG["batch_size"], shuffle=True, num_workers=CONFIG["num_workers"])
val_loader = DataLoader(val_ds, batch_size=CONFIG["batch_size"], shuffle=False, num_workers=CONFIG["num_workers"])
test_loader = DataLoader(test_ds, batch_size=CONFIG["batch_size"], shuffle=False, num_workers=CONFIG["num_workers"])

print("Training samples: ",len(train_ds))
print("Validation samples: ",len(val_ds))
print("Test samples", len(test_ds))


### Step 1: Model Architecture
While high quality training data is extremely important, it also matters what type of model you train with it. Our model addresses the challenges we identified through three key components:

1. **Spatial-Temporal Processing** — efficiently handle high-dimensional MEG data
2. **Attention-Based Pooling** — focus on discriminative time points within each window  
3. **Adapted Training Strategy** — specialised losses and sampling for rare positive examples

#### Architecture Overview

![Model Architecture Diagram](https://proper-parallel6-dull-outside-car.github.io/libribrain-keyword/tutorial/img/model-architecture.png)

#### Component Details
Let's walk through the things that make the model work!

**🏗️ Convolutional Trunk**

A lightweight Conv1D front-end projects the 306 MEG channels into a 128-dimensional feature space that already mixes spatial information across sensors. Stacked residual blocks plus a stride-2 downsampling expand the temporal receptive field and denoise/compress the sequence, leaving compact features that preserve event timing for the attention stage.

In [None]:
# Demo: Convolutional Trunk -} 306→128 + residual blocks + downsampling (stride 2)
import torch, torch.nn as nn, torch.nn.functional as F
import numpy as np, matplotlib.pyplot as plt

torch.manual_seed(0)

class ResidualBlock(nn.Module):
    def __init__(self, ch, k=3, dilation=1):
        super().__init__()
        p = (k - 1) // 2 * dilation
        self.conv1 = nn.Conv1d(ch, ch, k, padding=p, dilation=dilation)
        self.elu   = nn.ELU(inplace=True)
        self.conv2 = nn.Conv1d(ch, ch, k, padding=p, dilation=dilation)
        self.short = nn.Identity()
        nn.init.kaiming_normal_(self.conv1.weight, nonlinearity='linear')
        nn.init.zeros_(self.conv2.weight)

    def forward(self, x):
        h = self.conv1(x)
        h = self.elu(h)
        h = self.conv2(h)
        return self.elu(h + self.short(x))

class ConvTrunk(nn.Module):
    def __init__(self, c_in=306, c_mid=128, k=7, n_blocks=2, stride=2):
        super().__init__()
        p = (k - 1) // 2
        self.front = nn.Conv1d(c_in, c_mid, k, padding=p)
        self.blocks = nn.Sequential(*[ResidualBlock(c_mid, k=3) for _ in range(n_blocks)])
        self.down   = nn.Conv1d(c_mid, c_mid, 3, padding=1, stride=stride)
        self.act    = nn.ELU(inplace=True)

    def forward(self, x):  # (B, 306, T) -> (B, 128, T//2)
        x = self.front(x)
        x = self.blocks(x)
        x = self.down(x)
        return self.act(x)

# Dummy MEG-like input: batch of 8, 306 chans, 250 timepoints (~1s@250Hz)
B, C, T = 8, 306, 250
x = torch.randn(B, C, T)
trunk = ConvTrunk()

with torch.no_grad():
    y = trunk(x)

print(f"Input shape:  {tuple(x.shape)}")
print(f"Output shape: {tuple(y.shape)}  (306→128 channels; 250→{y.shape[-1]} time)")

# Quick visualization: channel-energy compression (sum of squares across time)
x_energy = (x**2).sum(dim=-1).mean(dim=0).numpy()       # (306,)
y_energy = (y**2).sum(dim=-1).mean(dim=0).numpy()       # (128,)

fig, ax = plt.subplots(figsize=(7,3))
ax.plot(np.log1p(np.sort(x_energy))[::-1], label="Input (306) energy")
ax.plot(np.log1p(np.sort(y_energy))[::-1], label="Trunk (128) energy")
ax.set_title("Convolutional trunk compresses channels while retaining structure")
ax.set_xlabel("Channel (sorted by energy)")
ax.set_ylabel("log(1+energy)")
ax.legend(); ax.grid(alpha=0.3); plt.tight_layout(); plt.show()

**🧠 Temporal Attention**

A parallel attention head scores each time step and forms a softmax weight over the window, letting the model concentrate probability mass on brief, informative bursts (e.g., around keyword onsets) while down-weighting idle/noisy segments. This adaptive, interpretable pooling handles variable latency/duration better than mean or max, which are either too diffuse or too brittle.

In [None]:
# Demo: Temporal Attention - create a synthetic "event" so attention has something real to lock onto
import torch, torch.nn as nn
import numpy as np, matplotlib.pyplot as plt

torch.manual_seed(0)
np.random.seed(0)

class TemporalAttentionHead(nn.Module):
    def __init__(self, c_in=128):
        super().__init__()
        self.logit_head = nn.Conv1d(c_in, 1, kernel_size=1)
        self.attn_head  = nn.Conv1d(c_in, 1, kernel_size=1)

    def forward(self, h):      # h: (B, C, T')
        logit_t = self.logit_head(h)             # (B, 1, T')
        attn_t  = self.attn_head(h)              # (B, 1, T')
        attn_w  = torch.softmax(attn_t, -1)      # across time
        pooled  = (attn_w * logit_t).sum(-1)     # (B, 1)
        return pooled.squeeze(1), logit_t, attn_w

# --- synth trunk activations with a localized event on a subset of channels
B, C, Tprime = 1, 128, 160
t0, sigma, amp, active_ch = 95, 6.0, 2.5, 16

h = torch.randn(B, C, Tprime) * 0.4
t = torch.arange(Tprime).float()
event = torch.exp(-0.5 * ((t - t0) / sigma) ** 2)  # Gaussian bump
sel = torch.randperm(C)[:active_ch]
h[:, sel, :] += amp * event  # inject event pattern on a few channels

head = TemporalAttentionHead(C)
with torch.no_grad():
    pooled, logit_t, attn_w = head(h)

# Compare pooling strategies on the same per-time logits
lt = logit_t[0, 0]                  # (T')
aw = attn_w[0, 0]
mean_pool = lt.mean()
max_pool  = lt.max()
attn_pool = (aw * lt).sum()

# Effective support (how concentrated attention is): smaller = sharper
eff_support = 1.0 / float((aw**2).sum())

print(f"Pooled logits — mean: {mean_pool:.3f} | max: {max_pool:.3f} | attention: {attn_pool:.3f}")
print(f"Attention effective support (1/sum(w^2)) = {eff_support:.1f} time-steps (lower = more focused)")

# Visualize: per-time logits + attention weights, with the event region shaded
fig, ax = plt.subplots(figsize=(7.5,3.4))
ax.plot(lt.numpy(), label="per-time logits")
ax2 = ax.twinx()
ax2.plot(aw.numpy(), alpha=0.7, label="attention weights")
ax.axvspan(int(t0 - 3*sigma), int(t0 + 3*sigma), alpha=0.15, lw=0, label="event region")
ax.set_title("Temporal attention locks onto the event-like region")
ax.set_xlabel("Time index")
ax.set_ylabel("Logit"); ax2.set_ylabel("Attention weight")
ax.grid(alpha=0.3)
lines, labels = ax.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
ax.legend(lines+lines2, labels+labels2, loc="upper left")
plt.tight_layout(); plt.show()

**⚖️ Focal Loss (α=0.95, γ=2.0)**

Focal loss rescales BCE by $(1-p_t)^\gamma$ with a class prior $\alpha$, so the ocean of easy negatives contributes almost nothing while positives and near-miss negatives dominate the gradient. Setting $\alpha{=}0.95,\ \gamma{=}2.0$ matches the <1% base rate, preventing “always negative” collapse and improving ranking without aggressive oversampling.

In [None]:
# Demo: Focal loss vs BCE - visualize which examples contribute to the loss
import torch, torch.nn.functional as F
import numpy as np, matplotlib.pyplot as plt

torch.manual_seed(0); np.random.seed(0)

def binary_focal_loss_with_logits(logits, targets, alpha=0.95, gamma=2.0, reduction='none'):
    p = torch.sigmoid(logits)
    ce = F.binary_cross_entropy_with_logits(logits, targets.float(), reduction='none')
    pt = p*targets + (1-p)*(1-targets)
    alpha_t = alpha*targets + (1-alpha)*(1-targets)
    loss = alpha_t * (1-pt).clamp_min(0).pow(gamma) * ce
    return loss if reduction=='none' else loss.mean()

# Simulate a heavily imbalanced batch: ~0.8% positives
N = 12000
pos_n = max(1, int(0.008 * N))
y = torch.zeros(N, dtype=torch.long)
pos_idx = torch.randperm(N)[:pos_n]
y[pos_idx] = 1

# Build logits:
# - many easy negatives (very < 0)
# - some medium negatives
# - a small slice of hard negatives (~ around 0)
# - positives somewhat > 0 with noise
logits = torch.randn(N) * 0.5 - 1.6
neg_mask = (y == 0)
neg_ids = neg_mask.nonzero(as_tuple=True)[0]
hard_ids = neg_ids[torch.randperm(len(neg_ids))[:max(50, N//100)]]
mid_ids  = neg_ids[torch.randperm(len(neg_ids))[:N//30]]
logits[mid_ids] += 0.9
logits[hard_ids] = torch.randn(len(hard_ids)) * 0.25  # centered near 0 (hard)
logits[pos_idx]  = torch.randn(len(pos_idx)) * 0.5 + 1.2

# Per-example losses
bce_per   = F.binary_cross_entropy_with_logits(logits, y.float(), reduction='none').detach()
focal_per = binary_focal_loss_with_logits(logits, y, alpha=0.95, gamma=2.0, reduction='none').detach()

# Group examples
def group_of(i):
    if y[i] == 1: return "positives"
    z = logits[i].item()
    if z < -1.0: return "easy neg"
    if -1.0 <= z < -0.3: return "medium neg"
    if -0.3 <= z < 0.5:  return "hard neg"
    return "borderline neg"

groups = ["positives","hard neg","borderline neg","medium neg","easy neg"]
bce_contrib = {g:0.0 for g in groups}
foc_contrib = {g:0.0 for g in groups}

for i in range(N):
    g = group_of(i)
    bce_contrib[g] += float(bce_per[i])
    foc_contrib[g] += float(focal_per[i])

bce_total  = sum(bce_contrib.values())
foc_total  = sum(foc_contrib.values())
bce_share  = [100.0 * bce_contrib[g] / bce_total for g in groups]
foc_share  = [100.0 * foc_contrib[g] / foc_total for g in groups]

print("BCE vs Focal — share of total loss by group (%):")
for g, b, f in zip(groups, bce_share, foc_share):
    print(f"  {g:13s}  BCE {b:6.2f}%   |   Focal {f:6.2f}%")

# Plot: contribution shares side-by-side
x = np.arange(len(groups))
w = 0.38
fig, ax = plt.subplots(figsize=(7.5,3.6))
ax.bar(x - w/2, bce_share, width=w, label="BCE")
ax.bar(x + w/2, foc_share, width=w, label="Focal (α=0.95, γ=2)")
ax.set_xticks(x); ax.set_xticklabels(groups, rotation=15)
ax.set_ylabel("Share of total loss (%)")
ax.set_title("Focal loss down-weights easy negatives, emphasizes positives & hard negatives")
ax.grid(axis='y', alpha=0.3)
ax.legend(loc="upper right")
plt.tight_layout(); plt.show()

**🎯 Pairwise Ranking Loss**
A pairwise logistic term compares each positive to sampled in-batch negatives and penalizes inversions (when $s_{+}\!\le\!s_{-}$), directly improving the **ordering** that drives precision–recall. Complementing focal, it widens the margin to hard negatives and yields more stable low–false-alarm thresholds—the regime that matters for KWS.

In [None]:
# Demo: Pairwise logistic ranking loss — changes ordering → PR improves
import torch, numpy as np, matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_curve, average_precision_score

torch.manual_seed(0); np.random.seed(0)

def pairwise_logistic_loss(scores: torch.Tensor, targets: torch.Tensor, num_neg: int = 24):
    """
    scores:  (N,) float tensor (can require_grad)
    targets: (N,) {0,1} long tensor
    Loss = mean(log(1 + exp(-(s_pos - s_neg)))) over sampled pos/neg pairs.
    """
    pos = (targets == 1).nonzero(as_tuple=True)[0]
    neg = (targets == 0).nonzero(as_tuple=True)[0]
    if len(pos) == 0 or len(neg) == 0:
        return scores.new_tensor(0.0), scores.new_tensor([0.0])

    pairs_pos, pairs_neg = [], []
    for p in pos:
        sel = neg[torch.randint(0, len(neg), (min(num_neg, len(neg)),))]
        pairs_pos.append(p.repeat(len(sel)))
        pairs_neg.append(sel)
    pos_idx = torch.cat(pairs_pos); neg_idx = torch.cat(pairs_neg)
    margins = scores[pos_idx] - scores[neg_idx]
    loss = torch.log1p(torch.exp(-margins)).mean()
    return loss, margins

# --- 1) Synthetic batch with overlap (ranking can improve)
N, pos_n = 1500, 24
y = torch.zeros(N, dtype=torch.long)
pos_idx = torch.randperm(N)[:pos_n]; y[pos_idx] = 1

scores0 = (torch.randn(N) * 0.8 - 0.1).to(torch.float32)
scores0[pos_idx] += 0.4  # slight lift for positives

# Metrics BEFORE
loss0, margins0 = pairwise_logistic_loss(scores0.detach(), y, num_neg=24)
ap0 = average_precision_score(y.numpy(), scores0.detach().numpy())

# --- 2) Optimize scores directly (Adam)
scores = torch.nn.Parameter(scores0.detach().clone())
opt = torch.optim.Adam([scores], lr=0.02)
K = 50
for _ in range(K):
    opt.zero_grad(set_to_none=True)
    loss, _ = pairwise_logistic_loss(scores, y, num_neg=24)
    loss.backward()
    opt.step()

scores1 = scores.detach()
loss1, margins1 = pairwise_logistic_loss(scores1, y, num_neg=24)
ap1 = average_precision_score(y.numpy(), scores1.numpy())

print(f"Pairwise loss  — before: {float(loss0):.4f} | after: {float(loss1):.4f}")
print(f"Avg margin s+−s− — before: {float(margins0.mean()):.3f} | after: {float(margins1.mean()):.3f}")
print(f"AUPRC (AP)     — before: {ap0:.3f} | after: {ap1:.3f}")

# --- 3) Visuals: PR curves + score distributions
prec0, rec0, _ = precision_recall_curve(y.numpy(), scores0.numpy())
prec1, rec1, _ = precision_recall_curve(y.numpy(), scores1.numpy())

fig, ax = plt.subplots(figsize=(7.2,4.0))
ax.plot(rec0, prec0, label=f"Before (AP={ap0:.3f})")
ax.plot(rec1, prec1, label=f"After {K} steps (AP={ap1:.3f})")
ax.set_title("Pairwise ranking improves ordering → better PR")
ax.set_xlabel("Recall"); ax.set_ylabel("Precision")
ax.grid(alpha=0.3); ax.legend(loc="lower left")
plt.tight_layout(); plt.show()

fig, ax = plt.subplots(1,2, figsize=(10,3.6), sharey=True)
ax[0].hist(scores0[y==0].numpy(), bins=40, alpha=0.85, label="neg")
ax[0].hist(scores0[y==1].numpy(), bins=20, alpha=0.85, label="pos")
ax[0].set_title("Scores before"); ax[0].legend()

ax[1].hist(scores1[y==0].numpy(), bins=40, alpha=0.85, label="neg")
ax[1].hist(scores1[y==1].numpy(), bins=20, alpha=0.85, label="pos")
ax[1].set_title("Scores after"); ax[1].legend()
for a in ax: a.grid(alpha=0.3); a.set_xlabel("score")
ax[0].set_ylabel("count")
plt.tight_layout(); plt.show()


**📊 Balanced Sampling**

We build training batches with a target ~10% positive rate—by pulling in all/most positives and subsampling negatives—so gradients aren’t starved by all-negative minibatches and the model learns rare-event features faster. Crucially, evaluation stays on the natural class prior, so AUPRC and false-alarms-per-hour reflect real-world conditions while training remains stable.

In [None]:
# Ultra-fast one-cell demo: natural vs balanced composition
# - Exact if train_ds exposes cached labels (labels/y/targets)
# - Otherwise uses a tiny random sample to estimate natural prevalence
# - Balanced: uses your global 'sampler' for 5 batches if present; else simulates ~10% positives

import time, numpy as np, matplotlib.pyplot as plt

assert 'train_ds' in globals(), "Missing train_ds. Run your dataset setup first."

B = int(CONFIG.get("batch_size", 64))
TARGET_POS = 0.10
NUM_ITERS = 5  # tiny, to keep it snappy
rng = np.random.default_rng(0)

# ---------- Try to get labels without touching MEG arrays ----------
def _get_label_array_if_any(ds):
    for name in ("labels", "y", "targets", "labels_np", "label_array"):
        if hasattr(ds, name):
            arr = getattr(ds, name)
            try:
                return np.asarray(arr, dtype=np.int64)
            except Exception:
                pass
    return None

labels_arr = _get_label_array_if_any(train_ds)
N = len(train_ds)

# ---------- Natural prevalence ----------
if labels_arr is not None and len(labels_arr) == N:
    # Fast path: exact, zero window loads
    nat_pos = int((labels_arr == 1).sum()); nat_tot = N
    nat_frac = nat_pos / max(1, nat_tot)
    nat_mode = "exact (cached labels)"
else:
    # Fast estimate: small random sample (min 512 or 0.2% of N)
    sample_n = int(max(512, 0.002 * N))
    sample_idx = rng.choice(np.arange(N), size=min(sample_n, N), replace=False)
    pos_hits = 0
    t0 = time.time()
    for j in sample_idx:
        yj = train_ds[j][1]  # accesses only a tiny subset
        pos_hits += int(yj)
    took = time.time() - t0
    nat_frac = pos_hits / max(1, len(sample_idx))
    nat_pos = int(round(nat_frac * N)); nat_tot = N
    nat_mode = f"estimated from {len(sample_idx)} samples ({took:.2f}s)"

# ---------- Balanced observation ----------
obs_pos = obs_tot = 0
start = time.time()
if 'sampler' in globals():
    # Light touch: just a few batches
    it = iter(sampler)
    for _ in range(NUM_ITERS):
        try:
            batch_indices = next(it)
        except StopIteration:
            break
        # Only read labels; keep count small for speed
        ys = [int(train_ds[j][1]) for j in batch_indices]
        obs_pos += sum(ys); obs_tot += len(ys)
    bal_mode = f"observed via global sampler ({NUM_ITERS} batches)"
else:
    # No sampler? Simulate balanced batches ~10% positives (no dataset access)
    for _ in range(NUM_ITERS):
        k_pos = int(round(TARGET_POS * B))
        k_neg = max(0, B - k_pos)
        obs_pos += k_pos; obs_tot += (k_pos + k_neg)
    bal_mode = f"simulated at TARGET_POS={TARGET_POS:.2f}"

elapsed = time.time() - start
obs_frac = (obs_pos / obs_tot) if obs_tot else TARGET_POS

# ---------- Plot ----------
labels_txt = ["Natural", "Balanced"]
pos_fracs = [nat_frac, obs_frac]
neg_fracs = [1 - nat_frac, 1 - obs_frac]

fig, ax = plt.subplots(figsize=(6, 4))
ax.bar(labels_txt, neg_fracs, label="negatives", color="#c7d4e8")
ax.bar(labels_txt, pos_fracs, bottom=neg_fracs, label="positives", color="#4c72b0")
ax.set_ylabel("Fraction in batch")
ax.set_title("Class composition: natural vs balanced")
ax.legend(loc="upper right")
ax.grid(axis='y', alpha=0.3)
plt.tight_layout(); plt.show()

print(f"Natural pos frac: {nat_frac:.5f}  | Balanced pos frac (observed): {obs_frac:.2f}")
print(f"Counts — Train: N≈{nat_tot:,} (pos≈{nat_pos:,}, neg≈{nat_tot - nat_pos:,})")
print(f"[natural] mode: {nat_mode}")
print(f"[balanced] mode: {bal_mode}  | elapsed={elapsed:.3f}s  | batch_size={B}")


Now, let's **train the model**! We define the model architecture, sampler, loss, Pytorch Lighning module (for convenience). We then start the actual training run. Depending on your GPU, this may take **~30 minutes.**

In [None]:
import math, random
import torch
import torch.nn as nn
from dataclasses import dataclass
from torch.utils.data import DataLoader, Dataset, BatchSampler

# NEW: metrics/figures deps
import numpy as np
import matplotlib.pyplot as plt

# Lightning
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger, CSVLogger

# ----------------------------- Model blocks -----------------------------
class ResNetBlock1D(nn.Module):
    def __init__(self, channels: int = 128):
        super().__init__()
        same_supported = 'same' in nn.Conv1d.__init__.__code__.co_varnames
        pad3 = 'same' if same_supported else 1
        self.net = nn.Sequential(
            nn.ELU(), nn.Conv1d(channels, channels, 3, 1, pad3),
            nn.ELU(), nn.Conv1d(channels, channels, 1, 1, 0),
        )
    def forward(self, x):
        return x + self.net(x)

class SpeechDetectionNet(nn.Module):
    def __init__(self, in_channels: int = 306, lse_temperature: float = 0.5):
        super().__init__()
        same_supported = 'same' in nn.Conv1d.__init__.__code__.co_varnames
        pad7 = 'same' if same_supported else 3
        self.trunk = nn.Sequential(
            nn.Conv1d(in_channels, 128, 7, 1, pad7),
            ResNetBlock1D(128),
            nn.ELU(),
            nn.Conv1d(128, 128, 50, 25, 0),
            nn.ELU(),
            nn.Conv1d(128, 128, 7, 1, pad7),
            nn.ELU(),
        )
        self.head = nn.Sequential(nn.Conv1d(128, 512, 4, 1, 0), nn.ReLU(), nn.Dropout(0.5))
        self.logits_t = nn.Conv1d(512, 1, 1, 1, 0)
        self.attn_t   = nn.Conv1d(512, 1, 1, 1, 0)
    def forward(self, x):
        h = self.head(self.trunk(x))
        logit_t = self.logits_t(h)
        attn = torch.softmax(self.attn_t(h), dim=-1)
        return (logit_t * attn).sum(dim=-1).squeeze(1)

@dataclass
class OptimConfig:
    lr: float = 1e-4
    weight_decay: float = 1e-4
    max_time_shift: int = 4
    noise_std: float = 0.01

class FocalLoss(nn.Module):
    def __init__(self, alpha: float = 0.95, gamma: float = 2.0):
        super().__init__(); self.alpha=float(alpha); self.gamma=float(gamma)
    def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        ce = nn.functional.binary_cross_entropy_with_logits(logits, targets.float(), reduction='none')
        p = torch.sigmoid(logits); pt = torch.where(targets == 1, p, 1 - p)
        alpha_t = torch.where(targets == 1, logits.new_tensor(self.alpha), logits.new_tensor(1 - self.alpha))
        return (alpha_t * (1 - pt).pow(self.gamma) * ce).mean()

def pairwise_logistic_loss(logits: torch.Tensor, labels: torch.Tensor, max_pairs: int = 4096) -> torch.Tensor:
    pos_idx = (labels == 1).nonzero(as_tuple=False).view(-1)
    neg_idx = (labels == 0).nonzero(as_tuple=False).view(-1)
    if pos_idx.numel() == 0 or neg_idx.numel() == 0: return logits.new_zeros(())
    num_pairs = min(max_pairs, int(pos_idx.numel()) * int(neg_idx.numel()))
    pi = pos_idx[torch.randint(0, pos_idx.numel(), (num_pairs,), device=logits.device)]
    ni = neg_idx[torch.randint(0, neg_idx.numel(), (num_pairs,), device=logits.device)]
    return torch.nn.functional.softplus(-(logits[pi] - logits[ni])).mean()

class BalancedBatchSampler(BatchSampler):
    def __init__(self, pos_idx, neg_idx, batch_size: int, pos_fraction: float = 0.1):
        assert 0.0 < pos_fraction < 1.0 and len(pos_idx) > 0
        self.p_idx, self.n_idx = list(pos_idx), list(neg_idx)
        self.batch_size = batch_size
        self.n_pos = max(1, int(round(batch_size * pos_fraction)))
        self.n_neg = batch_size - self.n_pos
        total = len(self.p_idx) + len(self.n_idx)
        self._epoch_len = max(1, total // batch_size)
    def __iter__(self):
        p, n = self.p_idx[:], self.n_idx[:]
        random.shuffle(p); random.shuffle(n); pi = ni = 0
        while True:
            if pi + self.n_pos > len(p): random.shuffle(p); pi = 0
            if ni + self.n_neg > len(n): random.shuffle(n); ni = 0
            batch = p[pi:pi+self.n_pos] + n[ni:ni+self.n_neg]
            pi += self.n_pos; ni += self.n_neg
            random.shuffle(batch); yield batch
    def __len__(self): return self._epoch_len

# --------------------- Lightning module with rich metrics ---------------------
class KeywordDetectorPL(pl.LightningModule):
    def __init__(self, in_channels: int = 306, opt: OptimConfig = OptimConfig(), pairwise_lambda: float = 0.5):
        super().__init__()
        self.save_hyperparameters()
        self.model = SpeechDetectionNet(in_channels)
        self.criterion = FocalLoss(alpha=0.95, gamma=2.0)
        self.pairwise_lambda = float(pairwise_lambda)

        self._val_probs, self._val_targets = [], []
        self._test_probs, self._test_targets = [], []

    def forward(self, x): return self.model(x)

    def _augment(self, x):
        if not self.training: return x
        smax = self.hparams.opt.max_time_shift
        if smax and smax > 0:
            shifts = torch.randint(-smax, smax + 1, (x.size(0),), device=x.device)
            for i, sh in enumerate(shifts):
                if int(sh) != 0: x[i] = torch.roll(x[i], int(sh), dims=-1)
        sigma = self.hparams.opt.noise_std
        return x + torch.randn_like(x) * sigma if (sigma and sigma > 0) else x

    # -- steps --
    def training_step(self, batch, _):
        x, y = batch
        logits = self(self._augment(x))
        focal = self.criterion(logits.float(), y.float())
        pairwise = pairwise_logistic_loss(logits.detach(), y)
        loss = focal + self.pairwise_lambda * pairwise
        self.log_dict(
            {"train/loss": loss, "train/focal": focal, "train/pairwise": pairwise},
            on_step=False, on_epoch=True, prog_bar=False
        )
        return loss

    def validation_step(self, batch, _):
        x, y = batch
        logits = self(x)
        loss = nn.functional.binary_cross_entropy_with_logits(logits.float(), y.float())
        probs = torch.sigmoid(logits.detach()).float().view(-1).cpu()
        self._val_probs.append(probs)
        self._val_targets.append(y.detach().float().view(-1).cpu())
        self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=True)

    def test_step(self, batch, _):
        x, y = batch
        logits = self(x)
        loss = nn.functional.binary_cross_entropy_with_logits(logits.float(), y.float())
        probs = torch.sigmoid(logits.detach()).float().view(-1).cpu()
        self._test_probs.append(probs)
        self._test_targets.append(y.detach().float().view(-1).cpu())
        self.log("test/loss", loss, on_step=False, on_epoch=True, prog_bar=True)

    # -- epoch hooks --
    def on_validation_epoch_start(self):
        self._val_probs, self._val_targets = [], []

    def on_validation_epoch_end(self):
        if len(self._val_probs) == 0: return
        probs = torch.cat(self._val_probs).numpy()
        y = torch.cat(self._val_targets).numpy().astype(np.int64)
        metrics, figs = self._compute_all_metrics(y, probs)
        for k, v in metrics.items():
            if np.isnan(v): continue
            self.log(f"val/{k}", float(v), prog_bar=(k in ["auprc","roc_auc","best_f1","uplift"]))
        self._log_figs_tensorboard(figs, split="val")

    def on_test_epoch_start(self):
        self._test_probs, self._test_targets = [], []

    def on_test_epoch_end(self):
        if len(self._test_probs) == 0: return
        probs = torch.cat(self._test_probs).numpy()
        y = torch.cat(self._test_targets).numpy().astype(np.int64)
        metrics, figs = self._compute_all_metrics(y, probs)
        for k, v in metrics.items():
            if np.isnan(v): continue
            self.log(f"test/{k}", float(v), prog_bar=(k in ["auprc","roc_auc","best_f1","uplift"]))
        self._log_figs_tensorboard(figs, split="test")

    # -- optim --
    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.hparams.opt.lr, weight_decay=self.hparams.opt.weight_decay)

    # ------------------------ metric helpers ------------------------
    @staticmethod
    def _precision_recall_ap(y_true: np.ndarray, y_score: np.ndarray):
        order = np.argsort(-y_score, kind="mergesort")
        y_sorted = y_true[order]
        tp = np.cumsum(y_sorted)
        fp = np.cumsum(1 - y_sorted)
        P = tp[-1] if tp.size else 0

        if P == 0:
            recall = np.array([0.0, 1.0])
            precision = np.array([1.0, 0.0])
            ap = np.nan
            return precision, recall, ap

        recall = tp / P
        precision = tp / np.maximum(tp + fp, 1)

        recall = np.concatenate(([0.0], recall))
        precision = np.concatenate(([1.0], precision))

        for i in range(precision.size - 2, -1, -1):
            precision[i] = max(precision[i], precision[i + 1])
        ap = np.sum((recall[1:] - recall[:-1]) * precision[1:])
        return precision, recall, float(ap)

    @staticmethod
    def _roc_auc_pairwise(y_true: np.ndarray, y_score: np.ndarray, max_pairs: int = 2_000_000):
        P = int(y_true.sum()); N = int((1 - y_true).sum())
        if P == 0 or N == 0: return np.nan
        pos = y_score[y_true == 1]
        neg = y_score[y_true == 0]
        total = P * N
        if total <= max_pairs:
            gt = (pos[:, None] > neg[None, :]).mean()
            eq = (pos[:, None] == neg[None, :]).mean()
            return float(gt + 0.5 * eq)
        rng = np.random.default_rng(42)
        pi = rng.integers(0, P, size=max_pairs)
        ni = rng.integers(0, N, size=max_pairs)
        ps = pos[pi]; ns = neg[ni]
        gt = (ps > ns).mean(); eq = (ps == ns).mean()
        return float(gt + 0.5 * eq)

    @staticmethod
    def _threshold_metrics(y_true: np.ndarray, y_score: np.ndarray, thresh: float):
        yhat = (y_score >= thresh).astype(np.int64)
        tp = int(np.sum((yhat == 1) & (y_true == 1)))
        fp = int(np.sum((yhat == 1) & (y_true == 0)))
        fn = int(np.sum((yhat == 0) & (y_true == 1)))
        tn = int(np.sum((yhat == 0) & (y_true == 0)))
        prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        rec = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        f1 = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0.0
        return dict(tp=tp, fp=fp, fn=fn, tn=tn, precision=prec, recall=rec, f1=f1)

    @staticmethod
    def _best_f1(y_true: np.ndarray, y_score: np.ndarray):
        if y_true.sum() == 0 or (1 - y_true).sum() == 0:
            return dict(threshold=np.nan, precision=np.nan, recall=np.nan, f1=np.nan)
        qs = np.unique(np.percentile(y_score, np.linspace(0, 100, 201)))
        best = {"f1": -1.0, "threshold": 0.5, "precision": 0.0, "recall": 0.0}
        for t in qs:
            m = KeywordDetectorPL._threshold_metrics(y_true, y_score, float(t))
            if m["f1"] > best["f1"]:
                best.update(m); best["threshold"] = float(t)
        return best

    def _compute_all_metrics(self, y_true: np.ndarray, y_score: np.ndarray):
        prevalence = float(y_true.mean()) if y_true.size else np.nan
        brier = float(np.mean((y_score - y_true) ** 2)) if y_true.size else np.nan
        precision, recall, ap = self._precision_recall_ap(y_true, y_score)
        roc_auc = self._roc_auc_pairwise(y_true, y_score)

        at_05 = self._threshold_metrics(y_true, y_score, 0.5)
        best = self._best_f1(y_true, y_score)

        uplift = (ap / prevalence) if (prevalence and prevalence > 0 and ap == ap) else np.nan

        metrics = {
            "auprc": ap if ap == ap else np.nan,
            "roc_auc": roc_auc,
            "brier": brier,
            "prevalence": prevalence,
            "uplift": uplift,
            "prec_at_0.5": at_05["precision"],
            "recall_at_0.5": at_05["recall"],
            "f1_at_0.5": at_05["f1"],
            "best_f1": best["f1"],
            "best_f1_threshold": best["threshold"],
            "best_f1_precision": best["precision"],
            "best_f1_recall": best["recall"],
            "n_samples": float(y_true.size),
            "n_pos": float(y_true.sum()),
            "n_neg": float((1 - y_true).sum()),
        }

        figs = {}
        # PR
        fig_pr = plt.figure()
        plt.plot(recall, precision)
        plt.xlabel("Recall"); plt.ylabel("Precision")
        title = f"PR (AP={metrics['auprc']:.4f})" if metrics["auprc"] == metrics["auprc"] else "PR"
        plt.title(title)
        plt.grid(True, alpha=0.3)
        figs["pr"] = fig_pr

        # ROC (if valid)
        if not np.isnan(roc_auc):
            thr = np.unique(np.percentile(y_score, np.linspace(0, 100, 201)))
            tprs, fprs = [], []
            P = max(1, int(y_true.sum())); N = max(1, int((1 - y_true).sum()))
            for t in thr:
                yhat = (y_score >= t).astype(np.int64)
                tp = np.sum((yhat == 1) & (y_true == 1))
                fp = np.sum((yhat == 1) & (y_true == 0))
                tprs.append(tp / P); fprs.append(fp / N)
            fig_roc = plt.figure()
            plt.plot(fprs, tprs)
            plt.plot([0, 1], [0, 1], linestyle="--", linewidth=1)
            plt.xlabel("FPR"); plt.ylabel("TPR")
            plt.title(f"ROC (AUC={roc_auc:.4f})")
            plt.grid(True, alpha=0.3)
            figs["roc"] = fig_roc

        return metrics, figs

    def _log_figs_tensorboard(self, figs: dict, split: str):
        writers = []
        try:
            candidates = []
            if isinstance(self.logger, TensorBoardLogger):
                candidates = [self.logger]
            elif hasattr(self.trainer, "loggers") and self.trainer.loggers:
                candidates = [lg for lg in self.trainer.loggers if isinstance(lg, TensorBoardLogger)]
            elif self.logger is not None:
                candidates = [self.logger] if isinstance(self.logger, TensorBoardLogger) else []
            for lg in candidates:
                writers.append(lg.experiment)
        except Exception:
            pass

        if not writers:
            for f in figs.values(): plt.close(f)
            return

        global_step = int(self.current_epoch)
        for name, fig in figs.items():
            for w in writers:
                try:
                    w.add_figure(f"{name}/{split}", fig, global_step=global_step, close=True)
                except Exception:
                    try:
                        w.add_figure(f"{name}/{split}", fig, global_step=global_step)
                    except Exception:
                        pass
            plt.close(fig)

# ------------------------ Sampler + DataLoaders ------------------------
# Assumes train_ds/val_ds/test_ds (and val_loader/test_loader) are defined elsewhere.
pos_idx, neg_idx = [], []
for i in range(len(train_ds)):
    _, y = train_ds[i]
    (pos_idx if int(y)==1 else neg_idx).append(i)
print(f"Found {len(pos_idx)} positives / {len(pos_idx)+len(neg_idx)} total in train")

sampler = BalancedBatchSampler(pos_idx, neg_idx, batch_size=CONFIG["batch_size"], pos_fraction=0.10)
train_loader_bal = DataLoader(train_ds, batch_sampler=sampler, num_workers=CONFIG["num_workers"])

# ------------------------------ Training --------------------------------
print("🚀 Starting training...")

# CSV logger always; add TB in Colab for curves
csv_logger = CSVLogger(save_dir=f"{BASE_PATH}/lightning_logs", name="kws", version=None)
logger = csv_logger
if IN_COLAB:
    tb_logger = TensorBoardLogger(save_dir=f"{BASE_PATH}/lightning_logs", name="tb", version=None)
    logger = [csv_logger, tb_logger]

# TensorBoard in Colab (optional)
if IN_COLAB:
    try:
        from google.colab import output
        %load_ext tensorboard
        %tensorboard --logdir {BASE_PATH}/lightning_logs
        print("📈 TensorBoard launched! Check the output above.")
    except Exception:
        print("📊 TensorBoard setup failed, but training completed successfully.")

model = KeywordDetectorPL(
    in_channels=train_ds[0][0].shape[0],
    opt=OptimConfig(
        lr=1e-4,
        weight_decay=1e-4,
        max_time_shift=4,
        noise_std=0.01
    )
)

trainer = pl.Trainer(
    devices="auto",
    max_epochs=CONFIG["max_epochs"],
    callbacks=[EarlyStopping(monitor="val/loss", mode="min", patience=CONFIG["early_patience"], verbose=True)],
    logger=logger,
    log_every_n_steps=50,
    check_val_every_n_epoch=1
)

pl.seed_everything(42)

print(f"📊 Training on {len(train_ds)} samples with balanced sampling")
print(f"✅ Validation on {len(val_ds)} samples")
print(f"🔬 Testing on {len(test_ds)} samples")

trainer.fit(model, train_loader_bal, val_loader)

print("\n🎯 Final evaluation on test set:")
trainer.test(model, dataloaders=test_loader)


## VI. Evaluation

Evaluating keyword detection models requires careful metric selection and interpretation. Traditional metrics can be misleading on imbalanced data—below are the ones that matter and how to interpret them for our keyword.

### Why Standard Metrics Fail

Accuracy is often misleading on imbalanced data (always predicting "no" will be correct over 99% of the time).

F1, on the other hand, can be dominated by precision when positive prevalence is very low.

### Metrics That Actually Matter
For the Neural Keyword Spotting task, we standardize two key evaluation dimensions:

**Threshold-free Metrics**
AUPRC (area under precision–recall):
- Baseline equals positive class prevalence. For Watson, p≈0.001 (0.1%).
- Aim for values clearly above 0.001; improvements of 2–10× over chance are meaningful.

Precision–Recall trade-off:
- Precision: fraction of predicted keywords that are correct (controls false alarms)
- Recall: fraction of true keywords detected

AUROC (secondary):
- Useful for architecture comparison, but optimistic under heavy imbalance.


**User-facing Deployment Metrics**
False alarms per hour (FA/h):
- Target <10 FA/h. Compute as (FP / total_seconds) × 3600.

Operating point selection:
- Choose threshold on validation to meet FA/h or precision targets; report test results at that threshold.

### Broad Performance Interpretation
- Chance: Prevalence (% of words that are our keyword)
- 2-5x Chance: modest improvement
- 2-5x Chance: reasonable for this task
- .>10x Chance: strong for this dataset

Let's see how we stack up:

In [None]:
# --- Metrics & curves + structured logs & nicer visuals ---
import json
from sklearn.metrics import (
    precision_recall_curve, roc_curve, auc, confusion_matrix,
    precision_score, recall_score, f1_score, accuracy_score
)
import numpy as np, torch

# Run preds
model.eval()
device = next(model.parameters()).device
all_probs, all_true = [], []
with torch.no_grad():
    for xb, yb in test_loader:
        xb = xb.to(device, non_blocking=True)
        logits = model(xb)
        probs = torch.sigmoid(logits).detach().cpu().float().view(-1)
        all_probs.append(probs)
        all_true.append(yb.cpu().int().view(-1))

y_prob = torch.cat(all_probs).numpy().astype(float)
y_true = torch.cat(all_true).numpy().astype(int)


# PR / ROC
prec, rec, thr = precision_recall_curve(y_true, y_prob)
pr_auc = auc(rec, prec)

try:
    fpr, tpr, roc_thr = roc_curve(y_true, y_prob)
    roc_auc = auc(fpr, tpr)
except Exception:
    fpr, tpr, roc_thr = None, None, None
    roc_auc = float('nan')

prevalence = float(y_true.mean())
uplift = (pr_auc / prevalence) if prevalence > 0 else float('nan')

# Operating point @ threshold 0.5
tau = 0.5
y_pred = (y_prob >= tau).astype(int)
cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
tn, fp, fn, tp = cm.ravel()

# Rates @ 0.5
prec05 = precision_score(y_true, y_pred, zero_division=0)
rec05  = recall_score(y_true, y_pred, zero_division=0)
f105   = f1_score(y_true, y_pred, zero_division=0)
acc05  = accuracy_score(y_true, y_pred)
spec05 = tn / (tn + fp) if (tn + fp) > 0 else float('nan')
fpr05  = 1 - spec05
tpr05  = rec05

# --- Human-readable summary ---
print("\n=== Test Summary ===")
print(f"Samples                      : {len(y_true):,} "
      f"(pos={int(y_true.sum()):,}, neg={int((1-y_true).sum()):,})")
print(f"Prevalence (chance AUPRC)    : {prevalence:.6f}")
print(f"AUPRC                        : {pr_auc:.4f}  | uplift vs. chance: {uplift:.2f}×")
print(f"AUROC                        : {roc_auc:.4f}" if not np.isnan(roc_auc) else "AUROC                        : n/a")
print(f"Threshold τ                  : {tau:.2f}")
print(f"Precision@τ                  : {prec05:.4f}")
print(f"Recall@τ (TPR)               : {rec05:.4f}")
print(f"Specificity@τ (TNR)          : {spec05:.4f}")
print(f"FPR@τ                        : {fpr05:.4f}")
print(f"F1@τ                         : {f105:.4f}")
print(f"Accuracy@τ                   : {acc05:.4f}")

# --- Curves with operating point marked ---
fig, ax = plt.subplots(1, 2, figsize=(12, 4.2))

# PR
ax[0].plot(rec, prec, lw=1.8, label=f'PR curve (AUC={pr_auc:.3f})')
ax[0].axhline(y=prevalence, color='r', linestyle='--', alpha=0.7, label=f'Chance ({prevalence:.3f})')
ax[0].scatter(rec05, prec05, s=40, marker='o', edgecolor='k', zorder=5,
              label=f'@ τ={tau:.2f}  (P={prec05:.2f}, R={rec05:.2f})')
ax[0].set_xlabel("Recall")
ax[0].set_ylabel("Precision")
ax[0].set_title("Precision–Recall")
ax[0].legend()
ax[0].grid(True, alpha=0.3)

# ROC (if available)
if fpr is not None:
    ax[1].plot(fpr, tpr, lw=1.8, label=f'ROC curve (AUC={roc_auc:.3f})')
    ax[1].plot([0, 1], [0, 1], 'k--', alpha=0.5, label='Chance (0.5)')
    ax[1].scatter(fpr05, tpr05, s=40, marker='o', edgecolor='k', zorder=5,
                  label=f'@ τ={tau:.2f}  (FPR={fpr05:.2f}, TPR={tpr05:.2f})')
    ax[1].set_xlabel("False Positive Rate")
    ax[1].set_ylabel("True Positive Rate")
    ax[1].set_title("ROC")
    ax[1].legend()
    ax[1].grid(True, alpha=0.3)
else:
    ax[1].axis('off')

plt.tight_layout()
plt.show()

# --- Colorful + annotated confusion matrix (counts + row-normalized %) ---
def plot_confusion_matrix_annotated(cm, class_names=("Negative", "Positive"), cmap="viridis"):
    cm = cm.astype(np.float64)
    # Row-normalized for percentages
    row_sums = cm.sum(axis=1, keepdims=True)
    cm_norm = np.divide(cm, row_sums, out=np.zeros_like(cm), where=row_sums != 0)

    fig, ax = plt.subplots(figsize=(5.5, 4.6))
    im = ax.imshow(cm_norm, interpolation='nearest', cmap=cmap)
    cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    cbar.ax.set_ylabel('Row-normalized %', rotation=270, va='bottom')

    ax.set_xticks([0, 1], labels=[f"Pred {n}" for n in class_names])
    ax.set_yticks([0, 1], labels=[f"True {n}" for n in class_names])
    ax.set_xlabel("Predicted")
    ax.set_ylabel("True")
    ax.set_title("Confusion Matrix (counts + row %)")
    ax.grid(False)

    # Annotate each cell with "xx% (count)"
    for i in range(2):
        for j in range(2):
            pct = f"{(cm_norm[i, j] * 100):.1f}%"
            cnt = f"{int(cm[i, j]):,}"
            text = f"{pct}\n({cnt})"
            ax.text(j, i, text, ha="center", va="center", fontsize=11,
                    color="white" if cm_norm[i, j] > 0.5 else "black")

    plt.tight_layout()
    plt.show()

print("\nConfusion Matrix (rows=true, cols=pred):")
print(cm)
print(f"True Negatives: {tn}, False Positives: {fp}")
print(f"False Negatives: {fn}, True Positives: {tp}")

plot_confusion_matrix_annotated(cm, class_names=("Negative", "Positive"), cmap="viridis")


Even at an AUPRC of 0.01 the model is, frankly, unusable in practice - the tiny uplift over chance in the PR curve’s top-left makes that obvious. Still, it’s **significantly above chance**, which matters: it signals there’s real information in the MEG that a better model/training regime could exploit.

To make this more actionable, we complement AUPRC with **false alarms per hour (FA/h) at fixed recall**. On the next chart, you want your model to be on the top-left. We report FA/h at a few recall targets and compare against two baselines: a **random (permuted-scores)** chance model and the **always-negative** trivial model. This turns “a small AUPRC” into an operational question: *at recall 0.2/0.4/0.6, how many false alarms per hour would you actually get—and is your performance better than nothing?*

In [None]:
# === False Alarms per Hour (FA/h) @ fixed recall ===
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_curve, confusion_matrix

# y_true: 0/1 labels; y_prob: model probabilities for the same windows
y_true_np = np.asarray(y_true).astype(int)
y_prob_np = np.asarray(y_prob).astype(float)
N = len(y_true_np)

# -------------------- Infer evaluation coverage hours from CONFIG / dataset --------------------
def infer_window_seconds(CONFIG, ds=None):
    # 1) If dataset exposes exact window span, prefer that
    for attr_pair in [
        ("window_seconds", None),
        ("tmin", "tmax"),            # many datasets expose tmin/tmax
        ("window_tmin", "window_tmax"),
    ]:
        a, b = attr_pair
        if ds is not None and hasattr(ds, a) and (b is None or hasattr(ds, b)):
            if b is None:
                return float(getattr(ds, a))
            return float(getattr(ds, b) - getattr(ds, a))
    # 2) Otherwise: derive from CONFIG buffers (keyword duration varies, but buffers dominate)
    neg = float(CONFIG.get("negative_buffer", 0.0))
    pos = float(CONFIG.get("positive_buffer", 0.0))
    # If tmin/tmax explicitly set in CONFIG, use that
    tmin = CONFIG.get("tmin", None)
    tmax = CONFIG.get("tmax", None)
    if tmin is not None and tmax is not None:
        return float(tmax - tmin)
    # Default: buffers only (conservative)
    return neg + pos

def infer_stride_seconds(CONFIG, ds=None, window_seconds=None):
    # Prefer dataset-provided stride/hop attributes if present
    for name in ["stride_seconds", "hop_seconds", "hop_s", "step_seconds"]:
        if ds is not None and hasattr(ds, name):
            return float(getattr(ds, name))
    # If CONFIG ever includes stride_seconds, use it
    if "stride_seconds" in CONFIG:
        return float(CONFIG["stride_seconds"])
    # Otherwise assume 50% overlap (robust default for tutorials)
    return (window_seconds or infer_window_seconds(CONFIG, ds)) / 2.0

def infer_hours_total(CONFIG, ds=None, y_len=None):
    # If dataset carries explicit coverage (best case), use it
    for name in ["coverage_seconds", "total_seconds", "eval_seconds"]:
        if ds is not None and hasattr(ds, name):
            return float(getattr(ds, name)) / 3600.0, f"{name} from dataset"
    # If dataset has an index with per-window [t0,t1], compute the union duration
    idx = None
    for name in ["index_df", "windows_df", "index"]:
        if ds is not None and hasattr(ds, name):
            idx = getattr(ds, name)
            break
    if idx is not None:
        # Try the most likely column name pairs
        for lcol, rcol in [("t0", "t1"), ("start_s", "end_s"), ("left_s", "right_s")]:
            if lcol in idx and rcol in idx:
                intervals = np.array(idx[[lcol, rcol]], dtype=float)
                # union of intervals
                order = np.argsort(intervals[:,0])
                intervals = intervals[order]
                total = 0.0
                cur_l, cur_r = intervals[0]
                for l, r in intervals[1:]:
                    if l <= cur_r:
                        cur_r = max(cur_r, r)
                    else:
                        total += (cur_r - cur_l)
                        cur_l, cur_r = l, r
                total += (cur_r - cur_l)
                return total / 3600.0, "union of window intervals from dataset index"
    # Otherwise: estimate from window + stride implied by CONFIG
    ws = infer_window_seconds(CONFIG, ds)
    st = infer_stride_seconds(CONFIG, ds, ws)
    seconds = max(0.0, (int(y_len or 0) - 1) * st + ws)
    return seconds / 3600.0, "estimated from CONFIG (window & 50% overlap)"

# Try to find a test dataset object if it's around
_test_ds = None
for cand in ["test_ds", "test_dataset", "test_loader"]:
    if cand in globals():
        obj = globals()[cand]
        if hasattr(obj, "dataset"):   # DataLoader
            _test_ds = obj.dataset
            break
        else:
            _test_ds = obj
            break

hours_total, hours_source = infer_hours_total(CONFIG, ds=_test_ds, y_len=N)
window_seconds = infer_window_seconds(CONFIG, ds=_test_ds)
stride_seconds = infer_stride_seconds(CONFIG, ds=_test_ds, window_seconds=window_seconds)
print(f"[coverage] hours_total={hours_total:.3f}h  (source: {hours_source})  "
      f"| window={window_seconds:.3f}s  stride≈{stride_seconds:.3f}s")

# -------------------- Core helpers --------------------
def _metrics_at_threshold(y_true, y_prob, tau, hours_total):
    y_pred = (y_prob >= tau).astype(int)
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0,1]).ravel()
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    prec   = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    fah    = fp / max(hours_total, 1e-12)
    return dict(threshold=float(tau), tp=int(tp), fp=int(fp), fn=int(fn), tn=int(tn),
                recall=recall, precision=prec, fa_per_hour=fah)

def _threshold_for_recall(y_true, y_prob, recall_target):
    # Map recall target → threshold via PR curve
    prec, rec, thr = precision_recall_curve(y_true, y_prob)
    if len(thr) == 0:
        return 0.5
    rec_seg = rec[1:]               # align to thr
    rec_inc = rec_seg[::-1]
    thr_inc = thr[::-1]
    r = float(np.clip(recall_target, rec_inc.min(), rec_inc.max()))
    return float(np.interp(r, rec_inc, thr_inc))

def _fa_curve(y_true, y_prob, hours_total):
    _, rec, thr = precision_recall_curve(y_true, y_prob)
    if len(thr) == 0:
        return np.array([0.0]), np.array([0.0]), np.array([0.5])
    all_thr = np.unique(np.concatenate([[-np.inf], thr, [np.inf]]))
    recalls, fahs = [], []
    for t in all_thr:
        m = _metrics_at_threshold(y_true, y_prob, t, hours_total)
        recalls.append(m["recall"])
        fahs.append(m["fa_per_hour"])
    return np.array(recalls), np.array(fahs), all_thr

# -------------------- Baselines --------------------
rng = np.random.default_rng(0)
y_prob_perm = rng.permutation(y_prob_np)  # random (same distribution; broken ordering)
always_neg_point = dict(recall=0.0, fa_per_hour=0.0)

# -------------------- FA/h at fixed recall targets --------------------
recall_targets = [0.20, 0.40, 0.60]

rows = []
rows_baseline = []
for r in recall_targets:
    tau = _threshold_for_recall(y_true_np, y_prob_np, r)
    rows.append({"recall_target": r, **_metrics_at_threshold(y_true_np, y_prob_np, tau, hours_total)})

    tau_b = _threshold_for_recall(y_true_np, y_prob_perm, r)
    rows_baseline.append({"recall_target": r, **_metrics_at_threshold(y_true_np, y_prob_perm, tau_b, hours_total)})

def _row_fmt(r):
    return (f"r*={r['recall_target']:.2f} | τ={r['threshold']:.3f} | "
            f"FA/h={r['fa_per_hour']:.2f} | P={r['precision']:.3f} | R={r['recall']:.3f}  "
            f"| TP={r['tp']}, FP={r['fp']}, FN={r['fn']}, TN={r['tn']}")

print("\n=== FA/h at fixed recall (model) ===")
for r in rows: print(_row_fmt(r))
print("\n--- Baseline: random (permuted scores) ---")
for r in rows_baseline: print(_row_fmt(r))
print("\n--- Trivial baseline: always negative ---")
print("r*=0.00 | τ=n/a  | FA/h=0.00 | P=0.000 | R=0.000  | TP=0, FP=0, FN=pos, TN=neg")

# -------------------- Plot: Recall vs FA/h (axes flipped) --------------------
rec_m, fah_m, _ = _fa_curve(y_true_np, y_prob_np, hours_total)
rec_b, fah_b, _ = _fa_curve(y_true_np, y_prob_perm, hours_total)

fig, ax = plt.subplots(figsize=(7.6, 4.4))
ax.plot(fah_m, rec_m, label="Model")
ax.plot(fah_b, rec_b, label="Random baseline (permuted)")
ax.scatter(always_neg_point["fa_per_hour"], always_neg_point["recall"], marker="x", s=60, label="Always negative")

for r in rows:
    ax.scatter(r["fa_per_hour"], r["recall"], s=36)
    ax.annotate(f"r*={r['recall_target']:.2f}\nτ={r['threshold']:.2f}",
                (r["fa_per_hour"], r["recall"]), textcoords="offset points", xytext=(6,6))

ax.set_xlabel("False alarms per hour (FA/h)")
ax.set_ylabel("Recall")
ax.set_title("Recall vs False Alarms per Hour")
ax.grid(True, alpha=0.3)
ax.legend(loc="lower right")
plt.tight_layout(); plt.show()


## VII. Next Steps

Congratulations! You've successfully built and trained a Neural Keyword Spotting model! This final section guides you through advanced experiments and research directions to push performance further.

### 🎯 Immediate Experiments

Here are some things you can try out right now, in this Notebook:

**Keyword Selection**
```python
# Try different keywords with varying properties:
common_words = ["the", "and", "of", "to"]        # High frequency, low precision
medium_words = ["holmes", "watson", "sherlock"]   # Medium frequency, higher precision  
rare_words = ["magnifying", "deduction", "pipe"]  # Low frequency, potentially high precision
```
*Research question*: How does keyword frequency affect detection difficulty?

**Sample Length Optimisation**
```python
# Experiment with different context windows:
short_context = {"negative_buffer": 0.05, "positive_buffer": 0.15}  # 200ms total
medium_context = {"negative_buffer": 0.10, "positive_buffer": 0.30} # 400ms total
long_context = {"negative_buffer": 0.20, "positive_buffer": 0.50}   # 700ms total
```
*Research question*: What's the optimal temporal context for keyword detection?

**Architecture Modifications**
```python
# Test different model configurations:
configs = [
    {"model_dim": 64, "dropout": 0.3},   # Smaller model
    {"model_dim": 256, "dropout": 0.5},  # Larger model
]
```
*Research question*: Which architectural choices matter most for performance?

### 🧠 Further Research Directions

If you prefer more comprehensive research directions, these might be interesting:

**1. Multi-keyword Detection**
- *Challenge*: Multi-label classification with extreme imbalance
- *Benefit*: More practical for real BCI applications

**2. Cross-session Generalisation**
- *Challenge*: Neural patterns drift over time and sessions
- *Metric*: Performance degradation over time

**3. Real-time Implementation**
- *Challenge*: Causal processing, latency constraints
- *Metric*: Detection latency vs. accuracy trade-off


### Conclusion

We hope you enjoyed the tutorial! If you have feedback or comments, please let us know under [REDACTED]. We believe that KWS has the potential to be the first practically useful application of non-invasive speech-BCIs. Even a 1-Bit channel might improve patients' quality of life drastically - so we wish you the best of luck in exploring it! 🧠✨
