# An EGMS data script

A seamless way to reproduce EGMS L3 ORTHO products from L2a or L2b products.

### Imports

In [None]:
import logging
import math
import re
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any

import dask.dataframe as dd
import geopandas as gpd
import numpy as np
import pandas as pd
from dask import delayed
from dask.distributed import Client, LocalCluster
from shapely.geometry import Polygon, box

from egms_io import ortho_writer, read_csv_or_zip
from format_dataframe import (
    FieldSpec,
    OutputNamingConvention,
    compute_pid,
    format_dataframe,
)
from metadata import generate_output_name
from model_estimation import model_estimation

### Parameters

Define the spatial/temporal sampling plan and AOI upfront so every later stage (gridding, interpolation, LS solving) can rely on consistent spacing, anchoring rules, acquisition window, and burst inventory.

In [None]:
input_bursts: list[Path] = list(Path.cwd().glob("test/input/*.zip"))
aoi: Polygon = box(4500000.0, 1700000.0, 4600000.0, 1800000.0)
aoi_name: str = "E45N17"
start_date: datetime = datetime.fromisoformat("2020-01-03")
end_date: datetime = datetime.fromisoformat("2024-12-25")
spacing: list[int] = [100, 100, 6]  # meters in easting, northing, days
anchor: list[str] = ["center", "center", "start"]
max_gap_days: int = 90  # maximum gap in days for interpolation
output_dir: Path = Path.cwd() / "output"

### Start distributed environment

Spin up a lightweight Dask cluster; this can be changed to a different cluster, e.g. Slurm, PBS or Kubernetes, so that subsequent rasterization and least-squares steps can scale beyond a single node while still respecting queue policies; the snippet also prints the dashboard endpoint, which can be used to monitor the computation progress.

In order to show the computation progress in JupyterLab, please create a new cluster through the Dask-JupyterLab web interface and copy-paste the scheduler address in the cell below; otherwise, a new cluster will be created, detached from the web interface.

In [None]:
scheduler_address: str | None = None  # e.g. "tcp://127.0.0.1:8786"
if scheduler_address is None:
    cluster = LocalCluster(n_workers=20, memory_limit=0, processes=True)
    client = Client(cluster)
else:
    client = Client(scheduler_address)
print(f"Dashboard available at: {client.dashboard_link}")
logger = logging.getLogger("egms")

### Data ingestion

Load each Sentinel-1 burst (CSV or zipped CSV) into memory via the shared `egms_io` helper, tagging rows with their source path so downstream diagnostics can trace anomalies back to the originating burst.

In [None]:
def data_ingestion(
    path: Path,
    *args: Any,
    logger: logging.Logger | None = None,
    **kwargs: Any,
 ) -> gpd.GeoDataFrame | tuple[gpd.GeoDataFrame, ...]:
    """Read a burst file (CSV or zipped CSV) and tag rows with their origin."""
    logger = logger or logging.getLogger("egms")
    out = read_csv_or_zip(path, *args, **kwargs)
    record_count = len(out) if isinstance(out, gpd.GeoDataFrame) else len(out[0])
    logger.info("Read %s with %d records", path, record_count)
    if isinstance(out, gpd.GeoDataFrame):
        out["path"] = path.stem
    else:
        out[0]["path"] = path.stem
    logger.debug("Added 'path' column to data")
    return out

In [None]:
raw_dfs = [
    delayed(data_ingestion)(path, read_xml=False, logger=logger)
    for path in input_bursts
]

### Time-interpolation procedure

Normalize each burst’s timeline by interpolating missing acquisition dates to a regular cadence (`spacing[2]`). This yields dense, gap-free temporal stacks so cells share comparable observations when we solve the LS systems later.

In [None]:
def interpolate_time_columns(
    df: gpd.GeoDataFrame,
    *,
    start: datetime | None,
    end: datetime | None,
    spacing_days: int,
    max_gap_days: int | None = None,
    interpolate_kwargs: dict[str, Any] | None = None,
    logger: logging.Logger | None = None,
 ) -> gpd.GeoDataFrame:
    """Interpolate acquisition columns to a regular cadence in place."""
    log = logger or logging.getLogger("egms")
    log.info("Starting interpolation of time columns")
    frame = df.copy()
    time_columns = sorted(
        (col for col in frame.columns if re.fullmatch(r"\d+", str(col))),
        key=int,
    )
    log.debug("Identified time columns: %s", time_columns)
    if not time_columns:
        return frame
    numeric = frame[time_columns].apply(pd.to_numeric, errors="coerce")
    time_index = pd.to_datetime(time_columns, format="%Y%m%d")
    start_ts = pd.Timestamp(start or time_index.min())
    log.debug("Using start timestamp: %s", start_ts)
    end_ts = pd.Timestamp(end or time_index.max())
    log.debug("Using end timestamp: %s", end_ts)
    freq = pd.Timedelta(days=int(spacing_days))
    log.debug("Using frequency: %s", freq)
    target_index = pd.date_range(start=start_ts, end=end_ts, freq=freq)
    log.debug("Generated target index with %d timestamps", len(target_index))
    if target_index.empty:
        target_index = pd.DatetimeIndex([start_ts])
    temporal_matrix = (
        pd.DataFrame(
            numeric.to_numpy().T,
            index=time_index,
            columns=frame.index,
        ).reindex(target_index)
    )
    log.debug(
        "Reindexed temporal matrix to target index with shape %s",
        temporal_matrix.shape,
    )
    interp_opts: dict[str, Any] = {
        "axis": 0,
        "limit_direction": "both",
        "method": "time",
    }
    if max_gap_days is not None:
        interp_opts["limit"] = math.ceil(max_gap_days / float(spacing_days))
    if interpolate_kwargs:
        interp_opts.update(interpolate_kwargs)
    interpolated = temporal_matrix.interpolate(**interp_opts).ffill().bfill()
    log.debug(
        "Interpolated temporal matrix with shape %s",
        interpolated.shape,
    )
    target_columns = [stamp.strftime("%Y%m%d") for stamp in target_index]
    interpolated = interpolated.T.set_axis(target_columns, axis=1)
    frame = frame.drop(columns=time_columns, errors="ignore").join(
        interpolated,
        how="left",
    )
    log.info(
        "Completed interpolation of time columns with %d columns",
        len(target_columns),
    )
    return frame


dfs = [
    delayed(interpolate_time_columns)(
        df,
        start=start_date,
        end=end_date,
        spacing_days=spacing[2],
        max_gap_days=max_gap_days,
        logger=logger,
    )
    for df in raw_dfs
]

### Rasterization

Translate each observation into a spatio-temporal cell ID derived from the anchor-aware grid. The encoding packs x/y/time indices into a 64-bit integer so we can shuffle gigantic datasets through Dask while keeping the geometry reconstruction deterministic.

In [None]:
# Definizione grigliato di uscita
spacing = np.array(spacing, dtype=int)
for idx in range(3):
    match anchor[idx]:
        case "start":
            anchor[idx] = 0
        case "center":
            anchor[idx] = int(spacing[idx] / 2)
        case "end":
            anchor[idx] = int(spacing[idx]) - 1
        case _:
            raise ValueError(
                f"Invalid {idx}-th anchor value: {anchor[idx]}"
            )
anchor = np.array(anchor, dtype=int)

xmin = aoi.bounds[0] - anchor[0]
ymin = aoi.bounds[1] - anchor[1]
tmin = start_date - timedelta(days=anchor[2].item())

def date_to_days(date_str: str) -> int:
    """Convert an acquisition date (YYYYMMDD) into the grid time index."""
    date = datetime.strptime(date_str, "%Y%m%d")
    return int((date - tmin).days // spacing[2])

In [None]:
def compute_cell_id(df: gpd.GeoDataFrame) -> pd.DataFrame:
    """Generate flattened spatio-temporal records from per-burst columns.
    
    The ID is made up as follows:
    - bits 0-15: cell index in easting
    - bits 16-31: cell index in northing
    - bits 32-63: cell index in time (days since tmin divided by spacing in days)
    """
    cid_x = df.geometry.x.sub(xmin).floordiv(spacing[0]).astype(int)
    cid_y = df.geometry.y.sub(ymin).floordiv(spacing[1]).astype(int)
    granules: list[pd.DataFrame] = []
    for col in df.columns:
        if not re.fullmatch(r"\d+", col):
            continue
        granule = pd.DataFrame(
            {
                "los_east": df["los_east"],
                "los_up": df["los_up"],
                "los_north": df["los_north"],
                "height": df["height_ortho"],
                "gnss_velocity": df["gnss_velocity"],
                "displ": df[col],
                "cid": cid_x.values
                + (cid_y.values << 16)
                + (date_to_days(col) << 32),
            }
        )
        granules.append(granule)
    return pd.concat(granules, ignore_index=True, copy=True)

In [None]:
# Reconstruct the final GeoDataFrame, recovering the cell spatio-temporal coordinates
def cell_id_to_coords(cid: int) -> tuple[float, float, datetime]:
    """Decode a cell identifier back into x, y, time coordinates."""
    x_idx = cid & 0xFFFF
    y_idx = (cid >> 16) & 0xFFFF
    t_idx = (cid >> 32) & 0xFFFF
    x = xmin + x_idx * spacing[0] + anchor[0]
    y = ymin + y_idx * spacing[1] + anchor[1]
    t = tmin + timedelta(days=int(t_idx * spacing[2] + anchor[2]))
    return (x, y, t)

In [None]:
dfs = [delayed(compute_cell_id)(df) for df in dfs]

In [None]:
df = dd.from_delayed(dfs, meta={
    "los_east": "float64",
    "los_up": "float64",
    "los_north": "float64",
    "height": "float64",
    "gnss_velocity": "float64",
    "displ": "float64",
    "cid": "int64",
})

Repartition the DataFrame to have partitions of ~4 MB, so that each partition can fit in memory during LS computation.

In [None]:
df = df.repartition(partition_size="4MB", force=True)

### LS computation

Group observations by cell ID and solve a per-cell 2×N least-squares problem (mixing ascending/descending looks) to recover east and up displacements. Cells lacking multi-geometry coverage are skipped to avoid unstable fits.

A similar 3xN LS problem is also solved for retreiving GNSS information for each cell.

In [None]:
def _ls_solve(partition: pd.DataFrame) -> pd.Series:
    """Solve east/up displacement and GNSS velocity for a cell partition."""
    asc_mask = partition.los_east < 0
    # Exclusion conditions
    excl_cond = asc_mask.all() # all ascending, no descending
    excl_cond |= not asc_mask.any() # all descending, no ascending
    excl_cond |= partition.size < 3 # less than three points
    if excl_cond:
        displ_sol = np.full(2, np.nan, dtype=float)
        gnss_sol = np.full(3, np.nan, dtype=float)
    else:
        # Solve the problem for the displacement
        design = partition[["los_east", "los_up"]].to_numpy()
        observed = partition["displ"].to_numpy()
        try:
            displ_sol, *_ = np.linalg.lstsq(design, observed, rcond=None)
        except np.linalg.LinAlgError:
            displ_sol = np.full(2, np.nan, dtype=float)
        # Solve the problem for the GNSS velocity
        design = partition[["los_east", "los_up", "los_north"]].to_numpy()
        observed = partition["gnss_velocity"].to_numpy()
        try:
            gnss_sol, *_ = np.linalg.lstsq(design, observed, rcond=None)
        except np.linalg.LinAlgError:
            gnss_sol = np.full(3, np.nan, dtype=float)

    return pd.Series({
        "east_disp": displ_sol[0],
        "up_disp": displ_sol[1],
        "gnss_velocity_e": gnss_sol[0],
        "gnss_velocity_u": gnss_sol[1],
        "gnss_velocity_n": gnss_sol[2],
        "height": partition["height"].mean(),
    })


df = df.groupby("cid").apply(
    _ls_solve,
    meta={
        "east_disp": "float64",
        "up_disp": "float64",
        "gnss_velocity_e": "float64",
        "gnss_velocity_u": "float64",
        "gnss_velocity_n": "float64",
        "height": "float64",
    },
 )

In [None]:
computed_df = df.compute()

### Rebuild displacement cubes

Once the LS inversion finishes, decode the packed cell IDs back to x/y/time triples and pivot the results into GeoDataFrames so each acquisition date becomes a column aligned with its geometry.

In [None]:
data = pd.DataFrame(computed_df.index.map(cell_id_to_coords).tolist(), columns=["x", "y", "t"], index=computed_df.index)
data = data.join(computed_df, how="outer")
data["x"] += anchor[0]
data["y"] += anchor[1]
data["t_str"] = (data["t"] + timedelta(days=anchor[2].item())).dt.strftime("%Y%m%d")

In [None]:
def build_component_gdf(data: pd.DataFrame, component: str) -> gpd.GeoDataFrame:
    """Pivot displacement component by acquisition date and attach geometries."""
    pivot = (
        data.pivot_table(
            index=["x", "y"],
            columns="t_str",
            values=component,
            aggfunc="mean",
        )
        .rename_axis(None, axis=1)
        .reset_index()
    )
    static_fields = (
        data.sort_values("t")
        .drop_duplicates(subset=["x", "y"], keep="first")
        [[
            "x",
            "y",
            "gnss_velocity_e",
            "gnss_velocity_u",
            "gnss_velocity_n",
            "height",
        ]]
    )
    pivot = pivot.merge(static_fields, on=["x", "y"], how="left")
    geometry = gpd.points_from_xy(pivot.pop("x"), pivot.pop("y"))
    return gpd.GeoDataFrame(pivot, geometry=geometry, crs="EPSG:3035")


east_gdf = build_component_gdf(data, "east_disp")
up_gdf = build_component_gdf(data, "up_disp")

### Output formatting

Apply `model_estimation` to translate LOS-derived stacks into east/up displacement summaries, then persist each component to ZIP and TIFF as required by the EGMS product specifications.

In [None]:
output_format = OutputNamingConvention(
    field_list=[
        FieldSpec("pid"),
        FieldSpec("easting", precision=0),
        FieldSpec("northing", precision=0),
        FieldSpec("height_ortho", previous_name="height", precision=1),
        FieldSpec("rmse_ts", previous_name="rmse", precision=1),
        FieldSpec("mean_velocity", precision=1),
        FieldSpec("mean_velocity_std", precision=1),
        FieldSpec("acceleration", precision=2),
        FieldSpec("acceleration_std", precision=2),
        FieldSpec("seasonality", precision=1),
        FieldSpec("seasonality_std", precision=1),
        FieldSpec("gnss_velocity_n", precision=1),
        FieldSpec("gnss_velocity_e", precision=1),
        FieldSpec("gnss_velocity_u", precision=1),
    ],
    date_format="%Y%m%d",
    date_precision=1,
)

In [None]:
output_dir.mkdir(parents=True, exist_ok=True)

wave_length_mm = 55.465760  # Wavelength of Sentinel-1 satellites
for final_df, suffix in [(east_gdf, "east"), (up_gdf, "up")]:
    final_df = final_df.clip(aoi)
    final_df = model_estimation(final_df, wave_length_mm)
    final_df["pid"] = compute_pid(final_df)
    final_df = format_dataframe(final_df, output_format)
    output_name = generate_output_name(final_df, suffix, aoi_name, 1)
    ortho_writer(final_df, output_dir / f"{output_name}.zip")
    # Produce also a Parquet file for further analysis
    final_df.to_parquet(output_dir / f"{output_name}.parquet")