In [None]:
import os
import sys

os.environ['USE_PYGEOS'] = '0'
import json

import folium
import geohash
import geopandas
import holoviews as hv
import hvplot.xarray  # noqa: F401
import numpy as np
import odc.geo.xr
import pystac
import rioxarray  # noqa: F401
import s3fs
import shapely
import xarray as xr
from dea_tools.spatial import xr_rasterize, xr_vectorize
from fsspec.implementations.http import HTTPFileSystem
from holoviews import opts
from odc.stac import configure_rio, load

# Hacking imports from parent directory :-)
NOTEBOOK_DIR = os.path.dirname(os.path.abspath('__file__'))
sys.path.append(os.path.dirname(NOTEBOOK_DIR))

from utils import get_rgb_dataset

In [None]:
import odc.stac
if odc.stac.__version__ != "0.3.6":
    raise Exception(f"You need to use odc.stac version 0.3.6 or greater. You have {odc.stac.__version__}")

In [None]:
s3_uri = "s3://files.auspatious.com/hsi_example/TD1_004930_20230205_L2A_20230224_03001065_COG.stac-item.json"

# Open S3 object as a file using s3fs
s3 = s3fs.S3FileSystem(anon=True)
with s3.open(s3_uri, "rt") as f:
    stac_dict = json.load(f)
item = pystac.read_dict(stac_dict)

# Optionally select a subset, so it doesn't take a long time
eo_bands_subset = item.assets["reflectance"].extra_fields["eo:bands"]

# Load the data, telling rasterio to not sign requests
configure_rio(
    cloud_defaults=True,
    aws={"aws_unsigned": True},
    AWS_S3_ENDPOINT="s3.ap-southeast-2.amazonaws.com",
)
ds = load(
    [item],
    measurements=[i["name"] for i in eo_bands_subset],
    chunks={"bands": 1, "longitude": 1200, "latitude": 1200}
)

# No need for time
ds = ds.squeeze("time")

# Stack up the bands, so we have a multi-dimensional raster instead
ds_stacked = ds.to_array("bands")

# Replace the original ds object with a nice indexed one
bands = list([float(i["description"]) for i in eo_bands_subset])
bands.sort()  # This is pretty dangerous... let's assume the .tif has bands in the right order!
ds = ds_stacked.assign_coords(bands=bands).to_dataset(name="reflectance")

# mask 0 as nan
ds = ds.where(ds != 0)

# Load all the data. Should take less than 2 minutes
ds = ds.compute()

ds


In [None]:
# Create a water layer
# Picked values from here https://en.wikipedia.org/wiki/Normalized_difference_water_index
green = ds.reflectance.sel(bands=559, method="nearest").astype("float32")
nir = ds.reflectance.sel(bands=864, method="nearest").astype("float32")

water = ((green - nir) / (green + nir)) > 0.2
ds["water"] = water.fillna(float("nan")).where(water)

In [None]:
ds.water.hvplot(aspect="equal")

In [None]:
MIN_AREA = 80  # Hectares

def add_geohash(row):
    return geohash.encode(row.geometry.centroid.y, row.geometry.centroid.x, precision=9)

    
# Create polygons from the water layer
water_polygons = xr_vectorize(ds.water, crs="epsg:4326", mask=ds.water.values==1)
water_polygons["area"] = water_polygons.to_crs("epsg:3577").area / 10000

# Drop geopandas rows where the area is less than MIN_AREA
water_polygons = water_polygons.drop(water_polygons[water_polygons['area'] < MIN_AREA].index)

# Compute a geohash for each polygon at level 9
geohashes = []
for _, row in water_polygons.iterrows():
    geohashes.append(add_geohash(row))

water_polygons["geohash"] = geohashes

# Add an ID row
water_polygons['id'] = range(1, water_polygons.shape[0] + 1)

# Show us what we've got
print(f"Found {water_polygons.shape[0]} water polygons that are larger than {MIN_AREA} hectare(s)")

In [None]:
# View the water layer on an interactive map

# Reduce the polygons by a small amount (in degrees)
# 0.001 is around 100 m
SHRINK_AMOUNT = 0.001

m = folium.Map(control_scale=True, tiles=None)

for _, row in water_polygons.iterrows():
    geojson = folium.GeoJson(
        data=json.dumps(shapely.geometry.mapping(row.geometry.buffer(-1 * SHRINK_AMOUNT))),
        style_function=lambda x: {"fillColor": "blue", "Color": "blue"},
        tooltip=f"{row.geohash}"
    )
    folium.Popup(f"<p><strong>geohash:</strong> {row.geohash}<br><strong>area:</strong> {row['area']:.3f} Ha</p>").add_to(
        geojson
    )
    geojson.add_to(m)

# Zoom map
m.fit_bounds(ds.odc.map_bounds())

tile = folium.TileLayer(
    tiles="https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{z}/{y}/{x}",
    attr="Esri",
    name="Esri Satellite",
    control=True,
).add_to(m)

folium.LayerControl().add_to(m)
display(m)

In [None]:
# Rasterise the polygons again, so we can join on the geohash later
# First shrink them by SHRINK_AMOUNT
water_polygons.geometry = water_polygons.geometry.buffer(-1 * SHRINK_AMOUNT)
water_raster = xr_rasterize(water_polygons, ds, attribute_col="id", crs="epsg:4326")

# Join the rasterised polygons to the dataset
ds["id"] = xr.DataArray(water_raster, dims=("latitude", "longitude"))

# Create another empty array of strings
ds["geohash"] = xr.DataArray(
      np.full((ds.latitude.size, ds.longitude.size), "", dtype="U9"),
      dims=("latitude", "longitude"),
)

for _, row in water_polygons.iterrows():
   # I think 'where' works the opposite of what you'd expect
   ds["geohash"] = ds.geohash.where(ds.id != row.id, row.geohash)

# Mask the empty values
ds["geohash"] = ds.geohash.where(ds.geohash != "", drop=False)
del ds["id"]

In [None]:
%%capture --no-stdout

means = ds.groupby("geohash").mean()
std_dev = ds.groupby("geohash").std()
min = ds.groupby("geohash").min()
max = ds.groupby("geohash").max()

# Create a new dataset with the mean, standard deviation, min and max values
# for each geohash
water_summaries = xr.Dataset(
    {
        "mean": means.reflectance,
        "std_dev": std_dev.reflectance,
        "min": min.reflectance,
        "max": max.reflectance,
    }
)

In [None]:
# Mean and std_dev plots
color_cycle = hv.Cycle("Category20")

plots = []
for geohash in water_summaries.geohash.values:
    row = water_summaries.sel(geohash=geohash)

    plots.append(
        (
            hv.Spread(
                row,
                vdims=["mean", "std_dev", "std_dev"],
                label=f"{geohash}"
            )
            * hv.Curve(
                row,
                vdims="mean",
                label=f"{geohash}"
            )
        )
    )

hv.Layout(plots).opts(
    opts.Spread(color=color_cycle, show_legend=True),
    opts.Curve(color=color_cycle, show_legend=True),
    opts.Overlay(
        show_title=True, frame_width=200, frame_height=50, show_legend=False, yaxis=None
    ),
).cols(4)

In [None]:
# Calculate absorption depth, to be plotted per geohash
absorption = 627
reference_band = 560
reference_band2 = 648

absorption = ds.reflectance.sel(bands=absorption, method = 'nearest')
reference1= ds.reflectance.sel(bands=reference_band, method = 'nearest')
reference2 = ds.reflectance.sel(bands=reference_band2, method = 'nearest')
ds["absorption_depth"] = (reference1 + reference2)/2 - absorption

# Simplify to a summary dataset, removing the bands dimension and reflectance data
ds_summary = ds.drop_dims("bands")

In [None]:
# Violin plots grouped by geohash
ds_summary.hvplot.violin(
    y="absorption_depth",
    by="geohash",
).opts(
    opts.Violin(
        width=800,
        height=600,
        xrotation=45,
        show_legend=False,
        title="Absorption Depth",
        # ylim=(-0.02, 0.03),
        violin_fill_color='absorption_depth',
        cmap = 'Spectral_r',
        # clim = (0, 0.02),
    )
)


In [None]:
# Below here we're going to select water bodies (geohashes) by thresholding
# the absorption depth.

geohash_absorption_depths = ds_summary.groupby("geohash").mean()

geohash_absorption_depths.hvplot(
    x="geohash",
    y="absorption_depth",
    kind="scatter",
    title="Mean absorption depth",
    color="absorption_depth",
    cmap = 'magma_r',
    # Line color
    line_color="grey",
    size=40,
    xaxis=None,
)

In [None]:
abs_d_gt_001 = geohash_absorption_depths.where(geohash_absorption_depths.absorption_depth < 1000, drop=True)
high_absv = list(abs_d_gt_001.geohash.values)

# Mean and std_dev plots
color_cycle = hv.Cycle("Category20")

plots = []
for geohash in high_absv:
    row = water_summaries.sel(geohash=geohash)

    plots.append(
        (
            hv.Spread(
                row,
                vdims=["mean", "std_dev", "std_dev"],
                label=f"{geohash}"
            )
            * hv.Curve(
                row,
                vdims="mean",
                label=f"{geohash}"
            )
        )
    )

# Add a mean of all waterbodies plot too
mean_all = water_summaries.mean("geohash")
plots.append(
    (
        hv.Spread(
            mean_all,
            vdims=["mean", "std_dev", "std_dev"],
            label=f"all"
        )
        * hv.Curve(
            mean_all,
            vdims="mean",
            label=f"all"
        )
    )
)

hv.Overlay(plots).opts(
    opts.Spread(color=color_cycle, show_legend=True),
    opts.Curve(color=color_cycle, show_legend=True),
    opts.Overlay(
        show_title=True, frame_width=600, frame_height=300, show_legend=True, yaxis=None
    ),
)