# Generate DQX Checks

In [0]:
# -- DQX profiling options: all known keys --
profile_options = {
    # "round": True,                 # (Removed - not valid for this profiler version)
    "max_in_count": 10,            # Max distinct values for is_in rule
    "distinct_ratio": 0.05,        # Max unique/total ratio for is_in rule
    "max_null_ratio": 0.01,        # Max null fraction to allow is_not_null rule
    "remove_outliers": True,       # Remove outliers for min/max
    "outlier_columns": [],         # Only these columns get outlier removal (empty=all numerics)
    "num_sigmas": 3,               # Stddev for outlier removal (z-score cutoff)
    "trim_strings": True,          # Strip whitespace before profiling strings
    "max_empty_ratio": 0.01,       # Max empty string ratio for is_not_null_or_empty
    "sample_fraction": 0.3,        # Row fraction to sample
    "sample_seed": None,           # Seed for reproducibility (set int for deterministic)
    "limit": 1000,                 # Max number of rows to profile
    "profile_types": None,         # List of rule types (e.g. ["is_in", "is_not_null"]); None=default
    "min_length": None,            # Min string length to consider (None disables)
    "max_length": None,            # Max string length to consider (None disables)
    "include_histograms": False,   # Compute histograms as part of profiling
    "min_value": None,             # Numeric min override (None disables)
    "max_value": None,             # Numeric max override (None disables)
}

In [0]:
%pip install databricks-labs-dqx==0.8.0

In [0]:
dbutils.library.restartPython()

In [0]:
# dqx_rule_generator.py
# Full file — adds YAML list support via header-comment parsing + display() preview.
# NOTHING REMOVED: all params/options stay. Renamed yaml_key_order -> key_order.
# Single-table & CSV FQNs still work. One YAML file per table is written in output_location.

import os
import re
import io
import json
from typing import List, Optional, Dict, Any, Literal, Tuple

import yaml
from databricks.sdk import WorkspaceClient
from databricks.labs.dqx.profiler.profiler import DQProfiler
from databricks.labs.dqx.profiler.dlt_generator import DQDltGenerator
from databricks.labs.dqx.engine import DQEngine
from databricks.labs.dqx.config import (
    FileChecksStorageConfig,
    WorkspaceFileChecksStorageConfig,
    TableChecksStorageConfig,
    VolumeFileChecksStorageConfig,
)
from pyspark.sql import SparkSession


def glob_to_regex(glob_pattern: str) -> str:
    if not glob_pattern or not glob_pattern.startswith('.'):
        raise ValueError("Exclude pattern must start with a dot, e.g. '.tamarack_*'")
    glob = glob_pattern[1:]
    regex = re.escape(glob).replace(r'\*', '.*')
    return '^' + regex + '$'


# Keys currently shown in DQX docs for profiler "options" (kept permissive; we only warn on extras)
DOC_SUPPORTED_KEYS = {
    "sample_fraction",
    "sample_seed",
    "limit",
    "remove_outliers",
    "outlier_columns",
    "num_sigmas",
    "max_null_ratio",
    "trim_strings",
    "max_empty_ratio",
    "distinct_ratio",
    "max_in_count",
    "round",
}

# ---------- YAML file support (header-comment parsing for defaults) ----------
_YAML_PATH_RE = re.compile(r"\.(ya?ml)$", re.IGNORECASE)
_FROM_INFOSCHEMA = re.compile(r"FROM\s+([A-Za-z0-9_]+)\.information_schema\.tables", re.IGNORECASE)
_TABLE_SCHEMA_EQ = re.compile(r"table_schema\s*=\s*'([^']+)'", re.IGNORECASE)

def _is_yaml_path(path: str) -> bool:
    return bool(_YAML_PATH_RE.search(path))

def _resolve_local_like_path(path: str) -> Optional[str]:
    """
    Try hard to resolve a repo-relative file path from a notebook:
    - as-given
    - join with cwd
    - walk up 6 parents and join
    Returns a filesystem path if found, else None.
    """
    candidates = []
    if os.path.exists(path):
        return path
    cwd = os.getcwd()
    base = cwd
    for _ in range(6):
        cand = os.path.abspath(os.path.join(base, path))
        candidates.append(cand)
        if os.path.exists(cand):
            return cand
        parent = os.path.dirname(base)
        if parent == base:
            break
        base = parent
    return None

def _read_text_any(path: str) -> str:
    """
    Read small text files from:
      - repo/local filesystem (relative or absolute)
      - dbfs:/, /dbfs/, /Volumes/
      - workspace files path (starts with '/'), via Files API
    """
    # DBFS / Volumes
    if path.startswith("dbfs:/") or path.startswith("/dbfs/") or path.startswith("/Volumes/"):
        try:
            from databricks.sdk.runtime import dbutils
        except Exception as e:
            raise RuntimeError("dbutils is required to read from DBFS/Volumes") from e
        target = path if path.startswith("dbfs:") else (f"dbfs:{path}" if path.startswith("/") else f"dbfs:/{path}")
        return dbutils.fs.head(target, 10 * 1024 * 1024)

    # Absolute workspace file path (e.g. /Workspace/Repos/.../file.yaml)
    if path.startswith("/"):
        wc = WorkspaceClient()
        try:
            # newer SDK signature
            data = wc.files.download(file_path=path).read()
        except TypeError:
            # older SDK signature
            data = wc.files.download(path=path).read()
        return data.decode("utf-8")

    # Relative/local resolution (repo-friendly)
    resolved = _resolve_local_like_path(path)
    if resolved and os.path.isfile(resolved):
        with open(resolved, "r", encoding="utf-8") as fh:
            return fh.read()

    # Nothing worked → helpful error
    msg = [
        f"Could not find YAML at '{path}'.",
        f"cwd={os.getcwd()}",
        "Hints:",
        "  - If this file is in your repo, pass a path relative to the notebook or the repo root.",
        "  - Or pass an absolute workspace path like '/Workspace/Repos/.../file.yaml'.",
        "  - Or use a DBFS/Volumes path like 'dbfs:/...' or '/Volumes/...'.",
    ]
    raise FileNotFoundError("\n".join(msg))

def _parse_global_hints_from_comments(text: str) -> Tuple[Optional[str], Optional[str]]:
    """
    Pull defaults from header comments only (no YAML keys):
      - catalog from: 'FROM <catalog>.information_schema.tables'
      - schema  from: "table_schema = '<schema>'"
    """
    m_cat = _FROM_INFOSCHEMA.search(text)
    m_sch = _TABLE_SCHEMA_EQ.search(text)
    cat = m_cat.group(1) if m_cat else None
    sch = m_sch.group(1) if m_sch else None
    return cat, sch

def _ensure_fqns(names: List[str], hint_catalog: Optional[str], hint_schema: Optional[str]) -> List[str]:
    """
    Accept: catalog.schema.table | schema.table | table
    Use comment-derived defaults to fill missing parts. Dotted entries override per item.
    """
    out: List[str] = []
    for n in names:
        n = n.strip()
        if not n:
            continue
        parts = n.split(".")
        if len(parts) == 3:
            out.append(n)
        elif len(parts) == 2:
            if not hint_catalog:
                raise ValueError(f"'{n}' lacks catalog; add it to the item or provide a comment default.")
            out.append(f"{hint_catalog}.{n}")
        elif len(parts) == 1:
            if not (hint_catalog and hint_schema):
                raise ValueError(f"'{n}' needs catalog & schema; set via comments or use dotted forms.")
            out.append(f"{hint_catalog}.{hint_schema}.{n}")
        else:
            raise ValueError(f"Unrecognized table format: {n}")
    bad = [t for t in out if t.count(".") != 2]
    if bad:
        raise ValueError(f"Invalid FQN(s) after resolution: {bad}")
    return sorted(set(out))

def _discover_tables_from_yaml_by_comments(yaml_path: str) -> List[str]:
    """
    Load the YAML file (with e.g. 'table_name:' list) and use ONLY header comments
    to infer catalog/schema defaults. Dotted items override per entry.
    """
    text = _read_text_any(yaml_path)
    cat_hint, sch_hint = _parse_global_hints_from_comments(text)

    obj = yaml.safe_load(io.StringIO(text))
    if not isinstance(obj, dict):
        raise ValueError(f"YAML must contain a mapping with a list; got: {type(obj).__name__}")

    names = None
    for key in ("table_name", "tables", "table_names", "list"):
        if isinstance(obj.get(key), list):
            names = [str(x).strip() for x in obj[key] if x]
            break
    if not names:
        raise ValueError(f"No table list found in YAML: {yaml_path}")

    fqns = _ensure_fqns(names, cat_hint, sch_hint)
    print(f"[INFO] Parsed defaults from comments: catalog={cat_hint} schema={sch_hint}")
    return fqns

def _display_table_preview(spark: SparkSession, fqns: List[str], title: str = "Resolved Tables") -> None:
    """
    Use display() (Databricks) for a clean, interactive preview instead of printing.
    """
    rows = [(f, *f.split(".")) for f in fqns]
    df = spark.createDataFrame(rows, "fqn string, catalog string, schema string, table string")
    print(f"\n=== {title} ({len(fqns)}) ===")
    try:
        display(df)
    except NameError:
        df.show(len(fqns), truncate=False)


class RuleGenerator:
    def __init__(
        self,
        mode: str,                       # "pipeline" | "catalog" | "schema" | "table"
        name_param: str,                 # pipeline CSV | catalog | catalog.schema | table FQN CSV | YAML path
        output_format: str,              # "yaml" | "table"
        output_location: str,            # yaml: folder or file; table: catalog.schema.table
        profile_options: Dict[str, Any],
        exclude_pattern: Optional[str] = None,     # e.g. ".tmp_*"
        created_by: Optional[str] = "LMG",
        columns: Optional[List[str]] = None,       # None => whole table (only if mode=="table")
        run_config_name: str = "default",          # DQX run group tag
        criticality: str = "warn",                 # "warn" | "error"
        key_order: Literal["engine", "custom"] = "custom",  # "engine" uses DQX save; "custom" enforces key order
        include_table_name: bool = True,           # include table_name in each rule dict
    ):
        self.mode = mode.lower().strip()
        self.name_param = name_param
        self.output_format = output_format.lower().strip()
        self.output_location = output_location
        self.profile_options = profile_options or {}
        self.exclude_pattern = exclude_pattern
        self.created_by = created_by
        self.columns = columns
        self.run_config_name = run_config_name
        self.criticality = criticality
        self.key_order = key_order
        self.include_table_name = include_table_name

        self.spark = SparkSession.getActiveSession()
        if not self.spark:
            raise RuntimeError("No active Spark session found. Run this in a Databricks notebook.")

        if self.output_format not in {"yaml", "table"}:
            raise ValueError("output_format must be 'yaml' or 'table'.")
        if self.output_format == "yaml" and not self.output_location:
            raise ValueError("When output_format='yaml', provide output_location (folder or file).")

    # ---------- profile options: pass via 'options' kwarg, warn on unknowns ----------
    def _profile_call_kwargs(self) -> Dict[str, Any]:
        """
        Build kwargs for DQProfiler.profile / profile_table.
        We pass:
          - cols=self.columns (when provided)
          - options=self.profile_options (dict; profiler reads keys internally)
        We only WARN on keys not in the current documented set; we do not drop them.
        """
        kwargs: Dict[str, Any] = {}
        if self.columns is not None:
            kwargs["cols"] = self.columns
        if self.profile_options:
            unknown = sorted(set(self.profile_options) - DOC_SUPPORTED_KEYS)
            if unknown:
                print(f"[INFO] Profiling options not in current docs (passing through anyway): {unknown}")
            kwargs["options"] = self.profile_options
        return kwargs

    # ---------- discovery ----------
    def _exclude_tables_by_pattern(self, fq_tables: List[str]) -> List[str]:
        if not self.exclude_pattern:
            return fq_tables
        regex = glob_to_regex(self.exclude_pattern)
        pattern = re.compile(regex)
        filtered = []
        for fq in fq_tables:
            tbl = fq.split('.')[-1]
            if not pattern.match(tbl):
                filtered.append(fq)
        print(f"[INFO] Excluded {len(fq_tables) - len(filtered)} tables by pattern '{self.exclude_pattern}'")
        return filtered

    def _discover_tables(self) -> List[str]:
        print("\n===== PARAMETERS PASSED THIS RUN =====")
        print(f"mode:             {self.mode}")
        print(f"name_param:       {self.name_param}")
        print(f"output_format:    {self.output_format}")
        print(f"output_location:  {self.output_location}")
        print(f"exclude_pattern:  {self.exclude_pattern}")
        print(f"created_by:       {self.created_by}")
        print(f"columns:          {self.columns}")
        print(f"run_config_name:  {self.run_config_name}")
        print(f"criticality:      {self.criticality}")
        print(f"key_order:        {self.key_order}")
        print(f"include_table_name: {self.include_table_name}")
        print(f"profile_options:")
        for k, v in self.profile_options.items():
            print(f"  {k}: {v}")
        print("======================================\n")

        allowed_modes = {"pipeline", "catalog", "schema", "table"}
        if self.mode not in allowed_modes:
            raise ValueError(f"Invalid mode '{self.mode}'. Must be one of: {sorted(allowed_modes)}.")
        if self.columns is not None and self.mode != "table":
            raise ValueError("The 'columns' parameter can only be used in mode='table'.")

        discovered: List[str] = []

        if self.mode == "pipeline":
            print("Searching for pipeline output tables...")
            ws = WorkspaceClient()
            pipelines = [p.strip() for p in self.name_param.split(",") if p.strip()]
            print(f"Pipelines passed: {pipelines}")
            for pipeline_name in pipelines:
                print(f"Finding output tables for pipeline: {pipeline_name}")
                pls = list(ws.pipelines.list_pipelines())
                pl = next((p for p in pls if p.name == pipeline_name), None)
                if not pl:
                    raise RuntimeError(f"Pipeline '{pipeline_name}' not found via SDK.")
                latest_update = pl.latest_updates[0].update_id
                events = ws.pipelines.list_pipeline_events(pipeline_id=pl.pipeline_id, max_results=250)
                pipeline_tables = [
                    getattr(ev.origin, "flow_name", None)
                    for ev in events
                    if getattr(ev.origin, "update_id", None) == latest_update and getattr(ev.origin, "flow_name", None)
                ]
                discovered += [x for x in pipeline_tables if x]

        elif self.mode == "catalog":
            print("Searching for tables in catalog...")
            catalog = self.name_param.strip()
            schemas = [row.namespace for row in self.spark.sql(f"SHOW SCHEMAS IN {catalog}").collect()]
            for s in schemas:
                tbls = self.spark.sql(f"SHOW TABLES IN {catalog}.{s}").collect()
                discovered += [f"{catalog}.{s}.{row.tableName}" for row in tbls]

        elif self.mode == "schema":
            print("Searching for tables in schema...")
            if self.name_param.count(".") != 1:
                raise ValueError("For 'schema' mode, name_param must be catalog.schema")
            catalog, schema = self.name_param.strip().split(".")
            tbls = self.spark.sql(f"SHOW TABLES IN {catalog}.{schema}").collect()
            discovered = [f"{catalog}.{schema}.{row.tableName}" for row in tbls]

        else:  # table
            print("Profiling one or more specific tables...")
            if _is_yaml_path(self.name_param):
                print(f"[INFO] name_param is a YAML list → {self.name_param}")
                discovered = _discover_tables_from_yaml_by_comments(self.name_param)
            else:
                tables = [t.strip() for t in self.name_param.split(",") if t.strip()]
                for t in tables:
                    if t.count(".") != 2:
                        raise ValueError(f"Table name '{t}' must be fully qualified (catalog.schema.table)")
                discovered = tables

        print("\nRunning exclude pattern filtering (if any)...")
        discovered = self._exclude_tables_by_pattern(discovered)

        # Interactive preview instead of plain prints
        _display_table_preview(self.spark, discovered, title="Final table list to generate DQX rules for")

        print("==========================================\n")
        return sorted(set(discovered))

    # ---------- storage config helpers ----------
    @staticmethod
    def _infer_file_storage_config(file_path: str):
        if file_path.startswith("/Volumes/"):
            return VolumeFileChecksStorageConfig(location=file_path)
        if file_path.startswith("/"):
            return WorkspaceFileChecksStorageConfig(location=file_path)
        return FileChecksStorageConfig(location=file_path)

    @staticmethod
    def _table_storage_config(table_fqn: str, run_config_name: Optional[str] = None, mode: str = "append"):
        return TableChecksStorageConfig(location=table_fqn, run_config_name=run_config_name, mode=mode)

    @staticmethod
    def _workspace_files_upload(path: str, payload: bytes) -> None:
        wc = WorkspaceClient()
        try:
            wc.files.upload(file_path=path, contents=payload, overwrite=True)  # newer SDK
        except TypeError:
            wc.files.upload(path=path, contents=payload, overwrite=True)       # older SDK

    @staticmethod
    def _ensure_parent(path: str) -> None:
        """Create parent dir for local paths."""
        parent = os.path.dirname(path)
        if parent and not os.path.exists(parent):
            os.makedirs(parent, exist_ok=True)

    @staticmethod
    def _ensure_dbfs_parent(dbutils, path: str) -> None:
        parent = path.rsplit("/", 1)[0] if "/" in path else path
        if parent:
            dbutils.fs.mkdirs(parent)

    # ---------- DQX check shaping ----------
    def _dq_constraint_to_check(
        self,
        rule_name: str,
        constraint_sql: str,
        table_name: str,
        criticality: str,
        run_config_name: str
    ) -> Dict[str, Any]:
        """
        Convert a profiler constraint (SQL) into a DQX check dict.
        Key order: table_name, name, criticality, run_config_name, check (insertion order).
        """
        d = {
            "name": rule_name,
            "criticality": criticality,
            "run_config_name": run_config_name,
            "check": {
                "function": "sql_expression",
                "arguments": {
                    "expression": constraint_sql,
                    "name": rule_name,
                }
            },
        }
        if self.include_table_name:
            d = {"table_name": table_name, **d}
        return d

    # ---------- YAML writers ----------
    def _write_yaml_ordered(self, checks: List[Dict[str, Any]], path: str) -> None:
        """
        Dump YAML preserving key order and upload:
          - /Volumes/... | dbfs:/... | /dbfs/... -> dbutils.fs.put (mkdirs parent)
          - /Shared/... (workspace files) -> Files API
          - relative/local path -> os.makedirs + open(...)
        """
        yaml_str = yaml.safe_dump(checks, sort_keys=False, default_flow_style=False)

        # DBFS / Volumes
        if path.startswith("dbfs:/") or path.startswith("/dbfs/") or path.startswith("/Volumes/"):
            try:
                from databricks.sdk.runtime import dbutils
            except Exception:
                raise RuntimeError("dbutils is required to write to DBFS/Volumes.")
            target = path if path.startswith("dbfs:/") else (f"dbfs:{path}" if not path.startswith("dbfs:") else path)
            self._ensure_dbfs_parent(dbutils, target.rsplit("/", 1)[0])
            dbutils.fs.put(target, yaml_str, True)
            print(f"[RUN] Wrote ordered YAML to {path}")
            return

        # Workspace files
        if path.startswith("/"):
            self._workspace_files_upload(path, yaml_str.encode("utf-8"))
            print(f"[RUN] Wrote ordered YAML to workspace file: {path}")
            return

        # Local (driver) relative/absolute filesystem (Repos-friendly)
        full_path = os.path.abspath(path)
        self._ensure_parent(full_path)
        with open(full_path, "w", encoding="utf-8") as f:
            f.write(yaml_str)
        print(f"[RUN] Wrote ordered YAML to local path: {full_path}")

    # ---------- main ----------
    def run(self):
        try:
            tables = self._discover_tables()
            print("[RUN] Beginning DQX rule generation on these tables:")
            for t in tables:
                print(f"  {t}")
            print("==========================================\n")

            call_kwargs = self._profile_call_kwargs()
            print("[RUN] Profiler call kwargs:")
            print(f"  cols:    {call_kwargs.get('cols')}")
            print(f"  options: {json.dumps(call_kwargs.get('options', {}), indent=2)}")

            dq_engine = DQEngine(WorkspaceClient())
            total_checks = 0

            for fq_table in tables:
                if fq_table.count(".") != 2:
                    print(f"[WARN] Skipping invalid table name: {fq_table}")
                    continue
                cat, sch, tab = fq_table.split(".")

                # Verify readability
                try:
                    print(f"[RUN] Checking table readability: {fq_table}")
                    self.spark.table(fq_table).limit(1).collect()
                except Exception as e:
                    print(f"[WARN] Table {fq_table} not readable in Spark: {e}")
                    continue

                profiler = DQProfiler(WorkspaceClient())
                generator = DQDltGenerator(WorkspaceClient())
                df = self.spark.table(fq_table)

                try:
                    print(f"[RUN] Profiling and generating rules for: {fq_table}")
                    # DataFrame profiling with options and optional cols
                    summary_stats, profiles = profiler.profile(df, **call_kwargs)
                    # If you prefer table-based API:
                    # summary_stats, profiles = profiler.profile_table(table=fq_table, **call_kwargs)

                    rules_dict = generator.generate_dlt_rules(profiles, language="Python_Dict")
                except Exception as e:
                    print(f"[WARN] Profiling failed for {fq_table}: {e}")
                    continue

                checks: List[Dict[str, Any]] = []
                for rule_name, constraint in (rules_dict or {}).items():
                    checks.append(
                        self._dq_constraint_to_check(
                            rule_name=rule_name,
                            constraint_sql=constraint,
                            table_name=fq_table,
                            criticality=self.criticality,
                            run_config_name=self.run_config_name,
                        )
                    )

                if not checks:
                    print(f"[INFO] No checks generated for {fq_table}.")
                    continue

                # Destination selection
                if self.output_format == "yaml":            # "yaml" | "table"
                    # Directory -> {table}.yaml ; or exact file path
                    if self.output_location.endswith((".yaml", ".yml")):
                        path = self.output_location
                    else:
                        path = f"{self.output_location.rstrip('/')}/{tab}.yaml"

                    if self.key_order == "engine":
                        cfg = self._infer_file_storage_config(path)
                        print(f"[RUN] Saving {len(checks)} checks via DQX to: {path}")
                        dq_engine.save_checks(checks, config=cfg)
                    else:  # "custom"
                        print(f"[RUN] Saving {len(checks)} checks with strict key order to: {path}")
                        self._write_yaml_ordered(checks, path)

                    total_checks += len(checks)

                else:  # table sink
                    cfg = self._table_storage_config(
                        table_fqn=self.output_location,
                        run_config_name=self.run_config_name,
                        mode="append"
                    )
                    print(f"[RUN] Appending {len(checks)} checks to table: {self.output_location} (run_config_name={self.run_config_name})")
                    dq_engine.save_checks(checks, config=cfg)
                    total_checks += len(checks)

            print(f"[RUN] {'Successfully saved' if total_checks else 'No'} checks. Count: {total_checks}")
        except Exception as e:
            print(f"[ERROR] Rule generation failed: {e}")


# -------------------- Usage examples --------------------
if __name__ == "__main__":
    profile_options = {
        # Sampling
        "sample_fraction": 0.3,
        "sample_seed": 42,
        "limit": 1000,
        # Outliers
        "remove_outliers": True,
        "outlier_columns": [],
        "num_sigmas": 3,
        # Nulls / empties
        "max_null_ratio": 0.01,
        "trim_strings": True,
        "max_empty_ratio": 0.01,
        # Distincts → is_in
        "distinct_ratio": 0.05,
        "max_in_count": 10,
        # Rounding
        "round": True,
        # (Keys not in current docs will still pass through; you’ll see a one-time INFO warning)
        # "include_histograms": False,
        # "min_length": None,
        # "max_length": None,
        # "min_value": None,
        # "max_value": None,
        # "profile_types": None,
    }
    """
    # Example A: Single table (unchanged)
    RuleGenerator(
        mode="table",                                   # "pipeline" | "catalog" | "schema" | "table"
        name_param="dq_prd.monitoring.job_run_audit",   # depends on mode
        output_format="yaml",                           # "yaml" | "table"
        output_location="dqx_checks",                   # yaml dir (local) OR "/Shared/..." OR "dbfs:/..." OR "/Volumes/..."
        profile_options=profile_options,
        columns=None,                                   # None => whole table (only valid when mode=="table")
        exclude_pattern=None,                           # e.g. ".tmp_*"
        created_by="LMG",
        run_config_name="default",
        criticality="error",
        key_order="custom",                             # "engine" or "custom"
        include_table_name=True,
    ).run()
    """
    # Example B: YAML list of tables + header comments (wkdy_* example)
    #   Notebook path: src/dqx/00_generate_dqx_checks
    #   YAML file path (relative to repo): info/wkdy_table/wkdy_table_info/wkdy_table_list-gold.yaml
    #   Output folder:                      info/wkdy_table/wkdy_table_info/wkdy_generated_dqx_checks
    RuleGenerator(
        mode="table",                                   # "pipeline" | "catalog" | "schema" | "table"
        name_param="resources/dqx_checks_generated/audit/audit_table_list-gold.yaml",
        output_format="yaml",                           # "yaml" | "table"
        output_location="resources/dqx_checks_generated/audit/audit_generated_dqx_checks",  # one YAML per table
        profile_options=profile_options,
        columns=None,                                   # None => whole table (only valid when mode=="table")
        exclude_pattern=None,                           # e.g. ".tmp_*"
        created_by="LMG",
        run_config_name="default",
        criticality="error",
        key_order="custom",                             # "engine" or "custom"
        include_table_name=True,
    ).run()

## Testing

In [0]:
# generate_checks.py
# Generate DQX checks from tables or YAML table-lists and write to YAML, table, or both.

import os
import re
import io
import json
import hashlib
import datetime
from typing import List, Optional, Dict, Any, Literal, Tuple

import yaml
from databricks.sdk import WorkspaceClient
from databricks.labs.dqx.profiler.profiler import DQProfiler
from databricks.labs.dqx.profiler.dlt_generator import DQDltGenerator
from databricks.labs.dqx.engine import DQEngine
from databricks.labs.dqx.config import (
    FileChecksStorageConfig,
    WorkspaceFileChecksStorageConfig,
    TableChecksStorageConfig,
    VolumeFileChecksStorageConfig,
)
from pyspark.sql import SparkSession, types as T
from pyspark.sql import DataFrame  # for type hints in show_df()

# Notebook env helper (prints banner in the notebook and returns a dict we can reuse)
from utils.print import print_notebook_env, get_notebook_path as _nb_path

# --------------------------------------------------------------------------------------
# Documentation dictionary for the generated checks table (apply on first create)
# --------------------------------------------------------------------------------------

DQX_GENERATED_CHECKS_CONFIG_METADATA: Dict[str, Any] = {
    "table": "dq_dev.dqx.checks_generated_config",  # will be overridden by the actual FQN you pass
    "table_comment": (
        "# DQX *Generated* Checks Configuration\n"
        "- Stores flattened rules generated by the profiler.\n"
        "- Each row is a rule; `check_id` is a stable hash of the canonical payload.\n"
        "- `generator_meta` captures the profiler options and generator settings used to create these rows.\n"
    ),
    "columns": {
        "check_id": "SHA-256 **hash** of canonical payload (stable rule identity).",
        "check_id_payload": "Canonical **JSON** used to compute `check_id`.",
        "table_name": "Fully qualified **target table** (`catalog.schema.table`).",

        "name": "Human-readable **rule name**.",
        "criticality": "Rule severity: `warn|warning|error`.",
        "check": "Structured **check** object: `{function, for_each_column, arguments}`.",
        "filter": "Optional row-level **filter** expression.",
        "run_config_name": "**Execution group/tag** for this rule.",
        "user_metadata": "User-provided **metadata** `map<string,string>`.",

        "yaml_path": "YAML **file path** that held this rule (or `<generated://...>` if ephemeral).",
        "active": "If **false**, the rule is ignored by runners.",

        # NEW (placed after active, before audit)
        "generator_meta": (
            "Array of two items: "
            "`[{section:'profile_options', payload:map}, {section:'generator_settings', payload:map}]`."
        ),

        "created_by": "Audit: **creator** of the row.",
        "created_at": "Audit: **creation timestamp** (UTC ISO).",
        "updated_by": "Audit: **last updater**.",
        "updated_at": "Audit: **last update timestamp**.",
    },
}

DQXGeneratedChecksConfig = DQX_GENERATED_CHECKS_CONFIG_METADATA

# --------------------------------------------------------------------------------------
# Schema (unified) for the generated checks table (adds generator_meta)
# --------------------------------------------------------------------------------------

DQX_GENERATED_CHECKS_CONFIG_SCHEMA = T.StructType([
    T.StructField("check_id",            T.StringType(),  False),
    T.StructField("check_id_payload",    T.StringType(),  False),
    T.StructField("table_name",          T.StringType(),  False),

    # DQX fields
    T.StructField("name",                T.StringType(),  False),
    T.StructField("criticality",         T.StringType(),  False),
    T.StructField("check", T.StructType([
        T.StructField("function",        T.StringType(),  False),
        T.StructField("for_each_column", T.ArrayType(T.StringType()), True),
        T.StructField("arguments",       T.MapType(T.StringType(), T.StringType()), True),
    ]), False),
    T.StructField("filter",              T.StringType(),  True),
    T.StructField("run_config_name",     T.StringType(),  False),
    T.StructField("user_metadata",       T.MapType(T.StringType(), T.StringType()), True),

    # Ops
    T.StructField("yaml_path",           T.StringType(),  False),
    T.StructField("active",              T.BooleanType(), False),

    # NEW: meta goes right here (as requested, before audit)
    T.StructField("generator_meta", T.ArrayType(T.StructType([
        T.StructField("section", T.StringType(), False),  # "profile_options" | "generator_settings"
        T.StructField("payload", T.MapType(T.StringType(), T.StringType()), True),
    ])), True),

    # Audit
    T.StructField("created_by",          T.StringType(),  False),
    T.StructField("created_at",          T.StringType(),  False),  # ISO string; cast downstream if needed
    T.StructField("updated_by",          T.StringType(),  True),
    T.StructField("updated_at",          T.StringType(),  True),
])

# --------------------------------------------------------------------------------------
# Constants / helpers (kept intact)
# --------------------------------------------------------------------------------------

DOC_SUPPORTED_KEYS = {
    "sample_fraction", "sample_seed", "limit",
    "remove_outliers", "outlier_columns", "num_sigmas",
    "max_null_ratio", "trim_strings", "max_empty_ratio",
    "distinct_ratio", "max_in_count", "round",
}

_YAML_PATH_RE = re.compile(r"\.(ya?ml)$", re.IGNORECASE)
_FROM_INFOSCHEMA = re.compile(r"FROM\s+([A-Za-z0-9_]+)\.information_schema\.tables", re.IGNORECASE)
_TABLE_SCHEMA_EQ = re.compile(r"table_schema\s*=\s*'([^']+)'", re.IGNORECASE)

def _is_yaml_path(p: str) -> bool:
    return bool(_YAML_PATH_RE.search(p))

def _esc_sql_comment(s: str) -> str:
    return s.replace("'", "''")

def _safe_json(obj: Any) -> str:
    return json.dumps(obj, sort_keys=True, separators=(",", ":"))

def _resolve_local_like_path(path: str) -> Optional[str]:
    """Resolve a repo/local-like path by walking up a few parents."""
    if os.path.exists(path):
        return os.path.abspath(path)
    base = os.getcwd()
    for _ in range(6):
        cand = os.path.abspath(os.path.join(base, path))
        if os.path.exists(cand):
            return cand
        parent = os.path.dirname(base)
        if parent == base:
            break
        base = parent
    return None

def _read_text_any(path: str) -> str:
    # DBFS / Volumes
    if path.startswith("dbfs:/") or path.startswith("/dbfs/") or path.startswith("/Volumes/"):
        try:
            from databricks.sdk.runtime import dbutils
        except Exception as e:
            raise RuntimeError("dbutils is required to read from DBFS/Volumes") from e
        target = path if path.startswith("dbfs:") else (f"dbfs:{path}" if path.startswith("/") else f"dbfs:/{path}")
        return dbutils.fs.head(target, 10 * 1024 * 1024)

    # Workspace Files (absolute)
    if path.startswith("/"):
        wc = WorkspaceClient()
        try:
            data = wc.files.download(file_path=path).read()
        except TypeError:
            data = wc.files.download(path=path).read()
        return data.decode("utf-8")

    # Local / repo-relative
    resolved = _resolve_local_like_path(path)
    if resolved and os.path.isfile(resolved):
        with open(resolved, "r", encoding="utf-8") as fh:
            return fh.read()

    raise FileNotFoundError(f"Could not read file: {path}")

def _ensure_parent_local(path: str) -> None:
    parent = os.path.dirname(path)
    if parent and not os.path.exists(parent):
        os.makedirs(parent, exist_ok=True)

def _to_dbfs_target(path: str) -> str:
    if path.startswith("dbfs:/"):
        return path
    if path.startswith("/dbfs/") or path.startswith("/Volumes/"):
        return "dbfs:" + path
    return path

def _write_text_any(path: str, payload: str) -> None:
    # DBFS / Volumes
    if path.startswith("dbfs:/") or path.startswith("/dbfs/") or path.startswith("/Volumes/"):
        try:
            from databricks.sdk.runtime import dbutils
        except Exception:
            raise RuntimeError("dbutils is required to write to DBFS/Volumes.")
        target = path if path.startswith("dbfs:/") else (f"dbfs:{path}" if not path.startswith("dbfs:") else path)
        parent = target.rsplit("/", 1)[0]
        if parent:
            dbutils.fs.mkdirs(parent)
        dbutils.fs.put(target, payload, True)
        return

    # Workspace Files
    if path.startswith("/"):
        wc = WorkspaceClient()
        try:
            wc.files.upload(file_path=path, contents=payload.encode("utf-8"), overwrite=True)
        except TypeError:
            wc.files.upload(path=path, contents=payload.encode("utf-8"), overwrite=True)
        return

    # Local (Repos/driver)
    full = os.path.abspath(path)
    _ensure_parent_local(full)
    with open(full, "w", encoding="utf-8") as fh:
        fh.write(payload)

def _parse_global_hints_from_comments(text: str) -> Tuple[Optional[str], Optional[str]]:
    m_cat = _FROM_INFOSCHEMA.search(text)
    m_sch = _TABLE_SCHEMA_EQ.search(text)
    return (m_cat.group(1) if m_cat else None, m_sch.group(1) if m_sch else None)

def _ensure_fqns(names: List[str], hint_catalog: Optional[str], hint_schema: Optional[str]) -> List[str]:
    out: List[str] = []
    for n in names:
        n = n.strip()
        if not n:
            continue
        parts = n.split(".")
        if len(parts) == 3:
            out.append(n)
        elif len(parts) == 2:
            if not hint_catalog:
                raise ValueError(f"'{n}' lacks catalog; add it or provide a comment default.")
            out.append(f"{hint_catalog}.{n}")
        elif len(parts) == 1:
            if not (hint_catalog and hint_schema):
                raise ValueError(f"'{n}' needs catalog & schema; add comments or use dotted forms.")
            out.append(f"{hint_catalog}.{hint_schema}.{n}")
        else:
            raise ValueError(f"Unrecognized table format: {n}")
    return sorted(set(out))

def _discover_tables_from_yaml_file(yaml_path: str) -> List[str]:
    """YAML contains e.g. `table_name: [a, b, c]`. Comments can hint catalog/schema."""
    text = _read_text_any(yaml_path)
    cat_hint, sch_hint = _parse_global_hints_from_comments(text)
    obj = yaml.safe_load(io.StringIO(text))
    if not isinstance(obj, dict):
        raise ValueError(f"YAML must contain a mapping with a list; got: {type(obj).__name__}")
    names = None
    for key in ("table_name", "tables", "table_names", "list"):
        if isinstance(obj.get(key), list):
            names = [str(x).strip() for x in obj[key] if x]
            break
    if not names:
        raise ValueError(f"No table list found in YAML: {yaml_path}")
    return _ensure_fqns(names, cat_hint, sch_hint)

def _prefix_of(table_fqn: str) -> str:
    """Prefix up to first underscore of the *table* portion."""
    base = table_fqn.split(".")[-1]
    return base.split("_", 1)[0].lower() if base else ""

def _filter_by_prefix_regex(tables: List[str], exclude_prefix_regex: Optional[str]) -> List[str]:
    if not exclude_prefix_regex:
        return tables
    pat = re.compile(exclude_prefix_regex, re.IGNORECASE)
    keep: List[str] = []
    for t in tables:
        if not pat.search(_prefix_of(t)):
            keep.append(t)
    return keep

def _display_table_preview(spark: SparkSession, fqns: List[str], title: str = "Resolved Tables") -> None:
    rows = [(f, *f.split(".")) for f in fqns]
    df = spark.createDataFrame(rows, "fqn string, catalog string, schema string, table string")
    print(f"\n=== {title} ({len(fqns)}) ===")
    try:
        display(df)
    except NameError:
        df.show(len(fqns), truncate=False)

def _now_iso() -> str:
    return datetime.datetime.utcnow().replace(microsecond=0).isoformat() + "Z"

def _stringify_map_values(d: Optional[Dict[str, Any]]) -> Dict[str, str]:
    out: Dict[str, str] = {}
    for k, v in (d or {}).items():
        if isinstance(v, (list, dict)):
            out[k] = _safe_json(v)
        elif isinstance(v, bool):
            out[k] = "true" if v else "false"
        elif v is None:
            out[k] = "null"
        else:
            out[k] = str(v)
    return out

def _compute_check_id_payload(table_name: str, check_dict: Dict[str, Any], filter_str: Optional[str]) -> str:
    def _canon_filter(s: Optional[str]) -> str:
        return "" if not s else " ".join(str(s).split())

    def _canon_check(chk: Dict[str, Any]) -> Dict[str, Any]:
        out = {"function": chk.get("function"), "for_each_column": None, "arguments": {}}
        fec = chk.get("for_each_column")
        if isinstance(fec, list):
            out["for_each_column"] = sorted([str(x) for x in fec]) or None
        args = chk.get("arguments") or {}
        canon_args: Dict[str, str] = {}
        for k, v in args.items():
            sv = "" if v is None else str(v).strip()
            if (sv.startswith("{") and sv.endswith("}")) or (sv.startswith("[") and sv.endswith("]")):
                try:
                    sv = _safe_json(json.loads(sv))
                except Exception:
                    pass
            canon_args[str(k)] = sv
        out["arguments"] = {k: canon_args[k] for k in sorted(canon_args)}
        return out

    payload_obj = {
        "table_name": (table_name or "").lower(),
        "filter": _canon_filter(filter_str),
        "check": _canon_check(check_dict or {}),
    }
    return _safe_json(payload_obj)

def _compute_check_id(payload: str) -> str:
    return hashlib.sha256(payload.encode("utf-8")).hexdigest()

# --------------------------------------------------------------------------------------
# Public aliases (you asked to keep these helper names around)
# --------------------------------------------------------------------------------------

def DisplayTablePreview(spark: SparkSession, fqns: List[str], title: str = "Resolved Tables") -> None:
    return _display_table_preview(spark, fqns, title)

def NowISO() -> str:
    return _now_iso()

def GetNotebookPath() -> str:
    return _nb_path()

def PathStartsWith(s: str, *prefixes: str) -> bool:
    return any(s.startswith(p) for p in prefixes)

def ResolveLocalLikePath(p: str) -> Optional[str]:
    return _resolve_local_like_path(p)

def DiscoverTablesFromYAML(yaml_path: str) -> List[str]:
    return _discover_tables_from_yaml_file(yaml_path)

def ListYAMLPathsInFolder(folder: str) -> List[str]:
    # DBFS/Volumes
    out: List[str] = []
    if folder.startswith("dbfs:/") or folder.startswith("/dbfs/") or folder.startswith("/Volumes/"):
        try:
            from databricks.sdk.runtime import dbutils
        except Exception:
            raise RuntimeError("dbutils is required to traverse DBFS/Volumes.")
        root = _to_dbfs_target(folder)
        def _walk_dbfs(dirpath: str):
            for fi in dbutils.fs.ls(dirpath):
                p = fi.path
                if p.endswith("/"):
                    _walk_dbfs(p)
                elif _is_yaml_path(p):
                    out.append(p)
        _walk_dbfs(root)
        return out

    # Workspace Files -> best-effort local resolve for recursion
    if folder.startswith("/"):
        resolved = _resolve_local_like_path(folder)
        if resolved and os.path.isdir(resolved):
            for r, _, files in os.walk(resolved):
                for f in files:
                    if _is_yaml_path(f):
                        out.append(os.path.join(r, f))
            return out
        return out

    # Local / repo
    resolved = _resolve_local_like_path(folder)
    if resolved and os.path.isdir(resolved):
        for r, _, files in os.walk(resolved):
            for f in files:
                if _is_yaml_path(f):
                    out.append(os.path.join(r, f))
    return out

# --------------------------------------------------------------------------------------
# Pretty display helpers (as requested)
# --------------------------------------------------------------------------------------

def _can_display() -> bool:
    return "display" in globals()

def show_df(df: DataFrame, n: int = 100, truncate: bool = False) -> None:
    if _can_display():
        display(df.limit(n))
    else:
        df.show(n, truncate=truncate)

def display_section(title: str) -> None:
    print("\n" + "═" * 80)
    print(f"║ {title}")
    print("═" * 80)

# --------------------------------------------------------------------------------------
# Column comment helper with robust multi-syntax fallback
# --------------------------------------------------------------------------------------

def _apply_column_comment_with_fallback(
    spark: SparkSession,
    cat: str,
    sch: str,
    tbl: str,
    col_name: str,
    comment_text: str,
    col_types_lower: Dict[str, str],
) -> bool:
    """Try COMMENT ON COLUMN, then ALTER ... ALTER COLUMN, then CHANGE COLUMN with type."""
    fqn_q = f"`{cat}`.`{sch}`.`{tbl}`"
    col_q = f"`{col_name}`"
    cmt = _esc_sql_comment(comment_text)

    # 1) COMMENT ON COLUMN
    try:
        spark.sql(f"COMMENT ON COLUMN {fqn_q}.{col_q} IS '{cmt}'")
        return True
    except Exception as e1:
        pass

    # 2) ALTER TABLE ... ALTER COLUMN ... COMMENT
    try:
        spark.sql(f"ALTER TABLE {fqn_q} ALTER COLUMN {col_q} COMMENT '{cmt}'")
        return True
    except Exception as e2:
        pass

    # 3) ALTER TABLE ... CHANGE COLUMN col col <type> COMMENT ...
    dtype = col_types_lower.get(col_name.lower())
    if not dtype:
        print(f"[WARN] Cannot determine data type for {cat}.{sch}.{tbl}.{col_name}; skipping column comment.")
        return False
    try:
        spark.sql(f"ALTER TABLE {fqn_q} CHANGE COLUMN {col_q} {col_q} {dtype} COMMENT '{cmt}'")
        return True
    except Exception as e3:
        print(f"[WARN] Failed to set comment for {cat}.{sch}.{tbl}.{col_name}: {e3}")
        return False

# --------------------------------------------------------------------------------------
# Table documentation application
#  - Table comment applied when just-created
#  - Column comments applied ALWAYS using robust fallback (fixes your error)
# --------------------------------------------------------------------------------------

def _apply_table_documentation_on_create(spark: SparkSession, table_fqn: str, doc: Dict[str, Any], just_created: bool):
    try:
        cat, sch, tbl = table_fqn.split(".")
    except ValueError:
        return

    table_comment = (doc or {}).get("table_comment") or ""
    if table_comment and just_created:
        spark.sql(
            f"COMMENT ON TABLE `{cat}`.`{sch}`.`{tbl}` IS '{_esc_sql_comment(table_comment)}'"
        )

    cols_doc: Dict[str, str] = (doc or {}).get("columns") or {}
    if not cols_doc:
        return

    # Discover existing columns and their types (for fallback syntax)
    desc_rows = spark.sql(f"DESCRIBE TABLE `{cat}`.`{sch}`.`{tbl}`").collect()
    existing_cols = {}
    col_types = {}
    for r in desc_rows:
        if r.col_name and not r.col_name.startswith("#"):
            existing_cols[r.col_name.lower()] = True
            # Some rows may not have data_type; guard defensively
            if hasattr(r, "data_type") and r.data_type:
                col_types[r.col_name.lower()] = r.data_type

    # Apply column comments when columns exist
    for col_name, cmt in cols_doc.items():
        if col_name.lower() not in existing_cols:
            continue
        _apply_column_comment_with_fallback(
            spark, cat, sch, tbl, col_name, cmt, col_types_lower=col_types
        )

# --------------------------------------------------------------------------------------
# Main generator
# --------------------------------------------------------------------------------------

class CheckGenerator:
    def __init__(
        self,
        scope: str,                         # "pipeline" | "catalog" | "schema" | "table" | "file" | "folder"
        source: str,                        # depends on scope
        output_format: str,                 # "yaml" | "table" | "both"
        output_yaml: Optional[str],         # folder or /Volumes/... or dbfs:/... or workspace "/..."
        output_table: Optional[str],        # fully-qualified table FQN
        profile_options: Dict[str, Any],
        exclude_prefix_regex: Optional[str] = None,   # regex on table prefix (before first "_")
        created_by: Optional[str] = "LMG",
        columns: Optional[List[str]] = None,          # only valid when scope=="table"
        run_config_name: str = "default",
        criticality: str = "warn",
        key_order: Literal["engine", "custom"] = "custom",
        include_table_name: bool = True,
        yaml_metadata: bool = False,                  # add commented header on each YAML
        table_doc: Optional[Dict[str, Any]] = None,   # documentation dict; defaults to DQX_GENERATED_CHECKS_CONFIG_METADATA
    ):
        self.scope = scope.lower().strip()
        self.source = source
        self.output_format = output_format.lower().strip()
        self.output_yaml = output_yaml
        self.output_table = output_table
        self.profile_options = profile_options or {}
        self.exclude_prefix_regex = exclude_prefix_regex
        self.created_by = created_by
        self.columns = columns
        self.run_config_name = run_config_name
        self.criticality = criticality
        self.key_order = key_order
        self.include_table_name = include_table_name
        self.yaml_metadata = yaml_metadata
        self.table_doc = table_doc or DQX_GENERATED_CHECKS_CONFIG_METADATA

        self.spark = SparkSession.getActiveSession() or SparkSession.builder.getOrCreate()
        self._validate_top_level()

    # -------------------
    # Validation
    # -------------------
    def _validate_top_level(self):
        allowed_scopes = {"pipeline", "catalog", "schema", "table", "file", "folder"}
        if self.scope not in allowed_scopes:
            raise ValueError(f"Invalid scope '{self.scope}'. Must be one of: {sorted(allowed_scopes)}.")

        allowed_formats = {"yaml", "table", "both"}
        if self.output_format not in allowed_formats:
            raise ValueError("output_format must be 'yaml' or 'table' or 'both'.")

        # Source expectations per scope (explicit notes)
        # scope="pipeline" -> source="pipeline_name1,pipeline_name2"
        # scope="catalog"  -> source="catalog"
        # scope="schema"   -> source="catalog.schema"
        # scope="table"    -> source="catalog.schema.table[,catalog.schema.table]"
        # scope="file"     -> source="<path to YAML file listing tables>"
        # scope="folder"   -> source="<path to folder with YAML table lists>"
        if self.scope == "catalog":
            if "." in self.source:
                raise ValueError("For scope='catalog', pass just the catalog name (no dots).")
        if self.scope == "schema":
            if self.source.count(".") != 1:
                raise ValueError("For scope='schema', pass 'catalog.schema'.")
        if self.scope == "table":
            for t in [x.strip() for x in self.source.split(",") if x.strip()]:
                if t.count(".") != 2:
                    raise ValueError(f"Invalid table FQN '{t}'. Use catalog.schema.table")
        if self.scope == "file":
            if not _is_yaml_path(self.source):
                raise ValueError("For scope='file', source must be a YAML path.")
        # folder: any path is ok; we'll scan recursively

        # Sinks
        if self.output_format == "yaml" and not self.output_yaml:
            raise ValueError("output_yaml is required when output_format='yaml'.")
        if self.output_format == "table" and not self.output_table:
            raise ValueError("output_table is required when output_format='table'.")
        if self.output_format == "both":
            if not self.output_yaml or not self.output_table:
                raise ValueError("When output_format='both', both output_yaml and output_table are required.")

    # -------------------
    # Discovery
    # -------------------
    def _walk_yaml_files(self, folder: str) -> List[str]:
        return ListYAMLPathsInFolder(folder)

    def _discover_tables(self) -> List[str]:
        print("\n===== PARAMETERS PASSED THIS RUN =====")
        print(f"scope:            {self.scope}")
        print(f"source:           {self.source}")
        print(f"output_format:    {self.output_format}")
        print(f"output_yaml:      {self.output_yaml}")
        print(f"output_table:     {self.output_table}")
        print(f"exclude_prefix_rx:{self.exclude_prefix_regex}")
        print(f"created_by:       {self.created_by}")
        print(f"columns:          {self.columns}")
        print(f"run_config_name:  {self.run_config_name}")
        print(f"criticality:      {self.criticality}")
        print(f"key_order:        {self.key_order}")
        print(f"include_table_name: {self.include_table_name}")
        print(f"yaml_metadata:    {self.yaml_metadata}")
        print("profile_options:")
        for k, v in self.profile_options.items():
            print(f"  {k}: {v}")
        print("======================================\n")

        discovered: List[str] = []

        if self.scope == "pipeline":
            print("Searching for pipeline output tables...")
            ws = WorkspaceClient()
            pipeline_names = [p.strip() for p in self.source.split(",") if p.strip()]
            for pipeline_name in pipeline_names:
                pls = list(ws.pipelines.list_pipelines())
                pl = next((p for p in pls if p.name == pipeline_name), None)
                if not pl:
                    raise RuntimeError(f"Pipeline '{pipeline_name}' not found via SDK.")
                latest_update = pl.latest_updates[0].update_id
                events = ws.pipelines.list_pipeline_events(pipeline_id=pl.pipeline_id, max_results=250)
                pipeline_tables = [
                    getattr(ev.origin, "flow_name", None)
                    for ev in events
                    if getattr(ev.origin, "update_id", None) == latest_update and getattr(ev.origin, "flow_name", None)
                ]
                discovered += [x for x in pipeline_tables if x]

        elif self.scope == "catalog":
            print("Discovering all tables in catalog...")
            catalog = self.source.strip()
            schemas = [row.namespace for row in self.spark.sql(f"SHOW SCHEMAS IN {catalog}").collect()]
            for s in schemas:
                tbls = self.spark.sql(f"SHOW TABLES IN {catalog}.{s}").collect()
                discovered += [f"{catalog}.{s}.{row.tableName}" for row in tbls]

        elif self.scope == "schema":
            print("Discovering all tables in schema...")
            catalog, schema = self.source.strip().split(".")
            tbls = self.spark.sql(f"SHOW TABLES IN {catalog}.{schema}").collect()
            discovered = [f"{catalog}.{schema}.{row.tableName}" for row in tbls]

        elif self.scope == "table":
            print("Using provided fully-qualified table(s)...")
            discovered = [t.strip() for t in self.source.split(",") if t.strip()]

        elif self.scope == "file":
            print("Reading table list from YAML file...")
            discovered = _discover_tables_from_yaml_file(self.source)

        else:  # folder
            print("Reading table lists from all YAML files in folder (recursive)...")
            yaml_files = self._walk_yaml_files(self.source)
            agg: List[str] = []
            for yp in yaml_files:
                try:
                    agg += _discover_tables_from_yaml_file(yp)
                except Exception as e:
                    print(f"[WARN] Skipping YAML '{yp}': {e}")
            discovered = sorted(set(agg))

        discovered = _filter_by_prefix_regex(discovered, self.exclude_prefix_regex)
        _display_table_preview(self.spark, discovered, title="Final table list to generate DQX rules for")
        print("==========================================\n")
        return discovered

    # -------------------
    # Profiler call args
    # -------------------
    def _profile_call_kwargs(self) -> Dict[str, Any]:
        kwargs: Dict[str, Any] = {}
        if self.columns is not None:
            kwargs["cols"] = self.columns
        if self.profile_options:
            unknown = sorted(set(self.profile_options) - DOC_SUPPORTED_KEYS)
            if unknown:
                print(f"[INFO] Profiling options not in current docs (passing through anyway): {unknown}")
            kwargs["options"] = self.profile_options
        return kwargs

    # -------------------
    # Rule shaping
    # -------------------
    def _dq_constraint_to_check(self, rule_name: str, constraint_sql: str, table_name: str) -> Dict[str, Any]:
        d = {
            "name": rule_name,
            "criticality": self.criticality,
            "run_config_name": self.run_config_name,
            "check": {
                "function": "sql_expression",
                "arguments": {
                    # remove duplicate 'name' from arguments per your preference
                    "expression": constraint_sql,
                }
            },
        }
        if self.include_table_name:
            d = {"table_name": table_name, **d}  # keep table_name first
        return d

    # -------------------
    # YAML emission (header + list items with blank lines)
    # -------------------
    def _yaml_header_block(self, table_fqn: str, env_info: Dict[str, Any]) -> str:
        dashed = "-" * 81
        lines = [
            "#" * 76,
            f"# GENERATED DQX CHECKS",
            f"# Table: {table_fqn}",
            f"# Generated at (UTC): {env_info.get('utc_time','')}",
            f"# Notebook: {env_info.get('notebook_path','Unknown')}",
            f"# Spark: {env_info.get('spark_version','')}  |  Python: {env_info.get('python_version','')}",
            f"# Cluster: {env_info.get('cluster_name','')} ({env_info.get('cluster_id','')})  |  Executor memory: {env_info.get('executor_memory','')}",
            "#" * 76,
            "",
            f"# {dashed}",
            "# Profile options:",
            "# " + _safe_json(self.profile_options),
            "# Generator settings:",
            "# " + _safe_json({
                "scope": self.scope,
                "source": self.source,
                "output_format": self.output_format,
                "output_yaml": self.output_yaml,
                "output_table": self.output_table,
                "criticality": self.criticality,
                "run_config_name": self.run_config_name,
                "include_table_name": self.include_table_name,
                "key_order": self.key_order,
                "exclude_prefix_regex": self.exclude_prefix_regex,
            }),
        ]
        return "\n".join(lines) + "\n\n"

    def _dump_rules_as_yaml_stream(self, rules: List[Dict[str, Any]]) -> str:
        """
        Emit a single YAML document that is a list of rule objects:
        - table_name: ...
          name: ...
          ...
        (Blank line between items for readability.)
        """
        pieces: List[str] = []
        for r in rules:
            block = yaml.safe_dump(r, sort_keys=False, default_flow_style=False).rstrip()
            lines = block.splitlines()
            if not lines:
                continue
            first = f"- {lines[0]}"
            rest = "\n".join(("  " + ln) for ln in lines[1:])
            pieces.append(first + ("\n" + rest if rest else ""))
        return "\n\n".join(pieces) + "\n"

    # -------------------
    # Table write helpers
    # -------------------
    def _ensure_schema_exists(self, fqn: str):
        cat, sch, _ = fqn.split(".")
        self.spark.sql(f"CREATE SCHEMA IF NOT EXISTS `{cat}`.`{sch}`")

    def _write_rows_to_table(self, fqn: str, rows: List[Dict[str, Any]], mode: str = "append"):
        self._ensure_schema_exists(fqn)
        existed = self.spark.catalog.tableExists(fqn)
        if not existed:
            # create empty table with correct schema first, then apply docs
            empty_df = self.spark.createDataFrame([], DQX_GENERATED_CHECKS_CONFIG_SCHEMA)
            empty_df.write.format("delta").mode("overwrite").saveAsTable(fqn)
        # Apply docs (table comment only on create; column comments always with fallback)
        _apply_table_documentation_on_create(self.spark, fqn, {**self.table_doc, "table": fqn}, just_created=(not existed))

        df = self.spark.createDataFrame(rows, schema=DQX_GENERATED_CHECKS_CONFIG_SCHEMA)
        df.write.format("delta").mode(mode).saveAsTable(fqn)
        print(f"[WRITE] {len(rows)} rows -> {fqn} ({mode})")

    # -------------------
    # Summary display
    # -------------------
    def _show_summary_table(self, summary: Dict[str, Dict[str, Any]]):
        if not summary:
            display_section("Checks written per table")
            print("(no tables processed)")
            return
        rows = []
        for t, s in summary.items():
            rows.append((
                s.get("table_name", t),
                int(s.get("checks_generated", 0)),
                bool(s.get("wrote_yaml", False)),
                s.get("yaml_path", None),
                int(s.get("table_rows_written", 0)),
                self.output_table or "",
            ))
        schema = "table_name string, checks_generated int, wrote_yaml boolean, yaml_path string, table_rows_written int, output_table string"
        df = self.spark.createDataFrame(rows, schema=schema)
        display_section("Checks written per table")
        show_df(df.orderBy("table_name"))

    # -------------------
    # Main
    # -------------------
    def run(self):
        dq_engine = DQEngine(WorkspaceClient())
        profiler = DQProfiler(WorkspaceClient())
        generator = DQDltGenerator(WorkspaceClient())

        env_info = print_notebook_env(self.spark)  # prints banner and returns dict
        call_kwargs = self._profile_call_kwargs()
        tables = self._discover_tables()

        all_rows_for_table_sink: List[Dict[str, Any]] = []
        written_yaml_paths: List[str] = []

        # Per-table summary tracking
        per_table_summary: Dict[str, Dict[str, Any]] = {}

        for fq in tables:
            if fq.count(".") != 2:
                print(f"[WARN] Skipping invalid table name: {fq}")
                continue

            cat, sch, tab = fq.split(".")
            per_table_summary.setdefault(fq, {
                "table_name": fq,
                "checks_generated": 0,
                "wrote_yaml": False,
                "yaml_path": None,
                "table_rows_written": 0,
            })

            try:
                self.spark.table(fq).limit(1).collect()  # readability probe
            except Exception as e:
                print(f"[WARN] Table not readable: {fq} -> {e}")
                continue

            # Profile & generate DLT rules
            try:
                df = self.spark.table(fq)
                _, profiles = profiler.profile(df, **call_kwargs)
                rules_dict = generator.generate_dlt_rules(profiles, language="Python_Dict")
            except Exception as e:
                print(f"[WARN] Profiling/rule-gen failed for {fq}: {e}")
                continue

            # Shape checks
            checks: List[Dict[str, Any]] = []
            for rule_name, constraint_sql in (rules_dict or {}).items():
                checks.append(self._dq_constraint_to_check(rule_name, constraint_sql, fq))

            per_table_summary[fq]["checks_generated"] = len(checks)

            if not checks:
                print(f"[INFO] No checks generated for {fq}.")
                continue

            # YAML sink
            yaml_path_for_rows: str = f"<generated://{fq}>"
            if self.output_format in {"yaml", "both"}:
                if self.output_yaml.endswith((".yaml", ".yml")):
                    path = self.output_yaml  # explicit file path (edge case)
                else:
                    path = f"{self.output_yaml.rstrip('/')}/{tab}.yaml"

                if self.key_order == "engine":
                    cfg = self._infer_file_storage_config(path)
                    dq_engine.save_checks(checks, config=cfg)  # writes list-of-dicts
                else:
                    header = self._yaml_header_block(fq, env_info) if self.yaml_metadata else ""
                    body = self._dump_rules_as_yaml_stream(checks)
                    _write_text_any(path, header + body)

                yaml_path_for_rows = path
                written_yaml_paths.append(path)
                per_table_summary[fq]["wrote_yaml"] = True
                per_table_summary[fq]["yaml_path"] = path
                print(f"[RUN] Wrote {len(checks)} rule(s) to YAML: {path}")

            # Prepare table rows (for TABLE only path; BOTH path reloads YAMLs to avoid drift)
            if self.output_format == "table":
                gen_meta = [
                    {"section": "profile_options", "payload": _stringify_map_values(self.profile_options)},
                    {"section": "generator_settings", "payload": _stringify_map_values({
                        "scope": self.scope, "source": self.source, "output_format": self.output_format,
                        "output_yaml": self.output_yaml or "", "output_table": self.output_table or "",
                        "criticality": self.criticality, "run_config_name": self.run_config_name,
                        "include_table_name": self.include_table_name, "key_order": self.key_order,
                        "exclude_prefix_regex": self.exclude_prefix_regex or "",
                    })},
                ]
                for rule in checks:
                    raw_check = rule["check"]
                    payload = _compute_check_id_payload(fq, raw_check, rule.get("filter"))
                    all_rows_for_table_sink.append({
                        "check_id": _compute_check_id(payload),
                        "check_id_payload": payload,
                        "table_name": fq,
                        "name": rule["name"],
                        "criticality": rule["criticality"],
                        "check": {
                            "function": raw_check.get("function"),
                            "for_each_column": raw_check.get("for_each_column"),
                            "arguments": _stringify_map_values(raw_check.get("arguments") or {}),
                        },
                        "filter": rule.get("filter"),
                        "run_config_name": rule["run_config_name"],
                        "user_metadata": _stringify_map_values(rule.get("user_metadata") or None) or None,
                        "yaml_path": yaml_path_for_rows,
                        "active": True,
                        "generator_meta": gen_meta,
                        "created_by": self.created_by,
                        "created_at": _now_iso(),
                        "updated_by": None,
                        "updated_at": None,
                    })

        # BOTH → reload the exact YAMLs we wrote and write those rows (canonical)
        if self.output_format == "both":
            rows_from_yaml: List[Dict[str, Any]] = []
            for yp in written_yaml_paths:
                try:
                    txt = _read_text_any(yp)
                    docs = list(yaml.safe_load_all(io.StringIO(txt)))
                    rules: List[dict] = []
                    for d in docs:
                        if not d:
                            continue
                        if isinstance(d, dict):
                            rules.append(d)
                        elif isinstance(d, list):
                            rules.extend([x for x in d if isinstance(x, dict)])

                    # If file is a single list (our custom format), docs will be [list]; handled above.
                    for r in rules:
                        fq = r.get("table_name")
                        raw_check = r.get("check") or {}
                        payload = _compute_check_id_payload(fq, raw_check, r.get("filter"))
                        row_obj = {
                            "check_id": _compute_check_id(payload),
                            "check_id_payload": payload,
                            "table_name": fq,
                            "name": r.get("name"),
                            "criticality": r.get("criticality"),
                            "check": {
                                "function": raw_check.get("function"),
                                "for_each_column": raw_check.get("for_each_column"),
                                "arguments": _stringify_map_values(raw_check.get("arguments") or {}),
                            },
                            "filter": r.get("filter"),
                            "run_config_name": r.get("run_config_name", self.run_config_name),
                            "user_metadata": _stringify_map_values(r.get("user_metadata") or None) or None,
                            "yaml_path": yp,
                            "active": True,
                            "generator_meta": [
                                {"section": "profile_options", "payload": _stringify_map_values(self.profile_options)},
                                {"section": "generator_settings", "payload": _stringify_map_values({
                                    "scope": self.scope, "source": self.source, "output_format": self.output_format,
                                    "output_yaml": self.output_yaml or "", "output_table": self.output_table or "",
                                    "criticality": self.criticality,
                                    "run_config_name": r.get("run_config_name", self.run_config_name),
                                    "include_table_name": self.include_table_name, "key_order": self.key_order,
                                    "exclude_prefix_regex": self.exclude_prefix_regex or "",
                                })},
                            ],
                            "created_by": self.created_by,
                            "created_at": _now_iso(),
                            "updated_by": None,
                            "updated_at": None,
                        }
                        rows_from_yaml.append(row_obj)
                        # count per-table table_rows_written (we'll add to summary after the write too)
                        if fq in per_table_summary:
                            per_table_summary[fq]["table_rows_written"] = per_table_summary[fq].get("table_rows_written", 0) + 1
                        else:
                            per_table_summary[fq] = {
                                "table_name": fq,
                                "checks_generated": 0,
                                "wrote_yaml": True,
                                "yaml_path": yp,
                                "table_rows_written": 1,
                            }
                except Exception as e:
                    print(f"[WARN] Could not load back YAML '{yp}' for table sink: {e}")

            if self.output_table and rows_from_yaml:
                self._write_rows_to_table(self.output_table, rows_from_yaml, mode="append")
            print(f"[DONE] Wrote YAML files ({len(written_yaml_paths)}). Then loaded {len(rows_from_yaml)} rows into {self.output_table}.")

        elif self.output_format == "table":
            if self.output_table and all_rows_for_table_sink:
                # update per-table counts before write
                for r in all_rows_for_table_sink:
                    fq = r["table_name"]
                    per_table_summary.setdefault(fq, {
                        "table_name": fq,
                        "checks_generated": 0,
                        "wrote_yaml": False,
                        "yaml_path": None,
                        "table_rows_written": 0,
                    })
                    per_table_summary[fq]["table_rows_written"] = per_table_summary[fq].get("table_rows_written", 0) + 1

                self._write_rows_to_table(self.output_table, all_rows_for_table_sink, mode="append")
            print(f"[DONE] Wrote {len(all_rows_for_table_sink)} rows into {self.output_table}.")
        else:
            print(f"[DONE] Wrote YAML files ({len(written_yaml_paths)}).")

        # Print a nice per-table summary
        self._show_summary_table(per_table_summary)

    # Storage config passthrough (kept)
    @staticmethod
    def _infer_file_storage_config(file_path: str):
        if file_path.startswith("/Volumes/"):
            return VolumeFileChecksStorageConfig(location=file_path)
        if file_path.startswith("/"):
            return WorkspaceFileChecksStorageConfig(location=file_path)
        return FileChecksStorageConfig(location=file_path)

    @staticmethod
    def _table_storage_config(table_fqn: str, run_config_name: Optional[str] = None, mode: str = "append"):
        return TableChecksStorageConfig(location=table_fqn, run_config_name=run_config_name, mode=mode)


# -------------------- Usage examples --------------------
if __name__ == "__main__":
    profile_options = {
        "sample_fraction": 0.3,
        "sample_seed": 42,
        "limit": 1000,
        "remove_outliers": True,
        "outlier_columns": [],
        "num_sigmas": 3,
        "max_null_ratio": 0.01,
        "trim_strings": True,
        "max_empty_ratio": 0.01,
        "distinct_ratio": 0.05,
        "max_in_count": 10,
        "round": True,
        # other passthrough keys are fine
    }

    # Example A — scope="table": write BOTH (YAMLs first, then load exactly those YAMLs into the table)
    CheckGenerator(
        scope="schema",                                  # "pipeline" | "catalog" | "schema" | "table" | "file" | "folder"
        source="de_prd.gold",
        output_format="both",                           # "yaml" | "table" | "both"
        output_yaml="/Volumes/dq_dev/dqx/generated_checks/2025-08-12/",  # Workspace Folder or Volume
        yaml_metadata=True,      #(yaml_modifier)      # True = Add run metadata to file | False = Don't add metadata to file
        key_order="custom",      #(yaml_modifier)      # "custom" = our ordered YAML with list items; "engine" = DQX default writer
        include_table_name=True, #(yaml_modifier)
        output_table="dq_dev.dqx.generated_checks_config",  # 'Table_fqn'  | 'None' if not writing to table
        profile_options=profile_options,
        exclude_prefix_regex=r"^tama",               # exclude tables whose prefix (before _) matches
        created_by="LMG",          #(table_modifier)      # populates 'created_by' column
        columns=None,              #(check_modifier)      # only valid when scope=="table"
        run_config_name="default", #(check_modifier)
        criticality="error",       #(check_modifier)       # "error" | "warn"
        table_doc=DQX_GENERATED_CHECKS_CONFIG_METADATA,   # used if/when we create the output table
    ).run()

    """
    # Example B — scope="catalog": write BOTH (YAMLs first, then load those YAMLs into table)
    CheckGenerator(
        scope="catalog",
        source="de_prd",
        output_format="both",
        output_yaml="dbfs:/mnt/dqx/generated_checks/de_prd",
        output_table="dq_dev.dqx.checks_generated_config",
        profile_options=profile_options,
        exclude_prefix_regex=r"^tamarack$",
        created_by="LMG",
        columns=None,
        run_config_name="default",
        criticality="error",
        key_order="custom",
        include_table_name=True,
        yaml_metadata=True,
        table_doc=DQX_GENERATED_CHECKS_CONFIG_METADATA,
    ).run()

    # Example C — scope="table": write directly to table (no YAML)
    CheckGenerator(
        scope="table",
        source="dq_prd.monitoring.job_run_audit",
        output_format="table",
        output_yaml=None,
        output_table="dq_dev.dqx.checks_generated_config",
        profile_options=profile_options,
        exclude_prefix_regex=None,
        created_by="LMG",
        columns=None,
        run_config_name="default",
        criticality="error",
        key_order="custom",
        include_table_name=True,
        yaml_metadata=False,
        table_doc=DQX_GENERATED_CHECKS_CONFIG_METADATA,
    ).run()
    """