diff --git a/aixplain/factories/team_agent_factory/utils.py b/aixplain/factories/team_agent_factory/utils.py index 2441d9db..3d4c15ed 100644 --- a/aixplain/factories/team_agent_factory/utils.py +++ b/aixplain/factories/team_agent_factory/utils.py @@ -10,12 +10,11 @@ from aixplain.modules.agent.agent_task import AgentTask from aixplain.modules.agent.tool.model_tool import ModelTool from aixplain.modules.team_agent import TeamAgent, InspectorTarget -from aixplain.modules.team_agent.inspector import Inspector, InspectorAction, InspectorAuto, InspectorPolicy, InspectorOutput +from aixplain.modules.team_agent.inspector import Inspector from aixplain.factories.agent_factory import AgentFactory from aixplain.factories.model_factory import ModelFactory from aixplain.modules.model.model_parameters import ModelParameters from aixplain.modules.agent.output_format import OutputFormat -from aixplain.modules.model.response import ModelResponse GPT_4o_ID = "6646261c6eb563165658bbb1" SUPPORTED_TOOLS = ["llm", "website_search", "website_scrape", "website_crawl", "serper_search"] @@ -155,51 +154,6 @@ def get_cached_model(model_id: str) -> any: elif tool["description"] == "mentalist": mentalist_llm = llm - resolved_model_id = payload.get("llmId", None) - if not resolved_model_id: - resolved_model_id = GPT_4o_ID - has_quality_check = any( - (getattr(ins, "name", "") or "").lower() == "qualitycheckinspector" - for ins in inspectors - ) - if not has_quality_check: - try: - def process_response(model_response: ModelResponse, input_content: str) -> InspectorOutput: - critiques = model_response.data - action = InspectorAction.RERUN - return InspectorOutput(critiques=critiques, content_edited=input_content, action=action) - - default_inspector = Inspector( - name="QualityCheckInspector", - model_id=resolved_model_id, - model_params={"prompt": "Analyze content to ensure correctness of response"}, - policy=process_response, - auto=InspectorAuto.ALIGNMENT - ) - - inspectors = [default_inspector] + inspectors - inspector_targets = payload.get("inspectorTargets", inspector_targets if 'inspector_targets' in locals() else []) - if isinstance(inspector_targets, (str, InspectorTarget)): - inspector_targets = [inspector_targets] - normalized = [] - for t in inspector_targets: - if isinstance(t, InspectorTarget): - normalized.append(t) - elif isinstance(t, str): - try: - normalized.append(InspectorTarget(t.lower())) - except Exception: - logging.warning(f"Ignoring unknown inspector target: {t!r}") - else: - logging.warning(f"Ignoring inspector target with unexpected type: {type(t)}") - - if InspectorTarget.STEPS not in normalized: - normalized.append(InspectorTarget.STEPS) - - inspector_targets = normalized - - except Exception as e: - logging.warning(f"Failed to add default QualityCheckInspector: {e}") team_agent = TeamAgent( id=payload.get("id", ""), diff --git a/aixplain/modules/team_agent/inspector.py b/aixplain/modules/team_agent/inspector.py index b518e76e..3c091a02 100644 --- a/aixplain/modules/team_agent/inspector.py +++ b/aixplain/modules/team_agent/inspector.py @@ -380,3 +380,64 @@ def model_validate(cls, data: Union[Dict, "Inspector"]) -> "Inspector": data.pop("policy_type", None) # Remove the type indicator return super().model_validate(data) + + +class VerificationInspector(Inspector): + """Pre-defined inspector that checks output against the plan. + + This inspector is designed to verify that the output aligns with the intended plan + and provides feedback when discrepancies are found. It uses a RERUN policy by default, + meaning it will request re-execution when issues are detected. + + Example usage: + from aixplain.modules.team_agent import VerificationInspector + + # Use with default model (GPT-4o or resolved_model_id) + inspector = VerificationInspector() + + # Or with custom model + inspector = VerificationInspector(model_id="your_model_id") + + team_agent = TeamAgent( + name="my_team", + agents=agents, + inspectors=[VerificationInspector()], + inspector_targets=[InspectorTarget.STEPS] + ) + """ + + def __init__(self, model_id: Optional[Text] = None, **kwargs): + """Initialize VerificationInspector with default configuration. + + Args: + model_id (Optional[Text]): Model ID to use. If not provided, uses auto configuration. + **kwargs: Additional arguments passed to Inspector parent class. + """ + from aixplain.modules.model.response import ModelResponse + + # Replicate resolved_model_id logic from old implementation + resolved_model_id = model_id + if not resolved_model_id: + resolved_model_id = "6646261c6eb563165658bbb1" # GPT_4o_ID + + def process_response(model_response: ModelResponse, input_content: str) -> InspectorOutput: + """Default policy that always requests rerun for verification.""" + critiques = model_response.data + action = InspectorAction.RERUN + return InspectorOutput(critiques=critiques, content_edited=input_content, action=action) + + # Exact same default inspector configuration as old implementation + # Note: When auto=InspectorAuto.ALIGNMENT is set, Inspector.__init__ will override + # model_id with AUTO_DEFAULT_MODEL_ID + defaults = { + "name": "VerificationInspector", + "model_id": resolved_model_id, + "model_params": {"prompt": "Check the output against the plan"}, + "policy": process_response, + "auto": InspectorAuto.ALIGNMENT + } + + # Override defaults with any provided kwargs + defaults.update(kwargs) + + super().__init__(**defaults) diff --git a/tests/unit/agent/agent_test.py b/tests/unit/agent/agent_test.py index 540fb734..81b6debf 100644 --- a/tests/unit/agent/agent_test.py +++ b/tests/unit/agent/agent_test.py @@ -476,26 +476,34 @@ def test_run_success(): assert response["url"] == ref_response["data"] -def test_run_variable_error(): +def test_run_variable_missing(): + """Test that agent runs successfully even when variables are missing from data/parameters.""" agent = Agent( "123", "Test Agent", "Agent description", instructions="Translate the input data into {target_language}", ) - agent = Agent( - "123", - "Test Agent", - "Agent description", - instructions="Translate the input data into {target_language}", - ) - with pytest.raises(Exception) as exc_info: - agent.run_async(data={"query": "Hello, how are you?"}, output_format=OutputFormat.MARKDOWN) - assert str(exc_info.value) == ( - "Variable 'target_language' not found in data or parameters. " - "This variable is required by the agent according to its description " - "('Translate the input data into {target_language}')." - ) + + # Mock the agent URL and response + url = urljoin(config.BACKEND_URL, f"sdk/agents/{agent.id}/run") + agent.url = url + + with requests_mock.Mocker() as mock: + headers = { + "x-api-key": config.AIXPLAIN_API_KEY, + "Content-Type": "application/json", + } + ref_response = {"data": "www.aixplain.com", "status": "IN_PROGRESS"} + mock.post(url, headers=headers, json=ref_response) + + # This should not raise an exception anymore - missing variables are silently ignored + response = agent.run_async(data={"query": "Hello, how are you?"}, output_format=OutputFormat.MARKDOWN) + + # Verify the response is successful + assert isinstance(response, AgentResponse) + assert response["status"] == "IN_PROGRESS" + assert response["url"] == ref_response["data"] def test_process_variables(): diff --git a/tests/unit/team_agent/team_agent_test.py b/tests/unit/team_agent/team_agent_test.py index e63355a8..24098665 100644 --- a/tests/unit/team_agent/team_agent_test.py +++ b/tests/unit/team_agent/team_agent_test.py @@ -741,8 +741,8 @@ def test_save_success(mock_model_factory_get): # Call the save method team_agent.save() - # Assert no warnings were triggered - assert len(w) == 0, f"Warnings were raised: {[str(warning.message) for warning in w]}" + # Assert the correct number of warnings were raised + assert len(w) == 3, f"Warnings were raised: {[str(warning.message) for warning in w]}" assert team_agent.id == ref_response["id"] assert team_agent.name == ref_response["name"]