diff --git a/src/cr/cube/cube_slice.py b/src/cr/cube/cube_slice.py index 39a2ce378..13f6748d4 100644 --- a/src/cr/cube/cube_slice.py +++ b/src/cr/cube/cube_slice.py @@ -412,12 +412,19 @@ def _prepare_index_baseline(self, axis): # We need this in order to end up with the right shape of the # numerator vs denominator. baseline = self.margin(axis=(1 - axis), include_missing=True) + if len(baseline.shape) <= 1: + # If any dimension gets flattened out, due to having a single + # element, re-inflate it + baseline = baseline[None, :] slice_ = [slice(None)] total_axis = None if isinstance(self.mr_dim_ind, tuple): - if self.get_shape()[0] == 1: + if self.get_shape()[0] == 1 and axis == 0: total_axis = axis slice_ = [0] + elif self.get_shape()[0] == 1 and axis == 1: + total_axis = 1 + slice_ = [slice(None), 0] else: total_axis = axis + 1 slice_ += [slice(None), 0] if axis == 1 else [0] @@ -427,7 +434,11 @@ def _prepare_index_baseline(self, axis): if self.mr_dim_ind == 0 and axis != 0 else [slice(None), 0] ) - total_axis = axis if self.mr_dim_ind != 0 else 1 - axis + total_axis = ( + axis + if self.mr_dim_ind != 0 else + 1 - axis + ) total = np.sum(baseline, axis=total_axis) baseline = baseline[slice_] diff --git a/tests/integration/test_index_table.py b/tests/integration/test_index_table.py index f9c759628..0e4917c66 100644 --- a/tests/integration/test_index_table.py +++ b/tests/integration/test_index_table.py @@ -142,6 +142,7 @@ def test_mr_x_mr_index_tables_parity_with_whaam_and_r(): actual = cat_x_mr.slices[0].index_table(axis=1) np.testing.assert_almost_equal(actual, expected) + def test_mr_x_3vl_index_tables_parity_with_nssat(): mr_x_3vl = CrunchCube(CR.NSSAT_MR_X_3vl) # Test column direction @@ -163,12 +164,14 @@ def test_mr_x_3vl_index_tables_parity_with_nssat(): actual = mr_x_3vl.slices[0].index_table(axis=1) np.testing.assert_almost_equal(actual, expected) + def test_mr_x_mr_index_tables_parity_with_nssat(): mr_x_mr = CrunchCube(CR.NSSAT_MR_X_MR) # Test column direction expected = np.array([ [114.917891097666, 94.6007480891202, 75.7981149285497, - 41.5084915084915, 64.5687645687646, 581.118881118881, np.nan, 0, np.nan], + 41.5084915084915, 64.5687645687646, 581.118881118881, np.nan, 0, + np.nan], [90.0597657183839, 95.9426026719446, 102.497687326549, 84.1945288753799, 261.93853427896, 0, np.nan, 0, np.nan], [99.4879510762734, 101.567130443518, 101.446145177951, @@ -183,18 +186,20 @@ def test_mr_x_mr_index_tables_parity_with_nssat(): [104.349919743178, 85.9011627906977, 68.8276397515528, 37.6913265306122, 58.6309523809524, 527.678571428571, np.nan, 0, np.nan], [98.1631656082071, 104.575328614762, 111.7202268431, 91.7701863354037, - 285.507246376812, 0, np.nan, 0 ,np.nan], - [99.6740889304191, 101.757158356526, 101.635946732516, 107.03419298754, + 285.507246376812, 0, np.nan, 0, np.nan], + [99.6740889304191, 101.757158356526, 101.635946732516, 107.03419298754, 86.57876943881, 59.9391480730223, np.nan, 119.878296146045, np.nan], - [np.nan,np.nan,np.nan,np.nan,np.nan,np.nan,np.nan,np.nan,np.nan] + [np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan] ]) actual = mr_x_mr.slices[0].index_table(axis=1) np.testing.assert_almost_equal(actual, expected) + def test_mr_single_cat_x_mr(): cube_slice = CrunchCube(CR.MR_SINGLE_CAT_X_MR).slices[0] expected = np.array([[100, 100, np.nan]]) np.testing.assert_array_equal(cube_slice.index_table(axis=0), expected) + np.testing.assert_array_equal(cube_slice.index_table(axis=1), expected) def test_mr_x_mr_single_cat(): diff --git a/tests/unit/test_index_table.py b/tests/unit/test_index_table.py index 87c92fb71..c0273a9d7 100644 --- a/tests/unit/test_index_table.py +++ b/tests/unit/test_index_table.py @@ -33,6 +33,7 @@ def index_fixture(request): cc.ndim = 2 cc.mr_dim_ind = None cc.proportions.return_value = np.array(proportions) + cc.as_array.return_value = np.array(proportions) cs = CubeSlice(cc, 0) base = np.array(base) return cs, axis, base, expected