# Prediction notebook

----------------------
This notebook allows for predictions on large images. It includes the tiling of those images either without smooth blending or with smooth blending

To do for later use:

- Smooth blending only works on very small images. Why? Can we improve?
- The predict function does currently not work with pretrained versions of the model

In [None]:
from patchify import patchify, unpatchify
import cv2
import numpy as np
from matplotlib import pyplot as plt

In [None]:
#Predict on large image
large_image = cv2.imread('../data/predictions/pred_test_2.tif', 0)
#This will split the image into small images of shape [3,3]
patches = patchify(large_image, (512, 512), step=512) 
print("Large image shape is: ", large_image.shape)
print("Patches array shape is: ", patches.shape)

In [None]:
plt.figure(figsize=(9, 9))
plt.imshow(large_image, cmap='gray')

In [None]:
plt.figure(figsize=(9, 9))
square = 6
ix = 1
for i in range(square):
	for j in range(square):
		# specify subplot and turn of axis
		ax = plt.subplot(square, square, ix)
		ax.set_xticks([])
		ax.set_yticks([])
		# plot 
		plt.imshow(patches[i, j, :, :], cmap='gray')
		ix += 1
# show the figure
plt.show()

In [None]:
#Load previously saved model
from keras.models import load_model
model = load_model("../models/05_27_onesample_10epochs.hdf5", compile=False)

In [None]:
predicted_patches = []
for i in range(patches.shape[0]):
    for j in range(patches.shape[1]):
        print("Now predicting on patch", i,j)
        
        single_patch = patches[i,j,:,:]  
        single_patch = single_patch / 255.  

        single_patch = np.expand_dims(np.array(single_patch), axis=2)
        single_patch_input=np.expand_dims(single_patch, 0)
        single_patch_prediction = (model.predict(single_patch_input))
        single_patch_predicted_img=np.argmax(single_patch_prediction, axis=3)[0,:,:]

        predicted_patches.append(single_patch_predicted_img)

predicted_patches = np.array(predicted_patches)

predicted_patches_reshaped = np.reshape(predicted_patches, (patches.shape[0], patches.shape[1], 512,512)) #Replace with patch size

In [None]:
predicted_patches_reshaped.shape

In [None]:
mask = cv2.imread('../data/original_data/masks/mask.tif', 0)

In [None]:
plt.figure(figsize=(9, 9))
square = 6
ix = 1
for i in range(square):
	for j in range(square):
		# specify subplot and turn of axis
		ax = plt.subplot(square, square, ix)
		ax.set_xticks([])
		ax.set_yticks([])
		# plot 
		plt.imshow(predicted_patches_reshaped[i, j, :, :], cmap='gray')
		ix += 1
# show the figure
plt.show()

In [None]:
reconstructed_image = unpatchify(predicted_patches_reshaped, large_image.shape)

In [None]:
plt.figure(figsize=(20, 20))
plt.subplot(221)
plt.title('Original Image')
plt.imshow(large_image, cmap='gray')
plt.subplot(222)
plt.title('Prediction')
plt.imshow(reconstructed_image, cmap = 'gray')
plt.show()

### Predict on large image

In [None]:
large_image = cv2.imread('../data/predictions/pred_test_2.tif', 0)
large_image_scaled = large_image /255.
large_image_scaled = np.expand_dims(large_image_scaled, axis=2)

large_image_scaled.shape


In [None]:
scale_percent = 1 # percent of original size
width = int(large_image.shape[1] * scale_percent / 100)
height = int(large_image.shape[0] * scale_percent / 100)
dim = (width, height)

In [None]:
smaller_image = cv2.resize(large_image, dim, interpolation=cv2.INTER_AREA)

In [None]:
smaller_image_scaled = smaller_image /255.
smaller_image_scaled = np.expand_dims(smaller_image_scaled, axis=2)

In [None]:
print('Resized Dimensions : ',smaller_image_scaled.shape)

In [None]:
patch_size=512
n_classes=4

In [None]:
#Load previously saved model
from keras.models import load_model
model = load_model("../models/05_27_onesample_10epochs.hdf5", compile=False)

In [None]:
# Use the algorithm. The `pred_func` is passed and will process all the image 8-fold by tiling small patches with overlap, called once with all those image as a batch outer dimension.
# Note that model.predict(...) accepts a 4D tensor of shape (batch, x, y, nb_channels), such as a Keras model.
from smooth_tiled_pred import predict_img_with_smooth_windowing
predictions_smooth = predict_img_with_smooth_windowing(
    smaller_image_scaled,    #Must be of shape (x, y, c) --> NOT of the shape (n, x, y, c)
    window_size=patch_size,
    subdivisions=2,  # Minimal amount of overlap for windowing. Must be an even number.
    nb_classes=n_classes,
    pred_func=(
        lambda img_batch_subdiv: model.predict((img_batch_subdiv))
    )
)

In [None]:
print(predictions_smooth.shape)

In [None]:
final_prediction = np.argmax(predictions_smooth, axis=2)

In [None]:
plt.figure(figsize=(20, 10))
plt.subplot(131)
plt.title('Testing Image')
plt.imshow(large_image, cmap='gray')
plt.subplot(132)
plt.title('Prediction without smooth blending')
plt.imshow(reconstructed_image)
plt.subplot(133)
plt.title('Prediction with smooth blending')
plt.imshow(final_prediction)
plt.show()