Skip to content

Commit

Permalink
1. Add logic to cover the case of nested template rendering.
Browse files Browse the repository at this point in the history
2. Add tests for the same
3. Add descriptive function names
  • Loading branch information
ashmeet13 committed Oct 21, 2020
1 parent 0ee69ad commit 08b0f75
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 29 deletions.
39 changes: 31 additions & 8 deletions airflow/upgrade/rules/undefined_jinja_varaibles.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class UndefinedJinjaVariablesRule(BaseRule):
With this change a task will fail if it recieves any undefined variables.
"""

def _check_rendered_content(self, rendered_content):
def _check_rendered_content(self, rendered_content, seen_oids=None):
"""Replicates the logic in BaseOperator.render_template() to
cover all the cases needed to be checked.
"""
Expand All @@ -47,15 +47,38 @@ def _check_rendered_content(self, rendered_content):
elif isinstance(rendered_content, (tuple, list, set)):
debug_error_messages = set()
for element in rendered_content:
debug_error_messages.union(self._check_rendered_content(element))
debug_error_messages.update(self._check_rendered_content(element))
return debug_error_messages

elif isinstance(rendered_content, dict):
debug_error_messages = set()
for key, value in rendered_content.items():
debug_error_messages.union(self._check_rendered_content(str(value)))
debug_error_messages.update(self._check_rendered_content(value))
return debug_error_messages

else:
if seen_oids is None:
seen_oids = set()
return self._nested_check_rendered(rendered_content, seen_oids)

def _nested_check_rendered(self, rendered_content, seen_oids):
debug_error_messages = set()
if id(rendered_content) not in seen_oids:
seen_oids.add(id(rendered_content))
nested_template_fields = rendered_content.template_fields
for attr_name in nested_template_fields:
nested_rendered_content = getattr(rendered_content, attr_name)

if nested_rendered_content:
errors = list(
self._check_rendered_content(nested_rendered_content, seen_oids)
)
for i in range(len(errors)):
errors[i].strip()
errors[i] += " NestedTemplateField={}".format(attr_name)
debug_error_messages.update(errors)
return debug_error_messages

def _render_task_content(self, task, content, context):
completed_rendering = False
errors_while_rendering = []
Expand All @@ -73,7 +96,7 @@ def _render_task_content(self, task, content, context):
errors_while_rendering.append(message)
return renderend_content, errors_while_rendering

def _task_level_(self, task):
def iterate_over_template_fields(self, task):
messages = {}
task_instance = TaskInstance(task=task, execution_date=timezone.utcnow())
context = task_instance.get_template_context()
Expand All @@ -84,18 +107,18 @@ def _task_level_(self, task):
task, content, context
)
debug_error_messages = list(
self._check_rendered_content(rendered_content)
self._check_rendered_content(rendered_content, set())
)
messages[attr_name] = errors_while_rendering + debug_error_messages

return messages

def _dag_level_(self, dag):
def iterate_over_dag_tasks(self, dag):
dag.template_undefined = jinja2.DebugUndefined
tasks = dag.tasks
messages = {}
for task in tasks:
error_messages = self._task_level_(task)
error_messages = self.iterate_over_template_fields(task)
messages[task.task_id] = error_messages
return messages

Expand All @@ -106,7 +129,7 @@ def check(self, dagbag=None):
dags = dagbag.dags
messages = []
for dag_id, dag in dags.items():
dag_messages = self._dag_level_(dag)
dag_messages = self.iterate_over_dag_tasks(dag)

for task_id, task_messages in dag_messages.items():
for attr_name, error_messages in task_messages.items():
Expand Down
99 changes: 78 additions & 21 deletions tests/upgrade/rules/test_undefined_jinja_varaibles.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,61 @@
from tests.models import DEFAULT_DATE


class TestConnTypeIsNotNullableRule(TestCase):
class ClassWithCustomAttributes:
"""Class for testing purpose: allows to create objects with custom attributes in one single statement."""

def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)

def __str__(self):
return "{}({})".format(ClassWithCustomAttributes.__name__, str(self.__dict__))

def __repr__(self):
return self.__str__()

def __eq__(self, other):
return self.__dict__ == other.__dict__

def __ne__(self, other):
return not self.__eq__(other)


class TestUndefinedJinjaVariablesRule(TestCase):
@classmethod
def setUpClass(cls):
cls.empty_dir = mkdtemp()

def setUp(self):
def setUpValidDag(self):
self.valid_dag = DAG(
dag_id="test-defined-jinja-variables", start_date=DEFAULT_DATE
)

BashOperator(
task_id="templated_string",
depends_on_past=False,
bash_command="echo",
env={
"integer": "{{ params.integer }}",
"float": "{{ params.float }}",
"string": "{{ params.string }}",
"boolean": "{{ params.boolean }}",
},
params={
"integer": 1,
"float": 1.0,
"string": "test_string",
"boolean": True,
},
dag=self.valid_dag,
)

def setUpInvalidDag(self):
self.invalid_dag = DAG(
dag_id="test-undefined-jinja-variables", start_date=DEFAULT_DATE
)
self.valid_dag = DAG(
dag_id="test-defined-jinja-variables", start_date=DEFAULT_DATE
)

template_command = """
invalid_template_command = """
{% for i in range(5) %}
echo "{{ params.defined_variable }}"
echo "{{ execution_date.today }}"
Expand All @@ -50,26 +90,32 @@ def setUp(self):
{% endfor %}
"""

BashOperator(
task_id="templated_string",
depends_on_past=False,
bash_command=template_command,
env={"undefined_object": "{{ undefined_object.element }}"},
params={"defined_variable": "defined_value"},
dag=self.invalid_dag,
nested_validation = ClassWithCustomAttributes(
nested1=ClassWithCustomAttributes(
att1="{{ nested.undefined }}", template_fields=["att1"]
),
nested2=ClassWithCustomAttributes(
att2="{{ bar }}", template_fields=["att2"]
),
template_fields=["nested1", "nested2"],
)

BashOperator(
task_id="templated_string",
depends_on_past=False,
bash_command="echo",
env={"defined_object": "{{ params.element }}"},
params={
"element": "defined_value",
bash_command=invalid_template_command,
env={
"undefined_object": "{{ undefined_object.element }}",
"nested_object": nested_validation,
},
dag=self.valid_dag,
params={"defined_variable": "defined_value"},
dag=self.invalid_dag,
)

def setUp(self):
self.setUpValidDag()
self.setUpInvalidDag()

def test_invalid_check(self):
dagbag = DagBag(dag_folder=self.empty_dir, include_examples=False)
dagbag.dags[self.invalid_dag.dag_id] = self.invalid_dag
Expand All @@ -81,9 +127,6 @@ def test_invalid_check(self):
messages = rule.check(dagbag)

expected_messages = [
"Possible UndefinedJinjaVariable -> DAG: test-undefined-jinja-variables, "
"Task: templated_string, Attribute: env, Error: Could not find the "
"object 'undefined_object",
"Possible UndefinedJinjaVariable -> DAG: test-undefined-jinja-variables, "
"Task: templated_string, Attribute: bash_command, Error: no such element: "
"dict object['undefined_variable']",
Expand All @@ -92,8 +135,22 @@ def test_invalid_check(self):
"pendulum.pendulum.Pendulum object['invalid_element']",
"Possible UndefinedJinjaVariable -> DAG: test-undefined-jinja-variables, "
"Task: templated_string, Attribute: bash_command, Error: foo",
"Possible UndefinedJinjaVariable -> DAG: test-undefined-jinja-variables, "
"Task: templated_string, Attribute: env, Error: Could not find the "
"object 'undefined_object",
"Possible UndefinedJinjaVariable -> DAG: test-undefined-jinja-variables, "
"Task: templated_string, Attribute: env, Error: Could not find the object 'nested'",
"Possible UndefinedJinjaVariable -> DAG: test-undefined-jinja-variables, "
"Task: templated_string, Attribute: env, Error: bar NestedTemplateField=att2 "
"NestedTemplateField=nested2",
"Possible UndefinedJinjaVariable -> DAG: test-undefined-jinja-variables, "
"Task: templated_string, Attribute: env, Error: no such element: "
"dict object['undefined'] NestedTemplateField=att1 NestedTemplateField=nested1",
"Possible UndefinedJinjaVariable -> DAG: test-undefined-jinja-variables, "
"Task: templated_string, Attribute: env, Error: no such element: dict object['element']",
]

assert len(messages) == len(expected_messages)
assert [m for m in messages if m in expected_messages], len(messages) == len(
expected_messages
)
Expand Down

0 comments on commit 08b0f75

Please sign in to comment.