In [1]:
from langchain_ollama import OllamaLLM
from typing_extensions import TypedDict
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langgraph.graph import StateGraph, START, END
from db_create import CargaDeArchivos
import re

In [2]:
a= CargaDeArchivos()
a.run_carga()
db_conn= a.conn

In [3]:
class State(TypedDict):
    """
    Represents the state of the workflow, including the question, schema, database connection,
    relevance, SQL query, query result, and other metadata.
    """
    question: str
    db_conn: None
    relevance: str
    sql_query: str
    query_result: str
    sql_error: bool
    final_answer: str
    attempts: int

In [4]:
def check_relevance(state: State):
    """
    Determines whether the user's question is relevant to the database schema.

    Args:
        state (State): The current state of the workflow.
        config (RunnableConfig): Configuration for the runnable.

    Returns:
        State: Updated state with relevance information.
    """
    question = state["question"]
    print(f"Checking relevance of the question: {question}")

    # Define the system prompt for relevance checking
    system = f"""
    You are an assistant that determines whether a given question is related to the following database schema
    A question is considered **relevant** if it pertains to activities, cases, durations, business insights, or any other concepts related to process analysis and revenue assessment.
    ### Database Schema  
    #### Table: "cases"
    - "id" (VARCHAR): Primary key.
    - "insurance" (BIGINT): Foreign key to insurance.
    - "avg_time" (DOUBLE): Duration (seconds) from case initiation to closure.
    - "type" (VARCHAR): Insurance category.
    - "branch" (VARCHAR): Policy branch.
    - "ramo" (VARCHAR): Coverage type.
    - "broker" (VARCHAR): Broker for the policy.
    - "state" (VARCHAR): Current case state.
    - "client" (VARCHAR): Client who bought the insurance.
    - "creator" (VARCHAR): Employee managing the case.
    - "value" (BIGINT): Insurance monetary value.
    - "approved" (BOOLEAN): TRUE if approved, else FALSE.
    - "insurance_creation" (TIMESTAMP_NS): Policy creation timestamp.
    - "insurance_start" (TIMESTAMP_NS): Coverage start timestamp.
    - "insurance_end" (TIMESTAMP_NS): Coverage end timestamp.

    #### Table: "activity"
    - "id" (BIGINT): Primary key.
    - "case" (VARCHAR): Foreign key to "cases"."id".
    - "timestamp" (TIMESTAMP_NS): Activity timestamp.
    - "name" (VARCHAR): Name of the activity.
    - "case_index" (BIGINT): Alias for "id".
    - "tpt" (DOUBLE): Activity duration (seconds).
    ### **Relevance Criteria:**
    - A question is considered **relevant** if it pertains to activities, cases, durations, business insights, or any other concepts related to process analysis and revenue assessment.
    - Questions involving **business insights**, such as client revenue, broker performance, or policy value trends, are relevant.
    - If the question is purely conceptual (e.g., "What is an activity?"), it is **not relevant**, even if it contains a column name.
    ###Response Format:
    - Respond with "relevant" if the question is related to the schema.
    - Respond with "not_relevant" if the question is not related to the schema.
    """
    # Define the human prompt with the user's question
    human = f"Question: {question}"

    # Create a prompt template for the LLM
    check_prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system),
            ("human", human),
        ]
    )

    # Invoke the LLM to determine relevance
    llm = OllamaLLM(model="gemma3:1b", temperature="0.0")
    relevance_checker = check_prompt | llm
    response = relevance_checker.invoke({}).strip().lower()

    # Validate the response to ensure it matches expected outputs
    if response not in ["relevant", "not_relevant"]:
        raise ValueError(f"Unexpected relevance response: {response}")

    # Update the state with the relevance result
    state["relevance"] = response
    state["attempts"] = 0
    print(f"Relevance determined: {state['relevance']}")
    return state

def convert_nl_to_sql(state: State):
    """
    Converts a natural language question into an SQL query based on the database schema.

    Args:
        state (State): The current state of the workflow.
        config (RunnableConfig): Configuration for the runnable.

    Returns:
        State: Updated state with the generated SQL query.
    """
    question = state["question"]
    print(f"Converting question to SQL {question}")
    system = """
    You are an SQL assistant specialized in DuckDB. Your task is to generate accurate SQL queries based on natural language questions, following the provided schema.

    ### Database Schema  
    #### Table: "cases"
    - "id" (VARCHAR): Primary key.
    - "insurance" (BIGINT): Foreign key to insurance.
    - "avg_time" (DOUBLE): Duration (seconds) from case initiation to closure.
    - "type" (VARCHAR): Insurance category.
    - "branch" (VARCHAR): Policy branch.
    - "ramo" (VARCHAR): Coverage type.
    - "broker" (VARCHAR): Broker for the policy.
    - "state" (VARCHAR): Current case state.
    - "client" (VARCHAR): Client who bought the insurance.
    - "creator" (VARCHAR): Employee managing the case.
    - "value" (BIGINT): Insurance monetary value.
    - "approved" (BOOLEAN): TRUE if approved, else FALSE.
    - "insurance_creation" (TIMESTAMP_NS): Policy creation timestamp.
    - "insurance_start" (TIMESTAMP_NS): Coverage start timestamp.
    - "insurance_end" (TIMESTAMP_NS): Coverage end timestamp.

    #### Table: "activity"
    - "id" (BIGINT): Primary key.
    - "case" (VARCHAR): Foreign key to "cases"."id".
    - "timestamp" (TIMESTAMP_NS): Activity timestamp.
    - "name" (VARCHAR): Name of the activity.
    - "case_index" (BIGINT): Alias for "id".
    - "tpt" (DOUBLE): Activity duration (seconds).

    ### Query Guidelines  
    1. Convert any time differences (e.g., between `insurance_start` and `insurance_creation`) from `INTERVAL` to a numeric type, such as seconds or minutes, for accurate calculations.
    2. Use functions like `EXTRACT(EPOCH FROM ...)` to convert `INTERVAL` types into numeric values (e.g., seconds) that can be averaged.
    3. **Use Table Aliases**: "cases" → c, "activity" → a.
    4. **Always Reference Columns with Aliases**: c."id", a."case".
    5. **Handle Aggregations**: Include non-aggregated columns in GROUP BY.
    6. **Date & Time Calculations**: Use EXTRACT(DAY FROM ...) for durations.
    7. **Filtering Conditions**: Use TRUE/FALSE for boolean values.
    8. **Use Explicit Joins**: Avoid implicit joins.
    9. **Optimize for Performance**: Use indexes, avoid unnecessary calculations, and limit results when needed.
    10. **Restrict Queries to Existing Tables**: Only use "cases" and "activity" tables.
    11. **Use JOINS only when necessary**: Avoid unnecessary joins.
    ### Output Format  
    - Return only the SQL query, with no extra formatting.  
    - Do **NOT** include language tags like `sql`, `vbnet`, or any other markers.  
    """
    llm= OllamaLLM(model="duckdb-nsql:latest",temperature="0.0")
    
    convert_prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system),
            ("human", "Question: {question}"),
        ]
    )
    sql_generator = convert_prompt | llm
    result = sql_generator.invoke({"question": question})
    message= re.sub(r'^\s*```sql\s*|\s*```$', '', result.strip(), flags=re.IGNORECASE)
    state["sql_query"] = message
    print(f"Generated SQL query: {state['sql_query']}")
    return state



def execute_sql(state:State):
    """
    Executes the SQL query on the  database and retrieves the results.

    Args:
        state (State): The current state of the workflow.
        config (RunnableConfig): Configuration for the runnable.

    Returns:
        State: Updated state with the query results or error information.
    """
    sql_query = state["sql_query"].strip()
    db_conn = state["db_conn"]  
    print(f"Executing SQL query: {sql_query}")

    try:
        # Ensure the query targets only the allowed tables
        allowed_tables = ["cases", "activity"]
        if not any(table in sql_query.lower() for table in allowed_tables):
            raise ValueError(f"Query must target only the tables: {', '.join(allowed_tables)}.")

        # Execute the SQL query using the connection
        cursor = db_conn.cursor()
        cursor.execute(sql_query)

        # Fetch results if it's a SELECT query
        if sql_query.lower().startswith("select"):
            rows = cursor.fetchall()
            columns = [desc[0] for desc in cursor.description]

            # Format the output
            if rows:
                formatted_result = "\n".join(
                    ", ".join(f"{col}: {row[idx]}" for idx, col in enumerate(columns))
                    for row in rows
                )
                print("SQL SELECT query executed successfully.")
            else:
                formatted_result = "No results found."
                print("SQL SELECT query executed successfully but returned no rows.")

            state["query_rows"] = rows
        else:
            formatted_result = "The action has been successfully completed."
            print("SQL command executed successfully.")

        state["query_result"] = formatted_result
        state["sql_error"] = False

    except Exception as e:
        state["query_result"] = f"Error executing SQL query: {str(e)}"
        state["sql_error"] = True
        print(f"Error executing SQL query: {str(e)}")
    print(state['query_result'])
    return state


    
def generate_serious_answer(state: State):
    """
    Simplified function to generate a business-oriented response based on the SQL query results.
    
    Args:
        state (State): The current state of the workflow.
        
    Returns:
        State: Updated state with the final answer.
    """
    question = state["question"]
    query_result = state['query_result']
    
    # Directly construct the system message without extra formatting
    system = f"""
    You are sOFIa, an AI assistant designed by the AI dream team of OFI Services. Your task is to:
    1. Answer the user's question based on the SQL result.
    2. Provide relevant business insights and recommendations based on the result.
    
    ### **Context:**
    - **User's question:** {question}
    - **SQL result:** {query_result}

    ### **Instructions:**
    - Provide an answer to the question.
    - Offer insights based on the query result, including trends, recommendations, or comparisons.
    
    Keep your response concise, relevant, and business-focused.
    """

    human_message = f"Question: {question}"
    
    # Use sOFIa to generate a response based on the SQL result
    llm = OllamaLLM(model="gemma3:1b", temperature="0.0")
    response = ChatPromptTemplate.from_messages([
        ("system", system),
        ("human", human_message),
    ]) | llm | StrOutputParser()
    
    # Generate and store the response
    message = response.invoke({})
    print(message)
    state["final_answer"] = message
    return state



def regenerate_query(state):
    """
    Fixes the SQL query by passing the error message to the SQL model instead of rewriting the user's question.

    Args:
        state (State): The current state of the workflow.

    Returns:
        State: Updated state with the fixed query.
    """
    query = state["sql_query"]
    error = state["query_result"]

    print(f"🔄 Regenerating query. Attempt {state['attempts'] + 1}")

    # ✅ Pass the query and error message to the SQL model for correction
    system = """You are an expert in SQL for DuckDB. Your task is to correct SQL queries based on error messages.
    
    ### Database Schema  
    #### Table: "cases"
    - "id" (VARCHAR): Primary key.
    - "insurance" (BIGINT): Foreign key to insurance.
    - "avg_time" (DOUBLE): Duration (seconds) from case initiation to closure.
    - "type" (VARCHAR): Insurance category.
    - "branch" (VARCHAR): Policy branch.
    - "ramo" (VARCHAR): Coverage type.
    - "broker" (VARCHAR): Broker for the policy.
    - "state" (VARCHAR): Current case state.
    - "client" (VARCHAR): Client who bought the insurance.
    - "creator" (VARCHAR): Employee managing the case.
    - "value" (BIGINT): Insurance monetary value.
    - "approved" (BOOLEAN): TRUE if approved, else FALSE.
    - "insurance_creation" (TIMESTAMP_NS): Policy creation timestamp.
    - "insurance_start" (TIMESTAMP_NS): Coverage start timestamp.
    - "insurance_end" (TIMESTAMP_NS): Coverage end timestamp.

    #### Table: "activity"
    - "id" (BIGINT): Primary key.
    - "case" (VARCHAR): Foreign key to "cases"."id".
    - "timestamp" (TIMESTAMP_NS): Activity timestamp.
    - "name" (VARCHAR): Name of the activity.
    - "case_index" (BIGINT): Alias for "id".
    - "tpt" (DOUBLE): Activity duration (seconds).

    ### Task Instructions:
    - **Return only the corrected SQL query. No explanations.**
    - **Ensure it runs correctly on DuckDB.**
    - **Preserve the original query intent while fixing errors.**

    """

    sql_fix_prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system),
            (
                "human",
                f"The following SQL query failed:\n{query}\n\nError encountered:\n{error}\n\nProvide a corrected SQL query.",
            ),
        ]
    )

    llm = OllamaLLM(model="duckdb-nsql:latest", temperature=0.0)  # Use DuckDB-specific SQL model
    fixer = sql_fix_prompt | llm
    corrected_query = fixer.invoke({"query": query, "error": error})

    # ✅ Update state with the corrected query
    print(f"✅ Fixed SQL query: {corrected_query}")
    state["sql_query"] = corrected_query
    state["attempts"] += 1
    return state


def end_max_iterations(state: State):
    """
    Ends the workflow after reaching the maximum number of attempts.

    Args:
        state (State): The current state of the workflow.
        config (RunnableConfig): Configuration for the runnable.

    Returns:
        State: Updated state with a termination message.
    """
    state["query_result"] = "Please try again."
    print("Maximum attempts reached. Ending the workflow.")
    return state



def generate_funny_response(state: State):
    """
    Generates a playful and humorous response for unrelated questions.
    
    Args:
        state (State): The current state of the workflow.
        
    Returns:
        State: Updated state with the funny response.
    """
    print("Generating a funny response for an unrelated question.")
    question = state["question"]
    
    # Add playful introduction when needed
    system = """You are **sOFIa**, a charming and funny assistant dessigned by AI team at OFI Services who responds in a playful and lighthearted manner.
    Your responses should always be fun, engaging, and humorous, but you should introduce yourself when needed, especially if the user doesn't know you yet. 
    Keep it light, and full of personality. You can even throw in a little joke here and there!
    """

    human_message = f"Question: {question}"

    # Generate the playful response
    funny_prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system),
            ("human", human_message),
        ]
    )
    
    llm = OllamaLLM(model="gemma3:1b", temperature="0.7")
    funny_response = funny_prompt | llm | StrOutputParser()
    message = funny_response.invoke({})
    
    state["final_answer"] = message
    print("Generated funny response.")
    print(message)
    
    return state


    
def check_attempts_router(state: State):
    """
    Routes the workflow based on the number of attempts made to generate a valid SQL query.

    Args:
        state (State): The current state of the workflow.

    Returns:
        str: The next node in the workflow.
    """
    if state["attempts"] < 3:
        return "execute_sql"
    else:
        return "end_max_iterations"



def execute_sql_router(state: State):
    """
    Routes the workflow based on whether the SQL query execution was successful.

    Args:
        state (State): The current state of the workflow.

    Returns:
        str: The next node in the workflow.
    """
    if not state.get("sql_error", False):
        return "generate_serious_answer"
    else:
        return "regenerate_query"

    
    
def relevance_router(state: State):
    """
    Routes the workflow based on the relevance of the user's question.

    Args:
        state (State): The current state of the workflow.

    Returns:
        str: The next node in the workflow.
    """
    if state["relevance"].lower() == "relevant":
        return "convert_to_sql"
    else:
        return "generate_funny_response"


In [5]:
workflow = StateGraph(State)
workflow.add_node("check_relevance", check_relevance)
workflow.add_node("convert_to_sql", convert_nl_to_sql)
workflow.add_node("execute_sql",execute_sql)
workflow.add_node("regenerate_query",regenerate_query)
workflow.add_node("generate_funny_response", generate_funny_response)
workflow.add_node("generate_serious_answer",generate_serious_answer)
workflow.add_node("end_max_iterations",end_max_iterations)

workflow.add_edge(START, "check_relevance")
workflow.add_conditional_edges(
        "check_relevance",
        relevance_router,
        {
        "convert_to_sql": "convert_to_sql",
            "generate_funny_response": "generate_funny_response" ,
        } 
    )
workflow.add_edge("convert_to_sql", "execute_sql")

workflow.add_conditional_edges(
        "execute_sql",
        execute_sql_router,
        {
            "generate_serious_answer": "generate_serious_answer",
            "regenerate_query": "regenerate_query",
        },
    )

workflow.add_conditional_edges(
        "regenerate_query",
        check_attempts_router,
        {
            "execute_sql": "execute_sql",
            "end_max_iterations": "end_max_iterations",
        },
    )
workflow.add_edge("end_max_iterations", END)
workflow.add_edge("generate_serious_answer",END)
workflow.add_edge("generate_funny_response",END)

chain= workflow.compile()


In [6]:
state= chain.invoke({"question":"Who are you?","db_conn":db_conn})

Checking relevance of the question: Who are you?
Relevance determined: not_relevant
Generating a funny response for an unrelated question.
Generated funny response.
Well hello there! I’m your friendly neighborhood AI assistant, designed to make your day a little brighter (and maybe a little less boring!). Think of me as a digital comedian with a serious problem-solving ability. 

I’m a bit of a puzzle, really. I was built by a team at OFI Services – they’re experts at making things fun and engaging. So, I’m here to help you with whatever you need, all while keeping things light and a little silly. 

Don't be shy! Ask me anything – I’m ready to chat, tell a joke, or just provide a bit of digital amusement. 😊


In [None]:
import time
import pandas as pd

def time_query_generation(state):
    """
    Measures the execution time of the model generating an SQL query.

    Args:
        state (dict): The state object that stores the query and result.

    Returns:
        dict: Updated state containing the query, execution time, and SQL output.
    """
    start_time = time.time()
    try:
        convert_nl_to_sql(state)  # Generates SQL query
        #format_query(state)  # Formats SQL query
        execute_sql(state)  # Executes SQL query
    except Exception as e:
        state["sql_query"] = f"Error: {e}"

    execution_time = time.time() - start_time
    state["execution_time"] = execution_time
    return state

In [None]:
test_cases = [
    {"question": "How many cases are there in total in the database?",
    "expected_sql": 'SELECT COUNT(*) FROM cases;'},
    {"question": "How many activities have been recorded?",
    "expected_sql": 'SELECT COUNT(*) FROM activity;'},
    {"question": "What are the different types of insurance available?",
    "expected_sql": 'SELECT DISTINCT type FROM cases;'},
    {"question": "What is the total value of approved insurance policies?",
    "expected_sql": 'SELECT SUM(value) FROM cases WHERE approved = 1;'},
    {"question": "How many cases were created in January?",
    "expected_sql": 'SELECT COUNT(*) FROM cases WHERE EXTRACT(MONTH FROM insurance_creation) = 1;'},
    {"question": "Who is the brocker with the most assigned cases?",
    "expected_sql": 'SELECT brocker FROM cases GROUP BY brocker ORDER BY COUNT(*) DESC LIMIT 1;'},
    {"question": "What is the most frequent activity in the database?",
    "expected_sql": 'SELECT name FROM activity GROUP BY name ORDER BY COUNT(*) DESC LIMIT 1;'},
    {"question": "What is the total insurance value for each type of 'ramo'?",
    "expected_sql": 'SELECT ramo, SUM(value) FROM cases GROUP BY ramo;'},
    {"question": "On average, how long does it take for an insurance policy to be approved?",
    "expected_sql": 'SELECT AVG(EXTRACT(DAY FROM insurance_start - insurance_creation)) FROM cases WHERE approved = 1;'},
    {"question": "What is the average number of activities per case?",
    "expected_sql": 'SELECT COUNT(activity.id) / COUNT(DISTINCT cases.id) FROM activity INNER JOIN cases ON activity.case = cases.id;'},
    {"question": "What is the most frequent activity performed in approved cases?",
    "expected_sql": 'SELECT activity.name FROM activity INNER JOIN cases ON activity.case = cases.id WHERE cases.approved = TRUE GROUP BY activity.name ORDER BY COUNT(activity.id) DESC LIMIT 1;'},
    {"question": "What is the total duration of all activities for each case?",
    "expected_sql": 'SELECT activity.case, SUM(activity.tpt) FROM activity GROUP BY activity.case;'},
    {"question": "What is the total value of cases that have at least one recorded activity?",
    "expected_sql": 'SELECT SUM(cases.value) FROM cases WHERE cases.id IN (SELECT DISTINCT case FROM activity);'}
]

In [None]:
results = []

# Run tests
for test in test_cases:
    state = {"question": test["question"], "expected_sql": test["expected_sql"],"db_conn":db_conn}
    result_state = time_query_generation(state)
    
    results.append({
        "Question": test["question"],
        "Expected SQL Query": test["expected_sql"],
        "Generated SQL Query": result_state['sql_query'],
        "Query Result": result_state["query_result"],
        "Execution Time (s)": result_state["execution_time"]
    })

# Convert results to DataFrame
df_results = pd.DataFrame(results)



In [None]:
df_results



In [None]:
df_results.to_csv("results_DeepSeek.csv", index=False)
