# Initial Imports

In [1]:
from __future__ import print_function
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
import tifffile
from mpl_toolkits.axes_grid1 import make_axes_locatable
from tqdm import tqdm
import math

# Load Model

In [None]:
i = input('Enter the file path for your trained 3D U-Net model \n')
model = tf.keras.models.load_model('{}'.format(i))

# Load Data

In [None]:
seed = 42
np.random.seed = seed


IMG_WIDTH = int(input('Please input the image x dimension \n'))
IMG_HEIGHT = int(input('Please input the image y dimension \n'))
IMG_DEPTH = int(input('Please input the image z dimension \n'))
IMG_CHANNELS = int(input('Please input the number of channels in the images \n'))
num_train_imgs = int(input('Please input the number of training images \n'))
num_test_imgs = int(input('Please input the number of testing images \n'))

x_train = np.zeros((num_train_imgs, IMG_DEPTH, IMG_HEIGHT, IMG_WIDTH), dtype=np.uint8)
y_train = np.zeros((num_train_imgs, IMG_DEPTH, IMG_HEIGHT, IMG_WIDTH), dtype=np.bool)

x_test = np.zeros((num_test_imgs, IMG_DEPTH, IMG_HEIGHT, IMG_WIDTH), dtype=np.uint8)
y_test = np.zeros((num_test_imgs, IMG_DEPTH, IMG_HEIGHT, IMG_WIDTH), dtype=np.uint8)

# Use your own file paths
counter = 0
while counter < num_train_imgs: 
    x_train[counter] = tifffile.imread('file_path_to_cropped_imgs_'+f"{counter:03}" + '.tif')
    y_train[counter] = tifffile.imread('file_path_to_cropped_masks_'+f"{counter:03}" + '.tif')
    counter +=1

counter = 0
while counter < num_test_imgs:
    x_test[counter] = tifffile.imread('file_path_to_cropped_imgs_' + f"{counter:03}" + '.tif')
    y_test[counter] = tifffile.imread('file_path_to_cropped_masks_' + f"{counter:03}" + '.tif')
    counter+=1
print('Finished Loading Data!')

# Using Model to Predict on Data

In [None]:
# Use this block for training images
img_number = int(input('Please input which training image you would like to run a prediction on \n'))
if img_number < len(x_train):
    img = x_train[img_number]
    gt = y_train[img_number]
    pd = model.predict(img.reshape([1, img.shape[0], img.shape[1], img.shape[2], 1]), verbose=1)
    print('Prediction Complete!')
else: 
    print('This number is out of range for the x_train array size.')

In [None]:
# Use this block for testing images
img_number2 = int(input('Please input which testing image you would like to run a prediction on \n'))
if img_number < len(x_test):
    img = x_test[img_number2]
    gt = y_test[img_number2]
    pd = model.predict(img.reshape([1, img.shape[0], img.shape[1], img.shape[2], 1]), verbose=1)
    print('Prediction Complete!')
else: 
    print('This number is out of range for the x_test array size.')

# Preprocessing the Prediction

In [None]:
def pd_preprocessing(pd):
    label = pd
    label = np.squeeze(label)
    label = np.transpose(label, (3, 0, 1, 2))
    label = label[1]
    return label

label = pd_preprocessing(pd)
print('Done')

# Testing Model Prediction 

In [None]:
# Run to view an interactive slider plot to view all slices of raw image, prediction, and GT label
def normal_plot(x):
    fig = plt.figure(figsize=(22,17))
    plt.suptitle('U-Net Output', fontsize=40)
    ax1 = fig.add_subplot(231)
    plt.title('Raw Image', fontsize=30)
    im1 = ax1.imshow(img[x], interpolation='None')
    divider = make_axes_locatable(ax1)
    cax = divider.append_axes('right', size='5%', pad=0.05)
    fig.colorbar(im1, cax=cax, orientation='vertical')
    
    ax2 = fig.add_subplot(232)
    plt.title('Predicted Label', fontsize=30)
    im2 = ax2.imshow(label[x], interpolation='None')
    divider = make_axes_locatable(ax2)
    cax = divider.append_axes('right', size='5%', pad=0.05)
    fig.colorbar(im2, cax=cax, orientation='vertical')
    
    ax3 = fig.add_subplot(233)
    plt.title('GT Label', fontsize=24)
    im3 = ax3.imshow(gt[x], interpolation='None')
    divider = make_axes_locatable(ax3)
    cax = divider.append_axes('right', size='5%', pad=0.05)
    fig.colorbar(im3, cax=cax, orientation='vertical')  
    return x
interact(normal_plot, x=(0, len(label)-1))
