Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions aixplain/factories/agent_factory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from aixplain.utils import config
from typing import Dict, List, Optional, Text, Union

from aixplain.factories.agent_factory.utils import build_agent
from aixplain.factories.agent_factory.utils import build_agent, validate_llm
from aixplain.utils.file_utils import _request_with_retry
from urllib.parse import urljoin

Expand All @@ -50,8 +50,30 @@ def create(
api_key: Text = config.TEAM_API_KEY,
supplier: Union[Dict, Text, Supplier, int] = "aiXplain",
version: Optional[Text] = None,
use_mentalist_and_inspector: bool = False,
) -> Agent:
"""Create a new agent in the platform."""
"""Create a new agent in the platform.

Args:
name (Text): name of the agent
llm_id (Text): aiXplain ID of the large language model to be used as agent.
tools (List[Tool], optional): list of tool for the agent. Defaults to [].
description (Text, optional): description of the agent role. Defaults to "".
api_key (Text, optional): team/user API key. Defaults to config.TEAM_API_KEY.
supplier (Union[Dict, Text, Supplier, int], optional): owner of the agent. Defaults to "aiXplain".
version (Optional[Text], optional): version of the agent. Defaults to None.
use_mentalist_and_inspector (bool, optional): flag to enable mentalist and inspector agents (which only works when a supervisor is enabled). Defaults to False.

Returns:
Agent: created Agent
"""
# validate LLM ID
validate_llm(llm_id)

orchestrator_llm_id, mentalist_and_inspector_llm_id = llm_id, None
if use_mentalist_and_inspector is True:
mentalist_and_inspector_llm_id = llm_id

try:
agent = None
url = urljoin(config.BACKEND_URL, "sdk/agents")
Expand Down Expand Up @@ -94,9 +116,10 @@ def create(
"description": description,
"supplier": supplier,
"version": version,
"llmId": llm_id,
"supervisorId": orchestrator_llm_id,
"plannerId": mentalist_and_inspector_llm_id,
}
if llm_id is not None:
payload["llmId"] = llm_id

logging.info(f"Start service for POST Create Agent - {url} - {headers} - {json.dumps(payload)}")
r = _request_with_retry("post", url, headers=headers, json=payload)
Expand Down
10 changes: 10 additions & 0 deletions aixplain/factories/agent_factory/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,13 @@ def build_agent(payload: Dict, api_key: Text = config.TEAM_API_KEY) -> Agent:
)
agent.url = urljoin(config.BACKEND_URL, f"sdk/agents/{agent.id}/run")
return agent


def validate_llm(model_id: Text) -> None:
from aixplain.factories.model_factory import ModelFactory

try:
llm = ModelFactory.get(model_id)
assert llm.function == Function.TEXT_GENERATION, "Large Language Model must be a text generation model."
except Exception:
raise Exception(f"Large Language Model with ID '{model_id}' not found.")
6 changes: 6 additions & 0 deletions tests/functional/agent/agent_functional_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,9 @@ def test_list_agents():
assert "results" in agents
agents_result = agents["results"]
assert type(agents_result) is list


def test_fail_non_existent_llm():
with pytest.raises(Exception) as exc_info:
AgentFactory.create(name="Test Agent", llm_id="non_existent_llm", tools=[])
assert str(exc_info.value) == "Large Language Model with ID 'non_existent_llm' not found."
55 changes: 55 additions & 0 deletions tests/unit/agent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from aixplain.utils import config
from aixplain.factories import AgentFactory
from aixplain.modules.agent import PipelineTool, ModelTool
from urllib.parse import urljoin


def test_fail_no_data_query():
Expand Down Expand Up @@ -77,3 +78,57 @@ def test_invalid_modeltool():
with pytest.raises(Exception) as exc_info:
AgentFactory.create(name="Test", tools=[ModelTool(model="309851793")], llm_id="6646261c6eb563165658bbb1")
assert str(exc_info.value) == "Model Tool Unavailable. Make sure Model '309851793' exists or you have access to it."


def test_create_agent():
from aixplain.enums import Supplier

with requests_mock.Mocker() as mock:
url = urljoin(config.BACKEND_URL, "sdk/agents")
headers = {"x-api-key": config.TEAM_API_KEY, "Content-Type": "application/json"}

ref_response = {
"id": "123",
"name": "Test Agent",
"description": "Test Agent Description",
"teamId": "123",
"version": "1.0",
"status": "onboarded",
"llmId": "6646261c6eb563165658bbb1",
"pricing": {"currency": "USD", "value": 0.0},
"assets": [
{
"type": "model",
"supplier": "openai",
"version": "1.0",
"assetId": "6646261c6eb563165658bbb1",
"function": "text-generation",
}
],
}
mock.post(url, headers=headers, json=ref_response)

url = urljoin(config.BACKEND_URL, "sdk/models/6646261c6eb563165658bbb1")
model_ref_response = {
"id": "6646261c6eb563165658bbb1",
"name": "Test LLM",
"description": "Test LLM Description",
"function": {"id": "text-generation"},
"supplier": "openai",
"version": {"id": "1.0"},
"status": "onboarded",
"pricing": {"currency": "USD", "value": 0.0},
}
mock.get(url, headers=headers, json=model_ref_response)

agent = AgentFactory.create(
name="Test Agent",
description="Test Agent Description",
llm_id="6646261c6eb563165658bbb1",
tools=[AgentFactory.create_model_tool(supplier=Supplier.OPENAI, function="text-generation")],
)

assert agent.name == ref_response["name"]
assert agent.description == ref_response["description"]
assert agent.llm_id == ref_response["llmId"]
assert agent.tools[0].function.value == ref_response["assets"][0]["function"]
Loading