# SQLQUeryChain - Mondial - Foreign Key - GPT 3.5

In [49]:

from langchain.sql_database import SQLDatabase

from langchain.chat_models import ChatOpenAI
from langchain.chains import create_sql_query_chain
from urllib.parse import quote  
from langchain.callbacks import get_openai_callback
import time
from dotenv import load_dotenv
import os
import sys
import json
load_dotenv()

experiment_path = '../..'
path = os.path.abspath('')
module_path = os.path.join(path, experiment_path)
print(module_path)
if module_path not in sys.path:
    sys.path.append(module_path+"/functions")


from sqldatabase_langchain_utils import SQLDatabaseLangchainUtils


/Users/moon/Documents/text_to_sql_byLLM/experiments/SQLQueryChain/../..


# Schema

In [50]:
# SCHEMA = 'mondial_gpt'
# PREFIX = 'mondial'

# FILE_NAME_RESULT = f"results/9_sql_queries_chatgpt_{SCHEMA}_fk.json"

SCHEMA = 'shipment'
PREFIX = 'shipment'
FILE_NAME_RESULT = f"results/9_sql_queries_gpt3.5_shipment.json"

In [51]:
def save_queries(queries):
    data = {"queries":queries}
    with open(FILE_NAME_RESULT, "w") as arquivo_json:
        json.dump(data, arquivo_json, indent=4) 

def read_queries():
    with open(FILE_NAME_RESULT, encoding='utf-8', errors='ignore') as json_data:
        data = json.load(json_data, strict=False)
    queries = data["queries"]
    return queries

## Connection

In [52]:
json_file_path = f"../../datasets/{SCHEMA}_db_connection.json"
with open(json_file_path, encoding='utf-8', errors='ignore') as json_data:
    db_connection = json.load(json_data, strict=False)

db_connection

{'DB_USER_NAME': 'SHIPMENT',
 'DB_PASS': 'oraclee',
 'DB_HOST': 'localhost',
 'DB_PORT': 1521,
 'DB_NAME': 'XE',
 'SQL_DRIVER': 'oracle+oracledb',
 'SERVICE_NAME': 'XE',
 'SCHEMA': 'SHIPMENT',
 'KEYWORD_SEARCH_API_URL': ''}

### Using SQLDatabase to get all the information from the database

In [53]:
db = SQLDatabaseLangchainUtils(db_connection=db_connection)

# exclusao = [
#     f"{SCHEMA}_tmdp",
#     f"{SCHEMA}_tmdpmap",
#     f"{SCHEMA}_tmds",
#     f"{SCHEMA}_tmjmap",
#     f"{SCHEMA}_tpv",
#     f"{SCHEMA}_tmdc",
#     f"{SCHEMA}_tmdcmap",
#     f"{SCHEMA}_tmdej",
#     f"{SCHEMA}_log_action",
#     f"{SCHEMA}_log_error",
#     f"{SCHEMA}_favorite_item", 
#     f"{SCHEMA}_favorite_query",
#     f"{SCHEMA}_favorite_tag",
#     f"{SCHEMA}_favorite_tag_item",
#     f"{SCHEMA}_favorite_visualization",
#     f"{SCHEMA}_dashboard",
#     f"{SCHEMA}_history",
#     "teste_cliente",
#     "teste_fornecedor",
#     "teste_funcionario"
# ]

# include_tables = [s for s in db.get_table_names() if not s.startswith(PREFIX) and s not in exclusao]
include_tables = db.get_table_names()
db = SQLDatabaseLangchainUtils(db_connection=db_connection, include_tables=include_tables)
db.get_table_names()

['customer', 'product', 'shipment', 'shippedproduct']

In [54]:
len(include_tables)

4

## Creating the prompt

In [55]:
from langchain.prompts.prompt import PromptTemplate

f = open(f"prompts/prompt_template_sql_query_chain.txt", "r")
prompt_template = f.read()
f.close()


PROMPT = PromptTemplate(
    input_variables=["input", "table_info", "top_k"], template=prompt_template
)

print(PROMPT)

input_variables=['input', 'table_info', 'top_k'] input_types={} partial_variables={} template='You are an Oracle SQL expert. Given an input question, first create a syntactically correct Oracle SQL query to run, then look at the results of the query and return the answer to the input question.\nUnless the user specifies in the question a specific number of examples to obtain, don\'t query for at {top_k} most results or any using the FETCH FIRST n ROWS ONLY clause as per Oracle SQL. You can order the results to return the most informative data in the database.\nNever query for all columns from a table. You must query only the columns that are needed to answer the question.\nPay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.\nPay attention to use TRUNC(SYSDATE) function to get the current date, if the question involves "today". \n\nSome hints:\n- Don\'t u

## Creating Chain to generate SQL

In [56]:
query_chain  = create_sql_query_chain(ChatOpenAI(temperature=0, model_name='gpt-3.5-turbo-16k'), db.db, prompt=PROMPT)


### Test: query and Number of Tokens

In [57]:
# with get_openai_callback() as cb:
#     sql_query = query_chain.invoke({"question":"What are the provinces with an area more than 10000?"})
    
#     print(cb.total_tokens)
#     print(cb.prompt_tokens)
#     print(cb.completion_tokens)
#     print(cb.total_cost)
# sql_query

In [58]:
# with get_openai_callback() as cb:
#     sql_query = query_chain.invoke({"question":"What are the languages spoken in Poland?"})
#     print(cb.total_tokens)
#     print(cb.prompt_tokens)
#     print(cb.completion_tokens)
#     print(cb.total_cost)
# sql_query

## Preparing natural language queries to run in LLM

In [59]:

json_file_path = f"../../datasets/{PREFIX}/queries_{PREFIX}.json"
with open(json_file_path, encoding='utf-8', errors='ignore') as json_data:
    queries = json.load(json_data, strict=False)
queries = queries['queries']
queries

[{'id': 1,
  'question': "Return the customer name and address as one field called fullAddress that consists of the address, city, state. Only show customers in 'IA' or 'BC' that have a valid address. Order by customer name ascending.",
  'query_string': '',
  'type': 'medium'},
 {'id': 2,
  'question': 'Return the shipment id, shipment date, product id, and amount for all shipments in 2022 where there was a product shipped with an amount greater than 8. Only show a shipment once and order by shipment date descending and amount descending.',
  'query_string': '',
  'type': 'medium'},
 {'id': 3,
  'question': "Find all customers that have an 'R' or 'T' in their name and have a state of 'MI' or 'IA'. Order by customer name descending.",
  'query_string': '',
  'type': 'simple'},
 {'id': 4,
  'question': 'Return a list of the unique product id and names that have shipped before with an amount less than 5 and order by product id descending.',
  'query_string': '',
  'type': 'simple'},
 {'i

# Running queries in LLM to generate SQL

In [60]:

# A cada X consultas, vai ser gerado um delay de 10s para evitar o bloqueio da API.
# import warnings
# from sqlalchemy import exc

# # Suppress SAWarning
# warnings.filterwarnings("ignore", category=exc.SAWarning, message=".*Cannot correctly sort tables.*")

# Your code here



# number_of_queries_to_delay = 25

count = 0

for instance in queries:
    with get_openai_callback() as cb:
        start_time = time.time()
        sql_query = query_chain.invoke({"question":instance["question"]})
        # this uses PROMPT template by filling it with input, table_info, top_k (possibly 0), sends filled prompt to GPT 4 and gets back SQL query
        end_time = time.time()
        instance["query_string"] = sql_query.replace('\n', ' ').strip()
        instance['total_tokens'] = cb.total_tokens
        instance['prompt_tokens'] = cb.prompt_tokens
        instance['completion_tokens'] = cb.completion_tokens
        instance['total_cost'] = cb.total_cost
        instance['time'] = end_time - start_time
        print(instance['id'], instance['question'], sql_query, instance['time'], instance['total_cost'])
    save_queries(queries)
    time.sleep(2)
queries

1 Return the customer name and address as one field called fullAddress that consists of the address, city, state. Only show customers in 'IA' or 'BC' that have a valid address. Order by customer name ascending. SELECT cname || ', ' || address || ', ' || city || ', ' || state AS fullAddress
FROM customer
WHERE state IN ('IA', 'BC') AND address IS NOT NULL
ORDER BY cname ASC; 0.7979540824890137 0.001898
2 Return the shipment id, shipment date, product id, and amount for all shipments in 2022 where there was a product shipped with an amount greater than 8. Only show a shipment once and order by shipment date descending and amount descending. SELECT s.sid, s.shipdate, sp.pid, sp.amount
FROM shipment s
JOIN shippedproduct sp ON s.sid = sp.sid
WHERE s.shipdate >= TO_DATE('2022-01-01', 'YYYY-MM-DD')
AND s.shipdate < TO_DATE('2023-01-01', 'YYYY-MM-DD')
AND sp.amount > 8
ORDER BY s.shipdate DESC, sp.amount DESC; 2.002218008041382 0.002094
3 Find all customers that have an 'R' or 'T' in their na

[{'id': 1,
  'question': "Return the customer name and address as one field called fullAddress that consists of the address, city, state. Only show customers in 'IA' or 'BC' that have a valid address. Order by customer name ascending.",
  'query_string': "SELECT cname || ', ' || address || ', ' || city || ', ' || state AS fullAddress FROM customer WHERE state IN ('IA', 'BC') AND address IS NOT NULL ORDER BY cname ASC;",
  'type': 'medium',
  'total_tokens': 618,
  'prompt_tokens': 574,
  'completion_tokens': 44,
  'total_cost': 0.001898,
  'time': 0.7979540824890137},
 {'id': 2,
  'question': 'Return the shipment id, shipment date, product id, and amount for all shipments in 2022 where there was a product shipped with an amount greater than 8. Only show a shipment once and order by shipment date descending and amount descending.',
  'query_string': "SELECT s.sid, s.shipdate, sp.pid, sp.amount FROM shipment s JOIN shippedproduct sp ON s.sid = sp.sid WHERE s.shipdate >= TO_DATE('2022-01-

In [61]:
queries=read_queries()
queries


[{'id': 1,
  'question': "Return the customer name and address as one field called fullAddress that consists of the address, city, state. Only show customers in 'IA' or 'BC' that have a valid address. Order by customer name ascending.",
  'query_string': "SELECT cname || ', ' || address || ', ' || city || ', ' || state AS fullAddress FROM customer WHERE state IN ('IA', 'BC') AND address IS NOT NULL ORDER BY cname ASC;",
  'type': 'medium',
  'total_tokens': 618,
  'prompt_tokens': 574,
  'completion_tokens': 44,
  'total_cost': 0.001898,
  'time': 0.7979540824890137},
 {'id': 2,
  'question': 'Return the shipment id, shipment date, product id, and amount for all shipments in 2022 where there was a product shipped with an amount greater than 8. Only show a shipment once and order by shipment date descending and amount descending.',
  'query_string': "SELECT s.sid, s.shipdate, sp.pid, sp.amount FROM shipment s JOIN shippedproduct sp ON s.sid = sp.sid WHERE s.shipdate >= TO_DATE('2022-01-

## Prompt Generated by Langchain

In [62]:
# sql_query_chain_prompt = query_chain.middle[1].template.format(table_info=db.get_table_info(), top_k=0, input="{input}")
sql_query_chain_prompt = PROMPT.format(
    table_info=db.get_table_info(), 
    top_k=0, 
    input="{input}"
)

print(sql_query_chain_prompt)

You are an Oracle SQL expert. Given an input question, first create a syntactically correct Oracle SQL query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, don't query for at 0 most results or any using the FETCH FIRST n ROWS ONLY clause as per Oracle SQL. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use TRUNC(SYSDATE) function to get the current date, if the question involves "today". 

Some hints:
- Don't use double quotes in column name
- 
Example:
`SELECT "column_name" FROM table` should be `SELECT column_name FR

In [63]:

f = open(f"prompt_generated_by_langchain_{SCHEMA}_fk.txt", "w")
f.write(sql_query_chain_prompt)
f.close()


#### Fixing queries

In [64]:
to_fix = [40,59,62,72,85,99]
for pos in to_fix:
    instance = queries[pos]
    q = read_queries()
    with get_openai_callback() as cb:
            start_time = time.time()
            sql_query = query_chain.invoke({"question":instance["question"]})
            end_time = time.time()
            instance["query_string"] = sql_query
            instance['total_tokens'] = cb.total_tokens
            instance['prompt_tokens'] = cb.prompt_tokens
            instance['completion_tokens'] = cb.completion_tokens
            instance['total_cost'] = cb.total_cost
            instance['time'] = end_time - start_time
            q[pos] = instance
            print(instance['id'], instance['question'], instance["query_string"], instance['time'], instance['total_cost'])
            save_queries(q)

IndexError: list index out of range