diff --git a/airflow-core/src/airflow/executors/base_executor.py b/airflow-core/src/airflow/executors/base_executor.py index 75ce7c9f5dfa4..36d8c86f5b737 100644 --- a/airflow-core/src/airflow/executors/base_executor.py +++ b/airflow-core/src/airflow/executors/base_executor.py @@ -35,11 +35,11 @@ from airflow.executors.executor_loader import ExecutorLoader from airflow.executors.workloads.callback import ExecuteCallback from airflow.executors.workloads.task import ExecuteTask +from airflow.executors.workloads.types import state_class_for_key from airflow.models import Log from airflow.models.callback import CallbackKey from airflow.observability.metrics import stats_utils from airflow.utils.log.logging_mixin import LoggingMixin -from airflow.utils.state import TaskInstanceState PARALLELISM: int = conf.getint("core", "PARALLELISM") @@ -76,7 +76,7 @@ def get_execution_api_server_url(conf_source: AirflowConfigParser | ExecutorConf from airflow.configuration import AirflowConfigParser from airflow.executors.executor_utils import ExecutorName from airflow.executors.workloads import ExecutorWorkload - from airflow.executors.workloads.types import WorkloadKey + from airflow.executors.workloads.types import WorkloadKey, WorkloadState from airflow.models.taskinstance import TaskInstance from airflow.models.taskinstancekey import TaskInstanceKey @@ -240,8 +240,11 @@ def __repr__(self): def start(self): # pragma: no cover """Executors may need to get things started.""" - def log_task_event(self, *, event: str, extra: str, ti_key: TaskInstanceKey): + def log_task_event(self, *, event: str, extra: str, ti_key: WorkloadKey): """Add an event to the log table.""" + if isinstance(ti_key, CallbackKey): + self.log.debug("Skipping log_task_event for callback key %s (event=%s)", ti_key, event) + return self._task_event_logs.append(Log(event=event, task_instance=ti_key, extra=extra)) def queue_workload(self, workload: ExecutorWorkload, session: Session) -> None: @@ -428,9 +431,7 @@ def trigger_tasks(self, open_slots: int) -> None: # TODO: This should not be using `TaskInstanceState` here, this is just "did the process complete, or did # it die". It is possible for the task itself to finish with success, but the state of the task to be set # to FAILED. By using TaskInstanceState enum here it confuses matters! - def change_state( - self, key: TaskInstanceKey, state: TaskInstanceState, info=None, remove_running=True - ) -> None: + def change_state(self, key: WorkloadKey, state: WorkloadState, info=None, remove_running=True) -> None: """ Change state of the task. @@ -447,41 +448,41 @@ def change_state( self.log.debug("Could not find key: %s", key) self.event_buffer[key] = state, info - def fail(self, key: TaskInstanceKey, info=None) -> None: + def fail(self, key: WorkloadKey, info=None) -> None: """ Set fail state for the event. :param info: Executor information for the task instance :param key: Unique key for the task instance """ - self.change_state(key, TaskInstanceState.FAILED, info) + self.change_state(key, state_class_for_key(key).FAILED, info) - def success(self, key: TaskInstanceKey, info=None) -> None: + def success(self, key: WorkloadKey, info=None) -> None: """ Set success state for the event. :param info: Executor information for the task instance :param key: Unique key for the task instance """ - self.change_state(key, TaskInstanceState.SUCCESS, info) + self.change_state(key, state_class_for_key(key).SUCCESS, info) - def queued(self, key: TaskInstanceKey, info=None) -> None: + def queued(self, key: WorkloadKey, info=None) -> None: """ Set queued state for the event. :param info: Executor information for the task instance :param key: Unique key for the task instance """ - self.change_state(key, TaskInstanceState.QUEUED, info) + self.change_state(key, state_class_for_key(key).QUEUED, info) - def running_state(self, key: TaskInstanceKey, info=None) -> None: + def running_state(self, key: WorkloadKey, info=None) -> None: """ Set running state for the event. :param info: Executor information for the task instance :param key: Unique key for the task instance """ - self.change_state(key, TaskInstanceState.RUNNING, info, remove_running=False) + self.change_state(key, state_class_for_key(key).RUNNING, info, remove_running=False) def get_event_buffer(self, dag_ids=None) -> dict[WorkloadKey, EventBufferValueType]: """ diff --git a/airflow-core/src/airflow/executors/workloads/types.py b/airflow-core/src/airflow/executors/workloads/types.py index 31cda7028466f..61f7bf037d24e 100644 --- a/airflow-core/src/airflow/executors/workloads/types.py +++ b/airflow-core/src/airflow/executors/workloads/types.py @@ -22,11 +22,11 @@ from airflow.models.callback import ExecutorCallback from airflow.models.taskinstance import TaskInstance +from airflow.models.taskinstancekey import TaskInstanceKey +from airflow.utils.state import CallbackState, TaskInstanceState if TYPE_CHECKING: from airflow.models.callback import CallbackKey - from airflow.models.taskinstancekey import TaskInstanceKey - from airflow.utils.state import CallbackState, TaskInstanceState # Type aliases for workload keys and states (used by executor layer) WorkloadKey: TypeAlias = TaskInstanceKey | CallbackKey @@ -38,3 +38,9 @@ # Type alias for scheduler workloads (ORM models that can be routed to executors) # Must be outside TYPE_CHECKING for use in function signatures SchedulerWorkload: TypeAlias = TaskInstance | ExecutorCallback + + +def state_class_for_key(key: WorkloadKey) -> type[TaskInstanceState] | type[CallbackState]: + if isinstance(key, TaskInstanceKey): + return TaskInstanceState + return CallbackState diff --git a/airflow-core/tests/unit/executors/test_base_executor.py b/airflow-core/tests/unit/executors/test_base_executor.py index 22f68d34963c7..61345dddb67db 100644 --- a/airflow-core/tests/unit/executors/test_base_executor.py +++ b/airflow-core/tests/unit/executors/test_base_executor.py @@ -41,7 +41,7 @@ from airflow.sdk import BaseOperator from airflow.sdk.execution_time.callback_supervisor import execute_callback from airflow.serialization.definitions.baseoperator import SerializedBaseOperator -from airflow.utils.state import State, TaskInstanceState +from airflow.utils.state import CallbackState, State, TaskInstanceState from tests_common.test_utils.config import conf_vars from tests_common.test_utils.markers import skip_if_force_lowest_dependencies_marker @@ -100,6 +100,36 @@ def test_get_event_buffer(): assert len(executor.event_buffer) == 0 +def test_log_task_event_branches_on_key_type(): + executor = BaseExecutor() + ti_key = TaskInstanceKey("my_dag", "my_task", timezone.utcnow(), 1) + + executor.log_task_event(event="task_event", extra="extra", ti_key=ti_key) + assert len(executor._task_event_logs) == 1 + + callback_key = str(UUID("00000000-0000-0000-0000-000000000001")) + executor.log_task_event(event="callback_event", extra="extra", ti_key=callback_key) + assert len(executor._task_event_logs) == 1 + + +@pytest.mark.parametrize( + ("method_name", "expected_state"), + [ + ("fail", CallbackState.FAILED), + ("success", CallbackState.SUCCESS), + ("queued", CallbackState.QUEUED), + ("running_state", CallbackState.RUNNING), + ], +) +def test_state_methods_pick_callback_state_for_callback_key(method_name, expected_state): + executor = BaseExecutor() + callback_key = str(UUID("00000000-0000-0000-0000-000000000002")) + + getattr(executor, method_name)(callback_key) + + assert executor.event_buffer[callback_key] == (expected_state, None) + + def test_fail_and_success(): executor = BaseExecutor() diff --git a/devel-common/src/tests_common/test_utils/version_compat.py b/devel-common/src/tests_common/test_utils/version_compat.py index 7921f02529668..635b0e08e3350 100644 --- a/devel-common/src/tests_common/test_utils/version_compat.py +++ b/devel-common/src/tests_common/test_utils/version_compat.py @@ -40,6 +40,7 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]: AIRFLOW_V_3_1_7_PLUS = get_base_airflow_version_tuple() >= (3, 1, 7) AIRFLOW_V_3_1_9_PLUS = get_base_airflow_version_tuple() >= (3, 1, 9) AIRFLOW_V_3_2_PLUS = get_base_airflow_version_tuple() >= (3, 2, 0) +AIRFLOW_V_3_3_PLUS = get_base_airflow_version_tuple() >= (3, 3, 0) if AIRFLOW_V_3_1_PLUS: from airflow.sdk import PokeReturnValue, timezone @@ -61,6 +62,7 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]: "AIRFLOW_V_3_0_1", "AIRFLOW_V_3_1_PLUS", "AIRFLOW_V_3_2_PLUS", + "AIRFLOW_V_3_3_PLUS", "NOTSET", "XCOM_RETURN_KEY", "ArgNotSet", diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_executor.py b/providers/celery/src/airflow/providers/celery/executors/celery_executor.py index db4df5ab7bcb5..d30bca1e570b7 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_executor.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_executor.py @@ -64,11 +64,16 @@ from airflow.providers.celery.executors.celery_executor_utils import TaskTuple, WorkloadInCelery if AIRFLOW_V_3_2_PLUS: - from airflow.executors.workloads.types import WorkloadKey as _WorkloadKey + from airflow.executors.workloads.types import ( + WorkloadKey as _WorkloadKey, + WorkloadState as _WorkloadState, + ) WorkloadKey: TypeAlias = _WorkloadKey + WorkloadState: TypeAlias = _WorkloadState else: WorkloadKey: TypeAlias = TaskInstanceKey # type: ignore[no-redef, misc] + WorkloadState: TypeAlias = TaskInstanceState # type: ignore[no-redef, misc] # PEP562 @@ -277,9 +282,7 @@ def update_all_workload_states(self) -> None: if state: self.update_task_state(cast("TaskInstanceKey", key), state, info) - def change_state( - self, key: TaskInstanceKey, state: TaskInstanceState, info=None, remove_running=True - ) -> None: + def change_state(self, key: WorkloadKey, state: WorkloadState, info=None, remove_running=True) -> None: super().change_state(key, state, info, remove_running=remove_running) self.workloads.pop(key, None)