In [15]:
import numpy as np
from patchify import patchify,unpatchify
import PIL
import matplotlib.pyplot as plt
import os
import logging
import pickle as pk

logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

logging.getLogger("PIL").setLevel(logging.WARNING)
logging.getLogger("matplotlib").setLevel(logging.WARNING)

PIL.Image.MAX_IMAGE_PIXELS = 933120000

import torch
from torch.utils.data import Dataset
import torchvision.transforms as T
from torchvision.transforms import InterpolationMode
from skimage.measure import block_reduce

In [11]:
class MapPatch():
    def __init__(self, patch, patch_index, origin_map):
        self.patch = patch
        self.patch_index = patch_index
        self.origin_map = origin_map
        
    @staticmethod
    def get_map_patches(file_name, patch_width, map_transformer = None, verbose = True):
        tif_map = PIL.Image.open(file_name)
        tif_map_np = np.array(tif_map)
        
        if map_transformer is not None and verbose:
            logging.info(f"Applying transformation {map_transformer.__name__} to {file_name}")
            tif_map_np = map_transformer(tif_map_np)
        
        tif_map_patches = patchify(image = tif_map_np, 
                                   patch_size = (patch_width, patch_width, 3),
                                   step = patch_width)

        if verbose:
            logging.info(f"{np.prod(tif_map_patches.shape[:2]):,} patches from {file_name} generated with shape {tif_map_patches.shape}")

        return tif_map_np, tif_map_patches
    
    @staticmethod
    def get_map_patch_list(file_name, patch_width, map_transformer = None, verbose = True):
        _, tif_map_patches = MapPatch.get_map_patches(file_name, 
                                                      patch_width, 
                                                      map_transformer = map_transformer, 
                                                      verbose = verbose)
        patches = []
        
        for i in range(tif_map_patches.shape[0]):
            for j in range(tif_map_patches.shape[1]):
                patches.append(MapPatch(tif_map_patches[i,j,0], patch_index = (i,j), origin_map = file_name))
                
        return patches
    
    def show(self, verbose = True):
        fig, ax = plt.subplots()
        ax.imshow(self.patch)
        
        if verbose:
            ax.set_title(f"Patch at {self.patch_index} from {self.origin_map}.")
            
        plt.show()

In [12]:
class PatchDataset(Dataset):
    def __init__(self, patches):
        self.patches = patches
    
    def __len__(self):
        return len(self.patches)
    
    def __getitem__(self, i):
        if isinstance(i, slice):
            start = i.start if i.start else 0
            stop = i.stop if i.stop else len(self.patches)
            step = i.step if i.step else 1
            
            return [(self.patches[j], self.patches[j].origin_map) for j in range(start, stop, step)]
        
        return (self.patches[i], self.patches[i].origin_map)
    
    @classmethod
    def from_dir(cls, directory, file_ext, patch_width, map_transformer = None):
        patches = []
        
        if file_ext == "tif":
            for file in os.listdir(directory):
                if file.endswith("tif"):
                    file_name = f"{directory}/{file}"
                    logging.info(f"Fetching patches from {file_name}.")
                    patches.extend(MapPatch.get_map_patch_list(file_name = file_name, 
                                                               patch_width = patch_width, 
                                                               map_transformer = map_transformer,
                                                               verbose = True))
        elif file_ext == "pk":
            for file in os.listdir(directory):
                if file.endswith("pk"):
                    file_name = f"{directory}/{file}"
                    logging.info(f"Fetching patches from {file_name}.")
                    with open(file_name, "rb") as f:
                        patches.extend(pk.load(file_name))
        else:
            print(f"{file_ext} is an invalid file format. Require tif or pk.")
            
        return cls(patches)
    
    def to_pickle(self, file_name = None):
        with open(f"{file_name}.pk", "wb") as f:
            pk.dump(self.patches, f)

In [13]:
# transformations to apply to the map

def max_pooler(img, kernel_size):
    return block_reduce(img, block_size = (kernel_size, kernel_size,1), func = np.max)

def min_pooler(img, kernel_size):
    return block_reduce(img, block_size = (kernel_size, kernel_size,1), func = np.min)

def med_reduce(x, axis):
    return np.median(x,axis).astype(np.int32)

def med_pooler(img, kernel_size):
    return block_reduce(img, block_size = (kernel_size, kernel_size,1), func = med_reduce)

def mean_reduce(x, axis):
    return np.mean(x,axis).astype(np.int32)

def mean_pooler(img, kernel_size):
    return block_reduce(img, block_size = (kernel_size, kernel_size,1), func = mean_reduce)

def torch_downsample(img, kernel_size, interpolation = InterpolationMode.BILINEAR):
    size = img.shape
    
    new_size = (size[0]//kernel_size, size[1]//kernel_size)
    
    tensor_img = np.moveaxis(img, -1, 0)
    tensor_img = torch.Tensor(tensor_img)
    
    resized_map = T.Resize(new_size, interpolation=interpolation)(tensor_img)
    
    resized_map = resized_map.numpy()
    resized_map = np.moveaxis(resized_map, 0, -1)
    
    return resized_map.astype(int)

def bilinear_interpolator(img, kernel_size):
    return torch_downsample(img, kernel_size, interpolation = InterpolationMode.BILINEAR)

def bicubic_interpolator(img, kernel_size):
    return torch_downsample(img, kernel_size, interpolation = InterpolationMode.BICUBIC)

In [16]:
def bilinear_interpolator_4x4(img):
    return bilinear_interpolator(img, 4)

pd = PatchDataset.from_dir("data", 
                           file_ext = "tif", 
                           patch_width = 64, 
                           map_transformer= bilinear_interpolator_4x4)
                           

INFO:root:Fetching patches from data/map_3.tif.
INFO:root:Applying transformation bilinear_interpolator_4x4 to data/map_3.tif
INFO:root:2,301 patches from data/map_3.tif generated with shape (39, 59, 1, 64, 64, 3)
INFO:root:Fetching patches from data/map_2.tif.
INFO:root:Applying transformation bilinear_interpolator_4x4 to data/map_2.tif
INFO:root:2,301 patches from data/map_2.tif generated with shape (39, 59, 1, 64, 64, 3)
INFO:root:Fetching patches from data/map_1.tif.
INFO:root:Applying transformation bilinear_interpolator_4x4 to data/map_1.tif
INFO:root:2,301 patches from data/map_1.tif generated with shape (39, 59, 1, 64, 64, 3)
INFO:root:Fetching patches from data/map_4.tif.
INFO:root:Applying transformation bilinear_interpolator_4x4 to data/map_4.tif
INFO:root:2,301 patches from data/map_4.tif generated with shape (39, 59, 1, 64, 64, 3)


for i in range(4):
    patch = pd[2301*i][0]
    img = PIL.Image.fromarray(np.uint8(patch.patch))
    img.save(f"data/patch_samples/patch_{patch.origin_map[5:-4]}.png")