Skip to content

Commit

Permalink
fixed warnings in time_histogram function (NeuralEnsemble#380)
Browse files Browse the repository at this point in the history
* added tolerance in check_neo_consistency
  • Loading branch information
dizcza committed Nov 6, 2020
1 parent 13be6fb commit a6a0854
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 32 deletions.
3 changes: 2 additions & 1 deletion elephant/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,8 @@ def check_consistency():
check_neo_consistency(spiketrains,
object_type=neo.SpikeTrain,
t_start=self._t_start,
t_stop=self._t_stop)
t_stop=self._t_stop,
tolerance=tolerance)
except ValueError as er:
# different t_start/t_stop
raise ValueError(er, "If you want to bin over the shared "
Expand Down
29 changes: 3 additions & 26 deletions elephant/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -895,32 +895,8 @@ def time_histogram(spiketrains, bin_size, t_start=None, t_stop=None,
elephant.conversion.BinnedSpikeTrain
"""
min_tstop = 0
if t_start is None:
# Find the internal range for t_start, where all spike trains are
# defined; cut all spike trains taking that time range only
max_tstart, min_tstop = get_common_start_stop_times(spiketrains)
t_start = max_tstart
if not all([max_tstart == t.t_start for t in spiketrains]):
warnings.warn(
"Spiketrains have different t_start values -- "
"using maximum t_start as t_start.")

if t_stop is None:
# Find the internal range for t_stop
if not min_tstop:
min_tstop = get_common_start_stop_times(spiketrains)[1]
t_stop = min_tstop
if not all([min_tstop == t.t_stop for t in spiketrains]):
warnings.warn(
"Spiketrains have different t_stop values -- "
"using minimum t_stop as t_stop.")

sts_cut = [st.time_slice(t_start=t_start, t_stop=t_stop) for st in
spiketrains]

# Bin the spike trains and sum across columns
bs = BinnedSpikeTrain(sts_cut, t_start=t_start, t_stop=t_stop,
bs = BinnedSpikeTrain(spiketrains, t_start=t_start, t_stop=t_stop,
bin_size=bin_size)

if binary:
Expand All @@ -944,7 +920,8 @@ def time_histogram(spiketrains, bin_size, t_start=None, t_stop=None,

return neo.AnalogSignal(signal=np.expand_dims(bin_hist, axis=1),
sampling_period=bin_size, units=bin_hist.units,
t_start=t_start, normalization=output, copy=False)
t_start=bs.t_start, normalization=output,
copy=False)


@deprecated_alias(binsize='bin_size')
Expand Down
16 changes: 11 additions & 5 deletions elephant/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def get_common_start_stop_times(neo_objects):


def check_neo_consistency(neo_objects, object_type, t_start=None,
t_stop=None):
t_stop=None, tolerance=1e-6):
"""
Checks that all input neo objects share the same units, t_start, and
t_stop.
Expand All @@ -149,6 +149,10 @@ def check_neo_consistency(neo_objects, object_type, t_start=None,
The common type.
t_start, t_stop : pq.Quantity or None, optional
If None, check for exact match of t_start/t_stop across the input.
tolerance : float, optional
The absolute affordable tolerance for the discrepancies between
t_start/stop magnitude values across trials.
Default : 1e-6
Raises
------
Expand All @@ -161,20 +165,22 @@ def check_neo_consistency(neo_objects, object_type, t_start=None,
neo_objects = [neo_objects]
try:
units = neo_objects[0].units
t_start0 = neo_objects[0].t_start.item()
t_stop0 = neo_objects[0].t_stop.item()
start = neo_objects[0].t_start.item()
stop = neo_objects[0].t_stop.item()
except AttributeError:
raise TypeError("The input must be a list of {}. Got {}".format(
object_type.__name__, type(neo_objects[0]).__name__))
if tolerance is None:
tolerance = 0
for neo_obj in neo_objects:
if not isinstance(neo_obj, object_type):
raise TypeError("The input must be a list of {}. Got {}".format(
object_type.__name__, type(neo_obj).__name__))
if neo_obj.units != units:
raise ValueError("The input must have the same units.")
if t_start is None and neo_obj.t_start.item() != t_start0:
if t_start is None and abs(neo_obj.t_start.item() - start) > tolerance:
raise ValueError("The input must have the same t_start.")
if t_stop is None and neo_obj.t_stop.item() != t_stop0:
if t_stop is None and abs(neo_obj.t_stop.item() - stop) > tolerance:
raise ValueError("The input must have the same t_stop.")


Expand Down

0 comments on commit a6a0854

Please sign in to comment.