Skip to content

Commit

Permalink
feat: Support positional arguments
Browse files Browse the repository at this point in the history
Resolves: flyteorg/flyte#5320
Signed-off-by: Chi-Sheng Liu <chishengliu@chishengliu.com>
  • Loading branch information
MortalHappiness committed Jun 20, 2024
1 parent 32bab30 commit 4f8c3eb
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 31 deletions.
22 changes: 11 additions & 11 deletions flytekit/core/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ class Interface(object):

def __init__(
self,
inputs: Union[Optional[Dict[str, Type]], Optional[Dict[str, Tuple[Type, Any]]]] = None,
outputs: Union[Optional[Dict[str, Type]], Optional[Dict[str, Optional[Type]]]] = None,
inputs: Union[OrderedDict[str, Type], OrderedDict[str, Tuple[Type, Any]], None] = None,
outputs: Optional[OrderedDict[str, Optional[Type]]] = None,
output_tuple_name: Optional[str] = None,
docstring: Optional[Docstring] = None,
):
Expand All @@ -67,14 +67,14 @@ def __init__(
primarily used when handling one-element NamedTuples.
:param docstring: Docstring of the annotated @task or @workflow from which the interface derives from.
"""
self._inputs: Union[Dict[str, Tuple[Type, Any]], Dict[str, Type]] = {} # type: ignore
self._inputs: Union[OrderedDict[str, Tuple[Type, Any]], OrderedDict[str, Type]] = OrderedDict() # type: ignore
if inputs:
for k, v in inputs.items():
if type(v) is tuple and len(cast(Tuple, v)) > 1:
self._inputs[k] = v # type: ignore
else:
self._inputs[k] = (v, None) # type: ignore
self._outputs = outputs if outputs else {} # type: ignore
self._outputs = outputs if outputs else OrderedDict() # type: ignore
self._output_tuple_name = output_tuple_name

if outputs:
Expand Down Expand Up @@ -123,8 +123,8 @@ def output_tuple_name(self) -> Optional[str]:
return self._output_tuple_name

@property
def inputs(self) -> Dict[str, type]:
r = {}
def inputs(self) -> OrderedDict[str, type]:
r = OrderedDict()
for k, v in self._inputs.items():
r[k] = v[0]
return r
Expand Down Expand Up @@ -375,7 +375,7 @@ def transform_function_to_interface(fn: typing.Callable, docstring: Optional[Doc
outputs = extract_return_annotation(return_annotation)
for k, v in outputs.items():
outputs[k] = v # type: ignore
inputs: Dict[str, Tuple[Type, Any]] = OrderedDict()
inputs: OrderedDict[str, Tuple[Type, Any]] = OrderedDict()
for k, v in signature.parameters.items(): # type: ignore
annotation = type_hints.get(k, None)
default = v.default if v.default is not inspect.Parameter.empty else None
Expand Down Expand Up @@ -446,7 +446,7 @@ def output_name_generator(length: int) -> Generator[str, None, None]:
yield default_output_name(x)


def extract_return_annotation(return_annotation: Union[Type, Tuple, None]) -> Dict[str, Type]:
def extract_return_annotation(return_annotation: Union[Type, Tuple, None]) -> OrderedDict[str, Type]:
"""
The purpose of this function is to sort out whether a function is returning one thing, or multiple things, and to
name the outputs accordingly, either by using our default name function, or from a typing.NamedTuple.
Expand Down Expand Up @@ -481,7 +481,7 @@ def t(a: int, b: str) -> Dict[str, int]: ...
# Handle Option 6
# We can think about whether we should add a default output name with type None in the future.
if return_annotation in (None, type(None), inspect.Signature.empty):
return {}
return OrderedDict()

# This statement results in true for typing.Namedtuple, single and void return types, so this
# handles Options 1, 2. Even though NamedTuple for us is multi-valued, it's a single value for Python
Expand All @@ -491,7 +491,7 @@ def t(a: int, b: str) -> Dict[str, int]: ...
bases = return_annotation.__bases__ # type: ignore
if len(bases) == 1 and bases[0] == tuple and hasattr(return_annotation, "_fields"):
logger.debug(f"Task returns named tuple {return_annotation}")
return dict(get_type_hints(cast(Type, return_annotation), include_extras=True))
return OrderedDict(get_type_hints(cast(Type, return_annotation), include_extras=True))

if hasattr(return_annotation, "__origin__") and return_annotation.__origin__ is tuple: # type: ignore
# Handle option 3
Expand All @@ -511,7 +511,7 @@ def t(a: int, b: str) -> Dict[str, int]: ...
else:
# Handle all other single return types
logger.debug(f"Task returns unnamed native tuple {return_annotation}")
return {default_output_name(): cast(Type, return_annotation)}
return OrderedDict([(default_output_name(), cast(Type, return_annotation))])


def remap_shared_output_descriptions(output_descriptions: Dict[str, str], outputs: Dict[str, Type]) -> Dict[str, str]:
Expand Down
40 changes: 20 additions & 20 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -1202,19 +1202,22 @@ def flyte_entity_call_handler(
#. Start a local execution - This means that we're not already in a local workflow execution, which means that
we should expect inputs to be native Python values and that we should return Python native values.
"""
# Sanity checks
# Only keyword args allowed
if len(args) > 0:
raise _user_exceptions.FlyteAssertion(
f"When calling tasks, only keyword args are supported. "
f"Aborting execution as detected {len(args)} positional args {args}"
)
# Make sure arguments are part of interface
for k, v in kwargs.items():
if k not in cast(SupportsNodeCreation, entity).python_interface.inputs:
raise AssertionError(
f"Received unexpected keyword argument '{k}' in function '{cast(SupportsNodeCreation, entity).name}'"
)
if k not in entity.python_interface.inputs:
raise AssertionError(f"Received unexpected keyword argument '{k}' in function '{entity.name}'")

# Check if we have more arguments than expected
if len(args) > len(entity.python_interface.inputs):
raise AssertionError(
f"Received more arguments than expected in function '{entity.name}'. Expected {len(entity.python_interface.inputs)} but got {len(args)}"
)

# Convert args to kwargs
for arg, input_name in zip(args, entity.python_interface.inputs.keys()):
if input_name in kwargs:
raise AssertionError(f"Got multiple values for argument '{input_name}' in function '{entity.name}'")
kwargs[input_name] = arg

ctx = FlyteContextManager.current_context()
if ctx.execution_state and (
Expand All @@ -1234,15 +1237,12 @@ def flyte_entity_call_handler(
child_ctx.execution_state
and child_ctx.execution_state.branch_eval_mode == BranchEvalMode.BRANCH_SKIPPED
):
if (
len(cast(SupportsNodeCreation, entity).python_interface.inputs) > 0
or len(cast(SupportsNodeCreation, entity).python_interface.outputs) > 0
):
output_names = list(cast(SupportsNodeCreation, entity).python_interface.outputs.keys())
if len(entity.python_interface.inputs) > 0 or len(entity.python_interface.outputs) > 0:
output_names = list(entity.python_interface.outputs.keys())
if len(output_names) == 0:
return VoidPromise(entity.name)
vals = [Promise(var, None) for var in output_names]
return create_task_output(vals, cast(SupportsNodeCreation, entity).python_interface)
return create_task_output(vals, entity.python_interface)
else:
return None
return cast(LocallyExecutable, entity).local_execute(ctx, **kwargs)
Expand All @@ -1255,7 +1255,7 @@ def flyte_entity_call_handler(
cast(ExecutionParameters, child_ctx.user_space_params)._decks = []
result = cast(LocallyExecutable, entity).local_execute(child_ctx, **kwargs)

expected_outputs = len(cast(SupportsNodeCreation, entity).python_interface.outputs)
expected_outputs = len(entity.python_interface.outputs)
if expected_outputs == 0:
if result is None or isinstance(result, VoidPromise):
return None
Expand All @@ -1268,10 +1268,10 @@ def flyte_entity_call_handler(
if (1 < expected_outputs == len(cast(Tuple[Promise], result))) or (
result is not None and expected_outputs == 1
):
return create_native_named_tuple(ctx, result, cast(SupportsNodeCreation, entity).python_interface)
return create_native_named_tuple(ctx, result, entity.python_interface)

raise AssertionError(
f"Expected outputs and actual outputs do not match."
f"Result {result}. "
f"Python interface: {cast(SupportsNodeCreation, entity).python_interface}"
f"Python interface: {entity.python_interface}"
)

0 comments on commit 4f8c3eb

Please sign in to comment.