Skip to content

Commit

Permalink
added uniform dithering method with refr period (NeuralEnsemble#297)
Browse files Browse the repository at this point in the history
* upd surrogate type

* Removed dither_spikes_with_refr_period from surrogates wrapper

Co-authored-by: p-bouss <peter.bouss@googlemail.com>
Co-authored-by: dizcza <dizcza@gmail.com>
  • Loading branch information
3 people committed Apr 3, 2020
1 parent 8022903 commit 8a9f6ec
Show file tree
Hide file tree
Showing 3 changed files with 414 additions and 242 deletions.
7 changes: 7 additions & 0 deletions elephant/spade.py
Original file line number Diff line number Diff line change
Expand Up @@ -1265,6 +1265,13 @@ def pvalue_spectrum(spiketrains, binsize, winlen, dither, n_surr, min_spikes=2,
if surr_method == 'joint_isi_dithering':
surrs = [instance.dithering()[0] for
instance in joint_isi_instances]
elif surr_method == 'dither_spikes_with_refractory_period':
# The initial refractory period is set to the bin size in order to
# prevent that spikes fall into the same bin, if the spike trains
# are sparse (min(ISI)>bin size).
surrs = [surr.dither_spikes(
spiketrain, dither=dither, n=1, refractory_period=binsize)[0]
for spiketrain in spiketrains]
else:
surrs = [surr.surrogates(
spiketrain, n=1, surr_method=surr_method,
Expand Down
163 changes: 118 additions & 45 deletions elephant/spike_train_surrogates.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,25 +46,70 @@
from __future__ import division, print_function, unicode_literals

import random
from functools import partial

import neo
import numpy as np
import quantities as pq
from scipy.ndimage import gaussian_filter

try:
import elephant.statistics as es

isi = es.isi
except ImportError:
from .statistics import isi # Convenience when in elephant working dir.
from elephant.statistics import isi

# List of all available surrogate methods
SURR_METHODS = ['dither_spike_train', 'dither_spikes', 'jitter_spikes',
'randomise_spikes', 'shuffle_isis', 'joint_isi_dithering']


def dither_spikes(spiketrain, dither, n=1, decimals=None, edges=True):
'randomise_spikes', 'shuffle_isis', 'joint_isi_dithering',
'dither_spikes_with_refractory_period']


def _dither_spikes_with_refractory_period(spiketrain, dither, n,
refractory_period):
units = spiketrain.units
t_start = spiketrain.t_start.rescale(units).magnitude
t_stop = spiketrain.t_stop.rescale(units).magnitude

dither = dither.rescale(units).magnitude
refractory_period = refractory_period.rescale(units).magnitude
# The initially guesses refractory period is compared to the minimal ISI.
# The smaller value is taken as the refractory to calculate with.
refractory_period = np.min(np.diff(spiketrain.magnitude),
initial=refractory_period)

dithered_spiketrains = []
for _ in range(n):
dithered_st = np.copy(spiketrain.magnitude)
random_ordered_ids = np.arange(len(spiketrain))
np.random.shuffle(random_ordered_ids)

for random_id in random_ordered_ids:
spike = dithered_st[random_id]
prev_spike = dithered_st[random_id - 1] \
if random_id > 0 \
else t_start - refractory_period
# subtract refractory period so that the first spike can move up
# to t_start
next_spike = dithered_st[random_id + 1] \
if random_id < len(spiketrain) - 1 \
else t_stop + refractory_period
# add refractory period so that the last spike can move up
# to t_stop

# Dither range in the direction to the previous spike
prev_dither = min(dither, spike - prev_spike - refractory_period)
# Dither range in the direction to the next spike
next_dither = min(dither, next_spike - spike - refractory_period)

dt = (prev_dither + next_dither) * random.random() - prev_dither
dithered_st[random_id] += dt

dithered_spiketrains.append(dithered_st)

dithered_spiketrains = np.array(dithered_spiketrains) * units

return dithered_spiketrains


def dither_spikes(spiketrain, dither, n=1, decimals=None, edges=True,
refractory_period=None):
"""
Generates surrogates of a spike train by spike dithering.
Expand All @@ -87,7 +132,8 @@ def dither_spikes(spiketrain, dither, n=1, decimals=None, edges=True):
Number of surrogates to be generated.
Default: 1.
decimals : int or None, optional
Number of decimal points for every spike time in the surrogates.
Number of decimal points for every spike time in the surrogates at a
millisecond level.
If None, machine precision is used.
Default: None.
edges : bool, optional
Expand All @@ -96,6 +142,16 @@ def dither_spikes(spiketrain, dither, n=1, decimals=None, edges=True):
(for `edges = True`) or set them to the range's closest end
(for `edges = False`).
Default: True.
refractory_period : pq.Quantity or None, optional
The dither range of each spike is adjusted such that the spike can not
fall into the `refractory_period` of the previous or next spike.
To account this, the refractory period is estimated as the smallest ISI
of the spike train. The given argument `refractory_period` here is thus
an initial estimation.
Note, that with this option a spike cannot "jump" over the previous or
next spike as it is normally possible.
If set to `None`, no refractoriness is in dithering.
Default: None
Returns
-------
Expand Down Expand Up @@ -124,34 +180,54 @@ def dither_spikes(spiketrain, dither, n=1, decimals=None, edges=True):
[0.0 ms, 1000.0 ms])>]
"""
# Transform spiketrain into a Quantity object (needed for matrix algebra)
data = spiketrain.view(pq.Quantity)
if len(spiketrain) == 0:
# return the empty spiketrain `n` times
return [spiketrain.copy() for _ in range(n)]

units = spiketrain.units
t_start = spiketrain.t_start.rescale(units).magnitude
t_stop = spiketrain.t_stop.rescale(units).magnitude

if refractory_period is None or refractory_period == 0:
# Main: generate the surrogates
dither = dither.rescale(units).magnitude
dithered_spiketrains = \
spiketrain.magnitude.reshape((1, len(spiketrain))) \
+ 2 * dither * np.random.random_sample((n, len(spiketrain)))\
- dither
dithered_spiketrains.sort(axis=0)

if edges:
# Leave out all spikes outside
# [spiketrain.t_start, spiketrain.t_stop]
dithered_spiketrains = \
[train[
np.all([t_start < train, train < t_stop], axis=0)]
for train in dithered_spiketrains]
else:
# Move all spikes outside
# [spiketrain.t_start, spiketrain.t_stop] to the range's ends
dithered_spiketrains = np.minimum(
np.maximum(dithered_spiketrains, t_start),
t_stop)

# Main: generate the surrogates
surr = data.reshape((1, len(data))) + 2 * dither * np.random.random_sample(
(n, len(data))) - dither
# Round the surrogate data to decimal position, if requested
if decimals is not None:
surr = surr.round(decimals)
dithered_spiketrains = dithered_spiketrains * units

if edges is False:
# Move all spikes outside [spiketrain.t_start, spiketrain.t_stop] to
# the range's ends
surr = np.minimum(np.maximum(surr.simplified.magnitude,
spiketrain.t_start.simplified.magnitude),
spiketrain.t_stop.simplified.magnitude) * pq.s
elif isinstance(refractory_period, pq.Quantity):
dithered_spiketrains = _dither_spikes_with_refractory_period(
spiketrain, dither, n, refractory_period)
else:
# Leave out all spikes outside [spiketrain.t_start, spiketrain.t_stop]
tstart, tstop = spiketrain.t_start.simplified.magnitude, \
spiketrain.t_stop.simplified.magnitude
surr = [np.sort(s[np.all([s >= tstart, s < tstop], axis=0)]) * pq.s
for s in surr.simplified.magnitude]
raise ValueError("refractory_period must be of type pq.Quantity")

# Return the surrogates as SpikeTrains
return [neo.SpikeTrain(s,
t_start=spiketrain.t_start,
t_stop=spiketrain.t_stop).rescale(spiketrain.units)
for s in surr]
# Round the surrogate data to decimal position, if requested
if decimals is not None:
dithered_spiketrains = \
dithered_spiketrains.rescale(pq.ms).round(decimals).rescale(units)

# Return the surrogates as list of neo.SpikeTrain
return [neo.SpikeTrain(
train, t_start=t_start, t_stop=t_stop)
for train in dithered_spiketrains]


def randomise_spikes(spiketrain, n=1, decimals=None):
Expand Down Expand Up @@ -266,7 +342,7 @@ def shuffle_isis(spiketrain, n=1, decimals=None):
isi0 = spiketrain[0] - spiketrain.t_start
ISIs = np.hstack([isi0, isi(spiketrain)])

# Round the ISIs to decimal position, if requested
# Round the isis to decimal position, if requested
if decimals is not None:
ISIs = ISIs.round(decimals)

Expand All @@ -280,12 +356,9 @@ def shuffle_isis(spiketrain, n=1, decimals=None):
t_stop=spiketrain.t_stop))

else:
sts = []
empty_train = neo.SpikeTrain([] * spiketrain.units,
t_start=spiketrain.t_start,
t_stop=spiketrain.t_stop)
for i in range(n):
sts.append(empty_train)
sts = [neo.SpikeTrain([] * spiketrain.units,
t_start=spiketrain.t_start,
t_stop=spiketrain.t_stop)] * n

return sts

Expand Down Expand Up @@ -936,13 +1009,13 @@ def surrogates(spiketrain, n=1, surr_method='dither_spike_train', dt=None,
raise AttributeError(
'specified surr_method (=%s) not valid' % surr_method)

if surr_method in ['dither_spike_train', 'dither_spikes']:
if surr_method in ('dither_spike_train', 'dither_spikes'):
return surrogate_types[surr_method](
spiketrain, dt, n=n, decimals=decimals, edges=edges)
if surr_method in ['randomise_spikes', 'shuffle_isis']:
if surr_method in ('randomise_spikes', 'shuffle_isis'):
return surrogate_types[surr_method](
spiketrain, n=n, decimals=decimals)
elif surr_method == 'jitter_spikes':
if surr_method == 'jitter_spikes':
return surrogate_types[surr_method](
spiketrain, dt, n=n)
# surr_method == 'joint_isi_dithering':
Expand Down

0 comments on commit 8a9f6ec

Please sign in to comment.