Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,64 @@
_IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
_DISALLOWED_SQL_TOKENS = (";", "--", "/*", "*/")

_QUERY_TAG_FIELDS = {
"airflow_dag_id": ("dag", "dag_id"),
"airflow_task_id": ("task", "task_id"),
"airflow_run_id": ("run_id", None),
}

_QUERY_TAG_ESCAPE_SEQUENCES = {
"\\": "\\\\",
",": "\\,",
":": "\\:",
}


def _escape_query_tag_value(value: str) -> str:
escaped = str(value)

for char, replacement in _QUERY_TAG_ESCAPE_SEQUENCES.items():
escaped = escaped.replace(char, replacement)

return escaped


def _format_query_tags(context: Context) -> str:
tags = []

for tag_name, (context_key, attr) in _QUERY_TAG_FIELDS.items():
value = context.get(context_key)

if attr:
value = getattr(value, attr, None)

if value:
tags.append(f"{tag_name}:{_escape_query_tag_value(value)}")

return ",".join(tags)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this could be made more clear and explicit via a mapping driven approach. Please see the below for guidance:

_QUERY_TAG_FIELDS = {
    "airflow_dag_id": ("dag", "dag_id"),
    "airflow_task_id": ("task", "task_id"),
    "airflow_run_id": ("run_id", None),
}


def _format_query_tags(context: Context) -> str:
    tags = []

    for tag_name, (context_key, attr) in _QUERY_TAG_FIELDS.items():
        value = context.get(context_key)

        if attr:
            value = getattr(value, attr, None)

        if value:
            tags.append(f"{tag_name}:{_escape_query_tag_value(value)}")

    return ",".join(tags)

Also, you could do the same for _escape_query_tag_value:

_QUERY_TAG_ESCAPE_SEQUENCES = {
    "\\": "\\\\",
    ",": "\\,",
    ":": "\\:",
}


def _escape_query_tag_value(value: str) -> str:
    """Escape Databricks query-tag separator characters in a tag value."""
    escaped = str(value)

    for char, replacement in _QUERY_TAG_ESCAPE_SEQUENCES.items():
        escaped = escaped.replace(char, replacement)

    return escaped

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion @SameerMesiah97! I agree the mapping-driven approach is cleaner and more maintainable. I've gone ahead and refactored both _format_query_tags and _escape_query_tag_value to use the _QUERY_TAG_FIELDS and _QUERY_TAG_ESCAPE_SEQUENCES mappings exactly as you suggested — this is included in commit 732df03 ("Refactor Databricks query tag helper utilities"). Please take another look and let me know if you'd like any further tweaks.



def _merge_query_tags(session_config: dict[str, Any], query_tags: str) -> dict[str, Any]:
"""Return a copied session config with Airflow query tags appended."""
updated_config = session_config.copy()
existing_tags = updated_config.get("query_tags", "")
updated_config["query_tags"] = f"{existing_tags},{query_tags}" if existing_tags else query_tags
return updated_config


def _inject_query_tags(hook: DatabricksSqlHook, context: Context) -> None:
"""Inject Airflow context metadata into Databricks query tags."""
query_tags = _format_query_tags(context)
if not query_tags:
return

if hook.session_config is None:
conn_extra = hook.databricks_conn.extra_dejson
hook.session_config = conn_extra.get("session_configuration", {})

if isinstance(hook.session_config, dict):
hook.session_config = _merge_query_tags(hook.session_config, query_tags)


class DatabricksSqlOperator(SQLExecuteQueryOperator):
"""
Expand Down Expand Up @@ -83,6 +141,11 @@ class DatabricksSqlOperator(SQLExecuteQueryOperator):
:param gcs_impersonation_chain: Optional service account to impersonate using short-term
credentials for GCS upload, or chained list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request. (templated)
:param inject_query_tags: If ``True`` (default), Airflow context metadata
(``airflow_dag_id``, ``airflow_task_id``, ``airflow_run_id``) is injected into the
Databricks session ``query_tags`` at execution time, preserving any user-defined
``query_tags`` already present in ``session_configuration``. Set to ``False`` to
retain full control over ``session_configuration`` and skip the automatic injection.
"""

template_fields: Sequence[str] = tuple(
Expand Down Expand Up @@ -117,6 +180,7 @@ def __init__(
client_parameters: dict[str, Any] | None = None,
gcp_conn_id: str = "google_cloud_default",
gcs_impersonation_chain: str | Sequence[str] | None = None,
inject_query_tags: bool = True,
**kwargs,
) -> None:
super().__init__(conn_id=databricks_conn_id, **kwargs)
Expand All @@ -134,6 +198,7 @@ def __init__(
self.schema = schema
self._gcp_conn_id = gcp_conn_id
self._gcs_impersonation_chain = gcs_impersonation_chain
self.inject_query_tags = inject_query_tags

@cached_property
def _hook(self) -> DatabricksSqlHook:
Expand All @@ -153,6 +218,11 @@ def _hook(self) -> DatabricksSqlHook:
def get_db_hook(self) -> DatabricksSqlHook:
return self._hook

def execute(self, context: Context) -> Any:
if self.inject_query_tags:
_inject_query_tags(self.get_db_hook(), context)
return super().execute(context)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want this behavior configurable (operator/provider-level opt-out)? Since this mutates session_configuration automatically, some users may prefer explicit control over injected warehouse metadata.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point — making this configurable makes sense so users retain explicit control over session_configuration. I'll add an inject_query_tags: bool = True parameter to both DatabricksSqlOperator and DatabricksCopyIntoOperator (defaulting to True to preserve the observability benefit, while allowing easy opt-out). I'll also document the new parameter in the operator docstrings. Will push the update shortly.


def _should_run_output_processing(self) -> bool:
return self.do_xcom_push or bool(self._output_path)

Expand Down Expand Up @@ -348,6 +418,11 @@ class DatabricksCopyIntoOperator(BaseOperator):
:param validate: optional configuration for schema & data validation. ``True`` forces validation
of all rows, integer number - validate only N first rows
:param copy_options: optional dictionary of copy options. Right now only ``force`` option is supported.
:param inject_query_tags: If ``True`` (default), Airflow context metadata
(``airflow_dag_id``, ``airflow_task_id``, ``airflow_run_id``) is injected into the
Databricks session ``query_tags`` at execution time, preserving any user-defined
``query_tags`` already present in ``session_configuration``. Set to ``False`` to
retain full control over ``session_configuration`` and skip the automatic injection.
"""

template_fields: Sequence[str] = (
Expand Down Expand Up @@ -381,6 +456,7 @@ def __init__(
force_copy: bool | None = None,
copy_options: dict[str, str] | None = None,
validate: bool | int | None = None,
inject_query_tags: bool = True,
**kwargs,
) -> None:
"""Create a new ``DatabricksSqlOperator``."""
Expand Down Expand Up @@ -415,6 +491,7 @@ def __init__(
self._client_parameters = client_parameters or {}
if force_copy is not None:
self._copy_options["force"] = "true" if force_copy else "false"
self.inject_query_tags = inject_query_tags
self._sql: str | None = None

def _get_hook(self) -> DatabricksSqlHook:
Expand Down Expand Up @@ -518,6 +595,8 @@ def execute(self, context: Context) -> Any:
self._sql = self._create_sql_query()
self.log.info("Executing: %s", self._sql)
hook = self._get_hook()
if self.inject_query_tags:
_inject_query_tags(hook, context)
hook.run(self._sql)

def on_kill(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,161 @@ def test_hook_is_cached():
assert hook is hook2


def _make_context(*, dag_id=None, task_id=None, run_id=None):
context: dict = {}
if dag_id is not None:
context["dag"] = mock.MagicMock(dag_id=dag_id)
if task_id is not None:
context["task"] = mock.MagicMock(task_id=task_id)
if run_id is not None:
context["run_id"] = run_id
return context


def _run_with_mocked_hook(op, context, initial_session_config, conn_extra=None):
"""Execute the operator with a mocked hook and return the resulting session_config."""
with mock.patch(
"airflow.providers.databricks.operators.databricks_sql.DatabricksSqlHook"
) as db_mock_class:
db_mock = db_mock_class.return_value
db_mock.session_config = initial_session_config
db_mock.databricks_conn = mock.MagicMock(extra_dejson=conn_extra or {})
op.execute(context)
return db_mock.session_config


def test_query_tags_injection_appends_to_existing_tags():
op = DatabricksCopyIntoOperator(
task_id=TASK_ID,
file_location=COPY_FILE_LOCATION,
file_format="JSON",
table_name="test",
)
context = _make_context(dag_id="test_dag", task_id="test_task", run_id="test_run_123")

result = _run_with_mocked_hook(op, context, {"query_tags": "user_tag:value"})

assert result["query_tags"] == (
"user_tag:value,airflow_dag_id:test_dag,"
"airflow_task_id:test_task,airflow_run_id:test_run_123"
)


def test_query_tags_injection_with_no_existing_tags():
op = DatabricksCopyIntoOperator(
task_id=TASK_ID,
file_location=COPY_FILE_LOCATION,
file_format="JSON",
table_name="test",
)
context = _make_context(dag_id="d", task_id="t", run_id="r")

result = _run_with_mocked_hook(op, context, {})

assert result["query_tags"] == "airflow_dag_id:d,airflow_task_id:t,airflow_run_id:r"


def test_query_tags_injection_with_partial_context():
op = DatabricksCopyIntoOperator(
task_id=TASK_ID,
file_location=COPY_FILE_LOCATION,
file_format="JSON",
table_name="test",
)
context = _make_context(task_id="only_task")

result = _run_with_mocked_hook(op, context, {})

assert result["query_tags"] == "airflow_task_id:only_task"


def test_query_tags_injection_with_empty_context():
op = DatabricksCopyIntoOperator(
task_id=TASK_ID,
file_location=COPY_FILE_LOCATION,
file_format="JSON",
table_name="test",
)

result = _run_with_mocked_hook(op, {}, {"unrelated": "keep"})

assert result == {"unrelated": "keep"}


def test_query_tags_injection_escapes_special_chars():
op = DatabricksCopyIntoOperator(
task_id=TASK_ID,
file_location=COPY_FILE_LOCATION,
file_format="JSON",
table_name="test",
)
context = _make_context(
dag_id="dag,with,commas",
task_id="task:with:colons",
run_id="run\\with\\backslashes",
)

result = _run_with_mocked_hook(op, context, {})

assert result["query_tags"] == (
"airflow_dag_id:dag\\,with\\,commas,"
"airflow_task_id:task\\:with\\:colons,"
"airflow_run_id:run\\\\with\\\\backslashes"
)


def test_query_tags_injection_preserves_unrelated_session_config():
op = DatabricksCopyIntoOperator(
task_id=TASK_ID,
file_location=COPY_FILE_LOCATION,
file_format="JSON",
table_name="test",
)
context = _make_context(dag_id="d", task_id="t", run_id="r")
initial = {"spark.sql.shuffle.partitions": "200", "query_tags": "x:y"}

result = _run_with_mocked_hook(op, context, initial)

assert result["spark.sql.shuffle.partitions"] == "200"
assert result["query_tags"] == "x:y,airflow_dag_id:d,airflow_task_id:t,airflow_run_id:r"


def test_query_tags_injection_falls_back_to_conn_extra_when_session_config_none():
op = DatabricksCopyIntoOperator(
task_id=TASK_ID,
file_location=COPY_FILE_LOCATION,
file_format="JSON",
table_name="test",
)
context = _make_context(dag_id="d", task_id="t", run_id="r")

result = _run_with_mocked_hook(
op,
context,
initial_session_config=None,
conn_extra={"session_configuration": {"query_tags": "conn_tag:1"}},
)

assert result["query_tags"] == (
"conn_tag:1,airflow_dag_id:d,airflow_task_id:t,airflow_run_id:r"
)


def test_query_tags_injection_disabled():
op = DatabricksCopyIntoOperator(
task_id=TASK_ID,
file_location=COPY_FILE_LOCATION,
file_format="JSON",
table_name="test",
inject_query_tags=False,
)
context = _make_context(dag_id="d", task_id="t", run_id="r")

result = _run_with_mocked_hook(op, context, {"query_tags": "user_tag:value"})

assert result == {"query_tags": "user_tag:value"}


@pytest.mark.parametrize(
("file_location", "expected_namespace", "expected_name"),
(
Expand Down
Loading