In [1]:
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import pickle
from scipy import stats

In [2]:
dfs = {}
dfs_melted = {}
all_dfs = pd.DataFrame()
for n_groups in [2,3,4,5]:
    temp_dict = pickle.load(open(f'experiment_results/final_scores_{n_groups}.pkl', 'rb'))
    temp_reform = {(outerKey, innerKey): values for outerKey, innerDict in temp_dict.items() for innerKey, values in innerDict.items()}
    dfs[n_groups] = pd.DataFrame(temp_reform)
    dfs_melted[n_groups] = dfs[n_groups].reset_index().melt(id_vars=['index'])
    dfs_melted[n_groups]['n_groups'] = n_groups
    all_dfs = all_dfs.append(dfs_melted[n_groups])

In [3]:
all_dfs.replace('LSTM_MLP_with_CD', 'COSCI-GAN', inplace=True)
all_dfs.replace('LSTM_without_CD', 'COSCI-GAN without CD', inplace=True)

all_dfs_without_CD = all_dfs[all_dfs['index'] != 'COSCI-GAN without CD']

In [4]:
n_groups = 5
fig = px.box(dfs_melted[n_groups], x='variable_1', y='value', color='index')
real_mean = dfs[n_groups].T['real'].mean()
# fig.add_trace(go.Scatter(x=[1,10], y=[real_mean, real_mean], mode="lines", name=""))
fig.add_hline(real_mean,line_dash="dash", line_color="red")
fig.show()

In [6]:
def add_anootation(fig, x0, x1, symbol):
    fig.add_shape(type="line",
        xref="x", yref="y"+" domain",
        x0=x0, y0=1.00, 
        x1=x0, y1=0.98,
        line=dict(color='black', width=2,)
    )
    # Horizontal line
    fig.add_shape(type="line",
        xref="x", yref="y"+" domain",
        x0=x0, y0=1.00,
        x1=x1, y1=1.00,
        line=dict(color='black', width=2,)
    )
    # Vertical line
    fig.add_shape(type="line",
        xref="x", yref="y"+" domain",
        x0=x1, y0=1.00, 
        x1=x1, y1=0.98,
        line=dict(color='black', width=2,)
    )
    ## add text at the correct x, y coordinates
    ## for bars, there is a direct mapping from the bar number to 0, 1, 2...
    fig.add_annotation(dict(font=dict(color='black',size=14),
        x=(x0 + x1)/2,
        y=1.05,
        showarrow=False,
        text=symbol,
        textangle=0,
        xref="x",
        yref="y"+" domain"
    ))
    return fig

def symbol_generator(data_1, data_2):
    pvalue = stats.ttest_ind(
        data_1,
        data_2,
        equal_var=False,
    )[1]
    if pvalue >= 0.05:
        symbol = 'ns'
    elif pvalue >= 0.01: 
        symbol = '*'
    elif pvalue >= 0.001:
        symbol = '**'
    else:
        symbol = '***'
    return symbol

In [33]:
fig = px.box(all_dfs_without_CD, x='n_groups', y='value', color='index', labels={'n_groups': 'Number of channels', 'value': 'Accuracy', 'index': 'Models'}, color_discrete_sequence=['red', 'green', 'blue'] )

# Group 2
data_1 = all_dfs_without_CD[(all_dfs_without_CD['n_groups']==2) & (all_dfs_without_CD['index']=='COSCI-GAN')]['value']
data_2 = all_dfs_without_CD[(all_dfs_without_CD['n_groups']==2) & (all_dfs_without_CD['index']=='baseline')]['value']

fig = add_anootation(fig, 1.75, 1.99, symbol_generator(data_1, data_2))

data_2 = all_dfs_without_CD[(all_dfs_without_CD['n_groups']==2) & (all_dfs_without_CD['index']=='real')]['value']
fig = add_anootation(fig, 2.01, 2.25, symbol_generator(data_1, data_2))

# Group 3
data_1 = all_dfs_without_CD[(all_dfs_without_CD['n_groups']==3) & (all_dfs_without_CD['index']=='COSCI-GAN')]['value']
data_2 = all_dfs_without_CD[(all_dfs_without_CD['n_groups']==3) & (all_dfs_without_CD['index']=='baseline')]['value']

fig = add_anootation(fig, 2.75, 2.99, symbol_generator(data_1, data_2))

data_2 = all_dfs_without_CD[(all_dfs_without_CD['n_groups']==3) & (all_dfs_without_CD['index']=='real')]['value']
fig = add_anootation(fig, 3.01, 3.25, symbol_generator(data_1, data_2))

# Group 4
data_1 = all_dfs_without_CD[(all_dfs_without_CD['n_groups']==4) & (all_dfs_without_CD['index']=='COSCI-GAN')]['value']
data_2 = all_dfs_without_CD[(all_dfs_without_CD['n_groups']==4) & (all_dfs_without_CD['index']=='baseline')]['value']

fig = add_anootation(fig, 3.75, 3.99, symbol_generator(data_1, data_2))

data_2 = all_dfs_without_CD[(all_dfs_without_CD['n_groups']==4) & (all_dfs_without_CD['index']=='real')]['value']
fig = add_anootation(fig, 4.01, 4.25, symbol_generator(data_1, data_2))

# Group 5
data_1 = all_dfs_without_CD[(all_dfs_without_CD['n_groups']==5) & (all_dfs_without_CD['index']=='COSCI-GAN')]['value']
data_2 = all_dfs_without_CD[(all_dfs_without_CD['n_groups']==5) & (all_dfs_without_CD['index']=='baseline')]['value']

fig = add_anootation(fig, 4.75, 4.99, symbol_generator(data_1, data_2))

data_2 = all_dfs_without_CD[(all_dfs_without_CD['n_groups']==5) & (all_dfs_without_CD['index']=='real')]['value']
fig = add_anootation(fig, 5.01, 5.25, symbol_generator(data_1, data_2))

fig.update_layout(yaxis_range=[0.45,1])
fig.update_layout(
    yaxis = dict(
        tickmode = 'array',
        tickvals = [0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1],
    ))
fig.update_layout(showlegend=False, margin=dict(l=0, r=15, b=0, t=23), font=dict(size=23), width=700, height=600)
fig.show()

In [35]:
fig.write_image("all_fake.pdf")

In [12]:
data_1 = all_dfs_without_CD[(all_dfs_without_CD['n_groups']==4) & (all_dfs_without_CD['index']=='COSCI-GAN')]['value']
data_1

1      0.687805
5      0.687805
9      0.724390
13     0.687805
17     0.695122
21     0.714634
25     0.682927
29     0.717073
33     0.690244
37     0.717073
41     0.678049
45     0.707317
49     0.675610
53     0.675610
57     0.690244
61     0.721951
65     0.670732
69     0.697561
73     0.697561
77     0.687805
81     0.697561
85     0.736585
89     0.717073
93     0.751220
97     0.653659
101    0.643902
105    0.714634
109    0.685366
113    0.692683
117    0.695122
Name: value, dtype: float64

In [7]:
all_dfs = []
for n_groups in [1,2,3,4,5]:
    temp_df = dfs[n_groups].T.droplevel(0)
    temp_df['amount'] = dfs[n_groups].T.droplevel(0).index
    temp_df['n_groups'] = n_groups
    all_dfs.append(temp_df)

all_dfs = pd.concat(all_dfs)
all_dfs

Unnamed: 0,baseline,LSTM_MLP_with_CD,LSTM_without_CD,real,amount,n_groups
1,0.587805,0.580488,0.543902,0.62439,1,1
2,0.529268,0.553659,0.585366,,2,1
4,0.551220,0.597561,0.560976,,4,1
6,0.536585,0.560976,0.570732,,6,1
8,0.524390,0.553659,0.482927,,8,1
...,...,...,...,...,...,...
2,0.495122,0.609756,0.534146,,2,5
4,0.539024,0.617073,0.592683,,4,5
6,0.482927,0.619512,0.565854,,6,5
8,0.453659,0.660976,0.578049,,8,5


In [8]:
fig = px.scatter(all_dfs, x='baseline', y='LSTM_MLP_with_CD', color='n_groups', size='amount')
fig.update_layout(width = 550,
                  height = 500,)
fig.update_xaxes(range=[0.45, 0.8])
fig.update_yaxes(range=[0.45, 0.8])
fig.update_xaxes(tickvals=[0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8])
fig.update_yaxes(tickvals=[0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8])
fig.add_trace(go.Scatter(
    x = [0.45,0.8],
    y = [0.45,0.8]
))
fig.update_layout(showlegend=False)

