In [None]:
import sys

sys.path.insert(0, '..')

import torch
import numpy as np
from torchvision import transforms

from dataset.dataloader import TorchLoader
from net.intermediate import IntermediateNetwork
from net.ssd import SSD300, MultiBoxLoss
from utils.obj_utils import cxcy_to_xy

device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

In [None]:
coco_root = '/home/sixigma/workplace/meow/coco_data/'

In [None]:
import torchvision

coco = torchvision.datasets.coco.CocoDetection(root=coco_root + 'train/train2017/', annFile=coco_root + 'annotations/instances_train2017.json')

In [None]:
resnet = IntermediateNetwork('resnet50', [5, 6])
ssd_net = SSD300(resnet, 80)

In [None]:
priors = cxcy_to_xy(ssd_net.priors_cxcy).to(device)
multibox = MultiBoxLoss(priors.to(device))
boxes = []
labels = []
width, height = coco[1][0].size
for obj in range(len(coco[1][1])):
    # coco bounding box format [top left x position, top left y position, width, height]
    box = coco[1][1][obj]['bbox']
    box = [box[0] / width, box[1] / height, box[0] / width + box[2] / width, box[1] / height + box[3] / height]
    boxes.append(box)
    labels.append(coco[1][1][obj]['category_id'])
t_boxes = torch.FloatTensor([boxes]).to(device)
t_labels = torch.FloatTensor([labels]).to(device)

In [None]:
image = np.array(transforms.Resize((300, 300))((coco[0][0])))
t_image = torch.from_numpy(image).permute(2, 0, 1).float().to(device)
locs, cls = ssd_net(transforms.Normalize(0, 255)(t_image).expand([1, -1, -1, -1]))
locs = locs.to(device)
cls = cls.to(device)

In [None]:
loss = multibox(locs, cls, t_boxes, t_labels)
loss

In [None]:
locs.shape, t_boxes.shape, priors.shape

In [None]:
locs[:, :2, :], priors[:2, :], t_boxes[:, :8, :]

In [None]:
t_labels, cls

In [None]:
# Get intersection bounding box left_top and right_down coordinate
lower_bounds = torch.max(t_boxes[0, :, :2].unsqueeze(1), priors[:, :2].unsqueeze(0))
upper_bounds = torch.min(t_boxes[0, :, 2:].unsqueeze(1), priors[:, 2:].unsqueeze(0))

In [None]:
# Get intersection bounding box width and height
intersection_dims = torch.clamp(upper_bounds - lower_bounds, min=0)

# Get volumn of intersection bounding box
intersection_area = intersection_dims[:, :, 0] * intersection_dims[:, :, 1]

# 
areas_set_1 = (t_boxes[0, :, 2] - t_boxes[0, :, 0]) * (t_boxes[0, :, 3] - t_boxes[0, :, 1])
areas_set_2 = (priors[:, 2] - priors[:, 0]) * (priors[:, 3] - priors[:, 1])

union = areas_set_1.unsqueeze(1) + areas_set_2.unsqueeze(0) - intersection_area
iou = intersection_area / union  # shape (n_obj, 8732)

In [None]:
overlap_for_every_obj_iou, overlap_for_every_obj_idx = iou.max(dim=1) # shape (n_obj)

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches

plt.figure()
fig, ax = plt.subplots(1, figsize=(12,9))
ax.imshow(coco[1][0])

# Get bounding-box colors
cmap = plt.get_cmap('tab20b')
colors = [cmap(i) for i in np.linspace(0, 1, 20)]

# groundtruth box
for obj in boxes:
    x = obj[0] * width
    y = obj[1] * height
    box_w = (obj[2] - obj[0]) * width
    box_h = (obj[3] - obj[1]) * height
    bbox = patches.Rectangle((x, y), box_w, box_h,
             linewidth=2, edgecolor='r', facecolor='none')
    ax.add_patch(bbox)
    
for prior in priors[overlap_for_every_obj_idx]:
    x = prior[0] * width
    y = prior[1] * height
    box_w = (prior[2] - prior[0]) * width
    box_h = (prior[3] - prior[1]) * height
    bbox = patches.Rectangle((x, y), box_w, box_h,
             linewidth=2, edgecolor='b', facecolor='none')
    ax.add_patch(bbox)

plt.show()