In [None]:
import healpix_geo.nested
import numpy as np
import xarray as xr
import xdggs  # noqa: F401
import xesmf

In [None]:
import distributed

client = distributed.Client()
client

In [None]:
url = "standard-2020-01.grib"  # local path

In [None]:
def normalize_height_above_ground(ds):
    if "heightAboveGround" in ds.coords:
        coord = ds["heightAboveGround"].variable
        metadata = coord.attrs
    else:
        coord = ds["surface"].variable
        metadata = {"units": "m", "standard_name": "height_above_ground"}

    attrs = {"height_above_ground": coord.item(), "_eopf_attrs": metadata}
    new = ds.drop_vars(["surface", "heightAboveGround"], errors="ignore")
    for var in new.data_vars.values():
        var.attrs |= attrs
    return new


def open_grib_levels(url: str, levels: dict[int, str], *, grid_metadata, **kwargs):
    dss = [
        xr.open_dataset(
            url,
            engine="cfgrib",
            filter_by_keys={"level": level},
            decode_timedelta=True,
            **kwargs
        ).pipe(normalize_height_above_ground)
        for level in levels
    ]

    return (
        xr.merge(dss, compat="override")
        .assign_coords(cell_ids=lambda ds: ds["values"])
        .dggs.decode(grid_metadata)
    )


ds = open_grib_levels(
    url,
    levels=[0, 2, 10],
    chunks={"time": 24, "values": -1},
    grid_metadata={
        "indexing_scheme": "nested",
        "grid_name": "healpix",
        "level": 7,
    },
)
ds

In [None]:
source_grid = ds[["latitude", "longitude"]].load()

grid_info = ds.dggs.grid_info
cell_ids = np.arange(12 * 4**grid_info.level, dtype="uint64")
longitude, latitude = healpix_geo.nested.healpix_to_lonlat(
    cell_ids, depth=grid_info.level, ellipsoid="WGS84"
)
target_grid = xr.Dataset(
    coords={
        "cell_ids": ("cells", cell_ids),
        "longitude": ("cells", longitude),
        "latitude": ("cells", latitude),
    }
)
display(source_grid, target_grid)

In [None]:
%%time
regridder = xesmf.Regridder(
    source_grid,
    target_grid,
    method="nearest_s2d",
    locstream_in=True,
    locstream_out=True,
    periodic=True,
)
regridder

In [None]:
grid_metadata = ds.dggs.grid_info

regridded = regridder.regrid_dataset(ds, keep_attrs=True).dggs.decode(grid_metadata)
regridded

In [None]:
in_memory = regridded.compute()
in_memory

In [None]:
in_memory["t2m"].dggs.explore(alpha=0.8)

In [None]:
from collections.abc import Hashable
from dataclasses import dataclass
from typing import Self

import pandas as pd


@dataclass
class DGGSGrouper(xr.groupers.Resampler):
    # TODO: define delta_level as delta_level ∈ ℕ (> 0)
    delta_level: int

    def factorize(self, group: xr.groupers.T_Group) -> xr.groupers.EncodedGroups:
        self.group = group

        new_level = group.dggs.grid_info.level - self.delta_level
        codes = group.dggs.zoom_to(new_level).rename("cell_ids")
        index = pd.Index(np.arange(12 * 4**new_level, dtype="uint64"))

        return xr.groupers.EncodedGroups(codes=codes, full_index=index)

    def reset(self) -> Self:
        return type(self)(delta_level=self.delta_level)

    def compute_chunks(
        self, variable: xr.Variable, *, dim: Hashable
    ) -> tuple[int, ...]:
        pass

In [None]:
%%time
downscaled = in_memory.groupby(cell_ids=DGGSGrouper(delta_level=4)).mean()
downscaled

In [None]:
zarr_url = f"{url.removesuffix('.grib')}.zarr"

Chunksize: base cells (nside ** 2 / 4**level)

In [None]:
chunk_size = 4**regridded.dggs.grid_info.level
regridded.chunk({"cells": chunk_size, "time": 7 * 24}).to_zarr(zarr_url, mode="w")

In [None]:
reloaded = xr.open_dataset(zarr_url, engine="zarr", chunks={})
reloaded