From 219ec36040210e5188b9c6b7586b75af8158e9e1 Mon Sep 17 00:00:00 2001 From: Kacper Muda Date: Tue, 25 Nov 2025 17:58:20 +0100 Subject: [PATCH] feat: auto-inject OpenLineage parent info into TriggerDagRunOperator conf When OpenLineage is enabled, add injection of OpenLineage parent task metadata into triggered DAG run conf to enable improved lineage tracking across DAG boundaries. --- dev/breeze/tests/test_selective_checks.py | 8 +- .../providers/openlineage/utils/utils.py | 1 + providers/standard/pyproject.toml | 8 + .../standard/operators/trigger_dagrun.py | 30 +- .../providers/standard/utils/openlineage.py | 185 +++++++ .../standard/operators/test_trigger_dagrun.py | 422 +++++++++++++++- .../unit/standard/utils/test_openlineage.py | 470 ++++++++++++++++++ 7 files changed, 1117 insertions(+), 7 deletions(-) create mode 100644 providers/standard/src/airflow/providers/standard/utils/openlineage.py create mode 100644 providers/standard/tests/unit/standard/utils/test_openlineage.py diff --git a/dev/breeze/tests/test_selective_checks.py b/dev/breeze/tests/test_selective_checks.py index 40256e869a58d..8008b84086a06 100644 --- a/dev/breeze/tests/test_selective_checks.py +++ b/dev/breeze/tests/test_selective_checks.py @@ -1941,7 +1941,7 @@ def test_expected_output_push( { "selected-providers-list-as-string": "amazon common.compat common.io common.sql " "databricks dbt.cloud ftp google microsoft.mssql mysql " - "openlineage oracle postgres sftp snowflake trino", + "openlineage oracle postgres sftp snowflake standard trino", "all-python-versions": f"['{DEFAULT_PYTHON_MAJOR_MINOR_VERSION}']", "all-python-versions-list-as-string": DEFAULT_PYTHON_MAJOR_MINOR_VERSION, "ci-image-build": "true", @@ -1952,7 +1952,7 @@ def test_expected_output_push( "docs-build": "true", "docs-list-as-string": "apache-airflow task-sdk amazon common.compat common.io common.sql " "databricks dbt.cloud ftp google microsoft.mssql mysql " - "openlineage oracle postgres sftp snowflake trino", + "openlineage oracle postgres sftp snowflake standard trino", "skip-prek-hooks": ALL_SKIPPED_COMMITS_ON_NO_CI_IMAGE, "run-kubernetes-tests": "false", "upgrade-to-newer-dependencies": "false", @@ -1960,10 +1960,10 @@ def test_expected_output_push( "providers-test-types-list-as-strings-in-json": json.dumps( [ { - "description": "amazon...google", + "description": "amazon...standard", "test_types": "Providers[amazon] Providers[common.compat,common.io,common.sql," "databricks,dbt.cloud,ftp,microsoft.mssql,mysql,openlineage,oracle," - "postgres,sftp,snowflake,trino] Providers[google]", + "postgres,sftp,snowflake,trino] Providers[google] Providers[standard]", } ] ), diff --git a/providers/openlineage/src/airflow/providers/openlineage/utils/utils.py b/providers/openlineage/src/airflow/providers/openlineage/utils/utils.py index 16f6c255dfc6c..477d088d3bfc2 100644 --- a/providers/openlineage/src/airflow/providers/openlineage/utils/utils.py +++ b/providers/openlineage/src/airflow/providers/openlineage/utils/utils.py @@ -803,6 +803,7 @@ class TaskInfo(InfoJsonEncodable): "sla", "task_id", "trigger_dag_id", + "trigger_run_id", "external_dag_id", "external_task_id", "trigger_rule", diff --git a/providers/standard/pyproject.toml b/providers/standard/pyproject.toml index 8bc91fa16e94b..596bb9bf69d94 100644 --- a/providers/standard/pyproject.toml +++ b/providers/standard/pyproject.toml @@ -62,12 +62,20 @@ dependencies = [ "apache-airflow-providers-common-compat>=1.8.0", ] +# The optional dependencies should be modified in place in the generated file +# Any change in the dependencies is preserved when the file is regenerated +[project.optional-dependencies] +"openlineage" = [ + "apache-airflow-providers-openlineage" +] + [dependency-groups] dev = [ "apache-airflow", "apache-airflow-task-sdk", "apache-airflow-devel-common", "apache-airflow-providers-common-compat", + "apache-airflow-providers-openlineage", # Additional devel dependencies (do not remove this line and add extra development dependencies) "apache-airflow-providers-mysql", ] diff --git a/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py b/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py index c0f8709fa87f9..239e8ae540f3a 100644 --- a/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py +++ b/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py @@ -40,6 +40,7 @@ from airflow.models.serialized_dag import SerializedDagModel from airflow.providers.common.compat.sdk import BaseOperatorLink, XCom, timezone from airflow.providers.standard.triggers.external_task import DagStateTrigger +from airflow.providers.standard.utils.openlineage import safe_inject_openlineage_properties_into_dagrun_conf from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS, BaseOperator from airflow.utils.state import DagRunState from airflow.utils.types import DagRunType @@ -136,6 +137,12 @@ class TriggerDagRunOperator(BaseOperator): :param fail_when_dag_is_paused: If the dag to trigger is paused, DagIsPaused will be raised. :param deferrable: If waiting for completion, whether or not to defer the task until done, default is ``False``. + :param openlineage_inject_parent_info: whether to include OpenLineage metadata about the parent task + in the triggered DAG run's conf, enabling improved lineage tracking. The metadata is only injected + if OpenLineage is enabled and running. This option does not modify any other part of the conf, + and existing OpenLineage-related settings in the conf will not be overwritten. The injection process + is safeguarded against exceptions - if any error occurs during metadata injection, it is gracefully + handled and the conf remains unchanged - so it's safe to use. Default is ``True`` """ template_fields: Sequence[str] = ( @@ -165,6 +172,7 @@ def __init__( skip_when_already_exists: bool = False, fail_when_dag_is_paused: bool = False, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + openlineage_inject_parent_info: bool = True, **kwargs, ) -> None: super().__init__(**kwargs) @@ -184,6 +192,7 @@ def __init__( self.failed_states = [DagRunState.FAILED] self.skip_when_already_exists = skip_when_already_exists self.fail_when_dag_is_paused = fail_when_dag_is_paused + self.openlineage_inject_parent_info = openlineage_inject_parent_info self._defer = deferrable self.logical_date = logical_date if logical_date is NOTSET: @@ -214,6 +223,12 @@ def execute(self, context: Context): except (TypeError, JSONDecodeError): raise ValueError("conf parameter should be JSON Serializable %s", self.conf) + if self.openlineage_inject_parent_info: + self.log.debug("Checking if OpenLineage information can be safely injected into dagrun conf.") + self.conf = safe_inject_openlineage_properties_into_dagrun_conf( + dr_conf=self.conf, ti=context.get("ti") + ) + if self.trigger_run_id: run_id = str(self.trigger_run_id) else: @@ -226,6 +241,9 @@ def execute(self, context: Context): else: run_id = DagRun.generate_run_id(DagRunType.MANUAL, parsed_logical_date or timezone.utcnow()) # type: ignore[misc,call-arg] + # Save run_id as task attribute - to be used by listeners + self.trigger_run_id = run_id + if self.fail_when_dag_is_paused: dag_model = DagModel.get_current(self.trigger_dag_id) if not dag_model: @@ -237,9 +255,13 @@ def execute(self, context: Context): raise AirflowException(f"Dag {self.trigger_dag_id} is paused") if AIRFLOW_V_3_0_PLUS: - self._trigger_dag_af_3(context=context, run_id=run_id, parsed_logical_date=parsed_logical_date) + self._trigger_dag_af_3( + context=context, run_id=self.trigger_run_id, parsed_logical_date=parsed_logical_date + ) else: - self._trigger_dag_af_2(context=context, run_id=run_id, parsed_logical_date=parsed_logical_date) + self._trigger_dag_af_2( + context=context, run_id=self.trigger_run_id, parsed_logical_date=parsed_logical_date + ) def _trigger_dag_af_3(self, context, run_id, parsed_logical_date): from airflow.exceptions import DagRunTriggerException @@ -327,6 +349,10 @@ def _trigger_dag_af_2(self, context, run_id, parsed_logical_date): return def execute_complete(self, context: Context, event: tuple[str, dict[str, Any]]): + run_ids = event[1]["run_ids"] + # Re-set as attribute after coming back from deferral - to be used by listeners. + # Just a safety check on length, we should always have single run_id here. + self.trigger_run_id = run_ids[0] if len(run_ids) == 1 else None if AIRFLOW_V_3_0_PLUS: self._trigger_dag_run_af_3_execute_complete(event=event) else: diff --git a/providers/standard/src/airflow/providers/standard/utils/openlineage.py b/providers/standard/src/airflow/providers/standard/utils/openlineage.py new file mode 100644 index 0000000000000..3dd06b4eaece7 --- /dev/null +++ b/providers/standard/src/airflow/providers/standard/utils/openlineage.py @@ -0,0 +1,185 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from airflow.exceptions import AirflowOptionalProviderFeatureException +from airflow.providers.common.compat.openlineage.check import require_openlineage_version + +if TYPE_CHECKING: + from airflow.models import TaskInstance + from airflow.sdk.types import RuntimeTaskInstanceProtocol as RuntimeTI + +log = logging.getLogger(__name__) + +OPENLINEAGE_PROVIDER_MIN_VERSION = "2.8.0" + + +def _is_openlineage_provider_accessible() -> bool: + """ + Check if the OpenLineage provider is accessible. + + This function attempts to import the necessary OpenLineage modules and checks if the provider + is enabled and the listener is available. + + Returns: + bool: True if the OpenLineage provider is accessible, False otherwise. + """ + try: + from airflow.providers.openlineage.conf import is_disabled + from airflow.providers.openlineage.plugins.listener import get_openlineage_listener + except (ImportError, AttributeError): + log.debug("OpenLineage provider could not be imported.") + return False + + if is_disabled(): + log.debug("OpenLineage provider is disabled.") + return False + + if not get_openlineage_listener(): + log.debug("OpenLineage listener could not be found.") + return False + + return True + + +@require_openlineage_version(provider_min_version=OPENLINEAGE_PROVIDER_MIN_VERSION) +def _get_openlineage_parent_info(ti: TaskInstance | RuntimeTI) -> dict[str, str]: + """Get OpenLineage metadata about the parent task.""" + from airflow.providers.openlineage.plugins.macros import ( + lineage_job_name, + lineage_job_namespace, + lineage_root_job_name, + lineage_root_job_namespace, + lineage_root_run_id, + lineage_run_id, + ) + + return { + "parentRunId": lineage_run_id(ti), + "parentJobName": lineage_job_name(ti), + "parentJobNamespace": lineage_job_namespace(), + "rootParentRunId": lineage_root_run_id(ti), + "rootParentJobName": lineage_root_job_name(ti), + "rootParentJobNamespace": lineage_root_job_namespace(ti), + } + + +def _inject_openlineage_parent_info_to_dagrun_conf( + dr_conf: dict | None, ol_parent_info: dict[str, str] +) -> dict: + """ + Safely inject OpenLineage parent and root run metadata into a DAG run configuration. + + This function adds parent and root job/run identifiers derived from the given TaskInstance into the + `openlineage` section of the DAG run configuration. If an `openlineage` key already exists, it is + preserved and extended, but no existing parent or root identifiers are overwritten. + + The function performs several safety checks: + - If conf is not a dictionary or contains a non-dict `openlineage` section, conf is returned unmodified. + - If `openlineage` section contains any parent/root lineage identifiers, conf is returned unmodified. + + Args: + dr_conf: The original DAG run configuration dictionary or None. + ol_parent_info: OpenLineage metadata about the parent task + + Returns: + A modified DAG run conf with injected OpenLineage parent and root metadata, + or the original conf if injection is not possible. + """ + current_ol_dr_conf = {} + if isinstance(dr_conf, dict) and dr_conf.get("openlineage"): + current_ol_dr_conf = dr_conf["openlineage"] + if not isinstance(current_ol_dr_conf, dict): + log.warning( + "Existing 'openlineage' section of DagRun conf is not a dictionary; " + "skipping injection of parent metadata." + ) + return dr_conf + forbidden_keys = ( + "parentRunId", + "parentJobName", + "parentJobNamespace", + "rootParentRunId", + "rootJobName", + "rootJobNamespace", + ) + + if existing := [k for k in forbidden_keys if k in current_ol_dr_conf]: + log.warning( + "'openlineage' section of DagRun conf already contains parent or root " + "identifiers: `%s`; skipping injection to avoid overwriting existing values.", + ", ".join(existing), + ) + return dr_conf + + return {**(dr_conf or {}), **{"openlineage": {**ol_parent_info, **current_ol_dr_conf}}} + + +def safe_inject_openlineage_properties_into_dagrun_conf( + dr_conf: dict | None, ti: TaskInstance | RuntimeTI | None +) -> dict | None: + """ + Safely inject OpenLineage parent task metadata into a DAG run conf. + + This function checks whether the OpenLineage provider is accessible and supports parent information + injection. If so, it enriches the DAG run conf with OpenLineage metadata about the parent task + to improve lineage tracking. The function does not modify other conf fields, will not overwrite + any existing content, and safely returns the original configuration if OpenLineage is unavailable, + unsupported, or an error occurs during injection. + + :param dr_conf: The original DAG run configuration dictionary. + :param ti: The TaskInstance whose metadata may be injected. + + :return: A potentially enriched DAG run conf with OpenLineage parent information, + or the original conf if injection was skipped or failed. + """ + try: + if ti is None: + log.debug("Task instance not provided - dagrun conf not modified.") + return dr_conf + + if not _is_openlineage_provider_accessible(): + log.debug("OpenLineage provider not accessible - dagrun conf not modified.") + return dr_conf + + ol_parent_info = _get_openlineage_parent_info(ti=ti) + + log.info("Injecting openlineage parent task information into dagrun conf.") + new_conf = _inject_openlineage_parent_info_to_dagrun_conf( + dr_conf=dr_conf.copy() if isinstance(dr_conf, dict) else dr_conf, + ol_parent_info=ol_parent_info, + ) + return new_conf + except AirflowOptionalProviderFeatureException: + log.info( + "Current OpenLineage provider version doesn't support parent information in " + "the DagRun conf. Upgrade `apache-airflow-providers-openlineage>=%s` to use this feature. " + "DagRun conf has not been modified by OpenLineage.", + OPENLINEAGE_PROVIDER_MIN_VERSION, + ) + return dr_conf + except Exception as e: + log.warning( + "An error occurred while trying to inject OpenLineage information into dagrun conf. " + "DagRun conf has not been modified by OpenLineage. Error: %s", + str(e), + ) + log.debug("Error details: ", exc_info=e) + return dr_conf diff --git a/providers/standard/tests/unit/standard/operators/test_trigger_dagrun.py b/providers/standard/tests/unit/standard/operators/test_trigger_dagrun.py index ffb4d1524fb50..4242766291153 100644 --- a/providers/standard/tests/unit/standard/operators/test_trigger_dagrun.py +++ b/providers/standard/tests/unit/standard/operators/test_trigger_dagrun.py @@ -66,6 +66,8 @@ task = EmptyOperator(task_id='test', dag=dag) """ +OL_UTILS_PATH = "airflow.providers.standard.utils.openlineage" +TRIGGER_OP_PATH = "airflow.providers.standard.operators.trigger_dagrun" class TestDagRunOperator: @@ -137,9 +139,10 @@ def test_trigger_dagrun(self): ).rsplit("_", 1)[0] # rsplit because last few characters are random. assert exc_info.value.dag_run_id == expected_run_id + assert task.trigger_run_id == expected_run_id # run_id is saved as attribute @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Implementation is different for Airflow 2 & 3") - @mock.patch("airflow.providers.standard.operators.trigger_dagrun.XCom.get_one") + @mock.patch(f"{TRIGGER_OP_PATH}.XCom.get_one") def test_extra_operator_link(self, mock_xcom_get_one, dag_maker): with dag_maker(TEST_DAG_ID, default_args={"start_date": DEFAULT_DATE}, serialized=True): task = TriggerDagRunOperator( @@ -255,6 +258,30 @@ def test_trigger_dag_run_execute_complete_should_fail(self): ), ) + @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test only for Airflow 3") + def test_trigger_dag_run_execute_complete_re_set_run_id_attribute(self): + operator = TriggerDagRunOperator( + task_id="test_task", + trigger_dag_id=TRIGGERED_DAG_ID, + wait_for_completion=True, + poke_interval=10, + failed_states=[], + ) + assert operator.trigger_run_id is None + + try: + operator.execute_complete( + {}, + ( + "airflow.providers.standard.triggers.external_task.DagStateTrigger", + {"run_ids": ["run_id_1"], "run_id_1": "success"}, + ), + ) + except Exception as e: + pytest.fail(f"Error: {e}") + + assert operator.trigger_run_id == "run_id_1" + @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Implementation is different for Airflow 2 & 3") def test_trigger_dag_run_with_fail_when_dag_is_paused_should_fail(self): with pytest.raises( @@ -301,6 +328,191 @@ def test_trigger_dagrun_with_str_conf_error(self): with pytest.raises(ValueError, match="conf parameter should be JSON Serializable"): task.execute(context={}) + @pytest.mark.parametrize("original_conf", (None, {}, {"foo": "bar"})) + @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Implementation is different for Airflow 2 & 3") + @mock.patch(f"{TRIGGER_OP_PATH}.safe_inject_openlineage_properties_into_dagrun_conf") + def test_trigger_dagrun_conf_openlineage_injection_disabled_with_explicit_false_arg( + self, mock_inject, original_conf + ): + """Test that conf is not modified when openlineage_inject_parent_info=False.""" + with time_machine.travel("2025-02-18T08:04:46Z", tick=False): + task = TriggerDagRunOperator( + task_id="test_task", + trigger_dag_id=TRIGGERED_DAG_ID, + conf=original_conf, + openlineage_inject_parent_info=False, + ) + + with pytest.raises(DagRunTriggerException) as exc_info: + task.execute(context={"ti": mock.MagicMock()}) + + # Injection function should not be called + mock_inject.assert_not_called() + # Conf should remain unchanged + assert exc_info.value.conf == original_conf + + @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Implementation is different for Airflow 2 & 3") + @mock.patch(f"{OL_UTILS_PATH}._is_openlineage_provider_accessible") + def test_trigger_dagrun_conf_openlineage_injection_disabled_when_ol_not_accessible( + self, mock_is_accessible + ): + """Test that conf is not modified when OpenLineage provider is not accessible.""" + original_conf = {"foo": "bar"} + # Simulate OL provider being disabled/not accessible + mock_is_accessible.return_value = False + + with time_machine.travel("2025-02-18T08:04:46Z", tick=False): + # openlineage_inject_parent_info defaults to True + task = TriggerDagRunOperator( + task_id="test_task", + trigger_dag_id=TRIGGERED_DAG_ID, + conf=original_conf, + ) + + ti = mock.MagicMock() + with pytest.raises(DagRunTriggerException) as exc_info: + task.execute(context={"ti": ti}) + + # Conf should remain unchanged when OL is unavailable + assert exc_info.value.conf == original_conf + + @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Implementation is different for Airflow 2 & 3") + @pytest.mark.parametrize( + ("provider_version", "should_modify"), + [ + ("2.7.0", False), # Below minimum - conf not modified + ("2.7.9", False), # Below minimum - conf not modified + ("2.8.0", True), # Exactly minimum - conf modified + ("2.8.1", True), # Above minimum - conf modified + ], + ) + @mock.patch(f"{OL_UTILS_PATH}._is_openlineage_provider_accessible") + @mock.patch("importlib.metadata.version") + def test_trigger_dagrun_conf_openlineage_injection_disabled_for_older_ol_providers( + self, mock_version, mock_is_accessible, provider_version, should_modify + ): + """Test that conf is only modified when OpenLineage provider version is sufficient.""" + original_conf = {"foo": "bar"} + ol_parent_info = { + "parentRunId": "test-run-id", + "parentJobName": "test-job", + "parentJobNamespace": "test-ns", + "rootParentRunId": "test-root-run-id", + "rootParentJobName": "test-root-job", + "rootParentJobNamespace": "test-root-ns", + } + injected_conf = { + "foo": "bar", + "openlineage": ol_parent_info, + } + + def _mock_version(package): + if package == "apache-airflow-providers-openlineage": + return provider_version + raise Exception(f"Unexpected package: {package}") + + mock_version.side_effect = _mock_version + mock_is_accessible.return_value = True + + with time_machine.travel("2025-02-18T08:04:46Z", tick=False): + task = TriggerDagRunOperator( + task_id="test_task", + trigger_dag_id=TRIGGERED_DAG_ID, + conf=original_conf, + ) + + mock_ti = mock.MagicMock() + if should_modify: + # When version is sufficient, mock _get_openlineage_parent_info to return data + with mock.patch(f"{OL_UTILS_PATH}._get_openlineage_parent_info", return_value=ol_parent_info): + with pytest.raises(DagRunTriggerException) as exc_info: + task.execute(context={"ti": mock_ti}) + # Conf should be modified + assert exc_info.value.conf == injected_conf + else: + # When version is insufficient, _get_openlineage_parent_info will raise + with pytest.raises(DagRunTriggerException) as exc_info: + task.execute(context={"ti": mock_ti}) + # Conf should remain unchanged + assert exc_info.value.conf == original_conf + + @pytest.mark.parametrize( + "exception", + [ + Exception("Generic error during injection"), + ValueError("Invalid data format"), + RuntimeError("Runtime issue"), + ], + ) + @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Implementation is different for Airflow 2 & 3") + @mock.patch(f"{OL_UTILS_PATH}._is_openlineage_provider_accessible") + def test_trigger_dagrun_conf_openlineage_injection_preserves_conf_on_exception( + self, mock_is_accessible, exception + ): + """Test that original conf is preserved when any exception occurs during injection.""" + original_conf = {"foo": "bar"} + mock_is_accessible.return_value = True + + # Simulate any exception during injection (version check failure, runtime error, etc.) + with ( + mock.patch( + f"{OL_UTILS_PATH}._inject_openlineage_parent_info_to_dagrun_conf", + side_effect=exception, + ), + time_machine.travel("2025-02-18T08:04:46Z", tick=False), + ): + task = TriggerDagRunOperator( + task_id="test_task", + trigger_dag_id=TRIGGERED_DAG_ID, + conf=original_conf, + ) + + mock_ti = mock.MagicMock() + with pytest.raises(DagRunTriggerException) as exc_info: + task.execute(context={"ti": mock_ti}) + + # Conf should remain unchanged when any exception occurs during injection + assert exc_info.value.conf == original_conf + + @pytest.mark.parametrize("original_conf", (None, {}, {"foo": "bar"})) + @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Implementation is different for Airflow 2 & 3") + @mock.patch(f"{OL_UTILS_PATH}._is_openlineage_provider_accessible") + @mock.patch(f"{OL_UTILS_PATH}._get_openlineage_parent_info") + def test_trigger_dagrun_conf_openlineage_injection_valid_data( + self, mock_get_parent_info, mock_is_accessible, original_conf + ): + """Test that OpenLineage injection works when OL is available and flag is True.""" + ol_parent_info = { + "rootParentRunId": "22222222-2222-2222-2222-222222222222", + "rootParentJobNamespace": "rootns", + "rootParentJobName": "rootjob", + "parentRunId": "33333333-3333-3333-3333-333333333333", + "parentJobNamespace": "parentns", + "parentJobName": "parentjob", + } + injected_conf = { + **(original_conf or {}), + "openlineage": ol_parent_info, + } + mock_is_accessible.return_value = True + mock_get_parent_info.return_value = ol_parent_info + + with time_machine.travel("2025-02-18T08:04:46Z", tick=False): + task = TriggerDagRunOperator( + task_id="test_task", + trigger_dag_id=TRIGGERED_DAG_ID, + conf=original_conf, + ) + + mock_ti = mock.MagicMock() + with pytest.raises(DagRunTriggerException) as exc_info: + task.execute(context={"ti": mock_ti}) + + # Conf should contain injected OpenLineage metadata + assert exc_info.value.conf == injected_conf + # Verify _get_openlineage_parent_info was called with ti + mock_get_parent_info.assert_called_once_with(ti=mock_ti) + # TODO: To be removed once the provider drops support for Airflow 2 @pytest.mark.skipif(AIRFLOW_V_3_0_PLUS, reason="Test only for Airflow 2") @@ -344,6 +556,17 @@ def test_trigger_dagrun(self, dag_maker, mock_supervisor_comms): assert dagrun.run_type == DagRunType.MANUAL assert dagrun.run_id == DagRun.generate_run_id(DagRunType.MANUAL, dagrun.logical_date) + def test_explicitly_provided_trigger_run_id_is_saved_as_attr(self, dag_maker, session): + with dag_maker(TEST_DAG_ID, default_args={"start_date": DEFAULT_DATE}, serialized=True): + task = TriggerDagRunOperator( + task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, trigger_run_id="test_run_id" + ) + assert task.trigger_run_id == "test_run_id" + dag_maker.create_dagrun() + task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + + assert task.trigger_run_id == "test_run_id" + def test_extra_operator_link(self, dag_maker, session): """Asserts whether the correct extra links url will be created.""" with dag_maker(TEST_DAG_ID, default_args={"start_date": DEFAULT_DATE}, serialized=True): @@ -819,3 +1042,200 @@ def test_trigger_dagrun_with_fail_when_dag_is_paused(self, dag_maker, session): dag_maker.create_dagrun() with pytest.raises(AirflowException, match=f"^Dag {TRIGGERED_DAG_ID} is paused$"): task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + + @pytest.mark.parametrize("original_conf", (None, {}, {"foo": "bar"})) + @mock.patch(f"{TRIGGER_OP_PATH}.safe_inject_openlineage_properties_into_dagrun_conf") + def test_trigger_dagrun_conf_openlineage_injection_disabled_with_explicit_false_arg( + self, mock_inject, original_conf, dag_maker + ): + """Test that conf is not modified when openlineage_inject_parent_info=False.""" + with dag_maker(TEST_DAG_ID, default_args={"start_date": DEFAULT_DATE}, serialized=True): + task = TriggerDagRunOperator( + task_id="test_task", + trigger_dag_id=TRIGGERED_DAG_ID, + conf=original_conf, + openlineage_inject_parent_info=False, + ) + dag_maker.sync_dagbag_to_db() + parse_and_sync_to_db(self.f_name) + dag_maker.create_dagrun() + + task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + + # Injection function should not be called + mock_inject.assert_not_called() + + # Verify conf was not modified by checking the triggered DAG run + with create_session() as session: + dagrun = session.query(DagRun).filter(DagRun.dag_id == TRIGGERED_DAG_ID).one() + assert dagrun.conf == (original_conf if original_conf is not None else {}) + + @mock.patch(f"{OL_UTILS_PATH}._is_openlineage_provider_accessible") + def test_trigger_dagrun_conf_openlineage_injection_disabled_when_ol_not_accessible( + self, mock_is_accessible, dag_maker + ): + """Test that conf is not modified when OpenLineage provider is not accessible.""" + original_conf = {"foo": "bar"} + # Simulate OL provider being disabled/not accessible + mock_is_accessible.return_value = False + + with dag_maker(TEST_DAG_ID, default_args={"start_date": DEFAULT_DATE}, serialized=True): + # openlineage_inject_parent_info defaults to True + task = TriggerDagRunOperator( + task_id="test_task", + trigger_dag_id=TRIGGERED_DAG_ID, + conf=original_conf, + ) + dag_maker.sync_dagbag_to_db() + parse_and_sync_to_db(self.f_name) + dag_maker.create_dagrun() + + task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + + # Verify conf was not modified + with create_session() as session: + dagrun = session.query(DagRun).filter(DagRun.dag_id == TRIGGERED_DAG_ID).one() + assert dagrun.conf == original_conf + + @pytest.mark.parametrize( + ("provider_version", "should_modify"), + [ + ("2.7.0", False), # Below minimum - conf not modified + ("2.7.9", False), # Below minimum - conf not modified + ("2.8.0", True), # Exactly minimum - conf modified + ("2.8.1", True), # Above minimum - conf modified + ], + ) + @mock.patch(f"{OL_UTILS_PATH}._is_openlineage_provider_accessible") + @mock.patch("importlib.metadata.version") + def test_trigger_dagrun_conf_openlineage_injection_disabled_for_older_ol_providers( + self, mock_version, mock_is_accessible, provider_version, should_modify, dag_maker + ): + """Test that conf is only modified when OpenLineage provider version is sufficient.""" + original_conf = {"foo": "bar"} + ol_parent_info = { + "parentRunId": "test-run-id", + "parentJobName": "test-job", + "parentJobNamespace": "test-ns", + "rootParentRunId": "test-root-run-id", + "rootParentJobName": "test-root-job", + "rootParentJobNamespace": "test-root-ns", + } + injected_conf = { + "foo": "bar", + "openlineage": ol_parent_info, + } + + def _mock_version(package): + if package == "apache-airflow-providers-openlineage": + return provider_version + raise Exception(f"Unexpected package: {package}") + + mock_version.side_effect = _mock_version + mock_is_accessible.return_value = True + + with dag_maker(TEST_DAG_ID, default_args={"start_date": DEFAULT_DATE}, serialized=True): + task = TriggerDagRunOperator( + task_id="test_task", + trigger_dag_id=TRIGGERED_DAG_ID, + conf=original_conf, + ) + dag_maker.sync_dagbag_to_db() + parse_and_sync_to_db(self.f_name) + dag_maker.create_dagrun() + + if should_modify: + # When version is sufficient, mock _get_openlineage_parent_info to return data + with mock.patch(f"{OL_UTILS_PATH}._get_openlineage_parent_info", return_value=ol_parent_info): + task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + else: + # When version is insufficient, _get_openlineage_parent_info will raise + task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + + with create_session() as session: + dagrun = session.query(DagRun).filter(DagRun.dag_id == TRIGGERED_DAG_ID).one() + if should_modify: + # When version is sufficient, conf should be modified + assert dagrun.conf == injected_conf + else: + # When version is insufficient, conf should remain unchanged + assert dagrun.conf == original_conf + + @pytest.mark.parametrize( + "exception", + [ + Exception("Generic error during injection"), + ValueError("Invalid data format"), + RuntimeError("Runtime issue"), + ], + ) + @mock.patch(f"{OL_UTILS_PATH}._is_openlineage_provider_accessible") + def test_trigger_dagrun_conf_openlineage_injection_preserves_conf_on_exception( + self, mock_is_accessible, exception, dag_maker + ): + """Test that original conf is preserved when any exception occurs during injection.""" + original_conf = {"foo": "bar"} + mock_is_accessible.return_value = True + + # Simulate any exception during injection (version check failure, runtime error, etc.) + with mock.patch( + f"{OL_UTILS_PATH}._inject_openlineage_parent_info_to_dagrun_conf", + side_effect=exception, + ): + with dag_maker(TEST_DAG_ID, default_args={"start_date": DEFAULT_DATE}, serialized=True): + task = TriggerDagRunOperator( + task_id="test_task", + trigger_dag_id=TRIGGERED_DAG_ID, + conf=original_conf, + ) + dag_maker.sync_dagbag_to_db() + parse_and_sync_to_db(self.f_name) + dag_maker.create_dagrun() + + task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + + # Verify conf was not modified when any exception occurs during injection + with create_session() as session: + dagrun = session.query(DagRun).filter(DagRun.dag_id == TRIGGERED_DAG_ID).one() + assert dagrun.conf == original_conf + + @pytest.mark.parametrize("original_conf", (None, {}, {"foo": "bar"})) + @mock.patch(f"{OL_UTILS_PATH}._is_openlineage_provider_accessible") + @mock.patch(f"{OL_UTILS_PATH}._get_openlineage_parent_info") + def test_trigger_dagrun_conf_openlineage_injection_valid_data( + self, mock_get_parent_info, mock_is_accessible, original_conf, dag_maker + ): + """Test that OpenLineage injection works when OL is available and flag is True.""" + ol_parent_info = { + "rootParentRunId": "22222222-2222-2222-2222-222222222222", + "rootParentJobNamespace": "rootns", + "rootParentJobName": "rootjob", + "parentRunId": "33333333-3333-3333-3333-333333333333", + "parentJobNamespace": "parentns", + "parentJobName": "parentjob", + } + injected_conf = { + **(original_conf or {}), + "openlineage": ol_parent_info, + } + mock_is_accessible.return_value = True + mock_get_parent_info.return_value = ol_parent_info + + with dag_maker(TEST_DAG_ID, default_args={"start_date": DEFAULT_DATE}, serialized=True): + task = TriggerDagRunOperator( + task_id="test_task", + trigger_dag_id=TRIGGERED_DAG_ID, + conf=original_conf, + ) + dag_maker.sync_dagbag_to_db() + parse_and_sync_to_db(self.f_name) + dag_maker.create_dagrun() + + task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + + # Verify conf contains injected OpenLineage metadata + with create_session() as session: + dagrun = session.query(DagRun).filter(DagRun.dag_id == TRIGGERED_DAG_ID).one() + assert dagrun.conf == injected_conf + # Verify _get_openlineage_parent_info was called + mock_get_parent_info.assert_called_once() diff --git a/providers/standard/tests/unit/standard/utils/test_openlineage.py b/providers/standard/tests/unit/standard/utils/test_openlineage.py new file mode 100644 index 0000000000000..d0e0ac5f43c48 --- /dev/null +++ b/providers/standard/tests/unit/standard/utils/test_openlineage.py @@ -0,0 +1,470 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from airflow.exceptions import AirflowOptionalProviderFeatureException +from airflow.providers.standard.utils.openlineage import ( + OPENLINEAGE_PROVIDER_MIN_VERSION, + _get_openlineage_parent_info, + _inject_openlineage_parent_info_to_dagrun_conf, + _is_openlineage_provider_accessible, + safe_inject_openlineage_properties_into_dagrun_conf, +) + +OL_UTILS_PATH = "airflow.providers.standard.utils.openlineage" +OL_PROVIDER_PATH = "airflow.providers.openlineage" +OL_MACROS_PATH = f"{OL_PROVIDER_PATH}.plugins.macros" +OL_CONF_PATH = f"{OL_PROVIDER_PATH}.conf" +OL_LISTENER_PATH = f"{OL_PROVIDER_PATH}.plugins.listener" +IMPORTLIB_VERSION = "importlib.metadata.version" + + +@patch(f"{OL_LISTENER_PATH}.get_openlineage_listener") +@patch(f"{OL_CONF_PATH}.is_disabled") +def test_is_openlineage_provider_accessible(mock_is_disabled, mock_get_listener): + mock_is_disabled.return_value = False + mock_get_listener.return_value = True + assert _is_openlineage_provider_accessible() is True + + +@patch(f"{OL_LISTENER_PATH}.get_openlineage_listener") +@patch(f"{OL_CONF_PATH}.is_disabled") +def test_is_openlineage_provider_disabled(mock_is_disabled, mock_get_listener): + mock_is_disabled.return_value = True + assert _is_openlineage_provider_accessible() is False + + +@patch(f"{OL_LISTENER_PATH}.get_openlineage_listener") +@patch(f"{OL_CONF_PATH}.is_disabled") +def test_is_openlineage_listener_not_found(mock_is_disabled, mock_get_listener): + mock_is_disabled.return_value = False + mock_get_listener.return_value = None + assert _is_openlineage_provider_accessible() is False + + +@patch(f"{OL_CONF_PATH}.is_disabled") +def test_is_openlineage_provider_accessible_import_error(mock_is_disabled): + """Test that ImportError is handled when OpenLineage modules cannot be imported.""" + mock_is_disabled.side_effect = RuntimeError("Should not be called.") + with patch.dict( + "sys.modules", + { + OL_CONF_PATH: None, + OL_LISTENER_PATH: None, + }, + ): + result = _is_openlineage_provider_accessible() + assert result is False + + +def _mock_ol_parent_info(): + """Create a mock OpenLineage parent info dict.""" + return { + "parentRunId": "test-run-id", + "parentJobName": "test-job", + "parentJobNamespace": "test-ns", + "rootParentRunId": "test-root-run-id", + "rootParentJobName": "test-root-job", + "rootParentJobNamespace": "test-root-ns", + } + + +def test_get_openlineage_parent_info(): + """Test that _get_openlineage_parent_info calls all macros correctly and returns proper structure.""" + ti = MagicMock() + expected_values = { + "parentRunId": "parent-run-id-123", + "parentJobName": "parent-job-name", + "parentJobNamespace": "parent-namespace", + "rootParentRunId": "root-run-id-456", + "rootParentJobName": "root-job-name", + "rootParentJobNamespace": "root-namespace", + } + + def _mock_version(package): + if package == "apache-airflow-providers-openlineage": + return OPENLINEAGE_PROVIDER_MIN_VERSION # Exactly minimum version + raise Exception(f"Unexpected package: {package}") + + with ( + patch(f"{OL_MACROS_PATH}.lineage_run_id", return_value=expected_values["parentRunId"]) as mock_run_id, + patch( + f"{OL_MACROS_PATH}.lineage_job_name", return_value=expected_values["parentJobName"] + ) as mock_job_name, + patch( + f"{OL_MACROS_PATH}.lineage_job_namespace", + return_value=expected_values["parentJobNamespace"], + ) as mock_job_namespace, + patch( + f"{OL_MACROS_PATH}.lineage_root_run_id", + return_value=expected_values["rootParentRunId"], + ) as mock_root_run_id, + patch( + f"{OL_MACROS_PATH}.lineage_root_job_name", + return_value=expected_values["rootParentJobName"], + ) as mock_root_job_name, + patch( + f"{OL_MACROS_PATH}.lineage_root_job_namespace", + return_value=expected_values["rootParentJobNamespace"], + ) as mock_root_job_namespace, + patch(IMPORTLIB_VERSION, side_effect=_mock_version), + ): + result = _get_openlineage_parent_info(ti) + + # Verify all macros were called correctly + mock_run_id.assert_called_once_with(ti) + mock_job_name.assert_called_once_with(ti) + mock_job_namespace.assert_called_once() # No args + mock_root_run_id.assert_called_once_with(ti) + mock_root_job_name.assert_called_once_with(ti) + mock_root_job_namespace.assert_called_once_with(ti) + + # Verify result structure + assert isinstance(result, dict) + assert result == expected_values + assert set(result.keys()) == { + "parentRunId", + "parentJobName", + "parentJobNamespace", + "rootParentRunId", + "rootParentJobName", + "rootParentJobNamespace", + } + + +@pytest.mark.parametrize( + ("provider_version", "should_raise"), + [ + ("2.7.0", True), # Below minimum + ("2.7.9", True), # Below minimum + ("2.8.0", False), # Exactly minimum + ("2.8.1", False), # Above minimum + ], +) +def test_get_openlineage_parent_info_version_check(provider_version, should_raise): + """Test that _get_openlineage_parent_info raises AirflowOptionalProviderFeatureException when version is insufficient.""" + ti = MagicMock() + + def _mock_version(package): + if package == "apache-airflow-providers-openlineage": + return provider_version + raise Exception(f"Unexpected package: {package}") + + with patch(IMPORTLIB_VERSION, side_effect=_mock_version): + if should_raise: + expected_err = ( + f"OpenLineage provider version `{provider_version}` is lower than " + f"required `{OPENLINEAGE_PROVIDER_MIN_VERSION}`, " + "skipping function `_get_openlineage_parent_info` execution" + ) + with pytest.raises(AirflowOptionalProviderFeatureException, match=expected_err): + _get_openlineage_parent_info(ti) + else: + # When version is sufficient, mock all macros to allow execution + with ( + patch(f"{OL_MACROS_PATH}.lineage_run_id", return_value="run-id"), + patch(f"{OL_MACROS_PATH}.lineage_job_name", return_value="job-name"), + patch(f"{OL_MACROS_PATH}.lineage_job_namespace", return_value="job-ns"), + patch(f"{OL_MACROS_PATH}.lineage_root_run_id", return_value="root-run-id"), + patch(f"{OL_MACROS_PATH}.lineage_root_job_name", return_value="root-job-name"), + patch(f"{OL_MACROS_PATH}.lineage_root_job_namespace", return_value="root-job-ns"), + ): + result = _get_openlineage_parent_info(ti) + assert isinstance(result, dict) + assert "parentRunId" in result + + +@pytest.mark.parametrize("dr_conf", [None, {}]) +def test_inject_parent_info_with_none_or_empty_conf(dr_conf): + """Test injection with None or empty dict creates new openlineage section.""" + result = _inject_openlineage_parent_info_to_dagrun_conf(dr_conf, _mock_ol_parent_info()) + expected = {"openlineage": _mock_ol_parent_info()} + assert result == expected + + +@pytest.mark.parametrize("dr_conf", ["conf as string", ["conf_list"], [{"a": 1}, {"b": 2}]]) +def test_inject_parent_info_with_wrong_type_conf_raises_error(dr_conf): + """Test injection with wrong type of conf raises error (we catch it later on).""" + with pytest.raises(TypeError): + _inject_openlineage_parent_info_to_dagrun_conf(dr_conf, _mock_ol_parent_info()) + + +def test_inject_parent_info_with_existing_conf_no_openlineage_key(): + """Test injection with existing conf but no openlineage key.""" + dr_conf = {"some": "other", "config": "value"} + result = _inject_openlineage_parent_info_to_dagrun_conf(dr_conf, _mock_ol_parent_info()) + + expected = { + "some": "other", + "config": "value", + "openlineage": _mock_ol_parent_info(), + } + assert result == expected + # Original conf should not be modified + assert dr_conf == {"some": "other", "config": "value"} + + +def test_inject_parent_info_with_existing_openlineage_dict(): + """Test injection with existing openlineage dict merges correctly.""" + dr_conf = { + "some": "other", + "openlineage": { + "existing": "value", + "otherKey": "otherValue", + }, + } + result = _inject_openlineage_parent_info_to_dagrun_conf(dr_conf, _mock_ol_parent_info()) + + expected = { + "some": "other", + "openlineage": { + "existing": "value", + "otherKey": "otherValue", + **_mock_ol_parent_info(), + }, + } + assert result == expected + # Original conf should not be modified + assert dr_conf == { + "some": "other", + "openlineage": { + "existing": "value", + "otherKey": "otherValue", + }, + } + + +def test_inject_parent_info_with_non_dict_openlineage_returns_unchanged(): + """Test that non-dict openlineage value returns conf unchanged.""" + dr_conf = {"openlineage": "not-a-dict"} + result = _inject_openlineage_parent_info_to_dagrun_conf(dr_conf, _mock_ol_parent_info()) + + assert result == dr_conf + assert result is dr_conf # Should return same object + + +@pytest.mark.parametrize( + "forbidden_key", + [ + "parentRunId", + "parentJobName", + "parentJobNamespace", + "rootParentRunId", + "rootJobName", + "rootJobNamespace", + ], +) +def test_inject_parent_info_with_existing_forbidden_key_returns_unchanged(forbidden_key): + """Test that existing forbidden keys prevent injection.""" + dr_conf = { + "openlineage": { + forbidden_key: "existing-value", + "otherKey": "value", + } + } + result = _inject_openlineage_parent_info_to_dagrun_conf(dr_conf, _mock_ol_parent_info()) + + assert result == dr_conf + assert result is dr_conf # Should return same object + + +def test_inject_parent_info_with_multiple_existing_keys_returns_unchanged(): + """Test that multiple existing forbidden keys are all detected.""" + dr_conf = { + "openlineage": { + "parentRunId": "existing-parent-id", + "rootParentJobName": "existing-root-job", + "otherKey": "value", + } + } + result = _inject_openlineage_parent_info_to_dagrun_conf(dr_conf, _mock_ol_parent_info()) + assert result == dr_conf + # Original conf should not be modified + assert dr_conf == { + "openlineage": { + "parentRunId": "existing-parent-id", + "rootParentJobName": "existing-root-job", + "otherKey": "value", + } + } + + +def test_safe_inject_returns_unchanged_when_provider_not_accessible(): + """Test returns original conf when OpenLineage provider is not accessible.""" + dr_conf = {"some": "config"} + + with patch(f"{OL_UTILS_PATH}._is_openlineage_provider_accessible", return_value=False): + result = safe_inject_openlineage_properties_into_dagrun_conf(dr_conf, MagicMock()) + + assert result == dr_conf + assert result is dr_conf # Should return same object + + +def test_safe_inject_correctly_injects_openlineage_info(): + """Test that OpenLineage injection works when OL is available and version is sufficient.""" + dr_conf = {"some": "config"} + expected_result = { + "some": "config", + "openlineage": _mock_ol_parent_info(), + } + + def _mock_version(package): + if package == "apache-airflow-providers-openlineage": + return OPENLINEAGE_PROVIDER_MIN_VERSION + raise Exception(f"Unexpected package: {package}") + + with ( + patch(f"{OL_UTILS_PATH}._is_openlineage_provider_accessible", return_value=True), + patch(IMPORTLIB_VERSION, side_effect=_mock_version), + patch(f"{OL_UTILS_PATH}._get_openlineage_parent_info", return_value=_mock_ol_parent_info()), + ): + result = safe_inject_openlineage_properties_into_dagrun_conf(dr_conf, MagicMock()) + + assert result == expected_result + + +@pytest.mark.parametrize("dr_conf", [None, {}, "not-a-dict", ["a", "b", "c"]]) +def test_safe_inject_handles_none_empty_and_non_dict_conf(dr_conf): + """Test handles None, empty dict, or non-dict conf without raising error.""" + + def _mock_version(package): + if package == "apache-airflow-providers-openlineage": + return OPENLINEAGE_PROVIDER_MIN_VERSION + raise Exception(f"Unexpected package: {package}") + + with ( + patch(f"{OL_UTILS_PATH}._is_openlineage_provider_accessible", return_value=True), + patch(IMPORTLIB_VERSION, side_effect=_mock_version), + patch(f"{OL_UTILS_PATH}._get_openlineage_parent_info", return_value=_mock_ol_parent_info()), + ): + result = safe_inject_openlineage_properties_into_dagrun_conf(dr_conf, MagicMock()) + + if dr_conf is None or isinstance(dr_conf, dict): + assert result == {"openlineage": _mock_ol_parent_info()} + else: + assert result == dr_conf + assert result is dr_conf + + +def test_safe_inject_copies_dict_before_passing(): + """Test that dict is copied before being passed to inject function.""" + dr_conf = {"some": "config", "nested": {"key": "value"}} + + with ( + patch(f"{OL_UTILS_PATH}._is_openlineage_provider_accessible", return_value=True), + patch(f"{OL_UTILS_PATH}._get_openlineage_parent_info", return_value=_mock_ol_parent_info()), + patch(f"{OL_UTILS_PATH}._inject_openlineage_parent_info_to_dagrun_conf") as mock_inject, + ): + expected_result = {**dr_conf, "openlineage": _mock_ol_parent_info()} + mock_inject.return_value = expected_result + safe_inject_openlineage_properties_into_dagrun_conf(dr_conf, MagicMock()) + + # Verify that a copy was passed (not the original) + mock_inject.assert_called_once() + call_args = mock_inject.call_args + passed_conf = call_args[1]["dr_conf"] # Keyword argument + assert passed_conf == dr_conf + # The copy should be a different object (shallow copy) + assert passed_conf is not dr_conf + + +@pytest.mark.parametrize( + "exception", [ValueError("Test error"), KeyError("Missing key"), RuntimeError("Runtime issue")] +) +def test_safe_inject_preserves_original_conf_on_exception(exception): + """Test that original conf is preserved when any exception occurs during injection.""" + dr_conf = {"key": "value", "nested": {"deep": "data"}} + ti = MagicMock() + + with ( + patch(f"{OL_UTILS_PATH}._is_openlineage_provider_accessible", return_value=True), + patch(f"{OL_UTILS_PATH}._get_openlineage_parent_info", side_effect=exception), + ): + result = safe_inject_openlineage_properties_into_dagrun_conf(dr_conf, ti) + + # Should return original conf unchanged + assert result == {"key": "value", "nested": {"deep": "data"}} + assert result is dr_conf # Should return same object + + +@pytest.mark.parametrize( + ("provider_version", "should_raise"), + [ + ("2.7.0", True), # Below minimum + ("2.7.9", True), # Below minimum + ("2.8.0", False), # Exactly minimum + ("2.8.1", False), # Above minimum + ("3.0.0", False), # Well above minimum + ], +) +def test_safe_inject_with_provider_version_check(provider_version, should_raise): + """Test that version checking works correctly - exception caught when insufficient, works when sufficient.""" + dr_conf = {"some": "config"} + ti = MagicMock() + ol_parent_info = _mock_ol_parent_info() + + def _mock_version(package): + if package == "apache-airflow-providers-openlineage": + return provider_version + raise Exception(f"Unexpected package: {package}") + + with ( + patch(f"{OL_UTILS_PATH}._is_openlineage_provider_accessible", return_value=True), + patch(IMPORTLIB_VERSION, side_effect=_mock_version), + ): + if should_raise: + # When version is insufficient, _get_openlineage_parent_info will raise + # The exception should be caught and conf returned unchanged + result = safe_inject_openlineage_properties_into_dagrun_conf(dr_conf, ti) + + assert result == dr_conf + else: + # When version is sufficient, mock _get_openlineage_parent_info to return data + with patch(f"{OL_UTILS_PATH}._get_openlineage_parent_info", return_value=ol_parent_info): + result = safe_inject_openlineage_properties_into_dagrun_conf(dr_conf, ti) + expected = { + "some": "config", + "openlineage": ol_parent_info, + } + assert result == expected + + +def test_inject_when_provider_not_found(): + """Test that injection handles case when OpenLineage provider package is not found.""" + dr_conf = {"some": "config"} + ti = MagicMock() + + # Simulate the case where _get_openlineage_parent_info raises AirflowOptionalProviderFeatureException + # because the provider package is not found (this happens inside require_openlineage_version decorator) + with ( + patch(f"{OL_UTILS_PATH}._is_openlineage_provider_accessible", return_value=True), + patch( + f"{OL_UTILS_PATH}._get_openlineage_parent_info", + side_effect=AirflowOptionalProviderFeatureException( + "OpenLineage provider not found or has no version, " + "skipping function `_get_openlineage_parent_info` execution" + ), + ), + ): + result = safe_inject_openlineage_properties_into_dagrun_conf(dr_conf, ti) + + assert result == dr_conf