In [1]:
from langgraph.graph import StateGraph,START,END,MessagesState
from typing import List,Literal,TypedDict,Optional,Dict
from pydantic_ai.messages import ModelMessagesTypeAdapter,ModelMessage
from pydantic_ai import Agent,RunContext,Tool
from langgraph.types import interrupt, Command,Literal
from dotenv import load_dotenv
from pydantic import BaseModel,Field
from dataclasses import dataclass
from langgraph.checkpoint.memory import MemorySaver
from langchain_experimental.utilities import PythonREPL
import nest_asyncio
import asyncio
import os
load_dotenv()

True

In [2]:
nest_asyncio.apply()

In [3]:
memory = MemorySaver()
config = {'configurable':{'thread_id':'1'}}

In [4]:
@dataclass
class unified_file_details:
    file_name: str
    file_path: str
    save_path: str = Field(default="D:/PROJECT_MULTIMODEL/Pydantic_ai/eda_pydantic/visualizations", 
                          description="Path where visualizations will be saved")

In [5]:
class Response(BaseModel):
    csv_file_name : str = Field(description="Name of the CSV file to be created")
    csv_file_path : str = Field(description="Path of the CSV file to be created")
    goal: List[str] = Field(description="Goals to be done by the agent defined by the user")
    next_question : str = Field(description="Next question to ask based on remaining fields")
    got_all : bool = Field(description="True if all fields are filled else False")

In [6]:
class Supervision(BaseModel):
    go_to: Literal["unified_agent", "__end__"] = Field(
        description="The next agent which needs to be called"
    )
    command: str = Field(description="Prompt or order for the agent")
    finished: bool = Field(description="True if all goals have been completed")
    finished_tasks: List[str] = Field(description="List of tasks which have been completed")
    final_answer_to_user: str = Field(description="Reply to be given to the user")

In [7]:

info_gathering_agent = Agent(
    'google-gla:gemini-2.0-flash',
    result_type=Response,
    system_prompt="""You are part of EDA (Exploratory Data Analysis) team. 
    Your task is to gather information about the data name and the goal of the user.
    Ask the user to provide the name of the csv file and the path where it is located.
    Then ask the user to provide the goals they want to achieve with the data.
    Once all fields are filled, you will return the CSV file name and path, goals of user, along with a message indicating that all fields have been filled.
    Always add finding data description and the data type as the first goal if any data is provided.
    Your goals are prompts to a supervisor agent.Which will use other agent to look into the data.
    """,
)

In [8]:
@dataclass
class task_monitoring:
    tasks_done : List[str] = Field(description="List of tasks which have been completed")
    response : str = Field(description="Response from the agent which did the task")
    execution_output : str = Field(description="Output of the agent after executing the task")
    code : str = Field(description="Code generated by the agent to do the task")
    saved_files : List[str] = Field(description="List of files which have been saved by the agent")

In [9]:
supervisor_agent = Agent(
    'google-gla:gemini-2.0-flash',
    result_type=Supervision,
    system_prompt="""You are a supervisor agent. You can give commands to the unified data agent that can:
    1. Analyze and preprocess data
    2. Create visualizations using matplotlib or seaborn
    
    You will give specific commands based on the user's goals.
    The unified agent has a Python sandbox to execute all tasks.
    You will keep track of completed tasks and remaining ones.
    For visualization tasks, suggest appropriate charts based on data types.
    """,
    deps_type=task_monitoring,
)

In [10]:
class UnifiedTask(BaseModel):
    tasks_done: List[str] = Field(description="A brief description of the tasks done by the agent")
    finished_task: bool = Field(description="True if all tasks have been completed")
    response: str = Field(description="Response of the agent after executing the task")
    code_execution_output: str = Field(description="Output of the agent after executing the task")
    code: str = Field(description="Code executed by the agent")
    saved_files: List[str] = Field(description="List of visualization files saved")

In [11]:
@dataclass
class file_details:
    file_name:str
    file_path:str

In [12]:
unified_agent = Agent(
    'google-gla:gemini-2.0-flash',
    result_type=UnifiedTask,
    tools=[PythonREPL().run],
    system_prompt="""You are a data analysis agent capable of both preprocessing and visualization.

    For preprocessing tasks:
    - Use pandas and numpy to analyze, clean, and transform data
    - Generate descriptive statistics and insights about the data
    
    For visualization tasks:
    - Create appropriate visualizations using matplotlib or seaborn based on data types
    - IMPORTANT: DO NOT use plt.show() as plots should not be displayed but saved to files
    - Always use plt.savefig() to save plots to the specified save path and then plt.close()
    - Name files descriptively (e.g., 'age_distribution.png', 'gender_counts.png')
    - Always return the list of saved files in your response
    
    IMPORTANT: Never use input() or any other user-interactive functions in your code.
    All data should come from the CSV file provided, and your code should run without requiring any user interaction.
    
    Execute all tasks in the Python sandbox and provide clear explanations of your findings.
    """,
    deps_type=unified_file_details,
)

In [13]:
class State(TypedDict):
    user_input: str 
    messages: List = Field(default=[])
    details: Response
    supervision: Supervision
    unified_task: UnifiedTask

In [14]:
def where_to_go(state:State)->Command[Literal['get_next_user_msg','supervisor']]:
    if state['details'].got_all:
        return 'supervisor'
    else:
        return 'get_next_user_msg'

In [15]:
async def get_user_info(state: State):

    message_history = state['messages']
    
    response = await info_gathering_agent.run(
        state['user_input'],
        message_history=message_history
    )
    state['details'] = response.data
    state['messages'] = message_history + response.new_messages()
    
    return state

In [16]:
async def get_next_msg(state:State):
    value = interrupt({})
    state['user_input']=value  
    return state

In [17]:
async def supervisor(state:State):
    if state['unified_task'].tasks_done == []:
        message_history = state['messages']
        details = state['details']
        prompt = "".join([str(i+1)+". "+ j + "\n\n" for i,j in enumerate(details.goal)])
        response = await supervisor_agent.run(prompt,message_history=message_history)
        state['supervision'] = response.data
        state['messages'] = message_history + response.new_messages()
    
    else:
        message_history = state['messages']
        details = state['details']
        prompt = "".join([str(i+1)+". "+ j + "\n\n" for i,j in enumerate(details.goal)])

        @supervisor_agent.system_prompt
        def get_system_prompt(ctx:RunContext[task_monitoring])->str:
            tasks_done = ctx.deps.tasks_done
            response_from_agent = ctx.deps.response
            execution_output = ctx.deps.execution_output
            code = ctx.deps.code
            saved_files = ctx.deps.saved_files
            return f"""Here are the tasks done and the response from the agent {tasks_done} and {response_from_agent}.
            The output from the latest execution {execution_output}, and the code generated by the agent {code}.
            You can use this information to give next commands or end.
            If any files are saved, please mention them in the response.Here are the files saved {saved_files}.
            Always remeber to keep track of outputs and by the end give the answers to the user.
            """

        tasks_montoring = task_monitoring(
            tasks_done = state['unified_task'].tasks_done,
            response = state['unified_task'].response,
            execution_output = state['unified_task'].code_execution_output,
            code = state['unified_task'].code,
            saved_files = state['unified_task'].saved_files,
        )

        

        response = await supervisor_agent.run(prompt,message_history=message_history,deps=tasks_montoring)
        state['supervision'] = response.data
        state['messages'] = message_history + response.new_messages()

    if state['unified_task'].tasks_done != []:
        print(state['unified_task'].tasks_done)
        print(state['unified_task'].response)
        print(state['unified_task'].code_execution_output)
        print(state['unified_task'].code)
        
        print("\n" + "="*80)
        print(f"🔍 TASKS COMPLETED BY AGENT ".ljust(79, "="))
        for i, task in enumerate(state['unified_task'].tasks_done, 1):
            print(f"{i}. {task}")
        
        print("\n" + "-"*80)
        print(f"💬 AGENT RESPONSE ".ljust(79, "-"))
        print(f"{state['unified_task'].response}")
        
        print("\n" + "-"*80)
        print(f"📊 CODE EXECUTION OUTPUT ".ljust(79, "-"))
        print(f"{state['unified_task'].code_execution_output}")
        
        print("\n" + "-"*80)
        print(f"💻 GENERATED CODE ".ljust(79, "-"))
        print(f"{state['unified_task'].code}")
        
        if state['unified_task'].saved_files:
            print("\n" + "-"*80)
            print(f"📁 SAVED FILES ".ljust(79, "-"))
            for i, file in enumerate(state['unified_task'].saved_files, 1):
                print(f"{i}. {file}")
        
        print("="*80)

    return state

In [18]:
async def unified_agent_node(state: State):
    @unified_agent.system_prompt
    def get_system_prompt(ctx: RunContext[unified_file_details]) -> str:
        csv_name = ctx.deps.file_name
        csv_path = ctx.deps.file_path
        save_path = ctx.deps.save_path
        return f'Prompt: Processing file {csv_name} at path {csv_path}. For any visualizations, save to {save_path}. DO NOT use plt.show(), only use plt.savefig() followed by plt.close().\n\n'
    
    file_info = unified_file_details(
        file_name=state['details'].csv_file_name,
        file_path=state['details'].csv_file_path,
        save_path="D:/PROJECT_MULTIMODEL/Pydantic_ai/eda_pydantic/visualizations"
    )
    
    prompt = state['supervision'].command
    
    response = await unified_agent.run(
        prompt,
        deps=file_info,
    )
    
    state['unified_task'] = response.data
    state['messages'] = state['messages'] + response.new_messages()
    

    if state['unified_task'].tasks_done != []:
        print(state['unified_task'].tasks_done)
        print(state['unified_task'].response)
        print(state['unified_task'].code_execution_output)
        print(state['unified_task'].code)
        
        # Replace with these improved print statements
        print("\n" + "="*80)
        print(f"🔍 TASKS COMPLETED BY AGENT ".ljust(79, "="))
        for i, task in enumerate(state['unified_task'].tasks_done, 1):
            print(f"{i}. {task}")
        
        print("\n" + "-"*80)
        print(f"💬 AGENT RESPONSE ".ljust(79, "-"))
        print(f"{state['unified_task'].response}")
        
        print("\n" + "-"*80)
        print(f"📊 CODE EXECUTION OUTPUT ".ljust(79, "-"))
        print(f"{state['unified_task'].code_execution_output}")
        
        print("\n" + "-"*80)
        print(f"💻 GENERATED CODE ".ljust(79, "-"))
        print(f"{state['unified_task'].code}")
        
        if state['unified_task'].saved_files:
            print("\n" + "-"*80)
            print(f"📁 SAVED FILES ".ljust(79, "-"))
            for i, file in enumerate(state['unified_task'].saved_files, 1):
                print(f"{i}. {file}")
        
        print("="*80)

    return state

In [19]:
def agent_selection(state: State) -> Command[Literal['unified_agent', '__end__']]:
    if state['supervision'].go_to == 'unified_agent':
        return 'unified_agent'
    else:
        return '__end__'

In [20]:
graph_builder = StateGraph(State)

In [21]:
graph_builder = StateGraph(State)

graph_builder.add_node("info_gatherer", get_user_info)
graph_builder.add_node("get_next_user_msg", get_next_msg)
graph_builder.add_node("supervisor", supervisor)
graph_builder.add_node("unified_agent", unified_agent_node)

graph_builder.add_conditional_edges("info_gatherer", where_to_go, {'get_next_user_msg':'get_next_user_msg', 'supervisor':'supervisor'})
graph_builder.add_conditional_edges("supervisor", agent_selection, {
    'unified_agent': 'unified_agent',
    '__end__': END
})

graph_builder.add_edge(START, "info_gatherer")
graph_builder.add_edge("get_next_user_msg", "info_gatherer")
graph_builder.add_edge("unified_agent", "supervisor")

# graph_builder.compile()

<langgraph.graph.state.StateGraph at 0x252a81fe690>

In [22]:
graph = graph_builder.compile(checkpointer=memory)

In [24]:

initial_details = Response(
    csv_file_name="",
    csv_file_path="",
    goal=[],
    next_question="",
    got_all=False
)
initial_supervision = Supervision(go_to="__end__", command="", finished=False, finished_tasks=[], final_answer_to_user="")
initial_unified_task = UnifiedTask(
    tasks_done=[],
    finished_task=False,
    response="",
    code_execution_output="",
    code="",
    saved_files=[]
)



In [24]:
response = await graph.ainvoke({
    'user_input': "Tell me what are the column names and columns numbers are present in users_data.csv and path is users_data.csv",
    'messages': [],
    'details': initial_details,
    'supervision': initial_supervision,
    'unified_task': initial_unified_task,
}, config=config)



Python REPL can execute arbitrary code. Use with caution.


['Analyzed the CSV file to identify column names, their numbers, and data types.']
The column names and their corresponding column numbers are listed, along with a description of each column's data type. The data types were inferred directly from the pandas DataFrame.
Column Names with Column Numbers:
1. id
2. current_age
3. retirement_age
4. birth_year
5. birth_month
6. gender
7. address
8. latitude
9. longitude
10. per_capita_income
11. yearly_income
12. total_debt
13. credit_score
14. num_credit_cards

Data Description and Data Type of Each Column:
id: int64
current_age: int64
retirement_age: int64
birth_year: int64
birth_month: int64
gender: object
address: object
latitude: float64
longitude: float64
per_capita_income: object
yearly_income: object
total_debt: object
credit_score: int64
num_credit_cards: int64

import pandas as pd

df = pd.read_csv("users_data.csv")

column_names = df.columns.tolist()

column_info = []
for i, column in enumerate(column_names):
    data_type = str(df

In [25]:
response

{'user_input': 'Tell me what are the column names and columns numbers are present in users_data.csv and path is users_data.csv',
 'messages': [ModelRequest(parts=[SystemPromptPart(content='You are part of EDA (Exploratory Data Analysis) team. \n    Your task is to gather information about the data name and the goal of the user.\n    Ask the user to provide the name of the csv file and the path where it is located.\n    Then ask the user to provide the goals they want to achieve with the data.\n    Once all fields are filled, you will return the CSV file name and path, goals of user, along with a message indicating that all fields have been filled.\n    Always add finding data description and the data type as the first goal if any data is provided.\n    Your goals are prompts to a supervisor agent.Which will use other agent to look into the data.\n    ', timestamp=datetime.datetime(2025, 4, 14, 9, 34, 38, 199664, tzinfo=datetime.timezone.utc), dynamic_ref=None, part_kind='system-prompt'

In [27]:
response2 = await graph.ainvoke({'user_input':"plot me a bar chart showing the birth month and the count of users in each month",},config=config)

['Analyzed the CSV file to identify column names, their numbers, and data types.']
The column names and their corresponding column numbers are listed, along with a description of each column's data type. The data types were inferred directly from the pandas DataFrame.
Column Names with Column Numbers:
1. id
2. current_age
3. retirement_age
4. birth_year
5. birth_month
6. gender
7. address
8. latitude
9. longitude
10. per_capita_income
11. yearly_income
12. total_debt
13. credit_score
14. num_credit_cards

Data Description and Data Type of Each Column:
id: int64
current_age: int64
retirement_age: int64
birth_year: int64
birth_month: int64
gender: object
address: object
latitude: float64
longitude: float64
per_capita_income: object
yearly_income: object
total_debt: object
credit_score: int64
num_credit_cards: int64

import pandas as pd

df = pd.read_csv("users_data.csv")

column_names = df.columns.tolist()

column_info = []
for i, column in enumerate(column_names):
    data_type = str(df

In [25]:
response = await graph.ainvoke({
    'user_input': "In users_data.csv do preprocessing if required . path is users_data.csv",
    'messages': [],
    'details': initial_details,
    'supervision': initial_supervision,
    'unified_task': initial_unified_task,
}, config=config)



Python REPL can execute arbitrary code. Use with caution.


['Analyzed data description and types', 'Preprocessed data', 'Generated visualizations']
I have analyzed the data, preprocessed it by cleaning the currency columns and converting the gender column to numerical values. I have also generated five visualizations: the distribution of current age, the distribution of yearly income, the count of gender, a scatter plot of user locations, and the distribution of credit score. All visualizations are saved to the specified directory.
['D:/PROJECT_MULTIMODEL/Pydantic_ai/eda_pydantic/visualizations/current_age_distribution.png', 'D:/PROJECT_MULTIMODEL/Pydantic_ai/eda_pydantic/visualizations/yearly_income_distribution.png', 'D:/PROJECT_MULTIMODEL/Pydantic_ai/eda_pydantic/visualizations/gender_counts.png', 'D:/PROJECT_MULTIMODEL/Pydantic_ai/eda_pydantic/visualizations/user_locations.png', 'D:/PROJECT_MULTIMODEL/Pydantic_ai/eda_pydantic/visualizations/credit_score_distribution.png']
import pandas as pd
import matplotlib.pyplot as plt
import seaborn a

In [26]:
response2 = await graph.ainvoke({
    'user_input': "drop address column and label encode gender column",
}, config=config)



['Analyzed data description and types', 'Preprocessed data', 'Generated visualizations']
I have analyzed the data, preprocessed it by cleaning the currency columns and converting the gender column to numerical values. I have also generated five visualizations: the distribution of current age, the distribution of yearly income, the count of gender, a scatter plot of user locations, and the distribution of credit score. All visualizations are saved to the specified directory.
['D:/PROJECT_MULTIMODEL/Pydantic_ai/eda_pydantic/visualizations/current_age_distribution.png', 'D:/PROJECT_MULTIMODEL/Pydantic_ai/eda_pydantic/visualizations/yearly_income_distribution.png', 'D:/PROJECT_MULTIMODEL/Pydantic_ai/eda_pydantic/visualizations/gender_counts.png', 'D:/PROJECT_MULTIMODEL/Pydantic_ai/eda_pydantic/visualizations/user_locations.png', 'D:/PROJECT_MULTIMODEL/Pydantic_ai/eda_pydantic/visualizations/credit_score_distribution.png']
import pandas as pd
import matplotlib.pyplot as plt
import seaborn a

<Figure size 800x600 with 0 Axes>