From 9f084f2cdb8aabff9532a0124e2e4e823fc36646 Mon Sep 17 00:00:00 2001 From: Slobodan Ilic Date: Fri, 8 Feb 2019 09:39:50 +0100 Subject: [PATCH 1/5] Implement min base size masks * TDD approach * Implement masks in a separate class --- src/cr/cube/min_base_size_mask.py | 69 ++++++++++++++ tests/unit/test_min_base_size.py | 151 ++++++++++++++++++++++++++++++ 2 files changed, 220 insertions(+) create mode 100644 src/cr/cube/min_base_size_mask.py create mode 100644 tests/unit/test_min_base_size.py diff --git a/src/cr/cube/min_base_size_mask.py b/src/cr/cube/min_base_size_mask.py new file mode 100644 index 000000000..259629fe8 --- /dev/null +++ b/src/cr/cube/min_base_size_mask.py @@ -0,0 +1,69 @@ +# encoding: utf-8 + +"""MinBaseSize class.""" + +from __future__ import division +import numpy as np + +from cr.cube.util import lazyproperty +from cr.cube.enum import DIMENSION_TYPE as DT + + +class MinBaseSizeMask: + """Helper for deciding which rows/columns to suppress, based on min base size. + + If a base value, that is used when calculating percentages, is less than a given + minimum base size, then all of the values obtained in such a way need to + suppressed. We achieve this by generating a mask, based on row/column/table + marginal values and the shape of the underlying slice. + """ + + def __init__(self, slice_, size): + self._slice = slice_ + self._size = size + + @lazyproperty + def column_mask(self): + margin = self._slice.margin(axis=0) + mask = margin < self._size + + if margin.shape == self._shape: + # If margin shape is the same as slice's (such as in a col margin for + # MR x CAT), don't broadcast the mask to the array shape, since + # they're already the same. + return mask + + # If the row margin is a row vector - broadcast it's mask to the array shape + return np.logical_or(np.zeros(self._shape, dtype=bool), mask) + + @lazyproperty + def row_mask(self): + margin = self._slice.margin(axis=1) + mask = margin < self._size + + if margin.shape == self._shape: + # If margin shape is the same as slice's (such as in a row margin for + # CAT x MR), don't broadcast the mask to the array shape, since + # they're already the same. + return mask + + # If the row margin is a column vector - broadcast it's mask to the array shape + return np.logical_or(np.zeros(self._shape, dtype=bool), mask[:, None]) + + @lazyproperty + def table_mask(self): + margin = self._slice.margin(axis=None) + mask = margin < self._size + + if margin.shape == self._shape: + return mask + + if self._slice.dim_types[0] == DT.MR: + # If the margin is a column vector - broadcast it's mask to the array shape + return np.logical_or(np.zeros(self._shape, dtype=bool), mask[:, None]) + + return np.logical_or(np.zeros(self._shape, dtype=bool), mask) + + @lazyproperty + def _shape(self): + return self._slice.get_shape() diff --git a/tests/unit/test_min_base_size.py b/tests/unit/test_min_base_size.py new file mode 100644 index 000000000..a4e56b52f --- /dev/null +++ b/tests/unit/test_min_base_size.py @@ -0,0 +1,151 @@ +# encoding: utf-8 + +"""Unit test suite for cr.cube.min_base_size module.""" + +import pytest +import numpy as np + +from cr.cube.cube_slice import CubeSlice +from cr.cube.min_base_size_mask import MinBaseSizeMask +from cr.cube.enum import DIMENSION_TYPE as DT + +from ..unitutil import instance_mock, method_mock, property_mock + + +class DescribeMinBaseSizeMask: + def it_provides_access_to_column_direction_mask( + self, _margin, _get_shape, column_mask_fixture + ): + size, shape, margin, expected_mask = column_mask_fixture + _margin.return_value = margin + _get_shape.return_value = shape + row_mask = MinBaseSizeMask(CubeSlice(None, None), size).column_mask + np.testing.assert_array_equal(row_mask, expected_mask) + + def it_provides_access_to_row_direction_mask( + self, _margin, _get_shape, row_mask_fixture + ): + size, shape, margin, expected_mask = row_mask_fixture + _margin.return_value = margin + _get_shape.return_value = shape + row_mask = MinBaseSizeMask(CubeSlice(None, None), size).row_mask + np.testing.assert_array_equal(row_mask, expected_mask) + + def it_provides_access_to_table_direction_mask( + self, _margin, _get_shape, _dim_types, table_mask_fixture + ): + size, shape, dim_types, margin, expected_mask = table_mask_fixture + _margin.return_value = margin + _get_shape.return_value = shape + _dim_types.return_value = dim_types + table_mask = MinBaseSizeMask(CubeSlice(None, None), size).table_mask + np.testing.assert_array_equal(table_mask, expected_mask) + + def it_sets_slice_on_construction(self, slice_): + size = 50 + min_base_size = MinBaseSizeMask(slice_, size) + assert min_base_size._slice is slice_ + assert min_base_size._size is size + + # fixtures ------------------------------------------------------- + + @pytest.fixture( + params=[ + # Margin is just a single row - broadcast it across shape + (30, (2, 3), [10, 20, 30], [[True, True, False], [True, True, False]]), + # Margin is 2D table (as in CAT x MR), use that shape (don't broadcast) + ( + 40, + (2, 3), + [[10, 20, 40], [30, 50, 60]], + [[True, True, False], [True, False, False]], + ), + ] + ) + def column_mask_fixture(self, request): + size, shape, margin, expected = request.param + margin = np.array(margin) + expected = np.array(expected) + return size, shape, margin, expected + + @pytest.fixture( + params=[ + # Margin is just a single column - broadcast it across shape + (30, (3, 2), [10, 20, 30], [[True, True], [True, True], [False, False]]), + # Margin is 2D table (as in CAT x MR), use that shape (don't broadcast) + ( + 40, + (2, 3), + [[10, 20, 40], [30, 50, 60]], + [[True, True, False], [True, False, False]], + ), + ] + ) + def row_mask_fixture(self, request): + size, shape, margin, expected = request.param + margin = np.array(margin) + expected = np.array(expected) + return size, shape, margin, expected + + @pytest.fixture( + params=[ + ( + 30, + (3, 2), + (DT.CAT, DT.CAT), + 10, + [[True, True], [True, True], [True, True]], + ), + ( + 30, + (3, 2), + (DT.CAT, DT.CAT), + 40, + [[False, False], [False, False], [False, False]], + ), + ( + 40, + (2, 3), + (DT.CAT, DT.CAT), + [[10, 20, 40], [30, 50, 60]], + [[True, True, False], [True, False, False]], + ), + ( + 40, + (2, 3), + (DT.CAT, DT.MR), + [10, 20, 40], + [[True, True, False], [True, True, False]], + ), + ( + 40, + (2, 3), + (DT.MR, DT.CAT), + [10, 40], + [[True, True, True], [False, False, False]], + ), + ] + ) + def table_mask_fixture(self, request): + size, shape, dim_types, margin, expected = request.param + margin = np.array(margin) + expected = np.array(expected) + return size, shape, dim_types, margin, expected + + # fixture components --------------------------------------------- + + @pytest.fixture + def slice_(self, request): + return instance_mock(request, CubeSlice) + + @pytest.fixture + def _margin(self, request): + return method_mock(request, CubeSlice, "margin") + + @pytest.fixture + def _get_shape(self, request): + return method_mock(request, CubeSlice, "get_shape") + + @pytest.fixture + def _dim_types(self, request): + return property_mock(request, CubeSlice, "dim_types") From d05fb30477584049213b6c0aeefa3f8dfb0d063f Mon Sep 17 00:00:00 2001 From: Slobodan Ilic Date: Tue, 12 Feb 2019 14:32:37 +0100 Subject: [PATCH 2/5] Add min base size to cube slice * Support with integration tests * Add additional function that returns thee object of the new mask class --- src/cr/cube/cube_slice.py | 5 + tests/integration/test_multiple_response.py | 151 ++++++++++++++++++++ 2 files changed, 156 insertions(+) diff --git a/src/cr/cube/cube_slice.py b/src/cr/cube/cube_slice.py index 2b4005a07..b39a96c5f 100644 --- a/src/cr/cube/cube_slice.py +++ b/src/cr/cube/cube_slice.py @@ -13,6 +13,7 @@ from scipy.stats.contingency import expected_freq from cr.cube.enum import DIMENSION_TYPE as DT +from cr.cube.min_base_size_mask import MinBaseSizeMask from cr.cube.measures.scale_means import ScaleMeans from cr.cube.measures.pairwise_pvalues import PairwisePvalues from cr.cube.util import compress_pruned, lazyproperty, memoize @@ -275,6 +276,10 @@ def margin( return self._extract_slice_result_from_cube(margin) + def min_base_size_mask(self, size): + """Return MinBaseSizeMask object with correct row, col and table masks.""" + return MinBaseSizeMask(self, size) + @lazyproperty def mr_dim_ind(self): """Get the correct index of the MR dimension in the cube slice.""" diff --git a/tests/integration/test_multiple_response.py b/tests/integration/test_multiple_response.py index bd358acec..096f818b9 100644 --- a/tests/integration/test_multiple_response.py +++ b/tests/integration/test_multiple_response.py @@ -1832,3 +1832,154 @@ def test_mr_by_cat_hs_cell_percentage(): ) actual = cube.proportions(axis=None, include_transforms_for_dims=[0, 1]) np.testing.assert_almost_equal(actual, expected) + + +def test_mr_x_cat_min_base_size_mask(): + cube_slice = CrunchCube(CR.MR_X_CAT_HS).slices[0] + + # Table margin evaluates to: + # + # array([176.36555176, 211.42058767, 247.74073787, 457.05095566, 471.93176847]) + # + # We thus choose the min base size to be 220, and expeect it to broadcast across + # columns (in the row direction, i.e. axis=1), sincee the MR is what won't be + # collapsed after doing the base calculation in the table direction. + expected_table_mask = np.array( + [ + [True, True, True, True, True, True], + [True, True, True, True, True, True], + [False, False, False, False, False, False], + [False, False, False, False, False, False], + [False, False, False, False, False, False], + ] + ) + np.testing.assert_array_equal( + cube_slice.min_base_size_mask(220).table_mask, expected_table_mask + ) + + # Column margin evaluates to: + # + # np.array([ + # [21.78869966, 32.81576041, 0.0, 58.86625406, 62.89483764, 0.0], + # [15.7386377, 40.78574176, 0.0, 76.98691601, 77.90929221, 0.0], + # [12.22150269, 40.98148847, 0.0, 91.95428994, 102.58345677, 0.0], + # [20.95300034, 63.13595644, 0.0, 165.67203661, 207.28996226, 0.0], + # [30.94322363, 88.23933157, 0.0, 165.82148906, 186.92772421, 0.0], + # ]) + # + # We thus choose the min base size to be 30, and expeect it to not be broadcast. + expected_column_mask = np.array( + [ + [True, False, True, False, False, True], + [True, False, True, False, False, True], + [True, False, True, False, False, True], + [True, False, True, False, False, True], + [False, False, True, False, False, True], + ] + ) + np.testing.assert_array_equal( + cube_slice.min_base_size_mask(30).column_mask, expected_column_mask + ) + + # Row margin evaluates to: + # + # np.array([31.63152104, 70.73073413, 125.75911351, 366.88839144, 376.76564059]) + # + # We thus choose the min base size to be 80, and expeect it to broadcast across + # columns (in the row direction, i.e. axis=1), sincee the MR is what won't be + # collapsed after doing the base calculation in the row direction. + expected_row_mask = np.array( + [ + [True, True, True, True, True, True], + [True, True, True, True, True, True], + [False, False, False, False, False, False], + [False, False, False, False, False, False], + [False, False, False, False, False, False], + ] + ) + np.testing.assert_array_equal( + cube_slice.min_base_size_mask(80).row_mask, expected_row_mask + ) + + +def test_cat_x_mr_min_base_size_mask(): + cube_slice = CrunchCube(CR.CAT_X_MR).slices[0] + + # Table margin evaluates to: + # + # array([80, 79, 70]) + # + # We thus choose the min base size to be 75, and expeect it to broadcast across + # rows (in the col direction, i.e. axis=0), sincee the MR is what won't be + # collapsed after doing the base calculation in the table direction. + expected_table_mask = np.array([[False, False, True], [False, False, True]]) + np.testing.assert_array_equal( + cube_slice.min_base_size_mask(75).table_mask, expected_table_mask + ) + + # Column margin evaluates to: + # + # np.array([40, 34, 38]) + # + # We thus choose the min base size to be 35, and expeect it to broadcast across + # rows (in the col direction, i.e. axis=0), sincee the MR is what won't be + # collapsed after doing the base calculation in the table direction. + expected_column_mask = np.array([[False, True, False], [False, True, False]]) + np.testing.assert_array_equal( + cube_slice.min_base_size_mask(35).column_mask, expected_column_mask + ) + + # Row margin evaluates to: + # + # np.array([[28, 25, 23], [52, 54, 47]]) + # + # We thus choose the min base size to be 25, and expeect it to not be broadcast + expected_row_mask = np.array([[False, False, True], [False, False, False]]) + np.testing.assert_array_equal( + cube_slice.min_base_size_mask(25).row_mask, expected_row_mask + ) + + +def test_mr_x_mr_min_base_size_mask(): + cube_slice = CrunchCube(CR.CAT_X_MR_X_MR).slices[0] + + # Table margin evaluates to: + # + # array([[10000, 10000], + # [10000, 10000], + # [10000, 10000]]) + # + # We thus choose the min base size to be 11000, and expeect it to be broadcast + # across all values + expected_table_mask = np.array([[True, True], [True, True], [True, True]]) + np.testing.assert_array_equal( + cube_slice.min_base_size_mask(11000).table_mask, expected_table_mask + ) + + # Column margin evaluates to: + # + # array([[1914, 5958], + # [1914, 5958], + # [1914, 5958]]) + # + # We thus choose the min base size to be 2000, and expeect it to broadcast across + # rows (in the col direction, i.e. axis=0), sincee the MR is what won't be + # collapsed after doing the base calculation in the table direction. + expected_column_mask = np.array([[True, False], [True, False], [True, False]]) + np.testing.assert_array_equal( + cube_slice.min_base_size_mask(2000).column_mask, expected_column_mask + ) + + # Row margin evaluates to: + # + # array([[6046, 6046], + # [1008, 1008], + # [ 974, 974]]) + # + # We thus choose the min base size to be 1000, and expeect it to broadcast across + # rows (in the col direction, i.e. axis=0), sincee the MR is what won't be + # collapsed after doing the base calculation in the table direction. + expected_row_mask = np.array([[False, False], [False, False], [True, True]]) + np.testing.assert_array_equal( + cube_slice.min_base_size_mask(1000).row_mask, expected_row_mask + ) From 562d6e22a7c339cfd440fba523c6d74b70ec0063 Mon Sep 17 00:00:00 2001 From: Slobodan Ilic Date: Wed, 13 Feb 2019 10:03:42 +0100 Subject: [PATCH 3/5] Add docstrings --- src/cr/cube/cube_slice.py | 14 +++++++++++++- src/cr/cube/min_base_size_mask.py | 3 +++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/src/cr/cube/cube_slice.py b/src/cr/cube/cube_slice.py index b39a96c5f..4eb61b9db 100644 --- a/src/cr/cube/cube_slice.py +++ b/src/cr/cube/cube_slice.py @@ -277,7 +277,19 @@ def margin( return self._extract_slice_result_from_cube(margin) def min_base_size_mask(self, size): - """Return MinBaseSizeMask object with correct row, col and table masks.""" + """Returns MinBaseSizeMask object with correct row, col and table masks. + + The returned object stores the necessary information about the base size, as + well as about the base values. It can create corresponding masks in teh row, + column, and table directions, based on the corresponding base values + (the values of the unweighted margins). + + Usage: + >>> slice_ = cube.slices[0] # obtain a valid cube slice + >>> slice_.min_base_size_mask(30).row_mask + >>> slice_.min_base_size_mask(50).column_mask + >>> slice_.min_base_size_mask(22).table_mask + """ return MinBaseSizeMask(self, size) @lazyproperty diff --git a/src/cr/cube/min_base_size_mask.py b/src/cr/cube/min_base_size_mask.py index 259629fe8..d8011b113 100644 --- a/src/cr/cube/min_base_size_mask.py +++ b/src/cr/cube/min_base_size_mask.py @@ -24,6 +24,7 @@ def __init__(self, slice_, size): @lazyproperty def column_mask(self): + """ndarray, True where column margin <= min_base_size, same shape as slice.""" margin = self._slice.margin(axis=0) mask = margin < self._size @@ -38,6 +39,7 @@ def column_mask(self): @lazyproperty def row_mask(self): + """ndarray, True where row margin <= min_base_size, same shape as slice.""" margin = self._slice.margin(axis=1) mask = margin < self._size @@ -52,6 +54,7 @@ def row_mask(self): @lazyproperty def table_mask(self): + """ndarray, True where table margin <= min_base_size, same shape as slice.""" margin = self._slice.margin(axis=None) mask = margin < self._size From 8d1ef299e9d9c58a18c1563d5279cd8aa62c86e3 Mon Sep 17 00:00:00 2001 From: Slobodan Ilic Date: Wed, 13 Feb 2019 19:14:34 +0100 Subject: [PATCH 4/5] Fix so that exporter tests pass --- src/cr/cube/cube_slice.py | 12 +++++++----- src/cr/cube/min_base_size_mask.py | 18 +++++++++++++----- tests/integration/test_multiple_response.py | 18 ++++++++++-------- tests/unit/test_min_base_size.py | 13 ++++++++++--- 4 files changed, 40 insertions(+), 21 deletions(-) diff --git a/src/cr/cube/cube_slice.py b/src/cr/cube/cube_slice.py index 4eb61b9db..0bbf025f3 100644 --- a/src/cr/cube/cube_slice.py +++ b/src/cr/cube/cube_slice.py @@ -127,7 +127,7 @@ def dim_types(self): return self._cube.dim_types[-2:] @memoize - def get_shape(self, prune=False): + def get_shape(self, prune=False, hs_dims=None): """Tuple of array dimensions' lengths. It returns a tuple of ints, each representing the length of a cube @@ -144,9 +144,11 @@ def get_shape(self, prune=False): >>> pruned_shape = get_shape(prune=True) """ if not prune: - return self.as_array().shape + return self.as_array(include_transforms_for_dims=hs_dims).shape - shape = compress_pruned(self.as_array(prune=True)).shape + shape = compress_pruned( + self.as_array(prune=True, include_transforms_for_dims=hs_dims) + ).shape # Eliminate dimensions that get reduced to 1 # (e.g. single element categoricals) return tuple(n for n in shape if n > 1) @@ -276,7 +278,7 @@ def margin( return self._extract_slice_result_from_cube(margin) - def min_base_size_mask(self, size): + def min_base_size_mask(self, size, hs_dims=None): """Returns MinBaseSizeMask object with correct row, col and table masks. The returned object stores the necessary information about the base size, as @@ -290,7 +292,7 @@ def min_base_size_mask(self, size): >>> slice_.min_base_size_mask(50).column_mask >>> slice_.min_base_size_mask(22).table_mask """ - return MinBaseSizeMask(self, size) + return MinBaseSizeMask(self, size, hs_dims) @lazyproperty def mr_dim_ind(self): diff --git a/src/cr/cube/min_base_size_mask.py b/src/cr/cube/min_base_size_mask.py index d8011b113..760548234 100644 --- a/src/cr/cube/min_base_size_mask.py +++ b/src/cr/cube/min_base_size_mask.py @@ -18,14 +18,17 @@ class MinBaseSizeMask: marginal values and the shape of the underlying slice. """ - def __init__(self, slice_, size): + def __init__(self, slice_, size, hs_dims=None): self._slice = slice_ self._size = size + self._hs_dims = hs_dims @lazyproperty def column_mask(self): """ndarray, True where column margin <= min_base_size, same shape as slice.""" - margin = self._slice.margin(axis=0) + margin = self._slice.margin( + axis=0, weighted=False, include_transforms_for_dims=self._hs_dims + ) mask = margin < self._size if margin.shape == self._shape: @@ -40,7 +43,9 @@ def column_mask(self): @lazyproperty def row_mask(self): """ndarray, True where row margin <= min_base_size, same shape as slice.""" - margin = self._slice.margin(axis=1) + margin = self._slice.margin( + axis=1, weighted=False, include_transforms_for_dims=self._hs_dims + ) mask = margin < self._size if margin.shape == self._shape: @@ -55,7 +60,7 @@ def row_mask(self): @lazyproperty def table_mask(self): """ndarray, True where table margin <= min_base_size, same shape as slice.""" - margin = self._slice.margin(axis=None) + margin = self._slice.margin(axis=None, weighted=False) mask = margin < self._size if margin.shape == self._shape: @@ -69,4 +74,7 @@ def table_mask(self): @lazyproperty def _shape(self): - return self._slice.get_shape() + shape = self._slice.get_shape(hs_dims=self._hs_dims) + if len(shape) != self._slice.ndim: + shape = (shape[0], 1) + return shape diff --git a/tests/integration/test_multiple_response.py b/tests/integration/test_multiple_response.py index 096f818b9..27f6c6d33 100644 --- a/tests/integration/test_multiple_response.py +++ b/tests/integration/test_multiple_response.py @@ -1859,18 +1859,20 @@ def test_mr_x_cat_min_base_size_mask(): # Column margin evaluates to: # - # np.array([ - # [21.78869966, 32.81576041, 0.0, 58.86625406, 62.89483764, 0.0], - # [15.7386377, 40.78574176, 0.0, 76.98691601, 77.90929221, 0.0], - # [12.22150269, 40.98148847, 0.0, 91.95428994, 102.58345677, 0.0], - # [20.95300034, 63.13595644, 0.0, 165.67203661, 207.28996226, 0.0], - # [30.94322363, 88.23933157, 0.0, 165.82148906, 186.92772421, 0.0], - # ]) + # np.array( + # [ + # [15, 24, 0, 57, 69, 0], + # [15, 34, 0, 75, 86, 0], + # [13, 37, 0, 81, 111, 0], + # [20, 50, 0, 159, 221, 0], + # [32, 69, 0, 167, 208, 0], + # ] + # ) # # We thus choose the min base size to be 30, and expeect it to not be broadcast. expected_column_mask = np.array( [ - [True, False, True, False, False, True], + [True, True, True, False, False, True], [True, False, True, False, False, True], [True, False, True, False, False, True], [True, False, True, False, False, True], diff --git a/tests/unit/test_min_base_size.py b/tests/unit/test_min_base_size.py index a4e56b52f..d3af7ecea 100644 --- a/tests/unit/test_min_base_size.py +++ b/tests/unit/test_min_base_size.py @@ -14,30 +14,33 @@ class DescribeMinBaseSizeMask: def it_provides_access_to_column_direction_mask( - self, _margin, _get_shape, column_mask_fixture + self, _margin, _get_shape, _ndim, column_mask_fixture ): size, shape, margin, expected_mask = column_mask_fixture _margin.return_value = margin _get_shape.return_value = shape + _ndim.return_value = len(shape) row_mask = MinBaseSizeMask(CubeSlice(None, None), size).column_mask np.testing.assert_array_equal(row_mask, expected_mask) def it_provides_access_to_row_direction_mask( - self, _margin, _get_shape, row_mask_fixture + self, _margin, _get_shape, _ndim, row_mask_fixture ): size, shape, margin, expected_mask = row_mask_fixture _margin.return_value = margin _get_shape.return_value = shape + _ndim.return_value = len(shape) row_mask = MinBaseSizeMask(CubeSlice(None, None), size).row_mask np.testing.assert_array_equal(row_mask, expected_mask) def it_provides_access_to_table_direction_mask( - self, _margin, _get_shape, _dim_types, table_mask_fixture + self, _margin, _get_shape, _ndim, _dim_types, table_mask_fixture ): size, shape, dim_types, margin, expected_mask = table_mask_fixture _margin.return_value = margin _get_shape.return_value = shape _dim_types.return_value = dim_types + _ndim.return_value = len(shape) table_mask = MinBaseSizeMask(CubeSlice(None, None), size).table_mask np.testing.assert_array_equal(table_mask, expected_mask) @@ -149,3 +152,7 @@ def _get_shape(self, request): @pytest.fixture def _dim_types(self, request): return property_mock(request, CubeSlice, "dim_types") + + @pytest.fixture + def _ndim(self, request): + return property_mock(request, CubeSlice, "ndim") From 9d354fdc676237fff8baa1b5e4244aa56c402e45 Mon Sep 17 00:00:00 2001 From: Slobodan Ilic Date: Wed, 13 Feb 2019 19:21:56 +0100 Subject: [PATCH 5/5] Coverage back to 100% --- src/cr/cube/min_base_size_mask.py | 6 ++++++ tests/unit/test_min_base_size.py | 14 ++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/src/cr/cube/min_base_size_mask.py b/src/cr/cube/min_base_size_mask.py index 760548234..f796c5012 100644 --- a/src/cr/cube/min_base_size_mask.py +++ b/src/cr/cube/min_base_size_mask.py @@ -75,6 +75,12 @@ def table_mask(self): @lazyproperty def _shape(self): shape = self._slice.get_shape(hs_dims=self._hs_dims) + if len(shape) != self._slice.ndim: + # TODO: This is an ugly hack that needs to happen due to the fact that we + # purge dimensions with the count of 1, when getting the slice shape. This + # will be addressed in a PR (already on the way) that strives to abandon + # the ad-hoc purging of 1-element dimensions altogether. shape = (shape[0], 1) + return shape diff --git a/tests/unit/test_min_base_size.py b/tests/unit/test_min_base_size.py index d3af7ecea..501f76516 100644 --- a/tests/unit/test_min_base_size.py +++ b/tests/unit/test_min_base_size.py @@ -44,6 +44,15 @@ def it_provides_access_to_table_direction_mask( table_mask = MinBaseSizeMask(CubeSlice(None, None), size).table_mask np.testing.assert_array_equal(table_mask, expected_mask) + def it_retains_single_element_dimension_in_shape( + self, _ndim, _get_shape, shape_fixture + ): + slice_shape, ndim, expected_mask_shape = shape_fixture + min_base_size = MinBaseSizeMask(CubeSlice(None, None), None) + _ndim.return_value = ndim + _get_shape.return_value = slice_shape + assert min_base_size._shape == expected_mask_shape + def it_sets_slice_on_construction(self, slice_): size = 50 min_base_size = MinBaseSizeMask(slice_, size) @@ -90,6 +99,11 @@ def row_mask_fixture(self, request): expected = np.array(expected) return size, shape, margin, expected + @pytest.fixture(params=[((2, 3), 2, (2, 3)), ((2,), 2, (2, 1))]) + def shape_fixture(self, request): + slice_shape, ndim, expected = request.param + return slice_shape, ndim, expected + @pytest.fixture( params=[ (