# DEXI YOLO Training Tutorial - Local Environment

This notebook walks you through the complete process of training a custom YOLO object detection model for drone detection on your local machine. We'll be working with 6 classes: bird, dog, cat, motorcycle, car, and truck.

## 📋 Table of Contents
1. [Environment Setup](#environment-setup)
2. [Dataset Exploration](#dataset-exploration) 
3. [Data Augmentation](#data-augmentation)
4. [YOLO Training](#yolo-training)
5. [Results Analysis](#results-analysis)
6. [Model Testing](#model-testing)
7. [ONNX Conversion](#onnx-conversion)

---

## 1. Environment Setup

First, let's make sure we have all the required packages installed and import the necessary libraries.

**Prerequisites:**
- Python 3.8+
- Virtual environment activated
- All dependencies installed: `pip install -r requirements.txt`

In [None]:
# Install required packages (only run if packages are not installed)
# Make sure you've activated your virtual environment first!
!pip install -r requirements.txt

In [None]:
# Import required libraries and check hardware acceleration
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import yaml
from ultralytics import YOLO
import torch
from IPython.display import Image, display
import pandas as pd
from glob import glob

# Set matplotlib style for better plots
plt.style.use('default')
plt.rcParams['figure.figsize'] = (12, 8)

print(f"🔧 System Information:")
print(f"   PyTorch version: {torch.__version__}")
print(f"   CUDA available: {torch.cuda.is_available()}")
print(f"   MPS available: {torch.backends.mps.is_available() if hasattr(torch.backends, 'mps') else 'Not available'}")
print(f"   MPS built: {torch.backends.mps.is_built() if hasattr(torch.backends, 'mps') else 'Not available'}")

# 🚀 Device Selection (Optimized for Apple Silicon MPS)
if torch.cuda.is_available():
    device = 'cuda'
    gpu_name = torch.cuda.get_device_name(0)
    print(f"\n🚀 Using NVIDIA GPU: {gpu_name}")
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    device = 'mps'
    print(f"\n⚡ Using Apple Silicon MPS acceleration!")
    print(f"   This will provide 4-8x speedup over CPU")
    print(f"   Expected training time: 30-45 minutes")
    
    # Clear MPS cache for optimal performance
    if torch.backends.mps.is_built():
        torch.mps.empty_cache()
else:
    device = 'cpu'
    print(f"\n💻 Using CPU (slower but works everywhere)")
    print(f"   Expected training time: 2-4 hours")
    print(f"   💡 For faster training, use GPU or Apple Silicon Mac")

print(f"\n✅ Selected device: {device}")

## 2. Dataset Exploration

Let's explore our dataset structure and examine the original images before augmentation.

In [None]:
# Find and display our original images
original_images_path = Path('train/original_image')
image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp']

# Find all image files
all_images = []
for ext in image_extensions:
    all_images.extend(original_images_path.glob(ext))
    all_images.extend(original_images_path.glob(ext.upper()))

print(f"📸 Found {len(all_images)} images in the dataset")

# Display original images
if all_images:
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    fig.suptitle('Original Images', fontsize=16, fontweight='bold')
    
    for idx, img_path in enumerate(all_images[:6]):
        if idx >= 6:
            break
        
        row = idx // 3
        col = idx % 3
        
        # Load and display image
        img = cv2.imread(str(img_path))
        img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        axes[row, col].imshow(img_rgb)
        axes[row, col].set_title(f"{img_path.stem}\n{img.shape[1]}x{img.shape[0]}px", fontweight='bold')
        axes[row, col].axis('off')
    
    # Hide unused subplots
    for idx in range(len(all_images), 6):
        row = idx // 3
        col = idx % 3
        axes[row, col].axis('off')
    
    plt.tight_layout()
    plt.show()
else:
    print("❌ No images found in the train/original_image directory")

## 3. Data Augmentation

Now we'll use our custom augmentation script to create multiple variations of each base image. This is crucial for training a robust YOLO model as it helps the model generalize better to different conditions.

In [None]:
# Let's examine our augmentation script first
print("🔧 Data Augmentation Script Overview:")
print("")
print("Our augmentation script applies the following transformations:")
print("• 🔄 Rotation: 0-360 degrees (random)")
print("• 📏 Scaling: 0.25x to 1.3x (random)")
print("• ☀️ Brightness: -30 to +30 (random)")
print("• 🌈 Contrast: 0.7x to 1.3x (random)")
print("• 📻 Noise: Added 20% of the time")
print("• 🌫️ Blur: Applied 15% of the time")
print("")
print("Each transformation creates realistic variations that help the model")
print("learn to detect objects under different conditions.")

In [None]:
# Set augmentation parameters
AUGMENTATIONS_PER_IMAGE = 150  # Adjust this number as needed
INPUT_DIR = "train/original_image"      # Directory with original images
OUTPUT_DIR = "train"           # Output directory for augmented dataset

print(f"⚙️ Augmentation Configuration:")
print(f"Input directory: {INPUT_DIR}")
print(f"Output directory: {OUTPUT_DIR}")
print(f"Augmentations per image: {AUGMENTATIONS_PER_IMAGE}")
print(f"Expected total images: {len(all_images) * AUGMENTATIONS_PER_IMAGE}")

In [None]:
# Run the augmentation process
print("🚀 Starting data augmentation...")
print("This may take a few minutes depending on the number of augmentations.")

# Import and use our augmentation class
from augment_dataset import YOLODatasetAugmenter

# Create augmenter instance
augmenter = YOLODatasetAugmenter(INPUT_DIR, OUTPUT_DIR)

# Run augmentation
augmenter.augment_all_images(AUGMENTATIONS_PER_IMAGE)

print("\n✅ Data augmentation completed!")

In [None]:
# Verify augmentation results
train_images_dir = Path('train/images')
train_labels_dir = Path('train/labels')

# Count generated files
augmented_images = list(train_images_dir.glob('*_[0-9][0-9][0-9].jpg'))
augmented_labels = list(train_labels_dir.glob('*_[0-9][0-9][0-9].txt'))
original_images = [f for f in train_images_dir.glob('*.jpg') if not f.name.endswith(('_001.jpg', '_002.jpg', '_003.jpg'))]

print(f"📊 Augmentation Results:")
print(f"Original images: {len(original_images)}")
print(f"Augmented images: {len(augmented_images)}")
print(f"Total images: {len(list(train_images_dir.glob('*.jpg')))}")
print(f"Label files: {len(augmented_labels)}")
print(f"")
print(f"Images per class:")
for class_name in ['bird', 'dog', 'cat', 'motorcycle', 'car', 'truck']:
    class_images = len(list(train_images_dir.glob(f'{class_name}_*.jpg')))
    print(f"  {class_name}: {class_images} images")

In [None]:
# Display a few examples of augmented images
print("🖼️ Sample Augmented Images:")

fig, axes = plt.subplots(3, 4, figsize=(16, 12))
fig.suptitle('Sample Augmented Images (Showing Transformations)', fontsize=16, fontweight='bold')

# Show examples from each class
classes = ['bird', 'dog', 'cat', 'motorcycle', 'car', 'truck']
sample_count = 0

for class_idx, class_name in enumerate(classes):
    class_images = list(train_images_dir.glob(f'{class_name}_*.jpg'))[:2]  # Get first 2 augmented images
    
    for img_idx, img_path in enumerate(class_images):
        if sample_count >= 12:  # 3x4 grid
            break
            
        row = sample_count // 4
        col = sample_count % 4
        
        img = cv2.imread(str(img_path))
        img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        axes[row, col].imshow(img_rgb)
        axes[row, col].set_title(f"{class_name.title()}\n{img_path.name}", fontweight='bold', fontsize=10)
        axes[row, col].axis('off')
        
        sample_count += 1
    
    if sample_count >= 12:
        break

# Hide unused subplots
for idx in range(sample_count, 12):
    row = idx // 4
    col = idx % 4
    axes[row, col].axis('off')

plt.tight_layout()
plt.show()

## 4. YOLO Training

Now that we have our augmented dataset ready, let's train our YOLO model. We'll use the YOLOv8 architecture, which is state-of-the-art for object detection.

In [None]:
# Training configuration
TRAINING_CONFIG = {
    'model_size': 'n',        # Options: 'n', 's', 'm', 'l', 'x' (nano to extra-large)
    'epochs': 100,            # Number of training epochs
    'imgsz': 640,             # Image size for training
    'batch_size': 16,         # Batch size (adjust based on your GPU memory)
    'device': device,         # Device determined earlier
}

print("🎯 Training Configuration:")
for key, value in TRAINING_CONFIG.items():
    print(f"  {key}: {value}")
    
print("\n💡 Model Size Guide:")
print("  • 'n' (nano): Fastest, smallest, good for mobile/edge devices")
print("  • 's' (small): Good balance of speed and accuracy")
print("  • 'm' (medium): Better accuracy, moderate speed")
print("  • 'l' (large): High accuracy, slower inference")
print("  • 'x' (extra-large): Highest accuracy, slowest inference")

In [None]:
# Initialize the YOLO model
model_name = f"yolov8{TRAINING_CONFIG['model_size']}.pt"
print(f"🤖 Loading {model_name} model...")

# Load pre-trained YOLO model
model = YOLO(model_name)

print(f"✅ Model loaded successfully!")
print(f"   Model: YOLOv8{TRAINING_CONFIG['model_size']}")
print(f"   Parameters: {sum(p.numel() for p in model.model.parameters()):,}")
print(f"   Size on disk: {os.path.getsize(model_name) / (1024*1024):.1f} MB")

In [None]:
# Start training with automatic train/val split
print("🚀 Starting YOLO training...")
print("This will take some time. You can monitor the progress below.")
print("Training logs and checkpoints will be saved in 'runs/detect/drone_detection/'")
print("💡 YOLO will automatically split training data for validation (80% train, 20% val)")

# Train the model using the parameters from train_yolo.py
results = model.train(
    data='dataset.yaml',
    epochs=TRAINING_CONFIG['epochs'],
    imgsz=TRAINING_CONFIG['imgsz'],
    batch=TRAINING_CONFIG['batch_size'],
    device=TRAINING_CONFIG['device'],
    project='runs/detect',
    name='drone_detection',
    save_period=10,  # Save checkpoint every 10 epochs
    patience=20,     # Early stopping patience
    
    # Automatic train/validation split using fraction parameter
    fraction=0.8,    # Use 80% of data for training (20% for validation)
    
    # Augmentation settings (additional to our pre-generated augmentations)
    hsv_h=0.015,     # Hue augmentation
    hsv_s=0.7,       # Saturation augmentation  
    hsv_v=0.4,       # Value augmentation
    degrees=0,       # Don't add rotation (we already did this)
    translate=0.1,   # Translation augmentation
    scale=0.1,       # Additional scale augmentation
    shear=0.1,       # Shear augmentation
    perspective=0.0, # Perspective augmentation
    flipud=0.0,      # No vertical flip (objects have orientation)
    fliplr=0.0,      # No horizontal flip (for consistency)
    mosaic=0.8,      # Mosaic augmentation probability
    mixup=0.1,       # Mixup augmentation probability
    
    # Optimization
    optimizer='AdamW',
    lr0=0.01,        # Initial learning rate
    lrf=0.1,         # Final learning rate (lr0 * lrf)
    momentum=0.937,
    weight_decay=0.0005,
    warmup_epochs=3,
    warmup_momentum=0.8,
    warmup_bias_lr=0.1,
    
    # Other settings
    box=7.5,         # Box loss gain
    cls=0.5,         # Class loss gain
    dfl=1.5,         # DFL loss gain
    verbose=True,
)

print("\n🎉 Training completed!")

## 5. Results Analysis

Let's analyze the training results and visualize the model's performance.

In [None]:
# Check if training results exist
results_dir = Path('runs/detect/drone_detection4')
if results_dir.exists():
    print(f"📂 Training results saved in: {results_dir}")
    print(f"📁 Contents:")
    for item in sorted(results_dir.iterdir()):
        if item.is_file():
            print(f"  📄 {item.name}")
        else:
            print(f"  📁 {item.name}/")
else:
    print("❌ Training results not found. Make sure training completed successfully.")

In [None]:
# Display training curves if available
results_image_path = results_dir / 'results.png'
if results_image_path.exists():
    print("📈 Training Results:")
    display(Image(str(results_image_path)))
else:
    print("📈 Training curves not found. They should be available after training completes.")

In [None]:
# Display confusion matrix if available
confusion_matrix_path = results_dir / 'confusion_matrix.png'
if confusion_matrix_path.exists():
    print("🎯 Confusion Matrix:")
    display(Image(str(confusion_matrix_path)))
else:
    print("🎯 Confusion matrix not found.")

In [None]:
# Load and display validation results
val_batch_path = results_dir / 'val_batch0_labels.jpg'
if val_batch_path.exists():
    print("🔍 Validation Batch with Ground Truth Labels:")
    display(Image(str(val_batch_path)))
    
val_pred_path = results_dir / 'val_batch0_pred.jpg'
if val_pred_path.exists():
    print("\n🤖 Validation Batch with Model Predictions:")
    display(Image(str(val_pred_path)))

if not val_batch_path.exists() and not val_pred_path.exists():
    print("🔍 Validation images not found.")

In [None]:
# Load the trained model and run validation
best_model_path = results_dir / 'weights' / 'best.pt'
if best_model_path.exists():
    print(f"🏆 Loading best model: {best_model_path}")
    trained_model = YOLO(str(best_model_path))
    
    # Run validation
    print("\n🔬 Running final validation...")
    val_results = trained_model.val()
    
    # Print key metrics
    print("\n📊 Final Model Metrics:")
    if hasattr(val_results, 'box'):
        metrics = val_results.box
        print(f"  mAP@0.5: {metrics.map50:.3f}")
        print(f"  mAP@0.5:0.95: {metrics.map:.3f}")
        print(f"  Precision: {metrics.mp:.3f}")
        print(f"  Recall: {metrics.mr:.3f}")
else:
    print("❌ Best model not found. Training may not have completed successfully.")

## 6. Model Testing

Let's test our trained model on some sample images to see how well it performs.

In [None]:
# Test the model on some training images
if 'trained_model' in locals():
    print("🧪 Testing the trained model on sample images...")
    
    # Get some test images
    test_images = list(train_images_dir.glob('*_001.jpg'))[:6]  # First augmented image of each class
    
    if test_images:
        fig, axes = plt.subplots(2, 3, figsize=(18, 12))
        fig.suptitle('Model Predictions on Sample Images', fontsize=16, fontweight='bold')
        
        for idx, img_path in enumerate(test_images[:6]):
            row = idx // 3
            col = idx % 3
            
            # Run inference
            results = trained_model(str(img_path), verbose=False)
            
            # Get the annotated image
            annotated_img = results[0].plot()
            annotated_img_rgb = cv2.cvtColor(annotated_img, cv2.COLOR_BGR2RGB)
            
            axes[row, col].imshow(annotated_img_rgb)
            axes[row, col].set_title(f"{img_path.stem}", fontweight='bold')
            axes[row, col].axis('off')
        
        plt.tight_layout()
        plt.show()
    else:
        print("❌ No test images found.")
else:
    print("❌ Trained model not available. Please complete the training step first.")

In [None]:
# Function to test model on a custom image
def test_on_custom_image(image_path, confidence_threshold=0.5):
    """Test the model on a custom image"""
    if 'trained_model' not in locals():
        print("❌ Trained model not available. Please complete the training step first.")
        return
    
    if not Path(image_path).exists():
        print(f"❌ Image not found: {image_path}")
        return
    
    print(f"🔍 Testing on: {image_path}")
    
    # Run inference
    results = trained_model(image_path, conf=confidence_threshold, verbose=False)
    
    # Display results
    annotated_img = results[0].plot()
    annotated_img_rgb = cv2.cvtColor(annotated_img, cv2.COLOR_BGR2RGB)
    
    plt.figure(figsize=(12, 8))
    plt.imshow(annotated_img_rgb)
    plt.title(f'Detection Results - {Path(image_path).name}', fontweight='bold', fontsize=14)
    plt.axis('off')
    plt.tight_layout()
    plt.show()
    
    # Load dataset config to get class names
    with open('dataset.yaml', 'r') as f:
        dataset_config = yaml.safe_load(f)
    
    # Print detection details
    if len(results[0].boxes) > 0:
        print("\n🎯 Detections:")
        for i, box in enumerate(results[0].boxes):
            class_id = int(box.cls[0])
            confidence = float(box.conf[0])
            class_name = dataset_config['names'][class_id]
            print(f"  {i+1}. {class_name} (confidence: {confidence:.3f})")
    else:
        print("\n❌ No objects detected")

# Example usage (uncomment and modify the path to test on your own images)
# test_on_custom_image('path/to/your/test/image.jpg', confidence_threshold=0.3)

## 7. ONNX Conversion for Deployment

Now let's convert our trained PyTorch model to ONNX format for efficient deployment on devices like Raspberry Pi.

In [None]:
# Convert model to ONNX format optimized for Pi deployment
from convert_to_onnx import convert_model_to_onnx

if 'best_model_path' in locals() and best_model_path.exists():
    print("🚀 Converting trained model to ONNX...")
    
    # Convert with 320x320 input size (optimal for Pi camera 320x240)
    onnx_path = convert_model_to_onnx(
        model_path=str(best_model_path),
        imgsz=320,  # Optimized for Pi camera
        half=False,  # Keep FP32 for better Pi compatibility
        simplify=True
    )
    
    if onnx_path:
        print(f"\n✅ ONNX conversion successful!")
        print(f"📁 ONNX model saved: {onnx_path}")
        print(f"🥧 Ready for Raspberry Pi deployment!")
        
        # Test the ONNX model
        try:
            import onnxruntime as ort
            session = ort.InferenceSession(onnx_path)
            print(f"\n🔍 ONNX Model Info:")
            print(f"   Input shape: {session.get_inputs()[0].shape}")
            print(f"   Output shape: {session.get_outputs()[0].shape}")
            print(f"   Providers: {ort.get_available_providers()}")
        except ImportError:
            print(f"\n⚠️  Install onnxruntime to test: pip install onnxruntime")
    else:
        print("❌ ONNX conversion failed")
else:
    print("❌ No trained model found. Please complete training first.")

## 🎉 Congratulations!

You've successfully completed the YOLO training tutorial! Here's what you've accomplished:

### ✅ What You've Done:
1. **Environment Setup**: Configured all required libraries and dependencies
2. **Dataset Exploration**: Examined the original dataset structure and images
3. **Data Augmentation**: Generated hundreds of augmented training images with various transformations
4. **YOLO Training**: Trained a custom YOLOv8 model on your 6-class dataset
5. **Results Analysis**: Evaluated model performance with metrics and visualizations
6. **Model Testing**: Tested the trained model on sample images
7. **ONNX Conversion**: Converted the model for Pi deployment

### 📁 Important Files Created:
- `runs/detect/drone_detection/weights/best.pt` - Your best trained model
- `runs/detect/drone_detection/weights/last.pt` - Last checkpoint
- `runs/detect/drone_detection/results.png` - Training curves
- `train/images/` - Augmented training images
- `train/labels/` - Corresponding YOLO format labels

### 🚀 Next Steps:
1. **Export Your Model**: Convert to different formats (ONNX, TensorRT, etc.) for deployment
2. **Create Validation Set**: Prepare a separate validation dataset for final testing
3. **Optimize for Deployment**: Experiment with different model sizes and quantization
4. **Real-world Testing**: Test on actual drone footage or real-world scenarios

### 📚 Additional Resources:
- [Ultralytics YOLOv8 Documentation](https://docs.ultralytics.com/)
- [YOLO Model Export Guide](https://docs.ultralytics.com/modes/export/)
- [Advanced Training Techniques](https://docs.ultralytics.com/modes/train/)

---

**Happy detecting! 🎯🤖**

In [None]:
# Final summary
print("🎊 TUTORIAL COMPLETED SUCCESSFULLY! 🎊")
print("")
print("Your YOLO model is now trained and ready to use!")
print(f"Best model saved at: runs/detect/drone_detection/weights/best.pt")
print("")
print("To use your model in a Python script:")
print("")
print("```python")
print("from ultralytics import YOLO")
print("")
print("# Load your trained model")
print("model = YOLO('runs/detect/drone_detection/weights/best.pt')")
print("")
print("# Run inference on an image")
print("results = model('path/to/your/image.jpg')")
print("")
print("# Show results")
print("results[0].show()")
print("```")