diff --git a/dev/breeze/src/airflow_breeze/utils/packages.py b/dev/breeze/src/airflow_breeze/utils/packages.py index 553f0379b4cc1..88809f30148e1 100644 --- a/dev/breeze/src/airflow_breeze/utils/packages.py +++ b/dev/breeze/src/airflow_breeze/utils/packages.py @@ -683,6 +683,20 @@ def get_cross_provider_dependent_packages(provider_id: str) -> list[str]: return get_provider_dependencies()[provider_id]["cross-providers-deps"] +def get_cross_provider_dependent_extras(provider_id: str) -> list[str]: + if provider_id in get_removed_provider_ids(): + return [] + + suspended_provider_ids = set(get_suspended_provider_ids()) + required_dependencies = "\n".join(get_provider_requirements(provider_id)) + return [ + cross_provider_id + for cross_provider_id in get_cross_provider_dependent_packages(provider_id) + if cross_provider_id not in suspended_provider_ids + and get_pip_package_name(cross_provider_id) not in required_dependencies + ] + + def get_license_files(provider_id: str) -> str: if provider_id == "fab": return str(["LICENSE", "NOTICE", "3rd-party-licenses/LICENSES-*"]).replace('"', "'") @@ -700,7 +714,7 @@ def get_provider_jinja_context( supported_python_versions = [ p for p in ALLOWED_PYTHON_MAJOR_MINOR_VERSIONS if p not in provider_details.excluded_python_versions ] - cross_providers_dependencies = get_cross_provider_dependent_packages(provider_id=provider_id) + cross_providers_dependencies = get_cross_provider_dependent_extras(provider_id=provider_id) requires_python_version: str = f">={DEFAULT_PYTHON_MAJOR_MINOR_VERSION}" # Most providers require the same python versions, but some may have exclusions @@ -729,7 +743,7 @@ def get_provider_jinja_context( "MIN_AIRFLOW_VERSION": get_min_airflow_version(provider_id), "PROVIDER_REMOVED": provider_details.removed, "PROVIDER_INFO": get_provider_info_dict(provider_id), - "CROSS_PROVIDERS_DEPENDENCIES": get_cross_provider_dependent_packages(provider_id), + "CROSS_PROVIDERS_DEPENDENCIES": cross_providers_dependencies, "CROSS_PROVIDERS_DEPENDENCIES_TABLE_RST": convert_cross_package_dependencies_to_table( cross_providers_dependencies, markdown=False ), diff --git a/dev/breeze/tests/test_packages.py b/dev/breeze/tests/test_packages.py index 2a11a4ef29d2a..e4ca8981e1572 100644 --- a/dev/breeze/tests/test_packages.py +++ b/dev/breeze/tests/test_packages.py @@ -31,6 +31,7 @@ expand_all_provider_distributions, find_matching_long_package_names, get_available_distributions, + get_cross_provider_dependent_extras, get_cross_provider_dependent_packages, get_dist_package_name_prefix, get_long_package_name, @@ -265,19 +266,22 @@ def test_get_min_airflow_version(provider_id: str, min_version: str): def test_convert_cross_package_dependencies_to_table(): EXPECTED = """ -| Dependent package | Extra | -|:----------------------------------------------------------------------------------------|:----------------| -| [apache-airflow-providers-common-compat](https://airflow.apache.org/docs/common-compat) | `common.compat` | -| [apache-airflow-providers-common-sql](https://airflow.apache.org/docs/common-sql) | `common.sql` | -| [apache-airflow-providers-google](https://airflow.apache.org/docs/google) | `google` | -| [apache-airflow-providers-openlineage](https://airflow.apache.org/docs/openlineage) | `openlineage` | +| Dependent package | Extra | +|:------------------------------------------------------------------------------------|:--------------| +| [apache-airflow-providers-google](https://airflow.apache.org/docs/google) | `google` | +| [apache-airflow-providers-openlineage](https://airflow.apache.org/docs/openlineage) | `openlineage` | """ assert ( - convert_cross_package_dependencies_to_table(get_cross_provider_dependent_packages("trino")).strip() + convert_cross_package_dependencies_to_table(get_cross_provider_dependent_extras("trino")).strip() == EXPECTED.strip() ) +def test_get_cross_provider_dependent_extras_excludes_required_dependencies(): + assert get_cross_provider_dependent_packages("openlineage") == ["common.compat", "common.sql"] + assert get_cross_provider_dependent_extras("openlineage") == [] + + def test_get_provider_info_dict(): provider_info_dict = get_provider_info_dict("amazon") assert provider_info_dict["name"] == "Amazon"