Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 41 additions & 21 deletions airflow/providers/microsoft/azure/hooks/adx.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"""
from __future__ import annotations

import warnings
from typing import Any

from azure.kusto.data.exceptions import KustoServiceError
Expand All @@ -33,6 +34,7 @@

from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
from airflow.providers.microsoft.azure.utils import _ensure_prefixes


class AzureDataExplorerHook(BaseHook):
Expand Down Expand Up @@ -82,21 +84,18 @@ def get_connection_form_widgets() -> dict[str, Any]:
from wtforms import PasswordField, StringField

return {
"extra__azure_data_explorer__tenant": StringField(
lazy_gettext("Tenant ID"), widget=BS3TextFieldWidget()
),
"extra__azure_data_explorer__auth_method": StringField(
lazy_gettext("Authentication Method"), widget=BS3TextFieldWidget()
),
"extra__azure_data_explorer__certificate": PasswordField(
"tenant": StringField(lazy_gettext("Tenant ID"), widget=BS3TextFieldWidget()),
"auth_method": StringField(lazy_gettext("Authentication Method"), widget=BS3TextFieldWidget()),
"certificate": PasswordField(
lazy_gettext("Application PEM Certificate"), widget=BS3PasswordFieldWidget()
),
"extra__azure_data_explorer__thumbprint": PasswordField(
"thumbprint": PasswordField(
lazy_gettext("Application Certificate Thumbprint"), widget=BS3PasswordFieldWidget()
),
}

@staticmethod
@_ensure_prefixes(conn_type="azure_data_explorer")
def get_ui_field_behaviour() -> dict[str, Any]:
"""Returns custom field behaviour"""
return {
Expand All @@ -108,43 +107,64 @@ def get_ui_field_behaviour() -> dict[str, Any]:
"placeholders": {
"login": "Varies with authentication method",
"password": "Varies with authentication method",
"extra__azure_data_explorer__auth_method": "AAD_APP/AAD_APP_CERT/AAD_CREDS/AAD_DEVICE",
"extra__azure_data_explorer__tenant": "Used with AAD_APP/AAD_APP_CERT/AAD_CREDS",
"extra__azure_data_explorer__certificate": "Used with AAD_APP_CERT",
"extra__azure_data_explorer__thumbprint": "Used with AAD_APP_CERT",
"auth_method": "AAD_APP/AAD_APP_CERT/AAD_CREDS/AAD_DEVICE",
"tenant": "Used with AAD_APP/AAD_APP_CERT/AAD_CREDS",
"certificate": "Used with AAD_APP_CERT",
"thumbprint": "Used with AAD_APP_CERT",
},
}

def __init__(self, azure_data_explorer_conn_id: str = default_conn_name) -> None:
super().__init__()
self.conn_id = azure_data_explorer_conn_id
self.connection = self.get_conn()
self.connection = self.get_conn() # todo: make this a property, or just delete

def get_conn(self) -> KustoClient:
"""Return a KustoClient object."""
conn = self.get_connection(self.conn_id)
extras = conn.extra_dejson
cluster = conn.host
if not cluster:
raise AirflowException("Host connection option is required")

def warn_if_collison(key, backcompat_key):
if backcompat_key in extras:
warnings.warn(
f"Conflicting params `{key}` and `{backcompat_key}` found in extras for conn "
f"{self.conn_id}. Using value for `{key}`. Please ensure this is the correct value "
f"and remove the backcompat key `{backcompat_key}`."
)

def get_required_param(name: str) -> str:
"""Extract required parameter value from connection, raise exception if not found"""
value = conn.extra_dejson.get(name)
"""
Extract required parameter value from connection, raise exception if not found.

Warns if both ``foo`` and ``extra__azure_data_explorer__foo`` found in conn extra.

Prefers unprefixed field.
"""
backcompat_prefix = "extra__azure_data_explorer__"
backcompat_key = f"{backcompat_prefix}{name}"
value = extras.get(name)
if value:
warn_if_collison(name, backcompat_key)
if not value:
value = extras.get(backcompat_key)
if not value:
raise AirflowException(f"Required connection parameter is missing: `{name}`")
return value

auth_method = get_required_param("extra__azure_data_explorer__auth_method")
auth_method = get_required_param("auth_method")

if auth_method == "AAD_APP":
tenant = get_required_param("extra__azure_data_explorer__tenant")
tenant = get_required_param("tenant")
kcsb = KustoConnectionStringBuilder.with_aad_application_key_authentication(
cluster, conn.login, conn.password, tenant
)
elif auth_method == "AAD_APP_CERT":
certificate = get_required_param("extra__azure_data_explorer__certificate")
thumbprint = get_required_param("extra__azure_data_explorer__thumbprint")
tenant = get_required_param("extra__azure_data_explorer__tenant")
certificate = get_required_param("certificate")
thumbprint = get_required_param("thumbprint")
tenant = get_required_param("tenant")
kcsb = KustoConnectionStringBuilder.with_aad_application_certificate_authentication(
cluster,
conn.login,
Expand All @@ -153,7 +173,7 @@ def get_required_param(name: str) -> str:
tenant,
)
elif auth_method == "AAD_CREDS":
tenant = get_required_param("extra__azure_data_explorer__tenant")
tenant = get_required_param("tenant")
kcsb = KustoConnectionStringBuilder.with_aad_user_password_authentication(
cluster, conn.login, conn.password, tenant
)
Expand Down
48 changes: 48 additions & 0 deletions airflow/providers/microsoft/azure/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from functools import wraps


def _ensure_prefixes(conn_type):
"""
Remove when provider min airflow version >= 2.5.0 since this is handled by
provider manager from that version.
"""

def dec(func):
@wraps(func)
def inner():
field_behaviors = func()
conn_attrs = {"host", "schema", "login", "password", "port", "extra"}

def _ensure_prefix(field):
if field not in conn_attrs and not field.startswith("extra__"):
return f"extra__{conn_type}__{field}"
else:
return field

if "placeholders" in field_behaviors:
placeholders = field_behaviors["placeholders"]
field_behaviors["placeholders"] = {_ensure_prefix(k): v for k, v in placeholders.items()}
return field_behaviors

return inner

return dec
60 changes: 56 additions & 4 deletions tests/providers/microsoft/azure/hooks/test_adx.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,26 @@
from __future__ import annotations

import json
import unittest
import os
from unittest import mock
from unittest.mock import patch

import pytest
from azure.kusto.data.request import ClientRequestProperties, KustoClient, KustoConnectionStringBuilder
from pytest import param

from airflow.exceptions import AirflowException
from airflow.models import Connection
from airflow.providers.microsoft.azure.hooks.adx import AzureDataExplorerHook
from airflow.utils import db
from airflow.utils.session import create_session
from tests.test_utils.providers import get_provider_min_airflow_version

ADX_TEST_CONN_ID = "adx_test_connection_id"


class TestAzureDataExplorerHook(unittest.TestCase):
def tearDown(self):
super().tearDown()
class TestAzureDataExplorerHook:
def teardown(self):
with create_session() as session:
session.query(Connection).filter(Connection.conn_id == ADX_TEST_CONN_ID).delete()

Expand Down Expand Up @@ -191,3 +193,53 @@ def test_run_query(self, mock_execute):
properties = ClientRequestProperties()
properties.set_option("option1", "option_value")
assert mock_execute.called_with("Database", "Logs | schema", properties=properties)

def test_get_ui_field_behaviour_placeholders(self):
"""
Check that ensure_prefixes decorator working properly

Note: remove this test and the _ensure_prefixes decorator after min airflow version >= 2.5.0
"""
assert list(AzureDataExplorerHook.get_ui_field_behaviour()["placeholders"].keys()) == [
"login",
"password",
"extra__azure_data_explorer__auth_method",
"extra__azure_data_explorer__tenant",
"extra__azure_data_explorer__certificate",
"extra__azure_data_explorer__thumbprint",
]
if get_provider_min_airflow_version("apache-airflow-providers-microsoft-azure") >= (2, 5):
raise Exception(
"You must now remove `_ensure_prefixes` from azure utils."
" The functionality is now taken care of by providers manager."
)

@pytest.mark.parametrize(
"uri",
[
param(
"a://usr:pw@host?extra__azure_data_explorer__tenant=my-tenant"
"&extra__azure_data_explorer__auth_method=AAD_APP",
id="prefix",
),
param("a://usr:pw@host?tenant=my-tenant&auth_method=AAD_APP", id="no-prefix"),
],
)
@patch("airflow.providers.microsoft.azure.hooks.adx.KustoConnectionStringBuilder")
def test_backcompat_prefix_works(self, mock_client, uri):
mock_with = mock_client.with_aad_application_key_authentication
with patch.dict(os.environ, AIRFLOW_CONN_MY_CONN=uri):
AzureDataExplorerHook(azure_data_explorer_conn_id="my_conn") # get_conn is called in init
mock_with.assert_called_once_with("host", "usr", "pw", "my-tenant")

@patch("airflow.providers.microsoft.azure.hooks.adx.KustoConnectionStringBuilder")
def test_backcompat_prefix_both_causes_warning(self, mock_client):
mock_with = mock_client.with_aad_application_key_authentication
with patch.dict(
in_dict=os.environ,
AIRFLOW_CONN_MY_CONN="a://usr:pw@host?tenant=my-tenant&auth_method=AAD_APP"
"&extra__azure_data_explorer__auth_method=AAD_APP",
):
with pytest.warns(Warning, match="Using value for `auth_method`"):
AzureDataExplorerHook(azure_data_explorer_conn_id="my_conn") # get_conn is called in init
mock_with.assert_called_once_with("host", "usr", "pw", "my-tenant")