Skip to content
Merged
28 changes: 14 additions & 14 deletions atomic_agents/agents/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from atomic_agents.lib.components.system_prompt_generator import SystemPromptContextProviderBase, SystemPromptGenerator


class BaseAgentIO(BaseModel):
class BaseIOSchema(BaseModel):
"""
Base class for input and output schemas for chat agents.
"""
Expand All @@ -21,7 +21,7 @@ def __rich__(self):
return JSON(json_str)


class BaseAgentInputSchema(BaseAgentIO):
class BaseAgentInputSchema(BaseIOSchema):
chat_message: str = Field(
...,
description="The chat message sent by the user to the assistant.",
Expand All @@ -36,7 +36,7 @@ class Config:
}


class BaseAgentOutputSchema(BaseAgentIO):
class BaseAgentOutputSchema(BaseIOSchema):
chat_message: str = Field(
...,
description=(
Expand All @@ -56,7 +56,7 @@ class Config:

class BaseAgentConfig(BaseModel):
client: instructor.client.Instructor = Field(..., description="Client for interacting with the language model.")
model: str = Field("gpt-3.5-turbo", description="The model to use for generating responses.")
model: str = Field("gpt-4o-mini", description="The model to use for generating responses.")
memory: Optional[AgentMemory] = Field(None, description="Memory component for storing chat history.")
system_prompt_generator: Optional[SystemPromptGenerator] = Field(
None, description="Component for generating system prompts."
Expand All @@ -76,8 +76,8 @@ class BaseAgent:
generating system prompts, and obtaining responses from a language model.

Attributes:
input_schema (Type[BaseAgentIO]): Schema for the input data.
output_schema (Type[BaseAgentIO]): Schema for the output data.
input_schema (Type[BaseIOSchema]): Schema for the input data.
output_schema (Type[BaseIOSchema]): Schema for the output data.
client: Client for interacting with the language model.
model (str): The model to use for generating responses.
memory (AgentMemory): Memory component for storing chat history.
Expand Down Expand Up @@ -133,15 +133,15 @@ def get_response(self, response_model=None) -> Type[BaseModel]:
response = self.client.chat.completions.create(model=self.model, messages=messages, response_model=response_model)
return response

def run(self, user_input: Optional[Type[BaseAgentIO]] = None) -> Type[BaseAgentIO]:
def run(self, user_input: Optional[Type[BaseIOSchema]] = None) -> Type[BaseIOSchema]:
"""
Runs the chat agent with the given user input.

Args:
user_input (Optional[Type[BaseAgentIO]]): The input from the user. If not provided, skips adding to memory.
user_input (Optional[Type[BaseIOSchema]]): The input from the user. If not provided, skips adding to memory.

Returns:
Type[BaseAgentIO]: The response from the chat agent.
Type[BaseIOSchema]: The response from the chat agent.
"""
if user_input:
self.current_user_input = user_input
Expand All @@ -165,9 +165,9 @@ def get_context_provider(self, provider_name: str) -> Type[SystemPromptContextPr
Raises:
KeyError: If the context provider is not found.
"""
if provider_name not in self.system_prompt_generator.system_prompt_info.context_providers:
if provider_name not in self.system_prompt_generator.context_providers:
raise KeyError(f"Context provider '{provider_name}' not found.")
return self.system_prompt_generator.system_prompt_info.context_providers[provider_name]
return self.system_prompt_generator.context_providers[provider_name]

def register_context_provider(self, provider_name: str, provider: SystemPromptContextProviderBase):
"""
Expand All @@ -177,7 +177,7 @@ def register_context_provider(self, provider_name: str, provider: SystemPromptCo
provider_name (str): The name of the context provider.
provider (SystemPromptContextProviderBase): The context provider instance.
"""
self.system_prompt_generator.system_prompt_info.context_providers[provider_name] = provider
self.system_prompt_generator.context_providers[provider_name] = provider

def unregister_context_provider(self, provider_name: str):
"""
Expand All @@ -186,7 +186,7 @@ def unregister_context_provider(self, provider_name: str):
Args:
provider_name (str): The name of the context provider to remove.
"""
if provider_name in self.system_prompt_generator.system_prompt_info.context_providers:
del self.system_prompt_generator.system_prompt_info.context_providers[provider_name]
if provider_name in self.system_prompt_generator.context_providers:
del self.system_prompt_generator.context_providers[provider_name]
else:
raise KeyError(f"Context provider '{provider_name}' not found.")
44 changes: 15 additions & 29 deletions atomic_agents/agents/tool_interface_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

from pydantic import Field, create_model

from atomic_agents.agents.base_agent import BaseAgentIO, BaseAgent, BaseAgentConfig
from atomic_agents.lib.components.system_prompt_generator import SystemPromptGenerator, SystemPromptInfo
from atomic_agents.lib.tools.base import BaseTool
from atomic_agents.agents.base_agent import BaseIOSchema, BaseAgent, BaseAgentConfig
from atomic_agents.lib.components.system_prompt_generator import SystemPromptGenerator
from atomic_agents.lib.tools.base_tool import BaseTool
from atomic_agents.lib.utils.format_tool_message import format_tool_message


Expand All @@ -13,18 +13,6 @@ class ToolInterfaceAgentConfig(BaseAgentConfig):
return_raw_output: bool = False


class ToolInputModel(BaseAgentIO):
tool_input: str = Field(..., description="Tool input. Presented as a single question or instruction")

class Config:
title = "Default Tool"
description = "Default tool description"
json_schema_extra = {
"title": "Default Tool",
"description": "Default tool description",
}


class ToolInterfaceAgent(BaseAgent):
"""
A specialized chat agent designed to interact with a specific tool.
Expand Down Expand Up @@ -60,7 +48,7 @@ def __init__(self, config: ToolInterfaceAgentConfig):
alias=f"tool_input_{self.tool_instance.tool_name}",
),
),
__base__=ToolInputModel,
__base__=BaseIOSchema,
)

# Manually set the configuration attributes
Expand All @@ -86,19 +74,17 @@ def __init__(self, config: ToolInterfaceAgentConfig):
output_instructions.append("Return the raw output of the tool.")

self.system_prompt_generator = SystemPromptGenerator(
system_prompt_info=SystemPromptInfo(
background=[
f"This AI agent is designed to interact with the {self.tool_instance.tool_name} tool.",
f"Tool description: {self.tool_instance.tool_description}",
],
steps=[
"Get the user input.",
"Convert the input to the proper parameters to call the tool.",
"Call the tool with the parameters.",
"Respond to the user",
],
output_instructions=output_instructions,
)
background=[
f"This AI agent is designed to interact with the {self.tool_instance.tool_name} tool.",
f"Tool description: {self.tool_instance.tool_description}",
],
steps=[
"Get the user input.",
"Convert the input to the proper parameters to call the tool.",
"Call the tool with the parameters.",
"Respond to the user",
],
output_instructions=output_instructions,
)

def get_response(self, response_model=None):
Expand Down
38 changes: 18 additions & 20 deletions atomic_agents/lib/components/system_prompt_generator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Dict, List, Optional


Expand All @@ -15,21 +14,20 @@ def __repr__(self) -> str:
return self.get_info()


@dataclass
class SystemPromptInfo:
background: List[str]
steps: List[str] = field(default_factory=list)
output_instructions: List[str] = field(default_factory=list)
context_providers: Dict[str, SystemPromptContextProviderBase] = field(default_factory=dict)


class SystemPromptGenerator:
def __init__(self, system_prompt_info: Optional[SystemPromptInfo] = None):
self.system_prompt_info = system_prompt_info or SystemPromptInfo(
background=["This is a conversation with a helpful and friendly AI assistant."]
)

self.system_prompt_info.output_instructions.extend(
def __init__(
self,
background: Optional[List[str]] = None,
steps: Optional[List[str]] = None,
output_instructions: Optional[List[str]] = None,
context_providers: Optional[Dict[str, SystemPromptContextProviderBase]] = None,
):
self.background = background or ["This is a conversation with a helpful and friendly AI assistant."]
self.steps = steps or []
self.output_instructions = output_instructions or []
self.context_providers = context_providers or {}

self.output_instructions.extend(
[
"Always respond using the proper JSON schema.",
"Always use the available additional information and context to enhance the response.",
Expand All @@ -38,9 +36,9 @@ def __init__(self, system_prompt_info: Optional[SystemPromptInfo] = None):

def generate_prompt(self) -> str:
sections = [
("IDENTITY and PURPOSE", self.system_prompt_info.background),
("INTERNAL ASSISTANT STEPS", self.system_prompt_info.steps),
("OUTPUT INSTRUCTIONS", self.system_prompt_info.output_instructions),
("IDENTITY and PURPOSE", self.background),
("INTERNAL ASSISTANT STEPS", self.steps),
("OUTPUT INSTRUCTIONS", self.output_instructions),
]

prompt_parts = []
Expand All @@ -51,9 +49,9 @@ def generate_prompt(self) -> str:
prompt_parts.extend(f"- {item}" for item in content)
prompt_parts.append("")

if self.system_prompt_info.context_providers:
if self.context_providers:
prompt_parts.append("# EXTRA INFORMATION AND CONTEXT")
for provider in self.system_prompt_info.context_providers.values():
for provider in self.context_providers.values():
info = provider.get_info()
if info:
prompt_parts.append(f"## {provider.title}")
Expand Down
4 changes: 2 additions & 2 deletions atomic_agents/lib/models/web_document.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pydantic import BaseModel, Field

from atomic_agents.agents.base_agent import BaseAgentIO
from atomic_agents.agents.base_agent import BaseIOSchema


class WebDocumentMetadata(BaseModel):
Expand All @@ -11,6 +11,6 @@ class WebDocumentMetadata(BaseModel):
author: str = Field(default="")


class WebDocument(BaseAgentIO):
class WebDocument(BaseIOSchema):
content: str
metadata: WebDocumentMetadata = Field(default_factory=lambda: WebDocumentMetadata(url=""))
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from pydantic import BaseModel

from atomic_agents.agents.base_agent import BaseAgentIO
from atomic_agents.agents.base_agent import BaseIOSchema


class BaseToolConfig(BaseModel):
Expand All @@ -15,15 +15,15 @@ class BaseTool:
Base class for all tools in the Atomic Agents framework.

Attributes:
input_schema (Type[BaseAgentIO]): The schema for the input data.
output_schema (Type[BaseAgentIO]): The schema for the output data.
input_schema (Type[BaseIOSchema]): The schema for the input data.
output_schema (Type[BaseIOSchema]): The schema for the output data.
tool_name (str): The name of the tool, derived from the input schema's title.
tool_description (str):
The description of the tool, derived from the input schema's description or overridden by the user.
"""

input_schema: Type[BaseAgentIO]
output_schema: Type[BaseAgentIO]
input_schema: Type[BaseIOSchema]
output_schema: Type[BaseIOSchema]

def __init__(self, config: BaseToolConfig = BaseToolConfig()):
"""
Expand All @@ -35,15 +35,15 @@ def __init__(self, config: BaseToolConfig = BaseToolConfig()):
self.tool_name = config.title or self.input_schema.Config.title
self.tool_description = config.description or self.input_schema.Config.description

def run(self, params: Type[BaseAgentIO]) -> BaseAgentIO:
def run(self, params: Type[BaseIOSchema]) -> BaseIOSchema:
"""
Runs the tool with the given parameters. This method should be implemented by subclasses.

Args:
params (BaseAgentIO): The input parameters for the tool, adhering to the input schema.
params (BaseIOSchema): The input parameters for the tool, adhering to the input schema.

Returns:
BaseAgentIO: The output of the tool, adhering to the output schema.
BaseIOSchema: The output of the tool, adhering to the output schema.

Raises:
NotImplementedError: If the method is not implemented by a subclass.
Expand Down
23 changes: 10 additions & 13 deletions atomic_agents/lib/tools/calculator_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@
from rich.console import Console
from sympy import sympify

from atomic_agents.agents.base_agent import BaseAgentIO
from atomic_agents.lib.tools.base import BaseTool, BaseToolConfig
from atomic_agents.agents.base_agent import BaseIOSchema
from atomic_agents.lib.tools.base_tool import BaseTool, BaseToolConfig


################
# INPUT SCHEMA #
################


class CalculatorToolSchema(BaseAgentIO):
class CalculatorToolInputSchema(BaseIOSchema):
expression: str = Field(..., description="Mathematical expression to evaluate. For example, '2 + 2'.")

class Config:
Expand All @@ -27,9 +26,7 @@ class Config:
####################
# OUTPUT SCHEMA(S) #
####################


class CalculatorToolOutputSchema(BaseAgentIO):
class CalculatorToolOutputSchema(BaseIOSchema):
result: str = Field(..., description="Result of the calculation.")


Expand All @@ -45,11 +42,11 @@ class CalculatorTool(BaseTool):
Tool for performing calculations based on the provided mathematical expression.

Attributes:
input_schema (CalculatorToolSchema): The schema for the input data.
input_schema (CalculatorToolInputSchema): The schema for the input data.
output_schema (CalculatorToolOutputSchema): The schema for the output data.
"""

input_schema = CalculatorToolSchema
input_schema = CalculatorToolInputSchema
output_schema = CalculatorToolOutputSchema

def __init__(self, config: CalculatorToolConfig = CalculatorToolConfig()):
Expand All @@ -61,12 +58,12 @@ def __init__(self, config: CalculatorToolConfig = CalculatorToolConfig()):
"""
super().__init__(config)

def run(self, params: CalculatorToolSchema) -> CalculatorToolOutputSchema:
def run(self, params: CalculatorToolInputSchema) -> CalculatorToolOutputSchema:
"""
Runs the CalculatorTool with the given parameters.

Args:
params (CalculatorToolSchema): The input parameters for the tool, adhering to the input schema.
params (CalculatorToolInputSchema): The input parameters for the tool, adhering to the input schema.

Returns:
CalculatorToolOutputSchema: The output of the tool, adhering to the output schema.
Expand All @@ -83,4 +80,4 @@ def run(self, params: CalculatorToolSchema) -> CalculatorToolOutputSchema:
#################
if __name__ == "__main__":
rich_console = Console()
rich_console.print(CalculatorTool().run(CalculatorToolSchema(expression="2 + 2")))
rich_console.print(CalculatorTool().run(CalculatorToolInputSchema(expression="2 + 2")))
Loading