In [1]:
import numpy as np
import random
import torch
import torch.utils.data as data
from PIL import Image
from PIL import Image, ImageFont, ImageDraw, ImageEnhance

import os
import json
import os.path
import colorsys


inf = float('inf')
nan = float('nan')

In [2]:
data_dir = "/home/aravind/dataset/"
ann_dir = data_dir + "annotations/panoptic/"

train_img_dir = data_dir + "train2017/"
train_seg_dir = ann_dir + "panoptic_train2017/"
train_ann_json = ann_dir + "panoptic_train2017.json"

val_img_dir = data_dir + "val2017/"
val_seg_dir = ann_dir + "panoptic_val2017/"
val_ann_json = ann_dir + "panoptic_val2017.json"

In [3]:
with open(val_ann_json,"r") as f:
    val_ann = json.load(f)
# with open(train_ann_json,"r") as f:
#     train_ann = json.load(f)



In [4]:
# config to train
# TODO: check Config is correct


class ProposalConfig():
    NAME = "InSegm"
    GPU_COUNT = 1
    # online training
    IMAGES_PER_GPU = 16
    STEPS_PER_EPOCH = 100
    NUM_WORKERS = 16
    PIN_MEMORY = True
    VALIDATION_STEPS = 20

    CAT_NAMES = ['BG'] + [
        'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
        'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
        'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
        'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella',
        'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
        'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
        'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork',
        'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
        'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
        'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv',
        'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
        'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
        'scissors', 'teddy bear', 'hair drier', 'toothbrush', 'banner',
        'blanket', 'bridge', 'cardboard', 'counter', 'curtain', 'door-stuff',
        'floor-wood', 'flower', 'fruit', 'gravel', 'house', 'light',
        'mirror-stuff', 'net', 'pillow', 'platform', 'playingfield',
        'railroad', 'river', 'road', 'roof', 'sand', 'sea', 'shelf', 'snow',
        'stairs', 'tent', 'towel', 'wall-brick', 'wall-stone', 'wall-tile',
        'wall-wood', 'water-other', 'window-blind', 'window-other',
        'tree-merged', 'fence-merged', 'ceiling-merged', 'sky-other-merged',
        'cabinet-merged', 'table-merged', 'floor-other-merged',
        'pavement-merged', 'mountain-merged', 'grass-merged', 'dirt-merged',
        'paper-merged', 'food-other-merged', 'building-other-merged',
        'rock-merged', 'wall-other-merged', 'rug-merged'
    ]
    CAT_IDS = [0] + [
        1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21,
        22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42,
        43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61,
        62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84,
        85, 86, 87, 88, 89, 90, 92, 93, 95, 100, 107, 109, 112, 118, 119, 122,
        125, 128, 130, 133, 138, 141, 144, 145, 147, 148, 149, 151, 154, 155,
        156, 159, 161, 166, 168, 171, 175, 176, 177, 178, 180, 181, 184, 185,
        186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199,
        200
    ]
    IGNORE_CAT_NAMES = ['BG']
    MEAN_PIXEL = np.array(
        [0.485, 0.456, 0.406], dtype=np.float32).reshape(1, 1, -1)
    STD_PIXEL = np.array(
        [0.229, 0.224, 0.225], dtype=np.float32).reshape(1, 1, -1)
    GRID_SHAPE = 14
    IMPULSE_SHAPE = (32, 32)
    MIN_AREA = 1
    MAX_AREA = inf
    MIN_INTERSECTION = 1
    
    def __init__(self):
        self.WIDTH = 32 * self.GRID_SHAPE
        self.HEIGHT = 32 * self.GRID_SHAPE
        self.BATCH_SIZE = self.IMAGES_PER_GPU * self.GPU_COUNT
        self.IMAGE_SHAPE = (self.WIDTH, self.HEIGHT, 3)
        # 133 + 1 in panoptic
        self.NUM_CATS = len(self.CAT_NAMES)
        self.IGNORE_CAT_IDS = [self.CAT_NAMES.index(c) for c in self.IGNORE_CAT_NAMES]
    def display(self):
        """Display Configuration values."""
        print("\nConfigurations:")
        for a in dir(self):
            if not a.startswith("__") and not callable(getattr(self, a)):
                print("{:30} {}".format(a, getattr(self, a)))
        print("\n")

In [5]:
class CocoDetection(data.Dataset):
    def __init__(self, img_dir, seg_dir, ann, config):
        self.img_dir = img_dir
        self.seg_dir = seg_dir
        self.coco_data = self.index_annotations(ann)
        self.config = config
        self.catMap = self.build_cat_map()

    def index_annotations(self, ann):
        # create map with coco image index as key
        d = {}
        for i in ann['annotations']:
            coco_index = i['image_id']
            d[coco_index] = {'segments_info': i['segments_info'],
                             'segments_file': i['file_name'],
                             'image_id': i['image_id']}
        for i in ann['images']:
            coco_index = i['id']
            image_file = i['file_name']
            d[coco_index]['image_file'] = image_file

        return list(d.values())

    # coco category ids remapped to contigous range(133+1)
    def build_cat_map(self):
        config = self.config
        coco_cat_ids = config.CAT_IDS
        catMap = {}
        for i in range(config.NUM_CATS):
            catMap[coco_cat_ids[i]] = i
        return catMap

    def __getitem__(self, index):

#         try:
        # 0. read coco data as is; if no instances of required criteria then
        # return random image since pytorch doesn't handle bad data here
        data = self.load_data(index)

        # 1. remove unwanted data
        # 2. fixed resolution.
        # 3. split stuff islands into different instances
        # 4. Data Augmentation: skipped for now
        data = self.standardize_data(*data)

        # 4. Target generation:
        return self.generate_targets(*data)
#         except:
#             print("problem loading image index: %d" % index)
#             return None

    def load_data(self, index):
        coco_data = self.coco_data
        config = self.config

        ignore_cat_ids = config.IGNORE_CAT_IDS

        ann = coco_data[index]
        image_id = ann['image_id']
        segments_info = ann['segments_info']
        segments_file = ann['segments_file']
        image_file = ann['image_file']

        img = Image.open(os.path.join(self.img_dir, image_file)).convert('RGB')
        img = np.array(img)

        instance_masks = []
        cat_ids = []

        coco_seg = Image.open(os.path.join(
            self.seg_dir, segments_file)).convert('RGB')
        coco_seg = np.array(coco_seg, dtype=np.uint8)
        seg_id = self.rgb2id(coco_seg)

        ignore_cat_ids = np.array(config.IGNORE_CAT_IDS)
        for s in segments_info:
            mask = np.where(seg_id == s['id'], 1, 0)
            iscrowd = s['iscrowd']
            cat_id = self.catMap[s['category_id']]
            if (s['iscrowd'] != 1) and (cat_id not in ignore_cat_ids):
                instance_masks.append(mask)
                cat_ids.append(self.catMap[s['category_id']])

        cat_ids = np.array(cat_ids)
        instance_masks = np.array(instance_masks)

        return img, instance_masks, cat_ids

    def standardize_data(self, img, instance_masks, cat_ids):
        instance_masks, cat_ids = self.split_stuff_islands(
            instance_masks, cat_ids)
        img, instance_masks = self.resize_data(img, instance_masks)

        # img, instance_masks, cat_ids = self.data_augment(img, instance_masks, cat_ids)
        return img, instance_masks, cat_ids

    def generate_targets(self, img, instance_masks, cat_ids):
        return img, instance_masks, cat_ids

    def rgb2id(self, color):
        return color[:, :, 0] + 256 * color[:, :, 1] + 256 * 256 * color[:, :, 2]

    def split_stuff_islands(self, instance_masks, cat_ids):
        from scipy.ndimage import label, convolve

        thing_idx = np.nonzero(cat_ids <= 80)
        stuff_idx = np.nonzero(cat_ids > 80)

        thing_ids = cat_ids[thing_idx]
        stuff_ids = cat_ids[stuff_idx]
        thing_masks = instance_masks[thing_idx]
        stuff_masks = instance_masks[stuff_idx]

        lp_filter = np.ones((16, 16))
        
        if stuff_ids.shape[0] == 0:
            return thing_masks, thing_ids

        for mask, stuff_id in zip(stuff_masks, stuff_ids):
            mask = convolve(mask, lp_filter, mode='constant', cval=0.0)
            mask = np.where(mask != 0, 1, 0)

            labelled_islands, num_islands = label(
                mask, structure=np.ones((3, 3)))
            islands = []
            for i in range(num_islands):
                island = np.where(labelled_islands == i+1, 1, 0)
                islands.append(island)
            islands = np.array(islands)
            island_cat_ids = np.array([stuff_id]*num_islands)
            thing_masks = np.concatenate([thing_masks, islands], 0)
            thing_ids = np.concatenate([thing_ids, island_cat_ids], 0)

        return thing_masks, thing_ids

    def resize_data(self, img, instance_masks):
        config = self.config

        w, h = config.WIDTH, config.HEIGHT
        img = self.resize_image(img, (w, h), "RGB")
        instance_masks = np.array(
            [self.resize_image(m, (w, h), "L") for m in instance_masks])

        return img, instance_masks

    def resize_image(self, img, size, mode):
        interpolation = {"RGB": Image.BICUBIC, "L": Image.NEAREST}[mode]
        img_obj = Image.fromarray(img.astype(np.uint8), mode)
        img_obj.thumbnail(size, interpolation)

        (w, h) = img_obj.size
        padded_img = Image.new(mode, size, "black")
        padded_img.paste(img_obj, ((size[0] - w) // 2, (size[1] - h) // 2))

        return np.array(padded_img)

    def __len__(self):
        return len(self.coco_data)

In [6]:
config = ProposalConfig()
val_dataset = CocoDetection(val_img_dir, val_seg_dir, val_ann, config)

In [7]:
def random_colors(N, bright=True):
    brightness = 1.0 if bright else 0.7
    hsv = [(i / N, 1, brightness) for i in range(N)]
    colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))
    random.shuffle(colors)
    return colors


def apply_mask(image, mask, color, alpha=0.5):
    for c in range(3):
        image[:, :, c] = np.where(
            mask == 1, image[:, :, c] * (1 - alpha) + alpha * color[c] * 255,
            image[:, :, c])
    return image

def extract_bbox(mask):
    m = np.where(mask != 0)
    # y1,x1,y2,x2. bottom right just outside of blah
    return np.min(m[0]), np.min(m[1]), np.max(m[0]) + 1, np.max(m[1]) + 1

def create_labelled_image(img,mask,class_name):
    img = Image.fromarray(img.astype(np.uint8))
    draw = ImageDraw.Draw(img)
    # y1, x1, y2, x2 = extract_bbox(masks)
    draw.rectangle(((0, 0), (40, 20)), fill="black")
    font = ImageFont.truetype("./data/Aaargh.ttf",14) 
    draw.text((5, 5), class_name, font=font,fill=(255,255,255))
    return img
def visualize_targets(img, masks, class_response, base_impulse, config):
    g = config.GRID_SHAPE
    
    img = np.moveaxis(img,0,2)
    img *= config.STD_PIXEL
    img += config.MEAN_PIXEL
    img *= 255
    class_ids = np.argmax(class_response,0).reshape(-1)
    N = class_ids.shape[0]
    response_colors = random_colors(N)
    impulse_colors = random_colors(N)
    for i in range(N):
        masked_img = img.copy()
        masked_img = apply_mask(masked_img, masks[i], response_colors[i])
        masked_img = apply_mask(masked_img, base_impulse[i], impulse_colors[i])
        masked_img = create_labelled_image(masked_img, masks[i], config.CLASS_NAMES[class_ids[i]])
        masked_img.save("./results/"+str(i)+".png","PNG")
        
def visualize_coco_data(img, masks, cat_ids, config):
    g = config.GRID_SHAPE
    
    img = np.moveaxis(img,0,2)
    img *= config.STD_PIXEL
    img += config.MEAN_PIXEL
    img *= 255
    class_ids = np.argmax(class_response,0).reshape(-1)
    N = class_ids.shape[0]
    response_colors = random_colors(N)
    impulse_colors = random_colors(N)
    for i in range(N):
        masked_img = img.copy()
        masked_img = apply_mask(masked_img, masks[i], response_colors[i])
        masked_img = apply_mask(masked_img, base_impulse[i], impulse_colors[i])
        masked_img = create_labelled_image(masked_img, masks[i], config.CLASS_NAMES[class_ids[i]])
        masked_img.save("./results/"+str(i)+".png","PNG")


In [17]:
index = random.choice(list(range(len(val_dataset))))
# index = 1072
# img, instance_masks, cat_ids = val_dataset.load_data(index)
print(index)
img, instance_masks, cat_ids = val_dataset[index]
Image.fromarray(img, "RGB").show()
print(cat_ids)
for i in range(instance_masks.shape[0]):
#     if config.CAT_NAMES[cat_ids[i]] == '-merged':
    if True:
        Image.fromarray((instance_masks[i].astype(np.uint8))*255,"L").show()
        print(np.sum(instance_masks[i]))
        print(config.CAT_NAMES[cat_ids[i]])

1089
[  1   1   1  35  36  98  98  98  98 126 126 126 126 126]
33169
person
27356
person
26293
person
932
baseball bat
495
baseball glove
33133
playingfield
1459
playingfield
298
playingfield
33718
playingfield
551
grass-merged
8657
grass-merged
4834
grass-merged
1235
grass-merged
2392
grass-merged


In [44]:
# index = random.choice(list(range(len(val_dataset))))
index = 5
img, instance_masks, cat_ids = val_dataset.load_data(index)
# img, instance_masks, cat_ids = val_dataset[index]
Image.fromarray(img, "RGB").show()
print(cat_ids)
for i in range(instance_masks.shape[0]):
#     if config.CAT_NAMES[cat_ids[i]] == 'dirt-merged':
    if True:
        Image.fromarray((instance_masks[i].astype(np.uint8))*255,"L").show()
        print(np.sum(instance_masks[i]))
        print(config.CAT_NAMES[cat_ids[i]])

[  1  31 106 120]
27486
person
3828
skis
213773
snow
26340
sky-other-merged


In [39]:
# some statistics
s = 0
for i in range(len(val_dataset)):
    _,_,cat_ids = val_dataset.load_data(i)
    s += cat_ids.shape[0]
print(s)

KeyboardInterrupt: 