## Identify leading edge

This is a test notebook, for me to make a few tests about leading edge detection on along track altimeter Waveforms.

The algo is an implementation of the leading edge detection of WHALES algo described in [this document](https://climate.esa.int/sites/default/files/Sea_State_cci_ATBD_v1.1-signed_0.pdf)

For now it is random code, with 0 optims consideration: 
- for loops in numpy arrays instead of using vectorized numpy functions (derivative)
- dask distribution with micro chunks is inefficient

But the objective here is to load zarr time ranges that have not been processed (zarr partitionned by time range) and to use chunks defined at zarr level. 
Zarr variable we use as input is *waveform_20_plrm_ku*, which is 20Hz data, with 128 waveform points per waveform. This variable is 2D, indexed by *time_20_c* (timestamp of the waveform) and *echo_sample_ind* (index of each of the 128 waveform data points).
Idealy, chunks are in the MBs range, so that it is worth using dask.

On each chunk, we have multiple waveforms on which we apply the leading_edge detection algorithm sequentially, to derive *startgate* and *stopgate* attributes of the waveform. And chunks are processed in parallel via dask.

Then the *startgate* and *stopgate* are stored to variables *waveform_20_plrm_ku_startgate* and *waveform_20_plrm_ku_stopgate* which are tied to index *time_20_c*.


In [3]:
import logging
import os
import sys
import numpy as np
import xarray as xr
import dask
import dask.distributed
import dask.array as da
import zcollection

from typing import Tuple, Optional

import eumdac.collection

sys.path.append(os.path.abspath('..'))
from src.connectors.eumdac_connector import EumdacConnector
from src.processors.zarr_processor import ZarrProcessor
from utils.logging_utils import setup_root_logging, setup_module_logger
from utils.opensearch_query_formatter import OpenSearchQueryFormatter

### The Leading edge detection algo - to apply on one waveform

To refacto - use numpy better

In [32]:
def identify_leading_edges(data_chunk: xr.Dataset, tau: float) -> xr.Dataset:
    """
    Identify the startgate and stopgate for multiple waveforms in a dataset chunk.

    :param xr.Dataset data_chunk: Input dataset chunk containing waveforms with shape [time_20_c, 128]
    :param float tau: Time spacing between consecutive gates in seconds
    :return: xr.Dataset containing startgate and stopgate variables indexed by time_20_c
    :rtype: xr.Dataset
    """
    waveforms = data_chunk['waveform_20_plrm_ku'].values  # Accessing the waveform data
    num_waveforms = waveforms.shape[0]  # This will be the chunk size, e.g., 5000
    print(waveforms, num_waveforms)
    startgates = np.full(num_waveforms, np.nan, dtype=np.float64)
    stopgates = np.full(num_waveforms, np.nan, dtype=np.float64)

    # Normalize waveforms
    medians = np.median(waveforms, axis=1, keepdims=True)
    N = 1.3 * medians
    normalized_waveforms = waveforms / N

    # Identify potential startgates
    normalized_diff = np.diff(normalized_waveforms, axis=1)
    startgate_candidates = (normalized_diff >= 0.01).astype(np.int32)
    startgates = np.argmax(startgate_candidates, axis=1) + 1

    # Create a mask for valid startgates
    valid_startgate_mask = (startgate_candidates[np.arange(num_waveforms), startgates] > 0)

    # Check for valid startgates, and get stopgates
    for i in range(num_waveforms):
        if valid_startgate_mask[i]:
            startgate = startgates[i]

            # Validating if the leading edge remains above 0.1 for 4 gates
            if startgate + 4 < normalized_waveforms.shape[1] and np.all(normalized_waveforms[i, startgate:startgate + 5] >= 0.1):
                startgates[i] = startgate

                # Identify stopgates
                derivatives = np.diff(normalized_waveforms[i])
                stopgate_candidates = np.where((derivatives[:-1] > 0) & (derivatives[1:] < 0))[0] + 1

                # Check for valid stopgates
                for stopgate in stopgate_candidates:
                    if stopgate + 3 < normalized_waveforms.shape[1] and np.all(derivatives[stopgate:stopgate + 3] < 0):
                        stopgates[i] = stopgate
                        break

    # Create a dataset to return
    return xr.Dataset({
        "startgate": (["time_20_c"], startgates),
        "stopgate": (["time_20_c"], stopgates),
    })

In [34]:
if __name__ == "__main__":
    # Simulated waveform data
    waveforms = xr.Dataset({
    "waveform_20_plrm_ku": xr.DataArray(
        [[0.0, 0.0, 0.0, 0.0001, 0.0015, 0.002, 0.004, 0.01, 0.02, 0.05, 0.15, 0.1, 0.08, 0.07, 0.03, 0.01, 0.0]],
        dims=["time_20_c", "echo_sample_ind"], 
        coords={"time_20_c": np.arange(1), "echo_sample_ind": np.arange(17)},
    )})
        
    tau: float = 3.125e-9  # Time spacing in seconds

    res = identify_leading_edges(waveforms, tau)
    print(res)

[[0.0e+00 0.0e+00 0.0e+00 1.0e-04 1.5e-03 2.0e-03 4.0e-03 1.0e-02 2.0e-02
  5.0e-02 1.5e-01 1.0e-01 8.0e-02 7.0e-02 3.0e-02 1.0e-02 0.0e+00]] 1
<xarray.Dataset> Size: 16B
Dimensions:    (time_20_c: 1)
Dimensions without coordinates: time_20_c
Data variables:
    startgate  (time_20_c) int64 8B 4
    stopgate   (time_20_c) float64 8B 10.0


### Download datasets containing waveform (Sentinel 3) and store them to zarr

**TODO**: zarr storage

In [5]:
setup_root_logging()

logger: logging.Logger = setup_module_logger(__name__)

COLLECTION_ID: str = "EO:EUM:DAT:0415"
DOWNLOAD_DIR: str = "/tmp/products"
MEASUREMENTS_FILENAME: str = "enhanced_measurement.nc"
ZARR_BASE_PATH: str = "/tmp/sen3_sral_enhanced"
INDEX_DIMENSION: str = "time_20_c"
download_dir: str = os.path.join(os.getcwd(), DOWNLOAD_DIR)

In [2]:

if __name__ == "__main__":
    logger.info("Connecting EUMDAC datastore...")
    connector: EumdacConnector = EumdacConnector()
    datastore: eumdac.datastore.DataStore = connector.datastore

    # Query a few data files for Sentinel3A and 3B SRAL (Level2 data) for 2024-09-20
    opensearch_query: str = OpenSearchQueryFormatter(
        query_params={
            "pi": COLLECTION_ID,
            "dtstart": "2024-09-23T00:00:00Z",
            "dtend": "2024-09-23T00:20:00Z",
        }
    ).format()
    logger.info("Listing EUMDAC products matching filters '%s'", opensearch_query)
    products: eumdac.collection.SearchResults = datastore.opensearch(query=opensearch_query)
    product_ids: list[str] = [str(x) for x in products]

    logger.info("%s matching products found", len(product_ids))
    logger.debug("Listed products are: %s", product_ids)

    # Download files - benefits of dask parallelization
    logger.info("Downloading products (dask parallelized)...")
    downloaded_folders: list[str] = connector.download_products(
        COLLECTION_ID, product_ids, download_dir, MEASUREMENTS_FILENAME
    )

    # Store files to partitionned zarr files
    # Partition by day (with zcollection) - to be tuned depending on data use / volumetry
    logger.info("Persisting data in a partitionned zarr collection...")
    netcdf_file_paths: list[str] = [
        os.path.join(folder, MEASUREMENTS_FILENAME) for folder in downloaded_folders
    ]
    partition_handler: zcollection.partitioning.Partitioning = zcollection.partitioning.Date(
        (INDEX_DIMENSION,), resolution='M'
    )

    zarr_processor: ZarrProcessor = ZarrProcessor(
        ZARR_BASE_PATH, 
        partition_handler, 
        index_dimension=INDEX_DIMENSION,
    )
    zarr_processor.netcdf_2_zarr(
        netcdf_file_paths,
        variables = ["waveform_20_plrm_ku"], # only extract waveform_20_plrm_ku
        chunk_sizes={
            # About 4mn chunks at 50Hz with int32 values- 2MB
            "waveform_20_plrm_ku": (5000,128)
        } 
    )
    logger.info("Job done")

2024-10-06 22:33:45.461 INFO     | 127.0.1.1 PID:7456 TID:140623667745856 | __main__ <module> 3905230125.py: 13 | Connecting EUMDAC datastore...
2024-10-06 22:33:45.624 INFO     | 127.0.1.1 PID:7456 TID:140623667745856 | __main__ <module> 3905230125.py: 25 | Listing EUMDAC products matching filters 'pi=EO:EUM:DAT:0415&dtstart=2024-09-23T00:00:00Z&dtend=2024-09-23T00:20:00Z'
2024-10-06 22:33:46.959 INFO     | 127.0.1.1 PID:7456 TID:140623667745856 | __main__ <module> 3905230125.py: 29 | 10 matching products found
2024-10-06 22:33:46.961 INFO     | 127.0.1.1 PID:7456 TID:140623667745856 | __main__ <module> 3905230125.py: 33 | Downloading products (dask parallelized)...
2024-10-06 22:33:47.367 INFO     | 127.0.1.1 PID:7456 TID:140623667745856 | src.connectors.eumdac_connector download_products eumdac_connector.py:177 | Downloading products...
2024-10-06 22:34:07.095 INFO     | 127.0.1.1 PID:7456 TID:140623667745856 | __main__ <module> 3905230125.py: 40 | Persisting data in a partitionned 

### WIP

To refacto 100%: load from zarr, process bigger chunks

In [36]:
import numpy as np
import xarray as xr
import dask.array as da

collection = zcollection.open_collection(ZARR_BASE_PATH)
print(collection)
data = collection.load(filters='year == 2024 and month == 9').to_xarray()
data['waveform_20_plrm_ku'] = data['waveform_20_plrm_ku'].chunk({'time_20_c': 5000, 'echo_sample_ind': 128})
print(data)

data = data.unify_chunks()
print(data["time_20_c"])

tau: float = 3.125e-9

# Check how many chunks are in the time_20_c dimension
num_chunks = data["time_20_c"].shape[0] // 5000 + (data["time_20_c"].shape[0] % 5000 > 0)
print(num_chunks)

# Create Dask arrays for the template
num_waveforms = 5000  # Number of waveforms in a chunk
num_data_points = 128  # Number of data points per waveform

template: xr.Dataset = xr.Dataset({
    "startgate": xr.DataArray(
        # 1D Dask array with num_waveforms points (1 per input waveform)
        da.empty(shape=(num_waveforms,), chunks=(num_waveforms,), dtype=np.float64),
        dims=["time_20_c"],
        coords={"time_20_c": np.arange(num_waveforms)},
    ), 
    "stopgate": xr.DataArray(
        # 1D Dask array with num_waveforms points (1 per input waveform)
        da.empty(shape=(num_waveforms,), chunks=(num_waveforms,), dtype=np.float64), 
        dims=["time_20_c"],
        coords={"time_20_c": np.arange(num_waveforms)},
    ),
})

# Now call map_blocks with the updated template
res = data.map_blocks(
    identify_leading_edges, 
    args=(tau,),
    template=template
)
res.compute()


2024-10-06 23:45:27.352 INFO     | 127.0.1.1 PID:9642 TID:140550285526080 | zcollection.collection from_config __init__.py:283 | Opening collection: '/tmp/sen3_sral_enhanced'


<Collection filesystem='LocalFileSystem', partition_base_dir='/tmp/sen3_sral_enhanced'mode='r'>
<xarray.Dataset> Size: 213MB
Dimensions:              (time_20_c: 203088, echo_sample_ind: 128)
Coordinates:
  * time_20_c            (time_20_c) datetime64[ns] 2MB 2024-09-22T23:33:26.0...
  * echo_sample_ind      (echo_sample_ind) int8 128B 0 1 2 3 ... 124 125 126 127
Data variables:
    lon_20_c             (time_20_c) float64 2MB dask.array<chunksize=(203088,), meta=np.ndarray>
    lat_20_c             (time_20_c) float64 2MB dask.array<chunksize=(203088,), meta=np.ndarray>
    waveform_20_plrm_ku  (time_20_c, echo_sample_ind) float64 208MB dask.array<chunksize=(5000, 128), meta=np.ndarray>
Attributes: (12/64)
    Conventions:                                 CF-1.6
    absolute_pass_number:                        89584
    absolute_rev_number:                         44792
    acq_station_name:                            CGS
    algo_bias_wind_speed_2p_sig0_plrm_ku_added:  2.7525
    alg

ValueError: map_blocks requires that one block of the input maps to one block of output. Expected number of output chunks along dimension 'time_20_c' to be 41. Received 1 instead. Please provide template if not provided, or fix the provided template.