In [1]:
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from torchvision.datasets import Cityscapes

In [2]:
!rm ./cityscapes/leftImg8bit/train/.*

zsh:1: no matches found: ./cityscapes/leftImg8bit/train/.*


In [108]:
import os
from PIL import Image


MEAN = np.array([72.55410438, 81.93415236, 71.4297832]) / 255
STD = np.array([51.04788791, 51.76003371, 50.94766331]) / 255

IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]


class CustomCityScapes(Dataset):
    def __init__(self, root_dir, split="train", target_type=["semantic"], image_transforms=None, mask_transforms=None):
        self.image_transforms = image_transforms
        self.mask_transforms = mask_transforms
        self.target_type = target_type
        # ./cityscapes/leftImg8bit/train/leftImg8bit/train <- mac/linux
        # .\\cityscapes\\leftImg8bit\\train\\leftImg8bit\\train <- windows
        self.images_split_path = os.path.join(root_dir, "leftImg8bit", split)
        self.annot_split_path = os.path.join(root_dir, "gtFine", split)
        
        self.image_paths_only = sorted([os.path.join(dir_path, file) 
                            for city in os.listdir(self.images_split_path) 
                            for dir_path, _, files in os.walk(os.path.join(self.images_split_path, city)) 
                            for file in files])
        
        # "./cityscapes/leftImg8bit/train/zurich/zurich_000069_000019_leftImg8bit.png" -> "zurich_000069_000019"
        self.image_ids = [os.path.split(file)[1].replace("_leftImg8bit.png", "") for file in self.image_paths_only]
        # "zurich_000069_000019": "./cityscapes/leftImg8bit/train/zurich/zurich_000069_000019_leftImg8bit.png"
        self.image_paths = {id_: file for id_, file in zip(self.image_ids, self.image_paths_only)}

        target_type_map = {'semantic': 'labelIds', 'color': 'color', 'polygons': 'polygon', 'instance': 'instanceIds'}
        self.target_type = [target_type_map[target] for target in self.target_type]
        self.annot_paths = list(zip(*[self.get_annotation_paths(type_=type_) for type_ in self.target_type]))


    def __iter__(self):
        for index_ in range(self.__len__()):
            yield self.__getitem__(index_)
    

    def __getitem__(self, index):
        img = Image.open(self.image_paths_only[index])
        annots = [Image.open(target_item_path) for target_item_path in self.annot_paths[index]]

        if self.image_transforms:
            img = self.image_transforms(img)
        
        if self.mask_transforms:
            annots = [self.mask_transforms(annot) for annot in annots]

        return img, annots


    def __repr__(self):
        return "Custom Cityscapes dataset"


    def __len__(self):
        return len(self.image_paths)
        

    def get_source_image_paths(self, id_=False):
        if not id_:
            return self.image_paths
        return self.image_paths[id_]
    

    def get_image(self, id_):
        return Image.open(self.image_paths[id_])


    def get_annotation_paths(self, type_):
        self.annot_path_local = {'polygons': {}, 'color': {}, 'labelIds': {}, 'instanceIds': {}}
        for city in os.listdir(self.annot_split_path):
            for dir_path, _, files in os.walk(os.path.join(self.annot_split_path, city)):
                for file in files:
                    file_path = os.path.join(dir_path, file)
                    file_id, cat = os.path.splitext(os.path.split(file_path)[1].replace('_gtFine_', '|'))[0].split('|')
                    self.annot_path_local[cat][file_id] = file_path
        
        return [self.annot_path_local[type_][id_] for id_ in self.image_ids]


# Dataset class ends here

def crop_eco_vehicle(image, bottom=-180):
    width, height = image.size
    return image.crop((0, 0, width, height+bottom))


train_transforms = transforms.Compose([
    transforms.Lambda(lambda image: crop_eco_vehicle(image)),
    transforms.Resize((256, 512), interpolation=transforms.InterpolationMode.NEAREST),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3),
    transforms.ToTensor(),
    transforms.Normalize(mean=MEAN, std=STD),
    
])

eval_transforms = transforms.Compose([
    transforms.Lambda(lambda image: crop_eco_vehicle(image)),
    transforms.Resize((256, 512), interpolation=transforms.InterpolationMode.NEAREST),
    transforms.ToTensor(),
    transforms.Normalize(mean=MEAN, std=STD),
])

mask_transforms = transforms.Compose([
    transforms.Lambda(lambda image: crop_eco_vehicle(image)),
    transforms.Resize((256, 512), interpolation=transforms.InterpolationMode.NEAREST),
    transforms.ToTensor()
])

dataset = CustomCityScapes("./cityscapes", split="train", target_type=["color", "semantic"], image_transforms=train_transforms, mask_transforms=mask_transforms)
# dataset = CustomCityScapes("./cityscapes", split="train", target_type=["color", "semantic"], image_transforms=None, mask_transforms=None)

In [110]:
# Checking the images

# plt.figure(figsize=(10, 10))
# subplot_idx = 1
# for i in range(5):
#     im, (color, sem) = dataset[i]
#     im = np.array(im).transpose(1, 2, 0)
#     # Image
#     plt.subplot(5, 3, subplot_idx)
#     subplot_idx += 1
#     plt.imshow(im)
#     plt.axis("off")
#     # Color info
#     plt.subplot(5, 3, subplot_idx)
#     subplot_idx += 1
#     plt.imshow(color)
#     plt.axis("off")
#     # Semantic info
#     plt.subplot(5, 3, subplot_idx)
#     subplot_idx += 1
#     plt.imshow(sem, cmap="gray")
#     plt.axis("off")

# plt.tight_layout()