In [1]:
from datasets import SpiderDataset
from prompt.prompt_factory import PromptFactory
from enums import PromptRepresentationType, ShemaInfoOptions, ExampleSelectionType
import json


  from .autonotebook import tqdm as notebook_tqdm


## Load Data

In [2]:
test_dataset = SpiderDataset(
        dataset_dir="spider_data",
        file_name="dev.json",
        path_to_gold="dev_gold.sql",
        table_file_path="tables.json",
    )

Processing samples: 1034it [00:00, 269775.46it/s]


In [3]:
test_dataset.calculate_question_embeddings("all-MiniLM-L6-v2")

Calculating question embeddings: 100%|██████████| 1034/1034 [00:10<00:00, 94.35it/s] 


In [4]:
test_samples = test_dataset.samples

In [5]:
train_dataset = SpiderDataset(
        dataset_dir="spider_data",
        file_name="train_spider.json",
        path_to_gold="train_gold.sql",
        table_file_path="tables.json",
        
    )

Processing samples: 7000it [00:00, 545291.46it/s]


In [6]:
train_dataset.calculate_question_embeddings("all-MiniLM-L6-v2")

Calculating question embeddings: 100%|██████████| 7000/7000 [00:59<00:00, 117.41it/s]


In [7]:
samples_for_examples = train_dataset.samples

## Prompt params

In [8]:
OP_RULE='You must minimize execution time while ensuring correctness.'
ks = [3,5]
PROMPT_REPRESENTATIONS = [PromptRepresentationType.OAI, PromptRepresentationType.REASONING]
SCHEMA_INFO_OPTION = [ShemaInfoOptions.FULL_SCHEMA]
EXAMPLE_SELECTION_TYPES = [ExampleSelectionType.RAND_HARDNESS, ExampleSelectionType.COSINE_SIM_HARDNESS]


In [9]:
OUTPUT_FILE = "./pipeline_results/prompts.json"
PROMPTS_LIST = []

## BASE

In [10]:
for prompt_repr in PROMPT_REPRESENTATIONS:
    for schema_info in SCHEMA_INFO_OPTION:
        print(f'prompt_repr: {prompt_repr}, schema_info: {schema_info}')
        prompts = PromptFactory.build_prompts(
            prompt_type = prompt_repr,
            samples=test_samples,
            schema_info_option=schema_info,
        )
        res = {
            prompt_repr.value+ '$' + schema_info.value: prompts
        }
        PROMPTS_LIST.append(res)

prompt_repr: PromptRepresentationType.OAI, schema_info: ShemaInfoOptions.FULL_SCHEMA


Building prompts: 100%|██████████| 1034/1034 [00:00<00:00, 256015.96prompt/s]


prompt_repr: PromptRepresentationType.REASONING, schema_info: ShemaInfoOptions.FULL_SCHEMA


Building prompts: 100%|██████████| 1034/1034 [00:00<00:00, 227504.08prompt/s]


## BASE + META INFO

In [11]:
for prompt_repr in PROMPT_REPRESENTATIONS:
    for schema_info in SCHEMA_INFO_OPTION:
        print(f'prompt_repr: {prompt_repr}, schema_info: {schema_info} + FK')
        prompts = PromptFactory.build_prompts(
            prompt_type = prompt_repr,
            samples=test_samples,
            schema_info_option=schema_info,
            add_fk_info=True,
        )
        res = {
            prompt_repr.value+ '$' + schema_info.value + '$' + 'fk': prompts
        }
        PROMPTS_LIST.append(res)

prompt_repr: PromptRepresentationType.OAI, schema_info: ShemaInfoOptions.FULL_SCHEMA + FK


Building prompts: 100%|██████████| 1034/1034 [00:00<00:00, 180772.39prompt/s]


prompt_repr: PromptRepresentationType.REASONING, schema_info: ShemaInfoOptions.FULL_SCHEMA + FK


Building prompts: 100%|██████████| 1034/1034 [00:00<00:00, 170181.70prompt/s]


## BASE + META INFO + EXAMPLES

In [12]:
for prompt_repr in PROMPT_REPRESENTATIONS:
    for schema_info in SCHEMA_INFO_OPTION:
        for example_selection_type in EXAMPLE_SELECTION_TYPES:
            for k in ks:
                print(f'prompt_repr: {prompt_repr}, schema_info: {schema_info}, example_selection_type: {example_selection_type}, k: {k}')
                prompts = PromptFactory.build_prompts(
                    prompt_type = prompt_repr,
                    samples=test_samples,
                    examples=samples_for_examples,
                    example_selection_type=example_selection_type,
                    k=k,
                    schema_info_option=schema_info,
                    add_fk_info=True,
                )
                res = {
                    prompt_repr.value+ '$' + schema_info.value + '$' + 'fk' + '$' + example_selection_type.value + '$' + str(k): prompts
                }
                PROMPTS_LIST.append(res)

prompt_repr: PromptRepresentationType.OAI, schema_info: ShemaInfoOptions.FULL_SCHEMA, example_selection_type: ExampleSelectionType.RAND_HARDNESS, k: 3


Building prompts: 100%|██████████| 1034/1034 [00:33<00:00, 30.61prompt/s]


prompt_repr: PromptRepresentationType.OAI, schema_info: ShemaInfoOptions.FULL_SCHEMA, example_selection_type: ExampleSelectionType.RAND_HARDNESS, k: 5


Building prompts: 100%|██████████| 1034/1034 [00:33<00:00, 31.15prompt/s]


prompt_repr: PromptRepresentationType.OAI, schema_info: ShemaInfoOptions.FULL_SCHEMA, example_selection_type: ExampleSelectionType.COSINE_SIM_HARDNESS, k: 3


Building prompts: 100%|██████████| 1034/1034 [00:35<00:00, 28.89prompt/s]


prompt_repr: PromptRepresentationType.OAI, schema_info: ShemaInfoOptions.FULL_SCHEMA, example_selection_type: ExampleSelectionType.COSINE_SIM_HARDNESS, k: 5


Building prompts: 100%|██████████| 1034/1034 [00:36<00:00, 28.18prompt/s]


prompt_repr: PromptRepresentationType.REASONING, schema_info: ShemaInfoOptions.FULL_SCHEMA, example_selection_type: ExampleSelectionType.RAND_HARDNESS, k: 3


Building prompts: 100%|██████████| 1034/1034 [00:33<00:00, 30.90prompt/s]


prompt_repr: PromptRepresentationType.REASONING, schema_info: ShemaInfoOptions.FULL_SCHEMA, example_selection_type: ExampleSelectionType.RAND_HARDNESS, k: 5


Building prompts: 100%|██████████| 1034/1034 [00:33<00:00, 31.15prompt/s]


prompt_repr: PromptRepresentationType.REASONING, schema_info: ShemaInfoOptions.FULL_SCHEMA, example_selection_type: ExampleSelectionType.COSINE_SIM_HARDNESS, k: 3


Building prompts: 100%|██████████| 1034/1034 [00:35<00:00, 28.87prompt/s]


prompt_repr: PromptRepresentationType.REASONING, schema_info: ShemaInfoOptions.FULL_SCHEMA, example_selection_type: ExampleSelectionType.COSINE_SIM_HARDNESS, k: 5


Building prompts: 100%|██████████| 1034/1034 [00:35<00:00, 28.88prompt/s]


---

In [13]:
with open(OUTPUT_FILE, "w") as f:
    json.dump(PROMPTS_LIST, f, indent=2)

In [14]:
with open('./pipeline_results/dev_gold.txt', 'w') as f:
    for t in test_samples:
        f.write(t.query_gold + '\t' + t.db_id +'\n')

## Check prompts

In [15]:
with open(OUTPUT_FILE, 'r') as f:
    list_of_dicts = json.load(f)

prompt_merged_dict = {}
for d in list_of_dicts:
    prompt_merged_dict.update(d)

for k in prompt_merged_dict.keys():
    print(f'Key: {k}')

Key: openai$full_schema
Key: reasoning$full_schema
Key: openai$full_schema$fk
Key: reasoning$full_schema$fk
Key: openai$full_schema$fk$random_hardness$3
Key: openai$full_schema$fk$random_hardness$5
Key: openai$full_schema$fk$cosine_sim_hardness$3
Key: openai$full_schema$fk$cosine_sim_hardness$5
Key: reasoning$full_schema$fk$random_hardness$3
Key: reasoning$full_schema$fk$random_hardness$5
Key: reasoning$full_schema$fk$cosine_sim_hardness$3
Key: reasoning$full_schema$fk$cosine_sim_hardness$5


In [16]:
for i,s in enumerate(test_samples):
    if s.query_gold == 'SELECT T1.last_name FROM Owners AS T1 JOIN Dogs AS T2 ON T1.owner_id  =  T2.owner_id WHERE T2.age  =  ( SELECT max(age) FROM Dogs )':
        print(i)
        

960
961


In [18]:
print(prompt_merged_dict['openai$full_schema'][960])

### Complete sqlite SQL query only and with no explanation.
### SQLite SQL tables, with their properties:
#
# breeds: (breed_code, breed_name)
# charges: (charge_id, charge_type, charge_amount)
# sizes: (size_code, size_description)
# treatment_types: (treatment_type_code, treatment_type_description)
# owners: (owner_id, first_name, last_name, street, city, state, zip_code, email_address, home_phone, cell_number)
# dogs: (dog_id, owner_id, abandoned_yn, breed_code, size_code, name, age, date_of_birth, gender, weight, date_arrived, date_adopted, date_departed)
# professionals: (professional_id, role_code, first_name, street, city, state, zip_code, last_name, email_address, home_phone, cell_number)
# treatments: (treatment_id, dog_id, professional_id, treatment_type_code, date_of_treatment, cost_of_treatment)
#
### Question: List the last name of the owner owning the youngest dog.
### Answer: SELECT


In [19]:
print(prompt_merged_dict['openai$full_schema'][826])

### Complete sqlite SQL query only and with no explanation.
### SQLite SQL tables, with their properties:
#
# conductor: (conductor_id, name, age, nationality, year_of_work)
# orchestra: (orchestra_id, orchestra, conductor_id, record_company, year_of_founded, major_record_format)
# performance: (performance_id, orchestra_id, type, date, official_ratings_(millions), weekly_rank, share)
# show: (show_id, performance_id, if_first_show, result, attendance)
#
### Question: What are the names of conductors whose nationalities are not "USA"?
### Answer: SELECT


In [20]:
# Remove newlines and tabs, and collapse multiple spaces
sql = """
SELECT DISTINCT substr(v.phone_number, 1, 3) AS area_code
FROM votes v
JOIN contestants c1 ON v.contestant_number = c1.contestant_number
WHERE c1.contestant_name = 'Tabatha Gehling'
AND EXISTS (
    SELECT 1 FROM votes v2
    JOIN contestants c2 ON v2.contestant_number = c2.contestant_number
    WHERE substr(v2.phone_number, 1, 3) = substr(v.phone_number, 1, 3)
      AND c2.contestant_name = 'Kelly Clauss'
)
AND EXISTS (
    SELECT 1 FROM votes v3
    JOIN contestants c3 ON v3.contestant_number = c3.contestant_number
    WHERE substr(v3.phone_number, 1, 3) = substr(v.phone_number, 1, 3)
      AND c3.contestant_name = 'Tabatha Gehling'
)
"""

import re
single_line_sql = re.sub(r'\s+', ' ', sql).strip()
print(single_line_sql)

SELECT DISTINCT substr(v.phone_number, 1, 3) AS area_code FROM votes v JOIN contestants c1 ON v.contestant_number = c1.contestant_number WHERE c1.contestant_name = 'Tabatha Gehling' AND EXISTS ( SELECT 1 FROM votes v2 JOIN contestants c2 ON v2.contestant_number = c2.contestant_number WHERE substr(v2.phone_number, 1, 3) = substr(v.phone_number, 1, 3) AND c2.contestant_name = 'Kelly Clauss' ) AND EXISTS ( SELECT 1 FROM votes v3 JOIN contestants c3 ON v3.contestant_number = c3.contestant_number WHERE substr(v3.phone_number, 1, 3) = substr(v.phone_number, 1, 3) AND c3.contestant_name = 'Tabatha Gehling' )
