In [None]:
from pycocotools.coco import COCO

dataDir = '/home/delong/workspace/dataset/coco2017'
dataType = 'val2017'
annFile = '{}/annotations/instances_{}.json'.format(dataDir, dataType)
coco = COCO(annFile)

In [None]:
import torch
from torch.utils.data import Dataset
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from pycocotools.coco import COCO

class COCODataset(Dataset):
    def __init__(self, dataDir, dataType):
        self.dataDir = dataDir
        self.dataType = dataType
        self.coco = COCO('{}/annotations/instances_{}.json'.format(dataDir, dataType))
        self.ids = list(sorted(self.coco.imgs.keys()))

    def __getitem__(self, index):
        coco = self.coco
        img_id = self.ids[index]
        ann_ids = coco.getAnnIds(imgIds=img_id)
        coco_annotation = coco.loadAnns(ann_ids)
        path = coco.loadImgs(img_id)[0]['file_name']

        masks = []
        labels = []
        for annotation in coco_annotation:
            mask = coco.annToMask(annotation)
            masks.append(mask)
            labels.append(annotation['category_id'])

        masks = np.stack(masks, axis=0)
        labels = np.array(labels, dtype=np.int64)

        return path, masks, labels

    def __len__(self):
        return len(self.ids)

# Create the dataset
dataDir = '/home/delong/workspace/dataset/coco2017'
dataType = 'val2017'
dataset = COCODataset(dataDir, dataType)

id_to_name = {category['id']: category['name'] for category in dataset.coco.loadCats(coco.getCatIds())}
print(id_to_name)

# Visualize the first 5 samples
for i in range(5):
    path, masks, labels = dataset[np.random.randint(0, len(dataset))]
    img = Image.open(f"{dataDir}/images/{dataType}/{path}")
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(img)
    plt.subplot(1, 2, 2)
    plt.imshow(np.sum(masks, axis=0))
    plt.show()

    print([id_to_name[label] for label in labels])  
    