In [37]:
import torch
import torch.nn as nn
import supervision as sv
from ultralytics import YOLO
from tqdm import tqdm
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import os
import argparse
import cv2
import time
# default dict
from collections import defaultdict

In [38]:
import torch
import torch.nn as nn
import torchvision.models as models
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
from pathlib import Path
import cv2
import json
import torch.nn.functional as F

In [39]:

class JerseyNumberNet(nn.Module):
    def __init__(self, num_classes=55):  # 100 numbers + 1 for no number
        super(JerseyNumberNet, self).__init__()

        # Use ResNet18 as backbone
        self.backbone = models.resnet18(pretrained=True)

        # Modify the first conv layer to handle potential different input size
        self.backbone.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)

        # Replace the final FC layer
        num_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Sequential(
            nn.Linear(num_features, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )

        # Add attention mechanism
        self.attention = nn.Sequential(
            nn.Conv2d(512, 1, kernel_size=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        # Extract features
        x = self.backbone.conv1(x)
        x = self.backbone.bn1(x)
        x = self.backbone.relu(x)
        x = self.backbone.maxpool(x)

        x = self.backbone.layer1(x)
        x = self.backbone.layer2(x)
        x = self.backbone.layer3(x)
        features = self.backbone.layer4(x)

        # Apply attention
        att = self.attention(features)
        features = features * att

        # Global average pooling and classification
        x = F.adaptive_avg_pool2d(features, (1, 1))
        x = torch.flatten(x, 1)
        x = self.backbone.fc(x)

        return x

In [40]:

class JerseyTracker:
    """JerseyTracker class for tracking players and their jersey numbers"""
    def __init__(self, jersey_model_path, player_model_path, device='cuda'):
        self.device = device
        
        # Initialize jersey number recognition model
        self.jersey_model = self.load_jersey_model(jersey_model_path)
        self.jersey_model.eval()
        
        # Initialize player tracking model
        self.player_model = YOLO(player_model_path)
        
        # Initialize tracking components
        self.tracker = sv.ByteTrack()
        self.tracker.reset()
        
        # Initialize jersey number tracking dictionary
        self.jersey_tracker = defaultdict(lambda: {'number': None, 'confidence': 0, 'count': 0})
        
        # Constants and colors for different classes
        self.BALL_ID = 0
        self.PLAYER_ID = 1
        self.GOALKEEPER_ID = 2
        self.REFEREE_ID = 3
        
        self.colors = {
            self.PLAYER_ID: '#00BFFF',      # Light blue for players
            self.GOALKEEPER_ID: '#FF1493',   # Pink for goalkeepers
            self.REFEREE_ID: '#FFD700'       # Gold for referees
        }
        
        # Transform for jersey number recognition
        self.transform = transforms.Compose([
            transforms.Resize((128, 128)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def load_jersey_model(self, model_path):
        model = JerseyNumberNet(num_classes=55)  # As per your training
        model.load_state_dict(torch.load(model_path, map_location=self.device))
        model = model.to(self.device)
        return model

    def predict_jersey_number(self, crop):
        """Predict jersey number from a player crop"""
        try:
            image = Image.fromarray(crop).convert('RGB')
            image = self.transform(image).unsqueeze(0).to(self.device)
            
            with torch.no_grad():
                output = self.jersey_model(image)
                prob = torch.softmax(output, dim=1)
                confidence, predicted = torch.max(prob, 1)
                
                return predicted.item(), confidence.item()
        except Exception as e:
            print(f"Error in jersey prediction: {str(e)}")
            return None, 0

    def update_jersey_tracking(self, tracker_id, number, confidence):
        """Update jersey number tracking for a player"""
        if confidence > self.jersey_tracker[tracker_id]['confidence']:
            self.jersey_tracker[tracker_id]['number'] = number
            self.jersey_tracker[tracker_id]['confidence'] = confidence
        self.jersey_tracker[tracker_id]['count'] += 1

    def resolve_goalkeepers_team_id(self, players, goalkeepers):
        """Resolve team IDs for goalkeepers based on proximity to team centroids"""
        goalkeepers_xy = goalkeepers.get_anchors_coordinates(sv.Position.BOTTOM_CENTER)
        players_xy = players.get_anchors_coordinates(sv.Position.BOTTOM_CENTER)
        
        team_0_centroid = players_xy[players.class_id == 0].mean(axis=0)
        team_1_centroid = players_xy[players.class_id == 1].mean(axis=0)
        
        goalkeepers_team_id = []
        for goalkeeper_xy in goalkeepers_xy:
            dist_0 = np.linalg.norm(goalkeeper_xy - team_0_centroid)
            dist_1 = np.linalg.norm(goalkeeper_xy - team_1_centroid)
            goalkeepers_team_id.append(0 if dist_0 < dist_1 else 1)
        
        return np.array(goalkeepers_team_id)

    def process_frame(self, frame, confidence=0.3):
        """Process a single frame"""
        # Get detections from player tracking model
        result = self.player_model(frame, conf=confidence, verbose=False)[0]
        detections = sv.Detections.from_ultralytics(result)

        # Separate ball detections
        ball_detections = detections[detections.class_id == self.BALL_ID]
        ball_detections.xyxy = sv.pad_boxes(xyxy=ball_detections.xyxy, px=10)

        # Process other detections
        all_detections = detections[detections.class_id != self.BALL_ID]
        all_detections = all_detections.with_nms(threshold=0.5, class_agnostic=True)
        all_detections = self.tracker.update_with_detections(all_detections)

        # Separate detections by class
        players_detections = all_detections[all_detections.class_id == self.PLAYER_ID]
        goalkeeper_detections = all_detections[all_detections.class_id == self.GOALKEEPER_ID]
        referees_detections = all_detections[all_detections.class_id == self.REFEREE_ID]

        # Process jersey numbers for players
        jersey_labels = []
        players_crops = [sv.crop_image(frame, xyxy) for xyxy in players_detections.xyxy]
        
        for crop, tracker_id in zip(players_crops, players_detections.tracker_id):
            number, conf = self.predict_jersey_number(crop)
            if number is not None and conf > 0.3:  # Confidence threshold
                self.update_jersey_tracking(tracker_id, number, conf)
            
            tracked_number = self.jersey_tracker[tracker_id]['number']
            if tracked_number is not None:
                jersey_labels.append(f"#{tracker_id} ({tracked_number})")
            else:
                jersey_labels.append(f"#{tracker_id}")

        # Create labels for goalkeepers and referees
        goalkeeper_labels = [f"#{id} (GK)" for id in goalkeeper_detections.tracker_id]
        referee_labels = [f"#{id} (REF)" for id in referees_detections.tracker_id]

        # Merge all detections
        all_detections = sv.Detections.merge([
            players_detections,
            goalkeeper_detections,
            referees_detections
        ])

        # Combine all labels
        all_labels = jersey_labels + goalkeeper_labels + referee_labels

        return ball_detections, all_detections, all_labels

In [41]:

def process_video(source_video_path, target_video_path, jersey_model_path, player_model_path, confidence=0.3):
    """Process entire video with jersey tracking and player detection"""
    # Initialize video components
    video_info = sv.VideoInfo.from_video_path(source_video_path)
    video_sink = sv.VideoSink(target_video_path, video_info)
    frame_generator = sv.get_video_frames_generator(source_video_path)
    
    # Initialize tracker
    jersey_tracker = JerseyTracker(jersey_model_path, player_model_path)
    
    # Initialize annotators
    colors = [jersey_tracker.colors[jersey_tracker.PLAYER_ID],
              jersey_tracker.colors[jersey_tracker.GOALKEEPER_ID],
              jersey_tracker.colors[jersey_tracker.REFEREE_ID]]
    
    ellipse_annotator = sv.EllipseAnnotator(
        color=sv.ColorPalette.from_hex(colors),
        thickness=2
    )
    label_annotator = sv.LabelAnnotator(
        color=sv.ColorPalette.from_hex(colors),
        text_color=sv.Color.from_hex('#000000'),
        text_position=sv.Position.BOTTOM_CENTER,
    )
    triangle_annotator = sv.TriangleAnnotator(
        color=sv.Color.from_hex('#FFD700'),
        base=20, height=17
    )

    # Process video
    print("Processing video...")
    with video_sink:
        for frame in tqdm(frame_generator, total=video_info.total_frames):
            ball_detections, all_detections, all_labels = jersey_tracker.process_frame(
                frame, confidence
            )
            
            # Annotate frame
            annotated_frame = frame.copy()
            annotated_frame = ellipse_annotator.annotate(annotated_frame, all_detections)
            annotated_frame = label_annotator.annotate(
                annotated_frame, all_detections, all_labels
            )
            annotated_frame = triangle_annotator.annotate(annotated_frame, ball_detections)
            
            video_sink.write_frame(annotated_frame)


In [42]:

source_video = "/home/sahilc/Desktop/Soccer Analysis Project/Experiments/Input/trimmed.mp4"
target_video = "/home/sahilc/Desktop/Soccer Analysis Project/Experiments/Output/annotated_numbers.mp4"
jersey_model = "/home/sahilc/Desktop/Soccer Analysis Project/Player Reidentification and Tracking/Assets/best_jersey_model (1).pth"
player_model = "/home/sahilc/Desktop/Soccer Analysis Project/Player Reidentification and Tracking/Assets/best.pt"

process_video(source_video, target_video, jersey_model, player_model, confidence=0.5)

Processing video...


100%|██████████| 1000/1000 [01:00<00:00, 16.54it/s]
