diff --git a/airflow/contrib/operators/docker_swarm_operator.py b/airflow/contrib/operators/docker_swarm_operator.py new file mode 100644 index 0000000000000..9ac3057a75fac --- /dev/null +++ b/airflow/contrib/operators/docker_swarm_operator.py @@ -0,0 +1,146 @@ +''' +Run ephemeral Docker Swarm services +''' +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from docker import types + +from airflow.exceptions import AirflowException +from airflow.operators.docker_operator import DockerOperator +from airflow.utils.decorators import apply_defaults +from airflow.utils.strings import get_random_string + + +class DockerSwarmOperator(DockerOperator): + """ + Execute a command as an ephemeral docker swarm service. + Example use-case - Using Docker Swarm orchestration to make one-time + scripts highly available. + + A temporary directory is created on the host and + mounted into a container to allow storing files + that together exceed the default disk size of 10GB in a container. + The path to the mounted directory can be accessed + via the environment variable ``AIRFLOW_TMP_DIR``. + + If a login to a private registry is required prior to pulling the image, a + Docker connection needs to be configured in Airflow and the connection ID + be provided with the parameter ``docker_conn_id``. + + :param image: Docker image from which to create the container. + If image tag is omitted, "latest" will be used. + :type image: str + :param api_version: Remote API version. Set to ``auto`` to automatically + detect the server's version. + :type api_version: str + :param auto_remove: Auto-removal of the container on daemon side when the + container's process exits. + The default is False. + :type auto_remove: bool + :param command: Command to be run in the container. (templated) + :type command: str or list + :param docker_url: URL of the host running the docker daemon. + Default is unix://var/run/docker.sock + :type docker_url: str + :param environment: Environment variables to set in the container. (templated) + :type environment: dict + :param force_pull: Pull the docker image on every run. Default is False. + :type force_pull: bool + :param mem_limit: Maximum amount of memory the container can use. + Either a float value, which represents the limit in bytes, + or a string like ``128m`` or ``1g``. + :type mem_limit: float or str + :param tls_ca_cert: Path to a PEM-encoded certificate authority + to secure the docker connection. + :type tls_ca_cert: str + :param tls_client_cert: Path to the PEM-encoded certificate + used to authenticate docker client. + :type tls_client_cert: str + :param tls_client_key: Path to the PEM-encoded key used to authenticate docker client. + :type tls_client_key: str + :param tls_hostname: Hostname to match against + the docker server certificate or False to disable the check. + :type tls_hostname: str or bool + :param tls_ssl_version: Version of SSL to use when communicating with docker daemon. + :type tls_ssl_version: str + :param tmp_dir: Mount point inside the container to + a temporary directory created on the host by the operator. + The path is also made available via the environment variable + ``AIRFLOW_TMP_DIR`` inside the container. + :type tmp_dir: str + :param user: Default user inside the docker container. + :type user: int or str + :param docker_conn_id: ID of the Airflow connection to use + :type docker_conn_id: str + """ + + @apply_defaults + def __init__( + self, + image, + *args, + **kwargs): + + super().__init__(image=image, *args, **kwargs) + + self.service = None + + def _run_image(self): + self.log.info('Starting docker service from image %s', self.image) + + self.service = self.cli.create_service( + types.TaskTemplate( + container_spec=types.ContainerSpec( + image=self.image, + command=self.get_command(), + env=self.environment, + user=self.user + ), + restart_policy=types.RestartPolicy(condition='none'), + resources=types.Resources(mem_limit=self.mem_limit) + ), + name='airflow-%s' % get_random_string(), + labels={'name': 'airflow__%s__%s' % (self.dag_id, self.task_id)} + ) + + self.log.info('Service started: %s', str(self.service)) + + status = None + # wait for the service to start the task + while not self.cli.tasks(filters={'service': self.service['ID']}): + continue + while True: + + status = self.cli.tasks( + filters={'service': self.service['ID']} + )[0]['Status']['State'] + if status in ['failed', 'complete']: + self.log.info('Service status before exiting: %s', status) + break + + if self.auto_remove: + self.cli.remove_service(self.service['ID']) + if status == 'failed': + raise AirflowException('Service failed: ' + repr(self.service)) + + def on_kill(self): + if self.cli is not None: + self.log.info('Removing docker service: %s', self.service['ID']) + self.cli.remove_service(self.service['ID']) diff --git a/airflow/example_dags/example_docker_swarm_operator.py b/airflow/example_dags/example_docker_swarm_operator.py new file mode 100644 index 0000000000000..1f2e9734657e4 --- /dev/null +++ b/airflow/example_dags/example_docker_swarm_operator.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +from datetime import timedelta +import airflow +from airflow import DAG +from airflow.contrib.operators.docker_swarm_operator import DockerSwarmOperator + +default_args = { + 'owner': 'airflow', + 'depends_on_past': False, + 'start_date': airflow.utils.dates.days_ago(1), + 'email': ['airflow@example.com'], + 'email_on_failure': False, + 'email_on_retry': False +} + +dag = DAG( + 'docker_swarm_sample', + default_args=default_args, + schedule_interval=timedelta(minutes=10), + catchup=False +) + +with dag as dag: + t1 = DockerSwarmOperator( + api_version='auto', + docker_url='tcp://localhost:2375', # Set your docker URL + command='/bin/sleep 10', + image='centos:latest', + auto_remove=True, + task_id='sleep_with_swarm', + ) +""" diff --git a/airflow/operators/docker_operator.py b/airflow/operators/docker_operator.py index 777af003e85b2..15ec0dc121f53 100644 --- a/airflow/operators/docker_operator.py +++ b/airflow/operators/docker_operator.py @@ -187,29 +187,13 @@ def get_hook(self): tls=self.__get_tls_config() ) - def execute(self, context): + def _run_image(self): + """ + Run a Docker container with the provided image + """ self.log.info('Starting docker container from image %s', self.image) - tls_config = self.__get_tls_config() - - if self.docker_conn_id: - self.cli = self.get_hook().get_conn() - else: - self.cli = APIClient( - base_url=self.docker_url, - version=self.api_version, - tls=tls_config - ) - - if self.force_pull or len(self.cli.images(name=self.image)) == 0: - self.log.info('Pulling docker image %s', self.image) - for l in self.cli.pull(self.image, stream=True): - output = json.loads(l.decode('utf-8').strip()) - if 'status' in output: - self.log.info("%s", output['status']) - with TemporaryDirectory(prefix='airflowtmp', dir=self.host_tmp_dir) as host_tmp_dir: - self.environment['AIRFLOW_TMP_DIR'] = self.tmp_dir self.volumes.append('{0}:{1}'.format(host_tmp_dir, self.tmp_dir)) self.container = self.cli.create_container( @@ -249,6 +233,31 @@ def execute(self, context): return self.cli.logs(container=self.container['Id']) \ if self.xcom_all else line.encode('utf-8') + def execute(self, context): + + tls_config = self.__get_tls_config() + + if self.docker_conn_id: + self.cli = self.get_hook().get_conn() + else: + self.cli = APIClient( + base_url=self.docker_url, + version=self.api_version, + tls=tls_config + ) + + # Pull the docker image if `force_pull` is set or image does not exist locally + if self.force_pull or len(self.cli.images(name=self.image)) == 0: + self.log.info('Pulling docker image %s', self.image) + for l in self.cli.pull(self.image, stream=True): + output = json.loads(l.decode('utf-8').strip()) + if 'status' in output: + self.log.info("%s", output['status']) + + self.environment['AIRFLOW_TMP_DIR'] = self.tmp_dir + + self._run_image() + def get_command(self): if self.command is not None and self.command.strip().find('[') == 0: commands = ast.literal_eval(self.command) diff --git a/airflow/utils/strings.py b/airflow/utils/strings.py new file mode 100644 index 0000000000000..179d2b15fd29c --- /dev/null +++ b/airflow/utils/strings.py @@ -0,0 +1,31 @@ +''' +Common utility functions with strings +''' +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import string +from random import choice + + +def get_random_string(length=8, choices=string.ascii_letters + string.digits): + ''' + Generate random string + ''' + return ''.join([choice(choices) for i in range(length)]) diff --git a/tests/operators/test_docker_swarm_operator.py b/tests/operators/test_docker_swarm_operator.py new file mode 100644 index 0000000000000..b59c664bfde95 --- /dev/null +++ b/tests/operators/test_docker_swarm_operator.py @@ -0,0 +1,143 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import unittest + +from docker import APIClient +from tests.compat import mock + +from airflow.contrib.operators.docker_swarm_operator import DockerSwarmOperator +from airflow.exceptions import AirflowException + + +class DockerSwarmOperatorTestCase(unittest.TestCase): + + @mock.patch('airflow.operators.docker_operator.APIClient') + @mock.patch('airflow.contrib.operators.docker_swarm_operator.types') + def test_execute(self, types_mock, client_class_mock): + + mock_obj = mock.Mock() + + def _client_tasks_side_effect(): + for _ in range(2): + yield [{'Status': {'State': 'pending'}}] + yield [{'Status': {'State': 'complete'}}] + + client_mock = mock.Mock(spec=APIClient) + client_mock.create_service.return_value = {'ID': 'some_id'} + client_mock.images.return_value = [] + client_mock.pull.return_value = [b'{"status":"pull log"}'] + client_mock.tasks.side_effect = _client_tasks_side_effect() + types_mock.TaskTemplate.return_value = mock_obj + types_mock.ContainerSpec.return_value = mock_obj + types_mock.RestartPolicy.return_value = mock_obj + types_mock.Resources.return_value = mock_obj + + client_class_mock.return_value = client_mock + + operator = DockerSwarmOperator( + api_version='1.19', command='env', environment={'UNIT': 'TEST'}, image='ubuntu:latest', + mem_limit='128m', user='unittest', task_id='unittest', auto_remove=True + ) + operator.execute(None) + + types_mock.TaskTemplate.assert_called_with( + container_spec=mock_obj, restart_policy=mock_obj, resources=mock_obj + ) + types_mock.ContainerSpec.assert_called_with( + image='ubuntu:latest', command='env', user='unittest', + env={'UNIT': 'TEST', 'AIRFLOW_TMP_DIR': '/tmp/airflow'} + ) + types_mock.RestartPolicy.assert_called_with(condition='none') + types_mock.Resources.assert_called_with(mem_limit='128m') + + client_class_mock.assert_called_with( + base_url='unix://var/run/docker.sock', tls=None, version='1.19' + ) + + csargs, cskwargs = client_mock.create_service.call_args_list[0] + self.assertEqual( + len(csargs), 1, 'create_service called with different number of arguments than expected' + ) + self.assertEqual(csargs, (mock_obj, )) + self.assertEqual(cskwargs['labels'], {'name': 'airflow__adhoc_airflow__unittest'}) + self.assertTrue(cskwargs['name'].startswith('airflow-')) + self.assertEqual(client_mock.tasks.call_count, 3) + client_mock.remove_service.assert_called_with('some_id') + + @mock.patch('airflow.operators.docker_operator.APIClient') + @mock.patch('airflow.contrib.operators.docker_swarm_operator.types') + def test_no_auto_remove(self, types_mock, client_class_mock): + + mock_obj = mock.Mock() + + client_mock = mock.Mock(spec=APIClient) + client_mock.create_service.return_value = {'ID': 'some_id'} + client_mock.images.return_value = [] + client_mock.pull.return_value = [b'{"status":"pull log"}'] + client_mock.tasks.return_value = [{'Status': {'State': 'complete'}}] + types_mock.TaskTemplate.return_value = mock_obj + types_mock.ContainerSpec.return_value = mock_obj + types_mock.RestartPolicy.return_value = mock_obj + types_mock.Resources.return_value = mock_obj + + client_class_mock.return_value = client_mock + + operator = DockerSwarmOperator(image='', auto_remove=False, task_id='unittest') + operator.execute(None) + + self.assertEqual( + client_mock.remove_service.call_count, 0, + 'Docker service being removed even when `auto_remove` set to `False`' + ) + + @mock.patch('airflow.operators.docker_operator.APIClient') + @mock.patch('airflow.contrib.operators.docker_swarm_operator.types') + def test_failed_service_raises_error(self, types_mock, client_class_mock): + + mock_obj = mock.Mock() + + client_mock = mock.Mock(spec=APIClient) + client_mock.create_service.return_value = {'ID': 'some_id'} + client_mock.images.return_value = [] + client_mock.pull.return_value = [b'{"status":"pull log"}'] + client_mock.tasks.return_value = [{'Status': {'State': 'failed'}}] + types_mock.TaskTemplate.return_value = mock_obj + types_mock.ContainerSpec.return_value = mock_obj + types_mock.RestartPolicy.return_value = mock_obj + types_mock.Resources.return_value = mock_obj + + client_class_mock.return_value = client_mock + + operator = DockerSwarmOperator(image='', auto_remove=False, task_id='unittest') + msg = "Service failed: {'ID': 'some_id'}" + with self.assertRaises(AirflowException) as error: + operator.execute(None) + self.assertEqual(str(error.exception), msg) + + def test_on_kill(self): + client_mock = mock.Mock(spec=APIClient) + + operator = DockerSwarmOperator(image='', auto_remove=False, task_id='unittest') + operator.cli = client_mock + operator.service = {'ID': 'some_id'} + + operator.on_kill() + + client_mock.remove_service.assert_called_with('some_id')