## Identify leading edge

This is a test notebook, for me to make a few tests about leading edge detection on along track altimeter Waveforms.

The algo is an implementation of the leading edge detection of WHALES algo described in [this document](https://climate.esa.int/sites/default/files/Sea_State_cci_ATBD_v1.1-signed_0.pdf)

The objective here is to load zarr time ranges that have not been processed (zarr partitionned by time range) and to use chunks defined at zarr level. 
Zarr variable we use as input is *waveform_20_plrm_ku*, which is 20Hz data, with 128 waveform points per waveform. This variable is 2D, indexed by *time_20_c* (timestamp of the waveform) and *echo_sample_ind* (index of each of the 128 waveform data points).
Idealy, chunks are in the MBs range, so that it is worth using dask.

On each chunk, we have multiple waveforms on which we apply the leading_edge detection algorithm sequentially, to derive *startgate* and *stopgate* attributes of the waveform. And chunks are processed in parallel via dask. I say 'sequentially', but I try to use vectorized np operation (optimization) such as np.median, np.diff, ... any time I can, and then I iterate over numpy array for operation I do not know how to vectorize.

Then the *startgate* and *stopgate* are stored to variables *waveform_20_plrm_ku_startgate* and *waveform_20_plrm_ku_stopgate* which are tied to index *time_20_c*.


In [1]:
import logging
import os
import sys
import numpy as np
import pandas as pd
import xarray as xr
import dask
import dask.distributed
import dask.array as da
import zcollection

from typing import Tuple, Optional

import eumdac.collection

sys.path.append(os.path.abspath('..'))
from src.connectors.eumdac_connector import EumdacConnector
from src.processors.zarr_processor import ZarrProcessor
from utils.logging_utils import setup_root_logging, setup_module_logger
from utils.opensearch_query_formatter import OpenSearchQueryFormatter

### The Leading edge detection algo - to apply on one waveform

I tried to use np vectorized functions as much as I could for oprimization, and process chunks of multiple waveforms (5000)

In [2]:
def identify_leading_edges(data_chunk: xr.Dataset, tau: float) -> xr.Dataset:
    """
    Identify the startgate and stopgate for multiple waveforms in a dataset chunk.

    :param xr.Dataset data_chunk: Input dataset chunk containing waveforms with shape [time_20_c, 128]
    :param float tau: Time spacing between consecutive gates in seconds
    :return: xr.Dataset containing 'waveform_20_plrm_ku_startgate' and 'waveform_20_plrm_ku_stopgate'
    variables indexed by 'time_20_c'.
    :rtype: xr.Dataset
    """
    waveforms = data_chunk['waveform_20_plrm_ku'].values  # Accessing the waveform data
    num_waveforms = waveforms.shape[0]  # This will be the chunk size, e.g., 5000
    startgates = np.full(num_waveforms, -1, dtype=np.int8)
    stopgates = np.full(num_waveforms, -1, dtype=np.int8)

    # "The waveform is normalised with normalisation factor N, where N = 1.3 * median(waveform)"
    medians = np.median(waveforms, axis=1, keepdims=True)
    N = 1.3 * medians
    normalized_waveforms = waveforms / N

    # Identify potential startgates
    # "The leading edge starts when the normalised waveform has a rise of 0.01 units
    # compared to the previous gate (startgate)"
    normalized_diff = np.diff(normalized_waveforms, axis=1)
    startgate_candidates = (normalized_diff >= 0.01).astype(np.int32)
    startgates = np.argmax(startgate_candidates, axis=1) + 1

    # Create a mask for valid startgates
    valid_startgate_mask = (startgate_candidates[np.arange(num_waveforms), startgates] > 0)

    # Check for valid startgates, and get stopgates
    for i in range(num_waveforms):
        if valid_startgate_mask[i]:
            startgate = startgates[i]

            # Validating startgate
            # "At this point, the leading edge is considered valid if, for at least four gates
            # after startgate, it does not decrease below 0.1 units (10% of the normalised power)"
            if (
                (startgate + 4 < normalized_waveforms.shape[1]) and 
                (np.all(normalized_waveforms[i, startgate+1:startgate + 5] >= 0.1))
            ):
                startgates[i] = startgate

                # Identify stopgates
                # "The end of the leading edge (stopgate) is fixed at the first gate in which the
                # derivative changes sign (i.e. the signal start decreasing and the trailing edge begins), if the
                # change of sign is kept for the following 3 gates."
                derivatives = normalized_diff[i]
                stopgate_candidates = np.where((derivatives[:-1] > 0) & (derivatives[1:] < 0))[0] + 1

                # Check for valid stopgates
                for stopgate in stopgate_candidates:
                    if stopgate + 3 < normalized_waveforms.shape[1] and np.all(derivatives[stopgate:stopgate + 3] < 0):
                        stopgates[i] = stopgate
                        break

    # Create a dataset to return
    retval = xr.Dataset({
        "waveform_20_plrm_ku_startgate": (["time_20_c"], startgates),
        "waveform_20_plrm_ku_stopgate": (["time_20_c"], stopgates),
    }, coords={
        "time_20_c": (["time_20_c"], data_chunk['time_20_c'].values)
    })
    return retval

In [3]:
if __name__ == "__main__":
    # Simulated waveform data
    waveforms: xr.Dataset = xr.Dataset({
    "waveform_20_plrm_ku": xr.DataArray(
        [[0.0, 0.0, 0.0, 0.0001, 0.0015, 0.002, 0.004, 0.01, 0.02, 0.05, 0.15, 0.1, 0.08, 0.07, 0.03, 0.01, 0.0],
        [0.0, 0.0, 0.0, 0.00004, 0.001, 0.0015, 0.004, 0.01, 0.02, 0.05, 0.15, 0.1, 0.08, 0.07, 0.03, 0.01, 0.0],
        [0.0, 0.0, 0.0002, 0.0015, 0.0018, 0.002, 0.004, 0.01, 0.02, 0.15, 0.12, 0.1, 0.08, 0.07, 0.03, 0.01, 0.0],
        [0.0, 0.0003, 0.0017, 0.0022, 0.0028, 0.0032, 0.004, 0.01, 0.02, 0.075, 0.073, 0.072, 0.07, 0.07, 0.03, 0.01, 0.0]],
        dims=["time_20_c", "echo_sample_ind"], 
        coords={"time_20_c": np.arange(4), "echo_sample_ind": np.arange(17)},
    )})
        
    tau: float = 3.125e-9  # Time spacing in seconds

    res: xr.Dataset = identify_leading_edges(waveforms, tau)
    print(res)

<xarray.Dataset> Size: 68B
Dimensions:                        (time_20_c: 4)
Coordinates:
  * time_20_c                      (time_20_c) int64 32B 0 1 2 3
Data variables:
    waveform_20_plrm_ku_startgate  (time_20_c) int64 32B 4 4 2 1
    waveform_20_plrm_ku_stopgate   (time_20_c) int8 4B 10 10 9 9


### Download datasets containing waveform (Sentinel 3) and store them to zarr


In [4]:
setup_root_logging()

logger: logging.Logger = setup_module_logger(__name__)

COLLECTION_ID: str = "EO:EUM:DAT:0415"
DOWNLOAD_DIR: str = "/tmp/products"
MEASUREMENTS_FILENAME: str = "enhanced_measurement.nc"
ZARR_BASE_PATH: str = "/tmp/sen3_sral_enhanced"
INDEX_DIMENSION: str = "time_20_c"
download_dir: str = os.path.join(os.getcwd(), DOWNLOAD_DIR)

In [5]:

if __name__ == "__main__":
    logger.info("Connecting EUMDAC datastore...")
    connector: EumdacConnector = EumdacConnector()
    datastore: eumdac.datastore.DataStore = connector.datastore

    # Query a few data files for Sentinel3A and 3B SRAL (Level2 data) for 2024-09-20
    opensearch_query: str = OpenSearchQueryFormatter(
        query_params={
            "pi": COLLECTION_ID,
            "dtstart": "2024-09-23T00:00:00Z",
            "dtend": "2024-09-23T00:20:00Z",
        }
    ).format()
    logger.info("Listing EUMDAC products matching filters '%s'", opensearch_query)
    products: eumdac.collection.SearchResults = datastore.opensearch(query=opensearch_query)
    product_ids: list[str] = [str(x) for x in products]

    logger.info("%s matching products found", len(product_ids))
    logger.debug("Listed products are: %s", product_ids)

    # Download files - benefits of dask parallelization
    logger.info("Downloading products (dask parallelized)...")
    downloaded_folders: list[str] = connector.download_products(
        COLLECTION_ID, product_ids, download_dir, MEASUREMENTS_FILENAME
    )

    # Store files to partitionned zarr files
    # Partition by day (with zcollection) - to be tuned depending on data use / volumetry
    logger.info("Persisting data in a partitionned zarr collection...")
    netcdf_file_paths: list[str] = [
        os.path.join(folder, MEASUREMENTS_FILENAME) for folder in downloaded_folders
    ]
    partition_handler: zcollection.partitioning.Partitioning = zcollection.partitioning.Date(
        (INDEX_DIMENSION,), resolution='M'
    )

    zarr_processor: ZarrProcessor = ZarrProcessor(
        ZARR_BASE_PATH, 
        partition_handler, 
        index_dimension=INDEX_DIMENSION,
    )
    zarr_processor.netcdf_2_zarr(
        netcdf_file_paths,
        variables = ["waveform_20_plrm_ku"], # only extract waveform_20_plrm_ku
        chunk_sizes={
            # About 4mn chunks at 50Hz with int32 values- 2MB
            "waveform_20_plrm_ku": (5000,128)
        } 
    )
    logger.info("Job done")

2024-10-07 09:49:19.171 INFO     | 127.0.1.1 PID:8506 TID:139623468074048 | __main__ <module> 3869866198.py:  2 | Connecting EUMDAC datastore...
2024-10-07 09:49:19.401 INFO     | 127.0.1.1 PID:8506 TID:139623468074048 | __main__ <module> 3869866198.py: 14 | Listing EUMDAC products matching filters 'pi=EO:EUM:DAT:0415&dtstart=2024-09-23T00:00:00Z&dtend=2024-09-23T00:20:00Z'
2024-10-07 09:49:20.202 INFO     | 127.0.1.1 PID:8506 TID:139623468074048 | __main__ <module> 3869866198.py: 18 | 10 matching products found
2024-10-07 09:49:20.203 INFO     | 127.0.1.1 PID:8506 TID:139623468074048 | __main__ <module> 3869866198.py: 22 | Downloading products (dask parallelized)...
2024-10-07 09:49:20.561 INFO     | 127.0.1.1 PID:8506 TID:139623468074048 | src.connectors.eumdac_connector download_products eumdac_connector.py:177 | Downloading products...
2024-10-07 09:49:32.574 INFO     | 127.0.1.1 PID:8506 TID:139623468074048 | __main__ <module> 3869866198.py: 29 | Persisting data in a partitionned 

### Compute startgate / stopgate of every waveform in the zarr collection

Chunk the zarr collection with 2MB chunks and apply the algorithm to extract startgate / stopgate.

Then, store the new variables representing waveforms  startgate / stopgate to zarr, in the same collection, but new variables.

In [6]:
collection: zcollection.Collection = zcollection.open_collection(ZARR_BASE_PATH)
data: xr.Dataset = collection.load(filters='year == 2024 and month == 9').to_xarray()
data['waveform_20_plrm_ku'] = data['waveform_20_plrm_ku'].chunk({'time_20_c': 5000, 'echo_sample_ind': 128})
data = data.unify_chunks()
logger.info("Input xr_dataset: %s", data)

tau: float = 3.125e-9

# Check how many chunks are in the time_20_c dimension
chunk_size: int = 5000  # Number of waveforms in a chunk
waveform_len: int = 128  # Number of waveforms in a chunk
num_chunks: int = data["time_20_c"].shape[0] // chunk_size + (data["time_20_c"].shape[0] % chunk_size > 0)
logger.info("num_chunks: %s", num_chunks)

# Create template for map_blocks transformation
total_waveforms: int = data["time_20_c"].shape[0]

template: xr.Dataset = xr.Dataset({
    "waveform_20_plrm_ku_startgate": (["time_20_c"], da.full(total_waveforms, -1, dtype=np.int8, chunks=(chunk_size,))),
    "waveform_20_plrm_ku_stopgate": (["time_20_c"], da.full(total_waveforms, -1, dtype=np.int8, chunks=(chunk_size,))),
}, coords={
    "time_20_c": (["time_20_c"], data['time_20_c'].values),
})

# map 'identify_leading_edges' to each chunk and compute the leading edges
xr_dataset: xr.Dataset = data.map_blocks(
    identify_leading_edges, 
    args=(tau,),
    template=template
)
xr_dataset.compute()
logger.info("Output xr_dataset: %s", xr_dataset)

# Add new variables to zarr collection
collection: zcollection.Collection = zcollection.open_collection(ZARR_BASE_PATH, mode="w")
zds: zcollection.dataset.Dataset = zcollection.dataset.Dataset.from_xarray(xr_dataset)

collection_variables = [x.name for  x in collection.variables()]
new_variables = ['waveform_20_plrm_ku_startgate', 'waveform_20_plrm_ku_stopgate']
for new_variables in new_variables:
    if new_variables not in collection_variables:
        collection.add_variable(zds.metadata().variables[new_variables])

# Add the echo_sample_ind dimension, for zcollection insert to work. 
# zcollection expects same dimensions even if not used by new variables
xr_dataset: xr.Dataset = xr_dataset.expand_dims(dim={"echo_sample_ind": waveform_len}, axis=1)

# Finally, insert values for new variables
collection.insert(xr_dataset)
logger.info("Successfully wrote output variables to zcollection.")


2024-10-07 09:49:36.878 INFO     | 127.0.1.1 PID:8506 TID:139623468074048 | zcollection.collection from_config __init__.py:283 | Opening collection: '/tmp/sen3_sral_enhanced'
2024-10-07 09:49:36.973 INFO     | 127.0.1.1 PID:8506 TID:139623468074048 | __main__ <module> 3580136275.py:  5 | Input xr_dataset: <xarray.Dataset> Size: 213MB
Dimensions:              (time_20_c: 203088, echo_sample_ind: 128)
Coordinates:
  * time_20_c            (time_20_c) datetime64[ns] 2MB 2024-09-22T23:33:26.0...
  * echo_sample_ind      (echo_sample_ind) int8 128B 0 1 2 3 ... 124 125 126 127
Data variables:
    lon_20_c             (time_20_c) float64 2MB dask.array<chunksize=(5000,), meta=np.ndarray>
    waveform_20_plrm_ku  (time_20_c, echo_sample_ind) float64 208MB dask.array<chunksize=(5000, 128), meta=np.ndarray>
    lat_20_c             (time_20_c) float64 2MB dask.array<chunksize=(5000,), meta=np.ndarray>
Attributes: (12/64)
    Conventions:                                 CF-1.6
    absolute_pass_n

In [7]:
# Check if zarr collection holds data fro new variables

collection: zcollection.Collection = zcollection.open_collection(ZARR_BASE_PATH)
data: zcollection.dataset.Dataset = collection.load(filters='year == 2024 and month == 9').to_xarray()
logger.info("collection data: %s", data)

2024-10-07 09:49:58.749 INFO     | 127.0.1.1 PID:8506 TID:139623468074048 | zcollection.collection from_config __init__.py:283 | Opening collection: '/tmp/sen3_sral_enhanced'
2024-10-07 09:49:58.801 INFO     | 127.0.1.1 PID:8506 TID:139623468074048 | __main__ <module> 1706179248.py:  5 | collection data: <xarray.Dataset> Size: 265MB
Dimensions:                        (time_20_c: 203088, echo_sample_ind: 128)
Coordinates:
  * time_20_c                      (time_20_c) datetime64[ns] 2MB 2024-09-22T...
  * echo_sample_ind                (echo_sample_ind) int8 128B 0 1 2 ... 126 127
Data variables:
    waveform_20_plrm_ku            (time_20_c, echo_sample_ind) float64 208MB dask.array<chunksize=(203088, 128), meta=np.ndarray>
    lat_20_c                       (time_20_c) float64 2MB dask.array<chunksize=(203088,), meta=np.ndarray>
    lon_20_c                       (time_20_c) float64 2MB dask.array<chunksize=(203088,), meta=np.ndarray>
    waveform_20_plrm_ku_stopgate   (time_20_c, ech