In [4]:
!pip install deep-sort-realtime

Collecting deep-sort-realtime
  Downloading deep_sort_realtime-1.3.2-py3-none-any.whl.metadata (12 kB)
Downloading deep_sort_realtime-1.3.2-py3-none-any.whl (8.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.4/8.4 MB[0m [31m54.9 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: deep-sort-realtime
Successfully installed deep-sort-realtime-1.3.2


In [5]:
import cv2
import torch
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.transforms import functional as F
from deep_sort_realtime.deepsort_tracker import DeepSort
import os
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
#Detector and Tracker classes for a uniform interface
class Detector:
    def __init__(self):
        pass
    def getDetections(self, frame):
        pass

class Tracker:
    def __init__(self):
        pass
    def getTrackedObjects(self, detections, frame):
        pass

In [7]:
# Detectors : 
class DetectorFasterRCNN(Detector):
    def __init__(self):
        self.model = fasterrcnn_resnet50_fpn(pretrained = True)
        self.model.eval()
        self.model.to(device)
        self.vehicle_classes = [2, 3, 4, 6]
        
    def getDetections(self, frame):
        frame_tensor = F.to_tensor(frame).unsqueeze(0).to(device)
        with torch.no_grad():
            outputs = self.model(frame_tensor)
        boxes = outputs[0]['boxes'].cpu().numpy()
        scores = outputs[0]['scores'].cpu().numpy()
        labels = outputs[0]['labels'].cpu().numpy()
        detections = [(box, score, label) for box, score, label in zip(boxes, scores, labels) if score > 0.8 and label in self.vehicle_classes]
        return detections

class DetectorYOLO(Detector):
    def __init__(self):
        self.model = None
    
    def getDetections(self, frame):
        return []
    
detectorFasterRCNN = DetectorFasterRCNN()
detectorYOLO = DetectorYOLO()

Downloading: "https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth" to /root/.cache/torch/hub/checkpoints/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth
100%|██████████| 160M/160M [00:01<00:00, 155MB/s]  


In [9]:
# Trackers
class TrackerSORT(Tracker):
    def __init__(self):
        self.tracker = None
    
    def getTrackedObjects(self, detections, frame):
        return self.tracker(detections, frame)
    
class TrackerDeepSORT(Tracker):
    def __init__(self):
        self.tracker = DeepSort()
    
    def getTrackedObjects(self, detections, frame):
        return self.tracker.update_tracks(detections, frame = frame)
    
trackerSORT = TrackerSORT()
trackerDeepSORT = TrackerDeepSORT()

In [10]:
class VehicleTracker:
    def __init__(self, detector, tracker):
        self.detector = detector
        self.tracker = tracker
        
    def writeVideo(self, frames, output_file='output.mp4', fps=30):
        height, width, _ = frames[0].shape
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        out = cv2.VideoWriter(output_file, fourcc, fps, (width, height))
        for frame in frames:
            out.write(frame)
        out.release()
        
    def putText(self, frame, text, top_left, bottom_right):
        cv2.rectangle(frame, top_left, bottom_right, (0, 0, 0), thickness=cv2.FILLED)
        cv2.putText(frame, text, (top_left[0] + 15, top_left[1] + 35), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)

    def getVideo(self, videofile, outputfile):
        cap = cv2.VideoCapture(videofile)
        video_frames = []
        
        # The following x-coordinate describes the line, the intersection with which determines our vehicle count
        xline = 500
        # This set contains the ids of the tracked vehicles which intersected with our vertical line, vehicle_count should be the size of this set
        intersectedIds = set()
        
        framect = 0
        while True : 
            ret, frame = cap.read()
            if not ret:
                break
            framect += 1
        
        print(f"Processing {framect} frames")
        cap = cv2.VideoCapture(videofile)

        for i in tqdm(range(framect)):
            ret, frame = cap.read()
            if not ret:
                break
                
            detections = self.detector.getDetections(frame)
            tracked_objects = self.tracker.getTrackedObjects(detections, frame)

            video_frame = frame.copy()
            linecolor = (252, 227, 3)
            for obj in tracked_objects : 
                track_id = obj.track_id
                bbox = obj.to_ltrb().astype("int")
                cv2.rectangle(video_frame, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (int(track_id), 255, int(track_id)), 2)
                cv2.putText(video_frame, str(track_id), (bbox[0], bbox[1] - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 2)
                
                if bbox[0] <= xline and xline <= bbox[2]:
                    if track_id not in intersectedIds:
                        intersectedIds.add(track_id)
                        linecolor = (255, 255, 255)
                    
            # Drawing the line and adding the count
            h, w = video_frame.shape[:2]
            cv2.line(video_frame, (xline, 0), (xline, h - 1), linecolor, 2)
            self.putText(video_frame, f'Count: {len(intersectedIds)}', (w - 200, 50), (w - 45, 100))
            
            video_frames.append(video_frame)

        cap.release()
        self.writeVideo(video_frames, output_file = outputfile)

vehicleTracker = VehicleTracker(detectorFasterRCNN, trackerDeepSORT)         

In [11]:
vehicleTracker.getVideo('/kaggle/input/intersectiondata01/vid2.mp4', "output1.mp4")

Processing 1775 frames


100%|██████████| 1775/1775 [06:11<00:00,  4.77it/s]
