In [None]:
from pathlib import Path
import random
from ultralytics.models import YOLO
import torch
import yaml
import cv2
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

random.seed(42)

In [None]:
from utils import DisplayPath
Path = DisplayPath

## Step 2: Configure Paths & Hyperparameters

In [None]:
# Dataset path (created by e2e_data_prep.ipynb)
YOLO_DATASET = Path("datasets/ready/full_dataset")
RUNS_DIR = Path("runs/segment")

# Verify dataset exists
if not YOLO_DATASET.exists():
    raise FileNotFoundError(f"Dataset not found at {YOLO_DATASET}. Run e2e_data_prep.ipynb first!")

print("Dataset:")
YOLO_DATASET.display()
print("  Train:")
(YOLO_DATASET / 'train').display()
print("  Val:")
(YOLO_DATASET / 'val').display()
print("  Test:")
(YOLO_DATASET / 'test').display()

In [None]:
EPOCHS = 50
BATCH_SIZE = 16
IMG_SIZE = 640
model_type = "yolo11n-seg.pt"
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'

print(f"Device: {DEVICE}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA: {torch.version.cuda}")

In [None]:
# Configuration d'augmentation R√âDUITE pour le training YOLO
# Puisqu'on pr√©-augmente massivement les trashcans, on r√©duit l'augmentation
# globale pour √©viter de sur-augmenter les red balls et humans
AUG_CONFIG = {
    'hsv_h': 0.010,  # Hue augmentation (r√©duit de 0.015)
    'hsv_s': 0.5,    # Saturation (r√©duit de 0.7)
    'hsv_v': 0.3,    # Value (r√©duit de 0.4)
    'degrees': 5.0,   # Rotation (r√©duit de 10.0)
    'translate': 0.05, # Translation (r√©duit de 0.1)
    'scale': 0.3,     # Scaling (r√©duit de 0.5)
    'shear': 0.0,     # Shearing
    'perspective': 0.0, # Perspective
    'flipud': 0.0,    # Vertical flip
    'fliplr': 0.5,    # Horizontal flip (maintenu)
    'mosaic': 0.5,    # Mosaic augmentation (r√©duit de 1.0)
    'mixup': 0.0,     # Mixup augmentation
    'copy_paste': 0.3, # üÜï Copy-paste aug pour classes rares
}

print("‚ö†Ô∏è  Augmentation globale R√âDUITE pour √©viter la sur-augmentation")
print("   Les trashcans sont pr√©-augment√©es massivement avant le training")

## Step 3: Verify Dataset Structure

Dataset is already prepared by e2e_data_prep.ipynb

In [None]:
# Verify dataset structure
print("="*60)
print("DATASET VERIFICATION")
print("="*60)

splits = ['train', 'val', 'test']
stats = {}

for split in splits:
    img_dir = YOLO_DATASET / split / "images"
    lbl_dir = YOLO_DATASET / split / "labels"
    
    if img_dir.exists() and lbl_dir.exists():
        num_images = len(list(img_dir.glob("*")))
        num_labels = len(list(lbl_dir.glob("*.txt")))
        stats[split] = {'images': num_images, 'labels': num_labels}
        print(f"{split.upper():5s}: {num_images:4d} images, {num_labels:4d} labels")
    else:
        stats[split] = {'images': 0, 'labels': 0}
        print(f"{split.upper():5s}: Missing!")

total_images = sum(s['images'] for s in stats.values())
total_labels = sum(s['labels'] for s in stats.values())

print(f"{'TOTAL':5s}: {total_images:4d} images, {total_labels:4d} labels")
print("="*60)

if total_images == 0:
    raise RuntimeError("No dataset found! Run e2e_data_prep.ipynb to create the dataset.")

## Step 3.5: Analyze Class Distribution

Check the distribution of classes in the training set to identify imbalances

In [None]:
# Analyze class distribution in training set
def count_class_instances(split):
    """Count instances of each class in a dataset split"""
    label_dir = YOLO_DATASET / split / "labels"
    class_counts = {0: 0, 1: 0, 2: 0}  # red ball, human, trashcan
    
    for label_file in label_dir.glob("*.txt"):
        try:
            with open(label_file, 'r') as f:
                for line in f:
                    parts = line.strip().split()
                    if parts:
                        class_id = int(float(parts[0]))
                        class_counts[class_id] = class_counts.get(class_id, 0) + 1
        except:
            continue
    
    return class_counts

print("="*60)
print("CLASS DISTRIBUTION ANALYSIS")
print("="*60)

class_names = {0: 'Red Ball', 1: 'Human', 2: 'Trashcan'}

for split in ['train', 'val']:
    counts = count_class_instances(split)
    total = sum(counts.values())
    
    if total > 0:
        print(f"\n{split.upper()}:")
        for class_id, count in counts.items():
            percentage = (count / total * 100) if total > 0 else 0
            print(f"  {class_names[class_id]:12s} (class {class_id}): {count:5d} instances ({percentage:5.1f}%)")
        print(f"  {'TOTAL':12s}           : {total:5d} instances")
        
        # Calculate imbalance ratio
        if counts[2] > 0:  # If trashcans exist
            max_count = max(counts.values())
            min_count = min(v for v in counts.values() if v > 0)
            imbalance_ratio = max_count / min_count
            print(f"  Imbalance ratio: {imbalance_ratio:.1f}x")
            
            if imbalance_ratio > 10:
                print(f"  ‚ö†Ô∏è  High class imbalance detected! Using cls=2.0 to compensate.")
        else:
            print(f"  ‚ö†Ô∏è  WARNING: No trashcans in {split} set!")

print("\n" + "="*60)

## Step 3.6: Augment Trashcan Training Data

Since trashcans are under-represented, create multiple augmented copies to balance the dataset

**Strategy:**
- üî¥ Red balls & üë§ Humans: Already abundant ‚Üí Light augmentation during training
- üóëÔ∏è Trashcans: Rare class ‚Üí **Massive pre-augmentation** (15x copies with strong transforms)
- Each trashcan image gets 15 augmented variants with:
  - Random rotations (¬±90¬∞, ¬±20¬∞)
  - Horizontal/vertical flips
  - Brightness/contrast/hue adjustments
  - Scaling & translation
  - Gaussian noise
  
**Benefits:**
- Balances class distribution without over-augmenting dominant classes
- Preserves label quality through geometric transform propagation
- Creates diverse trashcan appearances for better generalization

In [None]:
# Import augmentation utilities
from src.augmentation import augment_trashcan_dataset

# Execute augmentation
AUGMENT_TRASHCANS = True  # Set to False to skip augmentation
NUM_AUGMENTATIONS = 15     # Number of augmented copies per trashcan image
AUG_STRENGTH = 'strong'    # 'strong' or 'light'

if AUGMENT_TRASHCANS:
    # TODO: refactor name for more general use
    new_images = augment_trashcan_dataset(
        dataset_path=YOLO_DATASET,
        num_augmentations=NUM_AUGMENTATIONS,
        aug_strength=AUG_STRENGTH
    )
    
    if new_images > 0:
        print(f"\nüí° Tip: Re-run the class distribution analysis cell to see the updated statistics!")
else:
    print("Trashcan augmentation skipped (AUGMENT_TRASHCANS = False)")

## Step 4: Create YOLO Configuration File

In [None]:
classes = {
    'red ball': 0,
    'human': 1,
    'trashcan': 2
}

config = {
    'path': str(YOLO_DATASET.absolute()),
    'train': 'train/images',
    'val': 'val/images',
    'nc': len(classes),
    'names': list(classes.keys())
}

config_path = YOLO_DATASET / 'data.yaml'
with open(config_path, 'w') as f:
    yaml.dump(config, f, default_flow_style=False)

print(f"‚úì Configuration saved: {config_path}")
print("Dataset structure:")
YOLO_DATASET.display()
print("  Train:")
(YOLO_DATASET / 'train').display()
print("  Val:")
(YOLO_DATASET / 'val').display()
print("  Test:")
(YOLO_DATASET / 'test').display()

## Step 4.5: Find Trashcan Images for Monitoring

Identify validation images containing trashcans (class 2) to monitor training progress

In [None]:
def find_trashcan_images(val_labels_dir, val_images_dir, max_images=8):
    """Find validation images that contain trashcan annotations (class 2)"""
    trashcan_images = []
    
    # Scan all label files
    label_files = list(val_labels_dir.glob("*.txt"))
    print(f"Scanning {len(label_files)} label files for trashcans...")
    
    for label_file in label_files:
        try:
            with open(label_file, 'r') as f:
                lines = f.readlines()
                
            # Check if any line starts with "2 " (trashcan class)
            has_trashcan = any(line.strip().startswith('2 ') for line in lines)
            
            if has_trashcan:
                # Find corresponding image
                img_name = label_file.stem
                
                # Try different image extensions
                for ext in ['.jpg', '.jpeg', '.png', '.JPG', '.JPEG', '.PNG']:
                    img_path = val_images_dir / f"{img_name}{ext}"
                    if img_path.exists():
                        trashcan_images.append(str(img_path))
                        print(f"  ‚úì Found trashcan in: {img_path.name}")
                        break
                
                if len(trashcan_images) >= max_images:
                    break
        except Exception as e:
            # Skip problematic files (e.g., too large)
            continue
    
    return trashcan_images

# Find trashcan images
val_labels_dir = YOLO_DATASET / 'val' / 'labels'
val_images_dir = YOLO_DATASET / 'val' / 'images'

TRASHCAN_MONITOR_IMAGES = find_trashcan_images(val_labels_dir, val_images_dir, max_images=8)

print(f"\n{'='*60}")
print(f"Selected {len(TRASHCAN_MONITOR_IMAGES)} images for trashcan monitoring:")
for img_path in TRASHCAN_MONITOR_IMAGES:
    print(f"  - {Path(img_path).name}")
print(f"{'='*60}\n")

if len(TRASHCAN_MONITOR_IMAGES) == 0:
    print("‚ö†Ô∏è  Warning: No trashcan images found in validation set!")

## Step 4.6: Define Trashcan Monitoring Callback

Create a callback that visualizes trashcan segmentation progress at each epoch

In [None]:
def create_trashcan_monitor_callback(model, monitor_images, output_dir, project_name):
    """Create a callback to monitor trashcan segmentation progress during training"""
    
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    def on_train_epoch_end(trainer):
        """Called at the end of each training epoch"""
        if len(monitor_images) == 0:
            return
        
        epoch = trainer.epoch
        
        # Skip first few epochs to save time
        if epoch < 5:
            return
        
        # Run predictions on monitor images
        print(f"\n{'='*60}")
        print(f"Generating trashcan monitoring visualizations for epoch {epoch}...")
        print(f"{'='*60}")
        
        # Save current training state
        was_training = trainer.model.training
        
        try:
            # Put model in eval mode and disable gradients
            trainer.model.eval()
            
            # Create figure with subplots
            n_images = len(monitor_images)
            n_cols = min(3, n_images)
            n_rows = (n_images + n_cols - 1) // n_cols
            
            fig = plt.figure(figsize=(6 * n_cols, 5 * n_rows))
            gs = GridSpec(n_rows, n_cols, figure=fig, hspace=0.3, wspace=0.3)
            
            with torch.no_grad():
                for idx, img_path in enumerate(monitor_images):
                    try:
                        # Use the YOLO wrapper's predict method
                        # Create temporary YOLO instance with current weights
                        temp_model = YOLO(trainer.best if hasattr(trainer, 'best') else trainer.last)
                        results = temp_model.predict(img_path, conf=0.25, verbose=False)
                        
                        # Get the plotted image
                        result = results[0]
                        plotted_img = result.plot()
                        
                        # Convert BGR to RGB
                        plotted_img = cv2.cvtColor(plotted_img, cv2.COLOR_BGR2RGB)
                        
                        # Add to subplot
                        row = idx // n_cols
                        col = idx % n_cols
                        ax = fig.add_subplot(gs[row, col])
                        ax.imshow(plotted_img)
                        ax.axis('off')
                        
                        # Count detections by class
                        if result.masks is not None and len(result.boxes) > 0:
                            classes_detected = result.boxes.cls.cpu().numpy()
                            n_balls = int((classes_detected == 0).sum())
                            n_humans = int((classes_detected == 1).sum())
                            n_trashcans = int((classes_detected == 2).sum())
                            
                            title = f"{Path(img_path).name}\n"
                            title += f"Balls: {n_balls} | Humans: {n_humans} | Trashcans: {n_trashcans}"
                        else:
                            title = f"{Path(img_path).name}\nNo detections"
                        
                        ax.set_title(title, fontsize=10)
                        
                    except Exception as e:
                        print(f"  ‚ö†Ô∏è  Error processing {Path(img_path).name}: {e}")
            
            # Add main title
            fig.suptitle(f"Trashcan Monitoring - Epoch {epoch} - {project_name}", 
                         fontsize=16, fontweight='bold', y=0.995)
            
            # Save figure
            output_file = output_dir / f"epoch_{epoch:03d}.jpg"
            plt.savefig(output_file, dpi=150, bbox_inches='tight')
            plt.close(fig)
            
            print(f"‚úì Saved monitoring visualization:")
            output_dir.display()
            print(f"{'='*60}\n")
            
        finally:
            # Restore training mode
            if was_training:
                trainer.model.train()
    
    return on_train_epoch_end

## Step 5: Train Model with Trashcan Monitoring

Train YOLOv11 with:
- Data augmentation on train set
- Checkpoints saved for best model
- Validation after each epoch
- **Custom callback to monitor trashcan segmentation progress**

In [None]:
# Load pretrained model
model = YOLO(model_type)

In [None]:
project_name = 'ball_person_trashcan_model_v3'

In [None]:
# Setup trashcan monitoring
monitor_output_dir = RUNS_DIR / project_name / 'trashcan_monitor'
print(f"Trashcan monitoring output: {monitor_output_dir}")

# Add callback
callback_fn = create_trashcan_monitor_callback(
    model=model,
    monitor_images=TRASHCAN_MONITOR_IMAGES,
    output_dir=monitor_output_dir,
    project_name=project_name
)

model.add_callback('on_train_epoch_end', callback_fn)
print("‚úì Trashcan monitoring callback registered")

In [None]:
# Train model
head_idx = next((i for i, m in enumerate(model.model.model) if 'Detect' in m.__class__.__name__ or 'Segment' in m.__class__.__name__), len(model.model.model) - 1)

results = model.train(
    data=str(config_path),
    epochs=EPOCHS,
    freeze=list(range(head_idx)),
    batch=BATCH_SIZE,
    imgsz=IMG_SIZE,
    device=DEVICE,
    project=str(RUNS_DIR),
    name=project_name,
    exist_ok=True,
    
    # Checkpointing
    save=True,
    save_period=5,  # Save every 5 epochs
    
    # Validation
    val=True,
    
    # Data augmentation (R√âDUITE - trashcans pr√©-augment√©es massivement)
    **AUG_CONFIG,
    
    # Optimizer
    optimizer='Adam',
    lr0=0.001,
    lrf=0.01,
    momentum=0.937,
    weight_decay=0.0005,
    
    # Loss weights - Ajust√© pour dataset avec trashcans augment√©es
    # Avec l'augmentation massive des trashcans, on peut r√©duire cls
    box=7.5,
    cls=1.0,      # R√©duit de 20.0 √† 1.0 car trashcans maintenant bien repr√©sent√©es
    dfl=1.5,
    
    # Other
    patience=20,  # Early stopping
    workers=8,
    verbose=True
)


## Step 6: Load Best Model & Evaluate

In [None]:
best_model_path = RUNS_DIR / project_name / 'weights' / 'best.pt'
best_model_path.display()
model = YOLO(best_model_path)

## Step 7: Evaluate Results

In [None]:
# Validation metrics
metrics = model.val()

print("\n" + "="*60)
print("VALIDATION METRICS")
print("="*60)
print(f"Box mAP50: {metrics.box.map50:.4f}")
print(f"Box mAP50-95: {metrics.box.map:.4f}")
print(f"Mask mAP50: {metrics.seg.map50:.4f}")
print(f"Mask mAP50-95: {metrics.seg.map:.4f}")

# Per-class metrics
print("\n" + "="*60)
print("PER-CLASS METRICS (Segmentation)")
print("="*60)
class_names = ['red ball', 'human', 'trashcan']
for i, class_name in enumerate(class_names):
    try:
        map50 = metrics.seg.map50_per_class[i] if hasattr(metrics.seg, 'map50_per_class') else 0
        map_val = metrics.seg.map_per_class[i] if hasattr(metrics.seg, 'map_per_class') else 0
        print(f"{class_name:12s}: mAP50={map50:.4f}, mAP50-95={map_val:.4f}")
    except:
        print(f"{class_name:12s}: metrics not available")
print("="*60)

In [None]:
# Find best checkpoint
model_dir = RUNS_DIR / project_name
best_model = model_dir / 'weights' / 'best.pt'
last_model = model_dir / 'weights' / 'last.pt'

print(f"Best model: ")
best_model.display()
print(f"Last model: ")
last_model.display()
print(f"Results: ")
model_dir.display()
print(f"\nTrashcan monitoring visualizations: ")
(model_dir / 'trashcan_monitor').display()

## Step 8: Visualize Trashcan Progress Evolution

Create a comparison showing how trashcan segmentation improved over epochs

In [None]:
# List all monitoring visualizations
monitor_dir = RUNS_DIR / project_name / 'trashcan_monitor'

if monitor_dir.exists():
    viz_files = sorted(monitor_dir.glob("epoch_*.jpg"))
    print(f"Found {len(viz_files)} monitoring visualizations:")
    for viz_file in viz_files:
        print(f"  - {viz_file.name}")
    
    if len(viz_files) > 0:
        print(f"\nüí° Tip: Open the images in {monitor_dir} to see how trashcan segmentation evolved!")
        print(f"   You can use an image viewer or VS Code to flip through them chronologically.")
else:
    print("No monitoring visualizations found.")

## Step 9: Test on Sample Images (Optional)

In [None]:
# Test on validation images (sample from val set)
test_images = list((YOLO_DATASET / "val" / "images").glob("*"))[:10]

print(f"Testing on {len(test_images)} sample images...")

for img_path in test_images:
    results = model.predict(str(img_path), save=True, conf=0.25)
    print(f"  ‚úì {img_path.name}")

print(f"\nResults saved to: {RUNS_DIR / project_name}")