From 9767b0771cc2c913090b31fb509f9f2af9b3a37a Mon Sep 17 00:00:00 2001 From: LIU ZHE YOU Date: Mon, 18 May 2026 23:11:38 +0800 Subject: [PATCH 1/9] Add prek hook to check that no new @provide_session functions declare `session` positionally --- .pre-commit-config.yaml | 6 + .../ci/prek/check_provide_session_kwargs.py | 308 +++++++++++++++++ .../prek/known_provide_session_positional.txt | 89 +++++ .../prek/test_check_provide_session_kwargs.py | 320 ++++++++++++++++++ 4 files changed, 723 insertions(+) create mode 100755 scripts/ci/prek/check_provide_session_kwargs.py create mode 100644 scripts/ci/prek/known_provide_session_positional.txt create mode 100644 scripts/tests/ci/prek/test_check_provide_session_kwargs.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d019a32715075..0b659de77b80c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1064,6 +1064,12 @@ repos: language: python pass_filenames: true files: ^(airflow-core|airflow-ctl|task-sdk|providers|shared)/.*\.py$ + - id: check-no-new-provide-session-positional + name: Check that no new @provide_session functions declare `session` positionally + entry: ./scripts/ci/prek/check_provide_session_kwargs.py + language: python + pass_filenames: true + files: ^(airflow-core|airflow-ctl|task-sdk|providers|shared)/.*\.py$|^scripts/ci/prek/known_provide_session_positional\.txt$|^scripts/ci/prek/check_provide_session_kwargs\.py$ - id: check-no-new-airflow-core-utils-modules name: Check that no new modules are added under airflow-core/src/airflow/utils entry: ./scripts/ci/prek/check_no_new_airflow_core_utils_modules.py diff --git a/scripts/ci/prek/check_provide_session_kwargs.py b/scripts/ci/prek/check_provide_session_kwargs.py new file mode 100755 index 0000000000000..ae1bcfc6dec50 --- /dev/null +++ b/scripts/ci/prek/check_provide_session_kwargs.py @@ -0,0 +1,308 @@ +#!/usr/bin/env python +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "rich>=13.0.0", +# ] +# /// +"""Check that no new ``@provide_session`` functions declare ``session`` positionally. + +The project convention is that any function decorated with ``@provide_session`` +must declare ``session`` as keyword-only (after a bare ``*`` in the signature), +so callers cannot pass it positionally by accident. See +``contributing-docs/05_pull_requests.rst#database-session-handling``. + +All *existing* offenders are recorded in ``known_provide_session_positional.txt`` +next to this script as ``relative/path::N`` entries (one per file), where ``N`` +is the maximum number of ``@provide_session`` functions with a positional +``session`` argument allowed in that file. A file whose current count exceeds +the recorded limit is treated as a violation – move the ``session`` argument +behind a bare ``*`` instead. + +Modes +----- +Default (files passed by prek/pre-commit): + Check only the supplied files; fail if any file's count exceeds the limit. + When a file's count has *decreased*, the allowlist entry is tightened + automatically and the hook exits with a non-zero code so that pre-commit + reports the modified allowlist – just stage + ``scripts/ci/prek/known_provide_session_positional.txt`` and re-run. + +``--all-files``: + Walk the whole repository and check every ``.py`` file. + +``--cleanup``: + Remove entries for files that no longer exist. Safe to run at any time; + does not add new entries or raise limits. + +``--generate``: + Scan the whole repository and *rebuild* the allowlist from scratch. + Intended for the initial setup or after a large-scale clean-up sprint. +""" + +from __future__ import annotations + +import argparse +import ast +import typing +from pathlib import Path + +from rich.console import Console +from rich.panel import Panel + +console = Console(color_system="standard", width=200) + +REPO_ROOT = Path(__file__).parents[3] + +_PROVIDE_SESSION_DECORATOR = "provide_session" + + +def _has_provide_session_decorator(nodes: list[ast.expr]) -> bool: + """Whether one of ``nodes`` is a ``@provide_session`` decorator. + + Accepts both bare names (``@provide_session``) and attribute access + (``@something.provide_session``). + """ + for node in nodes: + if isinstance(node, ast.Name) and node.id == _PROVIDE_SESSION_DECORATOR: + return True + if isinstance(node, ast.Attribute) and node.attr == _PROVIDE_SESSION_DECORATOR: + return True + return False + + +def _session_is_positional(args: ast.arguments) -> ast.arg | None: + """Return the ``session`` arg if it is positional (not keyword-only).""" + for argument in args.args: + if argument.arg == "session": + return argument + return None + + +def _iter_positional_session_in_provide_session( + path: Path, +) -> typing.Iterator[tuple[ast.FunctionDef | ast.AsyncFunctionDef, ast.arg]]: + """Yield ``@provide_session`` functions in *path* whose ``session`` is positional.""" + try: + source = path.read_text(encoding="utf-8") + except OSError: + return + try: + tree = ast.parse(source, str(path)) + except SyntaxError: + return + for node in ast.walk(tree): + if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + continue + if not _has_provide_session_decorator(node.decorator_list): + continue + argument = _session_is_positional(node.args) + if argument is None: + continue + yield node, argument + + +def _count_violations(path: Path) -> int: + return sum(1 for _ in _iter_positional_session_in_provide_session(path)) + + +class AllowlistManager: + def __init__(self, allowlist_file: Path) -> None: + self.allowlist_file = allowlist_file + + def load(self) -> dict[str, int]: + if not self.allowlist_file.exists(): + return {} + + result: dict[str, int] = {} + for raw_line in self.allowlist_file.read_text().splitlines(): + if not (stripped := raw_line.strip()): + continue + + rel_str, _, count_str = stripped.rpartition("::") + if not rel_str or not count_str: + continue + + try: + result[rel_str] = int(count_str) + except ValueError: + continue + + return result + + def save(self, counts: dict[str, int]) -> None: + lines = [f"{rel}::{count}" for rel, count in sorted(counts.items())] + self.allowlist_file.write_text("\n".join(lines) + "\n") + + def generate(self) -> int: + console.print( + f"Scanning [cyan]{REPO_ROOT}[/cyan] for @provide_session functions with positional session …" + ) + counts: dict[str, int] = {} + for path in _iter_python_files(): + n = _count_violations(path) + if n > 0: + counts[str(path.relative_to(REPO_ROOT))] = n + + self.save(counts) + total = sum(counts.values()) + console.print( + f"[green]Generated[/green] [cyan]{self.allowlist_file.relative_to(REPO_ROOT)}[/cyan] " + f"with [bold]{len(counts)}[/bold] files / [bold]{total}[/bold] offenders." + ) + return 0 + + def cleanup(self) -> int: + allowlist = self.load() + if not allowlist: + console.print("[yellow]Allowlist is empty - nothing to clean up.[/yellow]") + return 0 + + stale: list[str] = [rel for rel in allowlist if not (REPO_ROOT / rel).exists()] + if stale: + console.print( + f"[yellow]Removing {len(stale)} stale entr{'y' if len(stale) == 1 else 'ies'}:[/yellow]" + ) + for s in sorted(stale): + console.print(f" [dim]-[/dim] {s}") + for s in stale: + del allowlist[s] + self.save(allowlist) + console.print( + f"\n[green]Updated[/green] [cyan]{self.allowlist_file.relative_to(REPO_ROOT)}[/cyan]" + ) + else: + console.print("[green]No stale entries found.[/green]") + return 0 + + +def _iter_python_files() -> list[Path]: + candidates: list[Path] = [] + for top in ("airflow-core", "airflow-ctl", "task-sdk", "providers", "shared"): + candidates.extend( + p.resolve() + for p in (REPO_ROOT / top).rglob("*.py") + if ".tox" not in p.parts and "__pycache__" not in p.parts + ) + return candidates + + +def _check_provide_session_kwargs( + files: list[Path], allowlist: dict[str, int], manager: AllowlistManager +) -> int: + violations: list[tuple[Path, int, int]] = [] + tightened: list[tuple[str, int, int]] = [] + + for path in files: + if not path.exists() or path.suffix != ".py": + continue + actual = _count_violations(path) + rel = str(path.relative_to(REPO_ROOT)) + allowed = allowlist.get(rel, 0) + if actual > allowed: + violations.append((path, actual, allowed)) + elif actual < allowed: + if actual == 0: + del allowlist[rel] + else: + allowlist[rel] = actual + tightened.append((rel, allowed, actual)) + + if tightened: + manager.save(allowlist) + console.print( + f"[green]Tightened {len(tightened)} entr{'y' if len(tightened) == 1 else 'ies'} " + f"in [cyan]{manager.allowlist_file.relative_to(REPO_ROOT)}[/cyan][/green] " + "(stage the updated file):" + ) + for rel, old, new in tightened: + console.print(f" [cyan]{rel}[/cyan] {old} -> {new}") + + if violations: + console.print( + Panel.fit( + "New [bold]@provide_session[/bold] function with positional ``session`` detected.\n" + "Move ``session`` after a bare ``*`` in the signature so callers must pass it by keyword:\n\n" + " [cyan]@provide_session\n" + " def foo(arg, *, session: Session = NEW_SESSION) -> None: ...[/cyan]\n\n" + "If this usage is intentional and pre-existing, run:\n\n" + " [cyan]uv run ./scripts/ci/prek/check_provide_session_kwargs.py --generate[/cyan]\n\n" + "to regenerate the allowlist, then commit the updated\n" + "[cyan]scripts/ci/prek/known_provide_session_positional.txt[/cyan].", + title="[red]Check failed[/red]", + border_style="red", + ) + ) + for path, actual, allowed in violations: + console.print(f" [cyan]{path.relative_to(REPO_ROOT)}[/cyan] count={actual} (allowed={allowed})") + for func, argument in _iter_positional_session_in_provide_session(path): + console.print(f" [dim]L{argument.lineno}[/dim] def {func.name}(...)") + return 1 + + return 1 if tightened else 0 + + +def main(argv: list[str] | None = None) -> int: + parser = argparse.ArgumentParser( + description="Prevent new @provide_session functions from declaring `session` positionally.", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + parser.add_argument("files", nargs="*", metavar="FILE", help="Files to check (provided by prek)") + parser.add_argument( + "--all-files", + action="store_true", + help="Check every Python file in the repository", + ) + parser.add_argument( + "--cleanup", + action="store_true", + help="Remove stale entries from the allowlist and exit", + ) + parser.add_argument( + "--generate", + action="store_true", + help="Regenerate the allowlist from the current codebase and exit", + ) + args = parser.parse_args(argv) + + manager = AllowlistManager(Path(__file__).parent / "known_provide_session_positional.txt") + + if args.generate: + return manager.generate() + + if args.cleanup: + return manager.cleanup() + + allowlist = manager.load() + + if args.all_files: + return _check_provide_session_kwargs(_iter_python_files(), allowlist, manager) + + if not args.files: + console.print( + "[yellow]No files provided. Pass filenames or use --all-files to scan the whole repo.[/yellow]" + ) + return 0 + + return _check_provide_session_kwargs([Path(f).resolve() for f in args.files], allowlist, manager) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/ci/prek/known_provide_session_positional.txt b/scripts/ci/prek/known_provide_session_positional.txt new file mode 100644 index 0000000000000..76f515eeca7b7 --- /dev/null +++ b/scripts/ci/prek/known_provide_session_positional.txt @@ -0,0 +1,89 @@ +airflow-core/src/airflow/api/common/delete_dag.py::1 +airflow-core/src/airflow/api/common/mark_tasks.py::1 +airflow-core/src/airflow/callbacks/database_callback_sink.py::1 +airflow-core/src/airflow/cli/commands/dag_command.py::8 +airflow-core/src/airflow/cli/commands/jobs_command.py::1 +airflow-core/src/airflow/cli/commands/task_command.py::1 +airflow-core/src/airflow/cli/commands/team_command.py::4 +airflow-core/src/airflow/cli/commands/variable_command.py::1 +airflow-core/src/airflow/dag_processing/dagbag.py::1 +airflow-core/src/airflow/dag_processing/manager.py::4 +airflow-core/src/airflow/jobs/base_job_runner.py::2 +airflow-core/src/airflow/jobs/job.py::7 +airflow-core/src/airflow/jobs/scheduler_job_runner.py::11 +airflow-core/src/airflow/jobs/triggerer_job_runner.py::1 +airflow-core/src/airflow/models/connection.py::2 +airflow-core/src/airflow/models/dag.py::7 +airflow-core/src/airflow/models/dagcode.py::6 +airflow-core/src/airflow/models/dagrun.py::15 +airflow-core/src/airflow/models/dagwarning.py::1 +airflow-core/src/airflow/models/deadline.py::1 +airflow-core/src/airflow/models/deadline_alert.py::1 +airflow-core/src/airflow/models/pool.py::13 +airflow-core/src/airflow/models/renderedtifields.py::4 +airflow-core/src/airflow/models/revoked_token.py::2 +airflow-core/src/airflow/models/serialized_dag.py::6 +airflow-core/src/airflow/models/taskinstance.py::21 +airflow-core/src/airflow/models/taskinstancehistory.py::2 +airflow-core/src/airflow/models/team.py::1 +airflow-core/src/airflow/models/trigger.py::7 +airflow-core/src/airflow/models/variable.py::2 +airflow-core/src/airflow/secrets/metastore.py::2 +airflow-core/src/airflow/serialization/definitions/dag.py::2 +airflow-core/src/airflow/ti_deps/deps/base_ti_dep.py::2 +airflow-core/src/airflow/ti_deps/deps/dag_ti_slots_available_dep.py::1 +airflow-core/src/airflow/ti_deps/deps/dag_unpaused_dep.py::1 +airflow-core/src/airflow/ti_deps/deps/dagrun_exists_dep.py::1 +airflow-core/src/airflow/ti_deps/deps/exec_date_after_start_date_dep.py::1 +airflow-core/src/airflow/ti_deps/deps/not_in_retry_period_dep.py::1 +airflow-core/src/airflow/ti_deps/deps/pool_slots_available_dep.py::1 +airflow-core/src/airflow/ti_deps/deps/prev_dagrun_dep.py::1 +airflow-core/src/airflow/ti_deps/deps/ready_to_reschedule.py::1 +airflow-core/src/airflow/ti_deps/deps/runnable_exec_date_dep.py::1 +airflow-core/src/airflow/ti_deps/deps/task_concurrency_dep.py::1 +airflow-core/src/airflow/ti_deps/deps/task_not_running_dep.py::1 +airflow-core/src/airflow/ti_deps/deps/valid_state_dep.py::1 +airflow-core/src/airflow/utils/cli_action_loggers.py::1 +airflow-core/src/airflow/utils/db.py::7 +airflow-core/src/airflow/utils/db_cleanup.py::2 +airflow-core/src/airflow/utils/log/file_task_handler.py::1 +airflow-core/tests/unit/api_fastapi/common/test_exceptions.py::4 +airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_assets.py::19 +airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_connections.py::2 +airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_run.py::1 +airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_tags.py::1 +airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_warning.py::1 +airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_event_logs.py::1 +airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_import_error.py::8 +airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_job.py::1 +airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_monitor.py::2 +airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_pools.py::2 +airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_variables.py::2 +airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_xcom.py::3 +airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_calendar.py::2 +airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_dags.py::1 +airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_gantt.py::1 +airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_grid.py::1 +airflow-core/tests/unit/cli/commands/test_rotate_fernet_key_command.py::2 +airflow-core/tests/unit/jobs/test_scheduler_job.py::1 +airflow-core/tests/unit/listeners/test_listeners.py::7 +airflow-core/tests/unit/models/test_taskinstance.py::4 +airflow-core/tests/unit/models/test_timestamp.py::2 +providers/amazon/src/airflow/providers/amazon/aws/triggers/emr.py::1 +providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py::1 +providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/template_rendering.py::1 +providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/pod.py::1 +providers/common/ai/tests/unit/common/ai/plugins/test_hitl_review.py::1 +providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py::2 +providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py::3 +providers/edge3/src/airflow/providers/edge3/models/edge_worker.py::10 +providers/edge3/src/airflow/providers/edge3/plugins/edge_executor_plugin.py::1 +providers/edge3/src/airflow/providers/edge3/worker_api/routes/logs.py::1 +providers/fab/src/airflow/providers/fab/auth_manager/cli_commands/permissions_command.py::1 +providers/google/src/airflow/providers/google/cloud/triggers/bigquery.py::1 +providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py::3 +providers/openlineage/src/airflow/providers/openlineage/utils/utils.py::1 +providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py::1 +providers/standard/src/airflow/providers/standard/sensors/external_task.py::1 +providers/standard/src/airflow/providers/standard/utils/sensor_helper.py::1 +providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py::3 diff --git a/scripts/tests/ci/prek/test_check_provide_session_kwargs.py b/scripts/tests/ci/prek/test_check_provide_session_kwargs.py new file mode 100644 index 0000000000000..0f7d298ba7a3a --- /dev/null +++ b/scripts/tests/ci/prek/test_check_provide_session_kwargs.py @@ -0,0 +1,320 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import ast +import textwrap +from pathlib import Path + +import pytest +from ci.prek.check_provide_session_kwargs import ( + AllowlistManager, + _check_provide_session_kwargs, + _count_violations, + _has_provide_session_decorator, + _iter_positional_session_in_provide_session, + _session_is_positional, +) + + +@pytest.fixture +def find_violations(write_python_file): + """Factory fixture: write code to a temp file and return positional-session violations.""" + + def _check(code: str) -> list[tuple[ast.FunctionDef, ast.arg]]: + path = write_python_file(code) + return list(_iter_positional_session_in_provide_session(path)) + + return _check + + +@pytest.fixture +def fake_repo(tmp_path, monkeypatch): + """Create a fake repo layout and patch REPO_ROOT so paths resolve correctly.""" + import ci.prek.check_provide_session_kwargs as hook + + monkeypatch.setattr(hook, "REPO_ROOT", tmp_path) + + def _write(rel: str, code: str) -> Path: + path = tmp_path / rel + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(textwrap.dedent(code)) + return path + + _write.root = tmp_path # type: ignore[attr-defined] + return _write + + +class TestHasProvideSessionDecorator: + def test_provide_session_name(self): + func = ast.parse("@provide_session\ndef foo(): pass").body[0] + assert _has_provide_session_decorator(func.decorator_list) is True + + def test_provide_session_attribute(self): + func = ast.parse("@utils.provide_session\ndef foo(): pass").body[0] + assert _has_provide_session_decorator(func.decorator_list) is True + + def test_no_decorator(self): + func = ast.parse("def foo(): pass").body[0] + assert _has_provide_session_decorator(func.decorator_list) is False + + def test_unrelated_decorator(self): + func = ast.parse("@staticmethod\ndef foo(): pass").body[0] + assert _has_provide_session_decorator(func.decorator_list) is False + + def test_multiple_decorators_including_provide_session(self): + func = ast.parse("@staticmethod\n@provide_session\ndef foo(): pass").body[0] + assert _has_provide_session_decorator(func.decorator_list) is True + + +class TestSessionIsPositional: + def test_no_session_arg(self): + func = ast.parse("def foo(x, y): pass").body[0] + assert _session_is_positional(func.args) is None + + def test_session_positional(self): + func = ast.parse("def foo(session=NEW_SESSION): pass").body[0] + argument = _session_is_positional(func.args) + assert argument is not None + assert argument.arg == "session" + + def test_session_keyword_only(self): + func = ast.parse("def foo(*, session=NEW_SESSION): pass").body[0] + assert _session_is_positional(func.args) is None + + def test_session_positional_among_other_args(self): + func = ast.parse("def foo(x, y, session=NEW_SESSION): pass").body[0] + argument = _session_is_positional(func.args) + assert argument is not None + assert argument.arg == "session" + + def test_session_kwonly_after_other_positional(self): + func = ast.parse("def foo(x, y, *, session=NEW_SESSION): pass").body[0] + assert _session_is_positional(func.args) is None + + +class TestIterPositionalSessionInProvideSession: + def test_keyword_only_session_is_clean(self, find_violations): + code = """\ + @provide_session + def foo(*, session=NEW_SESSION): + pass + """ + assert find_violations(code) == [] + + def test_positional_session_is_flagged(self, find_violations): + code = """\ + @provide_session + def foo(session=NEW_SESSION): + pass + """ + violations = find_violations(code) + assert len(violations) == 1 + func, argument = violations[0] + assert func.name == "foo" + assert argument.arg == "session" + + def test_no_provide_session_decorator_is_ignored(self, find_violations): + code = """\ + def foo(session=NEW_SESSION): + pass + """ + assert find_violations(code) == [] + + def test_async_function_with_positional_session_is_flagged(self, find_violations): + code = """\ + @provide_session + async def foo(session=NEW_SESSION): + pass + """ + violations = find_violations(code) + assert len(violations) == 1 + + def test_method_with_positional_session_is_flagged(self, find_violations): + code = """\ + class C: + @provide_session + def foo(self, session=NEW_SESSION): + pass + """ + violations = find_violations(code) + assert len(violations) == 1 + assert violations[0][0].name == "foo" + + def test_attribute_decorator_is_recognised(self, find_violations): + code = """\ + @airflow.utils.session.provide_session + def foo(session=NEW_SESSION): + pass + """ + violations = find_violations(code) + assert len(violations) == 1 + + def test_count_violations_multiple_in_file(self, write_python_file): + code = """\ + @provide_session + def a(session=NEW_SESSION): + pass + + @provide_session + def b(x, session=NEW_SESSION): + pass + + @provide_session + def c(*, session=NEW_SESSION): + pass + """ + path = write_python_file(code) + assert _count_violations(path) == 2 + + def test_syntax_error_returns_no_violations(self, write_python_file): + path = write_python_file("def foo(:\n pass") + assert _count_violations(path) == 0 + + +class TestAllowlistManager: + def test_load_missing_file_returns_empty(self, tmp_path): + manager = AllowlistManager(tmp_path / "missing.txt") + assert manager.load() == {} + + def test_save_and_load_round_trip(self, tmp_path): + manager = AllowlistManager(tmp_path / "allowlist.txt") + manager.save({"b/file.py": 2, "a/file.py": 1}) + # Sorted by key in the file + text = (tmp_path / "allowlist.txt").read_text() + assert text.splitlines() == ["a/file.py::1", "b/file.py::2"] + assert manager.load() == {"a/file.py": 1, "b/file.py": 2} + + def test_load_skips_blank_and_malformed_lines(self, tmp_path): + path = tmp_path / "allowlist.txt" + path.write_text("\nvalid/file.py::3\nnocount\n::5\nbad::notanumber\n") + assert AllowlistManager(path).load() == {"valid/file.py": 3} + + +class TestCheckProvideSessionKwargs: + def test_no_violations_in_clean_file(self, fake_repo, tmp_path): + path = fake_repo( + "airflow-core/src/airflow/clean.py", + """\ + @provide_session + def foo(*, session=NEW_SESSION): + pass + """, + ) + manager = AllowlistManager(tmp_path / "allowlist.txt") + assert _check_provide_session_kwargs([path], {}, manager) == 0 + + def test_new_violation_fails(self, fake_repo, tmp_path): + path = fake_repo( + "airflow-core/src/airflow/bad.py", + """\ + @provide_session + def foo(session=NEW_SESSION): + pass + """, + ) + manager = AllowlistManager(tmp_path / "allowlist.txt") + assert _check_provide_session_kwargs([path], {}, manager) == 1 + + def test_violation_within_allowlist_passes(self, fake_repo, tmp_path): + path = fake_repo( + "airflow-core/src/airflow/grandfathered.py", + """\ + @provide_session + def foo(session=NEW_SESSION): + pass + """, + ) + manager = AllowlistManager(tmp_path / "allowlist.txt") + allowlist = {"airflow-core/src/airflow/grandfathered.py": 1} + assert _check_provide_session_kwargs([path], allowlist, manager) == 0 + + def test_exceeding_allowlist_fails(self, fake_repo, tmp_path): + path = fake_repo( + "airflow-core/src/airflow/grew.py", + """\ + @provide_session + def a(session=NEW_SESSION): + pass + + @provide_session + def b(session=NEW_SESSION): + pass + """, + ) + manager = AllowlistManager(tmp_path / "allowlist.txt") + allowlist = {"airflow-core/src/airflow/grew.py": 1} + assert _check_provide_session_kwargs([path], allowlist, manager) == 1 + + def test_reducing_violations_tightens_allowlist(self, fake_repo, tmp_path): + path = fake_repo( + "airflow-core/src/airflow/improved.py", + """\ + @provide_session + def foo(session=NEW_SESSION): + pass + + @provide_session + def bar(*, session=NEW_SESSION): + pass + """, + ) + manager = AllowlistManager(tmp_path / "allowlist.txt") + allowlist = {"airflow-core/src/airflow/improved.py": 2} + # Exit non-zero so pre-commit reports the modified allowlist + assert _check_provide_session_kwargs([path], allowlist, manager) == 1 + assert manager.load() == {"airflow-core/src/airflow/improved.py": 1} + + def test_fixing_all_violations_removes_entry(self, fake_repo, tmp_path): + path = fake_repo( + "airflow-core/src/airflow/fixed.py", + """\ + @provide_session + def foo(*, session=NEW_SESSION): + pass + """, + ) + manager = AllowlistManager(tmp_path / "allowlist.txt") + allowlist = {"airflow-core/src/airflow/fixed.py": 1} + assert _check_provide_session_kwargs([path], allowlist, manager) == 1 + assert manager.load() == {} + + def test_non_python_file_is_skipped(self, fake_repo, tmp_path): + path = fake_repo( + "airflow-core/src/airflow/not_python.txt", "@provide_session\ndef foo(session=N): pass\n" + ) + manager = AllowlistManager(tmp_path / "allowlist.txt") + assert _check_provide_session_kwargs([path], {}, manager) == 0 + + +class TestCleanup: + def test_cleanup_removes_stale_entries(self, fake_repo, tmp_path): + fake_repo("airflow-core/src/airflow/keeper.py", "pass") + allowlist_path = tmp_path / "allowlist.txt" + manager = AllowlistManager(allowlist_path) + manager.save( + { + "airflow-core/src/airflow/keeper.py": 1, + "airflow-core/src/airflow/gone.py": 1, + } + ) + assert manager.cleanup() == 0 + assert manager.load() == {"airflow-core/src/airflow/keeper.py": 1} + + def test_cleanup_empty_allowlist(self, tmp_path): + manager = AllowlistManager(tmp_path / "allowlist.txt") + assert manager.cleanup() == 0 From b960a2f3529b1c1c0b03967e480370096121c7e5 Mon Sep 17 00:00:00 2001 From: LIU ZHE YOU Date: Tue, 19 May 2026 13:03:08 +0800 Subject: [PATCH 2/9] Address Copilot review feedback on provide_session prek hook - Detect positional-only `session` parameters (`def f(session, /, ...)`) by also scanning `args.posonlyargs`. - Close the allowlist-edit bypass: when `known_provide_session_positional.txt` is among the changed files, the hook now re-validates every listed file so contributors cannot raise counts without triggering the check. - Clarify that `--all-files` / `--generate` scan the project source roots (airflow-core, airflow-ctl, task-sdk, providers, shared), not the entire repository, and centralise that list in `_PROJECT_SOURCE_ROOTS`. - Update the test annotation for `find_violations` to include `ast.AsyncFunctionDef` so async functions are typed correctly. - Add tests covering positional-only detection and the allowlist-edit re-validation path. --- .../ci/prek/check_provide_session_kwargs.py | 54 ++++++++++++++--- .../prek/test_check_provide_session_kwargs.py | 59 ++++++++++++++++++- 2 files changed, 104 insertions(+), 9 deletions(-) diff --git a/scripts/ci/prek/check_provide_session_kwargs.py b/scripts/ci/prek/check_provide_session_kwargs.py index ae1bcfc6dec50..79b5f9ef84f49 100755 --- a/scripts/ci/prek/check_provide_session_kwargs.py +++ b/scripts/ci/prek/check_provide_session_kwargs.py @@ -45,15 +45,18 @@ ``scripts/ci/prek/known_provide_session_positional.txt`` and re-run. ``--all-files``: - Walk the whole repository and check every ``.py`` file. + Walk every ``.py`` file under the project source roots + (``airflow-core``, ``airflow-ctl``, ``task-sdk``, ``providers``, ``shared``) — + the same scope the pre-commit hook applies to. ``--cleanup``: Remove entries for files that no longer exist. Safe to run at any time; does not add new entries or raise limits. ``--generate``: - Scan the whole repository and *rebuild* the allowlist from scratch. - Intended for the initial setup or after a large-scale clean-up sprint. + Scan the same project source roots as ``--all-files`` and *rebuild* the + allowlist from scratch. Intended for the initial setup or after a + large-scale clean-up sprint. """ from __future__ import annotations @@ -72,6 +75,10 @@ _PROVIDE_SESSION_DECORATOR = "provide_session" +# Top-level directories scanned by ``--all-files`` / ``--generate``. Keep in sync with the +# ``files:`` pattern for this hook in ``.pre-commit-config.yaml``. +_PROJECT_SOURCE_ROOTS = ("airflow-core", "airflow-ctl", "task-sdk", "providers", "shared") + def _has_provide_session_decorator(nodes: list[ast.expr]) -> bool: """Whether one of ``nodes`` is a ``@provide_session`` decorator. @@ -88,8 +95,11 @@ def _has_provide_session_decorator(nodes: list[ast.expr]) -> bool: def _session_is_positional(args: ast.arguments) -> ast.arg | None: - """Return the ``session`` arg if it is positional (not keyword-only).""" - for argument in args.args: + """Return the ``session`` arg if it is positional (not keyword-only). + + Covers both regular positional args and positional-only args (``def f(session, /, ...)``). + """ + for argument in (*args.posonlyargs, *args.args): if argument.arg == "session": return argument return None @@ -194,7 +204,7 @@ def cleanup(self) -> int: def _iter_python_files() -> list[Path]: candidates: list[Path] = [] - for top in ("airflow-core", "airflow-ctl", "task-sdk", "providers", "shared"): + for top in _PROJECT_SOURCE_ROOTS: candidates.extend( p.resolve() for p in (REPO_ROOT / top).rglob("*.py") @@ -268,7 +278,10 @@ def main(argv: list[str] | None = None) -> int: parser.add_argument( "--all-files", action="store_true", - help="Check every Python file in the repository", + help=( + "Check every Python file under the project source roots " + "(airflow-core, airflow-ctl, task-sdk, providers, shared)" + ), ) parser.add_argument( "--cleanup", @@ -301,7 +314,32 @@ def main(argv: list[str] | None = None) -> int: ) return 0 - return _check_provide_session_kwargs([Path(f).resolve() for f in args.files], allowlist, manager) + paths = [Path(f).resolve() for f in args.files] + paths = _expand_for_allowlist_edits(paths, manager, allowlist) + return _check_provide_session_kwargs(paths, allowlist, manager) + + +def _expand_for_allowlist_edits( + paths: list[Path], manager: AllowlistManager, allowlist: dict[str, int] +) -> list[Path]: + """Add allowlisted files when the allowlist itself is being changed. + + Without this, a contributor could raise counts in + ``known_provide_session_positional.txt`` and the hook would do no validation + (since only the ``.txt`` file is passed), letting the loosened allowlist + sail through. + """ + if not any(p == manager.allowlist_file for p in paths): + return paths + + expanded = list(paths) + seen = {p for p in paths if p.suffix == ".py"} + for rel in allowlist: + candidate = (REPO_ROOT / rel).resolve() + if candidate.exists() and candidate not in seen: + seen.add(candidate) + expanded.append(candidate) + return expanded if __name__ == "__main__": diff --git a/scripts/tests/ci/prek/test_check_provide_session_kwargs.py b/scripts/tests/ci/prek/test_check_provide_session_kwargs.py index 0f7d298ba7a3a..aa4412da8c2b9 100644 --- a/scripts/tests/ci/prek/test_check_provide_session_kwargs.py +++ b/scripts/tests/ci/prek/test_check_provide_session_kwargs.py @@ -25,6 +25,7 @@ AllowlistManager, _check_provide_session_kwargs, _count_violations, + _expand_for_allowlist_edits, _has_provide_session_decorator, _iter_positional_session_in_provide_session, _session_is_positional, @@ -35,7 +36,7 @@ def find_violations(write_python_file): """Factory fixture: write code to a temp file and return positional-session violations.""" - def _check(code: str) -> list[tuple[ast.FunctionDef, ast.arg]]: + def _check(code: str) -> list[tuple[ast.FunctionDef | ast.AsyncFunctionDef, ast.arg]]: path = write_python_file(code) return list(_iter_positional_session_in_provide_session(path)) @@ -106,6 +107,12 @@ def test_session_kwonly_after_other_positional(self): func = ast.parse("def foo(x, y, *, session=NEW_SESSION): pass").body[0] assert _session_is_positional(func.args) is None + def test_session_positional_only(self): + func = ast.parse("def foo(session, /, x): pass").body[0] + argument = _session_is_positional(func.args) + assert argument is not None + assert argument.arg == "session" + class TestIterPositionalSessionInProvideSession: def test_keyword_only_session_is_clean(self, find_violations): @@ -301,6 +308,56 @@ def test_non_python_file_is_skipped(self, fake_repo, tmp_path): assert _check_provide_session_kwargs([path], {}, manager) == 0 +class TestExpandForAllowlistEdits: + def test_unchanged_when_allowlist_not_in_paths(self, fake_repo, tmp_path): + py = fake_repo("airflow-core/src/airflow/x.py", "pass") + manager = AllowlistManager(tmp_path / "allowlist.txt") + assert _expand_for_allowlist_edits([py], manager, {"airflow-core/src/airflow/x.py": 1}) == [py] + + def test_appends_allowlisted_files_when_allowlist_edited(self, fake_repo, tmp_path): + allowlist_path = tmp_path / "allowlist.txt" + manager = AllowlistManager(allowlist_path) + listed = fake_repo("airflow-core/src/airflow/listed.py", "pass") + # File in allowlist that does not exist on disk should be ignored. + result = _expand_for_allowlist_edits( + [allowlist_path], + manager, + {"airflow-core/src/airflow/listed.py": 1, "airflow-core/src/airflow/gone.py": 1}, + ) + assert allowlist_path in result + assert listed in result + assert (tmp_path / "airflow-core/src/airflow/gone.py").resolve() not in result + + def test_re_validates_listed_files_so_loosening_cannot_bypass(self, fake_repo, tmp_path, capsys): + """Editing only the allowlist must still trigger validation of listed files.""" + rel = "airflow-core/src/airflow/loosened.py" + fake_repo( + rel, + """\ + @provide_session + def foo(session=NEW_SESSION): + pass + + @provide_session + def bar(session=NEW_SESSION): + pass + """, + ) + allowlist_path = tmp_path / "allowlist.txt" + manager = AllowlistManager(allowlist_path) + # Allowlist loosened to 5 although file only has 2 positional sessions. + allowlist = {rel: 5} + manager.save(allowlist) + + # Only the allowlist file is "changed"; without re-validation this would return 0. + paths = _expand_for_allowlist_edits([allowlist_path], manager, allowlist) + rc = _check_provide_session_kwargs(paths, allowlist, manager) + + # Tightened from 5 -> 2, so the hook exits non-zero to surface the modified allowlist. + assert rc == 1 + assert manager.load() == {rel: 2} + + class TestCleanup: def test_cleanup_removes_stale_entries(self, fake_repo, tmp_path): fake_repo("airflow-core/src/airflow/keeper.py", "pass") From 6ee56796cbe1992490638b128fdee7f575a347f0 Mon Sep 17 00:00:00 2001 From: LIU ZHE YOU Date: Tue, 19 May 2026 13:17:42 +0800 Subject: [PATCH 3/9] Normalize paths when detecting allowlist-edit bypass Resolve both sides of the allowlist-file comparison in `_expand_for_allowlist_edits` so the detection is robust to symlinks and to callers passing unresolved paths. Without this, a symlinked or non-normalized invocation could skip the re-validation path and let a loosened allowlist sail through. Update tests to pass resolved paths (matching what `main()` does in production) and add a dedicated test that points the helper at a symlink to the allowlist file. --- .../ci/prek/check_provide_session_kwargs.py | 8 ++++-- .../prek/test_check_provide_session_kwargs.py | 26 ++++++++++++++++--- 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/scripts/ci/prek/check_provide_session_kwargs.py b/scripts/ci/prek/check_provide_session_kwargs.py index 79b5f9ef84f49..0fd3aa47218c3 100755 --- a/scripts/ci/prek/check_provide_session_kwargs.py +++ b/scripts/ci/prek/check_provide_session_kwargs.py @@ -328,12 +328,16 @@ def _expand_for_allowlist_edits( ``known_provide_session_positional.txt`` and the hook would do no validation (since only the ``.txt`` file is passed), letting the loosened allowlist sail through. + + Both sides of the allowlist-file comparison are resolved so the detection is + robust to symlinks and unresolved inputs (the hook can be invoked with either). """ - if not any(p == manager.allowlist_file for p in paths): + allowlist_file = manager.allowlist_file.resolve() + if not any(p.resolve() == allowlist_file for p in paths): return paths expanded = list(paths) - seen = {p for p in paths if p.suffix == ".py"} + seen = {p.resolve() for p in paths if p.suffix == ".py"} for rel in allowlist: candidate = (REPO_ROOT / rel).resolve() if candidate.exists() and candidate not in seen: diff --git a/scripts/tests/ci/prek/test_check_provide_session_kwargs.py b/scripts/tests/ci/prek/test_check_provide_session_kwargs.py index aa4412da8c2b9..0f05510d7411a 100644 --- a/scripts/tests/ci/prek/test_check_provide_session_kwargs.py +++ b/scripts/tests/ci/prek/test_check_provide_session_kwargs.py @@ -318,16 +318,33 @@ def test_appends_allowlisted_files_when_allowlist_edited(self, fake_repo, tmp_pa allowlist_path = tmp_path / "allowlist.txt" manager = AllowlistManager(allowlist_path) listed = fake_repo("airflow-core/src/airflow/listed.py", "pass") - # File in allowlist that does not exist on disk should be ignored. + # Pass a resolved path — matches production behavior (``main()`` resolves argv). result = _expand_for_allowlist_edits( - [allowlist_path], + [allowlist_path.resolve()], manager, {"airflow-core/src/airflow/listed.py": 1, "airflow-core/src/airflow/gone.py": 1}, ) - assert allowlist_path in result + assert allowlist_path.resolve() in result assert listed in result + # File in allowlist that does not exist on disk should be ignored. assert (tmp_path / "airflow-core/src/airflow/gone.py").resolve() not in result + def test_detection_robust_to_symlinked_allowlist(self, fake_repo, tmp_path): + """A symlink pointing at the allowlist file must still trigger expansion.""" + allowlist_path = tmp_path / "allowlist.txt" + manager = AllowlistManager(allowlist_path) + listed = fake_repo("airflow-core/src/airflow/listed.py", "pass") + manager.save({"airflow-core/src/airflow/listed.py": 1}) + + symlink = tmp_path / "allowlist_link.txt" + symlink.symlink_to(allowlist_path) + + # Production resolves argv before calling the helper — a symlinked path resolves + # to the real allowlist file and must be recognised as an allowlist edit. + result = _expand_for_allowlist_edits([symlink.resolve()], manager, manager.load()) + + assert listed in result + def test_re_validates_listed_files_so_loosening_cannot_bypass(self, fake_repo, tmp_path, capsys): """Editing only the allowlist must still trigger validation of listed files.""" rel = "airflow-core/src/airflow/loosened.py" @@ -350,7 +367,8 @@ def bar(session=NEW_SESSION): manager.save(allowlist) # Only the allowlist file is "changed"; without re-validation this would return 0. - paths = _expand_for_allowlist_edits([allowlist_path], manager, allowlist) + # Resolve the path to mirror what ``main()`` does in production. + paths = _expand_for_allowlist_edits([allowlist_path.resolve()], manager, allowlist) rc = _check_provide_session_kwargs(paths, allowlist, manager) # Tightened from 5 -> 2, so the hook exits non-zero to surface the modified allowlist. From 05f4b75e7cc2844ca93d785f085f3af1ad5254fd Mon Sep 17 00:00:00 2001 From: LIU ZHE YOU Date: Tue, 19 May 2026 13:23:55 +0800 Subject: [PATCH 4/9] Clarify --generate log line to name the scanned source roots The previous message said "Scanning " but the helper only walks the project source roots (airflow-core, airflow-ctl, task-sdk, providers, shared). Spell out those roots in the log so the output matches what is actually scanned. --- scripts/ci/prek/check_provide_session_kwargs.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/scripts/ci/prek/check_provide_session_kwargs.py b/scripts/ci/prek/check_provide_session_kwargs.py index 0fd3aa47218c3..f866c698e5d4d 100755 --- a/scripts/ci/prek/check_provide_session_kwargs.py +++ b/scripts/ci/prek/check_provide_session_kwargs.py @@ -161,8 +161,10 @@ def save(self, counts: dict[str, int]) -> None: self.allowlist_file.write_text("\n".join(lines) + "\n") def generate(self) -> int: + roots = ", ".join(_PROJECT_SOURCE_ROOTS) console.print( - f"Scanning [cyan]{REPO_ROOT}[/cyan] for @provide_session functions with positional session …" + f"Scanning project source roots ([cyan]{roots}[/cyan]) under [cyan]{REPO_ROOT}[/cyan] " + "for @provide_session functions with positional session …" ) counts: dict[str, int] = {} for path in _iter_python_files(): From ac9fdfef1eb047eace5b485b5aeec079094d371c Mon Sep 17 00:00:00 2001 From: LIU ZHE YOU Date: Tue, 19 May 2026 13:32:09 +0800 Subject: [PATCH 5/9] Harden allowlist handling against deletion and unsafe entries - Reject allowlist entries that escape `REPO_ROOT` (absolute paths or `..` segments) at load time. Without this guard, `relative_to()` in the violation reporter could crash on a maliciously or accidentally malformed allowlist line. Skipped entries are logged. - Detect when the allowlist file itself is among the changed paths but no longer exists on disk (deletion or rename in the same commit) and fail with a clear "restore or regenerate" message instead of silently passing. - Move the `check_provide_session_kwargs` module alias import to the top of the test module rather than inside the `fake_repo` fixture. Tests cover the missing-file failure path and the unsafe-entry skip. --- .../ci/prek/check_provide_session_kwargs.py | 40 ++++++++++++++++++- .../prek/test_check_provide_session_kwargs.py | 17 +++++++- 2 files changed, 54 insertions(+), 3 deletions(-) diff --git a/scripts/ci/prek/check_provide_session_kwargs.py b/scripts/ci/prek/check_provide_session_kwargs.py index f866c698e5d4d..6000fac2b4a9b 100755 --- a/scripts/ci/prek/check_provide_session_kwargs.py +++ b/scripts/ci/prek/check_provide_session_kwargs.py @@ -132,6 +132,22 @@ def _count_violations(path: Path) -> int: return sum(1 for _ in _iter_positional_session_in_provide_session(path)) +def _is_safe_relative(rel: str) -> bool: + """Whether ``rel`` is a plain relative path that stays inside ``REPO_ROOT``. + + Rejects absolute paths and any entry that resolves outside the repo root so + callers can ``relative_to(REPO_ROOT)`` without fear of a ``ValueError``. + """ + candidate = Path(rel) + if candidate.is_absolute(): + return False + try: + (REPO_ROOT / candidate).resolve().relative_to(REPO_ROOT.resolve()) + except ValueError: + return False + return True + + class AllowlistManager: def __init__(self, allowlist_file: Path) -> None: self.allowlist_file = allowlist_file @@ -150,10 +166,18 @@ def load(self) -> dict[str, int]: continue try: - result[rel_str] = int(count_str) + count = int(count_str) except ValueError: continue + if not _is_safe_relative(rel_str): + console.print( + f"[yellow]Ignoring unsafe allowlist entry (escapes repo root):[/yellow] {rel_str}" + ) + continue + + result[rel_str] = count + return result def save(self, counts: dict[str, int]) -> None: @@ -218,6 +242,20 @@ def _iter_python_files() -> list[Path]: def _check_provide_session_kwargs( files: list[Path], allowlist: dict[str, int], manager: AllowlistManager ) -> int: + allowlist_file = manager.allowlist_file.resolve() + if any(p.resolve() == allowlist_file for p in files) and not allowlist_file.exists(): + console.print( + Panel.fit( + f"Allowlist file [cyan]{allowlist_file}[/cyan] is missing.\n" + "It was passed to the hook but cannot be read, so the check cannot proceed.\n" + "Restore it from git or regenerate it with:\n\n" + " [cyan]uv run ./scripts/ci/prek/check_provide_session_kwargs.py --generate[/cyan]", + title="[red]Check failed[/red]", + border_style="red", + ) + ) + return 1 + violations: list[tuple[Path, int, int]] = [] tightened: list[tuple[str, int, int]] = [] diff --git a/scripts/tests/ci/prek/test_check_provide_session_kwargs.py b/scripts/tests/ci/prek/test_check_provide_session_kwargs.py index 0f05510d7411a..ea46555d5f21f 100644 --- a/scripts/tests/ci/prek/test_check_provide_session_kwargs.py +++ b/scripts/tests/ci/prek/test_check_provide_session_kwargs.py @@ -21,6 +21,7 @@ from pathlib import Path import pytest +from ci.prek import check_provide_session_kwargs as hook from ci.prek.check_provide_session_kwargs import ( AllowlistManager, _check_provide_session_kwargs, @@ -46,8 +47,6 @@ def _check(code: str) -> list[tuple[ast.FunctionDef | ast.AsyncFunctionDef, ast. @pytest.fixture def fake_repo(tmp_path, monkeypatch): """Create a fake repo layout and patch REPO_ROOT so paths resolve correctly.""" - import ci.prek.check_provide_session_kwargs as hook - monkeypatch.setattr(hook, "REPO_ROOT", tmp_path) def _write(rel: str, code: str) -> Path: @@ -211,6 +210,13 @@ def test_load_skips_blank_and_malformed_lines(self, tmp_path): path.write_text("\nvalid/file.py::3\nnocount\n::5\nbad::notanumber\n") assert AllowlistManager(path).load() == {"valid/file.py": 3} + def test_load_skips_unsafe_entries(self, fake_repo, tmp_path): + """Entries that escape REPO_ROOT (absolute paths or `..` segments) are ignored.""" + path = tmp_path / "allowlist.txt" + path.write_text("airflow-core/src/airflow/safe.py::1\n../escape.py::1\n/etc/passwd::1\n") + # `fake_repo` patches REPO_ROOT to tmp_path so the safety check is meaningful. + assert AllowlistManager(path).load() == {"airflow-core/src/airflow/safe.py": 1} + class TestCheckProvideSessionKwargs: def test_no_violations_in_clean_file(self, fake_repo, tmp_path): @@ -307,6 +313,13 @@ def test_non_python_file_is_skipped(self, fake_repo, tmp_path): manager = AllowlistManager(tmp_path / "allowlist.txt") assert _check_provide_session_kwargs([path], {}, manager) == 0 + def test_missing_allowlist_file_fails_loudly(self, fake_repo, tmp_path): + """Passing the allowlist path when the file is missing must fail, not silently pass.""" + allowlist_path = tmp_path / "allowlist.txt" + manager = AllowlistManager(allowlist_path) + assert not allowlist_path.exists() + assert _check_provide_session_kwargs([allowlist_path.resolve()], {}, manager) == 1 + class TestExpandForAllowlistEdits: def test_unchanged_when_allowlist_not_in_paths(self, fake_repo, tmp_path): From 4070be00bd67106e01efa6a99b130827c6a6d42c Mon Sep 17 00:00:00 2001 From: LIU ZHE YOU Date: Tue, 19 May 2026 14:08:39 +0800 Subject: [PATCH 6/9] Handle invalid UTF-8 in provide_session hook --- scripts/ci/prek/check_provide_session_kwargs.py | 2 +- scripts/tests/ci/prek/test_check_provide_session_kwargs.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/scripts/ci/prek/check_provide_session_kwargs.py b/scripts/ci/prek/check_provide_session_kwargs.py index 6000fac2b4a9b..8bfb752c931c4 100755 --- a/scripts/ci/prek/check_provide_session_kwargs.py +++ b/scripts/ci/prek/check_provide_session_kwargs.py @@ -110,7 +110,7 @@ def _iter_positional_session_in_provide_session( ) -> typing.Iterator[tuple[ast.FunctionDef | ast.AsyncFunctionDef, ast.arg]]: """Yield ``@provide_session`` functions in *path* whose ``session`` is positional.""" try: - source = path.read_text(encoding="utf-8") + source = path.read_text(encoding="utf-8", errors="replace") except OSError: return try: diff --git a/scripts/tests/ci/prek/test_check_provide_session_kwargs.py b/scripts/tests/ci/prek/test_check_provide_session_kwargs.py index ea46555d5f21f..73eb482c13cbb 100644 --- a/scripts/tests/ci/prek/test_check_provide_session_kwargs.py +++ b/scripts/tests/ci/prek/test_check_provide_session_kwargs.py @@ -191,6 +191,12 @@ def test_syntax_error_returns_no_violations(self, write_python_file): path = write_python_file("def foo(:\n pass") assert _count_violations(path) == 0 + def test_invalid_utf8_does_not_crash(self, tmp_path): + path = tmp_path / "invalid_utf8.py" + path.write_bytes(b"# bad byte: \xff\n@provide_session\ndef foo(session=NEW_SESSION):\n pass\n") + + assert _count_violations(path) == 1 + class TestAllowlistManager: def test_load_missing_file_returns_empty(self, tmp_path): From 7c6e4521d41f4cf42107e3a7996043fd556aa1e2 Mon Sep 17 00:00:00 2001 From: LIU ZHE YOU Date: Tue, 19 May 2026 14:23:40 +0800 Subject: [PATCH 7/9] Remove unused provide_session test fixture attribute --- scripts/tests/ci/prek/test_check_provide_session_kwargs.py | 1 - 1 file changed, 1 deletion(-) diff --git a/scripts/tests/ci/prek/test_check_provide_session_kwargs.py b/scripts/tests/ci/prek/test_check_provide_session_kwargs.py index 73eb482c13cbb..6e42dc57d4932 100644 --- a/scripts/tests/ci/prek/test_check_provide_session_kwargs.py +++ b/scripts/tests/ci/prek/test_check_provide_session_kwargs.py @@ -55,7 +55,6 @@ def _write(rel: str, code: str) -> Path: path.write_text(textwrap.dedent(code)) return path - _write.root = tmp_path # type: ignore[attr-defined] return _write From 3190db6a533f5f7881741b8c357c09c0e29b5f5a Mon Sep 17 00:00:00 2001 From: LIU ZHE YOU Date: Tue, 19 May 2026 14:40:45 +0800 Subject: [PATCH 8/9] Close allowlist-removal bypass in provide_session prek hook When only the allowlist file is changed, the hook now unions the previous allowlist (from `git show HEAD:`) with the current one, so removing an entry can't silently drop coverage for a file that still has positional `session` arguments. --- .../ci/prek/check_provide_session_kwargs.py | 52 +++++++++++++-- .../prek/test_check_provide_session_kwargs.py | 64 +++++++++++++++++++ 2 files changed, 110 insertions(+), 6 deletions(-) diff --git a/scripts/ci/prek/check_provide_session_kwargs.py b/scripts/ci/prek/check_provide_session_kwargs.py index 8bfb752c931c4..831094c6fdc07 100755 --- a/scripts/ci/prek/check_provide_session_kwargs.py +++ b/scripts/ci/prek/check_provide_session_kwargs.py @@ -63,6 +63,7 @@ import argparse import ast +import subprocess import typing from pathlib import Path @@ -152,12 +153,16 @@ class AllowlistManager: def __init__(self, allowlist_file: Path) -> None: self.allowlist_file = allowlist_file - def load(self) -> dict[str, int]: - if not self.allowlist_file.exists(): - return {} + @staticmethod + def parse(text: str) -> dict[str, int]: + """Parse allowlist *text* into a ``{rel_path: count}`` mapping. + Same validation rules as :meth:`load` so we can reuse parsing for the + on-disk allowlist *and* for the previous version fetched from git when + guarding against entry-removal bypasses. + """ result: dict[str, int] = {} - for raw_line in self.allowlist_file.read_text().splitlines(): + for raw_line in text.splitlines(): if not (stripped := raw_line.strip()): continue @@ -180,6 +185,11 @@ def load(self) -> dict[str, int]: return result + def load(self) -> dict[str, int]: + if not self.allowlist_file.exists(): + return {} + return self.parse(self.allowlist_file.read_text()) + def save(self, counts: dict[str, int]) -> None: lines = [f"{rel}::{count}" for rel, count in sorted(counts.items())] self.allowlist_file.write_text("\n".join(lines) + "\n") @@ -359,6 +369,33 @@ def main(argv: list[str] | None = None) -> int: return _check_provide_session_kwargs(paths, allowlist, manager) +def _previous_allowlist(manager: AllowlistManager) -> dict[str, int]: + """Return the allowlist as recorded at ``HEAD``, or an empty dict. + + Used by :func:`_expand_for_allowlist_edits` so that *removing* an entry + cannot silently drop coverage: the previously-listed file is still + re-validated against the new (post-edit) allowlist. Returns an empty mapping + when git is unavailable, the file does not yet exist at ``HEAD``, or the + allowlist sits outside ``REPO_ROOT``. + """ + try: + rel = manager.allowlist_file.resolve().relative_to(REPO_ROOT.resolve()) + except ValueError: + return {} + try: + completed = subprocess.run( + ["git", "-C", str(REPO_ROOT), "show", f"HEAD:{rel.as_posix()}"], + capture_output=True, + text=True, + check=False, + ) + except (FileNotFoundError, OSError): + return {} + if completed.returncode != 0: + return {} + return AllowlistManager.parse(completed.stdout) + + def _expand_for_allowlist_edits( paths: list[Path], manager: AllowlistManager, allowlist: dict[str, int] ) -> list[Path]: @@ -367,7 +404,9 @@ def _expand_for_allowlist_edits( Without this, a contributor could raise counts in ``known_provide_session_positional.txt`` and the hook would do no validation (since only the ``.txt`` file is passed), letting the loosened allowlist - sail through. + sail through. We also union the *previous* allowlist (from ``HEAD``) so that + removing an entry cannot silently bypass the check for a file that still + has positional ``session`` arguments. Both sides of the allowlist-file comparison are resolved so the detection is robust to symlinks and unresolved inputs (the hook can be invoked with either). @@ -378,7 +417,8 @@ def _expand_for_allowlist_edits( expanded = list(paths) seen = {p.resolve() for p in paths if p.suffix == ".py"} - for rel in allowlist: + previous = _previous_allowlist(manager) + for rel in {*allowlist, *previous}: candidate = (REPO_ROOT / rel).resolve() if candidate.exists() and candidate not in seen: seen.add(candidate) diff --git a/scripts/tests/ci/prek/test_check_provide_session_kwargs.py b/scripts/tests/ci/prek/test_check_provide_session_kwargs.py index 6e42dc57d4932..30a4d4a5bd71b 100644 --- a/scripts/tests/ci/prek/test_check_provide_session_kwargs.py +++ b/scripts/tests/ci/prek/test_check_provide_session_kwargs.py @@ -17,6 +17,8 @@ from __future__ import annotations import ast +import os +import subprocess import textwrap from pathlib import Path @@ -29,6 +31,7 @@ _expand_for_allowlist_edits, _has_provide_session_decorator, _iter_positional_session_in_provide_session, + _previous_allowlist, _session_is_positional, ) @@ -58,6 +61,34 @@ def _write(rel: str, code: str) -> Path: return _write +@pytest.fixture +def git_repo(fake_repo, tmp_path): + """Initialise ``tmp_path`` as a git repo so ``git show HEAD:`` works. + + Returns a helper that commits the current working-tree contents under a given + message, so tests can stage a "previous" allowlist at HEAD before mutating it. + """ + env = { + **os.environ, + "GIT_AUTHOR_NAME": "t", + "GIT_AUTHOR_EMAIL": "t@t", + "GIT_COMMITTER_NAME": "t", + "GIT_COMMITTER_EMAIL": "t@t", + } + + def _run(*args: str) -> None: + subprocess.run(["git", "-C", str(tmp_path), *args], check=True, env=env, capture_output=True) + + _run("init", "-q", "-b", "main") + _run("config", "commit.gpgsign", "false") + + def _commit(message: str) -> None: + _run("add", "-A") + _run("commit", "-q", "--allow-empty", "-m", message) + + return _commit + + class TestHasProvideSessionDecorator: def test_provide_session_name(self): func = ast.parse("@provide_session\ndef foo(): pass").body[0] @@ -363,6 +394,39 @@ def test_detection_robust_to_symlinked_allowlist(self, fake_repo, tmp_path): assert listed in result + def test_includes_previous_allowlist_entries_when_removed(self, fake_repo, git_repo, tmp_path): + """Removing an entry from the allowlist must still re-check the previously-listed file.""" + rel = "airflow-core/src/airflow/dropped.py" + fake_repo( + rel, + """\ + @provide_session + def foo(session=NEW_SESSION): + pass + """, + ) + allowlist_path = tmp_path / "allowlist.txt" + manager = AllowlistManager(allowlist_path) + manager.save({rel: 1}) + git_repo("seed allowlist at HEAD") + + # Working tree: remove the entry, but the offending file still exists. + allowlist_path.write_text("") + current = manager.load() + assert current == {} + + expanded = _expand_for_allowlist_edits([allowlist_path.resolve()], manager, current) + # The previously-listed file must be re-validated. + assert (tmp_path / rel).resolve() in expanded + + # And the full check should fail because the file still has positional sessions. + assert _check_provide_session_kwargs(expanded, current, manager) == 1 + + def test_previous_allowlist_empty_when_no_git_history(self, fake_repo, tmp_path): + """Without a git repo the previous-allowlist lookup returns empty and does not crash.""" + manager = AllowlistManager(tmp_path / "allowlist.txt") + assert _previous_allowlist(manager) == {} + def test_re_validates_listed_files_so_loosening_cannot_bypass(self, fake_repo, tmp_path, capsys): """Editing only the allowlist must still trigger validation of listed files.""" rel = "airflow-core/src/airflow/loosened.py" From 19f2e2a1afb50a881a941f6349d6c937b3f0a9ba Mon Sep 17 00:00:00 2001 From: LIU ZHE YOU Date: Fri, 22 May 2026 16:38:46 +0800 Subject: [PATCH 9/9] CI: Fix scripts/ci/prek/known_provide_session_positional.txt --- scripts/ci/prek/known_provide_session_positional.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/ci/prek/known_provide_session_positional.txt b/scripts/ci/prek/known_provide_session_positional.txt index 76f515eeca7b7..d0c84e2f6b48f 100644 --- a/scripts/ci/prek/known_provide_session_positional.txt +++ b/scripts/ci/prek/known_provide_session_positional.txt @@ -19,7 +19,7 @@ airflow-core/src/airflow/models/dagrun.py::15 airflow-core/src/airflow/models/dagwarning.py::1 airflow-core/src/airflow/models/deadline.py::1 airflow-core/src/airflow/models/deadline_alert.py::1 -airflow-core/src/airflow/models/pool.py::13 +airflow-core/src/airflow/models/pool.py::11 airflow-core/src/airflow/models/renderedtifields.py::4 airflow-core/src/airflow/models/revoked_token.py::2 airflow-core/src/airflow/models/serialized_dag.py::6