Skip to content

Commit

Permalink
Add sql_hook_params parameter to SqlToS3Operator (#33425)
Browse files Browse the repository at this point in the history
Adding `sql_hook_params` parameter to `SqlToS3Operator`. This will allow you to pass extra config params to the underlying SQL hook.
  • Loading branch information
alexbegg committed Aug 16, 2023
1 parent 70e3143 commit 45d5f64
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 1 deletion.
6 changes: 5 additions & 1 deletion airflow/providers/amazon/aws/transfers/sql_to_s3.py
Expand Up @@ -65,6 +65,8 @@ class SqlToS3Operator(BaseOperator):
:param s3_key: desired key for the file. It includes the name of the file. (templated)
:param replace: whether or not to replace the file in S3 if it previously existed
:param sql_conn_id: reference to a specific database.
:param sql_hook_params: Extra config params to be passed to the underlying hook.
Should match the desired hook constructor params.
:param parameters: (optional) the parameters to render the SQL query with.
:param aws_conn_id: reference to a specific S3 connection
:param verify: Whether or not to verify SSL certificates for S3 connection.
Expand Down Expand Up @@ -100,6 +102,7 @@ def __init__(
s3_bucket: str,
s3_key: str,
sql_conn_id: str,
sql_hook_params: dict | None = None,
parameters: None | Mapping | Iterable = None,
replace: bool = False,
aws_conn_id: str = "aws_default",
Expand All @@ -120,6 +123,7 @@ def __init__(
self.pd_kwargs = pd_kwargs or {}
self.parameters = parameters
self.groupby_kwargs = groupby_kwargs or {}
self.sql_hook_params = sql_hook_params

if "path_or_buf" in self.pd_kwargs:
raise AirflowException("The argument path_or_buf is not allowed, please remove it")
Expand Down Expand Up @@ -200,7 +204,7 @@ def _partition_dataframe(self, df: DataFrame) -> Iterable[tuple[str, DataFrame]]
def _get_hook(self) -> DbApiHook:
self.log.debug("Get connection for %s", self.sql_conn_id)
conn = BaseHook.get_connection(self.sql_conn_id)
hook = conn.get_hook()
hook = conn.get_hook(hook_params=self.sql_hook_params)
if not callable(getattr(hook, "get_pandas_df", None)):
raise AirflowException(
"This hook is not supported. The hook class must have get_pandas_df method."
Expand Down
18 changes: 18 additions & 0 deletions tests/providers/amazon/aws/transfers/test_sql_to_s3.py
Expand Up @@ -24,6 +24,7 @@
import pytest

from airflow.exceptions import AirflowException
from airflow.models import Connection
from airflow.providers.amazon.aws.transfers.sql_to_s3 import SqlToS3Operator


Expand Down Expand Up @@ -269,3 +270,20 @@ def test_without_groupby_kwarg(self):
}
)
)

@mock.patch("airflow.providers.common.sql.operators.sql.BaseHook.get_connection")
def test_hook_params(self, mock_get_conn):
mock_get_conn.return_value = Connection(conn_id="postgres_test", conn_type="postgres")
op = SqlToS3Operator(
query="query",
s3_bucket="bucket",
s3_key="key",
sql_conn_id="postgres_test",
task_id="task_id",
sql_hook_params={
"log_sql": False,
},
dag=None,
)
hook = op._get_hook()
assert hook.log_sql == op.sql_hook_params["log_sql"]

0 comments on commit 45d5f64

Please sign in to comment.