In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
from faster_rcnn.utils.datasets.voc.voc import VOCDetection
from PIL import Image
import numpy as np 
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from copy import copy
import os
from random import randint


In [3]:
def imshow(inp, gt_boxes=[], predict_boxes = []):
    """Imshow for Tensor."""
    print(inp.shape)
    inp = inp.numpy().transpose((1, 2, 0))
    inp = np.clip(inp, 0, 1)
    fig,ax = plt.subplots(1, figsize=(20, 10))

    ax.imshow(inp)
    for i, box in enumerate(gt_boxes):
        print(box)
        rect = patches.Rectangle((box[0], box[1]), box[2] - box[0], box[3] - box[1]  ,linewidth=2,edgecolor='r',facecolor='none')
        # Add the patch to the Axes
        ax.add_patch(rect)
        
    for i, box in enumerate(predict_boxes):
        rect = patches.Rectangle((box[0], box[1]), box[2] - box[0], box[3] - box[1]  ,linewidth=1,edgecolor='g',facecolor='none')
        # Add the patch to the Axes
        ax.add_patch(rect)

    plt.pause(0.001)  # pause a bit so that plots are updated


In [4]:
import torch.utils.data as data
from PIL import Image, ImageDraw
import os
import os.path
import sys
from torchvision import transforms
import numpy as np
import logging
try:
    from faster_rcnn.utils.datasets.voc.string_int_label_map_pb2 import StringIntLabelMap
except Exception as e:
    from string_int_label_map_pb2 import StringIntLabelMap

from google.protobuf import text_format

if sys.version_info[0] == 2:
    import xml.etree.cElementTree as ET
else:
    import xml.etree.ElementTree as ET
logging.basicConfig()
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)


class TransformVOCDetectionAnnotation(object):
    def __init__(self, keep_difficult=False):
        self.keep_difficult = keep_difficult

    def __call__(self, target):
        res = []

        for obj in target.iter('object'):
            name = obj.find('name').text
            bb = obj.find('bndbox')
            bndbox = [bb.find('xmin').text, bb.find('ymin').text,
                      bb.find('xmax').text, bb.find('ymax').text]

            res += [bndbox + [name]]

        return res


class VOCDetection(data.Dataset):
    def __init__(self, root, image_set, transform=None, target_transform=None):
        self.root = root
        self.image_set = image_set
        self.target_transform = target_transform

        dataset_name = 'VOC2007'
        self._annopath = os.path.join(
            self.root, dataset_name, 'Annotations', '%s.xml')
        self._imgpath = os.path.join(
            self.root, dataset_name, 'JPEGImages', '%s.jpg')
        self._imgsetpath = os.path.join(
            self.root, dataset_name, 'ImageSets', 'Main', '%s.txt')
        self._label_map_path = os.path.join(
            self.root, dataset_name, 'pascal_label_map.pbtxt')

        with open(self._label_map_path) as f:
            label_map_string = f.read()
            label_map = StringIntLabelMap()
            try:
                text_format.Merge(label_map_string, label_map)
            except text_format.ParseError:
                label_map.ParseFromString(label_map_string)

        label_map_dict = {'__background__': 0}
        self.classes = ['__background__']

        for id, item in enumerate(label_map.item, 1):
            label_map_dict[item.name] = id
            self.classes.append(item.name)

        self.label_map_dict = label_map_dict

        if transform is not None:
            self.transform = transform
        else:
            self.transform = transforms.Compose([
#                 transforms.Resize(600),
#                 transforms.ToTensor(),
#                 transforms.Normalize([0.485, 0.456, 0.406], [
#                                      0.229, 0.224, 0.225])
            ])

        with open(self._imgsetpath % self.image_set) as f:
            ids = f.readlines()

        self.ids = []
        for id in ids:
            striped_strings = id.strip().split()
            if len(striped_strings) == 2:
                self.ids.append(striped_strings[0])

    def __getitem__(self, index):
        img_id = self.ids[index]

        try:
            target = ET.parse(self._annopath % img_id).getroot()
            img = Image.open(self._imgpath % img_id).convert('RGB')
        except IOError as e:
            logger.debug(e)
            return None

        origin_size = img.size
        
        
        img = np.asarray(img, dtype=np.uint8)
        target_size = tuple(img.shape)
        im_info = np.array(
            [[float(target_size[0]), float(target_size[1]), 1.]])

        blobs = {}
        blobs['tensor'] = img
        blobs['im_info'] = im_info
        blobs['im_name'] = os.path.basename(self._imgpath % img_id)

        def bboxs(target):
            for obj in target.iter('object'):
                name = obj.find('name').text
                bb = obj.find('bndbox')
                bndbox = [bb.find('xmin').text, bb.find('ymin').text,
                          bb.find('xmax').text, bb.find('ymax').text]
                class_index = self.label_map_dict[name]
                yield bndbox, class_index

        try:
            gt_boxes, gt_classes = zip(*[box for box in bboxs(target)])
            gt_boxes = np.array(gt_boxes, dtype=np.uint16)
            gt_classes = np.array(gt_classes, dtype=np.int32)
        except ValueError as e:
            return None

        if self.target_transform is not None:
            target = self.target_transform(target)

        blobs['gt_classes'] = gt_classes
        blobs['boxes'] = gt_boxes * im_info[0][2]

        return blobs

    def __len__(self):
        return len(self.ids)

    def show(self, index):
        img, target = self.__getitem__(index)
        draw = ImageDraw.Draw(img)
        for obj in target:
            draw.rectangle(obj[0:4], outline=(255, 0, 0))
            draw.text(obj[0:2], obj[4], fill=(0, 255, 0))
        img.show()




In [7]:
root = '/data'
ds = VOCDetection(root, 'train')
print(len(ds))



288


In [9]:
id_dict = dict(enumerate(ds.classes))
print len(id_dict)
for i ,v in id_dict.iteritems():
    print i, v

2
0 __background__
1 671


In [10]:
import json

js = json.dumps(dict(enumerate(ds.classes)))

# Open new json file if not exist it will create
with open('id.json', 'a') as fp:
# write to json file
    fp.write(js)

In [12]:
crop_output_path = '/data/crop/'


for key, item in enumerate(ds):
    if key % 1000 == 0:
        print key
    
    if item is None:
        continue
    im_data = item['tensor']
    boxes = item['boxes']
    gt_classes = item['gt_classes']
    inp = im_data
    im = Image.fromarray(inp)
    flag = 0
    for box, class_name  in zip(boxes,gt_classes) :
        flag = 1
        try:
            copy_im = copy(im)
            copy_im = copy_im.crop(box)
            copy_im.show()
            output_dir = os.path.join(crop_output_path, id_dict[class_name])
            if not os.path.isdir(output_dir):
                print output_dir
                os.mkdir(output_dir)
            copy_im.save(os.path.join(
                    output_dir, str(randint(1, 10000)) + '.jpg'))
        except Exception as e:
            print e
    if flag:
        im.save(os.path.join(output_dir, 'origin' + str(randint(1, 10000)) + '.jpg'))

0
/data/crop/671


In [None]:
im = Image.fromarray(inp.astype('uint8'))

In [None]:
imshow(ds[0]['tensor'][0], ds[0]['boxes'])

In [None]:
                crop_img = copy_image.crop((left, top, right, bottom))
                crop_img.save(os.path.join(
                    crop_output_path, str(randint(1, 10000)) + image_file))