
# Notebook for U-Net and Attention U-Net models

___

## PAPER: Comparison of Conventional Machine Learning and Convolutional Deep Learning models for Seagrass Mapping using Satellite Imagery

#### Antonio Mederos-Barrera (mederosbarrera.antonio@gmail.com)

### Imports

In [None]:
# General
import os
import sys
# import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torchvision
import torchvision.transforms as transforms
# from torchvision.transforms import CenterCrop
from tqdm.notebook import tqdm
# import matplotlib
import matplotlib.pyplot as plt
import random
import pandas as pd
import seaborn as sns
from sklearn.metrics import confusion_matrix
import scipy.io
from scipy.stats import mode
from torch.optim.lr_scheduler import ExponentialLR

# Utils
sys.path.append(os.path.abspath('.'))
from utils.class_weights import obtain_class_weights
from utils.dice_loss_function import dice_loss
from utils.early_stopping import EarlyStopping
from utils.error_metrics import error_metrics
from utils.load_data import find_files, mean_std_dataset, data_train_test, SegmentationDataset
from utils.map_estimation_functions import mode_no_zeros, numpy_to_torch, obtain_mode_no_zeros, predict_model, sliding_window
from utils.model_initialization import initialize_weights_kaiming
from models.unet import UNet, DoubleConv, Down, Up, OutConv

### Hyperparameters

In [None]:
# GPU use
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("DEVICE:", DEVICE)
PIN_MEMORY = True if DEVICE == "cuda" else False

# Reproducibility
def set_manual_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        
SEED = 0                                                  # Random seed for reproducibility
set_manual_seed(SEED)

BATCH_SIZE = "all"                                        # Batch size (can be "all" to use all the training and test dataset)

NUM_EPOCHS = 2000                                         # Epochs (upper limit in case Early stopping does not converge)

PATIENCE_EARLYSTOPPING = 140                              # Patiente of Early Stopping

INPUT_WIDTH = 70                                          # Image size (first dimension)
INPUT_HEIGHT = 70                                         # Image size (second dimension)
NUM_CHANNELS = 8                                          # Image channels (third dimension)

# Path to the dataset
IMAGE_DATASET_PATH = "C:/Path/To/Data"
NUM_CLASSES = 5                                           # Number of classes
VECTOR_IDX_TEST = [6,10,9,7,13,17,19,25,29,34,37,42,46,50]# Vector images for the test dataset

PATH = "C:/Path/To/Save/Results/"
MODEL_PATH = PATH + "unet_"
PLOT_PATH = PATH + "plot_"
NPY_PATH = PATH + "lossfunc_"

# Models hyperparameters (lossFunc, LR, L2, ExpSchedGamma, AttGates)
Model1 = ["BCE", , 0.01, None, False]         # U-Net (BCE)
Model2 = ["GDL", 0.001, 0.01, 0.999, False]   # U-Net (GDL)
Model3 = ["GDL", 0.0001, 0.00001, None, True] # Attention U-Net
Models = [Model1, Model2, Model3]

# Estimation
IMG_PATH = 'C:Path/To/Cmomplete/WorldView/Image.mat' # Image
mymap = np.array([  # Colors
    [0, 0, 0],      # Land
    [139, 0, 0],    # Red algae
    [0, 139, 1],    # C. nodosa (sebadal)
    [142, 87, 2],   # Rock
    [255, 255, 0]   # Sand
])
window_size = (70, 70) # Slide windows
stride = 5          # Slide windows

 ### Training and test

In [None]:

def split_params_for_weight_decay(model):
    decay, no_decay = [], []
    for n, p in model.named_parameters():
        if not p.requires_grad:
            continue
        is_norm = any(k in n.lower() for k in ["bn", "norm", "bias"])
        is_att  = "att" in n.lower()
        if is_norm or is_att:
            no_decay.append(p)
        else:
            decay.append(p)
    return decay, no_decay

def plot_error_curve(matrix_data_train_test, limit=None):
    if limit is not None:
        matrix_data_train_test[matrix_data_train_test>limit] = limit
    plt.style.use("ggplot")
    plt.figure()
    plt.plot(matrix_data_train_test[:, 0], label="Training loss")
    plt.plot(matrix_data_train_test[:, 1], label="Test loss")
    plt.xlabel("Epoch #")
    plt.ylabel("Loss")
    plt.legend(loc="upper right")
    plt.show()
    

def delete_file(directory, start, end):
    for file in os.listdir(directory):
        if file.startswith(start) and file.endswith(end):
            file_path = os.path.join(directory, file)
            if os.path.isfile(file_path):
                try:
                    os.remove(file_path)
                except OSError as e:
                    print(f"Error deleting the file {file_path}: {e}")

trainLoader, testLoader, trainDS, testDS, BATCH_SIZE_TRAIN, BATCH_SIZE_TEST = data_train_test(IMAGE_DATASET_PATH,
                                                                                              INPUT_WIDTH,
                                                                                              INPUT_HEIGHT,
                                                                                              NUM_CLASSES,
                                                                                              BATCH_SIZE,
                                                                                              PIN_MEMORY,
                                                                                              VECTOR_IDX_TEST)
class_weights = obtain_class_weights(trainLoader, NUM_CLASSES, DEVICE)


indx_print = 0

for mdl_hyp in Models:
    
    # Hyperparameters
    if mdl_hyp[0] == "BCE":
        lossFunc = nn.BCEWithLogitsLoss()
    else:
        lossFunc = dice_loss(class_weights=class_weights)
    lr_value = mdl_hyp[1]
    reg_L2_value = mdl_hyp[2]
    scheduler_gamma = mdl_hyp[3]
    use_attention_gates = mdl_hyp[4]
    
    # Incitial print
    indx_print += 1
    print("[",indx_print,"of",len(Models),"]")

    # Random seed
    set_manual_seed(SEED)

    # Model
    if use_attention_gates:
        mdl = UNet(n_channels=NUM_CHANNELS,
                   n_classes=NUM_CLASSES,
                   use_attention_gates=True)
    else:
        mdl = UNet(n_channels=NUM_CHANNELS,
                   n_classes=NUM_CLASSES)
    initialize_weights_kaiming(mdl)
    mdl.to(DEVICE)

    # Optimizer
    if reg_L2_value is None:
        opt = Adam(mdl.parameters(), lr=lr_value)
    else:
        decay, no_decay = split_params_for_weight_decay(mdl)
        opt = Adam(
            [
                {"params": decay, "weight_decay": reg_L2_value},
                {"params": no_decay, "weight_decay": 0.0}
            ],
            lr=lr_value
        )

    # Scheduler LR
    if scheduler_gamma is not None:
        scheduler = ExponentialLR(opt, gamma=scheduler_gamma)
        lrs_values = []

    # Data
    trainLoader, testLoader, trainDS, testDS, BATCH_SIZE_TRAIN, BATCH_SIZE_TEST = data_train_test(IMAGE_DATASET_PATH,
                                                                                                  INPUT_WIDTH,
                                                                                                  INPUT_HEIGHT,
                                                                                                  NUM_CLASSES,
                                                                                                  BATCH_SIZE,
                                                                                                  PIN_MEMORY,
                                                                                                  VECTOR_IDX_TEST)

    # Steps
    trainSteps = len(trainDS) // BATCH_SIZE_TRAIN
    testSteps  = len(testDS)  // BATCH_SIZE_TEST

    # Early Stopping
    early_stopping = EarlyStopping(patience=PATIENCE_EARLYSTOPPING,
                                   path_base=MODEL_PATH,
                                   mode='min',
                                   loss_func=mdl_hyp[0],
                                   lr_value=lr_value,
                                   scheduler_gamma=scheduler_gamma,
                                   verbose=False,
                                   min_delta=0.00001,
                                   reg_L2_value=reg_L2_value,
                                   use_attention_gates=use_attention_gates)

    # Training and test epochs
    matrix_data_train_test = np.zeros((NUM_EPOCHS, 2))
    for e in tqdm(range(NUM_EPOCHS)):

        # Train
        mdl.train()
        totalTrainLoss = 0
        totalTestLoss = 0
        for (i, (x, y)) in enumerate(trainLoader):
            opt.zero_grad()
            (x, y) = (x.to(DEVICE), y.to(DEVICE))
            pred = mdl(x)
            loss = lossFunc(pred,y)
            loss.backward()
            opt.step()
            totalTrainLoss += loss

        # Test
        iou = np.zeros((NUM_CLASSES,))
        with torch.no_grad():
            mdl.eval()
            for (i, (x, y)) in enumerate(testLoader):
                (x, y) = (x.to(DEVICE), y.to(DEVICE))
                pred = mdl(x)
                val_loss = lossFunc(pred, y)
                totalTestLoss += val_loss

        # Training and test losses
        avgTrainLoss = totalTrainLoss / trainSteps
        avgTestLoss = totalTestLoss / testSteps
        matrix_data_train_test[e, 0] = avgTrainLoss.cpu().detach().numpy()
        matrix_data_train_test[e, 1] = avgTestLoss.cpu().detach().numpy()

        # Actualization of Scheduler LR
        if scheduler_gamma is not None:
            lrs_values.append(scheduler.get_last_lr()[0])
            scheduler.step()

        # Early Stopping
        early_stopping.epoch = e
        early_stopping.lr = lr_value
        early_stopping(avgTestLoss, mdl)
        if early_stopping.early_stop:

            best_epoch = int(e-PATIENCE_EARLYSTOPPING)
            print("Early stopping in the epoch:", best_epoch)

            # Save
            NPY_PATH_ = NPY_PATH+mdl_hyp[0]+'_LR-'+str(lr_value)+'_ExpSchedGamma-'+str(scheduler_gamma)+'_L2-'+str(reg_L2_value)+'_EPOCHS-'+str(best_epoch)+"_EarlyStopping_AG-"+str(use_attention_gates)+".npy"
            np.save(NPY_PATH_, matrix_data_train_test[:best_epoch+1,:])
            NPY_PATH_ALL_ = NPY_PATH+mdl_hyp[0]+'_LR-'+str(lr_value)+'_ExpSchedGamma-'+str(scheduler_gamma)+'_L2-'+str(reg_L2_value)+'_EPOCHS-'+str(best_epoch)+"_EarlyStopping_AG-"+str(use_attention_gates)+"_ALL.npy"
            np.save(NPY_PATH_ALL_, matrix_data_train_test)

            # Error Curves
            MODEL_PATH_ = MODEL_PATH+mdl_hyp[0]+'_LR-'+str(lr_value)+'_ExpSchedGamma-'+str(scheduler_gamma)+'_L2-'+str(reg_L2_value)+'_EPOCHS-'+str(best_epoch)+"_EarlyStopping_AG-"+str(use_attention_gates)+".pth"
            print(MODEL_PATH_)
            plot_error_curve(matrix_data_train_test[:best_epoch+1,:])
            plot_error_curve(matrix_data_train_test[:best_epoch+1,:], limit=1.0)

            if scheduler_gamma is not None:
                # LRs values (Schedulers)
                lrs_values_np = np.array(lrs_values)
                plt.figure
                plt.style.use("ggplot")
                plt.plot(lrs_values_np[:best_epoch+1])
                plt.xlabel('Epochs #')
                plt.ylabel('Learning Rate')
                plt.title("Learning Rate Exponential Scheduler. Gamma: "+str(scheduler_gamma))
                plt.grid(True)
                plt.show()

            break

    # Early stopping is not triggered
    if not early_stopping.early_stop:
        print("Early stopping is not triggered")
        plot_error_curve(matrix_data_train_test)
        plot_error_curve(matrix_data_train_test, limit=1.0)
        
        if scheduler_gamma is not None:
            # LRs values (Schedulers)
            plt.figure
            plt.style.use("ggplot")
            plt.plot(lrs_values)
            plt.xlabel('Epochs #')
            plt.ylabel('Learning Rate')
            plt.title("Learning Rate Exponential Schedule. Gamma:"+str(scheduler_gamma))              
            plt.grid(True)
            plt.show()
            
        MODEL_PATH_LAST_ = MODEL_PATH+mdl_hyp[0]+'_LR-'+str(lr_value)+'_ExpSchedGamma-'+str(scheduler_gamma)+'_L2-'+str(reg_L2_value)+'_EPOCHS-'+str(best_epoch)+"_EarlyStopping_AG-"+str(use_attention_gates)+".pth"
        torch.save(mdl.state_dict(), MODEL_PATH_LAST_)

### Error metrics

In [None]:

FINAL_MODELS = [PATH+archivo for archivo in os.listdir(PATH) if archivo.startswith(MODEL_PATH[len(PATH):]) and archivo.endswith(".pth")]

CSV_PATH = [model.replace('unet_', 'csv_') for model in FINAL_MODELS]
CSV_PATH = [model.replace('.pth', '.csv') for model in CSV_PATH]

for indx_model, model in enumerate(FINAL_MODELS):

    print("[",indx_model+1,"of",len(FINAL_MODELS),"]")
    
    # Model
    mdl = torch.load(model, weights_only=False)
    mdl.to(DEVICE)
    print(model)
          
    # Metrics initialization
    accuracy = np.zeros((NUM_CLASSES,))
    precision = np.zeros((NUM_CLASSES,))
    recall = np.zeros((NUM_CLASSES,))
    f1 = np.zeros((NUM_CLASSES,))
    iou = np.zeros((NUM_CLASSES,))
    dice = np.zeros((NUM_CLASSES,))
                
    # Test error metrics
    with torch.no_grad():
        mdl.eval()
        testSteps = 0
        for (i, (x, y)) in enumerate(testLoader):
            (x, y) = (x.to(DEVICE), y.to(DEVICE))
            pred = mdl(x)
            pred = torch.argmax(pred, dim=1)
            pred = F.one_hot(pred, num_classes=x.shape[1])
            pred = pred.permute(0, 3, 1, 2)
            y = y.cpu().detach().numpy()
            pred = pred.cpu().detach().numpy()
                                
            accuracy_, precision_, recall_, f1_, iou_, dice_ = error_metrics(y, pred)
            accuracy += accuracy_
            precision += precision_
            recall += recall_
            f1 += f1_
            iou += iou_
            dice += dice_
            
            testSteps +=1
                    
        accuracy /= testSteps
        precision /= testSteps
        recall /= testSteps
        f1 /= testSteps
        iou /= testSteps
        dice /= testSteps
        
        # Values and mean without land class (not considered in ML models)
        accuracy =  np.concatenate((accuracy, [-1 if -1 in accuracy else np.mean(accuracy[1:])]))
        precision =  np.concatenate((precision, [-1 if -1 in precision else np.mean(precision[1:])]))
        recall =  np.concatenate((recall, [-1 if -1 in recall else np.mean(recall[1:])]))
        f1 =  np.concatenate((f1, [-1 if -1 in f1 else np.mean(f1[1:])]))
        iou =  np.concatenate((iou, [-1 if -1 in iou else np.mean(iou[1:])]))
        dice =  np.concatenate((dice, [-1 if -1 in dice else np.mean(dice[1:])]))
        print("Metrics (Acc, Prec, Recall, F1, IoU, Dice):\n", 
                accuracy,'\n',
                precision,'\n',
                recall,'\n',
                f1,'\n',
                iou,'\n',
                dice,'\n')

        # Save metrics in CSV file
        metrics_dict = {
            'Accuracy': accuracy,
            'Precision': precision,
            'Recall': recall,
            'F1-score': f1,
            'IoU': iou,
            'Dice Coefficient': dice
        }
        columns = ['Land', 'Red algae', 'C. nodosa', 'Rock', 'Sand', 'Mean (without land)']
        df_metrics = pd.DataFrame(metrics_dict, index=columns).transpose()
        df_metrics.to_csv(CSV_PATH[indx_model], sep=';', index=True)


### Map estimation

In [None]:

FINAL_MODELS = [PATH+archivo for archivo in os.listdir(PATH) if archivo.startswith(MODEL_PATH[len(PATH):]) and archivo.endswith(".pth")]

SAVE_PATH = [model.replace('unet_', 'plot_') for model in FINAL_MODELS]
SAVE_PATH = [model.replace('.pth', '.mat') for model in SAVE_PATH]

mat_data_image = scipy.io.loadmat(IMG_PATH)
filtered_keys_image = [key for key in mat_data_image.keys() if "__" not in key]
image = mat_data_image[filtered_keys_image[0]]

mean, std = mean_std_dataset(IMAGE_DATASET_PATH, INPUT_WIDTH, INPUT_HEIGHT, NUM_CLASSES, BATCH_SIZE, PIN_MEMORY)
image = numpy_to_torch(image, mean, std)

# Estimation
# N = image.shape
for indx_model, path_model in tqdm(enumerate(FINAL_MODELS)):
    
    print("[",indx_model+1,"of",len(FINAL_MODELS),"]")
    
    # Model
    mdl = torch.load(path_model, weights_only=False)
    mdl.to(DEVICE)

    # Slide window
    pred_image = sliding_window(image, window_size, stride, mdl, DEVICE)

    # Mode without zeros values
    pred_image_mode = obtain_mode_no_zeros(pred_image)
    
    # Presentation
    plt.figure(figsize=(15, 15))
    plt.title(SAVE_PATH[indx_model])
    plt.imshow(mymap[pred_image_mode.astype(int)])
    plt.show()
    
    # Save
    scipy.io.savemat(SAVE_PATH[indx_model], {'seabedmap_graciosa_unet':pred_image_mode})