diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d019a32715075..0b659de77b80c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1064,6 +1064,12 @@ repos: language: python pass_filenames: true files: ^(airflow-core|airflow-ctl|task-sdk|providers|shared)/.*\.py$ + - id: check-no-new-provide-session-positional + name: Check that no new @provide_session functions declare `session` positionally + entry: ./scripts/ci/prek/check_provide_session_kwargs.py + language: python + pass_filenames: true + files: ^(airflow-core|airflow-ctl|task-sdk|providers|shared)/.*\.py$|^scripts/ci/prek/known_provide_session_positional\.txt$|^scripts/ci/prek/check_provide_session_kwargs\.py$ - id: check-no-new-airflow-core-utils-modules name: Check that no new modules are added under airflow-core/src/airflow/utils entry: ./scripts/ci/prek/check_no_new_airflow_core_utils_modules.py diff --git a/scripts/ci/prek/check_provide_session_kwargs.py b/scripts/ci/prek/check_provide_session_kwargs.py new file mode 100755 index 0000000000000..831094c6fdc07 --- /dev/null +++ b/scripts/ci/prek/check_provide_session_kwargs.py @@ -0,0 +1,430 @@ +#!/usr/bin/env python +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "rich>=13.0.0", +# ] +# /// +"""Check that no new ``@provide_session`` functions declare ``session`` positionally. + +The project convention is that any function decorated with ``@provide_session`` +must declare ``session`` as keyword-only (after a bare ``*`` in the signature), +so callers cannot pass it positionally by accident. See +``contributing-docs/05_pull_requests.rst#database-session-handling``. + +All *existing* offenders are recorded in ``known_provide_session_positional.txt`` +next to this script as ``relative/path::N`` entries (one per file), where ``N`` +is the maximum number of ``@provide_session`` functions with a positional +``session`` argument allowed in that file. A file whose current count exceeds +the recorded limit is treated as a violation – move the ``session`` argument +behind a bare ``*`` instead. + +Modes +----- +Default (files passed by prek/pre-commit): + Check only the supplied files; fail if any file's count exceeds the limit. + When a file's count has *decreased*, the allowlist entry is tightened + automatically and the hook exits with a non-zero code so that pre-commit + reports the modified allowlist – just stage + ``scripts/ci/prek/known_provide_session_positional.txt`` and re-run. + +``--all-files``: + Walk every ``.py`` file under the project source roots + (``airflow-core``, ``airflow-ctl``, ``task-sdk``, ``providers``, ``shared``) — + the same scope the pre-commit hook applies to. + +``--cleanup``: + Remove entries for files that no longer exist. Safe to run at any time; + does not add new entries or raise limits. + +``--generate``: + Scan the same project source roots as ``--all-files`` and *rebuild* the + allowlist from scratch. Intended for the initial setup or after a + large-scale clean-up sprint. +""" + +from __future__ import annotations + +import argparse +import ast +import subprocess +import typing +from pathlib import Path + +from rich.console import Console +from rich.panel import Panel + +console = Console(color_system="standard", width=200) + +REPO_ROOT = Path(__file__).parents[3] + +_PROVIDE_SESSION_DECORATOR = "provide_session" + +# Top-level directories scanned by ``--all-files`` / ``--generate``. Keep in sync with the +# ``files:`` pattern for this hook in ``.pre-commit-config.yaml``. +_PROJECT_SOURCE_ROOTS = ("airflow-core", "airflow-ctl", "task-sdk", "providers", "shared") + + +def _has_provide_session_decorator(nodes: list[ast.expr]) -> bool: + """Whether one of ``nodes`` is a ``@provide_session`` decorator. + + Accepts both bare names (``@provide_session``) and attribute access + (``@something.provide_session``). + """ + for node in nodes: + if isinstance(node, ast.Name) and node.id == _PROVIDE_SESSION_DECORATOR: + return True + if isinstance(node, ast.Attribute) and node.attr == _PROVIDE_SESSION_DECORATOR: + return True + return False + + +def _session_is_positional(args: ast.arguments) -> ast.arg | None: + """Return the ``session`` arg if it is positional (not keyword-only). + + Covers both regular positional args and positional-only args (``def f(session, /, ...)``). + """ + for argument in (*args.posonlyargs, *args.args): + if argument.arg == "session": + return argument + return None + + +def _iter_positional_session_in_provide_session( + path: Path, +) -> typing.Iterator[tuple[ast.FunctionDef | ast.AsyncFunctionDef, ast.arg]]: + """Yield ``@provide_session`` functions in *path* whose ``session`` is positional.""" + try: + source = path.read_text(encoding="utf-8", errors="replace") + except OSError: + return + try: + tree = ast.parse(source, str(path)) + except SyntaxError: + return + for node in ast.walk(tree): + if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + continue + if not _has_provide_session_decorator(node.decorator_list): + continue + argument = _session_is_positional(node.args) + if argument is None: + continue + yield node, argument + + +def _count_violations(path: Path) -> int: + return sum(1 for _ in _iter_positional_session_in_provide_session(path)) + + +def _is_safe_relative(rel: str) -> bool: + """Whether ``rel`` is a plain relative path that stays inside ``REPO_ROOT``. + + Rejects absolute paths and any entry that resolves outside the repo root so + callers can ``relative_to(REPO_ROOT)`` without fear of a ``ValueError``. + """ + candidate = Path(rel) + if candidate.is_absolute(): + return False + try: + (REPO_ROOT / candidate).resolve().relative_to(REPO_ROOT.resolve()) + except ValueError: + return False + return True + + +class AllowlistManager: + def __init__(self, allowlist_file: Path) -> None: + self.allowlist_file = allowlist_file + + @staticmethod + def parse(text: str) -> dict[str, int]: + """Parse allowlist *text* into a ``{rel_path: count}`` mapping. + + Same validation rules as :meth:`load` so we can reuse parsing for the + on-disk allowlist *and* for the previous version fetched from git when + guarding against entry-removal bypasses. + """ + result: dict[str, int] = {} + for raw_line in text.splitlines(): + if not (stripped := raw_line.strip()): + continue + + rel_str, _, count_str = stripped.rpartition("::") + if not rel_str or not count_str: + continue + + try: + count = int(count_str) + except ValueError: + continue + + if not _is_safe_relative(rel_str): + console.print( + f"[yellow]Ignoring unsafe allowlist entry (escapes repo root):[/yellow] {rel_str}" + ) + continue + + result[rel_str] = count + + return result + + def load(self) -> dict[str, int]: + if not self.allowlist_file.exists(): + return {} + return self.parse(self.allowlist_file.read_text()) + + def save(self, counts: dict[str, int]) -> None: + lines = [f"{rel}::{count}" for rel, count in sorted(counts.items())] + self.allowlist_file.write_text("\n".join(lines) + "\n") + + def generate(self) -> int: + roots = ", ".join(_PROJECT_SOURCE_ROOTS) + console.print( + f"Scanning project source roots ([cyan]{roots}[/cyan]) under [cyan]{REPO_ROOT}[/cyan] " + "for @provide_session functions with positional session …" + ) + counts: dict[str, int] = {} + for path in _iter_python_files(): + n = _count_violations(path) + if n > 0: + counts[str(path.relative_to(REPO_ROOT))] = n + + self.save(counts) + total = sum(counts.values()) + console.print( + f"[green]Generated[/green] [cyan]{self.allowlist_file.relative_to(REPO_ROOT)}[/cyan] " + f"with [bold]{len(counts)}[/bold] files / [bold]{total}[/bold] offenders." + ) + return 0 + + def cleanup(self) -> int: + allowlist = self.load() + if not allowlist: + console.print("[yellow]Allowlist is empty - nothing to clean up.[/yellow]") + return 0 + + stale: list[str] = [rel for rel in allowlist if not (REPO_ROOT / rel).exists()] + if stale: + console.print( + f"[yellow]Removing {len(stale)} stale entr{'y' if len(stale) == 1 else 'ies'}:[/yellow]" + ) + for s in sorted(stale): + console.print(f" [dim]-[/dim] {s}") + for s in stale: + del allowlist[s] + self.save(allowlist) + console.print( + f"\n[green]Updated[/green] [cyan]{self.allowlist_file.relative_to(REPO_ROOT)}[/cyan]" + ) + else: + console.print("[green]No stale entries found.[/green]") + return 0 + + +def _iter_python_files() -> list[Path]: + candidates: list[Path] = [] + for top in _PROJECT_SOURCE_ROOTS: + candidates.extend( + p.resolve() + for p in (REPO_ROOT / top).rglob("*.py") + if ".tox" not in p.parts and "__pycache__" not in p.parts + ) + return candidates + + +def _check_provide_session_kwargs( + files: list[Path], allowlist: dict[str, int], manager: AllowlistManager +) -> int: + allowlist_file = manager.allowlist_file.resolve() + if any(p.resolve() == allowlist_file for p in files) and not allowlist_file.exists(): + console.print( + Panel.fit( + f"Allowlist file [cyan]{allowlist_file}[/cyan] is missing.\n" + "It was passed to the hook but cannot be read, so the check cannot proceed.\n" + "Restore it from git or regenerate it with:\n\n" + " [cyan]uv run ./scripts/ci/prek/check_provide_session_kwargs.py --generate[/cyan]", + title="[red]Check failed[/red]", + border_style="red", + ) + ) + return 1 + + violations: list[tuple[Path, int, int]] = [] + tightened: list[tuple[str, int, int]] = [] + + for path in files: + if not path.exists() or path.suffix != ".py": + continue + actual = _count_violations(path) + rel = str(path.relative_to(REPO_ROOT)) + allowed = allowlist.get(rel, 0) + if actual > allowed: + violations.append((path, actual, allowed)) + elif actual < allowed: + if actual == 0: + del allowlist[rel] + else: + allowlist[rel] = actual + tightened.append((rel, allowed, actual)) + + if tightened: + manager.save(allowlist) + console.print( + f"[green]Tightened {len(tightened)} entr{'y' if len(tightened) == 1 else 'ies'} " + f"in [cyan]{manager.allowlist_file.relative_to(REPO_ROOT)}[/cyan][/green] " + "(stage the updated file):" + ) + for rel, old, new in tightened: + console.print(f" [cyan]{rel}[/cyan] {old} -> {new}") + + if violations: + console.print( + Panel.fit( + "New [bold]@provide_session[/bold] function with positional ``session`` detected.\n" + "Move ``session`` after a bare ``*`` in the signature so callers must pass it by keyword:\n\n" + " [cyan]@provide_session\n" + " def foo(arg, *, session: Session = NEW_SESSION) -> None: ...[/cyan]\n\n" + "If this usage is intentional and pre-existing, run:\n\n" + " [cyan]uv run ./scripts/ci/prek/check_provide_session_kwargs.py --generate[/cyan]\n\n" + "to regenerate the allowlist, then commit the updated\n" + "[cyan]scripts/ci/prek/known_provide_session_positional.txt[/cyan].", + title="[red]Check failed[/red]", + border_style="red", + ) + ) + for path, actual, allowed in violations: + console.print(f" [cyan]{path.relative_to(REPO_ROOT)}[/cyan] count={actual} (allowed={allowed})") + for func, argument in _iter_positional_session_in_provide_session(path): + console.print(f" [dim]L{argument.lineno}[/dim] def {func.name}(...)") + return 1 + + return 1 if tightened else 0 + + +def main(argv: list[str] | None = None) -> int: + parser = argparse.ArgumentParser( + description="Prevent new @provide_session functions from declaring `session` positionally.", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + parser.add_argument("files", nargs="*", metavar="FILE", help="Files to check (provided by prek)") + parser.add_argument( + "--all-files", + action="store_true", + help=( + "Check every Python file under the project source roots " + "(airflow-core, airflow-ctl, task-sdk, providers, shared)" + ), + ) + parser.add_argument( + "--cleanup", + action="store_true", + help="Remove stale entries from the allowlist and exit", + ) + parser.add_argument( + "--generate", + action="store_true", + help="Regenerate the allowlist from the current codebase and exit", + ) + args = parser.parse_args(argv) + + manager = AllowlistManager(Path(__file__).parent / "known_provide_session_positional.txt") + + if args.generate: + return manager.generate() + + if args.cleanup: + return manager.cleanup() + + allowlist = manager.load() + + if args.all_files: + return _check_provide_session_kwargs(_iter_python_files(), allowlist, manager) + + if not args.files: + console.print( + "[yellow]No files provided. Pass filenames or use --all-files to scan the whole repo.[/yellow]" + ) + return 0 + + paths = [Path(f).resolve() for f in args.files] + paths = _expand_for_allowlist_edits(paths, manager, allowlist) + return _check_provide_session_kwargs(paths, allowlist, manager) + + +def _previous_allowlist(manager: AllowlistManager) -> dict[str, int]: + """Return the allowlist as recorded at ``HEAD``, or an empty dict. + + Used by :func:`_expand_for_allowlist_edits` so that *removing* an entry + cannot silently drop coverage: the previously-listed file is still + re-validated against the new (post-edit) allowlist. Returns an empty mapping + when git is unavailable, the file does not yet exist at ``HEAD``, or the + allowlist sits outside ``REPO_ROOT``. + """ + try: + rel = manager.allowlist_file.resolve().relative_to(REPO_ROOT.resolve()) + except ValueError: + return {} + try: + completed = subprocess.run( + ["git", "-C", str(REPO_ROOT), "show", f"HEAD:{rel.as_posix()}"], + capture_output=True, + text=True, + check=False, + ) + except (FileNotFoundError, OSError): + return {} + if completed.returncode != 0: + return {} + return AllowlistManager.parse(completed.stdout) + + +def _expand_for_allowlist_edits( + paths: list[Path], manager: AllowlistManager, allowlist: dict[str, int] +) -> list[Path]: + """Add allowlisted files when the allowlist itself is being changed. + + Without this, a contributor could raise counts in + ``known_provide_session_positional.txt`` and the hook would do no validation + (since only the ``.txt`` file is passed), letting the loosened allowlist + sail through. We also union the *previous* allowlist (from ``HEAD``) so that + removing an entry cannot silently bypass the check for a file that still + has positional ``session`` arguments. + + Both sides of the allowlist-file comparison are resolved so the detection is + robust to symlinks and unresolved inputs (the hook can be invoked with either). + """ + allowlist_file = manager.allowlist_file.resolve() + if not any(p.resolve() == allowlist_file for p in paths): + return paths + + expanded = list(paths) + seen = {p.resolve() for p in paths if p.suffix == ".py"} + previous = _previous_allowlist(manager) + for rel in {*allowlist, *previous}: + candidate = (REPO_ROOT / rel).resolve() + if candidate.exists() and candidate not in seen: + seen.add(candidate) + expanded.append(candidate) + return expanded + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/ci/prek/known_provide_session_positional.txt b/scripts/ci/prek/known_provide_session_positional.txt new file mode 100644 index 0000000000000..d0c84e2f6b48f --- /dev/null +++ b/scripts/ci/prek/known_provide_session_positional.txt @@ -0,0 +1,89 @@ +airflow-core/src/airflow/api/common/delete_dag.py::1 +airflow-core/src/airflow/api/common/mark_tasks.py::1 +airflow-core/src/airflow/callbacks/database_callback_sink.py::1 +airflow-core/src/airflow/cli/commands/dag_command.py::8 +airflow-core/src/airflow/cli/commands/jobs_command.py::1 +airflow-core/src/airflow/cli/commands/task_command.py::1 +airflow-core/src/airflow/cli/commands/team_command.py::4 +airflow-core/src/airflow/cli/commands/variable_command.py::1 +airflow-core/src/airflow/dag_processing/dagbag.py::1 +airflow-core/src/airflow/dag_processing/manager.py::4 +airflow-core/src/airflow/jobs/base_job_runner.py::2 +airflow-core/src/airflow/jobs/job.py::7 +airflow-core/src/airflow/jobs/scheduler_job_runner.py::11 +airflow-core/src/airflow/jobs/triggerer_job_runner.py::1 +airflow-core/src/airflow/models/connection.py::2 +airflow-core/src/airflow/models/dag.py::7 +airflow-core/src/airflow/models/dagcode.py::6 +airflow-core/src/airflow/models/dagrun.py::15 +airflow-core/src/airflow/models/dagwarning.py::1 +airflow-core/src/airflow/models/deadline.py::1 +airflow-core/src/airflow/models/deadline_alert.py::1 +airflow-core/src/airflow/models/pool.py::11 +airflow-core/src/airflow/models/renderedtifields.py::4 +airflow-core/src/airflow/models/revoked_token.py::2 +airflow-core/src/airflow/models/serialized_dag.py::6 +airflow-core/src/airflow/models/taskinstance.py::21 +airflow-core/src/airflow/models/taskinstancehistory.py::2 +airflow-core/src/airflow/models/team.py::1 +airflow-core/src/airflow/models/trigger.py::7 +airflow-core/src/airflow/models/variable.py::2 +airflow-core/src/airflow/secrets/metastore.py::2 +airflow-core/src/airflow/serialization/definitions/dag.py::2 +airflow-core/src/airflow/ti_deps/deps/base_ti_dep.py::2 +airflow-core/src/airflow/ti_deps/deps/dag_ti_slots_available_dep.py::1 +airflow-core/src/airflow/ti_deps/deps/dag_unpaused_dep.py::1 +airflow-core/src/airflow/ti_deps/deps/dagrun_exists_dep.py::1 +airflow-core/src/airflow/ti_deps/deps/exec_date_after_start_date_dep.py::1 +airflow-core/src/airflow/ti_deps/deps/not_in_retry_period_dep.py::1 +airflow-core/src/airflow/ti_deps/deps/pool_slots_available_dep.py::1 +airflow-core/src/airflow/ti_deps/deps/prev_dagrun_dep.py::1 +airflow-core/src/airflow/ti_deps/deps/ready_to_reschedule.py::1 +airflow-core/src/airflow/ti_deps/deps/runnable_exec_date_dep.py::1 +airflow-core/src/airflow/ti_deps/deps/task_concurrency_dep.py::1 +airflow-core/src/airflow/ti_deps/deps/task_not_running_dep.py::1 +airflow-core/src/airflow/ti_deps/deps/valid_state_dep.py::1 +airflow-core/src/airflow/utils/cli_action_loggers.py::1 +airflow-core/src/airflow/utils/db.py::7 +airflow-core/src/airflow/utils/db_cleanup.py::2 +airflow-core/src/airflow/utils/log/file_task_handler.py::1 +airflow-core/tests/unit/api_fastapi/common/test_exceptions.py::4 +airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_assets.py::19 +airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_connections.py::2 +airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_run.py::1 +airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_tags.py::1 +airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_warning.py::1 +airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_event_logs.py::1 +airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_import_error.py::8 +airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_job.py::1 +airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_monitor.py::2 +airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_pools.py::2 +airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_variables.py::2 +airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_xcom.py::3 +airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_calendar.py::2 +airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_dags.py::1 +airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_gantt.py::1 +airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_grid.py::1 +airflow-core/tests/unit/cli/commands/test_rotate_fernet_key_command.py::2 +airflow-core/tests/unit/jobs/test_scheduler_job.py::1 +airflow-core/tests/unit/listeners/test_listeners.py::7 +airflow-core/tests/unit/models/test_taskinstance.py::4 +airflow-core/tests/unit/models/test_timestamp.py::2 +providers/amazon/src/airflow/providers/amazon/aws/triggers/emr.py::1 +providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py::1 +providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/template_rendering.py::1 +providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/pod.py::1 +providers/common/ai/tests/unit/common/ai/plugins/test_hitl_review.py::1 +providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py::2 +providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py::3 +providers/edge3/src/airflow/providers/edge3/models/edge_worker.py::10 +providers/edge3/src/airflow/providers/edge3/plugins/edge_executor_plugin.py::1 +providers/edge3/src/airflow/providers/edge3/worker_api/routes/logs.py::1 +providers/fab/src/airflow/providers/fab/auth_manager/cli_commands/permissions_command.py::1 +providers/google/src/airflow/providers/google/cloud/triggers/bigquery.py::1 +providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py::3 +providers/openlineage/src/airflow/providers/openlineage/utils/utils.py::1 +providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py::1 +providers/standard/src/airflow/providers/standard/sensors/external_task.py::1 +providers/standard/src/airflow/providers/standard/utils/sensor_helper.py::1 +providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py::3 diff --git a/scripts/tests/ci/prek/test_check_provide_session_kwargs.py b/scripts/tests/ci/prek/test_check_provide_session_kwargs.py new file mode 100644 index 0000000000000..30a4d4a5bd71b --- /dev/null +++ b/scripts/tests/ci/prek/test_check_provide_session_kwargs.py @@ -0,0 +1,477 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import ast +import os +import subprocess +import textwrap +from pathlib import Path + +import pytest +from ci.prek import check_provide_session_kwargs as hook +from ci.prek.check_provide_session_kwargs import ( + AllowlistManager, + _check_provide_session_kwargs, + _count_violations, + _expand_for_allowlist_edits, + _has_provide_session_decorator, + _iter_positional_session_in_provide_session, + _previous_allowlist, + _session_is_positional, +) + + +@pytest.fixture +def find_violations(write_python_file): + """Factory fixture: write code to a temp file and return positional-session violations.""" + + def _check(code: str) -> list[tuple[ast.FunctionDef | ast.AsyncFunctionDef, ast.arg]]: + path = write_python_file(code) + return list(_iter_positional_session_in_provide_session(path)) + + return _check + + +@pytest.fixture +def fake_repo(tmp_path, monkeypatch): + """Create a fake repo layout and patch REPO_ROOT so paths resolve correctly.""" + monkeypatch.setattr(hook, "REPO_ROOT", tmp_path) + + def _write(rel: str, code: str) -> Path: + path = tmp_path / rel + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(textwrap.dedent(code)) + return path + + return _write + + +@pytest.fixture +def git_repo(fake_repo, tmp_path): + """Initialise ``tmp_path`` as a git repo so ``git show HEAD:`` works. + + Returns a helper that commits the current working-tree contents under a given + message, so tests can stage a "previous" allowlist at HEAD before mutating it. + """ + env = { + **os.environ, + "GIT_AUTHOR_NAME": "t", + "GIT_AUTHOR_EMAIL": "t@t", + "GIT_COMMITTER_NAME": "t", + "GIT_COMMITTER_EMAIL": "t@t", + } + + def _run(*args: str) -> None: + subprocess.run(["git", "-C", str(tmp_path), *args], check=True, env=env, capture_output=True) + + _run("init", "-q", "-b", "main") + _run("config", "commit.gpgsign", "false") + + def _commit(message: str) -> None: + _run("add", "-A") + _run("commit", "-q", "--allow-empty", "-m", message) + + return _commit + + +class TestHasProvideSessionDecorator: + def test_provide_session_name(self): + func = ast.parse("@provide_session\ndef foo(): pass").body[0] + assert _has_provide_session_decorator(func.decorator_list) is True + + def test_provide_session_attribute(self): + func = ast.parse("@utils.provide_session\ndef foo(): pass").body[0] + assert _has_provide_session_decorator(func.decorator_list) is True + + def test_no_decorator(self): + func = ast.parse("def foo(): pass").body[0] + assert _has_provide_session_decorator(func.decorator_list) is False + + def test_unrelated_decorator(self): + func = ast.parse("@staticmethod\ndef foo(): pass").body[0] + assert _has_provide_session_decorator(func.decorator_list) is False + + def test_multiple_decorators_including_provide_session(self): + func = ast.parse("@staticmethod\n@provide_session\ndef foo(): pass").body[0] + assert _has_provide_session_decorator(func.decorator_list) is True + + +class TestSessionIsPositional: + def test_no_session_arg(self): + func = ast.parse("def foo(x, y): pass").body[0] + assert _session_is_positional(func.args) is None + + def test_session_positional(self): + func = ast.parse("def foo(session=NEW_SESSION): pass").body[0] + argument = _session_is_positional(func.args) + assert argument is not None + assert argument.arg == "session" + + def test_session_keyword_only(self): + func = ast.parse("def foo(*, session=NEW_SESSION): pass").body[0] + assert _session_is_positional(func.args) is None + + def test_session_positional_among_other_args(self): + func = ast.parse("def foo(x, y, session=NEW_SESSION): pass").body[0] + argument = _session_is_positional(func.args) + assert argument is not None + assert argument.arg == "session" + + def test_session_kwonly_after_other_positional(self): + func = ast.parse("def foo(x, y, *, session=NEW_SESSION): pass").body[0] + assert _session_is_positional(func.args) is None + + def test_session_positional_only(self): + func = ast.parse("def foo(session, /, x): pass").body[0] + argument = _session_is_positional(func.args) + assert argument is not None + assert argument.arg == "session" + + +class TestIterPositionalSessionInProvideSession: + def test_keyword_only_session_is_clean(self, find_violations): + code = """\ + @provide_session + def foo(*, session=NEW_SESSION): + pass + """ + assert find_violations(code) == [] + + def test_positional_session_is_flagged(self, find_violations): + code = """\ + @provide_session + def foo(session=NEW_SESSION): + pass + """ + violations = find_violations(code) + assert len(violations) == 1 + func, argument = violations[0] + assert func.name == "foo" + assert argument.arg == "session" + + def test_no_provide_session_decorator_is_ignored(self, find_violations): + code = """\ + def foo(session=NEW_SESSION): + pass + """ + assert find_violations(code) == [] + + def test_async_function_with_positional_session_is_flagged(self, find_violations): + code = """\ + @provide_session + async def foo(session=NEW_SESSION): + pass + """ + violations = find_violations(code) + assert len(violations) == 1 + + def test_method_with_positional_session_is_flagged(self, find_violations): + code = """\ + class C: + @provide_session + def foo(self, session=NEW_SESSION): + pass + """ + violations = find_violations(code) + assert len(violations) == 1 + assert violations[0][0].name == "foo" + + def test_attribute_decorator_is_recognised(self, find_violations): + code = """\ + @airflow.utils.session.provide_session + def foo(session=NEW_SESSION): + pass + """ + violations = find_violations(code) + assert len(violations) == 1 + + def test_count_violations_multiple_in_file(self, write_python_file): + code = """\ + @provide_session + def a(session=NEW_SESSION): + pass + + @provide_session + def b(x, session=NEW_SESSION): + pass + + @provide_session + def c(*, session=NEW_SESSION): + pass + """ + path = write_python_file(code) + assert _count_violations(path) == 2 + + def test_syntax_error_returns_no_violations(self, write_python_file): + path = write_python_file("def foo(:\n pass") + assert _count_violations(path) == 0 + + def test_invalid_utf8_does_not_crash(self, tmp_path): + path = tmp_path / "invalid_utf8.py" + path.write_bytes(b"# bad byte: \xff\n@provide_session\ndef foo(session=NEW_SESSION):\n pass\n") + + assert _count_violations(path) == 1 + + +class TestAllowlistManager: + def test_load_missing_file_returns_empty(self, tmp_path): + manager = AllowlistManager(tmp_path / "missing.txt") + assert manager.load() == {} + + def test_save_and_load_round_trip(self, tmp_path): + manager = AllowlistManager(tmp_path / "allowlist.txt") + manager.save({"b/file.py": 2, "a/file.py": 1}) + # Sorted by key in the file + text = (tmp_path / "allowlist.txt").read_text() + assert text.splitlines() == ["a/file.py::1", "b/file.py::2"] + assert manager.load() == {"a/file.py": 1, "b/file.py": 2} + + def test_load_skips_blank_and_malformed_lines(self, tmp_path): + path = tmp_path / "allowlist.txt" + path.write_text("\nvalid/file.py::3\nnocount\n::5\nbad::notanumber\n") + assert AllowlistManager(path).load() == {"valid/file.py": 3} + + def test_load_skips_unsafe_entries(self, fake_repo, tmp_path): + """Entries that escape REPO_ROOT (absolute paths or `..` segments) are ignored.""" + path = tmp_path / "allowlist.txt" + path.write_text("airflow-core/src/airflow/safe.py::1\n../escape.py::1\n/etc/passwd::1\n") + # `fake_repo` patches REPO_ROOT to tmp_path so the safety check is meaningful. + assert AllowlistManager(path).load() == {"airflow-core/src/airflow/safe.py": 1} + + +class TestCheckProvideSessionKwargs: + def test_no_violations_in_clean_file(self, fake_repo, tmp_path): + path = fake_repo( + "airflow-core/src/airflow/clean.py", + """\ + @provide_session + def foo(*, session=NEW_SESSION): + pass + """, + ) + manager = AllowlistManager(tmp_path / "allowlist.txt") + assert _check_provide_session_kwargs([path], {}, manager) == 0 + + def test_new_violation_fails(self, fake_repo, tmp_path): + path = fake_repo( + "airflow-core/src/airflow/bad.py", + """\ + @provide_session + def foo(session=NEW_SESSION): + pass + """, + ) + manager = AllowlistManager(tmp_path / "allowlist.txt") + assert _check_provide_session_kwargs([path], {}, manager) == 1 + + def test_violation_within_allowlist_passes(self, fake_repo, tmp_path): + path = fake_repo( + "airflow-core/src/airflow/grandfathered.py", + """\ + @provide_session + def foo(session=NEW_SESSION): + pass + """, + ) + manager = AllowlistManager(tmp_path / "allowlist.txt") + allowlist = {"airflow-core/src/airflow/grandfathered.py": 1} + assert _check_provide_session_kwargs([path], allowlist, manager) == 0 + + def test_exceeding_allowlist_fails(self, fake_repo, tmp_path): + path = fake_repo( + "airflow-core/src/airflow/grew.py", + """\ + @provide_session + def a(session=NEW_SESSION): + pass + + @provide_session + def b(session=NEW_SESSION): + pass + """, + ) + manager = AllowlistManager(tmp_path / "allowlist.txt") + allowlist = {"airflow-core/src/airflow/grew.py": 1} + assert _check_provide_session_kwargs([path], allowlist, manager) == 1 + + def test_reducing_violations_tightens_allowlist(self, fake_repo, tmp_path): + path = fake_repo( + "airflow-core/src/airflow/improved.py", + """\ + @provide_session + def foo(session=NEW_SESSION): + pass + + @provide_session + def bar(*, session=NEW_SESSION): + pass + """, + ) + manager = AllowlistManager(tmp_path / "allowlist.txt") + allowlist = {"airflow-core/src/airflow/improved.py": 2} + # Exit non-zero so pre-commit reports the modified allowlist + assert _check_provide_session_kwargs([path], allowlist, manager) == 1 + assert manager.load() == {"airflow-core/src/airflow/improved.py": 1} + + def test_fixing_all_violations_removes_entry(self, fake_repo, tmp_path): + path = fake_repo( + "airflow-core/src/airflow/fixed.py", + """\ + @provide_session + def foo(*, session=NEW_SESSION): + pass + """, + ) + manager = AllowlistManager(tmp_path / "allowlist.txt") + allowlist = {"airflow-core/src/airflow/fixed.py": 1} + assert _check_provide_session_kwargs([path], allowlist, manager) == 1 + assert manager.load() == {} + + def test_non_python_file_is_skipped(self, fake_repo, tmp_path): + path = fake_repo( + "airflow-core/src/airflow/not_python.txt", "@provide_session\ndef foo(session=N): pass\n" + ) + manager = AllowlistManager(tmp_path / "allowlist.txt") + assert _check_provide_session_kwargs([path], {}, manager) == 0 + + def test_missing_allowlist_file_fails_loudly(self, fake_repo, tmp_path): + """Passing the allowlist path when the file is missing must fail, not silently pass.""" + allowlist_path = tmp_path / "allowlist.txt" + manager = AllowlistManager(allowlist_path) + assert not allowlist_path.exists() + assert _check_provide_session_kwargs([allowlist_path.resolve()], {}, manager) == 1 + + +class TestExpandForAllowlistEdits: + def test_unchanged_when_allowlist_not_in_paths(self, fake_repo, tmp_path): + py = fake_repo("airflow-core/src/airflow/x.py", "pass") + manager = AllowlistManager(tmp_path / "allowlist.txt") + assert _expand_for_allowlist_edits([py], manager, {"airflow-core/src/airflow/x.py": 1}) == [py] + + def test_appends_allowlisted_files_when_allowlist_edited(self, fake_repo, tmp_path): + allowlist_path = tmp_path / "allowlist.txt" + manager = AllowlistManager(allowlist_path) + listed = fake_repo("airflow-core/src/airflow/listed.py", "pass") + # Pass a resolved path — matches production behavior (``main()`` resolves argv). + result = _expand_for_allowlist_edits( + [allowlist_path.resolve()], + manager, + {"airflow-core/src/airflow/listed.py": 1, "airflow-core/src/airflow/gone.py": 1}, + ) + assert allowlist_path.resolve() in result + assert listed in result + # File in allowlist that does not exist on disk should be ignored. + assert (tmp_path / "airflow-core/src/airflow/gone.py").resolve() not in result + + def test_detection_robust_to_symlinked_allowlist(self, fake_repo, tmp_path): + """A symlink pointing at the allowlist file must still trigger expansion.""" + allowlist_path = tmp_path / "allowlist.txt" + manager = AllowlistManager(allowlist_path) + listed = fake_repo("airflow-core/src/airflow/listed.py", "pass") + manager.save({"airflow-core/src/airflow/listed.py": 1}) + + symlink = tmp_path / "allowlist_link.txt" + symlink.symlink_to(allowlist_path) + + # Production resolves argv before calling the helper — a symlinked path resolves + # to the real allowlist file and must be recognised as an allowlist edit. + result = _expand_for_allowlist_edits([symlink.resolve()], manager, manager.load()) + + assert listed in result + + def test_includes_previous_allowlist_entries_when_removed(self, fake_repo, git_repo, tmp_path): + """Removing an entry from the allowlist must still re-check the previously-listed file.""" + rel = "airflow-core/src/airflow/dropped.py" + fake_repo( + rel, + """\ + @provide_session + def foo(session=NEW_SESSION): + pass + """, + ) + allowlist_path = tmp_path / "allowlist.txt" + manager = AllowlistManager(allowlist_path) + manager.save({rel: 1}) + git_repo("seed allowlist at HEAD") + + # Working tree: remove the entry, but the offending file still exists. + allowlist_path.write_text("") + current = manager.load() + assert current == {} + + expanded = _expand_for_allowlist_edits([allowlist_path.resolve()], manager, current) + # The previously-listed file must be re-validated. + assert (tmp_path / rel).resolve() in expanded + + # And the full check should fail because the file still has positional sessions. + assert _check_provide_session_kwargs(expanded, current, manager) == 1 + + def test_previous_allowlist_empty_when_no_git_history(self, fake_repo, tmp_path): + """Without a git repo the previous-allowlist lookup returns empty and does not crash.""" + manager = AllowlistManager(tmp_path / "allowlist.txt") + assert _previous_allowlist(manager) == {} + + def test_re_validates_listed_files_so_loosening_cannot_bypass(self, fake_repo, tmp_path, capsys): + """Editing only the allowlist must still trigger validation of listed files.""" + rel = "airflow-core/src/airflow/loosened.py" + fake_repo( + rel, + """\ + @provide_session + def foo(session=NEW_SESSION): + pass + + @provide_session + def bar(session=NEW_SESSION): + pass + """, + ) + allowlist_path = tmp_path / "allowlist.txt" + manager = AllowlistManager(allowlist_path) + # Allowlist loosened to 5 although file only has 2 positional sessions. + allowlist = {rel: 5} + manager.save(allowlist) + + # Only the allowlist file is "changed"; without re-validation this would return 0. + # Resolve the path to mirror what ``main()`` does in production. + paths = _expand_for_allowlist_edits([allowlist_path.resolve()], manager, allowlist) + rc = _check_provide_session_kwargs(paths, allowlist, manager) + + # Tightened from 5 -> 2, so the hook exits non-zero to surface the modified allowlist. + assert rc == 1 + assert manager.load() == {rel: 2} + + +class TestCleanup: + def test_cleanup_removes_stale_entries(self, fake_repo, tmp_path): + fake_repo("airflow-core/src/airflow/keeper.py", "pass") + allowlist_path = tmp_path / "allowlist.txt" + manager = AllowlistManager(allowlist_path) + manager.save( + { + "airflow-core/src/airflow/keeper.py": 1, + "airflow-core/src/airflow/gone.py": 1, + } + ) + assert manager.cleanup() == 0 + assert manager.load() == {"airflow-core/src/airflow/keeper.py": 1} + + def test_cleanup_empty_allowlist(self, tmp_path): + manager = AllowlistManager(tmp_path / "allowlist.txt") + assert manager.cleanup() == 0