In [1]:
import numpy as np
import random
import torch
import torch.utils.data as data
from PIL import Image
from PIL import Image, ImageFont, ImageDraw, ImageEnhance

import os
import json
import os.path
import colorsys


inf = float('inf')
nan = float('nan')

In [2]:
data_dir = "/home/aravind/dataset/"
ann_dir = data_dir + "annotations/panoptic/"

train_img_dir = data_dir + "train2017/"
train_seg_dir = ann_dir + "panoptic_train2017/"
train_ann_json = ann_dir + "panoptic_train2017.json"

val_img_dir = data_dir + "val2017/"
val_seg_dir = ann_dir + "panoptic_val2017/"
val_ann_json = ann_dir + "panoptic_val2017.json"

In [3]:
with open(val_ann_json,"r") as f:
    val_ann = json.load(f)
# with open(train_ann_json,"r") as f:
#     train_ann = json.load(f)



In [4]:
# config to train
# TODO: check Config is correct


class ProposalConfig():
    NAME = "InSegm"
    GPU_COUNT = 1
    # online training
    IMAGES_PER_GPU = 16
    STEPS_PER_EPOCH = 100
    NUM_WORKERS = 16
    PIN_MEMORY = True
    VALIDATION_STEPS = 20

    CAT_NAMES = ['BG'] + [
        'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
        'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
        'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
        'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella',
        'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
        'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
        'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork',
        'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
        'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
        'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv',
        'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
        'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
        'scissors', 'teddy bear', 'hair drier', 'toothbrush', 'banner',
        'blanket', 'bridge', 'cardboard', 'counter', 'curtain', 'door-stuff',
        'floor-wood', 'flower', 'fruit', 'gravel', 'house', 'light',
        'mirror-stuff', 'net', 'pillow', 'platform', 'playingfield',
        'railroad', 'river', 'road', 'roof', 'sand', 'sea', 'shelf', 'snow',
        'stairs', 'tent', 'towel', 'wall-brick', 'wall-stone', 'wall-tile',
        'wall-wood', 'water-other', 'window-blind', 'window-other',
        'tree-merged', 'fence-merged', 'ceiling-merged', 'sky-other-merged',
        'cabinet-merged', 'table-merged', 'floor-other-merged',
        'pavement-merged', 'mountain-merged', 'grass-merged', 'dirt-merged',
        'paper-merged', 'food-other-merged', 'building-other-merged',
        'rock-merged', 'wall-other-merged', 'rug-merged'
    ]
    CAT_IDS = [0] + [
        1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21,
        22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42,
        43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61,
        62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84,
        85, 86, 87, 88, 89, 90, 92, 93, 95, 100, 107, 109, 112, 118, 119, 122,
        125, 128, 130, 133, 138, 141, 144, 145, 147, 148, 149, 151, 154, 155,
        156, 159, 161, 166, 168, 171, 175, 176, 177, 178, 180, 181, 184, 185,
        186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199,
        200
    ]
    IGNORE_CAT_NAMES = ['BG']+['bear']
    MEAN_PIXEL = np.array(
        [0.485, 0.456, 0.406], dtype=np.float32).reshape(1, 1, -1)
    STD_PIXEL = np.array(
        [0.229, 0.224, 0.225], dtype=np.float32).reshape(1, 1, -1)
    GRID_SHAPE = 14
    IMPULSE_SHAPE = (32, 32)
    MIN_AREA = 32*32
    MIN_INTERSECTION = 1
    
    def __init__(self):
        self.WIDTH = 32 * self.GRID_SHAPE
        self.HEIGHT = 32 * self.GRID_SHAPE
        self.BATCH_SIZE = self.IMAGES_PER_GPU * self.GPU_COUNT
        self.IMAGE_SHAPE = (self.WIDTH, self.HEIGHT, 3)
        # 133 + 1 in panoptic
        self.NUM_CATS = len(self.CAT_NAMES)
        self.IGNORE_CAT_IDS = [self.CAT_NAMES.index(c) for c in self.IGNORE_CAT_NAMES]
    def display(self):
        """Display Configuration values."""
        print("\nConfigurations:")
        for a in dir(self):
            if not a.startswith("__") and not callable(getattr(self, a)):
                print("{:30} {}".format(a, getattr(self, a)))
        print("\n")

In [5]:
class CocoDetection(data.Dataset):
    def __init__(self, img_dir, seg_dir, ann, config):
        self.img_dir = img_dir
        self.seg_dir = seg_dir
        self.coco_data = self.index_annotations(ann)
        self.config = config
        self.catMap = self.build_cat_map()

    def index_annotations(self, ann):
        # create map with coco image index as key
        d = {}
        for i in ann['annotations']:
            coco_index = i['image_id']
            d[coco_index] = {'segments_info': i['segments_info'],
                             'segments_file': i['file_name'],
                             'image_id': i['image_id']}
        for i in ann['images']:
            coco_index = i['id']
            image_file = i['file_name']
            d[coco_index]['image_file'] = image_file

        return list(d.values())

    # coco category ids remapped to contigous range(133+1)
    def build_cat_map(self):
        config = self.config
        coco_cat_ids = config.CAT_IDS
        catMap = {}
        for i in range(config.NUM_CATS):
            catMap[coco_cat_ids[i]] = i
        return catMap

    def __getitem__(self, index):

        # 0. read coco data as is
        data = self.load_data(index)

        # 1. remove unwanted class data
        # 2. fixed resolution. 
        # 3. split stuff islands into different instances
        # 4. Data Augmentation: skipped for now
        data = self.standardize_data(*data)
        
        # 4. Target generation:
        return self.generate_targets(*data)

    def load_data(self, index):
        coco_data = self.coco_data
        config = self.config

        ann = coco_data[index]
        image_id = ann['image_id']
        segments_info = ann['segments_info']
        segments_file = ann['segments_file']
        image_file = ann['image_file']

        img = Image.open(os.path.join(self.img_dir, image_file)).convert('RGB')
        img = np.array(img)
        
        instance_masks = []
        cat_ids = []

        coco_seg = Image.open(os.path.join(
            self.seg_dir, segments_file)).convert('RGB')
        coco_seg = np.array(coco_seg, dtype=np.uint8)
        seg_id = self.rgb2id(coco_seg)
        
        for s in segments_info:
            mask = np.where(seg_id == s['id'], 1, 0)
            if s['iscrowd'] == 0 and s['area'] > config.MIN_AREA:
                instance_masks.append(mask)
                print(s['area'])
                cat_ids.append(self.catMap[s['category_id']])

        # add bg class, bg mask for unannotated regions
        cat_ids.append(0)
        cat_ids = np.array(cat_ids)
        instance_masks = np.array(instance_masks)
        bg_mask = np.where(np.sum(instance_masks,0) == 0, 1 ,0)
        bg_mask = bg_mask[None,:]
        instance_masks = np.concatenate([instance_masks, bg_mask], 0)
        
        return img, instance_masks, cat_ids
    
    def standardize_data(self, img, instance_masks, cat_ids):
        instance_masks, cat_ids = self.remove_ignored_cats(instance_masks, cat_ids)
        instance_masks, cat_ids = self.split_stuff_islands(instance_masks, cat_ids)
        img, instance_masks = self.resize_data(img, instance_masks)

        # img, instance_masks, cat_ids = self.data_augment(img, instance_masks, cat_ids)
        return img, instance_masks, cat_ids
        
    def generate_targets(self, img, instance_masks, cat_ids):
        return img, instance_masks, cat_ids
    
    def rgb2id(self, color):
        return color[:, :, 0] + 256 * color[:, :, 1] + 256 * 256 * color[:, :, 2]
    
    def remove_ignored_cats(self, instance_masks, cat_ids):
        config = self.config
        ignore_cat_ids = np.array(config.IGNORE_CAT_IDS)
        
        # get indices of ignored cats in the current 
        # set of cat_ids instances we need to ignore
        ignore_idx = np.nonzero(ignore_cat_ids[:,None] == cat_ids)[1]
        
        # delete these masks, cat_ids
        instance_masks = np.delete(instance_masks, ignore_idx, 0)
        cat_ids = np.delete(cat_ids, ignore_idx, 0)
        
        return instance_masks, cat_ids
    
    def split_stuff_islands(self, instance_masks, cat_ids):
        return instance_masks, cat_ids
    
    def resize_data(self, img, instance_masks):
        config = self.config
        
        w, h = config.WIDTH, config.HEIGHT
        img = self.resize_image(img, (w, h), "RGB")
        instance_masks = np.array(
            [self.resize_image(m, (w, h), "L") for m in instance_masks])
        
        return img, instance_masks
    
    def resize_image(self, img, size, mode):
        interpolation = {"RGB": Image.BICUBIC, "L": Image.NEAREST}[mode]
        img_obj = Image.fromarray(img.astype(np.uint8), mode)
        img_obj.thumbnail(size, interpolation)

        (w, h) = img_obj.size
        padded_img = Image.new(mode, size, "black")
        padded_img.paste(img_obj, ((size[0] - w) // 2, (size[1] - h) // 2))

        return np.array(padded_img)

    def statistics(self):
        for i in range(len(self)):
            img, instance_masks, class_ids = self.load_data(i)
        
    def __len__(self):
        return len(self.coco_data)

In [6]:
config = ProposalConfig()
val_dataset = CocoDetection(val_img_dir, val_seg_dir, val_ann, config)

In [7]:
index = random.choice(list(range(len(val_dataset))))
img, instance_masks, cat_ids = val_dataset[index]
Image.fromarray(img, "RGB").show()
print(instance_masks.shape)
for i in range(instance_masks.shape[0]):
    Image.fromarray((instance_masks[i].astype(np.uint8))*255,"L").show()
    print(config.CAT_NAMES[cat_ids[i]])

1158
26891
8497
25990
45008
34310
83645
(7, 448, 448)
giraffe
giraffe
giraffe
tree-merged
sky-other-merged
grass-merged
dirt-merged
