In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from .segmentation_models.core.models import Unet
import cv2
from .segmentation_models.core.utils import predict_big_image
import os
import pandas as pd

In [None]:
class MyMeanIOU(tf.keras.metrics.MeanIoU):
    def update_state(self, y_true, y_pred, sample_weight=None):
        return super().update_state(y_true, tf.argmax(y_pred, axis=-1), sample_weight)

In [None]:
class Augment(tf.keras.layers.Layer):
  def __init__(self, seed=42):
    super().__init__()
    # both use the same seed, so they'll make the same random changes.
    self.flip_inputs_h = tf.keras.layers.RandomFlip(mode="horizontal", seed=seed)
    self.flip_inputs_w = tf.keras.layers.RandomFlip(mode="vertical", seed=seed)
    self.flip_labels_h = tf.keras.layers.RandomFlip(mode="horizontal", seed=seed)
    self.flip_labels_w = tf.keras.layers.RandomFlip(mode="vertical", seed=seed)
    
    self.rotate_inputs = tf.keras.layers.RandomRotation(factor=1., fill_mode='constant', seed=seed)
    self.rotate_labels = tf.keras.layers.RandomRotation(factor=1., fill_mode='constant', seed=seed)

    self.rnd_contrast = tf.keras.layers.RandomContrast(factor=0.05, seed=seed)
    self.rnd_bright = tf.keras.layers.RandomBrightness(factor=0.05, seed=seed)

  def call(self, inputs, labels):

    inputs = self.rotate_inputs(inputs)
    labels = self.rotate_labels(labels)
    inputs = self.rnd_contrast(inputs)
    inputs = self.rnd_bright(inputs)

    return inputs, labels

imdir_train = "./Training/Patch/Images/"
mdir_train = "./Training/Patch/Annotations/"

imdir_val = "./Validazione/Patch/Images/"
mdir_val = "./Validazione/Patch/Annotations/"


train_imgs = tf.keras.utils.image_dataset_from_directory(
    imdir_train,
    labels=None,
    batch_size=8,
    image_size=(384, 384),
    color_mode="rgb",
    seed=42,
)
train_masks = tf.keras.utils.image_dataset_from_directory(
    mdir_train,
    labels=None,
    batch_size=8,
    image_size=(384, 384),
    color_mode="grayscale",
    seed=42,
)
val_imgs = tf.keras.utils.image_dataset_from_directory(
    imdir_val,
    labels=None,
    batch_size=8,
    image_size=(384, 384),
    color_mode="rgb",
    seed=42,
)
val_masks = tf.keras.utils.image_dataset_from_directory(
    mdir_val,
    labels=None,
    batch_size=8,
    image_size=(384, 384),
    color_mode="grayscale",
    seed=42,
)

train_ds = tf.data.Dataset.zip((train_imgs, train_masks))
val_ds = tf.data.Dataset.zip((val_imgs, val_masks))

train_ds = train_ds.map(Augment())

train_ds = train_ds.prefetch(buffer_size=tf.data.AUTOTUNE)
val_ds = val_ds.prefetch(buffer_size=tf.data.AUTOTUNE)

In [None]:
LUT = np.zeros(shape=(7,3), dtype=np.uint8)

LUT[0] = [0, 0, 0]          # Background
LUT[1] = [244, 229, 136]    # WDF
LUT[2] = [104, 180, 46]     # Swamp
LUT[3] = [42, 75, 155]      # Organic
LUT[4] = [241, 137, 24]     # Sand
LUT[5] = [128, 192, 123]    # PDF
LUT[6] = [106, 69, 149]     # ProDelta

In [None]:
for img, mask in train_ds.take(1):

    for i in range(8):
        fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(10, 10))
        ax1.imshow(img[i].numpy().astype(np.uint8))
        ax2.imshow(LUT[mask[i].numpy().astype('uint8')[..., 0]])
        ax1.imshow(LUT[mask[i].numpy().astype('uint8')[..., 0]], alpha=0.2)
        ax1.axis('off') ; ax2.axis('off')

In [None]:
EPOCHS = 100
starter_learning_rate = 1e-4
end_learning_rate = 5e-6
decay_steps = train_ds.cardinality().numpy() * EPOCHS
learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay(
    starter_learning_rate,
    decay_steps,
    end_learning_rate,
    power=0.5)

In [None]:
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath='best_model.h5',
    save_weights_only=True,
    monitor='val_my_mean_iou',
    mode='max',
    save_best_only=True)

In [None]:
model = Unet((384, 384, 3), backbone="efficientnetb3", classes=7, final_activation='softmax')

model.compile(
    optimizer=tf.keras.optimizers.Adam(
        learning_rate=learning_rate_fn,
        name="Adam",
        
    ),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
    metrics=[MyMeanIOU(num_classes=7)]
)

In [None]:
history = model.fit(train_ds, epochs=EPOCHS, validation_data=val_ds, callbacks=[model_checkpoint_callback])

In [None]:
hist = pd.DataFrame.from_dict(history.history)
hist.to_csv('history_effnetb3.csv', index=False)