diff --git a/msal/token_cache.py b/msal/token_cache.py index e554e118..66be5c9f 100644 --- a/msal/token_cache.py +++ b/msal/token_cache.py @@ -43,6 +43,8 @@ def __init__(self): self._lock = threading.RLock() self._cache = {} self.key_makers = { + # Note: We have changed token key format before when ordering scopes; + # changing token key won't result in cache miss. self.CredentialType.REFRESH_TOKEN: lambda home_account_id=None, environment=None, client_id=None, target=None, **ignored_payload_from_a_real_token: @@ -56,14 +58,18 @@ def __init__(self): ]).lower(), self.CredentialType.ACCESS_TOKEN: lambda home_account_id=None, environment=None, client_id=None, - realm=None, target=None, **ignored_payload_from_a_real_token: - "-".join([ + realm=None, target=None, + # Note: New field(s) can be added here + #key_id=None, + **ignored_payload_from_a_real_token: + "-".join([ # Note: Could use a hash here to shorten key length home_account_id or "", environment or "", self.CredentialType.ACCESS_TOKEN, client_id or "", realm or "", target or "", + #key_id or "", # So ATs of different key_id can coexist ]).lower(), self.CredentialType.ID_TOKEN: lambda home_account_id=None, environment=None, client_id=None, @@ -124,7 +130,7 @@ def _is_matching(entry: dict, query: dict, target_set: set = None) -> bool: target_set <= set(entry.get("target", "").split()) if target_set else True) - def search(self, credential_type, target=None, query=None): # O(n) generator + def search(self, credential_type, target=None, query=None, *, now=None): # O(n) generator """Returns a generator of matching entries. It is O(1) for AT hits, and O(n) for other types. @@ -150,21 +156,33 @@ def search(self, credential_type, target=None, query=None): # O(n) generator target_set = set(target) with self._lock: - # Since the target inside token cache key is (per schema) unsorted, - # there is no point to attempt an O(1) key-value search here. - # So we always do an O(n) in-memory search. + # O(n) search. The key is NOT used in search. + now = int(time.time() if now is None else now) + expired_access_tokens = [ + # Especially when/if we key ATs by ephemeral fields such as key_id, + # stale ATs keyed by an old key_id would stay forever. + # Here we collect them for their removal. + ] for entry in self._cache.get(credential_type, {}).values(): + if ( # Automatically delete expired access tokens + credential_type == self.CredentialType.ACCESS_TOKEN + and int(entry["expires_on"]) < now + ): + expired_access_tokens.append(entry) # Can't delete them within current for-loop + continue if (entry != preferred_result # Avoid yielding the same entry twice and self._is_matching(entry, query, target_set=target_set) ): yield entry + for at in expired_access_tokens: + self.remove_at(at) - def find(self, credential_type, target=None, query=None): + def find(self, credential_type, target=None, query=None, *, now=None): """Equivalent to list(search(...)).""" warnings.warn( "Use list(search(...)) instead to explicitly get a list.", DeprecationWarning) - return list(self.search(credential_type, target=target, query=query)) + return list(self.search(credential_type, target=target, query=query, now=now)) def add(self, event, now=None): """Handle a token obtaining event, and add tokens into cache.""" @@ -249,8 +267,11 @@ def __add(self, event, now=None): "expires_on": str(now + expires_in), # Same here "extended_expires_on": str(now + ext_expires_in) # Same here } - if data.get("key_id"): # It happens in SSH-cert or POP scenario - at["key_id"] = data.get("key_id") + at.update({k: data[k] for k in data if k in { + # Also store extra data which we explicitly allow + # So that we won't accidentally store a user's password etc. + "key_id", # It happens in SSH-cert or POP scenario + }}) if "refresh_in" in response: refresh_in = response["refresh_in"] # It is an integer at["refresh_on"] = str(now + refresh_in) # Schema wants a string diff --git a/tests/test_application.py b/tests/test_application.py index de916153..0736164c 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -340,6 +340,7 @@ class TestApplicationForRefreshInBehaviors(unittest.TestCase): account = {"home_account_id": "{}.{}".format(uid, utid)} rt = "this is a rt" client_id = "my_app" + soon = 60 # application.py considers tokens within 5 minutes as expired @classmethod def setUpClass(cls): # Initialization at runtime, not interpret-time @@ -414,7 +415,8 @@ def mock_post(url, headers=None, *args, **kwargs): def test_expired_token_and_unavailable_aad_should_return_error(self): # a.k.a. Attempt refresh expired token when AAD unavailable - self.populate_cache(access_token="expired at", expires_in=-1, refresh_in=-900) + self.populate_cache( + access_token="expired at", expires_in=self.soon, refresh_in=-900) error = "something went wrong" def mock_post(url, headers=None, *args, **kwargs): self.assertEqual("4|84,3|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY)) @@ -425,7 +427,8 @@ def mock_post(url, headers=None, *args, **kwargs): def test_expired_token_and_available_aad_should_return_new_token(self): # a.k.a. Attempt refresh expired token when AAD available - self.populate_cache(access_token="expired at", expires_in=-1, refresh_in=-900) + self.populate_cache( + access_token="expired at", expires_in=self.soon, refresh_in=-900) new_access_token = "new AT" new_refresh_in = 123 def mock_post(url, headers=None, *args, **kwargs): diff --git a/tests/test_token_cache.py b/tests/test_token_cache.py index 4e301fa3..494d6daf 100644 --- a/tests/test_token_cache.py +++ b/tests/test_token_cache.py @@ -3,7 +3,7 @@ import json import time -from msal.token_cache import * +from msal.token_cache import TokenCache, SerializableTokenCache from tests import unittest @@ -51,11 +51,14 @@ class TokenCacheTestCase(unittest.TestCase): def setUp(self): self.cache = TokenCache() + self.at_key_maker = self.cache.key_makers[ + TokenCache.CredentialType.ACCESS_TOKEN] def testAddByAad(self): client_id = "my_client_id" id_token = build_id_token( oid="object1234", preferred_username="John Doe", aud=client_id) + now = 1000 self.cache.add({ "client_id": client_id, "scope": ["s2", "s1", "s3"], # Not in particular order @@ -64,7 +67,7 @@ def testAddByAad(self): uid="uid", utid="utid", # client_info expires_in=3600, access_token="an access token", id_token=id_token, refresh_token="a refresh token"), - }, now=1000) + }, now=now) access_token_entry = { 'cached_at': "1000", 'client_id': 'my_client_id', @@ -78,14 +81,11 @@ def testAddByAad(self): 'target': 's1 s2 s3', # Sorted 'token_type': 'some type', } - self.assertEqual( - access_token_entry, - self.cache._cache["AccessToken"].get( - 'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s1 s2 s3') - ) + self.assertEqual(access_token_entry, self.cache._cache["AccessToken"].get( + self.at_key_maker(**access_token_entry))) self.assertIn( access_token_entry, - self.cache.find(self.cache.CredentialType.ACCESS_TOKEN), + self.cache.find(self.cache.CredentialType.ACCESS_TOKEN, now=now), "find(..., query=None) should not crash, even though MSAL does not use it") self.assertEqual( { @@ -144,8 +144,7 @@ def testAddByAdfs(self): expires_in=3600, access_token="an access token", id_token=id_token, refresh_token="a refresh token"), }, now=1000) - self.assertEqual( - { + access_token_entry = { 'cached_at': "1000", 'client_id': 'my_client_id', 'credential_type': 'AccessToken', @@ -157,10 +156,9 @@ def testAddByAdfs(self): 'secret': 'an access token', 'target': 's1 s2 s3', # Sorted 'token_type': 'some type', - }, - self.cache._cache["AccessToken"].get( - 'subject-fs.msidlab8.com-accesstoken-my_client_id-adfs-s1 s2 s3') - ) + } + self.assertEqual(access_token_entry, self.cache._cache["AccessToken"].get( + self.at_key_maker(**access_token_entry))) self.assertEqual( { 'client_id': 'my_client_id', @@ -206,37 +204,67 @@ def testAddByAdfs(self): "appmetadata-fs.msidlab8.com-my_client_id") ) - def test_key_id_is_also_recorded(self): - my_key_id = "some_key_id_123" + def assertFoundAccessToken(self, *, scopes, query, data=None, now=None): + cached_at = None + for cached_at in self.cache.search( + TokenCache.CredentialType.ACCESS_TOKEN, + target=scopes, query=query, now=now, + ): + for k, v in (data or {}).items(): # The extra data, if any + self.assertEqual(cached_at.get(k), v, f"AT should contain {k}={v}") + self.assertTrue(cached_at, "AT should be cached and searchable") + return cached_at + + def _test_data_should_be_saved_and_searchable_in_access_token(self, data): + scopes = ["s2", "s1", "s3"] # Not in particular order + now = 1000 self.cache.add({ - "data": {"key_id": my_key_id}, + "data": data, "client_id": "my_client_id", - "scope": ["s2", "s1", "s3"], # Not in particular order + "scope": scopes, "token_endpoint": "https://login.example.com/contoso/v2/token", "response": build_response( uid="uid", utid="utid", # client_info expires_in=3600, access_token="an access token", refresh_token="a refresh token"), - }, now=1000) - cached_key_id = self.cache._cache["AccessToken"].get( - 'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s1 s2 s3', - {}).get("key_id") - self.assertEqual(my_key_id, cached_key_id, "AT should be bound to the key") + }, now=now) + self.assertFoundAccessToken(scopes=scopes, data=data, now=now, query=dict( + data, # Also use the extra data as a query criteria + client_id="my_client_id", + environment="login.example.com", + realm="contoso", + home_account_id="uid.utid", + )) + + def test_extra_data_should_also_be_recorded_and_searchable_in_access_token(self): + self._test_data_should_be_saved_and_searchable_in_access_token({"key_id": "1"}) + + def test_access_tokens_with_different_key_id(self): + self._test_data_should_be_saved_and_searchable_in_access_token({"key_id": "1"}) + self._test_data_should_be_saved_and_searchable_in_access_token({"key_id": "2"}) + self.assertEqual( + len(self.cache._cache["AccessToken"]), + 1, """Historically, tokens are not keyed by key_id, +so a new token overwrites the old one, and we would end up with 1 token in cache""") def test_refresh_in_should_be_recorded_as_refresh_on(self): # Sounds weird. Yep. + scopes = ["s2", "s1", "s3"] # Not in particular order self.cache.add({ "client_id": "my_client_id", - "scope": ["s2", "s1", "s3"], # Not in particular order + "scope": scopes, "token_endpoint": "https://login.example.com/contoso/v2/token", "response": build_response( uid="uid", utid="utid", # client_info expires_in=3600, refresh_in=1800, access_token="an access token", ), #refresh_token="a refresh token"), }, now=1000) - refresh_on = self.cache._cache["AccessToken"].get( - 'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s1 s2 s3', - {}).get("refresh_on") - self.assertEqual("2800", refresh_on, "Should save refresh_on") + at = self.assertFoundAccessToken(scopes=scopes, query=dict( + client_id="my_client_id", + environment="login.example.com", + realm="contoso", + home_account_id="uid.utid", + )) + self.assertEqual("2800", at.get("refresh_on"), "Should save refresh_on") def test_old_rt_data_with_wrong_key_should_still_be_salvaged_into_new_rt(self): sample = { @@ -258,7 +286,7 @@ def test_old_rt_data_with_wrong_key_should_still_be_salvaged_into_new_rt(self): ) -class SerializableTokenCacheTestCase(TokenCacheTestCase): +class SerializableTokenCacheTestCase(unittest.TestCase): # Run all inherited test methods, and have extra check in tearDown() def setUp(self):