In [1]:
!pip install ultralytics torchvision scipy --quiet
import torch, torchvision, cv2
import numpy as np
from ultralytics import YOLO
from torchvision import transforms, models
from scipy.optimize import linear_sum_assignment
from sklearn.metrics.pairwise import cosine_similarity


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m17.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.0/63.0 MB[0m [31m20.6 MB/s[0m eta [36m0:00:00[0m
[?25hCreating new Ultralytics Settings v0.0.6 file ✅ 
View Ultralytics Settings with 'yolo settings' or at '/root/.config/Ultralytics/settings.json'
Update Settings with 'yolo settings key=value', i.e. 'yolo settings runs_dir=path/to/dir'. For help see https://docs.ultralytics.com/quickstart/#ultralytics-settings.


In [3]:
model = YOLO('/content/best.pt')  # Update path if needed


In [4]:
VIDEO_PATH = "/content/15sec_input_720p.mp4"  # Update path

cap = cv2.VideoCapture(VIDEO_PATH)
frame_embeddings = []

while True:
    ret, frame = cap.read()
    if not ret:
        break

    results = model.predict(frame, conf=0.4, iou=0.4, verbose=False)[0]

    player_dets = []
    for box in results.boxes:
        cls = int(box.cls.item())
        if cls == 2:  # player
            x1, y1, x2, y2 = map(int, box.xyxy[0].tolist())
            player_dets.append({'bbox': [x1, y1, x2, y2], 'frame': frame.copy()})

    frame_embeddings.append(player_dets)

cap.release()


In [8]:
resnet = models.resnet18(pretrained=True)
resnet.fc = torch.nn.Identity()
resnet = resnet.eval()

transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

def extract_embedding(image, bbox):
    x1, y1, x2, y2 = bbox
    crop = image[y1:y2, x1:x2]
    if crop.shape[0] < 5 or crop.shape[1] < 5:
        return None
    with torch.no_grad():
        img_tensor = transform(crop).unsqueeze(0)
        embedding = resnet(img_tensor).squeeze(0).cpu().numpy()
    return embedding


In [10]:
from tqdm import tqdm


# Tracking parameters
player_id_counter = 0
active_tracks = []
tracked_results = []

MAX_AGE = 30
DIST_THRESHOLD = 100

def get_centroid(box):
    x1, y1, x2, y2 = box
    return ((x1 + x2) // 2, (y1 + y2) // 2)

# Main tracking loop
for frame_data in tqdm(frame_embeddings, desc="Tracking"):
    frame_result = []
    detections = frame_data
    new_tracks = []
    cost_matrix = []

    # ✅ Precompute embeddings for all detections in this frame
    for det in detections:
        det['embedding'] = extract_embedding(det['frame'], det['bbox'])

    # Step 1: Build cost matrix (distance - appearance similarity)
    for track in active_tracks:
        t_cx, t_cy = get_centroid(track['bbox'])
        t_emb = track['embedding']
        row = []

        for det in detections:
            px, py = get_centroid(det['bbox'])
            emb = det['embedding']

            if emb is None:
                row.append(1e6)
                continue

            sim = cosine_similarity([t_emb], [emb])[0][0]
            cost = np.linalg.norm([px - t_cx, py - t_cy]) - sim * 100  # Lower = better
            row.append(cost)

        cost_matrix.append(row)

    matched_tracks = set()
    matched_dets = set()

    # Step 2: Match using Hungarian algorithm
    if cost_matrix:
        cost_matrix = np.array(cost_matrix)
        row_ind, col_ind = linear_sum_assignment(cost_matrix)

        for r, c in zip(row_ind, col_ind):
            if cost_matrix[r][c] < DIST_THRESHOLD:
                track = active_tracks[r]
                det = detections[c]
                emb = det['embedding']
                pid = track['id']
                new_tracks.append({
                    'id': pid,
                    'bbox': det['bbox'],
                    'embedding': emb,
                    'age': 0
                })
                frame_result.append({
                    'bbox': det['bbox'],
                    'id': pid
                })
                matched_tracks.add(r)
                matched_dets.add(c)

    # Step 3: Create new IDs for unmatched detections
    for i, det in enumerate(detections):
        if i not in matched_dets:
            emb = det['embedding']
            if emb is not None:
                player_id_counter += 1
                pid = player_id_counter
                new_tracks.append({
                    'id': pid,
                    'bbox': det['bbox'],
                    'embedding': emb,
                    'age': 0
                })
                frame_result.append({
                    'bbox': det['bbox'],
                    'id': pid
                })

    # Step 4: Keep unmatched tracks if not too old
    for i, track in enumerate(active_tracks):
        if i not in matched_tracks:
            track['age'] += 1
            if track['age'] <= MAX_AGE:
                new_tracks.append(track)

    active_tracks = new_tracks
    tracked_results.append(frame_result)


Tracking: 100%|██████████| 375/375 [01:54<00:00,  3.27it/s]


In [11]:
cap = cv2.VideoCapture(VIDEO_PATH)
width, height = int(cap.get(3)), int(cap.get(4))
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter("output.mp4", fourcc, 30.0, (width, height))

frame_idx = 0
while cap.isOpened():
    ret, frame = cap.read()
    if not ret or frame_idx >= len(tracked_results):
        break

    for player in tracked_results[frame_idx]:
        x1, y1, x2, y2 = player['bbox']
        pid = player['id']
        cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
        cv2.putText(frame, f'ID: {pid}', (x1, y1 - 10),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)

    out.write(frame)
    frame_idx += 1

cap.release()
out.release()
