In [99]:
import os, sys, sqlite3, json, time
from openai import OpenAI
from dotenv import load_dotenv
from transformers import AutoTokenizer, AutoModelForCausalLM
import pandas as pd
import numpy as np
from sentence_transformers import SentenceTransformer
from nltk.stem import WordNetLemmatizer
from sql_metadata import Parser
sys.path.append('../')
from filtering_schema.Description_base_linking import SchemaLinking

## Load Schema Link filtering schema from module

In [6]:
load_dotenv()
# Set environment variables
base_dir = "../filtering_schema"
# os.environ['schema_description_folder_path'] = os.path.join(base_dir, 'src', 'schemas', 'column-descriptions')
# os.environ['schema_data_types_folder_path'] = os.path.join(base_dir, 'src', 'schemas', 'column-datatypes')
os.environ['nsql_model_path'] = os.path.join(base_dir, 'models', 'nsql-350M')
os.environ['sentence_emb_model_path'] = os.path.join(base_dir, 'models', 'all-MiniLM-L6-v2')
os.environ['schema_description_folder_path'] = os.path.join(base_dir, 'src', 'schemas', 'coffeeshop-descriptions')
os.environ['schema_data_types_folder_path'] = os.path.join(base_dir, 'src', 'schemas', 'coffeeshop-datatypes')
os.environ['column_threshold'] = '0.2'
os.environ['table_threshold'] = '0.2'
os.environ['max_select_column'] = '10'
os.environ['filter_table'] = 'False'
os.environ['verbose'] = 'False'


schema_link = SchemaLinking()
schema_link.selected_domain(schema_description_folder_path=os.environ.get('schema_description_folder_path'),
                            schema_data_types_folder_path=os.environ.get('schema_data_types_folder_path'))

tokenizer = AutoTokenizer.from_pretrained(os.environ.get('nsql_model_path'))
model = AutoModelForCausalLM.from_pretrained(os.environ.get('nsql_model_path'))
verbose = bool(os.environ.get('verbose').lower() == 'true')

In [7]:
schema_link.filter_schema("Which shop opening from 2022")

Table string match  ----> shop


{'happy_hour': {},
 'happy_hour_member': {},
 'member': {},
 'shop': {'Shop_ID': 0.394,
  'Address': 0.605,
  'Num_of_staff': 0.408,
  'Score': 0.357,
  'Open_Year': 0.643}}

## LLM predict [MASK] example

In [27]:
from openai import OpenAI

client = OpenAI()
llm_model_name = 'gpt-3.5-turbo'
stop = ['\n\n']
prompt = """You are a SQL query assistant.
I have some SQL where the [MASK] column is syntaxed and I want you to respond to output that populates the [MASK] column of the SQL input followed by the question and schema description (name - description).
If you don't know which column to fill in Do not include columns that you have created yourself. And only columns defined from the schema must be used. 
Do not use columns from other tables or schema. must also be used from the same table defined in the input.


table :     cat - this table contain cat information
column :    id - number for identify cat
            name - name of cat
            age - age of cat
            gender - gender of cat
            
question: show me the name of cat.
input: SELECT [MASK] FROM cat;
output: SELECT name FROM cat;

question: show me the name and age of cat.
input: SELECT [MASK], [MASK] FROM cat;
output: SELECT name, age FROM cat;

question: Count number of cate each gender.
input: SELECT [MASK], COUNT([MASK]) FROM cat GROUP BY [MASK];
output: SELECT gender, COUNT(*) FROM cat GROUP BY gender;

question: show me the name and age of cat.
input: SELECT [MASK], [MASK] FROM cat;
output:
"""

response = client.chat.completions.create(
    model=llm_model_name,
    messages=[
            {"role": "system",
                "content": "I will give you some x-y examples followed by a x, you need to give me the y, and no other content."},
            {"role": "user", "content": prompt},
            ],
    stop=stop
)

In [28]:
response.choices[0].message.content

'SELECT name, age FROM cat;'

In [68]:
client = OpenAI()
stop = ['\n\n']

def llm_response(prompt, model='gpt-3.5-turbo'):
    response = client.chat.completions.create(
        model=model,
        messages=[
                {"role": "system",
                    "content": "I will give you some x-y examples followed by a x, you need to give me the y, and no other content."},
                {"role": "user", "content": prompt},
                ],
        stop=stop
    )
    return response.choices[0].message.content

## Use with SPIDER

In [33]:
src_folder = "../src"
schema_description_file = "mockup_schema_description.json"
with open(os.path.join(src_folder, schema_description_file)) as f:
    dbs = json.load(f)
model = SentenceTransformer('../models/all-MiniLM-L6-v2')
lemmanizer = WordNetLemmatizer()

for i in range(len(dbs)):
    dbs[i]['table'] = dbs[i]['table'].lower()
    dbs[i]['columns'] = {key.lower(): value for key, value in dbs[i]['columns'].items()}
    
dbs[:1]

[{'table': 'musical',
  'description': 'This table contains information about musicals.',
  'columns': {'musical_id': 'Unique identifier for the musical',
   'name': 'Name of the musical',
   'year': 'Year the musical was produced',
   'award': 'Award received by the musical',
   'category': 'Category of the award',
   'nominee': 'Name of the nominee associated with the musical',
   'result': 'Result of the award nomination for the musical'}}]

In [34]:
df = pd.read_csv('../src/NSText2SQL/train_spider.csv')
print(df.shape)
df.head()

(6994, 3)


Unnamed: 0,Question,Table,SQL
0,"What are the first names, office locations of ...","CREATE TABLE course (\n crs_code text,\n ...","SELECT T2.emp_fname, T4.prof_office, T3.crs_de..."
1,Please show the songs that have result 'nomina...,"CREATE TABLE artist (\n artist_id number,\n...",SELECT T2.song FROM music_festival AS T1 JOIN ...
2,Which teams had more than 3 eliminations?,CREATE TABLE elimination (\n elimination_id...,SELECT team FROM elimination GROUP BY team HAV...
3,"Show the names of people, and dates and venues...","CREATE TABLE people (\n people_id number,\n...","SELECT T3.name, T2.date, T2.venue FROM debate_..."
4,Tell me the the date when the first claim was ...,CREATE TABLE settlements (\n settlement_id ...,SELECT date_claim_made FROM claims ORDER BY da...


In [40]:
# Split the SQL query into lines
def table_column_of_create_table(query):
    lines = query.splitlines()
    schema = {}
    # Look for "CREATE TABLE" and start capturing columns
    capture = False
    for line in lines:
        if "CREATE TABLE" in line:
            capture = True
            table_name = line.split()[-2].lower()
            schema[table_name] = []
            # table_names.append(table_name)
        elif line.strip().endswith(')') or line.strip().endswith(');'):
            capture = False
        elif capture:
            column_name = line.strip().split()[0]
            if column_name in ["CONSTRAINT", "PRIMARY"]: continue
            schema[table_name].append(column_name.lower())
    return schema

In [81]:
def query_db(sql_query, db_name):
    try:
        conn = sqlite3.connect(f'../src/spider/database/{db_name}/{db_name}.sqlite')
        cursor = conn.cursor()
    except:
        return "CANNOT CONNECT DATABASE"
    try:
        cursor.execute(sql_query)
        results = cursor.fetchall()
    except:
        return "CANNOT FETCHING DATA"
    conn.close()
    return results

In [39]:
with open("../src/spider/table_database_map.json", "r") as f:
    map_table_db = json.load(f)

In [103]:
df_data = {"Question": [],
           "Actual SQL": [],
           "Predict MASK SQL": [],
           "Actual result": [],
           "Predict result": []}

exists_table = [i['table'].lower() for i in dbs]
short_learning_prompt = """You are a SQL query assistant.
I have some SQL where the [MASK] column is syntaxed and I want you to respond to output that populates the [MASK] column of the SQL input followed by the question and schema description (name - description).
If you don't know which column to fill in Do not include columns that you have created yourself. And only columns defined from the schema must be used. 
Do not use columns from other tables or schema. must also be used from the same table defined in the input.

#################

table :     cat - this table contain cat information
column :    id - number for identify cat
            name - name of cat
            age - age of cat
            gender - gender of cat
question: Count number of cate each gender.
input: SELECT [MASK], COUNT([MASK]) FROM cat GROUP BY [MASK];
output: SELECT gender, COUNT(*) FROM cat GROUP BY gender;

#################

table :     cat - this table contain cat information
column :    id - number for identify cat
            name - name of cat
            age - age of cat
            gender - gender of cat
question: show me the name and age of cat.
input: SELECT [MASK], [MASK] FROM cat;
output: SELECT name, age FROM cat;

#################
"""
for i,row in df.iterrows():
    table_of_query = row['Table']
    expect_schema = table_column_of_create_table(table_of_query)
    tables  = list(expect_schema.keys())
    all_columns = [value for values in expect_schema.values() for value in values]
    is_present = np.all(np.isin(np.array(tables), np.array(exists_table)))
    # all table from db is in exist tables (same database)
    if is_present:
        full_prompt = short_learning_prompt 
        question = row['Question']
        actual_sql = row['SQL']
        sql = actual_sql
        db = ""
        if '*' in sql: sql = sql.replace('*', "[MASK]")

        expect_table = []
        columns = Parser(row['SQL']).columns
        for col in columns:
            # found join function (Table1.column1)
            if "." in col:
                table_name, column_name = col.split('.') 
                expect_table.append(table_name)
        
        expect_table.extend(Parser(row['SQL']).tables)
        expect_table = list(set(expect_table))

        for t in dbs:
            table_name = t['table']
            if table_name in expect_table:
                db_id = map_table_db[table_name]
                full_prompt += f"\ntable: {table_name} - {t['description']}\ncolumn: "
                columns_of_table = list(t['columns'].keys())
                for column in columns_of_table:
                    full_prompt += f"{column} - {t['columns'][column]}\n"
                    if column in sql: 
                        sql = sql.replace(column, "[MASK]")
        
        full_prompt += f"question: {question}\n"
        full_prompt += f"input: {sql}\noutput:"
        strat_time = time.time()
        response = llm_response(full_prompt)
        if '\n' in response: response = response.replace('\n',' ')

        df_data['Question'].append(question)
        df_data['Actual SQL'].append(actual_sql)
        df_data['Predict MASK SQL'].append(response)
        df_data['Actual result'].append(query_db(actual_sql, db_id))
        df_data['Predict result'].append(query_db(response, db_id))

        print("DB:",db_id)
        print("Question:",question)
        print("SQL:",actual_sql)
        print("MASKED SQL:",sql)
        print("Response:",response)
        # print(f"Time taken: {time.time() - strat_time} seconds")
        print()

DB: coffee_shop
Question: Find the number of members living in each address.
SQL: SELECT COUNT(*), address FROM member GROUP BY address
MASKED SQL: SELECT COUNT([MASK]), [MASK] FROM member GROUP BY [MASK]
Response: SELECT COUNT(*) as count, address FROM member GROUP BY address

DB: cinema
Question: Count the number of cinemas.
SQL: SELECT COUNT(*) FROM cinema
MASKED SQL: SELECT COUNT([MASK]) FROM cinema
Response: SELECT COUNT(*) FROM cinema

DB: hospital_1
Question: How many rooms does each block floor have?
SQL: SELECT COUNT(*), T1.blockfloor FROM block AS T1 JOIN room AS T2 ON T1.blockfloor = T2.blockfloor AND T1.blockcode = T2.blockcode GROUP BY T1.blockfloor
MASKED SQL: SELECT COUNT([MASK]), T1.[MASK] FROM block AS T1 JOIN room AS T2 ON T1.[MASK] = T2.[MASK] AND T1.[MASK] = T2.[MASK] GROUP BY T1.[MASK]
Response: SELECT COUNT(roomnumber), T1.blockfloor FROM block AS T1 JOIN room AS T2 ON T1.blockcode = T2.blockcode AND T1.blockfloor = T2.blockfloor GROUP BY T1.blockfloor

DB: hospit

In [104]:
result_df = pd.DataFrame(df_data)
result_df.to_excel("LLM-predict-MASK.xlsx", index=False)
result_df.head()

Unnamed: 0,Question,Actual SQL,Predict MASK SQL,Actual result,Predict result
0,Find the number of members living in each addr...,"SELECT COUNT(*), address FROM member GROUP BY ...","SELECT COUNT(*) as count, address FROM member ...","[(1, Bridgeport), (2, Cheshire), (3, Hartford)...","[(1, Bridgeport), (2, Cheshire), (3, Hartford)..."
1,Count the number of cinemas.,SELECT COUNT(*) FROM cinema,SELECT COUNT(*) FROM cinema,"[(10,)]","[(10,)]"
2,How many rooms does each block floor have?,"SELECT COUNT(*), T1.blockfloor FROM block AS T...","SELECT COUNT(roomnumber), T1.blockfloor FROM b...","[(9, 1), (9, 2), (9, 3), (9, 4)]","[(9, 1), (9, 2), (9, 3), (9, 4)]"
3,What procedures cost less than 5000 and have J...,SELECT name FROM procedures WHERE cost < 5000 ...,SELECT name FROM procedures WHERE cost < 5000 ...,"[(Folded Demiophtalmectomy,), (Follicular Demi...","[(Folded Demiophtalmectomy,), (Follicular Demi..."
4,What is the location with the most cinemas ope...,SELECT location FROM cinema WHERE openning_yea...,SELECT location FROM cinema WHERE openning_yea...,"[(County Tipperary,)]","[(County Tipperary,)]"
