### Importing the required python packages

In [284]:
from urllib.parse import quote
import os
import getpass
import pandas as pd
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langgraph.prebuilt import create_react_agent
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.utilities import SQLDatabase
from langchain_groq import ChatGroq
import time
from tqdm import tqdm

### Connecting to the appropriate SQL DB

In [None]:
with open('pass.txt', 'r') as file:
    password = file.read()

user = "root"
host = "localhost"
database = "SEMEVALSAMPLE"

encoded_password = quote(password)
connection_uri = f"mysql+pymysql://{user}:{encoded_password}@{host}/{database}"
db = SQLDatabase.from_uri(connection_uri)
print(db.dialect)
print(db.get_usable_table_names())

### Defining the system messages for the answer generation prompt and formating prompt

In [286]:
system_message="""
You are an agent designed to interact with a SQL database.
Given is an input question and the table name which has its answer. First, get the schema for this table and then create a syntactically correct MySQL query using to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 1 results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
You have access to tools for interacting with the database.
Only use the below tools. Only use the information returned by the below tools to construct your final answer.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.

If multiple queries need to be executed, execute them one by one.
There might be cases where you need to query the table multiple times (example: finding unique values of a particular column from the table first).
Then you should query the schema of the table and give the final answer based on the query output.
"""

sys_msg="""
You are given a question and answer pair. Give the answer in one of the following answer types :-
Boolean: Valid answers include True/False, Y/N, Yes/No (all case insensitive).
Category: A value from a cell (or a substring of a cell) in the dataset.
Number: A numerical value from a cell in the dataset, which may represent a computed statistic (e.g., average, maximum, minimum).
List[category]: A list containing a fixed number of categories. The expected format is: "['cat', 'dog']". Pay attention to the wording of the question to determine if uniqueness is required or if repeated values are allowed.
List[number]: Similar to List[category], but with numbers as its elements.
Do not give anything else other than the final answer with the correct datatype.
"""

### Defining the model

In [287]:
if not os.environ.get("GROQ_API_KEY"):
        os.environ["GROQ_API_KEY"] = getpass.getpass("Enter API key for Groq: ")
llm = ChatGroq(model="llama-3.3-70b-versatile")
llm2 = ChatGroq(model="llama3-8b-8192")

### Generating the answers by iterating through the test_qa file and using agent for each question

In [None]:
df = pd.read_csv('test_qa.csv')
for index, row in tqdm(df.iterrows()):
    query = row['question']
    table = row['dataset']
    question = f"Table name: {table}, Question: {query}"

    toolkit = SQLDatabaseToolkit(db=db, llm=llm)
    tools = toolkit.get_tools()
    agent_executor = create_react_agent(llm, tools, state_modifier=system_message)

    events = agent_executor.stream(
        {"messages": [("user", question)]},
        stream_mode="values",
    )
    out_str = events[0]["messages"][-1]

    formatting_prompt = ChatPromptTemplate.from_messages([("system", sys_msg), ("human", f"Question: {query}, Answer:{out_str.content}")])
    response = llm2.invoke(formatting_prompt)
    
    with open("predictions.txt", "a+") as f:
        print(response.content,file=f)
    f.close()

    time.sleep(10)