diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml new file mode 100644 index 00000000..65658e18 --- /dev/null +++ b/.github/workflows/docs.yaml @@ -0,0 +1,45 @@ +name: Generate API docs + +on: + push: + branches: + - main + +permissions: + contents: write + pull-requests: write + +jobs: + build-docs: + runs-on: ubuntu-latest + + steps: + - name: Checkout repo + uses: actions/checkout@v4 + with: + persist-credentials: false + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.x' + + - name: Install dependencies + run: | + pip install pydoc-markdown + + - name: Run pydoc-markdown + run: | + pydoc-markdown pydoc-markdown.yml + + - name: Create Pull Request if docs changed + uses: peter-evans/create-pull-request@v7 + with: + branch: docs/regenerate-api-docs + commit-message: "chore(docs): regenerate API docs" + title: "chore(docs): regenerate API docs" + body: | + This PR was automatically generated by the workflow to regenerate the API documentation. + add-paths: | + docs/api-reference/python/** + delete-branch: true diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index 9286cae1..e1214b24 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -6,7 +6,7 @@ on: - main # any branch other than main, will use the test key - test - workflow_dispatch: + workflow_dispatch: jobs: setup-and-test: @@ -31,7 +31,8 @@ jobs: 'finetune_v2', 'general_assets', 'apikey', - 'agent_and_team_agent', + 'agent', + 'team_agent', ] include: - test-suite: 'unit' @@ -79,9 +80,13 @@ jobs: - test-suite: 'apikey' path: 'tests/functional/apikey' timeout: 45 - - test-suite: 'agent_and_team_agent' - path: 'tests/functional/agent tests/functional/team_agent' + - test-suite: 'agent' + path: 'tests/functional/agent' timeout: 45 + - test-suite: 'team_agent' + path: 'tests/functional/team_agent' + timeout: 45 + steps: - name: Checkout repository uses: actions/checkout@v4 @@ -91,7 +96,7 @@ jobs: with: python-version: "3.9" cache: 'pip' - + - name: Install dependencies run: | python -m pip install --upgrade pip @@ -112,7 +117,7 @@ jobs: fi echo "SLACK_TOKEN=${{ secrets.SLACK_TOKEN }}" >> $GITHUB_ENV echo "HF_TOKEN=${{ secrets.HF_TOKEN }}" >> $GITHUB_ENV - + - name: Run Tests timeout-minutes: ${{ matrix.timeout }} run: python -m pytest ${{ matrix.path }} diff --git a/aixplain/factories/agent_factory/__init__.py b/aixplain/factories/agent_factory/__init__.py index 05ef1f2b..ae0158e1 100644 --- a/aixplain/factories/agent_factory/__init__.py +++ b/aixplain/factories/agent_factory/__init__.py @@ -1,7 +1,4 @@ -__author__ = "lucaspavanelli" - -""" -Copyright 2024 The aiXplain SDK authors +"""Copyright 2024 The aiXplain SDK authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,6 +18,8 @@ Agent Factory Class """ +__author__ = "lucaspavanelli" + import json import logging import warnings @@ -49,6 +48,19 @@ from aixplain.enums import DatabaseSourceType +def to_literal_text(x): + """Convert value to literal text, escaping braces for string formatting. + + Args: + x: Value to convert (dict, list, or any other type) + + Returns: + str: Escaped string representation + """ + s = json.dumps(x, ensure_ascii=False, indent=2) if isinstance(x, (dict, list)) else str(x) + return s.replace("{", "{{").replace("}", "}}") + + class AgentFactory: """Factory class for creating and managing agents in the aiXplain system. @@ -63,7 +75,6 @@ def create( description: Text, instructions: Optional[Text] = None, llm: Optional[Union[LLM, Text]] = None, - llm_id: Optional[Text] = None, tools: Optional[List[Union[Tool, Model]]] = None, api_key: Text = config.TEAM_API_KEY, supplier: Union[Dict, Text, Supplier, int] = "aiXplain", @@ -72,6 +83,7 @@ def create( workflow_tasks: Optional[List[WorkflowTask]] = None, output_format: Optional[OutputFormat] = None, expected_output: Optional[Union[BaseModel, Text, dict]] = None, + **kwargs, ) -> Agent: """Create a new agent in the platform. @@ -85,14 +97,15 @@ def create( description (Text): description of the agent instructions. instructions (Text): instructions of the agent. llm (Optional[Union[LLM, Text]], optional): LLM instance to use as an object or as an ID. - llm_id (Optional[Text], optional): ID of LLM to use if no LLM instance provided. Defaults to None. tools (List[Union[Tool, Model]], optional): list of tool for the agent. Defaults to []. api_key (Text, optional): team/user API key. Defaults to config.TEAM_API_KEY. supplier (Union[Dict, Text, Supplier, int], optional): owner of the agent. Defaults to "aiXplain". version (Optional[Text], optional): version of the agent. Defaults to None. + tasks (List[WorkflowTask], optional): Deprecated. Use workflow_tasks instead. Defaults to None. workflow_tasks (List[WorkflowTask], optional): list of tasks for the agent. Defaults to []. output_format (OutputFormat, optional): default output format for agent responses. Defaults to OutputFormat.TEXT. expected_output (Union[BaseModel, Text, dict], optional): expected output. Defaults to None. + **kwargs: Additional keyword arguments. Returns: Agent: created Agent """ @@ -100,6 +113,16 @@ def create( workflow_tasks = [] if workflow_tasks is None else list(workflow_tasks) from aixplain.utils.llm_utils import get_llm_instance + # Extract llm_id from kwargs if present (deprecated parameter) + llm_id = kwargs.get("llm_id", None) + if llm_id is not None: + warnings.warn( + "The 'llm_id' parameter is deprecated and will be removed in a future version. " + "Use the 'llm' parameter instead by passing the LLM ID or LLM instance directly.", + DeprecationWarning, + stacklevel=2, + ) + if llm is None and llm_id is not None: llm = get_llm_instance(llm_id, api_key=api_key, use_cache=True) elif llm is None: @@ -112,9 +135,7 @@ def create( ), "'expected_output' must be a Pydantic BaseModel or a JSON object when 'output_format' is JSON." warnings.warn( - "Use `llm` to define the large language model (aixplain.modules.model.llm_model.LLM) to be used as agent. " - "Use `llm_id` to provide the model ID of the large language model to be used as agent. " - "Note: In upcoming releases, `llm` will become a required parameter.", + "Deprecating 'llm_id', use `llm` to define the large language model in agents.", UserWarning, ) from aixplain.factories.agent_factory.utils import ( @@ -133,7 +154,8 @@ def create( if tasks is not None: warnings.warn( - "The 'tasks' parameter is deprecated and will be removed in a future version. " "Use 'workflow_tasks' instead.", + "The 'tasks' parameter is deprecated and will be removed in a future version. " + "Use 'workflow_tasks' instead.", DeprecationWarning, stacklevel=2, ) @@ -145,7 +167,7 @@ def create( "name": name, "assets": [build_tool_payload(tool) for tool in tools], "description": description, - "instructions": instructions or description, + "instructions": instructions if instructions is not None else description, "supplier": supplier, "version": version, "llmId": llm_id, @@ -228,6 +250,17 @@ def create_workflow_task( expected_output: Text, dependencies: Optional[List[Text]] = None, ) -> WorkflowTask: + """Create a new workflow task for an agent. + + Args: + name (Text): Name of the task + description (Text): Description of what the task does + expected_output (Text): Expected output format or content + dependencies (Optional[List[Text]], optional): List of task names this task depends on. Defaults to None. + + Returns: + WorkflowTask: Created workflow task object + """ dependencies = [] if dependencies is None else list(dependencies) return WorkflowTask( name=name, @@ -238,6 +271,11 @@ def create_workflow_task( @classmethod def create_task(cls, *args, **kwargs): + """Create a workflow task (deprecated - use create_workflow_task instead). + + .. deprecated:: + Use :meth:`create_workflow_task` instead. + """ warnings.warn( "The 'create_task' method is deprecated and will be removed in a future version. " "Use 'create_workflow_task' instead.", @@ -351,7 +389,7 @@ def create_sql_tool( tables: Optional[List[Text]] = None, enable_commit: bool = False, ) -> SQLTool: - """Create a new SQL tool + """Create a new SQL tool. Args: name (Text): name of the tool @@ -361,6 +399,7 @@ def create_sql_tool( schema (Optional[Text], optional): database schema description tables (Optional[List[Text]], optional): table names to work with (optional) enable_commit (bool, optional): enable to modify the database (optional) + Returns: SQLTool: created SQLTool @@ -403,7 +442,9 @@ def create_sql_tool( # Already the correct type, no conversion needed pass else: - raise SQLToolError(f"Source type must be either a string or DatabaseSourceType enum, got {type(source_type)}") + raise SQLToolError( + f"Source type must be either a string or DatabaseSourceType enum, got {type(source_type)}" + ) database_path = None # Final database path to pass to SQLTool diff --git a/aixplain/factories/team_agent_factory/__init__.py b/aixplain/factories/team_agent_factory/__init__.py index 60050813..e9d2ba7a 100644 --- a/aixplain/factories/team_agent_factory/__init__.py +++ b/aixplain/factories/team_agent_factory/__init__.py @@ -23,6 +23,7 @@ import json import logging +import warnings from typing import Dict, List, Optional, Text, Union from urllib.parse import urljoin @@ -54,15 +55,12 @@ def create( cls, name: Text, agents: List[Union[Text, Agent]], - llm_id: Text = "669a63646eb56306647e1091", llm: Optional[Union[LLM, Text]] = None, supervisor_llm: Optional[Union[LLM, Text]] = None, - mentalist_llm: Optional[Union[LLM, Text]] = None, description: Text = "", api_key: Text = config.TEAM_API_KEY, supplier: Union[Dict, Text, Supplier, int] = "aiXplain", version: Optional[Text] = None, - use_mentalist: bool = True, inspectors: List[Inspector] = [], inspector_targets: List[Union[InspectorTarget, Text]] = [InspectorTarget.STEPS], instructions: Optional[Text] = None, @@ -75,24 +73,57 @@ def create( Args: name: The name of the team agent. agents: A list of agents to be added to the team. - llm_id: The ID of the LLM to be used for the team agent. llm (Optional[Union[LLM, Text]], optional): The LLM to be used for the team agent. supervisor_llm (Optional[Union[LLM, Text]], optional): Main supervisor LLM. Defaults to None. - mentalist_llm (Optional[Union[LLM, Text]], optional): LLM for planning. Defaults to None. description: The description of the team agent to be displayed in the aiXplain platform. api_key: The API key to be used for the team agent. supplier: The supplier of the team agent. version: The version of the team agent. - use_mentalist: Whether to use the mentalist agent. inspectors: A list of inspectors to be added to the team. inspector_targets: Which stages to be inspected during an execution of the team agent. (steps, output) - use_mentalist_and_inspector: Whether to use the mentalist and inspector agents. (legacy) instructions: The instructions to guide the team agent (i.e. appended in the prompt of the team agent). output_format: The output format to be used for the team agent. expected_output: The expected output to be used for the team agent. Returns: A new team agent instance. + + Deprecated Args: + llm_id: DEPRECATED. Use 'llm' parameter instead. The ID of the LLM to be used for the team agent. + mentalist_llm: DEPRECATED. LLM for planning. + use_mentalist: DEPRECATED. Whether to use the mentalist agent. """ + # Handle deprecated parameters from kwargs + if "llm_id" in kwargs: + llm_id = kwargs.pop("llm_id") + warnings.warn( + "Parameter 'llm_id' is deprecated and will be removed in a future version. " + "Please use 'llm' parameter instead.", + DeprecationWarning, + stacklevel=2, + ) + else: + llm_id = "669a63646eb56306647e1091" + + if "mentalist_llm" in kwargs: + mentalist_llm = kwargs.pop("mentalist_llm") + warnings.warn( + "Parameter 'mentalist_llm' is deprecated and will be removed in a future version.", + DeprecationWarning, + stacklevel=2, + ) + else: + mentalist_llm = None + + if "use_mentalist" in kwargs: + use_mentalist = kwargs.pop("use_mentalist") + warnings.warn( + "Parameter 'use_mentalist' is deprecated and will be removed in a future version.", + DeprecationWarning, + stacklevel=2, + ) + else: + use_mentalist = True + # legacy params if "use_mentalist_and_inspector" in kwargs: logging.warning( @@ -227,10 +258,13 @@ def _setup_llm_and_tool( team_agent = build_team_agent(payload=internal_payload, agents=agent_list, api_key=api_key) team_agent.validate(raise_exception=True) response = "Unspecified error" + inspectors=team_agent.inspectors + inspector_targets=team_agent.inspector_targets try: payload["inspectors"] = [ inspector.model_dump(by_alias=True) for inspector in inspectors - ] # convert Inspector object to dict + ] + payload["inspectorTargets"] = inspector_targets logging.debug(f"Start service for POST Create TeamAgent - {url} - {headers} - {json.dumps(payload)}") r = _request_with_retry("post", url, headers=headers, json=payload) response = r.json() diff --git a/aixplain/factories/team_agent_factory/utils.py b/aixplain/factories/team_agent_factory/utils.py index edbef6d4..3d4c15ed 100644 --- a/aixplain/factories/team_agent_factory/utils.py +++ b/aixplain/factories/team_agent_factory/utils.py @@ -154,6 +154,7 @@ def get_cached_model(model_id: str) -> any: elif tool["description"] == "mentalist": mentalist_llm = llm + team_agent = TeamAgent( id=payload.get("id", ""), name=payload.get("name", ""), @@ -349,6 +350,6 @@ def build_team_agent_from_yaml(yaml_code: str, llm_id: str, api_key: str, team_i agents=agent_objs, llm=llm, api_key=api_key, - use_mentalist=True, inspectors=[], + use_mentalist=True, # Deprecated parameter ) diff --git a/aixplain/modules/agent/__init__.py b/aixplain/modules/agent/__init__.py index 33d0a7fb..18483076 100644 --- a/aixplain/modules/agent/__init__.py +++ b/aixplain/modules/agent/__init__.py @@ -24,6 +24,7 @@ """ __author__ = "aiXplain" +import inspect import json import logging import re @@ -291,6 +292,279 @@ def generate_session_id(self, history: list = None) -> str: logging.error(f"Failed to initialize session {session_id}: {e}") return session_id + def _normalize_progress_data(self, progress: Dict) -> Dict: + """Normalize progress data from camelCase to snake_case. + + Args: + progress (Dict): Progress data from backend (may use camelCase) + + Returns: + Dict: Normalized progress data with snake_case keys + """ + if not progress: + return progress + + # Map camelCase to snake_case for known fields + normalized = {} + key_mapping = { + "toolInput": "tool_input", + "toolOutput": "tool_output", + "currentStep": "current_step", + "totalSteps": "total_steps", + } + + for key, value in progress.items(): + # Use mapped key if available, otherwise keep original + normalized_key = key_mapping.get(key, key) + normalized[normalized_key] = value + + return normalized + + def poll(self, poll_url: Text, name: Text = "model_process") -> "AgentResponse": + """Override poll to normalize progress data from camelCase to snake_case. + + Args: + poll_url (Text): URL to poll for operation status. + name (Text, optional): Identifier for the operation. Defaults to "model_process". + + Returns: + AgentResponse: Response with normalized progress data. + """ + # Call parent poll method + response = super().poll(poll_url, name) + + # Normalize progress data if present (stored in additional_fields) + if hasattr(response, "additional_fields") and isinstance(response.additional_fields, dict): + if "progress" in response.additional_fields and response.additional_fields["progress"]: + response.additional_fields["progress"] = self._normalize_progress_data( + response.additional_fields["progress"] + ) + + return response + + def _format_agent_progress( + self, + progress: Dict, + verbosity: Optional[str] = "full", + ) -> Optional[str]: + """Format agent progress message based on verbosity level. + + Args: + progress (Dict): Progress data from polling response + verbosity (Optional[str]): "full", "compact", or None (disables output) + + Returns: + Optional[str]: Formatted message or None + """ + if verbosity is None: + return None + + stage = progress.get("stage", "working") + tool = progress.get("tool") + runtime = progress.get("runtime") + success = progress.get("success") + reason = progress.get("reason", "") + tool_input = progress.get("tool_input", "") + tool_output = progress.get("tool_output", "") + + # Determine status icon + if success is True: + status_icon = "✓" + elif success is False: + status_icon = "✗" + else: + status_icon = "⏳" + + agent_name = self.name + + if verbosity == "compact": + # Compact mode: minimal info + if tool: + msg = f"⚙️ {agent_name} | {tool} | {status_icon}" + if success is True and tool_output: + output_str = str(tool_output)[:200] + msg += f" {output_str}" + msg += "..." if len(str(tool_output)) > 200 else "" + else: + stage_name = stage.replace("_", " ").title() + msg = f"🤖 {agent_name} | {status_icon} {stage_name}" + else: + # Full verbosity: detailed info + if tool: + msg = f"⚙️ {agent_name} | {tool} | {status_icon}" + + if runtime is not None and runtime > 0 and success is not None: + msg += f" ({runtime:.1f} s)" + + if tool_input: + msg += f" | Input: {tool_input}" + + if tool_output: + msg += f" | Output: {tool_output}" + + if reason: + msg += f" | Reason: {reason}" + else: + stage_name = stage.replace("_", " ").title() + msg = f"🤖 {agent_name} | {status_icon} {stage_name}" + if reason: + msg += f" | {reason}" + + return msg + + def _format_completion_message( + self, + elapsed_time: float, + response_body: "AgentResponse", + timed_out: bool = False, + timeout: float = 300, + verbosity: Optional[str] = "full", + ) -> str: + """Format completion message with metrics. + + Args: + elapsed_time (float): Total elapsed time in seconds + response_body (AgentResponse): Final response + timed_out (bool): Whether the operation timed out + timeout (float): Timeout value if timed out + verbosity (Optional[str]): "full" or "compact" + + Returns: + str: Formatted completion message + """ + if timed_out: + return f"✅ Done | ✗ Timeout - No response after {timeout}s" + + # Collect metrics from execution_stats if available + total_api_calls = 0 + total_credits = 0.0 + runtime = elapsed_time + + # Extract data dict (handle tuple or direct object) + data_dict = None + if hasattr(response_body, "data") and response_body.data: + if isinstance(response_body.data, tuple) and len(response_body.data) > 0: + # Data is a tuple, get first element + data_dict = response_body.data[0] if isinstance(response_body.data[0], dict) else None + elif isinstance(response_body.data, dict): + # Data is already a dict + data_dict = response_body.data + elif hasattr(response_body.data, "executionStats") or hasattr(response_body.data, "execution_stats"): + # Data is an object with attributes + exec_stats = getattr(response_body.data, "executionStats", None) or getattr( + response_body.data, "execution_stats", None + ) + if exec_stats and isinstance(exec_stats, dict): + total_api_calls = exec_stats.get("api_calls", 0) + total_credits = exec_stats.get("credits", 0.0) + runtime = exec_stats.get("runtime", elapsed_time) + + # Try to get metrics from data dict (camelCase fields from backend) + if data_dict and isinstance(data_dict, dict): + # Check executionStats first + exec_stats = data_dict.get("executionStats") + if exec_stats and isinstance(exec_stats, dict): + total_api_calls = exec_stats.get("api_calls", 0) + total_credits = exec_stats.get("credits", 0.0) + runtime = exec_stats.get("runtime", elapsed_time) + + # Fallback: check top-level fields (usedCredits, runTime) + if total_credits == 0.0: + total_credits = data_dict.get("usedCredits", 0.0) + if runtime == elapsed_time: + runtime = data_dict.get("runTime", elapsed_time) + + # Build single-line completion message with metrics + if verbosity == "compact": + msg = f"✅ Done | ({runtime:.1f} s total" + else: + msg = f"✅ Done | Completed successfully ({runtime:.1f} s total" + + # Always show API calls and credits + if total_api_calls > 0: + msg += f" | {total_api_calls} API calls" + msg += f" | ${total_credits}" + msg += ")" + + return msg + + def sync_poll( + self, + poll_url: Text, + name: Text = "model_process", + wait_time: float = 0.5, + timeout: float = 300, + progress_verbosity: Optional[str] = "compact", + ) -> "AgentResponse": + """Poll the platform until agent execution completes or times out. + + Args: + poll_url (Text): URL to poll for operation status. + name (Text, optional): Identifier for the operation. Defaults to "model_process". + wait_time (float, optional): Initial wait time in seconds between polls. Defaults to 0.5. + timeout (float, optional): Maximum total time to poll in seconds. Defaults to 300. + progress_verbosity (Optional[str], optional): Progress display mode - "full" (detailed), "compact" (brief), or None (no progress). Defaults to "compact". + + Returns: + AgentResponse: The final response from the agent execution. + """ + logging.info(f"Polling for Agent: Start polling for {name}") + start, end = time.time(), time.time() + wait_time = max(wait_time, 0.2) + completed = False + response_body = AgentResponse(status=ResponseStatus.FAILED, completed=False) + last_message = None # Track last message to avoid duplicates + + while not completed and (end - start) < timeout: + try: + response_body = self.poll(poll_url, name=name) + completed = response_body["completed"] + + # Display progress inline if enabled + if progress_verbosity and not completed: + progress = response_body.get("progress") + if progress: + msg = self._format_agent_progress(progress, progress_verbosity) + if msg and msg != last_message: + print(msg, flush=True) + last_message = msg + + end = time.time() + if completed is False: + time.sleep(wait_time) + if wait_time < 60: + wait_time *= 1.1 + except Exception as e: + response_body = AgentResponse( + status=ResponseStatus.FAILED, + completed=False, + error_message="No response from the service.", + ) + logging.error(f"Polling for Agent: polling for {name}: {e}") + return response_body + break + + # Display completion message + if progress_verbosity: + elapsed_time = end - start + timed_out = response_body["completed"] is not True + completion_msg = self._format_completion_message( + elapsed_time, response_body, timed_out, timeout, progress_verbosity + ) + print(completion_msg, flush=True) + + if response_body["completed"] is True: + logging.debug(f"Polling for Agent: Final status of polling for {name}: {response_body}") + else: + response_body = AgentResponse( + status=ResponseStatus.FAILED, + completed=False, + error_message="No response from the service.", + ) + logging.error(f"Polling for Agent: Final status of polling for {name}: No response in {timeout} seconds") + + return response_body + def run( self, data: Optional[Union[Dict, Text]] = None, @@ -304,9 +578,9 @@ def run( content: Optional[Union[Dict[Text, Text], List[Text]]] = None, max_tokens: int = 4096, max_iterations: int = 5, - output_format: Optional[OutputFormat] = None, - expected_output: Optional[Union[BaseModel, Text, dict]] = None, trace_request: bool = False, + progress_verbosity: Optional[str] = "compact", + **kwargs, ) -> AgentResponse: """Runs an agent call. @@ -322,13 +596,35 @@ def run( content (Union[Dict[Text, Text], List[Text]], optional): Content inputs to be processed according to the query. Defaults to None. max_tokens (int, optional): maximum number of tokens which can be generated by the agent. Defaults to 2048. max_iterations (int, optional): maximum number of iterations between the agent and the tools. Defaults to 10. - output_format (OutputFormat, optional): response format. If not provided, uses the format set during initialization. - expected_output (Union[BaseModel, Text, dict], optional): expected output. Defaults to None. trace_request (bool, optional): return the request id for tracing the request. Defaults to False. + progress_verbosity (Optional[str], optional): Progress display mode - "full" (detailed), "compact" (brief), or None (no progress). Defaults to "compact". + **kwargs: Additional keyword arguments. + Returns: Dict: parsed output from model """ start = time.time() + + # Extract deprecated parameters from kwargs + output_format = kwargs.get("output_format", None) + expected_output = kwargs.get("expected_output", None) + + if output_format is not None: + warnings.warn( + "The 'output_format' parameter is deprecated and will be removed in a future version. " + "Set the output format during agent initialization instead.", + DeprecationWarning, + stacklevel=2, + ) + + if expected_output is not None: + warnings.warn( + "The 'expected_output' parameter is deprecated and will be removed in a future version. " + "Set the expected output during agent initialization instead.", + DeprecationWarning, + stacklevel=2, + ) + if session_id is not None and history is not None: raise ValueError("Provide either `session_id` or `history`, not both.") @@ -361,9 +657,11 @@ def run( return response poll_url = response["url"] end = time.time() - result = self.sync_poll(poll_url, name=name, timeout=timeout, wait_time=wait_time) - # if result.status == ResponseStatus.FAILED: - # raise Exception("Model failed to run with error: " + result.error_message) + result = self.sync_poll( + poll_url, name=name, timeout=timeout, wait_time=wait_time, progress_verbosity=progress_verbosity + ) + if result.status == ResponseStatus.FAILED: + raise Exception("Model failed to run with error: " + result.error_message) result_data = result.get("data") or {} return AgentResponse( status=ResponseStatus.SUCCESS, @@ -427,6 +725,7 @@ def run_async( output_format (ResponseFormat, optional): response format. Defaults to TEXT. evolve (Union[Dict[str, Any], EvolveParam, None], optional): evolve the agent configuration. Can be a dictionary, EvolveParam instance, or None. trace_request (bool, optional): return the request id for tracing the request. Defaults to False. + Returns: dict: polling URL in response """ @@ -448,7 +747,8 @@ def run_async( if output_format == OutputFormat.JSON: assert expected_output is not None and ( - issubclass(expected_output, BaseModel) or isinstance(expected_output, dict) + (inspect.isclass(expected_output) and issubclass(expected_output, BaseModel)) + or isinstance(expected_output, dict) ), "Expected output must be a Pydantic BaseModel or a JSON object when output format is JSON." assert data is not None or query is not None, "Either 'data' or 'query' must be provided." @@ -490,7 +790,7 @@ def run_async( input_data = process_variables(query, data, parameters, self.instructions) if expected_output is None: expected_output = self.expected_output - if expected_output is not None and issubclass(expected_output, BaseModel): + if expected_output is not None and isinstance(expected_output, type) and issubclass(expected_output, BaseModel): expected_output = expected_output.model_json_schema() expected_output = normalize_expected_output(expected_output) # Use instance output_format if none provided diff --git a/aixplain/modules/agent/utils.py b/aixplain/modules/agent/utils.py index 588a57ea..9da5d787 100644 --- a/aixplain/modules/agent/utils.py +++ b/aixplain/modules/agent/utils.py @@ -38,16 +38,10 @@ def process_variables( variables = re.findall(r"(? None: @@ -79,14 +78,16 @@ def __init__( function (Function, optional): Model's AI function. Must be Function.TEXT_GENERATION. is_subscribed (bool, optional): Whether the user is subscribed. Defaults to False. cost (Dict, optional): Cost of the model. Defaults to None. - temperature (float, optional): Default temperature for text generation. Defaults to 0.001. + temperature (Optional[float], optional): Default temperature for text generation. Defaults to None. function_type (FunctionType, optional): Type of the function. Defaults to FunctionType.AI. **additional_info: Any additional model info to be saved. Raises: AssertionError: If function is not Function.TEXT_GENERATION. """ - assert function == Function.TEXT_GENERATION, "LLM only supports large language models (i.e. text generation function)" + assert function == Function.TEXT_GENERATION, ( + "LLM only supports large language models (i.e. text generation function)" + ) super().__init__( id=id, name=name, @@ -112,12 +113,13 @@ def run( history: Optional[List[Dict]] = None, temperature: Optional[float] = None, max_tokens: int = 128, - top_p: float = 1.0, + top_p: Optional[float] = None, name: Text = "model_process", timeout: float = 300, parameters: Optional[Dict] = None, wait_time: float = 0.5, stream: bool = False, + response_format: Optional[Text] = None, ) -> Union[ModelResponse, ModelResponseStreamer]: """Run the LLM model synchronously to generate text. @@ -138,8 +140,8 @@ def run( Defaults to None. max_tokens (int, optional): Maximum number of tokens to generate. Defaults to 128. - top_p (float, optional): Nucleus sampling parameter. Only tokens with cumulative - probability < top_p are considered. Defaults to 1.0. + top_p (Optional[float], optional): Nucleus sampling parameter. Only tokens with cumulative + probability < top_p are considered. Defaults to None. name (Text, optional): Identifier for this model run. Useful for logging. Defaults to "model_process". timeout (float, optional): Maximum time in seconds to wait for completion. @@ -150,6 +152,8 @@ def run( Defaults to 0.5. stream (bool, optional): Whether to stream the model's output tokens. Defaults to False. + response_format (Optional[Union[str, dict, BaseModel]], optional): + Specifies the desired output structure or format of the model’s response. Returns: Union[ModelResponse, ModelResponseStreamer]: If stream=False, returns a ModelResponse @@ -166,9 +170,13 @@ def run( parameters.setdefault("context", context) parameters.setdefault("prompt", prompt) parameters.setdefault("history", history) - parameters.setdefault("temperature", temperature if temperature is not None else self.temperature) + temp_value = temperature if temperature is not None else self.temperature + if temp_value is not None: + parameters.setdefault("temperature", temp_value) parameters.setdefault("max_tokens", max_tokens) - parameters.setdefault("top_p", top_p) + if top_p is not None: + parameters.setdefault("top_p", top_p) + parameters.setdefault("response_format", response_format) if stream: return self.run_stream(data=data, parameters=parameters) @@ -210,9 +218,10 @@ def run_async( history: Optional[List[Dict]] = None, temperature: Optional[float] = None, max_tokens: int = 128, - top_p: float = 1.0, + top_p: Optional[float] = None, name: Text = "model_process", parameters: Optional[Dict] = None, + response_format: Optional[Text] = None, ) -> ModelResponse: """Run the LLM model asynchronously to generate text. @@ -233,12 +242,14 @@ def run_async( Defaults to None. max_tokens (int, optional): Maximum number of tokens to generate. Defaults to 128. - top_p (float, optional): Nucleus sampling parameter. Only tokens with cumulative - probability < top_p are considered. Defaults to 1.0. + top_p (Optional[float], optional): Nucleus sampling parameter. Only tokens with cumulative + probability < top_p are considered. Defaults to None. name (Text, optional): Identifier for this model run. Useful for logging. Defaults to "model_process". parameters (Optional[Dict], optional): Additional model-specific parameters. Defaults to None. + response_format (Optional[Text], optional): Desired output format specification. + Defaults to None. Returns: ModelResponse: A response object containing: @@ -261,9 +272,13 @@ def run_async( parameters.setdefault("context", context) parameters.setdefault("prompt", prompt) parameters.setdefault("history", history) - parameters.setdefault("temperature", temperature if temperature is not None else self.temperature) + temp_value = temperature if temperature is not None else self.temperature + if temp_value is not None: + parameters.setdefault("temperature", temp_value) parameters.setdefault("max_tokens", max_tokens) - parameters.setdefault("top_p", top_p) + if top_p is not None: + parameters.setdefault("top_p", top_p) + parameters.setdefault("response_format", response_format) payload = build_payload(data=data, parameters=parameters) response = call_run_endpoint(payload=payload, url=url, api_key=self.api_key) return ModelResponse( diff --git a/aixplain/modules/team_agent/__init__.py b/aixplain/modules/team_agent/__init__.py index 58c19cd4..09d444be 100644 --- a/aixplain/modules/team_agent/__init__.py +++ b/aixplain/modules/team_agent/__init__.py @@ -1,6 +1,8 @@ -__author__ = "aiXplain" +"""Team Agent module for aiXplain SDK. + +This module provides the TeamAgent class and related functionality for creating and managing +multi-agent teams that can collaborate on complex tasks. -""" Copyright 2024 The aiXplain SDK authors Licensed under the Apache License, Version 2.0 (the "License"); @@ -21,11 +23,14 @@ Team Agent Class """ +__author__ = "aiXplain" + import json import logging import time import traceback import re +import warnings from enum import Enum from typing import Dict, List, Text, Optional, Union, Any from urllib.parse import urljoin @@ -86,14 +91,23 @@ class TeamAgent(Model, DeployableMixin[Agent]): name (Text): Name of the Team Agent agents (List[Agent]): List of agents that the Team Agent uses. description (Text, optional): description of the Team Agent. Defaults to "". - llm_id (Text, optional): large language model. Defaults to GPT-4o (6646261c6eb563165658bbb1). + llm (Optional[LLM]): Main LLM instance for the team agent. + supervisor_llm (Optional[LLM]): Supervisor LLM instance for the team agent. api_key (str): The TEAM API key used for authentication. supplier (Text): Supplier of the Team Agent. version (Text): Version of the Team Agent. cost (Dict, optional): model price. Defaults to None. - use_mentalist (bool): Use Mentalist agent for pre-planning. Defaults to True. inspectors (List[Inspector]): List of inspectors that the team agent uses. inspector_targets (List[InspectorTarget]): List of targets where the inspectors are applied. Defaults to [InspectorTarget.STEPS]. + status (AssetStatus): Status of the Team Agent. Defaults to DRAFT. + instructions (Optional[Text]): Instructions to guide the team agent. + output_format (OutputFormat): Response format. Defaults to TEXT. + expected_output (Optional[Union[BaseModel, Text, dict]]): Expected output format. + + Deprecated Attributes: + llm_id (Text): DEPRECATED. Use 'llm' parameter instead. Large language model ID. + mentalist_llm (Optional[LLM]): DEPRECATED. LLM for planning. + use_mentalist (bool): DEPRECATED. Whether to use Mentalist agent for pre-planning. """ is_valid: bool @@ -104,15 +118,12 @@ def __init__( name: Text, agents: List[Agent] = [], description: Text = "", - llm_id: Text = "6646261c6eb563165658bbb1", llm: Optional[LLM] = None, supervisor_llm: Optional[LLM] = None, - mentalist_llm: Optional[LLM] = None, api_key: Optional[Text] = config.TEAM_API_KEY, supplier: Union[Dict, Text, Supplier, int] = "aiXplain", version: Optional[Text] = None, cost: Optional[Dict] = None, - use_mentalist: bool = True, inspectors: List[Inspector] = [], inspector_targets: List[InspectorTarget] = [InspectorTarget.STEPS], status: AssetStatus = AssetStatus.DRAFT, @@ -121,7 +132,64 @@ def __init__( expected_output: Optional[Union[BaseModel, Text, dict]] = None, **additional_info, ) -> None: + # Handle deprecated parameters from kwargs + if "llm_id" in additional_info: + llm_id = additional_info.pop("llm_id") + warnings.warn( + "Parameter 'llm_id' is deprecated and will be removed in a future version. " + "Please use 'llm' parameter instead.", + DeprecationWarning, + stacklevel=2, + ) + else: + llm_id = "6646261c6eb563165658bbb1" + + if "mentalist_llm" in additional_info: + mentalist_llm = additional_info.pop("mentalist_llm") + warnings.warn( + "Parameter 'mentalist_llm' is deprecated and will be removed in a future version.", + DeprecationWarning, + stacklevel=2, + ) + else: + mentalist_llm = None + + if "use_mentalist" in additional_info: + use_mentalist = additional_info.pop("use_mentalist") + warnings.warn( + "Parameter 'use_mentalist' is deprecated and will be removed in a future version.", + DeprecationWarning, + stacklevel=2, + ) + else: + use_mentalist = True + """Initialize a TeamAgent instance. + + Args: + id (Text): Unique identifier for the team agent. + name (Text): Name of the team agent. + agents (List[Agent], optional): List of agents in the team. Defaults to []. + description (Text, optional): Description of the team agent. Defaults to "". + llm (Optional[LLM], optional): LLM instance. Defaults to None. + supervisor_llm (Optional[LLM], optional): Supervisor LLM instance. Defaults to None. + api_key (Optional[Text], optional): API key. Defaults to config.TEAM_API_KEY. + supplier (Union[Dict, Text, Supplier, int], optional): Supplier. Defaults to "aiXplain". + version (Optional[Text], optional): Version. Defaults to None. + cost (Optional[Dict], optional): Cost information. Defaults to None. + inspectors (List[Inspector], optional): List of inspectors. Defaults to []. + inspector_targets (List[InspectorTarget], optional): Inspector targets. Defaults to [InspectorTarget.STEPS]. + status (AssetStatus, optional): Status of the team agent. Defaults to AssetStatus.DRAFT. + instructions (Optional[Text], optional): Instructions for the team agent. Defaults to None. + output_format (OutputFormat, optional): Output format. Defaults to OutputFormat.TEXT. + expected_output (Optional[Union[BaseModel, Text, dict]], optional): Expected output format. Defaults to None. + **additional_info: Additional keyword arguments. + + Deprecated Args: + llm_id (Text, optional): DEPRECATED. Use 'llm' parameter instead. ID of the language model. Defaults to "6646261c6eb563165658bbb1". + mentalist_llm (Optional[LLM], optional): DEPRECATED. Mentalist/Planner LLM instance. Defaults to None. + use_mentalist (bool, optional): DEPRECATED. Whether to use mentalist/planner. Defaults to True. + """ super().__init__(id, name, description, api_key, supplier, version, cost=cost) self.additional_info = additional_info self.agents = agents @@ -146,6 +214,14 @@ def __init__( self.expected_output = expected_output def generate_session_id(self, history: list = None) -> str: + """Generate a new session ID for the team agent. + + Args: + history (list, optional): Chat history to initialize the session with. Defaults to None. + + Returns: + str: The generated session ID in format "{team_agent_id}_{timestamp}". + """ timestamp = datetime.now().strftime("%Y%m%d%H%M%S") session_id = f"{self.id}_{timestamp}" @@ -185,6 +261,288 @@ def generate_session_id(self, history: list = None) -> str: logging.error(f"Failed to initialize team session {session_id}: {e}") return session_id + def _normalize_progress_data(self, progress: Dict) -> Dict: + """Normalize progress data from camelCase to snake_case. + + Args: + progress (Dict): Progress data from backend (may use camelCase) + + Returns: + Dict: Normalized progress data with snake_case keys + """ + if not progress: + return progress + + # Map camelCase to snake_case for known fields + normalized = {} + key_mapping = { + "toolInput": "tool_input", + "toolOutput": "tool_output", + "currentStep": "current_step", + "totalSteps": "total_steps", + } + + for key, value in progress.items(): + # Use mapped key if available, otherwise keep original + normalized_key = key_mapping.get(key, key) + normalized[normalized_key] = value + + return normalized + + def _format_team_progress( + self, + progress: Dict, + verbosity: Optional[str] = "full", + ) -> Optional[str]: + """Format team agent progress message based on verbosity level. + + Args: + progress (Dict): Progress data from polling response + verbosity (Optional[str]): "full", "compact", or None (disables output) + + Returns: + Optional[str]: Formatted message or None + """ + if verbosity is None: + return None + + stage = progress.get("stage", "working") + agent_name = progress.get("agent") + tool = progress.get("tool") + message = progress.get("message", "") + runtime = progress.get("runtime") + success = progress.get("success") + current_step = progress.get("current_step", 0) + total_steps = progress.get("total_steps", 0) + reason = progress.get("reason", "") + tool_input = progress.get("tool_input", "") + tool_output = progress.get("tool_output", "") + + # Determine status icon + if success is True: + status_icon = "✓" + elif success is False: + status_icon = "✗" + else: + status_icon = "⏳" + + # Capitalize system agent names for better display + if agent_name: + system_agents = { + "orchestrator": "Orchestrator", + "mentalist": "Mentalist", + "response_generator": "Response Generator", + } + display_agent_name = system_agents.get(agent_name.lower(), agent_name) + else: + display_agent_name = None + + # Determine emoji and context + if stage in ["planning", "mentalist"]: + emoji = "🤖" + context = "Mentalist" + elif display_agent_name and tool: + emoji = "⚙️" + context = f"{display_agent_name} | {tool}" + elif display_agent_name: + emoji = "🤖" + context = display_agent_name + else: + emoji = "🤖" + context = self.name + + if verbosity == "compact": + # Compact mode: minimal info + msg = f"{emoji} {context} | {status_icon}" + + if current_step and total_steps: + msg += f" [{current_step}/{total_steps}]" + + # Show message if available (common for planning/orchestration stages) + if message and not tool_output: + msg += f" {message[:100]}" + elif success is True and tool_output: + output_str = str(tool_output)[:200] + msg += f" {output_str}" + msg += "..." if len(output_str) > 200 else "" + else: + # Full verbosity: detailed info + msg = f"{emoji} {context} | {status_icon}" + + if runtime is not None and runtime > 0 and success is not None: + msg += f" ({runtime:.1f} s)" + + if current_step and total_steps: + msg += f" | Step {current_step}/{total_steps}" + + if tool_input: + msg += f" | Input: {tool_input}" + + if tool_output: + msg += f" | Output: {tool_output}" + + if reason: + msg += f" | Reason: {reason}" + elif message: + # Show message if reason is not available (common for planning/orchestration) + msg += f" | {message}" + + return msg + + def _format_completion_message( + self, + elapsed_time: float, + response_body: AgentResponse, + timed_out: bool = False, + timeout: float = 300, + verbosity: Optional[str] = "full", + ) -> str: + """Format completion message with metrics. + + Args: + elapsed_time (float): Total elapsed time in seconds + response_body (AgentResponse): Final response + timed_out (bool): Whether the operation timed out + timeout (float): Timeout value if timed out + verbosity (Optional[str]): "full" or "compact" + + Returns: + str: Formatted completion message + """ + if timed_out: + return f"✅ Done | ✗ Timeout - No response after {timeout}s" + + # Collect metrics from execution_stats if available + total_api_calls = 0 + total_credits = 0.0 + runtime = elapsed_time + + # Extract data dict (handle tuple or direct object) + data_dict = None + if hasattr(response_body, "data") and response_body.data: + if isinstance(response_body.data, tuple) and len(response_body.data) > 0: + # Data is a tuple, get first element + data_dict = response_body.data[0] if isinstance(response_body.data[0], dict) else None + elif isinstance(response_body.data, dict): + # Data is already a dict + data_dict = response_body.data + elif hasattr(response_body.data, "executionStats") or hasattr(response_body.data, "execution_stats"): + # Data is an object with attributes + exec_stats = getattr(response_body.data, "executionStats", None) or getattr( + response_body.data, "execution_stats", None + ) + if exec_stats and isinstance(exec_stats, dict): + total_api_calls = exec_stats.get("api_calls", 0) + total_credits = exec_stats.get("credits", 0.0) + runtime = exec_stats.get("runtime", elapsed_time) + + # Try to get metrics from data dict (camelCase fields from backend) + if data_dict and isinstance(data_dict, dict): + # Check executionStats first + exec_stats = data_dict.get("executionStats") + if exec_stats and isinstance(exec_stats, dict): + total_api_calls = exec_stats.get("api_calls", 0) + total_credits = exec_stats.get("credits", 0.0) + runtime = exec_stats.get("runtime", elapsed_time) + + # Fallback: check top-level fields (usedCredits, runTime) + if total_credits == 0.0: + total_credits = data_dict.get("usedCredits", 0.0) + if runtime == elapsed_time: + runtime = data_dict.get("runTime", elapsed_time) + + # Build single-line completion message with metrics + if verbosity == "compact": + msg = f"✅ Done | ({runtime:.1f} s total" + else: + msg = f"✅ Done | Completed successfully ({runtime:.1f} s total" + + # Always show API calls and credits + if total_api_calls > 0: + msg += f" | {total_api_calls} API calls" + msg += f" | ${total_credits}" + msg += ")" + + return msg + + def sync_poll( + self, + poll_url: Text, + name: Text = "model_process", + wait_time: float = 0.5, + timeout: float = 300, + progress_verbosity: Optional[str] = "compact", + ) -> AgentResponse: + """Poll the platform until team agent execution completes or times out. + + Args: + poll_url (Text): URL to poll for operation status. + name (Text, optional): Identifier for the operation. Defaults to "model_process". + wait_time (float, optional): Initial wait time in seconds between polls. Defaults to 0.5. + timeout (float, optional): Maximum total time to poll in seconds. Defaults to 300. + progress_verbosity (Optional[str], optional): Progress display mode - "full" (detailed), "compact" (brief), or None (no progress). Defaults to "compact". + + Returns: + AgentResponse: The final response from the team agent execution. + """ + logging.info(f"Polling for Team Agent: Start polling for {name}") + start, end = time.time(), time.time() + wait_time = max(wait_time, 0.2) + completed = False + response_body = AgentResponse(status=ResponseStatus.FAILED, completed=False) + last_message = None # Track last message to avoid duplicates + + while not completed and (end - start) < timeout: + try: + response_body = self.poll(poll_url, name=name) + completed = response_body["completed"] + + # Display progress inline if enabled + if progress_verbosity and not completed: + progress = response_body.get("progress") + if progress: + msg = self._format_team_progress(progress, progress_verbosity) + if msg and msg != last_message: + print(msg, flush=True) + last_message = msg + + end = time.time() + if completed is False: + time.sleep(wait_time) + if wait_time < 60: + wait_time *= 1.1 + except Exception as e: + response_body = AgentResponse( + status=ResponseStatus.FAILED, + completed=False, + error_message="No response from the service.", + ) + logging.error(f"Polling for Team Agent: polling for {name}: {e}") + break + + # Display completion message + if progress_verbosity: + elapsed_time = end - start + timed_out = response_body["completed"] is not True + completion_msg = self._format_completion_message( + elapsed_time, response_body, timed_out, timeout, progress_verbosity + ) + print(completion_msg, flush=True) + + if response_body["completed"] is True: + logging.debug(f"Polling for Team Agent: Final status of polling for {name}: {response_body}") + else: + response_body = AgentResponse( + status=ResponseStatus.FAILED, + completed=False, + error_message="No response from the service.", + ) + logging.error( + f"Polling for Team Agent: Final status of polling for {name}: No response in {timeout} seconds" + ) + + return response_body + def run( self, data: Optional[Union[Dict, Text]] = None, @@ -198,9 +556,9 @@ def run( content: Optional[Union[Dict[Text, Text], List[Text]]] = None, max_tokens: int = 2048, max_iterations: int = 30, - output_format: Optional[OutputFormat] = None, - expected_output: Optional[Union[BaseModel, Text, dict]] = None, trace_request: bool = False, + progress_verbosity: Optional[str] = "compact", + **kwargs, ) -> AgentResponse: """Runs a team agent call. @@ -216,12 +574,31 @@ def run( content (Union[Dict[Text, Text], List[Text]], optional): Content inputs to be processed according to the query. Defaults to None. max_tokens (int, optional): maximum number of tokens which can be generated by the agents. Defaults to 2048. max_iterations (int, optional): maximum number of iterations between the agents. Defaults to 30. - output_format (OutputFormat, optional): response format. If not provided, uses the format set during initialization. - expected_output (Union[BaseModel, Text, dict], optional): expected output. Defaults to None. trace_request (bool, optional): return the request id for tracing the request. Defaults to False. + progress_verbosity (Optional[str], optional): Progress display mode - "full" (detailed), "compact" (brief), or None (no progress). Defaults to "compact". + Returns: AgentResponse: parsed output from model """ + # Handle deprecated parameters from kwargs + output_format = kwargs.pop("output_format", None) + if output_format is not None: + warnings.warn( + "Parameter 'output_format' in run() is deprecated and will be removed in a future version. " + "Please set 'output_format' during TeamAgent initialization instead.", + DeprecationWarning, + stacklevel=2, + ) + + expected_output = kwargs.pop("expected_output", None) + if expected_output is not None: + warnings.warn( + "Parameter 'expected_output' in run() is deprecated and will be removed in a future version. " + "Please set 'expected_output' during TeamAgent initialization instead.", + DeprecationWarning, + stacklevel=2, + ) + start = time.time() result_data = {} if session_id is not None and history is not None: @@ -253,7 +630,9 @@ def run( return response poll_url = response["url"] end = time.time() - result = self.sync_poll(poll_url, name=name, timeout=timeout, wait_time=wait_time) + result = self.sync_poll( + poll_url, name=name, timeout=timeout, wait_time=wait_time, progress_verbosity=progress_verbosity + ) result_data = result.data return AgentResponse( status=ResponseStatus.SUCCESS, @@ -310,6 +689,7 @@ def run_async( expected_output (Union[BaseModel, Text, dict], optional): expected output. Defaults to None. evolve (Union[Dict[str, Any], EvolveParam, None], optional): evolve the team agent configuration. Can be a dictionary, EvolveParam instance, or None. trace_request (bool, optional): return the request id for tracing the request. Defaults to False. + Returns: AgentResponse: polling URL in response """ @@ -335,7 +715,9 @@ def run_async( assert data is not None or query is not None, "Either 'data' or 'query' must be provided." if data is not None: if isinstance(data, dict): - assert "query" in data and data["query"] is not None, "When providing a dictionary, 'query' must be provided." + assert "query" in data and data["query"] is not None, ( + "When providing a dictionary, 'query' must be provided." + ) if session_id is None: session_id = data.pop("session_id", None) if history is None: @@ -348,9 +730,9 @@ def run_async( # process content inputs if content is not None: - assert ( - isinstance(query, str) and FileFactory.check_storage_type(query) == StorageType.TEXT - ), "When providing 'content', query must be text." + assert isinstance(query, str) and FileFactory.check_storage_type(query) == StorageType.TEXT, ( + "When providing 'content', query must be text." + ) if isinstance(content, list): assert len(content) <= 3, "The maximum number of content inputs is 3." @@ -419,6 +801,15 @@ def run_async( return response def poll(self, poll_url: Text, name: Text = "model_process") -> AgentResponse: + """Poll once for team agent execution status. + + Args: + poll_url (Text): URL to poll for status. + name (Text, optional): Identifier for the operation. Defaults to "model_process". + + Returns: + AgentResponse: Response containing status, data, and progress information. + """ used_credits, run_time = 0.0, 0.0 resp, error_message, status = None, None, ResponseStatus.SUCCESS headers = {"x-api-key": self.api_key, "Content-Type": "application/json"} @@ -459,10 +850,17 @@ def poll(self, poll_url: Text, name: Text = "model_process") -> AgentResponse: except Exception as e: import traceback - logging.error(f"Single Poll for Team Agent: Error of polling for {name}: {e}, traceback: {traceback.format_exc()}") + logging.error( + f"Single Poll for Team Agent: Error of polling for {name}: {e}, traceback: {traceback.format_exc()}" + ) status = ResponseStatus.FAILED error_message = str(e) finally: + # Normalize progress data from camelCase to snake_case + progress_data = resp.get("progress") if resp else None + if progress_data: + progress_data = self._normalize_progress_data(progress_data) + response = AgentResponse( status=status, data=resp_data, @@ -472,11 +870,12 @@ def poll(self, poll_url: Text, name: Text = "model_process") -> AgentResponse: run_time=run_time, usage=resp.get("usage", None), error_message=error_message, + progress=progress_data, ) return response def delete(self) -> None: - """Delete Corpus service""" + """Deletes Team Agent.""" try: url = urljoin(config.BACKEND_URL, f"sdk/agent-communities/{self.id}") headers = { @@ -488,9 +887,7 @@ def delete(self) -> None: if r.status_code != 200: raise Exception() except Exception: - message = ( - f"Team Agent Deletion Error (HTTP {r.status_code}): Make sure the Team Agent exists and you are the owner." - ) + message = f"Team Agent Deletion Error (HTTP {r.status_code}): Make sure the Team Agent exists and you are the owner." logging.error(message) raise Exception(f"{message}") @@ -673,19 +1070,20 @@ def from_dict(cls, data: Dict) -> "TeamAgent": name=data["name"], agents=agents, description=data.get("description", ""), - llm_id=data.get("llmId", "6646261c6eb563165658bbb1"), llm=llm, supervisor_llm=supervisor_llm, - mentalist_llm=mentalist_llm, supplier=data.get("supplier", "aiXplain"), version=data.get("version"), - use_mentalist=use_mentalist, status=status, instructions=data.get("instructions"), inspectors=inspectors, inspector_targets=inspector_targets, output_format=OutputFormat(data.get("outputFormat", OutputFormat.TEXT)), expected_output=data.get("expectedOutput"), + # Pass deprecated params via kwargs + llm_id=data.get("llmId", "6646261c6eb563165658bbb1"), + mentalist_llm=mentalist_llm, + use_mentalist=use_mentalist, ) def _validate(self) -> None: @@ -694,9 +1092,9 @@ def _validate(self) -> None: """Validate the Team.""" # validate name - assert ( - re.match(r"^[a-zA-Z0-9 \-\(\)]*$", self.name) is not None - ), "Team Agent Creation Error: Team name contains invalid characters. Only alphanumeric characters, spaces, hyphens, and brackets are allowed." + assert re.match(r"^[a-zA-Z0-9 \-\(\)]*$", self.name) is not None, ( + "Team Agent Creation Error: Team name contains invalid characters. Only alphanumeric characters, spaces, hyphens, and brackets are allowed." + ) try: llm = get_llm_instance(self.llm_id, use_cache=True) @@ -769,7 +1167,7 @@ def update(self) -> None: stack = inspect.stack() if len(stack) > 2 and stack[1].function != "save": warnings.warn( - "update() is deprecated and will be removed in a future version. " "Please use save() instead.", + "update() is deprecated and will be removed in a future version. Please use save() instead.", DeprecationWarning, stacklevel=2, ) @@ -903,7 +1301,9 @@ def evolve( end = time.time() result = self.sync_poll(poll_url, name="evolve_process", timeout=600) result_data = result.data - current_code = result_data.get("current_code") if isinstance(result_data, dict) else result_data.current_code + current_code = ( + result_data.get("current_code") if isinstance(result_data, dict) else result_data.current_code + ) if current_code is not None: if evolve_parameters.evolve_type == EvolveType.TEAM_TUNING: result_data["evolved_agent"] = build_team_agent_from_yaml( diff --git a/aixplain/modules/team_agent/inspector.py b/aixplain/modules/team_agent/inspector.py index d0a10932..3c091a02 100644 --- a/aixplain/modules/team_agent/inspector.py +++ b/aixplain/modules/team_agent/inspector.py @@ -16,8 +16,6 @@ name="team" agents=agents, description="team description", - llm_id="xyz", - use_mentalist=True, inspectors=[inspector], ) """ @@ -58,7 +56,7 @@ class InspectorOutput(BaseModel): class InspectorAuto(str, Enum): """A list of keywords for inspectors configured automatically in the backend.""" - + ALIGNMENT = "alignment" CORRECTNESS = "correctness" def get_name(self) -> Text: @@ -382,3 +380,64 @@ def model_validate(cls, data: Union[Dict, "Inspector"]) -> "Inspector": data.pop("policy_type", None) # Remove the type indicator return super().model_validate(data) + + +class VerificationInspector(Inspector): + """Pre-defined inspector that checks output against the plan. + + This inspector is designed to verify that the output aligns with the intended plan + and provides feedback when discrepancies are found. It uses a RERUN policy by default, + meaning it will request re-execution when issues are detected. + + Example usage: + from aixplain.modules.team_agent import VerificationInspector + + # Use with default model (GPT-4o or resolved_model_id) + inspector = VerificationInspector() + + # Or with custom model + inspector = VerificationInspector(model_id="your_model_id") + + team_agent = TeamAgent( + name="my_team", + agents=agents, + inspectors=[VerificationInspector()], + inspector_targets=[InspectorTarget.STEPS] + ) + """ + + def __init__(self, model_id: Optional[Text] = None, **kwargs): + """Initialize VerificationInspector with default configuration. + + Args: + model_id (Optional[Text]): Model ID to use. If not provided, uses auto configuration. + **kwargs: Additional arguments passed to Inspector parent class. + """ + from aixplain.modules.model.response import ModelResponse + + # Replicate resolved_model_id logic from old implementation + resolved_model_id = model_id + if not resolved_model_id: + resolved_model_id = "6646261c6eb563165658bbb1" # GPT_4o_ID + + def process_response(model_response: ModelResponse, input_content: str) -> InspectorOutput: + """Default policy that always requests rerun for verification.""" + critiques = model_response.data + action = InspectorAction.RERUN + return InspectorOutput(critiques=critiques, content_edited=input_content, action=action) + + # Exact same default inspector configuration as old implementation + # Note: When auto=InspectorAuto.ALIGNMENT is set, Inspector.__init__ will override + # model_id with AUTO_DEFAULT_MODEL_ID + defaults = { + "name": "VerificationInspector", + "model_id": resolved_model_id, + "model_params": {"prompt": "Check the output against the plan"}, + "policy": process_response, + "auto": InspectorAuto.ALIGNMENT + } + + # Override defaults with any provided kwargs + defaults.update(kwargs) + + super().__init__(**defaults) diff --git a/tests/functional/agent/agent_functional_test.py b/tests/functional/agent/agent_functional_test.py index 5dd73d2a..0d541500 100644 --- a/tests/functional/agent/agent_functional_test.py +++ b/tests/functional/agent/agent_functional_test.py @@ -122,7 +122,7 @@ def test_python_interpreter_tool(delete_agents_and_team_agents, AgentFactory): assert len(response["data"]["intermediate_steps"]) > 0 intermediate_step = response["data"]["intermediate_steps"][0] assert len(intermediate_step["tool_steps"]) > 0 - assert intermediate_step["tool_steps"][0]["tool"] == "Custom Code Tool" + assert intermediate_step["tool_steps"][0]["tool"] == "Python Code Interpreter Tool" agent.delete() diff --git a/tests/functional/team_agent/team_agent_functional_test.py b/tests/functional/team_agent/team_agent_functional_test.py index 3709c0f8..c357b8d4 100644 --- a/tests/functional/team_agent/team_agent_functional_test.py +++ b/tests/functional/team_agent/team_agent_functional_test.py @@ -247,7 +247,7 @@ def test_add_remove_agents_from_team_agent(run_input_map, delete_agents_and_team def test_team_agent_tasks(delete_agents_and_team_agents): assert delete_agents_and_team_agents agent = AgentFactory.create( - name="Teste", + name="Test Sub Agent", description="You are a test agent that always returns the same answer", tools=[ AgentFactory.create_model_tool(function=Function.TRANSLATION, supplier=Supplier.MICROSOFT), @@ -268,13 +268,13 @@ def test_team_agent_tasks(delete_agents_and_team_agents): ) team_agent = TeamAgentFactory.create( - name="Teste", + name="Test Multi Agent", agents=[agent], description="Teste", ) response = team_agent.run(data="Translate 'teste'") assert response.status == "SUCCESS" - assert "teste" in response.data["output"] + assert "test" in response.data["output"] def test_team_agent_with_parameterized_agents(run_input_map, delete_agents_and_team_agents): @@ -538,6 +538,7 @@ def test_team_agent_with_slack_connector(): authentication_schema=AuthenticationSchema.BEARER_TOKEN, data={"token": os.getenv("SLACK_TOKEN")}, ) + connection_id = response.data["id"] connection = ModelFactory.get(connection_id) diff --git a/tests/unit/agent/agent_test.py b/tests/unit/agent/agent_test.py index 540fb734..81b6debf 100644 --- a/tests/unit/agent/agent_test.py +++ b/tests/unit/agent/agent_test.py @@ -476,26 +476,34 @@ def test_run_success(): assert response["url"] == ref_response["data"] -def test_run_variable_error(): +def test_run_variable_missing(): + """Test that agent runs successfully even when variables are missing from data/parameters.""" agent = Agent( "123", "Test Agent", "Agent description", instructions="Translate the input data into {target_language}", ) - agent = Agent( - "123", - "Test Agent", - "Agent description", - instructions="Translate the input data into {target_language}", - ) - with pytest.raises(Exception) as exc_info: - agent.run_async(data={"query": "Hello, how are you?"}, output_format=OutputFormat.MARKDOWN) - assert str(exc_info.value) == ( - "Variable 'target_language' not found in data or parameters. " - "This variable is required by the agent according to its description " - "('Translate the input data into {target_language}')." - ) + + # Mock the agent URL and response + url = urljoin(config.BACKEND_URL, f"sdk/agents/{agent.id}/run") + agent.url = url + + with requests_mock.Mocker() as mock: + headers = { + "x-api-key": config.AIXPLAIN_API_KEY, + "Content-Type": "application/json", + } + ref_response = {"data": "www.aixplain.com", "status": "IN_PROGRESS"} + mock.post(url, headers=headers, json=ref_response) + + # This should not raise an exception anymore - missing variables are silently ignored + response = agent.run_async(data={"query": "Hello, how are you?"}, output_format=OutputFormat.MARKDOWN) + + # Verify the response is successful + assert isinstance(response, AgentResponse) + assert response["status"] == "IN_PROGRESS" + assert response["url"] == ref_response["data"] def test_process_variables(): diff --git a/tests/unit/team_agent/team_agent_test.py b/tests/unit/team_agent/team_agent_test.py index e63355a8..24098665 100644 --- a/tests/unit/team_agent/team_agent_test.py +++ b/tests/unit/team_agent/team_agent_test.py @@ -741,8 +741,8 @@ def test_save_success(mock_model_factory_get): # Call the save method team_agent.save() - # Assert no warnings were triggered - assert len(w) == 0, f"Warnings were raised: {[str(warning.message) for warning in w]}" + # Assert the correct number of warnings were raised + assert len(w) == 3, f"Warnings were raised: {[str(warning.message) for warning in w]}" assert team_agent.id == ref_response["id"] assert team_agent.name == ref_response["name"]