Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions dev/breeze/src/airflow_breeze/utils/packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,20 @@ def get_cross_provider_dependent_packages(provider_id: str) -> list[str]:
return get_provider_dependencies()[provider_id]["cross-providers-deps"]


def get_cross_provider_dependent_extras(provider_id: str) -> list[str]:
if provider_id in get_removed_provider_ids():
return []

suspended_provider_ids = set(get_suspended_provider_ids())
required_dependencies = "\n".join(get_provider_requirements(provider_id))
return [
cross_provider_id
for cross_provider_id in get_cross_provider_dependent_packages(provider_id)
if cross_provider_id not in suspended_provider_ids
and get_pip_package_name(cross_provider_id) not in required_dependencies
]


def get_license_files(provider_id: str) -> str:
if provider_id == "fab":
return str(["LICENSE", "NOTICE", "3rd-party-licenses/LICENSES-*"]).replace('"', "'")
Expand All @@ -700,7 +714,7 @@ def get_provider_jinja_context(
supported_python_versions = [
p for p in ALLOWED_PYTHON_MAJOR_MINOR_VERSIONS if p not in provider_details.excluded_python_versions
]
cross_providers_dependencies = get_cross_provider_dependent_packages(provider_id=provider_id)
cross_providers_dependencies = get_cross_provider_dependent_extras(provider_id=provider_id)

requires_python_version: str = f">={DEFAULT_PYTHON_MAJOR_MINOR_VERSION}"
# Most providers require the same python versions, but some may have exclusions
Expand Down Expand Up @@ -729,7 +743,7 @@ def get_provider_jinja_context(
"MIN_AIRFLOW_VERSION": get_min_airflow_version(provider_id),
"PROVIDER_REMOVED": provider_details.removed,
"PROVIDER_INFO": get_provider_info_dict(provider_id),
"CROSS_PROVIDERS_DEPENDENCIES": get_cross_provider_dependent_packages(provider_id),
"CROSS_PROVIDERS_DEPENDENCIES": cross_providers_dependencies,
"CROSS_PROVIDERS_DEPENDENCIES_TABLE_RST": convert_cross_package_dependencies_to_table(
cross_providers_dependencies, markdown=False
),
Expand Down
18 changes: 11 additions & 7 deletions dev/breeze/tests/test_packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
expand_all_provider_distributions,
find_matching_long_package_names,
get_available_distributions,
get_cross_provider_dependent_extras,
get_cross_provider_dependent_packages,
get_dist_package_name_prefix,
get_long_package_name,
Expand Down Expand Up @@ -265,19 +266,22 @@ def test_get_min_airflow_version(provider_id: str, min_version: str):

def test_convert_cross_package_dependencies_to_table():
EXPECTED = """
| Dependent package | Extra |
|:----------------------------------------------------------------------------------------|:----------------|
| [apache-airflow-providers-common-compat](https://airflow.apache.org/docs/common-compat) | `common.compat` |
| [apache-airflow-providers-common-sql](https://airflow.apache.org/docs/common-sql) | `common.sql` |
| [apache-airflow-providers-google](https://airflow.apache.org/docs/google) | `google` |
| [apache-airflow-providers-openlineage](https://airflow.apache.org/docs/openlineage) | `openlineage` |
| Dependent package | Extra |
|:------------------------------------------------------------------------------------|:--------------|
| [apache-airflow-providers-google](https://airflow.apache.org/docs/google) | `google` |
| [apache-airflow-providers-openlineage](https://airflow.apache.org/docs/openlineage) | `openlineage` |
"""
assert (
convert_cross_package_dependencies_to_table(get_cross_provider_dependent_packages("trino")).strip()
convert_cross_package_dependencies_to_table(get_cross_provider_dependent_extras("trino")).strip()
== EXPECTED.strip()
)


def test_get_cross_provider_dependent_extras_excludes_required_dependencies():
assert get_cross_provider_dependent_packages("openlineage") == ["common.compat", "common.sql"]
assert get_cross_provider_dependent_extras("openlineage") == []


def test_get_provider_info_dict():
provider_info_dict = get_provider_info_dict("amazon")
assert provider_info_dict["name"] == "Amazon"
Expand Down