Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ClusterPolicyViolation support to airflow local settings #10282

Merged
merged 22 commits into from
Aug 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions airflow/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ class AirflowDagCycleException(AirflowException):
"""Raise when there is a cycle in Dag definition"""


class AirflowClusterPolicyViolation(AirflowException):
"""Raise when there is a violation of a Cluster Policy in Dag definition"""


class DagNotFound(AirflowNotFoundException):
"""Raise when a DAG is not available in the system"""

Expand Down
7 changes: 4 additions & 3 deletions airflow/models/dagbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from airflow import settings
from airflow.configuration import conf
from airflow.dag.base_dag import BaseDagBag
from airflow.exceptions import AirflowDagCycleException
from airflow.exceptions import AirflowClusterPolicyViolation, AirflowDagCycleException
from airflow.plugins_manager import integrate_dag_plugins
from airflow.stats import Stats
from airflow.utils import timezone
Expand Down Expand Up @@ -343,9 +343,10 @@ def _process_modules(self, filepath, mods, file_last_changed_on_disk):
self.import_errors[dag.full_filepath] = f"Invalid Cron expression: {cron_e}"
self.file_last_changed[dag.full_filepath] = \
file_last_changed_on_disk
except AirflowDagCycleException as cycle_exception:
except (AirflowDagCycleException,
AirflowClusterPolicyViolation) as exception:
jaketf marked this conversation as resolved.
Show resolved Hide resolved
self.log.exception("Failed to bag_dag: %s", dag.full_filepath)
self.import_errors[dag.full_filepath] = str(cycle_exception)
self.import_errors[dag.full_filepath] = str(exception)
self.file_last_changed[dag.full_filepath] = file_last_changed_on_disk
return found_dags

Expand Down
36 changes: 36 additions & 0 deletions docs/concepts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1156,7 +1156,11 @@ state.

Cluster Policy
==============
Cluster policies provide an interface for taking action on every Airflow task
either at DAG load time or just before task execution.

Cluster Policies for Task Mutation
-----------------------------------
In case you want to apply cluster-wide mutations to the Airflow tasks,
you can either mutate the task right after the DAG is loaded or
mutate the task instance before task execution.
Expand Down Expand Up @@ -1207,6 +1211,38 @@ queue during retries:
ti.queue = 'retry_queue'


Cluster Policies for Custom Task Checks
-------------------------------------------
You may also use Cluster Policies to apply cluster-wide checks on Airflow
tasks. You can raise :class:`~airflow.exceptions.AirflowClusterPolicyViolation`
in a policy or task mutation hook (described below) to prevent a DAG from being
imported or prevent a task from being executed if the task is not compliant with
your check.

These checks are intended to help teams using Airflow to protect against common
beginner errors that may get past a code reviewer, rather than as technical
security controls.

For example, don't run tasks without airflow owners:

.. literalinclude:: /../tests/cluster_policies/__init__.py
:language: python
:start-after: [START example_cluster_policy_rule]
:end-before: [END example_cluster_policy_rule]

If you have multiple checks to apply, it is best practice to curate these rules
in a separate python module and have a single policy / task mutation hook that
performs multiple of these custom checks and aggregates the various error
messages so that a single ``AirflowClusterPolicyViolation`` can be reported in
the UI (and import errors table in the database).

For Example in ``airflow_local_settings.py``:

.. literalinclude:: /../tests/cluster_policies/__init__.py
:language: python
:start-after: [START example_list_of_cluster_policy_rules]
:end-before: [END example_list_of_cluster_policy_rules]

Where to put ``airflow_local_settings.py``?
-------------------------------------------

Expand Down
59 changes: 59 additions & 0 deletions tests/cluster_policies/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Callable, List

from airflow.configuration import conf
from airflow.exceptions import AirflowClusterPolicyViolation
from airflow.models.baseoperator import BaseOperator


# [START example_cluster_policy_rule]
def task_must_have_owners(task: BaseOperator):
if not task.owner or task.owner.lower() == conf.get('operators',
'default_owner'):
raise AirflowClusterPolicyViolation(
f'''Task must have non-None non-default owner. Current value: {task.owner}''')
# [END example_cluster_policy_rule]


# [START example_list_of_cluster_policy_rules]
TASK_RULES: List[Callable[[BaseOperator], None]] = [
task_must_have_owners,
]


def _check_task_rules(current_task: BaseOperator):
"""Check task rules for given task."""
notices = []
for rule in TASK_RULES:
try:
rule(current_task)
except AirflowClusterPolicyViolation as ex:
notices.append(str(ex))
if notices:
notices_list = " * " + "\n * ".join(notices)
raise AirflowClusterPolicyViolation(
f"DAG policy violation (DAG ID: {current_task.dag_id}, Path: {current_task.dag.filepath}):\n"
f"Notices:\n"
f"{notices_list}")


def cluster_policy(task: BaseOperator):
"""Ensure Tasks have non-default owners."""
_check_task_rules(task)
# [END example_list_of_cluster_policy_rules]
32 changes: 32 additions & 0 deletions tests/dags/test_missing_owner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from datetime import timedelta

from airflow import DAG
from airflow.operators.dummy_operator import DummyOperator
from airflow.utils.dates import days_ago

with DAG(
dag_id="test_missing_owner",
schedule_interval="0 0 * * *",
start_date=days_ago(2),
dagrun_timeout=timedelta(minutes=60),
tags=["example"],
) as dag:
run_this_last = DummyOperator(task_id="test_task",)
32 changes: 32 additions & 0 deletions tests/dags/test_with_non_default_owner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from datetime import timedelta

from airflow import DAG
from airflow.operators.dummy_operator import DummyOperator
from airflow.utils.dates import days_ago

with DAG(
dag_id="test_with_non_default_owner",
schedule_interval="0 0 * * *",
start_date=days_ago(2),
dagrun_timeout=timedelta(minutes=60),
tags=["example"],
) as dag:
run_this_last = DummyOperator(task_id="test_task", owner="John",)
33 changes: 32 additions & 1 deletion tests/models/test_dagbag.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
Expand Down Expand Up @@ -33,6 +32,7 @@
from airflow.models.serialized_dag import SerializedDagModel
from airflow.utils.dates import timezone as tz
from airflow.utils.session import create_session
from tests import cluster_policies
from tests.models import TEST_DAGS_FOLDER
from tests.test_utils import db
from tests.test_utils.asserts import assert_queries_count
Expand Down Expand Up @@ -700,3 +700,34 @@ def test_get_dag_with_dag_serialization(self):

self.assertCountEqual(updated_ser_dag_1.tags, ["example", "new_tag"])
self.assertGreater(updated_ser_dag_1_update_time, ser_dag_1_update_time)

@patch("airflow.settings.policy", cluster_policies.cluster_policy)
def test_cluster_policy_violation(self):
"""test that file processing results in import error when task does not
obey cluster policy.
"""
dag_file = os.path.join(TEST_DAGS_FOLDER, "test_missing_owner.py")

dagbag = DagBag(dag_folder=dag_file)
self.assertEqual(set(), set(dagbag.dag_ids))
expected_import_errors = {
dag_file: (
f"""DAG policy violation (DAG ID: test_missing_owner, Path: {dag_file}):\n"""
"""Notices:\n"""
""" * Task must have non-None non-default owner. Current value: airflow"""
)
}
self.assertEqual(expected_import_errors, dagbag.import_errors)

@patch("airflow.settings.policy", cluster_policies.cluster_policy)
def test_cluster_policy_obeyed(self):
"""test that dag successfully imported without import errors when tasks
obey cluster policy.
"""
dag_file = os.path.join(TEST_DAGS_FOLDER,
"test_with_non_default_owner.py")

dagbag = DagBag(dag_folder=dag_file)
self.assertEqual({"test_with_non_default_owner"}, set(dagbag.dag_ids))

self.assertEqual({}, dagbag.import_errors)