Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aixplain/factories/pipeline_factory/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def build_from_response(response: Dict, load_architecture: bool = False) -> Pipe
data_type=custom_input.get("dataType"),
code=custom_input["code"],
value=custom_input.get("value"),
is_required=custom_input.get("isRequired", False),
is_required=custom_input.get("isRequired", True),
)
node.number = node_json["number"]
node.label = node_json["label"]
Expand Down
24 changes: 21 additions & 3 deletions aixplain/modules/pipeline/designer/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from typing import (
List,
Union,
Expand All @@ -11,7 +12,7 @@

from aixplain.enums import DataType
from .enums import NodeType, ParamType

from .utils import find_prompt_params

if TYPE_CHECKING:
from .pipeline import DesignerPipeline
Expand Down Expand Up @@ -280,14 +281,31 @@ def __getitem__(self, code: str) -> Param:
return param
raise KeyError(f"Parameter with code '{code}' not found.")

def special_prompt_handling(self, code: str, value: str) -> None:
"""
This method will handle the special prompt handling for asset nodes
having `text-generation` function type.
"""
from .nodes import AssetNode

if isinstance(self.node, AssetNode) and self.node.asset.function == "text-generation":
if code == "prompt":
matches = find_prompt_params(value)
for match in matches:
self.node.inputs.create_param(match, DataType.TEXT, is_required=True)

def set_param_value(self, code: str, value: str) -> None:
self.special_prompt_handling(code, value)
self[code].value = value

def __setitem__(self, code: str, value: str) -> None:
# set param value on set item to avoid setting it manually
self[code].value = value
self.set_param_value(code, value)

def __setattr__(self, name: str, value: any) -> None:
# set param value on attribute assignment to avoid setting it manually
if isinstance(value, str) and hasattr(self, name):
self[name].value = value
self.set_param_value(name, value)
else:
super().__setattr__(name, value)

Expand Down
21 changes: 20 additions & 1 deletion aixplain/modules/pipeline/designer/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .nodes import AssetNode, Decision, Script, Input, Output, Router, Route, BareReconstructor, BareSegmentor, BareMetric
from .enums import NodeType, RouteType, Operation
from .mixins import OutputableMixin

from .utils import find_prompt_params

T = TypeVar("T", bound="AssetNode")

Expand Down Expand Up @@ -125,13 +125,32 @@ def is_param_set(self, node, param):
"""
return param.value or self.is_param_linked(node, param)

def special_prompt_validation(self, node: Node):
"""
This method will handle the special rule for asset nodes having
`text-generation` function type where if any prompt variable exists
then the `text` param is not required but the prompt param are.

:param node: the node
:raises ValueError: if the pipeline is not valid
"""
if isinstance(node, AssetNode) and node.asset.function == "text-generation":
if self.is_param_set(node, node.inputs.prompt):
matches = find_prompt_params(node.inputs.prompt.value)
if matches:
node.inputs.text.is_required = False
for match in matches:
if match not in node.inputs:
raise ValueError(f"Param {match} of node {node.label} should be defined and set")

def validate_params(self):
"""
This method will check if all required params are either set or linked

:raises ValueError: if the pipeline is not valid
"""
for node in self.nodes:
self.special_prompt_validation(node)
for param in node.inputs:
if param.is_required and not self.is_param_set(node, param):
raise ValueError(f"Param {param.code} of node {node.label} is required")
Expand Down
13 changes: 13 additions & 0 deletions aixplain/modules/pipeline/designer/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import re
from typing import List


def find_prompt_params(prompt: str) -> List[str]:
"""
This method will find the prompt parameters in the prompt string.

:param prompt: the prompt string
:return: list of prompt parameters
"""
param_regex = re.compile(r"\{\{([^\}]+)\}\}")
return param_regex.findall(prompt)
Loading