In [0]:
# 04_surface_cells_v2.py
from pyspark.sql import functions as F
from utils.site_lock import acquire_site_lock, release_site_lock

# ==================================================
# Unity Catalog context
# ==================================================
spark.sql("USE CATALOG main")
spark.sql("USE SCHEMA demo")

# ==================================================
# CONFIG
# ==================================================
SITE_ID = spark.conf.get("pipeline.siteId", "wellington_cbd")

POINTS_TABLE   = "processed_points_tiled_v2"
TILE_STATS_TBL = "tile_stats_v2"

OUTPUT_TABLE = "surface_cells_v2"
OUTPUT_PATH  = "abfss://processed@trimblegeospatialdemo.dfs.core.windows.net/surface_cells_v2"

# Cell resolution (meters). Choose 0.25 / 0.5 / 1.0 depending on density + cost.
CELL_SIZE_M = 0.50

# Skip tiles dominated by water (keep shoreline/mixed tiles)
SKIP_WATER_RATIO_GE = float(spark.conf.get("pipeline.skipWaterTileRatio", "0.8"))

# Optional: skip tiny tiles (saves cost + avoids noisy cells)
MIN_TILE_POINTS = 5_000

# ==================================================
# Job identity (for locking / audit)
# ==================================================
JOB_RUN_ID = spark.conf.get("spark.databricks.job.runId", "manual-notebook")

# ==================================================
# Acquire site-level lock (latest snapshot semantics)
# ==================================================
acquire_site_lock(
    spark=spark,
    site_id=SITE_ID,
    locked_by=JOB_RUN_ID,
    ttl_minutes=90
)

try:
    # ==================================================
    # 1) Select tiles to process (routing)
    #    - skip mostly-water tiles (waterRatio >= 0.7)
    #    - optionally skip tiles with very few points
    # ==================================================
    df_tiles = (
        spark.table(TILE_STATS_TBL)
             .filter(F.col("siteId") == SITE_ID)
             .filter(F.col("waterPointRatio") < F.lit(SKIP_WATER_RATIO_GE))
             .filter(F.col("pointCount") >= F.lit(MIN_TILE_POINTS))
             .select("siteId", "tileId")
             .distinct()
    )

    if df_tiles.rdd.isEmpty():
        raise RuntimeError(
            f"No tiles eligible for processing for siteId={SITE_ID}. "
            f"Check water ratio threshold / MIN_TILE_POINTS."
        )

    # Broadcast tile list (typically small)
    df_tiles_b = F.broadcast(df_tiles)

    # ==================================================
    # 2) Read points for selected tiles only
    # ==================================================
    df_points = (
        spark.table(POINTS_TABLE)
             .filter(F.col("siteId") == SITE_ID)
             .join(df_tiles_b, ["siteId", "tileId"], "inner")
             .select(
                 "siteId", "tileId",
                 "x", "y", "z",
                 "originX", "originY", "tileSizeM",
                 "tileX", "tileY",
                 "classification", "intensity"
             )
    )

    if df_points.rdd.isEmpty():
        raise RuntimeError(f"No point data found after routing join for siteId={SITE_ID}")

    # ==================================================
    # 3) Compute per-tile origin for cell indexing
    #
    # We want cell coords *within* each tile:
    #   localX = x - (originX + tileX * tileSizeM)
    #   localY = y - (originY + tileY * tileSizeM)
    #
    # Then:
    #   cellX = floor(localX / CELL_SIZE_M)
    #   cellY = floor(localY / CELL_SIZE_M)
    #
    # This avoids floating drift and makes cells stable inside a tile.
    # ==================================================
    df_cells_keyed = (
        df_points
        .withColumn("tileOriginX", F.col("originX") + (F.col("tileX") * F.col("tileSizeM")))
        .withColumn("tileOriginY", F.col("originY") + (F.col("tileY") * F.col("tileSizeM")))
        .withColumn("localX", F.col("x") - F.col("tileOriginX"))
        .withColumn("localY", F.col("y") - F.col("tileOriginY"))
        .withColumn("cellSizeM", F.lit(float(CELL_SIZE_M)))
        .withColumn("cellX", F.floor(F.col("localX") / F.lit(CELL_SIZE_M)).cast("int"))
        .withColumn("cellY", F.floor(F.col("localY") / F.lit(CELL_SIZE_M)).cast("int"))
    )

    # ==================================================
    # 4) Aggregate into cells
    # ==================================================
    df_surface_cells = (
        df_cells_keyed
        .groupBy("siteId", "tileId", "cellX", "cellY")
        .agg(
            F.count("*").alias("pointCount"),

            # cell-level water count
            F.sum(
                F.when(F.col("classification") == F.lit(WATER_CLASS), 1)
                .otherwise(0)
            ).alias("waterPointCount"),

            F.min("z").alias("minZ"),
            F.avg("z").alias("meanZ"),
            F.max("z").alias("maxZ"),
            F.expr("percentile_approx(z, 0.50)").alias("z_p50")
        )
        # cell-level water ratio
        .withColumn(
            "waterPointRatio",
            F.when(
                F.col("pointCount") > 0,
                F.col("waterPointCount") / F.col("pointCount")
            ).otherwise(F.lit(0.0))
        )
        .withColumn("cellSizeM", F.lit(float(CELL_SIZE_M)))
        .withColumn("computedAt", F.current_timestamp())
    )

    # ==================================================
    # 5) Safety check: ensure only one siteId
    # ==================================================
    if df_surface_cells.select("siteId").distinct().count() != 1:
        raise RuntimeError("surface_cells output contains multiple siteId values")

    # ==================================================
    # 6) Write latest snapshot (replace entire site)
    # ==================================================
    (
        df_surface_cells.write
            .format("delta")
            .mode("overwrite")
            .option("replaceWhere", f"siteId = '{SITE_ID}'")
            .option("path", OUTPUT_PATH)
            .partitionBy("siteId", "tileId")
            .saveAsTable(OUTPUT_TABLE)
    )

    # ==================================================
    # 7) Verification (high-level)
    # ==================================================
    print("\n=== Verify surface_cells_v2 ===")
    spark.sql(f"""
        SELECT
          siteId,
          COUNT(*) AS cellRows,
          SUM(pointCount) AS totalPointsUsed,
          MIN(minZ) AS siteMinZ,
          MAX(maxZ) AS siteMaxZ,
          MIN(computedAt) AS minComputedAt,
          MAX(computedAt) AS maxComputedAt
        FROM {OUTPUT_TABLE}
        WHERE siteId = '{SITE_ID}'
        GROUP BY siteId
    """).show(truncate=False)

    # Optional: how many tiles produced cells?
    spark.sql(f"""
        SELECT
          siteId,
          COUNT(DISTINCT tileId) AS tilesWithCells
        FROM {OUTPUT_TABLE}
        WHERE siteId = '{SITE_ID}'
        GROUP BY siteId
    """).show(truncate=False)

    print("âœ… surface_cells_v2 written successfully")

finally:
    # ==================================================
    # Release site-level lock
    # ==================================================
    release_site_lock(
        spark=spark,
        site_id=SITE_ID,
        locked_by=JOB_RUN_ID
    )
