In [51]:
import os
from dotenv import load_dotenv

load_dotenv(dotenv_path="../.env")

api_key = os.getenv("OPENAI_API_KEY")
os.environ["OPENAI_API_KEY"] = api_key

In [52]:
from langchain_core.prompts import ChatPromptTemplate

template = """Based on the table schema below, write a SQL query that would answer the user's question:
{schema}

Question: {question}
SQL Query:"""
prompt = ChatPromptTemplate.from_template(template)


In [53]:
prompt.format(schema =  "my schema", question = "how many users are there?")

"Human: Based on the table schema below, write a SQL query that would answer the user's question:\nmy schema\n\nQuestion: how many users are there?\nSQL Query:"

In [54]:
from langchain_community.utilities import SQLDatabase

postgresql_uri = "postgresql+psycopg2://root:cs620ibdc1234@localhost:5432/IowaLiquorSales"


db = SQLDatabase.from_uri(postgresql_uri)

def get_schema(_):
    return db.get_table_info()


In [55]:
db.run("SELECT COUNT(*) FROM LiquorSales;")

'[(2622712,)]'

In [56]:
db.run("SELECT * FROM LiquorSales LIMIT 1;")

"[('INV-33202600029', datetime.datetime(2021, 1, 4, 0, 0), '5336', 'EXPRESS MART', '4804 S HIGHWAY 61', 'MUSCATINE', '52761', None, '70', 'MUSCATINE', '1062100', 'GOLD RUM', '035', 'BACARDI USA INC', '43036', 'BACARDI GOLD', 12, 750.0, 8.26, 12.39, 2, 24.78, 1.5, 0.39)]"

In [57]:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI

llm = ChatOpenAI()

sql_chain = (
    RunnablePassthrough.assign(schema=get_schema)
    | prompt
    | llm.bind(stop=["\nSQLResult:"])
    | StrOutputParser()
)


In [58]:
user_question = 'How many records are in the liquorsales table?'
query = sql_chain.invoke({"question": user_question})
print("Generated SQL Query:", query)

Generated SQL Query: SELECT COUNT(*) AS total_records FROM liquorsales;


In [59]:
template = """Based on the table schema below, question, sql query, and sql response, write a natural language response:
{schema}

Question: {question}
SQL Query: {query}
SQL Response: {response}"""
prompt_response = ChatPromptTemplate.from_template(template)

print("Generated Prompt:", prompt_response)


Generated Prompt: input_variables=['query', 'question', 'response', 'schema'] input_types={} partial_variables={} messages=[HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['query', 'question', 'response', 'schema'], input_types={}, partial_variables={}, template='Based on the table schema below, question, sql query, and sql response, write a natural language response:\n{schema}\n\nQuestion: {question}\nSQL Query: {query}\nSQL Response: {response}'), additional_kwargs={})]


In [60]:
def run_query(query):
    return db.run(query)

In [61]:
run_query("SELECT COUNT(*) FROM liquorsales;")

'[(2622712,)]'

In [62]:
full_chain = (
    RunnablePassthrough.assign(query=sql_chain).assign(
        schema=get_schema,
        response=lambda vars: run_query(vars["query"]),
    )
    | prompt_response
    | llm
)


In [63]:
user_question = 'How many records are in the liquorsales table?'
query = full_chain.invoke({"question": user_question})

print("Generated Response:", query)

Generated Response: content='There are a total of 2,622,712 records in the liquorsales table.' additional_kwargs={'refusal': None} response_metadata={'token_usage': {'completion_tokens': 20, 'prompt_tokens': 602, 'total_tokens': 622, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}}, 'model_name': 'gpt-3.5-turbo-0125', 'system_fingerprint': None, 'finish_reason': 'stop', 'logprobs': None} id='run-1b2a9df3-06a5-4b9d-b481-3deb49b90edb-0' usage_metadata={'input_tokens': 602, 'output_tokens': 20, 'total_tokens': 622, 'input_token_details': {'audio': 0, 'cache_read': 0}, 'output_token_details': {'audio': 0, 'reasoning': 0}}
