In [1]:
import torch
import torchvision
from torchvision.datasets import VisionDataset
from pycocotools.coco import COCO
import os
import cv2
from copy import deepcopy
import albumentations as A
import numpy as np
from albumentations.pytorch import ToTensorV2

In [2]:
def transform_image(train=False):
  if train:
    transform = A.Compose([
        A.Resize(640, 480, 3),
        A.HorizontalFlip(p=0.3),
        A.VerticalFlip(p=0.3),
        A.RandomBrightnessContrast(p=0.1),
        A.ColorJitter(p=0.1),
        ToTensorV2()
    ], bbox_params=A.BboxParams(format='coco'))
  else:
    transform = A.Compose([
        A.Resize(640, 480, 3),
        ToTensorV2()
    ], bbox_params=A.BboxParams(format='coco'))
  return transform

In [3]:
from typing import Any


class COCODataset(VisionDataset):
    def __init__(self, root, split='train', transforms=None, transform=None, target_transform=None):
        super().__init__(root, transforms, transform, target_transform)
        self.root = root
        self.json = os.path.join(root, 'annotations/person_keypoints_val2017.json')

        # Read annotations
        self.coco = COCO(self.json)
        self.split = split

        # extract and sort ids
        self.ids = sorted([img['id'] for img in self.coco.imgs.values()])

        # Remove empty dicts
        self.ids = [id for id in self.ids if len(self.coco.loadAnns(self.coco.getAnnIds(id))) > 0]

        self.transforms = transforms

    def _load_image(self, id):
        path = os.path.join(self.root, 'images/val2017', self.coco.imgs[id]['file_name'])
        image = cv2.imread(path)
        return image
    
    def __getitem__(self, index: int):
        id = self.ids[index]
        image = self._load_image(id)
        # Load anns data
        target = self.coco.loadAnns(self.coco.getAnnIds(id))
        target = deepcopy(target)

        # This format is required by albumentations
        bboxes = [t['bbox'] + [t['category_id']] for t in target]

        if self.transforms is not None:
            transformed = self.transforms(image=image, bboxes=bboxes) # it calls Compose from albumentations and transforms image and bboxes

        image = transformed['image']
        bboxes = transformed['bboxes']

        new_bboxes = [] # for conversion from xywh to xyxy
        for box in bboxes:
            x1 = box[0]
            y1 = box[1]
            x2 = x1 + box[2]
            y2 = y1 + box[3]
            new_bboxes.append([x1, y1, x2, y2])

        bboxes = torch.tensor(new_bboxes, dtype=torch.float32) # conversion from list to tensor

        # conversion of some of the target values to pytorch tensors
        final_target = {}
        final_target['boxes'] = bboxes
        final_target['labels'] = torch.tensor([t['category_id'] for t in target], dtype=torch.int64)
        final_target['image_id'] = torch.tensor([t['image_id'] for t in target], dtype=torch.int64)
        # final_target['area'] = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0])
        final_target['iscrowd'] = torch.tensor([t['iscrowd'] for t in target], dtype=torch.int64)

        return image.div(255), final_target

    # returns number of available images
    def __len__(self):
        return len(self.ids)
        