# Custom Agent

I had limited success with the built-in functions, so I'll test a custom flow, really a combination of tweaked langchain functions.

I think part of this could be more specialized schema and table extraction, but for now I think I'll stick with my single schema example and try to expand later.

### Imports

In [505]:
import pandas as pd
import os
from dotenv import load_dotenv
import time

from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceEmbeddings
from langchain import HuggingFaceHub, SQLDatabase
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain

### Setup query
We'll use this to input into the model

In [401]:
question = "How many rows in the 'head' table in the 'department_management' schema?"

### Load Target Tables from Vectordb

Load in the top 3 results and save just the unique schemas to a set.

In [402]:
#setup embeddings using HuggingFace and the directory location
embeddings  = HuggingFaceEmbeddings()
persist_dir = '../data/processed/chromadb/schema-table-split'

# load from disk
vectordb = Chroma(persist_directory=persist_dir, embedding_function=embeddings)

#run similarity search
top_results = vectordb.similarity_search(question, k=3) #just return the top for now

In [513]:
top_results[0].metadata['schema']

'department_management'

#### Setup Target Schemas

In [404]:
target_schemas = set()
for doc in top_results:
    schema = doc.metadata['schema']
    target_schemas.add(schema)

print(target_schemas)

{'department_management'}


In [405]:
list(target_schemas)[0]

'department_management'

### Tool - Query for Schema Information
Use the langchain SQLDatabase class to get all the information about the schema. I think I also want to use the most likely columns returned to sort and prioritize how the tables are fed into the llm.

I could load in the sql_query tools from langchain.tools, but I want to tweek things a little bit, especially in the order. I could come back to this later if it does simplify things and my way doesn't make the outcome any better.

#### Point to Database

In [406]:
db_filepath = '../data/processed/db/'
db_filename = list(target_schemas)[0] + '.sqlite'

In [407]:
#point to database
base_dir = os.path.dirname(os.path.abspath(db_filepath+db_filename)) #get the full path within the device
db_path = os.path.join(base_dir, db_filename) #combine with filename to get db_path
db = SQLDatabase.from_uri("sqlite:///" + db_path) #connect via the lanchain method

#### Set table sorting and then pull tables names using langchain

In [408]:
#pull the table names from the vector query so we can have them in order of similarity to the question.
table_sorting = set()
for doc in top_results:
    schema = doc.metadata['schema']
    if schema == 'department_management':
        table_name = doc.metadata['table']
        table_sorting.add(table_name)
table_sorting = list(table_sorting)
print(table_sorting)

['management', 'head', 'department']


In [409]:
#get tables names
table_names = db.get_usable_table_names()

#sort by the results of our vector query to load the most likely table in first. Not sure if this will make a difference, but can't imagine it would hurt.
priority_tables = sorted(table_names, key=lambda x: (table_sorting.index(x) if x in table_sorting else float('inf'), table_names.index(x)))

priority_tables

['management', 'head', 'department']

#### Get SQL Dialect

Using the langchain class again

In [410]:
sql_dialect = db.dialect

sql_dialect

'sqlite'

#### Get table info

You can do this for all tables in the database automatically with the function, but I want to continue trying to order these by likelihood of being related to the question. So I'll do this in a loop.

This gives the full create statement and 3 sample rows, which I think will be a great start. And really gets us to where the Chain does for adding to the prompt template. I'll also combine this with the table name at the start of it.

In [411]:
table_info = []
for table in priority_tables:
    table_name = table
    info = db.get_table_info(table_names=[table])
    table_info.append('DDL for the ' + table_name + ' table:' + info)

#merge the items
tables_summary = '\n\n'.join(table_info)

tables_summary

'DDL for the management table:\nCREATE TABLE management (\n\t"department_ID" INTEGER, \n\t"head_ID" INTEGER, \n\ttemporary_acting TEXT, \n\tPRIMARY KEY ("department_ID", "head_ID"), \n\tFOREIGN KEY("head_ID") REFERENCES head ("head_ID"), \n\tFOREIGN KEY("department_ID") REFERENCES department ("Department_ID")\n)\n\n/*\n3 rows from management table:\ndepartment_ID\thead_ID\ttemporary_acting\n2\t5\tYes\n15\t4\tYes\n2\t6\tYes\n*/\n\nDDL for the head table:\nCREATE TABLE head (\n\t"head_ID" INTEGER, \n\tname TEXT, \n\tborn_state TEXT, \n\tage REAL, \n\tPRIMARY KEY ("head_ID")\n)\n\n/*\n3 rows from head table:\nhead_ID\tname\tborn_state\tage\n1\tTiger Woods\tAlabama\t67.0\n2\tSergio García\tCalifornia\t68.0\n3\tK. J. Choi\tAlabama\t69.0\n*/\n\nDDL for the department table:\nCREATE TABLE department (\n\t"Department_ID" INTEGER, \n\t"Name" TEXT, \n\t"Creation" TEXT, \n\t"Ranking" INTEGER, \n\t"Budget_in_Billions" REAL, \n\t"Num_Employees" REAL, \n\tPRIMARY KEY ("Department_ID")\n)\n\n/*\n3 ro

### Create Prompt Template

In [412]:
create_prompt = PromptTemplate(
    input_variables=[],
    template = f"""
        You are a SQL Query Writer. Given an input question, first create a working {sql_dialect} SQL statement to find the answer to an input question and then return only the syntactically correct SQL statement.
        
        Use one or multiple of these tables:
        {tables_summary} 

        Input question: "{question}"
    """
)

### Setup LLM
I'll test the google flan-t5 model

In [413]:
#get api key
load_dotenv()
hf_api_token = os.getenv('hf_token')

#add path to HF repo
repo_id = 'google/flan-t5-xxl'

#establish llm model
llm = HuggingFaceHub(repo_id=repo_id, huggingfacehub_api_token=hf_api_token, model_kwargs={"temperature": 0.1, "max_length": 256})

### 1st Chain - Generate SQL Query

In [414]:
chain = LLMChain(llm=llm, prompt=create_prompt, verbose=False)
sql_query = chain.predict()

In [415]:
sql_query

'SELECT Count ( * ) FROM head'

### Chain for checking SQL Query

This is directly from the langchain check query prompt

In [416]:
validate_prompt = PromptTemplate(
    input_variables=[],
    template = f"""{sql_query}

    Double check the {sql_dialect} query above for common mistakes, including:
    - Using NOT IN with NULL values
    - Using UNION when UNION ALL should have been used
    - Using BETWEEN for exclusive ranges
    - Data type mismatch in predicates
    - Properly quoting identifiers
    - Using the correct number of arguments for functions
    - Casting to the correct data type
    - Using the proper columns for joins

    If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query."""
)

In [417]:
chain = LLMChain(llm=llm, prompt=validate_prompt, verbose=False)
validated_query = chain.predict()

In [418]:
validated_query

'SELECT count(*) FROM head'

### Define Functions - Identify & Connect to DB, Get & Prioritize Tables, Create & Validate Query

In [510]:
#create unique chains
create_chain = LLMChain(llm=llm, prompt=create_prompt, verbose=False)

validate_chain = LLMChain(llm=llm, prompt=validate_prompt, verbose=False)

In [None]:
def create_sql(question):
    

### Run and Debug the Query

Run the query on the database, if it is succesful, send the results to an LLM to interpret them.
If unsuccessful, ask and LLM to debug and write a new one based on the error message. And then try again, only exiting the loop once it works.

The wonderful folks at langchain created a run_no_throw funciton that returns the successful results or the error message. Not a full big error.

In [420]:
sql_output = db.run_no_throw(validated_query)

sql_output

'[(10,)]'

In [453]:
#create prompt for answering the question
analyze_prompt = PromptTemplate(
    input_variables=[],
    template = f"""{sql_output}

    You are an expert data analyst. First look at the results of the above SQL output then identify the answer to this question:
    Question: "{question}"

    Then the answer in one sentence:
    """
)

In [455]:
#test analyze chain
chain = LLMChain(llm=llm, prompt=analyze_prompt, verbose=False)
answer = chain.predict()

In [456]:
answer

'10'

In [465]:
#create a query that should error
bad_query = 'SELECT count(a) FROM head'

In [468]:
query_error = db.run_no_throw(bad_query)

query_error

'Error: (sqlite3.OperationalError) no such column: a\n[SQL: SELECT count(a) FROM head]\n(Background on this error at: https://sqlalche.me/e/20/e3q8)'

In [None]:
#create query that should return a null avlue
empty_query = "SELECT name FROM head WHERE head_id = '25'"

In [486]:
test_empty = db.run_no_throw(empty_query)

test_empty

'[]'

In [477]:
#create prompt for debugging error
debug_error_prompt = PromptTemplate(
    input_variables=[],
    template = f"""{bad_query}

The query above produced the following error:

{query_error}

Rewrite the query with the error fixed:"""
)

In [478]:
#create prompt for debugging error
debug_empty_prompt = PromptTemplate(
    input_variables=[],
    template = f"""{empty_query}

The query above produced no result. Try rewriting the query so it will return results to this question "{question}":"""
)

In [479]:
#test debug
chain = LLMChain(llm=llm, prompt=debug_error_prompt, verbose=True)
debugged = chain.predict()



[1m> Entering new  chain...[0m
Prompt after formatting:
[32;1m[1;3mSELECT count(a) FROM head

The query above produced the following error:

Error: (sqlite3.OperationalError) no such column: a
[SQL: SELECT count(a) FROM head]
(Background on this error at: https://sqlalche.me/e/20/e3q8)

Rewrite the query with the error fixed:[0m

[1m> Finished chain.[0m


In [480]:
debugged

'SELECT count(*) FROM head'

In [481]:
#test empty debug
chain = LLMChain(llm=llm, prompt=debug_empty_prompt, verbose=True)
emp_debugged = chain.predict()



[1m> Entering new  chain...[0m
Prompt after formatting:
[32;1m[1;3mSELECT name FROM head WHERE head_id = '25'

The query above produced no result. Try rewriting the query so it will return results to this question "How many rows in the 'head' table in the 'department_management' schema?":[0m

[1m> Finished chain.[0m


In [482]:
emp_debugged

'SELECT Count ( * ) FROM head'

### Create Full Loop for this Running and Debugging Process

I want to have it take the output of the create sql statement step and then run it. If it returns any issues, try debugging 3 times. If none of those work, return a scripted error message.

In [508]:
#establish unique chains
analyze_chain = LLMChain(llm=llm, prompt=analyze_prompt, verbose=False)

debug_error_chain = LLMChain(llm=llm, prompt=debug_error_prompt, verbose=False)

debug_empty_chain = LLMChain(llm=llm, prompt=debug_empty_prompt, verbose=False)

In [511]:
def debug_and_run(max_attempts=3, sql_query=validated_query):
    max_att = max_attempts
    sql = sql_query

    for attempt in range(1, max_att + 1):
        query_result = db.run_no_throw(sql)

        if query_result[:5] == 'Error':
            if attempt == max_att:
                output = f"Unable to execute the SQL query after {max_att} attempts."
                break

            sql = debug_error_chain.predict()
            time.sleep(1)  # Add a delay of 1 second between attempts
        elif query_result == '[]':
            if attempt == max_att:
                output = "Query returned no results after multiple attempts."
                break

            sql = debug_empty_chain.predict()
            time.sleep(1)  # Add a delay of 1 second between attempts
        else:
            query_result = analyze_chain.predict()
            output = f"""
            Input question: {question}
            SQL Query: {sql}
            Answer: {query_result}"""
            break

## Create Agent

https://python.langchain.com/docs/modules/agents/

Two types ->
- Action: at each step, decide on the next action using the outputs of the previous steps
- Plan & Execute: decide on the full sequence up front, then execute them all without updating the plan

Plan & Execute is better for larger problems because it maintains focus. Often it's best to combine and let P&E agents use action agents to execute plans.

Tools -> actions an agent can take
Toolkit -> wrapper around collections of tools for specifc use cases. Will need this for sql, tool to inspect tables and tool to run queries.

#### Action Agents

Wrapped in agent executors - call the agent, get back an action and action input, call the referenced tool, take the output of it and return it back to the agent.

Typically has a prompt template, language model, and output parser

#### Plan-and-Execute Agents

Typically have the language model be the planner and the executor be an action agent.

#### Built-In Agent Observations

It appears the built-in model using an action agent - I'll try to tweak it. Maybe even by splitting out the steps into agents and trying the combo type with a larger plan & execute operating triggering the action agents.

### Agent Imports

In [7]:
from langchain.agents import initialize_agent, AgentType, load_tools, create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.prompts import PromptTemplate

### Setup Toolkit

I'll use the standard langchain one.

In [8]:
toolkit = SQLDatabaseToolkit(db=db, llm=llm)

### Create Prompt Template - Create SQL Query

In [16]:
create_prompt = PromptTemplate(
    input_variables={"tool_names"},
    template="""
        Use the following format:
        Question: the input question you must answer
        Thought: you should always think about what to do
        Action: the action to take, should be one of [{tool_names}]
        Action Input: the input to the action
        Observation: the result of the action
        Thought: I now know the final answer
        Final Answer: SQL query ONLY, always include columns "Indexes", "Index_Name" along with your final answer. 
        If column name is a reserve word add quotes around it
        """
)

In [17]:
agent_exec = create_sql_agent(llm=llm,
                                toolkit=toolkit,
                                verbose=True,
                                format_instructions=create_prompt,
                                top_k=5)

In [18]:
agent_exec("How many heads of the departments are older than 56?")



[1m> Entering new  chain...[0m
[32;1m[1;3mAction: sql_db_query
Action Input: [0m
Observation: [36;1m[1;3m[0m
Thought:[32;1m[1;3m I now know the final answer
Final Answer: 

SELECT COUNT(*) FROM department WHERE age[0m

[1m> Finished chain.[0m


{'input': 'How many heads of the departments are older than 56?',
 'output': 'SELECT COUNT(*) FROM department WHERE age'}

### Create Prompt Template - Check the SQL Query

From here: https://canvasapp.com/blog/text-to-sql-in-production

In [25]:
prompt = PromptTemplate(
    input_variables=["question", "sql_dialect", "current_day", "sql"],
    template="""
        You are a helpful AI that verifies that a SQL query runs correctly. If the query does not run successfully, you iteratively update the query based on the error message.
        You follow the following procedure:
            1. Run the input SQL query using the QueryTool
            2. If this runs successfully, return immediately.
            3. If the SQL query fails, debug the query:
                3a. consider the error message
                3b. update the SOL query
                3c. try re-running
            4. Repeat step 3 until you have a valid result. Finish and exit with the updated SQL query.
        You are writing SQL for the {sql_dialect} dialect.
        The current day is {current_day} in case the user references the date
        The user's original question: {question}
        The SQL: {sql}
        Begin!
    """
)

In [27]:
formatted_prompt = prompt.format(
    question=question,
    sql=sql, #feed in an sql from a prior step?
    sql_dialect=sql_dialect,
)

NameError: name 'sql' is not defined