Skip to content

Commit

Permalink
Support lazy loading of nodes attributes (#208)
Browse files Browse the repository at this point in the history
It should give some benefits:

- reduced loading time when only a few attributes are needed
- reduced memory usage when only a few attributes are needed
- possibility to re-use some methods in services that doesn't get any benefit from caching the nodes
  • Loading branch information
GianlucaFicarelli committed Jun 5, 2023
1 parent d274c40 commit 998d1d6
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 49 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Improvements
~~~~~~~~~~~~
- Clarification for partial circuit configs
- Publish version as ``bluepysnap.__version__``
- Support lazy loading of nodes attributes.
- Add python 3.11 tests.

Version v1.0.5
Expand Down
164 changes: 118 additions & 46 deletions bluepysnap/nodes/node_population.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from collections.abc import Mapping, Sequence
from copy import deepcopy

import libsonata
import numpy as np
import pandas as pd
from cached_property import cached_property
Expand Down Expand Up @@ -94,28 +95,97 @@ def _node_sets(self):
return self._circuit.node_sets

@cached_property
def _data(self):
"""Collect data for the node population as a pandas.DataFrame."""
nodes = self._population
categoricals = nodes.enumeration_names

_all = nodes.select_all()
result = pd.DataFrame(index=pd.RangeIndex(_all.flat_size, name="node_ids"))

for attr in sorted(nodes.attribute_names):
if attr in categoricals:
enumeration = np.asarray(nodes.get_enumeration(attr, _all))
values = np.asarray(nodes.enumeration_values(attr))
# if the size of `values` is large enough compared to `enumeration`, not using
# categorical reduces the memory usage.
if values.shape[0] < 0.5 * enumeration.shape[0]:
result[attr] = pd.Categorical.from_codes(enumeration, categories=values)
else:
result[attr] = values[enumeration]
else:
result[attr] = nodes.get_attribute(attr, _all)
for attr in sorted(utils.add_dynamic_prefix(nodes.dynamics_attribute_names)):
result[attr] = nodes.get_dynamics_attribute(attr.split(DYNAMICS_PREFIX)[1], _all)
def _cache(self):
"""Cached DataFrame of nodes, to be accessed through _get_data()."""
return pd.DataFrame(index=pd.RangeIndex(self.size, name="node_ids"))

def _get_libsonata_selection(self, node_ids):
"""Return a libsonata Selection from the given node_ids."""
if node_ids is None:
return libsonata.Selection([(0, self.size)])
return libsonata.Selection(node_ids)

def _get_values_from_sonata(self, nodes, attr, node_ids):
"""Return the selected values as np.ndarray or pd.Categorical."""
selection = self._get_libsonata_selection(node_ids)
if attr in nodes.enumeration_names:
enumeration = np.asarray(nodes.get_enumeration(attr, selection))
values = np.asarray(nodes.enumeration_values(attr))
# if the size of `values` is large enough compared to `enumeration`, not using
# categorical reduces the memory usage.
# We compare with nodes.size instead of len(enumeration) to not depend on the selection.
if len(values) < 0.5 * nodes.size:
return pd.Categorical.from_codes(enumeration, categories=values)
return values[enumeration]
if attr in nodes.attribute_names:
return nodes.get_attribute(attr, selection)
if attr.startswith(DYNAMICS_PREFIX):
stripped = attr[len(DYNAMICS_PREFIX) :]
if stripped in nodes.dynamics_attribute_names:
return nodes.get_dynamics_attribute(stripped, selection)
raise BluepySnapError(f"Attribute not found in population {self.name}: {attr}")

def _iter_selected_properties(self, existing, desired):
"""Yield ordered (idx, attr) for each attr in desired, and not in existing.
Called to ensure that the order of the columns of the cached DataFrame doesn't depend
on the order of the retrieved attributes, when _get_data is called multiple times
with different properties.
Args:
existing: existing attributes, that are going to be skipped.
desired: desired attributes, that are going to be yielded in order.
"""
idx = 0
existing = set(existing)
desired = set(desired)
for attr in self._ordered_property_names:
if attr in existing:
idx += 1
continue
if attr not in desired:
continue
yield idx, attr

def _get_data(self, properties=None, node_ids=None):
"""Collect data for the node population as a pandas.DataFrame.
Return a DataFrame with node_ids as index, loading the requested properties if needed.
The returned DataFrame isn't filtered by columns, so it may contain more properties than
requested, if they were loaded previously.
This is done for efficiency, since a copy of the data is not needed:
- in self.get(), the DataFrame is filtered by node_ids first, and by property later
- in self.property_dtypes(), all the properties are needed
- in self._node_ids_by_filter(), the required columns are selected if and when needed
Args:
properties (str|set|list|None): properties to load, or None to load all of them.
node_ids (list|np.ndarray|None): node ids to select.
If None, all the ids are selected, and the cache is read and updated if needed.
If not None, the cache is read and used if possible, but not updated.
"""
result = self._cache
if properties is None:
properties_set = self.property_names
else:
properties_set = set(utils.ensure_list(properties))
self._check_properties(properties_set)
if node_ids is not None:
# Select the ids from the cached dataframe.
# The original dataframe won't be updated in this case.
result = result.loc[node_ids]
cached_columns = properties_set.intersection(result.columns)
if len(cached_columns) < len(properties_set):
# some requested properties miss from the cache
nodes = self._population
# insert columns at the correct position
for n, (loc, name) in enumerate(
self._iter_selected_properties(existing=result.columns, desired=properties_set)
):
values = self._get_values_from_sonata(nodes=nodes, attr=name, node_ids=node_ids)
result.insert(n + loc, name, values)
return result

@property
Expand Down Expand Up @@ -144,6 +214,11 @@ def _property_names(self):
def _dynamics_params_names(self):
return set(utils.add_dynamic_prefix(self._population.dynamics_attribute_names))

@cached_property
def _ordered_property_names(self):
"""Similar to self.property_names, but as an ordered list."""
return sorted(self._property_names) + sorted(self._dynamics_params_names)

def source_in_edges(self):
"""Set of edge population names that use this node population as source.
Expand Down Expand Up @@ -236,11 +311,12 @@ def property_dtypes(self):
Returns:
pandas.Series: series indexed by field name with the corresponding dtype as value.
"""
return self._data.dtypes.sort_index()
# read all the properties, without loading any node id
return self._get_data(properties=None, node_ids=[]).dtypes.sort_index()

def _check_id(self, node_id):
"""Check that single node ID belongs to the circuit."""
if node_id < 0 or node_id >= len(self._data.index):
if node_id < 0 or node_id >= self.size:
raise BluepySnapError(f"node ID not found: {node_id} in population '{self.name}'")

def _check_ids(self, node_ids):
Expand All @@ -254,16 +330,16 @@ def _check_ids(self, node_ids):
else:
max_id = max(node_ids)
min_id = min(node_ids)
if min_id < 0 or max_id >= len(self._data.index):
if min_id < 0 or max_id >= self.size:
raise BluepySnapError(
f"All node IDs must be >= 0 and < {len(self._data.index)} "
f"for population '{self.name}'"
f"All node IDs must be >= 0 and < {self.size} " f"for population '{self.name}'"
)

def _check_property(self, prop):
"""Check if a property exists inside the dataset."""
if prop not in self.property_names:
raise BluepySnapError(f"No such property: '{prop}'")
def _check_properties(self, properties):
"""Check if the properties exist inside the dataset."""
unknown_props = properties - self.property_names
if unknown_props:
raise BluepySnapError(f"Unknown node properties: {sorted(unknown_props)}")

def _get_node_set(self, node_set_name):
"""Returns the node set named 'node_set_name'."""
Expand Down Expand Up @@ -301,13 +377,13 @@ def _node_ids_by_filter(self, queries, raise_missing_prop):
"""
queries = self._resolve_nodesets(queries)
properties = query.get_properties(queries)
if raise_missing_prop:
properties = query.get_properties(queries)
if not properties.issubset(self._data.columns):
unknown_props = properties - set(self._data.columns)
raise BluepySnapError(f"Unknown node properties: {unknown_props}")
idx = query.resolve_ids(self._data, self.name, queries)
return self._data.index[idx].values
self._check_properties(properties)
# load all the properties needed to execute the query, excluding the unknown properties
data = self._get_data(properties & self.property_names)
idx = query.resolve_ids(data, self.name, queries)
return idx.nonzero()[0]

def ids(self, group=None, limit=None, sample=None, raise_missing_property=True):
"""Node IDs corresponding to node ``group``.
Expand Down Expand Up @@ -337,7 +413,7 @@ def ids(self, group=None, limit=None, sample=None, raise_missing_property=True):
group = group.filter_population(self.name).get_ids()

if group is None:
result = self._data.index.values
result = np.arange(self.size)
elif isinstance(group, Mapping):
result = self._node_ids_by_filter(
queries=group, raise_missing_prop=raise_missing_property
Expand Down Expand Up @@ -379,7 +455,7 @@ def get(self, group=None, properties=None):
group (int/CircuitNodeId/CircuitNodeIds/sequence/str/mapping/None):
see :ref:`Group Concept`
properties (list|str|None): If specified, return only the properties in the list.
Otherwise return all properties.
Otherwise, return all the properties.
Returns:
value/pandas.Series/pandas.DataFrame:
Expand Down Expand Up @@ -441,21 +517,17 @@ def get(self, group=None, properties=None):
>>> type(result), result.shape
(pandas.core.frame.DataFrame, (1, 1))
"""
result = self._data
if group is not None:
if isinstance(group, (int, np.integer)):
self._check_id(group)
elif isinstance(group, CircuitNodeId):
group = self.ids(group)[0]
else:
group = self.ids(group)
result = result.loc[group]

if properties is not None:
for p in utils.ensure_list(properties):
self._check_property(p)
result = result[properties]

result = self._get_data(properties=properties)
result = result.loc[group] if group is not None else result
result = result[properties] if properties is not None else result
return result

def positions(self, group=None):
Expand Down
59 changes: 56 additions & 3 deletions tests/test_nodes/test_node_population.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import itertools
import json
import pickle
import sys
from unittest import mock

import libsonata
import numpy as np
import numpy.testing as npt
import pandas as pd
Expand All @@ -19,7 +21,7 @@
from bluepysnap.sonata_constants import DEFAULT_NODE_TYPE, Node
from bluepysnap.utils import IDS_DTYPE

from utils import TEST_DATA_DIR, create_node_population
from utils import TEST_DATA_DIR, assert_array_equal_strict, create_node_population


class TestNodePopulation:
Expand Down Expand Up @@ -303,8 +305,8 @@ def test_node_ids_by_filter_complex_query(self):
Cell.MTYPE: ["L23_MC", "L4_BP", "L6_BP", "L6_BPC"],
}
)
# replace the data using the __dict__ directly
test_obj.__dict__["_data"] = data
# populate the cached nodes
test_obj.__dict__["_cache"] = data

# only full match is accepted
npt.assert_equal(
Expand Down Expand Up @@ -640,6 +642,57 @@ def test_pickle(self, tmp_path):
assert pickle_path.stat().st_size < 210
assert test_obj.size == 3

def test_filter_properties(self):
assert self.test_obj._ordered_property_names == [
"layer",
"model_template",
"model_type",
"morphology",
"mtype",
"rotation_angle_xaxis",
"rotation_angle_yaxis",
"rotation_angle_zaxis",
"x",
"y",
"z",
"@dynamics:holding_current",
]
existing = ["morphology", "mtype", "y"]
desired = {"@dynamics:holding_current", "z", "x", "y", "model_type", "layer", "mtype"}

result = self.test_obj._iter_selected_properties(existing=existing, desired=desired)

expected = [
(0, "layer"),
(0, "model_type"),
(2, "x"),
(3, "z"),
(3, "@dynamics:holding_current"),
]
for actual_item, expected_item in itertools.zip_longest(result, expected):
assert actual_item == expected_item

def test_get_values_from_sonata(self):
nodes = self.test_obj._population

# valid attributes
result = self.test_obj._get_values_from_sonata(nodes, "mtype", [0, 1])
assert_array_equal_strict(result, np.array(["L2_X", "L6_Y"], dtype=object))

# dynamics attribute
result = self.test_obj._get_values_from_sonata(nodes, "@dynamics:holding_current", [2])
assert_array_equal_strict(result, np.array([0.3], dtype=float))

# empty selection
result = self.test_obj._get_values_from_sonata(nodes, "x", [])
assert_array_equal_strict(result, np.array([], dtype=float))

# unknown attribute
with pytest.raises(
BluepySnapError, match="Attribute not found in population default: unknown"
):
self.test_obj._get_values_from_sonata(nodes, "unknown", [2])


class TestNodePopulationSpatialIndex:
def setup_method(self):
Expand Down
8 changes: 8 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pathlib import Path

import libsonata
import numpy.testing as npt
import pytest

from bluepysnap.circuit import Circuit
Expand Down Expand Up @@ -111,3 +112,10 @@ def create_node_population(filepath, pop_name, circuit=None, node_sets=None, pop
node_pop = NodePopulation(circuit, pop_name)
circuit.nodes = Nodes(circuit)
return node_pop


def assert_array_equal_strict(x, y):
# With numpy >= 1.22.4 it would be possible to specify strict=True.
# The strict parameter ensures that the array data types match.
npt.assert_array_equal(x, y)
assert x.dtype == y.dtype

0 comments on commit 998d1d6

Please sign in to comment.