<a href="https://colab.research.google.com/github/arzhrd/Basketball-Player-Detail-Using-Computer-Vision/blob/main/Basketball_AI.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# 1. Configure API Keys and check for GPU
import os
from google.colab import userdata
from pathlib import Path
import torch

# Load API keys from Colab secrets
try:
    os.environ["HF_TOKEN"] = userdata.get("HF_TOKEN")
    os.environ["ROBOFLOW_API_KEY"] = userdata.get("ROBOFLOW_API_KEY")
    print("API keys loaded successfully.")
except Exception as e:
    print("Could not load API keys. Please set them up in Colab Secrets (🔑).")
    print("Required secrets: 'HF_TOKEN' and 'ROBOFLOW_API_KEY'")

# Check for GPU
!nvidia-smi

# Set home directory
HOME = Path.cwd()
print("HOME:", HOME)

# Set ONNX provider to use GPU
os.environ["ONNXRUNTIME_EXECUTION_PROVIDERS"] = "[CUDAExecutionProvider]"

# 2. Install SAM2 (Segment Anything Model 2) for tracking
!git clone https://github.com/Gy920/segment-anything-2-real-time.git
%cd {HOME}/segment-anything-2-real-time
!pip install -e . -q
!python setup.py build_ext --inplace
!(cd checkpoints && bash download_ckpts.sh)
%cd {HOME}

# 3. Install all other required Python packages
!pip install -q gdown inference-gpu supervision transformers num2words
!pip install -q git+https://github.com/roboflow/sports.git@feat/basketball
!pip install -q flash-attn --no-build-isolation

print("\n✅ All installations are complete.")

API keys loaded successfully.
/bin/bash: line 1: nvidia-smi: command not found
HOME: /content
Cloning into 'segment-anything-2-real-time'...
remote: Enumerating objects: 406, done.[K
remote: Counting objects: 100% (93/93), done.[K
remote: Compressing objects: 100% (37/37), done.[K
remote: Total 406 (delta 65), reused 56 (delta 56), pack-reused 313 (from 2)[K
Receiving objects: 100% (406/406), 79.43 MiB | 30.52 MiB/s, done.
Resolving deltas: 100% (91/91), done.
/content/segment-anything-2-real-time
  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build editable ... [?25l[?25hdone
  Preparing editable metadata (pyproject.toml) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m 

In [2]:
# 1. Download sample videos and fonts
SOURCE_VIDEO_DIRECTORY = HOME / "source"
!gdown -q https://drive.google.com/drive/folders/1eDJYqQ77Fytz15tKGdJCMeYSgmoQ-2-H -O {SOURCE_VIDEO_DIRECTORY} --folder
!gdown -q https://drive.google.com/drive/folders/1RBjpI5Xleb58lujeusxH0W5zYMMA4ytO -O {HOME / "fonts"} --folder
print("Sample videos and fonts downloaded.")

# 2. Define the source video path you want to process
# You can change the filename to process a different clip from the `source` directory
SOURCE_VIDEO_PATH = SOURCE_VIDEO_DIRECTORY / "boston-celtics-new-york-knicks-game-1-q1-04.28-04.20.mp4"

# 3. Define Team Rosters and Colors
TEAM_ROSTERS = {
  "New York Knicks": {
    "55": "Hukporti", "1": "Payne", "0": "Wright", "11": "Brunson", "3": "Hart",
    "32": "Towns", "44": "Shamet", "25": "Bridges", "2": "McBride",
    "23": "Robinson", "8": "Anunoby", "4": "Dadiet", "5": "Achiuwa", "13": "Kolek"
  },
  "Boston Celtics": {
    "42": "Horford", "55": "Scheierman", "9": "White", "20": "Davison",
    "7": "Brown", "0": "Tatum", "27": "Walsh", "4": "Holiday", "8": "Porzingis",
    "40": "Kornet", "88": "Queta", "11": "Pritchard", "30": "Hauser",
    "12": "Craig", "26": "Tillman"
  }
}

TEAM_COLORS = {
    "New York Knicks": "#006BB6",
    "Boston Celtics": "#007A33"
}

Sample videos and fonts downloaded.


In [None]:
import supervision as sv
from inference import get_model
import torch

# 1. Load Player and Number Detection Model (RF-DETR)
PLAYER_DETECTION_MODEL_ID = "basketball-player-detection-3-ycjdo/4"
PLAYER_DETECTION_MODEL = get_model(model_id=PLAYER_DETECTION_MODEL_ID)

# 2. Load Player Tracking Model (SAM2.1)
# We MUST change to the sam2 directory *before* importing from it.
%cd {HOME}/segment-anything-2-real-time
from sam2.build_sam import build_sam2_camera_predictor

SAM2_CHECKPOINT = "checkpoints/sam2.1_hiera_large.pt"
SAM2_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml"
sam_predictor = build_sam2_camera_predictor(SAM2_CONFIG, SAM2_CHECKPOINT)
%cd {HOME} # Change back to the home directory

# 3. Load Jersey Number Recognition Model (SmolVLM2)
NUMBER_RECOGNITION_MODEL_ID = "basketball-jersey-numbers-ocr/3"
NUMBER_RECOGNITION_MODEL = get_model(model_id=NUMBER_RECOGNITION_MODEL_ID)
NUMBER_RECOGNITION_MODEL_PROMPT = "Read the number."

print("\n✅ All models loaded successfully.")

In [None]:
import numpy as np
import supervision as sv
from tqdm import tqdm
from sports.common.team import TeamClassifier

# Class IDs for different player-related detections
PLAYER_CLASS_IDS = [3, 4, 5, 6, 7]

def shrink_boxes(xyxy: np.ndarray, scale: float) -> np.ndarray:
    """Shrinks bounding boxes to focus on the jersey."""
    x1, y1, x2, y2 = xyxy[:, 0], xyxy[:, 1], xyxy[:, 2], xyxy[:, 3]
    cx, cy = (x1 + x2) / 2, (y1 + y2) / 2
    w, h = (x2 - x1) * scale, (y2 - y1) * scale
    new_x1, new_y1 = cx - w / 2, cy - h / 2
    new_x2, new_y2 = cx + w / 2, cy + h / 2
    return np.stack([new_x1, new_y1, new_x2, new_y2], axis=1)

# 1. Collect player crops from all videos to build a training set
crops = []
for video_path in sv.list_files_with_extensions(SOURCE_VIDEO_DIRECTORY, extensions=["mp4"]):
    frame_generator = sv.get_video_frames_generator(source_path=str(video_path), stride=30)
    for frame in tqdm(frame_generator, desc=f"Processing {video_path.name}"):
        result = PLAYER_DETECTION_MODEL.infer(frame, confidence=0.4, iou_threshold=0.9, class_agnostic_nms=True)[0]
        detections = sv.Detections.from_inference(result)
        detections = detections[np.isin(detections.class_id, PLAYER_CLASS_IDS)]
        boxes = shrink_boxes(xyxy=detections.xyxy, scale=0.4)
        for box in boxes:
            crops.append(sv.crop_image(frame, box))

# 2. Train the team classifier and predict teams for the collected crops
team_classifier = TeamClassifier(device="cuda")
team_classifier.fit(crops)
teams = team_classifier.predict(crops)

# 3. Display the results of clustering for manual verification
team_0 = [crop for crop, team in zip(crops, teams) if team == 0]
team_1 = [crop for crop, team in zip(crops, teams) if team == 1]

print("--- CLUSTER 0 ---")
sv.plot_images_grid(images=team_0[:20], grid_size=(2, 10), size=(10, 2))
print("\n--- CLUSTER 1 ---")
sv.plot_images_grid(images=team_1[:20], grid_size=(2, 10), size=(10, 2))

In [None]:
# 4. MANUALLY ASSIGN TEAM NAMES BASED ON THE GRIDS ABOVE
# Look at the images for Cluster 0 and Cluster 1 and assign the correct team name.
# Uncomment the correct dictionary.

TEAM_NAMES = {
    0: "New York Knicks",
    1: "Boston Celtics",
}

# TEAM_NAMES = {
#     0: "Boston Celtics",
#     1: "New York Knicks",
# }

print("Team names assigned:")
print(f"Cluster 0 -> {TEAM_NAMES[0]}")
print(f"Cluster 1 -> {TEAM_NAMES[1]}")

In [None]:
# This cell contains helper classes and functions needed for the main loop.
# It includes the PropertyValidator for confirming numbers and functions for mask conversion.
from typing import Dict, List, Optional, Union, Iterable, Tuple
import cv2

def xyxy_to_mask(boxes: np.ndarray, resolution_wh: Tuple[int, int]) -> np.ndarray:
    width, height = resolution_wh
    n = boxes.shape[0]
    masks = np.zeros((n, height, width), dtype=bool)
    for i, (x_min, y_min, x_max, y_max) in enumerate(boxes):
        x_min, y_min = max(0, int(x_min)), max(0, int(y_min))
        x_max, y_max = min(width - 1, int(x_max)), min(height - 1, int(y_max))
        if x_max >= x_min and y_max >= y_min:
            masks[i, y_min:y_max + 1, x_min:x_max + 1] = True
    return masks

def coords_above_threshold(matrix: np.ndarray, threshold: float) -> List[Tuple[int, int]]:
    rows, cols = np.where(np.asarray(matrix) > threshold)
    pairs = list(zip(rows.tolist(), cols.tolist()))
    pairs.sort(key=lambda rc: matrix[rc[0], rc[1]], reverse=True)
    return pairs

def filter_segments_by_distance(mask: np.ndarray, distance_threshold: float = 300) -> np.ndarray:
    mask_uint8 = mask.astype(np.uint8)
    num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(mask_uint8, connectivity=8)
    if num_labels <= 1: return mask.copy()
    main_label = 1 + np.argmax(stats[1:, cv2.CC_STAT_AREA])
    main_centroid = centroids[main_label]
    filtered_mask = np.zeros_like(mask, dtype=bool)
    for label in range(1, num_labels):
        if np.linalg.norm(centroids[label] - main_centroid) <= distance_threshold:
            filtered_mask[labels == label] = True
    return filtered_mask

Value = Union[int, str, None]
class PropertyValidator:
    def __init__(self, n_consecutive: int):
        self.n = n_consecutive
        self._streak: Dict[int, int] = {}
        self._last: Dict[int, Optional[str]] = {}
        self._validated: Dict[int, Optional[str]] = {}

    def _normalize(self, value: Value) -> Optional[str]:
        if value is None: return None
        s = str(value).strip()
        return s if s else None

    def update(self, tracker_ids: List[int], values: List[Value]):
        for tid, raw in zip(tracker_ids, values):
            if tid in self._validated and self._validated.get(tid) is not None: continue
            val = self._normalize(raw)
            if val is None:
                self._streak[tid] = 0
                continue
            if self._last.get(tid) == val:
                self._streak[tid] = self._streak.get(tid, 0) + 1
            else:
                self._streak[tid] = 1
                self._last[tid] = val
            if self._streak.get(tid, 0) >= self.n:
                self._validated[tid] = self._last.get(tid)

    def get_validated(self, tracker_ids: Iterable[int]) -> List[Optional[str]]:
        return [self._validated.get(tid) for tid in tracker_ids]

print("Helper utilities defined.")

In [None]:
# This is the main processing loop. It will take a few minutes to run.
frames_history = []
detections_history = []
NUMBER_CLASS_ID = 2

# Initialize validators
number_validator = PropertyValidator(n_consecutive=3)
team_validator = PropertyValidator(n_consecutive=1)

# 1. Process the first frame to initialize trackers
frame_generator = sv.get_video_frames_generator(str(SOURCE_VIDEO_PATH))
frame = next(frame_generator)

# Detect players and determine their teams
result = PLAYER_DETECTION_MODEL.infer(frame, confidence=0.4, iou_threshold=0.9, class_agnostic_nms=True)[0]
detections = sv.Detections.from_inference(result)
detections = detections[np.isin(detections.class_id, PLAYER_CLASS_IDS)]
TRACKER_IDS = list(range(1, len(detections.class_id) + 1))
boxes = shrink_boxes(xyxy=detections.xyxy, scale=0.4)
crops = [sv.crop_image(frame, box) for box in boxes]
TEAMS = np.array(team_classifier.predict(crops))
team_validator.update(tracker_ids=TRACKER_IDS, values=TEAMS)

# Prompt SAM2.1 tracker with the initial player detections
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
    sam_predictor.load_first_frame(frame)
    for xyxy, tracker_id in zip(detections.xyxy, TRACKER_IDS):
        sam_predictor.add_new_prompt(frame_idx=0, obj_id=tracker_id, bbox=np.array([xyxy]))

# 2. Loop through the rest of the video frames
frame_generator = sv.get_video_frames_generator(str(SOURCE_VIDEO_PATH))
for index, frame in tqdm(enumerate(frame_generator), desc="Processing video"):
    frame_h, frame_w, *_ = frame.shape

    # Track players
    with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
        tracker_ids, mask_logits = sam_predictor.track(frame)
        masks = (mask_logits > 0.0).cpu().numpy().squeeze().astype(bool)
        player_masks = np.array([filter_segments_by_distance(mask) for mask in masks])
        player_detections = sv.Detections(
            xyxy=sv.mask_to_xyxy(masks=player_masks),
            mask=player_masks,
            tracker_id=np.array(tracker_ids)
        )

    frames_history.append(frame)
    detections_history.append(player_detections)

    # Recognize numbers every 5 frames for efficiency
    if index % 5 == 0:
        result = PLAYER_DETECTION_MODEL.infer(frame, confidence=0.4, iou_threshold=0.9)[0]
        number_detections = sv.Detections.from_inference(result)
        number_detections = number_detections[number_detections.class_id == NUMBER_CLASS_ID]
        number_detections.mask = xyxy_to_mask(boxes=number_detections.xyxy, resolution_wh=(frame_w, frame_h))

        number_crops = [sv.crop_image(frame, xyxy) for xyxy in sv.clip_boxes(sv.pad_boxes(xyxy=number_detections.xyxy, px=10), (frame_w, frame_h))]
        numbers = [NUMBER_RECOGNITION_MODEL.predict(crop, NUMBER_RECOGNITION_MODEL_PROMPT)[0] for crop in number_crops]

        # Match numbers to players using mask Intersection over Smaller area (IoS)
        iou = sv.mask_iou_batch(player_masks, number_detections.mask, sv.OverlapMetric.IOS)
        pairs = coords_above_threshold(iou, 0.9)
        if len(pairs) > 0:
            player_indices, number_indices = zip(*pairs)
            matched_tracker_ids = [player_detections.tracker_id[i] for i in player_indices]
            matched_numbers = [numbers[i] for i in number_indices]
            number_validator.update(tracker_ids=matched_tracker_ids, values=matched_numbers)

print("\n✅ Video processing complete.")

NameError: name 'PropertyValidator' is not defined

In [None]:
from IPython.display import Video

TARGET_VIDEO_PATH = HOME / f"{SOURCE_VIDEO_PATH.stem}-result.mp4"
TARGET_VIDEO_COMPRESSED_PATH = HOME / f"{SOURCE_VIDEO_PATH.stem}-result-compressed.mp4"

video_info = sv.VideoInfo.from_video_path(str(SOURCE_VIDEO_PATH))

# Set up annotators with team colors
team_colors = sv.ColorPalette.from_hex([
    TEAM_COLORS[TEAM_NAMES[0]],
    TEAM_COLORS[TEAM_NAMES[1]]
])
team_mask_annotator = sv.MaskAnnotator(color=team_colors, opacity=0.5, color_lookup=sv.ColorLookup.INDEX)
team_label_annotator = sv.RichLabelAnnotator(
    font_path=f"{HOME}/fonts/Staatliches-Regular.ttf",
    font_size=40, color=team_colors, text_color=sv.Color.WHITE,
    text_position=sv.Position.BOTTOM_CENTER, text_offset=(0, 10),
    color_lookup=sv.ColorLookup.INDEX
)

# Write the annotated frames to a new video file
with sv.VideoSink(str(TARGET_VIDEO_PATH), video_info) as sink:
    for frame, detections in tqdm(zip(frames_history, detections_history), desc="Generating final video", total=len(frames_history)):
        detections = detections[detections.area > 100]
        if len(detections) == 0:
            sink.write_frame(frame)
            continue

        # Get validated team and number for each player
        teams = np.array(team_validator.get_validated(tracker_ids=detections.tracker_id)).astype(int)
        numbers = np.array(number_validator.get_validated(tracker_ids=detections.tracker_id))

        # Create labels with number and player name
        labels = []
        for number, team in zip(numbers, teams):
            if number:
                player_name = TEAM_ROSTERS[TEAM_NAMES[team]].get(number, "")
                labels.append(f"#{number} {player_name}")
            else:
                labels.append(TEAM_NAMES[team])

        # Annotate the frame
        annotated_frame = frame.copy()
        annotated_frame = team_mask_annotator.annotate(scene=annotated_frame, detections=detections, custom_color_lookup=teams)
        annotated_frame = team_label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels, custom_color_lookup=teams)
        sink.write_frame(annotated_frame)

# Compress the video for easier viewing in Colab
!ffmpeg -y -loglevel error -i {TARGET_VIDEO_PATH} -vcodec libx264 -crf 28 {TARGET_VIDEO_COMPRESSED_PATH}

# Display the final result! 🏀
Video(TARGET_VIDEO_COMPRESSED_PATH, embed=True, width=1080)

NameError: name 'HOME' is not defined