###  Train Denoising Module: 
- in this notebook, we train the F-actin for Denoising. 
- The F-actin image patches goes to Super Resoulation Module. 
- From the SR output, we extract the primary features as PFE branch. 
- From the input Image we calculate the Mopire Patterns Features as MPE
- The PFE and MPE breanches are concanited together for GT branch.


In [None]:
import datetime
from csbdeep.io import load_training_data
from csbdeep.utils import axes_dict, plot_some, plot_history
import matplotlib.pyplot as plt
from models import Denoiser, Train_RDL_Denoising
from loss_functions import mse_ssim, mse_ssim_psnr 
import tensorflow as tf
import os
from pathlib import Path
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import callbacks
from tensorflow.keras.models import load_model
import numpy as np
from tensorflow.keras.callbacks import TensorBoard
from wandb.integration.keras import WandbMetricsLogger, WandbEvalCallback
import numpy as np

######  login WANDB #########
import wandb



In [None]:
wandb.login()
wandb.init(project="F-actin_DN",name=f"DN_train_MSE_SSIM_ep_1000_b_64",config= tf.compat.v1.flags.FLAGS, sync_tensorboard=True)
tensorboard_callback = TensorBoard(log_dir=wandb.run.dir)
tf.function(jit_compile=True)
# tf.config.run_functions_eagerly(True)
# tf.enable_eager_execution()

# tf.debugging.enable_check_numerics()
## set up the config for the model


gpus = tf.config.list_physical_devices('GPU')
print(f'These are the GPUs available and will be used :: \n {gpus}')

In [None]:
root_dir = '../F-actin'
den_model_dir = Path(root_dir)/'DNModel' 
sr_model_dir = Path(root_dir)/'SRModel_700_ready' # provide the path to the SR model
Path(den_model_dir).mkdir(exist_ok=True)
Path(sr_model_dir).mkdir(exist_ok=True)
train_data_file = f'{root_dir}/Train/DN/augmented_F-actin_02_DN.npz'
log_dir = "logs/fitDN/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
# for saving the results into output folder
output_dir = Path.cwd() / 'DN_Model_plots_and_results'
Path(output_dir).mkdir( exist_ok=True)


### load the training data and define the model and parameters

In [None]:
# for multi GOu training
strategy = tf.distribute.MirroredStrategy(devices=["/gpu:0"])

with strategy.scope(): 

    ############### load the data ################
    (X,Y), (X_val,Y_val), axes = load_training_data(train_data_file, validation_split=0.1, verbose=True)
    print('information about DN model Training data')
    print(f'X.shape : {X.shape} ,\n Y.shape : {Y.shape} ,\n X_val.shape : {X_val.shape} ,\n Y_val.shape : {Y_val.shape} ,\naxes : {axes} ')

    c = axes_dict(axes)['C']
    n_channel_in, n_channel_out = X.shape[c], Y.shape[c]
    print(f'n_channel_in : {n_channel_in} , n_channel_out : {n_channel_out} ')
    
    #############  pre process the data to fit into the model.  ################
    def preprocess_data(X, Y):
        # Squeeze the unnecessary dimensions and transpose the axes
        X = tf.squeeze(X, axis=-1)
        Y = tf.squeeze(Y, axis=-1)
        X = tf.transpose(X, perm=[0, 2, 3, 1])
        Y = tf.transpose(Y, perm=[0, 2, 3, 1])
        return X, Y


    X, Y = preprocess_data(X, Y)
    X_val, Y_val = preprocess_data(X_val, Y_val)
    print('after preprocessing the data \n')
    print(f'X.shape : {X.shape} , Y.shape : {Y.shape} , X_val.shape : {X_val.shape} , Y_val.shape : {Y_val.shape} , axes : {axes} ')

    ################# plot some train data ################
    plt.figure(figsize=(12,5))
    plot_some(tf.transpose(X_val[:5], perm=[0, 3, 1, 2]),tf.transpose(Y_val[:5], perm=[0, 3, 1, 2]))
    plt.suptitle('5 example validation patches (top row: source, bottom row: target)')
    plt.savefig(f'{output_dir}/DN_train_image_F-actin_02.png', bbox_inches='tight')

    ################### Define the Parameters ############################
    init_lr = 1e-4
    lr_decay_factor =.75	# Learning rate decay factor	

    batch_size = 32
    epochs = 1000
    beta_1=0.9
    beta_2=0.999
    wavelength = 0.488 
    excNA = 1.35
    dx = 62.6e-3
    dy = dx
    dxy = dx 
    scale_gt = 2.0
    setupNUM = 0
    space = wavelength/excNA/2 # here is teh change = /2
    k0mod = 1 / space
    napodize = 10
    nphases = 3
    ndirs = 3
    sigma_x = 0.5
    sigma_y = 0.5
    recalcarrays = 2
    ifshowmodamp = 0
    otf_path = 'TIRF488_cam1_0_z30_OTF2d.mrc' # the otf from the RDL-Sim package
    norders = int((nphases + 1) / 2)
    if setupNUM == 0:
        k0angle_c = [1.48, 2.5272, 3.5744]
        k0angle_g = [0.0908, -0.9564, -2.0036]  
    if setupNUM == 1:
        k0angle_c = [-1.66, -0.6128, 0.4344]
        k0angle_g = [3.2269, 2.1797, 1.1325]      
    if setupNUM == 2:
        k0angle_c = [1.5708, 2.618, 3.6652]
        k0angle_g = [0, -1.0472, -2.0944] 
    total_data,  height, width, channels = X.shape
    print(f'\n\n total_data,  height, width, channels : {total_data,  height, width, channels} \n\n')

    ########### define parameter dictionary ################
    parameters = {
        'Ny': height,
        'Nx': width,
        'lr_decay_factor': lr_decay_factor,
        'wavelength':wavelength,
        'excNA':excNA,
        'ndirs':ndirs,
        'nphases':nphases,
        'init_lr': init_lr,
        'ifshowmodamp':ifshowmodamp,
        'batch_size': batch_size,
        'epochs': epochs,
        'beta_1':beta_1,
        'beta_2':beta_2,
        'scale_gt': scale_gt,
        'setupNUM': setupNUM,
        'k0angle_c':k0angle_c,
        'k0angle_g':k0angle_g,
        'recalcarrays':recalcarrays,
        'dxy':dxy,
        'space':space,
        'k0mod':k0mod,
        'norders':norders,
        'napodize':napodize,
        'scale': scale_gt,
        'sigma_x': sigma_x,
        'sigma_y': sigma_y,
        'log_dir': log_dir,
        'den_model_dir': den_model_dir,
        'sr_model_dir': sr_model_dir,
        'otf_path' : otf_path,
        'results_path': output_dir        
    }
    ########### check the SR model and load the model if it is already trained ############
    if len(os.listdir(sr_model_dir)) > 0:

        with tf.keras.utils.custom_object_scope({'mse_ssim': mse_ssim}):
            if len(os.listdir(sr_model_dir)) > 0:
                print(f'Loading model from {sr_model_dir}')
                Trainingmodel_dfcan = load_model(sr_model_dir)
    else:
        assert 'DFCAN model has to be trained before training RDL denosier'  


    ############### define the DN model and compile the model ################
    Trainingmodel_denoise = Denoiser((height, width, nphases))
    optimizer = Adam(learning_rate=init_lr, beta_1=beta_1, beta_2=beta_2)
    Trainingmodel_denoise.compile(loss=mse_ssim, optimizer=optimizer)

    # Trainingmodel_denoise.summary() 

    tensorboard_callback = callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

    hrate = callbacks.History()
    ## Load the denoising model and Train the RDL denoiser
    rdl_denoising = Train_RDL_Denoising(
                        srmodel=Trainingmodel_dfcan, 
                        denmodel=Trainingmodel_denoise,
                        loss_fn=mse_ssim,
                        optimizer=optimizer,
                        parameters = parameters)

    # Trainingmodel = load_model(den_model_dir)
    #print(f'this is the data being send: data :: {data} , data_val :: {data_val} ')
    rdl_denoising.fit(data= data, data_val = data_val)
                                 # callbacks=[lrate, hrate, srate, tensorboard_callback, WandbMetricsLogger(), wandb_eval_callback ]




### Prediction from the model

In [None]:

plt.figure(figsize=(12,7))

# do the orediction on the validation data
_P = rdl_denoising.predict(X_val[:5])
print(f'P.shape : {_P.shape} ')


plot_some(tf.transpose(X_val[:5], perm=[0, 3, 1, 2]),tf.transpose(Y_val[:5], perm=[0, 3, 1, 2]),tf.transpose(_P, perm=[0, 3, 1, 2]),pmax=99.5)
plt.suptitle('5 example validation patches\n'      
             'top row: input (source),  '          
             'middle row: target (ground truth),  '
             'bottom row: predicted from source')
plt.savefig(f'{output_dir}/DN_train_image_prediction.png', bbox_inches='tight')