In [1]:
# üîπ Cell 1 ‚Äî Imports & Setup
import torch
import sys
import os
sys.path.append("..")

from src.schema_parser import SchemaParser
from src.nl_parser import NLParser
from src.semantic_aligner import SemanticAligner
from src.schema_binder import bind_schema_tokens
from src.ast_adapter import adapt_token_ast
from src.ast_renderer import SQLRenderer
from src.where_parser import WhereParser
from src.phase2_inference import infer_phase2_sql
from src.phase3_inference import infer_phase3_sql

from src.utils import (
    tokens_to_ids,
    ids_to_tokens,
    create_attention_mask,
    get_device,
    get_allowed_tokens
)

from src.vocab import START, PAD, TOKEN2ID, ID2TOKEN, tokens_to_ast,ast_to_tokens
from models.sql_transformer import SQLTransformer


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# üîπ Cell 2 ‚Äî Load Phase-4.5 Model (FIXED)
device = get_device()
model = SQLTransformer().to(device)

PHASE45_CKPT = "notebooks/checkpoints/phase4_5_join/phase4_5_best.pt"

ckpt = torch.load(PHASE45_CKPT, map_location=device)
model.load_state_dict(ckpt["model_state_dict"])

model.eval()
print("‚úÖ Phase-4.5 JOIN model loaded (Backward Compatible)")


‚úÖ Phase-4.5 JOIN model loaded (Backward Compatible)


In [3]:
# üîπ Cell 3 ‚Äî User Schema (PK‚ÄìFK EXPLICIT) üî•
USER_SCHEMA = {
    "schema_id": "U1",
    "tables": {

        "employees": {
            "pk": "emp_id",
            "columns": ["emp_id", "first_name", "dept_id", "salary", "location"]
        },

        "departments": {
            "pk": "dept_id",
            "columns": ["dept_id", "dept_name", "manager_id"]
        },

        "customers": {
            "pk": "customer_id",
            "columns": ["customer_id", "last_name", "country"]
        },

        "orders": {
            "pk": "order_id",
            "columns": ["order_id", "customer_id", "total_amount"],
            "fk": {
                "customer_id": "customers.customer_id"
            }
        },
        
        "projects": {
            "pk": "project_id",
            "columns": ["project_id", "project_name", "dept_id", "budget"],
            "fk": {
                "dept_id": "departments.dept_id"
            }
        },
        
        "assignments": {
            "pk": "assignment_id",
            "columns": ["assignment_id", "emp_id", "project_id", "hours_worked"],
            "fk": {
                "emp_id": "employees.emp_id",
                "project_id": "projects.project_id"
            }
        }
    }
}

# USER_NL_QUERY = "get project_id from projects"
# USER_NL_QUERY = "get dept_id from departments"
USER_NL_QUERY = "get dept_id and dept_name from departments"

# USER_NL_QUERY = "get dept_name from departments where dept_id = 123"
# USER_NL_QUERY = "show first_name from employees where dept_id = 123 and location is pune"

# USER_NL_QUERY = "show average salary from employees"
# USER_NL_QUERY = "show average salary from employees by dept_id and location"
# USER_NL_QUERY = "show average salary from employees by dept_id having average salary > 5000"

# USER_NL_QUERY = "show last_name and total_amount from customers and orders"
# USER_NL_QUERY = "list last_name and total_amount from customers and orders where total_amount is greater than 500"
# USER_NL_QUERY = "show first_name and dept_name from employees and departments where dept_name is ai and manager_id is 123"

# join+groupby+having
# USER_NL_QUERY = "show average salary by dept_id for employees and departments"
# USER_NL_QUERY = "show last_name with count order_id above 5 for customers and orders"
# USER_NL_QUERY = "show dept_name with average salary above 50000 for employees and departments"

# Left Join
# USER_NL_QUERY = "show last_name with count of order_id including customers with zero orders"
# USER_NL_QUERY = "list all employees and their departments if assigned"

# Right Join
# USER_NL_QUERY = "show first_name and dept_name for employees and their departments including departments without employees"
#17 USER_NL_QUERY = "list dept_name and first_name from employees and departments including departments with no employees"
# USER_NL_QUERY = "show first_name and dept_name from employees and departments where manager_id is 123 including departments without employees"
# USER_NL_QUERY = "show first_name and dept_name from employees and departments where dept_id is greater than 100 including departments without employees"
#20 USER_NL_QUERY = "show dept_name with count emp_id for employees and their departments including departments without employees"
#21 USER_NL_QUERY = "show dept_name with count emp_id above 5 for employees and their departments including departments without employees"


In [4]:
# üîπ Cell 4 ‚Äî PK‚ÄìFK Relationship Discovery (REPLACEMENT)
def discover_pk_fk_relationships(schema_json):
    relationships = []

    tables = schema_json["tables"]

    # 1Ô∏è‚É£ Explicit FK definitions (highest priority)
    for table, meta in tables.items():
        for fk_col, ref in meta.get("fk", {}).items():
            ref_table, ref_col = ref.split(".")
            relationships.append({
                "left_table": table,
                "left_col": fk_col,
                "right_table": ref_table,
                "right_col": ref_col
            })

    # 2Ô∏è‚É£ Implicit PK‚ÄìFK by column name (REAL-WORLD FIX)
    for t1, m1 in tables.items():
        for t2, m2 in tables.items():
            if t1 == t2:
                continue

            pk = m2.get("pk")
            if not pk:
                continue

            if pk in m1.get("columns", []):
                relationships.append({
                    "left_table": t1,
                    "left_col": pk,
                    "right_table": t2,
                    "right_col": pk
                })

    return relationships


In [5]:
def build_schema_bindings(ast, schema_json, signals,nl_query):
    """
    Bind abstract tokens to concrete schema elements
    using NL intent + PK‚ÄìFK constraints.
    """

    bindings = {
        "<TABLE>": [],
        "<COLUMN>": [],
        "<AGG>": None
    }

    # ---------- TABLE ----------
    base_table = ast["from"][0]
    # ---------- TABLE ----------
    if ast.get("joins"):
        bindings["<TABLE>"] = [ast["from"][0]] + [j["table"] for j in ast["joins"]]
    else:
        # üî• Single-table query
        bindings["<TABLE>"] = [ast["from"][0]]


    # ---------- COLUMNS (projection-aware) ----------
    aligner = SemanticAligner()
    mapping = aligner.align(
        user_terms=signals["entities"],
        schema_terms=[
            f"{t}.{c}"
            for t, meta in schema_json["tables"].items()
            for c in meta["columns"]
        ]
    )

    projection_text = nl_query.lower()


    for term, col in mapping.items():
        if term in projection_text:
            bindings["<COLUMN>"].append(col)

    # Fallback: select PK if nothing resolved
    if not bindings["<COLUMN>"]:
        pk = schema_json["tables"][base_table]["pk"]
        bindings["<COLUMN>"].append(f"{base_table}.{pk}")

    # ---------- AGG ----------
    if ast["select"] and ast["select"][0].get("agg"):
        bindings["<AGG>"] = ast["select"][0]["agg"]

    # ---------- FINAL SAFETY: DEDUP COLUMNS ----------
    bindings["<COLUMN>"] = list(dict.fromkeys(bindings["<COLUMN>"]))


    return bindings


In [None]:
def deduplicate_select_columns(ast):
    seen = set()
    unique_select = []

    for item in ast.get("select", []):
        col = item["column"]
        if col not in seen:
            seen.add(col)
            unique_select.append(item)

    ast["select"] = unique_select


In [None]:
# üîπ Cell 5 - phase4 Inference updated for right join
def infer_phase4_sql(schema_json, nl_query):
    # ============================================================
    # Phase-4.5 NL ‚Üí SQL (UPDATED ‚Äî LEFT + RIGHT JOIN SAFE)
    # ============================================================

    nl_lower = nl_query.lower()
    nl_parser = NLParser()
    signals = nl_parser.parse(nl_query)

    schema_tables = list(schema_json["tables"].keys())
    resolved_tables = [t for t in signals["tables"] if t in schema_tables]

    # ----------------------------
    # Phase routing
    # ----------------------------
    if len(resolved_tables) < 2:
        if signals["aggregations"] or signals["group_by"] or signals["having"]:
            return infer_phase3_sql(schema_json, nl_query)
        return infer_phase2_sql(schema_json, nl_query)

    base_table, join_table = resolved_tables[:2]

    # ----------------------------
    # TRUE aggregation intent
    # ----------------------------
    agg_verbs = ["average", "avg", "count", "number of", "how many"]
    has_agg_intent = any(v in nl_lower for v in agg_verbs)
    plain_join = not has_agg_intent

    # ----------------------------
    # Discover JOIN (schema-agnostic)
    # ----------------------------
    pk_fk = discover_pk_fk_relationships(schema_json)
    rel = next(
        r for r in pk_fk
        if {r["left_table"], r["right_table"]} == {base_table, join_table}
    )

    join_type = signals.get("join_type", "INNER")
    preserve_table = signals.get("preserve_table")

    # ============================================================
    # üî• NEW: Derive JOIN type from preserve_table (SAFE & GENERIC)
    # ============================================================
    if preserve_table:
        if preserve_table == base_table:
            join_type = "LEFT"
        elif preserve_table == join_table:
            join_type = "RIGHT"

    ast = {
        "select": [],
        "from": [base_table],
        "joins": [{
            "type": join_type,
            "table": join_table,
            "on": {
                "left": f"{rel['left_table']}.{rel['left_col']}",
                "op": "=",
                "right": f"{rel['right_table']}.{rel['right_col']}",
                "extra_conditions": []
            }
        }],
        "where": None,
        "group_by": [],
        "having": []
    }

    # ============================================================
    # A. PLAIN JOIN (NO AGG)
    # ============================================================
    if plain_join:
        projection_text = nl_lower.split(" where ")[0]

        for t in resolved_tables:
            for c in schema_json["tables"][t]["columns"]:
                if c in projection_text:
                    ast["select"].append({
                        "agg": None,
                        "column": f"{t}.{c}"
                    })

        if not ast["select"]:
            for t in resolved_tables:
                cols = schema_json["tables"][t]["columns"]
                readable = next((c for c in cols if not c.endswith("_id")), cols[0])
                ast["select"].append({
                    "agg": None,
                    "column": f"{t}.{readable}"
                })

        if " where " in nl_lower:
            where_parser = WhereParser(nl_parser, SemanticAligner())
            where_ast = where_parser.build_tree(
                where_parser.tokenize(nl_lower.split("where", 1)[1]),
                base_table,
                schema_json["tables"][base_table]["columns"],
                [
                    f"{t}.{c}"
                    for t in resolved_tables
                    for c in schema_json["tables"][t]["columns"]
                ]
            )

            # ============================================================
            # GENERIC OUTER JOIN WHERE SAFETY
            # ============================================================
            if join_type in ["LEFT", "RIGHT"] and preserve_table:
                preserved = preserve_table
                nullable = join_table if preserved == base_table else base_table

                safe = []
                for cond in where_ast if isinstance(where_ast, list) else [where_ast]:
                    if cond["column"].startswith(f"{nullable}."):
                        ast["joins"][0]["on"]["extra_conditions"].append(cond)
                    else:
                        safe.append(cond)

                ast["where"] = safe or None
            else:
                ast["where"] = where_ast

        return SQLRenderer().render(ast)

    # ============================================================
    # B. AGGREGATION (JOIN + GROUP BY + HAVING)
    # ============================================================

    agg_func = "COUNT" if "count" in nl_lower else "AVG"

    # ----------------------------
    # GROUP BY resolution
    # ----------------------------
    group_col = None

    if " by " in nl_lower:
        after_by = nl_lower.split(" by ", 1)[1]
        for t in resolved_tables:
            for c in schema_json["tables"][t]["columns"]:
                if c in after_by:
                    group_col = f"{t}.{c}"
                    break
            if group_col:
                break

    if not group_col:
        projection_text = nl_lower.split(" where ")[0]
        for t in resolved_tables:
            for c in schema_json["tables"][t]["columns"]:
                if (
                    c in projection_text
                    and not c.endswith("_id")
                ):
                    group_col = f"{t}.{c}"
                    break
            if group_col:
                break

    if not group_col:
        raise ValueError("‚ùå GROUP BY column not resolved")

    ast["group_by"] = [group_col]

    # ----------------------------
    # Aggregation column (schema-agnostic)
    # ----------------------------
    if agg_func == "COUNT":
        pk = schema_json["tables"][join_table]["pk"]
        agg_col = f"{join_table}.{pk}"
    else:
        agg_col = None
        for t in resolved_tables:
            for c in schema_json["tables"][t]["columns"]:
                if c in nl_lower and not c.endswith("_id"):
                    agg_col = f"{t}.{c}"
                    break
            if agg_col:
                break

        if not agg_col:
            raise ValueError("‚ùå Aggregation column not resolved")

    ast["select"] = [
        {"agg": None, "column": group_col},
        {"agg": agg_func, "column": agg_col}
    ]

    # ----------------------------
    # HAVING
    # ----------------------------
    if signals["numbers"]:
        ast["having"] = {
            "agg": agg_func,
            "column": agg_col,
            "op": ">",
            "value": int(signals["numbers"][0])
        }

    return SQLRenderer().render(ast)



# working for all 15 user nl queries till left join
# def infer_phase4_sql(schema_json, nl_query):
#     # ============================================================
#     # Phase-4.5 NL ‚Üí SQL (FINAL ‚Äî ALL 15 USER_NL_QUERY PASS)
#     # ============================================================

#     nl_lower = nl_query.lower()
#     nl_parser = NLParser()
#     signals = nl_parser.parse(nl_query)

#     schema_tables = list(schema_json["tables"].keys())
#     resolved_tables = [t for t in signals["tables"] if t in schema_tables]

#     # ----------------------------
#     # Phase routing
#     # ----------------------------
#     if len(resolved_tables) < 2:
#         if signals["aggregations"] or signals["group_by"] or signals["having"]:
#             return infer_phase3_sql(schema_json, nl_query)
#         return infer_phase2_sql(schema_json, nl_query)

#     base_table, join_table = resolved_tables[:2]

#     # ----------------------------
#     # TRUE aggregation intent
#     # ----------------------------
#     agg_verbs = ["average", "avg", "count", "number of", "how many"]
#     has_agg_intent = any(v in nl_lower for v in agg_verbs)
#     plain_join = not has_agg_intent

#     # ----------------------------
#     # Discover JOIN
#     # ----------------------------
#     pk_fk = discover_pk_fk_relationships(schema_json)
#     rel = next(
#         r for r in pk_fk
#         if {r["left_table"], r["right_table"]} == {base_table, join_table}
#     )

#     join_type = signals.get("join_type", "INNER")

#     ast = {
#         "select": [],
#         "from": [base_table],
#         "joins": [{
#             "type": join_type,
#             "table": join_table,
#             "on": {
#                 "left": f"{rel['left_table']}.{rel['left_col']}",
#                 "op": "=",
#                 "right": f"{rel['right_table']}.{rel['right_col']}",
#                 "extra_conditions": []
#             }
#         }],
#         "where": None,
#         "group_by": [],
#         "having": []
#     }

#     # ============================================================
#     # A. PLAIN JOIN (NO AGG)
#     # ============================================================
#     if plain_join:
#         projection_text = nl_lower.split(" where ")[0]

#         for t in resolved_tables:
#             for c in schema_json["tables"][t]["columns"]:
#                 if c in projection_text:
#                     ast["select"].append({
#                         "agg": None,
#                         "column": f"{t}.{c}"
#                     })

#         if not ast["select"]:
#             for t in resolved_tables:
#                 cols = schema_json["tables"][t]["columns"]
#                 readable = next((c for c in cols if not c.endswith("_id")), cols[0])
#                 ast["select"].append({
#                     "agg": None,
#                     "column": f"{t}.{readable}"
#                 })

#         if " where " in nl_lower:
#             where_parser = WhereParser(nl_parser, SemanticAligner())
#             where_ast = where_parser.build_tree(
#                 where_parser.tokenize(nl_lower.split("where", 1)[1]),
#                 base_table,
#                 schema_json["tables"][base_table]["columns"],
#                 [
#                     f"{t}.{c}"
#                     for t in resolved_tables
#                     for c in schema_json["tables"][t]["columns"]
#                 ]
#             )
#             ast["where"] = where_ast

#         return SQLRenderer().render(ast)

#     # ============================================================
#     # B. AGGREGATION (JOIN + GROUP BY + HAVING)
#     # ============================================================

#     agg_func = "COUNT" if "count" in nl_lower else "AVG"

#     # ----------------------------
#     # GROUP BY resolution
#     # ----------------------------
#     group_col = None

#     # 1Ô∏è‚É£ Strict "by" parsing
#     if " by " in nl_lower:
#         after_by = nl_lower.split(" by ", 1)[1]
#         for t in resolved_tables:
#             for c in schema_json["tables"][t]["columns"]:
#                 if c in after_by:
#                     group_col = f"{t}.{c}"
#                     break
#             if group_col:
#                 break

#     # 2Ô∏è‚É£ If no "by", use non-aggregation projection column
#     if not group_col:
#         projection_text = nl_lower.split(" where ")[0]
#         for t in resolved_tables:
#             for c in schema_json["tables"][t]["columns"]:
#                 if (
#                     c in projection_text
#                     and not c.endswith("_id")
#                     and c not in ["salary", "order_id"]
#                 ):
#                     group_col = f"{t}.{c}"
#                     break
#             if group_col:
#                 break

#     if not group_col:
#         raise ValueError("‚ùå GROUP BY column not resolved")

#     ast["group_by"] = [group_col]

#     # ----------------------------
#     # Aggregation column
#     # ----------------------------
#     if agg_func == "COUNT":
#         pk = schema_json["tables"][join_table]["pk"]
#         agg_col = f"{join_table}.{pk}"
#     else:
#         agg_col = "employees.salary"

#     ast["select"] = [
#         {"agg": None, "column": group_col},
#         {"agg": agg_func, "column": agg_col}
#     ]

#     # ----------------------------
#     # HAVING
#     # ----------------------------
#     if signals["numbers"]:
#         ast["having"] = {
#             "agg": agg_func,
#             "column": agg_col,
#             "op": ">",
#             "value": int(signals["numbers"][0])
#         }

#     return SQLRenderer().render(ast)

In [None]:

# üîπ Cell 6 - Run Inference
sql_output = infer_phase4_sql(USER_SCHEMA, USER_NL_QUERY)

print("üß† NL Query :", USER_NL_QUERY)
print("üßæ SQL Query:", sql_output)


In [None]:
# # # CHATGPTS working for all 15 user nl queries

# def infer_phase4_sql(schema_json, nl_query):
#     # ============================================================
#     # Phase-4.5 NL ‚Üí SQL (FINAL ‚Äî ALL 15 USER_NL_QUERY PASS)
#     # ============================================================

#     nl_lower = nl_query.lower()
#     nl_parser = NLParser()
#     signals = nl_parser.parse(nl_query)

#     schema_tables = list(schema_json["tables"].keys())
#     resolved_tables = [t for t in signals["tables"] if t in schema_tables]

#     # ----------------------------
#     # Phase routing
#     # ----------------------------
#     if len(resolved_tables) < 2:
#         if signals["aggregations"] or signals["group_by"] or signals["having"]:
#             return infer_phase3_sql(schema_json, nl_query)
#         return infer_phase2_sql(schema_json, nl_query)

#     base_table, join_table = resolved_tables[:2]

#     # ----------------------------
#     # TRUE aggregation intent
#     # ----------------------------
#     agg_verbs = ["average", "avg", "count", "number of", "how many"]
#     has_agg_intent = any(v in nl_lower for v in agg_verbs)
#     plain_join = not has_agg_intent

#     # ----------------------------
#     # Discover JOIN
#     # ----------------------------
#     pk_fk = discover_pk_fk_relationships(schema_json)
#     rel = next(
#         r for r in pk_fk
#         if {r["left_table"], r["right_table"]} == {base_table, join_table}
#     )

#     join_type = signals.get("join_type", "INNER")

#     ast = {
#         "select": [],
#         "from": [base_table],
#         "joins": [{
#             "type": join_type,
#             "table": join_table,
#             "on": {
#                 "left": f"{rel['left_table']}.{rel['left_col']}",
#                 "op": "=",
#                 "right": f"{rel['right_table']}.{rel['right_col']}",
#                 "extra_conditions": []
#             }
#         }],
#         "where": None,
#         "group_by": [],
#         "having": []
#     }

#     # ============================================================
#     # A. PLAIN JOIN (NO AGG)
#     # ============================================================
#     if plain_join:
#         projection_text = nl_lower.split(" where ")[0]

#         for t in resolved_tables:
#             for c in schema_json["tables"][t]["columns"]:
#                 if c in projection_text:
#                     ast["select"].append({
#                         "agg": None,
#                         "column": f"{t}.{c}"
#                     })

#         if not ast["select"]:
#             for t in resolved_tables:
#                 cols = schema_json["tables"][t]["columns"]
#                 readable = next((c for c in cols if not c.endswith("_id")), cols[0])
#                 ast["select"].append({
#                     "agg": None,
#                     "column": f"{t}.{readable}"
#                 })

#         if " where " in nl_lower:
#             where_parser = WhereParser(nl_parser, SemanticAligner())
#             where_ast = where_parser.build_tree(
#                 where_parser.tokenize(nl_lower.split("where", 1)[1]),
#                 base_table,
#                 schema_json["tables"][base_table]["columns"],
#                 [
#                     f"{t}.{c}"
#                     for t in resolved_tables
#                     for c in schema_json["tables"][t]["columns"]
#                 ]
#             )
#             ast["where"] = where_ast

#         return SQLRenderer().render(ast)

#     # ============================================================
#     # B. AGGREGATION (JOIN + GROUP BY + HAVING)
#     # ============================================================

#     agg_func = "COUNT" if "count" in nl_lower else "AVG"

#     # ----------------------------
#     # GROUP BY resolution
#     # ----------------------------
#     group_col = None

#     # 1Ô∏è‚É£ Strict "by" parsing
#     if " by " in nl_lower:
#         after_by = nl_lower.split(" by ", 1)[1]
#         for t in resolved_tables:
#             for c in schema_json["tables"][t]["columns"]:
#                 if c in after_by:
#                     group_col = f"{t}.{c}"
#                     break
#             if group_col:
#                 break

#     # 2Ô∏è‚É£ If no "by", use non-aggregation projection column
#     if not group_col:
#         projection_text = nl_lower.split(" where ")[0]
#         for t in resolved_tables:
#             for c in schema_json["tables"][t]["columns"]:
#                 if (
#                     c in projection_text
#                     and not c.endswith("_id")
#                     and c not in ["salary", "order_id"]
#                 ):
#                     group_col = f"{t}.{c}"
#                     break
#             if group_col:
#                 break

#     if not group_col:
#         raise ValueError("‚ùå GROUP BY column not resolved")

#     ast["group_by"] = [group_col]

#     # ----------------------------
#     # Aggregation column
#     # ----------------------------
#     if agg_func == "COUNT":
#         pk = schema_json["tables"][join_table]["pk"]
#         agg_col = f"{join_table}.{pk}"
#     else:
#         agg_col = "employees.salary"

#     ast["select"] = [
#         {"agg": None, "column": group_col},
#         {"agg": agg_func, "column": agg_col}
#     ]

#     # ----------------------------
#     # HAVING
#     # ----------------------------
#     if signals["numbers"]:
#         ast["having"] = {
#             "agg": agg_func,
#             "column": agg_col,
#             "op": ">",
#             "value": int(signals["numbers"][0])
#         }

#     return SQLRenderer().render(ast)


