diff --git a/aixplain/factories/team_agent_factory/utils.py b/aixplain/factories/team_agent_factory/utils.py index 2441d9db..3d4c15ed 100644 --- a/aixplain/factories/team_agent_factory/utils.py +++ b/aixplain/factories/team_agent_factory/utils.py @@ -10,12 +10,11 @@ from aixplain.modules.agent.agent_task import AgentTask from aixplain.modules.agent.tool.model_tool import ModelTool from aixplain.modules.team_agent import TeamAgent, InspectorTarget -from aixplain.modules.team_agent.inspector import Inspector, InspectorAction, InspectorAuto, InspectorPolicy, InspectorOutput +from aixplain.modules.team_agent.inspector import Inspector from aixplain.factories.agent_factory import AgentFactory from aixplain.factories.model_factory import ModelFactory from aixplain.modules.model.model_parameters import ModelParameters from aixplain.modules.agent.output_format import OutputFormat -from aixplain.modules.model.response import ModelResponse GPT_4o_ID = "6646261c6eb563165658bbb1" SUPPORTED_TOOLS = ["llm", "website_search", "website_scrape", "website_crawl", "serper_search"] @@ -155,51 +154,6 @@ def get_cached_model(model_id: str) -> any: elif tool["description"] == "mentalist": mentalist_llm = llm - resolved_model_id = payload.get("llmId", None) - if not resolved_model_id: - resolved_model_id = GPT_4o_ID - has_quality_check = any( - (getattr(ins, "name", "") or "").lower() == "qualitycheckinspector" - for ins in inspectors - ) - if not has_quality_check: - try: - def process_response(model_response: ModelResponse, input_content: str) -> InspectorOutput: - critiques = model_response.data - action = InspectorAction.RERUN - return InspectorOutput(critiques=critiques, content_edited=input_content, action=action) - - default_inspector = Inspector( - name="QualityCheckInspector", - model_id=resolved_model_id, - model_params={"prompt": "Analyze content to ensure correctness of response"}, - policy=process_response, - auto=InspectorAuto.ALIGNMENT - ) - - inspectors = [default_inspector] + inspectors - inspector_targets = payload.get("inspectorTargets", inspector_targets if 'inspector_targets' in locals() else []) - if isinstance(inspector_targets, (str, InspectorTarget)): - inspector_targets = [inspector_targets] - normalized = [] - for t in inspector_targets: - if isinstance(t, InspectorTarget): - normalized.append(t) - elif isinstance(t, str): - try: - normalized.append(InspectorTarget(t.lower())) - except Exception: - logging.warning(f"Ignoring unknown inspector target: {t!r}") - else: - logging.warning(f"Ignoring inspector target with unexpected type: {type(t)}") - - if InspectorTarget.STEPS not in normalized: - normalized.append(InspectorTarget.STEPS) - - inspector_targets = normalized - - except Exception as e: - logging.warning(f"Failed to add default QualityCheckInspector: {e}") team_agent = TeamAgent( id=payload.get("id", ""), diff --git a/aixplain/modules/team_agent/inspector.py b/aixplain/modules/team_agent/inspector.py index b518e76e..3c091a02 100644 --- a/aixplain/modules/team_agent/inspector.py +++ b/aixplain/modules/team_agent/inspector.py @@ -380,3 +380,64 @@ def model_validate(cls, data: Union[Dict, "Inspector"]) -> "Inspector": data.pop("policy_type", None) # Remove the type indicator return super().model_validate(data) + + +class VerificationInspector(Inspector): + """Pre-defined inspector that checks output against the plan. + + This inspector is designed to verify that the output aligns with the intended plan + and provides feedback when discrepancies are found. It uses a RERUN policy by default, + meaning it will request re-execution when issues are detected. + + Example usage: + from aixplain.modules.team_agent import VerificationInspector + + # Use with default model (GPT-4o or resolved_model_id) + inspector = VerificationInspector() + + # Or with custom model + inspector = VerificationInspector(model_id="your_model_id") + + team_agent = TeamAgent( + name="my_team", + agents=agents, + inspectors=[VerificationInspector()], + inspector_targets=[InspectorTarget.STEPS] + ) + """ + + def __init__(self, model_id: Optional[Text] = None, **kwargs): + """Initialize VerificationInspector with default configuration. + + Args: + model_id (Optional[Text]): Model ID to use. If not provided, uses auto configuration. + **kwargs: Additional arguments passed to Inspector parent class. + """ + from aixplain.modules.model.response import ModelResponse + + # Replicate resolved_model_id logic from old implementation + resolved_model_id = model_id + if not resolved_model_id: + resolved_model_id = "6646261c6eb563165658bbb1" # GPT_4o_ID + + def process_response(model_response: ModelResponse, input_content: str) -> InspectorOutput: + """Default policy that always requests rerun for verification.""" + critiques = model_response.data + action = InspectorAction.RERUN + return InspectorOutput(critiques=critiques, content_edited=input_content, action=action) + + # Exact same default inspector configuration as old implementation + # Note: When auto=InspectorAuto.ALIGNMENT is set, Inspector.__init__ will override + # model_id with AUTO_DEFAULT_MODEL_ID + defaults = { + "name": "VerificationInspector", + "model_id": resolved_model_id, + "model_params": {"prompt": "Check the output against the plan"}, + "policy": process_response, + "auto": InspectorAuto.ALIGNMENT + } + + # Override defaults with any provided kwargs + defaults.update(kwargs) + + super().__init__(**defaults)