In [6]:
%pip install torch torchvision opencv-python pillow numpy matplotlib deep-sort-realtime

Note: you may need to restart the kernel to use updated packages.


In [7]:
import os
import cv2
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models
import torchvision
import numpy as np
from PIL import Image

In [8]:
VIDEO_NAME = 'DSC_2411.MOV'
video_path = fr"tracking_rukomet\{VIDEO_NAME}"
output_txt_path = fr"tracking_rukomet\predictions\{VIDEO_NAME.replace('.MOV', '_siamese.txt')}"

detector = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights="DEFAULT")
detector.eval()

class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        self.feature_extractor = models.resnet18(pretrained=True)
        self.feature_extractor.fc = nn.Identity()  # Remove classification layer

    def forward(self, x):
        return self.feature_extractor(x)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
snn = SiameseNetwork().to(device).eval()

In [None]:
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Function to get embeddings
def get_embedding(image):
    image = transform(image).unsqueeze(0).to(device)
    with torch.no_grad():
        embedding = snn(image).cpu().numpy()
    return embedding

# Open video
cap = cv2.VideoCapture(video_path)

# Open file for writing bounding boxes
os.makedirs(os.path.dirname(output_txt_path), exist_ok=True)

# Detection function
def detect_people(frame):
    img_tensor = transforms.functional.to_tensor(frame).unsqueeze(0)
    with torch.no_grad():
        preds = detector(img_tensor)[0]
    
    # Extract detected boxes and scores
    boxes = preds["boxes"].cpu().numpy()
    scores = preds["scores"].cpu().numpy()
    labels = preds["labels"].cpu().numpy()

    # Filter only "person" class (COCO dataset: person label == 1)
    detections = []
    for box, score, label in zip(boxes, scores, labels):
        if score > 0.6 and label == 1:
            x1, y1, x2, y2 = map(int, box)
            detections.append((x1, y1, x2, y2, score))
    
    return detections

# Start tracking loop
with open(output_txt_path, "w") as f:
    frame_id = 0

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        frame_id += 1
        detections = detect_people(frame)

        for x1, y1, x2, y2, score in detections:
            w, h = x2 - x1, y2 - y1
            cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)

            # Save tracking data in DeepSORT format (without tracking ID)
            f.write(f"{frame_id},-1,{x1},{y1},{w},{h},1,-1,-1,-1\n")

        # Display frame
        cv2.imshow("Siamese Tracker", frame)

        # Exit on 'q' key
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

cap.release()
cv2.destroyAllWindows()

print(f"Bounding boxes saved to: {output_txt_path}")