Skip to content

Commit

Permalink
Skip DockerOperator task when it returns a provided exit code (#28996)
Browse files Browse the repository at this point in the history
  • Loading branch information
hussein-awala committed Jan 18, 2023
1 parent bc5cecc commit 3a7bfce
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 3 deletions.
13 changes: 11 additions & 2 deletions airflow/providers/docker/operators/docker.py
Expand Up @@ -32,7 +32,7 @@
from dotenv import dotenv_values

from airflow.compat.functools import cached_property
from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.models import BaseOperator
from airflow.providers.docker.hooks.docker import DockerHook

Expand Down Expand Up @@ -154,6 +154,9 @@ class DockerOperator(BaseOperator):
If rolling the logs creates excess files, the oldest file is removed.
Only effective when max-size is also set. A positive integer. Defaults to 1.
:param ipc_mode: Set the IPC mode for the container.
:param skip_exit_code: If task exits with this exit code, leave the task
in ``skipped`` state (default: None). If set to ``None``, any non-zero
exit code will be treated as a failure.
"""

template_fields: Sequence[str] = ("image", "command", "environment", "env_file", "container_name")
Expand Down Expand Up @@ -209,6 +212,7 @@ def __init__(
log_opts_max_size: str | None = None,
log_opts_max_file: str | None = None,
ipc_mode: str | None = None,
skip_exit_code: int | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand Down Expand Up @@ -269,6 +273,7 @@ def __init__(
self.log_opts_max_size = log_opts_max_size
self.log_opts_max_file = log_opts_max_file
self.ipc_mode = ipc_mode
self.skip_exit_code = skip_exit_code

@cached_property
def hook(self) -> DockerHook:
Expand Down Expand Up @@ -368,7 +373,11 @@ def _run_image_with_mounts(self, target_mounts, add_tmp_variable: bool) -> list[
self.log.info("%s", log_chunk)

result = self.cli.wait(self.container["Id"])
if result["StatusCode"] != 0:
if result["StatusCode"] == self.skip_exit_code:
raise AirflowSkipException(
f"Docker container returned exit code {self.skip_exit_code}. Skipping."
)
elif result["StatusCode"] != 0:
joined_log_lines = "\n".join(log_lines)
raise AirflowException(f"Docker container failed: {repr(result)} lines {joined_log_lines}")

Expand Down
30 changes: 30 additions & 0 deletions tests/providers/docker/decorators/test_docker.py
Expand Up @@ -16,9 +16,13 @@
# under the License.
from __future__ import annotations

import pytest

from airflow.decorators import task
from airflow.exceptions import AirflowException
from airflow.models.dag import DAG
from airflow.utils import timezone
from airflow.utils.state import TaskInstanceState

DEFAULT_DATE = timezone.datetime(2021, 9, 1)

Expand Down Expand Up @@ -100,3 +104,29 @@ def do_run():

assert len(dag.task_ids) == 21
assert dag.task_ids[-1] == "do_run__20"

@pytest.mark.parametrize(
"extra_kwargs, actual_exit_code, expected_state",
[
(None, 99, TaskInstanceState.FAILED),
({"skip_exit_code": 100}, 100, TaskInstanceState.SKIPPED),
({"skip_exit_code": 100}, 101, TaskInstanceState.FAILED),
({"skip_exit_code": None}, 0, TaskInstanceState.SUCCESS),
],
)
def test_skip_docker_operator(self, extra_kwargs, actual_exit_code, expected_state, dag_maker):
@task.docker(image="python:3.9-slim", auto_remove="force", **(extra_kwargs if extra_kwargs else {}))
def f(exit_code):
raise SystemExit(exit_code)

with dag_maker():
ret = f(actual_exit_code)

dr = dag_maker.create_dagrun()
if expected_state == TaskInstanceState.FAILED:
with pytest.raises(AirflowException):
ret.operator.run(start_date=dr.execution_date, end_date=dr.execution_date)
else:
ret.operator.run(start_date=dr.execution_date, end_date=dr.execution_date)
ti = dr.get_task_instances()[0]
assert ti.state == expected_state
23 changes: 22 additions & 1 deletion tests/providers/docker/operators/test_docker.py
Expand Up @@ -26,7 +26,7 @@
from docker.errors import APIError
from docker.types import DeviceRequest, LogConfig, Mount

from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.docker.operators.docker import DockerOperator

TEST_CONN_ID = "docker_test_connection"
Expand Down Expand Up @@ -510,6 +510,27 @@ def test_execute_unicode_logs(self):
logging.raiseExceptions = original_raise_exceptions
print_exception_mock.assert_not_called()

@pytest.mark.parametrize(
"extra_kwargs, actual_exit_code, expected_exc",
[
(None, 99, AirflowException),
({"skip_exit_code": 100}, 100, AirflowSkipException),
({"skip_exit_code": 100}, 101, AirflowException),
({"skip_exit_code": None}, 100, AirflowException),
],
)
def test_skip(self, extra_kwargs, actual_exit_code, expected_exc):
msg = {"StatusCode": actual_exit_code}
self.client_mock.wait.return_value = msg

kwargs = dict(image="ubuntu", owner="unittest", task_id="unittest")
if extra_kwargs:
kwargs.update(**extra_kwargs)
operator = DockerOperator(**kwargs)

with pytest.raises(expected_exc):
operator.execute({})

def test_execute_container_fails(self):
failed_msg = {"StatusCode": 1}
log_line = ["unicode container log 😁 ", b"byte string container log"]
Expand Down

0 comments on commit 3a7bfce

Please sign in to comment.