In [None]:
from pathlib import Path
import shutil
import hashlib
import random
import cv2
import numpy as np
import torch
from PIL import Image, ImageOps
from tqdm.notebook import tqdm

# Import project modules
from src.detection import GroundingDINODetector
from src.segmentation import SAMSegmenter

# Set random seed for reproducibility
random.seed(42)

In [None]:
# --- Configuration ---

# Source directory containing new images
SOURCE_DIR = Path("datasets/raw/red balls google dated") 

# Target Project Directory
PROJECT_DIR = Path("datasets/ready/full_dataset")

# Class Mapping
CLASS_MAP = {
    "red ball": 0,
    "human": 1,
    "trashcan": 2
}

DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'

# Count source images (recursive)
source_count = len([p for p in SOURCE_DIR.rglob("*") if p.suffix.lower() not in ['.Identifier']])

print(f"Device: {DEVICE}")
print(f"Source: {SOURCE_DIR} ({source_count} images found recursively)")
print(f"Project: {PROJECT_DIR}")

In [None]:
BOX_THRESHOLD =0.116
TEXT_THRESHOLD = 0.2
PROMPT = " . ".join(list(CLASS_MAP.keys()))
detector = GroundingDINODetector(
    box_threshold=BOX_THRESHOLD,
    text_threshold=TEXT_THRESHOLD,
    device=DEVICE
)
segmenter = SAMSegmenter() 

In [None]:
def get_image_hash(file_path: Path) -> str:
    """Calculate MD5 hash of an image file for deduplication."""
    hasher = hashlib.md5()
    with open(file_path, 'rb') as f:
        buf = f.read()
        hasher.update(buf)
    return hasher.hexdigest()

def mask_to_polygon(mask: np.ndarray) -> list[float]:
    """Convert binary mask to YOLO polygon format (normalized coordinates)."""
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    if not contours:
        return []
    
    # Find largest contour
    c = max(contours, key=cv2.contourArea)
    
    # Simplify contour
    epsilon = 0.005 * cv2.arcLength(c, True)
    approx = cv2.approxPolyDP(c, epsilon, True)
    
    if len(approx) < 3:
        return []
        
    # Normalize coordinates
    h, w = mask.shape
    points = []
    for point in approx:
        x, y = point[0]
        points.extend([x / w, y / h])
        
    return points

def save_yolo_label(file_path: Path, labels: list[tuple[int, list[float]]]):
    """Save labels to text file."""
    with open(file_path, 'w') as f:
        for class_id, points in labels:
            points_str = " ".join([f"{p:.6f}" for p in points])
            f.write(f"{class_id} {points_str}\n")

In [None]:
# --- Initialize Validation Set (One-time setup) ---

VAL_DIR = PROJECT_DIR / "val"
VAL_IMAGES_DIR = VAL_DIR / "images"
VAL_LABELS_DIR = VAL_DIR / "labels"

# Path to the fixed validation source dataset
FIXED_VAL_SOURCE = Path("datasets/ready/fixed_val")

if not VAL_IMAGES_DIR.exists() or not any(VAL_IMAGES_DIR.iterdir()):
    print(f"Validation set not found in {PROJECT_DIR}. Initializing from {FIXED_VAL_SOURCE}...")
    
    if FIXED_VAL_SOURCE.exists():
        # Create directories
        VAL_IMAGES_DIR.mkdir(parents=True, exist_ok=True)
        VAL_LABELS_DIR.mkdir(parents=True, exist_ok=True)
        
        # Copy images
        src_images = list((FIXED_VAL_SOURCE / "images").glob("*"))
        print(f"Copying {len(src_images)} images...")
        for img_path in src_images:
            if img_path.suffix.lower() in ['.jpg', '.jpeg', '.png']:
                shutil.copy(img_path, VAL_IMAGES_DIR / img_path.name)
                
        # Copy labels
        src_labels = list((FIXED_VAL_SOURCE / "labels").glob("*.txt"))
        print(f"Copying {len(src_labels)} labels...")
        for lbl_path in src_labels:
            shutil.copy(lbl_path, VAL_LABELS_DIR / lbl_path.name)
            
        print("✓ Validation set initialized.")
    else:
        print(f"Source {FIXED_VAL_SOURCE} does not exist! Validation set is empty.")
else:
    print(f"✓ Validation set already exists ({len(list(VAL_IMAGES_DIR.glob('*')))} images). Preserving it.")

In [None]:
# --- Load Existing Project State ---

splits = ['train', 'val', 'test']
existing_hashes = set()
existing_files_count = 0

print("Scanning existing project for duplicates...")

for split in splits:
    img_dir = PROJECT_DIR / split / "images"
    if img_dir.exists():
        for img_path in img_dir.glob("*"):
            if img_path.suffix.lower() in ['.jpg', '.jpeg', '.png']:
                img_hash = get_image_hash(img_path)
                existing_hashes.add(img_hash)
                existing_files_count += 1

print(f"Found {existing_files_count} existing images in project.")
print(f"Validation set is protected (contains {len(list((PROJECT_DIR / 'val' / 'images').glob('*')))} images).")

In [None]:
# --- Main Processing Loop ---
# all_keys = {}
# Create directories if they don't exist
for split in ['train', 'test']:
    (PROJECT_DIR / split / "images").mkdir(parents=True, exist_ok=True)
    (PROJECT_DIR / split / "labels").mkdir(parents=True, exist_ok=True)

# Get source images
source_images = []
if SOURCE_DIR.exists():
    # Use rglob for recursive search in subdirectories
    source_images = list(SOURCE_DIR.rglob("*"))
    # Filter for image extensions
    source_images = [p for p in source_images if p.suffix.lower() not in ['.Identifier']]

print(f"Found {len(source_images)} images in source directory (recursive search).")

added_count = 0
skipped_count = 0
error_count = 0

# Post-detection confidence threshold for humans (higher threshold)
HUMAN_CONF_THRESHOLD = 0.5

for img_path in tqdm(source_images, desc="Processing Images"):
    try:
        # 1. Check Deduplication
        img_hash = get_image_hash(img_path)
        if img_hash in existing_hashes:
            skipped_count += 1
            continue
            
        # 2. Load and Convert Image
        try:
            pil_img = Image.open(img_path)
            pil_img = ImageOps.exif_transpose(pil_img) # Fix rotation
            
            # Convert to RGB (handle RGBA/P modes)
            if pil_img.mode != 'RGB':
                pil_img = pil_img.convert('RGB')
                
            img_np = np.array(pil_img)
            h, w = img_np.shape[:2]
        except Exception as e:
            print(f"Error loading {img_path.name}: {e}")
            error_count += 1
            continue

        # 3. Run Detection (GroundingDINO)
        detections = detector.detect(img_path, PROMPT, return_all_by_label=True)
        
        yolo_labels = [] # List of (class_id, points)
        
        # 4. Run Segmentation (SAM2) & Format Labels
        # all_keys = {*all_keys, *detections.keys()}
        for label_name, dets in detections.items():
            if label_name not in CLASS_MAP:
                continue
                
            class_id = CLASS_MAP[label_name]
            
            for det in dets:
                # Post-detection filtering: Apply stricter threshold for humans
                if label_name == "human" and det['confidence'] < HUMAN_CONF_THRESHOLD:
                    continue
                
                bbox = det['bbox'] # [x1, y1, x2, y2]
                
                # Segment
                masks_result = segmenter.segment_bbox(pil_img, bbox)
                
                if masks_result and masks_result[0].masks is not None:
                    # Get the mask (SAM returns multiple, usually take the first/best)
                    # masks.data is (N, H, W)
                    mask = masks_result[0].masks.data[0].cpu().numpy().astype(np.uint8)
                    
                    # Convert to Polygon
                    points = mask_to_polygon(mask)
                    
                    if points:
                        yolo_labels.append((class_id, points))

        # 5. Save to Project
        # Determine split (Train vs Test) - Val is excluded for new data
        # 80% Train, 20% Test
        split = 'train' if random.random() < 0.8 else 'test'
        
        target_name = f"{img_hash}.jpg" # Use hash as filename to avoid collisions
        target_img_path = PROJECT_DIR / split / "images" / target_name
        target_lbl_path = PROJECT_DIR / split / "labels" / f"{img_hash}.txt"
        
        # Save Image (as JPG)
        pil_img.save(target_img_path, quality=95)
        
        # Save Labels
        save_yolo_label(target_lbl_path, yolo_labels)
        
        # Update state
        existing_hashes.add(img_hash)
        added_count += 1
        
    except Exception as e:
        print(f"Failed to process {img_path.name}: {e}")
        error_count += 1
        

print("\n" + "="*40)
print("PROCESSING COMPLETE")
print("="*40)
print(f"Added:   {added_count}")
print(f"Skipped: {skipped_count} (Duplicates)")
print(f"Errors:  {error_count}")
print(f"Total in Project: {len(existing_hashes)}")

In [None]:
# --- Generate Previews ---

PREVIEW_DIR = PROJECT_DIR / "preview"
PREVIEW_DIR.mkdir(exist_ok=True)

print(f"Generating previews in {PREVIEW_DIR}...")

# Colors for classes
COLORS = [(0, 0, 255), (0, 255, 0), (255, 0, 0)] # BGR

def draw_yolo_labels(img_path, label_path, output_path):
    if not img_path.exists(): return
    
    img = cv2.imread(str(img_path))
    h, w = img.shape[:2]
    
    if label_path.exists():
        with open(label_path, 'r') as f:
            lines = f.readlines()
            
        for line in lines:
            parts = line.strip().split()
            if len(parts) < 2: continue
            
            class_id = int(parts[0])
            coords = [float(p) for p in parts[1:]]
            
            # Denormalize
            points = []
            for i in range(0, len(coords), 2):
                x = int(coords[i] * w)
                y = int(coords[i+1] * h)
                points.append([x, y])
            
            pts = np.array(points, np.int32)
            pts = pts.reshape((-1, 1, 2))
            
            color = COLORS[class_id % len(COLORS)]
            cv2.polylines(img, [pts], True, color, 2)
            
            # Add label text
            label_text = f"Class {class_id}"
            cv2.putText(img, label_text, (points[0][0], points[0][1]-5), 
                       cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
    
    cv2.imwrite(str(output_path), img)

# Sample images from each split
samples_per_split = 5

for split in ['train', 'val', 'test']:
    img_dir = PROJECT_DIR / split / "images"
    lbl_dir = PROJECT_DIR / split / "labels"
    
    if not img_dir.exists(): continue
    
    images = list(img_dir.glob("*.jpg")) + list(img_dir.glob("*.png"))
    if not images: continue
    
    # Pick random samples
    samples = random.sample(images, min(len(images), samples_per_split))
    
    for img_path in samples:
        lbl_path = lbl_dir / (img_path.stem + ".txt")
        out_path = PREVIEW_DIR / f"{split}_{img_path.name}"
        draw_yolo_labels(img_path, lbl_path, out_path)

print(f"✓ Previews generated.")