In [None]:
!pip install fpdf plotly kaleido
from dotenv import load_dotenv
load_dotenv()
from langchain_openai import ChatOpenAI
from langgraph_supervisor import create_supervisor
from langgraph.prebuilt import create_react_agent
from langchain.agents import AgentExecutor
from typing_extensions import Annotated, TypedDict
from typing import Callable, Literal, Optional, Sequence, Type, TypeVar, Union, cast, List, Tuple, Any
from langchain_core.messages import AIMessage, BaseMessage, SystemMessage, ToolMessage
from langgraph.graph.message import add_messages
from langgraph.graph import StateGraph
from langgraph.prebuilt.tool_node import ToolNode, tools_condition
from langgraph.graph import MessagesState, START, END
from langgraph.types import Command
from langchain_core.messages import HumanMessage, AIMessage
from pydantic import BaseModel, Field
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_core.tools import tool
import json
import re
from fpdf import FPDF
import datetime
import os
import pandas as pd
import plotly.graph_objects as go

In [None]:
from sql_react_agent_llama import SQL_SUBAGENT, make_dataframe
from viz_agent import VIZ_AGENT
from initiate_llm import gpt_llm, llama_llm

##### Agent State and Setup

In [None]:
class AgentState(TypedDict):
    """The state of the agent."""
    messages: Annotated[Sequence[BaseMessage], add_messages]
    question: Annotated[str, Field(description="The user question")]
    sql_query: Annotated[str, Field(description="The SQL query generated by the agent")]
    results: List[List[Union[int, float, str, bool]]] = Field(..., description="The results returned by the SQL agent as a list of tuples.")
    df: pd.DataFrame = Field(..., description="The table returned by the make_table node as a pandas dataframe.")
    python_visualization_code: Annotated[str, Field(description="The Python code to visualize the results")]

##### SQL Sub-agent

In [None]:
def nl2sql_node(state: AgentState) -> Command[Literal["supervisor"]]:
    last = state['messages'][-1]
    question = last.content
    try:
        question = json.loads(question)
        question = question["question"]
    except json.JSONDecodeError:
        pass        
    result = SQL_SUBAGENT.invoke({"messages": [HumanMessage(content=last.content)]})
    return Command(
        update={
            "messages": [
                HumanMessage(content=result["messages"][-1].content, name="sql_agent")
            ],
            "question": question,
            "sql_query": json.loads(result["messages"][-1].content)['query'],
            "results": json.loads(result["messages"][-1].content)['result'],
        },
        goto="supervisor",
    )

##### Make table node

In [None]:
def make_table_node(state: AgentState) -> Command[Literal["supervisor"]]:
    print("Entering make_table_node")
    query = state["sql_query"]
    results = state["results"]
    print(f"Query: {query}, Results: {results}")
    try:
        df = make_dataframe(query, results)
        print(f"DataFrame created: {df}")
        return Command(
            update={
                "messages": [
                    HumanMessage(content=f"Here is the table created from the results:\n{df.to_string()}", name="make_table")
                ],
                "df": df,
            },
            goto="supervisor",
        )
    except Exception as e:
        print(f"Error in make_table_node: {str(e)}")
        return Command(
            update={
                "messages": [
                    HumanMessage(content=f"Error creating DataFrame: {str(e)}", name="make_table")
                ],
                "df": pd.DataFrame(),
            },
            goto="supervisor",
        )

##### Visualization Node

In [None]:
def viz_node(state: AgentState) -> Command[Literal["supervisor"]]:
    question = state["question"]
    df = state["df"]
    results = state["results"]
    viz_agent_state = {
        "question": question,
        "df": df,
        "results": results,
        "messages": [],
        "error": ""
    }
    response = VIZ_AGENT.invoke(viz_agent_state)
    viz_code = response["generation"].imports + "\n" + response["generation"].code
    return Command(
        update={
            "messages": [
                HumanMessage(content=f"Here is the visualization code for plotting the data:\n{viz_code}", name="viz_agent"),
            ],
            "python_visualization_code": viz_code,
        },
        goto="supervisor",
    )

##### Setting up Supervisor and Report Generation

In [None]:
members = ["sql_agent", "make_table", "viz_agent"]
options = members + ["generate_report", "finish"]

In [None]:
supervisor_prompt = (
    """
    You are a supervisor agent named 'SQLOrchestrator' tasked with managing the following workers: {members}.

    Your job is to orchestrate a multi-step pipeline to answer the user's question stored in `AgentState.question`. Examine the current `AgentState` fields and decide the next worker based on the following strict rules:

    1. **SQL Generation & Execution**  
       - If `AgentState.results` is empty, an empty string, or not populated with query results, route to `sql_agent`.  
       - The `sql_agent` generates `AgentState.sql_query`, executes it, and populates `AgentState.results` with a list of result tuples.

    2. **Table Construction**  
       - After `results` is populated (i.e., contains a non-empty list of result tuples), if `AgentState.df` is empty, an empty string, or marked as '<not populated>', route to `make_table`.  
       - The `make_table` node converts `AgentState.results` into a pandas DataFrame and populates `AgentState.df`.

    3. **Visualization Code**  
       - After `df` is populated (i.e., contains a DataFrame with data, marked as '<DataFrame populated>'), if `AgentState.python_visualization_code` is empty or an empty string, route to `viz_agent`.  
       - The `viz_agent` generates Python visualization code and populates `AgentState.python_visualization_code`.

    4. **Report Generation or Completion**  
       - After `results`, `df`, and `python_visualization_code` are all populated (i.e., `results` is a non-empty list, `df` is '<DataFrame populated>', and `python_visualization_code` is a non-empty string):
         - If the user's question in `AgentState.question` contains phrases like 'generate a report' or 'create a report' (case-insensitive), route to `generate_report`. The `generate_report` node creates a PDF report summarizing the results and visualization.
         - Otherwise, route to `finish` and provide a final answer summarizing `AgentState.df` and referencing `AgentState.python_visualization_code`.

    **Important Instructions**:
    - Follow the pipeline (`sql_agent` -> `make_table` -> `viz_agent`) for all queries.
    - Only route to `generate_report` if the user explicitly requests a report in `AgentState.question` (e.g., includes 'generate a report' or 'create a report').
    - For non-report queries, complete the workflow after `viz_agent` by routing to `finish` with a clear summary of the results.
    - Treat `AgentState.df` as populated only if it is marked as '<DataFrame populated>' in the state.
    - Do not skip steps or route to `generate_report` unless explicitly requested and all prior steps are complete.
    - If unsure, prioritize completing the routine pipeline and routing to `finish` over premature report generation.
    """
)

In [None]:
class Router(TypedDict):
    """Worker to route to next. If no workers needed, route to finish or generate_report with final answer."""
    next: Literal["sql_agent", "make_table_node", "viz_agent", "generate_report", "finish"]
    final_answer: str

In [None]:
def supervisor_node(state: AgentState) -> Command[Literal["sql_agent", "make_table_node", "viz_agent", "generate_report", "__end__"]]:
    def serialize_state(obj):
        if isinstance(obj, pd.DataFrame):
            return "<DataFrame populated>" if not obj.empty else "<not populated>"
        return str(obj)
    print("\n\n========BACK TO SUPERVISOR========\n")
    state_info = {k: v for k, v in state.items() if k != 'messages'}
    serialized_state = json.dumps(state_info, default=serialize_state)
    messages = [
        {"role": "system", "content": supervisor_prompt.format(members=members)},
        {"role": "user", "content": f"AgentState: {serialized_state}"}
    ] + state["messages"]
    print("INVOKING WITH STATE INFO\n", serialized_state)
    response = gpt_llm.with_structured_output(Router).invoke(messages)
    goto = response["next"]
    print(f"Next Worker: {goto}")
    print(f"LLM Response: {response}")
    if goto in ["generate_report", "finish"]:
        return Command(
            update={
                "messages": [
                    AIMessage(content=response["final_answer"], name="supervisor")
                ]
            },
            goto="generate_report" if goto == "generate_report" else END
        )
    return Command(goto=goto)

In [None]:
def generate_report(state: AgentState) -> Command[Literal["__end__"]]:
    print("Starting PDF generation...")
    pdf = FPDF()
    pdf.add_page()
    pdf.set_font("Arial", size=12)

    # Extract relevant information
    question = state.get("question", "Unknown question")
    sql_query = state.get("sql_query", "No SQL query found")
    df = state.get("df", pd.DataFrame())
    viz_code = state.get("python_visualization_code", "")

    # Add title
    pdf.cell(200, 10, txt="Analysis Report", ln=True, align='C')
    pdf.ln(10)

    # Add question
    pdf.set_font("Arial", "B", 12)
    pdf.cell(0, 10, txt="Question:", ln=True)
    pdf.set_font("Arial", size=12)
    pdf.multi_cell(0, 10, txt=question)
    pdf.ln(5)

    # Add SQL query
    if sql_query and sql_query != "No SQL query found":
        pdf.set_font("Arial", "B", 12)
        pdf.cell(0, 10, txt="SQL Query:", ln=True)
        pdf.set_font("Arial", size=12)
        pdf.multi_cell(0, 10, txt=sql_query)
        pdf.ln(5)

    # Add results
    if not df.empty:
        pdf.set_font("Arial", "B", 12)
        pdf.cell(0, 10, txt="Results:", ln=True)
        pdf.set_font("Arial", size=12)
        table_str = df.head(10).to_string(index=False)
        pdf.multi_cell(0, 10, txt=table_str)
        pdf.ln(5)

    # Generate and include visualization
    plot_file = None
    if viz_code:
        try:
            # Execute Plotly code to save figure as image
            plot_file = f"plot_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png"
            exec("import pandas as pd\n" + viz_code + f"\nfig.write_image('{plot_file}')", {"df": df, "pd": pd, "go": go, "fig": None})

            # Add visualization to PDF
            pdf.set_font("Arial", "B", 12)
            pdf.cell(0, 10, txt="Visualization:", ln=True)
            pdf.image(plot_file, x=10, w=180)
            pdf.ln(5)
        except Exception as e:
            pdf.set_font("Arial", size=12)
            pdf.multi_cell(0, 10, txt=f"Error generating visualization: {str(e)}")
            pdf.ln(5)
        finally:
            # Clean up plot file
            if plot_file and os.path.exists(plot_file):
                os.remove(plot_file)

    # Save PDF
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    filename = f"report_{timestamp}.pdf"
    pdf.output(filename)
    print(f"Report saved as {filename} in {os.getcwd()}")

    return Command(
        update={
            "messages": [
                AIMessage(content=f"Report has been generated and saved as {filename}", name="generate_report")
            ]
        },
        goto=END
    )

In [None]:
builder = StateGraph(AgentState)
builder.add_edge(START, "supervisor")
builder.add_node("supervisor", supervisor_node)
builder.add_node("sql_agent", nl2sql_node)
builder.add_node("make_table_node", make_table_node)
builder.add_node("viz_agent", viz_node)
builder.add_node("generate_report", generate_report)
graph = builder.compile()

In [None]:
from IPython.display import Image, display

try:
    display(Image(graph.get_graph().draw_mermaid_png()))
except Exception as e:
    print(str(e))
    pass

In [None]:
# Test routine query (no report)
initial_state = {
    "messages": [HumanMessage(content="Give me the number of employees present from each ethnicity")],
    "question": "",
    "sql_query": "",
    "results": "",
    "df": "",
    "python_visualization_code": ""
}

result = graph.invoke(initial_state)
print("Routine Query Result:", result)
print(f"Final messages: {result['messages']}")

In [None]:
# Test report query
report_state = {
    "messages": [HumanMessage(content="Give me the number of employees for each ethnicity and generate me a report")],
    "question": "",
    "sql_query": "",
    "results": "",
    "df": "",
    "python_visualization_code": ""
}

test_result = graph.invoke(report_state)
print("Report Query Result:", test_result)
print(f"Final messages: {test_result['messages']}")
# Verify the PDF was generated
for msg in test_result['messages']:
    if 'Report has been generated' in msg.content:
        print(f"PDF Generation Confirmation: {msg.content}")