In [1]:
%%capture --no-stderr
%pip install --upgrade --quiet langchain langchain-community langchain-openai faiss-cpu

In [2]:
import getpass
import os

if not os.environ.get("OPENAI_API_KEY"):
    os.environ["OPENAI_API_KEY"] = getpass.getpass()

 ········


In [4]:
# Comment out the below to opt-out of using LangSmith in this notebook. Not required.
if not os.environ.get("LANGCHAIN_API_KEY"):
    os.environ["LANGCHAIN_API_KEY"] = getpass.getpass()
    os.environ["LANGCHAIN_TRACING_V2"] = "true"

In [5]:
from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("sqlite:///Chinook.db")
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM Artist LIMIT 10;")

sqlite
['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']


"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]"

In [7]:
pip install -qU langchain-groq

Note: you may need to restart the kernel to use updated packages.


In [6]:
import getpass
import os

os.environ["GROQ_API_KEY"] = getpass.getpass()

from langchain_groq import ChatGroq

llm = ChatGroq(model="llama3-8b-8192")

 ········


In [7]:
#to just execute the query and print the result. 

from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
from langchain_core.runnables import RunnableMap, RunnablePassthrough
from operator import itemgetter
from langchain.chains import create_sql_query_chain

# Initialize the tools
execute_query = QuerySQLDataBaseTool(db=db)
write_query = create_sql_query_chain(llm, db)

# Helper function to extract the SQL query
def extract_sql_query(output):
    parts = output.split("SQLQuery: ")
    if len(parts) > 1:
        return parts[1].strip()
    return ""

# Define the pipeline
chain = (
    write_query  # Generate SQL query
    | RunnableMap({"query": extract_sql_query})  # Extract query part
    | execute_query  # Execute the query
)
 

# Invoke the pipeline
response = chain.invoke({"question": "Which country has the highest total sales?"})

print(response)


[('Brazil',), ('Germany',), ('Canada',), ('Norway',), ('Czech Republic',), ('Czech Republic',), ('Austria',), ('Belgium',), ('Denmark',), ('Brazil',), ('Brazil',), ('Brazil',), ('Brazil',), ('Canada',), ('Canada',), ('USA',), ('USA',), ('USA',), ('USA',), ('USA',), ('USA',), ('USA',), ('USA',), ('USA',), ('USA',), ('USA',), ('USA',), ('USA',), ('Canada',), ('Canada',), ('Canada',), ('Canada',), ('Canada',), ('Portugal',), ('Portugal',), ('Germany',), ('Germany',), ('Germany',), ('France',), ('France',), ('France',), ('France',), ('France',), ('Finland',), ('Hungary',), ('Ireland',), ('Italy',), ('Netherlands',), ('Poland',), ('Spain',), ('Sweden',), ('United Kingdom',), ('United Kingdom',), ('United Kingdom',), ('Australia',), ('Argentina',), ('Chile',), ('India',), ('India',)]


In [8]:
# execute query and get the result in form of natural language.

from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
from langchain.chains import create_sql_query_chain
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from operator import itemgetter

# Define the prompt template
answer_prompt = PromptTemplate.from_template(
    """Given the following user question, corresponding SQL query, and SQL result, answer the user question.

Question: {question}
SQL Query: {query}
SQL Result: {result}
Answer: """
)

# Create the query generation chain
write_query = create_sql_query_chain(llm, db)
execute_query = QuerySQLDataBaseTool(db=db)

# Helper function to extract SQL query from the chain output
def extract_sql_query(query_with_text):
    parts = query_with_text.split("SQLQuery: ")
    if len(parts) > 1:
        sql_query = parts[1].strip()  # Extract the SQL query
        return sql_query  # Return the extracted query
    return ""

# Step 1: Generate the SQL query from the user question
generated_query = write_query.invoke({"question": "Find all albums for the artist 'AC/DC'"})  # Store full response
sql_query = extract_sql_query(generated_query)  # Extract only the SQL query

# Step 2: Execute the extracted query
query_result = execute_query.invoke({"query": sql_query})  # Execute the SQL query and get the result

# Step 3: Build the answer prompt
answer_data = {
    "question": "Find all albums for the artist 'AC/DC'",
    "query": sql_query,
    "result": query_result
}

# Step 4: Format the answer prompt
formatted_prompt = answer_prompt.format(**answer_data)  # Format the prompt with input data

# Step 5: Use LLM to generate a response
llm_response = llm.invoke(formatted_prompt)  # LLM output as an AIMessage

# Step 6: Parse the LLM response explicitly
response = StrOutputParser().invoke(llm_response.content)  # Use the content of the AIMessage

# Output the final result
print(response)



The user question is: Find all albums for the artist 'AC/DC'.

The SQL query provided is a correct solution to this problem. It first finds the ArtistId of the artist 'AC/DC' from the Artist table and then finds all albums with that ArtistId from the Album table. The result is a list of all albums by AC/DC, which in this case is [(1, 'For Those About To Rock We Salute You'), (4, 'Let There Be Rock')].
