From 1a4d6f2b1c77f6d6350804dde6994ccae970e6fc Mon Sep 17 00:00:00 2001 From: Slobodan Ilic Date: Mon, 7 May 2018 16:35:36 +0200 Subject: [PATCH] [#156434192]: Address PR comments, merge methods --- src/cr/cube/crunch_cube.py | 27 ++++++++++++--------------- tests/unit/test_crunch_cube.py | 26 ++++---------------------- 2 files changed, 16 insertions(+), 37 deletions(-) diff --git a/src/cr/cube/crunch_cube.py b/src/cr/cube/crunch_cube.py index 8fd43d369..51c4e0ddc 100644 --- a/src/cr/cube/crunch_cube.py +++ b/src/cr/cube/crunch_cube.py @@ -252,7 +252,7 @@ def _prune_body(self, res, transforms=None): row_margin = np.sum(row_margin, axis=1) row_prune_inds = self._margin_pruned_indices( row_margin, - self.inserted_rows_inds(transforms) + self.inserted_dim_inds(transforms, 0) ) if self.ndim == 1 or len(res.shape) == 1: @@ -269,7 +269,7 @@ def _prune_body(self, res, transforms=None): col_margin = np.sum(col_margin, axis=0) col_prune_inds = self._margin_pruned_indices( col_margin, - self.inserted_col_inds(transforms) + self.inserted_dim_inds(transforms, 1) ) mask = self._create_mask(res, row_prune_inds, col_prune_inds) res = np.ma.masked_array(res, mask=mask) @@ -315,7 +315,7 @@ def _prune_indices(self, transforms): ) row_indices = self._margin_pruned_indices( row_margin, - self.inserted_rows_inds(transforms) + self.inserted_dim_inds(transforms, 0) ) if row_indices.ndim > 1: # In case of MR, we'd have 2D prune indices @@ -330,7 +330,7 @@ def _prune_indices(self, transforms): ) col_indices = self._margin_pruned_indices( col_margin, - self.inserted_col_inds(transforms) + self.inserted_dim_inds(transforms, 1) ) if col_indices.ndim > 1: # In case of MR, we'd have 2D prune indices @@ -359,10 +359,10 @@ def _prune_indices_tuple(self, row_margin, column_margin, transforms): column_margin = np.sum(column_margin, axis=0) row_inserted_indices = ( - self.inserted_rows_inds(transforms) + self.inserted_dim_inds(transforms, 0) ) col_inserted_indices = ( - self.inserted_col_inds(transforms) + self.inserted_dim_inds(transforms, 1) ) return ( @@ -379,13 +379,16 @@ def row_direction_axis(self): return 2 return 1 - def inserted_rows_inds(self, transforms): + def inserted_dim_inds(self, transforms, dim): if not transforms: return [] inserted_inds = self.inserted_hs_indices() - row_dim_ind = 0 if self.ndim < 3 else 1 + if dim == 0: # In case of row + dim_ind = 0 if self.ndim < 3 else 1 + elif dim == 1: + dim_ind = 1 return np.array( - inserted_inds[row_dim_ind] if len(inserted_inds) else [] + inserted_inds[dim_ind] if len(inserted_inds) else [] ) @staticmethod @@ -402,12 +405,6 @@ def _margin_pruned_indices(margin, insertions): def col_direction_axis(self): return self.ndim - 2 - def inserted_col_inds(self, transforms): - if not transforms: - return [] - inserted_inds = self.inserted_hs_indices() - return np.array(inserted_inds[1] if len(inserted_inds) > 1 else []) - @classmethod def _fix_valid_indices(cls, valid_indices, insertion_index, dim): '''Add indices for H&S inserted elements.''' diff --git a/tests/unit/test_crunch_cube.py b/tests/unit/test_crunch_cube.py index 59aacd51d..cd3767169 100644 --- a/tests/unit/test_crunch_cube.py +++ b/tests/unit/test_crunch_cube.py @@ -416,35 +416,17 @@ def hs_indices(self): @patch('numpy.array') @patch('cr.cube.crunch_cube.CrunchCube.inserted_hs_indices') @patch('cr.cube.crunch_cube.CrunchCube.ndim', 1) - def test_inserted_row_inds(self, mock_inserted_hs_indices, - mock_np_array): + def test_inserted_inds(self, mock_inserted_hs_indices, + mock_np_array): mock_np_array.return_value = Mock() cc = CrunchCube({}) expected = [] # Assert indices are not fetched without trasforms - actual = cc.inserted_rows_inds(None) + actual = cc.inserted_dim_inds(None, 0) assert actual == expected # Assert indices are fetch with transforms - actual = cc.inserted_rows_inds(Mock()) - mock_inserted_hs_indices.assert_called_once() - - @patch('numpy.array') - @patch('cr.cube.crunch_cube.CrunchCube.inserted_hs_indices') - @patch('cr.cube.crunch_cube.CrunchCube.ndim', 1) - def test_inserted_col_inds(self, mock_inserted_hs_indices, - mock_np_array): - mock_np_array.return_value = Mock() - - cc = CrunchCube({}) - expected = [] - - # Assert indices are not fetched without trasforms - actual = cc.inserted_col_inds(None) - assert actual == expected - - # Assert indices are fetch with transforms - actual = cc.inserted_rows_inds(Mock()) + actual = cc.inserted_dim_inds(Mock(), 0) mock_inserted_hs_indices.assert_called_once()