# Prepare a Training Data Set for ML

The main goal of this tutorial is to: Extract data from **multiple** STAC collections only at the **specified locations** as a **time series**.

As an example use case we are going to...
1. Get Station Data:
    - Download point data time series with a measured variable (e.g. snow depth)
    - Preprocess it to be used with EO Data
2. Get EO Data:
    - Query a STAC catalog
    - Get acquisitions for the points in the relevant time frame (temporal, spatial)
    - Get acquisitions from different collections
3. Prepare EO Data:
    - Load the found items into one data cube, by only loading the relevant geometry
    - Homogenize the datacube to a common temporal and spatial resolution
    - Add new information to the data cube: Calculate the NDSI (*ideally the collections should be radiometrically harmonized, e.g. [sen2like](https://github.com/senbox-org/sen2like), we are not doing this for this tutorial*)
    - Evaluate the gain in time steps 
4. Combine EO and Station Data:
    - Convert the data cube and the station measurements
    - Use a format that can be easily used for machine learning
5. Apply Machine Learning
    - Regression model to predict snow depth (this is exemplatory, not scientifically valid)
6. Weather Data
    - Compare to snow depth from ERA5 Land 

This tutorial should serve as guidance on how to extract data from multiple sources from STAC catalogs and use them in further workflows.
There are many more applications that could be covered. The next step could be to add more predictors and do a multivariate regression, taking into account more factors like elevation, aspect, temperature, etc.

Things to consider
- sparse xarray data cubes
- xvec


<img src="sketch_ws_prepml.png" width="600">

## Environment

This notebook needs a custom environment on terrabyte to run. Run the `micromamba install ...` cell below. Then close the session. Start a new jupyter session from the [terrabyte portal](https://portal.terrabyte.lrz.de/) where you specify the name (*prepml* in this example) of the newly created environment in the *custom environment field* (you have to type it there the first time you use it, then it will be available from the dropdown list above).  

In [None]:
# !micromamba create -y -n prepml requests numpy pandas geopandas xarray xvec rioxarray shapely odc-stac odc-geo pystac-client dask graphviz folium branca tensorflow seaborn libgdal libgdal-jp2openjpeg zarr jupyter jupyter-server-proxy

In [None]:
import os
import time
import socket
import io
import zipfile
from datetime import datetime, timezone

import requests
import numpy as np
import pandas as pd
import geopandas as gpd
import matplotlib.pyplot as plt

import xarray
import xvec
import rioxarray
from shapely.geometry import box, shape
from odc import stac as odc_stac
from odc.geo import geobox
from pystac_client import Client as pystacclient
from pystac.extensions.raster import RasterExtension
import tensorflow as tf
import seaborn as sns
import dask
from dask.distributed import Client

import folium
import folium.plugins as folium_plugins
import branca.colormap as cm


## 1. Get Station data

### 1.1 Download station data 

Download station data time series. We are going to use monthly snow depth measurements in South Tyrol. They have been prepared, gapfilled and made available via the ClirSnow project.

In [None]:
# set url
url = 'https://zenodo.org/records/5109574/files/meta_all.csv?download=1'
filename = url.split('/')[-1]
filename = filename.split('?')[0]

# Send a GET request to download the file
response = requests.get(url)

# Check if the request was successful
if response.status_code == 200:
    # Save the file locally
    with open(filename, 'wb') as file:
        file.write(response.content)
    print("File downloaded and saved successfully.")
else:
    print(f"Failed to download file. Status code: {response.status_code}")

In [None]:
sd_meta = pd.read_csv(filename)
sd_meta.head()

In [None]:
# set url
url = 'https://zenodo.org/records/5109574/files/data_monthly_IT_BZ.zip?download=1'
filename = url.split('/')[-1]
filename = filename.split('?')[0]

# Send a GET request to download the file
response = requests.get(url)

# Check if the request was successful
if response.status_code == 200:
    # Open the downloaded file as a zip file
    with zipfile.ZipFile(io.BytesIO(response.content)) as zip_ref:
        # Extract all contents into the current directory
        zip_ref.extractall()
    print("File downloaded and extracted successfully.")
else:
    print(f"Failed to download file. Status code: {response.status_code}")

In [None]:
# read it into memory
filename = filename.split('.')[0] + '.csv'
sd_mnth = pd.read_csv(filename)
sd_mnth.head()

### 1.2 Prepare station data

Filter the metadata to keep only stations form the province of South Tyrol having values between 2000 and 2019. 

In [None]:
start = 2000
end = 2019
sd_meta = sd_meta[
    (sd_meta['Provider'] == 'IT_BZ') &
    (sd_meta['HS_year_start'] <= start) &
    (sd_meta['HS_year_end'] >= end)]
sd_meta = sd_meta[['Name', 'Longitude', 'Latitude', 'Elevation']]

Select relevant columns.

In [None]:
sd_mnth = sd_mnth[['Name', 'year', 'month', 'HSmean_gapfill']]
sd_mnth = sd_mnth[sd_mnth['year'].between(start, end)]

And join the metadata (geographical information, ...).

In [None]:
sd = pd.merge(sd_meta, sd_mnth, on='Name', how='inner')

This is what our station data looks like now. The metadata combined with the measurements.

In [None]:
sd.head()

Turn snow depth time series into a geodataframe (we also do it for the metadata for plotting the locations).

In [None]:
sd = gpd.GeoDataFrame(data=sd,
                      geometry=gpd.points_from_xy(sd.Longitude, sd.Latitude),
                      crs="EPSG:4326")
sd_meta = gpd.GeoDataFrame(data=sd_meta,
                           geometry=gpd.points_from_xy(sd_meta.Longitude, sd_meta.Latitude),
                           crs="EPSG:4326")

Create a buffer around the points for extracting more than just one pixel.

In [None]:
sd_meta = sd_meta.to_crs(3035)  # LAEA Europe
sd_meta['geometry'] = sd_meta['geometry'].buffer(distance=200, cap_style='square')  # square
sd_meta = sd_meta.to_crs(4326)  # Back to 4326

Look at the distribution of the stations and the buffers on a map.

In [None]:
colormap = cm.linear.viridis.scale(round(sd_meta['Elevation'].min(), -2),
                                   round(sd_meta['Elevation'].max(), -2))
colormap.caption = 'Elevation m'

m = folium.Map(tiles="OpenStreetMap", zoom_start=9)

folium.GeoJson(
    sd_meta,
    name="Snow Depth Stations Buffer",
    style_function=lambda feature: {
        'fillColor': colormap(feature['properties']['Elevation']),
        'color': 'black',
        'weight': 1,
        'fillOpacity': 0.7,
    },
    tooltip=folium.GeoJsonTooltip(fields=["Name", "Elevation"]),
).add_to(m)

folium.GeoJson(
    sd_meta.geometry.centroid,
    name="Snow Depth Stations",
).add_to(m)

colormap.add_to(m)
m.fit_bounds(m.get_bounds())
m

Create a bounding box columns for querying the STAC catalog.

In [None]:
sd_meta = pd.concat([sd_meta, sd_meta.bounds], axis=1)

## 2. Get EO Data: The terrabyte STAC Catalog

### 2.1 Discover the STAC Catalog

List all available collections

In [None]:
catalog_url = 'https://stac.terrabyte.lrz.de/public/api'
catalog = pystacclient.open(catalog_url)
collections = catalog.get_all_collections()
for collection in collections:
    print(
        f"{collection.id} | {collection.title} | "
        f"{collection.extent.temporal.intervals[0][0].year} - "
        f"{collection.extent.temporal.intervals[0][1].year}")


Check the band names for S2 C1 L2A, the LS collections

In [None]:
catalog.get_collection('sentinel-2-l2a')

### 2.2 Query data from the STAC catalog

#### 2.2.1 Query a single station

Let's define the parameters we want to use for both collections.

In [None]:
max_cloud_cover = 25

query = {
    'eo:cloud_cover': {
        "gte": 0,
        "lte": max_cloud_cover
    }
}

start = '2000-01-01T00:00:00Z'
end = '2019-12-31T23:59:59Z'

bands = ['swir16',
         'green']  # fortunately the bands have the same names across multiple collections. That's not always the case.

Search Sentinel-2 for one specific station. The first station in the list.

In [None]:
%%time
collection = ['sentinel-2-c1-l2a']
bbox = [sd_meta.minx.iloc[0], sd_meta.miny.iloc[0],
        sd_meta.maxx.iloc[0], sd_meta.maxy.iloc[0]]  # first station in list

search = catalog.search(collections=collection,
                        bbox=bbox,
                        datetime=[start, end],
                        query=query)
items = list(search.items())  # TODO: Is this still the correct way to do it?

print(f'Found {len(items)} Scenes')

Inspect the search results. Full tiles are returned with their according metadata.

In [None]:
items[0]

Search Sentinel-2 and all Landsat collections for the specified parameters and the same specific station.

In [None]:
%%time
collection = ['landsat-ot-c2-l2', 'landsat-etm-c2-l2', 'landsat-tm-c2-l2',
              'sentinel-2-c1-l2a']  # searching for each collection separately is more performant for larger requests!

search = catalog.search(collections=collection,
                        bbox=bbox,
                        datetime=[start, end],
                        query=query)
items = list(search.items())  # TODO: Is this still the correct way to do it?

print(f'Found {len(items)} Scenes')

This search, explicitly addressing the collections is more effective!

In [None]:
%%time
collection = ['landsat-ot-c2-l2', 'landsat-etm-c2-l2',
              'landsat-tm-c2-l2', 'sentinel-2-c1-l2a']

for col in collection:
    search = catalog.search(collections=col,
                            bbox=bbox,
                            datetime=[start, end],
                            query=query)
    items = list(search.items())  # TODO: Is this still the correct way to do it?
    print(f'Collection: {col}. Found {len(items)} Scenes')

#### 2.2.2 Query multiple stations

Define a function to use to iterate over all stations. bbox is the variable object in this function, it is extracted from each row of the station list.

In [None]:
def query_stac(row, collection):
    bbox = [row.minx, row.miny, row.maxx, row.maxy]
    search = catalog.search(collections=collection,
                            bbox=bbox,
                            datetime=[start, end],
                            query=query)
    items = list(search.items())
    print(f"Name: {row['Name']}, Items: {len(items)}")
    return items

The function is applied to the station metadata geodataframe, where we had stored the buffers around the stations. The result is a list with all found STAC items for each of the stations.

In [None]:
%%time
collection = ['sentinel-2-c1-l2a']
items_list_s2 = sd_meta.apply(query_stac, args=collection, axis=1)

#### 2.2.3 Excursion: Use geometry of interest directly in search

If your point/vector data is spread out across the globe, with many items in between the geometries, it is better to use the geometry explicitly in the STAC search and not the bounding box. In this way you will only get the tiles you are interested in.

If your point/vector data is close to one another, with few items in between the geometries, it is better to use the bounding box of your geometries in the STAC search. In this way you will not duplicate any items in your search. This is what we will do know.

### 2.2.4 Query full bounding box

We pass the full bounding box of our station network to the search.

In [None]:
%%time
collection = 'sentinel-2-c1-l2a'
bbox = [sd_meta.minx.min(), sd_meta.miny.min(),
        sd_meta.maxx.max(), sd_meta.maxy.max()]  # all stations

search = catalog.search(collections=collection,
                        bbox=bbox,
                        datetime=[start, end],
                        query=query)
items_s2 = list(search.items())  # TODO: Is this still the correct way to do it?
print(f'Collection: {collection}. Found {len(items_s2)} Scenes')

In [None]:
%%time
collection = ['landsat-ot-c2-l2', 'landsat-etm-c2-l2', 'landsat-tm-c2-l2']

items_ls = []
for col in collection:
    search = catalog.search(collections=col,
                            bbox=bbox,
                            datetime=[start, end],
                            query=query)
    items = list(search.items())  # TODO: Is this still the correct way to do it?
    items_ls.append(items)
    print(f'Collection: {col}. Found {len(items)} Scenes')

# currently we have a list with 3 entries, let's flatten it
items_ls = [item for sublist in items_ls for item in sublist]

Let's plot a some items in relation to a station. It becomes clear that loading all of this data should be avoided if possible!

In [None]:
map = folium.Map()
layer_control = folium.LayerControl(position='topright', collapsed=True)

tile_s2 = shape(items_s2[0].geometry)
tile_s2 = gpd.GeoDataFrame([{'geometry': tile_s2}], crs="EPSG:4326")
tile_s2 = folium.GeoJson(tile_s2.to_json(), name="S2",
                         style_function=lambda x: {"fillColor": "blue"})

tile_ls = shape(items_ls[0].geometry)
tile_ls = gpd.GeoDataFrame([{'geometry': tile_ls}], crs="EPSG:4326")
tile_ls = folium.GeoJson(tile_ls.to_json(), name="LS",
                         style_function=lambda x: {"fillColor": "green"})

aoi = box(*bbox)
aoi = gpd.GeoDataFrame({"geometry": [aoi]}, crs="EPSG:4326")
aoi = folium.GeoJson(aoi.to_json(), name="aoi",
                     style_function=lambda x: {"fillColor": "white"})

station = box(*[sd_meta.minx.iloc[0], sd_meta.miny.iloc[0],
                sd_meta.maxx.iloc[0], sd_meta.maxy.iloc[0]])
station = gpd.GeoDataFrame({"geometry": [station]}, crs="EPSG:4326")
station_mark = station.geometry.centroid
station = folium.GeoJson(station.to_json(), name="Station",
                         style_function=lambda x: {"fillColor": "red"})

station_mark = folium.GeoJson(station_mark, name="Station Marker")

tile_s2.add_to(map)
tile_ls.add_to(map)
station.add_to(map)
aoi.add_to(map)
station_mark.add_to(map)
layer_control.add_to(map)
map.fit_bounds(map.get_bounds())
map

## 3. Starting a Dask Cluster

Here we are starting the dask client for scaling the computation to the available resources.
Once started, a link to the dask dashboard will be shown which will display details on the dask computation status.
This should be done **before** the first calculation on xarray objects takes place!

In [None]:
dir_out = '~/ws_prepml'
dask_tmpdir = os.path.join(dir_out, 'scratch', 'localCluster')
# from testins running without threads is the faster option 
dask_threads = 1

In [None]:
host = os.getenv('host')
jl_port = os.getenv('port')
#create to URL to point to the jupyter-server-proxy
dask_url = f'https://portal.terrabyte.lrz.de/node/{host}/{jl_port}' + '/proxy/{port}/status'
#dask will insert the final port chosen by the Cluster

dask.config.set({'temporary_directory': dask_tmpdir,
                 'distributed.dashboard.link': dask_url})

#some settings to increase network timeouts and allow the dashboard to plot larger graphs
#dask.config.set({'distributed.comm.timeouts.tcp': '180s',
#                 'distributed.comm.timeouts.connect': '120s',
#                 'distributed.dashboard.graph-max-items': 55000,
#                 'distributed.deploy.lost-worker-timeout': '90s',
#                 'distributed.scheduler.allowed-failures': 180,
#                 })

#we set the dashboard address for dask to choose a free random port,
# so there is no error with multiple tasks running on same node
client = Client(threads_per_worker=dask_threads,
                dashboard_address="127.0.0.1:0")
client

## 4. Prototyping

**Testing the workflow on a small subset of the data**. Before the workflow is mature it needs to be tested. This should be done on a small subset of the data set to: reduce processing time, save resources, get a feeling for the value ranges etc. Let's develop the workflow on a single station before expanding to all of them.

### 4.1 Reduce amount of data for quick prototyping
First, some tweaks to reduce the amount of data.
*Note: Do this for s2, ls, and then also combine the two into one cube.*

In [None]:
def filter_time(items, start_date, end_date):
    items_tst = [
        item for item in items
        if start_date <= item.datetime <= end_date
    ]
    return items_tst

In [None]:
# choose the timerange for prototyping
start_date = datetime(2017, 1, 1, tzinfo=timezone.utc)
end_date = datetime(2018, 1, 1, tzinfo=timezone.utc)

# define the parameters for prototyping per collection
proto_dict = {
    "s2": {
        "items_tst": filter_time(items_s2, start_date, end_date),
        "scale": RasterExtension.ext(items_s2[0].assets["B03"]).bands[0].scale,
        # 0.0001; it is constant across the relevant bands and time steps
        "offset": RasterExtension.ext(items_s2[0].assets["B03"]).bands[0].offset,
        # -0.1; it is constant across the relevant bands and time steps
    },
    "ls": {
        "items_tst": filter_time(items_ls, start_date, end_date),
        "scale": RasterExtension.ext(items_ls[0].assets["B03"]).bands[0].scale,
        # 2.75e-05; it is constant across the relevant bands and time steps
        "offset": RasterExtension.ext(items_ls[0].assets["B03"]).bands[0].offset,
        # -0.2; it is constant across the relevant bands and time steps
    }
}

# aoi
aoi_tst = sd_meta.iloc[0]

Defining a reduced data cube (**lazily - nothing is loaded so far**)

In [None]:
cube_s2 = odc_stac.load(proto_dict["s2"]["items_tst"],
                        geopolygon=aoi_tst.geometry,
                        groupby='solar_day',
                        chunks={"time": -1},  # keep time in one chunk.
                        bands=bands,
                        nodata=0,
                        )

In [None]:
cube_s2

In [None]:
cube_ls = odc_stac.load(proto_dict["ls"]["items_tst"],
                        geopolygon=aoi_tst.geometry,
                        groupby='solar_day',
                        chunks={"time": -1},  # keep time in one chunk.
                        bands=bands,
                        nodata=0,
                        )

In [None]:
cube_ls

### 4.2 Test the processing steps
Scale, aggregate to monthly values (maximum value composite), calculate NDSI. **Still lazy.**

In [None]:
def scale_and_offset(cube, scale, offset):
    cube = cube.where(cube['green'] != 0)
    cube = cube.where(cube['swir16'] != 0)
    cube = cube * scale + offset
    return cube


def aggregate_monthly_ndsi(cube):
    cube = cube.resample(
        time="1ME").median()  # using median here, since the collections are not spectrally harmonized - to ensure we keep information from both collections. max would be better.
    cube["ndsi"] = (cube.green - cube.swir16) / (cube.green + cube.swir16)
    #cube = cube["ndsi"]
    return cube


**Scale and Offset**

In [None]:
cube_s2 = scale_and_offset(cube_s2, scale=proto_dict['s2']['scale'],
                           offset=proto_dict['s2']['offset'])
cube_ls = scale_and_offset(cube_ls, scale=proto_dict['ls']['scale'],
                           offset=proto_dict['ls']['offset'])

**Merge Cubes**

This has to happen before calculating the monthly NDSI to increase the observations per month.

In [None]:
cube_s2_rep = cube_s2.rio.reproject_match(cube_ls)
cube_mg = xarray.concat([cube_ls, cube_s2_rep], dim="time")
cube_mg = cube_mg.sortby("time")

**Monthly NDSI**

In [None]:
cube_s2 = aggregate_monthly_ndsi(cube=cube_s2)
cube_ls = aggregate_monthly_ndsi(cube=cube_ls)
cube_mg = aggregate_monthly_ndsi(cube=cube_mg)

**Inspect**

Load a couple of time steps to check the values! Change the bands for more inspection. **Now data is loaded!**

In [None]:
%%time
cube_s2['ndsi'].isel(time=slice(0, 5)).plot.imshow(col="time", size=8, aspect=1, vmin=-1, vmax=1)

In [None]:
%%time
cube_ls['ndsi'].isel(time=slice(0, 5)).plot.imshow(col="time", size=8, aspect=1, vmin=-1, vmax=1)

In [None]:
%%time
cube_mg['ndsi'].isel(time=slice(0, 5)).plot.imshow(col="time", size=8, aspect=1, vmin=-1, vmax=1)

Let's look at the values as well!

In [None]:
cube_mg.isel(time=slice(0, 5)).load()

**Aggregate Spatially**

Aggregate the cube spatially to get a time series for the whole period we chose. This is to check if the seasonality makes sense.

In [None]:
cube_s2['ndsi'].mean(dim=["x", "y"]).plot(ylim=[-1, 1])

In [None]:
cube_ls['ndsi'].mean(dim=["x", "y"]).plot(ylim=[-1, 1])

In [None]:
cube_mg['ndsi'].mean(dim=["x", "y"]).plot(ylim=[-1, 1])

### 4.3 Compare to station data
Get the snow depth data for the test station

In [None]:
aoi_tst

Check what the actual measurements look like.

In [None]:
sd_mnth[(sd_mnth['Name'] == aoi_tst.Name)].head()

Add a time column in date format that matches the date format in the data cube.

In [None]:
sd_mnth['time'] = pd.to_datetime(
    sd_mnth['year'].astype(str) + '-' + sd_mnth['month'].astype(str)) + pd.offsets.MonthEnd(0)

Filter to the chosen time range and station.

In [None]:
sd_tst = sd_mnth[(sd_mnth['Name'] == aoi_tst.Name) &
                 (sd_mnth['time'] >= cube_mg.time.min().values) &
                 (sd_mnth['time'] <= cube_mg.time.max().values)]
sd_tst.head()

Convert the NDSI time series to a data frame.

In [None]:
%%time
df_ts = cube_mg['ndsi'].mean(dim=["x", "y"]).to_dataframe()
df_ts.head()

Join the two time series by date.

In [None]:
sd_tst = pd.merge(sd_tst, df_ts, on="time")
sd_tst

Plot their relationship (as a time series and as a scatter plot).

In [None]:
ax = sd_tst.plot(x="time", y="HSmean_gapfill", label="HSmean_gapfill",
                 marker="o", color="blue", figsize=(10, 6))
ax.set_ylabel("HSmean_gapfill", color="blue")
ax.legend(loc="upper left")

ax2 = ax.twinx()
sd_tst.plot(x="time", y="ndsi", label="ndsi", marker="s", color="orange", ax=ax2)
ax2.set_ylabel("NDSI", color="orange")

In [None]:
sd_tst.plot.scatter(x="ndsi", y="HSmean_gapfill", figsize=(10, 6))

Delete unneeded objects.

In [None]:
del sd_tst
del df_ts
del cube_mg
del cube_s2_rep
del cube_ls
del cube_s2

## 5. Scaling up to the full extent

#### 5.1 Defining the data cubes

Converting the bounding box of the station network we used above to a geodataframe

In [None]:
aoi = box(*bbox)
aoi = gpd.GeoDataFrame({"geometry": [aoi]}, crs="EPSG:4326")

Defining the Sentinel-2 data cube

In [None]:
%%time
bands = ["green", "swir16"]  # s2 and ls share the same band names
chunk_size = 512

cube_s2 = odc_stac.load(items_s2,
                        geopolygon=aoi.geometry,
                        groupby='solar_day',
                        chunks={"x": chunk_size, "y": chunk_size, "time": -1},
                        bands=bands,
                        resolution=60,  # going to 60 m resolution
                        )
cube_s2

Defining the Landsat data cube.

In [None]:
%%time
chunk_size = 512
cube_ls = odc_stac.load(items_ls,
                        geopolygon=aoi.geometry,
                        groupby='solar_day',
                        chunks={"x": chunk_size, "y": chunk_size, "time": -1},
                        bands=bands,
                        resolution=60,
                        )
cube_ls

#### 5.2 Scale and Offset

In [None]:
cube_s2 = scale_and_offset(cube_s2, scale=proto_dict['s2']['scale'],
                           offset=proto_dict['s2']['offset'])
cube_ls = scale_and_offset(cube_ls, scale=proto_dict['ls']['scale'],
                           offset=proto_dict['ls']['offset'])

#### 5.3 Merging the Cubes

Check the pixel sizes and CRS of the cubes.

In [None]:
def pixel_size_crs(cube):
    x_p = np.diff(cube.x).mean()
    y_p = np.diff(cube.y).mean()
    cube_crs = cube.rio.crs
    print(x_p)
    print(y_p)
    print(cube_crs)

In [None]:
print(pixel_size_crs(cube=cube_s2))
print(pixel_size_crs(cube=cube_ls))

But, the pixels are not aligned. One has to be reprojected.

In [None]:
%%time
# takes forever, dask execution starts!
# cube_s2_rep = cube_s2.rio.reproject_match(cube_ls)

Merge. And sort the time dimension.

In [None]:
cube_mg = xarray.concat([cube_ls, cube_s2], dim="time")
cube_mg = cube_mg.sortby("time")

Resampling the grid of the landsat data cube to twice it's size, aggregating using max().
**This is a proxy for using the geometries to extract zonal statistics. It's more efficient.**

In [None]:
#cube_ls = cube_ls.coarsen(x=2, y=2, boundary="trim").max() # chunksize automatically halved.

In [None]:
#chunk_size = 1024
#cube_ls = cube_ls.chunk({"x": chunk_size, "y": chunk_size})

Resampling the grid of the s2 data cube to six times it's size (60 m), aggregating using max(). Then aligning to the landsat grid.

In [None]:
#cube_s2 = cube_s2.coarsen(x = 6, y = 6, boundary="trim").max() # interp(x=cube_ls["x"], y=cube_ls["y"])   # https://github.com/pydata/xarray/issues/6799 --> doesn't work on chunks!

In [None]:
#chunk_size = 1024
#cube_s2 = cube_s2.chunk({"x": chunk_size, "y": chunk_size})

#### 5.4 NDSI and temporal aggregation

Temporal resampling to one month and calculating ndsi. (**lazy**)

In [None]:
cube_s2 = aggregate_monthly_ndsi(cube=cube_s2)
cube_ls = aggregate_monthly_ndsi(cube=cube_ls)
cube_mg = aggregate_monthly_ndsi(cube=cube_mg)

In [None]:
cube_ls

### 5.4 Extract the NDSI at the stations

Extract the stations via xvec. In this workshop we are extracting the point locations. This is very efficient. That's why we have coarsened the resolution of our data cubes before. To simulate a spatial aggregation into zones. Using xvec zonal_statistics is possible but more costly.

In [None]:
sd_meta.geometry.centroid.head()

### Sentinel

In [None]:
ndsi_s2_points = cube_s2.xvec.extract_points(sd_meta.geometry.centroid, x_coords="x",
                                             y_coords="y", index=True)

In [None]:
ndsi_s2_points

In [None]:
%%time
ndsi_s2_points = ndsi_s2_points["ndsi"].compute()

In [None]:
ndsi_s2_ts = ndsi_s2_points.to_dataframe(dim_order=["geometry", "time"])

In [None]:
ndsi_s2_ts

In [None]:
ndsi_s2_ts = ndsi_s2_ts.reset_index()
ndsi_s2_ts = pd.merge(ndsi_s2_ts, sd_meta[["Name", "Elevation"]], left_on="index", right_index=True)
ndsi_s2_ts = pd.merge(ndsi_s2_ts, sd_mnth, left_on=["Name", "time"], right_on=["Name", "time"])
ndsi_s2_ts = ndsi_s2_ts[["time", "Name", "index", "Elevation", "ndsi", "HSmean_gapfill"]]

In [None]:
ndsi_s2_ts.hist("HSmean_gapfill")

In [None]:
# Scatter plot
df = ndsi_s2_ts[ndsi_s2_ts['ndsi'] > 0]
plt.figure(figsize=(10, 6))
plt.scatter(df['ndsi'], df['HSmean_gapfill'], alpha=0.7, edgecolors='k')

# Add labels and title
plt.xlabel('NDSI')
plt.ylabel('HSmean_gapfill')
plt.title('Scatter Plot of NDSI vs HSmean_gapfill')
plt.grid(True)
plt.show()

### Landsat

In [None]:
ndsi_ls_points = cube_ls.xvec.extract_points(sd_meta.geometry.centroid,
                                             x_coords="x", y_coords="y", index=True)

In [None]:
ndsi_ls_points

Convert to data frames, and combine the NDSI time series.

In [None]:
%%time
ndsi_ls_points = ndsi_ls_points.compute()

In [None]:
ndsi_ls_ts = ndsi_ls_points.to_dataframe(dim_order=["geometry", "time"])

In [None]:
ndsi_ls_ts

Merge snow depth data. First metadata, then snow depth.

In [None]:
ndsi_ls_ts = ndsi_ls_ts.reset_index()
ndsi_ls_ts = pd.merge(ndsi_ls_ts, sd_meta[["Name", "Elevation"]], left_on="index", right_index=True)
ndsi_ls_ts = pd.merge(ndsi_ls_ts, sd_mnth, left_on=["Name", "time"], right_on=["Name", "time"])
ndsi_ls_ts = ndsi_ls_ts[["time", "Name", "index", "Elevation", "ndsi", "HSmean_gapfill"]]

In [None]:
# Scatter plot
df = ndsi_ls_ts[ndsi_ls_ts['ndsi'] > 0]
plt.figure(figsize=(10, 6))
plt.scatter(df['ndsi'], df['HSmean_gapfill'], alpha=0.7, edgecolors='k')

# Add labels and title
plt.xlabel('NDSI')
plt.ylabel('HSmean_gapfill')
plt.title('Scatter Plot of NDSI vs HSmean_gapfill')
plt.grid(True)
plt.show()

### Merged Cube


In [None]:
ndsi_mg_points = cube_mg.xvec.extract_points(sd_meta.geometry.centroid,
                                             x_coords="x", y_coords="y", index=True)

In [None]:
ndsi_mg_points

In [None]:
%%time
ndsi_mg_points = ndsi_mg_points.compute()

In [None]:
ndsi_mg_ts = ndsi_mg_points.to_dataframe(dim_order=["geometry", "time"])

In [None]:
ndsi_mg_ts

In [None]:
ndsi_mg_ts = ndsi_mg_ts.reset_index()
ndsi_mg_ts = pd.merge(ndsi_mg_ts, sd_meta[["Name", "Elevation"]], left_on="index", right_index=True)
ndsi_mg_ts = pd.merge(ndsi_mg_ts, sd_mnth, left_on=["Name", "time"], right_on=["Name", "time"])
ndsi_mg_ts = ndsi_mg_ts[["time", "Name", "index", "Elevation", "ndsi", "HSmean_gapfill"]]

In [None]:
# Scatter plot
df = ndsi_mg_ts[ndsi_mg_ts['ndsi'] > 0]
plt.figure(figsize=(10, 6))
plt.scatter(df['ndsi'], df['HSmean_gapfill'], alpha=0.7, edgecolors='k')

# Add labels and title
plt.xlabel('NDSI')
plt.ylabel('HSmean_gapfill')
plt.title('Scatter Plot of NDSI vs HSmean_gapfill')
plt.grid(True)
plt.show()

### 5.x Excursion: Extract via zonal_statistics

In [None]:
# replace nan with -2, then use normal rasterization
# ndsi_s2 = ndsi_s2.fillna(-2)
#print(sd_meta.crs)
#sd_meta_rep = sd_meta.to_crs(ndsi_s2.rio.crs)
#print(sd_meta_rep.crs)

In [None]:
##%%time
#ndsi_s2_zonal = ndsi_s2.isel(time=slice(None, 5)).xvec.zonal_stats(
#    list(sd_meta.geometry[0:5]), x_coords="x", y_coords="y", stats="max", 
#    method="iterate",
#    n_jobs=-1,
#    index=True,
#)
ndsi_s2_zonal


## Outlook
Next steps in this workflow could be:

- Test an example where the stations are spread out over the globe.
- Include cloud and quality maksing to get more robust results.
- Use static information like elevation, aspect etc.
- Apply more sophisticated machine learning:
    - Time series prediction (extrapolation in time)
    - Mapping (extrapolation in space)
- Add more features (e.g. ERA5, S1) and do a mulitvariate analysis.
- Find solution for xvec.zonal_stats(method="rasterize")
- try depth first approach (loop through workflow in prototyping)