Skip to content

Commit

Permalink
[#156434192]: Address PR comments, merge methods
Browse files Browse the repository at this point in the history
  • Loading branch information
slobodan-ilic committed May 7, 2018
1 parent 531eec7 commit 1a4d6f2
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 37 deletions.
27 changes: 12 additions & 15 deletions src/cr/cube/crunch_cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def _prune_body(self, res, transforms=None):
row_margin = np.sum(row_margin, axis=1)
row_prune_inds = self._margin_pruned_indices(
row_margin,
self.inserted_rows_inds(transforms)
self.inserted_dim_inds(transforms, 0)
)

if self.ndim == 1 or len(res.shape) == 1:
Expand All @@ -269,7 +269,7 @@ def _prune_body(self, res, transforms=None):
col_margin = np.sum(col_margin, axis=0)
col_prune_inds = self._margin_pruned_indices(
col_margin,
self.inserted_col_inds(transforms)
self.inserted_dim_inds(transforms, 1)
)
mask = self._create_mask(res, row_prune_inds, col_prune_inds)
res = np.ma.masked_array(res, mask=mask)
Expand Down Expand Up @@ -315,7 +315,7 @@ def _prune_indices(self, transforms):
)
row_indices = self._margin_pruned_indices(
row_margin,
self.inserted_rows_inds(transforms)
self.inserted_dim_inds(transforms, 0)
)
if row_indices.ndim > 1:
# In case of MR, we'd have 2D prune indices
Expand All @@ -330,7 +330,7 @@ def _prune_indices(self, transforms):
)
col_indices = self._margin_pruned_indices(
col_margin,
self.inserted_col_inds(transforms)
self.inserted_dim_inds(transforms, 1)
)
if col_indices.ndim > 1:
# In case of MR, we'd have 2D prune indices
Expand Down Expand Up @@ -359,10 +359,10 @@ def _prune_indices_tuple(self, row_margin, column_margin, transforms):
column_margin = np.sum(column_margin, axis=0)

row_inserted_indices = (
self.inserted_rows_inds(transforms)
self.inserted_dim_inds(transforms, 0)
)
col_inserted_indices = (
self.inserted_col_inds(transforms)
self.inserted_dim_inds(transforms, 1)
)

return (
Expand All @@ -379,13 +379,16 @@ def row_direction_axis(self):
return 2
return 1

def inserted_rows_inds(self, transforms):
def inserted_dim_inds(self, transforms, dim):
if not transforms:
return []
inserted_inds = self.inserted_hs_indices()
row_dim_ind = 0 if self.ndim < 3 else 1
if dim == 0: # In case of row
dim_ind = 0 if self.ndim < 3 else 1
elif dim == 1:
dim_ind = 1
return np.array(
inserted_inds[row_dim_ind] if len(inserted_inds) else []
inserted_inds[dim_ind] if len(inserted_inds) else []
)

@staticmethod
Expand All @@ -402,12 +405,6 @@ def _margin_pruned_indices(margin, insertions):
def col_direction_axis(self):
return self.ndim - 2

def inserted_col_inds(self, transforms):
if not transforms:
return []
inserted_inds = self.inserted_hs_indices()
return np.array(inserted_inds[1] if len(inserted_inds) > 1 else [])

@classmethod
def _fix_valid_indices(cls, valid_indices, insertion_index, dim):
'''Add indices for H&S inserted elements.'''
Expand Down
26 changes: 4 additions & 22 deletions tests/unit/test_crunch_cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,35 +416,17 @@ def hs_indices(self):
@patch('numpy.array')
@patch('cr.cube.crunch_cube.CrunchCube.inserted_hs_indices')
@patch('cr.cube.crunch_cube.CrunchCube.ndim', 1)
def test_inserted_row_inds(self, mock_inserted_hs_indices,
mock_np_array):
def test_inserted_inds(self, mock_inserted_hs_indices,
mock_np_array):
mock_np_array.return_value = Mock()

cc = CrunchCube({})
expected = []

# Assert indices are not fetched without trasforms
actual = cc.inserted_rows_inds(None)
actual = cc.inserted_dim_inds(None, 0)
assert actual == expected

# Assert indices are fetch with transforms
actual = cc.inserted_rows_inds(Mock())
mock_inserted_hs_indices.assert_called_once()

@patch('numpy.array')
@patch('cr.cube.crunch_cube.CrunchCube.inserted_hs_indices')
@patch('cr.cube.crunch_cube.CrunchCube.ndim', 1)
def test_inserted_col_inds(self, mock_inserted_hs_indices,
mock_np_array):
mock_np_array.return_value = Mock()

cc = CrunchCube({})
expected = []

# Assert indices are not fetched without trasforms
actual = cc.inserted_col_inds(None)
assert actual == expected

# Assert indices are fetch with transforms
actual = cc.inserted_rows_inds(Mock())
actual = cc.inserted_dim_inds(Mock(), 0)
mock_inserted_hs_indices.assert_called_once()

0 comments on commit 1a4d6f2

Please sign in to comment.