Skip to content

Commit

Permalink
Bugfix: Move rendering of map_index_template so it renders for fail…
Browse files Browse the repository at this point in the history
…ed tasks as long as it was defined before the point of failure (#38902)

Co-authored-by: Tzu-ping Chung <uranusjr@gmail.com>
  • Loading branch information
TJaniF and uranusjr committed Apr 15, 2024
1 parent 1f0f907 commit 456ec48
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 7 deletions.
25 changes: 18 additions & 7 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2731,18 +2731,29 @@ def signal_handler(signum, frame):
previous_state=TaskInstanceState.QUEUED, task_instance=self, session=session
)

# Execute the task
def _render_map_index(context: Context, *, jinja_env: jinja2.Environment | None) -> str | None:
"""Render named map index if the DAG author defined map_index_template at the task level."""
if jinja_env is None or (template := context.get("map_index_template")) is None:
return None
rendered_map_index = jinja_env.from_string(template).render(context)
log.debug("Map index rendered as %s", rendered_map_index)
return rendered_map_index

# Execute the task.
with set_current_context(context):
result = self._execute_task(context, task_orig)
try:
result = self._execute_task(context, task_orig)
except Exception:
# If the task failed, swallow rendering error so it doesn't mask the main error.
with contextlib.suppress(jinja2.TemplateSyntaxError, jinja2.UndefinedError):
self.rendered_map_index = _render_map_index(context, jinja_env=jinja_env)
raise
else: # If the task succeeded, render normally to let rendering error bubble up.
self.rendered_map_index = _render_map_index(context, jinja_env=jinja_env)

# Run post_execute callback
self.task.post_execute(context=context, result=result)

# DAG authors define map_index_template at the task level
if jinja_env is not None and (template := context.get("map_index_template")) is not None:
rendered_map_index = self.rendered_map_index = jinja_env.from_string(template).render(context)
self.log.info("Map index rendered as %s", rendered_map_index)

Stats.incr(f"operator_successes_{self.task.task_type}", tags=self.stats_tags)
# Same metric with tagging
Stats.incr("operator_successes", tags={**self.stats_tags, "task_type": self.task.task_type})
Expand Down
29 changes: 29 additions & 0 deletions tests/models/test_mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,33 @@ def task1(map_name):
return task1.expand(map_name=map_names)


def _create_named_map_index_renders_on_failure_classic(*, task_id, map_names, template):
class HasMapName(BaseOperator):
def __init__(self, *, map_name: str, **kwargs):
super().__init__(**kwargs)
self.map_name = map_name

def execute(self, context):
context["map_name"] = self.map_name
raise AirflowSkipException("Imagine this task failed!")

return HasMapName.partial(task_id=task_id, map_index_template=template).expand(
map_name=map_names,
)


def _create_named_map_index_renders_on_failure_taskflow(*, task_id, map_names, template):
from airflow.operators.python import get_current_context

@task(task_id=task_id, map_index_template=template)
def task1(map_name):
context = get_current_context()
context["map_name"] = map_name
raise AirflowSkipException("Imagine this task failed!")

return task1.expand(map_name=map_names)


@pytest.mark.parametrize(
"template, expected_rendered_names",
[
Expand All @@ -645,6 +672,8 @@ def task1(map_name):
[
pytest.param(_create_mapped_with_name_template_classic, id="classic"),
pytest.param(_create_mapped_with_name_template_taskflow, id="taskflow"),
pytest.param(_create_named_map_index_renders_on_failure_classic, id="classic-failure"),
pytest.param(_create_named_map_index_renders_on_failure_taskflow, id="taskflow-failure"),
],
)
def test_expand_mapped_task_instance_with_named_index(
Expand Down

0 comments on commit 456ec48

Please sign in to comment.