From 148f85dd0cca18a41e7cc60e0bb670c76d8e0110 Mon Sep 17 00:00:00 2001 From: Jens Scheffler Date: Sat, 25 Apr 2026 20:09:40 +0200 Subject: [PATCH 1/2] Support BaseExecutor.run_workload() for Airflow 3.3 in Edge Worker --- .../src/airflow/providers/edge3/cli/worker.py | 46 +++++++++++-------- .../airflow/providers/edge3/version_compat.py | 2 + .../edge3/tests/unit/edge3/cli/test_worker.py | 28 +++++++++-- 3 files changed, 54 insertions(+), 22 deletions(-) diff --git a/providers/edge3/src/airflow/providers/edge3/cli/worker.py b/providers/edge3/src/airflow/providers/edge3/cli/worker.py index 5eff1a9cac850..778af0dc325b2 100644 --- a/providers/edge3/src/airflow/providers/edge3/cli/worker.py +++ b/providers/edge3/src/airflow/providers/edge3/cli/worker.py @@ -59,7 +59,7 @@ EdgeWorkerState, EdgeWorkerVersionException, ) -from airflow.providers.edge3.version_compat import AIRFLOW_V_3_2_PLUS +from airflow.providers.edge3.version_compat import AIRFLOW_V_3_2_PLUS, AIRFLOW_V_3_3_PLUS from airflow.utils.net import getfqdn from airflow.utils.state import TaskInstanceState @@ -416,30 +416,38 @@ def _get_state(self) -> EdgeWorkerState: def _run_job_via_supervisor(self, workload: ExecuteTask, results_queue: Queue) -> int: _reset_parent_signal_state() - from airflow.sdk.execution_time.supervisor import supervise - # Ignore ctrl-c in this process -- we don't want to kill _this_ one. we let tasks run to completion os.setpgrp() logger.info("Worker starting up pid=%d", os.getpid()) - ti = workload.ti - setproctitle( - "airflow edge supervisor: " - f"dag_id={ti.dag_id} task_id={ti.task_id} run_id={ti.run_id} map_index={ti.map_index} " - f"try_number={ti.try_number}" - ) try: - supervise( - # This is the "wrong" ti type, but it duck types the same. TODO: Create a protocol for this. - # Same like in airflow/executors/local_executor.py:_execute_workload() - ti=ti, # type: ignore[arg-type] - dag_rel_path=workload.dag_rel_path, - bundle_info=workload.bundle_info, - token=workload.token, - server=self._execution_api_server_url, - log_path=workload.log_path, - ) + if AIRFLOW_V_3_3_PLUS: + from airflow.executors.base_executor import BaseExecutor + + BaseExecutor.run_workload( + workload=workload, + server=self._execution_api_server_url, + ) + else: + from airflow.sdk.execution_time.supervisor import supervise + + ti = workload.ti + setproctitle( + "airflow edge supervisor: " + f"dag_id={ti.dag_id} task_id={ti.task_id} run_id={ti.run_id} map_index={ti.map_index} " + f"try_number={ti.try_number}" + ) + supervise( + # This is the "wrong" ti type, but it duck types the same. TODO: Create a protocol for this. + # Same like in airflow/executors/local_executor.py:_execute_workload() + ti=ti, # type: ignore[arg-type] + dag_rel_path=workload.dag_rel_path, + bundle_info=workload.bundle_info, + token=workload.token, + server=self._execution_api_server_url, + log_path=workload.log_path, + ) results_queue.put("OK") return 0 except Exception as e: diff --git a/providers/edge3/src/airflow/providers/edge3/version_compat.py b/providers/edge3/src/airflow/providers/edge3/version_compat.py index 61b31ae45a22d..0f3b2b445c15a 100644 --- a/providers/edge3/src/airflow/providers/edge3/version_compat.py +++ b/providers/edge3/src/airflow/providers/edge3/version_compat.py @@ -34,8 +34,10 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]: AIRFLOW_V_3_1_PLUS = get_base_airflow_version_tuple() >= (3, 1, 0) AIRFLOW_V_3_2_PLUS = get_base_airflow_version_tuple() >= (3, 2, 0) +AIRFLOW_V_3_3_PLUS = get_base_airflow_version_tuple() >= (3, 3, 0) __all__ = [ "AIRFLOW_V_3_1_PLUS", "AIRFLOW_V_3_2_PLUS", + "AIRFLOW_V_3_3_PLUS", ] diff --git a/providers/edge3/tests/unit/edge3/cli/test_worker.py b/providers/edge3/tests/unit/edge3/cli/test_worker.py index 5ac72b84867f1..cc77e31388acb 100644 --- a/providers/edge3/tests/unit/edge3/cli/test_worker.py +++ b/providers/edge3/tests/unit/edge3/cli/test_worker.py @@ -54,7 +54,7 @@ ) from tests_common.test_utils.config import conf_vars -from tests_common.test_utils.version_compat import AIRFLOW_V_3_2_PLUS +from tests_common.test_utils.version_compat import AIRFLOW_V_3_2_PLUS, AIRFLOW_V_3_3_PLUS pytest.importorskip("pydantic", minversion="2.0.0") pytestmark = [pytest.mark.asyncio] @@ -235,8 +235,9 @@ def test_execution_api_server_url( assert url == expected_url @patch("airflow.sdk.execution_time.supervisor.supervise") + @pytest.mark.skipif(AIRFLOW_V_3_3_PLUS, reason="Test is for Airflow < 3.3.0 where supervise was used") @pytest.mark.asyncio - async def test_supervise_launch( + async def test_supervise_launch_pre_3_3( self, mock_supervise, worker_with_job: EdgeWorker, @@ -246,6 +247,24 @@ async def test_supervise_launch( q = mock.MagicMock() result = worker_with_job._run_job_via_supervisor(edge_job.command, q) + assert result == 0 + q.put.assert_not_called() + + @patch("airflow.executors.base_executor.BaseExecutor.run_workload") + @pytest.mark.skipif( + not AIRFLOW_V_3_3_PLUS, reason="Test is for Airflow >= 3.3.0 where BaseExecutor.run_workload is used" + ) + @pytest.mark.asyncio + async def test_supervise_launch( + self, + mock_run_workload, + worker_with_job: EdgeWorker, + ): + worker_with_job.__dict__["_execution_api_server_url"] = "https://mock-server/execution" + edge_job = worker_with_job.jobs.pop().edge_job + q = mock.MagicMock() + result = worker_with_job._run_job_via_supervisor(edge_job.command, q) + assert result == 0 q.put.assert_called_once() @@ -889,6 +908,9 @@ def test_reset_parent_signal_state_clears_all_handlers(self): for sig, prev in original.items(): signal.signal(sig, prev) + @pytest.mark.skipif( + not AIRFLOW_V_3_3_PLUS, reason="Test is for Airflow >= 3.3.0 where BaseExecutor.run_workload is used" + ) def test_run_job_via_supervisor_resets_signals_before_supervise(self, tmp_path): """Reset must run first: before ``os.setpgrp`` and before ``supervise``.""" worker = EdgeWorker(str(tmp_path / "mock.pid"), "mock", None, 8) @@ -901,7 +923,7 @@ def test_run_job_via_supervisor_resets_signals_before_supervise(self, tmp_path): ), mock.patch("os.setpgrp", side_effect=lambda: order("setpgrp")), mock.patch( - "airflow.sdk.execution_time.supervisor.supervise", + "airflow.executors.base_executor.BaseExecutor.run_workload", side_effect=lambda **_: order("supervise"), ), ): From d99f9ed008201db5192cdce4cd113c697baaae0e Mon Sep 17 00:00:00 2001 From: Jens Scheffler Date: Sun, 10 May 2026 16:56:25 +0200 Subject: [PATCH 2/2] Fix pytest --- providers/edge3/tests/unit/edge3/cli/test_worker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/providers/edge3/tests/unit/edge3/cli/test_worker.py b/providers/edge3/tests/unit/edge3/cli/test_worker.py index cc77e31388acb..0be1573353dbe 100644 --- a/providers/edge3/tests/unit/edge3/cli/test_worker.py +++ b/providers/edge3/tests/unit/edge3/cli/test_worker.py @@ -248,7 +248,7 @@ async def test_supervise_launch_pre_3_3( result = worker_with_job._run_job_via_supervisor(edge_job.command, q) assert result == 0 - q.put.assert_not_called() + q.put.assert_called_once_with("OK") @patch("airflow.executors.base_executor.BaseExecutor.run_workload") @pytest.mark.skipif( @@ -266,7 +266,7 @@ async def test_supervise_launch( result = worker_with_job._run_job_via_supervisor(edge_job.command, q) assert result == 0 - q.put.assert_called_once() + q.put.assert_called_once_with("OK") @patch("airflow.sdk.execution_time.supervisor.supervise") @pytest.mark.asyncio