In [1]:
import cv2
import torch
from torchvision import models, transforms
from torchvision.models.detection import SSD300_VGG16_Weights
from PIL import Image
import random

# Load the pre-trained SSD300 model with VGG16 backbone using the new weights parameter
weights = SSD300_VGG16_Weights.COCO_V1
model = models.detection.ssd300_vgg16(weights=weights)
model.eval()

# Define the transformation
transform = transforms.Compose([
    transforms.ToTensor()
])

# Fetch COCO class labels from the weights
coco_labels = weights.meta["categories"]

# Assign random colors to each class label
colors = {label: (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) for label in coco_labels}

# Open the video file
video_path = 'sample_video4.mp4'
cap = cv2.VideoCapture(video_path)

# Define the codec and create a VideoWriter object
fourcc = cv2.VideoWriter_fourcc(*'XVID')
out = cv2.VideoWriter('SSD_output.avi', fourcc, 30.0, (int(cap.get(3)), int(cap.get(4))))

while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break

    # Convert frame to PIL Image
    pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))

    # Apply the transformations
    input_img = transform(pil_img).unsqueeze(0)

    # Perform object detection
    with torch.no_grad():
        detections = model(input_img)[0]

    # Draw bounding boxes and labels on the frame
    for i in range(len(detections['boxes'])):
        box = detections['boxes'][i].numpy().astype(int)
        score = detections['scores'][i].item()
        label = coco_labels[detections['labels'][i]]
        color = colors[label]
        if 0.2 <= score <= 1:  # Confidence threshold range
            cv2.rectangle(frame, (box[0], box[1]), (box[2], box[3]), color, 2)
            cv2.putText(frame, f'{label}: {score:.2f}', (box[0], box[1]-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)

    # Display the frame
    cv2.imshow('Object Detection', frame)

    # Write the frame to the output video
    out.write(frame)

    # Break the loop if 'q' is pressed
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

# Release everything
cap.release()
out.release()
cv2.destroyAllWindows()
