# ConvGRU Optuna Study

In [3]:
import os
import optuna
from tsai.optuna import *
import papermill as pm
from tsai.optuna import run_optuna_study
from fastcore.basics import *
from optuna.distributions import *
from optuna.samplers import TPESampler
import wandb

In [13]:
config = AttrDict(
    study_name = 'general_study',
    study_type = 'bayesian',
    n_trials = 10,
    train_nb = f'{os.getcwd()}/nbs/convgru.ipynb', # path to the notebook to be executed
    # Don't use AttrDict here, just a normal dict
    search_space = {
        "convgru.attn": CategoricalDistribution([True, False]),
        "convgru.blur": CategoricalDistribution([True, False]),
        "convgru.coords_conv": CategoricalDistribution([True, False]),
        "convgru.norm": CategoricalDistribution(['batch', None]),
        "convgru.strategy": CategoricalDistribution(['zero', 'encoder'])
    },
    # Fixed parameters, not part of the search space (don't use AttrDict here)
    extra_params = {
        "n_epoch": 10,
        "bs": 64,
        "lookback": 4,
        "horizon": 4,
        "mmap": False,
        "normalize": True,
        "sel_steps": None,
        "stride": 1,
        "wandb.enabled": True,
        "wandb.log_learner": False,
        "wandb.mode": 'offline',
        "wandb.group": 'general_study_runs'
    },
    wandb_optuna = AttrDict(
        enabled = False,
        mode = 'offline'
    )
)

In [5]:
def create_objective(train_nb, search_space, extra_params=None):
    """
        Create objective function to be minimized by Optuna.
        Inputs:
            train_nb: path to the notebook to be used for training
            search_space: dictionary with the parameters to be optimized
            extra_params: dictionary with the extra parameters to be passed to the training notebook
        Output:
            valid_loss: validation loss
    """
    def objective(trial:optuna.Trial):
        # Define the parameters to be passed to the training notebook through papermill
        pm_parameters = {}
        for k,v in search_space.items():
            pm_parameters['config.' + k] = trial._suggest(k, v)

        # Add the extra parameters to the dictionary. The key of every parameter 
        # must be 'config.<param_name>'
        if extra_params is not None:
            for k,v in extra_params.items():
                pm_parameters['config.' + k] = v

        # Call the training notebook using papermill (don't print the output)
        stdout_file = open('tmp/pm_stdout.txt', 'w')
        stderr_file = open('tmp/pm_stderr.txt', 'w')

        pm.execute_notebook(
            train_nb,
            './tmp/pm_output.ipynb',
            parameters = pm_parameters,
            stdout_file = stdout_file,
            stderr_file = stderr_file
        )

        # Close the output files
        stdout_file.close()
        stderr_file.close()

        # Get the output value of interest from the source notebook
        %store -r valid_loss
        return valid_loss

    return objective

In [None]:
obj = create_objective(config.train_nb, config.search_space, extra_params=config.extra_params)
study = run_optuna_study(obj, study_type=config.study_type, direction='minimize', 
                         path='./tmp', study_name=config.study_name, n_trials=config.n_trials)

In [None]:
run = wandb.init(config=config, mode=config.wandb_optuna.mode, 
                 job_type='optuna-study') if config.wandb_optuna.enabled else None

In [None]:
if run is not None:
    run.log(dict(study.best_params, **{'best_value': study.best_value, 
                                       'best_trial_number': study.best_trial.number}))
    run.log_artifact(f'./tmp/{config.study_name}.pkl', type='optuna_study')
    run.log({
        'contour': optuna.visualization.plot_contour(study),
        'edf': optuna.visualization.plot_edf(study),
        'intermediate_values': optuna.visualization.plot_intermediate_values(study),
        'optimization_history': optuna.visualization.plot_optimization_history(study),
        'parallel_coordinate' : optuna.visualization.plot_parallel_coordinate(study),
        'param_importances': optuna.visualization.plot_param_importances(study),
        'slice': optuna.visualization.plot_slice(study)
    })

In [None]:
if run is not None:
    run.finish()