From 108258e86db2ece9d9f8db2b6d6fd95ab07b02d3 Mon Sep 17 00:00:00 2001 From: Ray Luo Date: Wed, 6 Mar 2019 15:45:14 -0800 Subject: [PATCH 01/14] Mention import for an inline sample --- msal/token_cache.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/msal/token_cache.py b/msal/token_cache.py index 116be878..bdecef53 100644 --- a/msal/token_cache.py +++ b/msal/token_cache.py @@ -192,15 +192,15 @@ class SerializableTokenCache(TokenCache): Depending on your need, the following simple recipe for file-based persistence may be sufficient:: - import atexit - cache = SerializableTokenCache() + import atexit, msal + cache = msal.SerializableTokenCache() cache.deserialize(open("my_cache.bin", "rb").read()) atexit.register(lambda: open("my_cache.bin", "wb").write(cache.serialize()) # Hint: The following optional line persists only when state changed if cache.has_state_changed else None ) - app = ClientApplication(..., token_cache=cache) + app = msal.ClientApplication(..., token_cache=cache) ... :var bool has_state_changed: From b187b80d107e61aa10923ec482c659d97f169429 Mon Sep 17 00:00:00 2001 From: Ray Luo Date: Thu, 7 Mar 2019 15:11:33 -0800 Subject: [PATCH 02/14] Fix bug of cache.has_state_changed not being initialized --- msal/token_cache.py | 2 ++ tests/test_token_cache.py | 6 ++++++ 2 files changed, 8 insertions(+) diff --git a/msal/token_cache.py b/msal/token_cache.py index bdecef53..9990db9d 100644 --- a/msal/token_cache.py +++ b/msal/token_cache.py @@ -207,6 +207,8 @@ class SerializableTokenCache(TokenCache): Indicates whether the cache state has changed since last :func:`~serialize` or :func:`~deserialize` call. """ + has_state_changed = False + def add(self, event, **kwargs): super(SerializableTokenCache, self).add(event, **kwargs) self.has_state_changed = True diff --git a/tests/test_token_cache.py b/tests/test_token_cache.py index eebd751d..ce5c3063 100644 --- a/tests/test_token_cache.py +++ b/tests/test_token_cache.py @@ -106,6 +106,12 @@ def setUp(self): } """) + def test_has_state_changed(self): + cache = SerializableTokenCache() + self.assertFalse(cache.has_state_changed) + cache.add({}) # An NO-OP add() still counts as a state change. Good enough. + self.assertTrue(cache.has_state_changed) + def tearDown(self): state = self.cache.serialize() logger.debug("serialize() = %s", state) From 81e8ddcb3dd0b11e704775595887b8a1d1e31908 Mon Sep 17 00:00:00 2001 From: Ray Luo Date: Thu, 7 Mar 2019 19:52:45 -0800 Subject: [PATCH 03/14] Fine tune documentation in SerializableTokenCache --- msal/token_cache.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/msal/token_cache.py b/msal/token_cache.py index 9990db9d..c945345a 100644 --- a/msal/token_cache.py +++ b/msal/token_cache.py @@ -192,11 +192,12 @@ class SerializableTokenCache(TokenCache): Depending on your need, the following simple recipe for file-based persistence may be sufficient:: - import atexit, msal + import os, atexit, msal cache = msal.SerializableTokenCache() - cache.deserialize(open("my_cache.bin", "rb").read()) + if os.path.exists("my_cache.bin"): + cache.deserialize(open("my_cache.bin", "r").read()) atexit.register(lambda: - open("my_cache.bin", "wb").write(cache.serialize()) + open("my_cache.bin", "w").write(cache.serialize()) # Hint: The following optional line persists only when state changed if cache.has_state_changed else None ) From ed244eec646da477f6f3857e05d1f92418d7d6b1 Mon Sep 17 00:00:00 2001 From: Ray Luo Date: Tue, 12 Mar 2019 10:49:05 -0700 Subject: [PATCH 04/14] Fix a missing comma in the inline documentation That line was presumably copied from here https://github.com/AzureAD/microsoft-authentication-library-for-python/blob/v0.2.0/sample/username_password_sample.py So this time we switch the sequence of the 2 lines, so that it will be less likely to go wrong in future documentation copy & paste work flow. --- sample/client_credential_sample.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sample/client_credential_sample.py b/sample/client_credential_sample.py index 5f539465..59d90b5c 100644 --- a/sample/client_credential_sample.py +++ b/sample/client_credential_sample.py @@ -4,8 +4,8 @@ { "authority": "https://login.microsoftonline.com/organizations", "client_id": "your_client_id", + "scope": ["https://graph.microsoft.com/.default"], "secret": "This is a sample only. You better NOT persist your password." - "scope": ["https://graph.microsoft.com/.default"] } You can then run this sample with a JSON configuration file: From 9311ccee741f6ff83aa509a88d4573e37a137701 Mon Sep 17 00:00:00 2001 From: Ray Luo Date: Wed, 20 Mar 2019 11:30:38 -0700 Subject: [PATCH 05/14] Add some convenient hint for username password flow --- sample/username_password_sample.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sample/username_password_sample.py b/sample/username_password_sample.py index 0137ae6e..e7555edd 100644 --- a/sample/username_password_sample.py +++ b/sample/username_password_sample.py @@ -57,4 +57,6 @@ print(result.get("error")) print(result.get("error_description")) print(result.get("correlation_id")) # You may need this when reporting a bug - + if 65001 in result.get("error_codes", []): # Not mean to be coded programatically, but... + # AAD requires user consent for U/P flow + print("Visit this to consent:", app.get_authorization_request_url(scope)) From cf9e30fa7024fba838625d0ece1c31262f095bae Mon Sep 17 00:00:00 2001 From: Ray Luo Date: Thu, 21 Mar 2019 14:29:23 -0700 Subject: [PATCH 06/14] Tidy up test_application.py --- tests/test_application.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_application.py b/tests/test_application.py index 180bef50..359086ee 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -110,10 +110,10 @@ def test_device_flow(self): CONFIG["client_id"], authority=CONFIG["authority"]) flow = self.app.initiate_device_flow(scopes=CONFIG.get("scope")) assert "user_code" in flow, str(flow) # Provision or policy might block DF - logging.warn(flow["message"]) + logging.warning(flow["message"]) duration = 30 - logging.warn("We will wait up to %d seconds for you to sign in" % duration) + logging.warning("We will wait up to %d seconds for you to sign in" % duration) flow["expires_at"] = time.time() + duration # Shorten the time for quick test result = self.app.acquire_token_by_device_flow(flow) self.assertLoosely( @@ -136,7 +136,7 @@ def setUpClass(cls): @unittest.skipUnless("scope" in CONFIG, "Missing scope") def test_auth_code(self): - from oauth2cli.authcode import obtain_auth_code + from msal.oauth2cli.authcode import obtain_auth_code port = CONFIG.get("listen_port", 44331) redirect_uri = "http://localhost:%s" % port auth_request_uri = self.app.get_authorization_request_url( From 30affebb3832252b060172a9b2405ed8be2c4a11 Mon Sep 17 00:00:00 2001 From: Ray Luo Date: Thu, 21 Mar 2019 14:31:10 -0700 Subject: [PATCH 07/14] Now get_accounts() ensures proper authority type --- msal/application.py | 12 ++++++++---- msal/token_cache.py | 8 ++++++-- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/msal/application.py b/msal/application.py index b667a4a8..9e08227a 100644 --- a/msal/application.py +++ b/msal/application.py @@ -236,17 +236,21 @@ def get_accounts(self, username=None): Your app can choose to display those information to end user, and allow user to choose one of his/her accounts to proceed. """ - # The following implementation finds accounts only from saved accounts, - # but does NOT correlate them with saved RTs. It probably won't matter, - # because in MSAL universe, there are always Accounts and RTs together. - accounts = self.token_cache.find( + accounts = [a for a in self.token_cache.find( # Find all useful accounts self.token_cache.CredentialType.ACCOUNT, query={"environment": self.authority.instance}) + if a["authority_type"] in ( + TokenCache.AuthorityType.ADFS, TokenCache.AuthorityType.MSSTS)] if username: # Federated account["username"] from AAD could contain mixed case lowercase_username = username.lower() accounts = [a for a in accounts if a["username"].lower() == lowercase_username] + # Does not further filter by existing RTs here. It probably won't matter. + # Because in most cases Accounts and RTs co-exist. + # Even in the rare case when an RT is revoked and then removed, + # acquire_token_silent() would then yield no result, + # apps would fall back to other acquire methods. This is the standard pattern. return accounts def acquire_token_silent( diff --git a/msal/token_cache.py b/msal/token_cache.py index c945345a..0353f9d0 100644 --- a/msal/token_cache.py +++ b/msal/token_cache.py @@ -31,6 +31,10 @@ class CredentialType: ACCOUNT = "Account" # Not exactly a credential type, but we put it here ID_TOKEN = "IdToken" + class AuthorityType: + ADFS = "ADFS" + MSSTS = "MSSTS" # MSSTS means AAD v2 for both AAD & MSA + def __init__(self): self._lock = threading.RLock() self._cache = {} @@ -118,8 +122,8 @@ def add(self, event, now=None): "oid", decoded_id_token.get("sub")), "username": decoded_id_token.get("preferred_username"), "authority_type": - "ADFS" if realm == "adfs" - else "MSSTS", # MSSTS means AAD v2 for both AAD & MSA + self.AuthorityType.ADFS if realm == "adfs" + else self.AuthorityType.MSSTS, # "client_info": response.get("client_info"), # Optional } From 699b2dc64c7d2c315bc194987bb8eb810a75be6a Mon Sep 17 00:00:00 2001 From: Ray Luo Date: Thu, 28 Mar 2019 16:41:30 -0700 Subject: [PATCH 08/14] Add _obtain_token(..., post=lambda ...) so you don't need to patch --- oauth2cli/oauth2.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/oauth2cli/oauth2.py b/oauth2cli/oauth2.py index e0096cb0..4d0cf61c 100644 --- a/oauth2cli/oauth2.py +++ b/oauth2cli/oauth2.py @@ -108,6 +108,9 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749 data=None, # All relevant data, which will go into the http body headers=None, # a dict to be sent as request headers timeout=None, + post=None, # A callable to replace requests.post(), for testing. + # Such as: lambda url, **kwargs: + # Mock(status_code=200, json=Mock(return_value={})) **kwargs # Relay all extra parameters to underlying requests ): # Returns the json object came from the OAUTH2 response _data = {'client_id': self.client_id, 'grant_type': grant_type} @@ -133,7 +136,7 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749 raise ValueError("token_endpoint not found in configuration") _headers = {'Accept': 'application/json'} _headers.update(headers or {}) - resp = self.session.post( + resp = (post or self.session.post)( self.configuration["token_endpoint"], headers=_headers, params=params, data=_data, auth=auth, timeout=timeout or self.timeout, From dac051a6a4ceece32d7ae2828073aeef7493c33a Mon Sep 17 00:00:00 2001 From: Ray Luo Date: Thu, 28 Mar 2019 16:43:09 -0700 Subject: [PATCH 09/14] Add method-level on_removing_rt trigger --- oauth2cli/oauth2.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/oauth2cli/oauth2.py b/oauth2cli/oauth2.py index 4d0cf61c..b9727cf5 100644 --- a/oauth2cli/oauth2.py +++ b/oauth2cli/oauth2.py @@ -396,16 +396,18 @@ def _obtain_token(self, grant_type, params=None, data=None, *args, **kwargs): def obtain_token_by_refresh_token(self, token_item, scope=None, rt_getter=lambda token_item: token_item["refresh_token"], + on_removing_rt=None, **kwargs): # type: (Union[str, dict], Union[str, list, set, tuple], Callable) -> dict """This is an "overload" which accepts a refresh token item as a dict, therefore this method can relay refresh_token item to event listeners. - :param refresh_token_item: A refresh token item came from storage + :param token_item: A refresh token item came from storage :param scope: If omitted, is treated as equal to the scope originally granted by the resource ownser, according to https://tools.ietf.org/html/rfc6749#section-6 :param rt_getter: A callable used to extract the RT from token_item + :param on_removing_rt: If absent, fall back to the one defined in initialization """ if isinstance(token_item, str): # Satisfy the L of SOLID, although we expect caller uses a dict @@ -415,7 +417,7 @@ def obtain_token_by_refresh_token(self, token_item, scope=None, resp = super(Client, self).obtain_token_by_refresh_token( rt_getter(token_item), scope=scope, **kwargs) if resp.get('error') == 'invalid_grant': - self.on_removing_rt(token_item) # Discard old RT + (on_removing_rt or self.on_removing_rt)(token_item) # Discard old RT if 'refresh_token' in resp: self.on_updating_rt(token_item, resp['refresh_token']) return resp From f5d30d0e8701029ccec5a927b9981b72a55b0b95 Mon Sep 17 00:00:00 2001 From: Ray Luo Date: Thu, 28 Mar 2019 15:12:59 -0700 Subject: [PATCH 10/14] Cache will record AppMetadata from now on --- msal/token_cache.py | 12 ++++++++++++ tests/test_token_cache.py | 9 +++++++++ 2 files changed, 21 insertions(+) diff --git a/msal/token_cache.py b/msal/token_cache.py index 0353f9d0..8fd79e59 100644 --- a/msal/token_cache.py +++ b/msal/token_cache.py @@ -30,6 +30,7 @@ class CredentialType: REFRESH_TOKEN = "RefreshToken" ACCOUNT = "Account" # Not exactly a credential type, but we put it here ID_TOKEN = "IdToken" + APP_METADATA = "AppMetadata" class AuthorityType: ADFS = "ADFS" @@ -162,6 +163,17 @@ def add(self, event, now=None): rt["family_id"] = response["foci"] self._cache.setdefault(self.CredentialType.REFRESH_TOKEN, {})[key] = rt + key = self._build_appmetadata_key(environment, event.get("client_id")) + self._cache.setdefault(self.CredentialType.APP_METADATA, {})[key] = { + "client_id": event.get("client_id"), + "environment": environment, + "family_id": response.get("foci"), # None is also valid + } + + @staticmethod + def _build_appmetadata_key(environment, client_id): + return "appmetadata-{}-{}".format(environment or "", client_id or "") + @classmethod def _build_rt_key( cls, diff --git a/tests/test_token_cache.py b/tests/test_token_cache.py index ce5c3063..1fac231e 100644 --- a/tests/test_token_cache.py +++ b/tests/test_token_cache.py @@ -88,6 +88,15 @@ def testAdd(self): self.cache._cache["IdToken"].get( 'uid.utid-login.example.com-idtoken-my_client_id-contoso-') ) + self.assertEqual( + { + "client_id": "my_client_id", + 'environment': 'login.example.com', + "family_id": None, + }, + self.cache._cache.get("AppMetadata", {}).get( + "appmetadata-login.example.com-my_client_id") + ) class SerializableTokenCacheTestCase(TokenCacheTestCase): From 56d8ddee9043d3cafb676b8d7bfbf102a06681ae Mon Sep 17 00:00:00 2001 From: Ray Luo Date: Thu, 28 Mar 2019 15:20:50 -0700 Subject: [PATCH 11/14] Refactor tests to provide 2 helpers --- tests/test_token_cache.py | 59 ++++++++++++++++++++++++++++----------- 1 file changed, 43 insertions(+), 16 deletions(-) diff --git a/tests/test_token_cache.py b/tests/test_token_cache.py index 1fac231e..f3771c38 100644 --- a/tests/test_token_cache.py +++ b/tests/test_token_cache.py @@ -12,30 +12,57 @@ class TokenCacheTestCase(unittest.TestCase): + @staticmethod + def build_id_token(sub="sub", oid="oid", preferred_username="me", **kwargs): + return "header.%s.signature" % base64.b64encode(json.dumps(dict({ + "sub": sub, + "oid": oid, + "preferred_username": preferred_username, + }, **kwargs)).encode()).decode('utf-8') + + @staticmethod + def build_response( # simulate a response from AAD + uid="uid", utid="utid", # They will form client_info + access_token=None, expires_in=3600, token_type="some type", + refresh_token=None, + foci=None, + id_token=None, # or something generated by build_id_token() + error=None, + ): + response = { + "client_info": base64.b64encode(json.dumps({ + "uid": uid, "utid": utid, + }).encode()).decode('utf-8'), + } + if error: + response["error"] = error + if access_token: + response.update({ + "access_token": access_token, + "expires_in": expires_in, + "token_type": token_type, + }) + if refresh_token: + response["refresh_token"] = refresh_token + if id_token: + response["id_token"] = id_token + if foci: + response["foci"] = foci + return response + def setUp(self): self.cache = TokenCache() def testAdd(self): - client_info = base64.b64encode(b''' - {"uid": "uid", "utid": "utid"} - ''').decode('utf-8') - id_token = "header.%s.signature" % base64.b64encode(b'''{ - "sub": "subject", - "oid": "object1234", - "preferred_username": "John Doe" - }''').decode('utf-8') + id_token = self.build_id_token(oid="object1234", preferred_username="John Doe") self.cache.add({ "client_id": "my_client_id", "scope": ["s2", "s1", "s3"], # Not in particular order "token_endpoint": "https://login.example.com/contoso/v2/token", - "response": { - "access_token": "an access token", - "token_type": "some type", - "expires_in": 3600, - "refresh_token": "a refresh token", - "client_info": client_info, - "id_token": id_token, - }, + "response": self.build_response( + uid="uid", utid="utid", # client_info + expires_in=3600, access_token="an access token", + id_token=id_token, refresh_token="a refresh token"), }, now=1000) self.assertEqual( { From afa37b1e3e4baab8ab6bf0a0ed5ad470cbb51f85 Mon Sep 17 00:00:00 2001 From: Ray Luo Date: Thu, 28 Mar 2019 16:28:41 -0700 Subject: [PATCH 12/14] FOCI Single Sign On --- msal/application.py | 63 ++++++++++++++++++++++++----- requirements.txt | 1 + tests/test_application.py | 84 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 139 insertions(+), 9 deletions(-) diff --git a/msal/application.py b/msal/application.py index 9e08227a..ee440a6d 100644 --- a/msal/application.py +++ b/msal/application.py @@ -305,26 +305,71 @@ def acquire_token_silent( "token_type": "Bearer", "expires_in": int(expires_in), # OAuth2 specs defines it as int } + return self._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family( + the_authority, decorate_scope(scopes, self.client_id), account, + **kwargs) + def _acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family( + self, authority, scopes, account, **kwargs): + query = { + "environment": authority.instance, + "home_account_id": (account or {}).get("home_account_id"), + # "realm": authority.tenant, # AAD RTs are tenant-independent + } + apps = self.token_cache.find( # Use find(), rather than token_cache.get(...) + TokenCache.CredentialType.APP_METADATA, query={ + "environment": authority.instance, "client_id": self.client_id}) + app_metadata = apps[0] if apps else {} + if not app_metadata: # Meaning this app is now used for the first time. + # When/if we have a way to directly detect current app's family, + # we'll rewrite this block, to support multiple families. + # For now, we try existing RTs (*). If it works, we are in that family. + # (*) RTs of a different app/family are not supposed to be + # shared with or accessible by us in the first place. + at = self._acquire_token_silent_by_finding_specific_refresh_token( + authority, scopes, + dict(query, family_id="1"), # A hack, we have only 1 family for now + rt_remover=lambda rt_item: None, # NO-OP b/c RTs are likely not mine + break_condition=lambda response: # Break loop when app not in family + # Based on an AAD-only behavior mentioned in internal doc here + # https://msazure.visualstudio.com/One/_git/ESTS-Docs/pullrequest/1138595 + "client_mismatch" in response.get("error_additional_info", []), + **kwargs) + if at: + return at + if app_metadata.get("family_id"): # Meaning this app belongs to this family + at = self._acquire_token_silent_by_finding_specific_refresh_token( + authority, scopes, dict(query, family_id=app_metadata["family_id"]), + **kwargs) + if at: + return at + # Either this app is an orphan, so we will naturally use its own RT; + # or all attempts above have failed, so we fall back to non-foci behavior. + return self._acquire_token_silent_by_finding_specific_refresh_token( + authority, scopes, dict(query, client_id=self.client_id), **kwargs) + + def _acquire_token_silent_by_finding_specific_refresh_token( + self, authority, scopes, query, + rt_remover=None, break_condition=lambda response: False, **kwargs): matches = self.token_cache.find( self.token_cache.CredentialType.REFRESH_TOKEN, # target=scopes, # AAD RTs are scope-independent - query={ - "client_id": self.client_id, - "environment": the_authority.instance, - "home_account_id": (account or {}).get("home_account_id"), - # "realm": the_authority.tenant, # AAD RTs are tenant-independent - }) - client = self._build_client(self.client_credential, the_authority) + query=query) + logger.debug("Found %d RTs matching %s", len(matches), query) + client = self._build_client(self.client_credential, authority) for entry in matches: - logger.debug("Cache hit an RT") + logger.debug("Cache attempts an RT") response = client.obtain_token_by_refresh_token( entry, rt_getter=lambda token_item: token_item["secret"], - scope=decorate_scope(scopes, self.client_id)) + on_removing_rt=rt_remover or self.token_cache.remove_rt, + scope=scopes, + **kwargs) if "error" not in response: return response logger.debug( "Refresh failed. {error}: {error_description}".format(**response)) + if break_condition(response): + break class PublicClientApplication(ClientApplication): # browser app or mobile app diff --git a/requirements.txt b/requirements.txt index 9c558e35..61a6510d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ . +mock; python_version < '3.3' diff --git a/tests/test_application.py b/tests/test_application.py index 359086ee..29707acd 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -2,8 +2,15 @@ import json import logging +try: + from unittest.mock import * # Python 3 +except: + from mock import * # Need an external mock package + from msal.application import * +import msal from tests import unittest +from tests.test_token_cache import TokenCacheTestCase THIS_FOLDER = os.path.dirname(__file__) @@ -155,3 +162,80 @@ def test_auth_code(self): error_description=result.get("error_description"))) self.assertCacheWorks(result) + +class TestClientApplicationAcquireTokenSilentFociBehaviors(unittest.TestCase): + + def setUp(self): + self.authority_url = "https://login.microsoftonline.com/common" + self.authority = msal.authority.Authority(self.authority_url) + self.scopes = ["s1", "s2"] + self.uid = "my_uid" + self.utid = "my_utid" + self.account = {"home_account_id": "{}.{}".format(self.uid, self.utid)} + self.frt = "what the frt" + self.cache = msal.SerializableTokenCache() + self.cache.add({ # Pre-populate a FRT + "client_id": "preexisting_family_app", + "scope": self.scopes, + "token_endpoint": "{}/oauth2/v2.0/token".format(self.authority_url), + "response": TokenCacheTestCase.build_response( + uid=self.uid, utid=self.utid, refresh_token=self.frt, foci="1"), + }) # The add(...) helper populates correct home_account_id for future searching + + def test_unknown_orphan_app_will_attempt_frt_and_not_remove_it(self): + app = ClientApplication( + "unknown_orphan", authority=self.authority_url, token_cache=self.cache) + logger.debug("%s.cache = %s", self.id(), self.cache.serialize()) + def tester(url, data=None, **kwargs): + self.assertEqual(self.frt, data.get("refresh_token"), "Should attempt the FRT") + return Mock(status_code=200, json=Mock(return_value={ + "error": "invalid_grant", + "error_description": "Was issued to another client"})) + app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family( + self.authority, self.scopes, self.account, post=tester) + self.assertNotEqual([], app.token_cache.find( + msal.TokenCache.CredentialType.REFRESH_TOKEN, query={"secret": self.frt}), + "The FRT should not be removed from the cache") + + def test_known_orphan_app_will_skip_frt_and_only_use_its_own_rt(self): + app = ClientApplication( + "known_orphan", authority=self.authority_url, token_cache=self.cache) + rt = "RT for this orphan app. We will check it being used by this test case." + self.cache.add({ # Populate its RT and AppMetadata, so it becomes a known orphan app + "client_id": app.client_id, + "scope": self.scopes, + "token_endpoint": "{}/oauth2/v2.0/token".format(self.authority_url), + "response": TokenCacheTestCase.build_response( + uid=self.uid, utid=self.utid, refresh_token=rt), + }) + logger.debug("%s.cache = %s", self.id(), self.cache.serialize()) + def tester(url, data=None, **kwargs): + self.assertEqual(rt, data.get("refresh_token"), "Should attempt the RT") + return Mock(status_code=200, json=Mock(return_value={})) + app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family( + self.authority, self.scopes, self.account, post=tester) + + def test_unknown_family_app_will_attempt_frt_and_join_family(self): + def tester(url, data=None, **kwargs): + self.assertEqual( + self.frt, data.get("refresh_token"), "Should attempt the FRT") + return Mock( + status_code=200, + json=Mock(return_value=TokenCacheTestCase.build_response( + uid=self.uid, utid=self.utid, foci="1", access_token="at"))) + app = ClientApplication( + "unknown_family_app", authority=self.authority_url, token_cache=self.cache) + at = app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family( + self.authority, self.scopes, self.account, post=tester) + logger.debug("%s.cache = %s", self.id(), self.cache.serialize()) + self.assertEqual("at", at.get("access_token"), "New app should get a new AT") + app_metadata = app.token_cache.find( + msal.TokenCache.CredentialType.APP_METADATA, + query={"client_id": app.client_id}) + self.assertNotEqual([], app_metadata, "Should record new app's metadata") + self.assertEqual("1", app_metadata[0].get("family_id"), + "The new family app should be recorded as in the same family") + # Known family app will simply use FRT, which is largely the same as this one + + # Will not test scenario of app leaving family. Per specs, it won't happen. + From f2c2fa11c1c5a9c10837fc20590d6e39006fa01b Mon Sep 17 00:00:00 2001 From: Abhidnya Date: Tue, 2 Apr 2019 13:35:20 -0700 Subject: [PATCH 13/14] Reading Authority Aliases (#25) After testing it manually with .NET for cross platform sharing for token cache, we confirmed that the .NET test cases pass :) --- msal/application.py | 72 ++++++++++++++++++++++++++++++++------- tests/test_application.py | 40 ++++++++++++++++++++++ 2 files changed, 99 insertions(+), 13 deletions(-) diff --git a/msal/application.py b/msal/application.py index ee440a6d..624fbe3d 100644 --- a/msal/application.py +++ b/msal/application.py @@ -5,6 +5,9 @@ from urllib.parse import urljoin import logging import sys +import warnings + +import requests from .oauth2cli import Client, JwtSigner from .authority import Authority @@ -101,6 +104,14 @@ def __init__( # Here the self.authority is not the same type as authority in input self.token_cache = token_cache or TokenCache() self.client = self._build_client(client_credential, self.authority) + self.authority_groups = self._get_authority_aliases() + + def _get_authority_aliases(self): + resp = requests.get( + "https://login.microsoftonline.com/common/discovery/instance?api-version=1.1&authorization_endpoint=https://login.microsoftonline.com/common/oauth2/authorize", + headers={'Accept': 'application/json'}) + resp.raise_for_status() + return [set(group['aliases']) for group in resp.json()['metadata']] def _build_client(self, client_credential, authority): client_assertion = None @@ -236,11 +247,15 @@ def get_accounts(self, username=None): Your app can choose to display those information to end user, and allow user to choose one of his/her accounts to proceed. """ - accounts = [a for a in self.token_cache.find( # Find all useful accounts - self.token_cache.CredentialType.ACCOUNT, - query={"environment": self.authority.instance}) - if a["authority_type"] in ( - TokenCache.AuthorityType.ADFS, TokenCache.AuthorityType.MSSTS)] + accounts = self._find_msal_accounts(environment=self.authority.instance) + if not accounts: # Now try other aliases of this authority instance + for group in self.authority_groups: + if self.authority.instance in group: + for alias in group: + if alias != self.authority.instance: + accounts = self._find_msal_accounts(environment=alias) + if accounts: + break if username: # Federated account["username"] from AAD could contain mixed case lowercase_username = username.lower() @@ -253,6 +268,12 @@ def get_accounts(self, username=None): # apps would fall back to other acquire methods. This is the standard pattern. return accounts + def _find_msal_accounts(self, environment): + return [a for a in self.token_cache.find( + TokenCache.CredentialType.ACCOUNT, query={"environment": environment}) + if a["authority_type"] in ( + TokenCache.AuthorityType.ADFS, TokenCache.AuthorityType.MSSTS)] + def acquire_token_silent( self, scopes, # type: List[str] @@ -279,19 +300,44 @@ def acquire_token_silent( - None when cache lookup does not yield anything. """ assert isinstance(scopes, list), "Invalid parameter type" - the_authority = Authority( - authority, - verify=self.verify, proxies=self.proxies, timeout=self.timeout, - ) if authority else self.authority - + if authority: + warnings.warn("We haven't decided how/if this method will accept authority parameter") + # the_authority = Authority( + # authority, + # verify=self.verify, proxies=self.proxies, timeout=self.timeout, + # ) if authority else self.authority + result = self._acquire_token_silent(scopes, account, self.authority, **kwargs) + if result: + return result + for group in self.authority_groups: + if self.authority.instance in group: + for alias in group: + if alias != self.authority.instance: + the_authority = Authority( + "https://" + alias + "/" + self.authority.tenant, + validate_authority=False, + verify=self.verify, proxies=self.proxies, + timeout=self.timeout,) + result = self._acquire_token_silent( + scopes, account, the_authority, **kwargs) + if result: + return result + + def _acquire_token_silent( + self, + scopes, # type: List[str] + account, # type: Optional[Account] + authority, # This can be different than self.authority + force_refresh=False, # type: Optional[boolean] + **kwargs): if not force_refresh: matches = self.token_cache.find( self.token_cache.CredentialType.ACCESS_TOKEN, target=scopes, query={ "client_id": self.client_id, - "environment": the_authority.instance, - "realm": the_authority.tenant, + "environment": authority.instance, + "realm": authority.tenant, "home_account_id": (account or {}).get("home_account_id"), }) now = time.time() @@ -306,7 +352,7 @@ def acquire_token_silent( "expires_in": int(expires_in), # OAuth2 specs defines it as int } return self._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family( - the_authority, decorate_scope(scopes, self.client_id), account, + authority, decorate_scope(scopes, self.client_id), account, **kwargs) def _acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family( diff --git a/tests/test_application.py b/tests/test_application.py index 29707acd..6346774a 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -239,3 +239,43 @@ def tester(url, data=None, **kwargs): # Will not test scenario of app leaving family. Per specs, it won't happen. +class TestClientApplicationForAuthorityMigration(unittest.TestCase): + + @classmethod + def setUp(self): + self.environment_in_cache = "sts.windows.net" + self.authority_url_in_app = "https://login.microsoftonline.com/common" + self.scopes = ["s1", "s2"] + uid = "uid" + utid = "utid" + self.account = {"home_account_id": "{}.{}".format(uid, utid)} + self.client_id = "my_app" + self.access_token = "access token for testing authority aliases" + self.cache = msal.SerializableTokenCache() + self.cache.add({ + "client_id": self.client_id, + "scope": self.scopes, + "token_endpoint": "https://{}/common/oauth2/v2.0/token".format( + self.environment_in_cache), + "response": TokenCacheTestCase.build_response( + uid=uid, utid=utid, + access_token=self.access_token, refresh_token="some refresh token"), + }) # The add(...) helper populates correct home_account_id for future searching + + def test_get_accounts(self): + app = ClientApplication( + self.client_id, + authority=self.authority_url_in_app, token_cache=self.cache) + accounts = app.get_accounts() + self.assertNotEqual([], accounts) + self.assertEqual(self.environment_in_cache, accounts[0].get("environment"), + "We should be able to find an account under an authority alias") + + def test_acquire_token_silent(self): + app = ClientApplication( + self.client_id, + authority=self.authority_url_in_app, token_cache=self.cache) + at = app.acquire_token_silent(self.scopes, self.account) + self.assertNotEqual(None, at) + self.assertEqual(self.access_token, at.get('access_token')) + From 48ba43bd5c06c6c6054afbb19872bcb47be787ca Mon Sep 17 00:00:00 2001 From: Abhidnya Date: Tue, 2 Apr 2019 13:48:51 -0700 Subject: [PATCH 14/14] MSAL Python 0.3.0 Bumping version number --- msal/application.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/msal/application.py b/msal/application.py index 624fbe3d..31560f45 100644 --- a/msal/application.py +++ b/msal/application.py @@ -18,7 +18,7 @@ # The __init__.py will import this. Not the other way around. -__version__ = "0.2.0" +__version__ = "0.3.0" logger = logging.getLogger(__name__)