In [14]:
%pip install torch torchvision opencv-python pillow numpy matplotlib deep-sort-realtime

Note: you may need to restart the kernel to use updated packages.


In [21]:
import os
import cv2
import torch
import torchvision.transforms as transforms
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import numpy as np
from scipy.optimize import linear_sum_assignment

In [20]:
VIDEO_NAME = 'DSC_2414.MOV'
video_path = fr"tracking_rukomet\{VIDEO_NAME}"
output_txt_path = fr"tracking_rukomet\predictions\{VIDEO_NAME.replace('.MOV', '_siameseNN.txt')}"
detector = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights="DEFAULT")
detector.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def get_siamese_network():
    class SiameseNetwork(nn.Module):
        def __init__(self):
            super(SiameseNetwork, self).__init__()
            self.feature_extractor = models.resnet18(pretrained=True)
            self.feature_extractor.fc = nn.Identity()

        def forward(self, x):
            return self.feature_extractor(x)

    return SiameseNetwork().to(device).eval()

snn = get_siamese_network()

transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

cap = cv2.VideoCapture(video_path)
os.makedirs(os.path.dirname(output_txt_path), exist_ok=True)

trackers = {} 
next_person_id = 0

def detect_people(frame):
    img_tensor = transforms.functional.to_tensor(frame).unsqueeze(0).to(device)
    with torch.no_grad():
        preds = detector(img_tensor)[0]
    
    boxes = preds["boxes"].cpu().numpy()
    scores = preds["scores"].cpu().numpy()
    labels = preds["labels"].cpu().numpy()

    detections = []
    for box, score, label in zip(boxes, scores, labels):
        if score > 0.6 and label == 1:
            x1, y1, x2, y2 = map(int, box)
            detections.append((x1, y1, x2, y2, score))
    
    return detections

def extract_embedding(frame, box):
    x1, y1, x2, y2 = box
    crop = frame[y1:y2, x1:x2]
    img_tensor = transform(crop).unsqueeze(0).to(device)
    with torch.no_grad():
        embedding = snn(img_tensor).cpu().numpy()
    return embedding.flatten()

def match_detections(previous_trackers, current_detections, frame):
    global next_person_id
    
    if len(previous_trackers) == 0:
        for det in current_detections:
            person_id = next_person_id
            next_person_id += 1
            trackers[person_id] = extract_embedding(frame, det[:4])
            yield person_id, det
        return
    
    prev_ids = list(previous_trackers.keys())
    prev_embeddings = np.array(list(previous_trackers.values()))
    
    current_embeddings = [extract_embedding(frame, det[:4]) for det in current_detections]
    
    if len(current_embeddings) == 0:
        return
    
    cost_matrix = np.zeros((len(prev_embeddings), len(current_embeddings)))
    for i, prev_emb in enumerate(prev_embeddings):
        for j, curr_emb in enumerate(current_embeddings):
            cost_matrix[i, j] = 1 - np.dot(prev_emb, curr_emb) / (np.linalg.norm(prev_emb) * np.linalg.norm(curr_emb))
    
    row_ind, col_ind = linear_sum_assignment(cost_matrix)
    
    matched = set()
    for i, j in zip(row_ind, col_ind):
        if cost_matrix[i, j] < 0.5:  # Threshold for similarity
            person_id = prev_ids[i]
            trackers[person_id] = current_embeddings[j]
            matched.add(j)
            yield person_id, current_detections[j]
    
    for j, det in enumerate(current_detections):
        if j not in matched:
            person_id = next_person_id
            next_person_id += 1
            trackers[person_id] = extract_embedding(frame, det[:4])
            yield person_id, det

with open(output_txt_path, "w") as f:
    frame_id = 0
    
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        frame_id += 1
        detections = detect_people(frame)
        
        new_trackers = {}
        for person_id, (x1, y1, x2, y2, score) in match_detections(trackers, detections, frame):
            w, h = x2 - x1, y2 - y1
            new_trackers[person_id] = trackers[person_id]
            
            cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
            cv2.putText(frame, str(person_id), (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
            
            f.write(f"{frame_id},{person_id},{x1},{y1},{w},{h},1,-1,-1,-1\n")
        
        trackers = new_trackers
        cv2.imshow("Siamese Tracker", frame)
        
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

cap.release()
cv2.destroyAllWindows()
print(f"Tracking results saved to: {output_txt_path}")
