-
Notifications
You must be signed in to change notification settings - Fork 1.5k
/
test_cloud_task_runner.py
417 lines (347 loc) · 14.8 KB
/
test_cloud_task_runner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
import cloudpickle
import datetime
import tempfile
import time
import uuid
from unittest.mock import MagicMock
import pytest
import prefect
from prefect.client import Client
from prefect.core import Edge, Task
from prefect.engine.cloud import CloudTaskRunner, CloudResultSerializer
from prefect.engine.result_serializers import (
JSONResultSerializer,
LocalResultSerializer,
)
from prefect.engine.runner import ENDRUN
from prefect.engine.state import (
Cached,
Failed,
Finished,
Mapped,
Paused,
Pending,
Running,
Retrying,
Skipped,
Success,
TimedOut,
TriggerFailed,
)
from prefect.serialization.result_serializers import ResultSerializerSchema
from prefect.utilities.configuration import set_temporary_config
@pytest.fixture(autouse=True)
def cloud_settings():
with set_temporary_config(
{
"engine.flow_runner.default_class": "prefect.engine.cloud.CloudFlowRunner",
"engine.task_runner.default_class": "prefect.engine.cloud.CloudTaskRunner",
"cloud.auth_token": "token",
}
):
yield
@pytest.fixture()
def client(monkeypatch):
cloud_client = MagicMock(
get_flow_run_info=MagicMock(return_value=MagicMock(state=None)),
set_flow_run_state=MagicMock(),
get_task_run_info=MagicMock(return_value=MagicMock(state=None)),
set_task_run_state=MagicMock(),
get_latest_task_run_states=MagicMock(
side_effect=lambda flow_run_id, states: states
),
)
monkeypatch.setattr(
"prefect.engine.cloud.task_runner.Client", MagicMock(return_value=cloud_client)
)
monkeypatch.setattr(
"prefect.engine.cloud.flow_runner.Client", MagicMock(return_value=cloud_client)
)
yield cloud_client
class TestInitializeRun:
def test_ensures_all_upstream_states_are_raw(self, client):
serialized_handler = ResultSerializerSchema().dump(LocalResultSerializer())
with tempfile.NamedTemporaryFile() as tmp:
with open(tmp.name, "wb") as f:
cloudpickle.dump(42, f)
a, b, c = (
Success(result=tmp.name),
Failed(result=55),
Pending(result=tmp.name),
)
a._metadata["result"] = dict(
raw=False, result_serializer=serialized_handler
)
c._metadata["result"] = dict(
raw=False, result_serializer=serialized_handler
)
result = CloudTaskRunner(Task()).initialize_run(
state=Success(), context={}, upstream_states={1: a, 2: b, 3: c}
)
assert result.upstream_states[1].result == 42
assert result.upstream_states[2].result == 55
assert result.upstream_states[3].result == 42
def test_ensures_provided_initial_state_is_raw(self, client):
serialized_handler = ResultSerializerSchema().dump(LocalResultSerializer())
with tempfile.NamedTemporaryFile() as tmp:
with open(tmp.name, "wb") as f:
cloudpickle.dump(42, f)
state = Success(result=tmp.name)
state._metadata["result"] = dict(
raw=False, result_serializer=serialized_handler
)
result = CloudTaskRunner(Task()).initialize_run(
state=state, context={}, upstream_states={}
)
assert result.state.result == 42
def test_task_runner_doesnt_call_client_if_map_index_is_none(client):
task = Task(name="test")
res = CloudTaskRunner(task=task).run()
## assertions
assert client.get_task_run_info.call_count == 0 # never called
assert client.set_task_run_state.call_count == 2 # Pending -> Running -> Success
states = [call[1]["state"] for call in client.set_task_run_state.call_args_list]
assert states == [Running(), Success()]
assert res.is_successful()
def test_task_runner_calls_get_task_run_info_if_map_index_is_not_none(client):
task = Task(name="test")
res = CloudTaskRunner(task=task).run(context={"map_index": 1})
## assertions
assert client.get_task_run_info.call_count == 1 # never called
assert client.set_task_run_state.call_count == 2 # Pending -> Running -> Success
states = [call[1]["state"] for call in client.set_task_run_state.call_args_list]
assert states == [Running(), Success()]
def test_task_runner_raises_endrun_if_client_cant_communicate_during_state_updates(
monkeypatch
):
@prefect.task(name="test")
def raise_error():
raise NameError("I don't exist")
get_task_run_info = MagicMock(return_value=MagicMock(state=None))
set_task_run_state = MagicMock(side_effect=SyntaxError)
client = MagicMock(
get_task_run_info=get_task_run_info, set_task_run_state=set_task_run_state
)
monkeypatch.setattr(
"prefect.engine.cloud.task_runner.Client", MagicMock(return_value=client)
)
## an ENDRUN will cause the TaskRunner to return the most recently computed state
res = CloudTaskRunner(task=raise_error).run(context={"map_index": 1})
assert set_task_run_state.called
assert res.is_running()
def test_task_runner_raises_endrun_if_client_cant_receive_state_updates(monkeypatch):
task = Task(name="test")
get_task_run_info = MagicMock(side_effect=SyntaxError)
set_task_run_state = MagicMock()
client = MagicMock(
get_task_run_info=get_task_run_info, set_task_run_state=set_task_run_state
)
monkeypatch.setattr(
"prefect.engine.cloud.task_runner.Client", MagicMock(return_value=client)
)
## an ENDRUN will cause the TaskRunner to return the most recently computed state
res = CloudTaskRunner(task=task).run(context={"map_index": 1})
assert get_task_run_info.called
assert res.is_failed()
assert isinstance(res.result, SyntaxError)
def test_task_runner_raises_endrun_with_correct_state_if_client_cant_receive_state_updates(
monkeypatch
):
task = Task(name="test")
get_task_run_info = MagicMock(side_effect=SyntaxError)
set_task_run_state = MagicMock()
client = MagicMock(
get_task_run_info=get_task_run_info, set_task_run_state=set_task_run_state
)
monkeypatch.setattr(
"prefect.engine.cloud.task_runner.Client", MagicMock(return_value=client)
)
## an ENDRUN will cause the TaskRunner to return the most recently computed state
state = Pending(message="unique message", result=42)
res = CloudTaskRunner(task=task).run(state=state, context={"map_index": 1})
assert get_task_run_info.called
assert res is state
@pytest.mark.parametrize(
"state", [Finished, Success, Skipped, Failed, TimedOut, TriggerFailed]
)
def test_task_runner_respects_the_db_state(monkeypatch, state):
task = Task(name="test")
db_state = state("already", result=10)
get_task_run_info = MagicMock(return_value=MagicMock(state=db_state))
set_task_run_state = MagicMock()
client = MagicMock(
get_task_run_info=get_task_run_info, set_task_run_state=set_task_run_state
)
monkeypatch.setattr(
"prefect.engine.cloud.task_runner.Client", MagicMock(return_value=client)
)
res = CloudTaskRunner(task=task).run(context={"map_index": 1})
## assertions
assert get_task_run_info.call_count == 1 # one time to pull latest state
assert set_task_run_state.call_count == 0 # never needs to update state
assert res == db_state
def test_task_runner_uses_cached_inputs_from_db_state(monkeypatch):
@prefect.task(name="test")
def add_one(x):
return x + 1
db_state = Retrying(cached_inputs=dict(x=41))
get_task_run_info = MagicMock(return_value=MagicMock(state=db_state))
set_task_run_state = MagicMock()
client = MagicMock(
get_task_run_info=get_task_run_info, set_task_run_state=set_task_run_state
)
monkeypatch.setattr(
"prefect.engine.cloud.task_runner.Client", MagicMock(return_value=client)
)
res = CloudTaskRunner(task=add_one).run(context={"map_index": 1})
## assertions
assert get_task_run_info.call_count == 1 # one time to pull latest state
assert set_task_run_state.call_count == 2 # Pending -> Running -> Success
assert res.is_successful()
assert res.result == 42
@pytest.mark.parametrize(
"state", [Finished, Success, Skipped, Failed, TimedOut, TriggerFailed]
)
def test_task_runner_prioritizes_kwarg_states_over_db_states(monkeypatch, state):
task = Task(name="test")
db_state = state("already", result=10)
get_task_run_info = MagicMock(return_value=MagicMock(state=db_state))
set_task_run_state = MagicMock()
client = MagicMock(
get_task_run_info=get_task_run_info, set_task_run_state=set_task_run_state
)
monkeypatch.setattr(
"prefect.engine.cloud.task_runner.Client", MagicMock(return_value=client)
)
res = CloudTaskRunner(task=task).run(
state=Pending("let's do this"), context={"map_index": 1}
)
## assertions
assert get_task_run_info.call_count == 1 # one time to pull latest state
assert set_task_run_state.call_count == 2 # Pending -> Running -> Success
states = [call[1]["state"] for call in set_task_run_state.call_args_list]
assert states == [Running(), Success()]
class TestHeartBeats:
def test_heartbeat_traps_errors_caused_by_client(self, monkeypatch):
client = MagicMock(update_task_run_heartbeat=MagicMock(side_effect=SyntaxError))
monkeypatch.setattr(
"prefect.engine.cloud.task_runner.Client", MagicMock(return_value=client)
)
runner = CloudTaskRunner(task=Task(name="bad"))
runner.task_run_id = None
with pytest.warns(UserWarning) as warning:
res = runner._heartbeat()
assert res is None
assert client.update_task_run_heartbeat.called
w = warning.pop()
assert "Heartbeat failed for Task 'bad'" in repr(w.message)
def test_heartbeat_traps_errors_caused_by_bad_attributes(self, monkeypatch):
monkeypatch.setattr("prefect.engine.cloud.task_runner.Client", MagicMock())
runner = CloudTaskRunner(task=Task())
with pytest.warns(UserWarning) as warning:
res = runner._heartbeat()
assert res is None
w = warning.pop()
assert "Heartbeat failed for Task 'Task'" in repr(w.message)
@pytest.mark.parametrize(
"executor", ["local", "sync", "mproc", "mthread"], indirect=True
)
def test_task_runner_has_a_heartbeat(self, executor, monkeypatch):
client = MagicMock()
monkeypatch.setattr(
"prefect.engine.cloud.task_runner.Client", MagicMock(return_value=client)
)
@prefect.task
def sleeper():
time.sleep(0.2)
with set_temporary_config({"cloud.heartbeat_interval": 0.05}):
res = CloudTaskRunner(task=sleeper).run(executor=executor)
assert res.is_successful()
assert client.update_task_run_heartbeat.called
assert client.update_task_run_heartbeat.call_count >= 2
@pytest.mark.parametrize(
"executor", ["local", "sync", "mproc", "mthread"], indirect=True
)
def test_task_runner_has_a_heartbeat_only_during_execution(
self, executor, monkeypatch
):
client = MagicMock()
monkeypatch.setattr(
"prefect.engine.cloud.task_runner.Client", MagicMock(return_value=client)
)
with set_temporary_config({"cloud.heartbeat_interval": 0.05}):
runner = CloudTaskRunner(task=Task())
runner.cache_result = lambda *args, **kwargs: time.sleep(0.2)
res = runner.run(executor=executor)
assert client.update_task_run_heartbeat.called
assert client.update_task_run_heartbeat.call_count == 1
@pytest.mark.parametrize(
"executor", ["local", "sync", "mproc", "mthread"], indirect=True
)
def test_task_runner_has_a_heartbeat_with_task_run_id(self, executor, monkeypatch):
client = MagicMock()
monkeypatch.setattr(
"prefect.engine.cloud.task_runner.Client", MagicMock(return_value=client)
)
task = Task(name="test")
res = CloudTaskRunner(task=task).run(
executor=executor, context={"task_run_id": 1234}
)
assert res.is_successful()
assert client.update_task_run_heartbeat.call_args[0][0] == 1234
class TestStateResultHandling:
def test_task_runner_handles_outputs_prior_to_setting_state(self, client):
serialized = ResultSerializerSchema().dump(JSONResultSerializer())
@prefect.task(
cache_for=datetime.timedelta(days=1),
result_serializer=JSONResultSerializer(),
)
def add(x, y):
return x + y
x_state = Success(result=1)
y_state = Success(result=1)
x_state._metadata["result"]["result_serializer"] = serialized
y_state._metadata["result"]["result_serializer"] = serialized
upstream_states = {
Edge(Task(), Task(), key="x"): x_state,
Edge(Task(), Task(), key="y"): y_state,
}
res = CloudTaskRunner(task=add).run(upstream_states=upstream_states)
## assertions
assert client.get_task_run_info.call_count == 0 # never called
assert (
client.set_task_run_state.call_count == 3
) # Pending -> Running -> Successful -> Cached
states = [call[1]["state"] for call in client.set_task_run_state.call_args_list]
assert states[0].is_running()
assert states[1].is_successful()
assert isinstance(states[2], Cached)
assert states[2].cached_inputs == dict(x="1", y="1")
assert states[2].result == "2"
def test_task_runner_handles_inputs_prior_to_setting_state(self, client):
serialized = ResultSerializerSchema().dump(JSONResultSerializer())
@prefect.task(max_retries=1, retry_delay=datetime.timedelta(days=1))
def add(x, y):
return x + y
state = Pending(cached_inputs=dict(x=1, y="0"))
x_state = Success(result=1)
y_state = Success(result=1)
x_state._metadata["result"]["result_serializer"] = serialized
y_state._metadata["result"]["result_serializer"] = serialized
upstream_states = {
Edge(Task(), Task(), key="x"): x_state,
Edge(Task(), Task(), key="y"): y_state,
}
res = CloudTaskRunner(task=add).run(
state=state, upstream_states=upstream_states
)
## assertions
assert client.get_task_run_info.call_count == 0 # never called
assert (
client.set_task_run_state.call_count == 3
) # Pending -> Running -> Failed -> Retrying
states = [call[1]["state"] for call in client.set_task_run_state.call_args_list]
assert states[0].is_running()
assert states[1].is_failed()
assert isinstance(states[2], Retrying)
assert states[2].cached_inputs == dict(x="1", y='"0"')