In [1]:
from batch_inference_llm import create_batch_file, push_batch_file, create_batch, get_batch_status, download_batch_results
from enums import InferenceOptions
from utils import post_process_sql, post_process_responce_string
import json

In [2]:
# PARAMS

PROMPTS_FILE = './pipeline_results/prompts.json'

INFERENCE_MODELS = [InferenceOptions.SQL_UNI, InferenceOptions.SQL_REASON]

In [3]:
DEV_MESSAGE = "You are an expert SQL translator specialized in converting natural language queries into correct and efficient SQL statements.\n"
DEV_MESSAGE += "Pay attention to PROVIDED INFORMATION and EXAMPLES.\n"
DEV_MESSAGE += "If aliases for table names are needed use: T1, T2 ... as aliases!\n\n"
DEV_MESSAGE += """# Response Formats\n\n## response_format_schema\n{"type": "object","properties": {"sql_query": {"type": "string",},}}"""
print(f'DEV_MESSAGE: {DEV_MESSAGE}')

DEV_MESSAGE: You are an expert SQL translator specialized in converting natural language queries into correct and efficient SQL statements.
Pay attention to PROVIDED INFORMATION and EXAMPLES.
If aliases for table names are needed use: T1, T2 ... as aliases!

# Response Formats

## response_format_schema
{"type": "object","properties": {"sql_query": {"type": "string",},}}


In [4]:
with open(PROMPTS_FILE, 'r') as f:
    list_of_dicts = json.load(f)

prompt_merged_dict = {}
for d in list_of_dicts:
    prompt_merged_dict.update(d)

for k in prompt_merged_dict.keys():
    print(f'Key: {k}')

Key: openai$full_schema
Key: reasoning$full_schema
Key: openai$full_schema$fk
Key: reasoning$full_schema$fk
Key: openai$full_schema$fk$random_hardness$3
Key: openai$full_schema$fk$random_hardness$5
Key: openai$full_schema$fk$cosine_sim_hardness$3
Key: openai$full_schema$fk$cosine_sim_hardness$5
Key: reasoning$full_schema$fk$random_hardness$3
Key: reasoning$full_schema$fk$random_hardness$5
Key: reasoning$full_schema$fk$cosine_sim_hardness$3
Key: reasoning$full_schema$fk$cosine_sim_hardness$5


---

In [5]:
propmpt_type = 'openai$full_schema'

In [6]:
create_batch_file(
    inference_option=InferenceOptions.SQL_UNI,
    prompts=prompt_merged_dict[propmpt_type],
    dev_message=DEV_MESSAGE,
    file_name=propmpt_type
)

Batch file 'openai$full_schema.jsonl'created successfully. 1034 requests.


In [7]:
file_id = push_batch_file(propmpt_type)

FileObject(id='file-PdPzKyrURg74cFJRL9Lbop', bytes=1461458, created_at=1748172529, filename='openai$full_schema.jsonl', object='file', purpose='batch', status='processed', expires_at=None, status_details=None)


In [8]:
batch = create_batch(file_id=file_id, description=f"{propmpt_type} v0")
print(f'Batch created:\n {batch}')

Batch created:
 Batch(id='batch_6832ff2a635081908b6fca2dea42a094', completion_window='24h', created_at=1748172586, endpoint='/v1/responses', input_file_id='file-PdPzKyrURg74cFJRL9Lbop', object='batch', status='validating', cancelled_at=None, cancelling_at=None, completed_at=None, error_file_id=None, errors=None, expired_at=None, expires_at=1748258986, failed_at=None, finalizing_at=None, in_progress_at=None, metadata={'description': 'openai$full_schema v0'}, output_file_id=None, request_counts=BatchRequestCounts(completed=0, failed=0, total=0))


---

In [9]:
print(get_batch_status('batch_6832ff2a635081908b6fca2dea42a094').output_file_id)

file-27KpgfMGsQy9pJaCerkmaw


In [10]:
download_batch_results('file-27KpgfMGsQy9pJaCerkmaw', 'openai$full_schema')

Batch results downloaded to openai$full_schema


In [11]:
responce = []
with open('./pipeline_results/openai$full_schema.jsonl', 'r') as f:
    for line in f:
        responce_dict = json.loads(line)
        responce.append(responce_dict)


In [12]:
# responce[0]['response']['body']['output'][1]['content'][0]['text'] - reasoning
responce[0]['response']['body']['output'][0]['content'][0]['text']

'{"sql_query": "SELECT COUNT(*) FROM singer;"}'

In [61]:
responce[0]['custom_id']

'0'

In [13]:
results = []
for r in responce:
    res = {
        "id": int(r['custom_id']),
    }
    try:
        responce_text = r['response']['body']['output'][0]['content'][0]['text']
    except (KeyError, IndexError):
        responce_text = "{'sql_query': 'SELECT'}"

    processed_responce_text = post_process_responce_string(responce_text)
    res['sql'] = processed_responce_text
    results.append(res)
print(f'len(results): {len(results)}')
    

len(results): 1034


In [14]:
results[:5]

[{'id': 0, 'sql': 'SELECT COUNT(*) FROM singer;'},
 {'id': 1, 'sql': 'SELECT COUNT(*) FROM singer;'},
 {'id': 2, 'sql': 'SELECT name, country, age FROM singer ORDER BY age DESC;'},
 {'id': 3, 'sql': 'SELECT name, country, age FROM singer ORDER BY age DESC;'},
 {'id': 4,
  'sql': "SELECT AVG(age), MIN(age), MAX(age) FROM singer WHERE country = 'France';"}]

In [15]:
results.sort(key=lambda x: x['id'])

In [16]:
output_file = './pipeline_results/dev_pred_openai$full_schema_v1.txt'
with open(output_file, "w") as f:
    for r in results:
        f.write(post_process_sql(r['sql'])+'\n')

In [None]:
# results = []
# for r in responce:
#     res = {
#         "id": int(r['custom_id']),
#     }
#     try:
#         sql = r['response']['body']['output'][1]['content'][0]['text']
#     except (KeyError, IndexError):
#         sql = "{'sql_query': 'SELECT'}"
    
#     # if int(r['custom_id']) == 188:
#     #     print(sql)
#     try:
#         d = json.loads(sql)
#     except Exception:
#         d = {}

#     res_sql = d.get('sql_query', 'SELECT')
#     if not res_sql.startswith('SELECT'):
#         res_sql = 'SELECT'

#     res['sql'] = res_sql
#     results.append(res)
# print(f'len(results): {len(results)}')

len(results): 1034


In [49]:
results[188]


{'id': 188, 'sql': 'SELECT'}