In [1]:
%pip install sqlglot

Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 23.2.1 -> 25.2
[notice] To update, run: python.exe -m pip install --upgrade pip


# Evaluation steps

 

### Model Eval
   - Unions

   - CTEs & Recursive CTEs

   - Time travel syntax

   - Sub-queries

 

### Column Eval
  - Aliases

  - "SELECT *"

  - Calculated/Multi-column fields

  - Window Functions

    - Qualified Column Refs

 

### Other

   - Masking salt key in output

# Models Eval

## Unions

In [2]:
import sqlglot
from sqlglot import expressions as exp

"""
Can I assume all snowflake compiled models will be formatted as db.schema.tbl? I think?
"""


# Example UNION query in Snowflake syntax (no quotes)
union_query = """
SELECT CUSTOMER_ID, ORDER_DATE, 'online' AS CHANNEL
FROM ECOMMERCE_DB.SALES.ONLINE_ORDERS
WHERE ORDER_DATE >= '2024-01-01'
UNION ALL
SELECT CUSTOMER_ID, PURCHASE_DATE AS ORDER_DATE, 'retail' AS CHANNEL
FROM ECOMMERCE_DB.SALES.RETAIL_SALES
WHERE PURCHASE_DATE >= '2024-01-01'
UNION
SELECT CUST_ID AS CUSTOMER_ID, TRANSACTION_DATE AS ORDER_DATE, 'mobile' AS CHANNEL
FROM MOBILE_APP_DB.TRANSACTIONS.MOBILE_TRANSACTIONS
WHERE TRANSACTION_DATE >= '2024-01-01'
"""

def extract_snowflake_tables(sql_query):
    parsed = sqlglot.parse_one(sql_query, dialect="snowflake")
    tables = set()
    for node in parsed.walk():
        if isinstance(node, exp.Table):
            db = node.catalog or ""
            schema = node.db or ""
            name = node.name
            # Build full table name: DATABASE.SCHEMA.TABLE (no quotes)
            if db and schema:
                full_name = f"{db}.{schema}.{name}"
            elif schema:
                full_name = f"{schema}.{name}"
            else:
                full_name = name
            tables.add(full_name)
    return sorted(tables)

# Test extraction
tables = extract_snowflake_tables(union_query)
print("Snowflake tables found:")
for t in tables:
    print(f"  - {t}")

Snowflake tables found:
  - ECOMMERCE_DB.SALES.ONLINE_ORDERS
  - ECOMMERCE_DB.SALES.RETAIL_SALES
  - MOBILE_APP_DB.TRANSACTIONS.MOBILE_TRANSACTIONS


In [3]:
test_queries = [
    # 1. Simple UNION with single tables
    """
    SELECT id FROM db1.schema1.tableA
    UNION
    SELECT id FROM db2.schema2.tableB
    """,

    # 2. UNION ALL with JOIN and subquery
    """
    SELECT u.user_id, o.order_id
    FROM analytics.users u
    JOIN analytics.orders o ON u.user_id = o.user_id
    UNION ALL
    SELECT user_id, NULL
    FROM analytics.inactive_users
    WHERE last_login < '2024-01-01'
    """,

    # 3. UNION with nested SELECT and CTE
    """
    WITH recent_orders AS (
        SELECT order_id, customer_id
        FROM sales.orders
        WHERE order_date > '2025-01-01'
    )
    SELECT customer_id FROM recent_orders
    UNION
    SELECT customer_id FROM sales.customers
    WHERE signup_date > '2025-01-01'
    UNION ALL
    SELECT customer_id FROM marketing.leads
    WHERE source = 'web'
    """
]

for i, q in enumerate(test_queries, 1):
    tables = extract_snowflake_tables(q)
    print(f"\nTest case {i}:")
    for t in tables:
        print(f"  - {t}")


Test case 1:
  - db1.schema1.tableA
  - db2.schema2.tableB

Test case 2:
  - analytics.inactive_users
  - analytics.orders
  - analytics.users

Test case 3:
  - marketing.leads
  - recent_orders
  - sales.customers
  - sales.orders


## CTEs and Recursive CTES

In [4]:
test_queries.append(
    """
    WITH active_customers AS (
        SELECT customer_id
        FROM crm_db.sales.customers
        WHERE status = 'active'
    ),
    recent_orders AS (
        SELECT order_id, customer_id
        FROM crm_db.sales.orders
        WHERE order_date > '2025-01-01'
    ),
    top_products AS (
        SELECT product_id
        FROM crm_db.sales.products
        WHERE rating > 4.5
    )
    SELECT ac.customer_id, ro.order_id
    FROM active_customers ac
    JOIN recent_orders ro ON ac.customer_id = ro.customer_id
    UNION
    SELECT customer_id, NULL
    FROM crm_db.marketing.leads
    WHERE source = 'web'
    UNION ALL
    SELECT NULL, order_id
    FROM recent_orders
    WHERE order_id NOT IN (SELECT order_id FROM crm_db.sales.returns)
    """
)

for i, q in enumerate(test_queries, 1):
    tables = extract_snowflake_tables(q)
    print(f"\nTest case {i}:")
    for t in tables:
        print(f"  - {t}")


Test case 1:
  - db1.schema1.tableA
  - db2.schema2.tableB

Test case 2:
  - analytics.inactive_users
  - analytics.orders
  - analytics.users

Test case 3:
  - marketing.leads
  - recent_orders
  - sales.customers
  - sales.orders

Test case 4:
  - active_customers
  - crm_db.marketing.leads
  - crm_db.sales.customers
  - crm_db.sales.orders
  - crm_db.sales.products
  - crm_db.sales.returns
  - recent_orders


In [5]:
test_queries.append(
    """
    WITH __dbt__cte__dummy_data AS (
        SELECT
            upper(nullif(v:DUMMY_VER_NAME::STRING,'')) AS dummy_ver_name,
            upper(nullif(v:DUMMY_POP_NAME::STRING,'')) AS dummy_pop_name,
            upper(nullif(v:DUMMY_LEVEL_CD::STRING,'')) AS dummy_level_cd,
            upper(nullif(v:DUMMY_VAR_NAME::STRING,'')) AS dummy_var_name,
            nullif(v:DUMMY_COEF::STRING,'')::NUMBER(8,3) AS dummy_coef
        FROM dummy_schema.dummy_table
    ),
    get_dummy_data AS (
        SELECT
            dummy_ver_name,
            dummy_pop_name,
            dummy_level_cd,
            dummy_var_name,
            dummy_coef
        FROM __dbt__cte__dummy_data
    )
    SELECT
        COALESCE(gd.dummy_ver_name::VARCHAR, '') || '~' || COALESCE(gd.dummy_pop_name::VARCHAR, '') || '~' || COALESCE(gd.dummy_level_cd::VARCHAR, '') || '~' || COALESCE(gd.dummy_var_name::VARCHAR, '') AS dummy_id,
        dd.dummy_key,
        gd.dummy_ver_name,
        gd.dummy_pop_name,
        gd.dummy_level_cd,
        gd.dummy_var_name,
        gd.dummy_coef
    FROM get_dummy_data gd
    INNER JOIN dummy_schema.dummy_dim dd ON gd.dummy_ver_name = dd.dummy_ver_name
    """
)

for i, q in enumerate(test_queries, 1):
    tables = extract_snowflake_tables(q)
    print(f"\nTest case {i}:")
    for t in tables:
        print(f"  - {t}")



Test case 1:
  - db1.schema1.tableA
  - db2.schema2.tableB

Test case 2:
  - analytics.inactive_users
  - analytics.orders
  - analytics.users

Test case 3:
  - marketing.leads
  - recent_orders
  - sales.customers
  - sales.orders

Test case 4:
  - active_customers
  - crm_db.marketing.leads
  - crm_db.sales.customers
  - crm_db.sales.orders
  - crm_db.sales.products
  - crm_db.sales.returns
  - recent_orders

Test case 5:
  - __dbt__cte__dummy_data
  - dummy_schema.dummy_dim
  - dummy_schema.dummy_table
  - get_dummy_data


In [6]:
def extract_snowflake_tables(sql_query):
    parsed = sqlglot.parse_one(sql_query, dialect="snowflake")
    tables = set()
    cte_names = set()

    # Collect CTE names
    for node in parsed.find_all(exp.CTE):
        if node.alias:
            cte_names.add(node.alias)

    # Collect all table references
    for node in parsed.walk():
        if isinstance(node, exp.Table):
            db = node.catalog or ""
            schema = node.db or ""
            name = node.name
            if db and schema:
                full_name = f"{db}.{schema}.{name}"
            elif schema:
                full_name = f"{schema}.{name}"
            else:
                full_name = name
            tables.add(full_name)

    # Separate physical tables from CTEs
    physical_tables = [t for t in tables if t not in cte_names]
    return sorted(physical_tables), sorted(cte_names)



for i, q in enumerate(test_queries, 1):
    physical_tables, cte_names = extract_snowflake_tables(q)
    print(f"\nTest case {i}:")
    print("  Physical tables:")
    for t in physical_tables:
        print(f"    - {t}")
    print("  CTE names:")
    for c in cte_names:
        print(f"    - {c}")


Test case 1:
  Physical tables:
    - db1.schema1.tableA
    - db2.schema2.tableB
  CTE names:

Test case 2:
  Physical tables:
    - analytics.inactive_users
    - analytics.orders
    - analytics.users
  CTE names:

Test case 3:
  Physical tables:
    - marketing.leads
    - sales.customers
    - sales.orders
  CTE names:
    - recent_orders

Test case 4:
  Physical tables:
    - crm_db.marketing.leads
    - crm_db.sales.customers
    - crm_db.sales.orders
    - crm_db.sales.products
    - crm_db.sales.returns
  CTE names:
    - active_customers
    - recent_orders
    - top_products

Test case 5:
  Physical tables:
    - dummy_schema.dummy_dim
    - dummy_schema.dummy_table
  CTE names:
    - __dbt__cte__dummy_data
    - get_dummy_data


## Timestamp Example

In [7]:
test_queries.append(
    """
    WITH __dbt__cte__dummy_data AS (
        SELECT
            upper(nullif(v:DUMMY_VER_NAME::STRING,'')) AS dummy_ver_name,
            upper(nullif(v:DUMMY_POP_NAME::STRING,'')) AS dummy_pop_name,
            upper(nullif(v:DUMMY_LEVEL_CD::STRING,'')) AS dummy_level_cd,
            upper(nullif(v:DUMMY_VAR_NAME::STRING,'')) AS dummy_var_name,
            nullif(v:DUMMY_COEF::STRING,'')::NUMBER(8,3) AS dummy_coef
        FROM dummy_schema.dummy_table AT (TIMESTAMP => '2025-07-31 00:00:00')
    ),
    get_dummy_data AS (
        SELECT
            dummy_ver_name,
            dummy_pop_name,
            dummy_level_cd,
            dummy_var_name,
            dummy_coef
        FROM __dbt__cte__dummy_data
    )
    SELECT
        COALESCE(gd.dummy_ver_name::VARCHAR, '') || '~' || COALESCE(gd.dummy_pop_name::VARCHAR, '') || '~' || COALESCE(gd.dummy_level_cd::VARCHAR, '') || '~' || COALESCE(gd.dummy_var_name::VARCHAR, '') AS dummy_id,
        dd.dummy_key,
        gd.dummy_ver_name,
        gd.dummy_pop_name,
        gd.dummy_level_cd,
        gd.dummy_var_name,
        gd.dummy_coef
    FROM get_dummy_data gd
    INNER JOIN dummy_schema.dummy_dim dd ON gd.dummy_ver_name = dd.dummy_ver_name
    """
)

for i, q in enumerate(test_queries, 1):
    physical_tables, cte_names = extract_snowflake_tables(q)
    print(f"\nTest case {i}:")
    print("  Physical tables:")
    for t in physical_tables:
        print(f"    - {t}")
    print("  CTE names:")
    for c in cte_names:
        print(f"    - {c}")



Test case 1:
  Physical tables:
    - db1.schema1.tableA
    - db2.schema2.tableB
  CTE names:

Test case 2:
  Physical tables:
    - analytics.inactive_users
    - analytics.orders
    - analytics.users
  CTE names:

Test case 3:
  Physical tables:
    - marketing.leads
    - sales.customers
    - sales.orders
  CTE names:
    - recent_orders

Test case 4:
  Physical tables:
    - crm_db.marketing.leads
    - crm_db.sales.customers
    - crm_db.sales.orders
    - crm_db.sales.products
    - crm_db.sales.returns
  CTE names:
    - active_customers
    - recent_orders
    - top_products

Test case 5:
  Physical tables:
    - dummy_schema.dummy_dim
    - dummy_schema.dummy_table
  CTE names:
    - __dbt__cte__dummy_data
    - get_dummy_data

Test case 6:
  Physical tables:
    - dummy_schema.dummy_dim
    - dummy_schema.dummy_table
  CTE names:
    - __dbt__cte__dummy_data
    - get_dummy_data


## Derived/Sub-query example

In [8]:
test_queries.append(
    """
    WITH __dbt__cte__dummy_data AS (
        SELECT
            upper(nullif(v:DUMMY_VER_NAME::STRING,'')) AS dummy_ver_name,
            upper(nullif(v:DUMMY_POP_NAME::STRING,'')) AS dummy_pop_name,
            upper(nullif(v:DUMMY_LEVEL_CD::STRING,'')) AS dummy_level_cd,
            upper(nullif(v:DUMMY_VAR_NAME::STRING,'')) AS dummy_var_name,
            nullif(v:DUMMY_COEF::STRING,'')::NUMBER(8,3) AS dummy_coef
        FROM dummy_schema.dummy_table AT (TIMESTAMP => '2025-07-31 00:00:00')
    ),
    get_dummy_data AS (
        SELECT
            dummy_ver_name,
            dummy_pop_name,
            dummy_level_cd,
            dummy_var_name,
            dummy_coef
        FROM __dbt__cte__dummy_data
    )
    SELECT
        COALESCE(gd.dummy_ver_name::VARCHAR, '') || '~' || COALESCE(gd.dummy_pop_name::VARCHAR, '') || '~' || COALESCE(gd.dummy_level_cd::VARCHAR, '') || '~' || COALESCE(gd.dummy_var_name::VARCHAR, '') AS dummy_id,
        gd.dummy_key,
        gd.dummy_ver_name,
        gd.dummy_pop_name,
        gd.dummy_level_cd,
        gd.dummy_var_name,
        gd.dummy_coef,
        sub.latest_status
    FROM get_dummy_data gd
    INNER JOIN (
        SELECT
            dummy_ver_name,
            MAX(status_date) AS latest_status
        FROM dummy_schema.dummy_status
        WHERE status_code IN (
            SELECT code FROM dummy_schema.status_codes WHERE is_active = 1
        )
        GROUP BY dummy_ver_name
    ) sub ON gd.dummy_ver_name = sub.dummy_ver_name
    WHERE gd.dummy_coef > (
        SELECT AVG(dummy_coef) FROM dummy_schema.dummy_table WHERE dummy_level_cd = gd.dummy_level_cd
    )
    """
)



In [9]:
def extract_snowflake_tables(sql_query):
    parsed = sqlglot.parse_one(sql_query, dialect="snowflake")
    cte_names = set()
    table_aliases = dict()
    all_physical_tables = set()
    join_subquery_tables = set()
    where_subquery_tables = set()
    cte_source_tables = set()

    # Collect CTE names and their source tables
    for cte in parsed.find_all(exp.CTE):
        if cte.alias:
            cte_names.add(cte.alias)
        # Find tables referenced inside CTE definitions
        for node in cte.find_all(exp.Table):
            db = node.catalog or ""
            schema = node.db or ""
            name = node.name
            if db and schema:
                full_name = f"{db}.{schema}.{name}"
            elif schema:
                full_name = f"{schema}.{name}"
            else:
                full_name = name
            if full_name not in cte_names:
                cte_source_tables.add(full_name)

    # Helper to get full table name
    def get_full_name(node):
        db = node.catalog or ""
        schema = node.db or ""
        name = node.name
        if db and schema:
            return f"{db}.{schema}.{name}"
        elif schema:
            return f"{schema}.{name}"
        else:
            return name

    # Collect all physical tables (not CTEs) anywhere in the query
    for node in parsed.walk():
        if isinstance(node, exp.Table):
            full_name = get_full_name(node)
            if full_name not in cte_names:
                all_physical_tables.add(full_name)
            if node.alias:
                table_aliases[node.alias] = full_name

    # Collect tables in JOIN subqueries and derived tables
    for join in parsed.find_all(exp.Join):
        for subquery in join.find_all(exp.Subquery):
            for node in subquery.walk():
                if isinstance(node, exp.Table):
                    tbl = get_full_name(node)
                    if tbl not in cte_names:
                        join_subquery_tables.add(tbl)
                    if node.alias:
                        table_aliases[node.alias] = tbl

    # Collect tables in WHERE subqueries
    for where in parsed.find_all(exp.Where):
        for subquery in where.find_all(exp.Subquery):
            for node in subquery.walk():
                if isinstance(node, exp.Table):
                    tbl = get_full_name(node)
                    if tbl not in cte_names:
                        where_subquery_tables.add(tbl)
                    if node.alias:
                        table_aliases[node.alias] = tbl

    valuable_join_tables = sorted(join_subquery_tables - where_subquery_tables)

    # Only include as source/target if:
    # - referenced in a CTE definition (cte_source_tables)
    # - or referenced outside of WHERE subqueries (i.e., not only in where_subquery_tables)
    source_target_tables = sorted(
        t for t in all_physical_tables
        if t in cte_source_tables or t not in where_subquery_tables
    )

    return (
        source_target_tables,
        sorted(cte_names),
        valuable_join_tables,
        sorted(where_subquery_tables),
        table_aliases
    )

# Example usage and test logic:
for i, q in enumerate(test_queries, 1):
    target_tables, cte_names, join_subquery_tables, where_subquery_tables, table_aliases = extract_snowflake_tables(q)
    print(f"\nTest case {i}:")
    print("  Source/target tables (all physical tables):")
    for t in target_tables:
        print(f"    - {t}")
    print("  CTE names:")
    for c in cte_names:
        print(f"    - {c}")
    print("  JOIN/derived subquery tables (valuable for lineage):")
    for j in join_subquery_tables:
        print(f"    - {j}")
    print("  WHERE subquery tables (not useful for lineage):")
    for w in where_subquery_tables:
        print(f"    - {w}")
    print("  Table aliases:")
    for alias, table in table_aliases.items():
        print(f"    {alias} -> {table}")


Test case 1:
  Source/target tables (all physical tables):
    - db1.schema1.tableA
    - db2.schema2.tableB
  CTE names:
  JOIN/derived subquery tables (valuable for lineage):
  WHERE subquery tables (not useful for lineage):
  Table aliases:

Test case 2:
  Source/target tables (all physical tables):
    - analytics.inactive_users
    - analytics.orders
    - analytics.users
  CTE names:
  JOIN/derived subquery tables (valuable for lineage):
  WHERE subquery tables (not useful for lineage):
  Table aliases:
    u -> analytics.users
    o -> analytics.orders

Test case 3:
  Source/target tables (all physical tables):
    - marketing.leads
    - sales.customers
    - sales.orders
  CTE names:
    - recent_orders
  JOIN/derived subquery tables (valuable for lineage):
  WHERE subquery tables (not useful for lineage):
  Table aliases:

Test case 4:
  Source/target tables (all physical tables):
    - crm_db.marketing.leads
    - crm_db.sales.customers
    - crm_db.sales.orders
    - crm_d

# Columns Eval

## Aliases

In [10]:
import sqlglot
from sqlglot import expressions as exp

def extract_snowflake_columns(sql_query):
    """
    Extracts column lineage information from a Snowflake SQL query.
    Returns a list of dicts, each describing an output column.
    """
    parsed = sqlglot.parse_one(sql_query, dialect="snowflake")
    columns = []

    # Helper to get the string representation of an expression
    def expr_to_str(expr):
        return expr.sql(dialect="snowflake") if expr else None

    # Helper to recursively collect all column references in an expression
    def collect_source_columns(expr):
        sources = set()
        for node in expr.walk():
            if isinstance(node, exp.Column):
                # node.table can be None if unqualified
                sources.add((node.table, node.name))
        return list(sources)

    # Find the outermost SELECT (not inside a subquery)
    select = parsed
    while not isinstance(select, exp.Select) and select:
        select = select.args.get("this") if hasattr(select, "args") else None

    if not isinstance(select, exp.Select):
        # Try to find any SELECT if not top-level
        select = next(parsed.find_all(exp.Select), None)

    if select:
        for proj in select.expressions:
            # Target/output column name
            alias = proj.alias_or_name
            # Raw SQL for the expression
            expression_sql = expr_to_str(proj)
            # Source columns referenced in the expression
            source_columns = collect_source_columns(proj)
            # Type: direct, calculated, or constant
            if isinstance(proj, exp.Column):
                col_type = "direct"
            elif proj.is_star:
                col_type = "star"
            elif not source_columns:
                col_type = "constant"
            else:
                col_type = "calculated"
            columns.append({
                "target_column": alias,
                "expression": expression_sql,
                "source_columns": source_columns,
                "type": col_type
            })
    return columns

# Example usage:
if __name__ == "__main__":
    example_query = """
    SELECT
        a,
        b + c AS sum_col,
        'foo' AS const_col,
        t1.d AS d_alias
    FROM my_schema.my_table t1
    """
    cols = extract_snowflake_columns(example_query)
    for col in cols:
        print(col)

{'target_column': 'a', 'expression': 'a', 'source_columns': [('', 'a')], 'type': 'direct'}
{'target_column': 'sum_col', 'expression': 'b + c AS sum_col', 'source_columns': [('', 'c'), ('', 'b')], 'type': 'calculated'}
{'target_column': 'const_col', 'expression': "'foo' AS const_col", 'source_columns': [], 'type': 'constant'}
{'target_column': 'd_alias', 'expression': 't1.d AS d_alias', 'source_columns': [('t1', 'd')], 'type': 'calculated'}


In [11]:
import sqlglot
from sqlglot import expressions as exp

def extract_snowflake_tables(sql_query):
    parsed = sqlglot.parse_one(sql_query, dialect="snowflake")
    cte_names = set()
    table_aliases = dict()
    all_physical_tables = set()
    join_subquery_tables = set()
    where_subquery_tables = set()
    cte_source_tables = set()

    # Collect CTE names and their source tables
    for cte in parsed.find_all(exp.CTE):
        if cte.alias:
            cte_names.add(cte.alias)
        # Find tables referenced inside CTE definitions
        for node in cte.find_all(exp.Table):
            db = node.catalog or ""
            schema = node.db or ""
            name = node.name
            if db and schema:
                full_name = f"{db}.{schema}.{name}"
            elif schema:
                full_name = f"{schema}.{name}"
            else:
                full_name = name
            if full_name not in cte_names:
                cte_source_tables.add(full_name)

    # Helper to get full table name
    def get_full_name(node):
        db = node.catalog or ""
        schema = node.db or ""
        name = node.name
        if db and schema:
            return f"{db}.{schema}.{name}"
        elif schema:
            return f"{schema}.{name}"
        else:
            return name

    # Collect all physical tables (not CTEs) anywhere in the query
    for node in parsed.walk():
        if isinstance(node, exp.Table):
            full_name = get_full_name(node)
            if full_name not in cte_names:
                all_physical_tables.add(full_name)
            if node.alias:
                table_aliases[node.alias] = full_name

    # Collect tables in JOIN subqueries and derived tables
    for join in parsed.find_all(exp.Join):
        for subquery in join.find_all(exp.Subquery):
            for node in subquery.walk():
                if isinstance(node, exp.Table):
                    tbl = get_full_name(node)
                    if tbl not in cte_names:
                        join_subquery_tables.add(tbl)
                    if node.alias:
                        table_aliases[node.alias] = tbl

    # Collect tables in WHERE subqueries
    for where in parsed.find_all(exp.Where):
        for subquery in where.find_all(exp.Subquery):
            for node in subquery.walk():
                if isinstance(node, exp.Table):
                    tbl = get_full_name(node)
                    if tbl not in cte_names:
                        where_subquery_tables.add(tbl)
                    if node.alias:
                        table_aliases[node.alias] = tbl

    valuable_join_tables = sorted(join_subquery_tables - where_subquery_tables)

    # Only include as source/target if:
    # - referenced in a CTE definition (cte_source_tables)
    # - or referenced outside of WHERE subqueries (i.e., not only in where_subquery_tables)
    source_target_tables = sorted(
        t for t in all_physical_tables
        if t in cte_source_tables or t not in where_subquery_tables
    )

    return (
        source_target_tables,
        sorted(cte_names),
        valuable_join_tables,
        sorted(where_subquery_tables),
        table_aliases
    )

def extract_snowflake_columns(sql_query):
    """
    Extracts column lineage information from a Snowflake SQL query.
    Returns a list of dicts, each describing an output column.
    """
    parsed = sqlglot.parse_one(sql_query, dialect="snowflake")
    columns = []

    # Helper to get the string representation of an expression
    def expr_to_str(expr):
        return expr.sql(dialect="snowflake") if expr else None

    # Helper to recursively collect all column references in an expression
    def collect_source_columns(expr):
        sources = set()
        for node in expr.walk():
            if isinstance(node, exp.Column):
                # node.table can be None if unqualified
                sources.add((node.table, node.name))
        return list(sources)

    # Find the outermost SELECT (not inside a subquery)
    select = parsed
    while not isinstance(select, exp.Select) and select:
        select = select.args.get("this") if hasattr(select, "args") else None

    if not isinstance(select, exp.Select):
        # Try to find any SELECT if not top-level
        select = next(parsed.find_all(exp.Select), None)

    if select:
        for proj in select.expressions:
            # Target/output column name
            alias = proj.alias_or_name
            # Raw SQL for the expression
            expression_sql = expr_to_str(proj)
            # Source columns referenced in the expression
            source_columns = collect_source_columns(proj)
            # Type: direct, calculated, or constant
            if isinstance(proj, exp.Column):
                col_type = "direct"
            elif proj.is_star:
                col_type = "star"
            elif not source_columns:
                col_type = "constant"
            else:
                col_type = "calculated"
            columns.append({
                "target_column": alias,
                "expression": expression_sql,
                "source_columns": source_columns,
                "type": col_type
            })
    return columns

# Example usage and integration:
test_sql = """
SELECT
    a,
    b + c AS sum_col,
    'foo' AS const_col,
    t1.d AS d_alias
FROM my_schema.my_table t1
"""

# Extract tables and aliases
tables_result = extract_snowflake_tables(test_sql)
source_target_tables, cte_names, join_subquery_tables, where_subquery_tables, table_aliases = tables_result

# Extract columns
columns_result = extract_snowflake_columns(test_sql)

# Resolve source tables for each column using table_aliases
for col in columns_result:
    resolved_sources = []
    for alias, col_name in col["source_columns"]:
        if alias in table_aliases:
            resolved_sources.append((table_aliases[alias], col_name))
        elif alias is None and len(source_target_tables) == 1:
            # Unqualified column, only one table in FROM
            resolved_sources.append((source_target_tables[0], col_name))
        else:
            resolved_sources.append((alias, col_name))  # Could be None or a CTE
    col["resolved_source_columns"] = resolved_sources

# Print results
print("Source/target tables (all physical tables):")
for t in source_target_tables:
    print(f"  - {t}")
print("CTE names:")
for c in cte_names:
    print(f"  - {c}")
print("Table aliases:")
for alias, table in table_aliases.items():
    print(f"  {alias} -> {table}")

print("\nColumns lineage:")
for col in columns_result:
    print(f"Target column: {col['target_column']}")
    print(f"  Expression: {col['expression']}")
    print(f"  Source columns: {col['source_columns']}")
    print(f"  Resolved source columns: {col['resolved_source_columns']}")
    print(f"  Type: {col['type']}")
    print()

Source/target tables (all physical tables):
  - my_schema.my_table
CTE names:
Table aliases:
  t1 -> my_schema.my_table

Columns lineage:
Target column: a
  Expression: a
  Source columns: [('', 'a')]
  Resolved source columns: [('', 'a')]
  Type: direct

Target column: sum_col
  Expression: b + c AS sum_col
  Source columns: [('', 'c'), ('', 'b')]
  Resolved source columns: [('', 'c'), ('', 'b')]
  Type: calculated

Target column: const_col
  Expression: 'foo' AS const_col
  Source columns: []
  Resolved source columns: []
  Type: constant

Target column: d_alias
  Expression: t1.d AS d_alias
  Source columns: [('t1', 'd')]
  Resolved source columns: [('my_schema.my_table', 'd')]
  Type: calculated



In [12]:
# Example usage and integration:
test_sql = """
SELECT
    a,
    b + c AS sum_col,
    'foo' AS const_col,
    t1.d AS d_alias
FROM my_schema.my_table t1
"""

# Add the complex CTE/subquery example as a test case
test_sql_2 = """
WITH __dbt__cte__dummy_data AS (
    SELECT
        upper(nullif(v:DUMMY_VER_NAME::STRING,'')) AS dummy_ver_name,
        upper(nullif(v:DUMMY_POP_NAME::STRING,'')) AS dummy_pop_name,
        upper(nullif(v:DUMMY_LEVEL_CD::STRING,'')) AS dummy_level_cd,
        upper(nullif(v:DUMMY_VAR_NAME::STRING,'')) AS dummy_var_name,
        nullif(v:DUMMY_COEF::STRING,'')::NUMBER(8,3) AS dummy_coef
    FROM dummy_schema.dummy_table AT (TIMESTAMP => '2025-07-31 00:00:00')
),
get_dummy_data AS (
    SELECT
        dummy_ver_name,
        dummy_pop_name,
        dummy_level_cd,
        dummy_var_name,
        dummy_coef
    FROM __dbt__cte__dummy_data
)
SELECT
    COALESCE(gd.dummy_ver_name::VARCHAR, '') || '~' || COALESCE(gd.dummy_pop_name::VARCHAR, '') || '~' || COALESCE(gd.dummy_level_cd::VARCHAR, '') || '~' || COALESCE(gd.dummy_var_name::VARCHAR, '') AS dummy_id,
    gd.dummy_key,
    gd.dummy_ver_name,
    gd.dummy_pop_name,
    gd.dummy_level_cd,
    gd.dummy_var_name,
    gd.dummy_coef,
    gd.latest_status
FROM get_dummy_data gd
INNER JOIN (
    SELECT
        dummy_ver_name,
        MAX(status_date) AS latest_status
    FROM dummy_schema.dummy_status
    WHERE status_code IN (
        SELECT code FROM dummy_schema.status_codes WHERE is_active = 1
    )
    GROUP BY dummy_ver_name
) sub ON gd.dummy_ver_name = sub.dummy_ver_name
WHERE gd.dummy_coef > (
    SELECT AVG(dummy_coef) FROM dummy_schema.dummy_table WHERE dummy_level_cd = gd.dummy_level_cd
)
"""

for sql in [test_sql_2]:
    print("\n=== NEW TEST CASE ===")
    tables_result = extract_snowflake_tables(sql)
    source_target_tables, cte_names, join_subquery_tables, where_subquery_tables, table_aliases = tables_result

    # --- Add this block to show all table-related variables ---
    print("All table variables:")
    print(f"  source_target_tables: {source_target_tables}")
    print(f"  cte_names: {cte_names}")
    print(f"  join_subquery_tables: {join_subquery_tables}")
    print(f"  where_subquery_tables: {where_subquery_tables}")
    print(f"  table_aliases: {table_aliases}")
    # If you want to see all_physical_tables and cte_source_tables, you need to modify extract_snowflake_tables to return them as well.
    # For now, only the above are available from the return value.
    print()

    columns_result = extract_snowflake_columns(sql)

    for col in columns_result:
        resolved_sources = []
        for alias, col_name in col["source_columns"]:
            if alias in table_aliases:
                resolved_sources.append((table_aliases[alias], col_name))
            elif alias is None and len(source_target_tables) == 1:
                resolved_sources.append((source_target_tables[0], col_name))
            else:
                resolved_sources.append((alias, col_name))
        col["resolved_source_columns"] = resolved_sources

    print("Source/target tables (all physical tables):")
    for t in source_target_tables:
        print(f"  - {t}")
    print("CTE names:")
    for c in cte_names:
        print(f"  - {c}")
    print("Table aliases:")
    for alias, table in table_aliases.items():
        print(f"  {alias} -> {table}")

    print("\nColumns lineage:")
    for col in columns_result:
        print(f"Target column: {col['target_column']}")
        print(f"  Expression: {col['expression']}")
        print(f"  Source columns: {col['source_columns']}")
        print(f"  Resolved source columns: {col['resolved_source_columns']}")
        print(f"  Type: {col['type']}")
        print()


=== NEW TEST CASE ===
All table variables:
  source_target_tables: ['dummy_schema.dummy_status', 'dummy_schema.dummy_table']
  cte_names: ['__dbt__cte__dummy_data', 'get_dummy_data']
  join_subquery_tables: ['dummy_schema.dummy_status']
  where_subquery_tables: ['dummy_schema.dummy_table', 'dummy_schema.status_codes']
  table_aliases: {'gd': 'get_dummy_data'}

Source/target tables (all physical tables):
  - dummy_schema.dummy_status
  - dummy_schema.dummy_table
CTE names:
  - __dbt__cte__dummy_data
  - get_dummy_data
Table aliases:
  gd -> get_dummy_data

Columns lineage:
Target column: dummy_id
  Expression: COALESCE(CAST(gd.dummy_ver_name AS VARCHAR), '') || '~' || COALESCE(CAST(gd.dummy_pop_name AS VARCHAR), '') || '~' || COALESCE(CAST(gd.dummy_level_cd AS VARCHAR), '') || '~' || COALESCE(CAST(gd.dummy_var_name AS VARCHAR), '') AS dummy_id
  Source columns: [('gd', 'dummy_ver_name'), ('gd', 'dummy_var_name'), ('gd', 'dummy_pop_name'), ('gd', 'dummy_level_cd')]
  Resolved source col

In [13]:
for sql in test_queries[6:7]:
    print("\n=== NEW TEST CASE ===")
    print(sql)
    tables_result = extract_snowflake_tables(sql)
    source_target_tables, cte_names, join_subquery_tables, where_subquery_tables, table_aliases = tables_result

    # --- Add this block to show all table-related variables ---
    print("All table variables:")
    print(f"  source_target_tables: {source_target_tables}")
    print(f"  cte_names: {cte_names}")
    print(f"  join_subquery_tables: {join_subquery_tables}")
    print(f"  where_subquery_tables: {where_subquery_tables}")
    print(f"  table_aliases: {table_aliases}")
    # If you want to see all_physical_tables and cte_source_tables, you need to modify extract_snowflake_tables to return them as well.
    # For now, only the above are available from the return value.
    print()

    columns_result = extract_snowflake_columns(sql)

    for col in columns_result:
        resolved_sources = []
        for alias, col_name in col["source_columns"]:
            if alias in table_aliases:
                resolved_sources.append((table_aliases[alias], col_name))
            elif alias is None and len(source_target_tables) == 1:
                resolved_sources.append((source_target_tables[0], col_name))
            else:
                resolved_sources.append((alias, col_name))
        col["resolved_source_columns"] = resolved_sources

    print("Source/target tables (all physical tables):")
    for t in source_target_tables:
        print(f"  - {t}")
    print("CTE names:")
    for c in cte_names:
        print(f"  - {c}")
    print("Table aliases:")
    for alias, table in table_aliases.items():
        print(f"  {alias} -> {table}")

    print("\nColumns lineage:")
    for col in columns_result:
        print(f"Target column: {col['target_column']}")
        print(f"  Expression: {col['expression']}")
        print(f"  Source columns: {col['source_columns']}")
        print(f"  Resolved source columns: {col['resolved_source_columns']}")
        print(f"  Type: {col['type']}")
        print()


=== NEW TEST CASE ===

    WITH __dbt__cte__dummy_data AS (
        SELECT
            upper(nullif(v:DUMMY_VER_NAME::STRING,'')) AS dummy_ver_name,
            upper(nullif(v:DUMMY_POP_NAME::STRING,'')) AS dummy_pop_name,
            upper(nullif(v:DUMMY_LEVEL_CD::STRING,'')) AS dummy_level_cd,
            upper(nullif(v:DUMMY_VAR_NAME::STRING,'')) AS dummy_var_name,
            nullif(v:DUMMY_COEF::STRING,'')::NUMBER(8,3) AS dummy_coef
        FROM dummy_schema.dummy_table AT (TIMESTAMP => '2025-07-31 00:00:00')
    ),
    get_dummy_data AS (
        SELECT
            dummy_ver_name,
            dummy_pop_name,
            dummy_level_cd,
            dummy_var_name,
            dummy_coef
        FROM __dbt__cte__dummy_data
    )
    SELECT
        COALESCE(gd.dummy_ver_name::VARCHAR, '') || '~' || COALESCE(gd.dummy_pop_name::VARCHAR, '') || '~' || COALESCE(gd.dummy_level_cd::VARCHAR, '') || '~' || COALESCE(gd.dummy_var_name::VARCHAR, '') AS dummy_id,
        gd.dummy_key,
        

In [14]:
# No source column identified because col is coming from 3 tables?
# How to align query alias "sub" with source table from within derived query? 
# Need to work on table to column association when is a column
# for models that can't find their source is there a source search based on the tables in the model?
# enforce best practice to add this aliases to all columns

In [15]:
import sqlglot
from sqlglot import expressions as exp

def extract_snowflake_columns(sql_query):
    """
    Extracts column lineage information from a Snowflake SQL query.
    Returns a list of dicts, each describing an output column.
    """
    parsed = sqlglot.parse_one(sql_query, dialect="snowflake")
    columns = []

    # Helper to get the string representation of an expression
    def expr_to_str(expr):
        return expr.sql(dialect="snowflake") if expr else None

    # Helper to recursively collect all column references in an expression
    def collect_source_columns(expr):
        sources = set()
        for node in expr.walk():
            if isinstance(node, exp.Column):
                sources.add((node.table, node.name))
        return list(sources)

    # Find all SELECTs in the query (including UNION branches, subqueries, etc.)
    selects = [node for node in parsed.walk() if isinstance(node, exp.Select)]

    all_columns = []
    for idx, select in enumerate(selects):
        select_columns = []
        for proj in select.expressions:
            alias = proj.alias_or_name
            expression_sql = expr_to_str(proj)
            source_columns = collect_source_columns(proj)
            if isinstance(proj, exp.Column):
                col_type = "direct"
            elif proj.is_star:
                col_type = "star"
            elif not source_columns:
                col_type = "constant"
            else:
                col_type = "calculated"
            select_columns.append({
                "select_idx": idx,
                "target_column": alias,
                "expression": expression_sql,
                "source_columns": source_columns,
                "type": col_type
            })
        all_columns.append(select_columns)
    return all_columns

# Example usage:
for i, sql in enumerate(test_queries[6:7], 1):
    print(f"\n=== Test Query {i} ===")
    print(f'{sql}\n')
    all_columns = extract_snowflake_columns(sql)
    for select_idx, select_columns in enumerate(all_columns):
        print(f"  SELECT branch {select_idx+1}:")
        for col in select_columns:
            print(f"    Target column: {col['target_column']}")
            print(f"      Expression: {col['expression']}")
            print(f"      Source columns: {col['source_columns']}")
            print(f"      Type: {col['type']}")


=== Test Query 1 ===

    WITH __dbt__cte__dummy_data AS (
        SELECT
            upper(nullif(v:DUMMY_VER_NAME::STRING,'')) AS dummy_ver_name,
            upper(nullif(v:DUMMY_POP_NAME::STRING,'')) AS dummy_pop_name,
            upper(nullif(v:DUMMY_LEVEL_CD::STRING,'')) AS dummy_level_cd,
            upper(nullif(v:DUMMY_VAR_NAME::STRING,'')) AS dummy_var_name,
            nullif(v:DUMMY_COEF::STRING,'')::NUMBER(8,3) AS dummy_coef
        FROM dummy_schema.dummy_table AT (TIMESTAMP => '2025-07-31 00:00:00')
    ),
    get_dummy_data AS (
        SELECT
            dummy_ver_name,
            dummy_pop_name,
            dummy_level_cd,
            dummy_var_name,
            dummy_coef
        FROM __dbt__cte__dummy_data
    )
    SELECT
        COALESCE(gd.dummy_ver_name::VARCHAR, '') || '~' || COALESCE(gd.dummy_pop_name::VARCHAR, '') || '~' || COALESCE(gd.dummy_level_cd::VARCHAR, '') || '~' || COALESCE(gd.dummy_var_name::VARCHAR, '') AS dummy_id,
        gd.dummy_key,
        g

In [25]:
def extract_snowflake_columns(sql_query):
    """
    Extracts column lineage information from a Snowflake SQL query.
    Returns a list of lists, each describing the output columns for each SELECT.
    """
    parsed = sqlglot.parse_one(sql_query, dialect="snowflake")

    def expr_to_str(expr):
        return expr.sql(dialect="snowflake") if expr else None

    def collect_source_columns(expr):
        sources = set()
        for node in expr.walk():
            if isinstance(node, exp.Column):
                sources.add((node.table, node.name))
        return list(sources)

    # Helper: get all tables in the FROM clause of a SELECT
    def get_from_tables(select):
        """
        Returns a dict mapping alias (lowercase) -> (full_table_name, alias)
        """
        tables = {}
        from_expr = select.args.get("from")
        if from_expr:
            # Base table
            base = from_expr.args.get("this")
            if isinstance(base, exp.Table):
                db = base.catalog or ""
                schema = base.db or ""
                name = base.name
                if db and schema:
                    full_name = f"{db}.{schema}.{name}"
                elif schema:
                    full_name = f"{schema}.{name}"
                else:
                    full_name = name
                alias = base.alias or name
                tables[alias.lower()] = (full_name, alias)
            # JOINed tables
            for join in from_expr.find_all(exp.Join):
                join_table = join.args.get("this")
                if isinstance(join_table, exp.Table):
                    db = join_table.catalog or ""
                    schema = join_table.db or ""
                    name = join_table.name
                    if db and schema:
                        full_name = f"{db}.{schema}.{name}"
                    elif schema:
                        full_name = f"{schema}.{name}"
                    else:
                        full_name = name
                    alias = join_table.alias or name
                    tables[alias.lower()] = (full_name, alias)
        return tables
    
    selects = [node for node in parsed.walk() if isinstance(node, exp.Select)]
    all_columns = []
    for idx, select in enumerate(selects):
        select_columns = []
        from_tables = get_from_tables(select)
        only_table = list(from_tables.values())[0][0] if len(from_tables) == 1 else None
        for proj in select.expressions:
            alias = proj.alias_or_name
            expression_sql = expr_to_str(proj)
            source_columns = collect_source_columns(proj)
            resolved_sources = []
            for tbl_alias, col_name in source_columns:
                if (not tbl_alias or tbl_alias == "") and only_table:
                    resolved_sources.append((only_table, col_name))
                elif tbl_alias:
                    tbl_alias_lc = tbl_alias.lower()
                    if tbl_alias_lc in from_tables:
                        full_table, real_alias = from_tables[tbl_alias_lc]
                        resolved_sources.append((full_table, real_alias, col_name))
                    else:
                        resolved_sources.append((tbl_alias, col_name))
                else:
                    resolved_sources.append((tbl_alias, col_name))
            if isinstance(proj, exp.Column):
                col_type = "direct"
            elif proj.is_star:
                col_type = "star"
            elif not source_columns:
                col_type = "constant"
            else:
                col_type = "calculated"
            select_columns.append({
                "select_idx": idx,
                "target_column": alias,
                "expression": expression_sql,
                "source_columns": source_columns,
                "resolved_source_columns": resolved_sources,
                "type": col_type
            })
        all_columns.append(select_columns)
    return all_columns

for i, sql in enumerate(test_queries[1:2], 1):
    print(f"\n=== Test Query {i} ===")
    print(f'\n{sql}\n')
    all_columns = extract_snowflake_columns(sql)
    for select_idx, select_columns in enumerate(all_columns):
        print(f"  SELECT branch {select_idx+1}:")
        for col in select_columns:
            print(f"    Target column: {col['target_column']}")
            print(f"      Expression: {col['expression']}")
            print(f"      Source columns: {col['source_columns']}")
            print(f"      Resolved source columns: {col.get('resolved_source_columns', [])}")
            print(f"      Type: {col['type']}")



=== Test Query 1 ===


    SELECT u.user_id, o.order_id
    FROM analytics.users u
    JOIN analytics.orders o ON u.user_id = o.user_id
    UNION ALL
    SELECT user_id, NULL
    FROM analytics.inactive_users
    WHERE last_login < '2024-01-01'
    

  SELECT branch 1:
    Target column: user_id
      Expression: u.user_id
      Source columns: [('u', 'user_id')]
      Resolved source columns: [('analytics.users', 'u', 'user_id')]
      Type: direct
    Target column: order_id
      Expression: o.order_id
      Source columns: [('o', 'order_id')]
      Resolved source columns: [('o', 'order_id')]
      Type: direct
  SELECT branch 2:
    Target column: user_id
      Expression: user_id
      Source columns: [('', 'user_id')]
      Resolved source columns: [('analytics.inactive_users', 'user_id')]
      Type: direct
    Target column: NULL
      Expression: NULL
      Source columns: []
      Resolved source columns: []
      Type: constant


In [35]:
# claude attempt

import sqlglot
from sqlglot import exp

def extract_snowflake_columns(sql_query):
    """
    Extracts column lineage information from a Snowflake SQL query.
    Returns a list of lists, each describing the output columns for each SELECT.
    """
    parsed = sqlglot.parse_one(sql_query, dialect="snowflake")

    def expr_to_str(expr):
        return expr.sql(dialect="snowflake") if expr else None

    def collect_source_columns(expr):
        sources = set()
        for node in expr.walk():
            if isinstance(node, exp.Column):
                # Get table alias/name - could be empty string
                table_ref = node.table if node.table else ""
                sources.add((table_ref, node.name))
        return list(sources)

    # Helper: get all tables in the FROM clause of a SELECT
    def get_from_tables(select):
        """
        Returns a dict mapping alias (lowercase) -> (full_table_name, alias)
        """
        tables = {}
        
        from_expr = select.args.get("from")
        if from_expr:
            # Base table
            base = from_expr.args.get("this")
            if isinstance(base, exp.Table):
                db = base.catalog or ""
                schema = base.db or ""
                name = base.name
                if db and schema:
                    full_name = f"{db}.{schema}.{name}"
                elif schema:
                    full_name = f"{schema}.{name}"
                else:
                    full_name = name
                alias = base.alias or name
                tables[alias.lower()] = (full_name, alias)
        
        # JOINs are stored at the SELECT level, not FROM level
        joins = select.args.get("joins")
        if joins:
            for join in joins:
                join_table = join.args.get("this")
                if isinstance(join_table, exp.Table):
                    db = join_table.catalog or ""
                    schema = join_table.db or ""
                    name = join_table.name
                    if db and schema:
                        full_name = f"{db}.{schema}.{name}"
                    elif schema:
                        full_name = f"{schema}.{name}"
                    else:
                        full_name = name
                    alias = join_table.alias or name
                    tables[alias.lower()] = (full_name, alias)
                        
        return tables
    
    selects = [node for node in parsed.walk() if isinstance(node, exp.Select)]
    all_columns = []
    
    for idx, select in enumerate(selects):
        select_columns = []
        from_tables = get_from_tables(select)
        only_table = list(from_tables.values())[0][0] if len(from_tables) == 1 else None
        
        for proj in select.expressions:
            alias = proj.alias_or_name
            expression_sql = expr_to_str(proj)
            source_columns = collect_source_columns(proj)
            resolved_sources = []
            
            for tbl_alias, col_name in source_columns:
                if not tbl_alias and only_table:
                    # No table alias and only one table - use that table
                    resolved_sources.append((only_table, col_name))
                elif tbl_alias:
                    # Has table alias - look it up in from_tables
                    tbl_alias_lc = tbl_alias.lower()
                    if tbl_alias_lc in from_tables:
                        full_table, real_alias = from_tables[tbl_alias_lc]
                        resolved_sources.append((full_table, real_alias, col_name))
                    else:
                        # Alias not found in from_tables - keep as is
                        resolved_sources.append((tbl_alias, col_name))
                else:
                    # No table alias and multiple tables - ambiguous
                    resolved_sources.append((tbl_alias, col_name))
            
            # Determine column type
            if isinstance(proj, exp.Column):
                col_type = "direct"
            elif proj.is_star:
                col_type = "star"
            elif not source_columns:
                col_type = "constant"
            else:
                col_type = "calculated"
            
            select_columns.append({
                "select_idx": idx,
                "target_column": alias,
                "expression": expression_sql,
                "source_columns": source_columns,
                "resolved_source_columns": resolved_sources,
                "type": col_type
            })
        
        all_columns.append(select_columns)
    
    return all_columns

# Test with your example query
test_query = """
SELECT u.user_id, o.order_id
FROM analytics.users u
JOIN analytics.orders o ON u.user_id = o.user_id
UNION ALL
SELECT user_id, NULL
FROM analytics.inactive_users
WHERE last_login < '2024-01-01'
"""

print("=== Test Query 1 ===")
print(f'\n{test_queries[0]}\n')
all_columns = extract_snowflake_columns(test_queries[0])

for select_idx, select_columns in enumerate(all_columns):
    print(f"  SELECT branch {select_idx+1}:")
    for col in select_columns:
        print(f"    Target column: {col['target_column']}")
        print(f"      Expression: {col['expression']}")
        print(f"      Source columns: {col['source_columns']}")
        print(f"      Resolved source columns: {col.get('resolved_source_columns', [])}")
        print(f"      Type: {col['type']}")

=== Test Query 1 ===


    SELECT id FROM db1.schema1.tableA
    UNION
    SELECT id FROM db2.schema2.tableB
    

  SELECT branch 1:
    Target column: id
      Expression: id
      Source columns: [('', 'id')]
      Resolved source columns: [('db1.schema1.tableA', 'id')]
      Type: direct
  SELECT branch 2:
    Target column: id
      Expression: id
      Source columns: [('', 'id')]
      Resolved source columns: [('db2.schema2.tableB', 'id')]
      Type: direct


## Snapshots

Testing nested CTE resolution fix...

=== TESTING NESTED CTE RESOLUTION ===
Expected: customer_id should trace through both CTEs to 'customers' table



UnboundLocalError: cannot access local variable 'with_clause' where it is not associated with a value

In [20]:
# Copy the functions and test the nested CTE fix

# [The complete fixed functions would be copied here - same as the artifact above]

# Quick test for the nested CTE issue
def quick_test():
    sql = """
    WITH base_customers AS (
        SELECT customer_id, customer_name
        FROM customers
    ),
    enriched_customers AS (
        SELECT customer_id, customer_name, 'active' as status
        FROM base_customers  
    )
    SELECT customer_id, status
    FROM enriched_customers
    """
    
    print("=== QUICK NESTED CTE TEST ===")
    result = trace_column_lineage(sql, "customer_id")
    
    print("Dependencies found:")
    for col in result.get("next_columns_to_search", []):
        print(f"  - {col['table']}.{col['column']} ({col.get('level', 'unknown')})")
    
    # Success check
    customers_found = any(col['table'] == 'customers' for col in result.get("next_columns_to_search", []))
    print(f"\nResult: {'✅ SUCCESS' if customers_found else '❌ STILL BROKEN'}")
    
    if not customers_found:
        print("\nDebug - what we found instead:")
        for col in result.get("next_columns_to_search", []):
            print(f"  Table: {col['table']}")
            print(f"  Level: {col.get('level')}")
            print(f"  Reason: {col.get('source_reason', col.get('trace_reason', 'unknown'))}")

if __name__ == "__main__":
    quick_test()

=== QUICK NESTED CTE TEST ===


UnboundLocalError: cannot access local variable 'with_clause' where it is not associated with a value

In [4]:
trace_column_lineage(query,"etg_cd")

{'error': "Column 'etg_cd' not found in query output",
 'llm_context': "The column 'etg_cd' was not found in the final query output.",
 'next_columns_to_search': [],
 'full_lineage': {}}

In [22]:
import sqlglot
from sqlglot import exp

def extract_snowflake_columns(sql_query, existing_cte_registry=None):
    """
    Extracts column lineage information from a Snowflake SQL query.
    Returns a list of lists, each describing the output columns for each SELECT.
    FIXED: Now handles SELECT * properly and accepts existing CTE registry
    """
    parsed = sqlglot.parse_one(sql_query, dialect="snowflake")

    def expr_to_str(expr):
        return expr.sql(dialect="snowflake") if expr else None

    def collect_source_columns(expr):
        sources = set()
        for node in expr.walk():
            if isinstance(node, exp.Column):
                # Get table alias/name - could be empty string
                table_ref = node.table if node.table else ""
                sources.add((table_ref, node.name))
        return list(sources)

    # Helper: get all tables in the FROM clause of a SELECT
    def get_from_tables(select, cte_registry=None):
        """
        Returns a dict mapping alias (lowercase) -> (full_table_name, alias)
        Now also considers CTEs in the registry
        """
        if cte_registry is None:
            cte_registry = {}
            
        tables = {}
        
        from_expr = select.args.get("from")
        if from_expr:
            # Base table
            base = from_expr.args.get("this")
            if isinstance(base, exp.Table):
                db = base.catalog or ""
                schema = base.db or ""
                name = base.name
                
                # Check if this is a CTE first
                if name.lower() in cte_registry:
                    # This is a CTE reference
                    alias = base.alias or name
                    tables[alias.lower()] = (f"CTE:{name}", alias)
                else:
                    # Regular table
                    if db and schema:
                        full_name = f"{db}.{schema}.{name}"
                    elif schema:
                        full_name = f"{schema}.{name}"
                    else:
                        full_name = name
                    alias = base.alias or name
                    tables[alias.lower()] = (full_name, alias)
        
        # JOINs are stored at the SELECT level, not FROM level
        joins = select.args.get("joins")
        if joins:
            for join in joins:
                join_table = join.args.get("this")
                if isinstance(join_table, exp.Table):
                    db = join_table.catalog or ""
                    schema = join_table.db or ""
                    name = join_table.name
                    
                    # Check if this is a CTE first
                    if name.lower() in cte_registry:
                        # This is a CTE reference
                        alias = join_table.alias or name
                        tables[alias.lower()] = (f"CTE:{name}", alias)
                    else:
                        # Regular table
                        if db and schema:
                            full_name = f"{db}.{schema}.{name}"
                        elif schema:
                            full_name = f"{schema}.{name}"
                        else:
                            full_name = name
                        alias = join_table.alias or name
                        tables[alias.lower()] = (full_name, alias)
                        
        return tables
    
    # Build CTE registry first (or use existing one)
    cte_registry = existing_cte_registry or {}
    with_clause = parsed.args.get("with")
    
    # If no existing registry, build it from this query
    if not existing_cte_registry and with_clause:
        for cte in with_clause.expressions:
            cte_name = cte.alias
            cte_query = cte.this  # The SELECT part of the CTE
            cte_registry[cte_name.lower()] = cte_query
    
    selects = [node for node in parsed.walk() if isinstance(node, exp.Select)]
    
    # FIXED: Find the outermost/final SELECT (not part of CTE definitions)
    final_selects = []
    cte_select_ids = set()
    
    # Mark all CTE definition selects
    if with_clause:
        for cte in with_clause.expressions:
            cte_select = cte.this
            cte_select_ids.add(id(cte_select))
    
    # Only include selects that are NOT part of CTE definitions
    for select in selects:
        if id(select) not in cte_select_ids:
            final_selects.append(select)
    
    # If no final selects found, fall back to all selects (for simple queries)
    if not final_selects:
        final_selects = selects
    
    all_columns = []
    
    for idx, select in enumerate(final_selects):
        select_columns = []
        from_tables = get_from_tables(select, cte_registry)
        only_table = list(from_tables.values())[0][0] if len(from_tables) == 1 else None
        
        for proj in select.expressions:
            alias = proj.alias_or_name
            expression_sql = expr_to_str(proj)
            source_columns = collect_source_columns(proj)
            resolved_sources = []
            
            # FIXED: Handle SELECT * expansion
            if proj.is_star:
                # For SELECT *, create dependencies on all source tables
                for table_alias, (full_table, real_alias) in from_tables.items():
                    resolved_sources.append((full_table, real_alias, "*"))
                col_type = "star"
            else:
                for tbl_alias, col_name in source_columns:
                    if not tbl_alias and only_table:
                        # No table alias and only one table - use that table
                        resolved_sources.append((only_table, col_name))
                    elif tbl_alias:
                        # Has table alias - look it up in from_tables
                        tbl_alias_lc = tbl_alias.lower()
                        if tbl_alias_lc in from_tables:
                            full_table, real_alias = from_tables[tbl_alias_lc]
                            resolved_sources.append((full_table, real_alias, col_name))
                        else:
                            # Alias not found in from_tables - keep as is
                            resolved_sources.append((tbl_alias, col_name))
                    else:
                        # No table alias and multiple tables - ambiguous
                        resolved_sources.append((tbl_alias, col_name))
                
                # Determine column type
                if isinstance(proj, exp.Column):
                    col_type = "direct"
                elif not source_columns:
                    col_type = "constant"
                else:
                    col_type = "calculated"
            
            select_columns.append({
                "select_idx": idx,
                "target_column": alias,
                "expression": expression_sql,
                "source_columns": source_columns,
                "resolved_source_columns": resolved_sources,
                "type": col_type
            })
        
        all_columns.append(select_columns)
    
    return all_columns


def trace_column_lineage(sql_query, target_column_name, existing_cte_registry=None):
    """
    Traces a specific column through all transformations and builds LLM-ready context.
    FIXED: Properly handles aliases, single names, and recursive CTE resolution
    """
    
    def should_stop_tracing(full_table_name, internal_prefixes=['ph_'], cte_registry=None):
        """Determine if we should STOP tracing (found external source)"""
        
        # Always continue for CTE references - they need recursive resolution
        if full_table_name.startswith("CTE:"):
            return False, "cte_reference"
        
        # Check if this is actually a CTE name (without CTE: prefix)
        if cte_registry and full_table_name.lower() in cte_registry:
            return False, "is_cte_name"
        
        # Parse the table name
        parts = full_table_name.split('.')
        
        if len(parts) >= 3:
            # Full qualified name: database.schema.table
            database = parts[0]
            database_lower = database.lower()
            starts_with_internal = any(database_lower.startswith(prefix.lower()) for prefix in internal_prefixes)
            
            if not starts_with_internal:
                return True, "external_database"  # STOP - external database
            else:
                return False, "internal_database"  # CONTINUE - internal database
                
        elif len(parts) == 2:
            # schema.table - check if schema indicates external
            schema = parts[0].lower()
            if any(schema.startswith(prefix.lower()) for prefix in internal_prefixes):
                return False, "internal_schema"  # CONTINUE - internal schema
            else:
                return True, "external_schema"  # STOP - external schema
            
        else:
            # Single name - check if it's a CTE first, then treat as external
            if cte_registry and full_table_name.lower() in cte_registry:
                return False, "single_name_is_cte"  # CONTINUE - it's a CTE
            else:
                return True, "external_source_table"  # STOP - treat as external
    
    # Parse and build CTE registry (or use existing one for nested calls)
    parsed = sqlglot.parse_one(sql_query, dialect="snowflake")
    cte_registry = existing_cte_registry or {}
    with_clause = parsed.args.get("with")
    
    # If no existing registry, build it from this query
    if not existing_cte_registry and with_clause:
        for cte in with_clause.expressions:
            cte_name = cte.alias
            cte_query = cte.this
            cte_registry[cte_name.lower()] = cte_query
    
    # Get basic column analysis (pass CTE registry for nested CTE detection)
    base_columns = extract_snowflake_columns(sql_query, cte_registry)
    
    # Find the target column in the final output
    target_column_info = None
    target_select_branch = None
    
    for select_idx, select_columns in enumerate(base_columns):
        for col_info in select_columns:
            if col_info['target_column'].lower() == target_column_name.lower():
                target_column_info = col_info
                target_select_branch = select_idx + 1
                break
            # Also check if target column could come from SELECT *
            elif col_info['type'] == 'star':
                # For star selections, assume any requested column could be available
                target_column_info = {
                    'target_column': target_column_name,
                    'expression': f"* (includes {target_column_name})",
                    'type': 'star',
                    'resolved_source_columns': col_info['resolved_source_columns'],
                    'select_idx': select_idx
                }
                target_select_branch = select_idx + 1
                break
        if target_column_info:
            break
    
    if not target_column_info:
        return {
            "error": f"Column '{target_column_name}' not found in query output",
            "llm_context": f"The column '{target_column_name}' was not found in the final query output.",
            "next_columns_to_search": [],
            "full_lineage": {}
        }
    
    # Build LLM context
    llm_context_parts = []
    next_columns = []
    cte_transformations = []
    
    # Basic column information
    llm_context_parts.append(f"COLUMN: {target_column_name}")
    llm_context_parts.append(f"EXPRESSION: {target_column_info['expression']}")
    llm_context_parts.append(f"TRANSFORMATION TYPE: {target_column_info['type']}")
    
    if target_select_branch:
        llm_context_parts.append(f"FOUND IN: SELECT branch {target_select_branch}")
    
    # Process resolved sources and trace through CTEs
    resolved_sources = target_column_info.get('resolved_source_columns', [])
    
    if resolved_sources:
        llm_context_parts.append("\nSOURCE ANALYSIS:")
        
        # Group resolved sources by table to detect multiple dependencies from same table/CTE
        sources_by_table = {}
        for source in resolved_sources:
            if len(source) >= 3:  # (table, alias, column)
                table, alias, column = source[:3]
                table_key = table
                # For star selections, use the target column name instead of "*"
                if column == "*":
                    column = target_column_name
            elif len(source) == 2:  # (table, column)
                table, column = source
                alias = table
                table_key = table
                # For star selections, use the target column name
                if column == "*":
                    column = target_column_name
            else:
                continue
                
            if table_key not in sources_by_table:
                sources_by_table[table_key] = []
            sources_by_table[table_key].append((table, alias, column))
        
        # Process each table's dependencies
        for table_key, table_sources in sources_by_table.items():
            if len(table_sources) == 1:
                # Single dependency from this table
                table, alias, column = table_sources[0]
                
                # Check if this is a CTE reference
                if table.startswith("CTE:"):
                    cte_name = table.replace("CTE:", "")
                    llm_context_parts.append(f"  └─ CTE REFERENCE: {alias}.{column} → {cte_name}.{column}")
                    
                    # Trace through the CTE
                    if cte_name.lower() in cte_registry:
                        cte_query = cte_registry[cte_name.lower()]
                        cte_sql = cte_query.sql(dialect="snowflake")
                        llm_context_parts.append(f"  └─ TRACING CTE '{cte_name}' (INTRA-FILE TRANSFORMATION):")
                        
                        # FIXED: Recursively analyze the CTE with the current CTE registry
                        # This ensures nested CTE references are properly resolved
                        cte_trace = trace_column_lineage(cte_sql, column, cte_registry)
                        if "error" not in cte_trace:
                            # Add CTE transformation info
                            cte_transformations.append({
                                "cte_name": cte_name,
                                "column": column,
                                "transformation_type": "intra_file_cte",
                                "details": cte_trace.get("llm_context", ""),
                                "dependencies": cte_trace.get("next_columns_to_search", [])
                            })
                            
                            # Add CTE's dependencies to our next_columns
                            for cte_dep in cte_trace.get("next_columns_to_search", []):
                                next_columns.append({
                                    "table": cte_dep["table"],
                                    "column": cte_dep["column"],
                                    "context": f"External source for {target_column_name} via CTE {cte_name}",
                                    "level": "external_via_cte",
                                    "cte_intermediate": cte_name
                                })
                            
                            # Add CTE context to LLM output
                            cte_context_lines = cte_trace["llm_context"].split('\n')
                            for line in cte_context_lines:
                                if line.strip():
                                    llm_context_parts.append(f"    {line}")
                        else:
                            llm_context_parts.append(f"    ERROR tracing CTE: {cte_trace['error']}")
                    else:
                        llm_context_parts.append(f"    WARNING: CTE '{cte_name}' not found in registry")
                        
                else:
                    # Regular table reference - check if it's actually a CTE first
                    should_stop, reason = should_stop_tracing(table, cte_registry=cte_registry)
                    
                    if not should_stop and reason in ["is_cte_name", "single_name_is_cte"]:
                        # This is actually a CTE that we need to trace through
                        cte_name = table
                        llm_context_parts.append(f"  └─ NESTED CTE REFERENCE: {alias}.{column} → {cte_name}.{column}")
                        
                        if cte_name.lower() in cte_registry:
                            cte_query = cte_registry[cte_name.lower()]
                            cte_sql = cte_query.sql(dialect="snowflake")
                            llm_context_parts.append(f"  └─ TRACING NESTED CTE '{cte_name}' (INTRA-FILE TRANSFORMATION):")
                            
                            # Recursively analyze the nested CTE
                            cte_trace = trace_column_lineage(cte_sql, column, cte_registry)
                            if "error" not in cte_trace:
                                # Add CTE transformation info
                                cte_transformations.append({
                                    "cte_name": cte_name,
                                    "column": column,
                                    "transformation_type": "nested_cte",
                                    "details": cte_trace.get("llm_context", ""),
                                    "dependencies": cte_trace.get("next_columns_to_search", [])
                                })
                                
                                # Add CTE's dependencies to our next_columns
                                for cte_dep in cte_trace.get("next_columns_to_search", []):
                                    next_columns.append({
                                        "table": cte_dep["table"],
                                        "column": cte_dep["column"],
                                        "context": f"External source for {target_column_name} via nested CTE {cte_name}",
                                        "level": "external_via_cte",
                                        "cte_intermediate": cte_name
                                    })
                                
                                # Add CTE context to LLM output
                                cte_context_lines = cte_trace["llm_context"].split('\n')
                                for line in cte_context_lines:
                                    if line.strip():
                                        llm_context_parts.append(f"    {line}")
                            else:
                                llm_context_parts.append(f"    ERROR tracing nested CTE: {cte_trace['error']}")
                        else:
                            llm_context_parts.append(f"    WARNING: Nested CTE '{cte_name}' not found in registry")
                    
                    elif should_stop:
                        # External table - add to dependencies
                        llm_context_parts.append(f"  └─ EXTERNAL TABLE: {table}.{column} (referenced as {alias}.{column}) - {reason}")
                        next_columns.append({
                            "table": table,
                            "column": column,
                            "context": f"External table dependency for {target_column_name}",
                            "level": "external_table",
                            "source_reason": reason
                        })
                    else:
                        # Internal table - would need further tracing in full system
                        llm_context_parts.append(f"  └─ INTERNAL TABLE: {table}.{column} (referenced as {alias}.{column}) - {reason}")
                        # In test mode, treat internal tables as external to show the dependency
                        next_columns.append({
                            "table": table,
                            "column": column,
                            "context": f"Internal table dependency for {target_column_name}",
                            "level": "internal_table",
                            "trace_reason": reason
                        })
            
            else:
                # Multiple dependencies from same table - group them
                table_name = table_key.replace("CTE:", "") if table_key.startswith("CTE:") else table_key
                columns = [col for _, _, col in table_sources]
                
                llm_context_parts.append(f"  └─ MULTIPLE DEPENDENCIES FROM {table_name}:")
                for table, alias, column in table_sources:
                    llm_context_parts.append(f"    • {alias}.{column}")
                
                if table_key.startswith("CTE:"):
                    # Multiple CTE dependencies - consolidate them
                    cte_name = table_key.replace("CTE:", "")
                    llm_context_parts.append(f"  └─ CONSOLIDATED CTE ANALYSIS for '{cte_name}':")
                    
                    if cte_name.lower() in cte_registry:
                        cte_query = cte_registry[cte_name.lower()]
                        cte_sql = cte_query.sql(dialect="snowflake")
                        
                        # Get all unique external dependencies from this CTE
                        all_cte_deps = set()
                        cte_column_details = []
                        
                        for table, alias, column in table_sources:
                            # FIXED: Pass CTE registry to nested calls
                            cte_trace = trace_column_lineage(cte_sql, column, cte_registry)
                            if "error" not in cte_trace:
                                cte_column_details.append({
                                    "column": column,
                                    "trace": cte_trace
                                })
                                
                                # Collect external dependencies
                                for cte_dep in cte_trace.get("next_columns_to_search", []):
                                    dep_key = (cte_dep["table"], cte_dep["column"])
                                    all_cte_deps.add(dep_key)
                        
                        # Add consolidated CTE transformation
                        cte_transformations.append({
                            "cte_name": cte_name,
                            "columns": columns,
                            "transformation_type": "consolidated_cte",
                            "details": f"CTE processes {len(columns)} columns: {', '.join(columns)}",
                            "column_details": cte_column_details
                        })
                        
                        # Add unique external dependencies
                        for dep_table, dep_column in all_cte_deps:
                            next_columns.append({
                                "table": dep_table,
                                "column": dep_column,
                                "context": f"External source for {target_column_name} via consolidated CTE {cte_name}",
                                "level": "external_via_cte",
                                "cte_intermediate": cte_name
                            })
                        
                        # Show consolidated CTE analysis
                        llm_context_parts.append(f"    └─ CTE '{cte_name}' processes {len(columns)} output columns")
                        llm_context_parts.append(f"    └─ External dependencies: {len(all_cte_deps)} unique sources")
                        
                else:
                    # Multiple dependencies from regular table
                    for table, alias, column in table_sources:
                        should_stop, reason = should_stop_tracing(table, cte_registry=cte_registry)
                        
                        if not should_stop and reason in ["is_cte_name", "single_name_is_cte"]:
                            # This is actually a nested CTE reference
                            cte_name = table
                            if cte_name.lower() in cte_registry:
                                cte_query = cte_registry[cte_name.lower()]
                                cte_sql = cte_query.sql(dialect="snowflake")
                                
                                # Trace through this nested CTE
                                cte_trace = trace_column_lineage(cte_sql, column, cte_registry)
                                if "error" not in cte_trace:
                                    # Add nested CTE's dependencies
                                    for cte_dep in cte_trace.get("next_columns_to_search", []):
                                        next_columns.append({
                                            "table": cte_dep["table"],
                                            "column": cte_dep["column"],
                                            "context": f"External source for {target_column_name} via nested CTE {cte_name}",
                                            "level": "external_via_cte",
                                            "cte_intermediate": cte_name
                                        })
                        
                        elif should_stop:
                            next_columns.append({
                                "table": table,
                                "column": column,
                                "context": f"External table dependency for {target_column_name}",
                                "level": "external_table",
                                "source_reason": reason
                            })
                        else:
                            next_columns.append({
                                "table": table,
                                "column": column,
                                "context": f"Internal table dependency for {target_column_name}",
                                "level": "internal_table",
                                "trace_reason": reason
                            })
    
    # Show CTE transformations summary
    if cte_transformations:
        llm_context_parts.append(f"\nINTRA-FILE CTE TRANSFORMATIONS:")
        for cte_info in cte_transformations:
            # Handle both 'column' and 'columns' keys safely
            if 'column' in cte_info:
                column_info = cte_info['column']
            elif 'columns' in cte_info:
                columns_list = cte_info['columns']
                if isinstance(columns_list, list):
                    column_info = ', '.join(columns_list)
                else:
                    column_info = str(columns_list)
            else:
                column_info = 'unknown'
            
            llm_context_parts.append(f"  CTE '{cte_info['cte_name']}' transforms {column_info}")
            llm_context_parts.append(f"    └─ Type: {cte_info.get('transformation_type', 'unknown')}")
    
    # Show CTE definitions if relevant
    if cte_registry:
        llm_context_parts.append(f"\nAVAILABLE CTEs IN THIS FILE:")
        for cte_name in cte_registry.keys():
            llm_context_parts.append(f"  - {cte_name}")
    
    # Remove duplicates from next_columns
    unique_next_columns = []
    seen = set()
    for col in next_columns:
        key = (col['table'], col['column'], col.get('level', 'unknown'))
        if key not in seen:
            seen.add(key)
            unique_next_columns.append(col)
    
    return {
        "llm_context": "\n".join(llm_context_parts),
        "next_columns_to_search": unique_next_columns,
        "cte_transformations": cte_transformations,
        "full_lineage": target_column_info
    }


# Test functions for Jupyter
def test_nested_cte_fix():
    """Test that nested CTEs are properly resolved to external sources"""
    
    # Test case: enriched_customers -> base_customers -> customers
    sql = """
    WITH base_customers AS (
        SELECT customer_id, customer_name
        FROM customers
    ),
    enriched_customers AS (
        SELECT customer_id, customer_name, 'active' as status
        FROM base_customers  
    )
    SELECT customer_id, status
    FROM enriched_customers
    """
    
    print("=== TESTING NESTED CTE RESOLUTION ===")
    print("Expected: customer_id should trace through both CTEs to 'customers' table")
    print()
    
    result = trace_column_lineage(sql, "customer_id")
    
    if "error" in result:
        print(f"ERROR: {result['error']}")
        return False
    
    # Check if we found the ultimate external source
    next_columns = result.get("next_columns_to_search", [])
    
    print("Found dependencies:")
    for col in next_columns:
        print(f"  - {col['table']}.{col['column']} ({col['level']})")
    
    # Success criteria: should find 'customers' table as external source
    external_sources = [col for col in next_columns if col['table'] == 'customers']
    
    if external_sources:
        print("✅ SUCCESS: Found 'customers' as external source!")
        print(f"   Details: {external_sources[0]}")
        return True
    else:
        print("❌ FAILED: Should have found 'customers' table as external source")
        print("   Currently stops at intermediate CTE instead of tracing to ultimate source")
        
        # Show debug info
        print("\nDEBUG - Full LLM Context:")
        print(result.get("llm_context", "No context"))
        return False

def test_simple_cte_still_works():
    """Ensure simple CTE case still works after our changes"""
    
    sql = """
    WITH customer_base AS (
        SELECT customer_id, customer_name
        FROM customers
    )
    SELECT customer_id
    FROM customer_base
    """
    
    print("\n=== TESTING SIMPLE CTE (REGRESSION TEST) ===")
    
    result = trace_column_lineage(sql, "customer_id")
    
    if "error" in result:
        print(f"ERROR: {result['error']}")
        return False
    
    next_columns = result.get("next_columns_to_search", [])
    external_sources = [col for col in next_columns if col['table'] == 'customers']
    
    if external_sources:
        print("✅ SUCCESS: Simple CTE still works!")
        return True
    else:
        print("❌ FAILED: Simple CTE broken")
        return False

# Quick test runner
def run_nested_cte_tests():
    print("Testing nested CTE resolution fix...\n")
    
    # Run tests
    nested_ok = test_nested_cte_fix()
    simple_ok = test_simple_cte_still_works()
    
    print(f"\n=== RESULTS ===")
    print(f"Nested CTE resolution: {'✅ PASS' if nested_ok else '❌ FAIL'}")
    print(f"Simple CTE regression: {'✅ PASS' if simple_ok else '❌ FAIL'}")
    
    if nested_ok and simple_ok:
        print("\n🎉 All tests passed! Ready to integrate with main DBT tracer.")
    else:
        print("\n🔧 Still needs fixes before integration.")
    
    return nested_ok and simple_ok

In [23]:
# Copy the functions from the artifact above, then run:
run_nested_cte_tests()

Testing nested CTE resolution fix...

=== TESTING NESTED CTE RESOLUTION ===
Expected: customer_id should trace through both CTEs to 'customers' table

Found dependencies:
  - customers.customer_id (external_via_cte)
✅ SUCCESS: Found 'customers' as external source!
   Details: {'table': 'customers', 'column': 'customer_id', 'context': 'External source for customer_id via CTE enriched_customers', 'level': 'external_via_cte', 'cte_intermediate': 'enriched_customers'}

=== TESTING SIMPLE CTE (REGRESSION TEST) ===
✅ SUCCESS: Simple CTE still works!

=== RESULTS ===
Nested CTE resolution: ✅ PASS
Simple CTE regression: ✅ PASS

🎉 All tests passed! Ready to integrate with main DBT tracer.


True

In [24]:
def test_simple_select():
    """Test 1: Basic SELECT"""
    sql = """
    SELECT customer_id, customer_name
    FROM customers
    """
    
    print("=== TEST 1: Simple SELECT ===")
    result = trace_column_lineage(sql, "customer_id")
    print("Result:", result.get("error", "SUCCESS"))
    print("Next columns:", result.get("next_columns_to_search", []))
    print()

def test_select_star():
    """Test 2: SELECT * """
    sql = """
    SELECT *
    FROM customers
    """
    
    print("=== TEST 2: SELECT * ===")
    result = trace_column_lineage(sql, "customer_id")
    print("Result:", result.get("error", "SUCCESS"))
    print("Next columns:", result.get("next_columns_to_search", []))
    print()

def test_simple_cte():
    """Test 3: Simple CTE"""
    sql = """
    WITH customer_base AS (
        SELECT customer_id, customer_name
        FROM customers
    )
    SELECT customer_id
    FROM customer_base
    """
    
    print("=== TEST 3: Simple CTE ===")
    result = trace_column_lineage(sql, "customer_id")
    print("Result:", result.get("error", "SUCCESS"))
    print("Next columns:", result.get("next_columns_to_search", []))
    print("CTE transformations:", len(result.get("cte_transformations", [])))
    print()

def test_cte_with_star():
    """Test 4: CTE with SELECT *"""
    sql = """
    WITH customer_base AS (
        SELECT *
        FROM customers
    )
    SELECT customer_id
    FROM customer_base
    """
    
    print("=== TEST 4: CTE with SELECT * ===")
    result = trace_column_lineage(sql, "customer_id")
    print("Result:", result.get("error", "SUCCESS"))
    print("Next columns:", result.get("next_columns_to_search", []))
    print()

def test_multiple_ctes():
    """Test 5: Multiple CTEs"""
    sql = """
    WITH base_customers AS (
        SELECT customer_id, customer_name
        FROM customers
    ),
    enriched_customers AS (
        SELECT customer_id, customer_name, 'active' as status
        FROM base_customers
    )
    SELECT customer_id, status
    FROM enriched_customers
    """
    
    print("=== TEST 5: Multiple CTEs ===")
    result = trace_column_lineage(sql, "customer_id")
    print("Result:", result.get("error", "SUCCESS"))
    print("Next columns:", result.get("next_columns_to_search", []))
    print("CTE transformations:", len(result.get("cte_transformations", [])))
    print()

def test_dbt_style_ctes():
    """Test 6: DBT-style CTEs (simplified version of your complex query)"""
    sql = """
    WITH __dbt__cte__table1 AS (
        SELECT col1, col2, col3
        FROM schema_name.table_name1
    ),
    cte_main AS (
        SELECT col1, col2
        FROM __dbt__cte__table1
    )
    SELECT col1
    FROM cte_main
    """
    
    print("=== TEST 6: DBT-style CTEs ===")
    result = trace_column_lineage(sql, "col1")
    print("Result:", result.get("error", "SUCCESS"))
    print("Next columns:", result.get("next_columns_to_search", []))
    print()

def test_extract_columns():
    """Test 7: Just test column extraction"""
    sql = """
    SELECT customer_id, customer_name
    FROM customers
    """
    
    print("=== TEST 7: Column Extraction ===")
    try:
        columns = extract_snowflake_columns(sql)
        print("Extracted columns:", len(columns), "select statements")
        for i, select_cols in enumerate(columns):
            print(f"  Select {i}: {len(select_cols)} columns")
            for col in select_cols:
                print(f"    - {col['target_column']}: {col['type']}")
    except Exception as e:
        print("Error:", str(e))
    print()

if __name__ == "__main__":
    print("Running simple lineage tests...\n")
    
    # Run tests in order of complexity
    test_extract_columns()
    test_simple_select()
    test_select_star()
    test_simple_cte()
    test_cte_with_star()
    test_multiple_ctes()
    test_dbt_style_ctes()
    
    print("Tests complete!")

Running simple lineage tests...

=== TEST 7: Column Extraction ===
Extracted columns: 1 select statements
  Select 0: 2 columns
    - customer_id: direct
    - customer_name: direct

=== TEST 1: Simple SELECT ===
Result: SUCCESS
Next columns: [{'table': 'customers', 'column': 'customer_id', 'context': 'External table dependency for customer_id', 'level': 'external_table', 'source_reason': 'external_source_table'}]

=== TEST 2: SELECT * ===
Result: SUCCESS
Next columns: [{'table': 'customers', 'column': 'customer_id', 'context': 'External table dependency for customer_id', 'level': 'external_table', 'source_reason': 'external_source_table'}]

=== TEST 3: Simple CTE ===
Result: SUCCESS
Next columns: [{'table': 'customers', 'column': 'customer_id', 'context': 'External source for customer_id via CTE customer_base', 'level': 'external_via_cte', 'cte_intermediate': 'customer_base'}]
CTE transformations: 1

=== TEST 4: CTE with SELECT * ===
Result: SUCCESS
Next columns: [{'table': 'customers