Skip to content

Commit

Permalink
fix report query for empty arrays or tuple (#52)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomdele committed Apr 9, 2020
1 parent e1589b5 commit 6b2a847
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 2 deletions.
3 changes: 2 additions & 1 deletion bluepysnap/frame_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from cached_property import cached_property
from pathlib2 import Path
import numpy as np
import pandas as pd
from libsonata import ElementReportReader

Expand Down Expand Up @@ -212,7 +213,7 @@ def nodes(self):

def _resolve(self, group):
"""Transform a group into a node_id array."""
if group == []:
if isinstance(group, (np.ndarray, list, tuple)) and len(group) == 0:
return fix_libsonata_empty_list()
return self.nodes.ids(group=group)

Expand Down
2 changes: 1 addition & 1 deletion bluepysnap/spike_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def nodes(self):

def _resolve_nodes(self, group):
"""Transform a node group into a node_id array."""
if group == []:
if isinstance(group, (np.ndarray, list, tuple)) and len(group) == 0:
return fix_libsonata_empty_list()
return self.nodes.ids(group=group)

Expand Down
2 changes: 2 additions & 0 deletions tests/test_frame_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ def test_get(self):
pdt.assert_frame_equal(self.test_obj.get(), self.df)

pdt.assert_frame_equal(self.test_obj.get([]), pd.DataFrame())
pdt.assert_frame_equal(self.test_obj.get(np.array([])), pd.DataFrame())
pdt.assert_frame_equal(self.test_obj.get(()), pd.DataFrame())

pdt.assert_frame_equal(self.test_obj.get(2), self.df.loc[:, [2]])

Expand Down
2 changes: 2 additions & 0 deletions tests/test_spike_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ def test_get(self):
pdt.assert_series_equal(self.test_obj.get(),
_create_series([2, 0, 1, 2, 0], [0.1, 0.2, 0.3, 0.7, 1.3]))
pdt.assert_series_equal(self.test_obj.get([]), _create_series([], []))
pdt.assert_series_equal(self.test_obj.get(np.array([])), _create_series([], []))
pdt.assert_series_equal(self.test_obj.get(()), _create_series([], []))
npt.assert_allclose(self.test_obj.get(2), np.array([0.1, 0.7]))
npt.assert_allclose(self.test_obj.get(0, t_start=1.), [1.3])
npt.assert_allclose(self.test_obj.get(0, t_stop=1.), [0.2])
Expand Down

0 comments on commit 6b2a847

Please sign in to comment.