diff --git a/airflow/providers/openlineage/extractors/base.py b/airflow/providers/openlineage/extractors/base.py index 3e18bb2399f21..d87334f48657b 100644 --- a/airflow/providers/openlineage/extractors/base.py +++ b/airflow/providers/openlineage/extractors/base.py @@ -70,6 +70,13 @@ def disabled_operators(self) -> set[str]: operator.strip() for operator in conf.get("openlineage", "disabled_for_operators").split(";") ) + @cached_property + def _is_operator_disabled(self) -> bool: + fully_qualified_class_name = ( + self.operator.__class__.__module__ + "." + self.operator.__class__.__name__ + ) + return fully_qualified_class_name in self.disabled_operators + def validate(self): assert self.operator.task_type in self.get_operator_classnames() @@ -78,10 +85,7 @@ def _execute_extraction(self) -> OperatorLineage | None: ... def extract(self) -> OperatorLineage | None: - fully_qualified_class_name = ( - self.operator.__class__.__module__ + "." + self.operator.__class__.__name__ - ) - if fully_qualified_class_name in self.disabled_operators: + if self._is_operator_disabled: self.log.debug( f"Skipping extraction for operator {self.operator.task_type} " "due to its presence in [openlineage] openlineage_disabled_for_operators." @@ -123,6 +127,12 @@ def _execute_extraction(self) -> OperatorLineage | None: return None def extract_on_complete(self, task_instance) -> OperatorLineage | None: + if self._is_operator_disabled: + self.log.debug( + f"Skipping extraction for operator {self.operator.task_type} " + "due to its presence in [openlineage] openlineage_disabled_for_operators." + ) + return None if task_instance.state == TaskInstanceState.FAILED: on_failed = getattr(self.operator, "get_openlineage_facets_on_failure", None) if on_failed and callable(on_failed): diff --git a/tests/providers/openlineage/extractors/test_base.py b/tests/providers/openlineage/extractors/test_base.py index 309e0c1a79f52..35d51ee2937af 100644 --- a/tests/providers/openlineage/extractors/test_base.py +++ b/tests/providers/openlineage/extractors/test_base.py @@ -285,3 +285,17 @@ def test_default_extractor_uses_wrong_operatorlineage_class(): assert ( ExtractorManager().extract_metadata(mock.MagicMock(), operator, complete=False) == OperatorLineage() ) + + +@mock.patch.dict( + os.environ, + { + "AIRFLOW__OPENLINEAGE__DISABLED_FOR_OPERATORS": "tests.providers.openlineage.extractors.test_base.ExampleOperator" + }, +) +def test_default_extraction_disabled_operator(): + extractor = DefaultExtractor(ExampleOperator(task_id="test")) + metadata = extractor.extract() + assert metadata is None + metadata = extractor.extract_on_complete(None) + assert metadata is None diff --git a/tests/providers/openlineage/extractors/test_bash.py b/tests/providers/openlineage/extractors/test_bash.py index 4919f2b873650..b5fe07741e60c 100644 --- a/tests/providers/openlineage/extractors/test_bash.py +++ b/tests/providers/openlineage/extractors/test_bash.py @@ -117,3 +117,16 @@ def test_extract_dag_bash_command_env_does_not_disable_on_random_string(): def test_extract_dag_bash_command_conf_does_not_disable_on_random_string(): extractor = BashExtractor(bash_task) assert extractor.extract().job_facets["sourceCode"] == SourceCodeJobFacet("bash", "ls -halt && exit 0") + + +@patch.dict( + os.environ, + {"AIRFLOW__OPENLINEAGE__DISABLED_FOR_OPERATORS": "airflow.operators.bash.BashOperator"}, +) +def test_bash_extraction_disabled_operator(): + operator = BashOperator(task_id="taskid", bash_command="echo 1;") + extractor = BashExtractor(operator) + metadata = extractor.extract() + assert metadata is None + metadata = extractor.extract_on_complete(None) + assert metadata is None diff --git a/tests/providers/openlineage/extractors/test_python.py b/tests/providers/openlineage/extractors/test_python.py index f90366e58e147..5a607e14684af 100644 --- a/tests/providers/openlineage/extractors/test_python.py +++ b/tests/providers/openlineage/extractors/test_python.py @@ -87,6 +87,8 @@ def test_python_extraction_disabled_operator(): extractor = PythonExtractor(operator) metadata = extractor.extract() assert metadata is None + metadata = extractor.extract_on_complete(None) + assert metadata is None @patch.dict(os.environ, {"OPENLINEAGE_AIRFLOW_DISABLE_SOURCE_CODE": "False"})