Skip to content

Commit

Permalink
Homogeneous interface for sample and limit (#115)
Browse files Browse the repository at this point in the history
* Add limit/sample to Nodes
* Add limit/sample to Edges
* Add limit/sample to EdgePopulations
* Factorize limit/sample inside _get_ids_from_pop
* Remove the # pylint: disable=arguments-differ for ids
  functions in Nodes/Edges
  • Loading branch information
tomdele committed Jan 5, 2021
1 parent 86414fb commit 12831a1
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 20 deletions.
51 changes: 37 additions & 14 deletions bluepysnap/edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,25 @@ def __init__(self, circuit): # pylint: disable=useless-super-delegation
def _collect_populations(self):
return self._get_populations(EdgeStorage, self._config['networks']['edges'])

def ids(self, edge_ids): # pylint: disable=arguments-differ
def ids(self, group=None, sample=None, limit=None):
"""Edge CircuitEdgeIds corresponding to edges ``edge_ids``.
Args:
edge_ids (int/CircuitEdgeId/CircuitEdgeIds/sequence): Which IDs will be
group (None/int/CircuitEdgeId/CircuitEdgeIds/sequence): Which IDs will be
returned depends on the type of the ``group`` argument:
- ``None``: return all CircuitEdgeIds.
- ``CircuitEdgeId``: return the ID in a CircuitEdgeIds object.
- ``CircuitEdgeIds``: return the IDs in a CircuitNodeIds object.
- ``int``: returns a CircuitEdgeIds object containing the corresponding edge ID
for all populations.
- ``sequence``: returns a CircuitEdgeIds object containing the corresponding edge
IDs for all populations.
sample (int): If specified, randomly choose ``sample`` number of
IDs from the match result. If the size of the sample is greater than
the size of all the EdgePopulations then all ids are taken and shuffled.
limit (int): If specified, return the first ``limit`` number of
IDs from the match result. If limit is greater than the size of all the populations
all node IDs are returned.
Returns:
CircuitEdgeIds: returns a CircuitEdgeIds containing all the edge IDs and the
Expand All @@ -65,11 +72,12 @@ def ids(self, edge_ids): # pylint: disable=arguments-differ
Notes:
This envision also the maybe future selection of edges on queries.
"""
if isinstance(edge_ids, CircuitEdgeIds):
diff = np.setdiff1d(edge_ids.get_populations(unique=True), self.population_names)
if isinstance(group, CircuitEdgeIds):
diff = np.setdiff1d(group.get_populations(unique=True), self.population_names)
if diff.size != 0:
raise BluepySnapError("Population {} does not exist in the circuit.".format(diff))
return self._get_ids_from_pop(lambda x: (x.ids(edge_ids), x.name), CircuitEdgeIds)
fun = lambda x: (x.ids(group), x.name)
return self._get_ids_from_pop(fun, CircuitEdgeIds, sample=sample, limit=limit)

def get(self, edge_ids=None, properties=None): # pylint: disable=arguments-differ
"""Edge properties as pandas DataFrame.
Expand Down Expand Up @@ -472,33 +480,48 @@ def _get(self, selection, properties=None):

return result

def ids(self, edge_ids):
def ids(self, group=None, limit=None, sample=None):
"""Edge IDs corresponding to edges ``edge_ids``.
Args:
edge_ids (int/CircuitEdgeId/CircuitEdgeIds/sequence): Which IDs will be
group (None/int/CircuitEdgeId/CircuitEdgeIds/sequence): Which IDs will be
returned depends on the type of the ``group`` argument:
- ``None``: return all IDs.
- ``int``, ``CircuitEdgeId``: return a single edge ID.
- ``CircuitEdgeIds`` return IDs of edges in an array.
- ``CircuitEdgeIds`` return IDs of edges the edge population in an array.
- ``sequence``: return IDs of edges in an array.
sample (int): If specified, randomly choose ``sample`` number of
IDs from the match result. If the size of the sample is greater than
the size of the EdgePopulation then all ids are taken and shuffled.
limit (int): If specified, return the first ``limit`` number of
IDs from the match result. If limit is greater than the size of the population
all node IDs are returned.
Returns:
numpy.array: A numpy array of IDs.
"""
if isinstance(edge_ids, CircuitEdgeIds):
result = edge_ids.filter_population(self.name).get_ids()
elif isinstance(edge_ids, np.ndarray):
result = edge_ids
if group is None:
result = self._population.select_all().flatten()
elif isinstance(group, CircuitEdgeIds):
result = group.filter_population(self.name).get_ids()
elif isinstance(group, np.ndarray):
result = group
else:
result = utils.ensure_list(edge_ids)
result = utils.ensure_list(group)
# test if first value is a CircuitEdgeId if yes then all values must be CircuitEdgeId
if isinstance(first(result, None), CircuitEdgeId):
try:
result = [cid.id for cid in result if cid.population == self.name]
except AttributeError:
raise BluepySnapError("All values from a list must be of type int or "
"CircuitEdgeId.")
if sample is not None:
if len(result) > 0:
result = np.random.choice(result, min(sample, len(result)), replace=False)
if limit is not None:
result = result[:limit]
return np.asarray(result)

def get(self, edge_ids, properties):
Expand Down
17 changes: 14 additions & 3 deletions bluepysnap/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,19 @@ def property_names(self):
"""Returns all the NetworkObject properties present inside the circuit."""
return set(prop for pop in self.values() for prop in pop.property_names)

def _get_ids_from_pop(self, fun_to_apply, returned_ids_cls):
def _get_ids_from_pop(self, fun_to_apply, returned_ids_cls, sample=None, limit=None):
"""Get CircuitIds of class 'returned_ids_cls' for all populations using 'fun_to_apply'.
Args:
fun_to_apply (function): A function that returns the list of IDs for each population
and the population containing these IDs.
returned_ids_cls (CircuitNodeIds/CircuitEdgeIds): the class for the CircuitIds.
sample (int): If specified, randomly choose ``sample`` number of
IDs from the match result. If the size of the sample is greater than
the size of all the NetworkObjectPopulation then all ids are taken and shuffled.
limit (int): If specified, return the first ``limit`` number of
IDs from the match result. If limit is greater than the size of all the population
then all IDs are returned.
Returns:
CircuitNodeIds/CircuitEdgeIds: containing the IDs and the populations.
Expand All @@ -141,10 +147,15 @@ def _get_ids_from_pop(self, fun_to_apply, returned_ids_cls):
populations.append(pops)
ids = np.concatenate(ids).astype(np.int64)
populations = np.concatenate(populations).astype(str_type)
return returned_ids_cls.from_arrays(populations, ids)
res = returned_ids_cls.from_arrays(populations, ids)
if sample:
res.sample(sample, inplace=True)
if limit:
res.limit(limit, inplace=True)
return res

@abc.abstractmethod
def ids(self, *args, **kwargs):
def ids(self, group=None, sample=None, limit=None):
"""Resolves the ids of the NetworkObject."""

@abc.abstractmethod
Expand Down
12 changes: 9 additions & 3 deletions bluepysnap/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def property_values(self, prop):
return set(value for pop in self.values() if prop in pop.property_names for value in
pop.property_values(prop))

def ids(self, group=None): # pylint: disable=arguments-differ
def ids(self, group=None, sample=None, limit=None):
"""Returns the CircuitNodeIds corresponding to the nodes from ``group``.
Args:
Expand All @@ -75,6 +75,12 @@ def ids(self, group=None): # pylint: disable=arguments-differ
- ``mapping``: Returns a CircuitNodeIds object containing nodes matching a
properties filter.
- ``None``: return all node IDs of the circuit in a CircuitNodeIds object.
sample (int): If specified, randomly choose ``sample`` number of
IDs from the match result. If the size of the sample is greater than
the size of all the NodePopulations then all ids are taken and shuffled.
limit (int): If specified, return the first ``limit`` number of
IDs from the match result. If limit is greater than the size of all the populations,
all node IDs are returned.
Returns:
CircuitNodeIds: returns a CircuitNodeIds containing all the node IDs and the
Expand Down Expand Up @@ -112,7 +118,7 @@ def ids(self, group=None): # pylint: disable=arguments-differ
raise BluepySnapError("Population {} does not exist in the circuit.".format(diff))

fun = lambda x: (x.ids(group, raise_missing_property=False), x.name)
return self._get_ids_from_pop(fun, CircuitNodeIds)
return self._get_ids_from_pop(fun, CircuitNodeIds, sample=sample, limit=limit)

def get(self, group=None, properties=None): # pylint: disable=arguments-differ
"""Node properties as a pandas DataFrame.
Expand Down Expand Up @@ -490,7 +496,7 @@ def ids(self, group=None, limit=None, sample=None, raise_missing_property=True):
returned depends on the type of the ``group`` argument:
- ``int``, ``CircuitNodeId``: return a single node ID if it belongs to the circuit.
- ``CircuitNodeIds`` return IDs of nodes in an array.
- ``CircuitNodeIds`` return IDs of nodes from the node population in an array.
- ``sequence``: return IDs of nodes in an array.
- ``str``: return IDs of nodes in a node set.
- ``mapping``: return IDs of nodes matching a properties filter.
Expand Down
3 changes: 3 additions & 0 deletions tests/test_circuit_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,16 +223,19 @@ def test_sample(self):
assert len(test_obj) == 4
test_obj.sample(1, inplace=True)
assert len(test_obj) == 1
assert len(self.ids_cls(pd.MultiIndex.from_arrays([[], []])).sample(2)) == 0

def test_limit(self):
tested = self.test_obj_sorted.limit(2, inplace=False)
assert len(tested) == 2
assert tested == self.ids_cls(self._circuit_ids(['a', 'a'], [0, 1]))
assert len(self.ids_cls(pd.MultiIndex.from_arrays([[], []])).limit(2)) == 0

def test_unique(self):
tested = self.ids_cls.from_dict({"a": [0, 0, 1], "b": [1, 2, 2]}).unique()
expected = self.ids_cls.from_dict({"a": [0, 1], "b": [1, 2]})
assert tested == expected
assert len(self.ids_cls(pd.MultiIndex.from_arrays([[], []])).unique()) == 0

def test_tolist(self):
expected = [('a', 0), ('a', 1), ('b', 0), ('a', 2)]
Expand Down
17 changes: 17 additions & 0 deletions tests/test_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def test_property_dtypes_fail(self):
test_obj.property_dtypes.sort_index()

def test_ids(self):
np.random.seed(42)
# single edge ID --> CircuitEdgeIds return populations with the 0 id
expected = CircuitEdgeIds.from_tuples([("default", 0), ("default2", 0)])
assert self.test_obj.ids(0) == expected
Expand Down Expand Up @@ -173,6 +174,14 @@ def test_ids(self):
expected = CircuitEdgeIds.from_arrays(["default2", "default2"], [0, 1])
assert ids.filter_population("default2").limit(2) == expected

tested = self.test_obj.ids(sample=2)
expected = CircuitEdgeIds.from_arrays(["default2", "default"], [2, 3], sort_index=False)
assert tested == expected

tested = self.test_obj.ids(limit=5)
expected = CircuitEdgeIds.from_dict({"default": [0, 1, 2, 3], "default2": [0]})
assert tested == expected

def test_get(self):
with pytest.raises(BluepySnapError):
self.test_obj.get(properties=["other2", "unknown"])
Expand Down Expand Up @@ -630,6 +639,7 @@ def test_property_dtypes(self):
pdt.assert_series_equal(expected, self.test_obj.property_dtypes)

def test_ids(self):
npt.assert_equal(self.test_obj.ids(), np.array([0, 1, 2, 3]))
assert self.test_obj.ids(0) == [0]
npt.assert_equal(self.test_obj.ids([0, 1]), np.array([0, 1]))
npt.assert_equal(self.test_obj.ids(np.array([0, 1])), np.array([0, 1]))
Expand All @@ -641,6 +651,13 @@ def test_ids(self):
npt.assert_equal(self.test_obj.ids(ids), np.array([0]))
ids = CircuitEdgeIds.from_tuples([("default2", 0), ("default2", 1)])
npt.assert_equal(self.test_obj.ids(ids), [])
npt.assert_equal(self.test_obj.ids(), np.array([0, 1, 2, 3]))

# limit too big compared to the number of ids
npt.assert_equal(self.test_obj.ids(limit=15), [0, 1, 2, 3])
npt.assert_equal(len(self.test_obj.ids(sample=2)), 2)
# if sample > population.size --> sample = population.size
npt.assert_equal(len(self.test_obj.ids(sample=25)), 4)

def test_get_1(self):
properties = [
Expand Down
10 changes: 10 additions & 0 deletions tests/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ def test_property_dtypes_fail(self):
test_obj.property_dtypes.sort_index()

def test_ids(self):
np.random.seed(42)

# None --> CircuitNodeIds with all ids
tested = self.test_obj.ids()
expected = CircuitNodeIds.from_dict({"default": [0, 1, 2], "default2": [0, 1, 2, 3]})
Expand Down Expand Up @@ -235,6 +237,14 @@ def test_ids(self):
expected = CircuitNodeIds.from_arrays(["default2", "default2"], [0, 1])
assert ids.filter_population("default2").limit(2) == expected

tested = self.test_obj.ids(sample=2)
expected = CircuitNodeIds.from_arrays(["default2","default2"], [3, 0], sort_index=False)
assert tested == expected

tested = self.test_obj.ids(limit=4)
expected = CircuitNodeIds.from_dict({"default": [0, 1, 2], "default2": [0]})
assert tested == expected

def test_get(self):
# return all properties for all the ids
tested = self.test_obj.get()
Expand Down

0 comments on commit 12831a1

Please sign in to comment.