Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions aixplain/factories/agent_factory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from aixplain.enums.function import Function
from aixplain.enums.supplier import Supplier
from aixplain.modules.agent import Agent, AgentTask, Tool
from aixplain.modules.agent.output_format import OutputFormat
from aixplain.modules.agent.tool.model_tool import ModelTool
from aixplain.modules.agent.tool.pipeline_tool import PipelineTool
from aixplain.modules.agent.tool.python_interpreter_tool import PythonInterpreterTool
Expand All @@ -41,7 +42,7 @@
from aixplain.modules.pipeline import Pipeline
from aixplain.utils import config
from typing import Callable, Dict, List, Optional, Text, Union

from pydantic import BaseModel
from aixplain.utils.request_utils import _request_with_retry
from urllib.parse import urljoin
from aixplain.enums import DatabaseSourceType
Expand All @@ -61,6 +62,8 @@ def create(
supplier: Union[Dict, Text, Supplier, int] = "aiXplain",
version: Optional[Text] = None,
tasks: List[AgentTask] = [],
output_format: Optional[OutputFormat] = None,
expected_output: Optional[Union[BaseModel, Text, dict]] = None,
) -> Agent:
"""Create a new agent in the platform.

Expand All @@ -80,7 +83,8 @@ def create(
supplier (Union[Dict, Text, Supplier, int], optional): owner of the agent. Defaults to "aiXplain".
version (Optional[Text], optional): version of the agent. Defaults to None.
tasks (List[AgentTask], optional): list of tasks for the agent. Defaults to [].

output_format (OutputFormat, optional): default output format for agent responses. Defaults to OutputFormat.TEXT.
expected_output (Union[BaseModel, Text, dict], optional): expected output. Defaults to None.
Returns:
Agent: created Agent
"""
Expand All @@ -92,6 +96,11 @@ def create(
# Use default GPT-4o if no LLM specified
llm = get_llm_instance("669a63646eb56306647e1091", api_key=api_key)

if output_format == OutputFormat.JSON:
assert expected_output is not None and (
issubclass(expected_output, BaseModel) or isinstance(expected_output, dict)
), "'expected_output' must be a Pydantic BaseModel or a JSON object when 'output_format' is JSON."

warnings.warn(
"Use `llm` to define the large language model (aixplain.modules.model.llm_model.LLM) to be used as agent. "
"Use `llm_id` to provide the model ID of the large language model to be used as agent. "
Expand Down Expand Up @@ -135,6 +144,12 @@ def create(
# Store the LLM object in payload to avoid recreating it
payload["llm"] = llm

if expected_output:
payload["expectedOutput"] = expected_output
if output_format:
if isinstance(output_format, OutputFormat):
output_format = output_format.value
payload["outputFormat"] = output_format
agent = build_agent(payload=payload, tools=tools, api_key=api_key)
agent.validate(raise_exception=True)
response = "Unspecified error"
Expand Down
3 changes: 3 additions & 0 deletions aixplain/factories/agent_factory/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from aixplain.modules.agent.tool.python_interpreter_tool import PythonInterpreterTool
from aixplain.modules.agent.tool.custom_python_code_tool import CustomPythonCodeTool
from aixplain.modules.agent.tool.sql_tool import SQLTool
from aixplain.modules.agent.output_format import OutputFormat
from aixplain.modules.model import Model
from aixplain.modules.model.connection import ConnectionTool
from typing import Dict, Text, List, Union
Expand Down Expand Up @@ -185,6 +186,8 @@ def build_agent(payload: Dict, tools: List[Tool] = None, api_key: Text = config.
llm=llm,
api_key=api_key,
status=AssetStatus(payload["status"]),
output_format=OutputFormat(payload.get("outputFormat", OutputFormat.TEXT)),
expected_output=payload.get("expectedOutput", None),
tasks=[
AgentTask(
name=task["name"],
Expand Down
14 changes: 12 additions & 2 deletions aixplain/factories/team_agent_factory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
from aixplain.utils.request_utils import _request_with_retry
from aixplain.modules.model.llm_model import LLM
from aixplain.utils.llm_utils import get_llm_instance
from pydantic import BaseModel
from aixplain.modules.agent.output_format import OutputFormat


class TeamAgentFactory:
Expand All @@ -55,6 +57,8 @@ def create(
inspectors: List[Inspector] = [],
inspector_targets: List[Union[InspectorTarget, Text]] = [InspectorTarget.STEPS],
instructions: Optional[Text] = None,
output_format: Optional[OutputFormat] = None,
expected_output: Optional[Union[BaseModel, Text, dict]] = None,
**kwargs,
) -> TeamAgent:
"""Create a new team agent in the platform.
Expand All @@ -75,7 +79,8 @@ def create(
inspector_targets: Which stages to be inspected during an execution of the team agent. (steps, output)
use_mentalist_and_inspector: Whether to use the mentalist and inspector agents. (legacy)
instructions: The instructions to guide the team agent (i.e. appended in the prompt of the team agent).

output_format: The output format to be used for the team agent.
expected_output: The expected output to be used for the team agent.
Returns:
A new team agent instance.
"""
Expand All @@ -90,6 +95,12 @@ def create(
logging.warning("TeamAgent Onboarding Warning: num_inspectors is no longer supported. Use inspectors instead.")

assert len(agents) > 0, "TeamAgent Onboarding Error: At least one agent must be provided."

if output_format == OutputFormat.JSON:
assert expected_output is not None and (
issubclass(expected_output, BaseModel) or isinstance(expected_output, dict)
), "'expected_output' must be a Pydantic BaseModel or a JSON object when 'output_format' is JSON."

agent_list = []
for agent in agents:
if isinstance(agent, Text) is True:
Expand Down Expand Up @@ -185,7 +196,6 @@ def _setup_llm_and_tool(llm_param: Optional[Union[LLM, Text]],
"tools": tools,
"role": instructions,
}

# Store the LLM objects directly in the payload for build_team_agent
internal_payload = payload.copy()
if llm is not None:
Expand Down
4 changes: 3 additions & 1 deletion aixplain/factories/team_agent_factory/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from aixplain.factories.agent_factory import AgentFactory
from aixplain.factories.model_factory import ModelFactory
from aixplain.modules.model.model_parameters import ModelParameters

from aixplain.modules.agent.output_format import OutputFormat

GPT_4o_ID = "6646261c6eb563165658bbb1"

Expand Down Expand Up @@ -94,6 +94,8 @@ def build_team_agent(payload: Dict, agents: List[Agent] = None, api_key: Text =
inspector_targets=inspector_targets,
api_key=api_key,
status=AssetStatus(payload["status"]),
output_format=OutputFormat(payload.get("outputFormat", OutputFormat.TEXT)),
expected_output=payload.get("expectedOutput", None),
)
team_agent.url = urljoin(config.BACKEND_URL, f"sdk/agent-communities/{team_agent.id}/run")

Expand Down
27 changes: 22 additions & 5 deletions aixplain/modules/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ class Agent(Model, DeployableMixin[Tool]):
backend_url (str): URL of the backend.
api_key (str): The TEAM API key used for authentication.
cost (Dict, optional): model price. Defaults to None.
output_format (OutputFormat): default output format for agent responses.
expected_output (Union[BaseModel, Text, dict], optional): expected output. Defaults to None.
"""

is_valid: bool
Expand All @@ -78,6 +80,8 @@ def __init__(
cost: Optional[Dict] = None,
status: AssetStatus = AssetStatus.DRAFT,
tasks: List[AgentTask] = [],
output_format: OutputFormat = OutputFormat.TEXT,
expected_output: Optional[Union[BaseModel, Text, dict]] = None,
**additional_info,
) -> None:
"""Create an Agent with the necessary information.
Expand All @@ -95,6 +99,8 @@ def __init__(
backend_url (str): URL of the backend.
api_key (str): The TEAM API key used for authentication.
cost (Dict, optional): model price. Defaults to None.
output_format (OutputFormat, optional): default output format for agent responses. Defaults to OutputFormat.TEXT.
expected_output (Union[BaseModel, Text, dict], optional): expected output. Defaults to None.
"""
super().__init__(id, name, description, api_key, supplier, version, cost=cost)
self.instructions = instructions
Expand All @@ -111,6 +117,8 @@ def __init__(
status = AssetStatus.DRAFT
self.status = status
self.tasks = tasks
self.output_format = output_format
self.expected_output = expected_output
self.is_valid = True

def _validate(self) -> None:
Expand Down Expand Up @@ -170,7 +178,7 @@ def run(
content: Optional[Union[Dict[Text, Text], List[Text]]] = None,
max_tokens: int = 4096,
max_iterations: int = 3,
output_format: OutputFormat = OutputFormat.TEXT,
output_format: Optional[OutputFormat] = None,
expected_output: Optional[Union[BaseModel, Text, dict]] = None,
) -> AgentResponse:
"""Runs an agent call.
Expand All @@ -187,7 +195,7 @@ def run(
content (Union[Dict[Text, Text], List[Text]], optional): Content inputs to be processed according to the query. Defaults to None.
max_tokens (int, optional): maximum number of tokens which can be generated by the agent. Defaults to 2048.
max_iterations (int, optional): maximum number of iterations between the agent and the tools. Defaults to 10.
output_format (OutputFormat, optional): response format. Defaults to TEXT.
output_format (OutputFormat, optional): response format. If not provided, uses the format set during initialization.
expected_output (Union[BaseModel, Text, dict], optional): expected output. Defaults to None.
Returns:
Dict: parsed output from model
Expand Down Expand Up @@ -258,7 +266,7 @@ def run_async(
content: Optional[Union[Dict[Text, Text], List[Text]]] = None,
max_tokens: int = 2048,
max_iterations: int = 10,
output_format: OutputFormat = OutputFormat.TEXT,
output_format: Optional[OutputFormat] = None,
expected_output: Optional[Union[BaseModel, Text, dict]] = None,
) -> AgentResponse:
"""Runs asynchronously an agent call.
Expand All @@ -273,7 +281,7 @@ def run_async(
content (Union[Dict[Text, Text], List[Text]], optional): Content inputs to be processed according to the query. Defaults to None.
max_tokens (int, optional): maximum number of tokens which can be generated by the agent. Defaults to 2048.
max_iterations (int, optional): maximum number of iterations between the agent and the tools. Defaults to 10.
output_format (OutputFormat, optional): response format. Defaults to TEXT.
output_format (OutputFormat, optional): response format. If not provided, uses the format set during initialization.
expected_output (Union[BaseModel, Text, dict], optional): expected output. Defaults to None.
Returns:
dict: polling URL in response
Expand Down Expand Up @@ -321,9 +329,14 @@ def run_async(

# build query
input_data = process_variables(query, data, parameters, self.instructions)

if expected_output is None:
expected_output = self.expected_output
if expected_output is not None and issubclass(expected_output, BaseModel):
expected_output = expected_output.model_json_schema()
# Use instance output_format if none provided
if output_format is None:
output_format = self.output_format

if isinstance(output_format, OutputFormat):
output_format = output_format.value

Expand Down Expand Up @@ -387,6 +400,8 @@ def to_dict(self) -> Dict:
else [],
"cost": self.cost,
"api_key": self.api_key,
"outputFormat": self.output_format.value,
"expectedOutput": self.expected_output,
}

@classmethod
Expand Down Expand Up @@ -461,6 +476,8 @@ def from_dict(cls, data: Dict) -> "Agent":
cost=data.get("cost"),
status=status,
tasks=tasks,
output_format=OutputFormat(data.get("outputFormat", OutputFormat.TEXT)),
expected_output=data.get("expectedOutput"),
)

def delete(self) -> None:
Expand Down
21 changes: 16 additions & 5 deletions aixplain/modules/team_agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ def __init__(
inspector_targets: List[InspectorTarget] = [InspectorTarget.STEPS],
status: AssetStatus = AssetStatus.DRAFT,
instructions: Optional[Text] = None,
output_format: OutputFormat = OutputFormat.TEXT,
expected_output: Optional[Union[BaseModel, Text, dict]] = None,
**additional_info,
) -> None:
super().__init__(id, name, description, api_key, supplier, version, cost=cost)
Expand All @@ -117,6 +119,8 @@ def __init__(
status = AssetStatus.DRAFT
self.status = status
self.is_valid = True
self.output_format = output_format
self.expected_output = expected_output

def run(
self,
Expand All @@ -131,7 +135,7 @@ def run(
content: Optional[Union[Dict[Text, Text], List[Text]]] = None,
max_tokens: int = 2048,
max_iterations: int = 30,
output_format: OutputFormat = OutputFormat.TEXT,
output_format: Optional[OutputFormat] = None,
expected_output: Optional[Union[BaseModel, Text, dict]] = None,
) -> AgentResponse:
"""Runs a team agent call.
Expand All @@ -148,7 +152,7 @@ def run(
content (Union[Dict[Text, Text], List[Text]], optional): Content inputs to be processed according to the query. Defaults to None.
max_tokens (int, optional): maximum number of tokens which can be generated by the agents. Defaults to 2048.
max_iterations (int, optional): maximum number of iterations between the agents. Defaults to 30.
output_format (OutputFormat, optional): response format. Defaults to TEXT.
output_format (OutputFormat, optional): response format. If not provided, uses the format set during initialization.
expected_output (Union[BaseModel, Text, dict], optional): expected output. Defaults to None.
Returns:
Dict: parsed output from model
Expand Down Expand Up @@ -211,7 +215,7 @@ def run_async(
content: Optional[Union[Dict[Text, Text], List[Text]]] = None,
max_tokens: int = 2048,
max_iterations: int = 30,
output_format: OutputFormat = OutputFormat.TEXT,
output_format: Optional[OutputFormat] = None,
expected_output: Optional[Union[BaseModel, Text, dict]] = None,
) -> AgentResponse:
"""Runs asynchronously a Team Agent call.
Expand All @@ -226,7 +230,7 @@ def run_async(
content (Union[Dict[Text, Text], List[Text]], optional): Content inputs to be processed according to the query. Defaults to None.
max_tokens (int, optional): maximum number of tokens which can be generated by the agents. Defaults to 2048.
max_iterations (int, optional): maximum number of iterations between the agents. Defaults to 30.
output_format (OutputFormat, optional): response format. Defaults to TEXT.
output_format (OutputFormat, optional): response format. If not provided, uses the format set during initialization.
expected_output (Union[BaseModel, Text, dict], optional): expected output. Defaults to None.
Returns:
dict: polling URL in response
Expand Down Expand Up @@ -271,9 +275,12 @@ def run_async(

# build query
input_data = process_variables(query, data, parameters, self.description)

if expected_output is None:
expected_output = self.expected_output
if expected_output is not None and issubclass(expected_output, BaseModel):
expected_output = expected_output.model_json_schema()
if output_format is None:
output_format = self.output_format
if isinstance(output_format, OutputFormat):
output_format = output_format.value

Expand Down Expand Up @@ -378,6 +385,8 @@ def to_dict(self) -> Dict:
"version": self.version,
"status": self.status.value,
"role": self.instructions,
"outputFormat": self.output_format.value,
"expectedOutput": self.expected_output,
}

@classmethod
Expand Down Expand Up @@ -479,6 +488,8 @@ def from_dict(cls, data: Dict) -> "TeamAgent":
instructions=data.get("role"),
inspectors=inspectors,
inspector_targets=inspector_targets,
output_format=OutputFormat(data.get("outputFormat", OutputFormat.TEXT)),
expected_output=data.get("expectedOutput"),
)

def _validate(self) -> None:
Expand Down
23 changes: 14 additions & 9 deletions tests/unit/agent/agent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1270,20 +1270,23 @@ def test_agent_serialization_completeness():
agent_dict = agent.to_dict()

required_fields = {
"id",
"name",
"description",
"llmId",
"version",
"role",
"assets",
"api_key",
"supplier",
"version",
"llmId",
"outputFormat",
"status",
"tasks",
"tools",
"name",
"description",
"cost",
"api_key",
"tools",
"assets",
"tasks",
"expectedOutput",
"id",
}

assert set(agent_dict.keys()) == required_fields

# Verify field values
Expand All @@ -1300,6 +1303,8 @@ def test_agent_serialization_completeness():
assert isinstance(agent_dict["assets"], list)
assert isinstance(agent_dict["tasks"], list)
assert len(agent_dict["tasks"]) == 2
assert agent_dict["outputFormat"] == "text"
assert agent_dict["expectedOutput"] is None

# Verify task serialization
task_dict = agent_dict["tasks"][0]
Expand Down
5 changes: 5 additions & 0 deletions tests/unit/team_agent/team_agent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,10 @@ def test_team_agent_serialization_completeness():
"version",
"status",
"role",
"outputFormat",
"expectedOutput",
}

assert set(team_dict.keys()) == required_fields

# Verify field values
Expand All @@ -457,6 +460,8 @@ def test_team_agent_serialization_completeness():
assert team_dict["status"] == "draft"
assert team_dict["links"] == []
assert team_dict["plannerId"] is None # use_mentalist=False
assert team_dict["outputFormat"] == "text"
assert team_dict["expectedOutput"] is None

# Verify agents serialization
assert isinstance(team_dict["agents"], list)
Expand Down