# Load libraries 

In [None]:
#!/usr/bin/env python3.7

import sys

if not sys.warnoptions:
    import warnings
    warnings.simplefilter("ignore")


from osgeo import gdal
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt
import tensorflow.keras.backend as K
import numpy as np
from sklearn.utils.class_weight import compute_class_weight
from tensorflow.python.keras import Model
from tensorflow.python.keras.layers import Input, Conv2D, MaxPooling2D, Conv2DTranspose, Concatenate, Dropout
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import Loss
from tensorflow.keras.callbacks import ReduceLROnPlateau
from segmentation_models.metrics import iou_score


#Resizing images, if needed
SIZE_X = 256 
SIZE_Y = 256
n_classes=3 #Number of classes for segmentation

# Load datasets from directory

In [None]:
# write a fuction to load the various datasets 
def load_data(directory):
    """
    Load multi-band .tif files from a directory into a NumPy array.
    
    Args:
        directory (str): Path to the directory containing .tif files.
    
    Returns:
        np.ndarray: Stack of .tif files as a 4D NumPy array (num_files, height, width, bands).
        list: List of filenames in the order they were loaded.
    """
    tif_files = [f for f in os.listdir(directory) if f.endswith('.tif')]
    tif_files.sort()

    arrays = []
    filenames = []

    for tif_file in tif_files:
        file_path = os.path.join(directory, tif_file)
        #print(f"Loading {file_path}...")

        dataset = gdal.Open(file_path)
        if dataset is None:
            print(f"Failed to load {file_path}")
            continue

        # Get number of bands
        num_bands = dataset.RasterCount
        bands = []

        # Read all bands
        for i in range(1, num_bands + 1):
            band = dataset.GetRasterBand(i)
            bands.append(band.ReadAsArray())

        # Stack bands along the last axis
        array = np.stack(bands, axis=-1)
        arrays.append(array)
        filenames.append(tif_file)

        dataset = None

    if arrays:
        stacked_array = np.stack(arrays, axis=0)
        print(f"Loaded {len(arrays)} .tif files into array of shape {stacked_array.shape}")
        return stacked_array, filenames
    else:
        print("No .tif files loaded.")
        return None, []

In [None]:
train_masks = load_data("path to training labels/")
y_train  = train_masks[0]

train_images = load_data("path to train images/")
X_train = train_images[0]

In [None]:
np.unique(y_train) # check the number of labels or uniques classes

In [None]:
val_images = load_data("path to validation images/")
X_val = val_images[0]

val_masks = load_data("path to Validation labels/")
y_val = val_masks[0]

In [None]:
del train_images, val_masks, val_images, train_masks 

## data agumentation

In [None]:
# Define a function to rotate images and masks by a specific angle (90°, 180°, 270°)
def rotate_image_stack(image_stack, mask, k):
    """
    Rotate the image stack and mask by k * 90 degrees.
    Args:
        image_stack: 4D tensor of shape (time, height, width, channels).
        mask: 3D tensor of shape (time, height, width).
        k: Number of 90-degree rotations (1 for 90°, 2 for 180°, 3 for 270°).
    Returns:
        Rotated image stack and mask.
    """
    image_stack = tf.image.rot90(image_stack, k=k)
    mask = tf.image.rot90(mask, k=k)
    return image_stack, mask

In [None]:
# Function to stack original and augmented datasets
def stack_datasets(original_images, original_masks, *augmented_datasets):
    """
    Stack the original dataset with augmented datasets.
    Args:
        original_images: Original 4D image stack.
        original_masks: Original 3D mask.
        *augmented_datasets: List of augmented (image_stack, mask) tuples.
    Returns:
        Stacked images and masks.
    """
    # Combine all datasets into lists
    all_images = [original_images] + [aug[0] for aug in augmented_datasets]
    all_masks = [original_masks] + [aug[1] for aug in augmented_datasets]
    
    # Stack along the batch dimension (axis=0)
    stacked_images = tf.concat(all_images, axis=0)
    stacked_masks = tf.concat(all_masks, axis=0)
    return stacked_images, stacked_masks

In [None]:
# Define a function to flip images and masks in horizontally
def random_flipHZ(image_stack, mask):
    # Randomly flip horizontally
    if tf.random.uniform(()) > 0.5:
        image_stack = tf.image.flip_left_right(image_stack)
        mask = tf.image.flip_left_right(mask)
    
    return image_stack, mask

# Define a function to flip images and masks in vertically
def random_flipVT(image_stack, mask):
   
    # Randomly flip vertically
    if tf.random.uniform(()) > 0.5:
        image_stack = tf.image.flip_up_down(image_stack)
        mask = tf.image.flip_up_down(mask)
    
    return image_stack, mask

In [None]:
# Function to stack original and augmented datasets
def stack_datasets2(original_images, original_masks, augmented_images, augmented_masks):
    # Stack along the batch dimension (axis=0)
    stacked_images = tf.concat([original_images, augmented_images], axis=0)
    stacked_masks = tf.concat([original_masks, augmented_masks], axis=0)
    return stacked_images, stacked_masks

In [None]:
# Convert to TensorFlow tensors
image_stack = tf.convert_to_tensor(X_train, dtype=tf.float32)
mask = tf.convert_to_tensor(y_train, dtype=tf.uint16)

# Apply different rotations to create augmented datasets
rotated_90, mask_90 = rotate_image_stack(image_stack, mask, k=1)  # 90°
rotated_180, mask_180 = rotate_image_stack(image_stack, mask, k=2)  # 180°
rotated_270, mask_270 = rotate_image_stack(image_stack, mask, k=3)  # 270°

In [None]:
# Stack the original dataset with all augmented datasets
X_train, y_train = stack_datasets(
    image_stack, mask,
    (rotated_90, mask_90),
    (rotated_180, mask_180),
    (rotated_270, mask_270)
)

In [None]:
# Apply augmentation to create flipping augmented data
augmented_imageVT, augmented_maskVT = random_flipVT(image_stack, mask) # vertical flip
augmented_imageHZ, augmented_maskHZ = random_flipHZ(image_stack, mask) # horizontal flip

In [None]:
del rotated_90, mask_90, image_stack, mask,rotated_180, mask_180,rotated_270, mask_270

In [None]:
# Stack the original and augmented datasets
X_train, y_train = stack_datasets2(X_train, y_train, augmented_imageVT, augmented_maskVT)

X_train, y_train = stack_datasets2(X_train, y_train, augmented_imageHZ, augmented_maskHZ)

In [None]:
del augmented_imageVT, augmented_maskVT, augmented_imageHZ, augmented_maskHZ

In [None]:
X_train = np.array(X_train)

## Standardisation

In [None]:
from sklearn.preprocessing import StandardScaler

# Reshape to (num_samples * width * height, channels)
samples, width, height, channels = X_train.shape
X_train_reshaped = X_train.reshape(-1, channels)  # Shape: (156*256*256, 25)

In [None]:
from joblib import dump, load
# Initialize StandardScaler
scaler = StandardScaler()

# Fit and transform the data
X_train_normalized = scaler.fit_transform(X_train_reshaped)  # Normalizes each channel

# Save the standardize parameters to a file for model predictions
#dump(scaler, 'scalerparameters_4model_predictions.joblib')

In [None]:
num_samples, width2, height2, num_channels = X_val.shape
X_val_reshaped = X_val.reshape(-1, num_channels)

#load the saved standardise parameters and apply to validation dataset
X_val_normalized = scaler.transform(X_val_reshaped)

X_train = X_train_normalized.reshape(samples, width, height, channels)
X_val = X_val_normalized.reshape(num_samples, width2, height2, num_channels)

In [None]:
del X_train_normalized, X_val_normalized, X_val_reshaped, X_train_reshaped

## One hot encoding labels

In [None]:
#from keras.utils import to_categorical

y_train_cat = tf.keras.utils.to_categorical(y_train, num_classes=n_classes)
y_val_cat = tf.keras.utils.to_categorical(y_val,num_classes=n_classes)

y_train = np.array(y_train)

## weights calculations

In [None]:
from sklearn.utils.class_weight import compute_class_weight

# Compute class weights dynamically and halve the weight of the majority class
def compute_weights(y_true):
    # Flatten y_train to 1D array of class labels
    y_flat = y_train.reshape(-1)

    class_counts = np.bincount(y_flat, minlength=n_classes)

    # Compute class weights
    classes = np.unique(y_flat)
    class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=y_flat)
    
    print("Class weights before adjustment:", class_weights)

    # Halve the weight of the majority class
    majority_class_index = np.argmax(class_counts)
    class_weights[majority_class_index] /= 2
    print("Class weights After adjustment:", class_weights)

    return {i: class_weights[i] for i in classes}


def normalize_class_weights(class_weights):
    class_weights = np.array(class_weights, dtype=np.float32)
    total = np.sum(class_weights)
    if total == 0:
        raise ValueError("Sum of class weights cannot be zero.")
    return class_weights / total


# Compute class weights dynamically
class_weights = compute_weights(y_train)

class_weights = [class_weights[k] for k in sorted(class_weights.keys())]
normalized_weights = normalize_class_weights(class_weights)

## Apply loss functions

In [None]:
# Dice Loss function for multi-class segmentation
def dice_loss(y_true, y_pred, smooth=1e-6):
    y_true_f = K.flatten(y_true)  # Flatten true labels
    y_pred_f = K.flatten(y_pred)  # Flatten predictions
    
    intersection = K.sum(y_true_f * y_pred_f)  # Compute intersection
    dice_coef = (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)  # Dice coefficient
    return 1 - dice_coef  # Dice loss

'''
A few useful metrics and losses
'''
def jacard_coef(y_true, y_pred):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (intersection + 1.0) / (K.sum(y_true_f) + K.sum(y_pred_f) - intersection + 1.0)

def focal_loss(y_true, y_pred, gamma=2.0, alpha=0.25):
    """
    Focal Loss for multiclass segmentation.
    
    Args:
        y_true: One-hot encoded ground truth.
        y_pred: Predicted probabilities.
        gamma: Focusing parameter.
        alpha: Balancing factor.
    """
    y_pred = tf.clip_by_value(y_pred, K.epsilon(), 1. - K.epsilon())
    cross_entropy = -y_true * tf.math.log(y_pred)
    focal = alpha * tf.pow(1 - y_pred, gamma) * cross_entropy
    return tf.reduce_mean(tf.reduce_sum(focal, axis=-1))

def combined_dice_focal_loss(gamma=2.0, alpha=0.25, weight_dice=0.5, weight_focal=0.5, smooth=1e-6):
    """
    Combined Dice + Focal Loss.
    
    Args:
        gamma: Focusing parameter for focal loss.
        alpha: Balancing factor for focal loss.
        weight_dice: Contribution of Dice Loss.
        weight_focal: Contribution of Focal Loss.
    """
    def loss(y_true, y_pred):
        dl = dice_loss(y_true, y_pred, smooth=smooth)
        fl = focal_loss(y_true, y_pred, gamma=gamma, alpha=alpha)
        return weight_dice * dl + weight_focal * fl

    return loss

## Model development

In [None]:
def unet_model(IMG_WIDTH, IMG_HIGHT, IMG_CHANNELS):
    inputs = tf.keras.layers.Input((IMG_WIDTH, IMG_HIGHT, IMG_CHANNELS))


    #Contraction path
    c1 = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(inputs)
    c1 = tf.keras.layers.Dropout(0.25)(c1)
    c1 = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c1)
    p1 = tf.keras.layers.MaxPooling2D((2, 2))(c1)

    c2 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p1)
    c2 = tf.keras.layers.Dropout(0.25)(c2)
    c2 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c2)
    p2 = tf.keras.layers.MaxPooling2D((2, 2))(c2)

    c3 = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p2)
    c3 = tf.keras.layers.Dropout(0.25)(c3)
    c3 = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c3)
    p3 = tf.keras.layers.MaxPooling2D((2, 2))(c3)

    c4 = tf.keras.layers.Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p3)
    c4 = tf.keras.layers.Dropout(0.25)(c4)
    c4 = tf.keras.layers.Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c4)
    p4 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(c4)

    c5 = tf.keras.layers.Conv2D(512, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p4)
    c5 = tf.keras.layers.Dropout(0.25)(c5)
    c5 = tf.keras.layers.Conv2D(512, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c5)

    #Expansive path 
    u6 = tf.keras.layers.Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(c5)
    u6 = tf.keras.layers.concatenate([u6, c4])
    c6 = tf.keras.layers.Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u6)
    c6 = tf.keras.layers.Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c6)

    u7 = tf.keras.layers.Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(c6)
    u7 = tf.keras.layers.concatenate([u7, c3])
    c7 = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u7)
    c7 = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c7)

    u8 = tf.keras.layers.Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(c7)
    u8 = tf.keras.layers.concatenate([u8, c2])
    c8 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u8)
    c8 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c8)

    u9 = tf.keras.layers.Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(c8)
    u9 = tf.keras.layers.concatenate([u9, c1], axis=3)
    c9 = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u9)
    c9 = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c9)

    outputs = tf.keras.layers.Conv2D(n_classes, (1, 1), activation='softmax')(c9)

    model = tf.keras.Model(inputs=[inputs], outputs=[outputs])
    #model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy', f1_m, precision_m, recall_m])
    
    return model

In [None]:
#######################################
#Parameters for model

IMG_HEIGHT = X_train.shape[1]
IMG_WIDTH  = X_train.shape[2]
IMG_CHANNELS = X_train.shape[3]
num_labels = 3  #Multiclass
input_shape = (IMG_HEIGHT,IMG_WIDTH,IMG_CHANNELS)

model = unet_model(IMG_HEIGHT,IMG_WIDTH,IMG_CHANNELS)


optimizer = Adam(lr=1e-3)
    
loss_fn = combined_dice_focal_loss(gamma=2.0, alpha=normalized_weights, weight_dice=0.6, weight_focal=0.4)
    
model.compile(optimizer=optimizer,loss= loss_fn,metrics=['accuracy', jacard_coef,iou_score])

# Define learning rate reduction callback
reduce_lr = ReduceLROnPlateau(
    monitor='val_loss',  # Metric to monitor
    factor=0.1,         # Factor by which to reduce (new_lr = lr * factor)
    patience=5,         # Number of epochs with no improvement
    verbose=1           # Print message when LR reduces
)

results = model.fit(X_train, y_train_cat, batch_size =16, verbose=1, epochs=30, validation_data=(X_val, y_val_cat), 
                    callbacks=[reduce_lr],
                    shuffle=False)

In [None]:
'''Prediction over the validation dataset'''
pred_test = model.predict(X_val)

pred_test = np.argmax(pred_test, axis=-1)

In [None]:
#save model for predictions
model.save('UNeTASMMonitoring_30_epochs_FocalLossAndDiceLoss.hdf5')