# 🌱 Capstone-Lazarus: Plant Disease Detection with PyTorch & Quantum Computing

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/MadScie254/Capstone-Lazarus/blob/main/notebooks/colab_training.ipynb)

**Fast, reproducible training on resource-limited hardware with Colab scaling capability**

## 🎯 Key Features

- **Transfer Learning**: EfficientNet, ResNet, MobileNet via `timm` library
- **Mixed Precision**: AMP for memory efficiency and speed
- **Quantum Computing**: Optional PennyLane integration (experimental)
- **Production Ready**: Checkpointing, early stopping, ONNX export
- **Resource Optimized**: HP ZBook G5 compatible (16GB RAM, Quadro P2000)

## 📋 Training Phases

1. **Quick Test** (1 epoch, 1k samples): Validate pipeline
2. **Development** (10 epochs, full dataset): Model selection
3. **Production** (50+ epochs): Final training with quantum experiments

---

## 🚀 Getting Started

Run all cells in order. Toggle settings in Section 1 for your hardware configuration.

# 1. Environment Setup and Configuration

First, we'll install all required dependencies and set up the environment for both Colab and local execution.

In [None]:
# Check if running on Colab
import os
import sys
from pathlib import Path

IN_COLAB = 'google.colab' in sys.modules

print(f"Running in Colab: {IN_COLAB}")
print(f"Python version: {sys.version}")

# For Colab: Mount Google Drive
if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    
    # Clone repository if not exists
    if not os.path.exists('/content/Capstone-Lazarus'):
        !git clone https://github.com/MadScie254/Capstone-Lazarus.git /content/Capstone-Lazarus
    
    # Change to project directory
    os.chdir('/content/Capstone-Lazarus')
    PROJECT_ROOT = Path('/content/Capstone-Lazarus')
else:
    # Local execution
    PROJECT_ROOT = Path().resolve()
    if not (PROJECT_ROOT / 'src').exists():
        PROJECT_ROOT = PROJECT_ROOT.parent if PROJECT_ROOT.name == 'notebooks' else PROJECT_ROOT
    
print(f"Project root: {PROJECT_ROOT}")

# Add src to path
if str(PROJECT_ROOT / 'src') not in sys.path:
    sys.path.append(str(PROJECT_ROOT / 'src'))

In [None]:
# Install required packages
import subprocess
import importlib.util

def install_if_missing(packages):
    """Install packages if they're not available."""
    for package in packages:
        if importlib.util.find_spec(package.split('==')[0]) is None:
            subprocess.check_call([sys.executable, '-m', 'pip', 'install', package])
            print(f"✅ Installed {package}")
        else:
            print(f"✅ {package.split('==')[0]} already available")

# Core ML packages
core_packages = [
    'torch>=2.0.0',
    'torchvision>=0.15.0',
    'torchaudio>=2.0.0',
    'timm>=0.9.0',
    'torchmetrics>=0.11.0'
]

# Data processing and augmentation
data_packages = [
    'albumentations>=1.3.0',
    'opencv-python>=4.8.0',
    'Pillow>=9.5.0'
]

# Visualization and utilities
viz_packages = [
    'matplotlib>=3.7.0',
    'seaborn>=0.12.0',
    'tqdm>=4.65.0',
    'pyyaml>=6.0',
    'scikit-learn>=1.3.0',
    'pandas>=2.0.0',
    'numpy>=1.24.0'
]

# Optional packages
optional_packages = [
    'wandb',  # For experiment tracking
    'pennylane>=0.32.0',  # For quantum computing
    'onnx>=1.14.0',  # For model export
    'onnxruntime>=1.15.0'  # For ONNX inference
]

print("📦 Installing core ML packages...")
install_if_missing(core_packages)

print("\n📦 Installing data processing packages...")
install_if_missing(data_packages)

print("\n📦 Installing visualization packages...")
install_if_missing(viz_packages)

print("\n📦 Installing optional packages (may fail, that's OK)...")
for package in optional_packages:
    try:
        install_if_missing([package])
    except Exception as e:
        print(f"⚠️  Optional package {package} failed to install: {e}")

print("\n✅ Package installation complete!")

In [None]:
# Load configuration
import yaml
import torch
from pathlib import Path

# Default configuration - optimized for HP ZBook G5 and Colab
default_config = {
    'seed': 42,
    'backbone': 'tf_efficientnet_b0',  # Fast and accurate
    'num_classes': 19,  # Adjust based on your dataset
    'image_size': 224,
    
    # Hardware optimizations
    'batch_size': 16 if not IN_COLAB else 32,  # ZBook P2000 has 4GB VRAM
    'num_workers': 2 if not IN_COLAB else 4,
    'pin_memory': True,
    
    # Training settings
    'epochs': 5 if IN_COLAB else 30,  # Start small for testing
    'learning_rate': 1e-3,
    'weight_decay': 1e-4,
    'optimizer': 'adamw',
    'scheduler': 'onecycle',
    
    # Performance features
    'use_amp': True,  # Automatic Mixed Precision
    'gradient_accumulation_steps': 2 if not IN_COLAB else 1,
    'use_ema': True,  # Exponential Moving Average
    'ema_decay': 0.9999,
    
    # Data augmentation
    'use_augmentations': True,
    'augmentation_strength': 'medium',
    'use_class_balancing': True,
    
    # Quantum computing (experimental)
    'use_quantum': False,  # Start with classical training
    'quantum': {
        'n_qubits': 4,
        'n_layers': 3,
        'embedding_dim': 4
    },
    
    # Checkpointing and monitoring
    'save_every': 10,
    'early_stopping_patience': 15,
    'use_wandb': False,  # Enable for experiment tracking
    
    # Paths
    'data_dir': str(PROJECT_ROOT / 'data'),
    'save_dir': str(PROJECT_ROOT / 'checkpoints'),
    
    # Testing
    'quick_test': True,  # Start with subset for validation
    'test_subset_size': 1000
}

# Try to load custom config if it exists
config_path = PROJECT_ROOT / 'config.yaml'
if config_path.exists():
    with open(config_path, 'r') as f:
        custom_config = yaml.safe_load(f)
    default_config.update(custom_config)
    print(f"✅ Loaded custom config from {config_path}")
else:
    print(f"ℹ️  Using default config. Create {config_path} to customize.")

config = default_config

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
config['device'] = str(device)

print(f"\n🔧 Configuration:")
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
print(f"Backbone: {config['backbone']}")
print(f"Batch size: {config['batch_size']}")
print(f"Image size: {config['image_size']}")
print(f"Mixed Precision: {config['use_amp']}")
print(f"Quantum layers: {config['use_quantum']}")
print(f"Quick test mode: {config['quick_test']}")

# Set random seeds for reproducibility
torch.manual_seed(config['seed'])
if torch.cuda.is_available():
    torch.cuda.manual_seed(config['seed'])

# 2. Data Loading and Preprocessing with Albumentations

Implement efficient data loading with stratified sampling and advanced augmentations optimized for plant disease classification.

In [None]:
# Import data utilities
try:
    from data_utils_torch import make_dataloaders, create_subset_loader, analyze_dataset_distribution
    print("✅ Imported local data_utils_torch")
except ImportError:
    print("⚠️  Local data_utils_torch not found. Using inline implementation.")
    
    # Inline implementation for standalone notebook
    import torch
    from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
    from torchvision import transforms
    from torchvision.datasets import ImageFolder
    import numpy as np
    from PIL import Image
    import random
    
    try:
        import albumentations as A
        from albumentations.pytorch import ToTensorV2
        ALBUMENTATIONS_AVAILABLE = True
    except ImportError:
        ALBUMENTATIONS_AVAILABLE = False
        print("⚠️  Albumentations not available, using torchvision transforms")
    
    class PlantDiseaseDataset(Dataset):
        """Plant Disease Dataset with Albumentations support."""
        
        def __init__(self, root_dir, transform=None, use_albumentations=True):
            self.root_dir = Path(root_dir)
            self.use_albumentations = use_albumentations and ALBUMENTATIONS_AVAILABLE
            self.transform = transform
            
            # Load dataset using ImageFolder for class mapping
            self.dataset = ImageFolder(str(root_dir))
            self.classes = self.dataset.classes
            self.class_to_idx = self.dataset.class_to_idx
            self.samples = self.dataset.samples
            
        def __len__(self):
            return len(self.samples)
        
        def __getitem__(self, idx):
            img_path, label = self.samples[idx]
            
            # Load image
            image = Image.open(img_path).convert('RGB')
            
            # Apply transforms
            if self.transform:
                if self.use_albumentations:
                    # Convert PIL to numpy for Albumentations
                    image = np.array(image)
                    transformed = self.transform(image=image)
                    image = transformed['image']
                else:
                    # Standard torchvision transforms
                    image = self.transform(image)
            else:
                # Default: convert to tensor
                image = transforms.ToTensor()(image)
                
            return image, label
    
    def get_albumentations_transforms(image_size=224, split="train", strength="medium"):
        """Get Albumentations transforms."""
        
        if not ALBUMENTATIONS_AVAILABLE:
            return get_torchvision_transforms(image_size, split)
        
        # Base transforms
        base_transforms = [A.Resize(image_size, image_size, always_apply=True)]
        
        if split == "train":
            if strength == "medium":
                aug_transforms = [
                    A.HorizontalFlip(p=0.5),
                    A.VerticalFlip(p=0.2),
                    A.Rotate(limit=25, p=0.5),
                    A.RandomBrightnessContrast(
                        brightness_limit=0.2, contrast_limit=0.2, p=0.5
                    ),
                    A.HueSaturationValue(
                        hue_shift_limit=10, sat_shift_limit=20, val_shift_limit=10, p=0.3
                    ),
                ]
            elif strength == "light":
                aug_transforms = [
                    A.HorizontalFlip(p=0.5),
                    A.Rotate(limit=15, p=0.3),
                    A.RandomBrightnessContrast(
                        brightness_limit=0.1, contrast_limit=0.1, p=0.3
                    ),
                ]
            else:  # heavy
                aug_transforms = [
                    A.HorizontalFlip(p=0.5),
                    A.VerticalFlip(p=0.3),
                    A.Rotate(limit=35, p=0.6),
                    A.RandomBrightnessContrast(
                        brightness_limit=0.3, contrast_limit=0.3, p=0.6
                    ),
                    A.HueSaturationValue(
                        hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.5
                    ),
                    A.OneOf([
                        A.ElasticTransform(p=0.3),
                        A.GridDistortion(p=0.3),
                        A.OpticalDistortion(p=0.3),
                    ], p=0.3),
                ]
            
            transforms_list = base_transforms + aug_transforms
        else:
            transforms_list = base_transforms
        
        # Add normalization and tensor conversion
        transforms_list.extend([
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], always_apply=True),
            ToTensorV2(always_apply=True)
        ])
        
        return A.Compose(transforms_list)
    
    def get_torchvision_transforms(image_size=224, split="train"):
        """Get torchvision transforms as fallback."""
        
        if split == "train":
            transform_list = [
                transforms.Resize((image_size, image_size)),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomRotation(25),
                transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ]
        else:
            transform_list = [
                transforms.Resize((image_size, image_size)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ]
        
        return transforms.Compose(transform_list)
    
    def create_weighted_sampler(dataset):
        """Create weighted sampler for class balancing."""
        
        labels = [sample[1] for sample in dataset.samples]
        class_counts = np.bincount(labels)
        
        # Calculate weights (inverse frequency)
        num_samples = len(labels)
        class_weights = num_samples / (len(class_counts) * class_counts)
        
        # Assign weight to each sample
        sample_weights = [class_weights[label] for label in labels]
        
        return WeightedRandomSampler(
            weights=sample_weights,
            num_samples=len(sample_weights),
            replacement=True
        )
    
    def make_dataloaders(data_dir, config, train_split=0.8, val_split=0.2):
        """Create train and validation DataLoaders."""
        
        data_path = Path(data_dir)
        if not data_path.exists():
            raise FileNotFoundError(f"Data directory not found: {data_dir}")
        
        # Get transforms
        if ALBUMENTATIONS_AVAILABLE and config.get('use_augmentations', True):
            train_transform = get_albumentations_transforms(
                image_size=config['image_size'],
                split='train',
                strength=config.get('augmentation_strength', 'medium')
            )
            val_transform = get_albumentations_transforms(
                image_size=config['image_size'],
                split='val'
            )
            use_albu = True
        else:
            train_transform = get_torchvision_transforms(
                image_size=config['image_size'],
                split='train'
            )
            val_transform = get_torchvision_transforms(
                image_size=config['image_size'],
                split='val'
            )
            use_albu = False
        
        # Check if data is already split
        train_dir = data_path / "train"
        val_dir = data_path / "val"
        
        if train_dir.exists() and val_dir.exists():
            print("Using existing train/val split")
            train_dataset = PlantDiseaseDataset(train_dir, transform=train_transform, use_albumentations=use_albu)
            val_dataset = PlantDiseaseDataset(val_dir, transform=val_transform, use_albumentations=use_albu)
        else:
            print("Creating train/val split from single directory")
            full_dataset = PlantDiseaseDataset(data_path, transform=None, use_albumentations=use_albu)
            
            # Create indices for train/val split
            dataset_size = len(full_dataset)
            indices = list(range(dataset_size))
            random.seed(config.get('seed', 42))
            random.shuffle(indices)
            
            train_size = int(train_split * dataset_size)
            train_indices = indices[:train_size]
            val_indices = indices[train_size:]
            
            # Create subsets
            train_dataset = torch.utils.data.Subset(full_dataset, train_indices)
            val_dataset = torch.utils.data.Subset(full_dataset, val_indices)
            
            # Apply transforms to subsets
            train_dataset.dataset.transform = train_transform
            val_dataset.dataset.transform = val_transform
        
        # Create weighted sampler for training
        if config.get('use_class_balancing', True):
            train_sampler = create_weighted_sampler(train_dataset)
            shuffle = False
        else:
            train_sampler = None
            shuffle = True
        
        # Create data loaders
        train_loader = DataLoader(
            train_dataset,
            batch_size=config['batch_size'],
            sampler=train_sampler,
            shuffle=shuffle,
            num_workers=config.get('num_workers', 4),
            pin_memory=config.get('pin_memory', True),
            persistent_workers=True if config.get('num_workers', 4) > 0 else False
        )
        
        val_loader = DataLoader(
            val_dataset,
            batch_size=config['batch_size'],
            shuffle=False,
            num_workers=config.get('num_workers', 4),
            pin_memory=config.get('pin_memory', True),
            persistent_workers=True if config.get('num_workers', 4) > 0 else False
        )
        
        print(f"Train loader: {len(train_loader)} batches ({len(train_dataset)} samples)")
        print(f"Val loader: {len(val_loader)} batches ({len(val_dataset)} samples)")
        
        return train_loader, val_loader
    
    def create_subset_loader(data_dir, config, subset_size=1000, split="train"):
        """Create a DataLoader with a subset of data for quick testing."""
        
        if ALBUMENTATIONS_AVAILABLE and config.get('use_augmentations', True):
            transform = get_albumentations_transforms(
                image_size=config['image_size'],
                split=split,
                strength='light'  # Light for quick testing
            )
            use_albu = True
        else:
            transform = get_torchvision_transforms(
                image_size=config['image_size'],
                split=split
            )
            use_albu = False
        
        # Create dataset
        dataset = PlantDiseaseDataset(data_dir, transform=transform, use_albumentations=use_albu)
        
        # Create random subset
        subset_size = min(subset_size, len(dataset))
        indices = random.sample(range(len(dataset)), subset_size)
        subset_dataset = torch.utils.data.Subset(dataset, indices)
        
        # Create loader
        loader = DataLoader(
            subset_dataset,
            batch_size=config['batch_size'],
            shuffle=(split == 'train'),
            num_workers=config.get('num_workers', 2),
            pin_memory=False
        )
        
        print(f"Subset loader created: {len(loader)} batches ({subset_size} samples)")
        return loader
    
    def analyze_dataset_distribution(data_dir):
        """Analyze class distribution in dataset."""
        
        dataset = ImageFolder(data_dir)
        
        # Count samples per class
        class_counts = {}
        for _, class_idx in dataset.samples:
            class_name = dataset.classes[class_idx]
            class_counts[class_name] = class_counts.get(class_name, 0) + 1
        
        total_samples = len(dataset.samples)
        
        analysis = {
            'total_samples': total_samples,
            'num_classes': len(dataset.classes),
            'class_names': dataset.classes,
            'class_counts': class_counts,
            'class_percentages': {
                name: (count / total_samples) * 100 
                for name, count in class_counts.items()
            }
        }
        
        return analysis

print("✅ Data utilities ready!")

In [None]:
# Analyze dataset distribution
data_path = Path(config['data_dir'])
print(f"Looking for data in: {data_path}")

if data_path.exists():
    # Check what's in the data directory
    subdirs = [d for d in data_path.iterdir() if d.is_dir()]
    print(f"Found {len(subdirs)} subdirectories:")
    for subdir in subdirs[:10]:  # Show first 10
        print(f"  📁 {subdir.name}")
    if len(subdirs) > 10:
        print(f"  ... and {len(subdirs) - 10} more")
    
    # Try to analyze distribution
    try:
        analysis = analyze_dataset_distribution(data_path)
        print(f"\n📊 Dataset Analysis:")
        print(f"Total samples: {analysis['total_samples']:,}")
        print(f"Number of classes: {analysis['num_classes']}")
        
        # Show class distribution
        print(f"\nClass distribution:")
        for class_name, count in sorted(analysis['class_counts'].items(), key=lambda x: x[1], reverse=True)[:5]:
            percentage = analysis['class_percentages'][class_name]
            print(f"  {class_name}: {count:,} ({percentage:.1f}%)")
        
        if analysis['num_classes'] > 5:
            print(f"  ... and {analysis['num_classes'] - 5} more classes")
            
        # Update config with actual number of classes
        config['num_classes'] = analysis['num_classes']
        
    except Exception as e:
        print(f"⚠️  Could not analyze dataset: {e}")
        print("This might be because the data structure is different than expected.")
        print("Expected structure: data_dir/class1/, data_dir/class2/, etc.")

else:
    print("⚠️  Data directory not found!")
    if IN_COLAB:
        print("📋 For Colab users:")
        print("1. Upload your dataset to Google Drive")
        print("2. Update the data_dir path in config")
        print("3. Or use the sample dataset from the repository")
    else:
        print("📋 For local users:")
        print("1. Make sure your data is in the correct directory")
        print("2. Update the data_dir path in config.yaml")
        print("3. Expected structure: data_dir/class1/, data_dir/class2/, etc.")

# 3. Model Factory Implementation with TIMM Integration

Create models using the TIMM library with pretrained weights and optional quantum layer integration.

In [None]:
# Import model factory
try:
    from model_factory_torch import create_model
    print("✅ Imported local model_factory_torch")
except ImportError:
    print("⚠️  Local model_factory_torch not found. Using inline implementation.")
    
    # Inline model factory implementation
    import timm
    import torch.nn as nn
    
    # Available backbone models
    BACKBONE_MODELS = {
        'tf_efficientnet_b0': {
            'timm_name': 'tf_efficientnet_b0.aa_in1k',
            'description': 'EfficientNet-B0 (5.3M params) - Fast and accurate',
            'input_size': 224,
            'params': '5.3M'
        },
        'tf_efficientnet_b1': {
            'timm_name': 'tf_efficientnet_b1.aa_in1k',
            'description': 'EfficientNet-B1 (7.8M params) - Good balance',
            'input_size': 240,
            'params': '7.8M'
        },
        'resnet18': {
            'timm_name': 'resnet18.a1_in1k',
            'description': 'ResNet-18 (11.7M params) - Classic, reliable',
            'input_size': 224,
            'params': '11.7M'
        },
        'resnet34': {
            'timm_name': 'resnet34.a1_in1k',
            'description': 'ResNet-34 (21.8M params) - More capacity',
            'input_size': 224,
            'params': '21.8M'
        },
        'mobilenetv3_small_100': {
            'timm_name': 'mobilenetv3_small_100.lamb_in1k',
            'description': 'MobileNetV3-Small (2.5M params) - Very efficient',
            'input_size': 224,
            'params': '2.5M'
        },
        'mobilenetv3_large_100': {
            'timm_name': 'mobilenetv3_large_100.ra_in1k',
            'description': 'MobileNetV3-Large (5.5M params) - Efficient',
            'input_size': 224,
            'params': '5.5M'
        }
    }
    
    class PlantDiseaseModel(nn.Module):
        \"\"\"Plant Disease Classification Model with optional quantum layer.\"\"\"\n        
        def __init__(self, backbone_name, num_classes, pretrained=True, quantum_layer=None):\n            super().__init__()\n            \n            if backbone_name not in BACKBONE_MODELS:\n                raise ValueError(f\"Unknown backbone: {backbone_name}. Available: {list(BACKBONE_MODELS.keys())}\")\n            \n            timm_name = BACKBONE_MODELS[backbone_name]['timm_name']\n            \n            # Create backbone model\n            self.backbone = timm.create_model(\n                timm_name,\n                pretrained=pretrained,\n                num_classes=0  # Remove classifier head\n            )\n            \n            # Get feature dimension\n            with torch.no_grad():\n                dummy_input = torch.randn(1, 3, 224, 224)\n                features = self.backbone(dummy_input)\n                self.feature_dim = features.shape[1]\n            \n            # Create classifier head\n            if quantum_layer is not None:\n                # Quantum-classical hybrid\n                self.classifier = nn.Sequential(\n                    nn.Linear(self.feature_dim, quantum_layer.embedding_dim),\n                    nn.ReLU(),\n                    nn.Dropout(0.3),\n                    quantum_layer,\n                    nn.Linear(quantum_layer.n_qubits, num_classes)  # Quantum outputs to classes\n                )\n            else:\n                # Standard classifier\n                self.classifier = nn.Sequential(\n                    nn.Dropout(0.3),\n                    nn.Linear(self.feature_dim, num_classes)\n                )\n        \n        def forward(self, x):\n            features = self.backbone(x)\n            return self.classifier(features)\n    \n    def create_model(config, quantum_layer=None):\n        \"\"\"Create a model based on configuration.\"\"\"\n        \n        backbone_name = config['backbone']\n        num_classes = config['num_classes']\n        \n        model = PlantDiseaseModel(\n            backbone_name=backbone_name,\n            num_classes=num_classes,\n            pretrained=True,\n            quantum_layer=quantum_layer\n        )\n        \n        return model
    
    def list_available_models():\n        \"\"\"List all available backbone models.\"\"\"\n        print(\"📱 Available Backbone Models:\")\n        for name, info in BACKBONE_MODELS.items():\n            print(f\"  {name}:\")\n            print(f\"    - {info['description']}\")\n            print(f\"    - Parameters: {info['params']}\")\n            print(f\"    - Input size: {info['input_size']}\")\n            print()

print("✅ Model factory ready!")

In [None]:
# Show available models and create model
list_available_models()

# Create model (without quantum layer for now)
print(f"🏗️  Creating model: {config['backbone']}")
model = create_model(config, quantum_layer=None)

# Move to device
model = model.to(device)

# Model summary
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\\n📊 Model Summary:")
print(f"Backbone: {config['backbone']}")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Model size (estimated): {total_params * 4 / (1024**2):.1f} MB")

# Test model with dummy input
with torch.no_grad():
    dummy_input = torch.randn(1, 3, config['image_size'], config['image_size']).to(device)
    output = model(dummy_input)
    print(f"Output shape: {output.shape}")
    print(f"Expected shape: [batch_size, {config['num_classes']}]")
    
print("✅ Model created successfully!")

# 4. Optional Quantum Layer Module (PennyLane)

Implement quantum neural network layers for experimental hybrid quantum-classical models.

In [None]:
# Quantum layer implementation (optional)\nquantum_available = False\n\ntry:\n    from quantum_layer import QuantumLayer, HybridQuantumClassifier\n    quantum_available = True\n    print(\"✅ Imported local quantum_layer\")\nexcept ImportError:\n    print(\"ℹ️  Local quantum_layer not found. Trying inline implementation.\")\n    \n    try:\n        import pennylane as qml\n        from pennylane import numpy as np\n        import torch.nn as nn\n        \n        class QuantumLayer(nn.Module):\n            \"\"\"Quantum Neural Network Layer using PennyLane.\"\"\"\n            \n            def __init__(self, n_qubits=4, n_layers=3, embedding_dim=4):\n                super().__init__()\n                self.n_qubits = n_qubits\n                self.n_layers = n_layers\n                self.embedding_dim = embedding_dim\n                \n                # Create quantum device\n                self.dev = qml.device('default.qubit', wires=n_qubits)\n                \n                # Define quantum circuit\n                @qml.qnode(self.dev, interface='torch')\n                def circuit(inputs, weights):\n                    # Embed classical data\n                    for i in range(n_qubits):\n                        qml.RY(inputs[i], wires=i)\n                    \n                    # Variational layers\n                    for layer in range(n_layers):\n                        for i in range(n_qubits):\n                            qml.RY(weights[layer, i, 0], wires=i)\n                            qml.RZ(weights[layer, i, 1], wires=i)\n                        \n                        # Entangling gates\n                        for i in range(n_qubits - 1):\n                            qml.CNOT(wires=[i, i + 1])\n                        if n_qubits > 2:\n                            qml.CNOT(wires=[n_qubits - 1, 0])  # Wrap around\n                    \n                    # Measure expectations\n                    return [qml.expval(qml.PauliZ(i)) for i in range(n_qubits)]\n                \n                # Create weight tensor\n                weight_shapes = {\"weights\": (n_layers, n_qubits, 2)}\n                self.qlayer = qml.qnn.TorchLayer(circuit, weight_shapes)\n            \n            def forward(self, x):\n                # Ensure input has correct dimension\n                if x.shape[-1] != self.n_qubits:\n                    # Project to quantum dimension\n                    if not hasattr(self, 'projection'):\n                        self.projection = nn.Linear(x.shape[-1], self.n_qubits).to(x.device)\n                    x = self.projection(x)\n                \n                # Apply quantum circuit\n                return self.qlayer(x)\n        \n        quantum_available = True\n        print(\"✅ Created inline quantum implementation\")\n        \n    except ImportError as e:\n        print(f\"⚠️  PennyLane not available: {e}\")\n        print(\"Quantum layers will be disabled.\")\n        quantum_available = False\n\n# Test quantum layer if available and enabled\nif quantum_available and config.get('use_quantum', False):\n    print(\"\\n🔬 Testing Quantum Layer:\")\n    \n    # Create quantum layer\n    quantum_config = config['quantum']\n    quantum_layer = QuantumLayer(\n        n_qubits=quantum_config['n_qubits'],\n        n_layers=quantum_config['n_layers'],\n        embedding_dim=quantum_config['embedding_dim']\n    )\n    \n    # Test with dummy input\n    test_input = torch.randn(2, quantum_config['embedding_dim'])  # Batch size 2\n    with torch.no_grad():\n        quantum_output = quantum_layer(test_input)\n        print(f\"Quantum layer input shape: {test_input.shape}\")\n        print(f\"Quantum layer output shape: {quantum_output.shape}\")\n    \n    print(\"✅ Quantum layer test successful!\")\n    print(\"\\n⚠️  Note: Quantum simulation is slow on CPU. Use small datasets for testing.\")\n    \n    # Create model with quantum layer\n    print(\"\\n🏗️  Creating quantum-classical hybrid model...\")\n    quantum_model = create_model(config, quantum_layer=quantum_layer)\n    quantum_model = quantum_model.to(device)\n    \n    # Test quantum model\n    with torch.no_grad():\n        dummy_input = torch.randn(1, 3, config['image_size'], config['image_size']).to(device)\n        quantum_output = quantum_model(dummy_input)\n        print(f\"Quantum model output shape: {quantum_output.shape}\")\n    \n    print(\"✅ Quantum model created successfully!\")\n    \n    # Ask user if they want to use quantum model\n    use_quantum_model = True  # Set to False for classical training\n    if use_quantum_model:\n        model = quantum_model\n        print(\"🚀 Using quantum-classical hybrid model for training\")\n    else:\n        print(\"📱 Sticking with classical model for faster training\")\n    \nelse:\n    if config.get('use_quantum', False):\n        print(\"⚠️  Quantum layers requested but not available. Using classical model.\")\n    else:\n        print(\"📱 Using classical model (quantum disabled in config)\")\n\nprint(\"\\n✅ Model setup complete!\")