# Coco Manager


* Reads the annotations from a JSON file (i.e. `instances_train2017.json`)

* Reads the images either from the disk (i.e. from `COCO/train2017/` folder),
or by downloading them from `coco_url`s provided in annotations JSON file

* Selects all of the Coco classes (81 classes total), or produces a subset to scale according to the project (i.e. if `classes = ['person']`, it will find all of the images with a person, and delete all other bounding boxes and masks from them)

* Scales the images to have a particular shape, gathers images into batches

* Shows the image alongside the semantic masks of the objects

* Fits easily into a `DataLoader` (see `Usage_Coco_Manager.ipynb`)

In [None]:

import torch
import torchvision as tv
import torch.nn as nn

import numpy as np
from pycocotools.coco import COCO
from PIL import Image

import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

import itertools
import os
from typing import Dict, Tuple, List

from skimage.morphology import disk, closing
from skimage import io


In [None]:

class Coco_Manager( torch.utils.data.Dataset ):
    
    def __init__( self,
                 local_imgs_path:   str=None,
                 path_ann:          str=None,
                 coco:              COCO=None,
                 output_size:       Tuple[int]=None):
        
        self.version = 1.3
        
        self.local_imgs_path = local_imgs_path
        self.coco = coco if coco else COCO( path_ann )
        self.ids = list(sorted(self.coco.imgs.keys()))
        print( f'(Full Set) Images found: {len(self.ids)}' )

        self.output_size = output_size
        self.possible_colors = self.create_colors()
        np.random.shuffle(self.possible_colors)


    '''Generate colors for the masks'''
    def create_colors( self ) -> List[Tuple[int]]:
        perm = lambda choices, n: list(itertools.product(*[choices]*n))
        colors = perm([64*i for i in range(4)], 3) 
        return [list(rgb) for rgb in colors[1:]]


    '''Choose a subset of image IDs'''
    def make_subset( self,
                    classes: List,
                    intersection: bool=False ) -> List:

        catIds = self.coco.getCatIds(catNms=classes)

        if intersection:
            true_ids = self.coco.getImgIds(catIds=catIds)
        else:
            true_ids = []
            for catId in catIds:
                true_ids = true_ids + self.coco.getImgIds(catIds=[catId])
            true_ids = list(set(true_ids))
                
        len_new = len(true_ids)
        len_old = len(self.ids)
        
        print( f'(Subset) Images found: {len_new}/{len_old} ({round(100*len_new/len_old, 2)}%)' )

        self.ids = true_ids
        self.catIds = catIds
        self.classes = classes
    

    '''Get an image from coco_url, if images themselves are not on disk'''
    def download_image( self, id: int ) -> Tuple[np.array, List]:
        
        imgIds = self.coco.getImgIds(imgIds = id)
        try:
            img = self.coco.loadImgs(imgIds[np.random.randint(0,len(imgIds))])[0]
            annIds = self.coco.getAnnIds( id )
            ann = self.coco.loadAnns( annIds )
        except:
            print( f"No image found with ID: {id}" )
            return None, None
        for i in range(len(ann)):
            ann[i].update({'category_name': self.coco.loadCats(ann[i]['category_id'])})

        img = np.array(io.imread(img['coco_url']))
        return img, ann



    '''DataLoader outputs Tuple[ image, dictionary with annotations ]'''
    def collate_fn( self, data: Tuple[ torch.Tensor, Dict ]
                  ) -> Tuple[ torch.Tensor, List[Dict] ]:
        
        imgs = [sample[0] for sample in data]
        anns = [sample[1] for sample in data]
        imgs = torch.stack( imgs )
        return {'imgs':imgs, 'anns':anns}
     
    
    def __len__( self ):
        return len( self.ids )


    def transform( self, img ) -> Tuple[torch.Tensor, Tuple[int]]:

        '''PIL Image to pytorch Tensor'''
        img = torch.Tensor(img)
        
        '''Pad the image if it's smaller than output_size'''
        if self.output_size != None:
            target_w, target_h = self.output_size
            w, h = img.size()[0], img.size()[1]           
            w_pad = (target_w - w) 
            h_pad = (target_h - h)
            crop_origin = (0, 0)
            if (w_pad > 0) or (h_pad > 0):
                img, crop_origin = self.center_padding( img,
                                          w_pad if (w_pad>0) else 0,
                                          h_pad if (h_pad>0) else 0)
            
            if (w_pad < 0) or (h_pad < 0):                
                img = self.resize(img)
                ratio = (target_w/w, target_h/h)
            else:
                ratio = (1, 1)
        
        return img, crop_origin, ratio



    '''Reshape the image to fit the output_size'''
    def resize( self, img: torch.Tensor ) -> torch.Tensor:
        resize = tv.transforms.Resize(size = self.output_size, interpolation=tv.transforms.InterpolationMode.NEAREST)
        T = torch.transpose
        img_resized = T(T(img, 0, 2), 1, 2)
        img_resized = resize(img_resized)
        img_resized = T(T(img_resized, 0, 1), 1, 2)
        # print(f'Resized: img.shape = {img.shape} -> img_resized.shape = {img_resized.shape}')
        return img_resized


    '''Add zero-padding to match output_size'''
    def center_padding( self,
                        img:  torch.Tensor,
                        w_pad:    int,
                        h_pad:    int   ) -> torch.Tensor:

        left_pad = w_pad // 2
        right_pad = w_pad - left_pad
        top_pad = h_pad // 2
        bottom_pad = h_pad - top_pad

        imgpad = nn.functional.pad(
            input=img,
            pad=(0, 0, top_pad, bottom_pad, left_pad, right_pad),
            mode="constant",
            value = 0 )
        # print( f'Added zero-padding: {img.size()} -> {imgpad.size()}' )
        return imgpad, (top_pad, left_pad)



    '''Fetch the image and annotations, transform and add to the batch'''
    def __getitem__( self, index: int ):
        
        '''Dataset ID to Coco ID: works in one of two modes: with local images and without'''
        img_id = self.ids[index]
        if self.local_imgs_path == None:
            img, coco_annotations = self.download_image( img_id )
        else:
            ann_ids = self.coco.getAnnIds( imgIds=img_id )
            coco_annotations = self.coco.loadAnns(ann_ids)
            path = self.coco.loadImgs(img_id)[0]['file_name']
            img = Image.open(os.path.join(self.local_imgs_path, path))

        '''Image Transformations'''
        img, crop_origin, ratio = self.transform( img )
        img /= 255

        '''If the subset is used, erase all unnecessary annotations'''
        if self.catIds:
            coco_annotations = [obj for obj in coco_annotations if obj['category_id'] in self.catIds ]

        '''Object Labels'''
        num_objs = len(coco_annotations)
        for i in range(num_objs):
            coco_annotations[i].update({'category_name': self.coco.loadCats(coco_annotations[i]['category_id'])})
        obj_labels = [ ob['category_name'][0]['name'] for ob in coco_annotations ]

        '''Bounding Boxes and Masks'''
        '''Coco [xmin, ymin, width, height] -> torch [xmin, ymin, xmax, ymax]'''
        boxes = []
        masks = []
        
        for i in range(num_objs):

            box = self.translate_box(
                coco_annotations[i]['bbox'],
                crop_origin,
                ratio )
            
            mask = self.coco.annToMask(coco_annotations[i])
            masks.append( mask )
            boxes.append( box )
            

        '''Bounding boxes areas'''
        areas = [ coco_annotations[i]['area'] for i in range(num_objs) ]


        '''Merge binary masks of the same class
        for example: 224x224x3 image with 2 classes -> np.array([224, 224, 2])
        From the greatest mask to the smallest'''
        stack_shape = (masks[0].shape[0], masks[0].shape[1], len(self.classes))
        masks_stack = np.zeros(stack_shape, dtype='uint8')

        for i in reversed(np.argsort(np.array(areas))):
            label = obj_labels[i]
            mask = masks[i]

            layer = self.classes.index( label )
            masks_stack[ :, :, layer ] |= mask
            
        masks_stack, _, _ = self.transform( masks_stack )

        '''Image Annotations Dictionary'''
        annotation = {}
        annotation["boxes"] = torch.as_tensor(boxes, dtype=torch.float32)
        annotation["masks"] = masks_stack
        annotation["labels"] = obj_labels
        annotation["coco_id"] = torch.tensor([img_id])
        annotation["areas"] = torch.as_tensor(areas, dtype=torch.float32)
        
        # ???
        # print( f'COCO: img.size() = {img.size()}' )
        # if img.size()[0] == 1:
        #     img = torch.stack([img for _ in range(3)])

        # print( f'Number of Objects: {num_objs}' )
        return img, annotation




    '''Shows the image and the semantic masks'''
    def show( self, img: np.array, anns: Dict ) -> None:
        
        with plt.style.context('dark_background'):
        
            fig, ax = plt.subplots(ncols=2, figsize=(18, 15), dpi=110)
            
            bboxes = anns['boxes']
            labels = anns['labels']
            obj_num = bboxes.size()[0]

            possible_colors = self.possible_colors.copy()
            colors = {label: np.array(possible_colors.pop(0)) for label in self.classes}

            '''Draw the boxes on the original image'''
            default_color = 'lawngreen'
            for i in range(obj_num):

                '''Shift the label color to brighter hues'''
                label = labels[i]
                color = (colors[label] / 255 + 1.5) / (1.5+1) 

                bbox = bboxes[i]
                rect = mpatches.Rectangle(
                    (bbox[0], bbox[1]), bbox[4], bbox[5], fill=False, edgecolor=default_color, linewidth=2)
                circle = mpatches.Circle(
                    (bbox[0]+torch.div(bbox[4], 2, rounding_mode='floor' ),
                     bbox[1]+torch.div(bbox[5], 2, rounding_mode='floor' )), color=default_color)
                ax[0].add_patch( rect )
                ax[0].add_patch( circle )
                ax[0].text( bbox[0], bbox[1]-5, labels[i], fontdict={'color':default_color} )
                ax[1].text( bbox[0], bbox[1]-5, labels[i], fontdict={'color':color} )
                
            '''Draw image'''
            img = img.numpy()
            ax[0].imshow( img )
            ax[0].set_title(f'Original Image (Coco ID: {anns["coco_id"].item()})', y=1.04)
            
            '''Fill the empty canvas with masks'''
            canvas = np.zeros((img.shape[0], img.shape[1], 3), dtype='uint8')
            masks = anns['masks']
            
            for layer, label in enumerate(self.classes):
                color = colors[label]
                mask = masks[:,:,layer].numpy()
                for i in range(3):
                    canvas[:,:,i] += (mask * color[i]).astype(np.uint8) # // 3

            # for i in range(3):
            #     canvas[:,:,i] = closing(canvas[:,:,i], disk(3))
            
            ax[1].imshow(canvas)
            ax[1].set_title('Mask from COCO annotations', y=1.04)



    def translate_box( self, org_bbox, crop_origin, ratio ):
        
        x0, y0 = crop_origin
        ratio_h, ratio_w = ratio

        xmin = int((org_bbox[0] + x0) * ratio_w)
        ymin = int((org_bbox[1] + y0) * ratio_h)

        width  = int(org_bbox[2] * ratio_w)
        height = int(org_bbox[3] * ratio_h)
        xmax   = xmin + width
        ymax   = ymin + height

        return [xmin, ymin, xmax, ymax, width, height]
