Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change type definition for provider_info_cache decorator #39750

Merged
merged 1 commit into from
May 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 10 additions & 13 deletions airflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import os
import sys
import warnings
from typing import TYPE_CHECKING

if os.environ.get("_AIRFLOW_PATCH_GEVENT"):
# If you are using gevents and start airflow webserver, you might want to run gevent monkeypatching
Expand Down Expand Up @@ -81,6 +82,13 @@
# Deprecated lazy imports
"AirflowException": (".exceptions", "AirflowException", True),
}
if TYPE_CHECKING:
# These objects are imported by PEP-562, however, static analyzers and IDE's
# have no idea about typing of these objects.
# Add it under TYPE_CHECKING block should help with it.
from airflow.models.dag import DAG
from airflow.models.dataset import Dataset
from airflow.models.xcom_arg import XComArg


def __getattr__(name: str):
Expand Down Expand Up @@ -119,24 +127,13 @@ def __getattr__(name: str):


if not settings.LAZY_LOAD_PROVIDERS:
from airflow import providers_manager
from airflow.providers_manager import ProvidersManager

manager = providers_manager.ProvidersManager()
manager = ProvidersManager()
manager.initialize_providers_list()
manager.initialize_providers_hooks()
manager.initialize_providers_extra_links()
if not settings.LAZY_LOAD_PLUGINS:
from airflow import plugins_manager

plugins_manager.ensure_plugins_loaded()


# This is never executed, but tricks static analyzers (PyDev, PyCharm,)
# into knowing the types of these symbols, and what
# they contain.
STATICA_HACK = True
globals()["kcah_acitats"[::-1].upper()] = False
if STATICA_HACK: # pragma: no cover
from airflow.models.dag import DAG
from airflow.models.dataset import Dataset
from airflow.models.xcom_arg import XComArg
26 changes: 14 additions & 12 deletions airflow/providers_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,14 @@
from dataclasses import dataclass
from functools import wraps
from time import perf_counter
from typing import TYPE_CHECKING, Any, Callable, MutableMapping, NamedTuple, TypeVar, cast
from typing import TYPE_CHECKING, Any, Callable, MutableMapping, NamedTuple, NoReturn, TypeVar

from packaging.utils import canonicalize_name

from airflow.exceptions import AirflowOptionalProviderFeatureException
from airflow.hooks.filesystem import FSHook
from airflow.hooks.package_index import PackageIndexHook
from airflow.typing_compat import ParamSpec
from airflow.utils import yaml
from airflow.utils.entry_points import entry_points_with_dist
from airflow.utils.log.logging_mixin import LoggingMixin
Expand All @@ -51,6 +52,9 @@
else:
from importlib_resources import files as resource_files

PS = ParamSpec("PS")
RT = TypeVar("RT")

MIN_PROVIDER_VERSIONS = {
"apache-airflow-providers-celery": "2.1.0",
}
Expand Down Expand Up @@ -261,11 +265,6 @@ class ConnectionFormWidgetInfo(NamedTuple):
is_sensitive: bool


T = TypeVar("T", bound=Callable)

logger = logging.getLogger(__name__)


def log_debug_import_from_sources(class_name, e, provider_package):
"""Log debug imports from sources."""
log.debug(
Expand Down Expand Up @@ -362,31 +361,34 @@ def _correctness_check(provider_package: str, class_name: str, provider_info: Pr

# We want to have better control over initialization of parameters and be able to debug and test it
# So we add our own decorator
def provider_info_cache(cache_name: str) -> Callable[[T], T]:
def provider_info_cache(cache_name: str) -> Callable[[Callable[PS, NoReturn]], Callable[PS, None]]:
"""
Decorate and cache provider info.

Decorator factory that create decorator that caches initialization of provider's parameters
:param cache_name: Name of the cache
"""

def provider_info_cache_decorator(func: T):
def provider_info_cache_decorator(func: Callable[PS, NoReturn]) -> Callable[PS, None]:
@wraps(func)
def wrapped_function(*args, **kwargs):
def wrapped_function(*args: PS.args, **kwargs: PS.kwargs) -> None:
providers_manager_instance = args[0]
if TYPE_CHECKING:
assert isinstance(providers_manager_instance, ProvidersManager)

if cache_name in providers_manager_instance._initialized_cache:
return
start_time = perf_counter()
logger.debug("Initializing Providers Manager[%s]", cache_name)
log.debug("Initializing Providers Manager[%s]", cache_name)
func(*args, **kwargs)
providers_manager_instance._initialized_cache[cache_name] = True
logger.debug(
log.debug(
"Initialization of Providers Manager[%s] took %.2f seconds",
cache_name,
perf_counter() - start_time,
)

return cast(T, wrapped_function)
return wrapped_function

return provider_info_cache_decorator

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ required-imports = ["from __future__ import annotations"]
combine-as-imports = true

[tool.ruff.lint.per-file-ignores]
"airflow/__init__.py" = ["F401"]
"airflow/__init__.py" = ["F401", "TCH004"]
"airflow/models/__init__.py" = ["F401", "TCH004"]
"airflow/models/sqla_models.py" = ["F401"]

Expand Down