# Custom Agent

I had limited success with the built-in functions, so here I test and build out a custom method, which is really a combination of tweaked langchain functions.

I think part of this could be more specialized schema and table extraction, but for now I test with my same question.

In [62]:
from watermark import watermark
print(watermark())

Last updated: 2023-08-01T00:56:17.968275-07:00

Python implementation: CPython
Python version       : 3.11.4
IPython version      : 8.14.0

Compiler    : Clang 15.0.7 
OS          : Darwin
Release     : 22.5.0
Machine     : x86_64
Processor   : i386
CPU cores   : 8
Architecture: 64bit



### Imports

In [2]:
import pandas as pd
import os
from dotenv import load_dotenv
import time

from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceEmbeddings
from langchain import HuggingFaceHub, SQLDatabase
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain

In [61]:
#set this to false to get rid of warning
os.environ["TOKENIZERS_PARALLELISM"] = "false"

### Setup query
We'll use this to input into the model

In [1]:
question = "How many heads of the departments are older than 56?"

### Load Target Tables from Vectordb

Load in the top 3 results and save just the unique schemas to a set.

In [3]:
#setup embeddings using HuggingFace and the directory location
embeddings  = HuggingFaceEmbeddings()
persist_dir = '../data/processed/chromadb/schema-table-info'

# load from disk
vectordb = Chroma(persist_directory=persist_dir, embedding_function=embeddings)

#run similarity search
top_results = vectordb.similarity_search(question, k=3) #just return the top for now

In [4]:
top_results

[Document(page_content='\nSchema: department_management\nTable: head\nColumns: ["head_ID", "name", "born_state", "age"]\nDDL:\n    \nCREATE TABLE head (\n\t"head_ID" INTEGER, \n\tname TEXT, \n\tborn_state TEXT, \n\tage REAL, \n\tPRIMARY KEY ("head_ID")\n)\n\n/*\n3 rows from head table:\nhead_ID\tname\tborn_state\tage\n1\tTiger Woods\tAlabama\t67.0\n2\tSergio García\tCalifornia\t68.0\n3\tK. J. Choi\tAlabama\t69.0\n*/\n', metadata={'schema': 'department_management', 'table': 'head', 'columns': '["head_ID", "name", "born_state", "age"]'}),
 Document(page_content='\nSchema: hr_1\nTable: departments\nColumns: ["DEPARTMENT_ID", "DEPARTMENT_NAME", "MANAGER_ID", "LOCATION_ID"]\nDDL:\n    \nCREATE TABLE departments (\n\t"DEPARTMENT_ID" DECIMAL(4, 0) DEFAULT \'0\' NOT NULL, \n\t"DEPARTMENT_NAME" VARCHAR(30) NOT NULL, \n\t"MANAGER_ID" DECIMAL(6, 0) DEFAULT NULL, \n\t"LOCATION_ID" DECIMAL(4, 0) DEFAULT NULL, \n\tPRIMARY KEY ("DEPARTMENT_ID")\n)\n\n/*\n3 rows from departments table:\nDEPARTMENT_ID\

In [5]:
top_results[0].metadata['schema']

'department_management'

#### Setup Target Schemas

I want to get a unique list of schemas from the top 3 documents returned by the vecotr database similarity search - and make sure they stay in order of importance.

In [6]:
#try in a standard list with dict.fromkeys
tgt_list = list(dict.fromkeys([doc.metadata['schema'] for doc in top_results]))

tgt_list

['department_management', 'hr_1', 'company_1']

In [9]:
tgt_list[0]

'department_management'

### Tool - Query for Schema Information
Use the langchain SQLDatabase class to get all the information about the schema. I think I also want to use the most likely columns returned to sort and prioritize how the tables are fed into the llm.

I could load in the sql_query tools from langchain.tools, but I want to tweek things a little bit, especially in the order. I could come back to this later if it does simplify things and my way doesn't make the outcome any better.

#### Point to Database

In [10]:
db_filepath = '../data/processed/db/'
db_filename = list(target_schemas)[0] + '.sqlite'

In [11]:
#point to database
base_dir = os.path.dirname(os.path.abspath(db_filepath+db_filename)) #get the full path within the device
db_path = os.path.join(base_dir, db_filename) #combine with filename to get db_path
db = SQLDatabase.from_uri("sqlite:///" + db_path) #connect via the lanchain method

#### Set table sorting and then pull tables names using langchain

In testing prompts in the last 3 notebooks I noticed that loading in all the tables without context of which table the similarity search even caused some troubles for ChatGPT. This got me thinking that even just putting the tables in order of how the similarity search ranked them compared to our question could be a useful change to the prompt. So the next steps I take are to to pull in the potential targets

In [14]:
#pull the table names from the vector query so we can have them in order of similarity to the question.
table_sorting = set()
for doc in top_results:
    schema = doc.metadata['schema']
    if schema == 'department_management':
        table_name = doc.metadata['table']
        table_sorting.add(table_name)
table_sorting = list(table_sorting)
print(table_sorting)

['head']


I could just use this one returned table, but for this version of the app I will just use this set to sort all of the columns in the table to feed to the LLM with the "set" columns prioritized.

In [15]:
#get tables names
table_names = db.get_usable_table_names()

#sort by the results of our vector query to load the most likely table in first. Not sure if this will make a difference, but can't imagine it would hurt.
priority_tables = sorted(table_names, key=lambda x: (table_sorting.index(x) if x in table_sorting else float('inf'), table_names.index(x)))

priority_tables

['head', 'department', 'management']

#### Get SQL Dialect

Using the langchain class again

In [16]:
sql_dialect = db.dialect

sql_dialect

'sqlite'

#### Get table info

You can do this for all tables in the database automatically with the function, but I want to continue trying to order these by likelihood of being related to the question. So I'll do this in a loop.

This gives the full create statement and 3 sample rows, which I think will be a great start. And really gets us to where the Chain does for adding to the prompt template. I'll also combine this with the table name at the start of it.

In [17]:
table_info = []
for table in priority_tables:
    table_name = table
    info = db.get_table_info(table_names=[table])
    table_info.append('DDL for the ' + table_name + ' table:' + info)

#merge the items
tables_summary = '\n\n'.join(table_info)

print(tables_summary)

DDL for the head table:
CREATE TABLE head (
	"head_ID" INTEGER, 
	name TEXT, 
	born_state TEXT, 
	age REAL, 
	PRIMARY KEY ("head_ID")
)

/*
3 rows from head table:
head_ID	name	born_state	age
1	Tiger Woods	Alabama	67.0
2	Sergio García	California	68.0
3	K. J. Choi	Alabama	69.0
*/

DDL for the department table:
CREATE TABLE department (
	"Department_ID" INTEGER, 
	"Name" TEXT, 
	"Creation" TEXT, 
	"Ranking" INTEGER, 
	"Budget_in_Billions" REAL, 
	"Num_Employees" REAL, 
	PRIMARY KEY ("Department_ID")
)

/*
3 rows from department table:
Department_ID	Name	Creation	Ranking	Budget_in_Billions	Num_Employees
1	State	1789	1	9.96	30266.0
2	Treasury	1789	2	11.1	115897.0
3	Defense	1947	3	439.3	3000000.0
*/

DDL for the management table:
CREATE TABLE management (
	"department_ID" INTEGER, 
	"head_ID" INTEGER, 
	temporary_acting TEXT, 
	PRIMARY KEY ("department_ID", "head_ID"), 
	FOREIGN KEY("head_ID") REFERENCES head ("head_ID"), 
	FOREIGN KEY("department_ID") REFERENCES department ("Department_ID"

### Create Prompt Template

Create a simple template based on the original provided by the langchain repo.

In [18]:
create_prompt = PromptTemplate(
    input_variables=[],
    template = f"""
        You are a SQL Query Writer. Given an input question, first create a working {sql_dialect} SQL statement to find the answer to an input question and then return only the syntactically correct SQL statement.
        
        Use one or multiple of these tables:
        {tables_summary} 

        Input question: "{question}"
    """
)

In [19]:
print(create_prompt)

input_variables=[] output_parser=None partial_variables={} template='\n        You are a SQL Query Writer. Given an input question, first create a working sqlite SQL statement to find the answer to an input question and then return only the syntactically correct SQL statement.\n        \n        Use one or multiple of these tables:\n        DDL for the head table:\nCREATE TABLE head (\n\t"head_ID" INTEGER, \n\tname TEXT, \n\tborn_state TEXT, \n\tage REAL, \n\tPRIMARY KEY ("head_ID")\n)\n\n/*\n3 rows from head table:\nhead_ID\tname\tborn_state\tage\n1\tTiger Woods\tAlabama\t67.0\n2\tSergio García\tCalifornia\t68.0\n3\tK. J. Choi\tAlabama\t69.0\n*/\n\nDDL for the department table:\nCREATE TABLE department (\n\t"Department_ID" INTEGER, \n\t"Name" TEXT, \n\t"Creation" TEXT, \n\t"Ranking" INTEGER, \n\t"Budget_in_Billions" REAL, \n\t"Num_Employees" REAL, \n\tPRIMARY KEY ("Department_ID")\n)\n\n/*\n3 rows from department table:\nDepartment_ID\tName\tCreation\tRanking\tBudget_in_Billions\tNum_

### Setup LLM
I'll test the google flan-t5 model

In [20]:
#get api key
load_dotenv()
hf_api_token = os.getenv('hf_token')

#add path to HF repo
repo_id = 'google/flan-t5-xxl'

#establish llm model
llm = HuggingFaceHub(repo_id=repo_id, huggingfacehub_api_token=hf_api_token, model_kwargs={"temperature": 0.1, "max_length": 512})

### 1st Chain - Generate SQL Query

In [21]:
chain = LLMChain(llm=llm, prompt=create_prompt, verbose=False)
sql_query = chain.predict()

In [22]:
sql_query

'SELECT count(*) FROM head WHERE age > 56'

### Chain for checking SQL Query

This is directly from the langchain check query prompt. It looks for common mistakes in the initially written sql query.

In [24]:
validate_prompt = PromptTemplate(
    input_variables=[],
    template = f"""{sql_query}

    Double check the {sql_dialect} query above for common mistakes, including:
    - Using NOT IN with NULL values
    - Using UNION when UNION ALL should have been used
    - Using BETWEEN for exclusive ranges
    - Data type mismatch in predicates
    - Properly quoting identifiers
    - Using the correct number of arguments for functions
    - Casting to the correct data type
    - Using the proper columns for joins

    If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query."""
)

In [25]:
chain = LLMChain(llm=llm, prompt=validate_prompt, verbose=False)
validated_query = chain.predict()

In [26]:
validated_query

'SELECT count(*) FROM head WHERE age > 56'

### Run and Debug the Query

Run the query on the database, if it is succesful, send the results to an LLM to interpret them.
If unsuccessful, ask and LLM to debug and write a new one based on the error message. And then try again, only exiting the loop once it works.

The wonderful folks at langchain created a run_no_throw funciton that returns the successful results or the error message. Not a full big error.

In [27]:
sql_output = db.run_no_throw(validated_query)

sql_output

'[(5,)]'

In [28]:
#create prompt for answering the question
analyze_prompt = PromptTemplate(
    input_variables=[],
    template = f"""
    You are an expert data analyst. Given an output of an SQL query, first look at the output and then determine the answer to a user question.
    
    SQL Output: "{sql_output}"
    Question: "{question}"

    Describe your answer with one sentence:
    """
)

In [29]:
#test analyze chain
chain = LLMChain(llm=llm, prompt=analyze_prompt, verbose=False)
answer = chain.predict()

In [30]:
answer

'The number of heads of the departments older than 56 is 5.'

In [31]:
#create a query that should error to test off of
bad_query = 'SELECT count(a) FROM head WHERE age > 56'

In [33]:
#test what output looks like
query_error = db.run_no_throw(bad_query)

query_error

'Error: (sqlite3.OperationalError) no such column: a\n[SQL: SELECT count(a) FROM head WHERE age > 56]\n(Background on this error at: https://sqlalche.me/e/20/e3q8)'

In [34]:
#create query that should return a null avlue
empty_query = "SELECT name FROM head WHERE head_id = '25'"

In [35]:
test_empty = db.run_no_throw(empty_query)

test_empty

'[]'

In [36]:
#create prompt for debugging error
debug_error_prompt = PromptTemplate(
    input_variables=[],
    template = f"""{bad_query}

The query above produced the following error:

{query_error}

Rewrite the query with the error fixed:"""
)

In [37]:
#create prompt for trying to fix an empty output
debug_empty_prompt = PromptTemplate(
    input_variables=[],
    template = f"""{empty_query}

The query above produced no result. Try rewriting the query so it will return results to this question "{question}":"""
)

In [38]:
#test debug
chain = LLMChain(llm=llm, prompt=debug_error_prompt, verbose=True)
debugged = chain.predict()



[1m> Entering new  chain...[0m
Prompt after formatting:
[32;1m[1;3mSELECT count(a) FROM head WHERE age > 56

The query above produced the following error:

Error: (sqlite3.OperationalError) no such column: a
[SQL: SELECT count(a) FROM head WHERE age > 56]
(Background on this error at: https://sqlalche.me/e/20/e3q8)

Rewrite the query with the error fixed:[0m

[1m> Finished chain.[0m


In [39]:
debugged

'SELECT count(*) FROM head WHERE age > 56'

In [40]:
#test empty debug
chain = LLMChain(llm=llm, prompt=debug_empty_prompt, verbose=True)
emp_debugged = chain.predict()



[1m> Entering new  chain...[0m
Prompt after formatting:
[32;1m[1;3mSELECT name FROM head WHERE head_id = '25'

The query above produced no result. Try rewriting the query so it will return results to this question "How many heads of the departments are older than 56?":[0m

[1m> Finished chain.[0m


In [41]:
emp_debugged

'SELECT count(*) FROM head WHERE age > 56'

That looks to be working with the steps manually broken out. So now I'll work on creating functions to make this easier to repeat.

## Create Functions

### Selecting database, connecting to it, running all of the prompts, etc.

I want to first build this out as a local application that has multiple entry point into the LLMs, and then work on building this out as a custom agent.

In [42]:
#establish unique chains
create_chain = LLMChain(llm=llm, prompt=create_prompt, verbose=False)

validate_chain = LLMChain(llm=llm, prompt=validate_prompt, verbose=False)

analyze_chain = LLMChain(llm=llm, prompt=analyze_prompt, verbose=False)

debug_error_chain = LLMChain(llm=llm, prompt=debug_error_prompt, verbose=False)

debug_empty_chain = LLMChain(llm=llm, prompt=debug_empty_prompt, verbose=False)

In [43]:
def similar_doc_search(question, vector_db=vectordb, k=3):
    """
    Run similarity search through a vectordb query on the user input question and return the k most similar schema-table combinations in document form..
    """
    similar_docs = vector_db.similarity_search(question, k)

    return similar_docs

In [44]:
def identify_schemas(documents):
    """
    Take in a list of documents created from our similar_doc_search function.
    Use the metadata.
    And convert this into a set of unique schemas returned by that search.
    """
    target_schemas = set(doc.metadata['schema'] for doc in documents)

    return target_schemas

In [45]:
def connect_db(db_path, target_schema):
    """
    Take in the identified schema and connect to the sqlite database with that name
    """
    db_filepath = db_path
    db_filename = target_schema + '.sqlite'

    #point to database
    base_dir = os.path.dirname(os.path.abspath(db_filepath+db_filename)) #get the full path within the device
    db_path = os.path.join(base_dir, db_filename) #combine with filename to get db_path
    db = SQLDatabase.from_uri("sqlite:///" + db_path) #connect via the lanchain method

    return db

In [46]:
def prioritize_tables(documents, target_schema, sql_database):
    """
    Take in a list of similar_doc_search documents and prioritize tables from the same schema.
    Sort all tables in the database based on the priority list.
    """
    table_sorting = set(doc.metadata['table'] for doc in documents if doc.metadata['schema'] == target_schema)
    table_names = sql_database.get_usable_table_names()
    table_indices = {table: index for index, table in enumerate(table_sorting)}

    priority_tables = sorted(table_names, key=lambda x: (table_indices.get(x, float('inf')), table_names.index(x)))

    return priority_tables

In [47]:
def get_table_info(tables, database):
    """
    Initialize an empty list, then loop through the list of tables in the database and save the DDL and 3 sample rows for each to a variable.
    """
    table_info = []
    for table in tables:
        table_name = table
        info = db.get_table_info(table_names=[table])
        table_info.append('DDL for the ' + table_name + ' table:' + info)

    #merge the items
    tables_summary = '\n\n'.join(table_info)

    return tables_summary

In [48]:
def get_sql_dialect(database):
    """
    Return the sql dialaect from the database.
    """
    sql_dialect = database.dialect

    return sql_dialect

In [49]:
def llm_create_sql(sql_dialect, table_info, question, lang_model):
    create_prompt = PromptTemplate(
    input_variables=[],
    template = f"""
        You are a SQL Query Writer. Given an input question, first create a working {sql_dialect} SQL statement to find the answer to an input question and then return only the syntactically correct SQL statement.
        
        Use one or multiple of these tables:
        {table_info} 

        Input question: "{question}"
    """
    )

    create_chain = LLMChain(llm=lang_model, prompt=create_prompt, verbose=False)

    sql_query = create_chain.predict()

    return sql_query

In [50]:
def llm_check_sql(sql_query, sql_dialect, lang_model):
    validate_prompt = PromptTemplate(
    input_variables=[],
    template = f"""{sql_query}

        Double check the {sql_dialect} query above for common mistakes, including:
        - Using NOT IN with NULL values
        - Using UNION when UNION ALL should have been used
        - Using BETWEEN for exclusive ranges
        - Data type mismatch in predicates
        - Properly quoting identifiers
        - Using the correct number of arguments for functions
        - Casting to the correct data type
        - Using the proper columns for joins

        If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query."""
    )

    validate_chain = LLMChain(llm=lang_model, prompt=validate_prompt, verbose=False)

    checked_sql = validate_chain.predict()

    return checked_sql

In [51]:
def run_sql(database, sql_query):
    query_result = database.run_no_throw(sql_query)

    return query_result

In [52]:
def llm_debug_error(sql_query, error_message, lang_model):
    """
    To be executed if the SQL query returns an error message
    Will provide the error message to the LLM and ask it to debug, returning a new query
    """
    #setup the debugging prompt
    debug_error_prompt = PromptTemplate(
        input_variables=[],
        template = f"""{sql_query}

    The query above produced the following error:

    {error_message}

    Rewrite the query with the error fixed:"""
    )

    debug_error_chain = LLMChain(llm=lang_model, prompt=debug_error_prompt, verbose=False)

    debugged_query = debug_error_chain.predict()

    return debugged_query

In [53]:
def llm_debug_empty(sql_query, question, lang_model):
    """
    To be executed if the SQL query returns an empty string.
    Will provide a prompt to the LLM asking it to re-review the original question and data and write a new query
    """
    debug_empty_prompt = PromptTemplate(
        input_variables=[],
        template = f"""{sql_query}

        The query above produced no result. Try rewriting the query so it will return results to this question "{question}":"""
        )
    
    debug_empty_chain = LLMChain(llm=lang_model, prompt=debug_empty_prompt, verbose=False)

    debugged_query = debug_empty_chain.predict()

    return debugged_query

In [54]:
def llm_analyze(sql_query, query_result, question, lang_model):
    """
    To be executed if the run SQL query returns a valid results.
    Will provide the full output of the original question, final sql query, and answer.
    """
    analyze_prompt = PromptTemplate(
        input_variables=[],
        template = f"""{query_result}

        You are an expert data analyst. First look at the results of the above SQL output then identify the answer to this question:
        Question: "{question}"

        Then the answer in one sentence:
        """
        )
    
    analyze_chain = LLMChain(llm=lang_model, prompt=analyze_prompt, verbose=False)
    
    llm_answer = analyze_chain.predict()
    
    output = f"""
            Input question: {question}
            SQL Query: {sql_query}
            Answer: {llm_answer}"""

    return output

In [55]:
def debug_and_run(max_attempts=3, sql_query=validated_query):
    """
    Execute the validated SQL query and try `max_attempts` times.
    If successful, return the query result.
    If not, raise an exception or return an appropriate error message.
    """
    attempt = 1
    while attempt <= max_attempts:
        query_result = db.run_no_throw(sql_query)

        if query_result[:5] == 'Error':
            if attempt == max_attempts:
                output = f"Unable to execute the SQL query after {max_attempts} attempts."
                break

            sql_query = debug_error_chain.predict()
            attempt += 1
            time.sleep(1)  # Add a delay of 1 second between attempts


        elif query_result == '[]':
            if attempt == max_attempts:
                output = "Query returned no results after multiple attempts."
                break

            sql_query = debug_empty_chain.predict()
            attempt +=1
            time.sleep(1)  # Add a delay of 1 second between attempts

        else:
            answer = analyze_chain.predict()
            output = f"""
            Input question: {question}
            SQL Query: {sql_query}
            Answer: {query_result}"""
            break

    return output

In [56]:
validated_query

'SELECT count(*) FROM head WHERE age > 56'

In [57]:
output = debug_and_run()

print(output)


            Input question: How many heads of the departments are older than 56?
            SQL Query: SELECT count(*) FROM head WHERE age > 56
            Answer: [(5,)]


## Observations

This is working for simple quesiton, but will have a fun roadmap of improvements to work on. But I'm happy with it as a proof of concept so I'll work on building this into .py files that can run through the terminal.