Sort Tracker

Importing Libraries

In [None]:
import cv2
import numpy as np
from sort import Sort  
import torchvision
import torch
from torchvision import transforms
from PIL import Image
import random

Load the model and define the dictionaries

In [None]:
# Load YOLO model 
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=True)  
model.eval()  

# Define the image transformation
transform = transforms.Compose([
    transforms.ToTensor(),  
])

# Dictionary for colors and trajectories
id_color_map = {}
trajectories = {}


Get the color for different id's

In [None]:
def get_color_for_id(track_id):
    if track_id not in id_color_map:
        
        id_color_map[track_id] = [random.randint(0, 255) for _ in range(3)]
    return id_color_map[track_id]

Detect the Object in Frame

In [None]:
def detect_objects_in_frame(frame):
    
    img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
    img_tensor = transform(img).unsqueeze(0)  # Add batch dimension

    # Perform inference
    with torch.no_grad():
        predictions = model(img_tensor)[0]

    boxes = []
    confidences = []
    
    for box, score, label in zip(predictions['boxes'], predictions['scores'], predictions['labels']):
        if score > 0.5 and ((label.item() == 3) or (label.item() == 10)):  # 3: car, 10: traffic light
            xmin, ymin, xmax, ymax = box.int().tolist()
            boxes.append([xmin, ymin, xmax, ymax])
            confidences.append(score.item())
 
    return boxes, confidences

In [None]:
def process_video(input_video_path, output_video_path):
    cap = cv2.VideoCapture(input_video_path)
    if not cap.isOpened():
        print("Error opening video file.")
        return

    fps = cap.get(cv2.CAP_PROP_FPS)
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

    # Define the codec and create VideoWriter object
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')  # or use 'mp4v' for .mp4 files
    out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))

    # Initialize SORT Tracker
    tracker = Sort()
    
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 
    frame_count = 0  
    
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        
        frame_count += 1  
    
        boxes, confidences = detect_objects_in_frame(frame)
        detections = [[xmin, ymin, xmax, ymax, conf] for (xmin, ymin, xmax, ymax), conf in zip(boxes, confidences)]
        trackers = tracker.update(np.array(detections))

        # Draw bounding boxes, tracker IDs, and trajectories on the frame
        for trk in trackers:
            x1, y1, x2, y2, track_id = trk
            x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)

            color = get_color_for_id(int(track_id))
            
            # Draw bounding box
            cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
            label = f'ID {int(track_id)}'
            cv2.putText(frame, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.75, color, 2)

            # Get the center of the current bounding box
            center = (x1 + (x2 - x1) // 2, y1 + (y2 - y1) // 2)

            # Store trajectory for this object
            if track_id not in trajectories:
                trajectories[track_id] = []
            trajectories[track_id].append(center)

            # Draw the trajectory line for this object
            for j in range(1, len(trajectories[track_id])):
                if trajectories[track_id][j - 1] is None or trajectories[track_id][j] is None:
                    continue
                cv2.line(frame, trajectories[track_id][j - 1], trajectories[track_id][j], color, 2)

        # Remove trajectories of objects that are no longer tracked
        for track_id in list(trajectories.keys()):
            if track_id not in trackers[:, 4]:
                del trajectories[track_id]

        # Write the frame to the output video
        out.write(frame)
        if frame_count % 5 == 0:
            print(f'Processing frame {frame_count}/{total_frames}')

    cap.release()
    out.release()
    cv2.destroyAllWindows()

In [None]:
input_video_path = 'FaisalTown.mp4' 
output_video_path = 'output_trajectory.mp4'  
process_video(input_video_path, output_video_path)