From 810ccc04cf6dd249998c49ad758fef79e6f3ce83 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Wed, 20 Apr 2022 14:29:11 +0100 Subject: [PATCH 1/4] Fix TI failure handling when task cannot be unmapped. At first glance this looks like a lot of un-related changed, but it is all related to handling errors in unmapping: - Ensure that SimpleTaskInstance (and thus the Zombie callback) knows about map_index, and simplify the code for SimpleTaskInstance -- no need for properties, just attributes works. - Be able to create a TaskFail from a TI, not a Task. This is so that we can create the TaskFail with the mapped task so we can delay unmapping the task in TI.handle_failure as long as possible. - Change email_alert and get_email_subject_content to take the task so we can pass the unmapped Task around. --- airflow/dag_processing/processor.py | 2 +- airflow/models/taskfail.py | 14 +-- airflow/models/taskinstance.py | 165 +++++++++------------------- tests/jobs/test_scheduler_job.py | 1 + 4 files changed, 62 insertions(+), 120 deletions(-) diff --git a/airflow/dag_processing/processor.py b/airflow/dag_processing/processor.py index d60d99877f40c..469b55cfeb621 100644 --- a/airflow/dag_processing/processor.py +++ b/airflow/dag_processing/processor.py @@ -604,7 +604,7 @@ def _execute_task_callbacks(self, dagbag: DagBag, request: TaskCallbackRequest): if simple_ti.task_id in dag.task_ids: task = dag.get_task(simple_ti.task_id) if request.is_failure_callback: - ti = TI(task, run_id=simple_ti.run_id) + ti = TI(task, run_id=simple_ti.run_id, map_index=simple_ti.map_index) # TODO: Use simple_ti to improve performance here in the future ti.refresh_from_db() ti.handle_failure_with_callback(error=request.msg, test_mode=self.UNIT_TEST_MODE) diff --git a/airflow/models/taskfail.py b/airflow/models/taskfail.py index a4bd102b5696f..f7de99c308cac 100644 --- a/airflow/models/taskfail.py +++ b/airflow/models/taskfail.py @@ -63,13 +63,13 @@ class TaskFail(Base): viewonly=True, ) - def __init__(self, task, run_id, start_date, end_date, map_index): - self.dag_id = task.dag_id - self.task_id = task.task_id - self.run_id = run_id - self.map_index = map_index - self.start_date = start_date - self.end_date = end_date + def __init__(self, ti): + self.dag_id = ti.dag_id + self.task_id = ti.task_id + self.run_id = ti.run_id + self.map_index = ti.map_index + self.start_date = ti.start_date + self.end_date = ti.end_date if self.end_date and self.start_date: self.duration = int((self.end_date - self.start_date).total_seconds()) else: diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 48d3a047fb183..a3844b0feab7a 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -20,6 +20,7 @@ import hashlib import logging import math +import operator import os import pickle import signal @@ -133,6 +134,7 @@ if TYPE_CHECKING: + from airflow.models.baseoperator import BaseOperator from airflow.models.dag import DAG, DagModel from airflow.models.dagrun import DagRun from airflow.models.operator import Operator @@ -1901,24 +1903,15 @@ def handle_failure( if not test_mode: self.refresh_from_db(session) - task = self.task.unmap() self.end_date = timezone.utcnow() self.set_duration() - Stats.incr(f'operator_failures_{task.task_type}', 1, 1) + Stats.incr(f'operator_failures_{self.task.task_type}') Stats.incr('ti_failures') if not test_mode: session.add(Log(State.FAILED, self)) # Log failure duration - session.add( - TaskFail( - task=task, - run_id=self.run_id, - start_date=self.start_date, - end_date=self.end_date, - map_index=self.map_index, - ) - ) + session.add(TaskFail(ti=self)) self.clear_next_method_args() @@ -1934,20 +1927,26 @@ def handle_failure( # only mark task instance as FAILED if the next task instance # try_number exceeds the max_tries ... or if force_fail is truthy + task = None + try: + task = self.task.unmap() + except Exception: + self.log.error("Unable to unmap task, can't determine if we need to send an alert email ot not") + if force_fail or not self.is_eligible_to_retry(): self.state = State.FAILED - email_for_state = task.email_on_failure + email_for_state = operator.attrgetter('email_on_failure') else: if self.state == State.QUEUED: # We increase the try_number so as to fail the task if it fails to start after sometime self._try_number += 1 self.state = State.UP_FOR_RETRY - email_for_state = task.email_on_retry + email_for_state = operator.attrgetter('email_on_retry') self._log_state('Immediate failure requested. ' if force_fail else '') - if email_for_state and task.email: + if task and email_for_state(task) and task.email: try: - self.email_alert(error) + self.email_alert(error, task) except Exception: self.log.exception('Failed to send email to: %s', task.email) @@ -2241,11 +2240,15 @@ def render_k8s_pod_yaml(self) -> Optional[dict]: sanitized_pod = ApiClient().sanitize_for_serialization(pod) return sanitized_pod - def get_email_subject_content(self, exception: BaseException) -> Tuple[str, str, str]: + def get_email_subject_content( + self, exception: BaseException, task: Optional["BaseOperator"] = None + ) -> Tuple[str, str, str]: """Get the email subject content for exceptions.""" # For a ti from DB (without ti.task), return the default value # Reuse it for smart sensor to send default email alert - use_default = not hasattr(self, 'task') + if task is None: + task = getattr(self, 'task') + use_default = task is None exception_html = str(exception).replace('\n', '
') default_subject = 'Airflow alert: {{ti}}' @@ -2312,13 +2315,14 @@ def render(key: str, content: str) -> str: return subject, html_content, html_content_err - def email_alert(self, exception): + def email_alert(self, exception, task: "BaseOperator"): """Send alert email with exception information.""" - subject, html_content, html_content_err = self.get_email_subject_content(exception) + subject, html_content, html_content_err = self.get_email_subject_content(exception, task=task) + assert task.email try: - send_email(self.task.email, subject, html_content) + send_email(task.email, subject, html_content) except Exception: - send_email(self.task.email, subject, html_content_err) + send_email(task.email, subject, html_content_err) def set_duration(self) -> None: """Set TI duration""" @@ -2573,9 +2577,10 @@ def __init__( dag_id: str, task_id: str, run_id: str, - start_date: datetime, - end_date: datetime, + start_date: Optional[datetime], + end_date: Optional[datetime], try_number: int, + map_index: int, state: str, executor_config: Any, pool: str, @@ -2584,21 +2589,20 @@ def __init__( run_as_user: Optional[str] = None, priority_weight: Optional[int] = None, ): - self._dag_id: str = dag_id - self._task_id: str = task_id - self._run_id: str = run_id - self._start_date: datetime = start_date - self._end_date: datetime = end_date - self._try_number: int = try_number - self._state: str = state - self._executor_config: Any = executor_config - self._run_as_user: Optional[str] = None - self._run_as_user = run_as_user - self._pool: str = pool - self._priority_weight: Optional[int] = None - self._priority_weight = priority_weight - self._queue: str = queue - self._key = key + self.dag_id = dag_id + self.task_id = task_id + self.run_id = run_id + self.map_index = map_index + self.start_date = start_date + self.end_date = end_date + self.try_number = try_number + self.state = state + self.executor_config = executor_config + self.run_as_user = run_as_user + self.pool = pool + self.priority_weight = priority_weight + self.queue = queue + self.key = key def __eq__(self, other): if isinstance(other, self.__class__): @@ -2611,6 +2615,7 @@ def from_ti(cls, ti: TaskInstance): dag_id=ti.dag_id, task_id=ti.task_id, run_id=ti.run_id, + map_index=ti.map_index, start_date=ti.start_date, end_date=ti.end_date, try_number=ti.try_number, @@ -2625,80 +2630,16 @@ def from_ti(cls, ti: TaskInstance): @classmethod def from_dict(cls, obj_dict: dict): - ti_key = obj_dict.get('_key', []) - start_date: Union[Any, datetime] = ( - datetime.fromisoformat(str(obj_dict.get('_start_date'))) - if obj_dict.get('_start_date') is not None - else None - ) - end_date: Union[Any, datetime] = ( - datetime.fromisoformat(str(obj_dict.get('_end_date'))) - if obj_dict.get('_end_date') is not None - else None - ) - return cls( - dag_id=str(obj_dict['_dag_id']), - task_id=str(obj_dict.get('_task_id')), - run_id=str(obj_dict.get('_run_id')), - start_date=start_date, - end_date=end_date, - try_number=obj_dict.get('_try_number', 1), - state=str(obj_dict.get('_state')), - executor_config=obj_dict.get('_executor_config'), - run_as_user=obj_dict.get('_run_as_user', None), - pool=str(obj_dict.get('_pool')), - priority_weight=obj_dict.get('_priority_weight', None), - queue=str(obj_dict.get('_queue')), - key=TaskInstanceKey(ti_key[0], ti_key[1], ti_key[2], ti_key[3], ti_key[4]), - ) - - @property - def dag_id(self) -> str: - return self._dag_id - - @property - def task_id(self) -> str: - return self._task_id - - @property - def run_id(self) -> str: - return self._run_id - - @property - def start_date(self) -> datetime: - return self._start_date - - @property - def end_date(self) -> datetime: - return self._end_date - - @property - def try_number(self) -> int: - return self._try_number - - @property - def state(self) -> str: - return self._state - - @property - def pool(self) -> str: - return self._pool - - @property - def priority_weight(self) -> Optional[int]: - return self._priority_weight - - @property - def queue(self) -> str: - return self._queue - - @property - def key(self) -> TaskInstanceKey: - return self._key - - @property - def executor_config(self): - return self._executor_config + ti_key = TaskInstanceKey(**obj_dict.pop('key')) + start_date = None + end_date = None + start_date_str: Optional[str] = obj_dict.pop('start_date') + end_date_str: Optional[str] = obj_dict.pop('end_date') + if start_date_str: + start_date = timezone.parse(start_date_str) + if end_date_str: + end_date = timezone.parse(end_date_str) + return cls(**obj_dict, start_date=start_date, end_date=end_date, key=ti_key) STATICA_HACK = True diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index ad3e196608307..0b9f3b56b3679 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -3793,6 +3793,7 @@ def test_find_zombies(self): assert ti.dag_id == requests[0].simple_task_instance.dag_id assert ti.task_id == requests[0].simple_task_instance.task_id assert ti.run_id == requests[0].simple_task_instance.run_id + assert ti.map_index == requests[0].simple_task_instance.map_index session.query(TaskInstance).delete() session.query(LocalTaskJob).delete() From 0c31affaf30cf5320633dc0624ec57c226e2cc5c Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Wed, 20 Apr 2022 16:07:42 +0100 Subject: [PATCH 2/4] fixup! Fix TI failure handling when task cannot be unmapped. --- airflow/models/taskinstance.py | 2 +- tests/api/common/test_delete_dag.py | 10 +--------- tests/callbacks/test_callback_requests.py | 2 +- 3 files changed, 3 insertions(+), 11 deletions(-) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index a3844b0feab7a..c3cb5024a6aac 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -2630,7 +2630,7 @@ def from_ti(cls, ti: TaskInstance): @classmethod def from_dict(cls, obj_dict: dict): - ti_key = TaskInstanceKey(**obj_dict.pop('key')) + ti_key = TaskInstanceKey(*obj_dict.pop('key')) start_date = None end_date = None start_date_str: Optional[str] = obj_dict.pop('start_date') diff --git a/tests/api/common/test_delete_dag.py b/tests/api/common/test_delete_dag.py index 9a97b1453584c..2830020d29114 100644 --- a/tests/api/common/test_delete_dag.py +++ b/tests/api/common/test_delete_dag.py @@ -96,15 +96,7 @@ def setup_dag_models(self, for_sub_dag=False): event="varimport", ) ) - session.add( - TF( - task=task, - run_id=ti.run_id, - start_date=test_date, - end_date=test_date, - map_index=ti.map_index, - ) - ) + session.add(TF(ti=ti)) session.add( TR( task=ti.task, diff --git a/tests/callbacks/test_callback_requests.py b/tests/callbacks/test_callback_requests.py index cf7b14ae2c751..286d64eaa156e 100644 --- a/tests/callbacks/test_callback_requests.py +++ b/tests/callbacks/test_callback_requests.py @@ -65,6 +65,6 @@ class TestCallbackRequest(unittest.TestCase): def test_from_json(self, input, request_class): json_str = input.to_json() - result = getattr(request_class, 'from_json')(json_str=json_str) + result = request_class.from_json(json_str=json_str) self.assertEqual(result, input) From 1a523ae640ebcf4320806733eaaa062395049123 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Wed, 20 Apr 2022 16:44:45 +0100 Subject: [PATCH 3/4] Update airflow/models/taskinstance.py Co-authored-by: Jed Cunningham <66968678+jedcunningham@users.noreply.github.com> --- airflow/models/taskinstance.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index c3cb5024a6aac..490691ec2c78b 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -1931,7 +1931,7 @@ def handle_failure( try: task = self.task.unmap() except Exception: - self.log.error("Unable to unmap task, can't determine if we need to send an alert email ot not") + self.log.error("Unable to unmap task, can't determine if we need to send an alert email or not") if force_fail or not self.is_eligible_to_retry(): self.state = State.FAILED From f59a16b870fbbe2260a7c57fa6e99cecb2564436 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Wed, 20 Apr 2022 16:46:09 +0100 Subject: [PATCH 4/4] fixup! Fix TI failure handling when task cannot be unmapped. --- tests/jobs/test_scheduler_job.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index 0b9f3b56b3679..4a4b89e619884 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -246,7 +246,7 @@ def test_process_executor_events(self, mock_stats_incr, mock_task_callback, dag_ mock_stats_incr.assert_has_calls( [ mock.call('scheduler.tasks.killed_externally'), - mock.call('operator_failures_EmptyOperator', 1, 1), + mock.call('operator_failures_EmptyOperator'), mock.call('ti_failures'), ], any_order=True, @@ -303,7 +303,7 @@ def test_process_executor_events_with_no_callback(self, mock_stats_incr, mock_ta mock_stats_incr.assert_has_calls( [ mock.call('scheduler.tasks.killed_externally'), - mock.call('operator_failures_EmptyOperator', 1, 1), + mock.call('operator_failures_EmptyOperator'), mock.call('ti_failures'), ], any_order=True,