# DQX Rules Generator

In [0]:
%pip install databricks-labs-dqx

dbutils.library.restartPython()

import databricks.labs.dqx
print(databricks.labs.dqx.__version__)

In [0]:
import re
import json
import inspect
from datetime import datetime, timezone
from typing import List, Optional, Dict, Any

from pyspark.sql import SparkSession, types as T, DataFrame

from databricks.sdk import WorkspaceClient
from databricks.labs.dqx.profiler.profiler import DQProfiler
from databricks.labs.dqx.profiler.dlt_generator import DQDltGenerator

def glob_to_regex(glob_pattern: str) -> str:
    if not glob_pattern or not glob_pattern.startswith('.'):
        raise ValueError("Exclude pattern must start with a dot, e.g. '.tamarack_*'")
    glob = glob_pattern[1:]
    regex = re.escape(glob).replace(r'\*', '.*')
    return '^' + regex + '$'

class RuleGenerator:
    def __init__(
        self,
        mode: str,
        name_param: str,
        output_table: str,
        profile_options: Dict[str, Any],
        exclude_pattern: Optional[str] = None,
        created_by: Optional[str] = "admin",
    ):
        self.mode = mode.lower().strip()
        self.name_param = name_param
        self.output_table = output_table
        self.profile_options = self._validate_profile_options(profile_options)
        self.exclude_pattern = exclude_pattern
        self.created_by = created_by
        self.spark = SparkSession.getActiveSession()
        if not self.spark:
            raise RuntimeError("No active Spark session found. Run this in a Databricks notebook.")

    @staticmethod
    def _allowed_profile_options():
        return set(inspect.signature(DQProfiler.profile).parameters.keys()) - {"self", "df"}

    def _validate_profile_options(self, opts: Dict[str, Any]) -> Dict[str, Any]:
        allowed = self._allowed_profile_options()
        valid = {k: v for k, v in (opts or {}).items() if k in allowed}
        ignored = set(opts or {}).difference(allowed)
        if ignored:
            print(f"[WARN] Ignored invalid profile_options keys: {sorted(ignored)}")
        return valid

    def _get_rules_output_schema(self):
        return T.StructType([
            T.StructField("batch_id", T.IntegerType(), True),
            T.StructField("pipeline", T.StringType(), True),
            T.StructField("is_pipeline", T.BooleanType(), True),
            T.StructField("source", T.StringType(), True),
            T.StructField("catalog", T.StringType(), True),
            T.StructField("schema", T.StringType(), True),
            T.StructField("table", T.StringType(), True),
            T.StructField("profile_options", T.ArrayType(
                T.StructType([
                    T.StructField("key", T.StringType(), False),
                    T.StructField("value", T.StringType(), True)
                ])
            ), True),
            T.StructField("rule_name", T.StringType(), True),
            T.StructField("rule_constraint", T.StringType(), True),
            T.StructField("created_by", T.StringType(), True),
            T.StructField("created_at", T.TimestampType(), True),
        ])

    def _exclude_tables_by_pattern(self, fq_tables: List[str]) -> List[str]:
        if not self.exclude_pattern:
            return fq_tables
        regex = glob_to_regex(self.exclude_pattern)
        pattern = re.compile(regex)
        filtered = []
        for fq in fq_tables:
            tbl = fq.split('.')[-1]
            if not pattern.match(tbl):
                filtered.append(fq)
        print(f"[INFO] Excluded {len(fq_tables) - len(filtered)} tables by pattern '{self.exclude_pattern}'")
        return filtered

    def _discover_tables(self) -> List[str]:
        # === PRINT: Values passed in this run ===
        print("\n===== PARAMETERS PASSED THIS RUN =====")
        print(f"mode:           {self.mode}")
        print(f"name_param:     {self.name_param}")
        print(f"output_table:   {self.output_table}")
        print(f"exclude_pattern:{self.exclude_pattern}")
        print(f"created_by:     {self.created_by}")
        print(f"profile_options:")
        for k, v in self.profile_options.items():
            print(f"  {k}: {v}")
        print("======================================\n")

        mode = self.mode
        name_param = self.name_param
        allowed_modes = {"pipeline", "catalog", "schema", "table"}
        if mode not in allowed_modes:
            raise ValueError(f"Invalid mode '{mode}'. Must be one of: {sorted(allowed_modes)}.")

        discovered = []

        if mode == "pipeline":
            print("Searching for pipeline output tables...")
            ws = WorkspaceClient()
            pipelines = [p.strip() for p in name_param.split(",") if p.strip()]
            print(f"Pipelines passed: {pipelines}")
            for pipeline_name in pipelines:
                print(f"Finding output tables for pipeline: {pipeline_name}")
                pls = list(ws.pipelines.list_pipelines())
                pl = next((p for p in pls if p.name == pipeline_name), None)
                if not pl:
                    raise RuntimeError(f"Pipeline '{pipeline_name}' not found via SDK.")
                latest_update = pl.latest_updates[0].update_id
                events = ws.pipelines.list_pipeline_events(pipeline_id=pl.pipeline_id, max_results=250)
                pipeline_tables = [
                    getattr(ev.origin, "flow_name", None)
                    for ev in events
                    if getattr(ev.origin, "update_id", None) == latest_update and getattr(ev.origin, "flow_name", None)
                ]
                pipeline_tables = [x for x in pipeline_tables if x]
                print(f"Found tables for pipeline '{pipeline_name}': {pipeline_tables}")
                discovered += pipeline_tables
            print(f"All discovered tables from pipelines: {discovered}")

        elif mode == "catalog":
            print("Searching for tables in catalog...")
            catalog = name_param.strip()
            schemas = [row.namespace for row in self.spark.sql(f"SHOW SCHEMAS IN {catalog}").collect()]
            catalog_tables = []
            for s in schemas:
                tbls = self.spark.sql(f"SHOW TABLES IN {catalog}.{s}").collect()
                found = [f"{catalog}.{s}.{row.tableName}" for row in tbls]
                catalog_tables += found
                if found:
                    print(f"Tables in {catalog}.{s}: {found}")
            discovered = catalog_tables
            print(f"All discovered tables in catalog: {discovered}")

        elif mode == "schema":
            print("Searching for tables in schema...")
            if name_param.count(".") != 1:
                raise ValueError("For 'schema' mode, name_param must be catalog.schema")
            catalog, schema = name_param.strip().split(".")
            tbls = self.spark.sql(f"SHOW TABLES IN {catalog}.{schema}").collect()
            schema_tables = [f"{catalog}.{schema}.{row.tableName}" for row in tbls]
            discovered = schema_tables
            print(f"Tables in schema {catalog}.{schema}: {schema_tables}")

        elif mode == "table":
            print("Profiling one or more specific tables...")
            tables = [t.strip() for t in name_param.split(",") if t.strip()]
            for t in tables:
                if t.count(".") != 2:
                    raise ValueError(f"Table name '{t}' must be fully qualified (catalog.schema.table)")
            discovered = tables
            print(f"Tables to be profiled: {discovered}")

        # Exclude
        print("\nRunning exclude pattern filtering (if any)...")
        discovered = self._exclude_tables_by_pattern(discovered)
        print("\nFinal table list to generate DQX rules for:")
        print(discovered)
        print("==========================================\n")
        return sorted(set(discovered))

    def _get_next_batch_id(self) -> int:
        table_exists = self.spark._jsparkSession.catalog().tableExists(self.output_table)
        if not table_exists:
            return 1
        try:
            existing = self.spark.table(self.output_table)
            max_id = existing.agg({"batch_id": "max"}).collect()[0][0]
            return (max_id or 0) + 1
        except Exception:
            return 1

    def _dict_to_kv_array(self, d: Dict[str, Any]) -> List[Dict[str, Any]]:
        return [{"key": k, "value": str(v)} for k, v in sorted(d.items())]

    def run(self):
        try:
            # === PRINT: Starting rule generation section ===
            tables = self._discover_tables()
            print("[RUN] Beginning DQX rule generation on these tables:")
            for t in tables:
                print(f"  {t}")
            print("==========================================\n")
            all_rules = []
            utc_now = datetime.now(timezone.utc)
            used_options = self.profile_options
            profile_options_array = self._dict_to_kv_array(used_options)
            batch_id = self._get_next_batch_id()
            print(f"[RUN] Using batch_id={batch_id}")
            for fq_table in tables:
                parts = fq_table.split('.')
                cat, sch, tab = (parts + [None, None, None])[:3]
                is_pipeline = self.mode == "pipeline"
                try:
                    print(f"[RUN] Checking table: {fq_table}")
                    self.spark.table(fq_table).show(1)
                except Exception as e:
                    print(f"[WARN] Table {fq_table} not readable in Spark: {e}")
                    continue
                profiler = DQProfiler(WorkspaceClient())
                generator = DQDltGenerator(WorkspaceClient())
                df = self.spark.table(fq_table)
                try:
                    print(f"[RUN] Profiling and generating rules for: {fq_table}")
                    summary_stats, profiles = profiler.profile(df, **used_options)
                    rules_dict = generator.generate_dlt_rules(profiles, language="Python_Dict")
                except Exception as e:
                    print(f"[WARN] Profiling failed for {fq_table}: {e}")
                    continue
                for rule_name, constraint in rules_dict.items():
                    all_rules.append({
                        "batch_id": batch_id,
                        "pipeline": self.name_param if is_pipeline else None,
                        "is_pipeline": is_pipeline,
                        "source": fq_table,
                        "catalog": cat,
                        "schema": sch,
                        "table": tab,
                        "profile_options": profile_options_array,
                        "rule_name": rule_name,
                        "rule_constraint": constraint,
                        "created_by": self.created_by,
                        "created_at": utc_now,
                    })
            if all_rules:
                print(f"[RUN] Writing {len(all_rules)} rules to output table: {self.output_table}")
                schema = self._get_rules_output_schema()
                rules_df = self.spark.createDataFrame(all_rules, schema=schema)
                output_schema = ".".join(self.output_table.split(".")[:-1])
                self.spark.sql(f"CREATE SCHEMA IF NOT EXISTS {output_schema}")
                rules_df.write.mode("append").saveAsTable(self.output_table)
                print(f"[RUN] Successfully wrote to table: {self.output_table}")
            else:
                print("[INFO] No rules generated.")
        except Exception as e:
            print(f"[ERROR] Rule generation failed: {e}")

# No changes to the example usage below:
if __name__ == "__main__":
    profile_options = {
        "max_in_count": 10,
        "distinct_ratio": 0.05,
        "max_null_ratio": 0.01,
        "remove_outliers": True,
        "outlier_columns": [],
        "num_sigmas": 3,
        "trim_strings": True,
        "max_empty_ratio": 0.01,
        "sample_fraction": 0.3,
        "sample_seed": None,
        "limit": 1000,
        "profile_types": None,
        "min_length": None,
        "max_length": None,
        "include_histograms": False,
        "min_value": None,
        "max_value": None,
    }
    RuleGenerator(
        mode="pipeline",
        name_param="pl_zoo_bronze",
        output_table="dq_dev.expectations.dqx_expectations",
        profile_options=profile_options,
        exclude_pattern=None,
        created_by="levi",
    ).run()

In [0]:
import re
import json
import inspect
from datetime import datetime, timezone
from typing import List, Optional, Dict, Any
from databricks.sdk import WorkspaceClient
from databricks.labs.dqx.profiler.profiler import DQProfiler
from databricks.labs.dqx.profiler.dlt_generator import DQDltGenerator
from pyspark.sql import SparkSession, types as T, DataFrame

def glob_to_regex(glob_pattern: str) -> str:
    if not glob_pattern or not glob_pattern.startswith('.'):
        raise ValueError("Exclude pattern must start with a dot, e.g. '.tamarack_*'")
    glob = glob_pattern[1:]
    regex = re.escape(glob).replace(r'\*', '.*')
    return '^' + regex + '$'

class RuleGenerator:
    def __init__(
        self,
        mode: str,
        name_param: str,
        output_table: str,
        profile_options: Dict[str, Any],
        exclude_pattern: Optional[str] = None,
        created_by: Optional[str] = "admin",
        columns: Optional[List[str]] = None,
    ):
        self.mode = mode.lower().strip()
        self.name_param = name_param
        self.output_table = output_table
        self.profile_options = self._validate_profile_options(profile_options)
        self.exclude_pattern = exclude_pattern
        self.created_by = created_by
        self.columns = columns
        self.spark = SparkSession.getActiveSession()
        if not self.spark:
            raise RuntimeError("No active Spark session found. Run this in a Databricks notebook.")

    @staticmethod
    def _allowed_profile_options():
        return set(inspect.signature(DQProfiler.profile).parameters.keys()) - {"self", "df"}

    def _validate_profile_options(self, opts: Dict[str, Any]) -> Dict[str, Any]:
        allowed = self._allowed_profile_options()
        valid = {k: v for k, v in (opts or {}).items() if k in allowed}
        ignored = set(opts or {}).difference(allowed)
        if ignored:
            print(f"[WARN] Ignored invalid profile_options keys: {sorted(ignored)}")
        return valid

    def _get_rules_output_schema(self):
        return T.StructType([
            T.StructField("batch_id", T.IntegerType(), True),
            T.StructField("pipeline", T.StringType(), True),
            T.StructField("is_pipeline", T.BooleanType(), True),
            T.StructField("source", T.StringType(), True),
            T.StructField("catalog", T.StringType(), True),
            T.StructField("schema", T.StringType(), True),
            T.StructField("table", T.StringType(), True),
            T.StructField("columns", T.ArrayType(T.StringType()), True),
            # -- Audit fields (order matters here) --
            T.StructField("profile_options", T.ArrayType(
                T.StructType([
                    T.StructField("key", T.StringType(), False),
                    T.StructField("value", T.StringType(), True)
                ])
            ), True),
            T.StructField("rule_generator_params", T.ArrayType(
                T.StructType([
                    T.StructField("key", T.StringType(), False),
                    T.StructField("value", T.StringType(), True)
                ])
            ), True),
            T.StructField("created_by", T.StringType(), True),
            T.StructField("created_at", T.TimestampType(), True),
            # -- End audit fields --
            T.StructField("rule_name", T.StringType(), True),
            T.StructField("rule_constraint", T.StringType(), True),
        ])

    def _exclude_tables_by_pattern(self, fq_tables: List[str]) -> List[str]:
        if not self.exclude_pattern:
            return fq_tables
        regex = glob_to_regex(self.exclude_pattern)
        pattern = re.compile(regex)
        filtered = []
        for fq in fq_tables:
            tbl = fq.split('.')[-1]
            if not pattern.match(tbl):
                filtered.append(fq)
        print(f"[INFO] Excluded {len(fq_tables) - len(filtered)} tables by pattern '{self.exclude_pattern}'")
        return filtered

    def _discover_tables(self) -> List[str]:
        print("\n===== PARAMETERS PASSED THIS RUN =====")
        print(f"mode:           {self.mode}")
        print(f"name_param:     {self.name_param}")
        print(f"output_table:   {self.output_table}")
        print(f"exclude_pattern:{self.exclude_pattern}")
        print(f"created_by:     {self.created_by}")
        print(f"columns:        {self.columns}")
        print(f"profile_options:")
        for k, v in self.profile_options.items():
            print(f"  {k}: {v}")
        print("======================================\n")

        mode = self.mode
        name_param = self.name_param
        allowed_modes = {"pipeline", "catalog", "schema", "table"}
        if mode not in allowed_modes:
            raise ValueError(f"Invalid mode '{mode}'. Must be one of: {sorted(allowed_modes)}.")

        # Validate that columns can only be used in table mode
        if self.columns is not None and mode != "table":
            raise ValueError("The 'columns' parameter can only be used in mode='table'.")

        discovered = []

        if mode == "pipeline":
            print("Searching for pipeline output tables...")
            ws = WorkspaceClient()
            pipelines = [p.strip() for p in name_param.split(",") if p.strip()]
            print(f"Pipelines passed: {pipelines}")
            for pipeline_name in pipelines:
                print(f"Finding output tables for pipeline: {pipeline_name}")
                pls = list(ws.pipelines.list_pipelines())
                pl = next((p for p in pls if p.name == pipeline_name), None)
                if not pl:
                    raise RuntimeError(f"Pipeline '{pipeline_name}' not found via SDK.")
                latest_update = pl.latest_updates[0].update_id
                events = ws.pipelines.list_pipeline_events(pipeline_id=pl.pipeline_id, max_results=250)
                pipeline_tables = [
                    getattr(ev.origin, "flow_name", None)
                    for ev in events
                    if getattr(ev.origin, "update_id", None) == latest_update and getattr(ev.origin, "flow_name", None)
                ]
                pipeline_tables = [x for x in pipeline_tables if x]
                print(f"Found tables for pipeline '{pipeline_name}': {pipeline_tables}")
                discovered += pipeline_tables
            print(f"All discovered tables from pipelines: {discovered}")

        elif mode == "catalog":
            print("Searching for tables in catalog...")
            catalog = name_param.strip()
            schemas = [row.namespace for row in self.spark.sql(f"SHOW SCHEMAS IN {catalog}").collect()]
            catalog_tables = []
            for s in schemas:
                tbls = self.spark.sql(f"SHOW TABLES IN {catalog}.{s}").collect()
                found = [f"{catalog}.{s}.{row.tableName}" for row in tbls]
                catalog_tables += found
                if found:
                    print(f"Tables in {catalog}.{s}: {found}")
            discovered = catalog_tables
            print(f"All discovered tables in catalog: {discovered}")

        elif mode == "schema":
            print("Searching for tables in schema...")
            if name_param.count(".") != 1:
                raise ValueError("For 'schema' mode, name_param must be catalog.schema")
            catalog, schema = name_param.strip().split(".")
            tbls = self.spark.sql(f"SHOW TABLES IN {catalog}.{schema}").collect()
            schema_tables = [f"{catalog}.{schema}.{row.tableName}" for row in tbls]
            discovered = schema_tables
            print(f"Tables in schema {catalog}.{schema}: {schema_tables}")

        elif mode == "table":
            print("Profiling one or more specific tables...")
            tables = [t.strip() for t in name_param.split(",") if t.strip()]
            for t in tables:
                if t.count(".") != 2:
                    raise ValueError(f"Table name '{t}' must be fully qualified (catalog.schema.table)")
            discovered = tables
            print(f"Tables to be profiled: {discovered}")

        print("\nRunning exclude pattern filtering (if any)...")
        discovered = self._exclude_tables_by_pattern(discovered)
        print("\nFinal table list to generate DQX rules for:")
        print(discovered)
        print("==========================================\n")
        return sorted(set(discovered))

    def _get_next_batch_id(self) -> int:
        table_exists = self.spark._jsparkSession.catalog().tableExists(self.output_table)
        if not table_exists:
            return 1
        try:
            existing = self.spark.table(self.output_table)
            max_id = existing.agg({"batch_id": "max"}).collect()[0][0]
            return (max_id or 0) + 1
        except Exception:
            return 1

    def _dict_to_kv_array(self, d: Dict[str, Any]) -> List[Dict[str, Any]]:
        return [{"key": k, "value": str(v)} for k, v in sorted(d.items())]

    def _get_rule_generator_params(self) -> List[Dict[str, str]]:
        # Audit parameters passed to main
        params = {
            "mode": self.mode,
            "name_param": self.name_param,
            "output_table": self.output_table,
            "columns": json.dumps(self.columns) if self.columns is not None else "None",
            "exclude_pattern": str(self.exclude_pattern),
            "created_by": self.created_by,
        }
        return [{"key": k, "value": str(v)} for k, v in params.items()]

    def run(self):
        try:
            tables = self._discover_tables()
            print("[RUN] Beginning DQX rule generation on these tables:")
            for t in tables:
                print(f"  {t}")
            print("==========================================\n")
            all_rules = []
            utc_now = datetime.now(timezone.utc)
            used_options = self.profile_options
            profile_options_array = self._dict_to_kv_array(used_options)
            rule_generator_params_array = self._get_rule_generator_params()
            batch_id = self._get_next_batch_id()
            print(f"[RUN] Using batch_id={batch_id}")
            for fq_table in tables:
                parts = fq_table.split('.')
                cat, sch, tab = (parts + [None, None, None])[:3]
                is_pipeline = self.mode == "pipeline"
                try:
                    print(f"[RUN] Checking table: {fq_table}")
                    self.spark.table(fq_table).show(1)
                except Exception as e:
                    print(f"[WARN] Table {fq_table} not readable in Spark: {e}")
                    continue
                profiler = DQProfiler(WorkspaceClient())
                generator = DQDltGenerator(WorkspaceClient())
                df = self.spark.table(fq_table)
                try:
                    print(f"[RUN] Profiling and generating rules for: {fq_table}")
                    if self.mode == "table" and self.columns is not None:
                        summary_stats, profiles = profiler.profile(df, cols=self.columns, **used_options)
                    else:
                        summary_stats, profiles = profiler.profile(df, **used_options)
                    rules_dict = generator.generate_dlt_rules(profiles, language="Python_Dict")
                except Exception as e:
                    print(f"[WARN] Profiling failed for {fq_table}: {e}")
                    continue
                for rule_name, constraint in rules_dict.items():
                    rule_columns = []
                    if self.columns is not None:
                        rule_columns = self.columns
                    elif isinstance(rule_name, str) and rule_name in df.columns:
                        rule_columns = [rule_name]
                    all_rules.append({
                        "batch_id": batch_id,
                        "pipeline": self.name_param if is_pipeline else None,
                        "is_pipeline": is_pipeline,
                        "source": fq_table,
                        "catalog": cat,
                        "schema": sch,
                        "table": tab,
                        "columns": rule_columns,
                        "profile_options": profile_options_array,
                        "rule_generator_params": rule_generator_params_array,
                        "created_by": self.created_by,
                        "created_at": utc_now,
                        "rule_name": rule_name,
                        "rule_constraint": constraint,
                    })
            if all_rules:
                print(f"[RUN] Writing {len(all_rules)} rules to output table: {self.output_table}")
                schema = self._get_rules_output_schema()
                rules_df = self.spark.createDataFrame(all_rules, schema=schema)
                output_schema = ".".join(self.output_table.split(".")[:-1])
                self.spark.sql(f"CREATE SCHEMA IF NOT EXISTS {output_schema}")
                rules_df.write.mode("append").saveAsTable(self.output_table)
                print(f"[RUN] Successfully wrote to table: {self.output_table}")
            else:
                print("[INFO] No rules generated.")
        except Exception as e:
            print(f"[ERROR] Rule generation failed: {e}")

# Usage example
if __name__ == "__main__":
    profile_options = {
        "max_in_count": 10,
        "distinct_ratio": 0.05,
        "max_null_ratio": 0.01,
        "remove_outliers": True,
        "outlier_columns": [],
        "num_sigmas": 3,
        "trim_strings": True,
        "max_empty_ratio": 0.01,
        "sample_fraction": 0.3,
        "sample_seed": None,
        "limit": 1000,
        "profile_types": None,
        "min_length": None,
        "max_length": None,
        "include_histograms": False,
        "min_value": None,
        "max_value": None,
    }
    RuleGenerator(
        mode="table",
        name_param="dq_dev.lmg_expectations.zoo_animal_inventory_stg",
        output_table="dq_dev.expectations.dqx_expectations",
        profile_options=profile_options,
        columns=["animal_id", "species"],
        exclude_pattern=None,
        created_by="LMG",
    ).run()