In [1]:
# üîπ Cell 1 ‚Äî Imports & Load Phase-2 Model
import torch
import sys
sys.path.append("..")

from models.sql_transformer import SQLTransformer
from src.schema_parser import SchemaParser
from src.nl_parser import NLParser
from src.semantic_aligner import SemanticAligner
from src.schema_binder import bind_schema_tokens
from src.ast_renderer import SQLRenderer
from src.utils import (
    tokens_to_ids,
    ids_to_tokens,
    create_attention_mask,
    get_device
)
from src.vocab import PAD

In [2]:
device = get_device()

model = SQLTransformer().to(device)
# model.load_state_dict(
#     torch.load("checkpoints/phase2_model.pt", map_location=device)
# )
# model.eval()

# print("‚úÖ Phase-2 model loaded for inference")

# ‚úÖ PHASE-2 LOADING SURGERY
PHASE2_CKPT = "checkpoints/phase2_model.pt"
checkpoint_state = torch.load(PHASE2_CKPT, map_location=device)
model_state = model.state_dict()

# Map old weights to new model indices (0-47)
mismatched_layers = ["embedding.weight", "fc_out.weight", "fc_out.bias"]
for name, param in checkpoint_state.items():
    if name in mismatched_layers:
        if len(param.shape) > 1:
            model_state[name][:param.shape[0], :] = param
        else:
            model_state[name][:param.shape[0]] = param
    else:
        model_state[name] = param

model.load_state_dict(model_state)
model.eval()
print("‚úÖ Phase-2  model loaded with Surgery (Size 49)")

‚úÖ Phase-2  model loaded with Surgery (Size 49)


In [3]:

# üîπ Cell 2 ‚Äî Example shown to user (important UX)
print("üìò EXAMPLE\n")

print("Schema:")
print("employees(id, name, salary, department)\n")

print("NL Query:")
print("show name from employees where salary > 100\n")

print("Expected SQL:")
print("SELECT employees.name FROM employees WHERE employees.salary > 100")

üìò EXAMPLE

Schema:
employees(id, name, salary, department)

NL Query:
show name from employees where salary > 100

Expected SQL:
SELECT employees.name FROM employees WHERE employees.salary > 100


In [7]:
# üîπ Cell 3 ‚Äî USER INPUT (only this cell changes)
USER_SCHEMA = {
    "schema_id": "U1",
    "tables": {
        "subject": ["id", "name", "marks"],
        "bat": ["id", "size", "weight"],
        "customers": ["id", "last_name", "age"],
        "orders": ["order_id", "customer_id", "order_date", "total_amount"],
        "products": ["product_id", "product_name", "category", "unit_price"],
        "employees": ["emp_id", "first_name", "department", "salary"],
        "inventory": ["stock_id", "product_id", "warehouse_location", "quantity"],
        "departments": ["dept_id", "dept_name", "manager_id"],
        "suppliers": ["supplier_id", "company_name", "contact_email", "country"],
        "shipments": ["ship_id", "order_id", "tracking_number", "status"]
    }
}
# get product_id, stock_id from inventory where quantity > 100
#list company_name from suppliers where country is USA and supplier_id = 123

USER_NL_QUERY = "list company_name, supplier_id from suppliers"

In [8]:
# üîπ Cell 4 ‚Äî Phase-2 Inference Function (CORE LOGIC)
''''from src.where_parser import WhereParser

def infer_phase2_sql(schema_json, nl_query):
    # 1Ô∏è‚É£ Schema parsing
    schema_parser = SchemaParser(schema_json)
    tables = schema_parser.get_tables()
    columns = schema_parser.get_all_columns()

    # 2Ô∏è‚É£ NL parsing
    nl_parser = NLParser()
    signals = nl_parser.parse(nl_query)

    # 3Ô∏è‚É£ Resolve TABLE
    resolved_table = None
    for t in signals["entities"]:
        if t in tables:
            resolved_table = t
            break

    if resolved_table is None:
        raise ValueError("‚ùå Could not resolve table")

    table_cols = schema_json["tables"][resolved_table]

    # 4Ô∏è‚É£ Resolve SELECT columns (before WHERE)
    nl_lower = nl_query.lower()
    if "where" in nl_lower:
        before_where = nl_lower.split("where", 1)[0]
    else:
        before_where = nl_lower

    select_columns = [
        f"{resolved_table}.{e}"
        for e in signals["entities"]
        if e in table_cols and e in before_where
    ]

    if not select_columns:
        raise ValueError("‚ùå No SELECT columns resolved")

    # 5Ô∏è‚É£ WHERE parsing (FULL BOOLEAN SUPPORT)
    where_ast = None
    if "where" in nl_lower:
        where_text = nl_lower.split("where", 1)[1]

        aligner = SemanticAligner()
        where_parser = WhereParser(nl_parser, aligner)

        tokens = where_parser.tokenize(where_text)
        where_ast = where_parser.build_tree(
            tokens,
            resolved_table,
            table_cols,
            columns
        )

    # 6Ô∏è‚É£ Render SQL
    renderer = SQLRenderer()
    sql = renderer.render({
        "select": [{"column": c, "agg": None} for c in select_columns],
        "from": [resolved_table],
        "where": where_ast
    })

    return sql'''

from src.phase2_inference import infer_phase2_sql

In [9]:
#üîπ Cell 5 ‚Äî Run Phase-2 Inference
sql_output = infer_phase2_sql(USER_SCHEMA, USER_NL_QUERY)

print("üß† NL Query :", USER_NL_QUERY)
print("üßæ SQL Query:", sql_output)

üß† NL Query : list company_name, supplier_id from suppliers
üßæ SQL Query: SELECT suppliers.company_name, suppliers.supplier_id FROM suppliers
