diff --git a/aixplain/modules/agent/__init__.py b/aixplain/modules/agent/__init__.py index 517f92cc..5e456026 100644 --- a/aixplain/modules/agent/__init__.py +++ b/aixplain/modules/agent/__init__.py @@ -25,6 +25,7 @@ import re import time import traceback +from datetime import datetime from aixplain.utils.file_utils import _request_with_retry from aixplain.enums import Function, Supplier, AssetStatus, StorageType, ResponseStatus @@ -34,7 +35,7 @@ from aixplain.modules.agent.tool import Tool from aixplain.modules.agent.agent_response import AgentResponse from aixplain.modules.agent.agent_response_data import AgentResponseData -from aixplain.modules.agent.utils import process_variables +from aixplain.modules.agent.utils import process_variables, validate_history from pydantic import BaseModel from typing import Dict, List, Text, Optional, Union from urllib.parse import urljoin @@ -164,6 +165,50 @@ def validate(self, raise_exception: bool = False) -> bool: logging.warning(f"Agent Validation Error: {e}") logging.warning("You won't be able to run the Agent until the issues are handled manually.") return self.is_valid + + def generate_session_id(self, history: list = None) -> str: + if history: + validate_history(history) + timestamp = datetime.now().strftime('%Y%m%d%H%M%S') + session_id = f"{self.id}_{timestamp}" + + if not history: + return session_id + + try: + validate_history(history) + headers = {"x-api-key": self.api_key, "Content-Type": "application/json"} + + payload = { + "id": self.id, + "query": "/", + "sessionId": session_id, + "history": history, + "executionParams": { + "maxTokens": 2048, + "maxIterations": 10, + "outputFormat": OutputFormat.TEXT.value, + "expectedOutput": None, + }, + "allowHistoryAndSessionId": True + } + + r = _request_with_retry("post", self.url, headers=headers, data=json.dumps(payload)) + resp = r.json() + poll_url = resp.get("data") + + result = self.sync_poll(poll_url, name="model_process", timeout=300, wait_time=0.5) + + if result.get("status") == ResponseStatus.SUCCESS: + return session_id + else: + logging.error(f"Session {session_id} initialization failed: {result}") + return session_id + + except Exception as e: + logging.error(f"Failed to initialize session {session_id}: {e}") + return session_id + def run( self, @@ -201,6 +246,14 @@ def run( Dict: parsed output from model """ start = time.time() + if session_id is not None and history is not None: + raise ValueError("Provide either `session_id` or `history`, not both.") + + if session_id is not None: + if not session_id.startswith(f"{self.id}_"): + raise ValueError(f"Session ID '{session_id}' does not belong to this Agent.") + if history: + validate_history(history) result_data = {} try: response = self.run_async( @@ -286,6 +339,17 @@ def run_async( Returns: dict: polling URL in response """ + + if session_id is not None and history is not None: + raise ValueError("Provide either `session_id` or `history`, not both.") + + if session_id is not None: + if not session_id.startswith(f"{self.id}_"): + raise ValueError(f"Session ID '{session_id}' does not belong to this Agent.") + + if history: + validate_history(history) + from aixplain.factories.file_factory import FileFactory if not self.is_valid: @@ -352,7 +416,6 @@ def run_async( "expectedOutput": expected_output, }, } - payload.update(parameters) payload = json.dumps(payload) diff --git a/aixplain/modules/agent/utils.py b/aixplain/modules/agent/utils.py index 684c82db..09076014 100644 --- a/aixplain/modules/agent/utils.py +++ b/aixplain/modules/agent/utils.py @@ -29,3 +29,45 @@ def process_variables( input_data[variable] = parameters.pop(variable) return input_data + + +def validate_history(history): + """ + Validates that `history` is a list of dicts, each with 'role' and 'content' keys. + Raises a ValueError if validation fails. + """ + if not isinstance(history, list): + raise ValueError( + "History must be a list of message dictionaries. " + "Example: [{'role': 'user', 'content': 'Hello'}, {'role': 'assistant', 'content': 'Hi there!'}]" + ) + + + allowed_roles = {"user", "assistant"} + + for i, item in enumerate(history): + if not isinstance(item, dict): + raise ValueError( + f"History item at index {i} is not a dict: {item}. " + "Each item must be a dictionary like: {'role': 'user', 'content': 'Hello'}" + ) + + if "role" not in item or "content" not in item: + raise ValueError( + f"History item at index {i} is missing 'role' or 'content': {item}. " + "Example of a valid message: {'role': 'assistant', 'content': 'Hi there!'}" + ) + + if item["role"] not in allowed_roles: + raise ValueError( + f"Invalid role '{item['role']}' at index {i}. Allowed roles: {allowed_roles}. " + "Example: {'role': 'user', 'content': 'Tell me a joke'}" + ) + + if not isinstance(item["content"], str): + raise ValueError( + f"'content' at index {i} must be a string. Got: {type(item['content'])}. " + "Example: {'role': 'assistant', 'content': 'Sure! Here’s one...'}" + ) + + return True diff --git a/aixplain/modules/team_agent/__init__.py b/aixplain/modules/team_agent/__init__.py index 08a1fa91..65108331 100644 --- a/aixplain/modules/team_agent/__init__.py +++ b/aixplain/modules/team_agent/__init__.py @@ -29,6 +29,7 @@ from enum import Enum from typing import Dict, List, Text, Optional, Union from urllib.parse import urljoin +from datetime import datetime from aixplain.enums import ResponseStatus from aixplain.enums.function import Function @@ -39,7 +40,7 @@ from aixplain.modules.agent import Agent, OutputFormat from aixplain.modules.agent.agent_response import AgentResponse from aixplain.modules.agent.agent_response_data import AgentResponseData -from aixplain.modules.agent.utils import process_variables +from aixplain.modules.agent.utils import process_variables, validate_history from aixplain.modules.team_agent.inspector import Inspector from aixplain.utils import config from aixplain.utils.request_utils import _request_with_retry @@ -122,6 +123,47 @@ def __init__( self.output_format = output_format self.expected_output = expected_output + def generate_session_id(self, history: list = None) -> str: + timestamp = datetime.now().strftime('%Y%m%d%H%M%S') + session_id = f"{self.id}_{timestamp}" + + if not history: + return session_id + + try: + validate_history(history) + headers = {"x-api-key": self.api_key, "Content-Type": "application/json"} + + payload = { + "id": self.id, + "query": "/", + "sessionId": session_id, + "history": history, + "executionParams": { + "maxTokens": 2048, + "maxIterations": 30, + "outputFormat": OutputFormat.TEXT.value, + "expectedOutput": None, + }, + "allowHistoryAndSessionId": True + } + + r = _request_with_retry("post", self.url, headers=headers, data=json.dumps(payload)) + resp = r.json() + poll_url = resp.get("data") + + result = self.sync_poll(poll_url, name="model_process", timeout=300, wait_time=0.5) + + if result.get("status") == ResponseStatus.SUCCESS: + return session_id + else: + logging.error(f"Team session init failed for {session_id}: {result}") + return session_id + except Exception as e: + logging.error(f"Failed to initialize team session {session_id}: {e}") + return session_id + + def run( self, data: Optional[Union[Dict, Text]] = None, @@ -159,6 +201,14 @@ def run( """ start = time.time() result_data = {} + if session_id is not None and history is not None: + raise ValueError("Provide either `session_id` or `history`, not both.") + + if session_id is not None: + if not session_id.startswith(f"{self.id}_"): + raise ValueError(f"Session ID '{session_id}' does not belong to this Agent.") + if history: + validate_history(history) try: response = self.run_async( data=data, @@ -235,6 +285,16 @@ def run_async( Returns: dict: polling URL in response """ + if session_id is not None and history is not None: + raise ValueError("Provide either `session_id` or `history`, not both.") + + if session_id is not None: + if not session_id.startswith(f"{self.id}_"): + raise ValueError(f"Session ID '{session_id}' does not belong to this Agent.") + + if history: + validate_history(history) + from aixplain.factories.file_factory import FileFactory if not self.is_valid: