In [17]:
from langchain_ollama import OllamaLLM
from typing_extensions import TypedDict
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langgraph.graph import StateGraph, START, END
from db_create import CargaDeArchivos
import re

In [18]:
a= CargaDeArchivos()
a.run_carga()
db_conn= a.conn


class State(TypedDict):
    """
    Represents the state of the workflow, including the question, schema, database connection,
    relevance, SQL query, query result, and other metadata.
    """
    question: str
    schema: str
    db_conn: None
    relevance: str
    sql_query: str
    query_result: str
    sql_error: bool
    final_answer: str
    attempts: int

In [29]:
def convert_nl_to_sql(state: State):
    """
    Converts a natural language question into an SQL query based on the database schema.

    Args:
        state (State): The current state of the workflow.
        config (RunnableConfig): Configuration for the runnable.

    Returns:
        State: Updated state with the generated SQL query.
    """
    question = state["question"]
    print(f"Converting question to SQL {question}")
    system = """
    You are an SQL assistant specialized in DuckDB. Your task is to generate accurate SQL queries based on natural language questions, following the provided schema.

    ### Database Schema  
    #### Table: "cases"
    - "id" (VARCHAR): Primary key.
    - "insurance" (BIGINT): Foreign key to insurance.
    - "avg_time" (DOUBLE): Duration (seconds) from case initiation to closure.
    - "type" (VARCHAR): Insurance category.
    - "branch" (VARCHAR): Policy branch.
    - "ramo" (VARCHAR): Coverage type.
    - "broker" (VARCHAR): Broker for the policy.
    - "state" (VARCHAR): Current case state.
    - "client" (VARCHAR): Client who bought the insurance.
    - "creator" (VARCHAR): Employee managing the case.
    - "value" (BIGINT): Insurance monetary value.
    - "approved" (BOOLEAN): TRUE if approved, else FALSE.
    - "insurance_creation" (TIMESTAMP_NS): Policy creation timestamp.
    - "insurance_start" (TIMESTAMP_NS): Coverage start timestamp.
    - "insurance_end" (TIMESTAMP_NS): Coverage end timestamp.

    #### Table: "activity"
    - "id" (BIGINT): Primary key.
    - "case" (VARCHAR): Foreign key to "cases"."id".
    - "timestamp" (TIMESTAMP_NS): Activity timestamp.
    - "name" (VARCHAR): Name of the activity.
    - "case_index" (BIGINT): Alias for "id".
    - "tpt" (DOUBLE): Activity duration (seconds).

    ### Query Guidelines  
    1. **Use Table Aliases**: "cases" → c, "activity" → a.
    2. **Always Reference Columns with Aliases**: c."id", a."case".
    3. **Handle Aggregations**: Include non-aggregated columns in GROUP BY.
    4. **Date & Time Calculations**: Use EXTRACT(DAY FROM ...) for durations.
    5. **Filtering Conditions**: Use TRUE/FALSE for boolean values.
    6. **Use Explicit Joins**: Avoid implicit joins.
    7. **Optimize for Performance**: Use indexes, avoid unnecessary calculations, and limit results when needed.
    8. **Restrict Queries to Existing Tables**: Only use "cases" and "activity" tables.

    ### Output Format  
    - Return only the SQL query, with no extra formatting.  
    - Do **NOT** include language tags like `sql`, `vbnet`, or any other markers.  
    """



    llm= OllamaLLM(model="vicuna:latest",temperature="0.0")
    
    convert_prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system),
            ("human", "Question: {question}"),
        ]
    )
    sql_generator = convert_prompt | llm
    result = sql_generator.invoke({"question": question})
    message= re.sub(r'^\s*```sql\s*|\s*```$', '', result.strip(), flags=re.IGNORECASE)
    state["sql_query"] = message
    print(f"Generated SQL query: {state['sql_query']}")
    return state


def execute_sql(state:State):
    """
    Executes the SQL query on the  database and retrieves the results.

    Args:
        state (State): The current state of the workflow.
        config (RunnableConfig): Configuration for the runnable.

    Returns:
        State: Updated state with the query results or error information.
    """
    sql_query = state["sql_query"].strip()
    db_conn = state["db_conn"]  
    print(f"Executing SQL query: {sql_query}")

    try:
        # Ensure the query targets only the allowed tables
        allowed_tables = ["cases", "activity"]
        if not any(table in sql_query.lower() for table in allowed_tables):
            raise ValueError(f"Query must target only the tables: {', '.join(allowed_tables)}.")

        # Execute the SQL query using the connection
        cursor = db_conn.cursor()
        cursor.execute(sql_query)

        # Fetch results if it's a SELECT query
        if sql_query.lower().startswith("select"):
            rows = cursor.fetchall()
            columns = [desc[0] for desc in cursor.description]

            # Format the output
            if rows:
                formatted_result = "\n".join(
                    ", ".join(f"{col}: {row[idx]}" for idx, col in enumerate(columns))
                    for row in rows
                )
                print("SQL SELECT query executed successfully.")
            else:
                formatted_result = "No results found."
                print("SQL SELECT query executed successfully but returned no rows.")

            state["query_rows"] = rows
        else:
            formatted_result = "The action has been successfully completed."
            print("SQL command executed successfully.")

        state["query_result"] = formatted_result
        state["sql_error"] = False

    except Exception as e:
        state["query_result"] = f"Error executing SQL query: {str(e)}"
        state["sql_error"] = True
        print(f"Error executing SQL query: {str(e)}")
    print(state['query_result'])
    return state

In [None]:
test_cases = [
    {"question": "How many cases are there in total in the database?",
    "expected_sql": 'SELECT COUNT(*) FROM cases;'},
    {"question": "How many activities have been recorded?",
    "expected_sql": 'SELECT COUNT(*) FROM activity;'},
    {"question": "What are the different types of insurance available?",
    "expected_sql": 'SELECT DISTINCT type FROM cases;'},
    {"question": "What is the total value of approved insurance policies?",
    "expected_sql": 'SELECT SUM(value) FROM cases WHERE approved = 1;'},
    {"question": "How many cases were created in January?",
    "expected_sql": 'SELECT COUNT(*) FROM cases WHERE EXTRACT(MONTH FROM insurance_creation) = 1;'},
    {"question": "Who is the broker with the most assigned cases?",
    "expected_sql": 'SELECT brocker FROM cases GROUP BY brocker ORDER BY COUNT(*) DESC LIMIT 1;'},
    {"question": "What is the most frequent activity in the database?",
    "expected_sql": 'SELECT name FROM activity GROUP BY name ORDER BY COUNT(*) DESC LIMIT 1;'},
    {"question": "What is the total insurance value for each type of 'ramo'?",
    "expected_sql": 'SELECT ramo, SUM(value) FROM cases GROUP BY ramo;'},
    {"question": "On average, how long does it take for an insurance policy to be approved?",
    "expected_sql": 'SELECT AVG(EXTRACT(DAY FROM insurance_start - insurance_creation)) FROM cases WHERE approved = 1;'},
    {"question": "What is the average number of activities per case?",
    "expected_sql": 'SELECT COUNT(activity.id) / COUNT(DISTINCT cases.id) FROM activity INNER JOIN cases ON activity.case = cases.id;'},
    {"question": "What is the most frequent activity performed in approved cases?",
    "expected_sql": 'SELECT activity.name FROM activity INNER JOIN cases ON activity.case = cases.id WHERE cases.approved = TRUE GROUP BY activity.name ORDER BY COUNT(activity.id) DESC LIMIT 1;'},
    {"question": "What is the total duration of all activities for each case?",
    "expected_sql": 'SELECT activity.case, SUM(activity.tpt) FROM activity GROUP BY activity.case;'},
    {"question": "What is the total value of cases that have at least one recorded activity?",
    "expected_sql": 'SELECT SUM(cases.value) FROM cases WHERE cases.id IN (SELECT DISTINCT case FROM activity);'},
    
]

In [31]:
import time
import pandas as pd

def time_query_generation(state):
    """
    Measures the execution time of the model generating an SQL query.

    Args:
        state (dict): The state object that stores the query and result.

    Returns:
        dict: Updated state containing the query, execution time, and SQL output.
    """
    start_time = time.time()
    try:
        convert_nl_to_sql(state)  # Generates SQL query
        execute_sql(state)  # Executes SQL query
    except Exception as e:
        state["sql_query"] = f"Error: {e}"

    execution_time = time.time() - start_time
    state["execution_time"] = execution_time
    return state

In [32]:
results = []

# Run tests
for test in test_cases:
    state = {"question": test["question"], "expected_sql": test["expected_sql"],"db_conn":db_conn}
    result_state = time_query_generation(state)
    
    results.append({
        "Question": test["question"],
        "Expected SQL Query": test["expected_sql"],
        "Generated SQL Query": result_state['sql_query'],
        "Query Result": result_state["query_result"],
        "Execution Time (s)": result_state["execution_time"]
    })

# Convert results to DataFrame
df_results = pd.DataFrame(results)

Converting question to SQL How many cases are there in total in the database?
Generated SQL query: SELECT COUNT(*) FROM cases;
Executing SQL query: SELECT COUNT(*) FROM cases;
SQL SELECT query executed successfully.
count_star(): 1000
Converting question to SQL How many activities have been recorded?
Generated SQL query: SELECT COUNT(a."id") AS "number_of_activities"
FROM activity a
JOIN cases c ON a."case" = c."id"
GROUP BY a."case", a."timestamp", a."name"
Executing SQL query: SELECT COUNT(a."id") AS "number_of_activities"
FROM activity a
JOIN cases c ON a."case" = c."id"
GROUP BY a."case", a."timestamp", a."name"
SQL SELECT query executed successfully.
number_of_activities: 1
number_of_activities: 1
number_of_activities: 1
number_of_activities: 1
number_of_activities: 1
number_of_activities: 1
number_of_activities: 1
number_of_activities: 1
number_of_activities: 1
number_of_activities: 1
number_of_activities: 1
number_of_activities: 1
number_of_activities: 1
number_of_activities: 1


In [33]:
df_results

Unnamed: 0,Question,Expected SQL Query,Generated SQL Query,Query Result,Execution Time (s)
0,How many cases are there in total in the datab...,SELECT COUNT(*) FROM cases;,SELECT COUNT(*) FROM cases;,count_star(): 1000,1.132163
1,How many activities have been recorded?,SELECT COUNT(*) FROM activity;,"SELECT COUNT(a.""id"") AS ""number_of_activities""...",number_of_activities: 1\nnumber_of_activities:...,1.974502
2,What are the different types of insurance avai...,SELECT DISTINCT type FROM cases;,SELECT c.type \nFROM cases c \nJOIN activity a...,type: Policy onboarding\ntype: Issuance\ntype:...,2.146427
3,What is the total value of approved insurance ...,SELECT SUM(value) FROM cases WHERE approved = 1;,SELECT SUM(c.value) AS total_approved_insuranc...,total_approved_insurance_value: 112080900,2.214971
4,How many cases were created in January?,SELECT COUNT(*) FROM cases WHERE EXTRACT(MONTH...,"SELECT COUNT(a.""case"") AS ""count""\nFROM activi...",count: 0,2.544626
5,Who is the broker with the most assigned cases?,SELECT brocker FROM cases GROUP BY brocker ORD...,"SELECT a.""broker"", COUNT(c.""id"") AS ""num_cases...",Error executing SQL query: Binder Error: Table...,2.547432
6,What is the most frequent activity in the data...,SELECT name FROM activity GROUP BY name ORDER ...,"SELECT a.name AS ""Activity"", COUNT(a.id) AS ""C...","Activity: Enviar a Revisión suscripción, Count...",2.444983
7,What is the total insurance value for each typ...,"SELECT ramo, SUM(value) FROM cases GROUP BY ramo;","SELECT c.""type"", SUM(c.""value"") AS ""total_insu...","type: Renewal, total_insurance_value: 36234600...",2.243066
8,"On average, how long does it take for an insur...",SELECT AVG(EXTRACT(DAY FROM insurance_start - ...,"SELECT AVG(a.""tpt"") AS ""avg_time"" FROM c JOIN ...",Error executing SQL query: Query must target o...,1.879625
9,What is the average number of activities per c...,SELECT COUNT(activity.id) / COUNT(DISTINCT cas...,"SELECT \n c.id, \n AVG(a.tpt) AS avg_act...","id: 9488, avg_activity_duration: 14373.7217918...",2.589537


In [34]:
df_results.to_csv("results_vicuna.csv", index=False)