diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 348e19d0ac0c4..a1908e34e7a3a 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -224,7 +224,7 @@ def __init__( end_date=None, # type: Optional[datetime] full_filepath=None, # type: Optional[str] template_searchpath=None, # type: Optional[Union[str, Iterable[str]]] - template_undefined=jinja2.Undefined, # type: Type[jinja2.Undefined] + template_undefined=None, # type: Optional[Type[jinja2.Undefined]] user_defined_macros=None, # type: Optional[Dict] user_defined_filters=None, # type: Optional[Dict] default_args=None, # type: Optional[Dict] @@ -807,7 +807,7 @@ def get_template_env(self): # type: () -> jinja2.Environment # Default values (for backward compatibility) jinja_env_options = { 'loader': jinja2.FileSystemLoader(searchpath), - 'undefined': self.template_undefined, + 'undefined': self.template_undefined or jinja2.Undefined, 'extensions': ["jinja2.ext.do"], 'cache_size': 0 } diff --git a/airflow/upgrade/rules/undefined_jinja_varaibles.py b/airflow/upgrade/rules/undefined_jinja_varaibles.py new file mode 100644 index 0000000000000..b97cfbc696996 --- /dev/null +++ b/airflow/upgrade/rules/undefined_jinja_varaibles.py @@ -0,0 +1,153 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import + +import re + +import jinja2 +import six + +from airflow import conf +from airflow.models import DagBag, TaskInstance +from airflow.upgrade.rules.base_rule import BaseRule +from airflow.utils import timezone + + +class UndefinedJinjaVariablesRule(BaseRule): + + title = "Jinja Template Variables cannot be undefined" + + description = """\ +The default behavior for DAG's Jinja templates has changed. Now, more restrictive validation +of non-existent variables is applied - `jinja2.StrictUndefined`. + +The user should do either of the following to fix this - +1. Fix the Jinja Templates by defining every variable or providing default values +2. Explicitly declare `template_undefined=jinja2.Undefined` while defining the DAG +""" + + def _check_rendered_content(self, rendered_content, seen_oids=None): + """Replicates the logic in BaseOperator.render_template() to + cover all the cases needed to be checked. + """ + if isinstance(rendered_content, six.string_types): + return set(re.findall(r"{{(.*?)}}", rendered_content)) + + elif isinstance(rendered_content, (int, float, bool)): + return set() + + elif isinstance(rendered_content, (tuple, list, set)): + debug_error_messages = set() + for element in rendered_content: + debug_error_messages.update(self._check_rendered_content(element)) + return debug_error_messages + + elif isinstance(rendered_content, dict): + debug_error_messages = set() + for key, value in rendered_content.items(): + debug_error_messages.update(self._check_rendered_content(value)) + return debug_error_messages + + else: + if seen_oids is None: + seen_oids = set() + return self._nested_check_rendered(rendered_content, seen_oids) + + def _nested_check_rendered(self, rendered_content, seen_oids): + debug_error_messages = set() + if id(rendered_content) not in seen_oids: + seen_oids.add(id(rendered_content)) + nested_template_fields = rendered_content.template_fields + for attr_name in nested_template_fields: + nested_rendered_content = getattr(rendered_content, attr_name) + + if nested_rendered_content: + errors = list( + self._check_rendered_content(nested_rendered_content, seen_oids) + ) + for i in range(len(errors)): + errors[i].strip() + errors[i] += " NestedTemplateField={}".format(attr_name) + debug_error_messages.update(errors) + return debug_error_messages + + def _render_task_content(self, task, content, context): + completed_rendering = False + errors_while_rendering = [] + while not completed_rendering: + # Catch errors such as {{ object.element }} where + # object is not defined + try: + renderend_content = task.render_template(content, context) + completed_rendering = True + except Exception as e: + undefined_variable = re.sub(" is undefined", "", str(e)) + undefined_variable = re.sub("'", "", undefined_variable) + context[undefined_variable] = dict() + message = "Could not find the object '{}'".format(undefined_variable) + errors_while_rendering.append(message) + return renderend_content, errors_while_rendering + + def iterate_over_template_fields(self, task): + messages = {} + task_instance = TaskInstance(task=task, execution_date=timezone.utcnow()) + context = task_instance.get_template_context() + for attr_name in task.template_fields: + content = getattr(task, attr_name) + if content: + rendered_content, errors_while_rendering = self._render_task_content( + task, content, context + ) + debug_error_messages = list( + self._check_rendered_content(rendered_content, set()) + ) + messages[attr_name] = errors_while_rendering + debug_error_messages + + return messages + + def iterate_over_dag_tasks(self, dag): + dag.template_undefined = jinja2.DebugUndefined + tasks = dag.tasks + messages = {} + for task in tasks: + error_messages = self.iterate_over_template_fields(task) + messages[task.task_id] = error_messages + return messages + + def check(self, dagbag=None): + if not dagbag: + dag_folder = conf.get("core", "dags_folder") + dagbag = DagBag(dag_folder) + dags = dagbag.dags + messages = [] + for dag_id, dag in dags.items(): + if dag.template_undefined: + continue + dag_messages = self.iterate_over_dag_tasks(dag) + + for task_id, task_messages in dag_messages.items(): + for attr_name, error_messages in task_messages.items(): + for error_message in error_messages: + message = ( + "Possible UndefinedJinjaVariable -> DAG: {}, Task: {}, " + "Attribute: {}, Error: {}".format( + dag_id, task_id, attr_name, error_message.strip() + ) + ) + messages.append(message) + return messages diff --git a/tests/upgrade/rules/test_undefined_jinja_varaibles.py b/tests/upgrade/rules/test_undefined_jinja_varaibles.py new file mode 100644 index 0000000000000..83f99a33044e7 --- /dev/null +++ b/tests/upgrade/rules/test_undefined_jinja_varaibles.py @@ -0,0 +1,192 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from tempfile import mkdtemp +from unittest import TestCase + +import jinja2 + +from airflow import DAG +from airflow.models import DagBag +from airflow.operators.bash_operator import BashOperator +from airflow.upgrade.rules.undefined_jinja_varaibles import UndefinedJinjaVariablesRule +from tests.models import DEFAULT_DATE + + +class ClassWithCustomAttributes: + """Class for testing purpose: allows to create objects with custom attributes in one single statement.""" + + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + def __str__(self): + return "{}({})".format(ClassWithCustomAttributes.__name__, str(self.__dict__)) + + def __repr__(self): + return self.__str__() + + def __eq__(self, other): + return self.__dict__ == other.__dict__ + + def __ne__(self, other): + return not self.__eq__(other) + + +class TestUndefinedJinjaVariablesRule(TestCase): + @classmethod + def setUpClass(cls): + cls.empty_dir = mkdtemp() + + def setUpValidDag(self): + self.valid_dag = DAG( + dag_id="test-defined-jinja-variables", start_date=DEFAULT_DATE + ) + + BashOperator( + task_id="templated_string", + depends_on_past=False, + bash_command="echo", + env={ + "integer": "{{ params.integer }}", + "float": "{{ params.float }}", + "string": "{{ params.string }}", + "boolean": "{{ params.boolean }}", + }, + params={ + "integer": 1, + "float": 1.0, + "string": "test_string", + "boolean": True, + }, + dag=self.valid_dag, + ) + + def setUpDagToSkip(self): + self.skip_dag = DAG( + dag_id="test-defined-jinja-variables", + start_date=DEFAULT_DATE, + template_undefined=jinja2.Undefined, + ) + + BashOperator( + task_id="templated_string", + depends_on_past=False, + bash_command="{{ undefined }}", + dag=self.skip_dag, + ) + + def setUpInvalidDag(self): + self.invalid_dag = DAG( + dag_id="test-undefined-jinja-variables", start_date=DEFAULT_DATE + ) + + invalid_template_command = """ + {% for i in range(5) %} + echo "{{ params.defined_variable }}" + echo "{{ execution_date.today }}" + echo "{{ execution_date.invalid_element }}" + echo "{{ params.undefined_variable }}" + echo "{{ foo }}" + {% endfor %} + """ + + nested_validation = ClassWithCustomAttributes( + nested1=ClassWithCustomAttributes( + att1="{{ nested.undefined }}", template_fields=["att1"] + ), + nested2=ClassWithCustomAttributes( + att2="{{ bar }}", template_fields=["att2"] + ), + template_fields=["nested1", "nested2"], + ) + + BashOperator( + task_id="templated_string", + depends_on_past=False, + bash_command=invalid_template_command, + env={ + "undefined_object": "{{ undefined_object.element }}", + "nested_object": nested_validation, + }, + params={"defined_variable": "defined_value"}, + dag=self.invalid_dag, + ) + + def setUp(self): + self.setUpValidDag() + self.setUpDagToSkip() + self.setUpInvalidDag() + + def test_description_and_title_is_defined(self): + rule = UndefinedJinjaVariablesRule() + assert isinstance(rule.description, str) + assert isinstance(rule.title, str) + + def test_valid_check(self): + dagbag = DagBag(dag_folder=self.empty_dir, include_examples=False) + dagbag.dags[self.valid_dag.dag_id] = self.valid_dag + rule = UndefinedJinjaVariablesRule() + + messages = rule.check(dagbag) + + assert len(messages) == 0 + + def test_skipping_dag_check(self): + dagbag = DagBag(dag_folder=self.empty_dir, include_examples=False) + dagbag.dags[self.skip_dag.dag_id] = self.skip_dag + rule = UndefinedJinjaVariablesRule() + + messages = rule.check(dagbag) + + assert len(messages) == 0 + + def test_invalid_check(self): + dagbag = DagBag(dag_folder=self.empty_dir, include_examples=False) + dagbag.dags[self.invalid_dag.dag_id] = self.invalid_dag + rule = UndefinedJinjaVariablesRule() + + messages = rule.check(dagbag) + + expected_messages = [ + "Possible UndefinedJinjaVariable -> DAG: test-undefined-jinja-variables, " + "Task: templated_string, Attribute: bash_command, Error: no such element: " + "dict object['undefined_variable']", + "Possible UndefinedJinjaVariable -> DAG: test-undefined-jinja-variables, " + "Task: templated_string, Attribute: bash_command, Error: no such element: " + "pendulum.pendulum.Pendulum object['invalid_element']", + "Possible UndefinedJinjaVariable -> DAG: test-undefined-jinja-variables, " + "Task: templated_string, Attribute: bash_command, Error: foo", + "Possible UndefinedJinjaVariable -> DAG: test-undefined-jinja-variables, " + "Task: templated_string, Attribute: env, Error: Could not find the " + "object 'undefined_object", + "Possible UndefinedJinjaVariable -> DAG: test-undefined-jinja-variables, " + "Task: templated_string, Attribute: env, Error: Could not find the object 'nested'", + "Possible UndefinedJinjaVariable -> DAG: test-undefined-jinja-variables, " + "Task: templated_string, Attribute: env, Error: bar NestedTemplateField=att2 " + "NestedTemplateField=nested2", + "Possible UndefinedJinjaVariable -> DAG: test-undefined-jinja-variables, " + "Task: templated_string, Attribute: env, Error: no such element: " + "dict object['undefined'] NestedTemplateField=att1 NestedTemplateField=nested1", + "Possible UndefinedJinjaVariable -> DAG: test-undefined-jinja-variables, " + "Task: templated_string, Attribute: env, Error: no such element: dict object['element']", + ] + + assert len(messages) == len(expected_messages) + assert [m for m in messages if m in expected_messages], len(messages) == len( + expected_messages + )