In [4]:
from io import BytesIO

import cv2
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
import requests
import torch
import torchvision.transforms as T
from PIL import Image
from torchvision.models.detection import FasterRCNN_ResNet50_FPN_Weights, FasterRCNN
from torchvision.models.detection import fasterrcnn_resnet50_fpn, fasterrcnn_mobilenet_v3_large_320_fpn
from torchvision.models.detection.retinanet import _COCO_CATEGORIES
from tqdm import tqdm
import torchvision
import time

In [5]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

model = fasterrcnn_resnet50_fpn(pretrained=True)
model.to(device)
model.eval()

transform = T.Compose([T.ToTensor()])

video_path = './data/lange_10.mp4'
output_path = './result/lange_10_fasterrcnn_resnet50_fpn.mp4'

cap = cv2.VideoCapture(video_path)

frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS)  
out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (frame_width, frame_height))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))


In [6]:
fps_start_time = time.time()
fps_counter = 0
vehicle_classes = [2, 3, 5, 7]  # 2 - car, 3 - motorcycle, 5 - bus, 7 - truck

while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break
    
    image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    input_tensor = transform(image).to(device)
    with torch.no_grad():
        prediction = model([input_tensor])

    confidence_threshold = 0.5 
    for label, score, bbox in zip(prediction[0]['labels'], prediction[0]['scores'], prediction[0]['boxes']):
        if label in vehicle_classes and score >= confidence_threshold:  
            bbox = bbox.cpu().numpy().astype(int)
            cv2.rectangle(frame, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 255, 0), 5)
            cv2.putText(frame, f"{_COCO_CATEGORIES[label]}:{score:.2f}", (bbox[0], bbox[1] - 10), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 255, 0), 5)
    
    out.write(frame)

    
fps_end_time = time.time()
average_fps = total_frames / (fps_end_time - fps_start_time)

print("Average FPS:", average_fps)

cap.release()
out.release()
cv2.destroyAllWindows()


Average FPS: 1.3157653758333472
