# SQLQUeryChain - Mondial - Foreign Key - GPT 3.5

In [1]:

from langchain.sql_database import SQLDatabase

from langchain.chat_models import ChatOpenAI
from langchain.chains import create_sql_query_chain
from urllib.parse import quote  
from langchain.callbacks import get_openai_callback
import time
from dotenv import load_dotenv
import os
import sys
import json
load_dotenv()

experiment_path = '../..'
path = os.path.abspath('')
module_path = os.path.join(path, experiment_path)
print(module_path)
if module_path not in sys.path:
    sys.path.append(module_path+"/functions")


from sqldatabase_langchain_utils import SQLDatabaseLangchainUtils


/Users/moon/Documents/text_to_sql_byLLM/experiments/SQLQueryChain/../..


# Schema

In [2]:
SCHEMA = 'mondial_gpt'
PREFIX = 'mondial'

FILE_NAME_RESULT = f"results/9_sql_queries_chatgpt_{SCHEMA}_fk.json"

In [3]:
def save_queries(queries):
    data = {"queries":queries}
    with open(FILE_NAME_RESULT, "w") as arquivo_json:
        json.dump(data, arquivo_json, indent=4) 

def read_queries():
    with open(FILE_NAME_RESULT, encoding='utf-8', errors='ignore') as json_data:
        data = json.load(json_data, strict=False)
    queries = data["queries"]
    return queries

## Connection

In [4]:
json_file_path = f"../../datasets/{SCHEMA}_db_connection.json"
with open(json_file_path, encoding='utf-8', errors='ignore') as json_data:
    db_connection = json.load(json_data, strict=False)

db_connection

{'DB_HOST': 'localhost',
 'DB_PORT': '1521',
 'DB_USER_NAME': 'MONDIAL',
 'DB_PASS': 'oraclee',
 'DB_NAME': 'XE',
 'SQL_DRIVER': 'oracle+oracledb',
 'SERVICE_NAME': 'XE',
 'SCHEMA': 'MONDIAL_GPT',
 'KEYWORD_SEARCH_API_URL': ''}

### Using SQLDatabase to get all the information from the database

In [5]:
db = SQLDatabaseLangchainUtils(db_connection=db_connection)

exclusao = [
    f"{SCHEMA}_tmdp",
    f"{SCHEMA}_tmdpmap",
    f"{SCHEMA}_tmds",
    f"{SCHEMA}_tmjmap",
    f"{SCHEMA}_tpv",
    f"{SCHEMA}_tmdc",
    f"{SCHEMA}_tmdcmap",
    f"{SCHEMA}_tmdej",
    f"{SCHEMA}_log_action",
    f"{SCHEMA}_log_error",
    f"{SCHEMA}_favorite_item", 
    f"{SCHEMA}_favorite_query",
    f"{SCHEMA}_favorite_tag",
    f"{SCHEMA}_favorite_tag_item",
    f"{SCHEMA}_favorite_visualization",
    f"{SCHEMA}_dashboard",
    f"{SCHEMA}_history",
    "teste_cliente",
    "teste_fornecedor",
    "teste_funcionario"
]

include_tables = [s for s in db.get_table_names() if not s.startswith(PREFIX) and s not in exclusao]
db = SQLDatabaseLangchainUtils(db_connection=db_connection, include_tables=include_tables)
db.get_table_names()

  self._metadata.reflect(
  self._metadata.reflect(
  self._metadata.reflect(


['airport',
 'borders',
 'city',
 'citylocalname',
 'cityothername',
 'citypops',
 'continent',
 'country',
 'countrylocalname',
 'countryothername',
 'countrypops',
 'desert',
 'economy',
 'encompasses',
 'ethnicgroup',
 'geo_desert',
 'geo_estuary',
 'geo_island',
 'geo_lake',
 'geo_mountain',
 'geo_river',
 'geo_sea',
 'geo_source',
 'island',
 'islandin',
 'ismember',
 'lake',
 'lakeonisland',
 'language',
 'located',
 'locatedon',
 'mergeswith',
 'mountain',
 'mountainonisland',
 'organization',
 'politics',
 'population',
 'province',
 'provincelocalname',
 'provinceothername',
 'provpops',
 'religion',
 'river',
 'riveronisland',
 'riverthrough',
 'sea',
 'spoken']

In [6]:
len(include_tables)

47

## Creating the prompt

In [7]:
from langchain.prompts.prompt import PromptTemplate

f = open(f"prompts/prompt_template_sql_query_chain.txt", "r")
prompt_template = f.read()
f.close()


PROMPT = PromptTemplate(
    input_variables=["input", "table_info", "top_k"], template=prompt_template
)

print(PROMPT)

input_variables=['input', 'table_info', 'top_k'] input_types={} partial_variables={} template='You are an Oracle SQL expert. Given an input question, first create a syntactically correct Oracle SQL query to run, then look at the results of the query and return the answer to the input question.\nUnless the user specifies in the question a specific number of examples to obtain, don\'t query for at {top_k} most results or any using the FETCH FIRST n ROWS ONLY clause as per Oracle SQL. You can order the results to return the most informative data in the database.\nNever query for all columns from a table. You must query only the columns that are needed to answer the question.\nPay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.\nPay attention to use TRUNC(SYSDATE) function to get the current date, if the question involves "today". \n\nSome hints:\n- Don\'t u

## Creating Chain to generate SQL

In [8]:
query_chain  = create_sql_query_chain(ChatOpenAI(temperature=0, model_name='gpt-3.5-turbo-16k'), db.db, prompt=PROMPT)


  query_chain  = create_sql_query_chain(ChatOpenAI(temperature=0, model_name='gpt-3.5-turbo-16k'), db.db, prompt=PROMPT)


### Test: query and Number of Tokens

In [9]:
with get_openai_callback() as cb:
    sql_query = query_chain.invoke({"question":"What are the provinces with an area more than 10000?"})
    
    print(cb.total_tokens)
    print(cb.prompt_tokens)
    print(cb.completion_tokens)
    print(cb.total_cost)
sql_query

4213
4193
20
0.012659


'SELECT name, country, area\nFROM province\nWHERE area > 10000;'

In [10]:
with get_openai_callback() as cb:
    sql_query = query_chain.invoke({"question":"What are the languages spoken in Poland?"})
    print(cb.total_tokens)
    print(cb.prompt_tokens)
    print(cb.completion_tokens)
    print(cb.total_cost)
sql_query

4221
4188
33
0.012695999999999999


"SELECT name\nFROM language\nWHERE name IN (\n    SELECT language\n    FROM spoken\n    WHERE country = 'POL'\n)\nORDER BY name;"

## Preparing natural language queries to run in LLM

In [11]:

json_file_path = f"../../datasets/{PREFIX}/queries_{PREFIX}.json"
with open(json_file_path, encoding='utf-8', errors='ignore') as json_data:
    queries = json.load(json_data, strict=False)
queries = queries['queries']
queries

[{'id': '1',
  'question': 'What is the area of Thailand?',
  'query_string': '',
  'type': 'simple'},
 {'id': '2',
  'question': 'What are the provinces with an area greater than 10000?',
  'query_string': '',
  'type': 'simple'},
 {'id': '3',
  'question': 'What are the languages spoken in Poland?',
  'query_string': '',
  'type': 'medium'},
 {'id': '4',
  'question': 'How deep is Lake Kariba?',
  'query_string': '',
  'type': 'simple'},
 {'id': '5',
  'question': 'What is the total of provinces of Netherlands?',
  'query_string': '',
  'type': 'complex'},
 {'id': '6',
  'question': 'What is the percentage of religious people are hindu in thailand?',
  'query_string': '',
  'type': 'complex'},
 {'id': '7',
  'question': 'List the number of provinces each river flows through.',
  'query_string': '',
  'type': 'medium'},
 {'id': '8',
  'question': 'Find all countries that became independent between 8/1/1910 and 8/1/1950.',
  'query_string': '',
  'type': 'complex'},
 {'id': '9',
  'que

# Running queries in LLM to generate SQL

In [12]:

# A cada X consultas, vai ser gerado um delay de 10s para evitar o bloqueio da API.
import warnings
from sqlalchemy import exc

# Suppress SAWarning
warnings.filterwarnings("ignore", category=exc.SAWarning, message=".*Cannot correctly sort tables.*")

# Your code here



number_of_queries_to_delay = 25
count = 0
for instance in queries:
    if count == number_of_queries_to_delay:
        count = 0
        time.sleep(10)
    with get_openai_callback() as cb:
        start_time = time.time()
        sql_query = query_chain.invoke({"question":instance["question"]})
        end_time = time.time()
        instance["query_string"] = sql_query
        instance['total_tokens'] = cb.total_tokens
        instance['prompt_tokens'] = cb.prompt_tokens
        instance['completion_tokens'] = cb.completion_tokens
        instance['total_cost'] = cb.total_cost
        instance['time'] = end_time - start_time
        print(instance['id'], instance['question'], instance["query_string"], instance['time'], instance['total_cost'])
    save_queries(queries)
    count += 1
queries

1 What is the area of Thailand? SELECT area FROM country WHERE name = 'Thailand'; 3.0832347869873047 0.012617000000000001
2 What are the provinces with an area greater than 10000? SELECT name, country, area
FROM province
WHERE area > 10000; 1.0262291431427002 0.012659
3 What are the languages spoken in Poland? SELECT name
FROM language
WHERE name IN (
    SELECT language
    FROM spoken
    WHERE country = 'POL'
)
ORDER BY name; 7.529034852981567 0.012695999999999999
4 How deep is Lake Kariba? SELECT depth FROM lake WHERE name = 'Lake Kariba'; 1.1658940315246582 0.012621
5 What is the total of provinces of Netherlands? SELECT COUNT(name) 
FROM province 
WHERE country = 'NL'; 0.6863560676574707 0.012635
6 What is the percentage of religious people are hindu in thailand? SELECT percentage 
FROM religion 
WHERE country = 'THA' AND name = 'Hindu' 0.5334818363189697 0.012678
7 List the number of provinces each river flows through. SELECT r.name AS river_name, COUNT(DISTINCT p.name) AS num_p

[{'id': '1',
  'question': 'What is the area of Thailand?',
  'query_string': "SELECT area FROM country WHERE name = 'Thailand';",
  'type': 'simple',
  'total_tokens': 4201,
  'prompt_tokens': 4187,
  'completion_tokens': 14,
  'total_cost': 0.012617000000000001,
  'time': 3.0832347869873047},
 {'id': '2',
  'question': 'What are the provinces with an area greater than 10000?',
  'query_string': 'SELECT name, country, area\nFROM province\nWHERE area > 10000;',
  'type': 'simple',
  'total_tokens': 4213,
  'prompt_tokens': 4193,
  'completion_tokens': 20,
  'total_cost': 0.012659,
  'time': 1.0262291431427002},
 {'id': '3',
  'question': 'What are the languages spoken in Poland?',
  'query_string': "SELECT name\nFROM language\nWHERE name IN (\n    SELECT language\n    FROM spoken\n    WHERE country = 'POL'\n)\nORDER BY name;",
  'type': 'medium',
  'total_tokens': 4221,
  'prompt_tokens': 4188,
  'completion_tokens': 33,
  'total_cost': 0.012695999999999999,
  'time': 7.529034852981567

In [13]:
queries=read_queries()
queries


[{'id': '1',
  'question': 'What is the area of Thailand?',
  'query_string': "SELECT area FROM country WHERE name = 'Thailand';",
  'type': 'simple',
  'total_tokens': 4201,
  'prompt_tokens': 4187,
  'completion_tokens': 14,
  'total_cost': 0.012617000000000001,
  'time': 3.0832347869873047},
 {'id': '2',
  'question': 'What are the provinces with an area greater than 10000?',
  'query_string': 'SELECT name, country, area\nFROM province\nWHERE area > 10000;',
  'type': 'simple',
  'total_tokens': 4213,
  'prompt_tokens': 4193,
  'completion_tokens': 20,
  'total_cost': 0.012659,
  'time': 1.0262291431427002},
 {'id': '3',
  'question': 'What are the languages spoken in Poland?',
  'query_string': "SELECT name\nFROM language\nWHERE name IN (\n    SELECT language\n    FROM spoken\n    WHERE country = 'POL'\n)\nORDER BY name;",
  'type': 'medium',
  'total_tokens': 4221,
  'prompt_tokens': 4188,
  'completion_tokens': 33,
  'total_cost': 0.012695999999999999,
  'time': 7.529034852981567

## Prompt Generated by Langchain

In [14]:
sql_query_chain_prompt = query_chain.middle[1].template.format(table_info=db.get_table_info(), top_k=0, input="{input}")
print(sql_query_chain_prompt)

You are an Oracle SQL expert. Given an input question, first create a syntactically correct Oracle SQL query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, don't query for at 0 most results or any using the FETCH FIRST n ROWS ONLY clause as per Oracle SQL. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use TRUNC(SYSDATE) function to get the current date, if the question involves "today". 

Some hints:
- Don't use double quotes in column name
- 
Example:
`SELECT "column_name" FROM table` should be `SELECT column_name FR

In [15]:

f = open(f"prompt_generated_by_langchain_{SCHEMA}_fk.txt", "w")
f.write(sql_query_chain_prompt)
f.close()


#### Fixing queries

In [None]:
to_fix = [40,59,62,72,85,99]
for pos in to_fix:
    instance = queries[pos]
    q = read_queries()
    with get_openai_callback() as cb:
            start_time = time.time()
            sql_query = query_chain.invoke({"question":instance["question"]})
            end_time = time.time()
            instance["query_string"] = sql_query
            instance['total_tokens'] = cb.total_tokens
            instance['prompt_tokens'] = cb.prompt_tokens
            instance['completion_tokens'] = cb.completion_tokens
            instance['total_cost'] = cb.total_cost
            instance['time'] = end_time - start_time
            q[pos] = instance
            print(instance['id'], instance['question'], instance["query_string"], instance['time'], instance['total_cost'])
            save_queries(q)

41 How many countries that are close to the Mediterranean Sea? SELECT COUNT(DISTINCT country) 
FROM borders
WHERE country1 IN (
    SELECT country 
    FROM sea 
    WHERE name = 'Mediterranean Sea'
) 
OR country2 IN (
    SELECT country 
    FROM sea 
    WHERE name = 'Mediterranean Sea'
) 0.9733130931854248 0.012832999999999999
60  List the names of capital cities which are the base for organizations in alphabetical order SELECT DISTINCT city.name
FROM city
JOIN organization ON city.name = organization.city
WHERE city.country = organization.country
ORDER BY city.name ASC; 1.0307528972625732 0.012711
63 Show the inflation rate of countries that are washed by the Arabian Sea SELECT c.name, e.inflation
FROM country c
JOIN economy e ON c.code = e.country
JOIN located l ON c.code = l.country
WHERE l.sea = 'Arabian Sea'
ORDER BY e.inflation DESC; 0.705967903137207 0.012794
73 What area is the largest continent? SELECT name, area FROM continent ORDER BY area DESC; 0.5486047267913818 0.01261