Skip to content

Commit

Permalink
Allow user to specify sampling frequency
Browse files Browse the repository at this point in the history
  • Loading branch information
edeno committed Feb 7, 2018
1 parent 2dda678 commit cdb129e
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions replay_classification/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,10 @@ def marginalized_intensities(self):
return xr.DataArray(marginalized_intensities, dims=dims,
coords=coords)

def plot_observation_model(self):
def plot_observation_model(self, sampling_frequency=1):
marginalized_intensities = (
self.marginalized_intensities().sum('mark_dimension'))
self.marginalized_intensities().sum('mark_dimension')
* sampling_frequency)
try:
return marginalized_intensities.plot(
row='signal', col='state', x='position', y='marks',
Expand Down Expand Up @@ -423,7 +424,7 @@ def plot_state_transition_model(self, **kwargs):
.plot(x='position_t', y='position_t_1',
robust=True, **kwargs))

def plot_observation_model(self):
def plot_observation_model(self, sampling_frequency=1):
conditional_intensity = self._combined_likelihood_kwargs[
'likelihood_kwargs']['conditional_intensity']
coords = dict(
Expand All @@ -435,7 +436,8 @@ def plot_observation_model(self):
dims=['signal', 'state', 'position'],
name='firing_rate').to_dataframe().reset_index()
g = sns.FacetGrid(
conditional_intensity, row='signal', col='state')
conditional_intensity * sampling_frequency,
row='signal', col='state')
return g.map(plt.plot, 'position', 'firing_rate')

def predict(self, spikes, time=None):
Expand Down

0 comments on commit cdb129e

Please sign in to comment.