Skip to content

Commit

Permalink
Allow AWS Operator RedshiftToS3Transfer To Run a Custom Query (#14177)
Browse files Browse the repository at this point in the history
Co-authored-by: Arati Nagmal <arati@stepfunction.ai>
  • Loading branch information
AratiNagmal and Arati Nagmal committed Feb 11, 2021
1 parent a3b9f1e commit 1b14726
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 13 deletions.
40 changes: 28 additions & 12 deletions airflow/providers/amazon/aws/transfers/redshift_to_s3.py
Expand Up @@ -29,15 +29,19 @@ class RedshiftToS3Operator(BaseOperator):
"""
Executes an UNLOAD command to s3 as a CSV with headers
:param schema: reference to a specific schema in redshift database
:type schema: str
:param table: reference to a specific table in redshift database
:type table: str
:param s3_bucket: reference to a specific S3 bucket
:type s3_bucket: str
:param s3_key: reference to a specific S3 key. If ``table_as_file_name`` is set
to False, this param must include the desired file name
:type s3_key: str
:param schema: reference to a specific schema in redshift database
Applicable when ``table`` param provided.
:type schema: str
:param table: reference to a specific table in redshift database
Used when ``select_query`` param not provided.
:type table: str
:param select_query: custom select query to fetch data from redshift database
:type select_query: str
:param redshift_conn_id: reference to a specific redshift database
:type redshift_conn_id: str
:param aws_conn_id: reference to a specific S3 connection
Expand All @@ -63,7 +67,8 @@ class RedshiftToS3Operator(BaseOperator):
:type autocommit: bool
:param include_header: If set to True the s3 file contains the header columns.
:type include_header: bool
:param table_as_file_name: If set to True, the s3 file will be named as the table
:param table_as_file_name: If set to True, the s3 file will be named as the table.
Applicable when ``table`` param provided.
:type table_as_file_name: bool
"""

Expand All @@ -75,10 +80,11 @@ class RedshiftToS3Operator(BaseOperator):
def __init__( # pylint: disable=too-many-arguments
self,
*,
schema: str,
table: str,
s3_bucket: str,
s3_key: str,
schema: str = None,
table: str = None,
select_query: str = None,
redshift_conn_id: str = 'redshift_default',
aws_conn_id: str = 'aws_default',
verify: Optional[Union[bool, str]] = None,
Expand All @@ -89,10 +95,10 @@ def __init__( # pylint: disable=too-many-arguments
**kwargs,
) -> None:
super().__init__(**kwargs)
self.s3_bucket = s3_bucket
self.s3_key = f'{s3_key}/{table}_' if (table and table_as_file_name) else s3_key
self.schema = schema
self.table = table
self.s3_bucket = s3_bucket
self.s3_key = s3_key
self.redshift_conn_id = redshift_conn_id
self.aws_conn_id = aws_conn_id
self.verify = verify
Expand All @@ -101,6 +107,16 @@ def __init__( # pylint: disable=too-many-arguments
self.include_header = include_header
self.table_as_file_name = table_as_file_name

self._select_query = None
if select_query:
self._select_query = select_query
elif self.schema and self.table:
self._select_query = f"SELECT * FROM {self.schema}.{self.table}"
else:
raise ValueError(
'Please provide both `schema` and `table` params or `select_query` to fetch the data.'
)

if self.include_header and 'HEADER' not in [uo.upper().strip() for uo in self.unload_options]:
self.unload_options = list(self.unload_options) + [
'HEADER',
Expand All @@ -124,10 +140,10 @@ def execute(self, context) -> None:
credentials = s3_hook.get_credentials()
credentials_block = build_credentials_block(credentials)
unload_options = '\n\t\t\t'.join(self.unload_options)
s3_key = f"{self.s3_key}/{self.table}_" if self.table_as_file_name else self.s3_key
select_query = f"SELECT * FROM {self.schema}.{self.table}"

unload_query = self._build_unload_query(credentials_block, select_query, s3_key, unload_options)
unload_query = self._build_unload_query(
credentials_block, self._select_query, self.s3_key, unload_options
)

self.log.info('Executing UNLOAD command...')
postgres_hook.run(unload_query, self.autocommit)
Expand Down
61 changes: 60 additions & 1 deletion tests/providers/amazon/aws/transfers/test_redshift_to_s3.py
Expand Up @@ -37,7 +37,7 @@ class TestRedshiftToS3Transfer(unittest.TestCase):
)
@mock.patch("boto3.session.Session")
@mock.patch("airflow.providers.postgres.hooks.postgres.PostgresHook.run")
def test_execute(
def test_table_unloading(
self,
table_as_file_name,
expected_s3_key,
Expand Down Expand Up @@ -147,6 +147,65 @@ def test_execute_sts_token(
assert token in unload_query
assert_equal_ignore_multiple_spaces(self, mock_run.call_args[0][0], unload_query)

@parameterized.expand(
[
["table", True, "key/table_"],
["table", False, "key"],
[None, False, "key"],
[None, True, "key"],
]
)
@mock.patch("boto3.session.Session")
@mock.patch("airflow.providers.postgres.hooks.postgres.PostgresHook.run")
def test_custom_select_query_unloading(
self,
table,
table_as_file_name,
expected_s3_key,
mock_run,
mock_session,
):
access_key = "aws_access_key_id"
secret_key = "aws_secret_access_key"
mock_session.return_value = Session(access_key, secret_key)
mock_session.return_value.access_key = access_key
mock_session.return_value.secret_key = secret_key
mock_session.return_value.token = None
s3_bucket = "bucket"
s3_key = "key"
unload_options = [
'HEADER',
]
select_query = "select column from table"

op = RedshiftToS3Operator(
select_query=select_query,
table=table,
table_as_file_name=table_as_file_name,
s3_bucket=s3_bucket,
s3_key=s3_key,
unload_options=unload_options,
include_header=True,
redshift_conn_id="redshift_conn_id",
aws_conn_id="aws_conn_id",
task_id="task_id",
dag=None,
)

op.execute(None)

unload_options = '\n\t\t\t'.join(unload_options)
credentials_block = build_credentials_block(mock_session.return_value)

unload_query = op._build_unload_query(
credentials_block, select_query, expected_s3_key, unload_options
)

assert mock_run.call_count == 1
assert access_key in unload_query
assert secret_key in unload_query
assert_equal_ignore_multiple_spaces(self, mock_run.call_args[0][0], unload_query)

def test_template_fields_overrides(self):
assert RedshiftToS3Operator.template_fields == (
's3_bucket',
Expand Down

0 comments on commit 1b14726

Please sign in to comment.