In [None]:
รง# Install PyTorch (ensure CUDA support for video processing)
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124

# Install SAM 3 (Hypothetical repo based on Nov 2025 release)
!pip install git+https://github.com/facebookresearch/sam3.git

# Install OpenCV and Matplotlib for visualization
!pip install opencv-python matplotlib

Looking in indexes: https://download.pytorch.org/whl/cu124
[31mERROR: Could not find a version that satisfies the requirement torch (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for torch[0m[31m
[0mCollecting git+https://github.com/facebookresearch/sam3.git
  Cloning https://github.com/facebookresearch/sam3.git to /private/var/folders/7_/ppvl_3r15_5fhytbh4trswq40000gn/T/pip-req-build-d48ovl1h
  Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/sam3.git /private/var/folders/7_/ppvl_3r15_5fhytbh4trswq40000gn/T/pip-req-build-d48ovl1h
^C
[31mERROR: Operation cancelled by user[0m[31m
[0mCollecting opencv-python
  Downloading opencv_python-4.12.0.88-cp37-abi3-macosx_13_0_x86_64.whl.metadata (19 kB)
Collecting numpy<2.3.0,>=2 (from opencv-python)
  Downloading numpy-2.2.6-cp312-cp312-macosx_14_0_x86_64.whl.metadata (62 kB)
INFO: pip is looking at multiple versions of contourpy to determine which version

In [None]:
import torch
import cv2
import numpy as np
import os
from sam3 import build_sam3, SAM3Predictor
from pathlib import Path
from tqdm import tqdm

# --- CONFIGURATION ---
VIDEO_PATH = "assets/game_footage_01.mp4"
OUTPUT_DIR = "training_data/labels"
IMAGES_DIR = "training_data/images"
# Map text prompts to YOLO Class IDs
CLASS_MAP = {
    "hockey player": 0,
    "puck": 1,
    "referee": 2
}

# --- MODEL LOADING ---
# Load the 'Large' model for maximum accuracy during offline auto-labeling
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading SAM 3 on {device}...")
sam3_model = build_sam3(checkpoint="checkpoints/sam3_large.pth").to(device)
predictor = SAM3Predictor(sam3_model)

def masks_to_yolo_boxes(masks_np):
    """
    Vectorized conversion of binary masks to YOLO (x_center, y_center, w, h) format.

    Args:
        masks_np (np.ndarray): Boolean array of shape (N, H, W)

    Returns:
        np.ndarray: Array of shape (N, 4) with normalized YOLO coordinates
    """
    n, h, w = masks_np.shape
    if n == 0:
        return np.empty((0, 4))

    boxes = []

    # Vectorized bounding box extraction
    # We project the mask onto X and Y axes to find min/max
    for i in range(n):
        mask = masks_np[i]
        if not np.any(mask):
            boxes.append([0, 0, 0, 0]) # Placeholder for empty
            continue

        rows = np.any(mask, axis=1)
        cols = np.any(mask, axis=0)
        y_min, y_max = np.where(rows)[0][[0, -1]]
        x_min, x_max = np.where(cols)[0][[0, -1]]

        # Calculate width/height in pixels
        w_pixel = x_max - x_min
        h_pixel = y_max - y_min

        # Normalize (YOLO format: center_x, center_y, w, h)
        x_center = (x_min + w_pixel / 2) / w
        y_center = (y_min + h_pixel / 2) / h
        w_norm = w_pixel / w
        h_norm = h_pixel / h

        boxes.append([x_center, y_center, w_norm, h_norm])

    return np.array(boxes)

def generate_training_data(video_path, class_map):
    # 1. Initialize Video State
    cap = cv2.VideoCapture(video_path)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    inference_state = predictor.init_state(video_path=video_path)

    print(f"Processing {total_frames} frames for classes: {list(class_map.keys())}")

    # 2. PROMPT: Text-based concept tracking (The SAM 3 Magic)
    # We pass the list of text prompts. SAM 3 tracks them across the video.
    prompts = list(class_map.keys())
    # Note: 'batch_size' controls VRAM usage.
    video_output = predictor.propagate_in_video(
        inference_state,
        text_prompts=prompts,
        batch_size=8
    )

    # 3. Serialize Data
    Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
    Path(IMAGES_DIR).mkdir(parents=True, exist_ok=True)

    current_frame = 0

    # We iterate through the generator or result dict from SAM 3
    for frame_idx, frame_data in tqdm(video_output.items(), total=total_frames):

        # Setup Label File
        label_path = os.path.join(OUTPUT_DIR, f"{Path(video_path).stem}_{frame_idx:06d}.txt")

        # Read frame (only needed if you want to save the JPGs for training)
        ret, frame_img = cap.read()
        if not ret: break

        # Save Frame Image (YOLO requires image + label pair)
        image_path = os.path.join(IMAGES_DIR, f"{Path(video_path).stem}_{frame_idx:06d}.jpg")
        cv2.imwrite(image_path, frame_img)

        labels_str = []

        # frame_data maps { prompt_text: masks_tensor } or similar structure
        for text_prompt, masks_logits in frame_data.items():
            class_id = class_map[text_prompt]

            # Binarize masks
            masks_binary = (masks_logits > 0.0).cpu().numpy().squeeze()

            # Handle single vs batch dimension
            if masks_binary.ndim == 2:
                masks_binary = masks_binary[np.newaxis, ...]

            # Get boxes
            yolo_boxes = masks_to_yolo_boxes(masks_binary)

            for box in yolo_boxes:
                # Filter noise (e.g. 0-area boxes)
                if box[2] < 0.001 or box[3] < 0.001: continue

                # Format: class x y w h
                labels_str.append(f"{class_id} {box[0]:.6f} {box[1]:.6f} {box[2]:.6f} {box[3]:.6f}")

        # Write Label File
        if labels_str:
            with open(label_path, "w") as f:
                f.write("\n".join(labels_str))

    cap.release()
    print("Auto-labeling complete.")

# --- EXECUTE ---
if __name__ == "__main__":
    generate_training_data(VIDEO_PATH, CLASS_MAP)

In [None]:
import matplotlib.pyplot as plt

def visualize_sample(image_dir, label_dir, num_samples=3):
    image_files = sorted(list(Path(image_dir).glob("*.jpg")))[:num_samples]

    for img_path in image_files:
        img = cv2.imread(str(img_path))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        h, w, _ = img.shape

        label_path = Path(label_dir) / f"{img_path.stem}.txt"
        if not label_path.exists(): continue

        with open(label_path, "r") as f:
            lines = f.readlines()

        for line in lines:
            cls, xc, yc, wn, hn = map(float, line.strip().split())

            # Denormalize
            x1 = int((xc - wn/2) * w)
            y1 = int((yc - hn/2) * h)
            x2 = int((xc + wn/2) * w)
            y2 = int((yc + hn/2) * h)

            color = (0, 255, 0) if cls == 0 else (255, 0, 0) # Green for player, Red for others
            cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)

        plt.figure(figsize=(10, 6))
        plt.imshow(img)
        plt.title(f"Sample: {img_path.name}")
        plt.axis('off')
        plt.show()

visualize_sample(IMAGES_DIR, OUTPUT_DIR)