diff --git a/providers/openlineage/src/airflow/providers/openlineage/utils/utils.py b/providers/openlineage/src/airflow/providers/openlineage/utils/utils.py index 5f11b186b62c7..78441eaddf707 100644 --- a/providers/openlineage/src/airflow/providers/openlineage/utils/utils.py +++ b/providers/openlineage/src/airflow/providers/openlineage/utils/utils.py @@ -841,6 +841,9 @@ class TaskInstanceInfo(InfoJsonEncodable): casts = { "log_url": lambda ti: getattr(ti, "log_url", None), "map_index": lambda ti: ti.map_index if getattr(ti, "map_index", -1) != -1 else None, + "rendered_map_index": lambda ti: ( + getattr(ti, "rendered_map_index", None) if getattr(ti, "map_index", -1) != -1 else None + ), "dag_bundle_version": lambda ti: ( ti.bundle_instance.version if hasattr(ti, "bundle_instance") else None ), diff --git a/providers/openlineage/tests/unit/openlineage/utils/test_utils.py b/providers/openlineage/tests/unit/openlineage/utils/test_utils.py index c406938196196..b64d77014c748 100644 --- a/providers/openlineage/tests/unit/openlineage/utils/test_utils.py +++ b/providers/openlineage/tests/unit/openlineage/utils/test_utils.py @@ -3035,11 +3035,22 @@ def test_taskinstance_info_af3(): assert dict(TaskInstanceInfo(runtime_ti)) == { "log_url": runtime_ti.log_url, "map_index": 2, + "rendered_map_index": None, "try_number": 1, "dag_bundle_version": "bundle_version", "dag_bundle_name": "bundle_name", } + runtime_ti.rendered_map_index = "country=PL" + assert dict(TaskInstanceInfo(runtime_ti))["rendered_map_index"] == "country=PL" + + # Should only be included if task is mapped + runtime_ti.map_index = -1 + runtime_ti.rendered_map_index = None + result = dict(TaskInstanceInfo(runtime_ti)) + assert result["map_index"] is None + assert result["rendered_map_index"] is None + @pytest.mark.skipif(AIRFLOW_V_3_0_PLUS, reason="Airflow 2 test") @patch.object(TaskInstance, "log_url", "some_log_url") # Depends on the host, hard to test exact value @@ -3055,6 +3066,7 @@ def test_taskinstance_info_af2(): assert dict(TaskInstanceInfo(ti)) == { "duration": 12.345, "map_index": 2, + "rendered_map_index": None, "pool": "default_pool", "try_number": 0, "queued_dttm": "2024-06-01T00:00:00+00:00", @@ -3063,6 +3075,19 @@ def test_taskinstance_info_af2(): "dag_bundle_version": None, } + # Also tested manually that it works well on AF2, hard to test hybrid property so just mocking it here + with patch.object(TaskInstance, "rendered_map_index", "country=PL"): + assert dict(TaskInstanceInfo(ti))["rendered_map_index"] == "country=PL" + + ti_unmapped = TaskInstance( + task=task_obj, run_id="task_instance_run_id", state=TaskInstanceState.RUNNING, map_index=-1 + ) + # Should only be included if task is mapped + with patch.object(TaskInstance, "rendered_map_index", "should-not-be-emitted"): + result = dict(TaskInstanceInfo(ti_unmapped)) + assert result["map_index"] is None + assert result["rendered_map_index"] is None + @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Airflow 3 test") def test_task_info_af3():