In [0]:
# -----------------------------------------------------------------------------
# Notebook: Distributed Raster Statistics with Spark RDD (Sentinel-2 Visual Assets)
# -----------------------------------------------------------------------------
# Purpose:
#   Scale out per-item raster descriptive statistics (min / max / mean / std)
#   for many Sentinel-2 Level-2A scenes stored in Microsoft Planetary Computer Pro.
#   Builds on earlier notebooks that:
#     - Ingest STAC Items to a Delta table (metadata catalog)
#     - Demonstrate single-scene inspection (rasterio Demo)
#
# Approach:
#   1. Authenticate via Managed Identity and obtain a collection SAS token.
#   2. Broadcast the SAS token to executors (avoid repeating retrieval).
#   3. Extract (id, visual asset href) pairs from the STAC Delta table.
#   4. Map an RDD function that opens each asset with rasterio, computes stats,
#      and returns a dictionary (with graceful error capture).
#   5. Materialize results into a structured Spark DataFrame for easy filtering
#      (e.g., identify failed assets, derive quality metrics, join later).
#
# When To Use:
#   - Quick quality scan across many scenes before heavier processing (tiling,
#     machine learning feature extraction, reflectance normalization).
#   - Identifying problematic assets (corrupt, missing, permission issues).
#
# Why RDD Instead of pure DataFrame UDF?:
#   - Simplicity for demonstration: direct Python logic & rasterio usage.
#   - Fine-grained error handling per record without wrapping UDF exceptions.
#   - For large-scale or performance-sensitive workflows you may prefer:
#       * MapPartitions to reduce overhead
#       * Vectorized UDFs only if data can be streamed in an efficient format
#       * Task batching patterns to limit open/close cycles
#
# Extension Ideas:
#   - Add percentile statistics or histogram summaries per asset.
#   - Persist results to Delta and schedule incremental updates (MERGE).
#   - Add retry logic for transient network/timeouts.
#   - Filter STAC Items (date, cloud cover) before mapping to reduce I/O.
#
# Prerequisites:
#   - Delta table `rt_demo.default.sentinel2` produced by ingestion notebook.
#   - Libraries: azure-identity, requests, rasterio, numpy.
# -----------------------------------------------------------------------------

from azure.identity import ManagedIdentityCredential
from pyspark.sql import Row
from pyspark.sql.types import StructType, StructField, StringType, FloatType
import matplotlib.pyplot as plt  # (Optional) could be used for aggregate viz later
import numpy as np
import rasterio
import requests

# Resource (audience) for GeoCatalog when acquiring AAD token.
MPCPRO_APP_ID = "https://geocatalog.spatio.azure.com"

# -----------------------------------------------------------------------------
# 1. Authenticate using Managed Identity
# -----------------------------------------------------------------------------
credential = ManagedIdentityCredential()
token = credential.get_token(MPCPRO_APP_ID)

# Environment endpoint & target collection (adjust for your deployment/region).
geocatalog_url = "{REPLACE-WITH-YOUR-GEOCATALOG-ENDPOINT}"
collection_id = "sentinel-2-l2a"

# -----------------------------------------------------------------------------
# 2. Request SAS token for collection (enables secure asset blob access)
# -----------------------------------------------------------------------------
headers = {"Authorization": f"Bearer {token.token}"}
params = {"api-version": "2025-04-30-preview"}
response = requests.get(
    f"{geocatalog_url}/sas/token/{collection_id}",
    headers=headers,
    params=params,
)
response.raise_for_status()
sas_token = response.json()["token"]  # Sensitive; avoid logging in production.

# Broadcast SAS token so each executor can append it when opening assets.
broadcast_sas_token = sc.broadcast(sas_token)

# -----------------------------------------------------------------------------
# 3. Select (id, visual asset href) pairs from Delta table
#    NOTE: The column access path uses nested map syntax for assets.visual.href.
#    You can filter here to reduce workload (e.g., cloud cover < 20).
# -----------------------------------------------------------------------------
items_rdd = (
    spark.table("rt_demo.default.sentinel2")
         .selectExpr("id", "assets.visual.href as href")
         .rdd
)

# -----------------------------------------------------------------------------
# 4. RDD map function: open raster & compute statistics
#    - Executed on workers; must re-import libraries inside the function if the
#      environment isolates them per task process.
#    - Returns a dictionary convertible to a Row; includes error field when
#      exceptions occur (network, decoding, missing band, permissions).
# -----------------------------------------------------------------------------

def compute_stats_rdd(row):
    import rasterio  # Local import for worker context
    import numpy as np

    item_id = row["id"]
    href = row["href"]
    sas_token_local = broadcast_sas_token.value

    try:
        with rasterio.open(f"{href}?{sas_token_local}") as src:
            band = src.read(1)
            return {
                "id": item_id,
                "min": float(np.min(band)),
                "max": float(np.max(band)),
                "mean": float(np.mean(band)),
                "std": float(np.std(band)),
                "error": None,
            }
    except Exception as e:
        # Capture exception string so downstream filtering can isolate failures.
        return {
            "id": item_id,
            "min": None,
            "max": None,
            "mean": None,
            "std": None,
            "error": str(e),
        }

# Schema for structured aggregation of results.
schema = StructType([
    StructField("id", StringType(), True),
    StructField("min", FloatType(), True),
    StructField("max", FloatType(), True),
    StructField("mean", FloatType(), True),
    StructField("std", FloatType(), True),
    StructField("error", StringType(), True),
])

# -----------------------------------------------------------------------------
# 5. Execute distributed computation & materialize DataFrame
# -----------------------------------------------------------------------------
results_rdd = items_rdd.map(compute_stats_rdd)
results_df = results_rdd.map(lambda x: Row(**x)).toDF(schema=schema)

# Display results (Databricks native display). You can also:
#   - results_df.filter("error IS NOT NULL").display()
#   - results_df.describe(["min", "max", "mean", "std"]).show()
#   - results_df.write.format("delta").mode("overwrite").saveAsTable("rt_demo.default.sentinel2_stats")
display(results_df)

# Optional: Basic aggregate visualization (uncomment if desired)
# aggregated = results_df.filter("error IS NULL").select("mean").toPandas()
# plt.hist(aggregated["mean"], bins=40, color="#4c72b0", edgecolor="white")
# plt.title("Distribution of Mean Pixel Values Across Scenes")
# plt.xlabel("Mean value")
# plt.ylabel("Frequency")
# plt.tight_layout()
# display()

print("Distributed raster statistics computation complete.")

id,min,max,mean,std,error
S2A_MSIL2A_20230816T105631_R094_T30TUK_20230816T171602,0.0,255.0,62.055912,91.45826,
S2A_MSIL2A_20230816T105631_R094_T30SVJ_20230816T171602,0.0,255.0,200.01283,56.26067,
S2A_MSIL2A_20230816T105631_R094_T30SWJ_20230816T171602,2.0,255.0,199.84663,56.615387,
S2A_MSIL2A_20230816T105631_R094_T30SUJ_20230816T171602,0.0,255.0,84.268295,86.929,
S2A_MSIL2A_20230815T030551_R075_T50TMK_20230815T082905,0.0,255.0,88.633896,68.91471,
