# Rationale for the fusion of independent Sentinel-1 and Sentinel-2 exclusion masks and inland water retrievals

Multi-sensor fusion pipeline that combines independent exclusion masks and indipendent water retrievals derived from Sentinel-1 (S1) SAR and Sentinel-2 (S2) optical data. Exclusion masks are merged using a decision-level rule. The final water extent is estimated through a Gaussian fusion model that exploits the mean and standard deviation of the sensor-specific retrievals, which have been previously filtered using the exclusion masks.
For further details please refer to the journal submission.

## Workflow Overview
1. Load S1 and S2 predictions (in the form of mean and standard deviation) previously computed.
2. Upload/generate indipendent exclusion masks for the S1 and S2 acquisitions from which the predictions are derived.
3. Fusion of indipendent masked Gaussian distributions (i.e., masked S1 and S2 predictions) from both sensors.

## 1. Setup and imports

In [None]:
import numpy as np
import xarray as xr
import rioxarray
import pystac_client
import odc.stac
from odc.geo.geobox import GeoBox
from datetime import datetime, timedelta

## 2. Configuration and file paths

In [None]:
# File paths for S1 predictions
s1mu_result_path = "path/to/your/s1mufile.tif"
s1sigma_result_path = "path/to/your/s1sigmafile.tif"

# File paths for S2 predictions
s2mu_result_path = "path/to/your/s2mufile.tif"
s2sigma_result_path = "path/to/your/s2sigmafile.tif"

# Spatial resolution and CRS
dx = 0.0001  # 10m resolution
epsg = 4326
time_format = "%Y-%m-%d"

# Temporal queries for S1
start_date_s1 = datetime(year=2023, month=6, day=20) # Set the date to the S1 acquisition date (same scene as the prediction)
end_date_s1 = start_date_s1 + timedelta(days=1)
date_query_s1 = start_date_s1.strftime(time_format) + "/" + end_date_s1.strftime(time_format)

# Temporal queries for S2
start_date_s2 = datetime(year=2023, month=6, day=20) # Set the date to the S2 acquisition date (same scene as the prediction)
end_date_s2 = start_date_s2 + timedelta(days=1)
date_query_s2 = start_date_s2.strftime(time_format) + "/" + end_date_s2.strftime(time_format)

## 3. Utility functions

In [None]:
def is_valid_pixel_S1(data):
    """
    Return valid S1 observations.
    
    Parameters:
    -----------
    data : xarray.DataArray
        GFM exclusion mask
    
    Returns:
    --------
    xarray.DataArray : Boolean mask where True indicates valid pixels
    """
    return data == 0

In [None]:
def compute_ndvi_band(data):
    """
    Compute NDVI from Sentinel-2 NIR and Red bands.
    
    Parameters:
    -----------
    data : xarray.Dataset
        Dataset containing 'nir' and 'red' bands
    
    Returns:
    --------
    xarray.DataArray : NDVI values
    """
    return (data.nir - data.red) / (data.nir + data.red + 1e-6)

In [None]:
def is_valid_pixel_s2(data):
    """
    Return valid S2 observations based on Scene Classification Layer (SCL).
    
    Excludes: topographic shadow, cloud shadow, cloud medium/high probability, thin cirrus
    Valid classes: 4-7 (vegetation, not vegetated, water, unclassified) and 11 (snow/ice)
    
    Parameters:
    -----------
    data : xarray.DataArray
        SCL classification data
    
    Returns:
    --------
    xarray.DataArray : Boolean mask where True indicates valid pixels
    """
    return ((data > 3) & (data < 8)) | (data == 11)

In [None]:
def fuse_masked_binary_masks(mask1, mask2):
    """
    Fuse two binary masks with NaN handling.
    
    Logic:
    - If both mask1 and mask2 are NaN → result is NaN
    - If either mask is 1 → result is 1
    - If one is 0 and the other is NaN → result is 0
    - If both are 0 → result is 0
    
    Parameters:
    -----------
    mask1, mask2 : xarray.DataArray
        Binary masks (0, 1, or NaN)
    
    Returns:
    --------
    xarray.DataArray : Fused mask
    """
    result = xr.where((mask1 == 1) | (mask2 == 1), 1, 0)
    fused_masked = result.where(~(xr.ufuncs.isnan(mask1) & xr.ufuncs.isnan(mask2)))
    return fused_masked

In [None]:
def fuse_gaussians(mu1, sigma1, mu2, sigma2):
    """
    Fuse two Gaussian distributions using optimal weighting.
    
    The fusion formula:
    - fused_mu = (var2 * mu1 + var1 * mu2) / (var1 + var2)
    - fused_var = (var1 * var2) / (var1 + var2)
    
    Handles cases where only one sensor is valid:
    - If both valid → use fusion formula
    - If only one valid → use that sensor's mu and sigma
    - If neither valid → output NaN
    
    Parameters:
    -----------
    mu1, sigma1 : xarray.DataArray
        Mean and std dev for first Gaussian (S1)
    mu2, sigma2 : xarray.DataArray
        Mean and std dev for second Gaussian (S2)
    
    Returns:
    --------
    fused_mu, fused_sigma : xarray.DataArray
        Fused mean and standard deviation
    """
    epsilon = 1e-6

    # Replace zero or negative sigma values with small epsilon
    sigma1 = xr.where((~xr.ufuncs.isfinite(sigma1)) | (sigma1 <= 0), epsilon, sigma1)
    sigma2 = xr.where((~xr.ufuncs.isfinite(sigma2)) | (sigma2 <= 0), epsilon, sigma2)

    # Compute variances
    var1 = sigma1 ** 2
    var2 = sigma2 ** 2

    # Validity masks
    valid_1 = xr.ufuncs.isfinite(mu1) & xr.ufuncs.isfinite(sigma1)
    valid_2 = xr.ufuncs.isfinite(mu2) & xr.ufuncs.isfinite(sigma2)
    both_valid = valid_1 & valid_2
    only_1_valid = valid_1 & ~valid_2
    only_2_valid = valid_2 & ~valid_1

    # Initialize output
    fused_mu = xr.full_like(mu1, np.nan)
    fused_sigma = xr.full_like(sigma1, np.nan)

    # Fusion formula where both valid
    fused_mu_both = (var2 * mu1 + var1 * mu2) / (var1 + var2)
    fused_var_both = (var1 * var2) / (var1 + var2)
    fused_sigma_both = np.sqrt(fused_var_both)

    # Apply conditions
    fused_mu = xr.where(both_valid, fused_mu_both, fused_mu)
    fused_sigma = xr.where(both_valid, fused_sigma_both, fused_sigma)

    fused_mu = xr.where(only_1_valid, mu1, fused_mu)
    fused_sigma = xr.where(only_1_valid, sigma1, fused_sigma)

    fused_mu = xr.where(only_2_valid, mu2, fused_mu)
    fused_sigma = xr.where(only_2_valid, sigma2, fused_sigma)

    return fused_mu, fused_sigma

## 4. Sentinel-1 data processing

### 4.1 Load S1 predictions

In [None]:
# Load S1 predictions
s1_mu = rioxarray.open_rasterio(s1mu_result_path)
s1_sigma = rioxarray.open_rasterio(s1sigma_result_path)

# Reproject to WGS84, then optimize data size
s1mu_result_latlon = s1_mu.rio.reproject("EPSG:4326").round(3).astype("float32")
s1sigma_result_latlon = s1_sigma.rio.reproject("EPSG:4326").round(3).astype("float32")

print(f"S1 μ shape: {s1mu_result_latlon.shape}")
print(f"S1 μ CRS: {s1mu_result_latlon.rio.crs}")

### 4.2 Load S1 Exclusion Mask

In [None]:
# Get spatial bounds from S1 data
bounds = s1mu_result_latlon.rio.bounds()

# Connect to EODC STAC catalog
client = pystac_client.Client.open("https://stac.eodc.eu/api/v1")

# Search for GFM exclusion mask
items = client.search(
    collections=["GFM"],
    bbox=bounds,
    datetime=date_query_s1,
    limit=100,
).item_collection()

print(f"{len(items)} GFM scenes found")
print(f"Available assets: {items[0].assets.keys()}")

In [None]:
# Create geobox for datacube
geobox = GeoBox.from_bbox(bounds, crs=f"epsg:{epsg}", resolution=dx)

# Load GFM data into datacube
dc = odc.stac.load(
    items,
    bands=["exclusion_mask"],
    chunks={'x': 512, 'y': 512},
    geobox=geobox,
    resampling="nearest",
    groupby="solar_day"
)

# Clean up mask (remove 255 values)
dc['exclusion_mask'] = dc.exclusion_mask.where(dc.exclusion_mask != 255)
dc = dc.squeeze()

print(f"Datacube shape: {dc.exclusion_mask.shape}")

### 4.3 Apply S1 Mask and register data

In [None]:
# Create validity mask for S1
dc["valid_S1"] = is_valid_pixel_S1(dc.exclusion_mask)

# Coregister S1 predictions to datacube
dc["S1_mu"] = (s1mu_result_latlon
               .rio.reproject_match(dc["valid_S1"])
               .squeeze(drop=True)
               .rename({"y": "latitude", "x": "longitude"}))

dc["S1_sigma"] = (s1sigma_result_latlon
                  .rio.reproject_match(dc["valid_S1"])
                  .squeeze(drop=True)
                  .rename({"y": "latitude", "x": "longitude"}))

print("S1 data successfully registered to datacube")

## 5. Sentinel-2 Data Processing

### 5.1 Load S2 Predictions

In [None]:
# Load S2 predictions
s2_mu = rioxarray.open_rasterio(s2mu_result_path)
s2_sigma = rioxarray.open_rasterio(s2sigma_result_path)

# Reproject to WGS84, then optimize data size
s2mu_result_latlon = s2_mu.rio.reproject("EPSG:4326").round(3).astype("float32")
s2sigma_result_latlon = s2_sigma.rio.reproject("EPSG:4326").round(3).astype("float32")

print(f"S2 μ shape: {s2mu_result_latlon.shape}")
print(f"S2 μ CRS: {s2mu_result_latlon.rio.crs}")

### 5.2 Load S2 Scene Classification and Spectral Bands

In [None]:
# Get spatial bounds
bounds = s2mu_result_latlon.rio.bounds()

# Search for Sentinel-2 data
items_s2 = pystac_client.Client.open(
    "https://earth-search.aws.element84.com/v1"
).search(
    bbox=bounds,
    collections=["sentinel-2-l2a"],
    datetime=date_query_s2,
    limit=100,
).item_collection()

print(f"{len(items_s2)} S2 scenes found")

In [None]:
# Create geobox
geobox_s2 = GeoBox.from_bbox(bounds, crs=f"epsg:{epsg}", resolution=dx)

# Load S2 data into datacube
dc_s2 = odc.stac.load(
    items_s2,
    bands=["red", "nir", "scl"],
    chunks={"x": 512, "y": 512},
    groupby="solar_day",
    geobox=geobox_s2,
    resampling={
        "red": "bilinear",
        "nir": "bilinear",
        "scl": "nearest"
    }
)

# Aggregate time dimension
dc_s2 = xr.Dataset({
    "red": dc_s2.red.median(dim='time'),
    "nir": dc_s2.nir.median(dim='time'),
    "scl": dc_s2.scl.max(dim='time')
})

print(f"S2 datacube shape: {dc_s2.red.shape}")

### 5.3 Build up the S2 exclusion mask

In [None]:
# Compute NDVI
dc_s2["ndvi"] = compute_ndvi_band(dc_s2)

# Create validity mask combining SCL and NDVI threshold
# Current: valid_S2 = (SCL valid) & (NDVI <= 0.6) for highly vegetated areas
dc_s2["valid_S2"] = (is_valid_pixel_s2(dc_s2.scl)) & (dc_s2.ndvi <= 0.6)

### 5.4 Register S2 Predictions

In [None]:
# Coregister S2 predictions to datacube
dc_s2["S2_mu"] = (s2mu_result_latlon
                  .rio.reproject_match(dc_s2["valid_S2"])
                  .squeeze(drop=True)
                  .rename({"y": "latitude", "x": "longitude"}))

dc_s2["S2_sigma"] = (s2sigma_result_latlon
                     .rio.reproject_match(dc_s2["valid_S2"])
                     .squeeze(drop=True)
                     .rename({"y": "latitude", "x": "longitude"}))

print("S2 data successfully registered to datacube")

## 6. Data Fusion

### 6.1 Merge S1 and S2 Datacubes

In [None]:
# Combine S1 and S2 datacubes
combined_dc = xr.merge([dc_s2, dc])

print(f"Combined datacube variables: {list(combined_dc.data_vars)}")

### 6.3 Fuse Masked Gaussian Distributions

In [None]:
# Prepare masked inputs
masked_μ1 = combined_dc["S1_mu"].where(combined_dc["valid_S1"])
masked_σ1 = combined_dc["S1_sigma"].where(combined_dc["valid_S1"])
masked_μ2 = combined_dc["S2_mu"].where(combined_dc["valid_S2"])
masked_σ2 = combined_dc["S2_sigma"].where(combined_dc["valid_S2"])

# Fuse masked Gaussian distributions
masked_fused_mu, masked_fused_sigma = fuse_gaussians(
    masked_μ1, masked_σ1, masked_μ2, masked_σ2
)

# Store in datacube
combined_dc["masked_fused_mu"] = masked_fused_mu
combined_dc["masked_fused_sigma"] = masked_fused_sigma

print("Masked Gaussian fusion complete")

### 6.4 Fuse Gaussian Distributions (Unmasked)

In [None]:
# Fuse original (unmasked) Gaussian distributions for eventual comparison
μ1 = combined_dc["S1_mu"]
σ1 = combined_dc["S1_sigma"]
μ2 = combined_dc["S2_mu"]
σ2 = combined_dc["S2_sigma"]

fused_mu, fused_sigma = fuse_gaussians(μ1, σ1, μ2, σ2)

# Store in datacube
combined_dc["fused_mu"] = fused_mu
combined_dc["fused_sigma"] = fused_sigma

print("Unmasked Gaussian fusion complete")

## 7. Summary and Next Steps

In [None]:
# Display final datacube structure
print("\n=== Final Combined Datacube ===")
print(f"Dimensions: {combined_dc.dims}")
print(f"\nVariables:")
for var in combined_dc.data_vars:
    print(f"  - {var}: {combined_dc[var].shape}")

print("\n=== Available Data Products ===")
print("S1 products: S1_mu, S1_sigma, valid_S1")
print("S2 products: S2_mu, S2_sigma, valid_S2, ndvi")
print("Fused products: masked_fused_mu, masked_fused_sigma, fused_mu, fused_sigma")

## 8. Optional: Export Results

In [None]:
# Example: Export fused masked results to GeoTIFF
# output_path = "/path/to/output/"
# combined_dc["masked_fused_mu"].rio.to_raster(f"{output_path}/fused_masked_mu.tif")
# combined_dc["masked_fused_sigma"].rio.to_raster(f"{output_path}/fused_masked_sigma.tif")