<a href="https://colab.research.google.com/github/AnumandlaS/Soccer-Player-Reidentification/blob/main/Reidentification(in_a_single_feed)_using_reid.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install ultralytics torchreid torch torchvision numpy opencv-python

Collecting ultralytics
  Downloading ultralytics-8.3.167-py3-none-any.whl.metadata (37 kB)
Collecting torchreid
  Downloading torchreid-0.2.5.tar.gz (92 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m92.7/92.7 kB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting ultralytics-thop>=2.0.0 (from ultralytics)
  Downloading ultralytics_thop-2.0.14-py3-none-any.whl.metadata (9.4 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nv

In [5]:
import cv2
import numpy as np
from ultralytics import YOLO
import torch
import torchreid
from scipy.spatial.distance import cosine
from collections import defaultdict

# Load YOLOv8 model
model = YOLO('/content/best.pt')

# Load ReID model (osnet_x1_0 pre-trained)
reid_model = torchreid.models.build_model(
    name='osnet_x1_0',
    num_classes=1,  # Not used for feature extraction
    pretrained=True
)
reid_model.eval()
reid_model.cuda() if torch.cuda.is_available() else reid_model.cpu()

# Initialize feature database and ID mapping
feature_db = defaultdict(list)  # Stores {track_id: [(frame_num, embedding)]}
global_id_counter = 0
global_id_map = {}  # Maps track_id to global_id for reidentification

# Cosine similarity threshold for reidentification
SIMILARITY_THRESHOLD = 0.8

# Load video
video_path = '/content/15sec_input_720p.mp4'
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

# Initialize video writer
out = cv2.VideoWriter('output_reid.mp4', cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))

def preprocess_image(img, size=(128, 256)):
    img = cv2.resize(img, size)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = img.transpose((2, 0, 1))  # HWC to CHW
    img = img / 255.0
    # Reshape mean and std for broadcasting
    mean = np.array([0.485, 0.456, 0.406]).reshape(3, 1, 1)
    std = np.array([0.229, 0.224, 0.225]).reshape(3, 1, 1)
    img = (img - mean) / std  # Normalize
    return torch.tensor(img, dtype=torch.float32).unsqueeze(0)

def extract_reid_features(img):
    img_tensor = preprocess_image(img).cuda() if torch.cuda.is_available() else preprocess_image(img)
    with torch.no_grad():
        features = reid_model(img_tensor)
    return features.cpu().numpy().flatten()

def match_features(new_features, frame_num):
    global global_id_counter
    best_id, best_sim = None, -1
    for track_id, features_list in feature_db.items():
        for _, stored_features in features_list:
            sim = 1 - cosine(new_features, stored_features)
            if sim > best_sim:
                best_sim = sim
                best_id = track_id
    if best_sim > SIMILARITY_THRESHOLD:
        return global_id_map[best_id]
    else:
        global_id_counter += 1
        return global_id_counter

ret = True
frame_num = 0
while ret:
    ret, frame = cap.read()
    if ret:
        # Detect and track people
        results = model.track(frame, persist=True)
        boxes = results[0].boxes.xyxy.cpu().numpy()  # Bounding boxes
        track_ids = results[0].boxes.id.cpu().numpy() if results[0].boxes.id is not None else []

        # Process each detected person
        for i, box in enumerate(boxes):
            x1, y1, x2, y2 = map(int, box)
            track_id = int(track_ids[i]) if i < len(track_ids) else -1

            # Crop person image
            person_img = frame[y1:y2, x1:x2]
            if person_img.size == 0:
                continue

            # Extract ReID features
            features = extract_reid_features(person_img)

            # Match features to assign global ID
            global_id = match_features(features, frame_num) if track_id != -1 else global_id_counter + 1
            if track_id != -1:
                feature_db[track_id].append((frame_num, features))
                global_id_map[track_id] = global_id

            # Draw bounding box and global ID
            cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
            cv2.putText(frame, f'ID: {global_id}', (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)

        # Write frame to output video
        out.write(frame)
        frame_num += 1

# Release resources
cap.release()
out.release()
cv2.destroyAllWindows()

Successfully loaded imagenet pretrained weights from "/root/.cache/torch/checkpoints/osnet_x1_0_imagenet.pth"
** The following layers are discarded due to unmatched keys or layer size: ['classifier.weight', 'classifier.bias']

0: 384x640 1 ball, 16 players, 2 referees, 68.8ms
Speed: 2.6ms preprocess, 68.8ms inference, 2.3ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 16 players, 2 referees, 43.6ms
Speed: 2.2ms preprocess, 43.6ms inference, 1.8ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 15 players, 2 referees, 43.7ms
Speed: 3.5ms preprocess, 43.7ms inference, 2.0ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 14 players, 2 referees, 43.7ms
Speed: 2.2ms preprocess, 43.7ms inference, 1.5ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 14 players, 2 referees, 50.6ms
Speed: 2.2ms preprocess, 50.6ms inference, 1.5ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 14 players, 2 referees, 52.9ms
Speed: 3.1ms preprocess