Skip to content

Commit

Permalink
Add the $target query to nodes (#76)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomdele committed Jul 13, 2020
1 parent addfb54 commit c098e29
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 8 deletions.
19 changes: 15 additions & 4 deletions bluepysnap/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@
POPULATION_KEY, Node, ConstContainer)


# this constant is not part of the sonata standard
NODE_SET_KEY = "$node_set"


class NodeStorage(object):
"""Node storage access."""

Expand Down Expand Up @@ -282,23 +286,30 @@ def _positional_mask(self, node_ids):
mask[valid_node_ids] = True
return mask

def _node_population_mask(self, queries):
"""Handle the population and node ID queries."""
def _circuit_mask(self, queries):
"""Handle the population, node ID and node set queries."""
populations = queries.pop(POPULATION_KEY, None)
if populations is not None and self.name not in set(utils.ensure_list(populations)):
node_ids = []
else:
node_ids = queries.pop(NODE_ID_KEY, None)
node_set = queries.pop(NODE_SET_KEY, None)
if node_set is not None:
if not isinstance(node_set, six.string_types):
raise BluepySnapError("{} is not a valid node set name.".format(node_set))
node_ids = node_ids if node_ids else self._data.index.values
node_ids = np.intersect1d(node_ids, self.ids(node_set))
return queries, self._positional_mask(node_ids)

def _properties_mask(self, queries):
"""Return mask of node IDs with rows matching `props` dict."""
# pylint: disable=assignment-from-no-return
unknown_props = set(queries) - set(self._data.columns) - {POPULATION_KEY, NODE_ID_KEY}
circuit_keys = {POPULATION_KEY, NODE_ID_KEY, NODE_SET_KEY}
unknown_props = set(queries) - set(self._data.columns) - circuit_keys
if unknown_props:
raise BluepySnapError("Unknown node properties: [{0}]".format(", ".join(unknown_props)))

queries, mask = self._node_population_mask(queries)
queries, mask = self._circuit_mask(queries)
if not mask.any():
# Avoid fail and/or processing time if wrong population or no nodes
return mask
Expand Down
26 changes: 22 additions & 4 deletions tests/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,22 +159,22 @@ def test__positional_mask(self):
npt.assert_array_equal(self.test_obj._positional_mask([0, 2]), [True, False, True])

def test__node_population_mask(self):
queries, mask = self.test_obj._node_population_mask({"population": "default",
queries, mask = self.test_obj._circuit_mask({"population": "default",
"other": "val"})
assert queries == {"other": "val"}
npt.assert_array_equal(mask, [True, True, True])

queries, mask = self.test_obj._node_population_mask({"population": "unknown",
queries, mask = self.test_obj._circuit_mask({"population": "unknown",
"other": "val"})
assert queries == {"other": "val"}
npt.assert_array_equal(mask, [False, False, False])

queries, mask = self.test_obj._node_population_mask({"population": "default",
queries, mask = self.test_obj._circuit_mask({"population": "default",
"node_id": [2], "other": "val"})
assert queries == {"other": "val"}
npt.assert_array_equal(mask, [False, False, True])

queries, mask = self.test_obj._node_population_mask({"other": "val"})
queries, mask = self.test_obj._circuit_mask({"other": "val"})
assert queries == {"other": "val"}
npt.assert_array_equal(mask, [True, True, True])

Expand All @@ -198,6 +198,8 @@ def test_ids(self):
npt.assert_equal(_call({"node_id": [1]}), [1])
npt.assert_equal(_call({"node_id": [1, 2]}), [1, 2])
npt.assert_equal(_call({"node_id": [1, 2, 42]}), [1, 2])
npt.assert_equal(_call({"node_id": [1], "population": ["default"],
Cell.MORPHOLOGY: "morph-B"}), [1])

# same query with a $and operator
npt.assert_equal(_call({"$and": [{Cell.MTYPE: 'L6_Y'}, {Cell.MORPHOLOGY: "morph-B"}]}), [1])
Expand Down Expand Up @@ -235,6 +237,18 @@ def test_ids(self):
npt.assert_equal(_call('combined_combined_Node0_L6_Y__Node12_L6_Y__'),
[0, 1, 2]) # imbricated '$or' functions

npt.assert_equal(_call({"$node_set": 'Node12_L6_Y', "node_id": 1}), [1])
npt.assert_equal(_call({"$node_set": 'Node12_L6_Y', "node_id": [1, 2, 3]}), [1, 2])
npt.assert_equal(_call({"$node_set": 'Node12_L6_Y', "population": "default"}), [1, 2])
npt.assert_equal(_call({"$node_set": 'Node12_L6_Y', "population": "default", "node_id": 1}),
[1])
npt.assert_equal(_call({"$node_set": 'Node12_L6_Y', Cell.MORPHOLOGY: "morph-B"}),
[1])
npt.assert_equal(_call({"$and": [{"$node_set": 'Node12_L6_Y', "population": "default"},
{Cell.MORPHOLOGY: "morph-B"}]}), [1])
npt.assert_equal(_call({"$or": [{"$node_set": 'Node12_L6_Y', "population": "default"},
{Cell.MORPHOLOGY: "morph-B"}]}), [1, 2])

with pytest.raises(BluepySnapError):
_call('no-such-node-set')
with pytest.raises(BluepySnapError):
Expand All @@ -245,6 +259,10 @@ def test_ids(self):
_call([1, 999]) # one of node IDs out of range
with pytest.raises(BluepySnapError):
_call({'no-such-node-property': 42})
with pytest.raises(BluepySnapError):
_call({"$node_set": [1, 2]})
with pytest.raises(BluepySnapError):
_call({"$node_set": 'no-such-node-set'})

def test_node_ids_by_filter_complex_query(self):
test_obj = create_node_population(str(TEST_DATA_DIR / 'nodes.h5'), "default")
Expand Down

0 comments on commit c098e29

Please sign in to comment.