# EMIT MF watershed parametrization

Also consider morphological operations

Ticket: [#1414](https://git.orbio.earth/orbio/orbio/-/issues/1414)

In [None]:
import fsspec
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr
from lib.plume_masking import retrieval_mask_using_watershed_algo, sobel_marker_coordinates
from satellite_data_product.emit.masking import mask_retrieval_watershed

In [None]:
with fsspec.open(
    "azureml://subscriptions/6e71ce37-b9fe-4c43-942b-cf0f7e78c8ab/resourcegroups/orbio-ml-rg/workspaces/"
    "orbio-ml-ml-workspace/datastores/workspaceblobstore/paths/data/emit/emit_mf_gt_retrievals.nc"
) as fs:
    mfda = xr.open_dataset(fs)["mf_retrievals"].load()

In [None]:
# Adaptation of `satellite_data_product.emit.masking.mask_retrieval_watershed` that allows us to have a
# fixed-value floor threshold rather than one based on quantiles
def mask_retrieval_watershed_alt(
    retrieval: np.ndarray,
    masked_distance: int = 1,
    watershed_floor_threshold: float = 0.02,
    marker_threshold: float = 0.997,
) -> np.ndarray:
    """
    Apply watershed masking to the retrieval.

    Masking should be agnostic to the retrieval units, so retrieval can either be in mol/m2 or ppm. Default
    parameters are tuned to EMIT retrieval data.
    """
    # NOTE: EMIT retrievals have -9999 as no data. We want to remove these so watershed performs better
    retrieval = np.where(retrieval <= 0, np.nan, retrieval)

    marker_coords = sobel_marker_coordinates(
        retrieval, masked_distance=masked_distance, marker_threshold=marker_threshold
    )
    return retrieval_mask_using_watershed_algo(
        retrieval, marker_coords=marker_coords, watershed_floor_threshold=watershed_floor_threshold
    )

In [None]:
from skimage.morphology import binary_closing

In [None]:
site_granule_map = pd.read_csv("../emit-cv-mf-comparison/emit_gt_granule_map.csv", index_col="dual_index")

In [None]:
n_samples = 15

n_cols = 7
fig_scaling = 3

closing_footprint = np.ones((3, 3), dtype=bool)

samples = site_granule_map.sample(n=n_samples)

for site_id, site_props in samples.iterrows():
    fig = plt.figure(figsize=(fig_scaling * n_cols, fig_scaling))
    try:
        x = mfda.sel(dual_index=site_id).isel(band=0)
    except KeyError:
        print(f"Missing retrieval for {site_id}")
        continue
    y_current = mask_retrieval_watershed(x)
    fixed_floor = 0.02
    y_fixed_threshold = mask_retrieval_watershed_alt(x, watershed_floor_threshold=fixed_floor)
    tuned_quantile = 0.385
    y_tuned_quantile_threshold = mask_retrieval_watershed(x, watershed_floor_quantile=tuned_quantile)

    plot_payload = [
        {"label": "Retrieval", "data": x, "plt_kwargs": {"cmap": "Reds", "vmin": 0}},
        {"label": "Current quantile (0.2)", "data": y_current},
        {"label": "Current w/ closing", "data": binary_closing(y_current, footprint=closing_footprint)},
        {"label": f"Tuned quantile ({tuned_quantile})", "data": y_tuned_quantile_threshold},
        {
            "label": "Tuned quantile w/ closing",
            "data": binary_closing(y_tuned_quantile_threshold, footprint=closing_footprint),
        },
        {"label": f"Fixed floor ({fixed_floor})", "data": y_fixed_threshold},
        {"label": "Fixed floor w/ closing", "data": binary_closing(y_fixed_threshold, footprint=closing_footprint)},
    ]

    print(site_id)  # for easier copy/pasting
    for i, d in enumerate(plot_payload, start=1):
        ax = plt.subplot(1, n_cols, i)
        ax.imshow(d["data"], **d.get("plt_kwargs", {}))
        ax.set_title(d["label"])
        ax.axis("off")

    fig.suptitle(f"{site_id} (estimated {int(site_props.quantification_kg_h):,} kg/hr)")