Skip to content

Commit

Permalink
Use a L2-regularized GLM to handle perfect predictors
Browse files Browse the repository at this point in the history
  • Loading branch information
edeno committed Dec 11, 2017
1 parent aa2e8c5 commit 20bac2e
Show file tree
Hide file tree
Showing 3 changed files with 225 additions and 217 deletions.
413 changes: 213 additions & 200 deletions examples/Simulate_Ripple_Decoding_Data_Sorted_Spikes.ipynb

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions replay_classification/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .core import (combined_likelihood, empirical_movement_transition_matrix,
get_bin_centers, inbound_outbound_initial_conditions,
predict_state, uniform_initial_conditions)
from .sorted_spikes import (get_conditional_intensity, glm_fit,
from .sorted_spikes import (get_conditional_intensity, fit_glm_model,
poisson_likelihood,
predictors_by_trajectory_direction)

Expand Down Expand Up @@ -361,15 +361,15 @@ def fit(self):

logger.info('Fitting observation model...')
formula = ('1 + trajectory_direction * '
'bs(position, df=5, degree=3)')
'cr(position, df=5, constraints="center")')

training_data = pd.DataFrame(dict(
position=self.position,
trajectory_direction=self.trajectory_direction))
design_matrix = dmatrix(
formula, training_data, return_type='dataframe')
fit = [glm_fit(spikes, design_matrix, ind)
for ind, spikes in enumerate(self.spikes)]
fit = [fit_glm_model(spikes, design_matrix)
for spikes in self.spikes]

ci_by_state = {
direction: get_conditional_intensity(
Expand Down
21 changes: 8 additions & 13 deletions replay_classification/sorted_spikes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,26 @@
logger = getLogger(__name__)


def glm_fit(spikes, design_matrix, ind):
def fit_glm_model(spikes, design_matrix, penalty=1E-5):
'''Fits the Poisson model to the spikes from a neuron
Parameters
----------
spikes : array_like
design_matrix : array_like or pandas DataFrame
ind : int
penalty : float, optional
Returns
-------
fitted_model : object or NaN
Returns the statsmodel object if successful. If the model fails in
the weighted fit in the IRLS procedure, the model returns NaN.
fitted_model : statsmodel results
'''
try:
logger.debug('\t\t...Neuron #{}'.format(ind + 1))
fit = GLM(spikes, design_matrix,
family=families.Poisson(),
drop='missing').fit(maxiter=30)
return fit if fit.converged else np.nan
except np.linalg.linalg.LinAlgError:
logger.warn('Data is poorly scaled for neuron #{}'.format(ind + 1))
return np.nan
model = GLM(spikes, design_matrix, family=families.Poisson(),
drop='missing')
regularization_weights = np.ones((design_matrix.shape[1],)) * penalty
regularization_weights[0] = 0.0
return model.fit_regularized(alpha=regularization_weights, L1_wt=0)


def predictors_by_trajectory_direction(trajectory_direction,
Expand Down

0 comments on commit 20bac2e

Please sign in to comment.