diff --git a/aixplain/factories/pipeline_factory/utils.py b/aixplain/factories/pipeline_factory/utils.py index c9291031..08954571 100644 --- a/aixplain/factories/pipeline_factory/utils.py +++ b/aixplain/factories/pipeline_factory/utils.py @@ -88,7 +88,7 @@ def build_from_response(response: Dict, load_architecture: bool = False) -> Pipe data_type=custom_input.get("dataType"), code=custom_input["code"], value=custom_input.get("value"), - is_required=custom_input.get("isRequired", False), + is_required=custom_input.get("isRequired", True), ) node.number = node_json["number"] node.label = node_json["label"] diff --git a/aixplain/modules/pipeline/designer/base.py b/aixplain/modules/pipeline/designer/base.py index a925840f..08d4c8c5 100644 --- a/aixplain/modules/pipeline/designer/base.py +++ b/aixplain/modules/pipeline/designer/base.py @@ -1,3 +1,4 @@ +import re from typing import ( List, Union, @@ -11,7 +12,7 @@ from aixplain.enums import DataType from .enums import NodeType, ParamType - +from .utils import find_prompt_params if TYPE_CHECKING: from .pipeline import DesignerPipeline @@ -280,14 +281,31 @@ def __getitem__(self, code: str) -> Param: return param raise KeyError(f"Parameter with code '{code}' not found.") + def special_prompt_handling(self, code: str, value: str) -> None: + """ + This method will handle the special prompt handling for asset nodes + having `text-generation` function type. + """ + from .nodes import AssetNode + + if isinstance(self.node, AssetNode) and self.node.asset.function == "text-generation": + if code == "prompt": + matches = find_prompt_params(value) + for match in matches: + self.node.inputs.create_param(match, DataType.TEXT, is_required=True) + + def set_param_value(self, code: str, value: str) -> None: + self.special_prompt_handling(code, value) + self[code].value = value + def __setitem__(self, code: str, value: str) -> None: # set param value on set item to avoid setting it manually - self[code].value = value + self.set_param_value(code, value) def __setattr__(self, name: str, value: any) -> None: # set param value on attribute assignment to avoid setting it manually if isinstance(value, str) and hasattr(self, name): - self[name].value = value + self.set_param_value(name, value) else: super().__setattr__(name, value) diff --git a/aixplain/modules/pipeline/designer/pipeline.py b/aixplain/modules/pipeline/designer/pipeline.py index ece5ac0c..79013590 100644 --- a/aixplain/modules/pipeline/designer/pipeline.py +++ b/aixplain/modules/pipeline/designer/pipeline.py @@ -6,7 +6,7 @@ from .nodes import AssetNode, Decision, Script, Input, Output, Router, Route, BareReconstructor, BareSegmentor, BareMetric from .enums import NodeType, RouteType, Operation from .mixins import OutputableMixin - +from .utils import find_prompt_params T = TypeVar("T", bound="AssetNode") @@ -125,6 +125,24 @@ def is_param_set(self, node, param): """ return param.value or self.is_param_linked(node, param) + def special_prompt_validation(self, node: Node): + """ + This method will handle the special rule for asset nodes having + `text-generation` function type where if any prompt variable exists + then the `text` param is not required but the prompt param are. + + :param node: the node + :raises ValueError: if the pipeline is not valid + """ + if isinstance(node, AssetNode) and node.asset.function == "text-generation": + if self.is_param_set(node, node.inputs.prompt): + matches = find_prompt_params(node.inputs.prompt.value) + if matches: + node.inputs.text.is_required = False + for match in matches: + if match not in node.inputs: + raise ValueError(f"Param {match} of node {node.label} should be defined and set") + def validate_params(self): """ This method will check if all required params are either set or linked @@ -132,6 +150,7 @@ def validate_params(self): :raises ValueError: if the pipeline is not valid """ for node in self.nodes: + self.special_prompt_validation(node) for param in node.inputs: if param.is_required and not self.is_param_set(node, param): raise ValueError(f"Param {param.code} of node {node.label} is required") diff --git a/aixplain/modules/pipeline/designer/utils.py b/aixplain/modules/pipeline/designer/utils.py new file mode 100644 index 00000000..250d5501 --- /dev/null +++ b/aixplain/modules/pipeline/designer/utils.py @@ -0,0 +1,13 @@ +import re +from typing import List + + +def find_prompt_params(prompt: str) -> List[str]: + """ + This method will find the prompt parameters in the prompt string. + + :param prompt: the prompt string + :return: list of prompt parameters + """ + param_regex = re.compile(r"\{\{([^\}]+)\}\}") + return param_regex.findall(prompt) diff --git a/tests/unit/designer_unit_test.py b/tests/unit/designer_unit_test.py index 824fd162..57276a20 100644 --- a/tests/unit/designer_unit_test.py +++ b/tests/unit/designer_unit_test.py @@ -1,6 +1,5 @@ import pytest -import unittest.mock as mock - +from unittest.mock import patch, Mock, call from aixplain.enums import DataType from aixplain.modules.pipeline.designer.base import ( @@ -21,7 +20,7 @@ from aixplain.modules.pipeline.designer.mixins import LinkableMixin from aixplain.modules.pipeline.designer.pipeline import DesignerPipeline - +from aixplain.modules.pipeline.designer.base import find_prompt_params def test_create_node(): @@ -30,7 +29,7 @@ def test_create_node(): class BareNode(Node): pass - with mock.patch("aixplain.modules.pipeline.designer.Node.attach_to") as mock_attach_to: + with patch("aixplain.modules.pipeline.designer.Node.attach_to") as mock_attach_to: node = BareNode(number=3, label="FOO") mock_attach_to.assert_not_called() assert isinstance(node.inputs, Inputs) @@ -48,7 +47,7 @@ class FooNode(Node[FooNodeInputs, FooNodeOutputs]): inputs_class = FooNodeInputs outputs_class = FooNodeOutputs - with mock.patch("aixplain.modules.pipeline.designer.Node.attach_to") as mock_attach_to: + with patch("aixplain.modules.pipeline.designer.Node.attach_to") as mock_attach_to: node = FooNode(pipeline=pipeline, number=3, label="FOO") mock_attach_to.assert_called_once_with(pipeline) assert isinstance(node.inputs, FooNodeInputs) @@ -115,8 +114,8 @@ class AssetNode(Node): node = AssetNode() - with mock.patch.object(node.inputs, "serialize") as mock_inputs_serialize: - with mock.patch.object(node.outputs, "serialize") as mock_outputs_serialize: + with patch.object(node.inputs, "serialize") as mock_inputs_serialize: + with patch.object(node.outputs, "serialize") as mock_outputs_serialize: assert node.serialize() == { "number": node.number, "type": NodeType.ASSET, @@ -145,7 +144,7 @@ def test_create_param(): class TypedParam(Param): param_type = ParamType.INPUT - with mock.patch("aixplain.modules.pipeline.designer.Param.attach_to") as mock_attach_to: + with patch("aixplain.modules.pipeline.designer.Param.attach_to") as mock_attach_to: param = TypedParam( code="param", data_type=DataType.TEXT, @@ -158,7 +157,7 @@ class TypedParam(Param): assert param.value == "foo" assert param.param_type == ParamType.INPUT - with mock.patch("aixplain.modules.pipeline.designer.Param.attach_to") as mock_attach_to: + with patch("aixplain.modules.pipeline.designer.Param.attach_to") as mock_attach_to: param = TypedParam( code="param", data_type=DataType.TEXT, @@ -175,7 +174,7 @@ class TypedParam(Param): class UnTypedParam(Param): pass - with mock.patch("aixplain.modules.pipeline.designer.Param.attach_to") as mock_attach_to: + with patch("aixplain.modules.pipeline.designer.Param.attach_to") as mock_attach_to: param = UnTypedParam( code="param", data_type=DataType.TEXT, @@ -186,7 +185,7 @@ class UnTypedParam(Param): assert param.param_type == ParamType.OUTPUT - with mock.patch("aixplain.modules.pipeline.designer.Param.attach_to") as mock_attach_to: + with patch("aixplain.modules.pipeline.designer.Param.attach_to") as mock_attach_to: param = UnTypedParam( code="param", data_type=DataType.TEXT, @@ -202,7 +201,7 @@ class AssetNode(Node): node = AssetNode() - with mock.patch("aixplain.modules.pipeline.designer.Param.attach_to") as mock_attach_to: + with patch("aixplain.modules.pipeline.designer.Param.attach_to") as mock_attach_to: param = UnTypedParam( code="param", data_type=DataType.TEXT, @@ -226,7 +225,7 @@ class AssetNode(Node): node = AssetNode() - with mock.patch("aixplain.modules.pipeline.designer.Param.attach_to") as mock_attach_to: + with patch("aixplain.modules.pipeline.designer.Param.attach_to") as mock_attach_to: param = param_cls(code="param", data_type=DataType.TEXT, value="foo", node=node) mock_attach_to.assert_called_once_with(node) assert param.code == "param" @@ -253,7 +252,7 @@ class NoTypeParam(Param): input = InputParam(code="input", data_type=DataType.TEXT, value="foo") - with mock.patch.object(node.inputs, "add_param") as mock_add_param: + with patch.object(node.inputs, "add_param") as mock_add_param: input.attach_to(node) mock_add_param.assert_called_once_with(input) assert input.node is node @@ -265,7 +264,7 @@ class NoTypeParam(Param): output = OutputParam(code="output", data_type=DataType.TEXT, value="bar") - with mock.patch.object(node.outputs, "add_param") as mock_add_param: + with patch.object(node.outputs, "add_param") as mock_add_param: output.attach_to(node) mock_add_param.assert_called_once_with(output) assert output.node is node @@ -304,7 +303,7 @@ class AssetNode(Node, LinkableMixin): output = OutputParam(code="output", data_type=DataType.TEXT, value="bar", node=a) input = InputParam(code="input", data_type=DataType.TEXT, value="foo", node=b) - with mock.patch.object(input, "back_link") as mock_back_link: + with patch.object(input, "back_link") as mock_back_link: output.link(input) mock_back_link.assert_called_once_with(output) @@ -342,7 +341,7 @@ class AssetNode(Node, LinkableMixin): output = OutputParam(code="output", data_type=DataType.TEXT, value="bar", node=a) input = InputParam(code="input", data_type=DataType.TEXT, value="foo", node=b) - with mock.patch.object(a, "link") as mock_link: + with patch.object(a, "link") as mock_link: input.back_link(output) mock_link.assert_called_once_with(b, output, input) @@ -400,7 +399,7 @@ class AssetNode(Node, LinkableMixin): pipeline = DesignerPipeline() - with mock.patch("aixplain.modules.pipeline.designer.Link.attach_to") as mock_attach_to: + with patch("aixplain.modules.pipeline.designer.Link.attach_to") as mock_attach_to: link = Link( from_node=a, to_node=b, @@ -431,8 +430,8 @@ class AssetNode(Node, LinkableMixin): to_param="input", ) - with mock.patch.object(a, "attach_to") as mock_a_attach_to: - with mock.patch.object(b, "attach_to") as mock_b_attach_to: + with patch.object(a, "attach_to") as mock_a_attach_to: + with patch.object(b, "attach_to") as mock_b_attach_to: link.attach_to(pipeline) mock_a_attach_to.assert_called_once_with(pipeline) mock_b_attach_to.assert_called_once_with(pipeline) @@ -451,8 +450,8 @@ class AssetNode(Node, LinkableMixin): to_param="input", ) - with mock.patch.object(a, "attach_to") as mock_a_attach_to: - with mock.patch.object(b, "attach_to") as mock_b_attach_to: + with patch.object(a, "attach_to") as mock_a_attach_to: + with patch.object(b, "attach_to") as mock_b_attach_to: link.attach_to(pipeline) mock_a_attach_to.assert_not_called() mock_b_attach_to.assert_not_called() @@ -555,8 +554,8 @@ class AssetNode(Node): param_proxy = ParamProxy(node) - with mock.patch.object(param_proxy, "_create_param") as mock_create_param: - with mock.patch.object(param_proxy, "add_param") as mock_add_param: + with patch.object(param_proxy, "_create_param") as mock_create_param: + with patch.object(param_proxy, "add_param") as mock_add_param: param = param_proxy.create_param("foo", DataType.TEXT, "bar", is_required=True) mock_create_param.assert_called_once_with("foo", DataType.TEXT, "bar") mock_add_param.assert_called_once_with(param) @@ -588,6 +587,48 @@ class FooParam(Param): assert "'bar'" in str(excinfo.value) +def test_param_proxy_set_param_value(): + prompt_param = Mock(spec=Param, code="prompt") + param_proxy = ParamProxy(Mock()) + param_proxy._params = [prompt_param] + with patch.object(param_proxy, "special_prompt_handling") as mock_special_prompt_handling: + param_proxy.set_param_value("prompt", "hello {{foo}}") + mock_special_prompt_handling.assert_called_once_with("prompt", "hello {{foo}}") + assert prompt_param.value == "hello {{foo}}" + + +def test_param_proxy_special_prompt_handling(): + from aixplain.modules.pipeline.designer.nodes import AssetNode + + asset_node = Mock(spec=AssetNode, asset=Mock(function="text-generation")) + param_proxy = ParamProxy(asset_node) + with patch( + "aixplain.modules.pipeline.designer.base.find_prompt_params" + ) as mock_find_prompt_params: + mock_find_prompt_params.return_value = [] + param_proxy.special_prompt_handling("prompt", "hello {{foo}}") + mock_find_prompt_params.assert_called_once_with("hello {{foo}}") + asset_node.inputs.create_param.assert_not_called() + asset_node.reset_mock() + mock_find_prompt_params.reset_mock() + + mock_find_prompt_params.return_value = ["foo"] + param_proxy.special_prompt_handling("prompt", "hello {{foo}}") + mock_find_prompt_params.assert_called_once_with("hello {{foo}}") + asset_node.inputs.create_param.assert_called_once_with("foo", DataType.TEXT, is_required=True) + asset_node.reset_mock() + mock_find_prompt_params.reset_mock() + + mock_find_prompt_params.return_value = ["foo", "bar"] + param_proxy.special_prompt_handling("prompt", "hello {{foo}} {{bar}}") + mock_find_prompt_params.assert_called_once_with("hello {{foo}} {{bar}}") + assert asset_node.inputs.create_param.call_count == 2 + assert asset_node.inputs.create_param.call_args_list == [ + call("foo", DataType.TEXT, is_required=True), + call("bar", DataType.TEXT, is_required=True), + ] + + def test_node_link(): class AssetNode(Node, LinkableMixin): type: NodeType = NodeType.ASSET @@ -623,7 +664,7 @@ class AssetNode(Node): type: NodeType = NodeType.ASSET node1 = AssetNode() - with mock.patch.object(node1, "attach_to") as mock_attach_to: + with patch.object(node1, "attach_to") as mock_attach_to: pipeline.add_node(node1) mock_attach_to.assert_called_once_with(pipeline) @@ -636,14 +677,14 @@ class InputNode(Node): node = InputNode() - with mock.patch.object(pipeline, "add_node") as mock_add_node: + with patch.object(pipeline, "add_node") as mock_add_node: pipeline.add_nodes(node) assert mock_add_node.call_count == 1 node1 = InputNode() node2 = InputNode() - with mock.patch.object(pipeline, "add_node") as mock_add_node: + with patch.object(pipeline, "add_node") as mock_add_node: pipeline.add_nodes(node1, node2) assert mock_add_node.call_count == 2 @@ -662,6 +703,95 @@ class AssetNode(Node): link = Link(from_node=a, to_node=b, from_param="output", to_param="input") pipeline.add_link(link) - with mock.patch.object(link, "attach_to") as mock_attach_to: + with patch.object(link, "attach_to") as mock_attach_to: pipeline.add_link(link) mock_attach_to.assert_called_once_with(pipeline) + + +def test_pipeline_special_prompt_validation(): + from aixplain.modules.pipeline.designer.nodes import AssetNode + + pipeline = DesignerPipeline() + asset_node = Mock( + spec=AssetNode, + label="LLM(ID=1)", + asset=Mock(function="text-generation"), + inputs=Mock(prompt=Mock(value="hello {{foo}}"), text=Mock(is_required=True)), + ) + with patch.object(pipeline, "is_param_set") as mock_is_param_set: + mock_is_param_set.return_value = False + pipeline.special_prompt_validation(asset_node) + mock_is_param_set.assert_called_once_with(asset_node, asset_node.inputs.prompt) + assert asset_node.inputs.text.is_required is True + mock_is_param_set.reset_mock() + mock_is_param_set.return_value = True + with patch( + "aixplain.modules.pipeline.designer.pipeline.find_prompt_params" + ) as mock_find_prompt_params: + mock_find_prompt_params.return_value = [] + pipeline.special_prompt_validation(asset_node) + mock_is_param_set.assert_called_once_with( + asset_node, asset_node.inputs.prompt + ) + mock_find_prompt_params.assert_called_once_with( + asset_node.inputs.prompt.value + ) + assert asset_node.inputs.text.is_required is True + + mock_is_param_set.reset_mock() + mock_is_param_set.return_value = True + mock_find_prompt_params.reset_mock() + mock_find_prompt_params.return_value = ["foo"] + asset_node.inputs.__contains__ = Mock(return_value=False) + + with pytest.raises( + ValueError, + match="Param foo of node LLM\\(ID=1\\) should be defined and set", + ): + pipeline.special_prompt_validation(asset_node) + + mock_is_param_set.assert_called_once_with( + asset_node, asset_node.inputs.prompt + ) + mock_find_prompt_params.assert_called_once_with( + asset_node.inputs.prompt.value + ) + assert asset_node.inputs.text.is_required is False + + mock_is_param_set.reset_mock() + mock_is_param_set.return_value = True + mock_find_prompt_params.reset_mock() + mock_find_prompt_params.return_value = ["foo"] + asset_node.inputs.text.is_required = True + + asset_node.inputs.__contains__ = Mock(return_value=True) + pipeline.special_prompt_validation(asset_node) + mock_is_param_set.assert_called_once_with( + asset_node, asset_node.inputs.prompt + ) + mock_find_prompt_params.assert_called_once_with( + asset_node.inputs.prompt.value + ) + assert asset_node.inputs.text.is_required is False + + +@pytest.mark.parametrize( + "input, expected", + [ + ("hello {{foo}}", ["foo"]), + ("hello {{foo}} {{bar}}", ["foo", "bar"]), + ("hello {{foo}} {{bar}} {{baz}}", ["foo", "bar", "baz"]), + # no match cases + ("hello bar", []), + ("hello {{foo]] bar", []), + ("hello {foo} bar", []), + # edge cases + ("", []), + ("{{}}", []), + # interesting cases + ("hello {{foo {{bar}} baz}} {{bar}} {{baz}}", ["foo {{bar", "bar", "baz"]), + ], +) +def test_find_prompt_params(input, expected): + print(input, expected) + assert find_prompt_params(input) == expected