Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 47 additions & 6 deletions aixplain/modules/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from aixplain.utils.file_utils import _request_with_retry
from aixplain.enums.supplier import Supplier
from aixplain.enums.asset_status import AssetStatus
from aixplain.enums.storage_type import StorageType
from aixplain.modules.model import Model
from aixplain.modules.agent.tool import Tool
from aixplain.modules.agent.tool.model_tool import ModelTool
Expand Down Expand Up @@ -96,31 +97,43 @@ def __init__(

def run(
self,
query: Text,
data: Optional[Union[Dict, Text]] = None,
query: Optional[Text] = None,
session_id: Optional[Text] = None,
history: Optional[List[Dict]] = None,
name: Text = "model_process",
timeout: float = 300,
parameters: Dict = {},
wait_time: float = 0.5,
content: List[Text] = [],
) -> Dict:
"""Runs an agent call.

Args:
query (Text): query to be processed by the agent.
data (Optional[Union[Dict, Text]], optional): data to be processed by the agent. Defaults to None.
query (Optional[Text], optional): query to be processed by the agent. Defaults to None.
session_id (Optional[Text], optional): conversation Session ID. Defaults to None.
history (Optional[List[Dict]], optional): chat history (in case session ID is None). Defaults to None.
name (Text, optional): ID given to a call. Defaults to "model_process".
timeout (float, optional): total polling time. Defaults to 300.
parameters (Dict, optional): optional parameters to the model. Defaults to "{}".
wait_time (float, optional): wait time in seconds between polling calls. Defaults to 0.5.
content (List[Text], optional): Content inputs to be processed according to the query. Defaults to [].

Returns:
Dict: parsed output from model
"""
start = time.time()
try:
response = self.run_async(query=query, session_id=session_id, history=history, name=name, parameters=parameters)
response = self.run_async(
data=data,
query=query,
session_id=session_id,
history=history,
name=name,
parameters=parameters,
content=content,
)
if response["status"] == "FAILED":
end = time.time()
response["elapsed_time"] = end - start
Expand All @@ -137,27 +150,55 @@ def run(

def run_async(
self,
query: Text,
data: Optional[Union[Dict, Text]] = None,
query: Optional[Text] = None,
session_id: Optional[Text] = None,
history: Optional[List[Dict]] = None,
name: Text = "model_process",
parameters: Dict = {},
content: List[Text] = [],
) -> Dict:
"""Runs asynchronously an agent call.

Args:
query (Text): query to be processed by the agent.
data (Optional[Union[Dict, Text]], optional): data to be processed by the agent. Defaults to None.
query (Optional[Text], optional): query to be processed by the agent. Defaults to None.
session_id (Optional[Text], optional): conversation Session ID. Defaults to None.
history (Optional[List[Dict]], optional): chat history (in case session ID is None). Defaults to None.
name (Text, optional): ID given to a call. Defaults to "model_process".
parameters (Dict, optional): optional parameters to the model. Defaults to "{}".
content (List[Text], optional): Content inputs to be processed according to the query. Defaults to [].

Returns:
dict: polling URL in response
"""
headers = {"x-api-key": self.api_key, "Content-Type": "application/json"}
from aixplain.factories.file_factory import FileFactory

assert data is not None or query is not None, "Either 'data' or 'query' must be provided."
if data is not None:
if isinstance(data, dict):
assert "query" in data and data["query"] is not None, "When providing a dictionary, 'query' must be provided."
query = data.get("query")
if session_id is None:
session_id = data.get("session_id")
if history is None:
history = data.get("history")
if len(content) == 0:
content = data.get("content", [])
else:
query = data

# process content inputs
content = list(set(content))
if len(content) > 0:
assert FileFactory.check_storage_type(query) == StorageType.TEXT, "When providing 'content', query must be text."
assert len(content) <= 3, "The maximum number of content inputs is 3."
for input_link in content:
input_link = FileFactory.to_link(input_link)
query += f"\n{input_link}"

headers = {"x-api-key": self.api_key, "Content-Type": "application/json"}

payload = {"id": self.id, "query": FileFactory.to_link(query), "sessionId": session_id, "history": history}
payload.update(parameters)
payload = json.dumps(payload)
Expand Down
2 changes: 1 addition & 1 deletion tests/functional/agent/agent_functional_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_end2end(run_input_map):
agent = AgentFactory.create(name=run_input_map["agent_name"], llm_id=run_input_map["llm_id"], tools=tools)
print(f"Agent created: {agent.__dict__}")
print("Running agent")
response = agent.run(query=run_input_map["query"])
response = agent.run(data=run_input_map["query"])
print(f"Agent response: {response}")
assert response is not None
assert response["completed"] is True
Expand Down