## Step 1: Setup & Imports

In [1]:
from pathlib import Path
import shutil
from PIL import Image
import numpy as np
import cv2
import torch
from tqdm import tqdm
from src.detection import GroundingDINODetector
from src.segmentation import FastSAMSegmenter
from src.pipeline import img_pipeline

/home/tonino/projects/ball segmentation/.venv/lib/python3.12/site-packages/inference/models/utils.py:411: ModelDependencyMissing: Your `inference` configuration does not support SAM3 model. Install SAM3 dependencies and set CORE_MODEL_SAM3_ENABLED to True.


## Step 3: Load Models

- **Detector**: Grounding DINO (text prompts: "red ball" & "human")
- **Segmenter**: FastSAM (bbox-guided segmentation)

In [2]:
# Download Grounding DINO weights if needed
from pathlib import Path
import requests

checkpoint_path = Path("models/pretrained/groundingdino_swint_ogc.pth")
if not checkpoint_path.exists():
    print("Downloading Grounding DINO checkpoint (~693 MB)...")
    checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
    url = "https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth"
    response = requests.get(url, stream=True)
    total_size = int(response.headers.get('content-length', 0))
    
    with open(checkpoint_path, 'wb') as f:
        downloaded = 0
        for chunk in response.iter_content(chunk_size=8192):
            f.write(chunk)
            downloaded += len(chunk)
            if total_size > 0:
                percent = (downloaded / total_size) * 100
                print(f"\rProgress: {percent:.1f}%", end='')
    print(f"\n✓ Downloaded to {checkpoint_path}")
else:
    print(f"✓ Checkpoint found: {checkpoint_path}")

✓ Checkpoint found: models/pretrained/groundingdino_swint_ogc.pth


In [3]:
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [4]:
# Grounding DINO detector (unified detector for balls AND persons)
# Using lower thresholds to match the online demo behavior
from src.segmentation import SAMSegmenter


detector = GroundingDINODetector(
    model_checkpoint_path=str(checkpoint_path),
    box_threshold=0.5,  
    text_threshold=0.5,  
    device=DEVICE
)

segmenter = SAMSegmenter()

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


final text_encoder_type: bert-base-uncased


## Step 2: Configure Paths & Device

#### Use this one for validation

In [5]:
# Input path

IRL_RAW = Path("datasets/raw/IRL_validation_pictures")
yahoo_balls_raw = Path("datasets/cleaned/red_balls_human_yahoo_jpg")

class DatasetPaths:
    """Manages preprocessing paths with ready subfolder for final outputs"""
    def __init__(self, dataset_name: str, base=Path("datasets/preprocessed"), create=True):
        project = base / dataset_name

        # Intermediate outputs (visualizations)
        self.det_path = project / "detection"
        self.seg_path = project / "segmentation"
        self.empty_path = project / "empty"
        
        # Ready folder with images and labels
        self.ready_path = project / "ready"
        self.images_path = self.ready_path / "images"
        self.label_path = self.ready_path / "labels"
        
        paths = [
            project, 
            self.det_path, 
            self.seg_path, 
            self.empty_path,
            self.ready_path,
            self.images_path,
            self.label_path
        ]
        for path in paths:
            if create:
                path.mkdir(exist_ok=True, parents=True)

class ReadyDatasetPaths:
    """Manages final dataset paths (images, labels) - DEPRECATED: use DatasetPaths.ready_path instead"""
    def __init__(self, dataset_name: str, base=Path("datasets/ready"), create=True):
        self.root = base / dataset_name
        self.images = self.root / "images"
        self.labels = self.root / "labels"
        if create:
            self.root.mkdir(exist_ok=True, parents=True)
            self.images.mkdir(exist_ok=True)
            self.labels.mkdir(exist_ok=True)

# Intermediate outputs for detection+segmentation
yahoo_human_balls_dataset = DatasetPaths("yahoo_human_balls", create=True)

#### Use this one for mass training data

In [6]:
# Get image list

irl_img_paths = list(IRL_RAW.glob("*.jpg")) + list(IRL_RAW.glob("*.jpeg")) + \
            list(IRL_RAW.glob("*.JPG")) + list(IRL_RAW.glob("*.JPEG"))           
        
yahoo_balls_img_paths = list(yahoo_balls_raw.glob("*.jpg"))
len(irl_img_paths), len(yahoo_balls_img_paths)

(34, 598)

## Step 4: Initialize Models & Test Detection

## Step 4: Process Images (Multi-Class Pipeline)

Process each image through the unified detection+segmentation pipeline with multiple text prompts:
- **Prompt format**: "red ball . human" (separated by ' . ')
- **Output structure**:
  - `detection/` - All detected objects with colored bboxes
  - `segmentation/` - All segmented objects with colored masks
  - `labels/{label_name}/` - YOLO polygon format, one subdirectory per label

In [None]:
# Choose dataset to process
raw_images = yahoo_balls_img_paths
dataset_to_use = yahoo_human_balls_dataset

text_prompt = "red ball . human"

In [8]:
# Process all images with multi-class pipeline
for img_path in tqdm(raw_images, desc="Processing images"):
    img_pipeline(
        img_path,
        detect_fn=lambda p: detector.detect(p, text_prompt=text_prompt, return_all_by_label=True),
        segment_fn=segmenter.segment_bbox,
        det_output_dir=dataset_to_use.det_path,
        seg_output_dir=dataset_to_use.seg_path,
        txt_output_dir=dataset_to_use.label_path,
        empty_dir=dataset_to_use.empty_path,
        images_output_dir=dataset_to_use.images_path
    )

print("\n✓ Multi-class segmentation complete!")
print(f"  - Detection visualizations: {dataset_to_use.det_path}")
print(f"  - Segmentation visualizations: {dataset_to_use.seg_path}")
print(f"  - Ready dataset: {dataset_to_use.ready_path}")
print(f"    - Images: {dataset_to_use.images_path}")
print(f"    - Labels by class:")

# Parse prompts to show stats
prompts = [p.strip() for p in text_prompt.split('.') if p.strip()]
for prompt in prompts:
    label_key = prompt.strip().lower()
    label_subdir = dataset_to_use.label_path / label_key
    if label_subdir.exists():
        num_files = len(list(label_subdir.glob("*.txt")))
        print(f"      - {label_key}: {num_files} files")

  return fn(*args, **kwargs)
  with torch.cuda.amp.autocast(enabled=False):
Processing images: 100%|██████████| 598/598 [10:28<00:00,  1.05s/it]


✓ Multi-class segmentation complete!
  - Detection visualizations: datasets/preprocessed/yahoo_human_balls/detection
  - Segmentation visualizations: datasets/preprocessed/yahoo_human_balls/segmentation
  - Ready dataset: datasets/preprocessed/yahoo_human_balls/ready
    - Images: datasets/preprocessed/yahoo_human_balls/ready/images
    - Labels by class:
      - red ball: 456 files
      - human: 457 files



