From 06d35b74fd3c46cf73f8a976571180d9503e8ca5 Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Sun, 23 Oct 2022 11:38:43 -0700 Subject: [PATCH 1/2] Allow and prefer non-prefixed extra fields for AzureDataFactoryHook From airflow version 2.3, extra prefixes are not required so we enable them here. --- .../providers/microsoft/azure/hooks/adx.py | 62 ++++++++++++------- .../microsoft/azure/hooks/test_adx.py | 60 ++++++++++++++++-- 2 files changed, 97 insertions(+), 25 deletions(-) diff --git a/airflow/providers/microsoft/azure/hooks/adx.py b/airflow/providers/microsoft/azure/hooks/adx.py index 129f8464ac4e4..634c18c172faa 100644 --- a/airflow/providers/microsoft/azure/hooks/adx.py +++ b/airflow/providers/microsoft/azure/hooks/adx.py @@ -25,6 +25,7 @@ """ from __future__ import annotations +import warnings from typing import Any from azure.kusto.data.exceptions import KustoServiceError @@ -33,6 +34,7 @@ from airflow.exceptions import AirflowException from airflow.hooks.base import BaseHook +from airflow.providers.microsoft.azure.utils import _ensure_prefixes class AzureDataExplorerHook(BaseHook): @@ -82,21 +84,18 @@ def get_connection_form_widgets() -> dict[str, Any]: from wtforms import PasswordField, StringField return { - "extra__azure_data_explorer__tenant": StringField( - lazy_gettext("Tenant ID"), widget=BS3TextFieldWidget() - ), - "extra__azure_data_explorer__auth_method": StringField( - lazy_gettext("Authentication Method"), widget=BS3TextFieldWidget() - ), - "extra__azure_data_explorer__certificate": PasswordField( + "tenant": StringField(lazy_gettext("Tenant ID"), widget=BS3TextFieldWidget()), + "auth_method": StringField(lazy_gettext("Authentication Method"), widget=BS3TextFieldWidget()), + "certificate": PasswordField( lazy_gettext("Application PEM Certificate"), widget=BS3PasswordFieldWidget() ), - "extra__azure_data_explorer__thumbprint": PasswordField( + "thumbprint": PasswordField( lazy_gettext("Application Certificate Thumbprint"), widget=BS3PasswordFieldWidget() ), } @staticmethod + @_ensure_prefixes(conn_type="azure_data_explorer") def get_ui_field_behaviour() -> dict[str, Any]: """Returns custom field behaviour""" return { @@ -108,43 +107,64 @@ def get_ui_field_behaviour() -> dict[str, Any]: "placeholders": { "login": "Varies with authentication method", "password": "Varies with authentication method", - "extra__azure_data_explorer__auth_method": "AAD_APP/AAD_APP_CERT/AAD_CREDS/AAD_DEVICE", - "extra__azure_data_explorer__tenant": "Used with AAD_APP/AAD_APP_CERT/AAD_CREDS", - "extra__azure_data_explorer__certificate": "Used with AAD_APP_CERT", - "extra__azure_data_explorer__thumbprint": "Used with AAD_APP_CERT", + "auth_method": "AAD_APP/AAD_APP_CERT/AAD_CREDS/AAD_DEVICE", + "tenant": "Used with AAD_APP/AAD_APP_CERT/AAD_CREDS", + "certificate": "Used with AAD_APP_CERT", + "thumbprint": "Used with AAD_APP_CERT", }, } def __init__(self, azure_data_explorer_conn_id: str = default_conn_name) -> None: super().__init__() self.conn_id = azure_data_explorer_conn_id - self.connection = self.get_conn() + self.connection = self.get_conn() # todo: make this a property, or just delete def get_conn(self) -> KustoClient: """Return a KustoClient object.""" conn = self.get_connection(self.conn_id) + extras = conn.extra_dejson cluster = conn.host if not cluster: raise AirflowException("Host connection option is required") + def warn_if_collison(key, backcompat_key): + if backcompat_key in extras: + warnings.warn( + f"Conflicting params `{key}` and `{backcompat_key}` found in extras for conn " + f"{self.conn_id}. Using value for `{key}`. Please ensure this is the correct value " + f"and remove the backcompat key `{backcompat_key}`." + ) + def get_required_param(name: str) -> str: - """Extract required parameter value from connection, raise exception if not found""" - value = conn.extra_dejson.get(name) + """ + Extract required parameter value from connection, raise exception if not found. + + Warns if both ``foo`` and ``extra__azure_data_explorer__foo`` found in conn extra. + + Prefers unprefixed field. + """ + backcompat_prefix = "extra__azure_data_explorer__" + backcompat_key = f"{backcompat_prefix}{name}" + value = extras.get(name) + if value: + warn_if_collison(name, backcompat_key) + if not value: + value = extras.get(backcompat_key) if not value: raise AirflowException(f"Required connection parameter is missing: `{name}`") return value - auth_method = get_required_param("extra__azure_data_explorer__auth_method") + auth_method = get_required_param("auth_method") if auth_method == "AAD_APP": - tenant = get_required_param("extra__azure_data_explorer__tenant") + tenant = get_required_param("tenant") kcsb = KustoConnectionStringBuilder.with_aad_application_key_authentication( cluster, conn.login, conn.password, tenant ) elif auth_method == "AAD_APP_CERT": - certificate = get_required_param("extra__azure_data_explorer__certificate") - thumbprint = get_required_param("extra__azure_data_explorer__thumbprint") - tenant = get_required_param("extra__azure_data_explorer__tenant") + certificate = get_required_param("certificate") + thumbprint = get_required_param("thumbprint") + tenant = get_required_param("tenant") kcsb = KustoConnectionStringBuilder.with_aad_application_certificate_authentication( cluster, conn.login, @@ -153,7 +173,7 @@ def get_required_param(name: str) -> str: tenant, ) elif auth_method == "AAD_CREDS": - tenant = get_required_param("extra__azure_data_explorer__tenant") + tenant = get_required_param("tenant") kcsb = KustoConnectionStringBuilder.with_aad_user_password_authentication( cluster, conn.login, conn.password, tenant ) diff --git a/tests/providers/microsoft/azure/hooks/test_adx.py b/tests/providers/microsoft/azure/hooks/test_adx.py index 4e0c7d37e2f12..e6e9bc3189778 100644 --- a/tests/providers/microsoft/azure/hooks/test_adx.py +++ b/tests/providers/microsoft/azure/hooks/test_adx.py @@ -18,24 +18,26 @@ from __future__ import annotations import json -import unittest +import os from unittest import mock +from unittest.mock import patch import pytest from azure.kusto.data.request import ClientRequestProperties, KustoClient, KustoConnectionStringBuilder +from pytest import param from airflow.exceptions import AirflowException from airflow.models import Connection from airflow.providers.microsoft.azure.hooks.adx import AzureDataExplorerHook from airflow.utils import db from airflow.utils.session import create_session +from tests.test_utils.providers import get_provider_min_airflow_version ADX_TEST_CONN_ID = "adx_test_connection_id" -class TestAzureDataExplorerHook(unittest.TestCase): - def tearDown(self): - super().tearDown() +class TestAzureDataExplorerHook: + def teardown(self): with create_session() as session: session.query(Connection).filter(Connection.conn_id == ADX_TEST_CONN_ID).delete() @@ -191,3 +193,53 @@ def test_run_query(self, mock_execute): properties = ClientRequestProperties() properties.set_option("option1", "option_value") assert mock_execute.called_with("Database", "Logs | schema", properties=properties) + + def test_get_ui_field_behaviour_placeholders(self): + """ + Check that ensure_prefixes decorator working properly + + Note: remove this test and the _ensure_prefixes decorator after min airflow version >= 2.5.0 + """ + assert list(AzureDataExplorerHook.get_ui_field_behaviour()["placeholders"].keys()) == [ + "login", + "password", + "extra__azure_data_explorer__auth_method", + "extra__azure_data_explorer__tenant", + "extra__azure_data_explorer__certificate", + "extra__azure_data_explorer__thumbprint", + ] + if get_provider_min_airflow_version("apache-airflow-providers-microsoft-azure") >= (2, 5): + raise Exception( + "You must now remove `_ensure_prefixes` from azure utils." + " The functionality is now taken care of by providers manager." + ) + + @pytest.mark.parametrize( + "uri", + [ + param( + "a://usr:pw@host?extra__azure_data_explorer__tenant=my-tenant" + "&extra__azure_data_explorer__auth_method=AAD_APP", + id="prefix", + ), + param("a://usr:pw@host?tenant=my-tenant&auth_method=AAD_APP", id="no-prefix"), + ], + ) + @patch("airflow.providers.microsoft.azure.hooks.adx.KustoConnectionStringBuilder") + def test_backcompat_prefix_works(self, mock_client, uri): + mock_with = mock_client.with_aad_application_key_authentication + with patch.dict(os.environ, AIRFLOW_CONN_MY_CONN=uri): + AzureDataExplorerHook(azure_data_explorer_conn_id="my_conn") # get_conn is called in init + mock_with.assert_called_once_with("host", "usr", "pw", "my-tenant") + + @patch("airflow.providers.microsoft.azure.hooks.adx.KustoConnectionStringBuilder") + def test_backcompat_prefix_both_causes_warning(self, mock_client): + mock_with = mock_client.with_aad_application_key_authentication + with patch.dict( + in_dict=os.environ, + AIRFLOW_CONN_MY_CONN="a://usr:pw@host?tenant=my-tenant&auth_method=AAD_APP" + "&extra__azure_data_explorer__auth_method=AAD_APP", + ): + with pytest.warns(Warning, match="Using value for `auth_method`"): + AzureDataExplorerHook(azure_data_explorer_conn_id="my_conn") # get_conn is called in init + mock_with.assert_called_once_with("host", "usr", "pw", "my-tenant") From c95b2535b41ea3f8f10e10b81df72c09cacbd68a Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Sun, 23 Oct 2022 21:20:24 -0700 Subject: [PATCH 2/2] add missing util --- airflow/providers/microsoft/azure/utils.py | 48 ++++++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 airflow/providers/microsoft/azure/utils.py diff --git a/airflow/providers/microsoft/azure/utils.py b/airflow/providers/microsoft/azure/utils.py new file mode 100644 index 0000000000000..e3d24da897641 --- /dev/null +++ b/airflow/providers/microsoft/azure/utils.py @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from functools import wraps + + +def _ensure_prefixes(conn_type): + """ + Remove when provider min airflow version >= 2.5.0 since this is handled by + provider manager from that version. + """ + + def dec(func): + @wraps(func) + def inner(): + field_behaviors = func() + conn_attrs = {"host", "schema", "login", "password", "port", "extra"} + + def _ensure_prefix(field): + if field not in conn_attrs and not field.startswith("extra__"): + return f"extra__{conn_type}__{field}" + else: + return field + + if "placeholders" in field_behaviors: + placeholders = field_behaviors["placeholders"] + field_behaviors["placeholders"] = {_ensure_prefix(k): v for k, v in placeholders.items()} + return field_behaviors + + return inner + + return dec