In [0]:
# Databricks notebook: bronze_validation
# Path: /Workspace/Users/you/bronze_validation
# Widget: optional ingestion_batch_id (string)
#
# Writes:
#  - census.bronze.validation_reports_v1
#  - updates census.bronze.file_registry_v1 when appropriate (mark Failed on validation errors)
#  - appends run row to census.bronze.ingestion_audit_v1
#
# Behavior:
#  - If widget ingestion_batch_id provided -> validate that batch (fail early if not found)
#  - If not provided -> auto-select most-recent batch with ingestion_attempts > 0
#  - Always exits with structured JSON: {"status":"VALIDATION_COMPLETE", "validated": bool, "report": {...}}
#    This allows Airflow to decide on reingest vs continue. Only unexpected exceptions will crash the notebook.

from datetime import datetime
import json
from pyspark.sql import functions as F
from pyspark.sql.types import StringType
from delta.tables import DeltaTable

spark.conf.set("spark.sql.legacy.timeParserPolicy", "LEGACY")

# ---------- CONFIG ----------
FILE_REG_TABLE = "census.bronze.file_registry_v1"
BRONZE_TABLE = "census.bronze.individuals_raw_v1"
VALIDATION_TABLE = "census.bronze.validation_reports_v1"
INGESTION_AUDIT = "census.bronze.ingestion_audit_v1"

# thresholds (tunable)
DEFAULT_MANIFEST_TOL_PCT = 0.001   # Â±0.1%
AGE_HARD_LOW = 0
AGE_HARD_HIGH = 120
AGE_HARD_PASS_RATE = 0.999         # 99.9% must be within bounds
REQUIRED_COL_NULL_TOL = 0.0001     # 0.01%

# ---------- helpers ----------
def now():
    return datetime.utcnow()

def tidy_status_col(col):
    return F.lower(F.trim(F.coalesce(F.col(col), F.lit(""))))

def available_batches_summary():
    fr = spark.table(FILE_REG_TABLE)
    s = fr.withColumn("status_norm", tidy_status_col("ingestion_status"))
    agg = s.groupBy("ingestion_batch_id").agg(
        F.count(F.when(F.col("status_norm") == "pending", True)).alias("pending_count"),
        F.count(F.when(F.col("status_norm") == "processing", True)).alias("processing_count"),
        F.count(F.when(F.col("status_norm") == "succeeded", True)).alias("succeeded_count"),
        F.count(F.when(F.col("status_norm") == "failed", True)).alias("failed_count"),
        F.sum(F.coalesce(F.col("ingestion_attempts"), F.lit(0))).alias("ingestion_attempts"),
        F.max("updated_at").alias("last_updated"),
        F.count(F.lit(1)).alias("total_files")
    )
    return [r.asDict() for r in agg.orderBy(F.desc("ingestion_attempts"), F.desc("succeeded_count"), F.desc("last_updated")).collect()]

# ---------- determine ingestion_batch_id ----------
try:
    ingestion_batch_id_widget = dbutils.widgets.get("ingestion_batch_id")
    if ingestion_batch_id_widget is not None and ingestion_batch_id_widget.strip() == "":
        ingestion_batch_id_widget = None
except Exception:
    ingestion_batch_id_widget = None

# optional widget to override manifest tolerance
try:
    manifest_tol = float(dbutils.widgets.get("manifest_tolerance_pct"))
except Exception:
    manifest_tol = DEFAULT_MANIFEST_TOL_PCT

# Quick sanity: ensure file_registry exists and has rows
if not spark.catalog.tableExists(FILE_REG_TABLE):
    out = {"status":"ERROR","reason":"missing_file_registry_table","message": f"{FILE_REG_TABLE} does not exist"}
    dbutils.notebook.exit(json.dumps(out))

batches_summary = available_batches_summary()

if ingestion_batch_id_widget:
    # confirm registry rows for requested batch exist
    reg_rows = spark.table(FILE_REG_TABLE).filter(F.col("ingestion_batch_id") == ingestion_batch_id_widget).select("filename","ingestion_status","ingestion_attempts").collect()
    if not reg_rows:
        out = {"status":"ERROR","reason":"no_registry_rows_for_requested_batch","requested_batch": ingestion_batch_id_widget, "available_batches_summary": batches_summary}
        dbutils.notebook.exit(json.dumps(out))
    chosen_batch = ingestion_batch_id_widget
else:
    # auto-select: choose most recent batch where ingestion_attempts > 0 (evidence that an ingestion was attempted)
    fr = spark.table(FILE_REG_TABLE).withColumn("status_norm", tidy_status_col("ingestion_status"))
    batches = fr.groupBy("ingestion_batch_id").agg(
        F.sum(F.coalesce(F.col("ingestion_attempts"), F.lit(0))).alias("attempts"),
        F.sum(F.when(F.col("status_norm") == "succeeded", 1).otherwise(0)).alias("succeeded_count"),
        F.sum(F.when(F.col("status_norm") == "failed", 1).otherwise(0)).alias("failed_count"),
        F.max("updated_at").alias("last_updated")
    ).filter(F.col("attempts") > 0)

    if batches.limit(1).count() == 0:
        out = {"status":"NO_BATCH_TO_VALIDATE","message":"No batch found with ingestion attempts. Ensure register + ingest have run.","available_batches_summary": batches_summary}
        dbutils.notebook.exit(json.dumps(out))

    candidate = batches.orderBy(F.desc("succeeded_count"), F.desc("last_updated")).first()
    chosen_batch = candidate["ingestion_batch_id"]

print("Selected ingestion_batch_id for validation:", chosen_batch)

# ---------- perform validation ----------
start_ts = now()
run_id = f"bronze_validation-{chosen_batch}-{start_ts.strftime('%Y%m%dT%H%M%SZ')}"

# fetch registry rows for batch
reg_df = spark.table(FILE_REG_TABLE).filter(F.col("ingestion_batch_id") == chosen_batch)
reg_rows = reg_df.select("filename","manifest_reported_row_count","sha256_checksum","ingestion_status").collect()
if not reg_rows:
    out = {"status":"ERROR","reason":"no_registry_rows_for_batch_after_selection","selected_batch": chosen_batch, "available_batches_summary": batches_summary}
    dbutils.notebook.exit(json.dumps(out))

manifest_counts = {r["filename"]: int(r["manifest_reported_row_count"] or 0) for r in reg_rows}
manifest_total = sum(manifest_counts.values())

# read bronze partition for batch (if table missing, exit with error JSON)
if not spark.catalog.tableExists(BRONZE_TABLE):
    out = {"status":"ERROR","reason":"missing_bronze_table","table": BRONZE_TABLE}
    dbutils.notebook.exit(json.dumps(out))

bronze_df = spark.table(BRONZE_TABLE).filter(F.col("_ingestion_batch_id") == chosen_batch)
observed_total = bronze_df.count()

pct_diff = abs(observed_total - manifest_total) / manifest_total if manifest_total > 0 else 1.0

errors = []
warnings = []

# Hard rule: manifest agreement within tolerance -> considered validation error if exceeded
if manifest_total > 0 and pct_diff > manifest_tol:
    errors.append({"code":"manifest_count_mismatch","message":f"observed {observed_total} vs manifest {manifest_total}", "pct_diff": pct_diff})

# Required columns null fraction test
required_cols = ["person_id","geoid","census_year","date_of_birth"]
null_fractions = {}
for c in required_cols:
    null_count = bronze_df.filter(F.col(c).isNull()).count()
    frac = null_count / max(observed_total,1)
    null_fractions[c] = frac
    if frac > REQUIRED_COL_NULL_TOL:
        errors.append({"code":"required_col_null_excess","column":c,"null_fraction":frac})

# Age distribution checks
age_total = bronze_df.filter(F.col("age").isNotNull()).count()
age_invalid = bronze_df.filter((F.col("age").isNotNull()) & ((F.col("age") < AGE_HARD_LOW) | (F.col("age") > AGE_HARD_HIGH))).count()
age_invalid_frac = age_invalid / max(max(age_total,1),1)
if age_invalid_frac > (1 - AGE_HARD_PASS_RATE):
    errors.append({"code":"age_bounds_exceeded","invalid_count":age_invalid,"age_count":age_total,"invalid_fraction":age_invalid_frac})

# Soft rule: negative income fraction
neg_income_count = bronze_df.filter(F.col("annual_income_local").isNotNull() & (F.col("annual_income_local") < 0)).count()
neg_income_frac = neg_income_count / max(observed_total,1)
if neg_income_frac > 0.001:
    warnings.append({"code":"negative_income_fraction_high","fraction":neg_income_frac})

# numeric summaries (safe)
def safe_percentiles(df, col):
    try:
        # returns array of percentiles
        p = df.select(F.expr(f"percentile_approx({col}, array(0.01,0.25,0.5,0.75,0.99)) as pcts")).first()[0]
        return p
    except Exception:
        return None

age_pcts = safe_percentiles(bronze_df, "age")
income_pcts = safe_percentiles(bronze_df.filter(F.col("annual_income_local").isNotNull()), "annual_income_local")

# top region variants
top_regions = bronze_df.groupBy("region_name_reported").count().orderBy(F.desc("count")).limit(25).collect()
top_region_variants = [{"region_name_reported": r["region_name_reported"], "count": int(r["count"])} for r in top_regions]

# assemble report
report = {
    "ingestion_batch_id": chosen_batch,
    "run_id": run_id,
    "start_time": start_ts.isoformat(),
    "end_time": now().isoformat(),
    "manifest_total": manifest_total,
    "observed_total": observed_total,
    "pct_diff": pct_diff,
    "manifest_counts": manifest_counts,
    "null_fractions": null_fractions,
    "age_invalid_count": age_invalid,
    "age_invalid_fraction": age_invalid_frac,
    "negative_income_count": neg_income_count,
    "negative_income_fraction": neg_income_frac,
    "age_percentiles": age_pcts,
    "income_percentiles": income_pcts,
    "top_region_variants": top_region_variants,
    "warnings": warnings,
    "errors": errors
}

# persist to validation_reports table (append)
status_flag = "PASS" if not errors else "ERROR"
spark.createDataFrame([(chosen_batch, run_id, now(), json.dumps(report), status_flag)], schema="ingestion_batch_id string, run_id string, report_time timestamp, report_json string, status string").write.format("delta").mode("append").saveAsTable(VALIDATION_TABLE)

# If validation errors present, mark registry rows Failed (so operators can inspect). Do NOT raise/exit with error.
if errors:
    try:
        file_registry = DeltaTable.forName(spark, FILE_REG_TABLE)
        file_registry.update(
            condition = F.expr(f"ingestion_batch_id = '{chosen_batch}' AND ingestion_status != 'Failed'"),
            set = {
                "ingestion_status": F.lit("Failed"),
                "updated_at": F.lit(datetime.utcnow())
            }
        )
    except Exception as e_upd:
        # log but continue
        print("Warning: failed to mark registry rows as Failed:", str(e_upd))

# append run row to ingestion audit (status indicates there were validation errors)
ingest_status = "SUCCEEDED" if not errors else "FAILED_VALIDATION"
spark.createDataFrame([(chosen_batch, run_id, start_ts, datetime.utcnow(), ingest_status, json.dumps({"errors": errors, "warnings": warnings}))], schema="ingestion_batch_id string, run_id string, start_time timestamp, end_time timestamp, status string, notes string").write.format("delta").mode("append").saveAsTable(INGESTION_AUDIT)

# Final exit: always return structured JSON so Airflow can branch
validated_bool = (len(errors) == 0)
# TEST FAILURE
#validated_bool = False
result = {"status":"VALIDATION_COMPLETE", "validated": validated_bool, "report": report}
dbutils.notebook.exit(json.dumps(result))
