In [None]:
import os
from glob import glob
import cv2
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
from unet.model.constants import *
from unet.model.architecture import get_unet_model
from unet.model.preprocessing import load_data

# Load the data

In [None]:
Xtrain, ytrain = load_data()

# Load the model and make predictions

In [None]:
def predict_from_checkpoint(img, epoch=0):
    """Loads the weights for a particular epoch and then predicts
    with the loaded weights.
    
    Args:
        img (Union[list, np.ndarray]):
        epoch (int):
        
    Returns:
        tuple[np.ndarray, np.ndarray]
    """
    model.load_weights(CHECKPOINT_PATH.format(epoch=epoch))
    prediction = model.predict(np.array([img]) / 255)[0]
    prediction_mask = np.zeros(shape=prediction[:,:,0].shape)
    prediction_mask[prediction[:,:,0] > 0.5] = 255
    return prediction, prediction_mask


def plot_true_with_prediction(ax1, ax2, ax3, ax4, true_img, true_target, raw_prediction, predicted_target):
    """Plots four figures:
    1. the input image
    2. true target
    3. raw prediction (no threshold applied)
    4. prediction mask
    
    Args:
        ax1 (AxesSubplot):
        ax2 (AxesSubplot):
        ax3 (AxesSubplot):
        ax4 (AxesSubplot):
        true_img (np.ndarray):
        true_target (np.ndarray):
        raw_prediction (np.ndarray):
        predicted_target (np.ndarray):
    """
    ax1.set_title("Input image", fontsize=14)
    ax1.imshow(example_image[:,:,0], cmap='seismic', interpolation='bilinear')
    ax2.set_title("True target/mask", fontsize=14)
    ax2.imshow(example_target[:,:,0], cmap='gray', interpolation='bilinear')
    ax3.set_title("Raw prediction", fontsize=14)
    ax3.imshow(raw_prediction[:,:,0], cmap='gray', interpolation='bilinear')
    ax4.set_title("Prediction mask (threshold 0.5)", fontsize=14)
    ax4.imshow(predicted_target, cmap='gray', interpolation='bilinear');

In [None]:
model = get_unet_model(dropout_rate=0., batchnorm=False)

In [None]:
example_image = Xtrain[5][:,:,:1]
example_target = ytrain[5][:,:,:1]

In [None]:
nr_rows = 11
fig, axes = plt.subplots(nr_rows, 4, figsize=(25, 7 * nr_rows))


for epoch, (ax1, ax2, ax3, ax4) in enumerate(axes):
    prediction, prediction_mask = predict_from_checkpoint(example_image, epoch=epoch)
    plot_true_with_prediction(ax1, ax2, ax3, ax4, example_image, example_target, prediction, prediction_mask)