Skip to content

Commit

Permalink
SharedTokenCacheCredential uses MSAL when given an AuthenticationReco…
Browse files Browse the repository at this point in the history
…rd (#13490)
  • Loading branch information
chlowell committed Sep 4, 2020
1 parent f01aca1 commit 635b820
Show file tree
Hide file tree
Showing 7 changed files with 141 additions and 176 deletions.
4 changes: 4 additions & 0 deletions sdk/identity/azure-identity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
authentication with a user-specified application having a custom redirect URI
([#13344](https://github.com/Azure/azure-sdk-for-python/issues/13344))

### Breaking changes
- Removed `authentication_record` keyword argument from the async
`SharedTokenCacheCredential`, i.e. `azure.identity.aio.SharedTokenCacheCredential`

## 1.4.0 (2020-08-10)
### Added
- `DefaultAzureCredential` uses the value of environment variable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,18 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import time

from msal.application import PublicClientApplication

from azure.core.credentials import AccessToken
from azure.core.exceptions import ClientAuthenticationError

from .. import CredentialUnavailableError
from .._constants import AZURE_CLI_CLIENT_ID
from .._internal import AadClient
from .._internal.decorators import log_get_token
from .._internal.decorators import log_get_token, wrap_exceptions
from .._internal.msal_client import MsalClient
from .._internal.shared_token_cache import NO_TOKEN, SharedTokenCacheBase

try:
Expand All @@ -15,7 +23,8 @@

if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Any
from typing import Any, Optional
from .. import AuthenticationRecord
from .._internal import AadClientBase


Expand All @@ -37,6 +46,20 @@ class SharedTokenCacheCredential(SharedTokenCacheBase):
is unavailable. Defaults to False.
"""

def __init__(self, username=None, **kwargs):
# type: (Optional[str], **Any) -> None

self._auth_record = kwargs.pop("authentication_record", None) # type: Optional[AuthenticationRecord]
if self._auth_record:
# authenticate in the tenant that produced the record unless "tenant_id" specifies another
self._tenant_id = kwargs.pop("tenant_id", None) or self._auth_record.tenant_id
self._cache = kwargs.pop("_cache", None)
self._app = None
self._client_kwargs = kwargs
self._initialized = False
else:
super(SharedTokenCacheCredential, self).__init__(username=username, **kwargs)

@log_get_token("SharedTokenCacheCredential")
def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
# type (*str, **Any) -> AccessToken
Expand All @@ -51,18 +74,20 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
:raises ~azure.identity.CredentialUnavailableError: the cache is unavailable or contains insufficient user
information
:raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message``
attribute gives a reason. Any error response from Azure Active Directory is available as the error's
``response`` attribute.
attribute gives a reason.
"""
if not scopes:
raise ValueError("'get_token' requires at least one scope")

if not self._initialized:
self._initialize()

if not self._client:
if not self._cache:
raise CredentialUnavailableError(message="Shared token cache unavailable")

if self._auth_record:
return self._acquire_token_silent(*scopes)

account = self._get_account(self._username, self._tenant_id)

token = self._get_cached_access_token(scopes, account)
Expand All @@ -79,3 +104,54 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
def _get_auth_client(self, **kwargs):
# type: (**Any) -> AadClientBase
return AadClient(client_id=AZURE_CLI_CLIENT_ID, **kwargs)

def _initialize(self):
if self._initialized:
return

if not self._auth_record:
super(SharedTokenCacheCredential, self)._initialize()
return

self._load_cache()
if self._cache:
self._app = PublicClientApplication(
client_id=self._auth_record.client_id,
authority="https://{}/{}".format(self._auth_record.authority, self._tenant_id),
token_cache=self._cache,
http_client=MsalClient(**self._client_kwargs),
)

self._initialized = True

@wrap_exceptions
def _acquire_token_silent(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
"""Silently acquire a token from MSAL. Requires an AuthenticationRecord."""

result = None

accounts_for_user = self._app.get_accounts(username=self._auth_record.username)
if not accounts_for_user:
raise CredentialUnavailableError("The cache contains no account matching the given AuthenticationRecord.")

for account in accounts_for_user:
if account.get("home_account_id") != self._auth_record.home_account_id:
continue

now = int(time.time())
result = self._app.acquire_token_silent_with_error(list(scopes), account=account, **kwargs)
if result and "access_token" in result and "expires_in" in result:
return AccessToken(result["access_token"], now + int(result["expires_in"]))

# if we get this far, the cache contained a matching account but MSAL failed to authenticate it silently
if result:
# cache contains a matching refresh token but STS returned an error response when MSAL tried to use it
message = "Token acquisition failed"
details = result.get("error_description") or result.get("error")
if details:
message += ": {}".format(details)
raise ClientAuthenticationError(message=message)

# cache doesn't contain a matching refresh (or access) token
raise CredentialUnavailableError(message=NO_TOKEN.format(self._auth_record.username))
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
# pylint:disable=unused-import,ungrouped-imports
from typing import Any, Iterable, List, Mapping, Optional
from .._internal import AadClientBase
from azure.identity import AuthenticationRecord

CacheItem = Mapping[str, str]

Expand Down Expand Up @@ -89,46 +88,36 @@ def _filtered_accounts(accounts, username=None, tenant_id=None):
class SharedTokenCacheBase(ABC):
def __init__(self, username=None, **kwargs): # pylint:disable=unused-argument
# type: (Optional[str], **Any) -> None

self._auth_record = kwargs.pop("authentication_record", None) # type: Optional[AuthenticationRecord]
if self._auth_record:
# authenticate in the tenant that produced the record unless 'tenant_id' specifies another
authenticating_tenant = kwargs.pop("tenant_id", None) or self._auth_record.tenant_id
self._tenant_id = self._auth_record.tenant_id
self._authority = self._auth_record.authority
self._username = self._auth_record.username
self._environment_aliases = frozenset((self._authority,))
else:
authenticating_tenant = "organizations"
authority = kwargs.pop("authority", None)
self._authority = normalize_authority(authority) if authority else get_default_authority()
environment = urlparse(self._authority).netloc
self._environment_aliases = KNOWN_ALIASES.get(environment) or frozenset((environment,))
self._username = username
self._tenant_id = kwargs.pop("tenant_id", None)

authority = kwargs.pop("authority", None)
self._authority = normalize_authority(authority) if authority else get_default_authority()
environment = urlparse(self._authority).netloc
self._environment_aliases = KNOWN_ALIASES.get(environment) or frozenset((environment,))
self._username = username
self._tenant_id = kwargs.pop("tenant_id", None)
self._cache = kwargs.pop("_cache", None)
self._client = None # type: Optional[AadClientBase]
self._client_kwargs = kwargs
self._client_kwargs["tenant_id"] = authenticating_tenant
self._client_kwargs["tenant_id"] = "organizations"
self._initialized = False

def _initialize(self):
if self._initialized:
return

self._load_cache()
if self._cache:
self._client = self._get_auth_client(authority=self._authority, cache=self._cache, **self._client_kwargs)

self._initialized = True

def _load_cache(self):
if not self._cache and self.supported():
allow_unencrypted = self._client_kwargs.get("allow_unencrypted_cache", False)
try:
self._cache = load_user_cache(allow_unencrypted)
except Exception: # pylint:disable=broad-except
pass

if self._cache:
self._client = self._get_auth_client(authority=self._authority, cache=self._cache, **self._client_kwargs)

self._initialized = True

@abc.abstractmethod
def _get_auth_client(self, **kwargs):
# type: (**Any) -> AadClientBase
Expand Down Expand Up @@ -176,14 +165,6 @@ def _get_account(self, username=None, tenant_id=None):
# cache is empty or contains no refresh token -> user needs to sign in
raise CredentialUnavailableError(message=NO_ACCOUNTS)

if self._auth_record:
for account in accounts:
if account.get("home_account_id") == self._auth_record.home_account_id:
return account
raise CredentialUnavailableError(
message="The cache contains no account matching the given AuthenticationRecord."
)

filtered_accounts = _filtered_accounts(accounts, username, tenant_id)
if len(filtered_accounts) == 1:
return filtered_accounts[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ class SharedTokenCacheCredential(SharedTokenCacheBase, AsyncContextManager):
defines authorities for other clouds.
:keyword str tenant_id: an Azure Active Directory tenant ID. Used to select an account when the cache contains
tokens for multiple identities.
:keyword AuthenticationRecord authentication_record: an authentication record returned by a user credential such as
:class:`DeviceCodeCredential` or :class:`InteractiveBrowserCredential`
:keyword bool allow_unencrypted_cache: if True, the credential will fall back to a plaintext cache when encryption
is unavailable. Defaults to False.
"""
Expand Down
2 changes: 1 addition & 1 deletion sdk/identity/azure-identity/tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def validate_request(request, **_):
try:
expected_request, response = next(sessions)
except StopIteration:
assert False, "unexpected request: {}".format(request)
assert False, "unexpected request: {} {}".format(request.method, request.url)
expected_request.assert_matches(request)
return response

Expand Down
58 changes: 40 additions & 18 deletions sdk/identity/azure-identity/tests/test_shared_cache_credential.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,15 @@
except ImportError: # python < 3.3
from mock import Mock, patch # type: ignore

from helpers import build_aad_response, build_id_token, mock_response, Request, validating_transport
from helpers import (
build_aad_response,
build_id_token,
get_discovery_response,
mock_response,
msal_validating_transport,
Request,
validating_transport,
)


def test_supported():
Expand Down Expand Up @@ -513,8 +521,13 @@ def test_authority_environment_variable():

def test_authentication_record_empty_cache():
record = AuthenticationRecord("tenant_id", "client_id", "authority", "home_account_id", "username")
transport = Mock(side_effect=Exception("the credential shouldn't send a request"))
credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=TokenCache())

def send(request, **_):
# expecting only MSAL discovery requests
assert request.method == 'GET'
return get_discovery_response()

credential = SharedTokenCacheCredential(authentication_record=record, transport=Mock(send=send), _cache=TokenCache())

with pytest.raises(CredentialUnavailableError):
credential.get_token("scope")
Expand All @@ -529,13 +542,17 @@ def test_authentication_record_no_match():
username = "me"
record = AuthenticationRecord(tenant_id, client_id, authority, home_account_id, username)

transport = Mock(side_effect=Exception("the credential shouldn't send a request"))
def send(request, **_):
# expecting only MSAL discovery requests
assert request.method == 'GET'
return get_discovery_response()

cache = populated_cache(
get_account_event(
"not-" + username, "not-" + object_id, "different-" + tenant_id, client_id="not-" + client_id,
),
)
credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache)
credential = SharedTokenCacheCredential(authentication_record=record, transport=Mock(send=send), _cache=cache)

with pytest.raises(CredentialUnavailableError):
credential.get_token("scope")
Expand All @@ -557,7 +574,8 @@ def test_authentication_record():
)
cache = populated_cache(account)

transport = validating_transport(
transport = msal_validating_transport(
endpoint="https://{}/{}".format(authority, tenant_id),
requests=[Request(authority=authority, required_data={"refresh_token": expected_refresh_token})],
responses=[mock_response(json_payload=build_aad_response(access_token=expected_access_token))],
)
Expand Down Expand Up @@ -593,7 +611,8 @@ def test_auth_record_multiple_accounts_for_username():
),
)

transport = validating_transport(
transport = msal_validating_transport(
endpoint="https://{}/{}".format(authority, tenant_id),
requests=[Request(authority=authority, required_data={"refresh_token": expected_refresh_token})],
responses=[mock_response(json_payload=build_aad_response(access_token=expected_access_token))],
)
Expand Down Expand Up @@ -741,19 +760,22 @@ def test_authentication_record_authenticating_tenant():
"""when given a record and 'tenant_id', the credential should authenticate in the latter"""

expected_tenant_id = "tenant-id"
record = AuthenticationRecord("not- " + expected_tenant_id, "...", "...", "...", "...")
record = AuthenticationRecord("not- " + expected_tenant_id, "...", "localhost", "...", "...")

with patch.object(SharedTokenCacheCredential, "_get_auth_client") as get_auth_client:
credential = SharedTokenCacheCredential(
authentication_record=record, _cache=TokenCache(), tenant_id=expected_tenant_id
)
with pytest.raises(CredentialUnavailableError):
# this raises because the cache is empty
credential.get_token("scope")
def mock_send(request, **_):
if not request.body:
return get_discovery_response()
assert request.url.startswith("https://localhost/" + expected_tenant_id)
return mock_response(json_payload=build_aad_response(access_token="*"))

transport = Mock(send=Mock(wraps=mock_send))
credential = SharedTokenCacheCredential(
authentication_record=record, _cache=TokenCache(), tenant_id=expected_tenant_id, transport=transport
)
with pytest.raises(CredentialUnavailableError):
credential.get_token("scope") # this raises because the cache is empty

assert get_auth_client.call_count == 1
_, kwargs = get_auth_client.call_args
assert kwargs["tenant_id"] == expected_tenant_id
assert transport.send.called


def get_account_event(
Expand Down

0 comments on commit 635b820

Please sign in to comment.