In [None]:
from datetime import datetime, timezone

from pyspark.sql import functions as F


CATALOG = "main"
SCHEMA = "retail_p1"
NAMESPACE = f"{CATALOG}.{SCHEMA}"

BRONZE_ORDERS_TABLE = f"{NAMESPACE}.bronze_orders"
SILVER_ORDERS_TABLE = f"{NAMESPACE}.silver_orders_clean"
SILVER_CUSTOMERS_SCD2_TABLE = f"{NAMESPACE}.silver_customers_scd2"
SILVER_PRODUCTS_LATEST_TABLE = f"{NAMESPACE}.silver_products_latest"
GOLD_DAILY_REVENUE_TABLE = f"{NAMESPACE}.gold_daily_revenue"
DQ_RESULTS_TABLE = f"{NAMESPACE}.dq_results"


def get_widget(name: str, default: str) -> str:
    try:
        dbutils.widgets.text(name, default)
        value = dbutils.widgets.get(name).strip()
        return value or default
    except Exception:
        return default


FRESHNESS_HOURS = int(get_widget("freshness_hours", "168"))
FAIL_ON_ERROR = get_widget("fail_on_error", "false").lower() == "true"
RUN_ID = get_widget("run_id", datetime.now(timezone.utc).strftime("%Y%m%d%H%M%S"))
RUN_TS = datetime.now(timezone.utc)

spark.sql(f"CREATE SCHEMA IF NOT EXISTS {NAMESPACE}")
spark.sql(
    f"""
    CREATE TABLE IF NOT EXISTS {DQ_RESULTS_TABLE} (
      check_name STRING,
      check_type STRING,
      status STRING,
      failed_count DOUBLE,
      threshold DOUBLE,
      run_ts TIMESTAMP,
      details STRING
    ) USING DELTA
    """
)


def run_check(check_name: str, check_type: str, failed_count_sql: str, threshold: float = 0.0, details: str = ""):
    failed_raw = spark.sql(failed_count_sql).first()[0]
    failed_count = float(failed_raw) if failed_raw is not None else 0.0
    status = "PASS" if failed_count <= threshold else "FAIL"

    result_df = spark.createDataFrame(
        [
            (
                check_name,
                check_type,
                status,
                failed_count,
                float(threshold),
                RUN_TS,
                f"run_id={RUN_ID}; {details}",
            )
        ],
        "check_name string, check_type string, status string, failed_count double, threshold double, run_ts timestamp, details string",
    )
    result_df.write.format("delta").mode("append").saveAsTable(DQ_RESULTS_TABLE)
    print(f"[{status}] {check_name}: failed_count={failed_count}, threshold={threshold}")
    return status, failed_count


check_definitions = [
    (
        "silver_orders_unique_key",
        "uniqueness",
        f"""
        SELECT COUNT(*) AS failed_count
        FROM (
          SELECT order_id, product_id
          FROM {SILVER_ORDERS_TABLE}
          GROUP BY order_id, product_id
          HAVING COUNT(*) > 1
        )
        """,
        0.0,
        "Expect one row per (order_id, product_id) in silver_orders_clean.",
    ),
    (
        "silver_orders_null_customer_id",
        "null_check",
        f"SELECT COUNT(*) AS failed_count FROM {SILVER_ORDERS_TABLE} WHERE customer_id IS NULL",
        0.0,
        "Customer id is required in silver_orders_clean.",
    ),
    (
        "silver_orders_quantity_positive",
        "domain_check",
        f"SELECT COUNT(*) AS failed_count FROM {SILVER_ORDERS_TABLE} WHERE quantity <= 0 OR quantity IS NULL",
        0.0,
        "Quantity must be positive.",
    ),
    (
        "silver_orders_price_non_negative",
        "domain_check",
        f"SELECT COUNT(*) AS failed_count FROM {SILVER_ORDERS_TABLE} WHERE price < 0 OR price IS NULL",
        0.0,
        "Price must be non-negative.",
    ),
    (
        "silver_orders_orphan_customers",
        "referential_integrity",
        f"""
        SELECT COUNT(*) AS failed_count
        FROM {SILVER_ORDERS_TABLE} o
        LEFT JOIN {SILVER_CUSTOMERS_SCD2_TABLE} c
          ON o.customer_id = c.customer_id
          AND c.is_current = true
        WHERE c.customer_id IS NULL
        """,
        0.0,
        "All orders must map to a current customer row.",
    ),
    (
        "silver_orders_orphan_products",
        "referential_integrity",
        f"""
        SELECT COUNT(*) AS failed_count
        FROM {SILVER_ORDERS_TABLE} o
        LEFT JOIN {SILVER_PRODUCTS_LATEST_TABLE} p
          ON o.product_id = p.product_id
        WHERE p.product_id IS NULL
        """,
        0.0,
        "All orders must map to product dimension rows.",
    ),
    (
        "scd2_single_current_row",
        "scd2_integrity",
        f"""
        SELECT COUNT(*) AS failed_count
        FROM (
          SELECT customer_id
          FROM {SILVER_CUSTOMERS_SCD2_TABLE}
          GROUP BY customer_id
          HAVING SUM(CASE WHEN is_current THEN 1 ELSE 0 END) <> 1
        )
        """,
        0.0,
        "Each customer must have exactly one current SCD2 row.",
    ),
    (
        "scd2_no_overlapping_windows",
        "scd2_integrity",
        f"""
        SELECT COUNT(DISTINCT customer_id) AS failed_count
        FROM (
          SELECT a.customer_id
          FROM {SILVER_CUSTOMERS_SCD2_TABLE} a
          JOIN {SILVER_CUSTOMERS_SCD2_TABLE} b
           ON a.customer_id = b.customer_id
           AND (
             a.valid_from < b.valid_from
             OR (a.valid_from = b.valid_from AND a.customer_sk < b.customer_sk)
           )
           AND coalesce(a.valid_to, CAST('9999-12-31 00:00:00' AS TIMESTAMP)) > b.valid_from
           AND coalesce(b.valid_to, CAST('9999-12-31 00:00:00' AS TIMESTAMP)) > a.valid_from
        )
        """,
        0.0,
        "SCD2 validity windows must not overlap for the same customer.",
    ),
    (
        "gold_daily_revenue_non_negative_metrics",
        "gold_sanity",
        f"""
        SELECT COUNT(*) AS failed_count
        FROM {GOLD_DAILY_REVENUE_TABLE}
        WHERE gross_revenue < 0 OR net_revenue < 0 OR aov < 0 OR order_count < 0
        """,
        0.0,
        "Gold revenue metrics should not be negative.",
    ),
    (
        "bronze_orders_freshness",
        "freshness",
        f"""
        SELECT
          CASE
            WHEN MAX(_ingest_ts) IS NULL THEN 1
            WHEN ((unix_timestamp(current_timestamp()) - unix_timestamp(MAX(_ingest_ts))) / 3600.0) > {FRESHNESS_HOURS}
              THEN 1
            ELSE 0
          END AS failed_count
        FROM {BRONZE_ORDERS_TABLE}
        """,
        0.0,
        f"Bronze orders should be newer than {FRESHNESS_HOURS} hours.",
    ),
]

results = [run_check(*check_def) for check_def in check_definitions]
failures = [row for row in results if row[0] == "FAIL"]

summary_df = spark.sql(
    f"""
    SELECT
      run_ts,
      SUM(CASE WHEN status = 'PASS' THEN 1 ELSE 0 END) AS pass_count,
      SUM(CASE WHEN status = 'FAIL' THEN 1 ELSE 0 END) AS fail_count
    FROM {DQ_RESULTS_TABLE}
    WHERE details LIKE 'run_id={RUN_ID};%'
    GROUP BY run_ts
    ORDER BY run_ts DESC
    """
)
summary_df.show(truncate=False)

if FAIL_ON_ERROR and failures:
    raise RuntimeError(f"Data quality checks failed for run_id={RUN_ID}: {failures}")
