# üå± FasalVaidya: YOLOv8 Multi-Crop Nutrient Deficiency Classification

## üìã Overview

This notebook trains a **YOLOv8 Classification** model for detecting nutrient deficiencies across **9 crops** with **43 classes**.

### üéØ Why YOLOv8 for Classification?

- ‚ö° **Ultra-fast inference** (<10ms on GPU, <50ms on mobile)
- üéØ **State-of-the-art accuracy** for image classification
- üì± **Easy export** to ONNX, TensorFlow Lite, CoreML, TensorRT
- üîß **Simple API** with Ultralytics framework
- üìä **Built-in augmentation** and training optimization

### üåæ Supported Crops (9 total, 43 classes)

| Category | Crops | Classes |
|----------|-------|---------|
| **Cereals** | Rice, Wheat, Maize | 11 |
| **Commercial** | Banana, Coffee | 7 |
| **Vegetables** | Ashgourd, EggPlant, Snakegourd, Bittergourd | 25 |

### üìä YOLOv8 Classification Models

| Model | Size | Accuracy | Speed (CPU) | Speed (GPU) |
|-------|------|----------|-------------|-------------|
| YOLOv8n-cls | 5.3MB | Good | 12ms | 0.6ms |
| YOLOv8s-cls | 11.4MB | Better | 23ms | 0.9ms |
| YOLOv8m-cls | 36.6MB | Best | 85ms | 2.0ms |

We'll use **YOLOv8s-cls** for the best balance of accuracy and speed.

---

## üîß Section 1: Environment Setup

In [None]:
# ==========================================
# üì¶ Install Ultralytics YOLOv8 and dependencies
# ==========================================

!pip install ultralytics>=8.2.0 --quiet
!pip install opencv-python-headless pillow matplotlib seaborn tqdm --quiet

# Verify installation
import ultralytics
print(f"‚úÖ Ultralytics version: {ultralytics.__version__}")

from ultralytics import YOLO
print("‚úÖ YOLOv8 imported successfully!")

In [None]:
# ==========================================
# üñ•Ô∏è Check GPU availability
# ==========================================

import torch
import os
import shutil
from pathlib import Path

# GPU Check
print("=" * 50)
print("üñ•Ô∏è Hardware Check")
print("=" * 50)

if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"‚úÖ GPU: {gpu_name}")
    print(f"‚úÖ GPU Memory: {gpu_memory:.1f} GB")
    print(f"‚úÖ CUDA Version: {torch.version.cuda}")
    DEVICE = 'cuda'
else:
    print("‚ö†Ô∏è No GPU detected - training will be slow!")
    DEVICE = 'cpu'

print(f"‚úÖ PyTorch Version: {torch.__version__}")
print(f"‚úÖ Device: {DEVICE}")

## üìÇ Section 2: Mount Google Drive & Configure Paths

In [None]:
# ==========================================
# üìÅ Mount Google Drive
# ==========================================

from google.colab import drive
drive.mount('/content/drive')

print("‚úÖ Google Drive mounted!")

# Verify mount
!ls /content/drive/MyDrive/ | head -5

In [None]:
# ==========================================
# üåæ Define Crop Dataset Paths (Same as EfficientNet notebook)
# ==========================================

# Base paths in Google Drive
GDRIVE_BASE = "/content/drive/MyDrive"

# Dataset configurations - paths relative to GDRIVE_BASE
CROP_DATASETS = {
    'rice': 'PlantNutrientDeficiency_MobileCapture/Rice',
    'wheat': 'PlantNutrientDeficiency_MobileCapture/Wheat',
    'maize': 'maize_nutrient_deficiency',
    'banana': 'Banana_Nutrient_Deficiency_Dataset',
    'coffee': 'Coffee_Nutrient_Deficiency_dataset',
    'ashgourd': 'VegetableLeaves/Ashgourd',
    'eggplant': 'VegetableLeaves/EggPlant',
    'snakegourd': 'VegetableLeaves/Snakegourd',
    'bittergourd': 'VegetableLeaves/Bittergourd',
}

# Class renaming map - standardize class names across all datasets
CLASS_RENAME_MAP = {
    'rice': {
        'Nitrogen(N)': 'rice_nitrogen',
        'Phosphorus(P)': 'rice_phosphorus',
        'Potassium(K)': 'rice_potassium',
    },
    'wheat': {
        'control': 'wheat_control',
        'deficiency': 'wheat_deficiency',
    },
    'maize': {
        'ALL Present': 'maize_all_present',
        'ALLAB': 'maize_allab',
        'KAB': 'maize_kab',
        'NAB': 'maize_nab',
        'PAB': 'maize_pab',
        'ZNAB': 'maize_znab',
    },
    'banana': {
        'healthy': 'banana_healthy',
        'magnesium': 'banana_magnesium',
        'potassium': 'banana_potassium',
    },
    'coffee': {
        'healthy': 'coffee_healthy',
        'nitrogen-N': 'coffee_nitrogen_n',
        'phosphorus-P': 'coffee_phosphorus_p',
        'potasium-K': 'coffee_potassium_k',
    },
    'ashgourd': {
        'Deficiency of Boron': 'ashgourd_boron_deficiency',
        'Deficiency of Iron': 'ashgourd_iron_deficiency',
        'Deficiency of Manganese': 'ashgourd_manganese_deficiency',
        'Deficiency of Molybdenum': 'ashgourd_molybdenum_deficiency',
        'Deficiency of Nitrogen': 'ashgourd_nitrogen_deficiency',
        'Deficiency of Potassium': 'ashgourd_potassium_deficiency',
        'Healthy': 'ashgourd_healthy',
    },
    'eggplant': {
        'Deficiency of Magnesium': 'eggplant_magnesium_deficiency',
        'Deficiency of Nitrogen': 'eggplant_nitrogen_deficiency',
        'Deficiency of Potassium': 'eggplant_potassium_deficiency',
        'Healthy': 'eggplant_healthy',
    },
    'snakegourd': {
        'Deficiency of Copper': 'snakegourd_copper_deficiency',
        'Deficiency of Molybdenum': 'snakegourd_molybdenum_deficiency',
        'Deficiency of Nitrogen': 'snakegourd_nitrogen_deficiency',
        'Deficiency of Potassium': 'snakegourd_potassium_deficiency',
        'Healthy': 'snakegourd_healthy',
    },
    'bittergourd': {
        'Deficiency of Boron': 'bittergourd_boron_deficiency',
        'Deficiency of Calcium': 'bittergourd_calcium_deficiency',
        'Deficiency of Copper': 'bittergourd_copper_deficiency',
        'Deficiency of Iron': 'bittergourd_iron_deficiency',
        'Deficiency of Manganese': 'bittergourd_manganese_deficiency',
        'Deficiency of Nitrogen': 'bittergourd_nitrogen_deficiency',
        'Deficiency of Potassium': 'bittergourd_potassium_deficiency',
        'Deficiency of Sulfur': 'bittergourd_sulfur_deficiency',
        'Healthy': 'bittergourd_healthy',
    }
}

# Training configuration
CONFIG = {
    'img_size': 224,           # YOLOv8 classification default
    'batch_size': 32,          # Adjust based on GPU memory
    'epochs': 50,              # Training epochs
    'min_samples': 150,        # Minimum images per class
    'max_samples': 400,        # Maximum images per class
    'train_split': 0.8,        # 80% train, 20% val
    'model_variant': 'yolov8s-cls',  # Small model for balance of speed/accuracy
    'patience': 10,            # Early stopping patience
}

# Output paths
YOLO_DATASET_DIR = Path('/content/yolo_dataset')  # YOLO format dataset
OUTPUT_DIR = Path('/content/yolo_output')          # Training outputs
FINAL_MODEL_DIR = Path(f'{GDRIVE_BASE}/FasalVaidya_YOLOv8_Model')

print("‚úÖ Configuration loaded!")
print(f"üìä Total crops: {len(CROP_DATASETS)}")
total_classes = sum(len(v) for v in CLASS_RENAME_MAP.values())
print(f"üìä Total classes: {total_classes}")

## üîç Section 3: Dataset Discovery & Analysis

In [None]:
# ==========================================
# üîç Discover and analyze all datasets
# ==========================================

from collections import defaultdict
import glob

def discover_datasets():
    """Discover all available datasets and their class distributions."""
    
    dataset_info = {}
    all_classes = []
    
    print("=" * 70)
    print("üîç DATASET DISCOVERY")
    print("=" * 70)
    
    for crop, rel_path in CROP_DATASETS.items():
        full_path = Path(GDRIVE_BASE) / rel_path
        
        if not full_path.exists():
            print(f"‚ö†Ô∏è  {crop.upper()}: Path not found - {full_path}")
            continue
        
        # Find all subdirectories (classes)
        class_dirs = [d for d in full_path.iterdir() if d.is_dir()]
        
        if not class_dirs:
            print(f"‚ö†Ô∏è  {crop.upper()}: No class subdirectories found")
            continue
        
        crop_info = {
            'path': full_path,
            'classes': {},
            'total_images': 0
        }
        
        print(f"\nüå± {crop.upper()} ({full_path})")
        print("-" * 50)
        
        rename_map = CLASS_RENAME_MAP.get(crop, {})
        
        for class_dir in sorted(class_dirs):
            original_name = class_dir.name
            
            # Get standardized name or create one
            if original_name in rename_map:
                std_name = rename_map[original_name]
            else:
                # Auto-generate standardized name
                std_name = f"{crop}_{original_name.lower().replace(' ', '_').replace('-', '_')}"
            
            # Count images
            img_count = len(list(class_dir.glob('*.jpg'))) + \
                        len(list(class_dir.glob('*.jpeg'))) + \
                        len(list(class_dir.glob('*.png'))) + \
                        len(list(class_dir.glob('*.JPG'))) + \
                        len(list(class_dir.glob('*.JPEG'))) + \
                        len(list(class_dir.glob('*.PNG')))
            
            crop_info['classes'][original_name] = {
                'std_name': std_name,
                'count': img_count,
                'path': class_dir
            }
            crop_info['total_images'] += img_count
            all_classes.append(std_name)
            
            status = "‚úÖ" if img_count >= CONFIG['min_samples'] else "‚ö†Ô∏è"
            print(f"  {status} {original_name:35s} ‚Üí {std_name:40s} | {img_count:4d} images")
        
        dataset_info[crop] = crop_info
        print(f"  üìä Subtotal: {crop_info['total_images']} images across {len(crop_info['classes'])} classes")
    
    print("\n" + "=" * 70)
    print("üìä SUMMARY")
    print("=" * 70)
    total = sum(d['total_images'] for d in dataset_info.values())
    print(f"‚úÖ Found {len(dataset_info)} crops")
    print(f"‚úÖ Total classes: {len(all_classes)}")
    print(f"‚úÖ Total images: {total:,}")
    
    return dataset_info, sorted(set(all_classes))

dataset_info, all_class_names = discover_datasets()
print(f"\nüìã All classes ({len(all_class_names)}):")
for i, name in enumerate(all_class_names, 1):
    print(f"  {i:2d}. {name}")

## üìä Section 4: Visualize Dataset Distribution

In [None]:
# ==========================================
# üìä Visualize class distribution
# ==========================================

import matplotlib.pyplot as plt
import seaborn as sns

def visualize_distribution(dataset_info):
    """Create visualizations of dataset distribution."""
    
    # Collect all class data
    class_data = []
    for crop, info in dataset_info.items():
        for orig_name, class_info in info['classes'].items():
            class_data.append({
                'crop': crop,
                'class': class_info['std_name'],
                'count': class_info['count']
            })
    
    import pandas as pd
    df = pd.DataFrame(class_data)
    
    # Figure 1: Class distribution
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    
    # 1. Bar chart of all classes
    ax1 = axes[0, 0]
    colors = plt.cm.Set3(range(len(df)))
    bars = ax1.barh(df['class'], df['count'], color=[plt.cm.tab20(i % 20) for i in range(len(df))])
    ax1.axvline(x=CONFIG['min_samples'], color='r', linestyle='--', label=f"Min: {CONFIG['min_samples']}")
    ax1.axvline(x=CONFIG['max_samples'], color='g', linestyle='--', label=f"Max: {CONFIG['max_samples']}")
    ax1.set_xlabel('Number of Images')
    ax1.set_ylabel('Class')
    ax1.set_title('üìä Image Count per Class')
    ax1.legend()
    ax1.tick_params(axis='y', labelsize=7)
    
    # 2. Per-crop totals
    ax2 = axes[0, 1]
    crop_totals = df.groupby('crop')['count'].sum().sort_values(ascending=True)
    ax2.barh(crop_totals.index, crop_totals.values, color=plt.cm.Pastel1(range(len(crop_totals))))
    ax2.set_xlabel('Number of Images')
    ax2.set_title('üåæ Images per Crop')
    for i, v in enumerate(crop_totals.values):
        ax2.text(v + 10, i, str(v), va='center')
    
    # 3. Classes per crop
    ax3 = axes[1, 0]
    crop_classes = df.groupby('crop').size().sort_values(ascending=True)
    ax3.barh(crop_classes.index, crop_classes.values, color=plt.cm.Pastel2(range(len(crop_classes))))
    ax3.set_xlabel('Number of Classes')
    ax3.set_title('üìã Classes per Crop')
    for i, v in enumerate(crop_classes.values):
        ax3.text(v + 0.1, i, str(v), va='center')
    
    # 4. Distribution histogram
    ax4 = axes[1, 1]
    ax4.hist(df['count'], bins=20, edgecolor='black', alpha=0.7)
    ax4.axvline(x=CONFIG['min_samples'], color='r', linestyle='--', label=f"Min threshold: {CONFIG['min_samples']}")
    ax4.axvline(x=CONFIG['max_samples'], color='g', linestyle='--', label=f"Max threshold: {CONFIG['max_samples']}")
    ax4.axvline(x=df['count'].mean(), color='blue', linestyle='-', label=f"Mean: {df['count'].mean():.0f}")
    ax4.set_xlabel('Number of Images')
    ax4.set_ylabel('Frequency')
    ax4.set_title('üìà Distribution of Class Sizes')
    ax4.legend()
    
    plt.tight_layout()
    plt.savefig('/content/dataset_distribution.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    # Print summary stats
    print("\nüìä Dataset Statistics:")
    print(f"  ‚Ä¢ Total classes: {len(df)}")
    print(f"  ‚Ä¢ Total images: {df['count'].sum():,}")
    print(f"  ‚Ä¢ Mean images/class: {df['count'].mean():.1f}")
    print(f"  ‚Ä¢ Min images/class: {df['count'].min()}")
    print(f"  ‚Ä¢ Max images/class: {df['count'].max()}")
    
    under_min = len(df[df['count'] < CONFIG['min_samples']])
    over_max = len(df[df['count'] > CONFIG['max_samples']])
    print(f"\n‚ö†Ô∏è Classes below minimum ({CONFIG['min_samples']}): {under_min}")
    print(f"üìà Classes above maximum ({CONFIG['max_samples']}): {over_max}")
    
    return df

class_df = visualize_distribution(dataset_info)

## ‚öñÔ∏è Section 5: Create Balanced YOLO Dataset

YOLOv8 classification expects the following directory structure:
```
dataset/
‚îú‚îÄ‚îÄ train/
‚îÇ   ‚îú‚îÄ‚îÄ class1/
‚îÇ   ‚îÇ   ‚îú‚îÄ‚îÄ img1.jpg
‚îÇ   ‚îÇ   ‚îî‚îÄ‚îÄ img2.jpg
‚îÇ   ‚îî‚îÄ‚îÄ class2/
‚îÇ       ‚îî‚îÄ‚îÄ ...
‚îî‚îÄ‚îÄ val/
    ‚îú‚îÄ‚îÄ class1/
    ‚îî‚îÄ‚îÄ class2/
```

In [None]:
# ==========================================
# ‚öñÔ∏è Create balanced YOLO-format dataset
# ==========================================

import random
from PIL import Image
from tqdm.auto import tqdm

def create_yolo_dataset(dataset_info, all_classes, output_dir, 
                        min_samples=150, max_samples=400, train_split=0.8):
    """
    Create a balanced YOLO classification dataset.
    
    - Undersample classes with > max_samples
    - Augment classes with < min_samples
    - Split into train/val
    """
    
    output_dir = Path(output_dir)
    train_dir = output_dir / 'train'
    val_dir = output_dir / 'val'
    
    # Clean previous dataset
    if output_dir.exists():
        shutil.rmtree(output_dir)
    
    # Create directories
    for class_name in all_classes:
        (train_dir / class_name).mkdir(parents=True, exist_ok=True)
        (val_dir / class_name).mkdir(parents=True, exist_ok=True)
    
    print("=" * 70)
    print("‚öñÔ∏è CREATING BALANCED YOLO DATASET")
    print("=" * 70)
    print(f"üìÅ Output: {output_dir}")
    print(f"üìä Min samples/class: {min_samples}")
    print(f"üìä Max samples/class: {max_samples}")
    print(f"üìä Train/Val split: {train_split:.0%}/{1-train_split:.0%}")
    print()
    
    stats = {'train': defaultdict(int), 'val': defaultdict(int), 'augmented': 0}
    
    for crop, info in tqdm(dataset_info.items(), desc="Processing crops"):
        for orig_name, class_info in info['classes'].items():
            std_name = class_info['std_name']
            src_path = class_info['path']
            
            # Get all image files
            images = list(src_path.glob('*.jpg')) + \
                     list(src_path.glob('*.jpeg')) + \
                     list(src_path.glob('*.png')) + \
                     list(src_path.glob('*.JPG')) + \
                     list(src_path.glob('*.JPEG')) + \
                     list(src_path.glob('*.PNG'))
            
            if not images:
                continue
            
            # Shuffle images
            random.shuffle(images)
            
            # Balance: undersample or prepare for augmentation
            count = len(images)
            
            if count > max_samples:
                # Undersample
                images = images[:max_samples]
            
            # Split into train/val
            split_idx = int(len(images) * train_split)
            train_images = images[:split_idx]
            val_images = images[split_idx:]
            
            # Copy validation images (no augmentation)
            for img_path in val_images:
                dst = val_dir / std_name / f"{std_name}_{img_path.name}"
                shutil.copy2(img_path, dst)
                stats['val'][std_name] += 1
            
            # Process training images
            train_target = int(min_samples * train_split) if count < min_samples else len(train_images)
            
            for i, img_path in enumerate(train_images):
                dst = train_dir / std_name / f"{std_name}_{img_path.name}"
                shutil.copy2(img_path, dst)
                stats['train'][std_name] += 1
            
            # Augment if needed
            if count < min_samples:
                needed = train_target - len(train_images)
                if needed > 0:
                    augment_images(train_images, train_dir / std_name, std_name, needed, stats)
    
    # Print summary
    print("\n" + "=" * 70)
    print("üìä DATASET SUMMARY")
    print("=" * 70)
    
    total_train = sum(stats['train'].values())
    total_val = sum(stats['val'].values())
    
    print(f"‚úÖ Train images: {total_train:,}")
    print(f"‚úÖ Val images: {total_val:,}")
    print(f"‚úÖ Total images: {total_train + total_val:,}")
    print(f"üìà Augmented images: {stats['augmented']:,}")
    
    return output_dir, stats


def augment_images(source_images, output_dir, class_name, count, stats):
    """Create augmented copies of images."""
    from PIL import ImageEnhance, ImageFilter
    
    augmentations = [
        lambda img: img.transpose(Image.FLIP_LEFT_RIGHT),
        lambda img: img.rotate(15, fillcolor=(255, 255, 255)),
        lambda img: img.rotate(-15, fillcolor=(255, 255, 255)),
        lambda img: ImageEnhance.Brightness(img).enhance(1.2),
        lambda img: ImageEnhance.Brightness(img).enhance(0.8),
        lambda img: ImageEnhance.Contrast(img).enhance(1.2),
        lambda img: img.filter(ImageFilter.GaussianBlur(radius=1)),
    ]
    
    created = 0
    idx = 0
    
    while created < count:
        src_img_path = source_images[idx % len(source_images)]
        aug_fn = random.choice(augmentations)
        
        try:
            with Image.open(src_img_path) as img:
                img = img.convert('RGB')
                augmented = aug_fn(img)
                
                dst_name = f"{class_name}_aug_{created}_{src_img_path.stem}.jpg"
                augmented.save(output_dir / dst_name, 'JPEG', quality=95)
                
                created += 1
                stats['train'][class_name] += 1
                stats['augmented'] += 1
        except Exception as e:
            pass
        
        idx += 1
        if idx > count * 2:  # Safety break
            break

# Create the dataset
yolo_dataset_path, dataset_stats = create_yolo_dataset(
    dataset_info, 
    all_class_names,
    YOLO_DATASET_DIR,
    min_samples=CONFIG['min_samples'],
    max_samples=CONFIG['max_samples'],
    train_split=CONFIG['train_split']
)

In [None]:
# ==========================================
# üìã Verify dataset structure
# ==========================================

def verify_yolo_dataset(dataset_path):
    """Verify the YOLO dataset structure and print statistics."""
    
    dataset_path = Path(dataset_path)
    
    print("=" * 70)
    print("üìã YOLO DATASET VERIFICATION")
    print("=" * 70)
    
    for split in ['train', 'val']:
        split_path = dataset_path / split
        
        if not split_path.exists():
            print(f"‚ö†Ô∏è {split} directory not found!")
            continue
        
        classes = sorted([d.name for d in split_path.iterdir() if d.is_dir()])
        
        print(f"\nüìÅ {split.upper()} ({len(classes)} classes)")
        print("-" * 50)
        
        total = 0
        class_counts = {}
        for class_name in classes:
            class_path = split_path / class_name
            count = len(list(class_path.glob('*.jpg'))) + len(list(class_path.glob('*.jpeg'))) + len(list(class_path.glob('*.png')))
            class_counts[class_name] = count
            total += count
        
        # Print sorted by count
        for class_name, count in sorted(class_counts.items(), key=lambda x: x[1], reverse=True)[:10]:
            print(f"  {class_name:45s}: {count:4d} images")
        
        if len(classes) > 10:
            print(f"  ... and {len(classes) - 10} more classes")
        
        print(f"\n  üìä Total: {total:,} images")
    
    # Save class names file
    classes_file = dataset_path / 'classes.txt'
    train_classes = sorted([d.name for d in (dataset_path / 'train').iterdir() if d.is_dir()])
    with open(classes_file, 'w') as f:
        for cls in train_classes:
            f.write(f"{cls}\n")
    print(f"\n‚úÖ Saved class names to {classes_file}")
    
    return train_classes

yolo_classes = verify_yolo_dataset(YOLO_DATASET_DIR)
print(f"\nüìä Total classes for training: {len(yolo_classes)}")

## üèãÔ∏è Section 6: Train YOLOv8 Classification Model

### Model Options:
- `yolov8n-cls`: Nano - fastest, smallest (5.3MB)
- `yolov8s-cls`: Small - best balance (11.4MB) ‚¨ÖÔ∏è **Recommended**
- `yolov8m-cls`: Medium - more accurate (36.6MB)
- `yolov8l-cls`: Large - higher accuracy (83.3MB)
- `yolov8x-cls`: XLarge - highest accuracy (136.0MB)

In [None]:
# ==========================================
# üèãÔ∏è Initialize and train YOLOv8 Classification Model
# ==========================================

from ultralytics import YOLO
import yaml

# Create output directory
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# Model selection
MODEL_VARIANT = CONFIG['model_variant']  # 'yolov8s-cls' recommended

print("=" * 70)
print("üèãÔ∏è YOLOv8 CLASSIFICATION TRAINING")
print("=" * 70)
print(f"üì¶ Model: {MODEL_VARIANT}")
print(f"üìÅ Dataset: {YOLO_DATASET_DIR}")
print(f"üìä Image size: {CONFIG['img_size']}x{CONFIG['img_size']}")
print(f"üìä Batch size: {CONFIG['batch_size']}")
print(f"üìä Epochs: {CONFIG['epochs']}")
print(f"üìä Early stopping patience: {CONFIG['patience']}")
print()

# Load pretrained YOLOv8 classification model
model = YOLO(f'{MODEL_VARIANT}.pt')

print(f"‚úÖ Loaded {MODEL_VARIANT} pretrained model")
print(f"üìä Model info:")
print(model.info())

In [None]:
# ==========================================
# üöÄ Start Training
# ==========================================

import time

start_time = time.time()

# Train the model
# YOLOv8 automatically handles:
# - Data augmentation (flips, rotations, color jitter)
# - Learning rate scheduling
# - Best model checkpointing
# - Early stopping

results = model.train(
    data=str(YOLO_DATASET_DIR),     # Path to dataset
    epochs=CONFIG['epochs'],         # Number of epochs
    imgsz=CONFIG['img_size'],        # Image size
    batch=CONFIG['batch_size'],      # Batch size
    patience=CONFIG['patience'],     # Early stopping patience
    device=0 if DEVICE == 'cuda' else 'cpu',  # GPU or CPU
    project=str(OUTPUT_DIR),         # Output directory
    name='fasalvaidya_yolov8',       # Run name
    exist_ok=True,                   # Overwrite existing
    pretrained=True,                 # Use pretrained weights
    optimizer='AdamW',               # Optimizer
    lr0=0.001,                       # Initial learning rate
    lrf=0.01,                        # Final LR (lr0 * lrf)
    momentum=0.937,                  # Momentum
    weight_decay=0.0005,             # Weight decay
    warmup_epochs=3,                 # Warmup epochs
    warmup_momentum=0.8,             # Warmup momentum
    warmup_bias_lr=0.1,              # Warmup bias LR
    close_mosaic=0,                  # Disable mosaic (for classification)
    amp=True,                        # Mixed precision
    fraction=1.0,                    # Dataset fraction
    seed=42,                         # Random seed
    verbose=True,                    # Verbose output
    plots=True,                      # Generate plots
)

elapsed = time.time() - start_time
print(f"\n‚úÖ Training completed in {elapsed/60:.1f} minutes ({elapsed/3600:.2f} hours)")
print(f"üìÅ Results saved to: {OUTPUT_DIR / 'fasalvaidya_yolov8'}")

## üìà Section 7: Visualize Training Results

In [None]:
# ==========================================
# üìà Display training curves and results
# ==========================================

from IPython.display import Image as IPImage, display
import matplotlib.pyplot as plt
from PIL import Image

# Training results directory
results_dir = OUTPUT_DIR / 'fasalvaidya_yolov8'

print("=" * 70)
print("üìà TRAINING RESULTS")
print("=" * 70)

# Display training curves
curves_to_show = ['results.png', 'confusion_matrix.png', 'confusion_matrix_normalized.png']

for curve_name in curves_to_show:
    curve_path = results_dir / curve_name
    if curve_path.exists():
        print(f"\nüìä {curve_name}:")
        display(IPImage(filename=str(curve_path), width=800))
    else:
        print(f"‚ö†Ô∏è {curve_name} not found")

# Print metrics from CSV if available
results_csv = results_dir / 'results.csv'
if results_csv.exists():
    import pandas as pd
    df = pd.read_csv(results_csv)
    df.columns = df.columns.str.strip()  # Clean column names
    
    print("\nüìä Training Metrics (last 5 epochs):")
    print(df.tail().to_string())
    
    # Best epoch
    if 'metrics/accuracy_top1' in df.columns:
        best_idx = df['metrics/accuracy_top1'].idxmax()
        print(f"\nüèÜ Best Top-1 Accuracy: {df.loc[best_idx, 'metrics/accuracy_top1']:.4f} (epoch {best_idx + 1})")
    if 'metrics/accuracy_top5' in df.columns:
        print(f"üèÜ Best Top-5 Accuracy: {df['metrics/accuracy_top5'].max():.4f}")

## üß™ Section 8: Evaluate Model on Validation Set

In [None]:
# ==========================================
# üß™ Run validation with the best model
# ==========================================

# Load best model
best_model_path = results_dir / 'weights' / 'best.pt'

if best_model_path.exists():
    print(f"‚úÖ Loading best model from: {best_model_path}")
    best_model = YOLO(str(best_model_path))
    
    # Run validation
    print("\nüß™ Running validation...")
    val_results = best_model.val(
        data=str(YOLO_DATASET_DIR),
        split='val',
        imgsz=CONFIG['img_size'],
        batch=CONFIG['batch_size'],
        verbose=True
    )
    
    print("\n" + "=" * 70)
    print("üìä VALIDATION RESULTS")
    print("=" * 70)
    print(f"‚úÖ Top-1 Accuracy: {val_results.top1:.4f} ({val_results.top1*100:.2f}%)")
    print(f"‚úÖ Top-5 Accuracy: {val_results.top5:.4f} ({val_results.top5*100:.2f}%)")
    
else:
    print(f"‚ö†Ô∏è Best model not found at {best_model_path}")

In [None]:
# ==========================================
# üìä Per-class accuracy analysis
# ==========================================

from sklearn.metrics import classification_report, confusion_matrix
import numpy as np

def analyze_per_class_performance(model, dataset_path, class_names):
    """Analyze per-class performance with detailed metrics."""
    
    val_dir = Path(dataset_path) / 'val'
    
    y_true = []
    y_pred = []
    confidences = []
    
    print("üîç Running predictions on validation set...")
    
    for class_idx, class_name in enumerate(tqdm(class_names)):
        class_dir = val_dir / class_name
        if not class_dir.exists():
            continue
        
        images = list(class_dir.glob('*.jpg')) + list(class_dir.glob('*.jpeg')) + list(class_dir.glob('*.png'))
        
        for img_path in images:
            try:
                results = model.predict(str(img_path), verbose=False)
                pred_class = results[0].probs.top1
                confidence = results[0].probs.top1conf.item()
                
                y_true.append(class_idx)
                y_pred.append(pred_class)
                confidences.append(confidence)
            except Exception as e:
                pass
    
    # Classification report
    print("\n" + "=" * 70)
    print("üìä CLASSIFICATION REPORT")
    print("=" * 70)
    
    report = classification_report(y_true, y_pred, target_names=class_names, output_dict=True)
    
    # Print per-class metrics
    print(f"\n{'Class':<45} {'Precision':>10} {'Recall':>10} {'F1-Score':>10} {'Support':>10}")
    print("-" * 90)
    
    class_metrics = []
    for class_name in class_names:
        if class_name in report:
            m = report[class_name]
            class_metrics.append({
                'name': class_name,
                'precision': m['precision'],
                'recall': m['recall'],
                'f1': m['f1-score'],
                'support': m['support']
            })
            print(f"{class_name:<45} {m['precision']:>10.3f} {m['recall']:>10.3f} {m['f1-score']:>10.3f} {m['support']:>10.0f}")
    
    # Summary
    print("-" * 90)
    print(f"{'Macro Avg':<45} {report['macro avg']['precision']:>10.3f} {report['macro avg']['recall']:>10.3f} {report['macro avg']['f1-score']:>10.3f}")
    print(f"{'Weighted Avg':<45} {report['weighted avg']['precision']:>10.3f} {report['weighted avg']['recall']:>10.3f} {report['weighted avg']['f1-score']:>10.3f}")
    
    # Find worst performing classes
    print("\n‚ö†Ô∏è Bottom 5 Classes by F1-Score:")
    sorted_classes = sorted(class_metrics, key=lambda x: x['f1'])
    for i, m in enumerate(sorted_classes[:5], 1):
        print(f"  {i}. {m['name']}: F1={m['f1']:.3f}, Precision={m['precision']:.3f}, Recall={m['recall']:.3f}")
    
    # Average confidence
    print(f"\nüìä Average Prediction Confidence: {np.mean(confidences):.3f}")
    
    return report, y_true, y_pred

# Run analysis
if best_model_path.exists():
    report, y_true, y_pred = analyze_per_class_performance(best_model, YOLO_DATASET_DIR, yolo_classes)

## üñºÔ∏è Section 9: Test Predictions on Sample Images

In [None]:
# ==========================================
# üñºÔ∏è Visualize predictions on random samples
# ==========================================

def visualize_predictions(model, dataset_path, class_names, num_samples=12):
    """Show predictions on random validation samples."""
    
    val_dir = Path(dataset_path) / 'val'
    
    # Collect random images from different classes
    all_images = []
    for class_name in class_names:
        class_dir = val_dir / class_name
        if class_dir.exists():
            images = list(class_dir.glob('*.jpg')) + list(class_dir.glob('*.jpeg'))
            for img in images[:3]:  # Max 3 per class
                all_images.append((img, class_name))
    
    # Random sample
    random.shuffle(all_images)
    samples = all_images[:num_samples]
    
    # Create figure
    cols = 4
    rows = (num_samples + cols - 1) // cols
    fig, axes = plt.subplots(rows, cols, figsize=(16, 4*rows))
    axes = axes.flatten() if num_samples > 1 else [axes]
    
    for idx, (img_path, true_class) in enumerate(samples):
        # Predict
        results = model.predict(str(img_path), verbose=False)
        pred_class_idx = results[0].probs.top1
        pred_class = class_names[pred_class_idx]
        confidence = results[0].probs.top1conf.item()
        
        # Load image
        img = Image.open(img_path).convert('RGB')
        
        # Plot
        ax = axes[idx]
        ax.imshow(img)
        
        correct = pred_class == true_class
        color = 'green' if correct else 'red'
        symbol = '‚úÖ' if correct else '‚ùå'
        
        ax.set_title(f"{symbol} Pred: {pred_class}\n(True: {true_class})\nConf: {confidence:.2%}", 
                     fontsize=9, color=color)
        ax.axis('off')
    
    # Hide empty subplots
    for idx in range(len(samples), len(axes)):
        axes[idx].axis('off')
    
    plt.tight_layout()
    plt.savefig(str(OUTPUT_DIR / 'fasalvaidya_yolov8' / 'sample_predictions.png'), dpi=150, bbox_inches='tight')
    plt.show()

# Show predictions
if best_model_path.exists():
    print("üñºÔ∏è Sample Predictions:")
    visualize_predictions(best_model, YOLO_DATASET_DIR, yolo_classes, num_samples=12)

## üì¶ Section 10: Export Model to Multiple Formats

YOLOv8 supports export to:
- **ONNX**: Cross-platform, web deployment
- **TensorFlow Lite**: Mobile (Android/iOS)
- **CoreML**: iOS native
- **TensorRT**: NVIDIA GPU optimization
- **OpenVINO**: Intel CPU optimization
- **NCNN**: Mobile/embedded

In [None]:
# ==========================================
# üì¶ Export to ONNX format
# ==========================================

print("=" * 70)
print("üì¶ EXPORTING MODELS")
print("=" * 70)

if best_model_path.exists():
    # Export to ONNX
    print("\nüîÑ Exporting to ONNX...")
    onnx_path = best_model.export(
        format='onnx',
        imgsz=CONFIG['img_size'],
        half=False,  # FP32 for compatibility
        simplify=True,
        opset=12,
    )
    print(f"‚úÖ ONNX model saved: {onnx_path}")
    
    # Get file size
    onnx_size = Path(onnx_path).stat().st_size / (1024 * 1024)
    print(f"üìä ONNX model size: {onnx_size:.2f} MB")
else:
    print("‚ö†Ô∏è Best model not found, skipping export")

In [None]:
# ==========================================
# üì± Export to TensorFlow Lite (for mobile)
# ==========================================

if best_model_path.exists():
    print("\nüîÑ Exporting to TensorFlow Lite...")
    
    try:
        tflite_path = best_model.export(
            format='tflite',
            imgsz=CONFIG['img_size'],
            half=False,  # FP32
        )
        print(f"‚úÖ TFLite model saved: {tflite_path}")
        
        # Get file size
        tflite_size = Path(tflite_path).stat().st_size / (1024 * 1024)
        print(f"üìä TFLite model size: {tflite_size:.2f} MB")
        
    except Exception as e:
        print(f"‚ö†Ô∏è TFLite export failed: {e}")
        print("   This is common on some systems. Try the following:")
        print("   !pip install tensorflow>=2.10.0")

In [None]:
# ==========================================
# üçé Export to CoreML (for iOS)
# ==========================================

if best_model_path.exists():
    print("\nüîÑ Exporting to CoreML (iOS)...")
    
    try:
        coreml_path = best_model.export(
            format='coreml',
            imgsz=CONFIG['img_size'],
            half=False,
        )
        print(f"‚úÖ CoreML model saved: {coreml_path}")
        
    except Exception as e:
        print(f"‚ö†Ô∏è CoreML export failed: {e}")
        print("   Install coremltools: pip install coremltools")

## üíæ Section 11: Save Final Model and Metadata

In [None]:
# ==========================================
# üíæ Save model and metadata to Google Drive
# ==========================================

import json
from datetime import datetime

# Create final model directory
FINAL_MODEL_DIR.mkdir(parents=True, exist_ok=True)

print("=" * 70)
print("üíæ SAVING FINAL MODEL")
print("=" * 70)
print(f"üìÅ Destination: {FINAL_MODEL_DIR}")

# Copy best model weights
if best_model_path.exists():
    shutil.copy2(best_model_path, FINAL_MODEL_DIR / 'best.pt')
    print(f"‚úÖ Copied best.pt")

# Copy last model weights
last_model_path = results_dir / 'weights' / 'last.pt'
if last_model_path.exists():
    shutil.copy2(last_model_path, FINAL_MODEL_DIR / 'last.pt')
    print(f"‚úÖ Copied last.pt")

# Copy ONNX model
onnx_model = results_dir / 'weights' / 'best.onnx'
if onnx_model.exists():
    shutil.copy2(onnx_model, FINAL_MODEL_DIR / 'fasalvaidya_yolov8.onnx')
    print(f"‚úÖ Copied ONNX model")

# Copy TFLite model
for tflite_file in (results_dir / 'weights').glob('*.tflite'):
    shutil.copy2(tflite_file, FINAL_MODEL_DIR / 'fasalvaidya_yolov8.tflite')
    print(f"‚úÖ Copied TFLite model")
    break

# Save class labels
with open(FINAL_MODEL_DIR / 'labels.txt', 'w') as f:
    for class_name in yolo_classes:
        f.write(f"{class_name}\n")
print(f"‚úÖ Saved labels.txt ({len(yolo_classes)} classes)")

# Save metadata
metadata = {
    'model_name': 'FasalVaidya YOLOv8 Classification',
    'model_variant': CONFIG['model_variant'],
    'version': '1.0.0',
    'created_at': datetime.now().isoformat(),
    'framework': 'Ultralytics YOLOv8',
    'ultralytics_version': ultralytics.__version__,
    'pytorch_version': torch.__version__,
    'input_shape': [1, 3, CONFIG['img_size'], CONFIG['img_size']],
    'image_size': CONFIG['img_size'],
    'num_classes': len(yolo_classes),
    'classes': yolo_classes,
    'crops': list(CROP_DATASETS.keys()),
    'training_config': CONFIG,
    'class_mapping': CLASS_RENAME_MAP,
    'metrics': {
        'top1_accuracy': float(val_results.top1) if 'val_results' in dir() else None,
        'top5_accuracy': float(val_results.top5) if 'val_results' in dir() else None,
    },
    'export_formats': ['pt', 'onnx', 'tflite', 'coreml'],
}

with open(FINAL_MODEL_DIR / 'metadata.json', 'w') as f:
    json.dump(metadata, f, indent=2)
print(f"‚úÖ Saved metadata.json")

# Copy training results
if (results_dir / 'results.csv').exists():
    shutil.copy2(results_dir / 'results.csv', FINAL_MODEL_DIR / 'training_results.csv')
    print(f"‚úÖ Copied training_results.csv")

# Copy confusion matrix
for img_file in results_dir.glob('*.png'):
    shutil.copy2(img_file, FINAL_MODEL_DIR / img_file.name)
print(f"‚úÖ Copied result images")

print("\nüì¶ Final model package contents:")
for f in sorted(FINAL_MODEL_DIR.iterdir()):
    size = f.stat().st_size / (1024 * 1024) if f.is_file() else 0
    print(f"  üìÑ {f.name}: {size:.2f} MB" if size > 0.01 else f"  üìÑ {f.name}")

## üîç Section 12: Validate Exported Models

In [None]:
# ==========================================
# üîç Validate ONNX Model
# ==========================================

print("=" * 70)
print("üîç VALIDATING EXPORTED MODELS")
print("=" * 70)

# Test ONNX model
onnx_path = FINAL_MODEL_DIR / 'fasalvaidya_yolov8.onnx'
if onnx_path.exists():
    print("\nüì¶ Testing ONNX model...")
    
    try:
        import onnxruntime as ort
        
        # Create session
        session = ort.InferenceSession(str(onnx_path))
        
        # Get input info
        input_info = session.get_inputs()[0]
        print(f"  ‚úÖ Input name: {input_info.name}")
        print(f"  ‚úÖ Input shape: {input_info.shape}")
        print(f"  ‚úÖ Input type: {input_info.type}")
        
        # Get output info
        output_info = session.get_outputs()[0]
        print(f"  ‚úÖ Output name: {output_info.name}")
        print(f"  ‚úÖ Output shape: {output_info.shape}")
        
        # Test inference with dummy data
        import numpy as np
        dummy_input = np.random.randn(1, 3, CONFIG['img_size'], CONFIG['img_size']).astype(np.float32)
        outputs = session.run(None, {input_info.name: dummy_input})
        print(f"  ‚úÖ Test inference successful!")
        print(f"  ‚úÖ Output shape: {outputs[0].shape}")
        
    except ImportError:
        print("  ‚ö†Ô∏è onnxruntime not installed. Run: pip install onnxruntime")
    except Exception as e:
        print(f"  ‚ùå ONNX validation failed: {e}")
else:
    print("‚ö†Ô∏è ONNX model not found")

In [None]:
# ==========================================
# üîç Validate TFLite Model
# ==========================================

tflite_path = FINAL_MODEL_DIR / 'fasalvaidya_yolov8.tflite'
if tflite_path.exists():
    print("\nüì± Testing TFLite model...")
    
    try:
        import tensorflow as tf
        
        # Load TFLite model
        interpreter = tf.lite.Interpreter(model_path=str(tflite_path))
        interpreter.allocate_tensors()
        
        # Get input/output details
        input_details = interpreter.get_input_details()
        output_details = interpreter.get_output_details()
        
        print(f"  ‚úÖ Input shape: {input_details[0]['shape']}")
        print(f"  ‚úÖ Input dtype: {input_details[0]['dtype']}")
        print(f"  ‚úÖ Output shape: {output_details[0]['shape']}")
        
        # Test inference
        input_shape = input_details[0]['shape']
        test_input = np.random.randn(*input_shape).astype(np.float32)
        interpreter.set_tensor(input_details[0]['index'], test_input)
        interpreter.invoke()
        output = interpreter.get_tensor(output_details[0]['index'])
        
        print(f"  ‚úÖ Test inference successful!")
        print(f"  ‚úÖ Output shape: {output.shape}")
        
    except ImportError:
        print("  ‚ö†Ô∏è TensorFlow not installed")
    except Exception as e:
        print(f"  ‚ùå TFLite validation failed: {e}")
else:
    print("‚ö†Ô∏è TFLite model not found")

## ‚ö° Section 13: Benchmark Inference Speed

In [None]:
# ==========================================
# ‚ö° Benchmark inference speed
# ==========================================

import time
import numpy as np

def benchmark_model(model_path, model_type, num_runs=100):
    """Benchmark inference speed for different model formats."""
    
    # Create dummy image
    dummy_img = np.random.randint(0, 255, (CONFIG['img_size'], CONFIG['img_size'], 3), dtype=np.uint8)
    
    # Warmup runs
    warmup = 10
    
    if model_type == 'pytorch':
        model = YOLO(str(model_path))
        # Warmup
        for _ in range(warmup):
            model.predict(dummy_img, verbose=False)
        
        # Benchmark
        times = []
        for _ in range(num_runs):
            start = time.perf_counter()
            model.predict(dummy_img, verbose=False)
            times.append((time.perf_counter() - start) * 1000)  # ms
        
        return np.mean(times), np.std(times)
    
    elif model_type == 'onnx':
        import onnxruntime as ort
        session = ort.InferenceSession(str(model_path))
        input_name = session.get_inputs()[0].name
        
        # Prepare input
        img = dummy_img.transpose(2, 0, 1).astype(np.float32) / 255.0
        img = np.expand_dims(img, 0)
        
        # Warmup
        for _ in range(warmup):
            session.run(None, {input_name: img})
        
        # Benchmark
        times = []
        for _ in range(num_runs):
            start = time.perf_counter()
            session.run(None, {input_name: img})
            times.append((time.perf_counter() - start) * 1000)
        
        return np.mean(times), np.std(times)
    
    return None, None

print("=" * 70)
print("‚ö° INFERENCE SPEED BENCHMARK")
print("=" * 70)
print(f"üìä Running {100} iterations per model\n")

# Benchmark PyTorch model
pt_path = FINAL_MODEL_DIR / 'best.pt'
if pt_path.exists():
    mean_time, std_time = benchmark_model(pt_path, 'pytorch', num_runs=100)
    print(f"üî• PyTorch (.pt):  {mean_time:.2f} ¬± {std_time:.2f} ms  ({1000/mean_time:.1f} FPS)")

# Benchmark ONNX model
onnx_path = FINAL_MODEL_DIR / 'fasalvaidya_yolov8.onnx'
if onnx_path.exists():
    try:
        mean_time, std_time = benchmark_model(onnx_path, 'onnx', num_runs=100)
        print(f"üì¶ ONNX:           {mean_time:.2f} ¬± {std_time:.2f} ms  ({1000/mean_time:.1f} FPS)")
    except Exception as e:
        print(f"‚ö†Ô∏è ONNX benchmark failed: {e}")

print("\nüí° Note: TFLite benchmarking is best done on actual mobile devices")

## üì• Section 14: Download Model Package

In [None]:
# ==========================================
# üì• Create downloadable ZIP archive
# ==========================================

import zipfile

# Create ZIP file
zip_filename = f"FasalVaidya_YOLOv8_Model_{datetime.now().strftime('%Y%m%d_%H%M%S')}.zip"
zip_path = Path('/content') / zip_filename

print("=" * 70)
print("üì• CREATING DOWNLOAD PACKAGE")
print("=" * 70)

with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
    for file in FINAL_MODEL_DIR.rglob('*'):
        if file.is_file():
            arcname = file.relative_to(FINAL_MODEL_DIR)
            zipf.write(file, arcname)
            print(f"  üìÑ Added: {arcname}")

zip_size = zip_path.stat().st_size / (1024 * 1024)
print(f"\n‚úÖ Created: {zip_path}")
print(f"üìä Size: {zip_size:.2f} MB")

# Download link (for Colab)
try:
    from google.colab import files
    print("\nüì• Click below to download:")
    files.download(str(zip_path))
except ImportError:
    print("\nüí° Download from: " + str(zip_path))

print(f"\nüìÅ Model also saved to Google Drive at:")
print(f"   {FINAL_MODEL_DIR}")

## üìã Section 15: Summary & Next Steps

In [None]:
# ==========================================
# üìã Print final summary
# ==========================================

print("=" * 70)
print("üéâ TRAINING COMPLETE - SUMMARY")
print("=" * 70)

print(f"""
üì¶ Model Information:
   ‚Ä¢ Architecture: {CONFIG['model_variant']}
   ‚Ä¢ Framework: Ultralytics YOLOv8 {ultralytics.__version__}
   ‚Ä¢ Number of Classes: {len(yolo_classes)}
   ‚Ä¢ Input Size: {CONFIG['img_size']}x{CONFIG['img_size']}

üåæ Crops Covered:
   ‚Ä¢ Rice, Wheat, Maize (Cereals)
   ‚Ä¢ Banana, Coffee (Commercial)
   ‚Ä¢ Ashgourd, Eggplant, Snakegourd, Bittergourd (Vegetables)

üìä Training Configuration:
   ‚Ä¢ Epochs: {CONFIG['epochs']}
   ‚Ä¢ Batch Size: {CONFIG['batch_size']}
   ‚Ä¢ Optimizer: AdamW
   ‚Ä¢ Learning Rate: 0.001 ‚Üí 0.00001

üìÅ Output Files:
   ‚Ä¢ best.pt - PyTorch weights (best validation accuracy)
   ‚Ä¢ last.pt - PyTorch weights (final epoch)
   ‚Ä¢ fasalvaidya_yolov8.onnx - ONNX format
   ‚Ä¢ fasalvaidya_yolov8.tflite - TensorFlow Lite
   ‚Ä¢ labels.txt - Class labels
   ‚Ä¢ metadata.json - Model metadata

üöÄ Next Steps:
   1. Integrate into FasalVaidya backend (ml/inference_yolov8.py)
   2. Test on mobile devices
   3. Compare accuracy with EfficientNet-B0 model
   4. Deploy to production

üìç Model saved to Google Drive:
   {FINAL_MODEL_DIR}
""")

# Print accuracy if available
if 'val_results' in dir():
    print(f"üìä Final Metrics:")
    print(f"   ‚Ä¢ Top-1 Accuracy: {val_results.top1*100:.2f}%")
    print(f"   ‚Ä¢ Top-5 Accuracy: {val_results.top5*100:.2f}%")

---

## üìö Appendix A: Using the Trained Model

### Python Inference Example:

```python
from ultralytics import YOLO

# Load the trained model
model = YOLO('path/to/best.pt')

# Predict on an image
results = model.predict('leaf_image.jpg')

# Get top prediction
top_class = results[0].probs.top1
top_conf = results[0].probs.top1conf
class_name = results[0].names[top_class]

print(f"Prediction: {class_name} ({top_conf:.2%})")
```

### Mobile Integration (React Native with ONNX):

```javascript
import * as ort from 'onnxruntime-react-native';

// Load ONNX model
const session = await ort.InferenceSession.create(modelPath);

// Run inference
const feeds = { 'images': inputTensor };
const results = await session.run(feeds);
```

---

## üîß Appendix B: Hyperparameter Tuning (Optional)

Run this cell to perform hyperparameter optimization using Ultralytics' built-in tuner:

In [None]:
# ==========================================
# üîß Hyperparameter Tuning (Optional - takes longer)
# ==========================================

# Uncomment to run hyperparameter tuning
# This will search for optimal learning rate, momentum, etc.

"""
from ultralytics import YOLO

# Load base model
tune_model = YOLO('yolov8s-cls.pt')

# Run tuning
tune_results = tune_model.tune(
    data=str(YOLO_DATASET_DIR),
    epochs=30,
    iterations=50,  # Number of search iterations
    optimizer='AdamW',
    plots=True,
    save=True,
    val=True
)

print("Best hyperparameters found:")
print(tune_results.best)
"""

print("üí° Hyperparameter tuning is disabled by default.")
print("   Uncomment the code above to run tuning (takes 2-4 hours).")