Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 65 additions & 2 deletions aixplain/modules/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import re
import time
import traceback
from datetime import datetime

from aixplain.utils.file_utils import _request_with_retry
from aixplain.enums import Function, Supplier, AssetStatus, StorageType, ResponseStatus
Expand All @@ -34,7 +35,7 @@
from aixplain.modules.agent.tool import Tool
from aixplain.modules.agent.agent_response import AgentResponse
from aixplain.modules.agent.agent_response_data import AgentResponseData
from aixplain.modules.agent.utils import process_variables
from aixplain.modules.agent.utils import process_variables, validate_history
from pydantic import BaseModel
from typing import Dict, List, Text, Optional, Union
from urllib.parse import urljoin
Expand Down Expand Up @@ -164,6 +165,50 @@ def validate(self, raise_exception: bool = False) -> bool:
logging.warning(f"Agent Validation Error: {e}")
logging.warning("You won't be able to run the Agent until the issues are handled manually.")
return self.is_valid

def generate_session_id(self, history: list = None) -> str:
if history:
validate_history(history)
timestamp = datetime.now().strftime('%Y%m%d%H%M%S')
session_id = f"{self.id}_{timestamp}"

if not history:
return session_id

try:
validate_history(history)
headers = {"x-api-key": self.api_key, "Content-Type": "application/json"}

payload = {
"id": self.id,
"query": "/",
"sessionId": session_id,
"history": history,
"executionParams": {
"maxTokens": 2048,
"maxIterations": 10,
"outputFormat": OutputFormat.TEXT.value,
"expectedOutput": None,
},
"allowHistoryAndSessionId": True
}

r = _request_with_retry("post", self.url, headers=headers, data=json.dumps(payload))
resp = r.json()
poll_url = resp.get("data")

result = self.sync_poll(poll_url, name="model_process", timeout=300, wait_time=0.5)

if result.get("status") == ResponseStatus.SUCCESS:
return session_id
else:
logging.error(f"Session {session_id} initialization failed: {result}")
return session_id

except Exception as e:
logging.error(f"Failed to initialize session {session_id}: {e}")
return session_id


def run(
self,
Expand Down Expand Up @@ -201,6 +246,14 @@ def run(
Dict: parsed output from model
"""
start = time.time()
if session_id is not None and history is not None:
raise ValueError("Provide either `session_id` or `history`, not both.")

if session_id is not None:
if not session_id.startswith(f"{self.id}_"):
raise ValueError(f"Session ID '{session_id}' does not belong to this Agent.")
if history:
validate_history(history)
result_data = {}
try:
response = self.run_async(
Expand Down Expand Up @@ -286,6 +339,17 @@ def run_async(
Returns:
dict: polling URL in response
"""

if session_id is not None and history is not None:
raise ValueError("Provide either `session_id` or `history`, not both.")

if session_id is not None:
if not session_id.startswith(f"{self.id}_"):
raise ValueError(f"Session ID '{session_id}' does not belong to this Agent.")

if history:
validate_history(history)

from aixplain.factories.file_factory import FileFactory

if not self.is_valid:
Expand Down Expand Up @@ -352,7 +416,6 @@ def run_async(
"expectedOutput": expected_output,
},
}

payload.update(parameters)
payload = json.dumps(payload)

Expand Down
42 changes: 42 additions & 0 deletions aixplain/modules/agent/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,45 @@ def process_variables(
input_data[variable] = parameters.pop(variable)

return input_data


def validate_history(history):
"""
Validates that `history` is a list of dicts, each with 'role' and 'content' keys.
Raises a ValueError if validation fails.
"""
if not isinstance(history, list):
raise ValueError(
"History must be a list of message dictionaries. "
"Example: [{'role': 'user', 'content': 'Hello'}, {'role': 'assistant', 'content': 'Hi there!'}]"
)


allowed_roles = {"user", "assistant"}

for i, item in enumerate(history):
if not isinstance(item, dict):
raise ValueError(
f"History item at index {i} is not a dict: {item}. "
"Each item must be a dictionary like: {'role': 'user', 'content': 'Hello'}"
)

if "role" not in item or "content" not in item:
raise ValueError(
f"History item at index {i} is missing 'role' or 'content': {item}. "
"Example of a valid message: {'role': 'assistant', 'content': 'Hi there!'}"
)

if item["role"] not in allowed_roles:
raise ValueError(
f"Invalid role '{item['role']}' at index {i}. Allowed roles: {allowed_roles}. "
"Example: {'role': 'user', 'content': 'Tell me a joke'}"
)

if not isinstance(item["content"], str):
raise ValueError(
f"'content' at index {i} must be a string. Got: {type(item['content'])}. "
"Example: {'role': 'assistant', 'content': 'Sure! Here’s one...'}"
)

return True
62 changes: 61 additions & 1 deletion aixplain/modules/team_agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from enum import Enum
from typing import Dict, List, Text, Optional, Union
from urllib.parse import urljoin
from datetime import datetime

from aixplain.enums import ResponseStatus
from aixplain.enums.function import Function
Expand All @@ -39,7 +40,7 @@
from aixplain.modules.agent import Agent, OutputFormat
from aixplain.modules.agent.agent_response import AgentResponse
from aixplain.modules.agent.agent_response_data import AgentResponseData
from aixplain.modules.agent.utils import process_variables
from aixplain.modules.agent.utils import process_variables, validate_history
from aixplain.modules.team_agent.inspector import Inspector
from aixplain.utils import config
from aixplain.utils.request_utils import _request_with_retry
Expand Down Expand Up @@ -122,6 +123,47 @@ def __init__(
self.output_format = output_format
self.expected_output = expected_output

def generate_session_id(self, history: list = None) -> str:
timestamp = datetime.now().strftime('%Y%m%d%H%M%S')
session_id = f"{self.id}_{timestamp}"

if not history:
return session_id

try:
validate_history(history)
headers = {"x-api-key": self.api_key, "Content-Type": "application/json"}

payload = {
"id": self.id,
"query": "/",
"sessionId": session_id,
"history": history,
"executionParams": {
"maxTokens": 2048,
"maxIterations": 30,
"outputFormat": OutputFormat.TEXT.value,
"expectedOutput": None,
},
"allowHistoryAndSessionId": True
}

r = _request_with_retry("post", self.url, headers=headers, data=json.dumps(payload))
resp = r.json()
poll_url = resp.get("data")

result = self.sync_poll(poll_url, name="model_process", timeout=300, wait_time=0.5)

if result.get("status") == ResponseStatus.SUCCESS:
return session_id
else:
logging.error(f"Team session init failed for {session_id}: {result}")
return session_id
except Exception as e:
logging.error(f"Failed to initialize team session {session_id}: {e}")
return session_id


def run(
self,
data: Optional[Union[Dict, Text]] = None,
Expand Down Expand Up @@ -159,6 +201,14 @@ def run(
"""
start = time.time()
result_data = {}
if session_id is not None and history is not None:
raise ValueError("Provide either `session_id` or `history`, not both.")

if session_id is not None:
if not session_id.startswith(f"{self.id}_"):
raise ValueError(f"Session ID '{session_id}' does not belong to this Agent.")
if history:
validate_history(history)
try:
response = self.run_async(
data=data,
Expand Down Expand Up @@ -235,6 +285,16 @@ def run_async(
Returns:
dict: polling URL in response
"""
if session_id is not None and history is not None:
raise ValueError("Provide either `session_id` or `history`, not both.")

if session_id is not None:
if not session_id.startswith(f"{self.id}_"):
raise ValueError(f"Session ID '{session_id}' does not belong to this Agent.")

if history:
validate_history(history)

from aixplain.factories.file_factory import FileFactory

if not self.is_valid:
Expand Down