# Prerequisites before execution

In [None]:
# Just in case of unmounted Drive

from google.colab import drive
drive.mount('/content/drive')

In [None]:
# ...:: PARAMETERS : FILL THEN RUN ALL :::...
# ...:: PATH/NAME SECTION

# Define the directory path of the used dataset
datasetDir = '/content/drive/My Drive/Colab Notebooks/datasets/'
# Define the dataset name used
datasetName = 'multiple_256_tag_float32'
# Define the directory path of the used model
modelDir = '/content/drive/My Drive/Colab Notebooks/models/'
# Define the model name used
modelName = 'unet_model_multilabels_256_meaniou_best.keras'
# Define the directory path of the used dataset
sourceImageDirectory = '/content/drive/My Drive/Colab Notebooks/image_sample/'

# Package loading

In [None]:
# ...:: NEEDED PACKAGE ::...

import os
import numpy as np
import matplotlib.pyplot as plt
from time import time

import tensorflow as tf
from tensorflow import keras

from google.colab.patches import cv2_imshow
import matplotlib.colors as mcolors

# Model loading

In [None]:
def dice_loss(y_true, y_pred):
    y_true = tf.cast(y_true, tf.float32)
    numerator = 2 * tf.reduce_sum(y_true * y_pred, axis=-1)
    denominator = tf.reduce_sum(y_true + y_pred, axis=-1)
    return 1 - (numerator + 1) / (denominator + 1)

custom_objects = {'dice_loss': dice_loss}

# Charger le modèle avec les custom_objects
unet_model = tf.keras.models.load_model(os.path.join(modelDir, modelName), custom_objects=custom_objects)

# Unit prediction

In [None]:
cmap = mcolors.ListedColormap(['white', 'red', 'blue', 'green','yellow'])
# bounds = [-1.5, -0.5, 0.5, 1.5, 2.5, 3.5]
# norm = mcolors.BoundaryNorm(bounds, cmap.N)

cmap_bin = mcolors.ListedColormap(['white', 'blue'])

file = '00dec6a.jpg'

In [None]:
# ...:: FUNCTION SECTION ::...

def create_masks(pred_mask):
  masks = []
  for i in range(pred_mask.shape[-1]):
    mask = tf.where(pred_mask[..., i] > 0.25, 1, 0)
    mask = mask[..., tf.newaxis][0]
    masks.append(mask)
  return masks

def create_mask(pred_mask):
  pred_mask = tf.argmax(pred_mask, axis=-1)
  pred_mask = pred_mask[..., tf.newaxis][0]
  return pred_mask

def put_bar(image, pred_mask):
  image = tf.image.convert_image_dtype(image, tf.float32)  # Convert image to float
  black_mask = tf.less(image, 0.04)
  pred_mask = tf.where(black_mask, -1, pred_mask)
  return pred_mask

def display_multiple(display_list):
  fig, axs = plt.subplots(4, 2, figsize=(20, 20))
  title = ["Input Image", "Fish", "Flower", "Gravel", "Sugar"]

  # Display input image in the first column
  for i in range(4):
    axs[i, 0].imshow(tf.keras.utils.array_to_img(display_list[0]))
    axs[i, 0].axis("off")
    if i == 0:
      axs[i, 0].set_title(title[0])

  # Display each class mask in the second column
  for i in range(4):
    axs[i, 1].imshow(tf.keras.utils.array_to_img(display_list[i+1]), cmap=cmap_bin)
    axs[i, 1].axis("off")
    axs[i, 1].set_title(title[i+1])

  plt.tight_layout()
  plt.show()


def display(display_list):
  plt.figure(figsize=(15, 15))
  title = ["Input Image", "Predicted Mask"]
  for i in range(len(display_list)):
    plt.subplot(1, len(display_list), i+1)
    plt.title(title[i])

    if i == 1:
      # print(np.unique(display_list[i], return_counts=True))
      plt.imshow(tf.keras.utils.array_to_img(display_list[i]), cmap=cmap) #, norm = norm)
    else:
      plt.imshow(tf.keras.utils.array_to_img(display_list[i]))
      plt.axis("off")
  plt.show()

def show_predictions(dataset=None, num=1):
  if dataset:
    for image, mask in dataset.take(num):
      pred_mask = unet_model.predict(image)
      mask = tf.reduce_sum(mask, axis=-1) / tf.cast(tf.shape(mask)[-1], tf.float32)
      display([image[0], mask, put_bar(create_mask(pred_mask))])
  else:
    display([sample_image, sample_mask,
            put_bar(create_mask(unet_model.predict(sample_image[tf.newaxis, ...])))])

def show_prediction(image):
  pred_mask = unet_model.predict(image)
  predictions_filter = tf.greater(pred_mask, 0.5)
  predictions_filter = tf.reduce_any(predictions_filter, axis=-1)
  target_mask = put_bar(image[0],create_mask(pred_mask))
  display([image[0], target_mask])

In [None]:
image = cv2.imread(os.path.join(sourceImageDirectory, file), cv2.IMREAD_GRAYSCALE)
cv2_imshow(image)
image = np.expand_dims(image, axis=[0, -1])

show_prediction(image)
# pred_mask = create_masks(unet_model.predict(image))
# display_multiple([image[0],pred_mask])

masks = create_masks(unet_model.predict(image))
display_multiple([image[0], *masks])