Skip to content

Commit

Permalink
Fix @task.kubernetes to receive input and send output (#28942)
Browse files Browse the repository at this point in the history
* Fix @task.kubernetes to receive input and send output

* Pickle input and rm unnecessary env vars

* Back to env vars and make cmds easier to read

* Remove check for op_args and op_kwargs on input write
  • Loading branch information
vchiapaikeo committed Feb 18, 2023
1 parent 73c8e7d commit 9a5c3e0
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 20 deletions.
62 changes: 48 additions & 14 deletions airflow/providers/cncf/kubernetes/decorators/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,17 @@
# under the License.
from __future__ import annotations

import base64
import inspect
import os
import pickle
import uuid
from shlex import quote
from tempfile import TemporaryDirectory
from textwrap import dedent
from typing import TYPE_CHECKING, Callable, Sequence

import dill
from kubernetes.client import models as k8s

from airflow.decorators.base import DecoratedOperator, TaskDecorator, task_decorator_factory
Expand All @@ -37,21 +40,20 @@
from airflow.utils.context import Context

_PYTHON_SCRIPT_ENV = "__PYTHON_SCRIPT"
_PYTHON_INPUT_ENV = "__PYTHON_INPUT"

_FILENAME_IN_CONTAINER = "/tmp/script.py"


def _generate_decode_command() -> str:
def _generate_decoded_command(env_var: str, file: str) -> str:
return (
f'python -c "import base64, os;'
rf"x = os.environ[\"{_PYTHON_SCRIPT_ENV}\"];"
rf'f = open(\"{_FILENAME_IN_CONTAINER}\", \"w\"); f.write(x); f.close()"'
rf"x = base64.b64decode(os.environ[\"{env_var}\"]);"
rf'f = open(\"{file}\", \"wb\"); f.write(x); f.close()"'
)


def _read_file_contents(filename):
with open(filename) as script_file:
return script_file.read()
def _read_file_contents(filename: str) -> str:
with open(filename, "rb") as script_file:
return base64.b64encode(script_file.read()).decode("utf-8")


class _KubernetesDecoratedOperator(DecoratedOperator, KubernetesPodOperator):
Expand All @@ -62,17 +64,16 @@ class _KubernetesDecoratedOperator(DecoratedOperator, KubernetesPodOperator):
{"op_args", "op_kwargs", *KubernetesPodOperator.template_fields} - {"cmds", "arguments"}
)

# since we won't mutate the arguments, we should just do the shallow copy
# Since we won't mutate the arguments, we should just do the shallow copy
# there are some cases we can't deepcopy the objects (e.g protobuf).
shallow_copy_attrs: Sequence[str] = ("python_callable",)

def __init__(self, namespace: str = "default", **kwargs) -> None:
self.pickling_library = pickle
def __init__(self, namespace: str = "default", use_dill: bool = False, **kwargs) -> None:
self.pickling_library = dill if use_dill else pickle
super().__init__(
namespace=namespace,
name=kwargs.pop("name", f"k8s_airflow_pod_{uuid.uuid4().hex}"),
cmds=["bash"],
arguments=["-cx", f"{_generate_decode_command()} && python {_FILENAME_IN_CONTAINER}"],
cmds=["dummy-command"],
**kwargs,
)

Expand All @@ -82,11 +83,41 @@ def _get_python_source(self):
res = remove_task_decorator(res, "@task.kubernetes")
return res

def _generate_cmds(self) -> list[str]:
script_filename = "/tmp/script.py"
input_filename = "/tmp/script.in"
output_filename = "/airflow/xcom/return.json"

write_local_script_file_cmd = (
f"{_generate_decoded_command(quote(_PYTHON_SCRIPT_ENV), quote(script_filename))}"
)
write_local_input_file_cmd = (
f"{_generate_decoded_command(quote(_PYTHON_INPUT_ENV), quote(input_filename))}"
)
make_xcom_dir_cmd = "mkdir -p /airflow/xcom"
exec_python_cmd = f"python {script_filename} {input_filename} {output_filename}"
return [
"bash",
"-cx",
" && ".join(
[
write_local_script_file_cmd,
write_local_input_file_cmd,
make_xcom_dir_cmd,
exec_python_cmd,
]
),
]

def execute(self, context: Context):
with TemporaryDirectory(prefix="venv") as tmp_dir:
script_filename = os.path.join(tmp_dir, "script.py")
py_source = self._get_python_source()
input_filename = os.path.join(tmp_dir, "script.in")

with open(input_filename, "wb") as file:
self.pickling_library.dump({"args": self.op_args, "kwargs": self.op_kwargs}, file)

py_source = self._get_python_source()
jinja_context = {
"op_args": self.op_args,
"op_kwargs": self.op_kwargs,
Expand All @@ -100,7 +131,10 @@ def execute(self, context: Context):
self.env_vars = [
*self.env_vars,
k8s.V1EnvVar(name=_PYTHON_SCRIPT_ENV, value=_read_file_contents(script_filename)),
k8s.V1EnvVar(name=_PYTHON_INPUT_ENV, value=_read_file_contents(input_filename)),
]

self.cmds = self._generate_cmds()
return super().execute(context)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
under the License.
-#}

import json
import {{ pickling_library }}
import sys

Expand All @@ -42,3 +43,8 @@ arg_dict = {"args": [], "kwargs": {}}
# Script
{{ python_callable_source }}
res = {{ python_callable }}(*arg_dict["args"], **arg_dict["kwargs"])

# Write output
with open(sys.argv[2], "w") as file:
if res is not None:
file.write(json.dumps(res))
82 changes: 76 additions & 6 deletions tests/providers/cncf/kubernetes/decorators/test_kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# under the License.
from __future__ import annotations

import base64
import pickle
from unittest import mock

import pytest
Expand All @@ -29,6 +31,8 @@
POD_MANAGER_CLASS = "airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager"
HOOK_CLASS = "airflow.providers.cncf.kubernetes.operators.kubernetes_pod.KubernetesHook"

XCOM_IMAGE = "XCOM_IMAGE"


@pytest.fixture(autouse=True)
def mock_create_pod() -> mock.Mock:
Expand All @@ -40,6 +44,18 @@ def mock_await_pod_start() -> mock.Mock:
return mock.patch(f"{POD_MANAGER_CLASS}.await_pod_start").start()


@pytest.fixture(autouse=True)
def await_xcom_sidecar_container_start() -> mock.Mock:
return mock.patch(f"{POD_MANAGER_CLASS}.await_xcom_sidecar_container_start").start()


@pytest.fixture(autouse=True)
def extract_xcom() -> mock.Mock:
f = mock.patch(f"{POD_MANAGER_CLASS}.extract_xcom").start()
f.return_value = '{"key1": "value1", "key2": "value2"}'
return f


@pytest.fixture(autouse=True)
def mock_await_pod_completion() -> mock.Mock:
f = mock.patch(f"{POD_MANAGER_CLASS}.await_pod_completion").start()
Expand Down Expand Up @@ -81,11 +97,65 @@ def f():

containers = mock_create_pod.call_args[1]["pod"].spec.containers
assert len(containers) == 1
assert containers[0].command == ["bash"]
assert containers[0].command[0] == "bash"
assert len(containers[0].args) == 0
assert containers[0].env[0].name == "__PYTHON_SCRIPT"
assert containers[0].env[0].value
assert containers[0].env[1].name == "__PYTHON_INPUT"

# Ensure we pass input through a b64 encoded env var
decoded_input = pickle.loads(base64.b64decode(containers[0].env[1].value))
assert decoded_input == {"args": [], "kwargs": {}}


def test_kubernetes_with_input_output(
dag_maker, session, mock_create_pod: mock.Mock, mock_hook: mock.Mock
) -> None:
with dag_maker(session=session) as dag:

@task.kubernetes(
image="python:3.10-slim-buster",
in_cluster=False,
cluster_context="default",
config_file="/tmp/fake_file",
)
def f(arg1, arg2, kwarg1=None, kwarg2=None):
return {"key1": "value1", "key2": "value2"}

f.override(task_id="my_task_id", do_xcom_push=True)("arg1", "arg2", kwarg1="kwarg1")

dr = dag_maker.create_dagrun()
(ti,) = dr.task_instances

mock_hook.return_value.get_xcom_sidecar_container_image.return_value = XCOM_IMAGE

dag.get_task("my_task_id").execute(context=ti.get_template_context(session=session))

mock_hook.assert_called_once_with(
conn_id=None,
in_cluster=False,
cluster_context="default",
config_file="/tmp/fake_file",
)
assert mock_create_pod.call_count == 1
assert mock_hook.return_value.get_xcom_sidecar_container_image.call_count == 1

containers = mock_create_pod.call_args[1]["pod"].spec.containers

# First container is Python script
assert len(containers) == 2
assert containers[0].command[0] == "bash"
assert len(containers[0].args) == 0

assert containers[0].env[0].name == "__PYTHON_SCRIPT"
assert containers[0].env[0].value
assert containers[0].env[1].name == "__PYTHON_INPUT"
assert containers[0].env[1].value

assert len(containers[0].args) == 2
assert containers[0].args[0] == "-cx"
assert containers[0].args[1].endswith("/tmp/script.py")
# Ensure we pass input through a b64 encoded env var
decoded_input = pickle.loads(base64.b64decode(containers[0].env[1].value))
assert decoded_input == {"args": ("arg1", "arg2"), "kwargs": {"kwarg1": "kwarg1"}}

assert containers[0].env[-1].name == "__PYTHON_SCRIPT"
assert containers[0].env[-1].value
# Second container is xcom image
assert containers[1].image == XCOM_IMAGE
assert containers[1].volume_mounts[0].mount_path == "/airflow/xcom"

0 comments on commit 9a5c3e0

Please sign in to comment.