# SHERPA (GridSearch) hyperparameter tuning

## General setup

In [None]:
import sys, shutil
from utils.setup import SetupSherpa
from neural_networks.models import generate_input_list
from neural_networks.models import generate_output_list

argv = sys.argv[1:]
# argv = ["-c", "nn_config/# argv = ["-c", "nn_config/SHERPA_threshold_GridSearch/cfg_SHERPA_GridSearch_flnt.yml"]
config_file = argv[-1]

setup = SetupSherpa(argv)

input_list        = generate_input_list(setup)
output_list       = generate_output_list(setup)
setup.output_list = output_list
spcam_outputs     = setup.spcam_outputs
children_idx_levs = setup.children_idx_levs

## SHERPA

In [None]:
import sherpa
from datetime import datetime
from pathlib  import Path
import tensorflow as tf
from tensorflow.keras.callbacks import LearningRateScheduler, EarlyStopping, ModelCheckpoint
from neural_networks.cbrain.learning_rate_schedule import LRUpdate
from neural_networks.cbrain.utils import load_pickle
from neural_networks.data_generator import build_train_generator,build_valid_generator
from neural_networks.cbrain.data_generator import DataGenerator
from neural_networks.cbrain.save_weights import save_norm
from neural_networks.models import generate_model_sherpa
from neural_networks.models import get_parents_sherpa

## Parameters & ranges
https://parameter-sherpa.readthedocs.io/en/latest/gettingstarted/guide.html

In [None]:
alg        = sherpa.algorithms.GridSearch()
hp_space = {
    'num_layers':setup.sherpa_num_layers,
     'num_nodes':setup.sherpa_num_nodes,
    'thresholds':setup.thresholds,
}
parameters = sherpa.Parameter.grid(hp_space)

## Study 

In [None]:
study = sherpa.Study(parameters=parameters,
                     algorithm=alg,
                     lower_is_better=True,
                     dashboard_port=None)

In [None]:
date = datetime.today().strftime('%Y%m%d')

for j, output in enumerate(output_list):        

    main_path = Path(
        "{nn_sherpa_path}/{date}_{output}/".format(
            nn_sherpa_path=setup.nn_sherpa_path,date=date,output=output
        )
    )
    Path(main_path).mkdir(parents=True, exist_ok=True)
    shutil.copyfile(config_file, Path(main_path,config_file.split('/')[-1]))
    
    print(f"Output: {output}, id: {j}")
    setup.output        = output
    setup.spcam_outputs = [iVar for iVar in spcam_outputs if iVar.value in output]
    setup.children_idx_levs = [[iLev,iId] for iLev,iId in children_idx_levs \
                                   if str(iLev)[:3] in output]
    inputs    = [False,input_list][setup.nn_type == 'SingleNN']
    pc_alpha  = False
    threshold = False
    
    for trial in study:
        
        # Hyperparameters
        if setup.nn_type == 'CausalSingleNN':
            setup.pc_alpha    = [setup.sherpa_pc_alphas]
#             setup.threshold   = [setup.sherpa_thresholds]
        setup.thresholds = [trial.parameters['thresholds']]
        setup.num_layers = trial.parameters['num_layers']
        setup.num_nodes  = trial.parameters['num_nodes']
        setup.n_trial    = trial.id

        # Causal links?
        if setup.nn_type == 'CausalSingleNN':
            inputs, pc_alpha, threshold = get_parents_sherpa(setup)
            print(f"\n CausalNN with pc_alpha: {pc_alpha}; threshold: {threshold}")
        print(inputs)

#       Create the model
        model = generate_model_sherpa(setup,
                                      parents=inputs,
                                      pc_alpha=pc_alpha,
                                      threshold=threshold)
            
        input_vars_dict  = model.input_vars_dict
        output_vars_dict = model.output_vars_dict
        
        path = Path(str(main_path)+"/{id_trial}/".format(id_trial=trial.id))
        path = model.get_path(path)
        Path(path).mkdir(parents=True, exist_ok=True)
        
        print(f"\nTrial ({trial.id}) summary: thr-{threshold}, {setup.num_layers} layers & {setup.num_nodes} nodes")
        
        with build_train_generator(
            input_vars_dict, output_vars_dict, setup
        ) as train_gen, build_valid_generator(
            input_vars_dict, output_vars_dict, setup
        ) as valid_gen:

            lrs = LearningRateScheduler(
                LRUpdate(init_lr=setup.init_lr, step=setup.step_lr, divide=setup.divide_lr)
            )
            early_stop = EarlyStopping(monitor="loss", patience=setup.train_patience)
            checkpoint = ModelCheckpoint(
                str(path),
                save_best_only=True, 
                monitor='loss',
                mode='min'
            )
            callbacks  = [lrs, early_stop, checkpoint]
            
            # Train model
            init_epochs = 0; epochs = setup.epochs
            for i in range(init_epochs, epochs):
                print(f"initial_epoch: {i}, epochs: {i+1}")
                model.model.fit(train_gen, callbacks=callbacks, initial_epoch=i, epochs=i+1)
                loss, metric = model.model.evaluate(valid_gen)
                print("Validation mse: ", metric)
                study.add_observation(trial=trial, iteration=i,
                                      objective=metric,
                                      context={'loss': loss})
            
            study.finalize(trial=trial)
            # Save trial
            # Save trial, weights & input list
            filename = model.get_filename()
            print(f"Saving model at: ", Path(path, f"{filename}_model.h5"))
            model.model.save(Path(path, f"{filename}_model.h5"))
            model.model.save_weights(str(Path(path, f"{filename}_weights.h5")))
            model.save_input_list(path, filename)
            if trial.id == 1:
                save_norm(
                    input_transform=train_gen.input_transform,
                    output_transform=train_gen.output_transform,
                    save_dir=str(main_path),
                    filename=filename,
                )

            # Save study
            study.save(output_dir=Path(main_path))
                
        print()
    print()

# # Save study
# Path(main_path).mkdir(parents=True, exist_ok=True)
# study.save(output_dir=Path(main_path))

## Load the study

study.load_dashboard(Path(path))

study.get_best_result()