Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
221 changes: 221 additions & 0 deletions atomic-agents/atomic_agents/agents/base_agent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import instructor
import asyncio
from pydantic import BaseModel, Field
from typing import Optional, Type
from atomic_agents.lib.components.agent_memory import AgentMemory
Expand Down Expand Up @@ -78,6 +79,27 @@ class BaseAgentConfig(BaseModel):
)


class BaseAgentConfigAsync(BaseModel):
#TODO: Add a test for this
client: instructor.client.AsyncInstructor = Field(..., description="Async client for interacting with the language model.")
model: str = Field("gpt-4o-mini", description="The model to use for generating responses.")
memory: Optional[AgentMemory] = Field(None, description="Memory component for storing chat history.")
system_prompt_generator: Optional[SystemPromptGenerator] = Field(
None, description="Component for generating system prompts."
)
input_schema: Optional[Type[BaseModel]] = Field(None, description="The schema for the input data.")
output_schema: Optional[Type[BaseModel]] = Field(None, description="The schema for the output data.")
model_config = {"arbitrary_types_allowed": True}
temperature: Optional[float] = Field(
0,
description="Temperature for response generation, typically ranging from 0 to 1.",
)
max_tokens: Optional[int] = Field(
None,
description="Maximum number of token allowed in the response generation.",
)


class BaseAgent:
"""
Base class for chat agents.
Expand Down Expand Up @@ -269,6 +291,205 @@ def unregister_context_provider(self, provider_name: str):
raise KeyError(f"Context provider '{provider_name}' not found.")


class BaseAgentAsync:
#TODO: Add a test for this
"""
Base class for chat agents.

This class provides the core functionality for handling chat interactions, including managing memory,
generating system prompts, and obtaining responses from a language model.

Attributes:
input_schema (Type[BaseIOSchema]): Schema for the input data.
output_schema (Type[BaseIOSchema]): Schema for the output data.
client: Client for interacting with the language model.
model (str): The model to use for generating responses.
memory (AgentMemory): Memory component for storing chat history.
system_prompt_generator (SystemPromptGenerator): Component for generating system prompts.
initial_memory (AgentMemory): Initial state of the memory.
max_tokens (int): Maximum number of tokens allowed in the response
"""

input_schema = BaseAgentInputSchema
output_schema = BaseAgentOutputSchema

def __init__(self, config: BaseAgentConfig):
"""
Initializes the BaseAgent.

Args:
config (BaseAgentConfig): Configuration for the chat agent.
"""
self.input_schema = config.input_schema or self.input_schema
self.output_schema = config.output_schema or self.output_schema
self.client = config.client
self.model = config.model
self.memory = config.memory or AgentMemory()
self.system_prompt_generator = config.system_prompt_generator or SystemPromptGenerator()
self.initial_memory = self.memory.copy()
self.current_user_input = None
self.temperature = config.temperature
self.max_tokens = config.max_tokens

self.lock = asyncio.Lock()

def reset_memory(self):
"""
Resets the memory to its initial state.
"""
self.memory = self.initial_memory.copy()

async def get_response(self, response_model=None) -> Type[BaseModel]:
"""
Obtains a response from the language model synchronously.

Args:
response_model (Type[BaseModel], optional):
The schema for the response data. If not set, self.output_schema is used.

Returns:
Type[BaseModel]: The response from the language model.
"""
if response_model is None:
response_model = self.output_schema

messages = [
{
"role": "system",
"content": self.system_prompt_generator.generate_prompt(),
}
] + self.memory.get_history()

response = await self.client.chat.completions.create(
messages=messages,
model=self.model,
response_model=response_model,
temperature=self.temperature,
max_tokens=self.max_tokens,
)

return response

async def run(self, user_input: Optional[BaseIOSchema] = None) -> BaseIOSchema:
"""
Runs the chat agent with the given user input synchronously.

Args:
user_input (Optional[BaseIOSchema]): The input from the user. If not provided, skips adding to memory.

Returns:
BaseIOSchema: The response from the chat agent.
"""
async with self.lock:
if user_input:
self.memory.initialize_turn()
self.current_user_input = user_input
self.memory.add_message("user", user_input)

response = await self.get_response(response_model=self.output_schema)
self.memory.add_message("assistant", response)

return response

async def run_async(self, user_input: Optional[BaseIOSchema] = None):
"""
Runs the chat agent with the given user input, supporting streaming output asynchronously.

Args:
user_input (Optional[BaseIOSchema]): The input from the user. If not provided, skips adding to memory.

Yields:
BaseModel: Partial responses from the chat agent.
"""
async with self.lock:
if user_input:
self.memory.initialize_turn()
self.current_user_input = user_input
self.memory.add_message("user", user_input)

messages = [
{
"role": "system",
"content": self.system_prompt_generator.generate_prompt(),
}
] + self.memory.get_history()

response_stream = await self.client.chat.completions.create_partial(
model=self.model,
messages=messages,
response_model=self.output_schema,
temperature=self.temperature,
max_tokens=self.max_tokens,
stream=True,
)

async for partial_response in response_stream:
yield partial_response

full_response_content = self.output_schema(**partial_response.model_dump())
self.memory.add_message("assistant", full_response_content)

async def stream_response_async(self, user_input: Optional[Type[BaseIOSchema]] = None):
"""
Deprecated method for streaming responses asynchronously. Use run_async instead.

Args:
user_input (Optional[Type[BaseIOSchema]]): The input from the user. If not provided, skips adding to memory.

Yields:
BaseModel: Partial responses from the chat agent.
"""
warnings.warn(
"stream_response_async is deprecated and will be removed in version 1.1. Use run_async instead which can be used in the exact same way.",
DeprecationWarning,
stacklevel=2,
)
async for response in self.run_async(user_input):
yield response

def get_context_provider(self, provider_name: str) -> Type[SystemPromptContextProviderBase]:
"""
Retrieves a context provider by name.

Args:
provider_name (str): The name of the context provider.

Returns:
SystemPromptContextProviderBase: The context provider if found.

Raises:
KeyError: If the context provider is not found.
"""
if provider_name not in self.system_prompt_generator.context_providers:
raise KeyError(f"Context provider '{provider_name}' not found.")
return self.system_prompt_generator.context_providers[provider_name]

def register_context_provider(self, provider_name: str, provider: SystemPromptContextProviderBase):
"""
Registers a new context provider.

Args:
provider_name (str): The name of the context provider.
provider (SystemPromptContextProviderBase): The context provider instance.
"""
self.system_prompt_generator.context_providers[provider_name] = provider

def unregister_context_provider(self, provider_name: str):
"""
Unregisters an existing context provider.

Args:
provider_name (str): The name of the context provider to remove.
"""
if provider_name in self.system_prompt_generator.context_providers:
del self.system_prompt_generator.context_providers[provider_name]
else:
raise KeyError(f"Context provider '{provider_name}' not found.")





if __name__ == "__main__":
from rich.console import Console
from rich.panel import Panel
Expand Down
Loading