In [None]:
!pip install git+https://github.com/Ashhad785/synthcity.git

In [None]:
!pip uninstall -y torchaudio torchdata
!pip install pycox
from pycox import datasets
from synthcity.metrics import Metrics
from synthcity.plugins.core.dataloader import SurvivalAnalysisDataLoader
import numpy as np
import pandas as pd
import os
import shutil
from timeit import default_timer as timer
from synthcity.plugins import Plugins

In [None]:
plugin_name="ctgan"

# Functions

In [None]:
from scipy.stats import mannwhitneyu, chi2_contingency,wilcoxon
import matplotlib.pyplot as plt

def identify_variable_types(df):
    continuous_columns = []
    discrete_columns = []

    for col in df.columns:
        unique_vals = df[col].unique()
        num_unique = len(unique_vals)
        if num_unique > 20:  # Threshold for considering a column as continuous
            continuous_columns.append(col)
        else:
            discrete_columns.append(col)

    return continuous_columns, discrete_columns

def compare_distributions(real_df, synthetic_df, alpha=0.05):
    real_df = real_df.drop(['duration', 'event'], axis=1)
    real_continuous, real_discrete = identify_variable_types(real_df)
    p_values_continuous = {}
    p_values_discrete = {}

    synthetic_df = synthetic_df.drop(['duration', 'event'], axis=1)
    synthetic_continuous, synthetic_discrete = identify_variable_types(synthetic_df)

    synthetic_continuous = [col for col in synthetic_continuous if col not in ["event", "duration"]]
    synthetic_discrete = [col for col in synthetic_discrete if col not in ["event", "duration"]]

    # Wilcoxon rank-sum test for continuous variables
    for col in real_continuous:
        if col in synthetic_continuous:
            _, p_value = mannwhitneyu(real_df[col], synthetic_df[col])
            p_values_continuous[col] = p_value

    # Chi-square test for discrete variables
    for col in real_discrete:
        if col in synthetic_discrete:
            contingency_table = pd.crosstab(real_df[col], synthetic_df[col])
            _, p, _, _ = chi2_contingency(contingency_table)
            p_values_discrete[col] = p

    # Plot p-values
    # plt.figure(figsize=(10, 6))

    # continuous_p_values = {col: p_values_continuous[col] for col in real_continuous}
    # discrete_p_values = {col: p_values_discrete[col] for col in real_discrete}

    # plt.plot(list(continuous_p_values.keys()), list(continuous_p_values.values()), label='Continuous', marker='o')
    # plt.plot(list(discrete_p_values.keys()), list(discrete_p_values.values()), label='Discrete', marker='o', linestyle='dashed')

    # # Plot alpha line
    # plt.axhline(y=alpha, color='red', linestyle='--', label=f'alpha = {alpha}')

    # plt.xlabel('Column Name')
    # plt.ylabel('p-value')
    # plt.title('Comparison of p-values for Real and Synthetic Data')
    # plt.xticks(rotation=45)
    # plt.legend()
    # plt.grid(True)
    # plt.tight_layout()
    # plt.show()

    return p_values_continuous, p_values_discrete



# FLCHAIN

In [None]:
dataset="flchain"

metrics_list = []
fit_times = []
generate_times = []
p_values_list = []

for i in range(5):
    df = pd.read_csv('/content/drive/MyDrive/Datasets/flchain_final.csv')
    df = df.drop('Unnamed: 0', axis=1)
    df = df[df['duration'] != 0]

    syn_model = Plugins().get(plugin_name)
    X=df
    # Measure the execution time of the fit function
    start = timer()
    syn_model.fit(X)
    fit_time = timer() - start
    fit_times.append(fit_time)

    random_state = i + 1
    np.random.seed(random_state)

    # Measure the execution time of the generate function
    start = timer()
    X_gen = syn_model.generate(count=len(df)).dataframe()
    generate_time = timer() - start
    generate_times.append(generate_time)

    # Save X_gen as a CSV file
    X_gen.to_csv(f"/content/drive/MyDrive/Nips/{dataset}_{plugin_name}_nocond_iteration_{i+1}.csv", index=False)

    loader1 = SurvivalAnalysisDataLoader(df, target_column="event", time_to_event_column="duration")
    loader2 = SurvivalAnalysisDataLoader(X_gen, target_column="event", time_to_event_column="duration")

    met_df = Metrics.evaluate(X_gt=loader1, X_syn=loader2, task_type='survival_analysis', metrics={
        'stats': ['jensenshannon_dist', 'chi_squared_test', 'feature_corr', 'inv_kl_divergence', 'ks_test',
                 'max_mean_discrepancy', 'wasserstein_dist', 'prdc', 'alpha_precision', 'survival_km_distance'],
        'performance': ['linear_model', 'mlp', 'xgb', 'feat_rank_distance']
    }, use_cache=False, random_state=random_state)

    met_df = met_df.iloc[:, 0]
    metrics_list.append(met_df)

    # Calculate p-values
    p_values_continuous, p_values_discrete = compare_distributions(df, X_gen)
    continuous_column_names = list(p_values_continuous.keys())
    discrete_column_names = list(p_values_discrete.keys())

    p_val = np.concatenate([list(p_values_continuous.values()), list(p_values_discrete.values())])
    p_values_list.append(p_val)

    workspace_dir = os.path.join(os.getcwd(), 'workspace')
    if os.path.exists(workspace_dir):
        shutil.rmtree(workspace_dir)

result_df = pd.concat(metrics_list, axis=1)

# Calculate the row-wise mean and standard deviation of the metrics
result_df['Mean'] = result_df.mean(axis=1)
result_df['Std'] = result_df.std(axis=1)
result_df['Std'] = result_df['Std'].round(4)

# Create DataFrame for p-values
p_values_df = pd.DataFrame(p_values_list, columns=continuous_column_names + discrete_column_names)

# Save result_df and p_values_df as CSV files
result_df.to_csv(f"/content/drive/MyDrive/Nips/{dataset}_{plugin_name}_nocond_result_df.csv")
p_values_df.to_csv(f"/content/drive/MyDrive/Nips/{dataset}_{plugin_name}_nocond_p_values_df.csv")

avg_fit_time = np.mean(fit_times)
avg_generate_time = np.mean(generate_times)
std_fit_time = np.std(fit_times)
std_generate_time = np.std(generate_times)

print(f"\nAverage Fit Time: {avg_fit_time:.4f} seconds, Standard Deviation: {std_fit_time:.4f} seconds")
print(f"Average Generate Time: {avg_generate_time:.4f} seconds, Standard Deviation: {std_generate_time:.4f} seconds")


[2024-04-19T22:38:21.591519+0000][1377][CRITICAL] module disabled: /usr/local/lib/python3.10/dist-packages/synthcity/plugins/generic/plugin_goggle.py
 25%|██▍       | 499/2000 [08:22<25:11,  1.01s/it]


[KeOps] Generating code for Max_SumShiftExpWeight_Reduction reduction (with parameters 0) of formula [c-1/2*(d*Sum((a-b)**2)),1] with a=Var(0,11,0), b=Var(1,11,1), c=Var(2,1,1), d=Var(3,1,2) ... OK


[2024-04-19T22:52:52.570611+0000][1377][CRITICAL] module disabled: /usr/local/lib/python3.10/dist-packages/synthcity/plugins/generic/plugin_goggle.py
 25%|██▍       | 499/2000 [08:26<25:22,  1.01s/it]




[2024-04-19T23:07:12.692475+0000][1377][CRITICAL] module disabled: /usr/local/lib/python3.10/dist-packages/synthcity/plugins/generic/plugin_goggle.py
 25%|██▍       | 499/2000 [08:27<25:26,  1.02s/it]




[2024-04-19T23:21:53.888532+0000][1377][CRITICAL] module disabled: /usr/local/lib/python3.10/dist-packages/synthcity/plugins/generic/plugin_goggle.py
 25%|██▍       | 499/2000 [08:26<25:24,  1.02s/it]




[2024-04-19T23:36:32.468572+0000][1377][CRITICAL] module disabled: /usr/local/lib/python3.10/dist-packages/synthcity/plugins/generic/plugin_goggle.py
 25%|██▍       | 499/2000 [08:27<25:27,  1.02s/it]



Average Fit Time: 513.0025 seconds, Standard Deviation: 1.9139 seconds
Average Generate Time: 0.1245 seconds, Standard Deviation: 0.0117 seconds


In [None]:
result_df

Unnamed: 0,min,min.1,min.2,min.3,min.4,Mean,Std
stats.jensenshannon_dist.marginal,0.005824,0.005791,0.005786,0.005847,0.005831,0.005816,0.0
stats.chi_squared_test.marginal,0.624263,0.624142,0.624317,0.623993,0.623983,0.62414,0.0001
stats.inv_kl_divergence.marginal,0.937778,0.937862,0.93803,0.937964,0.937947,0.937916,0.0001
stats.ks_test.marginal,0.866449,0.866668,0.86668,0.866322,0.866726,0.866569,0.0002
stats.max_mean_discrepancy.joint,0.000262,0.000261,0.000261,0.000261,0.000261,0.000261,0.0
stats.wasserstein_dist.joint,0.021938,0.021545,0.021755,0.021553,0.021347,0.021628,0.0002
stats.prdc.precision,0.964935,0.966586,0.96557,0.966078,0.966713,0.965976,0.0007
stats.prdc.recall,0.986279,0.985516,0.985771,0.986787,0.985644,0.985999,0.0005
stats.prdc.density,0.868276,0.868378,0.872824,0.875086,0.867056,0.870324,0.0031
stats.prdc.coverage,0.809046,0.812603,0.806505,0.812603,0.813874,0.810926,0.0027


In [None]:
p_values_df

Unnamed: 0,age,kappa,lambda,creatinine,sex,sample.yr,flc.grp,mgus,chapter
0,3.685625e-77,7.149978e-13,2.401125e-49,1.253561e-21,0.751146,0.713233,0.392621,0.447718,0.521891
1,3.7850549999999997e-78,1.838366e-13,4.85798e-51,9.947069999999999e-21,0.850218,0.549042,0.285245,0.974322,0.434317
2,1.446438e-76,6.281723e-13,4.545437e-49,1.153858e-20,0.978087,0.73681,0.2766,0.420234,0.131515
3,7.823914e-77,9.208031e-13,2.196384e-50,3.39405e-21,0.881069,0.820678,0.566046,1.0,0.326209
4,5.239544e-76,4.058707e-13,5.393889e-49,1.0127559999999999e-19,0.920359,0.704702,0.431346,0.392921,0.100039


# AIDS

In [None]:
dataset="aids"

metrics_list = []
fit_times = []
generate_times = []
p_values_list = []

for i in range(5):
    df = pd.read_csv('/content/drive/MyDrive/Datasets/aids.csv')
    df = df.drop('Unnamed: 0', axis=1)
    df = df[df['duration'] != 0]

    syn_model = Plugins().get(plugin_name)
    X=df
    # Measure the execution time of the fit function
    start = timer()
    syn_model.fit(X)
    fit_time = timer() - start
    fit_times.append(fit_time)

    random_state = i + 1
    np.random.seed(random_state)

    # Measure the execution time of the generate function
    start = timer()
    X_gen = syn_model.generate(count=len(df)).dataframe()
    generate_time = timer() - start
    generate_times.append(generate_time)

    # Save X_gen as a CSV file
    X_gen.to_csv(f"/content/drive/MyDrive/Nips/{dataset}_{plugin_name}_nocond_iteration_{i+1}.csv", index=False)

    loader1 = SurvivalAnalysisDataLoader(df, target_column="event", time_to_event_column="duration")
    loader2 = SurvivalAnalysisDataLoader(X_gen, target_column="event", time_to_event_column="duration")

    met_df = Metrics.evaluate(X_gt=loader1, X_syn=loader2, task_type='survival_analysis', metrics={
        'stats': ['jensenshannon_dist', 'chi_squared_test', 'feature_corr', 'inv_kl_divergence', 'ks_test',
                 'max_mean_discrepancy', 'wasserstein_dist', 'prdc', 'alpha_precision', 'survival_km_distance'],
        'performance': ['linear_model', 'mlp', 'xgb', 'feat_rank_distance']
    }, use_cache=False, random_state=random_state)

    met_df = met_df.iloc[:, 0]
    metrics_list.append(met_df)

    # Calculate p-values
    p_values_continuous, p_values_discrete = compare_distributions(df, X_gen)
    continuous_column_names = list(p_values_continuous.keys())
    discrete_column_names = list(p_values_discrete.keys())

    p_val = np.concatenate([list(p_values_continuous.values()), list(p_values_discrete.values())])
    p_values_list.append(p_val)

    workspace_dir = os.path.join(os.getcwd(), 'workspace')
    if os.path.exists(workspace_dir):
        shutil.rmtree(workspace_dir)

result_df = pd.concat(metrics_list, axis=1)

# Calculate the row-wise mean and standard deviation of the metrics
result_df['Mean'] = result_df.mean(axis=1)
result_df['Std'] = result_df.std(axis=1)
result_df['Std'] = result_df['Std'].round(4)

# Create DataFrame for p-values
p_values_df = pd.DataFrame(p_values_list, columns=continuous_column_names + discrete_column_names)

# Save result_df and p_values_df as CSV files
result_df.to_csv(f"/content/drive/MyDrive/Nips/{dataset}_{plugin_name}_nocond_result_df.csv")
p_values_df.to_csv(f"/content/drive/MyDrive/Nips/{dataset}_{plugin_name}_nocond_p_values_df.csv")

avg_fit_time = np.mean(fit_times)
avg_generate_time = np.mean(generate_times)
std_fit_time = np.std(fit_times)
std_generate_time = np.std(generate_times)

print(f"\nAverage Fit Time: {avg_fit_time:.4f} seconds, Standard Deviation: {std_fit_time:.4f} seconds")
print(f"Average Generate Time: {avg_generate_time:.4f} seconds, Standard Deviation: {std_generate_time:.4f} seconds")


[2024-04-19T23:51:07.243674+0000][1377][CRITICAL] module disabled: /usr/local/lib/python3.10/dist-packages/synthcity/plugins/generic/plugin_goggle.py
 72%|███████▏  | 1449/2000 [04:38<01:46,  5.20it/s]
[2024-04-19T23:56:48.914372+0000][1377][CRITICAL] module disabled: /usr/local/lib/python3.10/dist-packages/synthcity/plugins/generic/plugin_goggle.py
 72%|███████▏  | 1449/2000 [04:39<01:46,  5.19it/s]
[2024-04-20T00:02:31.503345+0000][1377][CRITICAL] module disabled: /usr/local/lib/python3.10/dist-packages/synthcity/plugins/generic/plugin_goggle.py
 72%|███████▏  | 1449/2000 [04:38<01:46,  5.19it/s]
[2024-04-20T00:08:13.099380+0000][1377][CRITICAL] module disabled: /usr/local/lib/python3.10/dist-packages/synthcity/plugins/generic/plugin_goggle.py
 72%|███████▏  | 1449/2000 [04:39<01:46,  5.18it/s]
[2024-04-20T00:13:54.293297+0000][1377][CRITICAL] module disabled: /usr/local/lib/python3.10/dist-packages/synthcity/plugins/generic/plugin_goggle.py
 72%|███████▏  | 1449/2000 [04:39<01:46,  


Average Fit Time: 281.0611 seconds, Standard Deviation: 0.4652 seconds
Average Generate Time: 0.0792 seconds, Standard Deviation: 0.0026 seconds


In [None]:
result_df

Unnamed: 0,min,min.1,min.2,min.3,min.4,Mean,Std
stats.jensenshannon_dist.marginal,0.007846,0.008251,0.008159,0.007842,0.007889,0.007997,0.0002
stats.chi_squared_test.marginal,0.83366,0.756574,0.756058,0.757799,0.834232,0.787665,0.0378
stats.inv_kl_divergence.marginal,0.945473,0.942242,0.941453,0.939588,0.951929,0.944137,0.0043
stats.ks_test.marginal,0.940988,0.937847,0.938315,0.94072,0.941389,0.939852,0.0015
stats.max_mean_discrepancy.joint,0.001738,0.001738,0.001738,0.001738,0.001738,0.001738,0.0
stats.wasserstein_dist.joint,0.06149,0.063997,0.058614,0.06393,0.060067,0.06162,0.0021
stats.prdc.precision,0.994787,0.997394,0.997394,0.997394,0.993918,0.996177,0.0015
stats.prdc.recall,0.926151,0.92007,0.926151,0.906169,0.915725,0.918853,0.0075
stats.prdc.density,1.07715,1.091051,1.083232,1.088792,1.086533,1.085352,0.0048
stats.prdc.coverage,0.874023,0.893136,0.876629,0.86881,0.873154,0.87715,0.0084


In [None]:
p_values_df

Unnamed: 0,age,cd4,priorzdv,hemophil,ivdrug,karnof,raceth,sex,strat2,tx,txgrp
0,2.5e-05,0.009989,0.28449,0.582189,0.966148,0.355496,0.150918,0.234461,0.173183,0.496377,0.928062
1,1.9e-05,0.009262,0.288531,1.0,0.685472,0.2038,0.312499,0.246964,0.123337,0.424347,0.931062
2,6e-06,0.027894,0.404605,0.582189,0.720066,0.477754,0.193536,0.633631,0.178271,0.459915,0.859484
3,2e-06,0.037497,0.228223,0.582189,0.655114,0.648449,0.243695,0.463991,0.149245,0.615116,0.814682
4,1.6e-05,0.022279,0.24929,0.137161,0.950359,0.560413,0.069481,0.176643,0.241894,0.743613,0.948738


# Metabric

In [None]:
dataset="metabric"

metrics_list = []
fit_times = []
generate_times = []
p_values_list = []

for i in range(5):
    df = datasets.metabric.read_df()
    df = df[df['duration'] != 0]

    syn_model = Plugins().get(plugin_name)
    X=df
    # Measure the execution time of the fit function
    start = timer()
    syn_model.fit(X)
    fit_time = timer() - start
    fit_times.append(fit_time)

    random_state = i + 1
    np.random.seed(random_state)

    # Measure the execution time of the generate function
    start = timer()
    X_gen = syn_model.generate(count=len(df)).dataframe()
    generate_time = timer() - start
    generate_times.append(generate_time)

    # Save X_gen as a CSV file
    X_gen.to_csv(f"/content/drive/MyDrive/Nips/{dataset}_{plugin_name}_nocond_iteration_{i+1}.csv", index=False)

    loader1 = SurvivalAnalysisDataLoader(df, target_column="event", time_to_event_column="duration")
    loader2 = SurvivalAnalysisDataLoader(X_gen, target_column="event", time_to_event_column="duration")

    met_df = Metrics.evaluate(X_gt=loader1, X_syn=loader2, task_type='survival_analysis', metrics={
        'stats': ['jensenshannon_dist', 'chi_squared_test', 'feature_corr', 'inv_kl_divergence', 'ks_test',
                 'max_mean_discrepancy', 'wasserstein_dist', 'prdc', 'alpha_precision', 'survival_km_distance'],
        'performance': ['linear_model', 'mlp', 'xgb', 'feat_rank_distance']
    }, use_cache=False, random_state=random_state)

    met_df = met_df.iloc[:, 0]
    metrics_list.append(met_df)

    # Calculate p-values
    p_values_continuous, p_values_discrete = compare_distributions(df, X_gen)
    continuous_column_names = list(p_values_continuous.keys())
    discrete_column_names = list(p_values_discrete.keys())

    p_val = np.concatenate([list(p_values_continuous.values()), list(p_values_discrete.values())])
    p_values_list.append(p_val)

    workspace_dir = os.path.join(os.getcwd(), 'workspace')
    if os.path.exists(workspace_dir):
        shutil.rmtree(workspace_dir)

result_df = pd.concat(metrics_list, axis=1)

# Calculate the row-wise mean and standard deviation of the metrics
result_df['Mean'] = result_df.mean(axis=1)
result_df['Std'] = result_df.std(axis=1)
result_df['Std'] = result_df['Std'].round(4)

# Create DataFrame for p-values
p_values_df = pd.DataFrame(p_values_list, columns=continuous_column_names + discrete_column_names)

# Save result_df and p_values_df as CSV files
result_df.to_csv(f"/content/drive/MyDrive/Nips/{dataset}_{plugin_name}_nocond_result_df.csv")
p_values_df.to_csv(f"/content/drive/MyDrive/Nips/{dataset}_{plugin_name}_nocond_p_values_df.csv")

avg_fit_time = np.mean(fit_times)
avg_generate_time = np.mean(generate_times)
std_fit_time = np.std(fit_times)
std_generate_time = np.std(generate_times)

print(f"\nAverage Fit Time: {avg_fit_time:.4f} seconds, Standard Deviation: {std_fit_time:.4f} seconds")
print(f"Average Generate Time: {avg_generate_time:.4f} seconds, Standard Deviation: {std_generate_time:.4f} seconds")


Dataset 'metabric' not locally available. Downloading...


[2024-04-20T00:19:36.780488+0000][1377][CRITICAL] module disabled: /usr/local/lib/python3.10/dist-packages/synthcity/plugins/generic/plugin_goggle.py


Done


 20%|█▉        | 399/2000 [01:48<07:16,  3.67it/s]
[2024-04-20T00:23:10.577508+0000][1377][CRITICAL] module disabled: /usr/local/lib/python3.10/dist-packages/synthcity/plugins/generic/plugin_goggle.py
 20%|█▉        | 399/2000 [01:50<07:22,  3.62it/s]
[2024-04-20T00:26:47.789247+0000][1377][CRITICAL] module disabled: /usr/local/lib/python3.10/dist-packages/synthcity/plugins/generic/plugin_goggle.py
 20%|█▉        | 399/2000 [01:49<07:21,  3.63it/s]
[2024-04-20T00:30:22.709871+0000][1377][CRITICAL] module disabled: /usr/local/lib/python3.10/dist-packages/synthcity/plugins/generic/plugin_goggle.py
 20%|█▉        | 399/2000 [01:49<07:19,  3.65it/s]
[2024-04-20T00:33:58.293684+0000][1377][CRITICAL] module disabled: /usr/local/lib/python3.10/dist-packages/synthcity/plugins/generic/plugin_goggle.py
 20%|█▉        | 399/2000 [01:50<07:23,  3.61it/s]



Average Fit Time: 112.6770 seconds, Standard Deviation: 0.3831 seconds
Average Generate Time: 0.0878 seconds, Standard Deviation: 0.0041 seconds


In [None]:
result_df

Unnamed: 0,min,min.1,min.2,min.3,min.4,Mean,Std
stats.jensenshannon_dist.marginal,0.012045,0.012793,0.012374,0.012314,0.012589,0.012423,0.0003
stats.chi_squared_test.marginal,0.624305,0.619957,0.622046,0.531274,0.621494,0.603815,0.0363
stats.inv_kl_divergence.marginal,0.883081,0.881954,0.883066,0.8766,0.880482,0.881036,0.0024
stats.ks_test.marginal,0.891033,0.889552,0.888979,0.890412,0.890508,0.890097,0.0007
stats.max_mean_discrepancy.joint,0.001063,0.001061,0.001065,0.001063,0.001064,0.001063,0.0
stats.wasserstein_dist.joint,0.041228,0.041507,0.040145,0.042583,0.04094,0.041281,0.0008
stats.prdc.precision,0.976353,0.979506,0.978981,0.976353,0.978981,0.978035,0.0014
stats.prdc.recall,0.950604,0.93011,0.939569,0.939044,0.95113,0.942091,0.0079
stats.prdc.density,0.956595,0.966474,0.959643,0.962165,0.966684,0.962312,0.0039
stats.prdc.coverage,0.823962,0.815029,0.81608,0.825013,0.807147,0.817446,0.0065


In [None]:
p_values_df

Unnamed: 0,x0,x1,x2,x3,x8,x4,x5,x6,x7
0,4.485423e-13,4.1797169999999995e-34,3.0374749999999996e-38,6.2530110000000004e-52,4.60083e-33,0.394441,1.0,0.127821,0.739378
1,1.170342e-13,2.278491e-34,6.292248e-38,7.383384e-52,6.863909e-34,0.144465,0.947125,0.331931,0.370634
2,3.03194e-13,3.97165e-32,1.03634e-38,4.214538e-53,4.679554e-36,0.052951,1.0,0.031412,0.8449
3,1.762658e-13,1.75781e-34,2.833987e-37,6.288411e-51,3.014409e-36,0.067842,0.72001,0.24522,0.841839
4,4.155925e-13,3.256996e-32,9.842594000000001e-39,9.049426e-51,1.987569e-35,0.460579,0.573296,0.679005,0.417389


# GBSG

In [None]:
dataset="gbsg"

metrics_list = []
fit_times = []
generate_times = []
p_values_list = []

for i in range(5):
    df = datasets.gbsg.read_df()
    df = df[df['duration'] != 0]

    syn_model = Plugins().get(plugin_name)
    X=df
    # Measure the execution time of the fit function
    start = timer()
    syn_model.fit(X)
    fit_time = timer() - start
    fit_times.append(fit_time)

    random_state = i + 1
    np.random.seed(random_state)

    # Measure the execution time of the generate function
    start = timer()
    X_gen = syn_model.generate(count=len(df)).dataframe()
    generate_time = timer() - start
    generate_times.append(generate_time)

    # Save X_gen as a CSV file
    X_gen.to_csv(f"/content/drive/MyDrive/Nips/{dataset}_{plugin_name}_nocond_iteration_{i+1}.csv", index=False)

    loader1 = SurvivalAnalysisDataLoader(df, target_column="event", time_to_event_column="duration")
    loader2 = SurvivalAnalysisDataLoader(X_gen, target_column="event", time_to_event_column="duration")

    met_df = Metrics.evaluate(X_gt=loader1, X_syn=loader2, task_type='survival_analysis', metrics={
        'stats': ['jensenshannon_dist', 'chi_squared_test', 'feature_corr', 'inv_kl_divergence', 'ks_test',
                 'max_mean_discrepancy', 'wasserstein_dist', 'prdc', 'alpha_precision', 'survival_km_distance'],
        'performance': ['linear_model', 'mlp', 'xgb', 'feat_rank_distance']
    }, use_cache=False, random_state=random_state)

    met_df = met_df.iloc[:, 0]
    metrics_list.append(met_df)

    # Calculate p-values
    p_values_continuous, p_values_discrete = compare_distributions(df, X_gen)
    continuous_column_names = list(p_values_continuous.keys())
    discrete_column_names = list(p_values_discrete.keys())

    p_val = np.concatenate([list(p_values_continuous.values()), list(p_values_discrete.values())])
    p_values_list.append(p_val)

    workspace_dir = os.path.join(os.getcwd(), 'workspace')
    if os.path.exists(workspace_dir):
        shutil.rmtree(workspace_dir)

result_df = pd.concat(metrics_list, axis=1)

# Calculate the row-wise mean and standard deviation of the metrics
result_df['Mean'] = result_df.mean(axis=1)
result_df['Std'] = result_df.std(axis=1)
result_df['Std'] = result_df['Std'].round(4)

# Create DataFrame for p-values
p_values_df = pd.DataFrame(p_values_list, columns=continuous_column_names + discrete_column_names)

# Save result_df and p_values_df as CSV files
result_df.to_csv(f"/content/drive/MyDrive/Nips/{dataset}_{plugin_name}_nocond_result_df.csv")
p_values_df.to_csv(f"/content/drive/MyDrive/Nips/{dataset}_{plugin_name}_nocond_p_values_df.csv")

avg_fit_time = np.mean(fit_times)
avg_generate_time = np.mean(generate_times)
std_fit_time = np.std(fit_times)
std_generate_time = np.std(generate_times)

print(f"\nAverage Fit Time: {avg_fit_time:.4f} seconds, Standard Deviation: {std_fit_time:.4f} seconds")
print(f"Average Generate Time: {avg_generate_time:.4f} seconds, Standard Deviation: {std_generate_time:.4f} seconds")


Dataset 'gbsg' not locally available. Downloading...


[2024-04-20T00:37:34.586479+0000][1377][CRITICAL] module disabled: /usr/local/lib/python3.10/dist-packages/synthcity/plugins/generic/plugin_goggle.py


Done


 50%|████▉     | 999/2000 [04:45<04:45,  3.50it/s]
[2024-04-20T00:44:19.996323+0000][1377][CRITICAL] module disabled: /usr/local/lib/python3.10/dist-packages/synthcity/plugins/generic/plugin_goggle.py
 50%|████▉     | 999/2000 [04:44<04:45,  3.51it/s]
[2024-04-20T00:51:05.033608+0000][1377][CRITICAL] module disabled: /usr/local/lib/python3.10/dist-packages/synthcity/plugins/generic/plugin_goggle.py
 50%|████▉     | 999/2000 [04:46<04:47,  3.48it/s]
[2024-04-20T00:57:52.104107+0000][1377][CRITICAL] module disabled: /usr/local/lib/python3.10/dist-packages/synthcity/plugins/generic/plugin_goggle.py
 50%|████▉     | 999/2000 [04:44<04:44,  3.51it/s]
[2024-04-20T01:04:35.913003+0000][1377][CRITICAL] module disabled: /usr/local/lib/python3.10/dist-packages/synthcity/plugins/generic/plugin_goggle.py
 50%|████▉     | 999/2000 [04:43<04:43,  3.53it/s]



Average Fit Time: 287.4944 seconds, Standard Deviation: 1.4224 seconds
Average Generate Time: 0.0730 seconds, Standard Deviation: 0.0023 seconds


In [None]:
result_df

Unnamed: 0,min,min.1,min.2,min.3,min.4,Mean,Std
stats.jensenshannon_dist.marginal,0.009117,0.00855,0.008749,0.009098,0.008774,0.008858,0.0002
stats.chi_squared_test.marginal,0.522624,0.482485,0.512003,0.521355,0.511565,0.510006,0.0145
stats.inv_kl_divergence.marginal,0.925917,0.934726,0.937446,0.938412,0.937871,0.934874,0.0047
stats.ks_test.marginal,0.898994,0.903325,0.900936,0.899791,0.901085,0.900826,0.0015
stats.max_mean_discrepancy.joint,0.000898,0.000899,0.000899,0.000899,0.000899,0.000899,0.0
stats.wasserstein_dist.joint,0.017127,0.01685,0.017187,0.017678,0.016298,0.017028,0.0005
stats.prdc.precision,0.993728,0.994624,0.994624,0.99328,0.994176,0.994086,0.0005
stats.prdc.recall,0.893817,0.898297,0.892921,0.891577,0.876344,0.890591,0.0075
stats.prdc.density,0.985036,0.992025,0.992204,0.996326,0.986201,0.990358,0.0042
stats.prdc.coverage,0.77957,0.780914,0.785842,0.772401,0.769713,0.777688,0.0059


In [None]:
p_values_df

Unnamed: 0,x3,x4,x5,x6,x0,x1,x2
0,0.221038,0.127862,7.235115e-09,9.480833e-10,0.038887,0.352351,1.0
1,0.356394,0.052472,7.559601e-09,1.088234e-09,0.246401,0.281872,0.149961
2,0.27571,0.041214,1.269046e-08,1.11636e-09,0.23806,0.330919,0.773334
3,0.248111,0.127376,1.051723e-08,8.373663e-09,0.333071,0.227022,0.718208
4,0.280662,0.10797,2.695925e-08,1.176554e-08,0.460186,0.473726,0.594067


# SUPPORT

In [None]:
dataset="support"

metrics_list = []
fit_times = []
generate_times = []
p_values_list = []

for i in range(5):
    df = datasets.support.read_df()
    df = df[df['duration'] != 0]

    syn_model = Plugins().get(plugin_name)
    X=df
    # Measure the execution time of the fit function
    start = timer()
    syn_model.fit(X)
    fit_time = timer() - start
    fit_times.append(fit_time)

    random_state = i + 1
    np.random.seed(random_state)

    # Measure the execution time of the generate function
    start = timer()
    X_gen = syn_model.generate(count=len(df)).dataframe()
    generate_time = timer() - start
    generate_times.append(generate_time)

    # Save X_gen as a CSV file
    X_gen.to_csv(f"/content/drive/MyDrive/Nips/{dataset}_{plugin_name}_nocond_iteration_{i+1}.csv", index=False)

    loader1 = SurvivalAnalysisDataLoader(df, target_column="event", time_to_event_column="duration")
    loader2 = SurvivalAnalysisDataLoader(X_gen, target_column="event", time_to_event_column="duration")

    met_df = Metrics.evaluate(X_gt=loader1, X_syn=loader2, task_type='survival_analysis', metrics={
        'stats': ['jensenshannon_dist', 'chi_squared_test', 'feature_corr', 'inv_kl_divergence', 'ks_test',
                 'max_mean_discrepancy', 'wasserstein_dist', 'prdc', 'alpha_precision', 'survival_km_distance'],
        'performance': ['linear_model', 'mlp', 'xgb', 'feat_rank_distance']
    }, use_cache=False, random_state=random_state)

    met_df = met_df.iloc[:, 0]
    metrics_list.append(met_df)

    # Calculate p-values
    p_values_continuous, p_values_discrete = compare_distributions(df, X_gen)
    continuous_column_names = list(p_values_continuous.keys())
    discrete_column_names = list(p_values_discrete.keys())

    p_val = np.concatenate([list(p_values_continuous.values()), list(p_values_discrete.values())])
    p_values_list.append(p_val)

    workspace_dir = os.path.join(os.getcwd(), 'workspace')
    if os.path.exists(workspace_dir):
        shutil.rmtree(workspace_dir)

result_df = pd.concat(metrics_list, axis=1)

# Calculate the row-wise mean and standard deviation of the metrics
result_df['Mean'] = result_df.mean(axis=1)
result_df['Std'] = result_df.std(axis=1)
result_df['Std'] = result_df['Std'].round(4)

# Create DataFrame for p-values
p_values_df = pd.DataFrame(p_values_list, columns=continuous_column_names + discrete_column_names)

# Save result_df and p_values_df as CSV files
result_df.to_csv(f"/content/drive/MyDrive/Nips/{dataset}_{plugin_name}_nocond_result_df.csv")
p_values_df.to_csv(f"/content/drive/MyDrive/Nips/{dataset}_{plugin_name}_nocond_p_values_df.csv")

avg_fit_time = np.mean(fit_times)
avg_generate_time = np.mean(generate_times)
std_fit_time = np.std(fit_times)
std_generate_time = np.std(generate_times)

print(f"\nAverage Fit Time: {avg_fit_time:.4f} seconds, Standard Deviation: {std_fit_time:.4f} seconds")
print(f"Average Generate Time: {avg_generate_time:.4f} seconds, Standard Deviation: {std_generate_time:.4f} seconds")


Dataset 'support' not locally available. Downloading...
Done


[2024-04-20T09:10:11.458726+0000][2633][CRITICAL] module disabled: /usr/local/lib/python3.10/dist-packages/synthcity/plugins/generic/plugin_goggle.py
 20%|█▉        | 399/2000 [08:10<32:46,  1.23s/it]


[KeOps] Generating code for Max_SumShiftExpWeight_Reduction reduction (with parameters 0) of formula [c-1/2*(d*Sum((a-b)**2)),1] with a=Var(0,16,0), b=Var(1,16,1), c=Var(2,1,1), d=Var(3,1,2) ... OK


[2024-04-20T09:23:05.773820+0000][2633][CRITICAL] module disabled: /usr/local/lib/python3.10/dist-packages/synthcity/plugins/generic/plugin_goggle.py
 20%|█▉        | 399/2000 [08:32<34:16,  1.28s/it]




[2024-04-20T09:36:00.329455+0000][2633][CRITICAL] module disabled: /usr/local/lib/python3.10/dist-packages/synthcity/plugins/generic/plugin_goggle.py
 20%|█▉        | 399/2000 [08:43<35:01,  1.31s/it]




[2024-04-20T09:49:05.604950+0000][2633][CRITICAL] module disabled: /usr/local/lib/python3.10/dist-packages/synthcity/plugins/generic/plugin_goggle.py
 20%|█▉        | 399/2000 [08:30<34:09,  1.28s/it]




[2024-04-20T10:01:52.083277+0000][2633][CRITICAL] module disabled: /usr/local/lib/python3.10/dist-packages/synthcity/plugins/generic/plugin_goggle.py
 20%|█▉        | 399/2000 [08:31<34:13,  1.28s/it]



Average Fit Time: 523.6260 seconds, Standard Deviation: 5.6088 seconds
Average Generate Time: 0.1924 seconds, Standard Deviation: 0.0357 seconds


In [None]:
result_df

Unnamed: 0,min,min.1,min.2,min.3,min.4,Mean,Std
stats.jensenshannon_dist.marginal,0.004982,0.005033,0.004966,0.005017,0.004978,0.004995,0.0
stats.chi_squared_test.marginal,0.742279,0.741677,0.679545,0.742122,0.741909,0.729507,0.025
stats.inv_kl_divergence.marginal,0.968301,0.968099,0.968104,0.96807,0.968135,0.968142,0.0001
stats.ks_test.marginal,0.933393,0.933041,0.933154,0.933464,0.933302,0.933271,0.0002
stats.max_mean_discrepancy.joint,0.000225,0.000225,0.000225,0.000225,0.000225,0.000225,0.0
stats.wasserstein_dist.joint,0.036289,0.036568,0.036415,0.03627,0.036611,0.036431,0.0001
stats.prdc.precision,0.971599,0.971825,0.969683,0.971487,0.971036,0.971126,0.0008
stats.prdc.recall,0.916601,0.918517,0.920545,0.916826,0.918291,0.918156,0.0014
stats.prdc.density,1.036966,1.037191,1.033247,1.035929,1.035276,1.035722,0.0014
stats.prdc.coverage,0.896427,0.895638,0.899696,0.89654,0.896315,0.896923,0.0014


In [None]:
p_values_df

Unnamed: 0,x0,x7,x8,x9,x10,x11,x12,x13,x1,x2,x3,x4,x5,x6
0,0.726121,3.310172e-12,1.548332e-36,0.001049,3.059714e-59,0.050561,0.027578,5.139781e-31,0.776252,0.902419,0.940691,0.974938,0.915735,0.87667
1,0.823218,3.828446e-12,7.179075e-37,0.001328,1.832767e-59,0.035229,0.032326,6.934507000000001e-31,0.691674,0.286341,0.880353,0.712349,1.0,0.864436
2,0.804119,6.81556e-13,1.5309069999999998e-36,0.001206,1.014069e-59,0.043853,0.031378,2.5035470000000003e-31,0.395874,0.963978,0.893514,0.4744,1.0,0.905785
3,0.978721,8.747006e-13,7.895356e-35,0.000542,1.655767e-58,0.050156,0.037892,4.9320980000000005e-31,0.707631,0.262724,0.841616,0.855166,0.947947,0.967275
4,0.764374,8.559673e-13,2.005448e-34,0.00309,1.46934e-59,0.054052,0.015606,1.304853e-31,0.69773,0.967451,0.857312,0.762023,0.555978,0.896069
