In [0]:
# Databricks notebook: silver_validation
# Path: /Workspace/Users/you/silver_validation
# Widget: ingestion_batch_id (optional)
#
# Writes:
#   - census.silver.validation_reports_v1  (append)
#   - census.silver.ingestion_audit_v1     (append)
#
# Behavior:
#   - If widget ingestion_batch_id provided -> validate that batch (fail early if not found)
#   - If not provided -> auto-select most-recent batch with ingestion_attempts > 0
#   - Always exits with structured JSON:
#       {"status":"VALIDATION_COMPLETE", "validated": bool, "report": {...}}
#
from datetime import datetime
import json
from pyspark.sql import functions as F
from delta.tables import DeltaTable

spark.conf.set("spark.sql.legacy.timeParserPolicy", "LEGACY")

# ---------- CONFIG ----------
FILE_REG_TABLE = "census.bronze.file_registry_v1"
SILVER_PERSON = "census.silver.dim_person"
SILVER_PERSON_HISTORY = "census.silver.dim_person_history"
SILVER_LINEAGE = "census.silver.lineage"
SILVER_HOUSEHOLD = "census.silver.dim_household"
SILVER_REGION = "census.silver.dim_region"

VALIDATION_TABLE = "census.silver.validation_reports_v1"
INGESTION_AUDIT = "census.silver.ingestion_audit_v1"

# thresholds (tunable)
MAX_DUPLICATES_ALLOWED = 0
PERSONS_TO_BRONZE_RATIO_WARN_THRESHOLD = 2.0

# ---------- helpers ----------
def now():
    return datetime.utcnow()

def tidy_status_col(col):
    return F.lower(F.trim(F.coalesce(F.col(col), F.lit(""))))

def available_batches_summary():
    if not spark.catalog.tableExists(FILE_REG_TABLE):
        return []
    fr = spark.table(FILE_REG_TABLE)
    s = fr.withColumn("status_norm", tidy_status_col("ingestion_status"))
    agg = s.groupBy("ingestion_batch_id").agg(
        F.count(F.when(F.col("status_norm") == "pending", True)).alias("pending_count"),
        F.count(F.when(F.col("status_norm") == "processing", True)).alias("processing_count"),
        F.count(F.when(F.col("status_norm") == "succeeded", True)).alias("succeeded_count"),
        F.count(F.when(F.col("status_norm") == "failed", True)).alias("failed_count"),
        F.sum(F.coalesce(F.col("ingestion_attempts"), F.lit(0))).alias("ingestion_attempts"),
        # --- FIX IS HERE: cast("string") ---
        F.max("updated_at").cast("string").alias("last_updated"), 
        F.count(F.lit(1)).alias("total_files")
    )
    return [r.asDict() for r in agg.orderBy(F.desc("ingestion_attempts"), F.desc("succeeded_count"), F.desc("last_updated")).collect()]

def pick_batch():
    """
    Select the ingestion_batch_id to validate:
      - prefer explicit widget
      - otherwise choose the most recent batch with ingestion_attempts > 0 (prefers batches with succeeded files)
    """
    try:
        widget_val = dbutils.widgets.get("ingestion_batch_id")
        if widget_val and widget_val.strip() != "":
            return widget_val
    except Exception:
        pass

    if not spark.catalog.tableExists(FILE_REG_TABLE):
        dbutils.notebook.exit(json.dumps({"status":"VALIDATION_COMPLETE","validated": False, "report": {"reason":"missing_file_registry_table","table": FILE_REG_TABLE}}))

    fr = spark.table(FILE_REG_TABLE).withColumn("status_norm", tidy_status_col("ingestion_status"))
    batches = fr.groupBy("ingestion_batch_id").agg(
        F.sum(F.coalesce(F.col("ingestion_attempts"), F.lit(0))).alias("attempts"),
        F.sum(F.when(F.col("status_norm") == "succeeded", 1).otherwise(0)).alias("succeeded_count"),
        F.max("updated_at").alias("last_updated")
    ).filter(F.col("attempts") > 0)

    if batches.limit(1).count() == 0:
        dbutils.notebook.exit(json.dumps({"status":"VALIDATION_COMPLETE","validated": False, "report": {"reason":"no_ingestion_attempts_found"}}))

    candidate = batches.orderBy(F.desc("succeeded_count"), F.desc("last_updated")).first()
    return candidate["ingestion_batch_id"]

# ---------- start ----------
ingestion_batch_id = pick_batch()
start_ts = now()
run_id = f"silver_validation-{ingestion_batch_id}-{start_ts.strftime('%Y%m%dT%H%M%SZ')}"

errors = []
warnings = []
notes = {}

# record available batches for debugging
batches_summary = available_batches_summary()
notes["available_batches_summary"] = batches_summary

# ---------- 0) dim_region existence and schema sanity ----------
try:
    if not spark.catalog.tableExists(SILVER_REGION):
        errors.append("dim_region_missing")
    else:
        region_df = spark.table(SILVER_REGION)
        region_count = region_df.count()
        if region_count == 0:
            warnings.append("dim_region_empty")
        # basic expected columns check
        expected_region_cols = {"geoid","region_name_standard","iso_admin_code"}
        missing_region_cols = list(expected_region_cols - set(region_df.columns))
        if missing_region_cols:
            warnings.append(f"dim_region_missing_columns:{missing_region_cols}")
except Exception as e:
    errors.append(f"dim_region_check_error:{str(e)}")

# ---------- 1) dim_person presence & minimal schema ---------- 
# should check if ANY processing happened for this batch, not just row counts
person_count_for_report = 0
processed_persons_count = 0  # Persons that were actually processed (new or updated)
try:
    if spark.catalog.tableExists(SILVER_PERSON):
        # Count persons that have this batch_id (either newly inserted or updated)
        persons_with_batch = spark.table(SILVER_PERSON).filter(
            F.col("ingestion_batch_id") == ingestion_batch_id
        ).count()
        
        # Also check lineage table to see how many persons this batch contributed to
        if spark.catalog.tableExists(SILVER_LINEAGE):
            lineage_count = spark.table(SILVER_LINEAGE).filter(
                F.col("ingestion_batch_id") == ingestion_batch_id
            ).count()
            
            # The actual count for reporting should be the maximum of these
            # This accounts for both new inserts and updates
            person_count_for_report = max(persons_with_batch, lineage_count)
            
            # Also count how many persons were actually processed (had their batch_id updated)
            # This helps differentiate between "no new data" vs "actual processing happened"
            processed_persons_count = persons_with_batch
            
            notes["person_count_with_batch"] = persons_with_batch
            notes["lineage_count_for_batch"] = lineage_count
            notes["processed_persons_count"] = processed_persons_count
            
            # If lineage has records but dim_person doesn't, it might mean no changes
            if lineage_count > 0 and persons_with_batch == 0:
                notes["batch_processed_no_changes"] = True
                print(f"NOTE: Batch {ingestion_batch_id} processed {lineage_count} persons but no attribute changes detected")
        
        else:
            person_count_for_report = persons_with_batch
            notes["person_count_with_batch"] = persons_with_batch
        
except Exception as e:
    person_count_for_report = 0
    warnings.append(f"person_count_error:{str(e)}")

# ---------- 2) uniqueness: canonical_person_id + census_year + is_current unique ----------
try:
    if spark.catalog.tableExists(SILVER_PERSON):
        # Check for ACTUAL duplicates in current records
        dup_q = f"""
            SELECT canonical_person_id, census_year, COUNT(*) as cnt
            FROM {SILVER_PERSON}
            WHERE is_current = true
            GROUP BY canonical_person_id, census_year
            HAVING COUNT(*) > 1
        """
        dup_count = spark.sql(dup_q).count()
        if dup_count > 0:
            errors.append(f"dup_current_records_count:{dup_count}")
            # Add debug info about sample duplicates
            dup_samples = spark.sql(dup_q + " LIMIT 5").collect()
            print(f"DEBUG: Found {dup_count} duplicate groups. Samples: {dup_samples}")
except Exception as e:
    errors.append(f"dup_check_error:{str(e)}")

# ---------- 3) lineage completeness ----------
try:
    if spark.catalog.tableExists(SILVER_PERSON) and spark.catalog.tableExists(SILVER_LINEAGE):
        # Check that for every lineage record in this batch, there's a corresponding person
        # This is the main check - every person processed should be in lineage
        missing_lineage_q = f"""
            SELECT p.canonical_person_id, p.census_year
            FROM {SILVER_PERSON} p
            WHERE p.ingestion_batch_id = '{ingestion_batch_id}'
              AND p.is_current = true
              AND NOT EXISTS (
                SELECT 1 
                FROM {SILVER_LINEAGE} l
                WHERE l.canonical_person_id = p.canonical_person_id
                  AND l.census_year = p.census_year
                  AND l.ingestion_batch_id = '{ingestion_batch_id}'
              )
        """
        missing_lineage = spark.sql(missing_lineage_q).count()
        
        if missing_lineage > 0:
            errors.append(f"persons_missing_lineage:{missing_lineage}")
            print(f"ERROR: {missing_lineage} persons in dim_person with batch {ingestion_batch_id} but no lineage record")
            
        # Also check the reverse - every lineage should have a current person (or historical)
        # But note: lineage might exist for historical records too
        missing_person_q = f"""
            SELECT l.canonical_person_id, l.census_year
            FROM {SILVER_LINEAGE} l
            WHERE l.ingestion_batch_id = '{ingestion_batch_id}'
              AND NOT EXISTS (
                SELECT 1 
                FROM {SILVER_PERSON} p
                WHERE p.canonical_person_id = l.canonical_person_id
                  AND p.census_year = l.census_year
                  AND (p.is_current = true OR p.ingestion_batch_id = '{ingestion_batch_id}')
              )
        """
        missing_person = spark.sql(missing_person_q).count()
        
        if missing_person > 0:
            warnings.append(f"lineage_missing_persons:{missing_person}")
            print(f"WARNING: {missing_person} lineage records without corresponding person")
            
    else:
        # lineage missing entirely -> warning
        if not spark.catalog.tableExists(SILVER_LINEAGE):
            warnings.append("lineage_table_missing")
except Exception as e:
    errors.append(f"lineage_check_error:{str(e)}")

# ---------- 4) household table sanity ----------
try:
    if not spark.catalog.tableExists(SILVER_HOUSEHOLD):
        warnings.append("dim_household_missing")
    else:
        hh_df = spark.table(SILVER_HOUSEHOLD).filter(F.col("ingestion_batch_id") == ingestion_batch_id)
        hh_count = hh_df.count()
        notes["household_count"] = hh_count
        
        # Note: Household count can be 0 even if batch processed persons
        # This is because households are derived from person data, and might not change
        
        # check for households without ids
        null_hh = hh_df.filter(F.col("household_id").isNull()).count()
        if null_hh > 0:
            warnings.append(f"household_null_count:{null_hh}")
except Exception as e:
    warnings.append(f"household_check_error:{str(e)}")

# ---------- 5) cross-check bronze->silver population sanity ----------
bronze_count = 0
try:
    if spark.catalog.tableExists("census.bronze.individuals_raw_v1"):
        bronze_count = spark.table("census.bronze.individuals_raw_v1").filter(
            F.col("_ingestion_batch_id") == ingestion_batch_id
        ).count()
        notes["bronze_count"] = bronze_count
except Exception:
    bronze_count = 0

# If processed_persons_count > 0, then at least some processing happened
if bronze_count > 0 and processed_persons_count == 0:
    # Check if this might be expected (same data re-ingested)
    warnings.append("batch_processed_no_new_persons")
    print(f"WARNING: Bronze has {bronze_count} rows but no new/updated persons in silver. This might be expected if data is identical.")

# The old warning about ratio is removed - it was misleading

# ---------- 6) optional: basic attribute-level checks (age bounds, literacy, employment distributions) ----------
try:
    if spark.catalog.tableExists(SILVER_PERSON):
        p = spark.table(SILVER_PERSON).filter(F.col("ingestion_batch_id") == ingestion_batch_id)
        if p.count() > 0:  
            # age nulls fraction
            if "age" in p.columns:
                age_nulls = p.filter(F.col("age").isNull()).count()
                age_null_frac = age_nulls / max(1, p.count())
                if age_null_frac > 0.05:
                    warnings.append(f"age_high_null_fraction:{age_null_frac:.4f}")
            # literacy present?
            if "literacy" not in p.columns:
                warnings.append("literacy_missing_in_person")
            # employment_status present?
            if "employment_status" not in p.columns:
                warnings.append("employment_status_missing_in_person")
except Exception:
    pass

# ---------- 7) Check SCD2 history consistency ----------
try:
    if spark.catalog.tableExists(SILVER_PERSON_HISTORY):
        history_count = spark.table(SILVER_PERSON_HISTORY).filter(
            F.col("ingestion_batch_id") == ingestion_batch_id
        ).count()
        notes["history_count_for_batch"] = history_count
    
        if history_count > 0:
            notes["scd2_triggered"] = True
            print(f"INFO: SCD2 was triggered for {history_count} records (attribute changes detected)")
        else:
            notes["scd2_triggered"] = False
            print(f"INFO: No SCD2 triggered for batch {ingestion_batch_id} (no attribute changes)")
except Exception as e:
    warnings.append(f"history_check_error:{str(e)}")

# ---------- Assemble report ----------
report = {
    "ingestion_batch_id": ingestion_batch_id,
    "run_id": run_id,
    "start_time": start_ts.isoformat(),
    "end_time": now().isoformat(),
    "errors": errors,
    "warnings": warnings,
    "notes": notes,
    "bronze_count": bronze_count,
    "person_count": person_count_for_report,
    "processed_persons_count": processed_persons_count,
    "scd2_triggered": notes.get("scd2_triggered", False),
    "batch_processed_no_changes": notes.get("batch_processed_no_changes", False)
}

# persist to validation_reports table (append)
status_flag = "PASS" if not errors else "ERROR"
try:
    spark.createDataFrame([(ingestion_batch_id, run_id, now(), json.dumps(report), status_flag)],
                          schema="ingestion_batch_id string, run_id string, report_time timestamp, report_json string, status string") \
         .write.format("delta").mode("append").saveAsTable(VALIDATION_TABLE)
except Exception as e:
    # if writing the validation table fails, still exit with structured JSON indicating failure
    report["validation_table_write_error"] = str(e)

# append ingestion audit row
audit_notes = json.dumps({"report_summary": {"errors": len(errors), "warnings": len(warnings)}, "note": "silver_validation run"})
audit_row = (ingestion_batch_id, run_id, start_ts, datetime.utcnow(), ("SUCCEEDED" if not errors else "FAILED_VALIDATION"), audit_notes)
audit_schema = "ingestion_batch_id string, run_id string, start_time timestamp, end_time timestamp, status string, notes string"
spark.createDataFrame([audit_row], schema=audit_schema).write.format("delta").mode("append").saveAsTable(INGESTION_AUDIT)

# If hard errors exist, do NOT throw an exception â€” instead return structured result and let orchestrator decide
validated_bool = (len(errors) == 0)
result = {"status":"VALIDATION_COMPLETE", "validated": validated_bool, "report": report}

print(f"\n=== VALIDATION SUMMARY ===")
print(f"Batch ID: {ingestion_batch_id}")
print(f"Errors: {len(errors)}")
print(f"Warnings: {len(warnings)}")
print(f"Bronze rows: {bronze_count}")
print(f"Persons processed: {processed_persons_count}")
print(f"SCD2 triggered: {notes.get('scd2_triggered', False)}")
print(f"Validation result: {'PASS' if validated_bool else 'FAIL'}")

# final structured exit for Airflow to parse
dbutils.notebook.exit(json.dumps(result))