From d51de50e5ce897223b0367bc03f458d6c1f0b7a2 Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Sat, 22 Oct 2022 14:21:31 -0700 Subject: [PATCH] Update WasbHook to reflect preference for unprefixed extra (#27024) Since 2.3 we don't need the extra prefix to edit custom fields with UI form. From 2.5 we won't need to use the prefix in the ui form behaviors method either. --- .../providers/microsoft/azure/hooks/wasb.py | 67 +++++++++++++++---- .../connections/wasb.rst | 14 ++-- .../microsoft/azure/hooks/test_wasb.py | 32 +++++++++ 3 files changed, 92 insertions(+), 21 deletions(-) diff --git a/airflow/providers/microsoft/azure/hooks/wasb.py b/airflow/providers/microsoft/azure/hooks/wasb.py index 21caa3d6dc96f..27e30f420fa59 100644 --- a/airflow/providers/microsoft/azure/hooks/wasb.py +++ b/airflow/providers/microsoft/azure/hooks/wasb.py @@ -27,6 +27,7 @@ import logging import os +from functools import wraps from typing import Any from azure.core.exceptions import HttpResponseError, ResourceExistsError, ResourceNotFoundError @@ -37,6 +38,34 @@ from airflow.hooks.base import BaseHook +def _ensure_prefixes(conn_type): + """ + Remove when provider min airflow version >= 2.5.0 since this is handled by + provider manager from that version. + """ + + def dec(func): + @wraps(func) + def inner(): + field_behaviors = func() + conn_attrs = {'host', 'schema', 'login', 'password', 'port', 'extra'} + + def _ensure_prefix(field): + if field not in conn_attrs and not field.startswith('extra__'): + return f"extra__{conn_type}__{field}" + else: + return field + + if 'placeholders' in field_behaviors: + placeholders = field_behaviors['placeholders'] + field_behaviors['placeholders'] = {_ensure_prefix(k): v for k, v in placeholders.items()} + return field_behaviors + + return inner + + return dec + + class WasbHook(BaseHook): """ Interacts with Azure Blob Storage through the ``wasb://`` protocol. @@ -67,21 +96,20 @@ def get_connection_form_widgets() -> dict[str, Any]: from wtforms import PasswordField, StringField return { - "extra__wasb__connection_string": PasswordField( + "connection_string": PasswordField( lazy_gettext('Blob Storage Connection String (optional)'), widget=BS3PasswordFieldWidget() ), - "extra__wasb__shared_access_key": PasswordField( + "shared_access_key": PasswordField( lazy_gettext('Blob Storage Shared Access Key (optional)'), widget=BS3PasswordFieldWidget() ), - "extra__wasb__tenant_id": StringField( + "tenant_id": StringField( lazy_gettext('Tenant Id (Active Directory Auth)'), widget=BS3TextFieldWidget() ), - "extra__wasb__sas_token": PasswordField( - lazy_gettext('SAS Token (optional)'), widget=BS3PasswordFieldWidget() - ), + "sas_token": PasswordField(lazy_gettext('SAS Token (optional)'), widget=BS3PasswordFieldWidget()), } @staticmethod + @_ensure_prefixes(conn_type='wasb') def get_ui_field_behaviour() -> dict[str, Any]: """Returns custom field behaviour""" return { @@ -96,10 +124,10 @@ def get_ui_field_behaviour() -> dict[str, Any]: 'login': 'account name', 'password': 'secret', 'host': 'account url', - 'extra__wasb__connection_string': 'connection string auth', - 'extra__wasb__tenant_id': 'tenant', - 'extra__wasb__shared_access_key': 'shared access key', - 'extra__wasb__sas_token': 'account url or token', + 'connection_string': 'connection string auth', + 'tenant_id': 'tenant', + 'shared_access_key': 'shared access key', + 'sas_token': 'account url or token', }, } @@ -119,6 +147,17 @@ def __init__( except ValueError: logger.setLevel(logging.WARNING) + def _get_field(self, extra_dict, field_name): + prefix = 'extra__wasb__' + if field_name.startswith('extra_'): + raise ValueError( + f"Got prefixed name {field_name}; please remove the '{prefix}' prefix " + f"when using this method." + ) + if field_name in extra_dict: + return extra_dict[field_name] or None + return extra_dict.get(f"{prefix}{field_name}") or None + def get_conn(self) -> BlobServiceClient: """Return the BlobServiceClient object.""" conn = self.get_connection(self.conn_id) @@ -130,17 +169,17 @@ def get_conn(self) -> BlobServiceClient: # https://docs.microsoft.com/en-us/azure/storage/blobs/storage-manage-access-to-resources return BlobServiceClient(account_url=conn.host, **extra) - connection_string = extra.pop('connection_string', extra.pop('extra__wasb__connection_string', None)) + connection_string = self._get_field(extra, 'connection_string') if connection_string: # connection_string auth takes priority return BlobServiceClient.from_connection_string(connection_string, **extra) - shared_access_key = extra.pop('shared_access_key', extra.pop('extra__wasb__shared_access_key', None)) + shared_access_key = self._get_field(extra, 'shared_access_key') if shared_access_key: # using shared access key return BlobServiceClient(account_url=conn.host, credential=shared_access_key, **extra) - tenant = extra.pop('tenant_id', extra.pop('extra__wasb__tenant_id', None)) + tenant = self._get_field(extra, 'tenant_id') if tenant: # use Active Directory auth app_id = conn.login @@ -148,7 +187,7 @@ def get_conn(self) -> BlobServiceClient: token_credential = ClientSecretCredential(tenant, app_id, app_secret) return BlobServiceClient(account_url=conn.host, credential=token_credential, **extra) - sas_token = extra.pop('sas_token', extra.pop('extra__wasb__sas_token', None)) + sas_token = self._get_field(extra, 'sas_token') if sas_token: if sas_token.startswith('https'): return BlobServiceClient(account_url=sas_token, **extra) diff --git a/docs/apache-airflow-providers-microsoft-azure/connections/wasb.rst b/docs/apache-airflow-providers-microsoft-azure/connections/wasb.rst index 9dff97a00968a..823cc85c22dc1 100644 --- a/docs/apache-airflow-providers-microsoft-azure/connections/wasb.rst +++ b/docs/apache-airflow-providers-microsoft-azure/connections/wasb.rst @@ -34,13 +34,13 @@ There are four ways to connect to Azure Blob Storage using Airflow. i.e. add specific credentials (client_id, secret, tenant) and subscription id to the Airflow connection. 2. Use `Azure Shared Key Credential `_ - i.e. add shared key credentials to ``extra__wasb__shared_access_key`` the Airflow connection. + i.e. add shared key credentials to ``shared_access_key`` the Airflow connection. 3. Use a `SAS Token `_ - i.e. add a key config to ``extra__wasb__sas_token`` in the Airflow connection. + i.e. add a key config to ``sas_token`` in the Airflow connection. 4. Use a `Connection String `_ - i.e. add connection string to ``extra__wasb__connection_string`` in the Airflow connection. + i.e. add connection string to ``connection_string`` in the Airflow connection. Only one authorization method can be used at a time. If you need to manage multiple credentials or keys then you should configure multiple connections. @@ -67,10 +67,10 @@ Extra (optional) Specify the extra parameters (as json dictionary) that can be used in Azure connection. The following parameters are all optional: - * ``extra__wasb__tenant_id``: Specify the tenant to use. Needed for Active Directory (token) authentication. - * ``extra__wasb__shared_access_key``: Specify the shared access key. Needed for shared access key authentication. - * ``extra__wasb__connection_string``: Connection string for use with connection string authentication. - * ``extra__wasb__sas_token``: SAS Token for use with SAS Token authentication. + * ``tenant_id``: Specify the tenant to use. Needed for Active Directory (token) authentication. + * ``shared_access_key``: Specify the shared access key. Needed for shared access key authentication. + * ``connection_string``: Connection string for use with connection string authentication. + * ``sas_token``: SAS Token for use with SAS Token authentication. When specifying the connection in environment variable you should specify it using URI syntax. diff --git a/tests/providers/microsoft/azure/hooks/test_wasb.py b/tests/providers/microsoft/azure/hooks/test_wasb.py index f18c1da2f75d1..9d02b0dec1bb4 100644 --- a/tests/providers/microsoft/azure/hooks/test_wasb.py +++ b/tests/providers/microsoft/azure/hooks/test_wasb.py @@ -28,6 +28,7 @@ from airflow.models import Connection from airflow.providers.microsoft.azure.hooks.wasb import WasbHook from airflow.utils import db +from tests.test_utils.providers import get_provider_min_airflow_version, object_exists # connection_string has a format CONN_STRING = ( @@ -390,3 +391,34 @@ def test_connection_failure(self, mock_service): status, msg = hook.test_connection() assert status is False assert msg == "Authentication failed." + + def test__ensure_prefixes_removal(self): + """Ensure that _ensure_prefixes is removed from snowflake when airflow min version >= 2.5.0.""" + path = 'airflow.providers.microsoft.azure.hooks.wasb._ensure_prefixes' + if not object_exists(path): + raise Exception( + "You must remove this test. It only exists to " + "remind us to remove decorator `_ensure_prefixes`." + ) + + if get_provider_min_airflow_version('apache-airflow-providers-microsoft-azure') >= (2, 5): + raise Exception( + "You must now remove `_ensure_prefixes` from WasbHook. The functionality is now taken" + "care of by providers manager." + ) + + def test___ensure_prefixes(self): + """ + Check that ensure_prefixes decorator working properly + Note: remove this test when removing ensure_prefixes (after min airflow version >= 2.5.0 + """ + assert list(WasbHook.get_ui_field_behaviour()['placeholders'].keys()) == [ + 'extra', + 'login', + 'password', + 'host', + 'extra__wasb__connection_string', + 'extra__wasb__tenant_id', + 'extra__wasb__shared_access_key', + 'extra__wasb__sas_token', + ]