In [None]:
"""
Improved Hybrid VQC Ensemble
- Frozen ResNet18 backbone
- 512->4 projector (tanh)
- 4-qubit VQC (diverse depth & entanglers)
- 4->2 linear head (CrossEntropy with label smoothing)
- Bagging x5 + Soft-vote (with temperature calibration)
- Optional stacking on calibrated logits
- Optional feature caching (precompute 512-d features)
"""

import os, random, copy
from pathlib import Path
from typing import List, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, RandomSampler, TensorDataset
from torchvision import datasets, transforms
from torchvision.models import resnet18
from tqdm.auto import tqdm
import numpy as np

# --- Quantum ---
import pennylane as qml

# ----------------------------
# Config & global constants
# ----------------------------
SEED = 123
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

ROOT = Path("./")
TRAIN_ROOT = ROOT / "_bin_dataset" / "train"
VAL_ROOT   = ROOT / "_bin_dataset" / "val"

ART_DIR = ROOT / "artifacts" / "ensemble"
ART_DIR.mkdir(parents=True, exist_ok=True)

IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

N_MODELS = 5
N_EPOCHS = 8
BATCH_TRAIN = 32
BATCH_VAL = 64
LABEL_SMOOTH = 0.05
USE_FEATURE_CACHE = True
CLIP_NORM = 1.0

print(f"Device: {device}")
print(f"Artifacts: {ART_DIR}")


  from .autonotebook import tqdm as notebook_tqdm


Device: cuda
Artifacts: artifacts\ensemble


In [2]:
# ----------------------------
# 1) Load frozen backbone
# ----------------------------
def load_frozen_backbone() -> nn.Module:
    ckpt_backbone = ROOT / "outputs" / "resnet18_backbone_only.pt"
    if not ckpt_backbone.exists():
        alt = ROOT / "resnet18_finetuned.pt"
        if alt.exists():
            ckpt_backbone = alt
        else:
            raise FileNotFoundError("Missing backbone weights (resnet18_backbone_only.pt or resnet18_finetuned.pt).")

    backbone = resnet18(weights=None)
    state = torch.load(ckpt_backbone, map_location="cpu")
    if isinstance(state, dict) and "state_dict" in state:
        state = state["state_dict"]
    # Load, ignore classifier keys
    missing, unexpected = backbone.load_state_dict(state, strict=False)
    # print("missing:", missing, "unexpected:", unexpected)

    backbone.fc = nn.Identity()   # remove head
    backbone.eval().to(device)
    for p in backbone.parameters():
        p.requires_grad_(False)
    return backbone

BACKBONE = load_frozen_backbone()
print("Frozen backbone loaded. Output features: 512")

# ----------------------------
# 2) Model components
# ----------------------------
class L512to4(nn.Module):
    def __init__(self, in_dim=512, hidden_dim=4):
        super().__init__()
        self.fc = nn.Linear(in_dim, hidden_dim)
        self.act = nn.Tanh()
    def forward(self, z):
        return self.act(self.fc(z))

# VQC diversity helpers
n_qubits = 4
from math import pi as PI

def make_entangler(kind: str):
    if kind == "ladder":
        return [(1,2), (0,1), (2,3)]
    if kind == "ring":
        return [(0,1), (1,2), (2,3), (3,0)]
    # random subset
    pairs = [(0,1),(1,2),(2,3),(0,2),(1,3)]
    random.shuffle(pairs)
    return pairs[:4]

class QuantumLayer(nn.Module):
    def __init__(self, depth=None, pairs=None, pattern=None, shots=None):
        super().__init__()
        self.depth = int(depth) if depth is not None else random.choice([6])
        self.pairs = pairs if pairs is not None else make_entangler(pattern or random.choice(["ladder","ring","rand"]))
        self.weights = nn.Parameter(0.01 * torch.randn(self.depth, n_qubits))
        self.dev = qml.device("default.qubit", wires=n_qubits, shots=shots)

        def circuit(x, w):
            for q in range(n_qubits):
                qml.Hadamard(wires=q)
                qml.RY(PI * x[q] / 2.0, wires=q)
            for l in range(self.depth):
                for q in range(n_qubits):
                    qml.RY(w[l, q], wires=q)
                for a,b in self.pairs:
                    qml.CNOT(wires=[a,b])
            return [qml.expval(qml.PauliZ(q)) for q in range(n_qubits)]

        self.qnode = qml.QNode(circuit, self.dev, interface="torch", diff_method="best")

    def forward(self, x4_batch: torch.Tensor) -> torch.Tensor:
        outs = []
        for i in range(x4_batch.shape[0]):
            y = self.qnode(x4_batch[i], self.weights)
            if not isinstance(y, torch.Tensor):
                y = torch.stack(y)
            outs.append(y)
        return torch.stack(outs, dim=0).to(torch.float32)


class L4to2(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(4, 2)
    def forward(self, z4):
        return self.fc(z4)

class HybridModel(nn.Module):
    """
    Accepts either images [B,3,224,224] OR precomputed features [B,512].
    """
    def __init__(self, backbone, proj, q_layer, head):
        super().__init__()
        self.backbone = backbone    # frozen
        self.proj = proj
        self.q_layer = q_layer
        self.head = head

    def forward(self, x):
        if x.dim() == 2 and x.size(1) == 512:
            z512 = x.to(device)  # precomputed
        else:
            with torch.no_grad():
                z512 = self.backbone(x)
        x4 = self.proj(z512)
        zq = self.q_layer(x4)
        logits = self.head(zq)
        return logits

def create_new_model(depth=None, pairs=None):
    local_backbone = copy.deepcopy(BACKBONE)
    proj = L512to4(512, n_qubits)
    q_layer = QuantumLayer(depth=depth, pairs=pairs)  # <- can fix shape here
    head = L4to2()
    return HybridModel(local_backbone, proj, q_layer, head).to(device)


Frozen backbone loaded. Output features: 512


In [3]:
# ----------------------------
# 3) Data & transforms
# ----------------------------
tfm = transforms.Compose([
    transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])

train_ds = datasets.ImageFolder(TRAIN_ROOT, transform=tfm)
val_ds   = datasets.ImageFolder(VAL_ROOT,   transform=tfm)

def precompute_features(dataset) -> TensorDataset:
    loader = DataLoader(dataset, batch_size=BATCH_VAL, shuffle=False, num_workers=0)
    feats, labels = [], []
    with torch.no_grad():
        for imgs, y in tqdm(loader, desc="Precompute backbone features"):
            z = BACKBONE(imgs.to(device))     # frozen backbone
            feats.append(z.cpu())
            labels.append(y)
    Z = torch.cat(feats).contiguous()
    Y = torch.cat(labels).contiguous()
    return TensorDataset(Z, Y)

if USE_FEATURE_CACHE:
    train_feat = precompute_features(train_ds)
    val_feat   = precompute_features(val_ds)
    # Bagged loaders from features
    n_samples = len(train_feat)
    train_loaders = []
    for i in range(N_MODELS):
        sampler = RandomSampler(train_feat, replacement=True, num_samples=n_samples,
                                generator=torch.Generator().manual_seed(SEED + i))
        train_loaders.append(DataLoader(train_feat, batch_size=BATCH_TRAIN, sampler=sampler, num_workers=0))
    val_loader = DataLoader(val_feat, batch_size=BATCH_VAL, shuffle=False, num_workers=0)
else:
    # End-to-end (slower)
    n_samples = len(train_ds)
    train_loaders = []
    for i in range(N_MODELS):
        sampler = RandomSampler(train_ds, replacement=True, num_samples=n_samples,
                                generator=torch.Generator().manual_seed(SEED + i))
        train_loaders.append(DataLoader(train_ds, batch_size=BATCH_TRAIN, sampler=sampler, num_workers=0))
    val_loader = DataLoader(val_ds, batch_size=BATCH_VAL, shuffle=False, num_workers=0)

print(f"Dataset: train={len(train_ds)}, val={len(val_ds)}")
print(f"Created {len(train_loaders)} bagged train loaders; 1 validation loader.")


Precompute backbone features:   0%|          | 0/8 [00:00<?, ?it/s]

Precompute backbone features: 100%|██████████| 8/8 [00:01<00:00,  4.59it/s]
Precompute backbone features: 100%|██████████| 2/2 [00:00<00:00,  5.32it/s]


Dataset: train=480, val=120
Created 5 bagged train loaders; 1 validation loader.


In [None]:
# ----------------------------
# 4) Train loop + utilities
# ----------------------------
def accuracy_from_logits(logits, labels):
    preds = logits.argmax(dim=1)
    return (preds == labels).float().mean().item()

def run_epoch(model, loader, optimizer, scheduler, ep, train=True):
    model.train(train)
    loss_sum, correct, total = 0.0, 0, 0
    bar = tqdm(loader, desc=f"Epoch {ep:02d} ({'Train' if train else 'Val'})")
    crit = nn.CrossEntropyLoss(label_smoothing=LABEL_SMOOTH if train else 0.0)
    for xb, yb in bar:
        xb, yb = xb.to(device), yb.to(device)
        if train:
            optimizer.zero_grad(set_to_none=True)
        with torch.set_grad_enabled(train):
            logits = model(xb)
            loss = crit(logits, yb)
        if train:
            loss.backward()
            nn.utils.clip_grad_norm_(model.q_layer.parameters(), CLIP_NORM)
            optimizer.step()
        loss_sum += loss.item() * xb.size(0)
        acc = (logits.argmax(1) == yb).sum().item()
        correct += acc
        total += xb.size(0)
        bar.set_postfix(loss=loss_sum/total, acc=correct/total)
    if train and scheduler is not None:
        scheduler.step()
    return loss_sum/total, correct/total

@torch.no_grad()
def collect_logits(model, loader) -> Tuple[torch.Tensor, torch.Tensor]:
    all_logits, all_y = [], []
    for xb, yb in loader:
        xb = xb.to(device)
        logits = model(xb)
        all_logits.append(logits.cpu())
        all_y.append(yb)
    return torch.cat(all_logits), torch.cat(all_y)

# Temperature calibration per model on the validation set
class TempScale(nn.Module):
    def __init__(self):
        super().__init__()
        self.t = nn.Parameter(torch.ones(1))
    def forward(self, logits):
        return logits / self.t.clamp_min(1e-3)

def fit_temperature(model, val_loader) -> float:
    model.eval()
    logits, y = collect_logits(model, val_loader)
    ts = TempScale().to(logits.device)
    opt = torch.optim.Adam(ts.parameters(), lr=0.05)
    crit = nn.CrossEntropyLoss()
    for _ in range(300):
        opt.zero_grad()
        loss = crit(ts(logits), y)
        loss.backward()
        opt.step()
    return float(ts.t.detach().cpu().item())


In [5]:
# ----------------------------
# 5) Train ensemble
# ----------------------------
print(f"\nStarting training for {N_MODELS} models...")
cal_temps = []  # temperature per model for calibration

for k in range(N_MODELS):
    print(f"\n{'='*30}\nTraining Model {k+1}/{N_MODELS}\n{'='*30}")
    torch.manual_seed(SEED + k)
    model_k = create_new_model()

    optimizer = torch.optim.Adam([
        {"params": model_k.proj.parameters(), "lr": 1e-3},
        {"params": model_k.q_layer.parameters(), "lr": 1e-2},
        {"params": model_k.head.parameters(), "lr": 1e-3},
    ])
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=N_EPOCHS)

    train_loader_k = train_loaders[k]

    best_va, best_sd = -1.0, None
    patience, bad = 3, 0

    for ep in range(1, N_EPOCHS + 1):
        trL, trA = run_epoch(model_k, train_loader_k, optimizer, scheduler, ep, train=True)
        vaL, vaA = run_epoch(model_k, val_loader, optimizer, scheduler, ep, train=False)
        print(f"[Model {k+1} Ep {ep:02d}] Train L/A: {trL:.4f}/{trA:.3f} | Val L/A: {vaL:.4f}/{vaA:.3f}")
        if vaA > best_va:
            best_va = vaA
            best_sd = copy.deepcopy(model_k.state_dict())
            bad = 0
        else:
            bad += 1
            if bad >= patience:
                print("Early stopping.")
                break

    save_path = ART_DIR / f"model_{k}.pt"
    if best_sd is not None:
        torch.save(best_sd, save_path)
        print(f"Saved best weights to {save_path}")
    # Save best checkpoint
    # after torch.save(best_sd, save_path)
    cfg_path = ART_DIR / f"model_{k}_cfg.pt"
    torch.save({"depth": model_k.q_layer.depth, "pairs": model_k.q_layer.pairs}, cfg_path)
    print(f"Saved model cfg to {cfg_path}: depth={model_k.q_layer.depth}, pairs={model_k.q_layer.pairs}")


    # Temperature calibration on validation set
    t_k = fit_temperature(model_k, val_loader)
    cal_temps.append(t_k)
    torch.save({"temperature": t_k}, ART_DIR / f"model_{k}_temp.pt")
    print(f"Calibrated temperature t={t_k:.3f} saved.")

print("\nEnsemble training complete.")


Starting training for 5 models...

Training Model 1/5


Epoch 01 (Train): 100%|██████████| 15/15 [00:26<00:00,  1.75s/it, acc=0.912, loss=0.391]
Epoch 01 (Val): 100%|██████████| 2/2 [00:03<00:00,  1.71s/it, acc=0.983, loss=0.209]


[Model 1 Ep 01] Train L/A: 0.3906/0.912 | Val L/A: 0.2088/0.983


Epoch 02 (Train): 100%|██████████| 15/15 [00:25<00:00,  1.69s/it, acc=0.996, loss=0.199]
Epoch 02 (Val): 100%|██████████| 2/2 [00:03<00:00,  1.67s/it, acc=0.983, loss=0.151]


[Model 1 Ep 02] Train L/A: 0.1990/0.996 | Val L/A: 0.1509/0.983


Epoch 03 (Train): 100%|██████████| 15/15 [00:26<00:00,  1.77s/it, acc=1, loss=0.176]
Epoch 03 (Val): 100%|██████████| 2/2 [00:03<00:00,  1.61s/it, acc=0.992, loss=0.139]


[Model 1 Ep 03] Train L/A: 0.1759/1.000 | Val L/A: 0.1386/0.992


Epoch 04 (Train): 100%|██████████| 15/15 [00:28<00:00,  1.93s/it, acc=1, loss=0.168]
Epoch 04 (Val): 100%|██████████| 2/2 [00:04<00:00,  2.06s/it, acc=0.992, loss=0.13]


[Model 1 Ep 04] Train L/A: 0.1683/1.000 | Val L/A: 0.1296/0.992


Epoch 05 (Train): 100%|██████████| 15/15 [00:26<00:00,  1.76s/it, acc=0.998, loss=0.165]
Epoch 05 (Val): 100%|██████████| 2/2 [00:03<00:00,  1.56s/it, acc=0.992, loss=0.124]


[Model 1 Ep 05] Train L/A: 0.1651/0.998 | Val L/A: 0.1240/0.992


Epoch 06 (Train): 100%|██████████| 15/15 [00:25<00:00,  1.70s/it, acc=1, loss=0.158]
Epoch 06 (Val): 100%|██████████| 2/2 [00:03<00:00,  1.55s/it, acc=0.992, loss=0.121]


[Model 1 Ep 06] Train L/A: 0.1580/1.000 | Val L/A: 0.1211/0.992
Early stopping.
Saved best weights to artifacts\ensemble\model_0.pt
Saved model cfg to artifacts\ensemble\model_0_cfg.pt: depth=6, pairs=[(0, 1), (1, 2), (2, 3), (3, 0)]
Calibrated temperature t=0.457 saved.

Training Model 2/5


Epoch 01 (Train): 100%|██████████| 15/15 [00:26<00:00,  1.78s/it, acc=0.89, loss=0.36]  
Epoch 01 (Val): 100%|██████████| 2/2 [00:03<00:00,  1.61s/it, acc=0.992, loss=0.251]


[Model 2 Ep 01] Train L/A: 0.3600/0.890 | Val L/A: 0.2511/0.992


Epoch 02 (Train): 100%|██████████| 15/15 [00:26<00:00,  1.75s/it, acc=0.996, loss=0.267]
Epoch 02 (Val): 100%|██████████| 2/2 [00:03<00:00,  1.77s/it, acc=0.992, loss=0.224]


[Model 2 Ep 02] Train L/A: 0.2673/0.996 | Val L/A: 0.2237/0.992


Epoch 03 (Train): 100%|██████████| 15/15 [00:28<00:00,  1.91s/it, acc=0.998, loss=0.251]
Epoch 03 (Val): 100%|██████████| 2/2 [00:03<00:00,  1.66s/it, acc=0.992, loss=0.204]


[Model 2 Ep 03] Train L/A: 0.2513/0.998 | Val L/A: 0.2042/0.992


Epoch 04 (Train): 100%|██████████| 15/15 [00:31<00:00,  2.10s/it, acc=1, loss=0.226]
Epoch 04 (Val): 100%|██████████| 2/2 [00:03<00:00,  1.72s/it, acc=0.992, loss=0.192]


[Model 2 Ep 04] Train L/A: 0.2261/1.000 | Val L/A: 0.1918/0.992
Early stopping.
Saved best weights to artifacts\ensemble\model_1.pt
Saved model cfg to artifacts\ensemble\model_1_cfg.pt: depth=6, pairs=[(0, 1), (1, 2), (2, 3), (3, 0)]
Calibrated temperature t=0.282 saved.

Training Model 3/5


Epoch 01 (Train): 100%|██████████| 15/15 [00:26<00:00,  1.75s/it, acc=0.883, loss=0.438]
Epoch 01 (Val): 100%|██████████| 2/2 [00:03<00:00,  1.53s/it, acc=0.983, loss=0.307]


[Model 3 Ep 01] Train L/A: 0.4381/0.883 | Val L/A: 0.3072/0.983


Epoch 02 (Train): 100%|██████████| 15/15 [00:26<00:00,  1.76s/it, acc=1, loss=0.311]
Epoch 02 (Val): 100%|██████████| 2/2 [00:03<00:00,  1.57s/it, acc=0.983, loss=0.281]


[Model 3 Ep 02] Train L/A: 0.3110/1.000 | Val L/A: 0.2809/0.983


Epoch 03 (Train): 100%|██████████| 15/15 [00:26<00:00,  1.80s/it, acc=1, loss=0.288]
Epoch 03 (Val): 100%|██████████| 2/2 [00:03<00:00,  1.53s/it, acc=0.983, loss=0.259]


[Model 3 Ep 03] Train L/A: 0.2879/1.000 | Val L/A: 0.2592/0.983


Epoch 04 (Train): 100%|██████████| 15/15 [00:26<00:00,  1.77s/it, acc=1, loss=0.269]
Epoch 04 (Val): 100%|██████████| 2/2 [00:03<00:00,  1.73s/it, acc=0.983, loss=0.244]


[Model 3 Ep 04] Train L/A: 0.2694/1.000 | Val L/A: 0.2438/0.983
Early stopping.
Saved best weights to artifacts\ensemble\model_2.pt
Saved model cfg to artifacts\ensemble\model_2_cfg.pt: depth=6, pairs=[(1, 2), (0, 1), (2, 3)]
Calibrated temperature t=0.280 saved.

Training Model 4/5


Epoch 01 (Train): 100%|██████████| 15/15 [00:27<00:00,  1.85s/it, acc=0.952, loss=0.327]
Epoch 01 (Val): 100%|██████████| 2/2 [00:03<00:00,  1.69s/it, acc=0.983, loss=0.195]


[Model 4 Ep 01] Train L/A: 0.3268/0.952 | Val L/A: 0.1947/0.983


Epoch 02 (Train): 100%|██████████| 15/15 [00:29<00:00,  1.94s/it, acc=0.992, loss=0.22] 
Epoch 02 (Val): 100%|██████████| 2/2 [00:03<00:00,  1.95s/it, acc=0.992, loss=0.176]


[Model 4 Ep 02] Train L/A: 0.2202/0.992 | Val L/A: 0.1757/0.992


Epoch 03 (Train): 100%|██████████| 15/15 [00:29<00:00,  1.94s/it, acc=0.998, loss=0.196]
Epoch 03 (Val): 100%|██████████| 2/2 [00:03<00:00,  1.88s/it, acc=0.992, loss=0.161]


[Model 4 Ep 03] Train L/A: 0.1963/0.998 | Val L/A: 0.1614/0.992


Epoch 04 (Train): 100%|██████████| 15/15 [00:28<00:00,  1.93s/it, acc=0.998, loss=0.191]
Epoch 04 (Val): 100%|██████████| 2/2 [00:03<00:00,  1.69s/it, acc=0.992, loss=0.152]


[Model 4 Ep 04] Train L/A: 0.1911/0.998 | Val L/A: 0.1522/0.992


Epoch 05 (Train): 100%|██████████| 15/15 [00:32<00:00,  2.19s/it, acc=0.998, loss=0.184]
Epoch 05 (Val): 100%|██████████| 2/2 [00:03<00:00,  1.77s/it, acc=0.992, loss=0.146]


[Model 4 Ep 05] Train L/A: 0.1835/0.998 | Val L/A: 0.1462/0.992
Early stopping.
Saved best weights to artifacts\ensemble\model_3.pt
Saved model cfg to artifacts\ensemble\model_3_cfg.pt: depth=6, pairs=[(0, 1), (1, 2), (2, 3), (3, 0)]
Calibrated temperature t=0.369 saved.

Training Model 5/5


Epoch 01 (Train): 100%|██████████| 15/15 [00:28<00:00,  1.93s/it, acc=0.779, loss=0.46] 
Epoch 01 (Val): 100%|██████████| 2/2 [00:03<00:00,  1.67s/it, acc=0.983, loss=0.26]


[Model 5 Ep 01] Train L/A: 0.4604/0.779 | Val L/A: 0.2598/0.983


Epoch 02 (Train): 100%|██████████| 15/15 [00:28<00:00,  1.92s/it, acc=0.994, loss=0.271]
Epoch 02 (Val): 100%|██████████| 2/2 [00:03<00:00,  1.80s/it, acc=0.992, loss=0.234]


[Model 5 Ep 02] Train L/A: 0.2715/0.994 | Val L/A: 0.2336/0.992


Epoch 03 (Train): 100%|██████████| 15/15 [00:29<00:00,  1.94s/it, acc=0.998, loss=0.248]
Epoch 03 (Val): 100%|██████████| 2/2 [00:03<00:00,  1.95s/it, acc=0.992, loss=0.214]


[Model 5 Ep 03] Train L/A: 0.2482/0.998 | Val L/A: 0.2139/0.992


Epoch 04 (Train): 100%|██████████| 15/15 [00:30<00:00,  2.02s/it, acc=0.998, loss=0.235]
Epoch 04 (Val): 100%|██████████| 2/2 [00:03<00:00,  1.84s/it, acc=0.992, loss=0.202]


[Model 5 Ep 04] Train L/A: 0.2353/0.998 | Val L/A: 0.2020/0.992


Epoch 05 (Train): 100%|██████████| 15/15 [00:38<00:00,  2.55s/it, acc=1, loss=0.223]
Epoch 05 (Val): 100%|██████████| 2/2 [00:04<00:00,  2.24s/it, acc=0.992, loss=0.194]


[Model 5 Ep 05] Train L/A: 0.2228/1.000 | Val L/A: 0.1941/0.992
Early stopping.
Saved best weights to artifacts\ensemble\model_4.pt
Saved model cfg to artifacts\ensemble\model_4_cfg.pt: depth=6, pairs=[(0, 1), (1, 2), (2, 3), (3, 0)]
Calibrated temperature t=0.319 saved.

Ensemble training complete.


In [None]:
def load_base_models():
    models, temps = [], []
    for k in range(N_MODELS):
        path = ART_DIR / f"model_{k}.pt"
        if not path.exists():
            print(f"Missing {path}, skipping.")
            continue

        # 1) peek state_dict to recover depth from q_layer.weights shape
        sd = torch.load(path, map_location="cpu")
        if "q_layer.weights" in sd:
            depth = sd["q_layer.weights"].shape[0]
        else:
            depth = next(v.shape[0] for k2, v in sd.items() if k2.endswith("q_layer.weights"))

        # 2) try to read saved entangler pairs
        cfg_path = ART_DIR / f"model_{k}_cfg.pt"
        pairs = None
        if cfg_path.exists():
            cfg = torch.load(cfg_path, map_location="cpu")
            pairs = cfg.get("pairs", None)

        # 3) instantiate with matching depth/pairs, then load
        m = create_new_model(depth=depth, pairs=pairs)
        m.load_state_dict(sd, strict=True)
        m.eval().to(device)
        models.append(m)

        # load temperature if present
        tpath = ART_DIR / f"model_{k}_temp.pt"
        if tpath.exists():
            temps.append(float(torch.load(tpath, map_location="cpu")["temperature"]))
        else:
            temps.append(1.0)
    return models, temps


base_models, base_temps = load_base_models()
print(f"Loaded {len(base_models)} base models for ensemble.")

@torch.no_grad()
def predict_soft_vote(x, models_list, temps_list, return_probs=False):
    """x: image tensor [1,3,224,224] or features [1,512]"""
    probs = []
    for m, t in zip(models_list, temps_list):
        logits = m(x.to(device)) / max(t, 1e-3)
        probs.append(F.softmax(logits, dim=1))
    avg = torch.mean(torch.stack(probs), dim=0)
    pred_idx = int(avg.argmax(dim=1).item())
    conf = float(avg.max(dim=1).values.item())
    if return_probs:
        return pred_idx, conf, avg.cpu().numpy()
    return pred_idx, conf

# Quick single-image test
TEST_IMG = ROOT / "test1" / "1.jpg"
if TEST_IMG.exists() and len(base_models) > 0:
    from PIL import Image
    tf_single = tfm
    img = Image.open(TEST_IMG).convert("RGB")
    xb = tf_single(img).unsqueeze(0)
    pred_idx, conf, avg_probs = predict_soft_vote(xb, base_models, base_temps, True)
    classes = val_ds.classes
    print(f"\n[Soft-vote] {TEST_IMG.name}: {classes[pred_idx]}  (conf {conf:.4f})")
    # print("avg probs:", avg_probs)

# ----------------------------
# 7) Stacking on calibrated logits
# ----------------------------
@torch.no_grad()
def create_meta_dataset(models_list, temps_list, loader) -> Tuple[torch.Tensor, torch.Tensor]:
    print("Creating meta-dataset (calibrated logits on validation set)...")
    Xs, Ys = [], []
    for xb, yb in tqdm(loader, desc="Meta-features"):
        xb = xb.to(device)
        per_model = []
        for m, t in zip(models_list, temps_list):
            l = m(xb) / max(t, 1e-3)       # calibrated logits
            per_model.append(l)            # [B,2]
        Xb = torch.cat(per_model, dim=1)   # [B, 2*N_MODELS]
        Xs.append(Xb.cpu())
        Ys.append(yb.cpu())
    return torch.cat(Xs), torch.cat(Ys)

if len(base_models) > 0:
    X_meta, y_meta = create_meta_dataset(base_models, base_temps, val_loader)
    meta_ds = TensorDataset(X_meta, y_meta)
    meta_loader = DataLoader(meta_ds, batch_size=64, shuffle=True, num_workers=0)

    meta_learner = nn.Linear(2 * len(base_models), 2).to(device)
    opt_meta = torch.optim.Adam(meta_learner.parameters(), lr=1e-3)
    crit_meta = nn.CrossEntropyLoss()

    print("Training meta-learner on calibrated logits...")
    meta_learner.train()
    for ep in range(1, 11):
        loss_sum, n = 0.0, 0
        for Xb, yb in meta_loader:
            Xb, yb = Xb.to(device), yb.to(device)
            opt_meta.zero_grad(set_to_none=True)
            out = meta_learner(Xb)
            loss = crit_meta(out, yb)
            loss.backward()
            opt_meta.step()
            loss_sum += loss.item() * Xb.size(0); n += Xb.size(0)
        if ep == 1 or ep % 2 == 0:
            print(f"[Meta Ep {ep:02d}] loss={loss_sum/n:.4f}")
    meta_learner.eval()
    torch.save(meta_learner.state_dict(), ART_DIR / "meta_learner.pt")
    print("Meta-learner saved.")

    class StackedEnsemble(nn.Module):
        def __init__(self, bases, temps, meta):
            super().__init__()
            self.bases = nn.ModuleList(bases)
            for m in self.bases:
                for p in m.parameters(): p.requires_grad = False
            self.temps = [float(t) for t in temps]
            self.meta = meta.eval()
        @torch.no_grad()
        def forward(self, x):
            feats = []
            for m, t in zip(self.bases, self.temps):
                l = m(x.to(device)) / max(t, 1e-3)
                feats.append(l)
            meta_in = torch.cat(feats, dim=1)
            return self.meta(meta_in)

    stacked = StackedEnsemble(base_models, base_temps, meta_learner).to(device)

    # Optional single-image check
    if TEST_IMG.exists():
        from PIL import Image
        xb = tfm(Image.open(TEST_IMG).convert("RGB")).unsqueeze(0)
        with torch.no_grad():
            logits = stacked(xb)
            probs = F.softmax(logits, dim=1)
            pred = int(probs.argmax(1).item()); conf = float(probs.max(1).values.item())
        print(f"\n[Stacked]  {TEST_IMG.name}: {val_ds.classes[pred]}  (conf {conf:.4f})")

Loaded 5 base models for ensemble.

[Soft-vote] 1.jpg: Positive  (conf 0.7348)
Creating meta-dataset (calibrated logits on validation set)...


Meta-features: 100%|██████████| 2/2 [00:22<00:00, 11.11s/it]


Training meta-learner on calibrated logits...
[Meta Ep 01] loss=0.3625
[Meta Ep 02] loss=0.3373
[Meta Ep 04] loss=0.2905
[Meta Ep 06] loss=0.2509
[Meta Ep 08] loss=0.2182
[Meta Ep 10] loss=0.1907
Meta-learner saved.

[Stacked]  1.jpg: Positive  (conf 0.6601)
