# Load DQX Checks

In [0]:
"""
START: update_dqx_rules_table
|
|-- 1. Check if Delta config table exists:
|     |-- If table does NOT exist:
|     |     |-- Create empty Delta table using TABLE_SCHEMA.
|     |
|     |-- If table DOES exist:
|           |-- Proceed to next step.
|
|-- 2. For each YAML file in rules_dir:
|     |-- Parse file (YAML load).
|     |-- FILE-LEVEL validation (all rules in file target same table; no dup rule names; filename matches table name).
|     |-- For each rule:
|           |-- RULE-LEVEL validation (required fields, format, criticality, etc).
|     |-- DQX syntax validation (DQEngine.validate_checks).
|     |-- For each rule:
|           |-- Extract run_config_name, look up valid_target_table and quarantine_target_table from output config.
|           |-- Flatten and collect rule as dict (with hash_id, audit fields, etc).
|
|-- 3. Combine all flattened rules into one list.
|
|-- 4. Upsert all rules into Delta table:
|     |-- If entry exists (yaml_path, table_name, name): UPDATE all fields except created_at/created_by.
|     |-- If entry missing: INSERT, set created_at=now(UTC), created_by='admin', updated_at/updated_by=None.
|
|-- END: update_dqx_rules_table
"""

In [0]:
%pip install databricks-labs-dqx==0.8.0

In [0]:
dbutils.library.restartPython()

In [0]:
# Databricks notebook: 01_load_dqx_checks  (DISPLAY-FIRST + METADATA, SAFE COMMENTS, RECURSIVE YAML DISCOVERY)
# - Uses display() for clean tables
# - Clearer section titles/column names
# - Applies TABLE comment always (markdown), COLUMN comments only on first create
# - Uses ALTER TABLE ... ALTER COLUMN COMMENT to avoid parser issues
# - Shows a markdown-ish preview table of table/column docs
# - NEW: Recursively loads *all* .yaml/.yml files under the configured folder (folders inside folders)

import os
import json
import hashlib
import yaml
from typing import Dict, Any, Optional, List, Tuple

from pyspark.sql import SparkSession, DataFrame, types as T
from pyspark.sql.functions import to_timestamp, col, desc

# from delta.tables import DeltaTable   # not needed when we always overwrite
from databricks.labs.dqx.engine import DQEngine

from utils.print import print_notebook_env
from utils.timezone import current_time_iso

# ======================================================
# Small helpers for Databricks-friendly tabular display
# ======================================================

def _can_display() -> bool:
    return "display" in globals()

def show_df(df: DataFrame, n: int = 100, truncate: bool = False) -> None:
    if _can_display():
        display(df.limit(n))
    else:
        df.show(n, truncate=truncate)

def display_section(title: str) -> None:
    print("\n" + "═" * 80)
    print(f"║ {title}")
    print("═" * 80)

# =========================
# Target schema (Delta sink)
# =========================
TABLE_SCHEMA = T.StructType([
    T.StructField("check_id",            T.StringType(), False),  # sha256 over canonical payload
    T.StructField("check_id_payload",    T.StringType(), False),  # canonical JSON used to compute check_id
    T.StructField("table_name",          T.StringType(), False),

    # DQX fields
    T.StructField("name",                T.StringType(), False),
    T.StructField("criticality",         T.StringType(), False),
    T.StructField("check", T.StructType([
        T.StructField("function",        T.StringType(), False),
        T.StructField("for_each_column", T.ArrayType(T.StringType()), True),
        T.StructField("arguments",       T.MapType(T.StringType(), T.StringType()), True),
    ]), False),
    T.StructField("filter",              T.StringType(), True),
    T.StructField("run_config_name",     T.StringType(), False),
    T.StructField("user_metadata",       T.MapType(T.StringType(), T.StringType()), True),

    # Ops fields
    T.StructField("yaml_path",           T.StringType(), False),
    T.StructField("active",              T.BooleanType(), False),
    T.StructField("created_by",          T.StringType(), False),
    T.StructField("created_at",          T.StringType(), False),  # ISO string; cast on write
    T.StructField("updated_by",          T.StringType(), True),
    T.StructField("updated_at",          T.StringType(), True),
])

# =========================
# YAML loading (robust)
# =========================
def load_yaml_rules(path: str) -> List[dict]:
    with open(path, "r") as fh:
        docs = list(yaml.safe_load_all(fh))
    out: List[dict] = []
    for d in docs:
        if d is None:
            continue
        if isinstance(d, dict):
            out.append(d)
        elif isinstance(d, list):
            out.extend([x for x in d if isinstance(x, dict)])
    return out

# =========================
# Canonicalization & IDs
# =========================
def _canon_filter(s: Optional[str]) -> str:
    return "" if not s else " ".join(str(s).split())

def _canon_check(chk: Dict[str, Any]) -> Dict[str, Any]:
    out = {"function": chk.get("function"), "for_each_column": None, "arguments": {}}
    fec = chk.get("for_each_column")
    if isinstance(fec, list):
        out["for_each_column"] = sorted([str(x) for x in fec]) or None
    args = chk.get("arguments") or {}
    canon_args: Dict[str, str] = {}
    for k, v in args.items():
        sv = "" if v is None else str(v).strip()
        if (sv.startswith("{") and sv.endswith("}")) or (sv.startswith("[") and sv.endswith("]")):
            try:
                sv = json.dumps(json.loads(sv), sort_keys=True, separators=(",", ":"))
            except Exception:
                pass
        canon_args[str(k)] = sv
    out["arguments"] = {k: canon_args[k] for k in sorted(canon_args)}
    return out

def compute_check_id_payload(table_name: str, check_dict: Dict[str, Any], filter_str: Optional[str]) -> str:
    payload_obj = {
        "table_name": (table_name or "").lower(),
        "filter": _canon_filter(filter_str),
        "check": _canon_check(check_dict or {}),
    }
    return json.dumps(payload_obj, sort_keys=True, separators=(",", ":"))

def compute_check_id_from_payload(payload: str) -> str:
    return hashlib.sha256(payload.encode()).hexdigest()

# =========================
# Conversions / validation
# =========================
def _stringify_map_values(d: Dict[str, Any]) -> Dict[str, str]:
    out: Dict[str, str] = {}
    for k, v in (d or {}).items():
        if isinstance(v, (list, dict)):
            out[k] = json.dumps(v)
        elif isinstance(v, bool):
            out[k] = "true" if v else "false"
        elif v is None:
            out[k] = "null"
        else:
            out[k] = str(v)
    return out

def validate_rules_file(rules: List[dict], file_path: str):
    if not rules:
        raise ValueError(f"No rules found in {file_path} (empty or invalid YAML).")
    probs, seen = [], set()
    for r in rules:
        nm = r.get("name")
        if not nm:
            probs.append(f"Missing rule name in {file_path}")
        if nm in seen:
            probs.append(f"Duplicate rule name '{nm}' in {file_path}")
        seen.add(nm)
    if probs:
        raise ValueError(f"File-level validation failed in {file_path}: {probs}")

def validate_rule_fields(
    rule: dict,
    file_path: str,
    required_fields: List[str],
    allowed_criticality={"error", "warn", "warning"},
):
    probs = []
    for f in required_fields:
        if not rule.get(f):
            probs.append(f"Missing required field '{f}' in rule '{rule.get('name')}' ({file_path})")
    if rule.get("table_name", "").count(".") != 2:
        probs.append(
            f"table_name '{rule.get('table_name')}' not fully qualified in rule '{rule.get('name')}' ({file_path})"
        )
    if rule.get("criticality") not in allowed_criticality:
        probs.append(
            f"Invalid criticality '{rule.get('criticality')}' in rule '{rule.get('name')}' ({file_path})"
        )
    if not rule.get("check", {}).get("function"):
        probs.append(f"Missing check.function in rule '{rule.get('name')}' ({file_path})")
    if probs:
        raise ValueError("Rule-level validation failed: " + "; ".join(probs))

def validate_with_dqx(rules: List[dict], file_path: str):
    status = DQEngine.validate_checks(rules)
    if getattr(status, "has_errors", False):
        raise ValueError(f"DQX validation failed in {file_path}:\n{status.to_string()}")

# =========================
# Build rows
# =========================
def process_yaml_file(path: str, required_fields: List[str], time_zone: str = "UTC") -> List[dict]:
    docs = load_yaml_rules(path)
    if not docs:
        print(f"[skip] {path} has no rules (empty/comment-only).")
        return []

    validate_rules_file(docs, path)

    now = current_time_iso(time_zone)
    flat: List[dict] = []

    for rule in docs:
        validate_rule_fields(rule, path, required_fields=required_fields)

        raw_check = rule["check"] or {}
        payload = compute_check_id_payload(rule["table_name"], raw_check, rule.get("filter"))
        check_id = compute_check_id_from_payload(payload)

        function = raw_check.get("function")
        if not isinstance(function, str) or not function:
            raise ValueError(
                f"{path}: check.function must be a non-empty string (rule '{rule.get('name')}')."
            )

        for_each = raw_check.get("for_each_column")
        if for_each is not None and not isinstance(for_each, list):
            raise ValueError(
                f"{path}: check.for_each_column must be an array of strings (rule '{rule.get('name')}')."
            )
        if isinstance(for_each, list):
            try:
                for_each = [str(x) for x in for_each]
            except Exception:
                raise ValueError(
                    f"{path}: unable to cast for_each_column items to strings (rule '{rule.get('name')}')."
                )

        arguments = raw_check.get("arguments", {}) or {}
        if not isinstance(arguments, dict):
            raise ValueError(f"{path}: check.arguments must be a map (rule '{rule.get('name')}').")
        arguments = _stringify_map_values(arguments)

        user_metadata = rule.get("user_metadata")
        if user_metadata is not None:
            if not isinstance(user_metadata, dict):
                raise ValueError(
                    f"{path}: user_metadata must be a map (rule '{rule.get('name')}')."
                )
            user_metadata = _stringify_map_values(user_metadata)

        flat.append(
            {
                "check_id": check_id,
                "check_id_payload": payload,
                "table_name": rule["table_name"],
                "name": rule["name"],
                "criticality": rule["criticality"],
                "check": {
                    "function": function,
                    "for_each_column": for_each if for_each else None,
                    "arguments": arguments if arguments else None,
                },
                "filter": rule.get("filter"),
                "run_config_name": rule["run_config_name"],
                "user_metadata": user_metadata if user_metadata else None,
                "yaml_path": path,
                "active": rule.get("active", True),
                "created_by": "AdminUser",
                "created_at": now,
                "updated_by": None,
                "updated_at": None,
            }
        )

    validate_with_dqx(docs, path)
    return flat

# =========================
# Batch dedupe (on check_id ONLY)
# =========================
def _fmt_rule_for_dup(r: dict) -> str:
    return (
        f"name={r.get('name')} | file={r.get('yaml_path')} | "
        f"criticality={r.get('criticality')} | run_config={r.get('run_config_name')} | "
        f"filter={r.get('filter')}"
    )

def dedupe_rules_in_batch_by_check_id(rules: List[dict], mode: str = "warn") -> List[dict]:
    groups: Dict[str, List[dict]] = {}
    for r in rules:
        groups.setdefault(r["check_id"], []).append(r)

    out: List[dict] = []
    dropped = 0
    blocks: List[str] = []

    for cid, lst in groups.items():
        if len(lst) == 1:
            out.append(lst[0])
            continue
        lst = sorted(lst, key=lambda x: (x.get("yaml_path", ""), x.get("name", "")))
        keep, dups = lst[0], lst[1:]
        dropped += len(dups)
        head = f"[dup/batch/check_id] {len(dups)} duplicate(s) for check_id={cid[:12]}…"
        lines = ["    " + _fmt_rule_for_dup(x) for x in lst]
        tail = f"    -> keeping: name={keep.get('name')} | file={keep.get('yaml_path')}"
        blocks.append("\n".join([head, *lines, tail]))
        out.append(keep)

    if dropped:
        msg = "\n\n".join(blocks) + f"\n[dedupe/batch] total dropped={dropped}"
        if mode == "error":
            raise ValueError(msg)
        if mode == "warn":
            print(msg)
    return out

# =========================
# Table & column comments (safe, creation-only for columns)
# =========================
def _esc_comment(s: str) -> str:
    return (s or "").replace("'", "''")

def _q_fqn(fqn: str) -> str:
    # quote a.b.c -> `a`.`b`.`c`
    return ".".join(f"`{p}`" for p in fqn.split("."))

def default_table_documentation(target_table_fqn: str) -> Dict[str, Any]:
    """
    Default doc structure (markdown strings). Pass your own dict to main(table_doc=...).
    """
    return {
        "table": target_table_fqn,
        "table_comment": (
            "# DQX Checks Configuration\n"
            f"- **Target**: `{target_table_fqn}`\n"
            "- Stores flattened Data Quality checks generated from YAML files.\n"
            "- `check_id` is a stable hash over (table_name, filter, check.*).\n"
        ),
        "columns": {
            "check_id": "**Hash** of canonical rule payload. Used for identity/dedupe.",
            "check_id_payload": "Canonical **JSON** used to compute `check_id`.",
            "table_name": "Fully qualified **target table** (`catalog.schema.table`).",
            "name": "Human-readable **rule name**.",
            "criticality": "Rule severity: `warn|warning|error`.",
            "check": "Structured **check** object `{function, for_each_column, arguments}`.",
            "filter": "Optional row-level **filter** expression.",
            "run_config_name": "**Execution group/tag** to route this rule.",
            "user_metadata": "User-provided **metadata** `map<string,string>`.",
            "yaml_path": "Source **YAML** file path that defined this rule.",
            "active": "If **false**, the rule is ignored by runners.",
            "created_by": "Audit: **creator** of the row.",
            "created_at": "Audit: **creation timestamp**.",
            "updated_by": "Audit: **last updater**.",
            "updated_at": "Audit: **last update timestamp**.",
        },
    }

def preview_table_documentation(spark: SparkSession, table_fqn: str, doc: Dict[str, Any]) -> None:
    display_section("TABLE METADATA PREVIEW (markdown text stored in comments)")
    doc_df = spark.createDataFrame(
        [(table_fqn, doc.get("table_comment", ""))],
        schema="table string, table_comment_markdown string",
    )
    show_df(doc_df, n=1, truncate=False)

    cols = doc.get("columns", {}) or {}
    cols_df = spark.createDataFrame(
        [(k, v) for k, v in cols.items()],
        schema="column string, column_comment_markdown string",
    )
    show_df(cols_df, n=200, truncate=False)

def apply_table_documentation(
    spark: SparkSession,
    table_fqn: str,
    doc: Optional[Dict[str, Any]],
    created_now: bool,
) -> None:
    """
    Apply table & column comments.
    - TABLE comment: always applied (markdown string).
    - COLUMN comments: only when created_now=True (your requested behavior).
    Uses robust syntax to avoid parser issues.
    """
    if not doc:
        return

    qtable = _q_fqn(table_fqn)

    # Prefer COMMENT ON TABLE; fallback to TBLPROPERTIES
    table_comment = _esc_comment(doc.get("table_comment", ""))
    if table_comment:
        try:
            spark.sql(f"COMMENT ON TABLE {qtable} IS '{table_comment}'")
        except Exception:
            spark.sql(f"ALTER TABLE {qtable} SET TBLPROPERTIES ('comment' = '{table_comment}')")

    if not created_now:
        return  # Only set column comments on first create

    cols: Dict[str, str] = doc.get("columns", {}) or {}
    # Only try columns that exist
    existing_cols = {f.name.lower() for f in spark.table(table_fqn).schema.fields}
    for col_name, comment in cols.items():
        if col_name.lower() not in existing_cols:
            continue
        qcol = f"`{col_name}`"
        comment_sql = f"ALTER TABLE {qtable} ALTER COLUMN {qcol} COMMENT '{_esc_comment(comment)}'"
        try:
            spark.sql(comment_sql)
        except Exception as e:
            print(f"[meta] Skipped column comment for {table_fqn}.{col_name}: {e}")

# =========================
# Recursive YAML discovery
# =========================
def _normalize_base(path: str) -> str:
    """Allow dbfs:/… by translating to /dbfs/… so open()/os.walk() work."""
    if path.startswith("dbfs:/"):
        # dbfs:/a/b -> /dbfs/a/b
        return "/dbfs/" + path[len("dbfs:/") :].lstrip("/")
    return path

def discover_yaml_files_recursive(base_dir: str) -> List[str]:
    """
    Recursively find all *.yaml / *.yml under base_dir (skips hidden dirs/files).
    Supports both workspace filesystem and dbfs:/ (via /dbfs/ bridge).
    """
    base = _normalize_base(base_dir)
    if not os.path.isdir(base):
        raise FileNotFoundError(f"Rules folder not found or not a directory: {base_dir} (resolved: {base})")

    out: List[str] = []
    for root, dirs, files in os.walk(base):
        # skip hidden directories for noise reduction
        dirs[:] = [d for d in dirs if not d.startswith(".")]
        for f in files:
            if f.startswith("."):
                continue
            fl = f.lower()
            if fl.endswith(".yaml") or fl.endswith(".yml"):
                out.append(os.path.join(root, f))
    return sorted(out)

# =========================
# Delta I/O (ALWAYS OVERWRITE)
# =========================
def ensure_schema_exists(spark: SparkSession, full_table_name: str):
    parts = full_table_name.split(".")
    if len(parts) != 3:
        raise ValueError(f"Expected a 3-part name (catalog.schema.table), got '{full_table_name}'")
    cat, sch, _ = parts
    spark.sql(f"CREATE SCHEMA IF NOT EXISTS `{cat}`.`{sch}`")

def overwrite_rules_into_delta(
    spark: SparkSession,
    df: DataFrame,
    delta_table_name: str,
    table_doc: Optional[Dict[str, Any]] = None,
):
    # Track existence BEFORE write
    existed_before = spark.catalog.tableExists(delta_table_name)

    # Ensure schema
    ensure_schema_exists(spark, delta_table_name)

    # Cast audit timestamps
    df = df.withColumn("created_at", to_timestamp(col("created_at"))) \
           .withColumn("updated_at", to_timestamp(col("updated_at")))

    # Overwrite (content + schema)
    df.write.format("delta") \
      .mode("overwrite") \
      .option("overwriteSchema", "true") \
      .saveAsTable(delta_table_name)

    # Apply comments (table: always; columns: only on first create)
    apply_table_documentation(spark, delta_table_name, table_doc, created_now=not existed_before)

    # Preview metadata
    if table_doc:
        preview_table_documentation(spark, delta_table_name, table_doc)

    # Display a clean confirmation table
    display_section("WRITE RESULT (Delta)")
    summary = spark.createDataFrame(
        [(df.count(), delta_table_name)],
        schema="`rules written` long, `target table` string",
    )
    show_df(summary, n=1)
    print(f"\nrules written to target table: {delta_table_name}")

# =========================
# Display-first debug helpers
# =========================
def debug_display_batch(spark: SparkSession, df_rules: DataFrame) -> None:
    display_section("SUMMARY OF RULES LOADED FROM YAML")
    totals = [
        (
            df_rules.count(),
            df_rules.select("check_id").distinct().count(),
            df_rules.select("check_id", "run_config_name").distinct().count(),
        )
    ]
    totals_df = spark.createDataFrame(
        totals,
        schema="`total number of rules found` long, `unique rules found` long, `distinct pair of rules` long",
    )
    show_df(totals_df, n=1)

    display_section("SAMPLE OF RULES LOADED FROM YAML (check_id, name, run_config_name, yaml_path)")
    sample_cols = df_rules.select("check_id", "name", "run_config_name", "yaml_path").orderBy(
        desc("yaml_path")
    )
    show_df(sample_cols, n=50, truncate=False)

    display_section("RULES LOADED PER TABLE")
    by_table = df_rules.groupBy("table_name").count().orderBy(desc("count"))
    show_df(by_table, n=200)

    # Only show payload snippet when small
    distinct_cid = totals[0][1]
    if distinct_cid <= 3:
        display_section("PAYLOAD PREVIEW (first 3)")
        show_df(df_rules.select("check_id", "check_id_payload"), n=3, truncate=False)

def print_rules_df(spark: SparkSession, rules: List[dict]) -> Optional[DataFrame]:
    if not rules:
        print("No rules to show.")
        return None
    df = (
        spark.createDataFrame(rules, schema=TABLE_SCHEMA)
        .withColumn("created_at", to_timestamp(col("created_at")))
        .withColumn("updated_at", to_timestamp(col("updated_at")))
    )
    debug_display_batch(spark, df)
    return df

# =========================
# Validation (now works off a file list)
# =========================
def validate_rule_files(file_paths: List[str], required_fields: List[str], fail_fast: bool = True) -> List[str]:
    errors = []
    for full_path in file_paths:
        print(f"\nValidating file: {full_path}")
        try:
            docs = load_yaml_rules(full_path)
            if not docs:
                print(f"  (empty) skipped: {full_path}")
                continue
            validate_rules_file(docs, full_path)
            print(f"  File-level validation passed for {full_path}")
            for rule in docs:
                validate_rule_fields(rule, full_path, required_fields=required_fields)
                print(f"    Rule-level validation passed for rule '{rule.get('name')}'")
            validate_with_dqx(docs, full_path)
            print(f"  DQX validation PASSED for {full_path}")
        except Exception as ex:
            print(f"  Validation FAILED for file {full_path}\n  Reason: {ex}")
            errors.append(str(ex))
            if fail_fast:
                break
    if not errors:
        print("\nAll YAML rule files are valid!")
    else:
        print("\nRule validation errors found:")
        for e in errors:
            print(e)
    return errors

# =========================
# Main
# =========================
def main(
    output_config_path: str = "resources/dqx_config.yaml",
    rules_dir: Optional[str] = None,
    time_zone: str = "America/Chicago",
    dry_run: bool = False,
    validate_only: bool = False,
    required_fields: Optional[List[str]] = None,
    batch_dedupe_mode: str = "warn",  # warn | error | skip
    table_doc: Optional[Dict[str, Any]] = None,   # table & column comments; markdown strings
):
    spark = SparkSession.builder.getOrCreate()

    # Environment banner (unchanged)
    print_notebook_env(spark, local_timezone=time_zone)

    with open(output_config_path, "r") as fh:
        output_config = yaml.safe_load(fh) or {}
    rules_dir = rules_dir or output_config["dqx_yaml_checks"]
    delta_table_name = output_config["dqx_checks_config_table_name"]

    required_fields = required_fields or ["table_name", "name", "criticality", "run_config_name", "check"]

    # NEW: recursively discover YAML files
    yaml_files = discover_yaml_files_recursive(rules_dir)
    display_section("YAML FILES DISCOVERED (recursive)")
    files_df = spark.createDataFrame([(p,) for p in yaml_files], "yaml_path string")
    show_df(files_df, n=500, truncate=False)

    if validate_only:
        print("\nValidation only: not writing any rules.")
        validate_rule_files(yaml_files, required_fields)
        return

    # collect rules
    all_rules: List[dict] = []
    for full_path in yaml_files:
        file_rules = process_yaml_file(full_path, required_fields=required_fields, time_zone=time_zone)
        if file_rules:
            all_rules.extend(file_rules)
            print(f"[loader] {full_path}: rules={len(file_rules)}")

    if not all_rules:
        print("No rules discovered; nothing to do.")
        return

    print(f"[loader] total parsed rules (pre-dedupe): {len(all_rules)}")

    # in-batch dedupe on check_id only
    all_rules = dedupe_rules_in_batch_by_check_id(all_rules, mode=batch_dedupe_mode)

    # assemble DataFrame and DISPLAY diagnostics
    df = spark.createDataFrame(all_rules, schema=TABLE_SCHEMA)
    debug_display_batch(spark, df)

    if dry_run:
        display_section("DRY-RUN: FULL RULES PREVIEW")
        show_df(df.orderBy("table_name", "name"), n=1000, truncate=False)
        return

    # ALWAYS OVERWRITE on each run
    table_doc = table_doc or default_table_documentation(delta_table_name)
    overwrite_rules_into_delta(spark, df, delta_table_name, table_doc=table_doc)
    print(f"Finished writing rules to '{delta_table_name}' (overwrite).")


if __name__ == "__main__":
    # main(dry_run=True)
    # main(validate_only=True)
    main()