In [None]:
import cv2
import torch
import numpy as np
from ultralytics import YOLO
from scipy.optimize import linear_sum_assignment
from torchreid.utils import FeatureExtractor
import json
from tqdm import tqdm
import matplotlib.pyplot as plt
from IPython.display import Video, display
import os

class PlayerTracker:
    def __init__(self, model_path):
        self.model = YOLO(model_path)
        print("Class labels:", self.model.names)  # Debug: check class labels
        self.tracker_args = {
            'tracker': 'bytetrack.yaml',
            'persist': True,
            'verbose': False,
            'stream': True
        }

    def track_players(self, video_path):
        cap = cv2.VideoCapture(video_path)
        fps = cap.get(cv2.CAP_PROP_FPS)
        width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        cap.release()

        tracks = {}
        results = self.model.track(
            source=video_path,
            conf=0.3,  # Lower confidence threshold
            **self.tracker_args
        )

        for frame_idx, r in enumerate(tqdm(results, total=total_frames, desc="Tracking")):
            if r.boxes.id is None:
                continue

            boxes = r.boxes.xyxy.cpu().numpy()
            track_ids = [int(tid) for tid in r.boxes.id.cpu().numpy()]
            confs = r.boxes.conf.cpu().numpy()

            #print(f"[DEBUG] Frame {frame_idx}: Detected {len(track_ids)} players")

            for box, track_id, conf in zip(boxes, track_ids, confs):
                x1, y1, x2, y2 = box
                if track_id not in tracks:
                    tracks[track_id] = {
                        'boxes': [],
                        'frames': [],
                        'features': []
                    }

                tracks[track_id]['boxes'].append([x1, y1, x2, y2])
                tracks[track_id]['frames'].append(frame_idx)

        return tracks, (width, height), fps

class FeatureExtractorWrapper:
    def __init__(self):
        self.model = FeatureExtractor(
            model_name='osnet_x1_0',
            model_path='',
            device='cuda' if torch.cuda.is_available() else 'cpu'
        )

    def extract_features(self, frame, box):
        x1, y1, x2, y2 = map(int, box)
        h, w = frame.shape[:2]
        x1, y1 = max(0, x1), max(0, y1)
        x2, y2 = min(w, x2), min(h, y2)

        if x2 <= x1 or y2 <= y1:
            return None

        crop = frame[y1:y2, x1:x2]
        crop = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
        crop = cv2.resize(crop, (128, 256))
        features = self.model([crop])[0]
        return features.cpu().numpy()

def process_video(video_path, tracker, extractor):
    print(f"Processing {video_path}")
    tracks, resolution, fps = tracker.track_players(video_path)
    cap = cv2.VideoCapture(video_path)

    for track_id, data in tqdm(tracks.items(), desc="Extracting features"):
        frame_indices = data['frames']
        boxes = data['boxes']
        features = []

        for idx, frame_idx in enumerate(frame_indices):
            cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
            ret, frame = cap.read()
            if not ret:
                continue

            feat = extractor.extract_features(frame, boxes[idx])
            if feat is not None:
                features.append(feat)

        features = [f for f in features if f is not None]
        if features:
            avg_features = np.mean(features, axis=0)
            tracks[track_id]['features'] = avg_features
        else:
            tracks[track_id]['features'] = None

    cap.release()
    return tracks, resolution, fps

def match_players(tracks1, tracks2):
    valid_tracks1 = {id: t for id, t in tracks1.items() if t['features'] is not None}
    valid_tracks2 = {id: t for id, t in tracks2.items() if t['features'] is not None}

    ids1 = list(valid_tracks1.keys())
    ids2 = list(valid_tracks2.keys())

    if not ids1 or not ids2:
        return {}

    sim_matrix = np.zeros((len(ids1), len(ids2)))

    for i, id1 in enumerate(ids1):
        for j, id2 in enumerate(ids2):
            feat1 = valid_tracks1[id1]['features']
            feat2 = valid_tracks2[id2]['features']
            sim = np.dot(feat1, feat2) / (np.linalg.norm(feat1) * np.linalg.norm(feat2))
            sim_matrix[i, j] = sim

    row_idx, col_idx = linear_sum_assignment(-sim_matrix)
    mapping = {}
    for i, j in zip(row_idx, col_idx):
        if sim_matrix[i, j] > 0.5:
            mapping[int(ids2[j])] = int(ids1[i])

    return mapping

def visualize_results(video_path, tracks, mapping, output_path):
    cap = cv2.VideoCapture(video_path)
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))

    for frame_idx in tqdm(range(total_frames), desc="Rendering video"):
        ret, frame = cap.read()
        if not ret:
            break

        for track_id, data in tracks.items():
            if frame_idx not in data['frames']:
                continue

            idx = data['frames'].index(frame_idx)
            box = data['boxes'][idx]
            x1, y1, x2, y2 = map(int, box)

            display_id = mapping.get(track_id, track_id)
            is_mapped = track_id in mapping
            color = (0, 255, 0) if is_mapped else (0, 0, 255)

            cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
            cv2.putText(frame, f"ID: {display_id}", (x1, y1 - 10),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2)

        out.write(frame)

    cap.release()
    out.release()

def main():
    tracker = PlayerTracker("best.pt")
    extractor = FeatureExtractorWrapper()

    print("\nProcessing broadcast video...")
    broadcast_tracks, res_b, fps_b = process_video("videos/broadcast.mp4", tracker, extractor)

    print("\nProcessing tacticam video...")
    tacticam_tracks, res_t, fps_t = process_video("videos/tacticam.mp4", tracker, extractor)

    print("\nMatching players...")
    mapping = match_players(broadcast_tracks, tacticam_tracks)
    print(f"Found {len(mapping)} player matches")

    json_mapping = {str(k): int(v) for k, v in mapping.items()}
    with open("player_mapping.json", "w") as f:
        json.dump(json_mapping, f)

    print("\nGenerating broadcast result video...")
    visualize_results("videos/broadcast.mp4", broadcast_tracks,
                      {v: k for k, v in mapping.items()}, "broadcast_result.mp4")

    print("\nGenerating tacticam result video...")
    visualize_results("videos/tacticam.mp4", tacticam_tracks,
                      mapping, "tacticam_result.mp4")

    return mapping

if __name__ == "__main__":
    print("Starting player re-identification...")
    final_mapping = main()
    print("\nPlayer Mapping Complete!")
    print("Mapping:", final_mapping)
    print("Result videos saved: broadcast_result.mp4, tacticam_result.mp4")


Starting player re-identification...
Class labels: {0: 'ball', 1: 'goalkeeper', 2: 'player', 3: 'referee'}
Successfully loaded imagenet pretrained weights from "C:\Users\addur/.cache\torch\checkpoints\osnet_x1_0_imagenet.pth"
** The following layers are discarded due to unmatched keys or layer size: ['classifier.weight', 'classifier.bias']
Model: osnet_x1_0
- params: 2,193,616
- flops: 978,878,352

Processing broadcast video...
Processing videos/broadcast.mp4


Tracking: 100%|██████████| 132/132 [01:01<00:00,  2.14it/s]
Extracting features: 100%|██████████| 35/35 [02:05<00:00,  3.60s/it]



Processing tacticam video...
Processing videos/tacticam.mp4


Tracking: 100%|██████████| 201/201 [01:33<00:00,  2.15it/s]
Extracting features: 100%|██████████| 46/46 [07:37<00:00,  9.94s/it]



Matching players...
Found 35 player matches

Generating broadcast result video...


Rendering video: 100%|██████████| 132/132 [00:02<00:00, 59.17it/s]



Generating tacticam result video...


Rendering video: 100%|██████████| 201/201 [00:03<00:00, 55.60it/s]


Player Mapping Complete!
Mapping: {252: 1, 187: 2, 292: 3, 179: 13, 182: 14, 189: 15, 193: 16, 181: 17, 174: 18, 190: 19, 173: 20, 214: 21, 188: 22, 184: 23, 195: 24, 234: 28, 218: 30, 185: 38, 207: 45, 245: 46, 280: 48, 197: 55, 180: 68, 175: 76, 232: 92, 178: 102, 236: 127, 192: 129, 204: 145, 237: 149, 219: 156, 176: 157, 177: 166, 183: 167, 294: 169}
Result videos saved: broadcast_result.mp4, tacticam_result.mp4



