# Inference

In [7]:
import os
import glob
import numpy as np
import ipywidgets as ipyw
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
import torchio as tio

from utils.utils import get_geodismaps
from models.networks import P_RNet3D
from data_loaders.transforms import get_transform

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


## load model weights

In [2]:
pnet_best_ckpt_dir = "./experiments/best_ckpts/brats3d_pnet_init_train"
pnet_best_ckpt_path = sorted(glob.glob(f"{pnet_best_ckpt_dir}/*.pt"))[-1]

rnet_best_ckpt_dir = "./experiments/best_ckpts/brats3d_rnet_init_train"
rnet_best_ckpt_path = sorted(glob.glob(f"{rnet_best_ckpt_dir}/*.pt"))[-1]

pnet = P_RNet3D(c_in=1, c_blk=16, n_classes=2).to(device)
rnet = P_RNet3D(c_in=4, c_blk=16, n_classes=2).to(device)

pnet.load_state_dict(torch.load(pnet_best_ckpt_path))
rnet.load_state_dict(torch.load(rnet_best_ckpt_path))

pnet.eval()
rnet.eval()

P_RNet3D(
  (block1): Sequential(
    (0): Conv3d(4, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (1): ReLU()
    (2): Conv3d(16, 16, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0))
    (3): ReLU()
  )
  (block1_downsample): Sequential(
    (0): Conv3d(16, 4, kernel_size=(1, 1, 1), stride=(1, 1, 1))
    (1): ReLU()
  )
  (block2): Sequential(
    (0): Conv3d(16, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(2, 2, 2), dilation=(2, 2, 2))
    (1): ReLU()
    (2): Conv3d(16, 16, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(2, 2, 0), dilation=(2, 2, 1))
    (3): ReLU()
  )
  (block2_downsample): Sequential(
    (0): Conv3d(16, 4, kernel_size=(1, 1, 1), stride=(1, 1, 1))
    (1): ReLU()
  )
  (block3): Sequential(
    (0): Conv3d(16, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(4, 4, 4), dilation=(4, 4, 4))
    (1): ReLU()
    (2): Conv3d(16, 16, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(4, 4, 0), dilation=(4, 4, 1))
    (3): 

## Inference on test datas

In [3]:
save_dir = "./results"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

In [4]:
test_dir = "./dataset/test"
test_images = glob.glob(f"{test_dir}/*/*_flair.nii.gz")
test_labels = glob.glob(f"{test_dir}/*/*_seg.nii.gz")

test_transform = get_transform("valid")
n_classes = 2
target_class = 1

for image_path, label_path in tqdm(zip(test_images, test_labels), total=len(test_images)):
    input_subject = tio.Subject(
        image = tio.ScalarImage(image_path),
        label = tio.LabelMap(label_path)
    )
    input_subject = test_transform(input_subject)
    inputs = input_subject.image.data.unsqueeze(dim=0).to(device)
    true_labels = input_subject.label.data[0, ...].unsqueeze(dim=0).to(device)

    with torch.no_grad():
        pred_logits = pnet(inputs)
        pred_labels_pnet = torch.argmax(pred_logits, dim=1)
        fore_dist_map, back_dist_map = get_geodismaps(inputs, true_labels, pred_labels_pnet)
        rnet_inputs = torch.cat([
            inputs,
            pred_labels_pnet.unsqueeze(dim=1), 
            torch.Tensor(fore_dist_map).unsqueeze(dim=1).to(device), 
            torch.Tensor(back_dist_map).unsqueeze(dim=1).to(device)
        ], dim=1)

        pred_logits = rnet(rnet_inputs)
        pred_labels_rnet = torch.argmax(pred_logits, dim=1)

        pred_onehot_pnet = torch.nn.functional.one_hot(pred_labels_pnet, n_classes).permute(0, 4, 1, 2, 3)
        pred_onehot_target_pnet = pred_onehot_pnet[:, target_class, ...]

        pred_onehot_rnet = torch.nn.functional.one_hot(pred_labels_rnet, n_classes).permute(0, 4, 1, 2, 3)
        pred_onehot_target_rnet = pred_onehot_rnet[:, target_class, ...]
    
    pred_labelmap_pnet = tio.LabelMap(
        tensor=pred_onehot_target_pnet.cpu(),
        affine=input_subject.image.affine
    )
    pred_labelmap_rnet = tio.LabelMap(
        tensor=pred_onehot_target_rnet.cpu(),
        affine=input_subject.image.affine
    )

    save_path_pnet = os.path.join(
        save_dir,
        os.path.basename(image_path).replace("_flair.nii.gz", "_pnet_pred.nii.gz")
    )
    save_path_rnet = os.path.join(
        save_dir,
        os.path.basename(image_path).replace("_flair.nii.gz", "_rnet_pred.nii.gz")
    )

    pred_labelmap_pnet.save(save_path_pnet)
    pred_labelmap_rnet.save(save_path_rnet)


100%|██████████| 20/20 [05:53<00:00, 17.68s/it]


## Visualize results

In [5]:


class ImageSliceViewer3D_with_prediction:
    
    # reference = https://github.com/esmitt/imageSliceViewer/blob/master/SliceViewer.ipynb
    
    """ 
    ImageSliceViewer3D is for viewing volumetric image slices in jupyter or
    ipython notebooks. 
    
    User can interactively change the slice plane selection for the image and 
    the slice plane being viewed. 

    Argumentss:
    Volume = 3D input image
    figsize = default(8,8), to set the size of the figure
    cmap = default('gray'), string for the matplotlib colormap. You can find 
    more matplotlib colormaps on the following link:
    https://matplotlib.org/users/colormaps.html
    
    """
    
    def __init__(self, image, label, pred_pnet, pred_rnet, overlap = False, figsize=(50,50), cmap_image='gray', cmap_label = 'viridis', cmap_pred_pnet = 'inferno', cmap_pred_rnet = 'magma'):
        self.image = image
        self.label = label 
        self.pred_pnet = pred_pnet
        self.pred_rnet = pred_rnet
        self.figsize = figsize
        self.cmap_image = cmap_image
        self.cmap_label = cmap_label
        self.cmap_pred_pnet = cmap_pred_pnet
        self.cmap_pred_rnet = cmap_pred_rnet
        self.v = [np.min(image), np.max(image)]
        self.v_label = [np.min(label), np.max(label)]
        self.v_pred_pnet = [np.min(pred_pnet), np.max(pred_pnet)]
        self.v_pred_rnet = [np.min(pred_rnet), np.max(pred_rnet)]
        self.overlap = overlap

        # Call to select slice plane
        ipyw.interact(self.views)
    
    def views(self):
        self.vol1 = np.transpose(self.image, [1,2,0])
        self.vol2 = np.rot90(np.transpose(self.image, [2,0,1]), 3) #rotate 270 degrees
        self.vol3 = np.transpose(self.image, [0,1,2])
        maxZ1 = self.vol1.shape[2] - 1
        maxZ2 = self.vol2.shape[2] - 1
        maxZ3 = self.vol3.shape[2] - 1
        
        self.vol1_label = np.transpose(self.label, [1,2,0])
        self.vol2_label = np.rot90(np.transpose(self.label, [2,0,1]), 3) #rotate 270 degrees
        self.vol3_label = np.transpose(self.label, [0,1,2])
        
        self.vol1_pred_pnet = np.transpose(self.pred_pnet, [1,2,0])
        self.vol2_pred_pnet = np.rot90(np.transpose(self.pred_pnet, [2,0,1]), 3) #rotate 270 degrees
        self.vol3_pred_pnet = np.transpose(self.pred_pnet, [0,1,2])
        
        self.vol1_pred_rnet = np.transpose(self.pred_rnet, [1,2,0])
        self.vol2_pred_rnet = np.rot90(np.transpose(self.pred_rnet, [2,0,1]), 3) #rotate 270 degrees
        self.vol3_pred_rnet = np.transpose(self.pred_rnet, [0,1,2])
        
        ipyw.interact(self.plot_slice, 
            z1=ipyw.IntSlider(min=0, max=maxZ1, step=1, continuous_update=False, 
            description='Axial:'), 
            z2=ipyw.IntSlider(min=0, max=maxZ2, step=1, continuous_update=False, 
            description='Coronal:'),
            z3=ipyw.IntSlider(min=0, max=maxZ3, step=1, continuous_update=False, 
            description='Sagittal:'))

    def plot_slice(self, z1, z2, z3):
        
        if self.overlap:

            f, ax = plt.subplots(1,3, figsize=self.figsize)
            ax[0].imshow(self.vol1[:,:,z1], cmap=plt.get_cmap(self.cmap_image), 
                vmin=self.v[0], vmax=self.v[1])
            ax[1].imshow(self.vol2[:,:,z2], cmap=plt.get_cmap(self.cmap_image), 
                vmin=self.v[0], vmax=self.v[1])
            ax[2].imshow(self.vol3[:,:,z3], cmap=plt.get_cmap(self.cmap_image), 
                vmin=self.v[0], vmax=self.v[1])
            
            ax[0].imshow(self.vol1_label[:,:,z1], cmap=plt.get_cmap(self.cmap_label),
                vmin=self.v_label[0], vmax=self.v_label[1], alpha = 0.3)
            ax[1].imshow(self.vol2_label[:,:,z2], cmap=plt.get_cmap(self.cmap_label),
                vmin=self.v_label[0], vmax=self.v_label[1], alpha = 0.3)
            ax[2].imshow(self.vol3_label[:,:,z3], cmap=plt.get_cmap(self.cmap_label),
                vmin=self.v_label[0], vmax=self.v_label[1], alpha = 0.3)
            
            ax[0].imshow(self.vol1_pred_pnet[:,:,z1], cmap=plt.get_cmap(self.cmap_pred_pnet),
                vmin=self.v_pred_pnet[0], vmax=self.v_pred_pnet[1], alpha = 0.3)
            ax[1].imshow(self.vol2_pred_pnet[:,:,z2], cmap=plt.get_cmap(self.cmap_pred_pnet),
                vmin=self.v_pred_pnet[0], vmax=self.v_pred_pnet[1], alpha = 0.3)
            ax[2].imshow(self.vol3_pred_pnet[:,:,z3], cmap=plt.get_cmap(self.cmap_pred_pnet),
                vmin=self.v_pred_pnet[0], vmax=self.v_pred_pnet[1], alpha = 0.3)

            ax[0].imshow(self.vol1_pred_rnet[:,:,z1], cmap=plt.get_cmap(self.cmap_pred_rnet),
                vmin=self.v_pred_rnet[0], vmax=self.v_pred_rnet[1], alpha = 0.3)
            ax[1].imshow(self.vol2_pred_rnet[:,:,z2], cmap=plt.get_cmap(self.cmap_pred_rnet),
                vmin=self.v_pred_rnet[0], vmax=self.v_pred_rnet[1], alpha = 0.3)
            ax[2].imshow(self.vol3_pred_rnet[:,:,z3], cmap=plt.get_cmap(self.cmap_pred_rnet),
                vmin=self.v_pred_rnet[0], vmax=self.v_pred_rnet[1], alpha = 0.3)
            
            plt.show()
            
            
        else:
            # default settings: without overlap plots
            f, ax = plt.subplots(4,3, figsize=self.figsize)
            ax[0][0].imshow(self.vol1[:,:,z1], cmap=plt.get_cmap(self.cmap_image), 
                vmin=self.v[0], vmax=self.v[1])
            ax[0][1].imshow(self.vol2[:,:,z2], cmap=plt.get_cmap(self.cmap_image), 
                vmin=self.v[0], vmax=self.v[1])
            ax[0][2].imshow(self.vol3[:,:,z3], cmap=plt.get_cmap(self.cmap_image), 
                vmin=self.v[0], vmax=self.v[1])
            
            ax[1][0].imshow(self.vol1_label[:,:,z1], cmap=plt.get_cmap(self.cmap_label),
                vmin=self.v_label[0], vmax=self.v_label[1])
            ax[1][1].imshow(self.vol2_label[:,:,z2], cmap=plt.get_cmap(self.cmap_label),
                vmin=self.v_label[0], vmax=self.v_label[1])
            ax[1][2].imshow(self.vol3_label[:,:,z3], cmap=plt.get_cmap(self.cmap_label),
                vmin=self.v_label[0], vmax=self.v_label[1])
            
            ax[2][0].imshow(self.vol1_pred_pnet[:,:,z1], cmap=plt.get_cmap(self.cmap_pred_pnet),
                vmin=self.v_pred_pnet[0], vmax=self.v_pred_pnet[1], alpha = 0.3)
            ax[2][1].imshow(self.vol2_pred_pnet[:,:,z2], cmap=plt.get_cmap(self.cmap_pred_pnet),
                vmin=self.v_pred_pnet[0], vmax=self.v_pred_pnet[1], alpha = 0.3)
            ax[2][2].imshow(self.vol3_pred_pnet[:,:,z3], cmap=plt.get_cmap(self.cmap_pred_pnet),
                vmin=self.v_pred_pnet[0], vmax=self.v_pred_pnet[1], alpha = 0.3)

            ax[3][0].imshow(self.vol1_pred_rnet[:,:,z1], cmap=plt.get_cmap(self.cmap_pred_rnet),
                vmin=self.v_pred_rnet[0], vmax=self.v_pred_rnet[1], alpha = 0.3)
            ax[3][1].imshow(self.vol2_pred_rnet[:,:,z2], cmap=plt.get_cmap(self.cmap_pred_rnet),
                vmin=self.v_pred_rnet[0], vmax=self.v_pred_rnet[1], alpha = 0.3)
            ax[3][2].imshow(self.vol3_pred_rnet[:,:,z3], cmap=plt.get_cmap(self.cmap_pred_rnet),
                vmin=self.v_pred_rnet[0], vmax=self.v_pred_rnet[1], alpha = 0.3)
            
            plt.show()

In [9]:
pnet_inferences = glob.glob(f"{save_dir}/*_pnet_pred.nii.gz")
rnet_inferences = glob.glob(f"{save_dir}/*_rnet_pred.nii.gz")
transform = tio.Compose([
    tio.ToCanonical(), 
    tio.Resample(1), 
    tio.RemapLabels({2:1, 3:1, 4:1})
])

image_path, label_path, pnet_inf_path, rnet_inf_path = test_images[5], test_labels[5], pnet_inferences[5], rnet_inferences[5]

test_subject = tio.Subject(
    image = tio.ScalarImage(image_path),
    label = tio.LabelMap(label_path),
    pnet_inf = tio.LabelMap(pnet_inf_path),
    rnet_inf = tio.LabelMap(rnet_inf_path)
)
test_subject = transform(test_subject)

image_np = test_subject.image.data.squeeze(dim=0).numpy()
label_np = test_subject.label.data.squeeze(dim=0).numpy()
pnet_inf_np = test_subject.pnet_inf.data.squeeze(dim=0).numpy()
rnet_inf_np = test_subject.rnet_inf.data.squeeze(dim=0).numpy()

ImageSliceViewer3D_with_prediction(image_np, label_np, pnet_inf_np, rnet_inf_np)

interactive(children=(Output(),), _dom_classes=('widget-interact',))

<__main__.ImageSliceViewer3D_with_prediction at 0x7f9e58465f50>