## Identify leading edge

This is a test notebook, for me to make a few tests about leading edge detection on along track altimeter Waveforms.

The algo is an implementation of the leading edge detection of WHALES algo described in [this document](https://climate.esa.int/sites/default/files/Sea_State_cci_ATBD_v1.1-signed_0.pdf)

For now it is random code, with 0 optims consideration: 
- for loops in numpy arrays instead of using vectorized numpy functions (derivative)
- dask distribution with micro chunks is inefficient

But the objective here is to load zarr time ranges that have not been processed (zarr partitionned by time range) and to use chunks defined at zarr level. 
Zarr variable we use as input is *waveform_20_plrm_ku*, which is 20Hz data, with 128 waveform points per waveform. This variable is 2D, indexed by *time_20_ku* (timestamp of the waveform) and *echo_sample_ind* (index of each of the 128 waveform data points).
Idealy, chunks are in the MBs range, so that it is worth using dask.

On each chunk, we have multiple waveforms on which we apply the leading_edge detection algorithm sequentially, to derive *startgate* and *stopgate* attributes of the waveform. And chunks are processed in parallel via dask.

Then the *startgate* and *stopgate* are stored to variables *waveform_20_plrm_ku_startgate* and *waveform_20_plrm_ku_stopgate* which are tied to index *time_20_ku*.


In [1]:
import os
import sys
import numpy as np
import xarray as xr
import dask
import dask.distributed
import dask.array as da

from typing import Tuple, Optional

import eumdac.collection

sys.path.append(os.path.abspath('..'))
from src.connectors.eumdac_connector import EumdacConnector
from utils.opensearch_query_formatter import OpenSearchQueryFormatter

### The Leading edge detection algo - to apply on one waveform

To refacto - use numpy better

In [92]:
def identify_leading_edge(waveform: np.array, tau: float) -> np.array:
    """
    Identify the startgate and stopgate of a waveform.
    From algo described here: https://climate.esa.int/sites/default/files/Sea_State_cci_ATBD_v1.1-signed_0.pdf

    :param np.array waveform: numpy array of waveform values
    :param float tau: time spacing between consecutive gates in seconds
    :return: 1D np.array with 2 values:
      - startgate: index of the start gate (Optional[int])
      - stopgate: index of the stop gate (Optional[int])
    :rtype: np.array

    """
    res: np.array = np.array([-1.0, -1.0], dtype=np.float64)
    try:
        # The waveform is normalised with normalisation factor N, 
        # where N = 1.3 * median(waveform)
        N: float = 1.3 * np.median(waveform)
        normalized_waveform: np.ndarray = waveform / N
    
        # The leading edge starts when the normalised waveform has a 
        # rise of 0.01 units compared to the previous gate (startgate)
        startgate: int | None = None
        for i in range(1, len(normalized_waveform)):
            if normalized_waveform[i] - normalized_waveform[i-1] >= 0.01:
                startgate = i
                break
    
        if startgate is None:
            return res  # No valid startgate found
    
        # At this point, the leading edge is considered valid if, for at least four gates 
        # after startgate, it does not decrease below 0.1 units (10% of the normalised power).
        valid_leading_edge: bool = (
            startgate + 4 < len(normalized_waveform) and
            np.all(normalized_waveform[startgate:startgate + 5] >= 0.1)
        )
    
        if not valid_leading_edge:
            return res  # Leading edge is not valid
    
        # The end of the leading edge (stopgate) is fixed at the first gate 
        # in which the derivative changes sign (i.e. the signal start decreasing 
        # and the trailing edge begins), if the change of sign is kept 
        # for the following 3 gates.
        stopgate: int | None = None
        for i in range(startgate + 1, len(normalized_waveform) - 3):
            # Calculate the derivative
            derivative: float = normalized_waveform[i + 1] - normalized_waveform[i]
            if derivative < 0:  # Start of decrease
                # Check if the derivative stays negative for the next 3 gates
                if (normalized_waveform[i + 2] - normalized_waveform[i + 1] < 0 and
                    normalized_waveform[i + 3] - normalized_waveform[i + 2] < 0):
                    stopgate = i
                    break
    
        res = np.array([startgate, stopgate], dtype=np.float64)
        print(res)
        return res
    except:
        return res

In [83]:
if __name__ == "__main__":
    # Simulated waveform data
    waveform: np.ndarray = np.array([0.0, 0.0, 0.0, 0.0001, 0.0015, 0.002, 0.004, 0.01, 0.02, 0.05, 0.15, 0.1, 0.08, 0.07, 0.03, 0.01, 0.0])
    tau: float = 3.125e-9  # Time spacing in seconds

    res = identify_leading_edge(waveform, tau)
    print(res)
    startgate, stopgate = res
    print(f"Startgate: {startgate}, Stopgate: {stopgate}")

(17,)
[ 4. 10.]
Startgate: 4.0, Stopgate: 10.0


### Download datasets containing waveform (Sentinel 3) and store them to zarr

**TODO**: zarr storage

In [3]:
COLLECTION_ID: str = "EO:EUM:DAT:0415"
DOWNLOAD_DIR: str = "/tmp/products"
MEASUREMENTS_FILENAME: str = "enhanced_measurement.nc"
INDEX_DIMENSION: str = "time_01"
download_dir: str = os.path.join(os.getcwd(), DOWNLOAD_DIR)

connector: EumdacConnector = EumdacConnector()
datastore: eumdac.datastore.DataStore = connector.datastore

# Query a few data files for Sentinel3A and 3B SRAL (Level2 data)
opensearch_query: str = OpenSearchQueryFormatter(
    query_params={
        "pi": COLLECTION_ID,
        "dtstart": "2024-09-23T00:20:00Z",
        "dtend": "2024-09-23T00:30:00Z",
    }
).format()
products: eumdac.collection.SearchResults = datastore.opensearch(query=opensearch_query)
product_ids: list[str] = [str(x) for x in products]
# If in local mode, process only a subset of the products for faster execution
if os.getenv("LOCAL_MODE", "1"):
    print("Local mode: processing every 50th product to debug faster")
    product_ids = product_ids[::50]
print("%s matching products found", len(product_ids))
print("Listed products are: %s", product_ids)

# Download files - benefits of dask parallelization
print("Downloading products (dask parallelized)...")
downloaded_folders: list[str] = connector.download_products(
    COLLECTION_ID, product_ids, download_dir, MEASUREMENTS_FILENAME
)

Local mode: processing every 50th product to debug faster
%s matching products found 1
Listed products are: %s ['S3B_SR_2_WAT____20240923T002358_20240923T002431_20240923T015119_0033_098_016______MAR_O_NR_005.SEN3']
Downloading products (dask parallelized)...


Perhaps you already have a cluster running?
Hosting the HTTP server on port 43047 instead


### WIP

To refacto 100%: load from zarr, process bigger chunks

In [6]:
# Use a local cluster, and process only - CPU intensive, GIL pb with threads
cluster: dask.distributed.LocalCluster = dask.distributed.LocalCluster(processes=True)
client: dask.distributed.Client = dask.distributed.Client(cluster)

# Process only 1 file
ds = xr.open_dataset(f"{DOWNLOAD_DIR}/S3B_SR_2_WAT____20240923T002358_20240923T002431_20240923T015119_0033_098_016______MAR_O_NR_005.SEN3/{MEASUREMENTS_FILENAME}")
ds.close()

print(ds)
lrm_da = ds["waveform_20_plrm_ku"]

# I know this is very in efficient, to have few kB of data to process in dask -> MBs is the usual dask standard
# But it is for map_blocks learning purpose
lrm_dda = da.from_array(lrm_da, chunks=(1, 128))
tau: float = 3.125e-9

# identify_leading_edge(waveform: np.ndarray, tau: float) -> np.array 1D with 2 values
result_array = lrm_dda.map_blocks(
    # Squeeze the 2D array (1, 128) to 1D (128) as expected by 'identify_leading_edge'
    lambda x: identify_leading_edge(np.squeeze(x), tau),
    meta=np.zeros(2, dtype=np.float64)
)
result = result_array.compute()
print(result)

client.close()
cluster.close()

Perhaps you already have a cluster running?
Hosting the HTTP server on port 41935 instead


<xarray.Dataset> Size: 3MB
Dimensions:                                        (time_01: 33,
                                                    time_20_ku: 628,
                                                    time_20_c: 628,
                                                    echo_sample_ind: 128)
Coordinates:
  * time_01                                        (time_01) datetime64[ns] 264B ...
  * time_20_ku                                     (time_20_ku) datetime64[ns] 5kB ...
  * time_20_c                                      (time_20_c) datetime64[ns] 5kB ...
  * echo_sample_ind                                (echo_sample_ind) int8 128B ...
    lat_01                                         (time_01) float64 264B ...
    lon_01                                         (time_01) float64 264B ...
    lat_20_ku                                      (time_20_ku) float64 5kB ...
    lon_20_ku                                      (time_20_ku) float64 5kB ...
    lat_20_c               

2024-10-05 22:01:19,635 - distributed.worker - ERROR - Compute Failed
Key:       ('lambda-c9198e16ddca1e5fdea9fc365bdbf830', 95, 0)
State:     executing
Function:  subgraph_callable-369879cfed3b0dbf39fe955e1502bb9d
args:      (<xarray.DataArray 'waveform_20_plrm_ku' (time_20_c: 628, echo_sample_ind: 128)> Size: 643kB
array([[ 9.421,  8.474,  5.44 , ..., 26.968, 25.114, 12.534],
       [ 6.666,  4.541,  3.079, ..., 14.956, 10.474,  8.22 ],
       [ 8.677,  6.729,  4.366, ..., 20.98 , 17.217, 16.225],
       ...,
       [23.719, 22.934, 21.37 , ..., 38.054, 41.226, 28.066],
       [26.14 , 24.18 , 20.693, ..., 41.975, 39.995, 33.043],
       [29.138, 23.461, 18.273, ..., 38.704, 26.502, 27.926]])
Coordinates:
  * time_20_c        (time_20_c) datetime64[ns] 5kB 2024-09-23T00:23:57.53276...
  * echo_sample_ind  (echo_sample_ind) int8 128B 0 1 2 3 4 ... 124 125 126 127
    lat_20_c         (time_20_c) float64 5kB -65.15 -65.16 ... -67.0 -67.0
    lon_20_c         (time_20_c) float64 5kB 125

NameError: name 'identify_leading_edge' is not defined

rror("name \'identify_leading_edge\' is not defined")'
Traceback: '  File "/home/abonnin/anaconda3/envs/sentinel3-sral-demo/lib/python3.12/site-packages/dask/optimization.py", line 1001, in __call__\n    return core.get(self.dsk, self.outkey, dict(zip(self.inkeys, args)))\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n  File "/home/abonnin/anaconda3/envs/sentinel3-sral-demo/lib/python3.12/site-packages/dask/core.py", line 157, in get\n    result = _execute_task(task, cache)\n             ^^^^^^^^^^^^^^^^^^^^^^^^^^\n  File "/home/abonnin/anaconda3/envs/sentinel3-sral-demo/lib/python3.12/site-packages/dask/core.py", line 127, in _execute_task\n    return func(*(_execute_task(a, cache) for a in args))\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n  File "/tmp/ipykernel_3165/3513934251.py", line 20, in <lambda>\n'

2024-10-05 22:01:19,783 - distributed.worker - ERROR - Compute Failed
Key:       ('lambda-c9198e16ddca1e5fdea9fc365bdbf830', 622, 0)
St