Skip to content

Commit

Permalink
Move query functionality to a standalone module (#127)
Browse files Browse the repository at this point in the history
  • Loading branch information
asanin-epfl committed Feb 17, 2021
1 parent 1c0fe49 commit db69493
Show file tree
Hide file tree
Showing 7 changed files with 245 additions and 137 deletions.
1 change: 1 addition & 0 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,4 @@ ignore-docstrings=yes

[TYPECHECK]
extension-pkg-whitelist=libsonata
ignored-modules=numpy
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ New Features
~~~~~~~~~~~~~
- Added NeuronModelsHelper to access nodes neuron models

Improvements
~~~~~~~~~~~~~~
- Moved nodes query mechanism to a separate module

Version v0.9.1
--------------

Expand Down
120 changes: 20 additions & 100 deletions bluepysnap/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,11 @@
from bluepysnap.network import NetworkObject
from bluepysnap import utils
from bluepysnap.exceptions import BluepySnapError
from bluepysnap.sonata_constants import (DYNAMICS_PREFIX, NODE_ID_KEY,
POPULATION_KEY, Node, ConstContainer)
import bluepysnap.query as query
from bluepysnap.sonata_constants import (DYNAMICS_PREFIX, Node, ConstContainer)
from bluepysnap.circuit_ids import CircuitNodeId, CircuitNodeIds
from bluepysnap._doctools import AbstractDocSubstitutionMeta

# this constant is not part of the sonata standard
NODE_SET_KEY = "$node_set"


class Nodes(NetworkObject, metaclass=AbstractDocSubstitutionMeta,
source_word="NetworkObject", target_word="Node"):
Expand Down Expand Up @@ -217,18 +214,6 @@ def load_population_data(self, population):
return result


# TODO: move to `libsonata` library
def _complex_query(prop, query):
# pylint: disable=assignment-from-no-return
result = np.full(len(prop), True)
for key, value in query.items():
if key == '$regex':
result = np.logical_and(result, prop.str.match(value + "\\Z"))
else:
raise BluepySnapError("Unknown query modifier: '%s'" % key)
return result


class NodePopulation:
"""Node population access."""

Expand Down Expand Up @@ -385,89 +370,17 @@ def _get_node_set(self, node_set_name):
raise BluepySnapError("Undefined node set: '%s'" % node_set_name)
return self._node_sets[node_set_name]

def _positional_mask(self, node_ids):
"""Positional mask for the node IDs.
Args:
node_ids (None/numpy.ndarray): the ids array. If None all ids are selected.
Examples:
if the data set contains 5 nodes:
_positional_mask([0,2]) --> [True, False, True, False, False]
"""
if node_ids is None:
return np.full(len(self._data), fill_value=True)
mask = np.full(len(self._data), fill_value=False)
valid_node_ids = pd.Index(utils.ensure_list(node_ids)).intersection(self._data.index)
mask[valid_node_ids] = True
return mask

def _circuit_mask(self, queries):
"""Handle the population, node ID and node set queries."""
populations = queries.pop(POPULATION_KEY, None)
if populations is not None and self.name not in set(utils.ensure_list(populations)):
node_ids = []
else:
node_ids = queries.pop(NODE_ID_KEY, None)
node_set = queries.pop(NODE_SET_KEY, None)
if node_set is not None:
if not isinstance(node_set, str):
raise BluepySnapError("{} is not a valid node set name.".format(node_set))
node_ids = node_ids if node_ids else self._data.index.values
node_ids = np.intersect1d(node_ids, self.ids(node_set))
return queries, self._positional_mask(node_ids)

def _properties_mask(self, queries, raise_missing_prop):
"""Return mask of node IDs with rows matching `props` dict."""
# pylint: disable=assignment-from-no-return
circuit_keys = {POPULATION_KEY, NODE_ID_KEY, NODE_SET_KEY}
unknown_props = set(queries) - set(self._data.columns) - circuit_keys
if unknown_props:
if raise_missing_prop:
raise BluepySnapError(
"Unknown node properties: [{0}]".format(", ".join(unknown_props)))
return np.full(len(self._data), fill_value=False)

queries, mask = self._circuit_mask(queries)
if not mask.any():
# Avoid fail and/or processing time if wrong population or no nodes
return mask

for prop, values in queries.items():
prop = self._data[prop]
if np.issubdtype(prop.dtype.type, np.floating):
v1, v2 = values
prop_mask = np.logical_and(prop >= v1, prop <= v2)
elif isinstance(values, str) and values.startswith('regex:'):
prop_mask = _complex_query(prop, {'$regex': values[6:]})
elif isinstance(values, Mapping):
prop_mask = _complex_query(prop, values)
else:
prop_mask = np.in1d(prop, values)
mask = np.logical_and(mask, prop_mask)
return mask

def _operator_mask(self, queries, raise_missing_prop):
"""Handle the query operators '$or', '$and'."""
if len(queries) == 0:
return np.full(len(self._data), True)

# will pop the population and or/and operators so need to copy
queries = deepcopy(queries)
first_key = list(queries)[0]
if first_key == '$or':
queries = queries.pop("$or")
operator = np.logical_or
elif first_key == '$and':
queries = queries.pop("$and")
operator = np.logical_and
else:
return self._properties_mask(queries, raise_missing_prop)
def _resolve_nodesets(self, queries):
def _resolve(queries, queries_key):
if queries_key == query.NODE_SET_KEY:
if query.AND_KEY not in queries:
queries[query.AND_KEY] = []
queries[query.AND_KEY].append(self._get_node_set(queries[queries_key]))
del queries[queries_key]

mask = np.full(len(self._data), first_key != "$or")
for query in queries:
mask = operator(mask, self._operator_mask(query, raise_missing_prop))
return mask
resolved_queries = deepcopy(queries)
query.traverse_queries_bottom_up(resolved_queries, _resolve)
return resolved_queries

def _node_ids_by_filter(self, queries, raise_missing_prop):
"""Return node IDs if their properties match the `queries` dict.
Expand All @@ -486,7 +399,14 @@ def _node_ids_by_filter(self, queries, raise_missing_prop):
>>> { Node.X: (0, 1), Node.MTYPE: 'L1_SLAC' }]})
"""
return self._data.index[self._operator_mask(queries, raise_missing_prop)].values
queries = self._resolve_nodesets(queries)
if raise_missing_prop:
properties = query.get_properties(queries)
if not properties.issubset(self._data.columns):
unknown_props = properties - set(self._data.columns)
raise BluepySnapError(f"Unknown node properties: {unknown_props}")
idx = query.resolve_ids(self._data, self.name, queries)
return self._data.index[idx].values

def ids(self, group=None, limit=None, sample=None, raise_missing_property=True):
"""Node IDs corresponding to node ``group``.
Expand Down
157 changes: 157 additions & 0 deletions bluepysnap/query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
"""Module to process search queries of nodes/edges."""
from collections.abc import Mapping
from copy import deepcopy

import numpy as np
import pandas as pd
from bluepysnap.exceptions import BluepySnapError
from bluepysnap import utils

# this constant is not part of the sonata standard
NODE_ID_KEY = "node_id"
EDGE_ID_KEY = "edge_id"
POPULATION_KEY = "population"
OR_KEY = "$or"
AND_KEY = "$and"
REGEX_KEY = "$regex"
NODE_SET_KEY = "$node_set"
VALUE_KEYS = {REGEX_KEY}
ALL_KEYS = {NODE_ID_KEY, EDGE_ID_KEY, POPULATION_KEY, OR_KEY, AND_KEY, NODE_SET_KEY} | VALUE_KEYS


# TODO: move to `libsonata` library
def _complex_query(prop, query):
result = np.full(len(prop), True)
for key, value in query.items():
if key == REGEX_KEY:
result = np.logical_and(result, prop.str.match(value + "\\Z"))
else:
raise BluepySnapError("Unknown query modifier: '%s'" % key)
return result


def _positional_mask(data, ids):
"""Positional mask for the node IDs.
Args:
ids (None/numpy.ndarray): the ids array. If None all ids are selected.
Examples:
if the data set contains 5 nodes:
_positional_mask(data, [0,2]) --> [True, False, True, False, False]
"""
if ids is None:
return np.full(len(data), fill_value=True)
mask = np.full(len(data), fill_value=False)
valid_ids = pd.Index(utils.ensure_list(ids)).intersection(data.index)
mask[valid_ids] = True
return mask


def _circuit_mask(data, population_name, queries):
"""Handle the population, node ID queries."""
populations = queries.pop(POPULATION_KEY, None)
if populations is not None and population_name not in set(utils.ensure_list(populations)):
ids = []
else:
ids = queries.pop(NODE_ID_KEY, queries.pop(EDGE_ID_KEY, None))
return queries, _positional_mask(data, ids)


def _properties_mask(data, population_name, queries):
"""Return mask of IDs matching `props` dict."""
unknown_props = set(queries) - set(data.columns) - ALL_KEYS
if unknown_props:
return np.full(len(data), fill_value=False)

queries, mask = _circuit_mask(data, population_name, queries)
if not mask.any():
# Avoid fail and/or processing time if wrong population or no nodes
return mask

for prop, values in queries.items():
prop = data[prop]
if np.issubdtype(prop.dtype.type, np.floating):
v1, v2 = values
prop_mask = np.logical_and(prop >= v1, prop <= v2)
elif isinstance(values, Mapping):
prop_mask = _complex_query(prop, values)
else:
prop_mask = np.in1d(prop, values)
mask = np.logical_and(mask, prop_mask)
return mask


def traverse_queries_bottom_up(queries, traverse_fn):
"""Traverse queries tree from leaves to root, left to right.
Args:
queries (dict): queries
traverse_fn (function): function to execute on each node of `queries` in traverse order
"""
for key in list(queries.keys()):
if key in {OR_KEY, AND_KEY}:
for subquery in queries[key]:
traverse_queries_bottom_up(subquery, traverse_fn)
elif isinstance(queries[key], Mapping):
if VALUE_KEYS & set(queries[key]):
if not set(queries[key]).issubset(VALUE_KEYS):
raise BluepySnapError("Value operators can't be used with plain values")
else:
traverse_queries_bottom_up(queries[key], traverse_fn)
traverse_fn(queries, key)


def get_properties(queries):
"""Extracts properties names from `queries`.
Args:
queries (dict): queries
Returns:
set: set of properties names
"""

def _collect(_, query_key):
if query_key not in ALL_KEYS:
props.add(query_key)

props = set()
traverse_queries_bottom_up(queries, _collect)
return props


def resolve_ids(data, population_name, queries):
"""Returns an index mask of `data` for given `queries`.
Args:
data (pd.DataFrame): data
population_name (str): population name of `data`
queries (dict): queries
Returns:
np.array: index mask
"""

def _merge_queries_masks(queries):
if len(queries) == 0:
return np.full(len(data), True)
return np.logical_and.reduce(list(queries.values()))

def _collect(queries, queries_key):
# each queries value is replaced with a bit mask of corresponding ids
if queries_key == OR_KEY:
# children are already resolved masks due to traverse order
children_mask = [_merge_queries_masks(query) for query in queries[queries_key]]
queries[queries_key] = np.logical_or.reduce(children_mask)
elif queries_key == AND_KEY:
# children are already resolved masks due to traverse order
children_mask = [_merge_queries_masks(query) for query in queries[queries_key]]
queries[queries_key] = np.logical_and.reduce(children_mask)
else:
queries[queries_key] = _properties_mask(
data, population_name, {queries_key: queries[queries_key]})

queries = deepcopy(queries)
traverse_queries_bottom_up(queries, _collect)
return _merge_queries_masks(queries)
2 changes: 0 additions & 2 deletions bluepysnap/sonata_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
from bluepysnap.exceptions import BluepySnapError

DYNAMICS_PREFIX = "@dynamics:"
NODE_ID_KEY = "node_id"
POPULATION_KEY = "population"


class ConstContainer:
Expand Down
42 changes: 7 additions & 35 deletions tests/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,30 +428,6 @@ def _mock_edge(name, source, target):
assert circuit.nodes['default'].source_in_edges() == {"edge1", "edge3"}
assert circuit.nodes['default'].target_in_edges() == {"edge2"}

def test__positional_mask(self):
npt.assert_array_equal(self.test_obj._positional_mask([1, 2]), [False, True, True])
npt.assert_array_equal(self.test_obj._positional_mask([0, 2]), [True, False, True])

def test__node_population_mask(self):
queries, mask = self.test_obj._circuit_mask({"population": "default",
"other": "val"})
assert queries == {"other": "val"}
npt.assert_array_equal(mask, [True, True, True])

queries, mask = self.test_obj._circuit_mask({"population": "unknown",
"other": "val"})
assert queries == {"other": "val"}
npt.assert_array_equal(mask, [False, False, False])

queries, mask = self.test_obj._circuit_mask({"population": "default",
"node_id": [2], "other": "val"})
assert queries == {"other": "val"}
npt.assert_array_equal(mask, [False, False, True])

queries, mask = self.test_obj._circuit_mask({"other": "val"})
assert queries == {"other": "val"}
npt.assert_array_equal(mask, [True, True, True])

def test_ids(self):
_call = self.test_obj.ids
npt.assert_equal(_call(), [0, 1, 2])
Expand Down Expand Up @@ -589,24 +565,20 @@ def test_node_ids_by_filter_complex_query(self):
# ...not 'startswith'
npt.assert_equal(
[],
test_obj.ids({
Cell.MTYPE: {'$regex': 'L6'}, })
test_obj.ids({Cell.MTYPE: {'$regex': 'L6'}})
)
# ...or 'endswith'
npt.assert_equal(
[],
test_obj.ids({
Cell.MTYPE: {'$regex': 'BP'}, })
)
# tentative support for 'regex:' prefix
npt.assert_equal(
[1, 2],
test_obj.ids({
Cell.MTYPE: 'regex:.*BP', })
test_obj.ids({Cell.MTYPE: {'$regex': 'BP'}})
)
# '$regex' is the only query modifier supported for the moment
with pytest.raises(BluepySnapError):
with pytest.raises(BluepySnapError) as e:
test_obj.ids({Cell.MTYPE: {'err': '.*BP'}}, raise_missing_property=False)
assert 'Unknown query modifier' in e.value.args[0]
with pytest.raises(BluepySnapError) as e:
test_obj.ids({Cell.MTYPE: {'err': '.*BP'}})
assert 'Unknown node properties' in e.value.args[0]

def test_get(self):
_call = self.test_obj.get
Expand Down

0 comments on commit db69493

Please sign in to comment.