Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 18 additions & 12 deletions aixplain/modules/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def run(
timeout: float = 300,
parameters: Dict = {},
wait_time: float = 0.5,
content: List[Text] = [],
content: Optional[Union[Dict[Text, Text], List[Text]]] = None,
) -> Dict:
"""Runs an agent call.

Expand All @@ -118,7 +118,7 @@ def run(
timeout (float, optional): total polling time. Defaults to 300.
parameters (Dict, optional): optional parameters to the model. Defaults to "{}".
wait_time (float, optional): wait time in seconds between polling calls. Defaults to 0.5.
content (List[Text], optional): Content inputs to be processed according to the query. Defaults to [].
content (Union[Dict[Text, Text], List[Text]], optional): Content inputs to be processed according to the query. Defaults to None.

Returns:
Dict: parsed output from model
Expand Down Expand Up @@ -156,7 +156,7 @@ def run_async(
history: Optional[List[Dict]] = None,
name: Text = "model_process",
parameters: Dict = {},
content: List[Text] = [],
content: Optional[Union[Dict[Text, Text], List[Text]]] = None,
) -> Dict:
"""Runs asynchronously an agent call.

Expand All @@ -167,7 +167,7 @@ def run_async(
history (Optional[List[Dict]], optional): chat history (in case session ID is None). Defaults to None.
name (Text, optional): ID given to a call. Defaults to "model_process".
parameters (Dict, optional): optional parameters to the model. Defaults to "{}".
content (List[Text], optional): Content inputs to be processed according to the query. Defaults to [].
content (Union[Dict[Text, Text], List[Text]], optional): Content inputs to be processed according to the query. Defaults to None.

Returns:
dict: polling URL in response
Expand All @@ -183,19 +183,25 @@ def run_async(
session_id = data.get("session_id")
if history is None:
history = data.get("history")
if len(content) == 0:
content = data.get("content", [])
if content is None:
content = data.get("content")
else:
query = data

# process content inputs
content = list(set(content))
if len(content) > 0:
if content is not None:
assert FileFactory.check_storage_type(query) == StorageType.TEXT, "When providing 'content', query must be text."
assert len(content) <= 3, "The maximum number of content inputs is 3."
for input_link in content:
input_link = FileFactory.to_link(input_link)
query += f"\n{input_link}"

if isinstance(content, list):
assert len(content) <= 3, "The maximum number of content inputs is 3."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tim-nelson, we need to set this max number of content in the documentation.

for input_link in content:
input_link = FileFactory.to_link(input_link)
query += f"\n{input_link}"
elif isinstance(content, dict):
for key, value in content.items():
assert "{{" + key + "}}" in query, f"Key '{key}' not found in query."
value = FileFactory.to_link(value)
query = query.replace("{{" + key + "}}", f"'{value}'")

headers = {"x-api-key": self.api_key, "Content-Type": "application/json"}

Expand Down
63 changes: 63 additions & 0 deletions tests/unit/agent_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import pytest
import requests_mock
from aixplain.modules import Agent
from aixplain.utils import config


def test_fail_no_data_query():
agent = Agent("123", "Test Agent")
with pytest.raises(Exception) as exc_info:
agent.run_async()
assert str(exc_info.value) == "Either 'data' or 'query' must be provided."


def test_fail_query_must_be_provided():
agent = Agent("123", "Test Agent")
with pytest.raises(Exception) as exc_info:
agent.run_async(data={})
assert str(exc_info.value) == "When providing a dictionary, 'query' must be provided."


def test_fail_query_as_text_when_content_not_empty():
agent = Agent("123", "Test Agent")
with pytest.raises(Exception) as exc_info:
agent.run_async(
data={"query": "https://aixplain-platform-assets.s3.amazonaws.com/samples/en/CPAC1x2.wav"},
content=["https://aixplain-platform-assets.s3.amazonaws.com/samples/en/CPAC1x2.wav"],
)
assert str(exc_info.value) == "When providing 'content', query must be text."


def test_fail_content_exceed_maximum():
agent = Agent("123", "Test Agent")
with pytest.raises(Exception) as exc_info:
agent.run_async(
data={"query": "Transcribe the audios:"},
content=[
"https://aixplain-platform-assets.s3.amazonaws.com/samples/en/CPAC1x2.wav",
"https://aixplain-platform-assets.s3.amazonaws.com/samples/en/CPAC1x2.wav",
"https://aixplain-platform-assets.s3.amazonaws.com/samples/en/CPAC1x2.wav",
"https://aixplain-platform-assets.s3.amazonaws.com/samples/en/CPAC1x2.wav",
],
)
assert str(exc_info.value) == "The maximum number of content inputs is 3."


def test_fail_key_not_found():
agent = Agent("123", "Test Agent")
with pytest.raises(Exception) as exc_info:
agent.run_async(data={"query": "Translate the text: {{input1}}"}, content={"input2": "Hello, how are you?"})
assert str(exc_info.value) == "Key 'input2' not found in query."


def test_sucess_query_content():
agent = Agent("123", "Test Agent")
with requests_mock.Mocker() as mock:
url = agent.url
headers = {"x-api-key": config.TEAM_API_KEY, "Content-Type": "application/json"}
ref_response = {"data": "Hello, how are you?", "status": "IN_PROGRESS"}
mock.post(url, headers=headers, json=ref_response)

response = agent.run_async(data={"query": "Translate the text: {{input1}}"}, content={"input1": "Hello, how are you?"})
assert response["status"] == ref_response["status"]
assert response["url"] == ref_response["data"]