This code is to detect objects from the sum filter

In [1]:
import torch
import os
from Model import Model
from torchvision import transforms
from torch.autograd import Variable
from torch.nn import functional as F
import cv2
import numpy as np
from PIL import Image
import xml.etree.ElementTree as ET
from utils import find_central_point, sum_filter 
from utils import find_corner, returnHeatmap
classes = np.array([
        'aeroplane',
        'bicycle',
        'bird',
        'boat',
        'bottle',
        'bus',
        'car',
        'cat',
        'chair',
        'cow',
        'diningtable',
        'dog',
        'horse',
        'motorbike',
        'person',
        'pottedplant',
        'sheep',
        'sofa',
        'train',
        'tvmonitor',
    ])
#location of the input image
image_path = './example_images/000071.jpg'
annotation_path = './example_images/000071.xml'

In [2]:
def parse_rec(filename):
    #extract annotation form data
    tree = ET.parse(filename)
    objects = []
    for obj in tree.findall('object'):
        obj_struct = {}
        obj_struct['name'] = obj.find('name').text
        obj_struct['pose'] = obj.find('pose').text
        obj_struct['truncated'] = int(obj.find('truncated').text)
        obj_struct['difficult'] = int(obj.find('difficult').text)
        bbox = obj.find('bndbox')
        obj_struct['bbox'] = [int(bbox.find('xmin').text),
                              int(bbox.find('ymin').text),
                              int(bbox.find('xmax').text),
                              int(bbox.find('ymax').text)]
        objects.append(obj_struct)
        
    return objects

In [5]:
def objectdetection(input, threshod=130):
    #object detection
    net = Model()
    net.load_state_dict(torch.load('weights/epoch204densenet.pth'))
    net.eval()
    
    features_blosbs = []
    def hook_feature(module, input, output):
        #extract feature from the last layer
        features_blosbs.append(output.data.cpu().numpy())
        
    net._modules.get('features').register_forward_hook(hook_feature)
    
    params = list(net.parameters())
    weight_softmax = np.squeeze(params[-2].data.numpy())
    
    #input preprocessing
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225])
    preprocess = transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        normalize
    ])
    #load image
    img_pil = Image.open(input)

    img_tensor = preprocess(img_pil)
    img_variable = Variable(img_tensor.unsqueeze(0))
    logit = net(img_variable)
    
    h_x = F.softmax(logit, dim=1).data.squeeze()
    probs, idx = h_x.sort(0, True)
    probs = probs.numpy()
    idx = idx.numpy()
    #generate bounding box
    img = cv2.imread(input)
    height, width, _ = img.shape  

    print('{:.3f} -> {}'.format(probs[0], classes[idx[0]]))
    print('index of class', idx[0])
    heatmap = returnHeatmap(features_blosbs[0], weight_softmax, [idx[0]])
    resized_heat_map = cv2.resize(heatmap[0],(width, height)) 
    sum_heat_map, maxi_loc = sum_filter(resized_heat_map)
    #max_x, max_y is the central point of object, but it is not be used
    max_x, max_y = maxi_loc
    
    #min the left-top corner point, max is the right-bottom point
    min, max = find_corner(sum_heat_map, threshod)
    min[0], min[1] = min[1], min[0]
    max[0], max[1] = max[1], max[0]
    print('corner',min, max)

    #generate heatmap
    heatmap = cv2.applyColorMap(resized_heat_map, cv2.COLORMAP_JET)
    heatmap1 = cv2.applyColorMap(sum_heat_map, cv2.COLORMAP_JET)
    out_heatmap_ori = heatmap * 0.3 + img * 0.5
    out_heatmap_sum = heatmap1 * 0.3 + img * 0.5
    #out_heatmap_ori[max_x, max_y, :] = [0,255,0]
    cv2.imwrite(os.path.join('tmp', classes[idx[0]]) + '_heatmap_ori.jpg', out_heatmap_ori)
    cv2.imwrite(os.path.join('tmp', classes[idx[0]]) + '_heatmap_sum.jpg', out_heatmap_sum)
    cv2.imwrite(os.path.join('tmp', classes[idx[0]]) + '_heatmap_only_ori.jpg', resized_heat_map)
    cv2.imwrite(os.path.join('tmp', classes[idx[0]]) + '_heatmap_only_sum.jpg', sum_heat_map)
    #this code is for find the central code, but don't used
    #central_point = cv2.circle(out_heatmap1, (max_x, max_y), 5, (255,0,0))
    #cv2.imwrite(os.path.join('tmp', classes[idx[0]]) + '_heatmap_center_point.jpg', central_point)

    #generate predict bbox
    bbox1 = np.zeros((height, width, 3))
    cv2.rectangle(bbox1, (min[0], min[1]), (max[0], max[1]), (255,255,255))
    cv2.putText(bbox1, '{:.3f}, {}'.format(probs[0], classes[idx[0]]), (min[0], min[1]+10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 0)

    #generate ground truth bbox
    anno = parse_rec(annotation_path)
    print('annotation', anno)
    bbox2 = np.zeros((height, width, 3))

    #save image
    for item in anno:
        cv2.rectangle(bbox2, (item['bbox'][0], item['bbox'][1]), (item['bbox'][2], item['bbox'][3]), (0, 0, 255))
    result = bbox1 + img + bbox2
    #cv2.imwrite('CAM.jpg', result)
    name = os.path.join('tmp', classes[idx[0]]) + '.jpg'
    cv2.imwrite(name, result)

    print('end')
        
    return [max_x, max_y], resized_heat_map

In [6]:
if __name__ == '__main__':
    central_point, heat_map = objectdetection(image_path)

0.763 -> car
index of class 6
(375, 500)
corner [56, 43] [453, 339]
annotation [{'name': 'car', 'pose': 'Left', 'truncated': 0, 'difficult': 0, 'bbox': [61, 75, 443, 274]}]
end
