In [1]:
import sys

sys.path.insert(0, '..')

import torch
import numpy as np
import torchvision
from torchvision import transforms
from pyhandle.dataset.dataloader import TorchLoader
from pyhandle.net.intermediate import IntermediateNetwork

from net.ssd import SSD300, MultiBoxLoss
from utils.obj_utils import cxcy_to_xy, cxcy_to_gcxgcy, xy_to_cxcy

device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

In [2]:
coco_root = '/home/sixigma/workplace/meow/coco_data/'
coco = torchvision.datasets.coco.CocoDetection(root=coco_root + 'train/train2017/', annFile=coco_root + 'annotations/instances_train2017.json')

loading annotations into memory...
Done (t=9.37s)
creating index...
index created!


In [3]:
resnet = IntermediateNetwork('resnet50', [5, 6])
ssd_net = SSD300(resnet, 80)

In [4]:
priors = cxcy_to_xy(ssd_net.priors_cxcy)
multibox = MultiBoxLoss(priors)
boxes = []
labels = []
width, height = coco[1][0].size
for obj in range(len(coco[1][1])):
    # coco bounding box format [top left x position, top left y position, width, height]
    box = coco[1][1][obj]['bbox']
    box = [box[0] / width, box[1] / height, box[0] / width + box[2] / width, box[1] / height + box[3] / height]
    boxes.append(box)
    labels.append(coco[1][1][obj]['category_id'])
t_boxes = torch.FloatTensor([boxes]).to(device)
t_labels = torch.FloatTensor([labels]).to(device)



In [5]:
t_boxes, t_labels

(tensor([[[0.6024, 0.1409, 0.9383, 0.8385],
          [0.0828, 0.8368, 0.2891, 0.9664]]], device='cuda:0'),
 tensor([[25., 25.]], device='cuda:0'))

In [7]:
image = np.array(transforms.Resize((300, 300))((coco[0][0])))
t_image = torch.from_numpy(image).permute(2, 0, 1).float().cuda()
locs, cls = ssd_net(transforms.Normalize(0, 255)(t_image).expand([1, -1, -1, -1]))
locs = locs.to(device)
cls = cls.to(device)

In [8]:
loss = multibox(locs, cls, t_boxes, t_labels)
loss

tensor(30.1980, device='cuda:0', grad_fn=<AddBackward0>)

In [9]:
# Get intersection bounding box left_top and right_down coordinate
lower_bounds = torch.max(t_boxes[0, :, :2].unsqueeze(1), priors[:, :2].unsqueeze(0))
upper_bounds = torch.min(t_boxes[0, :, 2:].unsqueeze(1), priors[:, 2:].unsqueeze(0))

# Get intersection bounding box width and height
intersection_dims = torch.clamp(upper_bounds - lower_bounds, min=0)

# Get volumn of intersection bounding box
intersection_area = intersection_dims[:, :, 0] * intersection_dims[:, :, 1]

# 
areas_set_1 = (t_boxes[0, :, 2] - t_boxes[0, :, 0]) * (t_boxes[0, :, 3] - t_boxes[0, :, 1])
areas_set_2 = (priors[:, 2] - priors[:, 0]) * (priors[:, 3] - priors[:, 1])

union = areas_set_1.unsqueeze(1) + areas_set_2.unsqueeze(0) - intersection_area
iou = intersection_area / union  # shape (n_obj, 8732)

overlap_for_each_prior, object_for_each_prior = iou.max(dim=0)
_, prior_for_each_object = iou.max(dim=1) # (N_o)

In [10]:
# The fallowing 2 lines would ensure all objects in this image would map to a prior
# Then, assign each object to the corresponding maximum-overlap-prior. (This fixes 1.)
object_for_each_prior[prior_for_each_object] = torch.LongTensor(range(len(boxes))).to(device)
# To ensure these priors qualify, artificially give them an overlap of greater than 0.5. (This fixes 2.)
overlap_for_each_prior[prior_for_each_object] = 1.

# Labels for each prior
label_for_each_prior = t_labels[0][object_for_each_prior]  # (8732)
# Set priors whose overlaps with objects are less than the threshold to be background (no object)
label_for_each_prior[overlap_for_each_prior < 0.5] = 0  # (8732)

# Store
true_classes = label_for_each_prior

# Encode center-size object coordinates into the form we regressed predicted boxes to
offset_locs = cxcy_to_gcxgcy(xy_to_cxcy(t_boxes[0][object_for_each_prior]), ssd_net.priors_cxcy)  # (8732, 4)

# Identify priors that are positive (object/non-background)
positive_priors = true_classes != 0  # (N, 8732)

# LOCALIZATION LOSS

# Localization loss is computed only over positive (non-background) priors
loc_loss = torch.nn.SmoothL1Loss()(locs[0][positive_priors], offset_locs[positive_priors])  # (), scalar