In [1]:
from langchain_ollama import OllamaLLM
from typing_extensions import TypedDict
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langgraph.graph import StateGraph, START, END
from db_create import CargaDeArchivos

In [2]:
def fetch_schema(conn):
    """
    Fetch the schema from a DuckDB database, including column descriptions.

    Args:
        conn: Active DuckDB connection.

    Returns:
        dict: A dictionary where keys are table names and values are dictionaries 
              of column names, data types, and descriptions optimized for SQL agent queries.
    """
    cursor = conn.cursor()

    # Get all table names
    cursor.execute("SELECT table_name FROM information_schema.tables WHERE table_schema = 'main'")
    tables = [row[0] for row in cursor.fetchall()]

    schema_info = {}

    # Predefined column descriptions optimized for SQL agent interpretation
    column_descriptions = {
        "cases": {
            "id": "Primary key. Unique identifier for each case.",
            "insurance": "Foreign key. Unique identifier for the related insurance policy.",
            "avg_time": "Time duration (in seconds) from case initiation to closure.",
            "type": "Category of insurance policy (e.g., health, auto, home).",
            "branch": "Branch where the policy was issued.",
            "ramo": "Specific coverage type under the insurance policy.",
            "brocker": "Broker responsible for selling the insurance policy.",
            "client": "Client who purchased the insurance policy.",
            "creator": "Employee responsible for managing the case.",
            "value": "Monetary value of the insurance policy.",
            "approved": "Approval status: 1 = Approved, 0 = Not Approved.",
            "insurance_creation": "Timestamp of when the insurance policy was created.",
            "insurance_start": "Timestamp of when the policy coverage begins.",
            "insurance_end": "Timestamp of when the policy coverage expires.",
        },
        "activity": {
            "id": "Primary key. Unique identifier for each recorded activity.",
            "case": "Foreign key. Identifier of the case this activity belongs to.",
            "timestamp": "Timestamp indicating when the activity took place.",
            "name": "Descriptive name of the activity performed.",
            "case_index": "Alias for 'id'. Unique identifier for the activity record.",
            "tpt": "Time duration (in seconds) for this specific activity.",
        }
    }

    # Fetch column information for each table
    for table in tables:
        cursor.execute(f"""
            SELECT column_name, data_type 
            FROM information_schema.columns 
            WHERE table_name = '{table}'
        """)
        columns = cursor.fetchall()

        schema_info[table] = {
            col[0]: {
                "type": col[1],
                "description": column_descriptions.get(table, {}).get(col[0], "Relevant column for querying this table.")
            }
            for col in columns
        }

    return schema_info

In [3]:
a= CargaDeArchivos()
a.run_carga()
db_conn= a.conn
schema=fetch_schema(db_conn)
schema

{'activity': {'id': {'type': 'BIGINT',
   'description': 'Primary key. Unique identifier for each recorded activity.'},
  'case': {'type': 'VARCHAR',
   'description': 'Foreign key. Identifier of the case this activity belongs to.'},
  'timestamp': {'type': 'TIMESTAMP_NS',
   'description': 'Timestamp indicating when the activity took place.'},
  'name': {'type': 'VARCHAR',
   'description': 'Descriptive name of the activity performed.'},
  'case_index': {'type': 'BIGINT',
   'description': "Alias for 'id'. Unique identifier for the activity record."},
  'tpt': {'type': 'DOUBLE',
   'description': 'Time duration (in seconds) for this specific activity.'}},
 'cases': {'id': {'type': 'VARCHAR',
   'description': 'Primary key. Unique identifier for each case.'},
  'insurance': {'type': 'BIGINT',
   'description': 'Foreign key. Unique identifier for the related insurance policy.'},
  'avg_time': {'type': 'DOUBLE',
   'description': 'Time duration (in seconds) from case initiation to closur

In [4]:
class State(TypedDict):
    """
    Represents the state of the workflow, including the question, schema, database connection,
    relevance, SQL query, query result, and other metadata.
    """
    question: str
    schema: str
    db_conn: None
    relevance: str
    sql_query: str
    query_result: str
    sql_error: bool
    final_answer: str
    attempts: int

In [5]:
def check_relevance(state: State):
    """
    Determines whether the user's question is relevant to the database schema.

    Args:
        state (State): The current state of the workflow.
        config (RunnableConfig): Configuration for the runnable.

    Returns:
        State: Updated state with relevance information.
    """
    # Extract the question and schema from the state
    question = state["question"]
    schema = state["schema"]
    print(f"Checking relevance of the question: {question}")

    # Define the system prompt for relevance checking
    system = f"""
    You are an assistant that determines whether a given question is related to the following database schema
    and to the use case related with process mining for policy insurances.
    A question is considered **relevant** if it pertains to activities, cases, durations, business insights, or any other concepts related to process analysis and revenue assessment.
    
    ### **Relevance Criteria:**
    - Questions about **cases, activities, and process durations** are relevant.
    - Questions involving **business insights**, such as client revenue, broker performance, or policy value trends, are relevant.
    - If the question is purely conceptual (e.g., "What is an activity?"), it is **not relevant**, even if it contains a column name.
    
    ### **Schema:**
    {{schema}}
    
    Respond with only "relevant" or "not_relevant" no explanation is needed.
    """
    # Define the human prompt with the user's question
    human = f"Question: {question}"

    # Create a prompt template for the LLM
    check_prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system),
            ("human", human),
        ]
    )

    # Invoke the LLM to determine relevance
    llm = OllamaLLM(model="mistral:latest", temperature="0.0")
    relevance_checker = check_prompt | llm
    response = relevance_checker.invoke({"schema": schema}).strip().lower()

    # Validate the response to ensure it matches expected outputs
    if response not in ["relevant", "not_relevant"]:
        raise ValueError(f"Unexpected relevance response: {response}")

    # Update the state with the relevance result
    state["relevance"] = response
    state["attempts"] = 0
    print(f"Relevance determined: {state['relevance']}")
    return state

def convert_nl_to_sql(state: State):
    """
    Converts a natural language question into an SQL query based on the database schema.

    Args:
        state (State): The current state of the workflow.
        config (RunnableConfig): Configuration for the runnable.

    Returns:
        State: Updated state with the generated SQL query.
    """
    question = state["question"]
    schema= state["schema"]
    print(f"Converting question to SQL {question}")
    system = f"""You are an SQL assistant. Your task is to transform natural language questions into 
    SQL queries that conform to the following schema:
    
    {{schema}}
    
    ### **Database Overview**
    - The database consists of **two tables**: `"cases"` and `"activity"`.
    - The `"cases"` table contains details about each case, uniquely identified by `"id"`.
    - The `"activity"` table stores activities related to cases.
      - The `"case"` column in `"activity"` corresponds to `"id"` in `"cases"`, forming a logical relationship.
      - This relationship allows **joining the tables** to retrieve case-specific activities.
      - Each **case can have multiple activities**, each recorded with a `"timestamp"`.
    
    ### **Query Guidelines**
    - **To retrieve case-related data**, ensure:
      - Use **explicit joins** when combining data from multiple tables.
      - Match the `"case"` column in `"activity"` to the `"id"` column in `"cases"`.
    
    ### **Rules:**
    - Ensure SQL queries match the exact column names in the schema.
    - Durations for activities refers to the column tpt, for the durations of cases look to the column avg_time.
    - Avoid incorrect grouping or aggregations that do not consider timestamps.
    - Return only the **SQLite query**, without explanations.
    - **Prohibited SQL:** DELETE, CREATE, INSERT, ALTER, UPDATE, TRUNCATE.
    """
    
    llm= OllamaLLM(model="deepseek-coder:6.7b",temperature="0.0")
    
    convert_prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system),
            ("human", "Question: {question}"),
        ]
    )
    sql_generator = convert_prompt | llm
    result = sql_generator.invoke({"question": question,"schema":schema})
    state["sql_query"] = result
    print(f"Generated SQL query: {state['sql_query']}")
    return state


def format_query(state: State):
    """
    Formats the SQL query to ensure it adheres to the SQLite database schema.

    Args:
        state (State): The current state of the workflow.
        config (RunnableConfig): Configuration for the runnable.

    Returns:
        State: Updated state with the formatted SQL query.
    """
    print("Formatting query.")
    query= state["sql_query"]
    system = """
    You are an AI assistant responsible for formatting SQL queries to ensure they can be executed as raw SQL over a **PostgreSQL database**.

    ### Rules:
    - Ensure column names match the **PostgreSQL database schema exactly**.
    - The database consists of **two tables**: `"cases"` and `"activity"`.
    - Column names and their correct format:
        - cases.id → "id"
        - cases.insurance → "insurance"
        - cases.avg_time → "avg_time"
        - cases.type → "type"
        - cases.branch → "branch"
        - cases.ramo → "ramo"
        - cases.broker → "brocker"
        - cases.state → "state"
        - cases.client → "client"
        - cases.creator → "creator"
        - cases.value → "value"
        - cases.approved → "approved"
        - cases.insurance_creation → "insurance_creation"
        - cases.insurance_start → "insurance_start"
        - cases.insurance_end → "insurance_end"
        - activity.id → "id"
        - activity.case → "case"
        - activity.timestamp → "timestamp"
        - activity.name → "name"
        - activity.case_index → "case_index"
        - activity.tpt → "tpt"

    ### Query Formatting:
    - **Correct column names only if they are incorrect.** Otherwise, keep them unchanged.
    - **Do NOT modify query structure or logic.**
    - **Strictly return only the corrected SQL query.** No explanations, comments, formatting notes, or additional text.
    """
    human_message = f"Input: {query}"
    prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system),
            ("human", human_message),
        ]
    )
    llm=OllamaLLM(model="deepseek-coder:6.7b",temperature="0.0")
    response = prompt | llm | StrOutputParser()
    message = response.invoke({})
    state["sql_query"] = message
    return state


def execute_sql(state:State):
    """
    Executes the SQL query on the SQLite database and retrieves the results.

    Args:
        state (State): The current state of the workflow.
        config (RunnableConfig): Configuration for the runnable.

    Returns:
        State: Updated state with the query results or error information.
    """
    sql_query = state["sql_query"].strip()
    db_conn = state["db_conn"]  # SQLite connection
    print(f"Executing SQL query: {sql_query}")

    try:
        # Ensure the query targets only the allowed tables
        allowed_tables = ["cases", "activity"]
        if not any(table in sql_query.lower() for table in allowed_tables):
            raise ValueError(f"Query must target only the tables: {', '.join(allowed_tables)}.")

        # Execute the SQL query using the SQLite connection
        cursor = db_conn.cursor()
        cursor.execute(sql_query)

        # Fetch results if it's a SELECT query
        if sql_query.lower().startswith("select"):
            rows = cursor.fetchall()
            columns = [desc[0] for desc in cursor.description]

            # Format the output
            if rows:
                formatted_result = "\n".join(
                    ", ".join(f"{col}: {row[idx]}" for idx, col in enumerate(columns))
                    for row in rows
                )
                print("SQL SELECT query executed successfully.")
            else:
                formatted_result = "No results found."
                print("SQL SELECT query executed successfully but returned no rows.")

            state["query_rows"] = rows
        else:
            formatted_result = "The action has been successfully completed."
            print("SQL command executed successfully.")

        state["query_result"] = formatted_result
        state["sql_error"] = False

    except Exception as e:
        state["query_result"] = f"Error executing SQL query: {str(e)}"
        state["sql_error"] = True
        print(f"Error executing SQL query: {str(e)}")
    print(state['query_result'])
    return state


    
def generate_serious_answer(state: State):
    """
    Generates a serious and business-oriented response based on the SQL query results.

    Args:
        state (State): The current state of the workflow.
        config (RunnableConfig): Configuration for the runnable.

    Returns:
        State: Updated state with the final answer.
    """
    print("Generating a response for a related question.")
    question = state["question"]
    query_result= state['query_result']
    sql_query= state['sql_query']
    system = """
    You are a process mining assistant that helps users analyze business processes. Your task is to:
    1. Answer the user's original question based on the SQL query results.
    2. Provide relevant insights and business recommendations.
    
    ### **Context:**
    - The user's question: "{question}"
    - The SQL query executed: "{sql_query}"
    - The query result: "{query_result}"
    
    ### **How to Structure Your Response:**
    - **Start with a clear answer** to the user's question.
    - **Follow up with insights** based on the result. 
      - Identify trends, inefficiencies, or unusual patterns.
      - Suggest improvements (e.g., automation, better resource allocation).
      - Highlight comparisons if relevant (e.g., increase/decrease over time).
    
    ### **Example Outputs:**
    
    ❌ **User Question:** "What is the average duration of cases?"  
    ✅ **Response:**  
    "The average case duration is **5.2 days**.  
    Interestingly, cases in Department X take 40% longer than the company average. Consider automating Task Y to speed up the process."
    
    ❌ **User Question:** "How many cases happened in March?"  
    ✅ **Response:**  
    "There were **120 cases** in March.  
    This marks a 25% increase compared to February. If this trend continues, you may need additional resources for peak months."
    
    Keep your response concise, insightful, and business-oriented.
    """.format(question=question,sql_query=sql_query,query_result=query_result)
    human_message = f"Question: {question}"
    prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system),
            ("human", human_message),
        ]
    )
    llm=OllamaLLM(model="mistral:latest",temperature="0.0")
    response = prompt | llm | StrOutputParser()
    message = response.invoke({})
    state["final_answer"] = message
    print("Generated business response.")
    print(message)
    return state


def regenerate_query(state: State):
    """
    Reformulates the user's question to enable more precise SQL queries in case of errors.

    Args:
        state (State): The current state of the workflow.
        config (RunnableConfig): Configuration for the runnable.

    Returns:
        State: Updated state with the reformulated question.
    """
    question = state["question"]
    query=state['sql_query']
    error=state['query_result']
    schema=state['schema']
    print("Regenerating the SQL query by rewriting the question.")
    system = """You are an assistant that reformulates an original question to enable more precise SQL queries, while 
    considering that the previous sql query was {{query}} and it produced the next error {{error}}.. 
    Ensure that all necessary details, such as table joins, are preserved to retrieve complete and accurate data.
    take into account the database only has one table which follows the following schema:
    {{schema}}
    ### **Database Overview**
    - The database consists of **two tables**: `"cases"` and `"activity"`.
    - The `"cases"` table contains details about each case, uniquely identified by `"id"`.
    - The `"activity"` table stores activities related to cases.
      - The `"case"` column in `"activity"` corresponds to `"id"` in `"cases"`, forming a logical relationship.
      - Each **case can have multiple activities**, each recorded with a `"timestamp"`.
    ### Considerations:
    
        Answer only with the reformulated question, do not show what the query and error was. 
        No aditional information is needed
    """
    rewrite_prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system),
            (
                "human",
                f"Original Question: {question}\nReformulate the question to enable more precise SQL queries, ensuring all necessary details are preserved.",
            ),
        ]
    )
    llm= OllamaLLM(model="mistral:latest", temperature=0.0)
    rewriter = rewrite_prompt | llm
    rewritten = rewriter.invoke({"schema":schema,"query":query,"error":error})
    state["question"] = rewritten
    state["attempts"] += 1
    print(f"Rewritten question: {state['question']}")
    return state


def end_max_iterations(state: State):
    """
    Ends the workflow after reaching the maximum number of attempts.

    Args:
        state (State): The current state of the workflow.
        config (RunnableConfig): Configuration for the runnable.

    Returns:
        State: Updated state with a termination message.
    """
    state["query_result"] = "Please try again."
    print("Maximum attempts reached. Ending the workflow.")
    return state



def generate_funny_response(state: State):
    """
    Generates a playful and humorous response for unrelated questions.

    Args:
        state (State): The current state of the workflow.
        config (RunnableConfig): Configuration for the runnable.

    Returns:
        State: Updated state with the funny response.
    """
    print("Generating a funny response for an unrelated question.")
    question = state["question"]
    system = """You are a charming and funny assistant who responds in a playful manner.
    """
    human_message = f"Question: {question}"
    funny_prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system),
            ("human", human_message),
        ]
    )
    llm=OllamaLLM(model="mistral:latest",temperature="0.7")
    funny_response = funny_prompt | llm | StrOutputParser()
    message = funny_response.invoke({})
    state["final_answer"] = message
    print("Generated funny response.")
    print(message)
    return state

    
def check_attempts_router(state: State):
    """
    Routes the workflow based on the number of attempts made to generate a valid SQL query.

    Args:
        state (State): The current state of the workflow.

    Returns:
        str: The next node in the workflow.
    """
    if state["attempts"] < 3:
        return "convert_to_sql"
    else:
        return "end_max_iterations"



def execute_sql_router(state: State):
    """
    Routes the workflow based on whether the SQL query execution was successful.

    Args:
        state (State): The current state of the workflow.

    Returns:
        str: The next node in the workflow.
    """
    if not state.get("sql_error", False):
        return "generate_serious_answer"
    else:
        return "regenerate_query"

    
    
def relevance_router(state: State):
    """
    Routes the workflow based on the relevance of the user's question.

    Args:
        state (State): The current state of the workflow.

    Returns:
        str: The next node in the workflow.
    """
    if state["relevance"].lower() == "relevant":
        return "convert_to_sql"
    else:
        return "generate_funny_response"


In [6]:
workflow = StateGraph(State)
workflow.add_node("check_relevance", check_relevance)
workflow.add_node("convert_to_sql", convert_nl_to_sql)
workflow.add_node("execute_sql",execute_sql)
workflow.add_node("format_query",format_query)
workflow.add_node("regenerate_query",regenerate_query)
workflow.add_node("generate_funny_response", generate_funny_response)
workflow.add_node("generate_serious_answer",generate_serious_answer)
workflow.add_node("end_max_iterations",end_max_iterations)

workflow.add_edge(START, "check_relevance")
workflow.add_conditional_edges(
        "check_relevance",
        relevance_router,
        {
        "convert_to_sql": "convert_to_sql",
            "generate_funny_response": "generate_funny_response" ,
        } 
    )
workflow.add_edge("convert_to_sql", "format_query")
workflow.add_edge("format_query","execute_sql")

workflow.add_conditional_edges(
        "execute_sql",
        execute_sql_router,
        {
            "generate_serious_answer": "generate_serious_answer",
            "regenerate_query": "regenerate_query",
        },
    )

workflow.add_conditional_edges(
        "regenerate_query",
        check_attempts_router,
        {
            "convert_to_sql": "convert_to_sql",
            "end_max_iterations": "end_max_iterations",
        },
    )
workflow.add_edge("end_max_iterations", END)
workflow.add_edge("generate_serious_answer",END)
workflow.add_edge("generate_funny_response",END)

chain= workflow.compile()


In [8]:
question= input("Enter your question: ")
state= chain.invoke({"question":question,"schema":schema,"db_conn":db_conn})

Checking relevance of the question: how much was sold in march?
Relevance determined: relevant
Converting question to SQL how much was sold in march?
Generated SQL query: SELECT SUM(value) FROM cases WHERE strftime('%m', insurance_start) = '03' AND approved = 1;

Formatting query.
Executing SQL query: SELECT SUM("value") FROM "cases" WHERE EXTRACT(MONTH FROM "insurance_start") = 3 AND "approved" = 1;
SQL SELECT query executed successfully.
sum("value"): 1012700
Generating a response for a related question.
Generated business response.
 Answer: The total amount sold in March was **$1,012,700**.

   Insights: This represents a significant increase compared to the average monthly sales, indicating a successful marketing campaign or increased customer interest during that period. To maintain this momentum, consider reinvesting in marketing efforts for similar campaigns in future months. Additionally, analyzing the specific products with high sales in March could help identify opportunities