Skip to content

Commit

Permalink
Corrected improper use of .base on quantities objects (NeuralEnsemble…
Browse files Browse the repository at this point in the history
…#211)

* Corrected impropoer use of .base on quantities objects

* Fixed situation where magnitude would have been called on a numpy array

* Added test to check for too high threshold, and fixed unit tests alltogether
  • Loading branch information
mdenker committed Apr 11, 2019
1 parent f79ce3b commit f2f332f
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 27 deletions.
22 changes: 11 additions & 11 deletions elephant/spike_train_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,19 +152,19 @@ def threshold_detection(signal, threshold=0.0 * mV, sign='above'):
cutout = np.where(signal < threshold)[0]

if len(cutout) <= 0:
events = np.zeros(0)
events_base = np.zeros(0)
else:
take = np.where(np.diff(cutout) > 1)[0] + 1
take = np.append(0, take)

time = signal.times
events = time[cutout][take]

events_base = events.base
if events_base is None:
# This occurs in some Python 3 builds due to some
# bug in quantities.
events_base = np.array([event.base for event in events]) # Workaround
events_base = events.magnitude
if events_base is None:
# This occurs in some Python 3 builds due to some
# bug in quantities.
events_base = np.array([event.magnitude for event in events]) # Workaround

result_st = SpikeTrain(events_base, units=signal.times.units,
t_start=signal.t_start, t_stop=signal.t_stop)
Expand Down Expand Up @@ -236,12 +236,12 @@ def peak_detection(signal, threshold=0.0 * mV, sign='above', format=None):
max_idc = maxima_idc_split + true_borders[0::2]

events = signal.times[max_idc]
events_base = events.base
events_base = events.magnitude

if events_base is None:
# This occurs in some Python 3 builds due to some
# bug in quantities.
events_base = np.array([event.base for event in events]) # Workaround
if events_base is None:
# This occurs in some Python 3 builds due to some
# bug in quantities.
events_base = np.array([event.magnitude for event in events]) # Workaround
if format is None:
result_st = SpikeTrain(events_base, units=signal.times.units,
t_start=signal.t_start, t_stop=signal.t_stop)
Expand Down
35 changes: 19 additions & 16 deletions elephant/test/test_spike_train_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,23 @@ def pdiff(a, b):
class AnalogSignalThresholdDetectionTestCase(unittest.TestCase):

def setUp(self):
pass

def test_threshold_detection(self):
# Test whether spikes are extracted at the correct times from
# an analog signal.

# Load membrane potential simulated using Brian2
# according to make_spike_extraction_test_data.py.
curr_dir = os.path.dirname(os.path.realpath(__file__))
raw_data_file_loc = os.path.join(curr_dir,'spike_extraction_test_data.txt')
raw_data_file_loc = os.path.join(curr_dir, 'spike_extraction_test_data.txt')
raw_data = []
with open(raw_data_file_loc, 'r') as f:
for x in (f.readlines()):
raw_data.append(float(x))
vm = neo.AnalogSignal(raw_data, units=V, sampling_period=0.1*ms)
spike_train = stgen.threshold_detection(vm)
self.vm = neo.AnalogSignal(raw_data, units=V, sampling_period=0.1*ms)
self.true_time_stamps = [0.0123, 0.0354, 0.0712, 0.1191, 0.1694,
0.2200, 0.2711] * second

def test_threshold_detection(self):
# Test whether spikes are extracted at the correct times from
# an analog signal.

spike_train = stgen.threshold_detection(self.vm)
try:
len(spike_train)
except TypeError: # Handles an error in Neo related to some zero length
Expand All @@ -58,17 +59,18 @@ def test_threshold_detection(self):
t_stop=spike_train.t_stop,
units=spike_train.units)

# Correct values determined previously.
true_spike_train = [0.0123, 0.0354, 0.0712, 0.1191,
0.1694, 0.22, 0.2711]

# Does threshold_detection gives the correct number of spikes?
self.assertEqual(len(spike_train),len(true_spike_train))
self.assertEqual(len(spike_train),len(self.true_time_stamps))
# Does threshold_detection gives the correct times for the spikes?
try:
assert_array_almost_equal(spike_train,spike_train)
assert_array_almost_equal(spike_train, self.true_time_stamps)
except AttributeError: # If numpy version too old to have allclose
self.assertTrue(np.array_equal(spike_train,spike_train))
self.assertTrue(np.array_equal(spike_train, self.true_time_stamps))

def test_peak_detection_threshold(self):
# Test for empty SpikeTrain when threshold is too high
result = stgen.threshold_detection(self.vm, threshold=30 * mV)
self.assertEqual(len(result), 0)


class AnalogSignalPeakDetectionTestCase(unittest.TestCase):
Expand Down Expand Up @@ -100,6 +102,7 @@ def test_peak_detection_threshold(self):
result = stgen.peak_detection(self.vm, threshold=30 * mV)
self.assertEqual(len(result), 0)


class AnalogSignalSpikeExtractionTestCase(unittest.TestCase):

def setUp(self):
Expand Down

0 comments on commit f2f332f

Please sign in to comment.