In [None]:
# Install dependencies
%pip install -q mne_bids lightning torchmetrics scikit-learn plotly ipywidgets neptune

# Set up base path for dataset and related files
base_path = "./libribrain"

# Install pnpl from local modified package
%pip install ../modified-pnpl/pnpl

# Remember to set the NEPTUNE_API_TOKEN and NEPTUNE_PROJECT environment variables
# before running the next cell

In [None]:
#!/usr/bin/env python3
"""
Watson keyword detection on LibriBrain (MEG)
- Oversampling-only training (no class-weighting)
- Focal loss + pairwise ranking aux loss
- Temporal attention pooling backbone
- Robust, PR-friendly validation/test diagnostics
- Warmup + cosine LR; optional Neptune logging
- Fast cached label indexing for balanced sampler
"""

from __future__ import annotations
import os, math, random, json, hashlib, csv
from dataclasses import dataclass
from typing import Optional, Tuple, List, Iterator

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, BatchSampler

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, EarlyStopping

# ----------------- Utilities -----------------
def collate_label_only_xy(batch):  # picklable; used by index scan
    return [int(y) for _, y in batch]

def _cpu(t): return t.detach().float().cpu()

# ---------------- Neptune (optional) ----------------
def make_neptune_logger(run_name: str | None = None):
    api_key, project = os.getenv("NEPTUNE_API_TOKEN"), os.getenv("NEPTUNE_PROJECT")
    if not api_key or not project:
        print("Neptune: env vars not found -> skipping Neptune logging.")
        return None
    try:
        from lightning.pytorch.loggers import NeptuneLogger as _BaseNeptuneLogger
    except Exception:
        from pytorch_lightning.loggers import NeptuneLogger as _BaseNeptuneLogger

    class CleanNeptuneLogger(_BaseNeptuneLogger):
        def log_metrics(self, metrics, step: int | None = None):
            filt = {k: v for k, v in metrics.items() if k != "epoch" and not k.endswith("/epoch")}
            super().log_metrics(filt, step=None)

    logger = CleanNeptuneLogger(
        api_key=api_key, project=project, name=run_name,
        tags=["libribrain", "watson", "meg", "keyword-detection"],
        prefix="training/", log_model_checkpoints=False,
    )
    print("Neptune: ✅ enabled.")
    return logger

# ---------------- Model ----------------
class ResNetBlock1D(nn.Module):
    def __init__(self, channels: int = 128):
        super().__init__()
        same_supported = 'same' in nn.Conv1d.__init__.__code__.co_varnames
        pad3 = 'same' if same_supported else 1
        self.net = nn.Sequential(
            nn.ELU(), nn.Conv1d(channels, channels, 3, 1, pad3),
            nn.ELU(), nn.Conv1d(channels, channels, 1, 1, 0),
        )
    def forward(self, x): return x + self.net(x)

class SpeechDetectionNet(nn.Module):
    """Conv trunk + temporal attention pooling."""
    def __init__(self, in_channels: int = 306, lse_temperature: float = 0.5):  # lse_temperature kept for API compat
        super().__init__()
        same_supported = 'same' in nn.Conv1d.__init__.__code__.co_varnames
        pad7 = 'same' if same_supported else 3
        self.trunk = nn.Sequential(
            nn.Conv1d(in_channels, 128, 7, 1, pad7),
            ResNetBlock1D(128),
            nn.ELU(),
            nn.Conv1d(128, 128, 50, 25, 0),  # downsample time
            nn.ELU(),
            nn.Conv1d(128, 128, 7, 1, pad7),
            nn.ELU(),
        )
        self.head = nn.Sequential(nn.Conv1d(128, 512, 4, 1, 0), nn.ReLU(), nn.Dropout(0.5))
        self.logits_t = nn.Conv1d(512, 1, 1, 1, 0)
        self.attn_t   = nn.Conv1d(512, 1, 1, 1, 0)

    def forward(self, x):
        h = self.head(self.trunk(x))        # (N,512,T')
        logit_t = self.logits_t(h)          # (N,1,T')
        attn = torch.softmax(self.attn_t(h), dim=-1)
        return (logit_t * attn).sum(dim=-1).squeeze(1)  # (N,)

# ---------------- Losses + metrics helpers ----------------
@dataclass
class OptimConfig:
    lr: float = 1e-4
    weight_decay: float = 1e-4
    max_time_shift: int = 4
    noise_std: float = 0.01
    warmup_epochs: int = 1
    cosine_after_warmup: bool = True

class FocalLoss(nn.Module):
    def __init__(self, alpha: float = 0.95, gamma: float = 2.0):
        super().__init__(); self.alpha = float(alpha); self.gamma = float(gamma)
    def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        ce = nn.functional.binary_cross_entropy_with_logits(logits, targets.float(), reduction='none')
        p = torch.sigmoid(logits); pt = torch.where(targets == 1, p, 1 - p)
        alpha_t = torch.where(targets == 1, logits.new_tensor(self.alpha), logits.new_tensor(1 - self.alpha))
        return (alpha_t * (1 - pt).pow(self.gamma) * ce).mean()

try:
    from torchmetrics.classification import BinaryAccuracy, BinaryAveragePrecision, BinaryAUROC
except Exception:
    from torchmetrics import Accuracy as BinaryAccuracy                 # type: ignore
    from torchmetrics import AveragePrecision as BinaryAveragePrecision # type: ignore
    from torchmetrics import AUROC as BinaryAUROC                       # type: ignore

# ---------------- LightningModule ----------------
class WatsonKeywordPL(pl.LightningModule):
    def __init__(self, in_channels: int = 306, pos_weight: float = 1.0,
                 opt: OptimConfig = OptimConfig(), lse_temperature: float = 0.5,
                 pairwise_lambda: float = 0.5):
        super().__init__()
        self.save_hyperparameters()
        self.model = SpeechDetectionNet(in_channels, lse_temperature=lse_temperature)
        self.register_buffer("pos_weight_tensor", torch.tensor([pos_weight], dtype=torch.float32))  # kept for API compat
        self.criterion = FocalLoss(alpha=0.95, gamma=2.0)
        self.pairwise_lambda = float(pairwise_lambda)

        # epoch aggregates (ranking-friendly)
        self.train_acc = BinaryAccuracy(); self.val_acc = BinaryAccuracy(); self.test_acc = BinaryAccuracy()
        self.val_auprc = BinaryAveragePrecision(); self.test_auprc = BinaryAveragePrecision()
        self.val_auroc = BinaryAUROC(); self.test_auroc = BinaryAUROC()

        self._val_probs: List[torch.Tensor] = []; self._val_labels: List[torch.Tensor] = []
        self._test_probs: List[torch.Tensor] = []; self._test_labels: List[torch.Tensor] = []
        self._train_pos = 0; self._train_total = 0
        self._val_pos = 0; self._val_total = 0

    # ---- small helpers ----
    @staticmethod
    def _pairwise_logistic_loss(logits: torch.Tensor, labels: torch.Tensor, max_pairs: int = 4096) -> torch.Tensor:
        pos_idx = (labels == 1).nonzero(as_tuple=False).view(-1)
        neg_idx = (labels == 0).nonzero(as_tuple=False).view(-1)
        if pos_idx.numel() == 0 or neg_idx.numel() == 0: return logits.new_zeros(())
        num_pairs = min(max_pairs, int(pos_idx.numel()) * int(neg_idx.numel()))
        pi = pos_idx[torch.randint(0, pos_idx.numel(), (num_pairs,), device=logits.device)]
        ni = neg_idx[torch.randint(0, neg_idx.numel(), (num_pairs,), device=logits.device)]
        return torch.nn.functional.softplus(-(logits[pi] - logits[ni])).mean()

    @staticmethod
    def _rprecision(probs: torch.Tensor, labels: torch.Tensor) -> float:
        m = int(labels.sum().item()); 
        if m <= 0: return 0.0
        k = min(m, probs.numel())
        prec_at_m = labels[torch.topk(probs, k=k, largest=True).indices].float().mean().item()
        return float(prec_at_m)

    @staticmethod
    def _precision_recall_at_k(probs: torch.Tensor, labels: torch.Tensor, k: int) -> Tuple[float, float]:
        k = max(1, min(k, probs.numel()))
        topk = torch.topk(probs, k=k, largest=True).indices
        tp = labels[topk].sum().item()
        prec = tp / k; rec = tp / max(1, int(labels.sum().item()))
        return float(prec), float(rec)

    @staticmethod
    def _best_f1(probs: torch.Tensor, labels: torch.Tensor) -> Tuple[float, float, Tuple[int,int,int,int]]:
        N = probs.numel()
        if N == 0: return 0.0, 0.5, (0,0,0,0)
        sort_idx = torch.argsort(probs, descending=True); y = labels[sort_idx].to(torch.int32)
        cum_tp = torch.cumsum(y, dim=0); ks = torch.arange(1, N+1, device=probs.device)
        precision = cum_tp / ks; total_pos = max(1, int(labels.sum().item()))
        recall = cum_tp / total_pos
        denom = precision + recall
        f1 = torch.where(denom > 0, 2 * precision * recall / denom, torch.zeros_like(denom))
        i = int(torch.argmax(f1).item()); best_f1 = float(f1[i].item()); thr = float(probs[sort_idx[i]].item())
        k = i + 1; tp = int(cum_tp[i].item()); fp = int(k - tp); fn = int(total_pos - tp); tn = int(N - k - fn)
        return best_f1, thr, (tp, fp, tn, fn)

    @staticmethod
    def _f1_macro_at_threshold(probs: torch.Tensor, labels: torch.Tensor, threshold: float = 0.5) -> float:
        if probs.numel() == 0: return 0.0
        preds, lab = (probs >= threshold).to(torch.int32), labels.to(torch.int32)
        tp = int(((preds == 1) & (lab == 1)).sum().item())
        fp = int(((preds == 1) & (lab == 0)).sum().item())
        fn = int(((preds == 0) & (lab == 1)).sum().item())
        tn = int(((preds == 0) & (lab == 0)).sum().item())
        def _f1(p, r): return 0.0 if (p + r) == 0 else (2 * p * r) / (p + r)
        prec_pos = tp / max(1, tp + fp); rec_pos = tp / max(1, tp + fn)
        prec_neg = tn / max(1, tn + fn); rec_neg = tn / max(1, tn + fp)
        return float((_f1(prec_pos, rec_pos) + _f1(prec_neg, rec_neg)) / 2.0)

    @staticmethod
    def _recall_at_precision(probs: torch.Tensor, labels: torch.Tensor, min_precision: float = 0.9) -> float:
        N = probs.numel()
        if N == 0: return 0.0
        sort_idx = torch.argsort(probs, descending=True); y = labels[sort_idx].to(torch.int32)
        cum_tp = torch.cumsum(y, dim=0); ks = torch.arange(1, N+1, device=probs.device)
        precision = cum_tp / ks; total_pos = max(1, int(labels.sum().item()))
        recall = cum_tp / total_pos
        mask = precision >= min_precision
        return float(recall[mask].max().item()) if mask.any() else 0.0

    # ---- Lightning required ----
    def forward(self, x): return self.model(x)

    def _augment(self, x):
        if not self.training: return x
        smax = self.hparams.opt.max_time_shift
        if smax and smax > 0:
            shifts = torch.randint(-smax, smax + 1, (x.size(0),), device=x.device)
            for i, sh in enumerate(shifts):
                if int(sh) != 0: x[i] = torch.roll(x[i], int(sh), dims=-1)
        sigma = self.hparams.opt.noise_std
        return x + torch.randn_like(x) * sigma if (sigma and sigma > 0) else x

    def _bce_unweighted(self, logits, y):
        return nn.functional.binary_cross_entropy_with_logits(logits.float(), y.float())

    def training_step(self, batch, _):
        x, y = batch
        self._train_pos += int(y.sum()); self._train_total += int(y.numel())
        logits = self(self._augment(x))
        focal = self.criterion(logits.float(), y.float())
        pairwise = self._pairwise_logistic_loss(logits.detach(), y)  # detached for stability
        loss = focal + self.pairwise_lambda * pairwise
        probs = torch.sigmoid(logits.float())
        self.train_acc.update(probs, y)
        self.log_dict({
            "train_loss": loss, "train_focal": focal, "train_pairwise": pairwise,
            "train_pos_frac": y.float().mean(),
        }, on_step=True, on_epoch=True, prog_bar=False)
        return loss

    def on_train_epoch_end(self):
        self.log("train_acc", self.train_acc.compute(), on_step=False, on_epoch=True, prog_bar=True)
        if self._train_total > 0:
            self.log("train_pos_fraction_epoch", float(self._train_pos) / float(self._train_total),
                     on_step=False, on_epoch=True)
        self._train_pos = 0; self._train_total = 0; self.train_acc.reset()

    def on_validation_epoch_start(self):
        self._val_pos = 0; self._val_total = 0; self._val_probs.clear(); self._val_labels.clear()

    def validation_step(self, batch, _):
        x, y = batch
        logits = self(x); probs = torch.sigmoid(logits.float())
        self.val_acc.update(probs, y); self.val_auprc.update(probs, y); self.val_auroc.update(probs, y)
        self._val_pos += int(y.sum()); self._val_total += int(y.numel())
        self._val_probs.append(_cpu(probs)); self._val_labels.append(_cpu(y).int())
        self.log("val_loss", self._bce_unweighted(logits, y), on_step=False, on_epoch=True)

    def on_validation_epoch_end(self):
        base_rate = (self._val_pos / max(self._val_total, 1)) if self._val_total > 0 else 0.0
        both_classes = (self._val_pos > 0) and (self._val_pos < self._val_total)
        val_acc = self.val_acc.compute()
        val_auprc = (self.val_auprc.compute() if self._val_pos > 0
                     else torch.as_tensor(base_rate, device=self.device))
        val_auroc = (self.val_auroc.compute() if both_classes
                     else torch.as_tensor(0.5, device=self.device))

        probs = torch.cat(self._val_probs, dim=0) if self._val_probs else torch.empty(0)
        labels = torch.cat(self._val_labels, dim=0) if self._val_labels else torch.empty(0, dtype=torch.int64)
        if probs.numel() != labels.numel(): probs = probs[:labels.numel()]

        rprec = self._rprecision(probs, labels)
        m = int(labels.sum().item()) if labels.numel() > 0 else 1
        prec_m, rec_m = self._precision_recall_at_k(probs, labels, max(1, m))
        prec_2m, rec_2m = self._precision_recall_at_k(probs, labels, max(1, 2*m))
        prec_5m, rec_5m = self._precision_recall_at_k(probs, labels, max(1, 5*m))
        best_f1, best_thr, (tp, fp, tn, fn) = self._best_f1(probs, labels)
        rec_at_p90 = self._recall_at_precision(probs, labels, 0.90)
        f1_macro_05 = self._f1_macro_at_threshold(probs, labels, 0.5)

        print(
            f"[VAL] base_rate={base_rate:.6f}  AUPRC={float(val_auprc):.4f}  AUROC={float(val_auroc):.4f}  "
            f"RPrec={rprec:.4f}  BestF1={best_f1:.4f} @thr={best_thr:.4f}  "
            f"F1-macro@0.5={f1_macro_05:.4f}  Rec@P>=0.90={rec_at_p90:.4f}  "
            f"Conf(TP/FP/TN/FN)={tp}/{fp}/{tn}/{fn}"
        )

        self.log_dict({
            "val_acc": val_acc, "val_auprc": val_auprc, "val_auroc": val_auroc,
            "val_pos_rate": torch.as_tensor(base_rate, device=self.device),
            "val_random_auprc": torch.as_tensor(base_rate, device=self.device),
            "val_rprecision": torch.as_tensor(rprec, device=self.device),
            "val_precision_at_M": torch.as_tensor(prec_m, device=self.device),
            "val_recall_at_M": torch.as_tensor(rec_m, device=self.device),
            "val_precision_at_2M": torch.as_tensor(prec_2m, device=self.device),
            "val_recall_at_2M": torch.as_tensor(rec_2m, device=self.device),
            "val_precision_at_5M": torch.as_tensor(prec_5m, device=self.device),
            "val_recall_at_5M": torch.as_tensor(rec_5m, device=self.device),
            "val_best_f1": torch.as_tensor(best_f1, device=self.device),
            "val_best_f1_threshold": torch.as_tensor(best_thr, device=self.device),
            "val_macro_f1@0.5": torch.as_tensor(f1_macro_05, device=self.device),
            "val_recall_at_precision_0.90": torch.as_tensor(rec_at_p90, device=self.device),
            "val_tp_bestf1": torch.as_tensor(tp, device=self.device),
            "val_fp_bestf1": torch.as_tensor(fp, device=self.device),
            "val_tn_bestf1": torch.as_tensor(tn, device=self.device),
            "val_fn_bestf1": torch.as_tensor(fn, device=self.device),
        }, on_step=False, on_epoch=True, prog_bar=True)

        self.val_acc.reset(); self.val_auprc.reset(); self.val_auroc.reset()
        self._val_probs.clear(); self._val_labels.clear()

    def on_test_epoch_start(self):
        self._test_probs.clear(); self._test_labels.clear()

    def test_step(self, batch, _):
        x, y = batch
        logits = self(x); probs = torch.sigmoid(logits.float())
        self.test_acc.update(probs, y); self.test_auprc.update(probs, y); self.test_auroc.update(probs, y)
        self._test_probs.append(_cpu(probs)); self._test_labels.append(_cpu(y).int())
        self.log("test_loss", self._bce_unweighted(logits, y), on_step=False, on_epoch=True)

    def on_test_epoch_end(self):
        probs = torch.cat(self._test_probs, dim=0) if self._test_probs else torch.empty(0)
        labels = torch.cat(self._test_labels, dim=0) if self._test_labels else torch.empty(0, dtype=torch.int64)
        base_rate = float(labels.float().mean().item()) if labels.numel() > 0 else 0.0
        both_classes = (labels.sum().item() > 0) and (labels.sum().item() < labels.numel())
        try: test_auprc = self.test_auprc.compute()
        except Exception: test_auprc = torch.as_tensor(base_rate, device=self.device)
        try: test_auroc = self.test_auroc.compute() if both_classes else torch.as_tensor(0.5, device=self.device)
        except Exception: test_auroc = torch.as_tensor(0.5, device=self.device)

        rprec = self._rprecision(probs, labels)
        m = int(labels.sum().item()) if labels.numel() > 0 else 1
        prec_m, rec_m = self._precision_recall_at_k(probs, labels, max(1, m))
        best_f1, best_thr, (tp, fp, tn, fn) = self._best_f1(probs, labels)
        rec_at_p90 = self._recall_at_precision(probs, labels, 0.90)
        f1_macro_05 = self._f1_macro_at_threshold(probs, labels, 0.5)

        print(
            f"[TEST] base_rate={base_rate:.6f}  AUPRC={float(test_auprc):.4f}  AUROC={float(test_auroc):.4f}  "
            f"RPrec={rprec:.4f}  BestF1={best_f1:.4f} @thr={best_thr:.4f}  "
            f"F1-macro@0.5={f1_macro_05:.4f}  Rec@P>=0.90={rec_at_p90:.4f}  "
            f"Conf(TP/FP/TN/FN)={tp}/{fp}/{tn}/{fn}"
        )

        self.log_dict({
            "test_acc": self.test_acc.compute(), "test_auprc": test_auprc, "test_auroc": test_auroc,
            "test_rprecision": torch.as_tensor(rprec, device=self.device),
            "test_precision_at_M": torch.as_tensor(prec_m, device=self.device),
            "test_recall_at_M": torch.as_tensor(rec_m, device=self.device),
            "test_best_f1": torch.as_tensor(best_f1, device=self.device),
            "test_best_f1_threshold": torch.as_tensor(best_thr, device=self.device),
            "test_f1_macro@0.5": torch.as_tensor(f1_macro_05, device=self.device),
            "test_recall_at_precision_0.90": torch.as_tensor(rec_at_p90, device=self.device),
        }, on_step=False, on_epoch=True, prog_bar=True)

        # Save predictions as CSV (index, label, probability)
        try:
            os.makedirs("playground", exist_ok=True)
            out_path = os.path.join("playground", "test_predictions.csv")
            with open(out_path, "w", newline="") as f:
                w = csv.writer(f); w.writerow(["index", "label", "probability"])
                for i, (p, y) in enumerate(zip(probs.tolist(), labels.tolist())):
                    w.writerow([i, int(y), float(p)])
            print(f"Saved test predictions to {out_path}")
        except Exception as e:
            print(f"Failed to save test predictions CSV: {e}")

        self.test_acc.reset(); self.test_auprc.reset(); self.test_auroc.reset()
        self._test_probs.clear(); self._test_labels.clear()

    def configure_optimizers(self):
        opt = torch.optim.AdamW(self.parameters(), lr=self.hparams.opt.lr, weight_decay=self.hparams.opt.weight_decay)
        total_epochs = getattr(self.trainer, "max_epochs", 30) or 30
        warm = max(0, int(self.hparams.opt.warmup_epochs))
        if self.hparams.opt.cosine_after_warmup:
            def lr_lambda(epoch):
                if epoch < warm: return (epoch + 1) / max(1, warm)
                t = (epoch - warm) / max(1, total_epochs - warm)
                return 0.5 * (1 + math.cos(math.pi * t))
            return {"optimizer": opt, "lr_scheduler": {"scheduler": torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda)}}
        sch = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode="max", factor=0.5, patience=2)
        return {"optimizer": opt, "lr_scheduler": {"scheduler": sch, "monitor": "val_auprc"}}

# ---------------- Data ----------------
from pnpl.datasets.libribrain2025.word_dataset import LibriBrainWord
from pnpl.datasets.libribrain2025.constants import RUN_KEYS
try:
    from pnpl.datasets.libribrain2025.base import LibriBrainBase
except Exception:
    LibriBrainBase = None

class BalancedBatchSampler(BatchSampler):
    """Oversample positives to reach target fraction per batch (with replacement)."""
    def __init__(self, pos_idx: List[int], neg_idx: List[int], batch_size: int, pos_fraction: float = 0.1):
        assert 0.0 < pos_fraction < 1.0 and len(pos_idx) > 0
        self.p_idx, self.n_idx = pos_idx, neg_idx
        self.batch_size = batch_size
        self.n_pos = max(1, int(round(batch_size * pos_fraction)))
        self.n_neg = batch_size - self.n_pos
        total = len(pos_idx) + len(neg_idx)
        self._epoch_len = max(1, total // batch_size)

    def __iter__(self) -> Iterator[List[int]]:
        p, n = self.p_idx[:], self.n_idx[:]; random.shuffle(p); random.shuffle(n); pi = ni = 0
        while True:
            if pi + self.n_pos > len(p): random.shuffle(p); pi = 0
            if ni + self.n_neg > len(n): random.shuffle(n); ni = 0
            batch = p[pi:pi+self.n_pos] + n[ni:ni+self.n_neg]; pi += self.n_pos; ni += self.n_neg
            random.shuffle(batch); yield batch

    def __len__(self) -> int: return self._epoch_len

class LibriBrainWordDataModule(pl.LightningDataModule):
    def __init__(self, data_path: str, tmin: float=-0.1, tmax: float=0.8, batch_size: int=256,
                 num_workers: int=4, pin_memory: bool=True, standardize_train: bool=True,
                 target_pos_fraction: float = 0.10,
                 val_run_override: Optional[Tuple[str,str,str,str]] = None,
                 test_run_override: Optional[Tuple[str,str,str,str]] = None):
        super().__init__()
        self.data_path, self.tmin, self.tmax = data_path, tmin, tmax
        self.batch_size, self.num_workers, self.pin_memory = batch_size, num_workers, pin_memory
        self.standardize_train = standardize_train
        self.target_pos_fraction = target_pos_fraction
        self.val_run_override = val_run_override; self.test_run_override = test_run_override
        self._train_sampler: Optional[BalancedBatchSampler] = None

    def _available_runs(self):
        cands = [rk for rk in RUN_KEYS if rk[2].startswith('Sherlock')]
        if LibriBrainBase is None: return cands
        def _events_path(su, se, ta, ru):
            return os.path.join(self.data_path, ta, "derivatives", "events",
                                f"sub-{su}_ses-{se}_task-{ta}_run-{ru}_events.tsv")
        def _h5_path(su, se, ta, ru):
            return os.path.join(self.data_path, ta, "derivatives", "serialised",
                                f"sub-{su}_ses-{se}_task-{ta}_run-{ru}_proc-bads+headpos+sss+notch+bp+ds_meg.h5")
        avail=[]
        for s,se,t,r in cands:
            try:
                LibriBrainBase.ensure_file_download(_events_path(s,se,t,r), data_path=self.data_path)
                LibriBrainBase.ensure_file_download(_h5_path(s,se,t,r), data_path=self.data_path)
                avail.append((s,se,t,r))
            except Exception:
                pass
        return avail or cands

    def _hash_key(self, train_runs: List[Tuple[str,str,str,str]]) -> str:
        m = hashlib.sha256()
        m.update(json.dumps({"tmin": self.tmin, "tmax": self.tmax,
                             "runs": sorted(["_".join(x) for x in train_runs])},
                            sort_keys=True).encode("utf-8"))
        return m.hexdigest()[:16]

    def _cache_paths(self, key: str):  # single file for pos/neg index cache
        cache_dir = os.path.join(self.data_path, "_indices"); os.makedirs(cache_dir, exist_ok=True)
        return os.path.join(cache_dir, f"watson_{key}.pt")

    def _try_dataset_labels_fast(self, ds: Dataset) -> Optional[List[int]]:
        for attr in ("labels", "y", "targets", "_labels", "_y", "_targets"):
            if hasattr(ds, attr):
                lab = getattr(ds, attr)
                try:
                    if torch.is_tensor(lab): return lab.view(-1).cpu().int().tolist()
                    return list(map(int, list(lab)))
                except Exception: continue
        return None

    def _build_pos_neg_indices(self, ds: Dataset, cache_file: str) -> Tuple[List[int], List[int]]:
        if os.path.exists(cache_file):
            obj = torch.load(cache_file, map_location="cpu")
            pos_idx = list(map(int, obj["pos_idx"])); neg_idx = list(map(int, obj["neg_idx"]))
            print(f"Index cache: loaded {len(pos_idx)} positives / {len(pos_idx)+len(neg_idx)} total.")
            return pos_idx, neg_idx

        lbls = self._try_dataset_labels_fast(ds)
        if lbls is not None:
            pos_idx = [i for i, y in enumerate(lbls) if int(y) == 1]
            neg_idx = [i for i, y in enumerate(lbls) if int(y) == 0]
            print(f"Index fast-path: found {len(pos_idx)} positives / {len(lbls)} total.")
            torch.save({"pos_idx": pos_idx, "neg_idx": neg_idx}, cache_file)
            return pos_idx, neg_idx

        def _scan(num_workers: int) -> Tuple[List[int], List[int]]:
            print(f"Scanning training labels to build balanced sampler (num_workers={num_workers})…")
            pos_idx, neg_idx, idx = [], [], 0
            loader = DataLoader(ds, batch_size=2048, shuffle=False,
                                num_workers=num_workers, pin_memory=False,
                                persistent_workers=(num_workers > 0),
                                prefetch_factor=2 if num_workers > 0 else None,
                                collate_fn=collate_label_only_xy)
            for ys in loader:
                for y in ys:
                    (pos_idx if y == 1 else neg_idx).append(idx); idx += 1
                if idx % 50000 == 0: print(f"… scanned {idx} samples")
            return pos_idx, neg_idx

        try: pos_idx, neg_idx = _scan(self.num_workers)
        except Exception as e:
            print(f"[Label scan] parallel scan failed ({type(e).__name__}: {e}). Falling back to single-process.")
            pos_idx, neg_idx = _scan(0)

        total = len(pos_idx) + len(neg_idx); frac = len(pos_idx) / max(1, total)
        print(f"Found {len(pos_idx)} positives / {total} total ({frac:.6f}).")
        torch.save({"pos_idx": pos_idx, "neg_idx": neg_idx}, cache_file)
        return pos_idx, neg_idx

    def setup(self, stage: Optional[str]=None):
        all_runs = [rk for rk in self._available_runs()]
        self.val_run  = self.val_run_override  or ('0','12','Sherlock4','1')
        self.test_run = self.test_run_override or ('0','12','Sherlock5','1')
        train_runs = [rk for rk in all_runs if rk not in (self.val_run, self.test_run)]

        self.train_ds = LibriBrainWord(self.data_path, partition="train",
                                       keyword_detection="watson",
                                       preload_files=False,
                                       include_info=False,
                                       positive_buffer=0.25,
                                       standardize=self.standardize_train)
        self.val_ds   = LibriBrainWord(self.data_path, partition="validation",
                                       keyword_detection="watson",
                                       preload_files=False,
                                       include_info=False,
                                       standardize=True,
                                       positive_buffer=0.25,
                                       channel_means=getattr(self.train_ds, "channel_means", None),
                                       channel_stds=getattr(self.train_ds, "channel_stds", None)
                                      )
        self.test_ds  = LibriBrainWord(self.data_path, partition="test",
                                       keyword_detection="watson",
                                       preload_files=False,
                                       include_info=False,
                                       standardize=True,
                                       positive_buffer=0.25,
                                       channel_means=getattr(self.train_ds, "channel_means", None),
                                       channel_stds=getattr(self.train_ds, "channel_stds", None)
                                      )

        key = self._hash_key(train_runs); cache_file = self._cache_paths(key)
        pos_idx, neg_idx = self._build_pos_neg_indices(self.train_ds, cache_file)
        if len(pos_idx) == 0:
            raise RuntimeError("No positive samples found in training set; cannot build balanced sampler.")
        self._pos_idx, self._neg_idx = pos_idx, neg_idx
        self._train_sampler = BalancedBatchSampler(pos_idx, neg_idx, batch_size=self.batch_size,
                                                   pos_fraction=self.target_pos_fraction)

    def estimate_label_stats(self, sample: int = 200_000) -> Tuple[int,int]:
        if hasattr(self, "_pos_idx") and hasattr(self, "_neg_idx"):
            pos, total = len(self._pos_idx), len(self._pos_idx) + len(self._neg_idx)
            if sample < total and total > 0:
                frac = sample / total; pos = max(1, int(round(pos * frac))); total = sample
            return pos, total
        n = len(self.train_ds); k = min(20_000, n)
        idxs = random.sample(range(n), k=k)
        pos = sum(int(self.train_ds[i][1]) for i in idxs)
        return pos, k

    def train_dataloader(self):
        return DataLoader(self.train_ds, batch_sampler=self._train_sampler,
                          num_workers=self.num_workers, pin_memory=self.pin_memory,
                          persistent_workers=(self.num_workers > 0),
                          prefetch_factor=2 if self.num_workers > 0 else None)

    def val_dataloader(self):
        return DataLoader(self.val_ds, batch_size=self.batch_size, shuffle=False,
                          num_workers=self.num_workers, pin_memory=self.pin_memory,
                          persistent_workers=(self.num_workers > 0),
                          prefetch_factor=2 if self.num_workers > 0 else None)

    def test_dataloader(self):
        return DataLoader(self.test_ds, batch_size=self.batch_size, shuffle=False,
                          num_workers=self.num_workers, pin_memory=self.pin_memory,
                          persistent_workers=(self.num_workers > 0),
                          prefetch_factor=2 if self.num_workers > 0 else None)

# ---------------- Train ----------------
def main():
    data_path   = "dataset"
    tmin, tmax  = 0, 0.85
    epochs      = 30
    batch_size  = 256
    lr          = 1e-4
    num_workers = AUTO_WORKERS
    precision = choose_precision()
    devices     = 1
    target_pos_fraction = 0.05
    lse_temperature     = 0.5

    VAL_RUN  = ('0','12','Sherlock4','1')
    TEST_RUN = ('0','12','Sherlock5','1')

    pl.seed_everything(42, workers=True)
    torch.set_float32_matmul_precision('high')

    dm = LibriBrainWordDataModule(
        data_path, tmin, tmax, batch_size, num_workers,
        standardize_train=True, target_pos_fraction=target_pos_fraction
    )
    dm.setup()

    pos, total = dm.estimate_label_stats(sample=200_000)
    base_rate = pos / total
    pos_weight = 1.0  # oversampling-only
    print(f"[Label stats] pos={pos} total={total}  π={base_rate:.6f}  "
          f"target_p={target_pos_fraction:.2f}  pos_weight_eff={pos_weight:.1f}")
    print(f"[Config] window=({tmin:.2f},{tmax:.2f})  sampler_pos={target_pos_fraction:.2f}  "
          f"loss=focal(0.95,2.0)+pairwise(0.5)  pooling=attention")

    neptune_logger = make_neptune_logger(run_name="watson-meg")
    if neptune_logger:
        neptune_logger.log_hyperparams({
            "data_path": data_path, "tmin": tmin, "tmax": tmax,
            "batch_size": batch_size, "lr": lr, "precision": precision,
            "base_rate": base_rate, "target_pos_fraction": target_pos_fraction,
            "loss": "focal(alpha=0.95,gamma=2.0)+pairwise(lambda=0.5)", "pooling": "temporal_attention",
            "pos_weight_eff": pos_weight,
            "val_run": "_".join(VAL_RUN), "test_run": "_".join(TEST_RUN),
            "schedule": "warmup+cosine",
        })

    model = WatsonKeywordPL(
        in_channels=306, pos_weight=pos_weight,
        opt=OptimConfig(lr=lr, weight_decay=1e-4, max_time_shift=4, noise_std=0.01,
                        warmup_epochs=1, cosine_after_warmup=True),
        lse_temperature=lse_temperature
    )

    ckpt_cb = ModelCheckpoint(monitor="val_auprc", mode="max", save_top_k=1, filename="best-val-auprc")
    callbacks = [ckpt_cb, EarlyStopping(monitor="val_auprc", mode="max", patience=6, min_delta=5e-4),
                 LearningRateMonitor(logging_interval="epoch")]

    trainer = pl.Trainer(
        max_epochs=epochs, precision=precision, devices=devices,
        accelerator="gpu" if torch.cuda.is_available() else "cpu",
        callbacks=callbacks, logger=neptune_logger,
        log_every_n_steps=25, gradient_clip_val=1.0,
    )

    trainer.fit(model, datamodule=dm)

    if neptune_logger and ckpt_cb.best_model_path:
        try:
            neptune_logger.experiment["artifacts/checkpoints/best"].upload(ckpt_cb.best_model_path)
            print(f"Neptune: uploaded best checkpoint -> {ckpt_cb.best_model_path}")
        except Exception as e:
            print(f"Neptune: failed to upload checkpoint: {e}")

    trainer.test(model, datamodule=dm)

# if __name__ == "__main__":
#     main()

In [None]:
# %% Speed knobs (precision + dataloader tuning + TF32)
import os, time, itertools
import torch

def choose_precision():
    if not torch.cuda.is_available():
        return "32-true"
    # Prefer BF16 if natively supported; otherwise use FP16 mixed
    if torch.cuda.is_bf16_supported():
        return "bf16-mixed"
    return "16-mixed"

# Enable fast paths on Ampere+/Ada (safe for training here)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True

USE_COMPILE = False  # Potentially useful - we didn't use it in our experiments.

def maybe_compile(model: torch.nn.Module) -> torch.nn.Module:
    if USE_COMPILE and torch.cuda.is_available() and hasattr(torch, "compile"):
        try:
            return torch.compile(model, mode="max-autotune")
        except Exception as e:
            print(f"[compile] fallback (disabled): {e}")
            return model
    return model

# Better dataloader defaults (don’t modify base DM; monkey-patch loader kwargs)
DLOADER_KW = {
    "pin_memory": torch.cuda.is_available(),
    "persistent_workers": True,   # keep workers alive across epochs
    "prefetch_factor": 4,         # * workers queued batches
}

# heuristics for workers: saturate but don’t nuke the box
CPU_COUNT = os.cpu_count() or 8
AUTO_WORKERS = max(4, min(14, CPU_COUNT // 2))
print("Num Workers: ", AUTO_WORKERS)

# --- tiny probe to see if you’re data-bound (optional) ---
def dataloader_probe(dm, max_batches=50):
    """Measure samples/sec of the train loader without running a full epoch."""
    try:
        dl = dm.train_dataloader()
        it = iter(dl)
        # warmup a couple of batches (to start workers)
        for _ in range(3):
            next(it)
        t0 = time.time()
        seen = 0
        for _ in range(max_batches):
            batch = next(it)
            x = batch[0] if isinstance(batch, (tuple, list)) else batch
            bs = x.shape[0] if hasattr(x, "shape") else 1
            seen += bs
        dt = time.time() - t0
        print(f"[probe] ~{seen/max(1,dt):.1f} samples/sec over {max_batches} batches")
    except StopIteration:
        print("[probe] train loader exhausted quickly; increase max_batches to measure.")
    except Exception as e:
        print(f"[probe] probe skipped: {e}")


In [None]:
# %% Buffer-length sweep — impact of positive_buffer & negative_buffer
# Logs + artifacts: playground/buffer/neg=XX_pos=YY/seedZ/
import os, csv, json, shutil, gc, random, hashlib, glob
from typing import List, Tuple, Dict, Optional

import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, EarlyStopping

# Reuse from previous cells:
# - WatsonKeywordPL, OptimConfig, BalancedBatchSampler
# - LibriBrainWordDataModule base class and LibriBrainWord / RUN_KEYS / LibriBrainBase imports

# ----------------- config -----------------
NEG_BUFFERS = [0.00, 0.05, 0.10, 0.15, 0.20]
POS_BUFFERS = [0.00, 0.05, 0.10, 0.15, 0.20, 0.25, 0.30]
SEEDS       = [1, 2, 3]

BASE_DIR = os.path.join("playground", "buffer")
os.makedirs(BASE_DIR, exist_ok=True)

# ----------------- DataModule subclass with buffer control -----------------
class BufferSweepDM(LibriBrainWordDataModule):
    def __init__(self, data_path: str, tmin: float, tmax: float, batch_size: int, num_workers: int,
                 target_pos_fraction: float, pos_buf: float, neg_buf: float,
                 pin_memory: bool=True, standardize_train: bool=True,
                 val_run_override: Optional[Tuple[str,str,str,str]] = None,
                 test_run_override: Optional[Tuple[str,str,str,str]] = None):
        super().__init__(data_path, tmin, tmax, batch_size, num_workers, pin_memory,
                         standardize_train, target_pos_fraction, val_run_override, test_run_override)
        self.pos_buf = float(pos_buf)
        self.neg_buf = float(neg_buf)

    # include buffers in index-cache key so we never reuse wrong label indices
    def _hash_key(self, train_runs: List[Tuple[str,str,str,str]]) -> str:
        m = hashlib.sha256()
        m.update(json.dumps({
            "tmin": self.tmin, "tmax": self.tmax,
            "pos_buf": self.pos_buf, "neg_buf": self.neg_buf,
            "runs": sorted(["_".join(x) for x in train_runs])
        }, sort_keys=True).encode("utf-8"))
        return m.hexdigest()[:16]

    def setup(self, stage: Optional[str]=None):
        all_runs = [rk for rk in self._available_runs()]
        self.val_run  = self.val_run_override  or ('0','12','Sherlock4','1')
        self.test_run = self.test_run_override or ('0','12','Sherlock5','1')
        train_runs = [rk for rk in all_runs if rk not in (self.val_run, self.test_run)]

        # Build datasets with the requested buffers
        self.train_ds = LibriBrainWord(self.data_path, partition="train",
                                       keyword_detection="watson",
                                       preload_files=False,
                                       include_info=False,
                                       positive_buffer=self.pos_buf,
                                       negative_buffer=self.neg_buf,
                                       standardize=self.standardize_train)
        self.val_ds   = LibriBrainWord(self.data_path, partition="validation",
                                       keyword_detection="watson",
                                       preload_files=False,
                                       include_info=False,
                                       standardize=True,
                                       positive_buffer=self.pos_buf,
                                       negative_buffer=self.neg_buf,
                                       channel_means=getattr(self.train_ds, "channel_means", None),
                                       channel_stds=getattr(self.train_ds, "channel_stds", None))
        self.test_ds  = LibriBrainWord(self.data_path, partition="test",
                                       keyword_detection="watson",
                                       preload_files=False,
                                       include_info=False,
                                       standardize=True,
                                       positive_buffer=self.pos_buf,
                                       negative_buffer=self.neg_buf,
                                       channel_means=getattr(self.train_ds, "channel_means", None),
                                       channel_stds=getattr(self.train_ds, "channel_stds", None))

        key = self._hash_key(train_runs); cache_file = self._cache_paths(key)
        pos_idx, neg_idx = self._build_pos_neg_indices(self.train_ds, cache_file)
        if len(pos_idx) == 0:
            raise RuntimeError("No positive samples found in training set; cannot build balanced sampler.")

        self._pos_idx, self._neg_idx = pos_idx, neg_idx
        self._train_sampler = BalancedBatchSampler(pos_idx, neg_idx,
                                                   batch_size=self.batch_size,
                                                   pos_fraction=self.target_pos_fraction)

    def train_dataloader(self):
        # sanity print per epoch
        if isinstance(self._train_sampler, BalancedBatchSampler):
            try: steps = len(self._train_sampler)
            except Exception: steps = "?"
            print(f"[train loader] buffers: pos={self.pos_buf:.2f}, neg={self.neg_buf:.2f}  "
                  f"|P|={len(self._train_sampler.p_idx)} |N|={len(self._train_sampler.n_idx)}  "
                  f"pos_fraction={self.target_pos_fraction:.3f} batch={self.batch_size} steps/epoch={steps}")
        return super().train_dataloader()

# ----------------- helpers -----------------
def ensure_dir(p: str) -> str:
    os.makedirs(p, exist_ok=True); return p

def tag_float(x: float) -> str:
    # e.g., 0.10 -> "0p10", 0.0 -> "0p00"
    return f"{x:.2f}".replace(".", "p")

def csv_nonempty(path: str, min_rows: int = 1) -> bool:
    if not os.path.exists(path): return False
    try:
        with open(path, "r") as f:
            r = csv.reader(f)
            header = next(r, None)
            if header is None: return False
            for i, _ in enumerate(r, start=1):
                if i >= min_rows:
                    return True
    except Exception:
        return False
    return False

def seed_dir_path(dst_cfg_dir: str, seed: int) -> str:
    return os.path.join(dst_cfg_dir, f"seed{seed}")

def seed_preds_path(dst_cfg_dir: str, seed: int) -> str:
    return os.path.join(seed_dir_path(dst_cfg_dir, seed), f"test_predictions_seed{seed}.csv")

def seed_probs_path(dst_cfg_dir: str, seed: int) -> str:
    return os.path.join(seed_dir_path(dst_cfg_dir, seed), f"test_probs_seed{seed}.csv")

def labels_path(dst_cfg_dir: str) -> str:
    return os.path.join(dst_cfg_dir, "test_labels.csv")

def find_best_ckpt(seed_dir: str) -> Optional[str]:
    # Matches ModelCheckpoint filename pattern used below
    cands = sorted(glob.glob(os.path.join(seed_dir, "best-val-auprc-s*/*.ckpt")))  # PL sometimes nests by version
    cands += sorted(glob.glob(os.path.join(seed_dir, "best-val-auprc-s*.ckpt")))
    return cands[0] if cands else None

def split_predictions_csv(src_csv: str, dst_seed_dir: str, dst_cfg_dir: str, seed: int):
    """Create per-seed probs and one labels file at the config level from a full predictions CSV."""
    indices, labels, probs = [], [], []
    with open(src_csv, "r") as f:
        r = csv.DictReader(f)
        for row in r:
            indices.append(int(row["index"]))
            labels.append(int(row["label"]))
            probs.append(float(row["probability"]))

    # probs per-seed
    os.makedirs(dst_seed_dir, exist_ok=True)
    with open(os.path.join(dst_seed_dir, f"test_probs_seed{seed}.csv"), "w", newline="") as f:
        w = csv.writer(f); w.writerow(["index", "probability"])
        for i, p in zip(indices, probs):
            w.writerow([i, p])

    # labels once per-config (+ convenience copy in seed1 dir)
    if not os.path.exists(os.path.join(dst_cfg_dir, "test_labels.csv")):
        with open(os.path.join(dst_cfg_dir, "test_labels.csv"), "w", newline="") as f:
            w = csv.writer(f); w.writerow(["index", "label"])
            for i, y in zip(indices, labels):
                w.writerow([i, y])
        if seed == 1:
            with open(os.path.join(dst_seed_dir, "test_labels.csv"), "w", newline="") as f:
                w = csv.writer(f); w.writerow(["index", "label"])
                for i, y in zip(indices, labels):
                    w.writerow([i, y])

def write_combined_probs(dst_cfg_dir: str):
    lab_path = labels_path(dst_cfg_dir)
    if not os.path.exists(lab_path):
        print(f"[combine] labels not found for {dst_cfg_dir} — skipping")
        return
    with open(lab_path, "r") as f:
        reader = csv.DictReader(f)
        labels = [(int(r["index"]), int(r["label"])) for r in reader]

    all_probs: Dict[int, Dict[int, float]] = {}
    for s in SEEDS:
        p_path = seed_probs_path(dst_cfg_dir, s)
        if not os.path.exists(p_path):
            print(f"[combine] missing probs for seed{s} in {dst_cfg_dir} — skipping combine")
            return
        with open(p_path, "r") as f:
            reader = csv.DictReader(f)
            all_probs[s] = {int(r["index"]): float(r["probability"]) for r in reader}

    combined_path = os.path.join(dst_cfg_dir, "test_probs_all_seeds.csv")
    with open(combined_path, "w", newline="") as f:
        w = csv.writer(f)
        w.writerow(["index", "label", "prob_seed1", "prob_seed2", "prob_seed3"])
        for idx, lab in labels:
            w.writerow([
                idx, lab,
                all_probs[1].get(idx, float("nan")),
                all_probs[2].get(idx, float("nan")),
                all_probs[3].get(idx, float("nan")),
            ])
    print(f"[combine] wrote {combined_path}")

def record_metrics(row_path: str, row: Dict[str, float], header: List[str]):
    exists = os.path.exists(row_path)
    with open(row_path, "a", newline="") as f:
        w = csv.writer(f)
        if not exists:
            w.writerow(header)
        w.writerow([row.get(k, "") for k in header])

def seed_artifacts_ok(dst_cfg_dir: str, seed: int) -> bool:
    """A seed is considered 'done' if its per-seed FULL predictions CSV exists & non-empty.
       We'll backfill probs/labels if missing."""
    full_preds = seed_preds_path(dst_cfg_dir, seed)
    return csv_nonempty(full_preds, min_rows=1)

def backfill_splits_if_missing(dst_cfg_dir: str, seed: int):
    """If test_probs_seed{seed}.csv or test_labels.csv are missing, build them from the per-seed full predictions."""
    full_preds = seed_preds_path(dst_cfg_dir, seed)
    if not os.path.exists(full_preds):  # nothing to do
        return
    need_probs = not os.path.exists(seed_probs_path(dst_cfg_dir, seed))
    need_labels = not os.path.exists(labels_path(dst_cfg_dir))
    if need_probs or need_labels:
        print(f"[seed {seed}] backfilling splits from {full_preds} "
              f"(probs missing? {need_probs}, labels missing? {need_labels})")
        split_predictions_csv(full_preds, seed_dir_path(dst_cfg_dir, seed), dst_cfg_dir, seed)

# ----------------- runner -----------------
def run_one_config(pos_buf: float, neg_buf: float):
    cfg_tag = f"neg={tag_float(neg_buf)}_pos={tag_float(pos_buf)}"
    dst_cfg_dir = ensure_dir(os.path.join(BASE_DIR, cfg_tag))
    summary_csv = os.path.join(dst_cfg_dir, "metrics_summary.csv")

    print(f"\n================  Buffers: pos={pos_buf:.2f}  neg={neg_buf:.2f}  ================\n")
    with open(os.path.join(dst_cfg_dir, "manifest.json"), "w") as f:
        json.dump({"positive_buffer": pos_buf, "negative_buffer": neg_buf}, f, indent=2)

    # quick short-circuit: if every seed has its full predictions, just backfill splits/combine and bail
    all_done = True
    for seed in SEEDS:
        if not seed_artifacts_ok(dst_cfg_dir, seed):
            all_done = False
            break
    if all_done:
        print(f"[skip] all seeds already complete for {cfg_tag}. Ensuring splits & combined...")
        for seed in SEEDS:
            backfill_splits_if_missing(dst_cfg_dir, seed)
        write_combined_probs(dst_cfg_dir)
        return

    for seed in SEEDS:
        torch.cuda.empty_cache(); gc.collect()
        seed_dir = ensure_dir(os.path.join(dst_cfg_dir, f"seed{seed}"))
        print(f"\n--- seed {seed} ---")

        # If already done, just ensure splits exist and continue
        if seed_artifacts_ok(dst_cfg_dir, seed):
            print(f"[skip] found existing predictions for seed {seed}: {seed_preds_path(dst_cfg_dir, seed)}")
            backfill_splits_if_missing(dst_cfg_dir, seed)
            continue

        data_path   = "dataset"
        tmin, tmax  = 0.0, 0.85
        epochs      = 30
        batch_size  = 2048
        lr          = 1e-4
        num_workers = AUTO_WORKERS
        precision   = "bf16-mixed"
        devices     = 1
        target_pos_fraction = 0.05
        lse_temperature     = 0.5

        pl.seed_everything(seed, workers=True)
        torch.set_float32_matmul_precision('high')

        dm = BufferSweepDM(
            data_path=data_path, tmin=tmin, tmax=tmax, batch_size=batch_size, num_workers=num_workers,
            target_pos_fraction=target_pos_fraction, pos_buf=pos_buf, neg_buf=neg_buf,
            pin_memory=True, standardize_train=True
        )
        dm.setup()

        # If there is already a best checkpoint, avoid retraining; just test from ckpt to regenerate predictions.
        existing_best_ckpt = find_best_ckpt(seed_dir)
        if existing_best_ckpt and not os.path.exists(seed_preds_path(dst_cfg_dir, seed)):
            print(f"[fast-path] best checkpoint found for seed {seed} -> test-only: {existing_best_ckpt}")
            trainer = pl.Trainer(
                precision=precision,
                devices=devices,
                accelerator="gpu" if torch.cuda.is_available() else "cpu",
                logger=None,
                default_root_dir=seed_dir,
                log_every_n_steps=25,
            )
            results = trainer.test(model=None, datamodule=dm, ckpt_path=existing_best_ckpt)
            results = results[0] if isinstance(results, list) and results else {}
        else:
            # Full train + test
            model = WatsonKeywordPL(
                in_channels=306, pos_weight=1.0,
                opt=OptimConfig(lr=lr, weight_decay=1e-4, max_time_shift=4, noise_std=0.01,
                                warmup_epochs=1, cosine_after_warmup=True),
                lse_temperature=lse_temperature
            )

            ckpt_cb = ModelCheckpoint(monitor="val_auprc", mode="max", save_top_k=1,
                                      filename=f"best-val-auprc-s{seed}")
            callbacks = [
                ckpt_cb,
                EarlyStopping(monitor="val_auprc", mode="max", patience=6, min_delta=5e-4),
                LearningRateMonitor(logging_interval="epoch"),
            ]
            trainer = pl.Trainer(
                max_epochs=epochs,
                precision=precision,
                devices=devices,
                accelerator="gpu" if torch.cuda.is_available() else "cpu",
                callbacks=callbacks,
                logger=None,
                log_every_n_steps=25,
                gradient_clip_val=1.0,
                default_root_dir=seed_dir,
            )

            # train
            trainer.fit(model, datamodule=dm)

            # test with best checkpoint (module writes playground/test_predictions.csv)
            ckpt_path = ckpt_cb.best_model_path if ckpt_cb.best_model_path else None
            results = trainer.test(model=None, datamodule=dm, ckpt_path=ckpt_path)
            results = results[0] if isinstance(results, list) and results else {}

        # ---- save ALL prediction CSVs under playground/buffer/... ----
        # Preferred source is the test_predictions.csv the module writes.
        src_preds_global = os.path.join("playground", "test_predictions.csv")
        dst_seed_preds   = seed_preds_path(dst_cfg_dir, seed)

        if os.path.exists(src_preds_global):
            shutil.copy2(src_preds_global, dst_seed_preds)
            print(f"[seed {seed}] saved full predictions -> {dst_seed_preds}")
        elif not os.path.exists(dst_seed_preds):
            # As a fallback, if the module didn't write to global path but PL wrote somewhere else,
            # user can still place/rename it to dst_seed_preds and re-run just this backfill section.
            print(f"[seed {seed}] WARNING: expected predictions at {src_preds_global} not found "
                  f"and no existing {dst_seed_preds}. Cannot backfill splits for this seed yet.")

        # Create per-seed probs + labels (once) beside it
        if os.path.exists(dst_seed_preds):
            split_predictions_csv(dst_seed_preds, seed_dir, dst_cfg_dir, seed)

        # metrics row per seed (only if results are available in this run)
        if results:
            header = [
                "positive_buffer", "negative_buffer", "seed",
                "test_acc", "test_auprc", "test_auroc",
                "test_best_f1", "test_best_f1_threshold",
                "test_rprecision", "test_precision_at_M", "test_recall_at_M",
                "test_f1_macro@0.5", "test_recall_at_precision_0.90",
                "train_pos_count", "train_neg_count"
            ]
            row = {
                "positive_buffer": pos_buf, "negative_buffer": neg_buf, "seed": seed,
                "test_acc": results.get("test_acc", ""),
                "test_auprc": results.get("test_auprc", ""),
                "test_auroc": results.get("test_auroc", ""),
                "test_best_f1": results.get("test_best_f1", ""),
                "test_best_f1_threshold": results.get("test_best_f1_threshold", ""),
                "test_rprecision": results.get("test_rprecision", ""),
                "test_precision_at_M": results.get("test_precision_at_M", ""),
                "test_recall_at_M": results.get("test_recall_at_M", ""),
                "test_f1_macro@0.5": results.get("test_f1_macro@0.5", ""),
                "test_recall_at_precision_0.90": results.get("test_recall_at_precision_0.90", ""),
                "train_pos_count": len(getattr(dm, "_pos_idx", [])),
                "train_neg_count": len(getattr(dm, "_neg_idx", [])),
            }
            record_metrics(summary_csv, row, header)

        # cleanup
        torch.cuda.empty_cache(); gc.collect()

    # combined probs for the config
    # (works if all seeds now have probs + labels; otherwise prints a helpful skip note)
    write_combined_probs(dst_cfg_dir)

# ----------------- go! -----------------
total = len(NEG_BUFFERS) * len(POS_BUFFERS)
k = 0
for nb in reversed(NEG_BUFFERS):        # A: NEG asc ...   B: NEG desc
    for pb in reversed(POS_BUFFERS):    # A: POS asc ...   B: POS desc
        k += 1
        print(f"[{k}/{total}] pos={pb:.2f}  neg={nb:.2f}")
        run_one_config(pos_buf=pb, neg_buf=nb)

print("\n✅ Done. All outputs are under:", BASE_DIR)
print("Each config folder contains: seed folders, full test_predictions per seed, per-seed probs, labels, combined probs, and a metrics_summary.csv (for runs done in this session).")
