Skip to content

Commit

Permalink
Dedicate ManagedIdentity API
Browse files Browse the repository at this point in the history
  • Loading branch information
rayluo committed Mar 10, 2023
1 parent 0c57056 commit dd75a28
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 43 deletions.
1 change: 1 addition & 0 deletions msal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,5 @@
)
from .oauth2cli.oidc import Prompt
from .token_cache import TokenCache, SerializableTokenCache
from .imds import ManagedIdentity

15 changes: 0 additions & 15 deletions msal/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -2001,21 +2001,6 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs):
- an error response would contain "error" and usually "error_description".
"""
# TBD: force_refresh behavior
if self.client_credential is None:
from .imds import _scope_to_resource, _obtain_token
response = _obtain_token(
self.http_client,
" ".join(map(_scope_to_resource, scopes)),
client_id=self.client_id, # None for system-assigned, GUID for user-assigned
)
if "error" not in response:
self.token_cache.add(dict(
client_id=self.client_id,
scope=response["scope"].split() if "scope" in response else scopes,
token_endpoint=self.authority.token_endpoint,
response=response.copy(),
))
return response
if self.authority.tenant.lower() in ["common", "organizations"]:
warnings.warn(
"Using /common or /organizations authority "
Expand Down
69 changes: 68 additions & 1 deletion msal/imds.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import json
import logging
import os
import socket
import time
try: # Python 2
from urlparse import urlparse
Expand Down Expand Up @@ -57,6 +58,9 @@ def _obtain_token_on_azure_vm(http_client, resource, client_id=None):
raise

def _obtain_token_on_app_service(http_client, endpoint, identity_header, resource, client_id=None):
"""Obtains token for
`App Service <https://learn.microsoft.com/en-us/azure/app-service/overview-managed-identity?tabs=portal%2Chttp#rest-endpoint-reference>`_
"""
# Prerequisite: Create your app service https://docs.microsoft.com/en-us/azure/app-service/quickstart-python
# Assign it a managed identity https://docs.microsoft.com/en-us/azure/app-service/overview-managed-identity?tabs=portal%2Chttp
# SSH into your container for testing https://docs.microsoft.com/en-us/azure/app-service/configure-linux-open-ssh-session
Expand All @@ -73,7 +77,7 @@ def _obtain_token_on_app_service(http_client, endpoint, identity_header, resourc
headers={
"X-IDENTITY-HEADER": identity_header,
"Metadata": "true", # Unnecessary yet harmless for App Service,
# It will be needed by Azure Automation
# It will be needed by Azure Automation
# https://docs.microsoft.com/en-us/azure/automation/enable-managed-identity-for-automation#get-access-token-for-system-assigned-managed-identity-using-http-get
},
)
Expand All @@ -95,3 +99,66 @@ def _obtain_token_on_app_service(http_client, endpoint, identity_header, resourc
logger.debug("IMDS emits unexpected payload: %s", resp.text)
raise


class ManagedIdentity(object):
_instance, _tenant = socket.getfqdn(), "managed_identity" # Placeholders

def __init__(self, http_client, client_id=None, token_cache=None):
"""Create a managed identity object.
:param http_client:
An http client object. For example, you can use `requests.Session()`.
:param str client_id:
Optional.
It accepts the Client ID (NOT the Object ID) of your user-assigned managed identity.
If it is None, it means to use a system-assigned managed identity.
:param token_cache:
Optional. It accepts a :class:`msal.TokenCache` instance to store tokens.
"""
self._http_client = http_client
self._client_id = client_id
self._token_cache = token_cache

def acquire_token(self, resource):
access_token_from_cache = None
if self._token_cache:
matches = self._token_cache.find(
self._token_cache.CredentialType.ACCESS_TOKEN,
target=[resource],
query=dict(
client_id=self._client_id,
environment=self._instance,
realm=self._tenant,
home_account_id=None,
),
)
now = time.time()
for entry in matches:
expires_in = int(entry["expires_on"]) - now
if expires_in < 5*60: # Then consider it expired
continue # Removal is not necessary, it will be overwritten
logger.debug("Cache hit an AT")
access_token_from_cache = { # Mimic a real response
"access_token": entry["secret"],
"token_type": entry.get("token_type", "Bearer"),
"expires_in": int(expires_in), # OAuth2 specs defines it as int
}
if "refresh_on" in entry and int(entry["refresh_on"]) < now: # aging
break # With a fallback in hand, we break here to go refresh
return access_token_from_cache # It is still good as new
result = _obtain_token(self._http_client, resource, client_id=self._client_id)
if self._token_cache and "access_token" in result:
self._token_cache.add(dict(
client_id=self._client_id,
scope=[resource],
token_endpoint="https://{}/{}".format(self._instance, self._tenant),
response=result,
params={},
data={},
#grant_type="placeholder",
))
return result
return access_token_from_cache or result

76 changes: 49 additions & 27 deletions tests/msaltest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import getpass, logging, pprint, sys, msal
import functools, getpass, logging, pprint, sys, requests, msal


AZURE_CLI = "04b07795-8ddb-461a-bbee-02f9e1bf7b46"
Expand Down Expand Up @@ -141,32 +141,55 @@ def remove_account(app):
app.remove_account(account)
print('Account "{}" and/or its token(s) are signed out from MSAL Python'.format(account["username"]))

def acquire_token_for_client(app):
"""acquire_token_for_client() - Only for confidential client"""
pprint.pprint(app.acquire_token_for_client(_input_scopes()))
def acquire_token_for_managed_identity(app):
"""acquire_token() - Only for managed identity"""
pprint.pprint(app.acquire_token(_select_options([
"https://management.azure.com",
"https://graph.microsoft.com",
],
header="Acquire token for this resource",
accept_nonempty_string=True)))

def exit(app):
"""Exit"""
bug_link = (
"https://identitydivision.visualstudio.com/Engineering/_queries/query/79b3a352-a775-406f-87cd-a487c382a8ed/"
if app._enable_broker else
if getattr(app, "_enable_broker", None) else
"https://github.com/AzureAD/microsoft-authentication-library-for-python/issues/new/choose"
)
print("Bye. If you found a bug, please report it here: {}".format(bug_link))
sys.exit()

def _managed_identity():
client_id = _select_options([
{"client_id": None, "name": "System-assigned managed identity"},
],
option_renderer=lambda a: a["name"],
header="Choose the system-assigned managed identity "
"(or type in your user-assigned managed identity)",
accept_nonempty_string=True)
return msal.ManagedIdentity(
requests.Session(),
client_id=client_id["client_id"]
if isinstance(client_id, dict) else client_id,
token_cache=msal.TokenCache(),
)

def main():
print("Welcome to the Msal Python Console Test App, committed at 2022-5-2\n")
print("Welcome to the Console Test App for MSAL Python {}\n".format(msal.__version__))
chosen_app = _select_options([
{"client_id": AZURE_CLI, "name": "Azure CLI (Correctly configured for MSA-PT)"},
{"client_id": VISUAL_STUDIO, "name": "Visual Studio (Correctly configured for MSA-PT)"},
{"client_id": "95de633a-083e-42f5-b444-a4295d8e9314", "name": "Whiteboard Services (Non MSA-PT app. Accepts AAD & MSA accounts.)"},
{"client_id": None, "client_secret": None, "name": "System-assigned Managed Identity (Only works when running inside a supported environment, such as Azure VM, Azure App Service, Azure Automation)"},
{"test_managed_identity": None, "name": "Managed Identity (Only works when running inside a supported environment, such as Azure VM, Azure App Service, Azure Automation)"},
],
option_renderer=lambda a: a["name"],
header="Impersonate this app (or you can type in the client_id of your own app)",
accept_nonempty_string=True)
authority = _select_options([
if isinstance(chosen_app, dict) and "test_managed_identity" in chosen_app:
app = _managed_identity()
else:
authority = _select_options([
"https://login.microsoftonline.com/common",
"https://login.microsoftonline.com/organizations",
"https://login.microsoftonline.com/microsoft.onmicrosoft.com",
Expand All @@ -175,33 +198,32 @@ def main():
],
header="Input authority (Note that MSA-PT apps would NOT use the /common authority)",
accept_nonempty_string=True,
)
if isinstance(chosen_app, dict) and "client_secret" in chosen_app:
app = msal.ConfidentialClientApplication(
chosen_app["client_id"],
client_credential=chosen_app["client_secret"],
authority=authority,
)
else:
)
app = msal.PublicClientApplication(
chosen_app["client_id"] if isinstance(chosen_app, dict) else chosen_app,
authority=authority,
allow_broker=_input_boolean("Allow broker? (Azure CLI currently only supports @microsoft.com accounts when enabling broker)"),
)
if _input_boolean("Enable MSAL Python's DEBUG log?"):
logging.basicConfig(level=logging.DEBUG)
methods_to_be_tested = functools.reduce(lambda x, y: x + y, [
methods for app_type, methods in {
msal.PublicClientApplication: [
acquire_token_interactive,
acquire_ssh_cert_silently,
acquire_ssh_cert_interactive,
],
msal.ClientApplication: [
acquire_token_silent,
acquire_token_by_username_password,
remove_account,
],
msal.ManagedIdentity: [acquire_token_for_managed_identity],
}.items() if isinstance(app, app_type)])
while True:
func = _select_options(list(filter(None, [
acquire_token_silent,
acquire_token_interactive,
acquire_token_by_username_password,
acquire_ssh_cert_silently,
acquire_ssh_cert_interactive,
remove_account,
acquire_token_for_client if isinstance(
app, msal.ConfidentialClientApplication) else None,
exit,
])), option_renderer=lambda f: f.__doc__, header="MSAL Python APIs:")
func = _select_options(
methods_to_be_tested + [exit],
option_renderer=lambda f: f.__doc__, header="MSAL Python APIs:")
try:
func(app)
except ValueError as e:
Expand Down

0 comments on commit dd75a28

Please sign in to comment.