diff --git a/scripts/ci/prek/check_sdk_imports.py b/scripts/ci/prek/check_sdk_imports.py index f900390695f9f..b35ebc9db8fd8 100755 --- a/scripts/ci/prek/check_sdk_imports.py +++ b/scripts/ci/prek/check_sdk_imports.py @@ -29,7 +29,9 @@ import sys from pathlib import Path -from common_prek_utils import console +from common_prek_utils import console, has_nocheck_marker + +NOCHECK_MARKER = "# noqa: SDK001" def check_file_for_sdk_imports(file_path: Path) -> list[tuple[int, str]]: @@ -41,11 +43,14 @@ def check_file_for_sdk_imports(file_path: Path) -> list[tuple[int, str]]: except (OSError, UnicodeDecodeError, SyntaxError): return [] + source_lines = source.splitlines() mismatches = [] for node in ast.walk(tree): if isinstance(node, ast.ImportFrom): if node.module and ("airflow.sdk" in node.module): + if has_nocheck_marker(source_lines, node, NOCHECK_MARKER): + continue import_names = ", ".join(alias.name for alias in node.names) statement = f"from {node.module} import {import_names}" mismatches.append((node.lineno, statement)) diff --git a/scripts/ci/prek/common_prek_utils.py b/scripts/ci/prek/common_prek_utils.py index 68db84b6df0c6..b7ec454d1e6a1 100644 --- a/scripts/ci/prek/common_prek_utils.py +++ b/scripts/ci/prek/common_prek_utils.py @@ -521,6 +521,16 @@ def get_all_provider_info_dicts() -> dict[str, dict]: return providers +def has_nocheck_marker(source_lines: list[str], node: ast.ImportFrom, marker: str) -> bool: + """Check if the import statement has the given nocheck marker comment on any of its lines.""" + start = node.lineno + end = node.end_lineno or start + for lineno in range(start, end + 1): + if lineno <= len(source_lines) and marker in source_lines[lineno - 1]: + return True + return False + + def get_imports_from_file(file_path: Path, *, only_top_level: bool) -> list[str]: """ Returns list of all imports in file. diff --git a/scripts/tests/ci/prek/test_check_sdk_imports.py b/scripts/tests/ci/prek/test_check_sdk_imports.py new file mode 100644 index 0000000000000..cf097027dc3ca --- /dev/null +++ b/scripts/tests/ci/prek/test_check_sdk_imports.py @@ -0,0 +1,140 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import textwrap +from pathlib import Path + +import pytest +from check_sdk_imports import check_file_for_sdk_imports + + +class TestCheckFileForSdkImports: + @pytest.mark.parametrize( + "code, expected", + [ + pytest.param( + "from airflow.sdk import DAG\n", + [(1, "from airflow.sdk import DAG")], + id="from-sdk-import", + ), + pytest.param( + "from airflow.sdk.definitions import dag\n", + [(1, "from airflow.sdk.definitions import dag")], + id="from-sdk-submodule-import", + ), + pytest.param( + "from airflow.models import DagRun\n", + [], + id="core-import-allowed", + ), + pytest.param( + "import airflow.sdk\n", + [], + id="plain-import-not-checked", + ), + pytest.param( + "import os\nimport sys\n", + [], + id="stdlib-only", + ), + ], + ) + def test_detects_sdk_imports(self, tmp_path: Path, code: str, expected: list[tuple[int, str]]): + f = tmp_path / "example.py" + f.write_text(code) + assert check_file_for_sdk_imports(f) == expected + + +class TestNocheckMarker: + @pytest.mark.parametrize( + "code, expected", + [ + pytest.param( + "from airflow.sdk import DAG # noqa: SDK001\n", + [], + id="from-import-suppressed", + ), + pytest.param( + "from airflow.sdk.definitions import dag # noqa: SDK001\n", + [], + id="from-submodule-suppressed", + ), + pytest.param( + "from airflow.sdk import DAG # noqa: SDK001 - needed for compat\n", + [], + id="marker-with-extra-text", + ), + pytest.param( + textwrap.dedent("""\ + from airflow.sdk import ( + DAG, + Variable, + ) # noqa: SDK001 + """), + [], + id="multiline-marker-on-closing-paren", + ), + pytest.param( + textwrap.dedent("""\ + from airflow.sdk import ( # noqa: SDK001 + DAG, + Variable, + ) + """), + [], + id="multiline-marker-on-first-line", + ), + pytest.param( + textwrap.dedent("""\ + from airflow.sdk import ( + DAG, # noqa: SDK001 + Variable, + ) + """), + [], + id="multiline-marker-on-middle-line", + ), + pytest.param( + "from airflow.sdk import DAG # noqa: E402\n", + [(1, "from airflow.sdk import DAG")], + id="wrong-marker-not-suppressed", + ), + pytest.param( + textwrap.dedent("""\ + from airflow.sdk import ( + DAG, + Variable, + ) + """), + [(1, "from airflow.sdk import DAG, Variable")], + id="multiline-without-marker-detected", + ), + pytest.param( + textwrap.dedent("""\ + from airflow.sdk import DAG # noqa: SDK001 + from airflow.sdk.definitions import dag + """), + [(2, "from airflow.sdk.definitions import dag")], + id="only-marked-line-suppressed", + ), + ], + ) + def test_nocheck_marker(self, tmp_path: Path, code: str, expected: list[tuple[int, str]]): + f = tmp_path / "example.py" + f.write_text(code) + assert check_file_for_sdk_imports(f) == expected