In [None]:
import datetime as dt

import matplotlib.pyplot as plt
import odc.stac as odc_stac
import pystac
import pystac_client
import xarray as xr
from odc.geo.geobox import GeoBox

In [None]:
type Bounds = tuple[float, float, float, float]


def get_items(bounds: Bounds, timerange: str) -> pystac.ItemCollection:
    """Get Sentinel-2 items for a given bounding box and time range."""
    return (
        pystac_client.Client.open("https://earth-search.aws.element84.com/v1")
        .search(
            bbox=bounds,
            collections=["sentinel-2-l2a"],
            datetime=timerange,
            limit=100,
        )
        .item_collection()
    )

In [None]:
# Common parameters
dx: float = 0.002  # 0.0006  # 60m resolution
epsg = 4326

# Set Spatial extent
latmin: float = -19.6
latmax: float = -18.1
lonmin: float = 32.9
lonmax: float = 34.4
bounds: Bounds = (lonmin, latmin, lonmax, latmax)


# Set Temporal extent
year_before: int = 2017
year_after: int = 2020
month_start: int = 11
month_end: int = 12
day_end: int = 31
timerange_after: str = (
    f"{year_after}-{month_start}-01/{year_after}-{month_end}-{day_end}"
)
timerange_before: str = (
    f"{year_before}-{month_start}-01/{year_before}-{month_end}-{day_end}"
)

# Search for Sentinel-2 data
items_before = get_items(bounds, timerange_before)
items_after = get_items(bounds, timerange_after)

In [None]:
geobox = GeoBox.from_bbox(bounds, crs=f"epsg:{epsg}", resolution=dx)
dc_before = odc_stac.load(
    items_before,
    bands=["scl", "red", "green", "blue", "nir"],
    chunks={"time": 5, "x": 600, "y": 600},
    geobox=geobox,
)
dc_after = odc_stac.load(
    items_after,
    bands=["scl", "red", "green", "blue", "nir"],
    chunks={"time": 5, "x": 600, "y": 600},
    geobox=geobox,
)

In [None]:
# Preprocess the data
from enum import Enum


class SCLValues(Enum):
    """Enum for Sentinel-2 Scene Classification Layer (SCL) values."""

    NO_DATA = 0
    SATURATED_DEFECTIVE = 1
    DARK_AREA = 2
    CLOUD_SHADOW = 3
    VEGETATION = 4
    BARE_SOIL = 5
    WATER = 6
    CLOUD_LOW_PROB = 7
    CLOUD_MEDIUM_PROB = 8
    CLOUD_HIGH_PROB = 9
    CIRRUS = 10
    SNOW_ICE = 11


def is_valid_pixel(data: xr.DataArray) -> xr.DataArray:
    """Check if the pixel is valid based on the SCL band."""
    # include only vegetated, not_vegitated, water, and snow
    return ((data > 3) & (data < 7)) | (data == 11)


dc_before["valid"] = is_valid_pixel(dc_before["scl"])
dc_after["valid"] = is_valid_pixel(dc_after["scl"])


In [None]:
mosaic_before = dc_before.where(dc_before["valid"]).median(
    dim="time",
    skipna=True,
)
mosaic_after = dc_after.where(dc_after["valid"]).median(
    dim="time",
    skipna=True,
)


def normalized_difference(
    band1: xr.DataArray,
    band2: xr.DataArray,
) -> xr.DataArray:
    """Calculate the normalized difference between two bands."""
    return (band1 - band2) / (band1 + band2)


mosaic_before["ndvi"] = normalized_difference(
    mosaic_before["nir"],
    mosaic_before["red"],
)

mosaic_after["ndvi"] = normalized_difference(
    mosaic_after["nir"],
    mosaic_after["red"],
)

In [None]:
def get_mosaic_time_title(
    year: int,
    month_start: int,
    month_end: int,
    day_end: int,
) -> str:
    """Generate a title for the mosaic based on the time range."""
    start_date = dt.datetime(
        year=year,
        month=month_start,
        day=1,
        tzinfo=dt.UTC,
    )
    end_date = dt.datetime(
        year=year,
        month=month_end,
        day=day_end,
        tzinfo=dt.UTC,
    )
    fmt_start: str = start_date.strftime("%d.%b")
    fmt_end: str = end_date.strftime("%d.%b %Y")
    return f"{fmt_start} - {fmt_end}"


timestamp_title_before = get_mosaic_time_title(
    year_before,
    month_start,
    month_end,
    day_end,
)
timestamp_title_after = get_mosaic_time_title(
    year_after,
    month_start,
    month_end,
    day_end,
)

In [None]:
# Classification of "Forest"
# Could be more sophisticated...
threshold: float = 0.6
forest_before = (mosaic_before["ndvi"] >= threshold) * (1 << 0)
forest_after = (mosaic_after["ndvi"] >= threshold) * (1 << 1)

In [None]:
fig, ax = plt.subplots()
threshold: float = 0.6
forest_before.plot.imshow(
    ax=ax,
    robust=True,
    vmin=-1,
    vmax=1,
    cmap="RdYlGn",
)
ax.set_title(f"NDVI >= {threshold}  {timestamp_title_before}")
plt.show()

In [None]:
fig, ax = plt.subplots()
forest_after.plot.imshow(
    ax=ax,
    robust=True,
    vmin=-1,
    vmax=1,
    cmap="RdYlGn",
)
ax.set_title(f"NDVI   {timestamp_title_after}")
plt.show()

In [None]:
def get_mosaic_time_title(
    year_start: int,
    year_end: int,
    month_start: int,
    month_end: int,
    day_end: int,
) -> str:
    """Generate a title for the mosaic based on the time range."""
    start_date = dt.datetime(
        year=year_start,
        month=month_start,
        day=1,
        tzinfo=dt.UTC,
    )
    end_date = dt.datetime(
        year=year_start,
        month=month_end,
        day=day_end,
        tzinfo=dt.UTC,
    )
    fmt_start: str = start_date.strftime("%d.%b")
    fmt_end: str = end_date.strftime("%d.%b")
    return f"{year_start} - {year_end} ({fmt_start} - {fmt_end})"


diff_time_title = get_mosaic_time_title(
    year_before,
    year_after,
    month_start,
    month_end,
    day_end,
)
diff_time_title

In [None]:
# Difference
diff = forest_before + forest_after
diff = diff.where(diff == 1)
fig, ax = plt.subplots()
diff.plot.imshow(ax=ax, robust=True, cmap="Reds")
ax.set_title(f"NDVI Difference {diff_time_title}")
plt.show()