From b324e2340002357d22080b4fdeed56579f95be0b Mon Sep 17 00:00:00 2001 From: Slobodan Ilic Date: Sat, 27 Jan 2018 20:52:58 +0100 Subject: [PATCH] Fix margin edge case for 3D cube with MR --- src/cr/cube/crunch_cube.py | 27 ++++++--------------------- 1 file changed, 6 insertions(+), 21 deletions(-) diff --git a/src/cr/cube/crunch_cube.py b/src/cr/cube/crunch_cube.py index fa0352c32..8b7b61a66 100644 --- a/src/cr/cube/crunch_cube.py +++ b/src/cr/cube/crunch_cube.py @@ -404,24 +404,6 @@ def _double_mr_margin(self, axis, weighted): return (selected + non_selected)[np.ix_(*self.valid_indices)] - ################################################################### - - # TODO: Revisit this portion of the code, if we have problems. However, - # with current '_get_table' implementation, all dimensions should be - # preserved, and there shouldn't be any need for inflating. - - # if axis == 1: - # if len(array.shape) == 1: - # # In case of a flattened array (which happens with MR x CAT - # # (single element)), restore the flattened dimension. - # array = array[:, np.newaxis] - - # # If MR margin is calculated by rows, we only need the counts - # # and that's why we use array and not margin. - # return np.sum(array, axis) - - ################################################################### - def _mr_margin(self, axis, weighted, adjusted): if self.is_double_mr: return self._double_mr_margin(axis, weighted) @@ -438,9 +420,12 @@ def _mr_margin(self, axis, weighted, adjusted): # For cases when margin is calculated along the axis which is not MR, # we need to perform sumation along that axis, on the tabular # representation of the cube (which is obtained with 'as_array'). - if (self.mr_dim_ind in [0, 2] and axis == 1 or - self.mr_dim_ind == 1 and axis == 0): - + calculate_along_non_mr = ( + self.mr_dim_ind in [0, 2] and axis == 1 or + self.mr_dim_ind == 1 and axis == 0 or + self.mr_dim_ind == 1 and axis == 1 and len(self.dimensions) > 2 + ) + if calculate_along_non_mr: array = self.as_array(weighted=weighted) if axis == 1 and len(array.shape) == 1: