-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update DaskTaskRunner
for compatibility with the updated engine
#13555
Conversation
6600f22
to
e298937
Compare
e298937
to
aff963e
Compare
from prefect.utilities.engine import ( | ||
collect_task_run_inputs_sync, | ||
resolve_inputs_sync, | ||
) | ||
|
||
# We need to resolve some futures to map over their data, collect the upstream | ||
# links beforehand to retain relationship tracking. | ||
task_inputs = { | ||
k: collect_task_run_inputs_sync(v, max_depth=0) | ||
for k, v in parameters.items() | ||
} | ||
|
||
# Resolve the top-level parameters in order to get mappable data of a known length. | ||
# Nested parameters will be resolved in each mapped child where their relationships | ||
# will also be tracked. | ||
parameters = resolve_inputs_sync(parameters, max_depth=0) | ||
|
||
# Ensure that any parameters in kwargs are expanded before this check | ||
parameters = explode_variadic_parameter(task.fn, parameters) | ||
|
||
iterable_parameters = {} | ||
static_parameters = {} | ||
annotated_parameters = {} | ||
for key, val in parameters.items(): | ||
if isinstance(val, (allow_failure, quote)): | ||
# Unwrap annotated parameters to determine if they are iterable | ||
annotated_parameters[key] = val | ||
val = val.unwrap() | ||
|
||
if isinstance(val, unmapped): | ||
static_parameters[key] = val.value | ||
elif isiterable(val): | ||
iterable_parameters[key] = list(val) | ||
else: | ||
static_parameters[key] = val | ||
|
||
if not len(iterable_parameters): | ||
raise MappingMissingIterable( | ||
"No iterable parameters were received. Parameters for map must " | ||
f"include at least one iterable. Parameters: {parameters}" | ||
) | ||
|
||
iterable_parameter_lengths = { | ||
key: len(val) for key, val in iterable_parameters.items() | ||
} | ||
lengths = set(iterable_parameter_lengths.values()) | ||
if len(lengths) > 1: | ||
raise MappingLengthMismatch( | ||
"Received iterable parameters with different lengths. Parameters for map" | ||
f" must all be the same length. Got lengths: {iterable_parameter_lengths}" | ||
) | ||
|
||
map_length = list(lengths)[0] | ||
|
||
futures = [] | ||
for i in range(map_length): | ||
call_parameters = { | ||
key: value[i] for key, value in iterable_parameters.items() | ||
} | ||
call_parameters.update( | ||
{key: value for key, value in static_parameters.items()} | ||
) | ||
|
||
# Add default values for parameters; these are skipped earlier since they should | ||
# not be mapped over | ||
for key, value in get_parameter_defaults(task.fn).items(): | ||
call_parameters.setdefault(key, value) | ||
|
||
# Re-apply annotations to each key again | ||
for key, annotation in annotated_parameters.items(): | ||
call_parameters[key] = annotation.rewrap(call_parameters[key]) | ||
|
||
# Collapse any previously exploded kwargs | ||
call_parameters = collapse_variadic_parameters(task.fn, call_parameters) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is identical to a chunk of logic in the ThreadPoolTaskRunner
. If I need to repeat this in the RayTaskRunner
too, I'll move it up to the TaskRunner
base class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
overall looks good, left a few questions that I'd like to understand before ✅
parameters: Dict[str, Any], | ||
wait_for: Iterable[PrefectFuture], | ||
dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None, | ||
) -> PrefectDaskFuture: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Every future type that we implement for each task runner will have the same interface, right? It's just there will be some special handling based on the underlying system?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, Each specific future will have a different implementation for waiting for a future and getting a future result, but the .wait
, .result
, and .state
interfaces will all be the same.
""" | ||
self.__dict__.update(data) | ||
self._client = distributed.get_client() | ||
def __exit__(self, *args): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What let you drop the serialization requirement?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Previously, a PrefectFuture
carried an instance of the task runner that submitted the run around with it. We don't have to do that anymore because we directly wrap a future, so we shouldn't be pickling task runners anymore.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice.
Updates the task runner to implement the new
TaskRunner
interface. The new task runner delegates most responsibilities to thePrefectDistributedClient
, but the task runner handles wrapping and unwrappingPrefectDaskFuture
s to ensure work is efficiently scheduled on a Dask cluster.Example
Checklist
<link to issue>
"maintenance
,fix
,feature
,enhancement
,docs
.For documentation changes:
netlify.toml
for files that are removed or renamed.For new functions or classes in the Python SDK:
mkdocs.yml
navigation.