Skip to content
This repository has been archived by the owner on May 22, 2021. It is now read-only.

Commit

Permalink
[AIRFLOW-5924] Automatically unify bucket name and key in S3Hook (apa…
Browse files Browse the repository at this point in the history
…che#6574)

- change provide_bucket_name to provide bucket name also for function with keys
- refactoring
  • Loading branch information
feluelle authored and galuszkak committed Mar 5, 2020
1 parent 15050b1 commit 580a7f5
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 40 deletions.
74 changes: 50 additions & 24 deletions airflow/providers/amazon/aws/hooks/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import io
import re
from functools import wraps
from inspect import signature
from urllib.parse import urlparse

from botocore.exceptions import ClientError
Expand All @@ -36,24 +37,52 @@
def provide_bucket_name(func):
"""
Function decorator that provides a bucket name taken from the connection
in case no bucket name has been passed to the function and, if available, also no key has been passed.
in case no bucket name has been passed to the function.
"""

function_signature = signature(func)

@wraps(func)
def wrapper(*args, **kwargs):
func_params = func.__code__.co_varnames

def has_arg(name):
name_in_args = name in func_params and func_params.index(name) < len(args)
name_in_kwargs = name in kwargs
return name_in_args or name_in_kwargs
bound_args = function_signature.bind(*args, **kwargs)

if not has_arg('bucket_name') and not (has_arg('key') or has_arg('wildcard_key')):
if 'bucket_name' not in bound_args.arguments:
self = args[0]
connection = self.get_connection(self.aws_conn_id)
kwargs['bucket_name'] = connection.schema
if self.aws_conn_id:
connection = self.get_connection(self.aws_conn_id)
if connection.schema:
bound_args.arguments['bucket_name'] = connection.schema

return func(*bound_args.args, **bound_args.kwargs)

return wrapper


def unify_bucket_name_and_key(func):
"""
Function decorator that unifies bucket name and key taken from the key
in case no bucket name and at least a key has been passed to the function.
"""

function_signature = signature(func)

@wraps(func)
def wrapper(*args, **kwargs):
bound_args = function_signature.bind(*args, **kwargs)

def get_key_name():
if 'wildcard_key' in bound_args.arguments:
return 'wildcard_key'
if 'key' in bound_args.arguments:
return 'key'
raise ValueError('Missing key parameter!')

return func(*args, **kwargs)
key_name = get_key_name()
if key_name and 'bucket_name' not in bound_args.arguments:
bound_args.arguments['bucket_name'], bound_args.arguments[key_name] = \
S3Hook.parse_s3_url(bound_args.arguments[key_name])

return func(*bound_args.args, **bound_args.kwargs)

return wrapper

Expand Down Expand Up @@ -242,6 +271,7 @@ def list_keys(self, bucket_name=None, prefix='', delimiter='',
return None

@provide_bucket_name
@unify_bucket_name_and_key
def check_for_key(self, key, bucket_name=None):
"""
Checks if a key exists in a bucket
Expand All @@ -253,8 +283,6 @@ def check_for_key(self, key, bucket_name=None):
:return: True if the key exists and False if not.
:rtype: bool
"""
if not bucket_name:
(bucket_name, key) = self.parse_s3_url(key)

try:
self.get_conn().head_object(Bucket=bucket_name, Key=key)
Expand All @@ -264,6 +292,7 @@ def check_for_key(self, key, bucket_name=None):
return False

@provide_bucket_name
@unify_bucket_name_and_key
def get_key(self, key, bucket_name=None):
"""
Returns a boto3.s3.Object
Expand All @@ -275,14 +304,13 @@ def get_key(self, key, bucket_name=None):
:return: the key object from the bucket
:rtype: boto3.s3.Object
"""
if not bucket_name:
(bucket_name, key) = self.parse_s3_url(key)

obj = self.get_resource_type('s3').Object(bucket_name, key)
obj.load()
return obj

@provide_bucket_name
@unify_bucket_name_and_key
def read_key(self, key, bucket_name=None):
"""
Reads a key from S3
Expand All @@ -299,6 +327,7 @@ def read_key(self, key, bucket_name=None):
return obj.get()['Body'].read().decode('utf-8')

@provide_bucket_name
@unify_bucket_name_and_key
def select_key(self, key, bucket_name=None,
expression='SELECT * FROM S3Object',
expression_type='SQL',
Expand Down Expand Up @@ -330,8 +359,6 @@ def select_key(self, key, bucket_name=None,
input_serialization = {'CSV': {}}
if output_serialization is None:
output_serialization = {'CSV': {}}
if not bucket_name:
(bucket_name, key) = self.parse_s3_url(key)

response = self.get_conn().select_object_content(
Bucket=bucket_name,
Expand All @@ -346,6 +373,7 @@ def select_key(self, key, bucket_name=None,
if 'Records' in event)

@provide_bucket_name
@unify_bucket_name_and_key
def check_for_wildcard_key(self,
wildcard_key, bucket_name=None, delimiter=''):
"""
Expand All @@ -365,6 +393,7 @@ def check_for_wildcard_key(self,
delimiter=delimiter) is not None

@provide_bucket_name
@unify_bucket_name_and_key
def get_wildcard_key(self, wildcard_key, bucket_name=None, delimiter=''):
"""
Returns a boto3.s3.Object object matching the wildcard expression
Expand All @@ -378,8 +407,6 @@ def get_wildcard_key(self, wildcard_key, bucket_name=None, delimiter=''):
:return: the key object from the bucket or None if none has been found.
:rtype: boto3.s3.Object
"""
if not bucket_name:
(bucket_name, wildcard_key) = self.parse_s3_url(wildcard_key)

prefix = re.split(r'[*]', wildcard_key, 1)[0]
key_list = self.list_keys(bucket_name, prefix=prefix, delimiter=delimiter)
Expand All @@ -390,6 +417,7 @@ def get_wildcard_key(self, wildcard_key, bucket_name=None, delimiter=''):
return None

@provide_bucket_name
@unify_bucket_name_and_key
def load_file(self,
filename,
key,
Expand All @@ -413,8 +441,6 @@ def load_file(self,
by S3 and will be stored in an encrypted form while at rest in S3.
:type encrypt: bool
"""
if not bucket_name:
(bucket_name, key) = self.parse_s3_url(key)

if not replace and self.check_for_key(key, bucket_name):
raise ValueError("The key {key} already exists.".format(key=key))
Expand All @@ -427,6 +453,7 @@ def load_file(self,
client.upload_file(filename, bucket_name, key, ExtraArgs=extra_args)

@provide_bucket_name
@unify_bucket_name_and_key
def load_string(self,
string_data,
key,
Expand Down Expand Up @@ -460,6 +487,7 @@ def load_string(self,
self._upload_file_obj(file_obj, key, bucket_name, replace, encrypt)

@provide_bucket_name
@unify_bucket_name_and_key
def load_bytes(self,
bytes_data,
key,
Expand Down Expand Up @@ -489,6 +517,7 @@ def load_bytes(self,
self._upload_file_obj(file_obj, key, bucket_name, replace, encrypt)

@provide_bucket_name
@unify_bucket_name_and_key
def load_file_obj(self,
file_obj,
key,
Expand Down Expand Up @@ -519,9 +548,6 @@ def _upload_file_obj(self,
bucket_name=None,
replace=False,
encrypt=False):
if not bucket_name:
(bucket_name, key) = self.parse_s3_url(key)

if not replace and self.check_for_key(key, bucket_name):
raise ValueError("The key {key} already exists.".format(key=key))

Expand Down
49 changes: 33 additions & 16 deletions tests/providers/amazon/aws/hooks/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from airflow.exceptions import AirflowException
from airflow.models import Connection
from airflow.providers.amazon.aws.hooks.s3 import S3Hook, provide_bucket_name
from airflow.providers.amazon.aws.hooks.s3 import S3Hook, provide_bucket_name, unify_bucket_name_and_key

try:
from moto import mock_s3
Expand Down Expand Up @@ -252,24 +252,13 @@ class FakeS3Hook(S3Hook):
def test_function(self, bucket_name=None):
return bucket_name

# pylint: disable=unused-argument
@provide_bucket_name
def test_function_with_key(self, key, bucket_name=None):
return bucket_name

# pylint: disable=unused-argument
@provide_bucket_name
def test_function_with_wildcard_key(self, wildcard_key, bucket_name=None):
return bucket_name

fake_s3_hook = FakeS3Hook()
test_bucket_name = fake_s3_hook.test_function()
test_bucket_name_with_key = fake_s3_hook.test_function_with_key('test_key')
test_bucket_name_with_wildcard_key = fake_s3_hook.test_function_with_wildcard_key('test_*_key')

test_bucket_name = fake_s3_hook.test_function()
assert test_bucket_name == mock_get_connection.return_value.schema
assert test_bucket_name_with_key is None
assert test_bucket_name_with_wildcard_key is None

test_bucket_name = fake_s3_hook.test_function(bucket_name='bucket')
assert test_bucket_name == 'bucket'

def test_delete_objects_key_does_not_exist(self, s3_bucket):
hook = S3Hook()
Expand Down Expand Up @@ -298,3 +287,31 @@ def test_delete_objects_many_keys(self, mocked_s3_res, s3_bucket):
hook = S3Hook()
hook.delete_objects(bucket=s3_bucket, keys=keys)
assert [o.key for o in mocked_s3_res.Bucket(s3_bucket).objects.all()] == []

def test_unify_bucket_name_and_key(self):

class FakeS3Hook(S3Hook):

@unify_bucket_name_and_key
def test_function_with_wildcard_key(self, wildcard_key, bucket_name=None):
return bucket_name, wildcard_key

@unify_bucket_name_and_key
def test_function_with_key(self, key, bucket_name=None):
return bucket_name, key

@unify_bucket_name_and_key
def test_function_with_test_key(self, test_key, bucket_name=None):
return bucket_name, test_key

fake_s3_hook = FakeS3Hook()

test_bucket_name_with_wildcard_key = fake_s3_hook.test_function_with_wildcard_key('s3://foo/bar*.csv')
assert ('foo', 'bar*.csv') == test_bucket_name_with_wildcard_key

test_bucket_name_with_key = fake_s3_hook.test_function_with_key('s3://foo/bar.csv')
assert ('foo', 'bar.csv') == test_bucket_name_with_key

with pytest.raises(ValueError) as err:
fake_s3_hook.test_function_with_test_key('s3://foo/bar.csv')
assert isinstance(err.value, ValueError)

0 comments on commit 580a7f5

Please sign in to comment.