Skip to content

Commit

Permalink
Use regularized_glm package for faster sorted spike observation mod…
Browse files Browse the repository at this point in the history
…el fitting
  • Loading branch information
edeno committed Feb 21, 2018
1 parent 15775e5 commit 0be8b8b
Show file tree
Hide file tree
Showing 5 changed files with 178 additions and 187 deletions.
2 changes: 2 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ channels:
- conda-forge
- defaults
- ioam
- edeno
dependencies:
- "python>=3.5"
- setuptools
Expand All @@ -11,6 +12,7 @@ dependencies:
- pandas
- xarray
- statsmodels
- regularized_glm
- matplotlib
- seaborn
- patsy
Expand Down
318 changes: 159 additions & 159 deletions examples/Simulate_Ripple_Decoding_Data_Sorted_Spikes.ipynb

Large diffs are not rendered by default.

7 changes: 4 additions & 3 deletions replay_classification/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,13 +390,14 @@ def fit(self):
trajectory_direction=self.trajectory_direction))
design_matrix = dmatrix(
formula, training_data, return_type='dataframe')
fit = [fit_glm_model(
fit_coefficients = np.stack(
[fit_glm_model(
pd.DataFrame(spikes).loc[design_matrix.index], design_matrix)
for spikes in self.spikes]
for spikes in self.spikes], axis=1)

ci_by_state = {
direction: get_conditional_intensity(
fit, predictors_by_trajectory_direction(
fit_coefficients, predictors_by_trajectory_direction(
direction, self.place_bin_centers, design_matrix))
for direction in trajectory_directions}
conditional_intensity = np.stack(
Expand Down
36 changes: 12 additions & 24 deletions replay_classification/sorted_spikes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

import numpy as np
from patsy import build_design_matrices
from statsmodels.api import GLM, families
from statsmodels.api import families
from regularized_glm import penalized_IRLS

logger = getLogger(__name__)


def fit_glm_model(spikes, design_matrix, penalty=1E-5):
def fit_glm_model(spikes, design_matrix, penalty=3):
'''Fits the Poisson model to the spikes from a neuron.
Parameters
Expand All @@ -22,14 +23,13 @@ def fit_glm_model(spikes, design_matrix, penalty=1E-5):
fitted_model : statsmodel results
'''
model = GLM(spikes, design_matrix, family=families.Poisson(),
drop='missing')
if penalty is None:
return model.fit()
else:
regularization_weights = np.ones((design_matrix.shape[1],)) * penalty
regularization_weights[0] = 0.0
return model.fit_regularized(alpha=regularization_weights, L1_wt=0)
regularization_weights = np.ones((design_matrix.shape[1],)) * penalty
regularization_weights[0] = 0.0
return np.squeeze(
penalized_IRLS(
np.array(design_matrix), np.array(spikes),
family=families.Poisson(),
penalty=regularization_weights).coefficients)


def predictors_by_trajectory_direction(trajectory_direction,
Expand All @@ -44,22 +44,10 @@ def predictors_by_trajectory_direction(trajectory_direction,
[design_matrix.design_info], predictors)[0]


def glm_val(fitted_model, predict_design_matrix):
'''Predict the model's response given a design matrix and the model
parameters.
'''
try:
return fitted_model.predict(predict_design_matrix)
except AttributeError:
return np.full(predict_design_matrix.shape[0], np.nan)


def get_conditional_intensity(fit, predict_design_matrix):
def get_conditional_intensity(fit_coefficients, predict_design_matrix):
'''The conditional intensity for each model
'''
return [glm_val(fitted_model, predict_design_matrix)
for fitted_model in fit]
return np.exp(np.dot(predict_design_matrix, fit_coefficients)).T


def atleast_kd(array, k):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

INSTALL_REQUIRES = ['numpy >= 1.11', 'pandas >= 0.18.0', 'scipy', 'xarray',
'statsmodels', 'matplotlib', 'numba', 'patsy', 'seaborn',
'holoviews', 'bokeh']
'holoviews', 'bokeh', 'regularized_glm']
TESTS_REQUIRE = ['pytest >= 2.7.1']

setup(
Expand Down

0 comments on commit 0be8b8b

Please sign in to comment.