In [None]:
import cv2
import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image
import os
import numpy as np
from torchvision.models.detection import maskrcnn_resnet50_fpn
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

input_video = "raw-traffic-footage/drone-2.mp4"

frames_dir = "extracted_frames"
os.makedirs(frames_dir, exist_ok=True)

video = cv2.VideoCapture(input_video)

frame_rate = video.get(cv2.CAP_PROP_FPS)
frame_width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))

output_video = "output_video.mp4"
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
output = cv2.VideoWriter(output_video, fourcc, frame_rate, (frame_width, frame_height))

frame_count = 0

while True:
    ret, frame = video.read()
    
    if not ret:
        break
    
    frame_path = os.path.join(frames_dir, f"frame_{frame_count:05d}.jpg")
    cv2.imwrite(frame_path, frame)
    
    frame_count += 1

video.release()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = maskrcnn_resnet50_fpn(weights=torchvision.models.detection.MaskRCNN_ResNet50_FPN_Weights.DEFAULT)

model.to(device)
model.eval()

coco_labels = [
    '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
    'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
    'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
    'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
    'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
    'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
    'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
    'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
    'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
    'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
    'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]

transform = transforms.Compose([
    transforms.ToTensor()
])

def detect_objects(image):
    image_tensor = transform(image).unsqueeze(0).to(device)

    with torch.no_grad():
        outputs = model(image_tensor)

    boxes = outputs[0]['boxes'].cpu().numpy()
    labels = outputs[0]['labels'].cpu().numpy()
    scores = outputs[0]['scores'].cpu().numpy()
    masks = outputs[0]['masks'].cpu().numpy()

    detected_objects = []

    for box, label, score, mask in zip(boxes, labels, scores, masks):
        if score >= 0.5:
            detected_objects.append((box, coco_labels[label], score, mask))

    return detected_objects

def preprocess_image(image_path, target_size=(800, 800)):
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    h, w, _ = image.shape
    scale = min(target_size[0] / h, target_size[1] / w)
    new_size = (int(w * scale), int(h * scale))
    
    resized_image = cv2.resize(image, new_size)
    
    padded_image = np.zeros((target_size[0], target_size[1], 3), dtype=np.uint8)
    padded_image[:resized_image.shape[0], :resized_image.shape[1], :] = resized_image
    
    return padded_image, scale, scale

def visualize_detections(image_path, detected_objects, scale_x, scale_y, confidence_threshold=0.7):
    image = cv2.imread(image_path)
    image_with_detections = image.copy()

    for box, label, score, _ in detected_objects:
        if score >= confidence_threshold:
            xmin, ymin, xmax, ymax = box

            xmin_adjusted = int(xmin / scale_x)
            ymin_adjusted = int(ymin / scale_y)
            xmax_adjusted = int(xmax / scale_x)
            ymax_adjusted = int(ymax / scale_y)

            cv2.rectangle(image_with_detections, (xmin_adjusted, ymin_adjusted), (xmax_adjusted, ymax_adjusted), (0, 255, 0), 2)
            cv2.putText(image_with_detections, f"{label}: {score:.2f}", (xmin_adjusted, ymin_adjusted - 10),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)

    return image_with_detections

for frame_file in sorted(os.listdir(frames_dir)):
    frame_path = os.path.join(frames_dir, frame_file)
    
    preprocessed_image, scale_x, scale_y = preprocess_image(frame_path)
    
    pil_image = Image.fromarray(preprocessed_image)
    
    detected_objects = detect_objects(pil_image)
    
    image_with_detections = visualize_detections(frame_path, detected_objects, scale_x, scale_y)
    
    output.write(image_with_detections)

output.release()

import shutil
shutil.rmtree(frames_dir)

print("Video processing completed.")