In [1]:
import os
import six
import glob
import pickle

import numpy as np
import pandas as pd
import scipy.stats as spst

import p2pspatial
import pulse2percept.utils as p2pu

from sklearn.base import clone
import sklearn.metrics as sklm

%matplotlib inline
import matplotlib.pyplot as plt
plt.style.use('ggplot')

2018-04-07 21:54:52,482 [pulse2percept] [INFO] Welcome to pulse2percept


# Load data

In [2]:
def load_file(pickle_file, verbose=False):
    if verbose:
        print('- Processing %s' % pickle_file)
    _, _, _, specifics = pickle.load(open(pickle_file, 'rb'))
    if isinstance(specifics, list):
        print('List of specifics found in', pickle_file)
        return None
    row = {
        'subject': specifics['subject'],
        'model': specifics['modelname'],
        'exetime': specifics['exetime'],
        'adjust_bias': specifics['adjust_bias'],
        'n_folds': specifics['n_folds'],
        'idx_fold': specifics['idx_fold'],
        'best_cost': specifics['best_train_score'],
        'filepath': os.path.dirname(pickle_file),
        'filename': os.path.basename(pickle_file)
    }
    return row

In [3]:
results_dir = '../../results/shape3cv/'
pickle_files = np.sort(glob.glob(os.path.join(results_dir, '*.pickle')))
print('Found', len(pickle_files), 'files')

Found 1141 files


In [4]:
subjects = ['12-005', '51-009', '52-001', 'TB']
assert_params = {
    'amplitude': 2.0,
}
rootfolder = os.path.join(os.environ['SECOND_SIGHT_DATA'], 'shape')

In [5]:
data = pd.DataFrame(list(filter(None, p2pu.parfor(load_file, pickle_files))))

# All runs

In [6]:
# Group by these:
groupcols = ['subject', 'model', 'adjust_bias', 'idx_fold']
# Don't show these:
extracols = ['filepath', 'exetime', 'filename']
r2cols = ['r2_area', 'r2_orientation', 'r2_eccentricity', 'r2_compactness']

In [7]:
data.groupby(groupcols)['exetime'].agg(['mean', 'count'])

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,mean,count
subject,model,adjust_bias,idx_fold,Unnamed: 4_level_1,Unnamed: 5_level_1
12-005,A,True,0,431.473700,10
12-005,A,True,1,364.866903,10
12-005,A,True,2,254.916891,10
12-005,A,True,3,224.855537,10
12-005,A,True,4,494.697876,10
12-005,A,True,5,359.697818,10
12-005,A,True,6,319.071494,10
12-005,A,True,7,362.117683,10
12-005,A,True,8,279.848912,10
12-005,B,False,0,422.243781,5


# The best runs in their respective category

In [8]:
print('Best scores:')
# Find the rows that have the best score in their subject/model/bias group:
best_idx = data.groupby(groupcols)['best_cost'].transform(np.min) == data['best_cost']
data.loc[best_idx, :].drop(columns=extracols).groupby(groupcols).max()

Best scores:


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,best_cost,n_folds
subject,model,adjust_bias,idx_fold,Unnamed: 4_level_1,Unnamed: 5_level_1
12-005,A,True,0,8.066014,18
12-005,A,True,1,8.372955,18
12-005,A,True,2,8.273124,18
12-005,A,True,3,7.937856,18
12-005,A,True,4,8.013519,18
12-005,A,True,5,8.556084,18
12-005,A,True,6,7.783529,18
12-005,A,True,7,7.971503,18
12-005,A,True,8,8.771234,18
12-005,B,False,0,4.587605,19


# Runs that give sufficiently good results

Predicts all four parameters:

In [9]:
best = data.loc[best_idx, :]
best4_idx = (best['r2_area'] > 0) & (best['r2_orientation'] > 0) & (best['r2_eccentricity'] > 0) & (best['r2_compactness'] > 0)
best.loc[best4_idx, :].drop(columns=extracols[:-1]).groupby(groupcols).max()

KeyError: 'r2_area'

Predicts at least 3 out of 4:

In [None]:
best3_idx = (best['r2_area'] > 0) & (best['r2_orientation'] > 0) & (best['r2_eccentricity'] > 0)
best3_idx |= (best['r2_area'] > 0) & (best['r2_orientation'] > 0) & (best['r2_compactness'] > 0)
best3_idx |= (best['r2_eccentricity'] > 0) & (best['r2_orientation'] > 0) & (best['r2_compactness'] > 0)
best3_idx |= (best['r2_area'] > 0) & (best['r2_eccentricity'] > 0) & (best['r2_compactness'] > 0)
best.loc[best3_idx, :].drop(columns=extracols[:-1]).groupby(groupcols).max()

Predicts at least 2 out of 4:

In [None]:
best2_idx = (best['r2_area'] > 0) & (best['r2_orientation'] > 0)
best2_idx |= (best['r2_area'] > 0) & (best['r2_compactness'] > 0)
best2_idx |= (best['r2_area'] > 0) & (best['r2_eccentricity'] > 0)
best2_idx |= (best['r2_orientation'] > 0) & (best['r2_compactness'] > 0)
best2_idx |= (best['r2_orientation'] > 0) & (best['r2_eccentricity'] > 0)
best2_idx |= (best['r2_eccentricity'] > 0) & (best['r2_compactness'] > 0)
best.loc[best2_idx, :].drop(columns=extracols[:-1]).groupby(groupcols).max()
best.loc[best2_idx, :].drop(columns=extracols[:-1]).groupby(groupcols).max()

In [None]:
plot_files = best.loc[best2_idx, :].reset_index()

In [None]:
y, _, _, _ = pickle.load(open(os.path.join(*plot_files.loc[plot_files.index[0], ['filepath', 'filename']]), 'rb'))
if isinstance(y, list):
    y = pd.concat(y)
columns = y.drop(columns=['electrode', 'image']).columns
columns

# Plot best ones

In [None]:
fig, axes = plt.subplots(ncols=len(columns), nrows=len(plot_files), figsize=(14, 2 * len(plot_files)))
for (_, row), axrow in zip(plot_files.iterrows(), axes):
    y, y_pred, _, _ = pickle.load(open(os.path.join(*row[['filepath', 'filename']]), 'rb'))
    if isinstance(y, list):
        y = pd.concat(y)
        y_pred = pd.concat(y_pred)
    for col, ax in zip(columns, axrow):
        ax.scatter(y[col], y_pred[col])
        minval = np.minimum(y[col].min(), y_pred[col].min())
        maxval = np.maximum(y[col].max(), y_pred[col].max())
        ax.plot([minval, maxval], [minval, maxval], 'k--')
        ax.set_xlabel(col)
        ax.text(minval, maxval, "$R^2$=%.3f" % float(row['r2_' + col]), va='top')
    axrow[0].set_ylabel('%s %s %s' % (row['subject'], row['model'],
                                      "adjust" if row['adjust_bias'] else ""))
fig.tight_layout()

In [None]:
from statsmodels.formula.api import ols
from statsmodels.stats.anova import anova_lm

In [None]:
y, y_pred, _, _ = pickle.load(open(os.path.join(*row[['filepath', 'filename']]), 'rb'))

In [None]:
for col in columns:
    df = pd.DataFrame({'y': y[col], 'y_hat': y_pred[col]})
    model = ols('y ~ y_hat', data=df).fit()
    anovaResults = anova_lm(model)
    print('')
    print(col)
    print(anovaResults)