Skip to content

Commit

Permalink
RedshiftDataOperator replace await_result with `wait_for_completi…
Browse files Browse the repository at this point in the history
…on` (#29633)

* `RedshiftDataOperator` replace `await_result` with `wait_for_completion`
  • Loading branch information
eladkal committed Feb 20, 2023
1 parent dae7bf0 commit 45419e2
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 12 deletions.
20 changes: 15 additions & 5 deletions airflow/providers/amazon/aws/operators/redshift_data.py
Expand Up @@ -44,7 +44,7 @@ class RedshiftDataOperator(BaseOperator):
:param secret_arn: the name or ARN of the secret that enables db access
:param statement_name: the name of the SQL statement
:param with_event: indicates whether to send an event to EventBridge
:param await_result: indicates whether to wait for a result, if True wait, if False don't wait
:param wait_for_completion: indicates whether to wait for a result, if True wait, if False don't wait
:param poll_interval: how often in seconds to check the query status
:param aws_conn_id: aws connection to use
:param region: aws region to use
Expand Down Expand Up @@ -73,10 +73,11 @@ def __init__(
secret_arn: str | None = None,
statement_name: str | None = None,
with_event: bool = False,
await_result: bool = True,
wait_for_completion: bool = True,
poll_interval: int = 10,
aws_conn_id: str = "aws_default",
region: str | None = None,
await_result: bool | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -89,6 +90,15 @@ def __init__(
self.statement_name = statement_name
self.with_event = with_event
self.await_result = await_result
self.wait_for_completion = wait_for_completion
if await_result:
warnings.warn(
f"Parameter `{self.__class__.__name__}.await_result` is deprecated and will be removed "
"in a future release. Please use method `wait_for_completion` instead.",
DeprecationWarning,
stacklevel=2,
)
self.wait_for_completion = await_result
if poll_interval > 0:
self.poll_interval = poll_interval
else:
Expand Down Expand Up @@ -121,7 +131,7 @@ def execute_query(self) -> str:
secret_arn=self.secret_arn,
statement_name=self.statement_name,
with_event=self.with_event,
wait_for_completion=self.await_result,
wait_for_completion=self.wait_for_completion,
poll_interval=self.poll_interval,
)
return self.statement_id
Expand All @@ -142,7 +152,7 @@ def execute_batch_query(self) -> str:
secret_arn=self.secret_arn,
statement_name=self.statement_name,
with_event=self.with_event,
wait_for_completion=self.await_result,
wait_for_completion=self.wait_for_completion,
poll_interval=self.poll_interval,
)
return self.statement_id
Expand All @@ -169,7 +179,7 @@ def execute(self, context: Context) -> str:
secret_arn=self.secret_arn,
statement_name=self.statement_name,
with_event=self.with_event,
wait_for_completion=self.await_result,
wait_for_completion=self.wait_for_completion,
poll_interval=self.poll_interval,
)

Expand Down
27 changes: 22 additions & 5 deletions tests/providers/amazon/aws/operators/test_redshift_data.py
Expand Up @@ -19,6 +19,8 @@

from unittest import mock

import pytest

from airflow.providers.amazon.aws.operators.redshift_data import RedshiftDataOperator

CONN_ID = "aws_conn_test"
Expand All @@ -37,7 +39,7 @@ def test_execute(self, mock_exec_query):
statement_name = "statement_name"
parameters = [{"name": "id", "value": "1"}]
poll_interval = 5
await_result = True
wait_for_completion = True

operator = RedshiftDataOperator(
aws_conn_id=CONN_ID,
Expand All @@ -49,7 +51,7 @@ def test_execute(self, mock_exec_query):
secret_arn=secret_arn,
statement_name=statement_name,
parameters=parameters,
await_result=True,
wait_for_completion=True,
poll_interval=poll_interval,
)
operator.execute(None)
Expand All @@ -62,7 +64,7 @@ def test_execute(self, mock_exec_query):
statement_name=statement_name,
parameters=parameters,
with_event=False,
wait_for_completion=await_result,
wait_for_completion=wait_for_completion,
poll_interval=poll_interval,
)

Expand All @@ -74,7 +76,7 @@ def test_on_kill_without_query(self, mock_conn):
task_id=TASK_ID,
sql=SQL,
database=DATABASE,
await_result=False,
wait_for_completion=False,
)
operator.on_kill()
mock_conn.cancel_statement.assert_not_called()
Expand All @@ -87,10 +89,25 @@ def test_on_kill_with_query(self, mock_conn):
task_id=TASK_ID,
sql=SQL,
database=DATABASE,
await_result=False,
wait_for_completion=False,
)
operator.execute(None)
operator.on_kill()
mock_conn.cancel_statement.assert_called_once_with(
Id=STATEMENT_ID,
)

def test_deprecated_await_result_parameter(self):
warning_message = (
"Parameter `RedshiftDataOperator.await_result` is deprecated and will be removed "
"in a future release. Please use method `wait_for_completion` instead."
)
with pytest.warns(DeprecationWarning, match=warning_message):
op = RedshiftDataOperator(
task_id=TASK_ID,
aws_conn_id=CONN_ID,
sql=SQL,
database=DATABASE,
await_result=True,
)
assert op.wait_for_completion
4 changes: 2 additions & 2 deletions tests/system/providers/amazon/aws/example_redshift.py
Expand Up @@ -202,7 +202,7 @@ def delete_security_group(sec_group_id: str, sec_group_name: str):
);
""",
poll_interval=POLL_INTERVAL,
await_result=True,
wait_for_completion=True,
)
# [END howto_operator_redshift_data]

Expand All @@ -220,7 +220,7 @@ def delete_security_group(sec_group_id: str, sec_group_name: str):
INSERT INTO fruit VALUES ( 6, 'Strawberry', 'Red');
""",
poll_interval=POLL_INTERVAL,
await_result=True,
wait_for_completion=True,
)

# [START howto_operator_redshift_sql]
Expand Down

0 comments on commit 45419e2

Please sign in to comment.