In [None]:
import distributed

client = distributed.Client()
client

In [None]:
import xarray as xr
import xdggs

In [None]:
url = "/home/jmagin/work/data/destine/average_surface_temperature.zarr"
ds = (
    xr.open_dataset(url, engine="zarr", chunks={})
    .isel(oceanModelLayer=0)
    .pipe(xdggs.decode)
)
ds

In [None]:
import xarray as xr
import xdggs  # noqa: F401

from healpix_convolution.kernels import gaussian


def gaussian_kernel(
    cell_ids, sigma: float, truncate: float = 4.0, kernel_size: int | None = None
):
    """Create a symmetric gaussian kernel for the given cell ids

    Parameters
    ----------
    cell_ids : xarray.DataArray
        The cell ids. Must be valid according to xdggs.
    sigma : float
        The standard deviation of the gaussian function in radians.
    truncate : float, default: 4.0
        Truncate the kernel after this many multiples of sigma.
    kernel_size : int, optional
        If given, will be used instead of ``truncate`` to determine the size of the kernel.

    Returns
    -------
    kernel : xarray.DataArray
        The kernel as a sparse matrix.
    """
    dims = list(cell_ids.dims)

    grid = xdggs.healpix.HealpixInfo.from_dict(cell_ids.attrs)

    matrix = xr.apply_ufunc(
        gaussian.gaussian_kernel,
        cell_ids,
        kwargs={
            "resolution": grid.resolution,
            "indexing_scheme": grid.indexing_scheme,
            "sigma": sigma,
            "truncate": truncate,
            "kernel_size": kernel_size,
        },
        input_core_dims=[dims],
        output_core_dims=[["output_cells", "input_cells"]],
        dask="allowed",
        keep_attrs="drop",
    )

    if kernel_size is not None:
        size_param = {"kernel_size": kernel_size}
    else:
        size_param = {"truncate": truncate}

    return matrix.assign_attrs(
        {"kernel_type": "gaussian", "method": "continuous", "sigma": sigma} | size_param
    ).assign_coords(
        input_cell_ids=cell_ids.swap_dims({"cells": "input_cells"}).variable,
        output_cell_ids=cell_ids.swap_dims({"cells": "output_cells"}).variable,
    )

In [None]:
%%time
kernel = gaussian_kernel(ds["cell_ids"], sigma=0.0015, truncate=3)
kernel

In [None]:
def convolve(ds, kernel):
    if ds.chunksizes:
        kernel = kernel.chunk()

    def _convolve(arr, weights):
        src_dims = ["input_cells"]

        if not set(src_dims).issubset(arr.dims):
            return arr

        return xr.dot(
            # drop all input coords, as those would most likely be broadcast
            arr.variable,
            weights,
            # This dimension will be "contracted"
            # or summed over after multiplying by the weights
            dims=src_dims,
        )

    input_dims = ["cells"]

    unchanged_coords = {
        k: v for k, v in ds.coords.items() if not set(input_dims).intersection(v.dims)
    }

    return (
        ds.rename_dims({"cells": "input_cells"})
        .map(_convolve, weights=kernel)
        .rename_dims({"output_cells": "cells"})
        .rename_vars({"output_cell_ids": "cell_ids"})
    )

In [None]:
%%time
mask = ds.notnull()
convolved = convolve(ds.fillna(0).compute(), kernel).where(mask)
convolved

In [None]:
computed = convolved.compute()
computed

In [None]:
convolved_ = computed.pipe(xdggs.decode).dggs.assign_latlon_coords()
convolved_

## visualization

In [None]:
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature

In [None]:
import healpy as hp
import numpy as np

In [None]:
%%time
c = ds.compute()
mask = c.avg_thetao.notnull()
smoothed = xr.DataArray(
    hp.smoothing(c.avg_thetao.fillna(0).data, sigma=0.0015, nest=True),
    dims="cells",
    coords=c.coords,
).where(mask)
smoothed

In [None]:
smoothed.count()

In [None]:
convolved_.avg_thetao - smoothed

In [None]:
from healpix_convolution.plotting import xr_plot_healpix

In [None]:
fig, axes = plt.subplots(
    nrows=3, ncols=1, figsize=(16, 16), subplot_kw={"projection": ccrs.Mollweide()}
)
mappable1 = xr_plot_healpix(
    smoothed, ax=axes[0], cmap="plasma", vmin=230, title="smoothing"
)
# fig.colorbar(mappable1, orientation="horizontal")
mappable2 = xr_plot_healpix(
    convolved_.avg_thetao,
    ax=axes[1],
    cmap="plasma",
    vmin=230,
    title="healpix-convolution",
)
# fig.colorbar(mappable2, orientation="horizontal")
mappable3 = xr_plot_healpix(
    convolved_.avg_thetao - smoothed,
    ax=axes[2],
    cmap="RdBu_r",
    title="diff",
    vmin=-60,
    vmax=60,
)
fig.colorbar(mappable3)
fig.savefig("comparison_smoothing.png", format="png", bbox_inches="tight", dpi=300)