Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aixplain/factories/agent_factory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def create(
if isinstance(tool, ModelTool):
tool_payload.append(
{
"function": tool.function.value if tool.function is not None else None,
"function": tool.function.value,
"type": "model",
"description": tool.description,
"supplier": tool.supplier.value["code"] if tool.supplier else None,
Expand Down
6 changes: 4 additions & 2 deletions aixplain/modules/agent/tool/model_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

from aixplain.enums.function import Function
from aixplain.enums.supplier import Supplier
from aixplain.factories.model_factory import ModelFactory
from aixplain.modules.agent.tool import Tool
from aixplain.modules.model import Model

Expand Down Expand Up @@ -58,7 +57,6 @@ def __init__(
if function is not None:
if isinstance(function, str):
function = Function(function)
self.function = function

try:
if isinstance(supplier, dict):
Expand All @@ -68,9 +66,13 @@ def __init__(

if model is not None:
if isinstance(model, Text) is True:
from aixplain.factories.model_factory import ModelFactory

model = ModelFactory.get(model)
function = model.function
if isinstance(model.supplier, Supplier):
supplier = model.supplier
model = model.id
self.supplier = supplier
self.model = model
self.function = function
4 changes: 4 additions & 0 deletions tests/functional/agent/agent_functional_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def run_input_map(request):


def test_end2end(run_input_map):
for agent in AgentFactory.list()["results"]:
agent.delete()

tools = []
if "model_tools" in run_input_map:
for tool in run_input_map["model_tools"]:
Expand All @@ -47,6 +50,7 @@ def test_end2end(run_input_map):
]:
tool["supplier"] = supplier
break
print("TOOL: ", tool)
tools.append(AgentFactory.create_model_tool(**tool))
if "pipeline_tools" in run_input_map:
for tool in run_input_map["pipeline_tools"]:
Expand Down
78 changes: 19 additions & 59 deletions tests/unit/designer_test.py → tests/unit/designer_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@ def test_create_node():
class BareNode(Node):
pass

with mock.patch(
"aixplain.modules.pipeline.designer.Node.attach_to"
) as mock_attach_to:
with mock.patch("aixplain.modules.pipeline.designer.Node.attach_to") as mock_attach_to:
node = BareNode(number=3, label="FOO")
mock_attach_to.assert_not_called()
assert isinstance(node.inputs, Inputs)
Expand All @@ -50,9 +48,7 @@ class FooNode(Node[FooNodeInputs, FooNodeOutputs]):
inputs_class = FooNodeInputs
outputs_class = FooNodeOutputs

with mock.patch(
"aixplain.modules.pipeline.designer.Node.attach_to"
) as mock_attach_to:
with mock.patch("aixplain.modules.pipeline.designer.Node.attach_to") as mock_attach_to:
node = FooNode(pipeline=pipeline, number=3, label="FOO")
mock_attach_to.assert_called_once_with(pipeline)
assert isinstance(node.inputs, FooNodeInputs)
Expand Down Expand Up @@ -120,9 +116,7 @@ class AssetNode(Node):
node = AssetNode()

with mock.patch.object(node.inputs, "serialize") as mock_inputs_serialize:
with mock.patch.object(
node.outputs, "serialize"
) as mock_outputs_serialize:
with mock.patch.object(node.outputs, "serialize") as mock_outputs_serialize:
assert node.serialize() == {
"number": node.number,
"type": NodeType.ASSET,
Expand All @@ -148,13 +142,10 @@ class AssetNode(Node):


def test_create_param():

class TypedParam(Param):
param_type = ParamType.INPUT

with mock.patch(
"aixplain.modules.pipeline.designer.Param.attach_to"
) as mock_attach_to:
with mock.patch("aixplain.modules.pipeline.designer.Param.attach_to") as mock_attach_to:
param = TypedParam(
code="param",
data_type=DataType.TEXT,
Expand All @@ -167,9 +158,7 @@ class TypedParam(Param):
assert param.value == "foo"
assert param.param_type == ParamType.INPUT

with mock.patch(
"aixplain.modules.pipeline.designer.Param.attach_to"
) as mock_attach_to:
with mock.patch("aixplain.modules.pipeline.designer.Param.attach_to") as mock_attach_to:
param = TypedParam(
code="param",
data_type=DataType.TEXT,
Expand All @@ -186,9 +175,7 @@ class TypedParam(Param):
class UnTypedParam(Param):
pass

with mock.patch(
"aixplain.modules.pipeline.designer.Param.attach_to"
) as mock_attach_to:
with mock.patch("aixplain.modules.pipeline.designer.Param.attach_to") as mock_attach_to:
param = UnTypedParam(
code="param",
data_type=DataType.TEXT,
Expand All @@ -199,9 +186,7 @@ class UnTypedParam(Param):

assert param.param_type == ParamType.OUTPUT

with mock.patch(
"aixplain.modules.pipeline.designer.Param.attach_to"
) as mock_attach_to:
with mock.patch("aixplain.modules.pipeline.designer.Param.attach_to") as mock_attach_to:
param = UnTypedParam(
code="param",
data_type=DataType.TEXT,
Expand All @@ -217,9 +202,7 @@ class AssetNode(Node):

node = AssetNode()

with mock.patch(
"aixplain.modules.pipeline.designer.Param.attach_to"
) as mock_attach_to:
with mock.patch("aixplain.modules.pipeline.designer.Param.attach_to") as mock_attach_to:
param = UnTypedParam(
code="param",
data_type=DataType.TEXT,
Expand All @@ -243,12 +226,8 @@ class AssetNode(Node):

node = AssetNode()

with mock.patch(
"aixplain.modules.pipeline.designer.Param.attach_to"
) as mock_attach_to:
param = param_cls(
code="param", data_type=DataType.TEXT, value="foo", node=node
)
with mock.patch("aixplain.modules.pipeline.designer.Param.attach_to") as mock_attach_to:
param = param_cls(code="param", data_type=DataType.TEXT, value="foo", node=node)
mock_attach_to.assert_called_once_with(node)
assert param.code == "param"
assert param.data_type == DataType.TEXT
Expand Down Expand Up @@ -322,12 +301,8 @@ class AssetNode(Node, LinkableMixin):

assert "Param not registered as output" in str(excinfo.value)

output = OutputParam(
code="output", data_type=DataType.TEXT, value="bar", node=a
)
input = InputParam(
code="input", data_type=DataType.TEXT, value="foo", node=b
)
output = OutputParam(code="output", data_type=DataType.TEXT, value="bar", node=a)
input = InputParam(code="input", data_type=DataType.TEXT, value="foo", node=b)

with mock.patch.object(input, "back_link") as mock_back_link:
output.link(input)
Expand Down Expand Up @@ -364,12 +339,8 @@ class AssetNode(Node, LinkableMixin):

assert "Param not registered as input" in str(excinfo.value)

output = OutputParam(
code="output", data_type=DataType.TEXT, value="bar", node=a
)
input = InputParam(
code="input", data_type=DataType.TEXT, value="foo", node=b
)
output = OutputParam(code="output", data_type=DataType.TEXT, value="bar", node=a)
input = InputParam(code="input", data_type=DataType.TEXT, value="foo", node=b)

with mock.patch.object(a, "link") as mock_link:
input.back_link(output)
Expand Down Expand Up @@ -429,9 +400,7 @@ class AssetNode(Node, LinkableMixin):

pipeline = DesignerPipeline()

with mock.patch(
"aixplain.modules.pipeline.designer.Link.attach_to"
) as mock_attach_to:
with mock.patch("aixplain.modules.pipeline.designer.Link.attach_to") as mock_attach_to:
link = Link(
from_node=a,
to_node=b,
Expand Down Expand Up @@ -588,12 +557,8 @@ class AssetNode(Node):

with mock.patch.object(param_proxy, "_create_param") as mock_create_param:
with mock.patch.object(param_proxy, "add_param") as mock_add_param:
param = param_proxy.create_param(
"foo", DataType.TEXT, "bar", is_required=True
)
mock_create_param.assert_called_once_with(
"foo", DataType.TEXT, "bar"
)
param = param_proxy.create_param("foo", DataType.TEXT, "bar", is_required=True)
mock_create_param.assert_called_once_with("foo", DataType.TEXT, "bar")
mock_add_param.assert_called_once_with(param)
assert param.is_required is True

Expand Down Expand Up @@ -624,19 +589,14 @@ class FooParam(Param):


def test_node_link():

class AssetNode(Node, LinkableMixin):
type: NodeType = NodeType.ASSET

a = AssetNode()
b = AssetNode()

output = OutputParam(
code="output", data_type=DataType.TEXT, value="bar", node=a
)
input = InputParam(
code="input", data_type=DataType.TEXT, value="foo", node=b
)
output = OutputParam(code="output", data_type=DataType.TEXT, value="bar", node=a)
input = InputParam(code="input", data_type=DataType.TEXT, value="foo", node=b)

# here too lazy to mock Link class properly
# checking the output instance instead
Expand Down