# Quantum-Classical Hybrid Autoencoder for Medical Image Classification

This notebook implements a hybrid quantum-classical model for chest X-ray classification.

## Installation (Run if packages missing)

In [4]:
 #Uncomment and run if you need to install packages
 %pip install torch torchvision numpy pennylane medmnist

Collecting torch
  Downloading torch-2.9.0-cp310-cp310-win_amd64.whl.metadata (30 kB)
Collecting torchvision
  Downloading torchvision-0.24.0-cp310-cp310-win_amd64.whl.metadata (5.9 kB)
Collecting pennylane
  Using cached pennylane-0.42.3-py3-none-any.whl.metadata (11 kB)
Collecting medmnist
  Using cached medmnist-3.0.2-py3-none-any.whl.metadata (14 kB)
Collecting filelock (from torch)
  Using cached filelock-3.20.0-py3-none-any.whl.metadata (2.1 kB)
Collecting sympy>=1.13.3 (from torch)
  Using cached sympy-1.14.0-py3-none-any.whl.metadata (12 kB)
Collecting networkx>=2.5.1 (from torch)
  Downloading networkx-3.4.2-py3-none-any.whl.metadata (6.3 kB)
Collecting jinja2 (from torch)
  Using cached jinja2-3.1.6-py3-none-any.whl.metadata (2.9 kB)
Collecting fsspec>=0.8.5 (from torch)
  Downloading fsspec-2025.10.0-py3-none-any.whl.metadata (10 kB)
Collecting scipy (from pennylane)
  Using cached scipy-1.15.3-cp310-cp310-win_amd64.whl.metadata (60 kB)
Collecting rustworkx>=0.14.0 (from pen

## Imports

In [5]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import pennylane as qml
from pennylane.qnn import TorchLayer

print("All imports successful!")
print(f"PyTorch version: {torch.__version__}")
print(f"PennyLane version: {qml.__version__}")

All imports successful!
PyTorch version: 2.9.0+cpu
PennyLane version: 0.42.3


## Settings

In [26]:
n_qubits = 4                 # set desired number of qubits
n_layers = 2                 # set desired number of quantum layers
latent_dim = 2**n_qubits     # latent dimension = 2^n_qubits
img_size = 64              # set desired image size (28, 64, 128, 224)
batch_size = 10

print(f"Configuration:")
print(f"  Qubits: {n_qubits}")
print(f"  Quantum layers: {n_layers}")
print(f"  Latent dimension: {latent_dim}")
print(f"  Image size: {img_size}x{img_size}")
print(f"  Batch size: {batch_size}")

Configuration:
  Qubits: 4
  Quantum layers: 2
  Latent dimension: 16
  Image size: 64x64
  Batch size: 10


## Load Data from preprocess.py

We import the data loaders from our preprocess.py file to ensure consistency.

In [18]:
# Import data loaders from preprocess.py
from preprocess import train_loader, test_loader, train_dataset, test_dataset

print(f"Data loaded successfully!")
print(f"Training samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")
print(f"Training batches: {len(train_loader)}")
print(f"Test batches: {len(test_loader)}")

Data loaded successfully!
Training samples: 78468
Test samples: 22433
Training batches: 2452
Test batches: 702


Filter Out Double-Labels

In [27]:
from torch.utils.data import Subset, DataLoader

print("Filtering for single-label images only...")

# Filter training set
train_single_label_indices = []
for i in range(len(train_dataset)):
    _, label = train_dataset[i]
    if label.sum() == 1:  # Only one condition present
        train_single_label_indices.append(i)

train_dataset = Subset(train_dataset, train_single_label_indices)
train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True, drop_last=True)

# Filter test set
test_single_label_indices = []
for i in range(len(test_dataset)):
    _, label = test_dataset[i]
    if label.sum() == 1:
        test_single_label_indices.append(i)

test_dataset = Subset(test_dataset, test_single_label_indices)
test_loader = DataLoader(test_dataset, batch_size=10, shuffle=False, drop_last=False)

print(f"✓ Filtered Training: {len(train_dataset)} single-label images ({len(train_loader)} batches)")
print(f"✓ Filtered Test: {len(test_dataset)} single-label images ({len(test_loader)} batches)")


Filtering for single-label images only...
✓ Filtered Training: 21602 single-label images (2160 batches)
✓ Filtered Test: 6259 single-label images (626 batches)


## Classical Autoencoder

In [None]:
class Encoder(nn.Module):
    def __init__(self, latent_dim, img_size):
        """
        Encoder produces a `latent_dim`-dimensional vector from input images.

        Args:
            latent_dim (int): dimensionality of the encoder output. Must be provided
                by the notebook settings (e.g. `latent_dim = 2**n_qubits`).
            img_size (int): input image size (H=W). Must be provided by the settings
                cell so the model internals compute shapes correctly.
        """
        super().__init__()
        assert isinstance(latent_dim, int) and latent_dim > 0, "latent_dim must be a positive int"
        assert isinstance(img_size, int) and img_size > 0, "img_size must be a positive int"

        self.h8 = img_size // 8

        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(128, latent_dim)
        )
        
        # Initialize encoder weights properly
        for m in self.encoder.modules():
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x):
        z = self.encoder(x)  # (B,latent_dim)
        return z  # Return unnormalized z

## Quantum Circuit

In [21]:
import numpy as np

# ---- quantum device ----
dev = qml.device("default.qubit", wires=n_qubits)

@qml.qnode(dev, interface="torch", diff_method="parameter-shift")
def qnode(inputs, weights):
    # inputs: shape (2**n_qubits,), NOT batched - single sample
    qml.AmplitudeEmbedding(inputs, wires=range(n_qubits), normalize=True, pad_with=0.0)
    qml.StronglyEntanglingLayers(weights, wires=range(n_qubits))
    return [qml.expval(qml.PauliZ(i)) for i in range(n_qubits)]

class QuantumHead(nn.Module):
    def __init__(self, n_layers, n_qubits, n_classes, latent_dim):
        super().__init__()
        self.n_qubits = n_qubits
        self.encoder_fc = nn.Linear(latent_dim, 2**n_qubits)
        
        # Better weight initialization
        nn.init.xavier_uniform_(self.encoder_fc.weight, gain=0.01)  # Even smaller gain
        nn.init.zeros_(self.encoder_fc.bias)
        
        # Initialize quantum weights manually
        self.q_weights = nn.Parameter(torch.randn(n_layers, n_qubits, 3) * 0.001)  # Very small initialization
        
        # Readout layer
        self.readout = nn.Linear(n_qubits, n_classes)

    def forward(self, h):
        # h: (B, latent_dim)
        batch_size = h.shape[0]
        
        # Check input for NaN
        if torch.isnan(h).any() or torch.isinf(h).any():
            print("⚠️ NaN/Inf in input h!")
            h = torch.nan_to_num(h, nan=0.0, posinf=1.0, neginf=-1.0)
        
        z = self.encoder_fc(h)  # (B, 2**n_qubits)
        
        # Aggressive NaN/Inf replacement
        z = torch.nan_to_num(z, nan=0.0, posinf=1.0, neginf=-1.0)
        
        # Clamp BEFORE normalization
        z = torch.clamp(z, min=-5, max=5)
        
        # Normalize per row with extreme safety
        norm = torch.sqrt(torch.sum(z**2, dim=1, keepdim=True) + 1e-10)
        
        # If norm is too small, replace entire row with uniform vector
        small_norm_mask = norm.squeeze() < 1e-6
        if small_norm_mask.any():
            print(f"⚠️ {small_norm_mask.sum().item()} rows with near-zero norm, fixing...")
            uniform_vec = torch.ones(2**self.n_qubits, device=z.device) / np.sqrt(2**self.n_qubits)
            z[small_norm_mask] = uniform_vec
            norm = torch.sqrt(torch.sum(z**2, dim=1, keepdim=True) + 1e-10)
        
        z_normalized = z / norm
        
        # Final verification
        z_normalized = torch.nan_to_num(z_normalized, nan=0.0, posinf=0.0, neginf=0.0)
        
        # Re-normalize after NaN replacement
        final_norm = torch.sqrt(torch.sum(z_normalized**2, dim=1, keepdim=True) + 1e-10)
        z_normalized = z_normalized / final_norm
        
        # Manual batching - process one sample at a time
        results = []
        for i in range(batch_size):
            single_input = z_normalized[i].detach().clone()
            
            # Triple check this specific sample
            if torch.isnan(single_input).any() or torch.isinf(single_input).any():
                print(f"⚠️ Sample {i} still has NaN/Inf! Using uniform vector.")
                single_input = torch.ones(2**self.n_qubits, device=z.device) / np.sqrt(2**self.n_qubits)
            
            # Verify norm is close to 1
            sample_norm = torch.sqrt(torch.sum(single_input**2))
            if abs(sample_norm - 1.0) > 0.1:
                print(f"⚠️ Sample {i} norm is {sample_norm.item():.4f}, renormalizing...")
                single_input = single_input / (sample_norm + 1e-10)
            
            try:
                # Run quantum circuit on single sample
                expvals = qnode(single_input, self.q_weights)
                expvals_tensor = torch.stack(expvals).float()
            except Exception as e:
                print(f"⚠️ Quantum circuit failed for sample {i}: {e}")
                # Return zeros if circuit fails
                expvals_tensor = torch.zeros(self.n_qubits, device=z.device)
            
            results.append(expvals_tensor)
        
        # Stack all results back into batch
        expvals_batch = torch.stack(results)  # (B, n_qubits)
        
        # Final readout layer
        return self.readout(expvals_batch)

print(f"Quantum device initialized: {dev}")
print("QuantumHead with EXTREME safety checks!")

Quantum device initialized: <default.qubit device (wires=6) at 0x1f167570940>
QuantumHead with EXTREME safety checks!


## Hybrid Quantum-Classical Model

In [22]:
class HybridQML(nn.Module):
    def __init__(self, img_size, latent_dim, n_classes=14):
        super().__init__()
        self.enc = Encoder(latent_dim=latent_dim, img_size=img_size)
        self.qhead = QuantumHead(n_layers=n_layers, n_qubits=n_qubits, n_classes=n_classes, latent_dim=latent_dim)

    def forward(self, x, return_recon=False):
        h = self.enc(x)
        out = self.qhead(h)
        return out

## Training Setup

In [23]:
import time
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

model = HybridQML(img_size=img_size, latent_dim=latent_dim, n_classes=14).to(device)

optimizer = torch.optim.AdamW(
    model.parameters(), 
    lr=5e-5,
    weight_decay=1e-5,
    eps=1e-7
)

clf_criterion = nn.BCEWithLogitsLoss()

print("✓ Model recreated with dtype fix!")
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

def train_one_epoch(epoch):
    model.train()
    total, n = 0.0, 0
    window_start = time.time()
    
    for batch_idx, (imgs, labels) in enumerate(train_loader):
        imgs = imgs.to(device)
        labels = labels.float().to(device)

        logits = model(imgs, return_recon=True)
        loss = clf_criterion(logits, labels)

        optimizer.zero_grad()
        loss.backward()
        
        # GRADIENT CLIPPING - prevents NaN!
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()

        total += loss.item() * imgs.size(0)
        n += imgs.size(0)
        

        avg_size = 10
        if batch_idx % avg_size == 0:  
            # measure total time for the last 10 batches
            elapsed = time.time() - window_start
            avg_time = elapsed / avg_size if batch_idx != 0 else elapsed
            print(
                f"  Batch {batch_idx+1}/{len(train_loader)}, "
                f"Avg Loss: {total/n:.4f}, "
                f"Avg Time per batch: {avg_time:.4f}s"
            )
            window_start = time.time()  # reset window
            

    print(f"Epoch {epoch}: train loss = {total/n:.4f}")

@torch.no_grad()
def evaluate():
    model.eval()
    total, n = 0.0, 0
    for imgs, labels in test_loader:
        imgs = imgs.to(device)
        labels = labels.float().to(device)
        logits = model(imgs)
        loss = clf_criterion(logits, labels)
        total += loss.item() * imgs.size(0)
        n += imgs.size(0)
    print(f"Test BCEWithLogits loss = {total/n:.4f}")

print("Training functions defined successfully!")
print(f"Model has {sum(p.numel() for p in model.parameters())} parameters")

Using device: cpu
✓ Model recreated with dtype fix!
Parameters: 105,222
Training functions defined successfully!
Model has 105222 parameters


## Run Training

In [24]:
print("Starting training...\n")

for epoch in range(1, 3):
    train_one_epoch(epoch)
    evaluate()
    print()

print("Training complete!")

Starting training...

  Batch 1/2160, Avg Loss: 0.6846, Avg Time per batch: 9.9829s
  Batch 11/2160, Avg Loss: 0.6798, Avg Time per batch: 10.7231s


KeyboardInterrupt: 

In [None]:
# Save the trained model
torch.save(model.state_dict(), "hybrid_qml_model.pth")