In [1]:
# CIFAR-10 multi-view conformal fusion + MVCP (their method)
# - Per-view CNNs produce classwise probabilities
# - Baselines: Min-p, Fisher, Adjusted Fisher, Weighted Avg (learned)
# - Late fusion: LR on [p, prob]; conformalize fused scores
# - Added: MVCP (quantile-envelope over multivariate per-view scores)
# - Produces mean (std) tables across simulations & K; saves CSV/LaTeX

from __future__ import annotations

import warnings
from dataclasses import dataclass
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd
from scipy.stats import chi2
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

warnings.filterwarnings("ignore")

# =============================================================================
# Config
# =============================================================================

@dataclass
class CIFARConfig:
    # core
    alpha: float = 0.1
    Ks: Tuple[int, ...] = (2, 3, 4, 6)
    num_classes: int = 10
    num_simulations: int = 10

    # training
    epochs_per_view: int = 100      # adjust for runtime
    lr: float = 1e-3
    batch_size: int = 512
    max_iter_lr: int = 1000         # for sklearn LR (fusion & weight-learning)
    train_seed_base: int = 42

    # data split fractions (similar structure to synthetic)
    train_frac: float = 0.5         # predictor training from full train set
    cal_frac_of_temp: float = 0.3   # portion of temp used as calibration for per-view
    fuse_train_frac_of_rest: float = 0.7  # remaining split into fusion_train/cal

    # MVCP params (their method)
    mvcp_env_frac: float = 0.2      # fraction of fuse-cal used to define envelope S_C^(1)
    mvcp_num_dirs: int = 32         # number of projection directions M
    mvcp_eps: float = 5e-3          # tolerance around (1 - alpha) for coverage in S_C^(1)
    mvcp_rng_seed: int = 2024       # random seed for direction sampling

# =============================================================================
# Torch / Data
# =============================================================================

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# seeds (optional, for reproducibility)
torch.manual_seed(0)
np.random.seed(0)

transform = transforms.Compose([transforms.ToTensor()])

# Load once
train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
test_dataset  = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)

X_train_full = train_dataset.data.astype(np.float32) / 255.0   # (50000, 32, 32, 3)
Y_train_full = np.array(train_dataset.targets)
X_test_full  = test_dataset.data.astype(np.float32) / 255.0    # (10000, 32, 32, 3)
Y_test_full  = np.array(test_dataset.targets)

# =============================================================================
# Multi-view (patch) utilities
# =============================================================================

def split_image_into_k_patches(image: torch.Tensor, k: int) -> List[torch.Tensor]:
    # image: (C, H, W) = (3,32,32)
    C, H, W = image.shape
    if k == 4:
        # 2x2 grid
        patches = []
        for i in range(2):
            for j in range(2):
                patch = image[:, i*16:(i+1)*16, j*16:(j+1)*16]
                patches.append(patch)
        return patches
    else:
        # vertical stripes
        base_width = W // k
        remainder = W % k
        patches, start = [], 0
        for idx in range(k):
            width = base_width + (1 if idx < remainder else 0)
            patch = image[:, :, start:start+width]
            patches.append(patch)
            start += width
        return patches

class PatchesDataset(torch.utils.data.Dataset):
    def __init__(self, images: np.ndarray, labels: np.ndarray, k: int, view: int):
        self.images = images
        self.labels = labels
        self.k = k
        self.view = view

    def __len__(self) -> int:
        return len(self.labels)

    def __getitem__(self, idx: int):
        img = self.images[idx].transpose((2, 0, 1))   # (3,32,32)
        img = torch.tensor(img, dtype=torch.float32)
        patches = split_image_into_k_patches(img, self.k)
        patch = patches[self.view]
        label = int(self.labels[idx])
        return patch, label

# =============================================================================
# Simple CNN per view
# =============================================================================

class PredictorCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.pool  = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.fc1   = None
        self.fc2   = None
        self.num_classes = num_classes

    def forward_features(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        return x

    def forward(self, x):
        x = self.forward_features(x)
        if self.fc1 is None:
            b, c, h, w = x.shape
            self.fc1 = nn.Linear(c*h*w, 128).to(x.device)
            self.fc2 = nn.Linear(128, self.num_classes).to(x.device)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        return self.fc2(x)

def train_model(model: nn.Module, train_loader, num_epochs=100, lr=1e-3):
    crit = nn.CrossEntropyLoss()
    opt  = optim.Adam(model.parameters(), lr=lr)
    model.to(device)
    for ep in range(num_epochs):
        model.train()
        for xb, yb in train_loader:
            xb, yb = xb.to(device), torch.tensor(yb, dtype=torch.long, device=device)
            opt.zero_grad()
            loss = crit(model(xb), yb)
            loss.backward()
            opt.step()
        if (ep+1) % 25 == 0:
            print(f"  epoch {ep+1}/{num_epochs}")
    return model

# =============================================================================
# Conformal utilities (torch models)
# =============================================================================

def compute_nonconformity_scores(model: nn.Module, loader) -> Tuple[np.ndarray, np.ndarray]:
    model.eval()
    scores, labels = [], []
    with torch.no_grad():
        for xb, yb in loader:
            xb = xb.to(device)
            logits = model(xb)
            probs  = F.softmax(logits, dim=1)
            idx    = torch.arange(probs.size(0), device=probs.device)
            true_p = probs[idx, torch.tensor(yb, dtype=torch.long, device=probs.device)]
            s = (1 - true_p).detach().cpu().numpy()
            scores.extend(s)
            labels.extend(yb.numpy())
    return np.asarray(scores, float), np.asarray(labels, int)

def classwise_scores(scores: np.ndarray, labels: np.ndarray, L: int) -> Dict[int, np.ndarray]:
    out = {c: [] for c in range(L)}
    for s, y in zip(scores, labels):
        out[int(y)].append(float(s))
    return {c: np.asarray(v, float) for c, v in out.items()}

def per_view_pvalues_and_probs(
    model: nn.Module, class_scores: Dict[int, np.ndarray], loader, L: int
) -> Tuple[np.ndarray, np.ndarray]:
    """Return p-values (n,L) and probs (n,L) for a single view."""
    model.eval()
    probs_all = []
    with torch.no_grad():
        for xb, _ in loader:
            xb = xb.to(device)
            logits = model(xb)
            probs  = F.softmax(logits, dim=1).detach().cpu().numpy()
            probs_all.append(probs)
    probs_all = np.vstack(probs_all)  # (n, L)

    n = probs_all.shape[0]
    pvals = np.zeros((n, L))
    for y in range(L):
        cal = class_scores.get(y, np.array([]))
        if cal.size == 0:
            pvals[:, y] = 1.0
        else:
            s_test = 1 - probs_all[:, y]
            counts = np.sum(cal[:, None] >= s_test[None, :], axis=0)
            pvals[:, y] = (1 + counts) / (len(cal) + 1)
    return pvals, probs_all

# =============================================================================
# Fusion utilities
# =============================================================================

def build_fusion_features(pvals_list: List[np.ndarray], probs_list: List[np.ndarray]) -> np.ndarray:
    """Horizontally stack [pvals, probs] for each view -> (n, K*2L)"""
    blocks = [np.hstack([pvals_list[k], probs_list[k]]) for k in range(len(pvals_list))]
    return np.hstack(blocks)

def min_p_value_fusion(P_all: np.ndarray) -> np.ndarray:
    """K * min_k p_k^y. P_all: (K,n,L) -> (n,L)"""
    K = P_all.shape[0]
    return K * np.min(P_all, axis=0)

def fisher_fusion(P_all: np.ndarray) -> np.ndarray:
    """Standard Fisher."""
    eps = 1e-12
    p = np.clip(P_all, eps, 1.0)
    T = -2 * np.sum(np.log(p), axis=0)
    df = 2 * P_all.shape[0]
    return 1 - chi2.cdf(T, df=df)

def adjusted_fisher_fusion(P_train: np.ndarray, y_train: np.ndarray, P_test: np.ndarray, L: int) -> np.ndarray:
    """
    Moment-matched Fisher per class: fit variance of T_y = sum_k -2log p_k^y on per-class train,
    then use scaled-chi-square CDF.
    """
    K, _, _ = P_train.shape
    n_test = P_test.shape[1]
    eps = 1e-12
    out = np.zeros((n_test, L))
    for y in range(L):
        idx = np.where(y_train == y)[0]
        if idx.size < 5:
            out[:, y] = fisher_fusion(P_test)[:, y]
            continue
        P_cls = np.clip(P_train[:, idx, y], eps, 1.0)  # (K, n_y)
        W = -2 * np.log(P_cls)                          # (K, n_y)
        Wc = W - W.mean(axis=1, keepdims=True)
        Sigma = (Wc @ Wc.T) / max(W.shape[1] - 1, 1)    # (K, K)
        var_T = np.sum(Sigma)
        if not np.isfinite(var_T) or var_T <= 0:
            var_T = 4 * K
        f_y = (8.0 * K * K) / var_T
        c_y = var_T / (4 * K)

        P_t = np.clip(P_test[:, :, y], eps, 1.0)
        T_t = -2 * np.sum(np.log(P_t), axis=0)
        out[:, y] = 1 - chi2.cdf(T_t / c_y, df=f_y)
    return out

def weighted_average_fusion(P_all: np.ndarray, weights: np.ndarray) -> np.ndarray:
    """sum_k w_k p_k^y; P_all: (K,n,L), weights: (K,)"""
    return np.tensordot(weights, P_all, axes=(0, 0))

def learn_view_weights_from_pvals(pv_train_concat: np.ndarray, y_train: np.ndarray, K: int, L: int, max_iter: int, seed: int) -> np.ndarray:
    """
    Train multinomial LR on p-only features to predict y.
    Convert coef_ (L, K*L) -> view weights by Frobenius norm per (L×L) block.
    """
    lr = LogisticRegression(multi_class="multinomial", solver="lbfgs", max_iter=max_iter, random_state=seed)
    lr.fit(pv_train_concat, y_train)
    B = lr.coef_  # (L, K*L)
    imps = []
    for k in range(K):
        block = B[:, k*L:(k+1)*L]
        imps.append(np.linalg.norm(block, ord="fro"))
    w = np.array(imps, float)
    w = np.maximum(w, 1e-12)
    return w / w.sum()

# =============================================================================
# Conformalization of fused model (MISSING BEFORE — now included)
# =============================================================================

def fused_class_cal_scores(y_cal: np.ndarray, fused_probs_cal: np.ndarray, L: int) -> Dict[int, np.ndarray]:
    """
    Per-class calibration score sets for fused predictor:
    s = 1 - fused_probs[y]; stored classwise for conformal p-values later.
    """
    s = 1 - fused_probs_cal[np.arange(len(y_cal)), y_cal]
    out = {c: [] for c in range(L)}
    for sc, yy in zip(s, y_cal):
        out[int(yy)].append(float(sc))
    return {c: np.asarray(v, float) for c, v in out.items()}

def fused_p_values_from_cal(fused_probs: np.ndarray, cal_class_scores: Dict[int, np.ndarray]) -> np.ndarray:
    """
    Compute conformal p-values from fused predictor using per-class calibration scores.
    """
    n, L = fused_probs.shape
    out = np.zeros((n, L))
    for y in range(L):
        cal = cal_class_scores.get(y, np.array([]))
        if cal.size == 0:
            out[:, y] = 1.0
        else:
            s_test = 1 - fused_probs[:, y]
            counts = np.sum(cal[:, None] >= s_test[None, :], axis=0)
            out[:, y] = (1 + counts) / (len(cal) + 1)
    return out

# =============================================================================
# MVCP (their method): multivariate quantile envelope over per-view scores
# =============================================================================

def _mvcp_sample_directions(K: int, M: int, rng: np.random.RandomState) -> np.ndarray:
    """Uniform on the positive orthant of the unit (K-1)-sphere."""
    V = rng.randn(M, K)
    V = np.abs(V)
    V = V / np.linalg.norm(V, axis=1, keepdims=True)
    return V  # shape (M, K)

def _mvcp_quantile(a: np.ndarray, q: float) -> float:
    """Quantile helper with guard."""
    q = min(max(q, 0.0), 1.0)
    # Use method='higher' to avoid deprecated 'interpolation='
    return float(np.quantile(a, q, method="higher"))

def _mvcp_build_envelope(S1: np.ndarray, alpha: float, M: int, eps: float, rng: np.random.RandomState) -> Tuple[np.ndarray, np.ndarray]:
    """
    Build directions U (M,K) and per-direction thresholds q_e (M,) via beta search
    so that intersection of half-spaces covers ~ (1 - alpha) of S1.
    """
    n1, K = S1.shape
    U = _mvcp_sample_directions(K, M, rng)  # (M, K)
    # search beta in [alpha/M, alpha]
    beta_lo = alpha / max(M, 1)
    beta_hi = alpha
    # Precompute projections per direction
    Proj = S1 @ U.T  # (n1, M)

    # Binary search beta to hit ~1 - alpha coverage
    target = 1.0 - alpha
    for _ in range(30):
        beta = 0.5 * (beta_lo + beta_hi)
        # per-direction (1 - beta) quantiles
        q_e = np.array([_mvcp_quantile(Proj[:, m], 1.0 - beta) for m in range(M)])  # (M,)
        inside = (Proj <= q_e[None, :]).all(axis=1)
        cov = inside.mean()
        if cov > target + eps:
            beta_lo = beta
        elif cov < target - eps:
            beta_hi = beta
        else:
            break
    # final thresholds at current beta
    q_e = np.array([_mvcp_quantile(Proj[:, m], 1.0 - beta) for m in range(M)])
    return U, q_e

def _mvcp_adjust_scale(S2: np.ndarray, U: np.ndarray, q_e: np.ndarray, alpha: float) -> float:
    """Compute t*(s) = max_m u_m^T s / q_e,m on S2, then (1 - alpha)-quantile for scaling."""
    Proj2 = S2 @ U.T  # (n2, M)
    qsafe = np.maximum(q_e, 1e-12)
    t_vals = np.max(Proj2 / qsafe[None, :], axis=1)
    b_t = _mvcp_quantile(t_vals, 1.0 - alpha)
    return float(max(b_t, 1e-12))

def mvcp_fit(S_scores: np.ndarray, alpha: float, env_frac: float, M: int, eps: float, seed: int) -> Tuple[np.ndarray, np.ndarray, float]:
    """
    Fit MVCP envelope:
      S_scores: (n_cal, K) score vectors for TRUE labels on fuse-cal set.
    Returns directions U, thresholds q_e, and scale b_t.
    """
    rng = np.random.RandomState(seed)
    n = S_scores.shape[0]
    n1 = max(1, int(np.floor(env_frac * n)))
    # shuffle for randomness
    idx = rng.permutation(n)
    S1 = S_scores[idx[:n1]]
    S2 = S_scores[idx[n1:]] if n1 < n else S_scores[idx[:n1]]  # if too small, reuse
    U, q_e = _mvcp_build_envelope(S1, alpha, M, eps, rng)
    b_t = _mvcp_adjust_scale(S2, U, q_e, alpha)
    return U, q_e, b_t

def mvcp_predict_sets(pr_views: List[np.ndarray], U: np.ndarray, q_e: np.ndarray, b_t: float) -> np.ndarray:
    """
    Given per-view probabilities for TEST set, produce boolean prediction sets C(x):
      pr_views: list length K with arrays (n, L)
    Returns C_bool: (n, L) where True if class is included by MVCP region.
    """
    K = len(pr_views)
    n, L = pr_views[0].shape
    PR = np.stack(pr_views, axis=0)  # (K, n, L)
    C = np.zeros((n, L), dtype=bool)
    U_T = U.T  # (K, M)
    thresh = q_e * b_t  # (M,)

    for y in range(L):
        S_y = 1.0 - PR[:, :, y]        # (K, n)
        S_y = S_y.transpose(1, 0)      # (n, K)
        Proj = S_y @ U_T               # (n, M)
        C[:, y] = (Proj <= thresh[None, :]).all(axis=1)
    return C  # (n, L) boolean

# =============================================================================
# Metrics & tables
# =============================================================================

def evaluate_sets(P: np.ndarray, y_true: np.ndarray, alpha: float) -> Tuple[float, float]:
    C = (P > alpha)
    cov = float(np.mean(C[np.arange(len(y_true)), y_true]))
    size = float(np.mean(C.sum(axis=1)))
    return cov, size

def evaluate_sets_from_bool(C_bool: np.ndarray, y_true: np.ndarray) -> Tuple[float, float]:
    cov = float(np.mean(C_bool[np.arange(len(y_true)), y_true]))
    size = float(np.mean(C_bool.sum(axis=1)))
    return cov, size

def summarize_table(df: pd.DataFrame, methods: List[str], metric_name: str) -> pd.DataFrame:
    g = df.groupby("K").agg({m: ["mean", "std"] for m in methods})
    g.columns = [f"{a}_{b}" for a, b in g.columns]
    g = g.reset_index()
    for m in methods:
        g[m] = g.apply(lambda r: f"{r[f'{m}_mean']:.2f} ({r[f'{m}_std']:.2f})", axis=1)
    g.insert(1, "Metric", metric_name)
    return g[["K", "Metric"] + methods]

# =============================================================================
# Main experiment
# =============================================================================

def run_experiments(cfg: CIFARConfig):
    results_cov, results_size, results_acc = [], [], []

    for sim in range(cfg.num_simulations):
        print(f"\n=== Simulation {sim+1}/{cfg.num_simulations} ===")
        seed = cfg.train_seed_base + sim
        rng = np.random.RandomState(seed)

        # Splits (train for view predictors, then cal/fusion splits)
        X_trP, X_tmp, y_trP, y_tmp = train_test_split(
            X_train_full, Y_train_full, test_size=1 - cfg.train_frac, stratify=Y_train_full, random_state=seed
        )
        # We will split X_tmp into cal and fusion pools
        X_cal, X_rest, y_cal, y_rest = train_test_split(
            X_tmp, y_tmp, test_size=1 - cfg.cal_frac_of_temp, stratify=y_tmp, random_state=seed
        )
        X_fuse_tr, X_fuse_cal, y_fuse_tr, y_fuse_cal = train_test_split(
            X_rest, y_rest, test_size=1 - cfg.fuse_train_frac_of_rest, stratify=y_rest, random_state=seed
        )
        X_te, y_te = X_test_full, Y_test_full

        for K in cfg.Ks:
            print(f"\n  -> K = {K}")
            num_views = 4 if K == 4 else K

            # Build loaders per view
            loaders = {}
            for v in range(num_views):
                tr_loader   = torch.utils.data.DataLoader(PatchesDataset(X_trP,      y_trP,      K, v), batch_size=cfg.batch_size, shuffle=True)
                cal_loader  = torch.utils.data.DataLoader(PatchesDataset(X_cal,      y_cal,      K, v), batch_size=cfg.batch_size, shuffle=False)
                ftr_loader  = torch.utils.data.DataLoader(PatchesDataset(X_fuse_tr,  y_fuse_tr,  K, v), batch_size=cfg.batch_size, shuffle=False)
                fcal_loader = torch.utils.data.DataLoader(PatchesDataset(X_fuse_cal, y_fuse_cal, K, v), batch_size=cfg.batch_size, shuffle=False)
                te_loader   = torch.utils.data.DataLoader(PatchesDataset(X_te,       y_te,       K, v), batch_size=cfg.batch_size, shuffle=False)
                loaders[v] = dict(train=tr_loader, cal=cal_loader, ftr=ftr_loader, fcal=fcal_loader, te=te_loader)

            # Train per-view CNNs
            models, cal_classwise = [], []
            for v in range(num_views):
                print(f"    [View {v+1}/{num_views}] training...")
                m = PredictorCNN(num_classes=cfg.num_classes)
                m = train_model(m, loaders[v]["train"], num_epochs=cfg.epochs_per_view, lr=cfg.lr)
                models.append(m)
                sc, lab = compute_nonconformity_scores(m, loaders[v]["cal"])
                cal_classwise.append(classwise_scores(sc, lab, cfg.num_classes))

            # Per-view p/probs for fusion train/cal/test
            pv_tr, pr_tr = [], []
            pv_cal, pr_cal = [], []
            pv_te,  pr_te  = [], []
            for v in range(num_views):
                p, pr = per_view_pvalues_and_probs(models[v], cal_classwise[v], loaders[v]["ftr"], cfg.num_classes)
                pv_tr.append(p); pr_tr.append(pr)
                p, pr = per_view_pvalues_and_probs(models[v], cal_classwise[v], loaders[v]["fcal"], cfg.num_classes)
                pv_cal.append(p); pr_cal.append(pr)
                p, pr = per_view_pvalues_and_probs(models[v], cal_classwise[v], loaders[v]["te"],  cfg.num_classes)
                pv_te.append(p);  pr_te.append(pr)

            # -----------------------------
            # Late-fusion LR (yours)
            # -----------------------------
            X_ftr = build_fusion_features(pv_tr, pr_tr)
            fusion_lr = LogisticRegression(max_iter=cfg.max_iter_lr, multi_class="multinomial", solver="lbfgs", random_state=seed)
            fusion_lr.fit(X_ftr, y_fuse_tr)

            # Fused calibration probs + classwise cal scores
            X_fcal = build_fusion_features(pv_cal, pr_cal)
            fused_probs_cal = fusion_lr.predict_proba(X_fcal)
            fused_cal_scores = fused_class_cal_scores(y_fuse_cal, fused_probs_cal, cfg.num_classes)

            # Fused test probs
            X_ftest = build_fusion_features(pv_te, pr_te)
            fused_probs_test = fusion_lr.predict_proba(X_ftest)

            # Our conformal-fused p-values
            P_cf = fused_p_values_from_cal(fused_probs_test, fused_cal_scores)

            # Baselines (stack per-view p-values)
            P_train = np.stack(pv_tr, axis=0)   # (K, n_tr, L)
            P_test  = np.stack(pv_te, axis=0)   # (K, n_te, L)

            P_min    = min_p_value_fusion(P_test)
            P_fisher = fisher_fusion(P_test)
            P_adjF   = adjusted_fisher_fusion(P_train, y_fuse_tr, P_test, cfg.num_classes)

            # Learned weighted average from p-values-only (K*L features)
            pv_tr_concat = np.concatenate(pv_tr, axis=1)  # (n_tr, K*L)
            w_learned = learn_view_weights_from_pvals(pv_tr_concat, y_fuse_tr, num_views, cfg.num_classes, cfg.max_iter_lr, seed)
            P_wavgL = weighted_average_fusion(P_test, w_learned)

            # -----------------------------
            # MVCP (their method)
            # -----------------------------
            # Build TRUE-label multiview scores on fuse-cal set: s_k = 1 - p_k(y_i)
            n_cal = pr_cal[0].shape[0]
            S_scores = np.zeros((n_cal, num_views))
            for i in range(n_cal):
                yi = int(y_fuse_cal[i])
                for v in range(num_views):
                    S_scores[i, v] = 1.0 - pr_cal[v][i, yi]
            # Fit envelope & scaling
            U, q_e, b_t = mvcp_fit(
                S_scores,
                alpha=cfg.alpha,
                env_frac=cfg.mvcp_env_frac,
                M=cfg.mvcp_num_dirs,
                eps=cfg.mvcp_eps,
                seed=cfg.mvcp_rng_seed + sim + K  # vary across sims/K
            )
            # Predict MVCP sets on TEST
            C_mvcp = mvcp_predict_sets(pr_te, U, q_e, b_t)  # (n_test, L) boolean

            # Metrics
            cov_cf,   set_cf   = evaluate_sets(P_cf,     y_te, cfg.alpha)
            cov_min,  set_min  = evaluate_sets(P_min,    y_te, cfg.alpha)
            cov_fi,   set_fi   = evaluate_sets(P_fisher, y_te, cfg.alpha)
            cov_afi,  set_afi  = evaluate_sets(P_adjF,   y_te, cfg.alpha)
            cov_wl,   set_wl   = evaluate_sets(P_wavgL,  y_te, cfg.alpha)
            cov_mv,   set_mv   = evaluate_sets_from_bool(C_mvcp, y_te)

            results_cov.append({
                "Sim": sim, "K": K,
                "Conformal Fusion": cov_cf * 100,
                "Min p-Value": cov_min * 100,
                "Fisher": cov_fi * 100,
                "Adjusted Fisher": cov_afi * 100,
                "Weighted Avg (learned)": cov_wl * 100,
                "MVCP (theirs)": cov_mv * 100,
            })
            results_size.append({
                "Sim": sim, "K": K,
                "Conformal Fusion": set_cf,
                "Min p-Value": set_min,
                "Fisher": set_fi,
                "Adjusted Fisher": set_afi,
                "Weighted Avg (learned)": set_wl,
                "MVCP (theirs)": set_mv,
            })

            # (Optional) Reference accuracy using simple average of per-view probs
            avg_probs = np.mean(np.stack(pr_te, axis=0), axis=0)
            acc_ref = accuracy_score(y_te, np.argmax(avg_probs, axis=1)) * 100
            results_acc.append({"Sim": sim, "K": K, "Reference Acc (avg probs)": acc_ref})

    return pd.DataFrame(results_cov), pd.DataFrame(results_size), pd.DataFrame(results_acc)

# =============================================================================
# Save/print tables (same style as synthetic)
# =============================================================================

def save_tables(df_cov: pd.DataFrame, df_size: pd.DataFrame, df_acc: pd.DataFrame):
    methods = [
        "Conformal Fusion",
        "Min p-Value",
        "Fisher",
        "Adjusted Fisher",
        "Weighted Avg (learned)",
        "MVCP (theirs)",
    ]
    sum_cov = summarize_table(df_cov, methods, "Coverage (%)")
    sum_set = summarize_table(df_size, methods, "Average Set Size")

    # CSV + LaTeX (CIFAR version)
    sum_cov.to_csv("cifar_summary_coverage.csv", index=False)
    sum_set.to_csv("cifar_summary_setsize.csv", index=False)
    with open("cifar_summary_coverage.tex", "w") as f:
        f.write(sum_cov.to_latex(index=False, escape=False))
    with open("cifar_summary_setsize.tex", "w") as f:
        f.write(sum_set.to_latex(index=False, escape=False))

    # Compact side-by-side
    cov_comp = sum_cov.drop(columns=["Metric"]).rename(columns={
        "Conformal Fusion": "CF Cov",
        "Min p-Value": "MinPV Cov",
        "Fisher": "Fisher Cov",
        "Adjusted Fisher": "AdjF Cov",
        "Weighted Avg (learned)": "WAvgL Cov",
        "MVCP (theirs)": "MVCP Cov",
    })
    set_comp = sum_set.drop(columns=["Metric"]).rename(columns={
        "Conformal Fusion": "CF Set",
        "Min p-Value": "MinPV Set",
        "Fisher": "Fisher Set",
        "Adjusted Fisher": "AdjF Set",
        "Weighted Avg (learned)": "WAvgL Set",
        "MVCP (theirs)": "MVCP Set",
    })
    final = cov_comp.merge(set_comp, on="K").sort_values("K")
    final.to_csv("cifar_summary_final.csv", index=False)
    with open("cifar_summary_final.tex", "w") as f:
        f.write(final.to_latex(index=False, escape=False))

    # Acc means (if you want)
    acc_means = df_acc.groupby("K").mean(numeric_only=True).reset_index()
    acc_means.to_csv("cifar_accuracy_summary.csv", index=False)
    with open("cifar_accuracy_summary.tex", "w") as f:
        f.write(acc_means.to_latex(index=False, float_format="%.2f"))

    print("\nSaved:")
    print("  cifar_summary_coverage.csv / .tex")
    print("  cifar_summary_setsize.csv  / .tex")
    print("  cifar_summary_final.csv    / .tex")
    print("  cifar_accuracy_summary.csv / .tex")

# =============================================================================
# Entry
# =============================================================================

def main():
    cfg = CIFARConfig()
    df_cov, df_size, df_acc = run_experiments(cfg)
    print("\n=== Coverage (raw rows) ===")
    print(df_cov.head())
    print("\n=== Set Size (raw rows) ===")
    print(df_size.head())
    save_tables(df_cov, df_size, df_acc)

if __name__ == "__main__":
    main()


Using device: cuda
Files already downloaded and verified
Files already downloaded and verified

=== Simulation 1/10 ===

  -> K = 2
    [View 1/2] training...
  epoch 25/100
  epoch 50/100
  epoch 75/100
  epoch 100/100
    [View 2/2] training...
  epoch 25/100
  epoch 50/100
  epoch 75/100
  epoch 100/100

  -> K = 3
    [View 1/3] training...
  epoch 25/100
  epoch 50/100
  epoch 75/100
  epoch 100/100
    [View 2/3] training...
  epoch 25/100
  epoch 50/100
  epoch 75/100
  epoch 100/100
    [View 3/3] training...
  epoch 25/100
  epoch 50/100
  epoch 75/100
  epoch 100/100

  -> K = 4
    [View 1/4] training...
  epoch 25/100
  epoch 50/100
  epoch 75/100
  epoch 100/100
    [View 2/4] training...
  epoch 25/100
  epoch 50/100
  epoch 75/100
  epoch 25/100
  epoch 50/100
  epoch 75/100
  epoch 100/100

  -> K = 6
    [View 1/6] training...
  epoch 25/100
  epoch 50/100
  epoch 75/100
  epoch 100/100
    [View 2/6] training...
  epoch 25/100
  epoch 50/100
  epoch 75/100
  epoch 100

  epoch 25/100
  epoch 50/100
  epoch 75/100
  epoch 100/100
    [View 4/4] training...
  epoch 25/100
  epoch 50/100
  epoch 75/100
  epoch 100/100

  -> K = 6
    [View 1/6] training...
  epoch 25/100
  epoch 50/100
  epoch 75/100
  epoch 100/100
    [View 2/6] training...
  epoch 25/100
  epoch 50/100
  epoch 75/100
  epoch 100/100
    [View 3/6] training...
  epoch 25/100
  epoch 50/100
  epoch 75/100
  epoch 100/100
    [View 4/6] training...
  epoch 25/100
  epoch 50/100
  epoch 75/100
  epoch 100/100
    [View 5/6] training...
  epoch 25/100
  epoch 50/100
  epoch 75/100
  epoch 100/100

=== Simulation 8/10 ===

  -> K = 2
    [View 1/2] training...
  epoch 25/100
  epoch 50/100
  epoch 75/100
  epoch 100/100
    [View 2/2] training...
  epoch 25/100
  epoch 50/100
  epoch 75/100
  epoch 100/100

  -> K = 3
    [View 1/3] training...
  epoch 25/100
  epoch 50/100
  epoch 75/100
  epoch 100/100
    [View 2/3] training...
  epoch 25/100
  epoch 50/100
  epoch 75/100
  epoch 100/10