In [1]:
# Basic Imports
from glob import glob
import re
import os
import numpy as np
import random
from tqdm import tqdm

# Torch Imports
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToPILImage

# Image utils imports
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2
import matplotlib.pyplot as plt

In [2]:
# Load config file and get potsdam data path
import yaml
config_path = '/home/tu/tu_tu/tu_zxmav84/DS_Project/modules/config.yml'
with open(config_path, 'r') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)
orthophoto_dir = config['data']['orthophotos']

### Path Helper Functions

In [3]:
def get_file_paths():
    """
    Retrieves the file paths of image and mask files in the 'Patched' directory.

    Returns:
        dict: A dictionary containing the file paths of image and mask files.
              The dictionary is integer-indexed for easy sampling using an integer index.
              The keys are integers representing the index, and the values are tuples
              containing the image file path and the corresponding mask file path.

    """
    mask_files = glob(orthophoto_dir + '/labeling_subset/final_masks/*.tif')
    image_files = glob(orthophoto_dir + '/labeling_subset/images/*.tif')

    print(f"Indexing files in orthophotos/labeling_subset... \nFound: \t {len(image_files)} Images \n\t {len(mask_files)} Mask")

    # Get base name of all files and create dict with image and mask file paths
    pattern = '\d+_+\d+_patch_\d{1,2}_\d{1,2}'
    patch_base_names = [re.search(pattern, mask_files[i]).group(0) for i in range(len(mask_files))]
    # The dictionary is integer-indexed to allow the dataset __getitem__ class to sample using an integer idx
    path = orthophoto_dir + '/labeling_subset'
    file_paths = {i:(path+'/images/'+name+'.tif',path+'/final_masks/'+name+'.tif') for i, name in enumerate(patch_base_names)}
    return file_paths


In [4]:
def train_test_split(file_paths:dict, test_size:float=0.2):
    """
    Splits a dictionary of file paths into training and test sets.

    Args:
        file_paths (dict): A dictionary containing the file paths of image and mask files.
                           The keys are integers representing the index, and the values are tuples
                           containing the image file path and the corresponding mask file path.
        test_size (float, optional): The proportion of the dataset to include in the test set.
                                     Default is 0.2 (20% of the dataset).

    Returns:
        tuple: A tuple containing two dictionaries representing the training and test sets.
               Each dictionary is integer-indexed for easy sampling using an integer index.
               The keys are integers representing the index, and the values are tuples
               containing the image file path and the corresponding mask file path.
    """
    from sklearn.model_selection import train_test_split

    train_keys, test_keys = train_test_split(list(file_paths.keys()), test_size=test_size)
    train_dict = {i:file_paths[key] for i,key in enumerate(train_keys)}
    test_dict = {i:file_paths[key] for i,key in enumerate(test_keys)}
    print(f"Length of all files: {len(file_paths)}")
    print(f"Length of train ({len(train_dict)}) and test ({len(test_dict)}): {len(train_dict)+len(test_dict)}")
    return train_dict, test_dict

### Dataloader

In [5]:
class MunichTuningDataset(Dataset):
    def __init__(self, file_paths, transform=None):
        self.file_paths = file_paths
        self.transform = transform

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        image_filepath = self.file_paths[idx][0]
        mask_filepath = self.file_paths[idx][1]

        image = np.array(Image.open(image_filepath))
        label_mask = np.array(Image.open(mask_filepath))
        # Convert RGB mask to label mask
    
        if self.transform is not None:
            transformed = self.transform(image=image, mask=label_mask)
            transformed_image = transformed['image']
            transformed_label_mask = transformed['mask']
            return transformed_image, transformed_label_mask
        else:
            return image, label_mask

In [6]:
# Define transforms to be used in the Training
train_transform = A.Compose(
    [
        A.Resize(height=512, width=512),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.ColorJitter(p=0.25),
        #A.RandomCrop(500, 500),
        A.Normalize(
            mean = [0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        ),
        ToTensorV2(),
    ],
)

test_transform = A.Compose(
    [
        A.Resize(height=512, width=512),
        A.Normalize(
            mean = [0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        ),
        ToTensorV2(),
    ],
)

In [7]:
def get_munich_tuning_loaders(batch_size=2):
    file_paths = get_file_paths()
    train_dict, test_dict = train_test_split(file_paths, test_size=0.2)
    BATCH_SIZE = 2
    train_loader = DataLoader(MunichTuningDataset(train_dict, transform=None), 
                            batch_size = BATCH_SIZE, 
                            num_workers = 2)
    test_loader = DataLoader(MunichTuningDataset(test_dict, transform=test_transform),
                            batch_size = BATCH_SIZE, 
                            num_workers = 2)
    print(f"Length of train loader: {len(train_loader)}; Length of test loader: {len(test_loader)} with batch size {BATCH_SIZE}")

    return train_loader, test_loader


In [8]:
train_loader, test_loader = get_munich_tuning_loaders(batch_size=8)

Indexing files in orthophotos/labeling_subset... 
Found: 	 150 Images 
	 83 Mask
Length of all files: 83
Length of train (66) and test (17): 83
Length of train loader: 33; Length of test loader: 9 with batch size 2


### Test DataLoader Output

In [9]:
RGB_classes = [
       (0, 0, 0), # ignore
       (255, 255, 225), # impervious
       (255,  0, 255), # building
       (255, 200, 0), # low vegetation
       (0,  0,  255), # water
       (0, 130, 0)] # trees
Label_classes = [
       "ignore",
       "impervious",
       "building",
       "low vegetation",
       "water",
       "trees"]

idx2label = {key: value for key, value in enumerate(Label_classes)}

# Create a dictionary to translate a mask to a rgb tensor
idx2rgb = {key: value for key, value in enumerate(RGB_classes)}
rgb2idx = {v: k for k, v in idx2rgb.items()}

# Dict to map from label to rgb
rgb2label = dict(zip(Label_classes, RGB_classes))

In [10]:
import numpy as np
import matplotlib.patches as mpatches

# iterate over the train_loader
for i, (images, masks) in enumerate(train_loader):
    # stop after the first batch
    #if i > 1:
    #    break

    batch_size = images.shape[0]

    for j in range(batch_size):
        # select the j-th image and mask from the batch
        image = images[j].numpy()
        mask = masks[j].numpy()

        # PyTorch dataloaders usually return images in (C, H, W) format,
        # so we need to transpose this to (H, W, C) for matplotlib to display it correctly
        if image.shape[0] == 3:
            image = image.transpose((1, 2, 0))

        # Create an empty RGB mask
        mask_rgb = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)

        # Map each label index to its RGB equivalent
        for idx, rgb in idx2rgb.items():
            mask_rgb[mask == idx] = rgb
        
        # Plotting
        fig, ax = plt.subplots(1, 2, figsize=(12, 6))

        ax[0].imshow(image)  # display the image
        ax[0].set_title(f'Image {j+1}')
        ax[0].axis('off')

        ax[1].imshow(mask_rgb)  # display the mask
        ax[1].set_title(f'Label Mask {j+1}')
        ax[1].axis('off')

        # Create a list of patches for the legend
        patches = [mpatches.Patch(color=np.array(rgb)/255., label=label) for label, rgb in rgb2label.items()]

        # Create legend
        plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0., title="Classes")

        plt.show()
