Skip to content

Commit

Permalink
Update WasbHook to reflect preference for unprefixed extra (#27024)
Browse files Browse the repository at this point in the history
Since 2.3 we don't need the extra prefix to edit custom fields with UI form.  From 2.5 we won't need to use the prefix in the ui form behaviors method either.
  • Loading branch information
dstandish committed Oct 22, 2022
1 parent 837e463 commit d51de50
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 21 deletions.
67 changes: 53 additions & 14 deletions airflow/providers/microsoft/azure/hooks/wasb.py
Expand Up @@ -27,6 +27,7 @@

import logging
import os
from functools import wraps
from typing import Any

from azure.core.exceptions import HttpResponseError, ResourceExistsError, ResourceNotFoundError
Expand All @@ -37,6 +38,34 @@
from airflow.hooks.base import BaseHook


def _ensure_prefixes(conn_type):
"""
Remove when provider min airflow version >= 2.5.0 since this is handled by
provider manager from that version.
"""

def dec(func):
@wraps(func)
def inner():
field_behaviors = func()
conn_attrs = {'host', 'schema', 'login', 'password', 'port', 'extra'}

def _ensure_prefix(field):
if field not in conn_attrs and not field.startswith('extra__'):
return f"extra__{conn_type}__{field}"
else:
return field

if 'placeholders' in field_behaviors:
placeholders = field_behaviors['placeholders']
field_behaviors['placeholders'] = {_ensure_prefix(k): v for k, v in placeholders.items()}
return field_behaviors

return inner

return dec


class WasbHook(BaseHook):
"""
Interacts with Azure Blob Storage through the ``wasb://`` protocol.
Expand Down Expand Up @@ -67,21 +96,20 @@ def get_connection_form_widgets() -> dict[str, Any]:
from wtforms import PasswordField, StringField

return {
"extra__wasb__connection_string": PasswordField(
"connection_string": PasswordField(
lazy_gettext('Blob Storage Connection String (optional)'), widget=BS3PasswordFieldWidget()
),
"extra__wasb__shared_access_key": PasswordField(
"shared_access_key": PasswordField(
lazy_gettext('Blob Storage Shared Access Key (optional)'), widget=BS3PasswordFieldWidget()
),
"extra__wasb__tenant_id": StringField(
"tenant_id": StringField(
lazy_gettext('Tenant Id (Active Directory Auth)'), widget=BS3TextFieldWidget()
),
"extra__wasb__sas_token": PasswordField(
lazy_gettext('SAS Token (optional)'), widget=BS3PasswordFieldWidget()
),
"sas_token": PasswordField(lazy_gettext('SAS Token (optional)'), widget=BS3PasswordFieldWidget()),
}

@staticmethod
@_ensure_prefixes(conn_type='wasb')
def get_ui_field_behaviour() -> dict[str, Any]:
"""Returns custom field behaviour"""
return {
Expand All @@ -96,10 +124,10 @@ def get_ui_field_behaviour() -> dict[str, Any]:
'login': 'account name',
'password': 'secret',
'host': 'account url',
'extra__wasb__connection_string': 'connection string auth',
'extra__wasb__tenant_id': 'tenant',
'extra__wasb__shared_access_key': 'shared access key',
'extra__wasb__sas_token': 'account url or token',
'connection_string': 'connection string auth',
'tenant_id': 'tenant',
'shared_access_key': 'shared access key',
'sas_token': 'account url or token',
},
}

Expand All @@ -119,6 +147,17 @@ def __init__(
except ValueError:
logger.setLevel(logging.WARNING)

def _get_field(self, extra_dict, field_name):
prefix = 'extra__wasb__'
if field_name.startswith('extra_'):
raise ValueError(
f"Got prefixed name {field_name}; please remove the '{prefix}' prefix "
f"when using this method."
)
if field_name in extra_dict:
return extra_dict[field_name] or None
return extra_dict.get(f"{prefix}{field_name}") or None

def get_conn(self) -> BlobServiceClient:
"""Return the BlobServiceClient object."""
conn = self.get_connection(self.conn_id)
Expand All @@ -130,25 +169,25 @@ def get_conn(self) -> BlobServiceClient:
# https://docs.microsoft.com/en-us/azure/storage/blobs/storage-manage-access-to-resources
return BlobServiceClient(account_url=conn.host, **extra)

connection_string = extra.pop('connection_string', extra.pop('extra__wasb__connection_string', None))
connection_string = self._get_field(extra, 'connection_string')
if connection_string:
# connection_string auth takes priority
return BlobServiceClient.from_connection_string(connection_string, **extra)

shared_access_key = extra.pop('shared_access_key', extra.pop('extra__wasb__shared_access_key', None))
shared_access_key = self._get_field(extra, 'shared_access_key')
if shared_access_key:
# using shared access key
return BlobServiceClient(account_url=conn.host, credential=shared_access_key, **extra)

tenant = extra.pop('tenant_id', extra.pop('extra__wasb__tenant_id', None))
tenant = self._get_field(extra, 'tenant_id')
if tenant:
# use Active Directory auth
app_id = conn.login
app_secret = conn.password
token_credential = ClientSecretCredential(tenant, app_id, app_secret)
return BlobServiceClient(account_url=conn.host, credential=token_credential, **extra)

sas_token = extra.pop('sas_token', extra.pop('extra__wasb__sas_token', None))
sas_token = self._get_field(extra, 'sas_token')
if sas_token:
if sas_token.startswith('https'):
return BlobServiceClient(account_url=sas_token, **extra)
Expand Down
Expand Up @@ -34,13 +34,13 @@ There are four ways to connect to Azure Blob Storage using Airflow.
i.e. add specific credentials (client_id, secret, tenant) and subscription id to the Airflow connection.
2. Use `Azure Shared Key Credential
<https://docs.microsoft.com/en-us/rest/api/storageservices/authorize-with-shared-key>`_
i.e. add shared key credentials to ``extra__wasb__shared_access_key`` the Airflow connection.
i.e. add shared key credentials to ``shared_access_key`` the Airflow connection.
3. Use a `SAS Token
<https://docs.microsoft.com/en-us/rest/api/storageservices/create-account-sas>`_
i.e. add a key config to ``extra__wasb__sas_token`` in the Airflow connection.
i.e. add a key config to ``sas_token`` in the Airflow connection.
4. Use a `Connection String
<https://docs.microsoft.com/en-us/azure/data-explorer/kusto/api/connection-strings/storage>`_
i.e. add connection string to ``extra__wasb__connection_string`` in the Airflow connection.
i.e. add connection string to ``connection_string`` in the Airflow connection.

Only one authorization method can be used at a time. If you need to manage multiple credentials or keys then you should
configure multiple connections.
Expand All @@ -67,10 +67,10 @@ Extra (optional)
Specify the extra parameters (as json dictionary) that can be used in Azure connection.
The following parameters are all optional:

* ``extra__wasb__tenant_id``: Specify the tenant to use. Needed for Active Directory (token) authentication.
* ``extra__wasb__shared_access_key``: Specify the shared access key. Needed for shared access key authentication.
* ``extra__wasb__connection_string``: Connection string for use with connection string authentication.
* ``extra__wasb__sas_token``: SAS Token for use with SAS Token authentication.
* ``tenant_id``: Specify the tenant to use. Needed for Active Directory (token) authentication.
* ``shared_access_key``: Specify the shared access key. Needed for shared access key authentication.
* ``connection_string``: Connection string for use with connection string authentication.
* ``sas_token``: SAS Token for use with SAS Token authentication.

When specifying the connection in environment variable you should specify
it using URI syntax.
Expand Down
32 changes: 32 additions & 0 deletions tests/providers/microsoft/azure/hooks/test_wasb.py
Expand Up @@ -28,6 +28,7 @@
from airflow.models import Connection
from airflow.providers.microsoft.azure.hooks.wasb import WasbHook
from airflow.utils import db
from tests.test_utils.providers import get_provider_min_airflow_version, object_exists

# connection_string has a format
CONN_STRING = (
Expand Down Expand Up @@ -390,3 +391,34 @@ def test_connection_failure(self, mock_service):
status, msg = hook.test_connection()
assert status is False
assert msg == "Authentication failed."

def test__ensure_prefixes_removal(self):
"""Ensure that _ensure_prefixes is removed from snowflake when airflow min version >= 2.5.0."""
path = 'airflow.providers.microsoft.azure.hooks.wasb._ensure_prefixes'
if not object_exists(path):
raise Exception(
"You must remove this test. It only exists to "
"remind us to remove decorator `_ensure_prefixes`."
)

if get_provider_min_airflow_version('apache-airflow-providers-microsoft-azure') >= (2, 5):
raise Exception(
"You must now remove `_ensure_prefixes` from WasbHook. The functionality is now taken"
"care of by providers manager."
)

def test___ensure_prefixes(self):
"""
Check that ensure_prefixes decorator working properly
Note: remove this test when removing ensure_prefixes (after min airflow version >= 2.5.0
"""
assert list(WasbHook.get_ui_field_behaviour()['placeholders'].keys()) == [
'extra',
'login',
'password',
'host',
'extra__wasb__connection_string',
'extra__wasb__tenant_id',
'extra__wasb__shared_access_key',
'extra__wasb__sas_token',
]

0 comments on commit d51de50

Please sign in to comment.