# Setup & config

In [None]:
# ============================================
# Cell 1 – High-level setup & config
# ============================================
import os
import sys
import math
import random
import pickle
from dataclasses import dataclass, asdict
from typing import Dict, Any, Tuple, List, Optional

import numpy as np
import pandas as pd

import torch
import torch.nn as nn

from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score,
    f1_score,
    roc_auc_score,
    average_precision_score,
    confusion_matrix,
    precision_score,
    recall_score,
)
from sklearn.decomposition import PCA
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline

SEED = 42
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

def set_global_seed(seed: int = 42) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_global_seed(SEED)

DATASET_NAMES = [
    "defan",
    "mmlu",
    "halueval_qa",
    "halueval_summarization",
    "halueval_dialogue",
    "psiloqa",
]

PRETTY_DATASET_NAMES = {
    "defan": "Definitive Answers",
    "mmlu": "MMLU-PRO",
    "halueval_qa": "HaluEval QA",
    "halueval_summarization": "HaluEval Summ.",
    "halueval_dialogue": "HaluEval Dialogue",
    "psiloqa": "PsiloQA",
    "all_data": "All datasets",
}

BASE_METHODS = [
    "factcheckmate",
    "lap_eigvals",
    "icr_probe",
    "llm_check",
    "attn_and_hiddn",
]

MODELS_ROOT = "models"  # all models will be saved under models/<dataset>/<method>_model.pth

os.makedirs(MODELS_ROOT, exist_ok=True)

In [None]:
# ============================================
# Cell 2 – Colab / repo / feature extraction helpers
# (optional – only needed if you want to regenerate feature_xxx.pkl)
# ============================================
import subprocess

@dataclass
class RepoConfig:
    zip_path: str = "halu_detector.zip"
    repo_dir: str = "hallucination_detector"
    requirements_relpath: str = "requirements.txt"

def setup_hallucination_repo(cfg: RepoConfig) -> None:
    """
    Unzips the hallucination_detector repo and installs requirements.
    Adapt paths to wherever your zip is stored.
    """
    if os.path.exists(cfg.repo_dir):
        print(f"Removing existing {cfg.repo_dir}")
        subprocess.run(["rm", "-rf", cfg.repo_dir], check=False)

    print("Unzipping repo...")
    subprocess.run(["unzip", "-o", cfg.zip_path], check=True)
    # Clean macOS junk if present
    subprocess.run(["rm", "-rf", "__MACOSX"], check=False)

    # Install requirements
    req_path = os.path.join(cfg.repo_dir, cfg.requirements_relpath)
    print("Installing requirements from", req_path)
    subprocess.run(
        [sys.executable, "-m", "pip", "install", "-r", req_path],
        check=True,
    )
    print("Repo setup complete.")

@dataclass
class FeatureExtractionConfig:
    repo_dir: str = "hallucination_detector"
    model_name: str = "Qwen/Qwen2.5-72B-Instruct"
    dataset_name: str = "defan"
    output_dir: str = "/content/drive/MyDrive/halu_features"
    max_examples: int = 2
    batch_size: int = 1
    max_new_tokens: int = 128
    extra_args: Optional[List[str]] = None  # pass any extra CLI flags

def run_feature_extraction(cfg: FeatureExtractionConfig) -> None:
    """
    Thin wrapper around `python -m hallucination_detector.main ...`.
    Assumes OPENAI_API_KEY (or other needed keys) are already in the env.
    """
    cmd = [
        sys.executable,
        "-m",
        "hallucination_detector.main",
        "--model_name", cfg.model_name,
        "--dataset_name", cfg.dataset_name,
        "--output_dir", cfg.output_dir,
        "--max_examples", str(cfg.max_examples),
        "--device", DEVICE,
        "--batch_size", str(cfg.batch_size),
        "--max_new_tokens", str(cfg.max_new_tokens),
    ]
    if cfg.extra_args:
        cmd.extend(cfg.extra_args)

    print("Running feature extraction with command:")
    print(" ".join(cmd))
    subprocess.run(cmd, cwd=cfg.repo_dir, check=True)


In [None]:
setup_hallucination_repo(RepoConfig())

Removing existing hallucination_detector
Unzipping repo...
Installing requirements from hallucination_detector/requirements.txt


KeyboardInterrupt: 

# Load Pickle

In [None]:
# ============================================
# Cell 3 – Load feature .pkl files & dataset utilities
# ============================================
@dataclass
class FeaturePaths:
    root: str = "/content/drive/MyDrive/halu_features"  # adjust to your Drive/path

    def path(self, short_name: str) -> str:
        return os.path.join(self.root, f"features_{short_name}.pkl")

def _load_pickle(path: str):
    with open(path, "rb") as f:
        return pickle.load(f)

def load_raw_feature_pickles(paths: FeaturePaths) -> Dict[str, Any]:
    """
    Loads the raw CollatedMethodFeatures dicts for each original .pkl.
    """
    print("Loading raw feature pickles from", paths.root)
    data = {
        "defan": _load_pickle(paths.path("defan")),
        "mmlu": _load_pickle(paths.path("mmlu")),
        "halueval_qa": _load_pickle(paths.path("halueval_qa")),
        "halueval_summarization": _load_pickle(paths.path("halueval_summarization")),
        "halueval_dialogue_psiloqa": _load_pickle(
            paths.path("halueval_dialogue_psiloqa")
        ),
    }
    return data

def split_dialogue_psiloqa(dialogue_psiloqa: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
    """
    Split the combined HaluEval-dialogue + PsiloQA CollatedMethodFeatures
    bundle into separate dicts keyed by method name.
    """
    ref_qids = (
        dialogue_psiloqa["factcheckmate"]
        .meta["qid"]
        .reset_index(drop=True)
    )
    ds_names = ref_qids.str.split(":", n=1).str[0]  # "halueval/dialogue" or "psiloqa"
    datasets = ds_names.unique()
    print("Found dialogue sub-datasets:", datasets)

    idx_by_ds = {
        ds: np.where(ds_names == ds)[0]
        for ds in datasets
    }

    def slice_cmf(cmf, idx: np.ndarray):
        # scalars
        if cmf.scalars is not None and not cmf.scalars.empty:
            scalars = cmf.scalars.iloc[idx].reset_index(drop=True)
        else:
            scalars = cmf.scalars  # keep as-is (likely empty DF)

        # meta
        meta = cmf.meta.iloc[idx].reset_index(drop=True)

        # tensors
        tensors = {k: v[idx] for k, v in cmf.tensors.items()}

        cls = cmf.__class__
        return cls(scalars=scalars, tensors=tensors, meta=meta)

    per_dataset = {}
    for ds, idx in idx_by_ds.items():
        ds_methods = {}
        for method_name, cmf in dialogue_psiloqa.items():
            ds_methods[method_name] = slice_cmf(cmf, idx)
        per_dataset[ds] = ds_methods

    halueval_dialogue = per_dataset["halueval/dialogue"]
    psiloqa = per_dataset["psiloqa"]

    return halueval_dialogue, psiloqa

from collections import defaultdict

def merge_feature_dicts(list_of_dicts: List[Dict[str, Any]]) -> Dict[str, Any]:
    """
    Merge multiple {method_name -> CollatedMethodFeatures} dicts along rows.
    This matches the behavior you had in your `merge_feature_dicts` cell.
    """
    # 1) sanity: same method keys everywhere
    method_names = list(list_of_dicts[0].keys())
    for d in list_of_dicts[1:]:
        assert set(d.keys()) == set(method_names), "Method keys differ between datasets!"

    merged: Dict[str, Any] = {}

    for m in method_names:
        meta_list = []
        scalars_list = []
        tensor_lists: Dict[str, List[torch.Tensor]] = defaultdict(list)

        for d in list_of_dicts:
            cmf = d[m]

            # meta
            meta_list.append(cmf.meta)

            # scalars
            if cmf.scalars is not None and (
                len(cmf.scalars.columns) > 0 or len(cmf.scalars) > 0
            ):
                scalars_list.append(cmf.scalars)

            # tensors
            for tname, tval in cmf.tensors.items():
                tensor_lists[tname].append(tval)

        # --- meta with union of columns ---
        meta_cols = sorted(set().union(*(df.columns for df in meta_list)))
        meta_normed = [df.reindex(columns=meta_cols) for df in meta_list]
        meta_merged = pd.concat(meta_normed, ignore_index=True)

        # --- scalars (possibly empty) ---
        if scalars_list:
            scalar_cols = sorted(set().union(*(df.columns for df in scalars_list)))
            scalars_normed = [df.reindex(columns=scalar_cols) for df in scalars_list]
            scalars_merged = pd.concat(scalars_normed, ignore_index=True)
        else:
            # keep empty DF shape from first dataset
            scalars_merged = list_of_dicts[0][m].scalars

        # --- tensors with torch.cat along dim=0 ---
        tensors_merged = {}
        if tensor_lists:
            for tname, tlist in tensor_lists.items():
                base_shape = tlist[0].shape[1:]
                assert all(
                    t.shape[1:] == base_shape for t in tlist
                ), f"Tensor {tname} shape mismatch for {m}"
                tensors_merged[tname] = torch.cat(tlist, dim=0)
        else:
            tensors_merged = list_of_dicts[0][m].tensors

        cls = list_of_dicts[0][m].__class__
        merged[m] = cls(
            scalars=scalars_merged,
            tensors=tensors_merged,
            meta=meta_merged,
        )

    return merged

# --- High-level data loader (caches in memory) ---
_ALL_DATASETS: Dict[str, Dict[str, Any]] = {}
_ALL_DATA_MERGED: Optional[Dict[str, Any]] = None
_BALANCED_ALL_DATA_MERGED: Optional[Dict[str, Any]] = None

import numpy as np
import torch
from copy import deepcopy
from typing import Dict, Literal

# assuming CollatedMethodFeatures is already defined in your codebase:
# from your_module import CollatedMethodFeatures

def _subset_collated(cf, idx):
    """
    Subset a single CollatedMethodFeatures by a numpy index array.
    """
    # Subset scalars if they have rows
    if cf.scalars is not None and len(cf.scalars) > 0:
        scalars_sub = cf.scalars.iloc[idx].reset_index(drop=True)
    else:
        # Keep as-is (typically Empty DataFrame with shape (0, 0))
        scalars_sub = cf.scalars

    # Subset all tensors along the first dimension
    tensors_sub = {}
    for name, t in cf.tensors.items():
        # assume first dimension is sample dimension
        tensors_sub[name] = t[idx]

    # Subset meta
    meta_sub = cf.meta.iloc[idx].reset_index(drop=True)

    return type(cf)(scalars=scalars_sub, tensors=tensors_sub, meta=meta_sub)


def balance_all_data(
    all_data: Dict[str, "CollatedMethodFeatures"],
    strategy: Literal["downsample", "upsample"] = "downsample",
    seed: int = 42,
) -> Dict[str, "CollatedMethodFeatures"]:
    """
    Return a new all_data dict where class counts in 'label' are equal.

    Parameters
    ----------
    all_data : dict
        Mapping method_name -> CollatedMethodFeatures.
        All methods are assumed to have the same number of samples and
        aligned rows.
    strategy : {"downsample", "upsample"}, default="downsample"
        - "downsample": reduce each class to the size of the smallest class.
        - "upsample": increase each class to the size of the largest class
          (with replacement).
    seed : int, default=42
        Random seed for reproducibility.

    Returns
    -------
    balanced_all_data : dict
        Same structure as all_data, but only (or additionally) containing
        rows that yield equal class counts.
    """
    rng = np.random.default_rng(seed)

    # Get labels from the first method's meta
    first_key = next(iter(all_data))
    labels = np.asarray(all_data[first_key].meta["label"].values)

    classes, counts = np.unique(labels, return_counts=True)

    if strategy == "downsample":
        target_n = counts.min()
    elif strategy == "upsample":
        target_n = counts.max()
    else:
        raise ValueError(f"Unknown strategy: {strategy}")

    # Build the new index set
    new_indices = []
    for cls, cnt in zip(classes, counts):
        cls_indices = np.where(labels == cls)[0]

        if strategy == "downsample":
            chosen = rng.choice(cls_indices, size=target_n, replace=False)
        else:  # upsample
            replace = len(cls_indices) < target_n
            chosen = rng.choice(cls_indices, size=target_n, replace=replace)

        new_indices.append(chosen)

    new_indices = np.concatenate(new_indices)
    rng.shuffle(new_indices)

    # Apply the same subset to every method
    balanced_all_data = {}
    for name, cf in all_data.items():
        balanced_all_data[name] = _subset_collated(cf, new_indices)

    return balanced_all_data

def load_all_datasets(paths: FeaturePaths) -> Dict[str, Dict[str, Any]]:
    global _ALL_DATASETS, _ALL_DATA_MERGED, _BALANCED_ALL_DATA_MERGED
    if _ALL_DATASETS:
        return _ALL_DATASETS

    raw = load_raw_feature_pickles(paths)
    halueval_dialogue, psiloqa = split_dialogue_psiloqa(
        raw["halueval_dialogue_psiloqa"]
    )

    datasets = {
        "defan": raw["defan"],
        "balanced_defan": balance_all_data(raw["defan"]),
        "mmlu": raw["mmlu"],
        "balanced_mmlu": balance_all_data(raw["mmlu"]),
        "halueval_qa": raw["halueval_qa"],
        "balanced_halueval_qa": balance_all_data(raw["halueval_qa"]),
        "halueval_summarization": raw["halueval_summarization"],
        "balanced_halueval_summarization": balance_all_data(raw["halueval_summarization"]),
        "halueval_dialogue": halueval_dialogue,
        "balanced_halueval_dialogue": balance_all_data(halueval_dialogue),
        "psiloqa": psiloqa,
        "balanced_psiloqa": balance_all_data(psiloqa),
    }

    _ALL_DATASETS = datasets
    _ALL_DATA_MERGED = merge_feature_dicts([raw["defan"], raw["mmlu"],
                                            raw["halueval_qa"],
                                            raw["halueval_summarization"],
                                            halueval_dialogue, psiloqa])
    _BALANCED_ALL_DATA_MERGED = balance_all_data(_ALL_DATA_MERGED)
    print("Datasets loaded. Keys:", list(datasets.keys()))
    return datasets

def get_dataset(name: str, paths: FeaturePaths, isBalanced=False) -> Dict[str, Any]:
    """
    name ∈ DATASET_NAMES or 'all_data'
    """
    global _ALL_DATA_MERGED, _BALANCED_ALL_DATA_MERGED
    datasets = load_all_datasets(paths)
    if name == "all_data":
        if _ALL_DATA_MERGED is None:
            _ALL_DATA_MERGED = merge_feature_dicts(list(datasets.values()))
        dataset = _ALL_DATA_MERGED
    elif name == "balanced_all_data":
        if _BALANCED_ALL_DATA_MERGED is None:
            _BALANCED_ALL_DATA_MERGED = balance_all_data(_ALL_DATA_MERGED)
        dataset = _BALANCED_ALL_DATA_MERGED
    else:
        dataset = datasets[name]

    if isBalanced:
        dataset = balance_all_data(dataset)

    return dataset

In [None]:
!pip uninstall -y numpy pandas
!pip install --no-cache-dir numpy pandas

Found existing installation: numpy 1.26.4
Uninstalling numpy-1.26.4:
  Successfully uninstalled numpy-1.26.4
Found existing installation: pandas 2.3.3
Uninstalling pandas-2.3.3:
  Successfully uninstalled pandas-2.3.3
Collecting numpy
  Downloading numpy-2.3.5-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (62 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.1/62.1 kB[0m [31m141.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting pandas
  Downloading pandas-2.3.3-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.metadata (91 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m91.2/91.2 kB[0m [31m341.7 MB/s[0m eta [36m0:00:00[0m
Downloading numpy-2.3.5-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (16.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.6/16.6 MB[0m [31m313.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pandas-2.3.3-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_2

In [None]:
datasets = load_all_datasets(FeaturePaths())

Loading raw feature pickles from /content/drive/MyDrive/halu_features
Found dialogue sub-datasets: ['halueval/dialogue' 'psiloqa']
Datasets loaded. Keys: ['defan', 'balanced_defan', 'mmlu', 'balanced_mmlu', 'halueval_qa', 'balanced_halueval_qa', 'halueval_summarization', 'balanced_halueval_summarization', 'halueval_dialogue', 'balanced_halueval_dialogue', 'psiloqa', 'balanced_psiloqa']


In [None]:
!unzip "/content/drive/MyDrive/halu_features/models/*.zip" -d "/content/models"

Archive:  /content/drive/MyDrive/halu_features/models/models_all_data.zip
   creating: /content/models/models_all_data/
  inflating: /content/models/models_all_data/attention_model.pth  
  inflating: /content/models/models_all_data/lap_eigvals_model.pth  
  inflating: /content/models/models_all_data/factcheckmate_model.pth  
  inflating: /content/models/models_all_data/icr_probe_model.pth  
  inflating: /content/models/models_all_data/llm_check_model.pth  
   creating: /content/models/models_all_data/.ipynb_checkpoints/
  inflating: /content/models/models_all_data/attn_and_hiddn_model.pth  

Archive:  /content/drive/MyDrive/halu_features/models/models_defan.zip
   creating: /content/models/models_defan/
  inflating: /content/models/models_defan/attention_model.pth  
  inflating: /content/models/models_defan/lap_eigvals_model.pth  
  inflating: /content/models/models_defan/factcheckmate_model.pth  
  inflating: /content/models/models_defan/icr_probe_model.pth  
  inflating: /content/mod

# Generic metric helpers

In [None]:
# ============================================
# Cell 4 – Generic metric helpers
# ============================================
def compute_binary_metrics(
    y_true: np.ndarray,
    y_prob: np.ndarray,
    threshold: float = 0.5,
) -> Dict[str, Any]:
    y_true = np.asarray(y_true).astype(int)
    y_prob = np.asarray(y_prob).astype(float)
    y_pred = (y_prob >= threshold).astype(int)

    # handle degenerate cases
    try:
        auc = roc_auc_score(y_true, y_prob)
    except Exception:
        auc = float("nan")

    try:
        ap = average_precision_score(y_true, y_prob)
    except Exception:
        ap = float("nan")

    acc = accuracy_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred)
    prec = precision_score(y_true, y_pred, zero_division=0)
    rec = recall_score(y_true, y_pred, zero_division=0)

    cm = confusion_matrix(y_true, y_pred)
    if cm.shape == (2, 2):
        tn, fp, fn, tp = cm.ravel()
    else:
        # if something weird happens, pad to 2x2
        cm_full = np.zeros((2, 2), dtype=int)
        cm_full[: cm.shape[0], : cm.shape[1]] = cm
        tn, fp, fn, tp = cm_full.ravel()
        cm = cm_full

    return {
        "auroc": auc,
        "ap": ap,
        "acc": acc,
        "f1": f1,
        "precision": prec,
        "recall": rec,
        "threshold": float(threshold),
        "cm": cm,
        "tn": int(tn),
        "fp": int(fp),
        "fn": int(fn),
        "tp": int(tp),
        "pred_pos_rate": float((fp + tp) / max(len(y_true), 1)),
    }

def sweep_thresholds(
    y_true: np.ndarray,
    y_prob: np.ndarray,
    thresholds: Optional[np.ndarray] = None,
    metric: str = "acc",
) -> Dict[str, Any]:
    if thresholds is None:
        thresholds = np.linspace(0.2, 0.8, 13)

    outs = []
    for t in thresholds:
        m = compute_binary_metrics(y_true, y_prob, threshold=float(t))
        outs.append(m)

    best = max(outs, key=lambda x: x.get(metric, float("-inf")))
    return best


# FactCheckmate

In [None]:
import torch
from torch import nn
from sklearn.model_selection import train_test_split

# ============================================
# Cell 5 – FactCheckmate model (MLP) – train & evaluate
# ============================================
@dataclass
class FactCheckmateConfig:
    hidden_dim: int = 256
    dropout: float = 0.1
    lr: float = 1e-4
    batch_size: int = 256
    max_epochs: int = 50
    patience: int = 6
    threshold_min: float = 0.2
    threshold_max: float = 0.8
    threshold_steps: int = 13
    seed: int = SEED

class FactCheckmateMLP(nn.Module):
    def __init__(self, in_dim: int, hidden_dim: int = 256, dropout: float = 0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout),
            nn.Linear(hidden_dim, 1),
        )

    def forward(self, x):
        return self.net(x).squeeze(-1)  # logits

class NPDataset(torch.utils.data.Dataset):
    def __init__(self, X, y):
        X = np.asarray(X, dtype=np.float32)
        y = np.asarray(y, dtype=np.float32)
        self.X = torch.from_numpy(X)
        self.y = torch.from_numpy(y).float()

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

def _factcheckmate_xy(dataset: Dict[str, Any]) -> Tuple[np.ndarray, np.ndarray]:
    cmf = dataset["factcheckmate"]
    tensors = cmf.tensors["fcm_pooled_ffn"]
    labels = cmf.meta["label"]
    X_np = tensors.detach().cpu().numpy().astype("float32")
    y_np = labels.astype(int).to_numpy()
    return X_np, y_np

def _balanced_train_val_test_split(
    X_np: np.ndarray,
    y_np: np.ndarray,
    seed: int = SEED,
) -> Tuple[np.ndarray, ...]:
    X_train, X_tmp, y_train, y_tmp = train_test_split(
        X_np,
        y_np,
        test_size=0.30,
        random_state=seed,
        stratify=y_np,
    )
    X_val, X_test, y_val, y_test = train_test_split(
        X_tmp,
        y_tmp,
        test_size=0.50,
        random_state=seed,
        stratify=y_tmp,
    )

    pos_idx = np.where(y_train == 1)[0]
    neg_idx = np.where(y_train == 0)[0]
    m = min(len(pos_idx), len(neg_idx))
    rng = np.random.default_rng(seed)
    pos_sel = rng.choice(pos_idx, size=m, replace=False)
    neg_sel = rng.choice(neg_idx, size=m, replace=False)
    sel = np.concatenate([pos_sel, neg_sel])
    rng.shuffle(sel)

    X_train_b = X_train[sel]
    y_train_b = y_train[sel]

    return X_train_b, y_train_b, X_val, y_val, X_test, y_test

def _evaluate_factcheckmate_loader(
    model: nn.Module,
    loader: torch.utils.data.DataLoader,
    threshold: float = 0.5,
) -> Tuple[Dict[str, Any], np.ndarray]:
    model.eval()
    ys, ps = [], []
    with torch.inference_mode():
        for xb, yb in loader:
            xb = xb.to(DEVICE)
            logits = model(xb)
            prob = torch.sigmoid(logits).detach().cpu().numpy()
            ys.append(yb.numpy())
            ps.append(prob)

    y_true = np.concatenate(ys).astype(int)
    y_prob = np.concatenate(ps).astype(float)
    metrics = compute_binary_metrics(y_true, y_prob, threshold=threshold)
    return metrics, y_prob

def _factcheckmate_model_path(dataset_name: str) -> str:
    candidates = [
        os.path.join(MODELS_ROOT, dataset_name, "factcheckmate_model.pth"),
        f"models_{dataset_name}/factcheckmate_model.pth",          # legacy
        f"models/models_{dataset_name}/factcheckmate_model.pth",   # legacy
    ]
    for p in candidates:
        if os.path.exists(p):
            return p
    # default new location
    path = os.path.join(MODELS_ROOT, dataset_name, "factcheckmate_model.pth")
    os.makedirs(os.path.dirname(path), exist_ok=True)
    return path

def train_factcheckmate_for_dataset(
    dataset: Dict[str, Any],
    dataset_name: str,
    cfg: FactCheckmateConfig,
) -> Dict[str, Any]:
    """
    Trains the FactCheckmate MLP on fcm_pooled_ffn for a single dataset.
    Returns a dict with best state and train/val metrics.
    """
    set_global_seed(cfg.seed)

    X_np, y_np = _factcheckmate_xy(dataset)
    print(f"[{dataset_name}] FactCheckmate X shape: {X_np.shape}, y shape: {y_np.shape}")
    print("Positives:", int(y_np.sum()), "Negatives:", int((1 - y_np).sum()))

    X_train_b, y_train_b, X_val, y_val, X_test, y_test = _balanced_train_val_test_split(
        X_np, y_np, seed=cfg.seed
    )

    train_ds = NPDataset(X_train_b, y_train_b)
    val_ds = NPDataset(X_val, y_val)
    test_ds = NPDataset(X_test, y_test)

    train_loader = torch.utils.data.DataLoader(
        train_ds,
        batch_size=cfg.batch_size,
        shuffle=True,
        num_workers=2,
        pin_memory=True,
    )
    val_loader = torch.utils.data.DataLoader(
        val_ds,
        batch_size=2 * cfg.batch_size,
        shuffle=False,
        num_workers=2,
        pin_memory=True,
    )
    test_loader = torch.utils.data.DataLoader(
        test_ds,
        batch_size=2 * cfg.batch_size,
        shuffle=False,
        num_workers=2,
        pin_memory=True,
    )

    in_dim = X_np.shape[1]
    model = FactCheckmateMLP(
        in_dim=in_dim,
        hidden_dim=cfg.hidden_dim,
        dropout=cfg.dropout,
    ).to(DEVICE)

    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)

    best_val_auc = -1
    best_state = None
    no_improve = 0

    ts = np.linspace(cfg.threshold_min, cfg.threshold_max, cfg.threshold_steps)

    for epoch in range(1, cfg.max_epochs + 1):
        model.train()
        epoch_loss = 0.0
        for xb, yb in train_loader:
            xb = xb.to(DEVICE)
            yb = yb.to(DEVICE).squeeze(-1)
            optimizer.zero_grad(set_to_none=True)
            logits = model(xb)
            loss = criterion(logits, yb)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            epoch_loss += loss.item() * xb.size(0)
        epoch_loss /= len(train_loader.dataset)

        # evaluate at 0.5 and sweep thresholds on val
        val_metrics_05, y_prob_val = _evaluate_factcheckmate_loader(
            model, val_loader, threshold=0.5
        )

        # The following block was causing an error and is redundant.
        # best_val = sweep_thresholds(
        #     val_metrics_05["cm"].sum(axis=1).repeat(0) if False else
        #     np.concatenate([np.zeros_like(y_prob_val), np.ones_like(y_prob_val)]),
        #     # ^ dummy so we don't re-use; we just want threshold grid
        #     # but easier: just call sweep_thresholds on true labels directly:
        #     None,
        # )

        # Actually use threshold sweep on the real val labels
        val_labels = []
        with torch.inference_mode():
            for _, yb in val_loader:
                val_labels.append(yb.numpy())
        y_true_val = np.concatenate(val_labels).astype(int)
        best_val = sweep_thresholds(y_true_val, y_prob_val, thresholds=ts, metric="acc")
        t_star = best_val["threshold"]

        val_metrics_star, _ = _evaluate_factcheckmate_loader(
            model, val_loader, threshold=t_star
        )

        print(
            f"[{dataset_name}][Epoch {epoch:02d}] "
            f"loss={epoch_loss:.4f} "
            f"val_auc@0.5={val_metrics_05['auroc']:.4f} "
            f"val_acc@t*={val_metrics_star['acc']:.4f} "
            f"t*={t_star:.2f}"
        )

        if val_metrics_star["auroc"] > best_val_auc:
            best_val_auc = val_metrics_star["auroc"]
            best_state = {
                "epoch": epoch,
                "state_dict": {k: v.cpu() for k, v in model.state_dict().items()},
                "threshold": t_star,
                "in_dim": in_dim,
                "hidden_dim": cfg.hidden_dim,
                "dropout": cfg.dropout,
                "lr": cfg.lr,
                "seed": cfg.seed,
            }
            no_improve = 0
        else:
            no_improve += 1
            if no_improve >= cfg.patience:
                print(
                    f"Early stopping at epoch {epoch}. "
                    f"Best val auc={best_val_auc:.4f} "
                    f"with t*={best_state['threshold']:.2f}"
                )
                break

    # Save best model
    model_path = _factcheckmate_model_path(dataset_name)
    os.makedirs(os.path.dirname(model_path), exist_ok=True)
    torch.save(best_state, model_path)
    print("Saved FactCheckmate model to", model_path)

    # Evaluate best model on test split
    ckpt = best_state
    best_model = FactCheckmateMLP(
        in_dim=ckpt["in_dim"],
        hidden_dim=ckpt.get("hidden_dim", 256),
        dropout=ckpt.get("dropout", 0.1),
    ).to(DEVICE)
    best_model.load_state_dict(ckpt["state_dict"])
    test_metrics, _ = _evaluate_factcheckmate_loader(
        best_model, test_loader, threshold=ckpt["threshold"]
    )
    print("Test metrics (FactCheckmate):", test_metrics)

    return {
        "dataset": dataset_name,
        "config": asdict(cfg),
        "best_state": ckpt,
        "test_metrics": test_metrics,
    }

def load_factcheckmate_model(dataset_name: str) -> Tuple[FactCheckmateMLP, Dict[str, Any]]:
    path = _factcheckmate_model_path(dataset_name)
    ckpt = torch.load(path, map_location=DEVICE)
    model = FactCheckmateMLP(
        in_dim=ckpt["in_dim"],
        hidden_dim=ckpt.get("hidden_dim", 256),
        dropout=ckpt.get("dropout", 0.1),
    ).to(DEVICE)
    model.load_state_dict(ckpt["state_dict"])
    model.eval()
    return model, ckpt

def evaluate_factcheckmate_on_dataset(
    model: FactCheckmateMLP,
    dataset: Dict[str, Any],
    batch_size: int = 256,
    threshold: Optional[float] = None,
) -> Dict[str, Any]:
    X_np, y_np = _factcheckmate_xy(dataset)
    ds = NPDataset(X_np, y_np)
    loader = torch.utils.data.DataLoader(
        ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True
    )
    if threshold is None:
        threshold = 0.5
    metrics, y_prob = _evaluate_factcheckmate_loader(
        model, loader, threshold=threshold
    )
    return metrics

# LapEigvals

In [None]:
# ---------- Laplacian Eigenvalues ----------

@dataclass
class LapEigvalsConfig:
    pca_components: int = 512
    test_size: float = 0.3  # 60/20/20 via two splits
    val_fraction_of_temp: float = 0.5
    logreg_max_iter: int = 2000
    logreg_C: float = 1.0
    class_weight: str = "balanced"

def _lap_xy(dataset: Dict[str, Any]) -> Tuple[np.ndarray, np.ndarray]:
    features = dataset["lap_eigvals"].tensors["lap_eigvals_vector"]
    labels = dataset["lap_eigvals"].meta["label"]
    X = features.detach().cpu().numpy()
    y_series = labels
    if not np.issubdtype(y_series.dtype, np.number):
        y_series = pd.Categorical(y_series).codes
    y = np.asarray(y_series, dtype=int).ravel()
    return X, y

def _lap_model_path(dataset_name: str) -> str:
    path = os.path.join(MODELS_ROOT, dataset_name, "lap_eigvals_model.pth")
    os.makedirs(os.path.dirname(path), exist_ok=True)
    return path

def train_lap_eigvals_for_dataset(
    dataset: Dict[str, Any],
    dataset_name: str,
    cfg: LapEigvalsConfig,
) -> Dict[str, Any]:
    set_global_seed(SEED)
    X, y = _lap_xy(dataset)
    assert X.shape[0] == y.shape[0]
    classes = np.unique(y)
    assert len(classes) == 2, "Expected binary labels for lap_eigvals."

    X_tr, X_tmp, y_tr, y_tmp = train_test_split(
        X, y, test_size=cfg.test_size, stratify=y, random_state=SEED
    )
    X_val, X_te, y_val, y_te = train_test_split(
        X_tmp,
        y_tmp,
        test_size=cfg.val_fraction_of_temp,
        stratify=y_tmp,
        random_state=SEED,
    )

    pca = PCA(n_components=cfg.pca_components)
    Xtr_p = pca.fit_transform(X_tr)
    Xval_p = pca.transform(X_val)
    Xte_p = pca.transform(X_te)

    clf = LogisticRegression(
        max_iter=cfg.logreg_max_iter,
        C=cfg.logreg_C,
        class_weight=cfg.class_weight,
    )
    clf.fit(Xtr_p, y_tr)

    val_prob = clf.predict_proba(Xval_p)[:, 1]
    te_prob = clf.predict_proba(Xte_p)[:, 1]

    best_val = sweep_thresholds(y_val, val_prob, metric="acc")
    t_star = best_val["threshold"]

    val_metrics = compute_binary_metrics(y_val, val_prob, threshold=t_star)
    test_metrics = compute_binary_metrics(y_te, te_prob, threshold=t_star)

    print(f"[{dataset_name}] LapEigvals – Val metrics:", val_metrics)
    print(f"[{dataset_name}] LapEigvals – Test metrics:", test_metrics)

    state = {
        "pca": pca,
        "clf": clf,
        "threshold": t_star,
        "config": asdict(cfg),
    }
    torch.save(state, _lap_model_path(dataset_name))
    print("Saved LapEigvals model to", _lap_model_path(dataset_name))

    return {
        "dataset": dataset_name,
        "val_metrics": val_metrics,
        "test_metrics": test_metrics,
        "state": state,
    }

def load_lap_model(dataset_name: str):
    state = torch.load(_lap_model_path(dataset_name), map_location="cpu", weights_only=False)
    return state

def evaluate_lap_eigvals_on_dataset(
    state: Dict[str, Any],
    dataset: Dict[str, Any],
) -> Dict[str, Any]:
    X, y = _lap_xy(dataset)
    pca = state["pca"]
    clf = state["clf"]
    t_star = state["threshold"]

    X_p = pca.transform(X)
    y_prob = clf.predict_proba(X_p)[:, 1]
    metrics = compute_binary_metrics(y, y_prob, threshold=t_star)
    return metrics

# ICR Probe

In [None]:
# ---------- ICR Probe (logistic over layer means) ----------

@dataclass
class ICRConfig:
    logreg_max_iter: int = 2000
    class_weight: str = "balanced"
    test_size: float = 0.3
    val_fraction_of_temp: float = 0.5

def icr_layer_means(df: pd.DataFrame, scores_col: str = "icr_scores", label_col: str = "label"):
    X_list = []
    for cell in df[scores_col].tolist():
        LxT = np.asarray(cell, dtype=np.float32)  # (36, T)
        LxT = np.nan_to_num(LxT, nan=0.0, posinf=0.0, neginf=0.0)
        feat = LxT.mean(axis=1).astype(np.float32)  # (36,)
        X_list.append(feat)
    X = np.stack(X_list, axis=0)
    y = df[label_col].astype(int).to_numpy()
    return X, y

def _icr_model_path(dataset_name: str) -> str:
    path = os.path.join(MODELS_ROOT, dataset_name, "icr_probe_model.pth")
    os.makedirs(os.path.dirname(path), exist_ok=True)
    return path

def train_icr_probe_for_dataset(
    dataset: Dict[str, Any],
    dataset_name: str,
    cfg: ICRConfig,
) -> Dict[str, Any]:
    cmf = dataset["icr_probe"]
    X, y = icr_layer_means(cmf.meta)

    X_tr, X_tmp, y_tr, y_tmp = train_test_split(
        X, y, test_size=cfg.test_size, stratify=y, random_state=SEED
    )
    X_val, X_te, y_val, y_te = train_test_split(
        X_tmp,
        y_tmp,
        test_size=cfg.val_fraction_of_temp,
        stratify=y_tmp,
        random_state=SEED,
    )

    mean = X_tr.mean(axis=0, keepdims=True)
    std = X_tr.std(axis=0, keepdims=True)
    std[std < 1e-6] = 1.0
    Xtr_n = (X_tr - mean) / std
    Xval_n = (X_val - mean) / std
    Xte_n = (X_te - mean) / std

    clf = LogisticRegression(
        max_iter=cfg.logreg_max_iter,
        class_weight=cfg.class_weight,
    )
    clf.fit(Xtr_n, y_tr)

    val_prob = clf.predict_proba(Xval_n)[:, 1]
    te_prob = clf.predict_proba(Xte_n)[:, 1]

    best_val = sweep_thresholds(y_val, val_prob, metric="acc")
    t_star = best_val["threshold"]

    val_metrics = compute_binary_metrics(y_val, val_prob, threshold=t_star)
    test_metrics = compute_binary_metrics(y_te, te_prob, threshold=t_star)

    state = {
        "mean": mean,
        "std": std,
        "clf": clf,
        "threshold": t_star,
        "config": asdict(cfg),
    }
    torch.save(state, _icr_model_path(dataset_name))
    print("Saved ICR model to", _icr_model_path(dataset_name))

    return {
        "dataset": dataset_name,
        "val_metrics": val_metrics,
        "test_metrics": test_metrics,
        "state": state,
    }

def load_icr_model(dataset_name: str):
    return torch.load(_icr_model_path(dataset_name), map_location="cpu", weights_only=False)

def evaluate_icr_on_dataset(
    state: Dict[str, Any],
    dataset: Dict[str, Any],
) -> Dict[str, Any]:
    cmf = dataset["icr_probe"]
    X, y = icr_layer_means(cmf.meta)
    mean, std = state["mean"], state["std"]
    clf = state["clf"]
    t_star = state["threshold"]
    Xn = (X - mean) / std
    y_prob = clf.predict_proba(Xn)[:, 1]
    return compute_binary_metrics(y, y_prob, threshold=t_star)


# LLM-Check

In [None]:
# ---------- LLM‑Check scalars (logistic pipeline) ----------

@dataclass
class LLMCheckConfig:
    logreg_max_iter: int = 3000
    C: float = 1.0
    class_weight: str = "balanced"
    test_size: float = 0.2

def _llm_scalar_df(dataset: Dict[str, Any]) -> Tuple[pd.DataFrame, pd.Series]:
    llmc = dataset["llm_check"].scalars
    icr = dataset["icr_probe"].scalars
    features = pd.concat([llmc, icr], axis=1)
    labels = dataset["llm_check"].meta["label"]
    features = features.replace([np.inf, -np.inf], np.nan)
    features = features.fillna(features.mean())
    return features, labels

def _llm_model_path(dataset_name: str) -> str:
    path = os.path.join(MODELS_ROOT, dataset_name, "llm_check_model.pth")
    os.makedirs(os.path.dirname(path), exist_ok=True)
    return path

def train_llmcheck_for_dataset(
    dataset: Dict[str, Any],
    dataset_name: str,
    cfg: LLMCheckConfig,
) -> Dict[str, Any]:
    features, labels = _llm_scalar_df(dataset)

    X_train, X_test, y_train, y_test = train_test_split(
        features,
        labels,
        test_size=cfg.test_size,
        stratify=labels,
        random_state=SEED,
    )

    pipeline = make_pipeline(
        StandardScaler(),
        LogisticRegression(
            class_weight=cfg.class_weight,
            max_iter=cfg.logreg_max_iter,
            C=cfg.C,
        ),
    )
    pipeline.fit(X_train, y_train)

    y_prob = pipeline.predict_proba(X_test)[:, 1]
    best = sweep_thresholds(y_test.to_numpy(), y_prob, metric="acc")
    t_star = best["threshold"]
    test_metrics = compute_binary_metrics(
        y_test.to_numpy(), y_prob, threshold=t_star
    )

    state = {
        "pipeline": pipeline,
        "threshold": t_star,
        "config": asdict(cfg),
    }
    torch.save(state, _llm_model_path(dataset_name))
    print("Saved LLM‑Check model to", _llm_model_path(dataset_name))

    return {
        "dataset": dataset_name,
        "test_metrics": test_metrics,
        "state": state,
    }

def load_llmcheck_model(dataset_name: str):
    return torch.load(_llm_model_path(dataset_name), map_location="cpu", weights_only=False)

def evaluate_llmcheck_on_dataset(
    state: Dict[str, Any],
    dataset: Dict[str, Any],
) -> Dict[str, Any]:
    features, labels = _llm_scalar_df(dataset)
    pipeline = state["pipeline"]
    t_star = state["threshold"]
    y_prob = pipeline.predict_proba(features)[:, 1]
    return compute_binary_metrics(labels.to_numpy(), y_prob, threshold=t_star)


# Attention + Hidden

In [None]:
# ---------- Attention + Hidden scores ----------

@dataclass
class AttnHiddenConfig:
    logreg_max_iter: int = 3000
    C: float = 1.0
    class_weight: str = "balanced"
    test_size: float = 0.2

def _attn_hidden_features(dataset: Dict[str, Any]) -> Tuple[np.ndarray, np.ndarray]:
    attn_scores = dataset["llm_check"].tensors["llmcheck_attn_scores"]
    hiddn_scores = dataset["llm_check"].tensors["llmcheck_hidden_scores"]
    labels = dataset["llm_check"].meta["label"].astype(int).to_numpy()

    attn_scores = np.nan_to_num(attn_scores, nan=np.nan, posinf=np.nan, neginf=np.nan)
    attn_scores = np.nan_to_num(attn_scores, nan=np.nanmean(attn_scores, axis=0))

    hiddn_scores = np.nan_to_num(hiddn_scores, nan=np.nan, posinf=np.nan, neginf=np.nan)
    hiddn_scores = np.nan_to_num(hiddn_scores, nan=np.nanmean(hiddn_scores, axis=0))

    combined = np.concatenate([attn_scores, hiddn_scores], axis=1)
    return combined.astype("float32"), labels

def _attn_model_path(dataset_name: str) -> str:
    path = os.path.join(MODELS_ROOT, dataset_name, "attn_and_hiddn_model.pth")
    os.makedirs(os.path.dirname(path), exist_ok=True)
    return path

def train_attn_hidden_for_dataset(
    dataset: Dict[str, Any],
    dataset_name: str,
    cfg: AttnHiddenConfig,
) -> Dict[str, Any]:
    X, y = _attn_hidden_features(dataset)
    X_train, X_test, y_train, y_test = train_test_split(
        X,
        y,
        test_size=cfg.test_size,
        stratify=y,
        random_state=SEED,
    )

    pipeline = make_pipeline(
        StandardScaler(),
        LogisticRegression(
            class_weight=cfg.class_weight,
            max_iter=cfg.logreg_max_iter,
            C=cfg.C,
        ),
    )
    pipeline.fit(X_train, y_train)

    y_prob = pipeline.predict_proba(X_test)[:, 1]
    best = sweep_thresholds(y_test, y_prob, metric="acc")
    t_star = best["threshold"]
    test_metrics = compute_binary_metrics(y_test, y_prob, threshold=t_star)

    state = {
        "pipeline": pipeline,
        "threshold": t_star,
        "config": asdict(cfg),
    }
    torch.save(state, _attn_model_path(dataset_name))
    print("Saved Attn+Hidden model to", _attn_model_path(dataset_name))

    return {
        "dataset": dataset_name,
        "test_metrics": test_metrics,
        "state": state,
    }

def load_attn_hidden_model(dataset_name: str):
    return torch.load(_attn_model_path(dataset_name), map_location="cpu", weights_only=False)

def evaluate_attn_hidden_on_dataset(
    state: Dict[str, Any],
    dataset: Dict[str, Any],
) -> Dict[str, Any]:
    X, y = _attn_hidden_features(dataset)
    pipeline = state["pipeline"]
    t_star = state["threshold"]
    y_prob = pipeline.predict_proba(X)[:, 1]
    return compute_binary_metrics(y, y_prob, threshold=t_star)

# Meta Detector

In [None]:
# ============================================
# Cell 7 – Meta‑ensemble over base method probabilities
# ============================================
@dataclass
class MetaConfig:
    # features: base probs + LLMCheck scalars
    use_llm_scalars: bool = True
    use_prob_summary: bool = True  # mean/max/min/std, num_over_08, num_over_06
    logreg_C: float = 1.0
    logreg_max_iter: int = 3000
    gb_n_estimators: int = 300
    gb_learning_rate: float = 0.05
    gb_max_depth: int = 3
    gb_subsample: float = 0.7

def compute_base_probabilities_for_dataset(
    dataset_name: str,
    paths: FeaturePaths,
    base_model_train_source: str = None,
) -> Dict[str, np.ndarray]:
    """
    Compute base detector probabilities for a given dataset.
    By default, assumes models were trained on the *same* dataset_name.
    If you want cross‑dataset combinations, pass `base_model_train_source`
    as the dataset that the base models were trained on.
    """
    ds = get_dataset(dataset_name, paths)
    train_src = base_model_train_source or dataset_name

    # load base models trained on `train_src`
    fcm_model, fcm_ckpt = load_factcheckmate_model(train_src)
    lap_state = load_lap_model(train_src)
    icr_state = load_icr_model(train_src)
    llm_state = load_llmcheck_model(train_src)
    attn_state = load_attn_hidden_model(train_src)

    # FactCheckmate
    X_fcm, y_fcm = _factcheckmate_xy(ds)
    fcm_ds = NPDataset(X_fcm, y_fcm)
    loader = torch.utils.data.DataLoader(
        fcm_ds, batch_size=256, shuffle=False, num_workers=2, pin_memory=True
    )
    fcm_model.eval()
    fact_probs = []
    ys_fcm = []
    with torch.inference_mode():
        for xb, yb in loader:
            xb = xb.to(DEVICE)
            logits = fcm_model(xb)
            prob = torch.sigmoid(logits).detach().cpu().numpy()
            fact_probs.append(prob)
            ys_fcm.append(yb.numpy())
    fact_p = np.concatenate(fact_probs).astype("float32")
    y = np.concatenate(ys_fcm).astype(int)

    # LapEigvals
    lap_metrics = evaluate_lap_eigvals_on_dataset(lap_state, ds)
    # we want probs – easiest is to recompute here:
    X_lap, _ = _lap_xy(ds)
    lap_p = lap_state["clf"].predict_proba(lap_state["pca"].transform(X_lap))[:, 1].astype(
        "float32"
    )

    # ICR
    X_icr, y_icr = icr_layer_means(ds["icr_probe"].meta)
    mean, std = icr_state["mean"], icr_state["std"]
    Xn = (X_icr - mean) / std
    icr_p = icr_state["clf"].predict_proba(Xn)[:, 1].astype("float32")

    # LLMCheck
    llm_features, _ = _llm_scalar_df(ds)
    llm_p = llm_state["pipeline"].predict_proba(llm_features)[:, 1].astype("float32")

    # Attn+Hidden
    X_ah, _ = _attn_hidden_features(ds)
    ah_p = attn_state["pipeline"].predict_proba(X_ah)[:, 1].astype("float32")

    assert len(fact_p) == len(lap_p) == len(icr_p) == len(llm_p) == len(ah_p) == len(y)

    return {
        "y": y,
        "fact_p": fact_p,
        "lap_p": lap_p,
        "icr_p": icr_p,
        "llmc_p": llm_p,
        "attn_and_hiddn_p": ah_p,
    }

def build_meta_dataframe(
    dataset_name: str,
    paths: FeaturePaths,
    base_model_train_source: Optional[str] = None,
    cfg: MetaConfig = MetaConfig(),
) -> pd.DataFrame:
    ds = get_dataset(dataset_name, paths)
    base = compute_base_probabilities_for_dataset(
        dataset_name, paths, base_model_train_source=base_model_train_source
    )

    meta_df = pd.DataFrame(
        {
            "fact_p": base["fact_p"],
            "lap_p": base["lap_p"],
            "icr_p": base["icr_p"],
            "llmc_p": base["llmc_p"],
            "attn_and_hiddn_p": base["attn_and_hiddn_p"],
            "label": base["y"],
        }
    )

    if cfg.use_prob_summary:
        feature_cols = ["fact_p", "lap_p", "icr_p", "llmc_p", "attn_and_hiddn_p"]
        meta_df["mean_p"] = meta_df[feature_cols].mean(axis=1)
        meta_df["max_p"] = meta_df[feature_cols].max(axis=1)
        meta_df["min_p"] = meta_df[feature_cols].min(axis=1)
        meta_df["std_p"] = meta_df[feature_cols].std(axis=1)
        meta_df["num_over_08"] = (
            (meta_df[feature_cols] > 0.8).sum(axis=1).astype("int64") / len(feature_cols)
        )
        meta_df["num_over_06"] = (
            (meta_df[feature_cols] > 0.6).sum(axis=1).astype("int64") / len(feature_cols)
        )

    if cfg.use_llm_scalars:
        llmc = ds["llm_check"].scalars
        icr = ds["icr_probe"].scalars
        features = pd.concat([llmc, icr], axis=1)
        features = features.replace([np.inf, -np.inf], np.nan)
        features = features.fillna(features.mean())
        meta_df = pd.concat([meta_df, features.reset_index(drop=True)], axis=1)

    return meta_df

def train_meta_logreg(
    meta_df: pd.DataFrame,
    cfg: MetaConfig,
) -> Tuple[Any, Dict[str, Any]]:
    y = meta_df["label"].astype(int).to_numpy()
    X = meta_df.drop(columns=["label"])

    X_train, X_test, y_train, y_test = train_test_split(
        X,
        y,
        test_size=0.2,
        stratify=y,
        random_state=SEED,
    )

    pipeline = make_pipeline(
        StandardScaler(),
        LogisticRegression(
            penalty="l2",
            C=cfg.logreg_C,
            class_weight="balanced",
            max_iter=cfg.logreg_max_iter,
            solver="lbfgs",
        ),
    )
    pipeline.fit(X_train, y_train)

    y_prob = pipeline.predict_proba(X_test)[:, 1]
    best = sweep_thresholds(y_test, y_prob, metric="f1")
    t_star = best["threshold"]
    metrics = compute_binary_metrics(y_test, y_prob, threshold=t_star)

    return (
        {"pipeline": pipeline, "threshold": t_star, "config": asdict(cfg)},
        metrics,
    )

def train_meta_gb(
    meta_df: pd.DataFrame,
    cfg: MetaConfig,
) -> Tuple[GradientBoostingClassifier, Dict[str, Any]]:
    y = meta_df["label"].astype(int).to_numpy()
    X = meta_df.drop(columns=["label"]).to_numpy().astype("float32")

    X_train, X_test, y_train, y_test = train_test_split(
        X,
        y,
        test_size=0.2,
        stratify=y,
        random_state=SEED,
    )

    gb = GradientBoostingClassifier(
        n_estimators=cfg.gb_n_estimators,
        learning_rate=cfg.gb_learning_rate,
        max_depth=cfg.gb_max_depth,
        subsample=cfg.gb_subsample,
        random_state=SEED,
    )
    gb.fit(X_train, y_train)

    y_prob = gb.predict_proba(X_test)[:, 1]
    best = sweep_thresholds(y_test, y_prob, metric="f1")
    t_star = best["threshold"]
    metrics = compute_binary_metrics(y_test, y_prob, threshold=t_star)
    return gb, {"threshold": t_star, "metrics": metrics}


# Cross‑dataset utilities


In [None]:
# ============================================
# Cell 8 – Cross‑dataset utilities
# ============================================
def evaluate_base_method(
    method: str,
    train_dataset: str,
    test_dataset: str,
    paths: FeaturePaths,
) -> Dict[str, Any]:
    """
    Re‑use trained base models and evaluate on a *different* dataset.
    """
    train_ds_name = train_dataset
    test_ds = get_dataset(test_dataset, paths)

    if method == "factcheckmate":
        model, ckpt = load_factcheckmate_model(train_ds_name)
        threshold = ckpt.get("threshold", 0.5)
        return evaluate_factcheckmate_on_dataset(model, test_ds, threshold=threshold)
    elif method == "lap_eigvals":
        state = load_lap_model(train_ds_name)
        return evaluate_lap_eigvals_on_dataset(state, test_ds)
    elif method == "icr_probe":
        state = load_icr_model(train_ds_name)
        return evaluate_icr_on_dataset(state, test_ds)
    elif method == "llm_check":
        state = load_llmcheck_model(train_ds_name)
        return evaluate_llmcheck_on_dataset(state, test_ds)
    elif method == "attn_and_hiddn":
        state = load_attn_hidden_model(train_ds_name)
        return evaluate_attn_hidden_on_dataset(state, test_ds)
    else:
        raise ValueError(f"Unknown base method: {method}")

def run_cross_dataset_base_experiments(
    paths: FeaturePaths,
    datasets: List[str] = None,
    methods: List[str] = None,
) -> Dict[str, Dict[str, Dict[str, Dict[str, Any]]]]:
    """
    Returns nested dict: results[method][train_ds][test_ds] -> metrics dict
    """
    if datasets is None:
        datasets = DATASET_NAMES
    if methods is None:
        methods = BASE_METHODS

    results: Dict[str, Dict[str, Dict[str, Dict[str, Any]]]] = {}

    for method in methods:
        results[method] = {}
        for train_ds in datasets:
            results[method][train_ds] = {}
            for test_ds in datasets:
                if train_ds == test_ds:
                    continue
                print(f"Evaluating {method}: train={train_ds}, test={test_ds}")
                metrics = evaluate_base_method(method, train_ds, test_ds, paths)
                results[method][train_ds][test_ds] = metrics

    return results

def cross_results_to_dataframe(
    results: Dict[str, Dict[str, Dict[str, Dict[str, Any]]]]
) -> pd.DataFrame:
    rows = []
    for method, train_dict in results.items():
        for train_ds, test_dict in train_dict.items():
            for test_ds, metrics in test_dict.items():
                tn, fp, fn, tp = (
                    metrics["tn"],
                    metrics["fp"],
                    metrics["fn"],
                    metrics["tp"],
                )
                total = tn + fp + fn + tp
                rows.append(
                    {
                        "method": method,
                        "train_ds": train_ds,
                        "test_ds": test_ds,
                        "auroc": metrics["auroc"],
                        "ap": metrics["ap"],
                        "acc": metrics["acc"],
                        "f1": metrics["f1"],
                        "precision": metrics["precision"],
                        "recall": metrics["recall"],
                        "threshold": metrics["threshold"],
                        "tn": tn,
                        "fp": fp,
                        "fn": fn,
                        "tp": tp,
                        "pred_pos_rate": metrics["pred_pos_rate"],
                        "total": total,
                    }
                )
    df = pd.DataFrame(rows)
    df["train_pretty"] = df["train_ds"].map(PRETTY_DATASET_NAMES)
    df["test_pretty"] = df["test_ds"].map(PRETTY_DATASET_NAMES)
    return df

# Experiments

In [None]:
# ---- Example: Load data and train base detectors on all_data ----
paths = FeaturePaths()  # adjust root if needed
all_data = get_dataset("all_data", paths)
balanced_all_data = get_dataset("balanced_all_data", paths)

In [None]:
# ============================================
# Cell 9 – Example usage / experiment snippets
# ============================================

dataset_name = "balanced_all_data"
fc_cfg = FactCheckmateConfig()
factcheck_all = train_factcheckmate_for_dataset(balanced_all_data, dataset_name, fc_cfg)

lap_cfg = LapEigvalsConfig()
lap_all = train_lap_eigvals_for_dataset(balanced_all_data, dataset_name, lap_cfg)

icr_cfg = ICRConfig()
icr_all = train_icr_probe_for_dataset(balanced_all_data, dataset_name, icr_cfg)

llmc_cfg = LLMCheckConfig()
llmc_all = train_llmcheck_for_dataset(balanced_all_data, dataset_name, llmc_cfg)

ah_cfg = AttnHiddenConfig()
ah_all = train_attn_hidden_for_dataset(balanced_all_data, dataset_name, ah_cfg)

[balanced_all_data] FactCheckmate X shape: (78788, 2048), y shape: (78788,)
Positives: 39394 Negatives: 39394
[balanced_all_data][Epoch 01] loss=0.4452 val_auc@0.5=0.9047 val_acc@t*=0.8154 t*=0.50
[balanced_all_data][Epoch 02] loss=0.3845 val_auc@0.5=0.9109 val_acc@t*=0.8215 t*=0.55
[balanced_all_data][Epoch 03] loss=0.3750 val_auc@0.5=0.9135 val_acc@t*=0.8272 t*=0.50
[balanced_all_data][Epoch 04] loss=0.3694 val_auc@0.5=0.9151 val_acc@t*=0.8281 t*=0.45
[balanced_all_data][Epoch 05] loss=0.3659 val_auc@0.5=0.9161 val_acc@t*=0.8287 t*=0.50
[balanced_all_data][Epoch 06] loss=0.3619 val_auc@0.5=0.9166 val_acc@t*=0.8289 t*=0.45
[balanced_all_data][Epoch 07] loss=0.3597 val_auc@0.5=0.9178 val_acc@t*=0.8309 t*=0.45
[balanced_all_data][Epoch 08] loss=0.3573 val_auc@0.5=0.9186 val_acc@t*=0.8325 t*=0.45
[balanced_all_data][Epoch 09] loss=0.3540 val_auc@0.5=0.9188 val_acc@t*=0.8331 t*=0.45
[balanced_all_data][Epoch 10] loss=0.3524 val_auc@0.5=0.9192 val_acc@t*=0.8320 t*=0.45
[balanced_all_data][

In [None]:
# ---- Example: meta‑ensemble on all_data ----
meta_cfg = MetaConfig()
meta_df_all = build_meta_dataframe(dataset_name, paths, base_model_train_source="balanced_all_data", cfg=meta_cfg)

meta_state_lr, meta_metrics_lr = train_meta_logreg(meta_df_all, meta_cfg)
print("Meta logistic metrics:", meta_metrics_lr)

gb_model, gb_info = train_meta_gb(meta_df_all, meta_cfg)
print("Meta GB metrics:", gb_info["metrics"])

Meta logistic metrics: {'auroc': np.float64(0.9236575265879289), 'ap': np.float64(0.9302076729418869), 'acc': 0.8354486609975885, 'f1': 0.8408909615266613, 'precision': 0.8139700641482538, 'recall': 0.869653509328595, 'threshold': 0.35000000000000003, 'cm': array([[6313, 1566],
       [1027, 6852]]), 'tn': 6313, 'fp': 1566, 'fn': 1027, 'tp': 6852, 'pred_pos_rate': 0.5342048483310065}
Meta GB metrics: {'auroc': np.float64(0.9244772046475696), 'ap': np.float64(0.9318459691824587), 'acc': 0.8395100901129585, 'f1': 0.8400682982356289, 'precision': 0.8371565414671036, 'recall': 0.8430003807589795, 'threshold': 0.45000000000000007, 'cm': array([[6587, 1292],
       [1237, 6642]]), 'tn': 6587, 'fp': 1292, 'fn': 1237, 'tp': 6642, 'pred_pos_rate': 0.503490290646021}


In [None]:
# ---- Example: cross‑dataset base experiments ----
cross_results = run_cross_dataset_base_experiments(paths)
df_cross = cross_results_to_dataframe(cross_results)
df_cross.head()

Evaluating factcheckmate: train=defan, test=mmlu
Evaluating factcheckmate: train=defan, test=halueval_qa
Evaluating factcheckmate: train=defan, test=halueval_summarization
Evaluating factcheckmate: train=defan, test=halueval_dialogue
Evaluating factcheckmate: train=defan, test=psiloqa
Evaluating factcheckmate: train=mmlu, test=defan
Evaluating factcheckmate: train=mmlu, test=halueval_qa
Evaluating factcheckmate: train=mmlu, test=halueval_summarization
Evaluating factcheckmate: train=mmlu, test=halueval_dialogue
Evaluating factcheckmate: train=mmlu, test=psiloqa
Evaluating factcheckmate: train=halueval_qa, test=defan
Evaluating factcheckmate: train=halueval_qa, test=mmlu
Evaluating factcheckmate: train=halueval_qa, test=halueval_summarization
Evaluating factcheckmate: train=halueval_qa, test=halueval_dialogue
Evaluating factcheckmate: train=halueval_qa, test=psiloqa
Evaluating factcheckmate: train=halueval_summarization, test=defan
Evaluating factcheckmate: train=halueval_summarization,

FileNotFoundError: [Errno 2] No such file or directory: 'models/defan/lap_eigvals_model.pth'

# Visualization/Plotting Helpers

In [None]:
# ============================================
# Cell 10 – Plotting helpers
# ============================================
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc, precision_recall_curve


# ---------------- Basic helpers ----------------

def plot_confusion_matrix_from_metrics(
    metrics: dict,
    class_names=("Non-hallucination", "Hallucination"),
    normalize: bool = False,
    title: str = None,
):
    """
    Plot a 2x2 confusion matrix from a `metrics` dict created by compute_binary_metrics().
    """
    cm = np.array(metrics["cm"], dtype=float)

    if normalize:
        row_sums = cm.sum(axis=1, keepdims=True)
        row_sums[row_sums == 0] = 1.0
        cm_display = cm / row_sums
    else:
        cm_display = cm

    fig, ax = plt.subplots(figsize=(4, 4))
    im = ax.imshow(cm_display, aspect="equal")

    ax.set_xticks(np.arange(2))
    ax.set_yticks(np.arange(2))
    ax.set_xticklabels(class_names)
    ax.set_yticklabels(class_names)
    ax.set_xlabel("Predicted")
    ax.set_ylabel("True")

    if title is not None:
        ax.set_title(title)

    # annotate
    for i in range(2):
        for j in range(2):
            if normalize:
                text = f"{cm_display[i, j]:.2f}"
            else:
                text = f"{int(cm[i, j])}"
            ax.text(j, i, text, ha="center", va="center")

    fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    plt.tight_layout()
    plt.show()


def plot_roc_pr_curves(
    y_true,
    y_prob,
    title_prefix: str = "Model",
    show_baselines: bool = True,
):
    """
    Plot ROC and Precision–Recall curves given y_true and y_prob.
    (Use if you keep around raw logits/probabilities yourself.)
    """
    y_true = np.asarray(y_true)
    y_prob = np.asarray(y_prob)

    # ROC
    fpr, tpr, _ = roc_curve(y_true, y_prob)
    roc_auc = auc(fpr, tpr)

    # PR
    precision, recall, _ = precision_recall_curve(y_true, y_prob)
    pr_auc = auc(recall, precision)

    fig = plt.figure(figsize=(10, 4))

    # ROC
    ax1 = fig.add_subplot(1, 2, 1)
    ax1.plot(fpr, tpr)
    if show_baselines:
        ax1.plot([0, 1], [0, 1], linestyle="--")
    ax1.set_xlabel("False Positive Rate")
    ax1.set_ylabel("True Positive Rate")
    ax1.set_title(f"{title_prefix} ROC (AUC={roc_auc:.3f})")

    # PR
    ax2 = fig.add_subplot(1, 2, 2)
    ax2.plot(recall, precision)
    if show_baselines:
        base_rate = y_true.mean()
        ax2.hlines(base_rate, 0, 1, linestyles="--")
    ax2.set_xlabel("Recall")
    ax2.set_ylabel("Precision")
    ax2.set_title(f"{title_prefix} PR (AUC={pr_auc:.3f})")

    plt.tight_layout()
    plt.show()


# ---------------- Cross-dataset plots ----------------

def plot_cross_dataset_heatmap(
    df_cross: pd.DataFrame,
    metric: str = "f1",
    method: str = "factcheckmate",
    use_pretty_names: bool = True,
    annotate: bool = True,
    figsize=(7, 5),
):
    """
    Heatmap of cross-dataset performance for a single method.

    df_cross: output of cross_results_to_dataframe(cross_results)
    metric: one of ['f1', 'acc', 'auroc', 'ap', 'precision', 'recall', 'pred_pos_rate']
    method: one of BASE_METHODS
    """
    subset = df_cross[df_cross["method"] == method].copy()
    if subset.empty:
        raise ValueError(f"No rows for method={method} in df_cross")

    if use_pretty_names and "train_pretty" in subset.columns:
        row_key = "train_pretty"
        col_key = "test_pretty"
    else:
        row_key = "train_ds"
        col_key = "test_ds"

    pivot = subset.pivot(index=row_key, columns=col_key, values=metric)

    fig, ax = plt.subplots(figsize=figsize)
    im = ax.imshow(pivot.values, aspect="auto")

    ax.set_xticks(np.arange(pivot.shape[1]))
    ax.set_yticks(np.arange(pivot.shape[0]))
    ax.set_xticklabels(pivot.columns, rotation=45, ha="right")
    ax.set_yticklabels(pivot.index)
    ax.set_xlabel("Test dataset")
    ax.set_ylabel("Train dataset")
    ax.set_title(f"{method} – {metric.upper()} cross-dataset")

    if annotate:
        for i in range(pivot.shape[0]):
            for j in range(pivot.shape[1]):
                val = pivot.values[i, j]
                if np.isnan(val):
                    text = ""
                else:
                    text = f"{val:.2f}"
                ax.text(j, i, text, ha="center", va="center")

    fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    plt.tight_layout()
    plt.show()


def plot_cross_dataset_method_comparison(
    df_cross: pd.DataFrame,
    metric: str = "f1",
    train_ds: str = None,
    test_ds: str = None,
    use_pretty_names: bool = True,
    figsize=(7, 4),
):
    """
    Bar plot comparing methods for a fixed (train_ds, test_ds) pair.

    If train_ds or test_ds is None, uses all rows and averages across them.
    """
    df = df_cross.copy()

    if train_ds is not None:
        df = df[df["train_ds"] == train_ds]
    if test_ds is not None:
        df = df[df["test_ds"] == test_ds]

    if df.empty:
        raise ValueError("No rows left after filtering by train_ds/test_ds.")

    grouped = df.groupby("method")[metric].mean().reset_index()

    fig, ax = plt.subplots(figsize=figsize)
    x = np.arange(len(grouped))
    ax.bar(x, grouped[metric])
    ax.set_xticks(x)
    ax.set_xticklabels(grouped["method"], rotation=45, ha="right")
    ax.set_ylabel(metric.upper())

    if train_ds is not None and test_ds is not None:
        if use_pretty_names and train_ds in PRETTY_DATASET_NAMES and test_ds in PRETTY_DATASET_NAMES:
            train_label = PRETTY_DATASET_NAMES[train_ds]
            test_label = PRETTY_DATASET_NAMES[test_ds]
        else:
            train_label = train_ds
            test_label = test_ds
        title = f"{metric.upper()} – train={train_label}, test={test_label}"
    else:
        title = f"{metric.upper()} – mean across available pairs"

    ax.set_title(title)
    ax.set_ylim(0, 1.0)
    plt.tight_layout()
    plt.show()


# ---------------- Base vs meta ensemble ----------------

def plot_base_vs_meta_bar(
    base_metrics: dict,
    meta_metrics: dict,
    metric: str = "f1",
    title: str = None,
    figsize=(7, 4),
):
    """
    Compare base methods against a meta-ensemble (e.g., logistic meta model).

    base_metrics: dict method_name -> metrics_dict (from compute_binary_metrics or training funcs)
    meta_metrics: metrics_dict for the meta model
    metric: which metric to plot (e.g. 'f1', 'auroc', 'acc')
    """
    names = list(base_metrics.keys()) + ["meta"]
    vals = [base_metrics[m][metric] for m in base_metrics] + [meta_metrics[metric]]

    fig, ax = plt.subplots(figsize=figsize)
    x = np.arange(len(names))
    ax.bar(x, vals)
    ax.set_xticks(x)
    ax.set_xticklabels(names, rotation=45, ha="right")
    ax.set_ylabel(metric.upper())
    ax.set_ylim(0, 1.0)

    if title is None:
        title = f"Base vs Meta – {metric.upper()}"
    ax.set_title(title)

    plt.tight_layout()
    plt.show()
