Skip to content

Commit

Permalink
[AIRFLOW-4175] S3Hook load_file should support ACL policy paramete (#…
Browse files Browse the repository at this point in the history
…7733)

- Added acl_policy parameter to all the S3Hook.load_*() and S3Hook.copy_object() function
                     - Added unittest to test the response permissions when the policy is passed
                     - Updated the docstring of the function

Co-authored-by: retornam <retornam@users.noreply.github.com>
  • Loading branch information
OmairK and retornam committed Mar 16, 2020
1 parent ae035cd commit a8b5fc7
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 10 deletions.
46 changes: 36 additions & 10 deletions airflow/providers/amazon/aws/hooks/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,8 @@ def load_file(self,
bucket_name=None,
replace=False,
encrypt=False,
gzip=False):
gzip=False,
acl_policy=None):
"""
Loads a local file to S3
Expand All @@ -447,6 +448,9 @@ def load_file(self,
:type encrypt: bool
:param gzip: If True, the file will be compressed locally
:type gzip: bool
:param acl_policy: String specifying the canned ACL policy for the file being
uploaded to the S3 bucket.
:type acl_policy: str
"""

if not replace and self.check_for_key(key, bucket_name):
Expand All @@ -461,6 +465,8 @@ def load_file(self,
with gz.open(filename_gz, 'wb') as f_out:
shutil.copyfileobj(f_in, f_out)
filename = filename_gz
if acl_policy:
extra_args['ACL'] = acl_policy

client = self.get_conn()
client.upload_file(filename, bucket_name, key, ExtraArgs=extra_args)
Expand All @@ -473,7 +479,8 @@ def load_string(self,
bucket_name=None,
replace=False,
encrypt=False,
encoding='utf-8'):
encoding='utf-8',
acl_policy=None):
"""
Loads a string to S3
Expand All @@ -494,10 +501,13 @@ def load_string(self,
:type encrypt: bool
:param encoding: The string to byte encoding
:type encoding: str
:param acl_policy: The string to specify the canned ACL policy for the
object to be uploaded
:type acl_policy: str
"""
bytes_data = string_data.encode(encoding)
file_obj = io.BytesIO(bytes_data)
self._upload_file_obj(file_obj, key, bucket_name, replace, encrypt)
self._upload_file_obj(file_obj, key, bucket_name, replace, encrypt, acl_policy)

@provide_bucket_name
@unify_bucket_name_and_key
Expand All @@ -506,7 +516,8 @@ def load_bytes(self,
key,
bucket_name=None,
replace=False,
encrypt=False):
encrypt=False,
acl_policy=None):
"""
Loads bytes to S3
Expand All @@ -525,9 +536,12 @@ def load_bytes(self,
:param encrypt: If True, the file will be encrypted on the server-side
by S3 and will be stored in an encrypted form while at rest in S3.
:type encrypt: bool
:param acl_policy: The string to specify the canned ACL policy for the
object to be uploaded
:type acl_policy: str
"""
file_obj = io.BytesIO(bytes_data)
self._upload_file_obj(file_obj, key, bucket_name, replace, encrypt)
self._upload_file_obj(file_obj, key, bucket_name, replace, encrypt, acl_policy)

@provide_bucket_name
@unify_bucket_name_and_key
Expand All @@ -536,7 +550,8 @@ def load_file_obj(self,
key,
bucket_name=None,
replace=False,
encrypt=False):
encrypt=False,
acl_policy=None):
"""
Loads a file object to S3
Expand All @@ -552,21 +567,27 @@ def load_file_obj(self,
:param encrypt: If True, S3 encrypts the file on the server,
and the file is stored in encrypted form at rest in S3.
:type encrypt: bool
:param acl_policy: The string to specify the canned ACL policy for the
object to be uploaded
:type acl_policy: str
"""
self._upload_file_obj(file_obj, key, bucket_name, replace, encrypt)
self._upload_file_obj(file_obj, key, bucket_name, replace, encrypt, acl_policy)

def _upload_file_obj(self,
file_obj,
key,
bucket_name=None,
replace=False,
encrypt=False):
encrypt=False,
acl_policy=None):
if not replace and self.check_for_key(key, bucket_name):
raise ValueError("The key {key} already exists.".format(key=key))

extra_args = {}
if encrypt:
extra_args['ServerSideEncryption'] = "AES256"
if acl_policy:
extra_args['ACL'] = acl_policy

client = self.get_conn()
client.upload_fileobj(file_obj, bucket_name, key, ExtraArgs=extra_args)
Expand All @@ -576,7 +597,8 @@ def copy_object(self,
dest_bucket_key,
source_bucket_name=None,
dest_bucket_name=None,
source_version_id=None):
source_version_id=None,
acl_policy='private'):
"""
Creates a copy of an object that is already stored in S3.
Expand Down Expand Up @@ -604,6 +626,9 @@ def copy_object(self,
:type dest_bucket_name: str
:param source_version_id: Version ID of the source object (OPTIONAL)
:type source_version_id: str
:param acl_policy: The string to specify the canned ACL policy for the
object to be copied which is private by default.
:type acl_policy: str
"""

if dest_bucket_name is None:
Expand All @@ -629,7 +654,8 @@ def copy_object(self,
'VersionId': source_version_id}
response = self.get_conn().copy_object(Bucket=dest_bucket_name,
Key=dest_bucket_key,
CopySource=copy_source)
CopySource=copy_source,
ACL=acl_policy)
return response

def delete_objects(self, bucket, keys):
Expand Down
55 changes: 55 additions & 0 deletions tests/providers/amazon/aws/hooks/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,12 +231,28 @@ def test_load_string(self, s3_bucket):
resource = boto3.resource('s3').Object(s3_bucket, 'my_key') # pylint: disable=no-member
assert resource.get()['Body'].read() == b'Cont\xC3\xA9nt'

def test_load_string_acl(self, s3_bucket):
hook = S3Hook()
hook.load_string("Contént", "my_key", s3_bucket,
acl_policy='public-read')
response = boto3.client('s3').get_object_acl(Bucket=s3_bucket, Key="my_key", RequestPayer='requester')
assert ((response['Grants'][1]['Permission'] == 'READ') and
(response['Grants'][0]['Permission'] == 'FULL_CONTROL'))

def test_load_bytes(self, s3_bucket):
hook = S3Hook()
hook.load_bytes(b"Content", "my_key", s3_bucket)
resource = boto3.resource('s3').Object(s3_bucket, 'my_key') # pylint: disable=no-member
assert resource.get()['Body'].read() == b'Content'

def test_load_bytes_acl(self, s3_bucket):
hook = S3Hook()
hook.load_bytes(b"Content", "my_key", s3_bucket,
acl_policy='public-read')
response = boto3.client('s3').get_object_acl(Bucket=s3_bucket, Key="my_key", RequestPayer='requester')
assert ((response['Grants'][1]['Permission'] == 'READ') and
(response['Grants'][0]['Permission'] == 'FULL_CONTROL'))

def test_load_fileobj(self, s3_bucket):
hook = S3Hook()
with tempfile.TemporaryFile() as temp_file:
Expand All @@ -246,6 +262,19 @@ def test_load_fileobj(self, s3_bucket):
resource = boto3.resource('s3').Object(s3_bucket, 'my_key') # pylint: disable=no-member
assert resource.get()['Body'].read() == b'Content'

def test_load_fileobj_acl(self, s3_bucket):
hook = S3Hook()
with tempfile.TemporaryFile() as temp_file:
temp_file.write(b"Content")
temp_file.seek(0)
hook.load_file_obj(temp_file, "my_key", s3_bucket,
acl_policy='public-read')
response = boto3.client('s3').get_object_acl(Bucket=s3_bucket,
Key="my_key",
RequestPayer='requester') # pylint: disable=no-member # noqa: E501 # pylint: disable=C0301
assert ((response['Grants'][1]['Permission'] == 'READ') and
(response['Grants'][0]['Permission'] == 'FULL_CONTROL'))

def test_load_file_gzip(self, s3_bucket):
hook = S3Hook()
with tempfile.NamedTemporaryFile() as temp_file:
Expand All @@ -255,6 +284,32 @@ def test_load_file_gzip(self, s3_bucket):
resource = boto3.resource('s3').Object(s3_bucket, 'my_key') # pylint: disable=no-member
assert gz.decompress(resource.get()['Body'].read()) == b'Content'

def test_load_file_acl(self, s3_bucket):
hook = S3Hook()
with tempfile.NamedTemporaryFile() as temp_file:
temp_file.write(b"Content")
temp_file.seek(0)
hook.load_file(temp_file, "my_key", s3_bucket, gzip=True,
acl_policy='public-read')
response = boto3.client('s3').get_object_acl(Bucket=s3_bucket,
Key="my_key",
RequestPayer='requester') # pylint: disable=no-member # noqa: E501 # pylint: disable=C0301
assert ((response['Grants'][1]['Permission'] == 'READ') and
(response['Grants'][0]['Permission'] == 'FULL_CONTROL'))

def test_copy_object_acl(self, s3_bucket):
hook = S3Hook()
with tempfile.NamedTemporaryFile() as temp_file:
temp_file.write(b"Content")
temp_file.seek(0)
hook.load_file_obj(temp_file, "my_key", s3_bucket)
hook.copy_object("my_key", "my_key", s3_bucket, s3_bucket)
response = boto3.client('s3').get_object_acl(Bucket=s3_bucket,
Key="my_key",
RequestPayer='requester') # pylint: disable=no-member # noqa: E501 # pylint: disable=C0301
assert ((response['Grants'][0]['Permission'] == 'FULL_CONTROL') and
(len(response['Grants']) == 1))

@mock.patch.object(S3Hook, 'get_connection', return_value=Connection(schema='test_bucket'))
def test_provide_bucket_name(self, mock_get_connection):

Expand Down

0 comments on commit a8b5fc7

Please sign in to comment.