Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions aixplain/modules/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ def run(
parameters: Dict = {},
wait_time: float = 0.5,
content: Optional[Union[Dict[Text, Text], List[Text]]] = None,
max_tokens: int = 2048,
max_iterations: int = 10,
) -> Dict:
"""Runs an agent call.

Expand All @@ -118,6 +120,8 @@ def run(
parameters (Dict, optional): optional parameters to the model. Defaults to "{}".
wait_time (float, optional): wait time in seconds between polling calls. Defaults to 0.5.
content (Union[Dict[Text, Text], List[Text]], optional): Content inputs to be processed according to the query. Defaults to None.
max_tokens (int, optional): maximum number of tokens which can be generated by the agent. Defaults to 2048.
max_iterations (int, optional): maximum number of iterations between the agent and the tools. Defaults to 10.

Returns:
Dict: parsed output from model
Expand All @@ -132,6 +136,8 @@ def run(
name=name,
parameters=parameters,
content=content,
max_tokens=max_tokens,
max_iterations=max_iterations,
)
if response["status"] == "FAILED":
end = time.time()
Expand All @@ -156,6 +162,8 @@ def run_async(
name: Text = "model_process",
parameters: Dict = {},
content: Optional[Union[Dict[Text, Text], List[Text]]] = None,
max_tokens: int = 2048,
max_iterations: int = 10,
) -> Dict:
"""Runs asynchronously an agent call.

Expand All @@ -167,6 +175,8 @@ def run_async(
name (Text, optional): ID given to a call. Defaults to "model_process".
parameters (Dict, optional): optional parameters to the model. Defaults to "{}".
content (Union[Dict[Text, Text], List[Text]], optional): Content inputs to be processed according to the query. Defaults to None.
max_tokens (int, optional): maximum number of tokens which can be generated by the agent. Defaults to 2048.
max_iterations (int, optional): maximum number of iterations between the agent and the tools. Defaults to 10.

Returns:
dict: polling URL in response
Expand Down Expand Up @@ -205,6 +215,12 @@ def run_async(
headers = {"x-api-key": self.api_key, "Content-Type": "application/json"}

payload = {"id": self.id, "query": FileFactory.to_link(query), "sessionId": session_id, "history": history}
parameters.update(
{
"max_tokens": parameters["max_tokens"] if "max_tokens" in parameters else max_tokens,
"max_iterations": parameters["max_iterations"] if "max_iterations" in parameters else max_iterations,
}
)
payload.update(parameters)
payload = json.dumps(payload)

Expand Down
17 changes: 16 additions & 1 deletion aixplain/modules/team_agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ def run(
parameters: Dict = {},
wait_time: float = 0.5,
content: Optional[Union[Dict[Text, Text], List[Text]]] = None,
max_tokens: int = 2048,
max_iterations: int = 30,
) -> Dict:
"""Runs a team agent call.

Expand All @@ -121,7 +123,8 @@ def run(
parameters (Dict, optional): optional parameters to the model. Defaults to "{}".
wait_time (float, optional): wait time in seconds between polling calls. Defaults to 0.5.
content (Union[Dict[Text, Text], List[Text]], optional): Content inputs to be processed according to the query. Defaults to None.

max_tokens (int, optional): maximum number of tokens which can be generated by the agents. Defaults to 2048.
max_iterations (int, optional): maximum number of iterations between the agents. Defaults to 30.
Returns:
Dict: parsed output from model
"""
Expand All @@ -135,6 +138,8 @@ def run(
name=name,
parameters=parameters,
content=content,
max_tokens=max_tokens,
max_iterations=max_iterations,
)
if response["status"] == "FAILED":
end = time.time()
Expand All @@ -159,6 +164,8 @@ def run_async(
name: Text = "model_process",
parameters: Dict = {},
content: Optional[Union[Dict[Text, Text], List[Text]]] = None,
max_tokens: int = 2048,
max_iterations: int = 30,
) -> Dict:
"""Runs asynchronously a Team Agent call.

Expand All @@ -170,6 +177,8 @@ def run_async(
name (Text, optional): ID given to a call. Defaults to "model_process".
parameters (Dict, optional): optional parameters to the model. Defaults to "{}".
content (Union[Dict[Text, Text], List[Text]], optional): Content inputs to be processed according to the query. Defaults to None.
max_tokens (int, optional): maximum number of tokens which can be generated by the agents. Defaults to 2048.
max_iterations (int, optional): maximum number of iterations between the agents. Defaults to 30.

Returns:
dict: polling URL in response
Expand Down Expand Up @@ -208,6 +217,12 @@ def run_async(
headers = {"x-api-key": self.api_key, "Content-Type": "application/json"}

payload = {"id": self.id, "query": FileFactory.to_link(query), "sessionId": session_id, "history": history}
parameters.update(
{
"max_tokens": parameters["max_tokens"] if "max_tokens" in parameters else max_tokens,
"max_iterations": parameters["max_iterations"] if "max_iterations" in parameters else max_iterations,
}
)
payload.update(parameters)
payload = json.dumps(payload)

Expand Down