-
Notifications
You must be signed in to change notification settings - Fork 728
/
kubernetes_job.py
749 lines (680 loc) · 31.1 KB
/
kubernetes_job.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
import json
import math
import random
import time
from metaflow.tracing import inject_tracing_vars
from metaflow.exception import MetaflowException
from metaflow.metaflow_config import KUBERNETES_SECRETS
CLIENT_REFRESH_INTERVAL_SECONDS = 300
class KubernetesJobException(MetaflowException):
headline = "Kubernetes job error"
# Implements truncated exponential backoff from
# https://cloud.google.com/storage/docs/retry-strategy#exponential-backoff
def k8s_retry(deadline_seconds=60, max_backoff=32):
def decorator(function):
from functools import wraps
@wraps(function)
def wrapper(*args, **kwargs):
from kubernetes import client
deadline = time.time() + deadline_seconds
retry_number = 0
while True:
try:
result = function(*args, **kwargs)
return result
except client.rest.ApiException as e:
if e.status == 500:
current_t = time.time()
backoff_delay = min(
math.pow(2, retry_number) + random.random(), max_backoff
)
if current_t + backoff_delay < deadline:
time.sleep(backoff_delay)
retry_number += 1
continue # retry again
else:
raise
else:
raise
return wrapper
return decorator
class KubernetesJob(object):
def __init__(self, client, **kwargs):
self._client = client
self._kwargs = kwargs
def create(self):
# A discerning eye would notice and question the choice of using the
# V1Job construct over the V1Pod construct given that we don't rely much
# on any of the V1Job semantics. The major reasons at the moment are -
# 1. It makes the Kubernetes UIs (Octant, Lens) a bit easier on
# the eyes, although even that can be questioned.
# 2. AWS Step Functions, at the moment (Apr' 22) only supports
# executing Jobs and not Pods as part of it's publicly declared
# API. When we ship the AWS Step Functions integration with EKS,
# it will hopefully lessen our workload.
#
# Note: This implementation ensures that there is only one unique Pod
# (unique UID) per Metaflow task attempt.
client = self._client.get()
# tmpfs variables
use_tmpfs = self._kwargs["use_tmpfs"]
tmpfs_size = self._kwargs["tmpfs_size"]
tmpfs_enabled = use_tmpfs or (tmpfs_size and not use_tmpfs)
shared_memory = (
int(self._kwargs["shared_memory"])
if self._kwargs["shared_memory"]
else None
)
self._job = client.V1Job(
api_version="batch/v1",
kind="Job",
metadata=client.V1ObjectMeta(
# Annotations are for humans
annotations=self._kwargs.get("annotations", {}),
# While labels are for Kubernetes
labels=self._kwargs.get("labels", {}),
generate_name=self._kwargs["generate_name"],
namespace=self._kwargs["namespace"], # Defaults to `default`
),
spec=client.V1JobSpec(
# Retries are handled by Metaflow when it is responsible for
# executing the flow. The responsibility is moved to Kubernetes
# when Argo Workflows is responsible for the execution.
backoff_limit=self._kwargs.get("retries", 0),
completions=1, # A single non-indexed pod job
ttl_seconds_after_finished=7
* 60
* 60 # Remove job after a week. TODO: Make this configurable
* 24,
template=client.V1PodTemplateSpec(
metadata=client.V1ObjectMeta(
annotations=self._kwargs.get("annotations", {}),
labels=self._kwargs.get("labels", {}),
namespace=self._kwargs["namespace"],
),
spec=client.V1PodSpec(
# Timeout is set on the pod and not the job (important!)
active_deadline_seconds=self._kwargs["timeout_in_seconds"],
# TODO (savin): Enable affinities for GPU scheduling.
# affinity=?,
containers=[
client.V1Container(
command=self._kwargs["command"],
env=[
client.V1EnvVar(name=k, value=str(v))
for k, v in self._kwargs.get(
"environment_variables", {}
).items()
]
# And some downward API magic. Add (key, value)
# pairs below to make pod metadata available
# within Kubernetes container.
+ [
client.V1EnvVar(
name=k,
value_from=client.V1EnvVarSource(
field_ref=client.V1ObjectFieldSelector(
field_path=str(v)
)
),
)
for k, v in {
"METAFLOW_KUBERNETES_POD_NAMESPACE": "metadata.namespace",
"METAFLOW_KUBERNETES_POD_NAME": "metadata.name",
"METAFLOW_KUBERNETES_POD_ID": "metadata.uid",
"METAFLOW_KUBERNETES_SERVICE_ACCOUNT_NAME": "spec.serviceAccountName",
"METAFLOW_KUBERNETES_NODE_IP": "status.hostIP",
}.items()
]
+ [
client.V1EnvVar(name=k, value=str(v))
for k, v in inject_tracing_vars({}).items()
],
env_from=[
client.V1EnvFromSource(
secret_ref=client.V1SecretEnvSource(
name=str(k),
# optional=True
)
)
for k in list(self._kwargs.get("secrets", []))
+ KUBERNETES_SECRETS.split(",")
if k
],
image=self._kwargs["image"],
image_pull_policy=self._kwargs["image_pull_policy"],
name=self._kwargs["step_name"].replace("_", "-"),
resources=client.V1ResourceRequirements(
requests={
"cpu": str(self._kwargs["cpu"]),
"memory": "%sM" % str(self._kwargs["memory"]),
"ephemeral-storage": "%sM"
% str(self._kwargs["disk"]),
},
limits={
"%s.com/gpu".lower()
% self._kwargs["gpu_vendor"]: str(
self._kwargs["gpu"]
)
for k in [0]
# Don't set GPU limits if gpu isn't specified.
if self._kwargs["gpu"] is not None
},
),
volume_mounts=(
[
client.V1VolumeMount(
mount_path=self._kwargs.get("tmpfs_path"),
name="tmpfs-ephemeral-volume",
)
]
if tmpfs_enabled
else []
)
+ (
[
client.V1VolumeMount(
mount_path="/dev/shm", name="dhsm"
)
]
if shared_memory
else []
)
+ (
[
client.V1VolumeMount(
mount_path=path, name=claim
)
for claim, path in self._kwargs[
"persistent_volume_claims"
].items()
]
if self._kwargs["persistent_volume_claims"]
is not None
else []
),
)
],
node_selector=self._kwargs.get("node_selector"),
# TODO (savin): Support image_pull_secrets
# image_pull_secrets=?,
# TODO (savin): Support preemption policies
# preemption_policy=?,
#
# A Container in a Pod may fail for a number of
# reasons, such as because the process in it exited
# with a non-zero exit code, or the Container was
# killed due to OOM etc. If this happens, fail the pod
# and let Metaflow handle the retries.
restart_policy="Never",
service_account_name=self._kwargs["service_account"],
# Terminate the container immediately on SIGTERM
termination_grace_period_seconds=0,
tolerations=[
client.V1Toleration(**toleration)
for toleration in self._kwargs.get("tolerations") or []
],
volumes=(
[
client.V1Volume(
name="tmpfs-ephemeral-volume",
empty_dir=client.V1EmptyDirVolumeSource(
medium="Memory",
# Add default unit as ours differs from Kubernetes default.
size_limit="{}Mi".format(tmpfs_size),
),
)
]
if tmpfs_enabled
else []
)
+ (
[
client.V1Volume(
name="dhsm",
empty_dir=client.V1EmptyDirVolumeSource(
medium="Memory",
size_limit="{}Mi".format(shared_memory),
),
)
]
if shared_memory
else []
)
+ (
[
client.V1Volume(
name=claim,
persistent_volume_claim=client.V1PersistentVolumeClaimVolumeSource(
claim_name=claim
),
)
for claim in self._kwargs[
"persistent_volume_claims"
].keys()
]
if self._kwargs["persistent_volume_claims"] is not None
else []
),
# TODO (savin): Set termination_message_policy
),
),
),
)
return self
def execute(self):
client = self._client.get()
try:
# TODO: Make job submission back-pressure aware. Currently
# there doesn't seem to be a kubernetes-native way to
# achieve the guarantees that we are seeking.
# https://github.com/kubernetes/enhancements/issues/1040
# Hopefully, we will be able to get creative with kube-batch
response = (
client.BatchV1Api()
.create_namespaced_job(
body=self._job, namespace=self._kwargs["namespace"]
)
.to_dict()
)
return RunningJob(
client=self._client,
name=response["metadata"]["name"],
uid=response["metadata"]["uid"],
namespace=response["metadata"]["namespace"],
)
except client.rest.ApiException as e:
raise KubernetesJobException(
"Unable to launch Kubernetes job.\n %s"
% (json.loads(e.body)["message"] if e.body is not None else e.reason)
)
def step_name(self, step_name):
self._kwargs["step_name"] = step_name
return self
def namespace(self, namespace):
self._kwargs["namespace"] = namespace
return self
def name(self, name):
self._kwargs["name"] = name
return self
def command(self, command):
self._kwargs["command"] = command
return self
def image(self, image):
self._kwargs["image"] = image
return self
def cpu(self, cpu):
self._kwargs["cpu"] = cpu
return self
def memory(self, mem):
self._kwargs["memory"] = mem
return self
def environment_variable(self, name, value):
# Never set to None
if value is None:
return self
self._kwargs["environment_variables"] = dict(
self._kwargs.get("environment_variables", {}), **{name: value}
)
return self
def label(self, name, value):
self._kwargs["labels"] = dict(self._kwargs.get("labels", {}), **{name: value})
return self
def annotation(self, name, value):
self._kwargs["annotations"] = dict(
self._kwargs.get("annotations", {}), **{name: value}
)
return self
class RunningJob(object):
# State Machine implementation for the lifecycle behavior documented in
# https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle/
#
# This object encapsulates *both* V1Job and V1Pod. It simplifies the status
# to "running" and "done" (failed/succeeded) state. Note that V1Job and V1Pod
# status fields are not guaranteed to be always in sync due to the way job
# controller works.
# To ascertain the status of V1Job, we peer into the lifecycle status of
# the pod it is responsible for executing. Unfortunately, the `phase`
# attributes (pending, running, succeeded, failed etc.) only provide
# partial answers and the official API conventions guide suggests that
# it may soon be deprecated (however, not anytime soon - see
# https://github.com/kubernetes/kubernetes/issues/7856). `conditions` otoh
# provide a deeper understanding about the state of the pod; however
# conditions are not state machines and can be oscillating - from the
# official API conventions guide:
# In general, condition values may change back and forth, but some
# condition transitions may be monotonic, depending on the resource and
# condition type. However, conditions are observations and not,
# themselves, state machines, nor do we define comprehensive state
# machines for objects, nor behaviors associated with state
# transitions. The system is level-based rather than edge-triggered,
# and should assume an Open World.
# As a follow-up, we can synthesize our notion of "phase" state
# machine from `conditions`, since Kubernetes won't do it for us (for
# many good reasons).
#
# `conditions` can be of the following types -
# 1. (kubelet) Initialized (always True since we don't rely on init
# containers)
# 2. (kubelet) ContainersReady
# 3. (kubelet) Ready (same as ContainersReady since we don't use
# ReadinessGates -
# https://github.com/kubernetes/kubernetes/blob/master/pkg/kubelet/status/generate.go)
# 4. (kube-scheduler) PodScheduled
# (https://github.com/kubernetes/kubernetes/blob/master/pkg/scheduler/scheduler.go)
def __init__(self, client, name, uid, namespace):
self._client = client
self._name = name
self._pod_name = None
self._id = uid
self._namespace = namespace
self._job = self._fetch_job()
self._pod = self._fetch_pod()
import atexit
def best_effort_kill():
try:
self.kill()
except:
pass
atexit.register(best_effort_kill)
def __repr__(self):
return "{}('{}/{}')".format(
self.__class__.__name__, self._namespace, self._name
)
@k8s_retry()
def _fetch_job(self):
client = self._client.get()
try:
return (
client.BatchV1Api()
.read_namespaced_job(name=self._name, namespace=self._namespace)
.to_dict()
)
except client.rest.ApiException as e:
if e.status == 404:
raise KubernetesJobException(
"Unable to locate Kubernetes batch/v1 job %s" % self._name
)
raise
@k8s_retry()
def _fetch_pod(self):
# Fetch pod metadata.
client = self._client.get()
pods = (
client.CoreV1Api()
.list_namespaced_pod(
namespace=self._namespace,
label_selector="job-name={}".format(self._name),
)
.to_dict()["items"]
)
if pods:
return pods[0]
return {}
def kill(self):
# Terminating a Kubernetes job is a bit tricky. Issuing a
# `BatchV1Api.delete_namespaced_job` will also remove all traces of the
# job object from the Kubernetes API server which may not be desirable.
# This forces us to be a bit creative in terms of how we handle kill:
#
# 1. If the container is alive and kicking inside the pod, we simply
# attach ourselves to the container and issue a kill signal. The
# way we have initialized the Job ensures that the job will cleanly
# terminate.
# 2. In scenarios where either the pod (unschedulable etc.) or the
# container (ImagePullError etc.) hasn't come up yet, we become a
# bit creative by patching the job parallelism to 0. This ensures
# that the underlying node's resources are made available to
# kube-scheduler again. The downside is that the Job wouldn't mark
# itself as done and the pod metadata disappears from the API
# server. There is an open issue in the Kubernetes GH to provide
# better support for job terminations -
# https://github.com/kubernetes/enhancements/issues/2232
# 3. If the pod object hasn't shown up yet, we set the parallelism to 0
# to preempt it.
client = self._client.get()
if not self.is_done:
if self.is_running:
# Case 1.
from kubernetes.stream import stream
api_instance = client.CoreV1Api
try:
# TODO: stream opens a web-socket connection. It may
# not be desirable to open multiple web-socket
# connections frivolously (think killing a
# workflow during a for-each step).
stream(
api_instance().connect_get_namespaced_pod_exec,
name=self._pod["metadata"]["name"],
namespace=self._namespace,
command=[
"/bin/sh",
"-c",
"/sbin/killall5",
],
stderr=True,
stdin=False,
stdout=True,
tty=False,
)
except:
# Best effort. It's likely that this API call could be
# blocked for the user.
# --------------------------------------------------------
# We try patching Job parallelism anyway. Stopping any runaway
# jobs (and their pods) is secondary to correctly showing
# "Killed" status on the Kubernetes pod.
#
# This has the effect of pausing the job.
try:
client.BatchV1Api().patch_namespaced_job(
name=self._name,
namespace=self._namespace,
field_manager="metaflow",
body={"spec": {"parallelism": 0}},
)
except:
# Best effort.
pass
# raise
else:
# Case 2.
# This has the effect of pausing the job.
try:
client.BatchV1Api().patch_namespaced_job(
name=self._name,
namespace=self._namespace,
field_manager="metaflow",
body={"spec": {"parallelism": 0}},
)
except:
# Best effort.
pass
# raise
return self
@property
def id(self):
if self._pod_name:
return "pod %s" % self._pod_name
if self._pod:
self._pod_name = self._pod["metadata"]["name"]
return self.id
return "job %s" % self._name
@property
def is_done(self):
# Check if the container is done. As a side effect, also refreshes self._job and
# self._pod with the latest state
def done():
# Either the container succeeds or fails naturally or else we may have
# forced the pod termination causing the job to still be in an
# active state but for all intents and purposes dead to us.
return (
bool(self._job["status"].get("succeeded"))
or bool(self._job["status"].get("failed"))
or self._are_pod_containers_done
or (self._job["spec"]["parallelism"] == 0)
)
if not done():
# If not done, fetch newer status
self._job = self._fetch_job()
self._pod = self._fetch_pod()
return done()
@property
def status(self):
if not self.is_done:
if bool(self._job["status"].get("active")):
if self._pod:
msg = (
"Pod is %s"
% self._pod.get("status", {})
.get("phase", "uninitialized")
.lower()
)
# TODO (savin): parse Pod conditions
container_status = (
self._pod["status"].get("container_statuses") or [None]
)[0]
if container_status:
# We have a single container inside the pod
status = {"status": "waiting"}
for k, v in container_status["state"].items():
if v is not None:
status["status"] = k
status.update(v)
msg += ", Container is %s" % status["status"].lower()
reason = ""
if status.get("reason"):
pass
reason = status["reason"]
if status.get("message"):
reason += " - %s" % status["message"]
if reason:
msg += " - %s" % reason
return msg
return "Job is active"
return "Job status is unknown"
return "Job is done"
@property
def has_succeeded(self):
# The tasks container is in a terminal state and the status is marked as succeeded
return self.is_done and self._have_containers_succeeded
@property
def has_failed(self):
# Either the container is marked as failed or the Job is not allowed to
# any more pods
retval = self.is_done and (
bool(self._job["status"].get("failed"))
or self._has_any_container_failed
or (self._job["spec"]["parallelism"] == 0)
)
return retval
@property
def _have_containers_succeeded(self):
container_statuses = self._pod.get("status", {}).get("container_statuses", [])
if not container_statuses:
return False
for cstatus in container_statuses:
# If the terminated field is not set, the pod is still running.
terminated = cstatus.get("state", {}).get("terminated", {})
if not terminated:
return False
# If the terminated field is set but the `finished_at` field is not set,
# the pod is still considered as running.
if not terminated.get("finished_at"):
return False
# If finished_at is set AND reason is Completed
if terminated.get("reason", "").lower() != "completed":
return False
return True
@property
def _has_any_container_failed(self):
container_statuses = self._pod.get("status", {}).get("container_statuses", [])
if not container_statuses:
return False
for cstatus in container_statuses:
# If the terminated field is not set, the pod is still running. Too early
# to determine if any container failed.
terminated = cstatus.get("state", {}).get("terminated", {})
if not terminated:
return False
# If the terminated field is set but the `finished_at` field is not set,
# the pod is still considered as running. Too early to determine if any
# container failed.
if not terminated.get("finished_at"):
return False
# If finished_at is set AND reason is Error, it means that the
# container failed.
if terminated.get("reason", "").lower() == "error":
return True
# If none of the containers are marked as failed, the pod is not
# considered failed.
return False
@property
def _are_pod_containers_done(self):
# All containers in the pod have a containerStatus that has a
# finishedAt set.
container_statuses = self._pod.get("status", {}).get("container_statuses", [])
if not container_statuses:
return False
for cstatus in container_statuses:
# If the terminated field is not set, the pod is still running. Too early
# to determine if any container failed.
terminated = cstatus.get("state", {}).get("terminated", {})
if not terminated:
return False
# If the terminated field is set but the `finished_at` field is not set,
# the pod is still considered as running.
if not terminated.get("finished_at"):
return False
# If we got until here, the containers were marked terminated and their
# finishedAt was set.
return True
@property
def is_running(self):
# Returns true if the container is running.
if self.is_done:
return False
return not self._are_pod_containers_done
@property
def is_waiting(self):
return not self.is_done and not self.is_running
@property
def reason(self):
if self.is_done:
if self.has_succeeded:
return 0, None
# Best effort since Pod object can disappear on us at anytime
else:
if self._pod.get("status", {}).get("phase") not in (
"Succeeded",
"Failed",
):
# If pod status is dirty, check for newer status
self._pod = self._fetch_pod()
if self._pod:
if self._pod.get("status", {}).get("container_statuses") is None:
# We're done, but no container_statuses is set
# This can happen when the pod is evicted
return None, ": ".join(
filter(
None,
[
self._pod.get("status", {}).get("reason"),
self._pod.get("status", {}).get("message"),
],
)
)
for k, v in (
self._pod.get("status", {})
.get("container_statuses", [{}])[0]
.get("state", {})
.items()
):
if v is not None:
return v.get("exit_code"), ": ".join(
filter(
None,
[v.get("reason"), v.get("message")],
)
)
return None, None