<div align="center">
  <a href="http://www.sharif.edu/">
    <img src="https://cdn.freebiesupply.com/logos/large/2x/sharif-logo-png-transparent.png" alt="SUT Logo" width="140">
  </a>
  
  # Sharif University of Technology
  ### Electrical Engineering Department

  ## Signals and Systems
  #### *Final Project - Spring 2025*
</div>

---

<div align="center">
  <h1>
    <b>Object Tracker</b>
  </h1>
  <p>
    An object tracking system using YOLO for detection and various algorithms (KCF, CSRT, MOSSE) for tracking.
  </p>
</div>

<br>

| Professor                  |
| :-------------------------: |
| Dr. Mohammad Mehdi Mojahedian |

<br>

| Contributors              |
| :-----------------------: |
| **Amirreza Mousavi** |
| **Mahdi Falahi** |
| **Zahra Miladipour** |

---

# 0: Imports

In [None]:
import cv2
import matplotlib.pyplot as plt
import numpy as np
from ultralytics import YOLO
import time
import torch
from queue import Empty, Queue
from threading import Thread

# 1: Object Detection

## Preparing Models

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"ObjectTracker using device: {device}")
model = YOLO('./yolo11n.pt').to(device)

## ObjectTracker Class

In [None]:
class ObjectTracker:
    def __init__(self, model, tracker_type='KCF', detect_interval=80, conf_threshold=0.5):

        self.model = model
        self.detect_interval = detect_interval
        self.conf_threshold = conf_threshold
        self.frame_idx = 0
        self.tracked_objects = []

        self.tracker_constructors = {
            'CSRT': cv2.legacy.TrackerCSRT_create,
            'KCF': cv2.legacy.TrackerKCF_create,
            'MOSSE': cv2.legacy.TrackerMOSSE_create
        }
        if tracker_type not in self.tracker_constructors:
            raise ValueError(f"Invalid tracker type: {tracker_type}. Choose from {list(self.tracker_constructors.keys())}")
        self.tracker_type = tracker_type
        print(f"Using tracker: {self.tracker_type}")

    def process_frame(self, frame):
        # --- DETECTION PHASE ---
        if self.frame_idx % self.detect_interval == 0:
            detections = self.detect(frame)
            self.tracked_objects = [] # Reset trackers
            
            for obj_data in detections:
                # Create a new tracker for each detected object
                tracker = self.tracker_constructors[self.tracker_type]()
                bbox = (obj_data['x1'], obj_data['y1'], obj_data['x2'] - obj_data['x1'], obj_data['y2'] - obj_data['y1'])
                tracker.init(frame, bbox)
                
                # Store the tracker along with its metadata
                obj_data['tracker'] = tracker
                self.tracked_objects.append(obj_data)
            
            annotated_frame = self.draw_boxes(frame, self.tracked_objects)

        # --- TRACKING PHASE ---
        else:
            for obj in self.tracked_objects:
                tracker = obj['tracker']
                success, bbox = tracker.update(frame)
                if success:
                    # Update object coordinates based on the tracker's prediction
                    obj['x1'], obj['y1'] = int(bbox[0]), int(bbox[1])
                    obj['x2'], obj['y2'] = int(bbox[0] + bbox[2]), int(bbox[1] + bbox[3])
                # Optional: You could add logic here to remove failed trackers
            
            annotated_frame = self.draw_boxes(frame, self.tracked_objects)
        
        self.frame_idx += 1
        return annotated_frame

    def detect(self, frame):
        results = self.model(frame, verbose=False)[0]
        detections = []
        for box in results.boxes:
            conf = box.conf[0].item()
            if conf > self.conf_threshold:
                class_name = self.model.names[int(box.cls[0].item())]
                coords = box.xyxy[0].tolist()
                detections.append({
                    'class_name': class_name,
                    'x1': int(coords[0]), 'y1': int(coords[1]),
                    'x2': int(coords[2]), 'y2': int(coords[3]),
                    'conf': conf
                })
        return detections

    def draw_boxes(self, frame, objects, boxes_color='random'):
        frame_copy = frame.copy()

        for obj in objects:
            x1, y1, x2, y2 = obj['x1'], obj['y1'], obj['x2'], obj['y2']
            label = f'{obj['class_name']} {obj['conf']:.2f}'
            box_color = np.random.uniform(0, 255, 3).tolist() if boxes_color == 'random' else boxes_color

            box_w = x2 - x1
            font_scale = max(0.5, box_w / 200)
            font_thickness = max(1, int(box_w / 150))

            font = cv2.FONT_HERSHEY_SIMPLEX
            (tw, th), bl = cv2.getTextSize(label, font, font_scale, font_thickness)


            # Adaptive text color
            brightness_vect = np.array([0.114, 0.587, 0.299])
            brightness = np.dot(box_color, brightness_vect)
            text_color = (0,0,0) if brightness > 128 else (255,255,255)

            cv2.rectangle(frame_copy, (x1, y1 - th - 8), (x1 + tw, y1), box_color, -1)
            cv2.rectangle(frame_copy, (x1, y1), (x2, y2), box_color, 3)
            cv2.putText(frame_copy, label, (x1, y1 - 4), font, font_scale, text_color, font_thickness, cv2.LINE_AA)
            
        return frame_copy
    
    def plot_image(self, frame, size_mult=1.0, frame_title=False, axis=False):
        h, w = frame.shape[:2]
        base_figsize = (w / 100, h / 100)
        fig_w, fig_h = base_figsize[0]*size_mult, base_figsize[1]*size_mult
        plt.figure(figsize=(fig_w, fig_h))
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        plt.imshow(frame_rgb)
        plt.axis(axis)
        if frame_title != False:
            plt.title(frame_title)
        plt.show()

## Sample Detection Test

In [None]:
test_img = cv2.imread('./assets/images/0001.jpg')

test_tracker = ObjectTracker(model)
test_img_detections = test_tracker.detect(test_img)
test_tracker.plot_image(test_tracker.draw_boxes(test_img, test_img_detections))

# 2: Tracking

## Video Playback

In [None]:
class VideoReader:
    def __init__(self, source_path):
        self.stream = cv2.VideoCapture(source_path)
        if not self.stream.isOpened():
            raise IOError(f"Could not open video source: {source_path}")
        self.q = Queue(maxsize=2)
        self.stopped = False
        self.thread = Thread(target=self.update, args=(), daemon=True)

    def start(self):
        self.thread.start()

    def update(self):
        while not self.stopped:
            if not self.q.full():
                ret, frame = self.stream.read()
                if not ret:
                    self.stop()
                    return
                frame_pos = self.stream.get(cv2.CAP_PROP_POS_FRAMES)
                self.q.put((frame, frame_pos))

    def read(self):
        try:
            return self.q.get(timeout=1)
        except Empty:
            return None, None

    def stop(self):
        self.stopped = True
        if self.thread.is_alive():
            self.thread.join(timeout=0.5)
        self.stream.release()


class VideoPlayer:
    def __init__(self, source, size_multiplier=1.0, playback_speed=1.0, window_title="Video Playback"):
        self.reader = VideoReader(source)
        self.window_title = window_title
        
        # Get video properties from the underlying stream
        self.fps = self.reader.stream.get(cv2.CAP_PROP_FPS)
        self.frame_width = int(self.reader.stream.get(cv2.CAP_PROP_FRAME_WIDTH))
        self.frame_height = int(self.reader.stream.get(cv2.CAP_PROP_FRAME_HEIGHT))
        self.total_frames = self.reader.stream.get(cv2.CAP_PROP_FRAME_COUNT)
        
        # Calculate target delay based on original FPS and desired playback speed
        self.target_delay_ms = (1000 / self.fps) / playback_speed if self.fps > 0 else 33

        # Create and resize the display window
        cv2.namedWindow(self.window_title, cv2.WINDOW_NORMAL)
        cv2.resizeWindow(
            self.window_title,
            int(self.frame_width * size_multiplier),
            int(self.frame_height * size_multiplier)
        )
        print("--- Video Player Initialized ---")
        print(f"  Resolution: {self.frame_width}x{self.frame_height}")
        print(f"  Original FPS: {self.fps:.2f}")
        print(f"  Total Frames: {self.total_frames}")
        print(f"  Playback Speed: {playback_speed}x")
        print(f"  Target Delay: {self.target_delay_ms:.2f} ms")
        print("--------------------------------")

    def play(self, tracker):
        print("Starting playback... Press 'q' to quit.")
        last_time = time.time()
        self.reader.start()
        
        while True:
            start_time = time.perf_counter()
            
            # Unpack the synchronized frame and its position
            frame, current_frame_pos = self.reader.read()
            
            # If the reader returns None, the video has ended or the thread stopped.
            if frame is None:
                print("Video stream ended.")
                break

            # Process the object tracking
            # processed_frame = tracker.process_frame(frame)
            processed_frame = frame

            # Adding FPS Annotation
            fps = 1 / (time.time() - last_time)
            last_time = time.time()
            cv2.putText(processed_frame, f"FPS: {fps:.2f}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)

            cv2.imshow(self.window_title, processed_frame)

            # Calculate the actual delay needed to maintain the target playback speed
            processing_time_ms = (time.perf_counter() - start_time) * 1000
            real_delay_ms = int(self.target_delay_ms - processing_time_ms)
            
            # Exit if 'q' is pressed
            wait_key = cv2.waitKey(max(1, real_delay_ms))
            if wait_key & 0xFF == ord('q'):
                print("Playback stopped by user.")
                break
            
            # Check if this was the last frame
            if current_frame_pos >= self.total_frames:
                print('Video Playback finished.')
                break
        
        self.release()

    def release(self):
        """Stops the reader thread and closes all OpenCV windows."""
        print("Releasing resources...")
        self.reader.stop()
        cv2.destroyAllWindows()
        # Add a small delay to ensure windows close properly on all systems
        for _ in range(5):
            cv2.waitKey(1)

## Test Playback

In [None]:
# You can change these parameters
VIDEO_PATH = './assets/footage/car1.mp4' # Make sure this path is correct
MODEL_PATH = './yolo11n.pt'
PLAYBACK_SPEED = 1.5 # Play at 1.5x speed
WINDOW_SIZE = 1   # Display window at 75% of original size

try:
    tracker = ObjectTracker(model, 'CSRT')
    player = VideoPlayer(
        source=VIDEO_PATH,
        playback_speed=PLAYBACK_SPEED,
        size_multiplier=WINDOW_SIZE,
        window_title="High-Performance Player"
    )
    player.play(tracker)
except IOError as e:
    print(e)
except Exception as e:
    print(f"An unexpected error occurred: {e}")