Skip to content

Commit

Permalink
Add support for generator flows (#14061)
Browse files Browse the repository at this point in the history
Co-authored-by: Jeff Hale <discdiver@users.noreply.github.com>
  • Loading branch information
jlowin and discdiver committed Jun 17, 2024
1 parent 1d34acb commit f114496
Show file tree
Hide file tree
Showing 5 changed files with 482 additions and 4 deletions.
64 changes: 64 additions & 0 deletions docs/3.0rc/develop/write-workflows/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,70 @@ MyClass.my_class_method()
MyClass.my_static_method()
```

### Generators

Prefect supports synchronous and asynchronous generators as flows. The flow is considered to be `Running` as long as the generator is yielding values. When the generator is exhausted, the flow is considered `Completed`. Any values yielded by the generator can be consumed by other flows or tasks.

```python
from prefect import flow

@flow
def generator():
for i in range(10):
yield i

@flow
def consumer(x):
print(x)

for val in generator():
consumer(val)
```

<Warning>
**Generator functions are consumed when returned from flows**

The result of a completed flow must be serializable, but generators cannot be serialized.
Therefore, if you return a generator from a flow, the generator will be fully consumed and its yielded values will be returned as a list.
This can lead to unexpected behavior or blocking if the generator is infinite or very large.

Here is an example of proactive generator consumption:

```python
from prefect import flow

def gen():
yield from [1, 2, 3]
print('Generator consumed!')

@flow
def f():
return gen()

f() # prints 'Generator consumed!'
```

If you need to return a generator without consuming it, you can `yield` it instead of using `return`.
Values yielded from generator flows are not considered final results and do not face the same serialization constraints:

```python
from prefect import flow

def gen():
yield from [1, 2, 3]
print('Generator consumed!')

@flow
def f():
yield gen

generator = next(f())
list(generator) # prints 'Generator consumed!'

```
</Warning>


## Parameters

As with any Python function, you can pass arguments to a flow.
Expand Down
84 changes: 82 additions & 2 deletions src/prefect/flow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from dataclasses import dataclass, field
from typing import (
Any,
AsyncGenerator,
Callable,
Coroutine,
Dict,
Expand Down Expand Up @@ -50,12 +51,13 @@
return_value_to_state,
)
from prefect.utilities.asyncutils import run_coro_as_sync
from prefect.utilities.callables import call_with_parameters
from prefect.utilities.callables import call_with_parameters, parameters_to_args_kwargs
from prefect.utilities.collections import visit_collection
from prefect.utilities.engine import (
_get_hook_name,
_resolve_custom_flow_run_name,
capture_sigterm,
link_state_to_result,
propose_state_sync,
resolve_to_final_result,
)
Expand Down Expand Up @@ -632,6 +634,80 @@ async def run_flow_async(
return engine.state if return_type == "state" else engine.result()


def run_generator_flow_sync(
flow: Flow[P, R],
flow_run: Optional[FlowRun] = None,
parameters: Optional[Dict[str, Any]] = None,
wait_for: Optional[Iterable[PrefectFuture]] = None,
return_type: Literal["state", "result"] = "result",
) -> Generator[R, None, None]:
if return_type != "result":
raise ValueError("The return_type for a generator flow must be 'result'")

engine = FlowRunEngine[P, R](
flow=flow, parameters=parameters, flow_run=flow_run, wait_for=wait_for
)

with engine.start():
while engine.is_running():
with engine.run_context():
call_args, call_kwargs = parameters_to_args_kwargs(
flow.fn, engine.parameters or {}
)
gen = flow.fn(*call_args, **call_kwargs)
try:
while True:
gen_result = next(gen)
# link the current state to the result for dependency tracking
link_state_to_result(engine.state, gen_result)
yield gen_result
except StopIteration as exc:
engine.handle_success(exc.value)
except GeneratorExit as exc:
engine.handle_success(None)
gen.throw(exc)

return engine.result()


async def run_generator_flow_async(
flow: Flow[P, R],
flow_run: Optional[FlowRun] = None,
parameters: Optional[Dict[str, Any]] = None,
wait_for: Optional[Iterable[PrefectFuture]] = None,
return_type: Literal["state", "result"] = "result",
) -> AsyncGenerator[R, None]:
if return_type != "result":
raise ValueError("The return_type for a generator flow must be 'result'")

engine = FlowRunEngine[P, R](
flow=flow, parameters=parameters, flow_run=flow_run, wait_for=wait_for
)

with engine.start():
while engine.is_running():
with engine.run_context():
call_args, call_kwargs = parameters_to_args_kwargs(
flow.fn, engine.parameters or {}
)
gen = flow.fn(*call_args, **call_kwargs)
try:
while True:
# can't use anext in Python < 3.10
gen_result = await gen.__anext__()
# link the current state to the result for dependency tracking
link_state_to_result(engine.state, gen_result)
yield gen_result
except (StopAsyncIteration, GeneratorExit) as exc:
engine.handle_success(None)
if isinstance(exc, GeneratorExit):
gen.throw(exc)

# async generators can't return, but we can raise failures here
if engine.state.is_failed():
engine.result()


def run_flow(
flow: Flow[P, R],
flow_run: Optional[FlowRun] = None,
Expand All @@ -646,7 +722,11 @@ def run_flow(
wait_for=wait_for,
return_type=return_type,
)
if flow.isasync:
if flow.isasync and flow.isgenerator:
return run_generator_flow_async(**kwargs)
elif flow.isgenerator:
return run_generator_flow_sync(**kwargs)
elif flow.isasync:
return run_flow_async(**kwargs)
else:
return run_flow_sync(**kwargs)
14 changes: 12 additions & 2 deletions src/prefect/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@
from prefect.types import BANNED_CHARACTERS, WITHOUT_BANNED_CHARACTERS
from prefect.utilities.annotations import NotSet
from prefect.utilities.asyncutils import (
is_async_fn,
run_sync_in_worker_thread,
sync_compatible,
)
Expand Down Expand Up @@ -289,7 +288,18 @@ def __init__(
self.description = description or inspect.getdoc(fn)
update_wrapper(self, fn)
self.fn = fn
self.isasync = is_async_fn(self.fn)

# the flow is considered async if its function is async or an async
# generator
self.isasync = inspect.iscoroutinefunction(
self.fn
) or inspect.isasyncgenfunction(self.fn)

# the flow is considered a generator if its function is a generator or
# an async generator
self.isgenerator = inspect.isgeneratorfunction(
self.fn
) or inspect.isasyncgenfunction(self.fn)

raise_for_reserved_arguments(self.fn, ["return_state", "wait_for"])

Expand Down
Loading

0 comments on commit f114496

Please sign in to comment.