Skip to content

Commit

Permalink
Add DefaultAzureCredential support to AzureContainerRegistryHook (#33825
Browse files Browse the repository at this point in the history
)

* feat(providers/microsoft): add DefaultAzureCredential support to AzureContainerRegistryHook

* feat(providers/microsoft): pin azure-mgmt-containerregistry to >= 8.0.0

* docs(providers/microsoft): update connection documnetation for Azure Container Registry connection
  • Loading branch information
Lee-W committed Aug 30, 2023
1 parent 23b15e6 commit 539797f
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 1 deletion.
44 changes: 43 additions & 1 deletion airflow/providers/microsoft/azure/hooks/container_registry.py
Expand Up @@ -21,9 +21,12 @@
from functools import cached_property
from typing import Any

from azure.identity import DefaultAzureCredential
from azure.mgmt.containerinstance.models import ImageRegistryCredential
from azure.mgmt.containerregistry import ContainerRegistryManagementClient

from airflow.hooks.base import BaseHook
from airflow.providers.microsoft.azure.utils import get_field


class AzureContainerRegistryHook(BaseHook):
Expand All @@ -40,6 +43,24 @@ class AzureContainerRegistryHook(BaseHook):
conn_type = "azure_container_registry"
hook_name = "Azure Container Registry"

@staticmethod
def get_connection_form_widgets() -> dict[str, Any]:
"""Returns connection widgets to add to connection form."""
from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
from flask_babel import lazy_gettext
from wtforms import StringField

return {
"subscription_id": StringField(
lazy_gettext("Subscription ID (optional)"),
widget=BS3TextFieldWidget(),
),
"resource_group": StringField(
lazy_gettext("Resource group name (optional)"),
widget=BS3TextFieldWidget(),
),
}

@classmethod
def get_ui_field_behaviour(cls) -> dict[str, Any]:
"""Returns custom field behaviour."""
Expand All @@ -54,17 +75,38 @@ def get_ui_field_behaviour(cls) -> dict[str, Any]:
"login": "private registry username",
"password": "private registry password",
"host": "docker image registry server",
"subscription_id": "Subscription id (required for Azure AD authentication)",
"resource_group": "Resource group name (required for Azure AD authentication)",
},
}

def __init__(self, conn_id: str = "azure_registry") -> None:
super().__init__()
self.conn_id = conn_id

def _get_field(self, extras, name):
return get_field(
conn_id=self.conn_id,
conn_type=self.conn_type,
extras=extras,
field_name=name,
)

@cached_property
def connection(self) -> ImageRegistryCredential:
return self.get_conn()

def get_conn(self) -> ImageRegistryCredential:
conn = self.get_connection(self.conn_id)
return ImageRegistryCredential(server=conn.host, username=conn.login, password=conn.password)
password = conn.password
if not password:
extras = conn.extra_dejson
subscription_id = self._get_field(extras, "subscription_id")
resource_group = self._get_field(extras, "resource_group")
client = ContainerRegistryManagementClient(
credential=DefaultAzureCredential(), subscription_id=subscription_id
)
credentials = client.registries.list_credentials(resource_group, conn.login).as_dict()
password = credentials["passwords"][0]["value"]

return ImageRegistryCredential(server=conn.host, username=conn.login, password=password)
2 changes: 2 additions & 0 deletions airflow/providers/microsoft/azure/provider.yaml
Expand Up @@ -81,6 +81,8 @@ dependencies:
- adal>=1.2.7
- azure-storage-file-datalake>=12.9.1
- azure-kusto-data>=4.1.0

- azure-mgmt-containerregistry>=8.0.0
# TODO: upgrade to newer versions of all the below libraries.
# See issue https://github.com/apache/airflow/issues/30199
- azure-mgmt-containerinstance>=7.0.0,<9.0.0
Expand Down
10 changes: 10 additions & 0 deletions docs/apache-airflow-providers-microsoft-azure/connections/acr.rst
Expand Up @@ -50,6 +50,16 @@ Password
Host
Specify the Image Registry Server used for the initial connection.

Subscription ID
Specify the ID of the subscription used for the initial connection.
This is needed for Azure Active Directory (Azure AD) authentication.
Use extra param ``subscription_id`` to pass in the Azure subscription ID.

Resource Group Name (optional)
Specify the Azure Resource Group Name under which the desired Azure container registry resides.
This is needed for Azure Active Directory (Azure AD) authentication.
Use extra param ``resource_group`` to pass in the resource group name.

When specifying the connection in environment variable you should specify
it using URI syntax.

Expand Down
1 change: 1 addition & 0 deletions generated/provider_dependencies.json
Expand Up @@ -560,6 +560,7 @@
"azure-keyvault-secrets>=4.1.0",
"azure-kusto-data>=4.1.0",
"azure-mgmt-containerinstance>=7.0.0,<9.0.0",
"azure-mgmt-containerregistry>=8.0.0",
"azure-mgmt-cosmosdb",
"azure-mgmt-datafactory>=1.0.0,<2.0",
"azure-mgmt-datalake-store>=0.5.0",
Expand Down
Expand Up @@ -17,6 +17,8 @@
# under the License.
from __future__ import annotations

from unittest import mock

import pytest

from airflow.models import Connection
Expand All @@ -43,3 +45,39 @@ def test_get_conn(self, mocked_connection):
assert hook.connection.username == "myuser"
assert hook.connection.password == "password"
assert hook.connection.server == "test.cr"

@pytest.mark.parametrize(
"mocked_connection",
[
Connection(
conn_id="azure_container_registry",
conn_type="azure_container_registry",
login="myuser",
password="",
host="test.cr",
extra={"subscription_id": "subscription_id", "resource_group": "resource_group"},
)
],
indirect=True,
)
@mock.patch(
"airflow.providers.microsoft.azure.hooks.container_registry.ContainerRegistryManagementClient"
)
@mock.patch("airflow.providers.microsoft.azure.hooks.container_registry.DefaultAzureCredential")
def test_get_conn_with_default_azure_credential(
self, mocked_default_azure_credential, mocked_client, mocked_connection
):
mocked_client.return_value.registries.list_credentials.return_value.as_dict.return_value = {
"username": "myuser",
"passwords": [
{"name": "password", "value": "password"},
],
}

hook = AzureContainerRegistryHook(conn_id=mocked_connection.conn_id)
assert hook.connection is not None
assert hook.connection.username == "myuser"
assert hook.connection.password == "password"
assert hook.connection.server == "test.cr"

mocked_default_azure_credential.assert_called_with()

0 comments on commit 539797f

Please sign in to comment.