# Prompt Templates

Now that I have my vector database setup, I want to pull down the most likely schema and table and save the metadata as variables. Then I think the best way to proceed to is to create a prompt template that accepts the variables and allows a consistent entry point to the model.

## Imports

In [1]:
import pandas as pd
import os
from dotenv import load_dotenv

from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceEmbeddings
from langchain import HuggingFaceHub

from langchain import SQLDatabase, SQLDatabaseChain
from langchain.prompts.prompt import PromptTemplate

## Load Top Results from VectorDB

In [2]:
#setup embeddings using HuggingFace and the directory location
embeddings  = HuggingFaceEmbeddings()
persist_dir = '../data/processed/chromadb/schema-table-split'

# load from disk
vectordb = Chroma(persist_directory=persist_dir, embedding_function=embeddings)

In [3]:
#run prompt as query and get most likely results
query = "How many heads of the departments are older than 56?" #first prompt from the training data

result = vectordb.similarity_search(query, k=1)

In [4]:
result

[Document(page_content='Schema Table: department management head', metadata={'schema': 'department_management', 'table': 'head', 'columns': '["head_id", "name", "born_state", "age"]'})]

## Save Metadata to Variables

In [5]:
top_schema = result[0].metadata['schema']
top_table = result[0].metadata['table']
table_cols = result[0].metadata['columns']

print(top_schema, top_table, table_cols)

department_management head ["head_id", "name", "born_state", "age"]


## Build SQL Agent Through HF Hub

The docs on langchain are focused around either accessing the OpenAI API or a local model. I'm hoping I can mix both and access a Hugging Face model through an API without having to store it locally. I think I can do this through the Hugging Face Hub, by combined the steps in different parts of the documentation:
- [Langchain - Hugging Face Hub](https://python.langchain.com/docs/modules/model_io/models/llms/integrations/huggingface_hub)
- [Langchain - SQL Agents](https://python.langchain.com/docs/modules/chains/popular/sqlite)

I do have the start of the locally stored method commented out below.

### Load API Token

In [6]:
load_dotenv()
hf_api_token = os.getenv('hf_token')

In [7]:
#add path to HF repo
repo_id = 'tiiuae/falcon-7b-instruct'

### Point to LLM Model

In [8]:
#establish llm model
llm = HuggingFaceHub(repo_id=repo_id, huggingfacehub_api_token=hf_api_token, model_kwargs={"temperature": 0.5, "max_length": 200})

### Setup Simple SQL Agent
Langchain has some awesome features, but I'll start with a simple chain that points to one table.

First, I need to use the output of my vector db query to point to the sqlite database we want to work with.

#### Point to sqlite db location

In [9]:
BASE_DIR = os.path.dirname(os.path.abspath('../data/processed/db/department_management.sqlite'))
db_path = os.path.join(BASE_DIR, "department_management.sqlite")

db_path

'/Users/brettly/Sboard/projects/text-to-sql/data/processed/db/department_management.sqlite'

#### Establish db

In [10]:
db = SQLDatabase.from_uri("sqlite:///" + db_path)

#### Create Prompt Template
I'll even specify the single table here to really spoon feed the model.

In [11]:
sql_agent_prompt_template = """You are an expert data analyst. Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and describe the answer.
Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the SQLQuery
Answer: Final answer here

Only use the following tables:
{table_info}

Question: {input}"""

In [12]:
sql_prompt = PromptTemplate(
    input_variables=["input", "table_info", "dialect"], template=sql_agent_prompt_template)

#### Setup db_chain

In [13]:
db_chain = SQLDatabaseChain.from_llm(llm, db, prompt=sql_prompt, verbose=True)

#### Test with our simple question

In [14]:
db_chain.run(sql_prompt.format(input="How many heads of the departments are older than 56?", table_info="head", dialect="sqlite"))



[1m> Entering new  chain...[0m
You are an expert data analyst. Given an input question, first create a syntactically correct sqlite query to run, then look at the results of the query and describe the answer.
Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the SQLQuery
Answer: Final answer here

Only use the following tables:
head

Question: How many heads of the departments are older than 56?
SQLQuery:[32;1m[1;3mSELECT COUNT(t.head_ID) FROM head t[0m
SQLResult: [33;1m[1;3m[(10,)][0m
Answer:[32;1m[1;3m10 heads of the departments are older than 56.[0m
[1m> Finished chain.[0m


'10 heads of the departments are older than 56.'

**Observations:**
This simple chain method is working, although not as cleanly as I would like. The chain is inconsistent in returning "5" as the answer vs describing it with something like "There are 5 department heads older than 56."

I'm also struggling with the prompt template. I'm thinking there must be some preset formats in the langchain sql chain because anytime I try to change anything structurally, it errors out. So unforuntately this means the output usually includes "Question:" at the end. I'll need to check the documentation and online discussions a bit more to see if I can get around that.

## Create Functions
Even if I want some cleanup, this technically works which is great. But it is very segmented. So I want to pull the steps together so I don't have to manually change variables in each step.

In [15]:
def db_select(persist_dir, query):
    """
    Load the schema info vector database from disk and run the input question against it to return the most likely database we need to pull data from.
    """
    #setup embeddings using HuggingFace and the directory location
    embeddings  = HuggingFaceEmbeddings()
    per_dir = persist_dir

    # load from disk
    vectordb = Chroma(persist_directory=per_dir, embedding_function=embeddings)

    #run prompt as query and get most likely results
    result = vectordb.similarity_search(query, k=1)

    #save variables
    top_schema = result[0].metadata['schema']
    top_table = result[0].metadata['table']
    table_cols = result[0].metadata['columns']

    top_result = (top_schema, top_table, table_cols)

    return top_result

In [16]:
def locate_and_connect_db(filepath, filename):
    """Locate the absolute path of the given SQLITE database and connect to it via the langchain SQLDatabase.from_uri method.
    filepath is the filepath within the repo, example '../data/db/data.db'
    filename is just the filename.filetype, example 'data.db'
    """
    base_dir = os.path.dirname(os.path.abspath(filepath+filename)) #get the full path within the device
    db_path = os.path.join(base_dir, filename) #combine with filename to get db_path
    db = SQLDatabase.from_uri("sqlite:///" + db_path) #connect via the lanchain method

    return db

In [17]:
def load_llm_model(env_variable, repo_id='tiiuae/falcon-7b-instruct', temp=0.5, max_length=200):
    """
    Take in the target hugging face repo and the api key to setup the llm to use in our query chain
    env_variable is the name of the variable that stores your API key in your .env file
    """
    load_dotenv()
    hf_api_token = os.getenv(env_variable)
    llm = HuggingFaceHub(repo_id=repo_id, huggingfacehub_api_token=hf_api_token, model_kwargs={"temperature": temp, "max_length": max_length})

    return llm

In [40]:
template = """You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and explain the results in context of the original question.
            Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
            Query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
            Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
            Pay attention to use date('now') function to get the current date, if the question involves "today".

            Use the following format:

            Question: Question here
            SQLQuery: SQL Query to run
            SQLResult: Result of the SQLQuery
            Answer: Final answer here

            Only use the following tables:
            {table_info}

            Question: {input}"""


In [36]:
def create_prompt_template(prompt_template=template, input_vars=['input', 'table_info', 'top_k']):
    prompt_temp = PromptTemplate(input_variables=input_vars, template=prompt_template) #create our prompt template
    
    return prompt_temp

In [37]:
def create_sql_chain(llm, db, prompt_template, verbose=True):
    """Take in prompt template and input variables, along with the llm and db created from other functions to create our SQL Chain"""
    db_chain = SQLDatabaseChain.from_llm(llm, db, prompt=prompt_template, verbose=verbose, return_intermediate_steps=True) #create our database chain
    
    return db_chain

In [None]:
# def run_chain(db_chain, question, table, sql_dialect='sqlite'):
#     """Function to run chain that will end our overall application function by taking in your question and variables created by the other functions"""
#     db_chain.run(sql_prompt.format(input=question, table_info=table, dialect=sql_dialect))

In [41]:
def sql_analyst(question, vector_db_path='../data/processed/chromadb/schema-table-split', db_root_path='../data/processed/db/', sql_dialect='sqlite', env_api_key_var='hf_token'):
    """take in user info on where the databases and the api key are located and answer their question using the sql chain."""
    
    #query our vector database with the question to get the schema most related to the question.
    top_result = db_select(persist_dir=vector_db_path, query=question)
    top_schema = top_result[0]
    top_table = top_result[1]

    #use this result to establish our database
    db = locate_and_connect_db(filepath=db_root_path, filename=top_schema+'.'+sql_dialect)

    #initialize our large language model - needs an API key - we'll use the standard variables
    llm = load_llm_model(env_variable=env_api_key_var)

    #create prompt template using our preset template
    prompt_template = create_prompt_template(prompt_template=template)

    #setup the sql chain using the standard variables
    sql_chain = create_sql_chain(db=db, llm=llm, prompt_template=prompt_template)

    #run sql_chain on question
    print(top_schema)
    print(prompt_template)
    sql_chain(question)

#### Test

In [42]:
sql_analyst("How many heads of the departments are older than 56?")

department_management
input_variables=['input', 'table_info', 'top_k'] output_parser=None partial_variables={} template='You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and explain the results in context of the original question.\n            Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.\n            Query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.\n            Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.\n            Pay attention to use date(\'now\') function to get th

OperationalError: (sqlite3.OperationalError) no such column: management.dep
[SQL: SELECT COUNT(*) FROM management WHERE department_ID = '1' AND management.dep]
(Background on this error at: https://sqlalche.me/e/20/e3q8)