Skip to content

Commit

Permalink
fix: Fix parent id macro and remove unused utils (#37877)
Browse files Browse the repository at this point in the history
  • Loading branch information
kacpermuda committed Mar 5, 2024
1 parent b541f55 commit 2852976
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 146 deletions.
26 changes: 10 additions & 16 deletions airflow/providers/openlineage/plugins/macros.py
Expand Up @@ -16,17 +16,14 @@
# under the License.
from __future__ import annotations

import os
import typing

from airflow.configuration import conf
from airflow.providers.openlineage.plugins.adapter import OpenLineageAdapter
from airflow.providers.openlineage.plugins.adapter import _DAG_NAMESPACE, OpenLineageAdapter
from airflow.providers.openlineage.utils.utils import get_job_name

if typing.TYPE_CHECKING:
from airflow.models import TaskInstance

_JOB_NAMESPACE = conf.get("openlineage", "namespace", fallback=os.getenv("OPENLINEAGE_NAMESPACE", "default"))


def lineage_run_id(task_instance: TaskInstance):
"""
Expand All @@ -46,21 +43,18 @@ def lineage_run_id(task_instance: TaskInstance):
)


def lineage_parent_id(run_id: str, task_instance: TaskInstance):
def lineage_parent_id(task_instance: TaskInstance):
"""
Macro function which returns the generated job and run id for a given task.
Macro function which returns a unique identifier of given task that can be used to create ParentRunFacet.
This can be used to forward the ids from a task to a child run so the job
hierarchy is preserved. Child run can create ParentRunFacet from those ids.
This identifier is composed of the namespace, job name, and generated run id for given task, structured
as '{namespace}/{job_name}/{run_id}'. This can be used to forward task information from a task to a child
run so the job hierarchy is preserved. Child run can easily create ParentRunFacet from these information.
.. seealso::
For more information on how to use this macro, take a look at the guide:
:ref:`howto/macros:openlineage`
"""
job_name = OpenLineageAdapter.build_task_instance_run_id(
dag_id=task_instance.dag_id,
task_id=task_instance.task.task_id,
execution_date=task_instance.execution_date,
try_number=task_instance.try_number,
)
return f"{_JOB_NAMESPACE}/{job_name}/{run_id}"
job_name = get_job_name(task_instance.task)
run_id = lineage_run_id(task_instance)
return f"{_DAG_NAMESPACE}/{job_name}/{run_id}"
85 changes: 1 addition & 84 deletions airflow/providers/openlineage/utils/utils.py
Expand Up @@ -24,7 +24,6 @@
from contextlib import suppress
from functools import wraps
from typing import TYPE_CHECKING, Any, Iterable
from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse

import attrs
from attrs import asdict
Expand All @@ -42,101 +41,19 @@
from airflow.utils.log.secrets_masker import Redactable, Redacted, SecretsMasker, should_hide_value_for_key

if TYPE_CHECKING:
from airflow.models import DAG, BaseOperator, Connection, DagRun, TaskInstance
from airflow.models import DAG, BaseOperator, DagRun, TaskInstance


log = logging.getLogger(__name__)
_NOMINAL_TIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ"


def openlineage_job_name(dag_id: str, task_id: str) -> str:
return f"{dag_id}.{task_id}"


def get_operator_class(task: BaseOperator) -> type:
if task.__class__.__name__ in ("DecoratedMappedOperator", "MappedOperator"):
return task.operator_class
return task.__class__


def to_json_encodable(task: BaseOperator) -> dict[str, object]:
def _task_encoder(obj):
from airflow.models import DAG

if isinstance(obj, datetime.datetime):
return obj.isoformat()
elif isinstance(obj, DAG):
return {
"dag_id": obj.dag_id,
"tags": obj.tags,
"schedule_interval": obj.schedule_interval,
"timetable": obj.timetable.serialize(),
}
else:
return str(obj)

return json.loads(json.dumps(task.__dict__, default=_task_encoder))


def url_to_https(url) -> str | None:
# Ensure URL exists
if not url:
return None

base_url = None
if url.startswith("git@"):
part = url.split("git@")[1:2]
if part:
base_url = f'https://{part[0].replace(":", "/", 1)}'
elif url.startswith("https://"):
base_url = url

if not base_url:
raise ValueError(f"Unable to extract location from: {url}")

if base_url.endswith(".git"):
base_url = base_url[:-4]
return base_url


def redacted_connection_uri(conn: Connection, filtered_params=None, filtered_prefixes=None):
"""
Return the connection URI for the given Connection.
This method additionally filters URI by removing query parameters that are known to carry sensitive data
like username, password, access key.
"""
if filtered_prefixes is None:
filtered_prefixes = []
if filtered_params is None:
filtered_params = []

def filter_key_params(k: str):
return k not in filtered_params and any(substr in k for substr in filtered_prefixes)

conn_uri = conn.get_uri()
parsed = urlparse(conn_uri)

# Remove username and password
netloc = f"{parsed.hostname}" + (f":{parsed.port}" if parsed.port else "")
parsed = parsed._replace(netloc=netloc)
if parsed.query:
query_dict = dict(parse_qsl(parsed.query))
if conn.EXTRA_KEY in query_dict:
query_dict = json.loads(query_dict[conn.EXTRA_KEY])
filtered_qs = {k: v for k, v in query_dict.items() if not filter_key_params(k)}
parsed = parsed._replace(query=urlencode(filtered_qs))
return urlunparse(parsed)


def get_connection(conn_id) -> Connection | None:
from airflow.hooks.base import BaseHook

with suppress(Exception):
return BaseHook.get_connection(conn_id=conn_id)
return None


def get_job_name(task):
return f"{task.dag_id}.{task.task_id}"

Expand Down
19 changes: 9 additions & 10 deletions tests/providers/openlineage/plugins/test_macros.py
Expand Up @@ -37,16 +37,15 @@ def test_lineage_run_id():
assert actual == expected


def test_lineage_parent_id():
@mock.patch("airflow.providers.openlineage.plugins.macros.lineage_run_id")
def test_lineage_parent_id(mock_run_id):
mock_run_id.return_value = "run_id"
task = mock.MagicMock(
dag_id="dag_id", execution_date="execution_date", try_number=1, task=mock.MagicMock(task_id="task_id")
)
actual = lineage_parent_id(run_id="run_id", task_instance=task)
job_name = str(
uuid.uuid3(
uuid.NAMESPACE_URL,
f"{_DAG_NAMESPACE}.dag_id.task_id.execution_date.1",
)
dag_id="dag_id",
execution_date="execution_date",
try_number=1,
task=mock.MagicMock(task_id="task_id", dag_id="dag_id"),
)
expected = f"{_DAG_NAMESPACE}/{job_name}/run_id"
actual = lineage_parent_id(task_instance=task)
expected = f"{_DAG_NAMESPACE}/dag_id.task_id/run_id"
assert actual == expected
36 changes: 0 additions & 36 deletions tests/providers/openlineage/plugins/test_utils.py
Expand Up @@ -18,7 +18,6 @@

import datetime
import json
import os
import uuid
from json import JSONEncoder
from typing import Any
Expand All @@ -29,23 +28,15 @@
from pkg_resources import parse_version

from airflow.models import DAG as AIRFLOW_DAG, DagModel
from airflow.operators.empty import EmptyOperator
from airflow.providers.openlineage.utils.utils import (
InfoJsonEncodable,
OpenLineageRedactor,
_is_name_redactable,
get_connection,
to_json_encodable,
url_to_https,
)
from airflow.utils import timezone
from airflow.utils.log.secrets_masker import _secrets_masker
from airflow.utils.state import State

AIRFLOW_CONN_ID = "test_db"
AIRFLOW_CONN_URI = "postgres://localhost:5432/testdb"
SNOWFLAKE_CONN_URI = "snowflake://12345.us-east-1.snowflakecomputing.com/MyTestRole?extra__snowflake__account=12345&extra__snowflake__database=TEST_DB&extra__snowflake__insecure_mode=false&extra__snowflake__region=us-east-1&extra__snowflake__role=MyTestRole&extra__snowflake__warehouse=TEST_WH&extra__snowflake__aws_access_key_id=123456&extra__snowflake__aws_secret_access_key=abcdefg"


class SafeStrDict(dict):
def __str__(self):
Expand All @@ -59,21 +50,6 @@ def __str__(self):
return str(dict(castable))


def test_get_connection():
os.environ["AIRFLOW_CONN_DEFAULT"] = AIRFLOW_CONN_URI

conn = get_connection("default")
assert conn.host == "localhost"
assert conn.port == 5432
assert conn.conn_type == "postgres"
assert conn


def test_url_to_https_no_url():
assert url_to_https(None) is None
assert url_to_https("") is None


@pytest.mark.db_test
def test_get_dagrun_start_end():
start_date = datetime.datetime(2022, 1, 1)
Expand Down Expand Up @@ -105,18 +81,6 @@ def test_parse_version():
assert parse_version("2.2.4.dev0") < parse_version("2.3.0.dev0")


def test_to_json_encodable():
dag = AIRFLOW_DAG(
dag_id="test_dag", schedule_interval="*/2 * * * *", start_date=datetime.datetime.now(), catchup=False
)
task = EmptyOperator(task_id="test_task", dag=dag)

encodable = to_json_encodable(task)
encoded = json.dumps(encodable)
decoded = json.loads(encoded)
assert decoded == encodable


def test_safe_dict():
assert str(SafeStrDict({"a": 1})) == str({"a": 1})

Expand Down

0 comments on commit 2852976

Please sign in to comment.