# Training U-Net

Training a U-Net for denoising LP STEM images.

In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

from sklearn.model_selection import train_test_split

from tensorflow.keras.layers import Input
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau

### Enter the location of your aquaDenoising folder containing the general_functions folder
sys.path.append("path/to/aquaDenoising/")

from general_functions.neural_networks.norm_patch import robustnorm, randompatch
from general_functions.neural_networks.architectures import model_unet
from general_functions.neural_networks.metrics import ssimsmape_loss, ssim, psnr

## Load Simulated Dataset 

In [None]:
NB_PATCH = 512
PATCH_SIZE = 128

In [None]:
# %%time
allshape_clean = np.load("simulated/noiseless/images/dataset")
allshape_SNR = np.load("simulated/noisy/images/dataset")

allshape_clean = np.expand_dims(allshape_clean, axis=3)
allshape_SNR = np.expand_dims(allshape_SNR, axis=3)

allshape_clean.shape, allshape_SNR.shape

In [None]:
# fig, ax = plt.subplots(9, 5, figsize=[20, 36])
# for i in range(9):
#     for ii in range(5):
#         ax[i,ii].imshow(allshape_clean[ii*9+i])
#         ax[i,ii].set_title(f"{ii*9+i}")

In [None]:
### Train - Validation Split
# %%time

training_full_img_in = allshape_SNR[:-9]
training_full_img_out = allshape_clean[:-9]

validation_full_img_in = allshape_SNR[-9:]
validation_full_img_out = allshape_clean[-9:]

training_full_img_in.shape, validation_full_img_in.shape

In [None]:
# %%time
train_seed = 813

train_input  = np.zeros((0, PATCH_SIZE, PATCH_SIZE))
train_output = np.zeros((0, PATCH_SIZE, PATCH_SIZE))

for idx in range(training_full_img_in.shape[0]):
    train_input = np.append(train_input, randompatch(training_full_img_in[idx], NB_PATCH, patch_size=PATCH_SIZE, seed=train_seed, write_xy=False), axis=0)    
    train_output = np.append(train_output, randompatch(training_full_img_out[idx], NB_PATCH, patch_size=PATCH_SIZE, seed=train_seed, write_xy=False), axis=0)
    
train_input = robustnorm(train_input, 0.01)
train_output = robustnorm(train_output, 0)
    
train_input  = np.expand_dims(train_input, axis=3)
train_output = np.expand_dims(train_output, axis=3)

train_input.shape, train_output.shape

In [None]:
# %%time
val_seed = 525

val_input  = np.zeros((0,PATCH_SIZE,PATCH_SIZE))
val_output = np.zeros((0,PATCH_SIZE,PATCH_SIZE))

for idx in range(validation_full_img_in.shape[0]):
    val_input = np.append(val_input, randompatch(validation_full_img_in[idx], NB_PATCH, patch_size=PATCH_SIZE, seed=val_seed, write_xy=False), axis=0)
    val_output = np.append(val_output,randompatch(validation_full_img_out[idx], NB_PATCH, patch_size=PATCH_SIZE, seed=val_seed, write_xy=False), axis=0)

val_input = robustnorm(val_input, 0.01)
val_output = robustnorm(val_output, 0)
    
val_input  = np.expand_dims(val_input, axis=3)
val_output = np.expand_dims(val_output, axis=3)

val_input.shape, val_output.shape

In [None]:
fig, ax = plt.subplots(2,5,figsize=(24,12))

patch_img = 2804
ax[0,0].imshow(train_input[patch_img,:,:,0], cmap="gray")
ax[1,0].imshow(train_output[patch_img,:,:,0], cmap="gray")
ax[0,0].set_title(f"min:{np.min(train_input[patch_img,:,:,0])} max:{np.max(train_input[patch_img,:,:,0])}")

patch_img = 351
ax[0,1].imshow(train_input[patch_img,:,:,0], cmap="gray")
ax[1,1].imshow(train_output[patch_img,:,:,0], cmap="gray")
ax[0,1].set_title(f"min:{np.min(train_input[patch_img,:,:,0])} max:{np.max(train_input[patch_img,:,:,0])}")

patch_img = 1461
ax[0,2].imshow(train_input[patch_img,:,:,0], cmap="gray")
ax[1,2].imshow(train_output[patch_img,:,:,0], cmap="gray")
ax[0,2].set_title(f"min:{np.min(train_input[patch_img,:,:,0])} max:{np.max(train_input[patch_img,:,:,0])}")

patch_img = 1593
ax[0,3].imshow(train_input[patch_img,:,:,0], cmap="gray")
ax[1,3].imshow(train_output[patch_img,:,:,0], cmap="gray")
ax[0,3].set_title(f"min:{np.min(train_input[patch_img,:,:,0])} max:{np.max(train_input[patch_img,:,:,0])}")

patch_img = 872
ax[0,4].imshow(train_input[patch_img,:,:,0], cmap="gray")
ax[1,4].imshow(train_output[patch_img,:,:,0], cmap="gray")
ax[0,4].set_title(f"min:{np.min(train_input[patch_img,:,:,0])} max:{np.max(train_input[patch_img,:,:,0])}")

In [None]:
### Training Parameters

TRAINING_NB = "01"

NUM_EPOCHS = 300
BATCH_SIZE = 4

### - - - - - - - - - -

input_img = Input((128, 128, 1), name='img')
save_dir = f"DATA/UNet/saved_models/Training{TRAINING_NB}/"

model = model_unet(input_img, 1, n_filters=8, layers_repetition=2, dropout=0.05)
loss = {"output_denoised":ssimsmape_loss}
model.compile(optimizer=Adam(),loss=loss, metrics= ["accuracy", ssim, psnr, 'mean_absolute_percentage_error'])

### Callbacks

checkpoint = ModelCheckpoint(save_dir+f"model_vloss_min.h5",
                             monitor='val_loss',
                             verbose=1,
                             save_best_only=True,
                             mode='min')

early_stopping_monitor = EarlyStopping(monitor = 'val_loss',
                                       patience=15,
                                       verbose=1,
                                       mode='min')
                                       # restore_best_weights = True

Reduce_LR = ReduceLROnPlateau(factor=0.1,
                              patience=5,
                              min_lr=0.00001,
                              verbose=1)

Model_checkpoint = ModelCheckpoint(filepath = save_dir+f"model.h5",
                                   verbose=1,
                                   save_best_only = False)

callbacks_list = [checkpoint, early_stopping_monitor, Reduce_LR, Model_checkpoint]

In [None]:
history = model.fit(train_input, train_output, batch_size=BATCH_SIZE, epochs=NUM_EPOCHS, callbacks=callbacks_list,
                    validation_data=(val_input, val_output), use_multiprocessing = True)

In [None]:
fig, ax = plt.subplots(1,1, figsize=[8,6])
ax.plot(history.epoch[1:], history.history["loss"][1:], ".-", label="Training")
ax.set_ylabel("Loss : ssimSMAPE", fontsize=20)
ax.plot(history.epoch[1:], history.history["val_loss"][1:], ".-", label="Validation")    
ax.set_xlim(0,75)
ax.set_xlabel("Epochs", fontsize=20)
ax.set_title("Loss function", fontsize=25)
ax.legend(fontsize=15)

fig.savefig(f"DATA/UNet/saved_models/Training{TRAINING_NB}/Training_Curve.png", transparent=True)

In [None]:
model.load_weights(f"DATA/UNet/saved_models/Training{TRAINING_NB}/model_vloss_min.h5")

metrics = model.evaluate(val_input, val_output)
metrics_dict = dict(zip(model.metrics_names, metrics))

print("\n")
# print("The scores of the metrics of the model '" + model_folder.split("/")[-1] + "' are:")
print(metrics_dict)
print("\n")

In [None]:
### Add entries to the dictionnary to match the characteristics of the training

metrics_dict["training"] = f"Training{TRAINING_NB}"
# metrics_dict["dataset_size"] = training_size

### - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -


print("\n")
# print("The scores of the metrics of the model '" + model_folder.split("/")[-1] + "' are:")
print(metrics_dict)
print("\n")

### Choose to file location where all the results are saved

allscores_file = "DATA/UNet/saved_models/compare_scores.csv"

### - - - - - - - - - - - - - - - - - - - - - - - - - - - - 

metrics_df = pd.DataFrame(metrics_dict, index=[0])

try:
    allscores_df = pd.read_csv(allscores_file)
    allscores_df = pd.concat([allscores_df, metrics_df], ignore_index=True)
    allscores_df.to_csv(allscores_file, index=False)
    
except FileNotFoundError:
    metrics_df.to_csv(allscores_file, index=False)