Setup and imports

In [None]:
# Cell 1
from utils import read_video, save_video, measure_distance, convert_pixel_to_meters
from trackers.player_tracker import PlayerTracker
from trackers.ball_tracker import BallTracker
from trackers.court_line_detector import CourtLineDetector
from trackers.mini_court import MiniCourt
from highlight_exporter import HighlightExporter

player_tracker  = PlayerTracker(model_path='weights/yolov5su.pt', device='cuda')
ball_tracker    = BallTracker(model_path='weights/yolov5s_best_weights.pt', device='cuda')
court_detector = CourtLineDetector(model_path='weights/keypoints_model.pth', device='cuda')


Read vids and initialise mini court

In [None]:
frames, fps = read_video('input_videos/input_video.mp4')
mini = MiniCourt(frames[0].shape)

Detections and mapping

In [None]:
player_dets_all  = player_tracker.detect_frames(frames)
ball_centers     = ball_tracker.detect_frames(frames)
court_kpts       = court_detector.predict_frames(frames)
# filter to two players
player_dets = player_tracker.choose_and_filter_players(court_kpts, player_dets_all)

Compute Speeds and collect highlights

In [None]:
import pandas as pd
from copy import deepcopy
import constants
import cv2

# Initialize exporter
exp = HighlightExporter(output_dir='output', fps=fps)

# Interpolate missing ball positions and detect shot frames
ball_centers = ball_tracker.interpolate_ball_positions(ball_centers)
ball_shot_frames = ball_tracker.get_ball_shot_frames(ball_centers)

# Prepare data structures
player_stats_data = []
init = {
    'frame': 0,
    'p1_shots': 0, 'p1_total_shot_speed': 0.0, 'p1_last_shot_speed': 0.0,
    'p2_shots': 0, 'p2_total_shot_speed': 0.0, 'p2_last_shot_speed': 0.0,
    'p1_total_run_speed': 0.0, 'p1_last_run_speed': 0.0,
    'p2_total_run_speed': 0.0, 'p2_last_run_speed': 0.0
}
player_stats_data.append(init)

# Map mini-court positions for all frames
players_mini_full, ball_mini_full = mini.map_positions(player_dets, ball_centers, court_kpts)

total_frames = len(frames)
# Loop over each shot interval
for i in range(len(ball_shot_frames)-1):
    start = ball_shot_frames[i]
    end = ball_shot_frames[i+1]
    duration_s = (end - start) / fps
    # compute ball speed
    b_start = ball_mini_full[start]
    b_end = ball_mini_full[end]
    dist_ball_px = measure_distance(b_start, b_end)
    dist_ball_m = convert_pixel_to_meters(dist_ball_px, constants.COURT_LENGTH_M, mini.court_width)
    speed_ball = (dist_ball_m / duration_s) * 3.6  # km/h
    # determine hitter
    p_pos = players_mini_full[start]
    hitter = min(p_pos, key=lambda pid: measure_distance(p_pos[pid], b_start))
    opp = 1 if hitter == 2 else 2
    # opponent run speed
    p_start = players_mini_full[start][opp]
    p_end = players_mini_full[end][opp]
    dist_run_px = measure_distance(p_start, p_end)
    dist_run_m = convert_pixel_to_meters(dist_run_px, constants.COURT_LENGTH_M, mini.court_width)
    speed_run = (dist_run_m / duration_s) * 3.6
    # update stats
    prev = deepcopy(player_stats_data[-1])
    prev['frame'] = start
    # shooter stats
    prev[f'p{hitter}_shots'] += 1
    prev[f'p{hitter}_total_shot_speed'] += speed_ball
    prev[f'p{hitter}_last_shot_speed'] = speed_ball
    # opponent run stats
    prev[f'p{opp}_total_run_speed'] += speed_run
    prev[f'p{opp}_last_run_speed'] = speed_run
    player_stats_data.append(prev)

# Convert to DataFrame and forward-fill per frame
stats_df = pd.DataFrame(player_stats_data)
frames_df = pd.DataFrame({'frame': list(range(total_frames))})
stats_df = frames_df.merge(stats_df, on='frame', how='left').ffill()
# compute averages
stats_df['p1_avg_shot_speed'] = stats_df['p1_total_shot_speed'] / stats_df['p1_shots'].replace(0,1)
stats_df['p2_avg_shot_speed'] = stats_df['p2_total_shot_speed'] / stats_df['p2_shots'].replace(0,1)
stats_df['p1_avg_run_speed'] = stats_df['p1_total_run_speed'] / stats_df['p1_shots'].replace(0,1)
stats_df['p2_avg_run_speed'] = stats_df['p2_total_run_speed'] / stats_df['p2_shots'].replace(0,1)

# Annotate and export frames
for idx, frame in enumerate(exp.get_frames()):
    row = stats_df.iloc[idx]
    cv2.putText(frame, f"P1 shots: {int(row.p1_shots)}", (10, 40), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255,255,255), 2)
    cv2.putText(frame, f"P2 shots: {int(row.p2_shots)}", (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255,255,255), 2)

# finalize export
save_video(exp.get_frames(), 'output/annotated.mp4', fps)
exp.save_metadata('output/highlights.json')