Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions aixplain/modules/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from aixplain.enums.asset_status import AssetStatus
from aixplain.enums.storage_type import StorageType
from aixplain.modules.model import Model
from aixplain.modules.agent.output_format import OutputFormat
from aixplain.modules.agent.tool import Tool
from aixplain.modules.agent.tool.model_tool import ModelTool
from aixplain.modules.agent.tool.pipeline_tool import PipelineTool
Expand Down Expand Up @@ -127,6 +128,7 @@ def run(
content: Optional[Union[Dict[Text, Text], List[Text]]] = None,
max_tokens: int = 2048,
max_iterations: int = 10,
output_format: OutputFormat = OutputFormat.TEXT,
) -> Dict:
"""Runs an agent call.

Expand All @@ -142,7 +144,7 @@ def run(
content (Union[Dict[Text, Text], List[Text]], optional): Content inputs to be processed according to the query. Defaults to None.
max_tokens (int, optional): maximum number of tokens which can be generated by the agent. Defaults to 2048.
max_iterations (int, optional): maximum number of iterations between the agent and the tools. Defaults to 10.

output_format (ResponseFormat, optional): response format. Defaults to TEXT.
Returns:
Dict: parsed output from model
"""
Expand All @@ -158,6 +160,7 @@ def run(
content=content,
max_tokens=max_tokens,
max_iterations=max_iterations,
output_format=output_format,
)
if response["status"] == "FAILED":
end = time.time()
Expand All @@ -184,6 +187,7 @@ def run_async(
content: Optional[Union[Dict[Text, Text], List[Text]]] = None,
max_tokens: int = 2048,
max_iterations: int = 10,
output_format: OutputFormat = OutputFormat.TEXT,
) -> Dict:
"""Runs asynchronously an agent call.

Expand All @@ -197,7 +201,7 @@ def run_async(
content (Union[Dict[Text, Text], List[Text]], optional): Content inputs to be processed according to the query. Defaults to None.
max_tokens (int, optional): maximum number of tokens which can be generated by the agent. Defaults to 2048.
max_iterations (int, optional): maximum number of iterations between the agent and the tools. Defaults to 10.

output_format (ResponseFormat, optional): response format. Defaults to TEXT.
Returns:
dict: polling URL in response
"""
Expand Down Expand Up @@ -234,13 +238,18 @@ def run_async(

headers = {"x-api-key": self.api_key, "Content-Type": "application/json"}

payload = {"id": self.id, "query": FileFactory.to_link(query), "sessionId": session_id, "history": history}
parameters.update(
{
"max_tokens": parameters["max_tokens"] if "max_tokens" in parameters else max_tokens,
"max_iterations": parameters["max_iterations"] if "max_iterations" in parameters else max_iterations,
}
)
payload = {
"id": self.id,
"query": FileFactory.to_link(query),
"sessionId": session_id,
"history": history,
"executionParams": {
"maxTokens": parameters["max_tokens"] if "max_tokens" in parameters else max_tokens,
"maxIterations": parameters["max_iterations"] if "max_iterations" in parameters else max_iterations,
"outputFormat": output_format.value,
},
}

payload.update(parameters)
payload = json.dumps(payload)

Expand Down
30 changes: 30 additions & 0 deletions aixplain/modules/agent/output_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
__author__ = "thiagocastroferreira"

"""
Copyright 2024 The aiXplain SDK authors

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

Author: Duraikrishna Selvaraju, Thiago Castro Ferreira, Shreyas Sharma and Lucas Pavanelli
Date: February 21st 2024
Description:
Asset Enum
"""

from enum import Enum
from typing import Text


class OutputFormat(Text, Enum):
MARKDOWN = "markdown"
TEXT = "text"
26 changes: 17 additions & 9 deletions aixplain/modules/team_agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from aixplain.enums.asset_status import AssetStatus
from aixplain.enums.storage_type import StorageType
from aixplain.modules.model import Model
from aixplain.modules.agent import Agent
from aixplain.modules.agent import Agent, OutputFormat
from typing import Dict, List, Text, Optional, Union
from urllib.parse import urljoin

Expand Down Expand Up @@ -112,6 +112,7 @@ def run(
content: Optional[Union[Dict[Text, Text], List[Text]]] = None,
max_tokens: int = 2048,
max_iterations: int = 30,
output_format: OutputFormat = OutputFormat.TEXT,
) -> Dict:
"""Runs a team agent call.

Expand All @@ -127,6 +128,7 @@ def run(
content (Union[Dict[Text, Text], List[Text]], optional): Content inputs to be processed according to the query. Defaults to None.
max_tokens (int, optional): maximum number of tokens which can be generated by the agents. Defaults to 2048.
max_iterations (int, optional): maximum number of iterations between the agents. Defaults to 30.
output_format (ResponseFormat, optional): response format. Defaults to TEXT.
Returns:
Dict: parsed output from model
"""
Expand All @@ -142,6 +144,7 @@ def run(
content=content,
max_tokens=max_tokens,
max_iterations=max_iterations,
output_format=output_format,
)
if response["status"] == "FAILED":
end = time.time()
Expand All @@ -168,6 +171,7 @@ def run_async(
content: Optional[Union[Dict[Text, Text], List[Text]]] = None,
max_tokens: int = 2048,
max_iterations: int = 30,
output_format: OutputFormat = OutputFormat.TEXT,
) -> Dict:
"""Runs asynchronously a Team Agent call.

Expand All @@ -181,7 +185,7 @@ def run_async(
content (Union[Dict[Text, Text], List[Text]], optional): Content inputs to be processed according to the query. Defaults to None.
max_tokens (int, optional): maximum number of tokens which can be generated by the agents. Defaults to 2048.
max_iterations (int, optional): maximum number of iterations between the agents. Defaults to 30.

output_format (ResponseFormat, optional): response format. Defaults to TEXT.
Returns:
dict: polling URL in response
"""
Expand Down Expand Up @@ -218,13 +222,17 @@ def run_async(

headers = {"x-api-key": self.api_key, "Content-Type": "application/json"}

payload = {"id": self.id, "query": FileFactory.to_link(query), "sessionId": session_id, "history": history}
parameters.update(
{
"max_tokens": parameters["max_tokens"] if "max_tokens" in parameters else max_tokens,
"max_iterations": parameters["max_iterations"] if "max_iterations" in parameters else max_iterations,
}
)
payload = {
"id": self.id,
"query": FileFactory.to_link(query),
"sessionId": session_id,
"history": history,
"executionParams": {
"maxTokens": parameters["max_tokens"] if "max_tokens" in parameters else max_tokens,
"maxIterations": parameters["max_iterations"] if "max_iterations" in parameters else max_iterations,
"outputFormat": output_format.value,
},
}
payload.update(parameters)
payload = json.dumps(payload)

Expand Down
18 changes: 18 additions & 0 deletions tests/unit/agent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import requests_mock
from aixplain.enums.asset_status import AssetStatus
from aixplain.modules import Agent
from aixplain.modules.agent import OutputFormat
from aixplain.utils import config
from aixplain.factories import AgentFactory
from aixplain.modules.agent import PipelineTool, ModelTool
Expand Down Expand Up @@ -226,3 +227,20 @@ def test_update_success():
assert agent.description == ref_response["description"]
assert agent.llm_id == ref_response["llmId"]
assert agent.tools[0].function.value == ref_response["assets"][0]["function"]


def test_run_success():
agent = Agent("123", "Test Agent", "Sample Description")
url = urljoin(config.BACKEND_URL, f"sdk/agents/{agent.id}/run")
agent.url = url
with requests_mock.Mocker() as mock:
headers = {"x-aixplain-key": config.AIXPLAIN_API_KEY, "Content-Type": "application/json"}

ref_response = {"data": "www.aixplain.com", "status": "IN_PROGRESS"}
mock.post(url, headers=headers, json=ref_response)

response = agent.run_async(
data={"query": "Hello, how are you?"}, max_iterations=10, output_format=OutputFormat.MARKDOWN
)
assert response["status"] == "IN_PROGRESS"
assert response["url"] == ref_response["data"]