In [None]:
%load_ext autoreload
%autoreload 2
from dotenv import load_dotenv

In [None]:
# Load Environment Variables
load_dotenv()

In [None]:
# Create an engine to connect to the database

from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.asyncio import (
    AsyncSession,
    create_async_engine,
)
db_type="sqlite"

aengine = create_async_engine(
    url=f"{db_type}+aiosqlite:///demo_databases/tn_covid_cases_11_may.sqlite"
)

async_session = sessionmaker(
    bind=aengine,
    class_=AsyncSession,
    expire_on_commit=False
)

In [None]:
# Parameters
metric_db_id = "test"
llm = "gpt-4o"
validation_llm = "gpt-4o"
guardrails_llm = "gpt-4o"

sys_message = "Government and health officials in Tamil Nadu, India will ask you questions. You need to help them manage COVID cases and the availablity of beds in health facilities."

db_description = "[{\"name\":\"bed_vacancies_clinics_11_may\",\"description\":\"Each row identifies a district and the beds earmarked, occupied and available for COVID cases in the district clinics.\"},{\"name\":\"bed_vacancies_health_centers_and_district_hospitals_11_may\",\"description\":\"Each row identifies a district and the beds earmarked, occupied and available, with and without oxygen supply, and with and without ICU support, for COVID cases in the district health centers and hospitals.\"},{\"name\":\"covid_cases_11_may\",\"description\":\"Each row identifies a district and the number of people who received treatment, were discharged and died due to COVID.\"}]"

num_common_values = 10

indicator_vars="district_name" # This should be a comma delimited string in multiple vars


### Generate database descriptions and suggested questions for end-users

In [None]:
from askametric.query_processor.db_descriptor import DatabaseDescriptor

db_descr = DatabaseDescriptor(llm=llm)

async with async_session() as asession:
    description = await db_descr.generate_db_description(
        asession=asession,
        metric_db_id=metric_db_id,
        sys_message=sys_message,
        table_description=db_description,
        column_description=""
)

print("DB description:\n", description)

### Single-turn question

In [None]:
# Your question
query = {
    "query_text": "How many beds are available in chennai??",
    "query_metadata": {}
}


In [None]:
from askametric.query_processor.query_processor import LLMQueryProcessor

async with async_session() as session:
    qp = LLMQueryProcessor(
        query,
        session,
        metric_db_id,
        db_type,
        llm,
        guardrails_llm,
        sys_message,
        db_description,
        column_description="",
        num_common_values=num_common_values,
        indicator_vars=indicator_vars
    )
    await qp.process_query()
    print(qp.final_answer)

### Multi-turn chat

In [None]:
# Simulating a conversation

queries = [{
    "query_text": "How many beds are available in chennai??",
    "query_metadata": {}
},
{
    "query_text": "How about Ranipet??",
    "query_metadata": {}
},
{
    "query_text": "Inge irrukira COVID patients patthi sollu",
    "query_metadata": {}
},
{
    "query_text": "How many beds with oxygen supply in Ariyalur??",
    "query_metadata": {}
},
{
    "query_text": "And how many COVID patients here?",
    "query_metadata": {}
},
{
    "query_text": "No I want the information for Ariyalur",
    "query_metadata": {}
},
{
    "query_text": "Ranipet aur Ariyalur ke lie phir batao, kya kya jaankari diya hai aapne?",
    "query_metadata": {}
},
{
    "query_text": "What's the distribution of COVID patients in Madurai?",
    "query_metadata": {}
}
]

In [None]:
from askametric.query_processor.query_processor import MultiTurnQueryProcessor

chat_history = []
async with async_session() as session:
    for query in queries:
        print(f"Q: {query['query_text']}")
        mqp = MultiTurnQueryProcessor(
            query=query,
            asession=session,
            metric_db_id=metric_db_id,
            db_type=db_type,
            llm=llm,
            guardrails_llm=guardrails_llm,
            sys_message=sys_message,
            db_description=db_description,
            column_description="",
            indicator_vars=indicator_vars,
            num_common_values=num_common_values,
            chat_history=chat_history
        )
        await mqp.process_query()
        chat_history.append({"user": mqp.reframed_query,
                             "system": mqp.translated_final_answer})
        
        
        print(f"Query type: {mqp.query_type}")
        print(f"Reframed query: {mqp.reframed_query}")
        print(f"A: {mqp.final_answer}")
        print("\n")