In [0]:
# Databricks notebook: silver_conform
# Path: /Workspace/Users/you/silver_conform
#
# Widgets:
#  - ingestion_batch_id (optional)
#  - census_year (optional)
#
# Output tables:
#  - census.silver.dim_region
#  - census.silver.dim_person
#  - census.silver.dim_person_history
#  - census.silver.dim_household
#  - census.silver.lineage
#  - census.silver.validation_reports_v1
#
# Behavior: idempotent, defensive, uses only DataFrame APIs and Delta MERGE/SQL.

from datetime import datetime
import json
import uuid
from pyspark.sql import functions as F
from pyspark.sql.window import Window
from pyspark.sql.types import StringType
from delta.tables import DeltaTable

spark.conf.set("spark.sql.legacy.timeParserPolicy", "LEGACY")

# ---------- CONFIG ----------
RAW_VOLUME_ROOT = "/Volumes/census/raw/raw_files"
REGION_LOOKUP_PATH = f"{RAW_VOLUME_ROOT}/region_lookup.parquet"

BRONZE_TABLE = "census.bronze.individuals_raw_v1"
FILE_REG_TABLE = "census.bronze.file_registry_v1"
SILVER_REGION = "census.silver.dim_region"
SILVER_PERSON = "census.silver.dim_person"
SILVER_PERSON_HISTORY = "census.silver.dim_person_history"
SILVER_HOUSEHOLD = "census.silver.dim_household"
SILVER_LINEAGE = "census.silver.lineage"
SILVER_VALIDATION = "census.silver.validation_reports_v1"

SOURCE_PRIORITY_MAP = {
    "AdminRegister": 0,
    "FieldEnumeration": 1,
    "SurveySample": 2,
    "MobileUpdate": 3
}
DEFAULT_SOURCE_PRIORITY = 99
LEVENSHTEIN_THRESHOLD = 4

# ---------- helpers ----------
def now():
    return datetime.utcnow()

def tidy_status_col(col):
    return F.lower(F.trim(F.coalesce(F.col(col), F.lit(""))))

def pick_ingestion_batch():
    # Accept widget if provided, else auto-select most-recent attempted batch
    try:
        widget_val = dbutils.widgets.get("ingestion_batch_id")
        if widget_val and widget_val.strip() != "":
            return widget_val
    except Exception:
        pass

    if not spark.catalog.tableExists(FILE_REG_TABLE):
        dbutils.notebook.exit(json.dumps({"status":"FAIL","reason":"missing_file_registry","table":FILE_REG_TABLE}))

    fr = spark.table(FILE_REG_TABLE).withColumn("status_norm", tidy_status_col("ingestion_status"))
    batches = fr.groupBy("ingestion_batch_id").agg(
        F.sum(F.coalesce(F.col("ingestion_attempts"), F.lit(0))).alias("attempts"),
        F.sum(F.when(F.col("status_norm") == "succeeded", 1).otherwise(0)).alias("succeeded_count"),
        F.max("updated_at").alias("last_updated")
    ).filter(F.col("attempts") > 0)

    if batches.limit(1).count() == 0:
        dbutils.notebook.exit(json.dumps({"status":"FAIL","reason":"no_ingestion_attempts_found","message":"Run Bronze ingest first or provide ingestion_batch_id widget"}))

    candidate = batches.orderBy(F.desc("succeeded_count"), F.desc("last_updated")).first()
    return candidate["ingestion_batch_id"]

# ---------- start ----------
ingestion_batch_id = pick_ingestion_batch()
try:
    census_year_widget = dbutils.widgets.get("census_year")
    census_year = int(census_year_widget) if census_year_widget and census_year_widget.strip() != "" else None
except Exception:
    census_year = None

start_ts = now()
run_id = f"silver_conform-{ingestion_batch_id}-{start_ts.strftime('%Y%m%dT%H%M%SZ')}"
print(f"silver_conform starting. batch={ingestion_batch_id}, census_year={census_year}, run_id={run_id}")

# ---------- Defensive canonical region load (replace prior region load block) ----------
try:
    region_src = spark.read.format("parquet").load(REGION_LOOKUP_PATH)
    src_cols = set(region_src.columns)

    # pick the best id column safely (only reference existing columns)
    if "geoid" in src_cols:
        id_col_expr = F.col("geoid").cast("int")
    elif "geo_id" in src_cols:
        id_col_expr = F.col("geo_id").cast("int")
    elif "region_id" in src_cols:
        id_col_expr = F.col("region_id").cast("int")
    elif "id" in src_cols:
        id_col_expr = F.col("id").cast("int")
    else:
        # final fallback: null geoid (should not be the normal case)
        id_col_expr = F.lit(None).cast("int").alias("geoid")

    # pick iso code column safely
    if "iso_admin_code" in src_cols:
        iso_expr = F.col("iso_admin_code")
    elif "iso_code" in src_cols:
        iso_expr = F.col("iso_code")
    else:
        iso_expr = F.lit("")

    # pick region name safely
    if "region_name_standard" in src_cols:
        name_expr = F.col("region_name_standard")
    elif "region_name" in src_cols:
        name_expr = F.col("region_name")
    elif "name" in src_cols:
        name_expr = F.col("name")
    else:
        name_expr = F.lit("")

    # pick parent geoid if present
    parent_expr = F.col("parent_geoid").cast("int") if "parent_geoid" in src_cols else F.lit(0).cast("int")

    # pick urban/rural flag
    urban_expr = F.col("urban_rural_flag") if "urban_rural_flag" in src_cols else F.lit("")

    # Construct centroid defensively:
    # If a struct named 'centroid' exists normalize its nested fields;
    # else find scalar lat/lon candidates and create a struct; otherwise produce null-lat/lon struct.
    if "centroid" in src_cols:
        region_work = region_src.withColumn("centroid", F.col("centroid"))
    else:
        # common scalar name candidates
        lat_candidates = ["latitude_center", "latitude", "lat", "centroid_latitude", "centroid_lat"]
        lon_candidates = ["longitude_center", "longitude", "lon", "lng", "centroid_longitude", "centroid_lon"]
        found_lat = next((c for c in lat_candidates if c in src_cols), None)
        found_lon = next((c for c in lon_candidates if c in src_cols), None)
        if found_lat and found_lon:
            region_work = region_src.withColumn("centroid", F.struct(
                F.col(found_lat).cast("double").alias("latitude"),
                F.col(found_lon).cast("double").alias("longitude")
            ))
        else:
            # try fuzzy name detection if explicit candidates not found
            lat_guess = next((c for c in src_cols if "lat" in c.lower() and c.lower() != "latitude"), None)
            lon_guess = next((c for c in src_cols if ("lon" in c.lower() or "lng" in c.lower()) and c != lat_guess), None)
            if lat_guess and lon_guess:
                region_work = region_src.withColumn("centroid", F.struct(
                    F.col(lat_guess).cast("double").alias("latitude"),
                    F.col(lon_guess).cast("double").alias("longitude")
                ))
            else:
                region_work = region_src.withColumn("centroid", F.struct(
                    F.lit(None).cast("double").alias("latitude"),
                    F.lit(None).cast("double").alias("longitude")
                ))

    # Project canonical schema using only expressions built from existing columns
    region_df = region_work.select(
        id_col_expr.alias("geoid"),
        iso_expr.alias("iso_admin_code"),
        name_expr.alias("region_name_standard"),
        parent_expr.alias("parent_geoid"),
        urban_expr.alias("urban_rural_flag"),
        F.col("centroid").alias("centroid")
    )

    # Normalize centroid nested keys to (latitude, longitude) if needed
    region_df = region_df.withColumn(
        "centroid",
        F.struct(
            F.col("centroid.latitude").cast("double").alias("latitude"),
            F.col("centroid.longitude").cast("double").alias("longitude")
        )
    )

    # Final write: overwrite canonical dim_region
    region_df.write.format("delta").mode("overwrite").saveAsTable(SILVER_REGION)
    region_count = region_df.count()
    print(f"dim_region written ({region_count} rows).")
except Exception as e:
    dbutils.notebook.exit(json.dumps({"status":"FAIL","reason":"region_load_failed","error":str(e)}))

# ---------- 2) Read Bronze partition for batch ----------
if not spark.catalog.tableExists(BRONZE_TABLE):
    dbutils.notebook.exit(json.dumps({"status":"FAIL","reason":"missing_bronze_table"}))

bronze_df = spark.table(BRONZE_TABLE).filter(F.col("_ingestion_batch_id") == ingestion_batch_id)
if census_year:
    bronze_df = bronze_df.filter(F.col("census_year") == int(census_year))

bronze_count = bronze_df.count()
if bronze_count == 0:
    dbutils.notebook.exit(json.dumps({"status":"FAIL","reason":"no_bronze_rows_for_batch","ingestion_batch_id":ingestion_batch_id}))

print(f"bronze rows: {bronze_count}")

# ---------- 3) Prepare canonical projection from Bronze ----------
canonical_cols = {
    "person_id": ["person_id"],
    "national_id": ["national_id"],
    "household_id": ["household_id"],
    "geoid": ["geoid"],
    "region_name_reported": ["region_name_reported"],
    "census_year": ["census_year"],
    "date_of_birth": ["date_of_birth"],
    "age": ["age"],
    "sex": ["sex"],
    "education_level": ["education_level"],
    "literacy": ["literacy"],
    "employment_status": ["employment_status"],
    "employment_type": ["employment_type"],
    "industry_code": ["industry_code"],
    "annual_income_local": ["annual_income_local"],
    "record_confidence_score": ["record_confidence_score"],
    "last_updated": ["last_updated"],
    "_ingestion_row_hash": ["_ingestion_row_hash"],
    "_ingestion_source_file": ["_ingestion_source_file"],
    "_ingestion_batch_id": ["_ingestion_batch_id"],
    "enumeration_source": ["enumeration_source"],
    "_raw_payload_json": ["_raw_payload_json"]
}

bronze_cols = set(bronze_df.columns)
exprs = []
for canon, variants in canonical_cols.items():
    chosen = next((v for v in variants if v in bronze_cols), None)
    if chosen:
        if canon in ("geoid", "age"):
            exprs.append(F.col(chosen).cast("int").alias(canon))
        elif canon == "annual_income_local":
            exprs.append(F.col(chosen).cast("double").alias(canon))
        else:
            exprs.append(F.col(chosen).alias(canon))
    else:
        exprs.append(F.lit(None).alias(canon))

selected = bronze_df.select(*exprs)

# normalize sex
selected = selected.withColumn("sex",
    F.when(F.lower(F.trim(F.coalesce(F.col("sex"), F.lit("")))) == "male", F.lit("Male"))
     .when(F.lower(F.trim(F.coalesce(F.col("sex"), F.lit("")))) == "female", F.lit("Female"))
     .when(F.lower(F.trim(F.coalesce(F.col("sex"), F.lit("")))) == "other", F.lit("Other"))
     .otherwise(F.initcap(F.trim(F.coalesce(F.col("sex"), F.lit(None)))))
)

# ensure _ingestion_row_hash exists
selected = selected.withColumn("_ingestion_row_hash", F.coalesce(F.col("_ingestion_row_hash"),
    F.sha2(F.concat_ws("|", F.coalesce(F.col("person_id"),F.lit("")),
                         F.coalesce(F.col("household_id"),F.lit("")),
                         F.coalesce(F.col("date_of_birth").cast(StringType()),F.lit("")),
                         F.coalesce(F.col("last_updated").cast(StringType()),F.lit(""))
                        ), 256)))

# ---------- 4) Region reconciliation ----------
regions = spark.table(SILVER_REGION).select("geoid","region_name_standard").withColumn("region_std_norm", F.lower(F.trim(F.regexp_replace(F.coalesce(F.col("region_name_standard"), F.lit("")), r"[^a-zA-Z0-9\s]"," "))))
bronze_recon = selected.withColumn("region_norm", F.lower(F.trim(F.regexp_replace(F.coalesce(F.col("region_name_reported"), F.lit("")), r"[^a-zA-Z0-9\s]"," "))))

# exact normalized match
exact = bronze_recon.join(regions, bronze_recon.region_norm == regions.region_std_norm, how="left").select(bronze_recon["*"], regions["geoid"].alias("recon_geoid_exact"))

# fuzzy fallback using levenshtein (safe because regions small)
regions_norm = regions.select("geoid","region_std_norm").withColumnRenamed("region_std_norm","r_std_norm")
lev = bronze_recon.alias("b").crossJoin(regions_norm.alias("r")) \
    .select(F.col("b.*"), F.col("r.geoid").alias("candidate_geoid"), F.levenshtein(F.col("b.region_norm"), F.col("r.r_std_norm")).alias("lev")) \
    .filter(F.col("lev") <= LEVENSHTEIN_THRESHOLD)

w = Window.partitionBy("_ingestion_row_hash").orderBy(F.asc("lev"))
lev_best = lev.withColumn("rn", F.row_number().over(w)).filter(F.col("rn")==1).select("_ingestion_row_hash","candidate_geoid","lev").withColumnRenamed("candidate_geoid","recon_geoid_lev")

recon = exact.join(lev_best, on="_ingestion_row_hash", how="left")
recon_resolved = recon.withColumn("geoid_resolved", F.coalesce(F.col("geoid"), F.col("recon_geoid_exact"), F.col("recon_geoid_lev"))) \
    .withColumn("region_recon_flag", F.when(F.col("geoid_resolved").isNull(), F.lit(True)).otherwise(F.lit(False))) \
    .drop("geoid").withColumnRenamed("geoid_resolved","geoid")

# ---------- 5) Deterministic grouping & canonical selection ----------
cand = recon_resolved.withColumn("match_key_national", F.when(F.col("national_id").isNotNull() & (F.col("national_id") != ""), F.col("national_id")).otherwise(F.lit(None))) \
    .withColumn("match_key_composite", F.concat_ws("|", F.coalesce(F.col("household_id"),F.lit("")),
                                                     F.coalesce(F.col("date_of_birth").cast(StringType()),F.lit("")),
                                                     F.coalesce(F.col("geoid").cast(StringType()),F.lit(""))
                                                    )) \
    .withColumn("source_priority",
                F.when(F.col("enumeration_source") == "AdminRegister", 0)
                 .when(F.col("enumeration_source") == "FieldEnumeration", 1)
                 .when(F.col("enumeration_source") == "SurveySample", 2)
                 .when(F.col("enumeration_source") == "MobileUpdate", 3)
                 .otherwise(F.lit(DEFAULT_SOURCE_PRIORITY))
               )

cand = cand.withColumn("census_year", F.col("census_year").cast("int"))
cand = cand.withColumn("group_key",
                      F.when(F.col("match_key_national").isNotNull(), F.concat_ws("|", F.col("match_key_national"), F.col("census_year").cast(StringType())))
                       .otherwise(F.concat_ws("|", F.col("match_key_composite"), F.col("census_year").cast(StringType())))
                     )

w2 = Window.partitionBy("group_key").orderBy(F.asc("source_priority"), F.desc(F.coalesce(F.col("record_confidence_score"), F.lit(0.0))), F.desc(F.coalesce(F.col("last_updated"), F.lit('1970-01-01'))))
ranked = cand.withColumn("rank", F.row_number().over(w2))
canonical = ranked.filter(F.col("rank") == 1)

grouped = cand.groupBy("group_key","census_year").agg(
    F.collect_list(F.col("_ingestion_row_hash")).alias("merged_from_bronze"),
    F.collect_set(F.col("_ingestion_source_file")).alias("contributing_files"),
    F.min("source_priority").alias("min_source_priority")
)

canonical_attrs = canonical.select("group_key","person_id","national_id","household_id","geoid","date_of_birth","age","sex","education_level","literacy","employment_status","employment_type","annual_income_local","record_confidence_score","last_updated","enumeration_source")

silver_prep = grouped.join(canonical_attrs, on="group_key", how="left") \
    .withColumn("canonical_person_id", F.coalesce(F.col("national_id"), F.sha2(F.col("group_key"),256))) \
    .withColumn("person_surrogate_id", F.sha2(F.concat_ws("|", F.coalesce(F.col("canonical_person_id"),F.col("group_key")), F.lit(run_id)),256)) \
    .withColumn("ingestion_batch_id", F.lit(ingestion_batch_id)) \
    .withColumn("is_current", F.lit(True)) \
    .withColumn("effective_from", F.current_timestamp()) \
    .withColumn("effective_to", F.lit(None).cast("timestamp")) \
    .withColumn("source_priority", F.col("min_source_priority"))

person_final = silver_prep.select(
    "person_surrogate_id","canonical_person_id","group_key","census_year","person_id","national_id","household_id","geoid","date_of_birth","age","sex","education_level","literacy","employment_status","employment_type","annual_income_local","record_confidence_score","last_updated","merged_from_bronze","contributing_files","source_priority","ingestion_batch_id","is_current","effective_from","effective_to"
)


print("DEBUG: person_final pre-upsert count:", person_final.count())
display(person_final.select("canonical_person_id","census_year","ingestion_batch_id").limit(10))

print("DEBUG: Incoming batch distinct canonical_person_id count:", person_final.select("canonical_person_id").distinct().count())
print("DEBUG: Sample incoming keys:", person_final.select("canonical_person_id", "census_year").limit(5).collect())

# ---------- 6) SCD2 upsert into dim_person with proper history ----------
def scd2_upsert_with_history(target_table, history_table, incoming_df):
    """
    SCD2 upsert that properly captures historical versions.
    IMPORTANT: Only triggers on attribute changes, NOT on batch_id change alone.
    """
    tmp_view = f"tmp_in_{uuid.uuid4().hex}"
    incoming_df.createOrReplaceTempView(tmp_view)
    
    # Attributes compared for detecting changes - ONLY these should trigger SCD2
    attr_cols = ["age", "sex", "education_level", "literacy", "employment_status",
                 "employment_type", "annual_income_local", "household_id", "geoid",
                 "date_of_birth", "record_confidence_score"]
    
    # Build SQL list for comparison - use IS NOT DISTINCT FROM for proper NULL handling
    diff_cond_parts = []
    for c in attr_cols:
        # Handle both NULL values and actual differences
        diff_cond_parts.append(f"(NOT (t.{c} IS NOT DISTINCT FROM s.{c}))")
    
    # Only trigger SCD2 when attributes actually change
    diff_cond = " OR ".join(diff_cond_parts) if diff_cond_parts else "false"
    
    print(f"DEBUG: Using diff condition (ONLY attribute changes): {diff_cond}")
    print(f"DEBUG: Batch ID alone will NOT trigger SCD2")
    
    # 1) Capture rows that will be expired (for history table)
    # These are rows where attributes changed AND they're currently current
    capture_expired_sql = f"""
    SELECT t.*
    FROM {target_table} t
    INNER JOIN {tmp_view} s
      ON t.canonical_person_id = s.canonical_person_id
     AND t.census_year = s.census_year
    WHERE t.is_current = true
      AND ({diff_cond})
    """
    
    expired_rows = spark.sql(capture_expired_sql)
    
    if expired_rows.count() > 0:
        print(f"DEBUG: Capturing {expired_rows.count()} expired rows for history (attributes changed)")
        # Mark them as expired with current timestamp
        expired_for_history = expired_rows.withColumn("is_current", F.lit(False)) \
                                          .withColumn("effective_to", F.current_timestamp())
        # Append to history table
        expired_for_history.write.format("delta").mode("append").saveAsTable(history_table)
        print(f"DEBUG: Appended {expired_for_history.count()} rows to history table")
    else:
        print(f"DEBUG: No attribute changes detected - no rows to expire")
    
    # 2) MERGE: Expire existing current rows ONLY if attributes changed
    expire_merge_sql = f"""
    MERGE INTO {target_table} t
    USING (SELECT * FROM {tmp_view}) s
    ON t.canonical_person_id = s.canonical_person_id
      AND t.census_year = s.census_year
      AND t.is_current = true
    WHEN MATCHED AND ({diff_cond})
      THEN UPDATE SET 
        t.is_current = false, 
        t.effective_to = current_timestamp()
    """
    
    try:
        expire_result = spark.sql(expire_merge_sql)
        expired_count = expire_result.count() if hasattr(expire_result, 'count') else 0
        print(f"DEBUG: Expired {expired_count} rows via MERGE (attribute changes)")
    except Exception as e:
        print(f"[ERROR] Expire MERGE failed: {e}")
        raise

    # 3) MERGE: Insert new rows or update existing
    # For matching rows: always update batch_id and other metadata, but only insert new version if attributes changed
    insert_merge_sql = f"""
    MERGE INTO {target_table} t
    USING (SELECT * FROM {tmp_view}) s
    ON t.canonical_person_id = s.canonical_person_id
      AND t.census_year = s.census_year
      AND t.is_current = true
    WHEN MATCHED AND ({diff_cond})
      THEN UPDATE SET
        t.person_surrogate_id = s.person_surrogate_id,
        t.person_id = s.person_id,
        t.national_id = s.national_id,
        t.household_id = s.household_id,
        t.geoid = s.geoid,
        t.date_of_birth = s.date_of_birth,
        t.age = s.age,
        t.sex = s.sex,
        t.education_level = s.education_level,
        t.literacy = s.literacy,
        t.employment_status = s.employment_status,
        t.employment_type = s.employment_type,
        t.annual_income_local = s.annual_income_local,
        t.record_confidence_score = s.record_confidence_score,
        t.last_updated = s.last_updated,
        t.merged_from_bronze = s.merged_from_bronze,
        t.contributing_files = s.contributing_files,
        t.source_priority = s.source_priority,
        t.ingestion_batch_id = s.ingestion_batch_id,
        t.is_current = true,
        t.effective_from = current_timestamp(),
        t.effective_to = NULL
    WHEN NOT MATCHED
      THEN INSERT *
    """
    
    try:
        spark.sql(insert_merge_sql)
        print(f"DEBUG: Insert/Update MERGE completed")
    except Exception as e:
        raise RuntimeError(f"MERGE upsert failed for {target_table}: {e}")
    
    # 4) Clean up expired rows from dim_person (keep only current rows)
    delete_sql = f"""
    DELETE FROM {target_table}
    WHERE is_current = false
    """
    
    try:
        delete_result = spark.sql(delete_sql)
        deleted_count = delete_result.count() if hasattr(delete_result, 'count') else 0
        print(f"DEBUG: Deleted {deleted_count} expired rows from {target_table}")
    except Exception as e:
        print(f"[ERROR] DELETE failed: {e}")
        raise
    
    # 5) For rows where attributes didn't change, update the batch_id
    # This handles the case where same data comes in different batch
    update_batch_only_sql = f"""
    MERGE INTO {target_table} t
    USING (SELECT * FROM {tmp_view}) s
    ON t.canonical_person_id = s.canonical_person_id
      AND t.census_year = s.census_year
      AND t.is_current = true
    WHEN MATCHED AND (NOT ({diff_cond}))
      THEN UPDATE SET
        t.ingestion_batch_id = s.ingestion_batch_id,
        t.merged_from_bronze = s.merged_from_bronze,
        t.contributing_files = s.contributing_files
    """
    
    try:
        update_result = spark.sql(update_batch_only_sql)
        updated_count = update_result.count() if hasattr(update_result, 'count') else 0
        print(f"DEBUG: Updated batch_id only for {updated_count} rows (no attribute changes)")
    except Exception as e:
        print(f"[WARNING] Batch-only update failed (non-critical): {e}")

        
# ---------- ENSURE SCHEMA + TARGET TABLES EXIST (defensive bootstrap) ----------
# Make sure the Unity Catalog schema exists (idempotent)
spark.sql("CREATE SCHEMA IF NOT EXISTS census.silver")

# If dim_person does not exist, create an empty Delta table with the same schema
if not spark.catalog.tableExists(SILVER_PERSON):
    empty_person_df = spark.createDataFrame([], schema=person_final.schema)
    empty_person_df.write.format("delta").mode("overwrite").partitionBy("geoid").saveAsTable(SILVER_PERSON)
    print("Bootstrapped empty table:", SILVER_PERSON)
else:
    print("Target table already exists:", SILVER_PERSON)

# If dim_person_history does not exist, create it with EXACTLY the same schema
if not spark.catalog.tableExists(SILVER_PERSON_HISTORY):
    # Create history table with same schema as person_final
    empty_history_df = spark.createDataFrame([], schema=person_final.schema)
    empty_history_df.write.format("delta").mode("overwrite").partitionBy("geoid").saveAsTable(SILVER_PERSON_HISTORY)
    print("Bootstrapped empty table:", SILVER_PERSON_HISTORY)

pre_total = spark.table(SILVER_PERSON).count()
pre_batch_cnt = spark.table(SILVER_PERSON).filter(F.col("ingestion_batch_id")==ingestion_batch_id).count()
print(f"DEBUG: dim_person pre-merge total={pre_total}, batch_count={pre_batch_cnt}")

# Call the SCD2 upsert function
scd2_upsert_with_history(SILVER_PERSON, SILVER_PERSON_HISTORY, person_final)

# Post-merge counts
post_total = spark.table(SILVER_PERSON).count()
post_batch_cnt = spark.table(SILVER_PERSON).filter(F.col("ingestion_batch_id")==ingestion_batch_id).count()
history_total = spark.table(SILVER_PERSON_HISTORY).count()
print(f"DEBUG: dim_person post-merge total={post_total}, batch_count={post_batch_cnt}")
print(f"DEBUG: dim_person_history total={history_total}")
print(f"DEBUG: Current vs expired ratio: {post_total} current, {history_total - pre_total if history_total > pre_total else 0} new historical rows")

# Verify that dim_person only has current rows
current_check = spark.table(SILVER_PERSON).filter(F.col("is_current") == False).count()
if current_check > 0:
    print(f"WARNING: Found {current_check} non-current rows in dim_person. They should have been deleted.")

# ---------- 7) Household aggregates ----------
# explode merged_from_bronze and join back to Bronze to compute household-level metrics
contrib_flat = silver_prep.select("group_key", F.explode(F.col("merged_from_bronze")).alias("bronze_hash"))
bronze_key = bronze_df.select("_ingestion_row_hash","household_id","annual_income_local","geoid","literacy")
household_joined = contrib_flat.join(bronze_key, contrib_flat.bronze_hash == bronze_key._ingestion_row_hash, how="left").select("household_id","annual_income_local","geoid","literacy")
household_agg = household_joined.groupBy("household_id","geoid").agg(
    F.count("*").alias("household_size"),
    F.expr("percentile_approx(annual_income_local, 0.5)").alias("median_household_income"),
    F.avg(F.when(F.col("literacy")==True, 1).otherwise(0)).alias("household_literacy_rate")
).withColumn("ingestion_batch_id", F.lit(ingestion_batch_id))

try:
    tgt_house = DeltaTable.forName(spark, SILVER_HOUSEHOLD)
    tgt_house.alias("t").merge(
        household_agg.alias("s"),
        "t.household_id = s.household_id AND t.geoid = s.geoid"
    ).whenMatchedUpdate(set={
        "household_size": F.col("s.household_size"),
        "median_household_income": F.col("s.median_household_income"),
        "household_literacy_rate": F.col("s.household_literacy_rate"),
        "ingestion_batch_id": F.col("s.ingestion_batch_id")
    }).whenNotMatchedInsertAll().execute()
except Exception:
    household_agg.write.format("delta").mode("overwrite").partitionBy("geoid").saveAsTable(SILVER_HOUSEHOLD)

# ---------- 8) Lineage table ----------
lineage_df = person_final.select("canonical_person_id","person_surrogate_id","census_year","merged_from_bronze","contributing_files","ingestion_batch_id").dropDuplicates(["canonical_person_id","census_year","ingestion_batch_id"])
try:
    tgt_lineage = DeltaTable.forName(spark, SILVER_LINEAGE)
    tgt_lineage.alias("t").merge(
        lineage_df.alias("s"),
        "t.canonical_person_id = s.canonical_person_id AND t.census_year = s.census_year AND t.ingestion_batch_id = s.ingestion_batch_id"
    ).whenNotMatchedInsertAll().execute()
except Exception:
    lineage_df.write.format("delta").mode("overwrite").partitionBy("ingestion_batch_id").saveAsTable(SILVER_LINEAGE)

# ---------- 9) Final summary & validation row ----------
end_ts = now()
person_count = spark.table(SILVER_PERSON).filter(F.col("ingestion_batch_id") == ingestion_batch_id).count() if spark.catalog.tableExists(SILVER_PERSON) else 0
hh_count = spark.table(SILVER_HOUSEHOLD).filter(F.col("ingestion_batch_id") == ingestion_batch_id).count() if spark.catalog.tableExists(SILVER_HOUSEHOLD) else 0
lineage_count = spark.table(SILVER_LINEAGE).filter(F.col("ingestion_batch_id") == ingestion_batch_id).count() if spark.catalog.tableExists(SILVER_LINEAGE) else 0

report = {
    "ingestion_batch_id": ingestion_batch_id,
    "run_id": run_id,
    "start_time": start_ts.isoformat(),
    "end_time": end_ts.isoformat(),
    "bronze_row_count": bronze_count,
    "person_count": person_count,
    "household_count": hh_count,
    "lineage_count": lineage_count,
    "status": "SUCCEEDED"
}

spark.createDataFrame([(ingestion_batch_id, run_id, datetime.utcnow(), json.dumps(report), "PASS")], schema="ingestion_batch_id string, run_id string, report_time timestamp, report_json string, status string").write.format("delta").mode("append").saveAsTable(SILVER_VALIDATION)

dbutils.notebook.exit(json.dumps({"status":"SUCCESS","report":report}))
