# Sherpa hyperparameter tuning

## General setup

In [None]:
import sys
from utils.setup import SetupSherpa

argv = sys.argv[1:]
# argv = ["-c", "nn_config/cfg_sherpa.yml"]

setup = SetupSherpa(argv)

In [None]:
from neural_networks.models import generate_input_list
from neural_networks.models import generate_output_list
input_list        = generate_input_list(setup)
output_list       = generate_output_list(setup)
setup.output_list = output_list
spcam_outputs     = setup.spcam_outputs
children_idx_levs = setup.children_idx_levs

# Sherpa

## Setup
https://parameter-sherpa.readthedocs.io/en/latest/gettingstarted/guide.html

In [None]:
import sherpa

In [None]:
max_num_trials    = 50
trials_per_output = int(max_num_trials/len(output_list)) # Evenly split among outputs
alg        = sherpa.algorithms.RandomSearch(max_num_trials=max_num_trials)
causal_par = [sherpa.Ordinal('pc_alphas',  [0.001, 0.01, 0.1]),
              sherpa.Ordinal('thresholds', [.15, .2, .25])]
common_par = [sherpa.Continuous(name='init_lr', range=[0.001, 0.1], scale='log'),
              sherpa.Continuous(name='divide_lr', range=[1, 2]),
              sherpa.Discrete(name='num_layers', range=[1, 10]),
              sherpa.Ordinal('num_nodes', [32, 64, 128, 256])]

## Sherpa study 

In [None]:
parameters = [causal_par+common_par,common_par][setup.nn_type == 'SingleNN']

study = sherpa.Study(parameters=parameters,
                     algorithm=alg,
                     lower_is_better=True,
                     dashboard_port=None)

In [None]:
from neural_networks.cbrain.utils import load_pickle
from neural_networks.data_generator import build_train_generator,build_valid_generator
from neural_networks.cbrain.data_generator import DataGenerator
from neural_networks.models import generate_model_sherpa
from neural_networks.models import get_parents_sherpa

In [None]:
for j, output in enumerate(output_list):        
    
    print(f"{output}, j: {j}")
    setup.output        = output
    setup.spcam_outputs = [iVar for iVar in spcam_outputs if iVar.value in output]
    print(f"{setup.spcam_outputs}")
    setup.children_idx_levs = [[iLev,iId] for iLev,iId in children_idx_levs \
                                   if str(iLev)[:3] in output]
    print(f"{setup.children_idx_levs}")
    inputs    = [False,input_list][setup.nn_type == 'SingleNN']
    pc_alpha  = False
    threshold = False
    
    n_trial = 1
    for trial in study:
        print(f"Trial num. = {n_trial}")
            
        # Hyperparameters
        if setup.nn_type == 'CausalSingleNN':
            setup.pc_alpha    = [trial.parameters['pc_alphas']]
            setup.thresholds  = [trial.parameters['thresholds']]
        setup.init_lr    = trial.parameters['init_lr']
        setup.divide_lr  = trial.parameters['divide_lr']
        setup.num_layers = trial.parameters['num_layers']
        setup.num_nodes  = trial.parameters['num_nodes']
        setup.n_trial    = n_trial
        
        # Causal links?
        if setup.nn_type == 'CausalSingleNN' and inputs == False:
            inputs, pc_alpha, threshold = get_parents_sherpa(setup)
        
        # Create the model
        model = generate_model_sherpa(setup,
                                      parents=inputs,
                                      pc_alpha=pc_alpha,
                                      threshold=threshold)
        print(f"Training {model}")
            
        input_vars_dict  = model.input_vars_dict
        output_vars_dict = model.output_vars_dict
            
        with build_train_generator(
            input_vars_dict, output_vars_dict, setup
        ) as train_gen, build_valid_generator(
            input_vars_dict, output_vars_dict, setup
        ) as valid_gen:
            
            # Train model
            for i in range(setup.epochs):
                model.model.fit(train_gen)
                loss, metric = model.model.evaluate(valid_gen)
                study.add_observation(trial=trial, iteration=i,
                                      objective=metric,
                                      context={'loss': loss})
            
            study.finalize(trial=trial)
            if n_trial == trials_per_output:
                break
            n_trial += 1

        print()
    print()

## Saving

In [None]:
from datetime import datetime
from pathlib  import Path

In [None]:
date = datetime.today().strftime('%Y%m%d')
path = Path("{nn_sherpa_path}/{date}_{nn_type}/".format(
    nn_sherpa_path=setup.nn_sherpa_path, date=date, nn_type = setup.nn_type))
Path(path).mkdir(parents=True, exist_ok=True)
study.save(output_dir=Path(path))

## Load the study

study.load_dashboard(Path(path))

study.get_best_result()