In [ ]:
import re
from langchain_community.utilities import SQLDatabase
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_google_genai import ChatGoogleGenerativeAI
from ragas import evaluate
from ragas.metrics import ContextPrecision, RubricsScore
from ragas.dataset_schema import SingleTurnSample, EvaluationDataset

# Database connection details
host = "localhost"
port = "3306"
username = "root"
password = "arushisql"
database_schema = "text_to_sql"
mysql_uri = f"mysql+pymysql://{username}:{password}@{host}:{port}/{database_schema}"
db = SQLDatabase.from_uri(mysql_uri, sample_rows_in_table_info=2)

# Create LLM prompt template
template = """Based on the table schema below, write an SQL query that would answer the user's question:
Remember : Only provide the sql query, don't include anything else. Provide me sql query in a 
single line don't add line breaks
Table Schema: {schema}
Question : {question}
SQL Query:
 """"
prompt = ChatPromptTemplate.from_template(template)

# Initialize LLM
llm = ChatGoogleGenerativeAI(
    model = "gemini-2.0-flash",
    api_key = "AIzaSyCMHfDPev65cIQNRFpHK0pjeRyXzWhIlvU"
)

# Create SQL query chain using LLM and the prompt template
sql_chain = (
    RunnablePassthrough.assign(schema=lambda _: db.get_table_info())
    | prompt
    | llm.bind(stop={"/nSQLResult:"})
    | StrOutputParser()
)

# RAGAS evaluation setup
context_precision = ContextPrecision(llm=llm)
rubrics_score = RubricsScore(name="helpfulness", rubrics={
    "score1_description": "Response is useless/irrelevant...",
    "score2_description": "Response is minimally relevant...",
    "score3_description": "Response is relevant to the instruction...",
    "score4_description": "Response is very relevant to the instruction...",
    "score5_description": "Response is useful and very comprehensive..."
}, llm=llm)

# Function to evaluate user inputs
def evaluate_user_inputs(user_inputs):
    responses = []
    for question in user_inputs:
        resp = sql_chain.invoke({"question": question})
        match = re.search(r"```sql\s*(.*?)\s*```", resp, re.DOTALL | re.IGNORECASE)
        if match:
            query = match.group(1).strip()
            responses.append(query)
    return responses

# Example user inputs
user_inputs = [
    "What was the budget of Product 12",
    "What are the names of all products in the products table?",
    "List all customer names from the customers table.",
    "Find the name and state of all regions in the regions table.",
    "What is the name of the customer with Customer Index = 1"
]

# Evaluate the user inputs
responses = evaluate_user_inputs(user_inputs)
responses
