In [None]:
import numpy as np
import random

from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.image import load_img
import tensorflow as tf
import os
from PIL import Image
from keras import backend as K
from matplotlib import pyplot as plt
from skimage import transform
from skimage.io import imread, imshow, imread_collection, concatenate_images


def dice_coef(y_true, y_pred, smooth=1):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2.0 * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)


def soft_dice_loss(y_true, y_pred):
    return 1 - dice_coef(y_true, y_pred)


def iou_coef(y_true, y_pred, smooth=1):
    intersection = K.sum(K.abs(y_true * y_pred), axis=[1, 2, 3])
    union = K.sum(y_true, [1, 2, 3]) + K.sum(y_pred, [1, 2, 3]) - intersection
    iou = K.mean((intersection + smooth) / (union + smooth), axis=0)

    return iou


image_generator = tf.keras.preprocessing.image.ImageDataGenerator(validation_split=0.01)
test_images = image_generator.flow_from_directory(
    directory="../Data/Test/Images", subset="validation"
)

test_masks = image_generator.flow_from_directory(
    directory="../Data/Test/Masks", subset="validation"
)


model = load_model(
    "../Models/road_mapper_final.h5",
    custom_objects={
        "soft_dice_loss": soft_dice_loss,
        "iou_coef": iou_coef,
        "dice_coef_loss": dice_coef,
        "dice_loss": dice_coef,
    },
)

# evaluation = model.evaluate(test_images, test_masks)
predictions = model.predict(test_images, verbose=1)

thresh_val = 0.1
predicton_threshold = (predictions > thresh_val).astype(np.uint8)

# print(evaluation)
# for i in range(50):
#     plt.imsave(f"example{i}.tiff", np.squeeze(predicton_threshold[i][:, :, 0]))

ix = random.randint(0, len(predictions))
num_samples = 2

f = plt.figure(figsize = (15, 25))
for i in range(1, num_samples*4, 4):
#   ix = random.randint(0, len(predictions))

    ix = 1
    f.add_subplot(num_samples, 4, i)
    imshow(test_images[ix][:,:,0])
    plt.title("Image")
    plt.axis('off')

    f.add_subplot(num_samples, 4, i+1)
    imshow(np.squeeze(test_masks[ix][:,:,0]))
    plt.title("Groud Truth")
    plt.axis('off')

    f.add_subplot(num_samples, 4, i+2)
    imshow(np.squeeze(predictions[ix][:,:,0]))
    plt.title("Prediction")
    plt.axis('off')

    f.add_subplot(num_samples, 4, i+3)
    imshow(np.squeeze(predicton_threshold[ix][:,:,0]))
    plt.title("thresholded at {}".format(thresh_val))
    plt.axis('off')

plt.show()
