In [1]:
import torch
from ultralytics import YOLO
import cv2
import numpy as np
from torchvision import models, transforms
from scipy.spatial.distance import cosine
from tqdm import tqdm
import os

In [2]:
#device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cpu


In [3]:
#load pretrained model best.pt
model = YOLO(r'D:\Liat_ai\best.pt')  

#load resnet18 feature extractor
resnet = models.resnet18(pretrained=True)
resnet.fc = torch.nn.Identity()
resnet = resnet.to(device).eval()

#transform functions
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((128, 64)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])



In [4]:
#detection function
def detect_players(frame):
    results = model(frame)[0]
    detections = []
    for box in results.boxes:
        cls_id = int(box.cls.item())
        conf = float(box.conf.item())
        if conf > 0.5:
            x1, y1, x2, y2 = map(int, box.xyxy[0])
            detections.append((x1, y1, x2, y2, conf))
    return detections


In [5]:
player_embeddings = {}
next_player_id = 0
SIMILARITY_THRESHOLD = 0.6

#input video
cap = cv2.VideoCapture(r'D:\Liat_ai\15sec_input_720p.mp4')
if not cap.isOpened():
    raise ValueError("Error: video not opened!")

width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS)

#output video
out = cv2.VideoWriter(r'D:\Liat_ai\output.mp4',
                      cv2.VideoWriter_fourcc(*'mp4v'),
                      fps, (width, height))


In [6]:
#frame by frame reidentification
frame_num = 0
while True:
    ret, frame = cap.read()
    if not ret:
        break

    detections = detect_players(frame)
    current_frame_ids = []

    for (x1, y1, x2, y2, conf) in detections:
        crop = frame[y1:y2, x1:x2]
        if crop.size == 0:
            continue

        input_tensor = transform(crop).unsqueeze(0).to(device)
        with torch.no_grad():
            feature = resnet(input_tensor).squeeze().cpu().numpy()

        
        best_pid = None
        best_similarity = 1

        for pid, prev_feat in player_embeddings.items():
            similarity = cosine(prev_feat, feature)
            if similarity < best_similarity:
                best_similarity = similarity
                best_pid = pid

        if best_similarity < SIMILARITY_THRESHOLD:
            current_frame_ids.append((x1, y1, x2, y2, best_pid))
            player_embeddings[best_pid] = feature
        else:
            player_embeddings[next_player_id] = feature
            current_frame_ids.append((x1, y1, x2, y2, next_player_id))
            next_player_id += 1

    
    for (x1, y1, x2, y2, pid) in current_frame_ids:
        cv2.rectangle(frame, (x1, y1), (x2, y2), (0,255,0), 2)
        cv2.putText(frame, f'Player {pid}', (x1, y1 - 10),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,255,0), 2)

    out.write(frame)
    frame_num += 1
    if frame_num % 10 == 0:
        print(f"Processed {frame_num} frames...")

cap.release()
out.release()
print("Video saved.")



0: 384x640 1 ball, 16 players, 2 referees, 804.2ms
Speed: 5.4ms preprocess, 804.2ms inference, 1.8ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 18 players, 2 referees, 880.3ms
Speed: 2.1ms preprocess, 880.3ms inference, 0.9ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 1 ball, 16 players, 2 referees, 716.6ms
Speed: 1.9ms preprocess, 716.6ms inference, 1.6ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 1 ball, 14 players, 2 referees, 792.9ms
Speed: 1.8ms preprocess, 792.9ms inference, 0.9ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 1 ball, 14 players, 2 referees, 755.1ms
Speed: 1.7ms preprocess, 755.1ms inference, 0.9ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 1 ball, 16 players, 2 referees, 1153.0ms
Speed: 1.8ms preprocess, 1153.0ms inference, 1.3ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 15 players, 2 referees, 1294.8ms
Speed: 3.1ms preprocess, 1294.8ms inference, 1.0ms postproces