In [None]:
from pystac_client import Client
from odc.stac import load, configure_s3_access

from dask.distributed import Client as DaskClient

from dea_tools.coastal import pixel_tides
from dea_tools.spatial import subpixel_contours, points_on_line
from coastlines.utils import tide_cutoffs
from coastlines.vector import annual_movements, calculate_regressions

from pathlib import Path

import numpy as np

In [None]:
# STAC Catalog URL
catalog = "https://earth-search.aws.element84.com/v1"

# Create a STAC Client
client = Client.open(catalog)

# This line will fail if you don't have credentials configured
_ = configure_s3_access(cloud_defaults=True, requester_pays=True)

# Set up a dask client
dask_client = DaskClient(n_workers=4, threads_per_worker=4)
dask_client

In [None]:
# Find a location you're interested in on Google Maps and copy the coordinates
# by right-clicking on the map and clicking the coordinates

# These coords are in the order Y then X, or Latitude then Longitude
# coords = 20.775, 106.780  # Near Haiphong
coords = 12.293, 109.225  # Near Phuong Vinh Hoa 
buffer = 0.05
bbox = (coords[1] - buffer, coords[0] - buffer, coords[1] + buffer, coords[0] + buffer)
landsat_stretch = dict(vmin=7500, vmax=18000)

datetime = "2019/2024"

# Tide data and config
home = Path("~")
tide_data_location = f"{home}/tide_models"

In [None]:
items = client.search(
    collections=["landsat-c2-l2"],
    bbox=bbox,
    datetime=datetime,
).item_collection()

print(f"Found {len(items)} items")

In [None]:
data = load(
    items,
    bbox=bbox,
    collection="landsat-c2-l2",
    measurements=["red", "green", "blue", "nir08", "swir16", "qa_pixel"],
    groupby="solar_day",
    chunks={"x": 2048, "y": 2048},
)
data

In [None]:
data[["red", "green", "blue"]].isel(time=[0,1,2,3]).to_array().plot.imshow(col="time", col_wrap=2, size=6, **landsat_stretch)

In [None]:
# Bit flag mask for the QA_PIXEL band
# We need bits 3 and 4, which are the 4th and 5th bits from the right (0-indexed)
bitflags = 0b00011000
cloud_mask = (data.qa_pixel & bitflags) != 0

# Prepare a nodata mask
nodata = data.red == data.red.odc.nodata

# Combine the cloud mask and the nodata mask
mask = cloud_mask | nodata

# Apply the mask to the data
masked = data.where(~mask, other=np.nan)

In [None]:
masked[["red", "green", "blue"]].isel(time=[0, 1, 2, 3]).to_array().plot.imshow(col="time", col_wrap=2, size=6, **landsat_stretch)

In [None]:
# Add tide height to the data
tides_hires, tides_lowres = pixel_tides(
    masked, resample=True, directory=tide_data_location, model="FES2022", dask_compute=True
)

# Determine tide cutoff
tide_cutoff_min, tide_cutoff_max = tide_cutoffs(data, tides_lowres, tide_centre=0.0)

tide_bool = (tides_hires >= tide_cutoff_min) & (tides_hires <= tide_cutoff_max)
data_filtered = data.sel(time=tide_bool.sum(dim=["x", "y"]) > 0)

# Apply mask, and load in corresponding tide masked data
data_tide_masked = data_filtered.where(tide_bool)

print(data_tide_masked)

In [None]:
# Create MNDWI index
data_tide_masked["mndwi"] = (data_tide_masked.green - data_tide_masked.swir16) / (data_tide_masked.green + data_tide_masked.swir16)
data_tide_masked["ndwi"] = (data_tide_masked.green - data_tide_masked.nir08) / (data_tide_masked.green + data_tide_masked.nir08)
data_tide_masked["combined"] = (data_tide_masked.mndwi + data_tide_masked.ndwi) / 2

# # Group by year and calculate the median
combined_by_year = data_tide_masked.combined.groupby("time.year").median().to_dataset(name="combined").compute()
combined_by_year

In [None]:
combined_by_year.combined.plot.imshow(col="year", col_wrap=2, size=6, cmap="RdBu", robust=True)

In [None]:
# NOTE: this should be much simpler...
# TODO: Find out why the subpixel_contours function is not working without manually "squeeze()"ing the data

contour_arrays = {}
for i, da in combined_by_year.groupby("year"):    
    contours = subpixel_contours(
        da=da.combined.squeeze(),
        z_values=0.0,
        crs=combined_by_year.geobox.crs,
        min_vertices=15
    )
    contour_arrays[i] = contours

import geopandas as gpd
import pandas as pd

contour_gdf = gpd.GeoDataFrame(data={"year": list(contour_arrays.keys())}, geometry=pd.concat(contour_arrays.values(), ignore_index=True).geometry)

contour_gdf = contour_gdf.set_index("year")

contour_gdf.reset_index().explore(
    column="year",
    cmap="magma",
)

In [None]:
# Extract points at every 30 metres along the most recent shoreline
points_gdf = points_on_line(contour_gdf, index=2023, distance=30)
points_gdf.plot(markersize=3)

In [None]:
points_gdf = annual_movements(
    points_gdf, contours_gdf=contour_gdf, yearly_ds=combined_by_year, baseline_year=2023, water_index="combined"
)
points_gdf = calculate_regressions(points_gdf=points_gdf)

points_gdf

In [None]:
# Add human-friendly label for plotting
points_gdf["Coastal change"] = points_gdf.apply(
    lambda x: f'<h4>This coastline has {"<b>retreated</b>" if x.rate_time < 0 else "<b>grown</b>"} '
    f"by</br><b>{x.rate_time:.2f} m (±{x.se_time:.1f}) per year</b> since "
    f"<b>{contour_gdf.index[0]}</b></h4>",
    axis=1,
)
points_gdf.loc[points_gdf.sig_time > 0.05, "Coastal change"] = f"<h4>No significant trend of retreat or growth)</h4>"

m = contour_gdf.reset_index().explore(
    column="year",
    cmap="inferno",
    tiles="https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{z}/{y}/{x}",
    tooltip=False,
    style_kwds={"opacity": 0.5},
    attr="ESRI WorldImagery",
    categorical=True,
)

points_gdf.explore(
    m=m,
    column="rate_time",
    cmap="RdBu",
    markersize=5,
    tooltip="Coastal change",
)