Importing required dependencies

In [1]:
from pycocotools.coco import COCO
import numpy as np
import skimage.io as io
import random
import cv2
%matplotlib inline

Important variables to be changed according to your preferences (or left alone)

In [2]:
classes = ['laptop', 'tv']
image_size = (224,224)
folder = './COCOdataset2017'
mode = ['val', 'train']

Filter out the images with the inputs from "class" variable containing a list of classes of COCODataset

In [3]:
def filterDataset(folder, classes=None, mode='train'):    
    # initialize COCO api for instance annotations
    annFile = '{}/annotations/instances_{}.json'.format(folder, mode)
    coco = COCO(annFile)
    
    images = []
    if classes!=None:
        # iterate for each individual class in the list
        for className in classes:
            # get all images containing given categories
            catIds = coco.getCatIds(catNms=className)
            imgIds = coco.getImgIds(catIds=catIds)
            images += coco.loadImgs(imgIds)
    
    else:
        imgIds = coco.getImgIds()
        images = coco.loadImgs(imgIds)
    
    # Now, filter out the repeated images
    unique_images = []
    for i in range(len(images)):
        if images[i] not in unique_images:
            unique_images.append(images[i])
            
    random.shuffle(unique_images)
    dataset_size = len(unique_images)
    
    return unique_images, dataset_size, coco

Helper functions for creating trainig data

In [4]:
def getClassName(classID, cats):
    for i in range(len(cats)):
        if cats[i]['id']==classID:
            return cats[i]['name']
    return None

def getImage(imageObj, img_folder, input_image_size):
    # Read and normalize an image
    train_img = io.imread(img_folder + '/' + imageObj['file_name'])/255.0
    # Resize
    train_img = cv2.resize(train_img, input_image_size)
    if (len(train_img.shape)==3 and train_img.shape[2]==3): # If it is a RGB 3 channel image
        return train_img
    else: # To handle a black and white image, increase dimensions to 3
        stacked_img = np.stack((train_img,)*3, axis=-1)
        return stacked_img
        
def getBinaryMask(imageObj, coco, catIds, input_image_size):
    annIds = coco.getAnnIds(imageObj['id'], catIds=catIds, iscrowd=None)
    anns = coco.loadAnns(annIds)
    train_mask = np.zeros(input_image_size)
    for a in range(len(anns)):
        new_mask = cv2.resize(coco.annToMask(anns[a]), input_image_size)
        
        #Threshold because resizing may cause extraneous values
        new_mask[new_mask >= 0.5] = 1
        new_mask[new_mask < 0.5] = 0

        train_mask = np.maximum(new_mask, train_mask)

    # Add extra dimension for parity with train_img size [X * X * 3]
    train_mask = train_mask.reshape(input_image_size[0], input_image_size[1])
    return train_mask

def dataGeneratorCoco(images,coco, classes,folder,
                      batch_size=4,input_image_size=(224,224), mode='train'):
                      
    img_folder = '{}/images/{}'.format(folder, mode)
    catIds = coco.getCatIds(catNms=classes)

    for i in range(batch_size): 
        imageObj = images[i]
        train_img = getImage(imageObj, img_folder, input_image_size)
        cv2.imwrite(folder + "/images/train_img/" + imageObj['file_name'], cv2.cvtColor(np.float32(train_img), cv2.COLOR_BGR2RGB)*255)
        
        train_mask = getBinaryMask(imageObj, coco, catIds, input_image_size)
        cv2.imwrite(folder + "/images/train_mask/".format(mode) + imageObj['file_name'], train_mask*255) 
        

Finally call the functions to create the masked and non masked images resized according to "image_size" above

In [5]:

for mode in mode:
    images, dataset_size, coco = filterDataset(folder, classes, mode)
    batch_size = len(images)
    dataGeneratorCoco(images, coco, classes, folder, batch_size,image_size, mode)

loading annotations into memory...
Done (t=0.77s)
creating index...
index created!
loading annotations into memory...
Done (t=20.55s)
creating index...
index created!
