In [0]:
# 02_spatial_tiling_v2.py
import json
from pyspark.sql import functions as F, Window
from datetime import datetime
from trimble_geospatial_demo_utils.site_lock import acquire_site_lock, release_site_lock
from trimble_geospatial_demo_utils import send_notification

# ==================================================
# Unity Catalog context
# ==================================================
spark.sql("USE CATALOG main")
spark.sql("USE SCHEMA demo")

# ==================================================
# Read Job Parameters
# ==================================================
dbutils.widgets.text("siteId", "", "Site ID")
dbutils.widgets.text("ingestRunId", "", "Ingest Run ID")
dbutils.widgets.text("uploadJobId", "", "Upload Job ID")
dbutils.widgets.text("notificationUrl", "", "Notification URL")
dbutils.widgets.text("dbxWebhookSecret", "", "DBX Webhook Secret")

SITE_ID = dbutils.widgets.get("siteId")
TARGET_INGEST_RUN_ID = dbutils.widgets.get("ingestRunId")
UPLOAD_JOB_ID = dbutils.widgets.get("uploadJobId")
NOTIFICATION_URL = dbutils.widgets.get("notificationUrl")
DBX_WEBHOOK_SECRET = dbutils.widgets.get("dbxWebhookSecret")

# Validate required parameters
if not SITE_ID:
    raise ValueError("Missing required job parameter: siteId")
if not TARGET_INGEST_RUN_ID:
    raise ValueError("Missing required job parameter: ingestRunId")

print(f"üèóÔ∏è  Site ID: {SITE_ID}")
print(f"üîÑ Ingest Run ID: {TARGET_INGEST_RUN_ID}")

# ==================================================
# CONFIG SWITCHES
# ==================================================
USE_CONTROL_TABLE = False
ALLOW_FALLBACK = True

# ==================================================
# TABLES & PATHS (UC External Table)
# ==================================================
RAW_TABLE = "points_raw"  # main.demo.points_raw

PROCESSED_TABLE_V2 = "processed_points_tiled_v2"
PROCESSED_PATH_V2  = "abfss://processed@trimblegeospatialdemo.dfs.core.windows.net/points_tiled_v2"

CONTROL_TABLE = "control_tiling_params"  # main.demo.control_tiling_params

# ==================================================
# CODE DEFAULTS (SAFE BASELINE)
# ==================================================
DEFAULT_TILE_SIZE_M = 25.0
DEFAULT_HOT_TILE_THRESHOLD = 100_000
DEFAULT_TARGET_POINTS_PER_BUCKET = 600_000
DEFAULT_SALT_BUCKETS = 16
DEFAULT_ENABLE_SALT = False

# ==================================================
# Get Job Run ID (for lock tracking)
# ==================================================
JOB_RUN_ID = spark.conf.get(
    "spark.databricks.job.runId",
    "manual-notebook"
)

print(f"Job Run ID: {JOB_RUN_ID}")

# ==================================================
# Acquire site lock before processing
# ==================================================
print(f"\n=== Acquiring lock for site: {SITE_ID} ===")
acquire_site_lock(
    spark=spark,
    site_id=SITE_ID,
    locked_by=JOB_RUN_ID,
    ttl_minutes=90
)
print(f"‚úÖ Lock acquired for site: {SITE_ID}")

try:
    # ==================================================
    # 1) Load raw from Unity Catalog & use specified ingestRunId
    # ==================================================
    df_raw_site = (
        spark.table(RAW_TABLE)
             .filter(F.col("siteId") == SITE_ID)
    )

    INGEST_RUN_ID = TARGET_INGEST_RUN_ID

    df_raw = df_raw_site.filter(F.col("ingestRunId") == INGEST_RUN_ID)

    print("Using siteId =", SITE_ID)
    print("Using ingestRunId =", INGEST_RUN_ID)

    # ==================================================
    # 2) Resolve parameters (override ‚Üí control ‚Üí default)
    # ==================================================
    param_source = "code-defaults"

    TILE_SIZE_M = DEFAULT_TILE_SIZE_M
    HOT_TILE_THRESHOLD = DEFAULT_HOT_TILE_THRESHOLD
    TARGET_POINTS_PER_BUCKET = DEFAULT_TARGET_POINTS_PER_BUCKET
    SALT_BUCKETS = DEFAULT_SALT_BUCKETS
    ENABLE_SALT = DEFAULT_ENABLE_SALT

    if USE_CONTROL_TABLE:
        try:
            params = (
                spark.table(CONTROL_TABLE)
                     .filter((F.col("siteId") == SITE_ID) & (F.col("ingestRunId") == INGEST_RUN_ID))
                     .orderBy(F.col("computedAt").desc())
                     .limit(1)
                     .collect()
            )

            if params:
                p = params[0]
                TILE_SIZE_M = float(p["tileSizeM"])
                HOT_TILE_THRESHOLD = int(p["hotTileThreshold"])
                TARGET_POINTS_PER_BUCKET = int(p["targetPointsPerBucket"])
                SALT_BUCKETS = int(p["saltBuckets"])
                ENABLE_SALT = int(p["maxTilePoints"]) >= TARGET_POINTS_PER_BUCKET
                param_source = "control-table"
            else:
                if not ALLOW_FALLBACK:
                    raise ValueError("No control table entry found and fallback disabled.")
                print("‚ö†Ô∏è No control params found; falling back to code defaults.")

        except Exception as e:
            if not ALLOW_FALLBACK:
                raise
            print("‚ö†Ô∏è Failed to read control table, falling back to code defaults.")
            print("Reason:", str(e)[:200])

    print("=== Spatial Tiling Parameters (V2) ===")
    print("Source:", param_source)
    print("TILE_SIZE_M =", TILE_SIZE_M)
    print("HOT_TILE_THRESHOLD =", HOT_TILE_THRESHOLD)
    print("TARGET_POINTS_PER_BUCKET =", TARGET_POINTS_PER_BUCKET)
    print("SALT_BUCKETS =", SALT_BUCKETS)
    print("ENABLE_SALT =", ENABLE_SALT)
    print("======================================")

    # ==================================================
    # 3) Compute origin (bbox min) & tile indices
    #    NOTE: origin derived from THIS RUN to ensure determinism within run.
    #    If you need cross-run stable tiling, store origin per site in a control table.
    # ==================================================
    origin_row = df_raw.agg(F.min("x").alias("minX"), F.min("y").alias("minY")).first()
    originX, originY = float(origin_row["minX"]), float(origin_row["minY"])

    df_tiled = (
        df_raw
        .withColumn("originX", F.lit(originX))
        .withColumn("originY", F.lit(originY))
        .withColumn("tileSizeM", F.lit(float(TILE_SIZE_M)))
        .withColumn("tileX", F.floor((F.col("x") - F.lit(originX)) / F.lit(TILE_SIZE_M)).cast("int"))
        .withColumn("tileY", F.floor((F.col("y") - F.lit(originY)) / F.lit(TILE_SIZE_M)).cast("int"))
        .withColumn("tileId", F.concat_ws("_", F.col("tileX").cast("string"), F.col("tileY").cast("string")))
    )

    # ==================================================
    # 4) Salt only if enabled
    # ==================================================
    if ENABLE_SALT:
        tile_counts = (
            df_tiled.groupBy("tileId", "tileX", "tileY")
                    .agg(F.count("*").alias("pointCount"))
        )

        hot_keys = (
            tile_counts.filter(F.col("pointCount") >= HOT_TILE_THRESHOLD)
                       .select("tileId")
                       .distinct()
        )

        hot_keys_b = F.broadcast(hot_keys)

        df_hot = (
            df_tiled.join(hot_keys_b, ["tileId"], "left_semi")
                    .withColumn("tileSalt", F.pmod(F.hash("x","y","z"), F.lit(SALT_BUCKETS)))
                    .withColumn("isHotTile", F.lit(1))
        )

        df_non_hot = (
            df_tiled.join(hot_keys_b, ["tileId"], "left_anti")
                    .withColumn("tileSalt", F.lit(0))
                    .withColumn("isHotTile", F.lit(0))
        )

        df_processed = df_hot.unionByName(df_non_hot)
    else:
        df_processed = (
            df_tiled
            .withColumn("tileSalt", F.lit(0))
            .withColumn("isHotTile", F.lit(0))
        )

    # ==================================================
    # 5) Add snapshot metadata (traceability)
    #    Keep ingestRunId as snapshot marker + add snapshotAt timestamp
    # ==================================================
    df_processed = (
        df_processed
        .withColumn("snapshotAt", F.current_timestamp())  # When this snapshot was created
    )

    print(f"\n=== Snapshot Metadata ===")
    print(f"  - ingestRunId: {INGEST_RUN_ID} (source data run)")
    print(f"  - snapshotAt: {datetime.now().isoformat()} (processing timestamp)")
    print(f"  - Strategy: Latest snapshot only (previous snapshots overwritten)")

    # ==================================================
    # 6) Write V2: Only keep latest data per site
    #    Strategy: replaceWhere by siteId only (overwrite entire site)
    #    Keep ingestRunId + snapshotAt for traceability
    # ==================================================
    print(f"\n=== Writing V2 to Unity Catalog table: {PROCESSED_TABLE_V2} ===")

    (
        df_processed.write
            .format("delta")
            .mode("overwrite")
            .option("replaceWhere", f"siteId = '{SITE_ID}'")  # Only filter by siteId (replace entire site)
            .option("path", PROCESSED_PATH_V2)
            .partitionBy("siteId", "tileId")
            .saveAsTable(PROCESSED_TABLE_V2)
    )

    print("‚úÖ V2 Data written and table registered in Unity Catalog")
    print(f"   - Table: main.demo.{PROCESSED_TABLE_V2}")
    print(f"   - Location: {PROCESSED_PATH_V2}")
    print(f"   - Strategy: Only latest ingestRunId ({INGEST_RUN_ID}) kept for site '{SITE_ID}'")
    print(f"   - Snapshot timestamp: {datetime.now().isoformat()}")

    # ==================================================
    # 7) Verify V2 write (snapshot metadata)
    # ==================================================
    print("\n=== Verify V2 table (snapshot metadata) ===")
    spark.sql(f"""
    SELECT
      siteId,
      ingestRunId,
      MIN(snapshotAt) AS snapshotAt,
      COUNT(*) AS rows,
      COUNT(DISTINCT tileId) AS tiles
    FROM {PROCESSED_TABLE_V2}
    WHERE siteId = '{SITE_ID}'
    GROUP BY siteId, ingestRunId
    ORDER BY ingestRunId DESC
    """).show(truncate=False)

    print("\n=== Sample data with metadata ===")
    spark.sql(f"""
    SELECT siteId, ingestRunId, tileId, tileX, tileY, snapshotAt
    FROM {PROCESSED_TABLE_V2}
    WHERE siteId = '{SITE_ID}'
    LIMIT 5
    """).show(truncate=False)

    print("\n‚úÖ Complete (V2)!")
    print("\nSnapshot design benefits:")
    print("  ‚úÖ ingestRunId: Tracks source data run")
    print("  ‚úÖ snapshotAt: Tracks when this snapshot was created")
    print("  ‚úÖ Latest snapshot only: No historical data accumulation")
    print("  ‚úÖ API-friendly: Easy to return metadata with results")
    print("  ‚úÖ Troubleshooting: Can trace back to source data")

except Exception as e:
    if NOTIFICATION_URL and DBX_WEBHOOK_SECRET:
        payload = {
            "runId": spark.conf.get("spark.databricks.job.runId", "manual-notebook"),
            "jobId": UPLOAD_JOB_ID,
            "status": "FAILED",
            "error": str(e),
            "siteId": SITE_ID,
            "ingestRunId": TARGET_INGEST_RUN_ID,
        }
        try:
            send_notification(json.dumps(payload), NOTIFICATION_URL, webhook_secret=DBX_WEBHOOK_SECRET)
        except Exception as notify_ex:
            print("‚ö†Ô∏è Notification failed:", str(notify_ex)[:200])
    raise

finally:
    # ==================================================
    # Release site lock (always executed, even on error)
    # ==================================================
    print(f"\n=== Releasing lock for site: {SITE_ID} ===")
    release_site_lock(
        spark=spark,
        site_id=SITE_ID,
        locked_by=JOB_RUN_ID
    )
    print(f"‚úÖ Lock released for site: {SITE_ID}")