# Load DQX Checks

## Reference

### Runbook — `01_load_dqx_checks` (DQX Rule Loader)

**Purpose**  
Load and validate DQX rules defined in YAML, canonicalize them into a stable `check_id`, and overwrite the target **rules catalog** table with helpful metadata, comments, and a PK declaration. Designed for Databricks + Unity Catalog.

---

## Quick Start

**Standard run**
```python
main()
```

**Dry run (no write)**
```python
main(dry_run=True)
```

**Validate only (no write)**
```python
main(validate_only=True)
```

**Common overrides**
```python
main(
  output_config_path="resources/dqx_config.yaml",
  rules_dir=None,                          # else pulled from config
  batch_dedupe_mode="warn",                # warn | error | skip
  time_zone="America/Chicago"
)
```

---

## Inputs & Outputs

**Config (`resources/dqx_config.yaml`)**
- `dqx_yaml_checks`: folder with YAML files (recursive; supports `dbfs:/` via `/dbfs/...` bridge)
- `dqx_checks_config_table_name`: fully qualified UC table to overwrite, e.g. `dq_dev.dqx.checks`

**YAML rule shape (per rule)**
- Required: `table_name` (catalog.schema.table), `name`, `criticality` (`warn|warning|error`), `run_config_name`, `check.function`
- Optional: `check.for_each_column` (list of strings), `check.arguments` (map; values stringified), `filter`, `user_metadata`, `active` (default `true`)

**Output table (overwrite)**
- Schema includes: `check_id` (PRIMARY KEY by design), `check_id_payload`, `table_name`, `name`, `criticality`, `check{function,for_each_column,arguments}`, `filter`, `run_config_name`, `user_metadata`, `yaml_path`, `active`, `created_by/at`, `updated_by/at`.
- **Metadata**: table & column **comments** are applied.
- **Constraint**: attempts `ALTER TABLE ... ADD CONSTRAINT pk_check_id PRIMARY KEY (check_id) RELY` (Unity Catalog).
- **Guardrail**: hard **runtime assertion** ensures `check_id` uniqueness (fails the run if duplicates exist).

---

## What the Notebook Does (High Level)

1) **Environment banner** — prints cluster/warehouse info with local timezone.  
2) **Read config** — resolves `rules_dir` and `delta_table_name`.  
3) **Discover YAMLs** — recursively finds `*.yaml|*.yml` under `rules_dir` (including nested folders).  
4) **Validate (optional short-circuit)** — per file + per rule + **DQEngine.validate_checks**.  
5) **Load & canonicalize** — normalizes `filter`, sorts `for_each_column`, stringifies `arguments`, builds canonical JSON payload, computes `check_id=sha256(payload)`.  
6) **Batch de-dupe by `check_id`** — keep the lexicographically first (by `yaml_path`, `name`); mode: `warn|error|skip`.  
7) **Diagnostics** — displays totals, sample rules, and counts by table.  
8) **Overwrite target table** — writes Delta with `overwriteSchema=true` and timestamp casting.  
9) **Apply documentation** — table comment (fallback to TBLPROPERTIES) + column comments (with ALTER COLUMN; fallback to COMMENT ON COLUMN).  
10) **PK & enforcement** — add/refresh **informational PK** on `check_id` (if UC) and **assert uniqueness** at runtime.  
11) **Result summary** — small confirmation table and a printed success line.

---

## Parameters (function `main`)

- `output_config_path`: YAML path for loader config (default `resources/dqx_config.yaml`).  
- `rules_dir`: override rules folder (else from config).  
- `time_zone`: used for audit timestamps.  
- `dry_run`: if `True`, loads + displays but **does not write**.  
- `validate_only`: if `True`, runs validations and **exits**.  
- `required_fields`: override minimal YAML fields (default set used).  
- `batch_dedupe_mode`: `warn` (print & keep first) | `error` (fail) | `skip` (no dedupe).  
- `table_doc`: optional dict to override default table/column comments.

---

## Prereqs & Permissions

- Python deps: `pyspark`, `pyyaml`, `databricks-labs-dqx (0.8.x)` available on the cluster/warehouse.  
- Data access: **READ** on `rules_dir` (workspace file system or `dbfs:/`), **CREATE/ALTER** on target catalog & schema, and **CREATE TABLE / ALTER TABLE** privileges to set comments/constraints.  
- Unity Catalog strongly recommended (for PK declaration). The run still proceeds if constraint DDL is not supported; a message is logged.

---

## Failure Modes & Fixes

- **Validation failed** (missing required fields, invalid `criticality`, `table_name` not fully qualified, DQEngine errors): fix YAML and re-run.  
- **Duplicate `check_id`** during de-dupe or final uniqueness assertion: adjust YAML so canonical payloads differ (e.g., rule name is *not* part of identity).  
- **Insufficient privileges** during write/comments/constraint: grant `CREATE/ALTER` on target schema/table, or run with a service principal that has the rights.  
- **Column comments not applied**: only columns present in the table are touched; fallback DDL is attempted and any skips are logged.

---

## Post‑Run Sanity Checks (SQL)

```sql
-- Table present & counts
SELECT COUNT(*) AS rules, COUNT(DISTINCT check_id) AS unique_rules FROM dq_dev.dqx.checks;

-- Any dup check_id? (should be 0 rows)
SELECT check_id, COUNT(*) c FROM dq_dev.dqx.checks GROUP BY check_id HAVING COUNT(*)>1;

-- Peek arguments & coverage
SELECT name, check.function, check.arguments FROM dq_dev.dqx.checks LIMIT 20;
```

---

## Notes

- Overwrite semantics: this is the **system of record** for rules derived from YAML; manual edits in the table will be lost on next run.  
- `check_id` identity includes only `{table_name↓, filter, check.*}` — **not** `name`, `criticality`, or `run_config_name`.  
- Argument values are persisted as strings for stability; cast/parse (e.g., `try_cast`, `from_json`) downstream.

In [0]:
"""
START: main (01_load_dqx_checks)
|
|-- 0. Environment banner
|     |-- print_notebook_env(spark, local_timezone)
|
|-- 1. Load output config
|     |-- output_config = yaml(resources/dqx_config.yaml)
|     |-- rules_dir = output_config["dqx_yaml_checks"]
|     |-- delta_table_name = output_config["dqx_checks_config_table_name"]
|     |-- required_fields = ["table_name","name","criticality","run_config_name","check"]
|
|-- 2. Recursively discover YAMLs
|     |-- files = discover_yaml_files_recursive(rules_dir)  # supports dbfs:/ via /dbfs bridge
|     |-- DISPLAY: list of YAML paths
|
|-- 3. Validation-only short-circuit (if validate_only=True)
|     |-- For each file:
|     |     |-- load_yaml_rules(file)  # supports multi-doc YAML, flattens lists
|     |     |-- File-level checks: non-empty; no duplicate rule names
|     |     |-- For each rule:
|     |     |     |-- Rule-level checks:
|     |     |     |     |-- table_name fully qualified (catalog.schema.table)
|     |     |     |     |-- criticality ∈ {warn, warning, error}
|     |     |     |     |-- check.function exists
|     |     |     |-- DQEngine.validate_checks(docs) must pass
|     |-- PRINT summary; END
|
|-- 4. Load + flatten rules (normal path)
|     |-- all_rules = []
|     |-- For each file in files:
|           |-- docs = load_yaml_rules(file)
|           |-- File-level checks (as above)
|           |-- now = current_time_iso(time_zone)
|           |-- For each rule in docs:
|                 |-- Rule-level checks (as above)
|                 |-- Canonicalize:
|                 |     |-- filter := normalized whitespace (or "")
|                 |     |-- check := {
|                 |            function,
|                 |            for_each_column := sorted list or None (validated as list[str]),
|                 |            arguments := map with all values stringified
|                 |        }
|                 |     |-- user_metadata := map with stringified values (or None)
|                 |-- payload := JSON over {table_name_lower, filter, canonical check} (sorted keys, tight separators)
|                 |-- check_id := sha256(payload)
|                 |-- Append flattened dict:
|                 |     |-- keys: [check_id, check_id_payload, table_name, name, criticality,
|                 |                check{function, for_each_column, arguments}, filter, run_config_name,
|                 |                user_metadata, yaml_path=file, active (default True),
|                 |                created_by="AdminUser", created_at=now, updated_by=None, updated_at=None]
|           |-- DQEngine.validate_checks(docs) must pass (file-level)
|           |-- PRINT "[loader] {file}: rules={n}"
|
|-- 5. Batch de-duplication (by check_id only)
|     |-- Group all_rules by check_id
|     |-- Keep the lexicographically first (yaml_path, name); drop others
|     |-- Mode:
|           |-- "warn" (default): print detailed duplicate blocks
|           |-- "error": raise
|           |-- "skip": keep all (no drops)
|
|-- 6. Assemble DataFrame + display-first diagnostics
|     |-- df := spark.createDataFrame(all_rules, schema=TABLE_SCHEMA)
|     |-- DISPLAY: totals (count, distinct check_id, distinct (check_id, run_config_name))
|     |-- DISPLAY: sample(check_id, name, run_config_name, yaml_path) ordered by yaml_path desc
|     |-- DISPLAY: rules per table_name
|     |-- (If very small) DISPLAY: first 3 payloads
|
|-- 7. Dry-run short-circuit (if dry_run=True)
|     |-- DISPLAY: full rules preview ordered by (table_name, name)
|     |-- END
|
|-- 8. Overwrite Delta target table (ALWAYS overwrite)
|     |-- existed_before := spark.catalog.tableExists(delta_table_name)
|     |-- ensure_schema_exists(catalog.schema)
|     |-- Cast timestamps: created_at/updated_at := to_timestamp
|     |-- df.write.format("delta").mode("overwrite").option("overwriteSchema","true").saveAsTable(delta_table_name)
|
|-- 9. Apply documentation metadata (table + columns)
|     |-- doc := materialize DQX_CHECKS_CONFIG_METADATA with {TABLE_FQN}=delta_table_name
|     |-- Apply table comment (COMMENT ON TABLE; fallback to TBLPROPERTIES)
|     |-- If created_now (i.e., !existed_before):
|           |-- For each existing column: ALTER TABLE ... ALTER COLUMN ... COMMENT
|     |-- DISPLAY: preview of table comment + column comments
|
|-- 10. Write result summary
|     |-- DISPLAY: single-row summary (rules written, target table)
|     |-- PRINT: confirmation line with target table
|
END: main
"""

## Implementation

In [0]:
%pip install databricks-labs-dqx==0.8.0

In [0]:
dbutils.library.restartPython()

In [0]:
# Databricks notebook: 01_load_dqx_checks  (DISPLAY-FIRST + METADATA, COMMENTS-ALWAYS, PK + RUNTIME UNIQUENESS)

import os
import json
import hashlib
import yaml
from typing import Dict, Any, Optional, List, Tuple

from pyspark.sql import SparkSession, DataFrame, types as T
from pyspark.sql.functions import to_timestamp, col, desc

from databricks.labs.dqx.engine import DQEngine

from utils.print import print_notebook_env
from utils.timezone import current_time_iso
from utils.color import Color

# ======================================================
# Small helpers for Databricks-friendly tabular display
# ======================================================

def _can_display() -> bool:
    return "display" in globals()

def show_df(df: DataFrame, n: int = 100, truncate: bool = False) -> None:
    if _can_display():
        display(df.limit(n))
    else:
        df.show(n, truncate=truncate)

def display_section(title: str) -> None:
    print("\n" + f"{Color.b}{Color.deep_magenta}═{Color.r}" * 80)
    print(f"{Color.b}{Color.deep_magenta}║{Color.r} {Color.b}{Color.ghost_white}{title}{Color.r}")
    print(f"{Color.b}{Color.deep_magenta}═{Color.r}" * 80)

# =========================
# Target schema (Delta sink)
# =========================
TABLE_SCHEMA = T.StructType([
    T.StructField("check_id",            T.StringType(), False),  # PRIMARY KEY (by convention + runtime assertion)
    T.StructField("check_id_payload",    T.StringType(), False),
    T.StructField("table_name",          T.StringType(), False),

    # DQX fields
    T.StructField("name",                T.StringType(), False),
    T.StructField("criticality",         T.StringType(), False),
    T.StructField("check", T.StructType([
        T.StructField("function",        T.StringType(), False),
        T.StructField("for_each_column", T.ArrayType(T.StringType()), True),
        T.StructField("arguments",       T.MapType(T.StringType(), T.StringType()), True),
    ]), False),
    T.StructField("filter",              T.StringType(), True),
    T.StructField("run_config_name",     T.StringType(), False),
    T.StructField("user_metadata",       T.MapType(T.StringType(), T.StringType()), True),

    # Ops fields
    T.StructField("yaml_path",           T.StringType(), False),
    T.StructField("active",              T.BooleanType(), False),
    T.StructField("created_by",          T.StringType(), False),
    T.StructField("created_at",          T.StringType(), False),  # ISO string; cast on write
    T.StructField("updated_by",          T.StringType(), True),
    T.StructField("updated_at",          T.StringType(), True),
])

# ======================================================
# Documentation dictionary
# ======================================================
DQX_CHECKS_CONFIG_METADATA: Dict[str, Any] = {
    "table": "<override at create time>",
    "table_comment": (
        "## DQX Checks Configuration\n"
        "- One row per **unique canonical rule** generated from YAML (source of truth).\n"
        "- **Primary key**: `check_id` (sha256 of canonical payload). Uniqueness is enforced by the loader and a runtime assertion.\n"
        "- Rebuilt by the loader (typically **overwrite** semantics); manual edits will be lost.\n"
        "- Used by runners to resolve rules per `run_config_name` and by logs to map back to rule identity.\n"
        "- `check_id_payload` preserves the exact canonical JSON used to compute `check_id` for reproducibility.\n"
        "- `run_config_name` is a **routing tag**, not part of identity.\n"
        "- Only rows with `active=true` are executed."
    ),
    "columns": {
        "check_id": "PRIMARY KEY. Stable sha256 over canonical {table_name↓, filter, check.*}.",
        "check_id_payload": "Canonical JSON used to derive `check_id` (sorted keys, normalized values).",
        "table_name": "Target table FQN (`catalog.schema.table`). Lowercased in payload for stability.",
        "name": "Human-readable rule name. Used in UI/diagnostics and name-based joins when enriching logs.",
        "criticality": "Rule severity: `warn|warning|error`. Reporting normalizes warn/warning → `warning`.",
        "check": "Structured rule: `{function, for_each_column?, arguments?}`; argument values stringified.",
        "filter": "Optional SQL predicate applied before evaluation (row-level). Normalized in payload.",
        "run_config_name": "Execution group/tag. Drives which runs pick up this rule; **not** part of identity.",
        "user_metadata": "Free-form `map<string,string>` carried through to issues for traceability.",
        "yaml_path": "Absolute/volume path to the defining YAML doc (lineage).",
        "active": "If `false`, rule is ignored by runners.",
        "created_by": "Audit: creator/principal that materialized the row.",
        "created_at": "Audit: creation timestamp (cast to TIMESTAMP on write).",
        "updated_by": "Audit: last updater (nullable).",
        "updated_at": "Audit: last update timestamp (nullable; cast to TIMESTAMP on write).",
    },
}

def _materialize_table_doc(doc_template: Dict[str, Any], table_fqn: str) -> Dict[str, Any]:
    copy = json.loads(json.dumps(doc_template))
    copy["table"] = table_fqn
    if "table_comment" in copy and isinstance(copy["table_comment"], str):
        copy["table_comment"] = copy["table_comment"].replace("{TABLE_FQN}", table_fqn)
    return copy

# =========================
# YAML loading (robust)
# =========================
def load_yaml_rules(path: str) -> List[dict]:
    with open(path, "r") as fh:
        docs = list(yaml.safe_load_all(fh))
    out: List[dict] = []
    for d in docs:
        if d is None:
            continue
        if isinstance(d, dict):
            out.append(d)
        elif isinstance(d, list):
            out.extend([x for x in d if isinstance(x, dict)])
    return out

# =========================
# Canonicalization & IDs
# =========================
def _canon_filter(s: Optional[str]) -> str:
    return "" if not s else " ".join(str(s).split())

def _canon_check(chk: Dict[str, Any]) -> Dict[str, Any]:
    out = {"function": chk.get("function"), "for_each_column": None, "arguments": {}}
    fec = chk.get("for_each_column")
    if isinstance(fec, list):
        out["for_each_column"] = sorted([str(x) for x in fec]) or None
    args = chk.get("arguments") or {}
    canon_args: Dict[str, str] = {}
    for k, v in args.items():
        sv = "" if v is None else str(v).strip()
        if (sv.startswith("{") and sv.endswith("}")) or (sv.startswith("[") and sv.endswith("]")):
            try:
                sv = json.dumps(json.loads(sv), sort_keys=True, separators=(",", ":"))
            except Exception:
                pass
        canon_args[str(k)] = sv
    out["arguments"] = {k: canon_args[k] for k in sorted(canon_args)}
    return out

def compute_check_id_payload(table_name: str, check_dict: Dict[str, Any], filter_str: Optional[str]) -> str:
    payload_obj = {
        "table_name": (table_name or "").lower(),
        "filter": _canon_filter(filter_str),
        "check": _canon_check(check_dict or {}),
    }
    return json.dumps(payload_obj, sort_keys=True, separators=(",", ":"))

def compute_check_id_from_payload(payload: str) -> str:
    return hashlib.sha256(payload.encode()).hexdigest()

# =========================
# Conversions / validation
# =========================
def _stringify_map_values(d: Dict[str, Any]) -> Dict[str, str]:
    out: Dict[str, str] = {}
    for k, v in (d or {}).items():
        if isinstance(v, (list, dict)):
            out[k] = json.dumps(v)
        elif isinstance(v, bool):
            out[k] = "true" if v else "false"
        elif v is None:
            out[k] = "null"
        else:
            out[k] = str(v)
    return out

def validate_rules_file(rules: List[dict], file_path: str):
    if not rules:
        raise ValueError(f"No rules found in {file_path} (empty or invalid YAML).")
    probs, seen = [], set()
    for r in rules:
        nm = r.get("name")
        if not nm:
            probs.append(f"Missing rule name in {file_path}")
        if nm in seen:
            probs.append(f"Duplicate rule name '{nm}' in {file_path}")
        seen.add(nm)
    if probs:
        raise ValueError(f"File-level validation failed in {file_path}: {probs}")

def validate_rule_fields(
    rule: dict,
    file_path: str,
    required_fields: List[str],
    allowed_criticality={"error", "warn", "warning"},
):
    probs = []
    for f in required_fields:
        if not rule.get(f):
            probs.append(f"Missing required field '{f}' in rule '{rule.get('name')}' ({file_path})")
    if rule.get("table_name", "").count(".") != 2:
        probs.append(
            f"table_name '{rule.get('table_name')}' not fully qualified in rule '{rule.get('name')}' ({file_path})"
        )
    if rule.get("criticality") not in allowed_criticality:
        probs.append(
            f"Invalid criticality '{rule.get('criticality')}' in rule '{rule.get('name')}' ({file_path})"
        )
    if not rule.get("check", {}).get("function"):
        probs.append(f"Missing check.function in rule '{rule.get('name')}' ({file_path})")
    if probs:
        raise ValueError("Rule-level validation failed: " + "; ".join(probs))

def validate_with_dqx(rules: List[dict], file_path: str):
    status = DQEngine.validate_checks(rules)
    if getattr(status, "has_errors", False):
        raise ValueError(f"DQX validation failed in {file_path}:\n{status.to_string()}")

# =========================
# Build rows
# =========================
def process_yaml_file(path: str, required_fields: List[str], time_zone: str = "UTC", created_by: str = "AdminUser") -> List[dict]:
    docs = load_yaml_rules(path)
    if not docs:
        print(f"[skip] {path} has no rules.")
        return []

    validate_rules_file(docs, path)

    now = current_time_iso(time_zone)
    flat: List[dict] = []

    for rule in docs:
        validate_rule_fields(rule, path, required_fields=required_fields)

        raw_check = rule["check"] or {}
        payload = compute_check_id_payload(rule["table_name"], raw_check, rule.get("filter"))
        check_id = compute_check_id_from_payload(payload)

        function = raw_check.get("function")
        if not isinstance(function, str) or not function:
            raise ValueError(
                f"{path}: check.function must be a non-empty string (rule '{rule.get('name')}')."
            )

        for_each = raw_check.get("for_each_column")
        if for_each is not None and not isinstance(for_each, list):
            raise ValueError(
                f"{path}: check.for_each_column must be an array of strings (rule '{rule.get('name')}')."
            )
        if isinstance(for_each, list):
            try:
                for_each = [str(x) for x in for_each]
            except Exception:
                raise ValueError(
                    f"{path}: unable to cast for_each_column items to strings (rule '{rule.get('name')}')."
                )

        arguments = raw_check.get("arguments", {}) or {}
        if not isinstance(arguments, dict):
            raise ValueError(f"{path}: check.arguments must be a map (rule '{rule.get('name')}').")
        arguments = _stringify_map_values(arguments)

        user_metadata = rule.get("user_metadata")
        if user_metadata is not None:
            if not isinstance(user_metadata, dict):
                raise ValueError(f"{path}: user_metadata must be a map (rule '{rule.get('name')}').")
            user_metadata = _stringify_map_values(user_metadata)

        flat.append(
            {
                "check_id": check_id,
                "check_id_payload": payload,
                "table_name": rule["table_name"],
                "name": rule["name"],
                "criticality": rule["criticality"],
                "check": {
                    "function": function,
                    "for_each_column": for_each if for_each else None,
                    "arguments": arguments if arguments else None,
                },
                "filter": rule.get("filter"),
                "run_config_name": rule["run_config_name"],
                "user_metadata": user_metadata if user_metadata else None,
                "yaml_path": path,
                "active": rule.get("active", True),
                "created_by": created_by,
                "created_at": now,
                "updated_by": None,
                "updated_at": None,
            }
        )

    validate_with_dqx(docs, path)
    return flat

# =========================
# Batch dedupe (on check_id ONLY)
# =========================
def _fmt_rule_for_dup(r: dict) -> str:
    return (
        f"name={r.get('name')} | file={r.get('yaml_path')} | "
        f"criticality={r.get('criticality')} | run_config={r.get('run_config_name')} | "
        f"filter={r.get('filter')}"
    )

def dedupe_rules_in_batch_by_check_id(rules: List[dict], mode: str = "warn") -> List[dict]:
    groups: Dict[str, List[dict]] = {}
    for r in rules:
        groups.setdefault(r["check_id"], []).append(r)

    out: List[dict] = []
    dropped = 0
    blocks: List[str] = []

    for cid, lst in groups.items():
        if len(lst) == 1:
            out.append(lst[0]); continue
        lst = sorted(lst, key=lambda x: (x.get("yaml_path", ""), x.get("name", "")))
        keep, dups = lst[0], lst[1:]
        dropped += len(dups)
        head = f"[dup/batch/check_id] {len(dups)} duplicate(s) for check_id={cid[:12]}…"
        lines = ["    " + _fmt_rule_for_dup(x) for x in lst]
        tail = f"    -> keeping: name={keep.get('name')} | file={keep.get('yaml_path')}"
        blocks.append("\n".join([head, *lines, tail]))
        out.append(keep)

    if dropped:
        msg = "\n\n".join(blocks) + f"\n[dedupe/batch] total dropped={dropped}"
        if mode == "error":
            raise ValueError(msg)
        if mode == "warn":
            print(msg)
    return out

# =========================
# Comments + Constraints + Helpers
# =========================
def _esc_comment(s: str) -> str:
    return (s or "").replace("'", "''")

def _q_fqn(fqn: str) -> str:
    return ".".join(f"`{p}`" for p in fqn.split("."))

def preview_table_documentation(spark: SparkSession, table_fqn: str, doc: Dict[str, Any]) -> None:
    display_section("TABLE METADATA PREVIEW (markdown text stored in comments)")
    doc_df = spark.createDataFrame(
        [(table_fqn, doc.get("table_comment", ""))],
        schema="table string, table_comment_markdown string",
    )
    show_df(doc_df, n=1, truncate=False)

    cols = doc.get("columns", {}) or {}
    cols_df = spark.createDataFrame(
        [(k, v) for k, v in cols.items()],
        schema="column string, column_comment_markdown string",
    )
    show_df(cols_df, n=200, truncate=False)

def apply_table_documentation(
    spark: SparkSession,
    table_fqn: str,
    doc: Optional[Dict[str, Any]],
) -> None:
    """
    Apply table & column comments (ALWAYS).
    Uses robust syntax; falls back to TBLPROPERTIES for table comment if needed.
    """
    if not doc:
        return
    qtable = _q_fqn(table_fqn)

    # Table comment
    table_comment = _esc_comment(doc.get("table_comment", ""))
    if table_comment:
        try:
            spark.sql(f"COMMENT ON TABLE {qtable} IS '{table_comment}'")
        except Exception:
            spark.sql(f"ALTER TABLE {qtable} SET TBLPROPERTIES ('comment' = '{table_comment}')")

    # Column comments (always attempt)
    cols: Dict[str, str] = doc.get("columns", {}) or {}
    existing_cols = {f.name.lower() for f in spark.table(table_fqn).schema.fields}
    for col_name, comment in cols.items():
        if col_name.lower() not in existing_cols:
            continue
        qcol = f"`{col_name}`"
        comment_sql = f"ALTER TABLE {qtable} ALTER COLUMN {qcol} COMMENT '{_esc_comment(comment)}'"
        try:
            spark.sql(comment_sql)
        except Exception as e:
            # Try COMMENT ON COLUMN as fallback
            try:
                spark.sql(f"COMMENT ON COLUMN {qtable}.{qcol} IS '{_esc_comment(comment)}'")
            except Exception as e2:
                print(f"[meta] Skipped column comment for {table_fqn}.{col_name}: {e2}")

def ensure_primary_key_constraint(
    spark: SparkSession,
    table_fqn: str,
    column: str = "check_id",
    constraint_name: str = "pk_check_id",
    rely: bool = True,
) -> None:
    """
    Adds/refreshes an informational PRIMARY KEY constraint (Unity Catalog only).
    Not enforced by the engine; used for metadata/optimization. Will replace if exists.
    """
    qtable = _q_fqn(table_fqn)
    opt_rely = " RELY" if rely else ""
    try:
        spark.sql(f"ALTER TABLE {qtable} ADD CONSTRAINT {constraint_name} PRIMARY KEY ({column}){opt_rely}")
    except Exception:
        # try drop+add in case it already existed
        try:
            spark.sql(f"ALTER TABLE {qtable} DROP CONSTRAINT {constraint_name}")
            spark.sql(f"ALTER TABLE {qtable} ADD CONSTRAINT {constraint_name} PRIMARY KEY ({column}){opt_rely}")
        except Exception as e:
            print(f"[meta] Could not set PRIMARY KEY for {table_fqn}.{column}: {e}")

def assert_unique_check_id(spark: SparkSession, table_fqn: str) -> None:
    """
    Hard runtime enforcement: fail if any duplicate check_id exists.
    """
    qtable = _q_fqn(table_fqn)
    dup = spark.sql(f"""
        SELECT check_id, COUNT(*) AS c
        FROM {qtable}
        GROUP BY check_id
        HAVING COUNT(*) > 1
        LIMIT 1
    """).collect()
    if dup:
        cid, c = dup[0]["check_id"], dup[0]["c"]
        raise RuntimeError(f"[ENFORCE UNIQUE] Found duplicate check_id={cid} (count={c}) in {table_fqn}. Aborting.")

# =========================
# Recursive YAML discovery
# =========================
def _normalize_base(path: str) -> str:
    if path.startswith("dbfs:/"):
        return "/dbfs/" + path[len("dbfs:/") :].lstrip("/")
    return path

def discover_yaml_files_recursive(base_dir: str) -> List[str]:
    base = _normalize_base(base_dir)
    if not os.path.isdir(base):
        raise FileNotFoundError(f"Rules folder not found or not a directory: {base_dir} (resolved: {base})")
    out: List[str] = []
    for root, dirs, files in os.walk(base):
        dirs[:] = [d for d in dirs if not d.startswith(".")]
        for f in files:
            if f.startswith("."):
                continue
            fl = f.lower()
            if fl.endswith(".yaml") or fl.endswith(".yml"):
                out.append(os.path.join(root, f))
    return sorted(out)

# =========================
# Delta I/O (ALWAYS OVERWRITE)
# =========================
def ensure_schema_exists(spark: SparkSession, full_table_name: str):
    parts = full_table_name.split(".")
    if len(parts) != 3:
        raise ValueError(f"Expected a 3-part name (catalog.schema.table), got '{full_table_name}'")
    cat, sch, _ = parts
    spark.sql(f"CREATE SCHEMA IF NOT EXISTS `{cat}`.`{sch}`")

def overwrite_rules_into_delta(
    spark: SparkSession,
    df: DataFrame,
    delta_table_name: str,
    table_doc: Optional[Dict[str, Any]] = None,
):
    # Track existence BEFORE write (for info only)
    existed_before = spark.catalog.tableExists(delta_table_name)

    # Ensure schema
    ensure_schema_exists(spark, delta_table_name)

    # Cast audit timestamps
    df = df.withColumn("created_at", to_timestamp(col("created_at"))) \
           .withColumn("updated_at", to_timestamp(col("updated_at")))

    # Overwrite (content + schema)
    df.write.format("delta") \
      .mode("overwrite") \
      .option("overwriteSchema", "true") \
      .saveAsTable(delta_table_name)

    # Apply comments (table + columns) ALWAYS
    doc_to_apply = _materialize_table_doc(table_doc or DQX_CHECKS_CONFIG_METADATA, delta_table_name)
    apply_table_documentation(spark, delta_table_name, doc_to_apply)

    # Add/refresh informational PK on check_id (Unity Catalog required)
    ensure_primary_key_constraint(spark, delta_table_name, column="check_id", constraint_name="pk_check_id", rely=True)

    # Enforce uniqueness at runtime (hard fail if violated)
    assert_unique_check_id(spark, delta_table_name)

    # Preview metadata
    preview_table_documentation(spark, delta_table_name, doc_to_apply)

    # Display a clean confirmation table
    display_section("WRITE RESULT (Delta)")
    summary = spark.createDataFrame(
        [(df.count(), delta_table_name, "overwrite", "pk_check_id")],
        schema="`rules written` long, `target table` string, `mode` string, `constraint` string",
    )
    show_df(summary, n=1)
    print(f"\n{Color.b}{Color.ivory}Rules written to: '{Color.r}{Color.b}{Color.chartreuse}{delta_table_name}{Color.r}{Color.b}{Color.ivory}'  (PK declared, uniqueness asserted)")

# =========================
# Display-first debug helpers
# =========================
def debug_display_batch(spark: SparkSession, df_rules: DataFrame) -> None:
    display_section("SUMMARY OF RULES LOADED FROM YAML")
    totals = [
        (
            df_rules.count(),
            df_rules.select("check_id").distinct().count(),
            df_rules.select("check_id", "run_config_name").distinct().count(),
        )
    ]
    totals_df = spark.createDataFrame(
        totals,
        schema="`total number of rules found` long, `unique rules found` long, `distinct pair of rules` long",
    )
    show_df(totals_df, n=1)

    display_section("SAMPLE OF RULES LOADED FROM YAML (check_id, name, run_config_name, yaml_path)")
    sample_cols = df_rules.select("check_id", "name", "run_config_name", "yaml_path").orderBy(
        desc("yaml_path")
    )
    show_df(sample_cols, n=50, truncate=False)

    display_section("RULES LOADED PER TABLE")
    by_table = df_rules.groupBy("table_name").count().orderBy(desc("count"))
    show_df(by_table, n=200)

def print_rules_df(spark: SparkSession, rules: List[dict]) -> Optional[DataFrame]:
    if not rules:
        print("No rules to show.")
        return None
    df = (
        spark.createDataFrame(rules, schema=TABLE_SCHEMA)
        .withColumn("created_at", to_timestamp(col("created_at")))
        .withColumn("updated_at", to_timestamp(col("updated_at")))
    )
    debug_display_batch(spark, df)
    return df

# =========================
# Validation (works off a file list)
# =========================
def validate_rule_files(file_paths: List[str], required_fields: List[str], fail_fast: bool = True) -> List[str]:
    errors = []
    for full_path in file_paths:
        print(f"\nValidating file: {full_path}")
        try:
            docs = load_yaml_rules(full_path)
            if not docs:
                print(f"  (empty) skipped: {full_path}")
                continue
            validate_rules_file(docs, full_path)
            print(f"  File-level validation passed for {full_path}")
            for rule in docs:
                validate_rule_fields(rule, full_path, required_fields=required_fields)
                print(f"    Rule-level validation passed for rule '{rule.get('name')}'")
            validate_with_dqx(docs, full_path)
            print(f"  DQX validation PASSED for {full_path}")
        except Exception as ex:
            print(f"  Validation FAILED for file {full_path}\n  Reason: {ex}")
            errors.append(str(ex))
            if fail_fast:
                break
    if not errors:
        print("\nAll YAML rule files are valid!")
    else:
        print("\nRule validation errors found:")
        for e in errors:
            print(e)
    return errors

# =========================
# Recursive discovery + write
# =========================
def _load_output_config(path: str) -> Dict[str, Any]:
    with open(path, "r") as fh:
        return yaml.safe_load(fh) or {}

def main(
    output_config_path: str = "resources/dqx_config.yaml",
    rules_dir: Optional[str] = None,
    time_zone: str = "America/Chicago",
    dry_run: bool = False,
    validate_only: bool = False,
    required_fields: Optional[List[str]] = None,
    batch_dedupe_mode: str = "warn",  # warn | error | skip
    table_doc: Optional[Dict[str, Any]] = None,
    created_by: str = "AdminUser",
):
    spark = SparkSession.builder.getOrCreate()
    print_notebook_env(spark, local_timezone=time_zone)

    output_config = _load_output_config(output_config_path)
    rules_dir = rules_dir or output_config["dqx_yaml_checks"]
    delta_table_name = output_config["dqx_checks_config_table_name"]

    required_fields = required_fields or ["table_name", "name", "criticality", "run_config_name", "check"]

    # Recursively discover YAML files
    yaml_files = discover_yaml_files_recursive(rules_dir)
    display_section("YAML FILES DISCOVERED (recursive)")
    files_df = spark.createDataFrame([(p,) for p in yaml_files], "yaml_path string")
    show_df(files_df, n=500, truncate=False)

    if validate_only:
        print("\nValidation only: not writing any rules.")
        errs = validate_rule_files(yaml_files, required_fields)
        return {
            "mode": "validate_only",
            "config_path": output_config_path,
            "rules_files": len(yaml_files),
            "errors": errs,
        }

    # Collect rules
    all_rules: List[dict] = []
    for full_path in yaml_files:
        file_rules = process_yaml_file(full_path, required_fields=required_fields, time_zone=time_zone, created_by=created_by)
        if file_rules:
            all_rules.extend(file_rules)
            print(f"[loader] {full_path}: rules={len(file_rules)}")

    if not all_rules:
        print("No rules discovered; nothing to do.")
        return {"mode": "no_op", "config_path": output_config_path, "rules_files": len(yaml_files), "wrote_rows": 0}

    print(f"[loader] total parsed rules (pre-dedupe): {len(all_rules)}")

    # In-batch dedupe on check_id only
    pre_dedupe = len(all_rules)
    all_rules = dedupe_rules_in_batch_by_check_id(all_rules, mode=batch_dedupe_mode)
    post_dedupe = len(all_rules)

    # Assemble DataFrame and DISPLAY diagnostics
    df = spark.createDataFrame(all_rules, schema=TABLE_SCHEMA)
    debug_display_batch(spark, df)

    unique_check_ids = df.select("check_id").distinct().count()
    distinct_pairs  = df.select("check_id", "run_config_name").distinct().count()

    if dry_run:
        display_section("DRY-RUN: FULL RULES PREVIEW")
        show_df(df.orderBy("table_name", "name"), n=1000, truncate=False)
        return {
            "mode": "dry_run",
            "config_path": output_config_path,
            "rules_files": len(yaml_files),
            "rules_pre_dedupe": pre_dedupe,
            "rules_post_dedupe": post_dedupe,
            "unique_check_ids": unique_check_ids,
            "distinct_rule_run_pairs": distinct_pairs,
            "target_table": delta_table_name,
            "wrote_rows": 0,
        }

    # ALWAYS OVERWRITE on each run
    doc = _materialize_table_doc(table_doc or DQX_CHECKS_CONFIG_METADATA, delta_table_name)
    overwrite_rules_into_delta(spark, df, delta_table_name, table_doc=doc)
    wrote_rows = df.count()
    print(f"{Color.b}{Color.ivory}Finished writing rules to '{Color.r}{Color.b}{Color.i}{Color.sea_green}{delta_table_name}{Color.r}{Color.b}{Color.ivory}' (overwrite){Color.r}.")

    return {
        "mode": "overwrite",
        "config_path": output_config_path,
        "rules_files": len(yaml_files),
        "rules_pre_dedupe": pre_dedupe,
        "rules_post_dedupe": post_dedupe,
        "unique_check_ids": unique_check_ids,
        "distinct_rule_run_pairs": distinct_pairs,
        "target_table": delta_table_name,
        "wrote_rows": wrote_rows,
        "constraint": "pk_check_id",
    }

def load_checks(
    dqx_cfg_yaml: str = "resources/dqx_config.yaml",
    created_by: str = "AdminUser",
    time_zone: str = "America/Chicago",
    dry_run: bool = False,
    validate_only: bool = False,
    batch_dedupe_mode: str = "warn",
    table_doc: Optional[Dict[str, Any]] = None,
):
    """
    Notebook-friendly entrypoint mirroring 02_run_dqx_checks.run_checks(...).
    """
    return main(
        output_config_path=dqx_cfg_yaml,
        rules_dir=None,
        time_zone=time_zone,
        dry_run=dry_run,
        validate_only=validate_only,
        required_fields=None,
        batch_dedupe_mode=batch_dedupe_mode,
        table_doc=table_doc,
        created_by=created_by,
    )

# ---- run it ----
res = load_checks(
    dqx_cfg_yaml="resources/dqx_config.yaml",
    created_by="AdminUser",
    # dry_run=True,
    # validate_only=True,
    batch_dedupe_mode="warn",
)
print(res)