In [132]:
import numpy as np
from patchify import patchify,unpatchify
import PIL
import matplotlib.pyplot as plt
import os
import logging
import pickle as pk

logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

logging.getLogger("PIL").setLevel(logging.WARNING)
logging.getLogger("matplotlib").setLevel(logging.WARNING)

PIL.Image.MAX_IMAGE_PIXELS = 933120000

from torch.utils.data import Dataset,DataLoader

In [133]:
class MapPatch():
    def __init__(self, patch, patch_index, origin_map):
        self.patch = patch
        self.patch_index = patch_index
        self.origin_map = origin_map
        
    @staticmethod
    def get_map_patches(file_name, patch_width, verbose = False):
        tif_map = PIL.Image.open(file_name)
        tif_map_np = np.array(tif_map)
        tif_map_patches = patchify(image = tif_map_np, 
                                   patch_size = (patch_width, patch_width, 3),
                                   step = patch_width)

        if verbose:
            logging.info(f"{np.prod(tif_map_patches.shape[:2]):,} patches from {file_name} generated with shape {tif_map_patches.shape}")

        return tif_map_np, tif_map_patches
    
    @staticmethod
    def get_map_patch_list(file_name, patch_width, verbose = False):
        _, tif_map_patches = MapPatch.get_map_patches(file_name, patch_width, verbose = verbose)
        patches = []
        
        for i in range(tif_map_patches.shape[0]):
            for j in range(tif_map_patches.shape[1]):
                patches.append(MapPatch(tif_map_patches[i,j,0], patch_index = (i,j), origin_map = file_name))
                
        return patches

In [134]:
class PatchDataset(Dataset):
    def __init__(self, directory, file_ext = "tif", patch_width = 32):
        self.patches = self.load_patches_from_dir(directory, file_ext, patch_width)
    
    def __len__(self):
        return len(self.patches)
    
    def __getitem__(self, i):
        if isinstance(i, slice):
            start = i.start if i.start else 0
            stop = i.stop if i.stop else len(self.patches)
            step = i.step if i.step else 1
            
            return [(self.patches[j], self.patches[j].origin_map) for j in range(start, stop, step)]
        
        return (self.patches[i], self.patches[i].origin_map)
    
    def load_patches_from_dir(self, directory, file_ext, patch_width):
        patches = []
        
        if file_ext == "tif":
            for file in os.listdir(directory):
                if file.endswith("tif"):
                    file_name = f"{directory}/{file}"
                    logging.info(f"Fetching patches from {file_name}.")
                    patches.extend(MapPatch.get_map_patch_list(file_name = file_name, 
                                                               patch_width = patch_width, 
                                                               verbose = True))
        elif file_ext == "pk":
            for file in os.listdir(directory):
                if file.endswith("pk"):
                    file_name = f"{directory}/{file}"
                    logging.info(f"Fetching patches from {file_name}.")
                    with open(file_name, "rb") as f:
                        patches.extend(pk.load(file_name))
        else:
            print(f"{file_ext} is an invalid file format. Require tif or pk.")
            
        return patches
    
    def to_pickle(self, file_name = None):
        with open(f"{file_name}.pk", "wb") as f:
            pk.dump(self.patches, f)