In [None]:
import logging
from math import ceil, sqrt
from pathlib import Path

import folium
import holoviews as hv
import hvplot.xarray  # noqa
import xarray as xr
from darts_acquisition.arcticdem import load_arcticdem_tile
from darts_acquisition.planet import load_planet_masks, load_planet_scene
from darts_acquisition.tcvis import load_tcvis
from darts_ensemble.ensemble_v1 import EnsembleV1
from darts_postprocessing.prepare_export import prepare_export
from darts_preprocessing import preprocess_legacy_fast
from rich import traceback
from rich.logging import RichHandler

from darts.utils.earthengine import init_ee
from darts.utils.logging import LoggingManager

LoggingManager.setup_logging()
logging.basicConfig(
    level=logging.INFO,
    format="%(message)s",
    datefmt="[%X]",
    handlers=[RichHandler(rich_tracebacks=True)],
)
traceback.install(show_locals=False)
init_ee("ee-tobias-hoelzer")

In [None]:
DATA_ROOT = Path("../data")

# fpath = DATA_ROOT / "input/planet/PSOrthoTile/4372514/5790392_4372514_2022-07-16_2459"
fpath = DATA_ROOT / "input/planet/PSOrthoTile/4974017/5854937_4974017_2022-08-14_2475"
arcticdem_dir = DATA_ROOT / "download/arcticdem"
tcvis_dir = DATA_ROOT / "download/tcvis"

In [None]:
def plot_tile(tile: xr.Dataset, ncols=4) -> hv.Layout:  # noqa
    var_plots = [
        tile.hvplot(
            x="x",
            y="y",
            z=z,
            rasterize=True,
            aggregator="max",
            crs=str(tile.rio.crs),
            projection=str(tile.rio.crs),
            cmap="gray",
            colorbar=True,
            data_aspect=1,
            title=z,
        )
        for z in tile.data_vars
    ]
    return hv.Layout(var_plots).cols(ncols)


def plot_tile_interactive(tile: xr.Dataset) -> folium.Map:  # noqa
    m = folium.Map()

    for z in tile.data_vars:
        tile[z].odc.add_to(map=m, name=z)

    folium.LayerControl().add_to(m)
    return m

In [None]:
cache_file = DATA_ROOT / "intermediate" / f"planet_{fpath.stem}.nc"
force = True
slc = {"x": slice(0, 2000), "y": slice(6000, 8000)}
if cache_file.exists() and not force:
    tile = xr.open_dataset(cache_file, engine="h5netcdf", mask_and_scale=False).set_coords("spatial_ref")
else:
    tpi_outer_radius = 100
    buffer = ceil(tpi_outer_radius / 2 * sqrt(2))
    optical = load_planet_scene(fpath).isel(slc)
    arcticdem = load_arcticdem_tile(optical.odc.geobox, arcticdem_dir, buffer=buffer, resolution=2)
    tcvis = load_tcvis(optical.odc.geobox, tcvis_dir)
    data_masks = load_planet_masks(fpath).isel(slc)
    tile = preprocess_legacy_fast(optical, arcticdem, tcvis, data_masks, tpi_outer_radius)
    cache_file.parent.mkdir(exist_ok=True, parents=True)
    tile.to_netcdf(cache_file, engine="h5netcdf")

display(tile)
# plot_tile(tile)

In [None]:
ensemble = EnsembleV1(
    Path("../models/RTS_v6_tcvis.pt"),
    Path("../models/RTS_v6_notcvis.pt"),
)
logging.info(ensemble.rts_v6_tcvis_model.config["input_combination"])
logging.info(ensemble.rts_v6_notcvis_model.config["input_combination"])
tile = ensemble(tile, batch_size=4, keep_inputs=True, patch_size=1024, overlap=256)
display(tile)
# plot_tile(tile)

In [None]:
tile = prepare_export(tile, use_quality_mask=True)
tile

In [None]:
plot_tile(tile)

In [None]:
plot_tile_interactive(tile)