diff --git a/aixplain/modules/agent/__init__.py b/aixplain/modules/agent/__init__.py index 8d7391af..b7aad7aa 100644 --- a/aixplain/modules/agent/__init__.py +++ b/aixplain/modules/agent/__init__.py @@ -32,6 +32,7 @@ from aixplain.enums.asset_status import AssetStatus from aixplain.enums.storage_type import StorageType from aixplain.modules.model import Model +from aixplain.modules.agent.output_format import OutputFormat from aixplain.modules.agent.tool import Tool from aixplain.modules.agent.tool.model_tool import ModelTool from aixplain.modules.agent.tool.pipeline_tool import PipelineTool @@ -127,6 +128,7 @@ def run( content: Optional[Union[Dict[Text, Text], List[Text]]] = None, max_tokens: int = 2048, max_iterations: int = 10, + output_format: OutputFormat = OutputFormat.TEXT, ) -> Dict: """Runs an agent call. @@ -142,7 +144,7 @@ def run( content (Union[Dict[Text, Text], List[Text]], optional): Content inputs to be processed according to the query. Defaults to None. max_tokens (int, optional): maximum number of tokens which can be generated by the agent. Defaults to 2048. max_iterations (int, optional): maximum number of iterations between the agent and the tools. Defaults to 10. - + output_format (ResponseFormat, optional): response format. Defaults to TEXT. Returns: Dict: parsed output from model """ @@ -158,6 +160,7 @@ def run( content=content, max_tokens=max_tokens, max_iterations=max_iterations, + output_format=output_format, ) if response["status"] == "FAILED": end = time.time() @@ -184,6 +187,7 @@ def run_async( content: Optional[Union[Dict[Text, Text], List[Text]]] = None, max_tokens: int = 2048, max_iterations: int = 10, + output_format: OutputFormat = OutputFormat.TEXT, ) -> Dict: """Runs asynchronously an agent call. @@ -197,7 +201,7 @@ def run_async( content (Union[Dict[Text, Text], List[Text]], optional): Content inputs to be processed according to the query. Defaults to None. max_tokens (int, optional): maximum number of tokens which can be generated by the agent. Defaults to 2048. max_iterations (int, optional): maximum number of iterations between the agent and the tools. Defaults to 10. - + output_format (ResponseFormat, optional): response format. Defaults to TEXT. Returns: dict: polling URL in response """ @@ -234,13 +238,18 @@ def run_async( headers = {"x-api-key": self.api_key, "Content-Type": "application/json"} - payload = {"id": self.id, "query": FileFactory.to_link(query), "sessionId": session_id, "history": history} - parameters.update( - { - "max_tokens": parameters["max_tokens"] if "max_tokens" in parameters else max_tokens, - "max_iterations": parameters["max_iterations"] if "max_iterations" in parameters else max_iterations, - } - ) + payload = { + "id": self.id, + "query": FileFactory.to_link(query), + "sessionId": session_id, + "history": history, + "executionParams": { + "maxTokens": parameters["max_tokens"] if "max_tokens" in parameters else max_tokens, + "maxIterations": parameters["max_iterations"] if "max_iterations" in parameters else max_iterations, + "outputFormat": output_format.value, + }, + } + payload.update(parameters) payload = json.dumps(payload) diff --git a/aixplain/modules/agent/output_format.py b/aixplain/modules/agent/output_format.py new file mode 100644 index 00000000..3a53e2f8 --- /dev/null +++ b/aixplain/modules/agent/output_format.py @@ -0,0 +1,30 @@ +__author__ = "thiagocastroferreira" + +""" +Copyright 2024 The aiXplain SDK authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +Author: Duraikrishna Selvaraju, Thiago Castro Ferreira, Shreyas Sharma and Lucas Pavanelli +Date: February 21st 2024 +Description: + Asset Enum +""" + +from enum import Enum +from typing import Text + + +class OutputFormat(Text, Enum): + MARKDOWN = "markdown" + TEXT = "text" diff --git a/aixplain/modules/team_agent/__init__.py b/aixplain/modules/team_agent/__init__.py index 2f8b5c3b..08d820f0 100644 --- a/aixplain/modules/team_agent/__init__.py +++ b/aixplain/modules/team_agent/__init__.py @@ -33,7 +33,7 @@ from aixplain.enums.asset_status import AssetStatus from aixplain.enums.storage_type import StorageType from aixplain.modules.model import Model -from aixplain.modules.agent import Agent +from aixplain.modules.agent import Agent, OutputFormat from typing import Dict, List, Text, Optional, Union from urllib.parse import urljoin @@ -112,6 +112,7 @@ def run( content: Optional[Union[Dict[Text, Text], List[Text]]] = None, max_tokens: int = 2048, max_iterations: int = 30, + output_format: OutputFormat = OutputFormat.TEXT, ) -> Dict: """Runs a team agent call. @@ -127,6 +128,7 @@ def run( content (Union[Dict[Text, Text], List[Text]], optional): Content inputs to be processed according to the query. Defaults to None. max_tokens (int, optional): maximum number of tokens which can be generated by the agents. Defaults to 2048. max_iterations (int, optional): maximum number of iterations between the agents. Defaults to 30. + output_format (ResponseFormat, optional): response format. Defaults to TEXT. Returns: Dict: parsed output from model """ @@ -142,6 +144,7 @@ def run( content=content, max_tokens=max_tokens, max_iterations=max_iterations, + output_format=output_format, ) if response["status"] == "FAILED": end = time.time() @@ -168,6 +171,7 @@ def run_async( content: Optional[Union[Dict[Text, Text], List[Text]]] = None, max_tokens: int = 2048, max_iterations: int = 30, + output_format: OutputFormat = OutputFormat.TEXT, ) -> Dict: """Runs asynchronously a Team Agent call. @@ -181,7 +185,7 @@ def run_async( content (Union[Dict[Text, Text], List[Text]], optional): Content inputs to be processed according to the query. Defaults to None. max_tokens (int, optional): maximum number of tokens which can be generated by the agents. Defaults to 2048. max_iterations (int, optional): maximum number of iterations between the agents. Defaults to 30. - + output_format (ResponseFormat, optional): response format. Defaults to TEXT. Returns: dict: polling URL in response """ @@ -218,13 +222,17 @@ def run_async( headers = {"x-api-key": self.api_key, "Content-Type": "application/json"} - payload = {"id": self.id, "query": FileFactory.to_link(query), "sessionId": session_id, "history": history} - parameters.update( - { - "max_tokens": parameters["max_tokens"] if "max_tokens" in parameters else max_tokens, - "max_iterations": parameters["max_iterations"] if "max_iterations" in parameters else max_iterations, - } - ) + payload = { + "id": self.id, + "query": FileFactory.to_link(query), + "sessionId": session_id, + "history": history, + "executionParams": { + "maxTokens": parameters["max_tokens"] if "max_tokens" in parameters else max_tokens, + "maxIterations": parameters["max_iterations"] if "max_iterations" in parameters else max_iterations, + "outputFormat": output_format.value, + }, + } payload.update(parameters) payload = json.dumps(payload) diff --git a/tests/unit/agent_test.py b/tests/unit/agent_test.py index 1b4fd929..9e38937f 100644 --- a/tests/unit/agent_test.py +++ b/tests/unit/agent_test.py @@ -2,6 +2,7 @@ import requests_mock from aixplain.enums.asset_status import AssetStatus from aixplain.modules import Agent +from aixplain.modules.agent import OutputFormat from aixplain.utils import config from aixplain.factories import AgentFactory from aixplain.modules.agent import PipelineTool, ModelTool @@ -226,3 +227,20 @@ def test_update_success(): assert agent.description == ref_response["description"] assert agent.llm_id == ref_response["llmId"] assert agent.tools[0].function.value == ref_response["assets"][0]["function"] + + +def test_run_success(): + agent = Agent("123", "Test Agent", "Sample Description") + url = urljoin(config.BACKEND_URL, f"sdk/agents/{agent.id}/run") + agent.url = url + with requests_mock.Mocker() as mock: + headers = {"x-aixplain-key": config.AIXPLAIN_API_KEY, "Content-Type": "application/json"} + + ref_response = {"data": "www.aixplain.com", "status": "IN_PROGRESS"} + mock.post(url, headers=headers, json=ref_response) + + response = agent.run_async( + data={"query": "Hello, how are you?"}, max_iterations=10, output_format=OutputFormat.MARKDOWN + ) + assert response["status"] == "IN_PROGRESS" + assert response["url"] == ref_response["data"]