# Tutorial 8: Hyperparameter Optimization

To automatically tune hyperparameters in a `synthcity` plugin to generate more realistic data, we use hyperparameter optimization (HPO) algorithms such as Tree-structured Parzen estimators (TPE), Bayesian optimization, and genetic programming. In this tutorial we will use `optuna`, a very popular HPO library implementing TPE, to tune the hyperparameters of the `nflow` plugin to synthesize the diabetes dataset.

This tutorial requires the third party library `plotly` to be installed. This is not included in synthcity, as this tutorial is the only place it is needed. So in order to run this tutorial you will need to run `pip install plotly` as well as install synthcity.

In [1]:
# !pip install synthcity
# !pip install plotly
# !pip uninstall -y torchaudio torchdata

In [1]:
# stdlib
from typing import Any, List
  
# third party
import numpy as np
import pandas as pd
from ctgan import CTGAN
from torch.nn import TransformerEncoder
# synthcity absolute
from synthcity.plugins.core.dataloader import DataLoader, GenericDataLoader
from synthcity.plugins.core.distribution import Distribution
from synthcity.plugins.core.plugin import Plugin
from synthcity.plugins.core.schema import Schema
from synthcity.plugins.core.distribution import (
    Distribution,
    IntegerDistribution,
)


class sdv_ctgan_plugin(Plugin):
    """SDV CTGAN integration in synthcity."""

    def __init__(
        self,
        embedding_n_units: int = 128,
        n_iter: int = 150,
        batch_size: int = 100,
        cat_limit: int = 35,
        num_layers = 4,
        num_heads = 8,
        **kwargs: Any
    ) -> None:
        super().__init__(**kwargs)
        self.cat_limit = cat_limit
        self.model = CTGAN(
            embedding_dim=embedding_n_units,
            batch_size=batch_size,
            epochs=n_iter,
            verbose=False,
            num_layers = num_layers,
            num_heads = num_heads
        )

    @staticmethod
    def name() -> str:
        return "trans_ctgan"

    @staticmethod
    def type() -> str:
        return "debug"

    @staticmethod
    def hyperparameter_space(**kwargs: Any) -> List[Distribution]:
        """
        We can customize the hyperparameter space, and use it in AutoML benchmarks.
        """
        return [
            IntegerDistribution(name="batch_size", low=100, high=300, step=50),
            IntegerDistribution(name="n_iter", low=100, high=300, step=50),
            IntegerDistribution(name="num_layers", low=2, high=12, step=2),
            IntegerDistribution(name="num_heads", low=4, high=8, step=4)
        ]

    def _fit(self, X: DataLoader, *args: Any, **kwargs: Any) -> "sdvPlugin":
        """We selected the discrete columns based on the count of unique values, and train the CTGAN"""
        discrete_columns = []

        for col in X.columns:
            if len(X[col].unique()) < self.cat_limit:
                discrete_columns.append(col)
        from torch.nn import TransformerEncoder
        self.model.fit(X.dataframe(), discrete_columns=discrete_columns)
        return self

    def _generate(self, count: int, syn_schema: Schema, **kwargs: Any) -> pd.DataFrame:
        return self._safe_generate(self.model.sample, count, syn_schema)

  from .autonotebook import tqdm as notebook_tqdm




In [2]:
# synthcity absolute
from synthcity.plugins import Plugins

generators = Plugins()

generators.list()

[2025-02-02T16:44:24.400157+0800][6236][CRITICAL] module disabled: e:\qycache\anaconda\envs\LLM\lib\site-packages\synthcity\plugins\generic\plugin_goggle.py


['marginal_distributions',
 'fflows',
 'decaf',
 'rtvae',
 'ddpm',
 'timegan',
 'aim',
 'privbayes',
 'adsgan',
 'survival_gan',
 'survae',
 'tvae',
 'image_cgan',
 'bayesian_network',
 'great',
 'uniform_sampler',
 'dpgan',
 'image_adsgan',
 'radialgan',
 'dummy_sampler',
 'survival_nflow',
 'survival_ctgan',
 'arf',
 'nflow',
 'ctgan',
 'timevae',
 'pategan']

In [3]:
generators.add("trans_ctgan", sdv_ctgan_plugin)

generators.list()

['marginal_distributions',
 'fflows',
 'decaf',
 'rtvae',
 'ddpm',
 'timegan',
 'aim',
 'privbayes',
 'adsgan',
 'survival_gan',
 'survae',
 'tvae',
 'image_cgan',
 'bayesian_network',
 'great',
 'uniform_sampler',
 'dpgan',
 'image_adsgan',
 'radialgan',
 'dummy_sampler',
 'survival_nflow',
 'survival_ctgan',
 'arf',
 'nflow',
 'ctgan',
 'timevae',
 'pategan',
 'trans_ctgan']

In [4]:
# from sklearn.datasets import load_breast_cancer,load_diabetes
# X, y = load_diabetes(return_X_y=True, as_frame=True)
#real_path = "C:/Users/26332/Desktop/表格数据生成课题/表格GAN相关研究汇总（截止2024.11.2）/论文源代码整合/CTAB-GAN-main/Real_Datasets/Adult3.csv"
#real_path = "C:/Users/26332/Desktop/表格数据生成课题/表格GAN相关研究汇总（截止2024.11.2）/论文源代码整合/CTAB-GAN-main/Real_Datasets/creditcard2.csv"
#real_path = 'C:/Users/26332/Desktop/表格数据生成课题/表格GAN相关研究汇总（截止2024.11.2）/论文源代码整合/CTGAN-main/CTGAN-main/examples/csv/train_clean.csv'
#real_path = "C:/Users/26332/Desktop/表格数据生成课题/表格GAN相关研究汇总（截止2024.11.2）/论文源代码整合/CTAB-GAN-main/Real_Datasets/train2.csv"
real_path = "C:/Users/26332/Desktop/表格数据生成课题/表格GAN相关研究汇总（截止2024.11.2）/论文源代码整合/synthcity-main/tutorials/covertype_preprocessed.csv"
data = pd.read_csv(real_path)
#data = pd.read_csv('C:/Users/26332/Desktop/表格数据生成课题/表格GAN相关研究汇总（截止2024.11.2）/论文源代码整合/CTGAN-main/CTGAN-main/examples/csv/train_clean.csv')
#data = pd.read_csv('C:/Users/26332/Desktop/表格数据生成课题/表格GAN相关研究汇总（截止2024.11.2）/论文源代码整合/CTGAN-main/CTGAN-main/Adult_datasets.csv')

In [5]:
# stdlib
import sys
import warnings

# third party
import optuna
from sklearn.datasets import load_diabetes

# synthcity absolute
import synthcity.logger as log
from synthcity.plugins import Plugins
from synthcity.plugins.core.dataloader import GenericDataLoader

log.add(sink=sys.stderr, level="INFO")
warnings.filterwarnings("ignore")

## Load the dataset

In [7]:
# X, y = load_diabetes(return_X_y=True, as_frame=True)
# X["target"] = y
# X

In [6]:
loader = GenericDataLoader(
    data,
    target_column="Cover_Type",
    sensitive_columns=[],
)
train_loader, test_loader = loader.train(), loader.test()

## Load the plugin class

In [7]:
PLUGIN = "trans_ctgan"
plugin_cls = type(Plugins().get(PLUGIN))
plugin_cls

[2025-02-02T16:44:26.902942+0800][6236][CRITICAL] module disabled: e:\qycache\anaconda\envs\LLM\lib\site-packages\synthcity\plugins\generic\plugin_goggle.py
[2025-02-02T16:44:26.902942+0800][6236][CRITICAL] module disabled: e:\qycache\anaconda\envs\LLM\lib\site-packages\synthcity\plugins\generic\plugin_goggle.py


__main__.sdv_ctgan_plugin

## Display the hyperparameter space

In [8]:
plugin_cls.hyperparameter_space()

[IntegerDistribution(name='batch_size', data=None, random_state=0, marginal_distribution=None, low=100, high=300, step=50),
 IntegerDistribution(name='n_iter', data=None, random_state=0, marginal_distribution=None, low=100, high=300, step=50),
 IntegerDistribution(name='num_layers', data=None, random_state=0, marginal_distribution=None, low=2, high=12, step=2),
 IntegerDistribution(name='num_heads', data=None, random_state=0, marginal_distribution=None, low=4, high=8, step=4)]

## Use a trial to suggest a set of hyperparameters

In [9]:
from synthcity.utils.optuna_sample import suggest_all

trial = optuna.create_study().ask()
params = suggest_all(trial, plugin_cls.hyperparameter_space())
params['n_iter'] = 100  # speed up
params

{'batch_size': 250, 'n_iter': 100, 'num_layers': 10, 'num_heads': 8}

## Evaluate the plugin with the suggested hyperparameters

In [10]:
from synthcity.benchmark import Benchmarks

plugin = plugin_cls(**params).fit(train_loader)
report = Benchmarks.evaluate(
    [("trial", PLUGIN, params)],
    train_loader,  # Benchmarks.evaluate will split out a validation set
    repeats=2,
    metrics={"detection": ["detection_mlp"]},  # DELETE THIS LINE FOR ALL METRICS
)
report['trial']

  0%|          | 0/100 [00:00<?, ?it/s]
epoch 0
Aw
epoch 1
Aw
epoch 2
Aw
epoch 3
Aw
epoch 4
Aw
epoch 5
Aw
epoch 6
Aw
epoch 7
Aw
epoch 8
Aw
epoch 9
Aw
epoch 10
Aw
epoch 11
Aw
epoch 12
Aw
epoch 13
Aw
epoch 14
Aw
epoch 15
Aw
epoch 16
Aw
epoch 17
Aw
epoch 18
Aw
epoch 19
Aw
epoch 20
Aw
epoch 21
Aw
epoch 22
Aw
epoch 23
Aw
epoch 24
Aw
epoch 25
Aw
epoch 26
Aw
epoch 27
Aw
epoch 28
Aw
epoch 29
Aw
epoch 30
Aw
epoch 31
Aw
epoch 32
Aw
epoch 33
Aw
epoch 34
Aw
epoch 35
Aw
epoch 36
Aw
epoch 37
Aw
epoch 38
Aw
epoch 39
Aw
epoch 40
Aw
epoch 41
Aw
epoch 42
Aw
epoch 43
Aw
epoch 44
Aw
epoch 45
Aw
epoch 46
Aw
epoch 47
Aw
epoch 48
Aw
epoch 49
Aw
epoch 50
Aw
epoch 51
Aw
epoch 52
Aw
epoch 53
Aw
epoch 54
Aw
epoch 55
Aw
epoch 56
Aw
epoch 57
Aw
epoch 58
Aw
epoch 59
Aw
epoch 60
Aw
epoch 61
Aw
epoch 62
Aw
epoch 63
Aw
epoch 64
Aw
epoch 65
Aw
epoch 66
Aw
epoch 67
Aw
epoch 68
Aw
epoch 69
Aw
epoch 70
Aw
epoch 71
Aw
epoch 72
Aw
epoch 73
Aw
epoch 74
Aw
epoch 75
Aw
epoch 76
Aw
epoch 77
Aw
epoch 78
Aw
epoch 79
Aw
epoch 80
A

[2025-02-02T17:15:09.519060+0800][6236][INFO] Testcase : trial
[2025-02-02T17:15:09.531552+0800][6236][INFO] [testcase] Experiment repeat: 0 task type: classification Train df hash = 7570246924558574363
[2025-02-02T17:15:09.533561+0800][6236][CRITICAL] module disabled: e:\qycache\anaconda\envs\LLM\lib\site-packages\synthcity\plugins\generic\plugin_goggle.py
[2025-02-02T17:15:09.533561+0800][6236][CRITICAL] module disabled: e:\qycache\anaconda\envs\LLM\lib\site-packages\synthcity\plugins\generic\plugin_goggle.py


  0%|          | 0/100 [00:00<?, ?it/s]
epoch 0
Aw
epoch 1
Aw
epoch 2
Aw
epoch 3
Aw
epoch 4
Aw
epoch 5
Aw
epoch 6
Aw
epoch 7
Aw
epoch 8
Aw
epoch 9
Aw
epoch 10
Aw
epoch 11
Aw
epoch 12
Aw
epoch 13
Aw
epoch 14
Aw
epoch 15
Aw
epoch 16
Aw
epoch 17
Aw
epoch 18
Aw
epoch 19
Aw
epoch 20
Aw
epoch 21
Aw
epoch 22
Aw
epoch 23
Aw
epoch 24
Aw
epoch 25
Aw
epoch 26
Aw
epoch 27
Aw
epoch 28
Aw
epoch 29
Aw
epoch 30
Aw
epoch 31
Aw
epoch 32
Aw
epoch 33
Aw
epoch 34
Aw
epoch 35
Aw
epoch 36
Aw
epoch 37
Aw
epoch 38
Aw
epoch 39
Aw
epoch 40
Aw
epoch 41
Aw
epoch 42
Aw
epoch 43
Aw
epoch 44
Aw
epoch 45
Aw
epoch 46
Aw
epoch 47
Aw
epoch 48
Aw
epoch 49
Aw
epoch 50
Aw
epoch 51
Aw
epoch 52
Aw
epoch 53
Aw
epoch 54
Aw
epoch 55
Aw
epoch 56
Aw
epoch 57
Aw
epoch 58
Aw
epoch 59
Aw
epoch 60
Aw
epoch 61
Aw
epoch 62
Aw
epoch 63
Aw
epoch 64
Aw
epoch 65
Aw
epoch 66
Aw
epoch 67
Aw
epoch 68
Aw
epoch 69
Aw
epoch 70
Aw
epoch 71
Aw
epoch 72
Aw
epoch 73
Aw
epoch 74
Aw
epoch 75
Aw
epoch 76
Aw
epoch 77
Aw
epoch 78
Aw
epoch 79
Aw
epoch 80
A

[2025-02-02T17:39:53.559021+0800][6236][INFO] [Hillshade_3pm] quality loss for constraints le = 2.7096743428136514. Remaining 203. prev length 1280. Original dtype float64.
[2025-02-02T17:39:53.561022+0800][6236][INFO] [Hillshade_3pm] quality loss for constraints ge = -3.731208897747535. Remaining 91. prev length 203. Original dtype float64.
[2025-02-02T17:39:53.564033+0800][6236][INFO] [Elevation] quality loss for constraints ge = -3.7461682315564473. Remaining 90. prev length 91. Original dtype float64.
[2025-02-02T17:39:53.566030+0800][6236][INFO] [Slope] quality loss for constraints le = 3.747858032342592. Remaining 82. prev length 90. Original dtype float64.
[2025-02-02T17:39:53.569046+0800][6236][INFO] [Horizontal_Distance_To_Hydrology] quality loss for constraints le = 4.54625632482397. Remaining 74. prev length 82. Original dtype float64.
[2025-02-02T17:39:53.571047+0800][6236][INFO] [Horizontal_Distance_To_Hydrology] quality loss for constraints ge = -1.2579114221529188. Remai

  0%|          | 0/100 [00:00<?, ?it/s]
epoch 0
Aw
epoch 1
Aw
epoch 2
Aw
epoch 3
Aw
epoch 4
Aw
epoch 5
Aw
epoch 6
Aw
epoch 7
Aw
epoch 8
Aw
epoch 9
Aw
epoch 10
Aw
epoch 11
Aw
epoch 12
Aw
epoch 13
Aw
epoch 14
Aw
epoch 15
Aw
epoch 16
Aw
epoch 17
Aw
epoch 18
Aw
epoch 19
Aw
epoch 20
Aw
epoch 21
Aw
epoch 22
Aw
epoch 23
Aw
epoch 24
Aw
epoch 25
Aw
epoch 26
Aw
epoch 27
Aw
epoch 28
Aw
epoch 29
Aw
epoch 30
Aw
epoch 31
Aw
epoch 32
Aw
epoch 33
Aw
epoch 34
Aw
epoch 35
Aw
epoch 36
Aw
epoch 37
Aw
epoch 38
Aw
epoch 39
Aw
epoch 40
Aw
epoch 41
Aw
epoch 42
Aw
epoch 43
Aw
epoch 44
Aw
epoch 45
Aw
epoch 46
Aw
epoch 47
Aw
epoch 48
Aw
epoch 49
Aw
epoch 50
Aw
epoch 51
Aw
epoch 52
Aw
epoch 53
Aw
epoch 54
Aw
epoch 55
Aw
epoch 56
Aw
epoch 57
Aw
epoch 58
Aw
epoch 59
Aw
epoch 60
Aw
epoch 61
Aw
epoch 62
Aw
epoch 63
Aw
epoch 64
Aw
epoch 65
Aw
epoch 66
Aw
epoch 67
Aw
epoch 68
Aw
epoch 69
Aw
epoch 70
Aw
epoch 71
Aw
epoch 72
Aw
epoch 73
Aw
epoch 74
Aw
epoch 75
Aw
epoch 76
Aw
epoch 77
Aw
epoch 78
Aw
epoch 79
Aw
epoch 80
A

[2025-02-02T18:10:12.157288+0800][6236][INFO] [Hillshade_3pm] quality loss for constraints le = 2.7096743428136514. Remaining 297. prev length 1280. Original dtype float64.
[2025-02-02T18:10:12.159287+0800][6236][INFO] [Hillshade_3pm] quality loss for constraints ge = -3.731208897747535. Remaining 289. prev length 297. Original dtype float64.
[2025-02-02T18:10:12.161262+0800][6236][INFO] [Elevation] quality loss for constraints ge = -3.7461682315564473. Remaining 200. prev length 289. Original dtype float64.
[2025-02-02T18:10:12.164278+0800][6236][INFO] [Slope] quality loss for constraints le = 3.747858032342592. Remaining 197. prev length 200. Original dtype float64.
[2025-02-02T18:10:12.165267+0800][6236][INFO] [Slope] quality loss for constraints ge = -1.7466507308578898. Remaining 149. prev length 197. Original dtype float64.
[2025-02-02T18:10:12.168262+0800][6236][INFO] [Horizontal_Distance_To_Hydrology] quality loss for constraints le = 4.54625632482397. Remaining 148. prev lengt

Unnamed: 0,min,max,mean,stddev,median,iqr,rounds,errors,durations,direction
detection.detection_mlp.mean,0.998414,0.998623,0.998518,0.000104,0.998518,0.000104,2,0,1.88,minimize


## Create an Optuna study and optimize the hyperparameters

In [11]:
def objective(trial: optuna.Trial):
    hp_space = Plugins().get(PLUGIN).hyperparameter_space()
    hp_space[0].high = 100  # speed up for now
    params = suggest_all(trial, hp_space)
    ID = f"trial_{trial.number}"
    try:
        report = Benchmarks.evaluate(
            [(ID, PLUGIN, params)],
            train_loader,
            repeats=2,
            metrics={"detection": ["detection_mlp"]},  # DELETE THIS LINE FOR ALL METRICS
        )
    except Exception as e:  # invalid set of params
        print(f"{type(e).__name__}: {e}")
        print(params)
        raise optuna.TrialPruned()
    score = report[ID].query('direction == "minimize"')['mean'].mean()
    # average score across all metrics with direction="minimize"
    return score

study = optuna.create_study(direction="minimize")
study.optimize(objective, n_trials=4)
study.best_params

[2025-02-02T18:29:19.362801+0800][6236][CRITICAL] module disabled: e:\qycache\anaconda\envs\LLM\lib\site-packages\synthcity\plugins\generic\plugin_goggle.py
[2025-02-02T18:29:19.362801+0800][6236][CRITICAL] module disabled: e:\qycache\anaconda\envs\LLM\lib\site-packages\synthcity\plugins\generic\plugin_goggle.py
[2025-02-02T18:29:19.387319+0800][6236][INFO] Testcase : trial_0
[2025-02-02T18:29:19.389319+0800][6236][INFO] [testcase] Experiment repeat: 0 task type: classification Train df hash = 7570246924558574363
[2025-02-02T18:29:19.389319+0800][6236][CRITICAL] module disabled: e:\qycache\anaconda\envs\LLM\lib\site-packages\synthcity\plugins\generic\plugin_goggle.py
[2025-02-02T18:29:19.389319+0800][6236][CRITICAL] module disabled: e:\qycache\anaconda\envs\LLM\lib\site-packages\synthcity\plugins\generic\plugin_goggle.py


  0%|          | 0/300 [00:00<?, ?it/s]
epoch 0
Aw
epoch 1
Aw
epoch 2
Aw
epoch 3
Aw
epoch 4
Aw
epoch 5
Aw
epoch 6
Aw
epoch 7
Aw
epoch 8
Aw
epoch 9
Aw
epoch 10
Aw
epoch 11
Aw
epoch 12
Aw
epoch 13
Aw
epoch 14
Aw
epoch 15
Aw
epoch 16
Aw
epoch 17
Aw
epoch 18
Aw
epoch 19
Aw
epoch 20
Aw
epoch 21
Aw
epoch 22
Aw
epoch 23
Aw
epoch 24
Aw
epoch 25
Aw
epoch 26
Aw
epoch 27
Aw
epoch 28
Aw
epoch 29
Aw
epoch 30
Aw
epoch 31
Aw
epoch 32
Aw
epoch 33
Aw
epoch 34
Aw
epoch 35
Aw
epoch 36
Aw
epoch 37
Aw
epoch 38
Aw
epoch 39
Aw
epoch 40
Aw
epoch 41
Aw
epoch 42
Aw
epoch 43
Aw
epoch 44
Aw
epoch 45
Aw
epoch 46
Aw
epoch 47
Aw
epoch 48
Aw
epoch 49
Aw
epoch 50
Aw
epoch 51
Aw
epoch 52
Aw
epoch 53
Aw
epoch 54
Aw
epoch 55
Aw
epoch 56
Aw
epoch 57
Aw
epoch 58
Aw
epoch 59
Aw
epoch 60
Aw
epoch 61
Aw
epoch 62
Aw
epoch 63
Aw
epoch 64
Aw
epoch 65
Aw
epoch 66
Aw
epoch 67
Aw
epoch 68
Aw
epoch 69
Aw
epoch 70
Aw
epoch 71
Aw
epoch 72
Aw
epoch 73
Aw
epoch 74
Aw
epoch 75
Aw
epoch 76
Aw
epoch 77
Aw
epoch 78
Aw
epoch 79
Aw
epoch 80
A

[2025-02-02T18:35:57.872763+0800][6236][INFO] [Hillshade_3pm] quality loss for constraints le = 2.7096743428136514. Remaining 1239. prev length 1280. Original dtype float64.
[2025-02-02T18:35:57.873780+0800][6236][INFO] [Hillshade_3pm] quality loss for constraints ge = -3.731208897747535. Remaining 1234. prev length 1239. Original dtype float64.
[2025-02-02T18:35:57.875767+0800][6236][INFO] [Elevation] quality loss for constraints ge = -3.7461682315564473. Remaining 1216. prev length 1234. Original dtype float64.
[2025-02-02T18:35:57.876767+0800][6236][INFO] [Slope] quality loss for constraints le = 3.747858032342592. Remaining 1205. prev length 1216. Original dtype float64.
[2025-02-02T18:35:57.877759+0800][6236][INFO] [Slope] quality loss for constraints ge = -1.7466507308578898. Remaining 1196. prev length 1205. Original dtype float64.
[2025-02-02T18:35:57.879769+0800][6236][INFO] [Horizontal_Distance_To_Hydrology] quality loss for constraints ge = -1.2579114221529188. Remaining 652

  0%|          | 0/300 [00:00<?, ?it/s]
epoch 0
Aw
epoch 1
Aw
epoch 2
Aw
epoch 3
Aw
epoch 4
Aw
epoch 5
Aw
epoch 6
Aw
epoch 7
Aw
epoch 8
Aw
epoch 9
Aw
epoch 10
Aw
epoch 11
Aw
epoch 12
Aw
epoch 13
Aw
epoch 14
Aw
epoch 15
Aw
epoch 16
Aw
epoch 17
Aw
epoch 18
Aw
epoch 19
Aw
epoch 20
Aw
epoch 21
Aw
epoch 22
Aw
epoch 23
Aw
epoch 24
Aw
epoch 25
Aw
epoch 26
Aw
epoch 27
Aw
epoch 28
Aw
epoch 29
Aw
epoch 30
Aw
epoch 31
Aw
epoch 32
Aw
epoch 33
Aw
epoch 34
Aw
epoch 35
Aw
epoch 36
Aw
epoch 37
Aw
epoch 38
Aw
epoch 39
Aw
epoch 40
Aw
epoch 41
Aw
epoch 42
Aw
epoch 43
Aw
epoch 44
Aw
epoch 45
Aw
epoch 46
Aw
epoch 47
Aw
epoch 48
Aw
epoch 49
Aw
epoch 50
Aw
epoch 51
Aw
epoch 52
Aw
epoch 53
Aw
epoch 54
Aw
epoch 55
Aw
epoch 56
Aw
epoch 57
Aw
epoch 58
Aw
epoch 59
Aw
epoch 60
Aw
epoch 61
Aw
epoch 62
Aw
epoch 63
Aw
epoch 64
Aw
epoch 65
Aw
epoch 66
Aw
epoch 67
Aw
epoch 68
Aw
epoch 69
Aw
epoch 70
Aw
epoch 71
Aw
epoch 72
Aw
epoch 73
Aw
epoch 74
Aw
epoch 75
Aw
epoch 76
Aw
epoch 77
Aw
epoch 78
Aw
epoch 79
Aw
epoch 80
A

[2025-02-02T18:42:40.115624+0800][6236][INFO] [Hillshade_3pm] quality loss for constraints le = 2.7096743428136514. Remaining 1249. prev length 1280. Original dtype float64.
[2025-02-02T18:42:40.116622+0800][6236][INFO] [Hillshade_3pm] quality loss for constraints ge = -3.731208897747535. Remaining 1089. prev length 1249. Original dtype float64.
[2025-02-02T18:42:40.118627+0800][6236][INFO] [Elevation] quality loss for constraints le = 2.717010401566637. Remaining 1088. prev length 1089. Original dtype float64.
[2025-02-02T18:42:40.119628+0800][6236][INFO] [Elevation] quality loss for constraints ge = -3.7461682315564473. Remaining 995. prev length 1088. Original dtype float64.
[2025-02-02T18:42:40.120634+0800][6236][INFO] [Slope] quality loss for constraints le = 3.747858032342592. Remaining 987. prev length 995. Original dtype float64.
[2025-02-02T18:42:40.121633+0800][6236][INFO] [Slope] quality loss for constraints ge = -1.7466507308578898. Remaining 773. prev length 987. Original 

  0%|          | 0/150 [00:00<?, ?it/s]
epoch 0
Aw
epoch 1
Aw
epoch 2
Aw
epoch 3
Aw
epoch 4
Aw
epoch 5
Aw
epoch 6
Aw
epoch 7
Aw
epoch 8
Aw
epoch 9
Aw
epoch 10
Aw
epoch 11
Aw
epoch 12
Aw
epoch 13
Aw
epoch 14
Aw
epoch 15
Aw
epoch 16
Aw
epoch 17
Aw
epoch 18
Aw
epoch 19
Aw
epoch 20
Aw
epoch 21
Aw
epoch 22
Aw
epoch 23
Aw
epoch 24
Aw
epoch 25
Aw
epoch 26
Aw
epoch 27
Aw
epoch 28
Aw
epoch 29
Aw
epoch 30
Aw
epoch 31
Aw
epoch 32
Aw
epoch 33
Aw
epoch 34
Aw
epoch 35
Aw
epoch 36
Aw
epoch 37
Aw
epoch 38
Aw
epoch 39
Aw
epoch 40
Aw
epoch 41
Aw
epoch 42
Aw
epoch 43
Aw
epoch 44
Aw
epoch 45
Aw
epoch 46
Aw
epoch 47
Aw
epoch 48
Aw
epoch 49
Aw
epoch 50
Aw
epoch 51
Aw
epoch 52
Aw
epoch 53
Aw
epoch 54
Aw
epoch 55
Aw
epoch 56
Aw
epoch 57
Aw
epoch 58
Aw
epoch 59
Aw
epoch 60
Aw
epoch 61
Aw
epoch 62
Aw
epoch 63
Aw
epoch 64
Aw
epoch 65
Aw
epoch 66
Aw
epoch 67
Aw
epoch 68
Aw
epoch 69
Aw
epoch 70
Aw
epoch 71
Aw
epoch 72
Aw
epoch 73
Aw
epoch 74
Aw
epoch 75
Aw
epoch 76
Aw
epoch 77
Aw
epoch 78
Aw
epoch 79
Aw
epoch 80
A

[2025-02-02T18:51:28.343306+0800][6236][INFO] [Hillshade_3pm] quality loss for constraints le = 2.7096743428136514. Remaining 1143. prev length 1280. Original dtype float64.
[2025-02-02T18:51:28.344302+0800][6236][INFO] [Hillshade_3pm] quality loss for constraints ge = -3.731208897747535. Remaining 1115. prev length 1143. Original dtype float64.
[2025-02-02T18:51:28.345301+0800][6236][INFO] [Elevation] quality loss for constraints le = 2.717010401566637. Remaining 1114. prev length 1115. Original dtype float64.
[2025-02-02T18:51:28.346307+0800][6236][INFO] [Elevation] quality loss for constraints ge = -3.7461682315564473. Remaining 1105. prev length 1114. Original dtype float64.
[2025-02-02T18:51:28.347300+0800][6236][INFO] [Slope] quality loss for constraints le = 3.747858032342592. Remaining 1094. prev length 1105. Original dtype float64.
[2025-02-02T18:51:28.348300+0800][6236][INFO] [Slope] quality loss for constraints ge = -1.7466507308578898. Remaining 1064. prev length 1094. Orig

  0%|          | 0/150 [00:00<?, ?it/s]
epoch 0
Aw
epoch 1
Aw
epoch 2
Aw
epoch 3
Aw
epoch 4
Aw
epoch 5
Aw
epoch 6
Aw
epoch 7
Aw
epoch 8
Aw
epoch 9
Aw
epoch 10
Aw
epoch 11
Aw
epoch 12
Aw
epoch 13
Aw
epoch 14
Aw
epoch 15
Aw
epoch 16
Aw
epoch 17
Aw
epoch 18
Aw
epoch 19
Aw
epoch 20
Aw
epoch 21
Aw
epoch 22
Aw
epoch 23
Aw
epoch 24
Aw
epoch 25
Aw
epoch 26
Aw
epoch 27
Aw
epoch 28
Aw
epoch 29
Aw
epoch 30
Aw
epoch 31
Aw
epoch 32
Aw
epoch 33
Aw
epoch 34
Aw
epoch 35
Aw
epoch 36
Aw
epoch 37
Aw
epoch 38
Aw
epoch 39
Aw
epoch 40
Aw
epoch 41
Aw
epoch 42
Aw
epoch 43
Aw
epoch 44
Aw
epoch 45
Aw
epoch 46
Aw
epoch 47
Aw
epoch 48
Aw
epoch 49
Aw
epoch 50
Aw
epoch 51
Aw
epoch 52
Aw
epoch 53
Aw
epoch 54
Aw
epoch 55
Aw
epoch 56
Aw
epoch 57
Aw
epoch 58
Aw
epoch 59
Aw
epoch 60
Aw
epoch 61
Aw
epoch 62
Aw
epoch 63
Aw
epoch 64
Aw
epoch 65
Aw
epoch 66
Aw
epoch 67
Aw
epoch 68
Aw
epoch 69
Aw
epoch 70
Aw
epoch 71
Aw
epoch 72
Aw
epoch 73
Aw
epoch 74
Aw
epoch 75
Aw
epoch 76
Aw
epoch 77
Aw
epoch 78
Aw
epoch 79
Aw
epoch 80
A

[2025-02-02T19:00:13.572163+0800][6236][INFO] [Hillshade_3pm] quality loss for constraints le = 2.7096743428136514. Remaining 1039. prev length 1280. Original dtype float64.
[2025-02-02T19:00:13.573154+0800][6236][INFO] [Hillshade_3pm] quality loss for constraints ge = -3.731208897747535. Remaining 1025. prev length 1039. Original dtype float64.
[2025-02-02T19:00:13.574156+0800][6236][INFO] [Elevation] quality loss for constraints le = 2.717010401566637. Remaining 1001. prev length 1025. Original dtype float64.
[2025-02-02T19:00:13.575155+0800][6236][INFO] [Elevation] quality loss for constraints ge = -3.7461682315564473. Remaining 956. prev length 1001. Original dtype float64.
[2025-02-02T19:00:13.576071+0800][6236][INFO] [Slope] quality loss for constraints le = 3.747858032342592. Remaining 953. prev length 956. Original dtype float64.
[2025-02-02T19:00:13.577160+0800][6236][INFO] [Slope] quality loss for constraints ge = -1.7466507308578898. Remaining 167. prev length 953. Original 

  0%|          | 0/250 [00:00<?, ?it/s]
epoch 0
Aw
epoch 1
Aw
epoch 2
Aw
epoch 3
Aw
epoch 4
Aw
epoch 5
Aw
epoch 6
Aw
epoch 7
Aw
epoch 8
Aw
epoch 9
Aw
epoch 10
Aw
epoch 11
Aw
epoch 12
Aw
epoch 13
Aw
epoch 14
Aw
epoch 15
Aw
epoch 16
Aw
epoch 17
Aw
epoch 18
Aw
epoch 19
Aw
epoch 20
Aw
epoch 21
Aw
epoch 22
Aw
epoch 23
Aw
epoch 24
Aw
epoch 25
Aw
epoch 26
Aw
epoch 27
Aw
epoch 28
Aw
epoch 29
Aw
epoch 30
Aw
epoch 31
Aw
epoch 32
Aw
epoch 33
Aw
epoch 34
Aw
epoch 35
Aw
epoch 36
Aw
epoch 37
Aw
epoch 38
Aw
epoch 39
Aw
epoch 40
Aw
epoch 41
Aw
epoch 42
Aw
epoch 43
Aw
epoch 44
Aw
epoch 45
Aw
epoch 46
Aw
epoch 47
Aw
epoch 48
Aw
epoch 49
Aw
epoch 50
Aw
epoch 51
Aw
epoch 52
Aw
epoch 53
Aw
epoch 54
Aw
epoch 55
Aw
epoch 56
Aw
epoch 57
Aw
epoch 58
Aw
epoch 59
Aw
epoch 60
Aw
epoch 61
Aw
epoch 62
Aw
epoch 63
Aw
epoch 64
Aw
epoch 65
Aw
epoch 66
Aw
epoch 67
Aw
epoch 68
Aw
epoch 69
Aw
epoch 70
Aw
epoch 71
Aw
epoch 72
Aw
epoch 73
Aw
epoch 74
Aw
epoch 75
Aw
epoch 76
Aw
epoch 77
Aw
epoch 78
Aw
epoch 79
Aw
epoch 80
A

[2025-02-02T19:40:56.573817+0800][6236][INFO] [Hillshade_3pm] quality loss for constraints le = 2.7096743428136514. Remaining 1268. prev length 1280. Original dtype float64.
[2025-02-02T19:40:56.574813+0800][6236][INFO] [Hillshade_3pm] quality loss for constraints ge = -3.731208897747535. Remaining 1246. prev length 1268. Original dtype float64.
[2025-02-02T19:40:56.577520+0800][6236][INFO] [Elevation] quality loss for constraints le = 2.717010401566637. Remaining 1243. prev length 1246. Original dtype float64.
[2025-02-02T19:40:56.577520+0800][6236][INFO] [Elevation] quality loss for constraints ge = -3.7461682315564473. Remaining 1221. prev length 1243. Original dtype float64.
[2025-02-02T19:40:56.579568+0800][6236][INFO] [Slope] quality loss for constraints le = 3.747858032342592. Remaining 1207. prev length 1221. Original dtype float64.
[2025-02-02T19:40:56.579568+0800][6236][INFO] [Slope] quality loss for constraints ge = -1.7466507308578898. Remaining 851. prev length 1207. Origi

  0%|          | 0/250 [00:00<?, ?it/s]
epoch 0
Aw
epoch 1
Aw
epoch 2
Aw
epoch 3
Aw
epoch 4
Aw
epoch 5
Aw
epoch 6
Aw
epoch 7
Aw
epoch 8
Aw
epoch 9
Aw
epoch 10
Aw
epoch 11
Aw
epoch 12
Aw
epoch 13
Aw
epoch 14
Aw
epoch 15
Aw
epoch 16
Aw
epoch 17
Aw
epoch 18
Aw
epoch 19
Aw
epoch 20
Aw
epoch 21
Aw
epoch 22
Aw
epoch 23
Aw
epoch 24
Aw
epoch 25
Aw
epoch 26
Aw
epoch 27
Aw
epoch 28
Aw
epoch 29
Aw
epoch 30
Aw
epoch 31
Aw
epoch 32
Aw
epoch 33
Aw
epoch 34
Aw
epoch 35
Aw
epoch 36
Aw
epoch 37
Aw
epoch 38
Aw
epoch 39
Aw
epoch 40
Aw
epoch 41
Aw
epoch 42
Aw
epoch 43
Aw
epoch 44
Aw
epoch 45
Aw
epoch 46
Aw
epoch 47
Aw
epoch 48
Aw
epoch 49
Aw
epoch 50
Aw
epoch 51
Aw
epoch 52
Aw
epoch 53
Aw
epoch 54
Aw
epoch 55
Aw
epoch 56
Aw
epoch 57
Aw
epoch 58
Aw
epoch 59
Aw
epoch 60
Aw
epoch 61
Aw
epoch 62
Aw
epoch 63
Aw
epoch 64
Aw
epoch 65
Aw
epoch 66
Aw
epoch 67
Aw
epoch 68
Aw
epoch 69
Aw
epoch 70
Aw
epoch 71
Aw
epoch 72
Aw
epoch 73
Aw
epoch 74
Aw
epoch 75
Aw
epoch 76
Aw
epoch 77
Aw
epoch 78
Aw
epoch 79
Aw
epoch 80
A

[2025-02-02T19:57:46.273109+0800][6236][INFO] [Hillshade_3pm] quality loss for constraints le = 2.7096743428136514. Remaining 1237. prev length 1280. Original dtype float64.
[2025-02-02T19:57:46.274110+0800][6236][INFO] [Hillshade_3pm] quality loss for constraints ge = -3.731208897747535. Remaining 1184. prev length 1237. Original dtype float64.
[2025-02-02T19:57:46.275622+0800][6236][INFO] [Elevation] quality loss for constraints le = 2.717010401566637. Remaining 1182. prev length 1184. Original dtype float64.
[2025-02-02T19:57:46.276630+0800][6236][INFO] [Elevation] quality loss for constraints ge = -3.7461682315564473. Remaining 1031. prev length 1182. Original dtype float64.
[2025-02-02T19:57:46.278215+0800][6236][INFO] [Slope] quality loss for constraints le = 3.747858032342592. Remaining 998. prev length 1031. Original dtype float64.
[2025-02-02T19:57:46.279128+0800][6236][INFO] [Slope] quality loss for constraints ge = -1.7466507308578898. Remaining 946. prev length 998. Origina

  0%|          | 0/250 [00:00<?, ?it/s]
epoch 0
Aw
epoch 1
Aw
epoch 2
Aw
epoch 3
Aw
epoch 4
Aw
epoch 5
Aw
epoch 6
Aw
epoch 7
Aw
epoch 8
Aw
epoch 9
Aw
epoch 10
Aw
epoch 11
Aw
epoch 12
Aw
epoch 13
Aw
epoch 14
Aw
epoch 15
Aw
epoch 16
Aw
epoch 17
Aw
epoch 18
Aw
epoch 19
Aw
epoch 20
Aw
epoch 21
Aw
epoch 22
Aw
epoch 23
Aw
epoch 24
Aw
epoch 25
Aw
epoch 26
Aw
epoch 27
Aw
epoch 28
Aw
epoch 29
Aw
epoch 30
Aw
epoch 31
Aw
epoch 32
Aw
epoch 33
Aw
epoch 34
Aw
epoch 35
Aw
epoch 36
Aw
epoch 37
Aw
epoch 38
Aw
epoch 39
Aw
epoch 40
Aw
epoch 41
Aw
epoch 42
Aw
epoch 43
Aw
epoch 44
Aw
epoch 45
Aw
epoch 46
Aw
epoch 47
Aw
epoch 48
Aw
epoch 49
Aw
epoch 50
Aw
epoch 51
Aw
epoch 52
Aw
epoch 53
Aw
epoch 54
Aw
epoch 55
Aw
epoch 56
Aw
epoch 57
Aw
epoch 58
Aw
epoch 59
Aw
epoch 60
Aw
epoch 61
Aw
epoch 62
Aw
epoch 63
Aw
epoch 64
Aw
epoch 65
Aw
epoch 66
Aw
epoch 67
Aw
epoch 68
Aw
epoch 69
Aw
epoch 70
Aw
epoch 71
Aw
epoch 72
Aw
epoch 73
Aw
epoch 74
Aw
epoch 75
Aw
epoch 76
Aw
epoch 77
Aw
epoch 78
Aw
epoch 79
Aw
epoch 80
A

[2025-02-02T20:21:45.523010+0800][6236][INFO] [Hillshade_3pm] quality loss for constraints le = 2.7096743428136514. Remaining 1251. prev length 1280. Original dtype float64.
[2025-02-02T20:21:45.524531+0800][6236][INFO] [Hillshade_3pm] quality loss for constraints ge = -3.731208897747535. Remaining 1212. prev length 1251. Original dtype float64.
[2025-02-02T20:21:45.526539+0800][6236][INFO] [Elevation] quality loss for constraints le = 2.717010401566637. Remaining 1200. prev length 1212. Original dtype float64.
[2025-02-02T20:21:45.527539+0800][6236][INFO] [Elevation] quality loss for constraints ge = -3.7461682315564473. Remaining 1145. prev length 1200. Original dtype float64.
[2025-02-02T20:21:45.529538+0800][6236][INFO] [Slope] quality loss for constraints le = 3.747858032342592. Remaining 1090. prev length 1145. Original dtype float64.
[2025-02-02T20:21:45.530538+0800][6236][INFO] [Slope] quality loss for constraints ge = -1.7466507308578898. Remaining 1069. prev length 1090. Orig

  0%|          | 0/250 [00:00<?, ?it/s]
epoch 0
Aw
epoch 1
Aw
epoch 2
Aw
epoch 3
Aw
epoch 4
Aw
epoch 5
Aw
epoch 6
Aw
epoch 7
Aw
epoch 8
Aw
epoch 9
Aw
epoch 10
Aw
epoch 11
Aw
epoch 12
Aw
epoch 13
Aw
epoch 14
Aw
epoch 15
Aw
epoch 16
Aw
epoch 17
Aw
epoch 18
Aw
epoch 19
Aw
epoch 20
Aw
epoch 21
Aw
epoch 22
Aw
epoch 23
Aw
epoch 24
Aw
epoch 25
Aw
epoch 26
Aw
epoch 27
Aw
epoch 28
Aw
epoch 29
Aw
epoch 30
Aw
epoch 31
Aw
epoch 32
Aw
epoch 33
Aw
epoch 34
Aw
epoch 35
Aw
epoch 36
Aw
epoch 37
Aw
epoch 38
Aw
epoch 39
Aw
epoch 40
Aw
epoch 41
Aw
epoch 42
Aw
epoch 43
Aw
epoch 44
Aw
epoch 45
Aw
epoch 46
Aw
epoch 47
Aw
epoch 48
Aw
epoch 49
Aw
epoch 50
Aw
epoch 51
Aw
epoch 52
Aw
epoch 53
Aw
epoch 54
Aw
epoch 55
Aw
epoch 56
Aw
epoch 57
Aw
epoch 58
Aw
epoch 59
Aw
epoch 60
Aw
epoch 61
Aw
epoch 62
Aw
epoch 63
Aw
epoch 64
Aw
epoch 65
Aw
epoch 66
Aw
epoch 67
Aw
epoch 68
Aw
epoch 69
Aw
epoch 70
Aw
epoch 71
Aw
epoch 72
Aw
epoch 73
Aw
epoch 74
Aw
epoch 75
Aw
epoch 76
Aw
epoch 77
Aw
epoch 78
Aw
epoch 79
Aw
epoch 80
A

[2025-02-02T20:45:29.945557+0800][6236][INFO] [Hillshade_3pm] quality loss for constraints le = 2.7096743428136514. Remaining 1149. prev length 1280. Original dtype float64.
[2025-02-02T20:45:29.947579+0800][6236][INFO] [Hillshade_3pm] quality loss for constraints ge = -3.731208897747535. Remaining 1148. prev length 1149. Original dtype float64.
[2025-02-02T20:45:29.949970+0800][6236][INFO] [Elevation] quality loss for constraints le = 2.717010401566637. Remaining 1145. prev length 1148. Original dtype float64.
[2025-02-02T20:45:29.950973+0800][6236][INFO] [Elevation] quality loss for constraints ge = -3.7461682315564473. Remaining 1110. prev length 1145. Original dtype float64.
[2025-02-02T20:45:29.952977+0800][6236][INFO] [Slope] quality loss for constraints le = 3.747858032342592. Remaining 1108. prev length 1110. Original dtype float64.
[2025-02-02T20:45:29.953483+0800][6236][INFO] [Slope] quality loss for constraints ge = -1.7466507308578898. Remaining 835. prev length 1108. Origi

{'batch_size': 100, 'n_iter': 250, 'num_layers': 8, 'num_heads': 8}

## Visualize the study

In [14]:
# pip install nbformat

Note: you may need to restart the kernel to use updated packages.


In [12]:
from optuna.visualization import plot_contour
from optuna.visualization import plot_edf
from optuna.visualization import plot_optimization_history
from optuna.visualization import plot_parallel_coordinate
from optuna.visualization import plot_param_importances
from optuna.visualization import plot_slice

plot_optimization_history(study)

In [13]:
# Visualize high-dimensional parameter relationships. 
plot_parallel_coordinate(study)

## Test performance of the optimized plugin

In [None]:
best_params = study.best_params
report = Benchmarks.evaluate(
    [("test", PLUGIN, best_params)],
    train_loader,
    test_loader,
    repeats=1,
    metrics={"detection": ["detection_mlp", "detection_xgb"]},  # DELETE THIS LINE FOR ALL METRICS
)
Benchmarks.print(report)

[2024-11-24T21:39:13.996872+0800][38128][CRITICAL] module disabled: e:\qycache\anaconda\envs\LLM\lib\site-packages\synthcity\plugins\generic\plugin_goggle.py
100%|██████████| 100/100 [10:39<00:00,  6.40s/it]



[4m[1mPlugin : test[0m[0m


Unnamed: 0,min,max,mean,stddev,median,iqr,rounds,errors,durations
detection.detection_xgb.mean,1.0,1.0,1.0,0.0,1.0,0.0,1,0,0.25
detection.detection_mlp.mean,0.361494,0.361494,0.361494,0.0,0.361494,0.0,1,0,0.61





## Congratulations!

Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the movement towards Machine learning and AI for medicine, you can do so in the following ways!

### Star [Synthcity](https://github.com/vanderschaarlab/synthcity) on GitHub

- The easiest way to help our community is just by starring the Repos! This helps raise awareness of the tools we're building.


### Checkout other projects from vanderschaarlab
- [HyperImpute](https://github.com/vanderschaarlab/hyperimpute)
- [AutoPrognosis](https://github.com/vanderschaarlab/autoprognosis)
