In [None]:
#mlp model

In [1]:
from __future__ import annotations

import warnings
from dataclasses import dataclass
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd
from scipy.stats import chi2
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.models import resnet18, efficientnet_b0, EfficientNet_B0_Weights
from torchvision.transforms import Resize

import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import scienceplots

warnings.filterwarnings("ignore")

# =============================================================================
# Config
# =============================================================================

@dataclass
class CIFARConfig:
    # core
    alpha: float = 0.1
    Ks: Tuple[int, ...] = (2, 3, 4, 6)
    num_classes: int = 100
    num_simulations: int = 10

    # training
    epochs_per_view: int = 75      # adjust for runtime; your old code used 200
    lr: float = 1e-3
    batch_size: int = 8192
    max_iter_lr: int = 1000         # for sklearn LR (fusion & weight-learning)
    train_seed_base: int = 41

    # data split fractions (similar structure to synthetic)
    train_frac: float = 0.5         # predictor training from full train set
    cal_frac_of_temp: float = 0.3   # portion of temp used as calibration for per-view
    fuse_train_frac_of_rest: float = 0.7  # remaining split into fusion_train/cal

# =============================================================================
# Torch / Data
# =============================================================================

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

transform = transforms.Compose([transforms.ToTensor()])

# Load once
train_dataset = datasets.CIFAR100(root="./data", train=True, download=True, transform=transform)
test_dataset  = datasets.CIFAR100(root="./data", train=False, download=True, transform=transform)

X_train_full = train_dataset.data.astype(np.float32) / 255.0   # (50000, 32, 32, 3)
Y_train_full = np.array(train_dataset.targets)
X_test_full  = test_dataset.data.astype(np.float32) / 255.0    # (10000, 32, 32, 3)
Y_test_full  = np.array(test_dataset.targets)

# =============================================================================
# Multi-view (patch) utilities
# =============================================================================

def split_image_into_k_patches(image: torch.Tensor, k: int) -> List[torch.Tensor]:
    # image: (C, H, W) = (3,32,32)
    C, H, W = image.shape
    if k == 4:
        # 2x2 grid
        patches = []
        for i in range(2):
            for j in range(2):
                patch = image[:, i*16:(i+1)*16, j*16:(j+1)*16]
                patches.append(patch)
        return patches
    else:
        # vertical stripes
        base_width = W // k
        remainder = W % k
        patches, start = [], 0
        for idx in range(k):
            width = base_width + (1 if idx < remainder else 0)
            patch = image[:, :, start:start+width]
            patches.append(patch)
            start += width
        return patches

class PatchesDataset(torch.utils.data.Dataset):
    def __init__(self, images: np.ndarray, labels: np.ndarray, k: int, view: int):
        self.images = images
        self.labels = labels
        self.k = k
        self.view = view
        self.resize = Resize((32, 32))

    def __len__(self) -> int:
        return len(self.labels)

    def __getitem__(self, idx: int):
        img = self.images[idx].transpose((2, 0, 1))   # (3,32,32)
        img = torch.tensor(img, dtype=torch.float32)
        patches = split_image_into_k_patches(img, self.k)
        patch = patches[self.view]
        patch = self.resize(patch)
        label = int(self.labels[idx])
        return patch, label

# =============================================================================
# ResNet18 per view
# =============================================================================

# class PredictorCNN(nn.Module):
#     def __init__(self, num_classes=100):
#         super().__init__()
#         self.model = resnet18(pretrained=False, num_classes=num_classes)

#     def forward(self, x):
#         return self.model(x)
    
class PredictorCNN(nn.Module):
    def __init__(self, num_classes=100):
        super().__init__()
        # Load pretrained EfficientNet-B0 on ImageNet
        self.model = efficientnet_b0(weights=EfficientNet_B0_Weights.IMAGENET1K_V1)
        
        # Replace final classification layer
        in_features = self.model.classifier[1].in_features
        self.model.classifier[1] = nn.Linear(in_features, num_classes)

    def forward(self, x):
        return self.model(x)


def train_model(model: nn.Module, train_loader, num_epochs=100, lr=1e-3):
    crit = nn.CrossEntropyLoss()
    opt  = optim.Adam(model.parameters(), lr=lr)
    model.to(device)
    for ep in range(num_epochs):
        model.train()
        for xb, yb in train_loader:
            xb, yb = xb.to(device), torch.tensor(yb, dtype=torch.long, device=device)
            opt.zero_grad()
            loss = crit(model(xb), yb)
            loss.backward()
            opt.step()
        if (ep+1) % 25 == 0:
            print(f"  epoch {ep+1}/{num_epochs}")
    return model

# =============================================================================
# Conformal utilities (torch models)
# =============================================================================

def compute_nonconformity_scores(model: nn.Module, loader) -> Tuple[np.ndarray, np.ndarray]:
    model.eval()
    scores, labels = [], []
    with torch.no_grad():
        for xb, yb in loader:
            xb = xb.to(device)
            logits = model(xb)
            probs  = F.softmax(logits, dim=1)
            idx    = torch.arange(probs.size(0), device=probs.device)
            true_p = probs[idx, torch.tensor(yb, dtype=torch.long, device=probs.device)]
            s = (1 - true_p).detach().cpu().numpy()
            scores.extend(s)
            labels.extend(yb.numpy())
    return np.asarray(scores, float), np.asarray(labels, int)

def classwise_scores(scores: np.ndarray, labels: np.ndarray, L: int) -> Dict[int, np.ndarray]:
    out = {c: [] for c in range(L)}
    for s, y in zip(scores, labels):
        out[int(y)].append(float(s))
    return {c: np.asarray(v, float) for c, v in out.items()}

def per_view_pvalues_and_probs(
    model: nn.Module, class_scores: Dict[int, np.ndarray], loader, L: int
) -> Tuple[np.ndarray, np.ndarray]:
    """Return p-values (n,L) and probs (n,L) for a single view."""
    model.eval()
    probs_all = []
    with torch.no_grad():
        for xb, _ in loader:
            xb = xb.to(device)
            logits = model(xb)
            probs  = F.softmax(logits, dim=1).detach().cpu().numpy()
            probs_all.append(probs)
    probs_all = np.vstack(probs_all)  # (n, L)

    n = probs_all.shape[0]
    pvals = np.zeros((n, L))
    for y in range(L):
        cal = class_scores.get(y, np.array([]))
        if cal.size == 0:
            pvals[:, y] = 1.0
        else:
            s_test = 1 - probs_all[:, y]
            counts = np.sum(cal[:, None] >= s_test[None, :], axis=0)
            pvals[:, y] = (1 + counts) / (len(cal) + 1)
    return pvals, probs_all

# =============================================================================
# Fusion utilities (same as synthetic)
# =============================================================================

def build_fusion_features(pvals_list: List[np.ndarray], probs_list: List[np.ndarray]) -> np.ndarray:
    """Horizontally stack [pvals, probs] for each view -> (n, K*2L)"""
    blocks = [np.hstack([pvals_list[k], probs_list[k]]) for k in range(len(pvals_list))]
    return np.hstack(blocks)

def min_p_value_fusion(P_all: np.ndarray) -> np.ndarray:
    """K * min_k p_k^y. P_all: (K,n,L) -> (n,L)"""
    K = P_all.shape[0]
    return K * np.min(P_all, axis=0)

def fisher_fusion(P_all: np.ndarray) -> np.ndarray:
    """Standard Fisher."""
    eps = 1e-12
    p = np.clip(P_all, eps, 1.0)
    T = -2 * np.sum(np.log(p), axis=0)
    df = 2 * P_all.shape[0]
    return 1 - chi2.cdf(T, df=df)

def adjusted_fisher_fusion(P_train: np.ndarray, y_train: np.ndarray, P_test: np.ndarray, L: int) -> np.ndarray:
    """
    Moment-matched Fisher per class: fit variance of T_y = sum_k -2log p_k^y on per-class train,
    then use scaled-chi-square CDF.
    """
    K, _, _ = P_train.shape
    n_test = P_test.shape[1]
    eps = 1e-12
    out = np.zeros((n_test, L))
    for y in range(L):
        idx = np.where(y_train == y)[0]
        if idx.size < 5:
            out[:, y] = fisher_fusion(P_test)[:, y]
            continue
        P_cls = np.clip(P_train[:, idx, y], eps, 1.0)  # (K, n_y)
        W = -2 * np.log(P_cls)                          # (K, n_y)
        Wc = W - W.mean(axis=1, keepdims=True)
        Sigma = (Wc @ Wc.T) / max(W.shape[1] - 1, 1)    # (K, K)
        var_T = np.sum(Sigma)
        if not np.isfinite(var_T) or var_T <= 0:
            var_T = 4 * K
        f_y = (8.0 * K * K) / var_T
        c_y = var_T / (4 * K)

        P_t = np.clip(P_test[:, :, y], eps, 1.0)
        T_t = -2 * np.sum(np.log(P_t), axis=0)
        out[:, y] = 1 - chi2.cdf(T_t / c_y, df=f_y)
    return out

def weighted_average_fusion(P_all: np.ndarray, weights: np.ndarray) -> np.ndarray:
    """sum_k w_k p_k^y; P_all: (K,n,L), weights: (K,)"""
    return np.tensordot(weights, P_all, axes=(0, 0))

def learn_view_weights_from_pvals(pv_train_concat: np.ndarray, y_train: np.ndarray, K: int, L: int, max_iter: int, seed: int) -> np.ndarray:
    """
    Train multinomial LR on p-only features to predict y.
    Convert coef_ (L, K*L) -> view weights by Frobenius norm per (L×L) block.
    """
    lr = LogisticRegression(multi_class="multinomial", solver="lbfgs", max_iter=max_iter, random_state=seed)
    lr.fit(pv_train_concat, y_train)
    B = lr.coef_  # (L, K*L)
    imps = []
    for k in range(K):
        block = B[:, k*L:(k+1)*L]
        imps.append(np.linalg.norm(block, ord="fro"))
    w = np.array(imps, float)
    w = np.maximum(w, 1e-12)
    return w / w.sum()

# Conformalization of fused model
def fused_class_cal_scores(y_cal: np.ndarray, fused_probs_cal: np.ndarray, L: int) -> Dict[int, np.ndarray]:
    s = 1 - fused_probs_cal[np.arange(len(y_cal)), y_cal]
    out = {c: [] for c in range(L)}
    for sc, yy in zip(s, y_cal):
        out[int(yy)].append(float(sc))
    return {c: np.asarray(v, float) for c, v in out.items()}

def fused_p_values_from_cal(fused_probs: np.ndarray, cal_class_scores: Dict[int, np.ndarray]) -> np.ndarray:
    n, L = fused_probs.shape
    out = np.zeros((n, L))
    for y in range(L):
        cal = cal_class_scores.get(y, np.array([]))
        if cal.size == 0:
            out[:, y] = 1.0
        else:
            s_test = 1 - fused_probs[:, y]
            counts = np.sum(cal[:, None] >= s_test[None, :], axis=0)
            out[:, y] = (1 + counts) / (len(cal) + 1)
    return out

def evaluate_sets(P: np.ndarray, y_true: np.ndarray, alpha: float) -> Tuple[float, float]:
    C = (P > alpha)
    cov = float(np.mean(C[np.arange(len(y_true)), y_true]))
    size = float(np.mean(C.sum(axis=1)))
    return cov, size

def summarize_table(df: pd.DataFrame, methods: List[str], metric_name: str) -> pd.DataFrame:
    g = df.groupby("K").agg({m: ["mean", "std"] for m in methods})
    g.columns = [f"{a}_{b}" for a, b in g.columns]
    g = g.reset_index()
    for m in methods:
        g[m] = g.apply(lambda r: f"{r[f'{m}_mean']:.2f} ({r[f'{m}_std']:.2f})", axis=1)
    g.insert(1, "Metric", metric_name)
    return g[["K", "Metric"] + methods]

# =============================================================================
# Main experiment
# =============================================================================

def run_experiments(cfg: CIFARConfig):
    results_cov, results_size, results_acc = [], [], []

    for sim in range(cfg.num_simulations):
        print(f"\n=== Simulation {sim+1}/{cfg.num_simulations} ===")
        seed = cfg.train_seed_base + sim
        rng = np.random.RandomState(seed)

        # Splits (train for view predictors, then cal/fusion splits)
        X_trP, X_tmp, y_trP, y_tmp = train_test_split(
            X_train_full, Y_train_full, test_size=1 - cfg.train_frac, stratify=Y_train_full, random_state=seed
        )
        # We will split X_tmp into cal and fusion pools
        X_cal, X_rest, y_cal, y_rest = train_test_split(
            X_tmp, y_tmp, test_size=1 - cfg.cal_frac_of_temp, stratify=y_tmp, random_state=seed
        )
        X_fuse_tr, X_fuse_cal, y_fuse_tr, y_fuse_cal = train_test_split(
            X_rest, y_rest, test_size=1 - cfg.fuse_train_frac_of_rest, stratify=y_rest, random_state=seed
        )
        X_te, y_te = X_test_full, Y_test_full

        for K in cfg.Ks:
            print(f"\n  -> K = {K}")
            num_views = 4 if K == 4 else K

            # Build loaders per view
            loaders = {}
            for v in range(num_views):
                tr_loader   = torch.utils.data.DataLoader(PatchesDataset(X_trP,      y_trP,      K, v), batch_size=cfg.batch_size, shuffle=True)
                cal_loader  = torch.utils.data.DataLoader(PatchesDataset(X_cal,      y_cal,      K, v), batch_size=cfg.batch_size, shuffle=False)
                ftr_loader  = torch.utils.data.DataLoader(PatchesDataset(X_fuse_tr,  y_fuse_tr,  K, v), batch_size=cfg.batch_size, shuffle=False)
                fcal_loader = torch.utils.data.DataLoader(PatchesDataset(X_fuse_cal, y_fuse_cal, K, v), batch_size=cfg.batch_size, shuffle=False)
                te_loader   = torch.utils.data.DataLoader(PatchesDataset(X_te,       y_te,       K, v), batch_size=cfg.batch_size, shuffle=False)
                loaders[v] = dict(train=tr_loader, cal=cal_loader, ftr=ftr_loader, fcal=fcal_loader, te=te_loader)

            # Train per-view CNNs
            models, cal_classwise = [], []
            for v in range(num_views):
                print(f"    [View {v+1}/{num_views}] training...")
                m = PredictorCNN(num_classes=cfg.num_classes)
                m = train_model(m, loaders[v]["train"], num_epochs=cfg.epochs_per_view, lr=cfg.lr)
                models.append(m)
                sc, lab = compute_nonconformity_scores(m, loaders[v]["cal"])
                cal_classwise.append(classwise_scores(sc, lab, cfg.num_classes))

            # Per-view p/probs for fusion train/cal/test
            pv_tr, pr_tr = [], []
            pv_cal, pr_cal = [], []
            pv_te,  pr_te  = [], []
            for v in range(num_views):
                p, pr = per_view_pvalues_and_probs(models[v], cal_classwise[v], loaders[v]["ftr"], cfg.num_classes)
                pv_tr.append(p); pr_tr.append(pr)
                p, pr = per_view_pvalues_and_probs(models[v], cal_classwise[v], loaders[v]["fcal"], cfg.num_classes)
                pv_cal.append(p); pr_cal.append(pr)
                p, pr = per_view_pvalues_and_probs(models[v], cal_classwise[v], loaders[v]["te"],  cfg.num_classes)
                pv_te.append(p);  pr_te.append(pr)

            # Build fusion LR on [p, prob] from fusion-train
            X_ftr = build_fusion_features(pv_tr, pr_tr)
            fusion_lr = LogisticRegression(max_iter=cfg.max_iter_lr, multi_class="multinomial", solver="lbfgs", random_state=seed)
            fusion_lr.fit(X_ftr, y_fuse_tr)

            # Fused calibration probs + classwise cal scores
            X_fcal = build_fusion_features(pv_cal, pr_cal)
            fused_probs_cal = fusion_lr.predict_proba(X_fcal)
            fused_cal_scores = fused_class_cal_scores(y_fuse_cal, fused_probs_cal, cfg.num_classes)

            # Fused test probs
            X_ftest = build_fusion_features(pv_te, pr_te)
            fused_probs_test = fusion_lr.predict_proba(X_ftest)

            # Our conformal-fused p-values
            P_cf = fused_p_values_from_cal(fused_probs_test, fused_cal_scores)

            # Baselines (stack per-view p-values)
            P_train = np.stack(pv_tr, axis=0)   # (K, n_tr, L)
            P_test  = np.stack(pv_te, axis=0)   # (K, n_te, L)

            P_min    = min_p_value_fusion(P_test)
            P_fisher = fisher_fusion(P_test)
            P_adjF   = adjusted_fisher_fusion(P_train, y_fuse_tr, P_test, cfg.num_classes)

            # Learned weighted average from p-values-only (K*L features)
            pv_tr_concat = np.concatenate(pv_tr, axis=1)  # (n_tr, K*L)
            w_learned = learn_view_weights_from_pvals(pv_tr_concat, y_fuse_tr, num_views, cfg.num_classes, cfg.max_iter_lr, seed)
            P_wavgL = weighted_average_fusion(P_test, w_learned)

            # Metrics
            cov_cf,   set_cf   = evaluate_sets(P_cf,     y_te, cfg.alpha)
            cov_min,  set_min  = evaluate_sets(P_min,    y_te, cfg.alpha)
            cov_fi,   set_fi   = evaluate_sets(P_fisher, y_te, cfg.alpha)
            cov_afi,  set_afi  = evaluate_sets(P_adjF,   y_te, cfg.alpha)
            cov_wl,   set_wl   = evaluate_sets(P_wavgL,  y_te, cfg.alpha)

            results_cov.append({
                "Sim": sim, "K": K,
                "Conformal Fusion": cov_cf * 100,
                "Min p-Value": cov_min * 100,
                "Fisher": cov_fi * 100,
                "Adjusted Fisher": cov_afi * 100,
                "Weighted Averaging": cov_wl * 100,
            })
            results_size.append({
                "Sim": sim, "K": K,
                "Conformal Fusion": set_cf,
                "Min p-Value": set_min,
                "Fisher": set_fi,
                "Adjusted Fisher": set_afi,
                "Weighted Averaging": set_wl,
            })

            # (Optional) Reference accuracy using simple average of per-view probs
            avg_probs = np.mean(np.stack(pr_te, axis=0), axis=0)
            acc_ref = accuracy_score(y_te, np.argmax(avg_probs, axis=1)) * 100
            results_acc.append({"Sim": sim, "K": K, "Reference Acc (avg probs)": acc_ref})

    return pd.DataFrame(results_cov), pd.DataFrame(results_size), pd.DataFrame(results_acc)

# =============================================================================
# Save/print tables (same style as synthetic)
# =============================================================================

def save_tables(df_cov: pd.DataFrame, df_size: pd.DataFrame, df_acc: pd.DataFrame):
    methods = [
        "Conformal Fusion",
        "Min p-Value",
        "Fisher",
        "Adjusted Fisher",
        "Weighted Averaging",
    ]
    sum_cov = summarize_table(df_cov, methods, "Coverage (%)")
    sum_set = summarize_table(df_size, methods, "Average Set Size")

    # CSV + LaTeX (CIFAR version)
    sum_cov.to_csv("cifar100_summary_coverage.csv", index=False)
    sum_set.to_csv("cifar100_summary_setsize.csv", index=False)
    with open("cifar100_summary_coverage.tex", "w") as f:
        f.write(sum_cov.to_latex(index=False, escape=False))
    with open("cifar100_summary_setsize.tex", "w") as f:
        f.write(sum_set.to_latex(index=False, escape=False))

    # Compact side-by-side
    cov_comp = sum_cov.drop(columns=["Metric"]).rename(columns={
        "Conformal Fusion": "CF Cov",
        "Min p-Value": "MinPV Cov",
        "Fisher": "Fisher Cov",
        "Adjusted Fisher": "AdjF Cov",
        "Weighted Averaging": "WAvgL Cov",
    })
    set_comp = sum_set.drop(columns=["Metric"]).rename(columns={
        "Conformal Fusion": "CF Set",
        "Min p-Value": "MinPV Set",
        "Fisher": "Fisher Set",
        "Adjusted Fisher": "AdjF Set",
        "Weighted Averaging": "WAvgL Set",
    })
    final = cov_comp.merge(set_comp, on="K").sort_values("K")
    final.to_csv("cifar100_summary_final.csv", index=False)
    with open("cifar100_summary_final.tex", "w") as f:
        f.write(final.to_latex(index=False, escape=False))

    # Acc means (if you want)
    acc_means = df_acc.groupby("K").mean(numeric_only=True).reset_index()
    acc_means.to_csv("cifar100_accuracy_summary.csv", index=False)
    with open("cifar100_accuracy_summary.tex", "w") as f:
        f.write(acc_means.to_latex(index=False, float_format="%.2f"))

    print("\nSaved:")
    print("  cifar100_summary_coverage.csv / .tex")
    print("  cifar100_summary_setsize.csv  / .tex")
    print("  cifar100_summary_final.csv    / .tex")
    print("  cifar100_accuracy_summary.csv / .tex")

# =============================================================================
# Entry
# =============================================================================

def main():
    cfg = CIFARConfig()
    df_cov, df_size, df_acc = run_experiments(cfg)
    print("\n=== Coverage (raw rows) ===")
    print(df_cov.head())
    print("\n=== Set Size (raw rows) ===")
    print(df_size.head())
    save_tables(df_cov, df_size, df_acc)

    # Plotting
    plt.style.use(['science','ieee', 'no-latex'])

    # Avoid Type 3 fonts
    matplotlib.rcParams['pdf.fonttype'] = 42
    matplotlib.rcParams['ps.fonttype'] = 42

    sns.set_context("paper", font_scale=1.2)

    # Consistent method order/colors
    method_order = ['Conformal Fusion', 'Min p-Value', "Fisher", 'Adjusted Fisher', 'Weighted Averaging']
    palette = {'Conformal Fusion': 'blue', 'Min p-Value': 'red', "Fisher": 'green', 'Adjusted Fisher': 'cyan', 'Weighted Averaging': 'orange'}

    fig, ax = plt.subplots(figsize=(5.9, 4.4))

    # Melt df_cov for plotting
    df_cov_melt = pd.melt(df_cov, id_vars=['K', 'Sim'], value_vars=method_order, var_name='Method', value_name='Coverage')

    # Boxplot: Coverage vs K
    sns.boxplot(x='K', y='Coverage', hue='Method',
                data=df_cov_melt, hue_order=method_order, palette=palette, ax=ax)
    ax.set_title('Coverage vs. Number of Views on CIFAR-100', fontsize=15)
    ax.set_xlabel('Number of Views (K)', fontsize=13)
    ax.set_ylabel('Coverage (%)', fontsize=13)
    ax.legend(loc='lower left')

    sns.despine(right=True)

    plt.tight_layout()
    plt.savefig('cifar100_coverage_boxplots.png', dpi=600)
    plt.savefig('cifar100_coverage_boxplots.pdf')
    plt.show()

if __name__ == "__main__":
    main()

Using device: cuda
Files already downloaded and verified
Files already downloaded and verified

=== Simulation 1/10 ===

  -> K = 2
    [View 1/2] training...


OutOfMemoryError: CUDA out of memory. Tried to allocate 60.00 MiB. GPU 0 has a total capacity of 23.60 GiB of which 57.56 MiB is free. Process 290455 has 5.71 GiB memory in use. Process 291637 has 4.68 GiB memory in use. Process 295770 has 2.60 GiB memory in use. Including non-PyTorch memory, this process has 10.53 GiB memory in use. Of the allocated memory 10.22 GiB is allocated by PyTorch, and 45.33 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
#LR model

In [None]:
from __future__ import annotations

import warnings
from dataclasses import dataclass
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd
from scipy.stats import chi2
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from torchvision import datasets, transforms
from torchvision.models import resnet18, efficientnet_b0, EfficientNet_B0_Weights
from torchvision.transforms import Resize

import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import scienceplots

warnings.filterwarnings("ignore")

# =============================================================================
# Config
# =============================================================================

@dataclass
class CIFARConfig:
    # core
    alpha: float = 0.1
    Ks: Tuple[int, ...] = (2, 3, 4, 6)
    num_classes: int = 100
    num_simulations: int = 2

    # training
    epochs_per_view: int = 100      # adjust for runtime; your old code used 200
    lr: float = 1e-3
    batch_size: int = 8192
    max_iter_lr: int = 300         # for sklearn LR (baselines/weight-learning)
    train_seed_base: int = 41

    # data split fractions (similar structure to synthetic)
    train_frac: float = 0.5         # predictor training from full train set
    cal_frac_of_temp: float = 0.3   # portion of temp used as calibration for per-view
    fuse_train_frac_of_rest: float = 0.7  # remaining split into fusion_train/cal

# =============================================================================
# Torch / Data
# =============================================================================

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

transform = transforms.Compose([transforms.ToTensor()])

# Load once
train_dataset = datasets.CIFAR100(root="./data", train=True, download=True, transform=transform)
test_dataset  = datasets.CIFAR100(root="./data", train=False, download=True, transform=transform)

X_train_full = train_dataset.data.astype(np.float32) / 255.0   # (50000, 32, 32, 3)
Y_train_full = np.array(train_dataset.targets)
X_test_full  = test_dataset.data.astype(np.float32) / 255.0    # (10000, 32, 32, 3)
Y_test_full  = np.array(test_dataset.targets)

# =============================================================================
# Multi-view (patch) utilities
# =============================================================================

def split_image_into_k_patches(image: torch.Tensor, k: int) -> List[torch.Tensor]:
    # image: (C, H, W) = (3,32,32)
    C, H, W = image.shape
    if k == 4:
        # 2x2 grid
        patches = []
        for i in range(2):
            for j in range(2):
                patch = image[:, i*16:(i+1)*16, j*16:(j+1)*16]
                patches.append(patch)
        return patches
    else:
        # vertical stripes
        base_width = W // k
        remainder = W % k
        patches, start = [], 0
        for idx in range(k):
            width = base_width + (1 if idx < remainder else 0)
            patch = image[:, :, start:start+width]
            patches.append(patch)
            start += width
        return patches

class PatchesDataset(torch.utils.data.Dataset):
    def __init__(self, images: np.ndarray, labels: np.ndarray, k: int, view: int):
        self.images = images
        self.labels = labels
        self.k = k
        self.view = view
        self.resize = Resize((32, 32))

    def __len__(self) -> int:
        return len(self.labels)

    def __getitem__(self, idx: int):
        img = self.images[idx].transpose((2, 0, 1))   # (3,32,32)
        img = torch.tensor(img, dtype=torch.float32)
        patches = split_image_into_k_patches(img, self.k)
        patch = patches[self.view]
        patch = self.resize(patch)
        label = int(self.labels[idx])
        return patch, label

# =============================================================================
# Per-view CNN (EfficientNet-B0 head)
# =============================================================================

# class PredictorCNN(nn.Module):
#     def __init__(self, num_classes=100):
#         super().__init__()
#         self.model = resnet18(pretrained=False, num_classes=num_classes)
#     def forward(self, x):
#         return self.model(x)

class PredictorCNN(nn.Module):
    def __init__(self, num_classes=100):
        super().__init__()
        # Load pretrained EfficientNet-B0 on ImageNet
        self.model = efficientnet_b0(weights=EfficientNet_B0_Weights.IMAGENET1K_V1)
        # Replace final classification layer
        in_features = self.model.classifier[1].in_features
        self.model.classifier[1] = nn.Linear(in_features, num_classes)

    def forward(self, x):
        return self.model(x)


def train_model(model: nn.Module, train_loader, num_epochs=100, lr=1e-3):
    crit = nn.CrossEntropyLoss()
    opt  = optim.Adam(model.parameters(), lr=lr)
    model.to(device)
    for ep in range(num_epochs):
        model.train()
        for xb, yb in train_loader:
            xb, yb = xb.to(device), torch.tensor(yb, dtype=torch.long, device=device)
            opt.zero_grad()
            loss = crit(model(xb), yb)
            loss.backward()
            opt.step()
        if (ep+1) % 25 == 0:
            print(f"  epoch {ep+1}/{num_epochs}")
    return model

# =============================================================================
# Conformal utilities (torch models)
# =============================================================================

def compute_nonconformity_scores(model: nn.Module, loader) -> Tuple[np.ndarray, np.ndarray]:
    model.eval()
    scores, labels = [], []
    with torch.no_grad():
        for xb, yb in loader:
            xb = xb.to(device)
            logits = model(xb)
            probs  = F.softmax(logits, dim=1)
            idx    = torch.arange(probs.size(0), device=probs.device)
            true_p = probs[idx, torch.tensor(yb, dtype=torch.long, device=probs.device)]
            s = (1 - true_p).detach().cpu().numpy()
            scores.extend(s)
            labels.extend(yb.numpy())
    return np.asarray(scores, float), np.asarray(labels, int)

def classwise_scores(scores: np.ndarray, labels: np.ndarray, L: int) -> Dict[int, np.ndarray]:
    out = {c: [] for c in range(L)}
    for s, y in zip(scores, labels):
        out[int(y)].append(float(s))
    return {c: np.asarray(v, float) for c, v in out.items()}

def per_view_pvalues_and_probs(
    model: nn.Module, class_scores: Dict[int, np.ndarray], loader, L: int
) -> Tuple[np.ndarray, np.ndarray]:
    """Return p-values (n,L) and probs (n,L) for a single view."""
    model.eval()
    probs_all = []
    with torch.no_grad():
        for xb, _ in loader:
            xb = xb.to(device)
            logits = model(xb)
            probs  = F.softmax(logits, dim=1).detach().cpu().numpy()
            probs_all.append(probs)
    probs_all = np.vstack(probs_all)  # (n, L)

    n = probs_all.shape[0]
    pvals = np.zeros((n, L))
    for y in range(L):
        cal = class_scores.get(y, np.array([]))
        if cal.size == 0:
            pvals[:, y] = 1.0
        else:
            s_test = 1 - probs_all[:, y]
            counts = np.sum(cal[:, None] >= s_test[None, :], axis=0)
            pvals[:, y] = (1 + counts) / (len(cal) + 1)
    return pvals, probs_all

# =============================================================================
# Fusion utilities (baseline + new richer fusion)
# =============================================================================

def build_fusion_features(pvals_list: List[np.ndarray], probs_list: List[np.ndarray]) -> np.ndarray:
    """Horizontally stack [pvals, probs] for each view -> (n, K*2L)"""
    blocks = [np.hstack([pvals_list[k], probs_list[k]]) for k in range(len(pvals_list))]
    return np.hstack(blocks)

def build_fusion_features_extended(pvals_list: List[np.ndarray], probs_list: List[np.ndarray]) -> np.ndarray:
    """
    For each view, stack [pvals, probs, log(probs+eps)] horizontally.
    Output shape: (n, K * 3L)
    """
    eps = 1e-12
    blocks = []
    for k in range(len(pvals_list)):
        pvals = pvals_list[k]
        probs = probs_list[k]
        logp  = np.log(np.clip(probs, eps, 1.0))
        blocks.append(np.hstack([pvals, probs, logp]))
    return np.hstack(blocks)

class FusionMLP(nn.Module):
    """
    A 3-layer MLP with GELU, dropout, batchnorm, and a light residual connection.
    Input: concatenated per-view features (p, prob, logprob). Output: logits over L classes.
    """
    def __init__(self, input_dim: int, num_classes: int, hidden1: int = 2048, hidden2: int = 1024, p_drop: float = 0.2):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden1)
        self.bn1 = nn.BatchNorm1d(hidden1)
        self.fc2 = nn.Linear(hidden1, hidden2)
        self.bn2 = nn.BatchNorm1d(hidden2)
        self.fc3 = nn.Linear(hidden2, num_classes)
        self.drop = nn.Dropout(p_drop)
        self.act  = nn.GELU()

    def forward(self, x):
        x = self.fc1(x)
        x = self.bn1(x)
        x = self.act(x)
        x = self.drop(x)

        h = self.fc2(x)
        h = self.bn2(h)
        h = self.act(h)
        h = self.drop(h)

        # Residual within hidden space
        x_proj = x
        if x_proj.shape[1] != h.shape[1]:
            x_proj = h
        h = h + x_proj

        out = self.fc3(h)
        return out

def train_fusion_mlp(X: np.ndarray, y: np.ndarray, L: int, seed: int,
                     lr: float = 3e-4, batch_size: int = 4096, epochs: int = 30,
                     weight_decay: float = 1e-4, label_smoothing: float = 0.05) -> FusionMLP:
    torch.manual_seed(seed)
    np.random.seed(seed)

    X_t = torch.tensor(X, dtype=torch.float32)
    y_t = torch.tensor(y, dtype=torch.long)

    ds = TensorDataset(X_t, y_t)
    dl = DataLoader(ds, batch_size=batch_size, shuffle=True, drop_last=False)

    model = FusionMLP(input_dim=X.shape[1], num_classes=L).to(device)
    opt = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    sched = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
    crit = nn.CrossEntropyLoss(label_smoothing=label_smoothing)

    model.train()
    for ep in range(epochs):
        total = 0.0
        for xb, yb in dl:
            xb = xb.to(device)
            yb = yb.to(device)
            opt.zero_grad()
            logits = model(xb)
            loss = crit(logits, yb)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0)
            opt.step()
            total += float(loss.item()) * xb.size(0)
        sched.step()
        if (ep + 1) % 5 == 0:
            print(f"    [Fusion MLP] epoch {ep+1}/{epochs}  loss={total/len(ds):.4f}")
    return model

@torch.no_grad()
def fusion_mlp_predict_proba(model: FusionMLP, X: np.ndarray) -> np.ndarray:
    model.eval()
    X_t = torch.tensor(X, dtype=torch.float32, device=device)
    logits = model(X_t)
    probs  = F.softmax(logits, dim=1).detach().cpu().numpy()
    return probs

# ---- Baseline fusion operations (unchanged) ----

def min_p_value_fusion(P_all: np.ndarray) -> np.ndarray:
    """K * min_k p_k^y. P_all: (K,n,L) -> (n,L)"""
    K = P_all.shape[0]
    return K * np.min(P_all, axis=0)

def fisher_fusion(P_all: np.ndarray) -> np.ndarray:
    """Standard Fisher."""
    eps = 1e-12
    p = np.clip(P_all, eps, 1.0)
    T = -2 * np.sum(np.log(p), axis=0)
    df = 2 * P_all.shape[0]
    return 1 - chi2.cdf(T, df=df)

def adjusted_fisher_fusion(P_train: np.ndarray, y_train: np.ndarray, P_test: np.ndarray, L: int) -> np.ndarray:
    """
    Moment-matched Fisher per class: fit variance of T_y = sum_k -2log p_k^y on per-class train,
    then use scaled-chi-square CDF.
    """
    K, _, _ = P_train.shape
    n_test = P_test.shape[1]
    eps = 1e-12
    out = np.zeros((n_test, L))
    for y in range(L):
        idx = np.where(y_train == y)[0]
        if idx.size < 5:
            out[:, y] = fisher_fusion(P_test)[:, y]
            continue
        P_cls = np.clip(P_train[:, idx, y], eps, 1.0)  # (K, n_y)
        W = -2 * np.log(P_cls)                          # (K, n_y)
        Wc = W - W.mean(axis=1, keepdims=True)
        Sigma = (Wc @ Wc.T) / max(W.shape[1] - 1, 1)    # (K, K)
        var_T = np.sum(Sigma)
        if not np.isfinite(var_T) or var_T <= 0:
            var_T = 4 * K
        f_y = (8.0 * K * K) / var_T
        c_y = var_T / (4 * K)

        P_t = np.clip(P_test[:, :, y], eps, 1.0)
        T_t = -2 * np.sum(np.log(P_t), axis=0)
        out[:, y] = 1 - chi2.cdf(T_t / c_y, df=f_y)
    return out

def weighted_average_fusion(P_all: np.ndarray, weights: np.ndarray) -> np.ndarray:
    """sum_k w_k p_k^y; P_all: (K,n,L), weights: (K,)"""
    return np.tensordot(weights, P_all, axes=(0, 0))

def learn_view_weights_from_pvals(pv_train_concat: np.ndarray, y_train: np.ndarray, K: int, L: int, max_iter: int, seed: int) -> np.ndarray:
    """
    Train multinomial LR on p-only features to predict y.
    Convert coef_ (L, K*L) -> view weights by Frobenius norm per (L×L) block.
    """
    lr = LogisticRegression(multi_class="multinomial", solver="lbfgs", max_iter=max_iter, random_state=seed)
    lr.fit(pv_train_concat, y_train)
    B = lr.coef_  # (L, K*L)
    imps = []
    for k in range(K):
        block = B[:, k*L:(k+1)*L]
        imps.append(np.linalg.norm(block, ord="fro"))
    w = np.array(imps, float)
    w = np.maximum(w, 1e-12)
    return w / w.sum()

# Conformalization of fused model
def fused_class_cal_scores(y_cal: np.ndarray, fused_probs_cal: np.ndarray, L: int) -> Dict[int, np.ndarray]:
    s = 1 - fused_probs_cal[np.arange(len(y_cal)), y_cal]
    out = {c: [] for c in range(L)}
    for sc, yy in zip(s, y_cal):
        out[int(yy)].append(float(sc))
    return {c: np.asarray(v, float) for c, v in out.items()}

def fused_p_values_from_cal(fused_probs: np.ndarray, cal_class_scores: Dict[int, np.ndarray]) -> np.ndarray:
    n, L = fused_probs.shape
    out = np.zeros((n, L))
    for y in range(L):
        cal = cal_class_scores.get(y, np.array([]))
        if cal.size == 0:
            out[:, y] = 1.0
        else:
            s_test = 1 - fused_probs[:, y]
            counts = np.sum(cal[:, None] >= s_test[None, :], axis=0)
            out[:, y] = (1 + counts) / (len(cal) + 1)
    return out

def evaluate_sets(P: np.ndarray, y_true: np.ndarray, alpha: float) -> Tuple[float, float]:
    C = (P > alpha)
    cov = float(np.mean(C[np.arange(len(y_true)), y_true]))
    size = float(np.mean(C.sum(axis=1)))
    return cov, size

def summarize_table(df: pd.DataFrame, methods: List[str], metric_name: str) -> pd.DataFrame:
    g = df.groupby("K").agg({m: ["mean", "std"] for m in methods})
    g.columns = [f"{a}_{b}" for a, b in g.columns]
    g = g.reset_index()
    for m in methods:
        g[m] = g.apply(lambda r: f"{r[f'{m}_mean']:.2f} ({r[f'{m}_std']:.2f})", axis=1)
    g.insert(1, "Metric", metric_name)
    return g[["K", "Metric"] + methods]

# =============================================================================
# Main experiment
# =============================================================================

def run_experiments(cfg: CIFARConfig):
    results_cov, results_size, results_acc = [], [], []

    for sim in range(cfg.num_simulations):
        print(f"\n=== Simulation {sim+1}/{cfg.num_simulations} ===")
        seed = cfg.train_seed_base + sim
        rng = np.random.RandomState(seed)

        # Splits (train for view predictors, then cal/fusion splits)
        X_trP, X_tmp, y_trP, y_tmp = train_test_split(
            X_train_full, Y_train_full, test_size=1 - cfg.train_frac, stratify=Y_train_full, random_state=seed
        )
        # We will split X_tmp into cal and fusion pools
        X_cal, X_rest, y_cal, y_rest = train_test_split(
            X_tmp, y_tmp, test_size=1 - cfg.cal_frac_of_temp, stratify=y_tmp, random_state=seed
        )
        X_fuse_tr, X_fuse_cal, y_fuse_tr, y_fuse_cal = train_test_split(
            X_rest, y_rest, test_size=1 - cfg.fuse_train_frac_of_rest, stratify=y_rest, random_state=seed
        )
        X_te, y_te = X_test_full, Y_test_full

        for K in cfg.Ks:
            print(f"\n  -> K = {K}")
            num_views = 4 if K == 4 else K

            # Build loaders per view
            loaders = {}
            for v in range(num_views):
                tr_loader   = torch.utils.data.DataLoader(PatchesDataset(X_trP,      y_trP,      K, v), batch_size=cfg.batch_size, shuffle=True)
                cal_loader  = torch.utils.data.DataLoader(PatchesDataset(X_cal,      y_cal,      K, v), batch_size=cfg.batch_size, shuffle=False)
                ftr_loader  = torch.utils.data.DataLoader(PatchesDataset(X_fuse_tr,  y_fuse_tr,  K, v), batch_size=cfg.batch_size, shuffle=False)
                fcal_loader = torch.utils.data.DataLoader(PatchesDataset(X_fuse_cal, y_fuse_cal, K, v), batch_size=cfg.batch_size, shuffle=False)
                te_loader   = torch.utils.data.DataLoader(PatchesDataset(X_te,       y_te,       K, v), batch_size=cfg.batch_size, shuffle=False)
                loaders[v] = dict(train=tr_loader, cal=cal_loader, ftr=ftr_loader, fcal=fcal_loader, te=te_loader)

            # Train per-view CNNs
            models, cal_classwise = [], []
            for v in range(num_views):
                print(f"    [View {v+1}/{num_views}] training...")
                m = PredictorCNN(num_classes=cfg.num_classes)
                m = train_model(m, loaders[v]["train"], num_epochs=cfg.epochs_per_view, lr=cfg.lr)
                models.append(m)
                sc, lab = compute_nonconformity_scores(m, loaders[v]["cal"])
                cal_classwise.append(classwise_scores(sc, lab, cfg.num_classes))

            # Per-view p/probs for fusion train/cal/test
            pv_tr, pr_tr = [], []
            pv_cal, pr_cal = [], []
            pv_te,  pr_te  = [], []
            for v in range(num_views):
                p, pr = per_view_pvalues_and_probs(models[v], cal_classwise[v], loaders[v]["ftr"], cfg.num_classes)
                pv_tr.append(p); pr_tr.append(pr)
                p, pr = per_view_pvalues_and_probs(models[v], cal_classwise[v], loaders[v]["fcal"], cfg.num_classes)
                pv_cal.append(p); pr_cal.append(pr)
                p, pr = per_view_pvalues_and_probs(models[v], cal_classwise[v], loaders[v]["te"],  cfg.num_classes)
                pv_te.append(p);  pr_te.append(pr)

            # ------------------ NEW: richer fusion model (MLP) ------------------
            # Build richer fusion features: [p, prob, log(prob)]
            X_ftr   = build_fusion_features_extended(pv_tr, pr_tr)
            X_fcal  = build_fusion_features_extended(pv_cal, pr_cal)
            X_ftest = build_fusion_features_extended(pv_te,  pr_te)

            # Train a deeper Fusion-MLP on fusion-train
            fusion_mlp = train_fusion_mlp(
                X_ftr, y_fuse_tr, L=cfg.num_classes, seed=seed,
                lr=3e-4, batch_size=min(cfg.batch_size, 4096), epochs=100,
                weight_decay=1e-4, label_smoothing=0.05
            )

            # Fused calibration/test probabilities from the MLP
            fused_probs_cal  = fusion_mlp_predict_proba(fusion_mlp, X_fcal)
            fused_cal_scores = fused_class_cal_scores(y_fuse_cal, fused_probs_cal, cfg.num_classes)
            fused_probs_test = fusion_mlp_predict_proba(fusion_mlp, X_ftest)
            # -------------------------------------------------------------------

            # Our conformal-fused p-values
            P_cf = fused_p_values_from_cal(fused_probs_test, fused_cal_scores)

            # Baselines (stack per-view p-values) -- unchanged
            P_train = np.stack(pv_tr, axis=0)   # (K, n_tr, L)
            P_test  = np.stack(pv_te, axis=0)   # (K, n_te, L)

            P_min    = min_p_value_fusion(P_test)
            P_fisher = fisher_fusion(P_test)
            P_adjF   = adjusted_fisher_fusion(P_train, y_fuse_tr, P_test, cfg.num_classes)

            # Learned weighted average from p-values-only (K*L features)
            pv_tr_concat = np.concatenate(pv_tr, axis=1)  # (n_tr, K*L)
            w_learned = learn_view_weights_from_pvals(pv_tr_concat, y_fuse_tr, num_views, cfg.num_classes, cfg.max_iter_lr, seed)
            P_wavgL = weighted_average_fusion(P_test, w_learned)

            # Metrics
            cov_cf,   set_cf   = evaluate_sets(P_cf,     y_te, cfg.alpha)
            cov_min,  set_min  = evaluate_sets(P_min,    y_te, cfg.alpha)
            cov_fi,   set_fi   = evaluate_sets(P_fisher, y_te, cfg.alpha)
            cov_afi,  set_afi  = evaluate_sets(P_adjF,   y_te, cfg.alpha)
            cov_wl,   set_wl   = evaluate_sets(P_wavgL,  y_te, cfg.alpha)

            results_cov.append({
                "Sim": sim, "K": K,
                "Conformal Fusion": cov_cf * 100,
                "Min p-Value": cov_min * 100,
                "Fisher": cov_fi * 100,
                "Adjusted Fisher": cov_afi * 100,
                "Weighted Averaging": cov_wl * 100,
            })
            results_size.append({
                "Sim": sim, "K": K,
                "Conformal Fusion": set_cf,
                "Min p-Value": set_min,
                "Fisher": set_fi,
                "Adjusted Fisher": set_afi,
                "Weighted Averaging": set_wl,
            })

            # (Optional) Reference accuracy using simple average of per-view probs
            avg_probs = np.mean(np.stack(pr_te, axis=0), axis=0)
            acc_ref = accuracy_score(y_te, np.argmax(avg_probs, axis=1)) * 100
            results_acc.append({"Sim": sim, "K": K, "Reference Acc (avg probs)": acc_ref})

    return pd.DataFrame(results_cov), pd.DataFrame(results_size), pd.DataFrame(results_acc)

# =============================================================================
# Save/print tables (same style as synthetic)
# =============================================================================

def save_tables(df_cov: pd.DataFrame, df_size: pd.DataFrame, df_acc: pd.DataFrame):
    methods = [
        "Conformal Fusion",
        "Min p-Value",
        "Fisher",
        "Adjusted Fisher",
        "Weighted Averaging",
    ]
    sum_cov = summarize_table(df_cov, methods, "Coverage (%)")
    sum_set = summarize_table(df_size, methods, "Average Set Size")

    # CSV + LaTeX (CIFAR version)
    sum_cov.to_csv("cifar100_summary_coverage.csv", index=False)
    sum_set.to_csv("cifar100_summary_setsize.csv", index=False)
    with open("cifar100_summary_coverage.tex", "w") as f:
        f.write(sum_cov.to_latex(index=False, escape=False))
    with open("cifar100_summary_setsize.tex", "w") as f:
        f.write(sum_set.to_latex(index=False, escape=False))

    # Compact side-by-side
    cov_comp = sum_cov.drop(columns=["Metric"]).rename(columns={
        "Conformal Fusion": "CF Cov",
        "Min p-Value": "MinPV Cov",
        "Fisher": "Fisher Cov",
        "Adjusted Fisher": "AdjF Cov",
        "Weighted Averaging": "WAvgL Cov",
    })
    set_comp = sum_set.drop(columns=["Metric"]).rename(columns={
        "Conformal Fusion": "CF Set",
        "Min p-Value": "MinPV Set",
        "Fisher": "Fisher Set",
        "Adjusted Fisher": "AdjF Set",
        "Weighted Averaging": "WAvgL Set",
    })
    final = cov_comp.merge(set_comp, on="K").sort_values("K")
    final.to_csv("cifar100_summary_final.csv", index=False)
    with open("cifar100_summary_final.tex", "w") as f:
        f.write(final.to_latex(index=False, escape=False))

    # Acc means (if you want)
    acc_means = df_acc.groupby("K").mean(numeric_only=True).reset_index()
    acc_means.to_csv("cifar100_accuracy_summary.csv", index=False)
    with open("cifar100_accuracy_summary.tex", "w") as f:
        f.write(acc_means.to_latex(index=False, float_format="%.2f"))

    print("\nSaved:")
    print("  cifar100_summary_coverage.csv / .tex")
    print("  cifar100_summary_setsize.csv  / .tex")
    print("  cifar100_summary_final.csv    / .tex")
    print("  cifar100_accuracy_summary.csv / .tex")

# =============================================================================
# Entry
# =============================================================================

def main():
    cfg = CIFARConfig()
    df_cov, df_size, df_acc = run_experiments(cfg)
    print("\n=== Coverage (raw rows) ===")
    print(df_cov.head())
    print("\n=== Set Size (raw rows) ===")
    print(df_size.head())
    save_tables(df_cov, df_size, df_acc)

    # Plotting
    plt.style.use(['science','ieee', 'no-latex'])

    # Avoid Type 3 fonts
    matplotlib.rcParams['pdf.fonttype'] = 42
    matplotlib.rcParams['ps.fonttype'] = 42

    sns.set_context("paper", font_scale=1.2)

    # Consistent method order/colors
    method_order = ['Conformal Fusion', 'Min p-Value', "Fisher", 'Adjusted Fisher', 'Weighted Averaging']
    palette = {'Conformal Fusion': 'blue', 'Min p-Value': 'red', "Fisher": 'green', 'Adjusted Fisher': 'cyan', 'Weighted Averaging': 'orange'}

    fig, ax = plt.subplots(figsize=(5.9, 4.4))

    # Melt df_cov for plotting
    df_cov_melt = pd.melt(df_cov, id_vars=['K', 'Sim'], value_vars=method_order, var_name='Method', value_name='Coverage')

    # Boxplot: Coverage vs K
    sns.boxplot(x='K', y='Coverage', hue='Method',
                data=df_cov_melt, hue_order=method_order, palette=palette, ax=ax)
    ax.set_title('Coverage vs. Number of Views on CIFAR-100', fontsize=15)
    ax.set_xlabel('Number of Views (K)', fontsize=13)
    ax.set_ylabel('Coverage (%)', fontsize=13)
    ax.legend(loc='lower left')

    sns.despine(right=True)

    plt.tight_layout()
    plt.savefig('cifar100_coverage_boxplots.png', dpi=600)
    plt.savefig('cifar100_coverage_boxplots.pdf')
    plt.show()

if __name__ == "__main__":
    main()


Using device: cuda
Files already downloaded and verified
Files already downloaded and verified

=== Simulation 1/2 ===

  -> K = 2
    [View 1/2] training...
  epoch 50/100
  epoch 75/100
  epoch 100/100
    [View 2/2] training...
  epoch 25/100
  epoch 50/100
  epoch 75/100
  epoch 100/100
    [Fusion MLP] epoch 5/100  loss=2.1790
    [Fusion MLP] epoch 10/100  loss=1.8050
    [Fusion MLP] epoch 15/100  loss=1.5435
    [Fusion MLP] epoch 20/100  loss=1.3156
    [Fusion MLP] epoch 25/100  loss=1.1286
    [Fusion MLP] epoch 30/100  loss=0.9805
    [Fusion MLP] epoch 35/100  loss=0.8706
    [Fusion MLP] epoch 40/100  loss=0.7870
    [Fusion MLP] epoch 45/100  loss=0.7344
    [Fusion MLP] epoch 50/100  loss=0.6967
    [Fusion MLP] epoch 55/100  loss=0.6672
    [Fusion MLP] epoch 60/100  loss=0.6485
    [Fusion MLP] epoch 65/100  loss=0.6353
    [Fusion MLP] epoch 70/100  loss=0.6265
    [Fusion MLP] epoch 75/100  loss=0.6185
    [Fusion MLP] epoch 80/100  loss=0.6127
    [Fusion MLP] epoc