In [0]:
# Databricks notebook source
# Title: gold_materialize_extended.py
# Purpose: materialize Gold medallion artifacts, including education tables
# Assumptions:
#  - Silver conformed tables exist under census.silver (dim_person, dim_household, dim_region, lineage)
#  - Unity Catalog is available; write to census.gold
#  - Notebook runs with appropriate privileges to create/overwrite gold tables
# - This notebook is idempotent: tables are overwritten for the current run

from datetime import datetime
import json
import math
import uuid
import statistics
from typing import List, Tuple

from pyspark.sql import functions as F
from pyspark.sql.window import Window
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType, TimestampType, ArrayType, MapType, LongType
from delta.tables import DeltaTable

# compute stats
import pandas as pd
import numpy as np

spark.conf.set("spark.sql.legacy.timeParserPolicy", "LEGACY")

# ---------- CONFIG ----------
CATALOG = "census"
GOLD_SCHEMA = "gold"
GOLD_PREFIX = f"{CATALOG}.{GOLD_SCHEMA}"

SILVER_PERSON = f"{CATALOG}.silver.dim_person"
SILVER_HOUSEHOLD = f"{CATALOG}.silver.dim_household"
SILVER_REGION = f"{CATALOG}.silver.dim_region"
SILVER_LINEAGE = f"{CATALOG}.silver.lineage"

# Output Gold table names
DIM_AGE = f"{GOLD_PREFIX}.dim_age_group"
METRIC_DEFS = f"{GOLD_PREFIX}.metric_definitions"

FACT_POP = f"{GOLD_PREFIX}.fact_population_by_region_year"
INDICATORS = f"{GOLD_PREFIX}.indicators_literacy_employment"
FACT_HH = f"{GOLD_PREFIX}.fact_household_summary"
INCOME_DIST = f"{GOLD_PREFIX}.income_distribution_by_region_year"
FLAT_FACT = f"{GOLD_PREFIX}.fact_population_flat_region_year"
SMALL_AREA = f"{GOLD_PREFIX}.small_area_shrinkage_estimates"
INGEST_AUDIT = f"{GOLD_PREFIX}.ingestion_audit_v1"

# NEW education tables
EDU_DIST = f"{GOLD_PREFIX}.education_distribution_by_region_year"
EDU_XWALK = f"{GOLD_PREFIX}.education_employment_crosswalk"

# Partitioning config
PARTITION_COL = "census_year"

# Run metadata
RUN_ID = f"gold_materialize_run_{uuid.uuid4().hex[:8]}"
RUN_TS = datetime.utcnow()

# Utility
def table_exists(tname: str) -> bool:
    try:
        return spark.catalog.tableExists(tname)
    except Exception:
        return False

def ensure_schema():
    # Create schema if not exists (Unity Catalog)
    spark.sql(f"CREATE SCHEMA IF NOT EXISTS {GOLD_PREFIX}")

ensure_schema()

# ---------- pick ingestion_batch_id (widget or auto-detect) ----------
def tidy_status_col(col):
    return F.lower(F.trim(F.coalesce(F.col(col), F.lit(""))))

def pick_ingestion_batch():
    # First try widget
    try:
        widget_val = dbutils.widgets.get("ingestion_batch_id")
        if widget_val and widget_val.strip() != "":
            print(f"Using supplied ingestion_batch_id widget: {widget_val}")
            return widget_val
    except Exception:
        pass

    # Fallback: inspect bronze file_registry for recent batches with attempts
    try:
        fr_table = "census.bronze.file_registry_v1"
        if not spark.catalog.tableExists(fr_table):
            print("No file_registry found; ingestion_batch_id will be None.")
            return None

        fr = spark.table(fr_table).withColumn("status_norm", tidy_status_col("ingestion_status"))
        batches = fr.groupBy("ingestion_batch_id").agg(
            F.sum(F.coalesce(F.col("ingestion_attempts"), F.lit(0))).alias("attempts"),
            F.sum(F.when(F.col("status_norm") == "succeeded", 1).otherwise(0)).alias("succeeded_count"),
            F.max("updated_at").alias("last_updated")
        ).filter(F.col("attempts") > 0)

        if batches.limit(1).count() == 0:
            print("No ingestion attempts found in file_registry; ingestion_batch_id will be None.")
            return None

        candidate = batches.orderBy(F.desc("succeeded_count"), F.desc("last_updated")).first()
        print(f"Auto-selected ingestion_batch_id = {candidate['ingestion_batch_id']}")
        return candidate["ingestion_batch_id"]
    except Exception as e:
        print("Error while picking ingestion_batch_id:", str(e))
        return None

ingestion_batch_id = pick_ingestion_batch()
print("ingestion_batch_id used for this run:", ingestion_batch_id)
# RUN_ID will act as the gold run id / run identifier; keep existing name RUN_ID for compatibility

# ---------- 1) Create / Overwrite dim_age_group ----------
age_bins = [
    ("0-4", 0, 4),
    ("5-14", 5, 14),
    ("15-24", 15, 24),
    ("25-34", 25, 34),
    ("35-44", 35, 44),
    ("45-54", 45, 54),
    ("55-64", 55, 64),
    ("65-74", 65, 74),
    ("75+", 75, 200)
]

age_rows = [(label, lo, hi) for (label, lo, hi) in age_bins]
schema_age = StructType([
    StructField("age_group", StringType(), False),
    StructField("age_min", IntegerType(), False),
    StructField("age_max", IntegerType(), False)
])
df_age = spark.createDataFrame(age_rows, schema_age)
df_age.write.format("delta").mode("overwrite").option("overwriteSchema","true").saveAsTable(DIM_AGE)
print(f"Written {DIM_AGE} ({df_age.count()} rows)")

# ---------- 2) Metric definitions (simple examples) ----------
metric_defs = [
    ("population_count", "Total headcount", "COUNT(person_id)"),
    ("population_by_age_group", "Population by age group", "SUM(population_count) GROUP BY age_group"),
    ("literacy_rate", "Literacy rate (age 5+)", "SUM(literate_count) / SUM(pop_age_5_plus)"),
    ("employment_rate", "Employment rate (age 15+)", "SUM(employed_count) / SUM(pop_age_15_plus)"),
    ("mean_income", "Mean annual income (employed)", "AVG(annual_income_local)"),
    ("median_income", "Median annual income (employed)", "APPROX_PERCENTILE(annual_income_local, 0.5)"),
    ("gini_income", "Gini coefficient of income", "Gini(annual_income_local)"),
    ("dependency_ratio", "Dependency ratio", "(pop_0_14 + pop_65_plus) / pop_15_64")
]
schema_metrics = StructType([
    StructField("metric_id", StringType(), False),
    StructField("description", StringType(), True),
    StructField("sql_definition", StringType(), True)
])
df_metrics = spark.createDataFrame(metric_defs, schema_metrics)
df_metrics.write.format("delta").mode("overwrite").option("overwriteSchema","true").saveAsTable(METRIC_DEFS)
print(f"Written {METRIC_DEFS} ({df_metrics.count()} rows)")

# ---------- 3) Read Silver person (canonical current persons) ----------
if not table_exists(SILVER_PERSON):
    raise RuntimeError(f"Required source table not found: {SILVER_PERSON}")

person = spark.table(SILVER_PERSON)
# Prefer current persons if SCD2 is used
if "is_current" in person.columns:
    person = person.filter(F.col("is_current") == True)

# Ensure essential columns exist; provide aliases/typing
person = person.withColumn("geoid", F.col("geoid").cast("int")) \
               .withColumn("census_year", F.col("census_year").cast("int")) \
               .withColumn("age", F.col("age").cast("int")) \
               .withColumn("sex", F.col("sex").cast("string"))

# ---------- helper: assign age_group using dim_age_group mapping ----------
age_df = df_age  # created above

# Broadcast small dim_age for efficient join
age_broadcast = F.broadcast(age_df)

# Create age_group expression via join: easier to perform a cross join mapping
person_with_agegroup = person.join(age_df, (person.age >= age_df.age_min) & (person.age <= age_df.age_max), how="left") \
                             .withColumnRenamed("age_group", "age_group") \
                             .select(*person.columns, "age_group")

# fallback fill
person_with_agegroup = person_with_agegroup.withColumn("age_group", F.coalesce(F.col("age_group"), F.lit("unknown")))

print("Person rows for Gold processing:", person_with_agegroup.count())

# === Add derived flags on the canonical person_with_agegroup DataFrame ===
person_with_agegroup = (
    person_with_agegroup
      .withColumn("literacy_flag",
          F.when(
              (F.col("literacy") == True) |
              (F.lower(F.coalesce(F.col("literacy"), F.lit(""))).isin("true","1","yes")),
              F.lit(1)
          ).otherwise(F.lit(0))
      )
      .withColumn("is_employed_flag",
          F.when(
              F.lower(F.coalesce(F.col("employment_status"), F.lit(""))) == "employed",
              F.lit(1)
          ).otherwise(F.lit(0))
      )
      .withColumn("is_informal_flag",
          F.when(
              F.lower(F.coalesce(F.col("employment_type"), F.lit(""))) == "informal",
              F.lit(1)
          ).otherwise(F.lit(0))
      )
)

# ---------- 4) fact_population_by_region_year ----------
# Aggregate population_count by geoid, census_year, age_group, sex
pop_agg = person_with_agegroup.groupBy("geoid","census_year","age_group","sex") \
                              .agg(F.count(F.lit(1)).cast("long").alias("population_count"))

# attach provenance columns
pop_agg = pop_agg.withColumn("ingestion_batch_id", F.lit(ingestion_batch_id)) \
                 .withColumn("run_id", F.lit(RUN_ID))

# write
pop_agg.write.format("delta").mode("overwrite").partitionBy(PARTITION_COL).option("overwriteSchema","true").saveAsTable(FACT_POP)
print(f"Written {FACT_POP} ({pop_agg.count()} rows)")

# ---------- 5) indicators_literacy_employment ----------
# Compute key indicators per geoid,census_year
p = person_with_agegroup

# Normalize literacy/employment booleans and fields defensively
# (flags already exist on person_with_agegroup; these lines are harmless redundancy)
p = p.withColumn("literacy_flag", F.when(F.col("literacy") == True, 1).otherwise(0))
p = p.withColumn("is_employed_flag", F.when(F.col("employment_status") == "Employed", 1).otherwise(0))
p = p.withColumn("is_informal_flag", F.when(F.col("employment_type") == "Informal", 1).otherwise(0))

ind = p.groupBy("geoid","census_year").agg(
    F.sum(F.when(F.col("age") >= 5, F.col("literacy_flag")).otherwise(0)).alias("literate_count"),
    F.sum(F.when(F.col("age") >= 5, 1).otherwise(0)).alias("pop_age_5_plus"),
    F.sum(F.when(F.col("age") >= 15, F.col("is_employed_flag")).otherwise(0)).alias("employed_count"),
    F.sum(F.when(F.col("age") >= 15, 1).otherwise(0)).alias("pop_age_15_plus"),
    F.sum(F.col("is_informal_flag")).alias("informal_employed_count")
)

# Add derived rates
ind = ind.withColumn("informal_employment_share", 
                     F.when(F.col("employed_count") > 0, F.col("informal_employed_count") / F.col("employed_count")).otherwise(F.lit(None)))

# attach provenance columns
ind = ind.withColumn("ingestion_batch_id", F.lit(ingestion_batch_id)) \
         .withColumn("run_id", F.lit(RUN_ID))

ind = ind.select("geoid","census_year","literate_count","pop_age_5_plus","employed_count","pop_age_15_plus","informal_employed_count","informal_employment_share","ingestion_batch_id","run_id")

ind.write.format("delta").mode("overwrite").partitionBy(PARTITION_COL).option("overwriteSchema","true").saveAsTable(INDICATORS)
print(f"Written {INDICATORS} ({ind.count()} rows)")

# ---------- 6) fact_household_summary ----------
# build household aggregates from silver.household + person (if available)
if table_exists(SILVER_HOUSEHOLD):
    hh = spark.table(SILVER_HOUSEHOLD)
    # If household table exists, use it as base (could join to persons to enrich)
    # But compute summary: household_size, median_household_income, household_literacy_rate
    person_income = person_with_agegroup.select("household_id","geoid","annual_income_local","literacy_flag")
    # household_size and median income and literacy rate per household
    hh_agg = person_income.groupBy("household_id","geoid").agg(
        F.count("*").alias("household_size"),
        F.expr("percentile_approx(annual_income_local, 0.5)").alias("median_household_income"),
        F.avg(F.col("literacy_flag")).alias("household_literacy_rate")
    ).withColumn("ingestion_batch_id", F.lit(ingestion_batch_id)).withColumn("run_id", F.lit(RUN_ID))

    hh_agg.write.format("delta").mode("overwrite").partitionBy("geoid").option("overwriteSchema","true").saveAsTable(FACT_HH)
    print(f"Written {FACT_HH} ({hh_agg.count()} rows)")
else:
    # If no silver household, derive household aggregates directly from persons (best-effort)
    hh_agg = person_with_agegroup.groupBy("household_id","geoid").agg(
        F.count("*").alias("household_size"),
        F.expr("percentile_approx(annual_income_local, 0.5)").alias("median_household_income"),
        F.avg(F.col("literacy_flag")).alias("household_literacy_rate")
    ).withColumn("ingestion_batch_id", F.lit(ingestion_batch_id)).withColumn("run_id", F.lit(RUN_ID))

    hh_agg.write.format("delta").mode("overwrite").partitionBy("geoid").option("overwriteSchema","true").saveAsTable(FACT_HH)
    print(f"Derived & written {FACT_HH} ({hh_agg.count()} rows)")

# ---------- 7) income_distribution_by_region_year  (deciles, mean, median, gini) ----------
# Use employed persons with non-null positive incomes
employed_income = person_with_agegroup.filter((F.col("employment_status") == "Employed") & (F.col("annual_income_local").isNotNull()) & (F.col("annual_income_local") > 0)) \
                                      .select("geoid","census_year","annual_income_local")

# Produce grouped collect_list and some aggregates, then compute deciles/gini on driver-side per group
grp = employed_income.groupBy("geoid","census_year").agg(
    F.collect_list("annual_income_local").alias("incomes"),
    F.count("*").alias("employed_count"),
    F.avg("annual_income_local").alias("mean_income")
)

# materialize to pandas for decile/gini calculations (safe for 100k rows)
pdf = grp.toPandas()

def compute_stats(row):
    incomes = row["incomes"] if row["incomes"] is not None else []
    n = int(row["employed_count"] or 0)
    mean_income = float(row["mean_income"] or 0.0)
    if n == 0:
        return {
            "geoid": int(row["geoid"]),
            "census_year": int(row["census_year"]),
            "employed_count": 0,
            "mean_income": None,
            "median_income": None,
            "deciles": None,
            "gini_income": None,
            "topcoded_count": 0
        }
    arr = sorted([float(v) for v in incomes])
    # median
    median = float(statistics.median(arr))
    # deciles - 10%,20%...
    deciles = [float(np.quantile(arr, q/10.0)) for q in range(1,10)] if len(arr) > 0 else None
    # gini calculation
    try:
        a = np.array(arr)
        if a.sum() <= 0 or n == 0:
            gini = None
        else:
            sorted_a = np.sort(a)
            index = np.arange(1, n+1)
            gini = (2.0 * np.sum(index * sorted_a) - (n + 1) * np.sum(sorted_a)) / (n * np.sum(sorted_a))
            gini = float(gini)
    except Exception:
        gini = None
    topcoded = sum(1 for v in arr if v >= 200000)
    return {
        "geoid": int(row["geoid"]),
        "census_year": int(row["census_year"]),
        "employed_count": n,
        "mean_income": mean_income,
        "median_income": median,
        "deciles": deciles,
        "gini_income": gini,
        "topcoded_count": topcoded
    }

stats_rows = []
for idx, r in pdf.iterrows():
    stats_rows.append(compute_stats(r))

# build spark dataframe for income dist
if len(stats_rows) > 0:
    schema_income = StructType([
        StructField("geoid", IntegerType(), True),
        StructField("census_year", IntegerType(), True),
        StructField("employed_count", IntegerType(), True),
        StructField("mean_income", DoubleType(), True),
        StructField("median_income", DoubleType(), True),
        StructField("deciles", ArrayType(DoubleType()), True),
        StructField("gini_income", DoubleType(), True),
        StructField("topcoded_count", IntegerType(), True)
    ])
    spark_income = spark.createDataFrame(stats_rows, schema=schema_income)
else:
    spark_income = spark.createDataFrame([], schema=StructType([
        StructField("geoid", IntegerType(), True),
        StructField("census_year", IntegerType(), True),
        StructField("employed_count", IntegerType(), True),
        StructField("mean_income", DoubleType(), True),
        StructField("median_income", DoubleType(), True),
        StructField("deciles", ArrayType(DoubleType()), True),
        StructField("gini_income", DoubleType(), True),
        StructField("topcoded_count", IntegerType(), True)
    ]))

# attach provenance columns
spark_income = spark_income.withColumn("ingestion_batch_id", F.lit(ingestion_batch_id)).withColumn("run_id", F.lit(RUN_ID))

spark_income.write.format("delta").mode("overwrite").partitionBy(PARTITION_COL).option("overwriteSchema","true").saveAsTable(INCOME_DIST)
print(f"Written {INCOME_DIST} ({spark_income.count()} rows)")

# ---------- 8) small_area_shrinkage_estimates (empirical-Bayes for literacy) ----------
# Use indicators table to compute region-level k (literate_count) and n (pop_age_5_plus)
ind_df = spark.table(INDICATORS)
# compute global prior
global_sum = ind_df.agg(F.sum("literate_count").alias("K"), F.sum("pop_age_5_plus").alias("N")).collect()[0]
K = global_sum["K"] or 0
N = global_sum["N"] or 0
global_rate = (K / N) if N and N > 0 else 0.5

# choose equivalent sample size m as average sample size across regions (clip to reasonable)
avg_n = ind_df.agg(F.avg("pop_age_5_plus").alias("avg_n")).collect()[0]["avg_n"] or 10
m = max(5, int(avg_n))  # pseudo-counts

alpha0 = global_rate * m
beta0 = (1.0 - global_rate) * m

# compute posterior mean for each region
sa = ind_df.withColumn("alpha0", F.lit(float(alpha0))).withColumn("beta0", F.lit(float(beta0))) \
           .withColumn("posterior_mean", (F.col("alpha0") + F.col("literate_count")) / (F.col("alpha0") + F.col("beta0") + F.col("pop_age_5_plus"))) \
           .withColumn("k_region", F.col("literate_count")) \
           .withColumn("n_region", F.col("pop_age_5_plus")) \
           .select("geoid","census_year","k_region","n_region","posterior_mean","alpha0","beta0")

sa = sa.withColumnRenamed("posterior_mean", "shrinkage_literacy_rate") \
       .withColumnRenamed("k_region", "k_region") \
       .withColumnRenamed("n_region", "n_region")

# attach provenance columns
sa = sa.withColumn("ingestion_batch_id", F.lit(ingestion_batch_id)).withColumn("run_id", F.lit(RUN_ID))

sa.write.format("delta").mode("overwrite").partitionBy(PARTITION_COL).option("overwriteSchema","true").saveAsTable(SMALL_AREA)
print(f"Written {SMALL_AREA} ({sa.count()} rows)")

# ---------- 9) fact_population_flat_region_year (flatten several components into one table) ----------
# Build region-level flat view by joining: region_name from silver.region if available; else use geoid label
if table_exists(SILVER_REGION):
    region_df = spark.table(SILVER_REGION).select("geoid","region_name_standard","iso_admin_code")
    region_df = region_df.withColumnRenamed("region_name_standard","region_name")
else:
    # derive region name from flattened facts if available (best effort)
    region_df = pop_agg.select("geoid").distinct().withColumn("region_name", F.concat(F.lit("Region "), F.col("geoid")))

# indicators (ind), income (spark_income), small area (sa), household summary averaged per region-year
hh_region = spark.table(FACT_HH).groupBy("geoid","ingestion_batch_id").agg(
    F.avg("median_household_income").alias("avg_median_hh_income"),
    F.avg("household_size").alias("avg_household_size")
)
# but hh_region uses ingestion_batch_id rather than census_year; attempt to join via geoid only and later tune
# Build a base frame for each geoid,census_year from indicators
base = ind_df.alias("ind").join(region_df.alias("r"), on="geoid", how="left") \
                   .select(F.col("ind.geoid"), F.col("ind.census_year"), F.col("r.region_name"), "literate_count","pop_age_5_plus","employed_count","pop_age_15_plus","informal_employed_count")

# join income stats
flat = base.alias("b").join(
    spark_income.alias("inc"), 
    on=["geoid", "census_year"], 
    how="left"
).join(
    sa.alias("s"), 
    on=["geoid", "census_year"], 
    how="left"
)

# join household aggregates by geoid (best-effort)
if "geoid" in hh_region.columns:
    # hh_region uses ingestion_batch_id; aggregate by geoid average
    hh_by_geoid = hh_region.groupBy("geoid").agg(F.avg("avg_median_hh_income").alias("avg_median_hh_income"), F.avg("avg_household_size").alias("avg_household_size"))
    flat = flat.join(hh_by_geoid.alias("h"), on="geoid", how="left")
else:
    flat = flat.withColumn("avg_median_hh_income", F.lit(None)).withColumn("avg_household_size", F.lit(None))

flat = flat.select(
    "b.*",                                # Select all from indicators/base
    "inc.mean_income", 
    "inc.median_income",
    "inc.deciles",
    "inc.gini_income",
    "inc.topcoded_count",
    "s.shrinkage_literacy_rate",
    "h.avg_median_hh_income",
    "h.avg_household_size"
)

# finalize flattened columns: canonicalize column names used by dashboards
flat = flat.withColumnRenamed("mean_income","mean_income") \
           .withColumnRenamed("median_income","median_income") \
           .withColumn("pop_age_5_plus_final", F.col("pop_age_5_plus")) \
           .withColumn("pop_age_15_plus", F.col("pop_age_15_plus")) \
           .withColumn("employed_final_count", F.col("employed_count")) \
           .withColumn("literate_count", F.col("literate_count")) \
           .withColumn("gini_income", F.col("gini_income")) \
           .withColumn("topcoded_count", F.col("topcoded_count")) \
           .withColumn("deciles", F.col("deciles")) \
           .withColumn("ingestion_batch_id", F.lit(ingestion_batch_id)) \
           .withColumn("run_id", F.lit(RUN_ID))

# write flat
flat.write.format("delta").mode("overwrite").partitionBy(PARTITION_COL).option("overwriteSchema","true").saveAsTable(FLAT_FACT)
print(f"Written {FLAT_FACT} ({flat.count()} rows)")

# ---------- 10) EDUCATION: education_distribution_by_region_year ----------
# Build education distribution from silver.dim_person (education_level expected in silver)
edu_cols = ["geoid","census_year","education_level","sex","age","literacy"]

if "education_level" not in person_with_agegroup.columns:
    # Absent education_level—create an empty table and log
    print("education_level not found in silver.dim_person; creating empty education tables as placeholders.")
    empty_schema = StructType([
        StructField("geoid", IntegerType(), True),
        StructField("census_year", IntegerType(), True),
        StructField("education_level", StringType(), True),
        StructField("sex", StringType(), True),
        StructField("population_count", IntegerType(), True),
        StructField("population_share", DoubleType(), True),
        StructField("literate_count", IntegerType(), True),
        StructField("literacy_rate_within_level", DoubleType(), True),
        StructField("ingestion_batch_id", StringType(), True),
        StructField("run_id", StringType(), True)
    ])
    spark.createDataFrame([], schema=empty_schema).write.format("delta").mode("overwrite").saveAsTable(EDU_DIST)
    spark.createDataFrame([], schema=empty_schema).write.format("delta").mode("overwrite").saveAsTable(EDU_XWALK)
else:
    # Compute distribution
    # population_count by geoid,census_year,education_level,sex
    edu_base = person_with_agegroup.filter(F.col("education_level").isNotNull())
    edu_agg = edu_base.groupBy("geoid","census_year","education_level","sex").agg(
        F.count("*").alias("population_count"),
        F.sum(F.when(F.col("age") >= 5, F.col("literacy_flag")).otherwise(0)).alias("literate_count")
    )
    # compute population_share per geoid,census_year
    total_by_region = edu_agg.groupBy("geoid","census_year").agg(F.sum("population_count").alias("region_pop"))
    edu = edu_agg.join(total_by_region, on=["geoid","census_year"], how="left") \
                 .withColumn("population_share", F.col("population_count") / F.col("region_pop")) \
                 .withColumn("literacy_rate_within_level", F.when(F.col("population_count") > 0, F.col("literate_count") / F.col("population_count")).otherwise(F.lit(None))) \
                 .select("geoid","census_year","education_level","sex","population_count","population_share","literate_count","literacy_rate_within_level")

    # attach provenance columns
    edu = edu.withColumn("ingestion_batch_id", F.lit(ingestion_batch_id)).withColumn("run_id", F.lit(RUN_ID))

    # write
    edu.write.format("delta").mode("overwrite").partitionBy(PARTITION_COL).option("overwriteSchema","true").saveAsTable(EDU_DIST)
    print(f"Written {EDU_DIST} ({edu.count()} rows)")

    # ---------- 11) EDUCATION × EMPLOYMENT CROSSWALK ----------
    # For each education_level, geoid,census_year compute employment stats and income stats
    # Denominators: persons age >= 15 for employment rate
    edu_x = edu_base.filter(F.col("age") >= 15).groupBy("geoid","census_year","education_level").agg(
        F.count("*").alias("pop_age_15_plus_edu"),
        F.sum(F.when(F.col("employment_status") == "Employed", 1).otherwise(0)).alias("employed_count_edu"),
        F.sum(F.when(F.col("employment_type") == "Informal", 1).otherwise(0)).alias("informal_count_edu")
    )

    # income stats for employed in that education level
    incomes_edu = person_with_agegroup.filter((F.col("age") >= 15) & (F.col("employment_status") == "Employed") & (F.col("annual_income_local").isNotNull())).select(
        "geoid","census_year","education_level","annual_income_local"
    )

    # compute mean and median per group using collect_list approach and driver-side compute (safe for our size)
    inc_grp = incomes_edu.groupBy("geoid","census_year","education_level").agg(
        F.collect_list("annual_income_local").alias("incomes"),
        F.count("*").alias("employed_count_income")
    )

    pdf_inc = inc_grp.toPandas()
    rows_x = []
    for idx, r in pdf_inc.iterrows():
        ge = int(r["geoid"])
        cy = int(r["census_year"])
        edu_lvl = r["education_level"]
        incomes = r["incomes"] if r["incomes"] is not None else []
        if len(incomes) > 0:
            arr = sorted([float(x) for x in incomes])
            mean_i = float(sum(arr)/len(arr))
            median_i = float(statistics.median(arr))
        else:
            mean_i = None
            median_i = None
        rows_x.append({
            "geoid": ge,
            "census_year": cy,
            "education_level": edu_lvl,
            "mean_income": mean_i,
            "median_income": median_i
        })

    # create spark df for income summary
    if rows_x:
        schema_x = StructType([
            StructField("geoid", IntegerType(), True),
            StructField("census_year", IntegerType(), True),
            StructField("education_level", StringType(), True),
            StructField("mean_income", DoubleType(), True),
            StructField("median_income", DoubleType(), True)
        ])
        spark_inc_x = spark.createDataFrame(rows_x, schema=schema_x)
    else:
        spark_inc_x = spark.createDataFrame([], schema=StructType([
            StructField("geoid", IntegerType(), True),
            StructField("census_year", IntegerType(), True),
            StructField("education_level", StringType(), True),
            StructField("mean_income", DoubleType(), True),
            StructField("median_income", DoubleType(), True)
        ]))

    # join the edu_x and income summary
    edu_x_full = edu_x.join(spark_inc_x, on=["geoid","census_year","education_level"], how="left") \
                     .withColumn("employment_rate", F.when(F.col("pop_age_15_plus_edu") > 0, F.col("employed_count_edu") / F.col("pop_age_15_plus_edu")).otherwise(F.lit(None))) \
                     .withColumn("informal_employment_share", F.when(F.col("employed_count_edu") > 0, F.col("informal_count_edu") / F.col("employed_count_edu")).otherwise(F.lit(None)))

    # attach provenance columns
    edu_x_full = edu_x_full.withColumn("ingestion_batch_id", F.lit(ingestion_batch_id)).withColumn("run_id", F.lit(RUN_ID))

    edu_x_full = edu_x_full.select("geoid","census_year","education_level","pop_age_15_plus_edu","employed_count_edu","employment_rate","informal_employment_share","mean_income","median_income","ingestion_batch_id","run_id")

    edu_x_full.write.format("delta").mode("overwrite").partitionBy(PARTITION_COL).option("overwriteSchema","true").saveAsTable(EDU_XWALK)
    print(f"Written {EDU_XWALK} ({edu_x_full.count()} rows)")

# ---------- 12) ingestion_audit_v1 (append run metadata) ----------
# Build a small audit row
audit_row = {
    "run_id": RUN_ID,
    "ingestion_batch_id": ingestion_batch_id,
    "run_ts": RUN_TS.isoformat(),
    "notebook": "gold_materialize_extended",
    "status": "SUCCEEDED",
    "notes": json.dumps({
        "tables_written": [
            DIM_AGE, METRIC_DEFS, FACT_POP, INDICATORS, FACT_HH, INCOME_DIST, SMALL_AREA, FLAT_FACT, EDU_DIST, EDU_XWALK
        ],
        "row_counts": {
            "fact_population_by_region_year": int(pop_agg.count()),
            "indicators_literacy_employment": int(ind.count()),
            "fact_household_summary": int(hh_agg.count()) if 'hh_agg' in locals() else 0,
            "income_distribution_by_region_year": int(spark_income.count()),
            "small_area_shrinkage_estimates": int(sa.count()),
            "fact_population_flat_region_year": int(flat.count()),
            "education_distribution_by_region_year": int(spark.table(EDU_DIST).count()) if table_exists(EDU_DIST) else 0,
            "education_employment_crosswalk": int(spark.table(EDU_XWALK).count()) if table_exists(EDU_XWALK) else 0
        }
    })
}

audit_schema = StructType([
    StructField("run_id", StringType(), False),
    StructField("ingestion_batch_id", StringType(), True),
    StructField("run_ts", StringType(), False),
    StructField("notebook", StringType(), False),
    StructField("status", StringType(), False),
    StructField("notes", StringType(), True)
])
audit_df = spark.createDataFrame([ (audit_row["run_id"], audit_row["ingestion_batch_id"], audit_row["run_ts"], audit_row["notebook"], audit_row["status"], audit_row["notes"]) ], schema=audit_schema)

# append or create
if table_exists(INGEST_AUDIT):
    audit_df.write.format("delta").mode("append").option("mergeSchema", "true").saveAsTable(INGEST_AUDIT)
else:
    audit_df.write.format("delta").mode("overwrite").saveAsTable(INGEST_AUDIT)

print(f"Appended ingestion audit to {INGEST_AUDIT}")

# ---------- Done ----------
print("Gold materialization complete.")
print("Run ID:", RUN_ID)
print("Ingestion batch id:", ingestion_batch_id)


  RUN_TS = datetime.utcnow()


Auto-selected ingestion_batch_id = reg-20260103T154251Z
ingestion_batch_id used for this run: reg-20260103T154251Z
Written census.gold.dim_age_group (9 rows)
Written census.gold.metric_definitions (8 rows)
Person rows for Gold processing: 100000
Written census.gold.fact_population_by_region_year (9670 rows)
Written census.gold.indicators_literacy_employment (360 rows)
Written census.gold.fact_household_summary (82642 rows)
Written census.gold.income_distribution_by_region_year (360 rows)
Written census.gold.small_area_shrinkage_estimates (360 rows)
Written census.gold.fact_population_flat_region_year (360 rows)
Written census.gold.education_distribution_by_region_year (5387 rows)
Written census.gold.education_employment_crosswalk (1440 rows)
Appended ingestion audit to census.gold.ingestion_audit_v1
Gold materialization complete.
Run ID: gold_materialize_run_b2e70432
Ingestion batch id: reg-20260103T154251Z
