In [9]:
from langchain_community.utilities import SQLDatabase
from pyprojroot import here
import warnings
import os
from dotenv import load_dotenv
from urllib.parse import quote_plus

In [10]:
# Load environment variables
load_dotenv()


True

In [11]:
# PostgreSQL connection parameters
db_params = {
    "host": os.getenv("PG_HOST", "localhost"),
    "port": os.getenv("PG_PORT", "5432"),
    "database": "chinook",
    "user": os.getenv("PG_USER", "postgres"),
    "password": os.getenv("PG_PASSWORD", ""),
}

In [12]:
encoded_password = quote_plus(db_params["password"])

In [13]:
postgres_uri = (
    f"postgresql+psycopg2://"
    f"{db_params['user']}:{encoded_password}@"
    f"{db_params['host']}:{db_params['port']}/"
    f"{db_params['database']}"
)

In [15]:
# Connect to PostgreSQL database
db = SQLDatabase.from_uri(postgres_uri)

# Test the connection
print(db.dialect)
print(db.get_usable_table_names())

postgresql
['album', 'artist', 'customer', 'employee', 'genre', 'invoice', 'invoice_line', 'media_type', 'playlist', 'playlist_track', 'track']


In [37]:
# validate the connection to the vectordb
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM Artist LIMIT 10;")

postgresql
['album', 'artist', 'customer', 'employee', 'genre', 'invoice', 'invoice_line', 'media_type', 'playlist', 'playlist_track', 'track']


"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]"

In [41]:
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, SystemMessage
import os

# Set your OpenAI API key (from https://platform.openai.com/api-keys)
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")

# Initialize the LLM - specify your preferred model
llm = ChatOpenAI(
    model="gpt-3.5-turbo",  # or "gpt-4", "gpt-4-turbo", etc.
    temperature=0.0
)

# Create messages (LangChain message format)
messages = [
    SystemMessage(content="You are a helpful assistant"),
    HumanMessage(content="hello")
]

# Get response
response = llm.invoke(messages)
print(response.content)

Hello! How can I assist you today?


In [42]:
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
import os

# Define the prompt template
prompt = ChatPromptTemplate.from_messages([
    ("system", "You are a helpful assistant"),
    ("user", "{input}")
])

# Create the chain
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.0)
chain = prompt | llm | StrOutputParser()

# Invoke the chain
response = chain.invoke({"input": "hello"})
print(response)

Hello! How can I assist you today?


In [43]:
from langchain.chains import create_sql_query_chain

chain = create_sql_query_chain(llm, db)
response = chain.invoke({"question": "How many employees are there"})
print(response)

SELECT COUNT(employee_id) AS total_employees
FROM employee;


In [44]:
db.run(response.replace("SQLQuery:", "").strip())

'[(8,)]'

In [45]:
chain.get_prompts()[0].pretty_print()

You are a PostgreSQL expert. Given an input question, first create a syntactically correct PostgreSQL query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per PostgreSQL. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use CURRENT_DATE function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLR

In [46]:
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool

write_query = create_sql_query_chain(llm, db)
execute_query = QuerySQLDataBaseTool(db=db)

chain = write_query | execute_query

chain.invoke({"question": "How many employees are there"})

'[(8,)]'

In [47]:
from operator import itemgetter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough

answer_prompt = PromptTemplate.from_template(
    """Given the following user question, corresponding SQL query, and SQL result, answer the user question. The response should be user friendly.
Question: {question}
SQL Query: {query}
SQL Result: {result}
Answer: """
)

answer = answer_prompt | llm | StrOutputParser()
chain = (
    RunnablePassthrough.assign(query=write_query).assign(
        result=itemgetter("query") | execute_query
    )
    | answer
)

chain.invoke({"question": "List the total sales per country. Which country's customers spent the most?"})

'The customers from the United States spent the most with a total sales amount of $523.06.'

In [48]:
from langchain_community.agent_toolkits import create_sql_agent

agent_executor = create_sql_agent(llm, db=db, agent_type="openai-tools", verbose=True)

In [49]:
agent_executor.invoke(
    {
        "input": "List the total sales per country. Which country's customers spent the most?"
    }
)



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3m
Invoking: `sql_db_list_tables` with `{}`


[0m[38;5;200m[1;3malbum, artist, customer, employee, genre, invoice, invoice_line, media_type, playlist, playlist_track, track[0m[32;1m[1;3m
Invoking: `sql_db_schema` with `{'table_names': 'customer, invoice, invoice_line'}`


[0m[33;1m[1;3m
CREATE TABLE customer (
	customer_id INTEGER NOT NULL, 
	first_name VARCHAR(40) NOT NULL, 
	last_name VARCHAR(20) NOT NULL, 
	company VARCHAR(80), 
	address VARCHAR(70), 
	city VARCHAR(40), 
	state VARCHAR(40), 
	country VARCHAR(40), 
	postal_code VARCHAR(10), 
	phone VARCHAR(24), 
	fax VARCHAR(24), 
	email VARCHAR(60) NOT NULL, 
	support_rep_id INTEGER, 
	CONSTRAINT customer_pkey PRIMARY KEY (customer_id), 
	CONSTRAINT customer_support_rep_id_fkey FOREIGN KEY(support_rep_id) REFERENCES employee (employee_id)
)

/*
3 rows from customer table:
customer_id	first_name	last_name	company	address	city	state	country	postal_code	phone	fax	

{'input': "List the total sales per country. Which country's customers spent the most?",
 'output': 'The total sales per country are as follows:\n\n1. USA: $523.06\n2. Canada: $303.96\n3. France: $195.10\n4. Brazil: $190.10\n5. Germany: $156.48\n6. United Kingdom: $112.86\n7. Czech Republic: $90.24\n8. Portugal: $77.24\n9. India: $75.26\n10. Chile: $46.62\n\nThe country whose customers spent the most is the USA with a total sales of $523.06.'}

In [50]:
agent_executor.invoke({"input": "Describe the playlisttrack table"})



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3m
Invoking: `sql_db_list_tables` with `{}`


[0m[38;5;200m[1;3malbum, artist, customer, employee, genre, invoice, invoice_line, media_type, playlist, playlist_track, track[0m[32;1m[1;3m
Invoking: `sql_db_schema` with `{'table_names': 'playlist_track'}`


[0m[33;1m[1;3m
CREATE TABLE playlist_track (
	playlist_id INTEGER NOT NULL, 
	track_id INTEGER NOT NULL, 
	CONSTRAINT playlist_track_pkey PRIMARY KEY (playlist_id, track_id), 
	CONSTRAINT playlist_track_playlist_id_fkey FOREIGN KEY(playlist_id) REFERENCES playlist (playlist_id), 
	CONSTRAINT playlist_track_track_id_fkey FOREIGN KEY(track_id) REFERENCES track (track_id)
)

/*
3 rows from playlist_track table:
playlist_id	track_id
1	3402
1	3389
1	3390
*/[0m[32;1m[1;3mThe `playlist_track` table has the following columns:
- `playlist_id` (INTEGER, NOT NULL)
- `track_id` (INTEGER, NOT NULL)

It also has the following constraints:
- Primary Key: (playlist_id, track_

{'input': 'Describe the playlisttrack table',
 'output': 'The `playlist_track` table has the following columns:\n- `playlist_id` (INTEGER, NOT NULL)\n- `track_id` (INTEGER, NOT NULL)\n\nIt also has the following constraints:\n- Primary Key: (playlist_id, track_id)\n- Foreign Key: playlist_id references playlist(playlist_id)\n- Foreign Key: track_id references track(track_id)\n\nHere are 3 sample rows from the `playlist_track` table:\n1. playlist_id: 1, track_id: 3402\n2. playlist_id: 1, track_id: 3389\n3. playlist_id: 1, track_id: 3390'}