In [None]:
import logging
from pathlib import Path

import branca.colormap as cm
import geopandas as gpd
import pandas as pd
import xarray as xr
from darts_acquisition.utils.copernicus import init_copernicus
from rich import traceback
from rich.logging import RichHandler

logging.getLogger("darts").setLevel(logging.DEBUG)
logging.getLogger("darts").addHandler(RichHandler())

traceback.install(show_locals=True)

xr.set_options(display_expand_attrs=False)

init_copernicus()

labels_dir = Path("/home/pd/tohoel001/repositories/ML_training_labels/retrogressive_thaw_slumps")
data_dir = Path("/isipd/projects/p_aicore_pf/initze/training_data_creation/slumps/03_processed")

In [None]:
def parse_date(row):
    orthotile = row["datasource"] == "PlanetScope OrthoTile"
    if orthotile:
        return pd.to_datetime(row["image_id"].split("_")[-2], format="%Y-%m-%d", utc=True)
    else:
        return pd.to_datetime(row["image_id"].split("_")[0], format="%Y%m%d", utc=True)


def _legacy_path_gen(data_dir: Path):
    for iterdir in data_dir.iterdir():
        if iterdir.stem == "iteration001":
            for sitedir in (iterdir).iterdir():
                for imgdir in (sitedir).iterdir():
                    if not imgdir.is_dir():
                        continue
                    try:
                        yield next(imgdir.glob("*_SR.tif")).parent
                    except StopIteration:
                        yield next(imgdir.glob("*_SR_clip.tif")).parent
        else:
            for imgdir in (iterdir).iterdir():
                if not imgdir.is_dir():
                    continue
                try:
                    yield next(imgdir.glob("*_SR.tif")).parent
                except StopIteration:
                    yield next(imgdir.glob("*_SR_clip.tif")).parent


footprints = (gpd.read_file(footprints_file) for footprints_file in labels_dir.glob("*/ImageFootprints*.gpkg"))
footprints = gpd.GeoDataFrame(pd.concat(footprints, ignore_index=True))
footprints["date"] = footprints.apply(parse_date, axis=1)
fpaths = {fpath.stem: fpath for fpath in _legacy_path_gen(data_dir)}
footprints["fpath"] = footprints.image_id.map(fpaths)
display(footprints.head())
footprints.info()

In [None]:
from darts_acquisition.s2 import match_s2ids_from_geodataframe_stac

footprints = footprints.take([0, 1, 2])
matches = match_s2ids_from_geodataframe_stac(
    aoi=footprints, day_range=14, max_cloud_cover=10, simplify_geometry=0.1, save_scores=Path("./s2-scores.parquet")
)
matches

In [None]:
import geopandas as gpd

gpd.read_parquet(Path("./s2-scores.parquet"))

In [None]:
footprints["s2_item"] = footprints.index.map(matches)
footprints["s2id"] = footprints.s2_item.map(lambda x: x.id if x is not None else None)

In [None]:
matches_gdf = gpd.GeoDataFrame.from_features([item.to_dict() for item in matches.values() if item], crs="EPSG:4326")
matches_gdf["s2id"] = [item.id for item in matches.values() if item]
m = matches_gdf.explore(column="s2id", cmap=[cm.step.Pastel1_08.rgb_hex_str(i / 8) for i in range(8)])
footprints.explore(m=m, column="s2id", cmap=[cm.step.Set1_08.rgb_hex_str(i / 8) for i in range(8)])
m

In [None]:
for i, footprint in footprints.iterrows():
    if footprint["s2_item"] is None:
        print(f"No matching Sentinel-2 item found for {footprint['image_id']}.")
        continue

    s2_item = footprint["s2_item"]
    print(f"Found matching Sentinel-2 item {s2_item.id} for {footprint['image_id']}.")
    break
s2_item

In [None]:
from darts_acquisition.planet import load_planet_masks, load_planet_scene
from darts_acquisition.s2 import load_s2_from_stac
from matplotlib import pyplot as plt

s2ds = load_s2_from_stac(s2_item, cache=Path("/isipd/projects/p_aicore_pf/darts-nextgen/data/cache/s2stac"))
display(s2ds)
planetds = load_planet_scene(footprint.fpath)
planet_mask = load_planet_masks(footprint.fpath).quality_data_mask == 2
display(planetds)

fig, axs = plt.subplots(2, 5, figsize=(20, 6))

s2ds.red[::20, ::20].plot(ax=axs[0, 0], cmap="Reds", vmax=0.3)
s2ds.green[::20, ::20].plot(ax=axs[0, 1], cmap="Greens", vmax=0.3)
s2ds.blue[::20, ::20].plot(ax=axs[0, 2], cmap="Blues", vmax=0.3)
s2ds.nir[::20, ::20].plot(ax=axs[0, 3], cmap="gray", vmax=0.3)
s2ds.quality_data_mask[::20, ::20].plot(ax=axs[0, 4])

planetds.red[::20, ::20].plot(ax=axs[1, 0], cmap="Reds")
planetds.green[::20, ::20].plot(ax=axs[1, 1], cmap="Greens")
planetds.blue[::20, ::20].plot(ax=axs[1, 2], cmap="Blues")
planetds.nir[::20, ::20].plot(ax=axs[1, 3], cmap="gray")
planet_mask[::20, ::20].plot(ax=axs[1, 4])

In [None]:
from darts_acquisition.utils.arosics import align

logging.getLogger("darts").setLevel(logging.DEBUG)
align(
    s2ds.shift(x=-10),
    planetds.astype("float32"),
    target_mask=s2ds.quality_data_mask.shift(x=-10) == 2,
    reference_mask=planet_mask,
    resample_to="target",
    return_offset=True,
)

In [None]:
s2ds_aligned, oi = align(
    s2ds.shift(x=-10),
    s2ds,
    target_mask=s2ds.quality_data_mask == 2,
    reference_mask=s2ds.quality_data_mask == 2,
    window_size=64,
    bands=["red", "green", "blue", "nir"],
    return_offset=True,
)
display(oi)
(s2ds_aligned.red - s2ds.red)[::20, ::20].plot()