# Custom Agent

I had limited success with the built-in functions, so I'll test a custom flow, really a combination of tweaked langchain functions.

I think part of this could be more specialized schema and table extraction, but for now I think I'll stick with my single schema example and try to expand later.

### Imports

In [1]:
import pandas as pd
import os
from dotenv import load_dotenv
import time

from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceEmbeddings
from langchain import HuggingFaceHub, SQLDatabase
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain

### Setup query
We'll use this to input into the model

In [81]:
question = "Using the 'bike_1' schema, how many trips did not end in San Francisco?"

### Load Target Tables from Vectordb

Load in the top 3 results and save just the unique schemas to a set.

In [82]:
#setup embeddings using HuggingFace and the directory location
embeddings  = HuggingFaceEmbeddings()
persist_dir = '../data/processed/chromadb/schema-table-split'

# load from disk
vectordb = Chroma(persist_directory=persist_dir, embedding_function=embeddings)

#run similarity search
top_results = vectordb.similarity_search(question, k=3) #just return the top for now

In [83]:
top_results

[Document(page_content='Schema Table: bike 1 trip', metadata={'schema': 'bike_1', 'table': 'trip', 'columns': '["id", "duration", "start_date", "start_station_name", "start_station_id", "end_date", "end_station_name", "end_station_id", "bike_id", "subscription_type", "zip_code"]'}),
 Document(page_content='Schema Table: bike 1 station', metadata={'schema': 'bike_1', 'table': 'station', 'columns': '["id", "name", "lat", "long", "dock_count", "city", "installation_date"]'}),
 Document(page_content='Schema Table: bike 1 status', metadata={'schema': 'bike_1', 'table': 'status', 'columns': '["station_id", "bikes_available", "docks_available", "time"]'})]

In [84]:
top_results[0].metadata['schema']

'bike_1'

#### Setup Target Schemas

Sets are unordered, lists aren't automatically unique. I want a unique set of values in order of how they are pulled from the document strings. In some tests, it's giving me the 2nd most relevant schema first. I need to find a way to make sure it always provides the schema names in the order provided by the similarity search.

In [85]:
#try in a standard list with dict.fromkeys
tgt_list = list(dict.fromkeys([doc.metadata['schema'] for doc in top_results]))

tgt_list

['bike_1']

In [86]:
target_schemas = set(doc.metadata['schema'] for doc in top_results)

target_schemas

{'bike_1'}

In [87]:
target_schemas = set()
for doc in top_results:
    schema = doc.metadata['schema']
    target_schemas.add(schema)

print(target_schemas)

{'bike_1'}


In [88]:
list(target_schemas)[0]

'bike_1'

### Tool - Query for Schema Information
Use the langchain SQLDatabase class to get all the information about the schema. I think I also want to use the most likely columns returned to sort and prioritize how the tables are fed into the llm.

I could load in the sql_query tools from langchain.tools, but I want to tweek things a little bit, especially in the order. I could come back to this later if it does simplify things and my way doesn't make the outcome any better.

#### Point to Database

In [89]:
db_filepath = '../data/processed/db/'
db_filename = list(target_schemas)[0] + '.sqlite'

In [90]:
#point to database
base_dir = os.path.dirname(os.path.abspath(db_filepath+db_filename)) #get the full path within the device
db_path = os.path.join(base_dir, db_filename) #combine with filename to get db_path
db = SQLDatabase.from_uri("sqlite:///" + db_path) #connect via the lanchain method

#### Set table sorting and then pull tables names using langchain

In [92]:
#pull the table names from the vector query so we can have them in order of similarity to the question.
table_sorting = set()
for doc in top_results:
    schema = doc.metadata['schema']
    if schema == 'bike_1':
        table_name = doc.metadata['table']
        table_sorting.add(table_name)
table_sorting = list(table_sorting)
print(table_sorting)

['status', 'trip', 'station']


In [93]:
#get tables names
table_names = db.get_usable_table_names()

#sort by the results of our vector query to load the most likely table in first. Not sure if this will make a difference, but can't imagine it would hurt.
priority_tables = sorted(table_names, key=lambda x: (table_sorting.index(x) if x in table_sorting else float('inf'), table_names.index(x)))

priority_tables

['status', 'trip', 'station', 'weather']

#### Get SQL Dialect

Using the langchain class again

In [94]:
sql_dialect = db.dialect

sql_dialect

'sqlite'

#### Get table info

You can do this for all tables in the database automatically with the function, but I want to continue trying to order these by likelihood of being related to the question. So I'll do this in a loop.

This gives the full create statement and 3 sample rows, which I think will be a great start. And really gets us to where the Chain does for adding to the prompt template. I'll also combine this with the table name at the start of it.

In [95]:
table_info = []
for table in priority_tables:
    table_name = table
    info = db.get_table_info(table_names=[table])
    table_info.append('DDL for the ' + table_name + ' table:' + info)

#merge the items
tables_summary = '\n\n'.join(table_info)

print(tables_summary)

DDL for the status table:
CREATE TABLE status (
	station_id INTEGER, 
	bikes_available INTEGER, 
	docks_available INTEGER, 
	time TEXT, 
	FOREIGN KEY(station_id) REFERENCES station (id)
)

/*
3 rows from status table:
station_id	bikes_available	docks_available	time
3	12	3	2015-06-02 12:46:02
3	12	3	2015-06-02 12:47:02
3	12	3	2015-06-02 12:48:02
*/

DDL for the trip table:
CREATE TABLE trip (
	id INTEGER, 
	duration INTEGER, 
	start_date TEXT, 
	start_station_name TEXT, 
	start_station_id INTEGER, 
	end_date TEXT, 
	end_station_name TEXT, 
	end_station_id INTEGER, 
	bike_id INTEGER, 
	subscription_type TEXT, 
	zip_code INTEGER, 
	PRIMARY KEY (id)
)

/*
3 rows from trip table:
id	duration	start_date	start_station_name	start_station_id	end_date	end_station_name	end_station_id	bike_id	subscription_type	zip_code
900504	384	8/21/2015 17:03	Howard at 2nd	63	8/21/2015 17:10	San Francisco Caltrain 2 (330 Townsend)	69	454	Subscriber	94041
900505	588	8/21/2015 17:03	South Van Ness at Market	66	8/

### Create Prompt Template

In [96]:
create_prompt = PromptTemplate(
    input_variables=[],
    template = f"""
        You are a SQL Query Writer. Given an input question, first create a working {sql_dialect} SQL statement to find the answer to an input question and then return only the syntactically correct SQL statement.
        
        Use one or multiple of these tables:
        {tables_summary} 

        Input question: "{question}"
    """
)

In [100]:
print(create_prompt)

input_variables=[] output_parser=None partial_variables={} template='\n        You are a SQL Query Writer. Given an input question, first create a working sqlite SQL statement to find the answer to an input question and then return only the syntactically correct SQL statement.\n        \n        Use one or multiple of these tables:\n        DDL for the status table:\nCREATE TABLE status (\n\tstation_id INTEGER, \n\tbikes_available INTEGER, \n\tdocks_available INTEGER, \n\ttime TEXT, \n\tFOREIGN KEY(station_id) REFERENCES station (id)\n)\n\n/*\n3 rows from status table:\nstation_id\tbikes_available\tdocks_available\ttime\n3\t12\t3\t2015-06-02 12:46:02\n3\t12\t3\t2015-06-02 12:47:02\n3\t12\t3\t2015-06-02 12:48:02\n*/\n\nDDL for the trip table:\nCREATE TABLE trip (\n\tid INTEGER, \n\tduration INTEGER, \n\tstart_date TEXT, \n\tstart_station_name TEXT, \n\tstart_station_id INTEGER, \n\tend_date TEXT, \n\tend_station_name TEXT, \n\tend_station_id INTEGER, \n\tbike_id INTEGER, \n\tsubscriptio

### Setup LLM
I'll test the google flan-t5 model

In [97]:
#get api key
load_dotenv()
hf_api_token = os.getenv('hf_token')

#add path to HF repo
repo_id = 'google/flan-t5-xxl'

#establish llm model
llm = HuggingFaceHub(repo_id=repo_id, huggingfacehub_api_token=hf_api_token, model_kwargs={"temperature": 0.1, "max_length": 512})

### 1st Chain - Generate SQL Query

In [102]:
chain = LLMChain(llm=llm, prompt=create_prompt, verbose=False)
sql_query = chain.predict()

ValueError: Error raised by inference API: Input validation error: `inputs` must have less than 1024 tokens. Given: 1276

### Create Try-Except logic for flagging inputs with too many tokens.

I think with this I can have the program print the error so the user knows what's happening but keep the loop moving so it doesn't totally break.

In [104]:
try:
    sql_query = chain.predict()
except ValueError as e:
    print(e)
finally:
    print('keeping it moving')

Error raised by inference API: Input validation error: `inputs` must have less than 1024 tokens. Given: 1276
keeping it moving


In [38]:
sql_query

'SELECT Count ( * ) FROM head'

### Chain for checking SQL Query

This is directly from the langchain check query prompt

In [20]:
validate_prompt = PromptTemplate(
    input_variables=[],
    template = f"""{sql_query}

    Double check the {sql_dialect} query above for common mistakes, including:
    - Using NOT IN with NULL values
    - Using UNION when UNION ALL should have been used
    - Using BETWEEN for exclusive ranges
    - Data type mismatch in predicates
    - Properly quoting identifiers
    - Using the correct number of arguments for functions
    - Casting to the correct data type
    - Using the proper columns for joins

    If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query."""
)

In [39]:
chain = LLMChain(llm=llm, prompt=validate_prompt, verbose=False)
validated_query = chain.predict()

In [40]:
validated_query

'SELECT count(*) FROM head'

### Run and Debug the Query

Run the query on the database, if it is succesful, send the results to an LLM to interpret them.
If unsuccessful, ask and LLM to debug and write a new one based on the error message. And then try again, only exiting the loop once it works.

The wonderful folks at langchain created a run_no_throw funciton that returns the successful results or the error message. Not a full big error.

In [41]:
sql_output = db.run_no_throw(validated_query)

sql_output

'[(10,)]'

In [45]:
#create prompt for answering the question
analyze_prompt = PromptTemplate(
    input_variables=[],
    template = f"""
    You are an expert data analyst. Given an output of an SQL query, first look at the output and then determine the answer to a user question.
    
    SQL Output: "{sql_output}"
    Question: "{question}"

    Describe your answer with one sentence:
    """
)

In [46]:
#test analyze chain
chain = LLMChain(llm=llm, prompt=analyze_prompt, verbose=False)
answer = chain.predict()

In [47]:
answer

"There are 10 rows in the 'head' table in the 'department_management"

In [47]:
#create a query that should error
bad_query = 'SELECT count(a) FROM head'

In [48]:
query_error = db.run_no_throw(bad_query)

query_error

'Error: (sqlite3.OperationalError) no such column: a\n[SQL: SELECT count(a) FROM head]\n(Background on this error at: https://sqlalche.me/e/20/e3q8)'

In [49]:
#create query that should return a null avlue
empty_query = "SELECT name FROM head WHERE head_id = '25'"

In [50]:
test_empty = db.run_no_throw(empty_query)

test_empty

'[]'

In [51]:
#create prompt for debugging error
debug_error_prompt = PromptTemplate(
    input_variables=[],
    template = f"""{bad_query}

The query above produced the following error:

{query_error}

Rewrite the query with the error fixed:"""
)

In [52]:
#create prompt for debugging error
debug_empty_prompt = PromptTemplate(
    input_variables=[],
    template = f"""{empty_query}

The query above produced no result. Try rewriting the query so it will return results to this question "{question}":"""
)

In [53]:
#test debug
chain = LLMChain(llm=llm, prompt=debug_error_prompt, verbose=True)
debugged = chain.predict()



[1m> Entering new  chain...[0m
Prompt after formatting:
[32;1m[1;3mSELECT count(a) FROM head

The query above produced the following error:

Error: (sqlite3.OperationalError) no such column: a
[SQL: SELECT count(a) FROM head]
(Background on this error at: https://sqlalche.me/e/20/e3q8)

Rewrite the query with the error fixed:[0m

[1m> Finished chain.[0m


In [54]:
debugged

'SELECT count(*) FROM head'

In [55]:
#test empty debug
chain = LLMChain(llm=llm, prompt=debug_empty_prompt, verbose=True)
emp_debugged = chain.predict()



[1m> Entering new  chain...[0m
Prompt after formatting:
[32;1m[1;3mSELECT name FROM head WHERE head_id = '25'

The query above produced no result. Try rewriting the query so it will return results to this question "How many rows in the 'head' table in the 'department_management' schema?":[0m

[1m> Finished chain.[0m


In [119]:
emp_debugged

'SELECT Count ( * ) FROM head'

## Create Functions

### Selecting database, connecting to it, running all of the prompts, etc.

I want to first build this out as a local application that has multiple entry point into the LLMs, and then work on building this out as a custom agent.

In [69]:
#establish unique chains
create_chain = LLMChain(llm=llm, prompt=create_prompt, verbose=False)

validate_chain = LLMChain(llm=llm, prompt=validate_prompt, verbose=False)

analyze_chain = LLMChain(llm=llm, prompt=analyze_prompt, verbose=False)

debug_error_chain = LLMChain(llm=llm, prompt=debug_error_prompt, verbose=False)

debug_empty_chain = LLMChain(llm=llm, prompt=debug_empty_prompt, verbose=False)

In [83]:
def similar_doc_search(question, vector_db=vectordb, k=3):
    """
    Run similarity search through a vectordb query on the user input question and return the k most similar schema-table combinations in document form..
    """
    similar_docs = vector_db.similarity_search(question, k)

    return similar_docs

In [81]:
def identify_schemas(documents):
    """
    Take in a list of documents created from our similar_doc_search function.
    Use the metadata.
    And convert this into a set of unique schemas returned by that search.
    """
    target_schemas = set(doc.metadata['schema'] for doc in documents)

    return target_schemas

In [1]:
def connect_db(db_path, target_schema):
    """
    Take in the identified schema and connect to the sqlite database with that name
    """
    db_filepath = db_path
    db_filename = target_schema + '.sqlite'

    #point to database
    base_dir = os.path.dirname(os.path.abspath(db_filepath+db_filename)) #get the full path within the device
    db_path = os.path.join(base_dir, db_filename) #combine with filename to get db_path
    db = SQLDatabase.from_uri("sqlite:///" + db_path) #connect via the lanchain method

    return db

In [80]:
def prioritize_tables(documents, target_schema, sql_database):
    """
    Take in a list of similar_doc_search documents and prioritize tables from the same schema.
    Sort all tables in the database based on the priority list.
    """
    table_sorting = set(doc.metadata['table'] for doc in documents if doc.metadata['schema'] == target_schema)
    table_names = sql_database.get_usable_table_names()
    table_indices = {table: index for index, table in enumerate(table_sorting)}

    priority_tables = sorted(table_names, key=lambda x: (table_indices.get(x, float('inf')), table_names.index(x)))

    return priority_tables

In [77]:
def get_table_info(tables, database):
    """
    Initialize an empty list, then loop through the list of tables in the database and save the DDL and 3 sample rows for each to a variable.
    """
    table_info = []
    for table in tables:
        table_name = table
        info = db.get_table_info(table_names=[table])
        table_info.append('DDL for the ' + table_name + ' table:' + info)

    #merge the items
    tables_summary = '\n\n'.join(table_info)

    return tables_summary

In [93]:
def get_sql_dialect(database):
    """
    Return the sql dialaect from the database.
    """
    sql_dialect = database.dialect

    return sql_dialect

In [114]:
def llm_create_sql(sql_dialect, table_info, question, lang_model):
    create_prompt = PromptTemplate(
    input_variables=[],
    template = f"""
        You are a SQL Query Writer. Given an input question, first create a working {sql_dialect} SQL statement to find the answer to an input question and then return only the syntactically correct SQL statement.
        
        Use one or multiple of these tables:
        {table_info} 

        Input question: "{question}"
    """
    )

    create_chain = LLMChain(llm=lang_model, prompt=create_prompt, verbose=False)

    sql_query = create_chain.predict()

    return sql_query

In [115]:
def llm_check_sql(sql_query, sql_dialect, lang_model):
    validate_prompt = PromptTemplate(
    input_variables=[],
    template = f"""{sql_query}

        Double check the {sql_dialect} query above for common mistakes, including:
        - Using NOT IN with NULL values
        - Using UNION when UNION ALL should have been used
        - Using BETWEEN for exclusive ranges
        - Data type mismatch in predicates
        - Properly quoting identifiers
        - Using the correct number of arguments for functions
        - Casting to the correct data type
        - Using the proper columns for joins

        If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query."""
    )

    validate_chain = LLMChain(llm=lang_model, prompt=validate_prompt, verbose=False)

    checked_sql = validate_chain.predict()

    return checked_sql

In [111]:
def run_sql(database, sql_query):
    query_result = database.run_no_throw(sql_query)

    return query_result

In [116]:
def llm_debug_error(sql_query, error_message, lang_model):
    """
    To be executed if the SQL query returns an error message
    Will provide the error message to the LLM and ask it to debug, returning a new query
    """
    #setup the debugging prompt
    debug_error_prompt = PromptTemplate(
        input_variables=[],
        template = f"""{sql_query}

    The query above produced the following error:

    {error_message}

    Rewrite the query with the error fixed:"""
    )

    debug_error_chain = LLMChain(llm=lang_model, prompt=debug_error_prompt, verbose=False)

    debugged_query = debug_error_chain.predict()

    return debugged_query

In [117]:
def llm_debug_empty(sql_query, question, lang_model):
    """
    To be executed if the SQL query returns an empty string.
    Will provide a prompt to the LLM asking it to re-review the original question and data and write a new query
    """
    debug_empty_prompt = PromptTemplate(
        input_variables=[],
        template = f"""{sql_query}

        The query above produced no result. Try rewriting the query so it will return results to this question "{question}":"""
        )
    
    debug_empty_chain = LLMChain(llm=lang_model, prompt=debug_empty_prompt, verbose=False)

    debugged_query = debug_empty_chain.predict()

    return debugged_query

In [118]:
def llm_analyze(sql_query, query_result, question, lang_model):
    """
    To be executed if the run SQL query returns a valid results.
    Will provide the full output of the original question, final sql query, and answer.
    """
    analyze_prompt = PromptTemplate(
        input_variables=[],
        template = f"""{query_result}

        You are an expert data analyst. First look at the results of the above SQL output then identify the answer to this question:
        Question: "{question}"

        Then the answer in one sentence:
        """
        )
    
    analyze_chain = LLMChain(llm=lang_model, prompt=analyze_prompt, verbose=False)
    
    llm_answer = analyze_chain.predict()
    
    output = f"""
            Input question: {question}
            SQL Query: {sql_query}
            Answer: {llm_answer}"""

    return output

In [101]:
def debug_and_run(max_attempts=3, sql_query=validated_query):
    """
    Execute the validated SQL query and try `max_attempts` times.
    If successful, return the query result.
    If not, raise an exception or return an appropriate error message.
    """
    attempt = 1
    while attempt <= max_attempts:
        query_result = db.run_no_throw(sql_query)

        if query_result[:5] == 'Error':
            if attempt == max_attempts:
                output = f"Unable to execute the SQL query after {max_attempts} attempts."
                break

            sql_query = debug_error_chain.predict()
            attempt += 1
            time.sleep(1)  # Add a delay of 1 second between attempts


        elif query_result == '[]':
            if attempt == max_attempts:
                output = "Query returned no results after multiple attempts."
                break

            sql_query = debug_empty_chain.predict()
            attempt +=1
            time.sleep(1)  # Add a delay of 1 second between attempts

        else:
            answer = analyze_chain.predict()
            output = f"""
            Input question: {question}
            SQL Query: {sql_query}
            Answer: {query_result}"""
            break

    return output

In [62]:
validated_query

'SELECT count(*) FROM head'

In [67]:
output = debug_and_run()

print(output)


            Input question: How many rows in the 'head' table in the 'department_management' schema?
            SQL Query: SELECT count(*) FROM head
            Answer: 10


## Create Agent

https://python.langchain.com/docs/modules/agents/

Two types ->
- Action: at each step, decide on the next action using the outputs of the previous steps
- Plan & Execute: decide on the full sequence up front, then execute them all without updating the plan

Plan & Execute is better for larger problems because it maintains focus. Often it's best to combine and let P&E agents use action agents to execute plans.

Tools -> actions an agent can take
Toolkit -> wrapper around collections of tools for specifc use cases. Will need this for sql, tool to inspect tables and tool to run queries.

#### Action Agents

Wrapped in agent executors - call the agent, get back an action and action input, call the referenced tool, take the output of it and return it back to the agent.

Typically has a prompt template, language model, and output parser

#### Plan-and-Execute Agents

Typically have the language model be the planner and the executor be an action agent.

#### Built-In Agent Observations

It appears the built-in model using an action agent - I'll try to tweak it. Maybe even by splitting out the steps into agents and trying the combo type with a larger plan & execute operating triggering the action agents.

### Agent Imports

In [7]:
from langchain.agents import initialize_agent, AgentType, load_tools, create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.prompts import PromptTemplate

### Setup Toolkit

I'll use the standard langchain one.

In [8]:
toolkit = SQLDatabaseToolkit(db=db, llm=llm)

### Create Prompt Template - Create SQL Query

In [16]:
create_prompt = PromptTemplate(
    input_variables={"tool_names"},
    template="""
        Use the following format:
        Question: the input question you must answer
        Thought: you should always think about what to do
        Action: the action to take, should be one of [{tool_names}]
        Action Input: the input to the action
        Observation: the result of the action
        Thought: I now know the final answer
        Final Answer: SQL query ONLY, always include columns "Indexes", "Index_Name" along with your final answer. 
        If column name is a reserve word add quotes around it
        """
)

In [17]:
agent_exec = create_sql_agent(llm=llm,
                                toolkit=toolkit,
                                verbose=True,
                                format_instructions=create_prompt,
                                top_k=5)

In [18]:
agent_exec("How many heads of the departments are older than 56?")



[1m> Entering new  chain...[0m
[32;1m[1;3mAction: sql_db_query
Action Input: [0m
Observation: [36;1m[1;3m[0m
Thought:[32;1m[1;3m I now know the final answer
Final Answer: 

SELECT COUNT(*) FROM department WHERE age[0m

[1m> Finished chain.[0m


{'input': 'How many heads of the departments are older than 56?',
 'output': 'SELECT COUNT(*) FROM department WHERE age'}

### Create Prompt Template - Check the SQL Query

From here: https://canvasapp.com/blog/text-to-sql-in-production

In [25]:
prompt = PromptTemplate(
    input_variables=["question", "sql_dialect", "current_day", "sql"],
    template="""
        You are a helpful AI that verifies that a SQL query runs correctly. If the query does not run successfully, you iteratively update the query based on the error message.
        You follow the following procedure:
            1. Run the input SQL query using the QueryTool
            2. If this runs successfully, return immediately.
            3. If the SQL query fails, debug the query:
                3a. consider the error message
                3b. update the SOL query
                3c. try re-running
            4. Repeat step 3 until you have a valid result. Finish and exit with the updated SQL query.
        You are writing SQL for the {sql_dialect} dialect.
        The current day is {current_day} in case the user references the date
        The user's original question: {question}
        The SQL: {sql}
        Begin!
    """
)

In [5]:
i = 4
while True:
    if i == 3:
        output = 3
        break
    elif i < 7:
        i += 1
    else:
        output = "Didn't find that value, sorry!"
        break

In [6]:
output

"Didn't find that value, sorry!"