Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion airflow-core/.pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ repos:
^src/airflow/dag_processing/processor\.py$|
^src/airflow/datasets/metadata\.py$|
^src/airflow/exceptions\.py$|
^src/airflow/executors/local_executor\.py$|
^src/airflow/executors/base_executor\.py$|
^src/airflow/jobs/triggerer_job_runner\.py$|
^src/airflow/lineage/hook\.py$|
^src/airflow/listeners/spec/asset\.py$|
Expand Down
81 changes: 77 additions & 4 deletions airflow-core/src/airflow/executors/base_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
from airflow.configuration import conf
from airflow.executors import workloads
from airflow.executors.executor_loader import ExecutorLoader
from airflow.executors.workloads.callback import ExecuteCallback
from airflow.executors.workloads.task import ExecuteTask
from airflow.models import Log
from airflow.models.callback import CallbackKey
from airflow.observability.metrics import stats_utils
Expand All @@ -51,6 +53,7 @@
from airflow.callbacks.callback_requests import CallbackRequest
from airflow.cli.cli_config import GroupCommand
from airflow.executors.executor_utils import ExecutorName
from airflow.executors.workloads import ExecutorWorkload
from airflow.executors.workloads.types import WorkloadKey
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstancekey import TaskInstanceKey
Expand Down Expand Up @@ -219,7 +222,7 @@ def log_task_event(self, *, event: str, extra: str, ti_key: TaskInstanceKey):
"""Add an event to the log table."""
self._task_event_logs.append(Log(event=event, task_instance=ti_key, extra=extra))

def queue_workload(self, workload: workloads.All, session: Session) -> None:
def queue_workload(self, workload: ExecutorWorkload, session: Session) -> None:
if isinstance(workload, workloads.ExecuteTask):
ti = workload.ti
self.queued_tasks[ti.key] = workload
Expand All @@ -237,7 +240,7 @@ def queue_workload(self, workload: workloads.All, session: Session) -> None:
f"Workload must be one of: ExecuteTask, ExecuteCallback."
)

def _get_workloads_to_schedule(self, open_slots: int) -> list[tuple[WorkloadKey, workloads.All]]:
def _get_workloads_to_schedule(self, open_slots: int) -> list[tuple[WorkloadKey, ExecutorWorkload]]:
"""
Select and return the next batch of workloads to schedule, respecting priority policy.

Expand All @@ -246,7 +249,7 @@ def _get_workloads_to_schedule(self, open_slots: int) -> list[tuple[WorkloadKey,

:param open_slots: Number of available execution slots
"""
workloads_to_schedule: list[tuple[WorkloadKey, workloads.All]] = []
workloads_to_schedule: list[tuple[WorkloadKey, ExecutorWorkload]] = []

if self.queued_callbacks:
for key, workload in self.queued_callbacks.items():
Expand All @@ -262,7 +265,7 @@ def _get_workloads_to_schedule(self, open_slots: int) -> list[tuple[WorkloadKey,

return workloads_to_schedule

def _process_workloads(self, workloads: Sequence[workloads.All]) -> None:
def _process_workloads(self, workloads: Sequence[ExecutorWorkload]) -> None:
"""
Process the given workloads.

Expand Down Expand Up @@ -579,6 +582,76 @@ def get_cli_commands() -> list[GroupCommand]:
"""
return []

@staticmethod
def run_workload(
workload: ExecutorWorkload,
*,
server: str | None = None,
dry_run: bool = False,
subprocess_logs_to_stdout: bool = False,
proctitle: str | None = None,
) -> int:
"""
Pass the workload to the appropriate supervisor based on workload type.

Workload-specific attributes (log_path, sentry_integration, bundle_info, etc.) are read from the
workload object itself.

:param workload: The ``ExecutorWorkload`` to execute.
:param server: Base URL of the API server (used by task workloads).
:param dry_run: If True, execute without actual task execution (simulate run).
:param subprocess_logs_to_stdout: Should task logs also be sent to stdout via the main logger.
:param proctitle: Process title to set for this workload. If not provided, defaults to
``"airflow supervisor: <workload.display_name>"``.
:return: Exit code of the process.
"""
try:
from setproctitle import setproctitle

setproctitle(proctitle or f"airflow supervisor: {workload.display_name}")
except ImportError:
pass

# Resolve server URL from config when not explicitly provided.
# For example, team-specific executors may wish to pass their own server URL.
if server is None:
base_url = conf.get("api", "base_url", fallback="/")
if base_url.startswith("/"):
base_url = f"http://localhost:8080{base_url}"
server = conf.get(
"core",
"execution_api_server_url",
fallback=f"{base_url.rstrip('/')}/execution/",
)

if isinstance(workload, ExecuteTask):
from airflow.sdk.execution_time.supervisor import supervise_task

# workload.ti is a TaskInstanceDTO which duck-types as TaskInstance.
# TODO: Create a protocol for this.
return supervise_task(
ti=workload.ti, # type: ignore[arg-type]
bundle_info=workload.bundle_info,
dag_rel_path=workload.dag_rel_path,
token=workload.token,
server=server,
dry_run=dry_run,
log_path=workload.log_path,
subprocess_logs_to_stdout=subprocess_logs_to_stdout,
sentry_integration=getattr(workload, "sentry_integration", ""),
)
if isinstance(workload, ExecuteCallback):
from airflow.sdk.execution_time.callback_supervisor import supervise_callback

return supervise_callback(
id=workload.callback.id,
callback_path=workload.callback.data.get("path", ""),
callback_kwargs=workload.callback.data.get("kwargs", {}),
log_path=workload.log_path,
bundle_info=workload.bundle_info,
)
raise ValueError(f"Unknown workload type: {type(workload).__name__}")

@classmethod
def _get_parser(cls) -> argparse.ArgumentParser:
"""
Expand Down
116 changes: 37 additions & 79 deletions airflow-core/src/airflow/executors/local_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,7 @@

import structlog

from airflow.executors import workloads
from airflow.executors.base_executor import BaseExecutor
from airflow.executors.workloads.callback import execute_callback_workload
from airflow.utils.state import CallbackState, TaskInstanceState

# add logger to parameter of setproctitle to support logging
if sys.platform == "darwin":
Expand All @@ -49,11 +46,23 @@
setproctitle = lambda title, logger: real_setproctitle(title)

if TYPE_CHECKING:
from structlog.typing import FilteringBoundLogger as Logger

from airflow.executors.workloads import ExecutorWorkload
from airflow.executors.workloads.types import WorkloadResultType


def _get_execution_api_server_url(team_conf) -> str:
"""
Resolve the execution API server URL from team-specific configuration.

:param team_conf: Team-specific executor configuration (ExecutorConf or AirflowConfigParser)
"""
base_url = team_conf.get("api", "base_url", fallback="/")
if base_url.startswith("/"):
base_url = f"http://localhost:8080{base_url}"
default_execution_api_server = f"{base_url.rstrip('/')}/execution/"
return team_conf.get("core", "execution_api_server_url", fallback=default_execution_api_server)


def _get_executor_process_title_prefix(team_name: str | None) -> str:
"""
Build the process title prefix for LocalExecutor workers.
Expand All @@ -66,7 +75,7 @@ def _get_executor_process_title_prefix(team_name: str | None) -> str:

def _run_worker(
logger_name: str,
input: SimpleQueue[workloads.All | None],
input: SimpleQueue[ExecutorWorkload | None],
output: Queue[WorkloadResultType],
unread_messages: multiprocessing.sharedctypes.Synchronized[int],
team_conf,
Expand Down Expand Up @@ -99,74 +108,20 @@ def _run_worker(
with unread_messages:
unread_messages.value -= 1

# Handle different workload types
if isinstance(workload, workloads.ExecuteTask):
try:
_execute_work(log, workload, team_conf)
output.put((workload.ti.key, TaskInstanceState.SUCCESS, None))
except Exception as e:
log.exception("Task execution failed.")
output.put((workload.ti.key, TaskInstanceState.FAILED, e))

elif isinstance(workload, workloads.ExecuteCallback):
output.put((workload.callback.id, CallbackState.RUNNING, None))
try:
_execute_callback(log, workload, team_conf)
output.put((workload.callback.id, CallbackState.SUCCESS, None))
except Exception as e:
log.exception("Callback execution failed")
output.put((workload.callback.id, CallbackState.FAILED, e))

else:
raise ValueError(f"LocalExecutor does not know how to handle {type(workload)}")


def _execute_work(log: Logger, workload: workloads.ExecuteTask, team_conf) -> None:
"""
Execute command received and stores result state in queue.

:param log: Logger instance
:param workload: The workload to execute
:param team_conf: Team-specific executor configuration
"""
from airflow.sdk.execution_time.supervisor import supervise

setproctitle(f"{_get_executor_process_title_prefix(team_conf.team_name)} {workload.ti.id}", log)
if workload.running_state is not None:
output.put((workload.key, workload.running_state, None))

base_url = team_conf.get("api", "base_url", fallback="/")
# If it's a relative URL, use localhost:8080 as the default
if base_url.startswith("/"):
base_url = f"http://localhost:8080{base_url}"
default_execution_api_server = f"{base_url.rstrip('/')}/execution/"

# This will return the exit code of the task process, but we don't care about that, just if the
# _supervisor_ had an error reporting the state back (which will result in an exception.)
supervise(
# This is the "wrong" ti type, but it duck types the same. TODO: Create a protocol for this.
ti=workload.ti, # type: ignore[arg-type]
dag_rel_path=workload.dag_rel_path,
bundle_info=workload.bundle_info,
token=workload.token,
server=team_conf.get("core", "execution_api_server_url", fallback=default_execution_api_server),
log_path=workload.log_path,
subprocess_logs_to_stdout=True,
)


def _execute_callback(log: Logger, workload: workloads.ExecuteCallback, team_conf) -> None:
"""
Execute a callback workload.

:param log: Logger instance
:param workload: The ExecuteCallback workload to execute
:param team_conf: Team-specific executor configuration
"""
setproctitle(f"{_get_executor_process_title_prefix(team_conf.team_name)} {workload.callback.id}", log)

success, error_msg = execute_callback_workload(workload.callback, log)

if not success:
raise RuntimeError(error_msg or "Callback execution failed")
try:
BaseExecutor.run_workload(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this base run_workload implementation going to be used in other places? I see only one usage, feels strange the way this is designed.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All executors will end up being migrated to it as they are updated to support executor callbacks. This PR has it used in LocalExecutor (where you left this comment), Celery, and here in the K8S container path.

It feels a little odd for it to be in the BaseExecutor, but that is for Core/SDK separation reasons as suggested by Ash.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough, it doesn't quite seem like the right responsibility of that class. It's meant to be a baseclass for executors themselves, all usages of it, so far, are really through the inheritance of the subclasses. Basically calling a method on a "base" class directly like this looks like a code smell to me BaseExecutor.run_workload(...)

But if we're okay with that, and @ashb prefers this for some greater good, then I suppose I'm okay to sign off on it

workload,
server=_get_execution_api_server_url(team_conf),
proctitle=f"{_get_executor_process_title_prefix(team_conf.team_name)} {workload.display_name}",
subprocess_logs_to_stdout=True,
)
output.put((workload.key, workload.success_state, None))
except Exception as e:
log.exception("Workload execution failed.", workload_type=type(workload).__name__)
output.put((workload.key, workload.failure_state, e))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using .pop(workload.key, None) on both queued_tasks and queued_callbacks for every workload silently swallows missing keys. The old code used del which would raise KeyError if a workload was never queued or got dequeued twice -- surfacing logic bugs rather than hiding them.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm. I can see where you are going. I don't think it's quite that simple though. with the two pops how I have it, one of them will pop and the other is expected to be a no-op. a task will pop from queued_tasks and the queued_callbacks pop will just not do anything, and vice versa. A del would have a problem with the no-op side of that.

But you are right that I missed the case where it's not in either. How do you feel about this:

for workload in workload_list:
    self.activity_queue.put(workload)
    # A valid workload will exist in exactly one of these dicts.
    # One will succeed, the other will fail gracefully and return None.
    removed = self.queued_tasks.pop(workload.key, None) or self.queued_callbacks.pop(workload.key, None)
    if not removed:
        raise KeyError(f"Workload {workload.key} was not found in any queue.")

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I implemented that in e558b31, if you have another idea I can revert it



class LocalExecutor(BaseExecutor):
Expand All @@ -185,7 +140,7 @@ class LocalExecutor(BaseExecutor):
serve_logs: bool = True
supports_callbacks: bool = True

activity_queue: SimpleQueue[workloads.All | None]
activity_queue: SimpleQueue[ExecutorWorkload | None]
result_queue: SimpleQueue[WorkloadResultType]
workers: dict[int, multiprocessing.Process]
_unread_messages: multiprocessing.sharedctypes.Synchronized[int]
Expand Down Expand Up @@ -214,6 +169,7 @@ def start(self) -> None:

# Mypy sees this value as `SynchronizedBase[c_uint]`, but that isn't the right runtime type behaviour
# (it looks like an int to python)

self._unread_messages = multiprocessing.Value(ctypes.c_uint)

if self.is_mp_using_fork:
Expand Down Expand Up @@ -332,11 +288,13 @@ def terminate(self):
def _process_workloads(self, workload_list):
for workload in workload_list:
self.activity_queue.put(workload)
# Remove from appropriate queue based on workload type
if isinstance(workload, workloads.ExecuteTask):
del self.queued_tasks[workload.ti.key]
elif isinstance(workload, workloads.ExecuteCallback):
del self.queued_callbacks[workload.callback.id]
# A valid workload will exist in exactly one of these dicts.
# One pop will succeed, the other will return None gracefully.
removed = self.queued_tasks.pop(workload.key, None) or self.queued_callbacks.pop(
workload.key, None
)
if not removed:
raise KeyError(f"Workload {workload.key} was not found in any queue")
with self._unread_messages:
self._unread_messages.value += len(workload_list)
self._check_workers()
7 changes: 7 additions & 0 deletions airflow-core/src/airflow/executors/workloads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,20 @@

TaskInstance = TaskInstanceDTO

ExecutorWorkload = Annotated[
ExecuteTask | ExecuteCallback,
Field(discriminator="type"),
]
"""Workload types that can be sent to executors (excludes RunTrigger, which is handled by the triggerer)."""

__all__ = [
"All",
"BaseWorkload",
"BundleInfo",
"CallbackFetchMethod",
"ExecuteCallback",
"ExecuteTask",
"ExecutorWorkload",
"TaskInstance",
"TaskInstanceDTO",
]
Loading
Loading