Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 43 additions & 9 deletions aixplain/factories/agent_factory/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
__author__ = "lucaspavanelli"

"""
Copyright 2024 The aiXplain SDK authors
"""Copyright 2024 The aiXplain SDK authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -21,6 +18,8 @@
Agent Factory Class
"""

__author__ = "lucaspavanelli"

import json
import logging
import warnings
Expand Down Expand Up @@ -49,6 +48,19 @@
from aixplain.enums import DatabaseSourceType


def to_literal_text(x):
"""Convert value to literal text, escaping braces for string formatting.

Args:
x: Value to convert (dict, list, or any other type)

Returns:
str: Escaped string representation
"""
s = json.dumps(x, ensure_ascii=False, indent=2) if isinstance(x, (dict, list)) else str(x)
return s.replace("{", "{{").replace("}", "}}")


class AgentFactory:
"""Factory class for creating and managing agents in the aiXplain system.

Expand Down Expand Up @@ -90,9 +102,11 @@ def create(
api_key (Text, optional): team/user API key. Defaults to config.TEAM_API_KEY.
supplier (Union[Dict, Text, Supplier, int], optional): owner of the agent. Defaults to "aiXplain".
version (Optional[Text], optional): version of the agent. Defaults to None.
tasks (List[WorkflowTask], optional): Deprecated. Use workflow_tasks instead. Defaults to None.
workflow_tasks (List[WorkflowTask], optional): list of tasks for the agent. Defaults to [].
output_format (OutputFormat, optional): default output format for agent responses. Defaults to OutputFormat.TEXT.
expected_output (Union[BaseModel, Text, dict], optional): expected output. Defaults to None.

Returns:
Agent: created Agent
"""
Expand Down Expand Up @@ -133,7 +147,8 @@ def create(

if tasks is not None:
warnings.warn(
"The 'tasks' parameter is deprecated and will be removed in a future version. " "Use 'workflow_tasks' instead.",
"The 'tasks' parameter is deprecated and will be removed in a future version. "
"Use 'workflow_tasks' instead.",
DeprecationWarning,
stacklevel=2,
)
Expand All @@ -144,8 +159,8 @@ def create(
payload = {
"name": name,
"assets": [build_tool_payload(tool) for tool in tools],
"description": description,
"instructions": instructions or description,
"description": to_literal_text(description),
"instructions": to_literal_text(instructions) if instructions is not None else description,
"supplier": supplier,
"version": version,
"llmId": llm_id,
Expand Down Expand Up @@ -228,6 +243,17 @@ def create_workflow_task(
expected_output: Text,
dependencies: Optional[List[Text]] = None,
) -> WorkflowTask:
"""Create a new workflow task for an agent.

Args:
name (Text): Name of the task
description (Text): Description of what the task does
expected_output (Text): Expected output format or content
dependencies (Optional[List[Text]], optional): List of task names this task depends on. Defaults to None.

Returns:
WorkflowTask: Created workflow task object
"""
dependencies = [] if dependencies is None else list(dependencies)
return WorkflowTask(
name=name,
Expand All @@ -238,6 +264,11 @@ def create_workflow_task(

@classmethod
def create_task(cls, *args, **kwargs):
"""Create a workflow task (deprecated - use create_workflow_task instead).

.. deprecated::
Use :meth:`create_workflow_task` instead.
"""
warnings.warn(
"The 'create_task' method is deprecated and will be removed in a future version. "
"Use 'create_workflow_task' instead.",
Expand Down Expand Up @@ -351,7 +382,7 @@ def create_sql_tool(
tables: Optional[List[Text]] = None,
enable_commit: bool = False,
) -> SQLTool:
"""Create a new SQL tool
"""Create a new SQL tool.

Args:
name (Text): name of the tool
Expand All @@ -361,6 +392,7 @@ def create_sql_tool(
schema (Optional[Text], optional): database schema description
tables (Optional[List[Text]], optional): table names to work with (optional)
enable_commit (bool, optional): enable to modify the database (optional)

Returns:
SQLTool: created SQLTool

Expand Down Expand Up @@ -403,7 +435,9 @@ def create_sql_tool(
# Already the correct type, no conversion needed
pass
else:
raise SQLToolError(f"Source type must be either a string or DatabaseSourceType enum, got {type(source_type)}")
raise SQLToolError(
f"Source type must be either a string or DatabaseSourceType enum, got {type(source_type)}"
)

database_path = None # Final database path to pass to SQLTool

Expand Down
5 changes: 4 additions & 1 deletion aixplain/factories/team_agent_factory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,10 +227,13 @@ def _setup_llm_and_tool(
team_agent = build_team_agent(payload=internal_payload, agents=agent_list, api_key=api_key)
team_agent.validate(raise_exception=True)
response = "Unspecified error"
inspectors=team_agent.inspectors
inspector_targets=team_agent.inspector_targets
try:
payload["inspectors"] = [
inspector.model_dump(by_alias=True) for inspector in inspectors
] # convert Inspector object to dict
]
payload["inspectorTargets"] = inspector_targets
logging.debug(f"Start service for POST Create TeamAgent - {url} - {headers} - {json.dumps(payload)}")
r = _request_with_retry("post", url, headers=headers, json=payload)
response = r.json()
Expand Down
48 changes: 47 additions & 1 deletion aixplain/factories/team_agent_factory/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
from aixplain.modules.agent.agent_task import AgentTask
from aixplain.modules.agent.tool.model_tool import ModelTool
from aixplain.modules.team_agent import TeamAgent, InspectorTarget
from aixplain.modules.team_agent.inspector import Inspector
from aixplain.modules.team_agent.inspector import Inspector, InspectorAction, InspectorAuto, InspectorPolicy, InspectorOutput
from aixplain.factories.agent_factory import AgentFactory
from aixplain.factories.model_factory import ModelFactory
from aixplain.modules.model.model_parameters import ModelParameters
from aixplain.modules.agent.output_format import OutputFormat
from aixplain.modules.model.response import ModelResponse

GPT_4o_ID = "6646261c6eb563165658bbb1"
SUPPORTED_TOOLS = ["llm", "website_search", "website_scrape", "website_crawl", "serper_search"]
Expand Down Expand Up @@ -154,6 +155,51 @@ def get_cached_model(model_id: str) -> any:
elif tool["description"] == "mentalist":
mentalist_llm = llm

resolved_model_id = payload.get("llmId", None)
if not resolved_model_id:
resolved_model_id = GPT_4o_ID
has_quality_check = any(
(getattr(ins, "name", "") or "").lower() == "qualitycheckinspector"
for ins in inspectors
)
if not has_quality_check:
try:
def process_response(model_response: ModelResponse, input_content: str) -> InspectorOutput:
critiques = model_response.data
action = InspectorAction.RERUN
return InspectorOutput(critiques=critiques, content_edited=input_content, action=action)

default_inspector = Inspector(
name="QualityCheckInspector",
model_id=resolved_model_id,
model_params={"prompt": "Analyze content to ensure correctness of response"},
policy=process_response
)

inspectors = [default_inspector] + inspectors
inspector_targets = payload.get("inspectorTargets", inspector_targets if 'inspector_targets' in locals() else [])
if isinstance(inspector_targets, (str, InspectorTarget)):
inspector_targets = [inspector_targets]
normalized = []
for t in inspector_targets:
if isinstance(t, InspectorTarget):
normalized.append(t)
elif isinstance(t, str):
try:
normalized.append(InspectorTarget(t.lower()))
except Exception:
logging.warning(f"Ignoring unknown inspector target: {t!r}")
else:
logging.warning(f"Ignoring inspector target with unexpected type: {type(t)}")

if InspectorTarget.STEPS not in normalized:
normalized.append(InspectorTarget.STEPS)

inspector_targets = normalized

except Exception as e:
logging.warning(f"Failed to add default QualityCheckInspector: {e}")

team_agent = TeamAgent(
id=payload.get("id", ""),
name=payload.get("name", ""),
Expand Down
4 changes: 3 additions & 1 deletion aixplain/modules/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ def run(
output_format (OutputFormat, optional): response format. If not provided, uses the format set during initialization.
expected_output (Union[BaseModel, Text, dict], optional): expected output. Defaults to None.
trace_request (bool, optional): return the request id for tracing the request. Defaults to False.

Returns:
Dict: parsed output from model
"""
Expand Down Expand Up @@ -427,6 +428,7 @@ def run_async(
output_format (ResponseFormat, optional): response format. Defaults to TEXT.
evolve (Union[Dict[str, Any], EvolveParam, None], optional): evolve the agent configuration. Can be a dictionary, EvolveParam instance, or None.
trace_request (bool, optional): return the request id for tracing the request. Defaults to False.

Returns:
dict: polling URL in response
"""
Expand Down Expand Up @@ -490,7 +492,7 @@ def run_async(
input_data = process_variables(query, data, parameters, self.instructions)
if expected_output is None:
expected_output = self.expected_output
if expected_output is not None and issubclass(expected_output, BaseModel):
if expected_output is not None and isinstance(expected_output, type) and issubclass(expected_output, BaseModel):
expected_output = expected_output.model_json_schema()
expected_output = normalize_expected_output(expected_output)
# Use instance output_format if none provided
Expand Down
49 changes: 32 additions & 17 deletions aixplain/modules/model/llm_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
__author__ = "lucaspavanelli"

"""
Copyright 2024 The aiXplain SDK authors
"""Copyright 2024 The aiXplain SDK authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -20,6 +17,8 @@
Description:
Large Language Model Class
"""

__author__ = "lucaspavanelli"
import time
import logging
import traceback
Expand Down Expand Up @@ -63,7 +62,7 @@ def __init__(
function: Optional[Function] = None,
is_subscribed: bool = False,
cost: Optional[Dict] = None,
temperature: float = 0.001,
temperature: Optional[float] = None,
function_type: Optional[FunctionType] = FunctionType.AI,
**additional_info,
) -> None:
Expand All @@ -79,14 +78,16 @@ def __init__(
function (Function, optional): Model's AI function. Must be Function.TEXT_GENERATION.
is_subscribed (bool, optional): Whether the user is subscribed. Defaults to False.
cost (Dict, optional): Cost of the model. Defaults to None.
temperature (float, optional): Default temperature for text generation. Defaults to 0.001.
temperature (Optional[float], optional): Default temperature for text generation. Defaults to None.
function_type (FunctionType, optional): Type of the function. Defaults to FunctionType.AI.
**additional_info: Any additional model info to be saved.

Raises:
AssertionError: If function is not Function.TEXT_GENERATION.
"""
assert function == Function.TEXT_GENERATION, "LLM only supports large language models (i.e. text generation function)"
assert function == Function.TEXT_GENERATION, (
"LLM only supports large language models (i.e. text generation function)"
)
super().__init__(
id=id,
name=name,
Expand All @@ -112,12 +113,13 @@ def run(
history: Optional[List[Dict]] = None,
temperature: Optional[float] = None,
max_tokens: int = 128,
top_p: float = 1.0,
top_p: Optional[float] = None,
name: Text = "model_process",
timeout: float = 300,
parameters: Optional[Dict] = None,
wait_time: float = 0.5,
stream: bool = False,
response_format: Optional[Text] = None,
) -> Union[ModelResponse, ModelResponseStreamer]:
"""Run the LLM model synchronously to generate text.

Expand All @@ -138,8 +140,8 @@ def run(
Defaults to None.
max_tokens (int, optional): Maximum number of tokens to generate.
Defaults to 128.
top_p (float, optional): Nucleus sampling parameter. Only tokens with cumulative
probability < top_p are considered. Defaults to 1.0.
top_p (Optional[float], optional): Nucleus sampling parameter. Only tokens with cumulative
probability < top_p are considered. Defaults to None.
name (Text, optional): Identifier for this model run. Useful for logging.
Defaults to "model_process".
timeout (float, optional): Maximum time in seconds to wait for completion.
Expand All @@ -150,6 +152,8 @@ def run(
Defaults to 0.5.
stream (bool, optional): Whether to stream the model's output tokens.
Defaults to False.
response_format (Optional[Union[str, dict, BaseModel]], optional):
Specifies the desired output structure or format of the model’s response.

Returns:
Union[ModelResponse, ModelResponseStreamer]: If stream=False, returns a ModelResponse
Expand All @@ -166,9 +170,13 @@ def run(
parameters.setdefault("context", context)
parameters.setdefault("prompt", prompt)
parameters.setdefault("history", history)
parameters.setdefault("temperature", temperature if temperature is not None else self.temperature)
temp_value = temperature if temperature is not None else self.temperature
if temp_value is not None:
parameters.setdefault("temperature", temp_value)
parameters.setdefault("max_tokens", max_tokens)
parameters.setdefault("top_p", top_p)
if top_p is not None:
parameters.setdefault("top_p", top_p)
parameters.setdefault("response_format", response_format)

if stream:
return self.run_stream(data=data, parameters=parameters)
Expand Down Expand Up @@ -210,9 +218,10 @@ def run_async(
history: Optional[List[Dict]] = None,
temperature: Optional[float] = None,
max_tokens: int = 128,
top_p: float = 1.0,
top_p: Optional[float] = None,
name: Text = "model_process",
parameters: Optional[Dict] = None,
response_format: Optional[Text] = None,
) -> ModelResponse:
"""Run the LLM model asynchronously to generate text.

Expand All @@ -233,12 +242,14 @@ def run_async(
Defaults to None.
max_tokens (int, optional): Maximum number of tokens to generate.
Defaults to 128.
top_p (float, optional): Nucleus sampling parameter. Only tokens with cumulative
probability < top_p are considered. Defaults to 1.0.
top_p (Optional[float], optional): Nucleus sampling parameter. Only tokens with cumulative
probability < top_p are considered. Defaults to None.
name (Text, optional): Identifier for this model run. Useful for logging.
Defaults to "model_process".
parameters (Optional[Dict], optional): Additional model-specific parameters.
Defaults to None.
response_format (Optional[Text], optional): Desired output format specification.
Defaults to None.

Returns:
ModelResponse: A response object containing:
Expand All @@ -261,9 +272,13 @@ def run_async(
parameters.setdefault("context", context)
parameters.setdefault("prompt", prompt)
parameters.setdefault("history", history)
parameters.setdefault("temperature", temperature if temperature is not None else self.temperature)
temp_value = temperature if temperature is not None else self.temperature
if temp_value is not None:
parameters.setdefault("temperature", temp_value)
parameters.setdefault("max_tokens", max_tokens)
parameters.setdefault("top_p", top_p)
if top_p is not None:
parameters.setdefault("top_p", top_p)
parameters.setdefault("response_format", response_format)
payload = build_payload(data=data, parameters=parameters)
response = call_run_endpoint(payload=payload, url=url, api_key=self.api_key)
return ModelResponse(
Expand Down
2 changes: 1 addition & 1 deletion aixplain/modules/team_agent/inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class InspectorOutput(BaseModel):

class InspectorAuto(str, Enum):
"""A list of keywords for inspectors configured automatically in the backend."""

ALIGNMENT = "alignment"
CORRECTNESS = "correctness"

def get_name(self) -> Text:
Expand Down
2 changes: 1 addition & 1 deletion tests/functional/agent/agent_functional_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def test_python_interpreter_tool(delete_agents_and_team_agents, AgentFactory):
assert len(response["data"]["intermediate_steps"]) > 0
intermediate_step = response["data"]["intermediate_steps"][0]
assert len(intermediate_step["tool_steps"]) > 0
assert intermediate_step["tool_steps"][0]["tool"] == "Custom Code Tool"
assert intermediate_step["tool_steps"][0]["tool"] == "Python Code Interpreter Tool"
agent.delete()


Expand Down