In [None]:
import logging
from pathlib import Path

import matplotlib.pyplot as plt
import xarray as xr
from darts_ensemble.ensemble_v1 import EnsembleV1
from darts_postprocessing.prepare_export import prepare_export
from darts_preprocessing.preprocess import load_and_preprocess_planet_scene
from rich import traceback
from rich.logging import RichHandler

from darts.utils.earthengine import init_ee
from darts.utils.logging import setup_logging

setup_logging()
logging.basicConfig(
    level=logging.INFO,
    format="%(message)s",
    datefmt="[%X]",
    handlers=[RichHandler(rich_tracebacks=True)],
)
traceback.install(show_locals=False)
init_ee("ee-tobias-hoelzer")

In [None]:
DATA_ROOT = Path("../data")

# fpath = DATA_ROOT / "input/planet/PSOrthoTile/4372514/5790392_4372514_2022-07-16_2459"
fpath = DATA_ROOT / "input/planet/PSOrthoTile/4974017/5854937_4974017_2022-08-14_2475"
arcticdem_dir = DATA_ROOT / "input/ArcticDEM"
cache_dir = DATA_ROOT / "download"

In [None]:
cache_file = DATA_ROOT / "intermediate" / f"planet_{fpath.stem}.nc"
force = False
if cache_file.exists() and not force:
    tile = xr.open_dataset(cache_file, engine="h5netcdf", mask_and_scale=False)
else:
    tile = load_and_preprocess_planet_scene(fpath, arcticdem_dir, cache_dir)
    cache_file.parent.mkdir(exist_ok=True, parents=True)
    tile.to_netcdf(cache_file, engine="h5netcdf")

tile

In [None]:
tile_low_res = tile.coarsen(x=16, y=16, boundary="trim").mean()
fig, axs = plt.subplots(2, 6, figsize=(36, 10))
axs = axs.flatten()
for i, v in enumerate(tile_low_res.data_vars):
    tile_low_res[v].plot(ax=axs[i], cmap="gray")
    axs[i].set_title(v)

In [None]:
ensemble = EnsembleV1(
    "../models/RTS_v6_tcvis.pt",
    "../models/RTS_v6_notcvis.pt",
)
logging.info(ensemble.rts_v6_tcvis_model.config["input_combination"])
logging.info(ensemble.rts_v6_notcvis_model.config["input_combination"])
tile = ensemble(tile, batch_size=4, keep_inputs=True, patch_size=1024, overlap=256)
tile

In [None]:
final_low_res = tile.coarsen(x=16, y=16, boundary="trim").mean()
fig, axs = plt.subplots(2, 8, figsize=(52, 10))
axs = axs.flatten()
for i, v in enumerate(final_low_res.data_vars):
    if v == "probabilities":
        final_low_res[v].plot(ax=axs[i], cmap="gray", vmin=0, vmax=1)
    else:
        final_low_res[v].plot(ax=axs[i], cmap="gray")
    axs[i].set_title(v)

In [None]:
tile = prepare_export(tile)
tile

In [None]:
final_low_res = tile.coarsen(x=16, y=16, boundary="trim").mean()
fig, axs = plt.subplots(2, 8, figsize=(52, 10))
axs = axs.flatten()
for i, v in enumerate(final_low_res.data_vars):
    if v == "probabilities":
        final_low_res[v].plot(ax=axs[i], cmap="gray", vmin=0, vmax=100)
    else:
        final_low_res[v].plot(ax=axs[i], cmap="gray")
    axs[i].set_title(v)