# YOLOv5 Training on COCO128

This notebook contains the complete training pipeline for YOLOv5 on the COCO128 dataset.

## M1 Metal (MPS) Acceleration

This notebook is configured to use Apple's Metal Performance Shaders (MPS) backend for GPU acceleration on M1/M2 Macs.

**Requirements:**
- PyTorch >= 1.12.0 (MPS support)
- macOS 12.3+ (for MPS)

The notebook will automatically detect and use:
1. **MPS** (Metal) if available (Apple Silicon)
2. **CUDA** if available (NVIDIA GPU)
3. **CPU** as fallback

## 1. Setup and Imports

In [None]:
from pathlib import Path
import sys

parent_dir = Path.cwd().parent
sys.path.append(str(parent_dir))

from minimal_yolov5 import *
from train_yolov5 import *
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from tqdm import tqdm
import numpy as np

print("Imports successful!")

In [None]:
# Check PyTorch and MPS availability
print(f"PyTorch version: {torch.__version__}")
print(f"MPS available: {torch.backends.mps.is_available()}")
print(f"MPS built: {torch.backends.mps.is_built()}")

if torch.backends.mps.is_available():
    print("\n✓ MPS (Metal) acceleration is available and will be used!")
    print("  Your M1 GPU will accelerate training significantly.")
else:
    print("\n✗ MPS not available. Will use CPU.")
    print("  To enable MPS, upgrade PyTorch: pip install --upgrade torch torchvision")

## 2. Dataset Setup

In [None]:
# Setup paths for COCO128 dataset
dataset_root = Path("/Users/davide/Documents/Learning/Datasets/coco128")
img_dir = dataset_root / "images" / "train2017"
label_dir = dataset_root / "labels" / "train2017"

print(f"Dataset root: {dataset_root}")
print(f"Images dir: {img_dir}")
print(f"Labels dir: {label_dir}")
print(f"\nImages exist: {img_dir.exists()}")
print(f"Labels exist: {label_dir.exists()}")

# Count files
if img_dir.exists():
    num_images = len(list(img_dir.glob("*.jpg")))
    print(f"\nNumber of images: {num_images}")
if label_dir.exists():
    num_labels = len(list(label_dir.glob("*.txt")))
    print(f"Number of labels: {num_labels}")

In [None]:
# Create dataset and dataloader
print("Creating COCO128 dataset...")
dataset = YOLODataset(
    img_dir=str(img_dir),
    label_dir=str(label_dir),
    img_size=640,
    augment=False
)

# Split into train/val (80/20 split)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

train_dataset, val_dataset = torch.utils.data.random_split(
    dataset, 
    [train_size, val_size],
    generator=torch.Generator().manual_seed(42)
)

print(f"\nTrain set: {len(train_dataset)} images")
print(f"Val set: {len(val_dataset)} images")

# Create dataloaders
batch_size = 4
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=0  # Set to 0 for notebook
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=0
)

print(f"\nTrain batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")

In [None]:
# Test loading one batch
print("Loading a sample batch...")
imgs, targets = next(iter(train_loader))

print(f"\nBatch images shape: {imgs.shape}")
print(f"Batch targets shape: {targets.shape}")
print(f"Number of objects in batch: {len(targets)}")

# Show some sample labels
print(f"\nFirst 5 targets (img_idx, class, x, y, w, h):")
print(targets[:5])

## 3. Model Setup

In [None]:
# Setup training with M1 Metal (MPS) support
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print(f"Using device: MPS (Apple M1 Metal)")
elif torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"Using device: CUDA")
else:
    device = torch.device("cpu")
    print(f"Using device: CPU")

# Create model
print("\nCreating model...")
model = YOLOv5(num_classes=80, channels=3)
model = model.to(device)

num_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {num_params:,}")

# Setup optimizer
optimizer = torch.optim.SGD(
    model.parameters(),
    lr=0.01,
    momentum=0.937,
    weight_decay=0.0005
)

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=100
)

# Loss function
loss_fn = YOLOLoss(model)

print("\nSetup complete!")

## 4. Training Loop

In [None]:
# Training loop
num_epochs = 100
history = {'train_loss': [], 'val_loss': []}

print(f"Starting training for {num_epochs} epochs...")
print("=" * 80)

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    print("-" * 80)
    
    # ========== Training ==========
    model.train()
    train_loss = 0
    train_box_loss = 0
    train_obj_loss = 0
    train_cls_loss = 0
    
    pbar = tqdm(train_loader, desc='Training')
    for imgs, targets in pbar:
        imgs = imgs.to(device).float() / 255.0
        targets = targets.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        predictions = model(imgs)
        loss, loss_items = loss_fn(predictions, targets)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Track losses
        train_loss += loss.item()
        train_box_loss += loss_items[0].item()
        train_obj_loss += loss_items[1].item()
        train_cls_loss += loss_items[2].item()
        
        # Update progress bar
        pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'box': f'{loss_items[0]:.4f}',
            'obj': f'{loss_items[1]:.4f}',
            'cls': f'{loss_items[2]:.4f}'
        })
    
    # Average training losses
    train_loss /= len(train_loader)
    train_box_loss /= len(train_loader)
    train_obj_loss /= len(train_loader)
    train_cls_loss /= len(train_loader)
    
    # ========== Validation ==========
    # NOTE: We keep model in training mode because the loss function
    # expects predictions in training format (list of tensors), not inference format
    val_loss = 0
    val_box_loss = 0
    val_obj_loss = 0
    val_cls_loss = 0
    
    # Keep model in train mode but disable gradients
    model.train()
    
    for imgs, targets in tqdm(val_loader, desc='Validation'):
        imgs = imgs.to(device).float() / 255.0
        targets = targets.to(device)
        
        # Use torch.no_grad() to disable gradient computation
        with torch.no_grad():
            predictions = model(imgs)
            loss, loss_items = loss_fn(predictions, targets)
        
        val_loss += loss.item()
        val_box_loss += loss_items[0].item()
        val_obj_loss += loss_items[1].item()
        val_cls_loss += loss_items[2].item()
    
    # Average validation losses
    val_loss /= len(val_loader)
    val_box_loss /= len(val_loader)
    val_obj_loss /= len(val_loader)
    val_cls_loss /= len(val_loader)
    
    # Store history
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    
    # Print epoch summary
    print(f"\nEpoch {epoch+1} Summary:")
    print(f"  Train Loss: {train_loss:.4f} (box: {train_box_loss:.4f}, obj: {train_obj_loss:.4f}, cls: {train_cls_loss:.4f})")
    print(f"  Val Loss:   {val_loss:.4f} (box: {val_box_loss:.4f}, obj: {val_obj_loss:.4f}, cls: {val_cls_loss:.4f})")
    print(f"  LR: {optimizer.param_groups[0]['lr']:.6f}")
    
    # Step scheduler
    scheduler.step()

print("\n" + "=" * 80)
print("Training complete!")

## 5. Visualize Training Results

In [None]:
# Plot training history
plt.figure(figsize=(12, 5))

# Plot loss curves
plt.subplot(1, 2, 1)
plt.plot(history['train_loss'], label='Train Loss', marker='o')
plt.plot(history['val_loss'], label='Val Loss', marker='s')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.grid(True)

# Plot loss difference
plt.subplot(1, 2, 2)
epochs = range(1, len(history['train_loss']) + 1)
diff = [v - t for t, v in zip(history['train_loss'], history['val_loss'])]
plt.plot(epochs, diff, marker='o', color='purple')
plt.axhline(y=0, color='r', linestyle='--', alpha=0.3)
plt.xlabel('Epoch')
plt.ylabel('Val Loss - Train Loss')
plt.title('Overfitting Monitor')
plt.grid(True)

plt.tight_layout()
plt.show()

print(f"\nFinal Train Loss: {history['train_loss'][-1]:.4f}")
print(f"Final Val Loss: {history['val_loss'][-1]:.4f}")

## 6. COCO Class Names

In [None]:
# COCO dataset class names (80 classes)
COCO_CLASSES = [
    'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
    'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
    'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
    'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
    'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
    'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
    'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
    'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
    'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]

print(f"COCO dataset has {len(COCO_CLASSES)} classes")
print(f"First 10 classes: {COCO_CLASSES[:10]}")

## 7. Debug Prediction Statistics

In [None]:
# Debug: Check prediction statistics
model.eval()

with torch.no_grad():
    # Get a batch from validation set
    imgs, targets = next(iter(val_loader))
    imgs = imgs.to(device).float() / 255.0
    
    # Get predictions
    pred = model(imgs)[0]  # Shape: (N, 85) where N = batch_size * 25200
    
    print(f"Predictions shape: {pred.shape}")
    print(f"\nObjectness confidence statistics:")
    print(f"  Min: {pred[:, 4].min().item():.6f}")
    print(f"  Max: {pred[:, 4].max().item():.6f}")
    print(f"  Mean: {pred[:, 4].mean().item():.6f}")
    print(f"  Median: {pred[:, 4].median().item():.6f}")
    
    # Count predictions above different thresholds
    for thresh in [0.01, 0.05, 0.1, 0.3, 0.5]:
        count = (pred[:, 4] > thresh).sum().item()
        print(f"  Predictions > {thresh}: {count}")
    
    print(f"\nNote: If max confidence is very low (< 0.01), the model needs more training.")
    print(f"After proper training, you should see max confidence > 0.5")

## 8. Visualize Predictions on Validation Set

In [None]:
# Visualize predictions on validation set - showing original images with properly transformed boxes
model.eval()

# Get random samples from validation set
num_samples = 4
indices = np.random.choice(len(val_dataset), num_samples, replace=False)

fig, axes = plt.subplots(2, 2, figsize=(15, 15))
axes = axes.flatten()

conf_thresh = 0.05  # Lower threshold for visualization

with torch.no_grad():
    for idx, ax in zip(indices, axes):
        # Get the actual dataset item
        dataset_idx = val_dataset.indices[idx]
        img_path = dataset.img_files[dataset_idx]
        
        # Load ORIGINAL image (before letterbox)
        orig_img = cv2.imread(str(img_path))
        orig_img = cv2.cvtColor(orig_img, cv2.COLOR_BGR2RGB)
        h_orig, w_orig = orig_img.shape[:2]
        
        # Load ORIGINAL labels (before letterbox adjustment)
        label_path = dataset.label_dir / (img_path.stem + '.txt')
        orig_labels = []
        if label_path.exists():
            with open(label_path, 'r') as f:
                for line in f:
                    cls, x, y, w, h = map(float, line.strip().split())
                    orig_labels.append([cls, x, y, w, h])
        orig_labels = np.array(orig_labels) if orig_labels else np.zeros((0, 5))
        
        # Get letterboxed image and predictions
        img, _ = val_dataset[idx]
        img_tensor = img.unsqueeze(0).to(device).float() / 255.0
        pred = model(img_tensor)[0]
        
        # Filter predictions
        pred = pred[pred[:, 4] > conf_thresh]
        if len(pred) > 0:
            pred = pred[pred[:, 4].argsort(descending=True)][:20]
        
        # Display ORIGINAL image
        ax.imshow(orig_img)
        
        # Draw ground truth boxes in green (on original image)
        for label in orig_labels:
            if len(label) == 0:
                continue
            cls, x_norm, y_norm, w_norm, h_norm = label
            
            # Convert from normalized to pixel coordinates on ORIGINAL image
            x_center = x_norm * w_orig
            y_center = y_norm * h_orig
            box_w = w_norm * w_orig
            box_h = h_norm * h_orig
            
            x1 = x_center - box_w / 2
            y1 = y_center - box_h / 2
            
            rect = patches.Rectangle((x1, y1), box_w, box_h, linewidth=2, 
                                    edgecolor='green', facecolor='none', linestyle='-')
            ax.add_patch(rect)
            
            # Get class name
            class_idx = int(cls)
            class_name = COCO_CLASSES[class_idx] if class_idx < len(COCO_CLASSES) else f'C{class_idx}'
            
            ax.text(x1, max(10, y1 - 5), f'GT: {class_name}', 
                   color='white', fontsize=8, weight='bold',
                   bbox=dict(boxstyle='round,pad=0.3', facecolor='green', alpha=0.8, edgecolor='none'))
        
        # Transform predictions back from letterbox to original coordinates
        # Calculate letterbox parameters
        r = min(640 / h_orig, 640 / w_orig)
        new_unpad = (int(round(w_orig * r)), int(round(h_orig * r)))
        dw = (640 - new_unpad[0]) / 2
        dh = (640 - new_unpad[1]) / 2
        
        # Draw predictions in red (mapped back to original image)
        for p in pred:
            # Predictions are in letterbox coordinates (640x640)
            x_center, y_center, box_w, box_h = p[:4].cpu().numpy()
            
            # Remove letterbox padding
            x_center_unpad = x_center - dw
            y_center_unpad = y_center - dh
            
            # Scale back to original image size
            x_center_orig = x_center_unpad / r
            y_center_orig = y_center_unpad / r
            box_w_orig = box_w / r
            box_h_orig = box_h / r
            
            x1 = x_center_orig - box_w_orig / 2
            y1 = y_center_orig - box_h_orig / 2
            
            # Only draw if box is within original image bounds
            if (x1 >= 0 and y1 >= 0 and 
                x1 + box_w_orig <= w_orig and 
                y1 + box_h_orig <= h_orig and
                box_w_orig > 5 and box_h_orig > 5):
                
                rect = patches.Rectangle((x1, y1), box_w_orig, box_h_orig, linewidth=2, 
                                        edgecolor='red', facecolor='none', linestyle='-')
                ax.add_patch(rect)
                
                conf = p[4].item()
                cls_idx = p[5:].argmax().item()
                class_name = COCO_CLASSES[cls_idx] if cls_idx < len(COCO_CLASSES) else f'C{cls_idx}'
                
                ax.text(x1, max(10, y1 - 5), f'{class_name}: {conf:.2f}', 
                       color='white', fontsize=8, weight='bold',
                       bbox=dict(boxstyle='round,pad=0.3', facecolor='red', alpha=0.8, edgecolor='none'))
        
        num_preds = len(pred) if len(pred) > 0 else 0
        ax.set_title(f'Image {dataset_idx} - Green: GT ({len(orig_labels)} objs), Red: {num_preds} Preds', 
                     fontsize=10)
        ax.axis('off')

plt.tight_layout()
plt.show()

print(f"\nShowing predictions with confidence > {conf_thresh}")
print("Green boxes: Ground truth (on original image)")
print("Red boxes: Model predictions (transformed back from letterbox)")
print(f"\nNote: Original images shown without letterbox padding for clarity")

## 9. Save Model

In [None]:
# Save the trained model
save_path = Path("../checkpoints")
save_path.mkdir(exist_ok=True)

checkpoint_path = save_path / "yolov5_coco128_100epochs.pt"

torch.save({
    'epoch': num_epochs,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'train_loss': history['train_loss'][-1],
    'val_loss': history['val_loss'][-1],
    'history': history
}, checkpoint_path)

print(f"Model saved to: {checkpoint_path}")
print(f"Checkpoint includes: model weights, optimizer state, and training history")