In [None]:
import os
import random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split, WeightedRandomSampler
from torchvision import datasets, transforms, models
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.metrics import (
    confusion_matrix, 
    ConfusionMatrixDisplay, 
    classification_report,
    balanced_accuracy_score,
    precision_recall_fscore_support
)

# Custom imports
from utils.gpu_utils import CheckGPU, CheckCUDA, CheckGPUBrief, get_device
from utils.guava_dataset import (
    GuavaDataset, 
    load_guava_info, 
    print_guava_summary,
    get_class_labels
)
from utils.dataset_counter import CountDataset, PrintClassBalance

print("‚úÖ All libraries and custom modules imported successfully!")

#### Detect GPU Available, Details, Cuda, and cuDNN

In [None]:
# From utils.gpu_utils
CheckGPU()
CheckCUDA()

### Global Configuration Variables

In [None]:
# ============================
# GLOBAL CONFIGURATION
# ============================

# Dataset paths
DATASET_DIR = "../dataset"
TRAIN_DIR = os.path.join(DATASET_DIR, "Train")
TEST_DIR = os.path.join(DATASET_DIR, "Test")

# We'll create validation from training set
VALIDATION_SPLIT = 0.15  # 15% of training data for validation (85% train, 15% val)

# Augmentation settings
USE_AUGMENTATION = True

# Weighted sampling for class imbalance
USE_WEIGHTED_SAMPLER = True

# Normalization values
# Option 1: Use ImageNet pretrained values (recommended for transfer learning)
USE_IMAGENET_NORM = True
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

# Option 2: Compute from Guava dataset (set USE_IMAGENET_NORM = False to use these)
# These will be computed later if needed
GUAVA_MEAN = [0.5, 0.5, 0.5]  # Placeholder - compute from dataset
GUAVA_STD = [0.25, 0.25, 0.25]  # Placeholder - compute from dataset

# Set normalization based on choice
if USE_IMAGENET_NORM:
    NORMALIZE_MEAN = IMAGENET_MEAN
    NORMALIZE_STD = IMAGENET_STD
    print("üìä Using ImageNet normalization values (recommended for transfer learning)")
else:
    NORMALIZE_MEAN = GUAVA_MEAN
    NORMALIZE_STD = GUAVA_STD
    print("üìä Using Guava-specific normalization values")

# Image settings
IMG_HEIGHT = 224
IMG_WIDTH = 224

# Batch size - adjust based on your GPU memory
BATCH_SIZE = 32

# Number of classes - UPDATE THIS based on your dataset
# e.g., if you have day_01, day_02, ..., day_07 folders, set NUM_CLASSES = 7
NUM_CLASSES = None  # Will be auto-detected from dataset

# Model Architecture Selection
# Options: 'resnet50' or 'efficientnet_b3'
MODEL_ARCH = 'resnet50'  # Change this to 'efficientnet_b3' to switch models

# Model save path
MODEL_SAVE_DIR = "../models"
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)

# Seed for reproducibility
SEED = 42

# Set random seeds
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

print(f"\n‚úÖ Global configuration set successfully!")
print(f"   Dataset: {DATASET_DIR}")
print(f"   Model Architecture: {MODEL_ARCH.upper()}")
print(f"   Image Size: {IMG_HEIGHT}x{IMG_WIDTH}")
print(f"   Batch Size: {BATCH_SIZE}")
print(f"   Validation Split: {VALIDATION_SPLIT*100:.0f}%")
print(f"   Augmentation: {'ENABLED' if USE_AUGMENTATION else 'DISABLED'}")
print(f"   Weighted Sampler: {'ENABLED' if USE_WEIGHTED_SAMPLER else 'DISABLED'}")

### Guava Dataset Analysis & Information

Load and analyze the Guava Ripeness dataset structure, class distribution, and statistics.

In [None]:
# Load Guava dataset information
print("üîç Loading Guava dataset information...")

# Check if dataset exists
if not os.path.exists(TRAIN_DIR):
    print(f"‚ùå Training directory not found: {TRAIN_DIR}")
    print("\nüìù Please organize your dataset as follows:")
    print("   dataset/")
    print("   ‚îú‚îÄ‚îÄ Train/")
    print("   ‚îÇ   ‚îú‚îÄ‚îÄ day_01/  (or your age class names)")
    print("   ‚îÇ   ‚îú‚îÄ‚îÄ day_02/")
    print("   ‚îÇ   ‚îî‚îÄ‚îÄ ...")
    print("   ‚îî‚îÄ‚îÄ Test/")
    print("       ‚îú‚îÄ‚îÄ day_01/")
    print("       ‚îî‚îÄ‚îÄ ...")
else:
    dataset_info = load_guava_info(DATASET_DIR)
    print_guava_summary(dataset_info)
    
    # Auto-detect number of classes
    NUM_CLASSES = dataset_info['num_classes']
    CLASS_NAMES = dataset_info['classes']
    
    print(f"\n‚úÖ Dataset information loaded successfully!")
    print(f"   Detected {NUM_CLASSES} classes: {CLASS_NAMES}")
    
    # Count dataset details
    print("\n" + "="*60)
    print("üìÅ TRAINING SET DETAILS")
    train_info = CountDataset(TRAIN_DIR)
    
    if os.path.exists(TEST_DIR):
        print("\n" + "="*60)
        print("üìÅ TEST SET DETAILS")
        test_info = CountDataset(TEST_DIR)

### Visualize Sample Images from Dataset

Display sample guava images from each class to understand the data better.

In [None]:
from PIL import Image

print("üñºÔ∏è  Displaying sample guava images from dataset...")

if NUM_CLASSES is None:
    print("‚ùå Dataset not loaded. Please run the previous cell first.")
else:
    # Calculate grid size
    num_samples = min(NUM_CLASSES, 12)  # Show up to 12 samples
    cols = min(4, num_samples)
    rows = (num_samples + cols - 1) // cols
    
    fig, axes = plt.subplots(rows, cols, figsize=(4*cols, 4*rows))
    if num_samples == 1:
        axes = [axes]
    else:
        axes = axes.ravel()
    
    for idx, class_name in enumerate(CLASS_NAMES[:num_samples]):
        class_dir = os.path.join(TRAIN_DIR, class_name)
        
        # Get a random image from this class
        images = [f for f in os.listdir(class_dir) 
                  if os.path.splitext(f)[1].lower() in ['.jpg', '.jpeg', '.png', '.bmp']]
        
        if images:
            sample_img = random.choice(images)
            img_path = os.path.join(class_dir, sample_img)
            img = Image.open(img_path).convert('RGB')
            
            axes[idx].imshow(img)
            axes[idx].axis('off')
            axes[idx].set_title(f'{class_name}\n{img.size[0]}x{img.size[1]}px', 
                               fontsize=10, fontweight='bold')
        else:
            axes[idx].text(0.5, 0.5, 'No images', ha='center', va='center')
            axes[idx].axis('off')
    
    # Hide empty subplots
    for idx in range(num_samples, len(axes)):
        axes[idx].axis('off')
    
    plt.suptitle('Sample Guava Images by Ripeness Class', fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.show()
    
    print("‚úÖ Sample visualization complete!")

### Data Augmentation & Transforms

Define transforms for training (with augmentation) and validation/test sets.
Guava-specific augmentations optimized for fruit ripeness recognition.

In [None]:
if USE_AUGMENTATION:
    # Training augmentation - optimized for fruit images
    train_transforms = transforms.Compose([
        transforms.Resize((IMG_HEIGHT + 32, IMG_WIDTH + 32)),  # Resize slightly larger
        transforms.RandomCrop((IMG_HEIGHT, IMG_WIDTH)),  # Random crop to target size
        # Geometric augmentations
        transforms.RandomRotation(30),  # Fruits can be at various angles
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.3),
        # Color augmentations (important for ripeness detection)
        transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.4, hue=0.1),
        # Random perspective (simulates different viewing angles)
        transforms.RandomPerspective(distortion_scale=0.2, p=0.3),
        # Convert to tensor
        transforms.ToTensor(),
        # Normalize
        transforms.Normalize(mean=NORMALIZE_MEAN, std=NORMALIZE_STD),
        # Random erasing (simulates occlusion)
        transforms.RandomErasing(p=0.1, scale=(0.02, 0.1))
    ])
    print("‚úÖ Training augmentation ENABLED")
    print("   - Rotation: ¬±30¬∞")
    print("   - Random crop: Yes")
    print("   - Horizontal/Vertical flip")
    print("   - Color jitter: brightness/contrast/saturation/hue")
    print("   - Perspective distortion")
    print("   - Random erasing")
else:
    # No augmentation
    train_transforms = transforms.Compose([
        transforms.Resize((IMG_HEIGHT, IMG_WIDTH)),
        transforms.ToTensor(),
        transforms.Normalize(mean=NORMALIZE_MEAN, std=NORMALIZE_STD)
    ])
    print("‚ö†Ô∏è  Training augmentation DISABLED")

# Validation and test transforms (no augmentation)
val_test_transforms = transforms.Compose([
    transforms.Resize((IMG_HEIGHT, IMG_WIDTH)),
    transforms.ToTensor(),
    transforms.Normalize(mean=NORMALIZE_MEAN, std=NORMALIZE_STD)
])

print(f"\n‚úÖ Transforms defined successfully!")
print(f"   Target size: {IMG_HEIGHT}x{IMG_WIDTH}")
print(f"   Normalization: Mean={NORMALIZE_MEAN}, Std={NORMALIZE_STD}")

### Load Guava Datasets

Load the training and test sets.
Create validation set by splitting the training data.

In [None]:
if NUM_CLASSES is None:
    print("‚ùå Dataset not loaded. Please run the dataset analysis cell first.")
else:
    # Load full training dataset (will split into train/val)
    print("üìÇ Loading Guava training dataset...")
    full_train_dataset = GuavaDataset(
        root_dir=TRAIN_DIR,
        transform=None,  # Will assign transforms after split
    )
    
    # Load test dataset if exists
    if os.path.exists(TEST_DIR):
        print("üìÇ Loading Guava test dataset...")
        test_dataset = GuavaDataset(
            root_dir=TEST_DIR,
            transform=val_test_transforms,  # No augmentation for test
            class_mapping=full_train_dataset.class_to_idx  # Use same class mapping
        )
        num_test = len(test_dataset)
    else:
        print("‚ö†Ô∏è  No separate test set found. Will use validation set for testing.")
        test_dataset = None
        num_test = 0
    
    num_total_train = len(full_train_dataset)
    
    print(f"\n‚úÖ Datasets loaded successfully!")
    print(f"   Training samples (before split): {num_total_train:,}")
    print(f"   Test samples: {num_test:,}")
    
    # Split training into train and validation
    val_size = int(VALIDATION_SPLIT * num_total_train)
    train_size = num_total_train - val_size
    
    # Use random_split to create train/val indices
    train_subset, val_subset = random_split(
        full_train_dataset, 
        [train_size, val_size],
        generator=torch.Generator().manual_seed(SEED)
    )
    
    print(f"\n‚úÖ Training data split completed:")
    print(f"   üîπ Train: {train_size:,} samples ({(1-VALIDATION_SPLIT)*100:.0f}%)")
    print(f"   üîπ Validation: {val_size:,} samples ({VALIDATION_SPLIT*100:.0f}%)")
    print(f"   üîπ Test: {num_test:,} samples (separate set)")
    
    # Assign transforms to subsets
    # Note: This will apply to all samples, including val subset
    # We need to handle this differently
    
    print(f"\n‚úÖ Transforms will be assigned in DataLoader creation.")

### Compute Class Weights for Imbalanced Data

Calculate class weights to handle potential class imbalance in the Guava dataset.

In [None]:
if NUM_CLASSES is None:
    print("‚ùå Dataset not loaded.")
else:
    # Get class distribution from training data
    class_counts_dict = full_train_dataset.get_class_counts()
    class_counts = np.array([class_counts_dict.get(i, 0) for i in range(NUM_CLASSES)])
    
    print(f"üìä Class distribution in training set:")
    print(f"   Total classes: {len(class_counts)}")
    print(f"   Most populated class: {class_counts.max():,} samples")
    print(f"   Least populated class: {class_counts.min():,} samples")
    print(f"   Average per class: {class_counts.mean():.1f} samples")
    
    if class_counts.min() > 0:
        imbalance_ratio = class_counts.max() / class_counts.min()
        print(f"   Imbalance ratio: {imbalance_ratio:.2f}x")
    else:
        imbalance_ratio = float('inf')
        print(f"   ‚ö†Ô∏è  Warning: Some classes have 0 samples!")
    
    # Compute class weights (inverse frequency)
    class_weights = 1.0 / (class_counts + 1e-6)  # Add small epsilon to avoid division by zero
    class_weights = class_weights / class_weights.sum() * len(class_weights)  # Normalize
    
    print(f"\nüìê Class weights computed:")
    print(f"   Min weight: {class_weights.min():.4f}")
    print(f"   Max weight: {class_weights.max():.4f}")
    print(f"   Weight ratio: {class_weights.max() / class_weights.min():.2f}x")
    
    # Create sample weights for WeightedRandomSampler
    if USE_WEIGHTED_SAMPLER:
        # Get labels from training subset
        train_indices = train_subset.indices
        train_labels = [full_train_dataset.samples[idx][1] for idx in train_indices]
        sample_weights = [class_weights[label] for label in train_labels]
        
        print(f"\n‚úÖ Weighted sampler initialized for {len(sample_weights):,} training samples")
    else:
        sample_weights = None
        print("\n‚ö†Ô∏è  Weighted sampler DISABLED")

### Create DataLoaders

Initialize PyTorch DataLoaders with appropriate batch size and sampling strategy.

In [None]:
# Custom wrapper datasets to apply different transforms
class TransformDataset(torch.utils.data.Dataset):
    def __init__(self, subset, transform):
        self.subset = subset
        self.transform = transform
        
    def __len__(self):
        return len(self.subset)
    
    def __getitem__(self, idx):
        # Get the original item
        img_path, label = self.subset.dataset.samples[self.subset.indices[idx]]
        from PIL import Image
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
            
        return image, label


if NUM_CLASSES is None:
    print("‚ùå Dataset not loaded.")
else:
    # Create wrapped datasets with appropriate transforms
    train_dataset_wrapped = TransformDataset(train_subset, train_transforms)
    val_dataset_wrapped = TransformDataset(val_subset, val_test_transforms)
    
    # Create training DataLoader
    if USE_WEIGHTED_SAMPLER and sample_weights is not None:
        sampler = WeightedRandomSampler(
            weights=sample_weights,
            num_samples=len(sample_weights),
            replacement=True
        )
        train_loader = DataLoader(
            train_dataset_wrapped,
            batch_size=BATCH_SIZE,
            sampler=sampler,
            num_workers=0,  # Set to 0 for Windows compatibility
            pin_memory=True if torch.cuda.is_available() else False
        )
        print(f"‚úÖ Train DataLoader: {len(train_subset):,} samples with WeightedRandomSampler")
    else:
        train_loader = DataLoader(
            train_dataset_wrapped,
            batch_size=BATCH_SIZE,
            shuffle=True,
            num_workers=0,
            pin_memory=True if torch.cuda.is_available() else False
        )
        print(f"‚úÖ Train DataLoader: {len(train_subset):,} samples with shuffle=True")
    
    # Create validation DataLoader
    val_loader = DataLoader(
        val_dataset_wrapped,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=0,
        pin_memory=True if torch.cuda.is_available() else False
    )
    print(f"‚úÖ Validation DataLoader: {len(val_subset):,} samples")
    
    # Create test DataLoader if test set exists
    if test_dataset is not None:
        test_loader = DataLoader(
            test_dataset,
            batch_size=BATCH_SIZE,
            shuffle=False,
            num_workers=0,
            pin_memory=True if torch.cuda.is_available() else False
        )
        print(f"‚úÖ Test DataLoader: {len(test_dataset):,} samples")
    else:
        test_loader = val_loader  # Use validation as test
        print(f"‚ö†Ô∏è  Using validation set as test set")
    
    print(f"\nüì¶ Batch configuration:")
    print(f"   Batch size: {BATCH_SIZE}")
    print(f"   Train batches: {len(train_loader)}")
    print(f"   Validation batches: {len(val_loader)}")
    print(f"   Test batches: {len(test_loader)}")

### Preprocessing Summary

In [None]:
if NUM_CLASSES is not None:
    print("\n" + "="*80)
    print("üìä PREPROCESSING SUMMARY")
    print("="*80)
    print(f"{'Dataset:':<30} Guava Ripeness (Age Classification)")
    print(f"{'Total Classes:':<30} {NUM_CLASSES}")
    print(f"{'Training Samples:':<30} {train_size:,}")
    print(f"{'Validation Samples:':<30} {val_size:,}")
    print(f"{'Test Samples:':<30} {num_test:,}")
    print(f"{'Total Samples:':<30} {train_size + val_size + num_test:,}")
    print(f"\n{'Image Processing:':<30}")
    print(f"  {'- Target Size:':<28} {IMG_HEIGHT}x{IMG_WIDTH} pixels")
    print(f"  {'- Normalization:':<28} {'ImageNet' if USE_IMAGENET_NORM else 'Guava'}")
    print(f"\n{'Augmentation:':<30} {'ENABLED' if USE_AUGMENTATION else 'DISABLED'}")
    if USE_AUGMENTATION:
        print(f"  - Rotation, Flip, Color Jitter, Perspective, Erasing")
    print(f"\n{'Class Balancing:':<30}")
    print(f"  {'- Weighted Sampling:':<28} {'ENABLED' if USE_WEIGHTED_SAMPLER else 'DISABLED'}")
    if 'imbalance_ratio' in dir() and imbalance_ratio != float('inf'):
        print(f"  {'- Class Imbalance Ratio:':<28} {imbalance_ratio:.2f}x")
    print(f"\n{'Batch Configuration:':<30}")
    print(f"  {'- Batch Size:':<28} {BATCH_SIZE}")
    print(f"  {'- Train Batches/Epoch:':<28} {len(train_loader)}")
    print(f"  {'- Val Batches/Epoch:':<28} {len(val_loader)}")
    print(f"  {'- Test Batches:':<28} {len(test_loader)}")
    print("="*80)
    print("‚úÖ Preprocessing complete! Ready for model training.\n")

---
## Model Training Configuration

In [None]:
# ============================
# MODEL TRAINING CONFIGURATION
# ============================

# Training hyperparameters
LEARNING_RATE = 0.0001
MAX_EPOCHS = 30  # Maximum training epochs
WEIGHT_DECAY = 1e-4  # L2 regularization to prevent overfitting
DROPOUT_RATE = 0.4  # Dropout in classifier head

# Early stopping
EARLY_STOPPING_PATIENCE = 10  # Stop if no improvement for 10 epochs

# Gradient clipping
MAX_GRAD_NORM = 1.0  # Prevent exploding gradients

# Training history dictionary (global)
history = {
    "train_loss": [],
    "val_loss": [],
    "train_acc": [],
    "val_acc": [],
    "train_top5_acc": [],
    "val_top5_acc": [],
    "learning_rates": []
}

print("üéØ Training Configuration:")
print(f"   Model: {MODEL_ARCH.upper()}")
print(f"   Learning Rate: {LEARNING_RATE}")
print(f"   Weight Decay: {WEIGHT_DECAY}")
print(f"   Dropout Rate: {DROPOUT_RATE}")
print(f"   Max Epochs: {MAX_EPOCHS}")
print(f"   Early Stopping Patience: {EARLY_STOPPING_PATIENCE}")
print(f"   Classes: {NUM_CLASSES}")

### Load Pretrained Model Architecture

In [None]:
# Load pretrained model based on MODEL_ARCH configuration
CheckCUDA()
device = get_device()
print(f"\nüñ•Ô∏è  Using device: {device}")

if NUM_CLASSES is None:
    print("‚ùå NUM_CLASSES not set. Please run dataset analysis first.")
else:
    if MODEL_ARCH == 'resnet50':
        model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
        print("\n‚úÖ Pre-trained ResNet50 loaded (weights: IMAGENET1K_V1)")
        
        # Replace classifier head with dropout
        in_features = model.fc.in_features
        model.fc = nn.Sequential(
            nn.Dropout(DROPOUT_RATE),
            nn.Linear(in_features, NUM_CLASSES)
        )
        print(f"‚úÖ Classifier replaced: {in_features} ‚Üí Dropout({DROPOUT_RATE}) ‚Üí {NUM_CLASSES} classes")
    
    elif MODEL_ARCH == 'efficientnet_b3':
        model = models.efficientnet_b3(weights=models.EfficientNet_B3_Weights.IMAGENET1K_V1)
        print("‚úÖ Pre-trained EfficientNet-B3 loaded (weights: IMAGENET1K_V1)")
        
        # Replace classifier head with dropout
        in_features = model.classifier[1].in_features
        model.classifier = nn.Sequential(
            nn.Dropout(DROPOUT_RATE, inplace=True),
            nn.Linear(in_features, NUM_CLASSES)
        )
        print(f"‚úÖ Classifier replaced: {in_features} ‚Üí Dropout({DROPOUT_RATE}) ‚Üí {NUM_CLASSES} classes")
    
    else:
        raise ValueError(f"Unknown model architecture: {MODEL_ARCH}")
    
    # Move model to device
    model = model.to(device)
    
    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f"\nüìä Model Statistics:")
    print(f"   Total parameters: {total_params:,}")
    print(f"   Trainable parameters: {trainable_params:,}")
    print(f"   Non-trainable parameters: {total_params - trainable_params:,}")

### Define Loss Function and Optimizer

In [None]:
if NUM_CLASSES is not None:
    # Loss function with optional class weights
    if USE_WEIGHTED_SAMPLER:
        # Use weighted loss as well for extra emphasis on minority classes
        class_weights_tensor = torch.FloatTensor(class_weights).to(device)
        criterion = nn.CrossEntropyLoss(weight=class_weights_tensor)
        print("‚úÖ CrossEntropyLoss with class weights")
    else:
        criterion = nn.CrossEntropyLoss()
        print("‚úÖ CrossEntropyLoss (unweighted)")
    
    # Optimizer
    optimizer = optim.AdamW(
        model.parameters(),
        lr=LEARNING_RATE,
        weight_decay=WEIGHT_DECAY
    )
    print(f"‚úÖ AdamW optimizer (lr={LEARNING_RATE}, weight_decay={WEIGHT_DECAY})")
    
    # Learning rate scheduler (verbose deprecated in newer PyTorch, removed)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=0.5,
        patience=3
    )
    print("‚úÖ ReduceLROnPlateau scheduler (factor=0.5, patience=3)")

### Training Functions

In [None]:
def train_epoch(model, loader, criterion, optimizer, device):
    """Train for one epoch"""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(loader, desc="Training", leave=False)
    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
        
        optimizer.step()
        
        running_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
        pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{100.*correct/total:.2f}%'})
    
    epoch_loss = running_loss / total
    epoch_acc = 100. * correct / total
    
    return epoch_loss, epoch_acc


def validate_epoch(model, loader, criterion, device):
    """Validate for one epoch"""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        pbar = tqdm(loader, desc="Validation", leave=False)
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item() * images.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{100.*correct/total:.2f}%'})
    
    epoch_loss = running_loss / total
    epoch_acc = 100. * correct / total
    
    return epoch_loss, epoch_acc


print("‚úÖ Training functions defined!")

### Training Loop

In [None]:
if NUM_CLASSES is None:
    print("‚ùå Cannot train. Dataset not loaded.")
else:
    print("\n" + "="*80)
    print("üöÄ STARTING TRAINING")
    print("="*80)
    
    best_val_acc = 0.0
    best_epoch = 0
    patience_counter = 0
    
    for epoch in range(MAX_EPOCHS):
        print(f"\nüìÖ Epoch {epoch+1}/{MAX_EPOCHS}")
        print("-" * 40)
        
        # Train
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
        
        # Validate
        val_loss, val_acc = validate_epoch(model, val_loader, criterion, device)
        
        # Update scheduler
        scheduler.step(val_loss)
        current_lr = optimizer.param_groups[0]['lr']
        
        # Record history
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['train_acc'].append(train_acc)
        history['val_acc'].append(val_acc)
        history['learning_rates'].append(current_lr)
        
        # Print epoch summary
        print(f"   Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
        print(f"   Val Loss:   {val_loss:.4f} | Val Acc:   {val_acc:.2f}%")
        print(f"   LR: {current_lr:.6f}")
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_epoch = epoch + 1
            patience_counter = 0
            
            # Save model
            model_name = f"Guava_{MODEL_ARCH}_E{epoch+1}_VAL{val_acc:.2f}.pth"
            model_path = os.path.join(MODEL_SAVE_DIR, model_name)
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_acc': val_acc,
                'val_loss': val_loss,
                'class_names': CLASS_NAMES,
                'num_classes': NUM_CLASSES,
                'model_arch': MODEL_ARCH
            }, model_path)
            print(f"   üíæ Best model saved: {model_name}")
        else:
            patience_counter += 1
            print(f"   ‚è≥ No improvement ({patience_counter}/{EARLY_STOPPING_PATIENCE})")
        
        # Early stopping
        if patience_counter >= EARLY_STOPPING_PATIENCE:
            print(f"\nüõë Early stopping triggered at epoch {epoch+1}")
            break
    
    print("\n" + "="*80)
    print("‚úÖ TRAINING COMPLETE")
    print(f"   Best Validation Accuracy: {best_val_acc:.2f}% (Epoch {best_epoch})")
    print("="*80)

### Training Visualization

In [None]:
if len(history['train_loss']) > 0:
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    # Loss plot
    axes[0].plot(history['train_loss'], label='Train Loss', marker='o')
    axes[0].plot(history['val_loss'], label='Val Loss', marker='s')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Training and Validation Loss')
    axes[0].legend()
    axes[0].grid(True)
    
    # Accuracy plot
    axes[1].plot(history['train_acc'], label='Train Acc', marker='o')
    axes[1].plot(history['val_acc'], label='Val Acc', marker='s')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy (%)')
    axes[1].set_title('Training and Validation Accuracy')
    axes[1].legend()
    axes[1].grid(True)
    
    # Learning rate plot
    axes[2].plot(history['learning_rates'], marker='o', color='green')
    axes[2].set_xlabel('Epoch')
    axes[2].set_ylabel('Learning Rate')
    axes[2].set_title('Learning Rate Schedule')
    axes[2].set_yscale('log')
    axes[2].grid(True)
    
    plt.tight_layout()
    plt.savefig(os.path.join(MODEL_SAVE_DIR, 'training_curves.png'), dpi=150)
    plt.show()
    
    print("‚úÖ Training curves saved to models/training_curves.png")
else:
    print("‚ö†Ô∏è No training history to visualize.")

### Model Evaluation on Test Set

In [None]:
if NUM_CLASSES is not None and len(history['train_loss']) > 0:
    print("\n" + "="*80)
    print("üìä MODEL EVALUATION ON TEST SET")
    print("="*80)
    
    # Load best model
    best_model_files = [f for f in os.listdir(MODEL_SAVE_DIR) if f.startswith('Guava_') and f.endswith('.pth')]
    if best_model_files:
        # Sort by validation accuracy (extract from filename)
        best_model_files.sort(key=lambda x: float(x.split('VAL')[-1].replace('.pth', '')), reverse=True)
        best_model_path = os.path.join(MODEL_SAVE_DIR, best_model_files[0])
        
        checkpoint = torch.load(best_model_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        print(f"‚úÖ Loaded best model: {best_model_files[0]}")
    
    # Evaluate on test set
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Testing"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = outputs.max(1)
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    
    # Calculate metrics
    test_acc = 100. * (all_preds == all_labels).sum() / len(all_labels)
    balanced_acc = 100. * balanced_accuracy_score(all_labels, all_preds)
    
    print(f"\nüìà Test Results:")
    print(f"   Overall Accuracy: {test_acc:.2f}%")
    print(f"   Balanced Accuracy: {balanced_acc:.2f}%")
    
    # Classification report
    print(f"\nüìã Classification Report:")
    print(classification_report(all_labels, all_preds, target_names=CLASS_NAMES))
    
    # Confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(10, 8))
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=CLASS_NAMES)
    disp.plot(cmap='Blues', values_format='d')
    plt.title('Confusion Matrix - Test Set')
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.savefig(os.path.join(MODEL_SAVE_DIR, 'confusion_matrix.png'), dpi=150)
    plt.show()
    
    print("\n‚úÖ Confusion matrix saved to models/confusion_matrix.png")
    print("="*80)
else:
    print("‚ö†Ô∏è No model to evaluate. Please train the model first.")

---
## Export Model for Inference

Save the final model in a format suitable for inference (e.g., for the Vue frontend).

In [None]:
if NUM_CLASSES is not None:
    # Export model for inference
    inference_model_path = os.path.join(MODEL_SAVE_DIR, f"guava_classifier_{MODEL_ARCH}_final.pth")
    
    torch.save({
        'model_state_dict': model.state_dict(),
        'class_names': CLASS_NAMES,
        'num_classes': NUM_CLASSES,
        'model_arch': MODEL_ARCH,
        'img_size': (IMG_HEIGHT, IMG_WIDTH),
        'normalize_mean': NORMALIZE_MEAN,
        'normalize_std': NORMALIZE_STD
    }, inference_model_path)
    
    print(f"‚úÖ Inference model saved: {inference_model_path}")
    print(f"\nüì¶ Model includes:")
    print(f"   - Model weights")
    print(f"   - Class names: {CLASS_NAMES}")
    print(f"   - Image size: {IMG_HEIGHT}x{IMG_WIDTH}")
    print(f"   - Normalization values")