Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -414,15 +414,7 @@ def execute_async(self, key: BatchJobWorkloadKey, command: CommandType, queue=No
if isinstance(command[0], workloads.ExecuteTask) or (
AIRFLOW_V_3_3_PLUS and isinstance(command[0], workloads.ExecuteCallback)
):
workload = command[0]
ser_input = workload.model_dump_json()
command = [
"python",
"-m",
"airflow.sdk.execution_time.execute_workload",
"--json-string",
ser_input,
]
command = self._serialize_workload_to_command(command[0])
else:
raise ValueError(
f"BatchExecutor doesn't know how to handle workload of type: {type(command[0])}"
Expand Down Expand Up @@ -509,6 +501,39 @@ def _load_submit_kwargs(self) -> dict:
)
return submit_kwargs

@staticmethod
def _serialize_workload_to_command(workload) -> CommandType:
"""
Serialize a workload into a command for the Task SDK.

:param workload: ExecuteTask or ExecuteCallback workload to serialize
:return: Command as list of strings for Task SDK execution
"""
return [
"python",
"-m",
"airflow.sdk.execution_time.execute_workload",
"--json-string",
workload.model_dump_json(),
]

def _build_task_command(self, ti: TaskInstance) -> CommandType:
"""
Build task command for execution based on Airflow version.

For Airflow 3.x+, generates an ExecuteTask workload with JSON serialization.
For Airflow 2.x, uses the legacy command_as_list() method.

:param ti: TaskInstance to build command for
:return: Command as list of strings
"""
if AIRFLOW_V_3_0_PLUS:
from airflow.executors.workloads import ExecuteTask

workload = ExecuteTask.make(ti)
return self._serialize_workload_to_command(workload)
return ti.command_as_list()

def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[TaskInstance]:
"""
Adopt task instances which have an external_executor_id (the Batch job ID).
Expand All @@ -523,10 +548,12 @@ def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[Task

for batch_job in batch_jobs:
ti = next(ti for ti in tis if ti.external_executor_id == batch_job.job_id)
command = self._build_task_command(ti)

self.active_workers.add_job(
job_id=batch_job.job_id,
airflow_workload_key=ti.key,
airflow_cmd=ti.command_as_list(),
airflow_cmd=command,
queue=ti.queue,
exec_config=ti.executor_config,
attempt_number=ti.try_number,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import os
from unittest import mock
from unittest.mock import patch
from uuid import uuid4

import pytest
import yaml
Expand Down Expand Up @@ -776,7 +777,6 @@ def _mock_sync(
}
executor.batch.describe_jobs.return_value = {"jobs": [after_batch_job]}

@pytest.mark.skip(reason="Adopting task instances hasn't been ported over to Airflow 3 yet")
def test_try_adopt_task_instances(self, mock_executor):
"""Test that executor can adopt orphaned task instances from a SchedulerJob shutdown event."""
mock_executor.batch.describe_jobs.return_value = {
Expand All @@ -794,8 +794,38 @@ def test_try_adopt_task_instances(self, mock_executor):
orphaned_tasks[0].external_executor_id = "001" # Matches a running task_arn
orphaned_tasks[1].external_executor_id = "002" # Matches a running task_arn
orphaned_tasks[2].external_executor_id = None # One orphaned task has no external_executor_id
for task in orphaned_tasks:

for idx, task in enumerate(orphaned_tasks):
task.try_number = 1
task.key = mock.Mock(spec=TaskInstanceKey)
task.queue = "default"
task.executor_config = {}
task.id = uuid4()
task.dag_version_id = uuid4()
task.task_id = f"task_{idx}"
task.dag_id = "test_dag"
task.run_id = "test_run"
task.map_index = -1
task.pool_slots = 1
task.priority_weight = 1
task.context_carrier = {}
task.queued_dttm = dt.datetime(2024, 1, 1, tzinfo=dt.timezone.utc)
task.dag_model = mock.Mock()
task.dag_model.bundle_name = "test_bundle"
task.dag_model.relative_fileloc = "test_dag.py"
task.dag_run = mock.Mock()
task.dag_run.bundle_version = "1.0.0"
task.dag_run.context_carrier = {}

if not AIRFLOW_V_3_0_PLUS:
task.command_as_list.return_value = [
"airflow",
"tasks",
"run",
"dag",
f"task_{idx}",
"2024-01-01",
]

not_adopted_tasks = mock_executor.try_adopt_task_instances(orphaned_tasks)

Expand All @@ -805,6 +835,59 @@ def test_try_adopt_task_instances(self, mock_executor):
# The remaining one task is unable to be adopted.
assert len(not_adopted_tasks) == 1

@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3+")
def test_serialize_workload_to_command(self, mock_executor):
"""Test that _serialize_workload_to_command properly serializes a Task SDK workload."""
from airflow.executors.workloads import ExecuteTask

workload = mock.Mock(spec=ExecuteTask)
ser_workload = json.dumps({"test_key": "test_value"})
workload.model_dump_json.return_value = ser_workload

command = mock_executor._serialize_workload_to_command(workload)

assert command == [
"python",
"-m",
"airflow.sdk.execution_time.execute_workload",
"--json-string",
ser_workload,
]
workload.model_dump_json.assert_called_once()

@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3+")
@mock.patch("airflow.executors.workloads.ExecuteTask")
def test_build_task_command_airflow3(self, mock_execute_task_class, mock_executor):
"""Test _build_task_command for Airflow 3.x+ using Task SDK."""
mock_ti = mock.Mock(spec=TaskInstance)
mock_workload = mock.Mock()
ser_workload = json.dumps({"task": "data"})
mock_workload.model_dump_json.return_value = ser_workload
mock_execute_task_class.make.return_value = mock_workload

command = mock_executor._build_task_command(mock_ti)

mock_execute_task_class.make.assert_called_once_with(mock_ti)
assert command == [
"python",
"-m",
"airflow.sdk.execution_time.execute_workload",
"--json-string",
ser_workload,
]

@pytest.mark.skipif(AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 2.x")
def test_build_task_command_airflow2(self, mock_executor):
"""Test _build_task_command for Airflow 2.x using command_as_list."""
mock_ti = mock.Mock(spec=TaskInstance)
expected_command = ["airflow", "tasks", "run", "dag_id", "task_id", "execution_date"]
mock_ti.command_as_list.return_value = expected_command

command = mock_executor._build_task_command(mock_ti)

mock_ti.command_as_list.assert_called_once()
assert command == expected_command

@pytest.mark.skipif(not AIRFLOW_V_3_1_PLUS, reason="Multi-team support requires Airflow 3.1+")
def test_team_config(self):
"""Test that the executor uses team-specific configuration when provided via self.conf."""
Expand Down