diff --git a/src/cr/cube/crunch_cube.py b/src/cr/cube/crunch_cube.py index 4e12a9c3c..dee38d221 100644 --- a/src/cr/cube/crunch_cube.py +++ b/src/cr/cube/crunch_cube.py @@ -72,7 +72,7 @@ def __init__(self, response): 'A `cube` must be JSON or `dict`.' ).format(type(response))) - self.slices = self._get_slices() + self.slices = self.get_slices() def _fix_shape(self, array): '''Fixes shape of MR variables. @@ -1165,8 +1165,8 @@ def scale_means(self): '''Get cube means.''' return ScaleMeans(self).data - def _get_slices(self): - if self.ndim < 3: + def get_slices(self, ca_as_0th=False): + if self.ndim < 3 and not ca_as_0th: return [CubeSlice(self, 0)] - return [CubeSlice(self, i) for i, _ in enumerate(self.labels()[0])] + return [CubeSlice(self, i, ca_as_0th) for i, _ in enumerate(self.labels()[0])] diff --git a/src/cr/cube/cube_slice.py b/src/cr/cube/cube_slice.py index d550681ed..7a06b5814 100644 --- a/src/cr/cube/cube_slice.py +++ b/src/cr/cube/cube_slice.py @@ -14,11 +14,23 @@ class CubeSlice(object): ''' row_dim_ind = 0 - col_dim_ind = 1 - def __init__(self, cube, index): + def __init__(self, cube, index, ca_as_0th=False): + + if ca_as_0th and cube.dim_types[0] != 'categorical_array': + msg = ( + 'Cannot set CA as 0th for cube that ' + 'does not have CA items as the 0th dimension.' + ) + raise ValueError(msg) + self._cube = cube self._index = index + self.ca_as_0th = ca_as_0th + + @property + def col_dim_ind(self): + return 1 if not self.ca_as_0th else 0 def __getattr__(self, attr): cube_attr = getattr(self._cube, attr) @@ -45,6 +57,12 @@ def _update_args(self, kwargs): # If cube is 2D it doesn't actually have slices (itself is a slice). # In this case we don't need to convert any arguments, but just # pass them to the underlying cube (which is the slice). + if self.ca_as_0th: + axis = kwargs.get('axis', False) + if axis is None: + # TODO: Write detailed explanation here in comments. + # Special case for CA slices (in multitables). + kwargs['axis'] = 1 return kwargs # Handling API methods that include 'axis' parameter @@ -88,7 +106,7 @@ def _update_args(self, kwargs): return kwargs def _update_result(self, result): - if self.ndim < 3 or len(result) - 1 < self._index: + if (self.ndim < 3 and not self.ca_as_0th) or len(result) - 1 < self._index: return result result = result[self._index] if isinstance(result, tuple): @@ -100,7 +118,7 @@ def _update_result(self, result): def _call_cube_method(self, method, *args, **kwargs): kwargs = self._update_args(kwargs) result = getattr(self._cube, method)(*args, **kwargs) - if method in ('labels', 'inserted_hs_indices'): + if method in ('labels', 'inserted_hs_indices') and not self.ca_as_0th: return result[-2:] return self._update_result(result) @@ -112,7 +130,7 @@ def table_name(self): of the cube name with the label of the corresponding slice (nth label of the 0th dimension). ''' - if self.ndim < 3: + if self.ndim < 3 and not self.ca_as_0th: return None title = self._cube.name @@ -153,7 +171,11 @@ def ca_main_axis(self): def labels(self, hs_dims=None, prune=False): '''Get labels for the cube slice, and perform pruning by slice.''' - labels = self._cube.labels(include_transforms_for_dims=hs_dims)[-2:] + if self.ca_as_0th: + labels = self._cube.labels(include_transforms_for_dims=hs_dims)[1:] + else: + labels = self._cube.labels(include_transforms_for_dims=hs_dims)[-2:] + if not prune: return labels diff --git a/tests/unit/test_crunch_cube.py b/tests/unit/test_crunch_cube.py index 6d2a664b7..5874f80cb 100644 --- a/tests/unit/test_crunch_cube.py +++ b/tests/unit/test_crunch_cube.py @@ -8,7 +8,7 @@ # pylint: disable=invalid-name, no-self-use, protected-access -@patch('cr.cube.crunch_cube.CrunchCube._get_slices', lambda x: None) +@patch('cr.cube.crunch_cube.CrunchCube.get_slices', lambda x: None) class TestCrunchCube(TestCase): '''Test class for the CrunchCube unit tests.