diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2500c2c8..4ad3e5bc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,18 +8,26 @@ repos: pass_filenames: false types: [python] always_run: true - + - repo: https://github.com/psf/black - rev: 22.10.0 + rev: 25.1.0 hooks: - id: black language_version: python3 args: # arguments to configure black - --line-length=128 - + - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v2.0.0 # Use the latest version + rev: v5.0.0 # Use the latest version + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-merge-conflict + - id: check-added-large-files + + - repo: https://github.com/pycqa/flake8 + rev: 7.2.0 hooks: - id: flake8 args: # arguments to configure flake8 - - --ignore=E402,E501,E203,W503 \ No newline at end of file + - --ignore=E402,E501,E203,W503 diff --git a/aixplain/enums/__init__.py b/aixplain/enums/__init__.py index 17308467..5c3a71f5 100644 --- a/aixplain/enums/__init__.py +++ b/aixplain/enums/__init__.py @@ -20,4 +20,5 @@ from .asset_status import AssetStatus from .index_stores import IndexStores from .function_type import FunctionType +from .evolve_type import EvolveType from .code_interpreter import CodeInterpreterModel diff --git a/aixplain/enums/evolve_type.py b/aixplain/enums/evolve_type.py new file mode 100644 index 00000000..555fdb53 --- /dev/null +++ b/aixplain/enums/evolve_type.py @@ -0,0 +1,6 @@ +from enum import Enum + + +class EvolveType(str, Enum): + TEAM_TUNING = "team_tuning" + INSTRUCTION_TUNING = "instruction_tuning" diff --git a/aixplain/factories/agent_factory/__init__.py b/aixplain/factories/agent_factory/__init__.py index 95037f65..0caf0814 100644 --- a/aixplain/factories/agent_factory/__init__.py +++ b/aixplain/factories/agent_factory/__init__.py @@ -28,7 +28,7 @@ from aixplain.enums.function import Function from aixplain.enums.supplier import Supplier -from aixplain.modules.agent import Agent, AgentTask, Tool +from aixplain.modules.agent import Agent, Tool, WorkflowTask from aixplain.modules.agent.output_format import OutputFormat from aixplain.modules.agent.tool.model_tool import ModelTool from aixplain.modules.agent.tool.pipeline_tool import PipelineTool @@ -67,7 +67,8 @@ def create( api_key: Text = config.TEAM_API_KEY, supplier: Union[Dict, Text, Supplier, int] = "aiXplain", version: Optional[Text] = None, - tasks: List[AgentTask] = [], + tasks: List[WorkflowTask] = None, + workflow_tasks: List[WorkflowTask] = [], output_format: Optional[OutputFormat] = None, expected_output: Optional[Union[BaseModel, Text, dict]] = None, ) -> Agent: @@ -80,7 +81,7 @@ def create( Args: name (Text): name of the agent - description (Text): description of the agent role. + description (Text): description of the agent instructions. instructions (Text): instructions of the agent. llm (Optional[Union[LLM, Text]], optional): LLM instance to use as an object or as an ID. llm_id (Optional[Text], optional): ID of LLM to use if no LLM instance provided. Defaults to None. @@ -88,7 +89,7 @@ def create( api_key (Text, optional): team/user API key. Defaults to config.TEAM_API_KEY. supplier (Union[Dict, Text, Supplier, int], optional): owner of the agent. Defaults to "aiXplain". version (Optional[Text], optional): version of the agent. Defaults to None. - tasks (List[AgentTask], optional): list of tasks for the agent. Defaults to []. + workflow_tasks (List[WorkflowTask], optional): list of tasks for the agent. Defaults to []. output_format (OutputFormat, optional): default output format for agent responses. Defaults to OutputFormat.TEXT. expected_output (Union[BaseModel, Text, dict], optional): expected output. Defaults to None. Returns: @@ -127,6 +128,16 @@ def create( elif isinstance(supplier, Supplier): supplier = supplier.value["code"] + if tasks is not None: + warnings.warn( + "The 'tasks' parameter is deprecated and will be removed in a future version. " "Use 'workflow_tasks' instead.", + DeprecationWarning, + stacklevel=2, + ) + workflow_tasks = tasks if workflow_tasks is None or workflow_tasks == [] else workflow_tasks + + workflow_tasks = workflow_tasks or [] + payload = { "name": name, "assets": [build_tool_payload(tool) for tool in tools], @@ -136,7 +147,7 @@ def create( "version": version, "llmId": llm_id, "status": "draft", - "tasks": [task.to_dict() for task in tasks], + "tasks": [task.to_dict() for task in workflow_tasks], "tools": [], } @@ -206,32 +217,30 @@ def create_from_dict(cls, dict: Dict) -> Agent: return agent @classmethod - def create_task( + def create_workflow_task( cls, name: Text, description: Text, expected_output: Text, - dependencies: Optional[List[Text]] = None, - ) -> AgentTask: - """Create a new task for an agent. - - Args: - name (Text): Name of the task. - description (Text): Description of what the task should accomplish. - expected_output (Text): Description of the expected output format. - dependencies (Optional[List[Text]], optional): List of task names that must - complete before this task can start. Defaults to None. - - Returns: - AgentTask: Created task object. - """ - return AgentTask( + dependencies: Optional[List[Text]] = [], + ) -> WorkflowTask: + return WorkflowTask( name=name, description=description, expected_output=expected_output, dependencies=dependencies, ) + @classmethod + def create_task(cls, *args, **kwargs): + warnings.warn( + "The 'create_task' method is deprecated and will be removed in a future version. " + "Use 'create_workflow_task' instead.", + DeprecationWarning, + stacklevel=2, + ) + return cls.create_workflow_task(*args, **kwargs) + @classmethod def create_model_tool( cls, diff --git a/aixplain/factories/model_factory/utils.py b/aixplain/factories/model_factory/utils.py index 2cd049ee..a831a7c7 100644 --- a/aixplain/factories/model_factory/utils.py +++ b/aixplain/factories/model_factory/utils.py @@ -58,10 +58,14 @@ def create_model_from_response(response: Dict) -> Model: additional_kwargs = {} attributes = response.get("attributes", None) if attributes: - embedding_model = next((item["code"] for item in attributes if item["name"] == "embeddingmodel"), None) + embedding_model = next( + (item.get("code") for item in attributes if item.get("name") == "embeddingmodel" and "code" in item), None + ) if embedding_model: additional_kwargs["embedding_model"] = embedding_model - embedding_size = next((item["value"] for item in attributes if item["name"] == "embeddingSize"), None) + embedding_size = next( + (item.get("value") for item in attributes if item.get("name") == "embeddingSize" and "value" in item), None + ) if embedding_size: additional_kwargs["embedding_size"] = embedding_size @@ -87,7 +91,7 @@ def create_model_from_response(response: Dict) -> Model: ModelClass = IndexModel elif function_type == FunctionType.INTEGRATION: ModelClass = Integration - elif function_type == FunctionType.CONNECTION : + elif function_type == FunctionType.CONNECTION: ModelClass = ConnectionTool elif function_type == FunctionType.MCP_CONNECTION: ModelClass = MCPConnection @@ -297,11 +301,52 @@ def get_model_from_ids(model_ids: List[str], api_key: Optional[str] = None) -> L raise Exception(f"{message}") if 200 <= r.status_code < 300: models = [] - for item in resp["items"]: - item["api_key"] = config.TEAM_API_KEY - if api_key is not None: - item["api_key"] = api_key - models.append(create_model_from_response(item)) + + def process_items(items): + """Helper function to process model items and add API key""" + for item in items: + item["api_key"] = config.TEAM_API_KEY + if api_key is not None: + item["api_key"] = api_key + models.append(create_model_from_response(item)) + + # Check if pagination is needed ( pageNumber: 0 indicates pagination required) + if "pageTotal" in resp: + # Handle paginated response - need to fetch all pages + page_number = 1 + total_fetched = resp.get("pageTotal", 0) + page_items = resp.get("items", []) + total_items = resp.get("total", 0) + process_items(page_items) + + while True: + # Make request for current page + paginated_url = urljoin(config.BACKEND_URL, f"sdk/models?ids={','.join(model_ids)}&pageNumber={page_number}") + logging.info(f"Fetching page {page_number} - {paginated_url}") + page_r = _request_with_retry("get", paginated_url, headers=headers) + page_resp = page_r.json() + + if not (200 <= page_r.status_code < 300): + error_message = f"Model GET Error: Failed to retrieve models page {page_number}. Status Code: {page_r.status_code}. Error: {page_resp}" + logging.error(error_message) + raise Exception(error_message) + + # Process items from current page + page_items = page_resp.get("items", []) + if not page_items: + break + + process_items(page_items) + total_fetched += len(page_items) + + if total_fetched >= total_items: + break + + page_number += 1 + else: + # Handle non-paginated response (original logic) + process_items(resp["items"]) + return models else: error_message = f"Model GET Error: Failed to retrieve models {model_ids}. Status Code: {r.status_code}. Error: {resp}" diff --git a/aixplain/factories/team_agent_factory/__init__.py b/aixplain/factories/team_agent_factory/__init__.py index 580c28c7..ba8d1119 100644 --- a/aixplain/factories/team_agent_factory/__init__.py +++ b/aixplain/factories/team_agent_factory/__init__.py @@ -145,13 +145,13 @@ def _get_llm_safely(llm_id: str, llm_type: str) -> LLM: try: return get_llm_instance(llm_id, api_key=api_key) except Exception: - raise Exception(f"TeamAgent Onboarding Error: LLM {llm_id} does not exist for {llm_type}. To resolve this, set the following LLM parameters to a valid LLM object or LLM ID: llm, supervisor_llm, mentalist_llm.") + raise Exception( + f"TeamAgent Onboarding Error: LLM {llm_id} does not exist for {llm_type}. To resolve this, set the following LLM parameters to a valid LLM object or LLM ID: llm, supervisor_llm, mentalist_llm." + ) - def _setup_llm_and_tool(llm_param: Optional[Union[LLM, Text]], - default_id: Text, - llm_type: str, - description: str, - tools: List[Dict]) -> LLM: + def _setup_llm_and_tool( + llm_param: Optional[Union[LLM, Text]], default_id: Text, llm_type: str, description: str, tools: List[Dict] + ) -> LLM: """Helper to set up an LLM and add its tool configuration.""" llm_instance = None # Set up LLM @@ -159,21 +159,25 @@ def _setup_llm_and_tool(llm_param: Optional[Union[LLM, Text]], llm_instance = _get_llm_safely(default_id, llm_type) else: llm_instance = _get_llm_safely(llm_param, llm_type) if isinstance(llm_param, str) else llm_param - + # Add tool configuration if llm_instance is not None: - tools.append({ - "type": "llm", - "description": description, - "parameters": llm_instance.get_parameters().to_list() if llm_instance.get_parameters() else None, - }) + tools.append( + { + "type": "llm", + "description": description, + "parameters": llm_instance.get_parameters().to_list() if llm_instance.get_parameters() else None, + } + ) return llm_instance, tools # Set up LLMs and their tools tools = [] llm, tools = _setup_llm_and_tool(llm, llm_id, "Main LLM", "main", tools) supervisor_llm, tools = _setup_llm_and_tool(supervisor_llm, llm_id, "Supervisor LLM", "supervisor", tools) - mentalist_llm, tools = _setup_llm_and_tool(mentalist_llm, llm_id, "Mentalist LLM", "mentalist", tools) if use_mentalist else None + mentalist_llm, tools = ( + _setup_llm_and_tool(mentalist_llm, llm_id, "Mentalist LLM", "mentalist", tools) if use_mentalist else (None, []) + ) team_agent = None url = urljoin(config.BACKEND_URL, "sdk/agent-communities") @@ -201,9 +205,8 @@ def _setup_llm_and_tool(llm_param: Optional[Union[LLM, Text]], "supplier": supplier, "version": version, "status": "draft", - "tools": [], - "instructions": instructions, "tools": tools, + "instructions": instructions, } # Store the LLM objects directly in the payload for build_team_agent internal_payload = payload.copy() diff --git a/aixplain/factories/team_agent_factory/inspector_factory.py b/aixplain/factories/team_agent_factory/inspector_factory.py index a6d38d4f..ce6c4927 100644 --- a/aixplain/factories/team_agent_factory/inspector_factory.py +++ b/aixplain/factories/team_agent_factory/inspector_factory.py @@ -4,6 +4,8 @@ and monitor team agent operations. Inspectors can be created from existing models or using automatic configurations. +WARNING: This feature is currently in private beta. + Example: Create an inspector from a model with adaptive policy:: @@ -19,7 +21,7 @@ """ import logging -from typing import Dict, Optional, Text, Union +from typing import Dict, Optional, Text, Union, Callable from urllib.parse import urljoin from aixplain.enums.asset_status import AssetStatus @@ -31,7 +33,7 @@ from aixplain.utils.file_utils import _request_with_retry -INSPECTOR_SUPPORTED_FUNCTIONS = [Function.GUARDRAILS, Function.TEXT_GENERATION] +INSPECTOR_SUPPORTED_FUNCTIONS = [Function.GUARDRAILS, Function.TEXT_GENERATION, Function.UTILITIES] class InspectorFactory: @@ -48,7 +50,7 @@ def create_from_model( name: Text, model: Union[Text, Model], model_config: Optional[Dict] = None, - policy: InspectorPolicy = InspectorPolicy.ADAPTIVE, # default: doing something dynamically + policy: Union[InspectorPolicy, Callable] = InspectorPolicy.ADAPTIVE, # default: doing something dynamically ) -> Inspector: """Create a new inspector agent from an onboarded model. @@ -62,11 +64,10 @@ def create_from_model( to use for the inspector. model_config (Optional[Dict], optional): Configuration parameters for the inspector model (e.g., prompts, thresholds). Defaults to None. - policy (InspectorPolicy, optional): Action to take upon negative feedback: - - WARN: Log warning but continue execution - - ABORT: Stop execution on negative feedback - - ADAPTIVE: Dynamically decide based on context - Defaults to InspectorPolicy.ADAPTIVE. + policy: Action to take upon negative feedback (WARN/ABORT/ADAPTIVE) + or a callable function. If callable, must have name "process_response", + arguments "model_response" and "input_content" (both strings), and + return InspectorAction. Defaults to ADAPTIVE. Returns: Inspector: Created and configured inspector agent. @@ -124,7 +125,7 @@ def create_auto( cls, auto: InspectorAuto, name: Optional[Text] = None, - policy: InspectorPolicy = InspectorPolicy.ADAPTIVE, + policy: Union[InspectorPolicy, Callable] = InspectorPolicy.ADAPTIVE, ) -> Inspector: """Create a new inspector agent using automatic configuration. @@ -136,11 +137,10 @@ def create_auto( auto (InspectorAuto): Pre-configured automatic inspector instance. name (Optional[Text], optional): Name for the inspector. If not provided, uses the name from the auto configuration. Defaults to None. - policy (InspectorPolicy, optional): Action to take upon negative feedback: - - WARN: Log warning but continue execution - - ABORT: Stop execution on negative feedback - - ADAPTIVE: Dynamically decide based on context - Defaults to InspectorPolicy.ADAPTIVE. + policy: Action to take upon negative feedback (WARN/ABORT/ADAPTIVE) + or a callable function. If callable, must have name "process_response", + arguments "model_response" and "input_content" (both strings), and + return InspectorAction. Defaults to ADAPTIVE. Returns: Inspector: Created and configured inspector agent using automatic diff --git a/aixplain/factories/team_agent_factory/utils.py b/aixplain/factories/team_agent_factory/utils.py index 08c7e93f..4a0eefc1 100644 --- a/aixplain/factories/team_agent_factory/utils.py +++ b/aixplain/factories/team_agent_factory/utils.py @@ -1,12 +1,14 @@ __author__ = "lucaspavanelli" import logging -from typing import Dict, Text, List +from typing import Dict, Text, List, Optional from urllib.parse import urljoin import aixplain.utils.config as config from aixplain.enums.asset_status import AssetStatus from aixplain.modules.agent import Agent +from aixplain.modules.agent.agent_task import AgentTask +from aixplain.modules.agent.tool.model_tool import ModelTool from aixplain.modules.team_agent import TeamAgent, InspectorTarget from aixplain.modules.team_agent.inspector import Inspector from aixplain.factories.agent_factory import AgentFactory @@ -15,6 +17,7 @@ from aixplain.modules.agent.output_format import OutputFormat GPT_4o_ID = "6646261c6eb563165658bbb1" +SUPPORTED_TOOLS = ["llm", "website_search", "website_scrape", "website_crawl", "serper_search"] def build_team_agent(payload: Dict, agents: List[Agent] = None, api_key: Text = config.TEAM_API_KEY) -> TeamAgent: @@ -29,7 +32,7 @@ def build_team_agent(payload: Dict, agents: List[Agent] = None, api_key: Text = - name: Team agent name - agents: List of agent configurations - description: Optional description - - role: Optional instructions + - instructions: Optional instructions - teamId: Optional supplier information - version: Optional version - cost: Optional cost information @@ -67,9 +70,21 @@ def build_team_agent(payload: Dict, agents: List[Agent] = None, api_key: Text = continue # Ensure custom classes are instantiated: for compatibility with backend return format - inspectors = [ - inspector if isinstance(inspector, Inspector) else Inspector(**inspector) for inspector in payload.get("inspectors", []) - ] + inspectors = [] + for inspector_data in payload.get("inspectors", []): + try: + if isinstance(inspector_data, Inspector): + inspectors.append(inspector_data) + else: + # Handle both old format and new format with policy_type + if hasattr(Inspector, "model_validate"): + inspectors.append(Inspector.model_validate(inspector_data)) + else: + inspectors.append(Inspector(**inspector_data)) + except Exception as e: + logging.warning(f"Failed to create inspector from data: {e}") + continue + inspector_targets = [InspectorTarget(target.lower()) for target in payload.get("inspectorTargets", [])] # Get LLMs from tools if present @@ -100,7 +115,10 @@ def build_team_agent(payload: Dict, agents: List[Agent] = None, api_key: Text = # Convert parameters list to dictionary format expected by ModelParameters params_dict = {} for param in tool["parameters"]: - params_dict[param["name"]] = {"required": False, "value": param["value"]} + params_dict[param["name"]] = { + "required": False, + "value": param["value"], + } # Create ModelParameters and set it on the LLM llm.model_params = ModelParameters(params_dict) @@ -146,5 +164,165 @@ def build_team_agent(payload: Dict, agents: List[Agent] = None, api_key: Text = if task_dependency: team_agent.agents[idx].tasks[i].dependencies[j] = task_dependency else: - raise Exception(f"Team Agent Creation Error: Task dependency not found - {dependency}") + team_agent.agents[idx].tasks[i].dependencies[j] = None + return team_agent + + +def parse_tool_from_yaml(tool: str) -> ModelTool: + from aixplain.enums import Function + + tool_name = tool.strip() + if tool_name == "translation": + return ModelTool( + function=Function.TRANSLATION, + ) + elif tool_name == "speech-recognition": + return ModelTool( + function=Function.SPEECH_RECOGNITION, + ) + elif tool_name == "text-to-speech": + return ModelTool( + function=Function.SPEECH_SYNTHESIS, + ) + elif tool_name == "llm": + return ModelTool(function=Function.TEXT_GENERATION) + elif tool_name == "serper_search": + return ModelTool(model="65c51c556eb563350f6e1bb1") + elif tool.strip() == "website_search": + return ModelTool(model="6736411cf127849667606689") + elif tool.strip() == "website_scrape": + return ModelTool(model="6748e4746eb5633559668a15") + elif tool.strip() == "website_crawl": + return ModelTool(model="6748d4cff12784b6014324e2") + else: + raise Exception(f"Tool {tool} in yaml not found.") + + +import yaml + + +def is_yaml_formatted(text): + """ + Check if a string is valid YAML format with additional validation. + + Args: + text (str): The string to check + + Returns: + bool: True if valid YAML, False otherwise + """ + if not text or not isinstance(text, str): + return False + + # Strip whitespace + text = text.strip() + + # Empty string is valid YAML + if not text: + return True + + try: + parsed = yaml.safe_load(text) + + # If it's just a plain string without YAML structure, + # we might want to consider it as non-YAML + # This is optional depending on your requirements + if isinstance(parsed, str) and "\n" not in text and ":" not in text: + return False + + return True + except yaml.YAMLError: + return False + + +def build_team_agent_from_yaml(yaml_code: str, llm_id: str, api_key: str, team_id: Optional[str] = None) -> TeamAgent: + import yaml + from aixplain.factories import AgentFactory, TeamAgentFactory + + # check if it is a yaml or just as string + if not is_yaml_formatted(yaml_code): + return None + team_config = yaml.safe_load(yaml_code) + + agents_data = team_config.get("agents", []) + tasks_data = team_config.get("tasks", []) + system_data = team_config.get("system", {"query": "", "name": "Test Team"}) + team_name = system_data.get("name", "") + team_description = system_data.get("description", "") + team_instructions = system_data.get("instructions", "") + llm = ModelFactory.get(llm_id) + # Create agent mapping by name for easier task assignment + agents_mapping = {} + agent_objs = [] + + # Parse agents + for agent_entry in agents_data: + agent_name = list(agent_entry.keys())[0] + agent_info = agent_entry[agent_name] + agent_instructions = agent_info.get("instructions", "") + agent_description = agent_info["description"] + agent_name = agent_name.replace("_", " ") + agent_name = f"{agent_name} agent" if not agent_name.endswith(" agent") else agent_name + agent_obj = Agent( + id="", + name=agent_name, + description=agent_description, + instructions=agent_instructions, + tasks=[], # Tasks will be assigned later + tools=[parse_tool_from_yaml(tool) for tool in agent_info.get("tools", []) if tool in SUPPORTED_TOOLS], + llm=llm, + ) + agents_mapping[agent_name] = agent_obj + agent_objs.append(agent_obj) + + # Create task collections for each agent (clean approach) + agent_tasks = {agent_name: [] for agent_name in agents_mapping.keys()} + + # Parse tasks and collect them by agent + for task in tasks_data: + for task_name, task_info in task.items(): + task_description = task_info.get("description", "") + expected_output = task_info.get("expected_output", "") + dependencies = task_info.get("dependencies", []) + agent_name = task_info.get("agent", "") + agent_name = agent_name.replace("_", " ") + agent_name = f"{agent_name} agent" if not agent_name.endswith(" agent") else agent_name + + task_obj = AgentTask( + name=task_name, + description=task_description, + expected_output=expected_output, + dependencies=dependencies, + ) + + # Add task to the corresponding agent's collection + if agent_name in agent_tasks: + # Check for duplicates within this build + existing_task_names = [task.name for task in agent_tasks[agent_name]] + if task_name not in existing_task_names: + agent_tasks[agent_name].append(task_obj) + else: + raise Exception(f"Agent '{agent_name}' referenced in tasks not found.") + + # Create agents with their respective task collections + for i, agent in enumerate(agent_objs): + agent_name = agent.name + agent_objs[i] = AgentFactory.create( + name=agent.name, + description=agent.description, + instructions=agent.instructions, + tools=agent.tools, + llm=llm, + tasks=agent_tasks.get(agent_name, []), # Use collected tasks + ) + return TeamAgentFactory.create( + name=team_name, + description=team_description, + instructions=team_instructions, + agents=agent_objs, + llm=llm, + api_key=api_key, + use_mentalist=True, + inspectors=[], + ) diff --git a/aixplain/modules/agent/__init__.py b/aixplain/modules/agent/__init__.py index 3ed4b04a..68007914 100644 --- a/aixplain/modules/agent/__init__.py +++ b/aixplain/modules/agent/__init__.py @@ -29,20 +29,23 @@ from aixplain.utils.file_utils import _request_with_retry from aixplain.enums import Function, Supplier, AssetStatus, StorageType, ResponseStatus +from aixplain.enums.evolve_type import EvolveType from aixplain.modules.model import Model -from aixplain.modules.agent.agent_task import AgentTask +from aixplain.modules.agent.agent_task import WorkflowTask, AgentTask from aixplain.modules.agent.output_format import OutputFormat from aixplain.modules.agent.tool import Tool from aixplain.modules.agent.agent_response import AgentResponse from aixplain.modules.agent.agent_response_data import AgentResponseData from aixplain.modules.agent.utils import process_variables, validate_history from pydantic import BaseModel -from typing import Dict, List, Text, Optional, Union +from typing import Dict, List, Text, Optional, Union, Any +from aixplain.modules.agent.evolve_param import EvolveParam, validate_evolve_param from urllib.parse import urljoin from aixplain.modules.model.llm_model import LLM from aixplain.utils import config from aixplain.modules.mixins import DeployableMixin +import warnings class Agent(Model, DeployableMixin[Tool]): @@ -92,6 +95,7 @@ def __init__( cost: Optional[Dict] = None, status: AssetStatus = AssetStatus.DRAFT, tasks: List[AgentTask] = [], + workflow_tasks: List[WorkflowTask] = [], output_format: OutputFormat = OutputFormat.TEXT, expected_output: Optional[Union[BaseModel, Text, dict]] = None, **additional_info, @@ -138,7 +142,16 @@ def __init__( except Exception: status = AssetStatus.DRAFT self.status = status - self.tasks = tasks + if tasks: + warnings.warn( + "The 'tasks' parameter is deprecated and will be removed in a future version. " "Use 'workflow_tasks' instead.", + DeprecationWarning, + stacklevel=2, + ) + self.workflow_tasks = tasks + else: + self.workflow_tasks = workflow_tasks + self.tasks = self.workflow_tasks self.output_format = output_format self.expected_output = expected_output self.is_valid = True @@ -373,6 +386,7 @@ def run_async( max_iterations: int = 10, output_format: Optional[OutputFormat] = None, expected_output: Optional[Union[BaseModel, Text, dict]] = None, + evolve: Union[Dict[str, Any], EvolveParam, None] = None, ) -> AgentResponse: """Runs asynchronously an agent call. @@ -388,6 +402,8 @@ def run_async( max_iterations (int, optional): maximum number of iterations between the agent and the tools. Defaults to 10. output_format (OutputFormat, optional): response format. If not provided, uses the format set during initialization. expected_output (Union[BaseModel, Text, dict], optional): expected output. Defaults to None. + output_format (ResponseFormat, optional): response format. Defaults to TEXT. + evolve (Union[Dict[str, Any], EvolveParam, None], optional): evolve the agent configuration. Can be a dictionary, EvolveParam instance, or None. Returns: dict: polling URL in response """ @@ -404,8 +420,9 @@ def run_async( from aixplain.factories.file_factory import FileFactory - if not self.is_valid: - raise Exception("Agent is not valid. Please validate the agent before running.") + # Validate and normalize evolve parameters using the base model + evolve_param = validate_evolve_param(evolve) + evolve_dict = evolve_param.to_dict() if output_format == OutputFormat.JSON: assert expected_output is not None and ( @@ -467,6 +484,7 @@ def run_async( "outputFormat": output_format, "expectedOutput": expected_output, }, + "evolve": json.dumps(evolve_dict), } payload.update(parameters) payload = json.dumps(payload) @@ -503,7 +521,7 @@ def to_dict(self) -> Dict: "version": self.version, "llmId": self.llm_id if self.llm is None else self.llm.id, "status": self.status.value, - "tasks": [task.to_dict() for task in self.tasks], + "tasks": [task.to_dict() for task in self.workflow_tasks], "tools": ( [ { @@ -533,7 +551,7 @@ def from_dict(cls, data: Dict) -> "Agent": """ from aixplain.factories.agent_factory.utils import build_tool from aixplain.enums import AssetStatus - from aixplain.modules.agent_task import AgentTask + from aixplain.modules.agent import WorkflowTask # Extract tools from assets using proper tool building tools = [] @@ -549,10 +567,10 @@ def from_dict(cls, data: Dict) -> "Agent": logging.warning(f"Failed to build tool from asset data: {e}") # Extract tasks using from_dict method - tasks = [] + workflow_tasks = [] if "tasks" in data: for task_data in data["tasks"]: - tasks.append(AgentTask.from_dict(task_data)) + workflow_tasks.append(WorkflowTask.from_dict(task_data)) # Extract LLM from tools section (main LLM info) llm = None @@ -583,7 +601,7 @@ def from_dict(cls, data: Dict) -> "Agent": id=data["id"], name=data["name"], description=data["description"], - instructions=data.get("role"), + instructions=data.get("instructions"), tools=tools, llm_id=data.get("llmId", "6646261c6eb563165658bbb1"), llm=llm, @@ -592,7 +610,7 @@ def from_dict(cls, data: Dict) -> "Agent": version=data.get("version"), cost=data.get("cost"), status=status, - tasks=tasks, + workflow_tasks=workflow_tasks, output_format=OutputFormat(data.get("outputFormat", OutputFormat.TEXT)), expected_output=data.get("expectedOutput"), ) @@ -725,3 +743,131 @@ def __repr__(self) -> str: str: A string in the format "Agent: (id=)". """ return f"Agent: {self.name} (id={self.id})" + + def evolve_async( + self, + evolve_type: Union[EvolveType, str] = EvolveType.TEAM_TUNING, + max_successful_generations: int = 3, + max_failed_generation_retries: int = 3, + max_iterations: int = 50, + max_non_improving_generations: Optional[int] = 2, + llm: Optional[Union[Text, LLM]] = None, + ) -> AgentResponse: + """Asynchronously evolve the Agent and return a polling URL in the AgentResponse. + + Args: + evolve_type (Union[EvolveType, str]): Type of evolution (TEAM_TUNING or INSTRUCTION_TUNING). Defaults to TEAM_TUNING. + max_successful_generations (int): Maximum number of successful generations to evolve. Defaults to 3. + max_failed_generation_retries (int): Maximum retry attempts for failed generations. Defaults to 3. + max_iterations (int): Maximum number of iterations. Defaults to 50. + max_non_improving_generations (Optional[int]): Stop condition parameter for non-improving generations. Defaults to 2, can be None. + llm (Optional[Union[Text, LLM]]): LLM to use for evolution. Can be an LLM ID string or LLM object. Defaults to None. + + Returns: + AgentResponse: Response containing polling URL and status. + """ + from aixplain.utils.evolve_utils import create_llm_dict + + query = "" + + # Create EvolveParam from individual parameters + evolve_parameters = EvolveParam( + to_evolve=True, + evolve_type=evolve_type, + max_successful_generations=max_successful_generations, + max_failed_generation_retries=max_failed_generation_retries, + max_iterations=max_iterations, + max_non_improving_generations=max_non_improving_generations, + llm=create_llm_dict(llm), + ) + + return self.run_async(query=query, evolve=evolve_parameters) + + def evolve( + self, + evolve_type: Union[EvolveType, str] = EvolveType.TEAM_TUNING, + max_successful_generations: int = 3, + max_failed_generation_retries: int = 3, + max_iterations: int = 50, + max_non_improving_generations: Optional[int] = 2, + llm: Optional[Union[Text, LLM]] = None, + ) -> AgentResponse: + """Synchronously evolve the Agent and poll for the result. + + Args: + evolve_type (Union[EvolveType, str]): Type of evolution (TEAM_TUNING or INSTRUCTION_TUNING). Defaults to TEAM_TUNING. + max_successful_generations (int): Maximum number of successful generations to evolve. Defaults to 3. + max_failed_generation_retries (int): Maximum retry attempts for failed generations. Defaults to 3. + max_iterations (int): Maximum number of iterations. Defaults to 50. + max_non_improving_generations (Optional[int]): Stop condition parameter for non-improving generations. Defaults to 2, can be None. + llm (Optional[Union[Text, LLM]]): LLM to use for evolution. Can be an LLM ID string or LLM object. Defaults to None. + + Returns: + AgentResponse: Final response from the evolution process. + """ + from aixplain.utils.evolve_utils import create_llm_dict + from aixplain.factories.team_agent_factory.utils import build_team_agent_from_yaml + + # Create EvolveParam from individual parameters + evolve_parameters = EvolveParam( + to_evolve=True, + evolve_type=evolve_type, + max_successful_generations=max_successful_generations, + max_failed_generation_retries=max_failed_generation_retries, + max_iterations=max_iterations, + max_non_improving_generations=max_non_improving_generations, + llm=create_llm_dict(llm), + ) + + start = time.time() + try: + logging.info(f"Evolve started with parameters: {evolve_parameters}") + logging.info("It might take a while...") + response = self.evolve_async( + evolve_type=evolve_type, + max_successful_generations=max_successful_generations, + max_failed_generation_retries=max_failed_generation_retries, + max_iterations=max_iterations, + max_non_improving_generations=max_non_improving_generations, + llm=llm, + ) + if response["status"] == ResponseStatus.FAILED: + end = time.time() + response["elapsed_time"] = end - start + return response + poll_url = response["url"] + end = time.time() + result = self.sync_poll(poll_url, name="evolve_process", timeout=600) + result_data = result.data + + if "current_code" in result_data and result_data["current_code"] is not None: + if evolve_parameters.evolve_type == EvolveType.TEAM_TUNING: + result_data["evolved_agent"] = build_team_agent_from_yaml( + result_data["current_code"], + self.llm_id, + self.api_key, + self.id, + ) + elif evolve_parameters.evolve_type == EvolveType.INSTRUCTION_TUNING: + self.instructions = result_data["current_code"] + self.update() + result_data["evolved_agent"] = self + else: + raise ValueError( + "evolve_parameters.evolve_type must be one of the following: TEAM_TUNING, INSTRUCTION_TUNING" + ) + return AgentResponse( + status=ResponseStatus.SUCCESS, + completed=True, + data=result_data, + used_credits=getattr(result, "used_credits", 0.0), + run_time=getattr(result, "run_time", end - start), + ) + except Exception as e: + logging.error(f"Agent Evolve: Error in evolving: {e}") + end = time.time() + return AgentResponse( + status=ResponseStatus.FAILED, + completed=False, + error_message="No response from the service.", + ) diff --git a/aixplain/modules/agent/agent_response.py b/aixplain/modules/agent/agent_response.py index fc35e072..d10cff29 100644 --- a/aixplain/modules/agent/agent_response.py +++ b/aixplain/modules/agent/agent_response.py @@ -1,8 +1,11 @@ from aixplain.enums import ResponseStatus -from typing import Any, Dict, Optional, Text, Union, List +from typing import Any, Dict, Optional, Text, Union, List, TYPE_CHECKING from aixplain.modules.agent.agent_response_data import AgentResponseData from aixplain.modules.model.response import ModelResponse +if TYPE_CHECKING: + from aixplain.modules.team_agent.evolver_response_data import EvolverResponseData + class AgentResponse(ModelResponse): """A response object for agent execution results. @@ -25,7 +28,7 @@ class AgentResponse(ModelResponse): def __init__( self, status: ResponseStatus = ResponseStatus.FAILED, - data: Optional[AgentResponseData] = None, + data: Optional[Union[AgentResponseData, "EvolverResponseData"]] = None, details: Optional[Union[Dict, List]] = {}, completed: bool = False, error_message: Text = "", diff --git a/aixplain/modules/agent/agent_response_data.py b/aixplain/modules/agent/agent_response_data.py index 4de10d12..b97b16c1 100644 --- a/aixplain/modules/agent/agent_response_data.py +++ b/aixplain/modules/agent/agent_response_data.py @@ -95,6 +95,9 @@ def to_dict(self) -> Dict[str, Any]: "critiques": self.critiques, } + def get(self, key: str, default: Optional[Any] = None) -> Any: + return getattr(self, key, default) + def __getitem__(self, key: str) -> Any: """Get an attribute value using dictionary-style access. diff --git a/aixplain/modules/agent/agent_task.py b/aixplain/modules/agent/agent_task.py index 469beba1..433d58c0 100644 --- a/aixplain/modules/agent/agent_task.py +++ b/aixplain/modules/agent/agent_task.py @@ -1,7 +1,7 @@ -from typing import List, Optional, Text, Union +from typing import List, Text, Union, Optional -class AgentTask: +class WorkflowTask: """A task definition for an AI agent to execute. This class represents a task that can be assigned to an agent, including its @@ -11,23 +11,24 @@ class AgentTask: name (Text): The unique identifier/name of the task. description (Text): Detailed description of what the task should accomplish. expected_output (Text): Description of the expected output format or content. - dependencies (Optional[List[Union[Text, AgentTask]]]): List of tasks or task + dependencies (Optional[List[Union[Text, WorkflowTask]]]): List of tasks or task names that must be completed before this task. Defaults to None. """ + def __init__( self, name: Text, description: Text, expected_output: Text, - dependencies: Optional[List[Union[Text, "AgentTask"]]] = None, + dependencies: Optional[List[Union[Text, "WorkflowTask"]]] = [], ): - """Initialize a new AgentTask instance. + """Initialize a new WorkflowTask instance. Args: name (Text): The unique identifier/name of the task. description (Text): Detailed description of what the task should accomplish. expected_output (Text): Description of the expected output format or content. - dependencies (Optional[List[Union[Text, AgentTask]]], optional): List of + dependencies (Optional[List[Union[Text, WorkflowTask]]], optional): List of tasks or task names that must be completed before this task. Defaults to None. """ @@ -39,7 +40,7 @@ def __init__( def to_dict(self) -> dict: """Convert the task to a dictionary representation. - This method serializes the task data, converting any AgentTask dependencies + This method serializes the task data, converting any WorkflowTask dependencies to their name strings. Returns: @@ -49,7 +50,7 @@ def to_dict(self) -> dict: - expectedOutput: The expected output description - dependencies: List of dependency names or None """ - agent_task_dict = { + workflow_task_dict = { "name": self.name, "description": self.description, "expectedOutput": self.expected_output, @@ -57,20 +58,20 @@ def to_dict(self) -> dict: } if self.dependencies: - for i, dependency in enumerate(agent_task_dict["dependencies"]): - if isinstance(dependency, AgentTask): - agent_task_dict["dependencies"][i] = dependency.name - return agent_task_dict + for i, dependency in enumerate(workflow_task_dict["dependencies"]): + if isinstance(dependency, WorkflowTask): + workflow_task_dict["dependencies"][i] = dependency.name + return workflow_task_dict @classmethod - def from_dict(cls, data: dict) -> "AgentTask": - """Create an AgentTask instance from a dictionary representation. + def from_dict(cls, data: dict) -> "WorkflowTask": + """Create an WorkflowTask instance from a dictionary representation. Args: - data: Dictionary containing AgentTask parameters + data: Dictionary containing WorkflowTask parameters Returns: - AgentTask instance + WorkflowTask instance """ return cls( name=data["name"], @@ -78,3 +79,17 @@ def from_dict(cls, data: dict) -> "AgentTask": expected_output=data["expectedOutput"], dependencies=data.get("dependencies", None), ) + + +# !this is a backward compatibility for the AgentTask class +# it will be removed in the future +class AgentTask(WorkflowTask): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def to_dict(self): + return super().to_dict() + + @classmethod + def from_dict(cls, data: dict) -> "AgentTask": + return super().from_dict(data) diff --git a/aixplain/modules/agent/evolve_param.py b/aixplain/modules/agent/evolve_param.py new file mode 100644 index 00000000..1693293e --- /dev/null +++ b/aixplain/modules/agent/evolve_param.py @@ -0,0 +1,269 @@ +__author__ = "aiXplain" + +""" +Copyright 2024 The aiXplain SDK authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +Author: aiXplain Team +Date: December 2024 +Description: + EvolveParam Base Model Class for Agent and TeamAgent evolve functionality +""" +from aixplain.enums import EvolveType +from dataclasses import dataclass, field +from typing import Any, Dict, Optional, Union + + +@dataclass +class EvolveParam: + """Base model for evolve parameters used in Agent and TeamAgent evolution. + + Attributes: + to_evolve (bool): Whether to enable evolution. Defaults to False. + evolve_type (Optional[EvolveType]): Type of evolve. + max_successful_generations (int): Maximum number of successful generations. + max_failed_generation_retries (int): Maximum number of failed generation retries. + max_iterations (int): Maximum number of iterations. + max_non_improving_generations (Optional[int]): Maximum number of non-improving generations. + llm (Optional[Dict[str, Any]]): LLM configuration with all parameters. + additional_params (Optional[Dict[str, Any]]): Additional parameters. + """ + + to_evolve: bool = False + evolve_type: Optional[EvolveType] = EvolveType.TEAM_TUNING + max_successful_generations: int = 3 + max_failed_generation_retries: int = 3 + max_iterations: int = 50 + max_non_improving_generations: Optional[int] = 2 + llm: Optional[Dict[str, Any]] = None + additional_params: Optional[Dict[str, Any]] = field(default_factory=dict) + + def __post_init__(self): + """Validate parameters after initialization.""" + self.validate() + + def validate(self) -> None: + """Validate evolve parameters. + + Raises: + ValueError: If any parameter is invalid. + """ + if self.evolve_type is not None: + if isinstance(self.evolve_type, str): + # Convert string to EvolveType + try: + self.evolve_type = EvolveType(self.evolve_type) + except ValueError: + raise ValueError( + f"evolve_type '{self.evolve_type}' is not a valid EvolveType. Valid values are: {list(EvolveType)}" + ) + elif not isinstance(self.evolve_type, EvolveType): + raise ValueError("evolve_type must be a valid EvolveType or string") + if self.additional_params is not None: + if not isinstance(self.additional_params, dict): + raise ValueError("additional_params must be a dictionary") + + if self.max_successful_generations is not None: + if not isinstance(self.max_successful_generations, int): + raise ValueError("max_successful_generations must be an integer") + if self.max_successful_generations <= 0: + raise ValueError("max_successful_generations must be positive") + + if self.max_failed_generation_retries is not None: + if not isinstance(self.max_failed_generation_retries, int): + raise ValueError("max_failed_generation_retries must be an integer") + if self.max_failed_generation_retries <= 0: + raise ValueError("max_failed_generation_retries must be positive") + + if self.max_iterations is not None: + if not isinstance(self.max_iterations, int): + raise ValueError("max_iterations must be an integer") + if self.max_iterations <= 0: + raise ValueError("max_iterations must be positive") + + if self.max_non_improving_generations is not None: + if not isinstance(self.max_non_improving_generations, int): + raise ValueError("max_non_improving_generations must be an integer or None") + if self.max_non_improving_generations <= 0: + raise ValueError("max_non_improving_generations must be positive or None") + + # Add validation for llm parameter + if self.llm is not None: + if not isinstance(self.llm, dict): + raise ValueError("llm must be a dictionary or None") + + @classmethod + def from_dict(cls, data: Union[Dict[str, Any], None]) -> "EvolveParam": + """Create EvolveParam instance from dictionary. + + Args: + data (Union[Dict[str, Any], None]): Dictionary containing evolve parameters. + + Returns: + EvolveParam: Instance with parameters set from dictionary. + + Raises: + ValueError: If data format is invalid. + """ + if data is None: + return cls() + + if not isinstance(data, dict): + raise ValueError("evolve parameter must be a dictionary or None") + + # Extract known parameters + known_params = { + "to_evolve": data.get("toEvolve", data.get("to_evolve", False)), + "evolve_type": data.get("evolve_type"), + "max_successful_generations": data.get("max_successful_generations"), + "max_failed_generation_retries": data.get("max_failed_generation_retries"), + "max_iterations": data.get("max_iterations"), + "max_non_improving_generations": data.get("max_non_improving_generations"), + "llm": data.get("llm"), + "additional_params": data.get("additional_params"), + } + + # Remove None values + known_params = {k: v for k, v in known_params.items() if v is not None} + + # Collect additional parameters + additional_params = { + k: v + for k, v in data.items() + if k + not in [ + "toEvolve", + "to_evolve", + "evolve_type", + "max_successful_generations", + "max_failed_generation_retries", + "max_iterations", + "max_non_improving_generations", + "llm", + "additional_params", + ] + } + + return cls(additional_params=additional_params, **known_params) + + def to_dict(self) -> Dict[str, Any]: + """Convert EvolveParam instance to dictionary for API calls. + + Returns: + Dict[str, Any]: Dictionary representation with API-compatible keys. + """ + result = { + "toEvolve": self.to_evolve, + } + + # Add optional parameters if they are set + if self.evolve_type is not None: + result["evolve_type"] = self.evolve_type + if self.max_successful_generations is not None: + result["max_successful_generations"] = self.max_successful_generations + if self.max_failed_generation_retries is not None: + result["max_failed_generation_retries"] = self.max_failed_generation_retries + if self.max_iterations is not None: + result["max_iterations"] = self.max_iterations + # Always include max_non_improving_generations, even if None + result["max_non_improving_generations"] = self.max_non_improving_generations + if self.llm is not None: + result["llm"] = self.llm + if self.additional_params is not None: + result.update(self.additional_params) + + return result + + def merge(self, other: Union[Dict[str, Any], "EvolveParam"]) -> "EvolveParam": + """Merge this EvolveParam with another set of parameters. + + Args: + other (Union[Dict[str, Any], EvolveParam]): Other parameters to merge. + + Returns: + EvolveParam: New instance with merged parameters. + """ + if isinstance(other, dict): + other = EvolveParam.from_dict(other) + elif not isinstance(other, EvolveParam): + raise ValueError("other must be a dictionary or EvolveParam instance") + + # Create merged parameters + merged_additional = {**self.additional_params, **other.additional_params} + + return EvolveParam( + to_evolve=other.to_evolve if other.to_evolve else self.to_evolve, + evolve_type=(other.evolve_type if other.evolve_type is not None else self.evolve_type), + max_successful_generations=( + other.max_successful_generations + if other.max_successful_generations is not None + else self.max_successful_generations + ), + max_failed_generation_retries=( + other.max_failed_generation_retries + if other.max_failed_generation_retries is not None + else self.max_failed_generation_retries + ), + max_iterations=(other.max_iterations if other.max_iterations is not None else self.max_iterations), + max_non_improving_generations=( + other.max_non_improving_generations + if other.max_non_improving_generations is not None + else self.max_non_improving_generations + ), + llm=(other.llm if other.llm is not None else self.llm), + additional_params=merged_additional, + ) + + def __repr__(self) -> str: + return ( + f"EvolveParam(" + f"to_evolve={self.to_evolve}, " + f"evolve_type={self.evolve_type}, " + f"max_successful_generations={self.max_successful_generations}, " + f"max_failed_generation_retries={self.max_failed_generation_retries}, " + f"max_iterations={self.max_iterations}, " + f"max_non_improving_generations={self.max_non_improving_generations}, " + f"llm={self.llm}, " + f"additional_params={self.additional_params})" + ) + + +def validate_evolve_param( + evolve_param: Union[Dict[str, Any], EvolveParam, None], +) -> EvolveParam: + """Utility function to validate and convert evolve parameters. + + Args: + evolve_param (Union[Dict[str, Any], EvolveParam, None]): Input evolve parameters. + + Returns: + EvolveParam: Validated EvolveParam instance. + + Raises: + ValueError: If parameters are invalid. + """ + if evolve_param is None: + return EvolveParam() + + if isinstance(evolve_param, EvolveParam): + evolve_param.validate() + return evolve_param + + if isinstance(evolve_param, dict): + # Check for required toEvolve key for backward compatibility + if "toEvolve" not in evolve_param and "to_evolve" not in evolve_param: + raise ValueError("evolve parameter must contain 'toEvolve' key") + return EvolveParam.from_dict(evolve_param) + + raise ValueError("evolve parameter must be a dictionary, EvolveParam instance, or None") diff --git a/aixplain/modules/model/index_model.py b/aixplain/modules/model/index_model.py index b9420f52..c4343a4c 100644 --- a/aixplain/modules/model/index_model.py +++ b/aixplain/modules/model/index_model.py @@ -1,3 +1,6 @@ +import os +import warnings +from uuid import uuid4 from aixplain.enums import EmbeddingModel, Function, Supplier, ResponseStatus, StorageType, FunctionType from aixplain.modules.model import Model from aixplain.utils import config @@ -9,9 +12,7 @@ from aixplain.enums.splitting_options import SplittingOptions import os -from urllib.parse import urljoin -from aixplain.utils.file_utils import _request_with_retry - +DOCLING_MODEL_ID = "677bee6c6eb56331f9192a91" class IndexFilterOperator(Enum): """Enumeration of operators available for filtering index records. @@ -177,8 +178,6 @@ def __init__( model = ModelFactory.get(embedding_model) self.embedding_size = model.additional_info["embedding_size"] except Exception as e: - import warnings - warnings.warn(f"Failed to get embedding size for embedding model {embedding_model}: {e}") self.embedding_size = None @@ -231,11 +230,11 @@ def search(self, query: str, top_k: int = 10, filters: List[IndexFilter] = []) - } return self.run(data=data) - def upsert(self, documents: List[Record], splitter: Optional[Splitter] = None) -> ModelResponse: + def upsert(self, documents: List[Record] | str, splitter: Optional[Splitter] = None) -> ModelResponse: """Upsert documents into the index Args: - documents (List[Record]): List of documents to be upserted + documents (List[Record] | str): List of documents to be upserted or a file path splitter (Splitter, optional): Splitter to be applied. Defaults to None. Returns: @@ -244,8 +243,12 @@ def upsert(self, documents: List[Record], splitter: Optional[Splitter] = None) - Examples: index_model.upsert([Record(value="Hello, world!", value_type="text", uri="", id="1", attributes={})]) index_model.upsert([Record(value="Hello, world!", value_type="text", uri="", id="1", attributes={})], splitter=Splitter(split=True, split_by=SplittingOptions.WORD, split_length=1, split_overlap=0)) + index_model.upsert("my_file.pdf") + index_model.upsert("my_file.pdf", splitter=Splitter(split=True, split_by=SplittingOptions.WORD, split_length=400, split_overlap=50)) Splitter in the above example is optional and can be used to split the documents into smaller chunks. """ + if isinstance(documents, str): + documents = [self.prepare_record_from_file(documents)] # Validate documents for doc in documents: doc.validate() @@ -272,7 +275,7 @@ def upsert(self, documents: List[Record], splitter: Optional[Splitter] = None) - return response raise Exception(f"Failed to upsert documents: {response.error_message}") - def count(self) -> float: + def count(self) -> int: """Get the total number of documents in the index. Returns: @@ -335,6 +338,63 @@ def delete_record(self, record_id: Text) -> ModelResponse: return response raise Exception(f"Failed to delete record: {response.error_message}") + def prepare_record_from_file(self, file_path: str, file_id: str = None) -> Record: + """Prepare a record from a file. + + Args: + file_path (str): The path to the file to be processed. + file_id (str, optional): The ID to assign to the record. If not provided, a unique ID is generated. + + Returns: + Record: A Record object containing the file's content and metadata. + + Raises: + Exception: If the file cannot be parsed. + + Example: + >>> record = index_model.prepare_record_from_file("/path/to/file.txt") + """ + response = self.parse_file(file_path) + file_name = file_path.split("/")[-1] + if not file_id: + file_id = file_name + "_" + str(uuid4()) + return Record(value=response.data, value_type="text", id=file_id, attributes={"file_name": file_name}) + + @staticmethod + def parse_file(file_path: str) -> ModelResponse: + """Parse a file using the Docling model. + + Args: + file_path (str): The path to the file to be parsed. + + Returns: + ModelResponse: The response containing the parsed file content. + + Raises: + Exception: If the file does not exist or cannot be parsed. + + Example: + >>> response = IndexModel.parse_file("/path/to/file.pdf") + """ + if not os.path.exists(file_path): + raise Exception(f"File {file_path} does not exist") + if file_path.endswith(".txt"): + with open(file_path, "r") as file: + data = file.read() + if not data: + warnings.warn(f"File {file_path} is empty") + return ModelResponse(status=ResponseStatus.SUCCESS, data=data, completed=True) + try: + from aixplain.factories import ModelFactory + + model = ModelFactory.get(DOCLING_MODEL_ID) + response = model.run(file_path) + if not response.data: + warnings.warn(f"File {file_path} is empty") + return response + except Exception as e: + raise Exception(f"Failed to parse file: {e}") + def retrieve_records_with_filter(self, filter: IndexFilter) -> ModelResponse: """ Retrieve records from the index that match the given filter. diff --git a/aixplain/modules/model/utils.py b/aixplain/modules/model/utils.py index 4c1b6239..9fa3bf2d 100644 --- a/aixplain/modules/model/utils.py +++ b/aixplain/modules/model/utils.py @@ -2,11 +2,82 @@ import json import logging +import ast +import inspect from aixplain.utils.file_utils import _request_with_retry from typing import Callable, Dict, List, Text, Tuple, Union, Optional from aixplain.exceptions import get_error_from_status_code +def _extract_function_parameters(func: Callable) -> List[Tuple[str, str]]: + """ + Extract function parameters using AST parsing for robust handling of multiline functions. + + Args: + func: The function to extract parameters from + + Returns: + List of tuples containing (parameter_name, parameter_type) + """ + try: + # Use inspect.signature for the most reliable approach + sig = inspect.signature(func) + parameters = [] + + for param_name, param in sig.parameters.items(): + # Extract type annotation + if param.annotation != inspect.Parameter.empty: + # Handle complex type annotations + if hasattr(param.annotation, "__name__"): + param_type = param.annotation.__name__ + else: + param_type = str(param.annotation) + else: + raise ValueError(f"Parameter '{param_name}' missing type annotation") + + parameters.append((param_name, param_type)) + + return parameters + + except Exception as e: + # Fallback to AST parsing if inspect fails + try: + source = inspect.getsource(func) + tree = ast.parse(source) + + # Find the function definition + func_def = None + for node in ast.walk(tree): + if isinstance(node, ast.FunctionDef) and node.name == func.__name__: + func_def = node + break + + if not func_def: + raise ValueError(f"Could not find function definition for {func.__name__}") + + parameters = [] + for arg in func_def.args.args: + param_name = arg.arg + + # Extract type annotation + if arg.annotation: + if isinstance(arg.annotation, ast.Name): + param_type = arg.annotation.id + elif isinstance(arg.annotation, ast.Constant): + param_type = str(arg.annotation.value) + else: + param_type = ast.unparse(arg.annotation) + else: + raise ValueError(f"Parameter '{param_name}' missing type annotation") + + parameters.append((param_name, param_type)) + + return parameters + + except Exception as ast_error: + raise ValueError(f"Failed to extract parameters: {e}. AST fallback also failed: {ast_error}") + + def build_payload(data: Union[Text, Dict], parameters: Optional[Dict] = None, stream: Optional[bool] = None): """Build a JSON payload for API requests. @@ -301,9 +372,7 @@ def parse_code_decorated(code: Union[Text, Callable]) -> Tuple[Text, List, Text, description = ( getattr(code, "_tool_description", None) if hasattr(code, "_tool_description") - else code.__doc__.strip() - if code.__doc__ - else "" + else code.__doc__.strip() if code.__doc__ else "" ) name = getattr(code, "_tool_name", None) if hasattr(code, "_tool_name") else "" if hasattr(code, "_tool_inputs") and code._tool_inputs != []: @@ -330,20 +399,12 @@ def parse_code_decorated(code: Union[Text, Callable]) -> Tuple[Text, List, Text, str_code = inspect.getsource(code) description = code.__doc__.strip() if code.__doc__ else "" name = code.__name__ - # Try to infer parameters - params_match = re.search(r"def\s+\w+\s*\((.*?)\)\s*(?:->.*?)?:", str_code) - parameters = params_match.group(1).split(",") if params_match else [] - - for input in parameters: - if not input: - continue - assert ( - len(input.split(":")) > 1 - ), "Utility Model Error: Input type is required. For instance def main(a: int, b: int) -> int:" - input_name, input_type = input.split(":") - input_name = input_name.strip() - input_type = input_type.split("=")[0].strip() + # Extract parameters using AST for robust parsing + parameters = _extract_function_parameters(code) + inputs = [] + + for input_name, input_type in parameters: if input_type in ["int", "float"]: input_type = "number" inputs.append( diff --git a/aixplain/modules/team_agent/__init__.py b/aixplain/modules/team_agent/__init__.py index b5270bd8..b5669ece 100644 --- a/aixplain/modules/team_agent/__init__.py +++ b/aixplain/modules/team_agent/__init__.py @@ -27,7 +27,7 @@ import traceback import re from enum import Enum -from typing import Dict, List, Text, Optional, Union +from typing import Dict, List, Text, Optional, Union, Any from urllib.parse import urljoin from datetime import datetime @@ -36,12 +36,15 @@ from aixplain.enums.supplier import Supplier from aixplain.enums.asset_status import AssetStatus from aixplain.enums.storage_type import StorageType +from aixplain.enums.evolve_type import EvolveType from aixplain.modules.model import Model from aixplain.modules.agent import Agent, OutputFormat from aixplain.modules.agent.agent_response import AgentResponse from aixplain.modules.agent.agent_response_data import AgentResponseData +from aixplain.modules.agent.evolve_param import EvolveParam, validate_evolve_param from aixplain.modules.agent.utils import process_variables, validate_history from aixplain.modules.team_agent.inspector import Inspector +from aixplain.modules.team_agent.evolver_response_data import EvolverResponseData from aixplain.utils import config from aixplain.utils.request_utils import _request_with_retry from aixplain.modules.model.llm_model import LLM @@ -123,6 +126,7 @@ def __init__( self.agents = agents self.llm_id = llm_id self.llm = llm + self.api_key = api_key self.use_mentalist = use_mentalist self.inspectors = inspectors self.inspector_targets = inspector_targets @@ -213,7 +217,7 @@ def run( output_format (OutputFormat, optional): response format. If not provided, uses the format set during initialization. expected_output (Union[BaseModel, Text, dict], optional): expected output. Defaults to None. Returns: - Dict: parsed output from model + AgentResponse: parsed output from model """ start = time.time() result_data = {} @@ -283,6 +287,7 @@ def run_async( max_iterations: int = 30, output_format: Optional[OutputFormat] = None, expected_output: Optional[Union[BaseModel, Text, dict]] = None, + evolve: Union[Dict[str, Any], EvolveParam, None] = None, ) -> AgentResponse: """Runs asynchronously a Team Agent call. @@ -298,8 +303,9 @@ def run_async( max_iterations (int, optional): maximum number of iterations between the agents. Defaults to 30. output_format (OutputFormat, optional): response format. If not provided, uses the format set during initialization. expected_output (Union[BaseModel, Text, dict], optional): expected output. Defaults to None. + evolve (Union[Dict[str, Any], EvolveParam, None], optional): evolve the team agent configuration. Can be a dictionary, EvolveParam instance, or None. Returns: - dict: polling URL in response + AgentResponse: polling URL in response """ if session_id is not None and history is not None: raise ValueError("Provide either `session_id` or `history`, not both.") @@ -313,6 +319,10 @@ def run_async( from aixplain.factories.file_factory import FileFactory + # Validate and normalize evolve parameters using the base model + evolve_param = validate_evolve_param(evolve) + evolve_dict = evolve_param.to_dict() + if not self.is_valid: raise Exception("Team Agent is not valid. Please validate the team agent before running.") @@ -371,6 +381,7 @@ def run_async( "outputFormat": output_format, "expectedOutput": expected_output, }, + "evolve": json.dumps(evolve_dict), } payload.update(parameters) payload = json.dumps(payload) @@ -384,7 +395,7 @@ def run_async( logging.info(f"Result of request for {name} - {r.status_code} - {resp}") poll_url = resp["data"] - return AgentResponse( + response = AgentResponse( status=ResponseStatus.IN_PROGRESS, url=poll_url, data=AgentResponseData(input=input_data), @@ -394,10 +405,69 @@ def run_async( except Exception: msg = f"Error in request for {name} - {traceback.format_exc()}" logging.error(f"Team Agent Run Async: Error in running for {name}: {resp}") - return AgentResponse( - status=ResponseStatus.FAILED, - error=msg, + if resp is not None: + response = AgentResponse( + status=ResponseStatus.FAILED, + error=msg, + ) + return response + + def poll(self, poll_url: Text, name: Text = "model_process") -> AgentResponse: + used_credits, run_time = 0.0, 0.0 + resp, error_message, status = None, None, ResponseStatus.SUCCESS + headers = {"x-api-key": self.api_key, "Content-Type": "application/json"} + r = _request_with_retry("get", poll_url, headers=headers) + try: + resp = r.json() + if resp["completed"] is True: + status = ResponseStatus(resp.get("status", "FAILED")) + if "error_message" in resp or "supplierError" in resp: + status = ResponseStatus.FAILED + error_message = resp.get("error_message") + else: + status = ResponseStatus.IN_PROGRESS + logging.debug(f"Single Poll for Team Agent: Status of polling for {name}: {resp}") + + resp_data = resp.get("data") or {} + used_credits = resp_data.get("usedCredits", 0.0) + run_time = resp_data.get("runTime", 0.0) + evolve_type = resp_data.get("evolve_type", EvolveType.TEAM_TUNING.value) + if "evolved_agent" in resp_data and status == ResponseStatus.SUCCESS: + if evolve_type == EvolveType.INSTRUCTION_TUNING.value: + # return this class as it is but replace its description and instructions + evolved_agent = self + current_code = resp_data.get("current_code", "") + evolved_agent.description = current_code + evolved_agent.update() + resp_data["evolved_agent"] = evolved_agent + else: + resp_data = EvolverResponseData.from_dict(resp_data, llm_id=self.llm_id, api_key=self.api_key) + else: + resp_data = AgentResponseData( + input=resp_data.get("input"), + output=resp_data.get("output"), + session_id=resp_data.get("session_id"), + intermediate_steps=resp_data.get("intermediate_steps"), + execution_stats=resp_data.get("executionStats"), + ) + except Exception as e: + import traceback + + logging.error(f"Single Poll for Team Agent: Error of polling for {name}: {e}, traceback: {traceback.format_exc()}") + status = ResponseStatus.FAILED + error_message = str(e) + finally: + response = AgentResponse( + status=status, + data=resp_data, + details=resp.get("details", {}), + completed=resp.get("completed", False), + used_credits=used_credits, + run_time=run_time, + usage=resp.get("usage", None), + error_message=error_message, ) + return response def delete(self) -> None: """Delete Corpus service""" @@ -480,7 +550,7 @@ def to_dict(self) -> Dict: - supplier (str): The supplier code - version (str): The version number - status (str): The current status - - role (str): The team agent's instructions + - instructions (str): The team agent's instructions """ if self.use_mentalist: planner_id = self.mentalist_llm.id if self.mentalist_llm else self.llm_id @@ -493,11 +563,11 @@ def to_dict(self) -> Dict: "links": [], "description": self.description, "llmId": self.llm.id if self.llm else self.llm_id, - "supervisorId": self.supervisor_llm.id if self.supervisor_llm else self.llm_id, + "supervisorId": (self.supervisor_llm.id if self.supervisor_llm else self.llm_id), "plannerId": planner_id, "inspectors": [inspector.model_dump(by_alias=True) for inspector in self.inspectors], "inspectorTargets": [target.value for target in self.inspector_targets], - "supplier": self.supplier.value["code"] if isinstance(self.supplier, Supplier) else self.supplier, + "supplier": (self.supplier.value["code"] if isinstance(self.supplier, Supplier) else self.supplier), "version": self.version, "status": self.status.value, "instructions": self.instructions, @@ -515,10 +585,10 @@ def from_dict(cls, data: Dict) -> "TeamAgent": Returns: TeamAgent instance """ - from aixplain.factories.agent_factory import AgentFactory from aixplain.factories.model_factory import ModelFactory from aixplain.enums import AssetStatus from aixplain.modules.team_agent import Inspector, InspectorTarget + from aixplain.modules.agent import Agent # Extract agents from agents list using proper agent loading agents = [] @@ -527,20 +597,23 @@ def from_dict(cls, data: Dict) -> "TeamAgent": if "assetId" in agent_data: try: # Load agent using AgentFactory - agent = AgentFactory.get(agent_data["assetId"]) + agent = Agent.from_dict(agent_data) agents.append(agent) except Exception as e: # Log warning but continue processing other agents import logging logging.warning(f"Failed to load agent {agent_data['assetId']}: {e}") - + else: + agents.append(Agent.from_dict(agent_data)) # Extract inspectors using proper model validation inspectors = [] if "inspectors" in data: for inspector_data in data["inspectors"]: try: - if hasattr(Inspector, "model_validate"): + if isinstance(inspector_data, Inspector): + inspectors.append(inspector_data) + elif hasattr(Inspector, "model_validate"): inspectors.append(Inspector.model_validate(inspector_data)) else: inspectors.append(Inspector(**inspector_data)) @@ -548,6 +621,7 @@ def from_dict(cls, data: Dict) -> "TeamAgent": import logging logging.warning(f"Failed to create inspector from data: {e}") + continue # Extract inspector targets inspector_targets = [InspectorTarget.STEPS] # default @@ -601,7 +675,7 @@ def from_dict(cls, data: Dict) -> "TeamAgent": version=data.get("version"), use_mentalist=use_mentalist, status=status, - instructions=data.get("role"), + instructions=data.get("instructions"), inspectors=inspectors, inspector_targets=inspector_targets, output_format=OutputFormat(data.get("outputFormat", OutputFormat.TEXT)), @@ -726,3 +800,133 @@ def __repr__(self): str: A string in the format "TeamAgent: (id=)". """ return f"TeamAgent: {self.name} (id={self.id})" + + def evolve_async( + self, + evolve_type: Union[EvolveType, str] = EvolveType.TEAM_TUNING, + max_successful_generations: int = 3, + max_failed_generation_retries: int = 3, + max_iterations: int = 50, + max_non_improving_generations: Optional[int] = 2, + llm: Optional[Union[Text, LLM]] = None, + ) -> AgentResponse: + """Asynchronously evolve the Team Agent and return a polling URL in the AgentResponse. + + Args: + evolve_type (Union[EvolveType, str]): Type of evolution (TEAM_TUNING or INSTRUCTION_TUNING). Defaults to TEAM_TUNING. + max_successful_generations (int): Maximum number of successful generations to evolve. Defaults to 3. + max_failed_generation_retries (int): Maximum retry attempts for failed generations. Defaults to 3. + max_iterations (int): Maximum number of iterations. Defaults to 50. + max_non_improving_generations (Optional[int]): Stop condition parameter for non-improving generations. Defaults to 2, can be None. + llm (Optional[Union[Text, LLM]]): LLM to use for evolution. Can be an LLM ID string or LLM object. Defaults to None. + + Returns: + AgentResponse: Response containing polling URL and status. + """ + from aixplain.utils.evolve_utils import create_llm_dict + + query = "" + + # Create EvolveParam from individual parameters + evolve_parameters = EvolveParam( + to_evolve=True, + evolve_type=evolve_type, + max_successful_generations=max_successful_generations, + max_failed_generation_retries=max_failed_generation_retries, + max_iterations=max_iterations, + max_non_improving_generations=max_non_improving_generations, + llm=create_llm_dict(llm), + ) + + response = self.run_async(query=query, evolve=evolve_parameters) + return response + + def evolve( + self, + evolve_type: Union[EvolveType, str] = EvolveType.TEAM_TUNING, + max_successful_generations: int = 3, + max_failed_generation_retries: int = 3, + max_iterations: int = 50, + max_non_improving_generations: Optional[int] = 2, + llm: Optional[Union[Text, LLM]] = None, + ) -> AgentResponse: + """Synchronously evolve the Team Agent and poll for the result. + + Args: + evolve_type (Union[EvolveType, str]): Type of evolution (TEAM_TUNING or INSTRUCTION_TUNING). Defaults to TEAM_TUNING. + max_successful_generations (int): Maximum number of successful generations to evolve. Defaults to 3. + max_failed_generation_retries (int): Maximum retry attempts for failed generations. Defaults to 3. + max_iterations (int): Maximum number of iterations. Defaults to 50. + max_non_improving_generations (Optional[int]): Stop condition parameter for non-improving generations. Defaults to 2, can be None. + llm (Optional[Union[Text, LLM]]): LLM to use for evolution. Can be an LLM ID string or LLM object. Defaults to None. + + Returns: + AgentResponse: Final response from the evolution process. + """ + from aixplain.enums import EvolveType + from aixplain.utils.evolve_utils import create_llm_dict + from aixplain.factories.team_agent_factory.utils import build_team_agent_from_yaml + + # Create EvolveParam from individual parameters + evolve_parameters = EvolveParam( + to_evolve=True, + evolve_type=evolve_type, + max_successful_generations=max_successful_generations, + max_failed_generation_retries=max_failed_generation_retries, + max_iterations=max_iterations, + max_non_improving_generations=max_non_improving_generations, + llm=create_llm_dict(llm), + ) + start = time.time() + try: + logging.info(f"Evolve started with parameters: {evolve_parameters}") + logging.info("It might take a while...") + response = self.evolve_async( + evolve_type=evolve_type, + max_successful_generations=max_successful_generations, + max_failed_generation_retries=max_failed_generation_retries, + max_iterations=max_iterations, + max_non_improving_generations=max_non_improving_generations, + llm=llm, + ) + if response["status"] == ResponseStatus.FAILED: + end = time.time() + response["elapsed_time"] = end - start + return response + poll_url = response["url"] + end = time.time() + result = self.sync_poll(poll_url, name="evolve_process", timeout=600) + result_data = result.data + current_code = result_data.get("current_code") if isinstance(result_data, dict) else result_data.current_code + if current_code is not None: + if evolve_parameters.evolve_type == EvolveType.TEAM_TUNING: + result_data["evolved_agent"] = build_team_agent_from_yaml( + result_data["current_code"], + self.llm_id, + self.api_key, + self.id, + ) + elif evolve_parameters.evolve_type == EvolveType.INSTRUCTION_TUNING: + self.instructions = result_data["current_code"] + self.description = result_data["current_code"] + self.update() + result_data["evolved_agent"] = self + else: + raise ValueError( + "evolve_parameters.evolve_type must be one of the following: TEAM_TUNING, INSTRUCTION_TUNING" + ) + return AgentResponse( + status=ResponseStatus.SUCCESS, + completed=True, + data=result_data, + used_credits=getattr(result, "used_credits", 0.0), + run_time=getattr(result, "run_time", end - start), + ) + except Exception as e: + logging.error(f"Team Agent Evolve: Error in evolving: {e}") + end = time.time() + return AgentResponse( + status=ResponseStatus.FAILED, + completed=False, + error_message="No response from the service.", + ) diff --git a/aixplain/modules/team_agent/evolver_response_data.py b/aixplain/modules/team_agent/evolver_response_data.py new file mode 100644 index 00000000..06479e4e --- /dev/null +++ b/aixplain/modules/team_agent/evolver_response_data.py @@ -0,0 +1,70 @@ +from typing import Any, Dict, List, Text, TYPE_CHECKING + +if TYPE_CHECKING: + from aixplain.modules.team_agent import TeamAgent + + +class EvolverResponseData: + def __init__( + self, + evolved_agent: "TeamAgent", + current_code: Text, + evaluation_report: Text, + comparison_report: Text, + criteria: Text, + archive: List[Text], + current_output: Text = "", + ) -> None: + self.evolved_agent = evolved_agent + self.current_code = current_code + self.evaluation_report = evaluation_report + self.comparison_report = comparison_report + self.criteria = criteria + self.archive = archive + self.current_output = current_output + + @classmethod + def from_dict(cls, data: Dict[str, Any], llm_id: Text, api_key: Text) -> "EvolverResponseData": + from aixplain.factories.team_agent_factory.utils import build_team_agent_from_yaml + + yaml_code = data.get("current_code", "") + evolved_team_agent = build_team_agent_from_yaml(yaml_code=yaml_code, llm_id=llm_id, api_key=api_key) + return cls( + evolved_agent=evolved_team_agent, + current_code=yaml_code, + evaluation_report=data.get("evaluation_report", ""), + comparison_report=data.get("comparison_report", ""), + criteria=data.get("criteria", ""), + archive=data.get("archive", []), + current_output=data.get("current_output", ""), + ) + + def to_dict(self) -> Dict[str, Any]: + return { + "evolved_agent": self.evolved_agent, + "current_code": self.current_code, + "evaluation_report": self.evaluation_report, + "comparison_report": self.comparison_report, + "criteria": self.criteria, + "archive": self.archive, + "current_output": self.current_output, + } + + def __getitem__(self, key: str) -> Any: + return getattr(self, key, None) + + def __setitem__(self, key: str, value: Any) -> None: + if hasattr(self, key): + setattr(self, key, value) + else: + raise KeyError(f"{key} is not a valid attribute of {self.__class__.__name__}") + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"evolved_agent='{self.evolved_agent}', " + f"evaluation_report='{self.evaluation_report}', " + f"comparison_report='{self.comparison_report}', " + f"criteria='{self.criteria}', " + f"archive='{self.archive}', " + ) diff --git a/aixplain/modules/team_agent/inspector.py b/aixplain/modules/team_agent/inspector.py index 8cd12cc0..7e8024a0 100644 --- a/aixplain/modules/team_agent/inspector.py +++ b/aixplain/modules/team_agent/inspector.py @@ -1,4 +1,5 @@ """Pre-defined agent for inspecting the data flow within a team agent. +WARNING: This feature is currently in private beta. Example usage: @@ -19,17 +20,40 @@ ) """ +import inspect from enum import Enum -from typing import Dict, Optional, Text +from typing import Dict, Optional, Text, Union, Callable -from pydantic import field_validator +import textwrap +from pydantic import BaseModel, field_validator from aixplain.modules.agent.model_with_params import ModelWithParams +from aixplain.modules.model.response import ModelResponse AUTO_DEFAULT_MODEL_ID = "67fd9e2bef0365783d06e2f0" # GPT-4.1 Nano +class InspectorAction(str, Enum): + """ + Inspector's decision on the next action. + """ + + CONTINUE = "continue" + RERUN = "rerun" + ABORT = "abort" + + +class InspectorOutput(BaseModel): + """ + Inspector's output. + """ + + critiques: Text + content_edited: Text + action: InspectorAction + + class InspectorAuto(str, Enum): """A list of keywords for inspectors configured automatically in the backend.""" @@ -55,6 +79,202 @@ class InspectorPolicy(str, Enum): ADAPTIVE = "adaptive" # adjust execution according to feedback +def validate_policy_callable(policy_func: Callable) -> bool: + """Validate that the policy callable meets the required constraints.""" + # Check function name + if policy_func.__name__ != "process_response": + return False + + # Get function signature + sig = inspect.signature(policy_func) + params = list(sig.parameters.keys()) + + # Check arguments - should have exactly 2 parameters: model_response and input_content + if len(params) != 2 or params[0] != "model_response" or params[1] != "input_content": + return False + + # Check return type annotation - should return InspectorOutput + return_annotation = sig.return_annotation + if return_annotation != InspectorOutput: + return False + + return True + + +def callable_to_code_string(policy_func: Callable) -> str: + """Convert a callable policy function to a code string for serialization.""" + try: + source_code = get_policy_source(policy_func) + if source_code is None: + # If we can't get the source code, create a minimal representation + sig = inspect.signature(policy_func) + return f"def process_response{str(sig)}:\n # Function source not available\n pass" + + # Dedent the source code to remove leading whitespace + source_code = textwrap.dedent(source_code) + return source_code + except (OSError, TypeError): + # If we can't get the source code, create a minimal representation + sig = inspect.signature(policy_func) + return f"def process_response{str(sig)}:\n # Function source not available\n pass" + + +def code_string_to_callable(code_string: str) -> Callable: + """Convert a code string back to a callable function for deserialization.""" + try: + # Create a namespace to execute the code + namespace = { + "InspectorAction": InspectorAction, + "InspectorOutput": InspectorOutput, + "ModelResponse": ModelResponse, + "str": str, + "int": int, + "float": float, + "bool": bool, + "list": list, + "dict": dict, + "tuple": tuple, + "set": set, + "len": len, + "print": print, + "range": range, + "enumerate": enumerate, + "zip": zip, + "map": map, + "filter": filter, + "any": any, + "all": all, + "sum": sum, + "min": min, + "max": max, + "abs": abs, + "round": round, + "sorted": sorted, + "reversed": reversed, + "isinstance": isinstance, + "hasattr": hasattr, + "getattr": getattr, + "setattr": setattr, + "dir": dir, + "type": type, + "id": id, + "hash": hash, + "repr": repr, + "str": str, + "format": format, + "ord": ord, + "chr": chr, + "bin": bin, + "oct": oct, + "hex": hex, + "pow": pow, + "divmod": divmod, + "complex": complex, + "bytes": bytes, + "bytearray": bytearray, + "memoryview": memoryview, + "slice": slice, + "property": property, + "staticmethod": staticmethod, + "classmethod": classmethod, + "super": super, + "object": object, + "Exception": Exception, + "ValueError": ValueError, + "TypeError": TypeError, + "AttributeError": AttributeError, + "KeyError": KeyError, + "IndexError": IndexError, + "RuntimeError": RuntimeError, + "AssertionError": AssertionError, + "ImportError": ImportError, + "ModuleNotFoundError": ModuleNotFoundError, + "NameError": NameError, + "SyntaxError": SyntaxError, + "IndentationError": IndentationError, + "TabError": TabError, + "UnboundLocalError": UnboundLocalError, + "UnicodeError": UnicodeError, + "UnicodeDecodeError": UnicodeDecodeError, + "UnicodeEncodeError": UnicodeEncodeError, + "UnicodeTranslateError": UnicodeTranslateError, + "OSError": OSError, + "FileNotFoundError": FileNotFoundError, + "PermissionError": PermissionError, + "ProcessLookupError": ProcessLookupError, + "TimeoutError": TimeoutError, + "ConnectionError": ConnectionError, + "BrokenPipeError": BrokenPipeError, + "ConnectionAbortedError": ConnectionAbortedError, + "ConnectionRefusedError": ConnectionRefusedError, + "ConnectionResetError": ConnectionResetError, + "BlockingIOError": BlockingIOError, + "ChildProcessError": ChildProcessError, + "NotADirectoryError": NotADirectoryError, + "IsADirectoryError": IsADirectoryError, + "InterruptedError": InterruptedError, + "EnvironmentError": EnvironmentError, + "IOError": IOError, + "EOFError": EOFError, + "MemoryError": MemoryError, + "RecursionError": RecursionError, + "SystemError": SystemError, + "ReferenceError": ReferenceError, + "FloatingPointError": FloatingPointError, + "OverflowError": OverflowError, + "ZeroDivisionError": ZeroDivisionError, + "ArithmeticError": ArithmeticError, + "BufferError": BufferError, + "LookupError": LookupError, + "StopIteration": StopIteration, + "GeneratorExit": GeneratorExit, + "KeyboardInterrupt": KeyboardInterrupt, + "SystemExit": SystemExit, + "BaseException": BaseException, + } + + # Execute the code string in the namespace + exec(code_string, namespace) + + # Get the function from the namespace + if "process_response" not in namespace: + raise ValueError("Code string must define a function named 'process_response'") + + func = namespace["process_response"] + + # Store the original source code as an attribute for later retrieval + func._source_code = code_string + + # Validate the function + if not validate_policy_callable(func): + raise ValueError("Deserialized function does not meet the required constraints") + + return func + except Exception as e: + raise ValueError(f"Failed to deserialize code string to callable: {e}") + + +def get_policy_source(func: Callable) -> Optional[str]: + """Get the source code of a policy function. + + This function tries to retrieve the source code of a policy function. + It first checks if the function has a stored _source_code attribute (for functions + created via code_string_to_callable), then falls back to inspect.getsource(). + + Args: + func: The function to get source code for + + Returns: + The source code string if available, None otherwise + """ + if hasattr(func, "_source_code"): + return func._source_code + try: + return inspect.getsource(func) + except (OSError, TypeError): + return None + + class Inspector(ModelWithParams): """Pre-defined agent for inspecting the data flow within a team agent. @@ -64,13 +284,15 @@ class Inspector(ModelWithParams): name: The name of the inspector. model_id: The ID of the model to wrap. model_params: The configuration for the model. - policy: The policy for the inspector. Default is ADAPTIVE. + policy: The policy for the inspector. Can be InspectorPolicy enum or a callable function. + If callable, must have name "process_response", arguments "model_response" and "input_content" (both strings), + and return InspectorAction. Default is ADAPTIVE. """ name: Text model_params: Optional[Dict] = None auto: Optional[InspectorAuto] = None - policy: InspectorPolicy = InspectorPolicy.ADAPTIVE + policy: Union[InspectorPolicy, Callable] = InspectorPolicy.ADAPTIVE def __init__(self, *args, **kwargs): """Initialize an Inspector instance. @@ -114,3 +336,47 @@ def validate_name(cls, v: Text) -> Text: if v == "": raise ValueError("name cannot be empty") return v + + @field_validator("policy") + def validate_policy(cls, v: Union[InspectorPolicy, Callable]) -> Union[InspectorPolicy, Callable]: + if callable(v): + if not validate_policy_callable(v): + raise ValueError( + "Policy callable must have name 'process_response', arguments 'model_response' and 'input_content' (both strings), and return InspectorAction" + ) + elif not isinstance(v, InspectorPolicy): + raise ValueError(f"Policy must be InspectorPolicy enum or a valid callable function, got {type(v)}") + return v + + def model_dump(self, by_alias: bool = False, **kwargs) -> Dict: + """Override model_dump to handle callable policy serialization.""" + data = super().model_dump(by_alias=by_alias, **kwargs) + + # Handle callable policy serialization + if callable(self.policy): + data["policy"] = callable_to_code_string(self.policy) + data["policy_type"] = "callable" + elif isinstance(self.policy, InspectorPolicy): + data["policy"] = self.policy.value + data["policy_type"] = "enum" + + return data + + @classmethod + def model_validate(cls, data: Union[Dict, "Inspector"]) -> "Inspector": + """Override model_validate to handle callable policy deserialization.""" + if isinstance(data, cls): + return data + + # Handle callable policy deserialization + if isinstance(data, dict) and data.get("policy_type") == "callable": + policy_code = data.get("policy") + if isinstance(policy_code, str): + try: + data["policy"] = code_string_to_callable(policy_code) + except Exception: + # If deserialization fails, fall back to default policy + data["policy"] = InspectorPolicy.ADAPTIVE + data.pop("policy_type", None) # Remove the type indicator + + return super().model_validate(data) diff --git a/aixplain/utils/evolve_utils.py b/aixplain/utils/evolve_utils.py new file mode 100644 index 00000000..437e1a31 --- /dev/null +++ b/aixplain/utils/evolve_utils.py @@ -0,0 +1,29 @@ +from typing import Union, Dict, Any, Optional, Text +from aixplain.modules.model.llm_model import LLM + + +def create_llm_dict(llm: Optional[Union[Text, LLM]]) -> Optional[Dict[str, Any]]: + """Create a dictionary representation of an LLM for evolution parameters. + + Args: + llm: Either an LLM ID string or an LLM object instance. + + Returns: + Dictionary with LLM information if llm is provided, None otherwise. + """ + if llm is None: + return None + + if isinstance(llm, LLM): + return { + "id": llm.id, + "name": llm.name, + "description": llm.description, + "supplier": llm.supplier, + "version": llm.version, + "function": llm.function, + "parameters": (llm.get_parameters().to_list() if llm.get_parameters() else None), + "temperature": getattr(llm, "temperature", None), + } + else: + return {"id": llm} diff --git a/tests/functional/general_assets/asset_functional_test.py b/tests/functional/general_assets/asset_functional_test.py index 1074c095..7617e30d 100644 --- a/tests/functional/general_assets/asset_functional_test.py +++ b/tests/functional/general_assets/asset_functional_test.py @@ -95,7 +95,7 @@ def test_model_supplier(ModelFactory): @pytest.mark.parametrize( "model_ids,model_names", [ - (("674728f51ed8e18fd8a1383f", "669a63646eb56306647e1091"), ("Yi-Large", "GPT-4o Mini")), + (("67be216bd8f6a65d6f74d5e9", "669a63646eb56306647e1091"), ("Anthropic Claude 3.7 Sonnet", "GPT-4o Mini")), ], ) @pytest.mark.parametrize("ModelFactory", [ModelFactory, v2.Model]) diff --git a/tests/functional/model/data/test_file_parser_input.pdf b/tests/functional/model/data/test_file_parser_input.pdf new file mode 100644 index 00000000..5882d1bc Binary files /dev/null and b/tests/functional/model/data/test_file_parser_input.pdf differ diff --git a/tests/functional/model/hf_onboarding_test.py b/tests/functional/model/hf_onboarding_test.py deleted file mode 100644 index fa68d2e8..00000000 --- a/tests/functional/model/hf_onboarding_test.py +++ /dev/null @@ -1,60 +0,0 @@ -__author__ = "michaellam" - -import pytest -import time - -from aixplain.factories.model_factory import ModelFactory -from tests.test_utils import delete_asset -from aixplain.utils import config - - -@pytest.mark.skip(reason="Model Deployment is deactivated for improvements.") -def test_deploy_model(): - # Start the deployment - model_name = "Test Model" - repo_id = "tiiuae/falcon-7b" - response = ModelFactory.deploy_huggingface_model(model_name, repo_id, hf_token=config.HF_TOKEN) - assert "id" in response.keys() - - # Check for status - model_id = response["id"] - num_retries = 120 - counter = 0 - while ModelFactory.get_huggingface_model_status(model_id)["status"].lower() != "onboarded": - time.sleep(10) - counter += 1 - if counter == num_retries: - assert ModelFactory.get_huggingface_model_status(model_id)["status"].lower() == "onboarded" - - # Clean up - delete_asset(model_id, config.TEAM_API_KEY) - - -# @pytest.mark.skip(reason="Model Deployment is deactivated for improvements.") -def test_nonexistent_model(): - # Start the deployment - model_name = "Test Model" - repo_id = "nonexistent-supplier/nonexistent-model" - response = ModelFactory.deploy_huggingface_model(model_name, repo_id, hf_token=config.HF_TOKEN) - assert response["statusCode"] == 400 - assert response["message"] == "err.unable_to_onboard_model" - - -# @pytest.mark.skip(reason="Model Deployment is deactivated for improvements.") -def test_size_limit(): - # Start the deployment - model_name = "Test Model" - repo_id = "tiiuae/falcon-40b" - response = ModelFactory.deploy_huggingface_model(model_name, repo_id, hf_token=config.HF_TOKEN) - assert response["statusCode"] == 400 - assert response["message"] == "err.unable_to_onboard_model" - - -# @pytest.mark.skip(reason="Model Deployment is deactivated for improvements.") -def test_gated_model(): - # Start the deployment - model_name = "Test Model" - repo_id = "meta-llama/Llama-2-7b-hf" - response = ModelFactory.deploy_huggingface_model(model_name, repo_id, hf_token="mock_key") - assert response["statusCode"] == 400 - assert response["message"] == "err.unable_to_onboard_model" diff --git a/tests/functional/model/image_upload_e2e_test.py b/tests/functional/model/image_upload_e2e_test.py deleted file mode 100644 index 90ebddfd..00000000 --- a/tests/functional/model/image_upload_e2e_test.py +++ /dev/null @@ -1,71 +0,0 @@ -__author__ = "michaellam" - -from pathlib import Path -import json -from aixplain.factories.model_factory import ModelFactory -from tests.test_utils import delete_asset, delete_service_account -from aixplain.utils import config -import docker -import pytest - - -def test_create_and_upload_model(): - # List the host machines - host_response = ModelFactory.list_host_machines() - for hosting_machine_dict in host_response: - assert "code" in hosting_machine_dict.keys() - assert "type" in hosting_machine_dict.keys() - assert "cores" in hosting_machine_dict.keys() - assert "memory" in hosting_machine_dict.keys() - assert "hourlyCost" in hosting_machine_dict.keys() - - # List the functions - response = ModelFactory.list_functions() - items = response["items"] - for item in items: - assert "output" not in item.keys() - assert "params" not in item.keys() - assert "id" not in item.keys() - assert "name" in item.keys() - - # Register the model, and create an image repository for it. - with open(Path("tests/test_requests/create_asset_request.json")) as f: - mock_register_payload = json.load(f) - name = mock_register_payload["name"] - description = mock_register_payload["description"] - function = mock_register_payload["function"] - source_language = mock_register_payload["sourceLanguage"] - input_modality = mock_register_payload["input_modality"] - output_modality = mock_register_payload["output_modality"] - documentation_url = mock_register_payload["documentation_url"] - register_response = ModelFactory.create_asset_repo( - name, description, function, source_language, input_modality, output_modality, documentation_url, config.TEAM_API_KEY - ) - assert "id" in register_response.keys() - assert "repositoryName" in register_response.keys() - model_id = register_response["id"] - repo_name = register_response["repositoryName"] - - # Log into the image repository. - login_response = ModelFactory.asset_repo_login() - - assert login_response["username"] == "AWS" - assert login_response["registry"] == "535945872701.dkr.ecr.us-east-1.amazonaws.com" - assert "password" in login_response.keys() - - username = login_response["username"] - password = login_response["password"] - registry = login_response["registry"] - - # Push an image to ECR - low_level_client = docker.APIClient(base_url="unix://var/run/docker.sock") - low_level_client.pull("bash") - low_level_client.tag("bash", f"{registry}/{repo_name}") - low_level_client.push(f"{registry}/{repo_name}", auth_config={"username": username, "password": password}) - - # Send an email to finalize onboarding process - ModelFactory.onboard_model(model_id, "latest", "fake_hash") - - # Clean up - delete_service_account(config.TEAM_API_KEY) - delete_asset(model_id, config.TEAM_API_KEY) diff --git a/tests/functional/model/image_upload_functional_test.py b/tests/functional/model/image_upload_functional_test.py deleted file mode 100644 index c5abd487..00000000 --- a/tests/functional/model/image_upload_functional_test.py +++ /dev/null @@ -1,82 +0,0 @@ -__author__ = "michaellam" -from pathlib import Path -import json -from aixplain.factories.model_factory import ModelFactory -from tests.test_utils import delete_asset, delete_service_account -from aixplain.utils import config -import pytest - - -def test_login(): - response = ModelFactory.asset_repo_login() - assert response["username"] == "AWS" - assert response["registry"] == "535945872701.dkr.ecr.us-east-1.amazonaws.com" - assert "password" in response.keys() - - # Test cleanup - delete_service_account(config.TEAM_API_KEY) - - -def test_create_asset_repo(): - with open(Path("tests/test_requests/create_asset_request.json")) as f: - mock_register_payload = json.load(f) - name = mock_register_payload["name"] - description = mock_register_payload["description"] - function = mock_register_payload["function"] - source_language = mock_register_payload["sourceLanguage"] - input_modality = mock_register_payload["input_modality"] - output_modality = mock_register_payload["output_modality"] - documentation_url = mock_register_payload["documentation_url"] - response = ModelFactory.create_asset_repo( - name, - description, - function, - source_language, - input_modality, - output_modality, - documentation_url, - config.TEAM_API_KEY, - ) - response_dict = dict(response) - assert "id" in response_dict.keys() - assert "repositoryName" in response_dict.keys() - - # Test cleanup - delete_asset(response["id"], config.TEAM_API_KEY) - - -def test_list_host_machines(): - response = ModelFactory.list_host_machines() - for hosting_machine_dict in response: - assert "code" in hosting_machine_dict.keys() - assert "type" in hosting_machine_dict.keys() - assert "cores" in hosting_machine_dict.keys() - assert "memory" in hosting_machine_dict.keys() - assert "hourlyCost" in hosting_machine_dict.keys() - - -def test_get_functions(): - # Verbose - response = ModelFactory.list_functions(True) - items = response["items"] - for item in items: - assert "output" in item.keys() - assert "params" in item.keys() - assert "id" in item.keys() - assert "name" in item.keys() - - # Non-verbose - response = ModelFactory.list_functions() # Not verbose by default - items = response["items"] - for item in items: - assert "output" not in item.keys() - assert "params" not in item.keys() - assert "id" not in item.keys() - assert "name" in item.keys() - - -@pytest.mark.skip(reason="Not included in first release") -def list_image_repo_tags(): - response = ModelFactory.list_image_repo_tags() - assert "Image tags" in response.keys() - assert "nextToken" in response.keys() diff --git a/tests/functional/model/run_model_test.py b/tests/functional/model/run_model_test.py index bf7415fb..e4bdb496 100644 --- a/tests/functional/model/run_model_test.py +++ b/tests/functional/model/run_model_test.py @@ -11,7 +11,9 @@ from aixplain.factories.index_factory.utils import AirParams, VectaraParams, GraphRAGParams, ZeroEntropyParams import time import os +import json +CACHE_FOLDER = ".cache" def pytest_generate_tests(metafunc): if "llm_model" in metafunc.fixturenames: @@ -88,7 +90,7 @@ def run_index_model(index_model, retries): [Record(value="Berlin is the capital of Germany.", value_type="text", uri="", id="1", attributes={})] ) break - except Exception as e: + except Exception: time.sleep(180) response = index_model.search("Berlin") @@ -143,9 +145,6 @@ def test_index_model_with_filter(embedding_model, supplier_params): from aixplain.factories import IndexFactory from aixplain.modules.model.index_model import IndexFilter, IndexFilterOperator - for index in IndexFactory.list()["results"]: - index.delete() - params = supplier_params(name=str(uuid4()), description=str(uuid4())) if embedding_model is not None: params = supplier_params(name=str(uuid4()), description=str(uuid4()), embedding_model=embedding_model) @@ -212,7 +211,7 @@ def test_aixplain_model_cache_creation(): # Instantiate the Model (replace this with a real model ID from your env) model_id = "6239efa4822d7a13b8e20454" # Translate from Punjabi to Portuguese (Brazil) - _ = Model(id=model_id) + _ = ModelFactory.get(model_id) # Assert the cache file was created assert os.path.exists(cache_file), "Expected cache file was not created." @@ -230,9 +229,6 @@ def test_index_model_air_with_image(): from uuid import uuid4 from aixplain.factories.index_factory.utils import AirParams - for index in IndexFactory.list()["results"]: - index.delete() - params = AirParams( name=f"Image Index {uuid4()}", description="Index for images", embedding_model=EmbeddingModel.JINA_CLIP_V2_MULTIMODAL ) @@ -305,9 +301,6 @@ def test_index_model_air_with_splitter(embedding_model, supplier_params): from aixplain.modules.model.index_model import Splitter from aixplain.enums.splitting_options import SplittingOptions - for index in IndexFactory.list()["results"]: - index.delete() - params = supplier_params( name=f"Splitter Index {uuid4()}", description="Index for splitter", embedding_model=embedding_model ) @@ -324,6 +317,103 @@ def test_index_model_air_with_splitter(embedding_model, supplier_params): index_model.delete() +def test_index_model_with_txt_file(): + """Testing Index Model with local txt file input""" + from aixplain.factories import IndexFactory + from uuid import uuid4 + from aixplain.factories.index_factory.utils import AirParams + from pathlib import Path + + # Create test file path + test_file_path = Path(__file__).parent / "data" / "test_input.txt" + + # Create index with OpenAI Ada 002 for text processing + params = AirParams( + name=f"File Index {uuid4()}", description="Index for file processing", embedding_model=EmbeddingModel.OPENAI_ADA002 + ) + index_model = IndexFactory.create(params=params) + + try: + # Upsert the file + response = index_model.upsert(str(test_file_path)) + assert str(response.status) == "SUCCESS" + + # Verify the content was indexed + response = index_model.search("demo") + assert str(response.status) == "SUCCESS" + assert "🤖" in response.data, "Robot emoji should be present in the response" + + # Verify count + assert index_model.count() > 0 + + finally: + # Cleanup + index_model.delete() + + +def test_index_model_with_pdf_file(): + """Testing Index Model with PDF file input""" + from aixplain.factories import IndexFactory + from uuid import uuid4 + from aixplain.factories.index_factory.utils import AirParams + from pathlib import Path + + # Create test file path + test_file_path = Path(__file__).parent / "data" / "test_file_parser_input.pdf" + + # Create index with OpenAI Ada 002 for text processing + params = AirParams( + name=f"PDF Index {uuid4()}", description="Index for PDF processing", embedding_model=EmbeddingModel.OPENAI_ADA002 + ) + index_model = IndexFactory.create(params=params) + + try: + # Upsert the PDF file + response = index_model.upsert(str(test_file_path)) + assert str(response.status) == "SUCCESS" + + # Verify the content was indexed + response = index_model.search("document") + assert str(response.status) == "SUCCESS" + assert len(response.data) > 0 + + # Verify count + assert index_model.count() > 0 + + finally: + # Cleanup + index_model.delete() + + +def test_index_model_with_invalid_file(): + """Testing Index Model with invalid file input""" + from aixplain.factories import IndexFactory + from uuid import uuid4 + from aixplain.factories.index_factory.utils import AirParams + from pathlib import Path + + # Create non-existent file path + test_file_path = Path(__file__).parent / "data" / "nonexistent.pdf" + + # Create index with OpenAI Ada 002 for text processing + params = AirParams( + name=f"Invalid File Index {uuid4()}", + description="Index for invalid file testing", + embedding_model=EmbeddingModel.OPENAI_ADA002, + ) + index_model = IndexFactory.create(params=params) + + try: + # Attempt to upsert non-existent file + with pytest.raises(Exception) as e: + index_model.upsert(str(test_file_path)) + assert "does not exist" in str(e.value) + + finally: + # Cleanup + index_model.delete() + + def _test_records(): from aixplain.modules.model.record import Record from aixplain.enums import DataType @@ -373,11 +463,6 @@ def setup_index_with_test_records(): from aixplain.enums import EmbeddingModel from aixplain.factories.index_factory.utils import AirParams from uuid import uuid4 - import time - - # Clean up all existing indexes - for index in IndexFactory.list()["results"]: - index.delete() params = AirParams( name=f"Test Index {uuid4()}", diff --git a/tests/functional/team_agent/evolver_test.py b/tests/functional/team_agent/evolver_test.py new file mode 100644 index 00000000..47c45e45 --- /dev/null +++ b/tests/functional/team_agent/evolver_test.py @@ -0,0 +1,138 @@ +import pytest +from aixplain.enums.function import Function +from aixplain.enums.supplier import Supplier +from aixplain.enums import ResponseStatus +from aixplain.factories.agent_factory import AgentFactory +from aixplain.factories.team_agent_factory import TeamAgentFactory +import time + + +team_dict = { + "team_agent_name": "Test Text Speech Team", + "llm_id": "6646261c6eb563165658bbb1", + "llm_name": "GPT4o", + "query": "Translate this text into Portuguese: 'This is a test'. Translate to pt and synthesize in audio", + "description": "You are a text translation and speech synthesizing agent. You will be provided a text in the source language and expected to translate and synthesize in the target language.", + "agents": [ + { + "agent_name": "Text Translation agent", + "llm_id": "6646261c6eb563165658bbb1", + "llm_name": "GPT4o", + "description": "## ROLE\nText Translator\n\n## GOAL\nTranslate the text supplied into the users desired language.\n\n## BACKSTORY\nYou are a text translation agent. You will be provided a text in the source language and expected to translate in the target language.", + "tasks": [ + { + "name": "Text translation", + "description": "Translate a text from source language (English) to target language (Portuguese)", + "expected_output": "target language text", + } + ], + "model_tools": [{"function": "translation", "supplier": "AWS"}], + }, + { + "agent_name": "Test Speech Synthesis agent", + "llm_id": "6646261c6eb563165658bbb1", + "llm_name": "GPT4o", + "description": "## ROLE\nSpeech Synthesizer\n\n## GOAL\nTranscribe the translated text into speech.\n\n## BACKSTORY\nYou are a speech synthesizing agent. You will be provided a text to synthesize into audio and return the audio link.", + "tasks": [ + { + "name": "Speech synthesis", + "description": "Synthesize a text from text to speech", + "expected_output": "audio link of the synthesized text", + "dependencies": ["Text translation"], + } + ], + "model_tools": [{"function": "speech_synthesis", "supplier": "Google"}], + }, + ], +} + + +def parse_tools(tools_info): + tools = [] + for tool in tools_info: + function_enum = Function[tool["function"].upper().replace(" ", "_")] + supplier_enum = Supplier[tool["supplier"].upper().replace(" ", "_")] + tools.append(AgentFactory.create_model_tool(function=function_enum, supplier=supplier_enum)) + return tools + + +def build_team_agent_from_json(team_config: dict): + agents_data = team_config["agents"] + tasks_data = team_config.get("tasks", []) + + agent_objs = [] + for agent_entry in agents_data: + agent_name = agent_entry["agent_name"] + agent_description = agent_entry["description"] + agent_llm_id = agent_entry.get("llm_id", None) + + agent_tasks = [] + for task in tasks_data: + task_name = task.get("task_name", "") + task_info = task + + if agent_name == task_info["agent"]: + task_obj = AgentFactory.create_task( + name=task_name.replace("_", " "), + description=task_info.get("description", ""), + expected_output=task_info.get("expected_output", ""), + dependencies=[t.replace("_", " ") for t in task_info.get("dependencies", [])], + ) + agent_tasks.append(task_obj) + + if "model_tools" in agent_entry: + agent_tools = parse_tools(agent_entry["model_tools"]) + else: + agent_tools = [] + + agent_obj = AgentFactory.create( + name=agent_name.replace("_", " "), + description=agent_description, + tools=agent_tools, + tasks=agent_tasks, + llm_id=agent_llm_id, + ) + agent_objs.append(agent_obj) + + return TeamAgentFactory.create( + name=team_config["team_agent_name"], + agents=agent_objs, + description=team_config["description"], + llm_id=team_config.get("llm_id", None), + inspectors=[], + use_mentalist=True, + ) + + +@pytest.fixture +def team_agent(): + return build_team_agent_from_json(team_dict) + + +def test_evolver_output(team_agent): + response = team_agent.evolve() + poll_url = response["url"] + result = team_agent.poll(poll_url) + + while result.status == ResponseStatus.IN_PROGRESS: + time.sleep(30) + result = team_agent.poll(poll_url) + + assert "system" in result["data"]["evolved_agent"]["name"].lower(), "System should be in the system name" + assert result["status"] == ResponseStatus.SUCCESS, "Final result should have a 'SUCCESS' status" + assert "evolved_agent" in result["data"], "Data should contain 'evolved_agent'" + assert "evaluation_report" in result["data"], "Data should contain 'evaluation_report'" + assert "criteria" in result["data"], "Data should contain 'criteria'" + assert "archive" in result["data"], "Data should contain 'archive'" + + +def test_evolver_with_custom_llm_id(team_agent): + """Test evolver functionality with custom LLM ID""" + custom_llm_id = "6646261c6eb563165658bbb1" # GPT-4o ID + + # Test with llm parameter + response = team_agent.evolve_async(llm=custom_llm_id) + + assert response is not None + assert "url" in response or response.get("url") is not None + assert response["status"] == ResponseStatus.IN_PROGRESS diff --git a/tests/functional/team_agent/inspector_functional_test.py b/tests/functional/team_agent/inspector_functional_test.py index e1679dde..9e50eac0 100644 --- a/tests/functional/team_agent/inspector_functional_test.py +++ b/tests/functional/team_agent/inspector_functional_test.py @@ -1,5 +1,7 @@ """ Functional tests for team agents with inspectors. + +WARNING: This feature is currently in private beta. """ from dotenv import load_dotenv @@ -10,10 +12,12 @@ import pytest from aixplain import aixplain_v2 as v2 -from aixplain.factories import AgentFactory, TeamAgentFactory +from aixplain.factories import AgentFactory, TeamAgentFactory, ModelFactory from aixplain.enums.asset_status import AssetStatus from aixplain.modules.team_agent import InspectorTarget -from aixplain.modules.team_agent.inspector import Inspector, InspectorPolicy +from aixplain.modules.team_agent.inspector import Inspector, InspectorPolicy, InspectorAction, InspectorOutput +from aixplain.modules.model.response import ModelResponse +from aixplain.enums.response_status import ResponseStatus from tests.functional.team_agent.test_utils import ( RUN_FILE, @@ -24,6 +28,36 @@ ) +# Define callable policy functions at module level for proper serialization +def process_response(model_response: ModelResponse, input_content: str) -> InspectorOutput: + """Basic callable policy function for testing.""" + if "error" in model_response.error_message.lower() or "invalid" in model_response.data.lower(): + return InspectorOutput(critiques="Error or invalid content detected", content_edited="", action=InspectorAction.ABORT) + elif "warning" in model_response.data.lower(): + return InspectorOutput(critiques="Warning detected", content_edited="", action=InspectorAction.RERUN) + return InspectorOutput(critiques="No issues detected", content_edited="", action=InspectorAction.CONTINUE) + + +def process_response_abort(model_response: ModelResponse, input_content: str) -> InspectorOutput: + """Callable policy function that aborts on specific content.""" + abort_keywords = ["dangerous", "harmful", "illegal", "inappropriate"] + for keyword in abort_keywords: + if keyword in model_response.data.lower(): + return InspectorOutput( + critiques=f"Abort keyword '{keyword}' detected", content_edited="", action=InspectorAction.ABORT + ) + return InspectorOutput(critiques="No abort keywords detected", content_edited="", action=InspectorAction.CONTINUE) + + +def process_response_rerun(model_response: ModelResponse, input_content: str) -> InspectorOutput: + """Callable policy function that triggers rerun on specific conditions.""" + if len(model_response.data.strip()) < 10 or "placeholder" in model_response.data.lower(): + return InspectorOutput( + critiques="Content too short or contains placeholder", content_edited="", action=InspectorAction.RERUN + ) + return InspectorOutput(critiques="Content is acceptable", content_edited="", action=InspectorAction.CONTINUE) + + @pytest.fixture(scope="function") def delete_agents_and_team_agents(): for team_agent in TeamAgentFactory.list()["results"]: @@ -557,3 +591,218 @@ def test_team_agent_with_input_adaptive_inspector(run_input_map, delete_agents_a ), "The mentalist input does not contain the revised query from the last query_manager" team_agent.delete() + + +@pytest.mark.parametrize("TeamAgentFactory", [TeamAgentFactory, v2.TeamAgent]) +def test_team_agent_with_callable_policy(run_input_map, delete_agents_and_team_agents, TeamAgentFactory): + """Comprehensive test of callable policy functionality with team agent integration""" + assert delete_agents_and_team_agents + + agents = create_agents_from_input_map(run_input_map) + + # Test 1: Create inspector with callable policy + inspector = Inspector( + name="callable_inspector", + model_id=run_input_map["llm_id"], + model_params={"prompt": "Check if the steps are valid and provide feedback"}, + policy=process_response, # Using module-level callable policy + ) + + # Test 2: Verify the inspector was created correctly + assert inspector.name == "callable_inspector" + assert callable(inspector.policy) + assert inspector.policy.__name__ == "process_response" + + # Test 3: Verify the callable policy works correctly + result1 = inspector.policy( + ModelResponse(status=ResponseStatus.FAILED, error_message="This is an error message", data="input"), "input" + ) + assert result1.action == InspectorAction.ABORT + + result2 = inspector.policy( + ModelResponse(status=ResponseStatus.SUCCESS, data="This is a warning message", error_message=""), "input" + ) + assert result2.action == InspectorAction.RERUN + + result3 = inspector.policy( + ModelResponse(status=ResponseStatus.SUCCESS, data="This is a normal message", error_message=""), "input" + ) + assert result3.action == InspectorAction.CONTINUE + + # Test 4: Create team agent with callable policy inspector + team_agent = create_team_agent( + TeamAgentFactory, + agents, + run_input_map, + use_mentalist=True, + inspectors=[inspector], + inspector_targets=[InspectorTarget.STEPS], + ) + + assert team_agent is not None + assert team_agent.status == AssetStatus.DRAFT + + # Test 5: Deploy team agent (backend properly handles callable policies) + team_agent.deploy() + team_agent = TeamAgentFactory.get(team_agent.id) + assert team_agent is not None + assert team_agent.status == AssetStatus.ONBOARDED + + # Test 6: Verify backend properly handles callable policies + assert len(team_agent.inspectors) == 1 + backend_inspector = team_agent.inspectors[0] + assert backend_inspector.name == "callable_inspector" + # Backend should properly handle callable policies, not fall back to ADAPTIVE + assert callable(backend_inspector.policy) + assert backend_inspector.policy.__name__ == "process_response" + + # Verify the backend-preserved callable policy still works correctly + assert ( + backend_inspector.policy( + ModelResponse(status=ResponseStatus.FAILED, error_message="This is an error message", data="input"), "input" + ).action + == InspectorAction.ABORT + ) + assert ( + backend_inspector.policy( + ModelResponse(status=ResponseStatus.SUCCESS, data="This is a warning message", error_message=""), "input" + ).action + == InspectorAction.RERUN + ) + assert ( + backend_inspector.policy( + ModelResponse(status=ResponseStatus.SUCCESS, data="This is a normal message", error_message=""), "input" + ).action + == InspectorAction.CONTINUE + ) + + team_agent.delete() + + +@pytest.mark.parametrize("TeamAgentFactory", [TeamAgentFactory, v2.TeamAgent]) +def test_inspector_action_verification(run_input_map, delete_agents_and_team_agents, TeamAgentFactory): + """Test that inspector actions are properly executed and their results are verified""" + assert delete_agents_and_team_agents + + agents = create_agents_from_input_map(run_input_map) + + # Create a custom callable policy that always returns ABORT + def process_response(model_response: ModelResponse, input_content: str) -> InspectorOutput: + """Custom policy that always returns ABORT for safety testing.""" + return InspectorOutput(critiques="Safety check", content_edited="", action=InspectorAction.ABORT) + + # Create inspector with custom callable policy + inspector = Inspector( + name="custom_abort_inspector", + model_id=run_input_map["llm_id"], + model_params={"prompt": "You are a safety inspector."}, + policy=process_response, + ) + + # Create team agent with the custom policy inspector + team_agent = create_team_agent( + TeamAgentFactory, + agents, + run_input_map, + use_mentalist=True, + inspectors=[inspector], + inspector_targets=[InspectorTarget.STEPS], + ) + + # Deploy and run team agent + team_agent.deploy() + team_agent = TeamAgentFactory.get(team_agent.id) + response = team_agent.run(data=run_input_map["query"]) + + assert response is not None + assert response["completed"] is True + assert response["status"].lower() == "success" + + # Extract steps from response + steps = getattr(response.data, "intermediate_steps", []) if hasattr(response, "data") else [] + + # Find inspector steps + inspector_steps = [step for step in steps if "inspector" in step.get("agent", "").lower()] + + # If no inspector steps found, backend may not be using custom policies + if not inspector_steps: + print("No inspector steps found - backend may not be using custom policies") + team_agent.delete() + return + + # Verify inspector executed and took ABORT action + inspector_step = inspector_steps[0] + assert inspector_step.get("action") == "abort", "Inspector should have returned ABORT" + + # Verify response generator ran after inspector + response_generator_steps = [step for step in steps if "response_generator" in step.get("agent", "").lower()] + assert len(response_generator_steps) == 1, "Response generator should run exactly once after ABORT" + + team_agent.delete() + + +@pytest.mark.parametrize("TeamAgentFactory", [TeamAgentFactory, v2.TeamAgent]) +def test_team_agent_with_utility_inspector(run_input_map, delete_agents_and_team_agents, TeamAgentFactory): + """Test team agent with a Utility model as inspector""" + assert delete_agents_and_team_agents + + agents = create_agents_from_input_map(run_input_map) + + def lowercase_inspector(content: str) -> bool: + if content.islower(): + return "" + else: + return "Content is not all lowercase. There are uppercase characters in the content." + + utility_model = ModelFactory.create_utility_model( + name="Lowercase Inspector Test", + description="Inspect the content of the response. If the content is not all lowercase, provide feedback.'", + code=lowercase_inspector, + ) + utility_model.deploy() + + utility_model_id = utility_model.id + inspector = Inspector( + name="utility_inspector", + model_id=utility_model_id, + policy=InspectorPolicy.WARN, + ) + + # Create team agent with steps inspector + team_agent = create_team_agent( + TeamAgentFactory, + agents, + run_input_map, + use_mentalist=True, + inspectors=[inspector], + inspector_targets=[InspectorTarget.STEPS], + ) + + assert team_agent is not None + assert team_agent.status == AssetStatus.DRAFT + + # deploy team agent + team_agent.deploy() + team_agent = TeamAgentFactory.get(team_agent.id) + assert team_agent is not None + assert team_agent.status == AssetStatus.ONBOARDED + + # Run the team agent + response = team_agent.run(data=run_input_map["query"]) + + assert response is not None + assert response["completed"] is True + assert response["status"].lower() == "success" + + # Check for inspector steps + if "intermediate_steps" in response["data"]: + steps = response["data"]["intermediate_steps"] + verify_inspector_steps(steps, ["utility_inspector"], [InspectorTarget.STEPS]) + verify_response_generator(steps) + + # Verify inspector runs and execution continues + inspector_steps = [step for step in steps if "utility_inspector" in step.get("agent", "").lower()] + assert len(inspector_steps) > 0, "Utility inspector should run at least once" + + utility_model.delete() + team_agent.delete() diff --git a/tests/functional/team_agent/team_agent_functional_test.py b/tests/functional/team_agent/team_agent_functional_test.py index 5b9b4972..e8352e1d 100644 --- a/tests/functional/team_agent/team_agent_functional_test.py +++ b/tests/functional/team_agent/team_agent_functional_test.py @@ -124,7 +124,10 @@ def test_fail_non_existent_llm(run_input_map, TeamAgentFactory): llm_id="non_existent_llm", agents=agents, ) - assert str(exc_info.value) == "Large Language Model with ID 'non_existent_llm' not found." + assert ( + str(exc_info.value) + == "TeamAgent Onboarding Error: LLM non_existent_llm does not exist for Main LLM. To resolve this, set the following LLM parameters to a valid LLM object or LLM ID: llm, supervisor_llm, mentalist_llm." + ) @pytest.mark.parametrize("TeamAgentFactory", [TeamAgentFactory, v2.TeamAgent]) @@ -288,7 +291,7 @@ def test_team_agent_with_instructions(delete_agents_and_team_agents): instructions="Use only 'Agent 2' to solve the tasks.", llm_id="6646261c6eb563165658bbb1", use_mentalist=True, - use_inspector=False, + inspectors=[], ) response = team_agent.run(data="Translate 'cat' to Portuguese") @@ -405,7 +408,7 @@ class Response(BaseModel): description="Team agent", llm_id="6646261c6eb563165658bbb1", use_mentalist=False, - use_inspector=False, + inspectors=[], ) # Run the team agent @@ -481,7 +484,7 @@ def test_team_agent_with_slack_connector(): description="Team agent", llm_id="6646261c6eb563165658bbb1", use_mentalist=False, - use_inspector=False, + inspectors=[], ) response = team_agent.run( diff --git a/tests/unit/agent/agent_factory_utils_test.py b/tests/unit/agent/agent_factory_utils_test.py index 5c7fbb77..89c1b695 100644 --- a/tests/unit/agent/agent_factory_utils_test.py +++ b/tests/unit/agent/agent_factory_utils_test.py @@ -10,7 +10,7 @@ from aixplain.modules.agent.tool.custom_python_code_tool import CustomPythonCodeTool from aixplain.modules.agent.tool.sql_tool import SQLTool from aixplain.modules.agent import Agent -from aixplain.modules.agent.agent_task import AgentTask +from aixplain.modules.agent.agent_task import WorkflowTask from aixplain.factories import ModelFactory, PipelineFactory import os @@ -50,7 +50,13 @@ def mock_tools(): "tool_dict,expected_error", [ pytest.param( - {"type": "model", "supplier": "aixplain", "version": "1.0", "assetId": "test_model", "description": "Test model"}, + { + "type": "model", + "supplier": "aixplain", + "version": "1.0", + "assetId": "test_model", + "description": "Test model", + }, "Function is required for model tools", id="missing_function", ), @@ -124,19 +130,30 @@ def test_build_tool_error_cases(tool_dict, expected_error): id="model_tool_with_params", ), pytest.param( - {"type": "pipeline", "description": "Test pipeline", "assetId": "test_pipeline"}, + { + "type": "pipeline", + "description": "Test pipeline", + "assetId": "test_pipeline", + }, PipelineTool, {"description": "Test pipeline", "pipeline": "test_pipeline"}, id="pipeline_tool", ), pytest.param( - {"type": "utility", "description": "Test utility", "utilityCode": "print('Hello World')"}, + { + "type": "utility", + "description": "Test utility", + "utilityCode": "print('Hello World')", + }, CustomPythonCodeTool, {"description": "Test utility", "code": "print('Hello World')"}, id="custom_python_tool", ), pytest.param( - {"type": "utility", "description": "Test utility"}, PythonInterpreterTool, {}, id="python_interpreter_tool" + {"type": "utility", "description": "Test utility"}, + PythonInterpreterTool, + {}, + id="python_interpreter_tool", ), pytest.param( { @@ -186,8 +203,6 @@ def test_build_tool_error_cases(tool_dict, expected_error): ), ], ) - - def test_build_tool_success_cases(tool_dict, expected_type, expected_attrs, mock_model, mocker): """Test successful tool creation with various configurations.""" mocker.patch.object(ModelFactory, "get", return_value=mock_model) @@ -195,13 +210,24 @@ def test_build_tool_success_cases(tool_dict, expected_type, expected_attrs, mock "aixplain.modules.model.utils.parse_code_decorated", return_value=("print('Hello World')", [], "Test description", "test_name"), ) - mocker.patch("os.path.exists", lambda path: True if path == "test_db.db" else os.path.exists(path)) - mocker.patch("aixplain.factories.file_factory.FileFactory.upload", return_value="s3://mocked-file-path/test_db.db") + mocker.patch( + "os.path.exists", + lambda path: True if path == "test_db.db" else os.path.exists(path), + ) + mocker.patch( + "aixplain.factories.file_factory.FileFactory.upload", + return_value="s3://mocked-file-path/test_db.db", + ) if tool_dict["type"] == "pipeline": mocker.patch.object( PipelineFactory, "get", - return_value=Pipeline(id=tool_dict["assetId"], description=tool_dict["description"], name="Pipeline", api_key=""), + return_value=Pipeline( + id=tool_dict["assetId"], + description=tool_dict["description"], + name="Pipeline", + api_key="", + ), ) tool = build_tool(tool_dict) @@ -273,7 +299,11 @@ def test_build_tool_success_cases(tool_dict, expected_type, expected_attrs, mock "description": "Test model", "function": "speech-recognition", }, - {"type": "pipeline", "description": "Test pipeline", "assetId": "test_pipeline"}, + { + "type": "pipeline", + "description": "Test pipeline", + "assetId": "test_pipeline", + }, ], }, { @@ -286,8 +316,6 @@ def test_build_tool_success_cases(tool_dict, expected_type, expected_attrs, mock ), ], ) - - def test_build_agent_success_cases(payload, expected_attrs, mock_tools, mocker): """Test successful agent creation with various configurations.""" mocker.patch.object( @@ -301,9 +329,9 @@ def test_build_agent_success_cases(payload, expected_attrs, mock_tools, mocker): for attr, value in expected_attrs.items(): if attr == "tasks": - assert len(agent.tasks) == len(value) - for task, expected_task in zip(agent.tasks, value): - assert isinstance(task, AgentTask) + assert len(agent.workflow_tasks) == len(value) + for task, expected_task in zip(agent.workflow_tasks, value): + assert isinstance(task, WorkflowTask) for task_attr, task_value in expected_task.items(): assert getattr(task, task_attr) == task_value elif attr == "tools": @@ -322,7 +350,13 @@ def test_build_agent_success_cases(payload, expected_attrs, mock_tools, mocker): "id": "test_agent", "name": "Test Agent", "status": "onboarded", - "assets": [{"type": "invalid_type", "description": "Test tool", "assetId": "invalid_asset"}], + "assets": [ + { + "type": "invalid_type", + "description": "Test tool", + "assetId": "invalid_asset", + } + ], }, "Agent Creation Error: Tool type not supported", id="invalid_tool_type", @@ -369,7 +403,13 @@ def test_build_agent_success_cases(payload, expected_attrs, mock_tools, mocker): "id": "test_agent", "name": "Test Agent", "status": "onboarded", - "assets": [{"type": "model", "assetId": "test_model", "function": "speech-recognition"}], + "assets": [ + { + "type": "model", + "assetId": "test_model", + "function": "speech-recognition", + } + ], }, "Tool test_model is not available. Make sure it exists or you have access to it. If you think this is an error, please contact the administrators.", id="generic_error", diff --git a/tests/unit/agent/agent_test.py b/tests/unit/agent/agent_test.py index fc66141b..24a78b7c 100644 --- a/tests/unit/agent/agent_test.py +++ b/tests/unit/agent/agent_test.py @@ -803,13 +803,13 @@ def test_create_agent_task(): assert task.name == "Test Task" assert task.description == "Test Description" assert task.expected_output == "Test Output" - assert task.dependencies is None + assert task.dependencies == [] task_dict = task.to_dict() assert task_dict["name"] == "Test Task" assert task_dict["description"] == "Test Description" assert task_dict["expectedOutput"] == "Test Output" - assert task_dict["dependencies"] is None + assert task_dict["dependencies"] == [] def test_agent_response(): diff --git a/tests/unit/agent/evolve_param_test.py b/tests/unit/agent/evolve_param_test.py new file mode 100644 index 00000000..2f0f744e --- /dev/null +++ b/tests/unit/agent/evolve_param_test.py @@ -0,0 +1,187 @@ +""" +Unit tests for EvolveParam base model functionality +""" + +import pytest +from aixplain.modules.agent.evolve_param import ( + EvolveParam, + EvolveType, + validate_evolve_param, +) + + +class TestEvolveParam: + """Test class for EvolveParam functionality""" + + def test_default_initialization(self): + """Test EvolveParam default initialization""" + default_param = EvolveParam() + + assert default_param is not None + assert default_param.to_evolve is False + assert default_param.evolve_type == EvolveType.TEAM_TUNING + assert default_param.max_successful_generations == 3 + assert default_param.max_failed_generation_retries == 3 + assert default_param.max_iterations == 50 + assert default_param.max_non_improving_generations == 2 + assert default_param.llm is None + assert default_param.additional_params == {} + + # Test to_dict method + result_dict = default_param.to_dict() + assert isinstance(result_dict, dict) + assert "toEvolve" in result_dict + + def test_custom_initialization(self): + """Test EvolveParam custom initialization""" + custom_param = EvolveParam( + to_evolve=True, + max_successful_generations=5, + max_failed_generation_retries=2, + max_iterations=30, + max_non_improving_generations=4, + evolve_type=EvolveType.TEAM_TUNING, + llm={"id": "test_llm_id", "name": "Test LLM"}, + additional_params={"customParam": "custom_value"}, + ) + + assert custom_param.to_evolve is True + assert custom_param.max_successful_generations == 5 + assert custom_param.max_failed_generation_retries == 2 + assert custom_param.max_iterations == 30 + assert custom_param.max_non_improving_generations == 4 + assert custom_param.evolve_type == EvolveType.TEAM_TUNING + assert custom_param.llm == {"id": "test_llm_id", "name": "Test LLM"} + assert custom_param.additional_params == {"customParam": "custom_value"} + + # Test to_dict method + result_dict = custom_param.to_dict() + assert result_dict["toEvolve"] is True + assert result_dict["max_successful_generations"] == 5 + assert result_dict["max_failed_generation_retries"] == 2 + assert result_dict["max_iterations"] == 30 + assert result_dict["max_non_improving_generations"] == 4 + assert result_dict["evolve_type"] == EvolveType.TEAM_TUNING + assert result_dict["llm"] == {"id": "test_llm_id", "name": "Test LLM"} + assert result_dict["customParam"] == "custom_value" + + def test_from_dict_with_api_format(self): + """Test EvolveParam from_dict() with API format""" + api_dict = { + "toEvolve": True, + "max_successful_generations": 10, + "max_failed_generation_retries": 4, + "max_iterations": 40, + "max_non_improving_generations": 5, + "evolve_type": EvolveType.TEAM_TUNING, + "llm": {"id": "api_llm_id", "name": "API LLM"}, + "customParam": "custom_value", + } + + from_dict_param = EvolveParam.from_dict(api_dict) + + assert from_dict_param.to_evolve is True + assert from_dict_param.max_successful_generations == 10 + assert from_dict_param.max_failed_generation_retries == 4 + assert from_dict_param.max_iterations == 40 + assert from_dict_param.max_non_improving_generations == 5 + assert from_dict_param.evolve_type == EvolveType.TEAM_TUNING + assert from_dict_param.llm == {"id": "api_llm_id", "name": "API LLM"} + + # Test round-trip conversion + result_dict = from_dict_param.to_dict() + assert result_dict["toEvolve"] is True + assert result_dict["max_successful_generations"] == 10 + assert result_dict["max_failed_generation_retries"] == 4 + assert result_dict["max_iterations"] == 40 + assert result_dict["max_non_improving_generations"] == 5 + + def test_validate_evolve_param_with_none(self): + """Test validate_evolve_param() with None input""" + validated_none = validate_evolve_param(None) + + assert validated_none is not None + assert isinstance(validated_none, EvolveParam) + assert validated_none.to_evolve is False + + result_dict = validated_none.to_dict() + assert "toEvolve" in result_dict + + def test_validate_evolve_param_with_dict(self): + """Test validate_evolve_param() with dictionary input""" + input_dict = {"toEvolve": True, "max_successful_generations": 5} + validated_dict = validate_evolve_param(input_dict) + + assert isinstance(validated_dict, EvolveParam) + assert validated_dict.to_evolve is True + assert validated_dict.max_successful_generations == 5 + + result_dict = validated_dict.to_dict() + assert result_dict["toEvolve"] is True + assert result_dict["max_successful_generations"] == 5 + + def test_validate_evolve_param_with_instance(self): + """Test validate_evolve_param() with EvolveParam instance""" + custom_param = EvolveParam( + to_evolve=True, + max_successful_generations=5, + max_failed_generation_retries=2, + max_iterations=30, + max_non_improving_generations=4, + evolve_type=EvolveType.TEAM_TUNING, + llm={"id": "instance_llm_id"}, + additional_params={"customParam": "custom_value"}, + ) + + validated_instance = validate_evolve_param(custom_param) + + assert validated_instance is custom_param # Should return the same instance + assert validated_instance.to_evolve is True + assert validated_instance.max_successful_generations == 5 + assert validated_instance.max_failed_generation_retries == 2 + + def test_invalid_max_successful_generations_raises_error(self): + """Test that invalid max_successful_generations raises ValueError""" + with pytest.raises(ValueError, match="max_successful_generations must be positive"): + EvolveParam(max_successful_generations=0) # max_successful_generations <= 0 should fail + + def test_validate_evolve_param_missing_to_evolve_key(self): + """Test that missing toEvolve key raises ValueError""" + with pytest.raises(ValueError, match="evolve parameter must contain 'toEvolve' key"): + validate_evolve_param({"no_to_evolve": True}) # Missing toEvolve key + + def test_evolve_type_enum_values(self): + """Test that EvolveType enum values work correctly""" + param_team_tuning = EvolveParam(evolve_type=EvolveType.TEAM_TUNING) + + assert param_team_tuning.evolve_type == EvolveType.TEAM_TUNING + + # Test in to_dict conversion + dict_team_tuning = param_team_tuning.to_dict() + + assert "evolve_type" in dict_team_tuning + + def test_invalid_additional_params_type(self): + """Test that invalid additional_params type raises ValueError""" + with pytest.raises(ValueError, match="additional_params must be a dictionary"): + EvolveParam(additional_params="not a dict") + + def test_merge_with_dict(self): + """Test merging with a dictionary""" + base_param = EvolveParam(to_evolve=False, max_successful_generations=3, additional_params={"base": "value"}) + merge_dict = { + "toEvolve": True, + "max_successful_generations": 5, + "llm": {"id": "merged_llm_id"}, + "customParam": "custom_value", + } + + merged = base_param.merge(merge_dict) + + assert merged.to_evolve is True + assert merged.max_successful_generations == 5 + assert merged.llm == {"id": "merged_llm_id"} + assert merged.additional_params == { + "base": "value", + "customParam": "custom_value", + } diff --git a/tests/unit/agent/test_agent_evolve.py b/tests/unit/agent/test_agent_evolve.py new file mode 100644 index 00000000..f4aa1990 --- /dev/null +++ b/tests/unit/agent/test_agent_evolve.py @@ -0,0 +1,214 @@ +""" +Unit tests for Agent evolve functionality with llm parameter +""" + +import pytest +from unittest.mock import Mock, patch +from aixplain.modules.agent import Agent +from aixplain.modules.model.llm_model import LLM +from aixplain.modules.agent.evolve_param import EvolveParam +from aixplain.enums import EvolveType, Function, Supplier, ResponseStatus +from aixplain.modules.agent.agent_response import AgentResponse +from aixplain.modules.agent.agent_response_data import AgentResponseData + + +class TestAgentEvolve: + """Test class for Agent evolve functionality""" + + @pytest.fixture + def mock_agent(self): + """Create a mock Agent for testing""" + agent = Mock(spec=Agent) + agent.id = "test_agent_id" + agent.name = "Test Agent" + agent.api_key = "test_api_key" + return agent + + @pytest.fixture + def mock_llm(self): + """Create a mock LLM for testing""" + llm = Mock(spec=LLM) + llm.id = "test_llm_id" + llm.name = "Test LLM" + llm.description = "Test LLM Description" + llm.supplier = Supplier.OPENAI + llm.version = "1.0.0" + llm.function = Function.TEXT_GENERATION + llm.temperature = 0.7 + + # Mock get_parameters + mock_params = Mock() + mock_params.to_list.return_value = [{"name": "temperature", "type": "float"}] + llm.get_parameters.return_value = mock_params + + return llm + + def test_evolve_async_with_llm_string(self, mock_agent): + """Test evolve_async with llm as string ID""" + from aixplain.modules.agent import Agent + + # Create a real Agent instance but mock its methods + agent = Agent( + id="test_agent_id", + name="Test Agent", + description="Test Description", + instructions="Test Instructions", + tools=[], + llm_id="6646261c6eb563165658bbb1", + ) + + # Mock the run_async method + mock_response = AgentResponse( + status=ResponseStatus.IN_PROGRESS, + url="http://test-poll-url.com", + data=AgentResponseData(input="test input"), + run_time=0.0, + used_credits=0.0, + ) + + with patch.object(agent, "run_async", return_value=mock_response) as mock_run_async: + result = agent.evolve_async(llm="custom_llm_id_123") + + # Verify run_async was called with correct evolve parameter + mock_run_async.assert_called_once() + call_args = mock_run_async.call_args + + # Check that evolve parameter contains llm + evolve_param = call_args[1]["evolve"] + assert isinstance(evolve_param, EvolveParam) + assert evolve_param.llm == {"id": "custom_llm_id_123"} + + assert result == mock_response + + def test_evolve_async_with_llm_object(self, mock_agent, mock_llm): + """Test evolve_async with llm as LLM object""" + from aixplain.modules.agent import Agent + + # Create a real Agent instance but mock its methods + agent = Agent( + id="test_agent_id", + name="Test Agent", + description="Test Description", + instructions="Test Instructions", + tools=[], + llm_id="6646261c6eb563165658bbb1", + ) + + # Mock the run_async method + mock_response = AgentResponse( + status=ResponseStatus.IN_PROGRESS, + url="http://test-poll-url.com", + data=AgentResponseData(input="test input"), + run_time=0.0, + used_credits=0.0, + ) + + with patch.object(agent, "run_async", return_value=mock_response) as mock_run_async: + result = agent.evolve_async(llm=mock_llm) + + # Verify run_async was called with correct evolve parameter + mock_run_async.assert_called_once() + call_args = mock_run_async.call_args + + # Check that evolve parameter contains llm + evolve_param = call_args[1]["evolve"] + assert isinstance(evolve_param, EvolveParam) + + expected_llm_dict = { + "id": "test_llm_id", + "name": "Test LLM", + "description": "Test LLM Description", + "supplier": Supplier.OPENAI, + "version": "1.0.0", + "function": Function.TEXT_GENERATION, + "parameters": [{"name": "temperature", "type": "float"}], + "temperature": 0.7, + } + assert evolve_param.llm == expected_llm_dict + + assert result == mock_response + + def test_evolve_async_without_llm(self, mock_agent): + """Test evolve_async without llm parameter""" + from aixplain.modules.agent import Agent + + # Create a real Agent instance but mock its methods + agent = Agent( + id="test_agent_id", + name="Test Agent", + description="Test Description", + instructions="Test Instructions", + tools=[], + llm_id="6646261c6eb563165658bbb1", + ) + + # Mock the run_async method + mock_response = AgentResponse( + status=ResponseStatus.IN_PROGRESS, + url="http://test-poll-url.com", + data=AgentResponseData(input="test input"), + run_time=0.0, + used_credits=0.0, + ) + + with patch.object(agent, "run_async", return_value=mock_response) as mock_run_async: + result = agent.evolve_async() + + # Verify run_async was called with correct evolve parameter + mock_run_async.assert_called_once() + call_args = mock_run_async.call_args + + # Check that evolve parameter has llm as None + evolve_param = call_args[1]["evolve"] + assert isinstance(evolve_param, EvolveParam) + assert evolve_param.llm is None + + assert result == mock_response + + def test_evolve_with_custom_parameters(self, mock_agent): + """Test evolve with custom parameters including llm""" + from aixplain.modules.agent import Agent + + # Create a real Agent instance but mock its methods + agent = Agent( + id="test_agent_id", + name="Test Agent", + description="Test Description", + instructions="Test Instructions", + tools=[], + llm_id="6646261c6eb563165658bbb1", + ) + + with patch.object(agent, "evolve_async") as mock_evolve_async, patch.object(agent, "sync_poll") as mock_sync_poll: + + # Mock evolve_async response + mock_evolve_async.return_value = {"status": ResponseStatus.IN_PROGRESS, "url": "http://test-poll-url.com"} + + # Mock sync_poll response + mock_result = Mock() + mock_result.data = {"current_code": "test code", "evolved_agent": "evolved_agent_data"} + mock_sync_poll.return_value = mock_result + + result = agent.evolve( + evolve_type=EvolveType.TEAM_TUNING, + max_successful_generations=5, + max_failed_generation_retries=3, + max_iterations=40, + max_non_improving_generations=3, + llm="custom_llm_id", + ) + + # Verify evolve_async was called with correct parameters + mock_evolve_async.assert_called_once_with( + evolve_type=EvolveType.TEAM_TUNING, + max_successful_generations=5, + max_failed_generation_retries=3, + max_iterations=40, + max_non_improving_generations=3, + llm="custom_llm_id", + ) + + # Verify sync_poll was called + mock_sync_poll.assert_called_once_with("http://test-poll-url.com", name="evolve_process", timeout=600) + + assert result is not None diff --git a/tests/unit/agent/test_evolver_llm_utils.py b/tests/unit/agent/test_evolver_llm_utils.py new file mode 100644 index 00000000..1997115b --- /dev/null +++ b/tests/unit/agent/test_evolver_llm_utils.py @@ -0,0 +1,129 @@ +""" +Unit tests for evolver LLM utility functions +""" + +from unittest.mock import Mock +from aixplain.utils.evolve_utils import create_llm_dict +from aixplain.modules.model.llm_model import LLM +from aixplain.enums import Function, Supplier + + +class TestCreateLLMDict: + """Test class for create_llm_dict functionality""" + + def test_create_llm_dict_with_none(self): + """Test create_llm_dict with None input""" + result = create_llm_dict(None) + assert result is None + + def test_create_llm_dict_with_string_id(self): + """Test create_llm_dict with LLM ID string""" + llm_id = "test_llm_id_123" + result = create_llm_dict(llm_id) + + expected = {"id": llm_id} + assert result == expected + + def test_create_llm_dict_with_llm_object(self): + """Test create_llm_dict with LLM object""" + # Create a mock LLM object + mock_llm = Mock(spec=LLM) + mock_llm.id = "llm_id_456" + mock_llm.name = "Test LLM Model" + mock_llm.description = "A test LLM model for unit testing" + mock_llm.supplier = Supplier.OPENAI + mock_llm.version = "1.0.0" + mock_llm.function = Function.TEXT_GENERATION + mock_llm.temperature = 0.7 + + # Mock the get_parameters method + mock_parameters = Mock() + mock_parameters.to_list.return_value = [ + {"name": "max_tokens", "type": "integer", "default": 2048}, + {"name": "temperature", "type": "float", "default": 0.7}, + ] + mock_llm.get_parameters.return_value = mock_parameters + + result = create_llm_dict(mock_llm) + + expected = { + "id": "llm_id_456", + "name": "Test LLM Model", + "description": "A test LLM model for unit testing", + "supplier": Supplier.OPENAI, + "version": "1.0.0", + "function": Function.TEXT_GENERATION, + "parameters": [ + {"name": "max_tokens", "type": "integer", "default": 2048}, + {"name": "temperature", "type": "float", "default": 0.7}, + ], + "temperature": 0.7, + } + assert result == expected + + def test_create_llm_dict_with_llm_object_no_parameters(self): + """Test create_llm_dict with LLM object that has no parameters""" + # Create a mock LLM object + mock_llm = Mock(spec=LLM) + mock_llm.id = "llm_id_789" + mock_llm.name = "Simple LLM" + mock_llm.description = "A simple LLM without parameters" + mock_llm.supplier = Supplier.OPENAI + mock_llm.version = "2.0.0" + mock_llm.function = Function.TEXT_GENERATION + mock_llm.temperature = 0.5 + + # Mock get_parameters to return None + mock_llm.get_parameters.return_value = None + + result = create_llm_dict(mock_llm) + + expected = { + "id": "llm_id_789", + "name": "Simple LLM", + "description": "A simple LLM without parameters", + "supplier": Supplier.OPENAI, + "version": "2.0.0", + "function": Function.TEXT_GENERATION, + "parameters": None, + "temperature": 0.5, + } + assert result == expected + + def test_create_llm_dict_with_llm_object_no_temperature(self): + """Test create_llm_dict with LLM object that has no temperature attribute""" + # Create a mock LLM object without temperature + mock_llm = Mock(spec=LLM) + mock_llm.id = "llm_id_999" + mock_llm.name = "No Temp LLM" + mock_llm.description = "LLM without temperature" + mock_llm.supplier = Supplier.GOOGLE + mock_llm.version = "3.0.0" + mock_llm.function = Function.TEXT_GENERATION + + # Remove temperature attribute + del mock_llm.temperature + + # Mock get_parameters to return None + mock_llm.get_parameters.return_value = None + + result = create_llm_dict(mock_llm) + + expected = { + "id": "llm_id_999", + "name": "No Temp LLM", + "description": "LLM without temperature", + "supplier": Supplier.GOOGLE, + "version": "3.0.0", + "function": Function.TEXT_GENERATION, + "parameters": None, + "temperature": None, + } + assert result == expected + + def test_create_llm_dict_with_empty_string(self): + """Test create_llm_dict with empty string""" + result = create_llm_dict("") + + expected = {"id": ""} + assert result == expected diff --git a/tests/unit/index_model_test.py b/tests/unit/index_model_test.py index 8d5c3a74..5f5e13eb 100644 --- a/tests/unit/index_model_test.py +++ b/tests/unit/index_model_test.py @@ -252,3 +252,60 @@ def test_index_model_splitter(): assert splitter.split_by == "sentence" assert splitter.split_length == 100 assert splitter.split_overlap == 0 + + +def test_parse_file_success(mocker): + mock_response = {"status": "SUCCESS", "data": "parsed content"} + mock_model = mocker.Mock() + mock_model.run.return_value = ModelResponse(status=ResponseStatus.SUCCESS, data="parsed content") + + mocker.patch("aixplain.factories.ModelFactory.get", return_value=mock_model) + mocker.patch("os.path.exists", return_value=True) + + response = IndexModel.parse_file("test.pdf") + + assert isinstance(response, ModelResponse) + assert response.status == ResponseStatus.SUCCESS + assert response.data == "parsed content" + mock_model.run.assert_called_once_with("test.pdf") + + +def test_parse_file_not_found(): + with pytest.raises(Exception) as e: + IndexModel.parse_file("nonexistent.pdf") + assert str(e.value) == "File nonexistent.pdf does not exist" + + +def test_parse_file_error(mocker): + mocker.patch("os.path.exists", return_value=True) + mocker.patch("aixplain.factories.ModelFactory.get", side_effect=Exception("Model error")) + + with pytest.raises(Exception) as e: + IndexModel.parse_file("test.pdf") + assert str(e.value) == "Failed to parse file: Model error" + + +def test_upsert_with_file_path(mocker): + mock_parse_response = ModelResponse(status=ResponseStatus.SUCCESS, data="parsed content") + mock_upsert_response = {"status": "SUCCESS"} + + mocker.patch("aixplain.modules.model.index_model.IndexModel.parse_file", return_value=mock_parse_response) + mocker.patch("aixplain.factories.FileFactory.check_storage_type", return_value=StorageType.TEXT) + + with requests_mock.Mocker() as mock: + mock.post(execute_url, json=mock_upsert_response, status_code=200) + index_model = IndexModel(id=index_id, data=data, name="name", function=Function.SEARCH) + response = index_model.upsert("test.pdf") + + assert isinstance(response, ModelResponse) + assert response.status == ResponseStatus.SUCCESS + + +def test_upsert_with_invalid_file_path(mocker): + mocker.patch("aixplain.modules.model.index_model.IndexModel.parse_file", side_effect=Exception("File not found")) + + index_model = IndexModel(id=index_id, data=data, name="name", function=Function.SEARCH) + + with pytest.raises(Exception) as e: + index_model.upsert("nonexistent.pdf") + assert str(e.value) == "File not found" diff --git a/tests/unit/team_agent/inspector_test.py b/tests/unit/team_agent/inspector_test.py index e21d6c53..aa8933d1 100644 --- a/tests/unit/team_agent/inspector_test.py +++ b/tests/unit/team_agent/inspector_test.py @@ -1,9 +1,21 @@ import pytest from unittest.mock import patch, MagicMock -from aixplain.modules.team_agent.inspector import Inspector, InspectorPolicy, InspectorAuto, AUTO_DEFAULT_MODEL_ID +from aixplain.modules.team_agent.inspector import ( + Inspector, + InspectorPolicy, + InspectorAuto, + AUTO_DEFAULT_MODEL_ID, + InspectorAction, + InspectorOutput, + callable_to_code_string, + code_string_to_callable, + get_policy_source, +) from aixplain.factories.team_agent_factory.inspector_factory import InspectorFactory from aixplain.enums.function import Function from aixplain.enums.asset_status import AssetStatus +from aixplain.modules.model.response import ModelResponse +from aixplain.enums.response_status import ResponseStatus # Test data INSPECTOR_CONFIG = { @@ -27,118 +39,518 @@ } -def test_inspector_creation(): - """Test basic inspector creation with valid parameters""" - inspector = Inspector( - name=INSPECTOR_CONFIG["name"], - model_id=INSPECTOR_CONFIG["model_id"], - model_params=INSPECTOR_CONFIG["model_config"], - policy=INSPECTOR_CONFIG["policy"], - ) +class TestInspectorCreation: + """Test inspector creation with various configurations""" - assert inspector.name == INSPECTOR_CONFIG["name"] - assert inspector.model_id == INSPECTOR_CONFIG["model_id"] - assert inspector.model_params == INSPECTOR_CONFIG["model_config"] - assert inspector.policy == INSPECTOR_CONFIG["policy"] - assert inspector.auto is None + def test_basic_inspector_creation(self): + """Test basic inspector creation with valid parameters""" + inspector = Inspector( + name=INSPECTOR_CONFIG["name"], + model_id=INSPECTOR_CONFIG["model_id"], + model_params=INSPECTOR_CONFIG["model_config"], + policy=INSPECTOR_CONFIG["policy"], + ) + + assert inspector.name == INSPECTOR_CONFIG["name"] + assert inspector.model_id == INSPECTOR_CONFIG["model_id"] + assert inspector.model_params == INSPECTOR_CONFIG["model_config"] + assert inspector.policy == INSPECTOR_CONFIG["policy"] + assert inspector.auto is None + def test_inspector_with_callable_policy(self): + """Test inspector creation with valid callable policy""" -def test_inspector_auto_creation(): - """Test inspector creation with auto configuration""" - inspector = Inspector(name="auto_inspector", auto=InspectorAuto.CORRECTNESS, policy=InspectorPolicy.WARN) + def process_response(model_response: ModelResponse, input_content: str) -> InspectorOutput: + if "error" in model_response.error_message.lower(): + return InspectorOutput(critiques="Error detected", content_edited="", action=InspectorAction.ABORT) + return InspectorOutput(critiques="No issues", content_edited="", action=InspectorAction.CONTINUE) - assert inspector.name == "auto_inspector" - assert inspector.auto == InspectorAuto.CORRECTNESS - assert inspector.policy == InspectorPolicy.WARN - assert inspector.model_id == AUTO_DEFAULT_MODEL_ID - assert inspector.model_params is None + inspector = Inspector( + name=INSPECTOR_CONFIG["name"], + model_id=INSPECTOR_CONFIG["model_id"], + model_params=INSPECTOR_CONFIG["model_config"], + policy=process_response, + ) + assert inspector.name == INSPECTOR_CONFIG["name"] + assert callable(inspector.policy) + assert inspector.policy.__name__ == "process_response" -def test_inspector_name_validation(): - """Test inspector name validation""" - with pytest.raises(ValueError, match="name cannot be empty"): - Inspector(name="", model_id="test_model_id") + def test_inspector_auto_creation(self): + """Test inspector creation with auto configuration""" + inspector = Inspector(name="auto_inspector", auto=InspectorAuto.CORRECTNESS, policy=InspectorPolicy.WARN) + assert inspector.name == "auto_inspector" + assert inspector.auto == InspectorAuto.CORRECTNESS + assert inspector.policy == InspectorPolicy.WARN + assert inspector.model_id == AUTO_DEFAULT_MODEL_ID + assert inspector.model_params is None -def test_inspector_factory_create_from_model(): - """Test creating inspector from model using factory""" - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.json.return_value = { - **MOCK_MODEL_RESPONSE, - "status": AssetStatus.ONBOARDED.value, - "function": {"id": Function.GUARDRAILS.value}, - } + def test_inspector_auto_creation_with_callable_policy(self): + """Test inspector creation with auto configuration and callable policy""" - with patch("aixplain.factories.team_agent_factory.inspector_factory._request_with_retry", return_value=mock_response): - inspector = InspectorFactory.create_from_model( - name=INSPECTOR_CONFIG["name"], - model=INSPECTOR_CONFIG["model_id"], - model_config=INSPECTOR_CONFIG["model_config"], - policy=INSPECTOR_CONFIG["policy"], + def process_response(model_response: ModelResponse, input_content: str) -> InspectorOutput: + return InspectorOutput(critiques="Test critique", content_edited="", action=InspectorAction.RERUN) + + inspector = Inspector(name="auto_inspector", auto=InspectorAuto.CORRECTNESS, policy=process_response) + + assert inspector.name == "auto_inspector" + assert inspector.auto == InspectorAuto.CORRECTNESS + assert inspector.policy == process_response + assert callable(inspector.policy) + assert inspector.model_id == AUTO_DEFAULT_MODEL_ID + assert inspector.model_params is None + + def test_inspector_name_validation(self): + """Test inspector name validation""" + with pytest.raises(ValueError, match="name cannot be empty"): + Inspector(name="", model_id="test_model_id") + + +class TestInspectorValidation: + """Test inspector validation and error handling""" + + def test_invalid_callable_name(self): + """Test inspector creation with callable that has wrong function name""" + + def wrong_name(model_response: ModelResponse, input_content: str) -> InspectorOutput: + return InspectorOutput(critiques="Test", content_edited="", action=InspectorAction.CONTINUE) + + with pytest.raises(ValueError, match="Policy callable must have name 'process_response'"): + Inspector( + name=INSPECTOR_CONFIG["name"], + model_id=INSPECTOR_CONFIG["model_id"], + model_params=INSPECTOR_CONFIG["model_config"], + policy=wrong_name, + ) + + def test_invalid_callable_arguments(self): + """Test inspector creation with callable that has wrong arguments""" + + def process_response(wrong_arg: ModelResponse, another_wrong_arg: str) -> InspectorOutput: + return InspectorOutput(critiques="Test", content_edited="", action=InspectorAction.CONTINUE) + + with pytest.raises(ValueError, match="Policy callable must have name 'process_response'"): + Inspector( + name=INSPECTOR_CONFIG["name"], + model_id=INSPECTOR_CONFIG["model_id"], + model_params=INSPECTOR_CONFIG["model_config"], + policy=process_response, + ) + + def test_invalid_callable_return_type(self): + """Test inspector creation with callable that has wrong return type""" + + def process_response(model_response: ModelResponse, input_content: str) -> str: + return "continue" + + with pytest.raises(ValueError, match="Policy callable must have name 'process_response'"): + Inspector( + name=INSPECTOR_CONFIG["name"], + model_id=INSPECTOR_CONFIG["model_id"], + model_params=INSPECTOR_CONFIG["model_config"], + policy=process_response, + ) + + def test_invalid_policy_type(self): + """Test inspector creation with invalid policy type""" + with pytest.raises(ValueError, match="Input should be"): + Inspector( + name=INSPECTOR_CONFIG["name"], + model_id=INSPECTOR_CONFIG["model_id"], + model_params=INSPECTOR_CONFIG["model_config"], + policy=123, # Invalid type + ) + + +class TestCodeStringConversion: + """Test conversion between callable functions and code strings""" + + def test_callable_to_code_string(self): + """Test converting callable to code string""" + + def process_response(model_response: ModelResponse, input_content: str) -> InspectorOutput: + if "error" in model_response.error_message.lower(): + return InspectorOutput(critiques="Error detected", content_edited="", action=InspectorAction.ABORT) + return InspectorOutput(critiques="No issues", content_edited="", action=InspectorAction.CONTINUE) + + code_string = callable_to_code_string(process_response) + assert isinstance(code_string, str) + assert "def process_response" in code_string + assert "model_response" in code_string + assert "input_content" in code_string + assert "InspectorAction.ABORT" in code_string + + def test_code_string_to_callable(self): + """Test converting code string back to callable""" + code_string = """def process_response(model_response: ModelResponse, input_content: str) -> InspectorOutput: + if "error" in model_response.error_message.lower(): + return InspectorOutput(critiques="Error detected", content_edited="", action=InspectorAction.ABORT) + return InspectorOutput(critiques="No issues", content_edited="", action=InspectorAction.CONTINUE)""" + + func = code_string_to_callable(code_string) + assert callable(func) + assert func.__name__ == "process_response" + + # Test the function works correctly + result1 = func(ModelResponse(status=ResponseStatus.FAILED, error_message="This is an error message"), "input") + assert result1.action == InspectorAction.ABORT + + result2 = func(ModelResponse(status=ResponseStatus.SUCCESS, data="This is a normal message"), "input") + assert result2.action == InspectorAction.CONTINUE + + def test_roundtrip_conversion(self): + """Test that serialization and deserialization work correctly together""" + + def process_response(model_response: ModelResponse, input_content: str) -> InspectorOutput: + if "error" in model_response.error_message.lower(): + return InspectorOutput(critiques="Error detected", content_edited="", action=InspectorAction.ABORT) + elif "warning" in model_response.data.lower(): + return InspectorOutput(critiques="Warning detected", content_edited="", action=InspectorAction.RERUN) + return InspectorOutput(critiques="No issues", content_edited="", action=InspectorAction.CONTINUE) + + # Serialize + code_string = callable_to_code_string(process_response) + + # Deserialize + deserialized_func = code_string_to_callable(code_string) + + # Test that the deserialized function works the same + assert ( + deserialized_func(ModelResponse(status=ResponseStatus.FAILED, error_message="error message"), "input").action + == InspectorAction.ABORT + ) + assert ( + deserialized_func(ModelResponse(status=ResponseStatus.SUCCESS, data="warning message"), "input").action + == InspectorAction.RERUN + ) + assert ( + deserialized_func(ModelResponse(status=ResponseStatus.SUCCESS, data="normal message"), "input").action + == InspectorAction.CONTINUE ) - assert inspector.name == INSPECTOR_CONFIG["name"] - assert inspector.model_id == INSPECTOR_CONFIG["model_id"] - assert inspector.model_params == INSPECTOR_CONFIG["model_config"] - assert inspector.policy == INSPECTOR_CONFIG["policy"] + def test_source_code_preservation(self): + """Test that code_string_to_callable preserves the original source code""" + code_string = """def process_response(model_response: ModelResponse, input_content: str) -> InspectorOutput: + if "error" in model_response.error_message.lower(): + return InspectorOutput(critiques="Error detected", content_edited="", action=InspectorAction.ABORT) + elif "warning" in model_response.data.lower(): + return InspectorOutput(critiques="Warning detected", content_edited="", action=InspectorAction.RERUN) + return InspectorOutput(critiques="No issues", content_edited="", action=InspectorAction.CONTINUE)""" + func = code_string_to_callable(code_string) -def test_inspector_factory_create_from_model_invalid_status(): - """Test creating inspector from model with invalid status""" - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.json.return_value = { - **MOCK_MODEL_RESPONSE, - "status": AssetStatus.DRAFT.value, - "function": {"id": Function.GUARDRAILS.value}, - } + # Verify the function has the _source_code attribute + assert hasattr(func, "_source_code") + assert func._source_code == code_string - with patch("aixplain.factories.team_agent_factory.inspector_factory._request_with_retry", return_value=mock_response): - with pytest.raises(ValueError, match="is not onboarded"): - InspectorFactory.create_from_model( + # Verify the function works correctly + assert ( + func(ModelResponse(status=ResponseStatus.FAILED, error_message="This is an error message"), "input").action + == InspectorAction.ABORT + ) + assert ( + func(ModelResponse(status=ResponseStatus.SUCCESS, data="This is a warning message"), "input").action + == InspectorAction.RERUN + ) + assert ( + func(ModelResponse(status=ResponseStatus.SUCCESS, data="This is a normal message"), "input").action + == InspectorAction.CONTINUE + ) + + +class TestSourceCodeRetrieval: + """Test source code retrieval functionality""" + + def test_get_policy_source_original_function(self): + """Test get_policy_source with an original function (should use inspect.getsource)""" + + def process_response(model_response: ModelResponse, input_content: str) -> InspectorOutput: + if "error" in model_response.error_message.lower(): + return InspectorOutput(critiques="Error detected", content_edited="", action=InspectorAction.ABORT) + return InspectorOutput(critiques="No issues", content_edited="", action=InspectorAction.CONTINUE) + + source = get_policy_source(process_response) + assert source is not None + assert "def process_response" in source + assert "InspectorAction.ABORT" in source + + def test_get_policy_source_deserialized_function(self): + """Test get_policy_source with a deserialized function (should use _source_code attribute)""" + code_string = """def process_response(model_response: ModelResponse, input_content: str) -> InspectorOutput: + if "error" in model_response.error_message.lower(): + return InspectorOutput(critiques="Error detected", content_edited="", action=InspectorAction.ABORT) + return InspectorOutput(critiques="No issues", content_edited="", action=InspectorAction.CONTINUE)""" + + func = code_string_to_callable(code_string) + + # Verify get_policy_source works with the deserialized function + source = get_policy_source(func) + assert source is not None + assert source == code_string + + def test_get_policy_source_fallback(self): + """Test get_policy_source fallback when neither approach works""" + # Create a function without source code info by using exec() + # This simulates a function created dynamically where inspect.getsource() would fail + namespace = {} + exec( + "def dynamic_func(x, y): return InspectorOutput(critiques='', content_edited='', action=InspectorAction.CONTINUE)", + namespace, + ) + func = namespace["dynamic_func"] + + # Remove any potential source code attributes + if hasattr(func, "_source_code"): + delattr(func, "_source_code") + + source = get_policy_source(func) + assert source is None + + +class TestInspectorSerialization: + """Test Inspector serialization and deserialization""" + + def test_model_dump_with_callable_policy(self): + """Test that Inspector.model_dump properly serializes callable policies""" + + def process_response(model_response: ModelResponse, input_content: str) -> InspectorOutput: + return InspectorOutput(critiques="Test critique", content_edited="", action=InspectorAction.ABORT) + + inspector = Inspector( + name="test_inspector", + model_id="test_model_id", + policy=process_response, + ) + + data = inspector.model_dump() + assert data["policy_type"] == "callable" + assert isinstance(data["policy"], str) + assert "def process_response" in data["policy"] + + def test_model_dump_with_enum_policy(self): + """Test that Inspector.model_dump properly serializes enum policies""" + inspector = Inspector( + name="test_inspector", + model_id="test_model_id", + policy=InspectorPolicy.WARN, + ) + + data = inspector.model_dump() + assert data["policy_type"] == "enum" + assert data["policy"] == "warn" + + def test_model_validate_with_callable_policy(self): + """Test that Inspector.model_validate properly deserializes callable policies""" + inspector_data = { + "name": "test_inspector", + "model_id": "test_model_id", + "policy": """def process_response(model_response: ModelResponse, input_content: str) -> InspectorOutput: + return InspectorOutput(critiques="Test critique", content_edited="", action=InspectorAction.ABORT)""", + "policy_type": "callable", + } + + inspector = Inspector.model_validate(inspector_data) + assert callable(inspector.policy) + assert inspector.policy.__name__ == "process_response" + result = inspector.policy(ModelResponse(status=ResponseStatus.SUCCESS, data="test"), "input") + assert result.action == InspectorAction.ABORT + + def test_model_validate_with_enum_policy(self): + """Test that Inspector.model_validate properly deserializes enum policies""" + inspector_data = { + "name": "test_inspector", + "model_id": "test_model_id", + "policy": "warn", + "policy_type": "enum", + } + + inspector = Inspector.model_validate(inspector_data) + assert inspector.policy == InspectorPolicy.WARN + + def test_model_validate_fallback(self): + """Test that Inspector.model_validate falls back to default policy on error""" + inspector_data = { + "name": "test_inspector", + "model_id": "test_model_id", + "policy": "invalid code string", + "policy_type": "callable", + } + + inspector = Inspector.model_validate(inspector_data) + assert inspector.policy == InspectorPolicy.ADAPTIVE # Default fallback + + def test_roundtrip_serialization_preserves_source_code(self): + """Test that Inspector round-trip serialization preserves source code""" + + def process_response(model_response: ModelResponse, input_content: str) -> InspectorOutput: + if "error" in model_response.error_message.lower(): + return InspectorOutput(critiques="Error detected", content_edited="", action=InspectorAction.ABORT) + elif "warning" in model_response.data.lower(): + return InspectorOutput(critiques="Warning detected", content_edited="", action=InspectorAction.RERUN) + return InspectorOutput(critiques="No issues", content_edited="", action=InspectorAction.CONTINUE) + + # Create inspector with callable policy + inspector = Inspector( + name="test_inspector", + model_id="test_model_id", + policy=process_response, + ) + + # Serialize to dict + inspector_dict = inspector.model_dump() + assert inspector_dict["policy_type"] == "callable" + assert isinstance(inspector_dict["policy"], str) + + # Deserialize from dict + inspector_copy = Inspector.model_validate(inspector_dict) + assert callable(inspector_copy.policy) + assert inspector_copy.policy.__name__ == "process_response" + + # Verify the deserialized function has source code and works correctly + assert hasattr(inspector_copy.policy, "_source_code") + assert "def process_response" in inspector_copy.policy._source_code + assert ( + inspector_copy.policy( + ModelResponse(status=ResponseStatus.FAILED, error_message="This is an error message"), "input" + ).action + == InspectorAction.ABORT + ) + assert ( + inspector_copy.policy( + ModelResponse(status=ResponseStatus.SUCCESS, data="This is a warning message"), "input" + ).action + == InspectorAction.RERUN + ) + assert ( + inspector_copy.policy(ModelResponse(status=ResponseStatus.SUCCESS, data="This is a normal message"), "input").action + == InspectorAction.CONTINUE + ) + + +class TestInspectorFactory: + """Test InspectorFactory functionality""" + + def test_create_from_model(self): + """Test creating inspector from model using factory""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + **MOCK_MODEL_RESPONSE, + "status": AssetStatus.ONBOARDED.value, + "function": {"id": Function.GUARDRAILS.value}, + } + + with patch("aixplain.factories.team_agent_factory.inspector_factory._request_with_retry", return_value=mock_response): + inspector = InspectorFactory.create_from_model( name=INSPECTOR_CONFIG["name"], model=INSPECTOR_CONFIG["model_id"], model_config=INSPECTOR_CONFIG["model_config"], policy=INSPECTOR_CONFIG["policy"], ) + assert inspector.name == INSPECTOR_CONFIG["name"] + assert inspector.model_id == INSPECTOR_CONFIG["model_id"] + assert inspector.model_params == INSPECTOR_CONFIG["model_config"] + assert inspector.policy == INSPECTOR_CONFIG["policy"] + + def test_create_from_model_with_callable_policy(self): + """Test creating inspector from model using factory with callable policy""" -def test_inspector_factory_create_from_model_invalid_function(): - """Test creating inspector from model with invalid function""" - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.json.return_value = { - **MOCK_MODEL_RESPONSE, - "status": AssetStatus.ONBOARDED.value, - "function": {"id": Function.TRANSLATION.value}, - } + def process_response(model_response: ModelResponse, input_content: str) -> InspectorOutput: + return InspectorOutput(critiques="Test critique", content_edited="", action=InspectorAction.CONTINUE) - with patch("aixplain.factories.team_agent_factory.inspector_factory._request_with_retry", return_value=mock_response): - with pytest.raises(ValueError, match="models are supported"): - InspectorFactory.create_from_model( + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + **MOCK_MODEL_RESPONSE, + "status": AssetStatus.ONBOARDED.value, + "function": {"id": Function.GUARDRAILS.value}, + } + + with patch("aixplain.factories.team_agent_factory.inspector_factory._request_with_retry", return_value=mock_response): + inspector = InspectorFactory.create_from_model( name=INSPECTOR_CONFIG["name"], model=INSPECTOR_CONFIG["model_id"], model_config=INSPECTOR_CONFIG["model_config"], - policy=INSPECTOR_CONFIG["policy"], + policy=process_response, ) + assert inspector.name == INSPECTOR_CONFIG["name"] + assert inspector.model_id == INSPECTOR_CONFIG["model_id"] + assert inspector.model_params == INSPECTOR_CONFIG["model_config"] + assert inspector.policy == process_response + assert callable(inspector.policy) + + def test_create_from_model_invalid_status(self): + """Test creating inspector from model with invalid status""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + **MOCK_MODEL_RESPONSE, + "status": AssetStatus.DRAFT.value, + "function": {"id": Function.GUARDRAILS.value}, + } + + with patch("aixplain.factories.team_agent_factory.inspector_factory._request_with_retry", return_value=mock_response): + with pytest.raises(ValueError, match="is not onboarded"): + InspectorFactory.create_from_model( + name=INSPECTOR_CONFIG["name"], + model=INSPECTOR_CONFIG["model_id"], + model_config=INSPECTOR_CONFIG["model_config"], + policy=INSPECTOR_CONFIG["policy"], + ) + + def test_create_from_model_invalid_function(self): + """Test creating inspector from model with invalid function""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + **MOCK_MODEL_RESPONSE, + "status": AssetStatus.ONBOARDED.value, + "function": {"id": Function.TRANSLATION.value}, + } + + with patch("aixplain.factories.team_agent_factory.inspector_factory._request_with_retry", return_value=mock_response): + with pytest.raises(ValueError, match="models are supported"): + InspectorFactory.create_from_model( + name=INSPECTOR_CONFIG["name"], + model=INSPECTOR_CONFIG["model_id"], + model_config=INSPECTOR_CONFIG["model_config"], + policy=INSPECTOR_CONFIG["policy"], + ) + + def test_create_auto(self): + """Test creating auto-configured inspector using factory""" + inspector = InspectorFactory.create_auto( + auto=InspectorAuto.CORRECTNESS, name="custom_name", policy=InspectorPolicy.ABORT + ) + + assert inspector.name == "custom_name" + assert inspector.auto == InspectorAuto.CORRECTNESS + assert inspector.policy == InspectorPolicy.ABORT + assert inspector.model_id == AUTO_DEFAULT_MODEL_ID + assert inspector.model_params is None + + def test_create_auto_with_callable_policy(self): + """Test creating auto-configured inspector using factory with callable policy""" -def test_inspector_factory_create_auto(): - """Test creating auto-configured inspector using factory""" - inspector = InspectorFactory.create_auto(auto=InspectorAuto.CORRECTNESS, name="custom_name", policy=InspectorPolicy.ABORT) + def process_response(model_response: ModelResponse, input_content: str) -> InspectorOutput: + return InspectorOutput(critiques="Test critique", content_edited="", action=InspectorAction.ABORT) - assert inspector.name == "custom_name" - assert inspector.auto == InspectorAuto.CORRECTNESS - assert inspector.policy == InspectorPolicy.ABORT - assert inspector.model_id == AUTO_DEFAULT_MODEL_ID - assert inspector.model_params is None + inspector = InspectorFactory.create_auto(auto=InspectorAuto.CORRECTNESS, name="custom_name", policy=process_response) + assert inspector.name == "custom_name" + assert inspector.auto == InspectorAuto.CORRECTNESS + assert inspector.policy == process_response + assert callable(inspector.policy) + assert inspector.model_id == AUTO_DEFAULT_MODEL_ID + assert inspector.model_params is None -def test_inspector_factory_create_auto_default_name(): - """Test creating auto-configured inspector with default name""" - inspector = InspectorFactory.create_auto(auto=InspectorAuto.CORRECTNESS) + def test_create_auto_default_name(self): + """Test creating auto-configured inspector with default name""" + inspector = InspectorFactory.create_auto(auto=InspectorAuto.CORRECTNESS) - assert inspector.name == "inspector_correctness" - assert inspector.auto == InspectorAuto.CORRECTNESS - assert inspector.policy == InspectorPolicy.ADAPTIVE # default policy + assert inspector.name == "inspector_correctness" + assert inspector.auto == InspectorAuto.CORRECTNESS + assert inspector.policy == InspectorPolicy.ADAPTIVE # default policy