In [2]:
from torch.utils.data import Dataset, DataLoader, random_split
import torch
from PIL import Image
import os
import numpy as np
from os.path import splitext
import matplotlib.pyplot as plt
from multiprocessing import Pool
from functools import partial
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings('ignore')

In [3]:
dataset_path= '../data/tiles/formatted'

In [4]:
def load_image(filename):
    ext = splitext(filename)[1]
    if ext == '.npy':
        return Image.fromarray(np.load(filename))
    elif ext in ['.pt', '.pth']:
        return Image.fromarray(torch.load(filename).numpy())
    else:
        return Image.open(filename)

In [21]:
class ImageDataset(Dataset):
    def __init__(self, dataset_path) -> None:
        files = os.listdir(dataset_path)
        self.images_path = []
        self.masks_path = []
        for file in files:
            if file.split('_')[1]=='tif':
                self.images_path.append(os.path.join(dataset_path,file))
                self.masks_path.append(os.path.join(dataset_path,file.replace('tif','shp')))
                
    def __len__(self):
        return len(self.images_path)
    
    @staticmethod
    def preprocess(mask_values, pil_img, scale, is_mask):
        w, h = pil_img.size
        newW, newH = int(scale * w), int(scale * h)
        assert newW > 0 and newH > 0, 'Scale is too small, resized images would have no pixel'
        pil_img = pil_img.resize((newW, newH), resample=Image.NEAREST if is_mask else Image.BICUBIC)
        img = np.asarray(pil_img)
        if img.ndim == 2:
            img = img[np.newaxis, ...]
        else:
            img = img.transpose((2, 0, 1))
        if (img > 1).any():
            img = img / 255.0
        return img
        
    def __getitem__(self, idx):
        #return {
        #    'image': self.images_path[idx],
        #    'mask': self.masks_path[idx]
        #}
        mask = load_image(self.images_path[idx])
        img = load_image(self.masks_path[idx])
        
        assert img.size == mask.size, \
            f'Image and mask should be the same size, but are {img.size} and {mask.size}'

        self.scale = 1.0
        #self.scale = 0.2
        
        img = self.preprocess(1, img, self.scale, is_mask=False)
        mask = self.preprocess(1, mask, self.scale, is_mask=True)
        
        return {
            'image': torch.as_tensor(img.copy()).float().contiguous(),
            'mask': torch.as_tensor(mask.copy()).long().contiguous()
        }
    
class ImageNameDataset(Dataset):
    def __init__(self, dataset_path) -> None:
        files = os.listdir(dataset_path)
        self.images_path = []
        self.masks_path = []
        for file in files:
            if file.split('_')[1]=='tif':
                self.images_path.append(os.path.join(dataset_path,file))
                self.masks_path.append(os.path.join(dataset_path,file.replace('tif','shp')))
                
    def __len__(self):
        return len(self.images_path)
    
    def __getitem__(self, idx):
        return {
            'image': self.images_path[idx],
            'mask': self.masks_path[idx]
        }

In [22]:
dataset = ImageNameDataset(dataset_path)
data_loader = DataLoader(dataset, batch_size=1)

for batch in data_loader:
    print(batch)

{'image': ['../data/tiles/formatted/tile_tif_0_2.png'], 'mask': ['../data/tiles/formatted/tile_shp_0_2.png']}
{'image': ['../data/tiles/formatted/tile_tif_2_0.png'], 'mask': ['../data/tiles/formatted/tile_shp_2_0.png']}
{'image': ['../data/tiles/formatted/tile_tif_2_1.png'], 'mask': ['../data/tiles/formatted/tile_shp_2_1.png']}
{'image': ['../data/tiles/formatted/tile_tif_0_3.png'], 'mask': ['../data/tiles/formatted/tile_shp_0_3.png']}
{'image': ['../data/tiles/formatted/tile_tif_0_1.png'], 'mask': ['../data/tiles/formatted/tile_shp_0_1.png']}
{'image': ['../data/tiles/formatted/tile_tif_2_3.png'], 'mask': ['../data/tiles/formatted/tile_shp_2_3.png']}
{'image': ['../data/tiles/formatted/tile_tif_2_2.png'], 'mask': ['../data/tiles/formatted/tile_shp_2_2.png']}
{'image': ['../data/tiles/formatted/tile_tif_0_0.png'], 'mask': ['../data/tiles/formatted/tile_shp_0_0.png']}
{'image': ['../data/tiles/formatted/tile_tif_0_4.png'], 'mask': ['../data/tiles/formatted/tile_shp_0_4.png']}
{'image': 

In [23]:
val_percent = 0.2

n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val
train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0))

print(len(train_set),len(val_set))


12 3


In [51]:
batch_size = 3
print()
#num_workers = os.cpu_count() - not working . . .

loader_args = dict(batch_size=batch_size, pin_memory=True)#, num_workers=os.cpu_count(), pin_memory=True)
train_loader = DataLoader(train_set, shuffle=True, **loader_args)
val_loader = DataLoader(val_set, shuffle=False, **loader_args)




In [49]:
print('train_set:')
for batch in train_loader:
    print(batch)
    
print('val_set:')
for batch in val_loader:
    print(batch)

train_set:
{'image': ['../data/tiles/formatted/tile_tif_0_2.png', '../data/tiles/formatted/tile_tif_2_3.png', '../data/tiles/formatted/tile_tif_1_3.png'], 'mask': ['../data/tiles/formatted/tile_shp_0_2.png', '../data/tiles/formatted/tile_shp_2_3.png', '../data/tiles/formatted/tile_shp_1_3.png']}
{'image': ['../data/tiles/formatted/tile_tif_0_1.png', '../data/tiles/formatted/tile_tif_1_4.png', '../data/tiles/formatted/tile_tif_1_2.png'], 'mask': ['../data/tiles/formatted/tile_shp_0_1.png', '../data/tiles/formatted/tile_shp_1_4.png', '../data/tiles/formatted/tile_shp_1_2.png']}
{'image': ['../data/tiles/formatted/tile_tif_2_1.png', '../data/tiles/formatted/tile_tif_2_2.png', '../data/tiles/formatted/tile_tif_0_3.png'], 'mask': ['../data/tiles/formatted/tile_shp_2_1.png', '../data/tiles/formatted/tile_shp_2_2.png', '../data/tiles/formatted/tile_shp_0_3.png']}
{'image': ['../data/tiles/formatted/tile_tif_0_4.png', '../data/tiles/formatted/tile_tif_2_0.png', '../data/tiles/formatted/tile_ti