In [23]:
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import pickle
from scipy import stats

In [24]:
dfs = {}
dfs_melted = {}
all_dfs = pd.DataFrame()
for n_groups in [2,3,4,5]:
    temp_dict = pickle.load(open(f'experiment_results/final_scores_{n_groups}.pkl', 'rb'))
    temp_reform = {(outerKey, innerKey): values for outerKey, innerDict in temp_dict.items() for innerKey, values in innerDict.items()}
    dfs[n_groups] = pd.DataFrame(temp_reform)
    dfs_melted[n_groups] = dfs[n_groups].reset_index().melt(id_vars=['index'])
    dfs_melted[n_groups]['n_groups'] = n_groups
    all_dfs = all_dfs.append(dfs_melted[n_groups])

In [25]:
all_dfs.replace('LSTM_MLP_with_CD', 'GroupGAN', inplace=True)
all_dfs.replace('LSTM_without_CD', 'GroupGAN without CD', inplace=True)
all_dfs.replace('real', 'Real', inplace=True)
all_dfs.replace('baseline', 'Baseline', inplace=True)

In [26]:
n_groups = 5
fig = px.box(dfs_melted[n_groups], x='variable_1', y='value', color='index')
real_mean = dfs[n_groups].T['real'].mean()
# fig.add_trace(go.Scatter(x=[1,10], y=[real_mean, real_mean], mode="lines", name=""))
fig.add_hline(real_mean,line_dash="dash", line_color="red")
fig.show()

In [34]:
def add_anootation(fig, x0, x1, symbol, y_offset=0.0):
    fig.add_shape(type="line",
        xref="x", yref="y"+" domain",
        x0=x0, y0=1.00 + y_offset, 
        x1=x0, y1=0.98 + y_offset,
        line=dict(color='black', width=2,)
    )
    # Horizontal line
    fig.add_shape(type="line",
        xref="x", yref="y"+" domain",
        x0=x0, y0=1.00 + y_offset,
        x1=x1, y1=1.00 + y_offset,
        line=dict(color='black', width=2,)
    )
    # Vertical line
    fig.add_shape(type="line",
        xref="x", yref="y"+" domain",
        x0=x1, y0=1.00 + y_offset,
        x1=x1, y1=0.98 + y_offset,
        line=dict(color='black', width=2,)
    )
    ## add text at the correct x, y coordinates
    ## for bars, there is a direct mapping from the bar number to 0, 1, 2...
    if symbol != 'ns':
        fig.add_annotation(dict(font=dict(color='black',size=14),
            x=(x0 + x1)/2,
            y=1.03 + y_offset,
            showarrow=False,
            text=symbol,
            textangle=0,
            xref="x",
            yref="y"+" domain"
        ))
    else:
        fig.add_annotation(dict(font=dict(color='black',size=14),
            x=(x0 + x1)/2,
            y=1.05 + y_offset,
            showarrow=False,
            text=symbol,
            textangle=0,
            xref="x",
            yref="y"+" domain"
        ))
    return fig

def symbol_generator(data_1, data_2):
    pvalue = stats.ttest_ind(
        data_1,
        data_2,
        equal_var=False,
    )[1]
    if pvalue >= 0.05:
        symbol = 'ns'
    elif pvalue >= 0.01: 
        symbol = '*'
    elif pvalue >= 0.001:
        symbol = '**'
    else:
        symbol = '***'
    return symbol

In [35]:
all_dfs_total = all_dfs

In [66]:
fig = px.box(all_dfs_total, x='n_groups', y='value', color='index', labels={'index': 'Model', 'value': 'Accuracy', 'n_groups': 'Number of channels'}, color_discrete_sequence=['red', 'green', 'yellow', 'blue'])

# Group 2
data_1 = all_dfs_total[(all_dfs_total['n_groups']==2) & (all_dfs_total['index']=='GroupGAN')]['value']
data_2 = all_dfs_total[(all_dfs_total['n_groups']==2) & (all_dfs_total['index']=='Baseline')]['value']

fig = add_anootation(fig, 1.74, 1.90, symbol_generator(data_1, data_2))

data_2 = all_dfs_total[(all_dfs_total['n_groups']==2) & (all_dfs_total['index']=='GroupGAN without CD')]['value']
fig = add_anootation(fig, 1.93, 2.08, symbol_generator(data_1, data_2))

data_2 = all_dfs_total[(all_dfs_total['n_groups']==2) & (all_dfs_total['index']=='Real')]['value']
fig = add_anootation(fig, 1.93, 2.27, symbol_generator(data_1, data_2), y_offset=0.045)

# Group 3
data_1 = all_dfs_total[(all_dfs_total['n_groups']==3) & (all_dfs_total['index']=='GroupGAN')]['value']
data_2 = all_dfs_total[(all_dfs_total['n_groups']==3) & (all_dfs_total['index']=='Baseline')]['value']

fig = add_anootation(fig, 2.74, 2.90, symbol_generator(data_1, data_2))

data_2 = all_dfs_total[(all_dfs_total['n_groups']==3) & (all_dfs_total['index']=='GroupGAN without CD')]['value']
fig = add_anootation(fig, 2.93, 3.08, symbol_generator(data_1, data_2))

data_2 = all_dfs_total[(all_dfs_total['n_groups']==3) & (all_dfs_total['index']=='Real')]['value']
fig = add_anootation(fig, 2.93, 3.27, symbol_generator(data_1, data_2), y_offset=0.045)

# Group 4
data_1 = all_dfs_total[(all_dfs_total['n_groups']==4) & (all_dfs_total['index']=='GroupGAN')]['value']
data_2 = all_dfs_total[(all_dfs_total['n_groups']==4) & (all_dfs_total['index']=='Baseline')]['value']

fig = add_anootation(fig, 3.74, 3.90, symbol_generator(data_1, data_2))

data_2 = all_dfs_total[(all_dfs_total['n_groups']==4) & (all_dfs_total['index']=='GroupGAN without CD')]['value']
fig = add_anootation(fig, 3.93, 4.08, symbol_generator(data_1, data_2))

data_2 = all_dfs_total[(all_dfs_total['n_groups']==4) & (all_dfs_total['index']=='Real')]['value']
fig = add_anootation(fig, 3.93, 4.27, symbol_generator(data_1, data_2), y_offset=0.045)

# Group 5
data_1 = all_dfs_total[(all_dfs_total['n_groups']==5) & (all_dfs_total['index']=='GroupGAN')]['value']
data_2 = all_dfs_total[(all_dfs_total['n_groups']==5) & (all_dfs_total['index']=='Baseline')]['value']

fig = add_anootation(fig, 4.74, 4.90, symbol_generator(data_1, data_2))

data_2 = all_dfs_total[(all_dfs_total['n_groups']==5) & (all_dfs_total['index']=='GroupGAN without CD')]['value']
fig = add_anootation(fig, 4.93, 5.08, symbol_generator(data_1, data_2))

data_2 = all_dfs_total[(all_dfs_total['n_groups']==5) & (all_dfs_total['index']=='Real')]['value']
fig = add_anootation(fig, 4.93, 5.27, symbol_generator(data_1, data_2), y_offset=0.045)


fig.update_layout(yaxis_range=[0.45,1])
fig.update_layout(
    yaxis = dict(
        tickmode = 'array',
        tickvals = [0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1],
    ))
fig.update_layout(margin=dict(l=0, r=15, b=0, t=38), font=dict(size=23), width=700, height=600, legend_title_text='')
fig.update_layout(legend=dict(yanchor="top", y=0.95, xanchor="left", x=0.05, bgcolor='rgba(0,0,0,0)', bordercolor='rgba(0,0,0,100)', borderwidth=2, font=dict(
            size=18,
        ),))
fig.show()

In [41]:
fig.write_image("all_fake_with_without.pdf")