Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Order scopes on save, and optimize the happy path for access token read #644

Merged
merged 3 commits into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
17 changes: 8 additions & 9 deletions msal/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -1357,13 +1357,14 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
key_id = kwargs.get("data", {}).get("key_id")
if key_id: # Some token types (SSH-certs, POP) are bound to a key
query["key_id"] = key_id
matches = self.token_cache.find(
self.token_cache.CredentialType.ACCESS_TOKEN,
target=scopes,
query=query)
now = time.time()
refresh_reason = msal.telemetry.AT_ABSENT
for entry in matches:
for entry in self.token_cache._find( # It returns a generator
rayluo marked this conversation as resolved.
Show resolved Hide resolved
self.token_cache.CredentialType.ACCESS_TOKEN,
target=scopes,
query=query,
): # Note that _find() holds a lock during this for loop;
# that is fine because this loop is fast
expires_in = int(entry["expires_on"]) - now
if expires_in < 5*60: # Then consider it expired
refresh_reason = msal.telemetry.AT_EXPIRED
Expand Down Expand Up @@ -1492,10 +1493,8 @@ def _acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
**kwargs) or last_resp

def _get_app_metadata(self, environment):
apps = self.token_cache.find( # Use find(), rather than token_cache.get(...)
TokenCache.CredentialType.APP_METADATA, query={
"environment": environment, "client_id": self.client_id})
return apps[0] if apps else {}
return self.token_cache._get_app_metadata(
environment=environment, client_id=self.client_id, default={})

def _acquire_token_silent_by_finding_specific_refresh_token(
self, authority, scopes, query,
Expand Down
67 changes: 58 additions & 9 deletions msal/token_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,20 +88,69 @@ def __init__(self):
"appmetadata-{}-{}".format(environment or "", client_id or ""),
}

def find(self, credential_type, target=None, query=None):
target = target or []
def _get_access_token(
self,
home_account_id, environment, client_id, realm, target, # Together they form a compound key
default=None,
): # O(1)
return self._get(
self.CredentialType.ACCESS_TOKEN,
self.key_makers[TokenCache.CredentialType.ACCESS_TOKEN](
home_account_id=home_account_id,
environment=environment,
client_id=client_id,
realm=realm,
target=" ".join(target),
),
default=default)

def _get_app_metadata(self, environment, client_id, default=None): # O(1)
return self._get(
self.CredentialType.APP_METADATA,
self.key_makers[TokenCache.CredentialType.APP_METADATA](
environment=environment,
client_id=client_id,
),
default=default)

def _get(self, credential_type, key, default=None): # O(1)
with self._lock:
return self._cache.get(credential_type, {}).get(key, default)

def _find(self, credential_type, target=None, query=None): # O(n) generator
"""Returns a generator of matching entries.

It is O(1) for AT hits, and O(n) for other types.
Note that it holds a lock during the entire search.
"""
target = sorted(target or []) # Match the order sorted by add()
assert isinstance(target, list), "Invalid parameter type"

preferred_result = None
if (credential_type == self.CredentialType.ACCESS_TOKEN
and "home_account_id" in query and "environment" in query
rayluo marked this conversation as resolved.
Show resolved Hide resolved
and "client_id" in query and "realm" in query and target
rayluo marked this conversation as resolved.
Show resolved Hide resolved
): # Special case for O(1) AT lookup
preferred_result = self._get_access_token(
query["home_account_id"], query["environment"],
query["client_id"], query["realm"], target)
if preferred_result:
yield preferred_result

target_set = set(target)
with self._lock:
# Since the target inside token cache key is (per schema) unsorted,
# there is no point to attempt an O(1) key-value search here.
# So we always do an O(n) in-memory search.
return [entry
for entry in self._cache.get(credential_type, {}).values()
if is_subdict_of(query or {}, entry)
and (target_set <= set(entry.get("target", "").split())
if target else True)
]
for entry in self._cache.get(credential_type, {}).values():
if is_subdict_of(query or {}, entry) and (
target_set <= set(entry.get("target", "").split())
if target else True):
if entry != preferred_result: # Avoid yielding the same entry twice
yield entry

def find(self, credential_type, target=None, query=None): # Obsolete. Use _find() instead.
return list(self._find(credential_type, target=target, query=query))

def add(self, event, now=None):
"""Handle a token obtaining event, and add tokens into cache."""
Expand Down Expand Up @@ -160,7 +209,7 @@ def __add(self, event, now=None):
decode_id_token(id_token, client_id=event["client_id"]) if id_token else {})
client_info, home_account_id = self.__parse_account(response, id_token_claims)

target = ' '.join(event.get("scope") or []) # Per schema, we don't sort it
target = ' '.join(sorted(event.get("scope") or [])) # Schema should have required sorting

with self._lock:
now = int(time.time() if now is None else now)
Expand Down
20 changes: 10 additions & 10 deletions tests/test_token_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,11 @@ def testAddByAad(self):
'home_account_id': "uid.utid",
'realm': 'contoso',
'secret': 'an access token',
'target': 's2 s1 s3',
'target': 's1 s2 s3', # Sorted
'token_type': 'some type',
},
self.cache._cache["AccessToken"].get(
'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s2 s1 s3')
'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s1 s2 s3')
)
self.assertEqual(
{
Expand All @@ -90,10 +90,10 @@ def testAddByAad(self):
'home_account_id': "uid.utid",
'last_modification_time': '1000',
'secret': 'a refresh token',
'target': 's2 s1 s3',
'target': 's1 s2 s3', # Sorted
},
self.cache._cache["RefreshToken"].get(
'uid.utid-login.example.com-refreshtoken-my_client_id--s2 s1 s3')
'uid.utid-login.example.com-refreshtoken-my_client_id--s1 s2 s3')
)
self.assertEqual(
{
Expand Down Expand Up @@ -150,11 +150,11 @@ def testAddByAdfs(self):
'home_account_id': "subject",
'realm': 'adfs',
'secret': 'an access token',
'target': 's2 s1 s3',
'target': 's1 s2 s3', # Sorted
'token_type': 'some type',
},
self.cache._cache["AccessToken"].get(
'subject-fs.msidlab8.com-accesstoken-my_client_id-adfs-s2 s1 s3')
'subject-fs.msidlab8.com-accesstoken-my_client_id-adfs-s1 s2 s3')
)
self.assertEqual(
{
Expand All @@ -164,10 +164,10 @@ def testAddByAdfs(self):
'home_account_id': "subject",
'last_modification_time': "1000",
'secret': 'a refresh token',
'target': 's2 s1 s3',
'target': 's1 s2 s3', # Sorted
},
self.cache._cache["RefreshToken"].get(
'subject-fs.msidlab8.com-refreshtoken-my_client_id--s2 s1 s3')
'subject-fs.msidlab8.com-refreshtoken-my_client_id--s1 s2 s3')
)
self.assertEqual(
{
Expand Down Expand Up @@ -214,7 +214,7 @@ def test_key_id_is_also_recorded(self):
refresh_token="a refresh token"),
}, now=1000)
cached_key_id = self.cache._cache["AccessToken"].get(
'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s2 s1 s3',
'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s1 s2 s3',
{}).get("key_id")
self.assertEqual(my_key_id, cached_key_id, "AT should be bound to the key")

Expand All @@ -229,7 +229,7 @@ def test_refresh_in_should_be_recorded_as_refresh_on(self): # Sounds weird. Yep
), #refresh_token="a refresh token"),
}, now=1000)
refresh_on = self.cache._cache["AccessToken"].get(
'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s2 s1 s3',
'uid.utid-login.example.com-accesstoken-my_client_id-contoso-s1 s2 s3',
{}).get("refresh_on")
self.assertEqual("2800", refresh_on, "Should save refresh_on")

Expand Down