Skip to content

Commit

Permalink
Update logic to cover more cases of Undefined Jinja Variables.
Browse files Browse the repository at this point in the history
Add tests for valid and invalid check
  • Loading branch information
ashmeet13 committed Oct 6, 2020
1 parent fcb3b8d commit 60d021a
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 47 deletions.
99 changes: 64 additions & 35 deletions airflow/upgrade/rules/undefined_jinja_varaibles.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import jinja2
import six

from airflow import conf
from airflow.models import DagBag, TaskInstance
from airflow.upgrade.rules.base_rule import BaseRule
from airflow.utils import timezone
Expand All @@ -36,57 +37,85 @@ class UndefinedJinjaVariablesRule(BaseRule):
With this change a task will fail if it recieves any undefined variables.
"""

def check_rendered_content(self, rendered_content):
def _check_rendered_content(self, rendered_content):
"""Replicates the logic in BaseOperator.render_template() to
cover all the cases needed to be checked.
"""
if isinstance(rendered_content, six.string_types):
return set(re.findall(r"{{(.*?)}}", rendered_content))

elif isinstance(rendered_content, (tuple, list, set)):
parsed_templates = set()
debug_error_messages = set()
for element in rendered_content:
parsed_templates.union(self.check_rendered_content(element))
return parsed_templates
debug_error_messages.union(self._check_rendered_content(element))
return debug_error_messages

elif isinstance(rendered_content, dict):
parsed_templates = set()
debug_error_messages = set()
for key, value in rendered_content.items():
parsed_templates.union(self.check_rendered_content(str(value)))
return parsed_templates
debug_error_messages.union(self._check_rendered_content(str(value)))
return debug_error_messages

def check(self, dagbag=DagBag()):
dags = dagbag.dags
messages = []
for dag_id, dag in dags.items():
bracket_pattern = r"\[(.*?)\]"
dag.template_undefined = jinja2.DebugUndefined
for task in dag.tasks:
task_instance = TaskInstance(
task=task, execution_date=timezone.utcnow()
def _render_task_content(self, task, content, context):
completed_rendering = False
errors_while_rendering = []
while not completed_rendering:
# Catch errors such as {{ object.element }} where
# object is not defined
try:
renderend_content = task.render_template(content, context)
completed_rendering = True
except Exception as e:
undefined_variable = re.sub(" is undefined", "", str(e))
undefined_variable = re.sub("'", "", undefined_variable)
context[undefined_variable] = dict()
message = "Could not find the object '{}'".format(undefined_variable)
errors_while_rendering.append(message)
return renderend_content, errors_while_rendering

def _task_level_(self, task):
messages = {}
task_instance = TaskInstance(task=task, execution_date=timezone.utcnow())
context = task_instance.get_template_context()
for attr_name in task.template_fields:
content = getattr(task, attr_name)
if content:
rendered_content, errors_while_rendering = self._render_task_content(
task, content, context
)
debug_error_messages = list(
self._check_rendered_content(rendered_content)
)
template_context = task_instance.get_template_context()
messages[attr_name] = errors_while_rendering + debug_error_messages

return messages

rendered_content_collection = []
def _dag_level_(self, dag):
dag.template_undefined = jinja2.DebugUndefined
tasks = dag.tasks
messages = {}
for task in tasks:
error_messages = self._task_level_(task)
messages[task.task_id] = error_messages
return messages

for attr_name in task.template_fields:
content = getattr(task, attr_name)
if content:
rendered_content_collection.append(
task.render_template(content, template_context)
)
def check(self, dagbag=None):
if not dagbag:
dag_folder = conf.get("core", "dags_folder")
dagbag = DagBag(dag_folder)
dags = dagbag.dags
messages = []
for dag_id, dag in dags.items():
dag_messages = self._dag_level_(dag)

for rendered_content in rendered_content_collection:
undefined_variables = self.check_rendered_content(rendered_content)
for undefined_variable in undefined_variables:
result = re.findall(bracket_pattern, undefined_variable)
if result:
undefined_variable = result[0].strip("'")
new_msg = (
"Possible Undefined Jinja Variable -> DAG: {}, Task: {}, "
"Variable: {}".format(
dag_id, task.task_id, undefined_variable.strip()
for task_id, task_messages in dag_messages.items():
for attr_name, error_messages in task_messages.items():
for error_message in error_messages:
message = (
"Possible UndefinedJinjaVariable -> DAG: {}, Task: {}, "
"Attribute: {}, Error: {}".format(
dag_id, task_id, attr_name, error_message.strip()
)
)
messages.append(new_msg)
messages.append(message)
return messages
60 changes: 48 additions & 12 deletions tests/upgrade/rules/test_undefined_jinja_varaibles.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from airflow import DAG
from airflow.models import DagBag
from airflow.operators.bash_operator import BashOperator
from airflow.upgrade.rules.undefined_jinja_varaibles import UndefinedJinjaVariablesRule
from airflow.upgrade.rules.undefined_jinja_varaibles import \
UndefinedJinjaVariablesRule
from tests.models import DEFAULT_DATE


Expand All @@ -32,12 +33,18 @@ def setUpClass(cls):

def setUp(self):

self.dag = DAG(dag_id="test-undefined-jinja-variables", start_date=DEFAULT_DATE)
self.invalid_dag = DAG(
dag_id="test-undefined-jinja-variables", start_date=DEFAULT_DATE
)
self.valid_dag = DAG(
dag_id="test-defined-jinja-variables", start_date=DEFAULT_DATE
)

template_command = """
{% for i in range(5) %}
echo "{{ params.defined_variable }}"
echo "{{ execution_date.today }}"
echo "{{ execution_date.invalid_element }}"
echo "{{ params.undefined_variable }}"
echo "{{ foo }}"
{% endfor %}
Expand All @@ -47,29 +54,58 @@ def setUp(self):
task_id="templated_string",
depends_on_past=False,
bash_command=template_command,
env={"undefined_object": "{{ undefined_object.element }}"},
params={"defined_variable": "defined_value"},
dag=self.dag,
dag=self.invalid_dag,
)

self.dagbag = DagBag(dag_folder=self.empty_dir, include_examples=False)
self.dagbag.dags[self.dag.dag_id] = self.dag
BashOperator(
task_id="templated_string",
depends_on_past=False,
bash_command="echo",
env={"defined_object": "{{ params.element }}"},
params={
"element": "defined_value",
},
dag=self.valid_dag,
)

def test_check(self):
def test_invalid_check(self):
dagbag = DagBag(dag_folder=self.empty_dir, include_examples=False)
dagbag.dags[self.invalid_dag.dag_id] = self.invalid_dag
rule = UndefinedJinjaVariablesRule()

assert isinstance(rule.description, str)
assert isinstance(rule.title, str)

messages = rule.check(self.dagbag)
messages = rule.check(dagbag)

expected_messages = [
"Possible Undefined Jinja Variable -> DAG: test-undefined-jinja-variables, "
"Task: templated_string, Variable: undefined",
"Possible Undefined Jinja Variable -> DAG: test-undefined-jinja-variables, "
"Task: templated_string, Variable: foo",
"Possible UndefinedJinjaVariable -> DAG: test-undefined-jinja-variables, "
"Task: templated_string, Attribute: env, Error: Could not find the "
"object 'undefined_object",
"Possible UndefinedJinjaVariable -> DAG: test-undefined-jinja-variables, "
"Task: templated_string, Attribute: bash_command, Error: no such element: "
"dict object['undefined_variable']",
"Possible UndefinedJinjaVariable -> DAG: test-undefined-jinja-variables, "
"Task: templated_string, Attribute: bash_command, Error: no such element: "
"pendulum.pendulum.Pendulum object['invalid_element']",
"Possible UndefinedJinjaVariable -> DAG: test-undefined-jinja-variables, "
"Task: templated_string, Attribute: bash_command, Error: foo",
]

print(messages)
assert [m for m in messages if m in expected_messages], len(messages) == len(
expected_messages
)

def test_valid_check(self):
dagbag = DagBag(dag_folder=self.empty_dir, include_examples=False)
dagbag.dags[self.valid_dag.dag_id] = self.valid_dag
rule = UndefinedJinjaVariablesRule()

assert isinstance(rule.description, str)
assert isinstance(rule.title, str)

messages = rule.check(dagbag)

assert len(messages) == 0

0 comments on commit 60d021a

Please sign in to comment.