In [None]:
# Upload Input Data
# -----------------
# Parameters:

# Upload Input Data (True, False)
UPLOAD_DATA = False

# -----------------

if UPLOAD_DATA:
    from google.colab import files
    files.upload()

# -----------------

In [None]:
# Interactive Shell
# -----------------

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

# -----------------

In [None]:
# Modules Import
# --------------

import os
import time
import json
import enum
import shutil
import pathlib
import random
from tqdm import tqdm

import numpy as np
import pandas as pd
from sklearn.utils import shuffle

import cv2
from PIL import Image

import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications.vgg16 import preprocess_input 

from matplotlib import cm
import matplotlib.pyplot as plt

%matplotlib inline

# --------------

In [None]:
# Transfer Learning Available Base Models
# ---------------------------------------

class DATASET_NAME(enum.Enum):
    ALL = "All"
    BIPBIP = "Bipbip"
    PEAD = "Pead"
    ROSEAU = "Roseau"
    WEEDELEC = "Weedelec"

class PLANT_NAME(enum.Enum):
    MAIS = "Mais"
    HARICOT = "Haricot"

# ---------------------------------------

In [None]:
# Module Parameters
# -----------------
# Parameters:

# Image Height
IMG_H = 256

# Image Width
IMG_W = 256

# Number Classes 
NUM_CLASSES = 3

# Batch Size
BATCH_SIZE = 4

# UNet (True, False) VGG
UNET = False

# Train Model (True, False)
TRAIN_MODEL = True

# Training Dataset
DATASETS_TRAINING = DATASET_NAME.BIPBIP

# -----------------

In [None]:
# Define Directories
# ------------------
# Parameters:

# Root Directory
ROOT_DIR = pathlib.Path("/content")

# ------------------

development_dataset = ROOT_DIR / "Development_Dataset"
training_path = development_dataset / "Training"

bipbip_path = training_path / "Bipbip"
bipbip_path_haricot = bipbip_path / "Haricot"
bipbip_path_mais = bipbip_path / "Mais"

roseau_path = training_path / "Roseau"
roseau_path_haricot = roseau_path / "Haricot"
roseau_path_mais = roseau_path / "Mais"

weedelec_path = training_path / "Weedelec"
weedelec_path_haricot = weedelec_path / "Haricot"
weedelec_path_mais = weedelec_path / "Mais"

pead_path = training_path / "Pead"
pead_path_haricot = pead_path / "Haricot"
pead_path_mais = pead_path / "Mais"

bipbip_path_list = [bipbip_path_haricot, bipbip_path_mais]
roseau_path_list = [roseau_path_haricot, roseau_path_mais]
pead_path_list = [pead_path_haricot, pead_path_mais]
weedelec_path_list = [weedelec_path_haricot, weedelec_path_mais]

dictionary = {
    DATASET_NAME.ALL: [bipbip_path_haricot, bipbip_path_mais, 
                       roseau_path_haricot, roseau_path_mais, 
                       weedelec_path_haricot, weedelec_path_mais],
    DATASET_NAME.BIPBIP: [bipbip_path_haricot, bipbip_path_mais],
    DATASET_NAME.ROSEAU: [roseau_path_haricot, roseau_path_mais],
    DATASET_NAME.PEAD: [pead_path_haricot, pead_path_mais],
    DATASET_NAME.WEEDELEC: [weedelec_path_haricot, weedelec_path_mais]     
}

dataset_path_list = dictionary.get(DATASETS_TRAINING)

dataset_tiles_path_list = [dataset_path / "Tiles" for dataset_path in dataset_path_list]

model_path = development_dataset / "model.h5py"

if not os.path.exists(development_dataset):
    !unzip -q /content/Development_Dataset.zip

# ------------------

In [None]:
# Random Seed
# -----------
# Parameters:

# Random Seed
SEED = 1000

# -----------

tf.random.set_seed(SEED) 

# -----------

In [None]:
# Auxiliary Functions
# -------------------

def tiling(path, shape, output):

    filenames = next(os.walk(path))[2]

    v_split = int(shape[1] / 256)
    h_split = int(shape[0] / 256)
    
    for row_index, image_name in enumerate(tqdm(filenames)):   

        image = Image.open(path / image_name).resize(shape)

        rgb_tensor = tf.convert_to_tensor(np.array(image))
        v_slices = tf.split(rgb_tensor, v_split, axis=0)

        for i, v_slice in enumerate(v_slices):
            h_slices = tf.split(v_slice, h_split, axis=1)

            for j, tile in enumerate(h_slices):
                file_name = os.path.splitext(image_name)[0]
                file_name = file_name + "_" + str(i) + "_" + str(j) + ".png"
                
                Image.fromarray(tile.numpy()).save(output / file_name)

def process_dataset(dataset_path, shape):

    images_path = dataset_path / "Images" 
    masks_path = dataset_path / "Masks"
    tiles_path = dataset_path / "Tiles"

    tiles_images_path = tiles_path / "Images" 
    tiles_masks_path = tiles_path / "Masks"

    if not os.path.exists(tiles_path):
        os.mkdir(tiles_path)
        os.mkdir(tiles_images_path)
        os.mkdir(tiles_masks_path)

    tiling(images_path, shape, tiles_images_path)
    tiling(masks_path, shape, tiles_masks_path)

# -------------------

In [None]:
# Tiling
# ------
# Parameters:

# Force Reload Images (True, False)
FORCE_RELOAD = False

# Shape Parameters used in resizing for the tiling part
BIPBIP_SHAPE = (2048, 1536)
ROSEAU_SHAPE = (1024,768)
WEEDELEC_SHAPE = (5120, 3328)
PEAD_SHAPE=(3072,2304)

# Real Shape Parameters
BIPBIP_REAL_SHAPE = (2048, 1536)
ROSEAU_REAL_SHAPE = (1228, 819)
WEEDELEC_REAL_SHAPE = (5184, 3456)
PEAD_REAL_SHAPE = (3280, 2464)

# ------

if FORCE_RELOAD:

    for dataset in bipbip_path_list:
        process_dataset(dataset, shape=BIPBIP_SHAPE)
    for dataset in roseau_path_list:
        process_dataset(dataset, shape=ROSEAU_SHAPE)
    for dataset in weedelec_path_list:
        process_dataset(dataset, shape=WEEDELEC_SHAPE)
    for dataset in pead_path_list:
        process_dataset(dataset, shape=PEAD_SHAPE)

# ------

In [None]:
# Data Splitting Function
# -----------------------

def split_dataset(tiles_path, percentage):

    split_path = tiles_path / "Splits"
    os.mkdir(split_path, mode = 0o666)
    train = open(split_path / "train.txt", "w+")
    val = open(split_path / "val.txt","w+")

    n_images = len(os.listdir(tiles_path / "Images"))
    n_train_images = (int) (n_images * percentage)
    n_val_images = (int) (n_images - n_train_images)

    count = 1
    count_train = 1
    count_val = 1
    
    for image in os.listdir(tiles_path / "Images"):
      if count % ((int) (n_images / n_val_images)) == 0:
        if count_val < n_val_images:
          val.write(image[: -4] + "\n")
          count_val += 1
          count += 1
        elif count_val == n_val_images:
          val.write(image[: -4])
          count_val += 1
          count += 1
      else:
        if count_train < n_train_images:
          train.write(image[: -4] + "\n")
          count_train += 1
          count += 1
        elif count_train == n_train_images:
          train.write(image[: -4])
          count_train += 1
          count +=1

    val.close()
    train.close() 

# ----------------------------

In [None]:
# Data Splitting Function
# -----------------------

def split_dataset(tiles_path, percentage):

    split_path = tiles_path / "Splits"
    # os.mkdir(split_path, mode = 0o666)
    os.mkdir(split_path)
    train = open(split_path / "train.txt", "w+")
    val = open(split_path / "val.txt","w+")

    n_images = len(os.listdir(tiles_path / "Images"))
    n_train_images = (int) (n_images * percentage)
    n_val_images = (int) (n_images - n_train_images)

    count = 1
    count_train = 1
    count_val = 1
    
    for image in os.listdir(tiles_path / "Images"):
      if count % ((int) (n_images / n_val_images)) == 0:
        if count_val < n_val_images:
          val.write(image[: -4] + "\n")
          count_val += 1
          count += 1
        elif count_val == n_val_images:
          val.write(image[: -4])
          count_val += 1
          count += 1
      else:
        if count_train < n_train_images:
          train.write(image[: -4] + "\n")
          count_train += 1
          count += 1
        elif count_train == n_train_images:
          train.write(image[: -4])
          count_train += 1
          count +=1

    val.close()
    train.close() 

# ----------------------------

In [None]:
# Dataframe Splitting
# -------------------
# Parameters:

# Train Percentage
TRAIN_P = 0.8

# -------------------

for dataset_tiles_path in dataset_tiles_path_list:

    splits_path = dataset_tiles_path / "Splits"

    if os.path.exists(splits_path):
        shutil.rmtree(dataset_tiles_path / "Splits")

    split_dataset(dataset_tiles_path, TRAIN_P)

# -------------------

In [None]:
# Data Augmentation
# -----------------
# Parameters:

# Apply Data Augementation (True, False)
APPLY_DATA_AUGMENTATION = True

# Data Augmentation Parameters
ROTATION_RANGE = 10
WIDTH_SHIFT_RANGE = 10
HEIGHT_SHIFT_RANGE = 10
ZOOM_RANGE = 0.3
HORIZONTAL_FLIP = True
VERTICAL_FLIP = True
FILL_MODE = "reflect"

# -----------------

if APPLY_DATA_AUGMENTATION:
    img_data_gen = ImageDataGenerator(rotation_range=ROTATION_RANGE,
                                      width_shift_range=WIDTH_SHIFT_RANGE,
                                      height_shift_range=HEIGHT_SHIFT_RANGE,
                                      zoom_range=ZOOM_RANGE,
                                      horizontal_flip=HORIZONTAL_FLIP,
                                      vertical_flip=VERTICAL_FLIP,
                                      fill_mode=FILL_MODE)
    mask_data_gen = ImageDataGenerator(rotation_range=ROTATION_RANGE,
                                       width_shift_range=WIDTH_SHIFT_RANGE,
                                       height_shift_range=HEIGHT_SHIFT_RANGE,
                                       zoom_range=ZOOM_RANGE,
                                       horizontal_flip=HORIZONTAL_FLIP,
                                       vertical_flip=VERTICAL_FLIP,
                                       fill_mode=FILL_MODE)

# -----------------

In [None]:
# Data Generator
# --------------

def read_rgb_mask(mask_arr):

    new_mask_arr = np.zeros(mask_arr.shape[:2], dtype=mask_arr.dtype)

    new_mask_arr[np.where(np.all(mask_arr == [0, 0, 0], axis=-1))] = 0
    new_mask_arr[np.where(np.all(mask_arr == [216, 124, 18], axis=-1))] = 0
    new_mask_arr[np.where(np.all(mask_arr == [255, 255, 255], axis=-1))] = 1
    new_mask_arr[np.where(np.all(mask_arr == [216, 67, 82], axis=-1))] = 2

    return new_mask_arr

class CustomDataset(tf.keras.utils.Sequence):


  def __init__(self, dataset_dirs, which_subset, img_generator=None, mask_generator=None, 
               preprocessing_function=None, out_shape=[256, 256]):
    
    subset_filenames = []
    
    for dataset_dir in dataset_dirs:
      if which_subset == "training":
        subset_file = os.path.join(dataset_dir, "Splits", "train.txt")
      elif which_subset == "validation":
        subset_file = os.path.join(dataset_dir, "Splits", "val.txt")
    
      with open(subset_file, "r") as f:
        lines = f.readlines()
  
      for line in lines:
        subset_filenames.append(line.strip()) 

    self.which_subset = which_subset
    self.dataset_dirs = dataset_dirs
    self.subset_filenames = subset_filenames
    self.img_generator = img_generator
    self.mask_generator = mask_generator
    self.preprocessing_function = preprocessing_function
    self.out_shape = out_shape

  def __len__(self):

    return len(self.subset_filenames)

  def __getitem__(self, index):
      
    curr_filename = self.subset_filenames[index]
    if "Bipbip_haricot" in curr_filename:
      dataset_dir = bipbip_path_haricot
    elif "Bipbip_mais" in curr_filename:
      dataset_dir = bipbip_path_mais
    elif "Roseau_haricot" in curr_filename:
      dataset_dir = roseau_path_haricot
    elif "Roseau_mais" in curr_filename:
      dataset_dir = roseau_path_mais 
    elif "Weedelec_haricot" in curr_filename:
      dataset_dir = weedelec_path_haricot
    elif "Weedelec_mais" in curr_filename:
      dataset_dir = weedelec_path_mais
    elif "Pead_haricot" in curr_filename:
      dataset_dir = pead_path_haricot
    elif "Pead_mais" in curr_filename:
      dataset_dir = pead_path_mais
    
    img = Image.open(os.path.join(dataset_dir, "Tiles", "Images", curr_filename + ".png"))

    mask = Image.open(os.path.join(dataset_dir, "Tiles", "Masks", curr_filename + ".png"))

    img = img.resize(self.out_shape)
    mask = mask.resize(self.out_shape, resample=Image.NEAREST)
    
    img_arr = np.array(img)
    mask_arr = np.array(mask)
    mask_arr = read_rgb_mask(mask_arr)
    mask_arr = np.expand_dims(mask_arr, -1)

    if self.which_subset == "training":
      if self.img_generator is not None and self.mask_generator is not None:
        img_t = self.img_generator.get_random_transform(img_arr.shape, seed=SEED)
        mask_t = self.mask_generator.get_random_transform(mask_arr.shape, seed=SEED)
        img_arr = self.img_generator.apply_transform(img_arr, img_t)
        
        out_mask = np.zeros_like(mask_arr)
        for c in np.unique(mask_arr):
          if c > 0:
            curr_class_arr = np.float32(mask_arr == c)
            curr_class_arr = self.mask_generator.apply_transform(curr_class_arr, mask_t)
            curr_class_arr = np.uint8(curr_class_arr)
            curr_class_arr = curr_class_arr * c 
            out_mask += curr_class_arr
    else:
      out_mask = mask_arr
    
    if self.preprocessing_function is not None:
        img_arr = self.preprocessing_function(img_arr)

    return img_arr, np.float32(out_mask)

# --------------

In [None]:
# Dataset Loader
# --------------

dataset_train = CustomDataset(dataset_tiles_path_list, "training", img_generator=img_data_gen, mask_generator=mask_data_gen, preprocessing_function=preprocess_input)

train_dataset = tf.data.Dataset.from_generator(lambda: dataset_train,
                                               output_types=(tf.float32, tf.float32),
                                               output_shapes=([IMG_H, IMG_W, 3], [IMG_H, IMG_W, 1]))

train_dataset = train_dataset.batch(BATCH_SIZE)
train_dataset = train_dataset.repeat()

dataset_valid = CustomDataset(dataset_tiles_path_list, "validation", preprocessing_function=preprocess_input)

valid_dataset = tf.data.Dataset.from_generator(lambda: dataset_valid,
                                               output_types=(tf.float32, tf.float32),
                                               output_shapes=([IMG_H, IMG_W, 3], [IMG_H, IMG_W, 1]))

valid_dataset = valid_dataset.batch(BATCH_SIZE)
valid_dataset = valid_dataset.repeat()

# --------------


In [None]:
# Data Generator Test
# -------------------

evenly_spaced_interval = np.linspace(0, 1, 3)
colors = [cm.rainbow(x) for x in evenly_spaced_interval]

iterator = iter(valid_dataset)

fig, ax = plt.subplots(1, 2, figsize=(9, 9))

augmented_img, target = next(iterator)
augmented_img = augmented_img[0]
augmented_img = augmented_img

target = np.array(target[0, ..., 0])
target_img = np.zeros([target.shape[0], target.shape[1], 3])

target_img[np.where(target == 0)] = [0, 0, 0]
target_img[np.where(target == 1)] = [255, 255, 255]
target_img[np.where(target == 2)] = [216, 67, 82]

_ = ax[0].imshow(np.uint8(augmented_img))
_ = ax[1].imshow(np.uint8(target_img))

plt.show()

# -------------------

In [None]:
# Model Architecture Setup
# ------------------------ 

if not UNET:
    BASE_MODEL = tf.keras.applications.VGG16(weights="imagenet", include_top=False, input_shape=(IMG_H, IMG_W, 3))

    for layer in BASE_MODEL.layers:
        layer.trainable = False

    BASE_MODEL.summary()

# ------------------------ 

In [None]:
# VGG Model Architecture
# ----------------------

def vgg_model(depth, start_f, num_classes):

    model = tf.keras.Sequential()
    
    # Encoder
    # -------
    model.add(BASE_MODEL)
            
    # Decoder
    # -------
    for i in range(depth):
        model.add(tf.keras.layers.UpSampling2D(2, interpolation="bilinear"))
        model.add(tf.keras.layers.Conv2D(filters=start_f,
                                         kernel_size=(3, 3),
                                         strides=(1, 1),
                                         padding="same"))
        model.add(tf.keras.layers.ReLU())

        start_f = start_f // 2

    # Prediction Layer
    # ----------------
    model.add(tf.keras.layers.Conv2D(filters=num_classes,
                                     kernel_size=(1, 1),
                                     strides=(1, 1),
                                     padding="same",
                                     activation="softmax"))
    
    return model

# ----------------------

In [None]:
# UNet Model Architecture
# -----------------------

def unet_model(depth, start_f, num_classes, dynamic_input_shape):
   
    layers_skip_connections = []
    
    # Encoder
    # -------
    if dynamic_input_shape:
        input_shape = [None, None, 3]
    else:
        input_shape = [img_h, img_w, 3]
        
    input_layer = tf.keras.layers.Input(input_shape)
    layer = input_layer
    for i in range(depth):
        layer = tf.keras.layers.Conv2D(filters=start_f, 
                                       kernel_size=(3, 3),
                                       strides=(1, 1),
                                       padding="same",
                                       input_shape=input_shape)(layer)
        layer = tf.keras.layers.ReLU()(layer)
        layer = tf.keras.layers.Conv2D(filters=start_f, 
                                       kernel_size=(3, 3),
                                       strides=(1, 1),
                                       padding="same",
                                       input_shape=input_shape)(layer)
        layer = tf.keras.layers.ReLU()(layer)
        layers_skip_connections.append(layer)
        layer = tf.keras.layers.MaxPool2D(pool_size=(2, 2)) (layer)

        start_f *= 2

    # Bottleneck
    layer = tf.keras.layers.Conv2D(filters=start_f, kernel_size=(3, 3), strides=(1, 1), padding="same")(layer)
    layer = tf.keras.layers.ReLU()(layer)
    layer = tf.keras.layers.Conv2D(filters=start_f, kernel_size=(3, 3), strides=(1, 1), padding="same")(layer)
    layer = tf.keras.layers.ReLU()(layer)

    start_f = start_f // 2
    
    # Decoder
    # -------
    for i in range(depth):
        layer = tf.keras.layers.UpSampling2D(size=(2, 2), interpolation="bilinear")(layer)
        layer = tf.keras.layers.Conv2D(filters=start_f, kernel_size=(2, 2), strides=(1, 1), padding="same")(layer)
        layer = tf.keras.layers.ReLU()(layer)
        
        first_layer = layers_skip_connections[depth - i - 1]
        layer = tf.keras.layers.Concatenate()([first_layer, layer])
        
        layer = tf.keras.layers.Conv2D(filters=start_f, kernel_size=(3, 3), strides=(1, 1), padding="same")(layer)
        layer = tf.keras.layers.ReLU()(layer)
        layer = tf.keras.layers.Conv2D(filters=start_f, kernel_size=(3, 3), strides=(1, 1), padding="same")(layer)
        layer = tf.keras.layers.ReLU()(layer)

        start_f = start_f // 2

    # Prediction Layer
    # ----------------
    layer = tf.keras.layers.Conv2D(filters=start_f, kernel_size=(3, 3), strides=(1, 1), padding="same")(layer)
    layer = tf.keras.layers.ReLU()(layer)
    layer = tf.keras.layers.Conv2D(filters=start_f, kernel_size=(3, 3), strides=(1, 1), padding="same")(layer)
    layer = tf.keras.layers.ReLU()(layer)
    layer = tf.keras.layers.Conv2D(filters=num_classes, kernel_size=(1, 1), strides=(1, 1), padding="same", activation="softmax")(layer)
        
    model = tf.keras.Model(inputs = input_layer, outputs = layer)

    return model

# -----------------------

In [None]:
# Generate Model
# --------------

if not UNET:    
    MODEL = vgg_model(depth=5, start_f=256, num_classes=NUM_CLASSES)
else:
    MODEL = unet_model(depth=4, start_f=32, num_classes=NUM_CLASSES)     

MODEL.summary()

# --------------

In [None]:
# Mean IoU
# --------

def meanIoU(y_true, y_pred):
    
    y_pred = tf.expand_dims(tf.argmax(y_pred, -1), -1)

    per_class_iou = []

    for i in range(1, NUM_CLASSES):
        class_pred = tf.cast(tf.where(y_pred == i, 1, 0), tf.float32)
        class_true = tf.cast(tf.where(y_true == i, 1, 0), tf.float32)
        intersection = tf.reduce_sum(class_true * class_pred)
        union = tf.reduce_sum(class_true) + tf.reduce_sum(class_pred) - intersection
    
        iou = (intersection + 1e-7) / (union + 1e-7)
        per_class_iou.append(iou)

    return tf.reduce_mean(per_class_iou)

# --------

In [None]:
# Weighted Loss Function
# ----------------------

def weighted_loss(onehot_labels, logits):

    class_weights = [0.026, 0.313, 0.661]
    weights = tf.reduce_sum(class_weights * onehot_labels, axis=-1)
    unweighted_losses = tf.nn.softmax_cross_entropy_with_logits(labels=[onehot_labels], logits=[logits])

    weighted_losses = unweighted_losses * weights
    loss = tf.reduce_mean(weighted_losses)

    return loss

# ----------------------

In [None]:
# Model Optimization
# ------------------
# Parameters:

# Early Stopping (True, False)
EARLY_STOP = True

# Weighted Loss Function (True, False)
WEIGHTED_LOSS = False

# Checkpoints
CHECKPOINT = True

# Learning Rate
LR = 1e-5

# Validation Metrics
METRICS = ["accuracy", meanIoU]

# ------------------

callbacks = []

if EARLY_STOP:
    es_callback = tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=3,restore_best_weights=True)
    callbacks.append(es_callback)

if WEIGHTED_LOSS:
    loss = weighted_loss
else:
    loss = tf.keras.losses.SparseCategoricalCrossentropy() 

if CHECKPOINT:
    model_checkpoint_path = development_dataset / "checkpoints" / "checkpoint-{}.ckpt".format(DATASETS_TRAINING.value)
    es_callback = tf.keras.callbacks.ModelCheckpoint(filepath=model_checkpoint_path, verbose=1, save_freq="epoch", save_best_only=True)
    callbacks.append(es_callback)


optimizer = tf.keras.optimizers.Adam(learning_rate=LR)

# Compile Model
MODEL.compile(optimizer=optimizer, loss=loss, metrics=METRICS)

# ------------------

In [None]:
# Model Fitting
# -------------
# Parameters:

# Number of Epochs
EPOCHS = 1

# -------------

if TRAIN_MODEL:
    if CHECKPOINT and os.path.exists(model_checkpoint_path):
        MODEL = load_model(model_checkpoint_path, custom_objects={'meanIoU': meanIoU})

    MODEL.fit(x=train_dataset,
              epochs=EPOCHS,
              validation_data=valid_dataset,
              validation_steps=len(dataset_valid), 
              callbacks=callbacks)

else:
    MODEL = load_model(model_checkpoint_path, custom_objects={'meanIoU': meanIoU})

# -------------

In [None]:
# Model Save
# ----------

MODEL.save(model_path)

# ----------

In [None]:
# Auxiliary Functions
# -------------------

def predict(image, shape):

    v_split = int(shape[1] / 256)
    h_split = int(shape[0] / 256)

    rgb_tensor = tf.convert_to_tensor(np.array(image))
    v_slices = tf.split(rgb_tensor, v_split, axis=0)

    v_pred = []

    for i, v_slice in enumerate(v_slices):
        h_slices = tf.split(v_slice, h_split, axis=1)

        h_pred = []

        for j, tile in enumerate(h_slices):
            out_sigmoid = MODEL.predict(x=tf.expand_dims(tile, 0))
            out_sigmoid = np.transpose(out_sigmoid, (0, 2, 1, 3))
            
            h_pred.append(out_sigmoid)

        split = tf.concat(h_pred, axis=1)
        v_pred.append(split)

    predicted_class = tf.concat(v_pred, axis=2)
    predicted_class = tf.argmax(predicted_class, -1)
    predicted_class = predicted_class[0, ...]
    predicted_class = np.transpose(predicted_class)

    return predicted_class

def process_output(array, shape, real_shape=None):

    image = np.zeros([shape[1], shape[0], 3])

    image[np.where(array == 0)] = [0, 0, 0]
    image[np.where(array == 1)] = [255, 255, 255]
    image[np.where(array == 2)] = [216, 67, 82]

    image = np.uint8(image)
    image = Image.fromarray(image, "RGB")

    if real_shape is not None:
        image = image.resize(real_shape, resample=Image.NEAREST)
    
    return np.array(image)

# -------------------

In [None]:
# Prediction Test
# ---------------
# Parameters:

# Dataset
DATASET_PATH = development_dataset / "Test_Dev" / "Bipbip" / "Mais" / "Images"

# Image Path
IMAGE_NAME = "Bipbip_mais_im_10441"

# Image Shape
SHAPE = BIPBIP_SHAPE

# ---------------

fig, ax = plt.subplots(1, 2, figsize=(20, 20))

image = Image.open(DATASET_PATH / (IMAGE_NAME + ".jpg")).resize(SHAPE)

predicted_class = predict(image, SHAPE)
prediction_img = process_output(predicted_class, SHAPE)

_ = ax[0].imshow(np.uint8(image))
_ = ax[1].imshow(np.uint8(prediction_img))

fig.canvas.draw()

# ---------------