# Dice Cube Tracking with SAM 2 Video Predictor

This notebook tracks and segments **green, red, and blue dice cubes** across all demonstration videos.

**Pipeline:**
1. **Grounding DINO** - Detect dice cubes by color on the first frame
2. **SAM 2 Video Predictor** - Track and segment dice throughout the video
3. **Workspace Filtering** - Only track objects on the tabletop (ignore background)

**Color-coded tracking:**
- ðŸŸ¢ Green dice
- ðŸ”´ Red dice  
- ðŸ”µ Blue dice

In [None]:
# ============================================================
# Install Dependencies
# ============================================================
# Note: SAM2 must be installed from GitHub (not available on PyPI)
%pip install torch torchvision --index-url https://download.pytorch.org/whl/cu124 -q
%pip install transformers accelerate -q
%pip install git+https://github.com/facebookresearch/segment-anything-2.git -q
%pip install opencv-python supervision -q
%pip install matplotlib numpy Pillow tqdm -q

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
[31mERROR: Could not find a version that satisfies the requirement segment-anything-2 (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for segment-anything-2[0m[31m
[0mNote: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [None]:
# ============================================================
# Import Libraries
# ============================================================
import os
import cv2
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from pathlib import Path
from collections import defaultdict
from tqdm import tqdm
import json
import warnings
import shutil
warnings.filterwarnings('ignore')

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

In [None]:
# ============================================================
# Configuration
# ============================================================
VIDEO_DIR = "./demonstrations"
OUTPUT_DIR = "./output_tracking"
FRAMES_DIR = os.path.join(OUTPUT_DIR, "frames")  # Temporary frames for SAM2 video
OUTPUT_VIDEOS_DIR = os.path.join(OUTPUT_DIR, "videos")
TRACKING_DATA_DIR = os.path.join(OUTPUT_DIR, "tracking_data")

os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(FRAMES_DIR, exist_ok=True)
os.makedirs(OUTPUT_VIDEOS_DIR, exist_ok=True)
os.makedirs(TRACKING_DATA_DIR, exist_ok=True)

# Gather all video files
video_files = sorted([f for f in os.listdir(VIDEO_DIR) if f.endswith(('.mp4', '.avi', '.mov'))])
print(f"Found {len(video_files)} videos:")
for v in video_files:
    print(f"  - {v}")

In [None]:
# ============================================================
# Load Grounding DINO for Initial Detection
# ============================================================
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection

GDINO_MODEL_ID = "IDEA-Research/grounding-dino-base"
gdino_processor = AutoProcessor.from_pretrained(GDINO_MODEL_ID)
gdino_model = AutoModelForZeroShotObjectDetection.from_pretrained(GDINO_MODEL_ID).to(DEVICE)
gdino_model.eval()
print("âœ“ Grounding DINO loaded")

In [None]:
# ============================================================
# Load SAM 2 for Video Segmentation & Tracking
# ============================================================
from sam2.build_sam import build_sam2_video_predictor, build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

# SAM2 checkpoint (you already have this)
SAM2_CHECKPOINT = "sam2_hiera_large.pt"
SAM2_CONFIG = "sam2_hiera_l.yaml"

# Download if not present
if not os.path.exists(SAM2_CHECKPOINT):
    print("Downloading SAM2 checkpoint...")
    import urllib.request
    urllib.request.urlretrieve(
        "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt",
        SAM2_CHECKPOINT
    )

# Build video predictor for tracking
sam2_video_predictor = build_sam2_video_predictor(SAM2_CONFIG, SAM2_CHECKPOINT, device=DEVICE)

# Build image predictor for initial segmentation
sam2_model = build_sam2(SAM2_CONFIG, SAM2_CHECKPOINT, device=DEVICE)
sam2_image_predictor = SAM2ImagePredictor(sam2_model)

print("SAM 2 Video Predictor loaded")

## Detection & Filtering Functions

These functions handle:
1. **Dice detection** using Grounding DINO with color-specific prompts
2. **Tabletop/workspace filtering** to ignore objects outside the work area
3. **Color classification** to assign correct labels

In [None]:
# ============================================================
# Workspace/Tabletop Detection
# ============================================================

def detect_tabletop_region(image, method="color"):
    """
    Detect the tabletop/workspace region to filter out background objects.
    
    Methods:
    - "color": Detect table by dominant color (works for colored tables)
    - "lower_half": Simple heuristic - table is typically in lower portion
    - "full": No filtering, use entire frame
    
    Returns: mask where True = workspace area
    """
    h, w = image.shape[:2]
    
    if method == "lower_half":
        # Simple heuristic: table is in lower 70% of frame
        mask = np.zeros((h, w), dtype=bool)
        mask[int(h * 0.2):, :] = True
        return mask
    
    elif method == "color":
        # Detect table surface by color - typically a solid color
        # Convert to HSV for better color segmentation
        hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
        
        # Detect dominant non-white/non-black regions in lower half
        lower_region = hsv[int(h * 0.5):, :]
        
        # Calculate histogram of hue values
        hist = cv2.calcHist([lower_region], [0], None, [180], [0, 180])
        dominant_hue = np.argmax(hist)
        
        # Create mask for table color (with tolerance)
        lower_bound = np.array([max(0, dominant_hue - 15), 30, 50])
        upper_bound = np.array([min(180, dominant_hue + 15), 255, 255])
        color_mask = cv2.inRange(hsv, lower_bound, upper_bound)
        
        # Combine with spatial prior (lower portion more likely to be table)
        spatial_weight = np.linspace(0.3, 1.0, h).reshape(-1, 1)
        spatial_mask = np.tile(spatial_weight, (1, w))
        
        combined = (color_mask > 0).astype(float) * spatial_mask
        mask = combined > 0.5
        
        # Clean up with morphology
        kernel = np.ones((20, 20), np.uint8)
        mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, kernel)
        mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
        
        return mask.astype(bool)
    
    else:  # "full"
        return np.ones((h, w), dtype=bool)


def is_in_workspace(box, workspace_mask, threshold=0.5):
    """
    Check if a detected object is within the workspace region.
    
    Args:
        box: [x1, y1, x2, y2] bounding box
        workspace_mask: Boolean mask of workspace region
        threshold: Minimum overlap ratio required
    
    Returns: True if object is in workspace
    """
    x1, y1, x2, y2 = map(int, box)
    x1, y1 = max(0, x1), max(0, y1)
    x2, y2 = min(workspace_mask.shape[1], x2), min(workspace_mask.shape[0], y2)
    
    if x2 <= x1 or y2 <= y1:
        return False
    
    box_region = workspace_mask[y1:y2, x1:x2]
    overlap_ratio = np.mean(box_region)
    
    return overlap_ratio >= threshold

In [None]:
# ============================================================
# Dice Detection with Grounding DINO
# ============================================================

def normalize_dice_label(raw_label):
    """
    Normalize Grounding DINO labels to canonical dice colors.
    """
    raw = raw_label.lower().strip()
    
    color_order = ["green", "red", "blue"]
    for color in color_order:
        if color in raw:
            return f"{color}_dice"
    
    if "dice" in raw or "cube" in raw:
        return "unknown_dice"
    
    return raw_label


def classify_dice_by_color(image, box):
    """
    Classify a dice by analyzing the dominant color in its bounding box.
    More reliable than relying solely on Grounding DINO labels.
    """
    x1, y1, x2, y2 = map(int, box)
    h, w = image.shape[:2]
    x1, y1 = max(0, x1), max(0, y1)
    x2, y2 = min(w, x2), min(h, y2)
    
    if x2 <= x1 or y2 <= y1:
        return "unknown_dice"
    
    roi = image[y1:y2, x1:x2]
    hsv = cv2.cvtColor(roi, cv2.COLOR_RGB2HSV)
    
    # Define color ranges in HSV
    color_ranges = {
        "red_dice": [
            (np.array([0, 100, 100]), np.array([10, 255, 255])),      # Red lower
            (np.array([160, 100, 100]), np.array([180, 255, 255]))    # Red upper
        ],
        "green_dice": [
            (np.array([35, 80, 80]), np.array([85, 255, 255]))        # Green
        ],
        "blue_dice": [
            (np.array([90, 80, 80]), np.array([130, 255, 255]))       # Blue
        ]
    }
    
    color_scores = {}
    for color_name, ranges in color_ranges.items():
        total_mask = np.zeros(hsv.shape[:2], dtype=np.uint8)
        for lower, upper in ranges:
            mask = cv2.inRange(hsv, lower, upper)
            total_mask = cv2.bitwise_or(total_mask, mask)
        color_scores[color_name] = np.sum(total_mask) / (total_mask.size * 255)
    
    # Return color with highest score if above threshold
    best_color = max(color_scores, key=color_scores.get)
    if color_scores[best_color] > 0.15:  # At least 15% of pixels match
        return best_color
    
    return "unknown_dice"


def detect_colored_dice(image, workspace_mask=None, box_threshold=0.25, text_threshold=0.2):
    """
    Detect all colored dice in an image using Grounding DINO.
    
    Args:
        image: PIL Image or numpy array
        workspace_mask: Optional mask to filter detections
        box_threshold: Confidence threshold for boxes
        text_threshold: Confidence threshold for text matching
    
    Returns: List of dicts with {box, label, score, color}
    """
    if isinstance(image, np.ndarray):
        pil_image = Image.fromarray(image)
        np_image = image
    else:
        pil_image = image
        np_image = np.array(image)
    
    # Text prompt for dice detection
    text_prompt = "green dice . red dice . blue dice . green cube . red cube . blue cube ."
    
    inputs = gdino_processor(
        images=pil_image,
        text=text_prompt,
        return_tensors="pt"
    ).to(DEVICE)
    
    with torch.no_grad():
        outputs = gdino_model(**inputs)
    
    results = gdino_processor.post_process_grounded_object_detection(
        outputs,
        inputs.input_ids,
        text_threshold=text_threshold,
        target_sizes=[pil_image.size[::-1]]
    )[0]
    
    boxes = results["boxes"].cpu().numpy()
    scores = results["scores"].cpu().numpy()
    labels = results["labels"]
    
    # Filter by confidence
    mask = scores >= box_threshold
    boxes = boxes[mask]
    scores = scores[mask]
    labels = [labels[i] for i in range(len(labels)) if mask[i]]
    
    detections = []
    for box, score, label in zip(boxes, scores, labels):
        # Filter by workspace if provided
        if workspace_mask is not None and not is_in_workspace(box, workspace_mask):
            continue
        
        # Classify color based on actual pixel values
        color = classify_dice_by_color(np_image, box)
        
        detections.append({
            "box": box,
            "score": float(score),
            "label": normalize_dice_label(label),
            "color": color
        })
    
    return detections


def apply_nms(detections, iou_threshold=0.5):
    """
    Apply non-maximum suppression to remove duplicate detections.
    """
    if len(detections) == 0:
        return detections
    
    boxes = np.array([d["box"] for d in detections])
    scores = np.array([d["score"] for d in detections])
    
    # Sort by score
    order = scores.argsort()[::-1]
    
    keep = []
    while len(order) > 0:
        i = order[0]
        keep.append(i)
        
        if len(order) == 1:
            break
        
        # Compute IoU with remaining boxes
        xx1 = np.maximum(boxes[i, 0], boxes[order[1:], 0])
        yy1 = np.maximum(boxes[i, 1], boxes[order[1:], 1])
        xx2 = np.minimum(boxes[i, 2], boxes[order[1:], 2])
        yy2 = np.minimum(boxes[i, 3], boxes[order[1:], 3])
        
        inter = np.maximum(0, xx2 - xx1) * np.maximum(0, yy2 - yy1)
        area_i = (boxes[i, 2] - boxes[i, 0]) * (boxes[i, 3] - boxes[i, 1])
        areas = (boxes[order[1:], 2] - boxes[order[1:], 0]) * (boxes[order[1:], 3] - boxes[order[1:], 1])
        iou = inter / (area_i + areas - inter + 1e-6)
        
        remaining = np.where(iou < iou_threshold)[0] + 1
        order = order[remaining]
    
    return [detections[i] for i in keep]

## SAM 2 Video Tracking Pipeline

This section implements the core tracking logic:
1. Extract frames from video
2. Initialize SAM 2 with first-frame detections
3. Propagate segmentation masks through all frames
4. Generate output video with tracked dice

In [None]:
# ============================================================
# Video Frame Extraction for SAM2
# ============================================================

def extract_frames_for_sam2(video_path, output_dir, sample_rate=1):
    """
    Extract frames from video for SAM2 video predictor.
    SAM2 requires JPEG frames in a directory.
    
    Args:
        video_path: Path to input video
        output_dir: Directory to save frames
        sample_rate: Extract every Nth frame (1 = all frames)
    
    Returns: (frame_paths, fps, total_frames, frame_indices)
    """
    # Clear output directory
    if os.path.exists(output_dir):
        shutil.rmtree(output_dir)
    os.makedirs(output_dir, exist_ok=True)
    
    cap = cv2.VideoCapture(video_path)
    fps = cap.get(cv2.CAP_PROP_FPS)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    
    frame_paths = []
    frame_indices = []
    frame_idx = 0
    saved_idx = 0
    
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        
        if frame_idx % sample_rate == 0:
            # SAM2 expects sequential naming
            frame_path = os.path.join(output_dir, f"{saved_idx:06d}.jpg")
            cv2.imwrite(frame_path, frame)
            frame_paths.append(frame_path)
            frame_indices.append(frame_idx)
            saved_idx += 1
        
        frame_idx += 1
    
    cap.release()
    
    return frame_paths, fps, total_frames, frame_indices, (width, height)

In [None]:
# ============================================================
# SAM2 Video Tracking Core
# ============================================================

# Color mapping for visualization
DICE_COLORS = {
    "green_dice": (0, 255, 0),      # Green
    "red_dice": (255, 0, 0),        # Red  
    "blue_dice": (0, 0, 255),       # Blue
    "unknown_dice": (255, 255, 0),  # Yellow
}


def track_dice_in_video(video_path, sample_rate=2, workspace_method="lower_half"):
    """
    Track all colored dice cubes in a video using SAM2.
    
    Args:
        video_path: Path to input video
        sample_rate: Process every Nth frame (lower = more accurate, slower)
        workspace_method: Method to detect tabletop ("lower_half", "color", "full")
    
    Returns: Dictionary with tracking results
    """
    video_name = Path(video_path).stem
    print(f"\n{'='*60}")
    print(f"Processing: {video_name}")
    print(f"{'='*60}")
    
    # Create temporary frame directory for this video
    video_frames_dir = os.path.join(FRAMES_DIR, video_name)
    
    # Step 1: Extract frames
    print("Step 1: Extracting frames...")
    frame_paths, fps, total_frames, frame_indices, (width, height) = \
        extract_frames_for_sam2(video_path, video_frames_dir, sample_rate)
    print(f"  - Extracted {len(frame_paths)} frames (original: {total_frames} @ {fps:.1f} FPS)")
    
    # Step 2: Detect dice on first frame
    print("Step 2: Detecting dice on first frame...")
    first_frame = cv2.cvtColor(cv2.imread(frame_paths[0]), cv2.COLOR_BGR2RGB)
    
    # Detect workspace/tabletop
    workspace_mask = detect_tabletop_region(first_frame, method=workspace_method)
    
    # Detect dice
    detections = detect_colored_dice(first_frame, workspace_mask, box_threshold=0.20)
    detections = apply_nms(detections, iou_threshold=0.4)
    
    if len(detections) == 0:
        print("  âš  No dice detected on first frame! Trying without workspace filter...")
        detections = detect_colored_dice(first_frame, None, box_threshold=0.15)
        detections = apply_nms(detections, iou_threshold=0.4)
    
    print(f"  - Found {len(detections)} dice:")
    for i, det in enumerate(detections):
        print(f"    [{i}] {det['color']} (conf: {det['score']:.2f})")
    
    if len(detections) == 0:
        print("  âœ— No dice found, skipping video")
        return None
    
    # Step 3: Initialize SAM2 video predictor
    print("Step 3: Initializing SAM2 video tracking...")
    
    with torch.inference_mode(), torch.autocast(DEVICE, dtype=torch.bfloat16):
        state = sam2_video_predictor.init_state(video_path=video_frames_dir)
        
        # Add each detected dice as a tracking object
        object_ids = []
        object_colors = {}
        
        for i, det in enumerate(detections):
            obj_id = i + 1  # SAM2 uses 1-indexed object IDs
            box = det["box"]
            
            # Add object with bounding box prompt
            _, out_obj_ids, out_mask_logits = sam2_video_predictor.add_new_points_or_box(
                inference_state=state,
                frame_idx=0,
                obj_id=obj_id,
                box=box
            )
            
            object_ids.append(obj_id)
            object_colors[obj_id] = det["color"]
        
        print(f"  - Initialized {len(object_ids)} objects for tracking")
        
        # Step 4: Propagate through video
        print("Step 4: Propagating masks through video...")
        
        # Collect all frame masks
        video_segments = {}  # {frame_idx: {obj_id: mask}}
        
        for frame_idx, obj_ids, mask_logits in sam2_video_predictor.propagate_in_video(state):
            masks = (mask_logits > 0.0).cpu().numpy()
            video_segments[frame_idx] = {}
            
            for i, obj_id in enumerate(obj_ids):
                video_segments[frame_idx][obj_id] = masks[i, 0]  # [H, W] boolean mask
        
        print(f"  - Tracked across {len(video_segments)} frames")
    
    # Compile results
    results = {
        "video_name": video_name,
        "video_path": video_path,
        "fps": fps,
        "total_frames": total_frames,
        "processed_frames": len(frame_paths),
        "sample_rate": sample_rate,
        "frame_size": (width, height),
        "frame_indices": frame_indices,
        "objects": {
            obj_id: {
                "color": object_colors[obj_id],
                "initial_box": detections[obj_id - 1]["box"].tolist()
            }
            for obj_id in object_ids
        },
        "segments": video_segments,
        "frame_paths": frame_paths
    }
    
    return results

In [None]:
# ============================================================
# Visualization & Video Output
# ============================================================

def create_tracking_video(results, output_path):
    """
    Create output video with segmentation masks overlaid.
    """
    if results is None:
        return
    
    frame_paths = results["frame_paths"]
    segments = results["segments"]
    objects = results["objects"]
    fps = results["fps"] / results["sample_rate"]  # Adjust for sampled frames
    
    # Get frame size
    first_frame = cv2.imread(frame_paths[0])
    height, width = first_frame.shape[:2]
    
    # Create video writer
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
    
    print(f"  Creating output video: {output_path}")
    
    for frame_idx, frame_path in enumerate(tqdm(frame_paths, desc="  Rendering")):
        frame = cv2.imread(frame_path)
        
        if frame_idx in segments:
            for obj_id, mask in segments[frame_idx].items():
                if obj_id in objects:
                    color_name = objects[obj_id]["color"]
                    color_bgr = DICE_COLORS.get(color_name, (255, 255, 0))
                    # Convert RGB to BGR for OpenCV
                    color_bgr = (color_bgr[2], color_bgr[1], color_bgr[0])
                    
                    # Apply mask overlay
                    mask_3ch = np.stack([mask] * 3, axis=-1)
                    overlay = frame.copy()
                    overlay[mask] = color_bgr
                    frame = cv2.addWeighted(frame, 0.7, overlay, 0.3, 0)
                    
                    # Draw contour
                    contours, _ = cv2.findContours(
                        mask.astype(np.uint8), 
                        cv2.RETR_EXTERNAL, 
                        cv2.CHAIN_APPROX_SIMPLE
                    )
                    cv2.drawContours(frame, contours, -1, color_bgr, 2)
                    
                    # Add label
                    if contours:
                        M = cv2.moments(contours[0])
                        if M["m00"] > 0:
                            cx = int(M["m10"] / M["m00"])
                            cy = int(M["m01"] / M["m00"])
                            label = color_name.replace("_", " ").title()
                            cv2.putText(frame, label, (cx - 30, cy - 10),
                                       cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
        
        out.write(frame)
    
    out.release()
    print(f"  âœ“ Saved: {output_path}")


def visualize_first_frame_detections(results, save_path=None):
    """
    Visualize the initial detections on the first frame.
    """
    if results is None:
        return
    
    frame_path = results["frame_paths"][0]
    frame = cv2.cvtColor(cv2.imread(frame_path), cv2.COLOR_BGR2RGB)
    segments = results["segments"].get(0, {})
    objects = results["objects"]
    
    fig, axes = plt.subplots(1, 2, figsize=(16, 8))
    
    # Original with boxes
    ax1 = axes[0]
    ax1.imshow(frame)
    ax1.set_title(f"Initial Detections - {results['video_name']}")
    
    for obj_id, obj_info in objects.items():
        box = obj_info["initial_box"]
        color = DICE_COLORS.get(obj_info["color"], (255, 255, 0))
        # Normalize to 0-1 for matplotlib
        color_norm = tuple(c / 255 for c in color)
        
        rect = plt.Rectangle(
            (box[0], box[1]), box[2] - box[0], box[3] - box[1],
            fill=False, edgecolor=color_norm, linewidth=2
        )
        ax1.add_patch(rect)
        ax1.text(box[0], box[1] - 5, obj_info["color"].replace("_", " "),
                color=color_norm, fontsize=10, fontweight='bold')
    ax1.axis('off')
    
    # With segmentation masks
    ax2 = axes[1]
    mask_overlay = frame.copy().astype(float)
    
    for obj_id, mask in segments.items():
        if obj_id in objects:
            color = DICE_COLORS.get(objects[obj_id]["color"], (255, 255, 0))
            for c in range(3):
                mask_overlay[:, :, c] = np.where(
                    mask, 
                    mask_overlay[:, :, c] * 0.5 + color[c] * 0.5,
                    mask_overlay[:, :, c]
                )
    
    ax2.imshow(mask_overlay.astype(np.uint8))
    ax2.set_title(f"Segmentation Masks - {results['video_name']}")
    ax2.axis('off')
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()

In [None]:
# ============================================================
# Save Tracking Data
# ============================================================

def save_tracking_data(results, output_dir):
    """
    Save tracking data (centroids, bounding boxes per frame) to JSON.
    """
    if results is None:
        return
    
    video_name = results["video_name"]
    
    # Extract centroid and bbox data per frame
    tracking_data = {
        "video_name": video_name,
        "fps": results["fps"],
        "total_frames": results["total_frames"],
        "sample_rate": results["sample_rate"],
        "objects": {},
        "frames": {}
    }
    
    # Object info
    for obj_id, obj_info in results["objects"].items():
        tracking_data["objects"][str(obj_id)] = {
            "color": obj_info["color"]
        }
    
    # Per-frame tracking
    for frame_idx, masks in results["segments"].items():
        original_frame_idx = results["frame_indices"][frame_idx]
        tracking_data["frames"][str(original_frame_idx)] = {}
        
        for obj_id, mask in masks.items():
            if not np.any(mask):
                continue
            
            # Get bounding box from mask
            ys, xs = np.where(mask)
            if len(xs) == 0:
                continue
            
            bbox = [int(xs.min()), int(ys.min()), int(xs.max()), int(ys.max())]
            centroid = [int(np.mean(xs)), int(np.mean(ys))]
            area = int(np.sum(mask))
            
            tracking_data["frames"][str(original_frame_idx)][str(obj_id)] = {
                "bbox": bbox,
                "centroid": centroid,
                "area": area,
                "color": results["objects"][obj_id]["color"]
            }
    
    # Save to JSON
    output_path = os.path.join(output_dir, f"{video_name}_tracking.json")
    with open(output_path, 'w') as f:
        json.dump(tracking_data, f, indent=2)
    
    print(f"  âœ“ Saved tracking data: {output_path}")
    return tracking_data

## Process All Demonstration Videos

Run the tracking pipeline on all videos in the demonstrations folder.

In [None]:
# ============================================================
# Process All Videos
# ============================================================

# Configuration
SAMPLE_RATE = 2  # Process every 2nd frame (balance speed/accuracy)
WORKSPACE_METHOD = "lower_half"  # Options: "lower_half", "color", "full"

all_results = {}

for video_file in video_files:
    video_path = os.path.join(VIDEO_DIR, video_file)
    video_name = Path(video_file).stem
    
    try:
        # Track dice in video
        results = track_dice_in_video(
            video_path, 
            sample_rate=SAMPLE_RATE,
            workspace_method=WORKSPACE_METHOD
        )
        
        if results is not None:
            all_results[video_name] = results
            
            # Visualize first frame detections
            vis_path = os.path.join(OUTPUT_DIR, f"{video_name}_detections.png")
            visualize_first_frame_detections(results, save_path=vis_path)
            
            # Create output video with tracking
            output_video_path = os.path.join(OUTPUT_VIDEOS_DIR, f"{video_name}_tracked.mp4")
            create_tracking_video(results, output_video_path)
            
            # Save tracking data
            save_tracking_data(results, TRACKING_DATA_DIR)
            
    except Exception as e:
        print(f"  âœ— Error processing {video_name}: {e}")
        import traceback
        traceback.print_exc()

print(f"\n{'='*60}")
print(f"Processing complete!")
print(f"{'='*60}")
print(f"Videos processed: {len(all_results)}/{len(video_files)}")
print(f"Output videos: {OUTPUT_VIDEOS_DIR}")
print(f"Tracking data: {TRACKING_DATA_DIR}")

## Summary Statistics

View a summary of all tracked dice across videos.

In [None]:
# ============================================================
# Summary Statistics
# ============================================================

print("\n" + "="*70)
print("TRACKING SUMMARY")
print("="*70)

summary_data = []
for video_name, results in all_results.items():
    objects = results["objects"]
    
    # Count by color
    color_counts = {"green_dice": 0, "red_dice": 0, "blue_dice": 0, "unknown_dice": 0}
    for obj_info in objects.values():
        color = obj_info["color"]
        if color in color_counts:
            color_counts[color] += 1
    
    summary_data.append({
        "video": video_name,
        "green": color_counts["green_dice"],
        "red": color_counts["red_dice"],
        "blue": color_counts["blue_dice"],
        "total": len(objects),
        "frames": results["processed_frames"]
    })
    
    print(f"\n{video_name}:")
    print(f"  ðŸŸ¢ Green dice: {color_counts['green_dice']}")
    print(f"  ðŸ”´ Red dice:   {color_counts['red_dice']}")
    print(f"  ðŸ”µ Blue dice:  {color_counts['blue_dice']}")
    print(f"  ðŸ“Š Total tracked: {len(objects)} objects over {results['processed_frames']} frames")

# Create summary table
print("\n" + "="*70)
print("SUMMARY TABLE")
print("="*70)
print(f"{'Video':<20} {'Green':>8} {'Red':>8} {'Blue':>8} {'Total':>8} {'Frames':>8}")
print("-"*70)
for row in summary_data:
    print(f"{row['video']:<20} {row['green']:>8} {row['red']:>8} {row['blue']:>8} {row['total']:>8} {row['frames']:>8}")

## Single Video Test (Optional)

Use this cell to test tracking on a single video before processing all.

In [None]:
# ============================================================
# Test on Single Video (run this first to verify setup)
# ============================================================

# Pick the first video for testing
TEST_VIDEO = video_files[0] if video_files else None

if TEST_VIDEO:
    test_video_path = os.path.join(VIDEO_DIR, TEST_VIDEO)
    
    # Track with higher sample rate for faster testing
    test_results = track_dice_in_video(
        test_video_path,
        sample_rate=5,  # Faster for testing
        workspace_method="lower_half"
    )
    
    if test_results:
        # Visualize
        visualize_first_frame_detections(test_results)
        
        # Create short test video
        test_output = os.path.join(OUTPUT_VIDEOS_DIR, f"{Path(TEST_VIDEO).stem}_test.mp4")
        create_tracking_video(test_results, test_output)
        
        print("\nâœ“ Test complete! Check the output above to verify dice detection.")
        print("  If dice are detected correctly, run the 'Process All Videos' cell.")
else:
    print("No videos found in demonstrations folder!")