In [3]:
# Module Imports
import torch.fft as fft
import torch
from PIL import Image
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import numpy as np
import monai
import torchvision.transforms as transforms
from pathlib import Path
import SimpleITK as sitk
import sys
sys.path.append('../')
from ipywidgets import interact, interactive, fixed, interact_manual, IntSlider
import math
import cv2

from modules.losses import mind_loss as mind

In [4]:
# Source: https://discuss.pytorch.org/t/add-label-captions-to-make-grid/42863
irange = range

def make_grid_with_labels(tensor, labels, nrow=8, limit=20, padding=2,
                          normalize=False, range=None, scale_each=False, pad_value=0):
    """Make a grid of images.

    Args:
        tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W)
            or a list of images all of the same size.
        labels (list):  ( [labels_1,labels_2,labels_3,...labels_n]) where labels is Bx1 vector of some labels
        limit ( int, optional): Limits number of images and labels to make grid of
        nrow (int, optional): Number of images displayed in each row of the grid.
            The final grid size is ``(B / nrow, nrow)``. Default: ``8``.
        padding (int, optional): amount of padding. Default: ``2``.
        normalize (bool, optional): If True, shift the image to the range (0, 1),
            by the min and max values specified by :attr:`range`. Default: ``False``.
        range (tuple, optional): tuple (min, max) where min and max are numbers,
            then these numbers are used to normalize the image. By default, min and max
            are computed from the tensor.
        scale_each (bool, optional): If ``True``, scale each image in the batch of
            images separately rather than the (min, max) over all images. Default: ``False``.
        pad_value (float, optional): Value for the padded pixels. Default: ``0``.

    Example:
        See this notebook `here <https://gist.github.com/anonymous/bf16430f7750c023141c562f3e9f2a91>`_

    """
    # Opencv configs
#     if not isinstance(labels, list):
#         raise ValueError
#     else:
#         labels = np.asarray(labels).T[0]
#     if limit is not None:
#         tensor = tensor[:limit, ::]
#         labels = labels[:limit, ::]

    import cv2
    font = 1
    fontScale = 2
    color = (255, 255, 255)
    thickness = 1

    if not (torch.is_tensor(tensor) or
            (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
        raise TypeError('tensor or list of tensors expected, got {}'.format(type(tensor)))

    # if list of tensors, convert to a 4D mini-batch Tensor
    if isinstance(tensor, list):
        tensor = torch.stack(tensor, dim=0)

    if tensor.dim() == 2:  # single image H x W
        tensor = tensor.unsqueeze(0)
    if tensor.dim() == 3:  # single image
        if tensor.size(0) == 1:  # if single-channel, convert to 3-channel
            tensor = torch.cat((tensor, tensor, tensor), 0)
        tensor = tensor.unsqueeze(0)

    if tensor.dim() == 4 and tensor.size(1) == 1:  # single-channel images
        tensor = torch.cat((tensor, tensor, tensor), 1)

    if normalize is True:
        tensor = tensor.clone()  # avoid modifying tensor in-place
        if range is not None:
            assert isinstance(range, tuple), \
                "range has to be a tuple (min, max) if specified. min and max are numbers"

        def norm_ip(img, min, max):
            img.clamp_(min=min, max=max)
            img.add_(-min).div_(max - min + 1e-5)

        def norm_range(t, range):
            if range is not None:
                norm_ip(t, range[0], range[1])
            else:
                norm_ip(t, float(t.min()), float(t.max()))

        if scale_each is True:
            for t in tensor:  # loop over mini-batch dimension
                norm_range(t, range)
        else:
            norm_range(tensor, range)

    if tensor.size(0) == 1:
        return tensor.squeeze(0)

    # make the mini-batch of images into a grid
    nmaps = tensor.size(0)
    xmaps = min(nrow, nmaps)
    ymaps = int(math.ceil(float(nmaps) / xmaps))
    height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding)
    num_channels = tensor.size(1)
    grid = tensor.new_full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value)
    k = 0
    for y in irange(ymaps):
        for x in irange(xmaps):
            if k >= nmaps:
                break
            working_tensor = tensor[k]
            if labels is not None:
                org = (0, int(tensor[k].shape[1] * 0.3))
                working_image = cv2.UMat(
                    np.asarray(np.transpose(working_tensor.numpy(), (1, 2, 0)) * 255).astype('uint8'))
                image = cv2.putText(working_image, f'{str(labels[k])}', org, font,
                                    fontScale, color, thickness, cv2.LINE_AA)
                working_tensor = transforms.ToTensor()(image.get())
            grid.narrow(1, y * height + padding, height - padding) \
                .narrow(2, x * width + padding, width - padding) \
                .copy_(working_tensor)
            k = k + 1
    return grid

In [5]:
# Visualization utilities 
transform_list = [transforms.ToTensor(), transforms.Normalize((0.5, ), (0.5,))]
transform = transforms.Compose(transform_list)

def show(img, ax):
    npimg = img.numpy()
    if ax is not None:
        ax.imshow(np.transpose(npimg, (1,2,0)))
    else:
        plt.imshow(np.transpose(npimg, (1,2,0)))

In [6]:
# Loading images and normalization utilities
def load_images(fn1, fn2, offset=5):
    A = sitk.ReadImage(str(fn1 / 'target.nrrd'))
    B = sitk.ReadImage(str(fn2 / 'deformed.nrrd'))

    A = sitk.GetArrayFromImage(A)
    B = sitk.GetArrayFromImage(B)

    A = A[A.shape[0]//2]
    B = B[B.shape[0]//2+offset]

    A = torch.Tensor(A)
    B = torch.Tensor(B)

    # Limits the lowest and highest HU unit
    A = torch.clamp(A, -1000, 2000)
    B = torch.clamp(B, -1000, 2000)

    # Normalize Hounsfield units to range [-1,1]
    A = min_max_normalize(A, -1000, 2000)
    B = min_max_normalize(B, -1000, 2000)

    return A, B

 
def min_max_normalize(image, min_value, max_value):
    image = image.float()
    image = (image - min_value) / (max_value - min_value)
    return 2 * image - 1

In [7]:
# Generalized Frequency Loss implementation
def generalized_freq_loss(fn1, fn2, offset=5):
    A, B = load_images(fn1, fn2, offset)
    A = (A+1)/2
    B = (B+1)/2
    
    f_A = fft.fftn(A, norm="ortho")
    f_B = fft.fftn(B, norm="ortho")
    
    f_A = fft.fftshift(f_A)
    f_B = fft.fftshift(f_B)

    f_A = torch.abs(f_A)
    f_B = torch.abs(f_B)

    f_A = torch.tanh(f_A)
    f_B = torch.tanh(f_B)

    abs_diff = torch.abs(f_A - f_B)
    abs_diff = torch.tanh(abs_diff)
    square = torch.square(f_A - f_B)
    square = torch.tanh(square)
    
    images = [A, f_A, B, f_B, abs_diff, square]
    for idx, image in enumerate(images):
        image = image.unsqueeze(dim=0)
        images[idx] = image
        
    f = plt.figure(figsize=(15, 15))
    show(make_grid_with_labels(images, ['A', 'B', 'f_A', 'f_B', 'Abs difference', 'Squared difference']), None)
    plt.axis('off')
    plt.show() 

In [8]:
# Registration Loss implementation
def registration_loss(fn1, fn2, offset=0):
    f = plt.figure(figsize=(15, 15))

    A, B = load_images(fn1, fn2, offset)
    
    ip = A.unsqueeze(dim=0).unsqueeze(dim=0).unsqueeze(dim=0)
    op = B.unsqueeze(dim=0).unsqueeze(dim=0).unsqueeze(dim=0)
    mil = monai.losses.LocalNormalizedCrossCorrelationLoss(reduction='none')
    img =  mil.forward(ip, op)
    img = img.squeeze()
    img = torch.tanh(img)
    img = (img +1)/2
    A = (A+1)/2
    B = (B+1)/2    
    show(make_grid_with_labels([A.unsqueeze(dim=0), B.unsqueeze(dim=0), img.unsqueeze(dim=0)], ['A', 'B', 'Registration Loss']), None)
    plt.axis('off')
    
    plt.show()

In [12]:
# MIND Loss implementation
mind_descriptor = mind.MINDDescriptor()
def mind_loss(fn1, fn2, offset=0):
    f = plt.figure(figsize=(15, 15))
    A, B = load_images(fn1, fn2, offset)
    input = A.unsqueeze(dim=0).unsqueeze(dim=0).unsqueeze(dim=0)
    target = B.unsqueeze(dim=0).unsqueeze(dim=0).unsqueeze(dim=0)
    input = input.view(-1, 1, *input.shape[3:])
    target = target.view(-1, 1, *target.shape[3:])
    
    ip_feature = mind_descriptor(input)
    target_feature = mind_descriptor(target)
    

    mind_diff = ip_feature - target_feature
    l1 = torch.norm(mind_diff, dim=1, keepdim=True).detach()
    A = (A+1)/2
    B = (B+1)/2   
    show(make_grid_with_labels([A.unsqueeze(dim=0), B.unsqueeze(dim=0), l1.squeeze(dim=0)],  ['A', 'B', 'MIND Loss']), None)
    plt.axis('off')
    plt.show()
    

In [13]:
# Wrapper function to access different losses interactively
def get_representations(fn1, fn2, offset, method):
    function_call = eval(method)
    function_call(fn1, fn2, offset)

In [14]:
# Interactive access to loss representations

# Set the paths below to a folder containing CBCT image - target.nrrd and CT image - deformed.nrrd, which is obtained post registration
fn1 = Path('/work/vq218944/sample_CT_data/0')
fn2 = Path('/work/vq218944/sample_CT_data/0')
offset = 0

interact(get_representations, fn1=fixed(fn1), fn2=fixed(fn2), offset=IntSlider(min=0, max=10), method=['generalized_freq_loss', 'registration_loss', 'mind_loss'])

interactive(children=(IntSlider(value=0, description='offset', max=10), Dropdown(description='method', options…

<function __main__.get_representations(fn1, fn2, offset, method)>