Functions for image loading and showing results

In [1]:
import tifffile as tifff

def load_acrobat_image(datafolder, case, stain, pyr_idx, train_val_str):
    if train_val_str.lower() == 'train':
        image_path = datafolder + os.path.sep + str(case) + '_' + stain + '.tiff'
    elif train_val_str.lower() == 'val':
        image_path = datafolder + os.path.sep + str(case) + '_' + stain + '_val.tiff'
    elif train_val_str.lower() == 'test':
        image_path = datafolder + os.path.sep + str(case) + '_' + stain + '_test.tiff'
    else: 
        ValueError("Three options for the last input: 'train' 'val' 'test'")
    image = tifff.imread(image_path, key=pyr_idx)
    return(image)

def show_aligned_images(moved_image_channel, target_image_channel,alpha=0.8):
    target_image_rgb = np.transpose(np.stack((target_image_channel,target_image_channel,target_image_channel)),axes=[1,2,0])
    target_image_rgb = target_image_rgb.astype(np.uint8)
    aligned_images_rgb = target_image_rgb
    aligned_images_rgb[:,:,1] = alpha*moved_image_channel + (1-alpha)*aligned_images_rgb[:,:,1]
    aligned_images_rgb = aligned_images_rgb.astype(np.uint8)
    return(aligned_images_rgb)

Functions for image preprocessing, tissue segmentation and SIFT extraction

In [2]:
import cv2

def round_up_to_odd(f):
    return np.ceil(f) // 2 * 2 + 1

def segment_he_tissue(he_channel, res, sigma=25, crop_ratio=0.1):
    clahe = cv2.createCLAHE(clipLimit=0.2, tileGridSize=(round(100/res), round(100/res)))
    eq = clahe.apply(he_channel)
    
    ksize = sigma*4
    ksize = round_up_to_odd(ksize/res)
    sigma = ksize/4
    gf = cv2.GaussianBlur(eq,(int(ksize),int(ksize)),sigma)
    
    start_row = int(crop_ratio*gf.shape[0])
    end_row = int(gf.shape[0]-crop_ratio*gf.shape[0])+1
    start_col = int(crop_ratio*gf.shape[1])
    end_col = int(gf.shape[1]-crop_ratio*gf.shape[1])+1
    gf_crop = gf[start_row:end_row, start_col:end_col]
        
    pixel_vals = gf_crop.reshape((-1,1)) 
    pixel_vals = np.float32(pixel_vals)
    criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 100, 0.85) #criteria

    k = 3 # Choosing number of cluster
    _, labels, centers = cv2.kmeans(pixel_vals, k, None, criteria, 10, cv2.KMEANS_RANDOM_CENTERS) 
    
    bg_label = labels == np.where(np.sum(centers,axis=1) == np.amin(np.sum(centers,axis=1)))
    bg_mask = bg_label.reshape((gf_crop.shape[0:2])) # reshape data into the original image dimensions
    bg_mask_full = cv2.copyMakeBorder(np.uint8(bg_mask), start_row, gf.shape[0] - gf_crop.shape[0] - start_row, start_col, gf.shape[1] - gf_crop.shape[1] - start_col, cv2.BORDER_CONSTANT, None, value = 1)
    
    tissue_mask_full = bg_mask_full == 0

    return(tissue_mask_full)

def segment_ihc_tissue(ihc_channel, res, sigma=25, crop_ratio=0.1):
    clahe = cv2.createCLAHE(clipLimit=2, tileGridSize=(round(100/res), round(100/res)))
    eq = clahe.apply(ihc_channel)
    
    ksize = sigma*4
    ksize = round_up_to_odd(ksize/res)
    sigma = ksize/4
    gf = 255-cv2.GaussianBlur(eq,(int(ksize),int(ksize)),sigma)
    
    start_row = int(crop_ratio*gf.shape[0])
    end_row = int(gf.shape[0]-crop_ratio*gf.shape[0])+1
    start_col = int(crop_ratio*gf.shape[1])
    end_col = int(gf.shape[1]-crop_ratio*gf.shape[1])+1
    gf_crop = gf[start_row:end_row, start_col:end_col]
        
    pixel_vals = gf_crop.reshape((-1,1)) 
    pixel_vals = np.float32(pixel_vals)
    criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 100, 0.85) #criteria
    
    k = 3 
    _, labels, centers = cv2.kmeans(pixel_vals, k, None, criteria, 10, cv2.KMEANS_RANDOM_CENTERS) 
    
    bg_label = labels == np.where(np.sum(centers,axis=1) == np.amin(np.sum(centers,axis=1)))
    bg_mask = bg_label.reshape((gf_crop.shape[0:2])) # reshape data into the original image dimensions
    bg_mask_full = cv2.copyMakeBorder(np.uint8(bg_mask), start_row, gf.shape[0] - gf_crop.shape[0] - start_row, start_col, gf.shape[1] - gf_crop.shape[1] - start_col, cv2.BORDER_CONSTANT, None, value = 1)
    
    tissue_mask_full = bg_mask_full == 0
    
    return(tissue_mask_full)

def sift_he(he_lab, res, sigma=10, crop_ratio=0.06):
    he_channel_seg = he_lab[:,:,1].copy()
    he_tissue_mask = segment_he_tissue(he_channel_seg, res, sigma=25, crop_ratio=0.06)
    
    he_channel_sift = he_lab[:,:,0].copy()
    
    clahe = cv2.createCLAHE(clipLimit=0.2, tileGridSize=(round(100/res), round(100/res)))
    he_channel_sift_eq = clahe.apply(he_channel_sift)
    
    sigma = 5
    ksize = sigma*4
    ksize = round_up_to_odd(ksize/res)
    sigma = ksize/4
    he_channel_sift_gf = 255-cv2.GaussianBlur(he_channel_sift_eq,(int(ksize),int(ksize)),sigma)
    
    start_row = int(crop_ratio*he_channel_sift_gf.shape[0])
    end_row = int(he_channel_sift_gf.shape[0]-crop_ratio*he_channel_sift_gf.shape[0])+1
    start_col = int(crop_ratio*he_channel_sift_gf.shape[1])
    end_col = int(he_channel_sift_gf.shape[1]-crop_ratio*he_channel_sift_gf.shape[1])+1
    he_channel_sift_gf[0:start_row, :] = np.median(he_channel_sift_gf[he_tissue_mask==0])
    he_channel_sift_gf[end_row:, :] = np.median(he_channel_sift_gf[he_tissue_mask==0])
    he_channel_sift_gf[:, 0:start_col] = np.median(he_channel_sift_gf[he_tissue_mask==0])
    he_channel_sift_gf[:, end_col:] = np.median(he_channel_sift_gf[he_tissue_mask==0])

    sift = cv2.SIFT_create()
    he_kp, he_des = sift.detectAndCompute(he_channel_sift_gf,None)
    
    he_show_keypoints = cv2.drawKeypoints(he_channel_sift_gf,he_kp,np.array([]))
    return(he_show_keypoints,he_kp,he_des,he_tissue_mask,he_channel_sift_gf)

def sift_ihc(ihc_hsv, res, sigma=10, crop_ratio=0.06):
    ihc_channel_seg = ihc_hsv[:,:,2].copy()
    ihc_tissue_mask = segment_ihc_tissue(ihc_channel_seg, res, sigma=25, crop_ratio=0.06)
    
    ihc_channel_sift = ihc_hsv[:,:,2].copy()
    
    clahe = cv2.createCLAHE(clipLimit=2, tileGridSize=(round(100/res), round(100/res)))
    ihc_channel_sift_eq = clahe.apply(ihc_channel_sift)
    
    sigma = 5
    ksize = sigma*4
    ksize = round_up_to_odd(ksize/res)
    sigma = ksize/4
    ihc_channel_sift_gf = 255-cv2.GaussianBlur(ihc_channel_sift_eq,(int(ksize),int(ksize)),sigma)
    
    start_row = int(crop_ratio*ihc_channel_sift_gf.shape[0])
    end_row = int(ihc_channel_sift_gf.shape[0]-crop_ratio*ihc_channel_sift_gf.shape[0])+1
    start_col = int(crop_ratio*ihc_channel_sift_gf.shape[1])
    end_col = int(ihc_channel_sift_gf.shape[1]-crop_ratio*ihc_channel_sift_gf.shape[1])+1
    ihc_channel_sift_gf[0:start_row, :] = np.median(ihc_channel_sift_gf[ihc_tissue_mask==0])
    ihc_channel_sift_gf[end_row:, :] = np.median(ihc_channel_sift_gf[ihc_tissue_mask==0])
    ihc_channel_sift_gf[:, 0:start_col] = np.median(ihc_channel_sift_gf[ihc_tissue_mask==0])
    ihc_channel_sift_gf[:, end_col:] = np.median(ihc_channel_sift_gf[ihc_tissue_mask==0])

    sift = cv2.SIFT_create()
    ihc_kp,ihc_des = sift.detectAndCompute(ihc_channel_sift_gf,None)
    
    ihc_show_keypoints = cv2.drawKeypoints(ihc_channel_sift_gf,ihc_kp,np.array([]))
    return(ihc_show_keypoints,ihc_kp,ihc_des,ihc_tissue_mask,ihc_channel_sift_gf)

Matching keypoints and ransac with DICE

In [3]:
import cv2
import numpy as np
import tqdm

def dice(pred, true, k = 1):
    intersection = np.sum(pred[true==k]) * 2.0
    dice = intersection / (np.sum(pred) + np.sum(true))
    return dice

def umeyama(P, Q):
    assert P.shape == Q.shape
    n, dim = P.shape

    centeredP = P - P.mean(axis=0)
    centeredQ = Q - Q.mean(axis=0)

    C = np.dot(np.transpose(centeredP), centeredQ) / n

    V, S, W = np.linalg.svd(C)
    d = (np.linalg.det(V) * np.linalg.det(W)) < 0.0

    if d:
        S[-1] = -S[-1]
        V[:, -1] = -V[:, -1]

    R = np.dot(V, W)

    varP = np.var(P, axis=0).sum()
    c = 1/varP * np.sum(S) # scale factor

    t = Q.mean(axis=0) - P.mean(axis=0).dot(c*R)

    return c, R, t

def matching_keypoints(source_image,target_image,source_des,target_des,source_kp,target_kp,lowe_ratio=0.9):
    FLANN_INDEX_KDTREE = 0
    index_params = dict(algorithm = FLANN_INDEX_KDTREE, trees = 5)
    search_params = dict(checks = 50)
    flann = cv2.FlannBasedMatcher(index_params, search_params)
    matches = flann.knnMatch(source_des,target_des,k=2)
    
    # Lowe's Ratio test
    good = []
    for m, n in matches:
        if m.distance < lowe_ratio*n.distance:
            good.append(m)
    
    good_source_kp = np.float32([ source_kp[m.queryIdx].pt for m in good ]).reshape(-1, 2)
    good_target_kp = np.float32([ target_kp[m.trainIdx].pt for m in good ]).reshape(-1, 2)
        
    n_good = len(good)
    good_source_kp_cv2 = [cv2.KeyPoint(point[0], point[1], 1) for point in good_source_kp]
    good_target_kp_cv2 = [cv2.KeyPoint(point[0], point[1], 1) for point in good_target_kp]
    placeholder_matches = [cv2.DMatch(idx, idx, 1) for idx in range(n_good)]
    show_matching_keypoints = cv2.drawMatches(source_image, good_source_kp_cv2, target_image, good_target_kp_cv2, placeholder_matches, None)
    return(good_source_kp, good_target_kp,show_matching_keypoints)

def dransac(source_kp, target_kp, source_tissue_mask, target_tissue_mask, proc_source_image, proc_target_image, source_image, target_image, epochs=30000, err_threshold=50, dice_weight=0.5, ninliers_weight=0.5):
    cols = target_tissue_mask.shape[1]
    rows = target_tissue_mask.shape[0]
    
    tot_indices = []
    tot_dice = []
    tot_mi = []
    tot_ninliers = []
    
    nsamples_iter = 3
    for itera in tqdm.tqdm(range(epochs)):
        indices = np.random.permutation(len(source_kp))[0:nsamples_iter]
        curr_source_kp = source_kp[indices,:]
        curr_target_kp = target_kp[indices,:]
    
        c, R, t = umeyama(curr_source_kp, curr_target_kp)

        if c > 0.95 and c < 1.05:
            tot_indices.append(indices)
            alpha = c*R[0,0]
            beta = c*R[0,1]
            M = np.float32([[alpha, -beta, t[0]],[beta, alpha, t[1]]])
            source_tissue_mask_reg = cv2.warpAffine(np.float32(source_tissue_mask), M, (cols,rows))
            source_kp_reg = source_kp.dot(c * R) + t

            err = np.sqrt(np.sum((source_kp_reg - target_kp) ** 2,axis=1))
            idx_inlier = np.where(err<err_threshold)
            n_inlier = len(idx_inlier[0])
            tot_ninliers.append(n_inlier)        

            curr_dice = dice(source_tissue_mask_reg,target_tissue_mask)
            tot_dice.append(curr_dice) 

    tot_ninliers = np.array(tot_ninliers)
    tot_ninliers_n = (tot_ninliers - tot_ninliers.min()) / (tot_ninliers.max() - tot_ninliers.min())
    
    tot_dice = np.array(tot_dice)
    tot_dice_n = (tot_dice - tot_dice.min()) / (tot_dice.max() - tot_dice.min())
    
    obj_func = dice_weight*tot_dice_n**2 + ninliers_weight*tot_ninliers_n**2

    best_iter = np.argmax(obj_func)
    best_ninliers = tot_ninliers[best_iter]
    best_dice = tot_dice[best_iter]
    
    print('Best iteration: ' + str(best_iter) +
    ' Number of inliers: ' + str(best_ninliers) +
    ' Dice: ' + str(best_dice))  

    best_source_kp = source_kp[tot_indices[best_iter],:]
    best_target_kp = target_kp[tot_indices[best_iter],:]
    
    c, R, t = umeyama(best_source_kp, best_target_kp)
    source_kp_reg = source_kp.dot(c * R) + t
    err = np.sqrt(np.sum((source_kp_reg - target_kp) ** 2,axis=1))
    idx_best_inlier = np.where(err<err_threshold)
    best_n_inlier = len(idx_best_inlier[0])
    best_inlier_source_kp = source_kp[idx_best_inlier]
    best_inlier_target_kp = target_kp[idx_best_inlier]
    c, R, t = umeyama(best_inlier_source_kp, best_inlier_target_kp)
    alpha = c*R[0,0]
    beta = c*R[0,1]
    best_M = np.float32([[alpha, -beta, t[0]],[beta, alpha, t[1]]])

    best_inlier_source_kp_cv2 = [cv2.KeyPoint(point[0], point[1], 1) for point in best_inlier_source_kp]
    best_inlier_target_kp_cv2 = [cv2.KeyPoint(point[0], point[1], 1) for point in best_inlier_target_kp]
    placeholder_matches = [cv2.DMatch(idx, idx, 1) for idx in range(best_n_inlier)]
    show_best_matching_keypoints = cv2.drawMatches(source_image, best_inlier_source_kp_cv2, target_image, best_inlier_target_kp_cv2, placeholder_matches, None)
    return(best_M,best_ninliers,best_dice,show_best_matching_keypoints)

INR models  

In [4]:
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

class Siren(nn.Module):
    """This is a dense neural network with sine activation functions.

    Arguments:
    layers -- ([*int]) amount of nodes in each layer of the network, e.g. [2, 16, 16, 1]
    gpu -- (boolean) use GPU when True, CPU when False
    weight_init -- (boolean) use special weight initialization if True
    omega -- (float) parameter used in the forward function
    """

    def __init__(self, layers, gpu=False, weight_init=True, omega=30):
        """Initialize the network."""

        super(Siren, self).__init__()
        
        self.n_layers = len(layers) - 1
        self.omega = omega

        # Make the layers
        self.layers = []
        for i in range(self.n_layers):
            self.layers.append(nn.Linear(layers[i], layers[i + 1]))

            # Weight Initialization
            if weight_init:
                with torch.no_grad():
                    if i == 0:
                        self.layers[-1].weight.uniform_(-1 / layers[i],
                                                        1 / layers[i])
                    else:
                        self.layers[-1].weight.uniform_(-np.sqrt(6 / layers[i]) / self.omega,
                                                        np.sqrt(6 / layers[i]) / self.omega)

        # Combine all layers to one model
        self.layers = nn.Sequential(*self.layers)

    def forward(self, x):
        """The forward function of the network."""
        # Perform relu on all layers except for the last one
        for layer in self.layers[:-1]:
            z = layer(x)            
            x = torch.sin(self.omega * z)

        # Propagate through final layer and return the output
        x = self.layers[-1](x)    
        return x 
    
    
class BaseNet(nn.Module):
    def __init__(self, layers):
        """Initialize the network."""

        super(BaseNet, self).__init__()        
        
        self.n_layers = len(layers) - 1

        # Make the layers
        self.layers = []
        for i in range(self.n_layers):
            self.layers.append(nn.Linear(layers[i], layers[i + 1]))
            
        # Combine all layers to one model
        self.layers = nn.Sequential(*self.layers)

    def forward(self, x):
        """The forward function of the network."""
        # Perform relu on all layers except for the last one
        for layer in self.layers[:-1]:
            x = torch.nn.functional.relu(layer(x))

        # Propagate through final layer and return the output
        return self.layers[-1](x)       

NCC Metric functions

In [5]:
import math
from typing import Optional, Tuple

from torch import Tensor
from torch.nn.modules.loss import _Loss

class StableStd(torch.autograd.Function):
    @staticmethod
    def forward(ctx, tensor):
        assert tensor.numel() > 1
        ctx.tensor = tensor.detach()
        res = torch.std(tensor).detach()
        ctx.result = res.detach()
        return res

    @staticmethod
    def backward(ctx, grad_output):
        tensor = ctx.tensor.detach()
        result = ctx.result.detach()
        e = 1e-6
        assert tensor.numel() > 1
        return (
            (2.0 / (tensor.numel() - 1.0))
            * (grad_output.detach() / (result.detach() * 2 + e))
            * (tensor.detach() - tensor.mean().detach())
        )


stablestd = StableStd.apply

def ncc(x1, x2, e=1e-10):
    assert x1.shape == x2.shape, "Inputs are not of similar shape"
    x1 = x1.view(-1, 25)
    x2 = x2.view(-1, 25)
    cc = ((x1 - x1.mean(dim=1)[:, None]) * (x2 - x2.mean(dim=1)[:, None])).mean(dim=1)
    std = x1.std(dim=1) * x2.std(dim=1)
    ncc = torch.mean(cc/(std+e))
    return ncc

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

class NCC(_Loss):
    def __init__(self, use_mask: bool = False):
        super().__init__()
        self.forward = self.metric

    def metric(self, fixed: Tensor, warped: Tensor) -> Tensor:
        return -ncc(fixed, warped)

Functions for the learning models

In [6]:
class LearningModel():
    """This class contains functions which are useful for all the learning models."""

    def __init__(self, **kwargs):
        """Initialize the learning model."""

        # Set all default arguments in a dict: self.args
        self.setDefaultArguments()

        # Check if all kwargs keys are valid (this checks for typos)
        assert all(kwarg in self.args.keys() for kwarg in kwargs)

        # Parse important argument from kwargs
        self.epochs = kwargs['epochs'] if 'epochs' in kwargs else self.args['epochs']
        self.log_interval = kwargs['log_interval'] if 'log_interval' in kwargs else self.args['log_interval']
        self.gpu = kwargs['gpu'] if 'gpu' in kwargs else self.args['gpu']
        self.lr = kwargs['lr'] if 'lr' in kwargs else self.args['lr']
        self.momentum = kwargs['momentum'] if 'momentum' in kwargs else self.args['momentum']
        self.optimizer_arg = kwargs['optimizer'] if 'optimizer' in kwargs else self.args['optimizer']
        self.loss_function_arg = kwargs['loss_function'] if 'loss_function' in kwargs else self.args['loss_function']
        self.layers = kwargs['layers'] if 'layers' in kwargs else self.args['layers']
        self.weight_init = kwargs['weight_init'] if 'weight_init' in kwargs else self.args['weight_init']
        self.omega = kwargs['omega'] if 'omega' in kwargs else self.args['omega']
        self.save_folder = kwargs['save_folder'] if 'save_folder' in kwargs else self.args['save_folder']
        self.gabor_scale = kwargs['gabor_scale'] if 'gabor_scale' in kwargs else self.args['gabor_scale']

        # Parse other arguments from kwargs
        self.verbose = kwargs['verbose'] if 'verbose' in kwargs else self.args['verbose']

        # Make folder for output
        if not self.save_folder == '' and not os.path.isdir(self.save_folder):
            os.mkdir(self.save_folder)

        # Add slash to divide folder and filename
        self.save_folder += '/'

        # Make loss list to save losses
        self.loss_list = [0 for _ in range(self.epochs)]
        self.data_loss_list = [0 for _ in range(self.epochs)]

        # Set seed
        torch.manual_seed(self.args['seed'])

        # Load network
        self.network_from_file = kwargs['network'] if 'network' in kwargs else self.args['network']
        if self.network_from_file is None:
            self.network = BaseNet(self.layers) 
            # self.network = Siren(self.layers, self.gpu, self.weight_init, self.omega)            
        else:
            self.network = torch.load(self.network_from_file)
            if self.gpu:
                self.network.cuda()

        # Choose the optimizer
        if self.optimizer_arg.lower() == 'sgd':
            self.optimizer = optim.SGD(self.network.parameters(), lr=self.lr, momentum=self.momentum)

        elif self.optimizer_arg.lower() == 'adam':
            self.optimizer = optim.Adam(self.network.parameters(), lr=self.lr) # Adam(self.network.parameters(), lr=self.lr)

        elif self.optimizer_arg.lower() == 'adadelta':
            self.optimizer = optim.Adadelta(self.network.parameters(), lr=self.lr)

        elif self.optimizer_arg.lower() == 'rmsprop':
            self.optimizer = optim.RMSprop(self.network.parameters(), lr=self.lr)            
            
        else:
            self.optimizer = optim.SGD(self.network.parameters(), lr=self.lr, momentum=self.momentum)
            print('WARNING: ' + str(self.optimizer_arg) + ' not recognized as optimizer, picked SGD instead')

        self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, self.epochs // 3, gamma=0.5)
        
        # Choose the loss function
        if self.loss_function_arg.lower() == 'mse':
            self.criterion = nn.MSELoss()

        elif self.loss_function_arg.lower() == 'l1':
            self.criterion = nn.L1Loss()

        elif self.loss_function_arg.lower() == 'ncc':
            self.criterion = NCC()            
            
        elif self.loss_function_arg.lower() == 'smoothl1':
            self.criterion = nn.SmoothL1Loss(beta=0.2)
            
        elif self.loss_function_arg.lower() == 'huber':
            self.criterion = nn.HuberLoss()            

        else:
            self.criterion = nn.MSELoss()
            print('WARNING: ' + str(self.loss_function_arg) + ' not recognized as loss function, picked MSE instead')

        # Move variables to GPU
        if self.gpu:
            self.network.cuda()

    def cuda(self):
        """Move the model to the GPU."""

        self.gpu = True
        self.network.cuda()
        
    def divergence(self, input_coords, output):
        """Compute the divergence of the output wrt the input."""

        div = 0
        for i in range(output.shape[-1]):
            div += torch.autograd.grad(output[..., i], input_coords, torch.ones_like(output[..., i]), create_graph=True)[0][..., i:i + 1]
        return div

    def gradient(self, input_coords, output, grad_outputs=None):
        """Compute the gradient of the output wrt the input."""

        grad_outputs = torch.ones_like(output)
        grad = torch.autograd.grad(output, [input_coords], grad_outputs=grad_outputs, create_graph=True)[0]
        return grad

    def laplace(self, input_coords, output, grad_outputs=None):
        """Compute the laplacian of the output wrt the input."""

        grad = self.gradient(input_coords, output)

        return self.divergence(input_coords, grad)

    def makeCoordinateSlice(self, dims=(28, 28), dimension=0, slice_pos=0, gpu=True):
        """Make a coordinate tensor."""

        dims = list(dims)

        coordinate_tensor = [torch.linspace(-1, 1, dims[i]) for i in range(2)]
        coordinate_tensor = torch.meshgrid(*coordinate_tensor)
        coordinate_tensor = torch.stack(coordinate_tensor, dim=2)
        coordinate_tensor = coordinate_tensor.view([np.prod(dims), 2])

        # Move to GPU if necessary
        if self.gpu and gpu:
            coordinate_tensor = coordinate_tensor.cuda()

        return coordinate_tensor

    def makeCoordinateTensor(self, dims=(28, 28, 28), gpu=True):
        """Make a coordinate tensor."""

        coordinate_tensor = [torch.linspace(-1, 1, dims[i]) for i in range(3)]
        coordinate_tensor = torch.meshgrid(*coordinate_tensor)
        coordinate_tensor = torch.stack(coordinate_tensor, dim=3)
        coordinate_tensor = coordinate_tensor.view([np.prod(dims), 3])

        # Move to GPU if necessary
        if self.gpu and gpu:
            coordinate_tensor = coordinate_tensor.cuda()

        return coordinate_tensor
    
    def makeMaskedCoordinateTensor(self, mask, dims=(28, 28, 28), gpu=True):
        """Make a coordinate tensor."""

        coordinate_tensor = [torch.linspace(-1, 1, dims[i]) for i in range(2)]
        coordinate_tensor = torch.meshgrid(*coordinate_tensor)
        coordinate_tensor = torch.stack(coordinate_tensor, dim=2)
        coordinate_tensor = coordinate_tensor.view([np.prod(dims), 2])
        coordinate_tensor = coordinate_tensor[mask.flatten() > 0, :]

        # Move to GPU if necessary
        if self.gpu and gpu:
            coordinate_tensor = coordinate_tensor.cuda()

        return coordinate_tensor  
    
    def setDefaultArguments(self):
        """Set default arguments."""

        self.args = {}

        self.args['network'] = None

        self.args['epochs'] = 200
        self.args['log_interval'] = 250
        self.args['verbose'] = True
        self.args['save_folder'] = 'output'

        self.args['gpu'] = torch.cuda.is_available()
        self.args['optimizer'] = 'Adam'
        self.args['loss_function'] = 'MSE'
        self.args['lr'] = 0.001
        self.args['momentum'] = 0.5

        self.args['layers'] = [2, 32, 32, 32, 1]
        self.args['positional_encoding'] = False
        self.args['weight_init'] = True
        self.args['omega'] = 30

        self.args['seed'] = 1 # 13 # 2 # 1
        
        self.args['offset'] = [0, 0, 0]
        self.args['gabor_scale'] = 1.0
        
    def train(self, epochs=None, red_blue=False):
        """Train the network."""
        # Determine epochs
        if epochs is None:
            epochs = self.epochs

        # Set seed
        torch.manual_seed(self.args['seed'])

        # Extend lost_list if necessary
        if not len(self.loss_list) == epochs:
            self.loss_list = [0 for _ in range(epochs)]
            self.data_loss_list = [0 for _ in range(epochs)]

        # Perform training iterations
        for i in tqdm.tqdm(range(epochs)):
            self.trainingIteration(i)

Functions for registering two images using the INR

In [7]:
from skimage.util import compare_images
      
class ImplicitRegistrator(LearningModel):
    """This is a class for registrating implicitly represented images."""

    def __call__(self, coordinate_tensor=None, output_shape=(28, 28), dimension=0, slice_pos=0):
        """Return the image-values for the given input-coordinates."""

        output_size = np.prod(output_shape)

        # Use standard coordinate tensor if none is given
        if coordinate_tensor is None:
            coordinate_tensor = self.makeCoordinateSlice(output_shape, dimension, slice_pos)

        output = self.network(coordinate_tensor)

        # Shift coordinates by 1/n * v
        coord_temp = torch.add(output, coordinate_tensor)
        
        output = coord_temp

        transformed_image = self.transformNoAdd(coord_temp)
        return transformed_image.cpu().detach().numpy()

    def __init__(self, moving_image, fixed_image, **kwargs):
        """Initialize the registrator."""

        # Initialize standard learning model
        super().__init__(**kwargs)

        # Parse arguments from kwargs
        self.mask = kwargs['mask'] if 'mask' in kwargs else self.args['mask']
       
        # Velocity integration steps
        self.velocity_steps = kwargs['velocity_steps'] if 'velocity_steps' in kwargs else self.args['velocity_steps']

        # Parse regularization kwargs
        self.output_regularization = kwargs['output_regularization'] if 'output_regularization' in kwargs else self.args['output_regularization']
        self.alpha_output = kwargs['alpha_output'] if 'alpha_output' in kwargs else self.args['alpha_output']
        self.reg_norm_output = kwargs['reg_norm_output'] if 'reg_norm_output' in kwargs else self.args['reg_norm_output']

        self.jacobian_regularization = kwargs['jacobian_regularization'] if 'jacobian_regularization' in kwargs else self.args['jacobian_regularization']
        self.alpha_jacobian = kwargs['alpha_jacobian'] if 'alpha_jacobian' in kwargs else self.args['alpha_jacobian']
        self.reg_norm_jacobian = kwargs['reg_norm_jacobian'] if 'reg_norm_jacobian' in kwargs else self.args['reg_norm_jacobian']

        self.diffusion_regularization = kwargs['diffusion_regularization'] if 'diffusion_regularization' in kwargs else self.args['diffusion_regularization']
        self.alpha_diffusion = kwargs['alpha_diffusion'] if 'alpha_diffusion' in kwargs else self.args['alpha_diffusion']

        self.elastic_regularization = kwargs['elastic_regularization'] if 'elastic_regularization' in kwargs else self.args['elastic_regularization']
        self.alpha_elastic = kwargs['alpha_elastic'] if 'alpha_elastic' in kwargs else self.args['alpha_elastic']

        self.quadratic_regularization = kwargs['quadratic_regularization'] if 'quadratic_regularization' in kwargs else self.args['quadratic_regularization']
        self.alpha_quadratic = kwargs['alpha_quadratic'] if 'alpha_quadratic' in kwargs else self.args['alpha_quadratic']

        self.ogden_regularization = kwargs['ogden_regularization'] if 'ogden_regularization' in kwargs else self.args['ogden_regularization']
        self.alpha_ogden = kwargs['alpha_ogden'] if 'alpha_ogden' in kwargs else self.args['alpha_ogden']

        self.hyper_regularization = kwargs['hyper_regularization'] if 'hyper_regularization' in kwargs else self.args['hyper_regularization']
        self.alpha_hyper = kwargs['alpha_hyper'] if 'alpha_hyper' in kwargs else self.args['alpha_hyper']

        self.time_regularization = kwargs['time_regularization'] if 'time_regularization' in kwargs else self.args['time_regularization']
        self.alpha_time = kwargs['alpha_time'] if 'alpha_time' in kwargs else self.args['alpha_time']

        self.bending_regularization = kwargs['bending_regularization'] if 'bending_regularization' in kwargs else self.args['bending_regularization']
        self.alpha_bending = kwargs['alpha_bending'] if 'alpha_bending' in kwargs else self.args['alpha_bending']        
        
        # Set seed
        torch.manual_seed(self.args['seed'])

        # Parse arguments from kwargs
        self.image_shape = kwargs['image_shape'] if 'image_shape' in kwargs else self.args['image_shape']
        self.batch_size = kwargs['batch_size'] if 'batch_size' in kwargs else self.args['batch_size']

        # Initialization
        self.moving_image = moving_image
        self.fixed_image = fixed_image
        
        self.image_shape = (int(self.fixed_image.shape[0]), int(self.fixed_image.shape[1]))
        self.possible_coordinate_tensor = self.makeMaskedCoordinateTensor(self.mask, self.fixed_image.shape)       
        ncc_width = 0.025 #  (1-epoch/self.epochs)*0.3+0.0001        
        coordinate_tensor_loc = [torch.linspace(-1*ncc_width, ncc_width, 5) for i in range(2)]
        coordinate_tensor_loc = torch.meshgrid(*coordinate_tensor_loc)
        coordinate_tensor_loc = torch.stack(coordinate_tensor_loc, dim=2)
        coordinate_tensor_loc = coordinate_tensor_loc.view([25, 2])
        coordinate_tensor_loc = torch.tile(coordinate_tensor_loc, (int(self.batch_size/25), 1))
        self.coordinate_tensor_loc = coordinate_tensor_loc.cuda()    
    
    
        if self.gpu:
            self.moving_image = self.moving_image.cuda()
            self.fixed_image = self.fixed_image.cuda()

        # Make coordinate_slice and image tensor for showing progress
        if self.verbose:
            self.coordinate_slice = self.makeCoordinateSlice(self.image_shape, 1, 0)
            self.fixed_image_tensor = self.interpolate(self.fixed_image, self.coordinate_slice)
            self.fixed_image_tensor = self.fixed_image_tensor.detach().cpu().view(self.image_shape)
            self.fixed_image_tensor_masked = self.interpolate(self.fixed_image  * torch.from_numpy(self.mask).cuda(), self.coordinate_slice)
            self.fixed_image_tensor_masked = self.fixed_image_tensor_masked.detach().cpu().view(self.image_shape)            

        # Move variables to GPU
        if self.gpu:
            self.moving_image.cuda()
            self.fixed_image.cuda()
        
        self.offset = torch.FloatTensor(kwargs['offset']).cuda()

    def computeDiffusionLoss(self, input_coords, output, batch_size=None):
        """Compute the diffustion regularization loss."""

        if batch_size is None:
            batch_size = self.batch_size

        jacobian_matrix = self.computeJacobianMatrix(input_coords, output, add_identity=False)

        loss = torch.linalg.norm(jacobian_matrix, dim=2)
        loss = torch.pow(loss, 2)
        loss = torch.sum(loss)

        return 0.5 * loss / batch_size

    def computeElasticLoss(self, input_coords, output, batch_size=None, nu=1, mu=1):
        """Compute the (linear) elastic regularization loss."""

        if batch_size is None:
            batch_size = self.batch_size

        jacobian_matrix = self.computeJacobianMatrix(input_coords, output, add_identity=False)

        V = jacobian_matrix + torch.transpose(jacobian_matrix, 1, 2)
        V = 0.5 * V

        loss = nu * torch.sum(torch.pow(self.computeTrace(V), 2))
        loss += mu * self.computeTraceSum(torch.matmul(V, V))

        return loss / batch_size

    def computeHyperElasticLoss(self, input_coords, output, batch_size=None, alpha_l=1, alpha_a=1, alpha_v=1):
        """Compute the hyper-elastic regularization loss."""

        if batch_size is None:
            batch_size = self.batch_size

        grad_u = self.computeJacobianMatrix(input_coords, output, add_identity=False)
        grad_y = self.computeJacobianMatrix(input_coords, output, add_identity=True)  # This is slow, faster to infer from grad_u

        # Compute length loss
        length_loss = torch.linalg.norm(grad_u, dim=(1, 2))
        length_loss = torch.pow(length_loss, 2)
        length_loss = torch.sum(length_loss)
        length_loss = 0.5 * alpha_l * length_loss

        # Compute cofactor matrices for the area loss
        cofactors = torch.zeros(batch_size, 3, 3)

        # Compute elements of cofactor matrices one by one (Ugliest solution ever?)
        cofactors[:, 0, 0] = torch.det(grad_y[:, 1:, 1:])
        cofactors[:, 0, 1] = torch.det(grad_y[:, 1:, 0::2])
        cofactors[:, 0, 2] = torch.det(grad_y[:, 1:, :2])
        cofactors[:, 1, 0] = torch.det(grad_y[:, 0::2, 1:])
        cofactors[:, 1, 1] = torch.det(grad_y[:, 0::2, 0::2])
        cofactors[:, 1, 2] = torch.det(grad_y[:, 0::2, :2])
        cofactors[:, 2, 0] = torch.det(grad_y[:, :2, 1:])
        cofactors[:, 2, 1] = torch.det(grad_y[:, :2, 0::2])
        cofactors[:, 2, 2] = torch.det(grad_y[:, :2, :2])

        # Compute area loss
        area_loss = torch.pow(cofactors, 2)
        area_loss = torch.sum(area_loss, dim=1)
        area_loss = area_loss - 1
        area_loss = torch.maximum(area_loss, torch.zeros_like(area_loss))
        area_loss = torch.pow(area_loss, 2)
        area_loss = torch.sum(area_loss)  # sum over dimension 1 and then 0
        area_loss = alpha_a * area_loss

        # Compute volume loss
        volume_loss = torch.det(grad_y)
        volume_loss = torch.mul(torch.pow(volume_loss - 1, 4), torch.pow(volume_loss, -2))
        volume_loss = torch.sum(volume_loss)
        volume_loss = alpha_v * volume_loss

        # Compute total loss
        loss = length_loss + area_loss + volume_loss

        return loss / batch_size

    def computeBendingEnergy(self, input_coords, output, batch_size=None):
        """Compute the bending energy."""
        
        if batch_size is None:
            batch_size = self.batch_size
        
        jacobian_matrix = self.computeJacobianMatrix(input_coords, output, add_identity=False)

        dx_xy = torch.zeros(input_coords.shape[0], 2, 2)
        dy_xy = torch.zeros(input_coords.shape[0], 2, 2)
        for i in range(2):
            dx_xy[:, i, :] = self.gradient(input_coords, jacobian_matrix[:, i, 0])
            dy_xy[:, i, :] = self.gradient(input_coords, jacobian_matrix[:, i, 1])

        dx_xy = torch.square(dx_xy)
        dy_xy = torch.square(dy_xy)

        loss = torch.mean(dx_xy[:, :, 0]) + torch.mean(dy_xy[:, :, 1])
        loss += 2 * torch.mean(dx_xyz[:, :, 1]) + 2 * torch.mean(dx_xyz[:, :, 2]) + torch.mean(dy_xyz[:, :, 2])

        return loss / batch_size
    
    def computeJacobianLoss(self, input_coords, output, batch_size=None):
        """Compute the jacobian regularization loss."""

        # Compute Jacobian matrices
        jac = self.computeJacobianMatrix(input_coords, output)

        # Compute determinants and take norm
        loss = torch.det(jac) - 1
        loss = torch.linalg.norm(loss, self.reg_norm_jacobian)

        return loss / self.batch_size

    def computeJacobianMatrix(self, input_coords, output, add_identity=True):
        """Compute the Jacobian matrix of the output wrt the input."""

        jacobian_matrix = torch.zeros(input_coords.shape[0], 2, 2)
        for i in range(2):
            jacobian_matrix[:, i, :] = self.gradient(input_coords, output[:, i])
            if add_identity:
                jacobian_matrix[:, i, i] += torch.ones_like(jacobian_matrix[:, i, i])
        return jacobian_matrix      

    def computeOgdenLoss(self, input_coords, output, batch_size=None, alpha_l=1, alpha_a=1, alpha_v=1):
        """Compute the Ogden regularization loss."""

        if batch_size is None:
            batch_size = self.batch_size

        grad_u = self.computeJacobianMatrix(input_coords, output, add_identity=False)
        grad_y = self.computeJacobianMatrix(input_coords, output, add_identity=True)  # This is slow, faster to infer from grad_u

        length_loss = torch.linalg.norm(grad_u, dim=(1, 2))
        length_loss = torch.pow(length_loss, 2)
        length_loss = torch.sum(length_loss)
        length_loss = 0.5 * alpha_l * length_loss

        # Compute cofactor matrices for the area loss
        cofactors = torch.zeros(batch_size, 3, 3)

        # Compute elements of cofactor matrices one by one (Ugliest solution ever?)
        cofactors[:, 0, 0] = torch.det(grad_y[:, 1:, 1:])
        cofactors[:, 0, 1] = torch.det(grad_y[:, 1:, 0::2])
        cofactors[:, 0, 2] = torch.det(grad_y[:, 1:, :2])
        cofactors[:, 1, 0] = torch.det(grad_y[:, 0::2, 1:])
        cofactors[:, 1, 1] = torch.det(grad_y[:, 0::2, 0::2])
        cofactors[:, 1, 2] = torch.det(grad_y[:, 0::2, :2])
        cofactors[:, 2, 0] = torch.det(grad_y[:, :2, 1:])
        cofactors[:, 2, 1] = torch.det(grad_y[:, :2, 0::2])
        cofactors[:, 2, 2] = torch.det(grad_y[:, :2, :2])

        # Compute area loss
        area_loss = torch.linalg.norm(cofactors, dim=(1, 2))
        area_loss = torch.pow(area_loss, 2)
        area_loss = torch.sum(area_loss)
        area_loss = alpha_a * area_loss

        # Compute volume loss
        volume_loss = torch.det(grad_y)
        volume_loss = torch.abs(volume_loss)  # To prevent taking log of negative number
        volume_loss = torch.pow(volume_loss, 2) - torch.log(volume_loss)
        volume_loss = torch.sum(volume_loss)
        volume_loss = alpha_v * volume_loss

        # Compute total loss
        loss = length_loss + area_loss + volume_loss

        return loss / batch_size

    def computeQuadraticElasticLoss(self, input_coords, output, batch_size=None, nu=1, mu=1):
        """Compute the quadratic elastic regularization loss."""

        if batch_size is None:
            batch_size = self.batch_size

        jacobian_matrix = self.computeJacobianMatrix(input_coords, output, add_identity=False)
        jacobian_matrix_transpose = torch.transpose(jacobian_matrix, 1, 2)

        E = jacobian_matrix + jacobian_matrix_transpose + torch.matmul(jacobian_matrix_transpose, jacobian_matrix)
        E = 0.5 * E

        loss = nu * torch.sum(torch.pow(self.computeTrace(E), 2))
        loss += mu * self.computeTraceSum(torch.matmul(E, E))

        return loss / batch_size

    def computeTrace(self, tensor, n=3):
        """Compute the traces of an array of nxn matrices."""

        return torch.sum(torch.mul(tensor, torch.eye(n)), (1, 2))

    def computeTraceSum(self, tensor, n=3):
        """Compute the sum of the traces of an array of nxn matrices."""

        return torch.sum(torch.mul(tensor, torch.eye(n)))

    def cuda(self):
        """Move the model to the GPU."""

        # Standard variables
        super().cuda()

        # Variables specific to this class
        self.moving_image.cuda()
        self.fixed_image.cuda()
        
    def getBendingForZ(self, size=(28, 28), dimension=0, z=0.0):
        output_size = np.prod(size)
        coordinate_slice = self.makeCoordinateSlice(size, dimension, z).requires_grad_(True)
        output = self.network(coordinate_slice)   
        jacobian_matrix = self.computeJacobianMatrix(coordinate_slice, output, add_identity=False)

        dx_xyz = torch.zeros(coordinate_slice.shape[0], 3, 3)
        dy_xyz = torch.zeros(coordinate_slice.shape[0], 3, 3)
        dz_xyz = torch.zeros(coordinate_slice.shape[0], 3, 3)
        for i in range(3):
            dx_xyz[:, i, :] = self.gradient(coordinate_slice, jacobian_matrix[:, i, 0])
            dy_xyz[:, i, :] = self.gradient(coordinate_slice, jacobian_matrix[:, i, 1])
            dz_xyz[:, i, :] = self.gradient(coordinate_slice, jacobian_matrix[:, i, 2])

        dx_xyz = torch.square(dx_xyz)
        dy_xyz = torch.square(dy_xyz)
        dz_xyz = torch.square(dz_xyz)

        bending = dx_xyz[:, :, 0] + dy_xyz[:, :, 1] + dz_xyz[:, :, 2]
        bending += 2 * dx_xyz[:, :, 1] + 2 * dx_xyz[:, :, 2] + dy_xyz[:, :, 2]

        return bending
               
    def getJacobianForZ(self, size=(28, 28), dimension=0, z=0.0):
        """Get the Jacobian Matrix and return its determinant."""

        output_size = np.prod(size)
        coordinate_slice = self.makeCoordinateSlice(size, dimension, z).requires_grad_(True)
        output = self.network(coordinate_slice)
        jacobian_matrix = self.computeJacobianMatrix(coordinate_slice, output)
        return torch.det(jacobian_matrix), output
        
        
    def getJacobian(self, size=(28, 28), dimension=0):
        """Get the Jacobian Matrix and return its determinant."""

        jacobian_matrix = self.getJacobianMatrix(size, dimension)
        determinant = torch.det(jacobian_matrix)
        return determinant

    def getJacobianMatrix(self, size=(28, 28), dimension=0):
        """Compute the jacobian matrix at points on a grid of size: self.jacobian_size.

        computeJacobianMatrix is used during training when input and output are already availible.
        getJacobianMatrix is used outside of training (e.g. for post-training visulaization).
        """

        output_size = np.prod(size)

        coordinate_slice = self.makeCoordinateSlice(size, dimension).requires_grad_(True)
        
        print(coordinate_slice)
        time_vector = torch.zeros(output_size, 1)
        if self.gpu:
            time_vector = time_vector.cuda()

        # output = self.network(torch.cat((coordinate_slice, time_vector), 1))
        output = self.network(coordinate_slice)

        output = output / self.velocity_steps

        # Shift coordinates by 1/n * v
        coord_temp = torch.add(output, coordinate_slice)

        # Velocity field integration
        for t in range(self.velocity_steps - 1):
            time_vector = ((t + 1) / self.velocity_steps) * torch.ones(output_size, 1)
            if self.gpu:
                time_vector = time_vector.cuda()

            output = self.network(torch.cat((coord_temp, time_vector), 1))

            output = output / self.velocity_steps

            # Shift coordinates by 1/n * v
            coord_temp = torch.add(output, coord_temp)

        output = torch.subtract(coord_temp, coordinate_slice)

        jacobian_matrix = self.computeJacobianMatrix(coordinate_slice, output)
        return jacobian_matrix

    def getJacobianMatrixHist(self, size=(28, 28), dimension=0, batchsize=5000):
        """Compute the jacobian matrix at points on a grid of size: self.jacobian_size.

        computeJacobianMatrix is used during training when input and output are already availible.
        getJacobianMatrix is used outside of training (e.g. for post-training visulaization).
        """

        output_size = np.prod(size)

        coordinate_slice = torch.rand(batchsize, 3).requires_grad_(True) * 2 - 1
        coordinate_slice = coordinate_slice.cuda()

        time_vector = torch.zeros(batchsize, 1)
        if self.gpu:
            time_vector = time_vector.cuda()

        output = self.network(torch.cat((coordinate_slice, time_vector), 1))

        output = output / self.velocity_steps

        # Shift coordinates by 1/n * v
        coord_temp = torch.add(output, coordinate_slice)

        # Velocity field integration
        for t in range(self.velocity_steps - 1):
            time_vector = ((t + 1) / self.velocity_steps) * torch.ones(output_size, 1)
            if self.gpu:
                time_vector = time_vector.cuda()

            output = self.network(torch.cat((coord_temp, time_vector), 1))

            output = output / self.velocity_steps

            # Shift coordinates by 1/n * v
            coord_temp = torch.add(output, coord_temp)

        output = torch.subtract(coord_temp, coordinate_slice)

        jacobian_matrix = self.computeJacobianMatrix(coordinate_slice, output)
        determinant = torch.det(jacobian_matrix).detach().cpu().numpy()

        plt.figure(figsize=(8, 8))
        plt.hist(determinant)
        plt.show()
        return determinant
    
    def interpolate(self, input_array, coordinates):
        return faster_bilinear_interpolation(input_array, coordinates[:, 0], coordinates[:, 1])
    
    def setDefaultArguments(self):
        """Set default arguments."""

        # Inherit default arguments from standard learning model
        super().setDefaultArguments()

        # Define the value of arguments
        self.args['mask'] = None
        self.args['mask_2'] = None

        self.args['method'] = 1

        self.args['lr'] = 0.0001
        self.args['batch_size'] = 5000
        self.args['layers'] = [4, 64, 64, 64, 3] # [3, 64, 64, 64, 3]
        self.args['gif_name'] = 'ImplicitRegistrator.gif'
        self.args['velocity_steps'] = 1

        # Define argument defaults specific to this class
        self.args['output_regularization'] = False
        self.args['alpha_output'] = 0.2
        self.args['reg_norm_output'] = 1

        self.args['jacobian_regularization'] = False
        self.args['alpha_jacobian'] = 0.085
        self.args['reg_norm_jacobian'] = 1

        self.args['diffusion_regularization'] = False
        self.args['alpha_diffusion'] = 0.01

        self.args['elastic_regularization'] = False
        self.args['alpha_elastic'] = 0.01

        self.args['quadratic_regularization'] = False
        self.args['alpha_quadratic'] = 0.01

        self.args['ogden_regularization'] = False
        self.args['alpha_ogden'] = 0.01

        self.args['hyper_regularization'] = False
        self.args['alpha_hyper'] = 0.01

        self.args['time_regularization'] = False
        self.args['alpha_time'] = 0.01

        self.args['bending_regularization'] = False
        self.args['alpha_bending'] = 0.05

        self.args['image_shape'] = (200, 200)

        self.args['inbetween_images'] = None
        self.args['inbetween_alpha'] = 0.1
        
    def train(self, epochs=None, red_blue=True):
        """Train the network.

        This function inherits the function from the super-class, but with different default arguments.
        """

        super().train(epochs, red_blue)

    def trainingIteration(self, epoch):
        """Perform one iteration of training."""

        # Reset the gradient
        self.network.train()

        loss = 0
        
        indices = torch.randperm(self.possible_coordinate_tensor.shape[0], device='cuda')[:int(self.batch_size/25)]
        coordinate_tensor = self.possible_coordinate_tensor[indices, :]
        coordinate_tensor = torch.repeat_interleave(coordinate_tensor, torch.tensor(np.ones(int(self.batch_size/25))*25).int().cuda(), dim=0)
                
        coordinate_tensor = coordinate_tensor + self.coordinate_tensor_loc
        # print(coordinate_tensor)
        # print(coordinate_tensor_loc.shape)        
        
        # coordinate_tensor = coordinate_tensor[mask.flatten() > 0, :]        
        coordinate_tensor = coordinate_tensor.requires_grad_(True) 

#         if self.gpu:
#             coordinate_tensor = coordinate_tensor.cuda()

        output = self.network(coordinate_tensor) + self.offset
        coord_temp = torch.add(output, coordinate_tensor)
        output = coord_temp
        
        transformed_image = self.transformNoAdd(coord_temp)
        fixed_image = self.interpolate(self.fixed_image, coordinate_tensor)

        # Compute the loss
        loss += self.criterion(transformed_image, fixed_image)

        # Store the value of the data loss
        if self.verbose:
            self.data_loss_list[epoch] = loss.detach().cpu().numpy()

        # Relativation of output
        output_rel = torch.subtract(output, coordinate_tensor)

        # Regularization
        if self.output_regularization:
            loss += self.alpha_output * torch.linalg.norm(output_rel, self.reg_norm_output) / (2 * self.batch_size)
        if self.jacobian_regularization: #  and epoch > self.epochs//2:
            loss += self.alpha_jacobian * self.computeJacobianLoss(coordinate_tensor, output_rel)
        if self.diffusion_regularization:
            loss += self.alpha_diffusion * self.computeDiffusionLoss(coordinate_tensor, output_rel)
        if self.elastic_regularization:
            loss += self.alpha_elastic * self.computeElasticLoss(coordinate_tensor, output_rel)
        if self.quadratic_regularization:
            loss += self.alpha_quadratic * self.computeQuadraticElasticLoss(coordinate_tensor, output_rel)
        if self.ogden_regularization:
            loss += self.alpha_ogden * self.computeOgdenLoss(coordinate_tensor, output_rel)
        if self.hyper_regularization:
            loss += self.alpha_hyper * self.computeHyperElasticLoss(coordinate_tensor, output_rel)
        if self.bending_regularization:
            loss += self.alpha_bending * self.computeBendingEnergy(coordinate_tensor, output_rel)            

        # Perform the backpropagation and update the parameters accordingly
        # self.optimizer.zero_grad()
        for param in self.network.parameters():
            param.grad = None        
        loss.backward()
        self.optimizer.step()

        # Regularization scheduler
        #         if epoch == self.epochs // 3:
        #             self.alpha_hyper *= 0.1
        # self.scheduler.step()

        # Store the value of the total loss
        if self.verbose:
            self.loss_list[epoch] = loss.detach().cpu().numpy()
        
#         # Print Logs
#         if (epoch % self.log_interval == 0 or epoch == self.epochs - 1):
#             # self.saveNetwork('network_' + str(epoch) + '.pt')
#             if self.verbose:
#                 with torch.no_grad():
#                     output = self.network(self.coordinate_slice)
#                     transformed_image = self.transform(output, self.coordinate_slice)
#                     self.printLogs(epoch, loss, transformed_image, output, coordinate_tensor.detach().cpu().numpy())

    def transform(self, transformation, coordinate_tensor=None, moving_image=None, reshape=False):
        """Transform moving image given a transformation."""

        # If no specific coordinate tensor is given use the standard one of 28x28
        if coordinate_tensor is None:
            coordinate_tensor = self.coordinate_tensor

        # If no moving image is given use the standard one
        if moving_image is None:
            moving_image = self.moving_image

        # From relative to absolute
        transformation = torch.add(transformation, coordinate_tensor)
        return self.interpolate(moving_image, transformation)

    def transformNoAdd(self, transformation, moving_image=None, reshape=False):
        """Transform moving image given a transformation."""

        # If no moving image is given use the standard one
        if moving_image is None:
            moving_image = self.moving_image
        # print('GET MOVING')
        return self.interpolate(moving_image, transformation)
    
    def plotDeterminant(self, dims=(100, 100), clip=False, epsilon=0.00001, colormap='seismic', dimension=0):
        """Plots the determinant of the jacobian matrix of the transform."""

        jacobian = self.getJacobian(dims, dimension)

        # Reshape determinant
        jacobian = jacobian.detach().numpy().reshape(dims)

        print('Percentage negative {}'.format(np.sum(jacobian < 0)/np.prod(jacobian.shape)))
        # Make figure
        plt.figure(figsize=(8, 8))

        # Clip values
        if clip:
            jacobian[jacobian < -1 - epsilon] = - 1 - epsilon
            jacobian[jacobian > 1 + epsilon] = 1 + epsilon
            plt.imshow(jacobian, cmap=colormap, norm=plt.Normalize(vmin=-1 - epsilon, vmax=1 + epsilon))
        else:
            maxval = np.max(np.abs(jacobian))
            plt.imshow(jacobian, cmap=colormap, norm=plt.Normalize(vmin=-maxval, vmax=maxval))

        # Plotting
        plt.title("Jacobian", fontsize=20)
        plt.yticks([]), plt.xticks([])
        plt.colorbar()
        plt.show()
        
    def fillPixel(self, dims, dimension = 1):
        coordinate_slice = self.makeCoordinateSlice(dims, dimension=dimension).requires_grad_(True)
        output = self.network(coordinate_slice)
        output = torch.cat((coordinate_slice, output), 1)
        print(output.shape)
        return output
        
    def plotDifferenceImage(self, output_shape=(500, 500), dimension=0, slice_pos=0, mask_overlay=False):
        """Plot an image of the difference between fixed image and transformed moving image."""

        moving_image = self(output_shape=output_shape, dimension=dimension, slice_pos=slice_pos)
        fixed_image = self.interpolate(self.fixed_image, self.makeCoordinateSlice(output_shape, dimension, slice_pos))
        fixed_image = fixed_image.detach().cpu().view(self.output_shape)

        difference_image = np.abs(moving_image - fixed_image)

        if mask_overlay:
            difference_image_rgb = np.zeros(output_shape[0], output_shape[1], 3)
            difference_image_rgb[:, :, 0] = difference_image

            # Transformed version of mask         FIX THIS
            mask = np.random.rand(*output_shape)
            difference_image[:, :, 2] = mask

        plt.figure(figsize=(8, 8))

        # Imshow  rgb/grayscale   depending on   mask_overlay
        if mask_overlay:
            plt.imshow(difference_image, cmap='gray', norm=plt.Normalize(vmin=0, vmax=1))
        else:
            plt.imshow(difference_image, norm=plt.Normalize(vmin=0, vmax=1))

        plt.title('Difference image', fontsize=20)
        plt.xticks([]), plt.yticks([])
        plt.show()
    
    def printLogs(self, epoch, loss, transformed_image, output, locations):
        """Print the progress of the training."""

        # Make figure and axis
        fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))

        # Print Loss
        print("-" * 10 + "  epoch: " + str(epoch) + "  " + "-" * 10)
        loss = loss.detach().cpu().numpy()
        print("Loss: " + str(loss) + '   PSNR: ' + str(-10 * np.log10(loss)))

        # Reshape transformed image
        transformed_image_plot = torch.FloatTensor(transformed_image.cpu().detach().numpy().reshape(*self.image_shape))

        # Compute difference image
        difference_image = torch.abs(torch.subtract(transformed_image_plot, self.fixed_image_tensor_masked))

        # Plot interpolated, fixed and difference image
        transformed_result = transformed_image.detach().cpu().numpy().reshape(*self.image_shape)
        fixed_ref = self.fixed_image_tensor_masked.numpy()
        
        ax1.imshow(transformed_result, cmap='gray') # , norm=plt.Normalize(vmin=0, vmax=1))
        ax1.set_title("Deformed Image", fontsize=20)
        # ax1.set_yticks([]), ax1.set_xticks([])
        
        ax2.imshow(fixed_ref, cmap='gray') # , norm=plt.Normalize(vmin=0, vmax=1))
        ax2.set_title("Target", fontsize=20)
        # ax2.set_yticks([]), ax2.set_xticks([])

        ax3.imshow(difference_image.detach().cpu().numpy().reshape(*self.image_shape), cmap='seismic') # , norm=plt.Normalize(vmin=0, vmax=1))
        ax3.set_title("Difference", fontsize=20)
        # ax3.set_yticks([]), ax3.set_xticks([])

        plt.show()
        
        plt.figure()
        plt.scatter(locations[:, 0], locations[:, 1])
        
        output_normed = np.linalg.norm(output.detach().cpu().numpy(), axis=1)

        plt.figure()
        plt.imshow(output_normed.reshape(*self.image_shape))
        plt.show()    
        
        plt.figure(figsize=(20,20))       
        
        comp_registered = compare_images(fixed_ref, transformed_result, method='checkerboard', n_tiles=(4, 4))
        plt.imshow(comp_registered, cmap='gray')
        plt.show()
        
        plt.figure(figsize=(20,20))       
        
        comp_registered = compare_images(fixed_ref, transformed_result, method='diff')
        plt.imshow(comp_registered, cmap='gray')
        plt.show()
                
        plt.figure(figsize=(20,20))       
        
        alpha = 0.6
        target_image_rgb = np.transpose(np.stack((fixed_ref, fixed_ref, fixed_ref)) ,axes=[1,2,0]) * 255
        target_image_rgb = target_image_rgb.astype(np.uint8)
        aligned_images_rgb = target_image_rgb
        aligned_images_rgb[:,:,1] = alpha*transformed_result*255 + (1-alpha)*aligned_images_rgb[:,:,1]
        aligned_images_rgb = aligned_images_rgb.astype(np.uint8)
        # plt.imshow(aligned_images_rgb)        
        plt.imshow(target_image_rgb)
        
        # plt.colorbar()
        plt.axis('off')
        plt.tight_layout()
        plt.show()

Bilinear interpolation function

In [8]:
def faster_bilinear_interpolation(input_array, x_indices, y_indices):
    
    x_indices = (x_indices + 1) * (input_array.shape[0]-1) * 0.5
    y_indices = (y_indices + 1) * (input_array.shape[1]-1) * 0.5
    
    x0 = torch.floor(x_indices.detach()).to(torch.long)
    y0 = torch.floor(y_indices.detach()).to(torch.long)
    x1 = x0 + 1
    y1 = y0 + 1

    x0 = torch.clamp(x0, 0, input_array.shape[0] - 1)
    y0 = torch.clamp(y0, 0, input_array.shape[1] - 1)
    x1 = torch.clamp(x1, 0, input_array.shape[0] - 1)
    y1 = torch.clamp(y1, 0, input_array.shape[1] - 1)

    x = x_indices - x0
    y = y_indices - y0

    output = input_array[x0, y0] * (1 - x) * (1 - y) + input_array[x1, y0] * x * (1 - y) + input_array[x0, y1] * (1 - x) * y + input_array[x1, y1] * x * y
    return output

Load CSV, load images, perform registration, update and save CSV

In [None]:
import os
import pandas as pd
import shutil
import cv2 
import skimage.transform
import scipy.ndimage as scnd

import matplotlib.pyplot as plt

csvpath = r'D:\ACROBAT Data\acrobat_validation_points_public_1_of_1.csv'
datafolder = r'D:\ACROBAT Data\acrobat_validation_pyramid_1_of_1'
output_csvpath = r'D:\ACROBAT Data\output\csv_output.csv'
output_image_path = r'D:\ACROBAT Data\output\images'

shutil.copy(csvpath,output_csvpath)
csvfile = pd.read_csv(csvpath)

train_val_str = 'test' # train, val or test
pyr = 0
original_scale = 10 # the scale of the loaded images
anon_ids = csvfile.anon_id

for idq in np.unique(anon_ids):
    row_idxs = np.where(anon_ids == idq)
    target_stain = 'HE'
    source_stain = csvfile.ihc_antibody[row_idxs[0][0]]

    res_he_10X = csvfile.mpp_he_10X[row_idxs[0][0]]
    res_ihc_10X = csvfile.mpp_ihc_10X[row_idxs[0][0]]

    pp_scale = 1 # desired scale for preprocessing
    pp_res_he = res_he_10X*(10/pp_scale)
    pp_res_ihc = res_ihc_10X*(10/pp_scale)

    ihc_x_um = csvfile.ihc_x[row_idxs[0]]
    ihc_y_um = csvfile.ihc_y[row_idxs[0]]
    ihc_x = ihc_x_um/res_ihc_10X
    ihc_y = ihc_y_um/res_ihc_10X

    pp_ihc_x = ihc_x_um/pp_res_ihc
    pp_ihc_y = ihc_y_um/pp_res_ihc
    pp_ihc_points = np.column_stack((pp_ihc_x,pp_ihc_y))

    csvfile.he_x[row_idxs[0]] = ihc_x_um
    csvfile.he_y[row_idxs[0]] = ihc_y_um

    target_image = load_acrobat_image(datafolder, str(idq), target_stain, pyr, train_val_str)
    source_image = load_acrobat_image(datafolder, str(idq), source_stain, pyr, train_val_str)

    # Image rescaling
    h,w,s = target_image.shape
    h //= int(original_scale/pp_scale)
    w //= int(original_scale/pp_scale)
    pp_target_image = cv2.resize(target_image, dsize=(w, h), interpolation=cv2.INTER_LINEAR)

    h,w,s = source_image.shape
    h //= int(original_scale/pp_scale)
    w //= int(original_scale/pp_scale)
    pp_source_image = cv2.resize(source_image, dsize=(w, h), interpolation=cv2.INTER_LINEAR)

    # Colorspace conversion
    pp_target_lab = cv2.cvtColor(pp_target_image,cv2.COLOR_RGB2Lab)
    target_show_keypoints,target_kp,target_des,target_tissue_mask,processed_target_channel = sift_he(pp_target_lab, pp_res_he)
    target_tissue_mask = target_tissue_mask.astype(np.uint8)

    pp_source_hsv = cv2.cvtColor(pp_source_image,cv2.COLOR_RGB2HSV)
    source_show_keypoints,source_kp,source_des,source_tissue_mask,processed_source_channel = sift_ihc(pp_source_hsv, pp_res_ihc)

    # SIFT matching
    good_source_kp, good_target_kp,show_matching_keypoints = matching_keypoints(pp_source_image,pp_target_image,source_des,target_des,source_kp,target_kp,lowe_ratio=0.95)
    
    # RANSAC
    best_M,best_ninliers,best_dice,show_best_matching_keypoints = dransac(good_source_kp, good_target_kp, source_tissue_mask, target_tissue_mask, processed_source_channel, processed_target_channel, pp_source_image, pp_target_image,epochs = 100000)

    # Apply rigid transformation to the IHC image
    pp_cols = pp_target_image.shape[1]
    pp_rows = pp_target_image.shape[0]
    processed_source_channel_rigid = cv2.warpAffine(processed_source_channel, best_M, (pp_cols,pp_rows))

    # Apply rigid transformation to the IHC points and update the CSV file
    pp_ihc_points_rigid = cv2.transform(np.array([pp_ihc_points]),best_M)
    ihc_points_rigid_um = pp_ihc_points_rigid*pp_res_ihc
    ihc_x_rigid_um = ihc_points_rigid_um[0][:,0]
    ihc_y_rigid_um = ihc_points_rigid_um[0][:,1]
    csvfile.he_x[row_idxs[0]] = ihc_x_rigid_um
    csvfile.he_y[row_idxs[0]] = ihc_y_rigid_um
    csvfile.to_csv(output_csvpath, index=False)

    # Debug figures for rigid transformation
    aligned_images_rgb = show_aligned_images(processed_source_channel_rigid, processed_target_channel,alpha=0.6)

    fig = plt.figure(figsize = (20,20))
    plt.imshow(show_best_matching_keypoints)
    plt.axis('off')

    fig = plt.figure(figsize = (20,20))
    plt.imshow(aligned_images_rgb)
    plt.axis('off')
    
    # Image preprocessing before INR
    inr_target_image = processed_target_channel/255.0
    inr_target_image = torch.FloatTensor(inr_target_image)

    inr_source_image = processed_source_channel_rigid/255.0
    inr_source_image = torch.FloatTensor(inr_source_image)

    # Dilate tissue mask of HE (helps the INR)
    inr_target_mask = np.clip(target_tissue_mask, 0, 1)
    inr_target_mask = scnd.binary_dilation(inr_target_mask, np.ones((51, 51)))

    # If needed, save the input images before the INR
    # np.save(output_image_path + os.path.sep + str(idq) + '_HE_1x', processed_target_channel)
    # np.save(output_image_path + os.path.sep + str(idq) + '_' + source_stain + '_rigid_1x', processed_source_channel_rigid)
    # np.save(output_image_path + os.path.sep + str(idq) + '_HE_mask_1x', target_tissue_mask)


    # INR parameters
    kwargs = {}
    kwargs['layers'] = [2, 256, 256, 256, 2] 
    kwargs['verbose'] = True 
    kwargs['optimizer'] = 'adam' 
    kwargs['lr'] = 0.00001
    kwargs['batch_size'] = 250 * 25  
    kwargs['hyper_regularization'] = False 
    kwargs['time_regularization'] = False
    kwargs['elastic_regularization'] = False 
    kwargs['quadratic_regularization'] = False
    kwargs['jacobian_regularization'] = False  
    kwargs['bending_regularization'] = False
    kwargs['epochs'] = 25000
    kwargs['log_interval'] = kwargs['epochs']//4
    kwargs['omega'] = 16 
    kwargs['alpha_hyper'] = 0.25
    kwargs['alpha_quadratic'] = 0.1
    kwargs['alpha_jacobian'] = 0.1
    kwargs['alpha_elastic'] = 0.5
    kwargs['alpha_bending'] = 10
    kwargs['velocity_steps'] = 1
    kwargs['save_folder'] = r'D:\ACROBAT Data\csv_rigid_val\output_inr\{}'.format(str(idq))
    kwargs['mask'] = inr_target_mask
    kwargs['loss_function'] = 'ncc'
    kwargs['offset'] = [0, 0]
    kwargs['gabor_scale'] = 32

    ImpReg = ImplicitRegistrator(inr_source_image, inr_target_image, **kwargs)
    ImpReg.train()
    
    # Apply INR transformation to the IHC moving image (using batches of coordinates to avoid out of GPU memory)
    output_shape = (inr_target_image.shape[0],inr_target_image.shape[1])
    grid_tensor = ImpReg.makeCoordinateSlice(ImpReg.image_shape, 1, 0)
    inr_source_image_reg = np.zeros(grid_tensor.shape[0])
    
    forward_batch_size = 10000;
    index = 0
    for grid_batch in torch.split(grid_tensor, forward_batch_size):
        inr_source_image_reg[index:index + forward_batch_size] = ImpReg(grid_batch)
        index = index + forward_batch_size
    inr_source_image_reg = inr_source_image_reg.reshape(output_shape)

    # Apply INR transformation to the IHC points
    inr_res_ihc = pp_res_ihc
    inr_ihc_x = pp_ihc_points_rigid[0][:,0]
    inr_ihc_y = pp_ihc_points_rigid[0][:,1]
    inr_ihc_points = np.column_stack((inr_ihc_x,inr_ihc_y))

    # From absolute to relative here 
    inr_ihc_points_rel = inr_ihc_points.copy()
    inr_ihc_points_rel[:,1] = (2 * inr_ihc_points[:,0]) / (inr_target_image.shape[1] - 1) - 1
    inr_ihc_points_rel[:,0] = (2 * inr_ihc_points[:,1]) / (inr_target_image.shape[0] - 1) - 1

    # Move points
    inr_ihc_points_rel_tensor = torch.from_numpy(np.float32(inr_ihc_points_rel))
    inr_ihc_points_rel_tensor = inr_ihc_points_rel_tensor.cuda()

    output = ImpReg.network(inr_ihc_points_rel_tensor)

    inr_ihc_points_rel_tensor_reg = torch.subtract(output, inr_ihc_points_rel_tensor)
    inr_ihc_points_tensor_reg = inr_ihc_points_rel_tensor_reg.clone()

    # and then back to absolute
    inr_ihc_points_tensor_reg[:,0] = inr_target_image.shape[0]/2*(-inr_ihc_points_rel_tensor_reg[:,0] + 1)
    inr_ihc_points_tensor_reg[:,1] = inr_target_image.shape[1]/2*(-inr_ihc_points_rel_tensor_reg[:,1] + 1)

    inr_ihc_points_reg = inr_ihc_points_tensor_reg.cpu().detach().numpy()
    inr_ihc_points_reg_x = inr_ihc_points_reg[:,1]
    inr_ihc_points_reg_y = inr_ihc_points_reg[:,0]
    inr_ihc_points_reg_x_um = inr_ihc_points_reg_x*inr_res_ihc
    inr_ihc_points_reg_y_um = inr_ihc_points_reg_y*inr_res_ihc
    
    # INR transformation figures
    fig = plt.figure(figsize=(20,20))
    ax = fig.add_subplot(2, 1, 1)
    aligned_images_rgb = show_aligned_images(inr_source_image*255, inr_target_image*255,alpha=0.6)
    ax.imshow(aligned_images_rgb)
    ax.axis('off')
    ax = fig.add_subplot(2, 1, 2)
    aligned_images_rgb = show_aligned_images(inr_source_image_reg*255, inr_target_image*255,alpha=0.6)
    ax.imshow(aligned_images_rgb)
    ax.axis('off')
    
    # Update the CSV file
    csvfile.he_x[row_idxs[0]] = inr_ihc_points_reg_x_um
    csvfile.he_y[row_idxs[0]] = inr_ihc_points_reg_y_um
    csvfile.to_csv(output_csvpath, index=False)

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  csvfile.he_x[row_idxs[0]] = ihc_x_um
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  csvfile.he_y[row_idxs[0]] = ihc_y_um
  5%|███▌                                                                       | 4789/100000 [00:11<04:06, 386.58it/s]