In [1]:
import torch
import sys
sys.path.append("..")

from models.sql_transformer import SQLTransformer
from src.schema_parser import SchemaParser
from src.nl_parser import NLParser
from src.semantic_aligner import SemanticAligner
from src.schema_binder import bind_schema_tokens
from src.ast_renderer import SQLRenderer
from src.utils import (
    tokens_to_ids,
    ids_to_tokens,
    create_attention_mask,
    get_device
)
from src.vocab import PAD


In [2]:
# Load model
device = get_device()

model = SQLTransformer().to(device)
model.load_state_dict(
    torch.load("checkpoints/phase1_model.pt", map_location=device)
)
model.eval()

print("‚úÖ Phase-1 model loaded")


‚úÖ Phase-1 model loaded


In [3]:
print("üìò EXAMPLE FOR USER\n")

print("Schema:")
print("employees(id, name, salary, department)")
print("orders(id, amount, date)\n")

print("NL Query:")
print("show salary from employees\n")

print("Expected SQL:")
print("SELECT employees.salary FROM employees")


üìò EXAMPLE FOR USER

Schema:
employees(id, name, salary, department)
orders(id, amount, date)

NL Query:
show salary from employees

Expected SQL:
SELECT employees.salary FROM employees


In [22]:
# =========================
# USER INPUT
# =========================
USER_SCHEMA = {
    "schema_id": "U1",
    "tables": {
        "subject": ["id", "name", "marks"],
        "bat": ["id", "size", "weight"],
        "customers": ["id", "last_name", "age"],
        "orders": ["order_id", "customer_id", "order_date", "total_amount"],
        "products": ["product_id", "product_name", "category", "unit_price"],
        "employees": ["emp_id", "first_name", "department", "salary"],
        "inventory": ["stock_id", "product_id", "warehouse_location", "quantity"],
        "departments": ["dept_id", "dept_name", "manager_id"],
        "suppliers": ["supplier_id", "company_name", "contact_email", "country"],
        "shipments": ["ship_id", "order_id", "tracking_number", "status"]
    }
}

USER_NL_QUERY = "list salry from employees"



In [23]:
def infer_phase1_sql(schema_json, nl_query):
    # 1Ô∏è‚É£ Schema parsing (rule-based)
    schema_parser = SchemaParser(schema_json)
    tables = schema_parser.get_tables()
    columns = schema_parser.get_all_columns()

    # 2Ô∏è‚É£ NL parsing
    nl_parser = NLParser()
    signals = nl_parser.parse(nl_query)

    # 3Ô∏è‚É£ Resolve TABLE from NL (HARD rule)
    resolved_table = None
    for t in signals["entities"]:
        if t in tables:
            resolved_table = t
            break

    if resolved_table is None:
        raise ValueError("‚ùå Failed to resolve table from NL query")

    table_columns = schema_json["tables"][resolved_table]

    # 4Ô∏è‚É£ Semantic alignment
    aligner = SemanticAligner()
    mapping = aligner.align(
        user_terms=signals["entities"],
        schema_terms=tables + columns,
        column_terms=columns
    )

    # 5Ô∏è‚É£ Resolve MULTIPLE SELECT columns (Phase-1 safe)
    select_columns = []

    # --- semantic alignment first (table constrained)
    for val in mapping.values():
        if val.startswith(resolved_table + "."):
            col = val.split(".", 1)[1]
            if col in table_columns:
                select_columns.append(val)

    # --- fallback: direct NL match
    for t in signals["entities"]:
        if t in table_columns:
            col = f"{resolved_table}.{t}"
            if col not in select_columns:
                select_columns.append(col)

    if not select_columns:
        raise ValueError("‚ùå Failed to resolve any SELECT column")

    # 6Ô∏è‚É£ Schema bindings (structure-only)
    # NOTE: model still outputs ONE <COLUMN>, we ignore it
    schema_bindings = {
        "<TABLE>": resolved_table,
        "<COLUMN>": select_columns[0]  # dummy for Phase-1 decoder
    }

    # 7Ô∏è‚É£ Model input (structure only)
    input_ids = torch.tensor(
        [tokens_to_ids(["<START>"])],
        device=device
    )

    attention_mask = torch.tensor(
        [create_attention_mask(input_ids[0].tolist(), PAD)],
        device=device
    )

    # 8Ô∏è‚É£ Generate SQL structure
    output_ids = model.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        schema_tables=tables,
        schema_columns=columns,
        max_len=10
    )

    # 9Ô∏è‚É£ Render SQL (MULTI-COLUMN SELECT)
    renderer = SQLRenderer()
    sql = renderer.render({
        "select": [
            {"agg": None, "column": col}
            for col in select_columns
        ],
        "from": [resolved_table],
        "where": [],
        "group_by": [],
        "having": []
    })

    return sql

In [24]:
sql_output = infer_phase1_sql(USER_SCHEMA, USER_NL_QUERY)

print("üß† NL Query :", USER_NL_QUERY)
print("üßæ SQL Query:", sql_output)


ValueError: ‚ùå Failed to resolve any SELECT column

In [None]:
def infer_phase1_sql(schema_json, nl_query):
    # 1Ô∏è‚É£ Schema parsing (rule-based)
    schema_parser = SchemaParser(schema_json)
    tables = schema_parser.get_tables()
    columns = schema_parser.get_all_columns()

    # 2Ô∏è‚É£ NL parsing
    nl_parser = NLParser()
    signals = nl_parser.parse(nl_query)

    # 3Ô∏è‚É£ Resolve TABLE from NL (HARD rule)
    resolved_table = None
    for t in signals["entities"]:
        if t in tables:
            resolved_table = t
            break

    if resolved_table is None:
        raise ValueError("‚ùå Failed to resolve table from NL query")

    table_columns = schema_json["tables"][resolved_table]

    # 4Ô∏è‚É£ Semantic alignment
    aligner = SemanticAligner()
    mapping = aligner.align(
        user_terms=signals["entities"],
        schema_terms=tables + columns,
        column_terms=columns
    )

    # 5Ô∏è‚É£ Resolve MULTIPLE SELECT columns (Phase-1 safe)
    select_columns = []

    # --- semantic alignment first (table constrained)
    for val in mapping.values():
        if val.startswith(resolved_table + "."):
            col = val.split(".", 1)[1]
            if col in table_columns:
                select_columns.append(val)

    # --- fallback: direct NL match
    for t in signals["entities"]:
        if t in table_columns:
            col = f"{resolved_table}.{t}"
            if col not in select_columns:
                select_columns.append(col)

    if not select_columns:
        raise ValueError("‚ùå Failed to resolve any SELECT column")

    # 6Ô∏è‚É£ Schema bindings (structure-only)
    # NOTE: model still outputs ONE <COLUMN>, we ignore it
    schema_bindings = {
        "<TABLE>": resolved_table,
        "<COLUMN>": select_columns[0]  # dummy for Phase-1 decoder
    }

    # 7Ô∏è‚É£ Model input (structure only)
    input_ids = torch.tensor(
        [tokens_to_ids(["<START>"])],
        device=device
    )

    attention_mask = torch.tensor(
        [create_attention_mask(input_ids[0].tolist(), PAD)],
        device=device
    )

    # 8Ô∏è‚É£ Generate SQL structure
    output_ids = model.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        schema_tables=tables,
        schema_columns=columns,
        max_len=10
    )

    # 9Ô∏è‚É£ Render SQL (MULTI-COLUMN SELECT)
    renderer = SQLRenderer()
    sql = renderer.render({
        "select": [
            {"agg": None, "column": col}
            for col in select_columns
        ],
        "from": [resolved_table],
        "where": [],
        "group_by": [],
        "having": []
    })

    return sql