Skip to content

Commit

Permalink
Fix tests for all single-element cases
Browse files Browse the repository at this point in the history
  • Loading branch information
slobodan-ilic committed Jan 2, 2019
1 parent 3414936 commit 0373b3f
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 6 deletions.
15 changes: 13 additions & 2 deletions src/cr/cube/cube_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,12 +412,19 @@ def _prepare_index_baseline(self, axis):
# We need this in order to end up with the right shape of the
# numerator vs denominator.
baseline = self.margin(axis=(1 - axis), include_missing=True)
if len(baseline.shape) <= 1:
# If any dimension gets flattened out, due to having a single
# element, re-inflate it
baseline = baseline[None, :]
slice_ = [slice(None)]
total_axis = None
if isinstance(self.mr_dim_ind, tuple):
if self.get_shape()[0] == 1:
if self.get_shape()[0] == 1 and axis == 0:
total_axis = axis
slice_ = [0]
elif self.get_shape()[0] == 1 and axis == 1:
total_axis = 1
slice_ = [slice(None), 0]
else:
total_axis = axis + 1
slice_ += [slice(None), 0] if axis == 1 else [0]
Expand All @@ -427,7 +434,11 @@ def _prepare_index_baseline(self, axis):
if self.mr_dim_ind == 0 and axis != 0 else
[slice(None), 0]
)
total_axis = axis if self.mr_dim_ind != 0 else 1 - axis
total_axis = (
axis
if self.mr_dim_ind != 0 else
1 - axis
)

total = np.sum(baseline, axis=total_axis)
baseline = baseline[slice_]
Expand Down
13 changes: 9 additions & 4 deletions tests/integration/test_index_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def test_mr_x_mr_index_tables_parity_with_whaam_and_r():
actual = cat_x_mr.slices[0].index_table(axis=1)
np.testing.assert_almost_equal(actual, expected)


def test_mr_x_3vl_index_tables_parity_with_nssat():
mr_x_3vl = CrunchCube(CR.NSSAT_MR_X_3vl)
# Test column direction
Expand All @@ -163,12 +164,14 @@ def test_mr_x_3vl_index_tables_parity_with_nssat():
actual = mr_x_3vl.slices[0].index_table(axis=1)
np.testing.assert_almost_equal(actual, expected)


def test_mr_x_mr_index_tables_parity_with_nssat():
mr_x_mr = CrunchCube(CR.NSSAT_MR_X_MR)
# Test column direction
expected = np.array([
[114.917891097666, 94.6007480891202, 75.7981149285497,
41.5084915084915, 64.5687645687646, 581.118881118881, np.nan, 0, np.nan],
41.5084915084915, 64.5687645687646, 581.118881118881, np.nan, 0,
np.nan],
[90.0597657183839, 95.9426026719446, 102.497687326549,
84.1945288753799, 261.93853427896, 0, np.nan, 0, np.nan],
[99.4879510762734, 101.567130443518, 101.446145177951,
Expand All @@ -183,18 +186,20 @@ def test_mr_x_mr_index_tables_parity_with_nssat():
[104.349919743178, 85.9011627906977, 68.8276397515528, 37.6913265306122,
58.6309523809524, 527.678571428571, np.nan, 0, np.nan],
[98.1631656082071, 104.575328614762, 111.7202268431, 91.7701863354037,
285.507246376812, 0, np.nan, 0 ,np.nan],
[99.6740889304191, 101.757158356526, 101.635946732516, 107.03419298754,
285.507246376812, 0, np.nan, 0, np.nan],
[99.6740889304191, 101.757158356526, 101.635946732516, 107.03419298754,
86.57876943881, 59.9391480730223, np.nan, 119.878296146045, np.nan],
[np.nan,np.nan,np.nan,np.nan,np.nan,np.nan,np.nan,np.nan,np.nan]
[np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan]
])
actual = mr_x_mr.slices[0].index_table(axis=1)
np.testing.assert_almost_equal(actual, expected)


def test_mr_single_cat_x_mr():
cube_slice = CrunchCube(CR.MR_SINGLE_CAT_X_MR).slices[0]
expected = np.array([[100, 100, np.nan]])
np.testing.assert_array_equal(cube_slice.index_table(axis=0), expected)
np.testing.assert_array_equal(cube_slice.index_table(axis=1), expected)


def test_mr_x_mr_single_cat():
Expand Down
1 change: 1 addition & 0 deletions tests/unit/test_index_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def index_fixture(request):
cc.ndim = 2
cc.mr_dim_ind = None
cc.proportions.return_value = np.array(proportions)
cc.as_array.return_value = np.array(proportions)
cs = CubeSlice(cc, 0)
base = np.array(base)
return cs, axis, base, expected

0 comments on commit 0373b3f

Please sign in to comment.