Skip to content

Commit

Permalink
Add test_connection method to AzureDataFactoryHook (#21924)
Browse files Browse the repository at this point in the history
  • Loading branch information
josh-fell committed Mar 2, 2022
1 parent a0e2eba commit f42559a
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 2 deletions.
22 changes: 21 additions & 1 deletion airflow/providers/microsoft/azure/hooks/data_factory.py
Expand Up @@ -30,7 +30,7 @@
import inspect
import time
from functools import wraps
from typing import Any, Callable, Dict, Optional, Set, Union
from typing import Any, Callable, Dict, Optional, Set, Tuple, Union

from azure.core.polling import LROPoller
from azure.identity import ClientSecretCredential, DefaultAzureCredential
Expand Down Expand Up @@ -891,3 +891,23 @@ def cancel_trigger(
:param config: Extra parameters for the ADF client.
"""
self.get_conn().trigger_runs.cancel(resource_group_name, factory_name, trigger_name, run_id, **config)

def test_connection(self) -> Tuple[bool, str]:
"""Test a configured Azure Data Factory connection."""
success = (True, "Successfully connected to Azure Data Factory.")

try:
# Attempt to list existing factories under the configured subscription and retrieve the first in
# the returned iterator. The Azure Data Factory API does allow for creation of a
# DataFactoryManagementClient with incorrect values but then will fail properly once items are
# retrieved using the client. We need to _actually_ try to retrieve an object to properly test the
# connection.
next(self.get_conn().factories.list())
return success
except StopIteration:
# If the iterator returned is empty it should still be considered a successful connection since
# it's possible to create a Data Factory via the ``AzureDataFactoryHook`` and none could
# legitimately exist yet.
return success
except Exception as e:
return False, str(e)
68 changes: 67 additions & 1 deletion tests/providers/microsoft/azure/hooks/test_azure_data_factory.py
Expand Up @@ -18,10 +18,11 @@

import json
from typing import Type
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, PropertyMock, patch

import pytest
from azure.identity import ClientSecretCredential, DefaultAzureCredential
from azure.mgmt.datafactory.models import FactoryListResponse
from pytest import fixture

from airflow.exceptions import AirflowException
Expand Down Expand Up @@ -74,9 +75,37 @@ def setup_module():
}
),
)
connection_missing_subscription_id = Connection(
conn_id="azure_data_factory_missing_subscription_id",
conn_type="azure_data_factory",
login="clientId",
password="clientSecret",
extra=json.dumps(
{
"extra__azure_data_factory__tenantId": "tenantId",
"extra__azure_data_factory__resource_group_name": DEFAULT_RESOURCE_GROUP,
"extra__azure_data_factory__factory_name": DEFAULT_FACTORY,
}
),
)
connection_missing_tenant_id = Connection(
conn_id="azure_data_factory_missing_tenant_id",
conn_type="azure_data_factory",
login="clientId",
password="clientSecret",
extra=json.dumps(
{
"extra__azure_data_factory__subscriptionId": "subscriptionId",
"extra__azure_data_factory__resource_group_name": DEFAULT_RESOURCE_GROUP,
"extra__azure_data_factory__factory_name": DEFAULT_FACTORY,
}
),
)

db.merge_conn(connection_client_secret)
db.merge_conn(connection_default_credential)
db.merge_conn(connection_missing_subscription_id)
db.merge_conn(connection_missing_tenant_id)


@fixture
Expand Down Expand Up @@ -526,3 +555,40 @@ def test_cancel_trigger(hook: AzureDataFactoryHook, user_args, sdk_args):
hook.cancel_trigger(*user_args)

hook._conn.trigger_runs.cancel.assert_called_with(*sdk_args)


@pytest.mark.parametrize(
argnames="factory_list_result",
argvalues=[iter([FactoryListResponse]), iter([])],
ids=["factory_exists", "factory_does_not_exist"],
)
def test_connection_success(hook, factory_list_result):
hook.get_conn().factories.list.return_value = factory_list_result
status, msg = hook.test_connection()

assert status is True
assert msg == "Successfully connected to Azure Data Factory."


def test_connection_failure(hook):
hook.get_conn().factories.list = PropertyMock(side_effect=Exception("Authentication failed."))
status, msg = hook.test_connection()

assert status is False
assert msg == "Authentication failed."


def test_connection_failure_missing_subscription_id():
hook = AzureDataFactoryHook("azure_data_factory_missing_subscription_id")
status, msg = hook.test_connection()

assert status is False
assert msg == "A Subscription ID is required to connect to Azure Data Factory."


def test_connection_failure_missing_tenant_id():
hook = AzureDataFactoryHook("azure_data_factory_missing_tenant_id")
status, msg = hook.test_connection()

assert status is False
assert msg == "A Tenant ID is required when authenticating with Client ID and Secret."

0 comments on commit f42559a

Please sign in to comment.