Skip to content

Commit

Permalink
Merge pull request #20 from raphaeldussin/extrapolation_update
Browse files Browse the repository at this point in the history
Extrapolation update
  • Loading branch information
raphaeldussin committed Sep 10, 2020
2 parents f5ce8a5 + 13369cc commit ef214fe
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 317 deletions.
396 changes: 82 additions & 314 deletions doc/notebooks/Masking.ipynb

Large diffs are not rendered by default.

38 changes: 37 additions & 1 deletion xesmf/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,10 @@ def add_corner(grid, lon_b, lat_b):


def esmf_regrid_build(sourcegrid, destgrid, method,
filename=None, extra_dims=None, ignore_degenerate=None):
filename=None, extra_dims=None,
extrap_method=None, extrap_dist_exponent=None,
extrap_num_src_pnts=None,
ignore_degenerate=None):
'''
Create an ESMF.Regrid object, containing regridding weights.
Expand Down Expand Up @@ -252,6 +255,20 @@ def esmf_regrid_build(sourcegrid, destgrid, method,
For example, if extra_dims=[Nlev, Ntime], then the data field dimension
will be [Nlon, Nlat, Nlev, Ntime]
extrap_method : str, optional
Extrapolation method. Options are
- 'inverse_dist'
- 'nearest_s2d'
extrap_dist_exponent : float, optional
The exponent to raise the distance to when calculating weights for the
extrapolation method. If none are specified, defaults to 2.0
extrap_num_src_pnts : int, optional
The number of source points to use for the extrapolation methods
that use more than one source point. If none are specified, defaults to 8
ignore_degenerate : bool, optional
If False (default), raise error if grids contain degenerated cells
(i.e. triangles or lines, instead of quadrilaterals)
Expand All @@ -276,6 +293,22 @@ def esmf_regrid_build(sourcegrid, destgrid, method,
raise ValueError('method should be chosen from '
'{}'.format(list(method_dict.keys())))

# use shorter, clearer names for options in ESMF.ExtrapMethod
extrap_dict = {'inverse_dist': ESMF.ExtrapMethod.NEAREST_IDAVG,
'nearest_s2d': ESMF.ExtrapMethod.NEAREST_STOD,
None: None
}
try:
esmf_extrap_method = extrap_dict[extrap_method]
except KeyError:
raise KeyError('`extrap_method` should be chosen from '
'{}'.format(list(extrap_dict.keys())))

# until ESMPy updates ESMP_FieldRegridStoreFile, extrapolation is not possible
# if files are written on disk
if (extrap_method is not None) & (filename is not None):
raise ValueError('`extrap_method` cannot be used along with `filename`.')

# conservative regridding needs cell corner information
if method in ['conservative', 'conservative_normed']:
if not sourcegrid.has_corners:
Expand Down Expand Up @@ -320,6 +353,9 @@ def esmf_regrid_build(sourcegrid, destgrid, method,
unmapped_action=ESMF.UnmappedAction.IGNORE,
ignore_degenerate=ignore_degenerate,
norm_type=norm_type,
extrap_method=esmf_extrap_method,
extrap_dist_exponent=extrap_dist_exponent,
extrap_num_src_pnts=extrap_num_src_pnts,
factors=filename is None)
if allow_masked_values:
kwargs.update(dict(src_mask_values=[0], dst_mask_values=[0]))
Expand Down
28 changes: 26 additions & 2 deletions xesmf/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ def ds_to_ESMFlocstream(ds):
class Regridder(object):
def __init__(self, ds_in, ds_out, method, periodic=False,
filename=None, reuse_weights=False,
extrap_method=None, extrap_dist_exponent=None,
extrap_num_src_pnts=None,
weights=None, ignore_degenerate=None,
locstream_in=False, locstream_out=False):
"""
Expand All @@ -122,8 +124,10 @@ def __init__(self, ds_in, ds_out, method, periodic=False,
----------
ds_in, ds_out : xarray DataSet, or dictionary
Contain input and output grid coordinates. Look for variables
``lon``, ``lat``, and optionally ``lon_b``, ``lat_b`` for
conservative methods.
``lon``, ``lat``, optionally ``lon_b``, ``lat_b`` for
conservative methods, and ``mask``. Note that for `mask`,
the ESMF convention is used, where masked values are identified
by 0, and non-masked values by 1.
Shape can be 1D (n_lon,) and (n_lat,) for rectilinear grids,
or 2D (n_y, n_x) for general curvilinear grids.
Expand Down Expand Up @@ -158,6 +162,20 @@ def __init__(self, ds_in, ds_out, method, periodic=False,
Whether to read existing weight file to save computing time.
False by default (i.e. re-compute, not reuse).
extrap_method : str, optional
Extrapolation method. Options are
- 'inverse_dist'
- 'nearest_s2d'
extrap_dist_exponent : float, optional
The exponent to raise the distance to when calculating weights for the
extrapolation method. If none are specified, defaults to 2.0
extrap_num_src_pnts : int, optional
The number of source points to use for the extrapolation methods
that use more than one source point. If none are specified, defaults to 8
weights : None, coo_matrix, dict, str, Dataset, Path,
Regridding weights, stored as
- a scipy.sparse COO matrix,
Expand Down Expand Up @@ -192,6 +210,9 @@ def __init__(self, ds_in, ds_out, method, periodic=False,
self.method = method
self.periodic = periodic
self.reuse_weights = reuse_weights
self.extrap_method = extrap_method
self.extrap_dist_exponent = extrap_dist_exponent
self.extrap_num_src_pnts = extrap_num_src_pnts
self.ignore_degenerate = ignore_degenerate
self.locstream_in = locstream_in
self.locstream_out = locstream_out
Expand Down Expand Up @@ -296,6 +317,9 @@ def _get_default_filename(self):

def _compute_weights(self):
regrid = esmf_regrid_build(self._grid_in, self._grid_out, self.method,
extrap_method = self.extrap_method,
extrap_dist_exponent = self.extrap_dist_exponent,
extrap_num_src_pnts = self.extrap_num_src_pnts,
ignore_degenerate=self.ignore_degenerate)

w = regrid.get_weights_dict(deep_copy=True)
Expand Down
19 changes: 19 additions & 0 deletions xesmf/tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,25 @@ def test_esmf_build_bilinear():
esmf_regrid_finalize(regrid)


def test_esmf_extrapolation():

grid_in = esmf_grid(lon_in.T, lat_in.T)
grid_out = esmf_grid(lon_out.T, lat_out.T)

regrid = esmf_regrid_build(grid_in, grid_out, 'bilinear')
data_out_esmpy = esmf_regrid_apply(regrid, data_in.T).T
# without extrapolation, the first and last lines/columns = 0
assert data_out_esmpy[0, 0] == 0

regrid = esmf_regrid_build(grid_in, grid_out, 'bilinear',
extrap_method='inverse_dist',
extrap_num_src_pnts=3,
extrap_dist_exponent=1)
data_out_esmpy = esmf_regrid_apply(regrid, data_in.T).T
# the 3 closest points in data_in are 2.010, 2.005, and 1.992. The result should be roughly equal to 2.0
assert np.round(data_out_esmpy[0, 0], 1) == 2.0


def test_regrid():

# use conservative regridding as an example,
Expand Down

0 comments on commit ef214fe

Please sign in to comment.