Skip to content

Commit

Permalink
make DefaultAzureCredential configurable in AzureKeyVaultBackend (#3…
Browse files Browse the repository at this point in the history
…5052)

* feat(azure): make DefaultAzureCredential configurable in AzureKeyVaultBackend

* test(providers/azure): extract common module string as a variable

* test(providers/azure): add test case test_client_authenticate_with_default_azure_credential_and_customized_configuration

* docs(providers/microsoft): update document for azure secret backend kwargs
  • Loading branch information
Lee-W committed Oct 30, 2023
1 parent 2b011b2 commit ebcb162
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 20 deletions.
28 changes: 25 additions & 3 deletions airflow/providers/microsoft/azure/secrets/key_vault.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
This module contains Azure Key Vault Backend.
.. spelling:word-list::
Entra
"""

from __future__ import annotations

import logging
Expand Down Expand Up @@ -76,8 +84,11 @@ class AzureKeyVaultBackend(BaseSecretsBackend, LoggingMixin):
If not given, it falls back to ``DefaultAzureCredential``
:param client_id: The client id of an Azure Key Vault to use.
If not given, it falls back to ``DefaultAzureCredential``
:param client_secret: The client secret of an Azure Key Vault to use.
If not given, it falls back to ``DefaultAzureCredential``
:param managed_identity_client_id: The client ID of a user-assigned managed identity.
If provided with `workload_identity_tenant_id`, they'll pass to ``DefaultAzureCredential``.
:param workload_identity_tenant_id: ID of the application's Microsoft Entra tenant.
Also called its "directory" ID.
If provided with `managed_identity_client_id`, they'll pass to ``DefaultAzureCredential``.
"""

def __init__(
Expand All @@ -91,6 +102,8 @@ def __init__(
tenant_id: str = "",
client_id: str = "",
client_secret: str = "",
managed_identity_client_id: str = "",
workload_identity_tenant_id: str = "",
**kwargs,
) -> None:
super().__init__()
Expand Down Expand Up @@ -118,6 +131,8 @@ def __init__(
self.tenant_id = tenant_id
self.client_id = client_id
self.client_secret = client_secret
self.managed_identity_client_id = managed_identity_client_id
self.workload_identity_tenant_id = workload_identity_tenant_id
self.kwargs = kwargs

@cached_property
Expand All @@ -127,7 +142,14 @@ def client(self) -> SecretClient:
if all([self.tenant_id, self.client_id, self.client_secret]):
credential = ClientSecretCredential(self.tenant_id, self.client_id, self.client_secret)
else:
credential = DefaultAzureCredential()
if self.managed_identity_client_id and self.workload_identity_tenant_id:
credential = DefaultAzureCredential(
managed_identity_client_id=self.managed_identity_client_id,
workload_identity_tenant_id=self.workload_identity_tenant_id,
additionally_allowed_tenants=[self.workload_identity_tenant_id],
)
else:
credential = DefaultAzureCredential()
client = SecretClient(vault_url=self.vault_url, credential=credential, **self.kwargs)
return client

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,16 @@ Storing and Retrieving Variables
If you have set ``variables_prefix`` as ``airflow-variables``, then for an Variable key of ``hello``,
you would want to store your Variable at ``airflow-variables-hello``.


Authentication
""""""""""""""
There are 3 ways to authenticate Azure Key Vault backend.

1. Set ``tenant_id``, ``client_id``, ``client_secret`` (using `ClientSecretCredential <https://learn.microsoft.com/en-us/python/api/azure-identity/azure.identity.clientsecretcredential?view=azure-python>`_)
2. Set ``managed_identity_client_id``, ``workload_identity_tenant_id`` (using `DefaultAzureCredential <https://learn.microsoft.com/en-us/python/api/azure-identity/azure.identity.defaultazurecredential?view=azure-python>`_ with these arguments)
3. Not providing extra connection configuration for falling back to `DefaultAzureCredential <https://learn.microsoft.com/en-us/python/api/azure-identity/azure.identity.defaultazurecredential?view=azure-python>`_


Reference
"""""""""

Expand Down
61 changes: 44 additions & 17 deletions tests/providers/microsoft/azure/secrets/test_azure_key_vault.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,18 @@
from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.microsoft.azure.secrets.key_vault import AzureKeyVaultBackend

KEY_VAULT_MODULE = "airflow.providers.microsoft.azure.secrets.key_vault"


class TestAzureKeyVaultBackend:
@mock.patch("airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend.get_conn_value")
@mock.patch(f"{KEY_VAULT_MODULE}.AzureKeyVaultBackend.get_conn_value")
def test_get_connection(self, mock_get_value):
mock_get_value.return_value = "scheme://user:pass@host:100"
conn = AzureKeyVaultBackend().get_connection("fake_conn")
assert conn.host == "host"

@mock.patch("airflow.providers.microsoft.azure.secrets.key_vault.DefaultAzureCredential")
@mock.patch("airflow.providers.microsoft.azure.secrets.key_vault.SecretClient")
@mock.patch(f"{KEY_VAULT_MODULE}.DefaultAzureCredential")
@mock.patch(f"{KEY_VAULT_MODULE}.SecretClient")
def test_get_conn_uri(self, mock_secret_client, mock_azure_cred):
mock_cred = mock.Mock()
mock_sec_client = mock.Mock()
Expand All @@ -53,7 +55,7 @@ def test_get_conn_uri(self, mock_secret_client, mock_azure_cred):
)
assert returned_uri == "postgresql://airflow:airflow@host:5432/airflow"

@mock.patch("airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend.client")
@mock.patch(f"{KEY_VAULT_MODULE}.AzureKeyVaultBackend.client")
def test_get_conn_uri_non_existent_key(self, mock_client):
"""
Test that if the key with connection ID is not present,
Expand All @@ -67,15 +69,15 @@ def test_get_conn_uri_non_existent_key(self, mock_client):
assert backend.get_conn_uri(conn_id=conn_id) is None
assert backend.get_connection(conn_id=conn_id) is None

@mock.patch("airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend.client")
@mock.patch(f"{KEY_VAULT_MODULE}.AzureKeyVaultBackend.client")
def test_get_variable(self, mock_client):
mock_client.get_secret.return_value = mock.Mock(value="world")
backend = AzureKeyVaultBackend()
returned_uri = backend.get_variable("hello")
mock_client.get_secret.assert_called_with(name="airflow-variables-hello")
assert "world" == returned_uri

@mock.patch("airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend.client")
@mock.patch(f"{KEY_VAULT_MODULE}.AzureKeyVaultBackend.client")
def test_get_variable_non_existent_key(self, mock_client):
"""
Test that if Variable key is not present,
Expand All @@ -85,7 +87,7 @@ def test_get_variable_non_existent_key(self, mock_client):
backend = AzureKeyVaultBackend()
assert backend.get_variable("test_mysql") is None

@mock.patch("airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend.client")
@mock.patch(f"{KEY_VAULT_MODULE}.AzureKeyVaultBackend.client")
def test_get_secret_value_not_found(self, mock_client):
"""
Test that if a non-existent secret returns None
Expand All @@ -96,7 +98,7 @@ def test_get_secret_value_not_found(self, mock_client):
backend._get_secret(path_prefix=backend.connections_prefix, secret_id="test_non_existent") is None
)

@mock.patch("airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend.client")
@mock.patch(f"{KEY_VAULT_MODULE}.AzureKeyVaultBackend.client")
def test_get_secret_value(self, mock_client):
"""
Test that get_secret returns the secret value
Expand All @@ -107,7 +109,7 @@ def test_get_secret_value(self, mock_client):
mock_client.get_secret.assert_called_with(name="af-secrets-test-mysql-password")
assert secret_val == "super-secret"

@mock.patch("airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend._get_secret")
@mock.patch(f"{KEY_VAULT_MODULE}.AzureKeyVaultBackend._get_secret")
def test_connection_prefix_none_value(self, mock_get_secret):
"""
Test that if Connections prefix is None,
Expand All @@ -125,7 +127,7 @@ def test_connection_prefix_none_value(self, mock_get_secret):
assert backend.get_conn_uri("test_mysql") is None
mock_get_secret.assert_not_called()

@mock.patch("airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend._get_secret")
@mock.patch(f"{KEY_VAULT_MODULE}.AzureKeyVaultBackend._get_secret")
def test_variable_prefix_none_value(self, mock_get_secret):
"""
Test that if Variables prefix is None,
Expand All @@ -138,7 +140,7 @@ def test_variable_prefix_none_value(self, mock_get_secret):
assert backend.get_variable("hello") is None
mock_get_secret.assert_not_called()

@mock.patch("airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend._get_secret")
@mock.patch(f"{KEY_VAULT_MODULE}.AzureKeyVaultBackend._get_secret")
def test_config_prefix_none_value(self, mock_get_secret):
"""
Test that if Config prefix is None,
Expand All @@ -151,20 +153,45 @@ def test_config_prefix_none_value(self, mock_get_secret):
assert backend.get_config("test_mysql") is None
mock_get_secret.assert_not_called()

@mock.patch("airflow.providers.microsoft.azure.secrets.key_vault.DefaultAzureCredential")
@mock.patch("airflow.providers.microsoft.azure.secrets.key_vault.ClientSecretCredential")
@mock.patch("airflow.providers.microsoft.azure.secrets.key_vault.SecretClient")
@mock.patch(f"{KEY_VAULT_MODULE}.DefaultAzureCredential")
@mock.patch(f"{KEY_VAULT_MODULE}.ClientSecretCredential")
@mock.patch(f"{KEY_VAULT_MODULE}.SecretClient")
def test_client_authenticate_with_default_azure_credential(
self, mock_client, mock_client_secret_credential, mock_defaul_azure_credential
):
"""
Test that if AzureKeyValueBackend is authenticated with DefaultAzureCredential
tenant_id, client_id and client_secret are not provided
"""
backend = AzureKeyVaultBackend(vault_url="https://example-akv-resource-name.vault.azure.net/")
backend.client
assert not mock_client_secret_credential.called
mock_defaul_azure_credential.assert_called_once()

@mock.patch("airflow.providers.microsoft.azure.secrets.key_vault.DefaultAzureCredential")
@mock.patch("airflow.providers.microsoft.azure.secrets.key_vault.ClientSecretCredential")
@mock.patch("airflow.providers.microsoft.azure.secrets.key_vault.SecretClient")
@mock.patch(f"{KEY_VAULT_MODULE}.DefaultAzureCredential")
@mock.patch(f"{KEY_VAULT_MODULE}.ClientSecretCredential")
@mock.patch(f"{KEY_VAULT_MODULE}.SecretClient")
def test_client_authenticate_with_default_azure_credential_and_customized_configuration(
self, mock_client, mock_client_secret_credential, mock_defaul_azure_credential
):
backend = AzureKeyVaultBackend(
vault_url="https://example-akv-resource-name.vault.azure.net/",
managed_identity_client_id="managed_identity_client_id",
workload_identity_tenant_id="workload_identity_tenant_id",
additionally_allowed_tenants=["workload_identity_tenant_id"],
)
backend.client
assert not mock_client_secret_credential.called
mock_defaul_azure_credential.assert_called_once_with(
managed_identity_client_id="managed_identity_client_id",
workload_identity_tenant_id="workload_identity_tenant_id",
additionally_allowed_tenants=["workload_identity_tenant_id"],
)

@mock.patch(f"{KEY_VAULT_MODULE}.DefaultAzureCredential")
@mock.patch(f"{KEY_VAULT_MODULE}.ClientSecretCredential")
@mock.patch(f"{KEY_VAULT_MODULE}.SecretClient")
def test_client_authenticate_with_client_secret_credential(
self, mock_client, mock_client_secret_credential, mock_defaul_azure_credential
):
Expand Down

0 comments on commit ebcb162

Please sign in to comment.