In [None]:
import sqlite3, json, warnings
import pandas as pd
from transformers import AutoTokenizer, AutoModelForCausalLM

In [None]:
tokenizer_350M = AutoTokenizer.from_pretrained("../models/nsql-350M")
model_350M = AutoModelForCausalLM.from_pretrained("../models/nsql-350M")

# tokenizer_2B = AutoTokenizer.from_pretrained("../models/nsql-2B")
# model_2B = AutoModelForCausalLM.from_pretrained("../models/nsql-2B")

In [None]:
with open("../src/spider/table_database_map.json") as f:
    table_map_db = json.load(f)

In [None]:
df = pd.read_csv('../src/NSText2SQL/train_spider.csv')
df.head()

In [None]:
def table_column_of_create_table(query):
    lines = query.splitlines()
    columns = []
    table_names = []

    # Look for "CREATE TABLE" and start capturing columns
    capture = False
    for line in lines:
        if "CREATE TABLE" in line:
            capture = True
            table_names.append(line.split()[-2].lower())
        elif line.strip().endswith(')') or line.strip().endswith(');'):
            capture = False
        elif capture:
            column_name = line.strip().split()[0]
            if column_name in ["CONSTRAINT", "PRIMARY"]: continue
            columns.append(column_name)
    return table_names, columns

In [None]:
def create_prompt(question, schema):
    full_prompt = ""
    full_prompt += f"{str(schema)}\n\n"
    full_prompt += "-- Using valid SQLite, answer the following questions for the tables provided above.\n\n"
    full_prompt += f"--{question}\n\nSELECT"
    return full_prompt


In [None]:
def pred_sql(prompt, model_name):

    if model_name == "nsql-350M":
        input_ids = tokenizer_350M(prompt, return_tensors="pt").input_ids
        generated_ids = model_350M.generate(input_ids, max_length=500)
        return tokenizer_350M.decode(generated_ids[0], skip_special_tokens=True).split('\n')[-1]
    # elif model_name == "nsql-2B":
    #     input_ids = tokenizer_2B(prompt, return_tensors="pt").input_ids
    #     generated_ids = model_2B.generate(input_ids, max_length=500)
    #     return tokenizer_2B.decode(generated_ids[0], skip_special_tokens=True).split('\n')[-1]

## Query by SQL

In [None]:
def query_db(sql_query, db_name):
    try:
        conn = sqlite3.connect(f'../src/spider/database/{db_name}/{db_name}.sqlite')
        cursor = conn.cursor()
    except:
        return "CANNOT CONNECT DATABASE"
    try:
        cursor.execute(sql_query)
        results = cursor.fetchall()
    except:
        return "CANNOT FETCHING DATA"
    conn.close()
    return results


In [None]:
df.shape

In [None]:
expect_query_results = []
nsql350M_query_results = []
nsql350M_query = []
nsql2B_query_results = []
nsql2B_query = []

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    for i, row in df.iterrows():
        question = row['Question']
        print(question)
        schema = row['Table']
        expect_query = row['SQL']
        schema_tables = table_column_of_create_table(schema)[0]
        try:
            db_of_table = table_map_db[schema_tables[0].lower()]
        except KeyError:
            error_occur = "TABLE NOT MATCH"
            expect_query_results.append(error_occur)
            nsql350M_query.append(error_occur)
            nsql350M_query_results.append(error_occur)
            # nsql2B_query.append(error_occur)
            # nsql2B_query_results.append(error_occur)
            continue

        expect_result = query_db(expect_query, db_of_table)
        full_prompt = create_prompt(question, schema)
        # try:
        #     pred_sql_query_350M = pred_sql(full_prompt, "nsql-350M")
        # except:
        #     pred_sql_query_350M = "GEN QUERY ERROR"
        pred_sql_query_350M = "test"
        print(pred_sql_query_350M)
        pred_result_350M = query_db(pred_sql_query_350M, db_of_table)
        
        # try:
        #     pred_sql_query_2B = pred_sql(full_prompt, "nsql-2B")
        # except:
        #     pred_sql_query_2B = "GEN QUERY ERROR"
        # print(pred_sql_query_2B)
        # pred_result_2B = query_db(pred_sql_query_2B, db_of_table)

        expect_query_results.append(expect_result)
        nsql350M_query.append(pred_sql_query_350M)
        nsql350M_query_results.append(pred_result_350M)
        # nsql2B_query.append(pred_sql_query_2B)
        # nsql2B_query_results.append(pred_result_2B)
        print(i)

In [None]:
df['nsql-350M-query'] = nsql350M_query
# df['nsql-2B-query'] = nsql2B_query
df['expect_result'] = expect_query_results
df['nsql-350M-result'] = nsql350M_query_results
# df['nsql-2B-result'] = nsql2B_query_results

df.to_csv("model-expirements.csv", index=False)
df.head()

In [None]:
df = pd.read_csv("../src/temp-model-expirements.csv")
exclude_values = ["CANNOT FETCHING DATA", "TABLE NOT MATCH", None]
filtered_df = df[~df['Expect-result'].isin(exclude_values)]
print(filtered_df.shape)
filtered_df.head()

In [None]:
filtered_df.to_excel("model-exp.xlsx", index=False)