Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ on:
- main
# any branch other than main, will use the test key
- test
workflow_dispatch:
workflow_dispatch:

jobs:
setup-and-test:
Expand All @@ -31,7 +31,8 @@ jobs:
'finetune_v2',
'general_assets',
'apikey',
'agent_and_team_agent',
'agent',
'team_agent',
]
include:
- test-suite: 'unit'
Expand Down Expand Up @@ -79,9 +80,13 @@ jobs:
- test-suite: 'apikey'
path: 'tests/functional/apikey'
timeout: 45
- test-suite: 'agent_and_team_agent'
path: 'tests/functional/agent tests/functional/team_agent'
- test-suite: 'agent'
path: 'tests/functional/agent'
timeout: 45
- test-suite: 'team_agent'
path: 'tests/functional/team_agent'
timeout: 45

steps:
- name: Checkout repository
uses: actions/checkout@v4
Expand All @@ -91,7 +96,7 @@ jobs:
with:
python-version: "3.9"
cache: 'pip'

- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand All @@ -112,7 +117,7 @@ jobs:
fi
echo "SLACK_TOKEN=${{ secrets.SLACK_TOKEN }}" >> $GITHUB_ENV
echo "HF_TOKEN=${{ secrets.HF_TOKEN }}" >> $GITHUB_ENV

- name: Run Tests
timeout-minutes: ${{ matrix.timeout }}
run: python -m pytest ${{ matrix.path }}
48 changes: 1 addition & 47 deletions aixplain/factories/team_agent_factory/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,11 @@
from aixplain.modules.agent.agent_task import AgentTask
from aixplain.modules.agent.tool.model_tool import ModelTool
from aixplain.modules.team_agent import TeamAgent, InspectorTarget
from aixplain.modules.team_agent.inspector import Inspector, InspectorAction, InspectorAuto, InspectorPolicy, InspectorOutput
from aixplain.modules.team_agent.inspector import Inspector
from aixplain.factories.agent_factory import AgentFactory
from aixplain.factories.model_factory import ModelFactory
from aixplain.modules.model.model_parameters import ModelParameters
from aixplain.modules.agent.output_format import OutputFormat
from aixplain.modules.model.response import ModelResponse

GPT_4o_ID = "6646261c6eb563165658bbb1"
SUPPORTED_TOOLS = ["llm", "website_search", "website_scrape", "website_crawl", "serper_search"]
Expand Down Expand Up @@ -155,51 +154,6 @@ def get_cached_model(model_id: str) -> any:
elif tool["description"] == "mentalist":
mentalist_llm = llm

resolved_model_id = payload.get("llmId", None)
if not resolved_model_id:
resolved_model_id = GPT_4o_ID
has_quality_check = any(
(getattr(ins, "name", "") or "").lower() == "qualitycheckinspector"
for ins in inspectors
)
if not has_quality_check:
try:
def process_response(model_response: ModelResponse, input_content: str) -> InspectorOutput:
critiques = model_response.data
action = InspectorAction.RERUN
return InspectorOutput(critiques=critiques, content_edited=input_content, action=action)

default_inspector = Inspector(
name="QualityCheckInspector",
model_id=resolved_model_id,
model_params={"prompt": "Analyze content to ensure correctness of response"},
policy=process_response,
auto=InspectorAuto.ALIGNMENT
)

inspectors = [default_inspector] + inspectors
inspector_targets = payload.get("inspectorTargets", inspector_targets if 'inspector_targets' in locals() else [])
if isinstance(inspector_targets, (str, InspectorTarget)):
inspector_targets = [inspector_targets]
normalized = []
for t in inspector_targets:
if isinstance(t, InspectorTarget):
normalized.append(t)
elif isinstance(t, str):
try:
normalized.append(InspectorTarget(t.lower()))
except Exception:
logging.warning(f"Ignoring unknown inspector target: {t!r}")
else:
logging.warning(f"Ignoring inspector target with unexpected type: {type(t)}")

if InspectorTarget.STEPS not in normalized:
normalized.append(InspectorTarget.STEPS)

inspector_targets = normalized

except Exception as e:
logging.warning(f"Failed to add default QualityCheckInspector: {e}")

team_agent = TeamAgent(
id=payload.get("id", ""),
Expand Down
13 changes: 5 additions & 8 deletions aixplain/modules/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,9 +393,6 @@ def _format_agent_progress(
if tool:
msg = f"⚙️ {agent_name} | {tool} | {status_icon}"

if runtime is not None and runtime > 0 and success is not None:
msg += f" ({runtime:.1f} s)"

if tool_input:
msg += f" | Input: {tool_input}"

Expand Down Expand Up @@ -604,27 +601,27 @@ def run(
Dict: parsed output from model
"""
start = time.time()

# Extract deprecated parameters from kwargs
output_format = kwargs.get("output_format", None)
expected_output = kwargs.get("expected_output", None)

if output_format is not None:
warnings.warn(
"The 'output_format' parameter is deprecated and will be removed in a future version. "
"Set the output format during agent initialization instead.",
DeprecationWarning,
stacklevel=2,
)

if expected_output is not None:
warnings.warn(
"The 'expected_output' parameter is deprecated and will be removed in a future version. "
"Set the expected output during agent initialization instead.",
DeprecationWarning,
stacklevel=2,
)

if session_id is not None and history is not None:
raise ValueError("Provide either `session_id` or `history`, not both.")

Expand Down Expand Up @@ -661,7 +658,7 @@ def run(
poll_url, name=name, timeout=timeout, wait_time=wait_time, progress_verbosity=progress_verbosity
)
if result.status == ResponseStatus.FAILED:
raise Exception("Model failed to run with error: " + result.error_message)
raise Exception("Model failed to run with error: " + result.error_message)
result_data = result.get("data") or {}
return AgentResponse(
status=ResponseStatus.SUCCESS,
Expand Down
40 changes: 32 additions & 8 deletions aixplain/modules/team_agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class TeamAgent(Model, DeployableMixin[Agent]):
instructions (Optional[Text]): Instructions to guide the team agent.
output_format (OutputFormat): Response format. Defaults to TEXT.
expected_output (Optional[Union[BaseModel, Text, dict]]): Expected output format.

Deprecated Attributes:
llm_id (Text): DEPRECATED. Use 'llm' parameter instead. Large language model ID.
mentalist_llm (Optional[LLM]): DEPRECATED. LLM for planning.
Expand Down Expand Up @@ -132,6 +132,32 @@ def __init__(
expected_output: Optional[Union[BaseModel, Text, dict]] = None,
**additional_info,
) -> None:
"""Initialize a TeamAgent instance.

Args:
id (Text): Unique identifier for the team agent.
name (Text): Name of the team agent.
agents (List[Agent], optional): List of agents in the team. Defaults to [].
description (Text, optional): Description of the team agent. Defaults to "".
llm (Optional[LLM], optional): LLM instance. Defaults to None.
supervisor_llm (Optional[LLM], optional): Supervisor LLM instance. Defaults to None.
api_key (Optional[Text], optional): API key. Defaults to config.TEAM_API_KEY.
supplier (Union[Dict, Text, Supplier, int], optional): Supplier. Defaults to "aiXplain".
version (Optional[Text], optional): Version. Defaults to None.
cost (Optional[Dict], optional): Cost information. Defaults to None.
inspectors (List[Inspector], optional): List of inspectors. Defaults to [].
inspector_targets (List[InspectorTarget], optional): Inspector targets. Defaults to [InspectorTarget.STEPS].
status (AssetStatus, optional): Status of the team agent. Defaults to AssetStatus.DRAFT.
instructions (Optional[Text], optional): Instructions for the team agent. Defaults to None.
output_format (OutputFormat, optional): Output format. Defaults to OutputFormat.TEXT.
expected_output (Optional[Union[BaseModel, Text, dict]], optional): Expected output format. Defaults to None.
**additional_info: Additional keyword arguments.

Deprecated Args:
llm_id (Text, optional): DEPRECATED. Use 'llm' parameter instead. ID of the language model. Defaults to "6646261c6eb563165658bbb1".
mentalist_llm (Optional[LLM], optional): DEPRECATED. Mentalist/Planner LLM instance. Defaults to None.
use_mentalist (bool, optional): DEPRECATED. Whether to use mentalist/planner. Defaults to True.
"""
# Handle deprecated parameters from kwargs
if "llm_id" in additional_info:
llm_id = additional_info.pop("llm_id")
Expand All @@ -143,7 +169,7 @@ def __init__(
)
else:
llm_id = "6646261c6eb563165658bbb1"

if "mentalist_llm" in additional_info:
mentalist_llm = additional_info.pop("mentalist_llm")
warnings.warn(
Expand All @@ -153,7 +179,7 @@ def __init__(
)
else:
mentalist_llm = None

if "use_mentalist" in additional_info:
use_mentalist = additional_info.pop("use_mentalist")
warnings.warn(
Expand Down Expand Up @@ -369,9 +395,6 @@ def _format_team_progress(
# Full verbosity: detailed info
msg = f"{emoji} {context} | {status_icon}"

if runtime is not None and runtime > 0 and success is not None:
msg += f" ({runtime:.1f} s)"

if current_step and total_steps:
msg += f" | Step {current_step}/{total_steps}"

Expand Down Expand Up @@ -576,6 +599,7 @@ def run(
max_iterations (int, optional): maximum number of iterations between the agents. Defaults to 30.
trace_request (bool, optional): return the request id for tracing the request. Defaults to False.
progress_verbosity (Optional[str], optional): Progress display mode - "full" (detailed), "compact" (brief), or None (no progress). Defaults to "compact".
**kwargs: Additional deprecated keyword arguments (output_format, expected_output).

Returns:
AgentResponse: parsed output from model
Expand All @@ -589,7 +613,7 @@ def run(
DeprecationWarning,
stacklevel=2,
)

expected_output = kwargs.pop("expected_output", None)
if expected_output is not None:
warnings.warn(
Expand All @@ -598,7 +622,7 @@ def run(
DeprecationWarning,
stacklevel=2,
)

start = time.time()
result_data = {}
if session_id is not None and history is not None:
Expand Down
61 changes: 61 additions & 0 deletions aixplain/modules/team_agent/inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,3 +380,64 @@ def model_validate(cls, data: Union[Dict, "Inspector"]) -> "Inspector":
data.pop("policy_type", None) # Remove the type indicator

return super().model_validate(data)


class VerificationInspector(Inspector):
"""Pre-defined inspector that checks output against the plan.

This inspector is designed to verify that the output aligns with the intended plan
and provides feedback when discrepancies are found. It uses a RERUN policy by default,
meaning it will request re-execution when issues are detected.

Example usage:
from aixplain.modules.team_agent import VerificationInspector

# Use with default model (GPT-4o or resolved_model_id)
inspector = VerificationInspector()

# Or with custom model
inspector = VerificationInspector(model_id="your_model_id")

team_agent = TeamAgent(
name="my_team",
agents=agents,
inspectors=[VerificationInspector()],
inspector_targets=[InspectorTarget.STEPS]
)
"""

def __init__(self, model_id: Optional[Text] = None, **kwargs):
"""Initialize VerificationInspector with default configuration.

Args:
model_id (Optional[Text]): Model ID to use. If not provided, uses auto configuration.
**kwargs: Additional arguments passed to Inspector parent class.
"""
from aixplain.modules.model.response import ModelResponse

# Replicate resolved_model_id logic from old implementation
resolved_model_id = model_id
if not resolved_model_id:
resolved_model_id = "6646261c6eb563165658bbb1" # GPT_4o_ID

def process_response(model_response: ModelResponse, input_content: str) -> InspectorOutput:
"""Default policy that always requests rerun for verification."""
critiques = model_response.data
action = InspectorAction.RERUN
return InspectorOutput(critiques=critiques, content_edited=input_content, action=action)

# Exact same default inspector configuration as old implementation
# Note: When auto=InspectorAuto.ALIGNMENT is set, Inspector.__init__ will override
# model_id with AUTO_DEFAULT_MODEL_ID
defaults = {
"name": "VerificationInspector",
"model_id": resolved_model_id,
"model_params": {"prompt": "Check the output against the plan"},
"policy": process_response,
"auto": InspectorAuto.ALIGNMENT
}

# Override defaults with any provided kwargs
defaults.update(kwargs)

super().__init__(**defaults)
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,20 @@ sidebar_label: api_key_checker
title: aixplain.decorators.api_key_checker
---

API key validation decorator for aiXplain SDK.

#### check\_api\_key

```python
def check_api_key(method)
```

[[view_source]](https://github.com/aixplain/aiXplain/blob/main/aixplain/decorators/api_key_checker.py#L4)
[[view_source]](https://github.com/aixplain/aiXplain/blob/main/aixplain/decorators/api_key_checker.py#L6)

Decorator to verify that an API key is set before executing the method.

This decorator checks if either TEAM_API_KEY or AIXPLAIN_API_KEY is set in the
configuration. If neither key is set, it raises an exception.
This decorator uses the centralized API key validation logic from config.py
to ensure consistent behavior across the entire SDK.

**Arguments**:

Expand Down
4 changes: 2 additions & 2 deletions docs/api-reference/python/aixplain/exceptions/init.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@ def get_error_from_status_code(status_code: int,
) -> AixplainBaseException
```

[[view_source]](https://github.com/aixplain/aiXplain/blob/main/aixplain/exceptions/__init__.py#L21)
[[view_source]](https://github.com/aixplain/aiXplain/blob/main/aixplain/exceptions/__init__.py#L35)

Map HTTP status codes to appropriate exception types.

**Arguments**:

- `status_code` _int_ - The HTTP status code to map.
- `default_message` _str, optional_ - The default message to use if no specific message is available.
- `error_details` _str, optional_ - Additional error details to include in the message.


**Returns**:
Expand Down
Loading
Loading