Skip to content

Commit

Permalink
Implement remove_tokens_for_client()
Browse files Browse the repository at this point in the history
  • Loading branch information
rayluo committed Feb 12, 2024
1 parent b286540 commit 3b96de6
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 0 deletions.
13 changes: 13 additions & 0 deletions msal/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -2178,6 +2178,19 @@ def _acquire_token_for_client(
telemetry_context.update_telemetry(response)
return response

def remove_tokens_for_client(self):
"""Remove all tokens that were previously acquired via
:func:`~acquire_token_for_client()` for the current client."""
for env in [self.authority.instance] + self._get_authority_aliases(
self.authority.instance):
for at in self.token_cache.find(TokenCache.CredentialType.ACCESS_TOKEN, query={
"client_id": self.client_id,
"environment": env,
"home_account_id": None, # These are mostly app-only tokens
}):
self.token_cache.remove_at(at)
# acquire_token_for_client() obtains no RTs, so we have no RT to remove

def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=None, **kwargs):
"""Acquires token using on-behalf-of (OBO) flow.
Expand Down
29 changes: 29 additions & 0 deletions tests/test_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,35 @@ def test_organizations_authority_should_emit_warning(self):
authority="https://login.microsoftonline.com/organizations")


class TestRemoveTokensForClient(unittest.TestCase):
def test_remove_tokens_for_client_should_remove_client_tokens_only(self):
at_for_user = "AT for user"
cca = msal.ConfidentialClientApplication(
"client_id", client_credential="secret",
authority="https://login.microsoftonline.com/microsoft.onmicrosoft.com")
self.assertEqual(
0, len(cca.token_cache.find(msal.TokenCache.CredentialType.ACCESS_TOKEN)))
cca.acquire_token_for_client(
["scope"],
post=lambda url, **kwargs: MinimalResponse(
status_code=200, text=json.dumps({"access_token": "AT for client"})))
self.assertEqual(
1, len(cca.token_cache.find(msal.TokenCache.CredentialType.ACCESS_TOKEN)))
cca.acquire_token_by_username_password(
"johndoe", "password", ["scope"],
post=lambda url, **kwargs: MinimalResponse(
status_code=200, text=json.dumps(build_response(
access_token=at_for_user, expires_in=3600,
uid="uid", utid="utid", # This populates home_account_id
))))
self.assertEqual(
2, len(cca.token_cache.find(msal.TokenCache.CredentialType.ACCESS_TOKEN)))
cca.remove_tokens_for_client()
remaining_tokens = cca.token_cache.find(msal.TokenCache.CredentialType.ACCESS_TOKEN)
self.assertEqual(1, len(remaining_tokens))
self.assertEqual(at_for_user, remaining_tokens[0].get("secret"))


class TestScopeDecoration(unittest.TestCase):
def _test_client_id_should_be_a_valid_scope(self, client_id, other_scopes):
# B2C needs this https://learn.microsoft.com/en-us/azure/active-directory-b2c/access-tokens#openid-connect-scopes
Expand Down

0 comments on commit 3b96de6

Please sign in to comment.