Skip to content

Commit

Permalink
Allow to disable openlineage at operator level (#33685)
Browse files Browse the repository at this point in the history
  • Loading branch information
RNHTTR committed Aug 29, 2023
1 parent 7d267fb commit 0d49d1f
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 11 deletions.
28 changes: 26 additions & 2 deletions airflow/providers/openlineage/extractors/base.py
Expand Up @@ -18,10 +18,12 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from functools import cached_property
from typing import TYPE_CHECKING

from attrs import Factory, define

from airflow.configuration import conf
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.state import TaskInstanceState

Expand Down Expand Up @@ -62,12 +64,30 @@ def get_operator_classnames(cls) -> list[str]:
"""
raise NotImplementedError()

@cached_property
def disabled_operators(self) -> set[str]:
return set(
operator.strip() for operator in conf.get("openlineage", "disabled_for_operators").split(";")
)

def validate(self):
assert self.operator.task_type in self.get_operator_classnames()

@abstractmethod
def _execute_extraction(self) -> OperatorLineage | None:
...

def extract(self) -> OperatorLineage | None:
pass
fully_qualified_class_name = (
self.operator.__class__.__module__ + "." + self.operator.__class__.__name__
)
if fully_qualified_class_name in self.disabled_operators:
self.log.warning(
f"Skipping extraction for operator {self.operator.task_type} "
"due to its presence in [openlineage] openlineage_disabled_for_operators."
)
return None
return self._execute_extraction()

def extract_on_complete(self, task_instance) -> OperatorLineage | None:
return self.extract()
Expand All @@ -85,7 +105,7 @@ def get_operator_classnames(cls) -> list[str]:
"""
return []

def extract(self) -> OperatorLineage | None:
def _execute_extraction(self) -> OperatorLineage | None:
# OpenLineage methods are optional - if there's no method, return None
try:
return self._get_openlineage_facets(self.operator.get_openlineage_facets_on_start) # type: ignore
Expand All @@ -96,6 +116,10 @@ def extract(self) -> OperatorLineage | None:
)
return None
except AttributeError:
self.log.warning(
f"Operator {self.operator.task_type} does not have the "
"get_openlineage_facets_on_start method."
)
return None

def extract_on_complete(self, task_instance) -> OperatorLineage | None:
Expand Down
10 changes: 8 additions & 2 deletions airflow/providers/openlineage/extractors/bash.py
Expand Up @@ -24,7 +24,10 @@
UnknownOperatorAttributeRunFacet,
UnknownOperatorInstance,
)
from airflow.providers.openlineage.utils.utils import get_filtered_unknown_operator_keys, is_source_enabled
from airflow.providers.openlineage.utils.utils import (
get_filtered_unknown_operator_keys,
is_source_enabled,
)

"""
:meta private:
Expand All @@ -46,7 +49,7 @@ class BashExtractor(BaseExtractor):
def get_operator_classnames(cls) -> list[str]:
return ["BashOperator"]

def extract(self) -> OperatorLineage | None:
def _execute_extraction(self) -> OperatorLineage | None:
job_facets: dict = {}
if is_source_enabled():
job_facets = {
Expand All @@ -73,3 +76,6 @@ def extract(self) -> OperatorLineage | None:
)
},
)

def extract(self) -> OperatorLineage | None:
return super().extract()
15 changes: 12 additions & 3 deletions airflow/providers/openlineage/extractors/python.py
Expand Up @@ -27,7 +27,10 @@
UnknownOperatorAttributeRunFacet,
UnknownOperatorInstance,
)
from airflow.providers.openlineage.utils.utils import get_filtered_unknown_operator_keys, is_source_enabled
from airflow.providers.openlineage.utils.utils import (
get_filtered_unknown_operator_keys,
is_source_enabled,
)

"""
:meta private:
Expand All @@ -49,7 +52,7 @@ class PythonExtractor(BaseExtractor):
def get_operator_classnames(cls) -> list[str]:
return ["PythonOperator"]

def extract(self) -> OperatorLineage | None:
def _execute_extraction(self) -> OperatorLineage | None:
source_code = self.get_source_code(self.operator.python_callable)
job_facet: dict = {}
if is_source_enabled() and source_code:
Expand Down Expand Up @@ -84,5 +87,11 @@ def get_source_code(self, callable: Callable) -> str | None:
# Trying to extract source code of builtin_function_or_method
return str(callable)
except OSError:
self.log.exception("Can't get source code facet of PythonOperator %s", self.operator.task_id)
self.log.exception(
"Can't get source code facet of PythonOperator %s",
self.operator.task_id,
)
return None

def extract(self) -> OperatorLineage | None:
return super().extract()
15 changes: 11 additions & 4 deletions airflow/providers/openlineage/provider.yaml
Expand Up @@ -60,6 +60,13 @@ config:
example: ~
default: "False"
version_added: ~
disabled_for_operators:
description: |
Semicolon separated string of Airflow Operator names to disable
type: string
example: "airflow.operators.bash.BashOperator;airflow.operators.python.PythonOperator"
default: ""
version_added: 1.1.0
namespace:
description: |
OpenLineage namespace
Expand All @@ -69,10 +76,10 @@ config:
default: ~
extractors:
description: |
Comma-separated paths to custom OpenLineage extractors.
Semicolon separated paths to custom OpenLineage extractors.
type: string
example: full.path.to.ExtractorClass;full.path.to.AnotherExtractorClass
default: ''
default: ""
version_added: ~
config_path:
description: |
Expand All @@ -81,7 +88,7 @@ config:
version_added: ~
type: string
example: ~
default: ''
default: ""
transport:
description: |
OpenLineage Client transport configuration. It should contain type
Expand All @@ -94,7 +101,7 @@ config:
* Console
type: string
example: '{"type": "http", "url": "http://localhost:5000"}'
default: ''
default: ""
version_added: ~
disable_source_code:
description: |
Expand Down
11 changes: 11 additions & 0 deletions tests/providers/openlineage/extractors/test_python_extractor.py
Expand Up @@ -72,6 +72,17 @@ def test_extract_operator_code_disables_on_no_env():
assert "sourceCode" not in extractor.extract().job_facets


@patch.dict(
os.environ,
{"AIRFLOW__OPENLINEAGE__DISABLED_FOR_OPERATORS": "airflow.operators.python.PythonOperator"},
)
def test_python_extraction_disabled_operator():
operator = PythonOperator(task_id="taskid", python_callable=callable)
extractor = PythonExtractor(operator)
metadata = extractor.extract()
assert metadata is None


@patch.dict(os.environ, {"OPENLINEAGE_AIRFLOW_DISABLE_SOURCE_CODE": "False"})
def test_extract_operator_code_enables_on_false_env():
operator = PythonOperator(task_id="taskid", python_callable=callable)
Expand Down

0 comments on commit 0d49d1f

Please sign in to comment.