Skip to content

Commit

Permalink
Fix Kubernetes job watch exit when no timeout given (#8350)
Browse files Browse the repository at this point in the history
Co-authored-by: Nathan Nowack <thrast36@gmail.com>
  • Loading branch information
zanieb and zzstoatzz committed Feb 2, 2023
1 parent fc5104a commit 4766c2e
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 25 deletions.
39 changes: 18 additions & 21 deletions src/prefect/infrastructure/kubernetes.py
@@ -1,11 +1,11 @@
import copy
import enum
import os
import time
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple, Union

import anyio.abc
import pendulum
import yaml
from pydantic import Field, root_validator, validator
from typing_extensions import Literal
Expand Down Expand Up @@ -602,32 +602,35 @@ def _watch_job(self, job_name: str) -> int:
)

self.logger.debug(f"Job {job_name!r}: Starting watch for job completion")
start_time = pendulum.now("utc")
deadline = (
(time.time() + self.job_watch_timeout_seconds)
if self.job_watch_timeout_seconds is not None
else None
)
completed = False
while not completed:
elapsed = (pendulum.now("utc") - start_time).in_seconds()
if (
self.job_watch_timeout_seconds is not None
and elapsed > self.job_watch_timeout_seconds
):
self.logger.error(f"Job {job_name!r}: Job timed out after {elapsed}s.")
remaining_time = deadline - time.time() if deadline else None
if deadline and remaining_time <= 0:
self.logger.error(
f"Job {job_name!r}: Job did not complete within "
f"timeout of {self.job_watch_timeout_seconds}s."
)
return -1

watch = kubernetes.watch.Watch()
with self.get_batch_client() as batch_client:
remaining_timeout = (
( # subtract previous watch time
self.job_watch_timeout_seconds - elapsed
)
if self.job_watch_timeout_seconds
else None
# The kubernetes library will disable retries if the timeout kwarg is
# present regardless of the value so we do not pass it unless given
# https://github.com/kubernetes-client/python/blob/84f5fea2a3e4b161917aa597bf5e5a1d95e24f5a/kubernetes/base/watch/watch.py#LL160
timeout_seconds = (
{"timeout_seconds": remaining_time} if deadline else {}
)

for event in watch.stream(
func=batch_client.list_namespaced_job,
field_selector=f"metadata.name={job_name}",
namespace=self.namespace,
timeout_seconds=remaining_timeout,
**timeout_seconds,
):
if event["object"].status.completion_time:
if not event["object"].status.succeeded:
Expand All @@ -636,12 +639,6 @@ def _watch_job(self, job_name: str) -> int:
completed = True
watch.stop()
break
else:
self.logger.error(
f"Job {job_name!r}: Job did not complete within "
f"timeout of {self.job_watch_timeout_seconds}s."
)
return -1

with self.get_client() as client:
pod_status = client.read_namespaced_pod_status(
Expand Down
134 changes: 130 additions & 4 deletions tests/infrastructure/test_kubernetes_job.py
Expand Up @@ -714,7 +714,7 @@ def test_uses_cluster_config_if_not_in_cluster(
mock_cluster_config.load_kube_config.assert_called_once()


@pytest.mark.parametrize("job_timeout", [24, None])
@pytest.mark.parametrize("job_timeout", [24, 100])
def test_allows_configurable_timeouts_for_pod_and_job_watches(
mock_k8s_client,
mock_watch,
Expand All @@ -728,9 +728,17 @@ def test_allows_configurable_timeouts_for_pod_and_job_watches(
command=["echo", "hello"],
pod_watch_timeout_seconds=42,
)
expected_job_call_kwargs = dict(
func=mock_k8s_batch_client.list_namespaced_job,
namespace=mock.ANY,
field_selector=mock.ANY,
)

if job_timeout is not None:
k8s_job_args["job_watch_timeout_seconds"] = job_timeout
expected_job_call_kwargs["timeout_seconds"] = pytest.approx(
job_timeout, abs=0.01
)

KubernetesJob(**k8s_job_args).run(MagicMock())

Expand All @@ -742,11 +750,41 @@ def test_allows_configurable_timeouts_for_pod_and_job_watches(
label_selector=mock.ANY,
timeout_seconds=42,
),
mock.call(**expected_job_call_kwargs),
]
)


@pytest.mark.parametrize("job_timeout", [None])
def test_excludes_timeout_from_job_watches_when_null(
mock_k8s_client,
mock_watch,
mock_k8s_batch_client,
job_timeout,
):
mock_watch.stream = mock.Mock(
side_effect=_mock_pods_stream_that_returns_running_pod
)
k8s_job_args = dict(
command=["echo", "hello"],
job_watch_timeout_seconds=job_timeout,
)

KubernetesJob(**k8s_job_args).run(MagicMock())

mock_watch.stream.assert_has_calls(
[
mock.call(
func=mock_k8s_client.list_namespaced_pod,
namespace=mock.ANY,
label_selector=mock.ANY,
timeout_seconds=mock.ANY,
),
mock.call(
func=mock_k8s_batch_client.list_namespaced_job,
namespace=mock.ANY,
field_selector=mock.ANY,
timeout_seconds=job_timeout,
# Note: timeout_seconds is excluded here
),
]
)
Expand All @@ -771,13 +809,12 @@ def test_watches_the_right_namespace(
func=mock_k8s_client.list_namespaced_pod,
namespace="my-awesome-flows",
label_selector=mock.ANY,
timeout_seconds=mock.ANY,
timeout_seconds=60,
),
mock.call(
func=mock_k8s_batch_client.list_namespaced_job,
namespace="my-awesome-flows",
field_selector=mock.ANY,
timeout_seconds=mock.ANY,
),
]
)
Expand Down Expand Up @@ -828,6 +865,95 @@ def mock_stream(*args, **kwargs):
assert result.status_code == -1


def test_watch_is_restarted_until_job_is_complete(
mock_k8s_client, mock_watch, mock_k8s_batch_client
):
def mock_stream(*args, **kwargs):
if kwargs["func"] == mock_k8s_client.list_namespaced_pod:
job_pod = MagicMock(spec=kubernetes.client.V1Pod)
job_pod.status.phase = "Running"
yield {"object": job_pod}

if kwargs["func"] == mock_k8s_batch_client.list_namespaced_job:
job = MagicMock(spec=kubernetes.client.V1Job)

# Yield the job then return exiting the stream
# After restarting the watch a few times, we'll report completion
job.status.completion_time = (
None if mock_watch.stream.call_count < 3 else True
)
yield {"object": job}

mock_watch.stream.side_effect = mock_stream
result = KubernetesJob(command=["echo", "hello"]).run(MagicMock())
assert result.status_code == 1
assert mock_watch.stream.call_count == 3


def test_watch_timeout_is_restarted_until_job_is_complete(
mock_k8s_client, mock_watch, mock_k8s_batch_client
):
def mock_stream(*args, **kwargs):
if kwargs["func"] == mock_k8s_client.list_namespaced_pod:
job_pod = MagicMock(spec=kubernetes.client.V1Pod)
job_pod.status.phase = "Running"
yield {"object": job_pod}

if kwargs["func"] == mock_k8s_batch_client.list_namespaced_job:
job = MagicMock(spec=kubernetes.client.V1Job)

# Sleep a little
sleep(0.25)

# Yield the job then return exiting the stream
job.status.completion_time = None
yield {"object": job}

mock_watch.stream.side_effect = mock_stream
result = KubernetesJob(command=["echo", "hello"], job_watch_timeout_seconds=1).run(
MagicMock()
)
assert result.status_code == -1

mock_watch.stream.assert_has_calls(
[
mock.call(
func=mock_k8s_client.list_namespaced_pod,
namespace=mock.ANY,
label_selector=mock.ANY,
timeout_seconds=mock.ANY,
),
# Starts with the full timeout
# Approximate comparisons are needed since executing code takes some time
mock.call(
func=mock_k8s_batch_client.list_namespaced_job,
field_selector=mock.ANY,
namespace=mock.ANY,
timeout_seconds=pytest.approx(1, abs=0.01),
),
# Then, elapsed time removed on each call
mock.call(
func=mock_k8s_batch_client.list_namespaced_job,
field_selector=mock.ANY,
namespace=mock.ANY,
timeout_seconds=pytest.approx(0.75, abs=0.05),
),
mock.call(
func=mock_k8s_batch_client.list_namespaced_job,
field_selector=mock.ANY,
namespace=mock.ANY,
timeout_seconds=pytest.approx(0.5, abs=0.05),
),
mock.call(
func=mock_k8s_batch_client.list_namespaced_job,
field_selector=mock.ANY,
namespace=mock.ANY,
timeout_seconds=pytest.approx(0.25, abs=0.05),
),
]
)


class TestCustomizingBaseJob:
"""Tests scenarios where a user is providing a customized base Job template"""

Expand Down

0 comments on commit 4766c2e

Please sign in to comment.