In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
from dataloader import train_dl, val_dl
from unet import UNet, ConvBlock
from loss_functions import DiceLoss, FocalTverskyLoss, dice_score
from tqdm import tqdm
import matplotlib.pyplot as plt
import json
import time
import os
from tqdm import tqdm
import ipywidgets as widgets
from ipywidgets import interact_manual, interact, IntSlider
import SimpleITK as sitk

AUGMENTATION:  True


In [2]:
folder = '.'
device = torch.device("cpu")
model = torch.load(os.path.join(folder, 'model.pth')).to(device)

In [3]:
x, y = next(iter(val_dl))
x, y = x.to(device), y.to(device)
print(x.shape)
print(y.shape)

torch.Size([8, 4, 160, 256, 256])
torch.Size([8, 1, 160, 256, 256])


In [4]:
pred = model(x)
print(pred.shape)

torch.Size([8, 1, 160, 256, 256])


In [5]:
mri_types = ['t1', 't1ce', 't2', 'flair', 'segm', 'pred']

@interact_manual
def choose_patient(patient=range(0, 8)):
    
    images = {'t1' : x[patient, 0, :, :, :].detach().to('cpu'),
              't1ce' : x[patient, 1, :, :, :].detach().to('cpu'),
              't2' : x[patient, 2, :, :, :].detach().to('cpu'),
              'flair' : x[patient, 3, :, :, :].detach().to('cpu'),
              'segm' : y[patient, 0, :, :, :].detach().to('cpu'),
              'pred' : pred[patient, 0, :, :, :].detach().to('cpu')}
    
    @interact_manual
    def choose_mri_type1(mri_type1 = mri_types):
        
        @interact_manual
        def choose_mri_type2(mri_type2 = mri_types):
        
            slider = IntSlider(value=77, min=0, max=154, step=1, orientation='horizontal', 
                               continuous_update = True, readout=True)

            @interact
            def show_slice(coord_z = slider):
                image1 = images[mri_type1]
                image2 = images[mri_type2]

                plt.figure(figsize=(8, 8))
                plt.subplot(1, 2, 1)
                plt.imshow(image1[coord_z], cmap='gray')
                plt.axis('off')

                plt.subplot(1, 2, 2)
                plt.imshow(image2[coord_z], cmap='gray')
                plt.axis('off')
                plt.show()

interactive(children=(Dropdown(description='patient', options=(0, 1, 2, 3, 4, 5, 6, 7), value=0), Button(descr…

In [6]:
del x, y, pred
torch.cuda.empty_cache()

In [7]:
def evaluate(dl):
    dscs = []
    for x, y in tqdm(dl):
        x, y = x.to(device), y.to(device)
        pred = model(x).detach().to('cpu')
        y = y.detach().to('cpu')
        dsc = dice_score(y, pred)
        dscs.append(dsc)
        
        del x, y, pred
        torch.cuda.empty_cache()
        
    mean_dsc = sum(dscs) / len(dscs)
    return mean_dsc

In [8]:
#mean_dsc = evaluate(train_dl)
#print('train dsc: ', mean_dsc)

mean_dsc = evaluate(val_dl)
print('val dsc: ', mean_dsc)

100%|██████████| 13/13 [02:20<00:00, 10.82s/it]

val dsc:  0.6745324514492814





# Save validation predictions to hdf5

In [3]:
from dataloader import val_ds
import h5py

In [15]:
file = 'val_preds.hdf5'
with h5py.File(file, 'w') as hf:
    for i in tqdm(range(len(val_ds))):
        x, y, patient_id = val_ds[i]
        x_tensor = torch.from_numpy(x).to(device)
        pred = model(x_tensor[None, ...])[0].detach().to('cpu')
        pred = np.array(pred)
        hf.create_dataset(str(i), data=pred)
        
    hf.close()

100%|██████████| 100/100 [02:36<00:00,  1.57s/it]
