From b7719b4219b08c5d4991537577880a3aaaa79440 Mon Sep 17 00:00:00 2001 From: "zaina.abushaban" Date: Mon, 22 Sep 2025 10:00:31 +0300 Subject: [PATCH 1/2] normalized expected output from basemodel --- aixplain/factories/agent_factory/__init__.py | 4 +++- .../factories/team_agent_factory/__init__.py | 3 ++- aixplain/modules/agent/__init__.py | 2 ++ aixplain/modules/agent/agent_task.py | 1 + aixplain/modules/team_agent/__init__.py | 2 ++ aixplain/utils/convert_datatype_utils.py | 17 ++++++++++++++++- 6 files changed, 26 insertions(+), 3 deletions(-) diff --git a/aixplain/factories/agent_factory/__init__.py b/aixplain/factories/agent_factory/__init__.py index 09e65670..05ef1f2b 100644 --- a/aixplain/factories/agent_factory/__init__.py +++ b/aixplain/factories/agent_factory/__init__.py @@ -34,6 +34,7 @@ from aixplain.modules.agent.tool.pipeline_tool import PipelineTool from aixplain.modules.agent.tool.python_interpreter_tool import PythonInterpreterTool from aixplain.modules.agent.tool.custom_python_code_tool import CustomPythonCodeTool +from aixplain.utils.convert_datatype_utils import normalize_expected_output from aixplain.modules.agent.tool.sql_tool import ( SQLTool, ) @@ -167,7 +168,8 @@ def create( payload["llm"] = llm if expected_output: - payload["expectedOutput"] = expected_output + payload["expectedOutput"] = normalize_expected_output(expected_output) + if output_format: if isinstance(output_format, OutputFormat): output_format = output_format.value diff --git a/aixplain/factories/team_agent_factory/__init__.py b/aixplain/factories/team_agent_factory/__init__.py index b14542d6..60050813 100644 --- a/aixplain/factories/team_agent_factory/__init__.py +++ b/aixplain/factories/team_agent_factory/__init__.py @@ -32,6 +32,7 @@ from aixplain.modules.team_agent.inspector import Inspector from aixplain.utils import config from aixplain.factories.team_agent_factory.utils import build_team_agent +from aixplain.utils.convert_datatype_utils import normalize_expected_output from aixplain.utils.request_utils import _request_with_retry from aixplain.modules.model.llm_model import LLM from aixplain.utils.llm_utils import get_llm_instance @@ -217,7 +218,7 @@ def _setup_llm_and_tool( if mentalist_llm is not None: internal_payload["mentalist_llm"] = mentalist_llm if expected_output: - payload["expectedOutput"] = expected_output + payload["expectedOutput"] = normalize_expected_output(expected_output) if output_format: if isinstance(output_format, OutputFormat): output_format = output_format.value diff --git a/aixplain/modules/agent/__init__.py b/aixplain/modules/agent/__init__.py index 090b4c0e..69c22f03 100644 --- a/aixplain/modules/agent/__init__.py +++ b/aixplain/modules/agent/__init__.py @@ -46,6 +46,7 @@ from aixplain.modules.agent.evolve_param import EvolveParam, validate_evolve_param from urllib.parse import urljoin from aixplain.modules.model.llm_model import LLM +from aixplain.utils.convert_datatype_utils import normalize_expected_output from aixplain.utils import config from aixplain.modules.mixins import DeployableMixin @@ -488,6 +489,7 @@ def run_async( expected_output = self.expected_output if expected_output is not None and issubclass(expected_output, BaseModel): expected_output = expected_output.model_json_schema() + expected_output = normalize_expected_output(expected_output) # Use instance output_format if none provided if output_format is None: output_format = self.output_format diff --git a/aixplain/modules/agent/agent_task.py b/aixplain/modules/agent/agent_task.py index 433d58c0..5b9ce3b9 100644 --- a/aixplain/modules/agent/agent_task.py +++ b/aixplain/modules/agent/agent_task.py @@ -1,4 +1,5 @@ from typing import List, Text, Union, Optional +from aixplain.utils.convert_datatype_utils import normalize_expected_output class WorkflowTask: diff --git a/aixplain/modules/team_agent/__init__.py b/aixplain/modules/team_agent/__init__.py index fcc5df7d..6e338bd7 100644 --- a/aixplain/modules/team_agent/__init__.py +++ b/aixplain/modules/team_agent/__init__.py @@ -45,6 +45,7 @@ from aixplain.modules.agent.utils import process_variables, validate_history from aixplain.modules.team_agent.inspector import Inspector from aixplain.modules.team_agent.evolver_response_data import EvolverResponseData +from aixplain.utils.convert_datatype_utils import normalize_expected_output from aixplain.utils import config from aixplain.utils.request_utils import _request_with_retry from aixplain.modules.model.llm_model import LLM @@ -365,6 +366,7 @@ def run_async( expected_output = self.expected_output if expected_output is not None and issubclass(expected_output, BaseModel): expected_output = expected_output.model_json_schema() + expected_output = normalize_expected_output(expected_output) if output_format is None: output_format = self.output_format if isinstance(output_format, OutputFormat): diff --git a/aixplain/utils/convert_datatype_utils.py b/aixplain/utils/convert_datatype_utils.py index 08864cf6..9d18b4c6 100644 --- a/aixplain/utils/convert_datatype_utils.py +++ b/aixplain/utils/convert_datatype_utils.py @@ -16,7 +16,8 @@ from typing import Union, Dict, List from aixplain.modules.metadata import MetaData - +import json +from pydantic import BaseModel def dict_to_metadata(metadatas: List[Union[Dict, MetaData]]) -> None: @@ -38,3 +39,17 @@ def dict_to_metadata(metadatas: List[Union[Dict, MetaData]]) -> None: metadatas[i] = MetaData(**metadatas[i]) except TypeError: raise TypeError(f"Data Asset Onboarding Error: One or more elements in the metadata_schema are not well-structured") + + + +def normalize_expected_output(obj): + if isinstance(obj, type) and issubclass(obj, BaseModel): + return obj.model_json_schema() if hasattr(obj, "model_json_schema") else obj.schema() + + if isinstance(obj, BaseModel): + return json.loads(obj.model_dump_json()) if hasattr(obj, "model_dump_json") else json.loads(obj.json()) + + if isinstance(obj, (dict, str)) or obj is None: + return obj + + return json.loads(json.dumps(obj)) \ No newline at end of file From 58472b66e498eaff2173143c0770687b9b925196 Mon Sep 17 00:00:00 2001 From: "zaina.abushaban" Date: Wed, 24 Sep 2025 14:22:44 +0300 Subject: [PATCH 2/2] added tests --- aixplain/modules/team_agent/__init__.py | 4 +- tests/unit/agent/agent_test.py | 61 ++++++++++++++++++++++++- 2 files changed, 61 insertions(+), 4 deletions(-) diff --git a/aixplain/modules/team_agent/__init__.py b/aixplain/modules/team_agent/__init__.py index 6e338bd7..74b9a14c 100644 --- a/aixplain/modules/team_agent/__init__.py +++ b/aixplain/modules/team_agent/__init__.py @@ -361,11 +361,9 @@ def run_async( headers = {"x-api-key": self.api_key, "Content-Type": "application/json"} # build query - input_data = process_variables(query, data, parameters, self.description) if expected_output is None: expected_output = self.expected_output - if expected_output is not None and issubclass(expected_output, BaseModel): - expected_output = expected_output.model_json_schema() + input_data = process_variables(query, data, parameters, self.description) expected_output = normalize_expected_output(expected_output) if output_format is None: output_format = self.output_format diff --git a/tests/unit/agent/agent_test.py b/tests/unit/agent/agent_test.py index 24a78b7c..540fb734 100644 --- a/tests/unit/agent/agent_test.py +++ b/tests/unit/agent/agent_test.py @@ -15,7 +15,8 @@ from aixplain.enums import Function, Supplier from aixplain.modules.agent.agent_response import AgentResponse from aixplain.modules.agent.agent_response_data import AgentResponseData - +from pydantic import BaseModel +import json def test_fail_no_data_query(): agent = Agent( @@ -1552,3 +1553,61 @@ def test_agent_serialization_status_enum(status_input, expected_output): agent_dict = agent.to_dict() assert agent_dict["status"] == expected_output + + + +class _EOUser(BaseModel): + id: int + name: str = "alice" + +def _schema_for(cls): + return cls.model_json_schema() if hasattr(cls, "model_json_schema") else cls.schema() + + +def test_run_normalizes_expected_output_pydantic_class_in_execution_params(): + agent = Agent( + id="eo-agent-norm-1", + name="EO Agent", + description="ensure expected_output is normalized", + expected_output=_EOUser, + ) + + run_url = urljoin(config.BACKEND_URL, f"sdk/agents/{agent.id}/run") + agent.url = run_url + + with requests_mock.Mocker() as mock: + headers = {"x-api-key": config.AIXPLAIN_API_KEY, "Content-Type": "application/json"} + mock.post(run_url, headers=headers, json={"data": "dummy", "status": "IN_PROGRESS"}) + + agent.run_async(data={"query": "hi"}) + + sent = mock.last_request.json() + assert "executionParams" in sent + assert "expectedOutput" in sent["executionParams"] + + eo = sent["executionParams"]["expectedOutput"] + assert isinstance(eo, dict), "expectedOutput must be a JSON-serializable dict" + assert eo == _schema_for(_EOUser), "expectedOutput schema doesn't match model schema" + + +def test_run_normalizes_expected_output_tuple_to_list_in_execution_params(): + agent = Agent( + id="eo-agent-norm-2", + name="EO Agent 2", + description="tuple normalization", + expected_output=(1, 2, 3), + ) + + run_url = urljoin(config.BACKEND_URL, f"sdk/agents/{agent.id}/run") + agent.url = run_url + + with requests_mock.Mocker() as mock: + headers = {"x-api-key": config.AIXPLAIN_API_KEY, "Content-Type": "application/json"} + mock.post(run_url, headers=headers, json={"data": "dummy", "status": "IN_PROGRESS"}) + + agent.run_async(data={"query": "hi"}) + + sent = mock.last_request.json() + assert "executionParams" in sent + assert "expectedOutput" in sent["executionParams"] + assert sent["executionParams"]["expectedOutput"] == [1, 2, 3], "tuple should normalize to list"