# Poke Bowl Inventory - YOLO Model Training

This notebook trains a YOLOv8 model for detecting 40 different product classes in the Poke Bowl inventory system.

## Training Improvements:
- GPU acceleration (CUDA/MPS)
- Proper data augmentation (mosaic, mixup, rotation, scaling, color jitter)
- Optimized hyperparameters
- Extended training epochs
- Learning rate scheduling
- Early stopping with patience
- Comprehensive validation metrics
- Model checkpointing

## 1. Setup and Installation

In [None]:
# Install required packages
!pip install ultralytics opencv-python pillow matplotlib seaborn pandas numpy torch torchvision

In [None]:
# Import libraries
import os
import yaml
import torch
import cv2
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from ultralytics import YOLO
from datetime import datetime
import shutil

# Set style for plots
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU: {torch.cuda.get_device_name(0)}")
elif torch.backends.mps.is_available():
    print("MPS (Apple Silicon) available")
else:
    print("WARNING: No GPU detected, training will be slow!")

## 2. Dataset Analysis

In [None]:
# Define paths
PROJECT_ROOT = Path.cwd()
DATASET_PATH = PROJECT_ROOT / 'dataset' / 'pokebowl_dataset'
DATA_YAML = DATASET_PATH / 'data.yaml'
TRAIN_IMAGES = DATASET_PATH / 'images' / 'train'
VAL_IMAGES = DATASET_PATH / 'images' / 'val'
TRAIN_LABELS = DATASET_PATH / 'labels' / 'train'
VAL_LABELS = DATASET_PATH / 'labels' / 'val'

print(f"Project root: {PROJECT_ROOT}")
print(f"Dataset path: {DATASET_PATH}")
print(f"Data YAML: {DATA_YAML}")
print(f"\nChecking paths...")
print(f"Train images exist: {TRAIN_IMAGES.exists()}")
print(f"Val images exist: {VAL_IMAGES.exists()}")
print(f"Train labels exist: {TRAIN_LABELS.exists()}")
print(f"Val labels exist: {VAL_LABELS.exists()}")

In [None]:
# Load and display dataset configuration
with open(DATA_YAML, 'r') as f:
    data_config = yaml.safe_load(f)

print("Dataset Configuration:")
print(f"Number of classes: {data_config['nc']}")
print(f"\nClass names:")
for i, name in enumerate(data_config['names']):
    print(f"  {i}: {name}")

In [None]:
# Count images and labels
train_images = list(TRAIN_IMAGES.glob('*.jpg'))
val_images = list(VAL_IMAGES.glob('*.jpg'))
train_labels = list(TRAIN_LABELS.glob('*.txt'))
val_labels = list(VAL_LABELS.glob('*.txt'))

print(f"\nDataset Statistics:")
print(f"Training images: {len(train_images)}")
print(f"Training labels: {len(train_labels)}")
print(f"Validation images: {len(val_images)}")
print(f"Validation labels: {len(val_labels)}")
print(f"Total images: {len(train_images) + len(val_images)}")
print(f"\nImages per class (average): {(len(train_images) + len(val_images)) / data_config['nc']:.1f}")

In [None]:
# Analyze class distribution
def count_class_instances(label_files, num_classes):
    """Count instances of each class in label files"""
    class_counts = np.zeros(num_classes, dtype=int)
    
    for label_file in label_files:
        with open(label_file, 'r') as f:
            for line in f:
                parts = line.strip().split()
                if parts:
                    class_id = int(parts[0])
                    if 0 <= class_id < num_classes:
                        class_counts[class_id] += 1
    
    return class_counts

train_class_counts = count_class_instances(train_labels, data_config['nc'])
val_class_counts = count_class_instances(val_labels, data_config['nc'])
total_class_counts = train_class_counts + val_class_counts

print("\nClass Distribution:")
print(f"{'Class ID':<10} {'Class Name':<40} {'Train':<8} {'Val':<8} {'Total':<8}")
print("-" * 80)
for i, name in enumerate(data_config['names']):
    print(f"{i:<10} {name:<40} {train_class_counts[i]:<8} {val_class_counts[i]:<8} {total_class_counts[i]:<8}")

print(f"\nTotal instances: {total_class_counts.sum()}")
print(f"Average instances per class: {total_class_counts.mean():.1f}")
print(f"Min instances: {total_class_counts.min()}")
print(f"Max instances: {total_class_counts.max()}")

In [None]:
# Visualize class distribution
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))

# Bar plot
x = np.arange(len(data_config['names']))
ax1.bar(x, train_class_counts, label='Train', alpha=0.7)
ax1.bar(x, val_class_counts, bottom=train_class_counts, label='Val', alpha=0.7)
ax1.set_xlabel('Class ID', fontsize=12)
ax1.set_ylabel('Number of Instances', fontsize=12)
ax1.set_title('Class Distribution', fontsize=14, fontweight='bold')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Horizontal bar plot with names
sorted_indices = np.argsort(total_class_counts)[::-1]
sorted_names = [data_config['names'][i] for i in sorted_indices]
sorted_counts = total_class_counts[sorted_indices]

y_pos = np.arange(len(sorted_names))
ax2.barh(y_pos, sorted_counts, alpha=0.7)
ax2.set_yticks(y_pos)
ax2.set_yticklabels(sorted_names, fontsize=8)
ax2.set_xlabel('Total Instances', fontsize=12)
ax2.set_title('Classes Sorted by Frequency', fontsize=14, fontweight='bold')
ax2.grid(True, alpha=0.3, axis='x')

plt.tight_layout()
plt.savefig('class_distribution.png', dpi=150, bbox_inches='tight')
plt.show()

print("\nClass distribution plot saved as 'class_distribution.png'")

In [None]:
# Visualize sample images with annotations
def visualize_sample_images(image_files, label_files, class_names, num_samples=6):
    """Visualize random sample images with their annotations"""
    samples = np.random.choice(len(image_files), min(num_samples, len(image_files)), replace=False)
    
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    axes = axes.ravel()
    
    for idx, sample_idx in enumerate(samples):
        img_path = image_files[sample_idx]
        img = cv2.imread(str(img_path))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        h, w = img.shape[:2]
        
        # Find corresponding label file
        label_path = None
        for lbl in label_files:
            if lbl.stem == img_path.stem:
                label_path = lbl
                break
        
        # Draw bounding boxes
        if label_path and label_path.exists():
            with open(label_path, 'r') as f:
                for line in f:
                    parts = line.strip().split()
                    if len(parts) >= 5:
                        class_id = int(parts[0])
                        x_center, y_center, width, height = map(float, parts[1:5])
                        
                        # Convert YOLO format to pixel coordinates
                        x1 = int((x_center - width/2) * w)
                        y1 = int((y_center - height/2) * h)
                        x2 = int((x_center + width/2) * w)
                        y2 = int((y_center + height/2) * h)
                        
                        # Draw rectangle and label
                        color = tuple(np.random.randint(0, 255, 3).tolist())
                        cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
                        
                        # Add class name
                        class_name = class_names[class_id] if class_id < len(class_names) else f"Class {class_id}"
                        label_text = f"{class_name}"
                        
                        # Draw label background
                        (text_w, text_h), _ = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
                        cv2.rectangle(img, (x1, y1 - text_h - 4), (x1 + text_w, y1), color, -1)
                        cv2.putText(img, label_text, (x1, y1 - 2), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
        
        axes[idx].imshow(img)
        axes[idx].set_title(f"{img_path.name}", fontsize=10)
        axes[idx].axis('off')
    
    plt.tight_layout()
    plt.savefig('sample_images.png', dpi=150, bbox_inches='tight')
    plt.show()

print("Visualizing sample training images...")
visualize_sample_images(train_images, train_labels, data_config['names'], num_samples=6)
print("Sample images saved as 'sample_images.png'")

## 3. Training Configuration

In [None]:
# Determine deviceif torch.cuda.is_available():    device = 0  # Use first GPU    print(f"Using GPU: {torch.cuda.get_device_name(0)}")elif torch.backends.mps.is_available():    device = 'mps'  # Use Apple Silicon GPU    print("Using MPS (Apple Silicon GPU)")else:    device = 'cpu'    print("WARNING: Using CPU - training will be very slow!")# Training hyperparameters (FIXED for tensor size mismatch)TRAINING_CONFIG = {    # Model    'model': 'yolov8n.pt',  # Nano model (fastest, good for edge devices)        # Dataset    'data': str(DATA_YAML),        # Training duration    'epochs': 200,    'patience': 50,        # Batch and image size    'batch': 8,  # Reduced from 16    'imgsz': 640,        # Device    'device': device,    'workers': 4,  # Reduced from 8        # Optimization    'optimizer': 'AdamW',    'lr0': 0.001,    'lrf': 0.01,    'momentum': 0.937,    'weight_decay': 0.0005,    'warmup_epochs': 3.0,    'warmup_momentum': 0.8,    'warmup_bias_lr': 0.1,        # Data augmentation (REDUCED to fix errors)    'hsv_h': 0.015,    'hsv_s': 0.7,    'hsv_v': 0.4,    'degrees': 5.0,   # Reduced from 10    'translate': 0.1, # Reduced from 0.2    'scale': 0.5,     # Reduced from 0.9    'shear': 2.0,     # Reduced from 5.0    'perspective': 0.0,  # Disabled    'flipud': 0.0,    'fliplr': 0.5,    'mosaic': 0.0,  # DISABLED (causes tensor issues)    'mixup': 0.0,     # DISABLED (causes tensor issues)    'copy_paste': 0.0,  # DISABLED (causes tensor issues)    'auto_augment': 'randaugment',    'erasing': 0.4,    'close_mosaic': 10,        # Loss weights    'box': 7.5,    'cls': 0.5,    'dfl': 1.5,        # Validation    'val': True,    'plots': True,    'save': True,    'save_period': 10,        # Other    'cache': False,    'amp': True,    'pretrained': True,    'verbose': True,    'seed': 42,    'deterministic': True,    'rect': False,  # Disabled for stability        # Project settings    'project': 'runs/train',    'name': f'pokebowl_yolov8n_{datetime.now().strftime("%Y%m%d_%H%M%S")}',    'exist_ok': True,}print("\nTraining Configuration:")print("=" * 60)for key, value in TRAINING_CONFIG.items():    print(f"{key:<20}: {value}")print("=" * 60)print("\n⚠️  NOTE: mixup and copy_paste disabled to fix tensor size mismatch error")print("This is normal for datasets with varying annotation counts per image.")

## 4. Model Training

In [None]:
# Initialize model
print(f"\nInitializing YOLO model: {TRAINING_CONFIG['model']}")
model = YOLO(TRAINING_CONFIG['model'])

print(f"Model loaded successfully!")
print(f"Model type: {type(model.model)}")
print(f"Model parameters: {sum(p.numel() for p in model.model.parameters()):,}")

In [None]:
# Start training
print("\n" + "="*80)
print("STARTING TRAINING")
print("="*80)
print(f"Training will run for up to {TRAINING_CONFIG['epochs']} epochs")
print(f"Early stopping patience: {TRAINING_CONFIG['patience']} epochs")
print(f"Device: {TRAINING_CONFIG['device']}")
print(f"Batch size: {TRAINING_CONFIG['batch']}")
print(f"Image size: {TRAINING_CONFIG['imgsz']}")
print("="*80 + "\n")

# Train the model
results = model.train(**TRAINING_CONFIG)

print("\n" + "="*80)
print("TRAINING COMPLETED!")
print("="*80)

## 5. Training Results Analysis

In [None]:
# Get training results directory
results_dir = Path(TRAINING_CONFIG['project']) / TRAINING_CONFIG['name']
print(f"Results directory: {results_dir}")
print(f"Results exist: {results_dir.exists()}")

if results_dir.exists():
    print("\nGenerated files:")
    for file in sorted(results_dir.glob('*')):
        if file.is_file():
            size = file.stat().st_size / (1024 * 1024)  # MB
            print(f"  {file.name:<40} {size:>8.2f} MB")

In [None]:
# Display training curves
results_img = results_dir / 'results.png'
if results_img.exists():
    img = plt.imread(str(results_img))
    plt.figure(figsize=(20, 12))
    plt.imshow(img)
    plt.axis('off')
    plt.title('Training Results', fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.show()
else:
    print("Results plot not found!")

In [None]:
# Display confusion matrix
confusion_matrix_img = results_dir / 'confusion_matrix.png'
if confusion_matrix_img.exists():
    img = plt.imread(str(confusion_matrix_img))
    plt.figure(figsize=(20, 20))
    plt.imshow(img)
    plt.axis('off')
    plt.title('Confusion Matrix', fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.show()
else:
    print("Confusion matrix not found!")

In [None]:
# Display validation batch predictions
val_batch_imgs = list(results_dir.glob('val_batch*_pred.jpg'))
if val_batch_imgs:
    print(f"Found {len(val_batch_imgs)} validation batch prediction images\n")
    
    # Display first 3 batches
    for i, img_path in enumerate(sorted(val_batch_imgs)[:3]):
        img = plt.imread(str(img_path))
        plt.figure(figsize=(20, 12))
        plt.imshow(img)
        plt.axis('off')
        plt.title(f'Validation Batch {i} Predictions', fontsize=16, fontweight='bold')
        plt.tight_layout()
        plt.show()
else:
    print("No validation batch predictions found!")

## 6. Model Validation

In [None]:
# Load best model
best_model_path = results_dir / 'weights' / 'best.pt'
print(f"Loading best model from: {best_model_path}")

if best_model_path.exists():
    best_model = YOLO(str(best_model_path))
    print("Best model loaded successfully!")
    
    # Validate the model
    print("\nRunning validation...")
    metrics = best_model.val(data=str(DATA_YAML), imgsz=640, batch=16, device=device)
    
    print("\n" + "="*80)
    print("VALIDATION METRICS")
    print("="*80)
    print(f"mAP50: {metrics.box.map50:.4f}")
    print(f"mAP50-95: {metrics.box.map:.4f}")
    print(f"Precision: {metrics.box.mp:.4f}")
    print(f"Recall: {metrics.box.mr:.4f}")
    print("="*80)
else:
    print("Best model not found!")

## 7. Test Predictions on Sample Images

In [None]:
# Test on validation images
if best_model_path.exists():
    # Select random validation images
    test_images = np.random.choice(val_images, min(6, len(val_images)), replace=False)
    
    fig, axes = plt.subplots(2, 3, figsize=(20, 14))
    axes = axes.ravel()
    
    for idx, img_path in enumerate(test_images):
        # Run inference
        results = best_model.predict(str(img_path), conf=0.25, iou=0.45, imgsz=640, device=device)
        
        # Get annotated image
        annotated_img = results[0].plot()
        annotated_img = cv2.cvtColor(annotated_img, cv2.COLOR_BGR2RGB)
        
        # Display
        axes[idx].imshow(annotated_img)
        axes[idx].set_title(f"{img_path.name}\nDetections: {len(results[0].boxes)}", fontsize=10)
        axes[idx].axis('off')
    
    plt.suptitle('Model Predictions on Validation Images', fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.savefig('test_predictions.png', dpi=150, bbox_inches='tight')
    plt.show()
    print("Test predictions saved as 'test_predictions.png'")
else:
    print("Model not available for testing!")

## 8. Export Best Model

In [None]:
# Copy best model to project root
if best_model_path.exists():
    destination = PROJECT_ROOT / 'best.pt'
    
    # Backup old model if exists
    if destination.exists():
        backup_path = PROJECT_ROOT / f'best_backup_{datetime.now().strftime("%Y%m%d_%H%M%S")}.pt'
        shutil.copy(destination, backup_path)
        print(f"Old model backed up to: {backup_path}")
    
    # Copy new model
    shutil.copy(best_model_path, destination)
    print(f"\nNew best model copied to: {destination}")
    print(f"Model size: {destination.stat().st_size / (1024*1024):.2f} MB")
    
    # Verify model
    print("\nVerifying model...")
    verify_model = YOLO(str(destination))
    print(f"Model classes: {len(verify_model.names)}")
    print(f"Model names: {list(verify_model.names.values())[:5]}...")  # Show first 5
    print("\n✓ Model verification successful!")
else:
    print("Best model not found!")

## 9. Training Summary

In [None]:
print("\n" + "="*80)
print("TRAINING SUMMARY")
print("="*80)
print(f"\nDataset:")
print(f"  - Training images: {len(train_images)}")
print(f"  - Validation images: {len(val_images)}")
print(f"  - Number of classes: {data_config['nc']}")
print(f"  - Total instances: {total_class_counts.sum()}")

print(f"\nTraining Configuration:")
print(f"  - Model: {TRAINING_CONFIG['model']}")
print(f"  - Epochs: {TRAINING_CONFIG['epochs']}")
print(f"  - Batch size: {TRAINING_CONFIG['batch']}")
print(f"  - Image size: {TRAINING_CONFIG['imgsz']}")
print(f"  - Device: {TRAINING_CONFIG['device']}")
print(f"  - Optimizer: {TRAINING_CONFIG['optimizer']}")
print(f"  - Learning rate: {TRAINING_CONFIG['lr0']}")

print(f"\nAugmentation:")
print(f"  - Mosaic: {TRAINING_CONFIG['mosaic']}")
print(f"  - Mixup: {TRAINING_CONFIG['mixup']}")
print(f"  - Copy-paste: {TRAINING_CONFIG['copy_paste']}")
print(f"  - Rotation: ±{TRAINING_CONFIG['degrees']}°")
print(f"  - Scale: ±{TRAINING_CONFIG['scale']}")
print(f"  - Translation: ±{TRAINING_CONFIG['translate']}")

if best_model_path.exists():
    print(f"\nFinal Metrics:")
    print(f"  - mAP50: {metrics.box.map50:.4f}")
    print(f"  - mAP50-95: {metrics.box.map:.4f}")
    print(f"  - Precision: {metrics.box.mp:.4f}")
    print(f"  - Recall: {metrics.box.mr:.4f}")

print(f"\nOutput Files:")
print(f"  - Results directory: {results_dir}")
print(f"  - Best model: {destination}")
print(f"  - Training plots: {results_dir / 'results.png'}")
print(f"  - Confusion matrix: {results_dir / 'confusion_matrix.png'}")

print("\n" + "="*80)
print("TRAINING COMPLETE! 🎉")
print("="*80)
print("\nThe new model has been saved to 'best.pt' and is ready for deployment.")
print("You can now use this model in your Poke Bowl Inventory System.")

## 10. Next Steps

After training is complete:

1. **Test the model** in the actual system:
   ```bash
   cd backend
   python3 main.py
   ```

2. **Monitor performance** on real camera feed

3. **Collect more data** if needed:
   - Focus on underrepresented classes
   - Add images with different lighting conditions
   - Include various angles and distances

4. **Fine-tune** if necessary:
   - Adjust confidence threshold in `config/config.yaml`
   - Modify IoU threshold for better detection
   - Retrain with more epochs if underfitting

5. **Consider upgrading** to a larger model (yolov8s or yolov8m) if:
   - You have more GPU memory
   - You need better accuracy
   - Inference speed is acceptable