# DQX Rules Generator

In [0]:
# -- DQX profiling options: all known keys --
profile_options = {
    # "round": True,                 # (Removed - not valid for this profiler version)
    "max_in_count": 10,            # Max distinct values for is_in rule
    "distinct_ratio": 0.05,        # Max unique/total ratio for is_in rule
    "max_null_ratio": 0.01,        # Max null fraction to allow is_not_null rule
    "remove_outliers": True,       # Remove outliers for min/max
    "outlier_columns": [],         # Only these columns get outlier removal (empty=all numerics)
    "num_sigmas": 3,               # Stddev for outlier removal (z-score cutoff)
    "trim_strings": True,          # Strip whitespace before profiling strings
    "max_empty_ratio": 0.01,       # Max empty string ratio for is_not_null_or_empty
    "sample_fraction": 0.3,        # Row fraction to sample
    "sample_seed": None,           # Seed for reproducibility (set int for deterministic)
    "limit": 1000,                 # Max number of rows to profile
    "profile_types": None,         # List of rule types (e.g. ["is_in", "is_not_null"]); None=default
    "min_length": None,            # Min string length to consider (None disables)
    "max_length": None,            # Max string length to consider (None disables)
    "include_histograms": False,   # Compute histograms as part of profiling
    "min_value": None,             # Numeric min override (None disables)
    "max_value": None,             # Numeric max override (None disables)
}

In [0]:
%pip install databricks-labs-dqx

In [0]:
dbutils.library.restartPython()

In [0]:
import re
import json
import inspect
from datetime import datetime, timezone
from typing import List, Optional, Dict, Any, Tuple, Literal

import yaml
from databricks.sdk import WorkspaceClient
from databricks.labs.dqx.profiler.profiler import DQProfiler
from databricks.labs.dqx.profiler.dlt_generator import DQDltGenerator
from databricks.labs.dqx.engine import DQEngine
from databricks.labs.dqx.config import (
    FileChecksStorageConfig,
    WorkspaceFileChecksStorageConfig,
    TableChecksStorageConfig,
    VolumeFileChecksStorageConfig,
)
from pyspark.sql import SparkSession


def glob_to_regex(glob_pattern: str) -> str:
    if not glob_pattern or not glob_pattern.startswith('.'):
        raise ValueError("Exclude pattern must start with a dot, e.g. '.tamarack_*'")
    glob = glob_pattern[1:]
    regex = re.escape(glob).replace(r'\*', '.*')
    return '^' + regex + '$'


class RuleGenerator:
    def __init__(
        self,
        mode: str,                       # "pipeline" | "catalog" | "schema" | "table"
        name_param: str,                 # pipeline name(s) CSV | catalog | catalog.schema | fully qualified table(s) CSV
        output_format: str,              # "yaml" | "table"
        output_location: str,            # yaml: folder or full file path; table: catalog.schema.table
        profile_options: Dict[str, Any],
        exclude_pattern: Optional[str] = None,     # e.g. ".foo_*" (table-name glob with leading dot)
        created_by: Optional[str] = "LMG",
        columns: Optional[List[str]] = None,       # None => whole table; only valid if mode=="table"
        run_config_name: str = "default",          # DQX run group tag on checks
        criticality: str = "warn",                 # "warn" | "error"
        yaml_key_order: Literal["engine", "custom"] = "engine",  # "engine" (DQX save) | "custom" (strict key order)
    ):
        self.mode = mode.lower().strip()
        self.name_param = name_param
        self.output_format = output_format.lower().strip()
        self.output_location = output_location
        # Pass options through as-is; let the profiler decide what to use.
        self.profile_options = profile_options or {}
        self.exclude_pattern = exclude_pattern
        self.created_by = created_by
        self.columns = columns
        self.run_config_name = run_config_name
        self.criticality = criticality
        self.yaml_key_order = yaml_key_order

        self.spark = SparkSession.getActiveSession()
        if not self.spark:
            raise RuntimeError("No active Spark session found. Run this in a Databricks notebook.")

        if self.output_format not in {"yaml", "table"}:
            raise ValueError("output_format must be 'yaml' or 'table'.")
        if self.output_format == "yaml" and not self.output_location:
            raise ValueError("When output_format='yaml', provide output_location (folder or file).")

    # -------------------- discovery helpers --------------------

    def _exclude_tables_by_pattern(self, fq_tables: List[str]) -> List[str]:
        if not self.exclude_pattern:
            return fq_tables
        regex = glob_to_regex(self.exclude_pattern)
        pattern = re.compile(regex)
        filtered = []
        for fq in fq_tables:
            tbl = fq.split('.')[-1]
            if not pattern.match(tbl):
                filtered.append(fq)
        print(f"[INFO] Excluded {len(fq_tables) - len(filtered)} tables by pattern '{self.exclude_pattern}'")
        return filtered

    def _discover_tables(self) -> List[str]:
        print("\n===== PARAMETERS PASSED THIS RUN =====")
        print(f"mode:             {self.mode}")
        print(f"name_param:       {self.name_param}")
        print(f"output_format:    {self.output_format}")
        print(f"output_location:  {self.output_location}")
        print(f"exclude_pattern:  {self.exclude_pattern}")
        print(f"created_by:       {self.created_by}")
        print(f"columns:          {self.columns}")
        print(f"run_config_name:  {self.run_config_name}")
        print(f"criticality:      {self.criticality}")
        print(f"yaml_key_order:   {self.yaml_key_order}")
        print(f"profile_options:")
        for k, v in self.profile_options.items():
            print(f"  {k}: {v}")
        print("======================================\n")

        mode = self.mode
        name_param = self.name_param
        allowed_modes = {"pipeline", "catalog", "schema", "table"}
        if mode not in allowed_modes:
            raise ValueError(f"Invalid mode '{mode}'. Must be one of: {sorted(allowed_modes)}.")

        if self.columns is not None and mode != "table":
            raise ValueError("The 'columns' parameter can only be used in mode='table'.")

        discovered: List[str] = []
        if mode == "pipeline":
            print("Searching for pipeline output tables...")
            ws = WorkspaceClient()
            pipelines = [p.strip() for p in name_param.split(",") if p.strip()]
            print(f"Pipelines passed: {pipelines}")
            for pipeline_name in pipelines:
                print(f"Finding output tables for pipeline: {pipeline_name}")
                pls = list(ws.pipelines.list_pipelines())
                pl = next((p for p in pls if p.name == pipeline_name), None)
                if not pl:
                    raise RuntimeError(f"Pipeline '{pipeline_name}' not found via SDK.")
                latest_update = pl.latest_updates[0].update_id
                events = ws.pipelines.list_pipeline_events(pipeline_id=pl.pipeline_id, max_results=250)
                pipeline_tables = [
                    getattr(ev.origin, "flow_name", None)
                    for ev in events
                    if getattr(ev.origin, "update_id", None) == latest_update and getattr(ev.origin, "flow_name", None)
                ]
                pipeline_tables = [x for x in pipeline_tables if x]
                print(f"Found tables for pipeline '{pipeline_name}': {pipeline_tables}")
                discovered += pipeline_tables
            print(f"All discovered tables from pipelines: {discovered}")

        elif mode == "catalog":
            print("Searching for tables in catalog...")
            catalog = name_param.strip()
            schemas = [row.namespace for row in self.spark.sql(f"SHOW SCHEMAS IN {catalog}").collect()]
            catalog_tables = []
            for s in schemas:
                tbls = self.spark.sql(f"SHOW TABLES IN {catalog}.{s}").collect()
                found = [f"{catalog}.{s}.{row.tableName}" for row in tbls]
                catalog_tables += found
                if found:
                    print(f"Tables in {catalog}.{s}: {found}")
            discovered = catalog_tables
            print(f"All discovered tables in catalog: {discovered}")

        elif mode == "schema":
            print("Searching for tables in schema...")
            if name_param.count(".") != 1:
                raise ValueError("For 'schema' mode, name_param must be catalog.schema")
            catalog, schema = name_param.strip().split(".")
            tbls = self.spark.sql(f"SHOW TABLES IN {catalog}.{schema}").collect()
            schema_tables = [f"{catalog}.{schema}.{row.tableName}" for row in tbls]
            discovered = schema_tables
            print(f"Tables in schema {catalog}.{schema}: {schema_tables}")

        elif mode == "table":
            print("Profiling one or more specific tables...")
            tables = [t.strip() for t in name_param.split(",") if t.strip()]
            for t in tables:
                if t.count(".") != 2:
                    raise ValueError(f"Table name '{t}' must be fully qualified (catalog.schema.table)")
            discovered = tables
            print(f"Tables to be profiled: {discovered}")

        print("\nRunning exclude pattern filtering (if any)...")
        discovered = self._exclude_tables_by_pattern(discovered)
        print("\nFinal table list to generate DQX rules for:")
        print(discovered)
        print("==========================================\n")
        return sorted(set(discovered))

    # -------------------- storage config helpers --------------------

    @staticmethod
    def _infer_file_storage_config(file_path: str):
        """
        Choose appropriate StorageConfig for file-like output:
          - /Volumes/...  -> VolumeFileChecksStorageConfig
          - /...          -> WorkspaceFileChecksStorageConfig
          - else          -> FileChecksStorageConfig (driver-local)
        """
        if file_path.startswith("/Volumes/"):
            return VolumeFileChecksStorageConfig(location=file_path)
        if file_path.startswith("/"):
            return WorkspaceFileChecksStorageConfig(location=file_path)
        return FileChecksStorageConfig(location=file_path)

    @staticmethod
    def _table_storage_config(table_fqn: str, run_config_name: Optional[str] = None, mode: str = "append"):
        return TableChecksStorageConfig(location=table_fqn, run_config_name=run_config_name, mode=mode)

    @staticmethod
    def _dq_constraint_to_check(
        rule_name: str,
        constraint_sql: str,
        table_name: str,
        criticality: str,
        run_config_name: str
    ) -> Dict[str, Any]:
        """
        Convert a profiler constraint (SQL) into a DQX check dict.
        Key order: table_name, name, criticality, run_config_name, check
        (insertion order preserved in Python 3.7+)
        """
        return {
            "table_name": table_name,
            "name": rule_name,
            "criticality": criticality,
            "run_config_name": run_config_name,
            "check": {
                "function": "sql_expression",
                "arguments": {
                    "expression": constraint_sql,
                    "name": rule_name,
                }
            },
        }

    def _write_yaml_ordered(self, checks: List[Dict[str, Any]], path: str) -> None:
        """
        Dump YAML preserving key order and upload:
          - /Volumes/... | dbfs:/... | /dbfs/... -> dbutils.fs.put
          - /Shared/... or any workspace path -> Workspace Files API
          - else -> local driver path
        """
        yaml_str = yaml.safe_dump(checks, sort_keys=False, default_flow_style=False)

        try:
            from databricks.sdk.runtime import dbutils  # available in notebooks
        except Exception:
            dbutils = None

        if path.startswith("dbfs:/") or path.startswith("/dbfs/") or path.startswith("/Volumes/"):
            if not dbutils:
                raise RuntimeError("dbutils not available to write to DBFS/Volumes.")
            target = path if path.startswith("dbfs:/") else (f"dbfs:{path}" if not path.startswith("dbfs:") else path)
            dbutils.fs.put(target, yaml_str, True)
            print(f"[RUN] Wrote ordered YAML to {path}")
            return

        if path.startswith("/"):
            wc = WorkspaceClient()
            wc.files.upload(path=path, contents=yaml_str.encode("utf-8"), overwrite=True)
            print(f"[RUN] Wrote ordered YAML to workspace file: {path}")
            return

        with open(path, "w", encoding="utf-8") as f:
            f.write(yaml_str)
        print(f"[RUN] Wrote ordered YAML to local path: {path}")

    # -------------------- main entrypoint --------------------

    def run(self):
        try:
            tables = self._discover_tables()
            print("[RUN] Beginning DQX rule generation on these tables:")
            for t in tables:
                print(f"  {t}")
            print("==========================================\n")

            # Echo profiler options (now passed through as-is)
            print("[RUN] Profiler options in effect:")
            for k, v in self.profile_options.items():
                print(f"  {k}: {v}")

            dq_engine = DQEngine(WorkspaceClient())
            used_options = self.profile_options

            total_checks = 0
            for fq_table in tables:
                parts = fq_table.split('.')
                if len(parts) != 3:
                    print(f"[WARN] Skipping invalid table name: {fq_table}")
                    continue
                cat, sch, tab = parts

                # Verify table readable
                try:
                    print(f"[RUN] Checking table readability: {fq_table}")
                    self.spark.table(fq_table).limit(1).collect()
                except Exception as e:
                    print(f"[WARN] Table {fq_table} not readable in Spark: {e}")
                    continue

                profiler = DQProfiler(WorkspaceClient())
                generator = DQDltGenerator(WorkspaceClient())
                df = self.spark.table(fq_table)

                try:
                    print(f"[RUN] Profiling and generating rules for: {fq_table}")
                    if self.mode == "table" and self.columns is not None:
                        summary_stats, profiles = profiler.profile(df, cols=self.columns, **used_options)
                    else:
                        summary_stats, profiles = profiler.profile(df, **used_options)

                    # Generate constraints (DLT-style) as Python dict
                    rules_dict = generator.generate_dlt_rules(profiles, language="Python_Dict")
                except TypeError as e:
                    # If profiler rejects unexpected kwargs, retry with no options
                    print(f"[WARN] Profiler rejected some options ({e}). Retrying with defaults...")
                    summary_stats, profiles = profiler.profile(df)
                    rules_dict = generator.generate_dlt_rules(profiles, language="Python_Dict")
                except Exception as e:
                    print(f"[WARN] Profiling failed for {fq_table}: {e}")
                    continue

                # Convert constraints -> DQX checks list (with our key order)
                checks: List[Dict[str, Any]] = []
                for rule_name, constraint in (rules_dict or {}).items():
                    checks.append(
                        self._dq_constraint_to_check(
                            rule_name=rule_name,
                            constraint_sql=constraint,
                            table_name=fq_table,
                            criticality=self.criticality,
                            run_config_name=self.run_config_name,
                        )
                    )

                if not checks:
                    print(f"[INFO] No checks generated for {fq_table}.")
                    continue

                # Determine output target per table
                if self.output_format == "yaml":
                    # If output_location is a directory, write {table}.yaml under it.
                    # If it's a file path ending with .yml/.yaml, use as-is.
                    if self.output_location.endswith(".yaml") or self.output_location.endswith(".yml"):
                        path = self.output_location
                    else:
                        path = self.output_location.rstrip("/") + f"/{tab}.yaml"

                    if self.yaml_key_order == "engine":
                        cfg = self._infer_file_storage_config(path)
                        print(f"[RUN] Saving {len(checks)} checks via DQX to: {path}")
                        dq_engine.save_checks(checks, config=cfg)
                    else:
                        print(f"[RUN] Saving {len(checks)} checks with strict key order to: {path}")
                        self._write_yaml_ordered(checks, path)

                    total_checks += len(checks)

                elif self.output_format == "table":
                    cfg = self._table_storage_config(
                        table_fqn=self.output_location,
                        run_config_name=self.run_config_name,
                        mode="append"
                    )
                    print(f"[RUN] Appending {len(checks)} checks to table: {self.output_location} (run_config_name={self.run_config_name})")
                    dq_engine.save_checks(checks, config=cfg)
                    total_checks += len(checks)

                else:
                    raise ValueError(f"Unsupported output_format: {self.output_format}")

            if total_checks:
                print(f"[RUN] Successfully saved {total_checks} checks.")
            else:
                print("[INFO] No checks were saved.")
        except Exception as e:
            print(f"[ERROR] Rule generation failed: {e}")


# -------------------- Usage example --------------------
if __name__ == "__main__":
    profile_options = {
        "max_in_count": 10,
        "distinct_ratio": 0.05,
        "max_null_ratio": 0.01,
        "remove_outliers": True,
        "outlier_columns": [],
        "num_sigmas": 3,
        "trim_strings": True,
        "max_empty_ratio": 0.01,
        "sample_fraction": 0.3,
        "sample_seed": None,
        "limit": 1000,
        "profile_types": None,
        "min_length": None,
        "max_length": None,
        "include_histograms": False,
        "min_value": None,
        "max_value": None,
    }

    RuleGenerator(
        mode="table",                                   # "pipeline" | "catalog" | "schema" | "table"
        name_param="dq_prd.monitoring.job_run_audit",   # depends on mode
        output_format="yaml",                           # "yaml" | "table"
        output_location="/Shared/dqx_checks",           # yaml: folder or /Shared/foo.yml; table: catalog.schema.table
        profile_options=profile_options,
        columns=None,                                   # None => whole table (only valid when mode=="table")
        exclude_pattern=None,                           # e.g. ".tmp_*"
        created_by="LMG",
        run_config_name="default",                      # tag checks for filtering when applying
        criticality="warn",                             # "warn" | "error"
        yaml_key_order="custom",                        # "engine" (DQX write) | "custom" (strict key order in YAML)
    ).run()