Skip to content

Commit

Permalink
Ensure that the report DataFrames have the same schema even when empty (
Browse files Browse the repository at this point in the history
  • Loading branch information
GianlucaFicarelli committed Jun 19, 2023
1 parent 2da1c6d commit 4e3deea
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 50 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Version v1.0.7
Bug Fixes
~~~~~~~~~
- Fix CircuitIds.sample() to always return different samples.
- Ensure that the report DataFrames have the same schema even when empty.


Version v1.0.6
Expand Down
28 changes: 11 additions & 17 deletions bluepysnap/frame_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

import bluepysnap._plotting
from bluepysnap.exceptions import BluepySnapError
from bluepysnap.utils import ensure_ids, ensure_list
from bluepysnap.utils import ensure_ids

L = logging.getLogger(__name__)

Expand Down Expand Up @@ -104,9 +104,6 @@ def get(self, group=None, t_start=None, t_stop=None, t_step=None):
except SonataError as e:
raise BluepySnapError(e) from e

if len(view.ids) == 0:
return pd.DataFrame()

# cell ids and section ids in the columns are enforced to be int64
# to avoid issues with numpy automatic conversions and to ensure that
# the results are the same regardless of the libsonata version [NSETM-1766]
Expand Down Expand Up @@ -163,21 +160,18 @@ def report(self):
- (population_name, node_id, compartment id) for the CompartmentReport
- (population_name, node_id) for the SomaReport
"""
res = pd.DataFrame()
dataframes = {}
for population in self.frame_report.population_names:
frames = self.frame_report[population]
try:
ids = frames.nodes.ids(group=self.group)
except BluepySnapError:
continue
data = frames.get(group=ids, t_start=self.t_start, t_stop=self.t_stop)
if data.empty:
continue
new_index = tuple(tuple([population] + ensure_list(x)) for x in data.columns)
data.columns = pd.MultiIndex.from_tuples(new_index)
# need to do this in order to preserve MultiIndex for columns
res = data if res.empty else data.join(res, how="outer")
return res.sort_index().sort_index(axis=1)
ids = frames.nodes.ids(group=self.group, raise_missing_property=False)
df = frames.get(group=ids, t_start=self.t_start, t_stop=self.t_stop)
dataframes[population] = df
# optimize when there is at most one non-empty df: use copy=False, and no need to sort
if sum(not df.empty for df in dataframes.values()) <= 1:
return pd.concat(dataframes, axis=1, copy=False)
# when concatenating multiple df, don't use copy=False because 2x slower (Pandas 2.0.2)
result = pd.concat(dataframes, axis=1)
return result.sort_index(axis=0).sort_index(axis=1)

# pylint: disable=protected-access
trace = bluepysnap._plotting.frame_trace
Expand Down
100 changes: 67 additions & 33 deletions tests/test_frame_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,37 +99,63 @@ def test_iter(self):
isinstance(report, test_module.PopulationCompartmentReport)

def test_filter(self):
expected = pd.DataFrame(
data=[
np.array([0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3] * 2, dtype=np.float32) + 0.1 * i
for i in range(4)
],
columns=pd.MultiIndex.from_tuples(
[
("default", 0, 0),
("default", 0, 1),
("default", 1, 0),
("default", 1, 1),
("default", 2, 0),
("default", 2, 1),
("default", 2, 1),
("default2", 0, 0),
("default2", 0, 1),
("default2", 1, 0),
("default2", 1, 1),
("default2", 2, 0),
("default2", 2, 1),
("default2", 2, 1),
]
),
index=np.array([0.3, 0.4, 0.5, 0.6]),
)

filtered = self.test_obj.filter(group=[0], t_start=0.3, t_stop=0.6)
assert filtered.frame_report == self.test_obj
assert filtered.t_start == 0.3
assert filtered.t_stop == 0.6
assert filtered.group == [0]
assert isinstance(filtered, test_module.FilteredFrameReport)
npt.assert_allclose(filtered.report.index, np.array([0.3, 0.4, 0.5, 0.6]))
assert filtered.report.columns.tolist() == [
expected_columns = [
("default", 0, 0),
("default", 0, 1),
("default2", 0, 0),
("default2", 0, 1),
]
pdt.assert_frame_equal(filtered.report, expected.loc[:, expected_columns])

filtered = self.test_obj.filter(group={"other1": ["B"]}, t_start=0.3, t_stop=0.6)
npt.assert_allclose(filtered.report.index, np.array([0.3, 0.4, 0.5, 0.6]))
assert filtered.report.columns.tolist() == [("default2", 1, 0), ("default2", 1, 1)]
expected_columns = [("default2", 1, 0), ("default2", 1, 1)]
pdt.assert_frame_equal(filtered.report, expected.loc[:, expected_columns])

filtered = self.test_obj.filter(group={"population": "default2"}, t_start=0.3, t_stop=0.6)
assert filtered.report.columns.tolist() == [
expected_columns = [
("default2", 0, 0),
("default2", 0, 1),
("default2", 1, 0),
("default2", 1, 1),
("default2", 2, 0),
("default2", 2, 1),
("default2", 2, 1),
]
pdt.assert_frame_equal(filtered.report, expected.loc[:, expected_columns])

filtered = self.test_obj.filter(group={"population": "default3"}, t_start=0.3, t_stop=0.6)
pdt.assert_frame_equal(filtered.report, pd.DataFrame())
pdt.assert_frame_equal(filtered.report, expected.iloc[:0, :0])


class TestSomaReport:
Expand All @@ -146,42 +172,45 @@ def test_iter(self):
isinstance(report, test_module.PopulationSomaReport)

def test_filter(self):
expected = pd.DataFrame(
data=[
np.array([0.3, 1.3, 2.3, 0.3, 1.3, 2.3], dtype=np.float32) + 0.1 * i
for i in range(4)
],
columns=pd.MultiIndex.from_tuples(
[
("default", 0),
("default", 1),
("default", 2),
("default2", 0),
("default2", 1),
("default2", 2),
]
),
index=np.array([0.3, 0.4, 0.5, 0.6]),
)

filtered = self.test_obj.filter(group=None, t_start=0.3, t_stop=0.6)
assert filtered.frame_report == self.test_obj
assert filtered.t_start == 0.3
assert filtered.t_stop == 0.6
assert filtered.group is None
assert isinstance(filtered, test_module.FilteredFrameReport)
npt.assert_allclose(filtered.report.index, np.array([0.3, 0.4, 0.5, 0.6]))
assert filtered.report.columns.tolist() == [
("default", 0),
("default", 1),
("default", 2),
("default2", 0),
("default2", 1),
("default2", 2),
]
pdt.assert_frame_equal(filtered.report, expected)

filtered = self.test_obj.filter(group={"other1": ["B"]}, t_start=0.3, t_stop=0.6)
npt.assert_allclose(filtered.report.index, np.array([0.3, 0.4, 0.5, 0.6]))
assert filtered.report.columns.tolist() == [("default2", 1)]
pdt.assert_frame_equal(filtered.report, expected.loc[:, [("default2", 1)]])

filtered = self.test_obj.filter(group={"population": "default2"}, t_start=0.3, t_stop=0.6)
assert filtered.report.columns.tolist() == [
("default2", 0),
("default2", 1),
("default2", 2),
]
pdt.assert_frame_equal(filtered.report, expected.loc[:, ["default2"]])

filtered = self.test_obj.filter(group={"population": "default3"}, t_start=0.3, t_stop=0.6)
pdt.assert_frame_equal(filtered.report, pd.DataFrame())
pdt.assert_frame_equal(filtered.report, expected.iloc[:0, :0])

ids = CircuitNodeIds.from_arrays(["default", "default", "default2"], [0, 1, 1])
filtered = self.test_obj.filter(group=ids, t_start=0.3, t_stop=0.6)
assert filtered.report.columns.tolist() == [("default", 0), ("default", 1), ("default2", 1)]
ids = CircuitNodeIds.from_tuples([("default2", 1)])
npt.assert_allclose(filtered.report.loc[:, ids.index].index, np.array([0.3, 0.4, 0.5, 0.6]))
npt.assert_allclose(filtered.report.loc[:, ids.index], np.array([[1.3, 1.4, 1.5, 1.6]]).T)
expected_columns = [("default", 0), ("default", 1), ("default2", 1)]
pdt.assert_frame_equal(filtered.report, expected.loc[:, expected_columns])


class TestPopulationFrameReport:
Expand All @@ -206,6 +235,11 @@ def setup_method(self):
ids = [(0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1), (2, 1)]
self.df = pd.DataFrame(data=data, columns=pd.MultiIndex.from_tuples(ids), index=timestamps)

@property
def empty_df(self):
"""Return an empty DataFrame with the original types of index and columns."""
return self.df.iloc[:0, :0]

def test__resolve(self):
npt.assert_array_equal(self.test_obj._resolve({Cell.MTYPE: "L6_Y"}), [1, 2])
assert self.test_obj._resolve({Cell.MTYPE: "L2_X"}) == [0]
Expand All @@ -231,9 +265,9 @@ def _assert_frame_equal(df1, df2):
t_stride = round(t_step / self.test_obj.frame_report.dt) if t_step is not None else 1

_assert_frame_equal(self.test_obj.get(t_step=t_step), self.df)
_assert_frame_equal(self.test_obj.get([], t_step=t_step), pd.DataFrame())
_assert_frame_equal(self.test_obj.get(np.array([]), t_step=t_step), pd.DataFrame())
_assert_frame_equal(self.test_obj.get((), t_step=t_step), pd.DataFrame())
_assert_frame_equal(self.test_obj.get([], t_step=t_step), self.empty_df)
_assert_frame_equal(self.test_obj.get(np.array([]), t_step=t_step), self.empty_df)
_assert_frame_equal(self.test_obj.get((), t_step=t_step), self.empty_df)

_assert_frame_equal(self.test_obj.get(2, t_step=t_step), self.df.loc[:, [2]])
_assert_frame_equal(
Expand All @@ -242,7 +276,7 @@ def _assert_frame_equal(df1, df2):

# not from this population
_assert_frame_equal(
self.test_obj.get(CircuitNodeId("default2", 2), t_step=t_step), pd.DataFrame()
self.test_obj.get(CircuitNodeId("default2", 2), t_step=t_step), self.empty_df
)

_assert_frame_equal(self.test_obj.get([2, 0], t_step=t_step), self.df.loc[:, [0, 2]])
Expand Down Expand Up @@ -322,7 +356,7 @@ def test_get_partially_not_in_report(self):

def test_get_not_in_report(self):
with patch.object(self.test_obj.__class__, "_resolve", return_value=np.asarray([4])):
pdt.assert_frame_equal(self.test_obj.get([4]), pd.DataFrame())
pdt.assert_frame_equal(self.test_obj.get([4]), self.empty_df)

def test_node_ids(self):
npt.assert_array_equal(self.test_obj.node_ids, np.array(sorted([0, 1, 2]), dtype=IDS_DTYPE))
Expand Down

0 comments on commit 4e3deea

Please sign in to comment.