In [None]:
import torch
import torch.nn.functional as F
import h5py
import numpy as np
import matplotlib.pyplot as plt
import os
from scipy.special import softmax
import sys
import model_zoo
import data_freiburg_numpy_to_hdf5
from utils import make_dir_safely, normalize_image
from losses import compute_dice

In [None]:
# Bern data
# This has already done Bern_numpy_to_hdf5.py
basepath = "/usr/bmicnas02/data-biwi-01/jeremy_students/data/inselspital/kady"
bern_tr = h5py.File(basepath + '/bern_images_and_labels_from_101_to_104.hdf5','r')
bern_vl = h5py.File(basepath + '/bern_images_and_labels_from_105_to_106.hdf5','r')
images_tr = bern_tr['images_train']
labels_tr = bern_tr['labels_train']
images_vl = bern_vl['images_validation']
labels_vl = bern_vl['labels_validation']        

## Using model

In [None]:
# load saved model
loss = "dice"
out_channels = 2
in_channels = 4
run = 1
note = '_full_run_fine_tune_Bern_lr_0.5e-3_scheduler_e50'
cut_z = 3
da = 0.0
model_path = f'/usr/bmicnas02/data-biwi-01/jeremy_students/lschlyter/CNN-segmentation/logdir/unet3d_da_{da}nchannels{in_channels}_r{run}_loss_{loss}_cut_z_{cut_z}{note}'
#model_path = f'/usr/bmicnas02/data-biwi-01/jeremy_students/lschlyter/CNN-segmentation/logdir/unet3d_da_{da}nchannels{in_channels}_r{run}_loss_{loss}_cut_z_{cut_z}{note}'
#model_path = f'/usr/bmicnas02/data-biwi-01/jeremy_students/lschlyter/CNN-segmentation/logdir/unet3d_da_{da}nchannels{in_channels}_r_phase_dice_cut_z_5_debug'
#model_path = "/usr/bmicnas02/data-biwi-01/jeremy_students/lschlyter/CNN-segmentation/logdir/unet3d_da_0.0nchannels1_r_DEBUG_LOSS_phase_dice_cut_z_0"
best_model_path = os.path.join(model_path, list(filter(lambda x: 'best' in x, os.listdir(model_path)))[-1])
print(best_model_path)
model = model_zoo.UNet(in_channels, out_channels);
model.load_state_dict(torch.load(best_model_path, map_location=torch.device('cpu')));
model.eval();

In [None]:
# Here you should do the cutting... 
def cut_z_slices(images, labels, n_cut):
    n_data = images.shape[0]
    index = np.arange(n_data)
    # We know we have 32 slices
    # First dim is the number of patients
    index_shaped = index.reshape(-1, 32)
    index_keep = index_shaped[:, n_cut:-n_cut].flatten()
    return images[index_keep], labels[index_keep]

if cut_z != 0:
    images_tr, labels_tr = cut_z_slices(images_tr, labels_tr, n_cut = cut_z)
    images_vl, labels_vl = cut_z_slices(images_vl, labels_vl, n_cut = cut_z)

In [None]:
images_tr.shape, labels_tr.shape, images_vl.shape, labels_vl.shape

### Compute loss 

In [None]:
def iterate_minibatches_validation(images, labels, batch_size):
    """
    Function to create mini batches from the dataset of a certain batch size
    :param images: input images
    :param labels: labels
    :param batch_size: batch size
    :return: mini batches"""
    assert len(images) == len(labels)
    
    # Generate randomly selected slices in each minibatch

    n_images = images.shape[0]
    random_indices = np.arange(n_images)
    np.random.shuffle(random_indices)

    # Use only fraction of the batches in each epoch

    for b_i in range(0, n_images, batch_size):

        if b_i + batch_size > n_images:
            continue


        # HDF5 requires indices to be in increasing order
        batch_indices = np.sort(random_indices[b_i:b_i+batch_size])

        X = images[batch_indices, ...]
        y = labels[batch_indices, ...]
        
        # ===========================
        # check if the velocity fields are to be used for the segmentation...
        # ===========================
        if in_channels == 1:
            X = X[..., 1:2]
        
        yield X, y


In [None]:
def dice_score_bern(images_set, labels_set, batch_size = 4):
        dice_score = 0
        for n_batch, batch in enumerate(iterate_minibatches_validation(images_set, labels_set, batch_size = batch_size)):
                model.eval()
                with torch.no_grad():
                        inputs, labels = batch
                
                        # From numpy.ndarray to tensors
                
                        # Input (batch_size, x,y,t,channel_number)
                        inputs = torch.from_numpy(inputs)
                        # Input (batch_size, channell,x,y,t)
                        inputs.transpose_(1,4).transpose_(2,4).transpose_(3,4)
                        # Labels (batch,size, x,y,t)
                
                        #inputs = inputs.to(device)
                        labels = torch.from_numpy(labels)#.to(device)
                        labels = torch.nn.functional.one_hot(labels.long(), num_classes = out_channels)
                        labels = labels.transpose(1,4).transpose(2,4).transpose(3,4)
                        if labels.shape[0] < batch_size:
                                continue
                        
                        logits = model(inputs.float())
                        
                        _, mean_dice,_= compute_dice(logits, labels)
                        
                        dice_score += mean_dice  
        return dice_score, n_batch
        

In [None]:
images_vl.shape, labels_vl.shape, images_tr.shape, labels_tr.shape

In [None]:
print("Training dice")
dice_score_tr, n_batch_tr = dice_score_bern(images_tr, labels_tr, batch_size = 4)
print("Total average dice score: ", dice_score_tr/(n_batch_tr+1))

In [None]:
dice_score_tr/(n_batch_tr+1)

In [None]:
print("Validation dice")
dice_score_vl, n_batch_vl = dice_score_bern(images_vl, labels_vl, batch_size = 4)
print("Total average dice score: ", dice_score_vl/(n_batch_vl+1))

In [None]:
dice_score_vl/(n_batch_vl+1)

## Visualization outputs

In [None]:
viz_on = "validation" # training/validation
best_model_path

In [None]:
print(images_vl.shape[0])

In [None]:
print(images_vl.shape[0])
indexes = np.arange(images_vl.shape[0]//2,images_vl.shape[0] , 1)

inputs = images_vl[indexes]
labels = labels_vl[indexes]
print(inputs.shape, labels.shape)
inputs = torch.from_numpy(inputs)
# Input (batch_size, channell,x,y,z)
inputs.transpose_(1,4).transpose_(2,4).transpose_(3,4)
labels = torch.from_numpy(labels)
print(inputs.shape, labels.shape)
preds = np.array([]).reshape(0, 2,144, 112,48)
for i in range(inputs.shape[0]):
    model.eval()
    with torch.no_grad():
        pred = model(inputs[None, i, :, :,:,:])
        preds = np.vstack([preds, pred.detach().numpy()])
        prediction = softmax(preds, axis=1).argmax(axis = 1)


In [None]:
prediction_axes = np.transpose(prediction, (1, 2, 0, 3))

np.save('prediction_validation_0_full_run_fine_tune_Bern_lr_0.5e-3_e50.npy', prediction_axes)

In [None]:
patient_n = 0
# By patient
batch_size = 9
indexes_z_slices =np.arange(0, (32 - 2*cut_z), (32 - cut_z)//8)
indexes =np.arange((32 -2 *cut_z)*patient_n, (32 - 2*cut_z)*(patient_n + 1), (32 - cut_z)//8)
inputs = images_vl[indexes]
labels = labels_vl[indexes]

print(inputs.shape, labels.shape)
inputs = torch.from_numpy(inputs)
# Input (batch_size, channell,x,y,z)
inputs.transpose_(1,4).transpose_(2,4).transpose_(3,4)
labels = torch.from_numpy(labels)
print(inputs.shape, labels.shape)
n_range = 9
preds = np.array([]).reshape(0, 2,144, 112,48)
for i in range(n_range):
    model.eval()
    with torch.no_grad():
        pred = model(inputs[i*batch_size//n_range:(i+1)*batch_size//n_range])
        preds = np.vstack([preds, pred.detach().numpy()])
        prediction = softmax(preds, axis=1).argmax(axis = 1)

In [None]:
%matplotlib inline
def presentation_viz_by_patient(input, gt, pred, time, save_path, note_save, indexes, n_channels = 4):
    z_slices = np.arange(0, len(indexes))
    
    # Create directory if it does not exist
    if save_path is not None:
        make_dir_safely(save_path)
    if n_channels == 4:
        h = 18
    else:
        h = 7
    
    fig, axs = plt.subplots(2+n_channels, len(z_slices), figsize = (18,h))
    nbatch = 0
    n_chan = n_channels
    ax = axs.reshape(-1)
    axes_index = 0
    for chan in range(n_channels):
        
        for i, z_slice in enumerate(indexes):
            
            ax[axes_index].imshow(input[i, n_channels - n_chan, :, :, time])
            ax[axes_index].set_title(f"z_{z_slice}_ch:{n_channels - n_chan}_t_{time}", fontsize = 10)
            axes_index += 1
        n_chan -= 1
    for i, z_slice in enumerate(indexes):
        ax[axes_index].imshow(pred[i, :,:, time])
        ax[axes_index].set_title(f"z_{z_slice}_pred_t_{time}", fontsize = 10)
        axes_index +=1
    for i, z_slice in enumerate(indexes):
        ax[axes_index].imshow(gt[i, :,:, time])
        ax[axes_index].set_title(f"z_{z_slice}_gt_t_{time}", fontsize = 10)
        axes_index +=1
    plt.savefig(save_path + f"{note_save}_t_{time}.png", bbox_inches='tight')
    plt.show()
    
        

In [None]:
save_viz = model_path + f"/results/visualization/notebook_viz/{viz_on}"

In [None]:
save_viz

In [None]:
presentation_viz_by_patient(input=inputs, gt=labels, pred=prediction, time=3, save_path = save_viz, n_channels = in_channels, indexes = indexes_z_slices,note_save = "patient_0")