Skip to content

Commit

Permalink
Add a new parameter to SQL operators to specify conn id field (#30784)
Browse files Browse the repository at this point in the history
  • Loading branch information
hussein-awala committed Aug 7, 2023
1 parent 5387c0c commit 9736143
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 8 deletions.
11 changes: 7 additions & 4 deletions airflow/providers/common/sql/operators/sql.py
Expand Up @@ -123,6 +123,8 @@ class BaseSQLOperator(BaseOperator):
:param conn_id: reference to a specific database
"""

conn_id_field = "conn_id"

def __init__(
self,
*,
Expand All @@ -141,8 +143,9 @@ def __init__(
@cached_property
def _hook(self):
"""Get DB Hook based on connection type."""
self.log.debug("Get connection for %s", self.conn_id)
conn = BaseHook.get_connection(self.conn_id)
conn_id = getattr(self, self.conn_id_field)
self.log.debug("Get connection for %s", conn_id)
conn = BaseHook.get_connection(conn_id)
hook = conn.get_hook(hook_params=self.hook_params)
if not isinstance(hook, DbApiHook):
from airflow.hooks.dbapi_hook import DbApiHook as _DbApiHook
Expand Down Expand Up @@ -411,7 +414,7 @@ class SQLColumnCheckOperator(BaseSQLOperator):
:ref:`howto/operator:SQLColumnCheckOperator`
"""

template_fields = ("partition_clause", "table", "sql")
template_fields: Sequence[str] = ("partition_clause", "table", "sql")
template_fields_renderers = {"sql": "sql"}

sql_check_template = """
Expand Down Expand Up @@ -639,7 +642,7 @@ class SQLTableCheckOperator(BaseSQLOperator):
:ref:`howto/operator:SQLTableCheckOperator`
"""

template_fields = ("partition_clause", "table", "sql", "conn_id")
template_fields: Sequence[str] = ("partition_clause", "table", "sql", "conn_id")

template_fields_renderers = {"sql": "sql"}

Expand Down
1 change: 1 addition & 0 deletions airflow/providers/databricks/operators/databricks_sql.py
Expand Up @@ -77,6 +77,7 @@ class DatabricksSqlOperator(SQLExecuteQueryOperator):

template_ext: Sequence[str] = (".sql",)
template_fields_renderers = {"sql": "sql"}
conn_id_field = "databricks_conn_id"

def __init__(
self,
Expand Down
4 changes: 3 additions & 1 deletion airflow/providers/exasol/operators/exasol.py
Expand Up @@ -38,10 +38,11 @@ class ExasolOperator(SQLExecuteQueryOperator):
:param handler: (optional) handler to process the results of the query
"""

template_fields: Sequence[str] = ("sql",)
template_fields: Sequence[str] = ("sql", "exasol_conn_id")
template_ext: Sequence[str] = (".sql",)
template_fields_renderers = {"sql": "sql"}
ui_color = "#ededed"
conn_id_field = "exasol_conn_id"

def __init__(
self,
Expand All @@ -51,6 +52,7 @@ def __init__(
handler=exasol_fetch_all_handler,
**kwargs,
) -> None:
self.exasol_conn_id = exasol_conn_id
if schema is not None:
hook_params = kwargs.pop("hook_params", {})
kwargs["hook_params"] = {"schema": schema, **hook_params}
Expand Down
11 changes: 11 additions & 0 deletions airflow/providers/google/cloud/operators/bigquery.py
Expand Up @@ -253,6 +253,7 @@ class BigQueryCheckOperator(_BigQueryDbHookMixin, SQLCheckOperator):
)
template_ext: Sequence[str] = (".sql",)
ui_color = BigQueryUIColors.CHECK.value
conn_id_field = "gcp_conn_id"

def __init__(
self,
Expand Down Expand Up @@ -371,6 +372,7 @@ class BigQueryValueCheckOperator(_BigQueryDbHookMixin, SQLValueCheckOperator):
)
template_ext: Sequence[str] = (".sql",)
ui_color = BigQueryUIColors.CHECK.value
conn_id_field = "gcp_conn_id"

def __init__(
self,
Expand Down Expand Up @@ -509,6 +511,7 @@ class BigQueryIntervalCheckOperator(_BigQueryDbHookMixin, SQLIntervalCheckOperat
"labels",
)
ui_color = BigQueryUIColors.CHECK.value
conn_id_field = "gcp_conn_id"

def __init__(
self,
Expand Down Expand Up @@ -634,6 +637,10 @@ class BigQueryColumnCheckOperator(_BigQueryDbHookMixin, SQLColumnCheckOperator):
:param labels: a dictionary containing labels for the table, passed to BigQuery
"""

template_fields: Sequence[str] = tuple(set(SQLColumnCheckOperator.template_fields) | {"gcp_conn_id"})

conn_id_field = "gcp_conn_id"

def __init__(
self,
*,
Expand Down Expand Up @@ -757,6 +764,10 @@ class BigQueryTableCheckOperator(_BigQueryDbHookMixin, SQLTableCheckOperator):
:param labels: a dictionary containing labels for the table, passed to BigQuery
"""

template_fields: Sequence[str] = tuple(set(SQLTableCheckOperator.template_fields) | {"gcp_conn_id"})

conn_id_field = "gcp_conn_id"

def __init__(
self,
*,
Expand Down
11 changes: 9 additions & 2 deletions airflow/providers/qubole/operators/qubole_check.py
Expand Up @@ -103,8 +103,10 @@ class QuboleCheckOperator(_QuboleCheckOperatorMixin, SQLCheckOperator, QuboleOpe
"""

conn_id_field = "qubole_conn_id"

template_fields: Sequence[str] = tuple(
set(QuboleOperator.template_fields) | set(SQLCheckOperator.template_fields)
set(QuboleOperator.template_fields) | set(SQLCheckOperator.template_fields) | {"qubole_conn_id"}
)
template_ext = QuboleOperator.template_ext
ui_fgcolor = "#000"
Expand All @@ -123,6 +125,7 @@ def __init__(
self.on_failure_callback = QuboleCheckHook.handle_failure_retry
self.on_retry_callback = QuboleCheckHook.handle_failure_retry
self._hook_context = None
self.qubole_conn_id = qubole_conn_id


# TODO(xinbinhuang): refactor to reduce levels of inheritance
Expand Down Expand Up @@ -155,9 +158,12 @@ class QuboleValueCheckOperator(_QuboleCheckOperatorMixin, SQLValueCheckOperator,
QuboleOperator and SQLValueCheckOperator are template-supported.
"""

template_fields = tuple(set(QuboleOperator.template_fields) | set(SQLValueCheckOperator.template_fields))
template_fields = tuple(
set(QuboleOperator.template_fields) | set(SQLValueCheckOperator.template_fields) | {"qubole_conn_id"}
)
template_ext = QuboleOperator.template_ext
ui_fgcolor = "#000"
conn_id_field = "qubole_conn_id"

def __init__(
self,
Expand All @@ -177,6 +183,7 @@ def __init__(
self.on_failure_callback = QuboleCheckHook.handle_failure_retry
self.on_retry_callback = QuboleCheckHook.handle_failure_retry
self._hook_context = None
self.qubole_conn_id = qubole_conn_id


def get_sql_from_qbol_cmd(params) -> str:
Expand Down
12 changes: 11 additions & 1 deletion airflow/providers/snowflake/operators/snowflake.py
Expand Up @@ -192,9 +192,10 @@ class SnowflakeCheckOperator(SQLCheckOperator):
the time you connect to Snowflake
"""

template_fields: Sequence[str] = ("sql",)
template_fields: Sequence[str] = tuple(set(SQLCheckOperator.template_fields) | {"snowflake_conn_id"})
template_ext: Sequence[str] = (".sql",)
ui_color = "#ededed"
conn_id_field = "snowflake_conn_id"

def __init__(
self,
Expand Down Expand Up @@ -259,6 +260,10 @@ class SnowflakeValueCheckOperator(SQLValueCheckOperator):
the time you connect to Snowflake
"""

template_fields: Sequence[str] = tuple(set(SQLValueCheckOperator.template_fields) | {"snowflake_conn_id"})

conn_id_field = "snowflake_conn_id"

def __init__(
self,
*,
Expand Down Expand Up @@ -333,6 +338,11 @@ class SnowflakeIntervalCheckOperator(SQLIntervalCheckOperator):
the time you connect to Snowflake
"""

template_fields: Sequence[str] = tuple(
set(SQLIntervalCheckOperator.template_fields) | {"snowflake_conn_id"}
)
conn_id_field = "snowflake_conn_id"

def __init__(
self,
*,
Expand Down
32 changes: 32 additions & 0 deletions tests/providers/common/sql/operators/test_sql.py
Expand Up @@ -1278,3 +1278,35 @@ def test_with_skip_in_branch_downstream_dependencies2(self, mock_get_db_hook):
assert ti.state == State.NONE
else:
raise ValueError(f"Invalid task id {ti.task_id} found!")


class TestBaseSQLOperatorSubClass:

from airflow.providers.common.sql.operators.sql import BaseSQLOperator

class NewStyleBaseSQLOperatorSubClass(BaseSQLOperator):
"""New style subclass of BaseSQLOperator"""

conn_id_field = "custom_conn_id_field"

def __init__(self, custom_conn_id_field="test_conn", **kwargs):
super().__init__(**kwargs)
self.custom_conn_id_field = custom_conn_id_field

class OldStyleBaseSQLOperatorSubClass(BaseSQLOperator):
"""Old style subclass of BaseSQLOperator"""

def __init__(self, custom_conn_id_field="test_conn", **kwargs):
super().__init__(conn_id=custom_conn_id_field, **kwargs)

@pytest.mark.parametrize(
"operator_class", [NewStyleBaseSQLOperatorSubClass, OldStyleBaseSQLOperatorSubClass]
)
@mock.patch("airflow.hooks.base.BaseHook.get_connection")
def test_new_style_subclass(self, mock_get_connection, operator_class):
from airflow.providers.common.sql.hooks.sql import DbApiHook

op = operator_class(task_id="test_task")
mock_get_connection.return_value.get_hook.return_value = MagicMock(spec=DbApiHook)
op.get_db_hook()
mock_get_connection.assert_called_once_with("test_conn")

0 comments on commit 9736143

Please sign in to comment.