In [1]:
import sys
sys.path.append("..")

In [2]:
from snowflake_llm.snowflake_utils import get_lagchain_connection
db=get_lagchain_connection()

In [3]:
from langchain.chains import create_sql_query_chain
from langchain_openai import ChatOpenAI

In [4]:
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
chain = create_sql_query_chain(llm, db)
response = chain.invoke({"question": "How many locations are there"})
response

'SELECT COUNT(*) FROM locations;'

In [5]:
db.run(response)

'[(5,)]'

In [6]:
chain.get_prompts()[0].pretty_print()


Given an input question, first create a syntactically correct snowflake query to run, then look at the results of the query and return the answer. Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most 5 results. You can order the results by a relevant column to return the most interesting examples in the database.

Never query for all the columns from a specific table, only ask for a the few relevant columns given the question.

Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the SQLQuery
Answer: Final answer here

Only use the following tables:
[33;1m[1;3m{table_info}[0m

Question: [33;1m[1;3m{input}[0m


In [7]:
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool

execute_query = QuerySQLDataBaseTool(db=db)
write_query = create_sql_query_chain(llm, db)
chain = write_query | execute_query
# chain.invoke({"question": "How many locations are there"})
# chain.invoke({"question": "what are the locations for order id 1"})


In [8]:
from operator import itemgetter

from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough

answer_prompt = PromptTemplate.from_template(
    """Given the following user question, corresponding SQL query, and SQL result, answer the user question.

Question: {question}
SQL Query: {query}
SQL Result: {result}
Answer: """
)

answer = answer_prompt | llm | StrOutputParser()
chain = (
    RunnablePassthrough.assign(query=write_query).assign(
        result=itemgetter("query") | execute_query
    )
    | answer
)

In [9]:
res=chain.invoke({"question": "describe orders with prduct A including locations"})

In [10]:
res

'The orders for Product A include the following details:\n1. Order ID: 1, Quantity: 3, Order Date: August 15, 2023, Order Amount: $150.00, Location: Warehouse A, City: New York, State: NY\n2. Order ID: 7, Quantity: 2, Order Date: August 19, 2023, Order Amount: $100.00, Location: Warehouse B, City: Los Angeles, State: CA\n3. Order ID: 13, Quantity: 4, Order Date: August 22, 2023, Order Amount: $200.00, Location: Warehouse C, City: Chicago, State: IL\n4. Order ID: 4, Quantity: 4, Order Date: August 16, 2023, Order Amount: $200.00, Location: Warehouse D, City: Houston, State: TX\n5. Order ID: 10, Quantity: 3, Order Date: August 21, 2023, Order Amount: $135.00, Location: Warehouse E, City: Miami, State: FL'