Skip to content

Commit

Permalink
support build aux coord with rollaxis
Browse files Browse the repository at this point in the history
  • Loading branch information
bjlittle committed Feb 15, 2024
1 parent 700d2ce commit ba36f80
Showing 1 changed file with 17 additions and 7 deletions.
Expand Up @@ -96,16 +96,24 @@ def _get_per_test_bounds_var(_coord_unused):
)

@classmethod
def _make_array_and_cf_data(cls, dimension_names):
def _make_array_and_cf_data(cls, dimension_names, rollaxis=False):
shape = tuple(cls.dim_names_lens[name] for name in dimension_names)
cf_data = mock.MagicMock(_FillValue=None, spec=[])
cf_data.chunking = mock.MagicMock(return_value=shape)
data = np.arange(np.prod(shape), dtype=float).reshape(shape)
data = np.arange(np.prod(shape), dtype=float)
if rollaxis:
shape = shape[1:] + (shape[0],)
data = data.reshape(shape)
data = np.rollaxis(data, -1)
else:
data = data.reshape(shape)
return data, cf_data

def _make_cf_bounds_var(self, dimension_names):
def _make_cf_bounds_var(self, dimension_names, rollaxis=False):
# Create the bounds cf variable.
bounds, cf_data = self._make_array_and_cf_data(dimension_names)
bounds, cf_data = self._make_array_and_cf_data(
dimension_names, rollaxis=rollaxis
)
bounds *= 1000 # Convert to metres.
cf_bounds_var = mock.Mock(
spec=CFVariable,
Expand All @@ -121,8 +129,10 @@ def _make_cf_bounds_var(self, dimension_names):

return cf_bounds_var

def _check_case(self, dimension_names):
self.cf_bounds_var = self._make_cf_bounds_var(dimension_names=dimension_names)
def _check_case(self, dimension_names, rollaxis=False):
self.cf_bounds_var = self._make_cf_bounds_var(
dimension_names, rollaxis=rollaxis
)

# Asserts must lie within context manager because of deferred loading.
build_auxiliary_coordinate(self.engine, self.cf_coord_var)
Expand All @@ -140,7 +150,7 @@ def test_fastest_varying_vertex_dim__normalise_bounds(self):

def test_slowest_varying_vertex_dim__normalise_bounds(self):
# Bounds in the first (slowest varying) dimension.
self._check_case(dimension_names=("nv", "foo", "bar"))
self._check_case(dimension_names=("nv", "foo", "bar"), rollaxis=True)

def test_fastest_with_different_dim_names__normalise_bounds(self):
# Despite the dimension names ('x', and 'y') differing from the coord's
Expand Down

0 comments on commit ba36f80

Please sign in to comment.