diff --git a/aixplain/factories/team_agent_factory/inspector_factory.py b/aixplain/factories/team_agent_factory/inspector_factory.py index a6d38d4f..65b44dd5 100644 --- a/aixplain/factories/team_agent_factory/inspector_factory.py +++ b/aixplain/factories/team_agent_factory/inspector_factory.py @@ -19,7 +19,7 @@ """ import logging -from typing import Dict, Optional, Text, Union +from typing import Dict, Optional, Text, Union, Callable from urllib.parse import urljoin from aixplain.enums.asset_status import AssetStatus @@ -48,7 +48,7 @@ def create_from_model( name: Text, model: Union[Text, Model], model_config: Optional[Dict] = None, - policy: InspectorPolicy = InspectorPolicy.ADAPTIVE, # default: doing something dynamically + policy: Union[InspectorPolicy, Callable] = InspectorPolicy.ADAPTIVE, # default: doing something dynamically ) -> Inspector: """Create a new inspector agent from an onboarded model. @@ -62,11 +62,10 @@ def create_from_model( to use for the inspector. model_config (Optional[Dict], optional): Configuration parameters for the inspector model (e.g., prompts, thresholds). Defaults to None. - policy (InspectorPolicy, optional): Action to take upon negative feedback: - - WARN: Log warning but continue execution - - ABORT: Stop execution on negative feedback - - ADAPTIVE: Dynamically decide based on context - Defaults to InspectorPolicy.ADAPTIVE. + policy: Action to take upon negative feedback (WARN/ABORT/ADAPTIVE) + or a callable function. If callable, must have name "process_response", + arguments "model_response" and "input_content" (both strings), and + return InspectorAction. Defaults to ADAPTIVE. Returns: Inspector: Created and configured inspector agent. @@ -124,7 +123,7 @@ def create_auto( cls, auto: InspectorAuto, name: Optional[Text] = None, - policy: InspectorPolicy = InspectorPolicy.ADAPTIVE, + policy: Union[InspectorPolicy, Callable] = InspectorPolicy.ADAPTIVE, ) -> Inspector: """Create a new inspector agent using automatic configuration. @@ -136,11 +135,10 @@ def create_auto( auto (InspectorAuto): Pre-configured automatic inspector instance. name (Optional[Text], optional): Name for the inspector. If not provided, uses the name from the auto configuration. Defaults to None. - policy (InspectorPolicy, optional): Action to take upon negative feedback: - - WARN: Log warning but continue execution - - ABORT: Stop execution on negative feedback - - ADAPTIVE: Dynamically decide based on context - Defaults to InspectorPolicy.ADAPTIVE. + policy: Action to take upon negative feedback (WARN/ABORT/ADAPTIVE) + or a callable function. If callable, must have name "process_response", + arguments "model_response" and "input_content" (both strings), and + return InspectorAction. Defaults to ADAPTIVE. Returns: Inspector: Created and configured inspector agent using automatic diff --git a/aixplain/factories/team_agent_factory/utils.py b/aixplain/factories/team_agent_factory/utils.py index 08c7e93f..ee0f9d2d 100644 --- a/aixplain/factories/team_agent_factory/utils.py +++ b/aixplain/factories/team_agent_factory/utils.py @@ -67,9 +67,21 @@ def build_team_agent(payload: Dict, agents: List[Agent] = None, api_key: Text = continue # Ensure custom classes are instantiated: for compatibility with backend return format - inspectors = [ - inspector if isinstance(inspector, Inspector) else Inspector(**inspector) for inspector in payload.get("inspectors", []) - ] + inspectors = [] + for inspector_data in payload.get("inspectors", []): + try: + if isinstance(inspector_data, Inspector): + inspectors.append(inspector_data) + else: + # Handle both old format and new format with policy_type + if hasattr(Inspector, "model_validate"): + inspectors.append(Inspector.model_validate(inspector_data)) + else: + inspectors.append(Inspector(**inspector_data)) + except Exception as e: + logging.warning(f"Failed to create inspector from data: {e}") + continue + inspector_targets = [InspectorTarget(target.lower()) for target in payload.get("inspectorTargets", [])] # Get LLMs from tools if present diff --git a/aixplain/modules/team_agent/__init__.py b/aixplain/modules/team_agent/__init__.py index b5270bd8..497664c7 100644 --- a/aixplain/modules/team_agent/__init__.py +++ b/aixplain/modules/team_agent/__init__.py @@ -540,7 +540,9 @@ def from_dict(cls, data: Dict) -> "TeamAgent": if "inspectors" in data: for inspector_data in data["inspectors"]: try: - if hasattr(Inspector, "model_validate"): + if isinstance(inspector_data, Inspector): + inspectors.append(inspector_data) + elif hasattr(Inspector, "model_validate"): inspectors.append(Inspector.model_validate(inspector_data)) else: inspectors.append(Inspector(**inspector_data)) @@ -548,6 +550,7 @@ def from_dict(cls, data: Dict) -> "TeamAgent": import logging logging.warning(f"Failed to create inspector from data: {e}") + continue # Extract inspector targets inspector_targets = [InspectorTarget.STEPS] # default diff --git a/aixplain/modules/team_agent/inspector.py b/aixplain/modules/team_agent/inspector.py index 8cd12cc0..f2ea74dc 100644 --- a/aixplain/modules/team_agent/inspector.py +++ b/aixplain/modules/team_agent/inspector.py @@ -19,17 +19,40 @@ ) """ +import inspect from enum import Enum -from typing import Dict, Optional, Text +from typing import Dict, Optional, Text, Union, Callable -from pydantic import field_validator +import textwrap +from pydantic import BaseModel, field_validator from aixplain.modules.agent.model_with_params import ModelWithParams +from aixplain.modules.model.response import ModelResponse AUTO_DEFAULT_MODEL_ID = "67fd9e2bef0365783d06e2f0" # GPT-4.1 Nano +class InspectorAction(str, Enum): + """ + Inspector's decision on the next action. + """ + + CONTINUE = "continue" + RERUN = "rerun" + ABORT = "abort" + + +class InspectorOutput(BaseModel): + """ + Inspector's output. + """ + + critiques: Text + content_edited: Text + action: InspectorAction + + class InspectorAuto(str, Enum): """A list of keywords for inspectors configured automatically in the backend.""" @@ -55,6 +78,202 @@ class InspectorPolicy(str, Enum): ADAPTIVE = "adaptive" # adjust execution according to feedback +def validate_policy_callable(policy_func: Callable) -> bool: + """Validate that the policy callable meets the required constraints.""" + # Check function name + if policy_func.__name__ != "process_response": + return False + + # Get function signature + sig = inspect.signature(policy_func) + params = list(sig.parameters.keys()) + + # Check arguments - should have exactly 2 parameters: model_response and input_content + if len(params) != 2 or params[0] != "model_response" or params[1] != "input_content": + return False + + # Check return type annotation - should return InspectorOutput + return_annotation = sig.return_annotation + if return_annotation != InspectorOutput: + return False + + return True + + +def callable_to_code_string(policy_func: Callable) -> str: + """Convert a callable policy function to a code string for serialization.""" + try: + source_code = get_policy_source(policy_func) + if source_code is None: + # If we can't get the source code, create a minimal representation + sig = inspect.signature(policy_func) + return f"def process_response{str(sig)}:\n # Function source not available\n pass" + + # Dedent the source code to remove leading whitespace + source_code = textwrap.dedent(source_code) + return source_code + except (OSError, TypeError): + # If we can't get the source code, create a minimal representation + sig = inspect.signature(policy_func) + return f"def process_response{str(sig)}:\n # Function source not available\n pass" + + +def code_string_to_callable(code_string: str) -> Callable: + """Convert a code string back to a callable function for deserialization.""" + try: + # Create a namespace to execute the code + namespace = { + "InspectorAction": InspectorAction, + "InspectorOutput": InspectorOutput, + "ModelResponse": ModelResponse, + "str": str, + "int": int, + "float": float, + "bool": bool, + "list": list, + "dict": dict, + "tuple": tuple, + "set": set, + "len": len, + "print": print, + "range": range, + "enumerate": enumerate, + "zip": zip, + "map": map, + "filter": filter, + "any": any, + "all": all, + "sum": sum, + "min": min, + "max": max, + "abs": abs, + "round": round, + "sorted": sorted, + "reversed": reversed, + "isinstance": isinstance, + "hasattr": hasattr, + "getattr": getattr, + "setattr": setattr, + "dir": dir, + "type": type, + "id": id, + "hash": hash, + "repr": repr, + "str": str, + "format": format, + "ord": ord, + "chr": chr, + "bin": bin, + "oct": oct, + "hex": hex, + "pow": pow, + "divmod": divmod, + "complex": complex, + "bytes": bytes, + "bytearray": bytearray, + "memoryview": memoryview, + "slice": slice, + "property": property, + "staticmethod": staticmethod, + "classmethod": classmethod, + "super": super, + "object": object, + "Exception": Exception, + "ValueError": ValueError, + "TypeError": TypeError, + "AttributeError": AttributeError, + "KeyError": KeyError, + "IndexError": IndexError, + "RuntimeError": RuntimeError, + "AssertionError": AssertionError, + "ImportError": ImportError, + "ModuleNotFoundError": ModuleNotFoundError, + "NameError": NameError, + "SyntaxError": SyntaxError, + "IndentationError": IndentationError, + "TabError": TabError, + "UnboundLocalError": UnboundLocalError, + "UnicodeError": UnicodeError, + "UnicodeDecodeError": UnicodeDecodeError, + "UnicodeEncodeError": UnicodeEncodeError, + "UnicodeTranslateError": UnicodeTranslateError, + "OSError": OSError, + "FileNotFoundError": FileNotFoundError, + "PermissionError": PermissionError, + "ProcessLookupError": ProcessLookupError, + "TimeoutError": TimeoutError, + "ConnectionError": ConnectionError, + "BrokenPipeError": BrokenPipeError, + "ConnectionAbortedError": ConnectionAbortedError, + "ConnectionRefusedError": ConnectionRefusedError, + "ConnectionResetError": ConnectionResetError, + "BlockingIOError": BlockingIOError, + "ChildProcessError": ChildProcessError, + "NotADirectoryError": NotADirectoryError, + "IsADirectoryError": IsADirectoryError, + "InterruptedError": InterruptedError, + "EnvironmentError": EnvironmentError, + "IOError": IOError, + "EOFError": EOFError, + "MemoryError": MemoryError, + "RecursionError": RecursionError, + "SystemError": SystemError, + "ReferenceError": ReferenceError, + "FloatingPointError": FloatingPointError, + "OverflowError": OverflowError, + "ZeroDivisionError": ZeroDivisionError, + "ArithmeticError": ArithmeticError, + "BufferError": BufferError, + "LookupError": LookupError, + "StopIteration": StopIteration, + "GeneratorExit": GeneratorExit, + "KeyboardInterrupt": KeyboardInterrupt, + "SystemExit": SystemExit, + "BaseException": BaseException, + } + + # Execute the code string in the namespace + exec(code_string, namespace) + + # Get the function from the namespace + if "process_response" not in namespace: + raise ValueError("Code string must define a function named 'process_response'") + + func = namespace["process_response"] + + # Store the original source code as an attribute for later retrieval + func._source_code = code_string + + # Validate the function + if not validate_policy_callable(func): + raise ValueError("Deserialized function does not meet the required constraints") + + return func + except Exception as e: + raise ValueError(f"Failed to deserialize code string to callable: {e}") + + +def get_policy_source(func: Callable) -> Optional[str]: + """Get the source code of a policy function. + + This function tries to retrieve the source code of a policy function. + It first checks if the function has a stored _source_code attribute (for functions + created via code_string_to_callable), then falls back to inspect.getsource(). + + Args: + func: The function to get source code for + + Returns: + The source code string if available, None otherwise + """ + if hasattr(func, "_source_code"): + return func._source_code + try: + return inspect.getsource(func) + except (OSError, TypeError): + return None + + class Inspector(ModelWithParams): """Pre-defined agent for inspecting the data flow within a team agent. @@ -64,13 +283,15 @@ class Inspector(ModelWithParams): name: The name of the inspector. model_id: The ID of the model to wrap. model_params: The configuration for the model. - policy: The policy for the inspector. Default is ADAPTIVE. + policy: The policy for the inspector. Can be InspectorPolicy enum or a callable function. + If callable, must have name "process_response", arguments "model_response" and "input_content" (both strings), + and return InspectorAction. Default is ADAPTIVE. """ name: Text model_params: Optional[Dict] = None auto: Optional[InspectorAuto] = None - policy: InspectorPolicy = InspectorPolicy.ADAPTIVE + policy: Union[InspectorPolicy, Callable] = InspectorPolicy.ADAPTIVE def __init__(self, *args, **kwargs): """Initialize an Inspector instance. @@ -114,3 +335,47 @@ def validate_name(cls, v: Text) -> Text: if v == "": raise ValueError("name cannot be empty") return v + + @field_validator("policy") + def validate_policy(cls, v: Union[InspectorPolicy, Callable]) -> Union[InspectorPolicy, Callable]: + if callable(v): + if not validate_policy_callable(v): + raise ValueError( + "Policy callable must have name 'process_response', arguments 'model_response' and 'input_content' (both strings), and return InspectorAction" + ) + elif not isinstance(v, InspectorPolicy): + raise ValueError(f"Policy must be InspectorPolicy enum or a valid callable function, got {type(v)}") + return v + + def model_dump(self, by_alias: bool = False, **kwargs) -> Dict: + """Override model_dump to handle callable policy serialization.""" + data = super().model_dump(by_alias=by_alias, **kwargs) + + # Handle callable policy serialization + if callable(self.policy): + data["policy"] = callable_to_code_string(self.policy) + data["policy_type"] = "callable" + elif isinstance(self.policy, InspectorPolicy): + data["policy"] = self.policy.value + data["policy_type"] = "enum" + + return data + + @classmethod + def model_validate(cls, data: Union[Dict, "Inspector"]) -> "Inspector": + """Override model_validate to handle callable policy deserialization.""" + if isinstance(data, cls): + return data + + # Handle callable policy deserialization + if isinstance(data, dict) and data.get("policy_type") == "callable": + policy_code = data.get("policy") + if isinstance(policy_code, str): + try: + data["policy"] = code_string_to_callable(policy_code) + except Exception: + # If deserialization fails, fall back to default policy + data["policy"] = InspectorPolicy.ADAPTIVE + data.pop("policy_type", None) # Remove the type indicator + + return super().model_validate(data) diff --git a/tests/functional/team_agent/inspector_functional_test.py b/tests/functional/team_agent/inspector_functional_test.py index e1679dde..97eecd8c 100644 --- a/tests/functional/team_agent/inspector_functional_test.py +++ b/tests/functional/team_agent/inspector_functional_test.py @@ -13,7 +13,9 @@ from aixplain.factories import AgentFactory, TeamAgentFactory from aixplain.enums.asset_status import AssetStatus from aixplain.modules.team_agent import InspectorTarget -from aixplain.modules.team_agent.inspector import Inspector, InspectorPolicy +from aixplain.modules.team_agent.inspector import Inspector, InspectorPolicy, InspectorAction, InspectorOutput +from aixplain.modules.model.response import ModelResponse +from aixplain.enums.response_status import ResponseStatus from tests.functional.team_agent.test_utils import ( RUN_FILE, @@ -24,6 +26,36 @@ ) +# Define callable policy functions at module level for proper serialization +def process_response(model_response: ModelResponse, input_content: str) -> InspectorOutput: + """Basic callable policy function for testing.""" + if "error" in model_response.error_message.lower() or "invalid" in model_response.data.lower(): + return InspectorOutput(critiques="Error or invalid content detected", content_edited="", action=InspectorAction.ABORT) + elif "warning" in model_response.data.lower(): + return InspectorOutput(critiques="Warning detected", content_edited="", action=InspectorAction.RERUN) + return InspectorOutput(critiques="No issues detected", content_edited="", action=InspectorAction.CONTINUE) + + +def process_response_abort(model_response: ModelResponse, input_content: str) -> InspectorOutput: + """Callable policy function that aborts on specific content.""" + abort_keywords = ["dangerous", "harmful", "illegal", "inappropriate"] + for keyword in abort_keywords: + if keyword in model_response.data.lower(): + return InspectorOutput( + critiques=f"Abort keyword '{keyword}' detected", content_edited="", action=InspectorAction.ABORT + ) + return InspectorOutput(critiques="No abort keywords detected", content_edited="", action=InspectorAction.CONTINUE) + + +def process_response_rerun(model_response: ModelResponse, input_content: str) -> InspectorOutput: + """Callable policy function that triggers rerun on specific conditions.""" + if len(model_response.data.strip()) < 10 or "placeholder" in model_response.data.lower(): + return InspectorOutput( + critiques="Content too short or contains placeholder", content_edited="", action=InspectorAction.RERUN + ) + return InspectorOutput(critiques="Content is acceptable", content_edited="", action=InspectorAction.CONTINUE) + + @pytest.fixture(scope="function") def delete_agents_and_team_agents(): for team_agent in TeamAgentFactory.list()["results"]: @@ -557,3 +589,151 @@ def test_team_agent_with_input_adaptive_inspector(run_input_map, delete_agents_a ), "The mentalist input does not contain the revised query from the last query_manager" team_agent.delete() + + +@pytest.mark.parametrize("TeamAgentFactory", [TeamAgentFactory, v2.TeamAgent]) +def test_team_agent_with_callable_policy(run_input_map, delete_agents_and_team_agents, TeamAgentFactory): + """Comprehensive test of callable policy functionality with team agent integration""" + assert delete_agents_and_team_agents + + agents = create_agents_from_input_map(run_input_map) + + # Test 1: Create inspector with callable policy + inspector = Inspector( + name="callable_inspector", + model_id=run_input_map["llm_id"], + model_params={"prompt": "Check if the steps are valid and provide feedback"}, + policy=process_response, # Using module-level callable policy + ) + + # Test 2: Verify the inspector was created correctly + assert inspector.name == "callable_inspector" + assert callable(inspector.policy) + assert inspector.policy.__name__ == "process_response" + + # Test 3: Verify the callable policy works correctly + result1 = inspector.policy( + ModelResponse(status=ResponseStatus.FAILED, error_message="This is an error message", data="input"), "input" + ) + assert result1.action == InspectorAction.ABORT + + result2 = inspector.policy( + ModelResponse(status=ResponseStatus.SUCCESS, data="This is a warning message", error_message=""), "input" + ) + assert result2.action == InspectorAction.RERUN + + result3 = inspector.policy( + ModelResponse(status=ResponseStatus.SUCCESS, data="This is a normal message", error_message=""), "input" + ) + assert result3.action == InspectorAction.CONTINUE + + # Test 4: Create team agent with callable policy inspector + team_agent = create_team_agent( + TeamAgentFactory, + agents, + run_input_map, + use_mentalist=True, + inspectors=[inspector], + inspector_targets=[InspectorTarget.STEPS], + ) + + assert team_agent is not None + assert team_agent.status == AssetStatus.DRAFT + + # Test 5: Deploy team agent (backend properly handles callable policies) + team_agent.deploy() + team_agent = TeamAgentFactory.get(team_agent.id) + assert team_agent is not None + assert team_agent.status == AssetStatus.ONBOARDED + + # Test 6: Verify backend properly handles callable policies + assert len(team_agent.inspectors) == 1 + backend_inspector = team_agent.inspectors[0] + assert backend_inspector.name == "callable_inspector" + # Backend should properly handle callable policies, not fall back to ADAPTIVE + assert callable(backend_inspector.policy) + assert backend_inspector.policy.__name__ == "process_response" + + # Verify the backend-preserved callable policy still works correctly + assert ( + backend_inspector.policy( + ModelResponse(status=ResponseStatus.FAILED, error_message="This is an error message", data="input"), "input" + ).action + == InspectorAction.ABORT + ) + assert ( + backend_inspector.policy( + ModelResponse(status=ResponseStatus.SUCCESS, data="This is a warning message", error_message=""), "input" + ).action + == InspectorAction.RERUN + ) + assert ( + backend_inspector.policy( + ModelResponse(status=ResponseStatus.SUCCESS, data="This is a normal message", error_message=""), "input" + ).action + == InspectorAction.CONTINUE + ) + + team_agent.delete() + + +@pytest.mark.parametrize("TeamAgentFactory", [TeamAgentFactory, v2.TeamAgent]) +def test_inspector_action_verification(run_input_map, delete_agents_and_team_agents, TeamAgentFactory): + """Test that inspector actions are properly executed and their results are verified""" + assert delete_agents_and_team_agents + + agents = create_agents_from_input_map(run_input_map) + + # Create a custom callable policy that always returns ABORT + def process_response(model_response: ModelResponse, input_content: str) -> InspectorOutput: + """Custom policy that always returns ABORT for safety testing.""" + return InspectorOutput(critiques="Safety check", content_edited="", action=InspectorAction.ABORT) + + # Create inspector with custom callable policy + inspector = Inspector( + name="custom_abort_inspector", + model_id=run_input_map["llm_id"], + model_params={"prompt": "You are a safety inspector."}, + policy=process_response, + ) + + # Create team agent with the custom policy inspector + team_agent = create_team_agent( + TeamAgentFactory, + agents, + run_input_map, + use_mentalist=True, + inspectors=[inspector], + inspector_targets=[InspectorTarget.STEPS], + ) + + # Deploy and run team agent + team_agent.deploy() + team_agent = TeamAgentFactory.get(team_agent.id) + response = team_agent.run(data=run_input_map["query"]) + + assert response is not None + assert response["completed"] is True + assert response["status"].lower() == "success" + + # Extract steps from response + steps = getattr(response.data, "intermediate_steps", []) if hasattr(response, "data") else [] + + # Find inspector steps + inspector_steps = [step for step in steps if "inspector" in step.get("agent", "").lower()] + + # If no inspector steps found, backend may not be using custom policies + if not inspector_steps: + print("No inspector steps found - backend may not be using custom policies") + team_agent.delete() + return + + # Verify inspector executed and took ABORT action + inspector_step = inspector_steps[0] + assert inspector_step.get("action") == "abort", "Inspector should have returned ABORT" + + # Verify response generator ran after inspector + response_generator_steps = [step for step in steps if "response_generator" in step.get("agent", "").lower()] + assert len(response_generator_steps) == 1, "Response generator should run exactly once after ABORT" + + team_agent.delete() diff --git a/tests/unit/team_agent/inspector_test.py b/tests/unit/team_agent/inspector_test.py index e21d6c53..aa8933d1 100644 --- a/tests/unit/team_agent/inspector_test.py +++ b/tests/unit/team_agent/inspector_test.py @@ -1,9 +1,21 @@ import pytest from unittest.mock import patch, MagicMock -from aixplain.modules.team_agent.inspector import Inspector, InspectorPolicy, InspectorAuto, AUTO_DEFAULT_MODEL_ID +from aixplain.modules.team_agent.inspector import ( + Inspector, + InspectorPolicy, + InspectorAuto, + AUTO_DEFAULT_MODEL_ID, + InspectorAction, + InspectorOutput, + callable_to_code_string, + code_string_to_callable, + get_policy_source, +) from aixplain.factories.team_agent_factory.inspector_factory import InspectorFactory from aixplain.enums.function import Function from aixplain.enums.asset_status import AssetStatus +from aixplain.modules.model.response import ModelResponse +from aixplain.enums.response_status import ResponseStatus # Test data INSPECTOR_CONFIG = { @@ -27,118 +39,518 @@ } -def test_inspector_creation(): - """Test basic inspector creation with valid parameters""" - inspector = Inspector( - name=INSPECTOR_CONFIG["name"], - model_id=INSPECTOR_CONFIG["model_id"], - model_params=INSPECTOR_CONFIG["model_config"], - policy=INSPECTOR_CONFIG["policy"], - ) +class TestInspectorCreation: + """Test inspector creation with various configurations""" - assert inspector.name == INSPECTOR_CONFIG["name"] - assert inspector.model_id == INSPECTOR_CONFIG["model_id"] - assert inspector.model_params == INSPECTOR_CONFIG["model_config"] - assert inspector.policy == INSPECTOR_CONFIG["policy"] - assert inspector.auto is None + def test_basic_inspector_creation(self): + """Test basic inspector creation with valid parameters""" + inspector = Inspector( + name=INSPECTOR_CONFIG["name"], + model_id=INSPECTOR_CONFIG["model_id"], + model_params=INSPECTOR_CONFIG["model_config"], + policy=INSPECTOR_CONFIG["policy"], + ) + + assert inspector.name == INSPECTOR_CONFIG["name"] + assert inspector.model_id == INSPECTOR_CONFIG["model_id"] + assert inspector.model_params == INSPECTOR_CONFIG["model_config"] + assert inspector.policy == INSPECTOR_CONFIG["policy"] + assert inspector.auto is None + def test_inspector_with_callable_policy(self): + """Test inspector creation with valid callable policy""" -def test_inspector_auto_creation(): - """Test inspector creation with auto configuration""" - inspector = Inspector(name="auto_inspector", auto=InspectorAuto.CORRECTNESS, policy=InspectorPolicy.WARN) + def process_response(model_response: ModelResponse, input_content: str) -> InspectorOutput: + if "error" in model_response.error_message.lower(): + return InspectorOutput(critiques="Error detected", content_edited="", action=InspectorAction.ABORT) + return InspectorOutput(critiques="No issues", content_edited="", action=InspectorAction.CONTINUE) - assert inspector.name == "auto_inspector" - assert inspector.auto == InspectorAuto.CORRECTNESS - assert inspector.policy == InspectorPolicy.WARN - assert inspector.model_id == AUTO_DEFAULT_MODEL_ID - assert inspector.model_params is None + inspector = Inspector( + name=INSPECTOR_CONFIG["name"], + model_id=INSPECTOR_CONFIG["model_id"], + model_params=INSPECTOR_CONFIG["model_config"], + policy=process_response, + ) + assert inspector.name == INSPECTOR_CONFIG["name"] + assert callable(inspector.policy) + assert inspector.policy.__name__ == "process_response" -def test_inspector_name_validation(): - """Test inspector name validation""" - with pytest.raises(ValueError, match="name cannot be empty"): - Inspector(name="", model_id="test_model_id") + def test_inspector_auto_creation(self): + """Test inspector creation with auto configuration""" + inspector = Inspector(name="auto_inspector", auto=InspectorAuto.CORRECTNESS, policy=InspectorPolicy.WARN) + assert inspector.name == "auto_inspector" + assert inspector.auto == InspectorAuto.CORRECTNESS + assert inspector.policy == InspectorPolicy.WARN + assert inspector.model_id == AUTO_DEFAULT_MODEL_ID + assert inspector.model_params is None -def test_inspector_factory_create_from_model(): - """Test creating inspector from model using factory""" - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.json.return_value = { - **MOCK_MODEL_RESPONSE, - "status": AssetStatus.ONBOARDED.value, - "function": {"id": Function.GUARDRAILS.value}, - } + def test_inspector_auto_creation_with_callable_policy(self): + """Test inspector creation with auto configuration and callable policy""" - with patch("aixplain.factories.team_agent_factory.inspector_factory._request_with_retry", return_value=mock_response): - inspector = InspectorFactory.create_from_model( - name=INSPECTOR_CONFIG["name"], - model=INSPECTOR_CONFIG["model_id"], - model_config=INSPECTOR_CONFIG["model_config"], - policy=INSPECTOR_CONFIG["policy"], + def process_response(model_response: ModelResponse, input_content: str) -> InspectorOutput: + return InspectorOutput(critiques="Test critique", content_edited="", action=InspectorAction.RERUN) + + inspector = Inspector(name="auto_inspector", auto=InspectorAuto.CORRECTNESS, policy=process_response) + + assert inspector.name == "auto_inspector" + assert inspector.auto == InspectorAuto.CORRECTNESS + assert inspector.policy == process_response + assert callable(inspector.policy) + assert inspector.model_id == AUTO_DEFAULT_MODEL_ID + assert inspector.model_params is None + + def test_inspector_name_validation(self): + """Test inspector name validation""" + with pytest.raises(ValueError, match="name cannot be empty"): + Inspector(name="", model_id="test_model_id") + + +class TestInspectorValidation: + """Test inspector validation and error handling""" + + def test_invalid_callable_name(self): + """Test inspector creation with callable that has wrong function name""" + + def wrong_name(model_response: ModelResponse, input_content: str) -> InspectorOutput: + return InspectorOutput(critiques="Test", content_edited="", action=InspectorAction.CONTINUE) + + with pytest.raises(ValueError, match="Policy callable must have name 'process_response'"): + Inspector( + name=INSPECTOR_CONFIG["name"], + model_id=INSPECTOR_CONFIG["model_id"], + model_params=INSPECTOR_CONFIG["model_config"], + policy=wrong_name, + ) + + def test_invalid_callable_arguments(self): + """Test inspector creation with callable that has wrong arguments""" + + def process_response(wrong_arg: ModelResponse, another_wrong_arg: str) -> InspectorOutput: + return InspectorOutput(critiques="Test", content_edited="", action=InspectorAction.CONTINUE) + + with pytest.raises(ValueError, match="Policy callable must have name 'process_response'"): + Inspector( + name=INSPECTOR_CONFIG["name"], + model_id=INSPECTOR_CONFIG["model_id"], + model_params=INSPECTOR_CONFIG["model_config"], + policy=process_response, + ) + + def test_invalid_callable_return_type(self): + """Test inspector creation with callable that has wrong return type""" + + def process_response(model_response: ModelResponse, input_content: str) -> str: + return "continue" + + with pytest.raises(ValueError, match="Policy callable must have name 'process_response'"): + Inspector( + name=INSPECTOR_CONFIG["name"], + model_id=INSPECTOR_CONFIG["model_id"], + model_params=INSPECTOR_CONFIG["model_config"], + policy=process_response, + ) + + def test_invalid_policy_type(self): + """Test inspector creation with invalid policy type""" + with pytest.raises(ValueError, match="Input should be"): + Inspector( + name=INSPECTOR_CONFIG["name"], + model_id=INSPECTOR_CONFIG["model_id"], + model_params=INSPECTOR_CONFIG["model_config"], + policy=123, # Invalid type + ) + + +class TestCodeStringConversion: + """Test conversion between callable functions and code strings""" + + def test_callable_to_code_string(self): + """Test converting callable to code string""" + + def process_response(model_response: ModelResponse, input_content: str) -> InspectorOutput: + if "error" in model_response.error_message.lower(): + return InspectorOutput(critiques="Error detected", content_edited="", action=InspectorAction.ABORT) + return InspectorOutput(critiques="No issues", content_edited="", action=InspectorAction.CONTINUE) + + code_string = callable_to_code_string(process_response) + assert isinstance(code_string, str) + assert "def process_response" in code_string + assert "model_response" in code_string + assert "input_content" in code_string + assert "InspectorAction.ABORT" in code_string + + def test_code_string_to_callable(self): + """Test converting code string back to callable""" + code_string = """def process_response(model_response: ModelResponse, input_content: str) -> InspectorOutput: + if "error" in model_response.error_message.lower(): + return InspectorOutput(critiques="Error detected", content_edited="", action=InspectorAction.ABORT) + return InspectorOutput(critiques="No issues", content_edited="", action=InspectorAction.CONTINUE)""" + + func = code_string_to_callable(code_string) + assert callable(func) + assert func.__name__ == "process_response" + + # Test the function works correctly + result1 = func(ModelResponse(status=ResponseStatus.FAILED, error_message="This is an error message"), "input") + assert result1.action == InspectorAction.ABORT + + result2 = func(ModelResponse(status=ResponseStatus.SUCCESS, data="This is a normal message"), "input") + assert result2.action == InspectorAction.CONTINUE + + def test_roundtrip_conversion(self): + """Test that serialization and deserialization work correctly together""" + + def process_response(model_response: ModelResponse, input_content: str) -> InspectorOutput: + if "error" in model_response.error_message.lower(): + return InspectorOutput(critiques="Error detected", content_edited="", action=InspectorAction.ABORT) + elif "warning" in model_response.data.lower(): + return InspectorOutput(critiques="Warning detected", content_edited="", action=InspectorAction.RERUN) + return InspectorOutput(critiques="No issues", content_edited="", action=InspectorAction.CONTINUE) + + # Serialize + code_string = callable_to_code_string(process_response) + + # Deserialize + deserialized_func = code_string_to_callable(code_string) + + # Test that the deserialized function works the same + assert ( + deserialized_func(ModelResponse(status=ResponseStatus.FAILED, error_message="error message"), "input").action + == InspectorAction.ABORT + ) + assert ( + deserialized_func(ModelResponse(status=ResponseStatus.SUCCESS, data="warning message"), "input").action + == InspectorAction.RERUN + ) + assert ( + deserialized_func(ModelResponse(status=ResponseStatus.SUCCESS, data="normal message"), "input").action + == InspectorAction.CONTINUE ) - assert inspector.name == INSPECTOR_CONFIG["name"] - assert inspector.model_id == INSPECTOR_CONFIG["model_id"] - assert inspector.model_params == INSPECTOR_CONFIG["model_config"] - assert inspector.policy == INSPECTOR_CONFIG["policy"] + def test_source_code_preservation(self): + """Test that code_string_to_callable preserves the original source code""" + code_string = """def process_response(model_response: ModelResponse, input_content: str) -> InspectorOutput: + if "error" in model_response.error_message.lower(): + return InspectorOutput(critiques="Error detected", content_edited="", action=InspectorAction.ABORT) + elif "warning" in model_response.data.lower(): + return InspectorOutput(critiques="Warning detected", content_edited="", action=InspectorAction.RERUN) + return InspectorOutput(critiques="No issues", content_edited="", action=InspectorAction.CONTINUE)""" + func = code_string_to_callable(code_string) -def test_inspector_factory_create_from_model_invalid_status(): - """Test creating inspector from model with invalid status""" - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.json.return_value = { - **MOCK_MODEL_RESPONSE, - "status": AssetStatus.DRAFT.value, - "function": {"id": Function.GUARDRAILS.value}, - } + # Verify the function has the _source_code attribute + assert hasattr(func, "_source_code") + assert func._source_code == code_string - with patch("aixplain.factories.team_agent_factory.inspector_factory._request_with_retry", return_value=mock_response): - with pytest.raises(ValueError, match="is not onboarded"): - InspectorFactory.create_from_model( + # Verify the function works correctly + assert ( + func(ModelResponse(status=ResponseStatus.FAILED, error_message="This is an error message"), "input").action + == InspectorAction.ABORT + ) + assert ( + func(ModelResponse(status=ResponseStatus.SUCCESS, data="This is a warning message"), "input").action + == InspectorAction.RERUN + ) + assert ( + func(ModelResponse(status=ResponseStatus.SUCCESS, data="This is a normal message"), "input").action + == InspectorAction.CONTINUE + ) + + +class TestSourceCodeRetrieval: + """Test source code retrieval functionality""" + + def test_get_policy_source_original_function(self): + """Test get_policy_source with an original function (should use inspect.getsource)""" + + def process_response(model_response: ModelResponse, input_content: str) -> InspectorOutput: + if "error" in model_response.error_message.lower(): + return InspectorOutput(critiques="Error detected", content_edited="", action=InspectorAction.ABORT) + return InspectorOutput(critiques="No issues", content_edited="", action=InspectorAction.CONTINUE) + + source = get_policy_source(process_response) + assert source is not None + assert "def process_response" in source + assert "InspectorAction.ABORT" in source + + def test_get_policy_source_deserialized_function(self): + """Test get_policy_source with a deserialized function (should use _source_code attribute)""" + code_string = """def process_response(model_response: ModelResponse, input_content: str) -> InspectorOutput: + if "error" in model_response.error_message.lower(): + return InspectorOutput(critiques="Error detected", content_edited="", action=InspectorAction.ABORT) + return InspectorOutput(critiques="No issues", content_edited="", action=InspectorAction.CONTINUE)""" + + func = code_string_to_callable(code_string) + + # Verify get_policy_source works with the deserialized function + source = get_policy_source(func) + assert source is not None + assert source == code_string + + def test_get_policy_source_fallback(self): + """Test get_policy_source fallback when neither approach works""" + # Create a function without source code info by using exec() + # This simulates a function created dynamically where inspect.getsource() would fail + namespace = {} + exec( + "def dynamic_func(x, y): return InspectorOutput(critiques='', content_edited='', action=InspectorAction.CONTINUE)", + namespace, + ) + func = namespace["dynamic_func"] + + # Remove any potential source code attributes + if hasattr(func, "_source_code"): + delattr(func, "_source_code") + + source = get_policy_source(func) + assert source is None + + +class TestInspectorSerialization: + """Test Inspector serialization and deserialization""" + + def test_model_dump_with_callable_policy(self): + """Test that Inspector.model_dump properly serializes callable policies""" + + def process_response(model_response: ModelResponse, input_content: str) -> InspectorOutput: + return InspectorOutput(critiques="Test critique", content_edited="", action=InspectorAction.ABORT) + + inspector = Inspector( + name="test_inspector", + model_id="test_model_id", + policy=process_response, + ) + + data = inspector.model_dump() + assert data["policy_type"] == "callable" + assert isinstance(data["policy"], str) + assert "def process_response" in data["policy"] + + def test_model_dump_with_enum_policy(self): + """Test that Inspector.model_dump properly serializes enum policies""" + inspector = Inspector( + name="test_inspector", + model_id="test_model_id", + policy=InspectorPolicy.WARN, + ) + + data = inspector.model_dump() + assert data["policy_type"] == "enum" + assert data["policy"] == "warn" + + def test_model_validate_with_callable_policy(self): + """Test that Inspector.model_validate properly deserializes callable policies""" + inspector_data = { + "name": "test_inspector", + "model_id": "test_model_id", + "policy": """def process_response(model_response: ModelResponse, input_content: str) -> InspectorOutput: + return InspectorOutput(critiques="Test critique", content_edited="", action=InspectorAction.ABORT)""", + "policy_type": "callable", + } + + inspector = Inspector.model_validate(inspector_data) + assert callable(inspector.policy) + assert inspector.policy.__name__ == "process_response" + result = inspector.policy(ModelResponse(status=ResponseStatus.SUCCESS, data="test"), "input") + assert result.action == InspectorAction.ABORT + + def test_model_validate_with_enum_policy(self): + """Test that Inspector.model_validate properly deserializes enum policies""" + inspector_data = { + "name": "test_inspector", + "model_id": "test_model_id", + "policy": "warn", + "policy_type": "enum", + } + + inspector = Inspector.model_validate(inspector_data) + assert inspector.policy == InspectorPolicy.WARN + + def test_model_validate_fallback(self): + """Test that Inspector.model_validate falls back to default policy on error""" + inspector_data = { + "name": "test_inspector", + "model_id": "test_model_id", + "policy": "invalid code string", + "policy_type": "callable", + } + + inspector = Inspector.model_validate(inspector_data) + assert inspector.policy == InspectorPolicy.ADAPTIVE # Default fallback + + def test_roundtrip_serialization_preserves_source_code(self): + """Test that Inspector round-trip serialization preserves source code""" + + def process_response(model_response: ModelResponse, input_content: str) -> InspectorOutput: + if "error" in model_response.error_message.lower(): + return InspectorOutput(critiques="Error detected", content_edited="", action=InspectorAction.ABORT) + elif "warning" in model_response.data.lower(): + return InspectorOutput(critiques="Warning detected", content_edited="", action=InspectorAction.RERUN) + return InspectorOutput(critiques="No issues", content_edited="", action=InspectorAction.CONTINUE) + + # Create inspector with callable policy + inspector = Inspector( + name="test_inspector", + model_id="test_model_id", + policy=process_response, + ) + + # Serialize to dict + inspector_dict = inspector.model_dump() + assert inspector_dict["policy_type"] == "callable" + assert isinstance(inspector_dict["policy"], str) + + # Deserialize from dict + inspector_copy = Inspector.model_validate(inspector_dict) + assert callable(inspector_copy.policy) + assert inspector_copy.policy.__name__ == "process_response" + + # Verify the deserialized function has source code and works correctly + assert hasattr(inspector_copy.policy, "_source_code") + assert "def process_response" in inspector_copy.policy._source_code + assert ( + inspector_copy.policy( + ModelResponse(status=ResponseStatus.FAILED, error_message="This is an error message"), "input" + ).action + == InspectorAction.ABORT + ) + assert ( + inspector_copy.policy( + ModelResponse(status=ResponseStatus.SUCCESS, data="This is a warning message"), "input" + ).action + == InspectorAction.RERUN + ) + assert ( + inspector_copy.policy(ModelResponse(status=ResponseStatus.SUCCESS, data="This is a normal message"), "input").action + == InspectorAction.CONTINUE + ) + + +class TestInspectorFactory: + """Test InspectorFactory functionality""" + + def test_create_from_model(self): + """Test creating inspector from model using factory""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + **MOCK_MODEL_RESPONSE, + "status": AssetStatus.ONBOARDED.value, + "function": {"id": Function.GUARDRAILS.value}, + } + + with patch("aixplain.factories.team_agent_factory.inspector_factory._request_with_retry", return_value=mock_response): + inspector = InspectorFactory.create_from_model( name=INSPECTOR_CONFIG["name"], model=INSPECTOR_CONFIG["model_id"], model_config=INSPECTOR_CONFIG["model_config"], policy=INSPECTOR_CONFIG["policy"], ) + assert inspector.name == INSPECTOR_CONFIG["name"] + assert inspector.model_id == INSPECTOR_CONFIG["model_id"] + assert inspector.model_params == INSPECTOR_CONFIG["model_config"] + assert inspector.policy == INSPECTOR_CONFIG["policy"] + + def test_create_from_model_with_callable_policy(self): + """Test creating inspector from model using factory with callable policy""" -def test_inspector_factory_create_from_model_invalid_function(): - """Test creating inspector from model with invalid function""" - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.json.return_value = { - **MOCK_MODEL_RESPONSE, - "status": AssetStatus.ONBOARDED.value, - "function": {"id": Function.TRANSLATION.value}, - } + def process_response(model_response: ModelResponse, input_content: str) -> InspectorOutput: + return InspectorOutput(critiques="Test critique", content_edited="", action=InspectorAction.CONTINUE) - with patch("aixplain.factories.team_agent_factory.inspector_factory._request_with_retry", return_value=mock_response): - with pytest.raises(ValueError, match="models are supported"): - InspectorFactory.create_from_model( + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + **MOCK_MODEL_RESPONSE, + "status": AssetStatus.ONBOARDED.value, + "function": {"id": Function.GUARDRAILS.value}, + } + + with patch("aixplain.factories.team_agent_factory.inspector_factory._request_with_retry", return_value=mock_response): + inspector = InspectorFactory.create_from_model( name=INSPECTOR_CONFIG["name"], model=INSPECTOR_CONFIG["model_id"], model_config=INSPECTOR_CONFIG["model_config"], - policy=INSPECTOR_CONFIG["policy"], + policy=process_response, ) + assert inspector.name == INSPECTOR_CONFIG["name"] + assert inspector.model_id == INSPECTOR_CONFIG["model_id"] + assert inspector.model_params == INSPECTOR_CONFIG["model_config"] + assert inspector.policy == process_response + assert callable(inspector.policy) + + def test_create_from_model_invalid_status(self): + """Test creating inspector from model with invalid status""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + **MOCK_MODEL_RESPONSE, + "status": AssetStatus.DRAFT.value, + "function": {"id": Function.GUARDRAILS.value}, + } + + with patch("aixplain.factories.team_agent_factory.inspector_factory._request_with_retry", return_value=mock_response): + with pytest.raises(ValueError, match="is not onboarded"): + InspectorFactory.create_from_model( + name=INSPECTOR_CONFIG["name"], + model=INSPECTOR_CONFIG["model_id"], + model_config=INSPECTOR_CONFIG["model_config"], + policy=INSPECTOR_CONFIG["policy"], + ) + + def test_create_from_model_invalid_function(self): + """Test creating inspector from model with invalid function""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + **MOCK_MODEL_RESPONSE, + "status": AssetStatus.ONBOARDED.value, + "function": {"id": Function.TRANSLATION.value}, + } + + with patch("aixplain.factories.team_agent_factory.inspector_factory._request_with_retry", return_value=mock_response): + with pytest.raises(ValueError, match="models are supported"): + InspectorFactory.create_from_model( + name=INSPECTOR_CONFIG["name"], + model=INSPECTOR_CONFIG["model_id"], + model_config=INSPECTOR_CONFIG["model_config"], + policy=INSPECTOR_CONFIG["policy"], + ) + + def test_create_auto(self): + """Test creating auto-configured inspector using factory""" + inspector = InspectorFactory.create_auto( + auto=InspectorAuto.CORRECTNESS, name="custom_name", policy=InspectorPolicy.ABORT + ) + + assert inspector.name == "custom_name" + assert inspector.auto == InspectorAuto.CORRECTNESS + assert inspector.policy == InspectorPolicy.ABORT + assert inspector.model_id == AUTO_DEFAULT_MODEL_ID + assert inspector.model_params is None + + def test_create_auto_with_callable_policy(self): + """Test creating auto-configured inspector using factory with callable policy""" -def test_inspector_factory_create_auto(): - """Test creating auto-configured inspector using factory""" - inspector = InspectorFactory.create_auto(auto=InspectorAuto.CORRECTNESS, name="custom_name", policy=InspectorPolicy.ABORT) + def process_response(model_response: ModelResponse, input_content: str) -> InspectorOutput: + return InspectorOutput(critiques="Test critique", content_edited="", action=InspectorAction.ABORT) - assert inspector.name == "custom_name" - assert inspector.auto == InspectorAuto.CORRECTNESS - assert inspector.policy == InspectorPolicy.ABORT - assert inspector.model_id == AUTO_DEFAULT_MODEL_ID - assert inspector.model_params is None + inspector = InspectorFactory.create_auto(auto=InspectorAuto.CORRECTNESS, name="custom_name", policy=process_response) + assert inspector.name == "custom_name" + assert inspector.auto == InspectorAuto.CORRECTNESS + assert inspector.policy == process_response + assert callable(inspector.policy) + assert inspector.model_id == AUTO_DEFAULT_MODEL_ID + assert inspector.model_params is None -def test_inspector_factory_create_auto_default_name(): - """Test creating auto-configured inspector with default name""" - inspector = InspectorFactory.create_auto(auto=InspectorAuto.CORRECTNESS) + def test_create_auto_default_name(self): + """Test creating auto-configured inspector with default name""" + inspector = InspectorFactory.create_auto(auto=InspectorAuto.CORRECTNESS) - assert inspector.name == "inspector_correctness" - assert inspector.auto == InspectorAuto.CORRECTNESS - assert inspector.policy == InspectorPolicy.ADAPTIVE # default policy + assert inspector.name == "inspector_correctness" + assert inspector.auto == InspectorAuto.CORRECTNESS + assert inspector.policy == InspectorPolicy.ADAPTIVE # default policy