Skip to content

Commit

Permalink
Improve modules import in AWS probvider by move some of them into a t…
Browse files Browse the repository at this point in the history
…ype-checking block (#33780)

* Improve modules import in AWS probvider by move some of them into a type-checking block

* comment from code review and fix static checks
  • Loading branch information
hussein-awala committed Aug 28, 2023
1 parent 83d09c0 commit 667ab8c
Show file tree
Hide file tree
Showing 46 changed files with 182 additions and 73 deletions.
7 changes: 4 additions & 3 deletions airflow/providers/amazon/aws/hooks/athena.py
Expand Up @@ -25,14 +25,15 @@
from __future__ import annotations

import warnings
from typing import Any

from botocore.paginate import PageIterator
from typing import TYPE_CHECKING, Any

from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.utils.waiter_with_logging import wait

if TYPE_CHECKING:
from botocore.paginate import PageIterator


class AthenaHook(AwsBaseHook):
"""Interact with Amazon Athena.
Expand Down
5 changes: 3 additions & 2 deletions airflow/providers/amazon/aws/hooks/base_aws.py
Expand Up @@ -43,9 +43,7 @@
import jinja2
import requests
import tenacity
from botocore.client import ClientMeta
from botocore.config import Config
from botocore.credentials import ReadOnlyCredentials
from botocore.waiter import Waiter, WaiterModel
from dateutil.tz import tzlocal
from slugify import slugify
Expand All @@ -66,6 +64,9 @@
BaseAwsConnection = TypeVar("BaseAwsConnection", bound=Union[boto3.client, boto3.resource])

if TYPE_CHECKING:
from botocore.client import ClientMeta
from botocore.credentials import ReadOnlyCredentials

from airflow.models.connection import Connection # Avoid circular imports.


Expand Down
6 changes: 4 additions & 2 deletions airflow/providers/amazon/aws/hooks/batch_client.py
Expand Up @@ -29,17 +29,19 @@
import itertools
from random import uniform
from time import sleep
from typing import Callable
from typing import TYPE_CHECKING, Callable

import botocore.client
import botocore.exceptions
import botocore.waiter

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher
from airflow.typing_compat import Protocol, runtime_checkable

if TYPE_CHECKING:
from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher


@runtime_checkable
class BatchProtocol(Protocol):
Expand Down
6 changes: 4 additions & 2 deletions airflow/providers/amazon/aws/hooks/batch_waiters.py
Expand Up @@ -29,15 +29,17 @@
import sys
from copy import deepcopy
from pathlib import Path
from typing import Callable
from typing import TYPE_CHECKING, Callable

import botocore.client
import botocore.exceptions
import botocore.waiter

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher

if TYPE_CHECKING:
from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher


class BatchWaitersHook(BatchClientHook):
Expand Down
6 changes: 5 additions & 1 deletion airflow/providers/amazon/aws/hooks/cloud_formation.py
Expand Up @@ -18,11 +18,15 @@
"""This module contains AWS CloudFormation Hook."""
from __future__ import annotations

from boto3 import client, resource
from typing import TYPE_CHECKING

from botocore.exceptions import ClientError

from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook

if TYPE_CHECKING:
from boto3 import client, resource


class CloudFormationHook(AwsBaseHook):
"""
Expand Down
5 changes: 4 additions & 1 deletion airflow/providers/amazon/aws/hooks/ecr.py
Expand Up @@ -20,11 +20,14 @@
import base64
import logging
from dataclasses import dataclass
from datetime import datetime
from typing import TYPE_CHECKING

from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.utils.log.secrets_masker import mask_secret

if TYPE_CHECKING:
from datetime import datetime

logger = logging.getLogger(__name__)


Expand Down
5 changes: 4 additions & 1 deletion airflow/providers/amazon/aws/hooks/ecs.py
Expand Up @@ -17,13 +17,16 @@
# under the License.
from __future__ import annotations

from botocore.waiter import Waiter
from typing import TYPE_CHECKING

from airflow.providers.amazon.aws.exceptions import EcsOperatorError, EcsTaskFailToStart
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
from airflow.providers.amazon.aws.utils import _StringCompareEnum
from airflow.typing_compat import Protocol, runtime_checkable

if TYPE_CHECKING:
from botocore.waiter import Waiter


def should_retry(exception: Exception):
"""Check if exception is related to ECS resource quota (CPU, MEM)."""
Expand Down
7 changes: 4 additions & 3 deletions airflow/providers/amazon/aws/links/emr.py
Expand Up @@ -16,15 +16,16 @@
# under the License.
from __future__ import annotations

from typing import Any

import boto3
from typing import TYPE_CHECKING, Any

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.amazon.aws.links.base_aws import BASE_AWS_CONSOLE_LINK, BaseAwsLink
from airflow.utils.helpers import exactly_one

if TYPE_CHECKING:
import boto3


class EmrClusterLink(BaseAwsLink):
"""Helper class for constructing AWS EMR Cluster Link."""
Expand Down
5 changes: 4 additions & 1 deletion airflow/providers/amazon/aws/log/cloudwatch_task_handler.py
Expand Up @@ -19,16 +19,19 @@

from datetime import datetime, timedelta
from functools import cached_property
from typing import TYPE_CHECKING

import watchtower

from airflow.configuration import conf
from airflow.models import TaskInstance
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
from airflow.providers.amazon.aws.utils import datetime_to_epoch_utc_ms
from airflow.utils.log.file_task_handler import FileTaskHandler
from airflow.utils.log.logging_mixin import LoggingMixin

if TYPE_CHECKING:
from airflow.models import TaskInstance


class CloudwatchTaskHandler(FileTaskHandler, LoggingMixin):
"""
Expand Down
5 changes: 4 additions & 1 deletion airflow/providers/amazon/aws/notifications/chime.py
Expand Up @@ -18,10 +18,13 @@
from __future__ import annotations

from functools import cached_property
from typing import TYPE_CHECKING

from airflow.exceptions import AirflowOptionalProviderFeatureException
from airflow.providers.amazon.aws.hooks.chime import ChimeWebhookHook
from airflow.utils.context import Context

if TYPE_CHECKING:
from airflow.utils.context import Context

try:
from airflow.notifications.basenotifier import BaseNotifier
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/amazon/aws/operators/ecs.py
Expand Up @@ -24,8 +24,6 @@
from functools import cached_property
from typing import TYPE_CHECKING, Sequence

import boto3

from airflow.configuration import conf
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.models import BaseOperator
Expand All @@ -43,6 +41,8 @@
from airflow.utils.helpers import prune_dict

if TYPE_CHECKING:
import boto3

from airflow.models import TaskInstance
from airflow.utils.context import Context

Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/amazon/aws/operators/rds.py
Expand Up @@ -23,8 +23,6 @@
from functools import cached_property
from typing import TYPE_CHECKING, Any, Sequence

from mypy_boto3_rds.type_defs import TagTypeDef

from airflow.configuration import conf
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.models import BaseOperator
Expand All @@ -39,6 +37,8 @@
from airflow.providers.amazon.aws.utils.waiter_with_logging import wait

if TYPE_CHECKING:
from mypy_boto3_rds.type_defs import TagTypeDef

from airflow.utils.context import Context


Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/amazon/aws/sensors/ecs.py
Expand Up @@ -19,8 +19,6 @@
from functools import cached_property
from typing import TYPE_CHECKING, Sequence

import boto3

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.ecs import (
EcsClusterStates,
Expand All @@ -31,6 +29,8 @@
from airflow.sensors.base import BaseSensorOperator

if TYPE_CHECKING:
import boto3

from airflow.utils.context import Context

DEFAULT_CONN_ID: str = "aws_default"
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/sensors/sqs.py
Expand Up @@ -26,13 +26,13 @@

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import BaseAwsConnection
from airflow.providers.amazon.aws.hooks.sqs import SqsHook
from airflow.providers.amazon.aws.triggers.sqs import SqsSensorTrigger
from airflow.providers.amazon.aws.utils.sqs import process_response
from airflow.sensors.base import BaseSensorOperator

if TYPE_CHECKING:
from airflow.providers.amazon.aws.hooks.base_aws import BaseAwsConnection
from airflow.utils.context import Context
from datetime import timedelta

Expand Down
5 changes: 3 additions & 2 deletions airflow/providers/amazon/aws/transfers/mongo_to_s3.py
Expand Up @@ -21,14 +21,15 @@
from typing import TYPE_CHECKING, Any, Iterable, Sequence, cast

from bson import json_util
from pymongo.command_cursor import CommandCursor
from pymongo.cursor import Cursor

from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.mongo.hooks.mongo import MongoHook

if TYPE_CHECKING:
from pymongo.command_cursor import CommandCursor
from pymongo.cursor import Cursor

from airflow.utils.context import Context


Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/transfers/sql_to_s3.py
Expand Up @@ -28,11 +28,11 @@
from airflow.hooks.base import BaseHook
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.common.sql.hooks.sql import DbApiHook

if TYPE_CHECKING:
import pandas as pd

from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.utils.context import Context


Expand Down
6 changes: 5 additions & 1 deletion airflow/providers/amazon/aws/triggers/athena.py
Expand Up @@ -16,10 +16,14 @@
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING

from airflow.providers.amazon.aws.hooks.athena import AthenaHook
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger

if TYPE_CHECKING:
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook


class AthenaTrigger(AwsBaseWaiterTrigger):
"""
Expand Down
6 changes: 4 additions & 2 deletions airflow/providers/amazon/aws/triggers/base.py
Expand Up @@ -18,12 +18,14 @@
from __future__ import annotations

from abc import abstractmethod
from typing import Any, AsyncIterator
from typing import TYPE_CHECKING, Any, AsyncIterator

from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait
from airflow.triggers.base import BaseTrigger, TriggerEvent

if TYPE_CHECKING:
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook


class AwsBaseWaiterTrigger(BaseTrigger):
"""
Expand Down
6 changes: 4 additions & 2 deletions airflow/providers/amazon/aws/triggers/batch.py
Expand Up @@ -19,16 +19,18 @@
import asyncio
import itertools
from functools import cached_property
from typing import Any
from typing import TYPE_CHECKING, Any

from botocore.exceptions import WaiterError
from deprecated import deprecated

from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
from airflow.triggers.base import BaseTrigger, TriggerEvent

if TYPE_CHECKING:
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook


@deprecated(reason="use BatchJobTrigger instead")
class BatchOperatorTrigger(BaseTrigger):
Expand Down
6 changes: 4 additions & 2 deletions airflow/providers/amazon/aws/triggers/ecs.py
Expand Up @@ -18,18 +18,20 @@
from __future__ import annotations

import asyncio
from typing import Any, AsyncIterator
from typing import TYPE_CHECKING, Any, AsyncIterator

from botocore.exceptions import ClientError, WaiterError

from airflow import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
from airflow.providers.amazon.aws.hooks.ecs import EcsHook
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher
from airflow.triggers.base import BaseTrigger, TriggerEvent

if TYPE_CHECKING:
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook


class ClusterActiveTrigger(AwsBaseWaiterTrigger):
"""
Expand Down
6 changes: 4 additions & 2 deletions airflow/providers/amazon/aws/triggers/eks.py
Expand Up @@ -17,15 +17,17 @@
from __future__ import annotations

import warnings
from typing import Any
from typing import TYPE_CHECKING, Any

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
from airflow.providers.amazon.aws.hooks.eks import EksHook
from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait
from airflow.triggers.base import TriggerEvent

if TYPE_CHECKING:
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook


class EksCreateClusterTrigger(AwsBaseWaiterTrigger):
"""
Expand Down

0 comments on commit 667ab8c

Please sign in to comment.