In [None]:
from __future__ import annotations

import os
import warnings
from dataclasses import dataclass
from typing import Dict, List, Tuple
import copy

import numpy as np
import pandas as pd
from scipy.stats import chi2
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.transforms import Resize

import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import scienceplots

from tqdm.auto import tqdm, trange

warnings.filterwarnings("ignore")
torch.backends.cudnn.benchmark = True

# =============================================================================
# Config
# =============================================================================

@dataclass
class CIFARConfig:
    # core
    alpha: float = 0.1
    Ks: Tuple[int, ...] = (2, 3, 4, 6)
    num_classes: int = 100
    num_simulations: int = 10   # keep 1 while debugging speed; bump later

    # training (SIMPLE)
    epochs_per_view: int = 40
    lr: float = 3e-4           # Adam learning rate
    batch_size: int = 256
    max_iter_lr: int = 1000
    train_seed_base: int = 41
    eval_interval: int = 5     # print eval every N epochs (early stop still checks every epoch)

    # early stopping
    early_stop_patience: int = 5
    early_stop_min_delta: float = 1e-3

    # data split fractions
    train_frac: float = 0.6
    cal_frac_of_temp: float = 0.1
    fuse_train_frac_of_rest: float = 0.5

    # MVCP
    mvcp_M: int = 100
    mvcp_eps: float = 1e-8
    mvcp_bs_iters: int = 30

# =============================================================================
# Torch / Data
# =============================================================================

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# CIFAR-100 stats (critical for healthy training)
CIFAR100_MEAN = (0.5071, 0.4867, 0.4408)
CIFAR100_STD  = (0.2675, 0.2565, 0.2761)

train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(CIFAR100_MEAN, CIFAR100_STD),
])
test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(CIFAR100_MEAN, CIFAR100_STD),
])

# Load once
train_dataset = datasets.CIFAR100(root="./data", train=True, download=True, transform=None)
test_dataset  = datasets.CIFAR100(root="./data", train=False, download=True, transform=None)

X_train_full = train_dataset.data.astype(np.float32) / 255.0
Y_train_full = np.array(train_dataset.targets)
X_test_full  = test_dataset.data.astype(np.float32) / 255.0
Y_test_full  = np.array(test_dataset.targets)

# =============================================================================
# Multi-view (patch) utilities
# =============================================================================

def split_image_into_k_patches(image: torch.Tensor, k: int) -> List[torch.Tensor]:
    C, H, W = image.shape
    if k == 4:  # 2x2 grid
        patches = []
        for i in range(2):
            for j in range(2):
                patch = image[:, i*16:(i+1)*16, j*16:(j+1)*16]
                patches.append(patch)
        return patches
    else:       # vertical stripes
        base_width = W // k
        remainder = W % k
        patches, start = [], 0
        for idx in range(k):
            width = base_width + (1 if idx < remainder else 0)
            patch = image[:, :, start:start+width]
            patches.append(patch)
            start += width
        return patches

class PatchesDataset(torch.utils.data.Dataset):
    def __init__(self, images: np.ndarray, labels: np.ndarray, k: int, view: int, train_like: bool=False):
        self.images = images
        self.labels = labels
        self.k = k
        self.view = view
        self.resize = Resize((32, 32))
        self.transform = train_transform if train_like else test_transform
        self.to_pil = transforms.ToPILImage()

    def __len__(self) -> int:
        return len(self.labels)

    def __getitem__(self, idx: int):
        img = self.images[idx].transpose((2, 0, 1))   # (3,32,32)
        img = torch.tensor(img, dtype=torch.float32)
        patch = self.resize(split_image_into_k_patches(img, self.k)[self.view])
        patch = self.transform(self.to_pil(patch))
        return patch, int(self.labels[idx])

# =============================================================================
# Simple ResNet-18 for CIFAR
# =============================================================================

class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, in_planes, planes, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, 3, stride, 1, bias=False)
        self.bn1   = nn.BatchNorm2d(planes)
        self.relu  = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(planes, planes, 3, 1, 1, bias=False)
        self.bn2   = nn.BatchNorm2d(planes)
        self.downsample = None
        if stride != 1 or in_planes != planes:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_planes, planes, 1, stride, bias=False),
                nn.BatchNorm2d(planes)
            )
    def forward(self, x):
        identity = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        return self.relu(out)

def _make_layer(in_planes, planes, num_blocks, stride):
    layers = [BasicBlock(in_planes, planes, stride)]
    for _ in range(1, num_blocks):
        layers.append(BasicBlock(planes, planes, 1))
    return nn.Sequential(*layers)

class PredictorCNN(nn.Module):
    def __init__(self, num_classes=100, width=64, p_drop=0.0):
        super().__init__()
        self.conv1 = nn.Conv2d(3, width, 3, 1, 1, bias=False)
        self.bn1   = nn.BatchNorm2d(width)
        self.relu  = nn.ReLU(inplace=True)
        self.layer1 = _make_layer(width,   width,   2, 1)   # 32x32
        self.layer2 = _make_layer(width,   2*width, 2, 2)   # 16x16
        self.layer3 = _make_layer(2*width, 4*width, 2, 2)   # 8x8
        self.layer4 = _make_layer(4*width, 8*width, 2, 2)   # 4x4
        self.pool  = nn.AdaptiveAvgPool2d(1)
        self.fc    = nn.Linear(8*width, num_classes)

    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x); x = self.layer2(x); x = self.layer3(x); x = self.layer4(x)
        x = self.pool(x)
        return self.fc(torch.flatten(x, 1))

# =============================================================================
# Eval + SIMPLE training + EARLY STOP
# =============================================================================

@torch.no_grad()
def eval_ce_acc(model: nn.Module, loader, device: torch.device) -> Tuple[float, float]:
    model.eval()
    nll, n_correct, n_total = 0.0, 0, 0
    for xb, yb in loader:
        xb = xb.to(device)
        yb = yb.to(device, dtype=torch.long)
        logits = model(xb)
        nll += F.cross_entropy(logits, yb, reduction="sum").item()
        n_correct += (logits.argmax(1) == yb).sum().item()
        n_total += yb.numel()
    return nll / max(n_total, 1), n_correct / max(n_total, 1)

def train_model(
    model: nn.Module,
    train_loader,
    num_epochs: int,
    lr: float,
    desc: str | None = None,
    eval_loader=None,
    eval_interval: int = 5,
    early_stop_patience: int = 7,
    early_stop_min_delta: float = 1e-3,
):
    crit = nn.CrossEntropyLoss()
    opt  = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)

    model.to(device)

    best_state = copy.deepcopy(model.state_dict())
    best_cal_ce = float("inf")
    epochs_no_improve = 0

    epoch_iter = trange(num_epochs, desc=desc or "Training", leave=False)
    for ep in epoch_iter:
        model.train()
        nll, n_correct, n_total = 0.0, 0, 0

        for xb, yb in tqdm(train_loader, total=len(train_loader), desc=f"Epoch {ep+1}/{num_epochs}", leave=False):
            xb = xb.to(device)
            yb = yb.to(device, dtype=torch.long)
            opt.zero_grad(set_to_none=True)
            logits = model(xb)
            loss = crit(logits, yb)
            loss.backward()
            opt.step()
            nll += loss.detach().float().cpu().item() * yb.numel()
            n_correct += (logits.detach().argmax(1) == yb).sum().item()
            n_total += yb.numel()

        train_ce = nll/max(n_total,1)
        train_acc = n_correct/max(n_total,1)
        if ((ep + 1) % max(eval_interval, 1)) == 0 or ep == num_epochs - 1:
            tqdm.write(f"[{desc or 'Train'}] Epoch {ep+1}/{num_epochs}  Train CE: {train_ce:.4f}  Acc: {train_acc:.4f}")

        # ---- Early stopping check EVERY epoch on calibration split ----
        if eval_loader is not None:
            cal_ce, cal_acc = eval_ce_acc(model, eval_loader, device)
            if ((ep + 1) % max(eval_interval, 1)) == 0:
                tqdm.write(f"[{desc or 'Train'}] >>> Eval @ epoch {ep+1}:  Cal CE: {cal_ce:.4f}  Acc: {cal_acc:.4f}")

            if cal_ce < best_cal_ce - early_stop_min_delta:
                best_cal_ce = cal_ce
                best_state = copy.deepcopy(model.state_dict())
                epochs_no_improve = 0
            else:
                epochs_no_improve += 1

            if epochs_no_improve >= early_stop_patience:
                tqdm.write(f"[{desc or 'Train'}] Early stopping at epoch {ep+1}. "
                           f"Best Cal CE: {best_cal_ce:.4f} (restoring best weights).")
                model.load_state_dict(best_state)
                break

    # Make sure best weights are used if never improved after last print
    if eval_loader is not None:
        model.load_state_dict(best_state)
    return model

# =============================================================================
# Conformal utilities (unchanged)
# =============================================================================

def compute_nonconformity_scores(model: nn.Module, loader) -> Tuple[np.ndarray, np.ndarray]:
    model.eval()
    scores, labels = [], []
    with torch.no_grad():
        for xb, yb in tqdm(loader, total=len(loader), desc="Cal: nonconformity", leave=False):
            xb = xb.to(device)
            logits = model(xb)
            probs  = F.softmax(logits, dim=1)
            idx    = torch.arange(probs.size(0), device=probs.device)
            true_p = probs[idx, torch.tensor(yb, dtype=torch.long, device=probs.device)]
            s = (1 - true_p).detach().cpu().numpy()
            scores.extend(s); labels.extend(yb.numpy())
    return np.asarray(scores, float), np.asarray(labels, int)

def classwise_scores(scores: np.ndarray, labels: np.ndarray, L: int) -> Dict[int, np.ndarray]:
    out = {c: [] for c in range(L)}
    for s, y in zip(scores, labels):
        out[int(y)].append(float(s))
    return {c: np.asarray(v, float) for c, v in out.items()}

def per_view_pvalues_and_probs(model: nn.Module, class_scores: Dict[int, np.ndarray], loader, L: int) -> Tuple[np.ndarray, np.ndarray]:
    model.eval()
    probs_all = []
    with torch.no_grad():
        for xb, _ in tqdm(loader, total=len(loader), desc="Infer: probs", leave=False):
            xb = xb.to(device)
            probs = F.softmax(model(xb), dim=1).detach().cpu().numpy()
            probs_all.append(probs)
    probs_all = np.vstack(probs_all)
    n = probs_all.shape[0]
    pvals = np.zeros((n, L))
    for y in range(L):
        cal = class_scores.get(y, np.array([]))
        if cal.size == 0:
            pvals[:, y] = 1.0
        else:
            s_test = 1 - probs_all[:, y]
            counts = np.sum(cal[:, None] >= s_test[None, :], axis=0)
            pvals[:, y] = (1 + counts) / (len(cal) + 1)
    return pvals, probs_all

@torch.no_grad()
def compute_probs(model: nn.Module, loader) -> np.ndarray:
    model.eval()
    out = []
    for xb, _ in tqdm(loader, total=len(loader), desc="Infer: probs", leave=False):
        xb = xb.to(device)
        out.append(F.softmax(model(xb), dim=1).detach().cpu().numpy())
    return np.vstack(out)

# =============================================================================
# Fusion utilities (unchanged)
# =============================================================================

def build_fusion_features(pvals_list: List[np.ndarray], probs_list: List[np.ndarray]) -> np.ndarray:
    blocks = [np.hstack([pvals_list[k], probs_list[k]]) for k in range(len(pvals_list))]
    return np.hstack(blocks)

def min_p_value_fusion(P_all: np.ndarray) -> np.ndarray:
    K = P_all.shape[0]; return K * np.min(P_all, axis=0)

def fisher_fusion(P_all: np.ndarray) -> np.ndarray:
    eps = 1e-12; p = np.clip(P_all, eps, 1.0)
    T = -2 * np.sum(np.log(p), axis=0); df = 2 * P_all.shape[0]
    return 1 - chi2.cdf(T, df=df)

def adjusted_fisher_fusion(P_train: np.ndarray, y_train: np.ndarray, P_test: np.ndarray, L: int) -> np.ndarray:
    K, _, _ = P_train.shape; n_test = P_test.shape[1]; eps = 1e-12
    out = np.zeros((n_test, L))
    for y in range(L):
        idx = np.where(y_train == y)[0]
        if idx.size < 5:
            out[:, y] = fisher_fusion(P_test)[:, y]; continue
        P_cls = np.clip(P_train[:, idx, y], eps, 1.0)
        W = -2 * np.log(P_cls); Wc = W - W.mean(axis=1, keepdims=True)
        Sigma = (Wc @ Wc.T) / max(W.shape[1] - 1, 1)
        var_T = np.sum(Sigma)
        if not np.isfinite(var_T) or var_T <= 0: var_T = 4 * K
        f_y = (8.0 * K * K) / var_T; c_y = var_T / (4 * K)
        P_t = np.clip(P_test[:, :, y], eps, 1.0); T_t = -2 * np.sum(np.log(P_t), axis=0)
        out[:, y] = 1 - chi2.cdf(T_t / c_y, df=f_y)
    return out

def weighted_average_fusion(P_all: np.ndarray, weights: np.ndarray) -> np.ndarray:
    return np.tensordot(weights, P_all, axes=(0, 0))

def learn_view_weights_from_pvals(pv_train_concat: np.ndarray, y_train: np.ndarray, K: int, L: int, max_iter: int, seed: int) -> np.ndarray:
    lr = LogisticRegression(multi_class="multinomial", solver="lbfgs", max_iter=max_iter, random_state=seed)
    lr.fit(pv_train_concat, y_train)
    B = lr.coef_; imps = []
    for k in range(K):
        block = B[:, k*L:(k+1)*L]
        imps.append(np.linalg.norm(block, ord="fro"))
    w = np.array(imps, float); w = np.maximum(w, 1e-12)
    return w / w.sum()

def fused_class_cal_scores(y_cal: np.ndarray, fused_probs_cal: np.ndarray, L: int) -> Dict[int, np.ndarray]:
    s = 1 - fused_probs_cal[np.arange(len(y_cal)), y_cal]
    out = {c: [] for c in range(L)}
    for sc, yy in zip(s, y_cal): out[int(yy)].append(float(sc))
    return {c: np.asarray(v, float) for c, v in out.items()}

def fused_p_values_from_cal(fused_probs: np.ndarray, cal_class_scores: Dict[int, np.ndarray]) -> np.ndarray:
    n, L = fused_probs.shape; out = np.zeros((n, L))
    for y in range(L):
        cal = cal_class_scores.get(y, np.array([]))
        if cal.size == 0: out[:, y] = 1.0
        else:
            s_test = 1 - fused_probs[:, y]
            counts = np.sum(cal[:, None] >= s_test[None, :], axis=0)
            out[:, y] = (1 + counts) / (len(cal) + 1)
    return out

def evaluate_sets(P: np.ndarray, y_true: np.ndarray, alpha: float) -> Tuple[float, float]:
    C = (P > alpha)
    cov = float(np.mean(C[np.arange(len(y_true)), y_true]))
    size = float(np.mean(C.sum(axis=1)))
    return cov, size

# =============================================================================
# MVCP
# =============================================================================

def _empirical_quantile_conservative(vals: np.ndarray, q: float) -> float:
    v = np.sort(vals); n = len(v)
    k = int(np.ceil((n + 1) * q)); k = min(max(k, 1), n)
    return v[k - 1]

def mvcp_quantile_envelope(S: np.ndarray, alpha: float, M: int, eps: float, bs_iters: int, rng: np.random.RandomState):
    N_dc, K = S.shape
    idx1, idx2 = train_test_split(np.arange(N_dc), test_size=0.5, random_state=rng.randint(0, 10_000))
    S1, S2 = S[idx1], S[idx2]

    V = np.abs(rng.randn(M, K))
    norms = np.linalg.norm(V, axis=1, ord=2)[:, None]; norms[norms == 0] = 1.0
    U = V / norms

    n1 = len(S1)
    target = 1.0 - alpha - 1.0/(n1 + 1)

    beta_lo, beta_hi = alpha / M, alpha
    for _ in range(bs_iters):
        beta = 0.5 * (beta_lo + beta_hi)
        qms = np.zeros(M)
        for m in range(M):
            proj = S1 @ U[m]
            qms[m] = _empirical_quantile_conservative(proj, 1 - beta)
        qms = np.maximum(qms, eps)
        inside = (S1 @ U.T) <= qms[None, :]
        cov = np.mean(np.all(inside, axis=1))
        if cov > target: beta_lo = beta
        else:            beta_hi = beta

    beta = beta_lo
    qms = np.zeros(M)
    for m in range(M):
        proj = S1 @ U[m]
        qms[m] = _empirical_quantile_conservative(proj, 1 - beta)
    qms = np.maximum(qms, eps)

    ratios = (S2 @ U.T) / qms[None, :]
    t_stars = np.max(ratios, axis=1)
    t_hat = _empirical_quantile_conservative(t_stars, 1 - alpha)
    q_hats = t_hat * qms
    return U, q_hats

def mvcp_predict_sets(probs_te: List[np.ndarray], y_te: np.ndarray, U: np.ndarray, q_hats: np.ndarray,
                      num_classes: int, alpha: float) -> Tuple[float, float]:
    N_te = len(y_te); K = len(probs_te); L = num_classes
    S_test = np.zeros((N_te, L, K), dtype=np.float64)
    for k in range(K):
        S_test[:, :, k] = 1.0 - probs_te[k]
    proj = np.einsum('nyk,mk->nym', S_test, U)
    inside = (proj <= q_hats[None, None, :])
    in_Q = np.all(inside, axis=2)
    cov = float(np.mean(in_Q[np.arange(N_te), y_te]))
    size = float(np.mean(np.sum(in_Q, axis=1)))
    return cov, size


# =============================================================================
# Main experiment
# =============================================================================

def run_experiments(cfg: CIFARConfig):
    results_cov, results_size, results_acc = [], [], []

    loader_kwargs = dict(
        batch_size=cfg.batch_size,
        pin_memory=True,
        num_workers=max(1, os.cpu_count() // 2),
        persistent_workers=True
    )

    for sim in trange(cfg.num_simulations, desc="Simulations"):
        print(f"\n=== Simulation {sim+1}/{cfg.num_simulations} ===")
        seed = cfg.train_seed_base + sim
        rng = np.random.RandomState(seed)

        # Splits
        X_trP, X_tmp, y_trP, y_tmp = train_test_split(
            X_train_full, Y_train_full, test_size=1 - cfg.train_frac, stratify=Y_train_full, random_state=seed
        )
        X_cal, X_rest, y_cal, y_rest = train_test_split(
            X_tmp, y_tmp, test_size=1 - cfg.cal_frac_of_temp, stratify=y_tmp, random_state=seed
        )
        X_fuse_tr, X_fuse_cal, y_fuse_tr, y_fuse_cal = train_test_split(
            X_rest, y_rest, test_size=1 - cfg.fuse_train_frac_of_rest, stratify=y_rest, random_state=seed
        )
        X_te, y_te = X_test_full, Y_test_full

        for K in tqdm(cfg.Ks, desc=f"Sim {sim+1}: K sweep", leave=False):
            print(f"\n  -> K = {K}")
            num_views = 4 if K == 4 else K

            # DataLoaders per view
            loaders = {}
            for v in trange(num_views, desc=f"K={K} build loaders", leave=False):
                tr_loader   = torch.utils.data.DataLoader(PatchesDataset(X_trP,      y_trP,      K, v, train_like=True),  shuffle=True,  **loader_kwargs)
                cal_loader  = torch.utils.data.DataLoader(PatchesDataset(X_cal,      y_cal,      K, v, train_like=False), shuffle=False, **loader_kwargs)
                ftr_loader  = torch.utils.data.DataLoader(PatchesDataset(X_fuse_tr,  y_fuse_tr,  K, v, train_like=False), shuffle=False, **loader_kwargs)
                fcal_loader = torch.utils.data.DataLoader(PatchesDataset(X_fuse_cal, y_fuse_cal, K, v, train_like=False), shuffle=False, **loader_kwargs)
                te_loader   = torch.utils.data.DataLoader(PatchesDataset(X_te,       y_te,       K, v, train_like=False), shuffle=False, **loader_kwargs)
                loaders[v] = dict(train=tr_loader, cal=cal_loader, ftr=ftr_loader, fcal=fcal_loader, te=te_loader)

            # Train per-view models + per-class calibration scores
            models, cal_classwise = [], []
            for v in trange(num_views, desc=f"K={K} train views", leave=False):
                m = PredictorCNN(num_classes=cfg.num_classes, width=64)
                m = train_model(
                    m, loaders[v]["train"],
                    num_epochs=cfg.epochs_per_view, lr=cfg.lr,
                    desc=f"View {v+1}/{num_views}",
                    eval_loader=loaders[v]["cal"],
                    eval_interval=cfg.eval_interval,
                    early_stop_patience=cfg.early_stop_patience,
                    early_stop_min_delta=cfg.early_stop_min_delta,
                )
                models.append(m)
                sc, lab = compute_nonconformity_scores(m, loaders[v]["cal"])
                cal_classwise.append(classwise_scores(sc, lab, cfg.num_classes))

            # Per-view p/probs for fusion train/cal/test
            pv_tr, pr_tr = [], []
            pv_cal, pr_cal = [], []
            pv_te,  pr_te  = [], []
            for v in trange(num_views, desc=f"K={K} infer views", leave=False):
                p, pr = per_view_pvalues_and_probs(models[v], cal_classwise[v], loaders[v]["ftr"], cfg.num_classes)
                pv_tr.append(p); pr_tr.append(pr)
                p, pr = per_view_pvalues_and_probs(models[v], cal_classwise[v], loaders[v]["fcal"], cfg.num_classes)
                pv_cal.append(p); pr_cal.append(pr)
                p, pr = per_view_pvalues_and_probs(models[v], cal_classwise[v], loaders[v]["te"],  cfg.num_classes)
                pv_te.append(p);  pr_te.append(pr)

            # Fusion LR on [p, prob]
            X_ftr = build_fusion_features(pv_tr, pr_tr)
            fusion_lr = LogisticRegression(max_iter=cfg.max_iter_lr, multi_class="multinomial", solver="lbfgs", random_state=seed)
            fusion_lr.fit(X_ftr, y_fuse_tr)

            # Conformal fusion p-values
            X_fcal = build_fusion_features(pv_cal, pr_cal)
            fused_probs_cal = fusion_lr.predict_proba(X_fcal)
            fused_cal_scores = fused_class_cal_scores(y_fuse_cal, fused_probs_cal, cfg.num_classes)
            X_ftest = build_fusion_features(pv_te, pr_te)
            fused_probs_test = fusion_lr.predict_proba(X_ftest)
            P_cf = fused_p_values_from_cal(fused_probs_test, fused_cal_scores)

            # Baselines with per-view p-values
            P_train = np.stack(pv_tr, axis=0)
            P_test  = np.stack(pv_te, axis=0)
            P_min    = min_p_value_fusion(P_test)
            P_fisher = fisher_fusion(P_test)
            P_adjF   = adjusted_fisher_fusion(P_train, y_fuse_tr, P_test, cfg.num_classes)

            pv_tr_concat = np.concatenate(pv_tr, axis=1)
            w_learned = learn_view_weights_from_pvals(pv_tr_concat, y_fuse_tr, num_views, cfg.num_classes, cfg.max_iter_lr, seed)
            P_wavgL = weighted_average_fusion(P_test, w_learned)

            # MVCP envelope
            X_dc = np.concatenate((X_fuse_tr, X_fuse_cal), axis=0)
            y_dc = np.concatenate((y_fuse_tr, y_fuse_cal))
            loaders_dc = {v: torch.utils.data.DataLoader(PatchesDataset(X_dc, y_dc, K, v, train_like=False), shuffle=False, **loader_kwargs)
                          for v in range(num_views)}
            probs_dc = [compute_probs(models[v], loaders_dc[v]) for v in range(num_views)]
            N_dc = len(y_dc)
            S = np.zeros((N_dc, num_views), dtype=np.float64)
            for k in range(num_views):
                S[:, k] = 1.0 - probs_dc[k][np.arange(N_dc), y_dc]
            U, q_hats = mvcp_quantile_envelope(S, cfg.alpha, cfg.mvcp_M, cfg.mvcp_eps, cfg.mvcp_bs_iters, rng)

            # Evaluate MVCP on test
            probs_te_full = [compute_probs(models[v], loaders[v]["te"]) for v in range(num_views)]
            cov_mvcp, size_mvcp = mvcp_predict_sets(probs_te_full, y_te, U, q_hats, cfg.num_classes, cfg.alpha)

            # Metrics
            cov_cf,   set_cf   = evaluate_sets(P_cf,     y_te, cfg.alpha)
            cov_min,  set_min  = evaluate_sets(P_min,    y_te, cfg.alpha)
            cov_fi,   set_fi   = evaluate_sets(P_fisher, y_te, cfg.alpha)
            cov_afi,  set_afi  = evaluate_sets(P_adjF,   y_te, cfg.alpha)
            cov_wl,   set_wl   = evaluate_sets(P_wavgL,  y_te, cfg.alpha)

            results_cov.append({"Sim": sim, "K": K, "MVCP": cov_mvcp*100, "Conformal Fusion": cov_cf*100,
                                "Min p-Value": cov_min*100, "Fisher": cov_fi*100,
                                "Adjusted Fisher": cov_afi*100, "Weighted Averaging": cov_wl*100})
            results_size.append({"Sim": sim, "K": K, "MVCP": size_mvcp, "Conformal Fusion": set_cf,
                                 "Min p-Value": set_min, "Fisher": set_fi,
                                 "Adjusted Fisher": set_afi, "Weighted Averaging": set_wl})

            # Reference accuracy using average of per-view probs
            avg_probs = np.mean(np.stack(pr_te, axis=0), axis=0)
            acc_ref = accuracy_score(y_te, np.argmax(avg_probs, axis=1)) * 100
            results_acc.append({"Sim": sim, "K": K, "Reference Acc (avg probs)": acc_ref})

    return pd.DataFrame(results_cov), pd.DataFrame(results_size), pd.DataFrame(results_acc)

# =============================================================================
# Save/print tables
# =============================================================================

def summarize_table(df: pd.DataFrame, methods: List[str], metric_name: str) -> pd.DataFrame:
    g = df.groupby("K").agg({m: ["mean", "std"] for m in methods})
    g.columns = [f"{a}_{b}" for a, b in g.columns]
    g = g.reset_index()
    for m in methods: g[m] = g.apply(lambda r: f"{r[f'{m}_mean']:.2f} ({r[f'{m}_std']:.2f})", axis=1)
    g.insert(1, "Metric", metric_name)
    return g[["K", "Metric"] + methods]

def save_tables(df_cov: pd.DataFrame, df_size: pd.DataFrame, df_acc: pd.DataFrame):
    methods = ["MVCP","Conformal Fusion","Min p-Value","Fisher","Adjusted Fisher","Weighted Averaging"]
    sum_cov = summarize_table(df_cov, methods, "Coverage (%)")
    sum_set = summarize_table(df_size, methods, "Average Set Size")
    sum_cov.to_csv("fcifar100_summary_coverage.csv", index=False)
    sum_set.to_csv("fcifar100_summary_setsize.csv", index=False)
    with open("fcifar100_summary_coverage.tex", "w") as f: f.write(sum_cov.to_latex(index=False, escape=False))
    with open("fcifar100_summary_setsize.tex", "w") as f: f.write(sum_set.to_latex(index=False, escape=False))
    cov_comp = sum_cov.drop(columns=["Metric"]).rename(columns={
        "MVCP": "MVCP Cov","Conformal Fusion": "CF Cov","Min p-Value": "MinPV Cov",
        "Fisher": "Fisher Cov","Adjusted Fisher": "AdjF Cov","Weighted Averaging": "WAvgL Cov"})
    set_comp = sum_set.drop(columns=["Metric"]).rename(columns={
        "MVCP": "MVCP Set","Conformal Fusion": "CF Set","Min p-Value": "MinPV Set",
        "Fisher": "Fisher Set","Adjusted Fisher": "AdjF Set","Weighted Averaging": "WAvgL Set"})
    final = cov_comp.merge(set_comp, on="K").sort_values("K")
    final.to_csv("fcifar100_summary_final.csv", index=False)
    with open("fcifar100_summary_final.tex", "w") as f: f.write(final.to_latex(index=False, escape=False))
    acc_means = df_acc.groupby("K").mean(numeric_only=True).reset_index()
    acc_means.to_csv("fcifar100_accuracy_summary.csv", index=False)
    with open("fcifar100_accuracy_summary.tex", "w") as f: f.write(acc_means.to_latex(index=False, float_format="%.2f"))
    print("\nSaved: coverage/setsize/final/accuracy CSV + TEX files.")

# =============================================================================
# Entry
# =============================================================================

def main():
    cfg = CIFARConfig()
    df_cov, df_size, df_acc = run_experiments(cfg)
    print("\n=== Coverage (raw rows) ==="); print(df_cov.head())
    print("\n=== Set Size (raw rows) ==="); print(df_size.head())
    save_tables(df_cov, df_size, df_acc)

    # Optional plot
    plt.style.use(['science','ieee', 'no-latex'])
    matplotlib.rcParams['pdf.fonttype'] = 42
    matplotlib.rcParams['ps.fonttype'] = 42
    sns.set_context("paper", font_scale=1.2)
    method_order = ['MVCP','Conformal Fusion','Min p-Value','Fisher','Adjusted Fisher','Weighted Averaging']
    palette = {'MVCP':'purple','Conformal Fusion':'blue','Min p-Value':'red','Fisher':'green','Adjusted Fisher':'cyan','Weighted Averaging':'orange'}
    fig, ax = plt.subplots(figsize=(5.9, 4.4))
    df_cov_melt = pd.melt(df_cov, id_vars=['K','Sim'], value_vars=method_order, var_name='Method', value_name='Coverage')
    sns.boxplot(x='K', y='Coverage', hue='Method', data=df_cov_melt, hue_order=method_order, palette=palette, ax=ax)
    ax.set_title('Coverage vs. Number of Views on CIFAR-100'); ax.set_xlabel('K'); ax.set_ylabel('Coverage (%)')
    ax.legend(loc='lower left'); sns.despine(right=True); plt.tight_layout()
    plt.savefig('fcifar100_coverage_boxplots.png', dpi=300); plt.savefig('fcifar100_coverage_boxplots.pdf')
    plt.show()

if __name__ == "__main__":
    main()


Using device: cuda
Files already downloaded and verified
Files already downloaded and verified


Simulations:   0%|          | 0/5 [00:00<?, ?it/s]


=== Simulation 1/5 ===


Sim 1: K sweep:   0%|          | 0/4 [00:00<?, ?it/s]


  -> K = 2


K=2 build loaders:   0%|          | 0/2 [00:00<?, ?it/s]

K=2 train views:   0%|          | 0/2 [00:00<?, ?it/s]

View 1/2:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 1/40:   0%|          | 0/118 [00:06<?, ?it/s]

Epoch 2/40:   0%|          | 0/118 [00:00<?, ?it/s]

Epoch 3/40:   0%|          | 0/118 [00:00<?, ?it/s]

Epoch 4/40:   0%|          | 0/118 [00:00<?, ?it/s]

Epoch 5/40:   0%|          | 0/118 [00:00<?, ?it/s]

[View 1/2] Epoch 5/40  Train CE: 2.8422  Acc: 0.2862
[View 1/2] >>> Eval @ epoch 5:  Cal CE: 3.0516  Acc: 0.2675


Epoch 6/40:   0%|          | 0/118 [00:00<?, ?it/s]

Epoch 7/40:   0%|          | 0/118 [00:00<?, ?it/s]

Epoch 8/40:   0%|          | 0/118 [00:00<?, ?it/s]

Epoch 9/40:   0%|          | 0/118 [00:00<?, ?it/s]

Epoch 10/40:   0%|          | 0/118 [00:00<?, ?it/s]

[View 1/2] Epoch 10/40  Train CE: 2.1226  Acc: 0.4361
[View 1/2] >>> Eval @ epoch 10:  Cal CE: 2.5827  Acc: 0.3565


Epoch 11/40:   0%|          | 0/118 [00:00<?, ?it/s]

Epoch 12/40:   0%|          | 0/118 [00:00<?, ?it/s]

Epoch 13/40:   0%|          | 0/118 [00:00<?, ?it/s]

Epoch 14/40:   0%|          | 0/118 [00:00<?, ?it/s]

Epoch 15/40:   0%|          | 0/118 [00:00<?, ?it/s]

[View 1/2] Epoch 15/40  Train CE: 1.6408  Acc: 0.5457
[View 1/2] >>> Eval @ epoch 15:  Cal CE: 2.2744  Acc: 0.4325


Epoch 16/40:   0%|          | 0/118 [00:00<?, ?it/s]

Epoch 17/40:   0%|          | 0/118 [00:00<?, ?it/s]

Epoch 18/40:   0%|          | 0/118 [00:00<?, ?it/s]

Epoch 19/40:   0%|          | 0/118 [00:00<?, ?it/s]

Epoch 20/40:   0%|          | 0/118 [00:00<?, ?it/s]

[View 1/2] Epoch 20/40  Train CE: 1.2343  Acc: 0.6479
[View 1/2] >>> Eval @ epoch 20:  Cal CE: 2.3458  Acc: 0.4345


Epoch 21/40:   0%|          | 0/118 [00:00<?, ?it/s]

Epoch 22/40:   0%|          | 0/118 [00:00<?, ?it/s]

Epoch 23/40:   0%|          | 0/118 [00:00<?, ?it/s]

Epoch 24/40:   0%|          | 0/118 [00:00<?, ?it/s]

[View 1/2] Early stopping at epoch 24. Best Cal CE: 2.1676 (restoring best weights).


Cal: nonconformity:   0%|          | 0/8 [00:00<?, ?it/s]

View 2/2:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 1/40:   0%|          | 0/118 [00:02<?, ?it/s]

Epoch 2/40:   0%|          | 0/118 [00:00<?, ?it/s]

Epoch 3/40:   0%|          | 0/118 [00:00<?, ?it/s]

Epoch 4/40:   0%|          | 0/118 [00:00<?, ?it/s]

Epoch 5/40:   0%|          | 0/118 [00:00<?, ?it/s]

[View 2/2] Epoch 5/40  Train CE: 2.7950  Acc: 0.2957
[View 2/2] >>> Eval @ epoch 5:  Cal CE: 2.8570  Acc: 0.2785


Epoch 6/40:   0%|          | 0/118 [00:00<?, ?it/s]

Epoch 7/40:   0%|          | 0/118 [00:00<?, ?it/s]

Epoch 8/40:   0%|          | 0/118 [00:00<?, ?it/s]

Epoch 9/40:   0%|          | 0/118 [00:00<?, ?it/s]

Epoch 10/40:   0%|          | 0/118 [00:00<?, ?it/s]

[View 2/2] Epoch 10/40  Train CE: 2.0893  Acc: 0.4410
[View 2/2] >>> Eval @ epoch 10:  Cal CE: 2.4421  Acc: 0.3680


Epoch 11/40:   0%|          | 0/118 [00:00<?, ?it/s]

Epoch 12/40:   0%|          | 0/118 [00:00<?, ?it/s]

Epoch 13/40:   0%|          | 0/118 [00:00<?, ?it/s]

Epoch 14/40:   0%|          | 0/118 [00:00<?, ?it/s]

Epoch 15/40:   0%|          | 0/118 [00:00<?, ?it/s]

[View 2/2] Epoch 15/40  Train CE: 1.6291  Acc: 0.5490
[View 2/2] >>> Eval @ epoch 15:  Cal CE: 2.2059  Acc: 0.4280


Epoch 16/40:   0%|          | 0/118 [00:00<?, ?it/s]

Epoch 17/40:   0%|          | 0/118 [00:00<?, ?it/s]

Epoch 18/40:   0%|          | 0/118 [00:00<?, ?it/s]

Epoch 19/40:   0%|          | 0/118 [00:00<?, ?it/s]

Epoch 20/40:   0%|          | 0/118 [00:00<?, ?it/s]

[View 2/2] Epoch 20/40  Train CE: 1.2405  Acc: 0.6441
[View 2/2] >>> Eval @ epoch 20:  Cal CE: 2.2331  Acc: 0.4495
[View 2/2] Early stopping at epoch 20. Best Cal CE: 2.2059 (restoring best weights).


Cal: nonconformity:   0%|          | 0/8 [00:00<?, ?it/s]

K=2 infer views:   0%|          | 0/2 [00:00<?, ?it/s]

Infer: probs:   0%|          | 0/36 [00:02<?, ?it/s]

Infer: probs:   0%|          | 0/36 [00:02<?, ?it/s]

Infer: probs:   0%|          | 0/40 [00:02<?, ?it/s]

Infer: probs:   0%|          | 0/36 [00:02<?, ?it/s]

Infer: probs:   0%|          | 0/36 [00:02<?, ?it/s]

Infer: probs:   0%|          | 0/40 [00:02<?, ?it/s]

Infer: probs:   0%|          | 0/71 [00:03<?, ?it/s]

Infer: probs:   0%|          | 0/71 [00:03<?, ?it/s]

Infer: probs:   0%|          | 0/40 [00:00<?, ?it/s]

Infer: probs:   0%|          | 0/40 [00:00<?, ?it/s]


  -> K = 3


K=3 build loaders:   0%|          | 0/3 [00:00<?, ?it/s]

K=3 train views:   0%|          | 0/3 [00:00<?, ?it/s]

View 1/3:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 1/40:   0%|          | 0/118 [00:03<?, ?it/s]

Epoch 2/40:   0%|          | 0/118 [00:00<?, ?it/s]

Epoch 3/40:   0%|          | 0/118 [00:00<?, ?it/s]

Epoch 4/40:   0%|          | 0/118 [00:00<?, ?it/s]

Epoch 5/40:   0%|          | 0/118 [00:00<?, ?it/s]

[View 1/3] Epoch 15/40  Train CE: 2.0887  Acc: 0.4445
[View 1/3] >>> Eval @ epoch 15:  Cal CE: 2.7068  Acc: 0.3445


Epoch 16/40:   0%|          | 0/118 [00:00<?, ?it/s]

Epoch 17/40:   0%|          | 0/118 [00:00<?, ?it/s]

Epoch 18/40:   0%|          | 0/118 [00:00<?, ?it/s]

Epoch 19/40:   0%|          | 0/118 [00:00<?, ?it/s]

Epoch 20/40:   0%|          | 0/118 [00:00<?, ?it/s]

[View 1/3] Epoch 20/40  Train CE: 1.6839  Acc: 0.5430
[View 1/3] >>> Eval @ epoch 20:  Cal CE: 2.7195  Acc: 0.3510


Epoch 21/40:   0%|          | 0/118 [00:00<?, ?it/s]

Epoch 22/40:   0%|          | 0/118 [00:00<?, ?it/s]

[View 1/3] Early stopping at epoch 22. Best Cal CE: 2.6380 (restoring best weights).


Cal: nonconformity:   0%|          | 0/8 [00:00<?, ?it/s]

View 2/3:   0%|          | 0/40 [00:00<?, ?it/s]

Epoch 1/40:   0%|          | 0/118 [00:03<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x152e363e9280>
Traceback (most recent call last):
  File "/home/siahkali/.local/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/home/siahkali/.local/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/apps/cent7/jupyterhub/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x152e363e9280>
Traceback (most recent call last):
  File "/home/siahkali/.local/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/home/siahkali/.local/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1587, in


if w.is_alive():    
  File "/home/siahkali/.local/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    <function _MultiProcessingDataLoaderIter.__del__ at 0x152e363e9280>
AssertionErrorself._shutdown_workers()
self._shutdown_workers():     
  File "/home/siahkali/.local/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
  File "/apps/cent7/jupyterhub/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
    if w.is_alive():Traceback (most recent call last):
can only test a child process  File "/apps/cent7/jupyterhub/lib/python3.9/multiprocessing/process.py", line 160, in is_alive

  File "/home/siahkali/.local/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__
if w.is_alive():      File "/apps/cent7/jupyterhub/lib/python3.9/multiprocessing/process.py", line 160, in is_alive

        
assert self._parent_pid == os.getpid(), 'can only test a child process'   

  File "/home/siahkali/.local/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    if w.is_alive():
  File "/apps/cent7/jupyterhub/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
    

      File "/home/siahkali/.local/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
  File "/apps/cent7/jupyterhub/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
self._shutdown_workers()self._shutdown_workers()    
    if w.is_alive():


    Exception ignored in:   File "/home/siahkali/.local/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    Traceback (most recent call last):
assert self._parent_pid == os.getpid(), 'can only test a child process'  File "/home/siahkali/.local/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
  File "/apps/cent7/jupyterhub/lib/python3.9/multiprocessing/process.py", line 160,

      File "/home/siahkali/.local/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__
self._shutdown_workers()
self._shutdown_workers()
if w.is_alive():        if w.is_alive():      File "/home/siahkali/.local/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers

assert self._parent_pid == os.getpid(), 'can only test a child process'self._shutdown_workers()
  File "/apps/cent7/jupyterhub/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
  File "/home/siahkali/.local/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers

        Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x152e363e9280>    
  File "/apps/cent7/jupyterhub/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
AssertionErrorif w.is_alive():  File "/home/siahkali/.local/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_worker

  File "/home/siahkali/.local/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__
<function _MultiProcessingDataLoaderIter.__del__ at 0x152e363e9280>
can only test a child process
      File "/home/siahkali/.local/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__
AssertionError  File "/apps/cent7/jupyterhub/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
        assert self._parent_pid == os.getpid(), 'can only test a child process'
Traceback (most recent call last):
:     
  File "/home/siahkali/.local/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__
can only test a child process    AssertionErrorself._shutdown_workers()
self._shutdown_workers()assert self._parent_pid == os.getpid(), 'can only test a child process'self._shutdown_workers()    

AssertionError  File "/home/siahkali/.local/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_w

Exception ignored in: self._shutdown_workers()  File "/home/siahkali/.local/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__
<function _MultiProcessingDataLoaderIter.__del__ at 0x152e363e9280>    

    
  File "/apps/cent7/jupyterhub/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
self._shutdown_workers()  File "/home/siahkali/.local/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers

self._shutdown_workers()Exception ignored in:   File "/home/siahkali/.local/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
Traceback (most recent call last):
        <function _MultiProcessingDataLoaderIter.__del__ at 0x152e363e9280>    if w.is_alive():Exception ignored in: Exception ignored in:   File "/home/siahkali/.local/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__

<function _MultiProcessingDataLoaderIter.__del__ at 0x152e363e