diff --git a/msal/application.py b/msal/application.py index b667a4a8..31560f45 100644 --- a/msal/application.py +++ b/msal/application.py @@ -5,6 +5,9 @@ from urllib.parse import urljoin import logging import sys +import warnings + +import requests from .oauth2cli import Client, JwtSigner from .authority import Authority @@ -15,7 +18,7 @@ # The __init__.py will import this. Not the other way around. -__version__ = "0.2.0" +__version__ = "0.3.0" logger = logging.getLogger(__name__) @@ -101,6 +104,14 @@ def __init__( # Here the self.authority is not the same type as authority in input self.token_cache = token_cache or TokenCache() self.client = self._build_client(client_credential, self.authority) + self.authority_groups = self._get_authority_aliases() + + def _get_authority_aliases(self): + resp = requests.get( + "https://login.microsoftonline.com/common/discovery/instance?api-version=1.1&authorization_endpoint=https://login.microsoftonline.com/common/oauth2/authorize", + headers={'Accept': 'application/json'}) + resp.raise_for_status() + return [set(group['aliases']) for group in resp.json()['metadata']] def _build_client(self, client_credential, authority): client_assertion = None @@ -236,19 +247,33 @@ def get_accounts(self, username=None): Your app can choose to display those information to end user, and allow user to choose one of his/her accounts to proceed. """ - # The following implementation finds accounts only from saved accounts, - # but does NOT correlate them with saved RTs. It probably won't matter, - # because in MSAL universe, there are always Accounts and RTs together. - accounts = self.token_cache.find( - self.token_cache.CredentialType.ACCOUNT, - query={"environment": self.authority.instance}) + accounts = self._find_msal_accounts(environment=self.authority.instance) + if not accounts: # Now try other aliases of this authority instance + for group in self.authority_groups: + if self.authority.instance in group: + for alias in group: + if alias != self.authority.instance: + accounts = self._find_msal_accounts(environment=alias) + if accounts: + break if username: # Federated account["username"] from AAD could contain mixed case lowercase_username = username.lower() accounts = [a for a in accounts if a["username"].lower() == lowercase_username] + # Does not further filter by existing RTs here. It probably won't matter. + # Because in most cases Accounts and RTs co-exist. + # Even in the rare case when an RT is revoked and then removed, + # acquire_token_silent() would then yield no result, + # apps would fall back to other acquire methods. This is the standard pattern. return accounts + def _find_msal_accounts(self, environment): + return [a for a in self.token_cache.find( + TokenCache.CredentialType.ACCOUNT, query={"environment": environment}) + if a["authority_type"] in ( + TokenCache.AuthorityType.ADFS, TokenCache.AuthorityType.MSSTS)] + def acquire_token_silent( self, scopes, # type: List[str] @@ -275,19 +300,44 @@ def acquire_token_silent( - None when cache lookup does not yield anything. """ assert isinstance(scopes, list), "Invalid parameter type" - the_authority = Authority( - authority, - verify=self.verify, proxies=self.proxies, timeout=self.timeout, - ) if authority else self.authority - + if authority: + warnings.warn("We haven't decided how/if this method will accept authority parameter") + # the_authority = Authority( + # authority, + # verify=self.verify, proxies=self.proxies, timeout=self.timeout, + # ) if authority else self.authority + result = self._acquire_token_silent(scopes, account, self.authority, **kwargs) + if result: + return result + for group in self.authority_groups: + if self.authority.instance in group: + for alias in group: + if alias != self.authority.instance: + the_authority = Authority( + "https://" + alias + "/" + self.authority.tenant, + validate_authority=False, + verify=self.verify, proxies=self.proxies, + timeout=self.timeout,) + result = self._acquire_token_silent( + scopes, account, the_authority, **kwargs) + if result: + return result + + def _acquire_token_silent( + self, + scopes, # type: List[str] + account, # type: Optional[Account] + authority, # This can be different than self.authority + force_refresh=False, # type: Optional[boolean] + **kwargs): if not force_refresh: matches = self.token_cache.find( self.token_cache.CredentialType.ACCESS_TOKEN, target=scopes, query={ "client_id": self.client_id, - "environment": the_authority.instance, - "realm": the_authority.tenant, + "environment": authority.instance, + "realm": authority.tenant, "home_account_id": (account or {}).get("home_account_id"), }) now = time.time() @@ -301,26 +351,71 @@ def acquire_token_silent( "token_type": "Bearer", "expires_in": int(expires_in), # OAuth2 specs defines it as int } + return self._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family( + authority, decorate_scope(scopes, self.client_id), account, + **kwargs) + def _acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family( + self, authority, scopes, account, **kwargs): + query = { + "environment": authority.instance, + "home_account_id": (account or {}).get("home_account_id"), + # "realm": authority.tenant, # AAD RTs are tenant-independent + } + apps = self.token_cache.find( # Use find(), rather than token_cache.get(...) + TokenCache.CredentialType.APP_METADATA, query={ + "environment": authority.instance, "client_id": self.client_id}) + app_metadata = apps[0] if apps else {} + if not app_metadata: # Meaning this app is now used for the first time. + # When/if we have a way to directly detect current app's family, + # we'll rewrite this block, to support multiple families. + # For now, we try existing RTs (*). If it works, we are in that family. + # (*) RTs of a different app/family are not supposed to be + # shared with or accessible by us in the first place. + at = self._acquire_token_silent_by_finding_specific_refresh_token( + authority, scopes, + dict(query, family_id="1"), # A hack, we have only 1 family for now + rt_remover=lambda rt_item: None, # NO-OP b/c RTs are likely not mine + break_condition=lambda response: # Break loop when app not in family + # Based on an AAD-only behavior mentioned in internal doc here + # https://msazure.visualstudio.com/One/_git/ESTS-Docs/pullrequest/1138595 + "client_mismatch" in response.get("error_additional_info", []), + **kwargs) + if at: + return at + if app_metadata.get("family_id"): # Meaning this app belongs to this family + at = self._acquire_token_silent_by_finding_specific_refresh_token( + authority, scopes, dict(query, family_id=app_metadata["family_id"]), + **kwargs) + if at: + return at + # Either this app is an orphan, so we will naturally use its own RT; + # or all attempts above have failed, so we fall back to non-foci behavior. + return self._acquire_token_silent_by_finding_specific_refresh_token( + authority, scopes, dict(query, client_id=self.client_id), **kwargs) + + def _acquire_token_silent_by_finding_specific_refresh_token( + self, authority, scopes, query, + rt_remover=None, break_condition=lambda response: False, **kwargs): matches = self.token_cache.find( self.token_cache.CredentialType.REFRESH_TOKEN, # target=scopes, # AAD RTs are scope-independent - query={ - "client_id": self.client_id, - "environment": the_authority.instance, - "home_account_id": (account or {}).get("home_account_id"), - # "realm": the_authority.tenant, # AAD RTs are tenant-independent - }) - client = self._build_client(self.client_credential, the_authority) + query=query) + logger.debug("Found %d RTs matching %s", len(matches), query) + client = self._build_client(self.client_credential, authority) for entry in matches: - logger.debug("Cache hit an RT") + logger.debug("Cache attempts an RT") response = client.obtain_token_by_refresh_token( entry, rt_getter=lambda token_item: token_item["secret"], - scope=decorate_scope(scopes, self.client_id)) + on_removing_rt=rt_remover or self.token_cache.remove_rt, + scope=scopes, + **kwargs) if "error" not in response: return response logger.debug( "Refresh failed. {error}: {error_description}".format(**response)) + if break_condition(response): + break class PublicClientApplication(ClientApplication): # browser app or mobile app diff --git a/msal/oauth2cli/oauth2.py b/msal/oauth2cli/oauth2.py index e0096cb0..b9727cf5 100644 --- a/msal/oauth2cli/oauth2.py +++ b/msal/oauth2cli/oauth2.py @@ -108,6 +108,9 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749 data=None, # All relevant data, which will go into the http body headers=None, # a dict to be sent as request headers timeout=None, + post=None, # A callable to replace requests.post(), for testing. + # Such as: lambda url, **kwargs: + # Mock(status_code=200, json=Mock(return_value={})) **kwargs # Relay all extra parameters to underlying requests ): # Returns the json object came from the OAUTH2 response _data = {'client_id': self.client_id, 'grant_type': grant_type} @@ -133,7 +136,7 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749 raise ValueError("token_endpoint not found in configuration") _headers = {'Accept': 'application/json'} _headers.update(headers or {}) - resp = self.session.post( + resp = (post or self.session.post)( self.configuration["token_endpoint"], headers=_headers, params=params, data=_data, auth=auth, timeout=timeout or self.timeout, @@ -393,16 +396,18 @@ def _obtain_token(self, grant_type, params=None, data=None, *args, **kwargs): def obtain_token_by_refresh_token(self, token_item, scope=None, rt_getter=lambda token_item: token_item["refresh_token"], + on_removing_rt=None, **kwargs): # type: (Union[str, dict], Union[str, list, set, tuple], Callable) -> dict """This is an "overload" which accepts a refresh token item as a dict, therefore this method can relay refresh_token item to event listeners. - :param refresh_token_item: A refresh token item came from storage + :param token_item: A refresh token item came from storage :param scope: If omitted, is treated as equal to the scope originally granted by the resource ownser, according to https://tools.ietf.org/html/rfc6749#section-6 :param rt_getter: A callable used to extract the RT from token_item + :param on_removing_rt: If absent, fall back to the one defined in initialization """ if isinstance(token_item, str): # Satisfy the L of SOLID, although we expect caller uses a dict @@ -412,7 +417,7 @@ def obtain_token_by_refresh_token(self, token_item, scope=None, resp = super(Client, self).obtain_token_by_refresh_token( rt_getter(token_item), scope=scope, **kwargs) if resp.get('error') == 'invalid_grant': - self.on_removing_rt(token_item) # Discard old RT + (on_removing_rt or self.on_removing_rt)(token_item) # Discard old RT if 'refresh_token' in resp: self.on_updating_rt(token_item, resp['refresh_token']) return resp diff --git a/msal/token_cache.py b/msal/token_cache.py index 116be878..8fd79e59 100644 --- a/msal/token_cache.py +++ b/msal/token_cache.py @@ -30,6 +30,11 @@ class CredentialType: REFRESH_TOKEN = "RefreshToken" ACCOUNT = "Account" # Not exactly a credential type, but we put it here ID_TOKEN = "IdToken" + APP_METADATA = "AppMetadata" + + class AuthorityType: + ADFS = "ADFS" + MSSTS = "MSSTS" # MSSTS means AAD v2 for both AAD & MSA def __init__(self): self._lock = threading.RLock() @@ -118,8 +123,8 @@ def add(self, event, now=None): "oid", decoded_id_token.get("sub")), "username": decoded_id_token.get("preferred_username"), "authority_type": - "ADFS" if realm == "adfs" - else "MSSTS", # MSSTS means AAD v2 for both AAD & MSA + self.AuthorityType.ADFS if realm == "adfs" + else self.AuthorityType.MSSTS, # "client_info": response.get("client_info"), # Optional } @@ -158,6 +163,17 @@ def add(self, event, now=None): rt["family_id"] = response["foci"] self._cache.setdefault(self.CredentialType.REFRESH_TOKEN, {})[key] = rt + key = self._build_appmetadata_key(environment, event.get("client_id")) + self._cache.setdefault(self.CredentialType.APP_METADATA, {})[key] = { + "client_id": event.get("client_id"), + "environment": environment, + "family_id": response.get("foci"), # None is also valid + } + + @staticmethod + def _build_appmetadata_key(environment, client_id): + return "appmetadata-{}-{}".format(environment or "", client_id or "") + @classmethod def _build_rt_key( cls, @@ -192,21 +208,24 @@ class SerializableTokenCache(TokenCache): Depending on your need, the following simple recipe for file-based persistence may be sufficient:: - import atexit - cache = SerializableTokenCache() - cache.deserialize(open("my_cache.bin", "rb").read()) + import os, atexit, msal + cache = msal.SerializableTokenCache() + if os.path.exists("my_cache.bin"): + cache.deserialize(open("my_cache.bin", "r").read()) atexit.register(lambda: - open("my_cache.bin", "wb").write(cache.serialize()) + open("my_cache.bin", "w").write(cache.serialize()) # Hint: The following optional line persists only when state changed if cache.has_state_changed else None ) - app = ClientApplication(..., token_cache=cache) + app = msal.ClientApplication(..., token_cache=cache) ... :var bool has_state_changed: Indicates whether the cache state has changed since last :func:`~serialize` or :func:`~deserialize` call. """ + has_state_changed = False + def add(self, event, **kwargs): super(SerializableTokenCache, self).add(event, **kwargs) self.has_state_changed = True diff --git a/requirements.txt b/requirements.txt index 9c558e35..61a6510d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ . +mock; python_version < '3.3' diff --git a/sample/client_credential_sample.py b/sample/client_credential_sample.py index 5f539465..59d90b5c 100644 --- a/sample/client_credential_sample.py +++ b/sample/client_credential_sample.py @@ -4,8 +4,8 @@ { "authority": "https://login.microsoftonline.com/organizations", "client_id": "your_client_id", + "scope": ["https://graph.microsoft.com/.default"], "secret": "This is a sample only. You better NOT persist your password." - "scope": ["https://graph.microsoft.com/.default"] } You can then run this sample with a JSON configuration file: diff --git a/sample/username_password_sample.py b/sample/username_password_sample.py index 0137ae6e..e7555edd 100644 --- a/sample/username_password_sample.py +++ b/sample/username_password_sample.py @@ -57,4 +57,6 @@ print(result.get("error")) print(result.get("error_description")) print(result.get("correlation_id")) # You may need this when reporting a bug - + if 65001 in result.get("error_codes", []): # Not mean to be coded programatically, but... + # AAD requires user consent for U/P flow + print("Visit this to consent:", app.get_authorization_request_url(scope)) diff --git a/tests/test_application.py b/tests/test_application.py index 180bef50..6346774a 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -2,8 +2,15 @@ import json import logging +try: + from unittest.mock import * # Python 3 +except: + from mock import * # Need an external mock package + from msal.application import * +import msal from tests import unittest +from tests.test_token_cache import TokenCacheTestCase THIS_FOLDER = os.path.dirname(__file__) @@ -110,10 +117,10 @@ def test_device_flow(self): CONFIG["client_id"], authority=CONFIG["authority"]) flow = self.app.initiate_device_flow(scopes=CONFIG.get("scope")) assert "user_code" in flow, str(flow) # Provision or policy might block DF - logging.warn(flow["message"]) + logging.warning(flow["message"]) duration = 30 - logging.warn("We will wait up to %d seconds for you to sign in" % duration) + logging.warning("We will wait up to %d seconds for you to sign in" % duration) flow["expires_at"] = time.time() + duration # Shorten the time for quick test result = self.app.acquire_token_by_device_flow(flow) self.assertLoosely( @@ -136,7 +143,7 @@ def setUpClass(cls): @unittest.skipUnless("scope" in CONFIG, "Missing scope") def test_auth_code(self): - from oauth2cli.authcode import obtain_auth_code + from msal.oauth2cli.authcode import obtain_auth_code port = CONFIG.get("listen_port", 44331) redirect_uri = "http://localhost:%s" % port auth_request_uri = self.app.get_authorization_request_url( @@ -155,3 +162,120 @@ def test_auth_code(self): error_description=result.get("error_description"))) self.assertCacheWorks(result) + +class TestClientApplicationAcquireTokenSilentFociBehaviors(unittest.TestCase): + + def setUp(self): + self.authority_url = "https://login.microsoftonline.com/common" + self.authority = msal.authority.Authority(self.authority_url) + self.scopes = ["s1", "s2"] + self.uid = "my_uid" + self.utid = "my_utid" + self.account = {"home_account_id": "{}.{}".format(self.uid, self.utid)} + self.frt = "what the frt" + self.cache = msal.SerializableTokenCache() + self.cache.add({ # Pre-populate a FRT + "client_id": "preexisting_family_app", + "scope": self.scopes, + "token_endpoint": "{}/oauth2/v2.0/token".format(self.authority_url), + "response": TokenCacheTestCase.build_response( + uid=self.uid, utid=self.utid, refresh_token=self.frt, foci="1"), + }) # The add(...) helper populates correct home_account_id for future searching + + def test_unknown_orphan_app_will_attempt_frt_and_not_remove_it(self): + app = ClientApplication( + "unknown_orphan", authority=self.authority_url, token_cache=self.cache) + logger.debug("%s.cache = %s", self.id(), self.cache.serialize()) + def tester(url, data=None, **kwargs): + self.assertEqual(self.frt, data.get("refresh_token"), "Should attempt the FRT") + return Mock(status_code=200, json=Mock(return_value={ + "error": "invalid_grant", + "error_description": "Was issued to another client"})) + app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family( + self.authority, self.scopes, self.account, post=tester) + self.assertNotEqual([], app.token_cache.find( + msal.TokenCache.CredentialType.REFRESH_TOKEN, query={"secret": self.frt}), + "The FRT should not be removed from the cache") + + def test_known_orphan_app_will_skip_frt_and_only_use_its_own_rt(self): + app = ClientApplication( + "known_orphan", authority=self.authority_url, token_cache=self.cache) + rt = "RT for this orphan app. We will check it being used by this test case." + self.cache.add({ # Populate its RT and AppMetadata, so it becomes a known orphan app + "client_id": app.client_id, + "scope": self.scopes, + "token_endpoint": "{}/oauth2/v2.0/token".format(self.authority_url), + "response": TokenCacheTestCase.build_response( + uid=self.uid, utid=self.utid, refresh_token=rt), + }) + logger.debug("%s.cache = %s", self.id(), self.cache.serialize()) + def tester(url, data=None, **kwargs): + self.assertEqual(rt, data.get("refresh_token"), "Should attempt the RT") + return Mock(status_code=200, json=Mock(return_value={})) + app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family( + self.authority, self.scopes, self.account, post=tester) + + def test_unknown_family_app_will_attempt_frt_and_join_family(self): + def tester(url, data=None, **kwargs): + self.assertEqual( + self.frt, data.get("refresh_token"), "Should attempt the FRT") + return Mock( + status_code=200, + json=Mock(return_value=TokenCacheTestCase.build_response( + uid=self.uid, utid=self.utid, foci="1", access_token="at"))) + app = ClientApplication( + "unknown_family_app", authority=self.authority_url, token_cache=self.cache) + at = app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family( + self.authority, self.scopes, self.account, post=tester) + logger.debug("%s.cache = %s", self.id(), self.cache.serialize()) + self.assertEqual("at", at.get("access_token"), "New app should get a new AT") + app_metadata = app.token_cache.find( + msal.TokenCache.CredentialType.APP_METADATA, + query={"client_id": app.client_id}) + self.assertNotEqual([], app_metadata, "Should record new app's metadata") + self.assertEqual("1", app_metadata[0].get("family_id"), + "The new family app should be recorded as in the same family") + # Known family app will simply use FRT, which is largely the same as this one + + # Will not test scenario of app leaving family. Per specs, it won't happen. + +class TestClientApplicationForAuthorityMigration(unittest.TestCase): + + @classmethod + def setUp(self): + self.environment_in_cache = "sts.windows.net" + self.authority_url_in_app = "https://login.microsoftonline.com/common" + self.scopes = ["s1", "s2"] + uid = "uid" + utid = "utid" + self.account = {"home_account_id": "{}.{}".format(uid, utid)} + self.client_id = "my_app" + self.access_token = "access token for testing authority aliases" + self.cache = msal.SerializableTokenCache() + self.cache.add({ + "client_id": self.client_id, + "scope": self.scopes, + "token_endpoint": "https://{}/common/oauth2/v2.0/token".format( + self.environment_in_cache), + "response": TokenCacheTestCase.build_response( + uid=uid, utid=utid, + access_token=self.access_token, refresh_token="some refresh token"), + }) # The add(...) helper populates correct home_account_id for future searching + + def test_get_accounts(self): + app = ClientApplication( + self.client_id, + authority=self.authority_url_in_app, token_cache=self.cache) + accounts = app.get_accounts() + self.assertNotEqual([], accounts) + self.assertEqual(self.environment_in_cache, accounts[0].get("environment"), + "We should be able to find an account under an authority alias") + + def test_acquire_token_silent(self): + app = ClientApplication( + self.client_id, + authority=self.authority_url_in_app, token_cache=self.cache) + at = app.acquire_token_silent(self.scopes, self.account) + self.assertNotEqual(None, at) + self.assertEqual(self.access_token, at.get('access_token')) + diff --git a/tests/test_token_cache.py b/tests/test_token_cache.py index eebd751d..f3771c38 100644 --- a/tests/test_token_cache.py +++ b/tests/test_token_cache.py @@ -12,30 +12,57 @@ class TokenCacheTestCase(unittest.TestCase): + @staticmethod + def build_id_token(sub="sub", oid="oid", preferred_username="me", **kwargs): + return "header.%s.signature" % base64.b64encode(json.dumps(dict({ + "sub": sub, + "oid": oid, + "preferred_username": preferred_username, + }, **kwargs)).encode()).decode('utf-8') + + @staticmethod + def build_response( # simulate a response from AAD + uid="uid", utid="utid", # They will form client_info + access_token=None, expires_in=3600, token_type="some type", + refresh_token=None, + foci=None, + id_token=None, # or something generated by build_id_token() + error=None, + ): + response = { + "client_info": base64.b64encode(json.dumps({ + "uid": uid, "utid": utid, + }).encode()).decode('utf-8'), + } + if error: + response["error"] = error + if access_token: + response.update({ + "access_token": access_token, + "expires_in": expires_in, + "token_type": token_type, + }) + if refresh_token: + response["refresh_token"] = refresh_token + if id_token: + response["id_token"] = id_token + if foci: + response["foci"] = foci + return response + def setUp(self): self.cache = TokenCache() def testAdd(self): - client_info = base64.b64encode(b''' - {"uid": "uid", "utid": "utid"} - ''').decode('utf-8') - id_token = "header.%s.signature" % base64.b64encode(b'''{ - "sub": "subject", - "oid": "object1234", - "preferred_username": "John Doe" - }''').decode('utf-8') + id_token = self.build_id_token(oid="object1234", preferred_username="John Doe") self.cache.add({ "client_id": "my_client_id", "scope": ["s2", "s1", "s3"], # Not in particular order "token_endpoint": "https://login.example.com/contoso/v2/token", - "response": { - "access_token": "an access token", - "token_type": "some type", - "expires_in": 3600, - "refresh_token": "a refresh token", - "client_info": client_info, - "id_token": id_token, - }, + "response": self.build_response( + uid="uid", utid="utid", # client_info + expires_in=3600, access_token="an access token", + id_token=id_token, refresh_token="a refresh token"), }, now=1000) self.assertEqual( { @@ -88,6 +115,15 @@ def testAdd(self): self.cache._cache["IdToken"].get( 'uid.utid-login.example.com-idtoken-my_client_id-contoso-') ) + self.assertEqual( + { + "client_id": "my_client_id", + 'environment': 'login.example.com', + "family_id": None, + }, + self.cache._cache.get("AppMetadata", {}).get( + "appmetadata-login.example.com-my_client_id") + ) class SerializableTokenCacheTestCase(TokenCacheTestCase): @@ -106,6 +142,12 @@ def setUp(self): } """) + def test_has_state_changed(self): + cache = SerializableTokenCache() + self.assertFalse(cache.has_state_changed) + cache.add({}) # An NO-OP add() still counts as a state change. Good enough. + self.assertTrue(cache.has_state_changed) + def tearDown(self): state = self.cache.serialize() logger.debug("serialize() = %s", state)