In [1]:
import os
import xml.etree.ElementTree as ET
import torch
import torchvision
import numpy as np
from collections import defaultdict
import cv2

In [2]:
device = torch.device('cuda')

In [3]:
VOC_CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat',
               'bottle', 'bus', 'car', 'cat', 
               'chair', 'cow', 'diningtable', 'dog', 
               'horse', 'motorbike', 'person', 'pottedplant',
               'sheep', 'sofa', 'train', 'tvmonitor')

In [4]:
def nms(bboxes, scores, threshold=0.5):
    '''
    bboxes(tensor) [N, 4]
    scores(tensor) [N, ]
    '''
    x1 = bboxes[:, 0]
    y1 = bboxes[:, 1]
    x2 = bboxes[:, 2]
    y2 = bboxes[:, 3]
    areas = (x2 - x1) * (y2 - y1)
    _, order = scores.sort(0, descending=True)
    keep = []
    while order.numel() > 0:
        if len(order.size()) == 0:
            i = order.item()
        else:
            i = order[0]
        keep.append(i)
        if order.numel() == 1:
            break
        xx1 = x1[order[1:]].clamp(min=x1[i])
        yy1 = y1[order[1:]].clamp(min=y1[i])
        xx2 = x2[order[1:]].clamp(max=x2[i])
        yy2 = y2[order[1:]].clamp(max=y2[i])
        w = (xx2 - xx1).clamp(min=0)
        h = (yy2 - yy1).clamp(min=0)
        inter = w * h
        ovr = inter / (areas[i] + areas[order[1:]] - inter)
        ids = (ovr <= threshold).nonzero().squeeze()
        if ids.numel() == 0:
            break
        order = order[ids+1]
    return torch.LongTensor(keep)

In [5]:
def decoder(pred):
    '''
    pred (tensor) 1 x 7 x 7 x 30
    return (tensor) box[[x1, y1, x2, y2]] label[...]
    '''
    boxes = []
    cls_indexs = []
    probs = []
    cell_size = 1. / 7
    pred = pred.data
    pred = pred.squeeze(0)
    contain1 = pred[:, :, 4].unsqueeze(2)
    contain2 = pred[:, :, 9].unsqueeze(2)
    contain = torch.cat((contain1, contain2), 2)
    mask1 = contain > 0.9
    mask2 = (contain == contain.max())
    mask = (mask1 + mask2).gt(0)
    min_score, min_index = torch.min(mask, 2)
    for i in range(7):
        for j in range(7):
            for b in range(2):
                index = min_index[i, j]
                mask[i, j, index] = 0
                if mask[i, j, b] == 1:
                    box = pred[i, j, b * 5 : b * 5 + 4]
                    contain_prob = torch.FloatTensor([pred[i, j, b * 5 + 4]])
                    xy = torch.FloatTensor([j, i]) * cell_size
                    box[:2] = box[:2] * cell_size + xy
                    box_xy = torch.FloatTensor(box.size())
                    box_xy[:2] = box[:2] - 0.5 * box[2:]
                    box_xy[2:] = box[:2] + 0.5 * box[2:]
                    max_prob,cls_index = torch.max(pred[i, j, 10:], 0)
                    boxes.append(box_xy.view(1, 4))
                    if len(cls_index.size()) == 0:
                        cls_index = cls_index.expand(1)
                    cls_indexs.append(cls_index)
                    probs.append(contain_prob)
    boxes = torch.cat(boxes, 0)
    probs = torch.cat(probs, 0)
    cls_indexs = torch.cat(cls_indexs, 0)
    keep = nms(boxes, probs)
    return boxes[keep], cls_indexs[keep], probs[keep]

In [8]:
def eval(root):
    target = defaultdict(list)
    preds = defaultdict(list)
    images_path = sorted(os.listdir(root + '/JPEGImages'))
    annotations_path = sorted(os.listdir(root + '/Annotations'))
    
    image_list = []
    for i in range(len(images_path)):
        image_list.append(images_path[i])
        
        tree = ET.parse(root + '/Annotations' + '/' + annotations_path[i])
        objects = tree.findall('object')
        for obj in objects:
            name = obj.find('name').text
            bndbox = obj.find('bndbox')
            xmin = int(bndbox.find('xmin').text)
            ymin = int(bndbox.find('ymin').text)
            xmax = int(bndbox.find('xmax').text)
            ymax = int(bndbox.find('ymax').text)
            target[(images_path[i], VOC_CLASSES.index(name))].append([xmin, ymin, xmax, ymax])

    model = torchvision.models.resnet50(pretrained=False)
    in_features = model.fc.in_features
    model.fc = torch.nn.Sequential(torch.nn.Linear(in_features, 1470), 
                                   torch.nn.Sigmoid())
    model.load_state_dict(torch.load('YOLOv1_Resnet50.pth'))
    model.to(device)
    
    print('Start evaluation!')
    model.eval()
    for image_path in image_list:
        image = cv2.imread(root + '/JPEGImages' + '/' + image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        h, w, c = image.shape

        img_to_model = cv2.resize(image, (448, 448))
        transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), ])
        img_to_model = transform(img_to_model)
        img_to_model = img_to_model.view(1, -1, 448, 448)
        img_to_model = img_to_model.to(device)
        pred = model(img_to_model)
        pred = pred.view(-1, 7, 7, 30)
        pred = pred.to(torch.device('cpu'))

        boxes, cls_indexs, probs = decoder(pred)
        result = []
        for i, box in enumerate(boxes):
            x1 = int(box[0] * w)
            x2 = int(box[2] * w)
            y1 = int(box[1] * h)
            y2 = int(box[3] * h)
            cls_index = cls_indexs[i]
            cls_index = int(cls_index)
            prob = probs[i]
            prob = float(prob)
            result.append([(x1, y1), (x2, y2), VOC_CLASSES[cls_index], image_path, prob])
        for (x1, y1), (x2, y2), class_name, image_id, prob in result:
            preds[class_name].append([image_id, prob, x1, y1, x2, y2])
    for (key, cls), values in target.items():
        for i in values:
            with open('mAP/input/ground-truth' + '/' + key[:-4] + '.txt', 'a+') as f:
                f.write(VOC_CLASSES[cls] + ' ' + str(i[0]) + ' ' + str(i[1]) + ' ' + str(i[2]) + ' ' + str(i[3]) + '\n')
    for key, values in preds.items():
        for i in values:
            with open('mAP/input/detection-results' + '/' + i[0][:-4] + '.txt', 'a+') as f:
                f.write(key + ' ' + str(i[1]) + ' ' + str(i[2]) + ' ' + str(i[3]) + ' ' + str(i[4]) + ' ' + str(i[5]) + '\n')
    print('Finish evaluation!')

In [9]:
eval('dataset/VOC2007train')

Start evaluation!
Finish evaluation!
