In [1]:
import os
import sys

sys.path.append(os.path.join(os.getcwd(), ".."))
import numpy as np
from tqdm import tqdm

import torch
from torchvision import datasets

from source.constants import CITYSCAPES_PATH

In [2]:
# Regression task: Define a function to compute the number of pixels for a specific class
def compute_class_pixel_count(targets, class_index):
    if targets.dim() == 2:
        targets = targets.unsqueeze(0)
    """
    Compute the number of pixels belonging to a specific class in the segmentation maps.
    Args:
        targets (Tensor): Batch of segmentation masks (B, H, W).
        class_index (int): Index of the class to count pixels for.
    Returns:
        Tensor: Number of pixels belonging to the specified class for each image in the batch (B,).
    """
    return (targets == class_index).view(targets.size(0), -1).sum(dim=1)

In [3]:
# Ground: 6
# Road: 7
# building: 11
# sky: 23
# car: 26
# vegetation: 21
target_classes = [[6, 7], 11, 23, 26, 21]

# * UTILITY
# tile image and target into 256 x 256 images
get_tile = lambda x, i, j: x[256 * i : 256 * (i + 1), 256 * j : 256 * (j + 1)]
# get 224 x 224 center crop for a 256 x 256 image
center_crop = lambda x: x[16:240, 16:240]

In [4]:
# create train/val dataset
cityscapes_dataset = datasets.Cityscapes(
    root=CITYSCAPES_PATH,
    split="train",
    mode="fine",
    target_type="semantic",
    transform=lambda x: torch.tensor(np.array(x), dtype=torch.uint8),
    target_transform=lambda x: torch.tensor(
        np.array(x), dtype=torch.int32
    ),  # need to cast to numpy array first apparently
)

images, targets = [], []

for image, target in tqdm(cityscapes_dataset):
    for i in range(4):
        for j in range(8):
            image_patch = center_crop(get_tile(image, i, j)).permute(2, 0, 1)
            images.append(image_patch)

            target_patch = center_crop(get_tile(target, i, j))
            class_pixels = []
            for class_indices in target_classes:
                if isinstance(class_indices, int):
                    class_indices = [class_indices]
                class_pixels.append(
                    sum(
                        [
                            compute_class_pixel_count(target_patch, class_index)
                            for class_index in class_indices
                        ]
                    )
                )
            targets.append(torch.tensor(class_pixels, dtype=torch.int32))

  0%|          | 0/2975 [00:00<?, ?it/s]

100%|██████████| 2975/2975 [03:17<00:00, 15.07it/s]


In [5]:
N = 10000
indices = np.random.randint(0, len(images), size=N)
images = [images[i] for i in indices]
targets = [targets[i] for i in indices]

In [6]:
images = torch.stack(images, dim=0)
targets = torch.stack(targets, dim=0)

print(images.shape, targets.shape)
print(images.dtype, targets.dtype)

torch.save(images, os.path.join(CITYSCAPES_PATH, "train_images.pt"))
torch.save(targets, os.path.join(CITYSCAPES_PATH, "train_targets.pt"))

torch.Size([10000, 3, 224, 224]) torch.Size([10000, 5])
torch.uint8 torch.int32


In [7]:
# create train/val dataset
cityscapes_dataset = datasets.Cityscapes(
    root=CITYSCAPES_PATH,
    split="val",
    mode="fine",
    target_type="semantic",
    transform=lambda x: torch.tensor(np.array(x), dtype=torch.uint8),
    target_transform=lambda x: torch.tensor(
        np.array(x), dtype=torch.int32
    ),  # need to cast to numpy array first apparently
)

images, targets = [], []

for image, target in tqdm(cityscapes_dataset):
    for i in range(4):
        for j in range(8):
            image_patch = center_crop(get_tile(image, i, j)).permute(2, 0, 1)
            images.append(image_patch)

            target_patch = center_crop(get_tile(target, i, j))
            class_pixels = []
            for class_indices in target_classes:
                if isinstance(class_indices, int):
                    class_indices = [class_indices]
                class_pixels.append(
                    sum(
                        [
                            compute_class_pixel_count(target_patch, class_index)
                            for class_index in class_indices
                        ]
                    )
                )
            targets.append(torch.tensor(class_pixels, dtype=torch.int32))

  2%|▏         | 12/500 [00:00<00:31, 15.32it/s]

100%|██████████| 500/500 [00:33<00:00, 14.76it/s]


In [8]:
images = torch.stack(images, dim=0)
targets = torch.stack(targets, dim=0)

print(images.shape, targets.shape)
print(images.dtype, targets.dtype)

torch.save(images, os.path.join(CITYSCAPES_PATH, "test_images.pt"))
torch.save(targets, os.path.join(CITYSCAPES_PATH, "test_targets.pt"))

torch.Size([16000, 3, 224, 224]) torch.Size([16000, 5])
torch.uint8 torch.int32
