# Object Detection with SSD

## Importing the libraries

In [1]:
import torch
import cv2
import imageio
# SSD model is taken from [https://github.com/amdegroot/ssd.pytorch] and adapted for torch 1.11.0
from data import BaseTransform, VOC_CLASSES as labelmap
from ssd import build_ssd

## Detection Function

In [2]:
# Frame by frame detection
def detect(frame, net, transform):
    height, width = frame.shape[:2]
    frame_t = transform(frame)[0]
    x = torch.from_numpy(frame_t).permute(2, 0, 1) #RBG to GRB with .permute()
    x = x.unsqueeze(0) # Take the batch with it's gradients
    with torch.no_grad():
        y = net(x) # Feed the frame to Neural Network
    detections = y.data # [batch, number of classes, number of occurence, (score, x0, y0, x1, y1)]
    scale = torch.Tensor([width, height, width, height])
    for i in range(detections.size(1)):
        j = 0
        while detections[0, i, j, 0] >= 0.3: # Score >= 0.3
            pt = (detections[0, i, j, 1:] * scale).numpy()
            cv2.rectangle(frame, (int(pt[0]), int(pt[1])), (int(pt[2]), int(pt[3])),
                         (153, 0, 0), 2)
            cv2.putText(frame, labelmap[i - 1], (int(pt[0]), int(pt[1])), 
                        cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 255, 255), 2, cv2.LINE_AA)
            j = j + 1
    return frame            

## SSD Neural Network

Download pretrained SSD weights (ssd300_mAP_77.43_v2.pth) from [this link](https://s3.amazonaws.com/amdegroot-models/ssd300_mAP_77.43_v2.pth).

In [3]:
net = build_ssd('test')
net.load_state_dict(torch.load('ssd300_mAP_77.43_v2.pth', 
                               map_location = lambda storage, 
                               loc: storage))
transform = BaseTransform(net.size, (104/256.0, 117/256.0, 123/256.0))

## Objection Detection on a Video

In [4]:
reader = imageio.get_reader('video.mp4')
fps = reader.get_meta_data()['fps']
writer = imageio.get_writer('output.mp4', fps = fps, macro_block_size = 1)
for i, frame in enumerate(reader):
    frame = detect(frame, net.eval(), transform)
    writer.append_data(frame)
writer.close()

