diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 2b8b935d78abf..8f9d71cfe7f43 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -2731,18 +2731,29 @@ def signal_handler(signum, frame): previous_state=TaskInstanceState.QUEUED, task_instance=self, session=session ) - # Execute the task + def _render_map_index(context: Context, *, jinja_env: jinja2.Environment | None) -> str | None: + """Render named map index if the DAG author defined map_index_template at the task level.""" + if jinja_env is None or (template := context.get("map_index_template")) is None: + return None + rendered_map_index = jinja_env.from_string(template).render(context) + log.debug("Map index rendered as %s", rendered_map_index) + return rendered_map_index + + # Execute the task. with set_current_context(context): - result = self._execute_task(context, task_orig) + try: + result = self._execute_task(context, task_orig) + except Exception: + # If the task failed, swallow rendering error so it doesn't mask the main error. + with contextlib.suppress(jinja2.TemplateSyntaxError, jinja2.UndefinedError): + self.rendered_map_index = _render_map_index(context, jinja_env=jinja_env) + raise + else: # If the task succeeded, render normally to let rendering error bubble up. + self.rendered_map_index = _render_map_index(context, jinja_env=jinja_env) # Run post_execute callback self.task.post_execute(context=context, result=result) - # DAG authors define map_index_template at the task level - if jinja_env is not None and (template := context.get("map_index_template")) is not None: - rendered_map_index = self.rendered_map_index = jinja_env.from_string(template).render(context) - self.log.info("Map index rendered as %s", rendered_map_index) - Stats.incr(f"operator_successes_{self.task.task_type}", tags=self.stats_tags) # Same metric with tagging Stats.incr("operator_successes", tags={**self.stats_tags, "task_type": self.task.task_type}) diff --git a/tests/models/test_mappedoperator.py b/tests/models/test_mappedoperator.py index e80a629794ae1..9f31652424aeb 100644 --- a/tests/models/test_mappedoperator.py +++ b/tests/models/test_mappedoperator.py @@ -631,6 +631,33 @@ def task1(map_name): return task1.expand(map_name=map_names) +def _create_named_map_index_renders_on_failure_classic(*, task_id, map_names, template): + class HasMapName(BaseOperator): + def __init__(self, *, map_name: str, **kwargs): + super().__init__(**kwargs) + self.map_name = map_name + + def execute(self, context): + context["map_name"] = self.map_name + raise AirflowSkipException("Imagine this task failed!") + + return HasMapName.partial(task_id=task_id, map_index_template=template).expand( + map_name=map_names, + ) + + +def _create_named_map_index_renders_on_failure_taskflow(*, task_id, map_names, template): + from airflow.operators.python import get_current_context + + @task(task_id=task_id, map_index_template=template) + def task1(map_name): + context = get_current_context() + context["map_name"] = map_name + raise AirflowSkipException("Imagine this task failed!") + + return task1.expand(map_name=map_names) + + @pytest.mark.parametrize( "template, expected_rendered_names", [ @@ -645,6 +672,8 @@ def task1(map_name): [ pytest.param(_create_mapped_with_name_template_classic, id="classic"), pytest.param(_create_mapped_with_name_template_taskflow, id="taskflow"), + pytest.param(_create_named_map_index_renders_on_failure_classic, id="classic-failure"), + pytest.param(_create_named_map_index_renders_on_failure_taskflow, id="taskflow-failure"), ], ) def test_expand_mapped_task_instance_with_named_index(