Skip to content

Commit

Permalink
xarray_utils: mask ocean and Antarctica (#219)
Browse files Browse the repository at this point in the history
* xarray_utils: mask ocean and Antarctica

* add CHANGELOG

* rename mask_land to mask_ocean

* add tests for dataarray

* move to top level

* rename obj -> data

* Update docs/source/api.rst

* mask_ocean: mention grid cell center

* fix merge issue
  • Loading branch information
mathause committed Nov 16, 2022
1 parent 8b8ae76 commit fe14ca6
Show file tree
Hide file tree
Showing 5 changed files with 277 additions and 2 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.rst
Expand Up @@ -32,6 +32,9 @@ New Features
- Added functions to stack regular lat-lon grids to 1D grids and unstack them again (`#217
<https://github.com/MESMER-group/mesmer/pull/217>`_). By `Mathias Hauser
<https://github.com/mathause>`_.
- Added functions to mask the ocean and Antarctica for xarray objects (`#219
<https://github.com/MESMER-group/mesmer/pull/219>`_). By `Mathias Hauser
<https://github.com/mathause>`_.
- Added functions to calculate the weighted global mean (`#220
<https://github.com/MESMER-group/mesmer/pull/220>`_). By `Mathias Hauser
<https://github.com/mathause>`_.
Expand Down
3 changes: 3 additions & 0 deletions docs/source/api.rst
Expand Up @@ -46,6 +46,9 @@ Data manipulation
~xarray_utils.grid.unstack_lat_lon_and_align
~xarray_utils.grid.unstack_lat_lon
~xarray_utils.grid.align_to_coords
~xarray_utils.mask.mask_ocean_fraction
~xarray_utils.mask.mask_ocean
~xarray_utils.mask.mask_antarctica
~xarray_utils.global_mean.lat_weights
~xarray_utils.global_mean.weighted_mean

Expand Down
3 changes: 1 addition & 2 deletions mesmer/xarray_utils/__init__.py
@@ -1,4 +1,3 @@
# flake8: noqa


from mesmer.xarray_utils import global_mean, grid
from mesmer.xarray_utils import global_mean, grid, mask
135 changes: 135 additions & 0 deletions mesmer/xarray_utils/mask.py
@@ -0,0 +1,135 @@
import numpy as np
import regionmask
import xarray as xr

import mesmer.utils


def _where_if_dim(obj, cond, dims):

# xarray applies where to all data_vars - even if they do not have the corresponding
# dimensions - we don't want that https://github.com/pydata/xarray/issues/7027

def _where(da):
if all(dim in da.dims for dim in dims):
return da.where(cond)
return da

if isinstance(obj, xr.Dataset):
return obj.map(_where, keep_attrs=True)

return obj.where(cond)


def mask_ocean_fraction(data, threshold, *, x_coords="lon", y_coords="lat"):
"""mask out ocean using fractional overlap
Parameters
----------
data : xr.Dataset | xr.DataArray
Array to mask.
threshold : float
Threshold above which land fraction to consider a grid point as a land grid
point. Must be must be between 0 and 1 inclusive.
x_coords : str, default: "lon"
Name of the x-coordinates.
y_coords : str, default: "lat"
Name of the y-coordinates.
Returns
-------
data : xr.Dataset | xr.DataArray
Array with ocean grid points masked out.
Notes
-----
- Uses the 1:110m land mask from Natural Earth (http://www.naturalearthdata.com).
- The fractional overlap of individual grid points and the land mask can only be
computed for regularly-spaced 1D x- and y-coordinates. For irregularly spaced
coordinates use :py:func:`mesmer.xarray_utils.mask_land`.
"""

if np.ndim(threshold) != 0 or (threshold < 0) or (threshold > 1):
raise ValueError("`threshold` must be a scalar between 0 and 1 (inclusive).")

# TODO: allow other masks?
land_110 = regionmask.defined_regions.natural_earth_v5_0_0.land_110

try:
mask_fraction = mesmer.utils.regionmaskcompat.mask_3D_frac_approx(
land_110, data[x_coords], data[y_coords]
)
except mesmer.utils.regionmaskcompat.InvalidCoordsError as e:
raise ValueError(
"Cannot calculate fractional mask for irregularly-spaced coords - use "
"``mask_land`` instead."
) from e

# drop region-specific coords
mask_fraction = mask_fraction.squeeze(drop=True)

mask_bool = mask_fraction > threshold

# only mask data_vars that have the coords
return _where_if_dim(data, mask_bool, [y_coords, x_coords])


def mask_ocean(data, *, x_coords="lon", y_coords="lat"):
"""mask out ocean
Parameters
----------
data : xr.Dataset | xr.DataArray
Array to mask.
x_coords : str, default: "lon"
Name of the x-coordinates.
y_coords : str, default: "lat"
Name of the y-coordinates.
Returns
-------
data : xr.Dataset | xr.DataArray
Array with ocean grid points masked out.
Notes
-----
- Uses the 1:110m land mask from Natural Earth (http://www.naturalearthdata.com).
- Whether a grid cell is in the ocean or on land is based on its center. For
regularly spaced coordinates use :py:func:`mesmer.xarray_utils.mask_land_fraction`.
"""

# TODO: allow other masks?
land_110 = regionmask.defined_regions.natural_earth_v5_0_0.land_110

mask_bool = land_110.mask_3D(data[x_coords], data[y_coords])

mask_bool = mask_bool.squeeze(drop=True)

# only mask data_vars that have the coords
return _where_if_dim(data, mask_bool, [y_coords, x_coords])


def mask_antarctica(data, *, y_coords="lat"):
"""mask out ocean
Parameters
----------
data : xr.Dataset | xr.DataArray
Array to mask.
y_coords : str, default: "lat"
Name of the y-coordinates.
Returns
-------
data : xr.Dataset | xr.DataArray
Array with Antarctic grid points masked out.
Notes
-----
- Masks grid points below 60°S.
"""

mask_bool = data[y_coords] >= -60

# only mask if data has y_coords
return _where_if_dim(data, mask_bool, [y_coords])
135 changes: 135 additions & 0 deletions tests/unit/test_mask.py
@@ -0,0 +1,135 @@
import numpy as np
import pytest
import xarray as xr

import mesmer.xarray_utils as mxu


def data_lon_lat(as_dataset, x_coords="lon", y_coords="lat"):

lon = np.arange(0.5, 360, 2)
lat = np.arange(90, -91, -2)

data = np.random.randn(*lat.shape, *lon.shape)

da = xr.DataArray(
data,
dims=(y_coords, x_coords),
coords={x_coords: lon, y_coords: lat},
attrs={"key": "da_attrs"},
)

ds = xr.Dataset(data_vars={"data": da, "scalar": 1}, attrs={"key": "ds_attrs"})

if as_dataset:
return ds
return ds.data


@pytest.mark.parametrize("threshold", ([0, 1], -0.1, 1.1))
def test_ocean_land_fraction_errors(threshold):

data = data_lon_lat(as_dataset=True)

with pytest.raises(
ValueError, match="`threshold` must be a scalar between 0 and 1"
):
mxu.mask.mask_ocean_fraction(data, threshold=threshold)


def test_ocean_land_fraction_irregular():

lon = [0, 1, 3]
lat = [0, 1, 3]

data = xr.Dataset(coords={"lon": lon, "lat": lat})

with pytest.raises(
ValueError,
match="Cannot calculate fractional mask for irregularly-spaced coords",
):
mxu.mask.mask_ocean_fraction(data, threshold=0.5)


def test_ocean_land_fraction_threshold():
# check that the threshold has an influence
data = data_lon_lat(as_dataset=True)

result_033 = mxu.mask.mask_ocean_fraction(data, threshold=0.33)
result_066 = mxu.mask.mask_ocean_fraction(data, threshold=0.66)

assert not (result_033.data == result_066.data).all()


def _test_mask(func, as_dataset, threshold=None, **kwargs):
# not checking the actual mask

data = data_lon_lat(as_dataset=as_dataset, **kwargs)

kwargs = kwargs if threshold is None else {"threshold": threshold, **kwargs}
result = func(data, **kwargs)

if as_dataset:
# ensure scalar is not broadcast
assert result.scalar.ndim == 0
assert result.attrs == {"key": "ds_attrs"}

result_da = result.data
else:
result_da = result

# ensure mask is applied
assert result_da.isnull().any()

assert result_da.attrs == {"key": "da_attrs"}


@pytest.mark.parametrize("as_dataset", (True, False))
def test_ocean_land_fraction_default(as_dataset):

_test_mask(mxu.mask.mask_ocean_fraction, as_dataset, threshold=0.5)


@pytest.mark.parametrize("as_dataset", (True, False))
@pytest.mark.parametrize("x_coords", ("x", "lon"))
@pytest.mark.parametrize("y_coords", ("y", "lat"))
def test_ocean_land_fraction(as_dataset, x_coords, y_coords):

_test_mask(
mxu.mask.mask_ocean_fraction,
as_dataset,
threshold=0.5,
x_coords=x_coords,
y_coords=y_coords,
)


@pytest.mark.parametrize("as_dataset", (True, False))
def test_ocean_land_default(
as_dataset,
):

_test_mask(mxu.mask.mask_ocean, as_dataset)


@pytest.mark.parametrize("as_dataset", (True, False))
@pytest.mark.parametrize("x_coords", ("x", "lon"))
@pytest.mark.parametrize("y_coords", ("y", "lat"))
def test_mask_land(as_dataset, x_coords, y_coords):

_test_mask(mxu.mask.mask_ocean, as_dataset, x_coords=x_coords, y_coords=y_coords)


@pytest.mark.parametrize("as_dataset", (True, False))
def test_mask_antarctiva_default(
as_dataset,
):

_test_mask(mxu.mask.mask_antarctica, as_dataset)


@pytest.mark.parametrize("as_dataset", (True, False))
@pytest.mark.parametrize("y_coords", ("y", "lat"))
def test_mask_antarctiva(as_dataset, y_coords):

_test_mask(mxu.mask.mask_antarctica, as_dataset, y_coords=y_coords)

0 comments on commit fe14ca6

Please sign in to comment.