diff --git a/airflow/providers/openlineage/extractors/manager.py b/airflow/providers/openlineage/extractors/manager.py index 480ffc9f39981..6d9fabc5cf47b 100644 --- a/airflow/providers/openlineage/extractors/manager.py +++ b/airflow/providers/openlineage/extractors/manager.py @@ -63,8 +63,9 @@ def __init__(self): for operator_class in extractor.get_operator_classnames(): self.extractors[operator_class] = extractor - env_extractors = conf.get("openlinege", "extractors", fallback=os.getenv("OPENLINEAGE_EXTRACTORS")) - if env_extractors is not None: + env_extractors = conf.get("openlineage", "extractors", fallback=os.getenv("OPENLINEAGE_EXTRACTORS")) + # skip either when it's empty string or None + if env_extractors: for extractor in env_extractors.split(";"): extractor: type[BaseExtractor] = try_import_from_string(extractor.strip()) for operator_class in extractor.get_operator_classnames(): diff --git a/airflow/providers/openlineage/provider.yaml b/airflow/providers/openlineage/provider.yaml index 9bd98e8b937c9..ebfe31749bf5b 100644 --- a/airflow/providers/openlineage/provider.yaml +++ b/airflow/providers/openlineage/provider.yaml @@ -85,7 +85,7 @@ config: Semicolon separated paths to custom OpenLineage extractors. type: string example: full.path.to.ExtractorClass;full.path.to.AnotherExtractorClass - default: "" + default: ~ version_added: ~ config_path: description: | diff --git a/tests/providers/openlineage/extractors/test_base.py b/tests/providers/openlineage/extractors/test_base.py index 7c2174fe5b270..309e0c1a79f52 100644 --- a/tests/providers/openlineage/extractors/test_base.py +++ b/tests/providers/openlineage/extractors/test_base.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import os from typing import Any from unittest import mock @@ -27,11 +28,13 @@ from airflow.models.baseoperator import BaseOperator from airflow.operators.python import PythonOperator from airflow.providers.openlineage.extractors.base import ( + BaseExtractor, DefaultExtractor, OperatorLineage, ) from airflow.providers.openlineage.extractors.manager import ExtractorManager from airflow.providers.openlineage.extractors.python import PythonExtractor +from tests.test_utils.config import conf_vars pytestmark = pytest.mark.db_test @@ -52,6 +55,12 @@ class CompleteRunFacet(BaseFacet): FINISHED_FACETS: dict[str, BaseFacet] = {"complete": CompleteRunFacet(True)} +class ExampleExtractor(BaseExtractor): + @classmethod + def get_operator_classnames(cls): + return ["ExampleOperator"] + + class ExampleOperator(BaseOperator): def execute(self, context) -> Any: pass @@ -221,6 +230,24 @@ def test_extraction_without_on_start(): ) +@mock.patch.dict( + os.environ, + {"OPENLINEAGE_EXTRACTORS": "tests.providers.openlineage.extractors.test_base.ExampleExtractor"}, +) +def test_extractors_env_var(): + extractor = ExtractorManager().get_extractor_class(ExampleOperator(task_id="example")) + assert extractor is ExampleExtractor + + +@mock.patch.dict(os.environ, {"OPENLINEAGE_EXTRACTORS": "no.such.extractor"}) +@conf_vars( + {("openlineage", "extractors"): "tests.providers.openlineage.extractors.test_base.ExampleExtractor"} +) +def test_config_has_precedence_over_env_var(): + extractor = ExtractorManager().get_extractor_class(ExampleOperator(task_id="example")) + assert extractor is ExampleExtractor + + def test_does_not_use_default_extractor_when_not_a_method(): extractor_class = ExtractorManager().get_extractor_class(BrokenOperator(task_id="a")) assert extractor_class is None