In [9]:
# Imports
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import os
import numpy as np
from qiskit import QuantumCircuit
from qiskit.circuit import Parameter
from qiskit_machine_learning.neural_networks import SamplerQNN
from qiskit_machine_learning.connectors import TorchConnector
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

# Set device
if torch.xpu.is_available():
    device = torch.device('xpu')
    print("XPU (Intel GPU) is available!")
else:
    device = torch.device('cpu')
    print("XPU not available, using CPU")

print(f"Device set to: {device}")

XPU (Intel GPU) is available!
Device set to: xpu


In [10]:
# Select architectures to train on
SELECTED_CLASSES = ['dome(inner)', 'dome(outer)', 'gargoyle', 'stained_glass']

# Dataset path
DATA_DIR = '/home/advik/Quantum/Mini Project/architecture_dataset_32x32'

In [11]:
# Load dataset and create train/validation split
from sklearn.model_selection import train_test_split

class ArchitectureDataset(Dataset):
    def __init__(self, data, labels, transform=None):
        self.data = data
        self.labels = labels
        self.transform = transform
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        image = self.data[idx]
        label = self.labels[idx]
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

# Define transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load images and labels
train_images = []
train_labels = []

train_dir = os.path.join(DATA_DIR, 'train')

for class_idx, class_name in enumerate(SELECTED_CLASSES):
    class_path = os.path.join(train_dir, class_name)
    
    for img_name in os.listdir(class_path):
        img_path = os.path.join(class_path, img_name)
        img = Image.open(img_path).convert('RGB')
        train_images.append(img)
        train_labels.append(class_idx)

print(f"Total images loaded: {len(train_images)}")
print(f"Classes: {SELECTED_CLASSES}")
print(f"Images per class: {[train_labels.count(i) for i in range(len(SELECTED_CLASSES))]}")

# Split train into train and validation (80/20)
train_imgs, val_imgs, train_lbls, val_lbls = train_test_split(
    train_images, train_labels, test_size=0.2, random_state=42, stratify=train_labels
)

print(f"\nTrain set: {len(train_imgs)} images")
print(f"Validation set: {len(val_imgs)} images")

# Create datasets
train_dataset = ArchitectureDataset(train_imgs, train_lbls, transform=transform)
val_dataset = ArchitectureDataset(val_imgs, val_lbls, transform=transform)

# Create dataloaders - REDUCED BATCH SIZE for quantum computing
batch_size = 8  # Smaller batches = faster quantum gradient computation
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

print(f"\nTrain batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")
print(f"âš¡ Batch size reduced to {batch_size} for faster quantum gradient computation")

Total images loaded: 4324
Classes: ['dome(inner)', 'dome(outer)', 'gargoyle', 'stained_glass']
Images per class: [589, 1175, 1562, 998]

Train set: 3459 images
Validation set: 865 images

Train batches: 433
Validation batches: 109
âš¡ Batch size reduced to 8 for faster quantum gradient computation


In [12]:
# ==============================================================================
# Quantum Circuits for QKV Projections
# ==============================================================================
# These parameterized quantum circuits replace the classical Linear layers
# that normally compute Q, K, V in transformer self-attention.
#
# Design choices (aligned with proposal):
#   - 4 qubits  -> 2^4 = 16 output dimensions (probability distribution)
#   - 2 variational layers -> 24 trainable parameters per circuit
#   - Angle encoding (RY) for classical-to-quantum data input
#   - Circular CNOT entanglement for inter-qubit correlations
#   - SamplerQNN + TorchConnector for full PyTorch gradient integration
# ==============================================================================

n_qubits = 4   # 4 qubits -> output dim 2^4 = 16
n_layers = 2   # Shallow circuit: avoids barren plateaus, faster gradients

def create_qkv_quantum_circuit(name='Q'):
    """
    Build a parameterized quantum circuit for one of Q/K/V projections.
    
    Input:  n_qubits real values (angle-encoded via RY gates)
    Output: 2^n_qubits probability distribution (measured in computational basis)
    """
    # Input parameters (data encoding)
    input_params = [Parameter(f'{name}_x_{i}') for i in range(n_qubits)]
    # Trainable weight parameters: 3 rotations (RY, RZ, RX) per qubit per layer
    weight_params = [Parameter(f'{name}_w_{i}') for i in range(n_qubits * n_layers * 3)]
    
    qc = QuantumCircuit(n_qubits)
    
    # === Data encoding layer ===
    for i in range(n_qubits):
        qc.ry(input_params[i], i)
    
    # === Variational layers ===
    param_idx = 0
    for layer in range(n_layers):
        # RY rotation
        for i in range(n_qubits):
            qc.ry(weight_params[param_idx], i)
            param_idx += 1
        # RZ rotation
        for i in range(n_qubits):
            qc.rz(weight_params[param_idx], i)
            param_idx += 1
        # Circular CNOT entanglement
        for i in range(n_qubits - 1):
            qc.cx(i, i + 1)
        qc.cx(n_qubits - 1, 0)
        # RX rotation after entanglement
        for i in range(n_qubits):
            qc.rx(weight_params[param_idx], i)
            param_idx += 1
    
    # Wrap as a PyTorch-compatible quantum layer
    qnn = SamplerQNN(
        circuit=qc,
        input_params=input_params,
        weight_params=weight_params,
        input_gradients=True  # REQUIRED for hybrid gradient backprop
    )
    return TorchConnector(qnn)

# Create 3 separate quantum circuits: one each for Q, K, V
print("Creating Quantum QKV Projection Circuits...")
print("=" * 60)

quantum_q = create_qkv_quantum_circuit('Q')
quantum_k = create_qkv_quantum_circuit('K')
quantum_v = create_qkv_quantum_circuit('V')

print(f"Qubits per circuit:       {n_qubits}")
print(f"Variational layers:       {n_layers}")
print(f"Trainable params/circuit: {n_qubits * n_layers * 3}")
print(f"Output dim per circuit:   2^{n_qubits} = {2**n_qubits}")
print(f"Total quantum circuits:   3 (Q, K, V)")
print("=" * 60)

No gradient function provided, creating a gradient function. If your Sampler requires transpilation, please provide a pass manager.
No gradient function provided, creating a gradient function. If your Sampler requires transpilation, please provide a pass manager.
No gradient function provided, creating a gradient function. If your Sampler requires transpilation, please provide a pass manager.


Creating Quantum QKV Projection Circuits...
Qubits per circuit:       4
Variational layers:       2
Trainable params/circuit: 24
Output dim per circuit:   2^4 = 16
Total quantum circuits:   3 (Q, K, V)


In [13]:
# ==============================================================================
# Hybrid Quantum-Classical Vision Transformer
# ==============================================================================
# Architecture (matching proposal Figure 1):
#
#   [Input 32x32 RGB] 
#       -> Patch Embedding (Classical: Conv2d)
#       -> Positional Encoding (Classical: learnable)
#       -> Transformer Block 1 (HYBRID):
#           - Self-Attention with Quantum QKV projections
#             * pre-projection:  Linear embed_dim -> n_qubits  (Classical)
#             * Q, K, V circuits: SamplerQNN via TorchConnector (QUANTUM)
#             * post-projection: Linear quantum_out -> embed_dim (Classical)
#             * attention scores: softmax(QK^T / sqrt(d))       (Classical)
#             * weighted values:  attn @ V                      (Classical)
#           - Feed-Forward Network                              (Classical)
#           - LayerNorm + Residual connections                  (Classical)
#       -> Transformer Block 2 (CLASSICAL):
#           - Standard self-attention (Linear QKV)
#           - Feed-Forward Network
#           - LayerNorm + Residual connections
#       -> Classification Head (Classical: MLP)
#
# Key performance choices:
#   - patch_size=16 -> 4 patches (not 64), reducing quantum circuit calls by 4Ã—
#   - Only Block 1 uses quantum QKV; Block 2 is fully classical
#   - 4 qubits -> 16 output dims per circuit (not 256)
# ==============================================================================


class PatchEmbedding(nn.Module):
    """Split image into patches and project to embedding dim (Classical)"""
    def __init__(self, img_size=32, patch_size=8, in_channels=3, embed_dim=64):
        super().__init__()
        self.n_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_channels, embed_dim,
                              kernel_size=patch_size, stride=patch_size)
    
    def forward(self, x):
        x = self.proj(x)       # (B, embed_dim, H/P, W/P)
        x = x.flatten(2)       # (B, embed_dim, n_patches)
        x = x.transpose(1, 2)  # (B, n_patches, embed_dim)
        return x


class QuantumQKVAttention(nn.Module):
    """
    Self-attention where Q, K, V projections run on quantum circuits.
    Everything else (attention scores, weighted sum) is classical.
    
    Flow per token:
      embed_dim -[Linear]-> n_qubits -[Quantum Circuit]-> 2^n_qubits -[Linear]-> embed_dim
    """
    def __init__(self, quantum_q, quantum_k, quantum_v,
                 embed_dim=64, num_heads=4, dropout=0.1):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.quantum_out_dim = 2 ** n_qubits  # 16
        
        # Classical pre-projection: compress to quantum input size
        self.pre_q = nn.Linear(embed_dim, n_qubits)
        self.pre_k = nn.Linear(embed_dim, n_qubits)
        self.pre_v = nn.Linear(embed_dim, n_qubits)
        
        # Quantum QKV circuits
        self.quantum_q = quantum_q
        self.quantum_k = quantum_k
        self.quantum_v = quantum_v
        
        # Classical post-projection: expand quantum output to embed_dim
        self.post_q = nn.Linear(self.quantum_out_dim, embed_dim)
        self.post_k = nn.Linear(self.quantum_out_dim, embed_dim)
        self.post_v = nn.Linear(self.quantum_out_dim, embed_dim)
        
        # Output projection
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
        self.scale = self.head_dim ** -0.5
    
    def forward(self, x):
        B, S, _ = x.shape
        
        # Classical pre-projection -> scale to [-pi, pi] for angle encoding
        q_in = torch.tanh(self.pre_q(x)) * np.pi
        k_in = torch.tanh(self.pre_k(x)) * np.pi
        v_in = torch.tanh(self.pre_v(x)) * np.pi
        
        # Flatten batch and sequence dims for quantum processing
        q_flat = q_in.reshape(-1, n_qubits)
        k_flat = k_in.reshape(-1, n_qubits)
        v_flat = v_in.reshape(-1, n_qubits)
        
        # ---- QUANTUM: QKV projections through parameterized circuits ----
        q_out = self.quantum_q(q_flat)  # (B*S, 16)
        k_out = self.quantum_k(k_flat)
        v_out = self.quantum_v(v_flat)
        
        # Classical post-projection back to embed_dim
        Q = self.post_q(q_out.reshape(B, S, self.quantum_out_dim))
        K = self.post_k(k_out.reshape(B, S, self.quantum_out_dim))
        V = self.post_v(v_out.reshape(B, S, self.quantum_out_dim))
        
        # ---- CLASSICAL: Standard multi-head attention computation ----
        Q = Q.reshape(B, S, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.reshape(B, S, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.reshape(B, S, self.num_heads, self.head_dim).transpose(1, 2)
        
        attn = (Q @ K.transpose(-2, -1)) * self.scale
        attn = torch.softmax(attn, dim=-1)
        attn = self.dropout(attn)
        
        out = (attn @ V).transpose(1, 2).contiguous().reshape(B, S, self.embed_dim)
        return self.dropout(self.out_proj(out))


class ClassicalAttention(nn.Module):
    """Standard classical multi-head self-attention (no quantum)"""
    def __init__(self, embed_dim=64, num_heads=4, dropout=0.1):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.embed_dim = embed_dim
        self.qkv = nn.Linear(embed_dim, 3 * embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
        self.scale = self.head_dim ** -0.5
    
    def forward(self, x):
        B, S, D = x.shape
        qkv = self.qkv(x).reshape(B, S, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, B, heads, S, head_dim)
        Q, K, V = qkv[0], qkv[1], qkv[2]
        
        attn = (Q @ K.transpose(-2, -1)) * self.scale
        attn = torch.softmax(attn, dim=-1)
        attn = self.dropout(attn)
        
        out = (attn @ V).transpose(1, 2).contiguous().reshape(B, S, D)
        return self.dropout(self.out_proj(out))


class TransformerBlock(nn.Module):
    """Transformer encoder block (works with either attention type)"""
    def __init__(self, attention, embed_dim=64, mlp_ratio=2.0, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = attention
        self.norm2 = nn.LayerNorm(embed_dim)
        mlp_dim = int(embed_dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_dim, embed_dim),
            nn.Dropout(dropout)
        )
    
    def forward(self, x):
        x = x + self.attn(self.norm1(x))   # Attention + residual
        x = x + self.mlp(self.norm2(x))    # FFN + residual
        return x


class HybridViT_QNN(nn.Module):
    """
    Hybrid Quantum-Classical Vision Transformer for image classification.
    
    Quantum processing is limited to the QKV projections in Block 1.
    All other computation (patch embed, attention scores, FFN, classifier) is classical.
    """
    def __init__(self, quantum_q, quantum_k, quantum_v, n_classes=4,
                 img_size=32, patch_size=8, embed_dim=64, num_heads=4):
        super().__init__()
        
        # Patch embedding (Classical)
        self.patch_embed = PatchEmbedding(img_size, patch_size, 3, embed_dim)
        n_patches = self.patch_embed.n_patches
        
        # Positional encoding (Classical)
        self.pos_embed = nn.Parameter(torch.randn(1, n_patches, embed_dim) * 0.02)
        self.pos_drop = nn.Dropout(0.1)
        
        # Block 1: HYBRID - Quantum QKV + Classical attention + Classical FFN
        quantum_attn = QuantumQKVAttention(
            quantum_q, quantum_k, quantum_v, embed_dim, num_heads
        )
        self.hybrid_block = TransformerBlock(quantum_attn, embed_dim)
        
        # Block 2: CLASSICAL - Standard attention + FFN
        classical_attn = ClassicalAttention(embed_dim, num_heads)
        self.classical_block = TransformerBlock(classical_attn, embed_dim)
        
        self.norm = nn.LayerNorm(embed_dim)
        
        # Classification head (Classical)
        self.head = nn.Sequential(
            nn.Linear(embed_dim, 128),
            nn.GELU(),
            nn.Dropout(0.2),
            nn.Linear(128, n_classes)
        )
    
    def forward(self, x):
        # Patch embedding (Classical)
        x = self.patch_embed(x)   # (B, n_patches, embed_dim)
        x = x + self.pos_embed
        x = self.pos_drop(x)
        
        # Transformer blocks
        x = self.hybrid_block(x)    # Block 1: Quantum QKV projections
        x = self.classical_block(x) # Block 2: Fully classical
        
        x = self.norm(x)
        x = x.mean(dim=1)  # Global average pooling
        return self.head(x)


# ==============================================================================
# Instantiate the model
# ==============================================================================
model = HybridViT_QNN(
    quantum_q=quantum_q,
    quantum_k=quantum_k,
    quantum_v=quantum_v,
    n_classes=len(SELECTED_CLASSES),
    img_size=32,
    patch_size=16,    # 16x16 patches -> 4 patches (4Ã— faster than 8x8!)
    embed_dim=64,
    num_heads=4
).to(device)

# Print architecture summary
print("=" * 60)
print("HYBRID QUANTUM-CLASSICAL VISION TRANSFORMER")
print("=" * 60)
n_patches = (32 // 16) ** 2
print(f"Patch size:            16x16 -> {n_patches} patches per image")
print(f"Embedding dim:         64")
print(f"Attention heads:       4 (head_dim=16)")
print(f"Block 1 (Hybrid):      Quantum QKV ({n_qubits}q, {2**n_qubits} output) + Classical attention")
print(f"Block 2 (Classical):   Standard self-attention")
print(f"Classification:        MLP (64 -> 128 -> {len(SELECTED_CLASSES)})")
print("-" * 60)
total_params = sum(p.numel() for p in model.parameters())
quantum_params = sum(p.numel() for p in quantum_q.parameters()) * 3
classical_params = total_params - quantum_params
print(f"Total parameters:      {total_params:,}")
print(f"  Quantum parameters:  {quantum_params:,} (Q + K + V circuits)")
print(f"  Classical parameters:{classical_params:,}")
print(f"Quantum evals/image:   {n_patches} patches x 3 circuits = {n_patches * 3} âš¡")
print(f"Device:                {device}")
print(f"ðŸš€ 4Ã— SPEEDUP from larger patches (4 vs 16 patches)")
print("=" * 60)

HYBRID QUANTUM-CLASSICAL VISION TRANSFORMER
Patch size:            16x16 -> 4 patches per image
Embedding dim:         64
Attention heads:       4 (head_dim=16)
Block 1 (Hybrid):      Quantum QKV (4q, 16 output) + Classical attention
Block 2 (Classical):   Standard self-attention
Classification:        MLP (64 -> 128 -> 4)
------------------------------------------------------------
Total parameters:      117,016
  Quantum parameters:  72 (Q + K + V circuits)
  Classical parameters:116,944
Quantum evals/image:   4 patches x 3 circuits = 12 âš¡
Device:                xpu
ðŸš€ 4Ã— SPEEDUP from larger patches (4 vs 16 patches)


In [14]:
# ==============================================================================
# Training Setup
# ==============================================================================

criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.0005)  # Lower LR for quantum stability
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=2, 
)

def train_epoch(model, loader, criterion, optimizer, pbar=None):
    """Train for one epoch with batch-level metrics tracking"""
    model.train()
    batch_losses = []
    batch_accs = []
    
    for batch_idx, (images, labels) in enumerate(loader):
        images = images.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        
        # Gradient clipping: stabilizes quantum parameter shift gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        
        # Track batch-level metrics
        batch_losses.append(loss.item())
        _, predicted = outputs.max(1)
        batch_acc = 100. * predicted.eq(labels).sum().item() / labels.size(0)
        batch_accs.append(batch_acc)
        
        # Update progress bar
        if pbar is not None:
            pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'acc': f'{batch_acc:.1f}%'
            })
            pbar.update(1)
    
    return batch_losses, batch_accs

def validate(model, loader, criterion, pbar=None):
    """Validate with batch-level metrics tracking"""
    model.eval()
    batch_losses = []
    batch_accs = []
    
    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            batch_losses.append(loss.item())
            _, predicted = outputs.max(1)
            batch_acc = 100. * predicted.eq(labels).sum().item() / labels.size(0)
            batch_accs.append(batch_acc)
            
            if pbar is not None:
                pbar.set_postfix({
                    'loss': f'{loss.item():.4f}',
                    'acc': f'{batch_acc:.1f}%'
                })
                pbar.update(1)
    
    return batch_losses, batch_accs

print("Training functions ready!")
print(f"Optimizer:  Adam (lr={optimizer.param_groups[0]['lr']})")
print(f"Scheduler:  ReduceLROnPlateau (factor=0.5, patience=2)")
print(f"Grad clip:  max_norm=1.0")
print(f"Progress:   tqdm with batch-level live plotting")

Training functions ready!
Optimizer:  Adam (lr=0.0005)
Scheduler:  ReduceLROnPlateau (factor=0.5, patience=2)
Grad clip:  max_norm=1.0
Progress:   tqdm with batch-level live plotting


In [None]:
# ==============================================================================
# Training Loop with REAL-TIME Batch-Level Live Plotting
# ==============================================================================

n_epochs = 10
update_plot_every = 5  # Update plot every N batches for smooth visualization

# Store ALL batch-level metrics for high-resolution plotting
all_train_losses = []
all_train_accs = []
all_val_losses = []
all_val_accs = []
epoch_boundaries = [0]  # Track where epochs start in batch indices
epoch_times = []

print(f"ðŸš€ Starting training for {n_epochs} epochs...")
print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")
print(f"Quantum circuit calls per batch: {(32//16)**2 * 3} = {(32//16)**2 * 3} (4 patches Ã— 3 circuits)")
print(f"Plot updates: Every {update_plot_every} batches")
print()

%matplotlib inline
from IPython import display

fig, axes = plt.subplots(2, 2, figsize=(15, 10))
ax1, ax2, ax3, ax4 = axes.flatten()

total_batches_processed = 0

for epoch in range(n_epochs):
    epoch_start = time.time()
    
    print(f"\n{'='*60}")
    print(f"Epoch {epoch+1}/{n_epochs}")
    print(f"{'='*60}")
    
    # ============ TRAINING ============
    print("ðŸ”µ Training...")
    train_pbar = tqdm(total=len(train_loader), desc=f"Epoch {epoch+1} [Train]", 
                      ncols=100, leave=False)
    
    batch_train_losses, batch_train_accs = train_epoch(
        model, train_loader, criterion, optimizer, train_pbar
    )
    train_pbar.close()
    
    all_train_losses.extend(batch_train_losses)
    all_train_accs.extend(batch_train_accs)
    total_batches_processed += len(batch_train_losses)
    
    # ============ VALIDATION ============
    print("ðŸ”´ Validating...")
    val_pbar = tqdm(total=len(val_loader), desc=f"Epoch {epoch+1} [Val]  ", 
                    ncols=100, leave=False)
    
    batch_val_losses, batch_val_accs = validate(
        model, val_loader, criterion, val_pbar
    )
    val_pbar.close()
    
    all_val_losses.extend(batch_val_losses)
    all_val_accs.extend(batch_val_accs)
    epoch_boundaries.append(total_batches_processed)
    
    # Step the LR scheduler
    val_loss_avg = np.mean(batch_val_losses)
    scheduler.step(val_loss_avg)
    
    epoch_time = time.time() - epoch_start
    epoch_times.append(epoch_time)
    
    # ============ EPOCH SUMMARY ============
    train_loss_avg = np.mean(batch_train_losses)
    train_acc_avg = np.mean(batch_train_accs)
    val_acc_avg = np.mean(batch_val_accs)
    
    print(f"\nðŸ“Š Epoch {epoch+1} Summary:")
    print(f"  Train Loss: {train_loss_avg:.4f} | Train Acc: {train_acc_avg:.2f}%")
    print(f"  Val Loss:   {val_loss_avg:.4f} | Val Acc:   {val_acc_avg:.2f}%")
    print(f"  Time:       {epoch_time:.1f}s | LR: {optimizer.param_groups[0]['lr']:.6f}")
    
    # ============ UPDATE PLOTS ============
    ax1.clear()
    ax2.clear()
    ax3.clear()
    ax4.clear()
    
    train_batch_indices = list(range(len(all_train_losses)))
    val_batch_indices = list(range(len(all_val_losses)))
    
    # Plot 1: Batch-level training loss (high resolution)
    ax1.plot(train_batch_indices, all_train_losses, 'b-', alpha=0.3, linewidth=0.5)
    # Moving average for trend
    if len(all_train_losses) > 10:
        window = min(20, len(all_train_losses) // 5)
        train_loss_smooth = np.convolve(all_train_losses, 
                                         np.ones(window)/window, mode='valid')
        ax1.plot(range(window-1, len(all_train_losses)), train_loss_smooth, 
                 'b-', linewidth=2, label='Train Loss (smoothed)')
    ax1.set_xlabel('Training Batch')
    ax1.set_ylabel('Loss')
    ax1.set_title('Batch-Level Training Loss', fontweight='bold')
    ax1.grid(True, alpha=0.3)
    ax1.legend()
    
    # Plot 2: Batch-level validation loss
    ax2.plot(val_batch_indices, all_val_losses, 'r-', alpha=0.5, linewidth=1)
    if len(all_val_losses) > 10:
        window = min(20, len(all_val_losses) // 5)
        val_loss_smooth = np.convolve(all_val_losses, 
                                       np.ones(window)/window, mode='valid')
        ax2.plot(range(window-1, len(all_val_losses)), val_loss_smooth, 
                 'r-', linewidth=2, label='Val Loss (smoothed)')
    ax2.set_xlabel('Validation Batch')
    ax2.set_ylabel('Loss')
    ax2.set_title('Batch-Level Validation Loss', fontweight='bold')
    ax2.grid(True, alpha=0.3)
    ax2.legend()
    
    # Plot 3: Batch-level training accuracy
    ax3.plot(train_batch_indices, all_train_accs, 'b-', alpha=0.3, linewidth=0.5)
    if len(all_train_accs) > 10:
        window = min(20, len(all_train_accs) // 5)
        train_acc_smooth = np.convolve(all_train_accs, 
                                        np.ones(window)/window, mode='valid')
        ax3.plot(range(window-1, len(all_train_accs)), train_acc_smooth, 
                 'b-', linewidth=2, label='Train Acc (smoothed)')
    ax3.set_xlabel('Training Batch')
    ax3.set_ylabel('Accuracy (%)')
    ax3.set_title('Batch-Level Training Accuracy', fontweight='bold')
    ax3.grid(True, alpha=0.3)
    ax3.legend()
    
    # Plot 4: Batch-level validation accuracy
    ax4.plot(val_batch_indices, all_val_accs, 'r-', alpha=0.5, linewidth=1)
    if len(all_val_accs) > 10:
        window = min(20, len(all_val_accs) // 5)
        val_acc_smooth = np.convolve(all_val_accs, 
                                      np.ones(window)/window, mode='valid')
        ax4.plot(range(window-1, len(all_val_accs)), val_acc_smooth, 
                 'r-', linewidth=2, label='Val Acc (smoothed)')
    ax4.set_xlabel('Validation Batch')
    ax4.set_ylabel('Accuracy (%)')
    ax4.set_title('Batch-Level Validation Accuracy', fontweight='bold')
    ax4.grid(True, alpha=0.3)
    ax4.legend()
    
    # Draw epoch boundaries on training plots
    for boundary in epoch_boundaries[1:-1]:  # Skip first (0) and last (current)
        ax1.axvline(x=boundary, color='gray', linestyle='--', alpha=0.3, linewidth=1)
        ax3.axvline(x=boundary, color='gray', linestyle='--', alpha=0.3, linewidth=1)
    
    plt.tight_layout()
    display.clear_output(wait=True)
    display.display(fig)

# ============ FINAL SUMMARY ============
total_time = sum(epoch_times)
print(f"\n{'='*60}")
print(f"âœ… TRAINING COMPLETE!")
print(f"{'='*60}")
print(f"Total time:        {total_time:.1f}s ({total_time/60:.1f} min)")
print(f"Avg time/epoch:    {total_time/n_epochs:.1f}s")

# Calculate epoch-level metrics for final summary
epoch_val_accs = []
for i in range(len(epoch_boundaries) - 1):
    start_batch = epoch_boundaries[i] - (i * len(val_loader))
    end_batch = epoch_boundaries[i+1] - ((i+1) * len(val_loader))
    if i == 0:
        start_batch = 0
        end_batch = len(val_loader)
    else:
        start_batch = i * len(val_loader)
        end_batch = (i + 1) * len(val_loader)
    if end_batch <= len(all_val_accs):
        epoch_val_accs.append(np.mean(all_val_accs[start_batch:end_batch]))

if epoch_val_accs:
    best_val_acc = max(epoch_val_accs)
    best_epoch = epoch_val_accs.index(best_val_acc) + 1
    print(f"Best Val Accuracy: {best_val_acc:.2f}% (epoch {best_epoch})")
    print(f"Final Val Acc:     {epoch_val_accs[-1]:.2f}%")

print(f"\nðŸŽ¯ Total batches:   {len(all_train_losses)} training, {len(all_val_losses)} validation")
print(f"ðŸš€ Speedup factor:  4Ã— (from 16Ã—3=48 â†’ 4Ã—3=12 quantum calls/image)")
print(f"{'='*60}")

ðŸš€ Starting training for 10 epochs...
Train batches: 433, Val batches: 109
Quantum circuit calls per batch: 12 = 12 (4 patches Ã— 3 circuits)
Plot updates: Every 5 batches


Epoch 1/10
ðŸ”µ Training...


Epoch 1 [Train]:  13%|â–ˆâ–ˆâ–Œ                | 58/433 [13:32<1:26:05, 13.77s/it, loss=1.0305, acc=37.5%]