# Train and reproduce models

## Imports

In [None]:
from solver import *  # import ADI solver
from generator import *  # import training / validation data generator
from DiffusionNet import dfn # import DiffusionNet model

import numpy as np # general array operations
import ray # parallel processing for data generation
from tqdm.notebook import tqdm # display progress bar 

import h5py # for saving in HDF5 format
from tensorflow.keras.models import Model  # machine learning library
from tensorflow.keras.optimizers import * # machine learning library
from tensorflow.keras.layers import * # machine learning library
import tensorflow

from ipywidgets import *  # import widgets
import warnings
warnings.filterwarnings("ignore")

In [None]:
# Initialize ray library
ray.init()

## Train

In [None]:
g=[12,24,48,96,192];s=[10,100];b=[100,1000,10_000]
@interact(grid_size=g,step=s,batches=b,train=False)
def grid_step_choice(grid_size,step,batches,train=False):
    name= f'step={step} {grid_size}x{grid_size}'

    if train :
        print('Generating training and validation data..')
        #generate data with grid size and step
        train_data_batches,train_data_batches_bias,validation_data_batches = generate_training_validation(N=grid_size,S=step,B=batches)
        
        
        print('Generating the model..')
        #initialize the model
        model = dfn()
        
    
        print('\nTraining..')
        #callbacks
        csv_logger = tensorflow.keras.callbacks.CSVLogger(f'ReproducedLogs/{name}.log')
        early_stopping = tensorflow.keras.callbacks.EarlyStopping(monitor='loss',min_delta=5e-5, patience=5, verbose=1, mode='auto',baseline=None, restore_best_weights=False)
        reduce_lr_callback = tensorflow.keras.callbacks.ReduceLROnPlateau(monitor = 'loss',factor = 0.5,patience = 3,verbose = 1,cooldown=1,min_delta = 1e-4,min_lr=1e-8 )
        model_checkpoint_callback = tensorflow.keras.callbacks.ModelCheckpoint(f'ReproducedModels/{name}.h5', monitor='loss', verbose=1, save_best_only=False,save_weights_only=False, mode='auto', save_freq='epoch',)

        #train
        model.fit(data_generator({**train_data_batches,**train_data_batches_bias}),
                  validation_data=data_generator(validation_data_batches),
                  steps_per_epoch=len({**train_data_batches,**train_data_batches_bias}),
                  validation_steps=len(validation_data_batches),
                  verbose=1,
                  epochs=100,
                  callbacks=[reduce_lr_callback,early_stopping,csv_logger,model_checkpoint_callback],
                 )