In [92]:
from dotenv import load_dotenv
load_dotenv()

True

In [93]:
#Connecting to SQL Database 'Chinook.db' 

from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("sqlite:///chinook.db")

print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM Artist LIMIT 10;")

sqlite
['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']


"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'AntÃ´nio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]"

In [94]:
#Setting up Chains 
# 1. Convert Question into SQL query
# 2. Execute SQL query
# 3. Ue result, generate answer 

#Chain #1 -> Convert Question into SQL query

from langchain_openai import ChatOpenAI

llm = ChatOpenAI(model = "gpt-4o-mini")

from langchain.chains import create_sql_query_chain

chain = create_sql_query_chain(llm, db)

In [104]:
#We have to ensure proper SQL query syntax. Below we extend the chain with a prompt and model call

from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate

system = """Double check the user's {dialect} query for common mistakes, including:
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Data type mismatch in predicates
- Properly quoting identifiers
- Using the correct number of arguments for functions
- Casting to the correct data type
- Using the proper columns for joins

If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.

Output the final SQL query only."""

prompt = ChatPromptTemplate.from_messages([("system", system), ("human", "{query}")]).partial(dialect=db.dialect)

validation_chain = prompt | llm | StrOutputParser()

full_chain = {"query": chain} | validation_chain
query = full_chain.invoke({"question": "AC/DC"})
query

'SELECT "Title" FROM "Album" WHERE "ArtistId" = (SELECT "ArtistId" FROM "Artist" WHERE "Name" = \'AC/DC\') LIMIT 5;'

In [105]:
db.run(query)

"[('For Those About To Rock We Salute You',), ('Let There Be Rock',)]"

In [97]:
#Chain #2 -> Execute SQL query (No human approval for querying -> to add for the future!)
#validation chain is PEAK

from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool


execute_query = QuerySQLDataBaseTool(db=db)
write_query = create_sql_query_chain(llm,db)

chain = {"query": write_query} | validation_chain| execute_query
chain.invoke({"question": "what is the average invoice"})

#Now on the asnwering the question!!!! it is broken but time to work on it!

'[(5.651941747572815,)]'

In [110]:
#Chain #3 -> Answer original prompt


from operator import itemgetter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough

answer_prompt = PromptTemplate.from_template(
    """Given the following user question, corresponding SQL query, and SQL result, answer the user question.
    
    Question: {question}
    SQL Query: {query}
    SQL Result: {result}
    Answer: """
)

chain = (
    RunnablePassthrough.assign(query=write_query).assign(result=itemgetter("query") | validation_chain | execute_query) | answer_prompt | llm | StrOutputParser()
)

chain.invoke({"question": "albums of AC/DC"})

'The albums of AC/DC are "For Those About To Rock We Salute You" and "Let There Be Rock".'