Skip to content

Commit

Permalink
Add bilinear regridding support (#148)
Browse files Browse the repository at this point in the history
* add bilinear method

* add test

* add test for mesh face regridding

* add test

* add add GridToMesh test

* fix tests

* add tests for MeshToGrid

* tidy test code

* add tests

* fix tests

* fix tests

* add io tests

* fix regridder saving

* update docstrings

* add diagrams

* add tests

* fix tests

* fix test

* add test coverage

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix test

* add to docstring

* further improvements

* address review comments

* fix tests

* Update esmf_regrid/tests/unit/experimental/unstructured_scheme/test__mesh_to_MeshInfo.py

Co-authored-by: Martin Yeo <40734014+trexfeathers@users.noreply.github.com>

* add comments

* edit comment

* address review comments

* address review comment

* address review comment

* shorten docstring link

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Martin Yeo <40734014+trexfeathers@users.noreply.github.com>
  • Loading branch information
3 people committed Mar 15, 2022
1 parent baa1030 commit f97a0bf
Show file tree
Hide file tree
Showing 15 changed files with 915 additions and 128 deletions.
18 changes: 12 additions & 6 deletions esmf_regrid/_esmf_sdo.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def __init__(
crs=None,
circular=False,
areas=None,
center=False,
):
"""
Create a :class:`GridInfo` object describing the grid.
Expand Down Expand Up @@ -127,6 +128,9 @@ def __init__(
Array describing the areas associated with
each face. If ``None``, then :mod:`ESMF` will use its own
calculated areas.
center : bool, default=False
Describes if the center points of the grid cells are used in regridding
calculations.
"""
self.lons = lons
Expand Down Expand Up @@ -173,6 +177,7 @@ def __init__(
self.crs = crs
self.circular = circular
self.areas = areas
self.center = center
super().__init__(
shape=shape,
index_offset=1,
Expand Down Expand Up @@ -253,13 +258,14 @@ def _make_esmf_sdo(self):
grid_corner_y = grid.get_coords(1, staggerloc=ESMF.StaggerLoc.CORNER)
grid_corner_y[:] = truecornerlats

# Grid center points would be added here, this is not necessary for
# Grid center points are added here, this is not necessary for
# conservative area weighted regridding
# grid.add_coords(staggerloc=ESMF.StaggerLoc.CENTER)
# grid_center_x = grid.get_coords(0, staggerloc=ESMF.StaggerLoc.CENTER)
# grid_center_x[:] = truecenterlons
# grid_center_y = grid.get_coords(1, staggerloc=ESMF.StaggerLoc.CENTER)
# grid_center_y[:] = truecenterlats
if self.center:
grid.add_coords(staggerloc=ESMF.StaggerLoc.CENTER)
grid_center_x = grid.get_coords(0, staggerloc=ESMF.StaggerLoc.CENTER)
grid_center_x[:] = truecenterlons
grid_center_y = grid.get_coords(1, staggerloc=ESMF.StaggerLoc.CENTER)
grid_center_y[:] = truecenterlats

if areas is not None:
grid.add_item(ESMF.GridItem.AREA, staggerloc=ESMF.StaggerLoc.CENTER)
Expand Down
24 changes: 20 additions & 4 deletions esmf_regrid/esmf_regridder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
]


def _get_regrid_weights_dict(src_field, tgt_field):
def _get_regrid_weights_dict(src_field, tgt_field, regrid_method):
regridder = ESMF.Regrid(
src_field,
tgt_field,
ignore_degenerate=True,
regrid_method=ESMF.RegridMethod.CONSERVE,
regrid_method=regrid_method,
unmapped_action=ESMF.UnmappedAction.IGNORE,
# Choosing the norm_type DSTAREA allows for mdtol type operations
# to be performed using the weights information later on.
Expand Down Expand Up @@ -52,7 +52,7 @@ def _weights_dict_to_sparse_array(weights, shape, index_offsets):
class Regridder:
"""Regridder for directly interfacing with :mod:`ESMF`."""

def __init__(self, src, tgt, precomputed_weights=None):
def __init__(self, src, tgt, method="conservative", precomputed_weights=None):
"""
Create a regridder from descriptions of horizontal grids/meshes.
Expand All @@ -71,6 +71,10 @@ def __init__(self, src, tgt, precomputed_weights=None):
Describes the target mesh/grid.
Data output by this regridder will be a :class:`numpy.ndarray` whose
shape is compatible with ``tgt``.
method : str
Either "conservative" or "bilinear". Corresponds to the :mod:`ESMF` methods
:attr:`~ESMF.api.constants.RegridMethod.CONSERVE` or
:attr:`~ESMF.api.constants.RegridMethod.BILINEAR` used to calculate weights.
precomputed_weights : :class:`scipy.sparse.spmatrix`, optional
If ``None``, :mod:`ESMF` will be used to
calculate regridding weights. Otherwise, :mod:`ESMF` will be bypassed
Expand All @@ -79,11 +83,23 @@ def __init__(self, src, tgt, precomputed_weights=None):
self.src = src
self.tgt = tgt

if method == "conservative":
esmf_regrid_method = ESMF.RegridMethod.CONSERVE
elif method == "bilinear":
esmf_regrid_method = ESMF.RegridMethod.BILINEAR
else:
raise ValueError(
f"method must be either 'bilinear' or 'conservative', got '{method}'."
)
self.method = method

self.esmf_regrid_version = esmf_regrid.__version__
if precomputed_weights is None:
self.esmf_version = ESMF.__version__
weights_dict = _get_regrid_weights_dict(
src.make_esmf_field(), tgt.make_esmf_field()
src.make_esmf_field(),
tgt.make_esmf_field(),
regrid_method=esmf_regrid_method,
)
self.weight_matrix = _weights_dict_to_sparse_array(
weights_dict,
Expand Down
26 changes: 21 additions & 5 deletions esmf_regrid/experimental/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
VERSION_ESMF = "ESMF_version"
VERSION_INITIAL = "esmf_regrid_version_on_initialise"
MDTOL = "mdtol"
METHOD = "method"


def save_regridder(rg, filename):
Expand Down Expand Up @@ -57,15 +58,19 @@ def save_regridder(rg, filename):
src_cube.add_dim_coord(src_grid[1], 1)

tgt_mesh = rg.mesh
tgt_data = np.zeros(tgt_mesh.face_node_connectivity.indices.shape[0])
tgt_location = rg.location
tgt_mesh_coords = tgt_mesh.to_MeshCoords(tgt_location)
tgt_data = np.zeros(tgt_mesh_coords[0].points.shape[0])
tgt_cube = Cube(tgt_data, var_name=TARGET_NAME, long_name=TARGET_NAME)
for coord in tgt_mesh.to_MeshCoords("face"):
for coord in tgt_mesh_coords:
tgt_cube.add_aux_coord(coord, 0)
elif regridder_type == "MeshToGridESMFRegridder":
src_mesh = rg.mesh
src_data = np.zeros(src_mesh.face_node_connectivity.indices.shape[0])
src_location = rg.location
src_mesh_coords = src_mesh.to_MeshCoords(src_location)
src_data = np.zeros(src_mesh_coords[0].points.shape[0])
src_cube = Cube(src_data, var_name=SOURCE_NAME, long_name=SOURCE_NAME)
for coord in src_mesh.to_MeshCoords("face"):
for coord in src_mesh_coords:
src_cube.add_aux_coord(coord, 0)

tgt_grid = (rg.grid_y, rg.grid_x)
Expand All @@ -81,6 +86,8 @@ def save_regridder(rg, filename):
)
raise TypeError(msg)

method = rg.method

weight_matrix = rg.regridder.weight_matrix
reformatted_weight_matrix = weight_matrix.tocoo()
weight_data = reformatted_weight_matrix.data
Expand All @@ -104,6 +111,7 @@ def save_regridder(rg, filename):
"esmf_regrid_version_on_save": save_version,
"normalization": normalization,
MDTOL: mdtol,
METHOD: method,
}

weights_cube = Cube(weight_data, var_name=WEIGHTS_NAME, long_name=WEIGHTS_NAME)
Expand Down Expand Up @@ -167,6 +175,10 @@ def load_regridder(filename):
assert regridder_type in REGRIDDER_NAME_MAP.keys()
scheme = REGRIDDER_NAME_MAP[regridder_type]

# Determine the regridding method, allowing for files created when
# conservative regridding was the only method.
method = weights_cube.attributes.get(METHOD, "conservative")

# Reconstruct the weight matrix.
weight_data = weights_cube.data
weight_rows = weights_cube.coord(WEIGHTS_ROW_NAME).points
Expand All @@ -179,7 +191,11 @@ def load_regridder(filename):
mdtol = weights_cube.attributes[MDTOL]

regridder = scheme(
src_cube, tgt_cube, mdtol=mdtol, precomputed_weights=weight_matrix
src_cube,
tgt_cube,
mdtol=mdtol,
method=method,
precomputed_weights=weight_matrix,
)

esmf_version = weights_cube.attributes[VERSION_ESMF]
Expand Down
38 changes: 34 additions & 4 deletions esmf_regrid/experimental/unstructured_regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def __init__(
node_start_index,
elem_start_index=0,
areas=None,
elem_coords=None,
location="face",
):
"""
Create a :class:`MeshInfo` object describing a UGRID-like mesh.
Expand All @@ -35,7 +37,7 @@ def __init__(
node_coords: :obj:`~numpy.typing.ArrayLike`
An ``Nx2`` array describing the location of the nodes of the mesh.
``node_coords[:,0]`` describes the longitudes in degrees and
``node_coords[:,1]`` describes the latitudes in degrees
``node_coords[:,1]`` describes the latitudes in degrees.
face_node_connectivity: :obj:`~numpy.typing.ArrayLike`
A masked array describing the face node connectivity of the
mesh. The unmasked points of ``face_node_connectivity[i]`` describe
Expand All @@ -54,16 +56,34 @@ def __init__(
areas: :obj:`~numpy.typing.ArrayLike`, optional
Array describing the areas associated with
each face. If ``None``, then :mod:`ESMF` will use its own calculated areas.
elem_coords : :obj:`~numpy.typing.ArrayLike`, optional
An ``Nx2`` array describing the location of the face centers of the mesh.
``elem_coords[:,0]`` describes the longitudes in degrees and
``elem_coords[:,1]`` describes the latitudes in degrees.
location : str, default="face"
Either "face" or "node". Describes the location for data on the mesh.
"""
self.node_coords = node_coords
self.fnc = face_node_connectivity
self.nsi = node_start_index
self.esi = elem_start_index
self.areas = areas
self.elem_coords = elem_coords
if location == "face":
field_kwargs = {"meshloc": ESMF.MeshLoc.ELEMENT}
shape = (len(face_node_connectivity),)
elif location == "node":
field_kwargs = {"meshloc": ESMF.MeshLoc.NODE}
shape = (len(node_coords),)
else:
raise ValueError(
f"The mesh location '{location}' is not supported, only "
f"'face' and 'node' are supported."
)
super().__init__(
shape=(len(face_node_connectivity),),
shape=shape,
index_offset=self.esi,
field_kwargs={"meshloc": ESMF.MeshLoc.ELEMENT},
field_kwargs=field_kwargs,
)

def _as_esmf_info(self):
Expand All @@ -78,6 +98,7 @@ def _as_esmf_info(self):
elemType = self.fnc.count(axis=1)
# Experiments seem to indicate that ESMF is using 0 indexing here
elemConn = self.fnc.compressed() - self.nsi
elemCoord = self.elem_coords
result = (
num_node,
num_elem,
Expand All @@ -88,6 +109,7 @@ def _as_esmf_info(self):
elemType,
elemConn,
self.areas,
elemCoord,
)
return result

Expand All @@ -103,6 +125,7 @@ def _make_esmf_sdo(self):
elemType,
elemConn,
areas,
elemCoord,
) = info
# ESMF can handle other dimensionalities, but we are unlikely
# to make such a use of ESMF
Expand All @@ -111,5 +134,12 @@ def _make_esmf_sdo(self):
)

emesh.add_nodes(num_node, nodeId, nodeCoord, nodeOwner)
emesh.add_elements(num_elem, elemId, elemType, elemConn, element_area=areas)
emesh.add_elements(
num_elem,
elemId,
elemType,
elemConn,
element_area=areas,
element_coords=elemCoord,
)
return emesh

0 comments on commit f97a0bf

Please sign in to comment.