Load data in sql before running. Dummy dataset script: https://github.com/lerocha/chinook-database/releases


In [1]:
import os


from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI



api_key=os.environ['OPEN_API_KEY']
llm = ChatOpenAI(api_key=api_key)

## DB Connecion

In [2]:
from langchain_community.utilities import SQLDatabase


# if you are using MySQL
user='root'
password='1234'
database='chinook'
port='3306'

def return_db(user,password,database,port):
    mysql_uri = f'mysql+mysqlconnector://{user}:{password}@localhost:{port}/{database}'
    return mysql_uri

In [3]:
# Returns schema
def get_schema(db):
    schema = db.get_table_info()
    return schema

In [4]:
def ret_query_prompt(schema,question):
    get_query_prompt=f'''Based on the table schema below, write a SQL query that would answer the user's question:
{schema}

Question: {question}
SQL Query:'''
    return get_query_prompt

In [5]:
def ret_final_prompt(question,query,query_result):
    final_prompt = f"""Based on the question, sql query, and sql response, write a natural language response:
Question: {question}
SQL Query: {query}
SQL Response: {query_result}"""
    return final_prompt

In [6]:
def run_query(db,query):
    return db.run(query)

In [7]:
def execute_query(db,question):
    schema=get_schema(db)
    query_prompt=ret_query_prompt(schema,question)
    generated_query=llm.invoke(query_prompt).content.strip()
    print('Query Generated:\n',generated_query)
    query_result=run_query(db,generated_query)
    print('\nQuery Result:\n',query_result)
    final_prompt=ret_final_prompt(question,generated_query,query_result)
    final_response=llm.invoke(final_prompt).content.strip()
    print('\nFinal Response:\n',final_response)
    return final_response


### Connect to DB

In [8]:

mysql_uri=return_db(user,password,database,port)
db = SQLDatabase.from_uri(mysql_uri)


### Testing

In [9]:
questions=[
    'how many albums are there in the database',
    'Which Customer has the highest total bill?',
    'Which Genre had the most sales?',
    'Which artist had the most sales?'
]

In [10]:
res=execute_query(db,questions[3])
print('----------------------------\n',res)

Query Generated:
 SELECT a.Name AS Artist, SUM(il.UnitPrice) AS TotalSales
FROM artist a
JOIN album al ON a.ArtistId = al.ArtistId
JOIN track t ON al.AlbumId = t.AlbumId
JOIN invoiceline il ON t.TrackId = il.TrackId
GROUP BY a.Name
ORDER BY TotalSales DESC
LIMIT 1;

Query Result:
 [('Iron Maiden', Decimal('138.60'))]

Final Response:
 The artist with the most sales is Iron Maiden, with a total sales amount of $138.60.
----------------------------
 The artist with the most sales is Iron Maiden, with a total sales amount of $138.60.
