Skip to content

Commit

Permalink
Merge pull request #646 from apdavison/issue645
Browse files Browse the repository at this point in the history
Fix for #645
  • Loading branch information
apdavison committed Jul 1, 2019
2 parents df49d69 + 3dad39e commit 1142720
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 7 deletions.
110 changes: 110 additions & 0 deletions examples/update_spike_source_array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""
A demonstration of the use of callbacks to update the spike times in a SpikeSourceArray.
Usage: update_spike_source_array.py [-h] [--plot-figure] simulator
positional arguments:
simulator neuron, nest, brian or another backend simulator
optional arguments:
-h, --help show this help message and exit
--plot-figure Plot the simulation results to a file.
"""

import numpy as np
from pyNN.utility import get_simulator, normalized_filename, ProgressBar
from pyNN.utility.plotting import Figure, Panel
from pyNN.parameters import Sequence

sim, options = get_simulator(("--plot-figure", "Plot the simulation results to a file.",
{"action": "store_true"}))

rate_increment = 20
interval = 200


class SetRate(object):
"""
A callback which changes the firing rate of a population of spike
sources at a fixed interval.
"""

def __init__(self, population, rate_generator, update_interval=20.0):
assert isinstance(population.celltype, sim.SpikeSourceArray)
self.population = population
self.update_interval = update_interval
self.rate_generator = rate_generator

def __call__(self, t):
try:
rate = next(rate_generator)
if rate > 0:
isi = 1000.0/rate
times = t + np.arange(0, self.update_interval, isi)
# here each neuron fires with the same isi,
# but there is a phase offset between neurons
spike_times = [
Sequence(times + phase * isi)
for phase in self.population.annotations["phase"]
]
else:
spike_times = []
self.population.set(spike_times=spike_times)
except StopIteration:
pass
return t + self.update_interval


class MyProgressBar(object):
"""
A callback which draws a progress bar in the terminal.
"""

def __init__(self, interval, t_stop):
self.interval = interval
self.t_stop = t_stop
self.pb = ProgressBar(width=int(t_stop / interval), char=".")

def __call__(self, t):
self.pb(t / self.t_stop)
return t + self.interval


sim.setup()


# === Create a population of poisson processes ===============================

p = sim.Population(50, sim.SpikeSourceArray())
p.annotate(phase=np.random.uniform(0, 1, size=p.size))
p.record('spikes')


# === Run the simulation, with two callback functions ========================

rate_generator = iter(range(0, 100, rate_increment))
sim.run(1000, callbacks=[MyProgressBar(10.0, 1000.0),
SetRate(p, rate_generator, interval)])


# === Retrieve recorded data, and count the spikes in each interval ==========

data = p.get_data().segments[0]

all_spikes = np.hstack([st.magnitude for st in data.spiketrains])
spike_counts = [((all_spikes >= x) & (all_spikes < x + interval)).sum()
for x in range(0, 1000, interval)]
expected_spike_counts = [p.size * rate * interval / 1000.0
for rate in range(0, 100, rate_increment)]

print("\nActual spike counts: {}".format(spike_counts))
print("Expected mean spike counts: {}".format(expected_spike_counts))

if options.plot_figure:
Figure(
Panel(data.spiketrains, xlabel="Time (ms)", xticks=True, markersize=0.5),
title="Incrementally updated SpikeSourceArrays",
annotations="Simulated with %s" % options.simulator.upper()
).save(normalized_filename("Results", "update_spike_source_array", "png", options.simulator))

sim.end()
22 changes: 19 additions & 3 deletions pyNN/neuron/cells.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,10 +581,12 @@ class VectorSpikeSource(hclass(h.VecStim)):
parameter_names = ('spike_times',)

def __init__(self, spike_times=[]):
self.recording = False
self.spike_times = spike_times
self.source = self
self.source_section = None
self.rec = None
self._recorded_spikes = numpy.array([])

def _set_spike_times(self, spike_times):
# spike_times should be a Sequence object
Expand All @@ -595,18 +597,32 @@ def _set_spike_times(self, spike_times):
if numpy.any(spike_times.value[:-1] > spike_times.value[1:]):
raise errors.InvalidParameterValueError("Spike times given to SpikeSourceArray must be in increasing order")
self.play(self._spike_times)
if self.recording:
self._recorded_spikes = numpy.hstack((self._recorded_spikes, spike_times.value))

def _get_spike_times(self):
return self._spike_times

spike_times = property(fget=_get_spike_times,
fset=_set_spike_times)

@property
def recording(self):
return self._recording

@recording.setter
def recording(self, value):
self._recording = value
if value:
# when we turn recording on, the cell may already have had its spike times assigned
self._recorded_spikes = numpy.hstack((self._recorded_spikes, self.spike_times))

def get_recorded_spike_times(self):
return self._recorded_spikes

def clear_past_spikes(self):
"""If previous recordings are cleared, need to remove spikes from before the current time."""
end = self._spike_times.indwhere(">", h.t)
if end > 0:
self._spike_times.remove(0, end - 1) # range is inclusive
self._recorded_spikes = self._recorded_spikes[self._recorded_spikes > h.t]


class ArtificialCell(object):
Expand Down
7 changes: 6 additions & 1 deletion pyNN/neuron/recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ def _record(self, variable, new_ids, sampling_interval=None):
for id in new_ids:
if id._cell.rec is not None:
id._cell.rec.record(id._cell.spike_times)
else: # SpikeSourceArray
id._cell.recording = True
else:
self.sampling_interval = sampling_interval or self._simulator.state.dt
for id in new_ids:
Expand Down Expand Up @@ -97,7 +99,10 @@ def _get_spiketimes(self, id):
if hasattr(id, "__len__"):
all_spiketimes = {}
for cell_id in id:
spikes = numpy.array(cell_id._cell.spike_times)
if cell_id._cell.rec is None: # SpikeSourceArray
spikes = cell_id._cell.get_recorded_spike_times()
else:
spikes = numpy.array(cell_id._cell.spike_times)
all_spiketimes[cell_id] = spikes[spikes <= simulator.state.t + 1e-9]
return all_spiketimes
else:
Expand Down
27 changes: 24 additions & 3 deletions test/system/scenarios/test_cell_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
have_scipy = True
except ImportError:
have_scipy = False
from numpy.testing import assert_array_equal
import quantities as pq
from nose.tools import assert_greater, assert_less, assert_raises
from pyNN.parameters import Sequence
from pyNN.errors import InvalidParameterValueError

from .registry import register
Expand All @@ -31,7 +33,7 @@ def test_EIF_cond_alpha_isfa_ista(sim, plot_figure=False):
plt.plot(expected_spike_times, -40 * numpy.ones_like(expected_spike_times), "ro")
plt.savefig("test_EIF_cond_alpha_isfa_ista_%s.png" % sim.__name__)
diff = (data.spiketrains[0].rescale(pq.ms).magnitude - expected_spike_times) / expected_spike_times
assert abs(diff).max() < 0.01, abs(diff).max()
assert abs(diff).max() < 0.01, abs(diff).max()
sim.end()
return data
test_EIF_cond_alpha_isfa_ista.__test__ = False
Expand Down Expand Up @@ -262,7 +264,6 @@ def test_SpikeSourcePoissonRefractory(sim, plot_figure=False):
test_SpikeSourcePoissonRefractory.__test__ = False



@register()
def issue511(sim):
"""Giving SpikeSourceArray an array of non-ordered spike times should produce an InvalidParameterValueError error"""
Expand All @@ -271,6 +272,25 @@ def issue511(sim):
assert_raises(InvalidParameterValueError, sim.Population, 2, celltype)


@register()
def test_update_SpikeSourceArray(sim, plot_figure=False):
sim.setup()
sources = sim.Population(2, sim.SpikeSourceArray(spike_times=[]))
sources.record('spikes')
sim.run(10.0)
sources.set(spike_times=[
Sequence([12, 15, 18]),
Sequence([17, 19])
])
sim.run(10.0)
sources.set(spike_times=[
Sequence([22, 25]),
Sequence([23, 27, 29])
])
sim.run(10.0)
data = sources.get_data().segments[0].spiketrains
assert_array_equal(data[0].magnitude, numpy.array([12, 15, 18, 22, 25]))
test_update_SpikeSourceArray.__test__ = False

# todo: add test of Izhikevich model

Expand All @@ -286,4 +306,5 @@ def issue511(sim):
test_SpikeSourcePoisson(sim, plot_figure=args.plot_figure)
test_SpikeSourceGamma(sim, plot_figure=args.plot_figure)
test_SpikeSourcePoissonRefractory(sim, plot_figure=args.plot_figure)
issue511(sim)
issue511(sim)
test_update_SpikeSourceArray(sim, plot_figure=args.plot_figure)

0 comments on commit 1142720

Please sign in to comment.