Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 31 additions & 10 deletions msal/token_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def __init__(self):
self._lock = threading.RLock()
self._cache = {}
self.key_makers = {
# Note: We have changed token key format before when ordering scopes;
# changing token key won't result in cache miss.
self.CredentialType.REFRESH_TOKEN:
lambda home_account_id=None, environment=None, client_id=None,
target=None, **ignored_payload_from_a_real_token:
Expand All @@ -56,14 +58,18 @@ def __init__(self):
]).lower(),
self.CredentialType.ACCESS_TOKEN:
lambda home_account_id=None, environment=None, client_id=None,
realm=None, target=None, **ignored_payload_from_a_real_token:
"-".join([
realm=None, target=None,
# Note: New field(s) can be added here
#key_id=None,
**ignored_payload_from_a_real_token:
"-".join([ # Note: Could use a hash here to shorten key length
home_account_id or "",
environment or "",
self.CredentialType.ACCESS_TOKEN,
client_id or "",
realm or "",
target or "",
#key_id or "", # So ATs of different key_id can coexist
]).lower(),
self.CredentialType.ID_TOKEN:
lambda home_account_id=None, environment=None, client_id=None,
Expand Down Expand Up @@ -124,7 +130,7 @@ def _is_matching(entry: dict, query: dict, target_set: set = None) -> bool:
target_set <= set(entry.get("target", "").split())
if target_set else True)

def search(self, credential_type, target=None, query=None): # O(n) generator
def search(self, credential_type, target=None, query=None, *, now=None): # O(n) generator
"""Returns a generator of matching entries.

It is O(1) for AT hits, and O(n) for other types.
Expand All @@ -150,21 +156,33 @@ def search(self, credential_type, target=None, query=None): # O(n) generator

target_set = set(target)
with self._lock:
# Since the target inside token cache key is (per schema) unsorted,
# there is no point to attempt an O(1) key-value search here.
# So we always do an O(n) in-memory search.
# O(n) search. The key is NOT used in search.
now = int(time.time() if now is None else now)
expired_access_tokens = [
# Especially when/if we key ATs by ephemeral fields such as key_id,
# stale ATs keyed by an old key_id would stay forever.
# Here we collect them for their removal.
]
for entry in self._cache.get(credential_type, {}).values():
if ( # Automatically delete expired access tokens
credential_type == self.CredentialType.ACCESS_TOKEN
and int(entry["expires_on"]) < now
):
expired_access_tokens.append(entry) # Can't delete them within current for-loop
continue
if (entry != preferred_result # Avoid yielding the same entry twice
and self._is_matching(entry, query, target_set=target_set)
):
yield entry
for at in expired_access_tokens:
self.remove_at(at)

def find(self, credential_type, target=None, query=None):
def find(self, credential_type, target=None, query=None, *, now=None):
"""Equivalent to list(search(...))."""
warnings.warn(
"Use list(search(...)) instead to explicitly get a list.",
DeprecationWarning)
return list(self.search(credential_type, target=target, query=query))
return list(self.search(credential_type, target=target, query=query, now=now))

def add(self, event, now=None):
"""Handle a token obtaining event, and add tokens into cache."""
Expand Down Expand Up @@ -249,8 +267,11 @@ def __add(self, event, now=None):
"expires_on": str(now + expires_in), # Same here
"extended_expires_on": str(now + ext_expires_in) # Same here
}
if data.get("key_id"): # It happens in SSH-cert or POP scenario
at["key_id"] = data.get("key_id")
at.update({k: data[k] for k in data if k in {
# Also store extra data which we explicitly allow
# So that we won't accidentally store a user's password etc.
"key_id", # It happens in SSH-cert or POP scenario
}})
if "refresh_in" in response:
refresh_in = response["refresh_in"] # It is an integer
at["refresh_on"] = str(now + refresh_in) # Schema wants a string
Expand Down
7 changes: 5 additions & 2 deletions tests/test_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ class TestApplicationForRefreshInBehaviors(unittest.TestCase):
account = {"home_account_id": "{}.{}".format(uid, utid)}
rt = "this is a rt"
client_id = "my_app"
soon = 60 # application.py considers tokens within 5 minutes as expired

@classmethod
def setUpClass(cls): # Initialization at runtime, not interpret-time
Expand Down Expand Up @@ -414,7 +415,8 @@ def mock_post(url, headers=None, *args, **kwargs):

def test_expired_token_and_unavailable_aad_should_return_error(self):
# a.k.a. Attempt refresh expired token when AAD unavailable
self.populate_cache(access_token="expired at", expires_in=-1, refresh_in=-900)
self.populate_cache(
access_token="expired at", expires_in=self.soon, refresh_in=-900)
error = "something went wrong"
def mock_post(url, headers=None, *args, **kwargs):
self.assertEqual("4|84,3|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY))
Expand All @@ -425,7 +427,8 @@ def mock_post(url, headers=None, *args, **kwargs):

def test_expired_token_and_available_aad_should_return_new_token(self):
# a.k.a. Attempt refresh expired token when AAD available
self.populate_cache(access_token="expired at", expires_in=-1, refresh_in=-900)
self.populate_cache(
access_token="expired at", expires_in=self.soon, refresh_in=-900)
new_access_token = "new AT"
new_refresh_in = 123
def mock_post(url, headers=None, *args, **kwargs):
Expand Down
86 changes: 57 additions & 29 deletions tests/test_token_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
import time

from msal.token_cache import *
from msal.token_cache import TokenCache, SerializableTokenCache
from tests import unittest


Expand Down Expand Up @@ -51,11 +51,14 @@ class TokenCacheTestCase(unittest.TestCase):

def setUp(self):
self.cache = TokenCache()
self.at_key_maker = self.cache.key_makers[
TokenCache.CredentialType.ACCESS_TOKEN]

def testAddByAad(self):
client_id = "my_client_id"
id_token = build_id_token(
oid="object1234", preferred_username="John Doe", aud=client_id)
now = 1000
self.cache.add({
"client_id": client_id,
"scope": ["s2", "s1", "s3"], # Not in particular order
Expand All @@ -64,7 +67,7 @@ def testAddByAad(self):
uid="uid", utid="utid", # client_info
expires_in=3600, access_token="an access token",
id_token=id_token, refresh_token="a refresh token"),
}, now=1000)
}, now=now)
access_token_entry = {
'cached_at': "1000",
'client_id': 'my_client_id',
Expand All @@ -78,14 +81,11 @@ def testAddByAad(self):
'target': 's1 s2 s3', # Sorted
'token_type': 'some type',
}
self.assertEqual(
access_token_entry,
self.cache._cache["AccessToken"].get(
'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s1 s2 s3')
)
self.assertEqual(access_token_entry, self.cache._cache["AccessToken"].get(
self.at_key_maker(**access_token_entry)))
self.assertIn(
access_token_entry,
self.cache.find(self.cache.CredentialType.ACCESS_TOKEN),
self.cache.find(self.cache.CredentialType.ACCESS_TOKEN, now=now),
"find(..., query=None) should not crash, even though MSAL does not use it")
self.assertEqual(
{
Expand Down Expand Up @@ -144,8 +144,7 @@ def testAddByAdfs(self):
expires_in=3600, access_token="an access token",
id_token=id_token, refresh_token="a refresh token"),
}, now=1000)
self.assertEqual(
{
access_token_entry = {
'cached_at': "1000",
'client_id': 'my_client_id',
'credential_type': 'AccessToken',
Expand All @@ -157,10 +156,9 @@ def testAddByAdfs(self):
'secret': 'an access token',
'target': 's1 s2 s3', # Sorted
'token_type': 'some type',
},
self.cache._cache["AccessToken"].get(
'subject-fs.msidlab8.com-accesstoken-my_client_id-adfs-s1 s2 s3')
)
}
self.assertEqual(access_token_entry, self.cache._cache["AccessToken"].get(
self.at_key_maker(**access_token_entry)))
self.assertEqual(
{
'client_id': 'my_client_id',
Expand Down Expand Up @@ -206,37 +204,67 @@ def testAddByAdfs(self):
"appmetadata-fs.msidlab8.com-my_client_id")
)

def test_key_id_is_also_recorded(self):
my_key_id = "some_key_id_123"
def assertFoundAccessToken(self, *, scopes, query, data=None, now=None):
cached_at = None
for cached_at in self.cache.search(
TokenCache.CredentialType.ACCESS_TOKEN,
target=scopes, query=query, now=now,
):
for k, v in (data or {}).items(): # The extra data, if any
self.assertEqual(cached_at.get(k), v, f"AT should contain {k}={v}")
self.assertTrue(cached_at, "AT should be cached and searchable")
return cached_at

def _test_data_should_be_saved_and_searchable_in_access_token(self, data):
scopes = ["s2", "s1", "s3"] # Not in particular order
now = 1000
self.cache.add({
"data": {"key_id": my_key_id},
"data": data,
"client_id": "my_client_id",
"scope": ["s2", "s1", "s3"], # Not in particular order
"scope": scopes,
"token_endpoint": "https://login.example.com/contoso/v2/token",
"response": build_response(
uid="uid", utid="utid", # client_info
expires_in=3600, access_token="an access token",
refresh_token="a refresh token"),
}, now=1000)
cached_key_id = self.cache._cache["AccessToken"].get(
'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s1 s2 s3',
{}).get("key_id")
self.assertEqual(my_key_id, cached_key_id, "AT should be bound to the key")
}, now=now)
self.assertFoundAccessToken(scopes=scopes, data=data, now=now, query=dict(
data, # Also use the extra data as a query criteria
client_id="my_client_id",
environment="login.example.com",
realm="contoso",
home_account_id="uid.utid",
))

def test_extra_data_should_also_be_recorded_and_searchable_in_access_token(self):
self._test_data_should_be_saved_and_searchable_in_access_token({"key_id": "1"})

def test_access_tokens_with_different_key_id(self):
self._test_data_should_be_saved_and_searchable_in_access_token({"key_id": "1"})
self._test_data_should_be_saved_and_searchable_in_access_token({"key_id": "2"})
self.assertEqual(
len(self.cache._cache["AccessToken"]),
1, """Historically, tokens are not keyed by key_id,
so a new token overwrites the old one, and we would end up with 1 token in cache""")

def test_refresh_in_should_be_recorded_as_refresh_on(self): # Sounds weird. Yep.
scopes = ["s2", "s1", "s3"] # Not in particular order
self.cache.add({
"client_id": "my_client_id",
"scope": ["s2", "s1", "s3"], # Not in particular order
"scope": scopes,
"token_endpoint": "https://login.example.com/contoso/v2/token",
"response": build_response(
uid="uid", utid="utid", # client_info
expires_in=3600, refresh_in=1800, access_token="an access token",
), #refresh_token="a refresh token"),
}, now=1000)
refresh_on = self.cache._cache["AccessToken"].get(
'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s1 s2 s3',
{}).get("refresh_on")
self.assertEqual("2800", refresh_on, "Should save refresh_on")
at = self.assertFoundAccessToken(scopes=scopes, query=dict(
client_id="my_client_id",
environment="login.example.com",
realm="contoso",
home_account_id="uid.utid",
))
self.assertEqual("2800", at.get("refresh_on"), "Should save refresh_on")

def test_old_rt_data_with_wrong_key_should_still_be_salvaged_into_new_rt(self):
sample = {
Expand All @@ -258,7 +286,7 @@ def test_old_rt_data_with_wrong_key_should_still_be_salvaged_into_new_rt(self):
)


class SerializableTokenCacheTestCase(TokenCacheTestCase):
class SerializableTokenCacheTestCase(unittest.TestCase):
# Run all inherited test methods, and have extra check in tearDown()

def setUp(self):
Expand Down