In [43]:
!pip install -q ultralytics supervision scikit-learn

In [44]:
import os
import time
import numpy as np
from ultralytics import YOLO
import supervision as sv
from scipy.optimize import linear_sum_assignment
from sklearn.cluster import KMeans
import cv2
from IPython.display import HTML

In [45]:
VIDEO_PATH = '/content/drive/MyDrive/stealth/15sec_input_720p.mp4'

MODEL_PATH = '/content/drive/MyDrive/stealth/best.pt'

OUTPUT_PATH = '/content/output.mp4'

In [46]:
# Bounding box colors
TEAM_1_COLOR = sv.Color.BLUE
TEAM_2_COLOR = sv.Color.GREEN
GK_COLOR = sv.Color.YELLOW
REFEREE_COLOR = sv.Color(r=255, g=0, b=255)

In [47]:
class PlayerProfile:
    """A profile for each unique player, linking their ID to their appearance."""

    def __init__(self, player_id, team_id, hist):
        self.player_id = player_id
        self.team_id = team_id
        self.hist = hist
        self.tracker_id = None
        self.frames_lost = 0

In [48]:
def get_shoulder_to_knee_color(frame, bbox):
    """Calculates the average color of the entity's torso + shorts area."""

    x1, y1, x2, y2 = map(int, bbox)

    # ROI from ~shoulders (15%) to ~knees (85%)
    roi_y1 = y1 + int((y2 - y1) * 0.15)
    roi_y2 = y1 + int((y2 - y1) * 0.85)
    roi = frame[roi_y1:roi_y2, x1:x2]

    if roi.size == 0:
        return np.array([0, 0, 0])

    return np.mean(roi, axis=(0, 1))

In [49]:
def calculate_full_body_histogram(frame, bbox):
    """Calculates a 3D HSV color histogram for the entire bounding box."""
    x1, y1, x2, y2 = map(int, bbox)
    roi = frame[y1:y2, x1:x2]
    if roi.size == 0:
        return np.zeros(512)
    hsv_roi = cv2.cvtColor(roi, cv2.COLOR_BGR2HSV)
    hist = cv2.calcHist([hsv_roi], [0, 1, 2], None, [8, 8, 8], [0, 180, 0, 256, 0, 256])
    cv2.normalize(hist, hist, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX)
    return hist.flatten()

In [50]:
def setup_video_writer(cap, output_path):
    """Initializes the VideoWriter object."""

    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    fps = int(cap.get(cv2.CAP_PROP_FPS))
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

    return cv2.VideoWriter(output_path, fourcc, fps, (width, height))

In [51]:
def identify_entity_colors(video_path, model, frame_limit=100):
    """ Analyzes initial frames to dynamically identify jersey colors for two teams, goalkeepers, and referees. """

    cap = cv2.VideoCapture(video_path)
    player_colors, gk_colors, referee_colors = [], [], []
    frames_processed = 0

    print("Starting color identification for all entities...")

    while cap.isOpened() and frames_processed < frame_limit:
        ret, frame = cap.read()
        if not ret: break

        results = model(frame, verbose=False)[0]
        detections = sv.Detections.from_ultralytics(results)

        # Separate detections by class for color analysis
        player_detections = detections[detections.class_id == 2]
        gk_detections = detections[detections.class_id == 1]
        referee_detections = detections[detections.class_id == 3]

        for bbox in player_detections.xyxy:
            player_colors.append(get_shoulder_to_knee_color(frame, bbox))

        for bbox in gk_detections.xyxy:
            gk_colors.append(get_shoulder_to_knee_color(frame, bbox))

        for bbox in referee_detections.xyxy:
            referee_colors.append(get_shoulder_to_knee_color(frame, bbox))

        frames_processed += 1
    cap.release()

    # --- Team Color Clustering ---
    if len(player_colors) < 2:
        print("Warning: Not enough players detected. Using default team colors.")
        team1_color, team2_color = np.array([255, 255, 255]), np.array([200, 0, 0])
    else:
        kmeans_teams = KMeans(n_clusters=2, random_state=0, n_init='auto').fit(player_colors)
        team1_color, team2_color = kmeans_teams.cluster_centers_
        print("Team colors identified.")

    # --- Goalkeeper & Referee Color Averaging ---
    if not gk_colors:
        print("Warning: No goalkeepers detected for color analysis. Using a default bright color.")
        goalkeeper_color = np.array([0, 255, 255])
    else:
        goalkeeper_color = np.mean(gk_colors, axis=0)
        print("Goalkeeper color identified.")

    if not referee_colors:
        print("Warning: No referees detected for color analysis. Using a default dark color.")
        referee_color = np.array([50, 50, 50])
    else:
        referee_color = np.mean(referee_colors, axis=0)
        print("Referee color identified.")

    return team1_color, team2_color, goalkeeper_color, referee_color


In [52]:
def assign_new_player_id(team_id, is_goalkeeper, assigned_ids_team1, assigned_ids_team2):
    """Assigns the next available unique ID for a given team, handling goalkeepers separately."""

    if is_goalkeeper:
        if 11 not in assigned_ids_team1:
            assigned_ids_team1.add(11)
            return 11, 1
        elif 22 not in assigned_ids_team2:
            assigned_ids_team2.add(22)
            return 22, 2

    else:
        if team_id == 1:
            for player_id in range(1, 11):
                if player_id not in assigned_ids_team1:
                    assigned_ids_team1.add(player_id)
                    return player_id, 1
        else:
            for player_id in range(12, 22):
                if player_id not in assigned_ids_team2:
                    assigned_ids_team2.add(player_id)
                    return player_id, 2

    return None, None # No available IDs

In [53]:
model = YOLO(MODEL_PATH)

In [54]:
team1_jersey, team2_jersey, gk_jersey, referee_jersey = identify_entity_colors(VIDEO_PATH, model)

Starting color identification for all entities...
Team colors identified.
Referee color identified.


In [55]:
tracker = sv.ByteTrack(frame_rate=30)
cap = cv2.VideoCapture(VIDEO_PATH)
video_writer = setup_video_writer(cap, OUTPUT_PATH)

In [56]:
# Annotators
team1_annotator = sv.BoxAnnotator(color=TEAM_1_COLOR)
team2_annotator = sv.BoxAnnotator(color=TEAM_2_COLOR)
gk_annotator = sv.BoxAnnotator(color=GK_COLOR)
referee_annotator = sv.BoxAnnotator(color=REFEREE_COLOR)
label_annotator = sv.LabelAnnotator(text_position=sv.Position.TOP_CENTER, text_color=sv.Color.WHITE, text_scale=0.6)

In [57]:
# --- State Management ---
player_profiles = {}  # player_id -> PlayerProfile
tracker_to_player_id = {} # tracker_id -> player_id
entity_type_map = {}    # tracker_id -> 'team1', 'team2', 'goalkeeper', or 'referee'
assigned_team1_ids = set()
assigned_team2_ids = set()

In [58]:
total_frames = 0
start_time = time.time()

while cap.isOpened():
    ret, frame = cap.read()
    if not ret: break
    total_frames += 1

    results = model(frame, verbose=False)[0]
    detections = sv.Detections.from_ultralytics(results)
    detections = detections[np.isin(detections.class_id, [1, 2, 3])]
    tracked_detections = tracker.update_with_detections(detections)
    annotated_frame = frame.copy()

    active_tracker_ids = set(tracked_detections.tracker_id)

    # Update 'frames_lost' for all known player profiles
    for profile in player_profiles.values():
        if profile.tracker_id not in active_tracker_ids:
            profile.frames_lost += 1
            profile.tracker_id = None

    # --- Entity Classification, Re-Identification and ID Assignment ---
    unassigned_trackers = []
    for i in range(len(tracked_detections)):
        tracker_id = tracked_detections.tracker_id[i]

        # Only process trackers that haven't been assigned a permanent ID
        if tracker_id not in entity_type_map:
            unassigned_trackers.append(tracked_detections[i])
        # Its a known player, update their state
        elif tracker_id in tracker_to_player_id:
            player_profiles[tracker_to_player_id[tracker_id]].frames_lost = 0

    # Re-identification logic for unassigned trackers against lost players
    if unassigned_trackers:
        lost_profiles = [p for p in player_profiles.values() if p.frames_lost > 0]

        if lost_profiles:
            unassigned_hists = [calculate_full_body_histogram(frame, det.xyxy[0]) for det in unassigned_trackers]
            lost_hists = [p.hist for p in lost_profiles]
            cost_matrix = np.array([[1 - cv2.compareHist(uh, lh, cv2.HISTCMP_CORREL) for lh in lost_hists] for uh in unassigned_hists])
            row_ind, col_ind = linear_sum_assignment(cost_matrix)

            reidentified_indices = set()
            for r, c in zip(row_ind, col_ind):
                if cost_matrix[r,c] < 0.4:
                    matched_profile = lost_profiles[c]
                    tracker_id = unassigned_trackers[r].tracker_id[0]

                    tracker_to_player_id[tracker_id] = matched_profile.player_id
                    entity_type_map[tracker_id] = 'goalkeeper' if matched_profile.player_id in [11, 22] else f"team{matched_profile.team_id}"
                    matched_profile.tracker_id = tracker_id
                    matched_profile.frames_lost = 0
                    reidentified_indices.add(r)

            unassigned_trackers = [ut for i, ut in enumerate(unassigned_trackers) if i not in reidentified_indices]


    # Classify and assign IDs to truly new entities
    for new_detection in unassigned_trackers:
        tracker_id = new_detection.tracker_id[0]
        bbox = new_detection.xyxy[0]

        # Classify by shoulder-to-knee color profile
        kit_color = get_shoulder_to_knee_color(frame, bbox)
        dist_t1 = np.linalg.norm(kit_color - team1_jersey)
        dist_t2 = np.linalg.norm(kit_color - team2_jersey)
        dist_gk = np.linalg.norm(kit_color - gk_jersey)
        dist_ref = np.linalg.norm(kit_color - referee_jersey)
        min_dist = min(dist_t1, dist_t2, dist_gk, dist_ref)

        if min_dist == dist_ref:
            entity_type_map[tracker_id] = 'referee'
        else:
            is_goalkeeper = (min_dist == dist_gk)
            if min_dist == dist_t1 or (is_goalkeeper and dist_t1 < dist_t2):
              team_id_guess = 1
            else:
              team_id_guess = 2

            player_id, team_id = assign_new_player_id(team_id_guess, is_goalkeeper, assigned_team1_ids, assigned_team2_ids)

            if player_id is not None:
                if is_goalkeeper == True:
                  entity_type_map[tracker_id] = 'goalkeeper'
                else:
                  entity_type_map[tracker_id] = f"team{team_id}"

                hist = calculate_full_body_histogram(frame, bbox)
                new_profile = PlayerProfile(player_id, team_id, hist)
                new_profile.tracker_id = tracker_id
                player_profiles[player_id] = new_profile
                tracker_to_player_id[tracker_id] = player_id

    # --- Annotation Logic ---
    team1_mask = np.array([entity_type_map.get(tid) == 'team1' for tid in tracked_detections.tracker_id], dtype=bool)
    team2_mask = np.array([entity_type_map.get(tid) == 'team2' for tid in tracked_detections.tracker_id], dtype=bool)
    gk_mask = np.array([entity_type_map.get(tid) == 'goalkeeper' for tid in tracked_detections.tracker_id], dtype=bool)
    referee_mask = np.array([entity_type_map.get(tid) == 'referee' for tid in tracked_detections.tracker_id], dtype=bool)

    annotated_frame = team1_annotator.annotate(scene=annotated_frame, detections=tracked_detections[team1_mask])
    annotated_frame = team2_annotator.annotate(scene=annotated_frame, detections=tracked_detections[team2_mask])
    annotated_frame = gk_annotator.annotate(scene=annotated_frame, detections=tracked_detections[gk_mask])
    annotated_frame = referee_annotator.annotate(scene=annotated_frame, detections=tracked_detections[referee_mask])

    labels = []
    for tracker_id in tracked_detections.tracker_id:
        if entity_type_map.get(tracker_id) == 'referee':
            labels.append("Referee")
        else:
            labels.append(f"P{tracker_to_player_id.get(tracker_id, '?')}")

    annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=tracked_detections, labels=labels)

    video_writer.write(annotated_frame)

cap.release()
video_writer.release()
end_time = time.time()

In [59]:
# --- Metrics ---
processing_time = end_time - start_time
avg_fps = total_frames / processing_time if processing_time > 0 else 0

print("\n--- Processing Metrics ---")
print(f"Total frames processed: {total_frames}")
print(f"Total processing time: {processing_time:.2f} seconds")
print(f"Average FPS: {avg_fps:.2f}")
print("--------------------------\n")


--- Processing Metrics ---
Total frames processed: 375
Total processing time: 24.17 seconds
Average FPS: 15.51
--------------------------



In [60]:
def display_video(path):
    mp4 = open(path,'rb').read()
    data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
    return HTML("""
    <video width=800 controls>
          <source src="%s" type="video/mp4">
    </video>
    """ % data_url)

In [61]:
# Converts a video at {OUTPUT_PATH} to an H.264-encoded MP4 file
!ffmpeg -y -i {OUTPUT_PATH} -vcodec libx264 output_display.mp4 -hide_banner -loglevel error

In [62]:
from base64 import b64encode
display(display_video('output_display.mp4'))