Skip to content

Commit

Permalink
Rework provider manager to treat Airflow core hooks like other provid…
Browse files Browse the repository at this point in the history
…er hooks (#33051)
  • Loading branch information
jscheffl committed Aug 6, 2023
1 parent 6b21b79 commit 3461a23
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 10 deletions.
2 changes: 1 addition & 1 deletion airflow/cli/commands/connection_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def _valid_uri(uri: str) -> bool:
@cache
def _get_connection_types() -> list[str]:
"""Returns connection types available."""
_connection_types = ["fs", "mesos_framework-id", "email", "generic"]
_connection_types = []
providers_manager = ProvidersManager()
for connection_type, provider_info in providers_manager.hooks.items():
if provider_info:
Expand Down
42 changes: 40 additions & 2 deletions airflow/hooks/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
# under the License.
from __future__ import annotations

from pathlib import Path
from typing import Any

from airflow.hooks.base import BaseHook


Expand All @@ -33,9 +36,32 @@ class FSHook(BaseHook):
Extra: {"path": "/tmp"}
"""

def __init__(self, conn_id: str = "fs_default"):
conn_name_attr = "fs_conn_id"
default_conn_name = "fs_default"
conn_type = "fs"
hook_name = "File (path)"

@staticmethod
def get_connection_form_widgets() -> dict[str, Any]:
"""Returns connection widgets to add to connection form."""
from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
from flask_babel import lazy_gettext
from wtforms import StringField

return {"path": StringField(lazy_gettext("Path"), widget=BS3TextFieldWidget())}

@staticmethod
def get_ui_field_behaviour() -> dict[str, Any]:
"""Returns custom field behaviour."""
return {
"hidden_fields": ["host", "schema", "port", "login", "password", "extra"],
"relabeling": {},
"placeholders": {},
}

def __init__(self, fs_conn_id: str = default_conn_name):
super().__init__()
conn = self.get_connection(conn_id)
conn = self.get_connection(fs_conn_id)
self.basepath = conn.extra_dejson.get("path", "")
self.conn = conn

Expand All @@ -49,3 +75,15 @@ def get_path(self) -> str:
:return: the path.
"""
return self.basepath

def test_connection(self):
"""Test File connection."""
try:
p = self.get_path()
if not p:
return False, "File Path is undefined."
if not Path(p).exists():
return False, f"Path {p} does not exist."
return True, f"Path {p} is existing."
except Exception as e:
return False, str(e)
32 changes: 32 additions & 0 deletions airflow/providers_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from packaging.utils import canonicalize_name

from airflow.exceptions import AirflowOptionalProviderFeatureException
from airflow.hooks.filesystem import FSHook
from airflow.typing_compat import Literal
from airflow.utils import yaml
from airflow.utils.entry_points import entry_points_with_dist
Expand Down Expand Up @@ -431,6 +432,37 @@ def __init__(self):
)
# Set of plugins contained in providers
self._plugins_set: set[PluginInfo] = set()
self._init_airflow_core_hooks()

def _init_airflow_core_hooks(self):
"""Initializes the hooks dict with default hooks from Airflow core."""
core_dummy_hooks = {
"generic": "Generic",
"email": "Email",
"mesos_framework-id": "Mesos Framework ID",
}
for key, display in core_dummy_hooks.items():
self._hooks_lazy_dict[key] = HookInfo(
hook_class_name=None,
connection_id_attribute_name=None,
package_name=None,
hook_name=display,
connection_type=None,
connection_testable=False,
)
for cls in [FSHook]:
package_name = cls.__module__
hook_class_name = f"{cls.__module__}.{cls.__name__}"
hook_info = self._import_hook(
connection_type=None,
provider_info=None,
hook_class_name=hook_class_name,
package_name=package_name,
)
self._hook_provider_dict[hook_info.connection_type] = HookClassProvider(
hook_class_name=hook_class_name, package_name=package_name
)
self._hooks_lazy_dict[hook_info.connection_type] = hook_info

@provider_info_cache("list")
def initialize_providers_list(self):
Expand Down
4 changes: 0 additions & 4 deletions airflow/www/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,6 @@ def create_connection_form_class() -> type[DynamicForm]:

def _iter_connection_types() -> Iterator[tuple[str, str]]:
"""List available connection types."""
yield ("email", "Email")
yield ("fs", "File (path)")
yield ("generic", "Generic")
yield ("mesos_framework-id", "Mesos Framework ID")
for connection_type, provider_info in providers_manager.hooks.items():
if provider_info:
yield (connection_type, provider_info.hook_name)
Expand Down
6 changes: 3 additions & 3 deletions tests/always/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,14 +753,14 @@ def test_connection_test_success(self):
@mock.patch.dict(
"os.environ",
{
"AIRFLOW_CONN_TEST_URI_NO_HOOK": "fs://",
"AIRFLOW_CONN_TEST_URI_NO_HOOK": "unknown://",
},
)
def test_connection_test_no_hook(self):
conn = Connection(conn_id="test_uri_no_hook", conn_type="fs")
conn = Connection(conn_id="test_uri_no_hook", conn_type="unknown")
res = conn.test_connection()
assert res[0] is False
assert res[1] == 'Unknown hook type "fs"'
assert res[1] == 'Unknown hook type "unknown"'

@mock.patch.dict(
"os.environ",
Expand Down

0 comments on commit 3461a23

Please sign in to comment.