### Goal of this pipeline:
Goal of this pipeline is simple! So far, we have trained two models to detect:
1. Players, Referees, Balls, GoalKeepers
2. Number in the jersey

Now, the idea is to merge these two model to get a video where the players are being tracked and provided with jersey number as their id.

In [None]:
import cv2
import torch
import numpy as np
from ultralytics import YOLO
from pathlib import Path
import supervision as sv
from typing import List, Dict, Tuple
from collections import defaultdict

In [None]:
## Define the model for jersey number recognizer
class JerseyNumberNet(nn.Module):
    def __init__(self, pretrained=True):
        super(JerseyNumberNet, self).__init__()

        # Use ResNet18 as the base model
        self.backbone = models.resnet18(pretrained=pretrained)

        # Replace the classification head
        num_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Sequential(
            nn.Linear(num_features, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, 100)
        )

    def forward(self, x):
        return self.backbone(x)
    
class JerseyNumberRecognizer:
    def __init__(self, model_path = None, device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.device = device
        self.model = JerseyNumberNet().to(device)

        if model_path:
            self.model.load_state_dict(torch.load(model_path))
            self.model.eval()

        self.transform = transforms.Compose([
            transforms.Resize((128, 64)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def train(self, train_loader, val_loader, epochs=5, save_path='best_model.pth'):
        """Train the model using the given data loaders."""
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(self.model.parameters(), lr=0.001)
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.1)

        best_accuracy = 0

        # for debug purposes:
        for i, data in enumerate(train_loader):
            inputs, labels = data
            inputs, labels = inputs.to(self.device), labels.to(self.device)
            print(f"Inputs shape: {inputs.shape}")
            print(f"Labels shape: {labels.shape}")
            
            result = self.model(inputs)
            print(f"Result shape: {result.shape}")
            break

        for epoch in range(epochs):
            self.model.train()
            running_loss = 0.0
            for i, data in enumerate(train_loader):
                inputs, labels = data
                inputs, labels = inputs.to(self.device), labels.to(self.device)

                optimizer.zero_grad()

                outputs = self.model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                running_loss += loss.item()
                if i % 100 == 99:
                    print(f"[{epoch + 1}, {i + 1}] loss: {running_loss / 100}")
                    running_loss = 0.0

            self.model.eval()
            correct = 0
            total = 0
            with torch.no_grad():
                for data in val_loader:
                    inputs, labels = data
                    inputs, labels = inputs.to(self.device), labels.to(self.device)

                    outputs = self.model(inputs)

                    # Convert one-hot encoded labels to class indices
                    labels = torch.argmax(labels, dim=1)

                    # debug: Print the shape of tensors
                    
                    _, predicted = torch.max(outputs.data, 1)
                
                    total += labels.size(0)
                    print(f"labels: {labels}")
                    print(f"predicted: {predicted}")
                    correct += (predicted == labels).sum().item()
                    print(f"Total: {total}, Correct: {correct} for epoch {epoch + 1} in validation")

            accuracy = correct / total
            print(f"Accuracy after epoch {epoch + 1}: {accuracy}")

            if accuracy > best_accuracy:
                best_accuracy = accuracy
                torch.save(self.model.state_dict(), save_path)

            scheduler.step()
        print(f"Training finished. Best accuracy: {best_accuracy}")

    def predict(self, jersey_crop: np.ndarray) -> Tuple[int, float]:
        """Predict the jersey number from the given crop."""
        self.model.eval()

        image = Image.fromarray(cv2.cvtColor(jersey_crop, cv2.COLOR_BGR2RGB))
        image = self.transform(image).unsqueeze(0).to(self.device)

        with torch.no_grad():
            outputs = self.model(image)
            probabilities = torch.softmax(outputs, dim=1).cpu().numpy()
            prediction = torch.argmax(outputs).item()
            confidence = probabilities[0, prediction]

        return prediction, confidence
    

In [None]:
object_detector_weights = 'yolov5s.pt'

In [None]:
import cv2
import torch
import numpy as np
from ultralytics import YOLO
from pathlib import Path
import supervision as sv
from typing import List, Dict, Tuple
from collections import defaultdict

class PlayerJerseyTracker:
    def __init__(
        self,
        object_detector_weights: str,
        jersey_model_weights: str,
        conf_threshold: float = 0.5,
        jersey_conf_threshold: float = 0.7,
        device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
    ):
        """Initialize the player tracking pipeline with jersey recognition"""
        self.device = device
        self.conf_threshold = conf_threshold
        self.jersey_conf_threshold = jersey_conf_threshold
        
        # Initialize models
        self.object_detector = YOLO(object_detector_weights)
        self.jersey_recognizer = JerseyNumberRecognizer(
            model_path=jersey_model_weights,
            device=device
        )
        
        # Initialize tracker
        self.tracker = sv.ByteTrack()
        self.tracker.reset()
        
        # Player tracking history
        self.player_history = defaultdict(lambda: {
            'jersey_numbers': [],
            'confidences': [],
            'frames_visible': 0,
            'last_seen': 0,
            'most_common_number': -1
        })
        
        # Initialize annotators
        self.box_annotator = sv.BoxAnnotator(
            color=sv.ColorPalette.default(),
            thickness=2
        )
        self.label_annotator = sv.LabelAnnotator(
            color=sv.ColorPalette.default(),
            text_color=sv.Color.black(),
            text_position=sv.Position.BOTTOM_CENTER
        )
        
    def get_jersey_number(self, frame: np.ndarray, bbox: List[int]) -> Tuple[int, float]:
        """Extract and recognize jersey number from player bbox"""
        x1, y1, x2, y2 = map(int, bbox)
        h = y2 - y1
        w = x2 - x1
        
        # Extract upper body region
        jersey_y1 = max(0, y1 + int(h * 0.15))
        jersey_y2 = min(frame.shape[0], y1 + int(h * 0.5))
        jersey_x1 = max(0, x1)
        jersey_x2 = min(frame.shape[1], x2)
        
        jersey_crop = frame[jersey_y1:jersey_y2, jersey_x1:jersey_x2]
        
        if jersey_crop.size == 0:
            return -1, 0.0
        
        number, confidence = self.jersey_recognizer.predict(jersey_crop)
        return number, confidence

    def update_player_history(self, track_id: int, jersey_number: int, 
                            confidence: float, frame_id: int):
        """Update player tracking history"""
        history = self.player_history[track_id]
        history['jersey_numbers'].append(jersey_number)
        history['confidences'].append(confidence)
        history['frames_visible'] += 1
        history['last_seen'] = frame_id
        
        # Update most common jersey number if confidence is high enough
        if confidence > self.jersey_conf_threshold:
            numbers = [n for n, c in zip(history['jersey_numbers'], history['confidences'])
                      if c > self.jersey_conf_threshold]
            
            if numbers:
                from collections import Counter
                counter = Counter(numbers)
                history['most_common_number'] = counter.most_common(1)[0][0]

    def process_frame(self, frame: np.ndarray, frame_id: int) -> Tuple[List[Dict], np.ndarray]:
        """Process a single frame with tracking and jersey number recognition"""
        # Run detection
        results = self.object_detector(frame, conf=self.conf_threshold)[0]
        detections = sv.Detections.from_ultralytics(results)
        
        # Apply NMS
        detections = detections.with_nms(threshold=0.5)
        
        # Separate players from other detections
        player_mask = detections.class_id == 1  # Assuming class_id 1 is for players
        player_detections = detections[player_mask]
        
        # Update tracker
        tracked_players = self.tracker.update_with_detections(player_detections)
        
        # Process jersey numbers for tracked players
        for i in range(len(tracked_players)):
            if tracked_players.tracker_id is not None:
                track_id = tracked_players.tracker_id[i]
                bbox = tracked_players.xyxy[i]
                
                # Get jersey number
                number, conf = self.get_jersey_number(frame, bbox)
                
                # Update history
                self.update_player_history(track_id, number, conf, frame_id)
                
                # Add jersey number to annotations
                jersey_num = self.player_history[track_id]['most_common_number']
                if jersey_num != -1:
                    tracked_players.add_annotation(
                        annotation_type="jersey_number",
                        value=jersey_num,
                        index=i
                    )
        
        # Prepare frame annotation
        annotated_frame = frame.copy()
        
        # Draw bounding boxes
        labels = [
            f"#{track_id} " +
            (f"Jersey #{self.player_history[track_id]['most_common_number']}"
             if self.player_history[track_id]['most_common_number'] != -1 else "")
            for track_id in tracked_players.tracker_id
        ]
        
        annotated_frame = self.box_annotator.annotate(
            scene=annotated_frame, 
            detections=tracked_players
        )
        
        annotated_frame = self.label_annotator.annotate(
            scene=annotated_frame,
            detections=tracked_players,
            labels=labels
        )
        
        return tracked_players, annotated_frame

    def process_video(
        self,
        source_path: str,
        output_path: str,
        start_frame: int = 0,
        end_frame: int = None
    ):
        """Process video with player tracking and jersey number recognition"""
        # Get video info
        video_info = sv.VideoInfo.from_video_path(source_path)
        
        if end_frame is None:
            end_frame = video_info.total_frames
            
        # Initialize video writer
        video_sink = sv.VideoSink(
            output_path,
            video_info
        )
        
        # Process frames
        frame_generator = sv.get_video_frames_generator(source_path)
        
        with video_sink:
            for frame_idx, frame in enumerate(frame_generator):
                if frame_idx < start_frame:
                    continue
                if frame_idx >= end_frame:
                    break
                    
                try:
                    # Process frame
                    detections, annotated_frame = self.process_frame(
                        frame, frame_idx
                    )
                    
                    # Write frame
                    video_sink.write_frame(annotated_frame)
                    
                    # Show progress
                    if frame_idx % 30 == 0:
                        print(f"Processed frame {frame_idx}/{end_frame}")
                        
                except Exception as e:
                    print(f"Error processing frame {frame_idx}: {e}")
                    continue
        
        return self.player_history


In [None]:

# Example usage
if __name__ == "__main__":
    # Initialize tracker
    tracker = PlayerJerseyTracker(
        object_detector_weights="path/to/yolo_weights.pt",
        jersey_model_weights="path/to/jersey_model.pt",
        conf_threshold=0.5,
        jersey_conf_threshold=0.7
    )
    
    # Process video
    history = tracker.process_video(
        source_path="input_video.mp4",
        output_path="output_video.mp4"
    )
    
    # Print tracking statistics
    print("\nPlayer Tracking Statistics:")
    for track_id, data in history.items():
        if data['most_common_number'] != -1:
            print(f"\nPlayer #{track_id}")
            print(f"Jersey Number: {data['most_common_number']}")
            print(f"Frames Visible: {data['frames_visible']}")
            print(f"Average Confidence: {np.mean(data['confidences']):.2f}")