## Step 1: Setup & Imports

In [None]:
from pathlib import Path
import shutil
from PIL import Image
import numpy as np
import cv2
from ultralytics.models import YOLO
import torch
from tqdm import tqdm

from src.detection import RoboflowDetector
from src.segmentation import SAMSegmenter  # SAM 2.1 instead of FastSAM
from src.pipeline import img_pipeline

## Step 2: Configure Paths & Device

In [None]:
# Input path
IRL_RAW = Path("datasets/raw/IRL_validation_pictures")

# Output paths
IRL_READY = Path("datasets/ready/IRL_dataset_sam2")
IRL_IMAGES = IRL_READY / "images"
IRL_LABELS = IRL_READY / "labels"

# Intermediate outputs for ball detection+segmentation
BALL_DET_OUTPUT = Path("datasets/preprocessed/irl_balls_sam2/detection")
BALL_SEG_OUTPUT = Path("datasets/preprocessed/irl_balls_sam2/segmentation")
BALL_TXT_OUTPUT = Path("datasets/preprocessed/irl_balls_sam2/labels")

# Create directories
for dir_path in [IRL_IMAGES, IRL_LABELS, BALL_DET_OUTPUT, BALL_SEG_OUTPUT, BALL_TXT_OUTPUT]:
    dir_path.mkdir(parents=True, exist_ok=True)

# Device
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'

print(f"Device: {DEVICE}")
print(f"Input: {IRL_RAW}")
print(f"Output: {IRL_READY}")

# Get image list
img_paths = list(IRL_RAW.glob("*.jpg")) + list(IRL_RAW.glob("*.jpeg")) + \
            list(IRL_RAW.glob("*.JPG")) + list(IRL_RAW.glob("*.JPEG"))
print(f"Found {len(img_paths)} images")

## Step 3: Load Models

- **Ball**: Roboflow detector (red-ball-detection-new/1) + **SAM 2.1** segmenter
- **Person**: YOLO-seg pretrained (yolo11n-seg.pt)

In [None]:
# Ball detection + segmentation pipeline (Roboflow + SAM 2.1)
# Model ID is hardcoded in RoboflowDetector, API key loaded from .env
ball_detector = RoboflowDetector()
ball_segmenter = SAMSegmenter()  # Uses SAM 2.1 (sam2.1_b.pt)

# Person model (pretrained YOLO-seg)
PERSON_MODEL_PATH = Path('models/pretrained/yolo11n-seg.pt')
person_model = YOLO(str(PERSON_MODEL_PATH))

print(f"âœ“ Ball detector (Roboflow): {ball_detector.DEFAULT_MODEL_ID}")
print(f"âœ“ Ball segmenter: SAM 2.1 (sam2.1_b.pt)")
print(f"âœ“ Person model: {PERSON_MODEL_PATH}")

## Step 4: Segment Balls (Roboflow â†’ SAM 2.1 â†’ YOLO txt)

Process each image through the detection+segmentation pipeline using SAM 2.1 with bbox prompts.

**Note:** SAM 2.1 is more accurate but slower than FastSAM.

In [None]:
print(f"Processing {len(img_paths)} images for ball segmentation...")
print(f"Pipeline: Roboflow detection â†’ SAM 2.1 segmentation â†’ YOLO txt")
print()

for img_path in tqdm(img_paths, desc="Ball segmentation (SAM 2.1)"):
    img_pipeline(
        img_path,
        detect_fn=ball_detector.detect,
        segment_fn=ball_segmenter.segment_bbox,
        det_output_dir=BALL_DET_OUTPUT,
        seg_output_dir=BALL_SEG_OUTPUT,
        txt_output_dir=BALL_TXT_OUTPUT,
        mode="bbox"
    )

print("âœ“ Ball segmentation complete!")

## Step 5: Combine Ball + Person Masks

- Parse ball polygons from txt files
- Segment persons with YOLO-seg
- Combine into PNG masks (ball=0, person=1)
- **Ball has priority** over person in overlapping regions

In [None]:
# Configuration
CONF_THRESHOLD = 0.5  # Person confidence threshold

# Statistics
stats = {"total": 0, "with_ball": 0, "with_person": 0, "empty": 0}

print(f"Combining ball + person masks...")
print(f"Confidence threshold (person): {CONF_THRESHOLD}")
print(f"Class priority: ball > person")
print()

for img_path in tqdm(img_paths, desc="Combining masks"):
    stats["total"] += 1
    
    # Load image to get dimensions
    img = Image.open(img_path)
    h, w = img.height, img.width
    
    # Initialize combined mask (all background)
    combined_mask = np.zeros((h, w), dtype=np.uint8)
    has_detections = False
    
    # --- 1. Parse ball segmentation from txt (if exists) ---
    ball_txt_path = BALL_TXT_OUTPUT / (img_path.stem + '.txt')
    if ball_txt_path.exists():
        with open(ball_txt_path, 'r') as f:
            lines = f.readlines()
            
        for line in lines:
            parts = line.strip().split()
            if len(parts) < 7:  # Need at least class_id + 3 points (6 coords)
                continue
            
            # Extract normalized coordinates
            coords = [float(p) for p in parts[1:]]
            
            # Convert to pixel coordinates
            points = []
            for i in range(0, len(coords), 2):
                x = int(coords[i] * w)
                y = int(coords[i+1] * h)
                points.append([x, y])
            
            # Fill polygon with ball class (0)
            points_array = np.array(points, dtype=np.int32)
            cv2.fillPoly(combined_mask, [points_array], 0)
            has_detections = True
        
        stats["with_ball"] += 1
    
    # --- 2. Segment Persons (class 1) ---
    person_results = person_model.predict(
        str(img_path), 
        classes=[0],  # Person class in COCO
        conf=CONF_THRESHOLD, 
        device=DEVICE,
        verbose=False
    )
    
    if person_results[0].masks is not None:
        for mask in person_results[0].masks.data:
            mask_np = (mask.cpu().numpy() > 0.5).astype(np.uint8)
            
            # Resize if needed
            if mask_np.shape != (h, w):
                mask_np = cv2.resize(mask_np, (w, h), interpolation=cv2.INTER_NEAREST)
            
            # Add person ONLY where background (ensures ball priority)
            person_area = (combined_mask == 0) & (mask_np == 1)
            combined_mask[person_area] = 1
            has_detections = True
        
        stats["with_person"] += 1
    
    # Track empty images
    if not has_detections:
        stats["empty"] += 1
    
    # Save mask (even if empty)
    mask_img = Image.fromarray(combined_mask, mode='L')
    mask_img.save(IRL_LABELS / (img_path.stem + '.png'))
    
    # Copy original image
    shutil.copy(img_path, IRL_IMAGES / img_path.name)

print("\nâœ“ Processing complete!")

## Step 6: Display Statistics

In [None]:
print("=" * 60)
print("ðŸ“Š DATASET STATISTICS (SAM 2.1)")
print("=" * 60)
print(f"Total images processed:    {stats['total']}")
print(f"Images with ball(s):       {stats['with_ball']} ({stats['with_ball']/stats['total']*100:.1f}%)")
print(f"Images with person(s):     {stats['with_person']} ({stats['with_person']/stats['total']*100:.1f}%)")
print(f"Images with no detections: {stats['empty']} ({stats['empty']/stats['total']*100:.1f}%)")
print("=" * 60)

# Verify dataset consistency
num_images = len(list(IRL_IMAGES.glob("*")))
num_labels = len(list(IRL_LABELS.glob("*.png")))

print(f"\nâœ“ Dataset consistency check:")
print(f"  Images: {num_images}")
print(f"  Labels: {num_labels}")
print(f"  Match: {'âœ“ YES' if num_images == num_labels else 'âœ— NO'}")

print(f"\nâœ“ Dataset ready at: {IRL_READY}")
print(f"  - images/  ({num_images} files)")
print(f"  - labels/  ({num_labels} .png masks)")

## Bonus: Visualize SAM 2.1 vs FastSAM Comparison

Compare segmentation quality between the two approaches.

In [None]:
import matplotlib.pyplot as plt

# Select a sample image
sample_img_path = img_paths[0]
print(f"Sample: {sample_img_path.name}")

# Load original image
img = Image.open(sample_img_path)

# Load SAM 2.1 mask
sam2_mask_path = IRL_LABELS / (sample_img_path.stem + '.png')
sam2_mask = np.array(Image.open(sam2_mask_path))

# Load FastSAM mask (if exists)
fastsam_mask_path = Path("datasets/ready/IRL_dataset/labels") / (sample_img_path.stem + '.png')
if fastsam_mask_path.exists():
    fastsam_mask = np.array(Image.open(fastsam_mask_path))
    
    # Visualize comparison
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    axes[0].imshow(img)
    axes[0].set_title("Original Image")
    axes[0].axis('off')
    
    axes[1].imshow(fastsam_mask, cmap='tab10', vmin=0, vmax=9)
    axes[1].set_title("FastSAM Mask")
    axes[1].axis('off')
    
    axes[2].imshow(sam2_mask, cmap='tab10', vmin=0, vmax=9)
    axes[2].set_title("SAM 2.1 Mask")
    axes[2].axis('off')
    
    plt.tight_layout()
    plt.show()
else:
    print("FastSAM masks not found. Run data_preparation_clean.ipynb first for comparison.")
    
    # Show only SAM 2.1 result
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))
    
    axes[0].imshow(img)
    axes[0].set_title("Original Image")
    axes[0].axis('off')
    
    axes[1].imshow(sam2_mask, cmap='tab10', vmin=0, vmax=9)
    axes[1].set_title("SAM 2.1 Mask")
    axes[1].axis('off')
    
    plt.tight_layout()
    plt.show()