Skip to content

Commit

Permalink
[Feature] Synchrofact Detection (NeuralEnsemble#322)
Browse files Browse the repository at this point in the history
Co-authored-by: Aitor <a.morales-gregorio@fz-juelich.de>
Co-authored-by: dizcza <dizcza@gmail.com>
  • Loading branch information
3 people committed Dec 4, 2020
1 parent cc19b17 commit 9eadcf3
Show file tree
Hide file tree
Showing 11 changed files with 1,158 additions and 107 deletions.
7 changes: 4 additions & 3 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,8 @@
# the autosummary fields of each module.
autosummary_generate = True

# don't overwrite our custom toctree/*.rst
autosummary_generate_overwrite = False
# Set to False to not overwrite our custom toctree/*.rst
autosummary_generate_overwrite = True

# -- Options for HTML output ---------------------------------------------

Expand Down Expand Up @@ -344,7 +344,8 @@

# configuration for intersphinx: refer to Viziphant
intersphinx_mapping = {
'viziphant': ('https://viziphant.readthedocs.io/en/latest/', None)
'viziphant': ('https://viziphant.readthedocs.io/en/latest/', None),
'numpy': ('https://numpy.org/doc/stable', None)
}

# Use more reliable mathjax source
Expand Down
1 change: 1 addition & 0 deletions doc/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ Function Reference by Module
reference/sta
reference/statistics
reference/unitary_event_analysis
reference/utils
reference/waveform_features


Expand Down
5 changes: 5 additions & 0 deletions doc/reference/utils.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
=================
Utility functions
=================

.. automodule:: elephant.utils
34 changes: 4 additions & 30 deletions elephant/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import scipy.sparse as sps

from elephant.utils import is_binary, deprecated_alias, \
check_neo_consistency, get_common_start_stop_times
check_neo_consistency, get_common_start_stop_times, round_binning_errors

__all__ = [
"binarize",
Expand Down Expand Up @@ -185,18 +185,6 @@ def binarize(spiketrain, sampling_rate=None, t_start=None, t_stop=None,
###########################################################################


def _detect_rounding_errors(values, tolerance):
"""
Finds rounding errors in values that will be cast to int afterwards.
Returns True for values that are within tolerance of the next integer.
Works for both scalars and numpy arrays.
"""
if tolerance is None or tolerance == 0:
return np.zeros_like(values, dtype=bool)
# same as '1 - (values % 1) <= tolerance' but faster
return 1 - tolerance <= values % 1


class BinnedSpikeTrain(object):
"""
Class which calculates a binned spike train and provides methods to
Expand Down Expand Up @@ -417,12 +405,8 @@ def get_n_bins():
n_bins = (self._t_stop - self._t_start) / self._bin_size
if isinstance(n_bins, pq.Quantity):
n_bins = n_bins.simplified.item()
if _detect_rounding_errors(n_bins, tolerance=tolerance):
warnings.warn('Correcting a rounding error in the calculation '
'of n_bins by increasing n_bins by 1. '
'You can set tolerance=None to disable this '
'behaviour.')
return int(n_bins)
n_bins = round_binning_errors(n_bins, tolerance=tolerance)
return n_bins

def check_n_bins_consistency():
if self.n_bins != get_n_bins():
Expand Down Expand Up @@ -825,17 +809,7 @@ def _create_sparse_matrix(self, spiketrains, tolerance):

# shift spikes that are very close
# to the right edge into the next bin
rounding_error_indices = _detect_rounding_errors(
bins, tolerance=tolerance)
num_rounding_corrections = rounding_error_indices.sum()
if num_rounding_corrections > 0:
warnings.warn('Correcting {} rounding errors by shifting '
'the affected spikes into the following bin. '
'You can set tolerance=None to disable this '
'behaviour.'.format(num_rounding_corrections))
bins[rounding_error_indices] += .5

bins = bins.astype(np.int32)
bins = round_binning_errors(bins, tolerance=tolerance)
valid_bins = bins[bins < self.n_bins]
n_discarded += len(bins) - len(valid_bins)
f, c = np.unique(valid_bins, return_counts=True)
Expand Down
175 changes: 175 additions & 0 deletions elephant/spike_train_synchrony.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
:toctree: toctree/spike_train_synchrony/
spike_contrast
Synchrotool
:copyright: Copyright 2015-2020 by the Elephant team, see `doc/authors.rst`.
Expand All @@ -18,17 +19,26 @@
from __future__ import division, print_function, unicode_literals

from collections import namedtuple
from copy import deepcopy

import neo
import numpy as np
import quantities as pq

from elephant.statistics import Complexity
from elephant.utils import is_time_quantity

SpikeContrastTrace = namedtuple("SpikeContrastTrace", (
"contrast", "active_spiketrains", "synchrony"))


__all__ = [
"SpikeContrastTrace",
"spike_contrast",
"Synchrotool"
]


def _get_theta_and_n_per_bin(spiketrains, t_start, t_stop, bin_size):
"""
Calculates theta (amount of spikes per bin) and the amount of active spike
Expand Down Expand Up @@ -218,3 +228,168 @@ def spike_contrast(spiketrains, t_start=None, t_stop=None,
return synchrony, spike_contrast_trace

return synchrony


class Synchrotool(Complexity):
"""
Tool class to find, remove and/or annotate the presence of synchronous
spiking events across multiple spike trains.
The complexity is used to characterize synchronous events within the same
spike train and across different spike trains in the `spiketrains` list.
This way synchronous events can be found both in multi-unit and
single-unit spike trains.
This class inherits from :class:`elephant.statistics.Complexity`, see its
documentation for more details and input parameters description.
See also
--------
elephant.statistics.Complexity
"""

def __init__(self, spiketrains,
sampling_rate,
bin_size=None,
binary=True,
spread=0,
tolerance=1e-8):

self.annotated = False

super(Synchrotool, self).__init__(spiketrains=spiketrains,
bin_size=bin_size,
sampling_rate=sampling_rate,
binary=binary,
spread=spread,
tolerance=tolerance)

def delete_synchrofacts(self, threshold, in_place=False, mode='delete'):
"""
Delete or extract synchronous spiking events.
Parameters
----------
threshold : int
Threshold value for the deletion of spikes engaged in synchronous
activity.
* `deletion_threshold >= 2` leads to all spikes with a larger or
equal complexity value to be deleted/extracted.
* `deletion_threshold <= 1` leads to a ValueError, since this
would delete/extract all spikes and there are definitely more
efficient ways of doing so.
in_place : bool, optional
Determines whether the modification are made in place
on ``self.input_spiketrains``.
Default: False
mode : {'delete', 'extract'}, optional
Inversion of the mask for deletion of synchronous events.
* ``'delete'`` leads to the deletion of all spikes with
complexity >= `threshold`,
i.e. deletes synchronous spikes.
* ``'extract'`` leads to the deletion of all spikes with
complexity < `threshold`, i.e. extracts synchronous spikes.
Default: 'delete'
Raises
------
ValueError
If `mode` is not one in {'delete', 'extract'}.
If `threshold <= 1`.
Returns
-------
list of neo.SpikeTrain
List of spiketrains where the spikes with
``complexity >= threshold`` have been deleted/extracted.
* If ``in_place`` is True, the returned list is the same as
``self.input_spiketrains``.
* If ``in_place`` is False, the returned list is a deepcopy of
``self.input_spiketrains``.
"""

if not self.annotated:
self.annotate_synchrofacts()

if mode not in ['delete', 'extract']:
raise ValueError(f"Invalid mode '{mode}'. Valid modes are: "
f"'delete', 'extract'")

if threshold <= 1:
raise ValueError('A deletion threshold <= 1 would result '
'in the deletion of all spikes.')

if in_place:
spiketrain_list = self.input_spiketrains
else:
spiketrain_list = deepcopy(self.input_spiketrains)

for idx, st in enumerate(spiketrain_list):
mask = st.array_annotations['complexity'] < threshold
if mode == 'extract':
mask = np.invert(mask)
new_st = st[mask]
spiketrain_list[idx] = new_st
if in_place:
segment = st.segment
if segment is None:
continue

# replace link to spiketrain in segment
new_index = self._get_spiketrain_index(
segment.spiketrains, st)
segment.spiketrains[new_index] = new_st

block = segment.block
if block is None:
continue

# replace link to spiketrain in groups
for group in block.groups:
try:
idx = self._get_spiketrain_index(
group.spiketrains,
st)
except ValueError:
# st is not in this group, move to next group
continue

# st found in group, replace with new_st
group.spiketrains[idx] = new_st

return spiketrain_list

def annotate_synchrofacts(self):
"""
Annotate the complexity of each spike in the
``self.epoch.array_annotations`` *in-place*.
"""
epoch_complexities = self.epoch.array_annotations['complexity']
right_edges = (
self.epoch.times.magnitude.flatten()
+ self.epoch.durations.rescale(
self.epoch.times.units).magnitude.flatten()
)

for idx, st in enumerate(self.input_spiketrains):

# all indices of spikes that are within the half-open intervals
# defined by the boundaries
# note that every second entry in boundaries is an upper boundary
spike_to_epoch_idx = np.searchsorted(
right_edges,
st.times.rescale(self.epoch.times.units).magnitude.flatten())
complexity_per_spike = epoch_complexities[spike_to_epoch_idx]

st.array_annotate(complexity=complexity_per_spike)

self.annotated = True

def _get_spiketrain_index(self, spiketrain_list, spiketrain):
for index, item in enumerate(spiketrain_list):
if item is spiketrain:
return index
raise ValueError("Spiketrain is not found in the list")

0 comments on commit 9eadcf3

Please sign in to comment.