In [None]:
# Install dependencies
%pip install -q mne_bids lightning torchmetrics scikit-learn plotly ipywidgets neptune

# Set up base path for dataset and related files
base_path = "./libribrain"

# Install pnpl from local modified package
%pip install ../modified-pnpl/pnpl

# Remember to set the NEPTUNE_API_TOKEN and NEPTUNE_PROJECT environment variables
# before running the next cell

In [None]:
#!/usr/bin/env python3
# Keyword spotting on LibriBrain MEG — operationalized runner (scaling sweep)
# - Adds train_fraction to use a subset of the train set
# - Sweeps fractions: 0.05, 0.10, 0.20, 0.40, 0.60, 0.80, 1.00
# - Logs hours_per_epoch and hours_total (samples-consumed * window_seconds)
# - Leaves validation/test sets unchanged

from __future__ import annotations
import os, math, random, json, hashlib, time, platform, subprocess
from dataclasses import dataclass
from typing import Optional, Tuple, List, Iterator
from pathlib import Path
from datetime import datetime

import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_curve, average_precision_score

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, BatchSampler

import lightning as pl
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor, EarlyStopping
from lightning.pytorch.loggers import CSVLogger

# ============================== Sensor mask (optional) ==============================

SENSORS_SPEECH_MASK = [18, 20, 22, 23, 45, 120, 138, 140, 142, 143, 145, 146, 147, 149, 175, 176, 177, 179, 180, 198, 271, 272, 275]

class ChannelMaskedDataset(torch.utils.data.Dataset):
    def __init__(self, base_dataset, channel_indices=None):
        self.base_dataset = base_dataset
        self.channel_indices = channel_indices if channel_indices is not None else SENSORS_SPEECH_MASK
    def __len__(self): return len(self.base_dataset)
    def __getitem__(self, idx):
        x, y = self.base_dataset[idx]
        x_sel = x[self.channel_indices]
        if torch.is_tensor(y) and y.ndim > 0:
            y = y[y.shape[0] // 2]
        return x_sel, y

# ----------------- small helper for fast label scans -----------------

def collate_label_only_xy(batch):  # list[(x,y)] -> list[int]
    return [int(y) for _, y in batch]

# ============================== Neptune (optional) ==============================

def make_neptune_logger(run_name: str | None = None):
    api_key = os.getenv("NEPTUNE_API_TOKEN")
    project = os.getenv("NEPTUNE_PROJECT")
    if not api_key or not project:
        print("Neptune: env vars not found -> skipping Neptune logging.")
        return None
    try:
        from lightning.pytorch.loggers import NeptuneLogger as _BaseNeptuneLogger
    except Exception:
        from pytorch_lightning.loggers import NeptuneLogger as _BaseNeptuneLogger

    class CleanNeptuneLogger(_BaseNeptuneLogger):
        def log_metrics(self, metrics, step: int | None = None):
            filt = {k: v for k, v in metrics.items() if k != "epoch" and not k.endswith("/epoch")}
            super().log_metrics(filt, step=None)

    logger = CleanNeptuneLogger(
        api_key=api_key, project=project, name=run_name,
        tags=["libribrain", "watson", "meg", "keyword-detection", "generalization", "scaling"],
        prefix="training/", log_model_checkpoints=False,
    )
    print("Neptune: ✅ enabled.")
    return logger

# ============================== Model bits ==============================

@dataclass
class OptimConfig:
    lr: float = 2.5e-4
    weight_decay: float = 1e-4
    warmup_epochs: int = 2
    cosine_after_warmup: bool = True
    noise_std: float = 0.01

class FocalLoss(nn.Module):
    def __init__(self, alpha: float = 0.95, gamma: float = 2.0):
        super().__init__()
        self.alpha = float(alpha); self.gamma = float(gamma)
    def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        ce = nn.functional.binary_cross_entropy_with_logits(logits, targets.float(), reduction='none')
        p = torch.sigmoid(logits)
        pt = torch.where(targets == 1, p, 1 - p)
        alpha_t = torch.where(
            targets == 1,
            torch.as_tensor(self.alpha, device=logits.device, dtype=logits.dtype),
            torch.as_tensor(1 - self.alpha, device=logits.device, dtype=logits.dtype),
        )
        loss = alpha_t * (1 - pt).pow(self.gamma) * ce
        return loss.mean()

class ChannelSE(nn.Module):
    def __init__(self, channels: int, reduction: int = 4):
        super().__init__()
        self.avg = nn.AdaptiveAvgPool1d(1)
        self.fc  = nn.Sequential(
            nn.Linear(channels, max(1, channels // reduction), bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(max(1, channels // reduction), channels, bias=False),
            nn.Sigmoid()
        )
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        w = self.avg(x).squeeze(-1)      # (B,C)
        w = self.fc(w).unsqueeze(-1)     # (B,C,1)
        return x * w

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 4000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2) * (-math.log(1e4) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer("pe", pe.unsqueeze(0))  # (1,T,d)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x + self.pe[:, :x.size(1), :]

def _gn(c: int, groups: int = 16):
    g = max(1, min(groups, c))
    return nn.GroupNorm(num_groups=g, num_channels=c)

class EnhancedKeywordNet(nn.Module):
    """
    SE-gated 1-D Conv → 2-layer Transformer → temporal attention pool.
    GroupNorm for domain robustness (vs BN).
    """
    def __init__(self, in_channels: int, d_model: int = 128, n_heads: int = 4, n_layers: int = 2, dropout: float = 0.15):
        super().__init__()
        self.se = ChannelSE(in_channels)

        same = 'same' in nn.Conv1d.__init__.__code__.co_varnames
        k7pad = 'same' if same else 3
        self.stem = nn.Sequential(
            nn.Conv1d(in_channels, d_model, 7, 2, k7pad, bias=False),
            _gn(d_model), nn.ELU(),
            nn.Conv1d(d_model, d_model, 7, 2, k7pad, bias=False),
            _gn(d_model), nn.ELU()
        )

        enc = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=n_heads,
            dim_feedforward=d_model * 4, dropout=dropout, batch_first=True
        )
        self.pe          = PositionalEncoding(d_model)
        self.transformer = nn.TransformerEncoder(enc, num_layers=n_layers)
        self.dropout_h   = nn.Dropout(dropout)

        self.to_logit_t  = nn.Linear(d_model, 1)
        self.to_attn_t   = nn.Linear(d_model, 1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B,C,T) -> logits: (B,)
        x = self.se(x)
        h = self.stem(x)              # (B,d,T')
        h = h.permute(0, 2, 1)        # (B,T',d)
        h = self.transformer(self.pe(h))
        h = self.dropout_h(h)
        logit_t = self.to_logit_t(h).transpose(1, 2)       # (B,1,T')
        attn_t  = torch.softmax(self.to_attn_t(h).transpose(1, 2), dim=-1)
        return (logit_t * attn_t).sum(-1).squeeze(1)

# ============================== EMA helper ==============================

class EMAHelper:
    def __init__(self, module: nn.Module, decay: float = 0.999):
        self.decay = decay
        self.shadow = {n: p.detach().clone()
                       for n, p in module.named_parameters() if p.requires_grad}
        self.backup = None
    @torch.no_grad()
    def update(self, module: nn.Module):
        for n, p in module.named_parameters():
            if not p.requires_grad: continue
            self.shadow[n].mul_(self.decay).add_(p.detach(), alpha=1 - self.decay)
    @torch.no_grad()
    def apply_to(self, module: nn.Module):
        self.backup = {n: p.detach().clone()
                       for n, p in module.named_parameters() if p.requires_grad}
        for n, p in module.named_parameters():
            if not p.requires_grad: continue
            p.copy_(self.shadow[n])
    @torch.no_grad()
    def restore(self, module: nn.Module):
        if self.backup is None: return
        for n, p in module.named_parameters():
            if not p.requires_grad: continue
            p.copy_(self.backup[n])
        self.backup = None

# ============================== Metrics + module ==============================

try:
    from torchmetrics.classification import BinaryAccuracy, BinaryAveragePrecision, BinaryAUROC
except Exception:
    from torchmetrics import Accuracy as BinaryAccuracy                 # type: ignore
    from torchmetrics import AveragePrecision as BinaryAveragePrecision # type: ignore
    from torchmetrics import AUROC as BinaryAUROC                       # type: ignore

class WatsonKeywordPL(pl.LightningModule):
    def __init__(
        self,
        in_channels: int,
        opt: OptimConfig,
        loss_mode: str = "focal_pairwise",
        pairwise_lambda: float = 1.0,
        # priors for correction at inference:
        pi_train: float = 0.10,      # effective positive rate in training batches (sampler)
        pi_target: float = 0.003,    # true base rate in the wild / dataset
    ):
        super().__init__()
        self.save_hyperparameters()
        self.model = EnhancedKeywordNet(in_channels)

        self.criterion = FocalLoss(alpha=0.95, gamma=2.0)
        self.loss_mode = loss_mode
        self.pairwise_lambda = float(pairwise_lambda)

        # torchmetrics
        self.train_acc = BinaryAccuracy()
        self.val_acc   = BinaryAccuracy()
        self.test_acc  = BinaryAccuracy()
        self.val_auprc = BinaryAveragePrecision()
        self.test_auprc= BinaryAveragePrecision()
        self.val_auroc = BinaryAUROC()
        self.test_auroc= BinaryAUROC()

        # buffers for diagnostics & calibration
        self._val_probs: List[torch.Tensor] = []; self._val_labels: List[torch.Tensor] = []
        self._val_logits: List[torch.Tensor] = []
        self._test_probs: List[torch.Tensor] = []; self._test_labels: List[torch.Tensor] = []
        self.register_buffer("temp_scale", torch.tensor(1.0))  # temperature for logits at inference
        self.prior_bias = math.log((pi_target / (1 - pi_target)) / (pi_train / (1 - pi_train)))
        self._ema: Optional[EMAHelper] = None

        # exported at test end (for CSV)
        self.last_test_probs: Optional[torch.Tensor] = None
        self.last_test_labels: Optional[torch.Tensor] = None

    # ---------- helpers ----------
    @staticmethod
    def _pairwise_logistic_loss(logits: torch.Tensor, labels: torch.Tensor, max_pairs: int = 4096) -> torch.Tensor:
        pos_idx = (labels == 1).nonzero(as_tuple=False).view(-1)
        neg_idx = (labels == 0).nonzero(as_tuple=False).view(-1)
        if pos_idx.numel() == 0 or neg_idx.numel() == 0: return logits.new_zeros(())
        num_pairs = min(max_pairs, int(pos_idx.numel()) * int(neg_idx.numel()))
        pi = pos_idx[torch.randint(0, pos_idx.numel(), (num_pairs,), device=logits.device)]
        ni = neg_idx[torch.randint(0, neg_idx.numel(), (num_pairs,), device=logits.device)]
        diff = logits[pi] - logits[ni]
        return torch.nn.functional.softplus(-diff).mean()

    @staticmethod
    def _rprecision(probs: torch.Tensor, labels: torch.Tensor) -> float:
        m = int(labels.sum().item())
        if m <= 0: return 0.0
        k = min(m, probs.numel())
        topk = torch.topk(probs, k=k, largest=True).indices
        return float(labels[topk].float().mean().item())

    @staticmethod
    def _precision_recall_at_k(probs: torch.Tensor, labels: torch.Tensor, k: int) -> Tuple[float,float]:
        k = max(1, min(k, probs.numel()))
        topk = torch.topk(probs, k=k, largest=True).indices
        tp = labels[topk].sum().item()
        prec = tp / k
        total_pos = max(1, int(labels.sum().item()))
        rec = tp / total_pos
        return float(prec), float(rec)

    @staticmethod
    def _best_f1(probs: torch.Tensor, labels: torch.Tensor) -> Tuple[float, float, Tuple[int,int,int,int]]:
        N = probs.numel()
        if N == 0: return 0.0, 0.5, (0,0,0,0)
        sort_idx = torch.argsort(probs, descending=True)
        sorted_labels = labels[sort_idx].to(torch.int32)
        cum_tp = torch.cumsum(sorted_labels, dim=0)
        ks = torch.arange(1, N+1, device=probs.device)
        precision = cum_tp / ks
        total_pos = max(1, int(labels.sum().item()))
        recall = cum_tp / total_pos
        denom = precision + recall
        f1 = torch.where(denom > 0, 2 * precision * recall / denom, torch.zeros_like(denom))
        best_idx = int(torch.argmax(f1).item())
        best_f1 = float(f1[best_idx].item())
        thr = float(probs[sort_idx[best_idx]].item())
        k = best_idx + 1
        tp = int(cum_tp[best_idx].item()); fp = int(k - tp); fn = int(total_pos - tp); tn = int(N - k - fn)
        return best_f1, thr, (tp, fp, tn, fn)

    @staticmethod
    def _recall_at_precision(probs: torch.Tensor, labels: torch.Tensor, min_precision: float = 0.90) -> float:
        N = probs.numel()
        if N == 0: return 0.0
        sort_idx = torch.argsort(probs, descending=True)
        sorted_labels = labels[sort_idx].to(torch.int32)
        cum_tp = torch.cumsum(sorted_labels, dim=0)
        ks = torch.arange(1, N+1, device=probs.device)
        precision = cum_tp / ks
        total_pos = max(1, int(labels.sum().item()))
        recall = cum_tp / total_pos
        mask = precision >= min_precision
        return float(recall[mask].max().item()) if mask.any() else 0.0

    @staticmethod
    def _macro_f1_balanced_acc(probs: torch.Tensor, labels: torch.Tensor, thr: float = 0.5):
        preds = (probs >= thr).to(torch.int32)
        labels_i = labels.to(torch.int32)
        tp = int(((preds == 1) & (labels_i == 1)).sum().item())
        fp = int(((preds == 1) & (labels_i == 0)).sum().item())
        fn = int(((preds == 0) & (labels_i == 1)).sum().item())
        tn = int(((preds == 0) & (labels_i == 0)).sum().item())
        # F1 for pos and neg
        def _f1(tp_, fp_, fn_):
            p = tp_ / max(1, tp_ + fp_)
            r = tp_ / max(1, tp_ + fn_)
            d = p + r
            return float((2 * p * r / d) if d > 0 else 0.0)
        f1_pos = _f1(tp, fp, fn)
        f1_neg = _f1(tn, fn, fp)
        f1_macro = 0.5 * (f1_pos + f1_neg)
        # Balanced accuracy
        tpr = tp / max(1, tp + fn)
        tnr = tn / max(1, tn + fp)
        bal_acc = 0.5 * (tpr + tnr)
        return float(f1_macro), float(bal_acc), (tp, fp, tn, fn)

    @staticmethod
    def _brier(probs: torch.Tensor, labels: torch.Tensor) -> float:
        diff = probs.float() - labels.float()
        return float((diff * diff).mean().item())

    def _save_pr_curve(self, probs: torch.Tensor, labels: torch.Tensor, out_path: Path, title: str = "Validation"):
        try:
            y_true = labels.cpu().numpy()
            y_score = probs.cpu().numpy()
            precision, recall, _ = precision_recall_curve(y_true, y_score)
            ap = average_precision_score(y_true, y_score)
            out_path.parent.mkdir(parents=True, exist_ok=True)
            plt.figure(figsize=(6, 4))
            plt.step(recall, precision, where="post")
            plt.xlabel("Recall"); plt.ylabel("Precision")
            plt.ylim(0.0, 1.05); plt.xlim(0.0, 1.0)
            plt.title(f"{title} PR curve (AP={ap:.4f})")
            plt.grid(True, alpha=0.3)
            plt.tight_layout()
            plt.savefig(str(out_path), bbox_inches="tight")
            plt.close()
        except Exception as e:
            print(f"[PR] Failed to save PR curve to {out_path}: {e}")

    # ---------- Lightning lifecycle ----------
    def forward(self, x): return self.model(x)

    def on_fit_start(self):
        # EMA over backbone for stabler eval
        self._ema = EMAHelper(self.model, decay=0.999)

    def training_step(self, batch, _):
        x, y = batch
        x = self._augment(x)
        logits = self(x)

        focal = self.criterion(logits.float(), y.float())
        pairwise = self._pairwise_logistic_loss(logits, y)
        loss = focal + self.pairwise_lambda * pairwise

        probs = torch.sigmoid(logits.float())
        self.train_acc.update(probs, y)

        # metrics/logs
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=False)
        self.log("train_focal", focal, on_step=True, on_epoch=True)
        self.log("train_pairwise", pairwise, on_step=True, on_epoch=True)
        self.log("train_pos_frac_batch", y.float().mean(), on_step=True, on_epoch=False, prog_bar=True)

        return loss

    def on_after_backward(self):
        if self._ema is not None:
            self._ema.update(self.model)

    def on_train_epoch_end(self):
        self.log("train_acc", self.train_acc.compute(), on_step=False, on_epoch=True, prog_bar=True)
        self.train_acc.reset()

    def on_validation_epoch_start(self):
        if self._ema is not None:
            self._ema.apply_to(self.model)
        self._val_probs.clear(); self._val_labels.clear(); self._val_logits.clear()

    def validation_step(self, batch, _):
        x, y = batch
        logits = self(x)
        probs = torch.sigmoid(logits.float())

        # accumulate
        self.val_acc.update(probs, y); self.val_auprc.update(probs, y); self.val_auroc.update(probs, y)
        self._val_probs.append(probs.detach().float().cpu())
        self._val_labels.append(y.detach().int().cpu())
        self._val_logits.append(logits.detach().float().cpu())

    def _fit_temperature_on_val(self, logits: torch.Tensor, labels: torch.Tensor):
        # Prior-correct logits then fit a scalar temperature to minimize NLL on val
        device = logits.device
        prior_bias = torch.tensor(self.prior_bias, device=device, dtype=logits.dtype)
        z = logits + prior_bias
        t_raw = torch.tensor(0.0, device=device, requires_grad=True)  # T = softplus(t_raw) + eps
        opt = torch.optim.LBFGS([t_raw], lr=0.5, max_iter=50, line_search_fn="strong_wolfe")

        def closure():
            opt.zero_grad()
            T = torch.nn.functional.softplus(t_raw) + 1e-4
            loss = nn.functional.binary_cross_entropy_with_logits(z / T, labels.float())
            loss.backward()
            return loss
        try:
            opt.step(closure)
            with torch.no_grad():
                T = torch.nn.functional.softplus(t_raw) + 1e-4
                self.temp_scale.copy_(T.clamp(1e-3, 100.0))
        except Exception as e:
            print(f"[Calib] Temperature fit failed: {e}. Keeping T=1.0.")
            self.temp_scale.copy_(torch.tensor(1.0, device=self.device))

    def on_validation_epoch_end(self):
        probs = torch.cat(self._val_probs, dim=0) if self._val_probs else torch.empty(0)
        labels= torch.cat(self._val_labels, dim=0) if self._val_labels else torch.empty(0, dtype=torch.int64)
        logits= torch.cat(self._val_logits, dim=0) if self._val_logits else torch.empty(0)

        base_rate = float(labels.float().mean().item()) if labels.numel() > 0 else 0.0
        both_classes = (labels.sum().item() > 0) and (labels.sum().item() < labels.numel())

        try:  val_auprc = self.val_auprc.compute()
        except Exception: val_auprc = torch.as_tensor(base_rate, device=self.device)
        try:  val_auroc = self.val_auroc.compute() if both_classes else torch.as_tensor(0.5, device=self.device)
        except Exception: val_auroc = torch.as_tensor(0.5, device=self.device)
        val_acc = self.val_acc.compute()

        rprec = self._rprecision(probs, labels)
        m = int(labels.sum().item()) if labels.numel() > 0 else 0
        prec_m, rec_m = self._precision_recall_at_k(probs, labels, max(1, m))
        prec_2m, rec_2m = self._precision_recall_at_k(probs, labels, max(1, 2*m))
        prec_5m, rec_5m = self._precision_recall_at_k(probs, labels, max(1, 5*m))
        best_f1, best_thr, (tp, fp, tn, fn) = self._best_f1(probs, labels)
        rec_at_p90 = self._recall_at_precision(probs, labels, 0.90)

        print(f"[VAL] base_rate={base_rate:.6f}  AUPRC={float(val_auprc):.4f}  AUROC={float(val_auroc):.4f}  "
              f"RPrec={rprec:.4f}  BestF1={best_f1:.4f} @thr={best_thr:.4f}  "
              f"Rec@P>=0.90={rec_at_p90:.4f}  Conf(TP/FP/TN/FN)={tp}/{fp}/{tn}/{fn}")

        # log
        self.log("val_acc", val_acc, on_step=False, on_epoch=True)
        self.log("val_auprc", val_auprc, on_step=False, on_epoch=True, prog_bar=True)
        self.log("val_auroc", val_auroc, on_step=False, on_epoch=True)
        self.log("val_pos_rate", torch.as_tensor(base_rate, device=self.device), on_step=False, on_epoch=True)
        self.log("val_random_auprc", torch.as_tensor(base_rate, device=self.device), on_step=False, on_epoch=True)
        self.log("val_rprecision", torch.as_tensor(rprec, device=self.device), on_step=False, on_epoch=True, prog_bar=True)
        self.log("val_precision_at_M", torch.as_tensor(prec_m, device=self.device), on_step=False, on_epoch=True)
        self.log("val_recall_at_M", torch.as_tensor(rec_m, device=self.device), on_step=False, on_epoch=True)
        self.log("val_precision_at_2M", torch.as_tensor(prec_2m, device=self.device), on_step=False, on_epoch=True)
        self.log("val_recall_at_2M", torch.as_tensor(rec_2m, device=self.device), on_step=False, on_epoch=True)
        self.log("val_precision_at_5M", torch.as_tensor(prec_5m, device=self.device), on_step=False, on_epoch=True)
        self.log("val_recall_at_5M", torch.as_tensor(rec_5m, device=self.device), on_step=False, on_epoch=True)
        self.log("val_best_f1", torch.as_tensor(best_f1, device=self.device), on_step=False, on_epoch=True, prog_bar=True)
        self.log("val_best_f1_threshold", torch.as_tensor(best_thr, device=self.device), on_step=False, on_epoch=True)
        self.log("val_recall_at_precision_0.90", torch.as_tensor(rec_at_p90, device=self.device), on_step=False, on_epoch=True, prog_bar=True)

        # extra metrics
        f1_macro, bal_acc, _ = self._macro_f1_balanced_acc(probs, labels, thr=best_thr)
        brier = self._brier(probs, labels)
        self.log("val_f1_macro", torch.as_tensor(f1_macro, device=self.device), on_step=False, on_epoch=True, prog_bar=True)
        self.log("val_bal_acc", torch.as_tensor(bal_acc, device=self.device), on_step=False, on_epoch=True)
        self.log("val_brier", torch.as_tensor(brier, device=self.device), on_step=False, on_epoch=True)

        # PR curve figure
        try:
            out_path = Path(self.trainer.default_root_dir) / "artifacts" / "val_pr_curve.png"
            self._save_pr_curve(probs, labels, out_path, title="Validation")
            print(f"Saved validation PR curve -> {out_path}")
            # optional Neptune upload
            try:
                nexp = None
                for lg in (getattr(self.trainer, "loggers", []) or []):
                    h = getattr(lg, "experiment", None) or getattr(lg, "run", None)
                    if h is not None:
                        nexp = h; break
                if nexp is not None:
                    nexp["artifacts/plots/val_pr_curve"].upload(str(out_path))
            except Exception as e:
                print(f"[PR] Neptune upload skipped: {e}")
        except Exception as e:
            print(f"[PR] Failed to generate validation PR curve: {e}")

        # temperature scaling on validation (after logs)
        if logits.numel() > 0 and labels.sum().item() > 0:
            self._fit_temperature_on_val(logits.to(self.device), labels.to(self.device))

        # resets
        self.val_acc.reset(); self.val_auprc.reset(); self.val_auroc.reset()
        if self._ema is not None:
            self._ema.restore(self.model)

    def on_test_epoch_start(self):
        if self._ema is not None:
            self._ema.apply_to(self.model)
        self._test_probs.clear(); self._test_labels.clear()

    def test_step(self, batch, _):
        x, y = batch
        logits = self(x)
        # prior correction + temperature scaling for calibrated inference
        z = logits + self.prior_bias
        z = z / self.temp_scale.clamp(1e-3, 100.0)
        probs = torch.sigmoid(z.float())

        self.test_acc.update(probs, y); self.test_auprc.update(probs, y); self.test_auroc.update(probs, y)
        self._test_probs.append(probs.detach().float().cpu())
        self._test_labels.append(y.detach().int().cpu())

    def on_test_epoch_end(self):
        probs = torch.cat(self._test_probs, dim=0) if self._test_probs else torch.empty(0)
        labels= torch.cat(self._test_labels, dim=0) if self._test_labels else torch.empty(0, dtype=torch.int64)
        base_rate = float(labels.float().mean().item()) if labels.numel() > 0 else 0.0
        both_classes = (labels.sum().item() > 0) and (labels.sum().item() < labels.numel())

        try:  test_auprc = self.test_auprc.compute()
        except Exception: test_auprc = torch.as_tensor(base_rate, device=self.device)
        try:  test_auroc = self.test_auroc.compute() if both_classes else torch.as_tensor(0.5, device=self.device)
        except Exception: test_auroc = torch.as_tensor(0.5, device=self.device)

        rprec = self._rprecision(probs, labels)
        m = int(labels.sum().item()) if labels.numel() > 0 else 0
        prec_m, rec_m = self._precision_recall_at_k(probs, labels, max(1, m))
        best_f1, best_thr, (tp, fp, tn, fn) = self._best_f1(probs, labels)
        rec_at_p90 = self._recall_at_precision(probs, labels, 0.90)

        print(f"[TEST] base_rate={base_rate:.6f}  AUPRC={float(test_auprc):.4f}  AUROC={float(test_auroc):.4f}  "
              f"RPrec={rprec:.4f}  BestF1={best_f1:.4f} @thr={best_thr:.4f}  "
              f"Rec@P>=0.90={rec_at_p90:.4f}  Conf=~(TP/FP/TN/FN)={tp}/{fp}/{tn}/{fn}")

        self.log("test_acc", self.test_acc.compute(), on_step=False, on_epoch=True)
        self.log("test_auprc", test_auprc, on_step=False, on_epoch=True, prog_bar=True)
        self.log("test_auroc", test_auroc, on_step=False, on_epoch=True)
        self.log("test_rprecision", torch.as_tensor(rprec, device=self.device), on_step=False, on_epoch=True, prog_bar=True)
        self.log("test_precision_at_M", torch.as_tensor(prec_m, device=self.device), on_step=False, on_epoch=True)
        self.log("test_recall_at_M", torch.as_tensor(rec_m, device=self.device), on_step=False, on_epoch=True)

        # extra metrics (test)
        f1_macro, bal_acc, _ = self._macro_f1_balanced_acc(probs, labels, thr=best_thr)
        brier = self._brier(probs, labels)
        self.log("test_f1_macro", torch.as_tensor(f1_macro, device=self.device), on_step=False, on_epoch=True, prog_bar=True)
        self.log("test_bal_acc", torch.as_tensor(bal_acc, device=self.device), on_step=False, on_epoch=True)
        self.log("test_brier", torch.as_tensor(brier, device=self.device), on_step=False, on_epoch=True)

        # PR curve figure (test)
        try:
            out_path = Path(self.trainer.default_root_dir) / "artifacts" / "test_pr_curve.png"
            self._save_pr_curve(probs, labels, out_path, title="Test")
            print(f"Saved test PR curve -> {out_path}")
            # optional Neptune upload
            try:
                nexp = None
                for lg in (getattr(self.trainer, "loggers", []) or []):
                    h = getattr(lg, "experiment", None) or getattr(lg, "run", None)
                    if h is not None:
                        nexp = h; break
                if nexp is not None:
                    nexp["artifacts/plots/test_pr_curve"].upload(str(out_path))
            except Exception as e:
                print(f"[PR] Neptune upload skipped: {e}")
        except Exception as e:
            print(f"[PR] Failed to generate test PR curve: {e}")

        # export for CSV in main()
        self.last_test_probs = probs.cpu()
        self.last_test_labels = labels.cpu()

        self.test_acc.reset(); self.test_auprc.reset(); self.test_auroc.reset()
        if self._ema is not None:
            self._ema.restore(self.model)

    # ---------- aug ----------
    def _augment(self, x: torch.Tensor) -> torch.Tensor:
        if not self.training: return x
        B, C, T = x.shape

        # ±5% circular time-shift
        max_s = int(0.05 * T)
        if max_s > 0:
            shifts = torch.randint(-max_s, max_s + 1, (B,), device=x.device)
            for i, s in enumerate(shifts):
                if int(s) != 0:
                    x[i] = torch.roll(x[i], int(s), dims=-1)

        # Channel dropout (simulate dead sensors) 20%
        drop_mask = (torch.rand(B, C, 1, device=x.device) > 0.20).float()
        x = x * drop_mask

        # Gain jitter
        gains = torch.empty(B, 1, 1, device=x.device).uniform_(0.9, 1.1)
        x = x * gains

        # Gaussian noise
        sigma = self.hparams.opt.noise_std
        if sigma and sigma > 0: x = x + torch.randn_like(x) * sigma
        return x

    # ---------- optim ----------
    def configure_optimizers(self):
        opt = torch.optim.AdamW(self.parameters(), lr=self.hparams.opt.lr, weight_decay=self.hparams.opt.weight_decay)
        total_epochs = getattr(self.trainer, "max_epochs", 50) or 50
        warm = max(0, int(self.hparams.opt.warmup_epochs))
        if self.hparams.opt.cosine_after_warmup:
            def lr_lambda(epoch):
                if epoch < warm: return (epoch + 1) / max(1, warm)
                t = (epoch - warm) / max(1, total_epochs - warm)
                return 0.5 * (1 + math.cos(math.pi * t))
            sch = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda)
            return {"optimizer": opt, "lr_scheduler": {"scheduler": sch}}
        else:
            sch = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode="max", factor=0.5, patience=2)
            return {"optimizer": opt, "lr_scheduler": {"scheduler": sch, "monitor": "val_auprc"}}

# ============================== Data ==============================

from pnpl.datasets.libribrain2025.word_dataset import LibriBrainWord
from pnpl.datasets.libribrain2025.constants import RUN_KEYS
try:
    from pnpl.datasets.libribrain2025.base import LibriBrainBase
except Exception:
    LibriBrainBase = None

class BalancedBatchSampler(BatchSampler):
    """Oversample positives to reach target fraction per batch (sampling with replacement)."""
    def __init__(self, pos_idx: List[int], neg_idx: List[int], batch_size: int, pos_fraction: float = 0.1):
        assert 0.0 < pos_fraction < 1.0
        assert len(pos_idx) > 0 and len(neg_idx) > 0, "BalancedBatchSampler requires at least one pos and neg."
        self.p_idx, self.n_idx = pos_idx, neg_idx
        self.batch_size = batch_size
        self.n_pos = max(1, int(round(batch_size * pos_fraction)))
        self.n_neg = self.batch_size - self.n_pos
        total = len(pos_idx) + len(neg_idx)
        self._epoch_len = max(1, total // batch_size)

    def __iter__(self) -> Iterator[List[int]]:
        p_shuf = self.p_idx[:]; n_shuf = self.n_idx[:]
        random.shuffle(p_shuf); random.shuffle(n_shuf)
        pi = ni = 0
        while True:
            if pi + self.n_pos > len(p_shuf): random.shuffle(p_shuf); pi = 0
            if ni + self.n_neg > len(n_shuf): random.shuffle(n_shuf); ni = 0
            batch = p_shuf[pi:pi+self.n_pos] + n_shuf[ni:ni+self.n_neg]
            pi += self.n_pos; ni += self.n_neg
            random.shuffle(batch)
            yield batch

    def __len__(self) -> int: return self._epoch_len

class LibriBrainWordDataModule(pl.LightningDataModule):
    def __init__(
        self, data_path: str, tmin: float=-0.1, tmax: float=0.8, batch_size: int=256,
        num_workers: int=4, pin_memory: bool=True, standardize_train: bool=True,
        target_pos_fraction: float = 0.10,
        val_run_override: Optional[Tuple[str,str,str,str]] = None,
        test_run_override: Optional[Tuple[str,str,str,str]] = None,
        apply_channel_mask: bool = False,
        channel_indices: Optional[List[int]] = None,
        use_balanced_sampling: bool = True,
        train_fraction: float = 1.0,
        fraction_seed: int = 1234,
    ):
        super().__init__()
        self.data_path, self.tmin, self.tmax = data_path, tmin, tmax
        self.batch_size, self.num_workers, self.pin_memory = batch_size, num_workers, pin_memory
        self.standardize_train = standardize_train
        self.target_pos_fraction = target_pos_fraction
        self.val_run_override = val_run_override
        self.test_run_override = test_run_override
        self.apply_channel_mask = apply_channel_mask
        self.channel_indices = channel_indices
        self.use_balanced_sampling = use_balanced_sampling
        self.train_fraction = float(train_fraction)
        self.fraction_seed = int(fraction_seed)

        self.val_run: Optional[Tuple[str,str,str,str]] = None
        self.test_run: Optional[Tuple[str,str,str,str]] = None
        self.train_ds: Optional[Dataset] = None
        self.val_ds: Optional[Dataset] = None
        self.test_ds: Optional[Dataset] = None
        self._train_sampler: Optional[BalancedBatchSampler] = None

        # effective (possibly sub-sampled) index lists
        self._pos_idx: List[int] = []
        self._neg_idx: List[int] = []
        self._eff_pos_idx: List[int] = []
        self._eff_neg_idx: List[int] = []
        self._eff_total: int = 0

    def _available_runs(self):
        cands = [rk for rk in RUN_KEYS if rk[2].startswith('Sherlock')]
        if LibriBrainBase is None: return cands
        def _events_path(su, se, ta, ru):
            return os.path.join(self.data_path, ta, "derivatives", "events",
                                f"sub-{su}_ses-{se}_task-{ta}_run-{ru}_events.tsv")
        def _h5_path(su, se, ta, ru):
            return os.path.join(self.data_path, ta, "derivatives", "serialised",
                                f"sub-{su}_ses-{se}_task-{ta}_run-{ru}_proc-bads+headpos+sss+notch+bp+ds_meg.h5")
        avail=[]
        for s,se,t,r in cands:
            try:
                LibriBrainBase.ensure_file_download(_events_path(s,se,t,r), data_path=self.data_path)
                LibriBrainBase.ensure_file_download(_h5_path(s,se,t,r), data_path=self.data_path)
                avail.append((s,se,t,r))
            except Exception:
                pass
        return avail or cands

    def _hash_key(self, train_runs: List[Tuple[str,str,str,str]]) -> str:
        m = hashlib.sha256()
        payload = {"tmin": self.tmin, "tmax": self.tmax, "runs": sorted(list(map(lambda x: "_".join(x), train_runs)))}
        m.update(json.dumps(payload, sort_keys=True).encode("utf-8"))
        return m.hexdigest()[:16]

    def _cache_paths(self, key: str):
        cache_dir = os.path.join(self.data_path, "_indices")
        os.makedirs(cache_dir, exist_ok=True)
        return os.path.join(cache_dir, f"watson_{key}.pt")

    def _try_dataset_labels_fast(self, ds: Dataset) -> Optional[List[int]]:
        for attr in ("labels", "y", "targets", "_labels", "_y", "_targets"):
            if hasattr(ds, attr):
                lab = getattr(ds, attr)
                try:
                    if torch.is_tensor(lab):
                        return lab.view(-1).cpu().int().tolist()
                    return list(map(int, list(lab)))
                except Exception:
                    continue
        return None

    def _build_pos_neg_indices(self, ds: Dataset, cache_file: str) -> Tuple[List[int], List[int]]:
        if os.path.exists(cache_file):
            try:
                obj = torch.load(cache_file, map_location="cpu")
                pos_idx = list(map(int, obj["pos_idx"]))
                neg_idx = list(map(int, obj["neg_idx"]))
                # validate against current dataset length
                n = len(ds)
                pos_idx = [i for i in pos_idx if 0 <= i < n]
                neg_idx = [i for i in neg_idx if 0 <= i < n]
                if len(pos_idx) == 0 or len(neg_idx) == 0:
                    raise ValueError("empty index lists after validation")
                print(f"Index cache: loaded {len(pos_idx)} positives / {len(pos_idx)+len(neg_idx)} total.")
                return pos_idx, neg_idx
            except Exception as e:
                print(f"Index cache invalid ({type(e).__name__}: {e}); rebuilding indices…")

        lbls = self._try_dataset_labels_fast(ds)
        if lbls is not None:
            pos_idx = [i for i, y in enumerate(lbls) if int(y) == 1]
            neg_idx = [i for i, y in enumerate(lbls) if int(y) == 0]
            print(f"Index fast-path: found {len(pos_idx)} positives / {len(lbls)} total.")
            torch.save({"pos_idx": pos_idx, "neg_idx": neg_idx}, cache_file)
            return pos_idx, neg_idx

        def _scan(num_workers: int) -> Tuple[List[int], List[int]]:
            print(f"Scanning training labels (num_workers={num_workers})…")
            pos_idx, neg_idx = [], []
            loader = DataLoader(
                ds, batch_size=2048, shuffle=False,
                num_workers=num_workers, pin_memory=False,
                persistent_workers=(num_workers > 0),
                prefetch_factor=2 if num_workers > 0 else None,
                collate_fn=collate_label_only_xy,
            )
            idx = 0
            for ys in loader:
                for y in ys:
                    (pos_idx if y == 1 else neg_idx).append(idx)
                    idx += 1
                if idx % 50000 == 0:
                    print(f"… scanned {idx} samples")
            return pos_idx, neg_idx

        try: pos_idx, neg_idx = _scan(self.num_workers)
        except Exception as e:
            print(f"[Label scan] parallel failed ({type(e).__name__}: {e}). Falling back to single-process.")
            pos_idx, neg_idx = _scan(0)

        total = len(pos_idx) + len(neg_idx)
        print(f"Found {len(pos_idx)} positives / {total} total ({len(pos_idx)/max(1,total):.6f}).")
        torch.save({"pos_idx": pos_idx, "neg_idx": neg_idx}, cache_file)
        return pos_idx, neg_idx

    def _subsample_indices(self, pos_idx: List[int], neg_idx: List[int], frac: float, seed: int) -> Tuple[List[int], List[int]]:
        frac = float(max(0.0001, min(1.0, frac)))
        if frac >= 0.9999:  # full
            return pos_idx[:], neg_idx[:]
        rng = random.Random(seed)
        k_pos = max(1, int(round(len(pos_idx) * frac)))
        k_neg = max(1, int(round(len(neg_idx) * frac)))
        pos_sub = rng.sample(pos_idx, k_pos) if k_pos < len(pos_idx) else pos_idx[:]
        neg_sub = rng.sample(neg_idx, k_neg) if k_neg < len(neg_idx) else neg_idx[:]
        return pos_sub, neg_sub

    def _infer_window_seconds(self) -> float:
        """
        Try to infer per-sample window seconds from dataset metadata;
        fall back to (tmax - tmin).
        """
        # Attempt via dataset attributes
        for ds in (self.train_ds, self.val_ds, self.test_ds):
            if ds is None: continue
            for attr_pair in (("tmin", "tmax"),):
                if hasattr(ds, attr_pair[0]) and hasattr(ds, attr_pair[1]):
                    try:
                        return float(getattr(ds, "tmax") - getattr(ds, "tmin"))
                    except Exception:
                        pass
            for attr in ("window_seconds", "win_seconds", "segment_seconds"):
                if hasattr(ds, attr):
                    try:
                        return float(getattr(ds, attr))
                    except Exception:
                        pass
            # sfreq heuristic
            if hasattr(ds, "sfreq") or hasattr(ds, "sample_rate"):
                try:
                    x0, _ = ds[0]
                    sr = float(getattr(ds, "sfreq", getattr(ds, "sample_rate", 250.0)))
                    return float(x0.shape[-1]) / sr
                except Exception:
                    pass
        # fallback to config window
        return float(self.tmax - self.tmin)

    def setup(self, stage: Optional[str]=None):
        all_runs = [rk for rk in self._available_runs()]

        # Fixed splits
        self.val_run  = self.val_run_override  or ('0','12','Sherlock4','1')
        self.test_run = self.test_run_override or ('0','12','Sherlock5','1')

        # exclude val/test run from training
        train_runs = [rk for rk in all_runs if rk not in (self.val_run, self.test_run)]

        # datasets
        self.train_ds = LibriBrainWord(
            self.data_path,
            partition="train",
            # include_run_keys=train_runs,
            # tmin=self.tmin,
            # tmax=self.tmax,
            keyword_detection="watson",
            preload_files=False, include_info=False,
            standardize=self.standardize_train
        )
        self.val_ds = LibriBrainWord(
            self.data_path,
            partition="validation",
            # include_run_keys=[self.val_run],
            # tmin=self.tmin,
            # tmax=self.tmax,
            keyword_detection="watson",
            preload_files=False,
            include_info=False,
            standardize=True,
            channel_means=getattr(self.train_ds, "channel_means", None),
            channel_stds=getattr(self.train_ds, "channel_stds", None),
        )
        self.test_ds = LibriBrainWord(
            self.data_path,
            partition="test",
            # include_run_keys=[self.test_run],
            # tmin=self.tmin,
            # tmax=self.tmax,
            keyword_detection="watson",
            preload_files=False, include_info=False,
            standardize=True,
            channel_means=getattr(self.train_ds, "channel_means", None),
            channel_stds=getattr(self.train_ds, "channel_stds", None),
        )

        # optional channel mask wrapping
        if self.apply_channel_mask:
            ch_idx = self.channel_indices if self.channel_indices is not None else SENSORS_SPEECH_MASK
            self.train_ds = ChannelMaskedDataset(self.train_ds, channel_indices=ch_idx)
            self.val_ds   = ChannelMaskedDataset(self.val_ds,   channel_indices=ch_idx)
            self.test_ds  = ChannelMaskedDataset(self.test_ds,  channel_indices=ch_idx)

        # derive in_channels
        try:
            x0, _ = self.train_ds[0]
            self.num_channels = int(x0.shape[0])
        except Exception:
            self.num_channels = 306

        key = self._hash_key(train_runs)
        cache_file = self._cache_paths(key)
        pos_idx, neg_idx = self._build_pos_neg_indices(self.train_ds, cache_file)
        if len(pos_idx) == 0 or len(neg_idx) == 0:
            raise RuntimeError("Need both positive and negative samples in training set.")
        self._pos_idx = pos_idx; self._neg_idx = neg_idx

        # ---- NEW: apply train_fraction sub-sampling (deterministic) ----
        self._eff_pos_idx, self._eff_neg_idx = self._subsample_indices(
            self._pos_idx, self._neg_idx, self.train_fraction, self.fraction_seed
        )
        self._eff_total = len(self._eff_pos_idx) + len(self._eff_neg_idx)
        if self._eff_total == 0:
            raise RuntimeError("train_fraction produced an empty effective train set.")

        # sampler
        if self.use_balanced_sampling:
            self._train_sampler = BalancedBatchSampler(
                self._eff_pos_idx, self._eff_neg_idx,
                batch_size=self.batch_size, pos_fraction=self.target_pos_fraction
            )
        else:
            self._train_sampler = None

        # handy stats
        self.window_seconds = self._infer_window_seconds()

    def estimate_label_stats(self, sample: int = 200_000) -> Tuple[int,int]:
        if hasattr(self, "_pos_idx") and hasattr(self, "_neg_idx") and (self._pos_idx or self._neg_idx):
            pos = len(self._pos_idx)
            total = len(self._pos_idx) + len(self._neg_idx)
            if sample < total and total > 0:
                frac = sample / total
                pos = max(1, int(round(pos * frac))); total = sample
            return pos, total
        n = len(self.train_ds)
        k = min(20_000, n)
        idxs = random.sample(range(n), k=k)
        pos = sum(int(self.train_ds[i][1]) for i in idxs)
        return pos, k

    def train_dataloader(self):
        if self._train_sampler is not None:
            return DataLoader(self.train_ds, batch_sampler=self._train_sampler,
                              num_workers=self.num_workers, pin_memory=self.pin_memory,
                              persistent_workers=(self.num_workers > 0),
                              prefetch_factor=2 if self.num_workers > 0 else None)
        return DataLoader(self.train_ds, batch_size=self.batch_size, shuffle=True,
                          num_workers=self.num_workers, pin_memory=self.pin_memory,
                          persistent_workers=(self.num_workers > 0),
                          prefetch_factor=2 if self.num_workers > 0 else None)

    def val_dataloader(self):
        return DataLoader(self.val_ds, batch_size=self.batch_size, shuffle=False,
                          num_workers=self.num_workers, pin_memory=self.pin_memory,
                          persistent_workers=(self.num_workers > 0),
                          prefetch_factor=2 if self.num_workers > 0 else None)

    def test_dataloader(self):
        return DataLoader(self.test_ds, batch_size=self.batch_size, shuffle=False,
                          num_workers=self.num_workers, pin_memory=self.pin_memory,
                          persistent_workers=(self.num_workers > 0),
                          prefetch_factor=2 if self.num_workers > 0 else None)

# ============================== Utility: run dirs & metadata ==============================

def get_git_commit() -> str:
    try:
        return subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]).decode().strip()
    except Exception:
        return "n/a"

def system_info() -> dict:
    return {
        "python": platform.python_version(),
        "platform": platform.platform(),
        "torch": torch.__version__,
        "pytorch_lightning": pl.__version__,
        "cuda_available": torch.cuda.is_available(),
        "cuda_device_count": torch.cuda.device_count(),
        "cuda_name": torch.cuda.get_device_name(0) if torch.cuda.is_available() else "cpu",
        "git_commit": get_git_commit(),
        "time_utc": datetime.utcnow().isoformat() + "Z",
    }

# ============================== Train / Test (SCALING SWEEP) ==============================

def main():
    # ---- core hyper-params (unchanged defaults) ----
    data_path   = os.getenv("DATA_PATH", "dataset")
    tmin, tmax  = -0.05, 1
    epochs      = 10
    batch_size  = 256
    lr          = 2.5e-4
    num_workers = 4
    precision   = "bf16-mixed"
    devices     = 1

    # Sampling & priors
    target_pos_fraction = 0.10   # balanced batch pos-rate
    VAL_RUN  = ('0','12','Sherlock4','1')   # fixed for reproducibility
    TEST_RUN = ('0','12','Sherlock5','1')

    # Fractions to sweep
    FRACTIONS = [0.05, 0.10, 0.20, 0.40, 0.60, 0.80, 1.00]

    # ---- deterministic-ish setup ----
    pl.seed_everything(42, workers=True)
    torch.set_float32_matmul_precision('high')

    # ---- run root ----
    run_root = Path(os.getenv("OUT_DIR", "runs"))
    ts = datetime.utcnow().strftime("%Y%m%d_%H%M%S")

    scaling_summary = []

    for frac in FRACTIONS:
        tag = f"frac{int(round(frac*100)):02d}"
        run_dir = run_root / f"watson_meg_{ts}_{tag}"
        ckpt_dir = run_dir / "checkpoints"
        art_dir  = run_dir / "artifacts"
        for d in (ckpt_dir, art_dir):
            d.mkdir(parents=True, exist_ok=True)

        # ---- data (per fraction) ----
        dm = LibriBrainWordDataModule(
            data_path, tmin, tmax, batch_size, num_workers,
            standardize_train=True, target_pos_fraction=target_pos_fraction,
            val_run_override=VAL_RUN, test_run_override=TEST_RUN,
            apply_channel_mask=False, channel_indices=None,
            use_balanced_sampling=True,
            train_fraction=frac, fraction_seed=42,
        )
        dm.setup()

        # Estimate base prevalence from TRAIN universe (not balanced, full train universe)
        pos, total = dm.estimate_label_stats(sample=200_000)
        base_rate = max(1, pos) / max(2, total)  # guard tiny denominators
        print(f"[Label stats] pos={pos} total={total}  π={base_rate:.6f}  sampler_pos={target_pos_fraction:.2f}  (train_fraction={frac:.2f})")

        # ---- compute hours accounting ----
        # batches/epoch: from sampler when used, else ceil(len(train)/bs)
        if dm._train_sampler is not None:
            batches_per_epoch = len(dm._train_sampler)
        else:
            batches_per_epoch = math.ceil(len(dm.train_ds) / dm.batch_size)
        samples_per_epoch = batches_per_epoch * dm.batch_size
        window_seconds = getattr(dm, "window_seconds", float(tmax - tmin))
        hours_per_epoch = (samples_per_epoch * window_seconds) / 3600.0
        hours_total = hours_per_epoch * epochs

        # ---- loggers (per fraction) ----
        csv_logger = CSVLogger(save_dir=str(run_dir), name="pl_logs", version="")
        neptune_logger = make_neptune_logger(run_name=f"watson-meg-generalization-{tag}")
        loggers = [csv_logger] + ([neptune_logger] if neptune_logger else [])

        # Neptune hyperparams
        if neptune_logger:
            neptune_logger.log_hyperparams({
                "data_path": data_path, "tmin": tmin, "tmax": tmax,
                "batch_size": batch_size, "lr": lr, "precision": precision,
                "base_rate": base_rate, "target_pos_fraction": target_pos_fraction,
                "loss": "focal(alpha=0.95,gamma=2.0)+pairwise(lambda=1.0)",
                "pooling": "temporal_attention",
                "schedule": "warmup+cosine",
                "val_run": "_".join(VAL_RUN), "test_run": "_".join(TEST_RUN),
                "norm": "GroupNorm", "ema": "0.999",
                "train_fraction": frac,
                "window_seconds": window_seconds,
                "batches_per_epoch": batches_per_epoch,
                "samples_per_epoch": samples_per_epoch,
                "hours_per_epoch": hours_per_epoch,
                "hours_total": hours_total,
                "prior_bias": float(math.log((base_rate/(1-base_rate)) / (target_pos_fraction/(1-target_pos_fraction)))),
            })

        # persist run config & system info locally
        (run_dir / "config.json").write_text(json.dumps({
            "data_path": data_path, "tmin": tmin, "tmax": tmax,
            "batch_size": batch_size, "lr": lr, "precision": precision,
            "epochs": epochs, "devices": devices,
            "val_run": VAL_RUN, "test_run": TEST_RUN,
            "target_pos_fraction": target_pos_fraction,
            "train_fraction": frac,
            "window_seconds": window_seconds,
            "batches_per_epoch": batches_per_epoch,
            "samples_per_epoch": samples_per_epoch,
            "hours_per_epoch": hours_per_epoch,
            "hours_total": hours_total,
        }, indent=2))
        (run_dir / "system_info.json").write_text(json.dumps(system_info(), indent=2))

        print(f"[Scaling] {tag} -> batches/epoch={batches_per_epoch}  samples/epoch={samples_per_epoch}  "
              f"window_s={window_seconds:.3f}  hours/epoch={hours_per_epoch:.3f}  hours_total={hours_total:.3f}")

        # ---- model ----
        model = WatsonKeywordPL(
            in_channels=getattr(dm, "num_channels", 306),
            opt=OptimConfig(lr=lr, weight_decay=1e-4, warmup_epochs=2, cosine_after_warmup=True, noise_std=0.01),
            loss_mode="focal_pairwise", pairwise_lambda=1.0,
            pi_train=target_pos_fraction, pi_target=base_rate
        )

        # ---- callbacks ----
        ckpt_cb = ModelCheckpoint(monitor="val_auprc", mode="max", save_top_k=1,
                                  filename=f"best-val-auprc-ema-calib-{tag}", dirpath=str(ckpt_dir))
        callbacks = [
            ckpt_cb,
            EarlyStopping(monitor="val_auprc", mode="max", patience=3, min_delta=1e-3),
            LearningRateMonitor(logging_interval="epoch"),
        ]

        # ---- trainer ----
        trainer = pl.Trainer(
            max_epochs=epochs,
            precision=precision,
            devices=devices,
            accelerator="gpu" if torch.cuda.is_available() else "cpu",
            callbacks=callbacks,
            logger=loggers,
            log_every_n_steps=25,
            gradient_clip_val=1.0,
            detect_anomaly=False,
            default_root_dir=str(run_dir),
        )

        # ---- train ----
        t0 = time.time()
        trainer.fit(model, datamodule=dm)
        fit_secs = time.time() - t0

        # save "last" checkpoint (state after training) locally
        last_ckpt_path = ckpt_dir / f"last_{tag}.ckpt"
        trainer.save_checkpoint(str(last_ckpt_path))

        # optionally upload best ckpt to Neptune (still saved locally regardless)
        if neptune_logger and ckpt_cb.best_model_path:
            try:
                # lightning neptune logger API varies; try both keys
                exp = getattr(neptune_logger, "experiment", None) or getattr(neptune_logger, "run", None)
                if exp is not None:
                    exp["artifacts/checkpoints/best"].upload(ckpt_cb.best_model_path)
                print(f"Neptune: uploaded best checkpoint -> {ckpt_cb.best_model_path}")
            except Exception as e:
                print(f"Neptune: failed to upload checkpoint: {e}")

        # ---- test ----
        t1 = time.time()
        trainer.test(model, datamodule=dm)
        test_secs = time.time() - t1

        # ---- export test probabilities CSV ----
        probs = getattr(model, "last_test_probs", None)
        labels= getattr(model, "last_test_labels", None)
        csv_path = art_dir / f"test_probabilities_{tag}.csv"
        if probs is not None and labels is not None and probs.numel() == labels.numel():
            import csv
            with csv_path.open("w", newline="") as f:
                w = csv.writer(f)
                w.writerow(["index", "prob", "label"])  # prob is post-calibration (prior+temperature)
                for i, (p, y) in enumerate(zip(probs.tolist(), labels.tolist())):
                    w.writerow([i, float(p), int(y)])
            print(f"Saved test probabilities -> {csv_path}")
        else:
            print("[WARN] Test probabilities unavailable; CSV not written.")

        # ---- write a compact metrics summary ----
        summary = {
            "train_fraction": frac,
            "fit_seconds": round(fit_secs, 2),
            "test_seconds": round(test_secs, 2),
            "best_ckpt": ckpt_cb.best_model_path,
            "last_ckpt": str(last_ckpt_path),
            "window_seconds": window_seconds,
            "batches_per_epoch": batches_per_epoch,
            "samples_per_epoch": samples_per_epoch,
            "hours_per_epoch": hours_per_epoch,
            "hours_total": hours_total,
        }
        try:
            final_metrics = {k: (v.item() if hasattr(v, 'item') else v)
                             for k, v in trainer.callback_metrics.items()}
            summary["final_metrics"] = final_metrics
        except Exception:
            pass
        (run_dir / "metrics_summary.json").write_text(json.dumps(summary, indent=2))

        print("\n=== RUN ARTIFACTS ===")
        print(f"Run dir: {run_dir}")
        print(f" - Best checkpoint: {ckpt_cb.best_model_path or 'n/a'}")
        print(f" - Last checkpoint: {last_ckpt_path}")
        print(f" - CSV metrics (per-epoch): {Path(csv_logger.log_dir) / 'metrics.csv'}")
        print(f" - Test probabilities: {csv_path if csv_path.exists() else 'n/a'}\n")

        # Collect for a global summary
        auprc = summary.get("final_metrics", {}).get("test_auprc", None)
        auroc = summary.get("final_metrics", {}).get("test_auroc", None)
        scaling_summary.append({
            "fraction": frac,
            "hours_total": hours_total,
            "hours_per_epoch": hours_per_epoch,
            "batches_per_epoch": batches_per_epoch,
            "samples_per_epoch": samples_per_epoch,
            "test_auprc": float(auprc) if auprc is not None else None,
            "test_auroc": float(auroc) if auroc is not None else None,
            "run_dir": str(run_dir),
        })

        # free memory between runs
        del trainer, model, dm
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    # write a top-level sweep summary next to the timestamped root of this sweep
    sweep_dir = run_root / f"watson_meg_{ts}_SWEEP_SUMMARY"
    sweep_dir.mkdir(parents=True, exist_ok=True)
    (sweep_dir / "scaling_summary.json").write_text(json.dumps(scaling_summary, indent=2))
    # also CSV
    try:
        import csv
        with (sweep_dir / "scaling_summary.csv").open("w", newline="") as f:
            w = csv.DictWriter(f, fieldnames=list(scaling_summary[0].keys()))
            w.writeheader(); w.writerows(scaling_summary)
        print(f"[Sweep] Wrote summary -> {sweep_dir / 'scaling_summary.csv'}")
    except Exception as e:
        print(f"[Sweep] Could not write CSV summary: {e}")

if __name__ == "__main__":
    main()
