In [2]:
from torch.utils.data import Dataset, DataLoader
import torch
from PIL import Image
import os
import numpy as np
from os.path import splitext
import matplotlib.pyplot as plt
from multiprocessing import Pool
from functools import partial
from tqdm.auto import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
dataset_path= '../data/tiles/formatted'

In [4]:
def load_image(filename):
    ext = splitext(filename)[1]
    if ext == '.npy':
        return Image.fromarray(np.load(filename))
    elif ext in ['.pt', '.pth']:
        return Image.fromarray(torch.load(filename).numpy())
    else:
        return Image.open(filename)

In [5]:
class ImageDataset(Dataset):
    def __init__(self, dataset_path) -> None:
        files = os.listdir(dataset_path)
        self.images_path = []
        self.masks_path = []
        for file in files:
            if file.split('_')[1]=='tif':
                self.images_path.append(os.path.join(dataset_path,file))
                self.masks_path.append(os.path.join(dataset_path,file.replace('tif','shp')))
                
    def __len__(self):
        return len(self.images_path)
    
    @staticmethod
    def preprocess(mask_values, pil_img, scale, is_mask):
        w, h = pil_img.size
        newW, newH = int(scale * w), int(scale * h)
        assert newW > 0 and newH > 0, 'Scale is too small, resized images would have no pixel'
        pil_img = pil_img.resize((newW, newH), resample=Image.NEAREST if is_mask else Image.BICUBIC)
        img = np.asarray(pil_img)
        if img.ndim == 2:
            img = img[np.newaxis, ...]
        else:
            img = img.transpose((2, 0, 1))
        if (img > 1).any():
            img = img / 255.0
        return img
        
    def __getitem__(self, idx):
        #return {
        #    'image': self.images_path[idx],
        #    'mask': self.masks_path[idx]
        #}
        mask = load_image(self.images_path[idx])
        img = load_image(self.masks_path[idx])
        
        assert img.size == mask.size, \
            f'Image and mask should be the same size, but are {img.size} and {mask.size}'

        self.scale = 1.0
        
        img = self.preprocess(1, img, self.scale, is_mask=False)
        mask = self.preprocess(1, mask, self.scale, is_mask=True)
        
        return {
            'image': torch.as_tensor(img.copy()).float().contiguous(),
            'mask': torch.as_tensor(mask.copy()).long().contiguous()
        }

In [7]:
dataset = ImageDataset(dataset_path)
data_loader = DataLoader(dataset, batch_size=1)

for batch in data_loader:
    print(batch)

{'image': tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]]]), 'mask': tensor([[[[  0,   0,   0,  ...,  24,  25,  24],
          [  0,   0,   0,  ...,  26,  26,  25],
          [  0,   0,   0,  ...,  28,  27,  25],
          ...,
          [131