-
Notifications
You must be signed in to change notification settings - Fork 13.7k
/
providers_manager.py
1270 lines (1096 loc) · 54.8 KB
/
providers_manager.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Manages all providers."""
from __future__ import annotations
import fnmatch
import functools
import inspect
import json
import logging
import os
import sys
import traceback
import warnings
from dataclasses import dataclass
from functools import wraps
from time import perf_counter
from typing import TYPE_CHECKING, Any, Callable, MutableMapping, NamedTuple, TypeVar, cast
from packaging.utils import canonicalize_name
from airflow.exceptions import AirflowOptionalProviderFeatureException
from airflow.hooks.filesystem import FSHook
from airflow.hooks.package_index import PackageIndexHook
from airflow.utils import yaml
from airflow.utils.entry_points import entry_points_with_dist
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.module_loading import import_string
from airflow.utils.singleton import Singleton
log = logging.getLogger(__name__)
if sys.version_info >= (3, 9):
from importlib.resources import files as resource_files
else:
from importlib_resources import files as resource_files
MIN_PROVIDER_VERSIONS = {
"apache-airflow-providers-celery": "2.1.0",
}
def _ensure_prefix_for_placeholders(field_behaviors: dict[str, Any], conn_type: str):
"""
Verify the correct placeholder prefix.
If the given field_behaviors dict contains a placeholder's node, and there
are placeholders for extra fields (i.e. anything other than the built-in conn
attrs), and if those extra fields are unprefixed, then add the prefix.
The reason we need to do this is, all custom conn fields live in the same dictionary,
so we need to namespace them with a prefix internally. But for user convenience,
and consistency between the `get_ui_field_behaviour` method and the extra dict itself,
we allow users to supply the unprefixed name.
"""
conn_attrs = {"host", "schema", "login", "password", "port", "extra"}
def ensure_prefix(field):
if field not in conn_attrs and not field.startswith("extra__"):
return f"extra__{conn_type}__{field}"
else:
return field
if "placeholders" in field_behaviors:
placeholders = field_behaviors["placeholders"]
field_behaviors["placeholders"] = {ensure_prefix(k): v for k, v in placeholders.items()}
return field_behaviors
if TYPE_CHECKING:
from airflow.decorators.base import TaskDecorator
from airflow.hooks.base import BaseHook
from airflow.typing_compat import Literal
class LazyDictWithCache(MutableMapping):
"""
Lazy-loaded cached dictionary.
Dictionary, which in case you set callable, executes the passed callable with `key` attribute
at first use - and returns and caches the result.
"""
__slots__ = ["_resolved", "_raw_dict"]
def __init__(self, *args, **kw):
self._resolved = set()
self._raw_dict = dict(*args, **kw)
def __setitem__(self, key, value):
self._raw_dict.__setitem__(key, value)
def __getitem__(self, key):
value = self._raw_dict.__getitem__(key)
if key not in self._resolved and callable(value):
# exchange callable with result of calling it -- but only once! allow resolver to return a
# callable itself
value = value()
self._resolved.add(key)
self._raw_dict.__setitem__(key, value)
return value
def __delitem__(self, key):
self._raw_dict.__delitem__(key)
try:
self._resolved.remove(key)
except KeyError:
pass
def __iter__(self):
return iter(self._raw_dict)
def __len__(self):
return len(self._raw_dict)
def __contains__(self, key):
return key in self._raw_dict
def _read_schema_from_resources_or_local_file(filename: str) -> dict:
try:
with resource_files("airflow").joinpath(filename).open("rb") as f:
schema = json.load(f)
except (TypeError, FileNotFoundError):
import pathlib
with (pathlib.Path(__file__).parent / filename).open("rb") as f:
schema = json.load(f)
return schema
def _create_provider_info_schema_validator():
"""Creates JSON schema validator from the provider_info.schema.json."""
import jsonschema
schema = _read_schema_from_resources_or_local_file("provider_info.schema.json")
cls = jsonschema.validators.validator_for(schema)
validator = cls(schema)
return validator
def _create_customized_form_field_behaviours_schema_validator():
"""Creates JSON schema validator from the customized_form_field_behaviours.schema.json."""
import jsonschema
schema = _read_schema_from_resources_or_local_file("customized_form_field_behaviours.schema.json")
cls = jsonschema.validators.validator_for(schema)
validator = cls(schema)
return validator
def _check_builtin_provider_prefix(provider_package: str, class_name: str) -> bool:
if provider_package.startswith("apache-airflow"):
provider_path = provider_package[len("apache-") :].replace("-", ".")
if not class_name.startswith(provider_path):
log.warning(
"Coherence check failed when importing '%s' from '%s' package. It should start with '%s'",
class_name,
provider_package,
provider_path,
)
return False
return True
@dataclass
class ProviderInfo:
"""
Provider information.
:param version: version string
:param data: dictionary with information about the provider
:param source_or_package: whether the provider is source files or PyPI package. When installed from
sources we suppress provider import errors.
"""
version: str
data: dict
package_or_source: Literal["source"] | Literal["package"]
def __post_init__(self):
if self.package_or_source not in ("source", "package"):
raise ValueError(
f"Received {self.package_or_source!r} for `package_or_source`. "
"Must be either 'package' or 'source'."
)
self.is_source = self.package_or_source == "source"
class HookClassProvider(NamedTuple):
"""Hook class and Provider it comes from."""
hook_class_name: str
package_name: str
class TriggerInfo(NamedTuple):
"""Trigger class and provider it comes from."""
trigger_class_name: str
package_name: str
integration_name: str
class NotificationInfo(NamedTuple):
"""Notification class and provider it comes from."""
notification_class_name: str
package_name: str
class PluginInfo(NamedTuple):
"""Plugin class, name and provider it comes from."""
name: str
plugin_class: str
provider_name: str
class HookInfo(NamedTuple):
"""Hook information."""
hook_class_name: str
connection_id_attribute_name: str
package_name: str
hook_name: str
connection_type: str
connection_testable: bool
class ConnectionFormWidgetInfo(NamedTuple):
"""Connection Form Widget information."""
hook_class_name: str
package_name: str
field: Any
field_name: str
is_sensitive: bool
T = TypeVar("T", bound=Callable)
logger = logging.getLogger(__name__)
def log_debug_import_from_sources(class_name, e, provider_package):
"""Log debug imports from sources."""
log.debug(
"Optional feature disabled on exception when importing '%s' from '%s' package",
class_name,
provider_package,
exc_info=e,
)
def log_optional_feature_disabled(class_name, e, provider_package):
"""Log optional feature disabled."""
log.debug(
"Optional feature disabled on exception when importing '%s' from '%s' package",
class_name,
provider_package,
exc_info=e,
)
log.info(
"Optional provider feature disabled when importing '%s' from '%s' package",
class_name,
provider_package,
)
def log_import_warning(class_name, e, provider_package):
"""Log import warning."""
log.warning(
"Exception when importing '%s' from '%s' package",
class_name,
provider_package,
exc_info=e,
)
# This is a temporary measure until all community providers will add AirflowOptionalProviderFeatureException
# where they have optional features. We are going to add tests in our CI to catch all such cases and will
# fix them, but until now all "known unhandled optional feature errors" from community providers
# should be added here
KNOWN_UNHANDLED_OPTIONAL_FEATURE_ERRORS = [("apache-airflow-providers-google", "No module named 'paramiko'")]
def _correctness_check(
provider_package: str, class_name: str, provider_info: ProviderInfo
) -> type[BaseHook] | None:
"""
Performs coherence check on provider classes.
For apache-airflow providers - it checks if it starts with appropriate package. For all providers
it tries to import the provider - checking that there are no exceptions during importing.
It logs appropriate warning in case it detects any problems.
:param provider_package: name of the provider package
:param class_name: name of the class to import
:return the class if the class is OK, None otherwise.
"""
if not _check_builtin_provider_prefix(provider_package, class_name):
return None
try:
imported_class = import_string(class_name)
except AirflowOptionalProviderFeatureException as e:
# When the provider class raises AirflowOptionalProviderFeatureException
# this is an expected case when only some classes in provider are
# available. We just log debug level here and print info message in logs so that
# the user is aware of it
log_optional_feature_disabled(class_name, e, provider_package)
return None
except ImportError as e:
if provider_info.is_source:
# When we have providers from sources, then we just turn all import logs to debug logs
# As this is pretty expected that you have a number of dependencies not installed
# (we always have all providers from sources until we split providers to separate repo)
log_debug_import_from_sources(class_name, e, provider_package)
return None
if "No module named 'airflow.providers." in e.msg:
# handle cases where another provider is missing. This can only happen if
# there is an optional feature, so we log debug and print information about it
log_optional_feature_disabled(class_name, e, provider_package)
return None
for known_error in KNOWN_UNHANDLED_OPTIONAL_FEATURE_ERRORS:
# Until we convert all providers to use AirflowOptionalProviderFeatureException
# we assume any problem with importing another "provider" is because this is an
# optional feature, so we log debug and print information about it
if known_error[0] == provider_package and known_error[1] in e.msg:
log_optional_feature_disabled(class_name, e, provider_package)
return None
# But when we have no idea - we print warning to logs
log_import_warning(class_name, e, provider_package)
return None
except Exception as e:
log_import_warning(class_name, e, provider_package)
return None
return imported_class
# We want to have better control over initialization of parameters and be able to debug and test it
# So we add our own decorator
def provider_info_cache(cache_name: str) -> Callable[[T], T]:
"""
Decorate and cache provider info.
Decorator factory that create decorator that caches initialization of provider's parameters
:param cache_name: Name of the cache
"""
def provider_info_cache_decorator(func: T):
@wraps(func)
def wrapped_function(*args, **kwargs):
providers_manager_instance = args[0]
if cache_name in providers_manager_instance._initialized_cache:
return
start_time = perf_counter()
logger.debug("Initializing Providers Manager[%s]", cache_name)
func(*args, **kwargs)
providers_manager_instance._initialized_cache[cache_name] = True
logger.debug(
"Initialization of Providers Manager[%s] took %.2f seconds",
cache_name,
perf_counter() - start_time,
)
return cast(T, wrapped_function)
return provider_info_cache_decorator
class ProvidersManager(LoggingMixin, metaclass=Singleton):
"""
Manages all provider packages.
This is a Singleton class. The first time it is
instantiated, it discovers all available providers in installed packages and
local source folders (if airflow is run from sources).
"""
resource_version = "0"
_initialized: bool = False
_initialization_stack_trace = None
@staticmethod
def initialized() -> bool:
return ProvidersManager._initialized
@staticmethod
def initialization_stack_trace() -> str | None:
return ProvidersManager._initialization_stack_trace
def __init__(self):
"""Initializes the manager."""
super().__init__()
ProvidersManager._initialized = True
ProvidersManager._initialization_stack_trace = "".join(traceback.format_stack(inspect.currentframe()))
self._initialized_cache: dict[str, bool] = {}
# Keeps dict of providers keyed by module name
self._provider_dict: dict[str, ProviderInfo] = {}
# Keeps dict of hooks keyed by connection type
self._hooks_dict: dict[str, HookInfo] = {}
self._fs_set: set[str] = set()
self._taskflow_decorators: dict[str, Callable] = LazyDictWithCache() # type: ignore[assignment]
# keeps mapping between connection_types and hook class, package they come from
self._hook_provider_dict: dict[str, HookClassProvider] = {}
# Keeps dict of hooks keyed by connection type. They are lazy evaluated at access time
self._hooks_lazy_dict: LazyDictWithCache[str, HookInfo | Callable] = LazyDictWithCache()
# Keeps methods that should be used to add custom widgets tuple of keyed by name of the extra field
self._connection_form_widgets: dict[str, ConnectionFormWidgetInfo] = {}
# Customizations for javascript fields are kept here
self._field_behaviours: dict[str, dict] = {}
self._extra_link_class_name_set: set[str] = set()
self._logging_class_name_set: set[str] = set()
self._auth_manager_class_name_set: set[str] = set()
self._secrets_backend_class_name_set: set[str] = set()
self._executor_class_name_set: set[str] = set()
self._provider_configs: dict[str, dict[str, Any]] = {}
self._api_auth_backend_module_names: set[str] = set()
self._trigger_info_set: set[TriggerInfo] = set()
self._notification_info_set: set[NotificationInfo] = set()
self._provider_schema_validator = _create_provider_info_schema_validator()
self._customized_form_fields_schema_validator = (
_create_customized_form_field_behaviours_schema_validator()
)
# Set of plugins contained in providers
self._plugins_set: set[PluginInfo] = set()
self._init_airflow_core_hooks()
def _init_airflow_core_hooks(self):
"""Initializes the hooks dict with default hooks from Airflow core."""
core_dummy_hooks = {
"generic": "Generic",
"email": "Email",
}
for key, display in core_dummy_hooks.items():
self._hooks_lazy_dict[key] = HookInfo(
hook_class_name=None,
connection_id_attribute_name=None,
package_name=None,
hook_name=display,
connection_type=None,
connection_testable=False,
)
for cls in [FSHook, PackageIndexHook]:
package_name = cls.__module__
hook_class_name = f"{cls.__module__}.{cls.__name__}"
hook_info = self._import_hook(
connection_type=None,
provider_info=None,
hook_class_name=hook_class_name,
package_name=package_name,
)
self._hook_provider_dict[hook_info.connection_type] = HookClassProvider(
hook_class_name=hook_class_name, package_name=package_name
)
self._hooks_lazy_dict[hook_info.connection_type] = hook_info
@provider_info_cache("list")
def initialize_providers_list(self):
"""Lazy initialization of providers list."""
# Local source folders are loaded first. They should take precedence over the package ones for
# Development purpose. In production provider.yaml files are not present in the 'airflow" directory
# So there is no risk we are going to override package provider accidentally. This can only happen
# in case of local development
self._discover_all_airflow_builtin_providers_from_local_sources()
self._discover_all_providers_from_packages()
self._verify_all_providers_all_compatible()
self._provider_dict = dict(sorted(self._provider_dict.items()))
def _verify_all_providers_all_compatible(self):
from packaging import version as packaging_version
for provider_id, info in self._provider_dict.items():
min_version = MIN_PROVIDER_VERSIONS.get(provider_id)
if min_version:
if packaging_version.parse(min_version) > packaging_version.parse(info.version):
log.warning(
"The package %s is not compatible with this version of Airflow. "
"The package has version %s but the minimum supported version "
"of the package is %s",
provider_id,
info.version,
min_version,
)
@provider_info_cache("hooks")
def initialize_providers_hooks(self):
"""Lazy initialization of providers hooks."""
self.initialize_providers_list()
self._discover_hooks()
self._hook_provider_dict = dict(sorted(self._hook_provider_dict.items()))
@provider_info_cache("filesystems")
def initialize_providers_filesystems(self):
"""Lazy initialization of providers filesystems."""
self.initialize_providers_list()
self._discover_filesystems()
@provider_info_cache("taskflow_decorators")
def initialize_providers_taskflow_decorator(self):
"""Lazy initialization of providers hooks."""
self.initialize_providers_list()
self._discover_taskflow_decorators()
@provider_info_cache("extra_links")
def initialize_providers_extra_links(self):
"""Lazy initialization of providers extra links."""
self.initialize_providers_list()
self._discover_extra_links()
@provider_info_cache("logging")
def initialize_providers_logging(self):
"""Lazy initialization of providers logging information."""
self.initialize_providers_list()
self._discover_logging()
@provider_info_cache("secrets_backends")
def initialize_providers_secrets_backends(self):
"""Lazy initialization of providers secrets_backends information."""
self.initialize_providers_list()
self._discover_secrets_backends()
@provider_info_cache("executors")
def initialize_providers_executors(self):
"""Lazy initialization of providers executors information."""
self.initialize_providers_list()
self._discover_executors()
@provider_info_cache("notifications")
def initialize_providers_notifications(self):
"""Lazy initialization of providers notifications information."""
self.initialize_providers_list()
self._discover_notifications()
@provider_info_cache("auth_managers")
def initialize_providers_auth_managers(self):
"""Lazy initialization of providers notifications information."""
self.initialize_providers_list()
self._discover_auth_managers()
@provider_info_cache("config")
def initialize_providers_configuration(self):
"""Lazy initialization of providers configuration information."""
self._initialize_providers_configuration()
def _initialize_providers_configuration(self):
"""
Internal method to initialize providers configuration information.
Should be used if we do not want to trigger caching for ``initialize_providers_configuration`` method.
In some cases we might want to make sure that the configuration is initialized, but we do not want
to cache the initialization method - for example when we just want to write configuration with
providers, but it is used in the context where no providers are loaded yet we will eventually
restore the original configuration and we want the subsequent ``initialize_providers_configuration``
method to be run in order to load the configuration for providers again.
"""
self.initialize_providers_list()
self._discover_config()
# Now update conf with the new provider configuration from providers
from airflow.configuration import conf
conf.load_providers_configuration()
@provider_info_cache("auth_backends")
def initialize_providers_auth_backends(self):
"""Lazy initialization of providers API auth_backends information."""
self.initialize_providers_list()
self._discover_auth_backends()
@provider_info_cache("plugins")
def initialize_providers_plugins(self):
self.initialize_providers_list()
self._discover_plugins()
def _discover_all_providers_from_packages(self) -> None:
"""
Discover all providers by scanning packages installed.
The list of providers should be returned via the 'apache_airflow_provider'
entrypoint as a dictionary conforming to the 'airflow/provider_info.schema.json'
schema. Note that the schema is different at runtime than provider.yaml.schema.json.
The development version of provider schema is more strict and changes together with
the code. The runtime version is more relaxed (allows for additional properties)
and verifies only the subset of fields that are needed at runtime.
"""
for entry_point, dist in entry_points_with_dist("apache_airflow_provider"):
package_name = canonicalize_name(dist.metadata["name"])
if package_name in self._provider_dict:
continue
log.debug("Loading %s from package %s", entry_point, package_name)
version = dist.version
provider_info = entry_point.load()()
self._provider_schema_validator.validate(provider_info)
provider_info_package_name = provider_info["package-name"]
if package_name != provider_info_package_name:
raise Exception(
f"The package '{package_name}' from setuptools and "
f"{provider_info_package_name} do not match. Please make sure they are aligned"
)
if package_name not in self._provider_dict:
self._provider_dict[package_name] = ProviderInfo(version, provider_info, "package")
else:
log.warning(
"The provider for package '%s' could not be registered from because providers for that "
"package name have already been registered",
package_name,
)
def _discover_all_airflow_builtin_providers_from_local_sources(self) -> None:
"""
Finds all built-in airflow providers if airflow is run from the local sources.
It finds `provider.yaml` files for all such providers and registers the providers using those.
This 'provider.yaml' scanning takes precedence over scanning packages installed
in case you have both sources and packages installed, the providers will be loaded from
the "airflow" sources rather than from the packages.
"""
try:
import airflow.providers
except ImportError:
log.info("You have no providers installed.")
return
seen = set()
for path in airflow.providers.__path__: # type: ignore[attr-defined]
try:
# The same path can appear in the __path__ twice, under non-normalized paths (ie.
# /path/to/repo/airflow/providers and /path/to/repo/./airflow/providers)
path = os.path.realpath(path)
if path not in seen:
seen.add(path)
self._add_provider_info_from_local_source_files_on_path(path)
except Exception as e:
log.warning(f"Error when loading 'provider.yaml' files from {path} airflow sources: {e}")
def _add_provider_info_from_local_source_files_on_path(self, path) -> None:
"""
Finds all the provider.yaml files in the directory specified.
:param path: path where to look for provider.yaml files
"""
root_path = path
for folder, subdirs, files in os.walk(path, topdown=True):
for filename in fnmatch.filter(files, "provider.yaml"):
try:
package_name = "apache-airflow-providers" + folder[len(root_path) :].replace(os.sep, "-")
self._add_provider_info_from_local_source_file(
os.path.join(folder, filename), package_name
)
subdirs[:] = []
except Exception as e:
log.warning("Error when loading 'provider.yaml' file from %s %e", folder, e)
def _add_provider_info_from_local_source_file(self, path, package_name) -> None:
"""
Parses found provider.yaml file and adds found provider to the dictionary.
:param path: full file path of the provider.yaml file
:param package_name: name of the package
"""
try:
log.debug("Loading %s from %s", package_name, path)
with open(path) as provider_yaml_file:
provider_info = yaml.safe_load(provider_yaml_file)
self._provider_schema_validator.validate(provider_info)
version = provider_info["versions"][0]
if package_name not in self._provider_dict:
self._provider_dict[package_name] = ProviderInfo(version, provider_info, "source")
else:
log.warning(
"The providers for package '%s' could not be registered because providers for that "
"package name have already been registered",
package_name,
)
except Exception as e:
log.warning("Error when loading '%s'", path, exc_info=e)
def _discover_hooks_from_connection_types(
self,
hook_class_names_registered: set[str],
already_registered_warning_connection_types: set[str],
package_name: str,
provider: ProviderInfo,
):
"""
Discover hooks from the "connection-types" property.
This is new, better method that replaces discovery from hook-class-names as it
allows to lazy import individual Hook classes when they are accessed.
The "connection-types" keeps information about both - connection type and class
name so we can discover all connection-types without importing the classes.
:param hook_class_names_registered: set of registered hook class names for this provider
:param already_registered_warning_connection_types: set of connections for which warning should be
printed in logs as they were already registered before
:param package_name:
:param provider:
:return:
"""
provider_uses_connection_types = False
connection_types = provider.data.get("connection-types")
if connection_types:
for connection_type_dict in connection_types:
connection_type = connection_type_dict["connection-type"]
hook_class_name = connection_type_dict["hook-class-name"]
hook_class_names_registered.add(hook_class_name)
already_registered = self._hook_provider_dict.get(connection_type)
if already_registered:
if already_registered.package_name != package_name:
already_registered_warning_connection_types.add(connection_type)
else:
log.warning(
"The connection type '%s' is already registered in the"
" package '%s' with different class names: '%s' and '%s'. ",
connection_type,
package_name,
already_registered.hook_class_name,
hook_class_name,
)
else:
self._hook_provider_dict[connection_type] = HookClassProvider(
hook_class_name=hook_class_name, package_name=package_name
)
# Defer importing hook to access time by setting import hook method as dict value
self._hooks_lazy_dict[connection_type] = functools.partial(
self._import_hook,
connection_type=connection_type,
provider_info=provider,
)
provider_uses_connection_types = True
return provider_uses_connection_types
def _discover_hooks_from_hook_class_names(
self,
hook_class_names_registered: set[str],
already_registered_warning_connection_types: set[str],
package_name: str,
provider: ProviderInfo,
provider_uses_connection_types: bool,
):
"""
Discover hooks from "hook-class-names' property.
This property is deprecated but we should support it in Airflow 2.
The hook-class-names array contained just Hook names without connection type,
therefore we need to import all those classes immediately to know which connection types
are supported. This makes it impossible to selectively only import those hooks that are used.
:param already_registered_warning_connection_types: list of connection hooks that we should warn
about when finished discovery
:param package_name: name of the provider package
:param provider: class that keeps information about version and details of the provider
:param provider_uses_connection_types: determines whether the provider uses "connection-types" new
form of passing connection types
:return:
"""
hook_class_names = provider.data.get("hook-class-names")
if hook_class_names:
for hook_class_name in hook_class_names:
if hook_class_name in hook_class_names_registered:
# Silently ignore the hook class - it's already marked for lazy-import by
# connection-types discovery
continue
hook_info = self._import_hook(
connection_type=None,
provider_info=provider,
hook_class_name=hook_class_name,
package_name=package_name,
)
if not hook_info:
# Problem why importing class - we ignore it. Log is written at import time
continue
already_registered = self._hook_provider_dict.get(hook_info.connection_type)
if already_registered:
if already_registered.package_name != package_name:
already_registered_warning_connection_types.add(hook_info.connection_type)
else:
if already_registered.hook_class_name != hook_class_name:
log.warning(
"The hook connection type '%s' is registered twice in the"
" package '%s' with different class names: '%s' and '%s'. "
" Please fix it!",
hook_info.connection_type,
package_name,
already_registered.hook_class_name,
hook_class_name,
)
else:
self._hook_provider_dict[hook_info.connection_type] = HookClassProvider(
hook_class_name=hook_class_name, package_name=package_name
)
self._hooks_lazy_dict[hook_info.connection_type] = hook_info
if not provider_uses_connection_types:
warnings.warn(
f"The provider {package_name} uses `hook-class-names` "
"property in provider-info and has no `connection-types` one. "
"The 'hook-class-names' property has been deprecated in favour "
"of 'connection-types' in Airflow 2.2. Use **both** in case you want to "
"have backwards compatibility with Airflow < 2.2",
DeprecationWarning,
)
for already_registered_connection_type in already_registered_warning_connection_types:
log.warning(
"The connection_type '%s' has been already registered by provider '%s.'",
already_registered_connection_type,
self._hook_provider_dict[already_registered_connection_type].package_name,
)
def _discover_hooks(self) -> None:
"""Retrieve all connections defined in the providers via Hooks."""
for package_name, provider in self._provider_dict.items():
duplicated_connection_types: set[str] = set()
hook_class_names_registered: set[str] = set()
provider_uses_connection_types = self._discover_hooks_from_connection_types(
hook_class_names_registered, duplicated_connection_types, package_name, provider
)
self._discover_hooks_from_hook_class_names(
hook_class_names_registered,
duplicated_connection_types,
package_name,
provider,
provider_uses_connection_types,
)
self._hook_provider_dict = dict(sorted(self._hook_provider_dict.items()))
@provider_info_cache("import_all_hooks")
def _import_info_from_all_hooks(self):
"""Force-import all hooks and initialize the connections/fields."""
# Retrieve all hooks to make sure that all of them are imported
_ = list(self._hooks_lazy_dict.values())
self._field_behaviours = dict(sorted(self._field_behaviours.items()))
# Widgets for connection forms are currently used in two places:
# 1. In the UI Connections, expected same order that it defined in Hook.
# 2. cli command - `airflow providers widgets` and expected that it in alphabetical order.
# It is not possible to recover original ordering after sorting,
# that the main reason why original sorting moved to cli part:
# self._connection_form_widgets = dict(sorted(self._connection_form_widgets.items()))
def _discover_filesystems(self) -> None:
"""Retrieve all filesystems defined in the providers."""
for provider_package, provider in self._provider_dict.items():
for fs_module_name in provider.data.get("filesystems", []):
if _correctness_check(provider_package, fs_module_name + ".get_fs", provider):
self._fs_set.add(fs_module_name)
self._fs_set = set(sorted(self._fs_set))
def _discover_taskflow_decorators(self) -> None:
for name, info in self._provider_dict.items():
for taskflow_decorator in info.data.get("task-decorators", []):
self._add_taskflow_decorator(
taskflow_decorator["name"], taskflow_decorator["class-name"], name
)
def _add_taskflow_decorator(self, name, decorator_class_name: str, provider_package: str) -> None:
if not _check_builtin_provider_prefix(provider_package, decorator_class_name):
return
if name in self._taskflow_decorators:
try:
existing = self._taskflow_decorators[name]
other_name = f"{existing.__module__}.{existing.__name__}"
except Exception:
# If problem importing, then get the value from the functools.partial
other_name = self._taskflow_decorators._raw_dict[name].args[0] # type: ignore[attr-defined]
log.warning(
"The taskflow decorator '%s' has been already registered (by %s).",
name,
other_name,
)
return
self._taskflow_decorators[name] = functools.partial(import_string, decorator_class_name)
@staticmethod
def _get_attr(obj: Any, attr_name: str):
"""Retrieve attributes of an object, or warn if not found."""
if not hasattr(obj, attr_name):
log.warning("The object '%s' is missing %s attribute and cannot be registered", obj, attr_name)
return None
return getattr(obj, attr_name)
def _import_hook(
self,
connection_type: str | None,
provider_info: ProviderInfo,
hook_class_name: str | None = None,
package_name: str | None = None,
) -> HookInfo | None:
"""
Import hook and retrieve hook information.
Either connection_type (for lazy loading) or hook_class_name must be set - but not both).
Only needs package_name if hook_class_name is passed (for lazy loading, package_name
is retrieved from _connection_type_class_provider_dict together with hook_class_name).
:param connection_type: type of the connection
:param hook_class_name: name of the hook class
:param package_name: provider package - only needed in case connection_type is missing
: return
"""
from wtforms import BooleanField, IntegerField, PasswordField, StringField
if connection_type is None and hook_class_name is None:
raise ValueError("Either connection_type or hook_class_name must be set")
if connection_type is not None and hook_class_name is not None:
raise ValueError(
f"Both connection_type ({connection_type} and "
f"hook_class_name {hook_class_name} are set. Only one should be set!"
)
if connection_type is not None:
class_provider = self._hook_provider_dict[connection_type]
package_name = class_provider.package_name
hook_class_name = class_provider.hook_class_name
else:
if not hook_class_name:
raise ValueError("Either connection_type or hook_class_name must be set")
if not package_name:
raise ValueError(
f"Provider package name is not set when hook_class_name ({hook_class_name}) is used"
)
allowed_field_classes = [IntegerField, PasswordField, StringField, BooleanField]
hook_class = _correctness_check(package_name, hook_class_name, provider_info)
if hook_class is None:
return None
try:
module, class_name = hook_class_name.rsplit(".", maxsplit=1)
# Do not use attr here. We want to check only direct class fields not those
# inherited from parent hook. This way we add form fields only once for the whole
# hierarchy and we add it only from the parent hook that provides those!
if "get_connection_form_widgets" in hook_class.__dict__:
widgets = hook_class.get_connection_form_widgets()
if widgets:
for widget in widgets.values():
if widget.field_class not in allowed_field_classes:
log.warning(
"The hook_class '%s' uses field of unsupported class '%s'. "
"Only '%s' field classes are supported",
hook_class_name,
widget.field_class,
allowed_field_classes,
)
return None
self._add_widgets(package_name, hook_class, widgets)
if "get_ui_field_behaviour" in hook_class.__dict__:
field_behaviours = hook_class.get_ui_field_behaviour()
if field_behaviours:
self._add_customized_fields(package_name, hook_class, field_behaviours)
except ImportError as e:
if "No module named 'flask_appbuilder'" in e.msg:
log.warning(
"The hook_class '%s' is not fully initialized (UI widgets will be missing), because "
"the 'flask_appbuilder' package is not installed, however it is not required for "
"Airflow components to work",
hook_class_name,
)
except Exception as e:
log.warning(
"Exception when importing '%s' from '%s' package: %s",
hook_class_name,
package_name,
e,
)
return None
hook_connection_type = self._get_attr(hook_class, "conn_type")
if connection_type:
if hook_connection_type != connection_type:
log.warning(
"Inconsistency! The hook class '%s' declares connection type '%s'"
" but it is added by provider '%s' as connection_type '%s' in provider info. "
"This should be fixed!",
hook_class,
hook_connection_type,
package_name,
connection_type,
)
connection_type = hook_connection_type
connection_id_attribute_name: str = self._get_attr(hook_class, "conn_name_attr")
hook_name: str = self._get_attr(hook_class, "hook_name")