Skip to content

Commit

Permalink
Defer to hook setting for split_statements in SQLExecuteQueryOperator (
Browse files Browse the repository at this point in the history
…#28635)

Some databases, such as snowflake, require you to split statements in order to submit multi-statement sql.  For such databases, splitting is the natural default, and we should defer to the hook to control that.
  • Loading branch information
dstandish committed Dec 30, 2022
1 parent 2aa52f4 commit 2e7b9f5
Show file tree
Hide file tree
Showing 9 changed files with 9 additions and 14 deletions.
11 changes: 8 additions & 3 deletions airflow/providers/common/sql/operators/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,8 @@ class SQLExecuteQueryOperator(BaseSQLOperator):
:param autocommit: (optional) if True, each command is automatically committed (default: False).
:param parameters: (optional) the parameters to render the SQL query with.
:param handler: (optional) the function that will be applied to the cursor (default: fetch_all_handler).
:param split_statements: (optional) if split single SQL string into statements (default: False).
:param split_statements: (optional) if split single SQL string into statements. By default, defers
to the default value in the ``run`` method of the configured hook.
:param return_last: (optional) return the result of only last statement (default: True).
.. seealso::
Expand All @@ -218,7 +219,7 @@ def __init__(
autocommit: bool = False,
parameters: Mapping | Iterable | None = None,
handler: Callable[[Any], Any] = fetch_all_handler,
split_statements: bool = False,
split_statements: bool | None = None,
return_last: bool = True,
**kwargs,
) -> None:
Expand Down Expand Up @@ -252,13 +253,17 @@ def _process_output(self, results: list[Any], descriptions: list[Sequence[Sequen
def execute(self, context):
self.log.info("Executing: %s", self.sql)
hook = self.get_db_hook()
if self.split_statements is not None:
extra_kwargs = {"split_statements": self.split_statements}
else:
extra_kwargs = {}
output = hook.run(
sql=self.sql,
autocommit=self.autocommit,
parameters=self.parameters,
handler=self.handler if self.do_xcom_push else None,
split_statements=self.split_statements,
return_last=self.return_last,
**extra_kwargs,
)
if return_single_query_results(self.sql, self.return_last, self.split_statements):
# For simplicity, we pass always list as input to _process_output, regardless if
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/common/sql/operators/sql.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class SQLExecuteQueryOperator(BaseSQLOperator):
autocommit: bool = ...,
parameters: Union[Mapping, Iterable, None] = ...,
handler: Callable[[Any], Any] = ...,
split_statements: bool = ...,
split_statements: Union[bool, None] = ...,
return_last: bool = ...,
**kwargs,
) -> None: ...
Expand Down
1 change: 0 additions & 1 deletion tests/providers/amazon/aws/operators/test_redshift_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,4 @@ def test_redshift_operator(self, mock_get_hook, test_autocommit, test_parameters
parameters=test_parameters,
handler=fetch_all_handler,
return_last=True,
split_statements=False,
)
2 changes: 0 additions & 2 deletions tests/providers/common/sql/operators/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def test_do_xcom_push(self, mock_get_db_hook):
handler=fetch_all_handler,
parameters=None,
return_last=True,
split_statements=False,
)

@mock.patch.object(SQLExecuteQueryOperator, "get_db_hook")
Expand All @@ -87,7 +86,6 @@ def test_dont_xcom_push(self, mock_get_db_hook):
sql="SELECT 1;",
autocommit=False,
parameters=None,
split_statements=False,
handler=None,
return_last=True,
)
Expand Down
2 changes: 0 additions & 2 deletions tests/providers/exasol/operators/test_exasol.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def test_overwrite_autocommit(self, mock_get_db_hook):
parameters=None,
handler=fetch_all_handler,
return_last=True,
split_statements=False,
)

@mock.patch("airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator.get_db_hook")
Expand All @@ -47,7 +46,6 @@ def test_pass_parameters(self, mock_get_db_hook):
parameters={"value": 1},
handler=fetch_all_handler,
return_last=True,
split_statements=False,
)

@mock.patch("airflow.providers.common.sql.operators.sql.BaseSQLOperator.__init__")
Expand Down
2 changes: 0 additions & 2 deletions tests/providers/jdbc/operators/test_jdbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def test_execute_do_push(self, mock_get_db_hook):
handler=fetch_all_handler,
parameters=jdbc_operator.parameters,
return_last=True,
split_statements=False,
)

@patch("airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator.get_db_hook")
Expand All @@ -52,5 +51,4 @@ def test_execute_dont_push(self, mock_get_db_hook):
parameters=jdbc_operator.parameters,
handler=None,
return_last=True,
split_statements=False,
)
1 change: 0 additions & 1 deletion tests/providers/oracle/operators/test_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def test_execute(self, mock_get_db_hook):
parameters=parameters,
handler=fetch_all_handler,
return_last=True,
split_statements=False,
)


Expand Down
1 change: 0 additions & 1 deletion tests/providers/trino/operators/test_trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,4 @@ def test_execute(self, mock_get_db_hook):
handler=list,
parameters=None,
return_last=True,
split_statements=False,
)
1 change: 0 additions & 1 deletion tests/providers/vertica/operators/test_vertica.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,4 @@ def test_execute(self, mock_get_db_hook):
handler=fetch_all_handler,
parameters=None,
return_last=True,
split_statements=False,
)

0 comments on commit 2e7b9f5

Please sign in to comment.