In [11]:
# Import statements
import torch
import torchvision as tv
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor as faster_rcnn_predictor
import torch.utils.tensorboard
import os
import numpy as np
from PIL import Image 

In [12]:
# Dataset Sub Class
class PennFudanDataset(object):
    def __init__(self, root, transforms):
        self.root = root
        self.transforms = transforms

        self.images = list(sorted(os.listdir(os.path.join(root, "PNGImages"))))
        self.masks = list(sorted(os.listdir(os.path.join(root, "PEDMasks"))))
    
    def __getitem__(self, idx):
        image_path = os.path.join(self.root, "PNGImages", self.images[idx])
        mask_path = os.path.join(self.root, "PedMasks", self.masks[idx])

        image = Image.open(image_path).convert("RGB")

        mask = Image.open(mask_path)
        mask = np.array(mask)

        object_ids = np.unique(mask)
        object_ids = object_ids[1:]

        masks = mask == object_ids[:, None, None]

        number_objs = len(object_ids)
        boxes = []
        for index in range(number_objs):
            position = np.where(masks[index])
            x_min = np.min(position[1])
            x_max = np.max(position[1])
            y_min = np.min(position[0])
            y_max = np.max(position[0])
            boxes.append([x_min, y_min, x_max, y_max])
        
        target = {
            "boxes": torch.as_tensor(boxes, dtype=torch.float32),
            "labels": torch.ones((number_objs, ), dtype=torch.int64),
            "masks": torch.as_tensor(masks, dtype=torch.uint8),
            "image_id": torch.tensor(idx),
            "area": (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]),
            "iscrowd": torch.zeros((number_objs, ), dtype=torch.int64)
        }

        if self.transforms is not None:
            image, target = self.transforms(image, target)

        return image, target
    
    def __len__(self):
        return len(self.images)

In [13]:
import torchvision
torchvision.__version__


'0.5.0'