From 8b3814f736aeb68f5bffe77a05ffc1bd697d4921 Mon Sep 17 00:00:00 2001 From: Thiago Castro Ferreira Date: Mon, 5 Aug 2024 12:11:58 -0300 Subject: [PATCH 1/2] Content inputs to be processed according to the query. --- aixplain/modules/agent/__init__.py | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/aixplain/modules/agent/__init__.py b/aixplain/modules/agent/__init__.py index 2f244d56..2aecc752 100644 --- a/aixplain/modules/agent/__init__.py +++ b/aixplain/modules/agent/__init__.py @@ -28,6 +28,7 @@ from aixplain.utils.file_utils import _request_with_retry from aixplain.enums.supplier import Supplier from aixplain.enums.asset_status import AssetStatus +from aixplain.enums.storage_type import StorageType from aixplain.modules.model import Model from aixplain.modules.agent.tool import Tool from aixplain.modules.agent.tool.model_tool import ModelTool @@ -103,6 +104,7 @@ def run( timeout: float = 300, parameters: Dict = {}, wait_time: float = 0.5, + content_inputs: List[Text] = [], ) -> Dict: """Runs an agent call. @@ -114,13 +116,21 @@ def run( timeout (float, optional): total polling time. Defaults to 300. parameters (Dict, optional): optional parameters to the model. Defaults to "{}". wait_time (float, optional): wait time in seconds between polling calls. Defaults to 0.5. + content_inputs (List[Text], optional): Content inputs to be processed according to the query. Defaults to []. Returns: Dict: parsed output from model """ start = time.time() try: - response = self.run_async(query=query, session_id=session_id, history=history, name=name, parameters=parameters) + response = self.run_async( + query=query, + session_id=session_id, + history=history, + name=name, + parameters=parameters, + content_inputs=content_inputs, + ) if response["status"] == "FAILED": end = time.time() response["elapsed_time"] = end - start @@ -142,6 +152,7 @@ def run_async( history: Optional[List[Dict]] = None, name: Text = "model_process", parameters: Dict = {}, + content_inputs: List[Text] = [], ) -> Dict: """Runs asynchronously an agent call. @@ -151,13 +162,26 @@ def run_async( history (Optional[List[Dict]], optional): chat history (in case session ID is None). Defaults to None. name (Text, optional): ID given to a call. Defaults to "model_process". parameters (Dict, optional): optional parameters to the model. Defaults to "{}". + content_inputs (List[Text], optional): Content inputs to be processed according to the query. Defaults to []. Returns: dict: polling URL in response """ - headers = {"x-api-key": self.api_key, "Content-Type": "application/json"} from aixplain.factories.file_factory import FileFactory + # process content inputs + content_inputs = list(set(content_inputs)) + if len(content_inputs) > 0: + assert ( + FileFactory.check_storage_type(query) == StorageType.TEXT + ), "When providing 'content_inputs', query must be text." + assert len(content_inputs) <= 3, "The maximum number of content inputs is 3." + for input_link in content_inputs: + input_link = FileFactory.to_link(input_link) + query += f"\n{input_link}" + + headers = {"x-api-key": self.api_key, "Content-Type": "application/json"} + payload = {"id": self.id, "query": FileFactory.to_link(query), "sessionId": session_id, "history": history} payload.update(parameters) payload = json.dumps(payload) From f1b339b867d49736cf58d4f2def0314a7da64caf Mon Sep 17 00:00:00 2001 From: Thiago Castro Ferreira Date: Wed, 7 Aug 2024 18:29:14 -0300 Subject: [PATCH 2/2] Add data and query parameters on running agent --- aixplain/modules/agent/__init__.py | 49 +++++++++++++------ .../functional/agent/agent_functional_test.py | 2 +- 2 files changed, 34 insertions(+), 17 deletions(-) diff --git a/aixplain/modules/agent/__init__.py b/aixplain/modules/agent/__init__.py index 2aecc752..8a5cd120 100644 --- a/aixplain/modules/agent/__init__.py +++ b/aixplain/modules/agent/__init__.py @@ -97,26 +97,28 @@ def __init__( def run( self, - query: Text, + data: Optional[Union[Dict, Text]] = None, + query: Optional[Text] = None, session_id: Optional[Text] = None, history: Optional[List[Dict]] = None, name: Text = "model_process", timeout: float = 300, parameters: Dict = {}, wait_time: float = 0.5, - content_inputs: List[Text] = [], + content: List[Text] = [], ) -> Dict: """Runs an agent call. Args: - query (Text): query to be processed by the agent. + data (Optional[Union[Dict, Text]], optional): data to be processed by the agent. Defaults to None. + query (Optional[Text], optional): query to be processed by the agent. Defaults to None. session_id (Optional[Text], optional): conversation Session ID. Defaults to None. history (Optional[List[Dict]], optional): chat history (in case session ID is None). Defaults to None. name (Text, optional): ID given to a call. Defaults to "model_process". timeout (float, optional): total polling time. Defaults to 300. parameters (Dict, optional): optional parameters to the model. Defaults to "{}". wait_time (float, optional): wait time in seconds between polling calls. Defaults to 0.5. - content_inputs (List[Text], optional): Content inputs to be processed according to the query. Defaults to []. + content (List[Text], optional): Content inputs to be processed according to the query. Defaults to []. Returns: Dict: parsed output from model @@ -124,12 +126,13 @@ def run( start = time.time() try: response = self.run_async( + data=data, query=query, session_id=session_id, history=history, name=name, parameters=parameters, - content_inputs=content_inputs, + content=content, ) if response["status"] == "FAILED": end = time.time() @@ -147,36 +150,50 @@ def run( def run_async( self, - query: Text, + data: Optional[Union[Dict, Text]] = None, + query: Optional[Text] = None, session_id: Optional[Text] = None, history: Optional[List[Dict]] = None, name: Text = "model_process", parameters: Dict = {}, - content_inputs: List[Text] = [], + content: List[Text] = [], ) -> Dict: """Runs asynchronously an agent call. Args: - query (Text): query to be processed by the agent. + data (Optional[Union[Dict, Text]], optional): data to be processed by the agent. Defaults to None. + query (Optional[Text], optional): query to be processed by the agent. Defaults to None. session_id (Optional[Text], optional): conversation Session ID. Defaults to None. history (Optional[List[Dict]], optional): chat history (in case session ID is None). Defaults to None. name (Text, optional): ID given to a call. Defaults to "model_process". parameters (Dict, optional): optional parameters to the model. Defaults to "{}". - content_inputs (List[Text], optional): Content inputs to be processed according to the query. Defaults to []. + content (List[Text], optional): Content inputs to be processed according to the query. Defaults to []. Returns: dict: polling URL in response """ from aixplain.factories.file_factory import FileFactory + assert data is not None or query is not None, "Either 'data' or 'query' must be provided." + if data is not None: + if isinstance(data, dict): + assert "query" in data and data["query"] is not None, "When providing a dictionary, 'query' must be provided." + query = data.get("query") + if session_id is None: + session_id = data.get("session_id") + if history is None: + history = data.get("history") + if len(content) == 0: + content = data.get("content", []) + else: + query = data + # process content inputs - content_inputs = list(set(content_inputs)) - if len(content_inputs) > 0: - assert ( - FileFactory.check_storage_type(query) == StorageType.TEXT - ), "When providing 'content_inputs', query must be text." - assert len(content_inputs) <= 3, "The maximum number of content inputs is 3." - for input_link in content_inputs: + content = list(set(content)) + if len(content) > 0: + assert FileFactory.check_storage_type(query) == StorageType.TEXT, "When providing 'content', query must be text." + assert len(content) <= 3, "The maximum number of content inputs is 3." + for input_link in content: input_link = FileFactory.to_link(input_link) query += f"\n{input_link}" diff --git a/tests/functional/agent/agent_functional_test.py b/tests/functional/agent/agent_functional_test.py index 766ba386..1827ef94 100644 --- a/tests/functional/agent/agent_functional_test.py +++ b/tests/functional/agent/agent_functional_test.py @@ -55,7 +55,7 @@ def test_end2end(run_input_map): agent = AgentFactory.create(name=run_input_map["agent_name"], llm_id=run_input_map["llm_id"], tools=tools) print(f"Agent created: {agent.__dict__}") print("Running agent") - response = agent.run(query=run_input_map["query"]) + response = agent.run(data=run_input_map["query"]) print(f"Agent response: {response}") assert response is not None assert response["completed"] is True