In [1]:
import optuna
from optuna import Trial
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import roc_auc_score
import numpy as np
from ctgan.synthesizers.ctgan import CTGAN,EnhancedCTGAN
from scipy.stats import wasserstein_distance, entropy
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


In [None]:

plt.rcParams.update({
    'font.family': 'serif',
    'font.size': 12,
    'axes.labelsize': 14,
    'axes.titlesize': 16,
    'figure.dpi': 300,
    'savefig.dpi': 300,
    'pdf.fonttype': 42
})

def objective(trial: Trial, real_data, discrete_columns):

    params = {
        "dropout": trial.suggest_float("dropout", 0.1, 0.5),         
        "num_first": trial.suggest_categorical("num_first", [True, False]),
        "batch_first": trial.suggest_categorical("batch_first", [True, False]), 
        "num_layers": trial.suggest_int("num_layers", 2, 8),        
        "num_heads": trial.suggest_categorical("num_heads", [4, 8]),
    }

    try:

        ctgan = EnhancedCTGAN(
            embedding_dim=128,
            generator_dim=(256, 256),
            discriminator_dim=(256, 256),
            dropout=params["dropout"],
            num_first=params["num_first"],
            batch_first=params["batch_first"],
            num_layers=params["num_layers"],
            num_heads=params["num_heads"],
            cuda=True 
        )

        ctgan.fit(real_data, discrete_columns=discrete_columns)

        synthetic_data = ctgan.sample(len(real_data))

        metrics = evaluate_data_quality(real_data, synthetic_data, discrete_columns)
        trial.set_user_attr("metrics", metrics)
        
        composite_score = (
            0.4 * (1 - metrics["wd"]) + 
            0.3 * (1 - metrics["jsd"]) +
            0.3 * (1 - abs(metrics["auc"] - 0.5))
        )
        return composite_score

    except Exception as e:
        print(f"Trial failed with parameters: {params}. Error: {str(e)}")
        return 0 

def evaluate_data_quality(real, synthetic, discrete_columns):
    metrics = {}

    cont_cols = [c for c in real.columns if c not in discrete_columns]
    if cont_cols:
        wd = np.mean([wasserstein_distance(real[c], synthetic[c]) for c in cont_cols])
    else:
        wd = 0.0
    metrics["wd"] = wd

    jsd_values = []
    for col in discrete_columns:
        real_p = real[col].value_counts(normalize=True)
        syn_p = synthetic[col].value_counts(normalize=True)
        all_cats = real_p.index.union(syn_p.index)
        
        real_p = real_p.reindex(all_cats, fill_value=0).values + 1e-10
        syn_p = syn_p.reindex(all_cats, fill_value=0).values + 1e-10
        
        m = 0.5 * (real_p + syn_p)
        jsd = 0.5 * (entropy(real_p, m) + entropy(syn_p, m))
        jsd_values.append(jsd)
    metrics["jsd"] = np.mean(jsd_values)
    
    combined = pd.concat([real, synthetic])
    labels = np.concatenate([np.ones(len(real)), np.zeros(len(synthetic))])

    auc_scores = []
    for _ in range(3):
        X_train, X_test, y_train, y_test = train_test_split(
            combined, labels, test_size=0.3, stratify=labels
        )
        clf = RandomForestClassifier(n_estimators=100, random_state=0)
        clf.fit(X_train, y_train)
        preds = clf.predict_proba(X_test)[:, 1]
        auc_scores.append(roc_auc_score(y_test, preds))
    metrics["auc"] = np.mean(auc_scores)
    
    return metrics

def generate_publication_figures(study):

    df = study.trials_dataframe()
    params = ["dropout", "num_first", "batch_first", "num_layers", "num_heads"]

    fig, axs = plt.subplots(2, 3, figsize=(18, 12))

    importance = optuna.importance.get_param_importances(study)
    sns.heatmap(
        pd.DataFrame([importance]), 
        annot=True, 
        cmap="YlGnBu",
        ax=axs[0,0],
        cbar_kws={'label': 'Importance Score'}
    )
    axs[0,0].set_title("(a) Parameter Importance Matrix")
    axs[0,0].set_xticklabels(importance.keys(), rotation=45)

    sns.countplot(
        x=df["params_num_first"].astype(bool),
        ax=axs[0,1],
        palette="Set2"
    )
    axs[0,1].set_title("(b) Norm First Distribution")
    axs[0,1].set_xlabel("Norm First")
    
    sns.countplot(
        x=df["params_batch_first"].astype(bool),
        ax=axs[0,2],
        palette="Set3"
    )
    axs[0,2].set_title("(c) Batch First Distribution")
    axs[0,2].set_xlabel("Batch First")

    sns.scatterplot(
        data=df,
        x="params_dropout",
        y="value",
        hue="params_num_layers",
        size="params_num_heads",
        sizes=(20, 200),
        palette="viridis",
        ax=axs[1,0]
    )
    axs[1,0].set_title("(d) Dropout vs Score")

    sns.boxplot(
        x="params_num_heads",
        y="value",
        hue="params_num_first",
        data=df,
        ax=axs[1,1],
        palette="Set2"
    )
    axs[1,1].set_title("(e) Heads & Norm Interaction")

    sns.lineplot(
        x=df["number"],
        y=df["value"].cummax(),
        ax=axs[1,2],
        color="darkblue"
    )
    axs[1,2].set_title("(f) Optimization Progress")
    
    plt.tight_layout()
    plt.savefig("full_analysis.pdf", bbox_inches='tight')
    plt.close()

def run_optimization(real_data, discrete_columns, n_trials=20):
    study = optuna.create_study(
        direction="maximize",
        sampler=optuna.samplers.TPESampler(seed=42),
        pruner=optuna.pruners.MedianPruner()
    )
    
    study.optimize(
        lambda trial: objective(trial, real_data, discrete_columns),
        n_trials=n_trials,
        show_progress_bar=True
    )
    
    generate_publication_figures(study)
    return study

In [None]:
import pandas as pd
import numpy as np

def auto_detect_discrete_columns(data, unique_ratio_threshold=0.05, unique_count_threshold=20):

    discrete_cols = []
    
    for col in data.columns:
        col_data = data[col].dropna()
        

        if len(col_data) == 0:
            continue
        

        dtype = col_data.dtype
        

        if dtype in [object, 'category', bool]:
            discrete_cols.append(col)
            continue
            
        if np.issubdtype(dtype, np.number):
            n_unique = col_data.nunique()
            total = len(col_data)
            

            if n_unique <= unique_count_threshold:
                discrete_cols.append(col)
                continue
                

            unique_ratio = n_unique / total
            if unique_ratio < unique_ratio_threshold:

                if np.issubdtype(dtype, np.integer):
                    discrete_cols.append(col)
                elif (col_data == col_data.astype(int)).all():
                    discrete_cols.append(col)
    
    return discrete_cols


if __name__ == "__main__":

    #real_path = "../synthcity-main/tutorials/covertype_preprocessed.csv"
    #real_path = "../CTAB-GAN-main/Real_Datasets/Credit.csv"
    #real_path = "../CTAB-GAN-main/Real_Datasets/Adult3.csv"
    real_path = '../CTGAN-main/CTGAN-main/examples/csv/train_clean.csv'
    real_data = pd.read_csv(real_path)
    
    discrete_cols = auto_detect_discrete_columns(real_data)
    print(discrete_cols)

In [None]:
def smart_normalization(data, discrete_cols, threshold=0.1):

    continuous_cols = [col for col in data.columns if col not in discrete_cols]
    if not continuous_cols:
        return data.copy(), False
    

    needs_scaling = False
    for col in continuous_cols:
        col_mean = data[col].mean()
        col_std = data[col].std()

        if abs(col_mean) > threshold or not (1-threshold < col_std < 1+threshold):
            needs_scaling = True
            break
    
    if needs_scaling:
        from sklearn.preprocessing import StandardScaler
        scaler = StandardScaler()
        data_scaled = data.copy()
        data_scaled[continuous_cols] = scaler.fit_transform(data[continuous_cols])
        return data_scaled, True
    else:
        return data.copy(), False

In [None]:
#discrete_cols = []
study = run_optimization(real_data_normalized, discrete_cols, n_trials=20)
best_params = study.best_params
print("\nOptimal Parameters:", best_params)
print("Optimal Parameters:")
for k, v in study.best_params.items():
    print(f"- {k}: {v}")

[I 2025-02-22 18:42:02,682] A new study created in memory with name: no-name-6c0af77c-217a-4880-b9ef-7dd041a64c75


  0%|          | 0/300 [00:00<?, ?it/s]
Hype-parameter:[dropout:0.249816047538945,num_first:True,batch_first:True,num_heads:8,num_layers:3]
epoch!0
epoch!1
epoch!2
epoch!3
epoch!4
epoch!5
epoch!6
epoch!7
epoch!8
epoch!9
epoch!10
epoch!11
epoch!12
epoch!13
epoch!14
epoch!15
epoch!16
epoch!17
epoch!18
epoch!19
epoch!20
epoch!21
epoch!22
epoch!23
epoch!24
epoch!25
epoch!26
epoch!27
epoch!28
epoch!29
epoch!30
epoch!31
epoch!32
epoch!33
epoch!34
epoch!35
epoch!36
epoch!37
epoch!38
epoch!39
epoch!40
epoch!41
epoch!42
epoch!43
epoch!44
epoch!45
epoch!46
epoch!47
epoch!48
epoch!49
epoch!50
epoch!51
epoch!52
epoch!53
epoch!54
epoch!55
epoch!56
epoch!57
epoch!58
epoch!59
epoch!60
epoch!61
epoch!62
epoch!63
epoch!64
epoch!65
epoch!66
epoch!67
epoch!68
epoch!69
epoch!70
epoch!71
epoch!72
epoch!73
epoch!74
epoch!75
epoch!76
epoch!77
epoch!78
epoch!79
epoch!80
epoch!81
epoch!82
epoch!83
epoch!84
epoch!85
epoch!86
epoch!87
epoch!88
epoch!89
epoch!90
epoch!91
epoch!92
epoch!93
epoch!94
epoch!95
epoch!

Best trial: 0. Best value: 0.303153:   5%|▌         | 1/20 [02:30<47:37, 150.39s/it]

[I 2025-02-22 18:44:33,071] Trial 0 finished with value: 0.30315300557667 and parameters: {'dropout': 0.249816047538945, 'num_first': True, 'batch_first': True, 'num_layers': 3, 'num_heads': 8}. Best is trial 0 with value: 0.30315300557667.




  0%|          | 0/300 [00:00<?, ?it/s]
Hype-parameter:[dropout:0.34044600469728353,num_first:True,batch_first:True,num_heads:8,num_layers:3]
epoch!0
epoch!1
epoch!2
epoch!3
epoch!4
epoch!5
epoch!6
epoch!7
epoch!8
epoch!9
epoch!10
epoch!11
epoch!12
epoch!13
epoch!14
epoch!15
epoch!16
epoch!17
epoch!18
epoch!19
epoch!20
epoch!21
epoch!22
epoch!23
epoch!24
epoch!25
epoch!26
epoch!27
epoch!28
epoch!29
epoch!30
epoch!31
epoch!32
epoch!33
epoch!34
epoch!35
epoch!36
epoch!37
epoch!38
epoch!39
epoch!40
epoch!41
epoch!42
epoch!43
epoch!44
epoch!45
epoch!46
epoch!47
epoch!48
epoch!49
epoch!50
epoch!51
epoch!52
epoch!53
epoch!54
epoch!55
epoch!56
epoch!57
epoch!58
epoch!59
epoch!60
epoch!61
epoch!62
epoch!63
epoch!64
epoch!65
epoch!66
epoch!67
epoch!68
epoch!69
epoch!70
epoch!71
epoch!72
epoch!73
epoch!74
epoch!75
epoch!76
epoch!77
epoch!78
epoch!79
epoch!80
epoch!81
epoch!82
epoch!83
epoch!84
epoch!85
epoch!86
epoch!87
epoch!88
epoch!89
epoch!90
epoch!91
epoch!92
epoch!93
epoch!94
epoch!95
epoc

Best trial: 1. Best value: 0.401519:  10%|█         | 2/20 [05:09<46:41, 155.62s/it]

[I 2025-02-22 18:47:12,347] Trial 1 finished with value: 0.40151888030178284 and parameters: {'dropout': 0.34044600469728353, 'num_first': True, 'batch_first': True, 'num_layers': 3, 'num_heads': 8}. Best is trial 1 with value: 0.40151888030178284.




  0%|          | 0/300 [00:00<?, ?it/s]
Hype-parameter:[dropout:0.2216968971838151,num_first:True,batch_first:False,num_heads:8,num_layers:2]
epoch!0
epoch!1
epoch!2
epoch!3
epoch!4
epoch!5
epoch!6
epoch!7
epoch!8
epoch!9
epoch!10
epoch!11
epoch!12
epoch!13
epoch!14
epoch!15
epoch!16
epoch!17
epoch!18
epoch!19
epoch!20
epoch!21
epoch!22
epoch!23
epoch!24
epoch!25
epoch!26
epoch!27
epoch!28
epoch!29
epoch!30
epoch!31
epoch!32
epoch!33
epoch!34
epoch!35
epoch!36
epoch!37
epoch!38
epoch!39
epoch!40
epoch!41
epoch!42
epoch!43
epoch!44
epoch!45
epoch!46
epoch!47
epoch!48
epoch!49
epoch!50
epoch!51
epoch!52
epoch!53
epoch!54
epoch!55
epoch!56
epoch!57
epoch!58
epoch!59
epoch!60
epoch!61
epoch!62
epoch!63
epoch!64
epoch!65
epoch!66
epoch!67
epoch!68
epoch!69
epoch!70
epoch!71
epoch!72
epoch!73
epoch!74
epoch!75
epoch!76
epoch!77
epoch!78
epoch!79
epoch!80
epoch!81
epoch!82
epoch!83
epoch!84
epoch!85
epoch!86
epoch!87
epoch!88
epoch!89
epoch!90
epoch!91
epoch!92
epoch!93
epoch!94
epoch!95
epoc

Best trial: 2. Best value: 0.492544:  15%|█▌        | 3/20 [06:55<37:40, 132.95s/it]

[I 2025-02-22 18:48:58,323] Trial 2 finished with value: 0.49254355111757586 and parameters: {'dropout': 0.2216968971838151, 'num_first': True, 'batch_first': False, 'num_layers': 2, 'num_heads': 8}. Best is trial 2 with value: 0.49254355111757586.




  0%|          | 0/300 [00:00<?, ?it/s]
Hype-parameter:[dropout:0.28242799368681437,num_first:True,batch_first:False,num_heads:4,num_layers:2]
epoch!0
epoch!1
epoch!2
epoch!3
epoch!4
epoch!5
epoch!6
epoch!7
epoch!8
epoch!9
epoch!10
epoch!11
epoch!12
epoch!13
epoch!14
epoch!15
epoch!16
epoch!17
epoch!18
epoch!19
epoch!20
epoch!21
epoch!22
epoch!23
epoch!24
epoch!25
epoch!26
epoch!27
epoch!28
epoch!29
epoch!30
epoch!31
epoch!32
epoch!33
epoch!34
epoch!35
epoch!36
epoch!37
epoch!38
epoch!39
epoch!40
epoch!41
epoch!42
epoch!43
epoch!44
epoch!45
epoch!46
epoch!47
epoch!48
epoch!49
epoch!50
epoch!51
epoch!52
epoch!53
epoch!54
epoch!55
epoch!56
epoch!57
epoch!58
epoch!59
epoch!60
epoch!61
epoch!62
epoch!63
epoch!64
epoch!65
epoch!66
epoch!67
epoch!68
epoch!69
epoch!70
epoch!71
epoch!72
epoch!73
epoch!74
epoch!75
epoch!76
epoch!77
epoch!78
epoch!79
epoch!80
epoch!81
epoch!82
epoch!83
epoch!84
epoch!85
epoch!86
epoch!87
epoch!88
epoch!89
epoch!90
epoch!91
epoch!92
epoch!93
epoch!94
epoch!95
epo

Best trial: 2. Best value: 0.492544:  20%|██        | 4/20 [08:29<31:20, 117.51s/it]

[I 2025-02-22 18:50:32,156] Trial 3 finished with value: 0.4518108353267789 and parameters: {'dropout': 0.28242799368681437, 'num_first': True, 'batch_first': False, 'num_layers': 2, 'num_heads': 4}. Best is trial 2 with value: 0.49254355111757586.
  0%|          | 0/300 [00:00<?, ?it/s]
Hype-parameter:[dropout:0.12602063719411183,num_first:False,batch_first:True,num_heads:4,num_layers:2]
epoch!0
epoch!1
epoch!2
epoch!3
epoch!4
epoch!5
epoch!6
epoch!7
epoch!8
epoch!9
epoch!10
epoch!11
epoch!12
epoch!13
epoch!14
epoch!15
epoch!16
epoch!17
epoch!18
epoch!19
epoch!20
epoch!21
epoch!22
epoch!23
epoch!24
epoch!25
epoch!26
epoch!27
epoch!28
epoch!29
epoch!30
epoch!31
epoch!32
epoch!33
epoch!34
epoch!35
epoch!36
epoch!37
epoch!38
epoch!39
epoch!40
epoch!41
epoch!42
epoch!43
epoch!44
epoch!45
epoch!46
epoch!47
epoch!48
epoch!49
epoch!50
epoch!51
epoch!52
epoch!53
epoch!54
epoch!55
epoch!56
epoch!57
epoch!58
epoch!59
epoch!60
epoch!61
epoch!62
epoch!63
epoch!64
epoch!65
epoch!66
epoch!67
epoch!

Best trial: 4. Best value: 0.676671:  25%|██▌       | 5/20 [10:12<28:04, 112.29s/it]

[I 2025-02-22 18:52:15,207] Trial 4 finished with value: 0.6766713372469819 and parameters: {'dropout': 0.12602063719411183, 'num_first': False, 'batch_first': True, 'num_layers': 2, 'num_heads': 4}. Best is trial 4 with value: 0.6766713372469819.




  0%|          | 0/300 [00:00<?, ?it/s]
Hype-parameter:[dropout:0.14881529393791154,num_first:True,batch_first:True,num_heads:8,num_layers:6]
epoch!0
epoch!1
epoch!2
epoch!3
epoch!4
epoch!5
epoch!6
epoch!7
epoch!8
epoch!9
epoch!10
epoch!11
epoch!12
epoch!13
epoch!14
epoch!15
epoch!16
epoch!17
epoch!18
epoch!19
epoch!20
epoch!21
epoch!22
epoch!23
epoch!24
epoch!25
epoch!26
epoch!27
epoch!28
epoch!29
epoch!30
epoch!31
epoch!32
epoch!33
epoch!34
epoch!35
epoch!36
epoch!37
epoch!38
epoch!39
epoch!40
epoch!41
epoch!42
epoch!43
epoch!44
epoch!45
epoch!46
epoch!47
epoch!48
epoch!49
epoch!50
epoch!51
epoch!52
epoch!53
epoch!54
epoch!55
epoch!56
epoch!57
epoch!58
epoch!59
epoch!60
epoch!61
epoch!62
epoch!63
epoch!64
epoch!65
epoch!66
epoch!67
epoch!68
epoch!69
epoch!70
epoch!71
epoch!72
epoch!73
epoch!74
epoch!75
epoch!76
epoch!77
epoch!78
epoch!79
epoch!80
epoch!81
epoch!82
epoch!83
epoch!84
epoch!85
epoch!86
epoch!87
epoch!88
epoch!89
epoch!90
epoch!91
epoch!92
epoch!93
epoch!94
epoch!95
epoc

Best trial: 4. Best value: 0.676671:  30%|███       | 6/20 [14:31<37:51, 162.24s/it]

[I 2025-02-22 18:56:34,414] Trial 5 finished with value: -0.2224975941788036 and parameters: {'dropout': 0.14881529393791154, 'num_first': True, 'batch_first': True, 'num_layers': 6, 'num_heads': 8}. Best is trial 4 with value: 0.6766713372469819.




  0%|          | 0/300 [00:00<?, ?it/s]
Hype-parameter:[dropout:0.31868411173731187,num_first:False,batch_first:False,num_heads:8,num_layers:8]
epoch!0
epoch!1
epoch!2
epoch!3
epoch!4
epoch!5
epoch!6
epoch!7
epoch!8
epoch!9
epoch!10
epoch!11
epoch!12
epoch!13
epoch!14
epoch!15
epoch!16
epoch!17
epoch!18
epoch!19
epoch!20
epoch!21
epoch!22
epoch!23
epoch!24
epoch!25
epoch!26
epoch!27
epoch!28
epoch!29
epoch!30
epoch!31
epoch!32
epoch!33
epoch!34
epoch!35
epoch!36
epoch!37
epoch!38
epoch!39
epoch!40
epoch!41
epoch!42
epoch!43
epoch!44
epoch!45
epoch!46
epoch!47
epoch!48
epoch!49
epoch!50
epoch!51
epoch!52
epoch!53
epoch!54
epoch!55
epoch!56
epoch!57
epoch!58
epoch!59
epoch!60
epoch!61
epoch!62
epoch!63
epoch!64
epoch!65
epoch!66
epoch!67
epoch!68
epoch!69
epoch!70
epoch!71
epoch!72
epoch!73
epoch!74
epoch!75
epoch!76
epoch!77
epoch!78
epoch!79
epoch!80
epoch!81
epoch!82
epoch!83
epoch!84
epoch!85
epoch!86
epoch!87
epoch!88
epoch!89
epoch!90
epoch!91
epoch!92
epoch!93
epoch!94
epoch!95
ep

Best trial: 4. Best value: 0.676671:  35%|███▌      | 7/20 [19:55<46:37, 215.20s/it]

[I 2025-02-22 19:01:58,635] Trial 6 finished with value: 0.39741748593188386 and parameters: {'dropout': 0.31868411173731187, 'num_first': False, 'batch_first': False, 'num_layers': 8, 'num_heads': 8}. Best is trial 4 with value: 0.6766713372469819.




  0%|          | 0/300 [00:00<?, ?it/s]
Hype-parameter:[dropout:0.1353970008207678,num_first:True,batch_first:False,num_heads:4,num_layers:3]
epoch!0
epoch!1
epoch!2
epoch!3
epoch!4
epoch!5
epoch!6
epoch!7
epoch!8
epoch!9
epoch!10
epoch!11
epoch!12
epoch!13
epoch!14
epoch!15
epoch!16
epoch!17
epoch!18
epoch!19
epoch!20
epoch!21
epoch!22
epoch!23
epoch!24
epoch!25
epoch!26
epoch!27
epoch!28
epoch!29
epoch!30
epoch!31
epoch!32
epoch!33
epoch!34
epoch!35
epoch!36
epoch!37
epoch!38
epoch!39
epoch!40
epoch!41
epoch!42
epoch!43
epoch!44
epoch!45
epoch!46
epoch!47
epoch!48
epoch!49
epoch!50
epoch!51
epoch!52
epoch!53
epoch!54
epoch!55
epoch!56
epoch!57
epoch!58
epoch!59
epoch!60
epoch!61
epoch!62
epoch!63
epoch!64
epoch!65
epoch!66
epoch!67
epoch!68
epoch!69
epoch!70
epoch!71
epoch!72
epoch!73
epoch!74
epoch!75
epoch!76
epoch!77
epoch!78
epoch!79
epoch!80
epoch!81
epoch!82
epoch!83
epoch!84
epoch!85
epoch!86
epoch!87
epoch!88
epoch!89
epoch!90
epoch!91
epoch!92
epoch!93
epoch!94
epoch!95
epoc

Best trial: 4. Best value: 0.676671:  40%|████      | 8/20 [21:57<37:03, 185.29s/it]

[I 2025-02-22 19:03:59,875] Trial 7 finished with value: 0.5046595511834501 and parameters: {'dropout': 0.1353970008207678, 'num_first': True, 'batch_first': False, 'num_layers': 3, 'num_heads': 4}. Best is trial 4 with value: 0.6766713372469819.




  0%|          | 0/300 [00:00<?, ?it/s]
Hype-parameter:[dropout:0.2123738038749523,num_first:True,batch_first:True,num_heads:4,num_layers:8]
epoch!0
epoch!1
epoch!2
epoch!3
epoch!4
epoch!5
epoch!6
epoch!7
epoch!8
epoch!9
epoch!10
epoch!11
epoch!12
epoch!13
epoch!14
epoch!15
epoch!16
epoch!17
epoch!18
epoch!19
epoch!20
epoch!21
epoch!22
epoch!23
epoch!24
epoch!25
epoch!26
epoch!27
epoch!28
epoch!29
epoch!30
epoch!31
epoch!32
epoch!33
epoch!34
epoch!35
epoch!36
epoch!37
epoch!38
epoch!39
epoch!40
epoch!41
epoch!42
epoch!43
epoch!44
epoch!45
epoch!46
epoch!47
epoch!48
epoch!49
epoch!50
epoch!51
epoch!52
epoch!53
epoch!54
epoch!55
epoch!56
epoch!57
epoch!58
epoch!59
epoch!60
epoch!61
epoch!62
epoch!63
epoch!64
epoch!65
epoch!66
epoch!67
epoch!68
epoch!69
epoch!70
epoch!71
epoch!72
epoch!73
epoch!74
epoch!75
epoch!76
epoch!77
epoch!78
epoch!79
epoch!80
epoch!81
epoch!82
epoch!83
epoch!84
epoch!85
epoch!86
epoch!87
epoch!88
epoch!89
epoch!90
epoch!91
epoch!92
epoch!93
epoch!94
epoch!95
epoch

Best trial: 4. Best value: 0.676671:  45%|████▌     | 9/20 [26:49<40:04, 218.63s/it]

[I 2025-02-22 19:08:51,807] Trial 8 finished with value: 0.09316597233902979 and parameters: {'dropout': 0.2123738038749523, 'num_first': True, 'batch_first': True, 'num_layers': 8, 'num_heads': 4}. Best is trial 4 with value: 0.6766713372469819.




  0%|          | 0/300 [00:00<?, ?it/s]
Hype-parameter:[dropout:0.10220884684944097,num_first:True,batch_first:False,num_heads:4,num_layers:2]
epoch!0
epoch!1
epoch!2
epoch!3
epoch!4
epoch!5
epoch!6
epoch!7
epoch!8
epoch!9
epoch!10
epoch!11
epoch!12
epoch!13
epoch!14
epoch!15
epoch!16
epoch!17
epoch!18
epoch!19
epoch!20
epoch!21
epoch!22
epoch!23
epoch!24
epoch!25
epoch!26
epoch!27
epoch!28
epoch!29
epoch!30
epoch!31
epoch!32
epoch!33
epoch!34
epoch!35
epoch!36
epoch!37
epoch!38
epoch!39
epoch!40
epoch!41
epoch!42
epoch!43
epoch!44
epoch!45
epoch!46
epoch!47
epoch!48
epoch!49
epoch!50
epoch!51
epoch!52
epoch!53
epoch!54
epoch!55
epoch!56
epoch!57
epoch!58
epoch!59
epoch!60
epoch!61
epoch!62
epoch!63
epoch!64
epoch!65
epoch!66
epoch!67
epoch!68
epoch!69
epoch!70
epoch!71
epoch!72
epoch!73
epoch!74
epoch!75
epoch!76
epoch!77
epoch!78
epoch!79
epoch!80
epoch!81
epoch!82
epoch!83
epoch!84
epoch!85
epoch!86
epoch!87
epoch!88
epoch!89
epoch!90
epoch!91
epoch!92
epoch!93
epoch!94
epoch!95
epo

Best trial: 4. Best value: 0.676671:  50%|█████     | 10/20 [28:20<29:52, 179.28s/it]

[I 2025-02-22 19:10:22,989] Trial 9 finished with value: 0.3460211660793039 and parameters: {'dropout': 0.10220884684944097, 'num_first': True, 'batch_first': False, 'num_layers': 2, 'num_heads': 4}. Best is trial 4 with value: 0.6766713372469819.
  0%|          | 0/300 [00:00<?, ?it/s]
Hype-parameter:[dropout:0.47805461769215485,num_first:False,batch_first:True,num_heads:4,num_layers:5]
epoch!0
epoch!1
epoch!2
epoch!3
epoch!4
epoch!5
epoch!6
epoch!7
epoch!8
epoch!9
epoch!10
epoch!11
epoch!12
epoch!13
epoch!14
epoch!15
epoch!16
epoch!17
epoch!18
epoch!19
epoch!20
epoch!21
epoch!22
epoch!23
epoch!24
epoch!25
epoch!26
epoch!27
epoch!28
epoch!29
epoch!30
epoch!31
epoch!32
epoch!33
epoch!34
epoch!35
epoch!36
epoch!37
epoch!38
epoch!39
epoch!40
epoch!41
epoch!42
epoch!43
epoch!44
epoch!45
epoch!46
epoch!47
epoch!48
epoch!49
epoch!50
epoch!51
epoch!52
epoch!53
epoch!54
epoch!55
epoch!56
epoch!57
epoch!58
epoch!59
epoch!60
epoch!61
epoch!62
epoch!63
epoch!64
epoch!65
epoch!66
epoch!67
epoch!6

Best trial: 10. Best value: 0.731606:  55%|█████▌    | 11/20 [31:27<27:14, 181.60s/it]

[I 2025-02-22 19:13:29,861] Trial 10 finished with value: 0.7316059474439767 and parameters: {'dropout': 0.47805461769215485, 'num_first': False, 'batch_first': True, 'num_layers': 5, 'num_heads': 4}. Best is trial 10 with value: 0.7316059474439767.
  0%|          | 0/300 [00:00<?, ?it/s]
Hype-parameter:[dropout:0.4630659181130072,num_first:False,batch_first:True,num_heads:4,num_layers:5]
epoch!0
epoch!1
epoch!2
epoch!3
epoch!4
epoch!5
epoch!6
epoch!7
epoch!8
epoch!9
epoch!10
epoch!11
epoch!12
epoch!13
epoch!14
epoch!15
epoch!16
epoch!17
epoch!18
epoch!19
epoch!20
epoch!21
epoch!22
epoch!23
epoch!24
epoch!25
epoch!26
epoch!27
epoch!28
epoch!29
epoch!30
epoch!31
epoch!32
epoch!33
epoch!34
epoch!35
epoch!36
epoch!37
epoch!38
epoch!39
epoch!40
epoch!41
epoch!42
epoch!43
epoch!44
epoch!45
epoch!46
epoch!47
epoch!48
epoch!49
epoch!50
epoch!51
epoch!52
epoch!53
epoch!54
epoch!55
epoch!56
epoch!57
epoch!58
epoch!59
epoch!60
epoch!61
epoch!62
epoch!63
epoch!64
epoch!65
epoch!66
epoch!67
epoch!

Best trial: 10. Best value: 0.731606:  60%|██████    | 12/20 [34:35<24:30, 183.75s/it]

[I 2025-02-22 19:16:38,519] Trial 11 finished with value: 0.7235858521191011 and parameters: {'dropout': 0.4630659181130072, 'num_first': False, 'batch_first': True, 'num_layers': 5, 'num_heads': 4}. Best is trial 10 with value: 0.7316059474439767.
  0%|          | 0/300 [00:00<?, ?it/s]
Hype-parameter:[dropout:0.48795153806627994,num_first:False,batch_first:True,num_heads:4,num_layers:5]
epoch!0
epoch!1
epoch!2
epoch!3
epoch!4
epoch!5
epoch!6
epoch!7
epoch!8
epoch!9
epoch!10
epoch!11
epoch!12
epoch!13
epoch!14
epoch!15
epoch!16
epoch!17
epoch!18
epoch!19
epoch!20
epoch!21
epoch!22
epoch!23
epoch!24
epoch!25
epoch!26
epoch!27
epoch!28
epoch!29
epoch!30
epoch!31
epoch!32
epoch!33
epoch!34
epoch!35
epoch!36
epoch!37
epoch!38
epoch!39
epoch!40
epoch!41
epoch!42
epoch!43
epoch!44
epoch!45
epoch!46
epoch!47
epoch!48
epoch!49
epoch!50
epoch!51
epoch!52
epoch!53
epoch!54
epoch!55
epoch!56
epoch!57
epoch!58
epoch!59
epoch!60
epoch!61
epoch!62
epoch!63
epoch!64
epoch!65
epoch!66
epoch!67
epoch!

Best trial: 10. Best value: 0.731606:  65%|██████▌   | 13/20 [37:38<21:23, 183.32s/it]

[I 2025-02-22 19:19:40,838] Trial 12 finished with value: 0.7250771825341151 and parameters: {'dropout': 0.48795153806627994, 'num_first': False, 'batch_first': True, 'num_layers': 5, 'num_heads': 4}. Best is trial 10 with value: 0.7316059474439767.
  0%|          | 0/300 [00:00<?, ?it/s]
Hype-parameter:[dropout:0.4871260012059897,num_first:False,batch_first:True,num_heads:4,num_layers:5]
epoch!0
epoch!1
epoch!2
epoch!3
epoch!4
epoch!5
epoch!6
epoch!7
epoch!8
epoch!9
epoch!10
epoch!11
epoch!12
epoch!13
epoch!14
epoch!15
epoch!16
epoch!17
epoch!18
epoch!19
epoch!20
epoch!21
epoch!22
epoch!23
epoch!24
epoch!25
epoch!26
epoch!27
epoch!28
epoch!29
epoch!30
epoch!31
epoch!32
epoch!33
epoch!34
epoch!35
epoch!36
epoch!37
epoch!38
epoch!39
epoch!40
epoch!41
epoch!42
epoch!43
epoch!44
epoch!45
epoch!46
epoch!47
epoch!48
epoch!49
epoch!50
epoch!51
epoch!52
epoch!53
epoch!54
epoch!55
epoch!56
epoch!57
epoch!58
epoch!59
epoch!60
epoch!61
epoch!62
epoch!63
epoch!64
epoch!65
epoch!66
epoch!67
epoch!

Best trial: 10. Best value: 0.731606:  70%|███████   | 14/20 [41:09<19:10, 191.72s/it]

[I 2025-02-22 19:23:11,985] Trial 13 finished with value: 0.6822523472123603 and parameters: {'dropout': 0.4871260012059897, 'num_first': False, 'batch_first': True, 'num_layers': 5, 'num_heads': 4}. Best is trial 10 with value: 0.7316059474439767.
  0%|          | 0/300 [00:00<?, ?it/s]
Hype-parameter:[dropout:0.416587732899981,num_first:False,batch_first:True,num_heads:4,num_layers:6]
epoch!0
epoch!1
epoch!2
epoch!3
epoch!4
epoch!5
epoch!6
epoch!7
epoch!8
epoch!9
epoch!10
epoch!11
epoch!12
epoch!13
epoch!14
epoch!15
epoch!16
epoch!17
epoch!18
epoch!19
epoch!20
epoch!21
epoch!22
epoch!23
epoch!24
epoch!25
epoch!26
epoch!27
epoch!28
epoch!29
epoch!30
epoch!31
epoch!32
epoch!33
epoch!34
epoch!35
epoch!36
epoch!37
epoch!38
epoch!39
epoch!40
epoch!41
epoch!42
epoch!43
epoch!44
epoch!45
epoch!46
epoch!47
epoch!48
epoch!49
epoch!50
epoch!51
epoch!52
epoch!53
epoch!54
epoch!55
epoch!56
epoch!57
epoch!58
epoch!59
epoch!60
epoch!61
epoch!62
epoch!63
epoch!64
epoch!65
epoch!66
epoch!67
epoch!68

Best trial: 10. Best value: 0.731606:  75%|███████▌  | 15/20 [44:53<16:47, 201.53s/it]

[I 2025-02-22 19:26:56,238] Trial 14 finished with value: 0.6696073595501961 and parameters: {'dropout': 0.416587732899981, 'num_first': False, 'batch_first': True, 'num_layers': 6, 'num_heads': 4}. Best is trial 10 with value: 0.7316059474439767.
  0%|          | 0/300 [00:00<?, ?it/s]
Hype-parameter:[dropout:0.4012897576105889,num_first:False,batch_first:True,num_heads:4,num_layers:6]
epoch!0
epoch!1
epoch!2
epoch!3
epoch!4
epoch!5
epoch!6
epoch!7
epoch!8
epoch!9
epoch!10
epoch!11
epoch!12
epoch!13
epoch!14
epoch!15
epoch!16
epoch!17
epoch!18
epoch!19
epoch!20
epoch!21
epoch!22
epoch!23
epoch!24
epoch!25
epoch!26
epoch!27
epoch!28
epoch!29
epoch!30
epoch!31
epoch!32
epoch!33
epoch!34
epoch!35
epoch!36
epoch!37
epoch!38
epoch!39
epoch!40
epoch!41
epoch!42
epoch!43
epoch!44
epoch!45
epoch!46
epoch!47
epoch!48
epoch!49
epoch!50
epoch!51
epoch!52
epoch!53
epoch!54
epoch!55
epoch!56
epoch!57
epoch!58
epoch!59
epoch!60
epoch!61
epoch!62
epoch!63
epoch!64
epoch!65
epoch!66
epoch!67
epoch!68

Best trial: 15. Best value: 0.76467:  80%|████████  | 16/20 [48:26<13:39, 204.96s/it] 

[I 2025-02-22 19:30:29,181] Trial 15 finished with value: 0.7646698858385752 and parameters: {'dropout': 0.4012897576105889, 'num_first': False, 'batch_first': True, 'num_layers': 6, 'num_heads': 4}. Best is trial 15 with value: 0.7646698858385752.
  0%|          | 0/300 [00:00<?, ?it/s]
Hype-parameter:[dropout:0.39236452092136653,num_first:False,batch_first:True,num_heads:4,num_layers:7]
epoch!0
epoch!1
epoch!2
epoch!3
epoch!4
epoch!5
epoch!6
epoch!7
epoch!8
epoch!9
epoch!10
epoch!11
epoch!12
epoch!13
epoch!14
epoch!15
epoch!16
epoch!17
epoch!18
epoch!19
epoch!20
epoch!21
epoch!22
epoch!23
epoch!24
epoch!25
epoch!26
epoch!27
epoch!28
epoch!29
epoch!30
epoch!31
epoch!32
epoch!33
epoch!34
epoch!35
epoch!36
epoch!37
epoch!38
epoch!39
epoch!40
epoch!41
epoch!42
epoch!43
epoch!44
epoch!45
epoch!46
epoch!47
epoch!48
epoch!49
epoch!50
epoch!51
epoch!52
epoch!53
epoch!54
epoch!55
epoch!56
epoch!57
epoch!58
epoch!59
epoch!60
epoch!61
epoch!62
epoch!63
epoch!64
epoch!65
epoch!66
epoch!67
epoch!

Best trial: 15. Best value: 0.76467:  85%|████████▌ | 17/20 [52:40<10:59, 219.79s/it]

[I 2025-02-22 19:34:43,460] Trial 16 finished with value: 0.7048519667569416 and parameters: {'dropout': 0.39236452092136653, 'num_first': False, 'batch_first': True, 'num_layers': 7, 'num_heads': 4}. Best is trial 15 with value: 0.7646698858385752.
  0%|          | 0/300 [00:00<?, ?it/s]
Hype-parameter:[dropout:0.41036179992175853,num_first:False,batch_first:True,num_heads:4,num_layers:4]
epoch!0
epoch!1
epoch!2
epoch!3
epoch!4
epoch!5
epoch!6
epoch!7
epoch!8
epoch!9
epoch!10
epoch!11
epoch!12
epoch!13
epoch!14
epoch!15
epoch!16
epoch!17
epoch!18
epoch!19
epoch!20
epoch!21
epoch!22
epoch!23
epoch!24
epoch!25
epoch!26
epoch!27
epoch!28
epoch!29
epoch!30
epoch!31
epoch!32
epoch!33
epoch!34
epoch!35
epoch!36
epoch!37
epoch!38
epoch!39
epoch!40
epoch!41
epoch!42
epoch!43
epoch!44
epoch!45
epoch!46
epoch!47
epoch!48
epoch!49
epoch!50
epoch!51
epoch!52
epoch!53
epoch!54
epoch!55
epoch!56
epoch!57
epoch!58
epoch!59
epoch!60
epoch!61
epoch!62
epoch!63
epoch!64
epoch!65
epoch!66
epoch!67
epoch

Best trial: 15. Best value: 0.76467:  90%|█████████ | 18/20 [55:06<06:35, 197.54s/it]

[I 2025-02-22 19:37:09,196] Trial 17 finished with value: 0.4865077489215951 and parameters: {'dropout': 0.41036179992175853, 'num_first': False, 'batch_first': True, 'num_layers': 4, 'num_heads': 4}. Best is trial 15 with value: 0.7646698858385752.
  0%|          | 0/300 [00:00<?, ?it/s]
Hype-parameter:[dropout:0.36725582885275265,num_first:False,batch_first:True,num_heads:4,num_layers:6]
epoch!0
epoch!1
epoch!2
epoch!3
epoch!4
epoch!5
epoch!6
epoch!7
epoch!8
epoch!9
epoch!10
epoch!11
epoch!12
epoch!13
epoch!14
epoch!15
epoch!16
epoch!17
epoch!18
epoch!19
epoch!20
epoch!21
epoch!22
epoch!23
epoch!24
epoch!25
epoch!26
epoch!27
epoch!28
epoch!29
epoch!30
epoch!31
epoch!32
epoch!33
epoch!34
epoch!35
epoch!36
epoch!37
epoch!38
epoch!39
epoch!40
epoch!41
epoch!42
epoch!43
epoch!44
epoch!45
epoch!46
epoch!47
epoch!48
epoch!49
epoch!50
epoch!51
epoch!52
epoch!53
epoch!54
epoch!55
epoch!56
epoch!57
epoch!58
epoch!59
epoch!60
epoch!61
epoch!62
epoch!63
epoch!64
epoch!65
epoch!66
epoch!67
epoch

Best trial: 15. Best value: 0.76467:  95%|█████████▌| 19/20 [58:59<03:28, 208.32s/it]

[I 2025-02-22 19:41:02,641] Trial 18 finished with value: 0.7120339538170241 and parameters: {'dropout': 0.36725582885275265, 'num_first': False, 'batch_first': True, 'num_layers': 6, 'num_heads': 4}. Best is trial 15 with value: 0.7646698858385752.
  0%|          | 0/300 [00:00<?, ?it/s]
Hype-parameter:[dropout:0.4429564520037397,num_first:False,batch_first:True,num_heads:4,num_layers:7]
epoch!0
epoch!1
epoch!2
epoch!3
epoch!4
epoch!5
epoch!6
epoch!7
epoch!8
epoch!9
epoch!10
epoch!11
epoch!12
epoch!13
epoch!14
epoch!15
epoch!16
epoch!17
epoch!18
epoch!19
epoch!20
epoch!21
epoch!22
epoch!23
epoch!24
epoch!25
epoch!26
epoch!27
epoch!28
epoch!29
epoch!30
epoch!31
epoch!32
epoch!33
epoch!34
epoch!35
epoch!36
epoch!37
epoch!38
epoch!39
epoch!40
epoch!41
epoch!42
epoch!43
epoch!44
epoch!45
epoch!46
epoch!47
epoch!48
epoch!49
epoch!50
epoch!51
epoch!52
epoch!53
epoch!54
epoch!55
epoch!56
epoch!57
epoch!58
epoch!59
epoch!60
epoch!61
epoch!62
epoch!63
epoch!64
epoch!65
epoch!66
epoch!67
epoch!

Best trial: 15. Best value: 0.76467: 100%|██████████| 20/20 [1:03:30<00:00, 190.54s/it]


[I 2025-02-22 19:45:33,454] Trial 19 finished with value: 0.37506619210469727 and parameters: {'dropout': 0.4429564520037397, 'num_first': False, 'batch_first': True, 'num_layers': 7, 'num_heads': 4}. Best is trial 15 with value: 0.7646698858385752.



Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  sns.countplot(

Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  sns.countplot(



Optimal Parameters: {'dropout': 0.4012897576105889, 'num_first': False, 'batch_first': True, 'num_layers': 6, 'num_heads': 4}
Optimal Parameters:
- dropout: 0.4012897576105889
- num_first: False
- batch_first: True
- num_layers: 6
- num_heads: 4
