Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create bounds for RotatedPole, Creep fill #174

Merged
merged 4 commits into from Mar 29, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 6 additions & 3 deletions HISTORY.rst
Expand Up @@ -12,9 +12,12 @@ Announcements

New features and enhancements
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
* New 'cos-lat' averaging in `spatial_mean` (:issue:`94`, :pull:`125`).
* Support for computing anomalies in `compute_deltas` (:pull:`165`).
* Add function `diagnostics.measures_improvement_2d`. (:pull:`167`).
* New 'cos-lat' averaging in ``spatial_mean`` (:issue:`94`, :pull:`125`).
* Support for computing anomalies in ``compute_deltas`` (:pull:`165`).
* Add function ``diagnostics.measures_improvement_2d``. (:pull:`167`).
* Add function ``regrid.create_bounds_rotated_pole`` and automatic use in ``regrid_dataset`` and ``spatial_mean``. This is temporary, while we wait for a functionning method in ``cf_xarray``. (:pull:`174`, :issue:`96`).
* Add ``spatial`` submodule with functions ``creep_weights`` and ``creep_fill`` for filling NaNs using neighbours. (:pull:`174`).
* Allow passing ``GeoDataFrame`` instances in ``spatial_mean``'s ``region`` argument, not only geospatial file paths. (:pull:`174`).

Breaking changes
^^^^^^^^^^^^^^^^
Expand Down
1 change: 1 addition & 0 deletions xscen/__init__.py
Expand Up @@ -15,6 +15,7 @@
reduce,
regrid,
scripting,
spatial,
utils,
)

Expand Down
20 changes: 18 additions & 2 deletions xscen/aggregate.py
Expand Up @@ -450,7 +450,10 @@ def spatial_mean(
}

elif region["method"] == "shape":
s = gpd.read_file(region["shape"]["shape"])
if not isinstance(region["shape"]["shape"], gpd.GeoDataFrame):
s = gpd.read_file(region["shape"]["shape"])
else:
s = region["shape"]["shape"]
if len(s != 1):
raise ValueError(
"Only a single polygon should be used with interp_centroid."
Expand Down Expand Up @@ -499,7 +502,10 @@ def spatial_mean(

# If the region is a shapefile, open with geopandas
elif region["method"] == "shape":
polygon = gpd.read_file(region["shape"]["shape"])
if not isinstance(region["shape"]["shape"], gpd.GeoDataFrame):
polygon = gpd.read_file(region["shape"]["shape"])
else:
polygon = region["shape"]["shape"]

# Simplify the geometries to a given tolerance, if needed.
# The simpler the polygons, the faster the averaging, but it will lose some precision.
Expand All @@ -519,6 +525,16 @@ def spatial_mean(

kwargs_copy = deepcopy(kwargs)
skipna = kwargs_copy.pop("skipna", False)

if (
ds.cf["longitude"].ndim == 2
and "longitude" not in ds.cf.bounds
and "rotated_pole" in ds
):
from .regrid import create_bounds_rotated_pole

ds = ds.update(create_bounds_rotated_pole(ds))

savg = xe.SpatialAverager(ds, polygon.geometry, **kwargs_copy)
ds_agg = savg(ds, keep_attrs=True, skipna=skipna)
extra_coords = {
Expand Down
62 changes: 62 additions & 0 deletions xscen/regrid.py
Expand Up @@ -7,6 +7,8 @@
from pathlib import PosixPath
from typing import Optional, Union

import cartopy.crs as ccrs
import cf_xarray as cfxr
import numpy as np
import xarray as xr
import xesmf as xe
Expand Down Expand Up @@ -310,6 +312,20 @@ def _regridder(
xe.frontend.Regridder
Regridder object
"""
if method.startswith("conservative"):
if (
ds_in.cf["longitude"].ndim == 2
and "longitude" not in ds_in.cf.bounds
and "rotated_pole" in ds_in
):
ds_in = ds_in.update(create_bounds_rotated_pole(ds_in))
if (
ds_grid.cf["longitude"].ndim == 2
and "longitude" not in ds_grid.cf.bounds
and "rotated_pole" in ds_grid
):
ds_grid = ds_grid.update(create_bounds_rotated_pole(ds_grid))

regridder = xe.Regridder(
ds_in=ds_in,
ds_out=ds_grid,
Expand All @@ -321,3 +337,49 @@ def _regridder(
regridder.to_netcdf(filename)

return regridder


def create_bounds_rotated_pole(ds):
"""Create bounds for rotated pole datasets."""
ds = ds.cf.add_bounds(["rlat", "rlon"])

# In "vertices" format then expand to 2D. From (N, 2) to (N+1,) to (N+1, M+1)
rlatv1D = cfxr.bounds_to_vertices(ds.rlat_bounds, "bounds")
rlonv1D = cfxr.bounds_to_vertices(ds.rlon_bounds, "bounds")
rlatv = rlatv1D.expand_dims(rlon_vertices=rlonv1D).transpose(
"rlon_vertices", "rlat_vertices"
)
rlonv = rlonv1D.expand_dims(rlat_vertices=rlatv1D).transpose(
"rlon_vertices", "rlat_vertices"
)

# Get cartopy's crs for the projection
RP = ccrs.RotatedPole(
pole_longitude=ds.rotated_pole.grid_north_pole_longitude,
pole_latitude=ds.rotated_pole.grid_north_pole_latitude,
central_rotated_longitude=ds.rotated_pole.north_pole_grid_longitude,
)
PC = ccrs.PlateCarree()

# Project points
pts = PC.transform_points(RP, rlonv.values, rlatv.values)
lonv = rlonv.copy(data=pts[..., 0]).rename("lon_vertices")
latv = rlatv.copy(data=pts[..., 1]).rename("lat_vertices")

# Back to CF bounds format. From (N+1, M+1) to (4, N, M)
lonb = cfxr.vertices_to_bounds(lonv, ("bounds", "rlon", "rlat")).rename(
"lon_bounds"
)
latb = cfxr.vertices_to_bounds(latv, ("bounds", "rlon", "rlat")).rename(
"lat_bounds"
)

# Create dataset, set coords and attrs
ds_bnds = xr.merge([lonb, latb]).assign(
lon=ds.lon, lat=ds.lat, rotated_pole=ds.rotated_pole
)
ds_bnds["rlat"] = ds.rlat
ds_bnds["rlon"] = ds.rlon
ds_bnds.lat.attrs["bounds"] = "lat_bounds"
ds_bnds.lon.attrs["bounds"] = "lon_bounds"
return ds_bnds.transpose(*ds.lon.dims, "bounds")
105 changes: 105 additions & 0 deletions xscen/spatial.py
@@ -0,0 +1,105 @@
"""Spatial tools."""
import itertools

import numpy as np
import sparse as sp
import xarray as xr


def creep_weights(mask, n=1, mode="clip"):
"""Compute weights for the creep fill.

The output is a sparse matrix with the same dimensions as `mask`, twice.

Parameters
----------
mask : DataArray
A boolean DataArray. False values are candidates to the filling.
Usually they represent missing values (`mask = da.notnull()`).
All dimensions are creep filled.
n : int
The order of neighbouring to use. 1 means only the adjacent grid cells are used.
mode : {'clip', 'wrap'}
If a cell is on the edge of the domain, `mode='wrap'` will wrap around to find neighbours.

Returns
-------
DataArray
Weights. The dot product must be taken over the last N dimensions.
"""
da = mask
mask = da.values
neighbors = np.array(
list(itertools.product(*[np.arange(-n, n + 1) for j in range(mask.ndim)]))
).T
src = []
dst = []
w = []
it = np.nditer(mask, flags=["f_index", "multi_index"], order="C")
for i in it:
if not i:
neigh_idx_2d = np.atleast_2d(it.multi_index).T + neighbors
neigh_idx_1d = np.ravel_multi_index(
neigh_idx_2d, mask.shape, order="C", mode=mode
)
neigh_idx = np.unravel_index(np.unique(neigh_idx_1d), mask.shape, order="C")
neigh = mask[neigh_idx]
N = (neigh).sum()
if N > 0:
src.extend([it.multi_index] * N)
dst.extend(np.stack(neigh_idx)[:, neigh].T)
w.extend([1 / N] * N)
else:
src.extend([it.multi_index])
dst.extend([it.multi_index])
w.extend([np.nan])
else:
src.extend([it.multi_index])
dst.extend([it.multi_index])
w.extend([1])
crds = np.concatenate((np.array(src).T, np.array(dst).T), axis=0)
return xr.DataArray(
sp.COO(crds, w, (*da.shape, *da.shape)),
dims=[f"{d}_out" for d in da.dims] + list(da.dims),
coords=da.coords,
name="creep_fill_weights",
)


def creep_fill(da, w):
"""Creep fill using pre-computed weights.

Parameters
----------
da: DataArray
A DataArray sharing the dimensions with the one used to compute the weights.
It can have other dimensions.
Dask is supported as long as there are no chunks over the creeped dims.
w: DataArray
The result of `creep_weights`.

Returns
-------
xarray.DataArray, same shape as `da`, but values filled according to `w`.

Examples
--------
>>> w = creep_weights(da.isel(time=0).notnull(), n=1)
>>> da_filled = creep_fill(da, w)
"""

def _dot(arr, wei):
N = wei.ndim // 2
extra_dim = arr.ndim - N
return np.tensordot(arr, wei, axes=(np.arange(N) + extra_dim, np.arange(N) + N))

N = w.ndim // 2
return xr.apply_ufunc(
_dot,
da,
w,
input_core_dims=[w.dims[N:], w.dims],
output_core_dims=[w.dims[N:]],
dask="parallelized",
output_dtypes=["float64"],
)