In [1]:
import torch
import torchvision
from torchvision.datasets import VisionDataset
from pycocotools.coco import COCO
import os
import cv2
from copy import deepcopy
import albumentations as A
import numpy as np
from albumentations.pytorch import ToTensorV2

In [2]:
def transform_image(train=False):
  if train:
    transform = A.Compose([
        A.Resize(640, 480, 3),
        # A.HorizontalFlip(p=0.3),
        # A.VerticalFlip(p=0.3),
        # A.RandomBrightnessContrast(p=0.1),
        # A.ColorJitter(p=0.1),
        ToTensorV2()
    ], bbox_params=A.BboxParams(format='coco'))
  else:
    transform = A.Compose([
        A.Resize(640, 480, 3),
        ToTensorV2()
    ], bbox_params=A.BboxParams(format='coco'))
  return transform

In [3]:
from typing import Any


class COCODataset(VisionDataset):
    def __init__(self, root, split='train', transforms=None, transform=None, target_transform=None):
        super().__init__(root, transforms, transform, target_transform)
        self.root = root
        self.json = os.path.join(root, 'annotations/person_keypoints_val2017.json')

        # Read annotations
        self.coco = COCO(self.json)
        self.split = split

        # extract and sort ids
        self.ids = sorted([img['id'] for img in self.coco.imgs.values()])

        # Remove empty dicts
        self.ids = [id for id in self.ids if len(self.coco.loadAnns(self.coco.getAnnIds(id))) > 0]

        self.transforms = transforms

    def _load_image(self, id):
        path = os.path.join(self.root, 'images/val2017', self.coco.imgs[id]['file_name'])
        image = cv2.imread(path)

        return image
    
    def __getitem__(self, index: int):
        id = self.ids[index]
        image = self._load_image(id)
        # Load anns data
        all_keypoints = []
        keypoints = self.coco.loadAnns(self.coco.getAnnIds(id))
        for i, element in enumerate(keypoints):
            keypoint = element['keypoints']
            keypoint = np.array_split(keypoint, len(keypoint)/3)
            height, width = image.shape[:2]
            width_scale_factor = 640 / width
            height_scale_factor = 480 / height
            for point in keypoint:
                point[0] = point[0] * width_scale_factor
                point[1] = point[1] * height_scale_factor
            all_keypoints.append(keypoint)
        
        image = cv2.resize(image, (640, 480))

        torch_image = torchvision.transforms.ToTensor()(image)
        return torch_image, image, all_keypoints

    # returns number of available images
    def __len__(self):
        return len(self.ids)
        