Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion aixplain/factories/agent_factory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from aixplain.modules.agent.tool.pipeline_tool import PipelineTool
from aixplain.modules.agent.tool.python_interpreter_tool import PythonInterpreterTool
from aixplain.modules.agent.tool.custom_python_code_tool import CustomPythonCodeTool
from aixplain.utils.convert_datatype_utils import normalize_expected_output
from aixplain.modules.agent.tool.sql_tool import (
SQLTool,
)
Expand Down Expand Up @@ -167,7 +168,8 @@ def create(
payload["llm"] = llm

if expected_output:
payload["expectedOutput"] = expected_output
payload["expectedOutput"] = normalize_expected_output(expected_output)

if output_format:
if isinstance(output_format, OutputFormat):
output_format = output_format.value
Expand Down
3 changes: 2 additions & 1 deletion aixplain/factories/team_agent_factory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from aixplain.modules.team_agent.inspector import Inspector
from aixplain.utils import config
from aixplain.factories.team_agent_factory.utils import build_team_agent
from aixplain.utils.convert_datatype_utils import normalize_expected_output
from aixplain.utils.request_utils import _request_with_retry
from aixplain.modules.model.llm_model import LLM
from aixplain.utils.llm_utils import get_llm_instance
Expand Down Expand Up @@ -217,7 +218,7 @@ def _setup_llm_and_tool(
if mentalist_llm is not None:
internal_payload["mentalist_llm"] = mentalist_llm
if expected_output:
payload["expectedOutput"] = expected_output
payload["expectedOutput"] = normalize_expected_output(expected_output)
if output_format:
if isinstance(output_format, OutputFormat):
output_format = output_format.value
Expand Down
2 changes: 2 additions & 0 deletions aixplain/modules/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from aixplain.modules.agent.evolve_param import EvolveParam, validate_evolve_param
from urllib.parse import urljoin
from aixplain.modules.model.llm_model import LLM
from aixplain.utils.convert_datatype_utils import normalize_expected_output

from aixplain.utils import config
from aixplain.modules.mixins import DeployableMixin
Expand Down Expand Up @@ -488,6 +489,7 @@ def run_async(
expected_output = self.expected_output
if expected_output is not None and issubclass(expected_output, BaseModel):
expected_output = expected_output.model_json_schema()
expected_output = normalize_expected_output(expected_output)
# Use instance output_format if none provided
if output_format is None:
output_format = self.output_format
Expand Down
1 change: 1 addition & 0 deletions aixplain/modules/agent/agent_task.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List, Text, Union, Optional
from aixplain.utils.convert_datatype_utils import normalize_expected_output


class WorkflowTask:
Expand Down
6 changes: 3 additions & 3 deletions aixplain/modules/team_agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from aixplain.modules.agent.utils import process_variables, validate_history
from aixplain.modules.team_agent.inspector import Inspector
from aixplain.modules.team_agent.evolver_response_data import EvolverResponseData
from aixplain.utils.convert_datatype_utils import normalize_expected_output
from aixplain.utils import config
from aixplain.utils.request_utils import _request_with_retry
from aixplain.modules.model.llm_model import LLM
Expand Down Expand Up @@ -360,11 +361,10 @@ def run_async(
headers = {"x-api-key": self.api_key, "Content-Type": "application/json"}

# build query
input_data = process_variables(query, data, parameters, self.description)
if expected_output is None:
expected_output = self.expected_output
if expected_output is not None and issubclass(expected_output, BaseModel):
expected_output = expected_output.model_json_schema()
input_data = process_variables(query, data, parameters, self.description)
expected_output = normalize_expected_output(expected_output)
if output_format is None:
output_format = self.output_format
if isinstance(output_format, OutputFormat):
Expand Down
17 changes: 16 additions & 1 deletion aixplain/utils/convert_datatype_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@

from typing import Union, Dict, List
from aixplain.modules.metadata import MetaData

import json
from pydantic import BaseModel

def dict_to_metadata(metadatas: List[Union[Dict, MetaData]]) -> None:

Expand All @@ -38,3 +39,17 @@ def dict_to_metadata(metadatas: List[Union[Dict, MetaData]]) -> None:
metadatas[i] = MetaData(**metadatas[i])
except TypeError:
raise TypeError(f"Data Asset Onboarding Error: One or more elements in the metadata_schema are not well-structured")



def normalize_expected_output(obj):
if isinstance(obj, type) and issubclass(obj, BaseModel):
return obj.model_json_schema() if hasattr(obj, "model_json_schema") else obj.schema()

if isinstance(obj, BaseModel):
return json.loads(obj.model_dump_json()) if hasattr(obj, "model_dump_json") else json.loads(obj.json())

if isinstance(obj, (dict, str)) or obj is None:
return obj

return json.loads(json.dumps(obj))
61 changes: 60 additions & 1 deletion tests/unit/agent/agent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from aixplain.enums import Function, Supplier
from aixplain.modules.agent.agent_response import AgentResponse
from aixplain.modules.agent.agent_response_data import AgentResponseData

from pydantic import BaseModel
import json

def test_fail_no_data_query():
agent = Agent(
Expand Down Expand Up @@ -1552,3 +1553,61 @@ def test_agent_serialization_status_enum(status_input, expected_output):

agent_dict = agent.to_dict()
assert agent_dict["status"] == expected_output



class _EOUser(BaseModel):
id: int
name: str = "alice"

def _schema_for(cls):
return cls.model_json_schema() if hasattr(cls, "model_json_schema") else cls.schema()


def test_run_normalizes_expected_output_pydantic_class_in_execution_params():
agent = Agent(
id="eo-agent-norm-1",
name="EO Agent",
description="ensure expected_output is normalized",
expected_output=_EOUser,
)

run_url = urljoin(config.BACKEND_URL, f"sdk/agents/{agent.id}/run")
agent.url = run_url

with requests_mock.Mocker() as mock:
headers = {"x-api-key": config.AIXPLAIN_API_KEY, "Content-Type": "application/json"}
mock.post(run_url, headers=headers, json={"data": "dummy", "status": "IN_PROGRESS"})

agent.run_async(data={"query": "hi"})

sent = mock.last_request.json()
assert "executionParams" in sent
assert "expectedOutput" in sent["executionParams"]

eo = sent["executionParams"]["expectedOutput"]
assert isinstance(eo, dict), "expectedOutput must be a JSON-serializable dict"
assert eo == _schema_for(_EOUser), "expectedOutput schema doesn't match model schema"


def test_run_normalizes_expected_output_tuple_to_list_in_execution_params():
agent = Agent(
id="eo-agent-norm-2",
name="EO Agent 2",
description="tuple normalization",
expected_output=(1, 2, 3),
)

run_url = urljoin(config.BACKEND_URL, f"sdk/agents/{agent.id}/run")
agent.url = run_url

with requests_mock.Mocker() as mock:
headers = {"x-api-key": config.AIXPLAIN_API_KEY, "Content-Type": "application/json"}
mock.post(run_url, headers=headers, json={"data": "dummy", "status": "IN_PROGRESS"})

agent.run_async(data={"query": "hi"})

sent = mock.last_request.json()
assert "executionParams" in sent
assert "expectedOutput" in sent["executionParams"]
assert sent["executionParams"]["expectedOutput"] == [1, 2, 3], "tuple should normalize to list"