# 03 – Transformer Fine-Tuning (Linear Probe, One-Hot Path)

This notebook trains a **linear probe** on top of a frozen encoder for binary regulatory DNA classification.

**Design**
- **Input format**: one-hot tensors `(N, L, 4)` saved in `data/processed/{train,val,test}.npz` with keys `X` and `y`.
- **Projection**: a small `Linear(4 → hidden)` is used **only** when tokenizer is not used (one-hot path).
- **Encoder**: optional Hugging Face model via `MODEL_NAME`; if loading fails or is omitted, a small placeholder encoder is used for quick tests.
- **Safety**: DataLoaders use `num_workers=0` (notebook-safe); tensors are moved to device inside the training loop.
- **Outputs**: checkpoints in `results/checkpoints/`, best model in `results/probe_best.pt`, metrics appended to `results/metrics.csv`.


In [76]:
import os, sys, math, random, time
from pathlib import Path
from typing import Dict, Any

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import roc_auc_score, average_precision_score
from tqdm.auto import tqdm

try:
    from transformers import AutoModel, AutoTokenizer, AutoConfig, get_linear_schedule_with_warmup
except Exception:
    AutoModel = AutoTokenizer = AutoConfig = get_linear_schedule_with_warmup = None

SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED)
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', DEVICE)


Device: cpu


## Config

In [77]:
# Paths
PROC = Path('data/processed')
RESULTS = Path('results'); RESULTS.mkdir(parents=True, exist_ok=True)
CKPT_DIR = RESULTS / 'checkpoints'; CKPT_DIR.mkdir(parents=True, exist_ok=True)

# Encoder model (optional). Leave empty to skip HF and use TinyEnc for one-hot runs.
MODEL_NAME = os.environ.get('PRETRAINED_MODEL', '')  # '' means: use TinyEnc
FULL_FINETUNE = False          # set True to unfreeze encoder
MAX_LEN = 512                  # only relevant if tokenizer is used

# Training hyperparameters
EPOCHS = 3
BATCH_TRAIN = 64
BATCH_EVAL  = 128
HEAD_LR = 1e-3
ENCODER_LR = 1e-5
WEIGHT_DECAY = 1e-2


## Data loading (.npz with X/y)

In [78]:
class SequenceDataset(Dataset):
    """Loads an .npz split containing either (X,y) or (inputs,labels).
    This project uses one-hot, i.e., X=(N,L,4), y=(N,)."""
    def __init__(self, path: str):
        z = np.load(path, allow_pickle=True)
        if 'X' in z and 'y' in z:
            self.inputs, self.labels = z['X'], z['y']
        elif 'inputs' in z and 'labels' in z:
            # Fallback: text inputs (not used in one-hot path)
            self.inputs, self.labels = z['inputs'], z['labels']
        else:
            raise ValueError(f'Unrecognized schema in {path}. Expected X/y.')
        assert len(self.inputs)==len(self.labels), 'length mismatch'

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        return {'input': self.inputs[idx], 'label': int(self.labels[idx])}

train_path = PROC/'train.npz'
val_path   = PROC/'val.npz'
test_path  = PROC/'test.npz'

train_ds = SequenceDataset(str(train_path)) if train_path.exists() else None
val_ds   = SequenceDataset(str(val_path))   if val_path.exists()   else None
test_ds  = SequenceDataset(str(test_path))  if test_path.exists()  else None

print('Datasets:',
      f"train={len(train_ds) if train_ds else None}",
      f"val={len(val_ds) if val_ds else None}",
      f"test={len(test_ds) if test_ds else None}")


Datasets: train=24 val=13 test=13


In [79]:
# Print label class counts for quick debug (why AUROC may be degenerate)
import numpy as np
import csv

def get_labels(ds, path):
    if ds is not None:
        labels = getattr(ds, 'labels', None)
        if labels is None:
            try:
                z = np.load(path, allow_pickle=True)
                labels = z.get('y') if 'y' in z else z.get('labels')
            except Exception:
                return None
        return np.asarray(labels).astype(int)
    return None

train_labels = get_labels(train_ds, train_path)
val_labels = get_labels(val_ds, val_path)
test_labels = get_labels(test_ds, test_path)

splits = {'train': train_labels, 'val': val_labels, 'test': test_labels}

# determine maximum class index across splits to create CSV columns
max_class = 0
for v in splits.values():
    if v is not None and v.size > 0:
        max_class = max(max_class, int(v.max()))

ncols = max_class + 1
rows = []
for name, labels in splits.items():
    if labels is None:
        counts = [0] * ncols
    else:
        bc = np.bincount(labels, minlength=ncols)
        counts = [int(x) for x in bc]
    rows.append((name, counts))

# print counts
for name, counts in rows:
    for i, c in enumerate(counts):
        print(f"{name} class {i}: {c}")

# write CSV summary (overwrite)
out_path = RESULTS / 'preproc_summary.csv'
header = ['split'] + [f'class_{i}' for i in range(ncols)]
with open(out_path, 'w', newline='') as f:
    w = csv.writer(f)
    w.writerow(header)
    for name, counts in rows:
        w.writerow([name] + counts)

print(f"Wrote label counts to {out_path}")

# Compute pos_weight for BCEWithLogitsLoss (binary case) and expose to later cells
pos_weight = None
try:
    if train_labels is not None:
        uniq = np.unique(train_labels)
        if uniq.size == 2:
            neg = int((train_labels == 0).sum())
            pos = int((train_labels == 1).sum())
            if pos > 0:
                import torch
                pos_weight = torch.tensor(float(neg) / float(pos), dtype=torch.float32).to(DEVICE)
                print(f"Computed pos_weight={pos_weight.item():.4f} (neg={neg}, pos={pos})")
            else:
                print('No positive examples in train set; pos_weight not set')
        else:
            print('pos_weight only computed for binary problems; found classes:', uniq)
except Exception as e:
    print('Could not compute pos_weight:', e)


train class 0: 12
train class 1: 12
val class 0: 6
val class 1: 7
test class 0: 7
test class 1: 6
Wrote label counts to results/preproc_summary.csv
Computed pos_weight=1.0000 (neg=12, pos=12)


## Collate function (one-hot aware, tokenizer-aware)

In [80]:
def make_collate_fn(tokenizer=None, max_length=512):
    def collate(batch):
        labels = torch.tensor([b['label'] for b in batch], dtype=torch.float)
        xs = [b['input'] for b in batch]

        # Tokenizer path (not used in one-hot, but supported)
        if tokenizer is not None and isinstance(xs[0], (str, bytes)):
            enc = tokenizer(list(xs), padding='longest', truncation=True,
                             max_length=max_length, return_tensors='pt')
            if 'attention_mask' not in enc:
                enc['attention_mask'] = torch.ones_like(enc['input_ids'])
            return {'input_ids': enc['input_ids'],
                    'attention_mask': enc['attention_mask'],
                    'label': labels}

        # One-hot arrays: (L,4) → stack to (B,L,4)
        arrs = [torch.as_tensor(x) for x in xs]
        if arrs[0].ndim == 2 and arrs[0].shape[-1] == 4:
            return {'embeddings': torch.stack(arrs), 'label': labels}

        # 1D token ids (fallback)
        if arrs[0].ndim == 1:
            maxlen = max(a.shape[0] for a in arrs)
            ids = torch.zeros((len(arrs), maxlen), dtype=torch.long)
            for i,a in enumerate(arrs): ids[i,:a.shape[0]] = a
            return {'input_ids': ids, 'attention_mask': (ids!=0).long(), 'label': labels}

        # Fallback
        return {'input': xs, 'label': labels}
    return collate


## Encoder and Linear Probe

In [81]:
import warnings

class LinearProbe(nn.Module):
    def __init__(self, encoder: nn.Module, encoder_dim: int, freeze_encoder: bool = True, proj_in_features: int | None = None):
        """encoder: HF model or TinyEnc that returns last_hidden_state (B,L,H)
        encoder_dim: H
        proj_in_features: if not None, constructs a projection from proj_in_features -> encoder_dim
                         and expects embeddings of shape (B,L,proj_in_features)
        """
        super().__init__()
        self.encoder = encoder
        if freeze_encoder and self.encoder is not None:
            for p in self.encoder.parameters():
                p.requires_grad = False

        self.proj = nn.Linear(proj_in_features, encoder_dim).to(DEVICE) if proj_in_features is not None else None
        if self.proj is not None:
            # simple, stable init
            nn.init.xavier_uniform_(self.proj.weight)
            if self.proj.bias is not None:
                nn.init.zeros_(self.proj.bias)

        self.classifier = nn.Linear(encoder_dim, 1)

    def forward(self, input_ids=None, attention_mask=None, embeddings=None):
        if embeddings is None:
            # Hugging Face encoder path
            out = self.encoder(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
            last_hidden = out.last_hidden_state
        else:
            last_hidden = embeddings  # expected (B,L,proj_in_features) if proj exists

        # apply proj if present (applies to last dimension)
        if self.proj is not None:
            last_hidden = last_hidden.to(self.proj.weight.device).float()
            last_hidden = self.proj(last_hidden)  # (B,L,H)

        pooled = last_hidden.mean(dim=1)  # (B,H)
        return self.classifier(pooled).squeeze(-1)


def load_encoder(model_name: str, default_dim: int = 128):
    """Attempt to load a HF encoder; on failure return TinyEnc fallback and clear messages.
    This function specially handles ModuleNotFoundError (e.g., missing triton) to give actionable guidance.
    """
    class TinyEnc(nn.Module):
        def __init__(self, dim=default_dim):
            super().__init__()
            self.dim = dim
        def forward(self, input_ids=None, attention_mask=None, return_dict=True):
            B = input_ids.shape[0] if input_ids is not None else 8
            L = input_ids.shape[1] if input_ids is not None else 256
            return type('X', (), {'last_hidden_state': torch.randn(B, L, self.dim, device=DEVICE)})

    if not model_name:
        return None, TinyEnc(default_dim).to(DEVICE), default_dim

    if AutoModel is None:
        warnings.warn("transformers package not available; using TinyEnc fallback.")
        return None, TinyEnc(default_dim).to(DEVICE), default_dim

    try:
        tok = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
        enc = AutoModel.from_pretrained(model_name, trust_remote_code=True)
        cfg = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
        print('Loaded HF encoder:', model_name)
        return tok, enc.to(DEVICE), getattr(cfg, 'hidden_size', default_dim)
    except ModuleNotFoundError as e:
        msg = str(e).lower()
        if 'triton' in msg:
            print(
                "HF model requires 'triton' which is not installed in this environment.\n"
                "Options:\n"
                "  1) Set PRETRAINED_MODEL='' (or None) to use the TinyEnc fallback for quick, notebook-safe runs.\n"
                "  2) Install triton following official instructions if you need the HF encoder (may require matching CUDA/toolkit).\n"
                "     Example: pip install triton  # verify CUDA/toolkit compatibility first.\n"
                "Note: Installing triton can be CUDA/version specific; avoid automatic installs inside the notebook." 
            )
        else:
            print('HF load failed with ModuleNotFoundError:', e)
        return None, TinyEnc(default_dim).to(DEVICE), default_dim
    except Exception as e:
        print('HF load failed, using tiny placeholder. Error:', e)
        return None, TinyEnc(default_dim).to(DEVICE), default_dim

# load encoder and construct probe
tokenizer, encoder, encoder_dim = load_encoder(MODEL_NAME)
proj_in = 4 if tokenizer is None else None
probe = LinearProbe(encoder, encoder_dim=encoder_dim, freeze_encoder=not FULL_FINETUNE, proj_in_features=proj_in).to(DEVICE)

# Clear any legacy global PROJ variable to avoid double-projection in the running kernel
PROJ = None

print('Encoder dim:', encoder_dim, '| PROJ:', 'enabled' if proj_in is not None else 'disabled')


HF load failed, using tiny placeholder. Error: replace/with-safetensors-checkpoint is not a local folder and is not a valid model identifier listed on 'https://huggingface.co/models'
If this is a private repository, make sure to pass a token having permission to this repo either by logging in with `hf auth login` or by passing `token=<your_token>`
Encoder dim: 128 | PROJ: enabled


## DataLoaders (notebook-safe)

In [82]:
if train_ds is not None:
    collate = make_collate_fn(tokenizer, MAX_LEN)
    # If we have training labels, create a WeightedRandomSampler to rebalance classes
    sampler = None
    try:
        if 'train_labels' in globals() and train_labels is not None:
            # small smoothing to avoid division by zero
            counts = np.bincount(train_labels)
            class_weights = 1.0 / (counts + 1e-8)
            sample_weights = class_weights[train_labels]
            sampler = torch.utils.data.WeightedRandomSampler(weights=sample_weights.tolist(),
                                                            num_samples=len(sample_weights),
                                                            replacement=True)
    except Exception as e:
        print('Could not build sampler:', e)
        sampler = None

    train_loader = DataLoader(train_ds, batch_size=BATCH_TRAIN,
                              shuffle=(sampler is None), sampler=sampler,
                              collate_fn=collate, num_workers=0)
    val_loader   = DataLoader(val_ds,   batch_size=BATCH_EVAL,  shuffle=False, collate_fn=collate, num_workers=0) if val_ds else None
    test_loader  = DataLoader(test_ds,  batch_size=BATCH_EVAL,  shuffle=False, collate_fn=collate, num_workers=0) if test_ds else None
    # Smoke test a batch
    try:
        b = next(iter(train_loader))
        print({k: (v.shape if hasattr(v,'shape') else type(v)) for k,v in b.items()})
    except Exception as e:
        print('Smoke test failed:', e)
else:
    train_loader = val_loader = test_loader = None
    print('No datasets found.')


{'embeddings': torch.Size([24, 2000, 4]), 'label': torch.Size([24])}


## Train/Eval helpers

In [83]:
# Loss: use pos_weight if available to rebalance BCEWithLogits
try:
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight) if ('pos_weight' in globals() and pos_weight is not None) else nn.BCEWithLogitsLoss()
except Exception:
    criterion = nn.BCEWithLogitsLoss()

def forward_with_optional_proj(model, batch):
    # Prefer model-local projection (model.proj) when available. Do not use a global PROJ variable.
    if 'embeddings' in batch:
        x = batch['embeddings'].to(DEVICE).float()  # (B,L,4) or (B,L,proj_in)
        # Let the model handle projection internally if it has one
        return model(embeddings=x)
    else:
        ids  = batch['input_ids'].to(DEVICE)
        mask = batch['attention_mask'].to(DEVICE)
        return model(input_ids=ids, attention_mask=mask)

def train_epoch(model, loader, optimizer, scheduler=None):
    model.train()
    losses = []
    for batch in tqdm(loader, leave=False):
        y = batch['label'].to(DEVICE).float()
        logits = forward_with_optional_proj(model, batch)
        loss = criterion(logits, y)
        optimizer.zero_grad(); loss.backward(); optimizer.step()
        if scheduler is not None: scheduler.step()
        losses.append(loss.item())
    return float(np.mean(losses)) if losses else float('nan')

@torch.no_grad()
def eval_epoch(model, loader):
    model.eval()
    losses, ys, ps = [], [], []
    for batch in tqdm(loader, leave=False):
        y = batch['label'].to(DEVICE).float()
        logits = forward_with_optional_proj(model, batch)
        loss = criterion(logits, y)
        p = torch.sigmoid(logits)
        losses.append(loss.item())
        ys.append(y.detach().cpu().numpy())
        ps.append(p.detach().cpu().numpy())
    if not ys: return float('nan'), float('nan'), float('nan')
    y = np.concatenate(ys); p = np.concatenate(ps)
    # Flatten
    y = y.ravel(); p = p.ravel()
    # If only one class present, AUROC/PR-AUC are undefined. Return NaN instead of raising.
    unique_classes = np.unique(y)
    if unique_classes.size < 2:
        auroc = float('nan')
        prauc = float('nan')
    else:
        auroc = roc_auc_score(y, p)
        prauc = average_precision_score(y, p)
    return float(np.mean(losses)), auroc, prauc

def save_ckpt(model, path):
    torch.save(model.state_dict(), path)


## Train the linear probe

In [84]:
if train_loader is not None:
    # Build optimizer with correct param groups (classifier, optional proj, optional encoder)
    def build_optimizer(probe, head_lr=HEAD_LR, enc_lr=ENCODER_LR, weight_decay=WEIGHT_DECAY, full_finetune=FULL_FINETUNE):
        params = []
        params.append({'params': probe.classifier.parameters(), 'lr': head_lr, 'weight_decay': weight_decay})
        if getattr(probe, 'proj', None) is not None:
            params.append({'params': probe.proj.parameters(), 'lr': head_lr, 'weight_decay': weight_decay})
        if full_finetune:
            enc_params = [p for p in probe.encoder.parameters() if p.requires_grad]
            if len(enc_params):
                params.append({'params': enc_params, 'lr': enc_lr, 'weight_decay': weight_decay})
        return torch.optim.AdamW(params)

    optimizer = build_optimizer(probe)

    # Early stopping training loop
    def train_one_epoch(model, loader, optimizer, device=DEVICE):
        model.train()
        losses = []
        for batch in loader:
            y = batch['label'].to(device).float()
            logits = forward_with_optional_proj(model, batch)
            loss = criterion(logits, y)
            optimizer.zero_grad()
            loss.backward()
            # Gradient clipping for stability on small/imbalanced splits
            try:
                torch.nn.utils.clip_grad_norm_(probe.parameters(), max_norm=1.0)
            except Exception:
                pass
            optimizer.step()
            losses.append(float(loss))
        return float(np.mean(losses)) if losses else float('nan')

    @torch.no_grad()
    def eval_metrics(model, loader, device=DEVICE):
        if loader is None:
            return float('nan'), float('nan'), float('nan')
        model.eval()
        losses, ys, ps = [], [], []
        for batch in loader:
            y = batch['label'].to(device).float()
            logits = forward_with_optional_proj(model, batch)
            loss = criterion(logits, y)
            p = torch.sigmoid(logits)
            losses.append(float(loss))
            ys.append(y.detach().cpu().numpy())
            ps.append(p.detach().cpu().numpy())
        if not ys:
            return float('nan'), float('nan'), float('nan')
        y = np.concatenate(ys); p = np.concatenate(ps)
        auroc = roc_auc_score(y, p) if np.unique(y).size > 1 else float('nan')
        ap    = average_precision_score(y, p) if np.unique(y).size > 1 else float('nan')
        return float(np.mean(losses)), auroc, ap

    def fit_with_early_stopping(model, train_loader, val_loader, optimizer, epochs=EPOCHS, patience=5, ckpt_dir=CKPT_DIR, results_dir=RESULTS):
        ckpt_dir.mkdir(parents=True, exist_ok=True)
        best = {'epoch': -1, 'val_auc': -np.inf, 'val_ap': 0.0}
        wait = 0
        history = []
        for ep in range(epochs):
            t0 = time.time()
            tr_loss = train_one_epoch(model, train_loader, optimizer)
            va_loss, va_auc, va_ap = eval_metrics(model, val_loader)
            dt = time.time() - t0
            ep_path = ckpt_dir / f'probe_epoch{ep}.pt'
            save_ckpt(model, ep_path)
            improved = not np.isnan(va_auc) and va_auc > best['val_auc']
            if improved:
                best = {'epoch': ep, 'val_auc': va_auc, 'val_ap': va_ap}
                save_ckpt(model, results_dir / 'probe_best.pt')
                wait = 0
            else:
                wait += 1
            print(f"Epoch {ep:02d} | train_loss={tr_loss:.4f} val_loss={va_loss:.4f} AUROC={va_auc:.4f} PR-AUC={va_ap:.4f} [{dt:.1f}s] {'*' if improved else ''}")
            # Append per-epoch metrics to CSV for later plotting/analysis
            try:
                epoch_log_path = RESULTS / "epoch_metrics.csv"
                write_header = not epoch_log_path.exists()
                with open(epoch_log_path, "a", newline="") as f:
                    w = csv.writer(f)
                    if write_header:
                        w.writerow(["epoch","train_loss","val_loss","val_auroc","val_prauc"])
                    w.writerow([ep, f"{tr_loss:.6f}", f"{va_loss:.6f}", f"{va_auc:.6f}", f"{va_ap:.6f}"])
            except Exception:
                pass
            history.append({'epoch': ep, 'train_loss': tr_loss, 'val_loss': va_loss, 'val_auc': va_auc, 'val_ap': va_ap})
            if wait >= patience:
                print(f"Early stopping (no val AUROC improvement in {patience} epochs).")
                break
        # End for ep
        # return results
        print('Best validation:', best)
        return best, history
    # Run training
    best, history = fit_with_early_stopping(probe, train_loader, val_loader, optimizer, epochs=EPOCHS, patience=5, ckpt_dir=CKPT_DIR, results_dir=RESULTS)
    # Save run config snapshot for reproducibility
    try:
        import json, time
        cfg = {
            "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
            "seed": int(SEED),
            "device": str(DEVICE),
            "model_name": MODEL_NAME or "TinyEnc(one-hot)",
            "full_finetune": bool(FULL_FINETUNE),
            "epochs": int(EPOCHS),
            "batch_train": int(BATCH_TRAIN),
            "batch_eval": int(BATCH_EVAL),
            "head_lr": float(HEAD_LR),
            "encoder_lr": float(ENCODER_LR),
            "weight_decay": float(WEIGHT_DECAY),
            "train_n": len(train_ds) if train_ds else 0,
            "val_n": len(val_ds) if val_ds else 0,
            "test_n": len(test_ds) if test_ds else 0
        }
        with open(RESULTS / "run_config.json", "w") as f:
            json.dump(cfg, f, indent=2)
    except Exception:
        pass

Epoch 00 | train_loss=0.6990 val_loss=0.6958 AUROC=0.5000 PR-AUC=0.5385 [0.0s] *
Epoch 01 | train_loss=0.6959 val_loss=0.6954 AUROC=0.5000 PR-AUC=0.5385 [0.0s] 
Epoch 02 | train_loss=0.6956 val_loss=0.6951 AUROC=0.5000 PR-AUC=0.5385 [0.0s] 
Best validation: {'epoch': 0, 'val_auc': 0.5, 'val_ap': 0.5384615384615384}


## Final evaluation on test set & logging

In [85]:
@torch.no_grad()
def collect_scores(model, loader, device=DEVICE):
    model.eval()
    ys, ps = [], []
    for batch in loader:
        y = batch['label'].to(device).float()
        logits = forward_with_optional_proj(model, batch)
        p = torch.sigmoid(logits)
        ys.append(y.detach().cpu().numpy())
        ps.append(p.detach().cpu().numpy())
    if not ys:
        return None, None
    return np.concatenate(ys), np.concatenate(ps)

def bootstrap_ci(metric_fn, y, p, n_boot=1000, seed=42):
    rng = np.random.default_rng(seed)
    n = len(y)
    vals = np.empty(n_boot, dtype=np.float64)
    idx = np.arange(n)
    for b in range(n_boot):
        j = rng.choice(idx, size=n, replace=True)
        try:
            vals[b] = metric_fn(y[j], p[j])
        except Exception:
            vals[b] = np.nan
    vals = vals[~np.isnan(vals)]
    if len(vals) == 0:
        return np.nan, (np.nan, np.nan)
    lo, hi = np.percentile(vals, [2.5, 97.5])
    return float(np.mean(vals)), (float(lo), float(hi))

# Load best model for test scoring
best_path = RESULTS / 'probe_best.pt'
if best_path.exists():
    probe.load_state_dict(torch.load(best_path, map_location=DEVICE))

test_auc, test_ap = float('nan'), float('nan')
if test_loader is not None:
    y_test, p_test = collect_scores(probe, test_loader)
    if y_test is not None:
        # Point estimates
        test_auc = roc_auc_score(y_test, p_test) if np.unique(y_test).size > 1 else float('nan')
        test_ap  = average_precision_score(y_test, p_test) if np.unique(y_test).size > 1 else float('nan')
        # Bootstrap CIs
        auc_mean, (auc_lo, auc_hi) = bootstrap_ci(roc_auc_score, y_test, p_test, n_boot=1000, seed=SEED)
        ap_mean,  (ap_lo,  ap_hi)  = bootstrap_ci(average_precision_score, y_test, p_test, n_boot=1000, seed=SEED)
        print(f"TEST: AUROC={test_auc:.4f}  (boot mean {auc_mean:.4f}, 95% CI [{auc_lo:.4f},{auc_hi:.4f}])")
        print(f"      PR-AUC={test_ap:.4f} (boot mean {ap_mean:.4f}, 95% CI [{ap_lo:.4f},{ap_hi:.4f}])")

        # Append to results/metrics.csv
        metrics_path = RESULTS / 'metrics.csv'
        write_header = not metrics_path.exists()
        with open(metrics_path, 'a', newline='') as f:
            w = csv.writer(f)
            if write_header:
                w.writerow(['run_name','AUROC','PR_AUC','AUROC_CI_low','AUROC_CI_high',
                            'PR_AUC_CI_low','PR_AUC_CI_high','split','best_epoch'])
            w.writerow(['linear_probe',
                        f'{test_auc:.4f}', f'{test_ap:.4f}',
                        f'{auc_lo:.4f}', f'{auc_hi:.4f}',
                        f'{ap_lo:.4f}', f'{ap_hi:.4f}',
                        'test', best.get('epoch', -1)])
else:
    print('No test loader available.')


TEST: AUROC=0.5000  (boot mean 0.5000, 95% CI [0.5000,0.5000])
      PR-AUC=0.4615 (boot mean 0.4597, 95% CI [0.2288,0.7692])


## Self-check: batch signature

In [86]:
if train_loader is not None:
    b = next(iter(train_loader))
    print('Batch signature:')
    for k,v in b.items():
        if hasattr(v,'shape'):
            print(' ', k, v.shape, v.dtype)
        else:
            print(' ', k, type(v))


Batch signature:
  embeddings torch.Size([24, 2000, 4]) torch.float32
  label torch.Size([24]) torch.float32
