In [None]:
from pathlib import Path

import ee
import numpy as np
import odc.geo.xr
import xarray as xr
from matplotlib import pyplot as plt

ee.Initialize()

In [None]:
def _plot_optical(tile, vmax=10000):
    slc = slice(None, None, 20)
    tile[["red", "green", "blue", "nir"]].sel(x=slc, y=slc).to_dataarray().plot(
        col="variable", col_wrap=2, vmin=0, vmax=vmax
    )


def _plot_optical_hist(tile, vmax=10000):
    slc = slice(None, None, 20)
    fig, axs = plt.subplots(2, 2, figsize=(6, 6))
    for i, band in enumerate(["red", "green", "blue", "nir"]):
        ax = axs[i // 2, i % 2]
        tile[band].plot.hist(ax=ax, bins=np.linspace(0, vmax, 100))
        ax.set_title(band)
    plt.tight_layout()

## Planet

In [None]:
# Planet
from darts_acquisition import load_planet_masks, load_planet_scene

# 20240822_053541_20_24f2
# 20200805_014140_1039
fpath = Path("/isipd/projects/p_aicore_pf/initze/data/planet/planet_data_inference_grid/scenes/20240822_053541_20_24f2")
optical = load_planet_scene(fpath)
data_masks = load_planet_masks(fpath)
tile_planet = xr.merge([optical, data_masks])
tile_planet

In [None]:
_plot_optical(tile_planet)

In [None]:
_plot_optical(tile_planet, vmax=3000)

In [None]:
_plot_optical_hist(tile_planet)

In [None]:
_plot_optical_hist(tile_planet, vmax=3000)

## GEE

In [None]:
geom = ee.Geometry.Polygon(list(tile_planet.odc.geobox.geographic_extent.geom.exterior.coords))
ic = (
    ee.ImageCollection("COPERNICUS/S2_SR_HARMONIZED")
    .filterBounds(geom)
    # .filterDate("2020-08-01", "2020-08-02")
    .filterDate("2024-08-21", "2024-08-28")
    .filterMetadata("CLOUDY_PIXEL_PERCENTAGE", "less_than", 10)
)
s2geeid = ic.aggregate_array("system:index").getInfo()[1]
s2geeid

In [None]:
from darts_utils.tilecache import XarrayCacheManager


def _download_gee(img, cache):
    bands_mapping: dict = {"B2": "blue", "B3": "green", "B4": "red", "B8": "nir"}

    if isinstance(img, str):
        s2id = img
        img = ee.Image(f"COPERNICUS/S2_SR_HARMONIZED/{s2id}")
    else:
        s2id = img.id().getInfo().split("/")[-1]

    if "SCL" not in bands_mapping.keys():
        bands_mapping["SCL"] = "scl"

    img = img.select(list(bands_mapping.keys()))

    def _get_tile():
        ds_s2 = xr.open_dataset(
            img,
            engine="ee",
            geometry=img.geometry(),
            crs=img.select(0).projection().crs().getInfo(),
            scale=10,
        )
        ds_s2.attrs["time"] = str(ds_s2.time.values[0])
        ds_s2 = ds_s2.isel(time=0).drop_vars("time").rename({"X": "x", "Y": "y"}).transpose("y", "x")
        ds_s2 = ds_s2.odc.assign_crs(ds_s2.attrs["crs"])
        ds_s2.load()
        return ds_s2

    ds_s2 = XarrayCacheManager(cache).get_or_create(
        identifier=f"gee-s2srh-{s2id}-{''.join(bands_mapping.keys())}",
        creation_func=_get_tile,
        force=False,
    )

    ds_s2 = ds_s2.rename_vars(bands_mapping)
    return ds_s2


tile_s2gee = _download_gee(s2geeid, "/isipd/projects/p_aicore_pf/darts-nextgen/data/cache/s2gee")
tile_s2gee = tile_s2gee.odc.crop(tile_planet.odc.geobox.extent)
tile_s2gee = tile_s2gee.odc.reproject(how=tile_planet.odc.geobox)
tile_s2gee = tile_s2gee.where(~tile_planet.red.isnull())
tile_s2gee

In [None]:
_plot_optical(tile_s2gee)

In [None]:
_plot_optical(tile_s2gee, vmax=3000)

In [None]:
_plot_optical_hist(tile_s2gee)

In [None]:
_plot_optical_hist(tile_s2gee, vmax=3000)

In [None]:
from darts_acquisition.s2 import search_s2_stac

# s2stacitems = search_s2_stac(tile_planet.odc.geobox.geographic_extent, "2024-08-01", "2024-08-03", 10)
s2stacitems = search_s2_stac(tile_planet.odc.geobox.geographic_extent, "2024-08-21", "2024-08-24", 10)

# Match gee item
s2id = [item for item in s2stacitems.keys() if s2geeid.split("_")[-1] in item][0]
s2stacitems[s2id]

In [None]:
from collections.abc import MutableMapping

from darts_acquisition.utils.copernicus import init_copernicus
from odc.stac import stac_load
from pystac import Item
from pystac_client import Client


def _flatten_dict(d: MutableMapping, parent_key: str = "", sep: str = ".") -> MutableMapping:
    items = []
    for k, v in d.items():
        new_key = parent_key + sep + k if parent_key else k
        if isinstance(v, MutableMapping):
            items.extend(_flatten_dict(v, new_key, sep=sep).items())
        else:
            items.append((new_key, v))
    return dict(items)


def _download_stac(s2item, cache):
    bands_mapping: dict = {"B02_10m": "blue", "B03_10m": "green", "B04_10m": "red", "B08_10m": "nir"}
    s2id = s2item.id if isinstance(s2item, Item) else s2item

    if "SCL_20m" not in bands_mapping.keys():
        bands_mapping["SCL_20m"] = "scl"

    def _get_tile():
        nonlocal s2item

        bands = list(bands_mapping.keys())

        if isinstance(s2item, str):
            catalog = Client.open("https://stac.dataspace.copernicus.eu/v1/")
            search = catalog.search(
                collections=["sentinel-2-l2a"],
                ids=[s2id],
            )
            s2item = next(search.items())

        # We can't use xpystac here, because they enforce chunking of 1024x1024, which results in long loading times
        # and a potential AWS limit error.
        init_copernicus(profile_name="default")
        ds_s2 = stac_load(
            [s2item],
            bands=bands,
            crs="utm",
            resolution=10,
        )

        ds_s2.attrs = _flatten_dict(s2item.properties)
        # Convert boolean values to int, since they are not supported in netcdf
        for key, value in ds_s2.attrs.items():
            if isinstance(value, bool):
                ds_s2.attrs[key] = int(value)
        ds_s2.attrs["time"] = str(ds_s2.time.values[0])
        ds_s2 = ds_s2.isel(time=0).drop_vars("time")
        return ds_s2

    ds_s2 = XarrayCacheManager(cache).get_or_create(
        identifier=f"stac-s2l2a-{s2id}-{''.join(bands_mapping.keys())}", creation_func=_get_tile, force=False
    )

    ds_s2 = ds_s2.rename_vars(bands_mapping)
    return ds_s2


tile_s2 = _download_stac(s2stacitems[s2id], "/isipd/projects/p_aicore_pf/darts-nextgen/data/cache/s2stac")
tile_s2 = tile_s2.odc.crop(tile_planet.odc.geobox.extent)
tile_s2 = tile_s2.odc.reproject(how=tile_planet.odc.geobox)
tile_s2 = tile_s2.where(~tile_planet.red.isnull())
tile_s2

In [None]:
_plot_optical(tile_s2)

In [None]:
_plot_optical(tile_s2, vmax=3000)

In [None]:
_plot_optical_hist(tile_s2)

In [None]:
_plot_optical_hist(tile_s2, vmax=3000)