Skip to content

Commit

Permalink
get_managed_identity_source() for Azure Identity
Browse files Browse the repository at this point in the history
  • Loading branch information
rayluo committed Jun 21, 2024
1 parent cb9cfe2 commit b443edd
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 1 deletion.
30 changes: 30 additions & 0 deletions msal/managed_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .token_cache import TokenCache
from .individual_cache import _IndividualCache as IndividualCache
from .throttled_http_client import ThrottledHttpClientBase, RetryAfterParser
from .cloudshell import _is_running_in_cloud_shell


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -305,6 +306,35 @@ def _scope_to_resource(scope): # This is an experimental reasonable-effort appr
return scope # There is no much else we can do here


APP_SERVICE = object()
AZURE_ARC = object()
CLOUD_SHELL = object() # In MSAL Python, token acquisition was done by
# PublicClientApplication(...).acquire_token_interactive(..., prompt="none")
MACHINE_LEARNING = object()
SERVICE_FABRIC = object()
DEFAULT_TO_VM = object() # Unknown environment; default to VM; you may want to probe
def get_managed_identity_source():
"""Detect the current environment and return the likely identity source.
When this function returns ``CLOUD_SHELL``, you should use
:func:`msal.PublicClientApplication.acquire_token_interactive` with ``prompt="none"``
to obtain a token.
"""
if ("IDENTITY_ENDPOINT" in os.environ and "IDENTITY_HEADER" in os.environ
and "IDENTITY_SERVER_THUMBPRINT" in os.environ
):
return SERVICE_FABRIC
if "IDENTITY_ENDPOINT" in os.environ and "IDENTITY_HEADER" in os.environ:
return APP_SERVICE
if "MSI_ENDPOINT" in os.environ and "MSI_SECRET" in os.environ:
return MACHINE_LEARNING
if "IDENTITY_ENDPOINT" in os.environ and "IMDS_ENDPOINT" in os.environ:
return AZURE_ARC
if _is_running_in_cloud_shell():
return CLOUD_SHELL
return DEFAULT_TO_VM


def _obtain_token(http_client, managed_identity, resource):
# A unified low-level API that talks to different Managed Identity
if ("IDENTITY_ENDPOINT" in os.environ and "IDENTITY_HEADER" in os.environ
Expand Down
52 changes: 51 additions & 1 deletion tests/test_mi.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,16 @@
ManagedIdentityError,
ArcPlatformNotSupportedError,
)
from msal.managed_identity import _supported_arc_platforms_and_their_prefixes
from msal.managed_identity import (
_supported_arc_platforms_and_their_prefixes,
get_managed_identity_source,
APP_SERVICE,
AZURE_ARC,
CLOUD_SHELL,
MACHINE_LEARNING,
SERVICE_FABRIC,
DEFAULT_TO_VM,
)


class ManagedIdentityTestCase(unittest.TestCase):
Expand Down Expand Up @@ -234,3 +243,44 @@ def test_arc_error_should_be_normalized(self, mocked_stat):
if sys.platform in _supported_arc_platforms_and_their_prefixes:
self.fail("Should not raise ArcPlatformNotSupportedError")


class GetManagedIdentitySourceTestCase(unittest.TestCase):

@patch.dict(os.environ, {
"IDENTITY_ENDPOINT": "http://localhost",
"IDENTITY_HEADER": "foo",
"IDENTITY_SERVER_THUMBPRINT": "bar",
})
def test_service_fabric(self):
self.assertEqual(get_managed_identity_source(), SERVICE_FABRIC)

@patch.dict(os.environ, {
"IDENTITY_ENDPOINT": "http://localhost",
"IDENTITY_HEADER": "foo",
})
def test_app_service(self):
self.assertEqual(get_managed_identity_source(), APP_SERVICE)

@patch.dict(os.environ, {
"MSI_ENDPOINT": "http://localhost",
"MSI_SECRET": "foo",
})
def test_machine_learning(self):
self.assertEqual(get_managed_identity_source(), MACHINE_LEARNING)

@patch.dict(os.environ, {
"IDENTITY_ENDPOINT": "http://localhost",
"IMDS_ENDPOINT": "http://localhost",
})
def test_arc(self):
self.assertEqual(get_managed_identity_source(), AZURE_ARC)

@patch.dict(os.environ, {
"AZUREPS_HOST_ENVIRONMENT": "cloud-shell-foo",
})
def test_cloud_shell(self):
self.assertEqual(get_managed_identity_source(), CLOUD_SHELL)

def test_default_to_vm(self):
self.assertEqual(get_managed_identity_source(), DEFAULT_TO_VM)

0 comments on commit b443edd

Please sign in to comment.