In [None]:
import math

import healpix_geo.nested
import numpy as np
import xarray as xr
import xdggs  # noqa: F401
import xesmf

In [None]:
ds0m = xr.open_dataset(
    "output_ifsnemo_highres.grib",
    engine="cfgrib",
    filter_by_keys={"level": 0},
    decode_timedelta=True,
)
ds2m = xr.open_dataset(
    "output_ifsnemo_highres.grib",
    engine="cfgrib",
    filter_by_keys={"level": 2},
    decode_timedelta=True,
)
ds10m = xr.open_dataset(
    "output_ifsnemo_highres.grib",
    engine="cfgrib",
    filter_by_keys={"level": 10},
    decode_timedelta=True,
)

grid_metadata = {
    "indexing_scheme": "nested",
    "grid_name": "healpix",
    "level": int(math.log(ds0m.sizes["values"] / 12) / math.log(4)),
}
dt = xr.DataTree.from_dict(
    {"surface": ds0m, "h2m": ds2m, "h10m": ds10m}
).map_over_datasets(
    lambda ds: (
        ds.assign_coords(cell_ids=ds["values"]).dggs.decode(grid_metadata)
        if "values" in ds.dims
        else ds
    )
)
dt

In [None]:
source_grid = ds0m[["latitude", "longitude"]]

cell_ids = np.arange(12 * 4 ** grid_metadata["level"], dtype="uint64")
longitude, latitude = healpix_geo.nested.healpix_to_lonlat(
    cell_ids, depth=grid_metadata["level"], ellipsoid="WGS84"
)
target_grid = xr.Dataset(
    coords={
        "cell_ids": ("cells", cell_ids, grid_metadata),
        "longitude": ("cells", longitude),
        "latitude": ("cells", latitude),
    }
)
display(source_grid, target_grid)

In [None]:
regridder = xesmf.Regridder(
    source_grid,
    target_grid,
    method="nearest_s2d",
    locstream_in=True,
    locstream_out=True,
    periodic=True,
)
regridder

In [None]:
regridded = dt.map_over_datasets(
    lambda ds: (
        ds
        if "values" not in ds.dims
        else regridder.regrid_dataset(ds, keep_attrs=True).dggs.decode(grid_metadata)
    )
)
regridded

In [None]:
regridded["/surface/sp"].dggs.explore(alpha=0.8)