-
Notifications
You must be signed in to change notification settings - Fork 1.5k
/
dask.py
267 lines (218 loc) · 9.7 KB
/
dask.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
import datetime
import logging
import queue
import uuid
import warnings
from contextlib import contextmanager
from typing import Any, Callable, Iterable, Iterator, List
import dask
from distributed import Client, Future, fire_and_forget, worker_client
from prefect import context
from prefect.engine.executors.base import Executor
class DaskExecutor(Executor):
"""
An executor that runs all functions using the `dask.distributed` scheduler on
a (possibly local) dask cluster. If you already have one running, provide the
address of the scheduler upon initialization; otherwise, one will be created
(and subsequently torn down) within the `start()` contextmanager.
Note that if you have tasks with tags of the form `"dask-resource:KEY=NUM"` they will be parsed
and passed as [Worker Resources](https://distributed.dask.org/en/latest/resources.html) of the form
`{"KEY": float(NUM)}` to the Dask Scheduler.
Args:
- address (string, optional): address of a currently running dask
scheduler; if one is not provided, a `distributed.LocalCluster()` will be created in `executor.start()`.
Defaults to `None`
- local_processes (bool, optional): whether to use multiprocessing or not
(computations will still be multithreaded). Ignored if address is provided.
Defaults to `False`.
- debug (bool, optional): whether to operate in debug mode; `debug=True`
will produce many additional dask logs. Defaults to the `debug` value in your Prefect configuration
- **kwargs (dict, optional): additional kwargs to be passed to the
`dask.distributed.Client` upon initialization (e.g., `n_workers`)
"""
def __init__(
self,
address: str = None,
local_processes: bool = None,
debug: bool = None,
**kwargs: Any
):
if address is None:
address = context.config.engine.executor.dask.address
if address == "local":
address = None
if local_processes is None:
local_processes = context.config.engine.executor.dask.local_processes
if debug is None:
debug = context.config.debug
self.address = address
self.local_processes = local_processes
self.debug = debug
self.is_started = False
self.kwargs = kwargs
super().__init__()
@contextmanager
def start(self) -> Iterator[None]:
"""
Context manager for initializing execution.
Creates a `dask.distributed.Client` and yields it.
"""
try:
if self.address is None:
self.kwargs.update(
silence_logs=logging.CRITICAL if not self.debug else logging.WARNING
)
self.kwargs.update(processes=self.local_processes)
with Client(self.address, **self.kwargs) as client:
self.client = client
self.is_started = True
yield self.client
finally:
self.client = None
self.is_started = False
def _prep_dask_kwargs(self) -> dict:
dask_kwargs = {"pure": False} # type: dict
## set a key for the dask scheduler UI
if context.get("task_full_name"):
key = context.get("task_full_name", "") + "-" + str(uuid.uuid4())
dask_kwargs.update(key=key)
## infer from context if dask resources are being utilized
dask_resource_tags = [
tag
for tag in context.get("task_tags", [])
if tag.lower().startswith("dask-resource")
]
if dask_resource_tags:
resources = {}
for tag in dask_resource_tags:
prefix, val = tag.split("=")
resources.update({prefix.split(":")[1]: float(val)})
dask_kwargs.update(resources=resources)
return dask_kwargs
def __getstate__(self) -> dict:
state = self.__dict__.copy()
if "client" in state:
del state["client"]
return state
def __setstate__(self, state: dict) -> None:
self.__dict__.update(state)
def submit(self, fn: Callable, *args: Any, **kwargs: Any) -> Future:
"""
Submit a function to the executor for execution. Returns a Future object.
Args:
- fn (Callable): function that is being submitted for execution
- *args (Any): arguments to be passed to `fn`
- **kwargs (Any): keyword arguments to be passed to `fn`
Returns:
- Future: a Future-like object that represents the computation of `fn(*args, **kwargs)`
"""
dask_kwargs = self._prep_dask_kwargs()
kwargs.update(dask_kwargs)
if self.is_started and hasattr(self, "client"):
future = self.client.submit(fn, *args, **kwargs)
elif self.is_started:
with worker_client(separate_thread=True) as client:
future = client.submit(fn, *args, **kwargs)
else:
raise ValueError("This executor has not been started.")
fire_and_forget(future)
return future
def map(self, fn: Callable, *args: Any, **kwargs: Any) -> List[Future]:
"""
Submit a function to be mapped over its iterable arguments.
Args:
- fn (Callable): function that is being submitted for execution
- *args (Any): arguments that the function will be mapped over
- **kwargs (Any): additional keyword arguments that will be passed to the Dask Client
Returns:
- List[Future]: a list of Future-like objects that represent each computation of
fn(*a), where a = zip(*args)[i]
"""
if not args:
return []
dask_kwargs = self._prep_dask_kwargs()
kwargs.update(dask_kwargs)
if self.is_started and hasattr(self, "client"):
futures = self.client.map(fn, *args, **kwargs)
elif self.is_started:
with worker_client(separate_thread=True) as client:
futures = client.map(fn, *args, **kwargs)
return client.gather(futures)
else:
raise ValueError("This executor has not been started.")
fire_and_forget(futures)
return futures
def wait(self, futures: Any) -> Any:
"""
Resolves the Future objects to their values. Blocks until the computation is complete.
Args:
- futures (Any): single or iterable of future-like objects to compute
Returns:
- Any: an iterable of resolved futures with similar shape to the input
"""
if self.is_started and hasattr(self, "client"):
return self.client.gather(futures)
elif self.is_started:
with worker_client(separate_thread=True) as client:
return client.gather(futures)
else:
raise ValueError("This executor has not been started.")
class LocalDaskExecutor(Executor):
"""
An executor that runs all functions locally using `dask` and a configurable dask scheduler. Note that
this executor is known to occasionally run tasks twice when using multi-level mapping.
Prefect's mapping feature will not work in conjunction with setting `scheduler="processes"`.
Args:
- scheduler (str): The local dask scheduler to use; common options are "synchronous", "threads" and "processes". Defaults to "synchronous".
- **kwargs (Any): Additional keyword arguments to pass to dask config
"""
def __init__(self, scheduler: str = "synchronous", **kwargs: Any):
self.scheduler = scheduler
self.kwargs = kwargs
super().__init__()
@contextmanager
def start(self) -> Iterator:
"""
Context manager for initializing execution.
Configures `dask` and yields the `dask.config` contextmanager.
"""
with dask.config.set(scheduler=self.scheduler, **self.kwargs) as cfg:
yield cfg
def submit(self, fn: Callable, *args: Any, **kwargs: Any) -> dask.delayed:
"""
Submit a function to the executor for execution. Returns a `dask.delayed` object.
Args:
- fn (Callable): function that is being submitted for execution
- *args (Any): arguments to be passed to `fn`
- **kwargs (Any): keyword arguments to be passed to `fn`
Returns:
- dask.delayed: a `dask.delayed` object that represents the computation of `fn(*args, **kwargs)`
"""
return dask.delayed(fn)(*args, **kwargs)
def map(self, fn: Callable, *args: Any) -> List[dask.delayed]:
"""
Submit a function to be mapped over its iterable arguments.
Args:
- fn (Callable): function that is being submitted for execution
- *args (Any): arguments that the function will be mapped over
Returns:
- List[dask.delayed]: the result of computating the function over the arguments
"""
if self.scheduler == "processes":
raise RuntimeError(
"LocalDaskExecutor cannot map if scheduler='processes'. Please set to either 'synchronous' or 'threads'."
)
results = []
for args_i in zip(*args):
results.append(self.submit(fn, *args_i))
return results
def wait(self, futures: Any) -> Any:
"""
Resolves a `dask.delayed` object to its values. Blocks until the computation is complete.
Args:
- futures (Any): iterable of `dask.delayed` objects to compute
Returns:
- Any: an iterable of resolved futures
"""
with dask.config.set(scheduler=self.scheduler, **self.kwargs) as cfg:
return dask.compute(futures)[0]