Skip to content

Commit

Permalink
fixed bugs in joint-ISI and shuffle ISI (NeuralEnsemble#364)
Browse files Browse the repository at this point in the history
Co-authored-by: dizcza <dizcza@gmail.com>
  • Loading branch information
pbouss and dizcza committed Oct 23, 2020
1 parent a0ca215 commit 1d48a77
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 29 deletions.
61 changes: 36 additions & 25 deletions elephant/spike_train_surrogates.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,31 +365,32 @@ def shuffle_isis(spiketrain, n_surrogates=1, decimals=None):
[0.0 ms, 1000.0 ms])>]
"""
if len(spiketrain) > 0:
isi0 = spiketrain[0] - spiketrain.t_start
ISIs = np.hstack([isi0, isi(spiketrain)])

# Round the isis to decimal position, if requested
if decimals is not None:
ISIs = ISIs.round(decimals)

# Create list of surrogate spike trains by random ISI permutation
sts = []
for surrogate_id in range(n_surrogates):
surr_times = np.cumsum(np.random.permutation(ISIs)) * \
spiketrain.units + spiketrain.t_start
sts.append(neo.SpikeTrain(
surr_times, t_start=spiketrain.t_start,
t_stop=spiketrain.t_stop,
sampling_rate=spiketrain.sampling_rate))

else:
sts = [neo.SpikeTrain([] * spiketrain.units,
t_start=spiketrain.t_start,
t_stop=spiketrain.t_stop,
sampling_rate=spiketrain.sampling_rate)
] * n_surrogates
if len(spiketrain) == 0:
return [neo.SpikeTrain([] * spiketrain.units,
t_start=spiketrain.t_start,
t_stop=spiketrain.t_stop,
sampling_rate=spiketrain.sampling_rate)
for _ in range(n_surrogates)]

# A correct sorting is necessary, to calculate the ISIs
spiketrain = spiketrain.copy()
spiketrain.sort()
isi0 = spiketrain[0] - spiketrain.t_start
isis = np.hstack([isi0, isi(spiketrain)])

# Round the isis to decimal position, if requested
if decimals is not None:
isis = isis.round(decimals)

# Create list of surrogate spike trains by random ISI permutation
sts = []
for surrogate_id in range(n_surrogates):
surr_times = np.cumsum(np.random.permutation(isis)) * \
spiketrain.units + spiketrain.t_start
sts.append(neo.SpikeTrain(
surr_times, t_start=spiketrain.t_start,
t_stop=spiketrain.t_stop,
sampling_rate=spiketrain.sampling_rate))
return sts


Expand Down Expand Up @@ -740,6 +741,9 @@ def __init__(self,
if not isinstance(spiketrain, neo.SpikeTrain):
raise TypeError('spiketrain must be of type neo.SpikeTrain')

# A correct sorting is necessary to calculate the ISIs
spiketrain = spiketrain.copy()
spiketrain.sort()
self.spiketrain = spiketrain
self.truncation_limit = self._get_magnitude(truncation_limit)
self.n_bins = n_bins
Expand Down Expand Up @@ -975,6 +979,12 @@ def dithering(self, n_surrogates=1):
dithered_st = self.spiketrain[0].magnitude + \
np.r_[0., np.cumsum(dithered_isi)]
sampling_rate = self.spiketrain.sampling_rate

# Due to rounding errors, the last spike may be above t_stop.
# If the case, this is set to t_stop.
if dithered_st[-1] > self.spiketrain.t_stop:
dithered_st[-1] = self.spiketrain.t_stop

dithered_st = neo.SpikeTrain(dithered_st * self._unit,
t_start=self.spiketrain.t_start,
t_stop=self.spiketrain.t_stop,
Expand All @@ -994,7 +1004,8 @@ def _determine_cumulative_functions(self):
jisih_cum = self._normalize_cumulative_distribution(
np.cumsum(diagonal))
self._jisih_cumulatives.append(jisih_cum)
self._jisih_cumulatives = np.array(self._jisih_cumulatives)
self._jisih_cumulatives = np.array(
self._jisih_cumulatives, dtype=object)
else:
self._jisih_cumulatives = self._window_cumulatives(rotated_jisih)

Expand Down
62 changes: 58 additions & 4 deletions elephant/test/test_spike_train_surrogates.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,24 @@ def test_shuffle_isis_output_decimals(self):

self.assertTrue(np.all(ISIs_orig == ISIs_surr))

def test_shuffle_isis_with_wrongly_ordered_spikes(self):
surr_method = 'shuffle_isis'
n_surr = 30
dither = 15 * pq.ms
spiketrain = neo.SpikeTrain(
[39.65696411, 98.93868274, 120.2417674, 134.70971166,
154.20788924,
160.29077989, 179.19884034, 212.86773029, 247.59488061,
273.04095041,
297.56437605, 344.99204215, 418.55696486, 460.54298334,
482.82299125,
524.236052, 566.38966742, 597.87562722, 651.26965293,
692.39802855,
740.90285815, 849.45874695, 974.57724848, 8.79247605],
t_start=0.*pq.ms, t_stop=1000.*pq.ms, units=pq.ms)
surr.surrogates(spiketrain, n_surrogates=n_surr, method=surr_method,
dt=dither)

def test_dither_spike_train_output_format(self):

spiketrain = neo.SpikeTrain(
Expand Down Expand Up @@ -531,6 +549,41 @@ def test_joint_isi_dithering_output(self):
0.05146, 0.058489, 0.078053]
assert_array_almost_equal(surrogate_train.magnitude, ground_truth)

def test_joint_isi_with_wrongly_ordered_spikes(self):
surr_method = 'joint_isi_dithering'
n_surr = 30
dither = 15 * pq.ms
spiketrain = neo.SpikeTrain(
[39.65696411, 98.93868274, 120.2417674, 134.70971166,
154.20788924,
160.29077989, 179.19884034, 212.86773029, 247.59488061,
273.04095041,
297.56437605, 344.99204215, 418.55696486, 460.54298334,
482.82299125,
524.236052, 566.38966742, 597.87562722, 651.26965293,
692.39802855,
740.90285815, 849.45874695, 974.57724848, 8.79247605],
t_start=0.*pq.ms, t_stop=1000.*pq.ms, units=pq.ms)
surr.surrogates(spiketrain, n_surrogates=n_surr, method=surr_method,
dt=dither)

def test_joint_isi_spikes_at_border(self):
surr_method = 'joint_isi_dithering'
n_surr = 30
dither = 15 * pq.ms
spiketrain = neo.SpikeTrain(
[4., 28., 45., 51., 83., 87., 96., 111., 126., 131.,
138., 150.,
209., 232., 253., 275., 279., 303., 320., 371., 396.,
401., 429., 447.,
479., 511., 535., 549., 581., 585., 605., 607., 626.,
630., 644., 714.,
832., 835., 853., 858., 878., 905., 909., 932., 950.,
961., 999., 1000.],
t_start=0.*pq.ms, t_stop=1000.*pq.ms, units=pq.ms)
surr.surrogates(
spiketrain, n_surrogates=n_surr, method=surr_method, dt=dither)

def test_bin_shuffling_output_format(self):

self.bin_size = 3*pq.ms
Expand Down Expand Up @@ -573,10 +626,11 @@ def test_bin_shuffling_empty_train(self):
self.assertEqual(np.sum(surrogate_train.to_bool_array()), 0)

def test_trial_shuffling_output_format(self):
spiketrain = [neo.SpikeTrain([90, 93, 97, 100, 105,
150, 180, 190] * pq.ms, t_stop=.2 * pq.s),
neo.SpikeTrain([90, 93, 97, 100, 105,
150, 180, 190] * pq.ms, t_stop=.2 * pq.s)]
spiketrain = \
[neo.SpikeTrain([90, 93, 97, 100, 105, 150, 180, 190] * pq.ms,
t_stop=.2 * pq.s),
neo.SpikeTrain([90, 93, 97, 100, 105, 150, 180, 190] * pq.ms,
t_stop=.2 * pq.s)]
# trial_length = 200 * pq.ms
# trial_separation = 50 * pq.ms
n_surrogates = 2
Expand Down

0 comments on commit 1d48a77

Please sign in to comment.