Skip to content

Commit

Permalink
Improve modules import in Airflow providers by some of them into a ty…
Browse files Browse the repository at this point in the history
…pe-checking block (#33754)
  • Loading branch information
hussein-awala committed Aug 27, 2023
1 parent 6802d41 commit 9d8c77e
Show file tree
Hide file tree
Showing 61 changed files with 223 additions and 120 deletions.
6 changes: 4 additions & 2 deletions airflow/providers/apache/beam/hooks/beam.py
Expand Up @@ -23,15 +23,14 @@
import copy
import functools
import json
import logging
import os
import select
import shlex
import shutil
import subprocess
import tempfile
import textwrap
from typing import Callable
from typing import TYPE_CHECKING, Callable

from packaging.version import Version

Expand All @@ -40,6 +39,9 @@
from airflow.providers.google.go_module_utils import init_module, install_dependencies
from airflow.utils.python_virtualenv import prepare_virtualenv

if TYPE_CHECKING:
import logging


class BeamRunnerType:
"""
Expand Down
6 changes: 4 additions & 2 deletions airflow/providers/apache/drill/hooks/drill.py
Expand Up @@ -17,13 +17,15 @@
# under the License.
from __future__ import annotations

from typing import Any, Iterable
from typing import TYPE_CHECKING, Any, Iterable

from sqlalchemy import create_engine
from sqlalchemy.engine import Connection

from airflow.providers.common.sql.hooks.sql import DbApiHook

if TYPE_CHECKING:
from sqlalchemy.engine import Connection


class DrillHook(DbApiHook):
"""
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/apache/flink/operators/flink_kubernetes.py
Expand Up @@ -20,12 +20,12 @@
from functools import cached_property
from typing import TYPE_CHECKING, Sequence

from kubernetes.client import CoreV1Api

from airflow.models import BaseOperator
from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook

if TYPE_CHECKING:
from kubernetes.client import CoreV1Api

from airflow.utils.context import Context


Expand Down
6 changes: 5 additions & 1 deletion airflow/providers/apache/impala/hooks/impala.py
Expand Up @@ -16,11 +16,15 @@
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING

from impala.dbapi import connect
from impala.interface import Connection

from airflow.providers.common.sql.hooks.sql import DbApiHook

if TYPE_CHECKING:
from impala.interface import Connection


class ImpalaHook(DbApiHook):
"""Interact with Apache Impala through impyla."""
Expand Down
6 changes: 4 additions & 2 deletions airflow/providers/apache/livy/hooks/livy.py
Expand Up @@ -21,18 +21,20 @@
import json
import re
from enum import Enum
from typing import Any, Sequence
from typing import TYPE_CHECKING, Any, Sequence

import aiohttp
import requests
from aiohttp import ClientResponseError
from asgiref.sync import sync_to_async

from airflow.exceptions import AirflowException
from airflow.models import Connection
from airflow.providers.http.hooks.http import HttpAsyncHook, HttpHook
from airflow.utils.log.logging_mixin import LoggingMixin

if TYPE_CHECKING:
from airflow.models import Connection


class BatchState(Enum):
"""Batch session states."""
Expand Down
6 changes: 4 additions & 2 deletions airflow/providers/apache/pinot/hooks/pinot.py
Expand Up @@ -19,15 +19,17 @@

import os
import subprocess
from typing import Any, Iterable, Mapping
from typing import TYPE_CHECKING, Any, Iterable, Mapping

from pinotdb import connect

from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
from airflow.models import Connection
from airflow.providers.common.sql.hooks.sql import DbApiHook

if TYPE_CHECKING:
from airflow.models import Connection


class PinotAdminHook(BaseHook):
"""
Expand Down
6 changes: 4 additions & 2 deletions airflow/providers/apprise/notifications/apprise.py
Expand Up @@ -18,7 +18,7 @@
from __future__ import annotations

from functools import cached_property
from typing import Iterable
from typing import TYPE_CHECKING, Iterable

from airflow.exceptions import AirflowOptionalProviderFeatureException

Expand All @@ -29,10 +29,12 @@
"Failed to import BaseNotifier. This feature is only available in Airflow versions >= 2.6.0"
)

from apprise import AppriseConfig, NotifyFormat, NotifyType

from airflow.providers.apprise.hooks.apprise import AppriseHook

if TYPE_CHECKING:
from apprise import AppriseConfig, NotifyFormat, NotifyType


class AppriseNotifier(BaseNotifier):
"""
Expand Down
3 changes: 2 additions & 1 deletion airflow/providers/celery/executors/celery_executor.py
Expand Up @@ -23,7 +23,6 @@
"""
from __future__ import annotations

import argparse
import logging
import math
import operator
Expand Down Expand Up @@ -83,6 +82,8 @@


if TYPE_CHECKING:
import argparse

from celery import Task

from airflow.executors.base_executor import CommandType, TaskTuple
Expand Down
5 changes: 3 additions & 2 deletions airflow/providers/celery/executors/celery_executor_utils.py
Expand Up @@ -34,7 +34,6 @@
from celery import Celery, Task, states as celery_states
from celery.backends.base import BaseKeyValueStoreBackend
from celery.backends.database import DatabaseBackend, Task as TaskDb, retry, session_cleanup
from celery.result import AsyncResult
from celery.signals import import_modules as celery_import_modules
from setproctitle import setproctitle
from sqlalchemy import select
Expand All @@ -43,7 +42,6 @@
from airflow.configuration import conf
from airflow.exceptions import AirflowException, RemovedInAirflow3Warning
from airflow.executors.base_executor import BaseExecutor
from airflow.models.taskinstance import TaskInstanceKey
from airflow.providers.celery.executors.default_celery import DEFAULT_CELERY_CONFIG
from airflow.stats import Stats
from airflow.utils.dag_parsing_context import _airflow_parsing_context_manager
Expand All @@ -54,7 +52,10 @@
log = logging.getLogger(__name__)

if TYPE_CHECKING:
from celery.result import AsyncResult

from airflow.executors.base_executor import CommandType, EventBufferValueType
from airflow.models.taskinstance import TaskInstanceKey

TaskInstanceInCelery = Tuple[TaskInstanceKey, CommandType, Optional[str], Task]

Expand Down
Expand Up @@ -20,8 +20,6 @@
from functools import cached_property
from typing import TYPE_CHECKING, Sequence

from airflow.callbacks.base_callback_sink import BaseCallbackSink
from airflow.callbacks.callback_requests import CallbackRequest
from airflow.configuration import conf
from airflow.providers.celery.executors.celery_executor import CeleryExecutor

Expand All @@ -36,6 +34,8 @@
from airflow.utils.providers_configuration_loader import providers_configuration_loaded

if TYPE_CHECKING:
from airflow.callbacks.base_callback_sink import BaseCallbackSink
from airflow.callbacks.callback_requests import CallbackRequest
from airflow.executors.base_executor import CommandType, EventBufferValueType, QueuedTaskInstanceType
from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance
from airflow.models.taskinstancekey import TaskInstanceKey
Expand Down
6 changes: 4 additions & 2 deletions airflow/providers/databricks/hooks/databricks_base.py
Expand Up @@ -28,7 +28,7 @@
import platform
import time
from functools import cached_property
from typing import Any
from typing import TYPE_CHECKING, Any
from urllib.parse import urlsplit

import aiohttp
Expand All @@ -48,9 +48,11 @@
from airflow import __version__
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
from airflow.models import Connection
from airflow.providers_manager import ProvidersManager

if TYPE_CHECKING:
from airflow.models import Connection

# https://docs.microsoft.com/en-us/azure/databricks/dev-tools/api/latest/aad/service-prin-aad-token#--get-an-azure-active-directory-access-token
# https://docs.microsoft.com/en-us/graph/deployments#app-registration-and-token-service-root-endpoints
AZURE_DEFAULT_AD_ENDPOINT = "https://login.microsoftonline.com"
Expand Down
6 changes: 4 additions & 2 deletions airflow/providers/databricks/hooks/databricks_sql.py
Expand Up @@ -18,15 +18,17 @@

from contextlib import closing
from copy import copy
from typing import Any, Callable, Iterable, Mapping, TypeVar, overload
from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, TypeVar, overload

from databricks import sql # type: ignore[attr-defined]
from databricks.sql.client import Connection # type: ignore[attr-defined]

from airflow.exceptions import AirflowException
from airflow.providers.common.sql.hooks.sql import DbApiHook, return_single_query_results
from airflow.providers.databricks.hooks.databricks_base import BaseDatabricksHook

if TYPE_CHECKING:
from databricks.sql.client import Connection

LIST_SQL_ENDPOINTS_ENDPOINT = ("GET", "api/2.0/sql/endpoints")


Expand Down
3 changes: 2 additions & 1 deletion airflow/providers/databricks/operators/databricks.py
Expand Up @@ -21,7 +21,6 @@
import time
import warnings
from functools import cached_property
from logging import Logger
from typing import TYPE_CHECKING, Any, Sequence

from airflow.configuration import conf
Expand All @@ -32,6 +31,8 @@
from airflow.providers.databricks.utils.databricks import normalise_json_content, validate_trigger_event

if TYPE_CHECKING:
from logging import Logger

from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.utils.context import Context

Expand Down
3 changes: 2 additions & 1 deletion airflow/providers/databricks/operators/databricks_sql.py
Expand Up @@ -22,7 +22,6 @@
import json
from typing import TYPE_CHECKING, Any, Sequence

from databricks.sql.types import Row
from databricks.sql.utils import ParamEscaper

from airflow.exceptions import AirflowException
Expand All @@ -31,6 +30,8 @@
from airflow.providers.databricks.hooks.databricks_sql import DatabricksSqlHook

if TYPE_CHECKING:
from databricks.sql.types import Row

from airflow.utils.context import Context


Expand Down
6 changes: 4 additions & 2 deletions airflow/providers/elasticsearch/hooks/elasticsearch.py
Expand Up @@ -19,16 +19,18 @@

import warnings
from functools import cached_property
from typing import Any
from typing import TYPE_CHECKING, Any
from urllib import parse

from elasticsearch import Elasticsearch

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.hooks.base import BaseHook
from airflow.models.connection import Connection as AirflowConnection
from airflow.providers.common.sql.hooks.sql import DbApiHook

if TYPE_CHECKING:
from airflow.models.connection import Connection as AirflowConnection


def connect(
host: str = "localhost",
Expand Down
7 changes: 5 additions & 2 deletions airflow/providers/elasticsearch/log/es_task_handler.py
Expand Up @@ -21,7 +21,6 @@
import sys
import warnings
from collections import defaultdict
from datetime import datetime
from operator import attrgetter
from time import time
from typing import TYPE_CHECKING, Any, Callable, List, Tuple
Expand All @@ -35,14 +34,18 @@
from airflow.configuration import conf
from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance
from airflow.providers.elasticsearch.log.es_json_formatter import ElasticsearchJSONFormatter
from airflow.providers.elasticsearch.log.es_response import ElasticSearchResponse, Hit
from airflow.utils import timezone
from airflow.utils.log.file_task_handler import FileTaskHandler
from airflow.utils.log.logging_mixin import ExternalLoggingMixin, LoggingMixin
from airflow.utils.session import create_session

if TYPE_CHECKING:
from datetime import datetime

from airflow.models.taskinstance import TaskInstance

LOG_LINE_DEFAULTS = {"exc_text": "", "stack_info": ""}
# Elasticsearch hosted log type
EsLogMsgType = List[Tuple[str, str]]
Expand Down
6 changes: 4 additions & 2 deletions airflow/providers/facebook/ads/hooks/ads.py
Expand Up @@ -21,16 +21,18 @@
import time
from enum import Enum
from functools import cached_property
from typing import Any
from typing import TYPE_CHECKING, Any

from facebook_business.adobjects.adaccount import AdAccount
from facebook_business.adobjects.adreportrun import AdReportRun
from facebook_business.adobjects.adsinsights import AdsInsights
from facebook_business.api import FacebookAdsApi

from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook

if TYPE_CHECKING:
from facebook_business.adobjects.adsinsights import AdsInsights


class JobStatus(Enum):
"""Available options for facebook async task status."""
Expand Down
8 changes: 5 additions & 3 deletions airflow/providers/hashicorp/hooks/vault.py
Expand Up @@ -19,11 +19,9 @@

import json
import warnings
from typing import Any
from typing import TYPE_CHECKING, Any

import hvac
from hvac.exceptions import VaultError
from requests import Response

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.hooks.base import BaseHook
Expand All @@ -34,6 +32,10 @@
)
from airflow.utils.helpers import merge_dicts

if TYPE_CHECKING:
import hvac
from requests import Response


class VaultHook(BaseHook):
"""
Expand Down
6 changes: 3 additions & 3 deletions airflow/providers/http/operators/http.py
Expand Up @@ -21,16 +21,16 @@
import pickle
from typing import TYPE_CHECKING, Any, Callable, Sequence

from requests import Response
from requests.auth import AuthBase

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.http.hooks.http import HttpHook
from airflow.providers.http.triggers.http import HttpTrigger

if TYPE_CHECKING:
from requests import Response
from requests.auth import AuthBase

from airflow.utils.context import Context


Expand Down

0 comments on commit 9d8c77e

Please sign in to comment.