Skip to content

Commit

Permalink
feat: Support positional arguments
Browse files Browse the repository at this point in the history
Resolves: flyteorg/flyte#5320
Signed-off-by: Chi-Sheng Liu <chishengliu@chishengliu.com>
  • Loading branch information
MortalHappiness committed Jun 20, 2024
1 parent 32bab30 commit a69ed93
Show file tree
Hide file tree
Showing 10 changed files with 67 additions and 61 deletions.
9 changes: 5 additions & 4 deletions flytekit/core/base_sql_task.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re
from typing import Any, Dict, Optional, Tuple, Type, TypeVar
from collections import OrderedDict
from typing import Any, Optional, Tuple, Type, TypeVar

from flytekit.core.base_task import PythonTask, TaskMetadata
from flytekit.core.interface import Interface
Expand All @@ -24,9 +25,9 @@ def __init__(
query_template: str,
task_config: Optional[T] = None,
task_type="sql_task",
inputs: Optional[Dict[str, Tuple[Type, Any]]] = None,
inputs: Optional[OrderedDict[str, Tuple[Type, Any]]] = None,
metadata: Optional[TaskMetadata] = None,
outputs: Optional[Dict[str, Type]] = None,
outputs: Optional[OrderedDict[str, Type]] = None,
**kwargs,
):
"""
Expand All @@ -36,7 +37,7 @@ def __init__(
super().__init__(
task_type=task_type,
name=name,
interface=Interface(inputs=inputs or {}, outputs=outputs or {}),
interface=Interface(inputs=inputs or OrderedDict(), outputs=outputs or OrderedDict()),
metadata=metadata,
task_config=task_config,
**kwargs,
Expand Down
2 changes: 1 addition & 1 deletion flytekit/core/container_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(
inputs: Optional[OrderedDict[str, Type]] = None,
metadata: Optional[TaskMetadata] = None,
arguments: Optional[List[str]] = None,
outputs: Optional[Dict[str, Type]] = None,
outputs: Optional[OrderedDict[str, Type]] = None,
requests: Optional[Resources] = None,
limits: Optional[Resources] = None,
input_data_dir: Optional[str] = None,
Expand Down
15 changes: 4 additions & 11 deletions flytekit/core/gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import datetime
import typing
from collections import OrderedDict
from typing import Tuple, Union

import click
Expand Down Expand Up @@ -49,11 +50,7 @@ def __init__(
self._python_interface = flyte_interface.Interface()
elif input_type:
# Waiting for user input, so the output of the node is whatever input the user provides.
self._python_interface = flyte_interface.Interface(
outputs={
"o0": self.input_type,
}
)
self._python_interface = flyte_interface.Interface(outputs=OrderedDict([("o0", self.input_type)]))
else:
# We don't know how to find the python interface here, approve() sets it below, See the code.
self._python_interface = None # type: ignore
Expand Down Expand Up @@ -205,12 +202,8 @@ def approve(upstream_item: Union[Tuple[Promise], Promise, VoidPromise], name: st

# In either case, we need a python interface
g._python_interface = flyte_interface.Interface(
inputs={
io_var_name: io_type,
},
outputs={
io_var_name: io_type,
},
inputs=OrderedDict([(io_var_name, io_type)]),
outputs=OrderedDict([(io_var_name, io_type)]),
)
kwargs = {io_var_name: upstream_item}

Expand Down
28 changes: 14 additions & 14 deletions flytekit/core/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ class Interface(object):

def __init__(
self,
inputs: Union[Optional[Dict[str, Type]], Optional[Dict[str, Tuple[Type, Any]]]] = None,
outputs: Union[Optional[Dict[str, Type]], Optional[Dict[str, Optional[Type]]]] = None,
inputs: Union[OrderedDict[str, Type], OrderedDict[str, Tuple[Type, Any]], None] = None,
outputs: Union[Dict[str, Type], Dict[str, Optional[Type]], None] = None,
output_tuple_name: Optional[str] = None,
docstring: Optional[Docstring] = None,
):
Expand All @@ -67,14 +67,14 @@ def __init__(
primarily used when handling one-element NamedTuples.
:param docstring: Docstring of the annotated @task or @workflow from which the interface derives from.
"""
self._inputs: Union[Dict[str, Tuple[Type, Any]], Dict[str, Type]] = {} # type: ignore
self._inputs: Union[OrderedDict[str, Tuple[Type, Any]], OrderedDict[str, Type]] = OrderedDict() # type: ignore
if inputs:
for k, v in inputs.items():
if type(v) is tuple and len(cast(Tuple, v)) > 1:
self._inputs[k] = v # type: ignore
else:
self._inputs[k] = (v, None) # type: ignore
self._outputs = outputs if outputs else {} # type: ignore
self._outputs = outputs if outputs else OrderedDict() # type: ignore
self._output_tuple_name = output_tuple_name

if outputs:
Expand Down Expand Up @@ -123,8 +123,8 @@ def output_tuple_name(self) -> Optional[str]:
return self._output_tuple_name

@property
def inputs(self) -> Dict[str, type]:
r = {}
def inputs(self) -> OrderedDict[str, type]:
r = OrderedDict()
for k, v in self._inputs.items():
r[k] = v[0]
return r
Expand All @@ -144,7 +144,7 @@ def default_inputs_as_kwargs(self) -> Dict[str, Any]:
return {k: v[1] for k, v in self._inputs.items() if v[1] is not None}

@property
def outputs(self) -> typing.Dict[str, type]:
def outputs(self) -> OrderedDict[str, type]:
return self._outputs # type: ignore

@property
Expand Down Expand Up @@ -313,8 +313,8 @@ def verify_outputs_artifact_bindings(


def transform_types_to_list_of_type(
m: Dict[str, type], bound_inputs: typing.Set[str], list_as_optional: bool = False
) -> Dict[str, type]:
m: OrderedDict[str, type], bound_inputs: typing.Set[str], list_as_optional: bool = False
) -> OrderedDict[str, type]:
"""
Converts unbound inputs into the equivalent (optional) collections. This is useful for array jobs / map style code.
It will create a collection of types even if any one these types is not a collection type.
Expand Down Expand Up @@ -375,7 +375,7 @@ def transform_function_to_interface(fn: typing.Callable, docstring: Optional[Doc
outputs = extract_return_annotation(return_annotation)
for k, v in outputs.items():
outputs[k] = v # type: ignore
inputs: Dict[str, Tuple[Type, Any]] = OrderedDict()
inputs: OrderedDict[str, Tuple[Type, Any]] = OrderedDict()
for k, v in signature.parameters.items(): # type: ignore
annotation = type_hints.get(k, None)
default = v.default if v.default is not inspect.Parameter.empty else None
Expand Down Expand Up @@ -446,7 +446,7 @@ def output_name_generator(length: int) -> Generator[str, None, None]:
yield default_output_name(x)


def extract_return_annotation(return_annotation: Union[Type, Tuple, None]) -> Dict[str, Type]:
def extract_return_annotation(return_annotation: Union[Type, Tuple, None]) -> OrderedDict[str, Type]:
"""
The purpose of this function is to sort out whether a function is returning one thing, or multiple things, and to
name the outputs accordingly, either by using our default name function, or from a typing.NamedTuple.
Expand Down Expand Up @@ -481,7 +481,7 @@ def t(a: int, b: str) -> Dict[str, int]: ...
# Handle Option 6
# We can think about whether we should add a default output name with type None in the future.
if return_annotation in (None, type(None), inspect.Signature.empty):
return {}
return OrderedDict()

# This statement results in true for typing.Namedtuple, single and void return types, so this
# handles Options 1, 2. Even though NamedTuple for us is multi-valued, it's a single value for Python
Expand All @@ -491,7 +491,7 @@ def t(a: int, b: str) -> Dict[str, int]: ...
bases = return_annotation.__bases__ # type: ignore
if len(bases) == 1 and bases[0] == tuple and hasattr(return_annotation, "_fields"):
logger.debug(f"Task returns named tuple {return_annotation}")
return dict(get_type_hints(cast(Type, return_annotation), include_extras=True))
return OrderedDict(get_type_hints(cast(Type, return_annotation), include_extras=True))

if hasattr(return_annotation, "__origin__") and return_annotation.__origin__ is tuple: # type: ignore
# Handle option 3
Expand All @@ -511,7 +511,7 @@ def t(a: int, b: str) -> Dict[str, int]: ...
else:
# Handle all other single return types
logger.debug(f"Task returns unnamed native tuple {return_annotation}")
return {default_output_name(): cast(Type, return_annotation)}
return OrderedDict([(default_output_name(), cast(Type, return_annotation))])


def remap_shared_output_descriptions(output_descriptions: Dict[str, str], outputs: Dict[str, Type]) -> Dict[str, str]:
Expand Down
8 changes: 7 additions & 1 deletion flytekit/core/launch_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,13 @@ class ReferenceLaunchPlan(ReferenceEntity, LaunchPlan):
"""

def __init__(
self, project: str, domain: str, name: str, version: str, inputs: Dict[str, Type], outputs: Dict[str, Type]
self,
project: str,
domain: str,
name: str,
version: str,
inputs: typing.OrderedDict[str, Type],
outputs: typing.OrderedDict[str, Type],
):
super().__init__(LaunchPlanReference(project, domain, name, version), inputs, outputs)

Expand Down
40 changes: 20 additions & 20 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -1202,19 +1202,22 @@ def flyte_entity_call_handler(
#. Start a local execution - This means that we're not already in a local workflow execution, which means that
we should expect inputs to be native Python values and that we should return Python native values.
"""
# Sanity checks
# Only keyword args allowed
if len(args) > 0:
raise _user_exceptions.FlyteAssertion(
f"When calling tasks, only keyword args are supported. "
f"Aborting execution as detected {len(args)} positional args {args}"
)
# Make sure arguments are part of interface
for k, v in kwargs.items():
if k not in cast(SupportsNodeCreation, entity).python_interface.inputs:
raise AssertionError(
f"Received unexpected keyword argument '{k}' in function '{cast(SupportsNodeCreation, entity).name}'"
)
if k not in entity.python_interface.inputs:
raise AssertionError(f"Received unexpected keyword argument '{k}' in function '{entity.name}'")

# Check if we have more arguments than expected
if len(args) > len(entity.python_interface.inputs):
raise AssertionError(
f"Received more arguments than expected in function '{entity.name}'. Expected {len(entity.python_interface.inputs)} but got {len(args)}"
)

# Convert args to kwargs
for arg, input_name in zip(args, entity.python_interface.inputs.keys()):
if input_name in kwargs:
raise AssertionError(f"Got multiple values for argument '{input_name}' in function '{entity.name}'")
kwargs[input_name] = arg

ctx = FlyteContextManager.current_context()
if ctx.execution_state and (
Expand All @@ -1234,15 +1237,12 @@ def flyte_entity_call_handler(
child_ctx.execution_state
and child_ctx.execution_state.branch_eval_mode == BranchEvalMode.BRANCH_SKIPPED
):
if (
len(cast(SupportsNodeCreation, entity).python_interface.inputs) > 0
or len(cast(SupportsNodeCreation, entity).python_interface.outputs) > 0
):
output_names = list(cast(SupportsNodeCreation, entity).python_interface.outputs.keys())
if len(entity.python_interface.inputs) > 0 or len(entity.python_interface.outputs) > 0:
output_names = list(entity.python_interface.outputs.keys())
if len(output_names) == 0:
return VoidPromise(entity.name)
vals = [Promise(var, None) for var in output_names]
return create_task_output(vals, cast(SupportsNodeCreation, entity).python_interface)
return create_task_output(vals, entity.python_interface)
else:
return None
return cast(LocallyExecutable, entity).local_execute(ctx, **kwargs)
Expand All @@ -1255,7 +1255,7 @@ def flyte_entity_call_handler(
cast(ExecutionParameters, child_ctx.user_space_params)._decks = []
result = cast(LocallyExecutable, entity).local_execute(child_ctx, **kwargs)

expected_outputs = len(cast(SupportsNodeCreation, entity).python_interface.outputs)
expected_outputs = len(entity.python_interface.outputs)
if expected_outputs == 0:
if result is None or isinstance(result, VoidPromise):
return None
Expand All @@ -1268,10 +1268,10 @@ def flyte_entity_call_handler(
if (1 < expected_outputs == len(cast(Tuple[Promise], result))) or (
result is not None and expected_outputs == 1
):
return create_native_named_tuple(ctx, result, cast(SupportsNodeCreation, entity).python_interface)
return create_native_named_tuple(ctx, result, entity.python_interface)

raise AssertionError(
f"Expected outputs and actual outputs do not match."
f"Result {result}. "
f"Python interface: {cast(SupportsNodeCreation, entity).python_interface}"
f"Python interface: {entity.python_interface}"
)
6 changes: 3 additions & 3 deletions flytekit/core/reference.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Dict, Type
from typing import OrderedDict, Type

from flytekit.core.launch_plan import ReferenceLaunchPlan
from flytekit.core.task import ReferenceTask
Expand All @@ -15,8 +15,8 @@ def get_reference_entity(
domain: str,
name: str,
version: str,
inputs: Dict[str, Type],
outputs: Dict[str, Type],
inputs: OrderedDict[str, Type],
outputs: OrderedDict[str, Type],
):
"""
See the documentation for :py:class:`flytekit.reference_task` and :py:class:`flytekit.reference_workflow` as well.
Expand Down
6 changes: 3 additions & 3 deletions flytekit/core/reference_entity.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Type, Union
from typing import Any, Optional, OrderedDict, Tuple, Type, Union

from flytekit.core.context_manager import BranchEvalMode, ExecutionState, FlyteContext
from flytekit.core.interface import Interface, transform_interface_to_typed_interface
Expand Down Expand Up @@ -71,8 +71,8 @@ class ReferenceEntity(object):
def __init__(
self,
reference: Union[WorkflowReference, TaskReference, LaunchPlanReference],
inputs: Dict[str, Type],
outputs: Dict[str, Type],
inputs: OrderedDict[str, Type],
outputs: OrderedDict[str, Type],
):
if (
not isinstance(reference, WorkflowReference)
Expand Down
6 changes: 3 additions & 3 deletions flytekit/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import datetime
from functools import update_wrapper
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union, overload
from typing import Any, Callable, Dict, Iterable, List, Optional, OrderedDict, Tuple, Type, TypeVar, Union, overload

from flytekit.core import launch_plan as _annotated_launchplan
from flytekit.core import workflow as _annotated_workflow
Expand Down Expand Up @@ -366,8 +366,8 @@ def __init__(
domain: str,
name: str,
version: str,
inputs: Dict[str, type],
outputs: Dict[str, Type],
inputs: OrderedDict[str, type],
outputs: OrderedDict[str, Type],
):
super().__init__(TaskReference(project, domain, name, version), inputs, outputs)

Expand Down
8 changes: 7 additions & 1 deletion flytekit/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,7 +886,13 @@ class ReferenceWorkflow(ReferenceEntity, PythonFunctionWorkflow): # type: ignor
"""

def __init__(
self, project: str, domain: str, name: str, version: str, inputs: Dict[str, Type], outputs: Dict[str, Type]
self,
project: str,
domain: str,
name: str,
version: str,
inputs: typing.OrderedDict[str, Type],
outputs: typing.OrderedDict[str, Type],
):
super().__init__(WorkflowReference(project, domain, name, version), inputs, outputs)

Expand Down

0 comments on commit a69ed93

Please sign in to comment.