From 8b3814f736aeb68f5bffe77a05ffc1bd697d4921 Mon Sep 17 00:00:00 2001 From: Thiago Castro Ferreira Date: Mon, 5 Aug 2024 12:11:58 -0300 Subject: [PATCH 1/4] Content inputs to be processed according to the query. --- aixplain/modules/agent/__init__.py | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/aixplain/modules/agent/__init__.py b/aixplain/modules/agent/__init__.py index 2f244d56..2aecc752 100644 --- a/aixplain/modules/agent/__init__.py +++ b/aixplain/modules/agent/__init__.py @@ -28,6 +28,7 @@ from aixplain.utils.file_utils import _request_with_retry from aixplain.enums.supplier import Supplier from aixplain.enums.asset_status import AssetStatus +from aixplain.enums.storage_type import StorageType from aixplain.modules.model import Model from aixplain.modules.agent.tool import Tool from aixplain.modules.agent.tool.model_tool import ModelTool @@ -103,6 +104,7 @@ def run( timeout: float = 300, parameters: Dict = {}, wait_time: float = 0.5, + content_inputs: List[Text] = [], ) -> Dict: """Runs an agent call. @@ -114,13 +116,21 @@ def run( timeout (float, optional): total polling time. Defaults to 300. parameters (Dict, optional): optional parameters to the model. Defaults to "{}". wait_time (float, optional): wait time in seconds between polling calls. Defaults to 0.5. + content_inputs (List[Text], optional): Content inputs to be processed according to the query. Defaults to []. Returns: Dict: parsed output from model """ start = time.time() try: - response = self.run_async(query=query, session_id=session_id, history=history, name=name, parameters=parameters) + response = self.run_async( + query=query, + session_id=session_id, + history=history, + name=name, + parameters=parameters, + content_inputs=content_inputs, + ) if response["status"] == "FAILED": end = time.time() response["elapsed_time"] = end - start @@ -142,6 +152,7 @@ def run_async( history: Optional[List[Dict]] = None, name: Text = "model_process", parameters: Dict = {}, + content_inputs: List[Text] = [], ) -> Dict: """Runs asynchronously an agent call. @@ -151,13 +162,26 @@ def run_async( history (Optional[List[Dict]], optional): chat history (in case session ID is None). Defaults to None. name (Text, optional): ID given to a call. Defaults to "model_process". parameters (Dict, optional): optional parameters to the model. Defaults to "{}". + content_inputs (List[Text], optional): Content inputs to be processed according to the query. Defaults to []. Returns: dict: polling URL in response """ - headers = {"x-api-key": self.api_key, "Content-Type": "application/json"} from aixplain.factories.file_factory import FileFactory + # process content inputs + content_inputs = list(set(content_inputs)) + if len(content_inputs) > 0: + assert ( + FileFactory.check_storage_type(query) == StorageType.TEXT + ), "When providing 'content_inputs', query must be text." + assert len(content_inputs) <= 3, "The maximum number of content inputs is 3." + for input_link in content_inputs: + input_link = FileFactory.to_link(input_link) + query += f"\n{input_link}" + + headers = {"x-api-key": self.api_key, "Content-Type": "application/json"} + payload = {"id": self.id, "query": FileFactory.to_link(query), "sessionId": session_id, "history": history} payload.update(parameters) payload = json.dumps(payload) From f1b339b867d49736cf58d4f2def0314a7da64caf Mon Sep 17 00:00:00 2001 From: Thiago Castro Ferreira Date: Wed, 7 Aug 2024 18:29:14 -0300 Subject: [PATCH 2/4] Add data and query parameters on running agent --- aixplain/modules/agent/__init__.py | 49 +++++++++++++------ .../functional/agent/agent_functional_test.py | 2 +- 2 files changed, 34 insertions(+), 17 deletions(-) diff --git a/aixplain/modules/agent/__init__.py b/aixplain/modules/agent/__init__.py index 2aecc752..8a5cd120 100644 --- a/aixplain/modules/agent/__init__.py +++ b/aixplain/modules/agent/__init__.py @@ -97,26 +97,28 @@ def __init__( def run( self, - query: Text, + data: Optional[Union[Dict, Text]] = None, + query: Optional[Text] = None, session_id: Optional[Text] = None, history: Optional[List[Dict]] = None, name: Text = "model_process", timeout: float = 300, parameters: Dict = {}, wait_time: float = 0.5, - content_inputs: List[Text] = [], + content: List[Text] = [], ) -> Dict: """Runs an agent call. Args: - query (Text): query to be processed by the agent. + data (Optional[Union[Dict, Text]], optional): data to be processed by the agent. Defaults to None. + query (Optional[Text], optional): query to be processed by the agent. Defaults to None. session_id (Optional[Text], optional): conversation Session ID. Defaults to None. history (Optional[List[Dict]], optional): chat history (in case session ID is None). Defaults to None. name (Text, optional): ID given to a call. Defaults to "model_process". timeout (float, optional): total polling time. Defaults to 300. parameters (Dict, optional): optional parameters to the model. Defaults to "{}". wait_time (float, optional): wait time in seconds between polling calls. Defaults to 0.5. - content_inputs (List[Text], optional): Content inputs to be processed according to the query. Defaults to []. + content (List[Text], optional): Content inputs to be processed according to the query. Defaults to []. Returns: Dict: parsed output from model @@ -124,12 +126,13 @@ def run( start = time.time() try: response = self.run_async( + data=data, query=query, session_id=session_id, history=history, name=name, parameters=parameters, - content_inputs=content_inputs, + content=content, ) if response["status"] == "FAILED": end = time.time() @@ -147,36 +150,50 @@ def run( def run_async( self, - query: Text, + data: Optional[Union[Dict, Text]] = None, + query: Optional[Text] = None, session_id: Optional[Text] = None, history: Optional[List[Dict]] = None, name: Text = "model_process", parameters: Dict = {}, - content_inputs: List[Text] = [], + content: List[Text] = [], ) -> Dict: """Runs asynchronously an agent call. Args: - query (Text): query to be processed by the agent. + data (Optional[Union[Dict, Text]], optional): data to be processed by the agent. Defaults to None. + query (Optional[Text], optional): query to be processed by the agent. Defaults to None. session_id (Optional[Text], optional): conversation Session ID. Defaults to None. history (Optional[List[Dict]], optional): chat history (in case session ID is None). Defaults to None. name (Text, optional): ID given to a call. Defaults to "model_process". parameters (Dict, optional): optional parameters to the model. Defaults to "{}". - content_inputs (List[Text], optional): Content inputs to be processed according to the query. Defaults to []. + content (List[Text], optional): Content inputs to be processed according to the query. Defaults to []. Returns: dict: polling URL in response """ from aixplain.factories.file_factory import FileFactory + assert data is not None or query is not None, "Either 'data' or 'query' must be provided." + if data is not None: + if isinstance(data, dict): + assert "query" in data and data["query"] is not None, "When providing a dictionary, 'query' must be provided." + query = data.get("query") + if session_id is None: + session_id = data.get("session_id") + if history is None: + history = data.get("history") + if len(content) == 0: + content = data.get("content", []) + else: + query = data + # process content inputs - content_inputs = list(set(content_inputs)) - if len(content_inputs) > 0: - assert ( - FileFactory.check_storage_type(query) == StorageType.TEXT - ), "When providing 'content_inputs', query must be text." - assert len(content_inputs) <= 3, "The maximum number of content inputs is 3." - for input_link in content_inputs: + content = list(set(content)) + if len(content) > 0: + assert FileFactory.check_storage_type(query) == StorageType.TEXT, "When providing 'content', query must be text." + assert len(content) <= 3, "The maximum number of content inputs is 3." + for input_link in content: input_link = FileFactory.to_link(input_link) query += f"\n{input_link}" diff --git a/tests/functional/agent/agent_functional_test.py b/tests/functional/agent/agent_functional_test.py index 766ba386..1827ef94 100644 --- a/tests/functional/agent/agent_functional_test.py +++ b/tests/functional/agent/agent_functional_test.py @@ -55,7 +55,7 @@ def test_end2end(run_input_map): agent = AgentFactory.create(name=run_input_map["agent_name"], llm_id=run_input_map["llm_id"], tools=tools) print(f"Agent created: {agent.__dict__}") print("Running agent") - response = agent.run(query=run_input_map["query"]) + response = agent.run(data=run_input_map["query"]) print(f"Agent response: {response}") assert response is not None assert response["completed"] is True From 12753d01d88b0b9fb489671e003e07722c8ea93d Mon Sep 17 00:00:00 2001 From: Thiago Castro Ferreira Date: Thu, 8 Aug 2024 13:05:21 -0300 Subject: [PATCH 3/4] Enable processing keys/values in content as well --- aixplain/modules/agent/__init__.py | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/aixplain/modules/agent/__init__.py b/aixplain/modules/agent/__init__.py index 8a5cd120..d05e70fb 100644 --- a/aixplain/modules/agent/__init__.py +++ b/aixplain/modules/agent/__init__.py @@ -105,7 +105,7 @@ def run( timeout: float = 300, parameters: Dict = {}, wait_time: float = 0.5, - content: List[Text] = [], + content: Optional[Union[Dict[Text, Text], List[Text]]] = None, ) -> Dict: """Runs an agent call. @@ -118,7 +118,7 @@ def run( timeout (float, optional): total polling time. Defaults to 300. parameters (Dict, optional): optional parameters to the model. Defaults to "{}". wait_time (float, optional): wait time in seconds between polling calls. Defaults to 0.5. - content (List[Text], optional): Content inputs to be processed according to the query. Defaults to []. + content (Union[Dict[Text, Text], List[Text]], optional): Content inputs to be processed according to the query. Defaults to None. Returns: Dict: parsed output from model @@ -156,7 +156,7 @@ def run_async( history: Optional[List[Dict]] = None, name: Text = "model_process", parameters: Dict = {}, - content: List[Text] = [], + content: Optional[Union[Dict[Text, Text], List[Text]]] = None, ) -> Dict: """Runs asynchronously an agent call. @@ -167,7 +167,7 @@ def run_async( history (Optional[List[Dict]], optional): chat history (in case session ID is None). Defaults to None. name (Text, optional): ID given to a call. Defaults to "model_process". parameters (Dict, optional): optional parameters to the model. Defaults to "{}". - content (List[Text], optional): Content inputs to be processed according to the query. Defaults to []. + content (Union[Dict[Text, Text], List[Text]], optional): Content inputs to be processed according to the query. Defaults to None. Returns: dict: polling URL in response @@ -183,19 +183,25 @@ def run_async( session_id = data.get("session_id") if history is None: history = data.get("history") - if len(content) == 0: - content = data.get("content", []) + if content is None: + content = data.get("content") else: query = data # process content inputs - content = list(set(content)) - if len(content) > 0: + if content is not None: assert FileFactory.check_storage_type(query) == StorageType.TEXT, "When providing 'content', query must be text." - assert len(content) <= 3, "The maximum number of content inputs is 3." - for input_link in content: - input_link = FileFactory.to_link(input_link) - query += f"\n{input_link}" + + if isinstance(content, list): + assert len(content) <= 3, "The maximum number of content inputs is 3." + for input_link in content: + input_link = FileFactory.to_link(input_link) + query += f"\n{input_link}" + elif isinstance(content, dict): + for key, value in content.items(): + assert key in query, f"Key '{key}' not found in query." + value = FileFactory.to_link(value) + query = query.replace(key, f"'{value}'") headers = {"x-api-key": self.api_key, "Content-Type": "application/json"} From 4c1238d9d09f4ea73bd342986988acfed6774b28 Mon Sep 17 00:00:00 2001 From: Thiago Castro Ferreira Date: Tue, 13 Aug 2024 13:58:39 -0300 Subject: [PATCH 4/4] Agent units tests and tags simolar Jinja2 --- aixplain/modules/agent/__init__.py | 4 +- tests/unit/agent_test.py | 63 ++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 2 deletions(-) create mode 100644 tests/unit/agent_test.py diff --git a/aixplain/modules/agent/__init__.py b/aixplain/modules/agent/__init__.py index d05e70fb..c0604f6a 100644 --- a/aixplain/modules/agent/__init__.py +++ b/aixplain/modules/agent/__init__.py @@ -199,9 +199,9 @@ def run_async( query += f"\n{input_link}" elif isinstance(content, dict): for key, value in content.items(): - assert key in query, f"Key '{key}' not found in query." + assert "{{" + key + "}}" in query, f"Key '{key}' not found in query." value = FileFactory.to_link(value) - query = query.replace(key, f"'{value}'") + query = query.replace("{{" + key + "}}", f"'{value}'") headers = {"x-api-key": self.api_key, "Content-Type": "application/json"} diff --git a/tests/unit/agent_test.py b/tests/unit/agent_test.py new file mode 100644 index 00000000..680fc21a --- /dev/null +++ b/tests/unit/agent_test.py @@ -0,0 +1,63 @@ +import pytest +import requests_mock +from aixplain.modules import Agent +from aixplain.utils import config + + +def test_fail_no_data_query(): + agent = Agent("123", "Test Agent") + with pytest.raises(Exception) as exc_info: + agent.run_async() + assert str(exc_info.value) == "Either 'data' or 'query' must be provided." + + +def test_fail_query_must_be_provided(): + agent = Agent("123", "Test Agent") + with pytest.raises(Exception) as exc_info: + agent.run_async(data={}) + assert str(exc_info.value) == "When providing a dictionary, 'query' must be provided." + + +def test_fail_query_as_text_when_content_not_empty(): + agent = Agent("123", "Test Agent") + with pytest.raises(Exception) as exc_info: + agent.run_async( + data={"query": "https://aixplain-platform-assets.s3.amazonaws.com/samples/en/CPAC1x2.wav"}, + content=["https://aixplain-platform-assets.s3.amazonaws.com/samples/en/CPAC1x2.wav"], + ) + assert str(exc_info.value) == "When providing 'content', query must be text." + + +def test_fail_content_exceed_maximum(): + agent = Agent("123", "Test Agent") + with pytest.raises(Exception) as exc_info: + agent.run_async( + data={"query": "Transcribe the audios:"}, + content=[ + "https://aixplain-platform-assets.s3.amazonaws.com/samples/en/CPAC1x2.wav", + "https://aixplain-platform-assets.s3.amazonaws.com/samples/en/CPAC1x2.wav", + "https://aixplain-platform-assets.s3.amazonaws.com/samples/en/CPAC1x2.wav", + "https://aixplain-platform-assets.s3.amazonaws.com/samples/en/CPAC1x2.wav", + ], + ) + assert str(exc_info.value) == "The maximum number of content inputs is 3." + + +def test_fail_key_not_found(): + agent = Agent("123", "Test Agent") + with pytest.raises(Exception) as exc_info: + agent.run_async(data={"query": "Translate the text: {{input1}}"}, content={"input2": "Hello, how are you?"}) + assert str(exc_info.value) == "Key 'input2' not found in query." + + +def test_sucess_query_content(): + agent = Agent("123", "Test Agent") + with requests_mock.Mocker() as mock: + url = agent.url + headers = {"x-api-key": config.TEAM_API_KEY, "Content-Type": "application/json"} + ref_response = {"data": "Hello, how are you?", "status": "IN_PROGRESS"} + mock.post(url, headers=headers, json=ref_response) + + response = agent.run_async(data={"query": "Translate the text: {{input1}}"}, content={"input1": "Hello, how are you?"}) + assert response["status"] == ref_response["status"] + assert response["url"] == ref_response["data"]