# Generate DQX Checks

In [0]:
# -- DQX profiling options: all known keys --
profile_options = {
    # "round": True,                 # (Removed - not valid for this profiler version)
    "max_in_count": 10,            # Max distinct values for is_in rule
    "distinct_ratio": 0.05,        # Max unique/total ratio for is_in rule
    "max_null_ratio": 0.01,        # Max null fraction to allow is_not_null rule
    "remove_outliers": True,       # Remove outliers for min/max
    "outlier_columns": [],         # Only these columns get outlier removal (empty=all numerics)
    "num_sigmas": 3,               # Stddev for outlier removal (z-score cutoff)
    "trim_strings": True,          # Strip whitespace before profiling strings
    "max_empty_ratio": 0.01,       # Max empty string ratio for is_not_null_or_empty
    "sample_fraction": 0.3,        # Row fraction to sample
    "sample_seed": None,           # Seed for reproducibility (set int for deterministic)
    "limit": 1000,                 # Max number of rows to profile
    "profile_types": None,         # List of rule types (e.g. ["is_in", "is_not_null"]); None=default
    "min_length": None,            # Min string length to consider (None disables)
    "max_length": None,            # Max string length to consider (None disables)
    "include_histograms": False,   # Compute histograms as part of profiling
    "min_value": None,             # Numeric min override (None disables)
    "max_value": None,             # Numeric max override (None disables)
}

In [0]:
%pip install databricks-labs-dqx==0.8.0

In [0]:
dbutils.library.restartPython()

In [0]:
# dqx_rule_generator.py
# Full file — adds YAML list support via header-comment parsing + display() preview.
# NOTHING REMOVED: all params/options stay. Renamed yaml_key_order -> key_order.
# Single-table & CSV FQNs still work. One YAML file per table is written in output_location.

import os
import re
import io
import json
from typing import List, Optional, Dict, Any, Literal, Tuple

import yaml
from databricks.sdk import WorkspaceClient
from databricks.labs.dqx.profiler.profiler import DQProfiler
from databricks.labs.dqx.profiler.dlt_generator import DQDltGenerator
from databricks.labs.dqx.engine import DQEngine
from databricks.labs.dqx.config import (
    FileChecksStorageConfig,
    WorkspaceFileChecksStorageConfig,
    TableChecksStorageConfig,
    VolumeFileChecksStorageConfig,
)
from pyspark.sql import SparkSession


def glob_to_regex(glob_pattern: str) -> str:
    if not glob_pattern or not glob_pattern.startswith('.'):
        raise ValueError("Exclude pattern must start with a dot, e.g. '.tamarack_*'")
    glob = glob_pattern[1:]
    regex = re.escape(glob).replace(r'\*', '.*')
    return '^' + regex + '$'


# Keys currently shown in DQX docs for profiler "options" (kept permissive; we only warn on extras)
DOC_SUPPORTED_KEYS = {
    "sample_fraction",
    "sample_seed",
    "limit",
    "remove_outliers",
    "outlier_columns",
    "num_sigmas",
    "max_null_ratio",
    "trim_strings",
    "max_empty_ratio",
    "distinct_ratio",
    "max_in_count",
    "round",
}

# ---------- YAML file support (header-comment parsing for defaults) ----------
_YAML_PATH_RE = re.compile(r"\.(ya?ml)$", re.IGNORECASE)
_FROM_INFOSCHEMA = re.compile(r"FROM\s+([A-Za-z0-9_]+)\.information_schema\.tables", re.IGNORECASE)
_TABLE_SCHEMA_EQ = re.compile(r"table_schema\s*=\s*'([^']+)'", re.IGNORECASE)

def _is_yaml_path(path: str) -> bool:
    return bool(_YAML_PATH_RE.search(path))

def _resolve_local_like_path(path: str) -> Optional[str]:
    """
    Try hard to resolve a repo-relative file path from a notebook:
    - as-given
    - join with cwd
    - walk up 6 parents and join
    Returns a filesystem path if found, else None.
    """
    candidates = []
    if os.path.exists(path):
        return path
    cwd = os.getcwd()
    base = cwd
    for _ in range(6):
        cand = os.path.abspath(os.path.join(base, path))
        candidates.append(cand)
        if os.path.exists(cand):
            return cand
        parent = os.path.dirname(base)
        if parent == base:
            break
        base = parent
    return None

def _read_text_any(path: str) -> str:
    """
    Read small text files from:
      - repo/local filesystem (relative or absolute)
      - dbfs:/, /dbfs/, /Volumes/
      - workspace files path (starts with '/'), via Files API
    """
    # DBFS / Volumes
    if path.startswith("dbfs:/") or path.startswith("/dbfs/") or path.startswith("/Volumes/"):
        try:
            from databricks.sdk.runtime import dbutils
        except Exception as e:
            raise RuntimeError("dbutils is required to read from DBFS/Volumes") from e
        target = path if path.startswith("dbfs:") else (f"dbfs:{path}" if path.startswith("/") else f"dbfs:/{path}")
        return dbutils.fs.head(target, 10 * 1024 * 1024)

    # Absolute workspace file path (e.g. /Workspace/Repos/.../file.yaml)
    if path.startswith("/"):
        wc = WorkspaceClient()
        try:
            # newer SDK signature
            data = wc.files.download(file_path=path).read()
        except TypeError:
            # older SDK signature
            data = wc.files.download(path=path).read()
        return data.decode("utf-8")

    # Relative/local resolution (repo-friendly)
    resolved = _resolve_local_like_path(path)
    if resolved and os.path.isfile(resolved):
        with open(resolved, "r", encoding="utf-8") as fh:
            return fh.read()

    # Nothing worked → helpful error
    msg = [
        f"Could not find YAML at '{path}'.",
        f"cwd={os.getcwd()}",
        "Hints:",
        "  - If this file is in your repo, pass a path relative to the notebook or the repo root.",
        "  - Or pass an absolute workspace path like '/Workspace/Repos/.../file.yaml'.",
        "  - Or use a DBFS/Volumes path like 'dbfs:/...' or '/Volumes/...'.",
    ]
    raise FileNotFoundError("\n".join(msg))

def _parse_global_hints_from_comments(text: str) -> Tuple[Optional[str], Optional[str]]:
    """
    Pull defaults from header comments only (no YAML keys):
      - catalog from: 'FROM <catalog>.information_schema.tables'
      - schema  from: "table_schema = '<schema>'"
    """
    m_cat = _FROM_INFOSCHEMA.search(text)
    m_sch = _TABLE_SCHEMA_EQ.search(text)
    cat = m_cat.group(1) if m_cat else None
    sch = m_sch.group(1) if m_sch else None
    return cat, sch

def _ensure_fqns(names: List[str], hint_catalog: Optional[str], hint_schema: Optional[str]) -> List[str]:
    """
    Accept: catalog.schema.table | schema.table | table
    Use comment-derived defaults to fill missing parts. Dotted entries override per item.
    """
    out: List[str] = []
    for n in names:
        n = n.strip()
        if not n:
            continue
        parts = n.split(".")
        if len(parts) == 3:
            out.append(n)
        elif len(parts) == 2:
            if not hint_catalog:
                raise ValueError(f"'{n}' lacks catalog; add it to the item or provide a comment default.")
            out.append(f"{hint_catalog}.{n}")
        elif len(parts) == 1:
            if not (hint_catalog and hint_schema):
                raise ValueError(f"'{n}' needs catalog & schema; set via comments or use dotted forms.")
            out.append(f"{hint_catalog}.{hint_schema}.{n}")
        else:
            raise ValueError(f"Unrecognized table format: {n}")
    bad = [t for t in out if t.count(".") != 2]
    if bad:
        raise ValueError(f"Invalid FQN(s) after resolution: {bad}")
    return sorted(set(out))

def _discover_tables_from_yaml_by_comments(yaml_path: str) -> List[str]:
    """
    Load the YAML file (with e.g. 'table_name:' list) and use ONLY header comments
    to infer catalog/schema defaults. Dotted items override per entry.
    """
    text = _read_text_any(yaml_path)
    cat_hint, sch_hint = _parse_global_hints_from_comments(text)

    obj = yaml.safe_load(io.StringIO(text))
    if not isinstance(obj, dict):
        raise ValueError(f"YAML must contain a mapping with a list; got: {type(obj).__name__}")

    names = None
    for key in ("table_name", "tables", "table_names", "list"):
        if isinstance(obj.get(key), list):
            names = [str(x).strip() for x in obj[key] if x]
            break
    if not names:
        raise ValueError(f"No table list found in YAML: {yaml_path}")

    fqns = _ensure_fqns(names, cat_hint, sch_hint)
    print(f"[INFO] Parsed defaults from comments: catalog={cat_hint} schema={sch_hint}")
    return fqns

def _display_table_preview(spark: SparkSession, fqns: List[str], title: str = "Resolved Tables") -> None:
    """
    Use display() (Databricks) for a clean, interactive preview instead of printing.
    """
    rows = [(f, *f.split(".")) for f in fqns]
    df = spark.createDataFrame(rows, "fqn string, catalog string, schema string, table string")
    print(f"\n=== {title} ({len(fqns)}) ===")
    try:
        display(df)
    except NameError:
        df.show(len(fqns), truncate=False)


class RuleGenerator:
    def __init__(
        self,
        mode: str,                       # "pipeline" | "catalog" | "schema" | "table"
        name_param: str,                 # pipeline CSV | catalog | catalog.schema | table FQN CSV | YAML path
        output_format: str,              # "yaml" | "table"
        output_location: str,            # yaml: folder or file; table: catalog.schema.table
        profile_options: Dict[str, Any],
        exclude_pattern: Optional[str] = None,     # e.g. ".tmp_*"
        created_by: Optional[str] = "LMG",
        columns: Optional[List[str]] = None,       # None => whole table (only if mode=="table")
        run_config_name: str = "default",          # DQX run group tag
        criticality: str = "warn",                 # "warn" | "error"
        key_order: Literal["engine", "custom"] = "custom",  # "engine" uses DQX save; "custom" enforces key order
        include_table_name: bool = True,           # include table_name in each rule dict
    ):
        self.mode = mode.lower().strip()
        self.name_param = name_param
        self.output_format = output_format.lower().strip()
        self.output_location = output_location
        self.profile_options = profile_options or {}
        self.exclude_pattern = exclude_pattern
        self.created_by = created_by
        self.columns = columns
        self.run_config_name = run_config_name
        self.criticality = criticality
        self.key_order = key_order
        self.include_table_name = include_table_name

        self.spark = SparkSession.getActiveSession()
        if not self.spark:
            raise RuntimeError("No active Spark session found. Run this in a Databricks notebook.")

        if self.output_format not in {"yaml", "table"}:
            raise ValueError("output_format must be 'yaml' or 'table'.")
        if self.output_format == "yaml" and not self.output_location:
            raise ValueError("When output_format='yaml', provide output_location (folder or file).")

    # ---------- profile options: pass via 'options' kwarg, warn on unknowns ----------
    def _profile_call_kwargs(self) -> Dict[str, Any]:
        """
        Build kwargs for DQProfiler.profile / profile_table.
        We pass:
          - cols=self.columns (when provided)
          - options=self.profile_options (dict; profiler reads keys internally)
        We only WARN on keys not in the current documented set; we do not drop them.
        """
        kwargs: Dict[str, Any] = {}
        if self.columns is not None:
            kwargs["cols"] = self.columns
        if self.profile_options:
            unknown = sorted(set(self.profile_options) - DOC_SUPPORTED_KEYS)
            if unknown:
                print(f"[INFO] Profiling options not in current docs (passing through anyway): {unknown}")
            kwargs["options"] = self.profile_options
        return kwargs

    # ---------- discovery ----------
    def _exclude_tables_by_pattern(self, fq_tables: List[str]) -> List[str]:
        if not self.exclude_pattern:
            return fq_tables
        regex = glob_to_regex(self.exclude_pattern)
        pattern = re.compile(regex)
        filtered = []
        for fq in fq_tables:
            tbl = fq.split('.')[-1]
            if not pattern.match(tbl):
                filtered.append(fq)
        print(f"[INFO] Excluded {len(fq_tables) - len(filtered)} tables by pattern '{self.exclude_pattern}'")
        return filtered

    def _discover_tables(self) -> List[str]:
        print("\n===== PARAMETERS PASSED THIS RUN =====")
        print(f"mode:             {self.mode}")
        print(f"name_param:       {self.name_param}")
        print(f"output_format:    {self.output_format}")
        print(f"output_location:  {self.output_location}")
        print(f"exclude_pattern:  {self.exclude_pattern}")
        print(f"created_by:       {self.created_by}")
        print(f"columns:          {self.columns}")
        print(f"run_config_name:  {self.run_config_name}")
        print(f"criticality:      {self.criticality}")
        print(f"key_order:        {self.key_order}")
        print(f"include_table_name: {self.include_table_name}")
        print(f"profile_options:")
        for k, v in self.profile_options.items():
            print(f"  {k}: {v}")
        print("======================================\n")

        allowed_modes = {"pipeline", "catalog", "schema", "table"}
        if self.mode not in allowed_modes:
            raise ValueError(f"Invalid mode '{self.mode}'. Must be one of: {sorted(allowed_modes)}.")
        if self.columns is not None and self.mode != "table":
            raise ValueError("The 'columns' parameter can only be used in mode='table'.")

        discovered: List[str] = []

        if self.mode == "pipeline":
            print("Searching for pipeline output tables...")
            ws = WorkspaceClient()
            pipelines = [p.strip() for p in self.name_param.split(",") if p.strip()]
            print(f"Pipelines passed: {pipelines}")
            for pipeline_name in pipelines:
                print(f"Finding output tables for pipeline: {pipeline_name}")
                pls = list(ws.pipelines.list_pipelines())
                pl = next((p for p in pls if p.name == pipeline_name), None)
                if not pl:
                    raise RuntimeError(f"Pipeline '{pipeline_name}' not found via SDK.")
                latest_update = pl.latest_updates[0].update_id
                events = ws.pipelines.list_pipeline_events(pipeline_id=pl.pipeline_id, max_results=250)
                pipeline_tables = [
                    getattr(ev.origin, "flow_name", None)
                    for ev in events
                    if getattr(ev.origin, "update_id", None) == latest_update and getattr(ev.origin, "flow_name", None)
                ]
                discovered += [x for x in pipeline_tables if x]

        elif self.mode == "catalog":
            print("Searching for tables in catalog...")
            catalog = self.name_param.strip()
            schemas = [row.namespace for row in self.spark.sql(f"SHOW SCHEMAS IN {catalog}").collect()]
            for s in schemas:
                tbls = self.spark.sql(f"SHOW TABLES IN {catalog}.{s}").collect()
                discovered += [f"{catalog}.{s}.{row.tableName}" for row in tbls]

        elif self.mode == "schema":
            print("Searching for tables in schema...")
            if self.name_param.count(".") != 1:
                raise ValueError("For 'schema' mode, name_param must be catalog.schema")
            catalog, schema = self.name_param.strip().split(".")
            tbls = self.spark.sql(f"SHOW TABLES IN {catalog}.{schema}").collect()
            discovered = [f"{catalog}.{schema}.{row.tableName}" for row in tbls]

        else:  # table
            print("Profiling one or more specific tables...")
            if _is_yaml_path(self.name_param):
                print(f"[INFO] name_param is a YAML list → {self.name_param}")
                discovered = _discover_tables_from_yaml_by_comments(self.name_param)
            else:
                tables = [t.strip() for t in self.name_param.split(",") if t.strip()]
                for t in tables:
                    if t.count(".") != 2:
                        raise ValueError(f"Table name '{t}' must be fully qualified (catalog.schema.table)")
                discovered = tables

        print("\nRunning exclude pattern filtering (if any)...")
        discovered = self._exclude_tables_by_pattern(discovered)

        # Interactive preview instead of plain prints
        _display_table_preview(self.spark, discovered, title="Final table list to generate DQX rules for")

        print("==========================================\n")
        return sorted(set(discovered))

    # ---------- storage config helpers ----------
    @staticmethod
    def _infer_file_storage_config(file_path: str):
        if file_path.startswith("/Volumes/"):
            return VolumeFileChecksStorageConfig(location=file_path)
        if file_path.startswith("/"):
            return WorkspaceFileChecksStorageConfig(location=file_path)
        return FileChecksStorageConfig(location=file_path)

    @staticmethod
    def _table_storage_config(table_fqn: str, run_config_name: Optional[str] = None, mode: str = "append"):
        return TableChecksStorageConfig(location=table_fqn, run_config_name=run_config_name, mode=mode)

    @staticmethod
    def _workspace_files_upload(path: str, payload: bytes) -> None:
        wc = WorkspaceClient()
        try:
            wc.files.upload(file_path=path, contents=payload, overwrite=True)  # newer SDK
        except TypeError:
            wc.files.upload(path=path, contents=payload, overwrite=True)       # older SDK

    @staticmethod
    def _ensure_parent(path: str) -> None:
        """Create parent dir for local paths."""
        parent = os.path.dirname(path)
        if parent and not os.path.exists(parent):
            os.makedirs(parent, exist_ok=True)

    @staticmethod
    def _ensure_dbfs_parent(dbutils, path: str) -> None:
        parent = path.rsplit("/", 1)[0] if "/" in path else path
        if parent:
            dbutils.fs.mkdirs(parent)

    # ---------- DQX check shaping ----------
    def _dq_constraint_to_check(
        self,
        rule_name: str,
        constraint_sql: str,
        table_name: str,
        criticality: str,
        run_config_name: str
    ) -> Dict[str, Any]:
        """
        Convert a profiler constraint (SQL) into a DQX check dict.
        Key order: table_name, name, criticality, run_config_name, check (insertion order).
        """
        d = {
            "name": rule_name,
            "criticality": criticality,
            "run_config_name": run_config_name,
            "check": {
                "function": "sql_expression",
                "arguments": {
                    "expression": constraint_sql,
                    "name": rule_name,
                }
            },
        }
        if self.include_table_name:
            d = {"table_name": table_name, **d}
        return d

    # ---------- YAML writers ----------
    def _write_yaml_ordered(self, checks: List[Dict[str, Any]], path: str) -> None:
        """
        Dump YAML preserving key order and upload:
          - /Volumes/... | dbfs:/... | /dbfs/... -> dbutils.fs.put (mkdirs parent)
          - /Shared/... (workspace files) -> Files API
          - relative/local path -> os.makedirs + open(...)
        """
        yaml_str = yaml.safe_dump(checks, sort_keys=False, default_flow_style=False)

        # DBFS / Volumes
        if path.startswith("dbfs:/") or path.startswith("/dbfs/") or path.startswith("/Volumes/"):
            try:
                from databricks.sdk.runtime import dbutils
            except Exception:
                raise RuntimeError("dbutils is required to write to DBFS/Volumes.")
            target = path if path.startswith("dbfs:/") else (f"dbfs:{path}" if not path.startswith("dbfs:") else path)
            self._ensure_dbfs_parent(dbutils, target.rsplit("/", 1)[0])
            dbutils.fs.put(target, yaml_str, True)
            print(f"[RUN] Wrote ordered YAML to {path}")
            return

        # Workspace files
        if path.startswith("/"):
            self._workspace_files_upload(path, yaml_str.encode("utf-8"))
            print(f"[RUN] Wrote ordered YAML to workspace file: {path}")
            return

        # Local (driver) relative/absolute filesystem (Repos-friendly)
        full_path = os.path.abspath(path)
        self._ensure_parent(full_path)
        with open(full_path, "w", encoding="utf-8") as f:
            f.write(yaml_str)
        print(f"[RUN] Wrote ordered YAML to local path: {full_path}")

    # ---------- main ----------
    def run(self):
        try:
            tables = self._discover_tables()
            print("[RUN] Beginning DQX rule generation on these tables:")
            for t in tables:
                print(f"  {t}")
            print("==========================================\n")

            call_kwargs = self._profile_call_kwargs()
            print("[RUN] Profiler call kwargs:")
            print(f"  cols:    {call_kwargs.get('cols')}")
            print(f"  options: {json.dumps(call_kwargs.get('options', {}), indent=2)}")

            dq_engine = DQEngine(WorkspaceClient())
            total_checks = 0

            for fq_table in tables:
                if fq_table.count(".") != 2:
                    print(f"[WARN] Skipping invalid table name: {fq_table}")
                    continue
                cat, sch, tab = fq_table.split(".")

                # Verify readability
                try:
                    print(f"[RUN] Checking table readability: {fq_table}")
                    self.spark.table(fq_table).limit(1).collect()
                except Exception as e:
                    print(f"[WARN] Table {fq_table} not readable in Spark: {e}")
                    continue

                profiler = DQProfiler(WorkspaceClient())
                generator = DQDltGenerator(WorkspaceClient())
                df = self.spark.table(fq_table)

                try:
                    print(f"[RUN] Profiling and generating rules for: {fq_table}")
                    # DataFrame profiling with options and optional cols
                    summary_stats, profiles = profiler.profile(df, **call_kwargs)
                    # If you prefer table-based API:
                    # summary_stats, profiles = profiler.profile_table(table=fq_table, **call_kwargs)

                    rules_dict = generator.generate_dlt_rules(profiles, language="Python_Dict")
                except Exception as e:
                    print(f"[WARN] Profiling failed for {fq_table}: {e}")
                    continue

                checks: List[Dict[str, Any]] = []
                for rule_name, constraint in (rules_dict or {}).items():
                    checks.append(
                        self._dq_constraint_to_check(
                            rule_name=rule_name,
                            constraint_sql=constraint,
                            table_name=fq_table,
                            criticality=self.criticality,
                            run_config_name=self.run_config_name,
                        )
                    )

                if not checks:
                    print(f"[INFO] No checks generated for {fq_table}.")
                    continue

                # Destination selection
                if self.output_format == "yaml":            # "yaml" | "table"
                    # Directory -> {table}.yaml ; or exact file path
                    if self.output_location.endswith((".yaml", ".yml")):
                        path = self.output_location
                    else:
                        path = f"{self.output_location.rstrip('/')}/{tab}.yaml"

                    if self.key_order == "engine":
                        cfg = self._infer_file_storage_config(path)
                        print(f"[RUN] Saving {len(checks)} checks via DQX to: {path}")
                        dq_engine.save_checks(checks, config=cfg)
                    else:  # "custom"
                        print(f"[RUN] Saving {len(checks)} checks with strict key order to: {path}")
                        self._write_yaml_ordered(checks, path)

                    total_checks += len(checks)

                else:  # table sink
                    cfg = self._table_storage_config(
                        table_fqn=self.output_location,
                        run_config_name=self.run_config_name,
                        mode="append"
                    )
                    print(f"[RUN] Appending {len(checks)} checks to table: {self.output_location} (run_config_name={self.run_config_name})")
                    dq_engine.save_checks(checks, config=cfg)
                    total_checks += len(checks)

            print(f"[RUN] {'Successfully saved' if total_checks else 'No'} checks. Count: {total_checks}")
        except Exception as e:
            print(f"[ERROR] Rule generation failed: {e}")


# -------------------- Usage examples --------------------
if __name__ == "__main__":
    profile_options = {
        # Sampling
        "sample_fraction": 0.3,
        "sample_seed": 42,
        "limit": 1000,
        # Outliers
        "remove_outliers": True,
        "outlier_columns": [],
        "num_sigmas": 3,
        # Nulls / empties
        "max_null_ratio": 0.01,
        "trim_strings": True,
        "max_empty_ratio": 0.01,
        # Distincts → is_in
        "distinct_ratio": 0.05,
        "max_in_count": 10,
        # Rounding
        "round": True,
        # (Keys not in current docs will still pass through; you’ll see a one-time INFO warning)
        # "include_histograms": False,
        # "min_length": None,
        # "max_length": None,
        # "min_value": None,
        # "max_value": None,
        # "profile_types": None,
    }
    """
    # Example A: Single table (unchanged)
    RuleGenerator(
        mode="table",                                   # "pipeline" | "catalog" | "schema" | "table"
        name_param="dq_prd.monitoring.job_run_audit",   # depends on mode
        output_format="yaml",                           # "yaml" | "table"
        output_location="dqx_checks",                   # yaml dir (local) OR "/Shared/..." OR "dbfs:/..." OR "/Volumes/..."
        profile_options=profile_options,
        columns=None,                                   # None => whole table (only valid when mode=="table")
        exclude_pattern=None,                           # e.g. ".tmp_*"
        created_by="LMG",
        run_config_name="default",
        criticality="error",
        key_order="custom",                             # "engine" or "custom"
        include_table_name=True,
    ).run()
    """
    # Example B: YAML list of tables + header comments (wkdy_* example)
    #   Notebook path: src/dqx/00_generate_dqx_checks
    #   YAML file path (relative to repo): info/wkdy_table/wkdy_table_info/wkdy_table_list-gold.yaml
    #   Output folder:                      info/wkdy_table/wkdy_table_info/wkdy_generated_dqx_checks
    RuleGenerator(
        mode="table",                                   # "pipeline" | "catalog" | "schema" | "table"
        name_param="resources/dqx_checks_generated/audit/audit_table_list-gold.yaml",
        output_format="yaml",                           # "yaml" | "table"
        output_location="resources/dqx_checks_generated/audit/audit_generated_dqx_checks",  # one YAML per table
        profile_options=profile_options,
        columns=None,                                   # None => whole table (only valid when mode=="table")
        exclude_pattern=None,                           # e.g. ".tmp_*"
        created_by="LMG",
        run_config_name="default",
        criticality="error",
        key_order="custom",                             # "engine" or "custom"
        include_table_name=True,
    ).run()

In [0]:
# dqx_rule_generator.py
# Full file — adds YAML list support via header-comment parsing + display() preview.
# NOTHING REMOVED: all params/options stay. Renamed yaml_key_order -> key_order.
# Single-table & CSV FQNs still work. One YAML file per table is written in output_location.

import os
import re
import io
import json
from typing import List, Optional, Dict, Any, Literal, Tuple

import yaml
from databricks.sdk import WorkspaceClient
from databricks.labs.dqx.profiler.profiler import DQProfiler
from databricks.labs.dqx.profiler.dlt_generator import DQDltGenerator
from databricks.labs.dqx.engine import DQEngine
from databricks.labs.dqx.config import (
    FileChecksStorageConfig,
    WorkspaceFileChecksStorageConfig,
    TableChecksStorageConfig,
    VolumeFileChecksStorageConfig,
)
from pyspark.sql import SparkSession


def glob_to_regex(glob_pattern: str) -> str:
    if not glob_pattern or not glob_pattern.startswith('.'):
        raise ValueError("Exclude pattern must start with a dot, e.g. '.tamarack_*'")
    glob = glob_pattern[1:]
    regex = re.escape(glob).replace(r'\*', '.*')
    return '^' + regex + '$'


# Keys currently shown in DQX docs for profiler "options" (kept permissive; we only warn on extras)
DOC_SUPPORTED_KEYS = {
    "sample_fraction",
    "sample_seed",
    "limit",
    "remove_outliers",
    "outlier_columns",
    "num_sigmas",
    "max_null_ratio",
    "trim_strings",
    "max_empty_ratio",
    "distinct_ratio",
    "max_in_count",
    "round",
}

# ---------- YAML file support (header-comment parsing for defaults) ----------
_YAML_PATH_RE = re.compile(r"\.(ya?ml)$", re.IGNORECASE)
_FROM_INFOSCHEMA = re.compile(r"FROM\s+([A-Za-z0-9_]+)\.information_schema\.tables", re.IGNORECASE)
_TABLE_SCHEMA_EQ = re.compile(r"table_schema\s*=\s*'([^']+)'", re.IGNORECASE)

def _is_yaml_path(path: str) -> bool:
    return bool(_YAML_PATH_RE.search(path))

def _resolve_local_like_path(path: str) -> Optional[str]:
    """
    Try hard to resolve a repo-relative file path from a notebook:
    - as-given
    - join with cwd
    - walk up 6 parents and join
    Returns a filesystem path if found, else None.
    """
    candidates = []
    if os.path.exists(path):
        return path
    cwd = os.getcwd()
    base = cwd
    for _ in range(6):
        cand = os.path.abspath(os.path.join(base, path))
        candidates.append(cand)
        if os.path.exists(cand):
            return cand
        parent = os.path.dirname(base)
        if parent == base:
            break
        base = parent
    return None

def _read_text_any(path: str) -> str:
    """
    Read small text files from:
      - repo/local filesystem (relative or absolute)
      - dbfs:/, /dbfs/, /Volumes/
      - workspace files path (starts with '/'), via Files API
    """
    # DBFS / Volumes
    if path.startswith("dbfs:/") or path.startswith("/dbfs/") or path.startswith("/Volumes/"):
        try:
            from databricks.sdk.runtime import dbutils
        except Exception as e:
            raise RuntimeError("dbutils is required to read from DBFS/Volumes") from e
        target = path if path.startswith("dbfs:") else (f"dbfs:{path}" if path.startswith("/") else f"dbfs:/{path}")
        return dbutils.fs.head(target, 10 * 1024 * 1024)

    # Absolute workspace file path (e.g. /Workspace/Repos/.../file.yaml)
    if path.startswith("/"):
        wc = WorkspaceClient()
        try:
            # newer SDK signature
            data = wc.files.download(file_path=path).read()
        except TypeError:
            # older SDK signature
            data = wc.files.download(path=path).read()
        return data.decode("utf-8")

    # Relative/local resolution (repo-friendly)
    resolved = _resolve_local_like_path(path)
    if resolved and os.path.isfile(resolved):
        with open(resolved, "r", encoding="utf-8") as fh:
            return fh.read()

    # Nothing worked → helpful error
    msg = [
        f"Could not find YAML at '{path}'.",
        f"cwd={os.getcwd()}",
        "Hints:",
        "  - If this file is in your repo, pass a path relative to the notebook or the repo root.",
        "  - Or pass an absolute workspace path like '/Workspace/Repos/.../file.yaml'.",
        "  - Or use a DBFS/Volumes path like 'dbfs:/...' or '/Volumes/...'.",
    ]
    raise FileNotFoundError("\n".join(msg))

def _parse_global_hints_from_comments(text: str) -> Tuple[Optional[str], Optional[str]]:
    """
    Pull defaults from header comments only (no YAML keys):
      - catalog from: 'FROM <catalog>.information_schema.tables'
      - schema  from: "table_schema = '<schema>'"
    """
    m_cat = _FROM_INFOSCHEMA.search(text)
    m_sch = _TABLE_SCHEMA_EQ.search(text)
    cat = m_cat.group(1) if m_cat else None
    sch = m_sch.group(1) if m_sch else None
    return cat, sch

def _ensure_fqns(names: List[str], hint_catalog: Optional[str], hint_schema: Optional[str]) -> List[str]:
    """
    Accept: catalog.schema.table | schema.table | table
    Use comment-derived defaults to fill missing parts. Dotted entries override per item.
    """
    out: List[str] = []
    for n in names:
        n = n.strip()
        if not n:
            continue
        parts = n.split(".")
        if len(parts) == 3:
            out.append(n)
        elif len(parts) == 2:
            if not hint_catalog:
                raise ValueError(f"'{n}' lacks catalog; add it to the item or provide a comment default.")
            out.append(f"{hint_catalog}.{n}")
        elif len(parts) == 1:
            if not (hint_catalog and hint_schema):
                raise ValueError(f"'{n}' needs catalog & schema; set via comments or use dotted forms.")
            out.append(f"{hint_catalog}.{hint_schema}.{n}")
        else:
            raise ValueError(f"Unrecognized table format: {n}")
    bad = [t for t in out if t.count(".") != 2]
    if bad:
        raise ValueError(f"Invalid FQN(s) after resolution: {bad}")
    return sorted(set(out))

def _discover_tables_from_yaml_by_comments(yaml_path: str) -> List[str]:
    """
    Load the YAML file (with e.g. 'table_name:' list) and use ONLY header comments
    to infer catalog/schema defaults. Dotted items override per entry.
    """
    text = _read_text_any(yaml_path)
    cat_hint, sch_hint = _parse_global_hints_from_comments(text)

    obj = yaml.safe_load(io.StringIO(text))
    if not isinstance(obj, dict):
        raise ValueError(f"YAML must contain a mapping with a list; got: {type(obj).__name__}")

    names = None
    for key in ("table_name", "tables", "table_names", "list"):
        if isinstance(obj.get(key), list):
            names = [str(x).strip() for x in obj[key] if x]
            break
    if not names:
        raise ValueError(f"No table list found in YAML: {yaml_path}")

    fqns = _ensure_fqns(names, cat_hint, sch_hint)
    print(f"[INFO] Parsed defaults from comments: catalog={cat_hint} schema={sch_hint}")
    return fqns

def _display_table_preview(spark: SparkSession, fqns: List[str], title: str = "Resolved Tables") -> None:
    """
    Use display() (Databricks) for a clean, interactive preview instead of printing.
    """
    rows = [(f, *f.split(".")) for f in fqns]
    df = spark.createDataFrame(rows, "fqn string, catalog string, schema string, table string")
    print(f"\n=== {title} ({len(fqns)}) ===")
    try:
        display(df)
    except NameError:
        df.show(len(fqns), truncate=False)


class RuleGenerator:
    def __init__(
        self,
        mode: str,                       # "pipeline" | "catalog" | "schema" | "table"
        name_param: str,                 # pipeline CSV | catalog | catalog.schema | table FQN CSV | YAML path
        output_format: str,              # "yaml" | "table"
        output_location: str,            # yaml: folder or file; table: catalog.schema.table
        profile_options: Dict[str, Any],
        exclude_pattern: Optional[str] = None,     # e.g. ".tmp_*"
        created_by: Optional[str] = "LMG",
        columns: Optional[List[str]] = None,       # None => whole table (only if mode=="table")
        run_config_name: str = "default",          # DQX run group tag
        criticality: str = "warn",                 # "warn" | "error"
        key_order: Literal["engine", "custom"] = "custom",  # "engine" uses DQX save; "custom" enforces key order
        include_table_name: bool = True,           # include table_name in each rule dict
    ):
        self.mode = mode.lower().strip()
        self.name_param = name_param
        self.output_format = output_format.lower().strip()
        self.output_location = output_location
        self.profile_options = profile_options or {}
        self.exclude_pattern = exclude_pattern
        self.created_by = created_by
        self.columns = columns
        self.run_config_name = run_config_name
        self.criticality = criticality
        self.key_order = key_order
        self.include_table_name = include_table_name

        self.spark = SparkSession.getActiveSession()
        if not self.spark:
            raise RuntimeError("No active Spark session found. Run this in a Databricks notebook.")

        if self.output_format not in {"yaml", "table"}:
            raise ValueError("output_format must be 'yaml' or 'table'.")
        if self.output_format == "yaml" and not self.output_location:
            raise ValueError("When output_format='yaml', provide output_location (folder or file).")

    # ---------- profile options: pass via 'options' kwarg, warn on unknowns ----------
    def _profile_call_kwargs(self) -> Dict[str, Any]:
        """
        Build kwargs for DQProfiler.profile / profile_table.
        We pass:
          - cols=self.columns (when provided)
          - options=self.profile_options (dict; profiler reads keys internally)
        We only WARN on keys not in the current documented set; we do not drop them.
        """
        kwargs: Dict[str, Any] = {}
        if self.columns is not None:
            kwargs["cols"] = self.columns
        if self.profile_options:
            unknown = sorted(set(self.profile_options) - DOC_SUPPORTED_KEYS)
            if unknown:
                print(f"[INFO] Profiling options not in current docs (passing through anyway): {unknown}")
            kwargs["options"] = self.profile_options
        return kwargs

    # ---------- discovery ----------
    def _exclude_tables_by_pattern(self, fq_tables: List[str]) -> List[str]:
        if not self.exclude_pattern:
            return fq_tables
        regex = glob_to_regex(self.exclude_pattern)
        pattern = re.compile(regex)
        filtered = []
        for fq in fq_tables:
            tbl = fq.split('.')[-1]
            if not pattern.match(tbl):
                filtered.append(fq)
        print(f"[INFO] Excluded {len(fq_tables) - len(filtered)} tables by pattern '{self.exclude_pattern}'")
        return filtered

    def _discover_tables(self) -> List[str]:
        print("\n===== PARAMETERS PASSED THIS RUN =====")
        print(f"mode:             {self.mode}")
        print(f"name_param:       {self.name_param}")
        print(f"output_format:    {self.output_format}")
        print(f"output_location:  {self.output_location}")
        print(f"exclude_pattern:  {self.exclude_pattern}")
        print(f"created_by:       {self.created_by}")
        print(f"columns:          {self.columns}")
        print(f"run_config_name:  {self.run_config_name}")
        print(f"criticality:      {self.criticality}")
        print(f"key_order:        {self.key_order}")
        print(f"include_table_name: {self.include_table_name}")
        print(f"profile_options:")
        for k, v in self.profile_options.items():
            print(f"  {k}: {v}")
        print("======================================\n")

        allowed_modes = {"pipeline", "catalog", "schema", "table"}
        if self.mode not in allowed_modes:
            raise ValueError(f"Invalid mode '{self.mode}'. Must be one of: {sorted(allowed_modes)}.")
        if self.columns is not None and self.mode != "table":
            raise ValueError("The 'columns' parameter can only be used in mode='table'.")

        discovered: List[str] = []

        if self.mode == "pipeline":
            print("Searching for pipeline output tables...")
            ws = WorkspaceClient()
            pipelines = [p.strip() for p in self.name_param.split(",") if p.strip()]
            print(f"Pipelines passed: {pipelines}")
            for pipeline_name in pipelines:
                print(f"Finding output tables for pipeline: {pipeline_name}")
                pls = list(ws.pipelines.list_pipelines())
                pl = next((p for p in pls if p.name == pipeline_name), None)
                if not pl:
                    raise RuntimeError(f"Pipeline '{pipeline_name}' not found via SDK.")
                latest_update = pl.latest_updates[0].update_id
                events = ws.pipelines.list_pipeline_events(pipeline_id=pl.pipeline_id, max_results=250)
                pipeline_tables = [
                    getattr(ev.origin, "flow_name", None)
                    for ev in events
                    if getattr(ev.origin, "update_id", None) == latest_update and getattr(ev.origin, "flow_name", None)
                ]
                discovered += [x for x in pipeline_tables if x]

        elif self.mode == "catalog":
            print("Searching for tables in catalog...")
            catalog = self.name_param.strip()
            schemas = [row.namespace for row in self.spark.sql(f"SHOW SCHEMAS IN {catalog}").collect()]
            for s in schemas:
                tbls = self.spark.sql(f"SHOW TABLES IN {catalog}.{s}").collect()
                discovered += [f"{catalog}.{s}.{row.tableName}" for row in tbls]

        elif self.mode == "schema":
            print("Searching for tables in schema...")
            if self.name_param.count(".") != 1:
                raise ValueError("For 'schema' mode, name_param must be catalog.schema")
            catalog, schema = self.name_param.strip().split(".")
            tbls = self.spark.sql(f"SHOW TABLES IN {catalog}.{schema}").collect()
            discovered = [f"{catalog}.{schema}.{row.tableName}" for row in tbls]

        else:  # table
            print("Profiling one or more specific tables...")
            if _is_yaml_path(self.name_param):
                print(f"[INFO] name_param is a YAML list → {self.name_param}")
                discovered = _discover_tables_from_yaml_by_comments(self.name_param)
            else:
                tables = [t.strip() for t in self.name_param.split(",") if t.strip()]
                for t in tables:
                    if t.count(".") != 2:
                        raise ValueError(f"Table name '{t}' must be fully qualified (catalog.schema.table)")
                discovered = tables

        print("\nRunning exclude pattern filtering (if any)...")
        discovered = self._exclude_tables_by_pattern(discovered)

        # Interactive preview instead of plain prints
        _display_table_preview(self.spark, discovered, title="Final table list to generate DQX rules for")

        print("==========================================\n")
        return sorted(set(discovered))

    # ---------- storage config helpers ----------
    @staticmethod
    def _infer_file_storage_config(file_path: str):
        if file_path.startswith("/Volumes/"):
            return VolumeFileChecksStorageConfig(location=file_path)
        if file_path.startswith("/"):
            return WorkspaceFileChecksStorageConfig(location=file_path)
        return FileChecksStorageConfig(location=file_path)

    @staticmethod
    def _table_storage_config(table_fqn: str, run_config_name: Optional[str] = None, mode: str = "append"):
        return TableChecksStorageConfig(location=table_fqn, run_config_name=run_config_name, mode=mode)

    @staticmethod
    def _workspace_files_upload(path: str, payload: bytes) -> None:
        wc = WorkspaceClient()
        try:
            wc.files.upload(file_path=path, contents=payload, overwrite=True)  # newer SDK
        except TypeError:
            wc.files.upload(path=path, contents=payload, overwrite=True)       # older SDK

    @staticmethod
    def _ensure_parent(path: str) -> None:
        """Create parent dir for local paths."""
        parent = os.path.dirname(path)
        if parent and not os.path.exists(parent):
            os.makedirs(parent, exist_ok=True)

    @staticmethod
    def _ensure_dbfs_parent(dbutils, path: str) -> None:
        parent = path.rsplit("/", 1)[0] if "/" in path else path
        if parent:
            dbutils.fs.mkdirs(parent)

    # ---------- DQX check shaping ----------
    def _dq_constraint_to_check(
        self,
        rule_name: str,
        constraint_sql: str,
        table_name: str,
        criticality: str,
        run_config_name: str
    ) -> Dict[str, Any]:
        """
        Convert a profiler constraint (SQL) into a DQX check dict.
        Key order: table_name, name, criticality, run_config_name, check (insertion order).
        """
        d = {
            "name": rule_name,
            "criticality": criticality,
            "run_config_name": run_config_name,
            "check": {
                "function": "sql_expression",
                "arguments": {
                    "expression": constraint_sql,
                    "name": rule_name,
                }
            },
        }
        if self.include_table_name:
            d = {"table_name": table_name, **d}
        return d

    # ---------- YAML writers ----------
    def _write_yaml_ordered(self, checks: List[Dict[str, Any]], path: str) -> None:
        """
        Dump YAML preserving key order and upload:
          - /Volumes/... | dbfs:/... | /dbfs/... -> dbutils.fs.put (mkdirs parent)
          - /Shared/... (workspace files) -> Files API
          - relative/local path -> os.makedirs + open(...)
        """
        yaml_str = yaml.safe_dump(checks, sort_keys=False, default_flow_style=False)

        # DBFS / Volumes
        if path.startswith("dbfs:/") or path.startswith("/dbfs/") or path.startswith("/Volumes/"):
            try:
                from databricks.sdk.runtime import dbutils
            except Exception:
                raise RuntimeError("dbutils is required to write to DBFS/Volumes.")
            target = path if path.startswith("dbfs:/") else (f"dbfs:{path}" if not path.startswith("dbfs:") else path)
            self._ensure_dbfs_parent(dbutils, target.rsplit("/", 1)[0])
            dbutils.fs.put(target, yaml_str, True)
            print(f"[RUN] Wrote ordered YAML to {path}")
            return

        # Workspace files
        if path.startswith("/"):
            self._workspace_files_upload(path, yaml_str.encode("utf-8"))
            print(f"[RUN] Wrote ordered YAML to workspace file: {path}")
            return

        # Local (driver) relative/absolute filesystem (Repos-friendly)
        full_path = os.path.abspath(path)
        self._ensure_parent(full_path)
        with open(full_path, "w", encoding="utf-8") as f:
            f.write(yaml_str)
        print(f"[RUN] Wrote ordered YAML to local path: {full_path}")

    # ---------- main ----------
    def run(self):
        try:
            tables = self._discover_tables()
            print("[RUN] Beginning DQX rule generation on these tables:")
            for t in tables:
                print(f"  {t}")
            print("==========================================\n")

            call_kwargs = self._profile_call_kwargs()
            print("[RUN] Profiler call kwargs:")
            print(f"  cols:    {call_kwargs.get('cols')}")
            print(f"  options: {json.dumps(call_kwargs.get('options', {}), indent=2)}")

            dq_engine = DQEngine(WorkspaceClient())
            total_checks = 0

            for fq_table in tables:
                if fq_table.count(".") != 2:
                    print(f"[WARN] Skipping invalid table name: {fq_table}")
                    continue
                cat, sch, tab = fq_table.split(".")

                # Verify readability
                try:
                    print(f"[RUN] Checking table readability: {fq_table}")
                    self.spark.table(fq_table).limit(1).collect()
                except Exception as e:
                    print(f"[WARN] Table {fq_table} not readable in Spark: {e}")
                    continue

                profiler = DQProfiler(WorkspaceClient())
                generator = DQDltGenerator(WorkspaceClient())
                df = self.spark.table(fq_table)

                try:
                    print(f"[RUN] Profiling and generating rules for: {fq_table}")
                    # DataFrame profiling with options and optional cols
                    summary_stats, profiles = profiler.profile(df, **call_kwargs)
                    # If you prefer table-based API:
                    # summary_stats, profiles = profiler.profile_table(table=fq_table, **call_kwargs)

                    rules_dict = generator.generate_dlt_rules(profiles, language="Python_Dict")
                except Exception as e:
                    print(f"[WARN] Profiling failed for {fq_table}: {e}")
                    continue

                checks: List[Dict[str, Any]] = []
                for rule_name, constraint in (rules_dict or {}).items():
                    checks.append(
                        self._dq_constraint_to_check(
                            rule_name=rule_name,
                            constraint_sql=constraint,
                            table_name=fq_table,
                            criticality=self.criticality,
                            run_config_name=self.run_config_name,
                        )
                    )

                if not checks:
                    print(f"[INFO] No checks generated for {fq_table}.")
                    continue

                # Destination selection
                if self.output_format == "yaml":            # "yaml" | "table"
                    # Directory -> {table}.yaml ; or exact file path
                    if self.output_location.endswith((".yaml", ".yml")):
                        path = self.output_location
                    else:
                        path = f"{self.output_location.rstrip('/')}/{tab}.yaml"

                    if self.key_order == "engine":
                        cfg = self._infer_file_storage_config(path)
                        print(f"[RUN] Saving {len(checks)} checks via DQX to: {path}")
                        dq_engine.save_checks(checks, config=cfg)
                    else:  # "custom"
                        print(f"[RUN] Saving {len(checks)} checks with strict key order to: {path}")
                        self._write_yaml_ordered(checks, path)

                    total_checks += len(checks)

                else:  # table sink
                    cfg = self._table_storage_config(
                        table_fqn=self.output_location,
                        run_config_name=self.run_config_name,
                        mode="append"
                    )
                    print(f"[RUN] Appending {len(checks)} checks to table: {self.output_location} (run_config_name={self.run_config_name})")
                    dq_engine.save_checks(checks, config=cfg)
                    total_checks += len(checks)

            print(f"[RUN] {'Successfully saved' if total_checks else 'No'} checks. Count: {total_checks}")
        except Exception as e:
            print(f"[ERROR] Rule generation failed: {e}")


# -------------------- Usage examples --------------------
if __name__ == "__main__":
    profile_options = {
        # Sampling
        "sample_fraction": 0.3,
        "sample_seed": 42,
        "limit": 1000,
        # Outliers
        "remove_outliers": True,
        "outlier_columns": [],
        "num_sigmas": 3,
        # Nulls / empties
        "max_null_ratio": 0.01,
        "trim_strings": True,
        "max_empty_ratio": 0.01,
        # Distincts → is_in
        "distinct_ratio": 0.05,
        "max_in_count": 10,
        # Rounding
        "round": True,
        # (Keys not in current docs will still pass through; you’ll see a one-time INFO warning)
        # "include_histograms": False,
        # "min_length": None,
        # "max_length": None,
        # "min_value": None,
        # "max_value": None,
        # "profile_types": None,
    }
    """
    # Example A: Single table (unchanged)
    RuleGenerator(
        mode="table",                                   # "pipeline" | "catalog" | "schema" | "table"
        name_param="dq_prd.monitoring.job_run_audit",   # depends on mode
        output_format="yaml",                           # "yaml" | "table"
        output_location="dqx_checks",                   # yaml dir (local) OR "/Shared/..." OR "dbfs:/..." OR "/Volumes/..."
        profile_options=profile_options,
        columns=None,                                   # None => whole table (only valid when mode=="table")
        exclude_pattern=None,                           # e.g. ".tmp_*"
        created_by="LMG",
        run_config_name="default",
        criticality="error",
        key_order="custom",                             # "engine" or "custom"
        include_table_name=True,
    ).run()
    """
    # Example B: YAML list of tables + header comments (wkdy_* example)
    #   Notebook path: src/dqx/00_generate_dqx_checks
    #   YAML file path (relative to repo): info/wkdy_table/wkdy_table_info/wkdy_table_list-gold.yaml
    #   Output folder:                      info/wkdy_table/wkdy_table_info/wkdy_generated_dqx_checks
    RuleGenerator(
        mode="table",                                   # "pipeline" | "catalog" | "schema" | "table"
        name_param="resources/dqx_checks_generated/crm/crm_table_list-gold.yaml",
        output_format="yaml",                           # "yaml" | "table"
        output_location="resources/dqx_checks_generated/crm/crm_generated_dqx_checks",  # one YAML per table
        profile_options=profile_options,
        columns=None,                                   # None => whole table (only valid when mode=="table")
        exclude_pattern=None,                           # e.g. ".tmp_*"
        created_by="LMG",
        run_config_name="default",
        criticality="error",
        key_order="custom",                             # "engine" or "custom"
        include_table_name=True,
    ).run()