In [None]:
import logging
from pathlib import Path

import matplotlib.pyplot as plt
import xarray as xr
from darts_postprocessing.prepare_export import prepare_export
from darts_preprocessing.preprocess_tobi import load_and_preprocess_planet_scene
from darts_segmentation.segment import SMPSegmenter
from lovely_tensors import monkey_patch
from rich import traceback
from rich.logging import RichHandler

xr.set_options(display_expand_data=False)

# Set up logging
logging.basicConfig(level=logging.INFO, handlers=[RichHandler()])
logging.getLogger("darts_preprocessing").setLevel(logging.DEBUG)
logging.getLogger("darts_segmentation").setLevel(logging.DEBUG)

monkey_patch()
traceback.install(show_locals=True)

In [2]:
DATA_ROOT = Path("../data/input")

# fpath = DATA_ROOT / "planet/PSOrthoTile/4372514/5790392_4372514_2022-07-16_2459"
fpath = DATA_ROOT / "planet/PSOrthoTile/4974017/5854937_4974017_2022-08-14_2475"
scene_id = fpath.parent.name

# TODO: change to vrt
elevation_path = DATA_ROOT / "ArcticDEM" / "relative_elevation" / f"{scene_id}_relative_elevation_100.tif"
slope_path = DATA_ROOT / "ArcticDEM" / "slope" / f"{scene_id}_slope.tif"


In [None]:
tile = load_and_preprocess_planet_scene(fpath, elevation_path, slope_path)
tile

In [None]:
tile_low_res = tile.coarsen(x=16, y=16, boundary="trim").mean()
fig, axs = plt.subplots(2, 5, figsize=(30, 10))
axs = axs.flatten()
for i, v in enumerate(tile_low_res.data_vars):
    tile_low_res[v].plot(ax=axs[i], cmap="gray")
    axs[i].set_title(v)

In [None]:
model = SMPSegmenter("../models/RTS_v6_notcvis.pt")
tile = model.segment_tile(tile, batch_size=4)
final = prepare_export(tile)

In [None]:
final_low_res = final.coarsen(x=16, y=16, boundary="trim").mean()
fig, axs = plt.subplots(2, 6, figsize=(36, 10))
axs = axs.flatten()
for i, v in enumerate(final_low_res.data_vars):
    final_low_res[v].plot(ax=axs[i], cmap="gray")
    axs[i].set_title(v)