In [149]:
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import pickle
import numpy as np
from scipy import stats

In [150]:
temp_dict = pickle.load(open(f'all/augmentation_final_scores_0.5_5.pkl', 'rb'))
temp_reform = {(outerKey, innerKey): values for outerKey, innerDict in temp_dict.items() for innerKey, values in innerDict.items()}
dfs = pd.DataFrame(temp_reform)
dfs_not_melted = dfs
dfs_melted = dfs.reset_index().melt(id_vars=['index'])
all_dfs = dfs_melted

In [151]:
temp_dict = pickle.load(open(f'old/augmentation_final_scores_0.5_5.pkl', 'rb'))
temp_reform = {(outerKey, innerKey): values for outerKey, innerDict in temp_dict.items() for innerKey, values in innerDict.items()}
dfs = pd.DataFrame(temp_reform)
dfs_not_melted = pd.concat([dfs_not_melted, dfs], axis=1)
dfs_melted = dfs.reset_index().melt(id_vars=['index'])
all_dfs = all_dfs.append(dfs_melted)

In [152]:
temp_dict = pickle.load(open(f'new/augmentation_final_scores_0.5_5.pkl', 'rb'))
dfs = pd.DataFrame(temp_dict)
dfs_not_melted = pd.concat([dfs_not_melted, dfs], axis=1)
dfs_melted = dfs.reset_index().melt(id_vars=['index'])
all_dfs = all_dfs.append(dfs_melted)

In [153]:
all_dfs = all_dfs.drop(['variable_0', 'variable_1', 'variable'], axis=1)

In [154]:
all_dfs.replace('groupgan', 'Group GAN', inplace=True)
all_dfs.replace('timegan', 'Time GAN', inplace=True)
all_dfs.replace('ff', 'Fourier Flows', inplace=True)
all_dfs.replace('real', 'Real', inplace=True)

In [155]:
fig = px.box(all_dfs,
             x='index', 
             y='value', 
             color='index', 
             labels={'index': 'Model', 'value': 'Accuracy'}, 
             color_discrete_sequence=['green', 'orange', 'purple', 'blue'],
             category_orders={"index": ["Group GAN", "Fourier Flows", "Time GAN", "Real"]})
# make yticks more accurate
fig.update_yaxes(tickprefix="",
                    tick0=0,
                    dtick=0.05)
fig.update_layout(showlegend=False, margin=dict(l=0, r=0, b=0, t=0), font=dict(size=20), width=700, height=500)
# remove the x-axis title
fig.update_xaxes(title_text="",
                    showline=True,
                    showgrid=False,
                    zeroline=True)
fig.show()

In [148]:
fig.write_image("augmentation_sota.pdf")

In [128]:
pvalue = stats.ttest_ind(
            all_dfs[all_dfs['index'] == 'GroupGAN']['value'],
            all_dfs[all_dfs['index'] == 'Time GAN']['value'],
            equal_var=False,
        )[1]
pvalue

8.240363339041626e-10

In [129]:
pvalue = stats.ttest_ind(
            all_dfs[all_dfs['index'] == 'GroupGAN']['value'],
            all_dfs[all_dfs['index'] == 'ff']['value'],
            equal_var=False,
        )[1]
pvalue

nan

In [191]:
def add_anootation(fig, x0, x1, level, symbol):
    level = level/30
    fig.add_shape(type="line",
        xref="x", yref="y"+" domain",
        x0=x0, y0=1.00 + level, 
        x1=x0, y1=0.98 + level,
        line=dict(color='black', width=2,)
    )
    # Horizontal line
    fig.add_shape(type="line",
        xref="x", yref="y"+" domain",
        x0=x0, y0=1.00 + level,
        x1=x1, y1=1.00 + level,
        line=dict(color='black', width=2,)
    )
    # Vertical line
    fig.add_shape(type="line",
        xref="x", yref="y"+" domain",
        x0=x1, y0=1.00 + level,
        x1=x1, y1=0.98 + level,
        line=dict(color='black', width=2,)
    )
    ## add text at the correct x, y coordinates
    ## for bars, there is a direct mapping from the bar number to 0, 1, 2...
    fig.add_annotation(dict(font=dict(color='black',size=14),
        x=(x0 + x1)/2,
        y=1.04 + level,
        showarrow=False,
        text=symbol,
        textangle=0,
        xref="x",
        yref="y"+" domain"
    ))
    return fig

def symbol_generator(data_1, data_2):
    pvalue = stats.ttest_ind(
        data_1,
        data_2,
        equal_var=False,
    )[1]
    if pvalue >= 0.05:
        symbol = 'ns'
    elif pvalue >= 0.01: 
        symbol = '*'
    elif pvalue >= 0.001:
        symbol = '**'
    else:
        symbol = '***'
    return symbol

In [196]:
fig = px.box(all_dfs,
             x='index', 
             y='value', 
             color='index', 
             labels={'index': 'Model', 'value': 'Accuracy'}, 
             color_discrete_sequence=['green', 'orange', 'purple', 'blue'],
             category_orders={"index": ["Group GAN", "Fourier Flows", "Time GAN", "Real"]})
# make yticks more accurate
fig.update_yaxes(tickprefix="",
                    tick0=0,
                    dtick=0.05)
fig.update_layout(showlegend=False, margin=dict(l=0, r=0, b=0, t=0), font=dict(size=20), width=700, height=500)
# remove the x-axis title
fig.update_xaxes(title_text="",
                    showline=True,
                    showgrid=False,
                    zeroline=True)
# Group 2
data_1 = all_dfs[(all_dfs['index']=='Group GAN')]['value']
data_2 = all_dfs[(all_dfs['index']=='Fourier Flows')]['value']

fig = add_anootation(fig, 0, 1, 0, symbol_generator(data_1, data_2))

data_2 = all_dfs[(all_dfs['index']=='Time GAN')]['value']
fig = add_anootation(fig, 0, 2, 1, symbol_generator(data_1, data_2))

data_3 = all_dfs[(all_dfs['index']=='Real')]['value']
fig = add_anootation(fig, 0, 3, 2, symbol_generator(data_1, data_3))


fig.update_layout(yaxis_title=None, margin=dict(l=0, r=0, b=0, t=40), font=dict(size=20))
fig.show()


In [199]:
fig.write_image("augmentation_sota.pdf")

In [140]:
fig = px.scatter(dfs_not_melted.T, x='timegan', y='ff', labels={'timegan': 'Time GAN', 'groupgan': 'Group GAN'})
fig.update_layout(width = 550,
                  height = 400,margin=dict(l=0, r=0, b=0, t=10), font=dict(size=15))
fig.update_xaxes(range=[0.45, 1.0])
fig.update_yaxes(range=[0.45, 1.0])
fig.update_xaxes(tickvals=[0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1.0])
fig.update_yaxes(tickvals=[0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1.0])
fig.update_traces(marker={'size': 10})

fig.add_trace(go.Scatter(
    x = [0.45,1.1],
    y = [0.45,1.1],
    showlegend=False,
    line=dict(color='red')
))


fig.add_trace(go.Scatter(
    x = [0.45,1.1],
    y = [0.55,1.2],
    showlegend=False,
    line=dict(dash='dash', color='black')
))


fig.add_trace(go.Scatter(
    x = [0.45,1.1],
    y = [0.65,1.3],
    showlegend=False,
    line=dict(dash='dash', color='black')
))

fig.add_trace(go.Scatter(
    x = [0.45,1.1],
    y = [0.75,1.4],
    showlegend=False,
    line=dict(dash='dash', color='black')
))

fig.add_trace(go.Scatter(
    x = [0.45,1.1],
    y = [0.35,1.0],
    showlegend=False,
    line=dict(dash='dash', color='black')
))

In [123]:
fig.write_image("scatter.pdf")


In [79]:
all_dfs

Unnamed: 0,baseline,LSTM_MLP_with_CD,LSTM_without_CD,real,amount,n_groups,fraction
1,0.773171,0.753659,0.743902,0.629268,1,2,0.3
2,0.712195,0.763415,0.668293,,2,2,0.3
4,0.782927,0.751220,0.729268,,4,2,0.3
6,0.748780,0.726829,0.658537,,6,2,0.3
8,0.748780,0.704878,0.648780,,8,2,0.3
...,...,...,...,...,...,...,...
2,0.624390,0.829268,0.631707,,2,5,0.3
4,0.648780,0.824390,0.817073,,4,5,0.3
6,0.646341,0.851220,0.709756,,6,5,0.3
8,0.731707,0.751220,0.704878,,8,5,0.3
