Skip to content

Commit

Permalink
unit-less BinnedSpikeTrain (NeuralEnsemble#378)
Browse files Browse the repository at this point in the history
* BinnedSpikeTrainView

* optimized bin_shuffling

* optimized spike_train_timescale

* optimized  CCH
  • Loading branch information
dizcza committed Nov 5, 2020
1 parent d1f2fd5 commit 366cd7c
Show file tree
Hide file tree
Showing 7 changed files with 286 additions and 155 deletions.
270 changes: 181 additions & 89 deletions elephant/conversion.py

Large diffs are not rendered by default.

83 changes: 45 additions & 38 deletions elephant/spike_train_correlation.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,24 +74,24 @@ def get_valid_lags(binned_spiketrain_i, binned_spiketrain_j):
at full overlap (valid mode).
"""

bin_size = binned_spiketrain_i.bin_size
bin_size = binned_spiketrain_i._bin_size

# see cross_correlation_histogram for the examples
if binned_spiketrain_i.n_bins < binned_spiketrain_j.n_bins:
# ex. 1) lags range: [-2, 5] ms
# ex. 2) lags range: [1, 2] ms
left_edge = (binned_spiketrain_j.t_start -
binned_spiketrain_i.t_start) / bin_size
right_edge = (binned_spiketrain_j.t_stop -
binned_spiketrain_i.t_stop) / bin_size
left_edge = (binned_spiketrain_j._t_start -
binned_spiketrain_i._t_start) / bin_size
right_edge = (binned_spiketrain_j._t_stop -
binned_spiketrain_i._t_stop) / bin_size
else:
# ex. 3) lags range: [-1, 3] ms
left_edge = (binned_spiketrain_j.t_stop -
binned_spiketrain_i.t_stop) / bin_size
right_edge = (binned_spiketrain_j.t_start -
binned_spiketrain_i.t_start) / bin_size
right_edge = int(right_edge.simplified.magnitude)
left_edge = int(left_edge.simplified.magnitude)
left_edge = (binned_spiketrain_j._t_stop -
binned_spiketrain_i._t_stop) / bin_size
right_edge = (binned_spiketrain_j._t_start -
binned_spiketrain_i._t_start) / bin_size
right_edge = int(right_edge)
left_edge = int(left_edge)
lags = np.arange(left_edge, right_edge + 1, dtype=np.int32)

return lags
Expand All @@ -106,9 +106,6 @@ def correlate_memory(self, cch_mode):
Cross-correlation of `self.binned_spiketrain1` and
`self.binned_spiketrain2`.
"""
binned_spiketrain1 = self.binned_spiketrain_i
binned_spiketrain2 = self.binned_spiketrain_j

st1_spmat = self.binned_spiketrain_i.sparse_matrix
st2_spmat = self.binned_spiketrain_j.sparse_matrix
left_edge, right_edge = self.window
Expand All @@ -120,7 +117,8 @@ def correlate_memory(self, cch_mode):
# 'valid' mode requires bins correction due to the shift in t_starts
# 'full' and 'pad' modes don't need this correction
if cch_mode == "valid":
if binned_spiketrain1.n_bins > binned_spiketrain2.n_bins:
if self.binned_spiketrain_i.n_bins > \
self.binned_spiketrain_j.n_bins:
st2_bin_idx_unique += right_edge
else:
st2_bin_idx_unique += left_edge
Expand Down Expand Up @@ -354,7 +352,7 @@ def covariance(binned_spiketrain, binary=False, fast=True):
"""
if binary:
binned_spiketrain = binned_spiketrain.binarize(copy=True)
binned_spiketrain = binned_spiketrain.binarize()

if fast and binned_spiketrain.sparsity > _SPARSITY_MEMORY_EFFICIENT_THR:
array = binned_spiketrain.to_array()
Expand Down Expand Up @@ -454,7 +452,7 @@ def correlation_coefficient(binned_spiketrain, binary=False, fast=True):
"""
if binary:
binned_spiketrain = binned_spiketrain.binarize(copy=True)
binned_spiketrain = binned_spiketrain.binarize()

if fast and binned_spiketrain.sparsity > _SPARSITY_MEMORY_EFFICIENT_THR:
array = binned_spiketrain.to_array()
Expand Down Expand Up @@ -679,17 +677,21 @@ def cross_correlation_histogram(
if binned_spiketrain_i.shape[0] != 1 or \
binned_spiketrain_j.shape[0] != 1:
raise ValueError("Spike trains must be one dimensional")
if not np.isclose(binned_spiketrain_i.bin_size.simplified.item(),
binned_spiketrain_j.bin_size.simplified.item()):

# rescale to the common units
# this does not change the data - only its representation
binned_spiketrain_j.rescale(binned_spiketrain_i.units)

if not np.isclose(binned_spiketrain_i._bin_size,
binned_spiketrain_j._bin_size):
raise ValueError("Bin sizes must be equal")

bin_size = binned_spiketrain_i.bin_size
bin_size = binned_spiketrain_i._bin_size
left_edge_min = -binned_spiketrain_i.n_bins + 1
right_edge_max = binned_spiketrain_j.n_bins - 1

t_lags_shift = (binned_spiketrain_j.t_start -
binned_spiketrain_i.t_start) / bin_size
t_lags_shift = t_lags_shift.simplified.item()
t_lags_shift = (binned_spiketrain_j._t_start -
binned_spiketrain_i._t_start) / bin_size
if not np.isclose(t_lags_shift, round(t_lags_shift)):
# For example, if bin_size=1 ms, binned_spiketrain_i.t_start=0 ms, and
# binned_spiketrain_j.t_start=0.5 ms then there is a global shift in
Expand Down Expand Up @@ -746,8 +748,8 @@ def cross_correlation_histogram(
raise ValueError("Invalid window parameter")

if binary:
binned_spiketrain_i = binned_spiketrain_i.binarize(copy=True)
binned_spiketrain_j = binned_spiketrain_j.binarize(copy=True)
binned_spiketrain_i = binned_spiketrain_i.binarize()
binned_spiketrain_j = binned_spiketrain_j.binarize()

cch_builder = _CrossCorrHist(binned_spiketrain_i, binned_spiketrain_j,
window=(left_edge, right_edge))
Expand Down Expand Up @@ -775,11 +777,14 @@ def cross_correlation_histogram(
annotations = dict(cch_parameters=annotations)

# Transform the array count into an AnalogSignal
t_start = pq.Quantity((lags[0] - 0.5) * bin_size,
units=binned_spiketrain_i.units, copy=False)
cch_result = neo.AnalogSignal(
signal=np.expand_dims(cross_corr, axis=1),
units=pq.dimensionless,
t_start=(lags[0] - 0.5) * binned_spiketrain_i.bin_size,
sampling_period=binned_spiketrain_i.bin_size, **annotations)
t_start=t_start,
sampling_period=binned_spiketrain_i.bin_size, copy=False,
**annotations)
return cch_result, lags


Expand Down Expand Up @@ -975,9 +980,9 @@ def spike_train_timescale(binned_spiketrain, max_tau):
Returns
-------
timescale : pq.Quantity
The auto-correlation time of the binned spiketrain. If
`binned_spiketrain` has less than 2 spikes, a warning is raised and
`np.nan` is returned.
The auto-correlation time of the binned spiketrain with the same units
as in the input. If `binned_spiketrain` has less than 2 spikes, a
warning is raised and `np.nan` is returned.
Notes
-----
Expand All @@ -1002,14 +1007,15 @@ def spike_train_timescale(binned_spiketrain, max_tau):
"np.nan will be returned.")
return np.nan

bin_size = binned_spiketrain.bin_size
if not (max_tau / bin_size).simplified.units == pq.dimensionless:
bin_size = binned_spiketrain._bin_size
try:
max_tau = max_tau.rescale(binned_spiketrain.units).item()
except (AttributeError, ValueError):
raise ValueError("max_tau needs units of time")

# safe casting of max_tau/bin_size to integer
max_tau_bins = int(np.round((max_tau / bin_size).simplified.magnitude))
if not np.isclose(max_tau.simplified.magnitude,
(max_tau_bins * bin_size).simplified.magnitude):
max_tau_bins = int(round(max_tau / bin_size))
if not np.isclose(max_tau, max_tau_bins * bin_size):
raise ValueError("max_tau has to be a multiple of the bin_size")

cch_window = [-max_tau_bins, max_tau_bins]
Expand All @@ -1018,9 +1024,10 @@ def spike_train_timescale(binned_spiketrain, max_tau):
cross_correlation_coefficient=True
)
# Take only t > 0 values, in particular neglecting the delta peak.
corrfct_pos = corrfct.time_slice(bin_size / 2, corrfct.t_stop).flatten()
start_id = corrfct.time_index((bin_size / 2) * binned_spiketrain.units)
corrfct = corrfct.magnitude.squeeze()[start_id:]

# Calculate the timescale using trapezoidal integration
integr = np.abs((corrfct_pos / corrfct_pos[0]).magnitude)**2
integr = (corrfct / corrfct[0]) ** 2
timescale = 2 * integrate.trapz(integr, dx=bin_size)
return timescale
return pq.Quantity(timescale, units=binned_spiketrain.units, copy=False)
15 changes: 9 additions & 6 deletions elephant/spike_train_surrogates.py
Original file line number Diff line number Diff line change
Expand Up @@ -1376,16 +1376,19 @@ def surrogates(
if method is bin_shuffling:
binned_spiketrain = conv.BinnedSpikeTrain(
spiketrain, bin_size=kwargs['bin_size'])
bin_grid = binned_spiketrain.bin_centers.simplified.magnitude
bin_size = binned_spiketrain._bin_size
# bin_centers share the same units as bin_size
bin_grid = binned_spiketrain.bin_centers.magnitude
max_displacement = int(
dt.simplified.magnitude / kwargs['bin_size'].simplified.magnitude)
binned_surrogates = method(
binned_spiketrain, max_displacement, n_surrogates=n_surrogates)
dt.rescale(binned_spiketrain.units).item() / bin_size)
binned_surrogates = bin_shuffling(binned_spiketrain,
max_displacement=max_displacement,
n_surrogates=n_surrogates)
surrogate_spiketrains = \
[neo.SpikeTrain(bin_grid[binned_surr.to_bool_array()[0]] * pq.s,
[neo.SpikeTrain(bin_grid[binned_surr.sparse_matrix.nonzero()[1]],
t_start=spiketrain.t_start,
t_stop=spiketrain.t_stop,
units=spiketrain.units,
units=binned_spiketrain.units,
sampling_rate=spiketrain.sampling_rate)
for binned_surr in binned_surrogates]
return surrogate_spiketrains
Expand Down
4 changes: 2 additions & 2 deletions elephant/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -924,7 +924,7 @@ def time_histogram(spiketrains, bin_size, t_start=None, t_stop=None,
bin_size=bin_size)

if binary:
bs = bs.binarize(copy=False)
bs = bs.binarize()
bin_hist = bs.get_num_of_spikes(axis=0)
# Flatten array
bin_hist = np.ravel(bin_hist)
Expand All @@ -944,7 +944,7 @@ def time_histogram(spiketrains, bin_size, t_start=None, t_stop=None,

return neo.AnalogSignal(signal=np.expand_dims(bin_hist, axis=1),
sampling_period=bin_size, units=bin_hist.units,
t_start=t_start, normalization=output)
t_start=t_start, normalization=output, copy=False)


@deprecated_alias(binsize='bin_size')
Expand Down
48 changes: 37 additions & 11 deletions elephant/test/test_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,29 +531,55 @@ def test_binned_to_binned(self):
self.assertRaises(ValueError, cv.BinnedSpikeTrain, a,
bin_size=1 * pq.s)

def test_binnend_spiketrain_rescaling(self):
def test_binnend_spiketrain_different_input_units(self):
train = neo.SpikeTrain(times=np.array([1.001, 1.002, 1.005]) * pq.s,
t_start=1 * pq.s, t_stop=1.01 * pq.s)
bst = cv.BinnedSpikeTrain(train,
t_start=1 * pq.s, t_stop=1.01 * pq.s,
bin_size=1 * pq.ms)
self.assertEqual(bst.units, pq.s)
target_edges = np.array([1000, 1001, 1002, 1003, 1004, 1005, 1006,
1007, 1008, 1009, 1010], dtype=np.float)
1007, 1008, 1009, 1010], dtype=np.float
) * pq.ms
target_centers = np.array(
[1000.5, 1001.5, 1002.5, 1003.5, 1004.5, 1005.5, 1006.5, 1007.5,
1008.5, 1009.5], dtype=np.float)
assert_array_almost_equal(bst.bin_edges.magnitude, target_edges)
assert_array_almost_equal(bst.bin_centers.magnitude, target_centers)
self.assertEqual(bst.bin_centers.units, pq.ms)
self.assertEqual(bst.bin_edges.units, pq.ms)
1008.5, 1009.5], dtype=np.float) * pq.ms
assert_array_almost_equal(bst.bin_edges, target_edges)
assert_array_almost_equal(bst.bin_centers, target_centers)

bst = cv.BinnedSpikeTrain(train,
t_start=1 * pq.s, t_stop=1010 * pq.ms,
bin_size=1 * pq.ms)
assert_array_almost_equal(bst.bin_edges.magnitude, target_edges)
assert_array_almost_equal(bst.bin_centers.magnitude, target_centers)
self.assertEqual(bst.bin_centers.units, pq.ms)
self.assertEqual(bst.bin_edges.units, pq.ms)
self.assertEqual(bst.units, pq.s)
assert_array_almost_equal(bst.bin_edges, target_edges)
assert_array_almost_equal(bst.bin_centers, target_centers)

def test_rescale(self):
train = neo.SpikeTrain(times=np.array([1.001, 1.002, 1.005]) * pq.s,
t_start=1 * pq.s, t_stop=1.01 * pq.s)
bst = cv.BinnedSpikeTrain(train, t_start=1 * pq.s,
t_stop=1.01 * pq.s,
bin_size=1 * pq.ms)
self.assertEqual(bst.units, pq.s)
self.assertEqual(bst._t_start, 1) # 1 s
self.assertEqual(bst._t_stop, 1.01) # 1.01 s
self.assertEqual(bst._bin_size, 0.001) # 0.001 s

bst.rescale(units='ms')
self.assertEqual(bst.units, pq.ms)
self.assertEqual(bst._t_start, 1000) # 1 s
self.assertEqual(bst._t_stop, 1010) # 1.01 s
self.assertEqual(bst._bin_size, 1) # 0.001 s

def test_repr(self):
train = neo.SpikeTrain(times=np.array([1.001, 1.002, 1.005]) * pq.s,
t_start=1 * pq.s, t_stop=1.01 * pq.s)
bst = cv.BinnedSpikeTrain(train, t_start=1 * pq.s,
t_stop=1.01 * pq.s,
bin_size=1 * pq.ms)
self.assertEqual(repr(bst), "BinnedSpikeTrain(t_start=1.0 s, "
"t_stop=1.01 s, bin_size=0.001 s; "
"shape=(1, 10))")

def test_binned_sparsity(self):
train = neo.SpikeTrain(np.arange(10), t_stop=10 * pq.s, units=pq.s)
Expand Down
6 changes: 1 addition & 5 deletions elephant/test/test_spike_train_correlation.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,15 +792,11 @@ def test_timescale_calculation(self):
timescale = 1 / (4 * nu)
np.random.seed(35)

timescale_num = []
for _ in range(10):
spikes = homogeneous_gamma_process(2, 2 * nu, 0 * pq.ms, T)
spikes_bin = conv.BinnedSpikeTrain(spikes, bin_size)
timescale_i = sc.spike_train_timescale(spikes_bin, 10 * timescale)
timescale_i.units = timescale.units
timescale_num.append(timescale_i.magnitude)
assert_array_almost_equal(timescale.magnitude, timescale_num,
decimal=3)
assert_array_almost_equal(timescale, timescale_i, decimal=3)

def test_timescale_errors(self):
spikes = neo.SpikeTrain([1, 5, 7, 8] * pq.ms, t_stop=10 * pq.ms)
Expand Down
15 changes: 11 additions & 4 deletions elephant/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,16 +159,23 @@ def check_neo_consistency(neo_objects, object_type, t_start=None,
"""
if not isinstance(neo_objects, (list, tuple)):
neo_objects = [neo_objects]
try:
units = neo_objects[0].units
t_start0 = neo_objects[0].t_start.item()
t_stop0 = neo_objects[0].t_stop.item()
except AttributeError:
raise TypeError("The input must be a list of {}. Got {}".format(
object_type.__name__, type(neo_objects[0]).__name__))
for neo_obj in neo_objects:
if not isinstance(neo_obj, object_type):
raise TypeError("The input must be a list of {}. Got {}".format(
object_type.__name__, type(neo_obj).__name__))
if t_start is None and not neo_obj.t_start == neo_objects[0].t_start:
if neo_obj.units != units:
raise ValueError("The input must have the same units.")
if t_start is None and neo_obj.t_start.item() != t_start0:
raise ValueError("The input must have the same t_start.")
if t_stop is None and not neo_obj.t_stop == neo_objects[0].t_stop:
if t_stop is None and neo_obj.t_stop.item() != t_stop0:
raise ValueError("The input must have the same t_stop.")
if not neo_obj.units == neo_objects[0].units:
raise ValueError("The input must have the same units.")


def check_same_units(quantities, object_type=pq.Quantity):
Expand Down

0 comments on commit 366cd7c

Please sign in to comment.