In [65]:
import PIL
import os
import logging
import pickle as pk
from collections import defaultdict

import ssl

ssl._create_default_https_context = ssl._create_unverified_context

import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from patchify import patchify,unpatchify

logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

logging.getLogger("PIL").setLevel(logging.WARNING)
logging.getLogger("matplotlib").setLevel(logging.WARNING)

PIL.Image.MAX_IMAGE_PIXELS = 933120000

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as T
from torchvision.transforms import InterpolationMode
from skimage.measure import block_reduce
from torch.utils.data import Dataset,DataLoader

from torchvision.models.resnet import resnet50, ResNet50_Weights

In [2]:
class MapPatch():
    def __init__(self, patch, patch_index, origin_map):
        self.patch = patch
        self.patch_index = patch_index
        self.origin_map = origin_map
        
    @staticmethod
    def get_map_patches(file_name, patch_width, map_transformer = None, verbose = True):
        tif_map = PIL.Image.open(file_name)
        tif_map_np = np.array(tif_map)
        
        if map_transformer is not None and verbose:
            logging.info(f"Applying transformation {map_transformer.__name__} to {file_name}")
            tif_map_np = map_transformer(tif_map_np)
        
        tif_map_patches = patchify(image = tif_map_np, 
                                   patch_size = (patch_width, patch_width, 3),
                                   step = patch_width)

        if verbose:
            logging.info(f"{np.prod(tif_map_patches.shape[:2]):,} patches from {file_name} generated with shape {tif_map_patches.shape}")

        return tif_map_np, tif_map_patches
    
    @staticmethod
    def get_map_patch_list(file_name, patch_width, map_transformer = None, verbose = True):
        _, tif_map_patches = MapPatch.get_map_patches(file_name, 
                                                      patch_width, 
                                                      map_transformer = map_transformer, 
                                                      verbose = verbose)
        patches = []
        
        for i in range(tif_map_patches.shape[0]):
            for j in range(tif_map_patches.shape[1]):
                patches.append(MapPatch(tif_map_patches[i,j,0], patch_index = (i,j), origin_map = file_name))
                
        return patches
    
    def show(self, verbose = True):
        fig, ax = plt.subplots()
        ax.imshow(self.patch)
        
        if verbose:
            ax.set_title(f"Patch at {self.patch_index} from {self.origin_map}.")
            
        plt.show()
    
    def __eq__(self, other):
        return self.patch_index == other.patch_index and self.origin_map == other.origin_map and np.all(np.isclose(self.patch, other.patch))

In [3]:
class PatchDataset(Dataset):
    def __init__(self, patches):
        self.patches = patches
    
    def __len__(self):
        return len(self.patches)
    
    def __getitem__(self, i):
        if isinstance(i, slice):
            start = i.start if i.start else 0
            stop = i.stop if i.stop else len(self.patches)
            step = i.step if i.step else 1
            
            return [(self.patches[j], self.patches[j].origin_map) for j in range(start, stop, step)]
        
        return (self.patches[i], self.patches[i].origin_map)
    
    @classmethod
    def from_dir(cls, directory, file_ext, patch_width, map_transformer = None):
        patches = []
        
        if file_ext == "tif":
            for file in os.listdir(directory):
                if file.endswith("tif"):
                    file_name = f"{directory}/{file}"
                    logging.info(f"Fetching patches from {file_name}.")
                    patches.extend(MapPatch.get_map_patch_list(file_name = file_name, 
                                                               patch_width = patch_width, 
                                                               map_transformer = map_transformer,
                                                               verbose = True))
        elif file_ext == "pk":
            for file in os.listdir(directory):
                if file.endswith("pk"):
                    file_name = f"{directory}/{file}"
                    logging.info(f"Fetching patches from {file_name}.")
                    with open(file_name, "rb") as f:
                        patches.extend(pk.load(file_name))
        else:
            print(f"{file_ext} is an invalid file format. Require tif or pk.")
            
        return cls(patches)
    
    def to_pickle(self, file_name = None):
        with open(f"{file_name}.pk", "wb") as f:
            pk.dump(self.patches, f)

In [4]:
class CLPatchDataset(Dataset):
    def __init__(self, patches):
        self.patches = patches
        self.patch_dict = self.__get_patch_dict(self.patches)
    
    def __len__(self):
        return len(self.patches)
    
    def __getitem__(self, i):
        
        if isinstance(i, slice):
            start = i.start if i.start else 0
            stop = i.stop if i.stop else len(self.patches)
            step = i.step if i.step else 1
            
            return [(self.patches[j], self.__get_matching_patches(self.patches[j])) for j in range(start, stop, step)]
        
        return (self.patches[i], self.__get_matching_patches(self.patches[i]))
    
    @classmethod
    def from_dir(cls, directory, file_ext, patch_width, map_transformer = None):
        patches = []
        
        if file_ext == "tif":
            for file in os.listdir(directory):
                if file.endswith("tif"):
                    file_name = f"{directory}/{file}"
                    logging.info(f"Fetching patches from {file_name}.")
                    patches.extend(MapPatch.get_map_patch_list(file_name = file_name, 
                                                               patch_width = patch_width, 
                                                               map_transformer = map_transformer,
                                                               verbose = True))
        elif file_ext == "pk":
            for file in os.listdir(directory):
                if file.endswith("pk"):
                    file_name = f"{directory}/{file}"
                    logging.info(f"Fetching patches from {file_name}.")
                    with open(file_name, "rb") as f:
                        patches.extend(pk.load(file_name))
        else:
            print(f"{file_ext} is an invalid file format. Require tif or pk.")
            
        return cls(patches)
    
    def __get_patch_dict(self, patches):
        patch_dict = {}
        
        for patch in patches:
            if patch.patch_index not in patch_dict.keys():
                patch_dict[patch.patch_index] = []
                
            patch_dict[patch.patch_index].append(patch)
        
        return patch_dict
    
    def __get_matching_patches(self, patch):
        return [match_patch for match_patch in self.patch_dict[patch.patch_index] if match_patch != patch]
        
    
    def to_pickle(self, file_name = None):
        with open(f"{file_name}.pk", "wb") as f:
            pk.dump(self.patches, f)

In [5]:
# transformations to apply to the map

def max_pooler(img, kernel_size):
    return block_reduce(img, block_size = (kernel_size, kernel_size,1), func = np.max)

def min_pooler(img, kernel_size):
    return block_reduce(img, block_size = (kernel_size, kernel_size,1), func = np.min)

def med_reduce(x, axis):
    return np.median(x,axis).astype(np.int32)

def med_pooler(img, kernel_size):
    return block_reduce(img, block_size = (kernel_size, kernel_size,1), func = med_reduce)

def mean_reduce(x, axis):
    return np.mean(x,axis).astype(np.int32)

def mean_pooler(img, kernel_size):
    return block_reduce(img, block_size = (kernel_size, kernel_size,1), func = mean_reduce)

def torch_downsample(img, kernel_size, interpolation = InterpolationMode.BILINEAR):
    size = img.shape
    
    new_size = (size[0]//kernel_size, size[1]//kernel_size)
    
    tensor_img = np.moveaxis(img, -1, 0)
    tensor_img = torch.Tensor(tensor_img)
    
    resized_map = T.Resize(new_size, interpolation=interpolation)(tensor_img)
    
    resized_map = resized_map.numpy()
    resized_map = np.moveaxis(resized_map, 0, -1)
    
    return resized_map.astype(int)

def bilinear_interpolator(img, kernel_size):
    return torch_downsample(img, kernel_size, interpolation = InterpolationMode.BILINEAR)

def bicubic_interpolator(img, kernel_size):
    return torch_downsample(img, kernel_size, interpolation = InterpolationMode.BICUBIC)

In [6]:
def bilinear_interpolator_4x4(img):
    return bilinear_interpolator(img, 4)

patch_dataset = CLPatchDataset.from_dir("data", 
                                       file_ext = "tif", 
                                       patch_width = 64, 
                                       map_transformer= bilinear_interpolator_4x4)
                           

INFO:root:Fetching patches from data/map_3.tif.
INFO:root:Applying transformation bilinear_interpolator_4x4 to data/map_3.tif
INFO:root:2,301 patches from data/map_3.tif generated with shape (39, 59, 1, 64, 64, 3)
INFO:root:Fetching patches from data/map_2.tif.
INFO:root:Applying transformation bilinear_interpolator_4x4 to data/map_2.tif
INFO:root:2,301 patches from data/map_2.tif generated with shape (39, 59, 1, 64, 64, 3)
INFO:root:Fetching patches from data/map_1.tif.
INFO:root:Applying transformation bilinear_interpolator_4x4 to data/map_1.tif
INFO:root:2,301 patches from data/map_1.tif generated with shape (39, 59, 1, 64, 64, 3)
INFO:root:Fetching patches from data/map_4.tif.
INFO:root:Applying transformation bilinear_interpolator_4x4 to data/map_4.tif
INFO:root:2,301 patches from data/map_4.tif generated with shape (39, 59, 1, 64, 64, 3)


In [54]:
class MapCLNN(nn.Module):
    def __init__(self):
        super(MapCLNN, self).__init__()
        
        # model hyperparameters
        self.MAX_PIXEL_VALUE = 255
        self.RESNET_DIM = 224
        self.RESNET_OUTPUT_DIM = 1000
        self.HIDDEN_DIM = 500
        self.OUTPUT_DIM = 100
        
        # resnet for encoding input images
        self.resnet = resnet50(weights=ResNet50_Weights.DEFAULT)
        
        # network layers
        self.lin_hidden = nn.Linear(in_features=self.RESNET_OUTPUT_DIM, out_features=self.HIDDEN_DIM, bias = False)
        self.batch_norm = nn.BatchNorm1d(num_features = self.HIDDEN_DIM)
        self.relu = nn.ReLU()
        self.lin_output = n.Linear(in_features = self.HIDDEN_DIM, out_features = self.OUTPUT_DIM, bias = False)
        
        # define optimiser
        self.optimiser = None
    
    def img_to_resnet(self, img, dim = None):
        """
        Convert image into the desired format for ResNet.
        The image must have width and height of at least self.RESNET_DIM, with RGB values between 0 and 1.
        Moreover, it must be normalised, by using a mean of [0.485, 0.456, 0.406] and a standard deviation of [0.229, 0.224, 0.225]
        --------------------------------------------------------------------------------------------------------------------------------
        :param img: a numpy nd.array, with 3 colour channels (this must be stored in the last dimensions), which has to be fed to ResNet
        :param dim: the desired dimension of the image (if we want to resize img before feeding it to ResNet).
                    This should be at least self.RESTNET_DIM.
        --------------------------------------------------------------------------------------------------------------------------------
        :return a Tensor, with the first dimension corresponding to the RGB channels, and normalised to be used by ResNet.
        """
        
        # put the colour channel in front
        if len(img.shape) == 3:
            norm_img = np.moveaxis(img, -1, 0)
        else:
            norm_img = np.moveaxis(img, -1, 1)
            
        # normalise into range [0,1]
        norm_img = torch.from_numpy(norm_img)/255
        
        # resize
        if dim is not None:
            assert dim >= self.RESNET_DIM, f"Provided dimension {dim} is less than the required for RESNET ({self.RESNET_DIM})"
            norm_img = T.Resize(dim)(norm_img)  
        else:
            norm_img = T.Resize(self.RESNET_DIM)(norm_img)
        
        # normalise mean and variance
        mean = torch.Tensor([0.485, 0.456, 0.406])
        std = torch.Tensor([0.229, 0.224, 0.225])
        
        return T.Normalize(mean = mean, std = std)(norm_img)
    
    def contrastive_loss(self, z_batch, tau):
        """
        Computes the contrastive loss (NT-XENT) for a mini-batch of augmented samples.
        --------------------------------------------------------------------------------------------------------
        z_batch: a (N,K) Tensor, with rows as embedding vectors. 
                 We expect that z_batch[2k] and z_batch[2k+1], 0 <= k < N, correspond to a positive sample pair
        tau: temperature parameter for NT-XENT loss
        --------------------------------------------------------------------------------------------------------
        return: a float, corresponding to the total loss for the mini-batch z_batch
        """
        N = len(z_batch)

        # normalise to have unit length rows
        norm_z_batch = F.normalize(z_batch)

        # compute similarity & apply factor of tau
        sim_batch = (norm_z_batch @ norm_z_batch.T)/tau

        # fill the diagonal with -1000, to make sure it is never considered in the cross entropy computations
        sim_batch.fill_diagonal_(-1000)

        # generate labels
        # z_batch[2k] should be similar to z_batch[2k+1] (since these will be the positive pair)
        # hence, labels should have the form [1,0,3,2,...,N,N-1]
        labels = torch.Tensor([k+1 if k%2 == 0 else k-1 for k in range(0,N)]).long()

        # return the NT-XENT loss
        return 1/N * F.cross_entropy(sim_batch, labels, reduction = "sum")
    
    def forward(self, x):
        """
        A forward pass through the network
        """
        res_x = self.img_to_resnet(x)
        
        model = nn.Sequential(
                self.resnet,
                self.lin_hidden,
                self.batch_norm,
                self.relu,
                self.lin_output
                )
        
        return model(res_x)
        
    
    def compile_optimiser(self, **kwargs):
        """
        Sets the optimiser parameters.
        """
        self.optimiser = optim.Adam(self.parameters(), **kwargs)
    
    def train(self, dataloader, tau, epochs):
        """
        Trains the network.
        """
        for epoch in epochs:
            for x_1,x_2 in dataloader:
                self.optimiser.zero_grad()
                
                z_1 = self(x_1)
                z_2 = self(x_2)
                
                z_batch = torch.stack((z_1,z_2), dim = 1).view(-1, self.OUTPUT_DIM)
                loss = self.contrastive_loss(z_batch, tau = tau)
                
                loss.backward()
                self.optimiser.step()
                
            if epoch % (epochs // 10) == 0:
                print(f"Epoch {epoch + 1} ---- NT-XENT = {loss}")