## Notebook Overview
- Load the hybrid checkpoint and rebuild the evaluation data split.
- Compare classical-only vs hybrid masks across accuracy, latency, and confusion matrices.
- Capture sample confidence visualisations to support benchmarking and publication figures.

## Notebook Overview
- Load the hybrid checkpoint and rebuild the evaluation data split.
- Compare classical-only vs hybrid masks across accuracy, latency, and confusion matrices.
- Capture sample confidence visualisations to support benchmarking and publication figures.

In [None]:
# Install runtime dependencies (no-op if already present)
!pip install -q kagglehub qiskit qiskit_machine_learning > /dev/null 2>&1

In [None]:
# Imports
import os
import time
import pickle
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import transforms

from sklearn.metrics import classification_report, confusion_matrix

from hybrid_components import (
    DiabeticRetinopathyDataset,
    ClassicalDRModel,
    tensor_to_numpy_image,
    create_confusion_plot,
    create_head_confidence_figure,
    generate_saliency_overlay
)

In [None]:
# Configuration
CONFIG = {
    "dataset": "sovitrath",
    "batch_size": 32,
    "num_workers": 2,
    "val_fraction": 0.2,
    "split_seed": 42,
    "checkpoint_path": "complete_checkpoint.pth",
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "class_mask_scenarios": {
        "classical_only": [1, 1, 0],
        "hybrid_full": [1, 1, 1]
    }
}

EVAL_TRANSFORM = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [None]:
def compute_accuracy(preds, targets):
    return (preds == targets).sum() / max(len(targets), 1)

def evaluate_with_mask(model, dataloader, device, mask, scenario_name, class_names):
    model.eval()
    mask_tensor = torch.tensor(mask, device=device, dtype=torch.float32)
    total_samples = 0
    correct_counts = {"Classical A": 0, "Classical B": 0, "Quantum": 0, "Ensemble": 0}
    latencies = {"classical_a": 0.0, "classical_b": 0.0, "quantum": 0.0, "ensemble": 0.0}
    predictions = {"Classical A": [], "Classical B": [], "Quantum": [], "Ensemble": []}
    targets = []
    sample_visuals = []
    for batch_idx, (images, labels, paths) in enumerate(tqdm(dataloader, desc=f'Evaluating [{scenario_name}]')):
        images = images.to(device)
        labels = labels.to(device)
        with torch.no_grad():
            outputs = model(images, return_all=True, active_mask=mask_tensor)
        prob_a = F.softmax(outputs['output_a'], dim=1)
        prob_b = F.softmax(outputs['output_b'], dim=1)
        prob_c = F.softmax(outputs['output_c'], dim=1)
        prob_ensemble = F.softmax(outputs['final_output'], dim=1)
        pred_a = torch.argmax(outputs['output_a'], dim=1)
        pred_b = torch.argmax(outputs['output_b'], dim=1)
        pred_c = torch.argmax(outputs['output_c'], dim=1)
        pred_ensemble = torch.argmax(outputs['final_output'], dim=1)
        total_samples += labels.size(0)
        targets.extend(labels.cpu().numpy().tolist())
        predictions['Classical A'].extend(pred_a.cpu().numpy().tolist())
        predictions['Classical B'].extend(pred_b.cpu().numpy().tolist())
        predictions['Quantum'].extend(pred_c.cpu().numpy().tolist())
        predictions['Ensemble'].extend(pred_ensemble.cpu().numpy().tolist())
        if mask_tensor[0] > 0:
            correct_counts['Classical A'] += (pred_a == labels).sum().item()
        if mask_tensor[1] > 0:
            correct_counts['Classical B'] += (pred_b == labels).sum().item()
        if mask_tensor[2] > 0:
            correct_counts['Quantum'] += (pred_c == labels).sum().item()
        correct_counts['Ensemble'] += (pred_ensemble == labels).sum().item()
        for key in latencies:
            latencies[key] += outputs['latencies'].get(key, 0.0)
        if batch_idx == 0:
            sample_count = min(2, images.size(0))
            for idx in range(sample_count):
                head_probs = {
                    'Classical A': prob_a[idx].detach().cpu().numpy(),
                    'Classical B': prob_b[idx].detach().cpu().numpy(),
                    'Quantum': prob_c[idx].detach().cpu().numpy(),
                    'Ensemble': prob_ensemble[idx].detach().cpu().numpy()
                }
                sample_visuals.append({
                    'image': images[idx].detach().cpu(),
                    'head_probs': head_probs,
                    'target': labels[idx].item(),
                    'predictions': {
                        'Classical A': pred_a[idx].item(),
                        'Classical B': pred_b[idx].item(),
                        'Quantum': pred_c[idx].item(),
                        'Ensemble': pred_ensemble[idx].item()
                    }
                })
    num_batches = max(len(dataloader), 1)
    latencies = {key: value / num_batches for key, value in latencies.items()}
    targets_np = np.array(targets)
    results = {
        'acc': {
            'a': float(correct_counts['Classical A'] / total_samples) if mask_tensor[0] > 0 else float('nan'),
            'b': float(correct_counts['Classical B'] / total_samples) if mask_tensor[1] > 0 else float('nan'),
            'c': float(correct_counts['Quantum'] / total_samples) if mask_tensor[2] > 0 else float('nan'),
            'ensemble': float(correct_counts['Ensemble'] / total_samples) if total_samples > 0 else 0.0
        },
        'latencies': latencies,
        'confusion': {
            'Classical A': confusion_matrix(targets_np, predictions['Classical A'], labels=list(range(len(class_names)))) if mask_tensor[0] > 0 else None,
            'Classical B': confusion_matrix(targets_np, predictions['Classical B'], labels=list(range(len(class_names)))) if mask_tensor[1] > 0 else None,
            'Quantum': confusion_matrix(targets_np, predictions['Quantum'], labels=list(range(len(class_names)))) if mask_tensor[2] > 0 else None,
            'Ensemble': confusion_matrix(targets_np, predictions['Ensemble'], labels=list(range(len(class_names))))
        },
        'report': classification_report(targets_np, predictions['Ensemble'], target_names=class_names, zero_division=0),
        'sample_visuals': []
    }
    for sample in sample_visuals:
        saliency_fig = generate_saliency_overlay(model, sample['image'], sample['predictions']['Ensemble'], device, mask_tensor)
        sample['saliency_figure'] = saliency_fig
        results['sample_visuals'].append(sample)
    return results

In [None]:
# Load checkpoint and prepare dataloader
checkpoint_path = Path(CONFIG['checkpoint_path'])
if not checkpoint_path.exists():
    raise FileNotFoundError(f'Checkpoint not found: {checkpoint_path}')
checkpoint = torch.load(checkpoint_path, map_location=CONFIG['device'])
model_info = checkpoint['model_info']
model = ClassicalDRModel(
    encoder_type=model_info['encoder_type'],
    num_classes=model_info['num_classes'],
    compressed_dim=model_info['compressed_dim']
)
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(CONFIG['device'])
class_names = model_info.get('classes', [str(i) for i in range(model_info['num_classes'])])
model.classes = class_names
model.enable_latency_tracking(True)
model.eval()
dataset = DiabeticRetinopathyDataset(dataset_type=CONFIG['dataset'], mode='benchmark', transform=EVAL_TRANSFORM)
val_size = int(CONFIG['val_fraction'] * len(dataset))
train_size = len(dataset) - val_size
if val_size == 0:
    val_size = len(dataset)
    train_size = 0
generator = torch.Generator().manual_seed(CONFIG['split_seed'])
if train_size > 0:
    _, val_subset = random_split(dataset, [train_size, val_size], generator=generator)
else:
    val_subset = dataset
val_loader = DataLoader(val_subset, batch_size=CONFIG['batch_size'], shuffle=False, num_workers=CONFIG['num_workers'])
print(f'Validation samples: {len(val_subset)}')
scenario_metrics = {}
rows = []
for scenario_name, mask in CONFIG['class_mask_scenarios'].items():
    metrics = evaluate_with_mask(model, val_loader, CONFIG['device'], mask, scenario_name, class_names)
    scenario_metrics[scenario_name] = metrics
    rows.append({
        'scenario': scenario_name,
        'ensemble_accuracy': metrics['acc']['ensemble'],
        'classical_a_accuracy': metrics['acc']['a'],
        'classical_b_accuracy': metrics['acc']['b'],
        'quantum_accuracy': metrics['acc']['c'],
        'latency_classical_a_ms': metrics['latencies']['classical_a'] * 1e3,
        'latency_classical_b_ms': metrics['latencies']['classical_b'] * 1e3,
        'latency_quantum_ms': metrics['latencies']['quantum'] * 1e3
    })
results_df = pd.DataFrame(rows)
display(results_df)

In [None]:
# Confusion matrices
for scenario_name, metrics in scenario_metrics.items():
    print(f'=== {scenario_name} ===')
    for head_name, cm in metrics['confusion'].items():
        if cm is None:
            continue
        fig = create_confusion_plot(cm, class_names, f'{scenario_name} - {head_name}')
        display(fig)
        plt.close(fig)
    print(metrics['report'])

In [None]:
# Sample confidence and saliency views
for scenario_name, metrics in scenario_metrics.items():
    samples = metrics.get('sample_visuals', [])
    if not samples:
        continue
    sample = samples[0]
    title = f"{scenario_name} | target={class_names[sample['target']]}"
    fig = create_head_confidence_figure(sample['image'], sample['head_probs'], class_names, title)
    display(fig)
    plt.close(fig)
    if sample.get('saliency_figure') is not None:
        display(sample['saliency_figure'])
        plt.close(sample['saliency_figure'])

In [None]:
# Install runtime dependencies (no-op if already present)
!pip install -q kagglehub qiskit qiskit_machine_learning > /dev/null 2>&1

In [None]:
# Imports
import os
import time
import pickle
import warnings
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms, models

from sklearn.metrics import classification_report, confusion_matrix

import kagglehub
from qiskit.circuit.library import RealAmplitudes, ZZFeatureMap
from qiskit.primitives import Sampler
from qiskit_machine_learning.neural_networks import SamplerQNN
from qiskit_machine_learning.connectors import TorchConnector

In [None]:
# Configuration
CONFIG = {
    "dataset": "sovitrath",
    "batch_size": 32,
    "num_workers": 2,
    "val_fraction": 0.2,
    "split_seed": 42,
    "checkpoint_path": "complete_checkpoint.pth",
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "class_mask_scenarios": {
        "classical_only": [1, 1, 0],
        "hybrid_full": [1, 1, 1]
    }
}

EVAL_TRANSFORM = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)

warnings.filterwarnings("ignore")

In [None]:
# Model definitions (mirrors training notebook)
class VisionEncoder(nn.Module):
    def __init__(self, encoder_type='vit', pretrained=True):
        super().__init__()
        self.encoder_type = encoder_type
        if encoder_type == 'vit':
            self.encoder = models.vit_b_16(pretrained=pretrained)
            self.encoder.heads = nn.Identity()
            self.projection = nn.Linear(768, 2048)
        else:
            resnet = models.resnet50(pretrained=pretrained)
            self.encoder = nn.Sequential(
                resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool,
                resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4,
                nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten()
            )
            self.projection = nn.Identity()
    def forward(self, x):
        features = self.encoder(x)
        if self.encoder_type == 'vit':
            features = self.projection(features)
        return features

class CompressionModule(nn.Module):
    def __init__(self, input_dim=2048, compressed_dim=30):
        super().__init__()
        self.compressor = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, compressed_dim)
        )
    def forward(self, x):
        return self.compressor(x)

class ClassicalHeadA(nn.Module):
    def __init__(self, input_dim=2048, num_classes=5):
        super().__init__()
        self.head = nn.Sequential(
            nn.Linear(input_dim, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )
    def forward(self, x):
        return self.head(x)

class ClassicalHeadB(nn.Module):
    def __init__(self, input_dim=30, num_classes=5):
        super().__init__()
        self.head = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 32),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(32, num_classes)
        )
    def forward(self, x):
        return self.head(x)

class QuantumClassificationHead(nn.Module):
    def __init__(self, input_dim=32, num_classes=5, num_qubits=4, shots=1024):
        super().__init__()
        self.num_qubits = num_qubits
        self.num_classes = num_classes
        self.q_device = torch.device('cpu')
        self.input_projection = nn.Linear(input_dim, num_qubits)
        feature_map = ZZFeatureMap(num_qubits, reps=2)
        ansatz = RealAmplitudes(num_qubits, reps=2)
        circuit = feature_map.compose(ansatz)
        sampler = Sampler(options={'shots': shots})
        self.qnn = SamplerQNN(
            sampler=sampler,
            circuit=circuit,
            input_params=feature_map.parameters,
            weight_params=ansatz.parameters,
            sparse=False,
            input_gradients=True
        )
        initial_weights = torch.zeros(self.qnn.num_weights, dtype=torch.double)
        self.q_layer = TorchConnector(self.qnn, initial_weights=initial_weights)
        output_dim = 2 ** num_qubits
        self.post_process = nn.Sequential(
            nn.Linear(output_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )
    def forward(self, x):
        projected = self.input_projection(x)
        projected_cpu = projected.to(self.q_device, dtype=torch.double)
        quantum_raw = self.q_layer(projected_cpu)
        quantum_raw = quantum_raw.to(dtype=torch.float32)
        target_device = projected.device
        if any(True for _ in self.post_process.parameters()):
            target_device = next(self.post_process.parameters()).device
        quantum_raw = quantum_raw.to(target_device)
        logits = self.post_process(quantum_raw)
        return logits

class DynamicEnsemble(nn.Module):
    def __init__(self, num_heads=3, init_temp=1.0):
        super().__init__()
        self.num_heads = num_heads
        self.base_weights = nn.Parameter(torch.ones(num_heads) / num_heads)
        self.temperature = nn.Parameter(torch.tensor(init_temp))
        self.uncertainty_scales = nn.Parameter(torch.ones(num_heads))
    def forward(self, head_outputs, uncertainties=None, active_mask=None):
        if active_mask is None:
            mask = torch.ones(len(head_outputs), device=head_outputs[0].device)
        else:
            mask = active_mask.to(head_outputs[0].device).float()
        weights = F.softmax(self.base_weights / self.temperature, dim=0)
        weights = weights * mask
        if weights.sum() <= 0:
            weights = torch.ones_like(weights) * mask
        weights = weights / (weights.sum() + 1e-8)
        if uncertainties is not None: