Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 15 additions & 14 deletions airflow-core/src/airflow/executors/base_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@
from airflow.executors.executor_loader import ExecutorLoader
from airflow.executors.workloads.callback import ExecuteCallback
from airflow.executors.workloads.task import ExecuteTask
from airflow.executors.workloads.types import state_class_for_key
from airflow.models import Log
from airflow.models.callback import CallbackKey
from airflow.observability.metrics import stats_utils
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.state import TaskInstanceState

PARALLELISM: int = conf.getint("core", "PARALLELISM")

Expand Down Expand Up @@ -76,7 +76,7 @@ def get_execution_api_server_url(conf_source: AirflowConfigParser | ExecutorConf
from airflow.configuration import AirflowConfigParser
from airflow.executors.executor_utils import ExecutorName
from airflow.executors.workloads import ExecutorWorkload
from airflow.executors.workloads.types import WorkloadKey
from airflow.executors.workloads.types import WorkloadKey, WorkloadState
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstancekey import TaskInstanceKey

Expand Down Expand Up @@ -240,8 +240,11 @@ def __repr__(self):
def start(self): # pragma: no cover
"""Executors may need to get things started."""

def log_task_event(self, *, event: str, extra: str, ti_key: TaskInstanceKey):
def log_task_event(self, *, event: str, extra: str, ti_key: WorkloadKey):
"""Add an event to the log table."""
if isinstance(ti_key, CallbackKey):
self.log.debug("Skipping log_task_event for callback key %s (event=%s)", ti_key, event)
return
self._task_event_logs.append(Log(event=event, task_instance=ti_key, extra=extra))

def queue_workload(self, workload: ExecutorWorkload, session: Session) -> None:
Expand Down Expand Up @@ -428,9 +431,7 @@ def trigger_tasks(self, open_slots: int) -> None:
# TODO: This should not be using `TaskInstanceState` here, this is just "did the process complete, or did
# it die". It is possible for the task itself to finish with success, but the state of the task to be set
# to FAILED. By using TaskInstanceState enum here it confuses matters!
def change_state(
self, key: TaskInstanceKey, state: TaskInstanceState, info=None, remove_running=True
) -> None:
def change_state(self, key: WorkloadKey, state: WorkloadState, info=None, remove_running=True) -> None:
"""
Change state of the task.

Expand All @@ -447,41 +448,41 @@ def change_state(
self.log.debug("Could not find key: %s", key)
self.event_buffer[key] = state, info

def fail(self, key: TaskInstanceKey, info=None) -> None:
def fail(self, key: WorkloadKey, info=None) -> None:
"""
Set fail state for the event.

:param info: Executor information for the task instance
:param key: Unique key for the task instance
"""
self.change_state(key, TaskInstanceState.FAILED, info)
self.change_state(key, state_class_for_key(key).FAILED, info)

def success(self, key: TaskInstanceKey, info=None) -> None:
def success(self, key: WorkloadKey, info=None) -> None:
"""
Set success state for the event.

:param info: Executor information for the task instance
:param key: Unique key for the task instance
"""
self.change_state(key, TaskInstanceState.SUCCESS, info)
self.change_state(key, state_class_for_key(key).SUCCESS, info)

def queued(self, key: TaskInstanceKey, info=None) -> None:
def queued(self, key: WorkloadKey, info=None) -> None:
"""
Set queued state for the event.

:param info: Executor information for the task instance
:param key: Unique key for the task instance
"""
self.change_state(key, TaskInstanceState.QUEUED, info)
self.change_state(key, state_class_for_key(key).QUEUED, info)

def running_state(self, key: TaskInstanceKey, info=None) -> None:
def running_state(self, key: WorkloadKey, info=None) -> None:
"""
Set running state for the event.

:param info: Executor information for the task instance
:param key: Unique key for the task instance
"""
self.change_state(key, TaskInstanceState.RUNNING, info, remove_running=False)
self.change_state(key, state_class_for_key(key).RUNNING, info, remove_running=False)

def get_event_buffer(self, dag_ids=None) -> dict[WorkloadKey, EventBufferValueType]:
"""
Expand Down
10 changes: 8 additions & 2 deletions airflow-core/src/airflow/executors/workloads/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@

from airflow.models.callback import ExecutorCallback
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.utils.state import CallbackState, TaskInstanceState

if TYPE_CHECKING:
from airflow.models.callback import CallbackKey
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.utils.state import CallbackState, TaskInstanceState

# Type aliases for workload keys and states (used by executor layer)
WorkloadKey: TypeAlias = TaskInstanceKey | CallbackKey
Expand All @@ -38,3 +38,9 @@
# Type alias for scheduler workloads (ORM models that can be routed to executors)
# Must be outside TYPE_CHECKING for use in function signatures
SchedulerWorkload: TypeAlias = TaskInstance | ExecutorCallback


def state_class_for_key(key: WorkloadKey) -> type[TaskInstanceState] | type[CallbackState]:
if isinstance(key, TaskInstanceKey):
return TaskInstanceState
return CallbackState
32 changes: 31 additions & 1 deletion airflow-core/tests/unit/executors/test_base_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from airflow.sdk import BaseOperator
from airflow.sdk.execution_time.callback_supervisor import execute_callback
from airflow.serialization.definitions.baseoperator import SerializedBaseOperator
from airflow.utils.state import State, TaskInstanceState
from airflow.utils.state import CallbackState, State, TaskInstanceState

from tests_common.test_utils.config import conf_vars
from tests_common.test_utils.markers import skip_if_force_lowest_dependencies_marker
Expand Down Expand Up @@ -100,6 +100,36 @@ def test_get_event_buffer():
assert len(executor.event_buffer) == 0


def test_log_task_event_branches_on_key_type():
executor = BaseExecutor()
ti_key = TaskInstanceKey("my_dag", "my_task", timezone.utcnow(), 1)

executor.log_task_event(event="task_event", extra="extra", ti_key=ti_key)
assert len(executor._task_event_logs) == 1

callback_key = str(UUID("00000000-0000-0000-0000-000000000001"))
executor.log_task_event(event="callback_event", extra="extra", ti_key=callback_key)
assert len(executor._task_event_logs) == 1


@pytest.mark.parametrize(
("method_name", "expected_state"),
[
("fail", CallbackState.FAILED),
("success", CallbackState.SUCCESS),
("queued", CallbackState.QUEUED),
("running_state", CallbackState.RUNNING),
],
)
def test_state_methods_pick_callback_state_for_callback_key(method_name, expected_state):
executor = BaseExecutor()
callback_key = str(UUID("00000000-0000-0000-0000-000000000002"))

getattr(executor, method_name)(callback_key)

assert executor.event_buffer[callback_key] == (expected_state, None)


def test_fail_and_success():
executor = BaseExecutor()

Expand Down
2 changes: 2 additions & 0 deletions devel-common/src/tests_common/test_utils/version_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]:
AIRFLOW_V_3_1_7_PLUS = get_base_airflow_version_tuple() >= (3, 1, 7)
AIRFLOW_V_3_1_9_PLUS = get_base_airflow_version_tuple() >= (3, 1, 9)
AIRFLOW_V_3_2_PLUS = get_base_airflow_version_tuple() >= (3, 2, 0)
AIRFLOW_V_3_3_PLUS = get_base_airflow_version_tuple() >= (3, 3, 0)

if AIRFLOW_V_3_1_PLUS:
from airflow.sdk import PokeReturnValue, timezone
Expand All @@ -61,6 +62,7 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]:
"AIRFLOW_V_3_0_1",
"AIRFLOW_V_3_1_PLUS",
"AIRFLOW_V_3_2_PLUS",
"AIRFLOW_V_3_3_PLUS",
"NOTSET",
"XCOM_RETURN_KEY",
"ArgNotSet",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,16 @@
from airflow.providers.celery.executors.celery_executor_utils import TaskTuple, WorkloadInCelery

if AIRFLOW_V_3_2_PLUS:
from airflow.executors.workloads.types import WorkloadKey as _WorkloadKey
from airflow.executors.workloads.types import (
WorkloadKey as _WorkloadKey,
WorkloadState as _WorkloadState,
)

WorkloadKey: TypeAlias = _WorkloadKey
WorkloadState: TypeAlias = _WorkloadState
else:
WorkloadKey: TypeAlias = TaskInstanceKey # type: ignore[no-redef, misc]
WorkloadState: TypeAlias = TaskInstanceState # type: ignore[no-redef, misc]


# PEP562
Expand Down Expand Up @@ -277,9 +282,7 @@ def update_all_workload_states(self) -> None:
if state:
self.update_task_state(cast("TaskInstanceKey", key), state, info)

def change_state(
self, key: TaskInstanceKey, state: TaskInstanceState, info=None, remove_running=True
) -> None:
def change_state(self, key: WorkloadKey, state: WorkloadState, info=None, remove_running=True) -> None:
super().change_state(key, state, info, remove_running=remove_running)
self.workloads.pop(key, None)

Expand Down
Loading