Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix type hinting for automatic run inputs #11796

Merged
merged 1 commit into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 8 additions & 31 deletions src/prefect/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@
)
from prefect.flows import Flow, load_flow_from_entrypoint
from prefect.futures import PrefectFuture, call_repr, resolve_futures_to_states
from prefect.input import RunInput, keyset_from_paused_state
from prefect.input import keyset_from_paused_state
from prefect.input.run_input import run_input_subclass_from_type
from prefect.logging.configuration import setup_logging
from prefect.logging.handlers import APILogHandler
Expand Down Expand Up @@ -212,7 +212,7 @@
from prefect.utilities.text import truncated_to

R = TypeVar("R")
T = TypeVar("T", bound=RunInput)
T = TypeVar("T")
EngineReturnType = Literal["future", "state", "result"]

NUM_CHARS_DYNAMIC_KEY = 8
Expand Down Expand Up @@ -987,18 +987,6 @@ async def pause_flow_run(
...


@overload
async def pause_flow_run(
wait_for_input: Type[Any],
flow_run_id: UUID = None,
timeout: int = 3600,
poll_interval: int = 10,
reschedule: bool = False,
key: str = None,
) -> Any:
...


@sync_compatible
@deprecated_parameter(
"flow_run_id", start_date="Dec 2023", help="Use `suspend_flow_run` instead."
Expand All @@ -1013,13 +1001,13 @@ async def pause_flow_run(
"wait_for_input", group="flow_run_input", when=lambda y: y is not None
)
async def pause_flow_run(
wait_for_input: Optional[Union[Type[T], Type[Any]]] = None,
wait_for_input: Optional[Type[T]] = None,
flow_run_id: UUID = None,
timeout: int = 3600,
poll_interval: int = 10,
reschedule: bool = False,
key: str = None,
):
) -> Optional[T]:
"""
Pauses the current flow run by blocking execution until resumed.

Expand Down Expand Up @@ -1102,8 +1090,8 @@ async def _in_process_pause(
reschedule=False,
key: str = None,
client=None,
wait_for_input: Optional[Union[Type[RunInput], Type[Any]]] = None,
) -> Optional[RunInput]:
wait_for_input: Optional[T] = None,
) -> Optional[T]:
if TaskRunContext.get():
raise RuntimeError("Cannot pause task runs.")

Expand Down Expand Up @@ -1234,29 +1222,18 @@ async def suspend_flow_run(
...


@overload
async def suspend_flow_run(
wait_for_input: Type[Any],
flow_run_id: Optional[UUID] = None,
timeout: Optional[int] = 3600,
key: Optional[str] = None,
client: PrefectClient = None,
) -> Any:
...


@sync_compatible
@inject_client
@experimental_parameter(
"wait_for_input", group="flow_run_input", when=lambda y: y is not None
)
async def suspend_flow_run(
wait_for_input: Optional[Union[Type[T], Type[Any]]] = None,
wait_for_input: Optional[Type[T]] = None,
flow_run_id: Optional[UUID] = None,
timeout: Optional[int] = 3600,
key: Optional[str] = None,
client: PrefectClient = None,
):
) -> Optional[T]:
"""
Suspends a flow run by stopping code execution until resumed.

Expand Down
134 changes: 92 additions & 42 deletions src/prefect/input/run_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ async def receiver_flow():
Type,
TypeVar,
Union,
cast,
overload,
)
from uuid import UUID, uuid4

Expand All @@ -93,7 +95,9 @@ async def receiver_flow():
if HAS_PYDANTIC_V2:
from prefect._internal.pydantic.v2_schema import create_v2_schema

T = TypeVar("T", bound="RunInput")
R = TypeVar("R", bound="RunInput")
T = TypeVar("T")

Keyset = Dict[
Union[Literal["description"], Literal["response"], Literal["schema"]], str
]
Expand Down Expand Up @@ -219,8 +223,8 @@ def load_from_flow_run_input(cls, flow_run_input: "FlowRunInput"):

@classmethod
def with_initial_data(
cls: Type[T], description: Optional[str] = None, **kwargs: Any
) -> Type[T]:
cls: Type[R], description: Optional[str] = None, **kwargs: Any
) -> Type[R]:
"""
Create a new `RunInput` subclass with the given initial data as field
defaults.
Expand Down Expand Up @@ -316,12 +320,12 @@ def subclass_from_base_model_type(
return type(f"{model_cls.__name__}RunInput", (RunInput, model_cls), {}) # type: ignore


class AutomaticRunInput(RunInput):
value: Any
class AutomaticRunInput(RunInput, Generic[T]):
value: T

@classmethod
@sync_compatible
async def load(cls, keyset: Keyset, flow_run_id: Optional[UUID] = None):
async def load(cls, keyset: Keyset, flow_run_id: Optional[UUID] = None) -> T:
"""
Load the run input response from the given key.

Expand All @@ -333,7 +337,7 @@ async def load(cls, keyset: Keyset, flow_run_id: Optional[UUID] = None):
return instance.value

@classmethod
def subclass_from_type(cls, _type: Type[T]) -> Type["AutomaticRunInput"]:
def subclass_from_type(cls, _type: Type[T]) -> Type["AutomaticRunInput[T]"]:
"""
Create a new `AutomaticRunInput` subclass from the given type.
"""
Expand Down Expand Up @@ -374,33 +378,30 @@ def receive(cls, *args, **kwargs):
return GetAutomaticInputHandler(run_input_cls=cls, *args, **kwargs)


def run_input_subclass_from_type(_type: Type[T]) -> Type[RunInput]:
def run_input_subclass_from_type(
_type: Union[Type[R], Type[T], pydantic.BaseModel]
) -> Union[Type[AutomaticRunInput[T]], Type[R]]:
"""
Create a new `RunInput` subclass from the given type.
"""
try:
is_class = issubclass(_type, object)
if issubclass(_type, RunInput):
return cast(Type[R], _type)
elif issubclass(_type, pydantic.BaseModel):
return cast(Type[R], RunInput.subclass_from_base_model_type(_type))
except TypeError:
is_class = False

if not is_class:
# Could be something like a typing._GenericAlias, so pass it through to
# Pydantic to see if we can create a model from it.
return AutomaticRunInput.subclass_from_type(_type)
if issubclass(_type, RunInput):
return _type
elif issubclass(_type, pydantic.BaseModel):
return RunInput.subclass_from_base_model_type(_type)
else:
# As a fall-through for a type that isn't a `RunInput` subclass or
# `pydantic.BaseModel` subclass, pass it through to Pydantic.
return AutomaticRunInput.subclass_from_type(_type)
pass

# Could be something like a typing._GenericAlias or any other type that
# isn't a `RunInput` subclass or `pydantic.BaseModel` subclass. Try passing
# it to AutomaticRunInput to see if we can create a model from it.
return cast(Type[AutomaticRunInput[T]], AutomaticRunInput.subclass_from_type(_type))


class GetInputHandler(Generic[T]):
class GetInputHandler(Generic[R]):
def __init__(
self,
run_input_cls: Type[T],
run_input_cls: Type[R],
key_prefix: str,
timeout: Optional[float] = 3600,
poll_interval: float = 10,
Expand All @@ -422,7 +423,7 @@ def __init__(
def __iter__(self):
return self

def __next__(self) -> T:
def __next__(self) -> R:
try:
return self.next()
except TimeoutError:
Expand All @@ -433,7 +434,7 @@ def __next__(self) -> T:
def __aiter__(self):
return self

async def __anext__(self) -> T:
async def __anext__(self) -> R:
try:
return await self.next()
except TimeoutError:
Expand All @@ -454,11 +455,11 @@ async def filter_for_inputs(self):

return flow_run_inputs

def to_instance(self, flow_run_input: "FlowRunInput") -> T:
def to_instance(self, flow_run_input: "FlowRunInput") -> R:
return self.run_input_cls.load_from_flow_run_input(flow_run_input)

@sync_compatible
async def next(self) -> T:
async def next(self) -> R:
flow_run_inputs = await self.filter_for_inputs()
if flow_run_inputs:
return self.to_instance(flow_run_inputs[0])
Expand All @@ -471,12 +472,22 @@ async def next(self) -> T:
return self.to_instance(flow_run_inputs[0])


class GetAutomaticInputHandler(GetInputHandler):
class GetAutomaticInputHandler(GetInputHandler, Generic[T]):
def __init__(self, *args, **kwargs):
self.with_metadata = kwargs.pop("with_metadata", False)
super().__init__(*args, **kwargs)

def to_instance(self, flow_run_input: "FlowRunInput") -> Any:
def __next__(self) -> T:
return cast(T, super().__next__())

async def __anext__(self) -> T:
return cast(T, await super().__anext__())

@sync_compatible
async def next(self) -> T:
return cast(T, await super().next())

def to_instance(self, flow_run_input: "FlowRunInput") -> T:
run_input = self.run_input_cls.load_from_flow_run_input(flow_run_input)

if self.with_metadata:
Expand Down Expand Up @@ -521,23 +532,62 @@ async def send_input(
)


@overload
def receive_input(
input_type: type,
input_type: Type[R],
timeout: Optional[float] = 3600,
poll_interval: float = 10,
raise_timeout_error: bool = False,
exclude_keys: Optional[Set[str]] = None,
key_prefix: Optional[str] = None,
flow_run_id: Optional[UUID] = None,
with_metadata: bool = False,
):
) -> GetInputHandler[R]:
...


@overload
def receive_input(
input_type: Type[T],
timeout: Optional[float] = 3600,
poll_interval: float = 10,
raise_timeout_error: bool = False,
exclude_keys: Optional[Set[str]] = None,
key_prefix: Optional[str] = None,
flow_run_id: Optional[UUID] = None,
with_metadata: bool = False,
) -> GetAutomaticInputHandler[T]:
...


def receive_input(
input_type: Union[Type[R], Type[T]],
timeout: Optional[float] = 3600,
poll_interval: float = 10,
raise_timeout_error: bool = False,
exclude_keys: Optional[Set[str]] = None,
key_prefix: Optional[str] = None,
flow_run_id: Optional[UUID] = None,
with_metadata: bool = False,
) -> Union[GetAutomaticInputHandler[T], GetInputHandler[R]]:
input_cls = run_input_subclass_from_type(input_type)
return input_cls.receive(
timeout=timeout,
poll_interval=poll_interval,
raise_timeout_error=raise_timeout_error,
exclude_keys=exclude_keys,
key_prefix=key_prefix,
flow_run_id=flow_run_id,
with_metadata=with_metadata,
)

if issubclass(input_cls, AutomaticRunInput):
return input_cls.receive(
timeout=timeout,
poll_interval=poll_interval,
raise_timeout_error=raise_timeout_error,
exclude_keys=exclude_keys,
key_prefix=key_prefix,
flow_run_id=flow_run_id,
with_metadata=with_metadata,
)
else:
return input_cls.receive(
timeout=timeout,
poll_interval=poll_interval,
raise_timeout_error=raise_timeout_error,
exclude_keys=exclude_keys,
key_prefix=key_prefix,
flow_run_id=flow_run_id,
)
8 changes: 8 additions & 0 deletions tests/input/test_run_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,14 @@ def test_automatic_input_receive_can_can_raise_timeout_errors_as_generator_sync(
pass


async def test_automatic_input_receive_run_input_subclass(flow_run):
await send_input(Place(city="New York", state="NY"), flow_run_id=flow_run.id)

received = await receive_input(Place, flow_run_id=flow_run.id, timeout=0).next()
assert received.city == "New York"
assert received.state == "NY"


async def test_receive(flow_run):
async def send():
for city, state in [("New York", "NY"), ("Boston", "MA"), ("Chicago", "IL")]:
Expand Down