Skip to content

Commit

Permalink
Add script_args for S3FileTransformOperator (#9019)
Browse files Browse the repository at this point in the history
Co-authored-by: Andrej Svec <asvec@slido.com>
  • Loading branch information
sweco and sweco committed May 28, 2020
1 parent 369e637 commit 1ed171b
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 48 deletions.
11 changes: 8 additions & 3 deletions airflow/providers/amazon/aws/operators/s3_file_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import subprocess
import sys
from tempfile import NamedTemporaryFile
from typing import Optional, Union
from typing import Optional, Sequence, Union

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
Expand Down Expand Up @@ -52,6 +52,8 @@ class S3FileTransformOperator(BaseOperator):
:type transform_script: str
:param select_expression: S3 Select expression
:type select_expression: str
:param script_args: arguments for transformation script (templated)
:type script_args: sequence of str
:param source_aws_conn_id: source s3 connection
:type source_aws_conn_id: str
:param source_verify: Whether or not to verify SSL certificates for S3 connection.
Expand All @@ -76,7 +78,7 @@ class S3FileTransformOperator(BaseOperator):
:type replace: bool
"""

template_fields = ('source_s3_key', 'dest_s3_key')
template_fields = ('source_s3_key', 'dest_s3_key', 'script_args')
template_ext = ()
ui_color = '#f9c915'

Expand All @@ -87,12 +89,14 @@ def __init__(
dest_s3_key: str,
transform_script: Optional[str] = None,
select_expression=None,
script_args: Optional[Sequence[str]] = None,
source_aws_conn_id: str = 'aws_default',
source_verify: Optional[Union[bool, str]] = None,
dest_aws_conn_id: str = 'aws_default',
dest_verify: Optional[Union[bool, str]] = None,
replace: bool = False,
*args, **kwargs) -> None:
# pylint: disable=too-many-arguments
super().__init__(*args, **kwargs)
self.source_s3_key = source_s3_key
self.source_aws_conn_id = source_aws_conn_id
Expand All @@ -103,6 +107,7 @@ def __init__(
self.replace = replace
self.transform_script = transform_script
self.select_expression = select_expression
self.script_args = script_args or []
self.output_encoding = sys.getdefaultencoding()

def execute(self, context):
Expand Down Expand Up @@ -137,7 +142,7 @@ def execute(self, context):

if self.transform_script is not None:
process = subprocess.Popen(
[self.transform_script, f_source.name, f_dest.name],
[self.transform_script, f_source.name, f_dest.name, *self.script_args],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
close_fds=True
Expand Down
99 changes: 54 additions & 45 deletions tests/providers/amazon/aws/operators/test_s3_file_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,25 +53,12 @@ def tearDown(self):
@mock_s3
def test_execute_with_transform_script(self, mock_log, mock_popen):
process_output = [b"Foo", b"Bar", b"Baz"]
self.mock_process(mock_popen, process_output=process_output)
input_path, output_path = self.s3_paths()

process = mock_popen.return_value
process.stdout.readline.side_effect = process_output
process.wait.return_value = None
process.returncode = 0

bucket = "bucket"
input_key = "foo"
output_key = "bar"
bio = io.BytesIO(b"input")

conn = boto3.client('s3')
conn.create_bucket(Bucket=bucket)
conn.upload_fileobj(Bucket=bucket, Key=input_key, Fileobj=bio)

s3_url = "s3://{0}/{1}"
op = S3FileTransformOperator(
source_s3_key=s3_url.format(bucket, input_key),
dest_s3_key=s3_url.format(bucket, output_key),
source_s3_key=input_path,
dest_s3_key=output_path,
transform_script=self.transform_script,
replace=True,
task_id="task_id")
Expand All @@ -84,24 +71,12 @@ def test_execute_with_transform_script(self, mock_log, mock_popen):
@mock.patch('subprocess.Popen')
@mock_s3
def test_execute_with_failing_transform_script(self, mock_popen):
process = mock_popen.return_value
process.stdout.readline.side_effect = []
process.wait.return_value = None
process.returncode = 42
self.mock_process(mock_popen, return_code=42)
input_path, output_path = self.s3_paths()

bucket = "bucket"
input_key = "foo"
output_key = "bar"
bio = io.BytesIO(b"input")

conn = boto3.client('s3')
conn.create_bucket(Bucket=bucket)
conn.upload_fileobj(Bucket=bucket, Key=input_key, Fileobj=bio)

s3_url = "s3://{0}/{1}"
op = S3FileTransformOperator(
source_s3_key=s3_url.format(bucket, input_key),
dest_s3_key=s3_url.format(bucket, output_key),
source_s3_key=input_path,
dest_s3_key=output_path,
transform_script=self.transform_script,
replace=True,
task_id="task_id")
Expand All @@ -111,9 +86,52 @@ def test_execute_with_failing_transform_script(self, mock_popen):

self.assertEqual('Transform script failed: 42', str(e.exception))

@mock.patch('subprocess.Popen')
@mock_s3
def test_execute_with_transform_script_args(self, mock_popen):
self.mock_process(mock_popen, process_output=[b"Foo", b"Bar", b"Baz"])
input_path, output_path = self.s3_paths()
script_args = ['arg1', 'arg2']

op = S3FileTransformOperator(
source_s3_key=input_path,
dest_s3_key=output_path,
transform_script=self.transform_script,
script_args=script_args,
replace=True,
task_id="task_id")
op.execute(None)

self.assertEqual(script_args, mock_popen.call_args[0][0][3:])

@mock.patch('airflow.providers.amazon.aws.hooks.s3.S3Hook.select_key', return_value="input")
@mock_s3
def test_execute_with_select_expression(self, mock_select_key):
input_path, output_path = self.s3_paths()
select_expression = "SELECT * FROM s3object s"

op = S3FileTransformOperator(
source_s3_key=input_path,
dest_s3_key=output_path,
select_expression=select_expression,
replace=True,
task_id="task_id")
op.execute(None)

mock_select_key.assert_called_once_with(
key=input_path,
expression=select_expression
)

@staticmethod
def mock_process(mock_popen, return_code=0, process_output=None):
process = mock_popen.return_value
process.stdout.readline.side_effect = process_output or []
process.wait.return_value = None
process.returncode = return_code

@staticmethod
def s3_paths():
bucket = "bucket"
input_key = "foo"
output_key = "bar"
Expand All @@ -124,16 +142,7 @@ def test_execute_with_select_expression(self, mock_select_key):
conn.upload_fileobj(Bucket=bucket, Key=input_key, Fileobj=bio)

s3_url = "s3://{0}/{1}"
select_expression = "SELECT * FROM S3Object s"
op = S3FileTransformOperator(
source_s3_key=s3_url.format(bucket, input_key),
dest_s3_key=s3_url.format(bucket, output_key),
select_expression=select_expression,
replace=True,
task_id="task_id")
op.execute(None)
input_path = s3_url.format(bucket, input_key)
output_path = s3_url.format(bucket, output_key)

mock_select_key.assert_called_once_with(
key=s3_url.format(bucket, input_key),
expression=select_expression
)
return input_path, output_path

0 comments on commit 1ed171b

Please sign in to comment.