diff --git a/providers/databricks/src/airflow/providers/databricks/operators/databricks_sql.py b/providers/databricks/src/airflow/providers/databricks/operators/databricks_sql.py index b50c434d04cb9..bb577d7b7b3a4 100644 --- a/providers/databricks/src/airflow/providers/databricks/operators/databricks_sql.py +++ b/providers/databricks/src/airflow/providers/databricks/operators/databricks_sql.py @@ -45,6 +45,64 @@ _IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") _DISALLOWED_SQL_TOKENS = (";", "--", "/*", "*/") +_QUERY_TAG_FIELDS = { + "airflow_dag_id": ("dag", "dag_id"), + "airflow_task_id": ("task", "task_id"), + "airflow_run_id": ("run_id", None), +} + +_QUERY_TAG_ESCAPE_SEQUENCES = { + "\\": "\\\\", + ",": "\\,", + ":": "\\:", +} + + +def _escape_query_tag_value(value: str) -> str: + escaped = str(value) + + for char, replacement in _QUERY_TAG_ESCAPE_SEQUENCES.items(): + escaped = escaped.replace(char, replacement) + + return escaped + + +def _format_query_tags(context: Context) -> str: + tags = [] + + for tag_name, (context_key, attr) in _QUERY_TAG_FIELDS.items(): + value = context.get(context_key) + + if attr: + value = getattr(value, attr, None) + + if value: + tags.append(f"{tag_name}:{_escape_query_tag_value(value)}") + + return ",".join(tags) + + +def _merge_query_tags(session_config: dict[str, Any], query_tags: str) -> dict[str, Any]: + """Return a copied session config with Airflow query tags appended.""" + updated_config = session_config.copy() + existing_tags = updated_config.get("query_tags", "") + updated_config["query_tags"] = f"{existing_tags},{query_tags}" if existing_tags else query_tags + return updated_config + + +def _inject_query_tags(hook: DatabricksSqlHook, context: Context) -> None: + """Inject Airflow context metadata into Databricks query tags.""" + query_tags = _format_query_tags(context) + if not query_tags: + return + + if hook.session_config is None: + conn_extra = hook.databricks_conn.extra_dejson + hook.session_config = conn_extra.get("session_configuration", {}) + + if isinstance(hook.session_config, dict): + hook.session_config = _merge_query_tags(hook.session_config, query_tags) + class DatabricksSqlOperator(SQLExecuteQueryOperator): """ @@ -83,6 +141,11 @@ class DatabricksSqlOperator(SQLExecuteQueryOperator): :param gcs_impersonation_chain: Optional service account to impersonate using short-term credentials for GCS upload, or chained list of accounts required to get the access_token of the last account in the list, which will be impersonated in the request. (templated) + :param inject_query_tags: If ``True`` (default), Airflow context metadata + (``airflow_dag_id``, ``airflow_task_id``, ``airflow_run_id``) is injected into the + Databricks session ``query_tags`` at execution time, preserving any user-defined + ``query_tags`` already present in ``session_configuration``. Set to ``False`` to + retain full control over ``session_configuration`` and skip the automatic injection. """ template_fields: Sequence[str] = tuple( @@ -117,6 +180,7 @@ def __init__( client_parameters: dict[str, Any] | None = None, gcp_conn_id: str = "google_cloud_default", gcs_impersonation_chain: str | Sequence[str] | None = None, + inject_query_tags: bool = True, **kwargs, ) -> None: super().__init__(conn_id=databricks_conn_id, **kwargs) @@ -134,6 +198,7 @@ def __init__( self.schema = schema self._gcp_conn_id = gcp_conn_id self._gcs_impersonation_chain = gcs_impersonation_chain + self.inject_query_tags = inject_query_tags @cached_property def _hook(self) -> DatabricksSqlHook: @@ -153,6 +218,11 @@ def _hook(self) -> DatabricksSqlHook: def get_db_hook(self) -> DatabricksSqlHook: return self._hook + def execute(self, context: Context) -> Any: + if self.inject_query_tags: + _inject_query_tags(self.get_db_hook(), context) + return super().execute(context) + def _should_run_output_processing(self) -> bool: return self.do_xcom_push or bool(self._output_path) @@ -348,6 +418,11 @@ class DatabricksCopyIntoOperator(BaseOperator): :param validate: optional configuration for schema & data validation. ``True`` forces validation of all rows, integer number - validate only N first rows :param copy_options: optional dictionary of copy options. Right now only ``force`` option is supported. + :param inject_query_tags: If ``True`` (default), Airflow context metadata + (``airflow_dag_id``, ``airflow_task_id``, ``airflow_run_id``) is injected into the + Databricks session ``query_tags`` at execution time, preserving any user-defined + ``query_tags`` already present in ``session_configuration``. Set to ``False`` to + retain full control over ``session_configuration`` and skip the automatic injection. """ template_fields: Sequence[str] = ( @@ -381,6 +456,7 @@ def __init__( force_copy: bool | None = None, copy_options: dict[str, str] | None = None, validate: bool | int | None = None, + inject_query_tags: bool = True, **kwargs, ) -> None: """Create a new ``DatabricksSqlOperator``.""" @@ -415,6 +491,7 @@ def __init__( self._client_parameters = client_parameters or {} if force_copy is not None: self._copy_options["force"] = "true" if force_copy else "false" + self.inject_query_tags = inject_query_tags self._sql: str | None = None def _get_hook(self) -> DatabricksSqlHook: @@ -518,6 +595,8 @@ def execute(self, context: Context) -> Any: self._sql = self._create_sql_query() self.log.info("Executing: %s", self._sql) hook = self._get_hook() + if self.inject_query_tags: + _inject_query_tags(hook, context) hook.run(self._sql) def on_kill(self) -> None: diff --git a/providers/databricks/tests/unit/databricks/operators/test_databricks_copy.py b/providers/databricks/tests/unit/databricks/operators/test_databricks_copy.py index f00653f8b22c8..1055bd5da2ed4 100644 --- a/providers/databricks/tests/unit/databricks/operators/test_databricks_copy.py +++ b/providers/databricks/tests/unit/databricks/operators/test_databricks_copy.py @@ -367,6 +367,161 @@ def test_hook_is_cached(): assert hook is hook2 +def _make_context(*, dag_id=None, task_id=None, run_id=None): + context: dict = {} + if dag_id is not None: + context["dag"] = mock.MagicMock(dag_id=dag_id) + if task_id is not None: + context["task"] = mock.MagicMock(task_id=task_id) + if run_id is not None: + context["run_id"] = run_id + return context + + +def _run_with_mocked_hook(op, context, initial_session_config, conn_extra=None): + """Execute the operator with a mocked hook and return the resulting session_config.""" + with mock.patch( + "airflow.providers.databricks.operators.databricks_sql.DatabricksSqlHook" + ) as db_mock_class: + db_mock = db_mock_class.return_value + db_mock.session_config = initial_session_config + db_mock.databricks_conn = mock.MagicMock(extra_dejson=conn_extra or {}) + op.execute(context) + return db_mock.session_config + + +def test_query_tags_injection_appends_to_existing_tags(): + op = DatabricksCopyIntoOperator( + task_id=TASK_ID, + file_location=COPY_FILE_LOCATION, + file_format="JSON", + table_name="test", + ) + context = _make_context(dag_id="test_dag", task_id="test_task", run_id="test_run_123") + + result = _run_with_mocked_hook(op, context, {"query_tags": "user_tag:value"}) + + assert result["query_tags"] == ( + "user_tag:value,airflow_dag_id:test_dag," + "airflow_task_id:test_task,airflow_run_id:test_run_123" + ) + + +def test_query_tags_injection_with_no_existing_tags(): + op = DatabricksCopyIntoOperator( + task_id=TASK_ID, + file_location=COPY_FILE_LOCATION, + file_format="JSON", + table_name="test", + ) + context = _make_context(dag_id="d", task_id="t", run_id="r") + + result = _run_with_mocked_hook(op, context, {}) + + assert result["query_tags"] == "airflow_dag_id:d,airflow_task_id:t,airflow_run_id:r" + + +def test_query_tags_injection_with_partial_context(): + op = DatabricksCopyIntoOperator( + task_id=TASK_ID, + file_location=COPY_FILE_LOCATION, + file_format="JSON", + table_name="test", + ) + context = _make_context(task_id="only_task") + + result = _run_with_mocked_hook(op, context, {}) + + assert result["query_tags"] == "airflow_task_id:only_task" + + +def test_query_tags_injection_with_empty_context(): + op = DatabricksCopyIntoOperator( + task_id=TASK_ID, + file_location=COPY_FILE_LOCATION, + file_format="JSON", + table_name="test", + ) + + result = _run_with_mocked_hook(op, {}, {"unrelated": "keep"}) + + assert result == {"unrelated": "keep"} + + +def test_query_tags_injection_escapes_special_chars(): + op = DatabricksCopyIntoOperator( + task_id=TASK_ID, + file_location=COPY_FILE_LOCATION, + file_format="JSON", + table_name="test", + ) + context = _make_context( + dag_id="dag,with,commas", + task_id="task:with:colons", + run_id="run\\with\\backslashes", + ) + + result = _run_with_mocked_hook(op, context, {}) + + assert result["query_tags"] == ( + "airflow_dag_id:dag\\,with\\,commas," + "airflow_task_id:task\\:with\\:colons," + "airflow_run_id:run\\\\with\\\\backslashes" + ) + + +def test_query_tags_injection_preserves_unrelated_session_config(): + op = DatabricksCopyIntoOperator( + task_id=TASK_ID, + file_location=COPY_FILE_LOCATION, + file_format="JSON", + table_name="test", + ) + context = _make_context(dag_id="d", task_id="t", run_id="r") + initial = {"spark.sql.shuffle.partitions": "200", "query_tags": "x:y"} + + result = _run_with_mocked_hook(op, context, initial) + + assert result["spark.sql.shuffle.partitions"] == "200" + assert result["query_tags"] == "x:y,airflow_dag_id:d,airflow_task_id:t,airflow_run_id:r" + + +def test_query_tags_injection_falls_back_to_conn_extra_when_session_config_none(): + op = DatabricksCopyIntoOperator( + task_id=TASK_ID, + file_location=COPY_FILE_LOCATION, + file_format="JSON", + table_name="test", + ) + context = _make_context(dag_id="d", task_id="t", run_id="r") + + result = _run_with_mocked_hook( + op, + context, + initial_session_config=None, + conn_extra={"session_configuration": {"query_tags": "conn_tag:1"}}, + ) + + assert result["query_tags"] == ( + "conn_tag:1,airflow_dag_id:d,airflow_task_id:t,airflow_run_id:r" + ) + + +def test_query_tags_injection_disabled(): + op = DatabricksCopyIntoOperator( + task_id=TASK_ID, + file_location=COPY_FILE_LOCATION, + file_format="JSON", + table_name="test", + inject_query_tags=False, + ) + context = _make_context(dag_id="d", task_id="t", run_id="r") + + result = _run_with_mocked_hook(op, context, {"query_tags": "user_tag:value"}) + + assert result == {"query_tags": "user_tag:value"} + + @pytest.mark.parametrize( ("file_location", "expected_namespace", "expected_name"), ( diff --git a/providers/databricks/tests/unit/databricks/operators/test_databricks_sql.py b/providers/databricks/tests/unit/databricks/operators/test_databricks_sql.py index e216c56bea2b4..721cbfbc4a24a 100644 --- a/providers/databricks/tests/unit/databricks/operators/test_databricks_sql.py +++ b/providers/databricks/tests/unit/databricks/operators/test_databricks_sql.py @@ -432,6 +432,125 @@ def test_exec_write_gcs_parquet_output(tmp_path): assert call_kwargs["object_name"] == "data/results.parquet" +def _make_context(*, dag_id=None, task_id=None, run_id=None): + from unittest.mock import MagicMock + + context: dict = {} + if dag_id is not None: + context["dag"] = MagicMock(dag_id=dag_id) + if task_id is not None: + context["task"] = MagicMock(task_id=task_id) + if run_id is not None: + context["run_id"] = run_id + return context + + +def _run_with_mocked_hook(op, context, initial_session_config, conn_extra=None): + """Execute the operator with a mocked hook and return the resulting session_config.""" + from unittest.mock import MagicMock + + op.do_xcom_push = False + with patch( + "airflow.providers.databricks.operators.databricks_sql.DatabricksSqlHook" + ) as db_mock_class: + db_mock = db_mock_class.return_value + db_mock.session_config = initial_session_config + db_mock.databricks_conn = MagicMock(extra_dejson=conn_extra or {}) + op.execute(context) + return db_mock.session_config + + +def test_query_tags_injection_appends_to_existing_tags(): + op = DatabricksSqlOperator(task_id=TASK_ID, sql="SELECT 1") + context = _make_context(dag_id="test_dag", task_id="test_task", run_id="test_run_123") + + result = _run_with_mocked_hook(op, context, {"query_tags": "user_tag:value"}) + + assert result["query_tags"] == ( + "user_tag:value,airflow_dag_id:test_dag," + "airflow_task_id:test_task,airflow_run_id:test_run_123" + ) + + +def test_query_tags_injection_with_no_existing_tags(): + op = DatabricksSqlOperator(task_id=TASK_ID, sql="SELECT 1") + context = _make_context(dag_id="d", task_id="t", run_id="r") + + result = _run_with_mocked_hook(op, context, {}) + + assert result["query_tags"] == "airflow_dag_id:d,airflow_task_id:t,airflow_run_id:r" + + +def test_query_tags_injection_with_partial_context(): + op = DatabricksSqlOperator(task_id=TASK_ID, sql="SELECT 1") + context = _make_context(task_id="only_task") + + result = _run_with_mocked_hook(op, context, {}) + + assert result["query_tags"] == "airflow_task_id:only_task" + + +def test_query_tags_injection_with_empty_context(): + op = DatabricksSqlOperator(task_id=TASK_ID, sql="SELECT 1") + + result = _run_with_mocked_hook(op, {}, {"unrelated": "keep"}) + + assert result == {"unrelated": "keep"} + + +def test_query_tags_injection_escapes_special_chars(): + op = DatabricksSqlOperator(task_id=TASK_ID, sql="SELECT 1") + context = _make_context( + dag_id="dag,with,commas", + task_id="task:with:colons", + run_id="run\\with\\backslashes", + ) + + result = _run_with_mocked_hook(op, context, {}) + + assert result["query_tags"] == ( + "airflow_dag_id:dag\\,with\\,commas," + "airflow_task_id:task\\:with\\:colons," + "airflow_run_id:run\\\\with\\\\backslashes" + ) + + +def test_query_tags_injection_preserves_unrelated_session_config(): + op = DatabricksSqlOperator(task_id=TASK_ID, sql="SELECT 1") + context = _make_context(dag_id="d", task_id="t", run_id="r") + initial = {"spark.sql.shuffle.partitions": "200", "query_tags": "x:y"} + + result = _run_with_mocked_hook(op, context, initial) + + assert result["spark.sql.shuffle.partitions"] == "200" + assert result["query_tags"] == "x:y,airflow_dag_id:d,airflow_task_id:t,airflow_run_id:r" + + +def test_query_tags_injection_falls_back_to_conn_extra_when_session_config_none(): + op = DatabricksSqlOperator(task_id=TASK_ID, sql="SELECT 1") + context = _make_context(dag_id="d", task_id="t", run_id="r") + + result = _run_with_mocked_hook( + op, + context, + initial_session_config=None, + conn_extra={"session_configuration": {"query_tags": "conn_tag:1"}}, + ) + + assert result["query_tags"] == ( + "conn_tag:1,airflow_dag_id:d,airflow_task_id:t,airflow_run_id:r" + ) + + +def test_query_tags_injection_disabled(): + op = DatabricksSqlOperator(task_id=TASK_ID, sql="SELECT 1", inject_query_tags=False) + context = _make_context(dag_id="d", task_id="t", run_id="r") + + result = _run_with_mocked_hook(op, context, {"query_tags": "user_tag:value"}) + + assert result == {"query_tags": "user_tag:value"} + + def test_is_gcs_output(): """Test _is_gcs_output property.""" op_gcs = DatabricksSqlOperator(task_id=TASK_ID, sql="SELECT 1", output_path="gs://bucket/path")