# Quantum-Classical Hybrid Autoencoder for Medical Image Classification

This notebook implements a hybrid quantum-classical model for chest X-ray classification.

## Installation (Run if packages missing)

In [73]:
# Uncomment and run if you need to install packages
# %pip install torch torchvision numpy pennylane medmnist

## Imports

In [74]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import pennylane as qml
from pennylane.qnn import TorchLayer

print("All imports successful!")
print(f"PyTorch version: {torch.__version__}")
print(f"PennyLane version: {qml.__version__}")

All imports successful!
PyTorch version: 2.8.0
PennyLane version: 0.42.3


## Settings

In [75]:
n_qubits = 6                 # set desired number of qubits
n_layers = 2                 # set desired number of quantum layers
latent_dim = 2**n_qubits     # latent dimension = 2^n_qubits
img_size = 224               # set desired image size (28, 64, 128, 224)
batch_size = 32

print(f"Configuration:")
print(f"  Qubits: {n_qubits}")
print(f"  Quantum layers: {n_layers}")
print(f"  Latent dimension: {latent_dim}")
print(f"  Image size: {img_size}x{img_size}")
print(f"  Batch size: {batch_size}")

Configuration:
  Qubits: 6
  Quantum layers: 2
  Latent dimension: 64
  Image size: 224x224
  Batch size: 32


## Load Data from preprocess.py

We import the data loaders from our preprocess.py file to ensure consistency.

In [76]:
# Import data loaders from preprocess.py
from preprocess import train_loader, test_loader, train_dataset, test_dataset

print(f"Data loaded successfully!")
print(f"Training samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")
print(f"Training batches: {len(train_loader)}")
print(f"Test batches: {len(test_loader)}")

Data loaded successfully!
Training samples: 78468
Test samples: 22433
Training batches: 2452
Test batches: 702


Filter Out Double-Labels

In [77]:
from torch.utils.data import Subset, DataLoader

print("Filtering for single-label images only...")

# Filter training set
train_single_label_indices = []
for i in range(len(train_dataset)):
    _, label = train_dataset[i]
    if label.sum() == 1:  # Only one condition present
        train_single_label_indices.append(i)

train_dataset = Subset(train_dataset, train_single_label_indices)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, drop_last=True)

# Filter test set
test_single_label_indices = []
for i in range(len(test_dataset)):
    _, label = test_dataset[i]
    if label.sum() == 1:
        test_single_label_indices.append(i)

test_dataset = Subset(test_dataset, test_single_label_indices)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, drop_last=False)

print(f"✓ Filtered Training: {len(train_dataset)} single-label images ({len(train_loader)} batches)")
print(f"✓ Filtered Test: {len(test_dataset)} single-label images ({len(test_loader)} batches)")


Filtering for single-label images only...
✓ Filtered Training: 21602 single-label images (675 batches)
✓ Filtered Test: 6259 single-label images (196 batches)


## Classical Autoencoder

In [78]:
class Autoencoder(nn.Module):
    def __init__(self, latent_dim=64, img_size=224):
        super().__init__()
        self.h8 = img_size // 8

        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(128, latent_dim)
        )
        
        # Initialize encoder weights properly
        for m in self.encoder.modules():
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 128 * self.h8 * self.h8),
            nn.ReLU(),
            nn.Unflatten(1, (128, self.h8, self.h8)),
            nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, 3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        z = self.encoder(x)  # (B,latent_dim)
        # REMOVED: z_norm = z / (z.norm(dim=1, keepdim=True) + 1e-8)
        x_rec = self.decoder(z)
        return z, x_rec  # Return unnormalized z

## Quantum Circuit

In [79]:
import numpy as np

# ---- quantum device ----
dev = qml.device("default.qubit", wires=n_qubits)

@qml.qnode(dev, interface="torch", diff_method="parameter-shift")
def qnode(inputs, weights):
    # inputs: shape (2**n_qubits,), NOT batched - single sample
    qml.AmplitudeEmbedding(inputs, wires=range(n_qubits), normalize=True, pad_with=0.0)
    qml.StronglyEntanglingLayers(weights, wires=range(n_qubits))
    return [qml.expval(qml.PauliZ(i)) for i in range(n_qubits)]

class QuantumHead(nn.Module):
    def __init__(self, n_layers, n_qubits, n_classes, latent_dim):
        super().__init__()
        self.n_qubits = n_qubits
        self.encoder_fc = nn.Linear(latent_dim, 2**n_qubits)
        
        # Better weight initialization
        nn.init.xavier_uniform_(self.encoder_fc.weight, gain=0.01)  # Even smaller gain
        nn.init.zeros_(self.encoder_fc.bias)
        
        # Initialize quantum weights manually
        self.q_weights = nn.Parameter(torch.randn(n_layers, n_qubits, 3) * 0.001)  # Very small initialization
        
        # Readout layer
        self.readout = nn.Linear(n_qubits, n_classes)

    def forward(self, h):
        # h: (B, latent_dim)
        batch_size = h.shape[0]
        
        # Check input for NaN
        if torch.isnan(h).any() or torch.isinf(h).any():
            print("⚠️ NaN/Inf in input h!")
            h = torch.nan_to_num(h, nan=0.0, posinf=1.0, neginf=-1.0)
        
        z = self.encoder_fc(h)  # (B, 2**n_qubits)
        
        # Aggressive NaN/Inf replacement
        z = torch.nan_to_num(z, nan=0.0, posinf=1.0, neginf=-1.0)
        
        # Clamp BEFORE normalization
        z = torch.clamp(z, min=-5, max=5)
        
        # Normalize per row with extreme safety
        norm = torch.sqrt(torch.sum(z**2, dim=1, keepdim=True) + 1e-10)
        
        # If norm is too small, replace entire row with uniform vector
        small_norm_mask = norm.squeeze() < 1e-6
        if small_norm_mask.any():
            print(f"⚠️ {small_norm_mask.sum().item()} rows with near-zero norm, fixing...")
            uniform_vec = torch.ones(2**self.n_qubits, device=z.device) / np.sqrt(2**self.n_qubits)
            z[small_norm_mask] = uniform_vec
            norm = torch.sqrt(torch.sum(z**2, dim=1, keepdim=True) + 1e-10)
        
        z_normalized = z / norm
        
        # Final verification
        z_normalized = torch.nan_to_num(z_normalized, nan=0.0, posinf=0.0, neginf=0.0)
        
        # Re-normalize after NaN replacement
        final_norm = torch.sqrt(torch.sum(z_normalized**2, dim=1, keepdim=True) + 1e-10)
        z_normalized = z_normalized / final_norm
        
        # Manual batching - process one sample at a time
        results = []
        for i in range(batch_size):
            single_input = z_normalized[i].detach().clone()
            
            # Triple check this specific sample
            if torch.isnan(single_input).any() or torch.isinf(single_input).any():
                print(f"⚠️ Sample {i} still has NaN/Inf! Using uniform vector.")
                single_input = torch.ones(2**self.n_qubits, device=z.device) / np.sqrt(2**self.n_qubits)
            
            # Verify norm is close to 1
            sample_norm = torch.sqrt(torch.sum(single_input**2))
            if abs(sample_norm - 1.0) > 0.1:
                print(f"⚠️ Sample {i} norm is {sample_norm.item():.4f}, renormalizing...")
                single_input = single_input / (sample_norm + 1e-10)
            
            try:
                # Run quantum circuit on single sample
                expvals = qnode(single_input, self.q_weights)
                expvals_tensor = torch.stack(expvals).float()
            except Exception as e:
                print(f"⚠️ Quantum circuit failed for sample {i}: {e}")
                # Return zeros if circuit fails
                expvals_tensor = torch.zeros(self.n_qubits, device=z.device)
            
            results.append(expvals_tensor)
        
        # Stack all results back into batch
        expvals_batch = torch.stack(results)  # (B, n_qubits)
        
        # Final readout layer
        return self.readout(expvals_batch)

print(f"Quantum device initialized: {dev}")
print("QuantumHead with EXTREME safety checks!")

Quantum device initialized: <default.qubit device (wires=6) at 0x3010d4250>
QuantumHead with EXTREME safety checks!


## Hybrid Quantum-Classical Model

In [80]:
model = HybridQML(img_size=img_size, latent_dim=latent_dim, n_classes=14).to(device)

optimizer = torch.optim.AdamW(
    model.parameters(), 
    lr=5e-5,
    weight_decay=1e-5,
    eps=1e-7
)

clf_criterion = nn.BCEWithLogitsLoss()

print("✓ Model recreated with dtype fix!")
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

✓ Model recreated with dtype fix!
Parameters: 6,720,647


## Training Setup

In [81]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

model = HybridQML(img_size=img_size, latent_dim=latent_dim, n_classes=14, recon_weight=0.05).to(device)

clf_criterion = nn.BCEWithLogitsLoss()  # multi-label loss
recon_criterion = nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=1e-5, eps=1e-7)

def train_one_epoch(epoch):
    model.train()
    total, n = 0.0, 0
    
    for batch_idx, (imgs, labels) in enumerate(train_loader):
        imgs = imgs.to(device)
        labels = labels.float().to(device)

        logits, x_rec = model(imgs, return_recon=True)
        loss = clf_criterion(logits, labels)

        optimizer.zero_grad()
        loss.backward()
        
        # GRADIENT CLIPPING - prevents NaN!
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()

        total += loss.item() * imgs.size(0)
        n += imgs.size(0)
        
        if batch_idx % 200 == 0:
            print(f"  Batch {batch_idx}/{len(train_loader)}, Avg Loss: {total/n:.4f}")

    print(f"Epoch {epoch}: train loss = {total/n:.4f}")

@torch.no_grad()
def evaluate():
    model.eval()
    total, n = 0.0, 0
    for imgs, labels in test_loader:
        imgs = imgs.to(device)
        labels = labels.float().to(device)
        logits = model(imgs)
        loss = clf_criterion(logits, labels)
        total += loss.item() * imgs.size(0)
        n += imgs.size(0)
    print(f"Test BCEWithLogits loss = {total/n:.4f}")

print("Training functions defined successfully!")
print(f"Model has {sum(p.numel() for p in model.parameters())} parameters")

Using device: cpu
Training functions defined successfully!
Model has 6720647 parameters


## Run Training

In [82]:
print("Starting training...\n")

for epoch in range(1, 6):
    train_one_epoch(epoch)
    evaluate()
    print()

print("Training complete!")

Starting training...

  Batch 0/675, Avg Loss: 0.7156
  Batch 200/675, Avg Loss: 0.7093
  Batch 400/675, Avg Loss: 0.7052
  Batch 600/675, Avg Loss: 0.7011
Epoch 1: train loss = 0.6997
Test BCEWithLogits loss = 0.6861

  Batch 0/675, Avg Loss: 0.6824
  Batch 200/675, Avg Loss: 0.6825
  Batch 400/675, Avg Loss: 0.6786
  Batch 600/675, Avg Loss: 0.6747
Epoch 2: train loss = 0.6733
Test BCEWithLogits loss = 0.6602

  Batch 0/675, Avg Loss: 0.6625


KeyboardInterrupt: 