diff --git a/msal/application.py b/msal/application.py index 610b41f8..6d4eeb58 100644 --- a/msal/application.py +++ b/msal/application.py @@ -21,7 +21,7 @@ # The __init__.py will import this. Not the other way around. -__version__ = "1.6.0" +__version__ = "1.7.0" logger = logging.getLogger(__name__) @@ -107,6 +107,7 @@ class ClientApplication(object): ACQUIRE_TOKEN_BY_DEVICE_FLOW_ID = "622" ACQUIRE_TOKEN_FOR_CLIENT_ID = "730" ACQUIRE_TOKEN_BY_AUTHORIZATION_CODE_ID = "832" + ACQUIRE_TOKEN_INTERACTIVE = "169" GET_ACCOUNTS_ID = "902" REMOVE_ACCOUNT_ID = "903" @@ -300,6 +301,78 @@ def _build_client(self, client_credential, authority): on_removing_rt=self.token_cache.remove_rt, on_updating_rt=self.token_cache.update_rt) + def initiate_auth_code_flow( + self, + scopes, # type: list[str] + redirect_uri=None, + state=None, # Recommended by OAuth2 for CSRF protection + prompt=None, + login_hint=None, # type: Optional[str] + domain_hint=None, # type: Optional[str] + claims_challenge=None, + ): + """Initiate an auth code flow. + + Later when the response reaches your redirect_uri, + you can use :func:`~acquire_token_by_auth_code_flow()` + to complete the authentication/authorization. + + :param list scope: + It is a list of case-sensitive strings. + :param str redirect_uri: + Optional. If not specified, server will use the pre-registered one. + :param str state: + An opaque value used by the client to + maintain state between the request and callback. + If absent, this library will automatically generate one internally. + :param str prompt: + By default, no prompt value will be sent, not even "none". + You will have to specify a value explicitly. + Its valid values are defined in Open ID Connect specs + https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest + :param str login_hint: + Optional. Identifier of the user. Generally a User Principal Name (UPN). + :param domain_hint: + Can be one of "consumers" or "organizations" or your tenant domain "contoso.com". + If included, it will skip the email-based discovery process that user goes + through on the sign-in page, leading to a slightly more streamlined user experience. + More information on possible values + `here `_ and + `here `_. + + :return: + The auth code flow. It is a dict in this form:: + + { + "auth_uri": "https://...", // Guide user to visit this + "state": "...", // You may choose to verify it by yourself, + // or just let acquire_token_by_auth_code_flow() + // do that for you. + "...": "...", // Everything else are reserved and internal + } + + The caller is expected to:: + + 1. somehow store this content, typically inside the current session, + 2. guide the end user (i.e. resource owner) to visit that auth_uri, + 3. and then relay this dict and subsequent auth response to + :func:`~acquire_token_by_auth_code_flow()`. + """ + client = Client( + {"authorization_endpoint": self.authority.authorization_endpoint}, + self.client_id, + http_client=self.http_client) + flow = client.initiate_auth_code_flow( + redirect_uri=redirect_uri, state=state, login_hint=login_hint, + prompt=prompt, + scope=decorate_scope(scopes, self.client_id), + domain_hint=domain_hint, + claims=_merge_claims_challenge_and_capabilities( + self._client_capabilities, claims_challenge), + ) + flow["claims_challenge"] = claims_challenge + return flow + def get_authorization_request_url( self, scopes, # type: list[str] @@ -386,6 +459,73 @@ def get_authorization_request_url( self._client_capabilities, claims_challenge), ) + def acquire_token_by_auth_code_flow( + self, auth_code_flow, auth_response, scopes=None, **kwargs): + """Validate the auth response being redirected back, and obtain tokens. + + It automatically provides nonce protection. + + :param dict auth_code_flow: + The same dict returned by :func:`~initiate_auth_code_flow()`. + :param dict auth_response: + A dict of the query string received from auth server. + :param list[str] scopes: + Scopes requested to access a protected API (a resource). + + Most of the time, you can leave it empty. + + If you requested user consent for multiple resources, here you will + need to provide a subset of what you required in + :func:`~initiate_auth_code_flow()`. + + OAuth2 was designed mostly for singleton services, + where tokens are always meant for the same resource and the only + changes are in the scopes. + In AAD, tokens can be issued for multiple 3rd party resources. + You can ask authorization code for multiple resources, + but when you redeem it, the token is for only one intended + recipient, called audience. + So the developer need to specify a scope so that we can restrict the + token to be issued for the corresponding audience. + + :return: + * A dict containing "access_token" and/or "id_token", among others, + depends on what scope was used. + (See https://tools.ietf.org/html/rfc6749#section-5.1) + * A dict containing "error", optionally "error_description", "error_uri". + (It is either `this `_ + or `that `_) + * Most client-side data error would result in ValueError exception. + So the usage pattern could be without any protocol details:: + + def authorize(): # A controller in a web app + try: + result = msal_app.acquire_token_by_auth_code_flow( + session.get("flow", {}), request.args) + if "error" in result: + return render_template("error.html", result) + use(result) # Token(s) are available in result and cache + except ValueError: # Usually caused by CSRF + pass # Simply ignore them + return redirect(url_for("index")) + """ + self._validate_ssh_cert_input_data(kwargs.get("data", {})) + return self.client.obtain_token_by_auth_code_flow( + auth_code_flow, + auth_response, + scope=decorate_scope(scopes, self.client_id) if scopes else None, + headers={ + CLIENT_REQUEST_ID: _get_new_correlation_id(), + CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header( + self.ACQUIRE_TOKEN_BY_AUTHORIZATION_CODE_ID), + }, + data=dict( + kwargs.pop("data", {}), + claims=_merge_claims_challenge_and_capabilities( + self._client_capabilities, + auth_code_flow.pop("claims_challenge", None))), + **kwargs) + def acquire_token_by_authorization_code( self, code, @@ -858,6 +998,80 @@ def __init__(self, client_id, client_credential=None, **kwargs): super(PublicClientApplication, self).__init__( client_id, client_credential=None, **kwargs) + def acquire_token_interactive( + self, + scopes, # type: list[str] + prompt=None, + login_hint=None, # type: Optional[str] + domain_hint=None, # type: Optional[str] + claims_challenge=None, + timeout=None, + port=None, + **kwargs): + """Acquire token interactively i.e. via a local browser. + + :param list scope: + It is a list of case-sensitive strings. + :param str prompt: + By default, no prompt value will be sent, not even "none". + You will have to specify a value explicitly. + Its valid values are defined in Open ID Connect specs + https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest + :param str login_hint: + Optional. Identifier of the user. Generally a User Principal Name (UPN). + :param domain_hint: + Can be one of "consumers" or "organizations" or your tenant domain "contoso.com". + If included, it will skip the email-based discovery process that user goes + through on the sign-in page, leading to a slightly more streamlined user experience. + More information on possible values + `here `_ and + `here `_. + + :param claims_challenge: + The claims_challenge parameter requests specific claims requested by the resource provider + in the form of a claims_challenge directive in the www-authenticate header to be + returned from the UserInfo Endpoint and/or in the ID Token and/or Access Token. + It is a string of a JSON object which contains lists of claims being requested from these locations. + + :param int timeout: + This method will block the current thread. + This parameter specifies the timeout value in seconds. + Default value ``None`` means wait indefinitely. + + :param int port: + The port to be used to listen to an incoming auth response. + By default we will use a system-allocated port. + (The rest of the redirect_uri is hard coded as ``http://localhost``.) + + :return: + - A dict containing no "error" key, + and typically contains an "access_token" key, + if cache lookup succeeded. + - A dict containing an "error" key, when token refresh failed. + """ + self._validate_ssh_cert_input_data(kwargs.get("data", {})) + claims = _merge_claims_challenge_and_capabilities( + self._client_capabilities, claims_challenge) + return self.client.obtain_token_by_browser( + scope=decorate_scope(scopes, self.client_id) if scopes else None, + redirect_uri="http://localhost:{port}".format( + # Hardcode the host, for now. AAD portal rejects 127.0.0.1 anyway + port=port or 0), + prompt=prompt, + login_hint=login_hint, + timeout=timeout, + auth_params={ + "claims": claims, + "domain_hint": domain_hint, + }, + data=dict(kwargs.pop("data", {}), claims=claims), + headers={ + CLIENT_REQUEST_ID: _get_new_correlation_id(), + CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header( + self.ACQUIRE_TOKEN_INTERACTIVE), + }, + **kwargs) + def initiate_device_flow(self, scopes=None, **kwargs): """Initiate a Device Flow instance, which will be used in :func:`~acquire_token_by_device_flow`. diff --git a/msal/mex.py b/msal/mex.py index a84f320b..edecba37 100644 --- a/msal/mex.py +++ b/msal/mex.py @@ -33,16 +33,25 @@ from xml.etree import cElementTree as ET except ImportError: from xml.etree import ElementTree as ET +import logging +logger = logging.getLogger(__name__) + def _xpath_of_root(route_to_leaf): # Construct an xpath suitable to find a root node which has a specified leaf return '/'.join(route_to_leaf + ['..'] * (len(route_to_leaf)-1)) def send_request(mex_endpoint, http_client, **kwargs): - mex_document = http_client.get(mex_endpoint, **kwargs).text - return Mex(mex_document).get_wstrust_username_password_endpoint() + mex_resp = http_client.get(mex_endpoint, **kwargs) + mex_resp.raise_for_status() + try: + return Mex(mex_resp.text).get_wstrust_username_password_endpoint() + except ET.ParseError: + logger.exception( + "Malformed MEX document: %s, %s", mex_resp.status_code, mex_resp.text) + raise class Mex(object): diff --git a/msal/oauth2cli/__init__.py b/msal/oauth2cli/__init__.py index b8941361..60bf2595 100644 --- a/msal/oauth2cli/__init__.py +++ b/msal/oauth2cli/__init__.py @@ -1,6 +1,7 @@ -__version__ = "0.3.0" +__version__ = "0.4.0" from .oidc import Client from .assertion import JwtAssertionCreator from .assertion import JwtSigner # Obsolete. For backward compatibility. +from .authcode import AuthCodeReceiver diff --git a/msal/oauth2cli/assertion.py b/msal/oauth2cli/assertion.py index e84400df..f8d3f16f 100644 --- a/msal/oauth2cli/assertion.py +++ b/msal/oauth2cli/assertion.py @@ -63,7 +63,11 @@ def __init__(self, key, algorithm, sha1_thumbprint=None, headers=None): Args: - key (str): The key for signing, e.g. a base64 encoded private key. + key (str): + An unencrypted private key for signing, in a base64 encoded string. + It can also be a cryptography ``PrivateKey`` object, + which is how you can work with a previously-encrypted key. + See also https://github.com/jpadilla/pyjwt/pull/525 algorithm (str): "RS256", etc.. See https://pyjwt.readthedocs.io/en/latest/algorithms.html RSA and ECDSA algorithms require "pip install cryptography". diff --git a/msal/oauth2cli/authcode.py b/msal/oauth2cli/authcode.py index 65eacfbf..6f45c13d 100644 --- a/msal/oauth2cli/authcode.py +++ b/msal/oauth2cli/authcode.py @@ -5,10 +5,10 @@ It optionally opens a browser window to guide a human user to manually login. After obtaining an auth code, the web server will automatically shut down. """ - -import argparse import webbrowser import logging +import socket +from string import Template try: # Python 3 from http.server import HTTPServer, BaseHTTPRequestHandler @@ -18,45 +18,23 @@ from urlparse import urlparse, parse_qs from urllib import urlencode -from .oauth2 import Client - logger = logging.getLogger(__name__) -def obtain_auth_code(listen_port, auth_uri=None): - """This function will start a web server listening on http://localhost:port - and then you need to open a browser on this device and visit your auth_uri. - When interaction finishes, this function will return the auth code, - and then shut down the local web server. - - :param listen_port: - The local web server will listen at http://localhost: - Unless the authorization server supports dynamic port, - you need to use the same port when you register with your app. - :param auth_uri: If provided, this function will try to open a local browser. - :return: Hang indefinitely, until it receives and then return the auth code. - """ - exit_hint = "Visit http://localhost:{p}?code=exit to abort".format(p=listen_port) - logger.warning(exit_hint) - if auth_uri: - page = "http://localhost:{p}?{q}".format(p=listen_port, q=urlencode({ - "text": "Open this link to sign in. You may use incognito window", - "link": auth_uri, - "exit_hint": exit_hint, - })) - browse(page) - server = HTTPServer(("", int(listen_port)), AuthCodeReceiver) - try: - server.authcode = None - while not server.authcode: - # Derived from - # https://docs.python.org/2/library/basehttpserver.html#more-examples - server.handle_request() - return server.authcode - finally: - server.server_close() -def browse(auth_uri): +def obtain_auth_code(listen_port, auth_uri=None): # Historically only used in testing + with AuthCodeReceiver(port=listen_port) as receiver: + return receiver.get_auth_response( + auth_uri=auth_uri, + welcome_template=""" + Open this link to Sign In + (You may want to use incognito window) +
Abort + """, + ).get("code") + + +def _browse(auth_uri): controller = webbrowser.get() # Get a default controller # Some Linux Distro does not setup default browser properly, # so we try to explicitly use some popular browser, if we found any. @@ -69,23 +47,28 @@ def browse(auth_uri): logger.info("Please open a browser on THIS device to visit: %s" % auth_uri) controller.open(auth_uri) -class AuthCodeReceiver(BaseHTTPRequestHandler): + +def _qs2kv(qs): + """Flatten parse_qs()'s single-item lists into the item itself""" + return {k: v[0] if isinstance(v, list) and len(v) == 1 else v + for k, v in qs.items()} + + +class _AuthCodeHandler(BaseHTTPRequestHandler): def do_GET(self): # For flexibility, we choose to not check self.path matching redirect_uri #assert self.path.startswith('/THE_PATH_REGISTERED_BY_THE_APP') qs = parse_qs(urlparse(self.path).query) - if qs.get('code'): # Then store it into the server instance - ac = self.server.authcode = qs['code'][0] - self._send_full_response('Authcode:\n{}'.format(ac)) - # NOTE: Don't do self.server.shutdown() here. It'll halt the server. - elif qs.get('text') and qs.get('link'): # Then display a landing page + if qs.get('code') or qs.get("error"): # So, it is an auth response + self.server.auth_response = _qs2kv(qs) + logger.debug("Got auth response: %s", self.server.auth_response) + template = (self.server.success_template + if "code" in qs else self.server.error_template) self._send_full_response( - '{text}
{exit_hint}'.format( - link=qs['link'][0], text=qs['text'][0], - exit_hint=qs.get("exit_hint", [''])[0], - )) + template.safe_substitute(**self.server.auth_response)) + # NOTE: Don't do self.server.shutdown() here. It'll halt the server. else: - self._send_full_response("This web service serves your redirect_uri") + self._send_full_response(self.server.welcome_page) def _send_full_response(self, body, is_ok=True): self.send_response(200 if is_ok else 400) @@ -94,18 +77,153 @@ def _send_full_response(self, body, is_ok=True): self.end_headers() self.wfile.write(body.encode("utf-8")) + def log_message(self, format, *args): + logger.debug(format, *args) # To override the default log-to-stderr behavior + + +class _AuthCodeHttpServer(HTTPServer): + def handle_timeout(self): + # It will be triggered when no request comes in self.timeout seconds. + # See https://docs.python.org/3/library/socketserver.html#socketserver.BaseServer.handle_timeout + raise RuntimeError("Timeout. No auth response arrived.") # Terminates this server + # We choose to not call self.server_close() here, + # because it would cause a socket.error exception in handle_request(), + # and likely end up the server being server_close() twice. + + +class _AuthCodeHttpServer6(_AuthCodeHttpServer): + address_family = socket.AF_INET6 + + +class AuthCodeReceiver(object): + # This class has (rather than is) an _AuthCodeHttpServer, so it does not leak API + def __init__(self, port=None): + """Create a Receiver waiting for incoming auth response. + + :param port: + The local web server will listen at http://...: + You need to use the same port when you register with your app. + If your Identity Provider supports dynamic port, you can use port=0 here. + Port 0 means to use an arbitrary unused port, per this official example: + https://docs.python.org/2.7/library/socketserver.html#asynchronous-mixins + """ + address = "127.0.0.1" # Hardcode, for now, Not sure what to expose, yet. + # Per RFC 8252 (https://tools.ietf.org/html/rfc8252#section-8.3): + # * Clients should listen on the loopback network interface only. + # (It is not recommended to use "" shortcut to bind all addr.) + # * the use of localhost is NOT RECOMMENDED. + # (Use) the loopback IP literal + # rather than localhost avoids inadvertently listening on network + # interfaces other than the loopback interface. + # Note: + # When this server physically listens to a specific IP (as it should), + # you will still be able to specify your redirect_uri using either + # IP (e.g. 127.0.0.1) or localhost, whichever matches your registration. + Server = _AuthCodeHttpServer6 if ":" in address else _AuthCodeHttpServer + # TODO: But, it would treat "localhost" or "" as IPv4. + # If pressed, we might just expose a family parameter to caller. + self._server = Server((address, port or 0), _AuthCodeHandler) + + def get_port(self): + """The port this server actually listening to""" + # https://docs.python.org/2.7/library/socketserver.html#SocketServer.BaseServer.server_address + return self._server.server_address[1] + + def get_auth_response(self, auth_uri=None, timeout=None, state=None, + welcome_template=None, success_template=None, error_template=None): + """Wait and return the auth response, or None when timeout. + + :param str auth_uri: + If provided, this function will try to open a local browser. + :param int timeout: In seconds. None means wait indefinitely. + :param str state: + You may provide the state you used in auth_url, + then we will use it to validate incoming response. + :param str welcome_template: + If provided, your end user will see it instead of the auth_uri. + When present, it shall be a plaintext or html template following + `Python Template string syntax `_, + and include some of these placeholders: $auth_uri and $abort_uri. + :param str success_template: + The page will be displayed when authentication was largely successful. + Placeholders can be any of these: + https://tools.ietf.org/html/rfc6749#section-5.1 + :param str error_template: + The page will be displayed when authentication encountered error. + Placeholders can be any of these: + https://tools.ietf.org/html/rfc6749#section-5.2 + :return: + The auth response of the first leg of Auth Code flow, + typically {"code": "...", "state": "..."} or {"error": "...", ...} + See https://tools.ietf.org/html/rfc6749#section-4.1.2 + and https://openid.net/specs/openid-connect-core-1_0.html#AuthResponse + Returns None when the state was mismatched, or when timeout occurred. + """ + welcome_uri = "http://localhost:{p}".format(p=self.get_port()) + abort_uri = "{loc}?error=abort".format(loc=welcome_uri) + logger.debug("Abort by visit %s", abort_uri) + self._server.welcome_page = Template(welcome_template or "").safe_substitute( + auth_uri=auth_uri, abort_uri=abort_uri) + if auth_uri: + _browse(welcome_uri if welcome_template else auth_uri) + self._server.success_template = Template(success_template or + "Authentication completed. You can close this window now.") + self._server.error_template = Template(error_template or + "Authentication failed. $error: $error_description. ($error_uri)") + + self._server.timeout = timeout # Otherwise its handle_timeout() won't work + self._server.auth_response = {} # Shared with _AuthCodeHandler + while True: + # Derived from + # https://docs.python.org/2/library/basehttpserver.html#more-examples + self._server.handle_request() + if self._server.auth_response: + if state and state != self._server.auth_response.get("state"): + logger.debug("State mismatch. Ignoring this noise.") + else: + break + return self._server.auth_response + + def close(self): + """Either call this eventually; or use the entire class as context manager""" + self._server.server_close() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() +# Note: Manually use or test this module by: +# python -m path.to.this.file -h if __name__ == '__main__': + import argparse, json + from .oauth2 import Client logging.basicConfig(level=logging.INFO) p = parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, description=__doc__ + "The auth code received will be shown at stdout.") - p.add_argument('endpoint', - help="The auth endpoint for your app. For example: " - "https://login.microsoftonline.com/your_tenant/oauth2/authorize") + p.add_argument( + '--endpoint', help="The auth endpoint for your app.", + default="https://login.microsoftonline.com/common/oauth2/v2.0/authorize") p.add_argument('client_id', help="The client_id of your application") - p.add_argument('redirect_port', type=int, help="The port in redirect_uri") + p.add_argument('--port', type=int, default=0, help="The port in redirect_uri") + p.add_argument('--host', default="127.0.0.1", help="The host of redirect_uri") + p.add_argument('--scope', default=None, help="The scope list") args = parser.parse_args() - client = Client(args.client_id, authorization_endpoint=args.endpoint) - auth_uri = client.build_auth_request_uri("code") - print(obtain_auth_code(args.redirect_port, auth_uri)) + client = Client({"authorization_endpoint": args.endpoint}, args.client_id) + with AuthCodeReceiver(port=args.port) as receiver: + flow = client.initiate_auth_code_flow( + scope=args.scope.split() if args.scope else None, + redirect_uri="http://{h}:{p}".format(h=args.host, p=receiver.get_port()), + ) + print(json.dumps(receiver.get_auth_response( + auth_uri=flow["auth_uri"], + welcome_template= + "Sign In, or Abort= 3 else (basestring, ) @@ -129,6 +139,10 @@ def __init__( This does not apply if you have chosen to pass your own Http client. """ + if not server_configuration: + raise ValueError("Missing input parameter server_configuration") + if not client_id: + raise ValueError("Missing input parameter client_id") self.configuration = server_configuration self.client_id = client_id self.client_secret = client_secret @@ -252,6 +266,26 @@ def _stringify(self, sequence): return sequence # as-is +def _scope_set(scope): + assert scope is None or isinstance(scope, (list, set, tuple)) + return set(scope) if scope else set([]) + + +def _generate_pkce_code_verifier(length=43): + assert 43 <= length <= 128 + verifier = "".join( # https://tools.ietf.org/html/rfc7636#section-4.1 + random.sample(string.ascii_letters + string.digits + "-._~", length)) + code_challenge = ( + # https://tools.ietf.org/html/rfc7636#section-4.2 + base64.urlsafe_b64encode(hashlib.sha256(verifier.encode("ascii")).digest()) + .rstrip(b"=")) # Required by https://tools.ietf.org/html/rfc7636#section-3 + return { + "code_verifier": verifier, + "transformation": "S256", # In Python, sha256 is always available + "code_challenge": code_challenge, + } + + class Client(BaseClient): # We choose to implement all 4 grants in 1 class """This is the main API for oauth2 client. @@ -353,30 +387,9 @@ def obtain_token_by_device_flow(self, return result time.sleep(1) # Shorten each round, to make exit more responsive - def build_auth_request_uri( + def _build_auth_request_uri( self, response_type, redirect_uri=None, scope=None, state=None, **kwargs): - """Generate an authorization uri to be visited by resource owner. - - Later when the response reaches your redirect_uri, - you can use parse_auth_response() to check the returned state. - - This method could be named build_authorization_request_uri() instead, - but then there would be a build_authentication_request_uri() in the OIDC - subclass doing almost the same thing. So we use a loose term "auth" here. - - :param response_type: - Must be "code" when you are using Authorization Code Grant, - "token" when you are using Implicit Grant, or other - (possibly space-delimited) strings as registered extension value. - See https://tools.ietf.org/html/rfc6749#section-3.1.1 - :param redirect_uri: Optional. Server will use the pre-registered one. - :param scope: It is a space-delimited, case-sensitive string. - Some ID provider can accept empty string to represent default scope. - :param state: Recommended. An opaque value used by the client to - maintain state between the request and callback. - :param kwargs: Other parameters, typically defined in OpenID Connect. - """ if "authorization_endpoint" not in self.configuration: raise ValueError("authorization_endpoint not found in configuration") authorization_endpoint = self.configuration["authorization_endpoint"] @@ -386,6 +399,251 @@ def build_auth_request_uri( sep = '&' if '?' in authorization_endpoint else '?' return "%s%s%s" % (authorization_endpoint, sep, urlencode(params)) + def build_auth_request_uri( + self, + response_type, redirect_uri=None, scope=None, state=None, **kwargs): + # This method could be named build_authorization_request_uri() instead, + # but then there would be a build_authentication_request_uri() in the OIDC + # subclass doing almost the same thing. So we use a loose term "auth" here. + """Generate an authorization uri to be visited by resource owner. + + Parameters are the same as another method :func:`initiate_auth_code_flow()`, + whose functionality is a superset of this method. + + :return: The auth uri as a string. + """ + warnings.warn("Use initiate_auth_code_flow() instead. ", DeprecationWarning) + return self._build_auth_request_uri( + response_type, redirect_uri=redirect_uri, scope=scope, state=state, + **kwargs) + + def initiate_auth_code_flow( + # The name is influenced by OIDC + # https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowAuth + self, + scope=None, redirect_uri=None, state=None, + **kwargs): + """Initiate an auth code flow. + + Later when the response reaches your redirect_uri, + you can use :func:`~obtain_token_by_auth_code_flow()` + to complete the authentication/authorization. + + This method also provides PKCE protection automatically. + + :param list scope: + It is a list of case-sensitive strings. + Some ID provider can accept empty string to represent default scope. + :param str redirect_uri: + Optional. If not specified, server will use the pre-registered one. + :param str state: + An opaque value used by the client to + maintain state between the request and callback. + If absent, this library will automatically generate one internally. + :param kwargs: Other parameters, typically defined in OpenID Connect. + + :return: + The auth code flow. It is a dict in this form:: + + { + "auth_uri": "https://...", // Guide user to visit this + "state": "...", // You may choose to verify it by yourself, + // or just let obtain_token_by_auth_code_flow() + // do that for you. + "...": "...", // Everything else are reserved and internal + } + + The caller is expected to:: + + 1. somehow store this content, typically inside the current session, + 2. guide the end user (i.e. resource owner) to visit that auth_uri, + 3. and then relay this dict and subsequent auth response to + :func:`~obtain_token_by_auth_code_flow()`. + """ + response_type = kwargs.pop("response_type", "code") # Auth Code flow + # Must be "code" when you are using Authorization Code Grant. + # The "token" for Implicit Grant is not applicable thus not allowed. + # It could theoretically be other + # (possibly space-delimited) strings as registered extension value. + # See https://tools.ietf.org/html/rfc6749#section-3.1.1 + if "token" in response_type: + # Implicit grant would cause auth response coming back in #fragment, + # but fragment won't reach a web service. + raise ValueError('response_type="token ..." is not allowed') + pkce = _generate_pkce_code_verifier() + flow = { # These data are required by obtain_token_by_auth_code_flow() + "state": state or "".join(random.sample(string.ascii_letters, 16)), + "redirect_uri": redirect_uri, + "scope": scope, + } + auth_uri = self._build_auth_request_uri( + response_type, + code_challenge=pkce["code_challenge"], + code_challenge_method=pkce["transformation"], + **dict(flow, **kwargs)) + flow["auth_uri"] = auth_uri + flow["code_verifier"] = pkce["code_verifier"] + return flow + + def obtain_token_by_auth_code_flow( + self, + auth_code_flow, + auth_response, + scope=None, + **kwargs): + """With the auth_response being redirected back, + validate it against auth_code_flow, and then obtain tokens. + + Internally, it implements PKCE to mitigate the auth code interception attack. + + :param dict auth_code_flow: + The same dict returned by :func:`~initiate_auth_code_flow()`. + :param dict auth_response: + A dict based on query string received from auth server. + + :param scope: + You don't usually need to use scope parameter here. + Some Identity Provider allows you to provide + a subset of what you specified during :func:`~initiate_auth_code_flow`. + :type scope: collections.Iterable[str] + + :return: + * A dict containing "access_token" and/or "id_token", among others, + depends on what scope was used. + (See https://tools.ietf.org/html/rfc6749#section-5.1) + * A dict containing "error", optionally "error_description", "error_uri". + (It is either `this `_ + or `that `_ + * Most client-side data error would result in ValueError exception. + So the usage pattern could be without any protocol details:: + + def authorize(): # A controller in a web app + try: + result = client.obtain_token_by_auth_code_flow( + session.get("flow", {}), auth_resp) + if "error" in result: + return render_template("error.html", result) + store_tokens() + except ValueError: # Usually caused by CSRF + pass # Simply ignore them + return redirect(url_for("index")) + """ + assert isinstance(auth_code_flow, dict) and isinstance(auth_response, dict) + # This is app developer's error which we do NOT want to map to ValueError + if not auth_code_flow.get("state"): + # initiate_auth_code_flow() already guarantees a state to be available. + # This check will also allow a web app to blindly call this method with + # obtain_token_by_auth_code_flow(session.get("flow", {}), auth_resp) + # which further simplifies their usage. + raise ValueError("state missing from auth_code_flow") + if auth_code_flow.get("state") != auth_response.get("state"): + raise ValueError("state mismatch: {} vs {}".format( + auth_code_flow.get("state"), auth_response.get("state"))) + if scope and set(scope) - set(auth_code_flow.get("scope", [])): + raise ValueError( + "scope must be None or a subset of %s" % auth_code_flow.get("scope")) + if auth_response.get("code"): # i.e. the first leg was successful + return self._obtain_token_by_authorization_code( + auth_response["code"], + redirect_uri=auth_code_flow.get("redirect_uri"), + # Required, if "redirect_uri" parameter was included in the + # authorization request, and their values MUST be identical. + scope=scope or auth_code_flow.get("scope"), + # It is both unnecessary and harmless, per RFC 6749. + # We use the same scope already used in auth request uri, + # thus token cache can know what scope the tokens are for. + data=dict( # Extract and update the data + kwargs.pop("data", {}), + code_verifier=auth_code_flow["code_verifier"], + ), + **kwargs) + if auth_response.get("error"): # It means the first leg encountered error + # Here we do NOT return original auth_response as-is, to prevent a + # potential {..., "access_token": "attacker's AT"} input being leaked + error = {"error": auth_response["error"]} + if auth_response.get("error_description"): + error["error_description"] = auth_response["error_description"] + if auth_response.get("error_uri"): + error["error_uri"] = auth_response["error_uri"] + return error + raise ValueError('auth_response must contain either "code" or "error"') + + def obtain_token_by_browser( + # Name influenced by RFC 8252: "native apps should (use) ... user's browser" + self, + scope=None, + extra_scope_to_consent=None, + redirect_uri=None, + timeout=None, + welcome_template=None, + success_template=None, + auth_params=None, + **kwargs): + """A native app can use this method to obtain token via a local browser. + + Internally, it implements PKCE to mitigate the auth code interception attack. + + :param scope: A list of scopes that you would like to obtain token for. + :type scope: collections.Iterable[str] + + :param extra_scope_to_consent: + Some IdP allows you to include more scopes for end user to consent. + The access token returned by this method will NOT include those scopes, + but the refresh token would record those extra consent, + so that your future :func:`~obtain_token_by_refresh_token()` call + would be able to obtain token for those additional scopes, silently. + :type scope: collections.Iterable[str] + + :param string redirect_uri: + The redirect_uri to be sent via auth request to Identity Provider (IdP), + to indicate where an auth response would come back to. + Such as ``http://127.0.0.1:0`` (default) or ``http://localhost:1234``. + + If port 0 is specified, this method will choose a system-allocated port, + then the actual redirect_uri will contain that port. + To use this behavior, your IdP would need to accept such dynamic port. + + Per HTTP convention, if port number is absent, it would mean port 80, + although you probably want to specify port 0 in this context. + + :param dict auth_params: + These parameters will be sent to authorization_endpoint. + + :param int timeout: In seconds. None means wait indefinitely. + :return: Same as :func:`~obtain_token_by_auth_code_flow()` + """ + _redirect_uri = urlparse(redirect_uri or "http://127.0.0.1:0") + if not _redirect_uri.hostname: + raise ValueError("redirect_uri should contain hostname") + if _redirect_uri.scheme == "https": + raise ValueError("Our local loopback server will not use https") + listen_port = _redirect_uri.port if _redirect_uri.port is not None else 80 + # This implementation allows port-less redirect_uri to mean port 80 + try: + with _AuthCodeReceiver(port=listen_port) as receiver: + flow = self.initiate_auth_code_flow( + redirect_uri="http://{host}:{port}".format( + host=_redirect_uri.hostname, port=receiver.get_port(), + ) if _redirect_uri.port is not None else "http://{host}".format( + host=_redirect_uri.hostname + ), # This implementation uses port-less redirect_uri as-is + scope=_scope_set(scope) | _scope_set(extra_scope_to_consent), + **(auth_params or {})) + auth_response = receiver.get_auth_response( + auth_uri=flow["auth_uri"], + state=flow["state"], # Optional but we choose to do it upfront + timeout=timeout, + welcome_template=welcome_template, + success_template=success_template, + ) + except PermissionError: + if 0 < listen_port < 1024: + self.logger.error( + "Can't listen on port %s. You may try port 0." % listen_port) + raise + return self.obtain_token_by_auth_code_flow( + flow, auth_response, scope=scope, **kwargs) + @staticmethod def parse_auth_response(params, state=None): """Parse the authorization response being redirected back. @@ -394,6 +652,8 @@ def parse_auth_response(params, state=None): :param state: REQUIRED if the state parameter was present in the client authorization request. This function will compare it with response. """ + warnings.warn( + "Use obtain_token_by_auth_code_flow() instead", DeprecationWarning) if not isinstance(params, dict): params = parse_qs(params) if params.get('state') != state: @@ -408,6 +668,9 @@ def obtain_token_by_authorization_code( but it can also be used by a device-side native app (Public Client). See more detail at https://tools.ietf.org/html/rfc6749#section-4.1.3 + You are encouraged to use its higher level method + :func:`~obtain_token_by_auth_code_flow` instead. + :param code: The authorization code received from authorization server. :param redirect_uri: Required, if the "redirect_uri" parameter was included in the @@ -417,6 +680,13 @@ def obtain_token_by_authorization_code( We suggest to use the same scope already used in auth request uri, so that this library can link the obtained tokens with their scope. """ + warnings.warn( + "Use obtain_token_by_auth_code_flow() instead", DeprecationWarning) + return self._obtain_token_by_authorization_code( + code, redirect_uri=redirect_uri, scope=scope, **kwargs) + + def _obtain_token_by_authorization_code( + self, code, redirect_uri=None, scope=None, **kwargs): data = kwargs.pop("data", {}) data.update(code=code, redirect_uri=redirect_uri) if scope: diff --git a/msal/oauth2cli/oidc.py b/msal/oauth2cli/oidc.py index 246bdae1..eb2e80aa 100644 --- a/msal/oauth2cli/oidc.py +++ b/msal/oauth2cli/oidc.py @@ -1,6 +1,10 @@ import json import base64 import time +import random +import string +import warnings +import hashlib from . import oauth2 @@ -56,7 +60,7 @@ def decode_id_token(id_token, client_id=None, issuer=None, nonce=None, now=None) err = "3. The aud (audience) Claim must contain this client's client_id." # Per specs: # 6. If the ID Token is received via direct communication between - # the Client and the Token Endpoint (which it is in this flow), + # the Client and the Token Endpoint (which it is during _obtain_token()), # the TLS server validation MAY be used to validate the issuer # in place of checking the token signature. if _now > decoded["exp"]: @@ -70,6 +74,11 @@ def decode_id_token(id_token, client_id=None, issuer=None, nonce=None, now=None) return decoded +def _nonce_hash(nonce): + # https://openid.net/specs/openid-connect-core-1_0.html#NonceNotes + return hashlib.sha256(nonce.encode("ascii")).hexdigest() + + class Client(oauth2.Client): """OpenID Connect is a layer on top of the OAuth2. @@ -101,6 +110,7 @@ def build_auth_request_uri(self, response_type, nonce=None, **kwargs): A hard-to-guess string used to mitigate replay attacks. See also `OIDC specs `_. """ + warnings.warn("Use initiate_auth_code_flow() instead", DeprecationWarning) return super(Client, self).build_auth_request_uri( response_type, nonce=nonce, **kwargs) @@ -116,6 +126,8 @@ def obtain_token_by_authorization_code(self, code, nonce=None, **kwargs): same nonce should also be provided here, so that we'll validate it. An exception will be raised if the nonce in id token mismatches. """ + warnings.warn( + "Use obtain_token_by_auth_code_flow() instead", DeprecationWarning) result = super(Client, self).obtain_token_by_authorization_code( code, **kwargs) nonce_in_id_token = result.get("id_token_claims", {}).get("nonce") @@ -125,3 +137,106 @@ def obtain_token_by_authorization_code(self, code, nonce=None, **kwargs): (nonce_in_id_token, nonce)) return result + def initiate_auth_code_flow( + self, + scope=None, + **kwargs): + """Initiate an auth code flow. + + It provides nonce protection automatically. + + :param list scope: + A list of strings, e.g. ["profile", "email", ...]. + This method will automatically send ["openid"] to the wire, + although it won't modify your input list. + + See :func:`oauth2.Client.initiate_auth_code_flow` in parent class + for descriptions on other parameters and return value. + """ + if "id_token" in kwargs.get("response_type", ""): + # Implicit grant would cause auth response coming back in #fragment, + # but fragment won't reach a web service. + raise ValueError('response_type="id_token ..." is not allowed') + _scope = list(scope) if scope else [] # We won't modify input parameter + if "openid" not in _scope: + # "If no openid scope value is present, + # the request may still be a valid OAuth 2.0 request, + # but is not an OpenID Connect request." -- OIDC Core Specs, 3.1.2.2 + # https://openid.net/specs/openid-connect-core-1_0.html#AuthRequestValidation + # Here we just automatically add it. If the caller do not want id_token, + # they should simply go with oauth2.Client. + _scope.append("openid") + nonce = "".join(random.sample(string.ascii_letters, 16)) + flow = super(Client, self).initiate_auth_code_flow( + scope=_scope, nonce=_nonce_hash(nonce), **kwargs) + flow["nonce"] = nonce + return flow + + def obtain_token_by_auth_code_flow(self, auth_code_flow, auth_response, **kwargs): + """Validate the auth_response being redirected back, and then obtain tokens, + including ID token which can be used for user sign in. + + Internally, it implements nonce to mitigate replay attack. + It also implements PKCE to mitigate the auth code interception attack. + + See :func:`oauth2.Client.obtain_token_by_auth_code_flow` in parent class + for descriptions on other parameters and return value. + """ + result = super(Client, self).obtain_token_by_auth_code_flow( + auth_code_flow, auth_response, **kwargs) + if "id_token_claims" in result: + nonce_in_id_token = result.get("id_token_claims", {}).get("nonce") + expected_hash = _nonce_hash(auth_code_flow["nonce"]) + if nonce_in_id_token != expected_hash: + raise RuntimeError( + 'The nonce in id token ("%s") should match our nonce ("%s")' % + (nonce_in_id_token, expected_hash)) + return result + + def obtain_token_by_browser( + self, + display=None, + prompt=None, + max_age=None, + ui_locales=None, + id_token_hint=None, # It is relevant, + # because this library exposes raw ID token + login_hint=None, + acr_values=None, + **kwargs): + """A native app can use this method to obtain token via a local browser. + + Internally, it implements nonce to mitigate replay attack. + It also implements PKCE to mitigate the auth code interception attack. + + :param string display: Defined in + `OIDC `_. + :param string prompt: Defined in + `OIDC `_. + :param int max_age: Defined in + `OIDC `_. + :param string ui_locales: Defined in + `OIDC `_. + :param string id_token_hint: Defined in + `OIDC `_. + :param string login_hint: Defined in + `OIDC `_. + :param string acr_values: Defined in + `OIDC `_. + + See :func:`oauth2.Client.obtain_token_by_browser` in parent class + for descriptions on other parameters and return value. + """ + filtered_params = {k:v for k, v in dict( + prompt=prompt, + display=display, + max_age=max_age, + ui_locales=ui_locales, + id_token_hint=id_token_hint, + login_hint=login_hint, + acr_values=acr_values, + ).items() if v is not None} # Filter out None values + return super(Client, self).obtain_token_by_browser( + auth_params=dict(kwargs.pop("auth_params", {}), **filtered_params), + **kwargs) + diff --git a/msal/token_cache.py b/msal/token_cache.py index b7ebbb99..34eff37c 100644 --- a/msal/token_cache.py +++ b/msal/token_cache.py @@ -234,8 +234,9 @@ def modify(self, credential_type, old_entry, new_key_value_pairs=None): with self._lock: if new_key_value_pairs: # Update with them entries = self._cache.setdefault(credential_type, {}) - entry = entries.setdefault(key, {}) # Create it if not yet exist - entry.update(new_key_value_pairs) + entries[key] = dict( + old_entry, # Do not use entries[key] b/c it might not exist + **new_key_value_pairs) else: # Remove old_entry self._cache.setdefault(credential_type, {}).pop(key, None) diff --git a/sample/interactive_sample.py b/sample/interactive_sample.py new file mode 100644 index 00000000..38593315 --- /dev/null +++ b/sample/interactive_sample.py @@ -0,0 +1,75 @@ +""" +The configuration file would look like this: + +{ + "authority": "https://login.microsoftonline.com/organizations", + "client_id": "your_client_id", + "scope": ["User.ReadBasic.All"], + // You can find the other permission names from this document + // https://docs.microsoft.com/en-us/graph/permissions-reference + "username": "your_username@your_tenant.com", // This is optional + "endpoint": "https://graph.microsoft.com/v1.0/users" + // You can find more Microsoft Graph API endpoints from Graph Explorer + // https://developer.microsoft.com/en-us/graph/graph-explorer +} + +You can then run this sample with a JSON configuration file: + + python sample.py parameters.json +""" + +import sys # For simplicity, we'll read config file from 1st CLI param sys.argv[1] +import json, logging, msal, requests + +# Optional logging +# logging.basicConfig(level=logging.DEBUG) # Enable DEBUG log for entire script +# logging.getLogger("msal").setLevel(logging.INFO) # Optionally disable MSAL DEBUG logs + +config = json.load(open(sys.argv[1])) + +# Create a preferably long-lived app instance which maintains a token cache. +app = msal.PublicClientApplication( + config["client_id"], authority=config["authority"], + # token_cache=... # Default cache is in memory only. + # You can learn how to use SerializableTokenCache from + # https://msal-python.rtfd.io/en/latest/#msal.SerializableTokenCache + ) + +# The pattern to acquire a token looks like this. +result = None + +# Firstly, check the cache to see if this end user has signed in before +accounts = app.get_accounts(username=config.get("username")) +if accounts: + logging.info("Account(s) exists in cache, probably with token too. Let's try.") + print("Account(s) already signed in:") + for a in accounts: + print(a["username"]) + chosen = accounts[0] # Assuming the end user chose this one to proceed + print("Proceed with account: %s" % chosen["username"]) + # Now let's try to find a token in cache for this account + result = app.acquire_token_silent(config["scope"], account=chosen) + +if not result: + logging.info("No suitable token exists in cache. Let's get a new one from AAD.") + print("A local browser window will be open for you to sign in. CTRL+C to cancel.") + result = app.acquire_token_interactive( + config["scope"], + login_hint=config.get("username"), # You can use this parameter to pre-fill + # the username (or email address) field of the sign-in page for the user, + # if you know the username ahead of time. + # Often, apps use this parameter during reauthentication, + # after already extracting the username from an earlier sign-in + # by using the preferred_username claim from returned id_token_claims. + ) + +if "access_token" in result: + # Calling graph using the access token + graph_response = requests.get( # Use token to call downstream service + config["endpoint"], + headers={'Authorization': 'Bearer ' + result['access_token']},) + print("Graph API call result: %s ..." % graph_response.text[:100]) +else: + print(result.get("error")) + print(result.get("error_description")) + print(result.get("correlation_id")) # You may need this when reporting a bug diff --git a/tests/http_client.py b/tests/http_client.py index 4bff9b45..a5587b70 100644 --- a/tests/http_client.py +++ b/tests/http_client.py @@ -10,11 +10,13 @@ def __init__(self, verify=True, proxies=None, timeout=None): self.timeout = timeout def post(self, url, params=None, data=None, headers=None, **kwargs): + assert not kwargs, "Our stack shouldn't leak extra kwargs: %s" % kwargs return MinimalResponse(requests_resp=self.session.post( url, params=params, data=data, headers=headers, timeout=self.timeout)) def get(self, url, params=None, headers=None, **kwargs): + assert not kwargs, "Our stack shouldn't leak extra kwargs: %s" % kwargs return MinimalResponse(requests_resp=self.session.get( url, params=params, headers=headers, timeout=self.timeout)) @@ -26,5 +28,6 @@ def __init__(self, requests_resp=None, status_code=None, text=None): self._raw_resp = requests_resp def raise_for_status(self): - if self._raw_resp: + if self._raw_resp is not None: # Turns out `if requests.response` won't work + # cause it would be True when 200<=status<400 self._raw_resp.raise_for_status() diff --git a/tests/test_client.py b/tests/test_client.py index ebce8e55..39fc9145 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -9,7 +9,7 @@ import requests -from msal.oauth2cli import Client, JwtSigner +from msal.oauth2cli import Client, JwtSigner, AuthCodeReceiver from msal.oauth2cli.authcode import obtain_auth_code from tests import unittest, Oauth2TestCase from tests.http_client import MinimalHttpClient, MinimalResponse @@ -153,6 +153,68 @@ def test_auth_code(self): redirect_uri=redirect_uri) self.assertLoosely(result, lambda: self.assertIn('access_token', result)) + @unittest.skipUnless( + "authorization_endpoint" in CONFIG.get("openid_configuration", {}), + "authorization_endpoint missing") + def test_auth_code_flow(self): + with AuthCodeReceiver(port=CONFIG.get("listen_port")) as receiver: + flow = self.client.initiate_auth_code_flow( + redirect_uri="http://localhost:%d" % receiver.get_port(), + scope=CONFIG.get("scope"), + login_hint=CONFIG.get("username"), # To skip the account selector + ) + auth_response = receiver.get_auth_response( + auth_uri=flow["auth_uri"], + state=flow["state"], # Optional but recommended + timeout=120, + welcome_template=""" + authorization_endpoint = {a}, client_id = {i} + Sign In or Abort + """.format( + a=CONFIG["openid_configuration"]["authorization_endpoint"], + i=CONFIG.get("client_id")), + ) + self.assertIsNotNone( + auth_response.get("code"), "Error: {}, Detail: {}".format( + auth_response.get("error"), auth_response)) + result = self.client.obtain_token_by_auth_code_flow(flow, auth_response) + #TBD: data={"resource": CONFIG.get("resource")}, # MSFT AAD v1 only + self.assertLoosely(result, lambda: self.assertIn('access_token', result)) + + def test_auth_code_flow_error_response(self): + with self.assertRaisesRegexp(ValueError, "state missing"): + self.client.obtain_token_by_auth_code_flow({}, {"code": "foo"}) + with self.assertRaisesRegexp(ValueError, "state mismatch"): + self.client.obtain_token_by_auth_code_flow({"state": "1"}, {"state": "2"}) + with self.assertRaisesRegexp(ValueError, "scope"): + self.client.obtain_token_by_auth_code_flow( + {"state": "s", "scope": ["foo"]}, {"state": "s"}, scope=["bar"]) + self.assertEqual( + {"error": "foo", "error_uri": "bar"}, + self.client.obtain_token_by_auth_code_flow( + {"state": "s"}, + {"state": "s", "error": "foo", "error_uri": "bar", "access_token": "fake"}), + "We should not leak malicious input into our output") + + @unittest.skipUnless( + "authorization_endpoint" in CONFIG.get("openid_configuration", {}), + "authorization_endpoint missing") + def test_obtain_token_by_browser(self): + result = self.client.obtain_token_by_browser( + scope=CONFIG.get("scope"), + redirect_uri=CONFIG.get("redirect_uri"), + welcome_template=""" + authorization_endpoint = {a}, client_id = {i} + Sign In or Abort + """.format( + a=CONFIG["openid_configuration"]["authorization_endpoint"], + i=CONFIG.get("client_id")), + success_template="Done. You can close this window now.", + login_hint=CONFIG.get("username"), # To skip the account selector + timeout=60, + ) + self.assertLoosely(result, lambda: self.assertIn('access_token', result)) + @unittest.skipUnless( CONFIG.get("openid_configuration", {}).get("device_authorization_endpoint"), "device_authorization_endpoint is missing") @@ -223,7 +285,7 @@ def test_rt_being_migrated(self): class TestSessionAccessibility(unittest.TestCase): def test_accessing_session_property_for_backward_compatibility(self): - client = Client({}, "client_id") + client = Client({"token_endpoint": "https://example.com"}, "client_id") client.session client.session.close() client.session = "something" diff --git a/tests/test_e2e.py b/tests/test_e2e.py index 957d01a4..adfe5d42 100644 --- a/tests/test_e2e.py +++ b/tests/test_e2e.py @@ -8,6 +8,7 @@ import msal from tests.http_client import MinimalHttpClient +from msal.oauth2cli import AuthCodeReceiver logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) @@ -297,14 +298,16 @@ def get_lab_app( Get it from environment variables if defined, otherwise fall back to use MSI. """ + logger.info( + "Reading ENV variables %s and %s for lab app defined at " + "https://docs.msidlab.com/accounts/confidentialclient.html", + env_client_id, env_client_secret) if os.getenv(env_client_id) and os.getenv(env_client_secret): # A shortcut mainly for running tests on developer's local development machine # or it could be setup on Travis CI # https://docs.travis-ci.com/user/environment-variables/#defining-variables-in-repository-settings # Data came from here # https://docs.msidlab.com/accounts/confidentialclient.html - logger.info("Using lab app defined by ENV variables %s and %s", - env_client_id, env_client_secret) client_id = os.getenv(env_client_id) client_secret = os.getenv(env_client_secret) else: @@ -399,6 +402,78 @@ def _test_acquire_token_by_auth_code( error_description=result.get("error_description"))) self.assertCacheWorksForUser(result, scope, username=None) + def _test_acquire_token_by_auth_code_flow( + self, client_id=None, authority=None, port=None, scope=None, + username_uri="", # But you would want to provide one + **ignored): + assert client_id and authority and scope + self.app = msal.ClientApplication( + client_id, authority=authority, http_client=MinimalHttpClient()) + with AuthCodeReceiver(port=port) as receiver: + flow = self.app.initiate_auth_code_flow( + redirect_uri="http://localhost:%d" % receiver.get_port(), + scopes=scope, + ) + auth_response = receiver.get_auth_response( + auth_uri=flow["auth_uri"], state=flow["state"], timeout=60, + welcome_template="""

{id}

    +
  1. Get a username from the upn shown at here
  2. +
  3. Get its password from https://aka.ms/GetLabUserSecret?Secret=msidlabXYZ + (replace the lab name with the labName from the link above).
  4. +
  5. Sign In or Abort
  6. +
""".format(id=self.id(), username_uri=username_uri), + ) + self.assertIsNotNone( + auth_response.get("code"), "Error: {}, Detail: {}".format( + auth_response.get("error"), auth_response)) + result = self.app.acquire_token_by_auth_code_flow(flow, auth_response) + logger.debug( + "%s: cache = %s, id_token_claims = %s", + self.id(), + json.dumps(self.app.token_cache._cache, indent=4), + json.dumps(result.get("id_token_claims"), indent=4), + ) + self.assertIn( + "access_token", result, + "{error}: {error_description}".format( + # Note: No interpolation here, cause error won't always present + error=result.get("error"), + error_description=result.get("error_description"))) + self.assertCacheWorksForUser(result, scope, username=None) + + def _test_acquire_token_interactive( + self, client_id=None, authority=None, scope=None, port=None, + username_uri="", # But you would want to provide one + **ignored): + assert client_id and authority and scope + self.app = msal.PublicClientApplication( + client_id, authority=authority, http_client=MinimalHttpClient()) + result = self.app.acquire_token_interactive( + scope, + timeout=60, + port=port, + welcome_template= # This is an undocumented feature for testing + """

{id}

    +
  1. Get a username from the upn shown at here
  2. +
  3. Get its password from https://aka.ms/GetLabUserSecret?Secret=msidlabXYZ + (replace the lab name with the labName from the link above).
  4. +
  5. Sign In or Abort
  6. +
""".format(id=self.id(), username_uri=username_uri), + ) + logger.debug( + "%s: cache = %s, id_token_claims = %s", + self.id(), + json.dumps(self.app.token_cache._cache, indent=4), + json.dumps(result.get("id_token_claims"), indent=4), + ) + self.assertIn( + "access_token", result, + "{error}: {error_description}".format( + # Note: No interpolation here, cause error won't always present + error=result.get("error"), + error_description=result.get("error_description"))) + self.assertCacheWorksForUser(result, scope, username=None) + def _test_acquire_token_obo(self, config_pca, config_cca): # 1. An app obtains a token representing a user, for our mid-tier service pca = msal.PublicClientApplication( @@ -474,9 +549,21 @@ def test_adfs2_fed_user(self): self._test_username_password(**config) def test_adfs2019_fed_user(self): - config = self.get_lab_user(usertype="federated", federationProvider="ADFSv2019") - config["password"] = self.get_lab_user_secret(config["lab_name"]) - self._test_username_password(**config) + try: + config = self.get_lab_user(usertype="federated", federationProvider="ADFSv2019") + config["password"] = self.get_lab_user_secret(config["lab_name"]) + self._test_username_password(**config) + except requests.exceptions.HTTPError: + if os.getenv("TRAVIS"): + self.skipTest("MEX endpoint in our test environment tends to fail") + raise + + @unittest.skipIf(os.getenv("TRAVIS"), "Browser automation is not yet implemented") + def test_cloud_acquire_token_interactive(self): + config = self.get_lab_user(usertype="cloud") + self._test_acquire_token_interactive( + username_uri="https://msidlab.com/api/user?usertype=cloud", + **config) def test_ropc_adfs2019_onprem(self): # Configuration is derived from https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/4.7.0/tests/Microsoft.Identity.Test.Common/TestConstants.cs#L250-L259 @@ -500,6 +587,26 @@ def test_adfs2019_onprem_acquire_token_by_auth_code(self): config["port"] = 8080 self._test_acquire_token_by_auth_code(**config) + @unittest.skipIf(os.getenv("TRAVIS"), "Browser automation is not yet implemented") + def test_adfs2019_onprem_acquire_token_by_auth_code_flow(self): + config = self.get_lab_user(usertype="onprem", federationProvider="ADFSv2019") + config["authority"] = "https://fs.%s.com/adfs" % config["lab_name"] + config["scope"] = self.adfs2019_scopes + config["port"] = 8080 + self._test_acquire_token_by_auth_code_flow( + username_uri="https://msidlab.com/api/user?usertype=onprem&federationprovider=ADFSv2019", + **config) + + @unittest.skipIf(os.getenv("TRAVIS"), "Browser automation is not yet implemented") + def test_adfs2019_onprem_acquire_token_interactive(self): + config = self.get_lab_user(usertype="onprem", federationProvider="ADFSv2019") + config["authority"] = "https://fs.%s.com/adfs" % config["lab_name"] + config["scope"] = self.adfs2019_scopes + config["port"] = 8080 + self._test_acquire_token_interactive( + username_uri="https://msidlab.com/api/user?usertype=onprem&federationprovider=ADFSv2019", + **config) + @unittest.skipUnless( os.getenv("LAB_OBO_CLIENT_SECRET"), "Need LAB_OBO_CLIENT SECRET from https://msidlabs.vault.azure.net/secrets/TodoListServiceV2-OBO/c58ba97c34ca4464886943a847d1db56") @@ -547,6 +654,17 @@ def test_b2c_acquire_token_by_auth_code(self): scope=config["defaultScopes"].split(','), ) + @unittest.skipIf(os.getenv("TRAVIS"), "Browser automation is not yet implemented") + def test_b2c_acquire_token_by_auth_code_flow(self): + config = self.get_lab_app_object(azureenvironment="azureb2ccloud") + self._test_acquire_token_by_auth_code_flow( + authority=self._build_b2c_authority("B2C_1_SignInPolicy"), + client_id=config["appId"], + port=3843, # Lab defines 4 of them: [3843, 4584, 4843, 60000] + scope=config["defaultScopes"].split(','), + username_uri="https://msidlab.com/api/user?usertype=b2c&b2cprovider=local", + ) + def test_b2c_acquire_token_by_ropc(self): config = self.get_lab_app_object(azureenvironment="azureb2ccloud") self._test_username_password( diff --git a/tests/test_token_cache.py b/tests/test_token_cache.py index 1666bba2..c846883d 100644 --- a/tests/test_token_cache.py +++ b/tests/test_token_cache.py @@ -222,6 +222,24 @@ def test_key_id_is_also_recorded(self): {}).get("key_id") self.assertEqual(my_key_id, cached_key_id, "AT should be bound to the key") + def test_old_rt_data_with_wrong_key_should_still_be_salvaged_into_new_rt(self): + sample = { + 'client_id': 'my_client_id', + 'credential_type': 'RefreshToken', + 'environment': 'login.example.com', + 'home_account_id': "uid.utid", + 'secret': 'a refresh token', + 'target': 's2 s1 s3', + } + new_rt = "this is a new RT" + self.cache._cache["RefreshToken"] = {"wrong-key": sample} + self.cache.modify( + self.cache.CredentialType.REFRESH_TOKEN, sample, {"secret": new_rt}) + self.assertEqual( + dict(sample, secret=new_rt), + self.cache._cache["RefreshToken"].get( + 'uid.utid-login.example.com-refreshtoken-my_client_id--s2 s1 s3') + ) class SerializableTokenCacheTestCase(TokenCacheTestCase): # Run all inherited test methods, and have extra check in tearDown()