diff --git a/aixplain/factories/agent_factory/__init__.py b/aixplain/factories/agent_factory/__init__.py index 6076eef6..0c73637c 100644 --- a/aixplain/factories/agent_factory/__init__.py +++ b/aixplain/factories/agent_factory/__init__.py @@ -67,7 +67,7 @@ def create( if isinstance(tool, ModelTool): tool_payload.append( { - "function": tool.function.value if tool.function is not None else None, + "function": tool.function.value, "type": "model", "description": tool.description, "supplier": tool.supplier.value["code"] if tool.supplier else None, diff --git a/aixplain/modules/agent/tool/model_tool.py b/aixplain/modules/agent/tool/model_tool.py index a5acab30..e15a8bea 100644 --- a/aixplain/modules/agent/tool/model_tool.py +++ b/aixplain/modules/agent/tool/model_tool.py @@ -24,7 +24,6 @@ from aixplain.enums.function import Function from aixplain.enums.supplier import Supplier -from aixplain.factories.model_factory import ModelFactory from aixplain.modules.agent.tool import Tool from aixplain.modules.model import Model @@ -58,7 +57,6 @@ def __init__( if function is not None: if isinstance(function, str): function = Function(function) - self.function = function try: if isinstance(supplier, dict): @@ -68,9 +66,13 @@ def __init__( if model is not None: if isinstance(model, Text) is True: + from aixplain.factories.model_factory import ModelFactory + model = ModelFactory.get(model) + function = model.function if isinstance(model.supplier, Supplier): supplier = model.supplier model = model.id self.supplier = supplier self.model = model + self.function = function diff --git a/tests/functional/agent/agent_functional_test.py b/tests/functional/agent/agent_functional_test.py index cefd34c3..58d421c8 100644 --- a/tests/functional/agent/agent_functional_test.py +++ b/tests/functional/agent/agent_functional_test.py @@ -37,6 +37,9 @@ def run_input_map(request): def test_end2end(run_input_map): + for agent in AgentFactory.list()["results"]: + agent.delete() + tools = [] if "model_tools" in run_input_map: for tool in run_input_map["model_tools"]: @@ -47,6 +50,7 @@ def test_end2end(run_input_map): ]: tool["supplier"] = supplier break + print("TOOL: ", tool) tools.append(AgentFactory.create_model_tool(**tool)) if "pipeline_tools" in run_input_map: for tool in run_input_map["pipeline_tools"]: diff --git a/tests/unit/designer_test.py b/tests/unit/designer_unit_test.py similarity index 89% rename from tests/unit/designer_test.py rename to tests/unit/designer_unit_test.py index 766c7e54..824fd162 100644 --- a/tests/unit/designer_test.py +++ b/tests/unit/designer_unit_test.py @@ -30,9 +30,7 @@ def test_create_node(): class BareNode(Node): pass - with mock.patch( - "aixplain.modules.pipeline.designer.Node.attach_to" - ) as mock_attach_to: + with mock.patch("aixplain.modules.pipeline.designer.Node.attach_to") as mock_attach_to: node = BareNode(number=3, label="FOO") mock_attach_to.assert_not_called() assert isinstance(node.inputs, Inputs) @@ -50,9 +48,7 @@ class FooNode(Node[FooNodeInputs, FooNodeOutputs]): inputs_class = FooNodeInputs outputs_class = FooNodeOutputs - with mock.patch( - "aixplain.modules.pipeline.designer.Node.attach_to" - ) as mock_attach_to: + with mock.patch("aixplain.modules.pipeline.designer.Node.attach_to") as mock_attach_to: node = FooNode(pipeline=pipeline, number=3, label="FOO") mock_attach_to.assert_called_once_with(pipeline) assert isinstance(node.inputs, FooNodeInputs) @@ -120,9 +116,7 @@ class AssetNode(Node): node = AssetNode() with mock.patch.object(node.inputs, "serialize") as mock_inputs_serialize: - with mock.patch.object( - node.outputs, "serialize" - ) as mock_outputs_serialize: + with mock.patch.object(node.outputs, "serialize") as mock_outputs_serialize: assert node.serialize() == { "number": node.number, "type": NodeType.ASSET, @@ -148,13 +142,10 @@ class AssetNode(Node): def test_create_param(): - class TypedParam(Param): param_type = ParamType.INPUT - with mock.patch( - "aixplain.modules.pipeline.designer.Param.attach_to" - ) as mock_attach_to: + with mock.patch("aixplain.modules.pipeline.designer.Param.attach_to") as mock_attach_to: param = TypedParam( code="param", data_type=DataType.TEXT, @@ -167,9 +158,7 @@ class TypedParam(Param): assert param.value == "foo" assert param.param_type == ParamType.INPUT - with mock.patch( - "aixplain.modules.pipeline.designer.Param.attach_to" - ) as mock_attach_to: + with mock.patch("aixplain.modules.pipeline.designer.Param.attach_to") as mock_attach_to: param = TypedParam( code="param", data_type=DataType.TEXT, @@ -186,9 +175,7 @@ class TypedParam(Param): class UnTypedParam(Param): pass - with mock.patch( - "aixplain.modules.pipeline.designer.Param.attach_to" - ) as mock_attach_to: + with mock.patch("aixplain.modules.pipeline.designer.Param.attach_to") as mock_attach_to: param = UnTypedParam( code="param", data_type=DataType.TEXT, @@ -199,9 +186,7 @@ class UnTypedParam(Param): assert param.param_type == ParamType.OUTPUT - with mock.patch( - "aixplain.modules.pipeline.designer.Param.attach_to" - ) as mock_attach_to: + with mock.patch("aixplain.modules.pipeline.designer.Param.attach_to") as mock_attach_to: param = UnTypedParam( code="param", data_type=DataType.TEXT, @@ -217,9 +202,7 @@ class AssetNode(Node): node = AssetNode() - with mock.patch( - "aixplain.modules.pipeline.designer.Param.attach_to" - ) as mock_attach_to: + with mock.patch("aixplain.modules.pipeline.designer.Param.attach_to") as mock_attach_to: param = UnTypedParam( code="param", data_type=DataType.TEXT, @@ -243,12 +226,8 @@ class AssetNode(Node): node = AssetNode() - with mock.patch( - "aixplain.modules.pipeline.designer.Param.attach_to" - ) as mock_attach_to: - param = param_cls( - code="param", data_type=DataType.TEXT, value="foo", node=node - ) + with mock.patch("aixplain.modules.pipeline.designer.Param.attach_to") as mock_attach_to: + param = param_cls(code="param", data_type=DataType.TEXT, value="foo", node=node) mock_attach_to.assert_called_once_with(node) assert param.code == "param" assert param.data_type == DataType.TEXT @@ -322,12 +301,8 @@ class AssetNode(Node, LinkableMixin): assert "Param not registered as output" in str(excinfo.value) - output = OutputParam( - code="output", data_type=DataType.TEXT, value="bar", node=a - ) - input = InputParam( - code="input", data_type=DataType.TEXT, value="foo", node=b - ) + output = OutputParam(code="output", data_type=DataType.TEXT, value="bar", node=a) + input = InputParam(code="input", data_type=DataType.TEXT, value="foo", node=b) with mock.patch.object(input, "back_link") as mock_back_link: output.link(input) @@ -364,12 +339,8 @@ class AssetNode(Node, LinkableMixin): assert "Param not registered as input" in str(excinfo.value) - output = OutputParam( - code="output", data_type=DataType.TEXT, value="bar", node=a - ) - input = InputParam( - code="input", data_type=DataType.TEXT, value="foo", node=b - ) + output = OutputParam(code="output", data_type=DataType.TEXT, value="bar", node=a) + input = InputParam(code="input", data_type=DataType.TEXT, value="foo", node=b) with mock.patch.object(a, "link") as mock_link: input.back_link(output) @@ -429,9 +400,7 @@ class AssetNode(Node, LinkableMixin): pipeline = DesignerPipeline() - with mock.patch( - "aixplain.modules.pipeline.designer.Link.attach_to" - ) as mock_attach_to: + with mock.patch("aixplain.modules.pipeline.designer.Link.attach_to") as mock_attach_to: link = Link( from_node=a, to_node=b, @@ -588,12 +557,8 @@ class AssetNode(Node): with mock.patch.object(param_proxy, "_create_param") as mock_create_param: with mock.patch.object(param_proxy, "add_param") as mock_add_param: - param = param_proxy.create_param( - "foo", DataType.TEXT, "bar", is_required=True - ) - mock_create_param.assert_called_once_with( - "foo", DataType.TEXT, "bar" - ) + param = param_proxy.create_param("foo", DataType.TEXT, "bar", is_required=True) + mock_create_param.assert_called_once_with("foo", DataType.TEXT, "bar") mock_add_param.assert_called_once_with(param) assert param.is_required is True @@ -624,19 +589,14 @@ class FooParam(Param): def test_node_link(): - class AssetNode(Node, LinkableMixin): type: NodeType = NodeType.ASSET a = AssetNode() b = AssetNode() - output = OutputParam( - code="output", data_type=DataType.TEXT, value="bar", node=a - ) - input = InputParam( - code="input", data_type=DataType.TEXT, value="foo", node=b - ) + output = OutputParam(code="output", data_type=DataType.TEXT, value="bar", node=a) + input = InputParam(code="input", data_type=DataType.TEXT, value="foo", node=b) # here too lazy to mock Link class properly # checking the output instance instead