In [1]:
import numpy as np
import os
import pickle
from time import time
from datetime import datetime

In [2]:
import pulse2percept as p2p
import p2pspatial

2018-01-17 00:08:14,649 [pulse2percept] [INFO] Welcome to pulse2percept


In [3]:
subject = '12-005'
modelname = ['A', p2pspatial.models.ModelA]
amplitude = 2.0
electrodes = None
random_state = 42
n_folds = 5

In [4]:
t_start = time()
now = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
filename = '%s-crossval-swarm_%s_%s.pickle' % (modelname[0], subject, now)
print(filename)

A-crossval-swarm_12-005_2018-01-17_00-08-15.pickle


In [5]:
rootfolder = os.path.join(os.environ['SECOND_SIGHT_DATA'], 'shape')
X, y = p2pspatial.load_data(rootfolder, subject=subject, electrodes=electrodes,
                            amplitude=amplitude, random_state=random_state,
                            single_stim=True, verbose=False)
print(X.shape, y.shape)
if len(X) == 0:
    raise ValueError('no data found')

(355, 9) (355, 1)


In [6]:
model_params = {'engine': 'serial', 'n_jobs': 1}
regressor = modelname[1](**model_params)

In [None]:
search_params = {'rho': (20, 1000)}
pso_options = {'max_iter': 100,
               'min_func': 0.1}
pso = p2pspatial.model_selection.ParticleSwarmOptimizer(
    regressor, search_params, **pso_options
)

In [None]:
fit_params = {}
y_test, y_pred, best_params = p2pspatial.model_selection.crossval_predict(
    pso, X, y, fit_params=fit_params, n_folds=n_folds)

Fold 1 / 5
No constraints given.


In [None]:
print("Done in %.3fs" % (time() - t_start))

In [None]:
specifics = {'subject': subject,
             'modelname': modelname,
             'amplitude': amplitude,
             'electrodes': electrodes,
             'n_folds': n_folds,
             'regressor': regressor,
             'optimizer': pso,
             'model_params': model_params,
             'search_params': search_params,
             'fit_params': fit_params,
             'now': now,
             'random_state': random_state}
pickle.dump((y_test, y_pred, best_params, specifics), open(filename, 'wb'))
print('Dumped data to %s' % filename)