Skip to content

Commit

Permalink
Reduce s3hook memory usage (#37886)
Browse files Browse the repository at this point in the history
  • Loading branch information
ellisms committed Mar 6, 2024
1 parent 1b04430 commit e7214fd
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 30 deletions.
53 changes: 23 additions & 30 deletions airflow/providers/amazon/aws/hooks/s3.py
Expand Up @@ -30,16 +30,18 @@
from contextlib import suppress
from copy import deepcopy
from datetime import datetime
from functools import wraps
from functools import cached_property, wraps
from inspect import signature
from io import BytesIO
from pathlib import Path
from tempfile import NamedTemporaryFile, gettempdir
from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast
from typing import TYPE_CHECKING, Any, Callable
from urllib.parse import urlsplit
from uuid import uuid4

if TYPE_CHECKING:
from mypy_boto3_s3.service_resource import Bucket as S3Bucket, Object as S3ResourceObject

from airflow.utils.types import ArgNotSet

with suppress(ImportError):
Expand All @@ -55,22 +57,17 @@
from airflow.providers.amazon.aws.utils.tags import format_tags
from airflow.utils.helpers import chunks

if TYPE_CHECKING:
from mypy_boto3_s3.service_resource import Bucket as S3Bucket, Object as S3ResourceObject

T = TypeVar("T", bound=Callable)

logger = logging.getLogger(__name__)


def provide_bucket_name(func: T) -> T:
def provide_bucket_name(func: Callable) -> Callable:
"""Provide a bucket name taken from the connection if no bucket name has been passed to the function."""
if hasattr(func, "_unify_bucket_name_and_key_wrapped"):
logger.warning("`unify_bucket_name_and_key` should wrap `provide_bucket_name`.")
function_signature = signature(func)

@wraps(func)
def wrapper(*args, **kwargs) -> T:
def wrapper(*args, **kwargs) -> Callable:
bound_args = function_signature.bind(*args, **kwargs)

if "bucket_name" not in bound_args.arguments:
Expand All @@ -90,10 +87,10 @@ def wrapper(*args, **kwargs) -> T:

return func(*bound_args.args, **bound_args.kwargs)

return cast(T, wrapper)
return wrapper


def provide_bucket_name_async(func: T) -> T:
def provide_bucket_name_async(func: Callable) -> Callable:
"""Provide a bucket name taken from the connection if no bucket name has been passed to the function."""
function_signature = signature(func)

Expand All @@ -110,15 +107,15 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any:

return await func(*bound_args.args, **bound_args.kwargs)

return cast(T, wrapper)
return wrapper


def unify_bucket_name_and_key(func: T) -> T:
def unify_bucket_name_and_key(func: Callable) -> Callable:
"""Unify bucket name and key in case no bucket name and at least a key has been passed to the function."""
function_signature = signature(func)

@wraps(func)
def wrapper(*args, **kwargs) -> T:
def wrapper(*args, **kwargs) -> Callable:
bound_args = function_signature.bind(*args, **kwargs)

if "wildcard_key" in bound_args.arguments:
Expand All @@ -141,7 +138,7 @@ def wrapper(*args, **kwargs) -> T:
# if provide_bucket_name is applied first, and there's a bucket defined in conn
# then if user supplies full key, bucket in key is not respected
wrapper._unify_bucket_name_and_key_wrapped = True # type: ignore[attr-defined]
return cast(T, wrapper)
return wrapper


class S3Hook(AwsBaseHook):
Expand Down Expand Up @@ -188,6 +185,15 @@ def __init__(

super().__init__(*args, **kwargs)

@cached_property
def resource(self):
return self.get_session().resource(
self.service_name,
endpoint_url=self.conn_config.get_service_endpoint_url(service_name=self.service_name),
config=self.config,
verify=self.verify,
)

@property
def extra_args(self):
"""Return hook's extra arguments (immutable)."""
Expand Down Expand Up @@ -307,13 +313,7 @@ def get_bucket(self, bucket_name: str | None = None) -> S3Bucket:
:param bucket_name: the name of the bucket
:return: the bucket object to the bucket name.
"""
s3_resource = self.get_session().resource(
"s3",
endpoint_url=self.conn_config.endpoint_url,
config=self.config,
verify=self.verify,
)
return s3_resource.Bucket(bucket_name)
return self.resource.Bucket(bucket_name)

@provide_bucket_name
def create_bucket(self, bucket_name: str | None = None, region_name: str | None = None) -> None:
Expand Down Expand Up @@ -943,14 +943,7 @@ def sanitize_extra_args() -> dict[str, str]:
if arg_name in S3Transfer.ALLOWED_DOWNLOAD_ARGS
}

s3_resource = self.get_session().resource(
"s3",
endpoint_url=self.conn_config.endpoint_url,
config=self.config,
verify=self.verify,
)
obj = s3_resource.Object(bucket_name, key)

obj = self.resource.Object(bucket_name, key)
obj.load(**sanitize_extra_args())
return obj

Expand Down
4 changes: 4 additions & 0 deletions tests/providers/amazon/aws/hooks/test_s3.py
Expand Up @@ -62,6 +62,10 @@ def test_get_conn(self):
hook = S3Hook()
assert hook.get_conn() is not None

def test_resource(self):
hook = S3Hook()
assert hook.resource is not None

def test_use_threads_default_value(self):
hook = S3Hook()
assert hook.transfer_config.use_threads is True
Expand Down

0 comments on commit e7214fd

Please sign in to comment.