Skip to content
31 changes: 26 additions & 5 deletions lib/iris/analysis/_area_weighted.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,22 @@ def __init__(self, src_grid_cube, target_grid_cube, mdtol=1):

.. Note::

Both sourge and target cubes must have an XY grid defined by
Both source and target cubes must have an XY grid defined by
separate X and Y dimensions with dimension coordinates.
All of the XY dimension coordinates must also be bounded, and have
the same cooordinate system.

"""
# Snapshot the state of the cubes to ensure that the regridder is
# impervious to external changes to the original source cubes.
# impervious to external changes to the original cubes.
self._src_grid = snapshot_grid(src_grid_cube)
self._target_grid = snapshot_grid(target_grid_cube)

# Store the x_dim and y_dim of the source cube.
x, y = get_xy_dim_coords(src_grid_cube)
self._src_x_dim = src_grid_cube.coord_dims(x)
self._src_y_dim = src_grid_cube.coord_dims(y)

# Missing data tolerance.
if not (0 <= mdtol <= 1):
msg = "Value for mdtol must be in range 0 - 1, got {}."
Expand All @@ -61,6 +67,10 @@ def __init__(self, src_grid_cube, target_grid_cube, mdtol=1):
# current usage of the experimental regrid function.
self._target_grid_cube_cache = None

self._regrid_info = eregrid._regrid_area_weighted_rectilinear_src_and_grid__prepare(
src_grid_cube, self._target_grid_cube
)

@property
def _target_grid_cube(self):
if self._target_grid_cube_cache is None:
Expand Down Expand Up @@ -92,11 +102,22 @@ def __call__(self, cube):
area-weighted regridding.

"""
if get_xy_dim_coords(cube) != self._src_grid:
if get_xy_dim_coords(cube) != self._src_grid or not (
_xy_data_dims_are_equal(cube, self._src_x_dim, self._src_y_dim)
):
raise ValueError(
"The given cube is not defined on the same "
"source grid as this regridder."
)
return eregrid.regrid_area_weighted_rectilinear_src_and_grid(
cube, self._target_grid_cube, mdtol=self._mdtol
return eregrid._regrid_area_weighted_rectilinear_src_and_grid__perform(
cube, self._regrid_info, mdtol=self._mdtol
)


def _xy_data_dims_are_equal(cube, x_dim, y_dim):
"""
Return whether the data dimensions of the x and y coordinates on the
the cube are equal to the values ``x_dim`` and ``y_dim``, respectively.
"""
x1, y1 = get_xy_dim_coords(cube)
return cube.coord_dims(x1) == x_dim and cube.coord_dims(y1) == y_dim
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,17 @@ def extract_grid(self, cube):

def check_mdtol(self, mdtol=None):
src_grid, target_grid = self.grids()
if mdtol is None:
regridder = AreaWeightedRegridder(src_grid, target_grid)
mdtol = 1
else:
regridder = AreaWeightedRegridder(
src_grid, target_grid, mdtol=mdtol
)
with mock.patch(
"iris.experimental.regrid."
"_regrid_area_weighted_rectilinear_src_and_grid__prepare"
) as prepare:
if mdtol is None:
regridder = AreaWeightedRegridder(src_grid, target_grid)
mdtol = 1
else:
regridder = AreaWeightedRegridder(
src_grid, target_grid, mdtol=mdtol
)

# Make a new cube to regrid with different data so we can
# distinguish between regridding the original src grid
Expand All @@ -58,18 +62,22 @@ def check_mdtol(self, mdtol=None):

with mock.patch(
"iris.experimental.regrid."
"regrid_area_weighted_rectilinear_src_and_grid",
"_regrid_area_weighted_rectilinear_src_and_grid__perform",
return_value=mock.sentinel.result,
) as regrid:
) as perform:
result = regridder(src)

self.assertEqual(regrid.call_count, 1)
_, args, kwargs = regrid.mock_calls[0]

self.assertEqual(args[0], src)
# Prepare:
self.assertEqual(prepare.call_count, 1)
_, args, kwargs = prepare.mock_calls[0]
self.assertEqual(
self.extract_grid(args[1]), self.extract_grid(target_grid)
)

# Perform:
self.assertEqual(perform.call_count, 1)
_, args, kwargs = perform.mock_calls[0]
self.assertEqual(args[0], src)
self.assertEqual(kwargs, {"mdtol": mdtol})
self.assertIs(result, mock.sentinel.result)

Expand Down Expand Up @@ -164,6 +172,38 @@ def test_multiple_src_on_same_grid(self):
self.assertArrayEqual(result1.data, reference1.data)
self.assertArrayEqual(result2.data, reference2.data)

def test_mismatched_data_dims(self):
coord_names = ["latitude", "longitude"]
x = np.linspace(20, 32, 4)
y = np.linspace(10, 22, 4)
src1 = self.cube(x, y)

data = np.arange(len(y) * len(x)).reshape(len(x), len(y))
src2 = Cube(data)
lat = DimCoord(y, "latitude", units="degrees")
lon = DimCoord(x, "longitude", units="degrees")
# Add dim coords in opposite order to self.cube.
src2.add_dim_coord(lat, 1)
src2.add_dim_coord(lon, 0)
for name in coord_names:
# Ensure contiguous bounds exists.
src1.coord(name).guess_bounds()
src2.coord(name).guess_bounds()

target = self.cube(np.linspace(20, 32, 2), np.linspace(10, 22, 2))
for name in coord_names:
# Ensure the bounds of the target cover the same range as the
# source.
target.coord(name).bounds = np.column_stack(
(
src1.coord(name).bounds[[0, 1], [0, 1]],
src1.coord(name).bounds[[2, 3], [0, 1]],
)
)

regridder = AreaWeightedRegridder(src1, target)
self.assertRaises(ValueError, regridder, src2)


if __name__ == "__main__":
tests.main()