Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Licensed under the MIT License.
import os

from azure.identity import DefaultAzureCredential, ManagedIdentityCredential
from azure.identity import ManagedIdentityCredential
from azure.storage.blob import BlobServiceClient


Expand Down Expand Up @@ -35,14 +35,19 @@ def get_connection_string(connection_name: str) -> str:
1. Not using managed identity: the environment variable exists as is
2. Using managed identity for blob input: __serviceUri must be appended
3. Using managed identity for blob trigger: __blobServiceUri must be appended
4. None of these cases existed, so the connection variable is invalid.
4. Using managed identity with host storage:
connection name must be AzureWebJobsStorage and __accountName must be appended
5. None of these cases existed, so the connection variable is invalid.
"""
if connection_name in os.environ:
return os.getenv(connection_name)
elif connection_name + "__serviceUri" in os.environ:
return os.getenv(connection_name + "__serviceUri")
elif connection_name + "__blobServiceUri" in os.environ:
return os.getenv(connection_name + "__blobServiceUri")
elif (connection_name == "AzureWebJobsStorage"
and connection_name + "__accountName" in os.environ):
return f"https://{os.getenv('AzureWebJobsStorage__accountName')}.blob.core.windows.net" # noqa
else:
raise ValueError(
f"Storage account connection name {connection_name} does not exist. "
Expand All @@ -54,10 +59,12 @@ def using_system_managed_identity(connection_name: str) -> bool:
"""
To determine if system-assigned managed identity is being used, we check if
the provided connection string has either of the two suffixes:
__serviceUri or __blobServiceUri.
__serviceUri or __blobServiceUri OR if the identity is using host storage.
"""
return (os.getenv(connection_name + "__serviceUri") is not None) or (
os.getenv(connection_name + "__blobServiceUri") is not None
os.getenv(connection_name + "__blobServiceUri") is not None or (
connection_name == "AzureWebJobsStorage" and (
os.getenv(connection_name + "__accountName")) is not None)
)


Expand All @@ -81,9 +88,9 @@ def service_client_factory(connection: str):

There are 3 cases:
1. The customer is using user-assigned managed identity -> the BlobServiceClient
must be created using a ManagedIdentityCredential.
must be created using a ManagedIdentityCredential with specified arguments.
2. The customer is using system based managed identity -> the BlobServiceClient
must be created using a DefaultAzureCredential.
must be created using a ManagedIdentityCredential with default arguments.
3. The customer is not using managed identity -> the BlobServiceClient must
be created using a connection string.
"""
Expand All @@ -94,6 +101,6 @@ def service_client_factory(connection: str):
client_id=os.getenv(connection + "__clientId")))
elif using_system_managed_identity(connection):
return BlobServiceClient(account_url=connection_string,
credential=DefaultAzureCredential())
credential=ManagedIdentityCredential())
else:
return BlobServiceClient.from_connection_string(connection_string)
14 changes: 14 additions & 0 deletions azurefunctions-extensions-bindings-blob/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,14 @@ def test_blob_service_uri_exists(self):
result = get_connection_string("MY_CONNECTION")
self.assertEqual(result, "blob_service_uri_string")

def test_host_storage_exists(self):
with patch.dict(os.environ, {
"AzureWebJobsStorage__accountName": "account_name_string"}):
os.environ.pop("AzureWebJobsStorage", None)
result = get_connection_string("AzureWebJobsStorage")
self.assertEqual(result,
"https://account_name_string.blob.core.windows.net")

def test_connection_string_missing_raises_value_error(self):
with patch.dict(os.environ, {}, clear=True):
with self.assertRaises(ValueError) as context:
Expand All @@ -64,6 +72,12 @@ def test_blob_service_uri_present(self):
result = using_system_managed_identity("MY_CONNECTION")
self.assertTrue(result)

def test_host_storage_present(self):
with patch.dict(os.environ, {"AzureWebJobsStorage__accountName":
"https://example.blob.core.windows.net/"}):
result = using_system_managed_identity("AzureWebJobsStorage")
self.assertTrue(result)

def test_both_uris_present(self):
with patch.dict(os.environ, {
"MY_CONNECTION__serviceUri": "https://example.service.core.windows.net/",
Expand Down