## Text-to-SQL on a biomedical dataset, optimized for low latency on a single-table
---
We show here how to build a conversational chatbot that is able to extract information from a relational database with a single table. This is a relatively simple example of text-to-SQL, as there are no joins required. We focus here on showing how to optimize latency using the [SQLDatabaseToolkit](https://python.langchain.com/v0.2/docs/integrations/toolkits/sql_database/) from [LangChain](https://www.langchain.com).

In the generic case, SQLDatabaseToolkit uses the ReAct framework to make multiple calls to the LLM: to ask the database what tables it contains, to ask the database for the schema of a subset of those tables, to test a possible SQL query, to run a query, and more. Given that we know the database has only one table we can make fewer calls to the LLM and hence reduce the latency of the overall text-to-SQL process.

We use the following database of diabetes patients:
```
@article{Machado2024,
    author = "Angela Machado",
    title = "{diabetes.csv}",
    year = "2024",
    month = "3",
    url = "https://figshare.com/articles/dataset/diabetes_csv/25421347",
    doi = "10.6084/m9.figshare.25421347.v1"
}
```

In [1]:
%pip install -qU openpyxl langchain boto3
%pip install -qU langchain-community langchain-aws
# %pip install -qU torch

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
astropy 6.1.0 requires numpy>=1.23, but you have numpy 1.22.4 which is incompatible.
awscli 1.33.24 requires botocore==1.34.142, but you have botocore 1.34.152 which is incompatible.
scikit-image 0.23.2 requires numpy>=1.23, but you have numpy 1.22.4 which is incompatible.
sparkmagic 0.21.0 requires pandas<2.0.0,>=0.17.1, but you have pandas 2.2.2 which is incompatible.
sphinx 7.3.7 requires docutils<0.22,>=0.18.1, but you have docutils 0.16 which is incompatible.[0m[31m
[0mNote: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [1]:
import os
import re
# import io
from typing import List, Tuple
# from pprint import pprint
# import random
import itertools
from time import time

import jinja2
from langchain_community.utilities import SQLDatabase
import sqlite3
import boto3
# import torch
import pandas as pd
# import numpy as np
from langchain_aws import ChatBedrock
from langchain_community.agent_toolkits.sql.base import create_sql_agent
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain.agents.agent_types import AgentType
from langchain.chains import create_sql_query_chain
from langchain_core.prompts import PromptTemplate
# from langchain_community.vectorstores import FAISS
# from langchain_community.embeddings.bedrock import BedrockEmbeddings

In [2]:
model_id = "anthropic.claude-3-sonnet-20240229-v1:0"
# model_id = "anthropic.claude-3-haiku-20240307-v1:0"
con = sqlite3.connect("test.db")
jenv = jinja2.Environment(trim_blocks=True, lstrip_blocks=True)
os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_API_KEY"] = "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"
os.environ["AWS_DEFAULT_REGION"] = "us-east-1"

is_conversational = True
force_setup_db = False
do_few_shot_prompting = False
show_SQL = True

llm = ChatBedrock(model_id=model_id)
db = SQLDatabase.from_uri("sqlite:///test.db")
context = db.get_context()
chain = create_sql_query_chain(llm, db)

In [4]:
prompt_template = '''Answer the following questions as best you can. You have access to the following tools:

{tools}

Use the following format:

Question: the input question you must answer
Thought: you should always think about what to do
Action: the action to take, should be one of [{tool_names}]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can repeat N times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question

You might find the following tips useful:
{% for tip in tips %}
  - {{ tip }}
{% endfor %}

The database has the following single table:

{{ table_info }}

You should NEVER have to use either the sql_db_schema tool or the sql_db_list_tables tool
as you know the only table is the "patients" table and you know its schema.

You NEVER can product SELECT statement with no LIMIT clause. You should always have an ORDER BY
clause and a "LIMIT 20" to avoid returning too many useless results.

When describing the final result you don't have to describe HOW the SQL statement worked,
just describe the results.

Begin!

Question: {input}
Thought: {agent_scratchpad}'''

In [5]:
df = pd.read_csv("diabetes.csv")
df.head()

Unnamed: 0,Pregnancies,Glucose,BloodPressure,SkinThickness,Insulin,BMI,DiabetesPedigreeFunction,Age,Outcome
0,6,148,72,35,0,33.6,0.627,50,1
1,1,85,66,29,0,26.6,0.351,31,0
2,8,183,64,0,0,23.3,0.672,32,1
3,1,89,66,23,94,28.1,0.167,21,0
4,0,137,40,35,168,43.1,2.288,33,1


In [6]:
def setup_db():
    print(f"Setting up DB")
    df.to_sql(name="patients", con=con, if_exists="replace", index=True)
    con.commit()

In [7]:
def maybe_setup_db():
    if force_setup_db:
        print("Forcing DB setup")
        setup_db()
    else:
        try:
            cur = con.cursor()
            cur.execute("SELECT count(*) FROM patient")
            print(f"Table exists ({cur.fetchone()[0]}), no need to recreate DB")
        except Exception as ex:
            # print(f"Caught: {ex}")
            cur.close()
            if "no such table: patient" in str(ex):
                print(f"Table not there, need to recreate DB")
                setup_db()
            else:
                raise ex

In [8]:
def extract_tag(response: str, name: str, greedy: bool = True,
                opening_optional: bool = False) -> Tuple[str, int]:
    """
    >>> extract_tag("foo <a>baz</a> bar", "a")
    ('baz', 10)

    >>> extract_tag("baz</a> bar", "a", opening_optional=True)
    ('baz', 3)

    """
    if opening_optional:
        patn = f"\\s*(.*)</{name}>" if greedy else f"\\s*(.*?)</{name}>"
        match = re.match(patn, response, re.DOTALL)
        if match:
            return match.group(1).strip(), match.end(1)
    else:
        patn = f"<{name}>(.*)</{name}>" if greedy else\
               f"<{name}>(.*?)</{name}>"
        match = re.search(patn, response, re.DOTALL)
        if match:
            return match.group(1).strip(), match.end(1)
    print(f"Couldn't find tag {name} in <<<{response}>>>")
    return "", -1

In [9]:
maybe_setup_db()

Table not there, need to recreate DB
Setting up DB


In [10]:
def decontextualize_question(question: str, messages: List[List[str]]) -> str:
    print(f"decontextualize_question {question} {messages}")
    prompt_template = """
I am going to give you a history of questions and answers, followed by a new question.
I want you to rewrite to the new question so that it stands alone, not needing the
historical context to make sense.

<history>
{% for x in history %}
  <question>{{ x[0] }}</question>
  <answer>{{ x[1] }}</answer>
{% endfor %}
</history>

Here is the new question:
<new_question>
{{question}}
</new_question>

You must make the absolute MINIMUM changes required to make the meaning of
the sentence clear without the context of the history. Make NO other changes.

Return the rewritten, standalone, question in <result></result> tags.
"""
    print("here")
    prompt = jenv.from_string(prompt_template).render(history=messages, question=question)
    print(f"prompt:\n{prompt}\n-----")
    response = llm.invoke(prompt)
    print(f"response:\n{response}\n--------")
    answer = extract_tag(response.content, "result")[0]
    print(f"answer: <<{answer}>>")
    return answer

In [11]:
cur = con.cursor()
cur.execute("SELECT * FROM sqlite_master")
DDL = cur.fetchone()[4]
print(DDL)

CREATE TABLE "patients" (
"index" INTEGER,
  "Pregnancies" INTEGER,
  "Glucose" INTEGER,
  "BloodPressure" INTEGER,
  "SkinThickness" INTEGER,
  "Insulin" INTEGER,
  "BMI" REAL,
  "DiabetesPedigreeFunction" REAL,
  "Age" INTEGER,
  "Outcome" INTEGER
)


In [12]:
from langchain.callbacks.base import BaseCallbackHandler


class SQLHandler(BaseCallbackHandler):
    def __init__(self):
        self._sql_result = []
        self._num_tool_actions = 0

    def on_agent_action(self, action, **kwargs):
        """Run on agent action. if the tool being used is sql_db_query,
         it means we're submitting the sql and we can 
         record it as the final sql"""

        self._num_tool_actions += 1
        if action.tool in ["sql_db_query_checker","sql_db_query"]:
            self._sql_result.append(action.tool_input)

    def sql_results(self) -> List[str]:
        return self._sql_result
    
    def num_tool_actions(self) -> int:
        return self._num_tool_actions

In [13]:
notes = []

In [14]:
def create_prompt(notes, DDL, question: str):
    prompt_0 = jenv.from_string(prompt_template).render(tips=notes,
                                                        table_info=DDL)
    print(f"prompt0:\n{prompt_0}")
    prompt = PromptTemplate.from_template(prompt_0)
    return prompt


In [16]:
def answer_standalone_question(question: str, messages: List[List[str]]) -> str:
    start_time = time()
    print(f"\n>>>> answer_question {question}")
    if is_conversational and messages:
        question = decontextualize_question(question, messages)
    handler = SQLHandler()
    try:
        agent_executor = create_sql_agent(
            llm=llm,
            toolkit=SQLDatabaseToolkit(db=db, llm=llm),
            verbose=True,
            prompt=create_prompt(notes, DDL, question),
            agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
            callbacks=[handler],
            handle_parsing_errors=True)
        print(f"++++++ about to invoke")
        for iteration in itertools.count(0):
            try:
                answer = agent_executor.invoke(input={"input": question}, config={"callbacks": [handler]})
                # print(f"++++++ finished invoke: {type(answer)} {answer}")
                # print(f"sql results: {handler.sql_results()}")
                duration = time() - start_time
                iter_str = f", {iteration} iterations" if iteration > 1 else ""
                history_str = f", history {len(messages):,}" if len(messages) > 0 else ""
                sql_result = handler.sql_results()[-1].strip() if len(handler.sql_results()) > 0 else None
                print(f"sql_result: {sql_result}")
                SQL_str = f"\n ```{sql_result}```" if show_SQL and sql_result else ""
                return answer['output'], f"{duration:.1f} secs, {handler.num_tool_actions():,} actions{iter_str}{history_str} {SQL_str}"
            except ValueError as ex:
                if iteration < 10:
                    print(f"iteration #{iteration}: caught {ex}")
                    print("retrying")
                else:
                    raise ex
    except Exception as ex:
        print(f"foobar: {ex}")
        raise ex


In [17]:
def answer_multiple_questions(questions: List[str]) -> List[Tuple[str, str]]:
    messages = []
    answers = []
    for question in questions:
        answer, extra_info = answer_standalone_question(question, messages)
        answers.append(answer)
        messages.append([question, answer])
    return list(zip(questions, answers))

In [18]:
answer_standalone_question("How many patients have a BMI over 20 and are older than 30?", [])


>>>> answer_question How many patients have a BMI over 20 and are older than 30?
prompt0:
Answer the following questions as best you can. You have access to the following tools:

{tools}

Use the following format:

Question: the input question you must answer
Thought: you should always think about what to do
Action: the action to take, should be one of [{tool_names}]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can repeat N times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question

You might find the following tips useful:

The database has the following single table:

CREATE TABLE "patients" (
"index" INTEGER,
  "Pregnancies" INTEGER,
  "Glucose" INTEGER,
  "BloodPressure" INTEGER,
  "SkinThickness" INTEGER,
  "Insulin" INTEGER,
  "BMI" REAL,
  "DiabetesPedigreeFunction" REAL,
  "Age" INTEGER,
  "Outcome" INTEGER
)

You should NEVER have to use either the s

('There are 347 patients who have a BMI over 20 and are older than 30.',
 '12.1 secs, 2 actions \n ```SELECT COUNT(*) FROM patients WHERE BMI > 20 AND Age > 30;```')

In [19]:
answer_multiple_questions(["How many patients have a BMI over 20 and are older than 30?",
                           "How many are over 50?"])


>>>> answer_question How many patients have a BMI over 20 and are older than 30?
prompt0:
Answer the following questions as best you can. You have access to the following tools:

{tools}

Use the following format:

Question: the input question you must answer
Thought: you should always think about what to do
Action: the action to take, should be one of [{tool_names}]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can repeat N times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question

You might find the following tips useful:

The database has the following single table:

CREATE TABLE "patients" (
"index" INTEGER,
  "Pregnancies" INTEGER,
  "Glucose" INTEGER,
  "BloodPressure" INTEGER,
  "SkinThickness" INTEGER,
  "Insulin" INTEGER,
  "BMI" REAL,
  "DiabetesPedigreeFunction" REAL,
  "Age" INTEGER,
  "Outcome" INTEGER
)

You should NEVER have to use either the s

[('How many patients have a BMI over 20 and are older than 30?',
  'There are 347 patients in the database who have a BMI over 20 and are older than 30 years old.'),
 ('How many are over 50?',
  'According to the query, there are 329 patients in the database who have a BMI over 20, are older than 30 years, and have a blood pressure reading over 50.')]