Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 35 additions & 3 deletions task-sdk/src/airflow/sdk/coordinators/java/coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import selectors
import signal
import socket
import stat
import subprocess
import time
import zipfile
Expand Down Expand Up @@ -62,13 +63,44 @@ def _start_server() -> socket.socket:


def _find_jars(items: Iterable[pathlib.Path]) -> Iterator[pathlib.Path]:
"""
Yield JAR files under *items*, descending into directories.

A symlink loop or a directory that hardlinks into one of its ancestors
would otherwise recurse until the interpreter stack is exhausted, so
directories are deduplicated by ``(st_dev, st_ino)`` for the duration
of a single scan.
"""
seen_dirs: set[tuple[int, int]] = set()
yield from _walk_jars(items, seen_dirs)


def _walk_jars(items: Iterable[pathlib.Path], seen_dirs: set[tuple[int, int]]) -> Iterator[pathlib.Path]:
for item in items:
if item.is_dir():
yield from _find_jars(item.iterdir())
elif item.is_file() and item.suffix == ".jar":
try:
st = item.stat()
except OSError:
continue
if stat.S_ISDIR(st.st_mode):
key = (st.st_dev, st.st_ino)
if key in seen_dirs:
log.debug("Skipping already-visited directory", path=item)
continue
seen_dirs.add(key)
yield from _walk_jars(_iter_dir(item), seen_dirs)
elif stat.S_ISREG(st.st_mode) and item.suffix == ".jar":
yield item


def _iter_dir(directory: pathlib.Path) -> Iterator[pathlib.Path]:
# iterdir() is lazy, so an unreadable directory raises only once iteration
# starts; swallow it here so a single bad directory does not abort the scan.
try:
yield from directory.iterdir()
except OSError:
return


def _calculate_classpath(jars_root: Sequence[pathlib.Path]) -> str:
jars = (p.as_posix() for p in _find_jars(jars_root))
return os.pathsep.join(sorted(jars)) # Keep output deterministic.
Expand Down
57 changes: 57 additions & 0 deletions task-sdk/tests/task_sdk/coordinators/java/test_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
_JavaActivitySubprocess,
_ResourceTracker,
_start_server,
_walk_jars,
)
from airflow.sdk.execution_time.coordinator import BaseCoordinator
from airflow.sdk.execution_time.supervisor import ActivitySubprocess
Expand Down Expand Up @@ -214,6 +215,62 @@ def test_find_by_explicit_main_class_not_present_raises(self, tmp_path):
with pytest.raises(FileNotFoundError, match="com.example.Missing"):
_JarInfo.find([tmp_path], "com.example.Missing")

def test_symlink_cycle_does_not_infinite_recurse(self, tmp_path):
nested = tmp_path / "inner"
nested.mkdir()
_make_jar(nested / "app.jar", main_class="com.example.Loop", schema_version="2026-06-16")
loop = nested / "loop"
try:
loop.symlink_to(tmp_path)
except (OSError, NotImplementedError):
pytest.skip("symlinks not supported on this platform")

result = _JarInfo.find([tmp_path], "com.example.Loop")
assert result == _JarInfo("com.example.Loop", "2026-06-16")


class TestWalkJars:
def test_skips_directory_whose_key_is_already_in_seen_dirs(self, tmp_path):
"""A directory whose (st_dev, st_ino) is already in seen_dirs is skipped."""
_make_jar(tmp_path / "app.jar", main_class="com.example.Main", schema_version="2026-06-16")
st = tmp_path.stat()
seen_dirs: set[tuple[int, int]] = {(st.st_dev, st.st_ino)}
assert list(_walk_jars([tmp_path], seen_dirs)) == []

def test_records_visited_directories_in_seen_dirs(self, tmp_path):
"""Every directory descended into is added to seen_dirs."""
sub = tmp_path / "sub"
sub.mkdir()
_make_jar(sub / "app.jar", main_class="com.example.Main", schema_version="2026-06-16")
seen_dirs: set[tuple[int, int]] = set()
list(_walk_jars([tmp_path], seen_dirs))
assert (tmp_path.stat().st_dev, tmp_path.stat().st_ino) in seen_dirs
assert (sub.stat().st_dev, sub.stat().st_ino) in seen_dirs

def test_symlink_cycle_yields_each_jar_once(self, tmp_path):
"""A symlink that loops back to an ancestor must not yield the same JAR twice."""
nested = tmp_path / "inner"
nested.mkdir()
jar = _make_jar(nested / "app.jar", main_class="com.example.Loop", schema_version="2026-06-16")
loop = nested / "loop"
try:
loop.symlink_to(tmp_path)
except (OSError, NotImplementedError):
pytest.skip("symlinks not supported on this platform")

seen_dirs: set[tuple[int, int]] = set()
yielded = list(_walk_jars([tmp_path], seen_dirs))
assert [p.resolve() for p in yielded] == [jar.resolve()]

def test_skip_logged_when_directory_revisited(self, tmp_path):
"""A revisited directory triggers the 'Skipping already-visited directory' debug log."""
sub = tmp_path / "sub"
sub.mkdir()
seen_dirs: set[tuple[int, int]] = {(sub.stat().st_dev, sub.stat().st_ino)}
with patch("airflow.sdk.coordinators.java.coordinator.log") as mock_log:
list(_walk_jars([sub], seen_dirs))
mock_log.debug.assert_any_call("Skipping already-visited directory", path=sub)


class TestAcceptConnections:
def _connect_after_delay(self, addr: tuple[str, int], delay: float = 0.0) -> None:
Expand Down
Loading