### Goal of this pipeline:
Goal of this pipeline is simple! So far, we have trained two models to detect:
1. Players, Referees, Balls, GoalKeepers
2. Number in the jersey

Now, the idea is to merge these two model to get a video where the players are being tracked and provided with jersey number as their id.

In [None]:
import os
import cv2
import torch
import numpy as np
import supervision as sv
from ultralytics import YOLO
from pathlib import Path
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

from classifier import JerseyNumberRecognizer
from tracker import PlayerTracker, BallTracker

## Configuration and Model Loading

In [None]:
# Paths
YOLO_WEIGHTS = 'path/to/yolov8_weights.pt'
JERSEY_WEIGHTS = 'path/to/jersey_model.pt'
INPUT_VIDEO = 'path/to/input.mp4'
OUTPUT_VIDEO = 'path/to/output.mp4'

# Constants
PLAYER_CLASS_ID = 1
BALL_CLASS_ID = 0
GOALKEEPER_CLASS_ID = 2
REFEREE_CLASS_ID = 3

CONF_THRESHOLD = 0.5
IOU_THRESHOLD = 0.3

# Initialize models and trackers
detector = YOLO(YOLO_WEIGHTS)
jersey_recognizer = JerseyNumberRecognizer(
    model_path=JERSEY_WEIGHTS,
    num_classes=45  # Number of unique jerseys in your dataset
)
player_tracker = PlayerTracker()
ball_tracker = BallTracker()

## Helper Function

In [None]:
def process_frame(frame: np.ndarray, frame_id: int):
    """Process a single frame"""
    # Run YOLOv8 detection
    results = detector(
        frame,
        conf=CONF_THRESHOLD,
        iou=IOU_THRESHOLD,
        verbose=False
    )[0]
    
    # Convert to supervision Detections
    detections = sv.Detections.from_ultralytics(results)
    
    # Apply NMS
    detections = detections.with_nms(threshold=0.5)
    
    # Separate different classes
    players = detections[detections.class_id == PLAYER_CLASS_ID]
    goalkeepers = detections[detections.class_id == GOALKEEPER_CLASS_ID]
    referees = detections[detections.class_id == REFEREE_CLASS_ID]
    balls = detections[detections.class_id == BALL_CLASS_ID]
    
    # Process players
    annotated_frame = frame.copy()
    for idx, player_det in enumerate(players):
        # Get jersey information
        jersey_info = jersey_recognizer.predict(
            frame,
            player_det.xyxy[0]
        )
        
        # Update player history
        player_tracker.update_history(
            player_det.tracker_id,
            jersey_info,
            frame_id
        )
        
        # Draw player annotations
        detection_info = {
            'bbox': player_det.xyxy[0],
            'track_id': player_det.tracker_id,
            'class': 'player'
        }
        annotated_frame = player_tracker.draw_annotations(
            annotated_frame,
            detection_info
        )
    
    # Process goalkeepers
    for gk_det in goalkeepers:
        detection_info = {
            'bbox': gk_det.xyxy[0],
            'track_id': gk_det.tracker_id,
            'class': 'goalkeeper'
        }
        annotated_frame = player_tracker.draw_annotations(
            annotated_frame,
            detection_info
        )
    
    # Process referees
    for ref_det in referees:
        detection_info = {
            'bbox': ref_det.xyxy[0],
            'track_id': ref_det.tracker_id,
            'class': 'referee'
        }
        annotated_frame = player_tracker.draw_annotations(
            annotated_frame,
            detection_info
        )
    
    # Process balls
    for ball_det in balls:
        annotated_frame = ball_tracker.draw_annotations(
            annotated_frame,
            ball_det.xyxy[0]
        )
    
    return annotated_frame

def process_video(input_path: str, output_path: str, start_frame: int = 0, end_frame: int = None):
    """Process entire video"""
    # Open video
    cap = cv2.VideoCapture(input_path)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = int(cap.get(cv2.CAP_PROP_FPS))
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    
    if end_frame is None:
        end_frame = total_frames
    
    # Create video writer
    out = cv2.VideoWriter(
        output_path,
        cv2.VideoWriter_fourcc(*'mp4v'),
        fps,
        (width, height)
    )
    
    # Process frames
    cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
    pbar = tqdm(total=end_frame-start_frame, desc='Processing video')
    
    frame_idx = start_frame
    while cap.isOpened() and frame_idx < end_frame:
        ret, frame = cap.read()
        if not ret:
            break
            
        try:
            # Process frame
            annotated_frame = process_frame(frame, frame_idx)
            
            # Write frame
            out.write(annotated_frame)
            
            # Display progress
            if frame_idx % 30 == 0:
                # Show tracking statistics
                num_players = len(player_tracker.track_history)
                players_with_numbers = sum(
                    1 for p in player_tracker.track_history.values()
                    if p['most_common_number'] != -1
                )
                pbar.set_postfix({
                    'players': num_players,
                    'identified': players_with_numbers
                })
            
            # Show frame
            cv2.imshow('Tracking', annotated_frame)
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break
                
        except Exception as e:
            print(f"Error processing frame {frame_idx}: {e}")
            
        frame_idx += 1
        pbar.update(1)
    
    # Cleanup
    cap.release()
    out.release()
    cv2.destroyAllWindows()
    pbar.close()
    
    return player_tracker.track_history

## Run Tracking

In [None]:
# Process video
track_history = process_video(
    INPUT_VIDEO,
    OUTPUT_VIDEO,
    start_frame=0
)

# Print tracking statistics
print("\nTracking Statistics:")
print(f"Total players tracked: {len(track_history)}")

identified_players = [
    (track_id, data) for track_id, data in track_history.items()
    if data['most_common_number'] != -1
]

print(f"Players with identified numbers: {len(identified_players)}")
print("\nPlayer Details:")
for track_id, data in identified_players:
    print(f"Player #{track_id}:")
    print(f"  Jersey Number: {data['most_common_number']}")
    print(f"  Team: {'A' if data['team'] == 0 else 'B' if data['team'] == 1 else 'Unknown'}")
    print(f"  Frames Visible: {data['frames_visible']}")

## Visualization Helpers(Optional)

In [None]:
def plot_player_trajectories(track_history, frame_shape):
    """Plot player movement trajectories"""
    
    plt.figure(figsize=(12, 8))
    
    # Plot trajectories for each player
    for track_id, data in track_history.items():
        if data['positions']:
            positions = np.array(data['positions'])
            
            # Plot trajectory
            color = 'r' if data['team'] == 0 else 'b' if data['team'] == 1 else 'gray'
            plt.plot(positions[:, 0], positions[:, 1], color=color, alpha=0.5)
            
            # Plot final position
            plt.scatter(
                positions[-1, 0],
                positions[-1, 1],
                color=color,
                s=100,
                label=f"Player {track_id} (#{data['most_common_number']})"
            )
    
    plt.xlim(0, frame_shape[1])
    plt.ylim(frame_shape[0], 0)  # Invert Y axis to match image coordinates
    plt.title('Player Trajectories')
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    plt.show()