In [None]:
import os
from glob import glob
import cv2
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
from unet.model.architecture import get_unet_model
from unet.model.preprocessing import load_data

# Load the data

In [None]:
Xtrain, ytrain = load_data()

# Load the model and make predictions

In [None]:
model = get_unet_model(dropout_rate=0., batchnorm=False)

checkpoint_dir = '../model/unet_saved_model/'
latest = tf.train.latest_checkpoint(checkpoint_dir)
print("Latest file: {}".format(latest))

model.load_weights(latest)

In [None]:
example_image = Xtrain[4][:,:,:1]
example_target = ytrain[4][:,:,:1]

In [None]:
prediction = model.predict(np.array([example_image]) / 255)[0]
prediction_mask = np.zeros(shape=prediction[:,:,0].shape)
prediction_mask[prediction[:,:,0] > 0.5] = 255

In [None]:
fig, [ax1, ax2, ax3, ax4] = plt.subplots(1, 4, figsize=(25, 10))


ax1.set_title("Input image", fontsize=14)
ax1.imshow(example_image[:,:,0], cmap='seismic', interpolation='bilinear')
ax2.set_title("True target/mask", fontsize=14)
ax2.imshow(example_target[:,:,0], cmap='gray', interpolation='bilinear')
ax3.set_title("Raw prediction", fontsize=14)
ax3.imshow(prediction[:,:,0], cmap='gray', interpolation='bilinear')
ax4.set_title("Prediction mask (threshold 0.5)", fontsize=14)
ax4.imshow(prediction_mask, cmap='gray', interpolation='bilinear');