Instructions before running this notebook
---

1. Follow the instructions in the `README` at the root of this repository to set up the required packages for Ask-a-Metric.

2. Create an account on vanna.ai and get an API key (skip if you already have a [vanna.ai](https://vanna.ai) account)

3. On your vanna.ai account, create **2 RAG models**. If you are running this notebook with existing models, please reset training data, in order to ensure you are working with "fresh" models.

4. Locally in your AAM conda environment / poetry shell, run `pip install vanna pandas`

5. Inside the `experiments/` folder, create a `.env` file with the follwing env variables:
    ```
    VANNA_API_KEY=<your vanna.ai API key, downloaded in step 2>

    VANNA_RAG_MODEL_WITH_SCHEMA=<name of first RAG model you created in step 3>
    
    VANNA_RAG_MODEL_WITHOUT_SCHEMA='<name of second RAG model you created in step 3>
    
    OPENAI_API_KEY=<your OpenAI API key>
    ```


In [None]:
%load_ext autoreload
%autoreload 2

## Imports

In [None]:
import os
import pandas as pd
import time
import asyncio

from askametric.query_processor.query_processor import LLMQueryProcessor
from askametric.utils import _ask_llm_json

from vanna.openai import OpenAI_Chat
from vanna.vannadb import VannaDB_VectorStore

from litellm import cost_per_token, token_counter
from dotenv import load_dotenv

from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.asyncio import (
    AsyncSession,
    create_async_engine,
)

# Load Environment Variables
load_dotenv(".env")


## Set parameters

In [None]:
# General Parameters
OPEN_AI_API_KEY = os.getenv('OPENAI_API_KEY')
WHICH_DB = "tn_covid_cases_11_may"
LLM_MODEL = "gpt-4o"

# AAM-specific Parameters
GUARDRAILS_LLM_MODEL = "gpt-4o"
SYSTEM_MESSAGE = "Government and health officials in Tamil Nadu, India will ask you questions.\
      You need to help them manage active COVID cases and the availablity of beds in health facilities."
DB_TABLE_DESCRIPTION = "- bed_vacancies_clinics_may_11: Each row identifies a district and the beds earmarked, occupied and available for COVID cases in the district clinics.\
- bed_vacancies_health_centers_and_district_hospitals_11_may: Each row identifies a district and the beds earmarked, occupied and available, with and without oxygen supply, and with and without ICU support, for COVID cases in the disctrict health centers and hospitals.\
- covid_cases_11_may: Each row identifies a district and the number of people who received treatment, were discharged and died due to COVID.\
"
DB_COLUMN_DESCRIPTION = ""
NUM_COMMON_VALUES = 10
DB_PATH = "../demo_databases/tn_covid_cases_11_may.sqlite"


# Vanna-specific Parameters
VANNA_AI_API_KEY = os.getenv("VANNA_API_KEY")
VANNA_RAG_MODEL_WITH_SCHEMA = os.getenv('VANNA_RAG_MODEL_WITH_SCHEMA')
VANNA_RAG_MODEL_WITHOUT_SCHEMA = os.getenv('VANNA_RAG_MODEL_WITHOUT_SCHEMA')
# these parameters are set to imitate the AAM setup as closely as possible,
# to allow a fair comparison between the two systems
ALLOW_AUTO_TRAIN_ON_QUERIES = False
VISUALIZE = False
ALLOW_LLM_TO_SEE_DATA_DURING_RAG = True

## Set up evaluation prompts for text responses

In [None]:
text_eval_message_template = """
----Message Begins----------------
Question: {question}
Answer: {text_response}
Correct Answer: {correct_answer}
Correct Language: {correct_language}
----Message Ends----------------
"""

SYSTEM_MESSAGE_TEXT_EVAL = """
You are a grading bot. You will get messages in the following format -

----Message Begins----------------
Question: ```<Some Question>```
Answer: ```<Answer to be graded>```
Correct Answer: ```<Correct Answer to Question>```
Correct Language: ```<Correct Language the "Answer" should be in with the language script if relevant>```
----Message Ends----------------

Give grades based on the following two points:
(a) Accuracy and Relevancy:
Is the "Answer" similar in meaning to "Correct Answer"?
Does the "Answer" address the key elements of the "Question"?
If yes, give a grade of 1, otherwise 0.
Remember, "Answer" and "Correct Answer" ONLY NEED TO BE SIMILAR in general meaning.
Look at the NUMBERS in the response to make you judgement.
(b) Language and Script:
Is "Answer" in the "Correct Language"? Is "Answer" in the same script as "Question"?
If yes, give a grade of 1, otherwise 0.

REMEMBER: Scores in each category (a or b) can ONLY be 0 or 1

Reply in the following json format -
{"accuracy_and_relevancy": <your grade for accuracy and relevancy>,
"language_and_script": <your grade for language and script>}
"""


## Set up utility functions for Vanna

In [None]:
# Util functions for Vanna
def setup_vanna_ai(
        vanna_rag_model,
        train_with_sql_schema: bool = False,
        system_message: str = SYSTEM_MESSAGE,
        db_path: str = DB_PATH):
    """
    Setup Vanna.ai Model

    Args:
        vanna_rag_model: the RAG model to use for Vanna
        train_with_sql_schema: whether to train the model with the SQL schema
        system_message: the system message for Vanna.ai
        db_path: the path to the database
    """
    class MyVanna(VannaDB_VectorStore, OpenAI_Chat):
        def __init__(self, config=None):
            VannaDB_VectorStore.__init__(
                self,
                vanna_model=vanna_rag_model,
                vanna_api_key=VANNA_AI_API_KEY,
                config=config)
            OpenAI_Chat.__init__(self, config=config)

    vn = MyVanna(
        config={'api_key': OPEN_AI_API_KEY, 'model': LLM_MODEL}
    )

    vn.connect_to_sqlite(db_path)
    vn.train(documentation = system_message + "\n\n" +
             DB_TABLE_DESCRIPTION + "\n\n" + DB_COLUMN_DESCRIPTION)

    if train_with_sql_schema:
        # Train vanna on SQL schema
        df_ddl = vn.run_sql("SELECT type, sql FROM sqlite_master WHERE sql is not null")
        for ddl in df_ddl['sql'].to_list():
            vn.train(ddl=ddl)

    return vn


async def get_vanna_reponse(vn, question: pd.DataFrame | str):
    """
    Send request to the vanna.ai and extract responses for a single question

    Args:
        vn: the Vanna model
        question: the question to ask

    Returns: a dictionary of vanna's response and associated costs
    """
    # Get response
    tic = time.time()  # record the start time
    vn_answer = vn.ask(
        question if isinstance(question, str) else question.question,
        print_results=False,
        auto_train=ALLOW_AUTO_TRAIN_ON_QUERIES,
        visualize=VISUALIZE,
        allow_llm_to_see_data=ALLOW_LLM_TO_SEE_DATA_DURING_RAG
        )
    response_time = time.time() - tic  # calculate the response time

    prompt_token_count = token_counter(
        text = vn.get_sql_prompt(
            initial_prompt="",
            question=question if isinstance(question, str) else question.question,
            question_sql_list=[],
            ddl_list=[],
            doc_list=[DB_TABLE_DESCRIPTION]
        )
    )
    response_token_count = token_counter(text=str(vn_answer))

    cost = cost_per_token(
        model=LLM_MODEL,
        prompt_tokens=prompt_token_count,
        completion_tokens=response_token_count,
    )

    # Vanna returns None if SQL query couldn't be processed
    if not vn_answer:
        text_response = ""
        sql_response = ""
    else:
        text_response = str(vn_answer[1])
        sql_response = str(vn_answer[0])

    if not isinstance(question, str):
        evaluation_json = await _ask_llm_json(
                prompt=text_eval_message_template.format(
                    question=question.question,
                    text_response=text_response,
                    correct_answer=question.correct_answer,
                    correct_language=question.language
                    ),
                system_message=SYSTEM_MESSAGE_TEXT_EVAL,
                llm=LLM_MODEL)
        evaluation_json = evaluation_json['answer']
    else:
        evaluation_json = {'accuracy_and_relevancy': 'N/A',
                           'language_and_script': 'N/A'}


    # Save only the sql query and text outputs along with the cost,
    response = {
        "question": question if isinstance(question, str) else question.question,
        "vanna_cost": sum(cost),
        "vanna_response_time": response_time,
        "vanna_text_response": text_response,
        "vanna_sql_query": sql_response,
        "vanna_accuracy_and_relevancy": evaluation_json['accuracy_and_relevancy'],
        "vanna_language_and_script": evaluation_json['language_and_script']
    }
    return response


## Set up utility functions for AAM

In [None]:
def get_db_asession(db_path: str = DB_PATH):
    """
    Get assession for db schema and aam

    Args:
        db_path: the path to the database
    """
    aengine = create_async_engine(
        url=f"sqlite+aiosqlite:///{db_path}"
    )

    async_session = sessionmaker(
        bind=aengine,
        class_=AsyncSession,
        expire_on_commit=False
    )

    return async_session

async def get_aam_reponse(question: pd.DataFrame | str,
                          asession: AsyncSession,
                          system_message: str = SYSTEM_MESSAGE,
                          db_description: str = DB_TABLE_DESCRIPTION,
                          ):
    """
    Send single query to the LLM

    Args:
        question: a single validation question
        asession: the async session for the database
        system_message: the system message for AAM
        db_description: the description of the database

    Returns: a dictionary of aam's response and associated costs
    """
    tic = time.time()  # record the start time

    async with asession() as session:
        qp = LLMQueryProcessor(
            query={"query_text": question if isinstance(question, str) else question.question,
                   "query_metadata": {}},
            asession=session,
            which_db=WHICH_DB,
            llm=LLM_MODEL,
            guardrails_llm=GUARDRAILS_LLM_MODEL,
            sys_message=system_message,
            db_description=db_description,
            column_description=DB_COLUMN_DESCRIPTION,
            num_common_values=NUM_COMMON_VALUES
        )
        try:
            await qp.process_query()
        except Exception as e:
            qp.final_answer = str(e)
            qp.sql_query = ""

        toc = time.time()  # record the end time
        response_time = toc - tic  # calculate the response time

        if not isinstance(question, str):
            evaluation_json = await _ask_llm_json(
                prompt=text_eval_message_template.format(
                    question=question.question,
                    text_response=qp.final_answer,
                    correct_answer=question.correct_answer,
                    correct_language=question.language
                    ),
                system_message=SYSTEM_MESSAGE_TEXT_EVAL,
                llm=LLM_MODEL)
            evaluation_json = evaluation_json['answer']
        else:
            evaluation_json = {'accuracy_and_relevancy': 'N/A',
                               'language_and_script': 'N/A'}



        # Save only the sql query and text outputs along with the cost
        response = {
            "question": question if isinstance(question, str) else question.question,
            "aam_cost": qp.cost,
            "aam_response_time": response_time,
            "aam_text_response": qp.final_answer,
            "aam_sql_query": qp.sql_query,
            "aam_accuracy_and_relevancy": evaluation_json["accuracy_and_relevancy"],
            "aam_language_and_script": evaluation_json["language_and_script"]
        }
        return response


## Prepare Vanna.ai and AAM models

In [None]:
# ---- Run this block ONCE -----
# NB: we train 2 instances, with and without the SQL schema for subsequent experiments
# NNB: we created FRESH models on each run of this notebook, so we don't need to worry about
# overwriting the models with the same name, or working with pre-trained models
vn_without_schema = setup_vanna_ai(
    VANNA_RAG_MODEL_WITHOUT_SCHEMA,
    train_with_sql_schema=False)

vn_with_schema = setup_vanna_ai(
    VANNA_RAG_MODEL_WITH_SCHEMA,
    train_with_sql_schema=True)

# Prepare the db for AAM
async_session = get_db_asession()


## Comparison experiments

In [None]:
# Load evaluation questions
eval_questions = pd.read_csv(
            "tn_covid_val_questions.csv",
            skip_blank_lines=True
        ).dropna(how="all")
eval_questions

### Run through basic, language and guardrail questions

In [None]:
# Get responses from AAM
tasks = [get_aam_reponse(eval_row, async_session) for _, eval_row in eval_questions.iterrows()]
aam_responses = await asyncio.gather(*tasks)
aam_responses = pd.DataFrame(aam_responses)
aam_responses

In [None]:
# Get responses from Vanna model without schema
tasks = [get_vanna_reponse(vn_without_schema, eval_row)
         for _, eval_row in eval_questions.iterrows()]
vn_without_schema_responses = await asyncio.gather(*tasks)
vn_without_schema_responses = pd.DataFrame(vn_without_schema_responses)

vn_without_schema_responses

In [None]:
# Get responses from Vanna model with schema
tasks = [get_vanna_reponse(vn_with_schema, eval_row)
         for _, eval_row in eval_questions.iterrows()]
vn_with_schema_responses = await asyncio.gather(*tasks)
vn_with_schema_responses = pd.DataFrame(vn_without_schema_responses)

vn_with_schema_responses

In [None]:
# Combine responses and save
vanna_merge = pd.merge(
    left=vn_without_schema_responses,
    right=vn_with_schema_responses,
    on="question",
    suffixes=("_without_schema", "_with_schema")
)
all_response = vanna_merge.merge(aam_responses, on="question")
all_response_plus_eval = all_response.merge(eval_questions, on="question")
all_response_plus_eval.to_csv("tn_covid_all_responses.csv", index=False)

### Compare accuracy, relevancy and language of responses

In [None]:
print("Accuracy and Relevancy Scores")
print(all_response.filter(like="accuracy_and_relevancy").mean(0))

print("Language and Script Scores")
print(all_response.filter(like="language_and_script").mean(0))

### Compare cost and response time

In [None]:
# Average cost per query (in USD)

# NB -- this might NOT be a fair comparison, since we only compute
# input and output token costs for vanna.ai (and not training or for intermediate RAG queries)
# Costs for AAM, however, are computed for every OpenAI API call in the pipeline.

all_response.filter(like="cost").mean(0)

In [None]:
# Average response time per query (in s)

all_response.filter(like="response_time").mean(0)

## Compare how easy it is to make quick updates to the LLM

In [None]:
# We want to update the system prompt, to ask the LLM so that
# it defaults to providing answers about Chennai, when no location is specified
# in the query. This is useful on AAM, since we often encounter user queries
# that could be ambiguous

# Update the system message
updated_system_message = SYSTEM_MESSAGE + "\n" + \
    "REMEMBER: If the user query does not specify a district, assume it is about Chennai."


In [None]:
# Retrain vanna.ai with the updated system message

# Remove previous training data for vanna.ai without schema
[vn_without_schema.remove_training_data(id=id)
 for id in vn_without_schema.get_training_data().id.values]

# Remove previous training data for vanna.ai with schema
[vn_with_schema.remove_training_data(id=id)
 for id in vn_with_schema.get_training_data().id.values]

# Retrain vanna.ai without schema
vn_without_schema = setup_vanna_ai(
    VANNA_RAG_MODEL_WITHOUT_SCHEMA,
    train_with_sql_schema=False,
    system_message=updated_system_message
)

# Retrain vanna.ai with schema
vn_with_schema = setup_vanna_ai(
    VANNA_RAG_MODEL_WITH_SCHEMA,
    train_with_sql_schema=True,
    system_message=updated_system_message
)


In [None]:
# Get response from AAM, Vanna without schema, and Vanna with schema
# with ambiguous queries

question = "How many ICU beds?"

# AAM
aam_response = await get_aam_reponse(question, async_session, updated_system_message)

# Vanna without schema
vn_without_schema_response = await get_vanna_reponse(vn_without_schema, question)

# Vanna with schema
vn_with_schema_response = await get_vanna_reponse(vn_with_schema, question)


In [None]:
print(f"AAM Response: {aam_response['aam_text_response']}")
print(f"Vanna without schema Response: {vn_without_schema_response['vanna_text_response']}")
print(f"Vanna with schema Response: {vn_with_schema_response['vanna_text_response']}")