# Table 2: Cross-validated model parameter values

In [1]:
import pickle
import pandas as pd
import argus_shapes as shapes

2019-05-08 11:09:45,622 [pulse2percept] [INFO] Welcome to pulse2percept


In [2]:
# Because the particle swarm gives different results depending on the initial
# conditions, we ran each CV fold multiple times (random init). Here we need
# to comb through these simulation runs and find the one that gave the best
# score (for each fold):
results_dir = '../results'
col_score = 'best_train_score'
col_groupby = ['subject', 'modelname', 'idx_fold']

try:
    # For all files of a given subject, model, CV fold (`col_groupby`), find the
    # best `col_score`:
    files = shapes.extract_best_pickle_files(results_dir, col_score, col_groupby)
except FileNotFoundError:
    # Results directory does not exist (or no files found). In this case, download
    # the data from OSF:
    shapes.fetch_data(osf_zip_url='https://osf.io/prv5z', save_path=results_dir)
    files = shapes.extract_best_pickle_files(results_dir, col_score, col_groupby)

In [3]:
df_params = []
for file in files:
    _, _, best_params, specifics = pickle.load(open(file, 'rb'))
    params = {
        'subject': specifics['subject'],
        'model': specifics['modelname'],
        'rho': best_params[0]['rho']
    }
    if 'axlambda' in best_params[0]:
        params['axlambda'] = best_params[0]['axlambda']
    df_params.append(params)
df_params = pd.DataFrame(df_params)

In [4]:
df_params.groupby(['subject', 'model'])['rho', 'axlambda'].agg(['mean', 'sem'])

Unnamed: 0_level_0,Unnamed: 1_level_0,rho,rho,axlambda,axlambda
Unnamed: 0_level_1,Unnamed: 1_level_1,mean,sem,mean,sem
subject,model,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2
S1,AxonMap,409.859708,4.697559,1189.556762,156.637505
S1,Scoreboard,532.674499,10.660461,,
S2,AxonMap,315.246876,17.074557,499.679555,141.643332
S2,Scoreboard,243.824212,33.815198,,
S3,AxonMap,143.837119,7.440525,1414.382895,95.578594
S3,Scoreboard,170.287357,1.163572,,
S4,AxonMap,437.193516,6.313911,1419.558295,42.481071
S4,Scoreboard,174.78679,1.414194,,
