From fb0623d5456d58ee778c69cd107b0c4527951330 Mon Sep 17 00:00:00 2001 From: LIU ZHE YOU Date: Thu, 7 May 2026 16:13:56 +0800 Subject: [PATCH 1/2] Support inline ignore marker for check-sdk-imports hook --- scripts/ci/prek/check_sdk_imports.py | 15 ++ .../tests/ci/prek/test_check_sdk_imports.py | 140 ++++++++++++++++++ 2 files changed, 155 insertions(+) create mode 100644 scripts/tests/ci/prek/test_check_sdk_imports.py diff --git a/scripts/ci/prek/check_sdk_imports.py b/scripts/ci/prek/check_sdk_imports.py index f900390695f9f..cc0315daff3f8 100755 --- a/scripts/ci/prek/check_sdk_imports.py +++ b/scripts/ci/prek/check_sdk_imports.py @@ -31,6 +31,8 @@ from common_prek_utils import console +NOCHECK_MARKER = "# nocheck: sdk-imports" + def check_file_for_sdk_imports(file_path: Path) -> list[tuple[int, str]]: """Check file for airflow.sdk imports. Returns list of (line_num, import_statement).""" @@ -41,11 +43,14 @@ def check_file_for_sdk_imports(file_path: Path) -> list[tuple[int, str]]: except (OSError, UnicodeDecodeError, SyntaxError): return [] + source_lines = source.splitlines() mismatches = [] for node in ast.walk(tree): if isinstance(node, ast.ImportFrom): if node.module and ("airflow.sdk" in node.module): + if _has_nocheck_marker(source_lines, node): + continue import_names = ", ".join(alias.name for alias in node.names) statement = f"from {node.module} import {import_names}" mismatches.append((node.lineno, statement)) @@ -53,6 +58,16 @@ def check_file_for_sdk_imports(file_path: Path) -> list[tuple[int, str]]: return mismatches +def _has_nocheck_marker(source_lines: list[str], node: ast.Import | ast.ImportFrom) -> bool: + """Check if the import statement has the nocheck marker comment on any of its lines.""" + start = node.lineno + end = node.end_lineno or start + for lineno in range(start, end + 1): + if lineno <= len(source_lines) and NOCHECK_MARKER in source_lines[lineno - 1]: + return True + return False + + def main(): parser = argparse.ArgumentParser(description="Check for SDK imports in airflow-core files") parser.add_argument("files", nargs="*", help="Files to check") diff --git a/scripts/tests/ci/prek/test_check_sdk_imports.py b/scripts/tests/ci/prek/test_check_sdk_imports.py new file mode 100644 index 0000000000000..458e61dada097 --- /dev/null +++ b/scripts/tests/ci/prek/test_check_sdk_imports.py @@ -0,0 +1,140 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import textwrap +from pathlib import Path + +import pytest +from check_sdk_imports import check_file_for_sdk_imports + + +class TestCheckFileForSdkImports: + @pytest.mark.parametrize( + "code, expected", + [ + pytest.param( + "from airflow.sdk import DAG\n", + [(1, "from airflow.sdk import DAG")], + id="from-sdk-import", + ), + pytest.param( + "from airflow.sdk.definitions import dag\n", + [(1, "from airflow.sdk.definitions import dag")], + id="from-sdk-submodule-import", + ), + pytest.param( + "from airflow.models import DagRun\n", + [], + id="core-import-allowed", + ), + pytest.param( + "import airflow.sdk\n", + [], + id="plain-import-not-checked", + ), + pytest.param( + "import os\nimport sys\n", + [], + id="stdlib-only", + ), + ], + ) + def test_detects_sdk_imports(self, tmp_path: Path, code: str, expected: list[tuple[int, str]]): + f = tmp_path / "example.py" + f.write_text(code) + assert check_file_for_sdk_imports(f) == expected + + +class TestNocheckMarker: + @pytest.mark.parametrize( + "code, expected", + [ + pytest.param( + "from airflow.sdk import DAG # nocheck: sdk-imports\n", + [], + id="from-import-suppressed", + ), + pytest.param( + "from airflow.sdk.definitions import dag # nocheck: sdk-imports\n", + [], + id="from-submodule-suppressed", + ), + pytest.param( + "from airflow.sdk import DAG # nocheck: sdk-imports - needed for compat\n", + [], + id="marker-with-extra-text", + ), + pytest.param( + textwrap.dedent("""\ + from airflow.sdk import ( + DAG, + Variable, + ) # nocheck: sdk-imports + """), + [], + id="multiline-marker-on-closing-paren", + ), + pytest.param( + textwrap.dedent("""\ + from airflow.sdk import ( # nocheck: sdk-imports + DAG, + Variable, + ) + """), + [], + id="multiline-marker-on-first-line", + ), + pytest.param( + textwrap.dedent("""\ + from airflow.sdk import ( + DAG, # nocheck: sdk-imports + Variable, + ) + """), + [], + id="multiline-marker-on-middle-line", + ), + pytest.param( + "from airflow.sdk import DAG # noqa: E402\n", + [(1, "from airflow.sdk import DAG")], + id="wrong-marker-not-suppressed", + ), + pytest.param( + textwrap.dedent("""\ + from airflow.sdk import ( + DAG, + Variable, + ) + """), + [(1, "from airflow.sdk import DAG, Variable")], + id="multiline-without-marker-detected", + ), + pytest.param( + textwrap.dedent("""\ + from airflow.sdk import DAG # nocheck: sdk-imports + from airflow.sdk.definitions import dag + """), + [(2, "from airflow.sdk.definitions import dag")], + id="only-marked-line-suppressed", + ), + ], + ) + def test_nocheck_marker(self, tmp_path: Path, code: str, expected: list[tuple[int, str]]): + f = tmp_path / "example.py" + f.write_text(code) + assert check_file_for_sdk_imports(f) == expected From 15ff2f5db81f36570b8745145d523e9f193ea4f2 Mon Sep 17 00:00:00 2001 From: LIU ZHE YOU Date: Thu, 7 May 2026 17:17:14 +0800 Subject: [PATCH 2/2] Addressed review comments --- scripts/ci/prek/check_sdk_imports.py | 16 +++------------- scripts/ci/prek/common_prek_utils.py | 10 ++++++++++ scripts/tests/ci/prek/test_check_sdk_imports.py | 14 +++++++------- 3 files changed, 20 insertions(+), 20 deletions(-) diff --git a/scripts/ci/prek/check_sdk_imports.py b/scripts/ci/prek/check_sdk_imports.py index cc0315daff3f8..b35ebc9db8fd8 100755 --- a/scripts/ci/prek/check_sdk_imports.py +++ b/scripts/ci/prek/check_sdk_imports.py @@ -29,9 +29,9 @@ import sys from pathlib import Path -from common_prek_utils import console +from common_prek_utils import console, has_nocheck_marker -NOCHECK_MARKER = "# nocheck: sdk-imports" +NOCHECK_MARKER = "# noqa: SDK001" def check_file_for_sdk_imports(file_path: Path) -> list[tuple[int, str]]: @@ -49,7 +49,7 @@ def check_file_for_sdk_imports(file_path: Path) -> list[tuple[int, str]]: for node in ast.walk(tree): if isinstance(node, ast.ImportFrom): if node.module and ("airflow.sdk" in node.module): - if _has_nocheck_marker(source_lines, node): + if has_nocheck_marker(source_lines, node, NOCHECK_MARKER): continue import_names = ", ".join(alias.name for alias in node.names) statement = f"from {node.module} import {import_names}" @@ -58,16 +58,6 @@ def check_file_for_sdk_imports(file_path: Path) -> list[tuple[int, str]]: return mismatches -def _has_nocheck_marker(source_lines: list[str], node: ast.Import | ast.ImportFrom) -> bool: - """Check if the import statement has the nocheck marker comment on any of its lines.""" - start = node.lineno - end = node.end_lineno or start - for lineno in range(start, end + 1): - if lineno <= len(source_lines) and NOCHECK_MARKER in source_lines[lineno - 1]: - return True - return False - - def main(): parser = argparse.ArgumentParser(description="Check for SDK imports in airflow-core files") parser.add_argument("files", nargs="*", help="Files to check") diff --git a/scripts/ci/prek/common_prek_utils.py b/scripts/ci/prek/common_prek_utils.py index 68db84b6df0c6..b7ec454d1e6a1 100644 --- a/scripts/ci/prek/common_prek_utils.py +++ b/scripts/ci/prek/common_prek_utils.py @@ -521,6 +521,16 @@ def get_all_provider_info_dicts() -> dict[str, dict]: return providers +def has_nocheck_marker(source_lines: list[str], node: ast.ImportFrom, marker: str) -> bool: + """Check if the import statement has the given nocheck marker comment on any of its lines.""" + start = node.lineno + end = node.end_lineno or start + for lineno in range(start, end + 1): + if lineno <= len(source_lines) and marker in source_lines[lineno - 1]: + return True + return False + + def get_imports_from_file(file_path: Path, *, only_top_level: bool) -> list[str]: """ Returns list of all imports in file. diff --git a/scripts/tests/ci/prek/test_check_sdk_imports.py b/scripts/tests/ci/prek/test_check_sdk_imports.py index 458e61dada097..cf097027dc3ca 100644 --- a/scripts/tests/ci/prek/test_check_sdk_imports.py +++ b/scripts/tests/ci/prek/test_check_sdk_imports.py @@ -65,17 +65,17 @@ class TestNocheckMarker: "code, expected", [ pytest.param( - "from airflow.sdk import DAG # nocheck: sdk-imports\n", + "from airflow.sdk import DAG # noqa: SDK001\n", [], id="from-import-suppressed", ), pytest.param( - "from airflow.sdk.definitions import dag # nocheck: sdk-imports\n", + "from airflow.sdk.definitions import dag # noqa: SDK001\n", [], id="from-submodule-suppressed", ), pytest.param( - "from airflow.sdk import DAG # nocheck: sdk-imports - needed for compat\n", + "from airflow.sdk import DAG # noqa: SDK001 - needed for compat\n", [], id="marker-with-extra-text", ), @@ -84,14 +84,14 @@ class TestNocheckMarker: from airflow.sdk import ( DAG, Variable, - ) # nocheck: sdk-imports + ) # noqa: SDK001 """), [], id="multiline-marker-on-closing-paren", ), pytest.param( textwrap.dedent("""\ - from airflow.sdk import ( # nocheck: sdk-imports + from airflow.sdk import ( # noqa: SDK001 DAG, Variable, ) @@ -102,7 +102,7 @@ class TestNocheckMarker: pytest.param( textwrap.dedent("""\ from airflow.sdk import ( - DAG, # nocheck: sdk-imports + DAG, # noqa: SDK001 Variable, ) """), @@ -126,7 +126,7 @@ class TestNocheckMarker: ), pytest.param( textwrap.dedent("""\ - from airflow.sdk import DAG # nocheck: sdk-imports + from airflow.sdk import DAG # noqa: SDK001 from airflow.sdk.definitions import dag """), [(2, "from airflow.sdk.definitions import dag")],