# Loading libaries

In [1]:
import numpy as np
import pandas as pd
import plotly.express as px
from scipy import stats
from copy import deepcopy
from plotly.subplots import make_subplots
from catch22 import catch22_all

In [2]:
label = 1

# Extract Catch22 features

In [3]:
timegan_data = np.load(f'generated_datasets/timeGAN_EEG_Eye_State_ZeroOne_chop_5best_{label}.npy')
groupgan_data = np.load(f'generated_datasets/GroupGAN_EEG_Eye_State_ZeroOne_chop_5best_{label}.npy')
ff_data = np.load(f'generated_datasets/FF_EEG_Eye_State_ZeroOne_chop_5best_{label}.npy')

In [4]:
real_data = np.array(pd.read_csv(f'/Users/aliseyfi/Documents/UBC/Research/GroupGAN-private/Dataset/EEG_Eye_State_ZeroOne_chop_5best_{label}.csv'))
real_data = real_data.reshape(real_data.shape[0], -1, timegan_data.shape[-1])

In [5]:
n_groups = timegan_data.shape[-1]

In [6]:
column_names = catch22_all(groupgan_data[0,:,0])['names']

In [7]:
for i in range(n_groups):
    features_generated_all = np.zeros((groupgan_data.shape[0], 22))
    for ind, data in enumerate(groupgan_data[:,:,i]):
        features_generated_all[ind] = catch22_all(data)['values']

    features_generated_all_df = pd.DataFrame(features_generated_all, columns=column_names)
    # save features_df to csv
    features_generated_all_df.to_csv(f'Catch22/GroupGAN_{i}.csv')

In [8]:
for i in range(n_groups):
    features_generated_all_timegan = np.zeros((timegan_data.shape[0], 22))
    for ind_timegan, data_timegan in enumerate(timegan_data[:,:,i]):
        features_generated_all_timegan[ind_timegan] = catch22_all(data_timegan)['values']

    features_generated_all_df_timegan = pd.DataFrame(features_generated_all_timegan, columns=column_names)
    # save features_df to csv
    features_generated_all_df_timegan.to_csv(f'Catch22/TimeGAN_{i}.csv')

In [9]:
for i in range(n_groups):
    features_generated_all_real = np.zeros((real_data.shape[0], 22))
    for ind_real, data_real in enumerate(real_data[:,:,i]):
        features_generated_all_real[ind_real] = catch22_all(data_real)['values']

    features_generated_all_df_real = pd.DataFrame(features_generated_all_real, columns=column_names)
    # save features_df to csv
    features_generated_all_df_real.to_csv(f'Catch22/real_{i}.csv')

In [10]:
for i in range(n_groups):
    features_generated_all_ff = np.zeros((ff_data.shape[0], 22))
    for ind_ff, data_ff in enumerate(ff_data[:,:,i]):
        features_generated_all_ff[ind_ff] = catch22_all(data_ff)['values']

    features_generated_all_df_ff = pd.DataFrame(features_generated_all_ff, columns=column_names)
    # save features_df to csv
    features_generated_all_df_ff.to_csv(f'Catch22/FF_{i}.csv')

# EEG

## Loading Data

In [11]:
DIR = 'Catch22/'

In [12]:
df_real = {}
df_timegan = {}
df_groupgan = {}
df_ff = {}

for i in range(n_groups):
    df_real[i] = pd.read_csv(f'{DIR}real_{i}.csv', index_col=0).add_prefix(f'real_{i}_')
    df_timegan[i] = pd.read_csv(f'{DIR}TimeGAN_{i}.csv', index_col=0).add_prefix(f'timegan_{i}_')
    df_groupgan[i] = pd.read_csv(f'{DIR}GroupGAN_{i}.csv', index_col=0).add_prefix(f'groupgan_{i}_')
    df_ff[i] = pd.read_csv(f'{DIR}FF_{i}.csv', index_col=0).add_prefix(f'ff_{i}_')

df_real[0]

Unnamed: 0,real_0_DN_HistogramMode_5,real_0_DN_HistogramMode_10,real_0_CO_f1ecac,real_0_CO_FirstMin_ac,real_0_CO_HistogramAMI_even_2_5,real_0_CO_trev_1_num,real_0_MD_hrv_classic_pnn40,real_0_SB_BinaryStats_mean_longstretch1,real_0_SB_TransitionMatrix_3ac_sumdiagcov,real_0_PD_PeriodicityWang_th0_01,...,real_0_FC_LocalSimple_mean1_tauresrat,real_0_DN_OutlierInclude_p_001_mdrmd,real_0_DN_OutlierInclude_n_001_mdrmd,real_0_SP_Summaries_welch_rect_area_5_1,real_0_SB_BinaryStats_diff_longstretch0,real_0_SB_MotifThree_quantile_hh,real_0_SC_FluctAnal_2_rsrangefit_50_1_logi_prop_r1,real_0_SC_FluctAnal_2_dfa_50_1_2_logi_prop_r1,real_0_SP_Summaries_welch_rect_centroid,real_0_FC_LocalSimple_mean3_stderr
0,-0.341505,-0.542205,5.0,11.0,0.480219,-0.510614,0.888889,21.0,0.032922,30.0,...,0.090909,0.250,-0.085,0.832639,6.0,1.831255,0.828571,0.714286,0.196350,0.681979
1,0.462960,-0.532909,5.0,9.0,0.472877,-0.096977,0.888889,22.0,0.034014,29.0,...,0.076923,0.220,-0.080,0.833623,6.0,1.755639,0.828571,0.685714,0.196350,0.691626
2,0.463938,0.261536,5.0,9.0,0.471295,-0.327071,0.919192,23.0,0.006803,28.0,...,0.071429,-0.200,-0.080,0.828017,6.0,1.748904,0.828571,0.714286,0.196350,0.689431
3,0.466615,0.260974,5.0,20.0,0.431445,-0.392937,0.929293,21.0,0.013605,27.0,...,0.071429,-0.210,-0.050,0.815713,4.0,1.737983,0.828571,0.714286,0.196350,0.688082
4,-0.211136,1.148033,5.0,20.0,0.424617,-0.186014,0.939394,22.0,0.054422,27.0,...,0.071429,-0.225,-0.020,0.830187,4.0,1.737983,0.828571,0.742857,0.196350,0.662210
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1019,-0.204267,0.013195,5.0,6.0,0.265351,0.085203,0.939394,17.0,0.074830,2.0,...,0.071429,-0.020,0.750,0.699909,5.0,1.956283,0.828571,0.828571,0.196350,0.760793
1020,-0.141329,0.084448,5.0,6.0,0.230401,0.207643,0.969697,16.0,0.013605,7.0,...,0.071429,0.080,0.740,0.640727,5.0,1.945973,0.828571,0.828571,0.245437,0.822684
1021,-0.075538,0.154464,3.0,3.0,0.214318,0.098282,0.969697,15.0,0.027211,3.0,...,0.071429,0.080,0.730,0.592499,5.0,1.946834,0.828571,0.828571,0.343612,0.884301
1022,-0.071338,0.157405,3.0,5.0,0.235900,0.225631,0.969697,14.0,0.054422,7.0,...,0.071429,0.070,0.720,0.549560,6.0,1.915905,0.828571,0.828571,0.490874,0.915448


## Correlation Analysis

### Computing Correlation

In [13]:
# Channel i and j

i = 2
j = 3

real = pd.concat([df_real[i], df_real[j]], axis=1).corr()
real = real[list(df_real[j].columns)].loc[list(df_real[i].columns)]

timegan = pd.concat([df_timegan[i], df_timegan[j]], axis=1).corr()
timegan = timegan[list(df_timegan[j].columns)].loc[list(df_timegan[i].columns)]

groupgan = pd.concat([df_groupgan[i], df_groupgan[j]], axis=1).corr()
groupgan = groupgan[list(df_groupgan[j].columns)].loc[list(df_groupgan[i].columns)]

ff = pd.concat([df_ff[i], df_ff[j]], axis=1).corr()
ff = ff[list(df_ff[j].columns)].loc[list(df_ff[i].columns)]


indices_real = real.columns[real.isna().all()].tolist()

features = list(set(['SC_FluctAnal_2_rsrangefit_50_1_logi_prop_r1', 'SC_FluctAnal_2_dfa_50_1_2_logi_prop_r1']))

real_features_j = ['real_' + str(j) + '_' + feature for feature in features]
timegan_features_j = ['timegan_' + str(j) + '_' + feature for feature in features]
groupgan_features_j = ['groupgan_' + str(j) + '_' + feature for feature in features]
ff_features_j = ['ff_' + str(j) + '_' + feature for feature in features]

real_features_i = ['real_' + str(i) + '_' + feature for feature in features]
timegan_features_i = ['timegan_' + str(i) + '_' + feature for feature in features]
groupgan_features_i = ['groupgan_' + str(i) + '_' + feature for feature in features]
ff_features_i = ['ff_' + str(i) + '_' + feature for feature in features]

real_not_nan = real.drop(real_features_j, axis=1)
timegan_not_nan = timegan.drop(timegan_features_j, axis=1)
groupgan_not_nan = groupgan.drop(groupgan_features_j, axis=1)
ff_not_nan = ff.drop(ff_features_j, axis=1)

real_not_nan = real_not_nan.drop(real_features_i, axis=0)
timegan_not_nan = timegan_not_nan.drop(timegan_features_i, axis=0)
groupgan_not_nan = groupgan_not_nan.drop(groupgan_features_i, axis=0)
ff_not_nan = ff_not_nan.drop(ff_features_i, axis=0)

In [14]:
real_not_nan

Unnamed: 0,real_3_DN_HistogramMode_5,real_3_DN_HistogramMode_10,real_3_CO_f1ecac,real_3_CO_FirstMin_ac,real_3_CO_HistogramAMI_even_2_5,real_3_CO_trev_1_num,real_3_MD_hrv_classic_pnn40,real_3_SB_BinaryStats_mean_longstretch1,real_3_SB_TransitionMatrix_3ac_sumdiagcov,real_3_PD_PeriodicityWang_th0_01,real_3_CO_Embed2_Dist_tau_d_expfit_meandiff,real_3_IN_AutoMutualInfoStats_40_gaussian_fmmi,real_3_FC_LocalSimple_mean1_tauresrat,real_3_DN_OutlierInclude_p_001_mdrmd,real_3_DN_OutlierInclude_n_001_mdrmd,real_3_SP_Summaries_welch_rect_area_5_1,real_3_SB_BinaryStats_diff_longstretch0,real_3_SB_MotifThree_quantile_hh,real_3_SP_Summaries_welch_rect_centroid,real_3_FC_LocalSimple_mean3_stderr
real_2_DN_HistogramMode_5,0.452799,0.329409,-0.12355,-0.022167,-0.148327,-0.035484,-0.07482,0.011939,-0.026233,-0.019993,-0.08741,-0.171422,0.084851,-0.04519,-0.154021,-0.125617,0.071505,-0.026179,0.129467,0.157603
real_2_DN_HistogramMode_10,0.349134,0.504495,-0.129595,-0.033936,-0.066421,0.066015,-0.030432,-0.015309,-0.014064,0.010654,-0.061696,-0.115335,0.029732,0.003759,-0.237511,-0.059188,0.037424,-0.057967,0.046504,0.123825
real_2_CO_f1ecac,-0.112753,-0.078579,0.861768,0.367658,0.669522,0.104292,-0.2828,0.552208,0.300193,0.190436,0.53404,0.391757,-0.377206,-0.291296,0.140837,0.736971,0.260827,-0.560335,-0.711107,-0.794648
real_2_CO_FirstMin_ac,-0.104301,-0.127055,0.415074,0.672852,0.505037,0.046731,-0.336558,0.3968,0.075732,0.462064,0.52448,0.571662,-0.053212,-0.42778,0.152497,0.524654,0.251718,-0.580124,-0.332437,-0.481569
real_2_CO_HistogramAMI_even_2_5,-0.114582,-0.045908,0.699843,0.479052,0.841076,-0.014731,-0.505424,0.58529,0.268183,0.329599,0.597635,0.532943,-0.218221,-0.469243,0.287862,0.766117,0.42179,-0.697405,-0.571913,-0.804017
real_2_CO_trev_1_num,-0.076298,0.091063,0.089586,0.078195,0.067118,0.739667,0.06125,0.060616,0.031277,0.093818,0.141113,0.095664,-0.041858,0.082705,-0.037651,0.157066,0.063529,-0.067543,-0.200845,-0.058601
real_2_MD_hrv_classic_pnn40,-0.014536,0.034587,-0.326224,-0.427199,-0.474178,0.145831,0.921726,-0.284647,-0.343337,-0.270968,-0.482586,-0.287063,-0.131845,0.430766,-0.406393,-0.344727,-0.809592,0.659066,0.200953,0.424774
real_2_SB_BinaryStats_mean_longstretch1,-0.045089,-0.088127,0.490664,0.387995,0.572774,-0.003911,-0.217786,0.758885,0.105841,0.222128,0.387896,0.367573,-0.259298,-0.506896,0.16447,0.493464,0.203918,-0.409555,-0.355114,-0.520693
real_2_SB_TransitionMatrix_3ac_sumdiagcov,-0.044895,-0.000445,0.29371,0.124912,0.192208,-0.001033,-0.314753,0.176498,0.633771,0.038555,0.228209,0.087391,-0.318652,-0.127901,0.169791,0.192214,0.351961,-0.26141,-0.192905,-0.319549
real_2_PD_PeriodicityWang_th0_01,-0.022649,0.002646,0.264224,0.470859,0.32718,0.065796,-0.217034,0.257212,0.009819,0.547794,0.380502,0.420276,-0.002829,-0.24648,0.043863,0.418889,0.196881,-0.420365,-0.267647,-0.332869


### Plotting Correlation

In [15]:
fig = make_subplots(rows=1, cols=4, subplot_titles=('Real Data',
                                     'Group GAN', 'Time GAN', 'Fourier Flow'))

fig.add_trace(
    px.imshow(real_not_nan, 
                labels=dict(x="Channel 1", y="Channel 2", color="Correlation"),
                color_continuous_scale=px.colors.diverging.RdBu, 
                range_color=(-1, 1)).data[0], row=1, col=1
)
fig.add_trace(
    px.imshow(timegan_not_nan, 
                labels=dict(x="Channel 1", y="Channel 2", color="Correlation"),
                color_continuous_scale=px.colors.diverging.RdBu, 
                range_color=(-1, 1)).data[0], row=1, col=3
)

fig.add_trace(
    px.imshow(groupgan_not_nan, 
                labels=dict(x="Channel 1", y="Channel 2", color="Correlation"),
                color_continuous_scale=px.colors.diverging.RdBu, 
                range_color=(-1, 1)).data[0], row=1, col=2
)

fig.add_trace(
    px.imshow(ff_not_nan, 
                labels=dict(x="Channel 1", y="Channel 2", color="Correlation"),
                color_continuous_scale=px.colors.diverging.RdBu, 
                range_color=(-1, 1)).data[0], row=1, col=4
)

fig.update_xaxes(showticklabels=False) 
fig.update_yaxes(showticklabels=False) 
fig.update_xaxes(title_text="Feature Channel 1", row = 1, col = 1)
fig.update_yaxes(title_text="Feature Channel 2", row = 1, col = 1)
fig.update_xaxes(title_text="Feature Channel 1", row = 1, col = 2)
fig.update_xaxes(title_text="Feature Channel 1", row = 1, col = 3)
fig.update_xaxes(title_text="Feature Channel 1", row = 1, col = 4)

fig.update_layout(height=260, width=1100, margin = {'l':0,'r':0,'t':23,'b':0},)
fig.show()

In [16]:
fig.write_image("correlation.pdf")

### Find uper triangle without common Nans

In [17]:
stats.spearmanr(np.array(real_not_nan.values).flatten(), np.array(timegan_not_nan.values).flatten())

SpearmanrResult(correlation=0.7333962087263044, pvalue=1.055668746233312e-68)

### MAE

In [18]:
# Channel i and j
n_channels = 5

MSE_groupgan = 0
MSE_timegan = 0
MSE_ff = 0

MAE_groupgan = 0
MAE_timegan = 0
MAE_ff = 0

Spearsman_groupgan = 0
Spearsman_timegan = 0
Spearsman_ff = 0   

norm_groupgan = 0
norm_timegan = 0
norm_ff = 0

MAEs_groupgan = []
MAEs_timegan = []
MAEs_ff = []
counter = 0

for i in range(0, n_channels):
    for j in range(i, n_channels):
        counter += 1
        
        real = pd.concat([df_real[i], df_real[j]], axis=1).corr()
        real = real[list(df_real[j].columns)].loc[list(df_real[i].columns)]

        timegan = pd.concat([df_timegan[i], df_timegan[j]], axis=1).corr()
        timegan = timegan[list(df_timegan[j].columns)].loc[list(df_timegan[i].columns)]

        groupgan = pd.concat([df_groupgan[i], df_groupgan[j]], axis=1).corr()
        groupgan = groupgan[list(df_groupgan[j].columns)].loc[list(df_groupgan[i].columns)]

        ff = pd.concat([df_ff[i], df_ff[j]], axis=1).corr()
        ff = ff[list(df_ff[j].columns)].loc[list(df_ff[i].columns)]

        features = list(set(['SC_FluctAnal_2_rsrangefit_50_1_logi_prop_r1', 'SC_FluctAnal_2_dfa_50_1_2_logi_prop_r1']))

        real_features_j = ['real_' + str(j) + '_' + feature for feature in features]
        timegan_features_j = ['timegan_' + str(j) + '_' + feature for feature in features]
        groupgan_features_j = ['groupgan_' + str(j) + '_' + feature for feature in features]
        ff_features_j = ['ff_' + str(j) + '_' + feature for feature in features]

        real_features_i = ['real_' + str(i) + '_' + feature for feature in features]
        timegan_features_i = ['timegan_' + str(i) + '_' + feature for feature in features]
        groupgan_features_i = ['groupgan_' + str(i) + '_' + feature for feature in features]
        ff_features_i = ['ff_' + str(i) + '_' + feature for feature in features]

        real_not_nan = real.drop(real_features_j, axis=1)
        timegan_not_nan = timegan.drop(timegan_features_j, axis=1)
        groupgan_not_nan = groupgan.drop(groupgan_features_j, axis=1)
        ff_not_nan = ff.drop(ff_features_j, axis=1)

        real_not_nan = np.array(real_not_nan.drop(real_features_i, axis=0))
        timegan_not_nan = np.array(timegan_not_nan.drop(timegan_features_i, axis=0))
        groupgan_not_nan = np.array(groupgan_not_nan.drop(groupgan_features_i, axis=0))
        ff_not_nan = np.array(ff_not_nan.drop(ff_features_i, axis=0))

        MAE_groupgan_temp = np.mean(np.abs(real_not_nan - groupgan_not_nan))
        MAE_timegan_temp = np.mean(np.abs(real_not_nan - timegan_not_nan))
        MAE_ff_temp = np.mean(np.abs(real_not_nan - ff_not_nan))

        MAE_groupgan += MAE_groupgan_temp
        MAE_timegan += MAE_timegan_temp
        MAE_ff += MAE_ff_temp

        Spearman_groupgan_temp = stats.spearmanr(real_not_nan.flatten(), groupgan_not_nan.flatten())
        Spearman_timegan_temp = stats.spearmanr(real_not_nan.flatten(), timegan_not_nan.flatten())
        Spearman_ff_temp = stats.spearmanr(real_not_nan.flatten(), ff_not_nan.flatten())

        Spearsman_groupgan += Spearman_groupgan_temp[0]
        Spearsman_timegan += Spearman_timegan_temp[0]
        Spearsman_ff += Spearman_ff_temp[0]

        norm_groupgan_temp = np.linalg.norm(real_not_nan - groupgan_not_nan)
        norm_timegan_temp = np.linalg.norm(real_not_nan - timegan_not_nan)
        norm_ff_temp = np.linalg.norm(real_not_nan - ff_not_nan)

        norm_groupgan += norm_groupgan_temp
        norm_timegan += norm_timegan_temp
        norm_ff += norm_ff_temp

        MSE_groupgan_temp = np.mean((real_not_nan - groupgan_not_nan)**2)
        MSE_timegan_temp = np.mean((real_not_nan - timegan_not_nan)**2)
        MSE_ff_temp = np.mean((real_not_nan - ff_not_nan)**2)

        MSE_groupgan += MSE_groupgan_temp
        MSE_timegan += MSE_timegan_temp
        MSE_ff += MSE_ff_temp

        MAEs_groupgan.append(MAE_groupgan_temp)
        MAEs_timegan.append(MAE_timegan_temp)
        MAEs_ff.append(MAE_ff_temp)

        print('Channel ' + str(i) + ' and ' + str(j) + ':')
        print('GroupGAN: ' + str(MAE_groupgan_temp))
        print('TimeGAN: ' + str(MAE_timegan_temp))
        print('FF: ' + str(MAE_ff_temp))

Channel 0 and 0:
GroupGAN: 0.11825316812165192
TimeGAN: 0.25687015644386946
FF: 0.13779135219453761
Channel 0 and 1:
GroupGAN: 0.11470377975526365
TimeGAN: 0.26405552651178665
FF: 0.14537920060326934
Channel 0 and 2:
GroupGAN: 0.11549882343933124
TimeGAN: 0.261502421621957
FF: 0.1491231994116522
Channel 0 and 3:
GroupGAN: 0.11563175181188862
TimeGAN: 0.26597377754517887
FF: 0.1535906323862086
Channel 0 and 4:
GroupGAN: 0.1174027180233502
TimeGAN: 0.2656651974907343
FF: 0.14925887691772083
Channel 1 and 1:
GroupGAN: 0.10405399302215507
TimeGAN: 0.2570947123007378
FF: 0.13548321224643448
Channel 1 and 2:
GroupGAN: 0.10609336217561921
TimeGAN: 0.2569190328094731
FF: 0.1435368279134886
Channel 1 and 3:
GroupGAN: 0.11138411351604134
TimeGAN: 0.26147041922639336
FF: 0.15051634978199413
Channel 1 and 4:
GroupGAN: 0.11362712690607188
TimeGAN: 0.2647777151277615
FF: 0.14924613294287017
Channel 2 and 2:
GroupGAN: 0.09895684255730305
TimeGAN: 0.2402216231358227
FF: 0.13298081527029143
Channel 2 a

In [19]:
print("mean and std of MAE of GroupGAN: " + str(np.mean(MAEs_groupgan)) + ' ' + str(np.std(MAEs_groupgan)))
print("mean and std of MAE of TimeGAN: " + str(np.mean(MAEs_timegan)) + ' ' + str(np.std(MAEs_timegan)))
print("mean and std of MAE of FF: " + str(np.mean(MAEs_ff)) + ' ' + str(np.std(MAEs_ff)))

mean and std of MAE of GroupGAN: 0.11132749853547735 0.005040756888179478
mean and std of MAE of TimeGAN: 0.2570726372214571 0.008531535143097822
mean and std of MAE of FF: 0.14553728766178695 0.0062238850940166505


In [20]:
print("MAE GroupGAN: ", MAE_groupgan/counter)
print("MAE TimeGAN: ", MAE_timegan/counter)
print("MAE FF: ", MAE_ff/counter)

MAE GroupGAN:  0.11132749853547735
MAE TimeGAN:  0.2570726372214571
MAE FF:  0.14553728766178692


In [21]:
print("Spearman GroupGAN: ", Spearsman_groupgan/counter)
print("Spearman TimeGAN: ", Spearsman_timegan/counter)
print("Spearman FF: ", Spearsman_ff/counter)

Spearman GroupGAN:  0.8848469725858855
Spearman TimeGAN:  0.6967821789320937
Spearman FF:  0.9058619273346455


In [22]:
print("Norm GroupGAN: ", norm_groupgan/counter)
print("Norm TimeGAN: ", norm_timegan/counter)
print("Norm FF: ", norm_ff/counter)

Norm GroupGAN:  4.252135884017667
Norm TimeGAN:  8.672334675252968
Norm FF:  4.838795542979265


In [23]:
print("MSE GroupGAN: ", MSE_groupgan/counter)
print("MSE TimeGAN: ", MSE_timegan/counter)
print("MSE FF: ", MSE_ff/counter)

MSE GroupGAN:  0.026015885697387733
MSE TimeGAN:  0.10646817654825318
MSE FF:  0.03368241378878176
