Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
xarray_utils: mask ocean and Antarctica (#219)
* xarray_utils: mask ocean and Antarctica * add CHANGELOG * rename mask_land to mask_ocean * add tests for dataarray * move to top level * rename obj -> data * Update docs/source/api.rst * mask_ocean: mention grid cell center * fix merge issue
- Loading branch information
Showing
5 changed files
with
277 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,3 @@ | ||
# flake8: noqa | ||
|
||
|
||
from mesmer.xarray_utils import global_mean, grid | ||
from mesmer.xarray_utils import global_mean, grid, mask |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
import numpy as np | ||
import regionmask | ||
import xarray as xr | ||
|
||
import mesmer.utils | ||
|
||
|
||
def _where_if_dim(obj, cond, dims): | ||
|
||
# xarray applies where to all data_vars - even if they do not have the corresponding | ||
# dimensions - we don't want that https://github.com/pydata/xarray/issues/7027 | ||
|
||
def _where(da): | ||
if all(dim in da.dims for dim in dims): | ||
return da.where(cond) | ||
return da | ||
|
||
if isinstance(obj, xr.Dataset): | ||
return obj.map(_where, keep_attrs=True) | ||
|
||
return obj.where(cond) | ||
|
||
|
||
def mask_ocean_fraction(data, threshold, *, x_coords="lon", y_coords="lat"): | ||
"""mask out ocean using fractional overlap | ||
Parameters | ||
---------- | ||
data : xr.Dataset | xr.DataArray | ||
Array to mask. | ||
threshold : float | ||
Threshold above which land fraction to consider a grid point as a land grid | ||
point. Must be must be between 0 and 1 inclusive. | ||
x_coords : str, default: "lon" | ||
Name of the x-coordinates. | ||
y_coords : str, default: "lat" | ||
Name of the y-coordinates. | ||
Returns | ||
------- | ||
data : xr.Dataset | xr.DataArray | ||
Array with ocean grid points masked out. | ||
Notes | ||
----- | ||
- Uses the 1:110m land mask from Natural Earth (http://www.naturalearthdata.com). | ||
- The fractional overlap of individual grid points and the land mask can only be | ||
computed for regularly-spaced 1D x- and y-coordinates. For irregularly spaced | ||
coordinates use :py:func:`mesmer.xarray_utils.mask_land`. | ||
""" | ||
|
||
if np.ndim(threshold) != 0 or (threshold < 0) or (threshold > 1): | ||
raise ValueError("`threshold` must be a scalar between 0 and 1 (inclusive).") | ||
|
||
# TODO: allow other masks? | ||
land_110 = regionmask.defined_regions.natural_earth_v5_0_0.land_110 | ||
|
||
try: | ||
mask_fraction = mesmer.utils.regionmaskcompat.mask_3D_frac_approx( | ||
land_110, data[x_coords], data[y_coords] | ||
) | ||
except mesmer.utils.regionmaskcompat.InvalidCoordsError as e: | ||
raise ValueError( | ||
"Cannot calculate fractional mask for irregularly-spaced coords - use " | ||
"``mask_land`` instead." | ||
) from e | ||
|
||
# drop region-specific coords | ||
mask_fraction = mask_fraction.squeeze(drop=True) | ||
|
||
mask_bool = mask_fraction > threshold | ||
|
||
# only mask data_vars that have the coords | ||
return _where_if_dim(data, mask_bool, [y_coords, x_coords]) | ||
|
||
|
||
def mask_ocean(data, *, x_coords="lon", y_coords="lat"): | ||
"""mask out ocean | ||
Parameters | ||
---------- | ||
data : xr.Dataset | xr.DataArray | ||
Array to mask. | ||
x_coords : str, default: "lon" | ||
Name of the x-coordinates. | ||
y_coords : str, default: "lat" | ||
Name of the y-coordinates. | ||
Returns | ||
------- | ||
data : xr.Dataset | xr.DataArray | ||
Array with ocean grid points masked out. | ||
Notes | ||
----- | ||
- Uses the 1:110m land mask from Natural Earth (http://www.naturalearthdata.com). | ||
- Whether a grid cell is in the ocean or on land is based on its center. For | ||
regularly spaced coordinates use :py:func:`mesmer.xarray_utils.mask_land_fraction`. | ||
""" | ||
|
||
# TODO: allow other masks? | ||
land_110 = regionmask.defined_regions.natural_earth_v5_0_0.land_110 | ||
|
||
mask_bool = land_110.mask_3D(data[x_coords], data[y_coords]) | ||
|
||
mask_bool = mask_bool.squeeze(drop=True) | ||
|
||
# only mask data_vars that have the coords | ||
return _where_if_dim(data, mask_bool, [y_coords, x_coords]) | ||
|
||
|
||
def mask_antarctica(data, *, y_coords="lat"): | ||
"""mask out ocean | ||
Parameters | ||
---------- | ||
data : xr.Dataset | xr.DataArray | ||
Array to mask. | ||
y_coords : str, default: "lat" | ||
Name of the y-coordinates. | ||
Returns | ||
------- | ||
data : xr.Dataset | xr.DataArray | ||
Array with Antarctic grid points masked out. | ||
Notes | ||
----- | ||
- Masks grid points below 60°S. | ||
""" | ||
|
||
mask_bool = data[y_coords] >= -60 | ||
|
||
# only mask if data has y_coords | ||
return _where_if_dim(data, mask_bool, [y_coords]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
import numpy as np | ||
import pytest | ||
import xarray as xr | ||
|
||
import mesmer.xarray_utils as mxu | ||
|
||
|
||
def data_lon_lat(as_dataset, x_coords="lon", y_coords="lat"): | ||
|
||
lon = np.arange(0.5, 360, 2) | ||
lat = np.arange(90, -91, -2) | ||
|
||
data = np.random.randn(*lat.shape, *lon.shape) | ||
|
||
da = xr.DataArray( | ||
data, | ||
dims=(y_coords, x_coords), | ||
coords={x_coords: lon, y_coords: lat}, | ||
attrs={"key": "da_attrs"}, | ||
) | ||
|
||
ds = xr.Dataset(data_vars={"data": da, "scalar": 1}, attrs={"key": "ds_attrs"}) | ||
|
||
if as_dataset: | ||
return ds | ||
return ds.data | ||
|
||
|
||
@pytest.mark.parametrize("threshold", ([0, 1], -0.1, 1.1)) | ||
def test_ocean_land_fraction_errors(threshold): | ||
|
||
data = data_lon_lat(as_dataset=True) | ||
|
||
with pytest.raises( | ||
ValueError, match="`threshold` must be a scalar between 0 and 1" | ||
): | ||
mxu.mask.mask_ocean_fraction(data, threshold=threshold) | ||
|
||
|
||
def test_ocean_land_fraction_irregular(): | ||
|
||
lon = [0, 1, 3] | ||
lat = [0, 1, 3] | ||
|
||
data = xr.Dataset(coords={"lon": lon, "lat": lat}) | ||
|
||
with pytest.raises( | ||
ValueError, | ||
match="Cannot calculate fractional mask for irregularly-spaced coords", | ||
): | ||
mxu.mask.mask_ocean_fraction(data, threshold=0.5) | ||
|
||
|
||
def test_ocean_land_fraction_threshold(): | ||
# check that the threshold has an influence | ||
data = data_lon_lat(as_dataset=True) | ||
|
||
result_033 = mxu.mask.mask_ocean_fraction(data, threshold=0.33) | ||
result_066 = mxu.mask.mask_ocean_fraction(data, threshold=0.66) | ||
|
||
assert not (result_033.data == result_066.data).all() | ||
|
||
|
||
def _test_mask(func, as_dataset, threshold=None, **kwargs): | ||
# not checking the actual mask | ||
|
||
data = data_lon_lat(as_dataset=as_dataset, **kwargs) | ||
|
||
kwargs = kwargs if threshold is None else {"threshold": threshold, **kwargs} | ||
result = func(data, **kwargs) | ||
|
||
if as_dataset: | ||
# ensure scalar is not broadcast | ||
assert result.scalar.ndim == 0 | ||
assert result.attrs == {"key": "ds_attrs"} | ||
|
||
result_da = result.data | ||
else: | ||
result_da = result | ||
|
||
# ensure mask is applied | ||
assert result_da.isnull().any() | ||
|
||
assert result_da.attrs == {"key": "da_attrs"} | ||
|
||
|
||
@pytest.mark.parametrize("as_dataset", (True, False)) | ||
def test_ocean_land_fraction_default(as_dataset): | ||
|
||
_test_mask(mxu.mask.mask_ocean_fraction, as_dataset, threshold=0.5) | ||
|
||
|
||
@pytest.mark.parametrize("as_dataset", (True, False)) | ||
@pytest.mark.parametrize("x_coords", ("x", "lon")) | ||
@pytest.mark.parametrize("y_coords", ("y", "lat")) | ||
def test_ocean_land_fraction(as_dataset, x_coords, y_coords): | ||
|
||
_test_mask( | ||
mxu.mask.mask_ocean_fraction, | ||
as_dataset, | ||
threshold=0.5, | ||
x_coords=x_coords, | ||
y_coords=y_coords, | ||
) | ||
|
||
|
||
@pytest.mark.parametrize("as_dataset", (True, False)) | ||
def test_ocean_land_default( | ||
as_dataset, | ||
): | ||
|
||
_test_mask(mxu.mask.mask_ocean, as_dataset) | ||
|
||
|
||
@pytest.mark.parametrize("as_dataset", (True, False)) | ||
@pytest.mark.parametrize("x_coords", ("x", "lon")) | ||
@pytest.mark.parametrize("y_coords", ("y", "lat")) | ||
def test_mask_land(as_dataset, x_coords, y_coords): | ||
|
||
_test_mask(mxu.mask.mask_ocean, as_dataset, x_coords=x_coords, y_coords=y_coords) | ||
|
||
|
||
@pytest.mark.parametrize("as_dataset", (True, False)) | ||
def test_mask_antarctiva_default( | ||
as_dataset, | ||
): | ||
|
||
_test_mask(mxu.mask.mask_antarctica, as_dataset) | ||
|
||
|
||
@pytest.mark.parametrize("as_dataset", (True, False)) | ||
@pytest.mark.parametrize("y_coords", ("y", "lat")) | ||
def test_mask_antarctiva(as_dataset, y_coords): | ||
|
||
_test_mask(mxu.mask.mask_antarctica, as_dataset, y_coords=y_coords) |