Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix mask_ocean for 2D grids #314

Merged
merged 3 commits into from Sep 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGELOG.rst
Expand Up @@ -60,7 +60,9 @@ New Features
- Added functions to stack regular lat-lon grids to 1D grids and unstack them again (`#217
<https://github.com/MESMER-group/mesmer/pull/217>`_). By `Mathias Hauser`_.
- Added functions to mask the ocean and Antarctica (`#219
<https://github.com/MESMER-group/mesmer/pull/219>`_). By `Mathias Hauser`_.
- Added functions to mask the ocean and Antarctica (
`#219 <https://github.com/MESMER-group/mesmer/pull/219>`_ and
`#314 <https://github.com/MESMER-group/mesmer/pull/314>`_). By `Mathias Hauser`_.
- Added functions to calculate the weighted global mean
(`#220 <https://github.com/MESMER-group/mesmer/pull/220>`_ and
`#287 <https://github.com/MESMER-group/mesmer/pull/287>`_). By `Mathias Hauser`_.
Expand Down
10 changes: 5 additions & 5 deletions mesmer/core/mask.py
Expand Up @@ -5,13 +5,13 @@
import mesmer


def _where_if_dim(obj, cond, dims):
def _where_if_coords(obj, cond, coords):

# xarray applies where to all data_vars - even if they do not have the corresponding
# dimensions - we don't want that https://github.com/pydata/xarray/issues/7027

def _where(da):
if all(dim in da.dims for dim in dims):
if all(coord in da.coords for coord in coords):
return da.where(cond)
return da

Expand Down Expand Up @@ -71,7 +71,7 @@ def mask_ocean_fraction(data, threshold, *, x_coords="lon", y_coords="lat"):
mask_bool = mask_fraction > threshold

# only mask data_vars that have the coords
return _where_if_dim(data, mask_bool, [y_coords, x_coords])
return _where_if_coords(data, mask_bool, [y_coords, x_coords])


def mask_ocean(data, *, x_coords="lon", y_coords="lat"):
Expand Down Expand Up @@ -106,7 +106,7 @@ def mask_ocean(data, *, x_coords="lon", y_coords="lat"):
mask_bool = mask_bool.squeeze(drop=True)

# only mask data_vars that have the coords
return _where_if_dim(data, mask_bool, [y_coords, x_coords])
return _where_if_coords(data, mask_bool, [y_coords, x_coords])


def mask_antarctica(data, *, y_coords="lat"):
Expand All @@ -132,4 +132,4 @@ def mask_antarctica(data, *, y_coords="lat"):
mask_bool = data[y_coords] >= -60

# only mask if data has y_coords
return _where_if_dim(data, mask_bool, [y_coords])
return _where_if_coords(data, mask_bool, [y_coords])
24 changes: 24 additions & 0 deletions tests/unit/test_mask.py
Expand Up @@ -133,3 +133,27 @@ def test_mask_antarctiva_default(
def test_mask_antarctiva(as_dataset, y_coords):

_test_mask(mesmer.mask.mask_antarctica, as_dataset, y_coords=y_coords)


def test_mask_ocean_2D_grid():

lon = lat = np.arange(0, 30)
LON, LAT = np.meshgrid(lon, lat)

dims = ("rlat", "rlon")

data = np.random.randn(*LON.shape)

data_2D_grid = xr.Dataset(
{"data": (dims, data)}, coords={"lon": (dims, LON), "lat": (dims, LAT)}
)

data_1D_grid = xr.Dataset(
{"data": (("lat", "lon"), data)}, coords={"lon": lon, "lat": lat}
)

result = mesmer.mask.mask_ocean(data_2D_grid)
expected = mesmer.mask.mask_ocean(data_1D_grid)

# the Datasets don't have equal coords but their arrays should be the same
np.testing.assert_equal(result.data.values, expected.data.values)