https://python.langchain.com/docs/use_cases/sql/quickstart/

### **Test the sqldb**

In [2]:
from langchain_community.utilities import SQLDatabase
from pyprojroot import here
import warnings
warnings.filterwarnings("ignore")

**Connecting to the sqldb**

In [3]:
db_path = str(here("data")) + "/sqldb.db"
db = SQLDatabase.from_uri(f"sqlite:///{db_path}")

In [4]:
db

<langchain_community.utilities.sql_database.SQLDatabase at 0x2c321af1790>

In [5]:
# validate the connection to the vectordb
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM Artist LIMIT 10;")

sqlite
['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']


"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]"

### **Test the access to the environment variables**

In [6]:
from dotenv import load_dotenv
import os
print("Environment variables are loaded:", load_dotenv())
#print("test by reading a variable:", os.getenv("OPENAI_API_TYPE"))

Environment variables are loaded: True


### **Test your GPT model**

In [7]:
from openai import OpenAI

messages = [
    {"role": "system", "content": str(
        "You are a helpful assistant"
    )},
    {"role": "user", "content": str("hello")}
]

client = OpenAI(
    #api_version=os.getenv("OPENAI_API_VERSION"),
    #azure_endpoint=os.getenv("OPENAI_API_BASE"),
    api_key=os.getenv("OPENAI_API_KEY"),
)

response = client.chat.completions.create(
    model="gpt-4o-mini",
    messages=messages
)

print(response.choices[0].message.content)

Hello! How can I assist you today?


### **1. SQL query chain**

In [8]:
# Load the LLM
from langchain_openai import ChatOpenAI

model_name="gpt-4o-mini"
#azure_openai_api_key = os.environ["OPENAI_API_KEY"]
#azure_openai_endpoint = os.environ["OPENAI_API_BASE"]
llm = ChatOpenAI(
    #openai_api_version=os.getenv("OPENAI_API_VERSION"),
    #azure_deployment=model_name,
    model_name=model_name,
    temperature=0.0)

In [9]:
from langchain.chains import create_sql_query_chain

chain = create_sql_query_chain(llm, db)
response = chain.invoke({"question": "How many employees are there"})
print(response)


SQLQuery: SELECT COUNT("EmployeeId") AS "EmployeeCount" FROM "Employee"


In [9]:
sql_query = response.split("SQLQuery: ")[-1].strip()
print(sql_query)

SELECT COUNT("EmployeeId") AS "EmployeeCount" FROM "Employee"


Execute the query to make sure it’s valid

In [10]:
db.run(sql_query)

'[(8,)]'

In [11]:
chain.get_prompts()[0].pretty_print()

You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date('now') function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result

### **Answer the question in a user friendly manner**

In [16]:
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool

# Create the SQL query generation chain
write_query = create_sql_query_chain(llm, db)

# Create the SQL execution tool
execute_query = QuerySQLDataBaseTool(db=db)

def extract_sql_query(response):
    """
    Extracts the SQL query from the response.
    Assumes the response includes the SQL query preceded by 'SQLQuery: '.
    """
    # Adjust the split pattern to match the response format
    sql_query = response.split("SQLQuery: ")[-1].strip()
    return sql_query

def get_final_answer(question):
    """
    Generate an SQL query based on the question, execute it, and return the final answer.
    """
    # Step 1: Generate the SQL query
    response = write_query.invoke({"question": question})
    
    # Step 2: Extract the SQL query from the response
    sql_query = extract_sql_query(response)
    print("Extracted SQL Query:", sql_query)  # For debugging purposes
    
    # Step 3: Execute the extracted SQL query
    # Create a temporary tool to execute the query
    # Note: QuerySQLDataBaseTool expects a callable or function, not a string SQL query
    # So we'll need to use a lambda or function to execute the query
    execute_query_function = lambda query: execute_query.invoke({"query": query})
    
    # Execute the SQL query
    final_response = execute_query_function(sql_query)
    
    # Return the final answer from the execution result
    return final_response

# Example usage
question = "How many employees are there?"
final_answer = get_final_answer(question)
print("Final Answer:", final_answer)


Extracted SQL Query: SELECT COUNT("EmployeeId") AS "EmployeeCount" FROM "Employee";
Final Answer: [(8,)]


In [21]:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool

# Create the SQL query generation chain
write_query = create_sql_query_chain(llm, db)

# Create the SQL execution tool
execute_query = QuerySQLDataBaseTool(db=db)

# Define the answer prompt template
answer_prompt = PromptTemplate.from_template(
    """Given the following user question, corresponding SQL query, and SQL result, answer the user question.

    Question: {question}
    SQL Query: {query}
    SQL Result: {result}
    Answer: """
)

# Define the answer chain
answer = answer_prompt | llm | StrOutputParser()

def get_final_answer(question):
    """
    Generate an SQL query based on the question, execute it, and return the final answer.
    """
    # Step 1: Generate SQL query
    query_response = write_query.invoke({"question": question})
    
    # Debugging: Print the query_response to understand its format
    print("Query Response:", query_response)
    
    # Handle different possible formats of query_response
    if isinstance(query_response, dict):
        sql_query = query_response.get("query", "")
    elif isinstance(query_response, str):
        sql_query = extract_sql_query(query_response)  # Custom function to extract SQL from a string response
    else:
        raise TypeError("Unexpected format of query_response")
    
    print("Extracted SQL Query:", sql_query)  # Debugging purpose
    
    # Step 2: Execute SQL query
    execute_response = execute_query.invoke({"query": sql_query})
    
    # Debugging: Print the execute_response to understand its format
    print("Execute Response:", execute_response)
    
    # Handle different possible formats of execute_response
    if isinstance(execute_response, dict):
        sql_result = execute_response.get("result", "")
    elif isinstance(execute_response, str):
        sql_result = extract_sql_query(execute_response)  # Use custom extraction if necessary
    else:
        raise TypeError("Unexpected format of execute_response")
    
    print("SQL Result:", sql_result)  # Debugging purpose
    
    # Step 3: Generate final answer
    final_answer = answer.invoke({
        "question": question,
        "query": sql_query,
        "result": sql_result
    })
    
    return final_answer

# Custom function to extract SQL query from a string if needed
def extract_sql_query(response):
    """
    Extracts the SQL query from the response string if it's in a specific format.
    """
    sql_query = response.split("SQLQuery: ")[-1].strip()
    return sql_query

# Example usage
question = "How many employees are there?"
final_answer = get_final_answer(question)
print("Final Answer:", final_answer)


Query Response: SQLQuery: SELECT COUNT("EmployeeId") AS "EmployeeCount" FROM "Employee";
Extracted SQL Query: SELECT COUNT("EmployeeId") AS "EmployeeCount" FROM "Employee";
Execute Response: [(8,)]
SQL Result: [(8,)]
Final Answer: There are 8 employees.


In [22]:
chain

RunnableAssign(mapper={
  input: RunnableLambda(...),
  table_info: RunnableLambda(...)
})
| RunnableLambda(lambda x: {k: v for k, v in x.items() if k not in ('question', 'table_names_to_use')})
| PromptTemplate(input_variables=['input', 'table_info'], partial_variables={'top_k': '5'}, template='You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.\nUnless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.\nNever query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.\nPay attention to use only the column names you can see in the tables below. B

### **2. Agents**

Agent which provides a more flexible way of interacting with SQL databases. The main advantages of using the SQL Agent are:

- It can answer questions based on the databases’ schema as well as on the databases’ content (like describing a specific table).
- It can recover from errors by running a generated query, catching the traceback and regenerating it correctly.
- It can answer questions that require multiple dependent queries.
- It will save tokens by only considering the schema from relevant tables.

To initialize the agent, we use create_sql_agent function. This agent contains the SQLDatabaseToolkit which contains tools to:

- Create and execute queries
- Check query syntax
- Retrieve table descriptions
- …

In [23]:
from langchain_community.agent_toolkits import create_sql_agent

agent_executor = create_sql_agent(llm, db=db, agent_type="openai-tools", verbose=True)

In [24]:
agent_executor.invoke(
    {
        "input": "List the total sales per country. Which country's customers spent the most?"
    }
)



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3m
Invoking: `sql_db_list_tables` with `{}`


[0m[38;5;200m[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track[0m[32;1m[1;3m
Invoking: `sql_db_schema` with `{'table_names': 'Customer'}`


[0m[33;1m[1;3m
CREATE TABLE "Customer" (
	"CustomerId" INTEGER NOT NULL, 
	"FirstName" NVARCHAR(40) NOT NULL, 
	"LastName" NVARCHAR(20) NOT NULL, 
	"Company" NVARCHAR(80), 
	"Address" NVARCHAR(70), 
	"City" NVARCHAR(40), 
	"State" NVARCHAR(40), 
	"Country" NVARCHAR(40), 
	"PostalCode" NVARCHAR(10), 
	"Phone" NVARCHAR(24), 
	"Fax" NVARCHAR(24), 
	"Email" NVARCHAR(60) NOT NULL, 
	"SupportRepId" INTEGER, 
	PRIMARY KEY ("CustomerId"), 
	FOREIGN KEY("SupportRepId") REFERENCES "Employee" ("EmployeeId")
)

/*
3 rows from Customer table:
CustomerId	FirstName	LastName	Company	Address	City	State	Country	PostalCode	Phone	Fax	Email	SupportRepId
1	Luís	Gonçalves	Embraer - Empresa Bras

{'input': "List the total sales per country. Which country's customers spent the most?",
 'output': 'The total sales per country are as follows:\n\n1. **USA**: $523.06\n2. **Canada**: $303.96\n3. **France**: $195.10\n4. **Brazil**: $190.10\n5. **Germany**: $156.48\n6. **United Kingdom**: $112.86\n7. **Czech Republic**: $90.24\n8. **Portugal**: $77.24\n9. **India**: $75.26\n10. **Chile**: $46.62\n\nThe country whose customers spent the most is the **USA** with total sales of $523.06.'}

In [19]:
agent_executor.invoke({"input": "Describe the playlisttrack table"})
# agent_executor.invoke("Describe the playlisttrack table")



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3m
Invoking: `sql_db_list_tables` with `{}`


[0m[38;5;200m[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track[0m[32;1m[1;3m
Invoking: `sql_db_schema` with `{'table_names': 'PlaylistTrack'}`


[0m[33;1m[1;3m
CREATE TABLE "PlaylistTrack" (
	"PlaylistId" INTEGER NOT NULL, 
	"TrackId" INTEGER NOT NULL, 
	PRIMARY KEY ("PlaylistId", "TrackId"), 
	FOREIGN KEY("TrackId") REFERENCES "Track" ("TrackId"), 
	FOREIGN KEY("PlaylistId") REFERENCES "Playlist" ("PlaylistId")
)

/*
3 rows from PlaylistTrack table:
PlaylistId	TrackId
1	3402
1	3389
1	3390
*/[0m[32;1m[1;3mThe `PlaylistTrack` table has the following structure:

- **Columns**:
  - `PlaylistId` (INTEGER, NOT NULL): This is a foreign key referencing the `Playlist` table.
  - `TrackId` (INTEGER, NOT NULL): This is a foreign key referencing the `Track` table.

- **Primary Key**: The combination of `PlaylistId` a

{'input': 'Describe the playlisttrack table',
 'output': 'The `PlaylistTrack` table has the following structure:\n\n- **Columns**:\n  - `PlaylistId` (INTEGER, NOT NULL): This is a foreign key referencing the `Playlist` table.\n  - `TrackId` (INTEGER, NOT NULL): This is a foreign key referencing the `Track` table.\n\n- **Primary Key**: The combination of `PlaylistId` and `TrackId` serves as the primary key for this table, ensuring that each track can only appear once in a playlist.\n\n**Sample Rows**:\n1. PlaylistId: 1, TrackId: 3402\n2. PlaylistId: 1, TrackId: 3389\n3. PlaylistId: 1, TrackId: 3390'}