Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from collections import defaultdict, deque
from collections.abc import Sequence
from copy import deepcopy
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, cast

from botocore.exceptions import ClientError, NoCredentialsError

Expand All @@ -46,7 +46,7 @@
exponential_backoff_retry,
)
from airflow.providers.amazon.aws.hooks.ecs import EcsHook
from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS
from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_2_PLUS
from airflow.providers.common.compat.sdk import AirflowException, Stats, timezone
from airflow.utils.helpers import merge_dicts
from airflow.utils.state import State
Expand All @@ -55,6 +55,7 @@
from sqlalchemy.orm import Session

from airflow.executors import workloads
from airflow.executors.workloads.types import WorkloadKey
from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
from airflow.providers.amazon.aws.executors.ecs.utils import (
CommandType,
Expand Down Expand Up @@ -92,6 +93,9 @@ class AwsEcsExecutor(BaseExecutor):

supports_multi_team: bool = True

if AIRFLOW_V_3_2_PLUS:
supports_callbacks: bool = True

# AWS limits the maximum number of ARNs in the describe_tasks function.
DESCRIBE_TASKS_BATCH_SIZE = 99

Expand All @@ -104,6 +108,8 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.active_workers: EcsTaskCollection = EcsTaskCollection()
self.pending_tasks: deque = deque()
if AIRFLOW_V_3_2_PLUS:
self.queued_callbacks: dict[str, workloads.ExecuteCallback] = {}

# Check if self has the ExecutorConf set on the self.conf attribute, and if not, set it to the global
# configuration object. This allows the changes to be backwards compatible with older versions of
Expand Down Expand Up @@ -134,27 +140,38 @@ def __init__(self, *args, **kwargs):
def queue_workload(self, workload: workloads.All, session: Session | None) -> None:
from airflow.executors import workloads

if not isinstance(workload, workloads.ExecuteTask):
if AIRFLOW_V_3_2_PLUS and isinstance(workload, workloads.ExecuteCallback):
self.queued_callbacks[workload.callback.id] = workload
elif isinstance(workload, workloads.ExecuteTask):
ti = workload.ti
self.queued_tasks[ti.key] = workload
else:
raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(workload)}")
ti = workload.ti
self.queued_tasks[ti.key] = workload

def _process_workloads(self, workloads: Sequence[workloads.All]) -> None:
from airflow.executors.workloads import ExecuteTask
from airflow.executors import workloads as wl

for workload in workloads:
if isinstance(workload, wl.ExecuteTask):
command = [workload]
key = workload.ti.key
queue = workload.ti.queue
executor_config = workload.ti.executor_config or {}

# Airflow V3 version
for w in workloads:
if not isinstance(w, ExecuteTask):
raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(w)}")
del self.queued_tasks[key]
self.execute_async(key=key, command=command, queue=queue, executor_config=executor_config) # type: ignore[arg-type]
self.running.add(key)

command = [w]
key = w.ti.key
queue = w.ti.queue
executor_config = w.ti.executor_config or {}
elif AIRFLOW_V_3_2_PLUS and isinstance(workload, wl.ExecuteCallback):
command = [workload] # type: ignore[list-item]
key = workload.callback.id # type: ignore[assignment]

del self.queued_tasks[key]
self.execute_async(key=key, command=command, queue=queue, executor_config=executor_config) # type: ignore[arg-type]
self.running.add(key)
del self.queued_callbacks[key] # type: ignore[arg-type]
self.execute_async(key=key, command=command, queue=None) # type: ignore[arg-type]
self.running.add(key)

else:
raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(workload)}")

def start(self):
"""Call this when the Executor is run for the first time by the scheduler."""
Expand Down Expand Up @@ -292,12 +309,12 @@ def __update_running_task(self, task):
self.__handle_failed_task(task.task_arn, task.stopped_reason)
elif task_state == State.SUCCESS:
self.log.debug(
"Airflow task %s marked as %s after running on ECS Task (arn) %s",
"Airflow workload %s marked as %s after running on ECS Task (arn) %s",
task_key,
task_state,
task.task_arn,
)
self.success(task_key)
self.success(cast("TaskInstanceKey", task_key))
self.active_workers.pop_by_key(task_key)

def __describe_tasks(self, task_arns):
Expand Down Expand Up @@ -346,7 +363,7 @@ def __handle_failed_task(self, task_arn: str, reason: str):
failure_count = self.active_workers.failure_count_by_key(task_key)
if int(failure_count) < int(self.max_run_task_attempts):
self.log.warning(
"Airflow task %s failed due to %s. Failure %s out of %s occurred on %s. Rescheduling.",
"Airflow workload %s failed due to %s. Failure %s out of %s occurred on %s. Rescheduling.",
task_key,
reason,
failure_count,
Expand All @@ -365,11 +382,11 @@ def __handle_failed_task(self, task_arn: str, reason: str):
)
else:
self.log.error(
"Airflow task %s has failed a maximum of %s times. Marking as failed",
"Airflow workload %s has failed a maximum of %s times. Marking as failed",
task_key,
failure_count,
)
self.fail(task_key)
self.fail(cast("TaskInstanceKey", task_key))
self.active_workers.pop_by_key(task_key)

def attempt_task_runs(self):
Expand Down Expand Up @@ -430,37 +447,43 @@ def attempt_task_runs(self):
else:
reasons_str = ", ".join(failure_reasons)
self.log.error(
"ECS task %s has failed a maximum of %s times. Marking as failed. Reasons: %s",
"ECS workload %s has failed a maximum of %s times. Marking as failed. Reasons: %s",
task_key,
attempt_number,
reasons_str,
)
if isinstance(task_key, tuple):
self.log_task_event(
event="ecs task submit failure",
ti_key=task_key,
extra=(
f"Task could not be queued after {attempt_number} attempts. "
f"Marking as failed. Reasons: {reasons_str}"
),
)
self.fail(cast("TaskInstanceKey", task_key))
elif not run_task_response["tasks"]:
self.log.error("ECS RunTask Response: %s", run_task_response)
if isinstance(task_key, tuple):
self.log_task_event(
event="ecs task submit failure",
extra=f"ECS RunTask Response: {run_task_response}",
ti_key=task_key,
extra=(
f"Task could not be queued after {attempt_number} attempts. "
f"Marking as failed. Reasons: {reasons_str}"
),
)
self.fail(task_key)
elif not run_task_response["tasks"]:
self.log.error("ECS RunTask Response: %s", run_task_response)
self.log_task_event(
event="ecs task submit failure",
extra=f"ECS RunTask Response: {run_task_response}",
ti_key=task_key,
)
raise EcsExecutorException(
"No failures and no ECS tasks provided in response. This should never happen."
)
else:
task = run_task_response["tasks"][0]
self.active_workers.add_task(task, task_key, queue, cmd, exec_config, attempt_number)
self.running_state(task_key, task.task_arn)
self.running_state(cast("TaskInstanceKey", task_key), task.task_arn)

def _run_task(
self, task_id: TaskInstanceKey, cmd: CommandType, queue: str, exec_config: ExecutorConfigType
self,
task_id: WorkloadKey,
cmd: CommandType,
queue: str | None,
exec_config: ExecutorConfigType,
):
"""
Run a queued-up Airflow task.
Expand All @@ -475,7 +498,11 @@ def _run_task(
return run_task_response

def _run_task_kwargs(
self, task_id: TaskInstanceKey, cmd: CommandType, queue: str, exec_config: ExecutorConfigType
self,
task_id: WorkloadKey,
cmd: CommandType,
queue: str | None,
exec_config: ExecutorConfigType,
) -> dict:
"""
Update the Airflow command by modifying container overrides for task-specific kwargs.
Expand All @@ -494,14 +521,16 @@ def _run_task_kwargs(

return run_task_kwargs

def execute_async(self, key: TaskInstanceKey, command: CommandType, queue=None, executor_config=None):
"""Save the task to be executed in the next sync by inserting the commands into a queue."""
def execute_async(self, key: WorkloadKey, command: CommandType, queue=None, executor_config=None):
"""Save the workload to be executed in the next sync by inserting the commands into a queue."""
if executor_config and ("name" in executor_config or "command" in executor_config):
raise ValueError('Executor Config should never override "name" or "command"')
if len(command) == 1:
from airflow.executors.workloads import ExecuteTask
from airflow.executors import workloads

if isinstance(command[0], ExecuteTask):
if isinstance(command[0], workloads.ExecuteTask) or (
AIRFLOW_V_3_2_PLUS and isinstance(command[0], workloads.ExecuteCallback)
):
command = self._serialize_workload_to_command(command[0])
else:
raise ValueError(
Expand Down Expand Up @@ -567,9 +596,9 @@ def get_container(self, container_list):
@staticmethod
def _serialize_workload_to_command(workload) -> CommandType:
"""
Serialize an ExecuteTask workload into a command for the Task SDK.
Serialize a workload into a command for the Task SDK.

:param workload: ExecuteTask workload to serialize
:param workload: ExecuteTask or ExecuteCallback workload to serialize
:return: Command as list of strings for Task SDK execution
"""
return [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from airflow.utils.state import State

if TYPE_CHECKING:
from airflow.models.taskinstance import TaskInstanceKey
from airflow.executors.workloads.types import WorkloadKey

CommandType = Sequence[str]
ExecutorConfigFunctionType = Callable[[CommandType], dict]
Expand All @@ -57,11 +57,11 @@

@dataclass
class EcsQueuedTask:
"""Represents an ECS task that is queued. The task will be run in the next heartbeat."""
"""Represents a queued ECS workload (task or callback). The workload will be run in the next heartbeat."""

key: TaskInstanceKey
key: WorkloadKey
command: CommandType
queue: str
queue: str | None
executor_config: ExecutorConfigType
attempt_number: int
next_attempt_time: datetime.datetime
Expand All @@ -72,7 +72,7 @@ class EcsTaskInfo:
"""Contains information about a currently running ECS task."""

cmd: CommandType
queue: str
queue: str | None
config: ExecutorConfigType


Expand Down Expand Up @@ -156,20 +156,20 @@ def __repr__(self):


class EcsTaskCollection:
"""A five-way dictionary between Airflow task ids, Airflow cmds, ECS ARNs, and ECS task objects."""
"""A five-way dictionary between Airflow workload keys, commands, ECS ARNs, and ECS task objects."""

def __init__(self):
self.key_to_arn: dict[TaskInstanceKey, str] = {}
self.arn_to_key: dict[str, TaskInstanceKey] = {}
self.key_to_arn: dict[WorkloadKey, str] = {}
self.arn_to_key: dict[str, WorkloadKey] = {}
self.tasks: dict[str, EcsExecutorTask] = {}
self.key_to_failure_counts: dict[TaskInstanceKey, int] = defaultdict(int)
self.key_to_task_info: dict[TaskInstanceKey, EcsTaskInfo] = {}
self.key_to_failure_counts: dict[WorkloadKey, int] = defaultdict(int)
self.key_to_task_info: dict[WorkloadKey, EcsTaskInfo] = {}

def add_task(
self,
task: EcsExecutorTask,
airflow_task_key: TaskInstanceKey,
queue: str,
airflow_task_key: WorkloadKey,
queue: str | None,
airflow_cmd: CommandType,
exec_config: ExecutorConfigType,
attempt_number: int,
Expand All @@ -186,17 +186,17 @@ def update_task(self, task: EcsExecutorTask):
"""Update the state of the given task based on task ARN."""
self.tasks[task.task_arn] = task

def task_by_key(self, task_key: TaskInstanceKey) -> EcsExecutorTask:
"""Get a task by Airflow Instance Key."""
def task_by_key(self, task_key: WorkloadKey) -> EcsExecutorTask:
"""Get a task by Airflow workload key."""
arn = self.key_to_arn[task_key]
return self.task_by_arn(arn)

def task_by_arn(self, arn) -> EcsExecutorTask:
"""Get a task by AWS ARN."""
return self.tasks[arn]

def pop_by_key(self, task_key: TaskInstanceKey) -> EcsExecutorTask:
"""Delete task from collection based off of Airflow Task Instance Key."""
def pop_by_key(self, task_key: WorkloadKey) -> EcsExecutorTask:
"""Delete task from collection based off of Airflow workload key."""
arn = self.key_to_arn[task_key]
task = self.tasks[arn]
del self.key_to_arn[task_key]
Expand All @@ -211,20 +211,20 @@ def get_all_arns(self) -> list[str]:
"""Get all AWS ARNs in collection."""
return list(self.key_to_arn.values())

def get_all_task_keys(self) -> list[TaskInstanceKey]:
"""Get all Airflow Task Keys in collection."""
def get_all_task_keys(self) -> list[WorkloadKey]:
"""Get all Airflow workload keys in collection."""
return list(self.key_to_arn.keys())

def failure_count_by_key(self, task_key: TaskInstanceKey) -> int:
"""Get the number of times a task has failed given an Airflow Task Key."""
def failure_count_by_key(self, task_key: WorkloadKey) -> int:
"""Get the number of times a workload has failed given an Airflow workload key."""
return self.key_to_failure_counts[task_key]

def increment_failure_count(self, task_key: TaskInstanceKey):
"""Increment the failure counter given an Airflow Task Key."""
def increment_failure_count(self, task_key: WorkloadKey):
"""Increment the failure counter given an Airflow workload key."""
self.key_to_failure_counts[task_key] += 1

def info_by_key(self, task_key: TaskInstanceKey) -> EcsTaskInfo:
"""Get the Airflow Command given an Airflow task key."""
def info_by_key(self, task_key: WorkloadKey) -> EcsTaskInfo:
"""Get the task info given an Airflow workload key."""
return self.key_to_task_info[task_key]

def __getitem__(self, value):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]:
AIRFLOW_V_3_1_PLUS: bool = get_base_airflow_version_tuple() >= (3, 1, 0)
AIRFLOW_V_3_1_1_PLUS: bool = get_base_airflow_version_tuple() >= (3, 1, 1)
AIRFLOW_V_3_1_8_PLUS: bool = get_base_airflow_version_tuple() >= (3, 1, 8)
AIRFLOW_V_3_2_PLUS: bool = get_base_airflow_version_tuple() >= (3, 2, 0)

try:
from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet
Expand All @@ -58,6 +59,7 @@ def is_arg_set(value): # type: ignore[misc,no-redef]
"AIRFLOW_V_3_1_PLUS",
"AIRFLOW_V_3_1_1_PLUS",
"AIRFLOW_V_3_1_8_PLUS",
"AIRFLOW_V_3_2_PLUS",
"NOTSET",
"ArgNotSet",
"is_arg_set",
Expand Down
Loading
Loading