In [1]:
# pip install opencv-python torch torchvision numpy matplotlib ultralytics

In [2]:
from scipy.spatial import distance as dist
from collections import OrderedDict
import numpy as np
import cv2
from ultralytics import YOLO
from collections import defaultdict
import matplotlib.pyplot as plt

In [3]:
class CentroidTracker:
    def __init__(self, max_disappeared=50):
        self.next_object_id = 0
        self.objects = OrderedDict()
        self.disappeared = OrderedDict()
        self.max_disappeared = max_disappeared

    def register(self, centroid):
        self.objects[self.next_object_id] = centroid
        self.disappeared[self.next_object_id] = 0
        self.next_object_id += 1

    def deregister(self, object_id):
        del self.objects[object_id]
        del self.disappeared[object_id]

    def update(self, rects):
        if len(rects) == 0:
            for object_id in list(self.disappeared.keys()):
                self.disappeared[object_id] += 1
                if self.disappeared[object_id] > self.max_disappeared:
                    self.deregister(object_id)
            return self.objects

        input_centroids = np.zeros((len(rects), 2), dtype="int")
        for (i, (x1, y1, x2, y2)) in enumerate(rects):
            cX = int((x1 + x2) / 2.0)
            cY = int((y1 + y2) / 2.0)
            input_centroids[i] = (cX, cY)

        if len(self.objects) == 0:
            for i in range(0, len(input_centroids)):
                self.register(input_centroids[i])
        else:
            object_ids = list(self.objects.keys())
            object_centroids = list(self.objects.values())

            D = dist.cdist(np.array(object_centroids), input_centroids)
            rows = D.min(axis=1).argsort()
            cols = D.argmin(axis=1)[rows]

            used_rows, used_cols = set(), set()

            for (row, col) in zip(rows, cols):
                if row in used_rows or col in used_cols:
                    continue
                object_id = object_ids[row]
                self.objects[object_id] = input_centroids[col]
                self.disappeared[object_id] = 0
                used_rows.add(row)
                used_cols.add(col)

            unused_rows = set(range(0, D.shape[0])).difference(used_rows)
            unused_cols = set(range(0, D.shape[1])).difference(used_cols)

            for row in unused_rows:
                object_id = object_ids[row]
                self.disappeared[object_id] += 1
                if self.disappeared[object_id] > self.max_disappeared:
                    self.deregister(object_id)

            for col in unused_cols:
                self.register(input_centroids[col])

        return self.objects

In [4]:
# Initialize model and tracker
model = YOLO("yolov8n.pt")
ct = CentroidTracker()

In [5]:
vehicle_classes = {2: "car", 3: "motorcycle", 5: "bus", 7: "truck"}

In [6]:
cap = cv2.VideoCapture("traffic.mp4")

In [7]:
frame_count = 0
total_unique_ids = set()
class_id_map = {}

In [8]:
# Initialize a dictionary to hold the total count for each vehicle type
vehicle_counts = defaultdict(int)

In [9]:
while True:
    ret, frame = cap.read()
    if not ret:
        break
    frame_count += 1

    results = model(frame)[0]

    rects = []
    boxes_info = []

    for box in results.boxes:
        cls_id = int(box.cls)
        if cls_id in vehicle_classes:
            x1, y1, x2, y2 = map(int, box.xyxy[0])
            rects.append((x1, y1, x2, y2))
            boxes_info.append(((x1, y1, x2, y2), cls_id))

    objects = ct.update(rects)

    per_frame_stats = defaultdict(int)

    for object_id, centroid in objects.items():
        if object_id < len(boxes_info):
            (x1, y1, x2, y2), cls_id = boxes_info[object_id]
            class_name = vehicle_classes[cls_id]
            per_frame_stats[class_name] += 1
            total_unique_ids.add(object_id)
            class_id_map[object_id] = class_name

            # Draw box and ID
            cv2.rectangle(frame, (x1, y1), (x2, y2), (255, 0, 0), 2)
            cv2.putText(frame, f"ID:{object_id} {class_name}", (x1, y1 - 5),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 0), 2)

    # Update the total counts across all frames
    for vehicle_type, count in per_frame_stats.items():
        vehicle_counts[vehicle_type] += count

    # Display counts
    y_offset = 20
    for cls, count in per_frame_stats.items():
        cv2.putText(frame, f"{cls}: {count}", (10, y_offset),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)
        y_offset += 25

    cv2.putText(frame, f"Unique Vehicles: {len(set(class_id_map.keys()))}", (10, y_offset),
                cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)

    resized_frame = cv2.resize(frame, (800, 450))  # or (640, 360)
    cv2.imshow("Vehicle Detection", resized_frame)
    
    if cv2.waitKey(1) == ord("q"):
        break

cap.release()
cv2.destroyAllWindows()


0: 384x640 10 cars, 1 bus, 3 trucks, 89.2ms
Speed: 6.6ms preprocess, 89.2ms inference, 1.8ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 12 cars, 2 trucks, 73.6ms
Speed: 5.8ms preprocess, 73.6ms inference, 1.3ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 10 cars, 1 bus, 3 trucks, 56.4ms
Speed: 2.2ms preprocess, 56.4ms inference, 1.1ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 9 cars, 1 bus, 4 trucks, 53.5ms
Speed: 2.4ms preprocess, 53.5ms inference, 0.9ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 8 cars, 3 trucks, 53.0ms
Speed: 4.2ms preprocess, 53.0ms inference, 1.0ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 6 cars, 3 trucks, 52.2ms
Speed: 2.3ms preprocess, 52.2ms inference, 1.2ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 9 cars, 1 bus, 3 trucks, 57.1ms
Speed: 2.2ms preprocess, 57.1ms inference, 0.9ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 8 cars, 1 bus, 4 tr

In [10]:
# Print total unique vehicles
print(f"Total unique vehicles detected: {len(total_unique_ids)}")

Total unique vehicles detected: 14
