"""
Crop Image Classification Pipeline

This script implements a complete pipeline for training, validating and testing
image classification models on crop disease datasets.
"""

In [None]:
# ===== Section: Import Libraries =====
import os
import random
import time
import json
import logging
import warnings
import shutil
import argparse
from datetime import datetime
from pathlib import Path
from collections import defaultdict

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm  # For progress bars
import wandb  # For experiment tracking

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    handlers=[logging.FileHandler("training.log"), logging.StreamHandler()],
)
logger = logging.getLogger(__name__)

warnings.filterwarnings("ignore")

# PyTorch imports
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split, WeightedRandomSampler
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingWarmRestarts, OneCycleLR
import torchvision
from torchvision import datasets, models, transforms
import torchvision.transforms.functional as TF
from torch.utils.tensorboard import SummaryWriter

# Advanced augmentation
try:
    import albumentations as A
    from albumentations.pytorch import ToTensorV2
    ALBUMENTATIONS_AVAILABLE = True
except ImportError:
    print("Albumentations not available. Install with: pip install albumentations")
    ALBUMENTATIONS_AVAILABLE = False

# scikit-learn for evaluation
from sklearn.metrics import (
    classification_report, confusion_matrix, accuracy_score,
    precision_recall_fscore_support, roc_auc_score
)
from sklearn.utils.class_weight import compute_class_weight

# For visualization
import seaborn as sns
from PIL import Image, ImageEnhance

# Mixed precision training
from torch.cuda.amp import GradScaler, autocast

In [None]:
# Set random seed for reproducibility
def set_seed(seed=42):
    """Set random seed for reproducibility across all libraries"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)


set_seed()


# ===== Section: Enhanced Configuration Settings =====
class Config:
    """Enhanced configuration with checkpointing and optimization features"""
    # Paths
    data_dir = "./dataset"
    output_dir = "./models"
    results_dir = "./results"
    checkpoint_dir = "./checkpoints"
    tensorboard_dir = "./runs"

    # Dataset settings
    train_ratio = 0.7
    val_ratio = 0.15
    test_ratio = 0.15
    use_weighted_sampling = True  # For imbalanced datasets

    # Model settings
    model_type = "resnet101"  # Options: 'custom_cnn', 'resnet18', 'resnet50', 'resnet101', 'efficientnet_b0'
    pretrained = True
    freeze_backbone = False  # Whether to freeze backbone initially
    unfreeze_epoch = 10  # Epoch to unfreeze backbone if frozen

    # Training settings
    batch_size = 32
    num_epochs = 50
    learning_rate = 0.001
    weight_decay = 1e-4
    gradient_clip_val = 1.0

    # Advanced training options
    use_mixed_precision = True  # Automatic Mixed Precision
    use_cosine_annealing = True  # Cosine annealing scheduler
    warmup_epochs = 5
    label_smoothing = 0.1

    # Early stopping and checkpointing
    patience = 10
    save_top_k = 3  # Save top k models
    monitor_metric = 'val_acc'  # Metric to monitor for checkpointing
    checkpoint_every = 5  # Save checkpoint every N epochs

    # Data augmentation
    use_advanced_augmentation = True
    augmentation_strength = 'medium'  # 'light', 'medium', 'heavy'

    # Image settings
    img_size = 224
    normalize_mean = [0.485, 0.456, 0.406]
    normalize_std = [0.229, 0.224, 0.225]

    # Device settings
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    num_workers = 4
    pin_memory = True

    # Experiment tracking
    use_wandb = False  # Set to True to use Weights & Biases
    use_tensorboard = True
    experiment_name = f"crop_disease_{model_type}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"

    # Resume training
    resume_from_checkpoint = None  # Path to checkpoint to resume from


# Create output directories
for directory in [Config.output_dir, Config.results_dir, Config.checkpoint_dir, Config.tensorboard_dir]:
    os.makedirs(directory, exist_ok=True)

# Initialize experiment tracking
if Config.use_wandb:
    try:
        wandb.init(
            project="crop-disease-classification",
            name=Config.experiment_name,
            config=vars(Config)
        )
    except Exception as e:
        print(f"Failed to initialize wandb: {e}")
        Config.use_wandb = False


In [None]:
# ===== Section: Enhanced Data Preparation =====

class AlbumentationsTransform:
    """Wrapper for Albumentations transforms"""
    
    def __init__(self, transform):
        self.transform = transform
    
    def __call__(self, img):
        # Convert PIL to numpy
        img_np = np.array(img)
        # Apply transform
        augmented = self.transform(image=img_np)
        return augmented['image']


def get_advanced_transforms():
    """Get advanced data transformations using Albumentations"""
    if not ALBUMENTATIONS_AVAILABLE:
        return get_data_transforms()  # Fallback to basic transforms
    
    # Training transforms with different strength levels
    if Config.augmentation_strength == 'light':
        train_transform = A.Compose([
            A.Resize(Config.img_size + 32, Config.img_size + 32),
            A.RandomCrop(Config.img_size, Config.img_size),
            A.HorizontalFlip(p=0.5),
            A.Rotate(limit=15, p=0.5),
            A.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05, p=0.5),
            A.Normalize(Config.normalize_mean, Config.normalize_std),
            ToTensorV2()
        ])
    elif Config.augmentation_strength == 'medium':
        train_transform = A.Compose([
            A.Resize(Config.img_size + 32, Config.img_size + 32),
            A.RandomCrop(Config.img_size, Config.img_size),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.2),
            A.Rotate(limit=20, p=0.7),
            A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.7),
            A.OneOf([
                A.GaussNoise(var_limit=(10.0, 50.0), p=1.0),
                A.GaussianBlur(blur_limit=(1, 3), p=1.0),
                A.MotionBlur(blur_limit=3, p=1.0),
            ], p=0.3),
            A.OneOf([
                A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=1.0),
                A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=1.0),
            ], p=0.5),
            A.Cutout(num_holes=8, max_h_size=8, max_w_size=8, fill_value=0, p=0.3),
            A.Normalize(Config.normalize_mean, Config.normalize_std),
            ToTensorV2()
        ])
    else:  # heavy
        train_transform = A.Compose([
            A.Resize(Config.img_size + 64, Config.img_size + 64),
            A.RandomCrop(Config.img_size, Config.img_size),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.3),
            A.Rotate(limit=30, p=0.8),
            A.RandomScale(scale_limit=0.2, p=0.5),
            A.OneOf([
                A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=1.0),
                A.GridDistortion(p=1.0),
                A.OpticalDistortion(distort_limit=0.1, shift_limit=0.1, p=1.0),
            ], p=0.3),
            A.OneOf([
                A.GaussNoise(var_limit=(10.0, 50.0), p=1.0),
                A.GaussianBlur(blur_limit=(1, 5), p=1.0),
                A.MotionBlur(blur_limit=5, p=1.0),
            ], p=0.5),
            A.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.2, p=0.8),
            A.Cutout(num_holes=16, max_h_size=16, max_w_size=16, fill_value=0, p=0.5),
            A.CoarseDropout(max_holes=8, max_height=32, max_width=32, p=0.3),
            A.Normalize(Config.normalize_mean, Config.normalize_std),
            ToTensorV2()
        ])
    
    # Validation/Test transforms
    val_test_transform = A.Compose([
        A.Resize(Config.img_size + 32, Config.img_size + 32),
        A.CenterCrop(Config.img_size, Config.img_size),
        A.Normalize(Config.normalize_mean, Config.normalize_std),
        ToTensorV2()
    ])
    
    return AlbumentationsTransform(train_transform), AlbumentationsTransform(val_test_transform)


def get_data_transforms():
    """Define standard data transformations for training and validation/testing."""
    # Data augmentation and normalization for training
    train_transform = transforms.Compose([
        transforms.Resize(Config.img_size + 32),
        transforms.RandomCrop(Config.img_size),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.2),
        transforms.RandomRotation(20),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
        transforms.ToTensor(),
        transforms.Normalize(Config.normalize_mean, Config.normalize_std),
        transforms.RandomErasing(p=0.3, scale=(0.02, 0.33), ratio=(0.3, 3.3)),
    ])

    # Just normalization for validation & testing
    val_test_transform = transforms.Compose([
        transforms.Resize(Config.img_size + 32),
        transforms.CenterCrop(Config.img_size),
        transforms.ToTensor(),
        transforms.Normalize(Config.normalize_mean, Config.normalize_std),
    ])

    return train_transform, val_test_transform


def compute_class_weights(dataset):
    """Compute class weights for imbalanced datasets"""
    # Count samples per class
    class_counts = defaultdict(int)
    for _, label in dataset:
        class_counts[label] += 1
    
    # Convert to lists
    labels = list(class_counts.keys())
    counts = list(class_counts.values())
    
    # Compute weights
    total = sum(counts)
    weights = [total / (len(labels) * count) for count in counts]
    
    return torch.FloatTensor(weights)


def create_weighted_sampler(dataset):
    """Create weighted sampler for imbalanced datasets"""
    # Count samples per class
    class_counts = defaultdict(int)
    labels = []
    
    for _, label in dataset:
        class_counts[label] += 1
        labels.append(label)
    
    # Calculate weights
    total_samples = len(dataset)
    num_classes = len(class_counts)
    
    class_weights = {}
    for class_id, count in class_counts.items():
        class_weights[class_id] = total_samples / (num_classes * count)
    
    # Assign weight to each sample
    sample_weights = [class_weights[label] for label in labels]
    
    return WeightedRandomSampler(
        weights=sample_weights,
        num_samples=len(sample_weights),
        replacement=True
    )


def analyze_dataset_distribution(dataset, class_names):
    """Analyze and visualize dataset class distribution"""
    class_counts = defaultdict(int)
    for _, label in dataset:
        class_counts[label] += 1
    
    # Create visualization
    plt.figure(figsize=(15, 8))
    classes = [class_names[i] for i in sorted(class_counts.keys())]
    counts = [class_counts[i] for i in sorted(class_counts.keys())]
    
    plt.bar(range(len(classes)), counts)
    plt.xlabel('Classes')
    plt.ylabel('Number of Samples')
    plt.title('Dataset Class Distribution')
    plt.xticks(range(len(classes)), classes, rotation=45, ha='right')
    plt.tight_layout()
    plt.savefig(os.path.join(Config.results_dir, 'class_distribution.png'))
    plt.show()
    
    # Print statistics
    total = sum(counts)
    print(f"Total samples: {total}")
    print(f"Number of classes: {len(classes)}")
    print(f"Average samples per class: {total/len(classes):.1f}")
    print(f"Min samples: {min(counts)} ({classes[counts.index(min(counts))]})")
    print(f"Max samples: {max(counts)} ({classes[counts.index(max(counts))]})")
    
    return class_counts


def load_and_split_dataset():
    """Load the dataset and split it into train, validation and test sets."""
    # Get transforms
    if Config.use_advanced_augmentation and ALBUMENTATIONS_AVAILABLE:
        train_transform, val_test_transform = get_advanced_transforms()
    else:
        train_transform, val_test_transform = get_data_transforms()

    # Load the full dataset with training transformations
    full_dataset = datasets.ImageFolder(root=Config.data_dir, transform=train_transform)

    # Create a dataset with validation/test transformations
    val_test_dataset = datasets.ImageFolder(root=Config.data_dir, transform=val_test_transform)

    # Get class names and count
    class_names = full_dataset.classes
    num_classes = len(class_names)

    print(f"Found {num_classes} classes: {class_names[:5]}{'...' if len(class_names) > 5 else ''}")

    # Calculate splits
    dataset_size = len(full_dataset)
    train_size = int(Config.train_ratio * dataset_size)
    val_size = int(Config.val_ratio * dataset_size)
    test_size = dataset_size - train_size - val_size

    print(f"Total dataset size: {dataset_size}")
    print(f"Training set size: {train_size}")
    print(f"Validation set size: {val_size}")
    print(f"Testing set size: {test_size}")

    # Create the splits
    train_dataset, val_dataset_with_aug, test_dataset_with_aug = random_split(
        full_dataset,
        [train_size, val_size, test_size],
        generator=torch.Generator().manual_seed(42),
    )

    # Create validation and test datasets with proper transforms
    _, val_dataset_proper, test_dataset_proper = random_split(
        val_test_dataset,
        [train_size, val_size, test_size],
        generator=torch.Generator().manual_seed(42),
    )

    # Analyze dataset distribution
    print("\nAnalyzing dataset distribution...")
    analyze_dataset_distribution(train_dataset, class_names)

    return (
        train_dataset,
        val_dataset_proper,
        test_dataset_proper,
        class_names,
        num_classes,
    )


def create_dataloaders(train_dataset, val_dataset, test_dataset):
    """Create DataLoader objects for train, validation, and test datasets."""
    
    # Create weighted sampler for training if enabled
    train_sampler = None
    shuffle_train = True
    
    if Config.use_weighted_sampling:
        train_sampler = create_weighted_sampler(train_dataset)
        shuffle_train = False  # Cannot use shuffle with sampler
        print("Using weighted sampling for imbalanced dataset")

    train_loader = DataLoader(
        train_dataset,
        batch_size=Config.batch_size,
        shuffle=shuffle_train,
        sampler=train_sampler,
        num_workers=Config.num_workers,
        pin_memory=Config.pin_memory,
        drop_last=True,  # For batch norm stability
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=Config.batch_size,
        shuffle=False,
        num_workers=Config.num_workers,
        pin_memory=Config.pin_memory,
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=Config.batch_size,
        shuffle=False,
        num_workers=Config.num_workers,
        pin_memory=Config.pin_memory,
    )

    return train_loader, val_loader, test_loader

In [None]:
# ===== Section: Enhanced Model Architectures =====

class CustomCNN(nn.Module):
    """Enhanced Custom CNN architecture with residual connections"""

    def __init__(self, num_classes, dropout_rate=0.5):
        super(CustomCNN, self).__init__()
        
        # First convolutional block
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout2d(0.1),
        )

        # Second convolutional block
        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout2d(0.2),
        )

        # Third convolutional block
        self.conv3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout2d(0.3),
        )

        # Classifier with attention
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
        self.global_max_pool = nn.AdaptiveMaxPool2d(1)
        
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Dropout(dropout_rate),
            nn.Linear(512, 512),  # 256 * 2 (avg + max pooling)
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(512),
            nn.Dropout(dropout_rate),
            nn.Linear(512, num_classes),
        )
        
        # Initialize weights
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        
        # Global pooling
        avg_pool = self.global_avg_pool(x)
        max_pool = self.global_max_pool(x)
        x = torch.cat([avg_pool, max_pool], dim=1)
        
        return self.classifier(x)


def get_model(model_type, num_classes, pretrained=True):
    """Create a model based on the specified type with enhanced options."""
    print(f"Creating {model_type} model...")

    if model_type == "custom_cnn":
        model = CustomCNN(num_classes)

    elif model_type == "resnet18":
        if pretrained:
            model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        else:
            model = models.resnet18(weights=None)
        num_features = model.fc.in_features
        model.fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(num_features, num_classes)
        )

    elif model_type == "resnet50":
        if pretrained:
            model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
        else:
            model = models.resnet50(weights=None)
        num_features = model.fc.in_features
        model.fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(num_features, num_classes)
        )

    elif model_type == "resnet101":
        if pretrained:
            model = models.resnet101(weights=models.ResNet101_Weights.IMAGENET1K_V2)
        else:
            model = models.resnet101(weights=None)
        num_features = model.fc.in_features
        model.fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(num_features, num_classes)
        )

    elif model_type == "efficientnet_b0":
        if pretrained:
            model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1)
        else:
            model = models.efficientnet_b0(weights=None)
        num_features = model.classifier[1].in_features
        model.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(num_features, num_classes)
        )

    elif model_type == "vit_b_16":
        if pretrained:
            model = models.vit_b_16(weights=models.ViT_B_16_Weights.IMAGENET1K_V1)
        else:
            model = models.vit_b_16(weights=None)
        num_features = model.heads.head.in_features
        model.heads.head = nn.Linear(num_features, num_classes)

    else:
        raise ValueError(f"Model type {model_type} not supported")

    # Freeze backbone if specified
    if Config.freeze_backbone and pretrained and model_type != "custom_cnn":
        for param in model.parameters():
            param.requires_grad = False
        
        # Unfreeze classifier
        if hasattr(model, 'fc'):
            for param in model.fc.parameters():
                param.requires_grad = True
        elif hasattr(model, 'classifier'):
            for param in model.classifier.parameters():
                param.requires_grad = True
        elif hasattr(model, 'heads'):
            for param in model.heads.parameters():
                param.requires_grad = True

    return model


def unfreeze_backbone(model, model_type):
    """Unfreeze the backbone of a pretrained model"""
    if model_type == "custom_cnn":
        return  # Nothing to unfreeze
    
    for param in model.parameters():
        param.requires_grad = True
    
    print("Backbone unfrozen - all parameters are now trainable")


In [None]:
# ===== Section: Enhanced Training Functions =====

class ModelCheckpoint:
    """Handle model checkpointing and saving best models"""
    
    def __init__(self, checkpoint_dir, monitor='val_acc', save_top_k=3):
        self.checkpoint_dir = Path(checkpoint_dir)
        self.monitor = monitor
        self.save_top_k = save_top_k
        self.best_scores = []
        self.best_paths = []
        
    def save_checkpoint(self, model, optimizer, scheduler, epoch, metrics, is_best=False):
        """Save model checkpoint"""
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
            'metrics': metrics,
            'config': vars(Config)
        }
        
        # Save regular checkpoint
        checkpoint_path = self.checkpoint_dir / f'checkpoint_epoch_{epoch}.pth'
        torch.save(checkpoint, checkpoint_path)
        
        # Save best model
        if is_best:
            best_path = self.checkpoint_dir / 'best_model.pth'
            torch.save(checkpoint, best_path)
            print(f"New best model saved: {self.monitor} = {metrics[self.monitor]:.4f}")
            
            # Manage top-k models
            self._manage_top_k_models(checkpoint_path, metrics[self.monitor])
        
        return checkpoint_path
    
    def _manage_top_k_models(self, checkpoint_path, score):
        """Keep only top-k best models"""
        self.best_scores.append(score)
        self.best_paths.append(checkpoint_path)
        
        # Sort by score (descending for accuracy, ascending for loss)
        is_accuracy = 'acc' in self.monitor.lower()
        sorted_pairs = sorted(zip(self.best_scores, self.best_paths), 
                            key=lambda x: x[0], reverse=is_accuracy)
        
        # Keep only top-k
        if len(sorted_pairs) > self.save_top_k:
            to_remove = sorted_pairs[self.save_top_k:]
            for _, path in to_remove:
                if path.exists():
                    path.unlink()
            
            # Update lists
            sorted_pairs = sorted_pairs[:self.save_top_k]
            self.best_scores, self.best_paths = zip(*sorted_pairs)
            self.best_scores = list(self.best_scores)
            self.best_paths = list(self.best_paths)


def load_checkpoint(checkpoint_path, model, optimizer=None, scheduler=None):
    """Load model checkpoint"""
    checkpoint = torch.load(checkpoint_path, map_location=Config.device)
    
    model.load_state_dict(checkpoint['model_state_dict'])
    
    if optimizer and 'optimizer_state_dict' in checkpoint:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    if scheduler and 'scheduler_state_dict' in checkpoint and checkpoint['scheduler_state_dict']:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    
    return checkpoint['epoch'], checkpoint.get('metrics', {})


class LabelSmoothingCrossEntropy(nn.Module):
    """Label smoothing cross entropy loss"""
    
    def __init__(self, smoothing=0.1):
        super().__init__()
        self.smoothing = smoothing
    
    def forward(self, pred, target):
        n_classes = pred.size(-1)
        one_hot = torch.zeros_like(pred).scatter(1, target.view(-1, 1), 1)
        one_hot = one_hot * (1 - self.smoothing) + (1 - one_hot) * self.smoothing / (n_classes - 1)
        log_prob = F.log_softmax(pred, dim=-1)
        return -(one_hot * log_prob).sum(dim=-1).mean()


def create_optimizer_and_scheduler(model, train_loader):
    """Create optimizer and learning rate scheduler"""
    # Different learning rates for different parts of the model
    if Config.model_type != "custom_cnn" and Config.pretrained:
        # Separate parameters for backbone and classifier
        backbone_params = []
        classifier_params = []
        
        for name, param in model.named_parameters():
            if 'fc' in name or 'classifier' in name or 'heads' in name:
                classifier_params.append(param)
            else:
                backbone_params.append(param)
        
        # Use lower learning rate for backbone
        optimizer = optim.AdamW([
            {'params': backbone_params, 'lr': Config.learning_rate * 0.1},
            {'params': classifier_params, 'lr': Config.learning_rate}
        ], weight_decay=Config.weight_decay)
    else:
        optimizer = optim.AdamW(
            model.parameters(),
            lr=Config.learning_rate,
            weight_decay=Config.weight_decay
        )
    
    # Learning rate scheduler
    if Config.use_cosine_annealing:
        scheduler = CosineAnnealingWarmRestarts(
            optimizer,
            T_0=len(train_loader) * 10,  # Restart every 10 epochs
            T_mult=2,
            eta_min=Config.learning_rate * 0.001
        )
    else:
        scheduler = ReduceLROnPlateau(
            optimizer,
            mode='max',
            factor=0.5,
            patience=5,
            verbose=True
        )
    
    return optimizer, scheduler


def train_model(model, dataloaders, num_epochs=None):
    """Enhanced training with checkpointing and advanced features"""
    if num_epochs is None:
        num_epochs = Config.num_epochs
        
    since = time.time()
    device = Config.device
    model = model.to(device)

    # Initialize tracking
    history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}
    best_acc = 0.0
    start_epoch = 0

    # Create optimizer and scheduler
    optimizer, scheduler = create_optimizer_and_scheduler(model, dataloaders['train'])

    # Loss function with label smoothing
    if Config.label_smoothing > 0:
        criterion = LabelSmoothingCrossEntropy(smoothing=Config.label_smoothing)
    else:
        criterion = nn.CrossEntropyLoss()

    # Mixed precision scaler
    scaler = GradScaler() if Config.use_mixed_precision else None

    # Checkpointing
    checkpoint_manager = ModelCheckpoint(
        Config.checkpoint_dir,
        monitor=Config.monitor_metric,
        save_top_k=Config.save_top_k
    )

    # TensorBoard logging
    writer = None
    if Config.use_tensorboard:
        writer = SummaryWriter(Config.tensorboard_dir)

    # Resume from checkpoint if specified
    if Config.resume_from_checkpoint:
        print(f"Resuming from checkpoint: {Config.resume_from_checkpoint}")
        start_epoch, metrics = load_checkpoint(
            Config.resume_from_checkpoint, model, optimizer, scheduler
        )
        best_acc = metrics.get('val_acc', 0.0)
        start_epoch += 1

    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Model has {total_params:,} parameters, {trainable_params:,} are trainable")

    # Early stopping
    early_stop_counter = 0
    epochs_since_improvement = 0

    try:
        for epoch in range(start_epoch, num_epochs):
            print(f'Epoch {epoch + 1}/{num_epochs}')
            print('-' * 50)

            # Unfreeze backbone if specified
            if Config.freeze_backbone and epoch == Config.unfreeze_epoch:
                unfreeze_backbone(model, Config.model_type)
                # Recreate optimizer with all parameters
                optimizer, scheduler = create_optimizer_and_scheduler(model, dataloaders['train'])

            epoch_metrics = {}

            # Training and validation phases
            for phase in ['train', 'val']:
                if phase == 'train':
                    model.train()
                else:
                    model.eval()

                running_loss = 0.0
                running_corrects = 0
                num_samples = 0

                # Progress bar
                pbar = tqdm(dataloaders[phase], desc=f'{phase.capitalize()}')

                for batch_idx, (inputs, labels) in enumerate(pbar):
                    inputs = inputs.to(device, non_blocking=True)
                    labels = labels.to(device, non_blocking=True)

                    optimizer.zero_grad()

                    # Forward pass with mixed precision
                    with torch.set_grad_enabled(phase == 'train'):
                        if Config.use_mixed_precision and phase == 'train':
                            with autocast():
                                outputs = model(inputs)
                                loss = criterion(outputs, labels)
                            
                            scaler.scale(loss).backward()
                            
                            # Gradient clipping
                            if Config.gradient_clip_val > 0:
                                scaler.unscale_(optimizer)
                                torch.nn.utils.clip_grad_norm_(model.parameters(), Config.gradient_clip_val)
                            
                            scaler.step(optimizer)
                            scaler.update()
                        else:
                            outputs = model(inputs)
                            loss = criterion(outputs, labels)
                            
                            if phase == 'train':
                                loss.backward()
                                
                                # Gradient clipping
                                if Config.gradient_clip_val > 0:
                                    torch.nn.utils.clip_grad_norm_(model.parameters(), Config.gradient_clip_val)
                                
                                optimizer.step()

                    # Statistics
                    _, preds = torch.max(outputs, 1)
                    running_loss += loss.item() * inputs.size(0)
                    running_corrects += torch.sum(preds == labels.data)
                    num_samples += inputs.size(0)

                    # Update progress bar
                    current_acc = running_corrects.double() / num_samples
                    pbar.set_postfix({
                        'Loss': f'{running_loss/num_samples:.4f}',
                        'Acc': f'{current_acc:.4f}'
                    })

                    # Update scheduler for cosine annealing
                    if Config.use_cosine_annealing and phase == 'train' and isinstance(scheduler, CosineAnnealingWarmRestarts):
                        scheduler.step(epoch + batch_idx / len(dataloaders[phase]))

                # Epoch statistics
                epoch_loss = running_loss / num_samples
                epoch_acc = running_corrects.double() / num_samples

                print(f'{phase.capitalize()} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

                # Store metrics
                epoch_metrics[f'{phase}_loss'] = epoch_loss
                epoch_metrics[f'{phase}_acc'] = epoch_acc.item()
                history[f'{phase}_loss'].append(epoch_loss)
                history[f'{phase}_acc'].append(epoch_acc.item())

                # Log to TensorBoard
                if writer:
                    writer.add_scalar(f'{phase}/Loss', epoch_loss, epoch)
                    writer.add_scalar(f'{phase}/Accuracy', epoch_acc, epoch)

                # Update scheduler for ReduceLROnPlateau
                if phase == 'val' and not Config.use_cosine_annealing:
                    scheduler.step(epoch_acc)

            # Check for improvement
            current_val_acc = epoch_metrics['val_acc']
            is_best = current_val_acc > best_acc
            
            if is_best:
                best_acc = current_val_acc
                epochs_since_improvement = 0
            else:
                epochs_since_improvement += 1

            # Save checkpoint
            if (epoch + 1) % Config.checkpoint_every == 0 or is_best:
                checkpoint_manager.save_checkpoint(
                    model, optimizer, scheduler, epoch, epoch_metrics, is_best
                )

            # Log to wandb
            if Config.use_wandb:
                wandb.log(epoch_metrics, step=epoch)

            # Early stopping
            if epochs_since_improvement >= Config.patience:
                print(f'Early stopping triggered after {epoch + 1} epochs!')
                break

            print(f'Best val Acc so far: {best_acc:.4f}')
            print()

    except KeyboardInterrupt:
        print('Training interrupted by user')

    finally:
        if writer:
            writer.close()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:.4f}')

    # Load best model
    best_model_path = Config.checkpoint_dir + '/best_model.pth'
    if os.path.exists(best_model_path):
        load_checkpoint(best_model_path, model)
        print('Loaded best model weights')

    return model, history


In [None]:
# ===== Section: Enhanced Evaluation Functions =====
def evaluate_model(model, test_loader, class_names, save_results=True):
    """Comprehensive model evaluation with detailed metrics"""
    model.eval()
    model = model.to(Config.device)

    # Lists to store predictions and true labels
    all_preds = []
    all_labels = []
    all_probs = []

    # No gradient calculation needed for evaluation
    with torch.no_grad():
        for inputs, labels in tqdm(test_loader, desc='Evaluating'):
            inputs = inputs.to(Config.device)
            labels = labels.to(Config.device)

            outputs = model(inputs)
            probabilities = F.softmax(outputs, dim=1)
            _, preds = torch.max(outputs, 1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probabilities.cpu().numpy())

    # Calculate comprehensive metrics
    accuracy = accuracy_score(all_labels, all_preds)
    conf_matrix = confusion_matrix(all_labels, all_preds)
    
    # Detailed classification report
    report = classification_report(
        all_labels, all_preds, target_names=class_names, digits=4, output_dict=True
    )
    
    # Per-class metrics
    precision, recall, f1, support = precision_recall_fscore_support(
        all_labels, all_preds, average=None
    )

    # Macro and weighted averages
    macro_f1 = precision_recall_fscore_support(all_labels, all_preds, average='macro')[2]
    weighted_f1 = precision_recall_fscore_support(all_labels, all_preds, average='weighted')[2]

    print(f"Test Accuracy: {accuracy:.4f}")
    print(f"Macro F1-Score: {macro_f1:.4f}")
    print(f"Weighted F1-Score: {weighted_f1:.4f}")
    print("\nPer-class metrics:")
    for i, class_name in enumerate(class_names):
        print(f"{class_name:30} - Precision: {precision[i]:.4f}, Recall: {recall[i]:.4f}, F1: {f1[i]:.4f}")

    # Save detailed results
    if save_results:
        results = {
            'accuracy': accuracy,
            'macro_f1': macro_f1,
            'weighted_f1': weighted_f1,
            'classification_report': report,
            'confusion_matrix': conf_matrix.tolist(),
            'class_names': class_names
        }
        
        with open(os.path.join(Config.results_dir, 'evaluation_results.json'), 'w') as f:
            json.dump(results, f, indent=4)

    return accuracy, conf_matrix, report, all_probs


def compute_class_performance(all_labels, all_preds, class_names):
    """Compute detailed per-class performance metrics"""
    results = {}
    
    for i, class_name in enumerate(class_names):
        # Binary classification metrics for each class
        y_true_binary = [1 if label == i else 0 for label in all_labels]
        y_pred_binary = [1 if pred == i else 0 for pred in all_preds]
        
        precision = precision_recall_fscore_support(y_true_binary, y_pred_binary, average='binary')[0]
        recall = precision_recall_fscore_support(y_true_binary, y_pred_binary, average='binary')[1]
        f1 = precision_recall_fscore_support(y_true_binary, y_pred_binary, average='binary')[2]
        
        results[class_name] = {
            'precision': precision,
            'recall': recall,
            'f1_score': f1,
            'support': sum(y_true_binary)
        }
    
    return results


In [None]:
# ===== Section: Enhanced Visualization Functions =====
def plot_training_history(history, save_path=None):
    """Plot comprehensive training history with subplots"""
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    
    # Plot training & validation loss
    axes[0, 0].plot(history["train_loss"], label="Training Loss", linewidth=2)
    axes[0, 0].plot(history["val_loss"], label="Validation Loss", linewidth=2)
    axes[0, 0].set_title("Training and Validation Loss", fontsize=14, fontweight='bold')
    axes[0, 0].set_xlabel("Epochs")
    axes[0, 0].set_ylabel("Loss")
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)

    # Plot training & validation accuracy
    axes[0, 1].plot(history["train_acc"], label="Training Accuracy", linewidth=2)
    axes[0, 1].plot(history["val_acc"], label="Validation Accuracy", linewidth=2)
    axes[0, 1].set_title("Training and Validation Accuracy", fontsize=14, fontweight='bold')
    axes[0, 1].set_xlabel("Epochs")
    axes[0, 1].set_ylabel("Accuracy")
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)

    # Plot learning rate (if available)
    if len(history.get('lr', [])) > 0:
        axes[1, 0].plot(history["lr"], label="Learning Rate", linewidth=2, color='orange')
        axes[1, 0].set_title("Learning Rate Schedule", fontsize=14, fontweight='bold')
        axes[1, 0].set_xlabel("Epochs")
        axes[1, 0].set_ylabel("Learning Rate")
        axes[1, 0].set_yscale('log')
        axes[1, 0].grid(True, alpha=0.3)
    else:
        axes[1, 0].text(0.5, 0.5, 'Learning Rate\nNot Tracked', ha='center', va='center', transform=axes[1, 0].transAxes)

    # Plot validation accuracy zoomed
    if len(history["val_acc"]) > 0:
        axes[1, 1].plot(history["val_acc"], label="Validation Accuracy", linewidth=2, color='green')
        axes[1, 1].set_title("Validation Accuracy (Detailed)", fontsize=14, fontweight='bold')
        axes[1, 1].set_xlabel("Epochs")
        axes[1, 1].set_ylabel("Accuracy")
        axes[1, 1].grid(True, alpha=0.3)
        
        # Add best accuracy annotation
        best_acc = max(history["val_acc"])
        best_epoch = history["val_acc"].index(best_acc)
        axes[1, 1].annotate(f'Best: {best_acc:.4f}', 
                           xy=(best_epoch, best_acc), 
                           xytext=(best_epoch + 2, best_acc - 0.02),
                           arrowprops=dict(arrowstyle='->', color='red'))

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches="tight")
        print(f"Training history plot saved to {save_path}")

    plt.show()


def plot_confusion_matrix(conf_matrix, class_names, save_path=None, figsize=(12, 10)):
    """Plot enhanced confusion matrix with better formatting"""
    plt.figure(figsize=figsize)

    # Normalize the confusion matrix
    norm_conf_matrix = conf_matrix.astype("float") / conf_matrix.sum(axis=1)[:, np.newaxis]

    # Create heatmap
    mask = conf_matrix == 0
    sns.heatmap(
        norm_conf_matrix,
        annot=True,
        cmap="Blues",
        fmt=".3f",
        xticklabels=class_names,
        yticklabels=class_names,
        cbar_kws={'label': 'Normalized Count'},
        mask=mask
    )
    
    plt.title("Normalized Confusion Matrix", fontsize=16, fontweight='bold', pad=20)
    plt.ylabel("True Label", fontsize=12)
    plt.xlabel("Predicted Label", fontsize=12)
    plt.xticks(rotation=45, ha="right")
    plt.yticks(rotation=0)

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches="tight")
        print(f"Confusion matrix saved to {save_path}")

    plt.show()


def plot_class_distribution(class_counts, class_names, save_path=None):
    """Plot class distribution with enhanced styling"""
    plt.figure(figsize=(15, 8))
    
    # Sort classes by count for better visualization
    sorted_data = sorted(zip(class_names, class_counts), key=lambda x: x[1], reverse=True)
    sorted_names, sorted_counts = zip(*sorted_data)
    
    # Create bar plot
    bars = plt.bar(range(len(sorted_names)), sorted_counts, 
                   color=plt.cm.viridis(np.linspace(0, 1, len(sorted_names))))
    
    plt.xlabel('Classes', fontsize=12)
    plt.ylabel('Number of Samples', fontsize=12)
    plt.title('Dataset Class Distribution (Sorted by Count)', fontsize=14, fontweight='bold')
    plt.xticks(range(len(sorted_names)), sorted_names, rotation=45, ha='right')
    
    # Add value labels on bars
    for bar, count in zip(bars, sorted_counts):
        plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(sorted_counts)*0.01,
                str(count), ha='center', va='bottom', fontsize=10)
    
    plt.grid(axis='y', alpha=0.3)
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches="tight")
        print(f"Class distribution plot saved to {save_path}")
    
    plt.show()


def visualize_model_predictions(model, dataset, class_names, num_samples=8, save_path=None):
    """Visualize model predictions with confidence scores"""
    model.eval()
    device = Config.device
    
    # Get random samples
    indices = torch.randperm(len(dataset))[:num_samples]
    
    fig, axes = plt.subplots(2, 4, figsize=(20, 10))
    axes = axes.flatten()
    
    for i, idx in enumerate(indices):
        if i >= len(axes):
            break
            
        # Get image and true label
        image, true_label = dataset[idx]
        
        # Prepare image for display
        img_display = image.clone()
        if img_display.shape[0] == 3:  # If channels first
            img_display = img_display.permute(1, 2, 0)
        
        # Denormalize for display
        mean = torch.tensor(Config.normalize_mean)
        std = torch.tensor(Config.normalize_std)
        img_display = img_display * std + mean
        img_display = torch.clamp(img_display, 0, 1)
        
        # Make prediction
        image_input = image.unsqueeze(0).to(device)
        with torch.no_grad():
            output = model(image_input)
            probabilities = F.softmax(output, dim=1)
            confidence, predicted = torch.max(probabilities, 1)
        
        # Get class names
        true_class = class_names[true_label]
        pred_class = class_names[predicted.item()]
        conf_score = confidence.item()
        
        # Plot
        axes[i].imshow(img_display)
        axes[i].axis('off')
        
        # Color code: green for correct, red for incorrect
        color = 'green' if predicted.item() == true_label else 'red'
        title = f'True: {true_class}\nPred: {pred_class}\nConf: {conf_score:.3f}'
        axes[i].set_title(title, color=color, fontsize=10, fontweight='bold')
    
    # Hide unused subplots
    for i in range(len(indices), len(axes)):
        axes[i].axis('off')
    
    plt.suptitle('Model Predictions with Confidence Scores', fontsize=16, fontweight='bold')
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches="tight")
        print(f"Prediction visualization saved to {save_path}")
    
    plt.show()

In [None]:
# ===== Section: Prediction Demo =====
def predict_random_images(model, test_dataset, class_names, num_images=5):
    """Display and predict random images from the test set."""
    model.eval()
    model = model.to(Config.device)

    # Get a batch of random indices
    indices = torch.randperm(len(test_dataset))[:num_images]

    # Create a figure
    plt.figure(figsize=(15, 3 * num_images))

    for i, idx in enumerate(indices):
        # Get image and label
        image, label = test_dataset[idx]

        # Convert image for display
        image_for_display = image.clone()

        # Make prediction
        image = image.unsqueeze(0).to(Config.device)
        with torch.no_grad():
            output = model(image)
            _, predicted = torch.max(output, 1)

        # Get predicted and true class names
        predicted_class = class_names[predicted.item()]
        true_class = class_names[label]

        # Display the image
        plt.subplot(num_images, 1, i + 1)
        plt.imshow(np.transpose(image_for_display.cpu().numpy(), (1, 2, 0)))

        # Normalize the image for better display
        plt.title(f"True: {true_class} | Predicted: {predicted_class}")
        plt.axis("off")

        # Color based on correctness
        if predicted.item() == label:
            plt.title(
                f"True: {true_class} | Predicted: {predicted_class}", color="green"
            )
        else:
            plt.title(f"True: {true_class} | Predicted: {predicted_class}", color="red")

    plt.tight_layout()
    plt.savefig(os.path.join(Config.results_dir, "sample_predictions.png"))
    plt.show()

In [None]:
# ===== Section: Enhanced Main Execution =====
def main():
    """Enhanced main execution function with comprehensive training pipeline"""
    print("🚀 Starting Enhanced Crop Disease Classification Training Pipeline")
    print("=" * 70)
    
    # Print configuration
    print(f"\n📋 Configuration:")
    print(f"   Model: {Config.model_type}")
    print(f"   Pretrained: {Config.pretrained}")
    print(f"   Image size: {Config.img_size}x{Config.img_size}")
    print(f"   Batch size: {Config.batch_size}")
    print(f"   Learning rate: {Config.learning_rate}")
    print(f"   Epochs: {Config.num_epochs}")
    print(f"   Device: {Config.device}")
    print(f"   Mixed precision: {Config.use_mixed_precision}")
    print(f"   Advanced augmentation: {Config.use_advanced_augmentation}")

    # 1. Load and prepare data
    print("\n📂 Preparing datasets...")
    train_dataset, val_dataset, test_dataset, class_names, num_classes = load_and_split_dataset()

    train_loader, val_loader, test_loader = create_dataloaders(
        train_dataset, val_dataset, test_dataset
    )
    dataloaders = {"train": train_loader, "val": val_loader}

    # 2. Create model
    print(f"\n🏗️  Creating {Config.model_type} model...")
    model = get_model(Config.model_type, num_classes, Config.pretrained)
    model = model.to(Config.device)

    # Print model summary
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"   Total parameters: {total_params:,}")
    print(f"   Trainable parameters: {trainable_params:,}")
    print(f"   Model size: {total_params * 4 / 1024**2:.2f} MB")

    # 3. Train the model
    print(f"\n🏋️  Training {Config.model_type} model...")
    start_time = time.time()
    
    model, history = train_model(model, dataloaders, Config.num_epochs)
    
    training_time = time.time() - start_time
    print(f"\n⏱️  Training completed in {training_time//3600:.0f}h {(training_time%3600)//60:.0f}m {training_time%60:.0f}s")

    # 4. Evaluate on test set
    print("\n📊 Evaluating on test set...")
    accuracy, conf_matrix, report, all_probs = evaluate_model(model, test_loader, class_names)

    # 5. Create comprehensive visualizations
    print("\n📈 Creating visualizations...")
    
    # Training history
    plot_training_history(
        history,
        save_path=os.path.join(Config.results_dir, f"{Config.model_type}_training_history.png")
    )

    # Confusion matrix
    plot_confusion_matrix(
        conf_matrix,
        class_names,
        save_path=os.path.join(Config.results_dir, f"{Config.model_type}_confusion_matrix.png")
    )

    # Model predictions visualization
    visualize_model_predictions(
        model, test_dataset, class_names,
        save_path=os.path.join(Config.results_dir, f"{Config.model_type}_predictions.png")
    )

    # 6. Save models and results
    print("\n💾 Saving results...")
    
    # Save final model
    final_model_path = os.path.join(Config.output_dir, f"{Config.model_type}_final.pth")
    torch.save({
        'model_state_dict': model.state_dict(),
        'class_names': class_names,
        'num_classes': num_classes,
        'config': vars(Config),
        'accuracy': accuracy,
        'training_history': history
    }, final_model_path)
    print(f"   Final model saved to {final_model_path}")

    # Save best model for production use
    production_model_path = os.path.join(Config.output_dir, f"{Config.model_type}_production.pth")
    torch.save(model.state_dict(), production_model_path)
    print(f"   Production model saved to {production_model_path}")

    # 7. Generate comprehensive report
    print("\n📝 Generating comprehensive report...")
    
    # Training summary
    best_val_acc = max(history['val_acc']) if history['val_acc'] else 0
    best_epoch = history['val_acc'].index(best_val_acc) + 1 if history['val_acc'] else 0
    
    report_content = f"""# 🌱 Crop Disease Classification - Training Report

## 📋 Experiment Configuration
- **Model Architecture**: {Config.model_type}
- **Pretrained**: {Config.pretrained}
- **Image Size**: {Config.img_size}x{Config.img_size}
- **Batch Size**: {Config.batch_size}
- **Learning Rate**: {Config.learning_rate}
- **Weight Decay**: {Config.weight_decay}
- **Epochs Trained**: {len(history['train_loss']) if history['train_loss'] else 0}
- **Early Stopping Patience**: {Config.patience}
- **Mixed Precision**: {Config.use_mixed_precision}
- **Advanced Augmentation**: {Config.use_advanced_augmentation}
- **Augmentation Strength**: {Config.augmentation_strength}

## 📊 Dataset Information
- **Total Classes**: {num_classes}
- **Training Samples**: {len(train_dataset)}
- **Validation Samples**: {len(val_dataset)}
- **Test Samples**: {len(test_dataset)}

## 🏆 Performance Results
- **Best Validation Accuracy**: {best_val_acc:.4f} (Epoch {best_epoch})
- **Final Test Accuracy**: {accuracy:.4f}
- **Training Time**: {training_time//3600:.0f}h {(training_time%3600)//60:.0f}m {training_time%60:.0f}s

## 🔧 Model Information
- **Total Parameters**: {total_params:,}
- **Trainable Parameters**: {trainable_params:,}
- **Model Size**: {total_params * 4 / 1024**2:.2f} MB

## 📈 Training Progress
- **Final Training Loss**: {history['train_loss'][-1]:.4f if history['train_loss'] else 'N/A'}
- **Final Validation Loss**: {history['val_loss'][-1]:.4f if history['val_loss'] else 'N/A'}
- **Final Training Accuracy**: {history['train_acc'][-1]:.4f if history['train_acc'] else 'N/A'}
- **Final Validation Accuracy**: {history['val_acc'][-1]:.4f if history['val_acc'] else 'N/A'}

## 🎯 Detailed Classification Results
```
{classification_report(list(range(len(class_names))), list(range(len(class_names))), target_names=class_names) if class_names else 'N/A'}
```

## 📁 Generated Files
- Model checkpoint: `{final_model_path}`
- Production model: `{production_model_path}`
- Training history: `training_history.json`
- Evaluation results: `evaluation_results.json`
- Visualizations: `{Config.results_dir}/`

## ⚙️ Reproduction Commands
```bash
# Install dependencies
pip install torch torchvision albumentations wandb tensorboard

# Run training
python model_training.py
```

---
*Report generated on {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}*
"""

    # Save report
    report_path = os.path.join(Config.results_dir, "training_report.md")
    with open(report_path, "w") as f:
        f.write(report_content)
    print(f"   Comprehensive report saved to {report_path}")

    # Save training history as JSON
    history_path = os.path.join(Config.results_dir, "training_history.json")
    history_dict = {
        "train_loss": [float(val) for val in history["train_loss"]],
        "val_loss": [float(val) for val in history["val_loss"]],
        "train_acc": [float(val) for val in history["train_acc"]],
        "val_acc": [float(val) for val in history["val_acc"]],
        "config": vars(Config),
        "class_names": class_names,
        "num_classes": num_classes
    }
    
    with open(history_path, "w") as f:
        json.dump(history_dict, f, indent=4)
    print(f"   Training history saved to {history_path}")

    # 8. Final summary
    print("\n🎉 Training Pipeline Completed Successfully!")
    print(f"   📊 Best validation accuracy: {best_val_acc:.4f}")
    print(f"   🧪 Final test accuracy: {accuracy:.4f}")
    print(f"   💾 Models saved in: {Config.output_dir}")
    print(f"   📈 Results saved in: {Config.results_dir}")
    
    if Config.use_tensorboard:
        print(f"   📊 TensorBoard logs: {Config.tensorboard_dir}")
        print(f"   💡 View with: tensorboard --logdir {Config.tensorboard_dir}")
    
    print("\n" + "=" * 70)

    return model, history, accuracy


# ===== Utility Functions =====
def resume_training(checkpoint_path, data_dir=None):
    """Resume training from a checkpoint"""
    if data_dir:
        Config.data_dir = data_dir
    
    Config.resume_from_checkpoint = checkpoint_path
    print(f"Resuming training from: {checkpoint_path}")
    
    return main()


def evaluate_saved_model(model_path, data_dir=None):
    """Evaluate a saved model"""
    if data_dir:
        Config.data_dir = data_dir
    
    # Load data
    _, _, test_dataset, class_names, num_classes = load_and_split_dataset()
    _, _, test_loader = create_dataloaders(None, None, test_dataset)
    
    # Load model
    checkpoint = torch.load(model_path, map_location=Config.device)
    model = get_model(Config.model_type, num_classes, False)
    
    if 'model_state_dict' in checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'])
    else:
        model.load_state_dict(checkpoint)
    
    # Evaluate
    accuracy, conf_matrix, report, _ = evaluate_model(model, test_loader, class_names)
    return accuracy, conf_matrix, report


if __name__ == "__main__":
    main()
