## In this Notebook, we train the Super Resoulation Module. 
- the training progress can be viewed W&B in tensorboard section. 

In [None]:




import datetime
import os
from csbdeep.io import load_training_data
from csbdeep.utils import axes_dict, plot_some,plot_history
import matplotlib.pyplot as plt
from model_DFCAN import DFCAN
from loss_functions import mse_ssim, mse_ssim_psnr
import tensorflow as tf
from pathlib import Path
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import callbacks
from tensorflow.keras.models import load_model
from wandb.integration.keras import WandbMetricsLogger, WandbEvalCallback
from tensorflow.keras.callbacks import TensorBoard
from skimage.metrics import peak_signal_noise_ratio as psnr, mean_squared_error as mse, structural_similarity as ssim
import numpy as np
import wandb
    
print(tf.__version__)


In [None]:
# set up the train locations and logs.
wandb.login()
root_dir = '../F-actin'
model_dir = Path(root_dir)/'SRModel'
Path(model_dir).mkdir(exist_ok=True)
train_data_file = f'{root_dir}/Train/SR/augmented_F-actin_02_SR_big.npz'
log_dir = "logs/fitSR/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

############## for saving the results of the training ################
output_dir = Path.cwd() / 'SR_Model_plots_and_results'
Path(output_dir).mkdir( exist_ok=True)

# for saving the training progress and viewing it in tensorboard
wandb.init(project="F-actin_SR",name=f"SR_train_MSE_SSIM_ep_2400_b_64",config= tf.compat.v1.flags.FLAGS, sync_tensorboard=True)

tensorboard_callback = TensorBoard(log_dir=wandb.run.dir)
tf.function(jit_compile=True)

### load the training data and define the model and parameters

In [None]:
strategy = tf.distribute.MirroredStrategy(['GPU:0'])

with strategy.scope(): 
    ################  define the train  parameters  ################

    init_lr = 1e-4
    batch_size =64
    epochs = 2400
    beta_1=0.9
    beta_2=0.999
    scale_gt = 2.0
    lr_decay_factor = 0.75	# Learning rate decay factor	

    (X,Y), (X_val,Y_val), axes = load_training_data(train_data_file, validation_split=0.1, verbose=True)
    print()
    print()
    print('Information about SR training data')
    print(f"X_shape :  {X.shape} ,\nX_dtype : {X.dtype}   Y_shape: {Y.shape}\nY_dtype : {Y.dtype}   ,\nX_val : {X_val.shape} ,\nY_val : {Y_val.shape}")
    print()

    ############### preprocess the data to fit into the model ################
    def preprocess_data(X, Y):
        # Squeeze the unnecessary dimensions and transpose the axes
        X = tf.squeeze(X, axis=-1)
        Y = tf.squeeze(Y, axis=-1)
        X = tf.transpose(X, perm=[0, 2, 3, 1])
        Y = tf.transpose(Y, perm=[0, 2, 3, 1])
        return X, Y

    X, Y = preprocess_data(X, Y)
    X_val, Y_val = preprocess_data(X_val, Y_val)

    train_dataset = (X, Y)
    val_dataset =(X_val, Y_val)


    print(f'after preprocessing teh dta in batch chunks. X : {X.shape} Y : {Y.shape} X_val : {X_val.shape}  Y_val : {Y_val.shape}')

    ################  plot some of the training data  ################
    plt.figure(figsize=(12,6))
    plot_some(tf.transpose(X_val[:6], perm=[0, 3, 1, 2]),Y_val[:6])
    plt.suptitle('5 example validation patches (top row: source, bottom row: target)')
    plt.savefig(f'{output_dir}/SR_train_image_F-Actin_02_2400.png', bbox_inches='tight')


    total_data,  height, width, channels= X.shape
    print(f'total_data,  height, width, channels : {total_data,  height, width, channels}')
    valid_data = val_dataset

    Trainingmodel = DFCAN((height, width, channels), scale=scale_gt)
    optimizer = Adam(learning_rate=init_lr, beta_1=beta_1, beta_2=beta_2)
    Trainingmodel.compile(loss=mse_ssim, optimizer=optimizer)
    #Trainingmodel.summary()
    
    tensorboard_callback = callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
    # lrate = callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=4, verbose=1) # monitor val_loss for faster training
    
    lrate= callbacks.ReduceLROnPlateau(monitor='val_loss', factor=lr_decay_factor, 
                            patience=15, mode='auto', min_delta=1e-4,
                            cooldown=0, min_lr=init_lr*0.1, verbose=1)

    hrate = callbacks.History()
    
    srate = callbacks.ModelCheckpoint(
                str(model_dir),
                monitor="loss",
                save_best_only=True,
                save_weights_only=False,                
                mode="auto",
            )

    ################  load the model if it exists  ################
    if len(os.listdir(model_dir)) > 0:
    
      with tf.keras.utils.custom_object_scope({'mse_ssim': mse_ssim}):
        if len(os.listdir(model_dir)) > 0:
            print(f'Loading model from {model_dir}')
            Trainingmodel = load_model(model_dir)

   
    history = Trainingmodel.fit(X,Y, batch_size=batch_size,
                                   epochs=epochs, validation_data=val_dataset, shuffle=True,
                                   callbacks=[lrate, hrate, srate, tensorboard_callback ])
    
    
    Trainingmodel.save(model_dir)
        
    
    print(f'hisitry :: {history}')
    print(sorted(list(history.history.keys())))
    plt.figure(figsize=(16,5))
    plot_history(history,['loss','val_loss'])
    plt.figure(figsize=(12,7))
    plt.savefig(f'{output_dir}/SR_train_image_F-Actin_02_2400_history.png', bbox_inches='tight')


### Prediction from the model

In [None]:

_P = Trainingmodel.predict(X_val[:5])
plot_some(tf.transpose(X_val[:5], perm=[0, 3, 1, 2]),Y_val[:5],_P,pmax=99.5)
plt.suptitle('5 example validation patches\n'      
            'top row: input (source),  '          
            'middle row: target (ground truth),  '
            'bottom row: predicted from source')
plt.savefig(f'{output_dir}/SR_train_image_F-Actin_02_2400_predictions.png', bbox_inches='tight')