Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion airflow/dag_processing/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,7 @@ def _execute_task_callbacks(self, dagbag: DagBag, request: TaskCallbackRequest):
if simple_ti.task_id in dag.task_ids:
task = dag.get_task(simple_ti.task_id)
if request.is_failure_callback:
ti = TI(task, run_id=simple_ti.run_id)
ti = TI(task, run_id=simple_ti.run_id, map_index=simple_ti.map_index)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we should revive #19242, it can make the TI <-> SimpleTI conversions more future-proof.

# TODO: Use simple_ti to improve performance here in the future
ti.refresh_from_db()
ti.handle_failure_with_callback(error=request.msg, test_mode=self.UNIT_TEST_MODE)
Expand Down
14 changes: 7 additions & 7 deletions airflow/models/taskfail.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,13 @@ class TaskFail(Base):
viewonly=True,
)

def __init__(self, task, run_id, start_date, end_date, map_index):
self.dag_id = task.dag_id
self.task_id = task.task_id
self.run_id = run_id
self.map_index = map_index
self.start_date = start_date
self.end_date = end_date
def __init__(self, ti):
self.dag_id = ti.dag_id
self.task_id = ti.task_id
self.run_id = ti.run_id
self.map_index = ti.map_index
self.start_date = ti.start_date
self.end_date = ti.end_date
if self.end_date and self.start_date:
self.duration = int((self.end_date - self.start_date).total_seconds())
else:
Expand Down
165 changes: 53 additions & 112 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import hashlib
import logging
import math
import operator
import os
import pickle
import signal
Expand Down Expand Up @@ -133,6 +134,7 @@


if TYPE_CHECKING:
from airflow.models.baseoperator import BaseOperator
from airflow.models.dag import DAG, DagModel
from airflow.models.dagrun import DagRun
from airflow.models.operator import Operator
Expand Down Expand Up @@ -1901,24 +1903,15 @@ def handle_failure(
if not test_mode:
self.refresh_from_db(session)

task = self.task.unmap()
self.end_date = timezone.utcnow()
self.set_duration()
Stats.incr(f'operator_failures_{task.task_type}', 1, 1)
Stats.incr(f'operator_failures_{self.task.task_type}')
Comment on lines -1907 to +1908
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this change?

Stats.incr('ti_failures')
if not test_mode:
session.add(Log(State.FAILED, self))

# Log failure duration
session.add(
TaskFail(
task=task,
run_id=self.run_id,
start_date=self.start_date,
end_date=self.end_date,
map_index=self.map_index,
)
)
session.add(TaskFail(ti=self))

self.clear_next_method_args()

Expand All @@ -1934,20 +1927,26 @@ def handle_failure(
# only mark task instance as FAILED if the next task instance
# try_number exceeds the max_tries ... or if force_fail is truthy

task = None
try:
task = self.task.unmap()
except Exception:
self.log.error("Unable to unmap task, can't determine if we need to send an alert email or not")
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we log the traceback here with exception() instead?


if force_fail or not self.is_eligible_to_retry():
self.state = State.FAILED
email_for_state = task.email_on_failure
email_for_state = operator.attrgetter('email_on_failure')
else:
if self.state == State.QUEUED:
# We increase the try_number so as to fail the task if it fails to start after sometime
self._try_number += 1
self.state = State.UP_FOR_RETRY
email_for_state = task.email_on_retry
email_for_state = operator.attrgetter('email_on_retry')

self._log_state('Immediate failure requested. ' if force_fail else '')
if email_for_state and task.email:
if task and email_for_state(task) and task.email:
try:
self.email_alert(error)
self.email_alert(error, task)
except Exception:
self.log.exception('Failed to send email to: %s', task.email)

Expand Down Expand Up @@ -2241,11 +2240,15 @@ def render_k8s_pod_yaml(self) -> Optional[dict]:
sanitized_pod = ApiClient().sanitize_for_serialization(pod)
return sanitized_pod

def get_email_subject_content(self, exception: BaseException) -> Tuple[str, str, str]:
def get_email_subject_content(
self, exception: BaseException, task: Optional["BaseOperator"] = None
) -> Tuple[str, str, str]:
"""Get the email subject content for exceptions."""
# For a ti from DB (without ti.task), return the default value
# Reuse it for smart sensor to send default email alert
use_default = not hasattr(self, 'task')
if task is None:
task = getattr(self, 'task')
use_default = task is None
exception_html = str(exception).replace('\n', '<br>')

default_subject = 'Airflow alert: {{ti}}'
Expand Down Expand Up @@ -2312,13 +2315,14 @@ def render(key: str, content: str) -> str:

return subject, html_content, html_content_err

def email_alert(self, exception):
def email_alert(self, exception, task: "BaseOperator"):
"""Send alert email with exception information."""
subject, html_content, html_content_err = self.get_email_subject_content(exception)
subject, html_content, html_content_err = self.get_email_subject_content(exception, task=task)
assert task.email
try:
send_email(self.task.email, subject, html_content)
send_email(task.email, subject, html_content)
except Exception:
send_email(self.task.email, subject, html_content_err)
send_email(task.email, subject, html_content_err)

def set_duration(self) -> None:
"""Set TI duration"""
Expand Down Expand Up @@ -2573,9 +2577,10 @@ def __init__(
dag_id: str,
task_id: str,
run_id: str,
start_date: datetime,
end_date: datetime,
start_date: Optional[datetime],
end_date: Optional[datetime],
try_number: int,
map_index: int,
state: str,
executor_config: Any,
pool: str,
Expand All @@ -2584,21 +2589,20 @@ def __init__(
run_as_user: Optional[str] = None,
priority_weight: Optional[int] = None,
):
self._dag_id: str = dag_id
self._task_id: str = task_id
self._run_id: str = run_id
self._start_date: datetime = start_date
self._end_date: datetime = end_date
self._try_number: int = try_number
self._state: str = state
self._executor_config: Any = executor_config
self._run_as_user: Optional[str] = None
self._run_as_user = run_as_user
self._pool: str = pool
self._priority_weight: Optional[int] = None
self._priority_weight = priority_weight
self._queue: str = queue
self._key = key
self.dag_id = dag_id
self.task_id = task_id
self.run_id = run_id
self.map_index = map_index
self.start_date = start_date
self.end_date = end_date
self.try_number = try_number
self.state = state
self.executor_config = executor_config
self.run_as_user = run_as_user
self.pool = pool
self.priority_weight = priority_weight
self.queue = queue
self.key = key

def __eq__(self, other):
if isinstance(other, self.__class__):
Expand All @@ -2611,6 +2615,7 @@ def from_ti(cls, ti: TaskInstance):
dag_id=ti.dag_id,
task_id=ti.task_id,
run_id=ti.run_id,
map_index=ti.map_index,
start_date=ti.start_date,
end_date=ti.end_date,
try_number=ti.try_number,
Expand All @@ -2625,80 +2630,16 @@ def from_ti(cls, ti: TaskInstance):

@classmethod
def from_dict(cls, obj_dict: dict):
ti_key = obj_dict.get('_key', [])
start_date: Union[Any, datetime] = (
datetime.fromisoformat(str(obj_dict.get('_start_date')))
if obj_dict.get('_start_date') is not None
else None
)
end_date: Union[Any, datetime] = (
datetime.fromisoformat(str(obj_dict.get('_end_date')))
if obj_dict.get('_end_date') is not None
else None
)
return cls(
dag_id=str(obj_dict['_dag_id']),
task_id=str(obj_dict.get('_task_id')),
run_id=str(obj_dict.get('_run_id')),
start_date=start_date,
end_date=end_date,
try_number=obj_dict.get('_try_number', 1),
state=str(obj_dict.get('_state')),
executor_config=obj_dict.get('_executor_config'),
run_as_user=obj_dict.get('_run_as_user', None),
pool=str(obj_dict.get('_pool')),
priority_weight=obj_dict.get('_priority_weight', None),
queue=str(obj_dict.get('_queue')),
key=TaskInstanceKey(ti_key[0], ti_key[1], ti_key[2], ti_key[3], ti_key[4]),
)

@property
def dag_id(self) -> str:
return self._dag_id

@property
def task_id(self) -> str:
return self._task_id

@property
def run_id(self) -> str:
return self._run_id

@property
def start_date(self) -> datetime:
return self._start_date

@property
def end_date(self) -> datetime:
return self._end_date

@property
def try_number(self) -> int:
return self._try_number

@property
def state(self) -> str:
return self._state

@property
def pool(self) -> str:
return self._pool

@property
def priority_weight(self) -> Optional[int]:
return self._priority_weight

@property
def queue(self) -> str:
return self._queue

@property
def key(self) -> TaskInstanceKey:
return self._key

@property
def executor_config(self):
return self._executor_config
ti_key = TaskInstanceKey(*obj_dict.pop('key'))
start_date = None
end_date = None
start_date_str: Optional[str] = obj_dict.pop('start_date')
end_date_str: Optional[str] = obj_dict.pop('end_date')
if start_date_str:
start_date = timezone.parse(start_date_str)
if end_date_str:
end_date = timezone.parse(end_date_str)
return cls(**obj_dict, start_date=start_date, end_date=end_date, key=ti_key)


STATICA_HACK = True
Expand Down
10 changes: 1 addition & 9 deletions tests/api/common/test_delete_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,7 @@ def setup_dag_models(self, for_sub_dag=False):
event="varimport",
)
)
session.add(
TF(
task=task,
run_id=ti.run_id,
start_date=test_date,
end_date=test_date,
map_index=ti.map_index,
)
)
session.add(TF(ti=ti))
session.add(
TR(
task=ti.task,
Expand Down
2 changes: 1 addition & 1 deletion tests/callbacks/test_callback_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,6 @@ class TestCallbackRequest(unittest.TestCase):
def test_from_json(self, input, request_class):
json_str = input.to_json()

result = getattr(request_class, 'from_json')(json_str=json_str)
result = request_class.from_json(json_str=json_str)

self.assertEqual(result, input)
5 changes: 3 additions & 2 deletions tests/jobs/test_scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def test_process_executor_events(self, mock_stats_incr, mock_task_callback, dag_
mock_stats_incr.assert_has_calls(
[
mock.call('scheduler.tasks.killed_externally'),
mock.call('operator_failures_EmptyOperator', 1, 1),
mock.call('operator_failures_EmptyOperator'),
mock.call('ti_failures'),
],
any_order=True,
Expand Down Expand Up @@ -303,7 +303,7 @@ def test_process_executor_events_with_no_callback(self, mock_stats_incr, mock_ta
mock_stats_incr.assert_has_calls(
[
mock.call('scheduler.tasks.killed_externally'),
mock.call('operator_failures_EmptyOperator', 1, 1),
mock.call('operator_failures_EmptyOperator'),
mock.call('ti_failures'),
],
any_order=True,
Expand Down Expand Up @@ -3793,6 +3793,7 @@ def test_find_zombies(self):
assert ti.dag_id == requests[0].simple_task_instance.dag_id
assert ti.task_id == requests[0].simple_task_instance.task_id
assert ti.run_id == requests[0].simple_task_instance.run_id
assert ti.map_index == requests[0].simple_task_instance.map_index

session.query(TaskInstance).delete()
session.query(LocalTaskJob).delete()
Expand Down