Skip to content

Commit

Permalink
Fix type hinting for automatic run inputs (#11796)
Browse files Browse the repository at this point in the history
  • Loading branch information
bunchesofdonald committed Feb 1, 2024
1 parent da8226b commit eab65a4
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 73 deletions.
39 changes: 8 additions & 31 deletions src/prefect/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@
)
from prefect.flows import Flow, load_flow_from_entrypoint
from prefect.futures import PrefectFuture, call_repr, resolve_futures_to_states
from prefect.input import RunInput, keyset_from_paused_state
from prefect.input import keyset_from_paused_state
from prefect.input.run_input import run_input_subclass_from_type
from prefect.logging.configuration import setup_logging
from prefect.logging.handlers import APILogHandler
Expand Down Expand Up @@ -212,7 +212,7 @@
from prefect.utilities.text import truncated_to

R = TypeVar("R")
T = TypeVar("T", bound=RunInput)
T = TypeVar("T")
EngineReturnType = Literal["future", "state", "result"]

NUM_CHARS_DYNAMIC_KEY = 8
Expand Down Expand Up @@ -987,18 +987,6 @@ async def pause_flow_run(
...


@overload
async def pause_flow_run(
wait_for_input: Type[Any],
flow_run_id: UUID = None,
timeout: int = 3600,
poll_interval: int = 10,
reschedule: bool = False,
key: str = None,
) -> Any:
...


@sync_compatible
@deprecated_parameter(
"flow_run_id", start_date="Dec 2023", help="Use `suspend_flow_run` instead."
Expand All @@ -1013,13 +1001,13 @@ async def pause_flow_run(
"wait_for_input", group="flow_run_input", when=lambda y: y is not None
)
async def pause_flow_run(
wait_for_input: Optional[Union[Type[T], Type[Any]]] = None,
wait_for_input: Optional[Type[T]] = None,
flow_run_id: UUID = None,
timeout: int = 3600,
poll_interval: int = 10,
reschedule: bool = False,
key: str = None,
):
) -> Optional[T]:
"""
Pauses the current flow run by blocking execution until resumed.
Expand Down Expand Up @@ -1102,8 +1090,8 @@ async def _in_process_pause(
reschedule=False,
key: str = None,
client=None,
wait_for_input: Optional[Union[Type[RunInput], Type[Any]]] = None,
) -> Optional[RunInput]:
wait_for_input: Optional[T] = None,
) -> Optional[T]:
if TaskRunContext.get():
raise RuntimeError("Cannot pause task runs.")

Expand Down Expand Up @@ -1234,29 +1222,18 @@ async def suspend_flow_run(
...


@overload
async def suspend_flow_run(
wait_for_input: Type[Any],
flow_run_id: Optional[UUID] = None,
timeout: Optional[int] = 3600,
key: Optional[str] = None,
client: PrefectClient = None,
) -> Any:
...


@sync_compatible
@inject_client
@experimental_parameter(
"wait_for_input", group="flow_run_input", when=lambda y: y is not None
)
async def suspend_flow_run(
wait_for_input: Optional[Union[Type[T], Type[Any]]] = None,
wait_for_input: Optional[Type[T]] = None,
flow_run_id: Optional[UUID] = None,
timeout: Optional[int] = 3600,
key: Optional[str] = None,
client: PrefectClient = None,
):
) -> Optional[T]:
"""
Suspends a flow run by stopping code execution until resumed.
Expand Down
134 changes: 92 additions & 42 deletions src/prefect/input/run_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ async def receiver_flow():
Type,
TypeVar,
Union,
cast,
overload,
)
from uuid import UUID, uuid4

Expand All @@ -93,7 +95,9 @@ async def receiver_flow():
if HAS_PYDANTIC_V2:
from prefect._internal.pydantic.v2_schema import create_v2_schema

T = TypeVar("T", bound="RunInput")
R = TypeVar("R", bound="RunInput")
T = TypeVar("T")

Keyset = Dict[
Union[Literal["description"], Literal["response"], Literal["schema"]], str
]
Expand Down Expand Up @@ -219,8 +223,8 @@ def load_from_flow_run_input(cls, flow_run_input: "FlowRunInput"):

@classmethod
def with_initial_data(
cls: Type[T], description: Optional[str] = None, **kwargs: Any
) -> Type[T]:
cls: Type[R], description: Optional[str] = None, **kwargs: Any
) -> Type[R]:
"""
Create a new `RunInput` subclass with the given initial data as field
defaults.
Expand Down Expand Up @@ -316,12 +320,12 @@ def subclass_from_base_model_type(
return type(f"{model_cls.__name__}RunInput", (RunInput, model_cls), {}) # type: ignore


class AutomaticRunInput(RunInput):
value: Any
class AutomaticRunInput(RunInput, Generic[T]):
value: T

@classmethod
@sync_compatible
async def load(cls, keyset: Keyset, flow_run_id: Optional[UUID] = None):
async def load(cls, keyset: Keyset, flow_run_id: Optional[UUID] = None) -> T:
"""
Load the run input response from the given key.
Expand All @@ -333,7 +337,7 @@ async def load(cls, keyset: Keyset, flow_run_id: Optional[UUID] = None):
return instance.value

@classmethod
def subclass_from_type(cls, _type: Type[T]) -> Type["AutomaticRunInput"]:
def subclass_from_type(cls, _type: Type[T]) -> Type["AutomaticRunInput[T]"]:
"""
Create a new `AutomaticRunInput` subclass from the given type.
"""
Expand Down Expand Up @@ -374,33 +378,30 @@ def receive(cls, *args, **kwargs):
return GetAutomaticInputHandler(run_input_cls=cls, *args, **kwargs)


def run_input_subclass_from_type(_type: Type[T]) -> Type[RunInput]:
def run_input_subclass_from_type(
_type: Union[Type[R], Type[T], pydantic.BaseModel]
) -> Union[Type[AutomaticRunInput[T]], Type[R]]:
"""
Create a new `RunInput` subclass from the given type.
"""
try:
is_class = issubclass(_type, object)
if issubclass(_type, RunInput):
return cast(Type[R], _type)
elif issubclass(_type, pydantic.BaseModel):
return cast(Type[R], RunInput.subclass_from_base_model_type(_type))
except TypeError:
is_class = False

if not is_class:
# Could be something like a typing._GenericAlias, so pass it through to
# Pydantic to see if we can create a model from it.
return AutomaticRunInput.subclass_from_type(_type)
if issubclass(_type, RunInput):
return _type
elif issubclass(_type, pydantic.BaseModel):
return RunInput.subclass_from_base_model_type(_type)
else:
# As a fall-through for a type that isn't a `RunInput` subclass or
# `pydantic.BaseModel` subclass, pass it through to Pydantic.
return AutomaticRunInput.subclass_from_type(_type)
pass

# Could be something like a typing._GenericAlias or any other type that
# isn't a `RunInput` subclass or `pydantic.BaseModel` subclass. Try passing
# it to AutomaticRunInput to see if we can create a model from it.
return cast(Type[AutomaticRunInput[T]], AutomaticRunInput.subclass_from_type(_type))


class GetInputHandler(Generic[T]):
class GetInputHandler(Generic[R]):
def __init__(
self,
run_input_cls: Type[T],
run_input_cls: Type[R],
key_prefix: str,
timeout: Optional[float] = 3600,
poll_interval: float = 10,
Expand All @@ -422,7 +423,7 @@ def __init__(
def __iter__(self):
return self

def __next__(self) -> T:
def __next__(self) -> R:
try:
return self.next()
except TimeoutError:
Expand All @@ -433,7 +434,7 @@ def __next__(self) -> T:
def __aiter__(self):
return self

async def __anext__(self) -> T:
async def __anext__(self) -> R:
try:
return await self.next()
except TimeoutError:
Expand All @@ -454,11 +455,11 @@ async def filter_for_inputs(self):

return flow_run_inputs

def to_instance(self, flow_run_input: "FlowRunInput") -> T:
def to_instance(self, flow_run_input: "FlowRunInput") -> R:
return self.run_input_cls.load_from_flow_run_input(flow_run_input)

@sync_compatible
async def next(self) -> T:
async def next(self) -> R:
flow_run_inputs = await self.filter_for_inputs()
if flow_run_inputs:
return self.to_instance(flow_run_inputs[0])
Expand All @@ -471,12 +472,22 @@ async def next(self) -> T:
return self.to_instance(flow_run_inputs[0])


class GetAutomaticInputHandler(GetInputHandler):
class GetAutomaticInputHandler(GetInputHandler, Generic[T]):
def __init__(self, *args, **kwargs):
self.with_metadata = kwargs.pop("with_metadata", False)
super().__init__(*args, **kwargs)

def to_instance(self, flow_run_input: "FlowRunInput") -> Any:
def __next__(self) -> T:
return cast(T, super().__next__())

async def __anext__(self) -> T:
return cast(T, await super().__anext__())

@sync_compatible
async def next(self) -> T:
return cast(T, await super().next())

def to_instance(self, flow_run_input: "FlowRunInput") -> T:
run_input = self.run_input_cls.load_from_flow_run_input(flow_run_input)

if self.with_metadata:
Expand Down Expand Up @@ -521,23 +532,62 @@ async def send_input(
)


@overload
def receive_input(
input_type: type,
input_type: Type[R],
timeout: Optional[float] = 3600,
poll_interval: float = 10,
raise_timeout_error: bool = False,
exclude_keys: Optional[Set[str]] = None,
key_prefix: Optional[str] = None,
flow_run_id: Optional[UUID] = None,
with_metadata: bool = False,
):
) -> GetInputHandler[R]:
...


@overload
def receive_input(
input_type: Type[T],
timeout: Optional[float] = 3600,
poll_interval: float = 10,
raise_timeout_error: bool = False,
exclude_keys: Optional[Set[str]] = None,
key_prefix: Optional[str] = None,
flow_run_id: Optional[UUID] = None,
with_metadata: bool = False,
) -> GetAutomaticInputHandler[T]:
...


def receive_input(
input_type: Union[Type[R], Type[T]],
timeout: Optional[float] = 3600,
poll_interval: float = 10,
raise_timeout_error: bool = False,
exclude_keys: Optional[Set[str]] = None,
key_prefix: Optional[str] = None,
flow_run_id: Optional[UUID] = None,
with_metadata: bool = False,
) -> Union[GetAutomaticInputHandler[T], GetInputHandler[R]]:
input_cls = run_input_subclass_from_type(input_type)
return input_cls.receive(
timeout=timeout,
poll_interval=poll_interval,
raise_timeout_error=raise_timeout_error,
exclude_keys=exclude_keys,
key_prefix=key_prefix,
flow_run_id=flow_run_id,
with_metadata=with_metadata,
)

if issubclass(input_cls, AutomaticRunInput):
return input_cls.receive(
timeout=timeout,
poll_interval=poll_interval,
raise_timeout_error=raise_timeout_error,
exclude_keys=exclude_keys,
key_prefix=key_prefix,
flow_run_id=flow_run_id,
with_metadata=with_metadata,
)
else:
return input_cls.receive(
timeout=timeout,
poll_interval=poll_interval,
raise_timeout_error=raise_timeout_error,
exclude_keys=exclude_keys,
key_prefix=key_prefix,
flow_run_id=flow_run_id,
)
8 changes: 8 additions & 0 deletions tests/input/test_run_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,14 @@ def test_automatic_input_receive_can_can_raise_timeout_errors_as_generator_sync(
pass


async def test_automatic_input_receive_run_input_subclass(flow_run):
await send_input(Place(city="New York", state="NY"), flow_run_id=flow_run.id)

received = await receive_input(Place, flow_run_id=flow_run.id, timeout=0).next()
assert received.city == "New York"
assert received.state == "NY"


async def test_receive(flow_run):
async def send():
for city, state in [("New York", "NY"), ("Boston", "MA"), ("Chicago", "IL")]:
Expand Down

0 comments on commit eab65a4

Please sign in to comment.