## Step 1: Setup & Imports

In [1]:
from pathlib import Path
import shutil
from PIL import Image
import numpy as np
import cv2
import torch
from tqdm import tqdm

from src.detection import GroundingDINODetector
from src.segmentation import FastSAMSegmenter
from src.pipeline import img_pipeline

/home/tonino/projects/ball segmentation/.venv/lib/python3.12/site-packages/inference/models/utils.py:411: ModelDependencyMissing: Your `inference` configuration does not support SAM3 model. Install SAM3 dependencies and set CORE_MODEL_SAM3_ENABLED to True.


## Step 3: Load Models

- **Detector**: Grounding DINO (text prompts: "red ball" & "human")
- **Segmenter**: FastSAM (bbox-guided segmentation)

In [2]:
# Download Grounding DINO weights if needed
from pathlib import Path
import requests

checkpoint_path = Path("models/pretrained/groundingdino_swint_ogc.pth")
if not checkpoint_path.exists():
    print("Downloading Grounding DINO checkpoint (~693 MB)...")
    checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
    url = "https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth"
    response = requests.get(url, stream=True)
    total_size = int(response.headers.get('content-length', 0))
    
    with open(checkpoint_path, 'wb') as f:
        downloaded = 0
        for chunk in response.iter_content(chunk_size=8192):
            f.write(chunk)
            downloaded += len(chunk)
            if total_size > 0:
                percent = (downloaded / total_size) * 100
                print(f"\rProgress: {percent:.1f}%", end='')
    print(f"\nâœ“ Downloaded to {checkpoint_path}")
else:
    print(f"âœ“ Checkpoint found: {checkpoint_path}")

âœ“ Checkpoint found: models/pretrained/groundingdino_swint_ogc.pth


In [3]:
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [4]:
# Grounding DINO detector (unified detector for balls AND persons)
# Using lower thresholds to match the online demo behavior
from src.segmentation import SAMSegmenter


detector = GroundingDINODetector(
    model_checkpoint_path=str(checkpoint_path),
    box_threshold=0.20,  # Even lower threshold
    text_threshold=0.15,  # Even lower threshold
    device=DEVICE
)

segmenter = SAMSegmenter()

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


final text_encoder_type: bert-base-uncased


## Step 2: Configure Paths & Device

#### Use this one for validation

In [5]:
# Input path

IRL_RAW = Path("datasets/raw/IRL_validation_pictures")
yahoo_balls_raw = Path("datasets/cleaned/red_balls_human_yahoo_jpg")

class DatasetPaths:
    """Manages intermediate preprocessing paths (detection, segmentation, labels)"""
    def __init__(self, dataset_name: str, base=Path("datasets/preprocessed"), create=True):
        project = base / dataset_name
        if create:
            project.mkdir(exist_ok=True, parents=True)
        self.det_path = project / "detection"
        self.seg_path = project / "segmentation"
        self.label_path = project / "labels"

class ReadyDatasetPaths:
    """Manages final dataset paths (images, labels)"""
    def __init__(self, dataset_name: str, base=Path("datasets/ready"), create=True):
        self.root = base / dataset_name
        if create:
            self.root.mkdir(exist_ok=True, parents=True)
        self.images = self.root / "images"
        self.labels = self.root / "labels"
        
        if create:
            self.images.mkdir(exist_ok=True)
            self.labels.mkdir(exist_ok=True)

# Intermediate outputs for detection+segmentation
irl_balls_dataset = DatasetPaths("irl_balls")
irl_humans_dataset = DatasetPaths("irl_persons")

# Final ready dataset
irl_ready = ReadyDatasetPaths("IRL_dataset")

In [6]:
yahoo_balls_dataset = DatasetPaths("red_balls_human_yahoo")
yahoo_balls_ready = ReadyDatasetPaths("yahoo_balls_dataset")

#### Use this one for mass training data

In [7]:
# Get image list

irl_img_paths = list(IRL_RAW.glob("*.jpg")) + list(IRL_RAW.glob("*.jpeg")) + \
            list(IRL_RAW.glob("*.JPG")) + list(IRL_RAW.glob("*.JPEG"))           
        
yahoo_balls_img_paths = list(yahoo_balls_raw.glob("*.jpg"))
len(irl_img_paths), len(yahoo_balls_img_paths)

(34, 598)

## Step 4: Initialize Models & Test Detection

## Step 4: Segment Balls & Persons (Grounding DINO â†’ FastSAM â†’ YOLO txt)

Process each image through the detection+segmentation pipeline with text prompts:
- **Balls**: prompt="red ball"
- **Persons**: prompt="human"

Output: YOLO polygon format in separate txt folders

In [8]:
raw_images = yahoo_balls_img_paths
dataset_to_use = yahoo_balls_dataset

In [9]:
for img_path in tqdm(raw_images, desc="Ball segmentation"):
    img_pipeline(
        img_path,
        detect_fn=lambda p: detector.detect(p, text_prompt="red ball"),
        segment_fn=segmenter.segment_bbox,
        det_output_dir=dataset_to_use.det_path,
        seg_output_dir=dataset_to_use.seg_path,
        txt_output_dir=dataset_to_use.label_path,
        empty_dir="empty_detections",
        mode="bbox"
    )

  return fn(*args, **kwargs)
  with torch.cuda.amp.autocast(enabled=False):
Ball segmentation: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 598/598 [23:20<00:00,  2.34s/it]


In [None]:

for img_path in tqdm(irl_img_paths, desc="Person segmentation"):
    img_pipeline(
        img_path,
        detect_fn=lambda p: detector.detect(p, text_prompt="human"),
        segment_fn=segmenter.segment_bbox,
        det_output_dir=irl_humans_dataset.det_path,
        seg_output_dir=irl_humans_dataset.seg_path,
        txt_output_dir=irl_humans_dataset.label_path,
        
        empty_dir="empty_detections",
        mode="bbox"
    )
    


print("\nâœ“ Segmentation complete!")

## Step 5: Combine Ball + Person Masks

- Parse ball polygons from txt files (class 0)
- Parse person polygons from txt files (class 1)
- Combine into PNG masks (ball=0, person=1)
- **Ball has priority** over person in overlapping regions

In [None]:
# Statistics
stats = {"total": 0, "with_ball": 0, "with_person": 0, "empty": 0}

print(f"Combining ball + person masks...")
print(f"Class priority: ball > person")
print()

for img_path in tqdm(irl_img_paths, desc="Combining masks"):
    stats["total"] += 1
    
    # Load image to get dimensions
    img = Image.open(img_path)
    h, w = img.height, img.width
    
    # Initialize combined mask (all background)
    combined_mask = np.zeros((h, w), dtype=np.uint8)
    has_detections = False
    
    # --- 1. Parse ball segmentation from txt (if exists) ---
    ball_txt_path = irl_balls_dataset.label_path / (img_path.stem + '.txt')
    if ball_txt_path.exists():
        with open(ball_txt_path, 'r') as f:
            lines = f.readlines()
            
        for line in lines:
            parts = line.strip().split()
            if len(parts) < 7:  # Need at least class_id + 3 points (6 coords)
                continue
            
            # Extract normalized coordinates
            coords = [float(p) for p in parts[1:]]
            
            # Convert to pixel coordinates
            points = []
            for i in range(0, len(coords), 2):
                x = int(coords[i] * w)
                y = int(coords[i+1] * h)
                points.append([x, y])
            
            # Fill polygon with ball class (0)
            points_array = np.array(points, dtype=np.int32)
            cv2.fillPoly(combined_mask, [points_array], 0)
            has_detections = True
        
        stats["with_ball"] += 1
    
    # --- 2. Parse person segmentation from txt (if exists) ---
    person_txt_path = irl_humans_dataset.label_path / (img_path.stem + '.txt')
    if person_txt_path.exists():
        with open(person_txt_path, 'r') as f:
            lines = f.readlines()
            
        for line in lines:
            parts = line.strip().split()
            if len(parts) < 7:
                continue
            
            # Extract normalized coordinates
            coords = [float(p) for p in parts[1:]]
            
            # Convert to pixel coordinates
            points = []
            for i in range(0, len(coords), 2):
                x = int(coords[i] * w)
                y = int(coords[i+1] * h)
                points.append([x, y])
            
            # Fill polygon with person class (1) ONLY where background
            # This ensures ball priority
            temp_mask = np.zeros((h, w), dtype=np.uint8)
            points_array = np.array(points, dtype=np.int32)
            cv2.fillPoly(temp_mask, [points_array], 1)
            
            person_area = (combined_mask == 0) & (temp_mask == 1)
            combined_mask[person_area] = 1
            has_detections = True
        
        stats["with_person"] += 1
    
    # Track empty images
    if not has_detections:
        stats["empty"] += 1
    
    # Save mask (even if empty)
    mask_img = Image.fromarray(combined_mask, mode='L')
    mask_img.save(irl_ready.labels / (img_path.stem + '.png'))
    
    # Copy original image
    shutil.copy(img_path, irl_ready.images / img_path.name)

print("\nâœ“ Processing complete!")

## Step 6: Display Statistics

In [None]:
print("=" * 60)
print("ðŸ“Š DATASET STATISTICS")
print("=" * 60)
print(f"Total images processed:    {stats['total']}")
print(f"Images with ball(s):       {stats['with_ball']} ({stats['with_ball']/stats['total']*100:.1f}%)")
print(f"Images with person(s):     {stats['with_person']} ({stats['with_person']/stats['total']*100:.1f}%)")
print(f"Images with no detections: {stats['empty']} ({stats['empty']/stats['total']*100:.1f}%)")
print("=" * 60)

# Verify dataset consistency
num_images = len(list(irl_ready.images.glob("*")))
num_labels = len(list(irl_ready.labels.glob("*.png")))

print(f"\nâœ“ Dataset consistency check:")
print(f"  Images: {num_images}")
print(f"  Labels: {num_labels}")
print(f"  Match: {'âœ“ YES' if num_images == num_labels else 'âœ— NO'}")

print(f"\nâœ“ Dataset ready at: {irl_ready.root}")
print(f"  - images/  ({num_images} files)")
print(f"  - labels/  ({num_labels} .png masks)")