In [0]:
# Databricks notebook: bronze_ingest
# Notebook path: /Workspace/Users/you/bronze_ingest
# Expects widget/base parameter: ingestion_batch_id (string). If not provided, it will pick the oldest Pending batch.

from datetime import datetime
import json
import hashlib
from pyspark.sql import functions as F
from pyspark.sql.types import StringType, DoubleType
from delta.tables import DeltaTable

spark.conf.set("spark.sql.legacy.timeParserPolicy", "LEGACY")

# ---------- CONFIG ----------
FILE_REG_TABLE = "census.bronze.file_registry_v1"
BRONZE_TABLE = "census.bronze.individuals_raw_v1"
INGESTION_AUDIT = "census.bronze.ingestion_audit_v1"
VALIDATION_REPORTS = "census.bronze.validation_reports_v1"
MANIFEST_TOLERANCE_PCT = 0.001  # light check; main validation notebook enforces authoritative tolerance

# ---------- Helpers ----------
def now():
    return datetime.utcnow()

def tidy_status_col(col):
    return F.lower(F.trim(F.coalesce(F.col(col), F.lit(""))))

def list_pending_batches():
    fr = spark.table(FILE_REG_TABLE)
    s = fr.withColumn("status_norm", tidy_status_col("ingestion_status"))
    agg = s.groupBy("ingestion_batch_id").agg(
        F.count(F.when(F.col("status_norm") == "pending", True)).alias("pending_count"),
        F.count(F.when(F.col("status_norm") == "processing", True)).alias("processing_count"),
        F.count(F.when(F.col("status_norm") == "succeeded", True)).alias("succeeded_count"),
        F.count(F.when(F.col("status_norm") == "failed", True)).alias("failed_count"),
        F.count(F.lit(1)).alias("total_files")
    )
    return [r.asDict() for r in agg.orderBy(F.desc("pending_count"), F.asc("ingestion_batch_id")).collect()]

# ---------- Determine ingestion_batch_id and pending rows ----------
try:
    ingestion_batch_id_widget = dbutils.widgets.get("ingestion_batch_id")
    if ingestion_batch_id_widget is not None and ingestion_batch_id_widget.strip() == "":
        ingestion_batch_id_widget = None
except Exception:
    ingestion_batch_id_widget = None

file_registry_df = spark.table(FILE_REG_TABLE).withColumn("status_norm", tidy_status_col("ingestion_status"))
batches_summary = list_pending_batches()

if ingestion_batch_id_widget:
    pending_df = file_registry_df.filter((F.col("ingestion_batch_id") == ingestion_batch_id_widget) & (F.col("status_norm") == "pending"))
    pending = pending_df.collect()
    if not pending:
        msg = {
            "status": "NO_PENDING_FOR_BATCH",
            "requested_batch": ingestion_batch_id_widget,
            "available_batches_summary": batches_summary,
            "message": f"No registry rows with ingestion_status='Pending' for ingestion_batch_id={ingestion_batch_id_widget}"
        }
        print(json.dumps(msg, indent=2))
        dbutils.notebook.exit(json.dumps(msg))
    chosen_batch = ingestion_batch_id_widget
else:
    pending_batches = file_registry_df.filter(F.col("status_norm") == "pending").select("ingestion_batch_id","created_at").distinct()
    if pending_batches.limit(1).count() == 0:
        msg = {
            "status":"NO_PENDING",
            "message":"No pending files found for ingestion (no ingestion_batch_id was supplied and no registry rows with ingestion_status='Pending').",
            "available_batches_summary": batches_summary
        }
        print(json.dumps(msg, indent=2))
        dbutils.notebook.exit(json.dumps(msg))
    chosen_row = pending_batches.orderBy(F.asc("created_at")).first()
    chosen_batch = chosen_row["ingestion_batch_id"]
    print(f"Auto-selected ingestion_batch_id = {chosen_batch} (oldest batch with Pending files)")

pending_df = file_registry_df.filter((F.col("ingestion_batch_id") == chosen_batch) & (F.col("status_norm") == "pending"))
pending = pending_df.collect()

if not pending:
    msg = {
        "status":"NO_PENDING_POST_SELECT",
        "chosen_batch": chosen_batch,
        "available_batches_summary": batches_summary,
        "message": "After selecting batch, no pending rows found. This indicates a race or state change; please re-run registration or check file_registry."
    }
    print(json.dumps(msg, indent=2))
    dbutils.notebook.exit(json.dumps(msg))

print("Selected batch:", chosen_batch)
print("Pending files count:", len(pending))
print("Pending files sample:", [r["filename"] for r in pending[:10]])

ingestion_batch_id = chosen_batch
run_id = f"bronze-{ingestion_batch_id}-{now().strftime('%Y%m%dT%H%M%SZ')}"
start_ts = now()

file_registry = DeltaTable.forName(spark, FILE_REG_TABLE)

for r in pending:
    fname = r["filename"]
    file_registry.update(
        condition = F.expr(f"filename = '{fname}'"),
        set = {
            "ingestion_status": F.lit("Processing"),
            "ingestion_attempts": F.coalesce(F.col("ingestion_attempts"), F.lit(0)) + F.lit(1),
            "last_ingestion_timestamp": F.lit(datetime.utcnow()),
            "updated_at": F.lit(datetime.utcnow())
        }
    )

# ---------- Read files and coalesce schema (robust) ----------
conformed_list = []
errors = []
processed_files = []

# canonical mapping
canonical_cols = {
    "person_id": ["person_id","personId","id"],
    "household_id": ["household_id","householdId","hh_id"],
    "geoid": ["geoid","geo_id","region_id"],
    "region_code_legacy": ["region_code_legacy","region_code"],
    "region_name_reported": ["region_name_reported","region_name","regionName"],
    "census_year": ["census_year","year"],
    "date_of_birth": ["date_of_birth","dob","DOB","birth_date"],
    "age": ["age","Age"],
    "sex": ["sex","gender"],
    "ethnicity_code": ["ethnicity_code","ethnicity"],
    "education_level": ["education_level","education"],
    "literacy": ["literacy","is_literate"],
    "employment_status": ["employment_status","employed_status"],
    "employment_type": ["employment_type","employer_type"],
    "industry_code": ["industry_code","industry"],
    "annual_income_local": ["annual_income_local","income","income_local"],
    "marital_status": ["marital_status","marital"],
    "migration_status": ["migration_status"],
    "arrival_year": ["arrival_year"],
    "is_head_of_household": ["is_head_of_household","head_of_household"],
    "record_confidence_score": ["record_confidence_score","confidence_score"],
    "enumeration_source": ["enumeration_source"],
    "national_id": ["national_id","nationalId"],
    "last_updated": ["last_updated","updated_at"]
}

all_variant_names = set()
for variants in canonical_cols.values():
    all_variant_names.update(variants)

for row in pending:
    fname = row["filename"]
    fpath = row["filepath"]
    try:
        # read file robustly
        try:
            src = spark.read.format("parquet").load(fpath)
        except Exception as e_par:
            src = spark.read.option("header","true").option("sep",";").option("encoding","latin1").option("inferSchema","false").csv(fpath)
        src_cols = src.columns

        # compute extras = original src columns not consumed by canonical variants
        extras = [c for c in src_cols if c not in all_variant_names]

        # build exprs for canonical projection (use first matching variant if present)
        exprs = []
        for canon, variants in canonical_cols.items():
            existing = [v for v in variants if v in src_cols]
            if existing:
                # cast where appropriate
                if canon == "annual_income_local":
                    exprs.append(F.col(existing[0]).cast(DoubleType()).alias(canon))
                elif canon in ("age","geoid","arrival_year"):
                    exprs.append(F.col(existing[0]).cast("int").alias(canon))
                elif canon == "record_confidence_score":
                    exprs.append(F.col(existing[0]).cast("double").alias(canon))
                elif canon == "last_updated":
                    exprs.append(F.to_timestamp(F.col(existing[0])).alias(canon))
                else:
                    exprs.append(F.col(existing[0]).alias(canon))
            else:
                exprs.append(F.lit(None).alias(canon))

        # select canonical columns AND extras 
        # Note: selecting extras here keeps those original columns available to construct _raw_payload_json
        select_cols = exprs + [F.col(c) for c in extras]
        selected = src.select(*select_cols)

        # normalize sex using the canonical alias 'sex' 
        selected = selected.withColumn("sex",
            F.when(F.lower(F.trim(F.coalesce(F.col("sex"), F.lit("")))) == "male", F.lit("Male"))
             .when(F.lower(F.trim(F.coalesce(F.col("sex"), F.lit("")))) == "female", F.lit("Female"))
             .when(F.lower(F.trim(F.coalesce(F.col("sex"), F.lit("")))) == "other", F.lit("Other"))
             .otherwise(F.initcap(F.trim(F.coalesce(F.col("sex"), F.lit(None)))))
        )

        # build JSON payload for extras (if any), then drop raw extras columns
        if extras:
            selected = selected.withColumn("_raw_payload_json", F.to_json(F.struct(*[F.col(c) for c in extras])))
            selected = selected.drop(*extras)
        else:
            selected = selected.withColumn("_raw_payload_json", F.lit(None).cast(StringType()))

        # add ingestion metadata and compute row hash
        selected = selected.withColumn("_ingestion_source_file", F.lit(fname)) \
                           .withColumn("_ingestion_batch_id", F.lit(ingestion_batch_id)) \
                           .withColumn("_ingestion_row_hash", F.sha2(F.concat_ws("|",
                               F.coalesce(F.col("person_id"), F.lit("")),
                               F.coalesce(F.col("household_id"), F.lit("")),
                               F.coalesce(F.col("date_of_birth").cast(StringType()), F.lit("")),
                               F.coalesce(F.col("last_updated").cast(StringType()), F.lit(""))
                           ), 256))

        conformed_list.append(selected)
        processed_files.append({"filename": fname, "status": "read_ok", "manifest_count": row["manifest_reported_row_count"]})
    except Exception as e:
        errors.append({"filename": fname, "error": str(e)})
        # mark file failed in registry and continue processing other files
        file_registry.update(
            condition = F.expr(f"filename = '{fname}'"),
            set = {
                "ingestion_status": F.lit("Failed"),
                "provenance_json": F.concat(F.col("provenance_json"), F.lit('\n'), F.lit(json.dumps({"error": str(e)}))),
                "updated_at": F.lit(datetime.utcnow())
            }
        )

# If nothing succeeded
if not conformed_list:
    end_ts = datetime.utcnow()
    spark.createDataFrame([(ingestion_batch_id, run_id, start_ts, end_ts, "FAILED", json.dumps({"errors": errors}))], schema="ingestion_batch_id string, run_id string, start_time timestamp, end_time timestamp, status string, notes string").write.format("delta").mode("append").saveAsTable(INGESTION_AUDIT)
    dbutils.notebook.exit(json.dumps({"status":"FAIL","reason":"no_files_ingested","errors":errors}))

# union the conformed frames (allow missing columns since schema drift exists)
union_df = conformed_list[0]
for dfp in conformed_list[1:]:
    union_df = union_df.unionByName(dfp, allowMissingColumns=True)

# quick manifest vs observed total check (light)
observed_total = union_df.count()
manifest_total = sum([int(r["manifest_reported_row_count"] or 0) for r in pending])

if manifest_total == 0:
    for r in pending:
        file_registry.update(
            condition = F.expr(f"filename = '{r['filename']}'"),
            set = {
                "ingestion_status": F.lit("Failed"),
                "provenance_json": F.concat(F.col("provenance_json"), F.lit('\n'), F.lit("manifest_total_zero")),
                "updated_at": F.lit(datetime.utcnow())
            }
        )
    spark.createDataFrame([(ingestion_batch_id, run_id, start_ts, datetime.utcnow(), "FAILED", "manifest_total_zero")], schema="ingestion_batch_id string, run_id string, start_time timestamp, end_time timestamp, status string, notes string").write.format("delta").mode("append").saveAsTable(INGESTION_AUDIT)
    dbutils.notebook.exit(json.dumps({"status":"FAIL","reason":"manifest_total_zero"}))

pct_diff = abs(observed_total - manifest_total) / manifest_total

# write to Bronze (partitioned by ingestion batch)
union_df.write.format("delta").mode("append").partitionBy("_ingestion_batch_id").saveAsTable(BRONZE_TABLE)

# mark registry rows Succeeded
for r in pending:
    file_registry.update(
        condition = F.expr(f"filename = '{r['filename']}'"),
        set = {
            "ingestion_status": F.lit("Succeeded"),
            "updated_at": F.lit(datetime.utcnow())
        }
    )

end_ts = datetime.utcnow()
spark.createDataFrame([(ingestion_batch_id, run_id, start_ts, end_ts, "SUCCEEDED", json.dumps({"observed_total": observed_total, "manifest_total": manifest_total, "pct_diff": pct_diff}))], schema="ingestion_batch_id string, run_id string, start_time timestamp, end_time timestamp, status string, notes string").write.format("delta").mode("append").saveAsTable(INGESTION_AUDIT)

dbutils.notebook.exit(json.dumps({"status":"SUCCESS","ingestion_batch_id":ingestion_batch_id,"observed_total":observed_total,"manifest_total":manifest_total,"pct_diff":pct_diff}))
