In [5]:
# CIFAR-10 (K=4) fusion-architecture + input-feature ablation:
# Compare Stage-2 fusion model:
#   (A) Multinomial LR on [p; prob]          (your default)
#   (B) 2-hidden-layer MLP on [p; prob]
#   (C) Multinomial LR on p-only
# Everything else matches: per-view CNNs, splits, seeds, epochs, lr, batch size, sims, alpha.
#
# Paste-and-run. Produces a small summary table (mean (std)) over simulations.

from __future__ import annotations

import warnings
from dataclasses import dataclass
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from torchvision import datasets, transforms

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

warnings.filterwarnings("ignore")

# =============================================================================
# Config (match your existing params)
# =============================================================================

@dataclass
class Cfg:
    alpha: float = 0.1
    K: int = 4
    L: int = 10
    num_simulations: int = 10

    # per-view CNN training
    epochs_per_view: int = 100
    lr: float = 1e-3
    batch_size: int = 512

    # fusion models
    max_iter_lr: int = 1000          # sklearn LR iterations
    fusion_epochs: int = 100         # MLP fusion epochs (explicit)
    fusion_hidden1: int = 128
    fusion_hidden2: int = 128
    fusion_weight_decay: float = 1e-4

    # splits (same structure as your script)
    train_frac: float = 0.5
    cal_frac_of_temp: float = 0.3
    fuse_train_frac_of_rest: float = 0.7

    seed_base: int = 42

cfg = Cfg()

# =============================================================================
# Device / determinism
# =============================================================================

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

def set_all_seeds(seed: int):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False

# =============================================================================
# Data
# =============================================================================

transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
test_dataset  = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)

X_train_full = train_dataset.data.astype(np.float32) / 255.0
Y_train_full = np.array(train_dataset.targets, dtype=int)
X_test_full  = test_dataset.data.astype(np.float32) / 255.0
Y_test_full  = np.array(test_dataset.targets, dtype=int)

# =============================================================================
# Multi-view patching (same as your code)
# =============================================================================

def split_image_into_k_patches(image: torch.Tensor, k: int) -> List[torch.Tensor]:
    # image: (3,32,32)
    C, H, W = image.shape
    if k == 4:
        patches = []
        for i in range(2):
            for j in range(2):
                patches.append(image[:, i*16:(i+1)*16, j*16:(j+1)*16])
        return patches
    else:
        base_width = W // k
        remainder = W % k
        patches, start = [], 0
        for idx in range(k):
            width = base_width + (1 if idx < remainder else 0)
            patches.append(image[:, :, start:start+width])
            start += width
        return patches

class PatchesDataset(torch.utils.data.Dataset):
    def __init__(self, images: np.ndarray, labels: np.ndarray, k: int, view: int):
        self.images = images
        self.labels = labels
        self.k = k
        self.view = view

    def __len__(self) -> int:
        return len(self.labels)

    def __getitem__(self, idx: int):
        img = self.images[idx].transpose((2, 0, 1))  # (3,32,32)
        img = torch.tensor(img, dtype=torch.float32)
        patches = split_image_into_k_patches(img, self.k)
        x = patches[self.view]
        y = int(self.labels[idx])
        return x, y

# =============================================================================
# Per-view CNN (same as your code)
# =============================================================================

class PredictorCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.pool  = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.fc1   = None
        self.fc2   = None
        self.num_classes = num_classes

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        if self.fc1 is None:
            b, c, h, w = x.shape
            self.fc1 = nn.Linear(c*h*w, 128).to(x.device)
            self.fc2 = nn.Linear(128, self.num_classes).to(x.device)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        return self.fc2(x)

def train_view_model(model: nn.Module, train_loader, num_epochs: int, lr: float):
    crit = nn.CrossEntropyLoss()
    opt  = optim.Adam(model.parameters(), lr=lr)
    model.to(device)
    for ep in range(num_epochs):
        model.train()
        for xb, yb in train_loader:
            xb = xb.to(device)
            yb = torch.tensor(yb, dtype=torch.long, device=device)
            opt.zero_grad()
            loss = crit(model(xb), yb)
            loss.backward()
            opt.step()
        if (ep + 1) % 25 == 0:
            print(f"      view-epoch {ep+1}/{num_epochs}")
    return model

# =============================================================================
# Conformal utilities
# =============================================================================

@torch.no_grad()
def compute_true_label_scores(model: nn.Module, loader) -> Tuple[np.ndarray, np.ndarray]:
    model.eval()
    scores, labels = [], []
    for xb, yb in loader:
        xb = xb.to(device)
        probs = F.softmax(model(xb), dim=1)
        idx = torch.arange(probs.size(0), device=probs.device)
        true_p = probs[idx, torch.tensor(yb, dtype=torch.long, device=probs.device)]
        s = (1.0 - true_p).detach().cpu().numpy()
        scores.append(s)
        labels.append(yb.numpy())
    return np.concatenate(scores, axis=0), np.concatenate(labels, axis=0)

def classwise_scores(scores: np.ndarray, labels: np.ndarray, L: int) -> Dict[int, np.ndarray]:
    out = {c: [] for c in range(L)}
    for s, y in zip(scores, labels):
        out[int(y)].append(float(s))
    return {c: np.asarray(v, float) for c, v in out.items()}

@torch.no_grad()
def per_view_pvalues_and_probs(model: nn.Module, cal_scores: Dict[int, np.ndarray], loader, L: int) -> Tuple[np.ndarray, np.ndarray]:
    model.eval()
    probs_all = []
    for xb, _ in loader:
        xb = xb.to(device)
        probs_all.append(F.softmax(model(xb), dim=1).detach().cpu().numpy())
    probs_all = np.vstack(probs_all)  # (n, L)

    n = probs_all.shape[0]
    pvals = np.zeros((n, L), dtype=float)
    for y in range(L):
        cal = cal_scores.get(y, np.array([], dtype=float))
        if cal.size == 0:
            pvals[:, y] = 1.0
        else:
            s_test = 1.0 - probs_all[:, y]
            counts = np.sum(cal[:, None] >= s_test[None, :], axis=0)
            pvals[:, y] = (1.0 + counts) / (len(cal) + 1.0)
    return pvals, probs_all

def build_fusion_features_full(pvals_list: List[np.ndarray], probs_list: List[np.ndarray]) -> np.ndarray:
    """[p; prob] blocks per view => (n, K*2L)"""
    blocks = [np.hstack([pvals_list[k], probs_list[k]]) for k in range(len(pvals_list))]
    return np.hstack(blocks)

def build_fusion_features_p_only(pvals_list: List[np.ndarray]) -> np.ndarray:
    """p-only blocks per view => (n, K*L)"""
    return np.hstack(pvals_list)

def fused_class_cal_scores(y_cal: np.ndarray, fused_probs_cal: np.ndarray, L: int) -> Dict[int, np.ndarray]:
    s = 1.0 - fused_probs_cal[np.arange(len(y_cal)), y_cal]
    out = {c: [] for c in range(L)}
    for sc, yy in zip(s, y_cal):
        out[int(yy)].append(float(sc))
    return {c: np.asarray(v, float) for c, v in out.items()}

def fused_p_values_from_cal(fused_probs: np.ndarray, cal_class_scores: Dict[int, np.ndarray]) -> np.ndarray:
    n, L = fused_probs.shape
    out = np.zeros((n, L), dtype=float)
    for y in range(L):
        cal = cal_class_scores.get(y, np.array([], dtype=float))
        if cal.size == 0:
            out[:, y] = 1.0
        else:
            s_test = 1.0 - fused_probs[:, y]
            counts = np.sum(cal[:, None] >= s_test[None, :], axis=0)
            out[:, y] = (1.0 + counts) / (len(cal) + 1.0)
    return out

def eval_sets_from_pvals(P: np.ndarray, y_true: np.ndarray, alpha: float) -> Tuple[float, float]:
    C = (P > alpha)
    cov = float(np.mean(C[np.arange(len(y_true)), y_true]))
    size = float(np.mean(C.sum(axis=1)))
    return cov, size

# =============================================================================
# Fusion MLP (2 hidden layers) for [p; prob]
# =============================================================================

class FusionMLP(nn.Module):
    def __init__(self, d_in: int, d_h1: int, d_h2: int, d_out: int):
        super().__init__()
        self.fc1 = nn.Linear(d_in, d_h1)
        self.fc2 = nn.Linear(d_h1, d_h2)
        self.fc3 = nn.Linear(d_h2, d_out)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

def train_fusion_mlp(X: np.ndarray, y: np.ndarray, d_in: int, L: int, seed: int) -> FusionMLP:
    set_all_seeds(seed)
    model = FusionMLP(d_in, cfg.fusion_hidden1, cfg.fusion_hidden2, L).to(device)
    opt = optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=cfg.fusion_weight_decay)
    crit = nn.CrossEntropyLoss()

    X_t = torch.tensor(X, dtype=torch.float32)
    y_t = torch.tensor(y, dtype=torch.long)
    ds = torch.utils.data.TensorDataset(X_t, y_t)
    dl = torch.utils.data.DataLoader(ds, batch_size=cfg.batch_size, shuffle=True)

    model.train()
    for _ in range(cfg.fusion_epochs):
        for xb, yb in dl:
            xb = xb.to(device)
            yb = yb.to(device)
            opt.zero_grad()
            loss = crit(model(xb), yb)
            loss.backward()
            opt.step()
    return model

@torch.no_grad()
def fusion_mlp_predict_proba(model: FusionMLP, X: np.ndarray) -> np.ndarray:
    model.eval()
    X_t = torch.tensor(X, dtype=torch.float32)
    dl = torch.utils.data.DataLoader(X_t, batch_size=4096, shuffle=False)
    probs = []
    for xb in dl:
        xb = xb.to(device)
        logits = model(xb)
        probs.append(F.softmax(logits, dim=1).detach().cpu().numpy())
    return np.vstack(probs)

# =============================================================================
# One simulation for K=4, returning metrics for:
#   LR([p;prob]) vs MLP([p;prob]) vs LR(p-only)
# =============================================================================

def run_one_sim(sim: int) -> Dict[str, float]:
    seed = cfg.seed_base + sim
    set_all_seeds(seed)

    # Splits (match your script)
    X_trP, X_tmp, y_trP, y_tmp = train_test_split(
        X_train_full, Y_train_full,
        test_size=1 - cfg.train_frac,
        stratify=Y_train_full,
        random_state=seed,
    )
    X_cal, X_rest, y_cal, y_rest = train_test_split(
        X_tmp, y_tmp,
        test_size=1 - cfg.cal_frac_of_temp,
        stratify=y_tmp,
        random_state=seed,
    )
    X_fuse_tr, X_fuse_cal, y_fuse_tr, y_fuse_cal = train_test_split(
        X_rest, y_rest,
        test_size=1 - cfg.fuse_train_frac_of_rest,
        stratify=y_rest,
        random_state=seed,
    )

    num_views = 4
    loaders = {}
    for v in range(num_views):
        tr_loader   = torch.utils.data.DataLoader(PatchesDataset(X_trP,       y_trP,       cfg.K, v), batch_size=cfg.batch_size, shuffle=True)
        cal_loader  = torch.utils.data.DataLoader(PatchesDataset(X_cal,       y_cal,       cfg.K, v), batch_size=cfg.batch_size, shuffle=False)
        ftr_loader  = torch.utils.data.DataLoader(PatchesDataset(X_fuse_tr,   y_fuse_tr,   cfg.K, v), batch_size=cfg.batch_size, shuffle=False)
        fcal_loader = torch.utils.data.DataLoader(PatchesDataset(X_fuse_cal,  y_fuse_cal,  cfg.K, v), batch_size=cfg.batch_size, shuffle=False)
        te_loader   = torch.utils.data.DataLoader(PatchesDataset(X_test_full, Y_test_full, cfg.K, v), batch_size=cfg.batch_size, shuffle=False)
        loaders[v] = dict(train=tr_loader, cal=cal_loader, ftr=ftr_loader, fcal=fcal_loader, te=te_loader)

    # Train per-view CNNs + cal1 classwise score sets
    models, cal_classwise = [], []
    print(f"\n=== Sim {sim+1}/{cfg.num_simulations} (seed={seed}) ===")
    for v in range(num_views):
        print(f"  [View {v+1}/{num_views}] training...")
        m = PredictorCNN(num_classes=cfg.L)
        m = train_view_model(m, loaders[v]["train"], num_epochs=cfg.epochs_per_view, lr=cfg.lr)
        models.append(m)
        sc, lab = compute_true_label_scores(m, loaders[v]["cal"])
        cal_classwise.append(classwise_scores(sc, lab, cfg.L))

    # Per-view p/probs for fusion-train, fusion-cal2, test
    pv_tr, pr_tr = [], []
    pv_cal2, pr_cal2 = [], []
    pv_te, pr_te = [], []
    for v in range(num_views):
        p, pr = per_view_pvalues_and_probs(models[v], cal_classwise[v], loaders[v]["ftr"], cfg.L)
        pv_tr.append(p); pr_tr.append(pr)
        p, pr = per_view_pvalues_and_probs(models[v], cal_classwise[v], loaders[v]["fcal"], cfg.L)
        pv_cal2.append(p); pr_cal2.append(pr)
        p, pr = per_view_pvalues_and_probs(models[v], cal_classwise[v], loaders[v]["te"], cfg.L)
        pv_te.append(p); pr_te.append(pr)

    # Features
    X_ftr_full   = build_fusion_features_full(pv_tr, pr_tr)
    X_fcal2_full = build_fusion_features_full(pv_cal2, pr_cal2)
    X_ftest_full = build_fusion_features_full(pv_te, pr_te)

    X_ftr_ponly   = build_fusion_features_p_only(pv_tr)
    X_fcal2_ponly = build_fusion_features_p_only(pv_cal2)
    X_ftest_ponly = build_fusion_features_p_only(pv_te)

    # =========================================================
    # (A) Fusion = multinomial LR on [p;prob], then conformalize
    # =========================================================
    fusion_lr_full = LogisticRegression(
        max_iter=cfg.max_iter_lr,
        multi_class="multinomial",
        solver="lbfgs",
        random_state=seed,
    )
    fusion_lr_full.fit(X_ftr_full, y_fuse_tr)

    fused_probs_cal_lr_full  = fusion_lr_full.predict_proba(X_fcal2_full)
    fused_cal_scores_lr_full = fused_class_cal_scores(y_fuse_cal, fused_probs_cal_lr_full, cfg.L)
    fused_probs_test_lr_full = fusion_lr_full.predict_proba(X_ftest_full)
    P_lr_full                = fused_p_values_from_cal(fused_probs_test_lr_full, fused_cal_scores_lr_full)

    cov_lr_full, set_lr_full = eval_sets_from_pvals(P_lr_full, Y_test_full, cfg.alpha)

    # =========================================================
    # (B) Fusion = 2-hidden-layer MLP on [p;prob], conformalize
    # =========================================================
    fusion_mlp = train_fusion_mlp(X_ftr_full, y_fuse_tr, d_in=X_ftr_full.shape[1], L=cfg.L, seed=seed)
    fused_probs_cal_mlp  = fusion_mlp_predict_proba(fusion_mlp, X_fcal2_full)
    fused_cal_scores_mlp = fused_class_cal_scores(y_fuse_cal, fused_probs_cal_mlp, cfg.L)
    fused_probs_test_mlp = fusion_mlp_predict_proba(fusion_mlp, X_ftest_full)
    P_mlp_full           = fused_p_values_from_cal(fused_probs_test_mlp, fused_cal_scores_mlp)

    cov_mlp_full, set_mlp_full = eval_sets_from_pvals(P_mlp_full, Y_test_full, cfg.alpha)

    # =========================================================
    # (C) Fusion = multinomial LR on p-only, then conformalize
    # =========================================================
    fusion_lr_p = LogisticRegression(
        max_iter=cfg.max_iter_lr,
        multi_class="multinomial",
        solver="lbfgs",
        random_state=seed,
    )
    fusion_lr_p.fit(X_ftr_ponly, y_fuse_tr)

    fused_probs_cal_lr_p  = fusion_lr_p.predict_proba(X_fcal2_ponly)
    fused_cal_scores_lr_p = fused_class_cal_scores(y_fuse_cal, fused_probs_cal_lr_p, cfg.L)
    fused_probs_test_lr_p = fusion_lr_p.predict_proba(X_ftest_ponly)
    P_lr_p                = fused_p_values_from_cal(fused_probs_test_lr_p, fused_cal_scores_lr_p)

    cov_lr_p, set_lr_p = eval_sets_from_pvals(P_lr_p, Y_test_full, cfg.alpha)

    return {
        "sim": sim,
        "LR_full_cov(%)": 100.0 * cov_lr_full,
        "LR_full_set": set_lr_full,
        "MLP_full_cov(%)": 100.0 * cov_mlp_full,
        "MLP_full_set": set_mlp_full,
        "LR_p_cov(%)": 100.0 * cov_lr_p,
        "LR_p_set": set_lr_p,
    }

# =============================================================================
# Run all sims + print summary table
# =============================================================================

rows = []
for sim in range(cfg.num_simulations):
    rows.append(run_one_sim(sim))

df = pd.DataFrame(rows)

def mean_std_str(x: pd.Series) -> str:
    return f"{x.mean():.2f} ({x.std(ddof=1):.2f})"

summary = pd.DataFrame({
    "Method": [
        r"Fusion LR [$\Pi;p$]",
        r"Fusion 2-layer MLP [$\Pi;p$]",
        r"Fusion LR [$p$ only]",
    ],
    "Coverage (%)": [
        mean_std_str(df["LR_full_cov(%)"]),
        mean_std_str(df["MLP_full_cov(%)"]),
        mean_std_str(df["LR_p_cov(%)"]),
    ],
    "Avg set size": [
        mean_std_str(df["LR_full_set"]),
        mean_std_str(df["MLP_full_set"]),
        mean_std_str(df["LR_p_set"]),
    ],
})

print("\n=== Per-simulation raw results ===")
print(df)

print("\n=== Summary over simulations (mean (std)) ===")
print(summary)

# Optional: save
summary.to_csv("fusion_arch_input_ablation_k4_summary.csv", index=False)
df.to_csv("fusion_arch_input_ablation_k4_raw.csv", index=False)
print("\nSaved: fusion_arch_input_ablation_k4_summary.csv, fusion_arch_input_ablation_k4_raw.csv")


Using device: cuda
Files already downloaded and verified
Files already downloaded and verified

=== Sim 1/10 (seed=42) ===
  [View 1/4] training...
      view-epoch 25/100
      view-epoch 50/100
      view-epoch 75/100
      view-epoch 100/100
  [View 2/4] training...
      view-epoch 25/100
      view-epoch 50/100
      view-epoch 75/100
      view-epoch 100/100
  [View 3/4] training...
      view-epoch 25/100
      view-epoch 50/100
      view-epoch 75/100
      view-epoch 100/100
  [View 4/4] training...
      view-epoch 25/100
      view-epoch 50/100
      view-epoch 75/100
      view-epoch 100/100

=== Sim 2/10 (seed=43) ===
  [View 1/4] training...
      view-epoch 25/100
      view-epoch 50/100
      view-epoch 75/100
      view-epoch 100/100
  [View 2/4] training...
      view-epoch 25/100
      view-epoch 50/100
      view-epoch 75/100
      view-epoch 100/100
  [View 3/4] training...
      view-epoch 25/100
      view-epoch 50/100
      view-epoch 75/100
      view-epoch 100

In [4]:
# CIFAR-10 (K=4) fusion-architecture ablation:
# Compare Stage-2 fusion model = (A) multinomial LR vs (B) 2-hidden-layer MLP
# Everything else matches: per-view CNNs, splits, seeds, epochs, lr, batch size, sims, alpha.

from __future__ import annotations

import os
import warnings
from dataclasses import dataclass
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from torchvision import datasets, transforms

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

warnings.filterwarnings("ignore")

# =============================================================================
# Config (match your existing params)
# =============================================================================

@dataclass
class Cfg:
    alpha: float = 0.1
    K: int = 4
    L: int = 10
    num_simulations: int = 10

    # per-view CNN training
    epochs_per_view: int = 100
    lr: float = 1e-3
    batch_size: int = 512

    # fusion models
    max_iter_lr: int = 1000          # sklearn LR iterations
    fusion_epochs: int = 100         # MLP fusion epochs (explicit)
    fusion_hidden1: int = 128
    fusion_hidden2: int = 128
    fusion_weight_decay: float = 1e-4

    # splits (same structure as your script)
    train_frac: float = 0.5
    cal_frac_of_temp: float = 0.3
    fuse_train_frac_of_rest: float = 0.7

    seed_base: int = 42

cfg = Cfg()

# =============================================================================
# Device / determinism
# =============================================================================

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

def set_all_seeds(seed: int):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    # Reasonable reproducibility without forcing full determinism slowdowns:
    torch.backends.cudnn.benchmark = False

# =============================================================================
# Data
# =============================================================================

transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
test_dataset  = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)

X_train_full = train_dataset.data.astype(np.float32) / 255.0
Y_train_full = np.array(train_dataset.targets, dtype=int)
X_test_full  = test_dataset.data.astype(np.float32) / 255.0
Y_test_full  = np.array(test_dataset.targets, dtype=int)

# =============================================================================
# Multi-view patching (same as your code)
# =============================================================================

def split_image_into_k_patches(image: torch.Tensor, k: int) -> List[torch.Tensor]:
    # image: (3,32,32)
    C, H, W = image.shape
    if k == 4:
        patches = []
        for i in range(2):
            for j in range(2):
                patches.append(image[:, i*16:(i+1)*16, j*16:(j+1)*16])
        return patches
    else:
        base_width = W // k
        remainder = W % k
        patches, start = [], 0
        for idx in range(k):
            width = base_width + (1 if idx < remainder else 0)
            patches.append(image[:, :, start:start+width])
            start += width
        return patches

class PatchesDataset(torch.utils.data.Dataset):
    def __init__(self, images: np.ndarray, labels: np.ndarray, k: int, view: int):
        self.images = images
        self.labels = labels
        self.k = k
        self.view = view

    def __len__(self) -> int:
        return len(self.labels)

    def __getitem__(self, idx: int):
        img = self.images[idx].transpose((2, 0, 1))  # (3,32,32)
        img = torch.tensor(img, dtype=torch.float32)
        patches = split_image_into_k_patches(img, self.k)
        x = patches[self.view]
        y = int(self.labels[idx])
        return x, y

# =============================================================================
# Per-view CNN (same as your code)
# =============================================================================

class PredictorCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.pool  = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.fc1   = None
        self.fc2   = None
        self.num_classes = num_classes

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        if self.fc1 is None:
            b, c, h, w = x.shape
            self.fc1 = nn.Linear(c*h*w, 128).to(x.device)
            self.fc2 = nn.Linear(128, self.num_classes).to(x.device)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        return self.fc2(x)

def train_view_model(model: nn.Module, train_loader, num_epochs: int, lr: float):
    crit = nn.CrossEntropyLoss()
    opt  = optim.Adam(model.parameters(), lr=lr)
    model.to(device)
    for ep in range(num_epochs):
        model.train()
        for xb, yb in train_loader:
            xb = xb.to(device)
            yb = torch.tensor(yb, dtype=torch.long, device=device)
            opt.zero_grad()
            loss = crit(model(xb), yb)
            loss.backward()
            opt.step()
        if (ep + 1) % 25 == 0:
            print(f"      view-epoch {ep+1}/{num_epochs}")
    return model

# =============================================================================
# Conformal utilities (same logic as your code)
# =============================================================================

@torch.no_grad()
def compute_true_label_scores(model: nn.Module, loader) -> Tuple[np.ndarray, np.ndarray]:
    model.eval()
    scores, labels = [], []
    for xb, yb in loader:
        xb = xb.to(device)
        probs = F.softmax(model(xb), dim=1)
        idx = torch.arange(probs.size(0), device=probs.device)
        true_p = probs[idx, torch.tensor(yb, dtype=torch.long, device=probs.device)]
        s = (1.0 - true_p).detach().cpu().numpy()
        scores.append(s)
        labels.append(yb.numpy())
    return np.concatenate(scores, axis=0), np.concatenate(labels, axis=0)

def classwise_scores(scores: np.ndarray, labels: np.ndarray, L: int) -> Dict[int, np.ndarray]:
    out = {c: [] for c in range(L)}
    for s, y in zip(scores, labels):
        out[int(y)].append(float(s))
    return {c: np.asarray(v, float) for c, v in out.items()}

@torch.no_grad()
def per_view_pvalues_and_probs(model: nn.Module, cal_scores: Dict[int, np.ndarray], loader, L: int) -> Tuple[np.ndarray, np.ndarray]:
    model.eval()
    probs_all = []
    for xb, _ in loader:
        xb = xb.to(device)
        probs_all.append(F.softmax(model(xb), dim=1).detach().cpu().numpy())
    probs_all = np.vstack(probs_all)  # (n, L)

    n = probs_all.shape[0]
    pvals = np.zeros((n, L), dtype=float)
    for y in range(L):
        cal = cal_scores.get(y, np.array([], dtype=float))
        if cal.size == 0:
            pvals[:, y] = 1.0
        else:
            s_test = 1.0 - probs_all[:, y]
            counts = np.sum(cal[:, None] >= s_test[None, :], axis=0)
            pvals[:, y] = (1.0 + counts) / (len(cal) + 1.0)
    return pvals, probs_all

def build_fusion_features(pvals_list: List[np.ndarray], probs_list: List[np.ndarray]) -> np.ndarray:
    blocks = [np.hstack([pvals_list[k], probs_list[k]]) for k in range(len(pvals_list))]
    return np.hstack(blocks)

def fused_class_cal_scores(y_cal: np.ndarray, fused_probs_cal: np.ndarray, L: int) -> Dict[int, np.ndarray]:
    s = 1.0 - fused_probs_cal[np.arange(len(y_cal)), y_cal]
    out = {c: [] for c in range(L)}
    for sc, yy in zip(s, y_cal):
        out[int(yy)].append(float(sc))
    return {c: np.asarray(v, float) for c, v in out.items()}

def fused_p_values_from_cal(fused_probs: np.ndarray, cal_class_scores: Dict[int, np.ndarray]) -> np.ndarray:
    n, L = fused_probs.shape
    out = np.zeros((n, L), dtype=float)
    for y in range(L):
        cal = cal_class_scores.get(y, np.array([], dtype=float))
        if cal.size == 0:
            out[:, y] = 1.0
        else:
            s_test = 1.0 - fused_probs[:, y]
            counts = np.sum(cal[:, None] >= s_test[None, :], axis=0)
            out[:, y] = (1.0 + counts) / (len(cal) + 1.0)
    return out

def eval_sets_from_pvals(P: np.ndarray, y_true: np.ndarray, alpha: float) -> Tuple[float, float]:
    C = (P > alpha)
    cov = float(np.mean(C[np.arange(len(y_true)), y_true]))
    size = float(np.mean(C.sum(axis=1)))
    return cov, size

# =============================================================================
# Fusion MLP (2 hidden layers)
# =============================================================================

class FusionMLP(nn.Module):
    def __init__(self, d_in: int, d_h1: int, d_h2: int, d_out: int):
        super().__init__()
        self.fc1 = nn.Linear(d_in, d_h1)
        self.fc2 = nn.Linear(d_h1, d_h2)
        self.fc3 = nn.Linear(d_h2, d_out)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

def train_fusion_mlp(X: np.ndarray, y: np.ndarray, d_in: int, L: int, seed: int) -> FusionMLP:
    set_all_seeds(seed)
    model = FusionMLP(d_in, cfg.fusion_hidden1, cfg.fusion_hidden2, L).to(device)
    opt = optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=cfg.fusion_weight_decay)
    crit = nn.CrossEntropyLoss()

    # torch dataset
    X_t = torch.tensor(X, dtype=torch.float32)
    y_t = torch.tensor(y, dtype=torch.long)
    ds = torch.utils.data.TensorDataset(X_t, y_t)
    dl = torch.utils.data.DataLoader(ds, batch_size=cfg.batch_size, shuffle=True)

    model.train()
    for ep in range(cfg.fusion_epochs):
        for xb, yb in dl:
            xb = xb.to(device)
            yb = yb.to(device)
            opt.zero_grad()
            loss = crit(model(xb), yb)
            loss.backward()
            opt.step()
    return model

@torch.no_grad()
def fusion_mlp_predict_proba(model: FusionMLP, X: np.ndarray) -> np.ndarray:
    model.eval()
    X_t = torch.tensor(X, dtype=torch.float32)
    dl = torch.utils.data.DataLoader(X_t, batch_size=4096, shuffle=False)
    probs = []
    for xb in dl:
        xb = xb.to(device)
        logits = model(xb)
        probs.append(F.softmax(logits, dim=1).detach().cpu().numpy())
    return np.vstack(probs)

# =============================================================================
# One simulation for K=4, returning metrics for LR vs MLP
# =============================================================================

def run_one_sim(sim: int) -> Dict[str, float]:
    seed = cfg.seed_base + sim
    set_all_seeds(seed)

    # Splits (match your script)
    X_trP, X_tmp, y_trP, y_tmp = train_test_split(
        X_train_full, Y_train_full,
        test_size=1 - cfg.train_frac,
        stratify=Y_train_full,
        random_state=seed,
    )
    X_cal, X_rest, y_cal, y_rest = train_test_split(
        X_tmp, y_tmp,
        test_size=1 - cfg.cal_frac_of_temp,
        stratify=y_tmp,
        random_state=seed,
    )
    X_fuse_tr, X_fuse_cal, y_fuse_tr, y_fuse_cal = train_test_split(
        X_rest, y_rest,
        test_size=1 - cfg.fuse_train_frac_of_rest,
        stratify=y_rest,
        random_state=seed,
    )

    # K=4 -> 4 views
    num_views = 4
    loaders = {}
    for v in range(num_views):
        tr_loader   = torch.utils.data.DataLoader(PatchesDataset(X_trP,      y_trP,      cfg.K, v), batch_size=cfg.batch_size, shuffle=True)
        cal_loader  = torch.utils.data.DataLoader(PatchesDataset(X_cal,      y_cal,      cfg.K, v), batch_size=cfg.batch_size, shuffle=False)
        ftr_loader  = torch.utils.data.DataLoader(PatchesDataset(X_fuse_tr,  y_fuse_tr,  cfg.K, v), batch_size=cfg.batch_size, shuffle=False)
        fcal_loader = torch.utils.data.DataLoader(PatchesDataset(X_fuse_cal, y_fuse_cal, cfg.K, v), batch_size=cfg.batch_size, shuffle=False)
        te_loader   = torch.utils.data.DataLoader(PatchesDataset(X_test_full, Y_test_full, cfg.K, v), batch_size=cfg.batch_size, shuffle=False)
        loaders[v] = dict(train=tr_loader, cal=cal_loader, ftr=ftr_loader, fcal=fcal_loader, te=te_loader)

    # Train per-view CNNs + cal1 classwise score sets
    models, cal_classwise = [], []
    print(f"\n=== Sim {sim+1}/{cfg.num_simulations} (seed={seed}) ===")
    for v in range(num_views):
        print(f"  [View {v+1}/{num_views}] training...")
        m = PredictorCNN(num_classes=cfg.L)
        m = train_view_model(m, loaders[v]["train"], num_epochs=cfg.epochs_per_view, lr=cfg.lr)
        models.append(m)
        sc, lab = compute_true_label_scores(m, loaders[v]["cal"])
        cal_classwise.append(classwise_scores(sc, lab, cfg.L))

    # Per-view p/probs for fusion-train, fusion-cal2, test
    pv_tr, pr_tr = [], []
    pv_cal2, pr_cal2 = [], []
    pv_te, pr_te = [], []
    for v in range(num_views):
        p, pr = per_view_pvalues_and_probs(models[v], cal_classwise[v], loaders[v]["ftr"], cfg.L)
        pv_tr.append(p); pr_tr.append(pr)
        p, pr = per_view_pvalues_and_probs(models[v], cal_classwise[v], loaders[v]["fcal"], cfg.L)
        pv_cal2.append(p); pr_cal2.append(pr)
        p, pr = per_view_pvalues_and_probs(models[v], cal_classwise[v], loaders[v]["te"], cfg.L)
        pv_te.append(p); pr_te.append(pr)

    # Fusion features
    X_ftr   = build_fusion_features(pv_tr, pr_tr)
    X_fcal2 = build_fusion_features(pv_cal2, pr_cal2)
    X_ftest = build_fusion_features(pv_te, pr_te)
    d_in = X_ftr.shape[1]

    # ---------------------------------------------------------
    # (A) Fusion = multinomial LR on [p;prob], then conformalize
    # ---------------------------------------------------------
    fusion_lr = LogisticRegression(
        max_iter=cfg.max_iter_lr,
        multi_class="multinomial",
        solver="lbfgs",
        random_state=seed,
    )
    fusion_lr.fit(X_ftr, y_fuse_tr)

    fused_probs_cal_lr  = fusion_lr.predict_proba(X_fcal2)
    fused_cal_scores_lr = fused_class_cal_scores(y_fuse_cal, fused_probs_cal_lr, cfg.L)
    fused_probs_test_lr = fusion_lr.predict_proba(X_ftest)
    P_fused_lr          = fused_p_values_from_cal(fused_probs_test_lr, fused_cal_scores_lr)

    cov_lr, set_lr = eval_sets_from_pvals(P_fused_lr, Y_test_full, cfg.alpha)

    # ---------------------------------------------------------
    # (B) Fusion = 2-hidden-layer MLP on [p;prob], conformalize
    # ---------------------------------------------------------
    fusion_mlp = train_fusion_mlp(X_ftr, y_fuse_tr, d_in=d_in, L=cfg.L, seed=seed)
    fused_probs_cal_mlp  = fusion_mlp_predict_proba(fusion_mlp, X_fcal2)
    fused_cal_scores_mlp = fused_class_cal_scores(y_fuse_cal, fused_probs_cal_mlp, cfg.L)
    fused_probs_test_mlp = fusion_mlp_predict_proba(fusion_mlp, X_ftest)
    P_fused_mlp          = fused_p_values_from_cal(fused_probs_test_mlp, fused_cal_scores_mlp)

    cov_mlp, set_mlp = eval_sets_from_pvals(P_fused_mlp, Y_test_full, cfg.alpha)

    return {
        "sim": sim,
        "LR_cov(%)": 100.0 * cov_lr,
        "LR_set": set_lr,
        "MLP_cov(%)": 100.0 * cov_mlp,
        "MLP_set": set_mlp,
    }

# =============================================================================
# Run all sims + print small table
# =============================================================================

rows = []
for sim in range(cfg.num_simulations):
    rows.append(run_one_sim(sim))

df = pd.DataFrame(rows)

def mean_std_str(x: pd.Series) -> str:
    return f"{x.mean():.2f} ({x.std(ddof=1):.2f})"

summary = pd.DataFrame({
    "Method": ["Fusion LR", "Fusion 2-layer MLP"],
    "Coverage (%)": [mean_std_str(df["LR_cov(%)"]), mean_std_str(df["MLP_cov(%)"])],
    "Avg set size": [mean_std_str(df["LR_set"]), mean_std_str(df["MLP_set"])],
})

print("\n=== Per-simulation raw results ===")
print(df)

print("\n=== Summary over simulations (mean (std)) ===")
print(summary)

# Optional: save
summary.to_csv("fusion_arch_ablation_k4_summary.csv", index=False)
df.to_csv("fusion_arch_ablation_k4_raw.csv", index=False)
print("\nSaved: fusion_arch_ablation_k4_summary.csv, fusion_arch_ablation_k4_raw.csv")


Using device: cuda
Files already downloaded and verified
Files already downloaded and verified

=== Sim 1/10 (seed=42) ===
  [View 1/4] training...
      view-epoch 25/100
      view-epoch 50/100
      view-epoch 75/100
      view-epoch 100/100
  [View 2/4] training...
      view-epoch 25/100
      view-epoch 50/100
      view-epoch 75/100
      view-epoch 100/100
  [View 3/4] training...
      view-epoch 25/100
      view-epoch 50/100
      view-epoch 75/100
      view-epoch 100/100
  [View 4/4] training...
      view-epoch 25/100
      view-epoch 50/100
      view-epoch 75/100
      view-epoch 100/100

=== Sim 2/10 (seed=43) ===
  [View 1/4] training...
      view-epoch 25/100
      view-epoch 50/100
      view-epoch 75/100
      view-epoch 100/100
  [View 2/4] training...
      view-epoch 25/100
      view-epoch 50/100
      view-epoch 75/100
      view-epoch 100/100
  [View 3/4] training...
      view-epoch 25/100
      view-epoch 50/100
      view-epoch 75/100
      view-epoch 100