In [None]:
"""
Matched version of your "gcifar100" script — UPDATED to match the FIRST code's
choices for: CNN model, preprocess, normalization, resize->patch, LR config, dataloaders, etc.

✅ Change in THIS revision (requested):
- WAvgL baseline is now implemented like your CIFAR-10/MVCP script:
    * Learn view weights from per-view p-values using multinomial LR on p-only features
    * Convert LR coefficients -> per-view importance via Frobenius norm per (L×L) block
    * Use weights to form weighted average of per-view p-values: P_wavg(x,y)=sum_k w_k p_k(x,y)
    * (No extra stage-3 recalibration for WAvgL, matching the other script’s baseline style)

Everything else remains as in your matched version:
- Backbone: ResNet-18 (ImageNet-pretrained optional)
- Preprocess: resize FULL image first, then split into patches (resize->patch)
- Normalization: ImageNet (default) or CIFAR-100
- Dataloader knobs: batch_size/num_workers/pin_memory/persistent_workers
- Optimizer for per-view CNN: Adam with cfg.lr
- LogisticRegression config: multinomial + lbfgs + max_iter=cfg.max_iter_lr + random_state=seed
- Still runs: CF / MinPV / Fisher / AdjF / WAvgL
- Still outputs raw CSV + aggregate CSV + old-style LaTeX mean(std) table(s)
"""

from __future__ import annotations

import os
import warnings
from dataclasses import dataclass
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd
from scipy.stats import chi2

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.transforms import Resize
from torchvision.models import resnet18, ResNet18_Weights

from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split

warnings.filterwarnings("ignore")

# IMPORTANT: match ablation script default behavior
torch.backends.cudnn.benchmark = False


# =============================================================================
# Config (MATCH first code's training + preprocess defaults)
# =============================================================================

@dataclass
class MatchedCIFARConfig:
    # core
    alpha: float = 0.1
    Ks: Tuple[int, ...] = (2, 3, 4, 6)  # experiments over K
    num_classes: int = 100
    num_simulations: int = 4

    # training (match your first code)
    epochs_per_view: int = 150
    lr: float = 1e-3
    batch_size: int = 1024
    max_iter_lr: int = 2000
    train_seed_base: int = 41

    # dataloader (match first code pattern)
    num_workers: int = 4

    # data split fractions (match first code)
    train_frac: float = 0.5
    cal_frac_of_temp: float = 0.3
    fuse_train_frac_of_rest: float = 0.7

    # ablation-style nt sweep
    nt_fracs: Tuple[float, ...] = (1.0,)

    # ResNet / preprocessing (match first code)
    pretrained_backbone: bool = True
    normalization: str = "imagenet"  # "imagenet" or "cifar100"
    full_image_size: int = 128       # resize BEFORE patching (resize->patch)

    # outputs
    out_dir: str = "matched_outputs"
    out_csv: str = "matched_cifar100_results.csv"


# =============================================================================
# Seeds / device
# =============================================================================

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


def set_all_seeds(seed: int) -> None:
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False


# =============================================================================
# Data (match first code: numpy in [0,1])
# =============================================================================

# keep CIFAR as raw tensors/numpy; normalize in PatchesDataset (like first code)
transform = transforms.Compose([transforms.ToTensor()])

train_dataset = datasets.CIFAR100(root="./data", train=True, download=True, transform=transform)
test_dataset  = datasets.CIFAR100(root="./data", train=False, download=True, transform=transform)

X_train_full = train_dataset.data.astype(np.float32) / 255.0
Y_train_full = np.array(train_dataset.targets, dtype=int)

X_test_full  = test_dataset.data.astype(np.float32) / 255.0
Y_test_full  = np.array(test_dataset.targets, dtype=int)


# =============================================================================
# Multi-view (patch) utilities (MATCH first code)
# =============================================================================

def split_image_into_k_patches(image: torch.Tensor, k: int) -> List[torch.Tensor]:
    """
    image: (C,H,W)
    For k=4, split into a 2x2 grid using the CURRENT H,W.
    Otherwise, split along width into k vertical strips with remainder handling.
    """
    _, H, W = image.shape
    if k == 4:
        h2 = H // 2
        w2 = W // 2
        patches = []
        for i in range(2):
            for j in range(2):
                patch = image[:, i * h2 : (i + 1) * h2, j * w2 : (j + 1) * w2]
                patches.append(patch)
        return patches
    else:
        base_width = W // k
        remainder = W % k
        patches, start = [], 0
        for idx in range(k):
            width = base_width + (1 if idx < remainder else 0)
            patch = image[:, :, start : start + width]
            patches.append(patch)
            start += width
        return patches


class PatchesDataset(torch.utils.data.Dataset):
    """
    MATCH first code:
      - input images are numpy float in [0,1] with shape (H,W,3)
      - convert to torch (3,H,W)
      - Resize FULL image to (full_image_size, full_image_size)
      - split into patches
      - normalize patch (ImageNet or CIFAR-100)
    """
    def __init__(
        self,
        images: np.ndarray,
        labels: np.ndarray,
        k: int,
        view: int,
        normalization: str = "imagenet",
        full_image_size: int = 128,
    ):
        self.images = images
        self.labels = labels
        self.k = k
        self.view = view

        self.resize = Resize((full_image_size, full_image_size))

        if normalization.lower() == "cifar100":
            mean = (0.5071, 0.4867, 0.4408)
            std  = (0.2675, 0.2565, 0.2761)
        else:
            mean = (0.485, 0.456, 0.406)
            std  = (0.229, 0.224, 0.225)

        self.normalize = transforms.Normalize(mean=mean, std=std)

    def __len__(self) -> int:
        return len(self.labels)

    def __getitem__(self, idx: int):
        img = torch.from_numpy(self.images[idx]).permute(2, 0, 1).contiguous()  # (3,32,32)
        img = self.resize(img)  # (3, full_image_size, full_image_size)
        patch = split_image_into_k_patches(img, self.k)[self.view]
        patch = self.normalize(patch)
        return patch, int(self.labels[idx])


# =============================================================================
# ResNet-18 per view (MATCH first code)
# =============================================================================

class PredictorCNN(nn.Module):
    def __init__(self, num_classes: int = 100, pretrained: bool = True):
        super().__init__()
        weights = ResNet18_Weights.IMAGENET1K_V1 if pretrained else None
        self.model = resnet18(weights=weights)
        in_features = self.model.fc.in_features
        self.model.fc = nn.Linear(in_features, num_classes)

    def forward(self, x):
        return self.model(x)


def train_model(model: nn.Module, train_loader, num_epochs: int, lr: float):
    # MATCH first code: Adam, no weight_decay, no schedule
    crit = nn.CrossEntropyLoss()
    opt = optim.Adam(model.parameters(), lr=lr)
    model.to(device)

    for ep in range(num_epochs):
        model.train()
        for xb, yb in train_loader:
            xb = xb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True) if torch.is_tensor(yb) else torch.tensor(
                yb, dtype=torch.long, device=device
            )
            opt.zero_grad(set_to_none=True)
            loss = crit(model(xb), yb)
            loss.backward()
            opt.step()

        if (ep + 1) % 25 == 0:
            print(f"  epoch {ep+1}/{num_epochs}")

    return model


# =============================================================================
# Conformal utilities (MATCH first code)
# =============================================================================

def compute_nonconformity_scores(model: nn.Module, loader) -> Tuple[np.ndarray, np.ndarray]:
    model.eval()
    scores, labels = [], []
    with torch.no_grad():
        for xb, yb in loader:
            xb = xb.to(device, non_blocking=True)
            yb_t = yb.to(device, non_blocking=True) if torch.is_tensor(yb) else torch.tensor(
                yb, dtype=torch.long, device=device
            )

            logits = model(xb)
            probs = F.softmax(logits, dim=1)
            idx = torch.arange(probs.size(0), device=probs.device)
            true_p = probs[idx, yb_t]
            s = (1.0 - true_p).detach().cpu().numpy()
            scores.extend(s)
            labels.extend(yb.cpu().numpy() if torch.is_tensor(yb) else np.asarray(yb))

    return np.asarray(scores, float), np.asarray(labels, int)


def classwise_scores(scores: np.ndarray, labels: np.ndarray, L: int) -> Dict[int, np.ndarray]:
    out = {c: [] for c in range(L)}
    for s, y in zip(scores, labels):
        out[int(y)].append(float(s))
    return {c: np.asarray(v, float) for c, v in out.items()}


def per_view_pvalues_and_probs(
    model: nn.Module, class_scores: Dict[int, np.ndarray], loader, L: int
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Returns:
      pvals: (n, L)
      probs: (n, L)
    """
    model.eval()
    probs_all = []
    with torch.no_grad():
        for xb, _ in loader:
            xb = xb.to(device, non_blocking=True)
            logits = model(xb)
            probs = F.softmax(logits, dim=1).detach().cpu().numpy().astype(np.float32)
            probs_all.append(probs)
    probs_all = np.vstack(probs_all)  # (n, L)

    n = probs_all.shape[0]
    pvals = np.zeros((n, L), dtype=np.float32)
    for y in range(L):
        cal = class_scores.get(y, np.array([], dtype=np.float32))
        if cal.size == 0:
            pvals[:, y] = 1.0
        else:
            s_test = 1.0 - probs_all[:, y]
            counts = np.sum(cal[:, None] >= s_test[None, :], axis=0)
            pvals[:, y] = (1.0 + counts) / (len(cal) + 1.0)

    return pvals, probs_all


def fused_class_cal_scores(y_cal: np.ndarray, fused_probs_cal: np.ndarray, L: int) -> Dict[int, np.ndarray]:
    s = 1.0 - fused_probs_cal[np.arange(len(y_cal)), y_cal]
    out = {c: [] for c in range(L)}
    for sc, yy in zip(s, y_cal):
        out[int(yy)].append(float(sc))
    return {c: np.asarray(v, float) for c, v in out.items()}


def fused_p_values_from_cal(fused_probs: np.ndarray, cal_class_scores: Dict[int, np.ndarray]) -> np.ndarray:
    n, L = fused_probs.shape
    out = np.zeros((n, L), dtype=np.float32)
    for y in range(L):
        cal = cal_class_scores.get(y, np.array([], dtype=np.float32))
        if cal.size == 0:
            out[:, y] = 1.0
        else:
            s_test = 1.0 - fused_probs[:, y]
            counts = np.sum(cal[:, None] >= s_test[None, :], axis=0)
            out[:, y] = (1.0 + counts) / (len(cal) + 1.0)
    return out


def evaluate_sets(P: np.ndarray, y_true: np.ndarray, alpha: float) -> Tuple[float, float]:
    C = (P > alpha)
    cov = float(np.mean(C[np.arange(len(y_true)), y_true]))
    size = float(np.mean(C.sum(axis=1)))
    return cov, size


# =============================================================================
# Fusion feature builders / p-value fusion baselines
# =============================================================================

def build_fusion_features(pvals_list: List[np.ndarray], probs_list: List[np.ndarray]) -> np.ndarray:
    # concat [pvals, probs] per view, then concat across views
    blocks = [np.hstack([pvals_list[k], probs_list[k]]) for k in range(len(pvals_list))]
    return np.hstack(blocks)


def fisher_fusion(P_all: np.ndarray) -> np.ndarray:
    """
    P_all: (K, n, L) -> (n, L)
    """
    eps = 1e-12
    p = np.clip(P_all, eps, 1.0)
    T = -2 * np.sum(np.log(p), axis=0)
    df = 2 * P_all.shape[0]
    return 1 - chi2.cdf(T, df=df)


def adjusted_fisher_fusion(P_train: np.ndarray, y_train: np.ndarray, P_test: np.ndarray, L: int) -> np.ndarray:
    """
    Moment-matched (class-conditional) adjusted Fisher.
    P_train: (K, n_train, L), P_test: (K, n_test, L) -> (n_test, L)
    """
    K, _, _ = P_train.shape
    n_test = P_test.shape[1]
    eps = 1e-12
    out = np.zeros((n_test, L), dtype=np.float32)

    for y in range(L):
        idx = np.where(y_train == y)[0]
        if idx.size < 5:
            out[:, y] = fisher_fusion(P_test)[:, y]
            continue

        P_cls = np.clip(P_train[:, idx, y], eps, 1.0)  # (K, n_y)
        W = -2 * np.log(P_cls)
        Wc = W - W.mean(axis=1, keepdims=True)
        Sigma = (Wc @ Wc.T) / max(W.shape[1] - 1, 1)
        var_T = np.sum(Sigma)
        if not np.isfinite(var_T) or var_T <= 0:
            var_T = 4 * K

        f_y = (8.0 * K * K) / var_T
        c_y = var_T / (4 * K)

        P_t = np.clip(P_test[:, :, y], eps, 1.0)
        T_t = -2 * np.sum(np.log(P_t), axis=0)
        out[:, y] = 1 - chi2.cdf(T_t / c_y, df=f_y)

    return out


# =============================================================================
# Subsampling helper (stratified) for nt_fracs
# =============================================================================

def stratified_subsample_indices(y: np.ndarray, n_sub: int, seed: int) -> np.ndarray:
    n = len(y)
    n_sub = int(max(1, min(n_sub, n)))
    idx_all = np.arange(n, dtype=int)
    if n_sub >= n:
        return idx_all
    idx_sub, _ = train_test_split(idx_all, train_size=n_sub, stratify=y, random_state=seed)
    return np.asarray(idx_sub, dtype=int)


# =============================================================================
# WAvgL baseline (UPDATED): learn weights from p-values via multinomial LR (CIFAR-10 style)
# =============================================================================

def weighted_average_fusion(P_all: np.ndarray, weights: np.ndarray) -> np.ndarray:
    """
    sum_k w_k p_k^y
    P_all: (K, n, L), weights: (K,) -> (n, L)
    """
    return np.tensordot(weights.astype(np.float32), P_all.astype(np.float32), axes=(0, 0))


def learn_view_weights_from_pvals(
    pv_train_concat: np.ndarray,
    y_train: np.ndarray,
    K: int,
    L: int,
    max_iter: int,
    seed: int,
) -> np.ndarray:
    """
    Train multinomial LR on p-only features to predict y.
    Convert coef_ (L, K*L) -> view weights by Frobenius norm per (L×L) block.
    """
    lr = LogisticRegression(
        multi_class="multinomial",
        solver="lbfgs",
        max_iter=max_iter,
        random_state=seed,
    )
    lr.fit(pv_train_concat, y_train)

    B = lr.coef_  # (L, K*L)
    imps = []
    for k in range(K):
        block = B[:, k * L : (k + 1) * L]     # (L, L)
        imps.append(np.linalg.norm(block, ord="fro"))

    w = np.asarray(imps, dtype=np.float64)
    w = np.maximum(w, 1e-12)
    w = w / w.sum()
    return w.astype(np.float32)


# =============================================================================
# Old-style LaTeX table writer (mean(std))
# =============================================================================

def _fmt_mean_std(mean: float, std: float, scale: float = 1.0) -> str:
    if not np.isfinite(std):
        std = 0.0
    return f"{mean*scale:.2f} ({std*scale:.2f})"


def save_old_style_latex_table(df_raw: pd.DataFrame, out_path: str, nt_frac: float) -> str:
    """
    columns:
      K | CF Cov | MinPV Cov | Fisher Cov | AdjF Cov | WAvgL Cov
        | CF Set | MinPV Set | Fisher Set | AdjF Set | WAvgL Set
    """
    d = df_raw[df_raw["nt_frac"] == float(nt_frac)].copy()
    if d.empty:
        raise ValueError(f"No rows found for nt_frac={nt_frac}")

    cols_cov = ["cov_cf", "cov_minpv", "cov_fisher", "cov_adjF", "cov_wavgl"]
    cols_set = ["size_cf", "size_minpv", "size_fisher", "size_adjF", "size_wavgl"]

    g_mean = d.groupby("K")[cols_cov + cols_set].mean()
    g_std  = d.groupby("K")[cols_cov + cols_set].std()

    rows = []
    for K in sorted(g_mean.index.tolist()):
        r = {
            "K": int(K),
            "CF Cov":      _fmt_mean_std(g_mean.loc[K, "cov_cf"],      g_std.loc[K, "cov_cf"],      scale=100.0),
            "MinPV Cov":   _fmt_mean_std(g_mean.loc[K, "cov_minpv"],   g_std.loc[K, "cov_minpv"],   scale=100.0),
            "Fisher Cov":  _fmt_mean_std(g_mean.loc[K, "cov_fisher"],  g_std.loc[K, "cov_fisher"],  scale=100.0),
            "AdjF Cov":    _fmt_mean_std(g_mean.loc[K, "cov_adjF"],    g_std.loc[K, "cov_adjF"],    scale=100.0),
            "WAvgL Cov":   _fmt_mean_std(g_mean.loc[K, "cov_wavgl"],   g_std.loc[K, "cov_wavgl"],   scale=100.0),
            "CF Set":      _fmt_mean_std(g_mean.loc[K, "size_cf"],     g_std.loc[K, "size_cf"],     scale=1.0),
            "MinPV Set":   _fmt_mean_std(g_mean.loc[K, "size_minpv"],  g_std.loc[K, "size_minpv"],  scale=1.0),
            "Fisher Set":  _fmt_mean_std(g_mean.loc[K, "size_fisher"], g_std.loc[K, "size_fisher"], scale=1.0),
            "AdjF Set":    _fmt_mean_std(g_mean.loc[K, "size_adjF"],   g_std.loc[K, "size_adjF"],   scale=1.0),
            "WAvgL Set":   _fmt_mean_std(g_mean.loc[K, "size_wavgl"],  g_std.loc[K, "size_wavgl"],  scale=1.0),
        }
        rows.append(r)

    table_df = pd.DataFrame(rows, columns=[
        "K", "CF Cov", "MinPV Cov", "Fisher Cov", "AdjF Cov", "WAvgL Cov",
        "CF Set", "MinPV Set", "Fisher Set", "AdjF Set", "WAvgL Set"
    ])

    latex = table_df.to_latex(
        index=False,
        escape=False,
        column_format="r" + "l"*10,
        booktabs=True,
    )

    with open(out_path, "w", encoding="utf-8") as f:
        f.write(latex)

    return latex


# =============================================================================
# Main experiment (same experiment logic; now matches first code's model+preproc)
# =============================================================================

def run_matched(cfg: MatchedCIFARConfig) -> pd.DataFrame:
    os.makedirs(cfg.out_dir, exist_ok=True)
    rows = []

    pin = bool(torch.cuda.is_available())

    for sim in range(cfg.num_simulations):
        print(f"\n=== Simulation {sim+1}/{cfg.num_simulations} ===")
        seed = cfg.train_seed_base + sim
        set_all_seeds(seed)

        # EXACT split structure (match first code)
        X_trP, X_tmp, y_trP, y_tmp = train_test_split(
            X_train_full,
            Y_train_full,
            test_size=1 - cfg.train_frac,
            stratify=Y_train_full,
            random_state=seed,
        )
        X_cal, X_rest, y_cal, y_rest = train_test_split(
            X_tmp,
            y_tmp,
            test_size=1 - cfg.cal_frac_of_temp,
            stratify=y_tmp,
            random_state=seed,
        )
        X_fuse_tr_full, X_fuse_cal, y_fuse_tr_full, y_fuse_cal = train_test_split(
            X_rest,
            y_rest,
            test_size=1 - cfg.fuse_train_frac_of_rest,
            stratify=y_rest,
            random_state=seed,
        )

        X_te, y_te = X_test_full, Y_test_full

        for K in cfg.Ks:
            print(f"\n  -> K = {K}")
            num_views = 4 if K == 4 else K

            # loaders per view (MATCH first code dataloader knobs)
            loaders = {}
            for v in range(num_views):
                tr_loader = torch.utils.data.DataLoader(
                    PatchesDataset(
                        X_trP, y_trP, K, v,
                        normalization=cfg.normalization,
                        full_image_size=cfg.full_image_size,
                    ),
                    batch_size=cfg.batch_size,
                    shuffle=True,
                    num_workers=cfg.num_workers,
                    pin_memory=pin,
                    persistent_workers=(cfg.num_workers > 0),
                )
                cal_loader = torch.utils.data.DataLoader(
                    PatchesDataset(
                        X_cal, y_cal, K, v,
                        normalization=cfg.normalization,
                        full_image_size=cfg.full_image_size,
                    ),
                    batch_size=cfg.batch_size,
                    shuffle=False,
                    num_workers=cfg.num_workers,
                    pin_memory=pin,
                    persistent_workers=(cfg.num_workers > 0),
                )
                ftr_loader = torch.utils.data.DataLoader(
                    PatchesDataset(
                        X_fuse_tr_full, y_fuse_tr_full, K, v,
                        normalization=cfg.normalization,
                        full_image_size=cfg.full_image_size,
                    ),
                    batch_size=cfg.batch_size,
                    shuffle=False,
                    num_workers=cfg.num_workers,
                    pin_memory=pin,
                    persistent_workers=(cfg.num_workers > 0),
                )
                fcal_loader = torch.utils.data.DataLoader(
                    PatchesDataset(
                        X_fuse_cal, y_fuse_cal, K, v,
                        normalization=cfg.normalization,
                        full_image_size=cfg.full_image_size,
                    ),
                    batch_size=cfg.batch_size,
                    shuffle=False,
                    num_workers=cfg.num_workers,
                    pin_memory=pin,
                    persistent_workers=(cfg.num_workers > 0),
                )
                te_loader = torch.utils.data.DataLoader(
                    PatchesDataset(
                        X_te, y_te, K, v,
                        normalization=cfg.normalization,
                        full_image_size=cfg.full_image_size,
                    ),
                    batch_size=cfg.batch_size,
                    shuffle=False,
                    num_workers=cfg.num_workers,
                    pin_memory=pin,
                    persistent_workers=(cfg.num_workers > 0),
                )
                loaders[v] = dict(train=tr_loader, cal=cal_loader, ftr=ftr_loader, fcal=fcal_loader, te=te_loader)

            # Train per-view predictors
            models: List[PredictorCNN] = []
            cal_classwise: List[Dict[int, np.ndarray]] = []
            for v in range(num_views):
                print(f"    [View {v+1}/{num_views}] training ResNet-18...")
                m = PredictorCNN(num_classes=cfg.num_classes, pretrained=cfg.pretrained_backbone)
                m = train_model(m, loaders[v]["train"], num_epochs=cfg.epochs_per_view, lr=cfg.lr)
                models.append(m)

                sc, lab = compute_nonconformity_scores(m, loaders[v]["cal"])
                cal_classwise.append(classwise_scores(sc, lab, cfg.num_classes))

            # Per-view p/probs on fusion-train FULL, fusion-cal, and test
            pv_tr_full, pr_tr_full = [], []
            pv_fcal,    pr_fcal    = [], []
            pv_te,      pr_te      = [], []

            for v in range(num_views):
                p, pr = per_view_pvalues_and_probs(models[v], cal_classwise[v], loaders[v]["ftr"], cfg.num_classes)
                pv_tr_full.append(p); pr_tr_full.append(pr)

                p, pr = per_view_pvalues_and_probs(models[v], cal_classwise[v], loaders[v]["fcal"], cfg.num_classes)
                pv_fcal.append(p); pr_fcal.append(pr)

                p, pr = per_view_pvalues_and_probs(models[v], cal_classwise[v], loaders[v]["te"], cfg.num_classes)
                pv_te.append(p); pr_te.append(pr)

            P_test = np.stack(pv_te, axis=0).astype(np.float32)  # (num_views, n_test, L)

            # Fixed features for CF (LR) on fusion-cal and test
            X_fcal_feat = build_fusion_features(pv_fcal, pr_fcal)
            X_test_feat = build_fusion_features(pv_te, pr_te)

            n_t_full = len(y_fuse_tr_full)

            # free GPU (same idea as first code)
            for m in models:
                m.to("cpu")
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

            for frac in cfg.nt_fracs:
                n_sub = int(round(frac * n_t_full))
                idx_sub = stratified_subsample_indices(
                    y_fuse_tr_full, n_sub,
                    seed=seed + int(frac * 1000)
                )

                y_tr_sub = y_fuse_tr_full[idx_sub]
                pv_tr_sub = [pv[idx_sub] for pv in pv_tr_full]
                pr_tr_sub = [pr[idx_sub] for pr in pr_tr_full]

                # =========================
                # CF (LR on [p;prob]) + stage-3 recalibration
                # =========================
                X_ftr_sub = build_fusion_features(pv_tr_sub, pr_tr_sub)

                fusion_lr = LogisticRegression(
                    max_iter=cfg.max_iter_lr,
                    multi_class="multinomial",
                    solver="lbfgs",
                    random_state=seed,
                )
                fusion_lr.fit(X_ftr_sub, y_tr_sub)

                fused_probs_cal = fusion_lr.predict_proba(X_fcal_feat)
                fused_cal_scores = fused_class_cal_scores(y_fuse_cal, fused_probs_cal, cfg.num_classes)

                fused_probs_test = fusion_lr.predict_proba(X_test_feat)
                P_cf = fused_p_values_from_cal(fused_probs_test, fused_cal_scores)
                cov_cf, size_cf = evaluate_sets(P_cf, y_te, cfg.alpha)

                # =========================
                # MinPV (Bonferroni min-p): min(1, K * min_k p_k)
                # =========================
                P_min = np.min(P_test, axis=0)  # (n_test, L)
                P_minpv = np.minimum(1.0, float(num_views) * P_min)
                cov_minpv, size_minpv = evaluate_sets(P_minpv, y_te, cfg.alpha)

                # =========================
                # Fisher (standard)
                # =========================
                P_fisher = fisher_fusion(P_test)
                cov_fisher, size_fisher = evaluate_sets(P_fisher, y_te, cfg.alpha)

                # =========================
                # AdjF (moment-matched Fisher baseline) — depends on subset
                # =========================
                P_train_sub = np.stack(pv_tr_sub, axis=0).astype(np.float32)  # (num_views, n_sub, L)
                P_adjF = adjusted_fisher_fusion(P_train_sub, y_tr_sub, P_test, cfg.num_classes)
                cov_adjF, size_adjF = evaluate_sets(P_adjF, y_te, cfg.alpha)

                # =========================
                # WAvgL (UPDATED): learn weights from p-values only, then weighted avg of p-values
                # =========================
                pv_tr_concat = np.concatenate(pv_tr_sub, axis=1).astype(np.float32)  # (n_sub, num_views*L)
                w = learn_view_weights_from_pvals(
                    pv_train_concat=pv_tr_concat,
                    y_train=y_tr_sub,
                    K=num_views,
                    L=cfg.num_classes,
                    max_iter=cfg.max_iter_lr,
                    seed=seed + 777 + int(frac * 1000),
                )
                P_wavgL = weighted_average_fusion(P_test, w)  # (n_test, L)
                P_wavgL = np.clip(P_wavgL, 0.0, 1.0)
                cov_wavgL, size_wavgL = evaluate_sets(P_wavgL, y_te, cfg.alpha)

                rows.append(dict(
                    sim=int(sim),
                    seed=int(seed),
                    K=int(K),
                    alpha=float(cfg.alpha),
                    nt_frac=float(frac),
                    nt_sub=int(len(idx_sub)),
                    nt_full=int(n_t_full),

                    cov_cf=float(cov_cf),
                    size_cf=float(size_cf),

                    cov_minpv=float(cov_minpv),
                    size_minpv=float(size_minpv),

                    cov_fisher=float(cov_fisher),
                    size_fisher=float(size_fisher),

                    cov_adjF=float(cov_adjF),
                    size_adjF=float(size_adjF),

                    cov_wavgl=float(cov_wavgL),
                    size_wavgl=float(size_wavgL),
                ))

                print(
                    f"    nt_frac={frac:.2f} (n={len(idx_sub):5d}) | "
                    f"CF: cov={cov_cf:.3f}, size={size_cf:.2f} | "
                    f"MinPV: cov={cov_minpv:.3f}, size={size_minpv:.2f} | "
                    f"Fisher: cov={cov_fisher:.3f}, size={size_fisher:.2f} | "
                    f"AdjF: cov={cov_adjF:.3f}, size={size_adjF:.2f} | "
                    f"WAvgL(p): cov={cov_wavgL:.3f}, size={size_wavgL:.2f}"
                )

    df = pd.DataFrame(rows)

    # Save raw rows
    out_path = os.path.join(cfg.out_dir, cfg.out_csv)
    df.to_csv(out_path, index=False)
    print(f"\nSaved raw rows: {out_path}")

    # Save numeric aggregate CSV (mean/std)
    agg = df.groupby(["K", "nt_frac"]).agg(
        cov_cf_mean=("cov_cf", "mean"), cov_cf_std=("cov_cf", "std"),
        size_cf_mean=("size_cf", "mean"), size_cf_std=("size_cf", "std"),

        cov_minpv_mean=("cov_minpv", "mean"), cov_minpv_std=("cov_minpv", "std"),
        size_minpv_mean=("size_minpv", "mean"), size_minpv_std=("size_minpv", "std"),

        cov_fisher_mean=("cov_fisher", "mean"), cov_fisher_std=("cov_fisher", "std"),
        size_fisher_mean=("size_fisher", "mean"), size_fisher_std=("size_fisher", "std"),

        cov_adjF_mean=("cov_adjF", "mean"), cov_adjF_std=("cov_adjF", "std"),
        size_adjF_mean=("size_adjF", "mean"), size_adjF_std=("size_adjF", "std"),

        cov_wavgl_mean=("cov_wavgl", "mean"), cov_wavgl_std=("cov_wavgl", "std"),
        size_wavgl_mean=("size_wavgl", "mean"), size_wavgl_std=("size_wavgl", "std"),
    ).reset_index()

    agg_path = os.path.join(cfg.out_dir, "matched_cifar100_agg.csv")
    agg.to_csv(agg_path, index=False)
    print(f"Saved aggregate CSV: {agg_path}")

    # Save + print old-style LaTeX tables (one per nt_frac)
    for frac in cfg.nt_fracs:
        frac_tag = str(frac).replace(".", "p")
        tex_path = os.path.join(cfg.out_dir, f"matched_cifar100_table_nt{frac_tag}.tex")
        latex = save_old_style_latex_table(df, tex_path, nt_frac=float(frac))
        print(f"\nSaved LaTeX table: {tex_path}\n")
        print(latex)

    return df


cfg = MatchedCIFARConfig(
    Ks=(2, 3, 4, 6),
    nt_fracs=(1.0,),
    # defaults match your FIRST code style
)
df = run_matched(cfg)
print("\nHead:")
print(df.head())


Using device: cuda

=== Simulation 1/4 ===

  -> K = 2
    [View 1/2] training ResNet-18...
  epoch 25/150
  epoch 50/150
  epoch 75/150
  epoch 100/150
  epoch 125/150
  epoch 150/150
    [View 2/2] training ResNet-18...
  epoch 25/150
  epoch 50/150
  epoch 75/150
  epoch 100/150
  epoch 125/150
  epoch 150/150
    nt_frac=1.00 (n=12249) | CF: cov=0.911, size=5.42 | MinPV: cov=0.926, size=11.54 | Fisher: cov=0.886, size=4.77 | AdjF: cov=0.892, size=5.06 | WAvgL(p): cov=0.959, size=11.84

  -> K = 3
    [View 1/3] training ResNet-18...
  epoch 25/150
  epoch 50/150
  epoch 75/150
  epoch 100/150
  epoch 125/150
  epoch 150/150
    [View 2/3] training ResNet-18...
  epoch 25/150
  epoch 50/150
  epoch 75/150
  epoch 100/150
  epoch 125/150
  epoch 150/150
    [View 3/3] training ResNet-18...
  epoch 25/150
  epoch 50/150
  epoch 75/150
  epoch 100/150
  epoch 125/150
  epoch 150/150
    nt_frac=1.00 (n=12249) | CF: cov=0.904, size=7.72 | MinPV: cov=0.929, size=17.85 | Fisher: cov=0.875


  -> K = 2
    [View 1/2] training ResNet-18...
  epoch 25/150
  epoch 50/150
  epoch 75/150
  epoch 100/150
  epoch 125/150
  epoch 150/150
    [View 2/2] training ResNet-18...
  epoch 25/150
  epoch 50/150
  epoch 75/150
  epoch 100/150
  epoch 125/150
  epoch 150/150
    nt_frac=1.00 (n=12249) | CF: cov=0.909, size=5.38 | MinPV: cov=0.922, size=11.72 | Fisher: cov=0.881, size=4.23 | AdjF: cov=0.889, size=4.40 | WAvgL(p): cov=0.955, size=10.14

  -> K = 3
    [View 1/3] training ResNet-18...
  epoch 25/150
  epoch 50/150
  epoch 75/150
  epoch 100/150
  epoch 125/150
  epoch 150/150
    [View 2/3] training ResNet-18...
  epoch 25/150
  epoch 50/150
  epoch 75/150
  epoch 100/150
  epoch 125/150
  epoch 150/150
    [View 3/3] training ResNet-18...
  epoch 25/150
  epoch 50/150
  epoch 75/150
  epoch 100/150
  epoch 125/150
  epoch 150/150
    nt_frac=1.00 (n=12249) | CF: cov=0.911, size=8.54 | MinPV: cov=0.927, size=17.31 | Fisher: cov=0.876, size=6.09 | AdjF: cov=0.897, size=7.17 | 