In [1]:
import sys
sys.path.append("..")

<h1>Example with inbuilt chain</h1>

In [2]:
database="TESTDB"
schema="SALES"

In [3]:
from snowflake_llm.snowflake_utils import get_lagchain_connection
db=get_lagchain_connection(database,schema)

In [4]:
from langchain.chains import create_sql_query_chain
from langchain_openai import ChatOpenAI
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
query_chain = create_sql_query_chain(llm, db,k=1)

In [5]:
response = query_chain.invoke({"question":  "How many tables do we have"})
print(response)

SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = 'PUBLIC';


In [6]:
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool

execute_query = QuerySQLDataBaseTool(db=db)
write_query = create_sql_query_chain(llm, db)
chain = write_query | execute_query
chain.invoke({"question": "How many locations are there"})
# chain.invoke({"question": "what are the locations for order id 1"})


'[(5,)]'

<h1>Build the same using retrieval query_chain for large database</h1>

In [7]:
table_names=db.get_usable_table_names()
table_info=db.get_table_info(table_names)

In [9]:
from langchain.text_splitter import CharacterTextSplitter 
from langchain.schema.document import Document
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableLambda, RunnablePassthrough

<h2>create context vector for all metadata<h2>

In [10]:
def get_text_chunks_langchain(text):
    """_summary_

        split db description in per table DDL definition as t follows a specific format
    """
    s =  ['CREATE'+d for d in table_info.split('CREATE') if d.strip() ]
    docs = [Document(page_content=x) for x in s]
    return docs

In [11]:
def convert_table_info_to_documents(table_info):
    documents=get_text_chunks_langchain(table_info)
    embeddings = OpenAIEmbeddings()
    db = FAISS.from_documents(documents, embeddings)
    return db

In [12]:
table_metadata_vector_store=convert_table_info_to_documents(table_info)

In [13]:
retriever=table_metadata_vector_store.as_retriever()

In [14]:
template=""" 
Given an input question, first create a syntactically correct snowflake query to run, then look at the results of the query and return the answer. Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most 5 results. You can order the results by a relevant column to return the most interesting examples in the database.

Never query for all the columns from a specific table, only ask for a the few relevant columns given the question.

Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Be very strict about the syntax of the sql query. remember the following
1.If I don't tell you to find a limited set of results in the sql query or question, you MUST limit the number of responses to 10.
3. Text / string where clauses must be fuzzy match e.g ilike %keyword%
4. Make sure to generate a single snowflake sql code, not multiple.

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the SQLQuery
Answer: Final answer here

Use the context provided to parse the table information:
Context:{context}

Question: {question}
"""

In [15]:
from langchain_core.prompts import PromptTemplate
prompt_template=PromptTemplate.from_template(template   
)

In [18]:
from operator import itemgetter
from langchain_core.runnables import RunnablePassthrough,RunnableParallel
retrieval_chain = (
            {
                "context": itemgetter("question")| retriever,
                "question": itemgetter("question"),
            }
            |RunnableParallel(
                question=lambda x: x['question']+"\nSQLQuery: ",
                context=RunnablePassthrough(),
            )            
            | prompt_template
            | llm.bind(stop=["\nSQLResult:"])
            | StrOutputParser()
            |(lambda text: text.strip())
            
        )

In [29]:
res=retrieval_chain.invoke({"question":  "what tables are there in the database"})

In [30]:
res

"SELECT table_name\nFROM information_schema.tables\nWHERE table_schema = 'public'"

In [21]:
execute_query = QuerySQLDataBaseTool(db=db)

In [22]:
chain = retrieval_chain  | execute_query
chain.invoke({"question": "what tables are in the database?"})

''

<h1>Lets use the new query builder with the answer chain</h1>

In [31]:
from operator import itemgetter

from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough

answer_prompt = PromptTemplate.from_template(
    """Given the following user question, corresponding SQL query, and SQL result, answer the user question.

Question: {question}
SQL Query: {query}
SQL Result: {result}
Answer: """
)

answer = answer_prompt | llm | StrOutputParser()
chain = (
    RunnablePassthrough.assign(query=retrieval_chain).assign(
        result=itemgetter("query") | execute_query
    )
    | answer
)

In [32]:
chain.invoke({"question": "what tables are in the database?"})

"The SQL query provided will return a list of tables in the database schema named 'PUBLIC'. The result of the query will be a list of table names that exist in that schema."