In [1]:
from datasets import SpiderDataset
from prompt.prompt_templates import OAIPrompt, ReasoningPrompt
from pprint import pprint
from inference_llm import get_sql_uni_llm, get_sql_reason_llm
from tqdm import tqdm
import json

---

In [2]:
test_dataset = SpiderDataset(
        dataset_dir="spider_data",
        file_name="dev.json",
        path_to_gold="dev_gold.sql",
        table_file_path="tables.json",
    )

Processing samples: 1034it [00:00, 382409.87it/s]


In [3]:
test_samples = test_dataset.samples[:10]

In [18]:
with open('./pipeline_results/dev_gold.txt', 'w') as f:
    for t in test_samples:
        f.write(t.query_gold + '\t' + t.db_id +'\n')

In [11]:
op_rule='You must minimize execution time while ensuring correctness.'
DEV_MESSAGE = "You are an expert SQL translator specialized in converting natural language queries into correct and efficient SQL statements.\n"
DEV_MESSAGE += "Do not use any aliases for tables or columns!\n\n"
DEV_MESSAGE += """# Response Formats\n\n## response_format_schema\n{"type": "object","properties": {"sql_query": {"type": "string",},}}"""


In [12]:
print(DEV_MESSAGE)

You are an expert SQL translator specialized in converting natural language queries into correct and efficient SQL statements.
Do not use any aliases for tables or columns!

# Response Formats

## response_format_schema
{"type": "object","properties": {"sql_query": {"type": "string",},}}


## OAI base

In [None]:
responces_to_save_OAI_base = []

for s in tqdm(test_samples):
    prompt = OAIPrompt.build_prompt(
        sample=s, k=0, op_rule=op_rule
    )
    responce = get_sql_uni_llm(prompt, model="gpt-4o-2024-08-06", dev_message=DEV_MESSAGE)
    responces_to_save_OAI_base.append(responce)

100%|██████████| 3/3 [00:03<00:00,  1.16s/it]


In [21]:
# Save the list of dictionaries to a JSON file
output_file = "OAI_output_base.json"

with open(output_file, "w") as f:
    json.dump([r.dict() for r in responces_to_save_OAI_base], f)

/var/folders/p2/skh823fd5t5fcl9_fbmxvw24rb6_10/T/ipykernel_90728/2484296937.py:5: PydanticDeprecatedSince20: The `dict` method is deprecated; use `model_dump` instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.10/migration/
  json.dump([r.dict() for r in responces_to_save_OAI_base], f)


In [None]:
pred = []
for p in list(responces_to_save_OAI_base):
    try:
        responce_string = p.output[0].content[0].text
    except:
        responce_string = "{'sql_query': ''}"
    response_dict = json.loads(responce_string)
    pred.append('SELECT '+ response_dict['sql_query'].strip().rstrip(';'))


In [25]:
output_file = 'dev_pred_OAI_base.txt'
with open(output_file, "w") as f:
    for line in pred:
        f.write(line + '\n')

## OAI with full schema

In [None]:
responces_to_save_OAI_wsc = []

for s in tqdm(test_samples):
    prompt = OAIPrompt.build_prompt(
        sample=s, k=0, op_rule=op_rule, schema_info_option='all'
    )
    responce = get_sql_uni_llm(prompt, model="gpt-4o-2024-08-06", dev_message=DEV_MESSAGE)
    responces_to_save_OAI_wsc.append(responce)

100%|██████████| 3/3 [00:04<00:00,  1.36s/it]


In [27]:
# Save the list of dictionaries to a JSON file
output_file = "OAI_output_wsc.json"

with open(output_file, "w") as f:
    json.dump([r.dict() for r in responces_to_save_OAI_wsc], f)

/var/folders/p2/skh823fd5t5fcl9_fbmxvw24rb6_10/T/ipykernel_90728/728498247.py:5: PydanticDeprecatedSince20: The `dict` method is deprecated; use `model_dump` instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.10/migration/
  json.dump([r.dict() for r in responces_to_save_OAI_wsc], f)


In [None]:
gold = []
for p in list(responces_to_save_OAI_wsc):
    try:
        reesponce_string = p.output[0].content[0].text
    except:
        reesponce_string = "{'sql_query': ''}"
    response_dict = json.loads(reesponce_string)
    pred.append('SELECT '+ response_dict['sql_query'].strip().rstrip(';'))


In [29]:
output_file = 'dev_pred_OAI_wsc.txt'
with open(output_file, "w") as f:
    for line in pred:
        f.write(line + '\n')

## OAI with full schema and fk

In [None]:
responces_to_save_OAI_wsc_fk = []

for s in tqdm(test_samples):
    prompt = OAIPrompt.build_prompt(
        sample=s, k=0, op_rule=op_rule, schema_info_option='all', add_fk_info=True
    )
    responce = get_sql_uni_llm(prompt, model="gpt-4o-2024-08-06", dev_message=DEV_MESSAGE)
    responces_to_save_OAI_wsc_fk.append(responce)

100%|██████████| 100/100 [01:44<00:00,  1.05s/it]


In [16]:
# Save the list of dictionaries to a JSON file
output_file = "OAI_output_wsc_fk.json"

with open(output_file, "w") as f:
    json.dump([r.dict() for r in responces_to_save_OAI_wsc_fk], f)

/var/folders/p2/skh823fd5t5fcl9_fbmxvw24rb6_10/T/ipykernel_33592/3059441803.py:5: PydanticDeprecatedSince20: The `dict` method is deprecated; use `model_dump` instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.10/migration/
  json.dump([r.dict() for r in responces_to_save_OAI_wsc_fk], f)


In [None]:
pred = []
for p in list(responces_to_save_OAI_wsc_fk):
    gold.append(s.query + '\t' + s.db_id)
    try:
        reesponce_string = p.output[0].content[0].text
    except:
        reesponce_string = "{'sql_query': ''}"
    response_dict = json.loads(reesponce_string)
    pred.append('SELECT '+ response_dict['sql_query'].strip().rstrip(';'))


In [18]:
output_file = 'dev_pred_OAI_wsc_fk.txt'
with open(output_file, "w") as f:
    for line in pred:
        f.write(line + '\n')

## Reasoning base

In [5]:
REASON_DEV_MESSAGE = "You are an expert SQL translator specialized in converting natural language queries into correct and efficient SQL statements.\n" \
    "For example:\n\n" \
    "USER INPUT: How many heads of the departments are older than 56 ?\n" \
    "OUTPUT: SELECT count(*) FROM head WHERE age  >  56"
print(REASON_DEV_MESSAGE)

You are an expert SQL translator specialized in converting natural language queries into correct and efficient SQL statements.
For example:

USER INPUT: How many heads of the departments are older than 56 ?
OUTPUT: SELECT count(*) FROM head WHERE age  >  56


In [14]:
responces_to_save_Reason_base = []

for s in tqdm(test_samples):
    prompt = ReasoningPrompt.build_prompt(
        sample=s, op_rule=op_rule, k=0
    )
    responce = get_sql_reason_llm(prompt, model="o3-mini-2025-01-31", dev_message=DEV_MESSAGE)
    responces_to_save_Reason_base.append(responce)

100%|██████████| 10/10 [00:41<00:00,  4.17s/it]


In [15]:
# Save the list of dictionaries to a JSON file
output_file = "./pipeline_results/Reason_output_base.json"

with open(output_file, "w") as f:
    json.dump([r.dict() for r in responces_to_save_Reason_base], f)

/var/folders/p2/skh823fd5t5fcl9_fbmxvw24rb6_10/T/ipykernel_36582/1912213914.py:5: PydanticDeprecatedSince20: The `dict` method is deprecated; use `model_dump` instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.10/migration/
  json.dump([r.dict() for r in responces_to_save_Reason_base], f)


In [16]:
pred = []
for p in list(responces_to_save_Reason_base):
    try:
        reesponce_string = p.output[1].content[0].text.strip()
        response_dict = json.loads(reesponce_string)
    except:
        reesponce_string = '{"sql_query": "SELECT"}'
        response_dict = json.loads(reesponce_string)
    pred.append(response_dict['sql_query'].strip().rstrip(';'))


In [17]:
output_file = './pipeline_results/dev_pred_Reason_base.txt'
with open(output_file, "w") as f:
    for line in pred:
        f.write(line + '\n')

## Reasoning with full schema

In [19]:
responces_to_save_Reason_wsc = []

for s in tqdm(test_samples):
    prompt = ReasoningPrompt.build_prompt(
        sample=s, op_rule=op_rule, k=0, schema_info_option='all'
    )
    responce = get_sql_reason_llm(prompt, model="o3-mini-2025-01-31", dev_message=DEV_MESSAGE)
    responces_to_save_Reason_wsc.append(responce)

100%|██████████| 10/10 [00:30<00:00,  3.08s/it]


In [20]:
# Save the list of dictionaries to a JSON file
output_file = "./pipeline_results/Reason_output_wsc.json"

with open(output_file, "w") as f:
    json.dump([r.dict() for r in responces_to_save_Reason_base], f)

/var/folders/p2/skh823fd5t5fcl9_fbmxvw24rb6_10/T/ipykernel_36582/2569758144.py:5: PydanticDeprecatedSince20: The `dict` method is deprecated; use `model_dump` instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.10/migration/
  json.dump([r.dict() for r in responces_to_save_Reason_base], f)


In [21]:
pred = []
for p in list(responces_to_save_Reason_wsc):
    try:
        reesponce_string = p.output[1].content[0].text.strip()
        response_dict = json.loads(reesponce_string)
    except:
        reesponce_string = '{"sql_query": "SELECT"}'
        response_dict = json.loads(reesponce_string)
    pred.append(response_dict['sql_query'].strip().rstrip(';'))


In [22]:
output_file = './pipeline_results/dev_pred_Reason_wsc.txt'
with open(output_file, "w") as f:
    for line in pred:
        f.write(line + '\n')