Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion aixplain/factories/team_agent_factory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,10 +227,13 @@ def _setup_llm_and_tool(
team_agent = build_team_agent(payload=internal_payload, agents=agent_list, api_key=api_key)
team_agent.validate(raise_exception=True)
response = "Unspecified error"
inspectors=team_agent.inspectors
inspector_targets=team_agent.inspector_targets
try:
payload["inspectors"] = [
inspector.model_dump(by_alias=True) for inspector in inspectors
] # convert Inspector object to dict
]
payload["inspectorTargets"] = inspector_targets
logging.debug(f"Start service for POST Create TeamAgent - {url} - {headers} - {json.dumps(payload)}")
r = _request_with_retry("post", url, headers=headers, json=payload)
response = r.json()
Expand Down
48 changes: 47 additions & 1 deletion aixplain/factories/team_agent_factory/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
from aixplain.modules.agent.agent_task import AgentTask
from aixplain.modules.agent.tool.model_tool import ModelTool
from aixplain.modules.team_agent import TeamAgent, InspectorTarget
from aixplain.modules.team_agent.inspector import Inspector
from aixplain.modules.team_agent.inspector import Inspector, InspectorAction, InspectorAuto, InspectorPolicy, InspectorOutput
from aixplain.factories.agent_factory import AgentFactory
from aixplain.factories.model_factory import ModelFactory
from aixplain.modules.model.model_parameters import ModelParameters
from aixplain.modules.agent.output_format import OutputFormat
from aixplain.modules.model.response import ModelResponse

GPT_4o_ID = "6646261c6eb563165658bbb1"
SUPPORTED_TOOLS = ["llm", "website_search", "website_scrape", "website_crawl", "serper_search"]
Expand Down Expand Up @@ -154,6 +155,51 @@ def get_cached_model(model_id: str) -> any:
elif tool["description"] == "mentalist":
mentalist_llm = llm

resolved_model_id = payload.get("llmId", None)
if not resolved_model_id:
resolved_model_id = GPT_4o_ID
has_quality_check = any(
(getattr(ins, "name", "") or "").lower() == "qualitycheckinspector"
for ins in inspectors
)
if not has_quality_check:
try:
def process_response(model_response: ModelResponse, input_content: str) -> InspectorOutput:
critiques = model_response.data
action = InspectorAction.RERUN
return InspectorOutput(critiques=critiques, content_edited=input_content, action=action)

default_inspector = Inspector(
name="QualityCheckInspector",
model_id=resolved_model_id,
model_params={"prompt": "Analyze content to ensure correctness of response"},
policy=process_response
)

inspectors = [default_inspector] + inspectors
inspector_targets = payload.get("inspectorTargets", inspector_targets if 'inspector_targets' in locals() else [])
if isinstance(inspector_targets, (str, InspectorTarget)):
inspector_targets = [inspector_targets]
normalized = []
for t in inspector_targets:
if isinstance(t, InspectorTarget):
normalized.append(t)
elif isinstance(t, str):
try:
normalized.append(InspectorTarget(t.lower()))
except Exception:
logging.warning(f"Ignoring unknown inspector target: {t!r}")
else:
logging.warning(f"Ignoring inspector target with unexpected type: {type(t)}")

if InspectorTarget.STEPS not in normalized:
normalized.append(InspectorTarget.STEPS)

inspector_targets = normalized

except Exception as e:
logging.warning(f"Failed to add default QualityCheckInspector: {e}")

team_agent = TeamAgent(
id=payload.get("id", ""),
name=payload.get("name", ""),
Expand Down
2 changes: 1 addition & 1 deletion aixplain/modules/team_agent/inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class InspectorOutput(BaseModel):

class InspectorAuto(str, Enum):
"""A list of keywords for inspectors configured automatically in the backend."""

ALIGNMENT = "alignment"
CORRECTNESS = "correctness"

def get_name(self) -> Text:
Expand Down