# DQX POC

In [0]:
%pip install databricks-labs-dqx

dbutils.library.restartPython()

In [0]:
from databricks.sdk import WorkspaceClient
from databricks.labs.dqx.profiler.profiler import DQProfiler
from databricks.labs.dqx.profiler.dlt_generator import DQDltGenerator
from pyspark.sql import SparkSession, types as T
from typing import List, Optional
import inspect

# -- DQX profiling options: all known keys --
profile_options = {
    # "round": True,                 # (Removed - not valid for this profiler version)
    "max_in_count": 10,            # Max distinct values for is_in rule
    "distinct_ratio": 0.05,        # Max unique/total ratio for is_in rule
    "max_null_ratio": 0.01,        # Max null fraction to allow is_not_null rule
    "remove_outliers": True,       # Remove outliers for min/max
    "outlier_columns": [],         # Only these columns get outlier removal (empty=all numerics)
    "num_sigmas": 3,               # Stddev for outlier removal (z-score cutoff)
    "trim_strings": True,          # Strip whitespace before profiling strings
    "max_empty_ratio": 0.01,       # Max empty string ratio for is_not_null_or_empty
    "sample_fraction": 0.3,        # Row fraction to sample
    "sample_seed": None,           # Seed for reproducibility (set int for deterministic)
    "limit": 1000,                 # Max number of rows to profile
    "profile_types": None,         # List of rule types (e.g. ["is_in", "is_not_null"]); None=default
    "min_length": None,            # Min string length to consider (None disables)
    "max_length": None,            # Max string length to consider (None disables)
    "include_histograms": False,   # Compute histograms as part of profiling
    "min_value": None,             # Numeric min override (None disables)
    "max_value": None,             # Numeric max override (None disables)
}

def valid_profile_options(profile_options):
    """Filter out unsupported keys for DQProfiler.profile()"""
    valid_keys = set(inspect.signature(DQProfiler.profile).parameters.keys()) - {"self", "df"}
    return {k: v for k, v in profile_options.items() if k in valid_keys}

# --- Table Discovery ---
def discover_output_tables(
    pipeline_name: str,
    sdk_client: Optional[object] = None
) -> List[str]:
    w = sdk_client or WorkspaceClient()
    pipelines = list(w.pipelines.list_pipelines())
    pl = next((p for p in pipelines if p.name == pipeline_name), None)
    if not pl:
        print(f"[ERROR] Pipeline '{pipeline_name}' not found via SDK.")
        raise RuntimeError(f"Pipeline '{pipeline_name}' not found via SDK.")
    latest_update = pl.latest_updates[0].update_id
    print(f"[INFO] Using latest update ID: {latest_update} for pipeline: {pipeline_name}")
    events = w.pipelines.list_pipeline_events(pipeline_id=pl.pipeline_id, max_results=250)
    tables = set()
    empty_pages = 0
    buffer = []
    it = iter(events)
    while True:
        buffer.clear()
        try:
            for _ in range(250):
                buffer.append(next(it))
        except StopIteration:
            pass
        page_tables = {
            getattr(ev.origin, "flow_name", None)
            for ev in buffer
            if getattr(ev.origin, "update_id", None) == latest_update and getattr(ev.origin, "flow_name", None)
        }
        page_tables.discard(None)
        if page_tables:
            tables |= page_tables
            empty_pages = 0
            print(f"[DEBUG] Found tables in this page: {sorted(page_tables)}")
        else:
            empty_pages += 1
            print(f"[DEBUG] No tables found in this page. Empty pages count: {empty_pages}")
        if empty_pages >= 2 or not buffer:
            break
    found = sorted(tables)
    if not found:
        print(f"[ERROR] No output tables found for pipeline '{pipeline_name}' using SDK event logs.")
    else:
        print(f"[INFO] Found {len(found)} tables for pipeline '{pipeline_name}': {found}")
    return found

# --- Helper Functions ---

def get_schema_name(table_name: str) -> str:
    parts = table_name.split(".")
    if len(parts) == 3:
        return f"{parts[0]}.{parts[1]}"
    elif len(parts) == 2:
        return parts[0]
    else:
        raise ValueError(f"Invalid table name: {table_name}")

def ensure_schema_exists(spark, table_name: str):
    schema = get_schema_name(table_name)
    spark.sql(f"CREATE SCHEMA IF NOT EXISTS {schema}")

def table_exists(spark, table_name):
    try:
        spark.table(table_name)
        return True
    except Exception as e:
        print(f"[table_exists] Table {table_name} not found or not readable: {e}")
        return False

def create_table_with_schema(spark, table_name):
    ensure_schema_exists(spark, table_name)
    schema = T.StructType([
        T.StructField("pipeline", T.StringType(), True),
        T.StructField("table", T.StringType(), True),
        T.StructField("name", T.StringType(), True),
        T.StructField("constraint", T.StringType(), True),
    ])
    empty_df = spark.createDataFrame([], schema)
    empty_df.write.mode("overwrite").saveAsTable(table_name)

def generate_dlt_expectations_for_table(table_name, ws, spark, profile_options):
    if not table_exists(spark, table_name):
        print(f"Skipping missing table: {table_name}")
        return None
    df = spark.table(table_name)
    profiler = DQProfiler(ws)
    generator = DQDltGenerator(ws)
    # Only valid options passed!
    summary_stats, profiles = profiler.profile(df, **valid_profile_options(profile_options))
    return generator.generate_dlt_rules(profiles, language="Python_Dict")

def create_or_overwrite_table(spark, df, output_table):
    ensure_schema_exists(spark, output_table)
    if table_exists(spark, output_table):
        df.write.mode("overwrite").saveAsTable(output_table)
    else:
        df.write.saveAsTable(output_table)

def filter_existing_tables(spark, table_list):
    existing = []
    for t in table_list:
        try:
            print(f"[filter_existing_tables] Trying to preview table: {t}")
            spark.table(t).show(1)
            existing.append(t)
        except Exception as e:
            print(f"[WARN] Table {t} not readable in Spark: {e}")
    return existing

# --- Main Orchestration ---

def generate_and_save_all_expectations(pipeline_name, output_table, profile_options):
    ws = WorkspaceClient()
    spark = SparkSession.getActiveSession()
    if not spark:
        raise RuntimeError("No active Spark session found. Run this in a Databricks notebook.")
    print(f"Using Spark session: {spark}")
    tables = discover_output_tables(pipeline_name, sdk_client=ws)
    print(f"Discovered tables for pipeline '{pipeline_name}': {tables}")
    existing_tables = filter_existing_tables(spark, tables)
    print(f"Tables readable in Spark: {existing_tables}")
    all_rules = []
    for tbl in existing_tables:
        print(f"Profiling and generating expectations for table: {tbl}")
        rules = generate_dlt_expectations_for_table(tbl, ws, spark, profile_options)
        if not rules:
            continue
        for name, constraint in rules.items():
            all_rules.append((pipeline_name, tbl, name, constraint))
    if all_rules:
        rules_df = spark.createDataFrame(all_rules, schema=["pipeline", "table", "name", "constraint"])
        create_or_overwrite_table(spark, rules_df, output_table)
        print(f"Wrote {len(all_rules)} expectations to table '{output_table}'.")
    else:
        print("No rules generated—check pipeline/table names and data.")
        if not table_exists(spark, output_table):
            create_table_with_schema(spark, output_table)

# --- Execution ---

pipeline_name = "pl_zoo_bronze"
output_constraints_table = "dq_dev.expectations.dqx_expectations"

generate_and_save_all_expectations(pipeline_name, output_constraints_table, profile_options)

spark = SparkSession.getActiveSession()
if spark and table_exists(spark, output_constraints_table):
    display(spark.table(output_constraints_table))
else:
    print(f"Table {output_constraints_table} does not exist after execution.")