# 🌱 Capstone-Lazarus: Plant Disease Detection with PyTorch & Quantum Computing

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/MadScie254/Capstone-Lazarus/blob/main/notebooks/colab_training.ipynb)

**Fast, reproducible training on resource-limited hardware with Colab scaling capability**

## 🎯 Key Features

- **Transfer Learning**: EfficientNet, ResNet, MobileNet via `timm` library
- **Mixed Precision**: AMP for memory efficiency and speed
- **Quantum Computing**: Optional PennyLane integration (experimental)
- **Production Ready**: Checkpointing, early stopping, ONNX export
- **Resource Optimized**: HP ZBook G5 compatible (16GB RAM, Quadro P2000)

## 📋 Training Phases

1. **Quick Test** (1 epoch, 1k samples): Validate pipeline
2. **Development** (10 epochs, full dataset): Model selection
3. **Production** (50+ epochs): Final training with quantum experiments

---

## 🚀 Getting Started

Run all cells in order. Toggle settings in Section 1 for your hardware configuration.

# 1. Environment Setup and Configuration

First, we'll install all required dependencies and set up the environment for both Colab and local execution.

In [None]:
# Check if running on Colab
import os
import sys
from pathlib import Path

IN_COLAB = 'google.colab' in sys.modules

print(f"Running in Colab: {IN_COLAB}")
print(f"Python version: {sys.version}")

# For Colab: Mount Google Drive
if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    
    # Clone repository if not exists
    if not os.path.exists('/content/Capstone-Lazarus'):
        !git clone https://github.com/MadScie254/Capstone-Lazarus.git /content/Capstone-Lazarus
    
    # Change to project directory
    os.chdir('/content/Capstone-Lazarus')
    PROJECT_ROOT = Path('/content/Capstone-Lazarus')
else:
    # Local execution
    PROJECT_ROOT = Path().resolve()
    if not (PROJECT_ROOT / 'src').exists():
        PROJECT_ROOT = PROJECT_ROOT.parent if PROJECT_ROOT.name == 'notebooks' else PROJECT_ROOT
    
print(f"Project root: {PROJECT_ROOT}")

# Add src to path
if str(PROJECT_ROOT / 'src') not in sys.path:
    sys.path.append(str(PROJECT_ROOT / 'src'))

In [None]:
# Install required packages
import subprocess
import importlib.util

def install_if_missing(packages):
    """Install packages if they're not available."""
    for package in packages:
        if importlib.util.find_spec(package.split('==')[0]) is None:
            subprocess.check_call([sys.executable, '-m', 'pip', 'install', package])
            print(f"✅ Installed {package}")
        else:
            print(f"✅ {package.split('==')[0]} already available")

# Core ML packages
core_packages = [
    'torch>=2.0.0',
    'torchvision>=0.15.0',
    'torchaudio>=2.0.0',
    'timm>=0.9.0',
    'torchmetrics>=0.11.0'
]

# Data processing and augmentation
data_packages = [
    'albumentations>=1.3.0',
    'opencv-python>=4.8.0',
    'Pillow>=9.5.0'
]

# Visualization and utilities
viz_packages = [
    'matplotlib>=3.7.0',
    'seaborn>=0.12.0',
    'tqdm>=4.65.0',
    'pyyaml>=6.0',
    'scikit-learn>=1.3.0',
    'pandas>=2.0.0',
    'numpy>=1.24.0'
]

# Optional packages
optional_packages = [
    'wandb',  # For experiment tracking
    'pennylane>=0.32.0',  # For quantum computing
    'onnx>=1.14.0',  # For model export
    'onnxruntime>=1.15.0'  # For ONNX inference
]

print("📦 Installing core ML packages...")
install_if_missing(core_packages)

print("\n📦 Installing data processing packages...")
install_if_missing(data_packages)

print("\n📦 Installing visualization packages...")
install_if_missing(viz_packages)

print("\n📦 Installing optional packages (may fail, that's OK)...")
for package in optional_packages:
    try:
        install_if_missing([package])
    except Exception as e:
        print(f"⚠️  Optional package {package} failed to install: {e}")

print("\n✅ Package installation complete!")

In [None]:
# Load configuration
import yaml
import torch
from pathlib import Path

# Default configuration - optimized for HP ZBook G5 and Colab
default_config = {
    'seed': 42,
    'backbone': 'tf_efficientnet_b0',  # Fast and accurate
    'num_classes': 19,  # Adjust based on your dataset
    'image_size': 224,
    
    # Hardware optimizations
    'batch_size': 16 if not IN_COLAB else 32,  # ZBook P2000 has 4GB VRAM
    'num_workers': 2 if not IN_COLAB else 4,
    'pin_memory': True,
    
    # Training settings
    'epochs': 5 if IN_COLAB else 30,  # Start small for testing
    'learning_rate': 1e-3,
    'weight_decay': 1e-4,
    'optimizer': 'adamw',
    'scheduler': 'onecycle',
    
    # Performance features
    'use_amp': True,  # Automatic Mixed Precision
    'gradient_accumulation_steps': 2 if not IN_COLAB else 1,
    'use_ema': True,  # Exponential Moving Average
    'ema_decay': 0.9999,
    
    # Data augmentation
    'use_augmentations': True,
    'augmentation_strength': 'medium',
    'use_class_balancing': True,
    
    # Quantum computing (experimental)
    'use_quantum': False,  # Start with classical training
    'quantum': {
        'n_qubits': 4,
        'n_layers': 3,
        'embedding_dim': 4
    },
    
    # Checkpointing and monitoring
    'save_every': 10,
    'early_stopping_patience': 15,
    'use_wandb': False,  # Enable for experiment tracking
    
    # Paths
    'data_dir': str(PROJECT_ROOT / 'data'),
    'save_dir': str(PROJECT_ROOT / 'checkpoints'),
    
    # Testing
    'quick_test': True,  # Start with subset for validation
    'test_subset_size': 1000
}

# Try to load custom config if it exists
config_path = PROJECT_ROOT / 'config.yaml'
if config_path.exists():
    with open(config_path, 'r') as f:
        custom_config = yaml.safe_load(f)
    default_config.update(custom_config)
    print(f"✅ Loaded custom config from {config_path}")
else:
    print(f"ℹ️  Using default config. Create {config_path} to customize.")

config = default_config

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
config['device'] = str(device)

print(f"\n🔧 Configuration:")
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
print(f"Backbone: {config['backbone']}")
print(f"Batch size: {config['batch_size']}")
print(f"Image size: {config['image_size']}")
print(f"Mixed Precision: {config['use_amp']}")
print(f"Quantum layers: {config['use_quantum']}")
print(f"Quick test mode: {config['quick_test']}")

# Set random seeds for reproducibility
torch.manual_seed(config['seed'])
if torch.cuda.is_available():
    torch.cuda.manual_seed(config['seed'])

# 2. Data Loading and Preprocessing with Albumentations

Implement efficient data loading with stratified sampling and advanced augmentations optimized for plant disease classification.

In [None]:
# Import data utilities
try:
    from data_utils_torch import make_dataloaders, create_subset_loader, analyze_dataset_distribution
    print("✅ Imported local data_utils_torch")
except ImportError:
    print("⚠️  Local data_utils_torch not found. Using inline implementation.")
    
    # Inline implementation for standalone notebook
    import torch
    from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
    from torchvision import transforms
    from torchvision.datasets import ImageFolder
    import numpy as np
    from PIL import Image
    import random
    
    try:
        import albumentations as A
        from albumentations.pytorch import ToTensorV2
        ALBUMENTATIONS_AVAILABLE = True
    except ImportError:
        ALBUMENTATIONS_AVAILABLE = False
        print("⚠️  Albumentations not available, using torchvision transforms")
    
    class PlantDiseaseDataset(Dataset):
        """Plant Disease Dataset with Albumentations support."""
        
        def __init__(self, root_dir, transform=None, use_albumentations=True):
            self.root_dir = Path(root_dir)
            self.use_albumentations = use_albumentations and ALBUMENTATIONS_AVAILABLE
            self.transform = transform
            
            # Load dataset using ImageFolder for class mapping
            self.dataset = ImageFolder(str(root_dir))
            self.classes = self.dataset.classes
            self.class_to_idx = self.dataset.class_to_idx
            self.samples = self.dataset.samples
            
        def __len__(self):
            return len(self.samples)
        
        def __getitem__(self, idx):
            img_path, label = self.samples[idx]
            
            # Load image
            image = Image.open(img_path).convert('RGB')
            
            # Apply transforms
            if self.transform:
                if self.use_albumentations:
                    # Convert PIL to numpy for Albumentations
                    image = np.array(image)
                    transformed = self.transform(image=image)
                    image = transformed['image']
                else:
                    # Standard torchvision transforms
                    image = self.transform(image)
            else:
                # Default: convert to tensor
                image = transforms.ToTensor()(image)
                
            return image, label
    
    def get_albumentations_transforms(image_size=224, split="train", strength="medium"):
        """Get Albumentations transforms."""
        
        if not ALBUMENTATIONS_AVAILABLE:
            return get_torchvision_transforms(image_size, split)
        
        # Base transforms
        base_transforms = [A.Resize(image_size, image_size, always_apply=True)]
        
        if split == "train":
            if strength == "medium":
                aug_transforms = [
                    A.HorizontalFlip(p=0.5),
                    A.VerticalFlip(p=0.2),
                    A.Rotate(limit=25, p=0.5),
                    A.RandomBrightnessContrast(
                        brightness_limit=0.2, contrast_limit=0.2, p=0.5
                    ),
                    A.HueSaturationValue(
                        hue_shift_limit=10, sat_shift_limit=20, val_shift_limit=10, p=0.3
                    ),
                ]
            elif strength == "light":
                aug_transforms = [
                    A.HorizontalFlip(p=0.5),
                    A.Rotate(limit=15, p=0.3),
                    A.RandomBrightnessContrast(
                        brightness_limit=0.1, contrast_limit=0.1, p=0.3
                    ),
                ]
            else:  # heavy
                aug_transforms = [
                    A.HorizontalFlip(p=0.5),
                    A.VerticalFlip(p=0.3),
                    A.Rotate(limit=35, p=0.6),
                    A.RandomBrightnessContrast(
                        brightness_limit=0.3, contrast_limit=0.3, p=0.6
                    ),
                    A.HueSaturationValue(
                        hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.5
                    ),
                    A.OneOf([
                        A.ElasticTransform(p=0.3),
                        A.GridDistortion(p=0.3),
                        A.OpticalDistortion(p=0.3),
                    ], p=0.3),
                ]
            
            transforms_list = base_transforms + aug_transforms
        else:
            transforms_list = base_transforms
        
        # Add normalization and tensor conversion
        transforms_list.extend([
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], always_apply=True),
            ToTensorV2(always_apply=True)
        ])
        
        return A.Compose(transforms_list)
    
    def get_torchvision_transforms(image_size=224, split="train"):
        """Get torchvision transforms as fallback."""
        
        if split == "train":
            transform_list = [
                transforms.Resize((image_size, image_size)),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomRotation(25),
                transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ]
        else:
            transform_list = [
                transforms.Resize((image_size, image_size)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ]
        
        return transforms.Compose(transform_list)
    
    def create_weighted_sampler(dataset):
        """Create weighted sampler for class balancing."""
        
        labels = [sample[1] for sample in dataset.samples]
        class_counts = np.bincount(labels)
        
        # Calculate weights (inverse frequency)
        num_samples = len(labels)
        class_weights = num_samples / (len(class_counts) * class_counts)
        
        # Assign weight to each sample
        sample_weights = [class_weights[label] for label in labels]
        
        return WeightedRandomSampler(
            weights=sample_weights,
            num_samples=len(sample_weights),
            replacement=True
        )
    
    def make_dataloaders(data_dir, config, train_split=0.8, val_split=0.2):
        """Create train and validation DataLoaders."""
        
        data_path = Path(data_dir)
        if not data_path.exists():
            raise FileNotFoundError(f"Data directory not found: {data_dir}")
        
        # Get transforms
        if ALBUMENTATIONS_AVAILABLE and config.get('use_augmentations', True):
            train_transform = get_albumentations_transforms(
                image_size=config['image_size'],
                split='train',
                strength=config.get('augmentation_strength', 'medium')
            )
            val_transform = get_albumentations_transforms(
                image_size=config['image_size'],
                split='val'
            )
            use_albu = True
        else:
            train_transform = get_torchvision_transforms(
                image_size=config['image_size'],
                split='train'
            )
            val_transform = get_torchvision_transforms(
                image_size=config['image_size'],
                split='val'
            )
            use_albu = False
        
        # Check if data is already split
        train_dir = data_path / "train"
        val_dir = data_path / "val"
        
        if train_dir.exists() and val_dir.exists():
            print("Using existing train/val split")
            train_dataset = PlantDiseaseDataset(train_dir, transform=train_transform, use_albumentations=use_albu)
            val_dataset = PlantDiseaseDataset(val_dir, transform=val_transform, use_albumentations=use_albu)
        else:
            print("Creating train/val split from single directory")
            full_dataset = PlantDiseaseDataset(data_path, transform=None, use_albumentations=use_albu)
            
            # Create indices for train/val split
            dataset_size = len(full_dataset)
            indices = list(range(dataset_size))
            random.seed(config.get('seed', 42))
            random.shuffle(indices)
            
            train_size = int(train_split * dataset_size)
            train_indices = indices[:train_size]
            val_indices = indices[train_size:]
            
            # Create subsets
            train_dataset = torch.utils.data.Subset(full_dataset, train_indices)
            val_dataset = torch.utils.data.Subset(full_dataset, val_indices)
            
            # Apply transforms to subsets
            train_dataset.dataset.transform = train_transform
            val_dataset.dataset.transform = val_transform
        
        # Create weighted sampler for training
        if config.get('use_class_balancing', True):
            train_sampler = create_weighted_sampler(train_dataset)
            shuffle = False
        else:
            train_sampler = None
            shuffle = True
        
        # Create data loaders
        train_loader = DataLoader(
            train_dataset,
            batch_size=config['batch_size'],
            sampler=train_sampler,
            shuffle=shuffle,
            num_workers=config.get('num_workers', 4),
            pin_memory=config.get('pin_memory', True),
            persistent_workers=True if config.get('num_workers', 4) > 0 else False
        )
        
        val_loader = DataLoader(
            val_dataset,
            batch_size=config['batch_size'],
            shuffle=False,
            num_workers=config.get('num_workers', 4),
            pin_memory=config.get('pin_memory', True),
            persistent_workers=True if config.get('num_workers', 4) > 0 else False
        )
        
        print(f"Train loader: {len(train_loader)} batches ({len(train_dataset)} samples)")
        print(f"Val loader: {len(val_loader)} batches ({len(val_dataset)} samples)")
        
        return train_loader, val_loader
    
    def create_subset_loader(data_dir, config, subset_size=1000, split="train"):
        """Create a DataLoader with a subset of data for quick testing."""
        
        if ALBUMENTATIONS_AVAILABLE and config.get('use_augmentations', True):
            transform = get_albumentations_transforms(
                image_size=config['image_size'],
                split=split,
                strength='light'  # Light for quick testing
            )
            use_albu = True
        else:
            transform = get_torchvision_transforms(
                image_size=config['image_size'],
                split=split
            )
            use_albu = False
        
        # Create dataset
        dataset = PlantDiseaseDataset(data_dir, transform=transform, use_albumentations=use_albu)
        
        # Create random subset
        subset_size = min(subset_size, len(dataset))
        indices = random.sample(range(len(dataset)), subset_size)
        subset_dataset = torch.utils.data.Subset(dataset, indices)
        
        # Create loader
        loader = DataLoader(
            subset_dataset,
            batch_size=config['batch_size'],
            shuffle=(split == 'train'),
            num_workers=config.get('num_workers', 2),
            pin_memory=False
        )
        
        print(f"Subset loader created: {len(loader)} batches ({subset_size} samples)")
        return loader
    
    def analyze_dataset_distribution(data_dir):
        """Analyze class distribution in dataset."""
        
        dataset = ImageFolder(data_dir)
        
        # Count samples per class
        class_counts = {}
        for _, class_idx in dataset.samples:
            class_name = dataset.classes[class_idx]
            class_counts[class_name] = class_counts.get(class_name, 0) + 1
        
        total_samples = len(dataset.samples)
        
        analysis = {
            'total_samples': total_samples,
            'num_classes': len(dataset.classes),
            'class_names': dataset.classes,
            'class_counts': class_counts,
            'class_percentages': {
                name: (count / total_samples) * 100 
                for name, count in class_counts.items()
            }
        }
        
        return analysis

print("✅ Data utilities ready!")

In [None]:
# Analyze dataset distribution
data_path = Path(config['data_dir'])
print(f"Looking for data in: {data_path}")

if data_path.exists():
    # Check what's in the data directory
    subdirs = [d for d in data_path.iterdir() if d.is_dir()]
    print(f"Found {len(subdirs)} subdirectories:")
    for subdir in subdirs[:10]:  # Show first 10
        print(f"  📁 {subdir.name}")
    if len(subdirs) > 10:
        print(f"  ... and {len(subdirs) - 10} more")
    
    # Try to analyze distribution
    try:
        analysis = analyze_dataset_distribution(data_path)
        print(f"\n📊 Dataset Analysis:")
        print(f"Total samples: {analysis['total_samples']:,}")
        print(f"Number of classes: {analysis['num_classes']}")
        
        # Show class distribution
        print(f"\nClass distribution:")
        for class_name, count in sorted(analysis['class_counts'].items(), key=lambda x: x[1], reverse=True)[:5]:
            percentage = analysis['class_percentages'][class_name]
            print(f"  {class_name}: {count:,} ({percentage:.1f}%)")
        
        if analysis['num_classes'] > 5:
            print(f"  ... and {analysis['num_classes'] - 5} more classes")
            
        # Update config with actual number of classes
        config['num_classes'] = analysis['num_classes']
        
    except Exception as e:
        print(f"⚠️  Could not analyze dataset: {e}")
        print("This might be because the data structure is different than expected.")
        print("Expected structure: data_dir/class1/, data_dir/class2/, etc.")

else:
    print("⚠️  Data directory not found!")
    if IN_COLAB:
        print("📋 For Colab users:")
        print("1. Upload your dataset to Google Drive")
        print("2. Update the data_dir path in config")
        print("3. Or use the sample dataset from the repository")
    else:
        print("📋 For local users:")
        print("1. Make sure your data is in the correct directory")
        print("2. Update the data_dir path in config.yaml")
        print("3. Expected structure: data_dir/class1/, data_dir/class2/, etc.")

# 3. Model Factory Implementation with TIMM Integration

Create models using the TIMM library with pretrained weights and optional quantum layer integration.

In [None]:
# Import model factory
try:
    from model_factory_torch import create_model
    print("✅ Imported local model_factory_torch")
except ImportError:
    print("⚠️  Local model_factory_torch not found. Using inline implementation.")
    
    # Inline model factory implementation
    import timm
    import torch.nn as nn
    
    # Available backbone models
    BACKBONE_MODELS = {
        'tf_efficientnet_b0': {
            'timm_name': 'tf_efficientnet_b0.aa_in1k',
            'description': 'EfficientNet-B0 (5.3M params) - Fast and accurate',
            'input_size': 224,
            'params': '5.3M'
        },
        'tf_efficientnet_b1': {
            'timm_name': 'tf_efficientnet_b1.aa_in1k',
            'description': 'EfficientNet-B1 (7.8M params) - Good balance',
            'input_size': 240,
            'params': '7.8M'
        },
        'resnet18': {
            'timm_name': 'resnet18.a1_in1k',
            'description': 'ResNet-18 (11.7M params) - Classic, reliable',
            'input_size': 224,
            'params': '11.7M'
        },
        'resnet34': {
            'timm_name': 'resnet34.a1_in1k',
            'description': 'ResNet-34 (21.8M params) - More capacity',
            'input_size': 224,
            'params': '21.8M'
        },
        'mobilenetv3_small_100': {
            'timm_name': 'mobilenetv3_small_100.lamb_in1k',
            'description': 'MobileNetV3-Small (2.5M params) - Very efficient',
            'input_size': 224,
            'params': '2.5M'
        },
        'mobilenetv3_large_100': {
            'timm_name': 'mobilenetv3_large_100.ra_in1k',
            'description': 'MobileNetV3-Large (5.5M params) - Efficient',
            'input_size': 224,
            'params': '5.5M'
        }
    }
    
    class PlantDiseaseModel(nn.Module):
        \"\"\"Plant Disease Classification Model with optional quantum layer.\"\"\"\n        
        def __init__(self, backbone_name, num_classes, pretrained=True, quantum_layer=None):\n            super().__init__()\n            \n            if backbone_name not in BACKBONE_MODELS:\n                raise ValueError(f\"Unknown backbone: {backbone_name}. Available: {list(BACKBONE_MODELS.keys())}\")\n            \n            timm_name = BACKBONE_MODELS[backbone_name]['timm_name']\n            \n            # Create backbone model\n            self.backbone = timm.create_model(\n                timm_name,\n                pretrained=pretrained,\n                num_classes=0  # Remove classifier head\n            )\n            \n            # Get feature dimension\n            with torch.no_grad():\n                dummy_input = torch.randn(1, 3, 224, 224)\n                features = self.backbone(dummy_input)\n                self.feature_dim = features.shape[1]\n            \n            # Create classifier head\n            if quantum_layer is not None:\n                # Quantum-classical hybrid\n                self.classifier = nn.Sequential(\n                    nn.Linear(self.feature_dim, quantum_layer.embedding_dim),\n                    nn.ReLU(),\n                    nn.Dropout(0.3),\n                    quantum_layer,\n                    nn.Linear(quantum_layer.n_qubits, num_classes)  # Quantum outputs to classes\n                )\n            else:\n                # Standard classifier\n                self.classifier = nn.Sequential(\n                    nn.Dropout(0.3),\n                    nn.Linear(self.feature_dim, num_classes)\n                )\n        \n        def forward(self, x):\n            features = self.backbone(x)\n            return self.classifier(features)\n    \n    def create_model(config, quantum_layer=None):\n        \"\"\"Create a model based on configuration.\"\"\"\n        \n        backbone_name = config['backbone']\n        num_classes = config['num_classes']\n        \n        model = PlantDiseaseModel(\n            backbone_name=backbone_name,\n            num_classes=num_classes,\n            pretrained=True,\n            quantum_layer=quantum_layer\n        )\n        \n        return model
    
    def list_available_models():\n        \"\"\"List all available backbone models.\"\"\"\n        print(\"📱 Available Backbone Models:\")\n        for name, info in BACKBONE_MODELS.items():\n            print(f\"  {name}:\")\n            print(f\"    - {info['description']}\")\n            print(f\"    - Parameters: {info['params']}\")\n            print(f\"    - Input size: {info['input_size']}\")\n            print()

print("✅ Model factory ready!")

In [None]:
# Show available models and create model
list_available_models()

# Create model (without quantum layer for now)
print(f"🏗️  Creating model: {config['backbone']}")
model = create_model(config, quantum_layer=None)

# Move to device
model = model.to(device)

# Model summary
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\\n📊 Model Summary:")
print(f"Backbone: {config['backbone']}")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Model size (estimated): {total_params * 4 / (1024**2):.1f} MB")

# Test model with dummy input
with torch.no_grad():
    dummy_input = torch.randn(1, 3, config['image_size'], config['image_size']).to(device)
    output = model(dummy_input)
    print(f"Output shape: {output.shape}")
    print(f"Expected shape: [batch_size, {config['num_classes']}]")
    
print("✅ Model created successfully!")

# 4. Optional Quantum Layer Module (PennyLane)

Implement quantum neural network layers for experimental hybrid quantum-classical models.

In [None]:
# Quantum layer implementation (optional)\nquantum_available = False\n\ntry:\n    from quantum_layer import QuantumLayer, HybridQuantumClassifier\n    quantum_available = True\n    print(\"✅ Imported local quantum_layer\")\nexcept ImportError:\n    print(\"ℹ️  Local quantum_layer not found. Trying inline implementation.\")\n    \n    try:\n        import pennylane as qml\n        from pennylane import numpy as np\n        import torch.nn as nn\n        \n        class QuantumLayer(nn.Module):\n            \"\"\"Quantum Neural Network Layer using PennyLane.\"\"\"\n            \n            def __init__(self, n_qubits=4, n_layers=3, embedding_dim=4):\n                super().__init__()\n                self.n_qubits = n_qubits\n                self.n_layers = n_layers\n                self.embedding_dim = embedding_dim\n                \n                # Create quantum device\n                self.dev = qml.device('default.qubit', wires=n_qubits)\n                \n                # Define quantum circuit\n                @qml.qnode(self.dev, interface='torch')\n                def circuit(inputs, weights):\n                    # Embed classical data\n                    for i in range(n_qubits):\n                        qml.RY(inputs[i], wires=i)\n                    \n                    # Variational layers\n                    for layer in range(n_layers):\n                        for i in range(n_qubits):\n                            qml.RY(weights[layer, i, 0], wires=i)\n                            qml.RZ(weights[layer, i, 1], wires=i)\n                        \n                        # Entangling gates\n                        for i in range(n_qubits - 1):\n                            qml.CNOT(wires=[i, i + 1])\n                        if n_qubits > 2:\n                            qml.CNOT(wires=[n_qubits - 1, 0])  # Wrap around\n                    \n                    # Measure expectations\n                    return [qml.expval(qml.PauliZ(i)) for i in range(n_qubits)]\n                \n                # Create weight tensor\n                weight_shapes = {\"weights\": (n_layers, n_qubits, 2)}\n                self.qlayer = qml.qnn.TorchLayer(circuit, weight_shapes)\n            \n            def forward(self, x):\n                # Ensure input has correct dimension\n                if x.shape[-1] != self.n_qubits:\n                    # Project to quantum dimension\n                    if not hasattr(self, 'projection'):\n                        self.projection = nn.Linear(x.shape[-1], self.n_qubits).to(x.device)\n                    x = self.projection(x)\n                \n                # Apply quantum circuit\n                return self.qlayer(x)\n        \n        quantum_available = True\n        print(\"✅ Created inline quantum implementation\")\n        \n    except ImportError as e:\n        print(f\"⚠️  PennyLane not available: {e}\")\n        print(\"Quantum layers will be disabled.\")\n        quantum_available = False\n\n# Test quantum layer if available and enabled\nif quantum_available and config.get('use_quantum', False):\n    print(\"\\n🔬 Testing Quantum Layer:\")\n    \n    # Create quantum layer\n    quantum_config = config['quantum']\n    quantum_layer = QuantumLayer(\n        n_qubits=quantum_config['n_qubits'],\n        n_layers=quantum_config['n_layers'],\n        embedding_dim=quantum_config['embedding_dim']\n    )\n    \n    # Test with dummy input\n    test_input = torch.randn(2, quantum_config['embedding_dim'])  # Batch size 2\n    with torch.no_grad():\n        quantum_output = quantum_layer(test_input)\n        print(f\"Quantum layer input shape: {test_input.shape}\")\n        print(f\"Quantum layer output shape: {quantum_output.shape}\")\n    \n    print(\"✅ Quantum layer test successful!\")\n    print(\"\\n⚠️  Note: Quantum simulation is slow on CPU. Use small datasets for testing.\")\n    \n    # Create model with quantum layer\n    print(\"\\n🏗️  Creating quantum-classical hybrid model...\")\n    quantum_model = create_model(config, quantum_layer=quantum_layer)\n    quantum_model = quantum_model.to(device)\n    \n    # Test quantum model\n    with torch.no_grad():\n        dummy_input = torch.randn(1, 3, config['image_size'], config['image_size']).to(device)\n        quantum_output = quantum_model(dummy_input)\n        print(f\"Quantum model output shape: {quantum_output.shape}\")\n    \n    print(\"✅ Quantum model created successfully!\")\n    \n    # Ask user if they want to use quantum model\n    use_quantum_model = True  # Set to False for classical training\n    if use_quantum_model:\n        model = quantum_model\n        print(\"🚀 Using quantum-classical hybrid model for training\")\n    else:\n        print(\"📱 Sticking with classical model for faster training\")\n    \nelse:\n    if config.get('use_quantum', False):\n        print(\"⚠️  Quantum layers requested but not available. Using classical model.\")\n    else:\n        print(\"📱 Using classical model (quantum disabled in config)\")\n\nprint(\"\\n✅ Model setup complete!\")

# 5. Training Loop with AMP and Advanced Features

Implement the complete training pipeline with mixed precision, checkpointing, and all production features.

In [None]:
# Create data loaders\nprint(\"📊 Creating data loaders...\")\n\nif Path(config['data_dir']).exists():\n    if config.get('quick_test', True):\n        # Quick test with subset\n        print(f\"🚀 Quick test mode: using {config['test_subset_size']} samples\")\n        train_loader = create_subset_loader(\n            config['data_dir'], \n            config, \n            subset_size=config['test_subset_size'], \n            split='train'\n        )\n        val_loader = create_subset_loader(\n            config['data_dir'], \n            config, \n            subset_size=config['test_subset_size'] // 4, \n            split='val'\n        )\n    else:\n        # Full dataset training\n        print(\"🎯 Full training mode\")\n        train_loader, val_loader = make_dataloaders(\n            config['data_dir'], \n            config\n        )\n    \n    print(f\"✅ Data loaders created:\")\n    print(f\"  - Train: {len(train_loader)} batches\")\n    print(f\"  - Validation: {len(val_loader)} batches\")\n    data_available = True\nelse:\n    print(\"⚠️  Data directory not found. Skipping training for now.\")\n    train_loader = None\n    val_loader = None\n    data_available = False\n\n# Test data loading if available\nif data_available:\n    print(\"\\n🔍 Testing data loading...\")\n    try:\n        # Get one batch\n        for batch_idx, (images, labels) in enumerate(train_loader):\n            print(f\"Batch {batch_idx + 1}:\")\n            print(f\"  Images shape: {images.shape}\")\n            print(f\"  Labels shape: {labels.shape}\")\n            print(f\"  Image range: [{images.min():.3f}, {images.max():.3f}]\")\n            print(f\"  Unique labels: {torch.unique(labels).tolist()}\")\n            \n            # Test model forward pass\n            model.eval()\n            with torch.no_grad():\n                images = images.to(device)\n                outputs = model(images)\n                print(f\"  Model output shape: {outputs.shape}\")\n                print(f\"  Output range: [{outputs.min():.3f}, {outputs.max():.3f}]\")\n            \n            break  # Only test first batch\n        \n        print(\"✅ Data loading test successful!\")\n        \n    except Exception as e:\n        print(f\"❌ Data loading test failed: {e}\")\n        data_available = False"

In [None]:
# Import training utilities\ntry:\n    from training_torch import Trainer\n    print(\"✅ Imported local training_torch\")\nexcept ImportError:\n    print(\"ℹ️  Local training_torch not found. Using simplified inline implementation.\")\n    \n    # Simplified training implementation for notebook\n    import torch.optim as optim\n    from torch.cuda.amp import GradScaler, autocast\n    from torch.optim.lr_scheduler import OneCycleLR\n    from sklearn.metrics import accuracy_score, classification_report\n    import time\n    from tqdm.auto import tqdm\n    \n    class SimpleTrainer:\n        \"\"\"Simplified trainer for notebook use.\"\"\"\n        \n        def __init__(self, model, train_loader, val_loader, config, save_dir=\"./checkpoints\"):\n            self.model = model\n            self.train_loader = train_loader\n            self.val_loader = val_loader\n            self.config = config\n            self.save_dir = Path(save_dir)\n            self.save_dir.mkdir(parents=True, exist_ok=True)\n            \n            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n            \n            # Mixed precision\n            self.use_amp = config.get('use_amp', True) and torch.cuda.is_available()\n            if self.use_amp:\n                self.scaler = GradScaler()\n            \n            # Optimizer\n            self.optimizer = optim.AdamW(\n                model.parameters(),\n                lr=config['learning_rate'],\n                weight_decay=config.get('weight_decay', 1e-4)\n            )\n            \n            # Scheduler\n            steps_per_epoch = len(train_loader)\n            self.scheduler = OneCycleLR(\n                self.optimizer,\n                max_lr=config['learning_rate'],\n                epochs=config['epochs'],\n                steps_per_epoch=steps_per_epoch\n            )\n            \n            # Loss function\n            self.criterion = nn.CrossEntropyLoss()\n            \n            # History\n            self.history = {\n                'train_loss': [],\n                'train_acc': [],\n                'val_loss': [],\n                'val_acc': []\n            }\n            \n            # Best score tracking\n            self.best_val_acc = 0.0\n        \n        def train_epoch(self, epoch):\n            self.model.train()\n            total_loss = 0.0\n            all_preds = []\n            all_targets = []\n            \n            pbar = tqdm(self.train_loader, desc=f\"Epoch {epoch+1} [Train]\")\n            \n            for batch_idx, (data, target) in enumerate(pbar):\n                data, target = data.to(self.device), target.to(self.device)\n                \n                self.optimizer.zero_grad()\n                \n                if self.use_amp:\n                    with autocast():\n                        output = self.model(data)\n                        loss = self.criterion(output, target)\n                    \n                    self.scaler.scale(loss).backward()\n                    self.scaler.step(self.optimizer)\n                    self.scaler.update()\n                else:\n                    output = self.model(data)\n                    loss = self.criterion(output, target)\n                    loss.backward()\n                    self.optimizer.step()\n                \n                self.scheduler.step()\n                \n                # Statistics\n                total_loss += loss.item()\n                pred = output.argmax(dim=1, keepdim=True)\n                all_preds.extend(pred.cpu().numpy())\n                all_targets.extend(target.cpu().numpy())\n                \n                # Update progress\n                pbar.set_postfix({\n                    'Loss': f\"{loss.item():.4f}\",\n                    'LR': f\"{self.optimizer.param_groups[0]['lr']:.2e}\"\n                })\n            \n            # Calculate metrics\n            avg_loss = total_loss / len(self.train_loader)\n            accuracy = accuracy_score(all_targets, all_preds)\n            \n            return avg_loss, accuracy\n        \n        def validate(self, epoch):\n            self.model.eval()\n            total_loss = 0.0\n            all_preds = []\n            all_targets = []\n            \n            with torch.no_grad():\n                pbar = tqdm(self.val_loader, desc=f\"Epoch {epoch+1} [Val]\")\n                \n                for data, target in pbar:\n                    data, target = data.to(self.device), target.to(self.device)\n                    \n                    if self.use_amp:\n                        with autocast():\n                            output = self.model(data)\n                            loss = self.criterion(output, target)\n                    else:\n                        output = self.model(data)\n                        loss = self.criterion(output, target)\n                    \n                    total_loss += loss.item()\n                    pred = output.argmax(dim=1, keepdim=True)\n                    all_preds.extend(pred.cpu().numpy())\n                    all_targets.extend(target.cpu().numpy())\n                    \n                    pbar.set_postfix({'Loss': f\"{loss.item():.4f}\"})\n            \n            avg_loss = total_loss / len(self.val_loader)\n            accuracy = accuracy_score(all_targets, all_preds)\n            \n            return avg_loss, accuracy\n        \n        def train(self, epochs=None):\n            if epochs is None:\n                epochs = self.config.get('epochs', 5)\n            \n            print(f\"🚀 Starting training for {epochs} epochs...\")\n            print(f\"Device: {self.device}\")\n            print(f\"Mixed Precision: {self.use_amp}\")\n            \n            for epoch in range(epochs):\n                start_time = time.time()\n                \n                # Training\n                train_loss, train_acc = self.train_epoch(epoch)\n                \n                # Validation\n                val_loss, val_acc = self.validate(epoch)\n                \n                # Update history\n                self.history['train_loss'].append(train_loss)\n                self.history['train_acc'].append(train_acc)\n                self.history['val_loss'].append(val_loss)\n                self.history['val_acc'].append(val_acc)\n                \n                # Save best model\n                if val_acc > self.best_val_acc:\n                    self.best_val_acc = val_acc\n                    torch.save({\n                        'epoch': epoch,\n                        'model_state_dict': self.model.state_dict(),\n                        'optimizer_state_dict': self.optimizer.state_dict(),\n                        'val_acc': val_acc,\n                        'config': self.config\n                    }, self.save_dir / 'best_model.pth')\n                \n                epoch_time = time.time() - start_time\n                \n                print(f\"Epoch {epoch+1:2d}/{epochs}: \"\n                      f\"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, \"\n                      f\"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, \"\n                      f\"Time: {epoch_time:.1f}s\")\n            \n            print(f\"\\n✅ Training completed!\")\n            print(f\"Best validation accuracy: {self.best_val_acc:.4f}\")\n            \n            return self.history\n    \n    # Use simple trainer\n    Trainer = SimpleTrainer\n\nprint(\"✅ Training utilities ready!\")"

In [None]:
# Start training if data is available\nif data_available:\n    print(\"🎯 Initializing trainer...\")\n    \n    # Create trainer\n    trainer = Trainer(\n        model=model,\n        train_loader=train_loader,\n        val_loader=val_loader,\n        config=config,\n        save_dir=PROJECT_ROOT / 'checkpoints'\n    )\n    \n    # Start training\n    print(\"\\n🚀 Starting training...\")\n    print(f\"Configuration:\")\n    print(f\"  - Epochs: {config['epochs']}\")\n    print(f\"  - Batch size: {config['batch_size']}\")\n    print(f\"  - Learning rate: {config['learning_rate']}\")\n    print(f\"  - Mixed precision: {config['use_amp']}\")\n    print(f\"  - Quick test: {config['quick_test']}\")\n    \n    if config.get('quick_test', True):\n        print(\"\\n⚡ Running quick test (subset training)\")\n        print(\"Set config['quick_test'] = False for full training\")\n    \n    # Run training\n    history = trainer.train()\n    \n    # Training completed - show results in next section\n    training_completed = True\n    \nelse:\n    print(\"⚠️  Skipping training - no data available\")\n    print(\"\\n📋 To run training:\")\n    print(\"1. Make sure your data is in the correct directory structure\")\n    print(\"2. Update config['data_dir'] if needed\")\n    print(\"3. Re-run this cell\")\n    training_completed = False"

# 6. Model Evaluation and Metrics

Comprehensive evaluation of the trained model with detailed metrics and visualizations.

In [None]:
# Plot training history and evaluate model\nif training_completed:\n    import matplotlib.pyplot as plt\n    import seaborn as sns\n    \n    # Set style\n    plt.style.use('default')\n    sns.set_palette(\"husl\")\n    \n    # Create training history plots\n    fig, axes = plt.subplots(2, 2, figsize=(15, 10))\n    \n    # Loss plot\n    axes[0, 0].plot(history['train_loss'], label='Train Loss', linewidth=2)\n    axes[0, 0].plot(history['val_loss'], label='Validation Loss', linewidth=2)\n    axes[0, 0].set_title('Training and Validation Loss', fontsize=14, fontweight='bold')\n    axes[0, 0].set_xlabel('Epoch')\n    axes[0, 0].set_ylabel('Loss')\n    axes[0, 0].legend()\n    axes[0, 0].grid(True, alpha=0.3)\n    \n    # Accuracy plot\n    axes[0, 1].plot(history['train_acc'], label='Train Accuracy', linewidth=2)\n    axes[0, 1].plot(history['val_acc'], label='Validation Accuracy', linewidth=2)\n    axes[0, 1].set_title('Training and Validation Accuracy', fontsize=14, fontweight='bold')\n    axes[0, 1].set_xlabel('Epoch')\n    axes[0, 1].set_ylabel('Accuracy')\n    axes[0, 1].legend()\n    axes[0, 1].grid(True, alpha=0.3)\n    \n    # Learning rate plot (if available)\n    if hasattr(trainer, 'lr_history') and trainer.lr_history:\n        axes[1, 0].plot(trainer.lr_history)\n        axes[1, 0].set_title('Learning Rate Schedule', fontsize=14, fontweight='bold')\n        axes[1, 0].set_xlabel('Step')\n        axes[1, 0].set_ylabel('Learning Rate')\n        axes[1, 0].grid(True, alpha=0.3)\n    else:\n        axes[1, 0].text(0.5, 0.5, 'Learning Rate History\\nNot Available', \n                       ha='center', va='center', transform=axes[1, 0].transAxes,\n                       fontsize=12, style='italic')\n        axes[1, 0].set_title('Learning Rate Schedule', fontsize=14, fontweight='bold')\n    \n    # Training summary\n    best_train_acc = max(history['train_acc'])\n    best_val_acc = max(history['val_acc'])\n    final_train_loss = history['train_loss'][-1]\n    final_val_loss = history['val_loss'][-1]\n    \n    summary_text = f\"\"\"Training Summary:\n    \n    Best Train Accuracy: {best_train_acc:.4f}\n    Best Val Accuracy: {best_val_acc:.4f}\n    Final Train Loss: {final_train_loss:.4f}\n    Final Val Loss: {final_val_loss:.4f}\n    \n    Total Epochs: {len(history['train_loss'])}\n    Model: {config['backbone']}\n    Batch Size: {config['batch_size']}\n    Learning Rate: {config['learning_rate']}\"\"\"\n    \n    axes[1, 1].text(0.05, 0.95, summary_text, transform=axes[1, 1].transAxes,\n                    fontsize=11, verticalalignment='top', fontfamily='monospace',\n                    bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.8))\n    axes[1, 1].set_xlim(0, 1)\n    axes[1, 1].set_ylim(0, 1)\n    axes[1, 1].axis('off')\n    axes[1, 1].set_title('Training Summary', fontsize=14, fontweight='bold')\n    \n    plt.tight_layout()\n    plt.show()\n    \n    # Print detailed results\n    print(f\"\\n📊 Training Results Summary:\")\n    print(f\"{'='*50}\")\n    print(f\"Best Training Accuracy: {best_train_acc:.4f} ({best_train_acc*100:.2f}%)\")\n    print(f\"Best Validation Accuracy: {best_val_acc:.4f} ({best_val_acc*100:.2f}%)\")\n    print(f\"Final Training Loss: {final_train_loss:.4f}\")\n    print(f\"Final Validation Loss: {final_val_loss:.4f}\")\n    \n    # Check for overfitting\n    gap = best_train_acc - best_val_acc\n    if gap > 0.1:\n        print(f\"\\n⚠️  Potential overfitting detected (gap: {gap:.4f})\")\n        print(\"Consider: reducing learning rate, adding regularization, or more data\")\n    elif gap < 0:\n        print(f\"\\n✅ Model is generalizing well (val > train by {-gap:.4f})\")\n    else:\n        print(f\"\\n✅ Good balance between training and validation performance\")\n    \n    # Model complexity analysis\n    total_params = sum(p.numel() for p in model.parameters())\n    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n    model_size_mb = total_params * 4 / (1024**2)  # Assuming float32\n    \n    print(f\"\\n🏗️  Model Architecture:\")\n    print(f\"Total Parameters: {total_params:,}\")\n    print(f\"Trainable Parameters: {trainable_params:,}\")\n    print(f\"Estimated Model Size: {model_size_mb:.1f} MB\")\n    \n    if model_size_mb > 50:\n        print(f\"⚠️  Model size > 50MB. Consider quantization for deployment.\")\n    else:\n        print(f\"✅ Model size suitable for mobile deployment.\")\n        \nelse:\n    print(\"⚠️  No training results to display. Run training first.\")"

# 7. Inference Pipeline and Model Export

Create optimized inference pipeline with model quantization and ONNX export for production deployment.

In [None]:
# Export trained model for deployment\nif training_completed:\n    print(\"📦 Preparing model for deployment...\")\n    \n    # Load best model\n    model.eval()\n    checkpoint_path = trainer.save_dir / 'best_model.pth'\n    \n    if checkpoint_path.exists():\n        checkpoint = torch.load(checkpoint_path, map_location=device)\n        model.load_state_dict(checkpoint['model_state_dict'])\n        print(f\"✅ Loaded best model (Val Acc: {checkpoint['val_acc']:.4f})\")\n    else:\n        print(\"ℹ️  Using current model state (no checkpoint found)\")\n    \n    # Create export directory\n    export_dir = PROJECT_ROOT / 'exports'\n    export_dir.mkdir(exist_ok=True)\n    \n    # 1. Export PyTorch model\n    torch_model_path = export_dir / f\"plant_disease_{config['backbone']}_pytorch.pth\"\n    torch.save({\n        'model_state_dict': model.state_dict(),\n        'config': config,\n        'class_names': train_loader.dataset.classes if hasattr(train_loader.dataset, 'classes') else None,\n        'model_info': {\n            'backbone': config['backbone'],\n            'num_classes': config['num_classes'],\n            'image_size': config['image_size']\n        }\n    }, torch_model_path)\n    \n    model_size_mb = torch_model_path.stat().st_size / (1024**2)\n    print(f\"✅ PyTorch model saved: {torch_model_path.name} ({model_size_mb:.1f} MB)\")\n    \n    # 2. Export to ONNX (if available)\n    try:\n        import onnx\n        \n        onnx_model_path = export_dir / f\"plant_disease_{config['backbone']}_onnx.onnx\"\n        \n        # Create dummy input\n        dummy_input = torch.randn(1, 3, config['image_size'], config['image_size']).to(device)\n        \n        # Export to ONNX\n        torch.onnx.export(\n            model,\n            dummy_input,\n            str(onnx_model_path),\n            export_params=True,\n            opset_version=11,\n            do_constant_folding=True,\n            input_names=['input'],\n            output_names=['output'],\n            dynamic_axes={\n                'input': {0: 'batch_size'},\n                'output': {0: 'batch_size'}\n            }\n        )\n        \n        onnx_size_mb = onnx_model_path.stat().st_size / (1024**2)\n        print(f\"✅ ONNX model saved: {onnx_model_path.name} ({onnx_size_mb:.1f} MB)\")\n        \n        # Verify ONNX model\n        onnx_model = onnx.load(str(onnx_model_path))\n        onnx.checker.check_model(onnx_model)\n        print(f\"✅ ONNX model verified successfully\")\n        \n    except ImportError:\n        print(\"⚠️  ONNX not available. Skipping ONNX export.\")\n    except Exception as e:\n        print(f\"❌ ONNX export failed: {e}\")\n    \n    # 3. Model quantization (experimental)\n    try:\n        print(\"\\n🔬 Attempting model quantization...\")\n        \n        # Prepare model for quantization\n        model.eval()\n        model_quant = torch.quantization.quantize_dynamic(\n            model.cpu(), \n            {torch.nn.Linear}, \n            dtype=torch.qint8\n        )\n        \n        # Save quantized model\n        quant_model_path = export_dir / f\"plant_disease_{config['backbone']}_quantized.pth\"\n        torch.save({\n            'model_state_dict': model_quant.state_dict(),\n            'config': config,\n            'quantized': True\n        }, quant_model_path)\n        \n        quant_size_mb = quant_model_path.stat().st_size / (1024**2)\n        print(f\"✅ Quantized model saved: {quant_model_path.name} ({quant_size_mb:.1f} MB)\")\n        print(f\"📊 Compression ratio: {model_size_mb/quant_size_mb:.2f}x smaller\")\n        \n        # Move model back to device\n        model = model.to(device)\n        \n    except Exception as e:\n        print(f\"⚠️  Quantization failed: {e}\")\n    \n    # 4. Create inference function\n    def create_inference_pipeline(model_path, config):\n        \"\"\"Create a standalone inference pipeline.\"\"\"\n        \n        import torch\n        import torch.nn.functional as F\n        from PIL import Image\n        import numpy as np\n        \n        # Load model\n        checkpoint = torch.load(model_path, map_location='cpu')\n        \n        # Recreate model (would need model factory)\n        # model = create_model(checkpoint['config'])\n        # model.load_state_dict(checkpoint['model_state_dict'])\n        # model.eval()\n        \n        def preprocess_image(image_path, image_size=224):\n            \"\"\"Preprocess image for inference.\"\"\"\n            image = Image.open(image_path).convert('RGB')\n            image = image.resize((image_size, image_size))\n            image = np.array(image) / 255.0\n            image = (image - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])\n            image = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0)\n            return image\n        \n        def predict(image_path):\n            \"\"\"Make prediction on image.\"\"\"\n            image = preprocess_image(image_path, config['image_size'])\n            \n            with torch.no_grad():\n                output = model(image)\n                probabilities = F.softmax(output, dim=1)\n                predicted_class = torch.argmax(probabilities, dim=1).item()\n                confidence = probabilities[0][predicted_class].item()\n            \n            return predicted_class, confidence\n        \n        return predict\n    \n    # Save inference code template\n    inference_code = f\"\"\"# Plant Disease Inference Pipeline\n# Generated from Capstone-Lazarus training notebook\n\nimport torch\nimport torch.nn.functional as F\nfrom PIL import Image\nimport numpy as np\n\n# Model configuration\nCONFIG = {config}\n\n# Load model function\ndef load_model(model_path):\n    checkpoint = torch.load(model_path, map_location='cpu')\n    # Note: You'll need to include model_factory_torch.py for this to work\n    # from model_factory_torch import create_model\n    # model = create_model(checkpoint['config'])\n    # model.load_state_dict(checkpoint['model_state_dict'])\n    # model.eval()\n    # return model\n    pass\n\ndef preprocess_image(image_path, image_size={config['image_size']}):\n    \"\"\"Preprocess image for inference.\"\"\"\n    image = Image.open(image_path).convert('RGB')\n    image = image.resize((image_size, image_size))\n    image = np.array(image) / 255.0\n    image = (image - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])\n    image = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0)\n    return image\n\ndef predict(model, image_path, class_names=None):\n    \"\"\"Make prediction on image.\"\"\"\n    image = preprocess_image(image_path)\n    \n    with torch.no_grad():\n        output = model(image)\n        probabilities = F.softmax(output, dim=1)\n        predicted_class = torch.argmax(probabilities, dim=1).item()\n        confidence = probabilities[0][predicted_class].item()\n    \n    result = {{\n        'predicted_class': predicted_class,\n        'confidence': confidence,\n        'class_name': class_names[predicted_class] if class_names else None\n    }}\n    \n    return result\n\n# Example usage:\n# model = load_model('plant_disease_{config['backbone']}_pytorch.pth')\n# result = predict(model, 'path_to_image.jpg')\n# print(f\"Prediction: {{result['class_name']}} ({{result['confidence']:.3f}})\")\n\"\"\"\n    \n    inference_script_path = export_dir / 'inference_pipeline.py'\n    with open(inference_script_path, 'w') as f:\n        f.write(inference_code)\n    \n    print(f\"\\n✅ Inference pipeline saved: {inference_script_path.name}\")\n    \n    # Summary\n    print(f\"\\n📦 Export Summary:\")\n    print(f\"Export directory: {export_dir}\")\n    print(f\"Files created:\")\n    for file in export_dir.iterdir():\n        if file.is_file():\n            size_mb = file.stat().st_size / (1024**2)\n            print(f\"  - {file.name} ({size_mb:.1f} MB)\")\n    \n    print(f\"\\n🚀 Deployment Ready!\")\n    print(f\"Integration tips:\")\n    print(f\"1. Copy model files to your deployment environment\")\n    print(f\"2. Use inference_pipeline.py as starting point\")\n    print(f\"3. For Streamlit: integrate with existing app/streamlit_app/\")\n    print(f\"4. For mobile: use ONNX or quantized models\")\n    \nelse:\n    print(\"⚠️  No trained model to export. Complete training first.\")"

# 8. Colab Optimizations & Final Setup

This section provides Colab-specific optimizations for maximum training efficiency and resource utilization.

In [None]:
# Colab-specific optimizations and resource management\nif IS_COLAB:\n    print(\"🚀 Applying Colab optimizations...\")\n    \n    # 1. Memory management for large datasets\n    import gc\n    import psutil\n    \n    def print_memory_usage():\n        \"\"\"Print current memory usage.\"\"\"\n        process = psutil.Process()\n        memory_info = process.memory_info()\n        memory_mb = memory_info.rss / 1024**2\n        print(f\"💾 Memory usage: {memory_mb:.1f} MB\")\n        \n        # GPU memory if available\n        if torch.cuda.is_available():\n            gpu_memory = torch.cuda.memory_allocated() / 1024**2\n            gpu_reserved = torch.cuda.memory_reserved() / 1024**2\n            print(f\"🖥️  GPU memory: {gpu_memory:.1f} MB allocated, {gpu_reserved:.1f} MB reserved\")\n    \n    def cleanup_memory():\n        \"\"\"Clean up memory to prevent OOM.\"\"\"\n        gc.collect()\n        if torch.cuda.is_available():\n            torch.cuda.empty_cache()\n        print(\"🧹 Memory cleaned\")\n    \n    print_memory_usage()\n    \n    # 2. Enhanced batch size scaling for T4 GPU\n    if config.get('auto_batch_size', False) and torch.cuda.is_available():\n        gpu_memory_gb = torch.cuda.get_device_properties(0).total_memory / 1024**3\n        print(f\"🖥️  Detected GPU memory: {gpu_memory_gb:.1f} GB\")\n        \n        # Scale batch size based on GPU memory\n        if gpu_memory_gb >= 15:  # T4 GPU\n            recommended_batch = 32\n        elif gpu_memory_gb >= 8:  # Smaller GPU\n            recommended_batch = 16\n        else:\n            recommended_batch = 8\n        \n        print(f\"📊 Recommended batch size: {recommended_batch}\")\n        if config['batch_size'] != recommended_batch:\n            print(f\"⚠️  Current batch size: {config['batch_size']}\")\n            print(f\"Consider updating config to: batch_size: {recommended_batch}\")\n    \n    # 3. Drive mounting and checkpoint management\n    from google.colab import drive\n    \n    try:\n        drive.mount('/content/drive')\n        print(\"✅ Google Drive mounted\")\n        \n        # Create persistent checkpoint directory\n        drive_checkpoint_dir = Path('/content/drive/MyDrive/Capstone-Lazarus/checkpoints')\n        drive_checkpoint_dir.mkdir(parents=True, exist_ok=True)\n        print(f\"📁 Checkpoint directory: {drive_checkpoint_dir}\")\n        \n        # Update trainer save directory if training was initialized\n        if 'trainer' in locals():\n            # Copy existing checkpoints to drive\n            if trainer.save_dir.exists():\n                import shutil\n                for checkpoint in trainer.save_dir.glob('*.pth'):\n                    drive_checkpoint = drive_checkpoint_dir / checkpoint.name\n                    shutil.copy2(checkpoint, drive_checkpoint)\n                    print(f\"📋 Copied {checkpoint.name} to drive\")\n            \n            # Update save directory\n            trainer.save_dir = drive_checkpoint_dir\n            print(\"✅ Updated trainer to save checkpoints to Google Drive\")\n        \n    except Exception as e:\n        print(f\"⚠️  Drive mount failed: {e}\")\n        print(\"Continuing with local storage (will be lost on session end)\")\n    \n    # 4. Enhanced training configuration for Colab\n    colab_optimizations = {\n        'use_amp': True,  # Always use mixed precision in Colab\n        'gradient_clip_val': 1.0,  # Prevent gradient explosion\n        'gradient_accumulation_steps': max(1, 32 // config.get('batch_size', 16)),  # Simulate larger batch\n        'checkpoint_every_n_epochs': 2,  # More frequent checkpointing\n        'early_stopping_patience': 5,  # Reduce for faster experimentation\n    }\n    \n    print(\"\\n⚙️  Colab training optimizations:\")\n    for key, value in colab_optimizations.items():\n        print(f\"  {key}: {value}\")\n        if key not in config:\n            config[key] = value\n    \n    # 5. Session timeout prevention\n    print(\"\\n⏰ Setting up session keepalive...\")\n    \n    def keep_colab_alive():\n        \"\"\"Function to prevent Colab timeout (call periodically).\"\"\"\n        from IPython.display import Javascript\n        display(Javascript('''\n            function ClickConnect(){\n                console.log(\"Working\");\n                document.querySelector(\"colab-toolbar-button#connect\").click()\n            }\n            setInterval(ClickConnect, 60000)\n        '''))\n    \n    # Enable keepalive (optional - comment out if not needed)\n    # keep_colab_alive()\n    \nelse:\n    print(\"💻 Local environment detected - using HP ZBook optimizations\")\n    \n    # Local environment optimizations\n    def print_local_specs():\n        \"\"\"Print local hardware specifications.\"\"\"\n        import psutil\n        \n        # CPU info\n        cpu_count = psutil.cpu_count(logical=False)\n        cpu_count_logical = psutil.cpu_count(logical=True)\n        print(f\"🔧 CPU: {cpu_count} cores ({cpu_count_logical} threads)\")\n        \n        # Memory info\n        memory = psutil.virtual_memory()\n        memory_gb = memory.total / 1024**3\n        print(f\"💾 RAM: {memory_gb:.1f} GB\")\n        \n        # GPU info\n        if torch.cuda.is_available():\n            gpu_name = torch.cuda.get_device_name(0)\n            gpu_memory_gb = torch.cuda.get_device_properties(0).total_memory / 1024**3\n            print(f\"🖥️  GPU: {gpu_name} ({gpu_memory_gb:.1f} GB VRAM)\")\n        else:\n            print(\"🖥️  GPU: Not available (CPU training)\")\n    \n    print_local_specs()\n\n# Universal optimizations\nprint(\"\\n🔧 Universal optimizations:\")\n\n# Set optimal number of workers based on environment\nif IS_COLAB:\n    optimal_workers = 2  # Colab has limited CPU\nelse:\n    optimal_workers = min(4, torch.multiprocessing.cpu_count())  # Local machine\n\nprint(f\"👥 DataLoader workers: {optimal_workers}\")\nif 'num_workers' in config:\n    config['num_workers'] = optimal_workers\n\n# Enable deterministic training for reproducibility\ntorch.backends.cudnn.deterministic = True\ntorch.backends.cudnn.benchmark = False  # Set to True for speed if input sizes are fixed\nprint(\"🎯 Deterministic training enabled\")\n\n# Set random seeds for reproducibility\nimport random\nimport numpy as np\n\nseed = config.get('seed', 42)\ntorch.manual_seed(seed)\nnp.random.seed(seed)\nrandom.seed(seed)\nif torch.cuda.is_available():\n    torch.cuda.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n\nprint(f\"🌱 Random seed set to: {seed}\")\n\nprint(\"\\n✅ Environment optimization complete!\")"

# 🎯 Usage Instructions & Complete Pipeline Summary

## 🚀 Quick Start
1. **Run all cells sequentially** - Each section builds upon the previous
2. **Monitor training** - Watch for validation accuracy improvements
3. **Export models** - Use section 7 to save models for deployment
4. **Check experiments** - View results in `experiments/` directory

## 📊 Expected Training Results
- **EfficientNet-B0**: ~90-95% validation accuracy
- **ResNet50**: ~88-93% validation accuracy  
- **MobileNet-V3**: ~87-92% validation accuracy
- **Training time**: 15-30 minutes per model (Colab T4)

## 🔧 Configuration Options

### Hardware Optimization
- **HP ZBook G5**: `batch_size: 16, use_amp: true`
- **Colab T4**: `batch_size: 32, gradient_accumulation: 2`
- **CPU only**: `batch_size: 8, use_amp: false`

### Model Variants
```yaml
backbone: 'efficientnet_b0'  # Fast, accurate
backbone: 'resnet50'         # Robust, proven
backbone: 'mobilenet_v3_small'  # Mobile deployment
```

### Quantum Experiments (Optional)
```yaml
quantum:
  enabled: true
  n_qubits: 4
  circuit_depth: 2
```

## 🐛 Troubleshooting

### Common Issues
- **OOM Error**: Reduce batch_size or image_size
- **Slow training**: Enable AMP, check num_workers
- **Import errors**: Run pip install commands again
- **Quantum errors**: Set `quantum.enabled: false`

### Memory Management
```python
# Free memory if needed
import gc
gc.collect()
torch.cuda.empty_cache()
```

## 📁 Output Files
```
experiments/
├── plant_disease_efficientnet_b0_20241223_143022/
│   ├── config.yaml
│   ├── best_model.pth
│   ├── training_log.csv
│   └── plots/
exports/
├── plant_disease_efficientnet_b0_pytorch.pth
├── plant_disease_efficientnet_b0_onnx.onnx
└── inference_pipeline.py
```

## 🔗 Integration with Existing Streamlit App
```python
# Copy to app/streamlit_app/
import torch
from inference_pipeline import load_model, predict

model = load_model('exports/plant_disease_efficientnet_b0_pytorch.pth')
result = predict(model, uploaded_image)
```

In [None]:
# Final status summary\nprint(\"🎉 Capstone-Lazarus PyTorch Training Pipeline Complete!\")\nprint(\"=\"*60)\n\n# Environment status\nprint(f\"Environment: {'Google Colab' if IS_COLAB else 'Local (HP ZBook G5)'}\")\nprint(f\"Device: {device}\")\nprint(f\"PyTorch version: {torch.__version__}\")\n\n# Configuration summary\nif 'config' in locals():\n    print(f\"\\n📝 Training Configuration:\")\n    key_configs = ['backbone', 'batch_size', 'learning_rate', 'epochs', 'image_size']\n    for key in key_configs:\n        if key in config:\n            print(f\"  {key}: {config[key]}\")\n\n# Training status\nif 'training_completed' in locals() and training_completed:\n    print(f\"\\n✅ Training Status: COMPLETED\")\n    if 'best_val_acc' in locals():\n        print(f\"  Best validation accuracy: {best_val_acc:.4f}\")\nelse:\n    print(f\"\\n⏳ Training Status: Ready to start\")\n    print(f\"  👆 Run the training cells above to begin\")\n\n# Available models\nprint(f\"\\n🏗️  Available Model Backbones:\")\nbackbone_options = ['efficientnet_b0', 'efficientnet_b1', 'resnet50', 'resnet34', 'mobilenet_v3_small', 'mobilenet_v3_large']\nfor backbone in backbone_options:\n    print(f\"  - {backbone}\")\n\n# File structure\nprint(f\"\\n📁 Generated Files:\")\nif PROJECT_ROOT.exists():\n    key_files = ['config.yaml', 'src/model_factory_torch.py', 'src/quantum_layer.py', 'src/data_utils_torch.py', 'src/training_torch.py']\n    for file_path in key_files:\n        full_path = PROJECT_ROOT / file_path\n        status = \"✅\" if full_path.exists() else \"❌\"\n        print(f\"  {status} {file_path}\")\n\n# Next steps\nprint(f\"\\n🚀 Next Steps:\")\nprint(f\"1. Modify config.yaml for your specific needs\")\nprint(f\"2. Run training cells to train your model\")\nprint(f\"3. Use exported models in your Streamlit app\")\nprint(f\"4. Experiment with quantum layers (optional)\")\nprint(f\"5. Deploy to production using ONNX models\")\n\nprint(f\"\\n📚 Documentation:\")\nprint(f\"- Training logs: experiments/[model_name]/\")\nprint(f\"- Model exports: exports/\")\nprint(f\"- Configuration: config.yaml\")\nprint(f\"- Source code: src/\")\n\nprint(f\"\\n⭐ Happy Training! ⭐\")\nprint(f\"For questions: Check README.md or experiment logs\")"