## imports

In [None]:
import json
import os
from typing import Annotated, TypedDict, Optional

# Assuming these are available or you will implement them
# from DataDistributor import DataDistributor
# from Constants import Constants # For API keys etc.

# You will need to replace this with your actual LLM initialization
# from langchain_core.language_models import BaseChatModel
# from langchain_core.output_parsers import JsonOutputParser
# from langchain_core.runnables import RunnablePassthrough
from langchain_core.messages import HumanMessage

# --- LangGraph Setup ---
from langgraph.graph import StateGraph, END, START
# from langgraph.prebuilt import ToolExecutor # You will likely use this in your real implementation
from langchain.tools import BaseTool

## Tools

In [None]:
# Placeholder for your tool implementations
# You will need to replace this with your actual tool definitions
# from your existing project (e.g., vector_search, parameter_search)
class MockVectorSearchTool:
    def invoke(self, query):
        print(f"MockVectorSearchTool invoked with query: {query}")
        # Return mock structured data for visualization testing
        return [
            {"organization": "Org A", "category": "Category 1", "value": 100},
            {"organization": "Org B", "category": "Category 2", "value": 150},
            {"organization": "Org A", "category": "Category 1", "value": 120},
            {"organization": "Org C", "category": "Category 1", "value": 80},
        ]

class MockParameterSearchTool:
    def invoke(self, query):
        print(f"MockParameterSearchTool invoked with query: {query}")
        # Return mock structured data for visualization testing
        return [
            {"organization": "Org A", "category": "Category 1", "value": 100},
            {"organization": "Org B", "category": "Category 2", "value": 150},
            {"organization": "Org A", "category": "Category 1", "value": 120},
            {"organization": "Org C", "category": "Category 1", "value": 80},
        ]

## DataFraming 

In [None]:

# You will need to implement this based on your data structure
def data_to_dataframe(raw_data: list) -> 'pd.DataFrame':
    """
    Transforms raw data (list of dicts) into a Pandas DataFrame.
    You will need pandas installed: pip install pandas
    """
    import pandas as pd
    print("Transforming data to DataFrame...")
    # Add logic to flatten or process data if needed
    return pd.DataFrame(raw_data)

## LLM Classes 

In [None]:
# Mock LLM classes for decision making and response generation
class MockDecisionLLM:
    async def ainvoke(self, prompt: str):
        print(f"MockDecisionLLM received prompt:\n{prompt}")
        # Simple mock logic to decide based on keywords
        if "visualization" in prompt.lower() or "chart" in prompt.lower() or "graph" in prompt.lower() or "plot" in prompt.lower():
             # Simulate a visualization request decision
             return HumanMessage(content='{{"request_type": "visualization", "chart_type": "bar", "data_fields": ["organization", "category"]}}')
        else:
             # Simulate a text workflow decision
             return HumanMessage(content='{{"request_type": "continue_text_workflow"}}')

class MockToolCallingLLM:
     async def ainvoke(self, prompt: str):
        print(f"MockToolCallingLLM received prompt:\n{prompt}")
        # Simple mock logic to decide which tool to call
        if "specific attributes" in prompt.lower() or "precise terms" in prompt.lower():
            tool_name = "parameter_search_tool"
        else:
            tool_name = "vector_search_tool"
        # Simulate a tool call
        return HumanMessage(content=json.dumps({
            "additional_kwargs": {
                "function_call": {
                    "name": tool_name,
                    "arguments": "{}" # Mock arguments, refine as needed
                }
            }
        }))


class MockSummarizeLLM:
     async def ainvoke(self, prompt: str):
        print(f"MockSummarizeLLM received prompt:\n{prompt}")
        # Simulate a summary
        return HumanMessage(content="This is a summarized context based on the search results.")

class MockGenerateAnswerLLM:
     async def ainvoke(self, prompt: str):
        print(f"MockGenerateAnswerLLM received prompt:\n{prompt}")
        # Simulate generating a text answer
        return HumanMessage(content="This is the final text answer based on the summary.")



## Utilities

In [None]:
# Placeholder for your add_messages function
def add_messages(left: list, right: list):
    """Append the new messages to the existing messages."""
    left.extend(right)
    return left

# Update State TypedDict
class State(TypedDict):
    messages: Annotated[list, add_messages]
    history: str
    question: str
    should_split: bool
    split_questions: list
    needs_history: bool
    needs_context: bool
    context: str
    summarized_context: str
    final_answer: str

    # New fields for visualization
    visualization_request: Optional[dict]
    visualization_data: Optional[list]
    visualization_output: Optional[dict]
    visualization_error: Optional[str]

# You will wrap your logic in a class like this
class GraphBuilder:
    def __init__(self):
        # Initialize your actual LLM models and tools here
        self.decision_model = MockDecisionLLM() # Replace with your LLM for decisions
        self.tool_model = MockToolCallingLLM() # Replace with your LLM for tool calling
        self.summarize_model = MockSummarizeLLM() # Replace with your LLM for summarization
        self.answer_model = MockGenerateAnswerLLM() # Replace with your LLM for answer generation

        # Initialize your actual tools here
        self.tools = {
            "vector_search_tool": MockVectorSearchTool(), # Replace with your vector search tool
            "parameter_search_tool": MockParameterSearchTool(), # Replace with your parameter search tool
        }
        # self.tool_executor = ToolExecutor(list(self.tools.values())) # Use ToolExecutor in real implementation

## LIDA Manager

In [None]:
        # Placeholder for LIDA Manager - requires installation: pip install lida
        # And you'll need an API key for LIDA's text generation model
        # from lida import Manager, TextGenerationConfig, DataConfig
        # from Constants import Constants # Assuming Constants holds your API key

## Workflow

In [None]:
    def create_workflow(self):
        workflow = StateGraph(State)

        # Existing nodes (placeholders for your actual implementations)
        workflow.add_node("fetch_context", self.fetch_context) # Fetches data based on query
        workflow.add_node("summarize", self.summarize_context) # Summarizes context for text answer
        workflow.add_node("generate_answer", self.generate_answer) # Generates final text answer

        # New node for visualization
        workflow.add_node("generate_visualization", self.generate_visualization_node) # Will call the LIDA tool logic

        # Add a decision node
        workflow.add_node("decide_next_step", self.decide_next_step)

        # Define edges
        workflow.add_edge(START, "fetch_context")
        workflow.add_edge("fetch_context", "decide_next_step") # After fetching data, decide what to do next

        # Conditional edge from decide_next_step
        workflow.add_conditional_edge(
            "decide_next_step",
            self.route_request, # Method to determine the next node
            {
                "generate_visualization": "generate_visualization", # If visualization is requested
                "continue_text_workflow": "summarize" # If a standard text answer is needed
            }
        )

        # Edges from new nodes
        # Note: In your plan, generate_visualization goes to END.
        # If you need to potentially add a text response *after* visualization,
        # you might route it to 'generate_answer' instead of END, and modify generate_answer
        # to handle both text-only and viz+text responses. For this setup,
        # we'll follow your plan and route viz to END.
        workflow.add_edge("generate_visualization", END)
        workflow.add_edge("summarize", "generate_answer") # Existing text workflow continues
        workflow.add_edge("generate_answer", END) # End of text workflow

        return workflow

Node Methods

In [None]:
    def create_workflow(self):
        workflow = StateGraph(State)

        # Existing nodes (placeholders for your actual implementations)
        workflow.add_node("fetch_context", self.fetch_context) # Fetches data based on query
        workflow.add_node("summarize", self.summarize_context) # Summarizes context for text answer
        workflow.add_node("generate_answer", self.generate_answer) # Generates final text answer

        # New node for visualization
        workflow.add_node("generate_visualization", self.generate_visualization_node) # Will call the LIDA tool logic

        # Add a decision node
        workflow.add_node("decide_next_step", self.decide_next_step)

        # Define edges
        workflow.add_edge(START, "fetch_context")
        workflow.add_edge("fetch_context", "decide_next_step") # After fetching data, decide what to do next

        # Conditional edge from decide_next_step
        workflow.add_conditional_edge(
            "decide_next_step",
            self.route_request, # Method to determine the next node
            {
                "generate_visualization": "generate_visualization", # If visualization is requested
                "continue_text_workflow": "summarize" # If a standard text answer is needed
            }
        )

        # Edges from new nodes
        # Note: In your plan, generate_visualization goes to END.
        # If you need to potentially add a text response *after* visualization,
        # you might route it to 'generate_answer' instead of END, and modify generate_answer
        # to handle both text-only and viz+text responses. For this setup,
        # we'll follow your plan and route viz to END.
        workflow.add_edge("generate_visualization", END)
        workflow.add_edge("summarize", "generate_answer") # Existing text workflow continues
        workflow.add_edge("generate_answer", END) # End of text workflow

        return workflow

##Node Methods

In [None]:
    # --- New and Modified Node Methods ---

    async def decide_next_step(self, state: State) -> State:
         """
         Decides whether to generate a visualization or continue with the text workflow.
         If visualization is requested, attempts to fetch the specific data needed.
         Sets the visualization_request and visualization_data fields in the state.
         """
         question = state['question']
         history = state.get('history', '')
         decision_prompt = f"""
         Based on the following user question and chat history, determine if the user is requesting a data visualization (chart, graph, plot).

         If the user explicitly asks for a visualization (e.g., "show me a bar chart", "plot X vs Y", "graph the distribution of Z"), respond with a JSON object indicating:
         1. "request_type": "visualization"
         2. "chart_type": Infer the type if mentioned (e.g., "bar", "line", "scatter"), otherwise null.
         3. "data_fields": A list of the specific fields from the data ("organization", "category", "research_areas", "keywords", "contacts", etc.) that are needed to create this visualization. Analyze the question carefully to identify the relevant fields.

         If the user is NOT requesting a visualization, respond with a JSON object indicating:
         1. "request_type": "continue_text_workflow"

         Your response should be a JSON object and contain ONLY the JSON object.
         """
         try:
             response = await self.decision_model.ainvoke(decision_prompt)
             decision = json.loads(response.content.strip())
             print(f"Decision from LLM: {decision}")

             if decision.get("request_type") == "visualization":
                  state['visualization_request'] = decision # Store the structured request
                  # Data fetching will happen in fetch_context or a dedicated node before viz generation
             else:
                 state['visualization_request'] = None # No visualization requested

         except Exception as e:
             print(f"Error in decide_next_step: {e}")
             state['visualization_request'] = None
             state['visualization_error'] = f"Error determining visualization request: {e}"

         return state

    def route_request(self, state: State) -> str:
        """
        Routes the workflow based on the visualization_request state.
        """
        if state.get('visualization_request'):
            print("Routing to generate_visualization")
            return "generate_visualization"
        else:
            print("Routing to continue_text_workflow")
            return "continue_text_workflow"

    async def generate_visualization_node(self, state: State) -> State:
        """
        Calls the LIDA tool logic to generate a visualization.
        """
        print("Executing generate_visualization_node")
        question = state['question']
        raw_data = state.get('visualization_data') # Data should be populated by fetch_context
        viz_request = state.get('visualization_request')

        if not raw_data:
            state['final_answer'] = "I couldn't find enough data to create a visualization for that."
            state['visualization_error'] = "No data context available for visualization."
            return state

        try:
            # Transform data for LIDA (list of dicts to DataFrame)
            # Ensure you have pandas installed: pip install pandas
            df = data_to_dataframe(raw_data)

            # Initialize LIDA (requires an API key, e.g., OpenAI or Google Generative AI)
            # Ensure you have lida installed: pip install lida
            # and the necessary text generation library (e.g., pip install openai)
            from lida import Manager, TextGenerationConfig, DataConfig
            # from Constants import Constants # Assuming Constants has get_google_ai_key or similar

            # Replace with your actual API key retrieval
            # lida_api_key = Constants._get_google_ai_key()
            # Or set it directly for testing in notebook
            lida_api_key = os.environ.get("GOOGLE_API_KEY") # Example using environment variable

            if not lida_api_key:
                 state['final_answer'] = "LIDA API key not found. Please set it up."
                 state['visualization_error'] = "Missing LIDA API key."
                 return state

            lida = Manager(text_gen_config=TextGenerationConfig(
                model="gemini-flash", # Or the model you intend to use
                api_key=lida_api_key
            ))

            # Summarize the data using LIDA
            # LIDA's summarize takes a file path or pandas DataFrame
            # If using DataFrame directly, you might need DataConfig
            data_summary = lida.summarize(data_frame=df, data_config=DataConfig())
            print(f"LIDA Data Summary: {data_summary}")

            # Use the user's query or the parsed viz_request for visualization generation
            visualization_goal = question # You might refine this based on viz_request details

            # Generate and execute visualization code specs
            viz_code_specs = lida.generate_viz(
                summary=data_summary,
                goal=visualization_goal, # Use user query as goal
                textgen_config=TextGenerationConfig(model="gemini-flash", api_key=lida_api_key),
                library="altair" # Specify a library LIDA should use (altair is common for vega-lite)
                # Ensure you have altair installed: pip install altair
            )
            print(f"LIDA Generated Viz Code Specs: {viz_code_specs}")

            if viz_code_specs:
                 # Execute the visualization code
                 # LIDA's execute_viz needs the code specs, data, and summary
                 # Ensure the data format is what execute_viz expects (likely DataFrame)
                 charts = lida.execute_viz(
                      code_specs=viz_code_specs[0].code, # Assuming the first spec is relevant and has a 'code' field
                      data=df,
                      summary=data_summary,
                      library="altair"
                  )
                 print(f"LIDA Executed Charts: {charts}")

                 if charts:
                      # Assuming charts is a list of dictionaries, each with 'spec' and 'raster' (image)
                      # Let's return the Vega-Lite spec if available
                     first_chart = charts[0]
                     if 'spec' in first_chart:
                          state['visualization_output'] = {
                              "type": "vega-lite",
                              "spec": first_chart['spec'] # Vega-Lite JSON specification
                          }
                          state['final_answer'] = "Here is the visualization you requested:" # A text intro
                     elif 'raster' in first_chart:
                          # If no spec but image is available, return the image (base64)
                          # LIDA might return base64 directly or a path. Check LIDA docs/output.
                          # Assume it's base64 for now.
                         state['visualization_output'] = {
                             "type": "image",
                             "data": first_chart['raster'] # Base64 encoded image data
                         }
                         state['final_answer'] = "Here is the visualization you requested:" # A text intro
                     else:
                         state['final_answer'] = "I generated a visualization, but I couldn't format it correctly."
                         state['visualization_error'] = "LIDA execution returned unexpected format."

                 else:
                      state['final_answer'] = "I was able to generate visualization code, but executing it failed."
                      state['visualization_error'] = "LIDA execute_viz failed."

            else:
                state['final_answer'] = "I understood you wanted a visualization, but I couldn't generate a suitable one based on the data and your request."
                state['visualization_error'] = "LIDA generate_viz failed."

        except ImportError as e:
             state['final_answer'] = f"Missing required library: {e}. Please install it (e.g., pip install lida pandas altair)."
             state['visualization_error'] = f"Import error: {e}"
        except Exception as e:
            print(f"Error during LIDA visualization generation: {e}")
            state['final_answer'] = "An error occurred while trying to generate the visualization."
            state['visualization_error'] = f"LIDA generation error: {e}"


        return state

## Fetch Context and Raw Output

In [None]:
    # Modify fetch_context to store raw results if visualization is requested
    async def fetch_context(self, state: State) -> State:
        print("Executing fetch_context")
        question = state['question']
        history = state.get('history', '')
        viz_request = state.get('visualization_request') # Check if viz was requested

        prompt = f"Question : {question},\n Chat History: {state.get('history','')}"

        try:
            # Use the tool_model to decide which search tool to use
            tool_decision_prompt = f"""
            Question Instructions\\n
                Inputs:

                Question: {question}

                Chat History: {history}

                Task: Your task is to generate an effective search query and determine the appropriate tool (vector_search_tool or parameter_search_tool) to retrieve the necessary information from the database. Consider both the question and the chat history (if relevant) when constructing your query and selecting the tool.
                If the question asks for specific attributes or uses precise terms, favor `parameter_search_tool`. If it's a more general or conceptual query, favor `vector_search_tool`.
                Also, determine if the user is asking for a *count* of entities. If so, set the `number` argument for `vector_search_tool` to true.

                Output: Return the tool call in the specified format.

                End of Instructions\\n
            """
            tool_d = await self.tool_model.ainvoke(tool_decision_prompt)
            tool_call = tool_d.additional_kwargs.get("function_call", {})

            if tool_call:
                func_name = tool_call.get("name", "")
                # In a real scenario, you'd parse and use func_args
                # For this mock, we'll just use a simple query string
                tool_query = question # Use the original question as a simple query

                if func_name in self.tools:
                    print(f"Invoking tool: {func_name}")
                    # Invoke the tool and get raw output
                    raw_output = self.tools[func_name].invoke(tool_query) # Use .invoke for sync mock tools

                    if viz_request and raw_output and isinstance(raw_output, list):
                        # If viz is requested and tool returned structured data (e.g., from parameter_search)
                        # Store this raw data for the visualization tool
                        state['visualization_data'] = raw_output
                        state['context'] = "Data fetched for visualization." # Indicate data was fetched
                        state['summarized_context'] = "" # Clear summarized context for viz path
                    else:
                        # If no viz request or tool didn't return structured data for viz,
                        # convert raw output to a string for summarization
                        context_string = json.dumps(raw_output) # Convert raw output to string
                        state['context'] = context_string
                        state['summarized_context'] = "" # Summarization will happen in the next node

                else:
                    state['context'] = "Could not identify relevant tool."
                    state['summarized_context'] = ""
                    state['final_answer'] = "I couldn't find the right tool to fetch data for that."

            else:
                 state['context'] = "Could not determine tool to call."
                 state['summarized_context'] = ""
                 state['final_answer'] = "I couldn't determine how to fetch data for your request."


        except Exception as e:
            print(f"Error in fetch_context: {e}")
            state['context'] = f"Error fetching data: {e}"
            state['summarized_context'] = ""
            state['final_answer'] = "An error occurred while fetching data."
            # Optionally, re-raise the exception if you want the graph to stop
            # raise e

        return state

    async def summarize_context(self, state: State) -> State:
        """
        Summarizes the fetched context for the text workflow.
        This node is skipped if a visualization is requested.
        """
        print("Executing summarize_context")
        context = state.get('context')
        if context and not state.get('visualization_request'):
             summary_prompt = f"""Summarize the following context for a chatbot response:
             {context}
             """
             try:
                 summary = await self.summarize_model.ainvoke(summary_prompt)
                 state['summarized_context'] = summary.content
             except Exception as e:
                 print(f"Error in summarize_context: {e}")
                 state['summarized_context'] = "Error summarizing context."
                 state['final_answer'] = "An error occurred while summarizing the information."
                 # Optionally, re-raise the exception
                 # raise e
        elif state.get('visualization_request'):
             print("Skipping summarization as visualization is requested.")
             state['summarized_context'] = "" # Ensure empty if skipped

        return state

    async def generate_answer(self, state: State) -> State:
        """
        Generates the final text answer based on the summarized context.
        This node is skipped if a visualization is requested (unless you modify routing).
        """
        print("Executing generate_answer")
        question = state['question']
        summarized_context = state.get('summarized_context')
        history = state.get('history', '')

        if summarized_context and not state.get('visualization_output'):
             answer_prompt = f"""Based on the following summarized context and chat history, answer the user's question:

             Question: {question}
             Chat History: {history}
             Summarized Context: {summarized_context}

             Provide a concise and helpful answer.
             """
             try:
                 answer = await self.answer_model.ainvoke(answer_prompt)
                 state['final_answer'] = answer.content
             except Exception as e:
                 print(f"Error in generate_answer: {e}")
                 state['final_answer'] = "An error occurred while generating the answer."
                 # Optionally, re-raise the exception
                 # raise e
        elif state.get('visualization_output') and state.get('final_answer') and "Here is the visualization" in state.get('final_answer'):
            # If visualization was successful, generate_visualization_node already set final_answer
            print("Skipping generate_answer as visualization output exists.")
            pass # Keep the final_answer set by generate_visualization_node
        else:
             # This case might be hit if neither summary nor viz output is ready, or an error occurred earlier
             if not state.get('final_answer'): # Avoid overwriting a previous error message
                 state['final_answer'] = "Could not generate a response based on the available information."


        return state

## Usage in Jupyter Notebook:

In [None]:
# --- How to use in your Jupyter Notebook ---

# 1. Initialize the GraphBuilder
graph_builder = GraphBuilder()

# 2. Create the workflow
workflow = graph_builder.create_workflow()

# 3. Compile the graph
app = workflow.compile()

# 4. Run the graph with an initial state
# You'll need to provide the initial state dictionary
# Example state for a text query:
initial_state_text = {
    'messages': [],
    'history': '',
    'question': 'Tell me about the organizations in the data.',
    'should_split': False,
    'split_questions': [],
    'needs_history': False,
    'needs_context': False,
    'context': '',
    'summarized_context': '',
    'final_answer': '',
    'visualization_request': None,
    'visualization_data': None,
    'visualization_output': None,
    'visualization_error': None
}

# Example state for a visualization query:
initial_state_viz = {
    'messages': [],
    'history': '',
    'question': 'Show me a bar chart of organizations by category.',
    'should_split': False,
    'split_questions': [],
    'needs_history': False,
    'needs_context': False,
    'context': '',
    'summarized_context': '',
    'final_answer': '',
    'visualization_request': None, # This will be set by decide_next_step
    'visualization_data': None,     # This will be populated by fetch_context
    'visualization_output': None,   # This will be populated by generate_visualization
    'visualization_error': None
}

# To run the graph (use await if in an async environment like certain parts of Jupyter)
# For a simple test in a synchronous notebook cell, you might need to use asyncio.run()
import asyncio

async def run_graph(state):
    # The graph runs until it reaches END
    result = await app.ainvoke(state)
    return result

# Example usage (run one at a time)
# print("Running text query...")
# text_result = await run_graph(initial_state_text)
# print("\n--- Text Result ---")
# print(text_result)

print("\nRunning visualization query...")
viz_result = await run_graph(initial_state_viz)
print("\n--- Visualization Result ---")
print(viz_result)

# To view the graph structure (requires graphviz: pip install graphviz)
# import graphviz
# graphviz.Source(app.get_graph().draw()).render('graph', view=True)

ModuleNotFoundError: No module named 'langchain_core'