Skip to content

Commit

Permalink
Managed Identity for Machine Learning
Browse files Browse the repository at this point in the history
  • Loading branch information
rayluo committed May 24, 2024
1 parent 2c8c5ba commit 9350391
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 0 deletions.
43 changes: 43 additions & 0 deletions msal/managed_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,15 @@ def _obtain_token(http_client, managed_identity, resource):
managed_identity,
resource,
)
if "MSI_ENDPOINT" in os.environ and "MSI_SECRET" in os.environ:
# Back ported from https://github.com/Azure/azure-sdk-for-python/blob/azure-identity_1.15.0/sdk/identity/azure-identity/azure/identity/_credentials/azure_ml.py
return _obtain_token_on_machine_learning(
http_client,
os.environ["MSI_ENDPOINT"],
os.environ["MSI_SECRET"],
managed_identity,
resource,
)
if "IDENTITY_ENDPOINT" in os.environ and "IMDS_ENDPOINT" in os.environ:
if ManagedIdentity.is_user_assigned(managed_identity):
raise ManagedIdentityError( # Note: Azure Identity for Python raised exception too
Expand All @@ -329,6 +338,7 @@ def _obtain_token(http_client, managed_identity, resource):


def _adjust_param(params, managed_identity):
# Modify the params dict in place
id_name = ManagedIdentity._types_mapping.get(
managed_identity.get(ManagedIdentity.ID_TYPE))
if id_name:
Expand Down Expand Up @@ -405,6 +415,39 @@ def _obtain_token_on_app_service(
logger.debug("IMDS emits unexpected payload: %s", resp.text)
raise

def _obtain_token_on_machine_learning(
http_client, endpoint, secret, managed_identity, resource,
):
# Could not find protocol docs from https://docs.microsoft.com/en-us/azure/machine-learning
# The following implementation is back ported from Azure Identity 1.15.0
logger.debug("Obtaining token via managed identity on Azure Machine Learning")
params = {"api-version": "2017-09-01", "resource": resource}
_adjust_param(params, managed_identity)
if params["api-version"] == "2017-09-01" and "client_id" in params:
# Workaround for a known bug in Azure ML 2017 API
params["clientid"] = params.pop("client_id")
resp = http_client.get(
endpoint,
params=params,
headers={"secret": secret},
)
try:
payload = json.loads(resp.text)
if payload.get("access_token") and payload.get("expires_on"):
return { # Normalizing the payload into OAuth2 format
"access_token": payload["access_token"],
"expires_in": int(payload["expires_on"]) - int(time.time()),
"resource": payload.get("resource"),
"token_type": payload.get("token_type", "Bearer"),
}
return {
"error": "invalid_scope", # TODO: To be tested
"error_description": "{}".format(payload),
}
except json.decoder.JSONDecodeError:
logger.debug("IMDS emits unexpected payload: %s", resp.text)
raise


def _obtain_token_on_service_fabric(
http_client, endpoint, identity_header, server_thumbprint, resource,
Expand Down
24 changes: 24 additions & 0 deletions tests/test_mi.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,30 @@ def test_app_service_error_should_be_normalized(self):
self.assertEqual({}, self.app._token_cache._cache)


@patch.dict(os.environ, {"MSI_ENDPOINT": "http://localhost", "MSI_SECRET": "foo"})
class MachineLearningTestCase(ClientTestCase):

def test_happy_path(self):
with patch.object(self.app._http_client, "get", return_value=MinimalResponse(
status_code=200,
text='{"access_token": "AT", "expires_on": "%s", "resource": "R"}' % (
int(time.time()) + 1234),
)) as mocked_method:
self._test_happy_path(self.app, mocked_method)

def test_machine_learning_error_should_be_normalized(self):
raw_error = '{"error": "placeholder", "message": "placeholder"}'
with patch.object(self.app._http_client, "get", return_value=MinimalResponse(
status_code=500,
text=raw_error,
)) as mocked_method:
self.assertEqual({
"error": "invalid_scope",
"error_description": "{'error': 'placeholder', 'message': 'placeholder'}",
}, self.app.acquire_token_for_client(resource="R"))
self.assertEqual({}, self.app._token_cache._cache)


@patch.dict(os.environ, {
"IDENTITY_ENDPOINT": "http://localhost",
"IDENTITY_HEADER": "foo",
Expand Down

0 comments on commit 9350391

Please sign in to comment.