In [0]:
# 05_feature_water_bodies_v2.py
import json
from pyspark.sql import functions as F
from graphframes import GraphFrame
from trimble_geospatial_demo_utils.site_lock import acquire_site_lock, release_site_lock
from trimble_geospatial_demo_utils import send_notification

# ==================================================
# Unity Catalog context
# ==================================================
spark.sql("USE CATALOG main")
spark.sql("USE SCHEMA demo")

# ==================================================
# Read Job Parameters
# ==================================================
dbutils.widgets.text("siteId", "", "Site ID")
dbutils.widgets.text("tileSizeM", "", "Tile Size (meters)")
dbutils.widgets.text("cellSizeM", "", "Cell Size (meters)")
dbutils.widgets.text("waterCellThreshold", "", "Water Cell Threshold")
dbutils.widgets.text("uploadJobId", "", "Upload Job ID")
dbutils.widgets.text("notificationUrl", "", "Notification URL")
dbutils.widgets.text("dbxWebhookSecret", "", "DBX Webhook Secret")

SITE_ID = dbutils.widgets.get("siteId")
TILE_SIZE_M = dbutils.widgets.get("tileSizeM")
CELL_SIZE_M = dbutils.widgets.get("cellSizeM")
WATER_CELL_THRESHOLD = dbutils.widgets.get("waterCellThreshold")
UPLOAD_JOB_ID = dbutils.widgets.get("uploadJobId")
NOTIFICATION_URL = dbutils.widgets.get("notificationUrl")
DBX_WEBHOOK_SECRET = dbutils.widgets.get("dbxWebhookSecret")

# Validate required parameters
if not SITE_ID:
    raise ValueError("Missing required job parameter: siteId")
if not TILE_SIZE_M:
    raise ValueError("Missing required job parameter: tileSizeM")
if not CELL_SIZE_M:
    raise ValueError("Missing required job parameter: cellSizeM")
if not WATER_CELL_THRESHOLD:
    raise ValueError("Missing required job parameter: waterCellThreshold")

# Convert to appropriate types
TILE_SIZE_M = float(TILE_SIZE_M)
CELL_SIZE_M = float(CELL_SIZE_M)
WATER_CELL_THRESHOLD = float(WATER_CELL_THRESHOLD)

print(f"üèóÔ∏è  Site ID: {SITE_ID}")
print(f"üìè Tile Size: {TILE_SIZE_M}m")
print(f"üìè Cell Size: {CELL_SIZE_M}m")
print(f"üíß Water Cell Threshold: {WATER_CELL_THRESHOLD}")

# ==================================================
# CONFIG
# ==================================================
CELLS_TABLE  = "surface_cells_v2"               # main.demo.surface_cells_v2
OUTPUT_TABLE = "features_water_bodies_v2"       # main.demo.features_water_bodies_v2
OUTPUT_PATH  = "abfss://processed@trimblegeospatialdemo.dfs.core.windows.net/features_water_bodies_v2"

# Neighborhood type: 4-neighbour (recommended). If you want 8-neighbour later, add diagonals.
USE_EIGHT_NEIGHBOUR = False

JOB_RUN_ID = spark.conf.get("spark.databricks.job.runId", "manual-notebook")

# ==================================================
# Set checkpoint directory for GraphFrames
# ==================================================
CHECKPOINT_PATH = f"abfss://processed@trimblegeospatialdemo.dfs.core.windows.net/checkpoints/graphframes/{SITE_ID}"
spark.sparkContext.setCheckpointDir(CHECKPOINT_PATH)

# ==================================================
# Derived constants
# ==================================================
cellsPerTile = int(round(TILE_SIZE_M / CELL_SIZE_M))
if cellsPerTile <= 0:
    raise ValueError("cellsPerTile computed <= 0. Check TILE_SIZE_M and CELL_SIZE_M.")

# ==================================================
# Acquire site lock
# ==================================================
acquire_site_lock(
    spark=spark,
    site_id=SITE_ID,
    locked_by=JOB_RUN_ID,
    ttl_minutes=120
)

try:
    # ==================================================
    # 1) Read water cells from surface_cells_v2
    # ==================================================
    df_cells = (
        spark.table(CELLS_TABLE)
             .filter(F.col("siteId") == SITE_ID)
             .filter(F.col("waterPointRatio") >= F.lit(WATER_CELL_THRESHOLD))
             .select(
                 "siteId",
                 "tileId",
                 "cellX", "cellY",
                 "cellSizeM",
                 "minZ", "meanZ", "maxZ",
                 "waterPointRatio"
             )
    )

    if df_cells.rdd.isEmpty():
        raise RuntimeError(f"No water cells found for siteId={SITE_ID} with threshold={WATER_CELL_THRESHOLD}")

    # ==================================================
    # 2) Parse tileX/tileY from tileId = "tileX_tileY"
    # ==================================================
    df_cells = (
        df_cells
        .withColumn("tileX", F.split(F.col("tileId"), "_").getItem(0).cast("int"))
        .withColumn("tileY", F.split(F.col("tileId"), "_").getItem(1).cast("int"))
    )

    # ==================================================
    # 3) Compute global grid coordinates so cross-tile neighbors differ by 1
    # ==================================================
    df_cells_global = (
        df_cells
        .withColumn("globalCellX", F.col("tileX") * F.lit(cellsPerTile) + F.col("cellX"))
        .withColumn("globalCellY", F.col("tileY") * F.lit(cellsPerTile) + F.col("cellY"))
        .withColumn("vertexId", F.concat_ws("_", F.col("globalCellX"), F.col("globalCellY")))
    )

    # Vertices must have column "id" for GraphFrames
    vertices = df_cells_global.select(F.col("vertexId").alias("id")).distinct()

    # ==================================================
    # 4) Build adjacency edges efficiently (no heavy self-join)
    # ==================================================
    v = df_cells_global.select("vertexId", "globalCellX", "globalCellY").alias("v")

    neighbours = [
        F.struct((F.col("globalCellX") + 1).alias("x"), F.col("globalCellY").alias("y")),
        F.struct((F.col("globalCellX") - 1).alias("x"), F.col("globalCellY").alias("y")),
        F.struct(F.col("globalCellX").alias("x"), (F.col("globalCellY") + 1).alias("y")),
        F.struct(F.col("globalCellX").alias("x"), (F.col("globalCellY") - 1).alias("y")),
    ]

    if USE_EIGHT_NEIGHBOUR:
        neighbours += [
            F.struct((F.col("globalCellX") + 1).alias("x"), (F.col("globalCellY") + 1).alias("y")),
            F.struct((F.col("globalCellX") + 1).alias("x"), (F.col("globalCellY") - 1).alias("y")),
            F.struct((F.col("globalCellX") - 1).alias("x"), (F.col("globalCellY") + 1).alias("y")),
            F.struct((F.col("globalCellX") - 1).alias("x"), (F.col("globalCellY") - 1).alias("y")),
        ]

    nbr = (
        df_cells_global
        .select(
            F.col("vertexId").alias("src"),
            F.explode(F.array(*neighbours)).alias("nbr")
        )
        .select(
            "src",
            F.col("nbr.x").alias("nx"),
            F.col("nbr.y").alias("ny")
        )
    )

    edges = (
        nbr.join(v, (F.col("nx") == F.col("v.globalCellX")) & (F.col("ny") == F.col("v.globalCellY")), "inner")
           .select(F.col("src"), F.col("v.vertexId").alias("dst"))
           .distinct()
    )

    # ==================================================
    # 5) Connected components => waterBodyId
    # ==================================================
    g = GraphFrame(vertices, edges)
    components = g.connectedComponents()  # returns: id, component

    df_labeled = (
        df_cells_global
        .join(components, df_cells_global.vertexId == components.id, "inner")
        .withColumnRenamed("component", "waterBodyId")
    )

    # ==================================================
    # 6) Aggregate water body features
    # ==================================================
    df_features = (
        df_labeled
        .groupBy("siteId", "waterBodyId")
        .agg(
            F.count("*").alias("cellCount"),
            (F.count("*") * F.first("cellSizeM") * F.first("cellSizeM")).alias("areaM2"),
            F.min("minZ").alias("minZ"),
            F.max("maxZ").alias("maxZ"),
            F.avg("meanZ").alias("meanZ"),
            F.min("globalCellX").alias("bboxMinCellX"),
            F.min("globalCellY").alias("bboxMinCellY"),
            F.max("globalCellX").alias("bboxMaxCellX"),
            F.max("globalCellY").alias("bboxMaxCellY")
        )
        .withColumn("computedAt", F.current_timestamp())
    )

    # ==================================================
    # 7) Write latest snapshot by site
    # ==================================================
    (
        df_features.write
            .format("delta")
            .mode("overwrite")
            .option("replaceWhere", f"siteId = '{SITE_ID}'")
            .option("path", OUTPUT_PATH)
            .partitionBy("siteId")
            .saveAsTable(OUTPUT_TABLE)
    )

    # ==================================================
    # 8) Quick verification
    # ==================================================
    spark.sql(f"""
        SELECT
          COUNT(*) AS waterBodies,
          SUM(cellCount) AS totalCells,
          SUM(areaM2) AS totalAreaM2
        FROM {OUTPUT_TABLE}
        WHERE siteId = '{SITE_ID}'
    """).show(truncate=False)

    spark.sql(f"""
        SELECT
          waterBodyId, cellCount, areaM2
        FROM {OUTPUT_TABLE}
        WHERE siteId = '{SITE_ID}'
        ORDER BY areaM2 DESC
        LIMIT 10
    """).show(truncate=False)

    print("‚úÖ features_water_bodies_v2 complete.")
except Exception as e:
    if NOTIFICATION_URL and DBX_WEBHOOK_SECRET:
        payload = {
            "runId": spark.conf.get("spark.databricks.job.runId", "manual-notebook"),
            "jobId": UPLOAD_JOB_ID,
            "status": "FAILED",
            "error": str(e),
            "siteId": SITE_ID,
            "tileSizeM": str(TILE_SIZE_M),
            "cellSizeM": str(CELL_SIZE_M),
            "waterCellThreshold": str(WATER_CELL_THRESHOLD),
        }
        try:
            send_notification(json.dumps(payload), NOTIFICATION_URL, webhook_secret=DBX_WEBHOOK_SECRET)
        except Exception as notify_ex:
            print("‚ö†Ô∏è Notification failed:", str(notify_ex)[:200])
    raise
finally:
    # Clean up checkpoint directory
    try:
        dbutils.fs.rm(CHECKPOINT_PATH, recurse=True)
    except:
        pass  # Ignore cleanup errors

    release_site_lock(
        spark=spark,
        site_id=SITE_ID,
        locked_by=JOB_RUN_ID
    )