Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 27 additions & 19 deletions providers/edge3/src/airflow/providers/edge3/cli/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
EdgeWorkerState,
EdgeWorkerVersionException,
)
from airflow.providers.edge3.version_compat import AIRFLOW_V_3_2_PLUS
from airflow.providers.edge3.version_compat import AIRFLOW_V_3_2_PLUS, AIRFLOW_V_3_3_PLUS
from airflow.utils.net import getfqdn
from airflow.utils.state import TaskInstanceState

Expand Down Expand Up @@ -416,30 +416,38 @@ def _get_state(self) -> EdgeWorkerState:
def _run_job_via_supervisor(self, workload: ExecuteTask, results_queue: Queue) -> int:
_reset_parent_signal_state()

from airflow.sdk.execution_time.supervisor import supervise

# Ignore ctrl-c in this process -- we don't want to kill _this_ one. we let tasks run to completion
os.setpgrp()

logger.info("Worker starting up pid=%d", os.getpid())
ti = workload.ti
setproctitle(
"airflow edge supervisor: "
f"dag_id={ti.dag_id} task_id={ti.task_id} run_id={ti.run_id} map_index={ti.map_index} "
f"try_number={ti.try_number}"
)

try:
supervise(
# This is the "wrong" ti type, but it duck types the same. TODO: Create a protocol for this.
# Same like in airflow/executors/local_executor.py:_execute_workload()
ti=ti, # type: ignore[arg-type]
dag_rel_path=workload.dag_rel_path,
bundle_info=workload.bundle_info,
token=workload.token,
server=self._execution_api_server_url,
log_path=workload.log_path,
)
if AIRFLOW_V_3_3_PLUS:
from airflow.executors.base_executor import BaseExecutor

BaseExecutor.run_workload(
workload=workload,
server=self._execution_api_server_url,
)
else:
from airflow.sdk.execution_time.supervisor import supervise

ti = workload.ti
setproctitle(
"airflow edge supervisor: "
f"dag_id={ti.dag_id} task_id={ti.task_id} run_id={ti.run_id} map_index={ti.map_index} "
f"try_number={ti.try_number}"
)
supervise(
# This is the "wrong" ti type, but it duck types the same. TODO: Create a protocol for this.
# Same like in airflow/executors/local_executor.py:_execute_workload()
ti=ti, # type: ignore[arg-type]
dag_rel_path=workload.dag_rel_path,
bundle_info=workload.bundle_info,
token=workload.token,
server=self._execution_api_server_url,
log_path=workload.log_path,
)
results_queue.put("OK")
return 0
except Exception as e:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]:

AIRFLOW_V_3_1_PLUS = get_base_airflow_version_tuple() >= (3, 1, 0)
AIRFLOW_V_3_2_PLUS = get_base_airflow_version_tuple() >= (3, 2, 0)
AIRFLOW_V_3_3_PLUS = get_base_airflow_version_tuple() >= (3, 3, 0)

__all__ = [
"AIRFLOW_V_3_1_PLUS",
"AIRFLOW_V_3_2_PLUS",
"AIRFLOW_V_3_3_PLUS",
]
30 changes: 26 additions & 4 deletions providers/edge3/tests/unit/edge3/cli/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
)

from tests_common.test_utils.config import conf_vars
from tests_common.test_utils.version_compat import AIRFLOW_V_3_2_PLUS
from tests_common.test_utils.version_compat import AIRFLOW_V_3_2_PLUS, AIRFLOW_V_3_3_PLUS

pytest.importorskip("pydantic", minversion="2.0.0")
pytestmark = [pytest.mark.asyncio]
Expand Down Expand Up @@ -235,8 +235,9 @@ def test_execution_api_server_url(
assert url == expected_url

@patch("airflow.sdk.execution_time.supervisor.supervise")
@pytest.mark.skipif(AIRFLOW_V_3_3_PLUS, reason="Test is for Airflow < 3.3.0 where supervise was used")
@pytest.mark.asyncio
async def test_supervise_launch(
async def test_supervise_launch_pre_3_3(
self,
mock_supervise,
worker_with_job: EdgeWorker,
Expand All @@ -247,7 +248,25 @@ async def test_supervise_launch(
result = worker_with_job._run_job_via_supervisor(edge_job.command, q)

assert result == 0
q.put.assert_called_once()
q.put.assert_called_once_with("OK")

@patch("airflow.executors.base_executor.BaseExecutor.run_workload")
@pytest.mark.skipif(
not AIRFLOW_V_3_3_PLUS, reason="Test is for Airflow >= 3.3.0 where BaseExecutor.run_workload is used"
)
@pytest.mark.asyncio
async def test_supervise_launch(
self,
mock_run_workload,
worker_with_job: EdgeWorker,
):
worker_with_job.__dict__["_execution_api_server_url"] = "https://mock-server/execution"
edge_job = worker_with_job.jobs.pop().edge_job
q = mock.MagicMock()
result = worker_with_job._run_job_via_supervisor(edge_job.command, q)

assert result == 0
q.put.assert_called_once_with("OK")

@patch("airflow.sdk.execution_time.supervisor.supervise")
@pytest.mark.asyncio
Expand Down Expand Up @@ -889,6 +908,9 @@ def test_reset_parent_signal_state_clears_all_handlers(self):
for sig, prev in original.items():
signal.signal(sig, prev)

@pytest.mark.skipif(
not AIRFLOW_V_3_3_PLUS, reason="Test is for Airflow >= 3.3.0 where BaseExecutor.run_workload is used"
)
def test_run_job_via_supervisor_resets_signals_before_supervise(self, tmp_path):
"""Reset must run first: before ``os.setpgrp`` and before ``supervise``."""
worker = EdgeWorker(str(tmp_path / "mock.pid"), "mock", None, 8)
Expand All @@ -901,7 +923,7 @@ def test_run_job_via_supervisor_resets_signals_before_supervise(self, tmp_path):
),
mock.patch("os.setpgrp", side_effect=lambda: order("setpgrp")),
mock.patch(
"airflow.sdk.execution_time.supervisor.supervise",
"airflow.executors.base_executor.BaseExecutor.run_workload",
side_effect=lambda **_: order("supervise"),
),
):
Expand Down
Loading