In [1]:
# # üîπ Cell 1 ‚Äî Imports & Setup
# import torch
# import sys
# import os
# sys.path.append("..")

# from src.schema_parser import SchemaParser
# from src.nl_parser import NLParser
# from src.semantic_aligner import SemanticAligner
# from src.schema_binder import bind_schema_tokens
# from src.ast_adapter import adapt_token_ast
# from src.ast_renderer import SQLRenderer
# from src.where_parser import WhereParser


# from src.utils import (
#     tokens_to_ids,
#     ids_to_tokens,

#     create_attention_mask,
#     get_device,
#     get_allowed_tokens  # üî• Ensure this is the updated version
# )

# from src.vocab import START, PAD, TOKEN2ID, ID2TOKEN, tokens_to_ast, AGG, VALUE,OPS
# from models.sql_transformer import SQLTransformer


import torch
import sys
import os
sys.path.append("..")

from src.schema_parser import SchemaParser
from src.nl_parser import NLParser
from src.semantic_aligner import SemanticAligner
from src.schema_binder import bind_schema_tokens
from src.ast_adapter import adapt_token_ast
from src.ast_renderer import SQLRenderer
from src.where_parser import WhereParser
from src.phase2_inference import infer_phase2_sql
from src.phase3_inference import infer_phase3_sql

from src.utils import (
    tokens_to_ids,
    ids_to_tokens,
    create_attention_mask,
    get_device,
    get_allowed_tokens
)

from src.vocab import START, PAD, TOKEN2ID, ID2TOKEN, tokens_to_ast
from models.sql_transformer import SQLTransformer


In [2]:
# # üîπ Cell 2 ‚Äî Load Phase-4 JOIN Model
# device = get_device()
# model = SQLTransformer().to(device)

# # ‚úÖ PHASE-4 LOADING SURGERY
# PHASE4_CKPT = "checkpoints/phase4_model.pt"
# checkpoint_state = torch.load(PHASE4_CKPT, map_location=device)
# model_state = model.state_dict()

# # Map old weights to new model indices (0-47)
# mismatched_layers = ["embedding.weight", "fc_out.weight", "fc_out.bias"]
# for name, param in checkpoint_state.items():
#     if name in mismatched_layers:
#         if len(param.shape) > 1:
#             model_state[name][:param.shape[0], :] = param
#         else:
#             model_state[name][:param.shape[0]] = param
#     else:
#         model_state[name] = param

# model.load_state_dict(model_state)
# model.eval()
# print("‚úÖ Phase-4 JOIN model loaded with Surgery (Size 49)")


device = get_device()
model = SQLTransformer().to(device)

PHASE4_CKPT = "checkpoints/phase4_model.pt"
checkpoint_state = torch.load(PHASE4_CKPT, map_location=device)
model_state = model.state_dict()

mismatched_layers = ["embedding.weight", "fc_out.weight", "fc_out.bias"]

for name, param in checkpoint_state.items():
    if name in mismatched_layers:
        if len(param.shape) > 1:
            model_state[name][:param.shape[0], :] = param
        else:
            model_state[name][:param.shape[0]] = param
    else:
        model_state[name] = param

model.load_state_dict(model_state)
model.eval()
print("‚úÖ Phase-4 JOIN model loaded (Backward Compatible)")


‚úÖ Phase-4 JOIN model loaded (Backward Compatible)


In [31]:
# üîπ Cell 3 ‚Äî User Schema (Relationships preserved)
USER_SCHEMA = {
    "schema_id": "U1",
    "tables": {
        "employees": ["emp_id", "first_name", "dept_id", "salary","location"],
        "departments": ["dept_id", "dept_name", "manager_id"],
        "orders": ["order_id", "customer_id", "total_amount"],
        "customers": ["customer_id", "last_name", "country"]
    }
    
}
#show last_name and order total_amount from customers and orders
#show first_name and dept_name from employees and departments where dept_name is ai and manager_id is 123
#list last_name and total_amount from customers and orders where total_amount is greater than 500
#show average salary from employees
# get dept_id, dept_name from departments
#get dept_name from departments where dept_id = 123
#show average salary from employees by dept_id and location
USER_NL_QUERY = "get dept_id, dept_name from departments"

In [32]:
# # üîπ Cell 4
# def discover_relationships(schema_tables):
#     """
#     Automatically finds Primary-Foreign key pairs by matching column names.
#     Strategy: If 'dept_id' exists in Table A and Table B, they are related.
#     """
#     relationships = []
#     table_names = list(schema_tables.keys())
    
#     for i in range(len(table_names)):
#         for j in range(i + 1, len(table_names)):
#             t1, t2 = table_names[i], table_names[j]
#             cols1 = set(schema_tables[t1])
#             cols2 = set(schema_tables[t2])
            
#             # Find common columns (e.g., {'dept_id'})
#             common = cols1.intersection(cols2)
            
#             # Filter out generic names like 'id' or 'name' to avoid false positives
#             for col in common:
#                 if col.lower() not in ["id", "name", "created_at", "updated_at"]:
#                     relationships.append({
#                         "from": f"{t1}.{col}",
#                         "to": f"{t2}.{col}"
#                     })
#     return relationships


def discover_relationships(schema_tables):
    relationships = []
    table_names = list(schema_tables.keys())

    for i in range(len(table_names)):
        for j in range(i + 1, len(table_names)):
            t1, t2 = table_names[i], table_names[j]
            cols1 = set(schema_tables[t1])
            cols2 = set(schema_tables[t2])

            common = cols1.intersection(cols2)

            for col in common:
                if col.lower() not in {"id", "name", "created_at", "updated_at"}:
                    relationships.append({
                        "left_table": t1,
                        "right_table": t2,
                        "left_col": col,
                        "right_col": col
                    })
    return relationships


In [33]:
# # üîπ Cell 5
# def infer_phase4_sql(schema_json, nl_query):
#     # 1Ô∏è‚É£ Setup
#     schema_parser = SchemaParser(schema_json)
#     all_columns = schema_parser.get_all_columns()
#     tables_dict = schema_json["tables"]
#     auto_rels = discover_relationships(tables_dict)
    
#     nl_parser = NLParser()
#     signals = nl_parser.parse(nl_query)
#     nl_lower = nl_query.lower()

#     # 2Ô∏è‚É£ Transformer structural decoding
#     prompt_ids = tokens_to_ids([START])
#     input_ids = torch.tensor([prompt_ids], device=device)
#     attention_mask = torch.ones(input_ids.shape, device=device)
    
#     generated_ids = model.generate(
#         input_ids=input_ids,
#         attention_mask=attention_mask,
#         schema_tables=schema_parser.get_tables(),
#         schema_columns=all_columns,
#         intent_signals=signals # Signals 'where': True [cite: 147-148]
#     )
#     tokens = ids_to_tokens(generated_ids)

#     # 3Ô∏è‚É£ JOIN bridge Resolution
#     resolved_tables = [t for t in signals["tables"] if t in tables_dict]
#     join_columns = []
#     if len(resolved_tables) >= 2:
#         for rel in auto_rels:
#             t1, t2 = rel["from"].split('.')[0], rel["to"].split('.')[0]
#             if t1 in resolved_tables and t2 in resolved_tables:
#                 join_columns = [rel["from"], rel["to"]]
#                 break

#     # 4Ô∏è‚É£ Recursive Boolean WHERE Logic üî•
#     where_ast = None
#     if "WHERE" in tokens and "where" in nl_lower:
#         where_text = nl_lower.split("where", 1)[1]
#         aligner = SemanticAligner()
#         where_parser = WhereParser(nl_parser, aligner)
        
#         # Tokenize into [dept_name is ai, and, manager_id is 123]
#         wp_tokens = where_parser.tokenize(where_text)
#         where_ast = where_parser.build_tree(
#             wp_tokens, 
#             table=resolved_tables[0], 
#             table_cols=tables_dict[resolved_tables[0]], 
#             all_columns=all_columns 
#         )

#     # 5Ô∏è‚É£ Final Binding & SELECT Alignment
#     aligner = SemanticAligner()
#     mapping = aligner.align(user_terms=signals["entities"], schema_terms=all_columns)
#     select_cols = list(dict.fromkeys([mapping[e] for e in signals["entities"] if e in mapping]))

#     schema_bindings = {
#         "<TABLE>": resolved_tables,
#         "<COLUMN>": {
#             "select": select_cols,
#             "join_left": join_columns[0] if join_columns else None,
#             "join_right": join_columns[1] if join_columns else None,
#             "where": [] # Handled by where_ast
#         },
#         "<VALUE>": signals.get("value")
#     }

#     bound_tokens = bind_schema_tokens(tokens, schema_bindings)
#     bound_ast = tokens_to_ast(bound_tokens)
    
#     # Inject high-precision results
#     bound_ast["select"] = [{"agg": None, "column": col} for col in select_cols[:2]]
#     if where_ast:
#         bound_ast["where"] = where_ast # Recursively renders AND/OR [cite: 37-39]

#     return SQLRenderer().render(adapt_token_ast(bound_ast))


def infer_phase4_sql(schema_json, nl_query):
    schema_parser = SchemaParser(schema_json)
    schema_tables = schema_parser.get_tables()
    all_columns = schema_parser.get_all_columns()

    nl_parser = NLParser()
    signals = nl_parser.parse(nl_query)
    nl_lower = nl_query.lower()

    # Projection text = part before FROM
    projection_text = nl_lower.split(" from ", 1)[0]


    # ==================================================
    # üîë SCHEMA-AWARE TABLE RESOLUTION (IMPORTANT)
    # ==================================================
    resolved_tables = [
        t for t in signals["entities"]
        if t in schema_tables
    ]

    print("shit1")
    # ==================================================
    # üîÄ PHASE ROUTING (FINAL, CORRECT)
    # ==================================================
    if len(resolved_tables) >= 2:
        pass  # Phase-4 JOIN (aggregation allowed)
    elif (
        signals["aggregations"]
        or signals.get("group_by")
        or signals.get("having")
    ):
        return infer_phase3_sql(schema_json, nl_query)
    else:
        return infer_phase2_sql(schema_json, nl_query)

    print("shit2")
    # ==================================================
    # üß† PHASE-4 JOIN LOGIC
    # ==================================================
    auto_rels = discover_relationships(schema_json["tables"])

    # Transformer decode
    input_ids = torch.tensor(
        [tokens_to_ids([START])],
        device=device
    )

    generated_ids = model.generate(
        input_ids=input_ids,
        attention_mask=torch.tensor(
            [create_attention_mask(input_ids[0].tolist(), PAD)],
            device=device
        ),
        schema_tables=schema_tables,
        schema_columns=all_columns,
        intent_signals=signals
    )

    # ==================================================
    # üîó JOIN RESOLUTION
    # ==================================================
    join_pair = None
    for rel in auto_rels:
        if (
            rel["left_table"] in resolved_tables
            and rel["right_table"] in resolved_tables
        ):
            join_pair = rel
            break

    if not join_pair:
        raise ValueError("‚ùå Could not resolve JOIN relationship")

    # ==================================================
    # üéØ SELECT ALIGNMENT (FINAL, STRICT)
    # ==================================================
    aligner = SemanticAligner()
    mapping = aligner.align(
        user_terms=signals["entities"],
        schema_terms=all_columns
    )
    
    select_cols = []
    
    for term in signals["entities"]:
        # must look like a column
        if "_" in term:
            # must appear in projection part (before FROM)
            if term in projection_text:
                if term in mapping:
                    mapped = mapping[term]
                    if "." in mapped:
                        select_cols.append(mapped)
    
    # Deduplicate, preserve order
    select_cols = list(dict.fromkeys(select_cols))



    base_table = resolved_tables[0]

    # ==================================================
    # üîé WHERE (Safe handling)
    # ==================================================
    where_ast = None
    
    if " where " in nl_lower:
        where_text = nl_lower.split(" where ", 1)[1]
    
        where_parser = WhereParser(nl_parser, aligner)
    
        wp_tokens = where_parser.tokenize(where_text)
    
        where_ast = where_parser.build_tree(
            wp_tokens,
            table=base_table,
            table_cols=schema_json["tables"][base_table],
            all_columns=all_columns
        )


    # ==================================================
    # üß© AST
    # ==================================================
    

    bound_ast = {
        "select": [
            {"agg": None, "column": c}
            for c in select_cols
        ],
    
        "from": [base_table],
    
        "joins": [
            {
                "type": "INNER",
                "table": join_pair["right_table"]
                if join_pair["left_table"] == base_table
                else join_pair["left_table"],
                "on": {
                    "left": f"{join_pair['left_table']}.{join_pair['left_col']}",
                    "op": "=",
                    "right": f"{join_pair['right_table']}.{join_pair['right_col']}",
                }
            }
        ],
    
        "where": where_ast or []
    }


    return SQLRenderer().render(bound_ast)



In [34]:
# #  üîπ Cell 6 - Run Inference
# #try:
# sql_output = infer_phase4_sql(USER_SCHEMA, USER_NL_QUERY)
# print("üß† NL Query :", USER_NL_QUERY)
# print("üßæ SQL Query:", sql_output)
# #except Exception as e:
#     #print(f"‚ùå Inference Error: {e}")

sql_output = infer_phase4_sql(USER_SCHEMA, USER_NL_QUERY)

print("üß† NL Query :", USER_NL_QUERY)
print("üßæ SQL Query:", sql_output)


shit1
üß† NL Query : get dept_id, dept_name from departments
üßæ SQL Query: SELECT departments.dept_id, departments.dept_name FROM departments
