In [21]:
import torch
import cv2
import numpy as np
import glob
import math
import os

import torch
import torch.utils.data as torch_data
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F

In [31]:
class GTiffDataset(torch_data.Dataset):
    def __init__(self, root_dir, tile_size = 256, stride = 256, debug = False, transform=None):
        """
        Args:
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.tile_size = tile_size
        self.stride = stride
        self.root_dir = root_dir
        self.transform = transform
        self.debug = debug
        self.images, self.masks = self.read_dir()
        
    def get_tiles(self, image, mask):
        i_tiles, m_tiles = [], []
        width = image.shape[1] - image.shape[1]%self.stride
        height = image.shape[0] - image.shape[0]%self.stride
        print(height, width, image.shape)
                
        for i in range(height, self.stride):
            for j in range(width, self.stride):
                if i+self.tile_size > image.shape[0]:
                    break
                img_tile = image[
                    i:i+self.tile_size, 
                    j:j+self.tile_size
                ]
                mask_tile = image[
                    i:i+self.tile_size, 
                    j:j+self.tile_size
                ]
                i_tiles.append(img_tile)
                m_tiles.append(mask_tile)
                
                if self.debug:
                    # Debugging the tiles
                    cv2.imwrite("debug_" + str(i) + "_" + str(j) + "_img.png", img_tile)
                    cv2.imwrite("debug_" + str(i) + "_" + str(j) + "_mask.png", mask_tile)
        return i_tiles, m_tiles

    def read_dir(self):
        tiles = [[], []]
        images = sorted(glob.glob(self.root_dir + '/' + '*_4326_cropped.png'))
        masks = sorted(glob.glob(self.root_dir + '/' + '*_mask_4326.tif'))
        for i, m in zip(images, masks):
            image = cv2.imread(i)
            mask = cv2.imread(m, 0)
            i_tiles, m_tiles = self.get_tiles(image, mask)
            for im, ma in zip(i_tiles, m_tiles):
                tiles[0].append(im)
                tiles[1].append(ma)
        
        return tiles
    
    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        sample = [self.images[idx], self.masks[idx]]
        if self.transform:
            sample = self.transform(sample)
        return sample

In [32]:
gtiffdataset = GTiffDataset('../data/pre-processed/dryvalleys/QB02', stride=128)

128 1920 (200, 1945, 3)
26240 15360 (26296, 15402, 3)
640 8192 (750, 8235, 3)
512 768 (622, 808, 3)
2176 15360 (2300, 15414, 3)
2432 15360 (2441, 15449, 3)
