In [56]:
import pandas as pd
import numpy as np
from datetime import datetime

In [57]:
#dataset is downloaded from https://github.com/BigTuna08/Banksformer_ecml_2022/tree/4e998944825f04f94f6cb3a0058b4733fd70ccf5/czech/banksformer/data
real = pd.read_csv('../CTGAN/tr_by_acct_w_age.csv')
real

Unnamed: 0.1,Unnamed: 0,column_a,account_id,date,type,operation,amount,balance,k_symbol,age,tcode
0,0,149432,1,950324,CREDIT,CREDIT IN CASH,1000.0,1000.0,,29,cash_cr
1,1,157404,1,950413,CREDIT,COLLECTION FROM ANOTHER BANK,3679.0,4679.0,,29,bank_cr
2,2,158832,1,950423,CREDIT,CREDIT IN CASH,12600.0,17279.0,,29,cash_cr
3,3,162681,1,950430,CREDIT,,19.2,17298.2,INTEREST CREDITED,29,interest_cr
4,4,167083,1,950513,CREDIT,COLLECTION FROM ANOTHER BANK,3679.0,20977.2,,29,bank_cr
...,...,...,...,...,...,...,...,...,...,...,...
1056315,1056315,1033141,11382,981202,DEBIT,CASH WITHDRAWAL,25600.0,41114.4,,46,cash_db_nan
1056316,1056316,1040574,11382,981210,CREDIT,COLLECTION FROM ANOTHER BANK,46248.0,87362.4,,46,bank_cr
1056317,1056317,1050362,11382,981225,DEBIT,CASH WITHDRAWAL,6300.0,81062.4,,46,cash_db_nan
1056318,1056318,1053037,11382,981231,CREDIT,,311.3,81373.6,INTEREST CREDITED,46,interest_cr


In [58]:
# data preprocessing , some parts are the same as 
#https://github.com/BigTuna08/Banksformer_ecml_2022/blob/4e998944825f04f94f6cb3a0058b4733fd70ccf5/czech/banksformer/nb1_preprocess_czech.ipynb

real = real.sort_values(by = ["account_id", "date"])
czech_date_parser = lambda x: datetime.strptime(str(x), "%y%m%d")
real["datetime"] = real["date"].apply(czech_date_parser)
#df["datetime"] = pd.to_datetime(df["datetime"])

real["month"] = real["datetime"].dt.month
real["day"] = real["datetime"].dt.day
real["dow"] =  real["datetime"].dt.dayofweek
real["year"] = real["datetime"].dt.year
real["doy"] = real["datetime"].dt.dayofyear


real["td"] = real[["account_id", "datetime"]].groupby("account_id").diff()
real["td"] = real["td"].apply(lambda x: x.days)
real["td"].fillna(0.0, inplace=True)


real['raw_amount'] = real.apply(lambda row: row['amount'] if row['type'] == 'CREDIT' else -row['amount'], axis=1)


bin_edges = [17, 30, 40, 50, 60, 81]
labels = ['18-30', '31-40', '41-50', '51-60', '61+']
# Use pd.cut() to convert ages to categorical groups
real['age_group'] = pd.cut(real['age'], bins=bin_edges, labels=labels, right=False)
real['age_group'] = real['age_group'].astype('object')


conditions = [
    (real['day'] >= 1) & (real['day'] <= 10),
    (real['day'] > 10) & (real['day'] <= 20),
    (real['day'] > 20) & (real['day'] <= 31)
      ]
categories = ['first', 'middle', 'last']
real['dtme_cat'] = np.select(conditions, categories)


cat_code_fields = ['type', 'operation', 'k_symbol']
TCODE_SEP = "__"
# create tcode by concating fields in "cat_code_fields"
tcode = real[cat_code_fields[0]].astype(str)
for ccf in cat_code_fields[1:]:
    tcode += TCODE_SEP + real[ccf].astype(str)

real["tcode"] = tcode


In [59]:
real

Unnamed: 0.1,Unnamed: 0,column_a,account_id,date,type,operation,amount,balance,k_symbol,age,...,datetime,month,day,dow,year,doy,td,raw_amount,age_group,dtme_cat
0,0,149432,1,950324,CREDIT,CREDIT IN CASH,1000.0,1000.0,,29,...,1995-03-24,3,24,4,1995,83,0.0,1000.0,18-30,last
1,1,157404,1,950413,CREDIT,COLLECTION FROM ANOTHER BANK,3679.0,4679.0,,29,...,1995-04-13,4,13,3,1995,103,20.0,3679.0,18-30,middle
2,2,158832,1,950423,CREDIT,CREDIT IN CASH,12600.0,17279.0,,29,...,1995-04-23,4,23,6,1995,113,10.0,12600.0,18-30,last
3,3,162681,1,950430,CREDIT,,19.2,17298.2,INTEREST CREDITED,29,...,1995-04-30,4,30,6,1995,120,7.0,19.2,18-30,last
4,4,167083,1,950513,CREDIT,COLLECTION FROM ANOTHER BANK,3679.0,20977.2,,29,...,1995-05-13,5,13,5,1995,133,13.0,3679.0,18-30,middle
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1056315,1056315,1033141,11382,981202,DEBIT,CASH WITHDRAWAL,25600.0,41114.4,,46,...,1998-12-02,12,2,2,1998,336,2.0,-25600.0,41-50,first
1056316,1056316,1040574,11382,981210,CREDIT,COLLECTION FROM ANOTHER BANK,46248.0,87362.4,,46,...,1998-12-10,12,10,3,1998,344,8.0,46248.0,41-50,first
1056317,1056317,1050362,11382,981225,DEBIT,CASH WITHDRAWAL,6300.0,81062.4,,46,...,1998-12-25,12,25,4,1998,359,15.0,-6300.0,41-50,last
1056318,1056318,1053037,11382,981231,CREDIT,,311.3,81373.6,INTEREST CREDITED,46,...,1998-12-31,12,31,3,1998,365,6.0,311.3,41-50,last


In [62]:
real_df = real[['amount', 'tcode', 'month', 'dow', 'year', 'dtme_cat', 'age_group', 'td']]   #real data given to discriminator
real_df

Unnamed: 0,amount,tcode,month,dow,year,dtme_cat,age_group,td
0,1000.0,CREDIT__CREDIT IN CASH__nan,3,4,1995,last,18-30,0.0
1,3679.0,CREDIT__COLLECTION FROM ANOTHER BANK__nan,4,3,1995,middle,18-30,20.0
2,12600.0,CREDIT__CREDIT IN CASH__nan,4,6,1995,last,18-30,10.0
3,19.2,CREDIT__nan__INTEREST CREDITED,4,6,1995,last,18-30,7.0
4,3679.0,CREDIT__COLLECTION FROM ANOTHER BANK__nan,5,5,1995,middle,18-30,13.0
...,...,...,...,...,...,...,...,...
1056315,25600.0,DEBIT__CASH WITHDRAWAL__nan,12,2,1998,first,41-50,2.0
1056316,46248.0,CREDIT__COLLECTION FROM ANOTHER BANK__nan,12,3,1998,first,41-50,8.0
1056317,6300.0,DEBIT__CASH WITHDRAWAL__nan,12,4,1998,last,41-50,15.0
1056318,311.3,CREDIT__nan__INTEREST CREDITED,12,3,1998,last,41-50,6.0


In [63]:
synth = pd.read_csv('synth.csv')
synth

Unnamed: 0,amount,tcode,month,dow,year,dtme_cat,age_group,td
0,6000.984844,DEBIT__REMITTANCE TO ANOTHER BANK__INSURANCE P...,3,2,1997,first,51-60,1.173169
1,-788.373432,DEBIT__CASH WITHDRAWAL__PAYMENT ON STATEMENT,11,6,1996,last,31-40,4.108638
2,21545.802611,CREDIT__CREDIT IN CASH__nan,1,6,1997,first,31-40,0.918106
3,5787.214126,DEBIT__REMITTANCE TO ANOTHER BANK__INSURANCE P...,2,2,1997,first,51-60,5.593859
4,2515.699736,DEBIT__CASH WITHDRAWAL__nan,3,1,1997,first,31-40,6.379620
...,...,...,...,...,...,...,...,...
3995,55.376225,CREDIT__COLLECTION FROM ANOTHER BANK__OLD AGE ...,8,0,1998,last,61+,0.017258
3996,1252.254299,DEBIT__CASH WITHDRAWAL__nan,10,4,1996,middle,51-60,1.180114
3997,11908.408018,CREDIT__CREDIT IN CASH__nan,4,6,1998,middle,51-60,2.961073
3998,436.063440,DEBIT__CASH WITHDRAWAL__nan,1,1,1995,first,31-40,4.842274


In [64]:
conditions = [
    (synth['dtme_cat'] == 'first'),
    (synth['dtme_cat'] == 'middle'),
    (synth['dtme_cat'] == 'last')
]

# Create the 'day' column based on 'dtme_cat'
synth['day'] = np.select(conditions, [
    np.random.randint(1, 11, size=len(synth)),   # Random values between 1 and 10
    np.random.randint(10, 21, size=len(synth)),  # Random values between 10 and 20
    np.random.randint(20, 31, size=len(synth))   # Random values between 20 and 30
])


In [65]:
synth

Unnamed: 0,amount,tcode,month,dow,year,dtme_cat,age_group,td,day
0,6000.984844,DEBIT__REMITTANCE TO ANOTHER BANK__INSURANCE P...,3,2,1997,first,51-60,1.173169,8
1,-788.373432,DEBIT__CASH WITHDRAWAL__PAYMENT ON STATEMENT,11,6,1996,last,31-40,4.108638,23
2,21545.802611,CREDIT__CREDIT IN CASH__nan,1,6,1997,first,31-40,0.918106,3
3,5787.214126,DEBIT__REMITTANCE TO ANOTHER BANK__INSURANCE P...,2,2,1997,first,51-60,5.593859,7
4,2515.699736,DEBIT__CASH WITHDRAWAL__nan,3,1,1997,first,31-40,6.379620,8
...,...,...,...,...,...,...,...,...,...
3995,55.376225,CREDIT__COLLECTION FROM ANOTHER BANK__OLD AGE ...,8,0,1998,last,61+,0.017258,20
3996,1252.254299,DEBIT__CASH WITHDRAWAL__nan,10,4,1996,middle,51-60,1.180114,18
3997,11908.408018,CREDIT__CREDIT IN CASH__nan,4,6,1998,middle,51-60,2.961073,15
3998,436.063440,DEBIT__CASH WITHDRAWAL__nan,1,1,1995,first,31-40,4.842274,9


In [66]:
synth['account_id'] = (synth.index // 80) + 1

In [67]:
synth['type'] = synth['tcode'].str.split('__').str[0]
synth['raw_amount'] = synth.apply(lambda row: row['amount'] if row['type'] == 'CREDIT' else -row['amount'], axis=1)

In [68]:
cat_order = ['first', 'middle', 'last']
synth['dtme_cat_order'] = pd.Categorical(synth['dtme_cat'], categories=cat_order, ordered=True)

# Sort the DataFrame based on 'account_id', 'year', 'month', 'dow', and 'dtme_cat'
synth_sorted = synth.sort_values(['account_id', 'year', 'month', 'dtme_cat'])

In [69]:
synth_sorted

Unnamed: 0,amount,tcode,month,dow,year,dtme_cat,age_group,td,day,account_id,type,raw_amount,dtme_cat_order
56,1320.637478,DEBIT__REMITTANCE TO ANOTHER BANK__HOUSEHOLD,4,3,1993,middle,61+,2.810049,12,1,DEBIT,-1320.637478,middle
32,142.130913,DEBIT__CASH WITHDRAWAL__nan,1,1,1995,last,51-60,15.043193,23,1,DEBIT,-142.130913,last
45,-874.737235,DEBIT__CASH WITHDRAWAL__nan,1,0,1995,middle,51-60,3.768954,12,1,DEBIT,874.737235,middle
5,151.876689,CREDIT__nan__INTEREST CREDITED,6,3,1995,last,31-40,2.170649,29,1,CREDIT,151.876689,last
44,-125.579254,DEBIT__CASH WITHDRAWAL__nan,7,2,1995,last,31-40,3.282290,22,1,DEBIT,125.579254,last
...,...,...,...,...,...,...,...,...,...,...,...,...,...
3995,55.376225,CREDIT__COLLECTION FROM ANOTHER BANK__OLD AGE ...,8,0,1998,last,61+,0.017258,20,50,CREDIT,55.376225,last
3988,181.378304,CREDIT__nan__INTEREST CREDITED,10,2,1998,last,51-60,0.026866,24,50,CREDIT,181.378304,last
3960,2970.501907,DEBIT__REMITTANCE TO ANOTHER BANK__HOUSEHOLD,10,3,1998,middle,31-40,3.956926,14,50,DEBIT,-2970.501907,middle
3982,2699.952496,CREDIT__nan__INTEREST CREDITED,11,0,1998,last,18-30,5.350949,27,50,CREDIT,2699.952496,last


In [70]:
real

Unnamed: 0.1,Unnamed: 0,column_a,account_id,date,type,operation,amount,balance,k_symbol,age,...,datetime,month,day,dow,year,doy,td,raw_amount,age_group,dtme_cat
0,0,149432,1,950324,CREDIT,CREDIT IN CASH,1000.0,1000.0,,29,...,1995-03-24,3,24,4,1995,83,0.0,1000.0,18-30,last
1,1,157404,1,950413,CREDIT,COLLECTION FROM ANOTHER BANK,3679.0,4679.0,,29,...,1995-04-13,4,13,3,1995,103,20.0,3679.0,18-30,middle
2,2,158832,1,950423,CREDIT,CREDIT IN CASH,12600.0,17279.0,,29,...,1995-04-23,4,23,6,1995,113,10.0,12600.0,18-30,last
3,3,162681,1,950430,CREDIT,,19.2,17298.2,INTEREST CREDITED,29,...,1995-04-30,4,30,6,1995,120,7.0,19.2,18-30,last
4,4,167083,1,950513,CREDIT,COLLECTION FROM ANOTHER BANK,3679.0,20977.2,,29,...,1995-05-13,5,13,5,1995,133,13.0,3679.0,18-30,middle
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1056315,1056315,1033141,11382,981202,DEBIT,CASH WITHDRAWAL,25600.0,41114.4,,46,...,1998-12-02,12,2,2,1998,336,2.0,-25600.0,41-50,first
1056316,1056316,1040574,11382,981210,CREDIT,COLLECTION FROM ANOTHER BANK,46248.0,87362.4,,46,...,1998-12-10,12,10,3,1998,344,8.0,46248.0,41-50,first
1056317,1056317,1050362,11382,981225,DEBIT,CASH WITHDRAWAL,6300.0,81062.4,,46,...,1998-12-25,12,25,4,1998,359,15.0,-6300.0,41-50,last
1056318,1056318,1053037,11382,981231,CREDIT,,311.3,81373.6,INTEREST CREDITED,46,...,1998-12-31,12,31,3,1998,365,6.0,311.3,41-50,last


In [71]:
import numpy as np
conditions = [
(real['day'] >= 1) & (real['day'] <= 10),
(real['day'] > 10) & (real['day'] <= 20),
(real['day'] > 20) & (real['day'] <= 31)
    ]
    
categories = ['first', 'middle', 'last']

# Use numpy.select() to map the numbers to categories
real['dtme_cat'] = np.select(conditions, categories)

In [72]:
cat_order = ['first', 'middle', 'last']
real['dtme_cat_order'] = pd.Categorical(real['dtme_cat'], categories=cat_order, ordered=True)

# Sort the DataFrame based on 'account_id', 'year', 'month', 'dow', and 'dtme_cat'
real_sorted = real.sort_values(['account_id', 'year', 'month', 'dtme_cat'])

## Wasserstein-1 distances for univariate amount (Amt) and Cash Flow

In [73]:
from scipy.stats import energy_distance
from scipy.stats import wasserstein_distance
from scipy.stats import ks_2samp

def ks_dist(real_obs, gen_obs):
    stat, pval = ks_2samp(real_obs, gen_obs)
    
    return stat

# the column names of continous features
CONT_FIELDS = ["amount", "td"]
CONTINUOUS_METRICS = {"wasser": wasserstein_distance,
                "ks": ks_dist,
                "energy_d": energy_distance}
univariate_cont_res = {}

for field in CONT_FIELDS:
    univariate_cont_res[field] = {}
    for name, fn in CONTINUOUS_METRICS.items():
        univariate_cont_res[field][name] = fn(real[field], synth[field])

In [74]:
univariate_cont_res

{'amount': {'wasser': 1626.505185236183,
  'ks': 0.17368312632535596,
  'energy_d': 12.757106607028666},
 'td': {'wasser': 1.0189492696349403,
  'ks': 0.12718320205998185,
  'energy_d': 0.319319711501024}}

In [75]:
gbr = real[["account_id", "month", "raw_amount", "year"]].groupby(["account_id", "month", "year"],as_index=False)["raw_amount"].sum()
gbr

Unnamed: 0,account_id,month,year,raw_amount
0,1,1,1996,-4593.5
1,1,1,1997,157.8
2,1,1,1998,-1621.7
3,1,2,1996,1372.1
4,1,2,1997,1770.7
...,...,...,...,...
185052,11382,11,1998,12336.5
185053,11382,12,1995,10563.6
185054,11382,12,1996,-5313.8
185055,11382,12,1997,17677.7


In [76]:
# Calculate the sum of 'raw_amount' for each 'account_id' in each 'month'
real_cf = real[["account_id", "month", "raw_amount", "year"]].groupby(["account_id", "month", "year"],as_index=False)["raw_amount"].sum()
synth_cf = synth[["account_id", "month", "raw_amount", "year"]].groupby(["account_id", "month", "year"],as_index=False)["raw_amount"].sum()

In [77]:
field = 'raw_amount'
univariate_cont_res['CF'] = {}
for name, fn in CONTINUOUS_METRICS.items():
    univariate_cont_res['CF'][name] = fn(real_cf[field], synth_cf[field])

In [78]:
univariate_cont_res

{'amount': {'wasser': 1626.505185236183,
  'ks': 0.17368312632535596,
  'energy_d': 12.757106607028666},
 'td': {'wasser': 1.0189492696349403,
  'ks': 0.12718320205998185,
  'energy_d': 0.319319711501024},
 'CF': {'wasser': 2542.3328818878304,
  'ks': 0.25618140634954634,
  'energy_d': 19.675372821093546}}

## JSD between the distributions of tcode 3-grams

In [79]:
import nltk
from nltk.util import ngrams
from scipy.special import rel_entr
from scipy.special import entr
from scipy.spatial import distance



In [80]:
def create_ngramcount_df(df, n, field):
    #gb = df.sort_values(by=["account_id", "datetime"]).groupby("account_id", sort=False)[field]
    gb = df.groupby("account_id", sort=False)[field]
    ngram_list = gb.apply(lambda x: list(ngrams(x, n=n)))

    counts = {}
    for ngram_seq in ngram_list:
        for ngram in ngram_seq:
            ngram = str(ngram)[1:-1]
            counts[ngram] = counts.get(ngram, 0) + 1
            
            
    df = pd.DataFrame.from_dict(counts, orient="index", columns=["counts"]).sort_values("counts", ascending=False)
    
            
    return df.reset_index().rename(columns={"index": "ngram"})

def compute_ngram_metrics(real_df, gen_df, field, n , pseudo_counts=0.0):

    
    n_codes_unique = len(set(real_df[field].unique()).union(set(gen_df[field].unique())))

    
    # create combo_df, which contains counts of all ngrams for both datasets (note: it omits any ngrams which do not occur in either dataset)
    real_ngrams = create_ngramcount_df(real_df, n, field)
    gen_ngrams = create_ngramcount_df(gen_df, n, field)
    combo_df = pd.merge(real_ngrams, gen_ngrams, on="ngram", how="outer", suffixes=("_real", "_gen")).fillna(0.0)


    N_obs_real = real_ngrams["counts"].sum()
    N_obs_gen = gen_ngrams["counts"].sum()
    N_possible_ngrams = n_codes_unique**n 

    
    # add psudo-counts
    combo_df["counts_real"] += pseudo_counts
    combo_df["ps_real"] = combo_df["counts_real"] / (N_obs_real + N_possible_ngrams*pseudo_counts)
    combo_df["counts_gen"] += pseudo_counts
    combo_df["ps_gen"] = combo_df["counts_gen"] / (N_obs_gen + N_possible_ngrams*pseudo_counts)
    

        
        
    # compute jsd (note: contribution to jsd from any ngram not in either dataset is 0)
    combo_df["ps_mid"] = (combo_df["ps_real"] + combo_df["ps_gen"])/2
    kl_real_M = sum(rel_entr(combo_df["ps_real"], combo_df["ps_mid"])) 
    kl_gen_M = sum(rel_entr(combo_df["ps_gen"], combo_df["ps_mid"]))

    jsd = (kl_real_M + kl_gen_M)/2
        
        
    # compute entropy for both distributions
    n_unobs = N_possible_ngrams - len(combo_df)

    entr_r = entr(combo_df["ps_real"]).sum()  # from observed
    
    entr_g = entr(combo_df["ps_gen"]).sum()  # from observed
    
    results = {"jsd":jsd, 
                      "entr_r":entr_r, 
                      "entr_g":entr_g,
                      "NED": entr_r - entr_g,
                      "l1":distance.minkowski(combo_df["ps_real"], combo_df["ps_gen"], p=1), 
                      "l2":distance.minkowski(combo_df["ps_real"], combo_df["ps_gen"], p=2),
                      "jac": distance.jaccard(combo_df["counts_real"]>0, combo_df["counts_gen"] > 0),
                      "count_r": len(real_ngrams),
                      "coverage_r": len(real_ngrams)/N_possible_ngrams,
                      "count_g": len(gen_ngrams),
                      "coverage_g": len(gen_ngrams)/N_possible_ngrams,
                      "count_max": N_possible_ngrams,
                      "field": field, 
                       "n":n, 
                       "pseudo_counts":pseudo_counts}
        
    return combo_df, results

In [82]:
combo_df2, result1 = compute_ngram_metrics(real_sorted, synth_sorted, 'tcode', 3)

In [83]:
result1


{'jsd': 0.20909098792798952,
 'entr_r': 5.482565585140385,
 'entr_g': 5.423129455171868,
 'NED': 0.05943612996851755,
 'l1': 0.9729732151513459,
 'l2': 0.08020194891820201,
 'jac': 0.7232824427480916,
 'count_r': 1469,
 'coverage_r': 0.358642578125,
 'count_g': 538,
 'coverage_g': 0.13134765625,
 'count_max': 4096,
 'field': 'tcode',
 'n': 3,
 'pseudo_counts': 0.0}

## JSD result comparing the univariate distributions of the tcode (Tcode)

In [84]:
real_distribution = real['tcode'].value_counts(normalize=True).sort_index()
synthetic_distribution = synth['tcode'].value_counts(normalize=True).sort_index()
# Add missing categories to the synthetic distribution with 0 probability


In [85]:
df_tcode = pd.merge(real_distribution, synthetic_distribution, left_index=True, right_index=True, how='outer')
df_tcode.columns = ['real', 'synthetic']

# Fill missing values with 0
df_tcode.fillna(0, inplace=True)

In [86]:
df_tcode['mid'] = (df_tcode['real'] + df_tcode['synthetic'])/2

In [87]:
kl_real_M = sum(rel_entr(df_tcode['real'], df_tcode['mid'])) 
kl_gen_M = sum(rel_entr(df_tcode['synthetic'], df_tcode['mid']))

jsd = (kl_real_M + kl_gen_M)/2

In [88]:
jsd

0.011530278901628317

## JSD result comparing the univariate distributions of the dtme_cat

In [89]:
real_distribution2 = real['dtme_cat'].value_counts(normalize=True).sort_index()
synthetic_distribution2 = synth['dtme_cat'].value_counts(normalize=True).sort_index()
df_dom = pd.merge(real_distribution2, synthetic_distribution2, left_index=True, right_index=True, how='outer')
df_dom.columns = ['real', 'synthetic']
# Fill missing values with 0
df_dom.fillna(0, inplace=True)
df_dom['mid'] = (df_dom['real'] + df_dom['synthetic'])/2
kl_real_M = sum(rel_entr(df_dom['real'], df_dom['mid'])) 
kl_gen_M = sum(rel_entr(df_dom['synthetic'], df_dom['mid']))

jsd = (kl_real_M + kl_gen_M)/2
jsd


0.0002999375208018228

## JSD result comparing the univariate distributions of the DOM

In [90]:
real_distribution3 = real['day'].value_counts(normalize=True).sort_index()
synthetic_distribution3 = synth['day'].value_counts(normalize=True).sort_index()
df_dom = pd.merge(real_distribution3, synthetic_distribution3, left_index=True, right_index=True, how='outer')
df_dom.columns = ['real', 'synthetic']
# Fill missing values with 0
df_dom.fillna(0, inplace=True)
df_dom['mid'] = (df_dom['real'] + df_dom['synthetic'])/2
kl_real_M = sum(rel_entr(df_dom['real'], df_dom['mid'])) 
kl_gen_M = sum(rel_entr(df_dom['synthetic'], df_dom['mid']))

jsd = (kl_real_M + kl_gen_M)/2
jsd

0.1428194689483469

## compare to the results in paper

In [91]:
data = {
    'Model': ['SeqCTGAN', 'BF', 'DG', 'TG'],
    'Amt': [1626, 2102, 1939, 1931],
    'CF': [2542, 2738, 57800, 4980],
    'Tcode': [0.01, 0.004, 0.007, 0.075],
    'DoM': [0.14, 0.011, 0.09, 0.059],
    'Tcode 3G': [0.2, 0.042, 0.132, 0.337]
}

In [92]:
df = pd.DataFrame(data)
df

Unnamed: 0,Model,Amt,CF,Tcode,DoM,Tcode 3G
0,SeqCTGAN,1626,2542,0.01,0.14,0.2
1,BF,2102,2738,0.004,0.011,0.042
2,DG,1939,57800,0.007,0.09,0.132
3,TG,1931,4980,0.075,0.059,0.337


## synth2

In [93]:
import pandas as pd
synth2 = pd.read_csv('synth2.csv')
synth2['type'] = synth2['tcode'].str.split('__').str[0]
synth2['raw_amount'] = synth2.apply(lambda row: row['amount'] if row['type'] == 'CREDIT' else -row['amount'], axis=1)
# Sort the DataFrame based on 'account_id', 'year', 'month', 'dow', and 'dtme_cat'
synth_sorted2 = synth2.sort_values(['account_id', 'year', 'month', 'day'])


In [94]:
synth_sorted2

Unnamed: 0,amount,tcode,month,dow,year,day,account_id,type,raw_amount
19,2256.261628,DEBIT__CASH WITHDRAWAL__nan,1,0,1996,2,1,DEBIT,-2256.261628
130,2596.664712,DEBIT__CASH WITHDRAWAL__nan,1,5,1996,2,1,DEBIT,-2596.664712
161,25298.418754,DEBIT__REMITTANCE TO ANOTHER BANK__,1,5,1996,9,1,DEBIT,-25298.418754
91,1243.841528,DEBIT__CASH WITHDRAWAL__nan,1,0,1996,25,1,DEBIT,-1243.841528
29,2386.010944,DEBIT__CASH WITHDRAWAL__nan,1,5,1996,27,1,DEBIT,-2386.010944
...,...,...,...,...,...,...,...,...,...
232303,8614.929748,DEBIT__REMITTANCE TO ANOTHER BANK__HOUSEHOLD,12,4,1998,25,1000,DEBIT,-8614.929748
232001,-134.654998,CREDIT__nan__INTEREST CREDITED,12,0,1998,31,1000,CREDIT,-134.654998
232065,7.044088,CREDIT__nan__INTEREST CREDITED,12,2,1998,31,1000,CREDIT,7.044088
232088,-417.322429,DEBIT__CASH WITHDRAWAL__PAYMENT ON STATEMENT,12,3,1998,31,1000,DEBIT,417.322429


In [95]:
from scipy.stats import energy_distance
from scipy.stats import wasserstein_distance
from scipy.stats import ks_2samp

def ks_dist(real_obs, gen_obs):
    stat, pval = ks_2samp(real_obs, gen_obs)
    
    return stat

# the column names of continous features
CONT_FIELDS = ["amount"]
CONTINUOUS_METRICS = {"wasser": wasserstein_distance,
                "ks": ks_dist,
                "energy_d": energy_distance}
univariate_cont_res = {}

for field in CONT_FIELDS:
    univariate_cont_res[field] = {}
    for name, fn in CONTINUOUS_METRICS.items():
        univariate_cont_res[field][name] = fn(real[field], synth2[field])

In [96]:
univariate_cont_res

{'amount': {'wasser': 1947.9626640897031,
  'ks': 0.24419681840801521,
  'energy_d': 16.1206782223226}}

In [97]:
# Calculate the sum of 'raw_amount' for each 'account_id' in each 'month'
real_cf = real[["account_id", "month", "raw_amount", "year"]].groupby(["account_id", "month", "year"],as_index=False)["raw_amount"].sum()
synth_cf2 = synth2[["account_id", "month", "raw_amount", "year"]].groupby(["account_id", "month", "year"],as_index=False)["raw_amount"].sum()

In [98]:
real_cf

Unnamed: 0,account_id,month,year,raw_amount
0,1,1,1996,-4593.5
1,1,1,1997,157.8
2,1,1,1998,-1621.7
3,1,2,1996,1372.1
4,1,2,1997,1770.7
...,...,...,...,...
185052,11382,11,1998,12336.5
185053,11382,12,1995,10563.6
185054,11382,12,1996,-5313.8
185055,11382,12,1997,17677.7


In [99]:
field = 'raw_amount'
univariate_cont_res['CF'] = {}
for name, fn in CONTINUOUS_METRICS.items():
    univariate_cont_res['CF'][name] = fn(real_cf[field], synth_cf2[field])

In [100]:
univariate_cont_res

{'amount': {'wasser': 1947.9626640897031,
  'ks': 0.24419681840801521,
  'energy_d': 16.1206782223226},
 'CF': {'wasser': 23146.255661084673,
  'ks': 0.41126547304691563,
  'energy_d': 102.82834319994029}}

In [101]:
# Sort the DataFrame based on 'account_id', 'year', 'month', 'dow', and 'dtme_cat'
real_sorted = real.sort_values(['account_id', 'year', 'month', 'day'])
# JSD between the distributions of tcode 3-grams
combo_df, result = compute_ngram_metrics(real_sorted, synth_sorted2, 'tcode', 3)
result['jsd']

0.2578011887289422

In [102]:
#JSD results comparing the univariate distributions of the tcode (Tcode)
real_distribution = real['tcode'].value_counts(normalize=True).sort_index()
synthetic_distribution = synth2['tcode'].value_counts(normalize=True).sort_index()
df_tcode = pd.merge(real_distribution, synthetic_distribution, left_index=True, right_index=True, how='outer')
df_tcode.columns = ['real', 'synthetic']

# Fill missing values with 0
df_tcode.fillna(0, inplace=True)
df_tcode['mid'] = (df_tcode['real'] + df_tcode['synthetic'])/2
kl_real_M = sum(rel_entr(df_tcode['real'], df_tcode['mid'])) 
kl_gen_M = sum(rel_entr(df_tcode['synthetic'], df_tcode['mid']))

jsd = (kl_real_M + kl_gen_M)/2
jsd

0.02852238898191151

In [103]:
#JSD results comparing the univariate distributions of the transaction day of the month (DoM)
real_distribution2 = real['day'].value_counts(normalize=True).sort_index()
synthetic_distribution2 = synth2['day'].value_counts(normalize=True).sort_index()
df_dom = pd.merge(real_distribution2, synthetic_distribution2, left_index=True, right_index=True, how='outer')
df_dom.columns = ['real', 'synthetic']
# Fill missing values with 0
df_dom.fillna(0, inplace=True)
df_dom['mid'] = (df_dom['real'] + df_dom['synthetic'])/2
kl_real_M = sum(rel_entr(df_dom['real'], df_dom['mid'])) 
kl_gen_M = sum(rel_entr(df_dom['synthetic'], df_dom['mid']))

jsd = (kl_real_M + kl_gen_M)/2
jsd


0.02261430100935216

In [104]:
data = {
    'Model': ['SeqCTGAN2', 'SeqCTGAN1', 'BF', 'DG', 'TG'],
    'Amt': [1947, 1626, 2102, 1939, 1931],
    'CF': [23146, 2542, 2738, 57800, 4980],
    'Tcode': [0.028,0.01, 0.004, 0.007, 0.075],
    'DoM': [0.02, 0.14, 0.011, 0.09, 0.059],
    'Tcode 3G': [0.25, 0.2, 0.042, 0.132, 0.337],
}

In [105]:
df = pd.DataFrame(data)
df

Unnamed: 0,Model,Amt,CF,Tcode,DoM,Tcode 3G
0,SeqCTGAN2,1947,23146,0.028,0.02,0.25
1,SeqCTGAN1,1626,2542,0.01,0.14,0.2
2,BF,2102,2738,0.004,0.011,0.042
3,DG,1939,57800,0.007,0.09,0.132
4,TG,1931,4980,0.075,0.059,0.337
