In [None]:
import healpix_geo.nested
import numpy as np
import xarray as xr
import xdggs  # noqa: F401
import xesmf

In [None]:
def normalize_height_above_ground(ds):
    if "heightAboveGround" in ds.coords:
        coord = ds["heightAboveGround"].variable
        metadata = coord.attrs
    else:
        coord = ds["surface"].variable
        metadata = {"units": "m", "standard_name": "height_above_ground"}

    attrs = {"height_above_ground": coord.item(), "_eopf_attrs": metadata}
    new = ds.drop_vars(["surface", "heightAboveGround"], errors="ignore")
    for var in new.data_vars.values():
        var.attrs |= attrs
    return new


def open_grib_levels(url: str, levels: dict[int, str], *, grid_metadata):
    dss = [
        xr.open_dataset(
            url, engine="cfgrib", filter_by_keys={"level": level}, decode_timedelta=True
        ).pipe(normalize_height_above_ground)
        for level in levels
    ]

    return (
        xr.merge(dss, compat="override")
        .assign_coords(cell_ids=lambda ds: ds["values"])
        .dggs.decode(grid_metadata)
    )

In [None]:
url = "output_ifsnemo_highres.grib"  # local path

ds_l10 = open_grib_levels(
    url,
    levels=[0, 2, 10],
    grid_metadata={
        "indexing_scheme": "nested",
        "grid_name": "healpix",
        "level": 10,
    },
)
ds_l10

In [None]:
url = "output_ifsnemo_lowres.grib"  # local path

ds_l6 = open_grib_levels(
    url,
    levels=[0, 2, 10],
    grid_metadata={
        "indexing_scheme": "nested",
        "grid_name": "healpix",
        "level": 6,
    },
)
ds_l6

In [None]:
def spherical_to_ellipsoidal(ds):
    source_grid = ds[["latitude", "longitude"]].load()
    grid_info = ds.dggs.grid_info

    cell_ids = np.arange(12 * 4**grid_info.level, dtype="uint64")
    longitude, latitude = healpix_geo.nested.healpix_to_lonlat(
        cell_ids, depth=grid_info.level, ellipsoid="WGS84"
    )
    target_grid = xr.Dataset(
        coords={
            "cell_ids": ("cells", cell_ids),
            "longitude": ("cells", longitude),
            "latitude": ("cells", latitude),
        }
    )

    regridder = xesmf.Regridder(
        source_grid,
        target_grid,
        method="nearest_s2d",
        locstream_in=True,
        locstream_out=True,
        periodic=True,
    )

    return regridder.regrid_dataset(ds, keep_attrs=True).dggs.decode(grid_info)

In [None]:
%%time
ds_l10_ellipsoid = spherical_to_ellipsoidal(ds_l10)
ds_l6_ellipsoid = spherical_to_ellipsoidal(ds_l6)

display(ds_l10_ellipsoid, ds_l6_ellipsoid)

In [None]:
regridded = xr.DataTree.from_dict(
    {
        "10": ds_l10_ellipsoid.drop_indexes("cell_ids").chunk({"cells": 4**10}),
        "6": ds_l6_ellipsoid.chunk(),
    }
)
regridded

In [None]:
zarr_url = f"{url.removesuffix('.grib')}.zarr"

In [None]:
regridded.to_zarr(zarr_url, mode="w")

In [None]:
reloaded = xr.open_datatree(zarr_url, engine="zarr", chunks={})
reloaded