Skip to content

Commit

Permalink
Support Service Fabric
Browse files Browse the repository at this point in the history
  • Loading branch information
rayluo committed Apr 11, 2023
1 parent 23e2e15 commit 75601db
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 30 deletions.
55 changes: 54 additions & 1 deletion msal/imds.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,18 @@ def _scope_to_resource(scope): # This is an experimental reasonable-effort appr


def _obtain_token(http_client, resource, client_id=None, object_id=None, mi_res_id=None):
if ("IDENTITY_ENDPOINT" in os.environ and "IDENTITY_HEADER" in os.environ
and "IDENTITY_SERVER_THUMBPRINT" in os.environ
):
if client_id or object_id or mi_res_id:
logger.debug(
"Ignoring client_id/object_id/mi_res_id. "
"Managed Identity in Service Fabric is configured in the cluster, "
"not during runtime. See also "
"https://learn.microsoft.com/en-us/azure/service-fabric/configure-existing-cluster-enable-managed-identity-token-service")
return _obtain_token_on_service_fabric(
http_client, os.environ["IDENTITY_ENDPOINT"], os.environ["IDENTITY_HEADER"],
os.environ["IDENTITY_SERVER_THUMBPRINT"], resource)
if "IDENTITY_ENDPOINT" in os.environ and "IDENTITY_HEADER" in os.environ:
return _obtain_token_on_app_service(
http_client, os.environ["IDENTITY_ENDPOINT"], os.environ["IDENTITY_HEADER"],
Expand Down Expand Up @@ -69,7 +81,8 @@ def _obtain_token_on_app_service(http_client, endpoint, identity_header, resourc
client_id=None, object_id=None, mi_res_id=None,
):
"""Obtains token for
`App Service <https://learn.microsoft.com/en-us/azure/app-service/overview-managed-identity?tabs=portal%2Chttp#rest-endpoint-reference>`_
`App Service <https://learn.microsoft.com/en-us/azure/app-service/overview-managed-identity?tabs=portal%2Chttp#rest-endpoint-reference>`_,
Azure Functions, and Azure Automation.
"""
# Prerequisite: Create your app service https://docs.microsoft.com/en-us/azure/app-service/quickstart-python
# Assign it a managed identity https://docs.microsoft.com/en-us/azure/app-service/overview-managed-identity?tabs=portal%2Chttp
Expand Down Expand Up @@ -114,6 +127,46 @@ def _obtain_token_on_app_service(http_client, endpoint, identity_header, resourc
raise


def _obtain_token_on_service_fabric(
http_client, endpoint, identity_header, server_thumbprint, resource,
):
"""Obtains token for
`Service Fabric <https://learn.microsoft.com/en-us/azure/service-fabric/>`_
"""
# Deployment https://learn.microsoft.com/en-us/azure/service-fabric/service-fabric-get-started-containers-linux
# See also https://github.com/Azure/azure-sdk-for-python/blob/main/sdk/identity/azure-identity/tests/managed-identity-live/service-fabric/service_fabric.md
# Protocol https://learn.microsoft.com/en-us/azure/service-fabric/how-to-managed-identity-service-fabric-app-code#acquiring-an-access-token-using-rest-api
logger.debug("Obtaining token via managed identity on Azure Service Fabric")
resp = http_client.get(
endpoint,
params={"api-version": "2019-07-01-preview", "resource": resource},
headers={"Secret": identity_header},
)
try:
payload = json.loads(resp.text)
if payload.get("access_token") and payload.get("expires_on"):
return { # Normalizing the payload into OAuth2 format
"access_token": payload["access_token"],
"expires_in": payload["expires_on"] - int(time.time()),
"resource": payload.get("resource"),
"token_type": payload["token_type"],
}
error = payload.get("error", {}) # https://learn.microsoft.com/en-us/azure/service-fabric/how-to-managed-identity-service-fabric-app-code#error-handling
error_mapping = { # Map Service Fabric errors into OAuth2 errors https://www.rfc-editor.org/rfc/rfc6749#section-5.2
"SecretHeaderNotFound": "unauthorized_client",
"ManagedIdentityNotFound": "invalid_client",
"ArgumentNullOrEmpty": "invalid_scope",
}
return {
"error": error_mapping.get(payload["error"]["code"], "invalid_request"),
"error_description": resp.text,
}
except ValueError:
logger.debug("IMDS emits unexpected payload: %s", resp.text)
raise



class ManagedIdentity(object):
_instance, _tenant = socket.getfqdn(), "managed_identity" # Placeholders

Expand Down
89 changes: 60 additions & 29 deletions tests/test_mi.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,21 @@ def _test_token_cache(self, app):
"Should have expected client_id")
self.assertEqual("managed_identity", at["realm"], "Should have expected realm")

def _test_happy_path(self, app, mocked_http):
result = app.acquire_token(resource="R")
mocked_http.assert_called_once()
self.assertEqual({
"access_token": "AT",
"expires_in": 1234,
"resource": "R",
"token_type": "Bearer",
}, result, "Should obtain a token response")
self.assertEqual(
result["access_token"],
app.acquire_token(resource="R").get("access_token"),
"Should hit the same token from cache")
self._test_token_cache(app)


class VmTestCase(ManagedIdentityTestCase):

Expand All @@ -34,19 +49,7 @@ def test_happy_path(self):
status_code=200,
text='{"access_token": "AT", "expires_in": "1234", "resource": "R"}',
)) as mocked_method:
result = app.acquire_token(resource="R")
mocked_method.assert_called_once()
self.assertEqual({
"access_token": "AT",
"expires_in": 1234,
"resource": "R",
"token_type": "Bearer",
}, result, "Should obtain a token response")
self.assertEqual(
result["access_token"],
app.acquire_token(resource="R").get("access_token"),
"Should hit the same token from cache")
self._test_token_cache(app)
self._test_happy_path(app, mocked_method)

def test_vm_error_should_be_returned_as_is(self):
raw_error = '{"raw": "error format is undefined"}'
Expand All @@ -63,26 +66,13 @@ def test_vm_error_should_be_returned_as_is(self):
class AppServiceTestCase(ManagedIdentityTestCase):

def test_happy_path(self):
# TODO: Combine this with VM's test case, and move it into base class
app = ManagedIdentity(requests.Session(), token_cache=TokenCache())
now = int(time.time())
with patch.object(app._http_client, "get", return_value=MinimalResponse(
status_code=200,
text='{"access_token": "AT", "expires_on": "%s", "resource": "R"}' % (now + 100),
text='{"access_token": "AT", "expires_on": "%s", "resource": "R"}' % (
int(time.time()) + 1234),
)) as mocked_method:
result = app.acquire_token(resource="R")
mocked_method.assert_called_once()
self.assertEqual({
"access_token": "AT",
"expires_in": 100,
"resource": "R",
"token_type": "Bearer",
}, result, "Should obtain a token response")
self.assertEqual(
result["access_token"],
app.acquire_token(resource="R").get("access_token"),
"Should hit the same token from cache")
self._test_token_cache(app)
self._test_happy_path(app, mocked_method)

def test_app_service_error_should_be_normalized(self):
raw_error = '{"statusCode": 500, "message": "error content is undefined"}'
Expand All @@ -97,3 +87,44 @@ def test_app_service_error_should_be_normalized(self):
}, app.acquire_token(resource="R"))
self.assertEqual({}, app._token_cache._cache)

@patch.dict(os.environ, {
"IDENTITY_ENDPOINT": "http://localhost",
"IDENTITY_HEADER": "foo",
"IDENTITY_SERVER_THUMBPRINT": "bar",
})
class ServiceFabricTestCase(ManagedIdentityTestCase):

def _test_happy_path(self, app):
with patch.object(app._http_client, "get", return_value=MinimalResponse(
status_code=200,
text='{"access_token": "AT", "expires_on": %s, "resource": "R", "token_type": "Bearer"}' % (
int(time.time()) + 1234),
)) as mocked_method:
super(ServiceFabricTestCase, self)._test_happy_path(app, mocked_method)

def test_happy_path(self):
self._test_happy_path(ManagedIdentity(
requests.Session(), token_cache=TokenCache()))

def test_unified_api_service_should_ignore_unnecessary_client_id(self):
self._test_happy_path(ManagedIdentity(
requests.Session(), client_id="foo", token_cache=TokenCache()))

def test_app_service_error_should_be_normalized(self):
raw_error = '''
{"error": {
"correlationId": "foo",
"code": "SecretHeaderNotFound",
"message": "Secret is not found in the request headers."
}}''' # https://learn.microsoft.com/en-us/azure/service-fabric/how-to-managed-identity-service-fabric-app-code#error-handling
app = ManagedIdentity(requests.Session(), token_cache=TokenCache())
with patch.object(app._http_client, "get", return_value=MinimalResponse(
status_code=404,
text=raw_error,
)) as mocked_method:
self.assertEqual({
"error": "unauthorized_client",
"error_description": raw_error,
}, app.acquire_token(resource="R"))
self.assertEqual({}, app._token_cache._cache)

0 comments on commit 75601db

Please sign in to comment.