In [1]:
from inference_llm import infer_llm
from tqdm import tqdm
import json
from enums import InferenceOptions
from utils import post_process_sql

## Load prompts

In [2]:
# PARAMS

PROMPTS_FILE = './pipeline_results/prompts.json'

INFERENCE_MODELS = [InferenceOptions.SQL_UNI, InferenceOptions.SQL_REASON]

In [3]:
with open(PROMPTS_FILE, 'r') as f:
    list_of_dicts = json.load(f)

prompt_merged_dict = {}
for d in list_of_dicts:
    prompt_merged_dict.update(d)

for k in prompt_merged_dict.keys():
    print(f'Key: {k}')

Key: openai$full_schema
Key: reasoning$full_schema
Key: openai$full_schema$fk
Key: reasoning$full_schema$fk
Key: openai$full_schema$fk$random_hardness$3
Key: openai$full_schema$fk$random_hardness$5
Key: openai$full_schema$fk$cosine_sim_hardness$3
Key: openai$full_schema$fk$cosine_sim_hardness$5
Key: reasoning$full_schema$fk$random_hardness$3
Key: reasoning$full_schema$fk$random_hardness$5
Key: reasoning$full_schema$fk$cosine_sim_hardness$3
Key: reasoning$full_schema$fk$cosine_sim_hardness$5


In [4]:
model_to_promp = {
    InferenceOptions.SQL_UNI.value: [k for k in prompt_merged_dict.keys() if k.startswith('openai')], 
    InferenceOptions.SQL_REASON.value: [k for k in prompt_merged_dict.keys() if k.startswith('reasoning')]
}
print(f'Model to prompt mapping: {model_to_promp}')

Model to prompt mapping: {'gpt-4.1-2025-04-14': ['openai$full_schema', 'openai$full_schema$fk', 'openai$full_schema$fk$random_hardness$3', 'openai$full_schema$fk$random_hardness$5', 'openai$full_schema$fk$cosine_sim_hardness$3', 'openai$full_schema$fk$cosine_sim_hardness$5'], 'o4-mini-2025-04-16': ['reasoning$full_schema', 'reasoning$full_schema$fk', 'reasoning$full_schema$fk$random_hardness$3', 'reasoning$full_schema$fk$random_hardness$5', 'reasoning$full_schema$fk$cosine_sim_hardness$3', 'reasoning$full_schema$fk$cosine_sim_hardness$5']}


In [5]:
DEV_MESSAGE = "You are an expert SQL translator specialized in converting natural language queries into correct and efficient SQL statements.\n"
DEV_MESSAGE += "Pay attention to PROVIDED INFORMATION and EXAMPLES.\n"
DEV_MESSAGE += "If aliases for table names are needed use: T1, T2 ... as aliases!\n\n"
DEV_MESSAGE += """# Response Formats\n\n## response_format_schema\n{"type": "object","properties": {"sql_query": {"type": "string",},}}"""
print(f'DEV_MESSAGE: {DEV_MESSAGE}')

DEV_MESSAGE: You are an expert SQL translator specialized in converting natural language queries into correct and efficient SQL statements.
Pay attention to PROVIDED INFORMATION and EXAMPLES.
If aliases for table names are needed use: T1, T2 ... as aliases!

# Response Formats

## response_format_schema
{"type": "object","properties": {"sql_query": {"type": "string",},}}


---

# BASE

### OAI

In [7]:
oai_prompts = prompt_merged_dict['openai$full_schema']
oai_base_responces = infer_llm(prompts=oai_prompts, dev_message=DEV_MESSAGE, inference_option=InferenceOptions.SQL_UNI)

Inferring SQL queries: 100%|██████████| 1034/1034 [24:32<00:00,  1.42s/it]


In [None]:
pred = []
for p in list(oai_base_responces):
    try:
        responce_string = p.output[0].content[0].text
    except:
        responce_string = "{'sql_query': ''}"
    response_dict = json.loads(responce_string)
    pred_sql = 'SELECT '+ response_dict['sql_query']
    pred.append(post_process_sql(pred_sql))


In [13]:
output_file = './pipeline_results/dev_pred_oai_base_v0.txt'
with open(output_file, "w") as f:
    for line in pred:
        f.write(line + '\n')

In [10]:
"SELECT name FROM conductor WHERE nationality != 'USA'"
for i, p in enumerate(pred):
    if p == "SELECT name FROM conductor WHERE nationality != 'USA'":
        print(i)

826
827


### Reasoning

In [17]:
reasoning_prompts = prompt_merged_dict['reasoning$full_schema']
reasoning_base_responces = infer_llm(prompts=reasoning_prompts, dev_message=DEV_MESSAGE, inference_option=InferenceOptions.SQL_REASON)

Inferring SQL queries: 100%|██████████| 1034/1034 [1:20:58<00:00,  4.70s/it]


In [33]:
pred = []
for p in list(reasoning_base_responces):
    try:
        responce_string = p.output[1].content[0].text
    except:
        print(p.output)
        responce_string = '{"sql_query": ""}'
    try:
        response_dict = json.loads(post_process_sql(responce_string))
    except:
        response_dict = json.loads(post_process_sql(responce_string) + '"}')
    pred_sql = 'SELECT '+ response_dict['sql_query']
    pred.append(post_process_sql(pred_sql))


[ResponseReasoningItem(id='rs_6814f0cdae988191a562f9bf2b3e1c7d0ad775f874f68bc9', summary=[], type='reasoning', status=None)]
[ResponseReasoningItem(id='rs_6814f0d564048191abd872781a536f180eb243310bc2e59f', summary=[], type='reasoning', status=None)]
[ResponseReasoningItem(id='rs_6814f0dcc5f481918a74e8d08d33a3a2071b20b1d043bd36', summary=[], type='reasoning', status=None)]
[ResponseReasoningItem(id='rs_6814f0fa27a48191a38b24a1d020448f0adaafef88fef6c3', summary=[], type='reasoning', status=None)]
[ResponseReasoningItem(id='rs_6814f104a4388191b793c16829e851f10e0525ae333542ba', summary=[], type='reasoning', status=None)]
[ResponseReasoningItem(id='rs_6814f1480bc88191b4583e7ad3f5f40c0b391d2a85fe8b89', summary=[], type='reasoning', status=None)]
[ResponseReasoningItem(id='rs_6814f14fa9408191a562b748c1e1dc090b3b297dc7dc0a7b', summary=[], type='reasoning', status=None)]
[ResponseReasoningItem(id='rs_6814f1c0bbf48191b2dc4e71e6267ebe05a861ffe4324360', summary=[], type='reasoning', status=None)]


In [29]:
output_file = './pipeline_results/dev_pred_reason_base_v0.txt'
with open(output_file, "w") as f:
    for line in pred:
        f.write(line + '\n')

## BASE + META INFO

### OAI

In [15]:
oai_base_meta_prompts = prompt_merged_dict['openai$full_schema$fk']
oai_base_meta_responces = infer_llm(prompts=oai_base_meta_prompts, dev_message=DEV_MESSAGE, inference_option=InferenceOptions.SQL_UNI)


Inferring SQL queries: 100%|██████████| 1034/1034 [27:54<00:00,  1.62s/it] 


In [None]:
pred = []
for p in list(oai_base_meta_responces):
    try:
        responce_string = p.output[0].content[0].text
    except:
        responce_string = "{'sql_query': ''}"
    response_dict = json.loads(responce_string)
    pred_sql = 'SELECT '+ response_dict['sql_query']
    pred.append(post_process_sql(pred_sql))


In [None]:
output_file = './pipeline_results/dev_pred_oai_base_meta_v0.txt'
with open(output_file, "w") as f:
    for line in pred:
        f.write(line + '\n')

# BASE + META INFO + EXAMPLES

### OAI + random by hardness k=3

In [7]:
oai_base_meta_rand_hard_3_prompts = prompt_merged_dict['openai$full_schema$fk$random_hardness$3']
oai_base_meta_rand_hard_3_responces = infer_llm(prompts=oai_base_meta_rand_hard_3_prompts, dev_message=DEV_MESSAGE, inference_option=InferenceOptions.SQL_UNI)


Inferring SQL queries: 100%|██████████| 1034/1034 [24:07<00:00,  1.40s/it]


In [10]:
pred = []
for p in list(oai_base_meta_rand_hard_3_responces):
    try:
        responce_string = p.output[0].content[0].text
    except:
        responce_string = "{'sql_query': ''}"
    response_dict = json.loads(responce_string)
    pred_sql = 'SELECT '+ response_dict['sql_query']
    pred.append(post_process_sql(pred_sql))


In [None]:
output_file = './pipeline_results/dev_pred_oai_base_meta_rand_hard_3_v0.txt'
with open(output_file, "w") as f:
    for line in pred:
        f.write(line + '\n')

### OAI + random by hardness k=5

In [6]:
oai_base_meta_rand_hard_5_prompts = prompt_merged_dict['openai$full_schema$fk$random_hardness$5']
oai_base_meta_rand_hard_5_responces = infer_llm(prompts=oai_base_meta_rand_hard_5_prompts, dev_message=DEV_MESSAGE, inference_option=InferenceOptions.SQL_UNI)


Inferring SQL queries: 100%|██████████| 1034/1034 [24:07<00:00,  1.40s/it]


In [7]:
pred = []
for p in list(oai_base_meta_rand_hard_5_responces):
    try:
        responce_string = p.output[0].content[0].text
    except:
        responce_string = "{'sql_query': ''}"
    response_dict = json.loads(responce_string)
    pred_sql = 'SELECT '+ response_dict['sql_query']
    pred.append(post_process_sql(pred_sql))


In [None]:
output_file = './pipeline_results/dev_pred_oai_base_meta_rand_hard_5_v0.txt'
with open(output_file, "w") as f:
    for line in pred:
        f.write(line + '\n')

---