diff --git a/azurefunctions-extensions-bindings-blob/azurefunctions/extensions/bindings/blob/utils.py b/azurefunctions-extensions-bindings-blob/azurefunctions/extensions/bindings/blob/utils.py index 5e05599..69ce04a 100644 --- a/azurefunctions-extensions-bindings-blob/azurefunctions/extensions/bindings/blob/utils.py +++ b/azurefunctions-extensions-bindings-blob/azurefunctions/extensions/bindings/blob/utils.py @@ -2,7 +2,7 @@ # Licensed under the MIT License. import os -from azure.identity import DefaultAzureCredential, ManagedIdentityCredential +from azure.identity import ManagedIdentityCredential from azure.storage.blob import BlobServiceClient @@ -35,7 +35,9 @@ def get_connection_string(connection_name: str) -> str: 1. Not using managed identity: the environment variable exists as is 2. Using managed identity for blob input: __serviceUri must be appended 3. Using managed identity for blob trigger: __blobServiceUri must be appended - 4. None of these cases existed, so the connection variable is invalid. + 4. Using managed identity with host storage: + connection name must be AzureWebJobsStorage and __accountName must be appended + 5. None of these cases existed, so the connection variable is invalid. """ if connection_name in os.environ: return os.getenv(connection_name) @@ -43,6 +45,9 @@ def get_connection_string(connection_name: str) -> str: return os.getenv(connection_name + "__serviceUri") elif connection_name + "__blobServiceUri" in os.environ: return os.getenv(connection_name + "__blobServiceUri") + elif (connection_name == "AzureWebJobsStorage" + and connection_name + "__accountName" in os.environ): + return f"https://{os.getenv('AzureWebJobsStorage__accountName')}.blob.core.windows.net" # noqa else: raise ValueError( f"Storage account connection name {connection_name} does not exist. " @@ -54,10 +59,12 @@ def using_system_managed_identity(connection_name: str) -> bool: """ To determine if system-assigned managed identity is being used, we check if the provided connection string has either of the two suffixes: - __serviceUri or __blobServiceUri. + __serviceUri or __blobServiceUri OR if the identity is using host storage. """ return (os.getenv(connection_name + "__serviceUri") is not None) or ( - os.getenv(connection_name + "__blobServiceUri") is not None + os.getenv(connection_name + "__blobServiceUri") is not None or ( + connection_name == "AzureWebJobsStorage" and ( + os.getenv(connection_name + "__accountName")) is not None) ) @@ -81,9 +88,9 @@ def service_client_factory(connection: str): There are 3 cases: 1. The customer is using user-assigned managed identity -> the BlobServiceClient - must be created using a ManagedIdentityCredential. + must be created using a ManagedIdentityCredential with specified arguments. 2. The customer is using system based managed identity -> the BlobServiceClient - must be created using a DefaultAzureCredential. + must be created using a ManagedIdentityCredential with default arguments. 3. The customer is not using managed identity -> the BlobServiceClient must be created using a connection string. """ @@ -94,6 +101,6 @@ def service_client_factory(connection: str): client_id=os.getenv(connection + "__clientId"))) elif using_system_managed_identity(connection): return BlobServiceClient(account_url=connection_string, - credential=DefaultAzureCredential()) + credential=ManagedIdentityCredential()) else: return BlobServiceClient.from_connection_string(connection_string) diff --git a/azurefunctions-extensions-bindings-blob/tests/test_utils.py b/azurefunctions-extensions-bindings-blob/tests/test_utils.py index b867cdf..4cc11ff 100644 --- a/azurefunctions-extensions-bindings-blob/tests/test_utils.py +++ b/azurefunctions-extensions-bindings-blob/tests/test_utils.py @@ -44,6 +44,14 @@ def test_blob_service_uri_exists(self): result = get_connection_string("MY_CONNECTION") self.assertEqual(result, "blob_service_uri_string") + def test_host_storage_exists(self): + with patch.dict(os.environ, { + "AzureWebJobsStorage__accountName": "account_name_string"}): + os.environ.pop("AzureWebJobsStorage", None) + result = get_connection_string("AzureWebJobsStorage") + self.assertEqual(result, + "https://account_name_string.blob.core.windows.net") + def test_connection_string_missing_raises_value_error(self): with patch.dict(os.environ, {}, clear=True): with self.assertRaises(ValueError) as context: @@ -64,6 +72,12 @@ def test_blob_service_uri_present(self): result = using_system_managed_identity("MY_CONNECTION") self.assertTrue(result) + def test_host_storage_present(self): + with patch.dict(os.environ, {"AzureWebJobsStorage__accountName": + "https://example.blob.core.windows.net/"}): + result = using_system_managed_identity("AzureWebJobsStorage") + self.assertTrue(result) + def test_both_uris_present(self): with patch.dict(os.environ, { "MY_CONNECTION__serviceUri": "https://example.service.core.windows.net/",