In [1]:
from pathlib import Path
import random
from ultralytics.models import YOLO
import torch
import yaml
import cv2
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

random.seed(42)

In [2]:
from utils import DisplayPath
Path = DisplayPath

## Step 2: Configure Paths & Hyperparameters

In [3]:
# Dataset path (created by e2e_data_prep.ipynb)
YOLO_DATASET = Path("datasets/ready/full_dataset")
RUNS_DIR = Path("runs/segment")

# Verify dataset exists
if not YOLO_DATASET.exists():
    raise FileNotFoundError(f"Dataset not found at {YOLO_DATASET}. Run e2e_data_prep.ipynb first!")

print("Dataset:")
YOLO_DATASET.display()
print("  Train:")
(YOLO_DATASET / 'train').display()
print("  Val:")
(YOLO_DATASET / 'val').display()
print("  Test:")
(YOLO_DATASET / 'test').display()

Dataset:


[datasets/ready/full_dataset](datasets/ready/full_dataset)

  Train:


[datasets/ready/full_dataset/train](datasets/ready/full_dataset/train)

  Val:


[datasets/ready/full_dataset/val](datasets/ready/full_dataset/val)

  Test:


[datasets/ready/full_dataset/test](datasets/ready/full_dataset/test)

In [4]:
EPOCHS = 50
BATCH_SIZE = 16
IMG_SIZE = 640
model_type = "yolo11n-seg.pt"
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'

print(f"Device: {DEVICE}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA: {torch.version.cuda}")

Device: cuda:0
GPU: NVIDIA GeForce RTX 3080 Laptop GPU
CUDA: 12.8


In [5]:
# # Configuration d'augmentation R√âDUITE pour le training YOLO
# # Puisqu'on pr√©-augmente massivement les trashcans, on r√©duit l'augmentation
# # globale pour √©viter de sur-augmenter les red balls et humans
# AUG_CONFIG = {
#     'hsv_h': 0.010,  # Hue augmentation (r√©duit de 0.015)
#     'hsv_s': 0.5,    # Saturation (r√©duit de 0.7)
#     'hsv_v': 0.3,    # Value (r√©duit de 0.4)
#     'degrees': 5.0,   # Rotation (r√©duit de 10.0)
#     'translate': 0.05, # Translation (r√©duit de 0.1)
#     'scale': 0.3,     # Scaling (r√©duit de 0.5)
#     'shear': 0.0,     # Shearing
#     'perspective': 0.0, # Perspective
#     'flipud': 0.0,    # Vertical flip
#     'fliplr': 0.5,    # Horizontal flip (maintenu)
#     'mosaic': 0.5,    # Mosaic augmentation (r√©duit de 1.0)
#     'mixup': 0.0,     # Mixup augmentation
#     'copy_paste': 0.3, # üÜï Copy-paste aug pour classes rares
# }

# print("Augmentation globale R√âDUITE pour √©viter la sur-augmentation")
# print("   Les trashcans sont pr√©-augment√©es massivement avant le training")

## Step 3: Verify Dataset Structure

Dataset is already prepared by e2e_data_prep.ipynb

In [6]:
# Verify dataset structure
print("="*60)
print("DATASET VERIFICATION")
print("="*60)

splits = ['train', 'val', 'test']
stats = {}

for split in splits:
    img_dir = YOLO_DATASET / split / "images"
    lbl_dir = YOLO_DATASET / split / "labels"
    
    if img_dir.exists() and lbl_dir.exists():
        num_images = len(list(img_dir.glob("*")))
        num_labels = len(list(lbl_dir.glob("*.txt")))
        stats[split] = {'images': num_images, 'labels': num_labels}
        print(f"{split.upper():5s}: {num_images:4d} images, {num_labels:4d} labels")
    else:
        stats[split] = {'images': 0, 'labels': 0}
        print(f"{split.upper():5s}: Missing!")

total_images = sum(s['images'] for s in stats.values())
total_labels = sum(s['labels'] for s in stats.values())

print(f"{'TOTAL':5s}: {total_images:4d} images, {total_labels:4d} labels")
print("="*60)

if total_images == 0:
    raise RuntimeError("No dataset found! Run e2e_data_prep.ipynb to create the dataset.")

DATASET VERIFICATION
TRAIN: 3950 images, 3950 labels
VAL  :   47 images,   47 labels
TEST :  214 images,  214 labels
TOTAL: 4211 images, 4211 labels


## Step 3.5: Analyze Class Distribution

Check the distribution of classes in the training set to identify imbalances

In [None]:
# Analyze class distribution in training set
from src.data_utils import count_class_instances

print("="*60)
print("CLASS DISTRIBUTION ANALYSIS")
print("="*60)

class_names = {0: 'Red Ball', 1: 'Human', 2: 'Trashcan'}

for split in ['train']:
    counts = count_class_instances(YOLO_DATASET, split)
    total = sum(counts.values())

    print(f"\n{split.upper()}:")
    for class_id, count in counts.items():
        percentage = (count / total * 100) if total > 0 else 0
        print(f"  {class_names[class_id]:12s} (class {class_id}): {count:5d} instances ({percentage:5.1f}%)")
    print(f"  {'TOTAL':12s}           : {total:5d} instances")
    
    # Calculate imbalance ratio
    if counts[2] > 0:  # If trashcans exist
        max_count = max(counts.values())
        min_count = min(v for v in counts.values() if v > 0)
        imbalance_ratio = max_count / min_count
        print(f"  Imbalance ratio: {imbalance_ratio:.1f}x")

print("\n" + "="*60)

CLASS DISTRIBUTION ANALYSIS

TRAIN:
  Red Ball     (class 0): 11288 instances ( 85.9%)
  Human        (class 1):  1008 instances (  7.7%)
  Trashcan     (class 2):   840 instances (  6.4%)
  TOTAL                  : 13136 instances
  Imbalance ratio: 13.4x



In [8]:
# Import augmentation utilities
import albumentations as A
from src.augmentation import augment_class_dataset

# Define custom augmentation pipelines
trashcan_pipeline = A.Compose([  # type: ignore
    A.RandomRotate90(p=0.5),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.3),
    A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.8),
    A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.8),
    A.GaussNoise(var_limit=(10.0, 50.0), p=0.5),
    A.RandomScale(scale_limit=0.3, p=0.7),
    A.Affine(rotate=(-20, 20), translate_percent=0.1, scale=(0.8, 1.2), shear=(-10, 10), p=0.7),
], keypoint_params=A.KeypointParams(format='xy', remove_invisible=False))

# Execute augmentation
DO_AUGMENTATION = False 
if DO_AUGMENTATION:
    # Augment Trashcans (Class 2) - Strong augmentation
    print("Augmenting Trashcans...")
    augment_class_dataset(
        dataset_path=YOLO_DATASET,
        class_id=2,
        num_augmentations=20,
        aug_config=trashcan_pipeline
    )
    
    # Example: Augment Balls (Class 0) - Light augmentation 
    augment_class_dataset(
        dataset_path=YOLO_DATASET,
        class_id=0,
        num_augmentations=3,
        aug_config='light'
    )

  A.GaussNoise(var_limit=(10.0, 50.0), p=0.5),


### üîÑ Reset: Clean Augmented Images

If you accidentally ran augmentation multiple times, use this to remove all augmented copies and start fresh

In [None]:
from src.augmentation import clean_augmented_images

# Execute cleanup (set to True to clean)
CLEAN_AUGMENTED = False

if CLEAN_AUGMENTED:
    removed = clean_augmented_images(YOLO_DATASET)
    print(f"\nüí° Tip: Re-run the class distribution analysis to see updated statistics!")
else:
    print("Set CLEAN_AUGMENTED = True to remove all augmented images")

Set CLEAN_AUGMENTED = True to remove all augmented images


## Step 4: Create YOLO Configuration File

In [10]:
classes = {
    'red ball': 0,
    'human': 1,
    'trashcan': 2
}

config = {
    'path': str(YOLO_DATASET.absolute()),
    'train': 'train/images',
    'val': 'val/images',
    'nc': len(classes),
    'names': list(classes.keys())
}

config_path = YOLO_DATASET / 'data.yaml'
with open(config_path, 'w') as f:
    yaml.dump(config, f, default_flow_style=False)

print(f"‚úì Configuration saved: {config_path}")
print("Dataset structure:")
YOLO_DATASET.display()
print("  Train:")
(YOLO_DATASET / 'train').display()
print("  Val:")
(YOLO_DATASET / 'val').display()
print("  Test:")
(YOLO_DATASET / 'test').display()

‚úì Configuration saved: datasets/ready/full_dataset/data.yaml
Dataset structure:


[datasets/ready/full_dataset](datasets/ready/full_dataset)

  Train:


[datasets/ready/full_dataset/train](datasets/ready/full_dataset/train)

  Val:


[datasets/ready/full_dataset/val](datasets/ready/full_dataset/val)

  Test:


[datasets/ready/full_dataset/test](datasets/ready/full_dataset/test)

## Step 4.5: Select Monitoring Images

Select diverse validation images covering all classes to monitor training progress

In [None]:
from src.data_utils import select_diverse_monitoring_images

# Select diverse monitoring images
val_labels_dir = YOLO_DATASET / 'val' / 'labels'
val_images_dir = YOLO_DATASET / 'val' / 'images'

MONITOR_IMAGES = select_diverse_monitoring_images(
    val_labels_dir, 
    val_images_dir, 
    images_per_class=3,
    include_mixed=True
)

if len(MONITOR_IMAGES) == 0:
    print("‚ö†Ô∏è  Warning: No monitoring images found in validation set!")

SELECTING DIVERSE MONITORING IMAGES

Searching for Red Ball images...


  ‚úì Found 3 images with Red Ball
    - a6b631525ff8dc9bc26bb53e97481606.jpg
    - d634b2ebe37b5735c16949926ae2d7bb.jpg
    - 088610ac49bfde8470097a40c8b749d7.jpg

Searching for Human images...
  ‚úì Found 3 images with Human
    - a6b631525ff8dc9bc26bb53e97481606.jpg
    - d634b2ebe37b5735c16949926ae2d7bb.jpg
    - 088610ac49bfde8470097a40c8b749d7.jpg

Searching for Trashcan images...
  ‚úì Found 3 images with Trashcan
    - de96513c1cdccc864e7b7e809162d06c.jpg
    - 1383c4a58a26c7238fbae31ec8e4e660.jpg
    - f7c36eabf5a95cf548a382f4a6b49050.jpg

Searching for mixed-class images...
  ‚úì Mixed: b2dcc2d8cc39ca13b5cf2e24bf036d62.jpg (Red Ball, Human)
  ‚úì Mixed: 9666760cde1bedabf437f3dbfc95f891.jpg (Red Ball, Human)
  ‚úì Mixed: d8ce4c69557862b45c0f7c18d0d3a412.jpg (Red Ball, Human)

Selected 12 images for monitoring



## Step 4.6: Define Multi-Class Monitoring Callback

Create a callback that visualizes segmentation progress for all classes at each epoch

In [None]:
from src.visualization import create_monitoring_callback

## Step 5: Train Model with Multi-Class Monitoring

Train YOLOv11 with:
- Data augmentation on train set (reduced for abundant classes)
- Trashcan pre-augmentation for class balance
- Checkpoints saved for best model
- Validation after each epoch
- **Custom callback to monitor all classes segmentation progress**

In [13]:
# Load pretrained model
model = YOLO(model_type)

In [5]:
project_name = 'ball_person_trashcan_model_v5'

In [15]:
# Setup multi-class monitoring
monitor_output_dir = RUNS_DIR / project_name / 'training_monitor'
print(f"Multi-class monitoring output: {monitor_output_dir}")

# Add callback
callback_fn = create_monitoring_callback(
    model=model,
    monitor_images=MONITOR_IMAGES,
    output_dir=monitor_output_dir,
    project_name=project_name
)

model.add_callback('on_train_epoch_end', callback_fn)
print("‚úì Multi-class monitoring callback registered")
print(f"   Monitoring {len(MONITOR_IMAGES)} images covering all classes")

Multi-class monitoring output: runs/segment/ball_person_trashcan_model_v4/training_monitor
‚úì Multi-class monitoring callback registered
   Monitoring 12 images covering all classes


In [16]:
# Train model
head_idx = next((i for i, m in enumerate(model.model.model) if 'Detect' in m.__class__.__name__ or 'Segment' in m.__class__.__name__), len(model.model.model) - 1)

results = model.train(
    data=str(config_path),
    epochs=EPOCHS,
    freeze=list(range(head_idx)),
    batch=BATCH_SIZE,
    imgsz=IMG_SIZE,
    device=DEVICE,
    project=str(RUNS_DIR),
    name=project_name,
    exist_ok=True,
    
    # Checkpointing
    save=True,
    save_period=5,  # Save every 5 epochs
    
    # Validation
    val=True,
    
    # Data augmentation
    # **AUG_CONFIG,
    
    # Optimizer
    optimizer='Adam',
    lr0=0.001,
    lrf=0.01,
    momentum=0.937,
    weight_decay=0.0005,
    
    # Loss weights - Ajust√© pour dataset avec trashcans augment√©es
    # Avec l'augmentation massive des trashcans, on peut r√©duire cls
    box=7.5,
    cls=1.0,      # R√©duit de 20.0 √† 1.0 car trashcans maintenant bien repr√©sent√©es
    dfl=1.5,
    
    # Other
    patience=20,  # Early stopping
    workers=8,
    verbose=True
)


Ultralytics 8.3.235 üöÄ Python-3.12.10 torch-2.9.1+cu128 CUDA:0 (NVIDIA GeForce RTX 3080 Laptop GPU, 8192MiB)
[34m[1mengine/trainer: [0magnostic_nms=False, amp=True, augment=False, auto_augment=randaugment, batch=16, bgr=0.0, box=7.5, cache=False, cfg=None, classes=None, close_mosaic=10, cls=1.0, compile=False, conf=None, copy_paste=0.0, copy_paste_mode=flip, cos_lr=False, cutmix=0.0, data=datasets/ready/full_dataset/data.yaml, degrees=0.0, deterministic=True, device=0, dfl=1.5, dnn=False, dropout=0.0, dynamic=False, embed=None, epochs=50, erasing=0.4, exist_ok=True, fliplr=0.5, flipud=0.0, format=torchscript, fraction=1.0, freeze=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22], half=False, hsv_h=0.015, hsv_s=0.7, hsv_v=0.4, imgsz=640, int8=False, iou=0.7, keras=False, kobj=1.0, line_width=None, lr0=0.001, lrf=0.01, mask_ratio=4, max_det=300, mixup=0.0, mode=train, model=yolo11n-seg.pt, momentum=0.937, mosaic=1.0, multi_scale=False, name=ball_pers

## Step 6: Load Best Model & Evaluate

In [6]:
best_model_path = RUNS_DIR / project_name / 'weights' / 'best.pt'
best_model_path.display()
model = YOLO(best_model_path)

[runs/segment/ball_person_trashcan_model_v4/weights/best.pt](runs/segment/ball_person_trashcan_model_v4/weights/best.pt)

## Step 7: Evaluate Results

In [7]:
# Validation metrics
metrics = model.val()

print("\n" + "="*60)
print("VALIDATION METRICS")
print("="*60)
print(f"Box mAP50: {metrics.box.map50:.4f}")
print(f"Box mAP50-95: {metrics.box.map:.4f}")
print(f"Mask mAP50: {metrics.seg.map50:.4f}")
print(f"Mask mAP50-95: {metrics.seg.map:.4f}")

# Per-class metrics
print("\n" + "="*60)
print("PER-CLASS METRICS (Segmentation)")
print("="*60)
class_names = ['red ball', 'human', 'trashcan']
for i, class_name in enumerate(class_names):
    try:
        map50 = metrics.seg.map50_per_class[i] if hasattr(metrics.seg, 'map50_per_class') else 0
        map_val = metrics.seg.map_per_class[i] if hasattr(metrics.seg, 'map_per_class') else 0
        print(f"{class_name:12s}: mAP50={map50:.4f}, mAP50-95={map_val:.4f}")
    except:
        print(f"{class_name:12s}: metrics not available")
print("="*60)

Ultralytics 8.3.235 üöÄ Python-3.12.10 torch-2.9.1+cu128 CUDA:0 (NVIDIA GeForce RTX 3080 Laptop GPU, 8192MiB)
YOLO11n-seg summary (fused): 113 layers, 2,835,153 parameters, 0 gradients, 9.6 GFLOPs
[34m[1mval: [0mFast image access ‚úÖ (ping: 0.0¬±0.0 ms, read: 463.9¬±29.2 MB/s, size: 2696.9 KB)
[K[34m[1mval: [0mScanning /home/tonino/projects/ball segmentation/datasets/ready/full_dataset/val/labels.cache... 47 images, 0 backgrounds, 0 corrupt: 100% ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ 47/47 55.8Kit/s 0.0s
[K                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95)     Mask(P          R      mAP50  mAP50-95): 100% ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ 3/3 6.6s/it 19.9s5.5s0
                   all         47         81      0.723      0.705      0.757        0.7      0.723      0.705      0.755      0.655
              red ball         34         34      0.711      0.435      0.558      0.467      0.711      0.435      0.553      0.433
               

In [8]:
# Find best checkpoint
model_dir = RUNS_DIR / project_name
best_model = model_dir / 'weights' / 'best.pt'
last_model = model_dir / 'weights' / 'last.pt'

print(f"Best model: ")
best_model.display()
print(f"Last model: ")
last_model.display()
print(f"Results: ")
model_dir.display()
print(f"\nMulti-class monitoring visualizations: ")
(model_dir / 'training_monitor').display()

Best model: 


[runs/segment/ball_person_trashcan_model_v4/weights/best.pt](runs/segment/ball_person_trashcan_model_v4/weights/best.pt)

Last model: 


[runs/segment/ball_person_trashcan_model_v4/weights/last.pt](runs/segment/ball_person_trashcan_model_v4/weights/last.pt)

Results: 


[runs/segment/ball_person_trashcan_model_v4](runs/segment/ball_person_trashcan_model_v4)


Multi-class monitoring visualizations: 


[runs/segment/ball_person_trashcan_model_v4/training_monitor](runs/segment/ball_person_trashcan_model_v4/training_monitor)

## Step 8: Visualize Training Progress Evolution

Review monitoring visualizations showing how segmentation improved for all classes over epochs

In [None]:
# List all monitoring visualizations
monitor_dir = RUNS_DIR / project_name / 'training_monitor'

if monitor_dir.exists():
    viz_files = sorted(monitor_dir.glob("epoch_*.jpg"))
    print(f"Found {len(viz_files)} monitoring visualizations:")
    for viz_file in viz_files:
        viz_file.display()
    
    if len(viz_files) > 0:
        print(f"\nüí° Tip: Open the images in {monitor_dir} to see how segmentation evolved for all classes!")
        print(f"   You can use an image viewer or VS Code to flip through them chronologically.")
        print(f"   Each image shows detections with: Balls | Humans | Trashcans")
else:
    print("No monitoring visualizations found.")

## Step 9: Test on Sample Images (Optional)

In [10]:
# Test on validation images (sample from val set)
test_images = list((YOLO_DATASET / "val" / "images").glob("*"))

print(f"Testing on {len(test_images)} sample images...")

for img_path in test_images:
    results = model.predict(str(img_path), save=True, conf=0.1)
    print(f"  ‚úì {img_path.name}")

print(f"\nResults saved to:")
(RUNS_DIR / project_name).display()

Testing on 47 sample images...

image 1/1 /home/tonino/projects/ball segmentation/datasets/ready/full_dataset/val/images/366acb21b00b40588372736b95776fac.jpg: 640x480 1 red ball, 1 human, 31.5ms
Speed: 4.4ms preprocess, 31.5ms inference, 6.8ms postprocess per image at shape (1, 3, 640, 480)
Results saved to [1m/home/tonino/projects/ball segmentation/runs/segment/predict10[0m
  ‚úì 366acb21b00b40588372736b95776fac.jpg

image 1/1 /home/tonino/projects/ball segmentation/datasets/ready/full_dataset/val/images/0d4db7c113776bc0a401d833d556df84.jpg: 640x480 1 red ball, 1 human, 33.6ms
Speed: 2.4ms preprocess, 33.6ms inference, 6.4ms postprocess per image at shape (1, 3, 640, 480)
Results saved to [1m/home/tonino/projects/ball segmentation/runs/segment/predict10[0m
  ‚úì 0d4db7c113776bc0a401d833d556df84.jpg

image 1/1 /home/tonino/projects/ball segmentation/datasets/ready/full_dataset/val/images/a6b631525ff8dc9bc26bb53e97481606.jpg: 640x480 1 human, 12.2ms
Speed: 2.2ms preprocess, 12.2ms i

[runs/segment/ball_person_trashcan_model_v4](runs/segment/ball_person_trashcan_model_v4)