In [1]:
import pandas as pd 
import numpy as np
from scipy.stats import wasserstein_distance as wd
from scipy.stats import chisquare
from scipy import stats
import warnings
from tqdm import tqdm
from SIMPRO import simpro

In [2]:
def create_datasets(df1, other_dfs = ''):
    output = {}
    output['d1'] = df1
    if other_dfs != '':
        count = 2
        for df in other_dfs:
            output[f"d{count}"] = df
            count += 1
    return output

The task_id subgroup of 31941 is shown as example.

In [81]:
task_id = 31941

In [3]:
d1_col = pd.read_csv(f"data/task_id_{task_id}/ads.csv").columns
d2_col = pd.read_csv(f"data/task_id_{task_id}/feeds.csv").columns

# DEREC-REaLTabFormer

In [50]:
og_p = pd.read_csv(f"data/task_id_{task_id}/DEREC/original/unique_parent_small.csv")
og_c_1 = pd.read_csv(f"data/task_id_{task_id}/DEREC/original/ads_child_small.csv")
og_c_2 = pd.read_csv(f"data/task_id_{task_id}/DEREC/original/feeds_child_small.csv")

og_1 = pd.merge(og_p[['user_id'] + list(d1_col.difference(og_c_1.columns))], og_c_1, left_on = 'user_id', right_on = 'user_id')
og_2 = pd.merge(og_p[['user_id'] + list(d2_col.difference(og_c_2.columns))], og_c_2, left_on = 'user_id', right_on = 'user_id')


In [52]:
syn_p = pd.read_csv(f"data/task_id_{task_id}/DEREC/synthetic/realtabformer_syn_parent.csv")
syn_c_1 = pd.read_csv(f"data/task_id_{task_id}/DEREC/synthetic/realtabformer_syn_child_ads.csv")
syn_c_2 = pd.read_csv(f"data/task_id_{task_id}/DEREC/synthetic/realtabformer_syn_child_feeds.csv")

syn_1 = pd.merge(syn_p[['user_id'] + list(d1_col.difference(syn_c_1.columns))], syn_c_1, left_on = 'user_id', right_on = 'user_id')
syn_2 = pd.merge(syn_p[['user_id'] + list(d2_col.difference(syn_c_2.columns))], syn_c_2, left_on = 'user_id', right_on = 'user_id')


In [53]:
drop_list = ['e_et', 'pt_d', 'ad_close_list_v001', 'ad_close_list_v002', 'ad_close_list_v003', 'log_id']

for col in drop_list:
    if col in og_1.columns:
        og_1 = og_1.drop(col, axis = 1)
    if col in og_2.columns:
        og_2 = og_2.drop(col, axis = 1)
    if col in syn_1.columns:
        syn_1 = syn_1.drop(col, axis = 1)
    if col in og_2.columns:
        syn_2 = syn_2.drop(col, axis = 1)


In [55]:
derec_og = create_datasets(og_1, [og_2])
derec_syn = create_datasets(syn_1, [syn_2])

derec_evaluation = simpro(derec_og, derec_syn)
derec_evaluation.cal_marginal_indicators()
derec_evaluation.cal_conditional_indicators()


100%|██████████████████████████████████████████████████████████████████████████████| 1936/1936 [08:31<00:00,  3.79it/s]


# Control Group

In [75]:
og_1 = pd.read_csv(f"data/task_id_{task_id}/Control Group/original/unique_feeds_small.csv")
og_2 = pd.read_csv(f"data/task_id_{task_id}/Control Group/original/unique_ads_small.csv")

syn_1 = pd.read_csv(f"data/task_id_{task_id}/Control Group/synthetic/realtabformer_syn_parent.csv")
syn_2 = pd.read_csv(f"data/task_id_{task_id}/Control Group/synthetic/realtabformer_syn_child.csv")

drop_list = ['e_et', 'pt_d', 'ad_close_list_v001', 'ad_close_list_v002', 'ad_close_list_v003', 'log_id']

for col in drop_list:
    if col in og_1.columns:
        og_1 = og_1.drop(col, axis = 1)
    if col in og_2.columns:
        og_2 = og_2.drop(col, axis = 1)
    if col in syn_1.columns:
        syn_1 = syn_1.drop(col, axis = 1)
    if col in og_2.columns:
        syn_2 = syn_2.drop(col, axis = 1)

In [76]:
cg_og = create_datasets(og_1, [og_2])
cg_syn = create_datasets(syn_1, [syn_2])

cg_evaluation = simpro(cg_og, cg_syn)
cg_evaluation.cal_marginal_indicators()
cg_evaluation.cal_conditional_indicators()

100%|██████████████████████████████████████████████████████████████████████████████| 1936/1936 [05:02<00:00,  6.41it/s]


# CT-GAN 

The evaluation for CT-GAN takes exceptionally long time as the CTGAN model synthesised a lot of categories in the form of integers so there are a lot of different conditional distribution for each column. This proves superiority on CT-GAN as CT-GAN requires pre-specification of discrete columns while the REaLTabFormer is capable of detecting that itself. 

In [77]:
og = pd.read_csv(f"data/task_id_{task_id}/CTGAN/original_dataset.csv")
syn = pd.read_csv(f"data/task_id_{task_id}/CTGAN/synthetic_dataset.csv")

col_to_extract = []
col_name = []

for col in d1_col:
    if (col in og.columns): 
        col_to_extract.append(col)
        col_name.append(col)
    elif (f"{col}_x" in og.columns):
        col_to_extract.append(f"{col}_x")
        col_name.append(col)
        
og_1 = og[col_to_extract]
og_1.columns = col_name

col_to_extract = []
col_name = []

for col in d2_col:
    if (col in og.columns): 
        col_to_extract.append(col)
        col_name.append(col)
    elif (f"{col}_y" in og.columns):
        col_to_extract.append(f"{col}_y")
        col_name.append(col)
        
og_2 = og[col_to_extract]
og_2.columns = col_name

col_to_extract = []
col_name = []

for col in d1_col:
    if (col in syn.columns): 
        col_to_extract.append(col)
        col_name.append(col)
    elif (f"{col}_x" in syn.columns):
        col_to_extract.append(f"{col}_x")
        col_name.append(col)
        
syn_1 = syn[col_to_extract]
syn_1.columns = col_name

col_to_extract = []
col_name = []

for col in d2_col:
    if (col in syn.columns): 
        col_to_extract.append(col)
        col_name.append(col)
    elif (f"{col}_y" in syn.columns):
        col_to_extract.append(f"{col}_y")
        col_name.append(col)
        
syn_2 = syn[col_to_extract]
syn_2.columns = col_name

drop_list = ['e_et', 'pt_d', 'ad_close_list_v001', 'ad_close_list_v002', 'ad_close_list_v003', 'log_id']

for col in drop_list:
    if col in og_1.columns:
        og_1 = og_1.drop(col, axis = 1)
    if col in og_2.columns:
        og_2 = og_2.drop(col, axis = 1)
    if col in syn_1.columns:
        syn_1 = syn_1.drop(col, axis = 1)
    if col in og_2.columns:
        syn_2 = syn_2.drop(col, axis = 1)

In [78]:
ctgan_og = create_datasets(og_1, [og_2])
ctgan_syn = create_datasets(syn_1, [syn_2])

ctgan_evaluation = simpro(ctgan_og, ctgan_syn)
ctgan_evaluation.cal_marginal_indicators()
ctgan_evaluation.cal_conditional_indicators()

  7%|█████▎                                                                       | 135/1936 [14:58<3:19:50,  6.66s/it]


KeyboardInterrupt: 

# TabDDPM

In [79]:
og = pd.read_csv(f"data/task_id_{task_id}/TabDDPM/original_dataset.csv")
syn = pd.read_csv(f"data/task_id_{task_id}/TabDDPM/synthetic_dataset.csv")

col_to_extract = []
col_name = []

for col in d1_col:
    if (col in og.columns): 
        col_to_extract.append(col)
        col_name.append(col)
    elif (f"{col}_x" in og.columns):
        col_to_extract.append(f"{col}_x")
        col_name.append(col)
        
og_1 = og[col_to_extract]
og_1.columns = col_name

col_to_extract = []
col_name = []

for col in d2_col:
    if (col in og.columns): 
        col_to_extract.append(col)
        col_name.append(col)
    elif (f"{col}_y" in og.columns):
        col_to_extract.append(f"{col}_y")
        col_name.append(col)
        
og_2 = og[col_to_extract]
og_2.columns = col_name

col_to_extract = []
col_name = []

for col in d1_col:
    if (col in syn.columns): 
        col_to_extract.append(col)
        col_name.append(col)
    elif (f"{col}_x" in syn.columns):
        col_to_extract.append(f"{col}_x")
        col_name.append(col)
        
syn_1 = syn[col_to_extract]
syn_1.columns = col_name

col_to_extract = []
col_name = []

for col in d2_col:
    if (col in syn.columns): 
        col_to_extract.append(col)
        col_name.append(col)
    elif (f"{col}_y" in syn.columns):
        col_to_extract.append(f"{col}_y")
        col_name.append(col)
        
syn_2 = syn[col_to_extract]
syn_2.columns = col_name

drop_list = ['e_et', 'pt_d', 'ad_close_list_v001', 'ad_close_list_v002', 'ad_close_list_v003', 'log_id']

for col in drop_list:
    if col in og_1.columns:
        og_1 = og_1.drop(col, axis = 1)
    if col in og_2.columns:
        og_2 = og_2.drop(col, axis = 1)
    if col in syn_1.columns:
        syn_1 = syn_1.drop(col, axis = 1)
    if col in og_2.columns:
        syn_2 = syn_2.drop(col, axis = 1)

In [80]:
tabddpm_og = create_datasets(og_1, [og_2])
tabddpm_syn = create_datasets(syn_1, [syn_2])

tabddpm_evaluation = simpro(tabddpm_og, tabddpm_syn)
tabddpm_evaluation.cal_marginal_indicators()
tabddpm_evaluation.cal_conditional_indicators()

100%|██████████████████████████████████████████████████████████████████████████████| 1936/1936 [03:08<00:00, 10.24it/s]


# All indicators are outputted into the all_result_record.xlsx file for logging and better table design. 