From 4f8c3eb88ac3c99afd6360434aaaf428fa1e5de1 Mon Sep 17 00:00:00 2001 From: Chi-Sheng Liu Date: Thu, 20 Jun 2024 15:57:55 +0800 Subject: [PATCH] feat: Support positional arguments Resolves: flyteorg/flyte#5320 Signed-off-by: Chi-Sheng Liu --- flytekit/core/interface.py | 22 ++++++++++----------- flytekit/core/promise.py | 40 +++++++++++++++++++------------------- 2 files changed, 31 insertions(+), 31 deletions(-) diff --git a/flytekit/core/interface.py b/flytekit/core/interface.py index 13b6af2d4b..2f6f5d9108 100644 --- a/flytekit/core/interface.py +++ b/flytekit/core/interface.py @@ -52,8 +52,8 @@ class Interface(object): def __init__( self, - inputs: Union[Optional[Dict[str, Type]], Optional[Dict[str, Tuple[Type, Any]]]] = None, - outputs: Union[Optional[Dict[str, Type]], Optional[Dict[str, Optional[Type]]]] = None, + inputs: Union[OrderedDict[str, Type], OrderedDict[str, Tuple[Type, Any]], None] = None, + outputs: Optional[OrderedDict[str, Optional[Type]]] = None, output_tuple_name: Optional[str] = None, docstring: Optional[Docstring] = None, ): @@ -67,14 +67,14 @@ def __init__( primarily used when handling one-element NamedTuples. :param docstring: Docstring of the annotated @task or @workflow from which the interface derives from. """ - self._inputs: Union[Dict[str, Tuple[Type, Any]], Dict[str, Type]] = {} # type: ignore + self._inputs: Union[OrderedDict[str, Tuple[Type, Any]], OrderedDict[str, Type]] = OrderedDict() # type: ignore if inputs: for k, v in inputs.items(): if type(v) is tuple and len(cast(Tuple, v)) > 1: self._inputs[k] = v # type: ignore else: self._inputs[k] = (v, None) # type: ignore - self._outputs = outputs if outputs else {} # type: ignore + self._outputs = outputs if outputs else OrderedDict() # type: ignore self._output_tuple_name = output_tuple_name if outputs: @@ -123,8 +123,8 @@ def output_tuple_name(self) -> Optional[str]: return self._output_tuple_name @property - def inputs(self) -> Dict[str, type]: - r = {} + def inputs(self) -> OrderedDict[str, type]: + r = OrderedDict() for k, v in self._inputs.items(): r[k] = v[0] return r @@ -375,7 +375,7 @@ def transform_function_to_interface(fn: typing.Callable, docstring: Optional[Doc outputs = extract_return_annotation(return_annotation) for k, v in outputs.items(): outputs[k] = v # type: ignore - inputs: Dict[str, Tuple[Type, Any]] = OrderedDict() + inputs: OrderedDict[str, Tuple[Type, Any]] = OrderedDict() for k, v in signature.parameters.items(): # type: ignore annotation = type_hints.get(k, None) default = v.default if v.default is not inspect.Parameter.empty else None @@ -446,7 +446,7 @@ def output_name_generator(length: int) -> Generator[str, None, None]: yield default_output_name(x) -def extract_return_annotation(return_annotation: Union[Type, Tuple, None]) -> Dict[str, Type]: +def extract_return_annotation(return_annotation: Union[Type, Tuple, None]) -> OrderedDict[str, Type]: """ The purpose of this function is to sort out whether a function is returning one thing, or multiple things, and to name the outputs accordingly, either by using our default name function, or from a typing.NamedTuple. @@ -481,7 +481,7 @@ def t(a: int, b: str) -> Dict[str, int]: ... # Handle Option 6 # We can think about whether we should add a default output name with type None in the future. if return_annotation in (None, type(None), inspect.Signature.empty): - return {} + return OrderedDict() # This statement results in true for typing.Namedtuple, single and void return types, so this # handles Options 1, 2. Even though NamedTuple for us is multi-valued, it's a single value for Python @@ -491,7 +491,7 @@ def t(a: int, b: str) -> Dict[str, int]: ... bases = return_annotation.__bases__ # type: ignore if len(bases) == 1 and bases[0] == tuple and hasattr(return_annotation, "_fields"): logger.debug(f"Task returns named tuple {return_annotation}") - return dict(get_type_hints(cast(Type, return_annotation), include_extras=True)) + return OrderedDict(get_type_hints(cast(Type, return_annotation), include_extras=True)) if hasattr(return_annotation, "__origin__") and return_annotation.__origin__ is tuple: # type: ignore # Handle option 3 @@ -511,7 +511,7 @@ def t(a: int, b: str) -> Dict[str, int]: ... else: # Handle all other single return types logger.debug(f"Task returns unnamed native tuple {return_annotation}") - return {default_output_name(): cast(Type, return_annotation)} + return OrderedDict([(default_output_name(), cast(Type, return_annotation))]) def remap_shared_output_descriptions(output_descriptions: Dict[str, str], outputs: Dict[str, Type]) -> Dict[str, str]: diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 557d621dd4..c4f71eb2d6 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -1202,19 +1202,22 @@ def flyte_entity_call_handler( #. Start a local execution - This means that we're not already in a local workflow execution, which means that we should expect inputs to be native Python values and that we should return Python native values. """ - # Sanity checks - # Only keyword args allowed - if len(args) > 0: - raise _user_exceptions.FlyteAssertion( - f"When calling tasks, only keyword args are supported. " - f"Aborting execution as detected {len(args)} positional args {args}" - ) # Make sure arguments are part of interface for k, v in kwargs.items(): - if k not in cast(SupportsNodeCreation, entity).python_interface.inputs: - raise AssertionError( - f"Received unexpected keyword argument '{k}' in function '{cast(SupportsNodeCreation, entity).name}'" - ) + if k not in entity.python_interface.inputs: + raise AssertionError(f"Received unexpected keyword argument '{k}' in function '{entity.name}'") + + # Check if we have more arguments than expected + if len(args) > len(entity.python_interface.inputs): + raise AssertionError( + f"Received more arguments than expected in function '{entity.name}'. Expected {len(entity.python_interface.inputs)} but got {len(args)}" + ) + + # Convert args to kwargs + for arg, input_name in zip(args, entity.python_interface.inputs.keys()): + if input_name in kwargs: + raise AssertionError(f"Got multiple values for argument '{input_name}' in function '{entity.name}'") + kwargs[input_name] = arg ctx = FlyteContextManager.current_context() if ctx.execution_state and ( @@ -1234,15 +1237,12 @@ def flyte_entity_call_handler( child_ctx.execution_state and child_ctx.execution_state.branch_eval_mode == BranchEvalMode.BRANCH_SKIPPED ): - if ( - len(cast(SupportsNodeCreation, entity).python_interface.inputs) > 0 - or len(cast(SupportsNodeCreation, entity).python_interface.outputs) > 0 - ): - output_names = list(cast(SupportsNodeCreation, entity).python_interface.outputs.keys()) + if len(entity.python_interface.inputs) > 0 or len(entity.python_interface.outputs) > 0: + output_names = list(entity.python_interface.outputs.keys()) if len(output_names) == 0: return VoidPromise(entity.name) vals = [Promise(var, None) for var in output_names] - return create_task_output(vals, cast(SupportsNodeCreation, entity).python_interface) + return create_task_output(vals, entity.python_interface) else: return None return cast(LocallyExecutable, entity).local_execute(ctx, **kwargs) @@ -1255,7 +1255,7 @@ def flyte_entity_call_handler( cast(ExecutionParameters, child_ctx.user_space_params)._decks = [] result = cast(LocallyExecutable, entity).local_execute(child_ctx, **kwargs) - expected_outputs = len(cast(SupportsNodeCreation, entity).python_interface.outputs) + expected_outputs = len(entity.python_interface.outputs) if expected_outputs == 0: if result is None or isinstance(result, VoidPromise): return None @@ -1268,10 +1268,10 @@ def flyte_entity_call_handler( if (1 < expected_outputs == len(cast(Tuple[Promise], result))) or ( result is not None and expected_outputs == 1 ): - return create_native_named_tuple(ctx, result, cast(SupportsNodeCreation, entity).python_interface) + return create_native_named_tuple(ctx, result, entity.python_interface) raise AssertionError( f"Expected outputs and actual outputs do not match." f"Result {result}. " - f"Python interface: {cast(SupportsNodeCreation, entity).python_interface}" + f"Python interface: {entity.python_interface}" )