In [1]:
import numpy as np
from PIL import Image
from torchvision.transforms import PILToTensor, ToTensor, Compose, Resize, ToPILImage
from src.datasets.utils.ResizeToDivisibleBy32 import ResizeToDivisibleBy32
from src.datasets.utils.ConvertUAVidMasks import ConvertUAVidMasks

## Custom transform

In [2]:
UAVID_SINGLE_MASK_PATH = r"D:\__repos\aerial_segmentation\data\UAVidSemanticSegmentationDataset\train\train\seq5\Labels\000400.png"

In [3]:
color_classes_dict = {
    (0, 0, 0): 0,  # Clutter
    (128, 0, 0): 1,  # Building
    (128, 64, 128): 2,  # Road
    (192, 0, 192): 3,  # Static_Car
    (0, 128, 0): 4,  # Tree
    (128, 128, 0): 5,  # Vegetation
    (64, 64, 0): 6,  # Human
    (64, 0, 128): 7,  # Moving_Car
}

In [4]:
convert_uavid_masks = ConvertUAVidMasks(color_classes_dict)
pil_to_tensor_transform = PILToTensor()
to_pil_transform = ToPILImage()

In [5]:
mask = Image.open(UAVID_SINGLE_MASK_PATH)

In [6]:
print(mask.getextrema())
print(mask.size)
print(mask.mode)

((0, 192), (0, 128), (0, 192))
(3840, 2160)
RGB


In [7]:
pil_to_tensor_transform = PILToTensor()
torch_mask = pil_to_tensor_transform(mask)

In [8]:
print(torch_mask.shape)
print(torch_mask.min())
print(torch_mask.max())

torch.Size([3, 2160, 3840])
tensor(0, dtype=torch.uint8)
tensor(192, dtype=torch.uint8)


In [10]:
if mask.mode != "P":
    convert_uavid_masks_transform = ConvertUAVidMasks(color_classes_dict)
    torch_mask = convert_uavid_masks_transform(torch_mask)

In [11]:
print(torch_mask.shape)
print(torch_mask.min())
print(torch_mask.max())

torch.Size([2160, 3840])
tensor(0, dtype=torch.uint8)
tensor(6, dtype=torch.uint8)


## Sanity check dataloader class

In [12]:
from torch.utils.data import DataLoader
from torchvision import transforms

from src.datasets.UAVidSemanticSegmentationDataset import (
    UAVidSemanticSegmentationDataset,
)

In [13]:
UAVID_DATASET_PATH = "data/UAVidSemanticSegmentationDataset"
IMAGE_SIZE = 576

In [14]:
example_dataset = UAVidSemanticSegmentationDataset(
    UAVID_DATASET_PATH,
    # transforms=[Compose([Resize(IMAGE_SIZE), ResizeToDivisibleBy32()])]
)
print(len(example_dataset))

200


In [15]:
example_loader = DataLoader(example_dataset, batch_size=1, shuffle=True)

In [16]:
for images, masks in example_loader:
    print(images.shape)
    print(masks.shape)
    break

torch.Size([1, 3, 2160, 4096])
torch.Size([1, 3, 2160, 4096])
