In [18]:
import ipywidgets as widgets
from ipywidgets import interact, interactive
from IPython.display import display
import SimpleITK as sitk
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torch
import nibabel as nib
import os
%matplotlib inline

# im_index = 363
# pred_index = 0

# im_real = nib.load(f'imagesTr/hepaticvessel_{im_index}.nii.gz').get_fdata()
# label_real = nib.load(f'labelsTr/hepaticvessel_{im_index}.nii.gz').get_fdata()
# label_pred = np.load(f'predictions/abhijit/prediction_{im_index}.npy').squeeze(0)

index_vis = 0
im_real = np.load(os.path.join('.', 'images_true', f'{index_vis}.npy'))
label_real = np.load(os.path.join('.', 'labels_true', f'{index_vis}.npy'))
label_pred = np.load(os.path.join('.', 'Ensamble', 'Abhijit', f'{index_vis}.npy')).squeeze(0)

# convert the one-hot predictions to labels (0, 1, 2)
# label_pred = np.argmax(label_pred.squeeze(0), axis=0)
# label_pred = np.argmax(label_pred.squeeze(0), axis=0)

print(im_real.shape)
print(label_real.shape)
print(label_pred.shape)
                     
                     
label_pred = torch.argmax(torch.softmax(torch.tensor(label_pred), axis=0), axis=0)
print(label_pred.shape)

height, width, depth = label_pred.shape

def visualize_depth(index=0, label=-1):
    # cmap = 'gray'
    cmap = 'viridis'
    # cmap = 'inferno'
    # cmap = 'plasma'
    
    index_depth = index;
    plt.figure(figsize=(12, 6))
    plt.title('height x width')

    plt.subplot(1, 3, 1)
    plt.imshow(im_real[:, :, index], cmap=cmap)
    plt.title('Real Image')
    plt.xticks([])
    plt.yticks([])

    plt.subplot(1, 3, 2)
    if label == 0:
        current_label = ' (Background)'
    elif label == 1:
        current_label = ' (Vessel)'
    elif label == 2:
        current_label = ' (Tumour)'
    else:
        current_label = ''
        
    plt.title(f'Real Label{current_label}')
    if label < 0:
        plt.imshow(label_real[:, :, index], cmap=cmap)
    else:
        plt.imshow(np.where(label_real[:, :, index] == label, 1, 0), cmap=cmap)
        
    plt.xticks([])
    plt.yticks([])

    plt.subplot(1, 3, 3)
    plt.title(f'Predicted Label{current_label}')
    if label < 0:
        plt.imshow(label_pred[:, :, index], cmap=cmap)
    else:
        plt.imshow(np.where(label_pred[:, :, index] == label, 1, 0), cmap=cmap)
    plt.xticks([])
    plt.yticks([])
    plt.show()

interact(visualize_depth, index=widgets.IntSlider(min=0, max=depth-1, step=1, value=0), label=widgets.IntSlider(min=-1, max=2, step=1, value=-1));

(512, 512, 134)
(512, 512, 134)
(3, 512, 512, 134)
torch.Size([512, 512, 134])


interactive(children=(IntSlider(value=0, description='index', max=133), IntSlider(value=-1, description='label…

In [14]:
def visualize_height(index=0):
    # cmap = 'gray'
    cmap = 'viridis'
    index_depth = index;
    plt.figure(figsize=(12, 6))
    plt.title('width x depth')

    plt.subplot(1, 3, 1)
    plt.imshow(im_real[index, :, :], cmap=cmap)
    plt.title('Real Image')
    plt.xticks([])
    plt.yticks([])

    plt.subplot(1, 3, 2)
    plt.title('Real Label')
    plt.imshow(label_real[index, :, :], cmap=cmap)
    plt.xticks([])
    plt.yticks([])

    plt.subplot(1, 3, 3)
    plt.title('Predicted Label')
    plt.imshow(label_pred[index, :, :], cmap=cmap)
    plt.xticks([])
    plt.yticks([])
    plt.show()

interact(visualize_height, index=widgets.IntSlider(min=0, max=depth-1, step=1, value=0));

interactive(children=(IntSlider(value=0, description='index', max=133), Output()), _dom_classes=('widget-inter…

In [15]:
def visualize_width(index=0):
    # cmap = 'gray'
    cmap = 'viridis'
    index_depth = index;
    plt.figure(figsize=(12, 6))
    plt.title('height x depth')

    plt.subplot(1, 3, 1)
    plt.imshow(im_real[:, index, :], cmap=cmap)
    plt.title('Real Image')
    plt.xticks([])
    plt.yticks([])

    plt.subplot(1, 3, 2)
    plt.title('Real Label')
    plt.imshow(label_real[:, index, :], cmap=cmap)
    plt.xticks([])
    plt.yticks([])

    plt.subplot(1, 3, 3)
    plt.title('Predicted Label')
    plt.imshow(label_pred[:, index, :], cmap=cmap)
    plt.xticks([])
    plt.yticks([])
    plt.show()

interact(visualize_width, index=widgets.IntSlider(min=0, max=depth-1, step=1, value=0));

interactive(children=(IntSlider(value=0, description='index', max=133), Output()), _dom_classes=('widget-inter…

In [164]:
class ElasticDeformation():
    """
        Classs containing Elastic Deformation Transformation
        Adapted from week 3 AML tutorial materials
    """

    def __init__(self, num_controlpoints=None, sigma=None):

        self.num_controlpoints = num_controlpoints
        self.sigma = sigma

        # Random parameters if not defined
        if self.sigma == None:
            self.sigma = np.random.uniform(low=1, high=5)

        if self.num_controlpoints == None:
            self.num_controlpoints = int(np.random.uniform(low=1, high=6))

    def apply_elastic_deformation(self, image, label):

        # We need to choose an interpolation method for our transformed image, let's just go with b-spline
        resampler = sitk.ResampleImageFilter()
        resampler.SetInterpolator(sitk.sitkBSpline)

        # Let's convert our image to an sitk image
        sitk_image = sitk.GetImageFromArray(image)

        # Specify the image to be transformed: This is the reference image
        resampler.SetReferenceImage(sitk_image)
        resampler.SetDefaultPixelValue(0)

        # Initialise the transform
        bspline_transform = self.create_elastic_deformation(image)

        # Set the transform in the initialiser
        resampler.SetTransform(bspline_transform)

        # Carry out the resampling according to the transform and the resampling method
        out_img_sitk = resampler.Execute(sitk_image)

        # Convert the image back into a python array
        out_img = sitk.GetArrayFromImage(out_img_sitk)

        # We need to choose an interpolation method for our transformed image, let's just go with b-spline
        resampler_label = sitk.ResampleImageFilter()
        resampler_label.SetInterpolator(sitk.sitkNearestNeighbor)

        # Let's convert our image to an sitk image
        sitk_label = sitk.GetImageFromArray(label)

        # Specify the image to be transformed: This is the reference image
        resampler_label.SetReferenceImage(sitk_label)
        resampler_label.SetDefaultPixelValue(0)

        # Initialise the transform
        bspline_transform = self.create_elastic_deformation(label)

        # Set the transform in the initialiser
        resampler_label.SetTransform(bspline_transform)

        # Carry out the resampling according to the transform and the resampling method
        out_label_sitk = resampler_label.Execute(sitk_label)

        # Convert the image back into a python array
        out_label = sitk.GetArrayFromImage(out_label_sitk)

        return out_img.reshape(image.shape), out_label.reshape(image.shape)

    def create_elastic_deformation(self, image):
        """
            We need to parameterise our b-spline transform
            The transform will depend on such variables as image size and sigma
            Sigma modulates the strength of the transformation
            The number of control points controls the granularity of our transform
        """
        # Create an instance of a SimpleITK image of the same size as our image
        itkimg = sitk.GetImageFromArray(np.zeros(image.shape))

        # This parameter is just a list with the number of control points per image dimensions
        trans_from_domain_mesh_size = [self.num_controlpoints] * itkimg.GetDimension()

        # We initialise the transform here: Passing the image size and the control point specifications
        bspline_transformation = sitk.BSplineTransformInitializer(itkimg, trans_from_domain_mesh_size)

        # Isolate the transform parameters: They will be all zero at this stage
        params = np.asarray(bspline_transformation.GetParameters(), dtype=float)

        # Let's initialise the transform by randomly initialising each parameter according to sigma
        params = params + np.random.randn(params.shape[0]) * self.sigma

        # Let's initialise the transform by randomly displacing each control point by a random distance (magnitude sigma)
        bspline_transformation.SetParameters(tuple(params))

        return bspline_transformation

    def __call__(self, image, label):
        if len(image.shape) == 5:
            image = image.squeeze(0).squeeze(0)
        if len(label.shape) == 5:
            label = label.squeeze(0).squeeze(0)
        image_ed, label_ed = self.apply_elastic_deformation(image, label)
        return image_ed, label_ed

In [170]:
elastic_deformation = ElasticDeformation(num_controlpoints=20, sigma=5)
im_real_ed, label_real_ed = elastic_deformation(torch.tensor(im_real).unsqueeze(0).unsqueeze(0), torch.tensor(label_real).unsqueeze(0).unsqueeze(0))

def visualize_depth(index=0, label=-1):
    # cmap = 'gray'
    cmap = 'viridis'
    index_depth = index;
    plt.figure(figsize=(12, 6))
    plt.title('height x width')

    plt.subplot(1, 3, 1)
    plt.imshow(im_real[:, :, index], cmap=cmap)
    plt.title('Real Image')
    plt.xticks([])
    plt.yticks([])

    plt.subplot(1, 3, 2)
    if label == 0:
        current_label = ' (Background)'
    elif label == 1:
        current_label = ' (Vessel)'
    elif label == 2:
        current_label = ' (Tumour)'
    else:
        current_label = ''
        
    plt.title(f'Real Label{current_label}')
    if label < 0:
        plt.imshow(label_real[:, :, index], cmap=cmap)
    else:
        plt.imshow(np.where(label_real[:, :, index] == label, 1, 0), cmap=cmap)
        
    plt.xticks([])
    plt.yticks([])

    plt.subplot(1, 3, 3)
    plt.title(f'Predicted Label{current_label}')
    if label < 0:
        plt.imshow(label_pred[:, :, index], cmap=cmap)
    else:
        plt.imshow(np.where(label_pred[:, :, index] == label, 1, 0), cmap=cmap)
    plt.xticks([])
    plt.yticks([])
    plt.show()
    
    ###########################################################################
    
    plt.figure(figsize=(12, 6))
    plt.title('height x width')

    plt.subplot(1, 3, 1)
    plt.imshow(im_real_ed[:, :, index], cmap=cmap)
    plt.title('Real Image (ED)')
    plt.xticks([])
    plt.yticks([])
    
    
    plt.subplot(1, 3, 2)
    if label == 0:
        current_label = ' (Background)'
    elif label == 1:
        current_label = ' (Vessel)'
    elif label == 2:
        current_label = ' (Tumour)'
    else:
        current_label = ''
        
    plt.title(f'Real Label{current_label} (ED)')
    if label < 0:
        plt.imshow(label_real_ed[:, :, index], cmap=cmap)
    else:
        plt.imshow(np.where(label_real_ed[:, :, index] == label, 1, 0), cmap=cmap)
        
    plt.xticks([])
    plt.yticks([])
    
    
    
interact(visualize_depth, index=widgets.IntSlider(min=0, max=depth-1, step=1, value=20), label=widgets.IntSlider(min=-1, max=2, step=1, value=-1));

interactive(children=(IntSlider(value=20, description='index', max=39), IntSlider(value=-1, description='label…

In [181]:
class AffineTransformation():
    """
        Classs containing Elastic Deformation Transformation
        Adapted from week 3 AML tutorial materials
    """

    def __init__(self, rotation=5, scale=(0.95, 1.05), shear=(0.01, 0.02), return_inverse=False, inverse_matrix=None):
        self.rotation = rotation
        self.scale = scale
        self.shear = shear
        self.return_inverse = return_inverse
        self.inverse_matrix = inverse_matrix

    def get_transformation_matrix(self):
        # apply rotation on the z-axis
        degree_rotation = torch.tensor(1, dtype=torch.float32).uniform_(-self.rotation, self.rotation)
        degree_rotation = (degree_rotation * torch.pi) / 180

        matrix_rotation = torch.zeros((1, 4, 4), dtype=torch.float32)
        matrix_rotation[0, 0, 0] = torch.cos(degree_rotation)
        matrix_rotation[0, 0, 1] = torch.sin(degree_rotation)
        matrix_rotation[0, 1, 0] = -torch.sin(degree_rotation)
        matrix_rotation[0, 1, 1] = torch.cos(degree_rotation)
        matrix_rotation[0, 2, 2] = 1
        matrix_rotation[0, 3, 3] = 1

        # apply scaling on each dimension
        matrix_scale = torch.zeros((1, 4, 4), dtype=torch.float32)
        matrix_scale[0, 0, 0] = torch.tensor(1, dtype=torch.float32).uniform_(self.scale[0], self.scale[1])
        matrix_scale[0, 1, 1] = torch.tensor(1, dtype=torch.float32).uniform_(self.scale[0], self.scale[1])
        matrix_scale[0, 2, 2] = torch.tensor(1, dtype=torch.float32).uniform_(self.scale[0], self.scale[1])
        matrix_scale[0, 3, 3] = 1
        # print(matrix_scale.shape)

        # shear
        degree_shear = torch.tensor((
            torch.tensor(1, dtype=torch.float32).uniform_(self.shear[0], self.shear[1]),
            torch.tensor(1, dtype=torch.float32).uniform_(self.shear[0], self.shear[1])
        ))

        matrix_shear = torch.zeros((1, 4, 4), dtype=torch.float32)
        matrix_shear[0, 0, 0] = 1
        matrix_shear[0, 0, 1] = degree_shear[0]
        matrix_shear[0, 1, 0] = degree_shear[1]
        matrix_shear[0, 1, 1] = 1
        matrix_shear[0, 2, 2] = 1
        matrix_shear[0, 3, 3] = 1

        # generate the combined affine transformation matrix
        self.matrix_affine = torch.matmul(matrix_shear, torch.matmul(matrix_rotation, matrix_scale))

        # generate the inverse transformation matrix
        self.matrix_affine_inv = torch.inverse(self.matrix_affine)

        # return to original coordinates
        self.matrix_affine = self.matrix_affine[:, 0:3, :]
        self.matrix_affine_inv = self.matrix_affine_inv[:, 0:3, :]

    def __call__(self, image, label):
        
        # obtain transformation matrix
        self.get_transformation_matrix()
        
        # define the affine grid and apply transformation on images and labels
        if self.return_inverse:
            grid_affine = F.affine_grid(self.matrix_affine_inv, image.shape, align_corners=False)
            trans_img = F.grid_sample(image.float(), grid_affine, padding_mode="border", align_corners=False)
            trans_label = F.grid_sample(label.float(), grid_affine, mode='nearest', padding_mode="zeros", align_corners=False)
        else:
            grid_affine = F.affine_grid(self.matrix_affine, image.shape, align_corners=False)
            trans_img = F.grid_sample(image.float(), grid_affine, padding_mode="border", align_corners=False)
            trans_label = F.grid_sample(label.float(), grid_affine, mode='nearest', padding_mode="zeros", align_corners=False)

        return trans_img, trans_label, self.matrix_affine, self.matrix_affine_inv

In [182]:
affine_transform = AffineTransformation(rotation=10, 
                                        scale=(0.90, 1.10), 
                                        shear=(0.01, 0.02)
                                       )
im_real_at, label_real_at, _, _ = affine_transform(torch.tensor(im_real).unsqueeze(0).unsqueeze(0), torch.tensor(label_real).unsqueeze(0).unsqueeze(0))
im_real_at = im_real_at.squeeze(0).squeeze(0)
label_real_at = label_real_at.squeeze(0).squeeze(0)

def visualize_depth(index=0, label=-1):
    # cmap = 'gray'
    cmap = 'viridis'
    index_depth = index;
    plt.figure(figsize=(12, 6))
    plt.title('height x width')

    plt.subplot(1, 3, 1)
    plt.imshow(im_real[:, :, index], cmap=cmap)
    plt.title('Real Image')
    plt.xticks([])
    plt.yticks([])

    plt.subplot(1, 3, 2)
    if label == 0:
        current_label = ' (Background)'
    elif label == 1:
        current_label = ' (Vessel)'
    elif label == 2:
        current_label = ' (Tumour)'
    else:
        current_label = ''
        
    plt.title(f'Real Label{current_label}')
    if label < 0:
        plt.imshow(label_real[:, :, index], cmap=cmap)
    else:
        plt.imshow(np.where(label_real[:, :, index] == label, 1, 0), cmap=cmap)
        
    plt.xticks([])
    plt.yticks([])

    plt.subplot(1, 3, 3)
    plt.title(f'Predicted Label{current_label}')
    if label < 0:
        plt.imshow(label_pred[:, :, index], cmap=cmap)
    else:
        plt.imshow(np.where(label_pred[:, :, index] == label, 1, 0), cmap=cmap)
    plt.xticks([])
    plt.yticks([])
    plt.show()
    
    ###########################################################################
    
    plt.figure(figsize=(12, 6))
    plt.title('height x width')

    plt.subplot(1, 3, 1)
    plt.imshow(im_real_at[:, :, index], cmap=cmap)
    plt.title('Real Image (ED)')
    plt.xticks([])
    plt.yticks([])
    
    
    plt.subplot(1, 3, 2)
    if label == 0:
        current_label = ' (Background)'
    elif label == 1:
        current_label = ' (Vessel)'
    elif label == 2:
        current_label = ' (Tumour)'
    else:
        current_label = ''
        
    plt.title(f'Real Label{current_label} (ED)')
    if label < 0:
        plt.imshow(label_real_at[:, :, index], cmap=cmap)
    else:
        plt.imshow(np.where(label_real_at[:, :, index] == label, 1, 0), cmap=cmap)
        
    plt.xticks([])
    plt.yticks([])
    
    
    
interact(visualize_depth, index=widgets.IntSlider(min=0, max=depth-1, step=1, value=20), label=widgets.IntSlider(min=-1, max=2, step=1, value=-1));

interactive(children=(IntSlider(value=20, description='index', max=39), IntSlider(value=-1, description='label…