In [1]:
import os
import six
import glob
import pickle

import numpy as np
import pandas as pd
import scipy.stats as spst

import argus_shapes
import pulse2percept.utils as p2pu

from sklearn.base import clone
import sklearn.metrics as sklm

%matplotlib inline
import matplotlib.pyplot as plt
plt.style.use('ggplot')

2018-04-09 16:17:13,078 [pulse2percept] [INFO] Welcome to pulse2percept


# Load data

In [2]:
def fix_data(pickle_files, verbose=True):
    data = []
    for pickle_file in pickle_files:
        if verbose:
            print('- Processing %s' % pickle_file)
        y_test, y_pred, best_params, specifics = pickle.load(open(pickle_file, 'rb'))
        
        if np.allclose([c in specifics for c in ['idx_fold', 'best_train_score', 'best_test_score']], True):
            if verbose:
                print("  - File up-to-date, skip.")
            continue

        if 'idx_fold' in specifics:
            idx_fold = specifics['idx_fold']
        else:
            idx_begin = pickle_file.find("shape3cvLOO-") + 12
            assert idx_begin > 12
            idx_end = pickle_file[idx_begin:].find("-")
            if idx_end > 2:
                print('  - No CV fold found, skip.')
                continue
            idx_fold = int(pickle_file[idx_begin:idx_begin + idx_end])
            
        if 'best_test_score' in specifics:
            best_test_score = specifics['best_test_score']
        elif 'best_score' in specifics:
            best_test_score = specifics['best_score'][0]
            del specifics['best_score']
        else:
            best_test_score = np.inf

        if 'best_train_score' in specifics:
            best_train_score = specifics['best_train_score']
        else:
            rootfolder = os.path.join(os.environ['SECOND_SIGHT_DATA'], 'shape')
            X, y = argus_shapes.load_data(rootfolder, subject=specifics['subject'], electrodes=None,
                                        amplitude=specifics['amplitude'], random_state=42,
                                        n_jobs=1, verbose=False)
            X, y = argus_shapes.exclude_bistables(X, y)
            if specifics['adjust_bias']:
                y = argus_shapes.adjust_drawing_bias(X, y,
                                                   scale_major=specifics['drawing']['major'],
                                                   scale_minor=specifics['drawing']['minor'],
                                                   rotate=specifics['drawing']['orient'])
                print('  - Adjusted for drawing bias:', X.shape, y.shape)
            if len(X) == 0:
                raise ValueError('No data found. Abort.')
            if specifics['avg_img']:
                X, y = argus_shapes.calc_mean_images(X, y)
            reg = specifics['regressor']
            reg.set_params(**best_params[0])
            reg.fit(X.drop(y_test[0].index));
            reg.set_params(engine='serial')
            best_train_score = reg.score(X.drop(y_test[0].index), y.drop(y_test[0].index))
        
        print('  - idx_fold=%d, best_train_score=%f, best_test_score=%f' % (idx_fold,\
                                                                            best_train_score,
                                                                            best_test_score))
        specifics['idx_fold'] = idx_fold
        specifics['best_train_score'] = best_train_score
        specifics['best_test_score'] = best_test_score

        pickle.dump((y_test, y_pred, best_params, specifics), open(pickle_file, 'wb'))
        print('  - Dumped new data to', pickle_file)

In [3]:
results_dir = '../../results/shape3cv/'
pickle_files = np.sort(glob.glob(os.path.join(results_dir, '*.pickle')))
print('Found', len(pickle_files), 'files')

Found 1209 files


In [4]:
subjects = ['12-005', '51-009', '52-001', 'TB']
assert_params = {
    'amplitude': 2.0,
}
rootfolder = os.path.join(os.environ['SECOND_SIGHT_DATA'], 'shape')

In [5]:
data = fix_data(pickle_files, verbose=False)

  - idx_fold=0, best_train_score=12.347908, best_test_score=1830216475410.484619
  - Dumped new data to ../../results/shape3cv/51-009_D__shape3cvLOO-0-swarm_2018-04-09_03-36-50.pickle
  - idx_fold=1, best_train_score=6.988467, best_test_score=163031929329.170471
  - Dumped new data to ../../results/shape3cv/51-009_D__shape3cvLOO-1-swarm_2018-04-09_13-23-16.pickle
  - idx_fold=10, best_train_score=6.404791, best_test_score=3071676425.052069
  - Dumped new data to ../../results/shape3cv/51-009_D__shape3cvLOO-10-swarm_2018-04-09_12-08-10.pickle
  - idx_fold=11, best_train_score=10.074544, best_test_score=469790598523.792847
  - Dumped new data to ../../results/shape3cv/51-009_D__shape3cvLOO-11-swarm_2018-04-09_16-07-22.pickle
  - idx_fold=12, best_train_score=8.003070, best_test_score=65328453968.044060
  - Dumped new data to ../../results/shape3cv/51-009_D__shape3cvLOO-12-swarm_2018-04-09_00-57-14.pickle
  - idx_fold=2, best_train_score=6.042219, best_test_score=230006948312.796875
  - D

In [6]:
print('All Done!')

All Done!
