In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import os
import torch as t
from utils.config import opt
from model import FasterRCNNVGG16
from trainer import FasterRCNNTrainer
from data.util import  read_image
from utils.vis_tool import vis_bbox
from utils import array_tool as at
import numpy as np
import pickle

In [None]:
faster_rcnn = FasterRCNNVGG16().cuda()
trainer = FasterRCNNTrainer(faster_rcnn).cuda()

In [None]:
%ls ./checkpoints/pedestrians_cyclists/

In [None]:
trainer.load('./checkpoints/pedestrians_cyclists/fasterrcnn_04162045_0.4105527015235588')
trainer.faster_rcnn.load_mahalanobis_features(save_dir='./checkpoints/pedestrians_cyclists')

In [None]:
img = read_image('misc/demo.jpg')
img = t.from_numpy(img)[None]

In [None]:
_bboxes, _labels, _scores = trainer.faster_rcnn.predict(img,visualize=True)

vis_bbox(at.tonumpy(img[0]),
         at.tonumpy(_bboxes[0]),
         at.tonumpy(_labels[0]).reshape(-1),
         at.tonumpy(_scores[0]).reshape(-1))

In [None]:
_bboxes, _labels, _scores = trainer.faster_rcnn.predict_mahalanobis(img,visualize=True)

vis_bbox(at.tonumpy(img[0]),
         at.tonumpy(_bboxes[0]),
         at.tonumpy(_labels[0]).reshape(-1),
         at.tonumpy(_scores[0]).reshape(-1))

# Forward pass, calc losses, backprop, perturb, forward pass

In [None]:
from data.dataset import Dataset, TestDataset
from torch.utils import data as data_

opt.voc_data_dir = "/media/tadenoud/DATADisk/datasets/kitti_2d/VOC2012/"
testset = TestDataset(opt, split="val")
test_dataloader = data_.DataLoader(testset,
                                   batch_size=1,
                                   num_workers=opt.test_num_workers,
                                   shuffle=False,
                                   pin_memory=True
                                   )

In [None]:
from utils.vis_tool import vis_bbox, vis_image
from tqdm import tqdm
import torch
from utils import array_tool as at

EPSILON=0.0005


mean = np.array([122.7717, 115.9465, 102.9801]).reshape(3, 1, 1)
def unnormalize_img(img, mean=mean):
    return img + mean


i = 0
for imgs, sizes, _, _, difficult in tqdm(test_dataloader, total=opt.test_num):    
    _, H, W = imgs.shape[1:]
    o_H, o_W = sizes
    scale = o_H / H
    
#     _bboxes, _labels, _scores = trainer.faster_rcnn.predict(at.totensor(imgs), [[H,W]])
    imgs = trainer.faster_rcnn.input_perturbation(imgs, scale, epsilon=10)
    vis_image(unnormalize_img(at.tonumpy(imgs[0])))
    if i == 2:
        break
    i += 1

In [None]:
from model.utils.bbox_tools import bbox_iou

def iou(bbox, label, gt_boxes, gt_labels, iou_thresh=0.5):
    ious = bbox_iou(np.array([bbox]), gt_boxes)
    best_iou = np.max(ious)
    best_label = gt_labels[np.argmax(ious)]
    
    return best_iou, best_iou > 0 and best_label == label

# Mahalanobis with input perturbation

In [None]:
IOU_THRESH = 0.5
scores = list()
ious = list()
class_corrects = list()

for imgs, sizes, gt_bboxes, gt_labels, difficult in tqdm(test_dataloader, total=opt.test_num):    
#     _, H, W = imgs.shape[1:]
#     o_H, o_W = sizes
    sizes = [sizes[0][0].item(), sizes[1][0].item()]
#     scale = o_H / H
    pred_bboxes, pred_labels, pred_scores = trainer.faster_rcnn.predict_mahalanobis(imgs, [sizes], perturbation=EPSILON)
    
    #calc iou and class correct
    for pred_bbox, pred_label, pred_score in zip(pred_bboxes[0], pred_labels[0], pred_scores[0]):
        if not len(pred_bbox):
            print('shite')
        best_iou, class_correct = iou(pred_bbox, 
                                      pred_label, 
                                      at.tonumpy(gt_bboxes)[0], 
                                      at.tonumpy(gt_labels)[0], 
                                      iou_thresh=IOU_THRESH)
        scores.append(pred_score)
        ious.append(best_iou)
        class_corrects.append(class_correct)

In [None]:
[at.totensor(np.concatenate([s.numpy() for s in sizes]))]

In [None]:
import matplotlib.pyplot as plt
# Softmax score vs IoU vs correct classificationplt
plt.scatter(scores, ious, c=class_corrects, cmap='viridis', s=5)
plt.xlabel("Mahalanobis Distance")
plt.ylabel("IoU with GT")
plt.show()