Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix: Move rendering of map_index_template so it renders for failed tasks as long as it was defined before the point of failure #38902

Merged
merged 20 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2731,18 +2731,29 @@ def signal_handler(signum, frame):
previous_state=TaskInstanceState.QUEUED, task_instance=self, session=session
)

# Execute the task
def _render_map_index(context: Context, *, jinja_env: jinja2.Environment | None) -> str | None:
"""Render named map index if the DAG author defined map_index_template at the task level."""
if jinja_env is None or (template := context.get("map_index_template")) is None:
return None
rendered_map_index = jinja_env.from_string(template).render(context)
log.debug("Map index rendered as %s", rendered_map_index)
return rendered_map_index

# Execute the task.
with set_current_context(context):
result = self._execute_task(context, task_orig)
try:
result = self._execute_task(context, task_orig)
except Exception:
# If the task failed, swallow rendering error so it doesn't mask the main error.
with contextlib.suppress(jinja2.TemplateSyntaxError, jinja2.UndefinedError):
self.rendered_map_index = _render_map_index(context, jinja_env=jinja_env)
raise
else: # If the task succeeded, render normally to let rendering error bubble up.
self.rendered_map_index = _render_map_index(context, jinja_env=jinja_env)

# Run post_execute callback
self.task.post_execute(context=context, result=result)

# DAG authors define map_index_template at the task level
if jinja_env is not None and (template := context.get("map_index_template")) is not None:
rendered_map_index = self.rendered_map_index = jinja_env.from_string(template).render(context)
self.log.info("Map index rendered as %s", rendered_map_index)

Stats.incr(f"operator_successes_{self.task.task_type}", tags=self.stats_tags)
# Same metric with tagging
Stats.incr("operator_successes", tags={**self.stats_tags, "task_type": self.task.task_type})
Expand Down
29 changes: 29 additions & 0 deletions tests/models/test_mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,33 @@ def task1(map_name):
return task1.expand(map_name=map_names)


def _create_named_map_index_renders_on_failure_classic(*, task_id, map_names, template):
class HasMapName(BaseOperator):
def __init__(self, *, map_name: str, **kwargs):
super().__init__(**kwargs)
self.map_name = map_name

def execute(self, context):
context["map_name"] = self.map_name
raise AirflowSkipException("Imagine this task failed!")

return HasMapName.partial(task_id=task_id, map_index_template=template).expand(
map_name=map_names,
)


def _create_named_map_index_renders_on_failure_taskflow(*, task_id, map_names, template):
from airflow.operators.python import get_current_context

@task(task_id=task_id, map_index_template=template)
def task1(map_name):
context = get_current_context()
context["map_name"] = map_name
raise AirflowSkipException("Imagine this task failed!")

return task1.expand(map_name=map_names)


@pytest.mark.parametrize(
"template, expected_rendered_names",
[
Expand All @@ -645,6 +672,8 @@ def task1(map_name):
[
pytest.param(_create_mapped_with_name_template_classic, id="classic"),
pytest.param(_create_mapped_with_name_template_taskflow, id="taskflow"),
pytest.param(_create_named_map_index_renders_on_failure_classic, id="classic-failure"),
pytest.param(_create_named_map_index_renders_on_failure_taskflow, id="taskflow-failure"),
],
)
def test_expand_mapped_task_instance_with_named_index(
Expand Down