diff --git a/aixplain/modules/agent/__init__.py b/aixplain/modules/agent/__init__.py index 8a5cd120..c0604f6a 100644 --- a/aixplain/modules/agent/__init__.py +++ b/aixplain/modules/agent/__init__.py @@ -105,7 +105,7 @@ def run( timeout: float = 300, parameters: Dict = {}, wait_time: float = 0.5, - content: List[Text] = [], + content: Optional[Union[Dict[Text, Text], List[Text]]] = None, ) -> Dict: """Runs an agent call. @@ -118,7 +118,7 @@ def run( timeout (float, optional): total polling time. Defaults to 300. parameters (Dict, optional): optional parameters to the model. Defaults to "{}". wait_time (float, optional): wait time in seconds between polling calls. Defaults to 0.5. - content (List[Text], optional): Content inputs to be processed according to the query. Defaults to []. + content (Union[Dict[Text, Text], List[Text]], optional): Content inputs to be processed according to the query. Defaults to None. Returns: Dict: parsed output from model @@ -156,7 +156,7 @@ def run_async( history: Optional[List[Dict]] = None, name: Text = "model_process", parameters: Dict = {}, - content: List[Text] = [], + content: Optional[Union[Dict[Text, Text], List[Text]]] = None, ) -> Dict: """Runs asynchronously an agent call. @@ -167,7 +167,7 @@ def run_async( history (Optional[List[Dict]], optional): chat history (in case session ID is None). Defaults to None. name (Text, optional): ID given to a call. Defaults to "model_process". parameters (Dict, optional): optional parameters to the model. Defaults to "{}". - content (List[Text], optional): Content inputs to be processed according to the query. Defaults to []. + content (Union[Dict[Text, Text], List[Text]], optional): Content inputs to be processed according to the query. Defaults to None. Returns: dict: polling URL in response @@ -183,19 +183,25 @@ def run_async( session_id = data.get("session_id") if history is None: history = data.get("history") - if len(content) == 0: - content = data.get("content", []) + if content is None: + content = data.get("content") else: query = data # process content inputs - content = list(set(content)) - if len(content) > 0: + if content is not None: assert FileFactory.check_storage_type(query) == StorageType.TEXT, "When providing 'content', query must be text." - assert len(content) <= 3, "The maximum number of content inputs is 3." - for input_link in content: - input_link = FileFactory.to_link(input_link) - query += f"\n{input_link}" + + if isinstance(content, list): + assert len(content) <= 3, "The maximum number of content inputs is 3." + for input_link in content: + input_link = FileFactory.to_link(input_link) + query += f"\n{input_link}" + elif isinstance(content, dict): + for key, value in content.items(): + assert "{{" + key + "}}" in query, f"Key '{key}' not found in query." + value = FileFactory.to_link(value) + query = query.replace("{{" + key + "}}", f"'{value}'") headers = {"x-api-key": self.api_key, "Content-Type": "application/json"} diff --git a/tests/unit/agent_test.py b/tests/unit/agent_test.py new file mode 100644 index 00000000..680fc21a --- /dev/null +++ b/tests/unit/agent_test.py @@ -0,0 +1,63 @@ +import pytest +import requests_mock +from aixplain.modules import Agent +from aixplain.utils import config + + +def test_fail_no_data_query(): + agent = Agent("123", "Test Agent") + with pytest.raises(Exception) as exc_info: + agent.run_async() + assert str(exc_info.value) == "Either 'data' or 'query' must be provided." + + +def test_fail_query_must_be_provided(): + agent = Agent("123", "Test Agent") + with pytest.raises(Exception) as exc_info: + agent.run_async(data={}) + assert str(exc_info.value) == "When providing a dictionary, 'query' must be provided." + + +def test_fail_query_as_text_when_content_not_empty(): + agent = Agent("123", "Test Agent") + with pytest.raises(Exception) as exc_info: + agent.run_async( + data={"query": "https://aixplain-platform-assets.s3.amazonaws.com/samples/en/CPAC1x2.wav"}, + content=["https://aixplain-platform-assets.s3.amazonaws.com/samples/en/CPAC1x2.wav"], + ) + assert str(exc_info.value) == "When providing 'content', query must be text." + + +def test_fail_content_exceed_maximum(): + agent = Agent("123", "Test Agent") + with pytest.raises(Exception) as exc_info: + agent.run_async( + data={"query": "Transcribe the audios:"}, + content=[ + "https://aixplain-platform-assets.s3.amazonaws.com/samples/en/CPAC1x2.wav", + "https://aixplain-platform-assets.s3.amazonaws.com/samples/en/CPAC1x2.wav", + "https://aixplain-platform-assets.s3.amazonaws.com/samples/en/CPAC1x2.wav", + "https://aixplain-platform-assets.s3.amazonaws.com/samples/en/CPAC1x2.wav", + ], + ) + assert str(exc_info.value) == "The maximum number of content inputs is 3." + + +def test_fail_key_not_found(): + agent = Agent("123", "Test Agent") + with pytest.raises(Exception) as exc_info: + agent.run_async(data={"query": "Translate the text: {{input1}}"}, content={"input2": "Hello, how are you?"}) + assert str(exc_info.value) == "Key 'input2' not found in query." + + +def test_sucess_query_content(): + agent = Agent("123", "Test Agent") + with requests_mock.Mocker() as mock: + url = agent.url + headers = {"x-api-key": config.TEAM_API_KEY, "Content-Type": "application/json"} + ref_response = {"data": "Hello, how are you?", "status": "IN_PROGRESS"} + mock.post(url, headers=headers, json=ref_response) + + response = agent.run_async(data={"query": "Translate the text: {{input1}}"}, content={"input1": "Hello, how are you?"}) + assert response["status"] == ref_response["status"] + assert response["url"] == ref_response["data"]