# Llama-Index Text-To-SQL Retrieval Agent
### Thoughts:
- Too inconsistent in its performance
- Easily makes up facts in the absence of results
- Isn't really able to grasp the full context of the data structure and meaning
- Underlying functionality difficult to modify, particularly the prompt template for the text-to-sql process prior to response synthesis.

In [1]:
import os
from dotenv import load_dotenv
from IPython.display import Markdown, display
import pandas as pd

from llama_index.core import SQLDatabase
from llama_index.llms.openai import OpenAI
from llama_index.core.indices.struct_store.sql_query import (
    SQLTableRetrieverQueryEngine,
)
from llama_index.core.objects import (
    SQLTableNodeMapping,
    ObjectIndex,
    SQLTableSchema,
)
from llama_index.core import VectorStoreIndex, PromptTemplate

from src.db.database import engine
from src.db import models


load_dotenv()


OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")


llm = OpenAI(temperature=0.1, model="gpt-4o-mini", api_key=OPENAI_API_KEY)

sql_database = SQLDatabase(engine)

table_node_mapping = SQLTableNodeMapping(sql_database)

table_schema_objs = [
    (SQLTableSchema(table_name=table.__tablename__, context_str=table.__context_str__)) 
    for table in models.__dict__.values() if hasattr(table, '__tablename__')
]

obj_index = ObjectIndex.from_objects(
    table_schema_objs,
    table_node_mapping,
    VectorStoreIndex,
)

response_synthesis_prompt_str = (
    "Given an input question, synthesize a response from the query results. \
    You must ensure your response is completely factual.\n"
    "<query>{query_str}</query>\n"
    "<sql>{sql_query}</sql>\n"
    "<sql response>SQL Response: {context_str}</sql response>\n"
    "Response: "
)
response_synthesis_prompt = PromptTemplate(
    response_synthesis_prompt_str,
)

query_engine = SQLTableRetrieverQueryEngine(
    sql_database, obj_index.as_retriever(similarity_top_k=1),
    response_synthesis_prompt=response_synthesis_prompt,
)

query = "What are the fields in the meetings table and what do they represent contextually?"
# query = "Using just your provided system messaging and without using SQL, \
#     What are the fields in the meetings table and what do they represent contextually?"
# query = "What is the name of the firm that has the most meetings and how many meets do they have?"
# query = "Can you show me the first 5 rows of meetings?"
# query = "Fetch the first 5 meetings and their content which have a firm attended that are in the Energy sector."
response = query_engine.query(query)

print("SQL Query:")
print("```\n" + response.metadata["sql_query"] + "\n```")
print("Response:")
display(Markdown(f"<b>{response}</b>"))
if "result" in response.metadata:
    display(pd.DataFrame(response.metadata["result"], columns=response.metadata["col_keys"]))

SQL Query:
```
SELECT column_name, data_type, description
FROM information_schema.columns
WHERE table_name = 'meetings';
```
Response:


<b>I'm sorry, but there was an error in retrieving the information about the fields in the meetings table. The query was invalid and the column "description" does not exist in the table.</b>

# Custom Simplified Implementation
- Much slower
- Has chain of thought reasoning with verbosity
- Still has issues constructing queries
- Need to consider how the information is presented back to the User in a memory-friendly way
    - Can return just beam_ids as part of the retrieval?
        - This can be added to the user's 'meetings in-focus' view?
    - Can return as markdown (BIG CONTEXT ISSUE)

In [10]:
from pydantic import BaseModel
from typing import Optional
from sqlalchemy.orm import Session
from sqlalchemy import text
from tenacity import retry, stop_after_attempt

from src.db.database import session_scope

class Thoughts(BaseModel):
    """
    Use this class to think through the problem step-by-step towards constructing the SQL query.

    Attributes:
        thoughts: str - Thinking out loud what things you need to consider to faciliate the user's query.
        outcome: str - What will need to be added to the SQL query to get the desired result base on the thoughts.
    """
    thoughts: str
    outcome: str

class TextToSQL(BaseModel):
    """
    Use this to construct the SQL query to retrieve the correct information from the database according to the user's query.

    Attributes:
        step: list[Thoughts] - List of thoughts to consider to construct the SQL query.
        joins: Optional[list[str]] - List of joins you will need to add to the SQL query if required (default is an empty list).
        filters: Optional[list[str]] - List of filters you will need to add to the SQL query if required (default is an empty list).
        fields: list[str] - List of fields to add to the SQL query, constructed as 'table_name.field_name'.
        query: str - The final SQL query to be executed.
        possible: bool - If the query is not possible to execute, set this to False.
    """
    steps: list[Thoughts]
    joins: Optional[list[str]] = []
    filters: Optional[list[str]] = []
    fields: list[str] = []
    query: str
    possible: bool

prompt_template = PromptTemplate(
    """Given an input question, you are to create a syntactically correct postgresql \
query to run that will provide the user with the data relevant to their question. \

Pay attention to use only the column names that you can see in the provided schema \
below between the <schema></schema> xml tags which is directly from a sqlalchemy python file. \
Be careful to not query for columns that do not exist. \
Pay attention to which column is in which table. Also, qualify column names \
with the table name when needed.

Only use tables listed below.
<schema>{schema}</schema>

Your response must be formatted as according to the structured output format provided.
Ensure that your sql query is well formatted for readability and syntactically correct.

User Question: {query}
"""
)

class SQLAgent:

    _validation_words: list[str] = ["delete", "update", "insert", "drop"]

    def __init__(
            self,
            llm,
            schema_file_path: str,
            output_template: type[TextToSQL] = TextToSQL,
            prompt_template: PromptTemplate = prompt_template,
            verbose: bool = False,
    ):
        self.llm = llm.as_structured_llm(output_template)
        self.schemas_file_str = open(schema_file_path, "r").read()
        self.output_template = output_template
        self.prompt_template = prompt_template
        self.verbose = verbose

    def _complete(self, query: str) -> TextToSQL:
        response = self.llm.complete(self.prompt_template.format(query=query, schema=self.schemas_file_str)).raw
        return response
    
    def _show_cot(self, response: TextToSQL):
        print("CHAIN OF THOUGHTS:")
        for step in response.steps:
            print(f"Thoughts: {step.thoughts}")
            print(f"Outcome: {step.outcome}", "\n")

    def _show_sql_query(self, response: TextToSQL):
        print("SQL QUERY:")
        print("```\n" + response.query + "\n```")

    def _valid_query(self, response: TextToSQL) -> bool:
        for word in self._validation_words:
            if word in response.query.lower():
                return False
        return True
    
    def _return_not_possible_thoughts(self, response: TextToSQL) -> str:
        response = "The SQL Agent has determined that the query is not possible to execute. Here are the thoughts that led to this conclusion:\n"
        for step in response.steps:
            response += f"Thoughts: {step.thoughts}\n"
            response += f"Outcome: {step.outcome}\n\n"
        return response
    
    def _execute_query(self, session: Session, response: TextToSQL) -> pd.DataFrame:
        result = session.execute(text(response.query)).fetchall()
        columns = response.fields
        return pd.DataFrame(result, columns=columns)

    @retry(stop=stop_after_attempt(3))
    def complete(self, session: Session, query: str):
        session.flush() # Ensure that the session is clean before executing the query.
        response = self._complete(query)
        if self.verbose:
            self._show_cot(response)
            self._show_sql_query(response)
        if not self._valid_query(response):
            raise NotImplementedError("Not implemented invalid queries yet.")
        if not response.possible:
            return self._return_not_possible_thoughts(response)
        results = self._execute_query(session, response)
        return results


agent = SQLAgent(llm, "src/db/models.py", verbose=True)

# query = "I need all meetings between 2022-01-01 and 2023-01-01 where the firms that attended are in the Energy sector."
query = "Return the beam_ids of all meetings between 2022-01-01 and 2023-01-01 where the firms that attended are in the Energy sector."

with session_scope() as session:
    response = agent.complete(session, query)

response

CHAIN OF THOUGHTS:
Thoughts: I need to retrieve the beam_ids from the meetings table where the date is between 2022-01-01 and 2023-01-01. Additionally, I need to filter the results based on the firms that attended the meetings, specifically those in the Energy sector.
Outcome: I will need to join the meetings table with the firms table through the firm_attended_id to filter by the sector. 

Thoughts: The meetings table has a firm_attended_id that links to the firms table. I need to ensure that I filter the firms based on the sector being 'Energy'.
Outcome: I will add a filter for firms.sector = 'Energy'. 

Thoughts: The date filter should be applied to the meetings.date column to ensure the meetings fall within the specified range.
Outcome: I will add a filter for meetings.date between '2022-01-01' and '2023-01-01'. 

SQL QUERY:
```
SELECT meetings.beam_id
FROM meetings
JOIN firms ON meetings.firm_attended_id = firms.firm_id
WHERE meetings.date BETWEEN '2022-01-01' AND '2023-01-01'
  A

Unnamed: 0,meetings.beam_id
0,7b61cd78-99bd-4870-84e9-8fe91e4df949
1,0a59596b-b0e9-400f-9535-4424728ded7f
2,64820540-2221-4b1c-939b-7fe896ebe3bb
3,99a999d7-adb2-4164-bc6e-83faecbed78e
4,872b7e38-19ef-4568-aeca-7d1a6b95e0ed
5,d96cfe90-f9a1-4214-891b-ccce929bce41
6,1f0cad65-204b-4dfe-8594-6dc667646301
7,05c599e7-81cc-4ab4-8bbe-d1d14ae3e0ff
8,57f87c6c-3065-4394-8260-c7d04bc7957a
9,718f443c-6d8e-4d36-9b60-744994ff0f21


In [11]:
response_md = response.to_markdown(index=True)
print(response_md)

|    | meetings.beam_id                     |
|---:|:-------------------------------------|
|  0 | 7b61cd78-99bd-4870-84e9-8fe91e4df949 |
|  1 | 0a59596b-b0e9-400f-9535-4424728ded7f |
|  2 | 64820540-2221-4b1c-939b-7fe896ebe3bb |
|  3 | 99a999d7-adb2-4164-bc6e-83faecbed78e |
|  4 | 872b7e38-19ef-4568-aeca-7d1a6b95e0ed |
|  5 | d96cfe90-f9a1-4214-891b-ccce929bce41 |
|  6 | 1f0cad65-204b-4dfe-8594-6dc667646301 |
|  7 | 05c599e7-81cc-4ab4-8bbe-d1d14ae3e0ff |
|  8 | 57f87c6c-3065-4394-8260-c7d04bc7957a |
|  9 | 718f443c-6d8e-4d36-9b60-744994ff0f21 |
| 10 | d7719519-0626-4a5e-b26a-28913730a025 |
| 11 | 13cfee19-3644-4798-82bf-2af6b327d6ce |
| 12 | 9e6b8d61-ec62-4501-b43e-5219b6363726 |
| 13 | e76c1a29-cd63-426f-ab86-fc27447ea06d |
| 14 | 313fad11-fc6e-4f59-8e96-48c36e704ef4 |
| 15 | 6248a04b-65ea-444f-b521-c5479f114bfb |
| 16 | 73c39876-689f-42e1-9abd-afe471ee8869 |
| 17 | 6ed95bc2-2d95-4489-aef8-c67591ce219c |
| 18 | 3107e204-c709-4e24-9f07-f39eb1294773 |
| 19 | 5bb656aa-ec86-4090-8a57-0d8