In [None]:
import os
import torch
import typing
from typing import List
import pandas as pd
import numpy as np
from PIL import Image
from pycocotools.coco import COCO
import matplotlib.pyplot as plt
import cv2
import torchvision
from torchvision.transforms import transforms as transforms

%matplotlib inline

In [None]:
RANDOM_SEED = 42

np.random.seed(RANDOM_SEED)

TRAIN_IMAGES_PATH = 'data/public_training_set_release_2.0/images/'
TRAIN_LABELS = 'data/public_training_set_release_2.0/annotations.json'

# Data Exploration

In [None]:
def read_image(path: str) -> np.ndarray:
    return np.array(Image.open(path))

def show_image_coco(image_id: int, coco_labels: COCO, with_mask: bool = True) -> np.ndarray:
    im_info = coco_labels.loadImgs(image_id)[0]
    
    if not with_mask:
        plt.imshow(read_image(TRAIN_IMAGES_PATH + im_info['file_name']))
    else:
        objs_ann = labels.imgToAnns[image_id]
        
        res_image = read_image(TRAIN_IMAGES_PATH + im_info['file_name'])

        plt.imshow(res_image)
        labels.showAnns(objs_ann)

In [None]:
labels = COCO(TRAIN_LABELS)

In [None]:
dir(labels)

In [None]:
img_ids = labels.getImgIds()
#184135
labels.imgToAnns[img_ids[1]]

In [None]:
len(labels.getCatIds())

In [None]:
show_image_coco(img_ids[66], labels, True)

# Dataset

In [None]:
class FoodDataset(torch.utils.data.Dataset):
    def __init__(self, img_path: str, coco_ds_path: str, trans: torchvision.transforms = None):
        self.img_path = img_path
        self.trans = trans

        self.coco_ds = COCO(coco_ds_path)
        self.img_ids  = sorted(self.coco_ds.getImgIds())
        
    def __getitem__(self, idx:int) -> dict:
        '''
        Args:
            idx: index of sample
        return:
            dict containing:
            - np.ndarray image of shape (H, W)
            - target (dict) containing: 
                - boxes:    FloatTensor[N, 4], N being the n° of instances and it's bounding 
                boxe coordinates in [x0, y0, x1, y1] format, ranging from 0 to W and 0 to H;
                - labels:   Int64Tensor[N], class label (0 is background);
                - image_id: Int64Tensor[1], unique id for each image;
                - area:     Tensor[N], area of bbox;
                - iscrowd:  UInt8Tensor[N], True or False;
                - masks:    UInt8Tensor[N, H, W], segmantation maps;
        '''
        img_id = self.img_ids[idx]
        img_obj = self.coco_ds.loadImgs(img_id)[0]
        anns_obj = self.coco_ds.loadAnns(self.coco_ds.getAnnIds(img_id)) 

        img = Image.open(os.path.join(self.img_path, img_obj['file_name']))

        bboxes = [ann['bbox'] for ann in anns_obj]
        masks = [self.coco_ds.annToMask(ann) for ann in anns_obj]
        areas = [ann['area'] for ann in anns_obj]

        boxes = torch.as_tensor(bboxes, dtype=torch.float32)
        labels = torch.ones(len(anns_obj), dtype=torch.int64)
        masks = torch.as_tensor(masks, dtype=torch.uint8)
        image_id = torch.tensor([idx])
        area = torch.as_tensor(areas)
        iscrowd = torch.zeros(len(anns_obj), dtype=torch.int64)


        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["masks"] = masks
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd

        if self.trans is not None and False:
            img, target = self.trans(img, target)

        return img, target

    def __len__(self) -> int:
        return len(self.imgs)

In [None]:
trans = transforms.Compose([
    transforms.ToTensor()
])
a = FoodDataset(TRAIN_IMAGES_PATH, TRAIN_LABELS, trans)

In [None]:
c = a.__getitem__(0)

In [None]:
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True, progress=True, 
                                                           num_classes=91) #498
# set the computation device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# load the modle on to the computation device and set to eval mode
model.to(device).eval()

In [None]:

data = a.__getitem__(0)[0]
image = torch.unsqueeze(torch.transpose(torch.from_numpy(np.array(data)/255.0),0,2),0).float().to(device)

model(image)

In [None]:
image.shape

In [None]:
transforms