Skip to content

Commit

Permalink
Add user defined execution context
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathanshir authored and turbaszek committed Jul 3, 2020
1 parent 8cc7fc1 commit de148b0
Show file tree
Hide file tree
Showing 7 changed files with 270 additions and 3 deletions.
11 changes: 11 additions & 0 deletions airflow/config_templates/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,17 @@
type: string
example: "path.to.CustomXCom"
default: "airflow.models.xcom.BaseXCom"
- name: additional_execute_contextmanager
description: |
Custom user function that returns a context manager. Syntax is "package.method".
Context is entered when operator starts executing task. ``__enter__()`` will be called
before the operator's execute method, and ``__exit__()`` shortly after.
Function's signature should accept two positional parameters - task instance
and execution context
version_added: 2.0.0
type: string
example: my.path.my_context_manager
default: ""

- name: logging
description: ~
Expand Down
8 changes: 8 additions & 0 deletions airflow/config_templates/default_airflow.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,14 @@ check_slas = True
# Example: xcom_backend = path.to.CustomXCom
xcom_backend = airflow.models.xcom.BaseXCom

# Custom user function that returns a context manager. Syntax is "package.method".
# Context is entered when operator starts executing task. ``__enter__()`` will be called
# before the operator's execute method, and ``__exit__()`` shortly after.
# Function's signature should accept two positional parameters - task instance
# and execution context
# Example: additional_execute_contextmanager = my.path.my_context_manager
additional_execute_contextmanager =

[logging]
# The folder where airflow should store its log files
# This path must be absolute
Expand Down
2 changes: 1 addition & 1 deletion airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,7 +1037,7 @@ def signal_handler(signum, frame):
self.log.error("Failed when executing execute callback")
self.log.exception(e3)

with set_current_context(context):
with set_current_context(context, task_copy):
result = self._execute_task(task_copy, context)

# If the task returns a result, push an XCom containing it
Expand Down
46 changes: 44 additions & 2 deletions airflow/task/context/current.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,63 @@
import logging
from typing import Any, Dict

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator

_CURRENT_CONTEXT = []
log = logging.getLogger(__name__)


def get_additional_execution_contextmanager(task_instance, execution_context):
"""
Retrieves the user defined execution context callback from the configuration,
and validates that it is indeed a context manager
:param execution_context: the current execution context to be passed to user ctx
"""
additional_execution_contextmanager = conf.getimport(
"core", "additional_execute_contextmanager"
)
if additional_execution_contextmanager:
try:
user_ctx_obj = additional_execution_contextmanager(
task_instance, execution_context
)
if hasattr(user_ctx_obj, "__enter__") and hasattr(user_ctx_obj, "__exit__"):
return user_ctx_obj
else:
raise AirflowException(
f"Loaded function {additional_execution_contextmanager} "
f"as additional execution contextmanager, but it does not have "
f"__enter__ or __exit__ method!"
)
except ImportError as e:
raise AirflowException(
f"Could not import additional execution contextmanager "
f"{additional_execution_contextmanager}!",
e,
)


@contextlib.contextmanager
def set_current_context(context: Dict[str, Any]):
def set_current_context(context: Dict[str, Any], task_instance: BaseOperator):
"""
Sets the current execution context to the provided context object.
This method should be called once per Task execution, before calling operator.execute
"""
_CURRENT_CONTEXT.append(context)

user_defined_exec_context = get_additional_execution_contextmanager(
context, task_instance
)

try:
yield context
if user_defined_exec_context is not None:
with user_defined_exec_context:
yield context
else:
yield context
finally:
expected_state = _CURRENT_CONTEXT.pop()
if expected_state != context:
Expand Down
1 change: 1 addition & 0 deletions docs/howto/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,4 @@ configuring an Airflow environment.
tracking-user-activity
email-config
use-alternative-secrets-backend
use-additional-execute-contextmanager
68 changes: 68 additions & 0 deletions docs/howto/use-additional-execute-contextmanager.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
.. Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
.. http://www.apache.org/licenses/LICENSE-2.0
.. Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
Defining Additional Execute Context Manager
===========================================

Creating new context manager
----------------------------

Users can create their own execution context manager to allow context management on a higher level.
To do so, one must define a new context manager in one of their files. This context manager is entered
before calling ``execute`` method and is exited shortly after it. Here is an example context manager
which provides authentication to Google Cloud Platform:

.. code-block:: python
import os
import subprocess
from contextlib import contextmanager
from tempfile import TemporaryDirectory
from unittest import mock
from google.auth.environment_vars import CLOUD_SDK_CONFIG_DIR
from airflow.providers.google.cloud.utils.credentials_provider import provide_gcp_conn_and_credentials
def execute_cmd(cmd):
with open(os.devnull, 'w') as dev_null:
return subprocess.call(args=cmd, stdout=dev_null, stderr=subprocess.STDOUT)
@contextmanager
def provide_gcp_context(task_instance, execution_context):
"""
Context manager that provides:
- GCP credentials for application supporting `Application Default Credentials (ADC)
strategy <https://cloud.google.com/docs/authentication/production>`__.
- temporary value of ``AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT`` connection
- the ``gcloud`` config directory isolated from user configuration
"""
project_id = os.environ["GCP_PROJECT_ID"]
key_file_path = os.environ["GCP_DEFAULT_SERVICE_KEY"]
with provide_gcp_conn_and_credentials(key_file_path, project_id=project_id), \
TemporaryDirectory() as gcloud_config_tmp, \
mock.patch.dict('os.environ', {CLOUD_SDK_CONFIG_DIR: gcloud_config_tmp}):
execute_cmd(["gcloud", "config", "set", "core/project", project_id])
execute_cmd(["gcloud", "auth", "activate-service-account", f"--key-file={key_file_path}"])
yield
execute_cmd(["gcloud", "config", "set", "account", "none", f"--project={project_id}"])
Your custom context manager has to accept two arguments:
1. ``task_instance`` - the executing task instance object (can also be retrieved from execution context via ``"ti"`` key.
2. ``execution_context`` - the execution context that is provided to an operator's ``execute`` function.
137 changes: 137 additions & 0 deletions tests/task/context/test_additional_execute_contextmanager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import contextlib
import os
import sys
from datetime import datetime, timedelta
from unittest import mock
from unittest.mock import MagicMock, Mock

import pytest

from airflow import DAG, configuration
from airflow.exceptions import AirflowException
from airflow.models import TaskInstance
from airflow.models.baseoperator import BaseOperator
from airflow.utils.dates import days_ago

DEFAULT_ARGS = {
"owner": "test",
"depends_on_past": True,
"start_date": days_ago(1),
"end_date": datetime.today(),
"schedule_interval": "@once",
"retries": 1,
"retry_delay": timedelta(minutes=1),
}


@contextlib.contextmanager
def user_defined_contextmanager(task_instance, execution_context): # pylint: disable=W0613
try:
yield True
finally:
pass


@contextlib.contextmanager
def incorrect_user_defined_context_func():
return 8


def not_ctx_manager(task_instance, execution_context): # pylint: disable=W0613
return None


def get_user_contextmanager(section="core", key="additional_execute_contextmanager",
value="test_additional_execute_contextmanager.user_defined_contextmanager"):
configuration.conf.set(section, key, value)
sys.path.append(os.path.dirname(__file__))
ti = Mock(TaskInstance)
ti.configure_mock(
get_additional_execution_contextmanager=TaskInstance.get_additional_execution_contextmanager)
context = ti.get_additional_execution_contextmanager(None, None)

return context


class TestUserDefinedContextLoad:
def test_config_read(self):
user_ctx = get_user_contextmanager()
assert user_ctx

def test_assert_exception_on_invalid_value(self):
with pytest.raises(AirflowException):
get_user_contextmanager(value="INVALID_PACKAGE.INVALID_MODULE.INVALID_FUNC")

def test_user_func_incorrect_signature(self):
with pytest.raises(TypeError):
get_user_contextmanager(
value="test_additional_execute_contextmanager.incorrect_user_defined_context_func")

def test_user_func_not_ctx_manager(self):
with pytest.raises(AirflowException):
get_user_contextmanager(value="test_additional_execute_contextmanager.not_ctx_manager")

def test_enter_exit_exists(self):
user_ctx = get_user_contextmanager()
assert user_ctx
# Ensure these attributes were loaded
assert user_ctx.__enter__
assert user_ctx.__exit__


class TestUserDefinedContextRuntime:
marker = MagicMock()
enter_counter = 0
exit_counter = 0

@staticmethod
def increment_enter_counter(p): # pylint: disable=W0613
TestUserDefinedContextRuntime.enter_counter += 1

@staticmethod
def increment_exit_counter(p1, p2, p3, p4): # pylint: disable=W0613
TestUserDefinedContextRuntime.exit_counter += 1

def test_simple_runtime(self):
# Configure mock so user context manager received is our mock marker object:
# (TestUserDefinedContextRuntime.marker)
attrs = {"__enter__": TestUserDefinedContextRuntime.increment_enter_counter,
"__exit__": TestUserDefinedContextRuntime.increment_exit_counter}
TestUserDefinedContextRuntime.marker.configure_mock(**attrs)

with mock.patch("test_additional_execute_contextmanager.user_defined_contextmanager",
return_value=TestUserDefinedContextRuntime.marker):
configuration.conf.set(
"core", "additional_execute_contextmanager", ""
"test_additional_execute_contextmanager"
".user_defined_contextmanager")
sys.path.append(os.path.dirname(__file__))

with DAG(dag_id="context_runtime_dag", default_args=DEFAULT_ARGS):
op = self.MySimpleOperator(task_id="check_affected_value")
op.run(ignore_ti_state=True, ignore_first_depends_on_past=True)

assert TestUserDefinedContextRuntime.marker.call_count == 1
assert TestUserDefinedContextRuntime.enter_counter == 1
assert TestUserDefinedContextRuntime.exit_counter == 1

class MySimpleOperator(BaseOperator):
def execute(self, context):
TestUserDefinedContextRuntime.marker()

0 comments on commit de148b0

Please sign in to comment.