In [19]:
import numpy as np
import nibabel as nib

import matplotlib.pyplot as plt
from matplotlib import cm
import matplotlib.animation as anim
import matplotlib.patches as mpatches
import matplotlib.gridspec as gridspec
import torch.nn.functional as F
import torch

from model.gradcam import GradCAMpp, load_vgg16, init_gcpp
from model.preprocessing import transform_datapoint
import visualization.contours as contours

In [2]:
def contour_mask(img):
    img = np.array(img[0])
    g_img = contours.rgb_to_grayscale(img[-1, :, :])
    return contours.contour_matrix_mask(g_img)

In [3]:
model_path = '../trained_model.pt'

sample_filename1 = '../brain350/BraTS20_Training_350_flair.nii'
sample_filename_mask = '../brain350/BraTS20_Training_350_seg.nii'
sample_filename2 = '../brain350/BraTS20_Training_350_t1.nii'
sample_filename3 = '../brain350/BraTS20_Training_350_t2.nii'
sample_filename4 = '../brain350/BraTS20_Training_350_t1ce.nii'

mean_path = 'data/mean.npy'
std_path = 'data/std.npy'

slice = 80 #Selecting what slice that we want to see

In [4]:
# Load images

sample_img1 = nib.load(sample_filename1)
sample_img1 = np.asanyarray(sample_img1.dataobj)

sample_img2 = nib.load(sample_filename2)
sample_img2 = np.asanyarray(sample_img2.dataobj)

sample_img3 = nib.load(sample_filename3)
sample_img3 = np.asanyarray(sample_img3.dataobj)

sample_img4 = nib.load(sample_filename4)
sample_img4 = np.asanyarray(sample_img4.dataobj)

sample_mask = nib.load(sample_filename_mask)
sample_mask = np.asanyarray(sample_mask.dataobj)

In [5]:
# Load trained model and GradCam++
vgg = load_vgg16(model_path)
gcpp = init_gcpp(vgg)

In [None]:
# Find GradCam++ mask
image = np.array([sample_img1, sample_img2, sample_img3, sample_img4])
image = np.swapaxes(image, 0, 3)

mean = np.load(mean_path)
std = np.load(std_path)

image = image[slice, :, :, :] # Choose the slice we want

image = transform_datapoint(image, mean, std)

In [None]:
# Remove the parts outside the brain
c_mask = contour_mask(image)

gcpp_mask, _ = gcpp(image)
gcpp_mask = gcpp_mask.numpy() * c_mask

In [22]:
gcpp_mask = F.upsample(torch.Tensor(gcpp_mask), size = (240, 240), mode = 'bilinear', align_corners = False)
gcpp_mask = gcpp_mask[0, 0, :].numpy()

In [24]:
# Rotate images for plotting
sample_img1 = np.rot90(sample_img1)
sample_img2 = np.rot90(sample_img2)
sample_img3 = np.rot90(sample_img3)
sample_img4 = np.rot90(sample_img4)
sample_mask = np.rot90(sample_mask)
gcpp_mask = np.rot90(gcpp_mask)

#Dividing into the different labels
mask_NCR_NET = sample_mask.copy()
mask_NCR_NET[mask_NCR_NET == 1] = 1
mask_NCR_NET[mask_NCR_NET == 2] = 0
mask_NCR_NET[mask_NCR_NET == 4] = 0

mask_ED = sample_mask.copy()
mask_ED[mask_ED == 1] = 0
mask_ED[mask_ED == 2] = 1
mask_ED[mask_ED == 4] = 0

mask_ET = sample_mask.copy()
mask_ET[mask_ET == 1] = 0
mask_ET[mask_ET == 2] = 0
mask_ET[mask_ET == 4] = 1

In [111]:
fig = plt.figure(figsize=(20, 10))

gs = gridspec.GridSpec(nrows=2, ncols=4, height_ratios=[1, 1.5])


<Figure size 1440x720 with 0 Axes>

In [112]:
##  Varying density along a streamline
#ax0 = fig.add_subplot(gs[0, 0])
#flair = ax0.imshow(sample_img1[:,:,slice], cmap='bone')
#ax0.set_title("FLAIR", fontsize=18, weight='bold', y=-0.2)
#fig.colorbar(flair)
#
##  Varying density along a streamline
#ax1 = fig.add_subplot(gs[0, 1])
#t1 = ax1.imshow(sample_img2[:,:,slice], cmap='bone')
#ax1.set_title("T1", fontsize=18, weight='bold', y=-0.2)
#fig.colorbar(t1)
#
##  Varying density along a streamline
#ax2 = fig.add_subplot(gs[0, 2])
#t2 = ax2.imshow(sample_img3[:,:,slice], cmap='bone')
#ax2.set_title("T2", fontsize=18, weight='bold', y=-0.2)
#fig.colorbar(t2)
#
##  Varying density along a streamline
#ax3 = fig.add_subplot(gs[0, 3])
#t1ce = ax3.imshow(sample_img4[:,:,slice], cmap='bone')
#ax3.set_title("T1 contrast", fontsize=18, weight='bold', y=-0.2)
#fig.colorbar(t1ce)

#  Varying density along a streamline
ax4 = fig.add_subplot(gs[1, 1:3])

l1 = ax4.imshow(sample_img1[:,:,slice], cmap='bone', alpha=1)
l2 = ax4.imshow(np.ma.masked_where(mask_NCR_NET[:,:,slice] == False, mask_NCR_NET[:,:,slice]), cmap='spring', alpha = 1)
l3 = ax4.imshow(np.ma.masked_where(mask_ED[:,:,slice]== False,  mask_ED[:,:,slice]), cmap='autumn', alpha=1)
l4 = ax4.imshow(np.ma.masked_where(mask_ET[:,:,slice] == False, mask_ET[:,:,slice]), cmap='winter', alpha=1)
l5 = ax4.imshow(np.ma.masked_where(gcpp_mask <= 0.001, gcpp_mask), cmap='Greens', alpha = 0.75)


ax4.set_title("", fontsize=20, weight='bold', y=-0.1)

Text(0.5, -0.1, '')

In [113]:
_ = [ax.set_axis_off() for ax in [ax0,ax1,ax2,ax3, ax4]]

#colors = [im.cmap(im.norm(1)) for im in [l1,l2, l3,l4]]
colors = [im.cmap(im.norm(1)) for im in [l2, l3,l4]]
colors.append("Green")
labels = ['Non-Enhancing tumor core', 'Peritumoral Edema ', 'GD-enhancing tumor', 'Grad-Cam++ prediction']
patches = [ mpatches.Patch(color=colors[i], label=f"{labels[i]}") for i in range(len(labels))]
# put those patched as legend-handles into the legend
#fig.legend(handles=patches, loc='center right', borderaxespad = 0.5, bbox_to_anchor = (0.85, 0.35), fontsize = 'xx-large', title='Mask Labels', title_fontsize=18, edgecolor="black",  facecolor='#c5c6c7')
#fig.legend(handles=patches, bbox_to_anchor=(1.1, 0.65), loc=2, borderaxespad=0.4,fontsize = 'xx-large', title='Mask Labels', title_fontsize=18, edgecolor="black",  facecolor='#c5c6c7')

#fig.suptitle("Multimodal Scans -  Data | Manually-segmented mask - Target", fontsize=20, weight='bold')

fig.savefig("slice_with_gradcam_no_legend.png", format="png", pad_inches = 0.4, transparent=False, bbox_inches='tight')
#fig.savefig("data_sample.png", format="png",  pad_inches=0.2, transparent=False, bbox_inches='tight')
#fig.savefig("data_sample.svg", format="svg",  pad_inches=0.2, transparent=False, bbox_inches='tight')