# Tutorial 8: Hyperparameter Optimization

To automatically tune hyperparameters in a `synthcity` plugin to generate more realistic data, we use hyperparameter optimization (HPO) algorithms such as Tree-structured Parzen estimators (TPE), Bayesian optimization, and genetic programming. In this tutorial we will use `optuna`, a very popular HPO library implementing TPE, to tune the hyperparameters of the `nflow` plugin to synthesize the diabetes dataset.

This tutorial requires the third party library `plotly` to be installed. This is not included in synthcity, as this tutorial is the only place it is needed. So in order to run this tutorial you will need to run `pip install plotly` as well as install synthcity.

In [None]:
!pip install synthcity
!pip install plotly

In [1]:
%load_ext autoreload
%autoreload 2

# stdlib
import sys
import warnings

# third party
import optuna
from sklearn.datasets import load_diabetes

# synthcity absolute
import synthcity.logger as log
from synthcity.plugins import Plugins
from synthcity.plugins.core.dataloader import GenericDataLoader

log.add(sink=sys.stderr, level="INFO")
warnings.filterwarnings("ignore")

    The default C++ compiler could not be found on your system.
    You need to either define the CXX environment variable or a symlink to the g++ command.
    For example if g++-8 is the command you can do
      import os
      os.environ['CXX'] = 'g++-8'
    


## Load the dataset

In [2]:
X, y = load_diabetes(return_X_y=True, as_frame=True)
X["target"] = y
X

Unnamed: 0,age,sex,bmi,bp,s1,s2,s3,s4,s5,s6,target
0,0.038076,0.050680,0.061696,0.021872,-0.044223,-0.034821,-0.043401,-0.002592,0.019907,-0.017646,151.0
1,-0.001882,-0.044642,-0.051474,-0.026328,-0.008449,-0.019163,0.074412,-0.039493,-0.068332,-0.092204,75.0
2,0.085299,0.050680,0.044451,-0.005670,-0.045599,-0.034194,-0.032356,-0.002592,0.002861,-0.025930,141.0
3,-0.089063,-0.044642,-0.011595,-0.036656,0.012191,0.024991,-0.036038,0.034309,0.022688,-0.009362,206.0
4,0.005383,-0.044642,-0.036385,0.021872,0.003935,0.015596,0.008142,-0.002592,-0.031988,-0.046641,135.0
...,...,...,...,...,...,...,...,...,...,...,...
437,0.041708,0.050680,0.019662,0.059744,-0.005697,-0.002566,-0.028674,-0.002592,0.031193,0.007207,178.0
438,-0.005515,0.050680,-0.015906,-0.067642,0.049341,0.079165,-0.028674,0.034309,-0.018114,0.044485,104.0
439,0.041708,0.050680,-0.015906,0.017293,-0.037344,-0.013840,-0.024993,-0.011080,-0.046883,0.015491,132.0
440,-0.045472,-0.044642,0.039062,0.001215,0.016318,0.015283,-0.028674,0.026560,0.044529,-0.025930,220.0


In [3]:
loader = GenericDataLoader(
    X,
    target_column="target",
    sensitive_columns=["sex"],
)
train_loader, test_loader = loader.train(), loader.test()

## Load the plugin class

In [4]:
PLUGIN = "tvae"
plugin_cls = type(Plugins().get(PLUGIN))
plugin_cls

[2023-06-04T14:33:29.060656+0200][19784][CRITICAL] module disabled: D:\Personal\Work\synthcity\src\synthcity\plugins\generic\plugin_goggle.py
[2023-06-04T14:33:29.060656+0200][19784][CRITICAL] module disabled: D:\Personal\Work\synthcity\src\synthcity\plugins\generic\plugin_goggle.py


synthcity.plugins.generic.plugin_tvae.TVAEPlugin

## Display the hyperparameter space

In [5]:
plugin_cls.hyperparameter_space()

[IntegerDistribution(name='n_iter', data=None, random_state=0, marginal_distribution=None, low=100, high=500, step=100),
 CategoricalDistribution(name='lr', data=None, random_state=0, marginal_distribution=None, choices=[0.0001, 0.0002, 0.001]),
 IntegerDistribution(name='decoder_n_layers_hidden', data=None, random_state=0, marginal_distribution=None, low=1, high=5, step=1),
 CategoricalDistribution(name='weight_decay', data=None, random_state=0, marginal_distribution=None, choices=[0.0001, 0.001]),
 CategoricalDistribution(name='batch_size', data=None, random_state=0, marginal_distribution=None, choices=[64, 128, 256, 512]),
 IntegerDistribution(name='n_units_embedding', data=None, random_state=0, marginal_distribution=None, low=50, high=500, step=50),
 IntegerDistribution(name='decoder_n_units_hidden', data=None, random_state=0, marginal_distribution=None, low=50, high=500, step=50),
 CategoricalDistribution(name='decoder_nonlin', data=None, random_state=0, marginal_distribution=None

## Use a trial to suggest a set of hyperparameters

In [6]:
from synthcity.utils.optuna import suggest_all

trial = optuna.create_study().ask()
params = suggest_all(trial, plugin_cls.hyperparameter_space())
params['n_iter'] = 10
params

{'n_iter': 10,
 'lr': 0.001,
 'decoder_n_layers_hidden': 1,
 'weight_decay': 0.001,
 'batch_size': 128,
 'n_units_embedding': 400,
 'decoder_n_units_hidden': 250,
 'decoder_nonlin': 'elu',
 'decoder_dropout': 0.11033838286560743,
 'encoder_n_layers_hidden': 4,
 'encoder_n_units_hidden': 300,
 'encoder_nonlin': 'relu',
 'encoder_dropout': 0.04140143497423224}

## Evaluate the plugin with the suggested hyperparameters

In [7]:
from synthcity.benchmark import Benchmarks

plugin = plugin_cls(**params).fit(train_loader)
report = Benchmarks.evaluate(
    [("trial", PLUGIN, params)],
    train_loader,  # Benchmarks.evaluate will split out a validation set
    repeats=1,
    metrics={"detection": ["detection_mlp"]},  # DELETE THIS LINE FOR ALL METRICS
)
report['trial']

100%|██████████| 10/10 [00:01<00:00,  9.66it/s]
[2023-06-04T14:33:37.813438+0200][19784][CRITICAL] module disabled: D:\Personal\Work\synthcity\src\synthcity\plugins\generic\plugin_goggle.py
100%|██████████| 10/10 [00:00<00:00, 11.75it/s]


Unnamed: 0,min,max,mean,stddev,median,iqr,rounds,errors,durations,direction
detection.detection_mlp.mean,0.5,0.5,0.5,0.0,0.5,0.0,1,0,1.64,minimize


## Create an Optuna study and optimize the hyperparameters

In [9]:
from optuna.pruners import SuccessiveHalvingPruner
from synthcity.utils.optuna import OptunaPruning

def objective(trial: optuna.Trial):
    hp_space = Plugins().get(PLUGIN).hyperparameter_space()
    hp_space[0].high = 100  # limit max n_iter to 100 for speed-up

    params = suggest_all(trial, hp_space)
    if enable_pruning:
        params.update(
            valid_size = validation_size,
            callbacks = [OptunaPruning(trial)],
        )

    try:
        ID = f"trial_{trial.number}"
        report = Benchmarks.evaluate(
            [(ID, PLUGIN, params)],
            train_loader,
            repeats=1,
            metrics={"detection": ["detection_mlp"]},  # DELETE THIS LINE FOR ALL METRICS
        )
    except Exception as e:  # invalid set of params
        print(f"{type(e).__name__}: {e}")
        print(params)
        raise
        raise optuna.TrialPruned()
    
    score = report[ID].query('direction == "minimize"')['mean'].mean()
    # average score across all metrics with direction="minimize"
    return score

enable_pruning = True
validation_size = 0.2
# pruner = None  # default pruner (MedianPruner)
pruner = SuccessiveHalvingPruner()

study = optuna.create_study(direction="minimize", pruner=pruner)
study.optimize(objective, n_trials=10)
study.best_params

[2023-06-04T14:34:55.208928+0200][19784][CRITICAL] module disabled: D:\Personal\Work\synthcity\src\synthcity\plugins\generic\plugin_goggle.py


TypeError: Object of type OptunaPruning is not JSON serializable
{'n_iter': 100, 'lr': 0.0001, 'decoder_n_layers_hidden': 3, 'weight_decay': 0.0001, 'batch_size': 512, 'n_units_embedding': 250, 'decoder_n_units_hidden': 450, 'decoder_nonlin': 'tanh', 'decoder_dropout': 0.11701923515620545, 'encoder_n_layers_hidden': 2, 'encoder_n_units_hidden': 450, 'encoder_nonlin': 'elu', 'encoder_dropout': 0.040265464737609814, 'valid_size': 0.2, 'callbacks': [<synthcity.utils.optuna.OptunaPruning object at 0x0000028A80FDA6E0>]}


TypeError: Object of type OptunaPruning is not JSON serializable

## Visualize the study

In [None]:
from optuna.visualization import plot_contour
from optuna.visualization import plot_edf
from optuna.visualization import plot_optimization_history
from optuna.visualization import plot_parallel_coordinate
from optuna.visualization import plot_param_importances
from optuna.visualization import plot_slice

plot_optimization_history(study)

In [None]:
# Visualize high-dimensional parameter relationships. 
plot_parallel_coordinate(study)

In [None]:
# Visualize hyperparameter relationships.
fig = plot_contour(study, params=['batch_size', 'lr', 'encoder_dropout', 'decoder_dropout'])
fig.update_layout(width=800, height=800)

In [None]:
# Visualize individual hyperparameters as slice plot.
plot_slice(study)

In [None]:
# Visualize parameter importances.
plot_param_importances(study)

In [None]:
# Learn which hyperparameters are affecting the trial duration with hyperparameter importance.
optuna.visualization.plot_param_importances(
    study, target=lambda t: t.duration.total_seconds(), target_name="duration"
)

In [None]:
# Visualize empirical distribution function of the objective.
plot_edf(study)

## Test performance of the optimized plugin

In [None]:
best_params = study.best_params
report = Benchmarks.evaluate(
    [("test", PLUGIN, best_params)],
    train_loader,
    test_loader,
    repeats=1,
    metrics={"detection": ["detection_mlp", "detection_xgb"]},  # DELETE THIS LINE FOR ALL METRICS
)
Benchmarks.print(report)

## Congratulations!

Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the movement towards Machine learning and AI for medicine, you can do so in the following ways!

### Star [Synthcity](https://github.com/vanderschaarlab/synthcity) on GitHub

- The easiest way to help our community is just by starring the Repos! This helps raise awareness of the tools we're building.


### Checkout other projects from vanderschaarlab
- [HyperImpute](https://github.com/vanderschaarlab/hyperimpute)
- [AutoPrognosis](https://github.com/vanderschaarlab/autoprognosis)
