# Prediction for 3D U-Net model

### Installation of patchify library

In [2]:
#Use patchify to break large volumes into smaller for training 
#and also to put patches back together after prediction.
!pip install patchify



### GPU availability 

In [4]:
#Make sure the GPU is available. 
import tensorflow as tf
device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
    raise SystemError('GPU device not found')
print('Found GPU at: {}'.format(device_name))

SystemError: GPU device not found

### Loading the four pairs of images and masks tif files

In [4]:
from skimage import io
from patchify import patchify, unpatchify
import numpy as np
from matplotlib import pyplot as plt
from keras import backend as K
from tensorflow.keras.utils import to_categorical
from sklearn.model_selection import train_test_split

In [13]:
#Load input images and masks. 
#Here we load 180x1024x1024 pixel volume. We will break it into patches of 16x128x128 for training, with a separation step of 5x128x128.
#This means that there is an overlap between the first patch and the next one of 5 in the depth dimension. 

#IMPORTANT: you need to change the file paths in case you do not upload the tif files in the same folder.. 
image = []

image.append(io.imread('/images/wt_pom1D_01_07_R3D_REF_image.tif'))
image.append(io.imread('/images/wt_pom1D_01_15_R3D_REF_image.tif'))
image.append(io.imread('/images/wt_pom1D_01_20_R3D_REF_image.tif'))
image.append(io.imread('/images/train/wt_pom1D_01_30_R3D_REF_image.tif'))

img_patches = []
img_patches.append(patchify(image[0], (16, 128, 128), step=(5, 128, 128)))
img_patches.append(patchify(image[1], (16, 128, 128), step=(5, 128, 128))) 
img_patches.append(patchify(image[2], (16, 128, 128), step=(5, 128, 128)))  
img_patches.append(patchify(image[3], (16, 128, 128), step=(5, 128, 128)))  

In [None]:
mask = []

mask.append(io.imread('/masks/wt_pom1D_01_07_R3D_REF_mask.tif'))
mask.append(io.imread('/masks/wt_pom1D_01_15_R3D_REF_mask.tif'))
mask.append(io.imread('/masks/wt_pom1D_01_20_R3D_REF_mask.tif'))
mask.append(io.imread('/masks/wt_pom1D_01_30_R3D_REF_mask.tif'))

mask_patches = []
mask_patches.append(patchify(mask[0], (16, 128, 128), step=(5, 128, 128)))
mask_patches.append(patchify(mask[1], (16, 128, 128), step=(5, 64, 64)))  
mask_patches.append(patchify(mask[2], (16, 128, 128), step=(5, 128, 128)))  
mask_patches.append(patchify(mask[3], (16, 128, 128), step=(5, 128, 128)))  

## Data pre processing

### Reshape the inputs

In [None]:
#We reshape the patches to get an input image and input mask of shape (N, D, H, W), where N is the total number of patches, 
# D is the depth size of the patches, H is the height size of the patches, and W is the width size of the patches. 
input_img = np.reshape(img_patches[0], (-1, img_patches[0].shape[3], img_patches[0].shape[4], img_patches[0].shape[5]))
input_mask = np.reshape(mask_patches[0], (-1, mask_patches[0].shape[3], mask_patches[0].shape[4], mask_patches[0].shape[5]))

for i in range(1, 4):
    input_img += np.reshape(img_patches[i], (-1, img_patches[i].shape[3], img_patches[i].shape[4], img_patches[i].shape[5]))
    input_mask += np.reshape(mask_patches[i], (-1, mask_patches[i].shape[3], mask_patches[i].shape[4], mask_patches[i].shape[5]))

input_img = np.array(input_img)
input_mask = np.array(input_mask)

print(input_img.shape)
print(input_mask.shape)

### Removing empy patches 

In [None]:
#Keep in a variable all the indices where the patches are empty in the whole sequences 
idx_img = np.where(input_mask.mean(axis=(1,2,3)) != 0)[0]

input_img = input_img[idx_img]
input_mask = input_mask[idx_img]

#Print the number of patches we have 
print(input_img.shape[0])
print(input_mask.shape[0])

In [22]:
#Standardize the input array of pixels by the maximum value the pixels
train_img = input_img / input_img.max() 

#Expand the dimension of the training sets by 1 to match with the input of the model (i.e. chanel number). 
train_img = np.expand_dims(train_img, axis=4)
train_mask = np.expand_dims(input_mask, axis=4)

#Since we are performing a binary segmentation, we need to binarize the masks. 
train_mask[train_mask>1] = 1

#Finally, we perform one hot encoding with the function to_categorical with a chosen number of classes of 2. 
n_classes=2
train_mask_cat = to_categorical(train_mask, num_classes=n_classes)

### Split randomly into training and validation set 

In [3]:
X_train, X_test, y_train, y_test = train_test_split(train_img, train_mask_cat, test_size = 0.20, random_state = 0)
print(X_train.shape)
print(X_test.shape)

NameError: name 'train_test_split' is not defined

### Prediction with the model we trained

In [None]:
import numpy as np
from keras.models import load_model

# Load the pre-trained model and predict from it 
# You need to specify another path file if the pre-trained model is not in the one one we specified, or simply add it to the /saved_models folder
model = load_model('/saved_models/3dunetmodel_leaky_bs4_16x128x128.h5', compile=False)

# Predict with the model trained
y_pred=model.predict(X_test)

#Predict on the test data
y_pred_argmax=np.argmax(y_pred, axis=4)
y_test_argmax = np.argmax(y_test, axis=4)

print(y_pred_argmax.shape)
print(y_test_argmax.shape)
print(np.unique(y_pred_argmax))

### Mean IoU 

In [None]:
#Using built in keras function for IoU
#Only works on TF > 2.0
from keras.metrics import MeanIoU
n_classes = 2
IOU_keras = MeanIoU(num_classes=n_classes)  
IOU_keras.update_state(y_test_argmax, y_pred_argmax)
print("Mean IoU =", IOU_keras.result().numpy())

### Testing random images 

In [None]:
#Test some random images
import random

test_img_number = random.randint(0, len(X_test)-1)
test_img = X_test[test_img_number]
ground_truth=y_test[test_img_number]

test_img_input=np.expand_dims(test_img, 0)


test_pred = my_model.predict(test_img_input)
test_prediction = np.argmax(test_pred, axis=4)[0,:,:,:]

ground_truth_argmax = np.argmax(ground_truth, axis=3)
print(ground_truth_argmax.shape)

#### Plotting the testing image, ground truth mask and the prediction 

In [None]:
slice = random.randint(0, ground_truth_argmax.shape[0]-1)
plt.figure(figsize=(12, 8))
plt.subplot(231)
plt.title('Testing Image')
plt.imshow(test_img[slice,:,:,0], cmap='gray')
plt.subplot(232)
plt.title('Testing Label')
plt.imshow(ground_truth_argmax[slice,:,:])
plt.subplot(233)
plt.title('Prediction on test image')
plt.imshow(test_prediction[slice,:,:])
plt.show()