Skip to content

Commit

Permalink
sonata module: Support for "target_simulator = 'NEST'" and for record…
Browse files Browse the repository at this point in the history
…ing state variables.
  • Loading branch information
apdavison committed Mar 14, 2019
1 parent 31040db commit 28861a5
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 55 deletions.
26 changes: 19 additions & 7 deletions pyNN/common/populations.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,9 +229,13 @@ def gen(i):
return gen

def _get_cell_initial_value(self, id, variable):
assert isinstance(self.initial_values[variable], LazyArray)
index = self.id_to_local_index(id)
return self.initial_values[variable][index]
if variable in self.initial_values:
assert isinstance(self.initial_values[variable], LazyArray)
index = self.id_to_local_index(id)
return self.initial_values[variable][index]
else:
logger.warning("Variable '{}' is not in initial values, returning 0.0".format(variable))
return 0.0

def _set_cell_initial_value(self, id, variable, value):
assert isinstance(self.initial_values[variable], LazyArray)
Expand Down Expand Up @@ -266,7 +270,7 @@ def get(self, parameter_names, gather=False, simplify=True):
"""
Get the values of the given parameters for every local cell in the
population, or, if gather=True, for all cells in the population.
Values will be expressed in the standard PyNN units (i.e. millivolts,
nanoamps, milliseconds, microsiemens, nanofarads, event per second).
"""
Expand Down Expand Up @@ -418,6 +422,10 @@ def can_record(self, variable):
"""Determine whether `variable` can be recorded from this population."""
return self.celltype.can_record(variable)

@property
def injectable(self):
return self.celltype.injectable

def record(self, variables, to_file=None, sampling_interval=None):
"""
Record the specified variable or variables for all cells in the
Expand All @@ -429,7 +437,7 @@ def record(self, variables, to_file=None, sampling_interval=None):
If specified, `to_file` should be either a filename or a Neo IO instance and `write_data()`
will be automatically called when `end()` is called.
`sampling_interval` should be a value in milliseconds, and an integer
multiple of the simulation timestep.
"""
Expand Down Expand Up @@ -479,7 +487,7 @@ def write_data(self, io, variables='all', gather=True, clear=False, annotations=
simulated on that node.
If `clear` is True, recorded data will be deleted from the `Population`.
`annotations` should be a dict containing simple data types such as
numbers and strings. The contents will be written into the output data
file as metadata.
Expand Down Expand Up @@ -533,7 +541,7 @@ def get_gsyn(self, gather=True, compatible_output=True):
def get_spike_counts(self, gather=True):
"""
Returns a dict containing the number of spikes for each neuron.
The dict keys are neuron IDs, not indices.
"""
# arguably, we should use indices
Expand Down Expand Up @@ -1426,6 +1434,10 @@ def inject(self, current_source):
for p in self.populations:
current_source.inject_into(p)

@property
def injectable(self):
return all(p.injectable for p in self.populations)

def describe(self, template='assembly_default.txt', engine='default'):
"""
Returns a human-readable description of the assembly.
Expand Down
9 changes: 6 additions & 3 deletions pyNN/common/projections.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,12 @@ def __init__(self, presynaptic_neurons, postsynaptic_neurons, connector,
self.pre = presynaptic_neurons # } these really
self.source = source # } should be
self.post = postsynaptic_neurons # } read-only
self.receptor_type = receptor_type or 'excitatory'
# TO FIX: (1) if weights are negative, default should be 'inhibitory'
# (2) default receptor type should depend on post-synaptic cell type
if receptor_type == "default":
receptor_type = None
self.receptor_type = receptor_type or sorted(postsynaptic_neurons.receptor_types)[0]
# TO FIX: if weights are negative, default should be the first inhibitory receptor type,
# not necessarily the first in alphabetical order.
# Should perhaps explicitly specify the default type(s)
if self.receptor_type not in postsynaptic_neurons.receptor_types:
valid_types = postsynaptic_neurons.receptor_types
assert len(valid_types) > 0
Expand Down
4 changes: 3 additions & 1 deletion pyNN/nest/cells.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
UNITS_MAP = {
'spikes': 'ms',
'V_m': 'mV',
'I_syn_ex': 'pA',
'I_syn_in': 'pA'
}


Expand All @@ -28,7 +30,7 @@ def get_defaults(model_name):
'thread', 'vp', 'receptor_types', 'events', 'global_id',
'element_type', 'type', 'type_id', 'has_connections', 'n_synapses',
'thread_local_id', 'node_uses_wfr', 'supports_precise_spikes',
'synaptic_elements', 'y_0', 'y_1']
'synaptic_elements', 'y_0', 'y_1', 'allow_offgrid_spikes', 'shift_now_spikes']
default_params = {}
default_initial_values = {}
for name, value in defaults.items():
Expand Down
5 changes: 3 additions & 2 deletions pyNN/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,10 @@ def get_component(self, label):
return obj
return None

def record(self, variables, to_file=None, sampling_interval=None):
def record(self, variables, to_file=None, sampling_interval=None, include_spike_source=True):
for obj in chain(self.populations, self.assemblies):
obj.record(variables, to_file=to_file, sampling_interval=sampling_interval)
if include_spike_source or obj.injectable: # spike sources are not injectable
obj.record(variables, to_file=to_file, sampling_interval=sampling_interval)

def get_data(self, variables='all', gather=True, clear=False, annotations=None):
return [assembly.get_data(variables, gather, clear, annotations)
Expand Down
146 changes: 104 additions & 42 deletions pyNN/serialization/sonata.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,14 @@ class SonataIO(BaseIO):

def __init__(self, base_dir,
spikes_file="spikes.h5",
spikes_sort_order=None):
spikes_sort_order=None,
report_config=None,
node_sets=None):
self.base_dir = base_dir
self.spike_file = spikes_file
self.spikes_sort_order = spikes_sort_order
self.report_config = report_config
self.node_sets = node_sets

def read(self):
"""
Expand Down Expand Up @@ -100,8 +104,8 @@ def write(self, blocks):
"""
Write a list of Blocks to SONATA HDF5 files.
(Currently only spike data supported).
"""
# Write spikes
spike_file_path = join(self.base_dir, self.spike_file)
spikes_file = h5py.File(spike_file_path, 'w')
spike_trains = []
Expand All @@ -113,13 +117,51 @@ def write(self, blocks):
spikes_group = spikes_file.create_group("spikes")
all_spike_times = np.hstack(st.rescale('ms').magnitude
for st in spike_trains).astype(np.float64)
gids = np.hstack(st.annotations["source_id"] * np.ones(st.shape, dtype=np.uint64)
gids = np.hstack(st.annotations["source_index"] * np.ones(st.shape, dtype=np.uint64)
for st in spike_trains)
# todo: handle sorting
spikes_group.create_dataset("timestamps", data=all_spike_times, dtype=np.float64)
spikes_group.create_dataset("gids", data=gids, dtype=np.uint64)
spikes_file.close()
logger.info("Wrote output to {}".format(spike_file_path))
logger.info("Wrote spike output to {}".format(spike_file_path))

# Write signals
for report_name, report_metadata in self.report_config.items():
file_name = report_metadata.get("file_name", report_name + ".h5")
file_path = join(self.base_dir, file_name)

signal_file = h5py.File(file_path, 'w')
population_name = self.node_sets[report_metadata["cells"]]["population"]
node_ids = self.node_sets[report_metadata["cells"]]["node_id"]
for block in blocks:
if block.name == population_name:
if len(block.segments) > 1:
raise NotImplementedError()
signal = block.segments[0].filter(name=report_metadata["variable_name"])
if len(signal) != 1:
raise NotImplementedError()

report_group = signal_file.create_group("report")
population_group = report_group.create_group(population_name)
dataset = population_group.create_dataset("data", data=signal[0].magnitude)
dataset.attrs["units"] = signal[0].units.dimensionality.string
dataset.attrs["variable_name"] = report_metadata["variable_name"]
n = dataset.shape[1]
mapping_group = population_group.create_group("mapping")
mapping_group.create_dataset("node_ids", data=node_ids)
# "gids" not in the spec, but expected by some bmtk utils
mapping_group.create_dataset("gids", data=node_ids)
#mapping_group.create_dataset("index_pointers", data=np.zeros((n,)))
mapping_group.create_dataset("index_pointer", data=np.arange(0, n+1)) # ??spec unclear
mapping_group.create_dataset("element_ids", data=np.zeros((n,)))
mapping_group.create_dataset("element_pos", data=np.zeros((n,)))
time_ds = mapping_group.create_dataset("time",
data=(float(signal[0].t_start),
float(signal[0].t_stop),
float(signal[0].sampling_period)))
time_ds.attrs["units"] = "ms"
logger.info("Wrote block {} to {}".format(block.name, file_path))
signal_file.close()


MAGIC = 0x0a7a
Expand Down Expand Up @@ -491,12 +533,8 @@ def import_from_sonata(config_file, sim):
"""
config = load_config(config_file)

if config.get("target_simulator", None) != "PyNN":
warn("`target_simulator` is not set to 'PyNN'. Proceeding with caution...")
# could also easily handle target_simulator="NEST" using native models
# NEURON also possible using native models, but a bit more involved
# seems that target_simulator is sometimes in the circuit_config

if config.get("target_simulator", None) not in ("PyNN", "NEST"):
warn("`target_simulator` is not set to 'PyNN' or 'NEST'. Proceeding with caution...")

sonata_node_populations = []
for nodes_config in config["networks"]["nodes"]:
Expand All @@ -519,7 +557,7 @@ def import_from_sonata(config_file, sim):
])

sonata_edge_populations = []

if "edges" in config["networks"]:
for edges_config in config["networks"]["edges"]:

Expand Down Expand Up @@ -688,6 +726,7 @@ def from_data(cls, id, node_types_array, index, h5_data, node_types_map, config)
parameters[key] = dynamics_params_group[key].value[index]

obj.parameters = parameters
obj.config = config
logger.info(parameters)

return obj
Expand All @@ -704,19 +743,27 @@ def get_cell_type(self, sim):
raise NotImplementedError("Only point neurons currently supported.")

if model_type == "virtual":
cell_types.add("SpikeSourceArray")
if self.config.get("target_simulator") == "NEST":
cell_types.add("nest:spike_generator")
else:
cell_types.add("pyNN:SpikeSourceArray")
else:
prefix, cell_type = self.parameters["model_template"][node_type_id].split(":")
if prefix.lower() not in ("pynn", "nrn"):
raise NotImplementedError("Only PyNN and NEURON-native networks currently supported, not: %s (from %s)."% \
(prefix, self.parameters["model_template"][node_type_id]))
cell_types.add(cell_type)
cell_types.add(self.parameters["model_template"][node_type_id])

if len(cell_types) != 1:
raise Exception("Heterogeneous group, not currently supported.")

cell_type_name = cell_types.pop()
cell_type_cls = getattr(sim, cell_type_name)
cell_type = cell_types.pop()
prefix, cell_type_name = cell_type.split(":")
if prefix.lower() not in ("pynn", "nrn", "nest"):
raise NotImplementedError("Only PyNN, NEST and NEURON-native networks currently supported, not: %s (from %s)."% \
(prefix, self.parameters["model_template"][node_type_id]))
if prefix.lower() == "nest":
cell_type_cls = sim.native_cell_type(cell_type_name)
if cell_type_name == "spike_generator":
cell_type_cls.uses_parrot = False
else:
cell_type_cls = getattr(sim, cell_type_name)
logger.info(" cell_type: {}".format(cell_type_cls))
return cell_type_cls

Expand Down Expand Up @@ -876,8 +923,11 @@ def from_data(cls, id, edge_types_array, index, source_ids, target_ids,
parameters[key] = dynamics_params_group[key].value[index]
if 'nsyns' in h5_data:
parameters['nsyns'] = h5_data['nsyns'].value[index]
if 'syn_weight' in h5_data:
parameters['syn_weight'] = h5_data['syn_weight'].value[index]

obj.parameters = parameters
obj.config = config
logger.info(parameters)
return obj

Expand All @@ -893,19 +943,23 @@ def get_synapse_and_receptor_type(self, sim):
model_templates = self.parameters.get("model_template", None)
if model_templates:
for edge_type_id, model_template in model_templates.items():
prefix, syn_type = model_template.split(":")
if prefix.lower() not in ("pynn", "nrn"):
raise NotImplementedError("Only PyNN and NEURON-native networks currently supported.")
synapse_types.add(syn_type)
synapse_types.add(model_template)

if len(synapse_types) != 1:
raise Exception("Heterogeneous group, not currently supported.")

synapse_type_name = synapse_types.pop()
synapse_type = synapse_types.pop()
prefix, synapse_type_name = model_template.split(":")
if prefix.lower() not in ("pynn", "nrn", "nest"):
raise NotImplementedError("Only PyNN, NEST and NEURON-native networks currently supported.")
else:
prefix = "pyNN"
synapse_type_name = "StaticSynapse"

synapse_type_cls = getattr(sim, synapse_type_name)
if prefix == "nest":
synapse_type_cls = sim.native_synapse_type(synapse_type_name)
else:
synapse_type_cls = getattr(sim, synapse_type_name)

receptor_types = self.parameters.get("receptor_type", None)
if receptor_types:
Expand Down Expand Up @@ -1007,23 +1061,35 @@ def __init__(self, run, inputs=None, output=None, reports=None,
self.inputs = {}
if self.reports is None:
self.reports = {}
if self.node_sets_file is not None:
with open(self.node_sets_file) as fp:
self.node_sets = json.load(fp)
# make all node set names lower case, needed by 300 IF neuron example
self.node_sets = {k.lower(): v for k, v in self.node_sets.items()}
# todo: handle compound node sets

def setup(self, sim):
self.sim = sim
sim.setup(timestep=self.run_config["dt"])

def _get_target(self, input_config, node_sets, net):
target = node_sets[input_config["node_set"]]
def _get_target(self, config, node_sets, net):
if "node_set" in config: # input config
target = node_sets[config["node_set"]]
elif "cells" in config: # recording config
# inconsistency in SONATA spec? Why not call this "node_set" also?
target = node_sets[config["cells"]]
if "model_type" in target:
raise NotImplementedError()
if "location" in target:
raise NotImplementedError()
if "node_id" in target:
raise NotImplementedError()
if "gids" in target:
raise NotImplementedError()
if "population" in target:
return net.get_component(target["population"])
assembly = net.get_component(target["population"])
if "node_id" in target:
indices = target["node_id"]
assembly = assembly[indices]
return assembly

def _set_input_spikes(self, input_config, node_sets, net):
# determine which assembly the spikes are for
Expand All @@ -1047,25 +1113,20 @@ def _set_input_spikes(self, input_config, node_sets, net):
# todo: map cell ids in spikes file to ids/index in the population
#logger.info("SETTING SPIKETIMES")
#logger.info(spiketrains)
assembly.set(spike_times=[Sequence(st.times) for st in spiketrains])
assembly.set(spike_times=[Sequence(st.times.rescale('ms').magnitude) for st in spiketrains])

def execute(self, net):
# create/configure inputs
if self.node_sets_file is not None:
with open(self.node_sets_file) as fp:
node_sets = json.load(fp)
# make all node set names lower case, needed by 300 IF neuron example
node_sets = {k.lower(): v for k, v in node_sets.items()}
# todo: handle compound node sets
for input_name, input_config in self.inputs.items():
if input_config["input_type"] != "spikes":
raise NotImplementedError()
self._set_input_spikes(input_config, node_sets, net)
self._set_input_spikes(input_config, self.node_sets, net)

# configure recording
net.record('spikes') # SONATA requires that we record spikes from all nodes
net.record('spikes', include_spike_source=False) # SONATA requires that we record spikes from all non-virtual nodes
for report_name, report_config in self.reports.items():
raise NotImplementedError("Reports not yet implemented")
assembly = self._get_target(report_config, self.node_sets, net)
assembly.record(report_config["variable_name"])

# run simulation
self.sim.run(self.run_config["tstop"])
Expand All @@ -1078,12 +1139,13 @@ def execute(self, net):
os.makedirs(directory)
io = SonataIO(self.output["output_dir"],
spikes_file=self.output.get("spikes_file", "spikes.h5"),
spikes_sort_order=self.output["spikes_sort_order"])
spikes_sort_order=self.output["spikes_sort_order"],
report_config=self.reports,
node_sets=self.node_sets)
# todo: handle reports
net.write_data(io)

@classmethod
def from_config(cls, config):
obj = cls(**config)
return obj

0 comments on commit 28861a5

Please sign in to comment.