-
Notifications
You must be signed in to change notification settings - Fork 17.1k
Add Databricks query tags for DatabricksSqlOperator and DatabricksCopyIntoOperator #66886
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
200733c
af9bd0b
fbbc6b8
d425a49
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -45,6 +45,64 @@ | |
| _IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") | ||
| _DISALLOWED_SQL_TOKENS = (";", "--", "/*", "*/") | ||
|
|
||
| _QUERY_TAG_FIELDS = { | ||
| "airflow_dag_id": ("dag", "dag_id"), | ||
| "airflow_task_id": ("task", "task_id"), | ||
| "airflow_run_id": ("run_id", None), | ||
| } | ||
|
|
||
| _QUERY_TAG_ESCAPE_SEQUENCES = { | ||
| "\\": "\\\\", | ||
| ",": "\\,", | ||
| ":": "\\:", | ||
| } | ||
|
|
||
|
|
||
| def _escape_query_tag_value(value: str) -> str: | ||
| escaped = str(value) | ||
|
|
||
| for char, replacement in _QUERY_TAG_ESCAPE_SEQUENCES.items(): | ||
| escaped = escaped.replace(char, replacement) | ||
|
|
||
| return escaped | ||
|
|
||
|
|
||
| def _format_query_tags(context: Context) -> str: | ||
| tags = [] | ||
|
|
||
| for tag_name, (context_key, attr) in _QUERY_TAG_FIELDS.items(): | ||
| value = context.get(context_key) | ||
|
|
||
| if attr: | ||
| value = getattr(value, attr, None) | ||
|
|
||
| if value: | ||
| tags.append(f"{tag_name}:{_escape_query_tag_value(value)}") | ||
|
|
||
| return ",".join(tags) | ||
|
|
||
|
|
||
| def _merge_query_tags(session_config: dict[str, Any], query_tags: str) -> dict[str, Any]: | ||
| """Return a copied session config with Airflow query tags appended.""" | ||
| updated_config = session_config.copy() | ||
| existing_tags = updated_config.get("query_tags", "") | ||
| updated_config["query_tags"] = f"{existing_tags},{query_tags}" if existing_tags else query_tags | ||
| return updated_config | ||
|
|
||
|
|
||
| def _inject_query_tags(hook: DatabricksSqlHook, context: Context) -> None: | ||
| """Inject Airflow context metadata into Databricks query tags.""" | ||
| query_tags = _format_query_tags(context) | ||
| if not query_tags: | ||
| return | ||
|
|
||
| if hook.session_config is None: | ||
| conn_extra = hook.databricks_conn.extra_dejson | ||
| hook.session_config = conn_extra.get("session_configuration", {}) | ||
|
|
||
| if isinstance(hook.session_config, dict): | ||
| hook.session_config = _merge_query_tags(hook.session_config, query_tags) | ||
|
|
||
|
|
||
| class DatabricksSqlOperator(SQLExecuteQueryOperator): | ||
| """ | ||
|
|
@@ -83,6 +141,11 @@ class DatabricksSqlOperator(SQLExecuteQueryOperator): | |
| :param gcs_impersonation_chain: Optional service account to impersonate using short-term | ||
| credentials for GCS upload, or chained list of accounts required to get the access_token | ||
| of the last account in the list, which will be impersonated in the request. (templated) | ||
| :param inject_query_tags: If ``True`` (default), Airflow context metadata | ||
| (``airflow_dag_id``, ``airflow_task_id``, ``airflow_run_id``) is injected into the | ||
| Databricks session ``query_tags`` at execution time, preserving any user-defined | ||
| ``query_tags`` already present in ``session_configuration``. Set to ``False`` to | ||
| retain full control over ``session_configuration`` and skip the automatic injection. | ||
| """ | ||
|
|
||
| template_fields: Sequence[str] = tuple( | ||
|
|
@@ -117,6 +180,7 @@ def __init__( | |
| client_parameters: dict[str, Any] | None = None, | ||
| gcp_conn_id: str = "google_cloud_default", | ||
| gcs_impersonation_chain: str | Sequence[str] | None = None, | ||
| inject_query_tags: bool = True, | ||
| **kwargs, | ||
| ) -> None: | ||
| super().__init__(conn_id=databricks_conn_id, **kwargs) | ||
|
|
@@ -134,6 +198,7 @@ def __init__( | |
| self.schema = schema | ||
| self._gcp_conn_id = gcp_conn_id | ||
| self._gcs_impersonation_chain = gcs_impersonation_chain | ||
| self.inject_query_tags = inject_query_tags | ||
|
|
||
| @cached_property | ||
| def _hook(self) -> DatabricksSqlHook: | ||
|
|
@@ -153,6 +218,11 @@ def _hook(self) -> DatabricksSqlHook: | |
| def get_db_hook(self) -> DatabricksSqlHook: | ||
| return self._hook | ||
|
|
||
| def execute(self, context: Context) -> Any: | ||
| if self.inject_query_tags: | ||
| _inject_query_tags(self.get_db_hook(), context) | ||
| return super().execute(context) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we want this behavior configurable (operator/provider-level opt-out)? Since this mutates
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point — making this configurable makes sense so users retain explicit control over |
||
|
|
||
| def _should_run_output_processing(self) -> bool: | ||
| return self.do_xcom_push or bool(self._output_path) | ||
|
|
||
|
|
@@ -348,6 +418,11 @@ class DatabricksCopyIntoOperator(BaseOperator): | |
| :param validate: optional configuration for schema & data validation. ``True`` forces validation | ||
| of all rows, integer number - validate only N first rows | ||
| :param copy_options: optional dictionary of copy options. Right now only ``force`` option is supported. | ||
| :param inject_query_tags: If ``True`` (default), Airflow context metadata | ||
| (``airflow_dag_id``, ``airflow_task_id``, ``airflow_run_id``) is injected into the | ||
| Databricks session ``query_tags`` at execution time, preserving any user-defined | ||
| ``query_tags`` already present in ``session_configuration``. Set to ``False`` to | ||
| retain full control over ``session_configuration`` and skip the automatic injection. | ||
| """ | ||
|
|
||
| template_fields: Sequence[str] = ( | ||
|
|
@@ -381,6 +456,7 @@ def __init__( | |
| force_copy: bool | None = None, | ||
| copy_options: dict[str, str] | None = None, | ||
| validate: bool | int | None = None, | ||
| inject_query_tags: bool = True, | ||
| **kwargs, | ||
| ) -> None: | ||
| """Create a new ``DatabricksSqlOperator``.""" | ||
|
|
@@ -415,6 +491,7 @@ def __init__( | |
| self._client_parameters = client_parameters or {} | ||
| if force_copy is not None: | ||
| self._copy_options["force"] = "true" if force_copy else "false" | ||
| self.inject_query_tags = inject_query_tags | ||
| self._sql: str | None = None | ||
|
|
||
| def _get_hook(self) -> DatabricksSqlHook: | ||
|
|
@@ -518,6 +595,8 @@ def execute(self, context: Context) -> Any: | |
| self._sql = self._create_sql_query() | ||
| self.log.info("Executing: %s", self._sql) | ||
| hook = self._get_hook() | ||
| if self.inject_query_tags: | ||
| _inject_query_tags(hook, context) | ||
| hook.run(self._sql) | ||
|
|
||
| def on_kill(self) -> None: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this could be made more clear and explicit via a mapping driven approach. Please see the below for guidance:
Also, you could do the same for
_escape_query_tag_value:There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestion @SameerMesiah97! I agree the mapping-driven approach is cleaner and more maintainable. I've gone ahead and refactored both
_format_query_tagsand_escape_query_tag_valueto use the_QUERY_TAG_FIELDSand_QUERY_TAG_ESCAPE_SEQUENCESmappings exactly as you suggested — this is included in commit 732df03 ("Refactor Databricks query tag helper utilities"). Please take another look and let me know if you'd like any further tweaks.