In [1]:
from langchain_community.utilities import SQLDatabase
from pyprojroot import here
import warnings
warnings.filterwarnings("ignore")

In [6]:
import os
from dotenv import load_dotenv

load_dotenv()

True

Connecting to the sqldb

In [2]:
db_path = r"data\sql\sqldb.db"
db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
db

<langchain_community.utilities.sql_database.SQLDatabase at 0x1e9096c8ef0>

In [3]:
# validate the connection to the vectordb
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM Artist LIMIT 10;")

sqlite
['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']


"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]"

In [7]:
from langchain_groq import ChatGroq
# llm=ChatGroq(model="llama3-70b-8192")
llm = ChatGroq(
    temperature = 0,
    groq_api_key = os.getenv("groq_api_key"),
    model_name = os.getenv("llama_model_name")
)

In [None]:
llm.invoke("hello how are you?")

AIMessage(content="Hello. I'm doing well, thanks for asking. I'm a large language model, so I don't have feelings or emotions like humans do, but I'm functioning properly and ready to help with any questions or tasks you may have. How about you? How's your day going so far?", additional_kwargs={}, response_metadata={'token_usage': {'completion_tokens': 61, 'prompt_tokens': 40, 'total_tokens': 101, 'completion_time': 0.221818182, 'prompt_time': 0.004503605, 'queue_time': 0.233389189, 'total_time': 0.226321787}, 'model_name': 'llama-3.3-70b-versatile', 'system_fingerprint': 'fp_e669a124b2', 'finish_reason': 'stop', 'logprobs': None}, id='run-a2806b69-def1-4ee5-9a0c-2e3b84312058-0', usage_metadata={'input_tokens': 40, 'output_tokens': 61, 'total_tokens': 101})

#### 1. SQL query chain

In [9]:
from langchain.chains import create_sql_query_chain

chain = create_sql_query_chain(llm, db)
response = chain.invoke({"question": "How many employees are there"})
print(response)

Question: How many employees are there
SQLQuery: SELECT COUNT("EmployeeId") FROM "Employee"


In [14]:
db.run("SELECT COUNT('EmployeeId') FROM 'Employee';")

'[(8,)]'

In [15]:
chain.get_prompts()[0].pretty_print()

You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date('now') function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result

#### Add QuerySQLDataBaseTool to the chain

In [17]:
from langchain_community.tools import QuerySQLDataBaseTool

write_query = create_sql_query_chain(llm, db)
execute_query = QuerySQLDataBaseTool(db=db)

chain = write_query | execute_query

chain.invoke({"question": "How many employees are there"})

'Error: (sqlite3.OperationalError) near "Question": syntax error\n[SQL: Question: How many employees are there\nSQLQuery: SELECT COUNT("EmployeeId") FROM "Employee"]\n(Background on this error at: https://sqlalche.me/e/20/e3q8)'

#### Answer the question in a user friendly manner

In [18]:
from operator import itemgetter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough

answer_prompt = PromptTemplate.from_template(
    """Given the following user question, corresponding SQL query, and SQL result, answer the user question.

Question: {question}
SQL Query: {query}
SQL Result: {result}
Answer: """
)

answer = answer_prompt | llm | StrOutputParser()
chain = (
    RunnablePassthrough.assign(query=write_query).assign(
        result=itemgetter("query") | execute_query
    )
    | answer
)

chain.invoke({"question": "How many employees are there"})

'It seems like there was an error in the SQL query. The query itself is correct, but it appears that the query was not executed properly due to extra text being included.\n\nThe correct SQL query should be:\n```sql\nSELECT COUNT("EmployeeId") FROM "Employee"\n```\nHowever, since the SQL result is an error message and not the actual result of the query, we cannot determine the exact number of employees.\n\nTo answer the user\'s question, we would need to execute the correct SQL query and retrieve the result. If we assume that the query is executed correctly, the answer would be the count of rows in the "Employee" table, which would be a numerical value.\n\nFor example, if the query returns a result of 50, the answer would be:\n"There are 50 employees." \n\nHowever, without the actual result, we cannot provide a specific answer.'

#### 2. Agents

In [21]:
from langchain_community.agent_toolkits import create_sql_agent

agent_executor = create_sql_agent(llm, db=db, agent_type="openai-tools", 
                                  verbose=True,
                                  agent_executor_kwargs = {"return_intermediate_steps": True}
                                  )

In [23]:
response = agent_executor.invoke(
    {
        "input": "List the total sales per country. Which country's customers spent the most?"
    }
)



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3m
Invoking: `sql_db_list_tables` with `{'tool_input': ''}`
responded:   Then I should double check my query before running it to make sure it is correct.  Then I should run the query to get the answer to the question.



[0m[38;5;200m[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track[0m[32;1m[1;3m
Invoking: `sql_db_schema` with `{'table_names': 'customers, orders, products, order_items, order_item_products'}`
responded:   Then I should double check my query before running it to make sure it is correct.  Then I should run the query to get the answer to the question.



[0m[33;1m[1;3mError: table_names {'orders', 'products', 'customers', 'order_items', 'order_item_products'} not found in database[0m[32;1m[1;3m
Invoking: `sql_db_query_checker` with `{'query': 'SELECT country, SUM(total_amount) FROM customers JOIN orders ON customers.customer_id = orders.

In [24]:
queries = []
for (log, output) in response["intermediate_steps"]:
    if log.tool == 'sql_db_query':
        queries.append(log.tool_input)

In [35]:
print(f"Query\n{queries[-1]['query']}\nAnswer\n{response['output']}")

Query
SELECT BillingCountry, SUM(Total) FROM Invoice GROUP BY BillingCountry ORDER BY SUM(Total) DESC LIMIT 10;
Answer
The country with the highest total sales is the USA, with a total of $523.06.


In [31]:
response['output']

'The country with the highest total sales is the USA, with a total of $523.06.'