diff --git a/airflow/providers/amazon/aws/hooks/s3.py b/airflow/providers/amazon/aws/hooks/s3.py index 6e7f2a8d7f52f..fa1246e4efd19 100644 --- a/airflow/providers/amazon/aws/hooks/s3.py +++ b/airflow/providers/amazon/aws/hooks/s3.py @@ -24,6 +24,7 @@ import io import re from functools import wraps +from inspect import signature from urllib.parse import urlparse from botocore.exceptions import ClientError @@ -36,24 +37,52 @@ def provide_bucket_name(func): """ Function decorator that provides a bucket name taken from the connection - in case no bucket name has been passed to the function and, if available, also no key has been passed. + in case no bucket name has been passed to the function. """ + function_signature = signature(func) + @wraps(func) def wrapper(*args, **kwargs): - func_params = func.__code__.co_varnames - - def has_arg(name): - name_in_args = name in func_params and func_params.index(name) < len(args) - name_in_kwargs = name in kwargs - return name_in_args or name_in_kwargs + bound_args = function_signature.bind(*args, **kwargs) - if not has_arg('bucket_name') and not (has_arg('key') or has_arg('wildcard_key')): + if 'bucket_name' not in bound_args.arguments: self = args[0] - connection = self.get_connection(self.aws_conn_id) - kwargs['bucket_name'] = connection.schema + if self.aws_conn_id: + connection = self.get_connection(self.aws_conn_id) + if connection.schema: + bound_args.arguments['bucket_name'] = connection.schema + + return func(*bound_args.args, **bound_args.kwargs) + + return wrapper + + +def unify_bucket_name_and_key(func): + """ + Function decorator that unifies bucket name and key taken from the key + in case no bucket name and at least a key has been passed to the function. + """ + + function_signature = signature(func) + + @wraps(func) + def wrapper(*args, **kwargs): + bound_args = function_signature.bind(*args, **kwargs) + + def get_key_name(): + if 'wildcard_key' in bound_args.arguments: + return 'wildcard_key' + if 'key' in bound_args.arguments: + return 'key' + raise ValueError('Missing key parameter!') - return func(*args, **kwargs) + key_name = get_key_name() + if key_name and 'bucket_name' not in bound_args.arguments: + bound_args.arguments['bucket_name'], bound_args.arguments[key_name] = \ + S3Hook.parse_s3_url(bound_args.arguments[key_name]) + + return func(*bound_args.args, **bound_args.kwargs) return wrapper @@ -242,6 +271,7 @@ def list_keys(self, bucket_name=None, prefix='', delimiter='', return None @provide_bucket_name + @unify_bucket_name_and_key def check_for_key(self, key, bucket_name=None): """ Checks if a key exists in a bucket @@ -253,8 +283,6 @@ def check_for_key(self, key, bucket_name=None): :return: True if the key exists and False if not. :rtype: bool """ - if not bucket_name: - (bucket_name, key) = self.parse_s3_url(key) try: self.get_conn().head_object(Bucket=bucket_name, Key=key) @@ -264,6 +292,7 @@ def check_for_key(self, key, bucket_name=None): return False @provide_bucket_name + @unify_bucket_name_and_key def get_key(self, key, bucket_name=None): """ Returns a boto3.s3.Object @@ -275,14 +304,13 @@ def get_key(self, key, bucket_name=None): :return: the key object from the bucket :rtype: boto3.s3.Object """ - if not bucket_name: - (bucket_name, key) = self.parse_s3_url(key) obj = self.get_resource_type('s3').Object(bucket_name, key) obj.load() return obj @provide_bucket_name + @unify_bucket_name_and_key def read_key(self, key, bucket_name=None): """ Reads a key from S3 @@ -299,6 +327,7 @@ def read_key(self, key, bucket_name=None): return obj.get()['Body'].read().decode('utf-8') @provide_bucket_name + @unify_bucket_name_and_key def select_key(self, key, bucket_name=None, expression='SELECT * FROM S3Object', expression_type='SQL', @@ -330,8 +359,6 @@ def select_key(self, key, bucket_name=None, input_serialization = {'CSV': {}} if output_serialization is None: output_serialization = {'CSV': {}} - if not bucket_name: - (bucket_name, key) = self.parse_s3_url(key) response = self.get_conn().select_object_content( Bucket=bucket_name, @@ -346,6 +373,7 @@ def select_key(self, key, bucket_name=None, if 'Records' in event) @provide_bucket_name + @unify_bucket_name_and_key def check_for_wildcard_key(self, wildcard_key, bucket_name=None, delimiter=''): """ @@ -365,6 +393,7 @@ def check_for_wildcard_key(self, delimiter=delimiter) is not None @provide_bucket_name + @unify_bucket_name_and_key def get_wildcard_key(self, wildcard_key, bucket_name=None, delimiter=''): """ Returns a boto3.s3.Object object matching the wildcard expression @@ -378,8 +407,6 @@ def get_wildcard_key(self, wildcard_key, bucket_name=None, delimiter=''): :return: the key object from the bucket or None if none has been found. :rtype: boto3.s3.Object """ - if not bucket_name: - (bucket_name, wildcard_key) = self.parse_s3_url(wildcard_key) prefix = re.split(r'[*]', wildcard_key, 1)[0] key_list = self.list_keys(bucket_name, prefix=prefix, delimiter=delimiter) @@ -390,6 +417,7 @@ def get_wildcard_key(self, wildcard_key, bucket_name=None, delimiter=''): return None @provide_bucket_name + @unify_bucket_name_and_key def load_file(self, filename, key, @@ -413,8 +441,6 @@ def load_file(self, by S3 and will be stored in an encrypted form while at rest in S3. :type encrypt: bool """ - if not bucket_name: - (bucket_name, key) = self.parse_s3_url(key) if not replace and self.check_for_key(key, bucket_name): raise ValueError("The key {key} already exists.".format(key=key)) @@ -427,6 +453,7 @@ def load_file(self, client.upload_file(filename, bucket_name, key, ExtraArgs=extra_args) @provide_bucket_name + @unify_bucket_name_and_key def load_string(self, string_data, key, @@ -460,6 +487,7 @@ def load_string(self, self._upload_file_obj(file_obj, key, bucket_name, replace, encrypt) @provide_bucket_name + @unify_bucket_name_and_key def load_bytes(self, bytes_data, key, @@ -489,6 +517,7 @@ def load_bytes(self, self._upload_file_obj(file_obj, key, bucket_name, replace, encrypt) @provide_bucket_name + @unify_bucket_name_and_key def load_file_obj(self, file_obj, key, @@ -519,9 +548,6 @@ def _upload_file_obj(self, bucket_name=None, replace=False, encrypt=False): - if not bucket_name: - (bucket_name, key) = self.parse_s3_url(key) - if not replace and self.check_for_key(key, bucket_name): raise ValueError("The key {key} already exists.".format(key=key)) diff --git a/tests/providers/amazon/aws/hooks/test_s3.py b/tests/providers/amazon/aws/hooks/test_s3.py index 8b384f3032a0c..a1891f8c149ed 100644 --- a/tests/providers/amazon/aws/hooks/test_s3.py +++ b/tests/providers/amazon/aws/hooks/test_s3.py @@ -25,7 +25,7 @@ from airflow.exceptions import AirflowException from airflow.models import Connection -from airflow.providers.amazon.aws.hooks.s3 import S3Hook, provide_bucket_name +from airflow.providers.amazon.aws.hooks.s3 import S3Hook, provide_bucket_name, unify_bucket_name_and_key try: from moto import mock_s3 @@ -252,24 +252,13 @@ class FakeS3Hook(S3Hook): def test_function(self, bucket_name=None): return bucket_name - # pylint: disable=unused-argument - @provide_bucket_name - def test_function_with_key(self, key, bucket_name=None): - return bucket_name - - # pylint: disable=unused-argument - @provide_bucket_name - def test_function_with_wildcard_key(self, wildcard_key, bucket_name=None): - return bucket_name - fake_s3_hook = FakeS3Hook() - test_bucket_name = fake_s3_hook.test_function() - test_bucket_name_with_key = fake_s3_hook.test_function_with_key('test_key') - test_bucket_name_with_wildcard_key = fake_s3_hook.test_function_with_wildcard_key('test_*_key') + test_bucket_name = fake_s3_hook.test_function() assert test_bucket_name == mock_get_connection.return_value.schema - assert test_bucket_name_with_key is None - assert test_bucket_name_with_wildcard_key is None + + test_bucket_name = fake_s3_hook.test_function(bucket_name='bucket') + assert test_bucket_name == 'bucket' def test_delete_objects_key_does_not_exist(self, s3_bucket): hook = S3Hook() @@ -298,3 +287,31 @@ def test_delete_objects_many_keys(self, mocked_s3_res, s3_bucket): hook = S3Hook() hook.delete_objects(bucket=s3_bucket, keys=keys) assert [o.key for o in mocked_s3_res.Bucket(s3_bucket).objects.all()] == [] + + def test_unify_bucket_name_and_key(self): + + class FakeS3Hook(S3Hook): + + @unify_bucket_name_and_key + def test_function_with_wildcard_key(self, wildcard_key, bucket_name=None): + return bucket_name, wildcard_key + + @unify_bucket_name_and_key + def test_function_with_key(self, key, bucket_name=None): + return bucket_name, key + + @unify_bucket_name_and_key + def test_function_with_test_key(self, test_key, bucket_name=None): + return bucket_name, test_key + + fake_s3_hook = FakeS3Hook() + + test_bucket_name_with_wildcard_key = fake_s3_hook.test_function_with_wildcard_key('s3://foo/bar*.csv') + assert ('foo', 'bar*.csv') == test_bucket_name_with_wildcard_key + + test_bucket_name_with_key = fake_s3_hook.test_function_with_key('s3://foo/bar.csv') + assert ('foo', 'bar.csv') == test_bucket_name_with_key + + with pytest.raises(ValueError) as err: + fake_s3_hook.test_function_with_test_key('s3://foo/bar.csv') + assert isinstance(err.value, ValueError)