In [1]:
import h5py
import matplotlib.pyplot as plt
import numpy as np
import os
import time

from sklearn.linear_model import (LinearRegression,
                                  RidgeCV,
                                  LassoCV,
                                  ElasticNetCV)
from sklearn.metrics import r2_score
from sklearn.model_selection import train_test_split
from uoineuro.tuning_utils import create_strf_design

%matplotlib inline

In [2]:
data_path = os.path.join(os.environ['HOME'],
                         'data/dmr/ecog_dmr.h5')

In [3]:
data = h5py.File(data_path, 'r')

In [4]:
stim = data['stim'][:]
resp = data['resp'][:]

In [5]:
n_samples, n_features = stim.shape
n_electrodes = resp.shape[0]
n_frames = 20
train_frac = 0.9
method = 'elasticnet'

In [6]:
X, Y = create_strf_design(stim, resp, n_frames)

In [7]:
strfs = np.zeros((n_electrodes, n_frames, n_features))
train_scores = np.zeros(n_electrodes)
test_scores = np.zeros(n_electrodes)

In [None]:
for electrode in range(n_electrodes):
    t = time.time()
    print('Electrode: ', electrode)
    X_train, X_test, y_train, y_test = \
        train_test_split(X, Y[electrode], train_size=train_frac)

    centering = X_train.mean(axis=0, keepdims=True)
    X_train -= centering
    X_test -= centering

    if method == 'ridge':
        fitter = RidgeCV(alphas=np.logspace(3, 6, num=300),
                         normalize=False,
                         fit_intercept=False,
                         cv=None).fit(X_train, y_train)
    elif method == 'elasticnet':
        fitter = ElasticNetCV(l1_ratio=np.array([0.1, 0.2, 0.3, 0.5]),
                              eps=1e-5,
                              n_alphas=100,
                              fit_intercept=False,
                              max_iter=5000,
                              cv=5).fit(X_train, y_train)
    strfs[electrode] = fitter.coef_.reshape((20, 96))
    train_scores[electrode] = fitter.score(X_train, y_train)
    test_scores[electrode] = fitter.score(X_test, y_test)
    print('Alpha: ', fitter.alpha_)
    print('L1 ratio: ', fitter.l1_ratio_)
    print('Time: ', time.time() - t)
    print('---')

Electrode:  0
