## Download the pre-processed data in previous data preparation notebook

In [4]:
# Download Ground Truth
!mkdir Dataset
!wget  https://zenodo.org/records/12751419/files/XDataPrichit.npy?download=1  -O Dataset/Xdata.npy
!wget  https://zenodo.org/records/12751419/files/YDataPrichit.npy?download=1  -O Dataset/Ydata.npy

mkdir: Dataset: File exists
zsh:1: no matches found: https://zenodo.org/records/12751419/files/XDataPrichit.npy?download=1
zsh:1: no matches found: https://zenodo.org/records/12751419/files/YDataPrichit.npy?download=1


In [8]:
# import libraries
import os
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import sklearn
import keras

os.environ["SM_FRAMEWORK"] = "tf.keras"
import tensorflow as tf
import segmentation_models as sm
from segmentation_models import Unet

keras.backend.set_image_data_format("channels_last")

ModuleNotFoundError: No module named 'tensorflow'

In [None]:
# import and divide dataset
Xdata = np.load("Dataset/Xdata.npy").transpose(0, 2, 3, 1)
Ydata = np.load("Dataset/Ydata.npy").transpose(0, 2, 3, 1)
print(f"the shape of input image matrix is {Xdata.shape}")

In [None]:
n = 1354  # sample number
fig, ax = plt.subplots(1, 2, figsize=(12, 6))
ax[0].imshow(Xdata[n, :, :, 6])
ax[1].imshow(Ydata[n, :, :, 0])

ax[0].ticklabel_format(useOffset=False, style="plain")
ax[1].ticklabel_format(useOffset=False, style="plain")

In [None]:
X_train, X_test, y_train, y_test = train_test_split(
    Xdata, Ydata, test_size=0.3, random_state=42
)

In [None]:
print(f" Size of XTrain is {X_train.shape}")
print(f" Size of XTest is {X_test.shape}")
print(f" Size of YTrain is {y_train.shape}")
print(f" Size of YTest is {y_test.shape}")

### Define model and Train

In [None]:
# model = sm.Unet('resnet34', classes=1, activation='sigmoid')
# https://segmentation-models.readthedocs.io/en/latest/tutorial.html
model = Unet(
    backbone_name="resnet34",
    classes=1,
    activation="sigmoid",
    encoder_weights=None,
    input_shape=(256, 256, 9),
)

In [None]:
# Configure the model for training.
# We use the "sparse" version of categorical_crossentropy
# because our target data is integers.
model.compile(
    optimizer=keras.optimizers.Adam(1e-4),
    loss="binary_crossentropy",
    metrics=[
        "accuracy",
        tf.keras.metrics.AUC(),
        tf.keras.metrics.MeanIoU(num_classes=2),
        tf.keras.metrics.Precision(),
        tf.keras.metrics.Recall(),
    ],
)


def trainmodel(model, xdata, ydata):
    NUMBER_EPOCHS = 10
    filepath = "checkpointMaping"
    BATCH_SIZE = 32

    model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
        filepath,
        monitor="val_loss",
        verbose=0,
        save_best_only=True,
        save_weights_only=False,
        mode="min",
        save_freq="epoch",
        options=None,
    )
    print(type(xdata), type(ydata))
    hist = model.fit(
        x=xdata,
        y=ydata,
        epochs=NUMBER_EPOCHS,
        batch_size=BATCH_SIZE,
        validation_split=0.2,  # auto validate using 20% of random samples at each epoch
        verbose=1,
        callbacks=[model_checkpoint_callback],
        class_weight={0: 1, 1: 5},
    )
    return hist

In [67]:
trainmodel(
    model,
    np.array(X_train, dtype=np.float32),
    np.expand_dims(np.array(y_train, dtype=np.float32), axis=-1),
)

In [None]:
# Generate predictions for all images in the validation set

val_preds = model.predict(X_test)

In [None]:
preds = val_preds
preds[preds > 0.50] = 1
preds[preds <= 0.50] = 0
sklearn.metrics.accuracy_score(y_test.flatten(), preds.flatten())

In [None]:
preds = val_preds
preds[preds > 0.50] = 1
preds[preds <= 0.50] = 0
sklearn.metrics.f1_score(y_test.flatten(), preds.flatten())

In [None]:
n = 235  # sample number
fig, ax = plt.subplots(1, 2, figsize=(12, 6))
im1 = ax[0].imshow(
    val_preds[n, :, :, :3].transpose((0, 1, 2)), vmin=0, vmax=0.5, cmap="plasma"
)
im2 = ax[1].imshow(y_test[n, :, :, 0], cmap="plasma")
ax[0].ticklabel_format(useOffset=False, style="plain")
ax[1].ticklabel_format(useOffset=False, style="plain")

fig.colorbar(im1, ax=ax[0])
fig.colorbar(im2, ax=ax[1])