In [0]:
# Databricks notebook: gold_validation
# Path: /Workspace/Users/you/gold_validation
# Widget: ingestion_batch_id (optional)
#
# Purpose: Validate Gold layer tables created by gold_materialize
# Returns structured JSON via dbutils.notebook.exit()
#
# Writes:
#   - census.gold.validation_reports_v1  (append)
#   - census.gold.ingestion_audit_v1     (append)  # Same as gold_materialize uses
#
# Behavior:
#   - If widget ingestion_batch_id provided -> validate that batch
#   - If not provided -> auto-select most-recent batch from gold audit table
#   - Always exits with structured JSON:
#       {"status":"VALIDATION_COMPLETE", "validated": bool, "report": {...}}
#
from datetime import datetime
import json
from pyspark.sql import functions as F

spark.conf.set("spark.sql.legacy.timeParserPolicy", "LEGACY")

# ---------- CONFIG ----------
GOLD_PREFIX = "census.gold"
GOLD_AUDIT_TABLE = f"{GOLD_PREFIX}.ingestion_audit_v1"

# Gold tables to validate
GOLD_TABLES = {
    "dim_age_group": f"{GOLD_PREFIX}.dim_age_group",
    "metric_definitions": f"{GOLD_PREFIX}.metric_definitions",
    "fact_population_by_region_year": f"{GOLD_PREFIX}.fact_population_by_region_year",
    "indicators_literacy_employment": f"{GOLD_PREFIX}.indicators_literacy_employment",
    "fact_household_summary": f"{GOLD_PREFIX}.fact_household_summary",
    "income_distribution_by_region_year": f"{GOLD_PREFIX}.income_distribution_by_region_year",
    "small_area_shrinkage_estimates": f"{GOLD_PREFIX}.small_area_shrinkage_estimates",
    "fact_population_flat_region_year": f"{GOLD_PREFIX}.fact_population_flat_region_year",
    "education_distribution_by_region_year": f"{GOLD_PREFIX}.education_distribution_by_region_year",
    "education_employment_crosswalk": f"{GOLD_PREFIX}.education_employment_crosswalk"
}

VALIDATION_TABLE = f"{GOLD_PREFIX}.validation_reports_v1"

# thresholds (tunable)
MIN_FACT_ROWS_THRESHOLD = 1  # At least 1 row in fact tables
EXPECTED_AGE_GROUPS = 9  # dim_age_group should have 9 rows

# ---------- helpers ----------
def now():
    return datetime.utcnow()

def table_exists(tname: str) -> bool:
    try:
        return spark.catalog.tableExists(tname)
    except Exception:
        return False

def pick_ingestion_batch():
    """
    Select the ingestion_batch_id to validate:
      - prefer explicit widget
      - otherwise choose the most recent batch from gold audit table
    """
    try:
        widget_val = dbutils.widgets.get("ingestion_batch_id")
        if widget_val and widget_val.strip() != "":
            print(f"Using widget ingestion_batch_id: {widget_val}")
            return widget_val
    except Exception:
        pass

    # Try to get from gold audit table
    if table_exists(GOLD_AUDIT_TABLE):
        try:
            # Get the most recent run for any batch
            latest_run = spark.table(GOLD_AUDIT_TABLE) \
                .filter(F.col("status") == "SUCCEEDED") \
                .orderBy(F.desc("run_ts")) \
                .select("ingestion_batch_id") \
                .first()
            
            if latest_run:
                batch_id = latest_run["ingestion_batch_id"]
                print(f"Auto-selected ingestion_batch_id from gold audit: {batch_id}")
                return batch_id
        except Exception as e:
            print(f"Error selecting batch from gold audit: {str(e)}")
    
    # Fallback: try to get from silver validation (upstream)
    try:
        silver_val_table = "census.silver.validation_reports_v1"
        if table_exists(silver_val_table):
            latest_silver = spark.table(silver_val_table) \
                .filter(F.col("status") == "PASS") \
                .orderBy(F.desc("report_time")) \
                .select("ingestion_batch_id") \
                .first()
            
            if latest_silver:
                batch_id = latest_silver["ingestion_batch_id"]
                print(f"Auto-selected ingestion_batch_id from silver validation: {batch_id}")
                return batch_id
    except Exception as e:
        print(f"Error selecting batch from silver validation: {str(e)}")
    
    # If still nothing, exit with error
    dbutils.notebook.exit(json.dumps({
        "status": "VALIDATION_COMPLETE",
        "validated": False,
        "report": {
            "reason": "no_ingestion_batch_id_found",
            "message": "Provide ingestion_batch_id widget or ensure gold audit table exists with successful runs"
        }
    }))

# ---------- start ----------
ingestion_batch_id = pick_ingestion_batch()
start_ts = now()
run_id = f"gold_validation-{ingestion_batch_id}-{start_ts.strftime('%Y%m%dT%H%M%SZ')}"

errors = []
warnings = []
notes = {}
table_counts = {}

print(f"Starting gold validation for batch: {ingestion_batch_id}")
print(f"Run ID: {run_id}")

# ---------- 1) Check all gold tables exist ----------
missing_tables = []
for table_name, table_path in GOLD_TABLES.items():
    if not table_exists(table_path):
        missing_tables.append(table_name)
        errors.append(f"missing_table:{table_name}")
    else:
        notes[f"{table_name}_exists"] = True

if missing_tables:
    print(f"ERROR: Missing gold tables: {missing_tables}")
else:
    print("✓ All expected gold tables exist")

# ---------- 2) Validate each table ----------
if not missing_tables:
    for table_name, table_path in GOLD_TABLES.items():
        try:
            df = spark.table(table_path)
            row_count = df.count()
            table_counts[table_name] = row_count
            notes[f"{table_name}_count"] = row_count
            
            # Table-specific validations
            if row_count == 0:
                if "fact" in table_name or "indicators" in table_name:
                    errors.append(f"empty_fact_table:{table_name}")
                    print(f"ERROR: {table_name} is empty (0 rows)")
                else:
                    warnings.append(f"empty_dimension_table:{table_name}")
                    print(f"WARNING: {table_name} is empty (0 rows)")
            
            # Check for ingestion_batch_id in fact tables (if column exists)
            if "ingestion_batch_id" in df.columns and row_count > 0:
                batch_rows = df.filter(F.col("ingestion_batch_id") == ingestion_batch_id).count()
                notes[f"{table_name}_batch_rows"] = batch_rows
                
                if batch_rows == 0 and ("fact" in table_name or "indicators" in table_name):
                    warnings.append(f"no_rows_for_batch:{table_name}")
                    print(f"WARNING: {table_name} has no rows for batch {ingestion_batch_id}")
            
            # Specific checks for dimension tables
            if table_name == "dim_age_group":
                if row_count != EXPECTED_AGE_GROUPS:
                    errors.append(f"dim_age_group_wrong_row_count: expected={EXPECTED_AGE_GROUPS}, actual={row_count}")
                    print(f"ERROR: dim_age_group has {row_count} rows, expected {EXPECTED_AGE_GROUPS}")
            
            # Check for nulls in key columns (example)
            if table_name == "fact_population_by_region_year" and row_count > 0:
                null_geoid = df.filter(F.col("geoid").isNull()).count()
                null_census_year = df.filter(F.col("census_year").isNull()).count()
                
                if null_geoid > 0:
                    errors.append(f"null_geoid_in_{table_name}:count={null_geoid}")
                if null_census_year > 0:
                    errors.append(f"null_census_year_in_{table_name}:count={null_census_year}")
            
        except Exception as e:
            errors.append(f"table_check_error:{table_name}:{str(e)}")
            print(f"ERROR checking {table_name}: {str(e)}")

# ---------- 3) Cross-table consistency checks ----------
try:
    # Check that fact_population_by_region_year has at least as many regions as indicators
    if "fact_population_by_region_year" in table_counts and "indicators_literacy_employment" in table_counts:
        fact_regions = spark.table(GOLD_TABLES["fact_population_by_region_year"]).select("geoid", "census_year").distinct().count()
        ind_regions = spark.table(GOLD_TABLES["indicators_literacy_employment"]).select("geoid", "census_year").distinct().count()
        
        notes["fact_unique_regions"] = fact_regions
        notes["indicators_unique_regions"] = ind_regions
        
        if fact_regions < ind_regions:
            warnings.append("indicators_has_more_regions_than_fact_pop")
            print(f"WARNING: indicators has {ind_regions} unique regions, fact_population has {fact_regions}")
except Exception as e:
    warnings.append(f"cross_table_check_error:{str(e)}")

# ---------- 4) Check gold audit table has entry for this batch ----------
try:
    if table_exists(GOLD_AUDIT_TABLE):
        audit_rows = spark.table(GOLD_AUDIT_TABLE) \
            .filter(F.col("ingestion_batch_id") == ingestion_batch_id) \
            .filter(F.col("status") == "SUCCEEDED") \
            .count()
        
        notes["gold_audit_rows_for_batch"] = audit_rows
        
        if audit_rows == 0:
            warnings.append("no_gold_audit_entry_for_batch")
            print(f"WARNING: No SUCCEEDED entry in gold audit table for batch {ingestion_batch_id}")
    else:
        warnings.append("gold_audit_table_missing")
except Exception as e:
    warnings.append(f"gold_audit_check_error:{str(e)}")

# ---------- 5) Summary statistics ----------
total_gold_rows = sum(table_counts.values()) if table_counts else 0
notes["total_gold_rows"] = total_gold_rows
notes["table_counts"] = table_counts

print(f"\n=== GOLD VALIDATION SUMMARY ===")
print(f"Total gold rows across all tables: {total_gold_rows}")
print(f"Errors: {len(errors)}")
print(f"Warnings: {len(warnings)}")
for table, count in table_counts.items():
    print(f"  {table}: {count} rows")

# ---------- Assemble report ----------
report = {
    "ingestion_batch_id": ingestion_batch_id,
    "run_id": run_id,
    "start_time": start_ts.isoformat(),
    "end_time": now().isoformat(),
    "errors": errors,
    "warnings": warnings,
    "notes": notes,
    "table_counts": table_counts,
    "total_gold_rows": total_gold_rows,
    "validation_layer": "gold"
}

# ---------- Write to validation table ----------
status_flag = "PASS" if not errors else "ERROR"

try:
    # Create validation table if it doesn't exist
    if not table_exists(VALIDATION_TABLE):
        spark.sql(f"""
            CREATE TABLE {VALIDATION_TABLE} (
                ingestion_batch_id STRING,
                run_id STRING,
                report_time TIMESTAMP,
                report_json STRING,
                status STRING
            )
            USING DELTA
        """)
        print(f"Created validation table: {VALIDATION_TABLE}")
    
    # Write validation result
    spark.createDataFrame([(ingestion_batch_id, run_id, now(), json.dumps(report), status_flag)],
                         schema="ingestion_batch_id string, run_id string, report_time timestamp, report_json string, status string") \
         .write.format("delta").mode("append").saveAsTable(VALIDATION_TABLE)
    print(f"✓ Wrote validation result to {VALIDATION_TABLE}")
    
except Exception as e:
    report["validation_table_write_error"] = str(e)
    print(f"ERROR writing to validation table: {str(e)}")

# ---------- Write to ingestion audit table ----------
try:
    if table_exists(GOLD_AUDIT_TABLE):
        audit_notes = json.dumps({
            "validation_summary": {
                "errors": len(errors),
                "warnings": len(warnings),
                "validated": len(errors) == 0
            },
            "note": "gold_validation run"
        })
        
        audit_row = (run_id, ingestion_batch_id, start_ts.isoformat(), now().isoformat(), 
                    "VALIDATION_SUCCEEDED" if not errors else "VALIDATION_FAILED", 
                    audit_notes)
        
        audit_df = spark.createDataFrame([audit_row], 
                                        schema="run_id string, ingestion_batch_id string, start_time string, end_time string, status string, notes string")
        
        audit_df.write.format("delta").mode("append").option("mergeSchema", "true").saveAsTable(GOLD_AUDIT_TABLE)
        print(f"✓ Wrote audit entry to {GOLD_AUDIT_TABLE}")
    else:
        warnings.append("gold_audit_table_missing_could_not_write")
except Exception as e:
    report["audit_table_write_error"] = str(e)
    print(f"ERROR writing to audit table: {str(e)}")

# ---------- Final result ----------
validated_bool = (len(errors) == 0)
result = {
    "status": "VALIDATION_COMPLETE",
    "validated": validated_bool,
    "report": report
}

print(f"\n=== FINAL RESULT ===")
print(f"Validation passed: {validated_bool}")
print(f"Returning structured JSON result")

# Return structured JSON for Airflow
dbutils.notebook.exit(json.dumps(result))