<a href="https://colab.research.google.com/github/arzhrd/Basketball-Player-Detail-Using-Computer-Vision/blob/main/Basketball_AI.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install ultralytics opencv-python-headless scikit-learn
!pip install torch torchvision
!pip install transformers pillow
!pip install supervision

Collecting ultralytics
  Downloading ultralytics-8.3.222-py3-none-any.whl.metadata (37 kB)
Collecting ultralytics-thop>=2.0.0 (from ultralytics)
  Downloading ultralytics_thop-2.0.18-py3-none-any.whl.metadata (14 kB)
Downloading ultralytics-8.3.222-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m28.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading ultralytics_thop-2.0.18-py3-none-any.whl (28 kB)
Installing collected packages: ultralytics-thop, ultralytics
Successfully installed ultralytics-8.3.222 ultralytics-thop-2.0.18
Collecting supervision
  Downloading supervision-0.26.1-py3-none-any.whl.metadata (13 kB)
Downloading supervision-0.26.1-py3-none-any.whl (207 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m207.2/207.2 kB[0m [31m8.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: supervision
Successfully installed supervision-0.26.1


In [2]:
import cv2
import numpy as np
import pickle
from ultralytics import YOLO
from sklearn.cluster import KMeans
from collections import defaultdict, Counter
import os
from pathlib import Path
from IPython.display import HTML, display
from base64 import b64encode

print("All libraries imported successfully!")

Creating new Ultralytics Settings v0.0.6 file ✅ 
View Ultralytics Settings with 'yolo settings' or at '/root/.config/Ultralytics/settings.json'
Update Settings with 'yolo settings key=value', i.e. 'yolo settings runs_dir=path/to/dir'. For help see https://docs.ultralytics.com/quickstart/#ultralytics-settings.
All libraries imported successfully!


In [3]:
from google.colab import files

# Upload your basketball video
print("Please upload your basketball video file...")
uploaded = files.upload()
input_video_path = list(uploaded.keys())[0]
print(f"Video uploaded: {input_video_path}")

Please upload your basketball video file...


Saving video_1.mp4 to video_1.mp4
Video uploaded: video_1.mp4


In [4]:
def read_video(video_path):
    """Read video and return frames"""
    cap = cv2.VideoCapture(video_path)
    frames = []
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        frames.append(frame)
    cap.release()
    return frames

def save_video(output_video_frames, output_video_path, fps=24):
    """Save frames as video"""
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_video_path, fourcc, fps,
                          (output_video_frames[0].shape[1], output_video_frames[0].shape[0]))
    for frame in output_video_frames:
        out.write(frame)
    out.release()
    print(f"Video saved to: {output_video_path}")

def display_video(video_path):
    """Display video in Colab"""
    mp4 = open(video_path,'rb').read()
    data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
    return HTML(f"""



    """)

In [5]:
def get_center_of_bbox(bbox):
    """Get center point of bounding box"""
    x1, y1, x2, y2 = bbox
    return int((x1 + x2) / 2), int((y1 + y2) / 2)

def get_bbox_width(bbox):
    """Get width of bounding box"""
    return bbox[2] - bbox[0]

def measure_distance(p1, p2):
    """Measure Euclidean distance between two points"""
    return ((p1[0] - p2[0])**2 + (p1[1] - p2[1])**2)**0.5

def get_foot_position(bbox):
    """Get foot position (bottom center) of bounding box"""
    x1, y1, x2, y2 = bbox
    return int((x1 + x2) / 2), int(y2)

In [6]:
class PlayerTracker:
    def __init__(self, model_path='yolov8x.pt'):
        """Initialize player tracker with YOLO model"""
        self.model = YOLO(model_path)

    def detect_frames(self, frames, read_from_stub=False, stub_path=None):
        """Detect players in frames"""
        player_detections = []

        if read_from_stub and stub_path is not None and os.path.exists(stub_path):
            with open(stub_path, 'rb') as f:
                player_detections = pickle.load(f)
            return player_detections

        for frame in frames:
            player_dict = self.detect_frame(frame)
            player_detections.append(player_dict)

        if stub_path is not None:
            with open(stub_path, 'wb') as f:
                pickle.dump(player_detections, f)

        return player_detections

    def detect_frame(self, frame):
        """Detect players in a single frame"""
        results = self.model.track(frame, persist=True)[0]
        id_name_dict = results.names

        player_dict = {}
        for box in results.boxes:
            track_id = int(box.id.tolist()[0]) if box.id is not None else -1
            result = box.xyxy.tolist()[0]
            object_cls_id = box.cls.tolist()[0]
            object_cls_name = id_name_dict[object_cls_id]

            if object_cls_name == "person":
                player_dict[track_id] = result

        return player_dict

In [7]:
class BallTracker:
    def __init__(self, model_path='yolov8x.pt'):
        """Initialize ball tracker with YOLO model"""
        self.model = YOLO(model_path)

    def interpolate_ball_positions(self, ball_positions):
        """Interpolate missing ball positions"""
        ball_positions = [x.get(1, []) for x in ball_positions]
        df_ball_positions = pd.DataFrame(ball_positions, columns=['x1', 'y1', 'x2', 'y2'])

        # Interpolate missing values
        df_ball_positions = df_ball_positions.interpolate()
        df_ball_positions = df_ball_positions.bfill()

        ball_positions = [{1: x} for x in df_ball_positions.to_numpy().tolist()]
        return ball_positions

    def detect_frames(self, frames, read_from_stub=False, stub_path=None):
        """Detect ball in frames"""
        ball_detections = []

        if read_from_stub and stub_path is not None and os.path.exists(stub_path):
            with open(stub_path, 'rb') as f:
                ball_detections = pickle.load(f)
            return ball_detections

        for frame in frames:
            ball_dict = self.detect_frame(frame)
            ball_detections.append(ball_dict)

        if stub_path is not None:
            with open(stub_path, 'wb') as f:
                pickle.dump(ball_detections, f)

        return ball_detections

    def detect_frame(self, frame):
        """Detect ball in a single frame"""
        results = self.model.predict(frame, conf=0.15)[0]
        id_name_dict = results.names

        ball_dict = {}
        for box in results.boxes:
            result = box.xyxy.tolist()[0]
            object_cls_id = box.cls.tolist()[0]
            object_cls_name = id_name_dict[object_cls_id]

            if object_cls_name == "sports ball":
                ball_dict[1] = result

        return ball_dict

In [8]:
class TeamAssigner:
    def __init__(self):
        """Initialize team assigner"""
        self.team_colors = {}
        self.player_team_dict = {}

    def get_clustering_model(self, image):
        """Get KMeans clustering model for the image"""
        # Reshape the image to 2D array
        image_2d = image.reshape(-1, 3)

        # Perform KMeans with 2 clusters
        kmeans = KMeans(n_clusters=2, init="k-means++", n_init=1)
        kmeans.fit(image_2d)

        return kmeans

    def get_player_color(self, frame, bbox):
        """Get the dominant color of player's jersey"""
        image = frame[int(bbox[1]):int(bbox[3]), int(bbox[0]):int(bbox[2])]

        # Get the top half of the image (jersey area)
        top_half_image = image[0:int(image.shape[0]/2), :]

        # Get clustering model
        kmeans = self.get_clustering_model(top_half_image)

        # Get the cluster labels for each pixel
        labels = kmeans.labels_

        # Reshape the labels to the original image shape
        clustered_image = labels.reshape(top_half_image.shape[0], top_half_image.shape[1])

        # Get the player cluster (most common cluster in center)
        corner_clusters = [clustered_image[0, 0], clustered_image[0, -1],
                          clustered_image[-1, 0], clustered_image[-1, -1]]
        non_player_cluster = max(set(corner_clusters), key=corner_clusters.count)
        player_cluster = 1 - non_player_cluster

        player_color = kmeans.cluster_centers_[player_cluster]

        return player_color

    def assign_team_color(self, frame, player_detections):
        """Assign team colors based on player jersey colors"""
        player_colors = []
        for _, player_bbox in player_detections.items():
            bbox = player_bbox
            player_color = self.get_player_color(frame, bbox)
            player_colors.append(player_color)

        kmeans = KMeans(n_clusters=2, init="k-means++", n_init=10)
        kmeans.fit(player_colors)

        self.kmeans = kmeans

        self.team_colors[1] = kmeans.cluster_centers_[0]
        self.team_colors[2] = kmeans.cluster_centers_[1]

    def get_player_team(self, frame, player_bbox, player_id):
        """Get which team a player belongs to"""
        if player_id in self.player_team_dict:
            return self.player_team_dict[player_id]

        player_color = self.get_player_color(frame, player_bbox)

        team_id = self.kmeans.predict(player_color.reshape(1, -1))[0]
        team_id += 1

        self.player_team_dict[player_id] = team_id

        return team_id

In [9]:
class BallPossessionAssigner:
    def __init__(self):
        self.max_player_ball_distance = 70

    def assign_ball_to_player(self, players, ball_bbox):
        """Assign ball to the closest player"""
        ball_position = get_center_of_bbox(ball_bbox)

        minimum_distance = 99999
        assigned_player = -1

        for player_id, player_bbox in players.items():
            player_position = get_foot_position(player_bbox)

            distance = measure_distance(player_position, ball_position)

            if distance < self.max_player_ball_distance:
                if distance < minimum_distance:
                    minimum_distance = distance
                    assigned_player = player_id

        return assigned_player

In [18]:
def draw_ellipse(frame, bbox, color, track_id=None):
    """Draw an ellipse at the bottom of bounding box"""
    y2 = int(bbox[3])
    x_center, _ = get_center_of_bbox(bbox)
    width = get_bbox_width(bbox)

    cv2.ellipse(
        frame,
        center=(x_center, y2),
        axes=(int(width), int(0.35*width)),
        angle=0.0,
        startAngle=-45,
        endAngle=235,
        color=color,
        thickness=2,
        lineType=cv2.LINE_4
    )

    if track_id is not None:
        rectangle_width = 40
        rectangle_height = 20
        x1_rect = x_center - rectangle_width // 2
        x2_rect = x_center + rectangle_width // 2
        y1_rect = (y2 - rectangle_height // 2) + 15
        y2_rect = (y2 + rectangle_height // 2) + 15

        cv2.rectangle(frame,
                     (int(x1_rect), int(y1_rect)),
                     (int(x2_rect), int(y2_rect)),
                     color,
                     cv2.FILLED)

        x1_text = x1_rect + 12
        y1_text = y1_rect + 15

        cv2.putText(
            frame,
            f"{track_id}",
            (int(x1_text), int(y1_text)),
            cv2.FONT_HERSHEY_SIMPLEX,
            0.6,
            (0, 0, 0),
            2
        )

    return frame

def draw_triangle(frame, bbox, color):
    """Draw a triangle above the bounding box (for ball)"""
    y = int(bbox[1])
    x, _ = get_center_of_bbox(bbox)

    triangle_points = np.array([
        [x, y],
        [x-10, y-20],
        [x+10, y-20],
    ])
    cv2.drawContours(frame, [triangle_points], 0, color, cv2.FILLED)
    cv2.drawContours(frame, [triangle_points], 0, (0, 0, 0), 2)

    return frame

def draw_team_possession(frame, frame_num, team_possession):
    """Draw team possession stats on frame"""
    overlay = frame.copy()
    cv2.rectangle(overlay, (1350, 850), (1900, 970), (255, 255, 255), cv2.FILLED)
    alpha = 0.4
    cv2.addWeighted(overlay, alpha, frame, 1 - alpha, 0, frame)

    team_possession_till_frame = team_possession[:frame_num+1]
    team_1_frames = team_possession_till_frame[team_possession_till_frame == 1].shape[0]
    team_2_frames = team_possession_till_frame[team_possession_till_frame == 2].shape[0]

    total_frames = team_1_frames + team_2_frames
    if total_frames > 0:
        team_1_percent = team_1_frames / total_frames
        team_2_percent = team_2_frames / total_frames
    else:
        team_1_percent = 0
        team_2_percent = 0

    cv2.putText(frame, f"Team 1 Possession: {team_1_percent*100:.1f}%",
                (1400, 900), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 3)
    cv2.putText(frame, f"Team 2 Possession: {team_2_percent*100:.1f}%",
                (1400, 950), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 3)

    return frame

def draw_annotations(video_frames, tracks, team_possession):
    """Draw all annotations on video frames"""
    output_video_frames = []
    for frame_num, frame in enumerate(video_frames):
        frame = frame.copy()

        player_dict = tracks["players"][frame_num]
        ball_dict = tracks["ball"][frame_num]

        # Draw Players
        for track_id, player_bbox in player_dict.items():
            if isinstance(player_bbox, list):
                color = (0, 0, 255)  # Default red
                has_ball = False
                actual_bbox = player_bbox
            else:
                color = player_bbox.get("team_color", (0, 0, 255))
                has_ball = player_bbox.get('has_ball', False)
                actual_bbox = player_bbox.get('bbox', player_bbox)

            frame = draw_ellipse(frame, actual_bbox, color, track_id)

            if has_ball:
                frame = draw_triangle(frame, actual_bbox, (0, 0, 255))

        # Draw ball
        for track_id, ball_bbox in ball_dict.items():
            frame = draw_triangle(frame, ball_bbox, (0, 255, 0))

        # Draw Team Possession
        frame = draw_team_possession(frame, frame_num, team_possession)

        output_video_frames.append(frame)

    return output_video_frames

In [19]:
import pandas as pd

# Read video
print("Reading video frames...")
video_frames = read_video(input_video_path)
print(f"Total frames: {len(video_frames)}")

# Initialize Trackers
print("\nInitializing trackers...")
player_tracker = PlayerTracker(model_path='yolov8x.pt')
ball_tracker = BallTracker(model_path='yolov8x.pt')

# Detect players and ball
print("\nDetecting players...")
player_detections = player_tracker.detect_frames(video_frames,
                                                  read_from_stub=False,
                                                  stub_path='player_detections.pkl')
print(f"Player detections completed: {len(player_detections)} frames")

print("\nDetecting ball...")
ball_detections = ball_tracker.detect_frames(video_frames,
                                             read_from_stub=False,
                                             stub_path='ball_detections.pkl')

# Interpolate ball positions
try:
    ball_detections = ball_tracker.interpolate_ball_positions(ball_detections)
    print("Ball position interpolation completed")
except:
    print("Ball interpolation skipped (requires pandas)")

# Assign players to teams
print("\nAssigning teams...")
team_assigner = TeamAssigner()
team_assigner.assign_team_color(video_frames[0], player_detections[0])

for frame_num, player_detection in enumerate(player_detections):
    for player_id, bbox in player_detection.items():
        team = team_assigner.get_player_team(video_frames[frame_num], bbox, player_id)

        if isinstance(bbox, dict):
            player_dict = bbox.copy()
        else:
            player_dict = {}

        player_dict['team'] = team
        player_dict['team_color'] = team_assigner.team_colors[team]
        player_detections[frame_num][player_id] = player_dict

# Assign ball possession
print("\nAssigning ball possession...")
ball_possession_assigner = BallPossessionAssigner()
team_possession = []

for frame_num, player_detection in enumerate(player_detections):
    ball_bbox = ball_detections[frame_num].get(1, [])

    player_bboxes = {}
    for k, v in player_detection.items():
        if isinstance(v, dict):
            player_bboxes[k] = v.get('bbox', [v.get('x1', 0), v.get('y1', 0), v.get('x2', 0), v.get('y2', 0)])
        else:
            player_bboxes[k] = v

    assigned_player = ball_possession_assigner.assign_ball_to_player(player_bboxes, ball_bbox)

    if assigned_player != -1:
        player_detections[frame_num][assigned_player]['has_ball'] = True
        team = player_detections[frame_num][assigned_player]['team']
        team_possession.append(team)
    else:
        team_possession.append(team_possession[-1] if team_possession else 0)

team_possession = np.array(team_possession)

# Prepare tracks dictionary
tracks = {
    "players": [{k: v if not isinstance(v, dict) else [v.get('bbox', [0,0,0,0])]
                 for k, v in frame.items()} for frame in player_detections],
    "ball": ball_detections
}

# Fix the tracks format
for frame_num in range(len(tracks["players"])):
    for player_id in tracks["players"][frame_num]:
        bbox_data = tracks["players"][frame_num][player_id]
        if isinstance(bbox_data, list) and len(bbox_data) == 1:
            tracks["players"][frame_num][player_id] = bbox_data[0]

print("\nProcessing complete! Preparing output video...")

Reading video frames...
Total frames: 117

Initializing trackers...

Detecting players...

0: 384x640 10 persons, 1 sports ball, 63.5ms
Speed: 1.9ms preprocess, 63.5ms inference, 10.8ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 10 persons, 1 sports ball, 49.8ms
Speed: 2.3ms preprocess, 49.8ms inference, 12.5ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 10 persons, 1 sports ball, 33.6ms
Speed: 2.0ms preprocess, 33.6ms inference, 11.0ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 10 persons, 33.5ms
Speed: 2.0ms preprocess, 33.5ms inference, 10.9ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 11 persons, 31.1ms
Speed: 2.1ms preprocess, 31.1ms inference, 10.8ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 11 persons, 30.8ms
Speed: 3.0ms preprocess, 30.8ms inference, 10.5ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 10 persons, 31.2ms
Speed: 1.9ms preprocess, 31.2ms inference, 12.2ms postproces

In [20]:
# Draw annotations on frames
print("Drawing annotations...")
output_video_frames = draw_annotations(video_frames, tracks, team_possession)

# Save output video
output_path = 'output_basketball_analysis.mp4'
save_video(output_video_frames, output_path, fps=24)

print(f"\n✅ Analysis complete!")
print(f"Output saved to: {output_path}")

# Display the output video
print("\nDisplaying output video...")
display(display_video(output_path))

Drawing annotations...
Video saved to: output_basketball_analysis.mp4

✅ Analysis complete!
Output saved to: output_basketball_analysis.mp4

Displaying output video...
