In [0]:
%load_ext autoreload
%autoreload 2
# Enables autoreload; learn more at https://docs.databricks.com/en/files/workspace-modules.html#autoreload-for-python-modules
# To disable autoreload; run %autoreload 0

# Layker - Dev Testing

### Yaml

In [0]:
# src/layker/snapshot_yaml.py

"""
YAML TABLE VALIDATION & SNAPSHOT
-------------------------------------------------------------------------------
This file validates and sanitizes YAML table DDL for Databricks/Delta Lake.
Validation covers all of the following rules:

TABLE-LEVEL VALIDATION CHECKS
-------------------------------------------------------------------------------
- Required top-level keys: catalog, schema, table, columns, properties
- catalog/schema/table: must be valid SQL identifiers ([a-z][a-z0-9_]*)
- At least one column must be defined under columns
- Column keys must be continuous (1,2,...N)
- Each column must include: name, datatype, nullable, active
- No duplicate column names allowed
- Column names must be valid SQL identifiers
- datatype must be supported Spark type (or valid complex type)
- active must be boolean
- If default_value present:
    - Must match expected type for that datatype
    - Boolean must be bool or string "true"/"false"
- Column comments cannot contain newline, carriage return, tab, or single quote

COLUMN CHECK CONSTRAINTS
-------------------------------------------------------------------------------
- If present, must be a dict
- Each constraint must be a dict with both name and expression
- No duplicate constraint names per column

SCHEMA-LEVEL REFERENCES
-------------------------------------------------------------------------------
- primary_key, partitioned_by: must reference only existing columns
- unique_keys: must be a list of lists, each referencing valid columns
- foreign_keys:
    - Must be a dict
    - Each FK must have: columns, reference_table, reference_columns
    - columns must exist
    - reference_table must be fully qualified (catalog.schema.table)
    - reference_columns must be list of strings

TABLE-LEVEL FEATURES
-------------------------------------------------------------------------------
- table_check_constraints: dict, each with name and expression, no duplicates
- row_filters: dict, each with name and expression, no duplicates
- tags: must be a dict (if present)
- owner: must be string or null (if present)

"""

import sys
import yaml
import re
from typing import Any, Dict, List, Tuple, Optional

# ---- VALIDATOR ----

class YamlSnapshot:
    REQUIRED_TOP_KEYS = ["catalog", "schema", "table", "columns"]
    OPTIONAL_TOP_KEYS = [
        "primary_key", "partitioned_by", "unique_keys", "foreign_keys",
        "table_check_constraints", "row_filters", "tags", "owner",
        "table_comment", "table_properties"
    ]

    REQUIRED_COL_KEYS = {"name", "datatype", "nullable", "active"}
    ALLOWED_OPTIONAL_COL_KEYS = {
        "comment", "tags", "column_masking_rule", "default_value", "variable_value", "column_check_constraints"
    }

    DISALLOWED_COMMENT_CHARS = ["\n", "\r", "\t", "'"]

    COMPLEX_TYPE_PATTERNS = [
        r"^array<.+>$", r"^struct<.+>$", r"^map<.+>$"
    ]

    ALLOWED_SPARK_TYPES = {
        "string": str, "int": int, "double": float, "float": float,
        "bigint": int, "boolean": bool, "binary": bytes,
        "date": str, "timestamp": str, "decimal": float,
    }

    @staticmethod
    def _is_valid_sql_identifier(name: str) -> bool:
        return bool(re.match(r"^[a-z][a-z0-9_]*$", name.strip()))

    @classmethod
    def _is_valid_spark_type(cls, dt: str) -> bool:
        dt_lc = dt.lower()
        if dt_lc in cls.ALLOWED_SPARK_TYPES:
            return True
        return any(re.match(p, dt_lc) for p in cls.COMPLEX_TYPE_PATTERNS)

    @staticmethod
    def _is_fully_qualified_table(ref: str) -> bool:
        return ref.count('.') == 2

    @classmethod
    def validate_dict(cls, cfg: Dict[str, Any]) -> Tuple[bool, List[str]]:
        errors: List[str] = []
        # 1. Required top-level keys
        for key in cls.REQUIRED_TOP_KEYS:
            if key not in cfg or cfg[key] in (None, ""):
                errors.append(f"Missing top-level key: '{key}'")

        # 2. Check that 'table_comment' and 'table_properties' are *not* under 'properties'
        if "properties" in cfg:
            if isinstance(cfg["properties"], dict):
                if "comment" in cfg["properties"]:
                    errors.append("Move 'comment' out of 'properties' and use top-level 'table_comment'")
                if "table_properties" in cfg["properties"]:
                    errors.append("Move 'table_properties' out of 'properties' and use top-level 'table_properties'")

        # 3. Table/catalog/schema identifier validity
        for k in ("catalog", "schema", "table"):
            v = cfg.get(k, "")
            if v and not cls._is_valid_sql_identifier(v.replace("_", "a").replace(".", "a")):
                errors.append(f"Invalid {k} name: '{v}'")

        # 4. Columns 1..N
        raw = cfg.get("columns", {})
        if not raw:
            errors.append("No columns defined. At least one column is required.")
            cols = {}
            nums = []
        else:
            cols = {str(k): v for k, v in raw.items()}
            try:
                nums = sorted(map(int, cols.keys()))
                if nums != list(range(1, len(nums) + 1)):
                    raise ValueError
            except Exception:
                errors.append(f"Column keys must be continuous 1..N, got {list(cols.keys())}")
                nums = []

        seen_names = set()
        all_col_names = []
        for i in nums:
            col = cols[str(i)]
            missing = cls.REQUIRED_COL_KEYS - set(col.keys())
            if missing:
                errors.append(f"Column {i} missing keys: {sorted(missing)}")
            name = col.get("name")
            if not name or not cls._is_valid_sql_identifier(name):
                errors.append(f"Column {i} name '{name}' invalid")
            if name in seen_names:
                errors.append(f"Duplicate column name: '{name}'")
            seen_names.add(name)
            all_col_names.append(name)
            dt = col.get("datatype")
            if not dt or not cls._is_valid_spark_type(dt):
                errors.append(f"Column {i} datatype '{dt}' not allowed")
            if not isinstance(col.get("active"), bool):
                errors.append(f"Column {i} 'active' must be boolean")
            dv = col.get("default_value")
            dt_lc = dt.lower() if dt else ""
            if dt and dv not in (None, "") and dt_lc in cls.ALLOWED_SPARK_TYPES and dt_lc not in ("date", "timestamp"):
                exp = cls.ALLOWED_SPARK_TYPES.get(dt_lc)
                if dt_lc == "boolean":
                    if not isinstance(dv, bool) and not (isinstance(dv, str) and dv.lower() in ("true", "false")):
                        errors.append(f"Column {i} default '{dv}' invalid for boolean")
                else:
                    if not isinstance(dv, exp):
                        errors.append(f"Column {i} default '{dv}' does not match {dt}")
            cm = col.get("comment", "")
            bad = [ch for ch in cls.DISALLOWED_COMMENT_CHARS if ch in cm]
            if bad:
                errors.append(f"Column {i} comment contains {bad}")

            # Column check constraints
            ccc = col.get("column_check_constraints", {})
            if ccc:
                if not isinstance(ccc, dict):
                    errors.append(f"Column {i} column_check_constraints must be a dict")
                else:
                    seen_constraint_names = set()
                    for cname, cdict in ccc.items():
                        if not isinstance(cdict, dict):
                            errors.append(f"Column {i} constraint '{cname}' must be a dict")
                        else:
                            if "name" not in cdict or "expression" not in cdict:
                                errors.append(f"Column {i} constraint '{cname}' missing 'name' or 'expression'")
                            name_val = cdict.get("name")
                            if name_val in seen_constraint_names:
                                errors.append(f"Column {i} has duplicate column_check_constraint name '{name_val}'")
                            seen_constraint_names.add(name_val)

        def validate_columns_exist(field, value):
            for col in value:
                if col not in all_col_names:
                    errors.append(f"Field '{field}' references unknown column '{col}'")

        if "primary_key" in cfg:
            pk = cfg["primary_key"]
            pk_cols = pk if isinstance(pk, list) else [pk]
            validate_columns_exist("primary_key", pk_cols)
        if "partitioned_by" in cfg:
            pb = cfg["partitioned_by"]
            pb_cols = pb if isinstance(pb, list) else [pb]
            validate_columns_exist("partitioned_by", pb_cols)
        if "unique_keys" in cfg:
            uk = cfg["unique_keys"]
            if not isinstance(uk, list):
                errors.append("unique_keys must be a list of lists")
            else:
                for idx, group in enumerate(uk):
                    if not isinstance(group, list):
                        errors.append(f"unique_keys entry {idx} must be a list")
                        continue
                    validate_columns_exist(f"unique_keys[{idx}]", group)
        if "foreign_keys" in cfg:
            fks = cfg["foreign_keys"]
            if not isinstance(fks, dict):
                errors.append("foreign_keys must be a dict")
            else:
                for fk_name, fk in fks.items():
                    required_fk_keys = {"columns", "reference_table", "reference_columns"}
                    missing_fk = required_fk_keys - set(fk)
                    if missing_fk:
                        errors.append(f"Foreign key '{fk_name}' missing keys: {missing_fk}")
                        continue
                    validate_columns_exist(f"foreign_keys.{fk_name}.columns", fk["columns"])
                    ref_tbl = fk["reference_table"]
                    if not isinstance(ref_tbl, str) or not cls._is_fully_qualified_table(ref_tbl):
                        errors.append(f"Foreign key '{fk_name}' reference_table '{ref_tbl}' must be fully qualified (catalog.schema.table)")
                    ref_cols = fk["reference_columns"]
                    if not isinstance(ref_cols, list) or not all(isinstance(x, str) for x in ref_cols):
                        errors.append(f"Foreign key '{fk_name}' reference_columns must be a list of strings")
        # Table-level check constraints
        if "table_check_constraints" in cfg:
            tcc = cfg["table_check_constraints"]
            if not isinstance(tcc, dict):
                errors.append("table_check_constraints must be a dict")
            else:
                names_seen = set()
                for cname, cdict in tcc.items():
                    if not isinstance(cdict, dict):
                        errors.append(f"table_check_constraints '{cname}' must be a dict")
                        continue
                    if "name" not in cdict or "expression" not in cdict:
                        errors.append(f"table_check_constraints '{cname}' missing 'name' or 'expression'")
                    name_val = cdict.get("name")
                    if name_val in names_seen:
                        errors.append(f"Duplicate table_check_constraint name: '{name_val}'")
                    names_seen.add(name_val)
        # Row filters
        if "row_filters" in cfg:
            rf = cfg["row_filters"]
            if not isinstance(rf, dict):
                errors.append("row_filters must be a dict")
            else:
                names_seen = set()
                for fname, fdict in rf.items():
                    if not isinstance(fdict, dict):
                        errors.append(f"row_filters '{fname}' must be a dict")
                        continue
                    if "name" not in fdict or "expression" not in fdict:
                        errors.append(f"row_filters '{fname}' missing 'name' or 'expression'")
                    name_val = fdict.get("name")
                    if name_val in names_seen:
                        errors.append(f"Duplicate row_filter name: '{name_val}'")
                    names_seen.add(name_val)
        if "tags" in cfg and not isinstance(cfg["tags"], dict):
            errors.append("Top-level 'tags' must be a dict")
        if "owner" in cfg and not (cfg["owner"] is None or isinstance(cfg["owner"], str)):
            errors.append("'owner' must be a string or null")
        # Enforce table_comment is string or missing
        if "table_comment" in cfg and not isinstance(cfg["table_comment"], str):
            errors.append("'table_comment' must be a string")
        # Enforce table_properties is dict or missing
        if "table_properties" in cfg and not isinstance(cfg["table_properties"], dict):
            errors.append("'table_properties' must be a dict")
        return (len(errors) == 0, errors)

# ---- SANITIZER ----

def sanitize_text(text: Any) -> str:
    t = str(text or "")
    clean = t.replace("\n", " ").replace("\r", " ").replace("\t", " ").strip()
    return clean.replace("'", "`")

def recursive_sanitize_comments(obj: Any, path: str = "") -> Any:
    if isinstance(obj, dict):
        for k, v in obj.items():
            curr = f"{path}.{k}" if path else k
            if path.endswith(".columns") and isinstance(v, dict) and "comment" in v:
                if isinstance(v["comment"], str):
                    v["comment"] = sanitize_text(v["comment"])
            else:
                recursive_sanitize_comments(v, curr)
    elif isinstance(obj, list):
        for i, item in enumerate(obj):
            recursive_sanitize_comments(item, f"{path}[{i}]")
    return obj

def sanitize_metadata(cfg: Dict[str, Any]) -> Dict[str, Any]:
    # Table comment: strip lines but preserve newlines
    if "table_comment" in cfg and isinstance(cfg["table_comment"], str):
        lines = str(cfg["table_comment"]).splitlines()
        cfg["table_comment"] = "\n".join(line.strip() for line in lines)
    # Table properties
    if "table_properties" in cfg and isinstance(cfg["table_properties"], dict):
        for k in list(cfg["table_properties"]):
            cfg["table_properties"][k] = sanitize_text(cfg["table_properties"][k])
    # Tags
    tags = cfg.setdefault("tags", {})
    for k in list(tags):
        tags[k] = sanitize_text(tags[k])
    if "row_filters" in cfg:
        for rf in cfg["row_filters"].values():
            if "name" in rf:
                rf["name"] = sanitize_text(rf["name"])
            if "expression" in rf:
                rf["expression"] = sanitize_text(rf["expression"])
    if "table_check_constraints" in cfg:
        for c in cfg["table_check_constraints"].values():
            if "name" in c:
                c["name"] = sanitize_text(c["name"])
            if "expression" in c:
                c["expression"] = sanitize_text(c["expression"])
    return cfg

# ---- SNAPSHOT YAML BUILDER ----

def build_snapshot_yaml(cfg: Dict[str, Any], env: Optional[str] = None) -> Tuple[Dict[str, Any], str]:
    def _get_catalog():
        cat = cfg.get("catalog", "").strip()
        if cat.endswith("_") and env:
            return f"{cat}{env}"
        return cat

    def _get_schema(): return cfg.get("schema", "").strip()
    def _get_table(): return cfg.get("table", "").strip()

    fq = f"{_get_catalog()}.{_get_schema()}.{_get_table()}"

    def _get_tags(): return cfg.get("tags", {})
    def _get_comment(): return cfg.get("table_comment", "")
    def _get_props(): return cfg.get("table_properties", {})

    def _get_columns_dict():
        cols_dict = cfg.get("columns", {})
        cols_dict_str = {str(k): v for k, v in cols_dict.items()}
        sorted_keys = sorted(map(int, cols_dict_str.keys()))
        col_result = {}
        for k in sorted_keys:
            col = cols_dict_str[str(k)]
            col_result[str(k)] = {
                "name": col.get("name", ""),
                "datatype": col.get("datatype", ""),
                "nullable": col.get("nullable", True),
                "active": col.get("active", True),
                "comment": col.get("comment", ""),
                "tags": col.get("tags", {}),
                "column_masking_rule": col.get("column_masking_rule", ""),
                "column_check_constraints": col.get("column_check_constraints", {}),
            }
        return col_result

    snapshot = {
        "full_table_name": fq,
        "catalog": _get_catalog(),
        "schema": _get_schema(),
        "table": _get_table(),
        "primary_key": cfg.get("primary_key", []),
        "foreign_keys": cfg.get("foreign_keys", {}),
        "unique_keys": cfg.get("unique_keys", []),
        "partitioned_by": cfg.get("partitioned_by", []),
        "table_tags": _get_tags(),
        "row_filters": cfg.get("row_filters", {}),
        "table_check_constraints": cfg.get("table_check_constraints", {}),
        "table_properties": _get_props(),
        "table_comment": _get_comment(),
        "owner": cfg.get("owner", ""),
        "columns": _get_columns_dict(),
    }
    return snapshot, fq

# ---- MAIN ENTRY ----

def validate_and_snapshot_yaml(yaml_path: str, env: Optional[str] = None, mode: str = "all") -> Tuple[Optional[Dict[str, Any]], Optional[str]]:
    # 1. Load YAML
    try:
        with open(yaml_path, "r") as f:
            raw_cfg = yaml.safe_load(f)
    except FileNotFoundError:
        print(f"YAML file not found: {yaml_path}")
        sys.exit(2)
    except yaml.YAMLError as e:
        print(f"YAML syntax error in {yaml_path}: {e}")
        sys.exit(2)
    except Exception as e:
        print(f"Error loading or parsing YAML: {e}")
        sys.exit(2)

    # 2. Validate
    valid, errors = YamlSnapshot.validate_dict(raw_cfg)
    if not valid:
        print("Validation failed:")
        for err in errors:
            print(f"  - {err}")
        print("YAML validation failed. See errors above.")
        sys.exit(1)
    else:
        print("YAML validation passed.")

    # 3. Sanitize
    cfg_clean = recursive_sanitize_comments(raw_cfg)
    cfg_clean = sanitize_metadata(cfg_clean)

    if mode == "validate":
        print("Validation complete. Exiting after successful validation.")
        sys.exit(0)

    # 4. Build snapshot
    snapshot_yaml, fq_table = build_snapshot_yaml(cfg_clean, env=env)
    print("Snapshot YAML and fully qualified table name are ready.")
    return snapshot_yaml, fq_table

In [0]:
yaml_path = "layker/resources/test.yaml"

# Call your function and print the output snapshot dict (and FQ table name)
snapshot_yaml, fq_table = validate_and_snapshot_yaml(yaml_path, env=None, mode="all")

print("--- FULLY SANITIZED SNAPSHOT YAML ---")
import pprint
pprint.pprint(snapshot_yaml)
print("--- FULLY QUALIFIED TABLE NAME ---")
print(fq_table)

### Snapshot

#### Table

In [0]:
table_name = "dq_dev.lmg_sandbox.config_driven_table_example"

describe_extended_query = """
DESCRIBE EXTENDED
  {table_name}
"""

spark.sql(describe_extended_query.format(table_name=table_name)).show(truncate=False, n=100)

In [0]:
import re
from typing import List, Dict, Any
from pyspark.sql import SparkSession


############################################################
                    Start of First Section                    
############################################################

# --- SYSTEM TABLE SNAPSHOT QUERIES ---
SNAPSHOT_QUERIES = {
    "table_tags": {
        "table": "system.information_schema.table_tags",
        "columns": ["catalog_name", "schema_name", "table_name", "tag_name", "tag_value"],
        "where_keys": [("catalog_name", 0), ("schema_name", 1), ("table_name", 2)],
    },
    "column_tags": {
        "table": "system.information_schema.column_tags",
        "columns": ["catalog_name", "schema_name", "table_name", "column_name", "tag_name", "tag_value"],
        "where_keys": [("catalog_name", 0), ("schema_name", 1), ("table_name", 2)],
    },
    "row_filters": {
        "table": "system.information_schema.row_filters",
        "columns": ["table_catalog", "table_schema", "table_name", "filter_name", "target_columns"],
        "where_keys": [("table_catalog", 0), ("table_schema", 1), ("table_name", 2)],
    },
    "constraint_table_usage": {
        "table": "system.information_schema.constraint_table_usage",
        "columns": ["constraint_catalog", "constraint_schema", "constraint_name"],
        "where_keys": [("table_catalog", 0), ("table_schema", 1), ("table_name", 2)],
    },
    "constraint_column_usage": {
        "table": "system.information_schema.constraint_column_usage",
        "columns": ["column_name", "constraint_name"],
        "where_keys": [("table_catalog", 0), ("table_schema", 1), ("table_name", 2)],
    },
}

def parse_fq_table(fq_table: str):
    parts = fq_table.split(".")
    if len(parts) != 3:
        raise ValueError("Expected format: catalog.schema.table")
    return parts[0], parts[1], parts[2]

def build_metadata_sql(kind: str, fq_table: str) -> str:
    config = SNAPSHOT_QUERIES[kind]
    catalog, schema, table = parse_fq_table(fq_table)
    table_vars = [catalog, schema, table]
    where_clauses = [
        f"{col_name} = '{table_vars[idx]}'"
        for col_name, idx in config["where_keys"]
    ]
    columns = ", ".join(config["columns"])
    return f"SELECT {columns} FROM {config['table']} WHERE {' AND '.join(where_clauses)}"

def get_metadata_snapshot(spark: SparkSession, fq_table: str) -> Dict[str, List[Dict[str, Any]]]:
    results = {}
    for kind in SNAPSHOT_QUERIES:
        try:
            sql = build_metadata_sql(kind, fq_table)
            df = spark.sql(sql)
            rows = [row.asDict() for row in df.collect()]
            results[kind] = rows
        except Exception as e:
            results[kind] = f"[ERROR] {e}"
    return results

# --- DESCRIBE TABLE EXTENDED PARSERS ---
def get_describe_rows(spark: SparkSession, fq_table: str) -> List[Dict[str, Any]]:
    sql = f"DESCRIBE EXTENDED {fq_table}"
    df = spark.sql(sql)
    return [row.asDict() for row in df.collect()]

def extract_columns(describe_rows: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    columns = []
    for row in describe_rows:
        col_name = (row.get("col_name") or "").strip()
        data_type = (row.get("data_type") or "").strip()
        comment = (row.get("comment") or "").strip() if row.get("comment") else None
        if col_name == "" or col_name.startswith("#"):
            if col_name == "# Partition Information":
                break
            continue
        columns.append({
            "name": col_name,
            "datatype": data_type,
            "comment": comment if comment and comment.upper() != "NULL" else "",
        })
    return columns

def extract_partitioned_by(describe_rows: List[Dict[str, Any]]) -> List[str]:
    collecting = False
    partition_cols = []
    for row in describe_rows:
        col_name = (row.get("col_name") or "").strip()
        if col_name == "# Partition Information":
            collecting = True
            continue
        if collecting:
            if not col_name or col_name.startswith("#"):
                break
            if col_name != "# col_name":
                partition_cols.append(col_name)
    return partition_cols

def extract_table_details(describe_rows: List[Dict[str, Any]]) -> Dict[str, Any]:
    details = {}
    table_properties = {}
    in_details = False
    for row in describe_rows:
        col_name = (row.get("col_name") or "").strip()
        data_type = (row.get("data_type") or "").strip()
        if col_name == "# Detailed Table Information":
            in_details = True
            continue
        if in_details:
            if not col_name or col_name.startswith("#"):
                break
            if col_name == "Owner":
                details["owner"] = data_type
            elif col_name == "Comment":
                details["comment"] = data_type
            elif col_name == "Table Properties":
                for prop in data_type.strip("[]").split(","):
                    if "=" in prop:
                        k, v = prop.split("=", 1)
                        table_properties[k.strip()] = v.strip()
    details["table_properties"] = table_properties
    return details

def extract_constraints(describe_rows: List[Dict[str, Any]]) -> List[Dict[str, str]]:
    constraints = []
    in_constraints = False
    for row in describe_rows:
        col_name = (row.get("col_name") or "").strip()
        data_type = (row.get("data_type") or "").strip()
        if col_name == "# Constraints":
            in_constraints = True
            continue
        if in_constraints:
            if not col_name or col_name.startswith("#"):
                break
            if col_name and data_type:
                constraints.append({"name": col_name, "type": data_type})
    return constraints

# ------------------------------------------------------------------
"""
# --- Example usage, all prints at the bottom: ---
table_name = "dq_dev.lmg_sandbox.config_driven_table_example"
spark = SparkSession.getActiveSession() or SparkSession.builder.getOrCreate()

# Fetch Unity Catalog system metadata (raw output)
uc_metadata = get_metadata_snapshot(spark, table_name)

# Fetch DESCRIBE EXTENDED rows (raw output)
describe_rows = get_describe_rows(spark, table_name)

# --- Now print results ---
for kind, rows in uc_metadata.items():
    print(f"\n--- {kind.upper()} ---")
    if isinstance(rows, str) and rows.startswith("[ERROR]"):
        print(rows)
    elif not rows:
        print("No rows found.")
    else:
        for row in rows:
            print(row)

print("\n--- COLUMNS ---")
columns = extract_columns(describe_rows)
print(columns if columns else "No columns found.")

print("\n--- PARTITIONED BY ---")
partitioned_by = extract_partitioned_by(describe_rows)
print(partitioned_by if partitioned_by else "No partitions found.")

print("\n--- TABLE DETAILS ---")
table_details = extract_table_details(describe_rows)
print(table_details if table_details else "No table details found.")

print("\n--- CONSTRAINTS ---")
constraints = extract_constraints(describe_rows)
print(constraints if constraints else "No constraints found.")
"""

############################################################
                    End of First Section                    
############################################################                


# -------------------------------------------------------- #


############################################################
                    Start of Second Section                    
############################################################
def build_table_metadata_snapshot(
    fq_table: str,
    uc_metadata: Dict[str, List[Dict[str, Any]]],
    describe_rows: List[Dict[str, Any]]
) -> Dict[str, Any]:
    catalog, schema, table = parse_fq_table(fq_table)
    # Table tags
    table_tags = {row["tag_name"]: row["tag_value"] for row in uc_metadata.get("table_tags", [])}
    # Table properties, owner, comment
    details = extract_table_details(describe_rows)
    # Table check constraints (if present in table_properties, or elsewhere)
    table_check_constraints = {
        k: {"name": k, "expression": v}
        for k, v in details.get("table_properties", {}).items()
        if k.startswith("delta.constraints")
    }
    # Row filters
    row_filters = [
        {"filter_name": row["filter_name"], "target_columns": row["target_columns"]}
        for row in uc_metadata.get("row_filters", [])
    ]
    # Partition columns
    partitioned_by = extract_partitioned_by(describe_rows)
    # Constraints
    constraints = extract_constraints(describe_rows)
    # Primary key: from constraints
    pk = []
    for c in constraints:
        if "PRIMARY KEY" in c["type"]:
            m = re.search(r"\((.*?)\)", c["type"])
            if m:
                pk = [col.strip().replace("`", "") for col in m.group(1).split(",")]
    # Columns (by index)
    columns_raw = extract_columns(describe_rows)
    # Column tags (merge by column name)
    col_tag_lookup = {}
    for row in uc_metadata.get("column_tags", []):
        col = row["column_name"]
        if col not in col_tag_lookup:
            col_tag_lookup[col] = {}
        col_tag_lookup[col][row["tag_name"]] = row["tag_value"]
    # Column check constraints (by constraint_column_usage)
    col_constraint_lookup = {}
    for row in uc_metadata.get("constraint_column_usage", []):
        col = row["column_name"]
        cons = row["constraint_name"]
        if col not in col_constraint_lookup:
            col_constraint_lookup[col] = {}
        col_constraint_lookup[col][cons] = {"name": cons}  # Expression requires deeper parsing if needed

    # Build columns dictionary by position (1-based, as in your spec)
    columns = {}
    for idx, col in enumerate(columns_raw, start=1):
        colname = col["name"]
        columns[idx] = {
            "column_name": colname,
            "datatype": col["datatype"],
            "comment": col["comment"],
            "nullable": None,  # Could be extracted if needed
            "masking_rule": None,  # Could be extracted if needed
            "column_tags": col_tag_lookup.get(colname, {}),
            "column_check_constraints": col_constraint_lookup.get(colname, {}),
        }

    result = {
        "table": {
            "fully_qualified_name": fq_table,
            "catalog": catalog,
            "schema": schema,
            "table": table,
            "owner": details.get("owner", ""),
            "comment": details.get("comment", ""),
            "table_properties": details.get("table_properties", {}),
            "table_tags": table_tags,
            "table_check_constraints": table_check_constraints,
            "row_filters": row_filters,
            "partitioned_by": partitioned_by,
            "primary_key": pk,
            "columns": columns,
        }
    }
    return result

# --- Usage Example ---
table_name = "dq_dev.lmg_sandbox.config_driven_table_example"
spark = SparkSession.getActiveSession() or SparkSession.builder.getOrCreate()
uc_metadata = get_metadata_snapshot(spark, table_name)
describe_rows = get_describe_rows(spark, table_name)

snapshot = build_table_metadata_snapshot(table_name, uc_metadata, describe_rows)
import pprint; pprint.pprint(snapshot, width=120)

############################################################
                    End of Second Section                    
############################################################

In [0]:
import re
from typing import Any, Dict, List, Optional
from pyspark.sql import SparkSession

class TableSnapshot:
    SNAPSHOT_QUERIES = {
        "table_tags": {
            "table": "system.information_schema.table_tags",
            "columns": ["catalog_name", "schema_name", "table_name", "tag_name", "tag_value"],
            "where_keys": [("catalog_name", 0), ("schema_name", 1), ("table_name", 2)],
        },
        "column_tags": {
            "table": "system.information_schema.column_tags",
            "columns": ["catalog_name", "schema_name", "table_name", "column_name", "tag_name", "tag_value"],
            "where_keys": [("catalog_name", 0), ("schema_name", 1), ("table_name", 2)],
        },
        "row_filters": {
            "table": "system.information_schema.row_filters",
            "columns": ["table_catalog", "table_schema", "table_name", "filter_name", "target_columns"],
            "where_keys": [("table_catalog", 0), ("table_schema", 1), ("table_name", 2)],
        },
        "constraint_table_usage": {
            "table": "system.information_schema.constraint_table_usage",
            "columns": ["constraint_catalog", "constraint_schema", "constraint_name"],
            "where_keys": [("table_catalog", 0), ("table_schema", 1), ("table_name", 2)],
        },
        "constraint_column_usage": {
            "table": "system.information_schema.constraint_column_usage",
            "columns": ["column_name", "constraint_name"],
            "where_keys": [("table_catalog", 0), ("table_schema", 1), ("table_name", 2)],
        },
    }

    def __init__(self, spark: SparkSession, fq_table: str):
        self.spark = spark
        self.fq_table = fq_table
        self.catalog, self.schema, self.table = self._parse_fq_table(fq_table)

    def _parse_fq_table(self, fq_table: str):
        parts = fq_table.split(".")
        if len(parts) != 3:
            raise ValueError("Expected format: catalog.schema.table")
        return parts[0], parts[1], parts[2]

    def _build_metadata_sql(self, kind: str) -> str:
        config = self.SNAPSHOT_QUERIES[kind]
        table_vars = [self.catalog, self.schema, self.table]
        where_clauses = [
            f"{col_name} = '{table_vars[idx]}'"
            for col_name, idx in config["where_keys"]
        ]
        columns = ", ".join(config["columns"])
        return f"SELECT {columns} FROM {config['table']} WHERE {' AND '.join(where_clauses)}"

    def _get_metadata_snapshot(self) -> Dict[str, List[Dict[str, Any]]]:
        results = {}
        for kind in self.SNAPSHOT_QUERIES:
            try:
                sql = self._build_metadata_sql(kind)
                df = self.spark.sql(sql)
                results[kind] = [row.asDict() for row in df.collect()]
            except Exception:
                results[kind] = []
        return results

    def _get_describe_rows(self) -> List[Dict[str, Any]]:
        sql = f"DESCRIBE EXTENDED {self.fq_table}"
        df = self.spark.sql(sql)
        return [row.asDict() for row in df.collect()]

    def _extract_columns(self, describe_rows: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        columns = []
        for row in describe_rows:
            col_name = (row.get("col_name") or "").strip()
            data_type = (row.get("data_type") or "").strip()
            comment = (row.get("comment") or "").strip() if row.get("comment") else ""
            if col_name == "" or col_name.startswith("#"):
                if col_name == "# Partition Information":
                    break
                continue
            columns.append({
                "name": col_name,
                "datatype": data_type,
                "comment": comment if comment.upper() != "NULL" else "",
            })
        return columns

    def _extract_partitioned_by(self, describe_rows: List[Dict[str, Any]]) -> List[str]:
        # More robust: handle extra headers, blank lines, and section switches.
        collecting = False
        partition_cols = []
        for row in describe_rows:
            col_name = (row.get("col_name") or "").strip()
            # Start block
            if col_name == "# Partition Information":
                collecting = True
                continue
            if collecting:
                if col_name.startswith("#") and col_name != "# col_name":
                    break  # new section begins
                if col_name == "" or col_name == "# col_name":
                    continue  # skip headers/empties
                partition_cols.append(col_name)
        return partition_cols

    def _extract_table_details(self, describe_rows: List[Dict[str, Any]]) -> Dict[str, Any]:
        details = {}
        table_properties = {}
        in_details = False
        for row in describe_rows:
            col_name = (row.get("col_name") or "").strip()
            data_type = (row.get("data_type") or "").strip()
            if col_name == "# Detailed Table Information":
                in_details = True
                continue
            if in_details:
                if not col_name or col_name.startswith("#"):
                    break
                if col_name == "Owner":
                    details["owner"] = data_type
                elif col_name == "Comment":
                    details["comment"] = data_type
                elif col_name == "Table Properties":
                    for prop in data_type.strip("[]").split(","):
                        if "=" in prop:
                            k, v = prop.split("=", 1)
                            table_properties[k.strip()] = v.strip()
        details["table_properties"] = table_properties
        return details

    def _extract_constraints(self, describe_rows: List[Dict[str, Any]]) -> List[Dict[str, str]]:
        constraints = []
        in_constraints = False
        for row in describe_rows:
            col_name = (row.get("col_name") or "").strip()
            data_type = (row.get("data_type") or "").strip()
            if col_name == "# Constraints":
                in_constraints = True
                continue
            if in_constraints:
                if not col_name or col_name.startswith("#"):
                    break
                if col_name and data_type:
                    constraints.append({"name": col_name, "type": data_type})
        return constraints

    def _build_columns(self, columns_raw: List[Dict[str, Any]], col_tags: Dict[str, Dict[str, Any]], col_checks: Dict[str, Dict[str, Any]]) -> Dict[int, Dict[str, Any]]:
        columns = {}
        for idx, col in enumerate(columns_raw, start=1):
            name = col["name"]
            columns[idx] = {
                "name": name,
                "datatype": col["datatype"],
                "nullable": None,  # Could be filled with extended logic
                "active": True,
                "comment": col.get("comment", ""),
                "tags": col_tags.get(name, {}),
                "column_masking_rule": "",  # No masking in snapshot, set empty
                "column_check_constraints": col_checks.get(name, {}),
            }
        return columns
    
    def _get_foreign_keys(self, uc_metadata: Dict[str, List[Dict[str, Any]]]) -> Dict[str, Any]:
        return {}

    def _get_unique_keys(self, uc_metadata: Dict[str, List[Dict[str, Any]]]) -> List[List[str]]:
        return []

    def build_table_metadata_dict(self) -> Dict[str, Any]:
        uc_metadata = self._get_metadata_snapshot()
        describe_rows = self._get_describe_rows()

        catalog, schema, table = self.catalog, self.schema, self.table

        table_tags = {row["tag_name"]: row["tag_value"] for row in uc_metadata.get("table_tags", [])}
        details = self._extract_table_details(describe_rows)
        table_check_constraints = {
            k: {"name": k, "expression": v}
            for k, v in details.get("table_properties", {}).items()
            if k.startswith("delta.constraints")
        }

        row_filters = {}
        for row in uc_metadata.get("row_filters", []):
            fname = row.get("filter_name")
            if fname:
                row_filters[fname] = {
                    "name": fname,
                    "expression": row.get("target_columns", "")
                }

        # Partition columns (fixed logic)
        partitioned_by = self._extract_partitioned_by(describe_rows)

        constraints = self._extract_constraints(describe_rows)
        pk = []
        for c in constraints:
            if "PRIMARY KEY" in c["type"]:
                m = re.search(r"\((.*?)\)", c["type"])
                if m:
                    pk = [col.strip().replace("`", "") for col in m.group(1).split(",")]

        columns_raw = self._extract_columns(describe_rows)
        col_tag_lookup = {}
        for row in uc_metadata.get("column_tags", []):
            col = row["column_name"]
            if col not in col_tag_lookup:
                col_tag_lookup[col] = {}
            col_tag_lookup[col][row["tag_name"]] = row["tag_value"]

        col_constraint_lookup = {}
        for row in uc_metadata.get("constraint_column_usage", []):
            col = row["column_name"]
            cons = row["constraint_name"]
            if col not in col_constraint_lookup:
                col_constraint_lookup[col] = {}
            col_constraint_lookup[col][cons] = {"name": cons}  # no expression parsing here

        columns = self._build_columns(columns_raw, col_tag_lookup, col_constraint_lookup)

        return {
            "full_table_name": self.fq_table,
            "catalog": catalog,
            "schema": schema,
            "table": table,
            "primary_key": pk,
            "foreign_keys": self._get_foreign_keys(uc_metadata),
            "unique_keys": self._get_unique_keys(uc_metadata),
            "partitioned_by": partitioned_by,
            "tags": table_tags,
            "row_filters": row_filters,
            "table_check_constraints": table_check_constraints,
            "table_properties": details.get("table_properties", {}),
            "comment": details.get("comment", ""),
            "owner": details.get("owner", ""),
            "columns": columns,
        }

In [0]:
# Usage example:
spark = SparkSession.getActiveSession() or SparkSession.builder.getOrCreate()
metadata_dict = TableSnapshot(spark, "dq_dev.lmg_sandbox.config_driven_table_example")
table_snapshot = metadata_dict.build_table_metadata_dict()
print(table_snapshot)

In [0]:
pprint.pprint(table_snapshot, width=120)

#### Yaml

In [0]:
import yaml
from typing import Any, Dict, List, Optional

class YamlSnapshot:
    """
    Loader for YAML DDL config files. Exposes all config blocks with a clean API.
    Handles dynamic env, catalog suffixes, and nested constraints/keys.
    """

    def __init__(self, config_path: str, env: Optional[str] = None):
        self.config_path = config_path
        self._env = env
        self._config: Dict[str, Any] = {}
        self.load_config()

    def load_config(self) -> None:
        try:
            with open(self.config_path, "r") as f:
                self._config = yaml.safe_load(f)
        except (FileNotFoundError, yaml.YAMLError) as e:
            raise ValueError(f"Error loading YAML configuration from {self.config_path}: {e}")

    @property
    def catalog(self) -> str:
        return self._config.get("catalog", "")

    @property
    def schema(self) -> str:
        return self._config.get("schema", "")

    @property
    def table(self) -> str:
        return self._config.get("table", "")

    @property
    def env(self) -> Optional[str]:
        return self._env

    @property
    def full_table_name(self) -> str:
        cat = self.catalog.strip()
        sch = self.schema.strip()
        tbl = self.table.strip()
        env = self.env
        if cat.endswith("_") and env:
            cat_full = f"{cat}{env}"
        else:
            cat_full = cat
        return f"{cat_full}.{sch}.{tbl}"

    @property
    def owner(self) -> str:
        return self._config.get("owner", "")

    @property
    def tags(self) -> Dict[str, Any]:
        return self._config.get("tags", {})

    @property
    def properties(self) -> Dict[str, Any]:
        return self._config.get("properties", {})

    @property
    def table_comment(self) -> str:
        return self.properties.get("comment", "")

    @property
    def table_properties(self) -> Dict[str, Any]:
        return self.properties.get("table_properties", {})

    @property
    def primary_key(self) -> List[str]:
        pk = self._config.get("primary_key", [])
        return pk if isinstance(pk, list) else [pk]

    @property
    def partitioned_by(self) -> List[str]:
        pb = self._config.get("partitioned_by", [])
        return pb if isinstance(pb, list) else [pb]

    @property
    def unique_keys(self) -> List[List[str]]:
        return self._config.get("unique_keys", [])

    @property
    def foreign_keys(self) -> Dict[str, Any]:
        return self._config.get("foreign_keys", {})

    @property
    def table_check_constraints(self) -> Dict[str, Any]:
        return self._config.get("table_check_constraints", {})

    @property
    def row_filters(self) -> Dict[str, Any]:
        return self._config.get("row_filters", {})

    @property
    def columns(self) -> List[Dict[str, Any]]:
        cols_dict = self._config.get("columns", {})
        cols_dict_str = {str(k): v for k, v in cols_dict.items()}
        sorted_keys = sorted(map(int, cols_dict_str.keys()))
        return [cols_dict_str[str(k)] for k in sorted_keys]

    def build_table_metadata_dict(self) -> Dict[str, Any]:
        # Return dict with keys in the exact order you want — relies on Python 3.7+ insertion order preservation
        result = {
            "full_table_name": self.full_table_name,
            "catalog": self.catalog,
            "schema": self.schema,
            "table": self.table,
            "primary_key": self.primary_key if self.primary_key else [],
            "foreign_keys": self.foreign_keys if self.foreign_keys else {},
            "unique_keys": self.unique_keys if self.unique_keys else [],
            "partitioned_by": self.partitioned_by if self.partitioned_by else [],
            "tags": self.tags if self.tags else {},
            "row_filters": self.row_filters if self.row_filters else {},
            "table_check_constraints": self.table_check_constraints if self.table_check_constraints else {},
            "table_properties": self.table_properties if self.table_properties else {},
            "comment": self.table_comment,
            "owner": self.owner,
            "columns": {},
        }

        # Numbered columns with requested keys & order
        for idx, col in enumerate(self.columns, 1):
            result["columns"][idx] = {
                "name": col.get("name", ""),
                "datatype": col.get("datatype", ""),
                "nullable": col.get("nullable", True),
                "active": col.get("active", True),
                "comment": col.get("comment", ""),
                "tags": col.get("tags", {}),
                "column_masking_rule": col.get("column_masking_rule", ""),
                "column_check_constraints": col.get("column_check_constraints", {}),
            }

        return result

    def describe(self) -> None:
        # Helper for dev/test use only
        print(f"Table: {self.full_table_name}")
        print(f"  Owner: {self.owner}")
        print(f"  Tags: {self.tags}")
        print(f"  Primary Key: {self.primary_key}")
        print(f"  Partitioned By: {self.partitioned_by}")
        print(f"  Unique Keys: {self.unique_keys}")
        print(f"  Foreign Keys: {self.foreign_keys}")
        print(f"  Table Check Constraints: {self.table_check_constraints}")
        print(f"  Row Filters: {self.row_filters}")
        print(f"  Table Properties: {self.table_properties}")
        print(f"  Columns:")
        for i, col in enumerate(self.columns, 1):
            print(
                f"    {i}: {col.get('name','')} ({col.get('datatype','')}, nullable={col.get('nullable', True)}) | "
                f"comment={col.get('comment','')}, tags={col.get('tags',{})}, active={col.get('active', True)}"
            )
            ccc = col.get("column_check_constraints", {})
            if ccc:
                print(f"      Column Check Constraints: {ccc}")

In [0]:

yaml_path = "layker/resources/example.yaml"

cfg = YamlSnapshot(yaml_path, env="dev")
yaml_snapshot = cfg.build_table_metadata_dict()
print(yaml_snapshot)

#import pprint
#pprint.pprint(table_meta, width=120)

In [0]:
import pprint
pprint.pprint(yaml_snapshot, width=120)

### Differences

In [0]:
from typing import Dict, Any, Optional

def diff_primary_key(yaml: Dict[str, Any], table: Optional[Dict[str, Any]], add, update):
    y_pk = yaml.get("primary_key", [])
    t_pk = table.get("primary_key", []) if table else []
    if y_pk and y_pk != t_pk:
        if not t_pk:
            add["primary_key"] = y_pk
        else:
            update["primary_key"] = y_pk

def diff_partitioned_by(yaml: Dict[str, Any], table: Optional[Dict[str, Any]], add):
    y_pb = yaml.get("partitioned_by", [])
    t_pb = table.get("partitioned_by", []) if table else []
    if y_pb and not t_pb:
        add["partitioned_by"] = y_pb

def diff_unique_keys(yaml: Dict[str, Any], table: Optional[Dict[str, Any]], add):
    y_uk = yaml.get("unique_keys", [])
    t_uk = table.get("unique_keys", []) if table else []
    if y_uk and y_uk != t_uk:
        add["unique_keys"] = y_uk

def diff_foreign_keys(yaml: Dict[str, Any], table: Optional[Dict[str, Any]], add):
    y_fk = yaml.get("foreign_keys", {})
    t_fk = table.get("foreign_keys", {}) if table else {}
    if y_fk and y_fk != t_fk:
        add["foreign_keys"] = y_fk

def diff_table_check_constraints(yaml: Dict[str, Any], table: Optional[Dict[str, Any]], add, update, remove):
    y_tcc = yaml.get("table_check_constraints", {})
    t_tcc = table.get("table_check_constraints", {}) if table else {}
    for k, v in y_tcc.items():
        if k not in t_tcc:
            add["table_check_constraints"][k] = v
        elif t_tcc[k] != v:
            update["table_check_constraints"][k] = v
    for k, v in t_tcc.items():
        if k not in y_tcc:
            remove["table_check_constraints"][k] = v

def diff_row_filters(yaml: Dict[str, Any], table: Optional[Dict[str, Any]], add, update, remove):
    y_rf = yaml.get("row_filters", {})
    t_rf = table.get("row_filters", {}) if table else {}
    for k, v in y_rf.items():
        if k not in t_rf:
            add["row_filters"][k] = v
        elif t_rf[k] != v:
            update["row_filters"][k] = v
    for k, v in t_rf.items():
        if k not in y_rf:
            remove["row_filters"][k] = v

def diff_table_tags(yaml: Dict[str, Any], table: Optional[Dict[str, Any]], add, remove):
    y_tags = yaml.get("tags", {})
    t_tags = table.get("tags", {}) if table else {}
    for k, v in y_tags.items():
        if k not in t_tags:
            add["table_tags"][k] = v
    for k, v in t_tags.items():
        if k not in y_tags:
            remove["table_tags"][k] = v

def diff_owner(yaml: Dict[str, Any], table: Optional[Dict[str, Any]], update):
    y_owner = yaml.get("owner", "")
    t_owner = table.get("owner", "") if table else ""
    if y_owner and y_owner != t_owner:
        update["owner"] = y_owner

def diff_table_comment(yaml: Dict[str, Any], table: Optional[Dict[str, Any]], update):
    y_comment = (yaml.get("comment", "") or "").strip()
    t_comment = (table.get("comment", "") or "").strip() if table else ""
    if y_comment and y_comment != t_comment:
        update["table_comment"] = y_comment

def diff_table_properties(yaml: Dict[str, Any], table: Optional[Dict[str, Any]], add):
    y_props = yaml.get("table_properties", {})
    t_props = table.get("table_properties", {}) if table else {}
    for k, v in y_props.items():
        if k not in t_props:
            add["table_properties"][k] = v

def diff_columns(yaml: Dict[str, Any], table: Optional[Dict[str, Any]], add, update, remove):
    y_cols = yaml.get("columns", {}) or {}
    t_cols = table.get("columns", {}) if table else {}
    y_idxs, t_idxs = set(y_cols.keys()), set(t_cols.keys())

    # Add new columns at the end
    max_t_idx = max(t_idxs) if t_idxs else 0
    for idx in y_idxs:
        if idx > max_t_idx:
            y_col = y_cols[idx]
            add["columns"][idx] = {
                "name": y_col.get("name", ""),
                "datatype": y_col.get("datatype", ""),
                "nullable": y_col.get("nullable", True),
                "column_comment": y_col.get("comment", ""),
                "column_tags": y_col.get("tags", {}),
                "column_masking_rule": y_col.get("column_masking_rule", ""),
                "column_check_constraints": y_col.get("column_check_constraints", {}),
            }
    # Remove columns missing in YAML
    for idx in t_idxs:
        if idx not in y_idxs:
            t_col = t_cols[idx]
            remove["columns"][idx] = {
                "name": t_col.get("name", ""),
                "column_tags": t_col.get("tags", {}),
                "column_check_constraints": t_col.get("column_check_constraints", {})
            }
    # Per-column tag/check constraint add/remove
    for idx in y_idxs & t_idxs:
        y_col, t_col = y_cols[idx], t_cols[idx]
        col_update = {}

        # Name change = rename (update)
        if y_col.get("name", "") != t_col.get("name", ""):
            col_update["name"] = y_col.get("name", "")
        # Comment update
        if y_col.get("comment", "") != t_col.get("comment", ""):
            col_update["column_comment"] = y_col.get("comment", "")
        # Masking rule update
        if y_col.get("column_masking_rule", "") != t_col.get("column_masking_rule", ""):
            col_update["column_masking_rule"] = y_col.get("column_masking_rule", "")

        # --- Column tags: only add new, remove missing ---
        y_ctags, t_ctags = y_col.get("tags", {}) or {}, t_col.get("tags", {}) or {}
        tag_add, tag_remove = {}, {}
        for k, v in y_ctags.items():
            if k not in t_ctags:
                tag_add[k] = v
        for k, v in t_ctags.items():
            if k not in y_ctags:
                tag_remove[k] = v
        if tag_add:
            col_update["column_tags"] = tag_add
        if tag_remove:
            remove["columns"].setdefault(idx, {}).setdefault("column_tags", {}).update(tag_remove)

        # --- Column check constraints: add new, remove missing ---
        y_cc, t_cc = y_col.get("column_check_constraints", {}) or {}, t_col.get("column_check_constraints", {}) or {}
        cc_add, cc_remove = {}, {}
        for k, v in y_cc.items():
            if k not in t_cc:
                cc_add[k] = v
        for k, v in t_cc.items():
            if k not in y_cc:
                cc_remove[k] = v
        if cc_add:
            col_update["column_check_constraints"] = cc_add
        if cc_remove:
            remove["columns"].setdefault(idx, {}).setdefault("column_check_constraints", {}).update(cc_remove)

        if col_update:
            update["columns"][idx] = col_update

def generate_differences(
    yaml_snapshot: Dict[str, Any],
    table_snapshot: Optional[Dict[str, Any]]
) -> Dict[str, Any]:
    """
    Compute differences between YAML snapshot and table snapshot, enforcing Layker semantics.
    Handles full create (table_snapshot is None) and incremental changes.
    """
    # FULL CREATE: If table does not exist, return all YAML fields under 'add'
    if table_snapshot is None:
        # You could filter out empty fields if you want, but generally you want everything needed to create.
        return {
            "full_table_name": yaml_snapshot.get("full_table_name", ""),
            "add": {k: v for k, v in yaml_snapshot.items() if k != "full_table_name"}
        }

    diffs = {
        "full_table_name": yaml_snapshot.get("full_table_name", ""),
        "add": {
            "primary_key": [],
            "partitioned_by": [],
            "unique_keys": [],
            "foreign_keys": {},
            "table_check_constraints": {},
            "row_filters": {},
            "table_tags": {},
            "owner": "",
            "table_comment": "",
            "table_properties": {},
            "columns": {},
        },
        "update": {
            "primary_key": [],
            "table_check_constraints": {},
            "row_filters": {},
            "table_tags": {},
            "owner": "",
            "table_comment": "",
            "columns": {},
        },
        "remove": {
            "table_check_constraints": {},
            "row_filters": {},
            "table_tags": {},
            "columns": {},
        }
    }

    # Dispatch to rules
    diff_primary_key(yaml_snapshot, table_snapshot, diffs["add"], diffs["update"])
    diff_partitioned_by(yaml_snapshot, table_snapshot, diffs["add"])
    diff_unique_keys(yaml_snapshot, table_snapshot, diffs["add"])
    diff_foreign_keys(yaml_snapshot, table_snapshot, diffs["add"])
    diff_table_check_constraints(yaml_snapshot, table_snapshot, diffs["add"], diffs["update"], diffs["remove"])
    diff_row_filters(yaml_snapshot, table_snapshot, diffs["add"], diffs["update"], diffs["remove"])
    diff_table_tags(yaml_snapshot, table_snapshot, diffs["add"], diffs["remove"])
    diff_owner(yaml_snapshot, table_snapshot, diffs["update"])
    diff_table_comment(yaml_snapshot, table_snapshot, diffs["update"])
    diff_table_properties(yaml_snapshot, table_snapshot, diffs["add"])
    diff_columns(yaml_snapshot, table_snapshot, diffs["add"], diffs["update"], diffs["remove"])

    # Clean up: only return keys that have values (don't return empty add/update/remove sections)
    out = {"full_table_name": diffs["full_table_name"]}
    for section in ["add", "update", "remove"]:
        filtered = {k: v for k, v in diffs[section].items() if v and (not isinstance(v, dict) or len(v))}
        if filtered:
            out[section] = filtered
    return out

In [0]:
from pyspark.sql import SparkSession
from layker.snapshot_yaml import validate_and_snapshot_yaml
#from layker.snapshot_table import TableSnapshot

yaml_path = "layker/resources/test.yaml"
table_name = "dq_dev.lmg_sandbox.config_driven_table_example"

spark = SparkSession.getActiveSession() or SparkSession.builder.getOrCreate()
snapshot_yaml, fq_table = validate_and_snapshot_yaml(yaml_path, env=None, mode="all")
#table_snapshot = TableSnapshot(spark, table_name).build_table_metadata_dict()

diffs = generate_differences(snapshot_yaml, None)
print(diffs)

In [0]:
import pprint

pprint.pprint(diffs, width=120, sort_dicts=False)

### Loader

In [0]:
# src/layker/loader.py

from typing import Any, Dict
from pyspark.sql import SparkSession

# ---- Centralized Loader Config ----
LOADER_CONFIG = {
    "add": {
        "primary_key": {
            "sql": "ALTER TABLE {fq} ADD PRIMARY KEY ({cols})",
            "desc": "ADD primary key: {cols}"
        },
        "partitioned_by": {
            "sql": "ALTER TABLE {fq} ADD PARTITIONED BY ({cols})",
            "desc": "ADD partition: {cols}"
        },
        "unique_keys": {
            "sql": "ALTER TABLE {fq} ADD CONSTRAINT uq_{key} UNIQUE ({cols})",
            "desc": "ADD unique key: {cols}"
        },
        "foreign_keys": {
            "sql": (
                "ALTER TABLE {fq} ADD CONSTRAINT {name} "
                "FOREIGN KEY ({cols}) REFERENCES {ref_tbl} ({ref_cols})"
            ),
            "desc": "ADD foreign key: {name} ({cols})"
        },
        "table_check_constraints": {
            "sql": "ALTER TABLE {fq} ADD CONSTRAINT {name} CHECK ({expression})",
            "desc": "ADD table check constraint: {name}"
        },
        "row_filters": {
            "sql": "ALTER TABLE {fq} ADD ROW FILTER {name} WHERE {expression}",
            "desc": "ADD row filter: {name}"
        },
        "table_tags": {
            "sql": "ALTER TABLE {fq} SET TAGS ('{key}' = '{val}')",
            "desc": "ADD table tag: {key}={val}"
        },
        "owner": {
            "sql": "ALTER TABLE {fq} OWNER TO `{owner}`",
            "desc": "SET owner: {owner}"
        },
        "table_comment": {
            "sql": "ALTER TABLE {fq} SET COMMENT '{comment}'",
            "desc": "SET table comment"
        },
        "table_properties": {
            "sql": "ALTER TABLE {fq} SET TBLPROPERTIES ('{key}' = '{val}')",
            "desc": "ADD table property: {key}={val}"
        },
        "columns": {
            "sql": None  # Handled separately or via create
        }
    },
    "update": {
        "primary_key": {
            "sql": "ALTER TABLE {fq} ALTER PRIMARY KEY ({cols})",
            "desc": "UPDATE primary key: {cols}"
        },
        "table_check_constraints": {
            "sql": "ALTER TABLE {fq} ALTER CONSTRAINT {name} CHECK ({expression})",
            "desc": "UPDATE table check constraint: {name}"
        },
        "row_filters": {
            "sql": "ALTER TABLE {fq} ALTER ROW FILTER {name} WHERE {expression}",
            "desc": "UPDATE row filter: {name}"
        },
        "table_tags": {
            "sql": "ALTER TABLE {fq} SET TAGS ('{key}' = '{val}')",
            "desc": "UPDATE table tag: {key}={val}"
        },
        "owner": {
            "sql": "ALTER TABLE {fq} OWNER TO `{owner}`",
            "desc": "UPDATE owner: {owner}"
        },
        "table_comment": {
            "sql": "ALTER TABLE {fq} SET COMMENT '{comment}'",
            "desc": "UPDATE table comment"
        },
        "table_properties": {
            "sql": "ALTER TABLE {fq} SET TBLPROPERTIES ('{key}' = '{val}')",
            "desc": "UPDATE table property: {key}={val}"
        },
        "columns": {
            "sql": None  # Handled separately
        }
    },
    "remove": {
        "table_check_constraints": {
            "sql": "ALTER TABLE {fq} DROP CONSTRAINT {name}",
            "desc": "REMOVE table check constraint: {name}"
        },
        "row_filters": {
            "sql": "ALTER TABLE {fq} DROP ROW FILTER {name}",
            "desc": "REMOVE row filter: {name}"
        },
        "table_tags": {
            "sql": "ALTER TABLE {fq} UNSET TAGS ('{key}')",
            "desc": "REMOVE table tag: {key}"
        },
        "columns": {
            "sql": None  # Handled separately
        }
    }
}


class DatabricksTableLoader:
    """
    Loads, updates, or removes table metadata using a differences dictionary.
    Handles CREATE TABLE if the add block implies creation.
    Applies column comments and tags after CREATE for all columns.
    """
    def __init__(self, diff_dict: Dict[str, Any], spark: SparkSession, dry_run: bool = False):
        self.diff = diff_dict
        self.spark = spark
        self.dry_run = dry_run
        self.fq = diff_dict["full_table_name"]
        self.log = []

    def run(self):
        add = self.diff.get("add", {})
        update = self.diff.get("update", {})
        remove = self.diff.get("remove", {})
        columns = add.get("columns", {})

        # --- CREATE TABLE path ---
        if columns and str(min(map(int, columns.keys()))) == "1" and not update and not remove:
            self._create_table(add)
            print("[SUMMARY] Table CREATE complete:")
            for entry in self.log:
                print(f"  - {entry}")
            return

        # --- ALTER TABLE path ---
        for action in ["add", "update", "remove"]:
            section = self.diff.get(action, {})
            if section:
                self._handle_section(action, section)
        print("[SUMMARY] Table modifications complete:")
        for entry in self.log:
            print(f"  - {entry}")

    def _create_table(self, add_section):
        cols = add_section["columns"]
        col_sqls = []
        for idx in sorted(cols, key=lambda x: int(x)):
            col = cols[idx]
            name = col["name"]
            datatype = col["datatype"]
            nullable = col.get("nullable", True)
            col_sql = f"`{name}` {datatype}{' NOT NULL' if not nullable else ''}"
            col_sqls.append(col_sql)
        columns_sql = ",\n  ".join(col_sqls)

        # Partitioning
        partitioned_by = add_section.get("partitioned_by", [])
        partition_sql = f"\nPARTITIONED BY ({', '.join(partitioned_by)})" if partitioned_by else ""

        # Properties
        tbl_props = add_section.get("table_properties", {})
        tbl_props_sql = ""
        if tbl_props:
            props = [f"'{k}' = '{v}'" for k, v in tbl_props.items()]
            tbl_props_sql = f"\nTBLPROPERTIES ({', '.join(props)})"

        # Table comment
        tbl_comment = add_section.get("table_comment", "")
        tbl_comment_sql = f"\nCOMMENT '{tbl_comment}'" if tbl_comment else ""

        sql = f"CREATE TABLE {self.fq} (\n  {columns_sql}\n){partition_sql}{tbl_comment_sql}{tbl_props_sql}"

        self._run(sql, "CREATE TABLE")
        self.log.append(f"CREATE TABLE with columns: {list(c['name'] for c in cols.values())}")

        # --- ENSURE column comments/tags are always applied post-create ---
        self._handle_column_comments_and_tags(cols)

        # Handle table tags, owner, PK, unique, FKs, checks, etc. as ALTER TABLE after creation.
        self._handle_post_create(add_section)

    def _handle_column_comments_and_tags(self, cols):
        for idx in sorted(cols, key=lambda x: int(x)):
            col = cols[idx]
            name = col["name"]

            # Column comment
            comment = col.get("comment", "")
            if comment:
                sql = f"ALTER TABLE {self.fq} ALTER COLUMN {name} COMMENT '{comment}'"
                self._run(sql, f"ADD comment to {name}")

            # Column tags
            tags = col.get("tags") or {}
            for tag, value in tags.items():
                sql = f"ALTER TABLE {self.fq} ALTER COLUMN {name} SET TAGS ('{tag}' = '{value}')"
                self._run(sql, f"ADD tag {tag} to {name}")

    def _handle_post_create(self, add_section):
        for key, val in add_section.items():
            if key in ("columns", "table_properties", "table_comment", "partitioned_by"):
                continue
            meta = LOADER_CONFIG["add"].get(key)
            if not meta or not val:
                continue
            sql_template = meta.get("sql")
            if sql_template is None:
                continue  # Only columns handled separately, rest should all have sql
            if key == "primary_key" or key == "partitioned_by":
                sql = sql_template.format(fq=self.fq, cols=", ".join(val))
                self._run(sql, meta["desc"].format(cols=", ".join(val)))
            elif key == "unique_keys":
                for group in val:
                    sql = sql_template.format(fq=self.fq, key="_".join(group), cols=", ".join(group))
                    self._run(sql, meta["desc"].format(cols=", ".join(group), key="_".join(group)))
            elif key == "foreign_keys":
                for fk_name, fk in val.items():
                    sql = sql_template.format(
                        fq=self.fq,
                        name=fk_name,
                        cols=", ".join(fk.get("columns", [])),
                        ref_tbl=fk.get("reference_table", ""),
                        ref_cols=", ".join(fk.get("reference_columns", [])),
                    )
                    self._run(sql, meta["desc"].format(name=fk_name, cols=", ".join(fk.get("columns", []))))
            elif key == "table_check_constraints":
                for cname, cdict in val.items():
                    sql = sql_template.format(fq=self.fq, name=cname, expression=cdict.get("expression"))
                    self._run(sql, meta["desc"].format(name=cname))
            elif key == "row_filters":
                for fname, fdict in val.items():
                    sql = sql_template.format(fq=self.fq, name=fname, expression=fdict.get("expression"))
                    self._run(sql, meta["desc"].format(name=fname))
            elif key == "table_tags":
                for k, v in val.items():
                    sql = sql_template.format(fq=self.fq, key=k, val=v)
                    self._run(sql, meta["desc"].format(key=k, val=v))
            elif key == "owner":
                sql = sql_template.format(fq=self.fq, owner=val)
                self._run(sql, meta["desc"].format(owner=val))
            # No table_comment/table_properties/columns here

    def _handle_section(self, action: str, section_dict: Dict[str, Any]):
        config = LOADER_CONFIG[action]
        for key, meta in config.items():
            val = section_dict.get(key)
            if not val:
                continue
            sql_template = meta.get("sql")
            if sql_template is None:
                self._handle_columns(action, val)
                continue
            if key == "primary_key" or key == "partitioned_by":
                sql = sql_template.format(fq=self.fq, cols=", ".join(val))
                self._run(sql, meta["desc"].format(cols=", ".join(val)))
            elif key == "unique_keys":
                for group in val:
                    sql = sql_template.format(fq=self.fq, key="_".join(group), cols=", ".join(group))
                    self._run(sql, meta["desc"].format(cols=", ".join(group), key="_".join(group)))
            elif key == "foreign_keys":
                for fk_name, fk in val.items():
                    sql = sql_template.format(
                        fq=self.fq,
                        name=fk_name,
                        cols=", ".join(fk.get("columns", [])),
                        ref_tbl=fk.get("reference_table", ""),
                        ref_cols=", ".join(fk.get("reference_columns", [])),
                    )
                    self._run(sql, meta["desc"].format(name=fk_name, cols=", ".join(fk.get("columns", []))))
            elif key == "table_check_constraints":
                for cname, cdict in val.items():
                    sql = sql_template.format(fq=self.fq, name=cname, expression=cdict.get("expression"))
                    self._run(sql, meta["desc"].format(name=cname))
            elif key == "row_filters":
                for fname, fdict in val.items():
                    sql = sql_template.format(fq=self.fq, name=fname, expression=fdict.get("expression"))
                    self._run(sql, meta["desc"].format(name=fname))
            elif key == "table_tags":
                for k, v in val.items():
                    sql = sql_template.format(fq=self.fq, key=k, val=v)
                    self._run(sql, meta["desc"].format(key=k, val=v))
            elif key == "table_properties":
                for k, v in val.items():
                    sql = sql_template.format(fq=self.fq, key=k, val=v)
                    self._run(sql, meta["desc"].format(key=k, val=v))
            elif key == "owner":
                sql = sql_template.format(fq=self.fq, owner=val)
                self._run(sql, meta["desc"].format(owner=val))
            elif key == "table_comment":
                sql = sql_template.format(fq=self.fq, comment=val)
                self._run(sql, meta["desc"])
            else:
                sql = sql_template.format(fq=self.fq, val=val)
                self._run(sql, f"{action.upper()} {key}: {val}")

    def _handle_columns(self, action: str, columns: Dict[int, Dict[str, Any]]):
        if action == "add":
            for idx, col in columns.items():
                name = col.get("name")
                datatype = col.get("datatype")
                if not name or not datatype:
                    continue
                ddl = f"`{name}` {datatype}"
                if not col.get("nullable", True):
                    ddl += " NOT NULL"
                sql = f"ALTER TABLE {self.fq} ADD COLUMNS ({ddl})"
                self._run(sql, f"ADD column {name}")
                if col.get("comment"):
                    sql = f"ALTER TABLE {self.fq} ALTER COLUMN {name} COMMENT '{col['comment']}'"
                    self._run(sql, f"ADD comment to {name}")
                for tag, value in (col.get("tags") or {}).items():
                    sql = f"ALTER TABLE {self.fq} ALTER COLUMN {name} SET TAGS ('{tag}' = '{value}')"
                    self._run(sql, f"ADD tag {tag} to {name}")
                if col.get("column_masking_rule"):
                    self.log.append(f"ADD masking rule for {name} (not supported)")
                for cc_name, cc_def in (col.get("column_check_constraints") or {}).items():
                    expr = cc_def.get("expression", "")
                    self.log.append(f"ADD check constraint {cc_name} on {name}: {expr}")
        elif action == "update":
            for idx, col in columns.items():
                name = col.get("name")
                if not name:
                    continue
                if col.get("comment"):
                    sql = f"ALTER TABLE {self.fq} ALTER COLUMN {name} COMMENT '{col['comment']}'"
                    self._run(sql, f"UPDATE comment for {name}")
                for tag, value in (col.get("tags") or {}).items():
                    sql = f"ALTER TABLE {self.fq} ALTER COLUMN {name} SET TAGS ('{tag}' = '{value}')"
                    self._run(sql, f"UPDATE tag {tag} for {name}")
                if col.get("column_masking_rule"):
                    self.log.append(f"UPDATE masking rule for {name} (not supported)")
                for cc_name, cc_def in (col.get("column_check_constraints") or {}).items():
                    expr = cc_def.get("expression", "")
                    self.log.append(f"UPDATE check constraint {cc_name} on {name}: {expr}")
        elif action == "remove":
            for idx, col in columns.items():
                name = col.get("name")
                if name:
                    sql = f"ALTER TABLE {self.fq} DROP COLUMN {name}"
                    self._run(sql, f"REMOVE column {name}")
                for tag in (col.get("tags") or {}):
                    sql = f"ALTER TABLE {self.fq} ALTER COLUMN {name} UNSET TAGS ('{tag}')"
                    self._run(sql, f"REMOVE tag {tag} from {name}")
                for cc_name in (col.get("column_check_constraints") or {}):
                    self.log.append(f"REMOVE check constraint {cc_name} from {name}")

    def _run(self, sql, desc):
        if self.dry_run:
            print(f"[DRY RUN] {sql}")
        else:
            self.spark.sql(sql)
        self.log.append(desc)

In [0]:
from pyspark.sql import SparkSession
from layker.snapshot_yaml import validate_and_snapshot_yaml
from layker.snapshot_table import TableSnapshot
from layker.differences import generate_differences

spark = SparkSession.getActiveSession() or SparkSession.builder.getOrCreate()
snapshot_yaml, fq_table = validate_and_snapshot_yaml(yaml_path:="layker/resources/test.yaml", env=None, mode="all")
table_snapshot = TableSnapshot(spark, fq_table).build_table_metadata_dict()

diffs = generate_differences(snapshot_yaml, table_snapshot)
print(diffs)

loader = DatabricksTableLoader(diff_dict=diffs, spark=spark, dry_run=False)
loader.run()

print("--- Loader Log Output ---")
for entry in loader.log:
    print(entry)

### Main Function

In [0]:
run_table_load(
    yaml_path="src/layker/resources/example.yaml",
    log_ddl="logs/example_comparison.yaml",  # adjust this path if you moved logs
    dry_run=False,
    spark=spark,
    env="prd",
    mode="all",    # "validate", "diff", "apply", "all"
    audit_log_table="src/layker/resources/audit.yaml"  # if you want to use the built-in one
)