Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 118 additions & 23 deletions msal/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from urllib.parse import urljoin
import logging
import sys
import warnings

import requests

from .oauth2cli import Client, JwtSigner
from .authority import Authority
Expand All @@ -15,7 +18,7 @@


# The __init__.py will import this. Not the other way around.
__version__ = "0.2.0"
__version__ = "0.3.0"

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -101,6 +104,14 @@ def __init__(
# Here the self.authority is not the same type as authority in input
self.token_cache = token_cache or TokenCache()
self.client = self._build_client(client_credential, self.authority)
self.authority_groups = self._get_authority_aliases()

def _get_authority_aliases(self):
resp = requests.get(
"https://login.microsoftonline.com/common/discovery/instance?api-version=1.1&authorization_endpoint=https://login.microsoftonline.com/common/oauth2/authorize",
headers={'Accept': 'application/json'})
resp.raise_for_status()
return [set(group['aliases']) for group in resp.json()['metadata']]

def _build_client(self, client_credential, authority):
client_assertion = None
Expand Down Expand Up @@ -236,19 +247,33 @@ def get_accounts(self, username=None):
Your app can choose to display those information to end user,
and allow user to choose one of his/her accounts to proceed.
"""
# The following implementation finds accounts only from saved accounts,
# but does NOT correlate them with saved RTs. It probably won't matter,
# because in MSAL universe, there are always Accounts and RTs together.
accounts = self.token_cache.find(
self.token_cache.CredentialType.ACCOUNT,
query={"environment": self.authority.instance})
accounts = self._find_msal_accounts(environment=self.authority.instance)
if not accounts: # Now try other aliases of this authority instance
for group in self.authority_groups:
if self.authority.instance in group:
for alias in group:
if alias != self.authority.instance:
accounts = self._find_msal_accounts(environment=alias)
if accounts:
break
if username:
# Federated account["username"] from AAD could contain mixed case
lowercase_username = username.lower()
accounts = [a for a in accounts
if a["username"].lower() == lowercase_username]
# Does not further filter by existing RTs here. It probably won't matter.
# Because in most cases Accounts and RTs co-exist.
# Even in the rare case when an RT is revoked and then removed,
# acquire_token_silent() would then yield no result,
# apps would fall back to other acquire methods. This is the standard pattern.
return accounts

def _find_msal_accounts(self, environment):
return [a for a in self.token_cache.find(
TokenCache.CredentialType.ACCOUNT, query={"environment": environment})
if a["authority_type"] in (
TokenCache.AuthorityType.ADFS, TokenCache.AuthorityType.MSSTS)]

def acquire_token_silent(
self,
scopes, # type: List[str]
Expand All @@ -275,19 +300,44 @@ def acquire_token_silent(
- None when cache lookup does not yield anything.
"""
assert isinstance(scopes, list), "Invalid parameter type"
the_authority = Authority(
authority,
verify=self.verify, proxies=self.proxies, timeout=self.timeout,
) if authority else self.authority

if authority:
warnings.warn("We haven't decided how/if this method will accept authority parameter")
# the_authority = Authority(
# authority,
# verify=self.verify, proxies=self.proxies, timeout=self.timeout,
# ) if authority else self.authority
result = self._acquire_token_silent(scopes, account, self.authority, **kwargs)
if result:
return result
for group in self.authority_groups:
if self.authority.instance in group:
for alias in group:
if alias != self.authority.instance:
the_authority = Authority(
"https://" + alias + "/" + self.authority.tenant,
validate_authority=False,
verify=self.verify, proxies=self.proxies,
timeout=self.timeout,)
result = self._acquire_token_silent(
scopes, account, the_authority, **kwargs)
if result:
return result

def _acquire_token_silent(
self,
scopes, # type: List[str]
account, # type: Optional[Account]
authority, # This can be different than self.authority
force_refresh=False, # type: Optional[boolean]
**kwargs):
if not force_refresh:
matches = self.token_cache.find(
self.token_cache.CredentialType.ACCESS_TOKEN,
target=scopes,
query={
"client_id": self.client_id,
"environment": the_authority.instance,
"realm": the_authority.tenant,
"environment": authority.instance,
"realm": authority.tenant,
"home_account_id": (account or {}).get("home_account_id"),
})
now = time.time()
Expand All @@ -301,26 +351,71 @@ def acquire_token_silent(
"token_type": "Bearer",
"expires_in": int(expires_in), # OAuth2 specs defines it as int
}
return self._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
authority, decorate_scope(scopes, self.client_id), account,
**kwargs)

def _acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
self, authority, scopes, account, **kwargs):
query = {
"environment": authority.instance,
"home_account_id": (account or {}).get("home_account_id"),
# "realm": authority.tenant, # AAD RTs are tenant-independent
}
apps = self.token_cache.find( # Use find(), rather than token_cache.get(...)
TokenCache.CredentialType.APP_METADATA, query={
"environment": authority.instance, "client_id": self.client_id})
app_metadata = apps[0] if apps else {}
if not app_metadata: # Meaning this app is now used for the first time.
# When/if we have a way to directly detect current app's family,
# we'll rewrite this block, to support multiple families.
# For now, we try existing RTs (*). If it works, we are in that family.
# (*) RTs of a different app/family are not supposed to be
# shared with or accessible by us in the first place.
at = self._acquire_token_silent_by_finding_specific_refresh_token(
authority, scopes,
dict(query, family_id="1"), # A hack, we have only 1 family for now
rt_remover=lambda rt_item: None, # NO-OP b/c RTs are likely not mine
break_condition=lambda response: # Break loop when app not in family
# Based on an AAD-only behavior mentioned in internal doc here
# https://msazure.visualstudio.com/One/_git/ESTS-Docs/pullrequest/1138595
"client_mismatch" in response.get("error_additional_info", []),
**kwargs)
if at:
return at
if app_metadata.get("family_id"): # Meaning this app belongs to this family
at = self._acquire_token_silent_by_finding_specific_refresh_token(
authority, scopes, dict(query, family_id=app_metadata["family_id"]),
**kwargs)
if at:
return at
# Either this app is an orphan, so we will naturally use its own RT;
# or all attempts above have failed, so we fall back to non-foci behavior.
return self._acquire_token_silent_by_finding_specific_refresh_token(
authority, scopes, dict(query, client_id=self.client_id), **kwargs)

def _acquire_token_silent_by_finding_specific_refresh_token(
self, authority, scopes, query,
rt_remover=None, break_condition=lambda response: False, **kwargs):
matches = self.token_cache.find(
self.token_cache.CredentialType.REFRESH_TOKEN,
# target=scopes, # AAD RTs are scope-independent
query={
"client_id": self.client_id,
"environment": the_authority.instance,
"home_account_id": (account or {}).get("home_account_id"),
# "realm": the_authority.tenant, # AAD RTs are tenant-independent
})
client = self._build_client(self.client_credential, the_authority)
query=query)
logger.debug("Found %d RTs matching %s", len(matches), query)
client = self._build_client(self.client_credential, authority)
for entry in matches:
logger.debug("Cache hit an RT")
logger.debug("Cache attempts an RT")
response = client.obtain_token_by_refresh_token(
entry, rt_getter=lambda token_item: token_item["secret"],
scope=decorate_scope(scopes, self.client_id))
on_removing_rt=rt_remover or self.token_cache.remove_rt,
scope=scopes,
**kwargs)
if "error" not in response:
return response
logger.debug(
"Refresh failed. {error}: {error_description}".format(**response))
if break_condition(response):
break


class PublicClientApplication(ClientApplication): # browser app or mobile app
Expand Down
11 changes: 8 additions & 3 deletions msal/oauth2cli/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749
data=None, # All relevant data, which will go into the http body
headers=None, # a dict to be sent as request headers
timeout=None,
post=None, # A callable to replace requests.post(), for testing.
# Such as: lambda url, **kwargs:
# Mock(status_code=200, json=Mock(return_value={}))
**kwargs # Relay all extra parameters to underlying requests
): # Returns the json object came from the OAUTH2 response
_data = {'client_id': self.client_id, 'grant_type': grant_type}
Expand All @@ -133,7 +136,7 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749
raise ValueError("token_endpoint not found in configuration")
_headers = {'Accept': 'application/json'}
_headers.update(headers or {})
resp = self.session.post(
resp = (post or self.session.post)(
self.configuration["token_endpoint"],
headers=_headers, params=params, data=_data, auth=auth,
timeout=timeout or self.timeout,
Expand Down Expand Up @@ -393,16 +396,18 @@ def _obtain_token(self, grant_type, params=None, data=None, *args, **kwargs):

def obtain_token_by_refresh_token(self, token_item, scope=None,
rt_getter=lambda token_item: token_item["refresh_token"],
on_removing_rt=None,
**kwargs):
# type: (Union[str, dict], Union[str, list, set, tuple], Callable) -> dict
"""This is an "overload" which accepts a refresh token item as a dict,
therefore this method can relay refresh_token item to event listeners.

:param refresh_token_item: A refresh token item came from storage
:param token_item: A refresh token item came from storage
:param scope: If omitted, is treated as equal to the scope originally
granted by the resource ownser,
according to https://tools.ietf.org/html/rfc6749#section-6
:param rt_getter: A callable used to extract the RT from token_item
:param on_removing_rt: If absent, fall back to the one defined in initialization
"""
if isinstance(token_item, str):
# Satisfy the L of SOLID, although we expect caller uses a dict
Expand All @@ -412,7 +417,7 @@ def obtain_token_by_refresh_token(self, token_item, scope=None,
resp = super(Client, self).obtain_token_by_refresh_token(
rt_getter(token_item), scope=scope, **kwargs)
if resp.get('error') == 'invalid_grant':
self.on_removing_rt(token_item) # Discard old RT
(on_removing_rt or self.on_removing_rt)(token_item) # Discard old RT
if 'refresh_token' in resp:
self.on_updating_rt(token_item, resp['refresh_token'])
return resp
Expand Down
33 changes: 26 additions & 7 deletions msal/token_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ class CredentialType:
REFRESH_TOKEN = "RefreshToken"
ACCOUNT = "Account" # Not exactly a credential type, but we put it here
ID_TOKEN = "IdToken"
APP_METADATA = "AppMetadata"

class AuthorityType:
ADFS = "ADFS"
MSSTS = "MSSTS" # MSSTS means AAD v2 for both AAD & MSA

def __init__(self):
self._lock = threading.RLock()
Expand Down Expand Up @@ -118,8 +123,8 @@ def add(self, event, now=None):
"oid", decoded_id_token.get("sub")),
"username": decoded_id_token.get("preferred_username"),
"authority_type":
"ADFS" if realm == "adfs"
else "MSSTS", # MSSTS means AAD v2 for both AAD & MSA
self.AuthorityType.ADFS if realm == "adfs"
else self.AuthorityType.MSSTS,
# "client_info": response.get("client_info"), # Optional
}

Expand Down Expand Up @@ -158,6 +163,17 @@ def add(self, event, now=None):
rt["family_id"] = response["foci"]
self._cache.setdefault(self.CredentialType.REFRESH_TOKEN, {})[key] = rt

key = self._build_appmetadata_key(environment, event.get("client_id"))
self._cache.setdefault(self.CredentialType.APP_METADATA, {})[key] = {
"client_id": event.get("client_id"),
"environment": environment,
"family_id": response.get("foci"), # None is also valid
}

@staticmethod
def _build_appmetadata_key(environment, client_id):
return "appmetadata-{}-{}".format(environment or "", client_id or "")

@classmethod
def _build_rt_key(
cls,
Expand Down Expand Up @@ -192,21 +208,24 @@ class SerializableTokenCache(TokenCache):
Depending on your need,
the following simple recipe for file-based persistence may be sufficient::

import atexit
cache = SerializableTokenCache()
cache.deserialize(open("my_cache.bin", "rb").read())
import os, atexit, msal
cache = msal.SerializableTokenCache()
if os.path.exists("my_cache.bin"):
cache.deserialize(open("my_cache.bin", "r").read())
atexit.register(lambda:
open("my_cache.bin", "wb").write(cache.serialize())
open("my_cache.bin", "w").write(cache.serialize())
# Hint: The following optional line persists only when state changed
if cache.has_state_changed else None
)
app = ClientApplication(..., token_cache=cache)
app = msal.ClientApplication(..., token_cache=cache)
...

:var bool has_state_changed:
Indicates whether the cache state has changed since last
:func:`~serialize` or :func:`~deserialize` call.
"""
has_state_changed = False

def add(self, event, **kwargs):
super(SerializableTokenCache, self).add(event, **kwargs)
self.has_state_changed = True
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
.
mock; python_version < '3.3'
2 changes: 1 addition & 1 deletion sample/client_credential_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
{
"authority": "https://login.microsoftonline.com/organizations",
"client_id": "your_client_id",
"scope": ["https://graph.microsoft.com/.default"],
"secret": "This is a sample only. You better NOT persist your password."
"scope": ["https://graph.microsoft.com/.default"]
}

You can then run this sample with a JSON configuration file:
Expand Down
4 changes: 3 additions & 1 deletion sample/username_password_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,6 @@
print(result.get("error"))
print(result.get("error_description"))
print(result.get("correlation_id")) # You may need this when reporting a bug

if 65001 in result.get("error_codes", []): # Not mean to be coded programatically, but...
# AAD requires user consent for U/P flow
print("Visit this to consent:", app.get_authorization_request_url(scope))
Loading