## Download the data

In [None]:
!mkdir Dataset
!wget  https://zenodo.org/records/7189381/files/trainX.npy?download=1  -O Dataset/Xdata.npy
!wget  https://zenodo.org/records/7189381/files/trainY.npy?download=1  -O Dataset/Ydata.npy

## Prepare paths of input images and target segmentation masks

In [None]:
import numpy as np
import tensorflow as tf
import keras
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import sklearn

In [None]:
Xdata = np.load("Dataset/Xdata.npy")
Ydata = np.load("Dataset/Ydata.npy")
print(f"the shape of input image matrix is {Xdata.shape}")

## What does one input image and corresponding segmentation mask look like?

In [None]:
n = 289  # sample number
fig, ax = plt.subplots(1, 2, figsize=(12, 6))
ax[0].imshow(Xdata[n, :, :, :3].transpose((0, 1, 2)))
ax[1].imshow(Ydata[n, :, :, 0])

ax[0].ticklabel_format(useOffset=False, style="plain")
ax[1].ticklabel_format(useOffset=False, style="plain")

## Prepare dataset to train validate and test the model

In [None]:
X_train, X_test, y_train, y_test = train_test_split(
    Xdata, Ydata, test_size=0.3, random_state=42
)

## Prepare U-Net model

In [None]:
import tensorflow as tf
from tensorflow.keras import models
from tensorflow.keras.layers import *
from tensorflow.keras.activations import *
from tensorflow.keras.models import Sequential


def down_block(x, filters, use_maxpool=True):
    x = Conv2D(filters, 3, padding="same")(x)
    x = BatchNormalization()(x)
    x = LeakyReLU()(x)
    x = Conv2D(filters, 3, padding="same")(x)
    x = BatchNormalization()(x)
    x = LeakyReLU()(x)
    if use_maxpool == True:
        return MaxPooling2D(strides=(2, 2))(x), x
    else:
        return x


def up_block(x, y, filters):
    x = UpSampling2D()(x)
    x = Concatenate(axis=3)([x, y])
    x = Conv2D(filters, 3, padding="same")(x)
    x = BatchNormalization()(x)
    x = LeakyReLU()(x)
    x = Conv2D(filters, 3, padding="same")(x)
    x = BatchNormalization()(x)
    x = LeakyReLU()(x)
    return x


def get_model(input_size=(256, 256, 3), *, classes, dropout):
    filter = [64, 128, 256, 512, 1024]
    # encode
    input = Input(shape=input_size)
    x, temp1 = down_block(input, filter[0])
    x, temp2 = down_block(x, filter[1])
    x, temp3 = down_block(x, filter[2])
    x, temp4 = down_block(x, filter[3])
    x = down_block(x, filter[4], use_maxpool=False)
    # decode
    x = up_block(x, temp4, filter[3])
    x = up_block(x, temp3, filter[2])
    x = up_block(x, temp2, filter[1])
    x = up_block(x, temp1, filter[0])
    x = Dropout(dropout)(x)
    output = Conv2D(classes, 1, activation="sigmoid")(x)
    model = models.Model(input, output, name="unet")
    model.summary()
    return model


# Build model
model = get_model(input_size=(128, 128, 4), classes=1, dropout=0.2)
model.summary()

## Train the model

In [None]:
# Compile model for training
model.compile(
    optimizer=keras.optimizers.Adam(1e-4),
    loss="binary_crossentropy",
    metrics=[
        "accuracy",
        tf.keras.metrics.AUC(),
        tf.keras.metrics.MeanIoU(num_classes=2),
        tf.keras.metrics.Precision(),
        tf.keras.metrics.Recall(),
    ],
)


def trainmodel(model, xdata, ydata):
    NUMBER_EPOCHS = 10
    filepath = "checkpointMaping"
    BATCH_SIZE = 32

    model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
        filepath,
        monitor="val_loss",
        verbose=0,
        save_best_only=True,
        save_weights_only=False,
        mode="min",
        save_freq="epoch",
        options=None,
    )
    print(type(xdata), type(ydata))
    hist = model.fit(
        x=xdata,
        y=ydata,
        epochs=NUMBER_EPOCHS,
        batch_size=BATCH_SIZE,
        validation_split=0.2,  # auto validate using 20% of random samples at each epoch
        verbose=1,
        callbacks=[model_checkpoint_callback],
        class_weight={0: 1, 1: 5},
    )
    return hist

In [None]:
trainmodel(
    model,
    np.array(X_train, dtype=np.float32),
    np.expand_dims(np.array(y_train, dtype=np.float32), axis=-1),
)

## Visualize predictions

In [None]:
# Generate predictions for all images in the validation set

val_preds = model.predict(X_test)

In [None]:
preds = val_preds
preds[preds > 0.50] = 1
preds[preds <= 0.50] = 0
sklearn.metrics.accuracy_score(y_test.flatten(), preds.flatten())

In [None]:
preds = val_preds
preds[preds > 0.50] = 1
preds[preds <= 0.50] = 0
sklearn.metrics.f1_score(y_test.flatten(), preds.flatten())

In [None]:
n = 235  # sample number
fig, ax = plt.subplots(1, 2, figsize=(12, 6))
im1 = ax[0].imshow(
    val_preds[n, :, :, :3].transpose((0, 1, 2)), vmin=0, vmax=0.5, cmap="plasma"
)
im2 = ax[1].imshow(y_test[n, :, :, 0], cmap="plasma")
ax[0].ticklabel_format(useOffset=False, style="plain")
ax[1].ticklabel_format(useOffset=False, style="plain")

fig.colorbar(im1, ax=ax[0])
fig.colorbar(im2, ax=ax[1])