In [21]:
import pandas as pd
import numpy as np
import plotly.graph_objects as go
from using_deepsulci.settings import Settings

kwargs = {
    'train_cohort': 'pclean50A',
    'hemi': 'L',
    'model': 'unet3d',
    'run': '01'
}

settings = Settings()
f = settings.outputs.get_from_template("labelled_cohort_stats", **kargs)
df = pd.read_csv(f)

ss_list = []
data = {k: [] for k in ["sulci", "bacc", "acc", "esi", "specificity", "sensibility"]}
for colname in df.columns:
    name = '_'.join(colname.split('_')[1:])
    if len(name) > 3 and name not in ss_list:
        ss_list.append(name)
        data['sulci'].append(name)
        data['bacc'].append(np.mean(df['bacc_' + name]))
        data['acc'].append(np.mean(df['acc_' + name]))
        data['esi'].append(np.mean(df['ESI_' + name]))
        data['specificity'].append(np.mean(df['spec_' + name]))
        data['sensibility'].append(np.mean(df['sens_' + name]))
ss_list = np.asarray(ss_list)[np.argsort(data['bacc'])[::-1]]
data = pd.DataFrame(data)
data = data.sort_values('bacc', ascending=False)

c = ['hsl('+str(h)+',50%'+',50%)' for h in np.linspace(0, 360, len(ss_list))]
fig = go.Figure(
    data=[go.Bar(y=bacc)],
)

fig = go.Figure(
    data=[go.Box( y=df['acc_' + ss], marker_color=c[i] ) for i, ss in enumerate(ss_list)],
    layout_title_text="Average Balanced Accuracy"
)

fig.update_layout(
    xaxis = dict(
        tickmode = 'array',
        tickvals = np.arange(len(ss_list)),
        ticktext = ss_list
    )
)

fig.show()