Skip to content

Commit

Permalink
Update return types of get_key methods on S3Hook (#30923)
Browse files Browse the repository at this point in the history
`S3Hook` has two methods, `get_key` and `get_wildcard_key`, which use an
AWS `ServiceResource` to fetch object data from S3. Both methods are
correctly documented as returning an instance of `S3.Object`, but their
return types are annotated with `S3Transfer`. This is incorrect.
The actual return type, `S3.Object`, is not a subtype of `S3Transfer`, and
the two types have many different methods.

This PR uses the `mypy-boto3-s3` package to set a correct return type of
S3 resource `Object` for `get_key` and `get_wildcard_key`.
  • Loading branch information
jonshea committed May 3, 2023
1 parent 2d5166f commit cb71d41
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 4 deletions.
11 changes: 7 additions & 4 deletions airflow/providers/amazon/aws/hooks/s3.py
Expand Up @@ -33,11 +33,11 @@
from io import BytesIO
from pathlib import Path
from tempfile import NamedTemporaryFile, gettempdir
from typing import Any, Callable, TypeVar, cast
from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast
from urllib.parse import urlsplit
from uuid import uuid4

from boto3.s3.transfer import S3Transfer, TransferConfig
from boto3.s3.transfer import TransferConfig
from botocore.exceptions import ClientError

from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
Expand All @@ -46,6 +46,9 @@
from airflow.providers.amazon.aws.utils.tags import format_tags
from airflow.utils.helpers import chunks

if TYPE_CHECKING:
from mypy_boto3_s3.service_resource import Object as S3ResourceObject

T = TypeVar("T", bound=Callable)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -521,7 +524,7 @@ def check_for_key(self, key: str, bucket_name: str | None = None) -> bool:

@unify_bucket_name_and_key
@provide_bucket_name
def get_key(self, key: str, bucket_name: str | None = None) -> S3Transfer:
def get_key(self, key: str, bucket_name: str | None = None) -> S3ResourceObject:
"""
Returns a :py:class:`S3.Object`.
Expand Down Expand Up @@ -626,7 +629,7 @@ def check_for_wildcard_key(
@provide_bucket_name
def get_wildcard_key(
self, wildcard_key: str, bucket_name: str | None = None, delimiter: str = ""
) -> S3Transfer:
) -> S3ResourceObject | None:
"""
Returns a boto3.s3.Object object matching the wildcard expression
Expand Down
2 changes: 2 additions & 0 deletions airflow/providers/amazon/provider.yaml
Expand Up @@ -72,6 +72,8 @@ dependencies:
- mypy-boto3-rds>=1.24.0
- mypy-boto3-redshift-data>=1.24.0
- mypy-boto3-appflow>=1.24.0
- mypy-boto3-s3>=1.24.0


integrations:
- integration-name: Amazon Athena
Expand Down
4 changes: 4 additions & 0 deletions airflow/providers/apache/hive/transfers/s3_to_hive.py
Expand Up @@ -149,6 +149,10 @@ def execute(self, context: Context):

else:
raise AirflowException(f"The key {self.s3_key} does not exists")

if TYPE_CHECKING:
assert s3_key_object

_, file_ext = os.path.splitext(s3_key_object.key)
if self.select_expression and self.input_compressed and file_ext.lower() != ".gz":
raise AirflowException("GZIP is the only compression format Amazon S3 Select supports")
Expand Down
1 change: 1 addition & 0 deletions docs/apache-airflow-providers-amazon/index.rst
Expand Up @@ -106,6 +106,7 @@ PIP package Version required
``mypy-boto3-rds`` ``>=1.24.0``
``mypy-boto3-redshift-data`` ``>=1.24.0``
``mypy-boto3-appflow`` ``>=1.24.0``
``mypy-boto3-s3`` ``>=1.24.0``
======================================= ==================

Cross provider package dependencies
Expand Down
1 change: 1 addition & 0 deletions generated/provider_dependencies.json
Expand Up @@ -25,6 +25,7 @@
"mypy-boto3-appflow>=1.24.0",
"mypy-boto3-rds>=1.24.0",
"mypy-boto3-redshift-data>=1.24.0",
"mypy-boto3-s3>=1.24.0",
"redshift_connector>=2.0.888",
"sqlalchemy_redshift>=0.8.6",
"watchtower~=2.0.1"
Expand Down

0 comments on commit cb71d41

Please sign in to comment.