In [None]:
import pandas as pd 
import numpy as np
import glob

## Prep data for error bars

### The function below is what you will change to shift to confidence intervals. You should be able to replace things like per10/per20/... with whatever manipulation you want.

## If you change shaded_error_prep(), single_shaded_error_fig() will need to be updated to support the changes


In [None]:
# takes in df of a subject's trials, then performs calculations on them and outputs the new df

def shaded_error_prep(hold):
    hold = hold.reindex(sorted(hold.columns), axis = 1)

    mean = hold.mean(axis=1,skipna=True)

    median = pd.DataFrame(hold.median(axis=1,skipna=True))

    ## Percentiles

    upper_percentile = pd.DataFrame(np.nanpercentile(hold, 75, axis = 1))
    mid_percentile = pd.DataFrame(np.nanpercentile(hold, 50, axis = 1))
    low_percentile = pd.DataFrame(np.nanpercentile(hold, 25, axis = 1))

    per10 = pd.DataFrame(np.nanpercentile(hold, 10, axis = 1))
    per20 = pd.DataFrame(np.nanpercentile(hold, 20, axis = 1))
    per30 = pd.DataFrame(np.nanpercentile(hold, 30, axis = 1))
    per40 = pd.DataFrame(np.nanpercentile(hold, 40, axis = 1))

    min = hold.min(axis=1)

    max = hold.max(axis=1)

    std = hold.std(axis=1, skipna = True,)

    hold['10%'] = per10
    hold['20%'] = per20
    hold['30%'] = per30
    hold['40%'] = per40

    hold['mean'] = mean
    hold['75%'] = upper_percentile
    hold['25%'] = low_percentile
    hold['median'] = median

    hold['min'] = min
    hold['max'] = max 
    hold['std'] = std

    return hold

## This is responsible for printing the figures

### acc_shaded_error produces multi-fig displays, like the pre-post or sess2&5 graphs. 

In [None]:
import plotly.graph_objs as go 
from plotly.subplots import make_subplots


## Inputs:
# folder: output file folder path (where you want to save figs to)
# name: what you would like to call this particular file
# df_list: list of dfs you would like to plot. each df will receive its own subplot on the figure. (can be swapped between vertical and horizontal arrangement by flipping i and 1)
def acc_shaded_error(folder, name, df_list):
    rows = len(df_list)
    fig = make_subplots(rows=rows, cols=1)

    i = 0
    for df in df_list:
        i += 1
        fig = single_shaded_error_fig(fig,df,i, 1)

    save_path = 'chunk_accs_figs/' + folder + '/' + name + '.png'

    fig.write_image(save_path)


## Inputs:
# Should basically be handled by previous function
# Fig: fig object to add plot to
# df: df to use for plot
# row: row position in figure
# col: col position in figure
#
# Any edits made in shaded_error_prep must be cascaded into this function
def single_shaded_error_fig(fig, df, row, col):

    legend = True
    if row*col > 1:
        legend = False

    x_vals = list(df.index)
    y_vals = df['median']
    y_10 = df['10%']
    y_20 = df['20%']
    y_30 = df['30%']
    y_40 = df['40%']
    y_75 = df['75%']


    fig.add_scatter(
        name = '10%',
        x = x_vals,
        y = y_10,
        mode = 'lines',
        marker = dict(color="#33ccff"),
        line = dict(width=1.5),
        showlegend=legend,
        row = row,
        col = col
    )
    fig.add_scatter(
        name = '20%',
        x = x_vals,
        y = y_20,
        mode = 'lines',
        marker = dict(color="#03fce3"),
        line = dict(width=1.5),
        fillcolor = 'rgba(68, 68, 68, 0.3)',
        fill='tonexty',
        showlegend=legend,
        row = row,
        col = col
    )
    fig.add_scatter(
        name = '30%',
        x = x_vals,
        y = y_30,
        mode = 'lines',
        marker = dict(color="#9dfc03"),
        line = dict(width=1.5),
        fillcolor = 'rgba(68, 68, 68, 0.3)',
        fill='tonexty',
        showlegend=legend,
        row = row,
        col = col
    )
    fig.add_scatter(
        name = '40%',
        x = x_vals,
        y = y_40,
        mode = 'lines',
        marker = dict(color="#ffd966"),
        line = dict(width=1.5),
        fillcolor = 'rgba(68, 68, 68, 0.3)',
        fill='tonexty',
        showlegend=legend,
        row = row,
        col = col
    )
    fig.add_scatter(
        name = 'median',
        x = x_vals,
        y = y_vals,
        mode = 'lines',
        marker = dict(color="#ff6666"),
        line = dict(width=1.5),
        fillcolor = 'rgba(68, 68, 68, 0.3)',
        fill='tonexty',
        showlegend=legend,
        row = row,
        col = col
    )
    fig.add_scatter(
        name = '75%',
        x = x_vals,
        y = y_75,
        mode = 'lines',
        marker = dict(color="#cc33ff"),
        line = dict(width=1.5),
        showlegend=legend,
        fillcolor = 'rgba(68, 68, 68, 0.3)',
        fill='tonexty',
        row = row,
        col = col
    )    
    
    fig.update_layout(
        yaxis_title='rolling acc (25ticks)',
        title=name,
        hovermode="x"
    )

    return fig

### This bit makes figures for single participant single session. just a way for you to see the general structure of this figure-maker.

In [None]:
files = glob.glob("chunk_accs/subj_accs/*.csv")

# Gets some relevant characteristics out of known file path structures
def get_subjsess(path):
    file = path.split('/')[-1]

    hold = file.split('_')

    subj = hold[0]
    sess = hold[1]

    return subj,sess

for path in files:
    df = pd.read_csv(path)

    subj, sess = get_subjsess(path)

    name = subj + '_' + sess

    print(name)

    acc_shaded_error('subj_accs_figs', name,[df])

## This next bit preps the files for use in figures

### intake files should be included in this folder, check subj_accs

### This next chunk takes individual files and builds different aggregated files, for groupings like bio/arb, motion class, and session. Feel free to group however you like, and use column names in subJ_accs to build new dfs.

#### I just aggregated things by different groupings so that it would be easier to process later. you can prep however you like really, as long as your data is in the right format by the time it goes into acc_shaded_error()

In [None]:
bio_files = glob.glob("chunk_accs/subj_accs/*bi*.csv")
arb_files = glob.glob("chunk_accs/subj_accs/*ar*.csv")

## Remove problem participants

bi_rest = pd.DataFrame()
bi_open = pd.DataFrame()
bi_close = pd.DataFrame()
ar_rest = pd.DataFrame()
ar_open = pd.DataFrame()
ar_close = pd.DataFrame()

sess2_files = [file for file in files if 'sess2' in file]
sess5_files = [file for file in files if 'sess5' in file]

which_files = [(sess2_files, 'sess2'),(sess5_files, 'sess5')]

for pair in which_files:
    files = pair[0]
    session = pair[1]
    for path in files:
        subj,sess = get_subjsess(path)
        name = subj + '_' + sess + '_'
        df = pd.read_csv(path)
        rest = 0
        open = 0
        close = 0
        if 'sub-bi' in path:
            for column in df.columns.values:
                if 'rest' in column:
                    rest += 1
                    bi_rest[name + column] = df[column]
                elif 'open' in column:
                    open += 1
                    bi_open[name + column] = df[column]
                elif 'close' in column:
                    close += 1
                    bi_close[name + column] = df[column]
                    
        elif 'sub-ar' in path:
            for column in df.columns.values:
                if 'rest' in column:
                    rest += 1
                    ar_rest[name + column] = df[column]
                elif 'open' in column:
                    open += 1
                    ar_open[name + column] = df[column]
                elif 'close' in column:
                    close += 1
                    ar_close[name + column] = df[column]
                    
    save_path = 'chunk_accs/aggregate_accs/' + session + '_'

    def reindex_df(df):
        return df.reindex(sorted(df.columns), axis=1)

    bi_agg = pd.DataFrame()
    bi_agg = bi_rest
    bi_agg = bi_agg.join(bi_open)
    bi_agg = bi_agg.join(bi_close)

    ar_agg = pd.DataFrame()
    ar_agg = ar_rest
    ar_agg = ar_agg.join(ar_open)
    ar_agg = ar_agg.join(ar_close)


    bi_agg = shaded_error_prep(reindex_df(bi_agg))
    ar_agg = shaded_error_prep(reindex_df(ar_agg))


    bi_rest = shaded_error_prep(reindex_df(bi_rest))
    bi_open = shaded_error_prep(reindex_df(bi_open))
    bi_close = shaded_error_prep(reindex_df(bi_close))
    ar_rest = shaded_error_prep(reindex_df(ar_rest))
    ar_open = shaded_error_prep(reindex_df(ar_open))
    ar_close = shaded_error_prep(reindex_df(ar_close))


    bi_agg.to_csv(save_path + 'bio_agg.csv')
    bi_rest.to_csv(save_path + 'bio_rest_aggregate.csv')
    bi_open.to_csv(save_path + 'bio_open_aggregate.csv') 
    bi_close.to_csv(save_path + 'bio_close_aggregate.csv')
    ar_agg.to_csv(save_path + 'arb_agg.csv')
    ar_rest.to_csv(save_path + 'arb_rest_aggregate.csv')
    ar_open.to_csv(save_path + 'arb_open_aggregate.csv')
    ar_close.to_csv(save_path + 'arb_close_aggregate.csv')

## Builds Error Bars for sess2 & 5

In [None]:
agg_files = glob.glob("chunk_accs/aggregate_accs/*.csv")

for file in agg_files:
    
    hold = pd.read_csv(file)

    bits = file.split('/')[-1]
    bits = bits.split('_')
    sess = bits[0]
    group = bits[1]
    gesture = bits[2]


    df_list = [hold]
    if 'sess2' in file:
        file2 = list(filter(lambda a: 'sess5_' + group + '_' + gesture in a, agg_files))[0]
        if file2:
            df_list.append(pd.read_csv(file2))

            sess = 'sess2&5'
    elif 'sess5' in file:
        continue 

    name = ''
    if 'agg.csv' in file:
        name = group + '_' + sess 
    else:
        name = group + '_' + gesture + '_' + sess

    acc_shaded_error('paired_agg_accs_figs', name, df_list)

## Builds same error bars but for pre & post within a single session

In [None]:
pre_post_agg_files = glob.glob("chunk_accs/agg_pre_post_accs/*.csv")

for file in pre_post_agg_files:
    
    hold = pd.read_csv(file)

    bits = file.split('/')[-1]
    print(bits)
    bits = bits.split('_')
    group = bits[0]
    gesture = bits[1]
    pre_post = bits[2]


    df_list = [hold]
    if 'pre' in bits:
        file2 = list(filter(lambda a: group + '_' + gesture + '_post' in a, pre_post_agg_files))[0]
        if file2:
            print('Appending... ', file2)
            df_list.append(pd.read_csv(file2))

            pre_post = 'pre_post'
    elif 'post' in file:
        continue 

    
    name = group + '_' + gesture + '_' + pre_post

    acc_shaded_error('agg_prepost_accs_figs', name, df_list)