### CodeS Query Execution
The purpose of this notebook is to execute all of the predicted CodeS queries in order to emulate the CodeS method of testing multiple inferences against a single NL question and selecting the first one that does not throw some sort of SQL engine error.

We have to do this with our own implementation because the benchmarks in the CodeS code base use Sqlite databases, whereas SNAILS uses MS SQL Server databases.

```
Copyright 2024 Kyle Luoma

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
```

In [None]:
import multiprocessing as mp
mp.set_start_method("spawn")
import src.util.db_util_sqlite as db_util_sqlite
import pandas as pd
import src.nl_to_sql_inference_and_prompt_generation as pgn
from tqdm import tqdm

In [None]:
NUM_PROCESSES = 32


In [None]:
nat_levels = ["Native", "Regular", "Low", "Least"]
nat_label_lookup = {
    "Native": "NATIVE",
    "Regular": "N1",
    "Low": "N2",
    "Least": "N3"
}

In [None]:
def mp_denat_f(query_data: tuple) -> tuple:
    db = query_data[0]
    q_num = query_data[1]
    q_num_ix = query_data[2]
    query = query_data[3]
    nat_level = query_data[4]
    denat_query = pgn.denaturalize_query(
        query=query,
        naturalness={
                    "table": nat_level,
                    "column": nat_level
                },
        db_name=db.split("-")[0],
        syntax="sqlite"
    )
    return (db, q_num, q_num_ix, denat_query)

In [None]:
for nat_level in nat_levels:
    codes_df = pd.read_excel(f"./queries/codes_predictions/snails_{nat_level}_CodeS_predictions.xlsx")
    codes_df = codes_df.query("database=='SBODemoUS-Service'")
    database = []
    number = []
    question = []
    hints = []
    notes = []
    query_gold = []
    schema_pruning = []
    prompt = []
    query_predicted = []
    query_predicted_on_naturalized_schema = []
    col_naturalness_modifier = []
    tab_naturalness_modifier = []
    completed_queries = []

    denat_query_lookup = {}
    if nat_level != "Native":
        pool = mp.Pool(processes=NUM_PROCESSES)
        mp_query_list = [
            (
                row.database, 
                row.number, 
                row.predicted_sql_ix,
                row.query_predicted, 
                nat_label_lookup[nat_level][1]
            ) 
                for row in codes_df.itertuples()
        ]
        print(f"Running mp_denat_f for nat_level {nat_level}")
        # denaturalized_queries = pool.map(
        #     mp_denat_f,
        #     mp_query_list
        #     )
        denaturalized_queries = []
        for q_tuple in tqdm(mp_query_list):
            denaturalized_queries.append(mp_denat_f(q_tuple))
        print("Finished mp_denat_f")
        
        for q in denaturalized_queries:
            denat_query_lookup[(q[0], q[1], q[2])] = q[3]

    for row in codes_df.itertuples():
        db_name = row.database.split("-")[0]
        if nat_level == "Native":
            query = row.query_predicted
        else:
            query = denat_query_lookup[(row.database, row.number, row.predicted_sql_ix)]
        if (row.database, row.number) not in completed_queries:
            try:
                res_df = db_util_sqlite.do_query(
                    query=query,
                    database_name=db_name,
                    db_list_file="./.local/dbinfo_sqlite.json"
                    )
            except db_util_sqlite.sqlite3.OperationalError as e:
                print(e)
                res_df = None
            if type(res_df) == pd.DataFrame or row.predicted_sql_ix == 3:
                completed_queries.append((row.database, row.number))
                database.append(row.database)
                number.append(row.number)
                question.append(row.question)
                hints.append([])
                notes.append([])
                query_gold.append(row.query_gold)
                schema_pruning.append(1)
                prompt.append("CodeS prompt")
                query_predicted_on_naturalized_schema.append(row.query_predicted)
                query_predicted.append(query)
                col_naturalness_modifier.append(nat_label_lookup[row.col_naturalness_modifier])
                tab_naturalness_modifier.append(nat_label_lookup[row.tab_naturalness_modifier])
    predict_df = pd.DataFrame({
        "database": database,
        "number": number,
        "question": question,
        "hints": hints,
        "notes": notes,
        "query_gold": query_gold,
        "schema_pruning": schema_pruning,
        "prompt": prompt,
        "query_predicted_on_naturalized_schema": query_predicted_on_naturalized_schema,
        "query_predicted": query_predicted,
        "col_naturalness_modifier": col_naturalness_modifier,
        "tab_naturalness_modifier": tab_naturalness_modifier
    })
    for db_name in predict_df.database.unique():
        print(db_name)
        predict_df.query(f"database == '{db_name}'").to_excel(
            f"./queries/predicted/{db_name}-{nat_label_lookup[nat_level]}-{nat_label_lookup[nat_level]}-CodeS-queries-predicted.xlsx"
        )
            

In [None]:
snails_dbs = [
    "ASIS_20161108_HerpInv_Database",
    "ATBI",
    "CratersWildlifeObservations",
    "KlamathInvasiveSpecies",
    "NorthernPlainsFireManagement",
    "NTSB",
    "NYSED_SRC2022",
    "PacificIslandLandbirds",
    "SBODemoUS-Banking",
    "SBODemoUS-Business Partners",
    "SBODemoUS-Finance",
    "SBODemoUS-General",
    "SBODemoUS-Human Resources",
    "SBODemoUS-Inventory and Production",
    "SBODemoUS-Reports",
    "SBODemoUS-Sales Opportunities",
    "SBODemoUS-Service"
]

In [None]:
query_modifications = {
    ("ATBI", 7): [
        ("select TOP 1 Event_ID", "select Event_ID"), 
        ("count(*) desc", "count(*) desc limit 1")
    ],
    ("ATBI", 9): [
        ("SELECT TOP 1 Event_ID", "SELECT Event_ID"),
        ("order by count(*)", "order by count(*) limit 1")
    ],
    ("ATBI", 36): [
        ("select top 1 decay", "select decay"),
        ("order by MPD asc", "order by MPD asc limit 1")
    ],
    ("ATBI", 37): [
        ("select top 1 decay", "select decay"),
        ("order by Length asc", "order by Length asc limit 1")
    ],
    ("KlamathInvasiveSpecies", 1): [
        ("and YEAR(Start_Date) = 2011", ""),
        ("and MONTH(Start_Date) = 8", ""),
        ("where DAY(Start_Date) = 5", "WHERE date(Start_Date) between date('2011-08-05') AND date('2011-08-05')")
    ],
    ("KlamathInvasiveSpecies", 2): [
        ("and YEAR(Start_Date) = 2011", ""),
        ("and MONTH(Start_Date) = 8", ""),
        ("where DAY(Start_Date) = 5", "WHERE date(Start_Date) between date('2011-08-05') AND date('2011-08-05')")
    ],
    ("KlamathInvasiveSpecies", 8): [
        (
            "where year(Start_Date) = 2009 and month(Start_Date) = 7 and day(Start_Date) = 28",
            "WHERE date(Start_Date) BETWEEN date('2009-07-28') AND date('2009-07-28')"
        ),
        ("top 1 ", ""),
        ("order by Elevation desc", "order by Elevation desc limit 1")
    ],
    ("KlamathInvasiveSpecies", 21): [
        ("year(Revision_Date)", "strftime('%Y', Revision_Date)")
    ],
    ("KlamathInvasiveSpecies", 27): [
        (
            "where year(Start_Date) = 2021 and month(Start_Date) = 8", 
            "where date(Start_Date) >= date('2021-08-01') and date(Start_Date) <= date('2021-08-31')"
        )
    ],
    ("KlamathInvasiveSpecies", 39): [
        (
            "where year(Start_Date) = 2015 and month(Start_Date) = 7",
            "where date(Start_Date) >= date('2015-07-01') and date(Start_Date) <= date('2015-07-31')"
        )
    ],
    ("NorthernPlainsFireManagement", 1): [
        ("year(Event_Date)", "strftime('%Y', Event_Date)")
    ],
    ("NorthernPlainsFireManagement", 4): [
        ("year(event_date) = 2001", "strftime('%Y', Event_Date) = '2001'")
    ],
    ("NorthernPlainsFireManagement", 24): [
        ("where year(release_date) = 2015", "where date(release_date) >= date('2015-01-01') AND date(release_date) <= date('2015-12-31')")
    ],
    ("NorthernPlainsFireManagement", 26): [
        ("year(Event_Date) = 1999", "strftime('%Y', Event_Date) = '1999'")
    ],
    ("NorthernPlainsFireManagement", 27): [
        ("month(e.Event_Date)", "strftime('%m', e.Event_Date)")
    ],
    ("NorthernPlainsFireManagement", 35): [
        ("month(Event_Date)", "strftime('%m', Event_Date)"),
        ("top 1", ""),
        ("desc", "desc limit 1")
    ],
    ("NTSB", 71): [
        ("ORDER BY CASEID", "ORDER BY EV.CASEID")
    ],
    ("NYSED_SRC2022", 12): [
        ("where not exists ", "where p.ENTITY_CD NOT IN"),
        ("and p.ENTITY_CD = q.ENTITY_CD", "")
    ],
    ("SBODemoUS-Banking", 5): [
        ("year(CreateDate) = 2012", "date(CreateDate) >= date('2012-01-01' and date(CreateDate) <= date('2012-12-31'))") 
    ],
    ("SBODemoUS-Banking", 9): [
        ("where DeposDate = '2012-01-31'", "where date(DeposDate) = '2012-01-31'") 
    ],
    ("SBODemoUS-Business Partners", 7): [
        ("Year(CreateDate) = '2021'", "date(CreateDate) >= date('2021-01-01' and date(CreateDate) <= date('2021-12-31'))") 
    ],
    ("SBODemoUS-Finance", 2): [
        ("year(DueDate) = '2021'", "date(DueDate) >= date('2021-01-01' and date(DueDate) <= date('2021-12-31'))") 
    ],
    ("SBODemoUS-Finance", 6): [
        ("year(RateDate)", "strftime('%Y', RateDate)"),
        ("top 1 ", ""),
        ("desc", "desc limit 1")
    ],
    ("SBODemoUS-Finance", 7): [
        ("year(NextDeu)", "strftime('%Y', NextDeu)") 
    ],
    ("SBODemoUS-Finance", 10): [
        ("year(ValidFrom) > 2020", "date(ValidFrom) > date('2020-12-31')") 
    ],
    ("SBODemoUS-General", 1): [
        ("year < '2021-10-20'", "date(START) < date('2021-10-20')") 
    ],
    ("SBODemoUS-Sales Opportunities", 4): [
        ("year(CloseDate) = 2014", "date(CloseDate) >= date('2014-01-01' and date(CloseDate) <= date('2014-12-31'))") 
    ],
    ("SBODemoUS-Sales Opportunities", 5): [
        ("year(ValidFrom) = 2021", "date(ValidFrom) >= date('2021-01-01' and date(ValidFrom) <= date('2021-12-31'))") 
    ],
    ("SBODemoUS-Sales Opportunities", 10): [
        ("year(CreateDate)", "strftime('%Y', CreateDate)") 
    ]
}

### Modify gold queries for Sqlite compatibility

In [None]:
import pandas as pd
snails_dbs = [
    "ASIS_20161108_HerpInv_Database",
    "ATBI",
    "CratersWildlifeObservations",
    "KlamathInvasiveSpecies",
    "NorthernPlainsFireManagement",
    "NTSB",
    "NYSED_SRC2022",
    "PacificIslandLandbirds",
    "SBODemoUS-Banking",
    "SBODemoUS-Business Partners",
    "SBODemoUS-Finance",
    "SBODemoUS-General",
    "SBODemoUS-Human Resources",
    "SBODemoUS-Inventory and Production",
    "SBODemoUS-Reports",
    "SBODemoUS-Sales Opportunities",
    "SBODemoUS-Service"
]
for db_name in snails_dbs:
    for nat_level in ["NATIVE", "N1", "N2", "N3"]:
        sql_file_text = ""
        limit_queries = []
        pred_df = pd.read_excel(f"./queries/predicted/archive/codes_with_top_in_gold/{db_name}-{nat_level}-{nat_level}-CodeS-queries-predicted.xlsx")
        for row in pred_df.itertuples():
            sql_file_text += f"-- {row.number}: {row.question}\n"
            if "top " in row.query_gold.lower() and (row.database, row.number) not in query_modifications:
                top_num = row.query_gold.lower().split("top")[1].split()[0]
                limit_query = row.query_gold.replace(f"top {top_num}", "")
                limit_query = limit_query.replace(f"TOP {top_num}", "")
                limit_query += f"limit {top_num}"
                limit_queries.append(limit_query)
                sql_file_text += f"{limit_query}\n;\n\n"
            elif (row.database, row.number) in query_modifications:
                limit_query = row.query_gold
                for repl_tuple in query_modifications[(row.database, row.number)]:
                    limit_query = limit_query.replace(repl_tuple[0], repl_tuple[1])
                limit_queries.append(limit_query)
                sql_file_text += f"{limit_query}\n;\n\n"
            else:
                limit_queries.append(row.query_gold)
                sql_file_text += f"{row.query_gold}\n;\n\n"
        pred_df["query_gold"] = limit_queries
        pred_df["notes"] = ["Gold query modified from top to limit for sqlite compatibility" for i in range(0, pred_df.shape[0])]
        pred_df.to_excel(f"./queries/predicted/{db_name}-{nat_level}-{nat_level}-CodeS-queries-predicted.xlsx")
        with open(f"./queries/snails_sqlite_queries/{db_name}_Native.sql", "w") as f:
            f.write(sql_file_text)



### Test Sqlite-compatible gold queries

In [None]:
import src.util.db_util_sqlite as db_util_sqlite
import pandas as pd
snails_dbs = [
    "ASIS_20161108_HerpInv_Database",
    "ATBI",
    "CratersWildlifeObservations",
    "KlamathInvasiveSpecies",
    "NorthernPlainsFireManagement",
    "NTSB",
    "NYSED_SRC2022",
    "PacificIslandLandbirds",
    "SBODemoUS-Banking",
    "SBODemoUS-Business Partners",
    "SBODemoUS-Finance",
    "SBODemoUS-General",
    "SBODemoUS-Human Resources",
    "SBODemoUS-Inventory and Production",
    "SBODemoUS-Reports",
    "SBODemoUS-Sales Opportunities",
    "SBODemoUS-Service"
]
for db_name in snails_dbs:
    for nat_level in ["NATIVE"]:
        limit_queries = []
        pred_df = pd.read_excel(f"./queries/predicted/{db_name}-{nat_level}-{nat_level}-CodeS-queries-predicted.xlsx")
        for row in pred_df.itertuples():
            # print(nat_level, row.database, row.number, row.query_gold)
            res = db_util_sqlite.do_query(
                row.query_gold,
                row.database.split("-")[0],
                db_list_file="./.local/dbinfo_sqlite.json"
            )
            if res.shape[0] == 0:
                print(row.database, row.number, "Has zero-length result!!!")
                print(row.query_gold)
            