## Text-to-SQL on a biomedical dataset, optimized for low latency on a single-table
---
We show here how to build a conversational chatbot that is able to extract information from a relational database with a single table. This is a relatively simple example of text-to-SQL, as there are no joins required. We focus here on showing how to optimize latency using the [SQLDatabaseToolkit](https://python.langchain.com/v0.2/docs/integrations/toolkits/sql_database/) from [LangChain](https://www.langchain.com).

In the generic case, SQLDatabaseToolkit uses the ReAct framework to make multiple calls to the LLM: to ask the database what tables it contains, to ask the database for the schema of a subset of those tables, to test a possible SQL query, to run a query, and more. Given that we know the database has only one table we can make fewer calls to the LLM and hence reduce the latency of the overall text-to-SQL process.

We use the following database of diabetes patients, which has been downloaded for you as the file `diabetes.csv`:
```
@article{Machado2024,
    author = "Angela Machado",
    title = "{diabetes.csv}",
    year = "2024",
    month = "3",
    url = "https://figshare.com/articles/dataset/diabetes_csv/25421347",
    doi = "10.6084/m9.figshare.25421347.v1"
}
```

Note that the following `pip install` commands may generate warnings: you can safely ignore these.

In [None]:
%pip install -qU openpyxl langchain boto3
%pip install -qU langchain-community langchain-aws

In [None]:
import os
import sys
from typing import List, Tuple
import itertools
from time import time

import jinja2
from langchain_community.utilities import SQLDatabase
import sqlite3
import boto3
import pandas as pd
from langchain_aws import ChatBedrock
from langchain_community.agent_toolkits.sql.base import create_sql_agent
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain.agents.agent_types import AgentType
from langchain.chains import create_sql_query_chain
from langchain_core.prompts import PromptTemplate
from langchain.callbacks.base import BaseCallbackHandler

sys.path.append('../')
import utilities as u

In [None]:
model_id = "anthropic.claude-3-sonnet-20240229-v1:0"
# model_id = "anthropic.claude-3-haiku-20240307-v1:0"
con = sqlite3.connect("test.db")
jenv = jinja2.Environment(trim_blocks=True, lstrip_blocks=True)
# This is a useful way to keep track of tool invocations:
#os.environ["LANGCHAIN_TRACING_V2"] = "true"
#os.environ["LANGCHAIN_API_KEY"] = "..."
os.environ["AWS_DEFAULT_REGION"] = "us-west-2"

is_conversational = True
force_setup_db = False
do_few_shot_prompting = False
show_SQL = True

llm = ChatBedrock(model_id=model_id, region_name="us-west-2")
db = SQLDatabase.from_uri("sqlite:///test.db")
context = db.get_context()
chain = create_sql_query_chain(llm, db)

### Load the test data into a database

First, we load the CSV file into a DataFrame and take a look at some rows:

In [None]:
df = pd.read_csv("diabetes.csv")
df.head()

Next, we load this data into a SQLite database:

In [None]:
def setup_db():
    print("Setting up DB")
    df.to_sql(name="patients", con=con, if_exists="replace", index=True)
    con.commit()

In [None]:
def maybe_setup_db():
    if force_setup_db:
        print("Forcing DB setup")
        setup_db()
    else:
        try:
            cur = con.cursor()
            cur.execute("SELECT count(*) FROM patient")
            print(f"Table exists ({cur.fetchone()[0]}), no need to recreate DB")
        except Exception as ex:
            # print(f"Caught: {ex}")
            cur.close()
            if "no such table: patient" in str(ex):
                print(f"Table not there, need to recreate DB")
                setup_db()
            else:
                raise ex

In [None]:
maybe_setup_db()

### In order to make the chatbot conversational, we need to de-contextualize questions

For example, if the first question is "How many patients are over 30?" and the second question is "And how many of those have a BMI > 30?" then we need to rewrite the second question to replace "those" with an appropriate referent. For example, we could rewrite the question as "How many patients that are over 30 also have a BMI > 30?"

In [None]:
def decontextualize_question(question: str, messages: List[List[str]]) -> str:
    """
    Each message is a list of [question, answer].
    """
    print(f"decontextualize_question {question} {messages}")
    prompt_template = """
I am going to give you a history of questions and answers, followed by a new question.
I want you to rewrite to the new question so that it stands alone, not needing the
historical context to make sense.

<history>
{% for x in history %}
  <question>{{ x[0] }}</question>
  <answer>{{ x[1] }}</answer>
{% endfor %}
</history>

Here is the new question:
<new_question>
{{question}}
</new_question>

You must make the absolute MINIMUM changes required to make the meaning of
the sentence clear without the context of the history. Make NO other changes.

Return the rewritten, standalone, question in <result></result> tags.
"""
    prompt = jenv.from_string(prompt_template).render(history=messages, question=question)
    # print(f"prompt:\n{prompt}\n-----")
    response = llm.invoke(prompt)
    # print(f"response:\n{response}\n--------")
    answer = u.extract_tag(response.content, "result")[0]
    # print(f"answer: <<{answer}>>")
    return answer

Extract the `CREATE TABLE` statement from the database and store it away so we can later insert it into the prompt.

In [None]:
cur = con.cursor()
cur.execute("SELECT * FROM sqlite_master")
DDL = cur.fetchone()[4]
print(DDL)

We use an instance of `BaseCallbackHandler` to introspect on the sequence of LLM calls (tool invocations) so
we can later report on useful information about this tool chain like the SQL generated and the number of tool invocations.

In [None]:
class SQLHandler(BaseCallbackHandler):
    def __init__(self):
        self._sql_result = []
        self._num_tool_actions = 0

    def on_agent_action(self, action, **kwargs):
        """Runs on agent action. if the tool being used is sql_db_query,
         it means we're submitting the sql and we can 
         record it as the final sql
        """
        self._num_tool_actions += 1
        if action.tool in ["sql_db_query_checker", "sql_db_query"]:
            self._sql_result.append(action.tool_input)

    def sql_results(self) -> List[str]:
        return self._sql_result

    def num_tool_actions(self) -> int:
        return self._num_tool_actions

We can optionally provide notes or hints about the schema to help guide to model towards generating more accurate
SQL. In this case the schema is straightforward so we haven't need to add any notes, but you can experiment with adding 
some in here

In [None]:
notes: List[str] = []

The following is the main prompt that we use to direct the [ReAct](https://arxiv.org/pdf/2210.03629) workflow. Typically this agentic workflow would use the tools sql_db_schema and sql_db_list_tables to extract metadata (the schema) from the database. This requires extra LLM inferences that increases the latency of the overall agentic workflow. Here we both explicitly provide the table name and `CREATE TABLE` statement and also tell the LLM to not call these tools.

In [None]:
prompt_template = '''
Answer the following questions as best you can.

You have access to the following tools:

{tools}

Use the following format:

Question: the input question you must answer
Thought: you should always think about what to do
Action: the action to take, should be one of [{tool_names}]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can repeat N times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question

You might find the following tips useful:
{% for tip in tips %}
  - {{ tip }}
{% endfor %}

The database has the following single table:

{{ table_info }}

You should NEVER have to use either the sql_db_schema tool or the sql_db_list_tables tool
as you know the only table is the "patients" table and you know its schema.

You NEVER can product SELECT statement with no LIMIT clause. You should always have an ORDER BY
clause and a "LIMIT 20" to avoid returning too many useless results.

When describing the final result you don't have to describe HOW the SQL statement worked,
just describe the results.

Begin!

Question: {input}
Thought: {agent_scratchpad}'''

In [None]:
def create_prompt(notes, DDL, question: str):
    prompt_0 = jenv.from_string(prompt_template).render(tips=notes,
                                                        table_info=DDL)
    prompt = PromptTemplate.from_template(prompt_0)
    return prompt

## Answering questions

Below we provide two functions, `answer_standalone_question` and `answer_multiple_questions`, that you can use to drive a chatbot. While the interaction here is admitedly crude, you can easily take these functions and plug them into a framework such as [gradio's ChatBot](https://www.gradio.app/docs/gradio/chatbot) to create a more sophisticated UX.

In [None]:
def answer_standalone_question(question: str,
                               messages: List[List[str]]) -> str:
    start_time: float = time()
    if is_conversational and messages:
        question = decontextualize_question(question, messages)
    handler = SQLHandler()
    try:
        agent_executor = create_sql_agent(
            llm=llm,
            toolkit=SQLDatabaseToolkit(db=db, llm=llm),
            verbose=True,
            prompt=create_prompt(notes, DDL, question),
            agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
            callbacks=[handler],
            handle_parsing_errors=True)
        for iteration in itertools.count(0):
            try:
                answer = agent_executor.invoke(input={"input": question},
                                               config={"callbacks": [handler]})
                duration = time() - start_time
                iter_str = f", {iteration} iterations" if iteration > 1 else ""
                history_str = f", history {len(messages):,}" if len(messages) > 0 else ""
                sql_result = handler.sql_results()[-1].strip() if len(handler.sql_results()) > 0\
                             else None
                print(f"sql_result: {sql_result}")
                SQL_str = f"\n ```{sql_result}```" if show_SQL and sql_result else ""
                return answer['output'],\
                       f"{duration:.1f} secs, {handler.num_tool_actions():,} actions{iter_str}{history_str} {SQL_str}"
            except ValueError as ex:
                if iteration < 10:
                    print(f"iteration #{iteration}: caught {ex}")
                    print("retrying")
                else:
                    raise ex
    except Exception as ex:
        print(f"Caught: {ex}")
        raise ex

In [None]:
def answer_multiple_questions(questions: List[str]) -> List[Tuple[str, str]]:
    messages: List[Tuple[str, str]] = []
    answers: List[str] = []
    for question in questions:
        answer, extra_info = answer_standalone_question(question, messages)
        answers.append(answer)
        messages.append([question, answer])
    return list(zip(questions, answers))

If when executing the next cell you see this error:

![model access error](content/model-access-error.png)

then you need to go to the Bedrock web console and request model access.

In [None]:
answer_standalone_question("How many patients have a BMI over 20 and are older than 30?",
                           [])

In [None]:
answer_multiple_questions(
    ["How many patients have a BMI over 20 and are older than 30?",
     "How many are over 50?"])