In [None]:
import datetime

import cartopy.crs as ccrs
import geoviews as gv
import geoviews.feature as gf
import pandas as pd
import panel as pn
import param
import pystac
import rasterio
import rioxarray  # noqa: F401
import xarray as xr
from holoviews.plotting import list_cmaps

import utils.xyt

pn.extension()
gv.extension("bokeh")

In [None]:
CATALOG_URL = "https://s3.waw3-2.cloudferro.com/swift/v1/wpl-stac/stac/catalog.json"
SITE = "degero"
COLLECTION = "albedo"

In [None]:
root: pystac.Catalog = pystac.read_file(CATALOG_URL)  # type: ignore
catalog: pystac.Catalog = root.get_child(SITE)  # type: ignore
collection: pystac.Collection = catalog.get_child(COLLECTION)  # type: ignore

In [None]:
extent = utils.xyt.Extent.from_pystac(collection.extent)

In [None]:
xyt = utils.xyt.XYT(extent=extent)

In [None]:
xyt

In [None]:
class ZarrDataset(pn.viewable.Viewer):
    location: utils.xyt.XYT = param.ClassSelector(
        class_=utils.xyt.XYT, label="Region of interest", allow_None=False, constant=True
    )  # type: ignore

    # these should be identical datasets, but chunked differently
    xy_ds: xr.Dataset = param.ClassSelector(
        class_=xr.Dataset, label="Datacube chunked for spatial reads", allow_None=False, constant=True
    )  # type: ignore
    ts_ds: xr.Dataset = param.ClassSelector(
        class_=xr.Dataset, label="Datacube chunked for temporal reads", allow_None=False, constant=True
    )  # type: ignore

    primary_var_name: str = param.Selector(objects=[])  # type: ignore
    uncertainty_var_name: str | None = param.Selector(objects=[])  # type: ignore
    uncertainty_scalar: float | None = param.Number(default=None, allow_None=True)  # type: ignore

    colormap_name: str = param.Selector(objects=list_cmaps(), allow_None=False)  # type: ignore
    colormap_min: float = param.Number(default=None, allow_None=False)  # type: ignore
    colormap_max: float = param.Number(default=None, allow_None=False)  # type: ignore

    # most appropriate date to visualize (from the time dimension of the dataset)
    # given the region of interest self.location.date
    date: datetime.datetime = param.Date(default=datetime.datetime(2000, 1, 1), allow_None=False)  # type: ignore

    crs: ccrs.CRS = param.ClassSelector(
        class_=ccrs.CRS, default=None, allow_None=False, constant=True
    )  # type: ignore

    def __init__(self, **params):
        super().__init__(**params)

        data_vars = list(self.xy_ds.data_vars)

        # set primary_var_name options from the Dataset
        self.param.primary_var_name.objects = data_vars
        self.param.primary_var_name.allow_None = False

        # set uncertainty_var_name options from the Dataset
        self.param.uncertainty_var_name.objects = data_vars
        self.param.uncertainty_var_name.allow_None = True

        # reassign values to trigger validation
        if "primary_var_name" in params:
            self.primary_var_name = params["primary_var_name"]
        else:
            self.primary_var_name = data_vars[0]

        if "uncertainty_var_name" in params:
            self.uncertainty_var_name = params["uncertainty_var_name"]
        else:
            self.uncertainty_var_name = None

        if "crs" not in params:
            # pull the CRS from xy_ds
            crs: rasterio.crs.CRS | None = self.xy_ds.rio.crs
            if crs is None:
                raise ValueError
            if not crs.is_epsg_code:
                raise NotImplementedError
            with param.edit_constant(self):
                self.crs = ccrs.epsg(crs.to_epsg())

    @staticmethod
    def from_pystac(location: utils.xyt.XYT, collection: pystac.Collection) -> "ZarrDataset":
        """
        Create a ZarrDataset from a pystac Collection

        Looks for a custom field `wpl:render` in the collection's metadata

        ```json
        {
            "assets": ["albedo.xy.zarr", "albedo.ts.zarr"],
            "colormap_name": "copper",
            "colormap_range": [0, 1],
            "primary_var_name": "albedo",
            "uncertainty_scalar": None,
            "uncertainty_var_name": "albedo_std_dev"
        }
        ```
        """

        # this is a custom field which provides some default visualization parameters for the zarr datacube
        WPL_RENDER_KEY = "wpl:render"

        if WPL_RENDER_KEY not in collection.extra_fields:
            raise ValueError(f"Collection {collection.id} does not have the required field {WPL_RENDER_KEY}")

        wpl_render = collection.extra_fields[WPL_RENDER_KEY]

        xy_asset_key = next(a for a in wpl_render["assets"] if a.endswith(".xy.zarr"))
        ts_asset_key = next(a for a in wpl_render["assets"] if a.endswith(".ts.zarr"))

        xy_asset = collection.assets[xy_asset_key]
        ts_asset = collection.assets[ts_asset_key]

        xy_ds = xr.open_dataset(
            xy_asset.href,
            **xy_asset.ext.xarray.open_kwargs,  # type: ignore
        )
        ts_ds = xr.open_dataset(
            ts_asset.href,
            **ts_asset.ext.xarray.open_kwargs,  # type: ignore
        )

        return ZarrDataset(
            location=location,
            xy_ds=xy_ds,
            ts_ds=ts_ds,
            primary_var_name=wpl_render["primary_var_name"],
            uncertainty_var_name=wpl_render["uncertainty_var_name"],
            uncertainty_scalar=wpl_render["uncertainty_scalar"],
            colormap_name=wpl_render["colormap_name"],
            colormap_min=wpl_render["colormap_range"][0],
            colormap_max=wpl_render["colormap_range"][1],
        )

    @param.depends("location.date", watch=True, on_init=True)
    def select_date(self):
        """
        Given the date of the region of interest,
        find the most appropriate date to visualize from the time dimension of the dataset.
        """
        # this is a numpy datetime64[ns]
        t = self.xy_ds.time.sel(time=self.location.date, method="ffill")
        self.date = pd.Timestamp(t.values).to_pydatetime()

    def map_view(self):
        """
        GeoViews map plot of the primary variable @ the time of interest.
        Shows the bounding box of the extent.
        Plot the point of interest.
        Tap / click on the map to update the point of interest.
        """
        def map(time: datetime.datetime):
            _slice = self.xy_ds[self.primary_var_name].sel(time=time)
            # manually trigger loading the data from the zarr store
            _slice.load()
            img = gv.Image(_slice, kdims=["x", "y"], crs=self.crs)
            img.opts(colorbar=True, cmap=self.colormap_name, clim=(self.colormap_min, self.colormap_max))
            return img

        map_dmap = gv.DynamicMap(map, streams={"time": self.param.date})

        bbox = self.location.extent.spatial.polygon
        point = self.location.point()

        overlay = gf.ocean * gf.land * bbox * map_dmap * point

        return overlay

    def __panel__(self) -> pn.viewable.Viewable:
        """
        Build a visualisation of the dataset.

        GeoViews map plot of the primary variable @ the time of interest.
        Plot the point of interest.
        Tap / click on the map to update the point of interest.

        Time series plot of the primary variable @ the point of interest.
        Includes uncertainty if available.
        Plot the time of interest.
        Tap / click on the time series plot to update the time of interest.
        """
        pass


In [None]:
obj = ZarrDataset.from_pystac(xyt, collection)

In [None]:
obj.map_view().opts(width=500)