# SHERPA (Asynchronous Successive Halving; a.k.a ASHA) hyperparameter tuning

Blog on ASHA: https://blog.ml.cmu.edu/2018/12/12/massively-parallel-hyperparameter-optimization/ 

## General setup

In [None]:
import sys, os, glob, shutil
from utils.setup import SetupSherpa
from neural_networks.models import generate_input_list
from neural_networks.models import generate_output_list

argv = sys.argv[1:]
# argv = ["-c", "nn_config/cfg_sherpa.yml"]
# argv = ["-c", "nn_config/220322_SHERPA_ASHA/cfg_SHERPA_ASHA.yml"]
config_file = argv[-1]

setup = SetupSherpa(argv)

input_list        = generate_input_list(setup)
output_list       = generate_output_list(setup)
setup.output_list = output_list
spcam_outputs     = setup.spcam_outputs
children_idx_levs = setup.children_idx_levs

## SHERPA

In [None]:
import sherpa
from datetime import datetime
from pathlib  import Path
import tensorflow as tf
from tensorflow.keras.callbacks import LearningRateScheduler, EarlyStopping, ModelCheckpoint
from tensorflow.keras.models import load_model
from neural_networks.cbrain.learning_rate_schedule import LRUpdate
from neural_networks.cbrain.utils import load_pickle
from neural_networks.data_generator import build_train_generator,build_valid_generator
from neural_networks.cbrain.data_generator import DataGenerator
from neural_networks.cbrain.save_weights import save_norm
from neural_networks.models import generate_model_sherpa
from neural_networks.models import get_parents_sherpa

## Parameters & ranges
https://parameter-sherpa.readthedocs.io/en/latest/gettingstarted/guide.html

In [None]:
alg = sherpa.algorithms.SuccessiveHalving(
#     r=1, R=setup.sherpa_num_trials, eta=3, s=0, max_finished_configs=1
    r=1, R=setup.sherpa_num_trials, eta=5, s=0, max_finished_configs=1
)
causal_par = [
    sherpa.Ordinal('pc_alphas',  setup.sherpa_pc_alphas),
    sherpa.Ordinal('thresholds', setup.sherpa_thresholds)
]
# common_par = [sherpa.Continuous(name='init_lr', range=[0.001, 0.1], scale='log'),
#               sherpa.Continuous(name='divide_lr', range=[1, 2]),
#               sherpa.Discrete(name='num_layers', range=[1, 10]),
#               sherpa.Ordinal('num_nodes', [32, 64, 128, 256, 512])]
# common_par = [sherpa.Discrete(name='num_layers', range=[1, 10]),
#               sherpa.Ordinal('num_nodes', [32, 64, 128, 256, 512])]
common_par = [
    sherpa.Discrete(name='num_layers', range=setup.sherpa_num_layers),
    sherpa.Ordinal('num_nodes', setup.sherpa_num_nodes)
]

## Study 

In [None]:
parameters = [causal_par+common_par,common_par][setup.nn_type == 'SingleNN']

study = sherpa.Study(parameters=parameters,
                     algorithm=alg,
                     lower_is_better=True,
                     dashboard_port=None)

In [None]:
date = datetime.today().strftime('%Y%m%d')

for j, output in enumerate(output_list):        
    
    main_path = Path(
        "{nn_sherpa_path}/{date}_{output}/".format(
            nn_sherpa_path=setup.nn_sherpa_path,date=date,output=output
        )
    )
    Path(main_path).mkdir(parents=True, exist_ok=True)
    shutil.copyfile(config_file, Path(main_path,config_file.split('/')[-1]))
    
    print(f"Output: {output}, id: {j}")
    setup.output        = output
    setup.spcam_outputs = [iVar for iVar in spcam_outputs if iVar.value in output]
#     print(f"{setup.spcam_outputs}")
    setup.children_idx_levs = [[iLev,iId] for iLev,iId in children_idx_levs \
                                   if str(iLev)[:3] in output]
#     print(f"{setup.children_idx_levs}")
    inputs    = [False,input_list][setup.nn_type == 'SingleNN']
    pc_alpha  = False
    threshold = False
    
    for trial in study:
        
        # Getting number of training epochs
        # For: r=1, R=9, eta=3, s=0, max_finished_configs=1
#         initial_epoch = {1: 0, 3: 1, 9: 4}[trial.parameters['resource']]
        # For: r=1, R=50, eta=5, s=0, max_finished_configs=5
        initial_epoch = {1: 0, 5: 1, 25: 6}[trial.parameters['resource']]
        epochs = trial.parameters['resource'] + initial_epoch
        
        print("-"*100)
        print(f"Trial:\t{trial.id}\nEpochs:\t{initial_epoch} to {epochs}\nParameters: {trial.parameters}\n")

        modelph_save = Path(str(main_path),f"{trial.parameters['save_to']}/")
        Path(modelph_save).mkdir(parents=True, exist_ok=True)

        # Hyperparameters
        if setup.nn_type == 'CausalSingleNN':
            setup.pc_alpha    = [trial.parameters['pc_alphas']]
            setup.thresholds  = [trial.parameters['thresholds']]
            print(f"pc_alpha: {setup.pc_alpha}; threshold: {setup.thresholds}")
        # setup.init_lr    = trial.parameters['init_lr']
        # setup.divide_lr  = trial.parameters['divide_lr']
        setup.num_layers = trial.parameters['num_layers']
        setup.num_nodes  = trial.parameters['num_nodes']
        setup.n_trial    = trial.id
        
        # Causal links?
        if setup.nn_type == 'CausalSingleNN' and inputs == False:
            inputs, pc_alpha, threshold = get_parents_sherpa(setup)
        
        
        if trial.parameters['load_from'] == "":
            print(f"Creating new model for trial {trial.id}...\n")
            model = generate_model_sherpa(setup,
                                          parents=inputs,
                                          pc_alpha=pc_alpha,
                                          threshold=threshold)
            input_vars_dict  = model.input_vars_dict
            output_vars_dict = model.output_vars_dict
            filename = model.get_filename()
            if trial.id == 1: model.save_input_list(main_path, filename)
            model = model.model
            
        else:
            # Loading model
#             filename = glob.glob(str(path)+'/*_model.h5')[0]
#             print(f"Loading model from: ", filename, "...\n")
#             model = load_model(filename)
            modelph_load = Path(str(main_path),f"{trial.parameters['load_from']}/")
            modelnm_load = Path(modelph_load,f"{filename}_model.h5")
            print(f"Loading model from: ", str(modelnm_load), "...\n")
            model = load_model(modelnm_load)

        # Train model
        with build_train_generator(
            input_vars_dict, output_vars_dict, setup
        ) as train_gen, build_valid_generator(
            input_vars_dict, output_vars_dict, setup
        ) as valid_gen:

            lrs = LearningRateScheduler(
                LRUpdate(init_lr=setup.init_lr, step=setup.step_lr, divide=setup.divide_lr)
            )
            early_stop = EarlyStopping(monitor="loss", patience=setup.train_patience)
            checkpoint = ModelCheckpoint(
                str(modelph_save),
                save_best_only=True, 
                monitor='loss', 
                mode='min'
            )
            callbacks  = [lrs, early_stop, checkpoint]
            
            # Train model
            for i in range(initial_epoch, epochs):
                model.fit(train_gen, callbacks=callbacks, initial_epoch=i, epochs=i+1)
                loss, metric = model.evaluate(valid_gen)
                print("Validation mse: ", metric)
                study.add_observation(trial=trial, iteration=i,
                                      objective=metric,
                                      context={'loss': loss})
            
            study.finalize(trial=trial)
            # Save trial & input list
#             print(f"Saving model at: ", Path(path, f"{filename}_model.h5"))
#             model.model.save(Path(path, f"{filename}_model.h5"))
            modelnm_save = Path(modelph_save,f"{filename}_model.h5")
            print(f"Saving model at: ",str(modelnm_save))
            model.save(modelnm_save)
            if trial.id == 1:
                save_norm(
                    input_transform=train_gen.input_transform,
                    output_transform=train_gen.output_transform,
                    save_dir=str(main_path),
                    filename=filename,
                )
            
            study.save(output_dir=Path(main_path))

        print()
    print()

## Load the study

study.load_dashboard(Path(path))

study.get_best_result()