In [None]:
import math

import healpix_geo.nested
import numpy as np
import xarray as xr
import xdggs  # noqa: F401
import xesmf

In [None]:
url = "output_ifsnemo_highres.grib"  # local path

In [None]:
def normalize_height_above_ground(ds):
    if "heightAboveGround" in ds.coords:
        coord = ds["heightAboveGround"].variable
        metadata = coord.attrs
    else:
        coord = ds["surface"].variable
        metadata = {"units": "m", "standard_name": "height_above_ground"}

    attrs = {"height_above_ground": coord.item(), "_eopf_attrs": metadata}
    new = ds.drop_vars(["surface", "heightAboveGround"], errors="ignore")
    for var in new.data_vars.values():
        var.attrs |= attrs
    return new


ds0m = xr.open_dataset(
    url, engine="cfgrib", filter_by_keys={"level": 0}, decode_timedelta=True
).pipe(normalize_height_above_ground)
ds2m = xr.open_dataset(
    url, engine="cfgrib", filter_by_keys={"level": 2}, decode_timedelta=True
).pipe(normalize_height_above_ground)
ds10m = xr.open_dataset(
    url, engine="cfgrib", filter_by_keys={"level": 10}, decode_timedelta=True
).pipe(normalize_height_above_ground)

grid_metadata = {
    "indexing_scheme": "nested",
    "grid_name": "healpix",
    "level": int(math.log(ds0m.sizes["values"] / 12) / math.log(4)),
}
ds = (
    xr.merge([ds0m, ds2m, ds10m], compat="override")
    .assign_coords(cell_ids=lambda ds: ds["values"])
    .dggs.decode(grid_metadata)
)
ds.compute()

In [None]:
source_grid = ds[["latitude", "longitude"]].load()

cell_ids = np.arange(12 * 4 ** grid_metadata["level"], dtype="uint64")
longitude, latitude = healpix_geo.nested.healpix_to_lonlat(
    cell_ids, depth=grid_metadata["level"], ellipsoid="WGS84"
)
target_grid = xr.Dataset(
    coords={
        "cell_ids": ("cells", cell_ids, grid_metadata),
        "longitude": ("cells", longitude),
        "latitude": ("cells", latitude),
    }
)
display(source_grid, target_grid)

In [None]:
%%time
regridder = xesmf.Regridder(
    source_grid,
    target_grid,
    method="nearest_s2d",
    locstream_in=True,
    locstream_out=True,
    periodic=True,
)
regridder

In [None]:
regridded = regridder.regrid_dataset(ds, keep_attrs=True).dggs.decode(grid_metadata)
regridded

In [None]:
regridded["sp"].dggs.explore(alpha=0.8)

In [None]:
zarr_url = f"{url.removesuffix('.grib')}.zarr"

In [None]:
chunk_size = 4**regridded.dggs.grid_info.level
regridded.chunk({"cells": chunk_size}).to_zarr(zarr_url, mode="w")

In [None]:
reloaded = xr.open_dataset(zarr_url, engine="zarr", chunks={})
reloaded