diff --git a/aixplain/factories/team_agent_factory/__init__.py b/aixplain/factories/team_agent_factory/__init__.py index 60050813..5463902e 100644 --- a/aixplain/factories/team_agent_factory/__init__.py +++ b/aixplain/factories/team_agent_factory/__init__.py @@ -227,10 +227,13 @@ def _setup_llm_and_tool( team_agent = build_team_agent(payload=internal_payload, agents=agent_list, api_key=api_key) team_agent.validate(raise_exception=True) response = "Unspecified error" + inspectors=team_agent.inspectors + inspector_targets=team_agent.inspector_targets try: payload["inspectors"] = [ inspector.model_dump(by_alias=True) for inspector in inspectors - ] # convert Inspector object to dict + ] + payload["inspectorTargets"] = inspector_targets logging.debug(f"Start service for POST Create TeamAgent - {url} - {headers} - {json.dumps(payload)}") r = _request_with_retry("post", url, headers=headers, json=payload) response = r.json() diff --git a/aixplain/factories/team_agent_factory/utils.py b/aixplain/factories/team_agent_factory/utils.py index edbef6d4..9f23f56e 100644 --- a/aixplain/factories/team_agent_factory/utils.py +++ b/aixplain/factories/team_agent_factory/utils.py @@ -10,11 +10,12 @@ from aixplain.modules.agent.agent_task import AgentTask from aixplain.modules.agent.tool.model_tool import ModelTool from aixplain.modules.team_agent import TeamAgent, InspectorTarget -from aixplain.modules.team_agent.inspector import Inspector +from aixplain.modules.team_agent.inspector import Inspector, InspectorAction, InspectorAuto, InspectorPolicy, InspectorOutput from aixplain.factories.agent_factory import AgentFactory from aixplain.factories.model_factory import ModelFactory from aixplain.modules.model.model_parameters import ModelParameters from aixplain.modules.agent.output_format import OutputFormat +from aixplain.modules.model.response import ModelResponse GPT_4o_ID = "6646261c6eb563165658bbb1" SUPPORTED_TOOLS = ["llm", "website_search", "website_scrape", "website_crawl", "serper_search"] @@ -154,6 +155,51 @@ def get_cached_model(model_id: str) -> any: elif tool["description"] == "mentalist": mentalist_llm = llm + resolved_model_id = payload.get("llmId", None) + if not resolved_model_id: + resolved_model_id = GPT_4o_ID + has_quality_check = any( + (getattr(ins, "name", "") or "").lower() == "qualitycheckinspector" + for ins in inspectors + ) + if not has_quality_check: + try: + def process_response(model_response: ModelResponse, input_content: str) -> InspectorOutput: + critiques = model_response.data + action = InspectorAction.RERUN + return InspectorOutput(critiques=critiques, content_edited=input_content, action=action) + + default_inspector = Inspector( + name="QualityCheckInspector", + model_id=resolved_model_id, + model_params={"prompt": "Analyze content to ensure correctness of response"}, + policy=process_response + ) + + inspectors = [default_inspector] + inspectors + inspector_targets = payload.get("inspectorTargets", inspector_targets if 'inspector_targets' in locals() else []) + if isinstance(inspector_targets, (str, InspectorTarget)): + inspector_targets = [inspector_targets] + normalized = [] + for t in inspector_targets: + if isinstance(t, InspectorTarget): + normalized.append(t) + elif isinstance(t, str): + try: + normalized.append(InspectorTarget(t.lower())) + except Exception: + logging.warning(f"Ignoring unknown inspector target: {t!r}") + else: + logging.warning(f"Ignoring inspector target with unexpected type: {type(t)}") + + if InspectorTarget.STEPS not in normalized: + normalized.append(InspectorTarget.STEPS) + + inspector_targets = normalized + + except Exception as e: + logging.warning(f"Failed to add default QualityCheckInspector: {e}") + team_agent = TeamAgent( id=payload.get("id", ""), name=payload.get("name", ""), diff --git a/aixplain/modules/team_agent/inspector.py b/aixplain/modules/team_agent/inspector.py index d0a10932..2e870c00 100644 --- a/aixplain/modules/team_agent/inspector.py +++ b/aixplain/modules/team_agent/inspector.py @@ -58,7 +58,7 @@ class InspectorOutput(BaseModel): class InspectorAuto(str, Enum): """A list of keywords for inspectors configured automatically in the backend.""" - + ALIGNMENT = "alignment" CORRECTNESS = "correctness" def get_name(self) -> Text: