Skip to content

Commit

Permalink
fix: Check if operator is disabled in DefaultExtractor.extract_on_com…
Browse files Browse the repository at this point in the history
…plete (#37392)
  • Loading branch information
kacpermuda committed Feb 14, 2024
1 parent 7461ac7 commit 61f0adf
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 4 deletions.
18 changes: 14 additions & 4 deletions airflow/providers/openlineage/extractors/base.py
Expand Up @@ -70,6 +70,13 @@ def disabled_operators(self) -> set[str]:
operator.strip() for operator in conf.get("openlineage", "disabled_for_operators").split(";")
)

@cached_property
def _is_operator_disabled(self) -> bool:
fully_qualified_class_name = (
self.operator.__class__.__module__ + "." + self.operator.__class__.__name__
)
return fully_qualified_class_name in self.disabled_operators

def validate(self):
assert self.operator.task_type in self.get_operator_classnames()

Expand All @@ -78,10 +85,7 @@ def _execute_extraction(self) -> OperatorLineage | None:
...

def extract(self) -> OperatorLineage | None:
fully_qualified_class_name = (
self.operator.__class__.__module__ + "." + self.operator.__class__.__name__
)
if fully_qualified_class_name in self.disabled_operators:
if self._is_operator_disabled:
self.log.debug(
f"Skipping extraction for operator {self.operator.task_type} "
"due to its presence in [openlineage] openlineage_disabled_for_operators."
Expand Down Expand Up @@ -123,6 +127,12 @@ def _execute_extraction(self) -> OperatorLineage | None:
return None

def extract_on_complete(self, task_instance) -> OperatorLineage | None:
if self._is_operator_disabled:
self.log.debug(
f"Skipping extraction for operator {self.operator.task_type} "
"due to its presence in [openlineage] openlineage_disabled_for_operators."
)
return None
if task_instance.state == TaskInstanceState.FAILED:
on_failed = getattr(self.operator, "get_openlineage_facets_on_failure", None)
if on_failed and callable(on_failed):
Expand Down
14 changes: 14 additions & 0 deletions tests/providers/openlineage/extractors/test_base.py
Expand Up @@ -285,3 +285,17 @@ def test_default_extractor_uses_wrong_operatorlineage_class():
assert (
ExtractorManager().extract_metadata(mock.MagicMock(), operator, complete=False) == OperatorLineage()
)


@mock.patch.dict(
os.environ,
{
"AIRFLOW__OPENLINEAGE__DISABLED_FOR_OPERATORS": "tests.providers.openlineage.extractors.test_base.ExampleOperator"
},
)
def test_default_extraction_disabled_operator():
extractor = DefaultExtractor(ExampleOperator(task_id="test"))
metadata = extractor.extract()
assert metadata is None
metadata = extractor.extract_on_complete(None)
assert metadata is None
13 changes: 13 additions & 0 deletions tests/providers/openlineage/extractors/test_bash.py
Expand Up @@ -117,3 +117,16 @@ def test_extract_dag_bash_command_env_does_not_disable_on_random_string():
def test_extract_dag_bash_command_conf_does_not_disable_on_random_string():
extractor = BashExtractor(bash_task)
assert extractor.extract().job_facets["sourceCode"] == SourceCodeJobFacet("bash", "ls -halt && exit 0")


@patch.dict(
os.environ,
{"AIRFLOW__OPENLINEAGE__DISABLED_FOR_OPERATORS": "airflow.operators.bash.BashOperator"},
)
def test_bash_extraction_disabled_operator():
operator = BashOperator(task_id="taskid", bash_command="echo 1;")
extractor = BashExtractor(operator)
metadata = extractor.extract()
assert metadata is None
metadata = extractor.extract_on_complete(None)
assert metadata is None
2 changes: 2 additions & 0 deletions tests/providers/openlineage/extractors/test_python.py
Expand Up @@ -87,6 +87,8 @@ def test_python_extraction_disabled_operator():
extractor = PythonExtractor(operator)
metadata = extractor.extract()
assert metadata is None
metadata = extractor.extract_on_complete(None)
assert metadata is None


@patch.dict(os.environ, {"OPENLINEAGE_AIRFLOW_DISABLE_SOURCE_CODE": "False"})
Expand Down

0 comments on commit 61f0adf

Please sign in to comment.