## SQLDatabaseSequentialChain

Using the simple chain, I got the process working, but the results aren't great. If I point the input variables to very specific places it can get the right answer. But using the standard prompt template through langchain doesn't give the correct results for our test queries. This prompt feeds in the table creation statements and first few rows of each table. In doing this it then relies on the AI to take the useful information based on the input question. I think there is some further prompt engineering I could do, but first I want to explore the other SQL options from Lanchain and iterate based on what I feel offers the best long-term solultions.

There is this sequential chain that 1. Determines which tables to us based on the query. 2. Based on those tables, call the normal SQL database chain.

This sounds similar, but I'm curious if parsing these and asking the model to do this in 2 steps will improve performance.

### Imports

In [1]:
import pandas as pd
import os
from dotenv import load_dotenv

from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceEmbeddings
from langchain import HuggingFaceHub

from langchain import SQLDatabase, SQLDatabaseChain
from langchain.chains import SQLDatabaseSequentialChain
from langchain.prompts.prompt import PromptTemplate

### Identify Target Schema & Load DB

I think this could be built out, but I just want to test this sequential chain for now, so I'll leave as is.

In [2]:
def db_select(persist_dir, query):
    """
    Load the schema info vector database from disk and run the input question against it to return the most likely database we need to pull data from.
    """
    #setup embeddings using HuggingFace and the directory location
    embeddings  = HuggingFaceEmbeddings()
    per_dir = persist_dir

    # load from disk
    vectordb = Chroma(persist_directory=per_dir, embedding_function=embeddings)

    #run prompt as query and get most likely results
    result = vectordb.similarity_search(query, k=1)

    #save variables
    top_schema = result[0].metadata['schema']
    top_table = result[0].metadata['table']
    table_cols = result[0].metadata['columns']

    top_result = (top_schema, top_table, table_cols)

    return top_result

In [3]:
def locate_and_connect_db(filepath, filename):
    """Locate the absolute path of the given SQLITE database and connect to it via the langchain SQLDatabase.from_uri method.
    filepath is the filepath within the repo, example '../data/db/data.db'
    filename is just the filename.filetype, example 'data.db'
    """
    base_dir = os.path.dirname(os.path.abspath(filepath+filename)) #get the full path within the device
    db_path = os.path.join(base_dir, filename) #combine with filename to get db_path
    db = SQLDatabase.from_uri("sqlite:///" + db_path) #connect via the lanchain method

    return db

### Load LLM

In [4]:
def load_llm_model(env_variable, repo_id='tiiuae/falcon-7b-instruct', temp=0.5, max_length=200):
    """
    Take in the target hugging face repo and the api key to setup the llm to use in our query chain
    env_variable is the name of the variable that stores your API key in your .env file
    """
    load_dotenv()
    hf_api_token = os.getenv(env_variable)
    llm = HuggingFaceHub(repo_id=repo_id, huggingfacehub_api_token=hf_api_token, model_kwargs={"temperature": temp, "max_length": max_length})

    return llm

### Create Prompt

In [5]:
# template = """You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and explain the results in context of the original question.
#             Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
#             Query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
#             Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
#             Pay attention to use date('now') function to get the current date, if the question involves "today".

#             Use the following format:

#             Question: Question here
#             SQLQuery: SQL Query to run
#             SQLResult: Result of the SQLQuery
#             Answer: Final answer here

#             Only use the following tables:
#             {table_info}

#             Question: {input}"""

In [6]:
# def create_prompt_template(prompt_template=template, input_vars=['input', 'table_info', 'top_k']):
#     prompt_temp = PromptTemplate(input_variables=input_vars, template=prompt_template) #create our prompt template
    
#     return prompt_temp

### Create Chain

In [7]:
def create_sql_chain(llm, db, verbose=True):
    """Take in prompt template and input variables, along with the llm and db created from other functions to create our SQL Chain"""
    db_chain = SQLDatabaseSequentialChain.from_llm(llm, db, verbose=verbose) #create our database chain
    
    return db_chain

### Link Full Process

In [12]:
def sql_analyst_seq(question, vector_db_path='../data/processed/chromadb/schema-table-split', db_root_path='../data/processed/db/', sql_dialect='sqlite', env_api_key_var='hf_token'):
    """take in user info on where the databases and the api key are located and answer their question using the sql chain."""
    
    #query our vector database with the question to get the schema most related to the question.
    top_result = db_select(persist_dir=vector_db_path, query=question)
    top_schema = top_result[0]
    top_table = top_result[1]

    #use this result to establish our database
    db = locate_and_connect_db(filepath=db_root_path, filename=top_schema+'.'+sql_dialect)

    #initialize our large language model - needs an API key - we'll use the standard variables
    llm = load_llm_model(env_variable=env_api_key_var)

    # #create prompt template using our preset template
    # prompt_template = create_prompt_template(prompt_template=template)

    #setup the sql chain using the standard variables
    sql_chain = create_sql_chain(db=db, llm=llm)

    #run sql_chain on question
    sql_chain.run(question)

## Test

In [13]:
sql_analyst_seq("How many heads of the departments are older than 56?")

department_management


[1m> Entering new  chain...[0m




Table names to use:
[33;1m[1;3m[][0m

[1m> Entering new  chain...[0m
How many heads of the departments are older than 56?
SQLQuery:[32;1m[1;3mSELECT COUNT(DISTINCT(departments.name) FROM departments INNER JO[0m

OperationalError: (sqlite3.OperationalError) near "FROM": syntax error
[SQL: SELECT COUNT(DISTINCT(departments.name) FROM departments INNER JO]
(Background on this error at: https://sqlalche.me/e/20/e3q8)