<div align="center">
  <a href="http://www.sharif.edu/">
    <img src="https://cdn.freebiesupply.com/logos/large/2x/sharif-logo-png-transparent.png" alt="SUT Logo" width="140">
  </a>
  
  # Sharif University of Technology
  ### Electrical Engineering Department

  ## Signals and Systems
  #### *Final Project - Spring 2025*
</div>

---

<div align="center">
  <h1>
    <b>Object Tracker</b>
  </h1>
  <p>
    An object tracking system using YOLO for detection and various algorithms (KCF, CSRT, MOSSE) for tracking.
  </p>
</div>

<br>

| Professor                  |
| :-------------------------: |
| Dr. Mohammad Mehdi Mojahedian |

<br>

| Contributors              |
| :-----------------------: |
| **Amirreza Mousavi** |
| **Mahdi Falahi** |
| **Zahra Miladipour** |

---

# 0: Imports

In [None]:
import cv2
import matplotlib.pyplot as plt
import numpy as np
from ultralytics import YOLO
import time
import torch
from scipy.optimize import linear_sum_assignment

# 1: Object Detection

## Preparing Models

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"ObjectTracker using device: {device}")
model = YOLO('./yolo11n.pt').to(device)

## ObjectTracker Class

In [None]:
class ObjectTracker:
    def __init__(self, model, tracker_type='KCF', detect_interval=48, conf_threshold=0.5, iou_threshold=0.7, max_lost_frames=5, max_objects=10, use_kalman=True):
        self.model = model
        self.detect_interval = detect_interval
        self.conf_threshold = conf_threshold
        self.iou_threshold = iou_threshold
        self.max_lost_frames = max_lost_frames
        self.max_objects = max_objects
        self.use_kalman = use_kalman
        self.frame_idx = 0
        self.tracked_objects = []
        self.next_track_id = 0

        self.tracker_constructors = {
            'CSRT': cv2.legacy.TrackerCSRT_create,
            'KCF': cv2.legacy.TrackerKCF_create,
            'MOSSE': cv2.legacy.TrackerMOSSE_create,
            'MEDIAN_FLOW': cv2.legacy.TrackerMedianFlow_create
        }
        if tracker_type not in self.tracker_constructors:
            raise ValueError(f"Invalid tracker type: {tracker_type}. Choose from {list(self.tracker_constructors.keys())}")
        self.tracker_type = tracker_type
        print(f"Using tracker: {self.tracker_type}")

    def process_frame(self, frame):
        height, width = frame.shape[:2]

        self._predict_phase(height, width)
        lost_track_detected = self._update_phase(frame)
        self._cleanup_and_detect_phase(frame, self.iou_threshold, self.max_lost_frames, lost_track_detected, self.max_objects)
        annotated_frame = self._drawing_phase(frame)

        self.frame_idx += 1
        return annotated_frame

    def _predict_phase(self, height, width):
        for obj in self.tracked_objects:
            obj['kf'].predict()
            predicted_state = obj['kf'].statePost
            cx, cy, w, h = predicted_state[0], predicted_state[1], predicted_state[2], predicted_state[3]
            obj['bbox'] = (int(cx - w/2), int(cy - h/2), int(cx + w/2), int(cy + h/2))

            x1, y1, x2, y2 = obj['bbox']
            if x2 < 0 or y2 < 0 or x1 > width or y1 > height:
                obj['lost_frames'] = self.max_lost_frames
            elif max(0, min(x2, width) - max(x1, 0)) * max(0, min(y2, height) - max(y1, 0)) < 0.1 * (x2 - x1) * (y2 - y1):
                obj['lost_frames'] += 1
            obj['bbox'] = (max(0, x1), max(0, y1), min(width, x2), min(height, y2))

    def _update_phase(self, frame):
        lost_track_detected = False
        for obj in self.tracked_objects:
            success, bbox = obj['tracker'].update(frame)
            if success:
                x1, y1, w, h = [int(v) for v in bbox]
                cx, cy = x1 + w/2, y1 + h/2
                measurement = np.array([cx, cy, w, h], dtype=np.float32)
                obj['kf'].correct(measurement)
                obj['bbox'] = (x1, int(y1), int(x1 + w), int(y1 + h))
                obj['lost_frames'] = 0
            else:
                obj['lost_frames'] += 1
                lost_track_detected = True
        return lost_track_detected

    def _cleanup_and_detect_phase(self, frame, iou_threshold, max_lost_frames, lost_track_detected, max_objects):
        self.tracked_objects = [t for t in self.tracked_objects if t['lost_frames'] < max_lost_frames]

        if lost_track_detected or (self.frame_idx % self.detect_interval == 0):
            detections = self.detect(frame)
            
            if detections:
                tracked_bboxes = [t['bbox'] for t in self.tracked_objects]
                detected_bboxes = [[d['x1'], d['y1'], d['x2'], d['y2']] for d in detections]
                
                iou_matrix = self._calculate_iou(tracked_bboxes, detected_bboxes)
                matched_pairs, _, unmatched_detections = self._apply_matching(iou_matrix, iou_threshold)

                for t_idx, d_idx in matched_pairs:
                    track = self.tracked_objects[t_idx]
                    det = detections[d_idx]
                    x1, y1, x2, y2 = det['x1'], det['y1'], det['x2'], det['y2']
                    w, h = x2 - x1, y2 - y1
                    cx, cy = x1 + w/2, y1 + h/2
                    measurement = np.array([cx, cy, w, h], dtype=np.float32)
                    track['kf'].correct(measurement)
                    track['tracker'].init(frame, (x1, y1, w, h))
                    track['bbox'] = (x1, y1, x2, y2)
                    track['lost_frames'] = 0
                        
    def _drawing_phase(self, frame):
        objects_to_draw = []
        for obj in self.tracked_objects:
            state = obj['kf'].statePost
            cx, cy, w, h = state[0], state[1], state[2], state[3]
            x1, y1, x2, y2 = int(cx-w/2), int(cy-h/2), int(cx+w/2), int(cy+h/2)
            
            objects_to_draw.append({
                'x1': x1, 'y1': y1, 'x2': x2, 'y2': y2,
                'class_name': f"ID-{obj['id']} {obj['class_name']}",
                'conf': obj.get('conf', 1.0), 'color': obj['color']
            })
        
        annotated_frame = self.draw_boxes(frame, objects_to_draw)
        return annotated_frame

    def add_manual_track(self, frame, bbox, class_name):
        x1, y1, x2, y2 = [int(c) for c in bbox]
        w, h = x2 - x1, y2 - y1

        if w <= 0 or h <= 0:
            print(f"Warning: Invalid bbox {bbox}. Skipping.")
            return

        print(f"Adding new manual track for '{class_name}' at {(x1, y1, w, h)}")
        
        new_kf = self._create_kalman_filter()
        new_kf.statePost = np.array([x1 + w/2, y1 + h/2, w, h, 0, 0, 0, 0], dtype=np.float32)
        
        tracker = self.tracker_constructors[self.tracker_type]()
        tracker.init(frame, (x1, y1, w, h))
        
        self.tracked_objects.append({
            'id': self.next_track_id, 'kf': new_kf, 'tracker': tracker,
            'class_name': class_name, 'conf': 1.0,
            'color': np.random.uniform(0, 255, 3).tolist(), 'bbox': (x1, y1, x2, y2),
            'lost_frames': 0
        })
        self.next_track_id += 1

    def detect(self, frame):
        results = self.model(frame, verbose=False)[0]
        detections = []
        for box in results.boxes:
            conf = box.conf[0].item()
            if conf > self.conf_threshold:
                class_name = self.model.names[int(box.cls[0].item())]
                coords = box.xyxy[0].tolist()
                detections.append({
                    'class_name': class_name,
                    'x1': int(coords[0]), 'y1': int(coords[1]),
                    'x2': int(coords[2]), 'y2': int(coords[3]),
                    'conf': conf
                })
        return detections

    def _calculate_iou(self, bboxes1, bboxes2):
        bboxes1 = np.array(bboxes1)
        bboxes2 = np.array(bboxes2)
        if bboxes1.size == 0 or bboxes2.size == 0:
            return np.empty((len(bboxes1), len(bboxes2)))
        xA = np.maximum(bboxes1[:, 0][:, np.newaxis], bboxes2[:, 0])
        yA = np.maximum(bboxes1[:, 1][:, np.newaxis], bboxes2[:, 1])
        xB = np.minimum(bboxes1[:, 2][:, np.newaxis], bboxes2[:, 2])
        yB = np.minimum(bboxes1[:, 3][:, np.newaxis], bboxes2[:, 3])
        interArea = np.maximum(0, xB - xA) * np.maximum(0, yB - yA)
        boxAArea = (bboxes1[:, 2] - bboxes1[:, 0]) * (bboxes1[:, 3] - bboxes1[:, 1])
        boxBArea = (bboxes2[:, 2] - bboxes2[:, 0]) * (bboxes2[:, 3] - bboxes2[:, 1])
        iou = interArea / (boxAArea[:, np.newaxis] + boxBArea - interArea + 1e-6)
        return iou
    
    def _create_kalman_filter(self):
        kf = cv2.KalmanFilter(8, 4)
        kf.transitionMatrix = np.array([[1,0,0,0,1,0,0,0],[0,1,0,0,0,1,0,0],[0,0,1,0,0,0,1,0],[0,0,0,1,0,0,0,1],[0,0,0,0,1,0,0,0],[0,0,0,0,0,1,0,0],[0,0,0,0,0,0,1,0],[0,0,0,0,0,0,0,1]], np.float32)
        kf.measurementMatrix = np.array([[1,0,0,0,0,0,0,0],[0,1,0,0,0,0,0,0],[0,0,1,0,0,0,0,0],[0,0,0,1,0,0,0,0]], np.float32)
        kf.processNoiseCov = np.eye(8, dtype=np.float32) * 0.03
        kf.processNoiseCov[4:, 4:] *= 10
        kf.measurementNoiseCov = np.eye(4, dtype=np.float32) * 0.1
        return kf

    def _apply_matching(self, iou_matrix, iou_threshold):
        if iou_matrix.size == 0:
            return [], [], list(range(0))
        cost_matrix = 1 - iou_matrix
        track_indices, detection_indices = linear_sum_assignment(cost_matrix)
        matched_pairs = []
        unmatched_track_indices = set(range(iou_matrix.shape[0]))
        unmatched_detection_indices = set(range(iou_matrix.shape[1]))
        for t_idx, d_idx in zip(track_indices, detection_indices):
            if iou_matrix[t_idx, d_idx] >= iou_threshold:
                matched_pairs.append((t_idx, d_idx))
                unmatched_track_indices.discard(t_idx)
                unmatched_detection_indices.discard(d_idx)
        return matched_pairs, list(unmatched_track_indices), list(unmatched_detection_indices)

    def draw_boxes(self, frame, objects, default_color='random'):
        frame_copy = frame.copy()
        for obj in objects:
            x1, y1, x2, y2 = obj['x1'], obj['y1'], obj['x2'], obj['y2']
            label = f"{obj['class_name']} {obj['conf']:.2f}"
            box_color = obj.get('color', np.random.uniform(0, 255, 3).tolist())
            cv2.rectangle(frame_copy, (x1, y1), (x2, y2), box_color, 2)
            cv2.putText(frame_copy, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, box_color, 2)
        return frame_copy

## Sample Detection Test

In [None]:
# test_img = cv2.imread('./assets/images/0001.jpg')

# test_tracker = ObjectTracker(model)
# test_img_detections = test_tracker.detect(test_img)
# test_tracker.plot_image(test_tracker.draw_boxes(test_img, test_img_detections))

# 2: Tracking

## VideoPlayer class

In [None]:
class VideoPlayer:
    def __init__(self, source, size_multiplier=1.0, playback_speed=1.0, window_title="Video Playback"):
        self.cap = cv2.VideoCapture(source)
        self.window_title = window_title
        
        self.fps = self.cap.get(cv2.CAP_PROP_FPS)
        self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        
        self.target_delay_ms = (1000 / self.fps) / playback_speed if self.fps > 0 else 33

        # --- State and Interaction Management ---
        self.state = 'INITIALIZING'
        self.selectable_detections = []
        self.user_selections = []
        self.is_drawing_roi = False
        self.roi_start_point = None
        self.roi_end_point = None
        self.new_manual_box = None
        self.show_help = True
        self.YOLO_CLASSES = {
            0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane', 
            5: 'bus', 6: 'train', 7: 'truck', 8: 'boat', 9: 'traffic light',
            10: 'fire hydrant', 11: 'stop sign', 12: 'parking meter', 13: 'bench',
            14: 'bird', 15: 'cat', 16: 'dog', 17: 'horse', 18: 'sheep', 19: 'cow',
            20: 'other' # Added 'other' class
        }
        
        cv2.namedWindow(self.window_title, cv2.WINDOW_NORMAL)
        cv2.resizeWindow(self.window_title, int(self.frame_width * size_multiplier), int(self.frame_height * size_multiplier))
        
        cv2.setMouseCallback(self.window_title, self._mouse_callback)
        print("--- Video Player Initialized for Interactive Tracking ---")

    def _mouse_callback(self, event, x, y, flags, param):
        if self.state != 'PAUSED_FOR_SELECTION': return

        if event == cv2.EVENT_LBUTTONDOWN:
            self.is_drawing_roi = True
            self.roi_start_point = (x, y)
            self.roi_end_point = (x, y)

        elif event == cv2.EVENT_MOUSEMOVE:
            if self.is_drawing_roi: self.roi_end_point = (x, y)

        elif event == cv2.EVENT_LBUTTONUP:
            if self.is_drawing_roi:
                self.is_drawing_roi = False
                if self.roi_end_point and self.roi_start_point and abs(self.roi_start_point[0] - self.roi_end_point[0]) > 5:
                    x1, y1 = self.roi_start_point
                    x2, y2 = self.roi_end_point
                    self.new_manual_box = (min(x1, x2), min(y1, y2), max(x1, x2), max(y1, y2))
                self.roi_start_point = None
                self.roi_end_point = None

        elif event == cv2.EVENT_RBUTTONDOWN:
            removed_selection = False
            for i, sel in reversed(list(enumerate(self.user_selections))):
                bbox = sel.get('bbox') or (sel['x1'], sel['y1'], sel['x2'], sel['y2'])
                if bbox[0] < x < bbox[2] and bbox[1] < y < bbox[3]:
                    removed_item = self.user_selections.pop(i)
                    if 'x1' in removed_item: self.selectable_detections.append(removed_item)
                    removed_selection = True
                    break
            
            if not removed_selection:
                for i, det in reversed(list(enumerate(self.selectable_detections))):
                    if det['x1'] < x < det['x2'] and det['y1'] < y < det['y2']:
                        self.user_selections.append(self.selectable_detections.pop(i))
                        break

    def _draw_pause_menu(self, frame):
        overlay = frame.copy()
        # Full-width semi-transparent background for the menu
        cv2.rectangle(overlay, (0, 0), (frame.shape[1], 210), (0, 0, 0), -1)
        frame = cv2.addWeighted(overlay, 0.7, frame, 0.3, 0)

        # Menu Title (Larger and Bolder)
        cv2.putText(frame, "PAUSED - SELECTION MODE", (25, 55), cv2.FONT_HERSHEY_TRIPLEX, 1.5, (0, 255, 255), 3)
        
        # Instructions (Larger and Bolder)
        cv2.putText(frame, "Mouse Controls:", (25, 105), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (255, 255, 255), 2)
        cv2.putText(frame, "- Left-Click & Drag: Draw a new box to track", (35, 135), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)
        cv2.putText(frame, "- Right-Click: Select a red box / Deselect a green box", (35, 160), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)

        cv2.putText(frame, "Keyboard: C: Confirm | H: Toggle Help | Space: Pause | Q: Quit", (25, 195), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)
        
        return frame
    
    def _get_numeric_input(self, frame):
        num_input = ""
        while True:
            frame_copy = frame.copy()
            overlay = frame_copy.copy()
            cv2.rectangle(overlay, (0, 0), (frame_copy.shape[1], frame_copy.shape[0]), (0, 0, 0), -1)
            frame_copy = cv2.addWeighted(overlay, 0.85, frame_copy, 0.15, 0)

            # --- New: Highlight selected class in green ---
            current_selection_id = -1
            try:
                if num_input: current_selection_id = int(num_input)
            except ValueError: pass

            cv2.putText(frame_copy, "Enter Class ID & Press Enter:", (50, 60), cv2.FONT_HERSHEY_TRIPLEX, 1.5, (0, 255, 255), 3)
            y_offset = 120
            for i, name in self.YOLO_CLASSES.items():
                if y_offset < frame.shape[0] - 30:
                    color = (0, 255, 0) if i == current_selection_id else (255, 255, 255)
                    thickness = 3 if i == current_selection_id else 2
                    cv2.putText(frame_copy, f"{i}: {name}", (50, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 1.0, color, thickness)
                    y_offset += 40
            
            cv2.imshow(self.window_title, frame_copy)
            
            key = cv2.waitKey(0)
            if key == 13: # Enter
                try:
                    if num_input and int(num_input) in self.YOLO_CLASSES: return int(num_input)
                    else: print(f"Error: Invalid ID. Please try again."); num_input = ""
                except ValueError: print("Error: Invalid input."); num_input = ""
            elif key == 8: num_input = num_input[:-1]
            elif ord('0') <= key <= ord('9'): num_input += chr(key)
            elif key == 27: return None # Escape key

    def play(self, tracker):
        frame_idx = 0
        temp_frame = None

        while True:
            if self.state in ['INITIALIZING', 'PLAYING']:
                ret, frame = self.cap.read()
                if not ret: break
                temp_frame = frame.copy()
                frame_idx += 1
            else: frame = temp_frame.copy()

            if self.state == 'INITIALIZING' and frame_idx == 2:
                self.state = 'PAUSED_FOR_SELECTION'
                self.selectable_detections = tracker.detect(frame)

            elif self.state == 'PLAYING':
                processed_frame = tracker.process_frame(frame)
                cv2.imshow(self.window_title, processed_frame)

            elif self.state == 'PAUSED_FOR_SELECTION':
                display_frame = frame.copy()
                if self.show_help: display_frame = self._draw_pause_menu(display_frame)
                
                for det in self.selectable_detections: cv2.rectangle(display_frame, (det['x1'], det['y1']), (det['x2'], det['y2']), (0, 0, 255), 2)
                for sel in self.user_selections:
                    bbox = sel.get('bbox') or (sel['x1'], sel['y1'], sel['x2'], sel['y2'])
                    cv2.rectangle(display_frame, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 255, 0), 3)
                if self.is_drawing_roi and self.roi_start_point and self.roi_end_point:
                    cv2.rectangle(display_frame, self.roi_start_point, self.roi_end_point, (255, 255, 0), 2)
                
                cv2.imshow(self.window_title, display_frame)

                if self.new_manual_box:
                    class_id = self._get_numeric_input(display_frame)
                    if class_id is not None:
                        self.user_selections.append({'bbox': self.new_manual_box, 'class_name': self.YOLO_CLASSES[class_id]})
                    self.new_manual_box = None

            key = cv2.waitKey(1 if self.state == 'PLAYING' else 20) & 0xFF
            if key == ord('q'): break
            elif key == ord('h'): self.show_help = not self.show_help
            elif key == 32 and self.state == 'PLAYING': # Spacebar to pause
                print("Paused. Entering selection mode...")
                self.state = 'PAUSED_FOR_SELECTION'
                self.selectable_detections = tracker.detect(frame)
                self.user_selections = list(tracker.tracked_objects)

            elif key == ord('c') and self.state == 'PAUSED_FOR_SELECTION':
                print("Selections confirmed. Resuming tracking...")
                tracker.tracked_objects = [] 
                tracker.next_track_id = 0
                for sel in self.user_selections:
                    bbox = sel.get('bbox') or (sel['x1'], sel['y1'], sel['x2'], sel['y2'])
                    tracker.add_manual_track(temp_frame, bbox, sel['class_name'])
                
                self.selectable_detections = []
                self.user_selections = []
                self.state = 'PLAYING'

        self.release()

    def release(self):
        print("Releasing resources...")
        if self.cap.isOpened(): self.cap.release()
        cv2.destroyAllWindows()
        for _ in range(5): cv2.waitKey(1)

## Test Playback

In [None]:
# You can change these parameters
VIDEO_PATH = './assets/footage/person4.mp4' # Make sure this path is correct
MODEL_PATH = './yolo11n.pt'
PLAYBACK_SPEED = 1.5 # Play at 1.5x speed
WINDOW_SIZE = 0.5   # Display window at 75% of original size

try:
    tracker = ObjectTracker(model, 'CSRT')
    player = VideoPlayer(
        source=VIDEO_PATH,
        playback_speed=PLAYBACK_SPEED,
        size_multiplier=WINDOW_SIZE,
        window_title="High-Performance Player"
    )
    player.play(tracker)
except IOError as e:
    print(e)
except Exception as e:
    print(f"An unexpected error occurred: {e}")