diff --git a/msal/application.py b/msal/application.py index 686cc95d..13015d09 100644 --- a/msal/application.py +++ b/msal/application.py @@ -26,7 +26,7 @@ # The __init__.py will import this. Not the other way around. -__version__ = "1.15.0" +__version__ = "1.16.0" logger = logging.getLogger(__name__) @@ -170,6 +170,7 @@ def __init__( # This way, it holds the same positional param place for PCA, # when we would eventually want to add this feature to PCA in future. exclude_scopes=None, + http_cache=None, ): """Create an instance of application. @@ -285,7 +286,8 @@ def __init__( which you will later provide via one of the acquire-token request. :param str azure_region: - Added since MSAL Python 1.12.0. + AAD provides regional endpoints for apps to opt in + to keep their traffic remain inside that region. As of 2021 May, regional service is only available for ``acquire_token_for_client()`` sent by any of the following scenarios:: @@ -302,9 +304,7 @@ def __init__( 4. An app which already onboard to the region's allow-list. - MSAL's default value is None, which means region behavior remains off. - If enabled, the `acquire_token_for_client()`-relevant traffic - would remain inside that region. + This parameter defaults to None, which means region behavior remains off. App developer can opt in to a regional endpoint, by provide its region name, such as "westus", "eastus2". @@ -330,12 +330,69 @@ def __init__( or provide a custom http_client which has a short timeout. That way, the latency would be under your control, but still less performant than opting out of region feature. + + New in version 1.12.0. + :param list[str] exclude_scopes: (optional) Historically MSAL hardcodes `offline_access` scope, which would allow your app to have prolonged access to user's data. If that is unnecessary or undesirable for your app, now you can use this parameter to supply an exclusion list of scopes, such as ``exclude_scopes = ["offline_access"]``. + + :param dict http_cache: + MSAL has long been caching tokens in the ``token_cache``. + Recently, MSAL also introduced a concept of ``http_cache``, + by automatically caching some finite amount of non-token http responses, + so that *long-lived* + ``PublicClientApplication`` and ``ConfidentialClientApplication`` + would be more performant and responsive in some situations. + + This ``http_cache`` parameter accepts any dict-like object. + If not provided, MSAL will use an in-memory dict. + + If your app is a command-line app (CLI), + you would want to persist your http_cache across different CLI runs. + The following recipe shows a way to do so:: + + # Just add the following lines at the beginning of your CLI script + import sys, atexit, pickle + http_cache_filename = sys.argv[0] + ".http_cache" + try: + with open(http_cache_filename, "rb") as f: + persisted_http_cache = pickle.load(f) # Take a snapshot + except ( + IOError, # A non-exist http cache file + pickle.UnpicklingError, # A corrupted http cache file + EOFError, # An empty http cache file + AttributeError, ImportError, IndexError, # Other corruption + ): + persisted_http_cache = {} # Recover by starting afresh + atexit.register(lambda: pickle.dump( + # When exit, flush it back to the file. + # It may occasionally overwrite another process's concurrent write, + # but that is fine. Subsequent runs will reach eventual consistency. + persisted_http_cache, open(http_cache_file, "wb"))) + + # And then you can implement your app as you normally would + app = msal.PublicClientApplication( + "your_client_id", + ..., + http_cache=persisted_http_cache, # Utilize persisted_http_cache + ..., + #token_cache=..., # You may combine the old token_cache trick + # Please refer to token_cache recipe at + # https://msal-python.readthedocs.io/en/latest/#msal.SerializableTokenCache + ) + app.acquire_token_interactive(["your", "scope"], ...) + + Content inside ``http_cache`` are cheap to obtain. + There is no need to share them among different apps. + + Content inside ``http_cache`` will contain no tokens nor + Personally Identifiable Information (PII). Encryption is unnecessary. + + New in version 1.16.0. """ self.client_id = client_id self.client_credential = client_credential @@ -370,7 +427,7 @@ def __init__( self.http_client.mount("https://", a) self.http_client = ThrottledHttpClient( self.http_client, - {} # Hard code an in-memory cache, for now + {} if http_cache is None else http_cache, # Default to an in-memory dict ) self.app_name = app_name @@ -437,17 +494,18 @@ def _build_telemetry_context( correlation_id=correlation_id, refresh_reason=refresh_reason) def _get_regional_authority(self, central_authority): - is_region_specified = bool(self._region_configured - and self._region_configured != self.ATTEMPT_REGION_DISCOVERY) self._region_detected = self._region_detected or _detect_region( self.http_client if self._region_configured is not None else None) - if (is_region_specified and self._region_configured != self._region_detected): + if (self._region_configured != self.ATTEMPT_REGION_DISCOVERY + and self._region_configured != self._region_detected): logger.warning('Region configured ({}) != region detected ({})'.format( repr(self._region_configured), repr(self._region_detected))) region_to_use = ( - self._region_configured if is_region_specified else self._region_detected) + self._region_detected + if self._region_configured == self.ATTEMPT_REGION_DISCOVERY + else self._region_configured) # It will retain the None i.e. opted out + logger.debug('Region to be used: {}'.format(repr(region_to_use))) if region_to_use: - logger.info('Region to be used: {}'.format(repr(region_to_use))) regional_host = ("{}.r.login.microsoftonline.com".format(region_to_use) if central_authority.instance in ( # The list came from https://github.com/AzureAD/microsoft-authentication-library-for-python/pull/358/files#r629400328 diff --git a/msal/oauth2cli/authcode.py b/msal/oauth2cli/authcode.py index 85bbd889..bcef60b8 100644 --- a/msal/oauth2cli/authcode.py +++ b/msal/oauth2cli/authcode.py @@ -7,6 +7,7 @@ """ import logging import socket +import sys from string import Template import threading import time @@ -103,7 +104,17 @@ def log_message(self, format, *args): logger.debug(format, *args) # To override the default log-to-stderr behavior -class _AuthCodeHttpServer(HTTPServer): +class _AuthCodeHttpServer(HTTPServer, object): + def __init__(self, server_address, *args, **kwargs): + _, port = server_address + if port and (sys.platform == "win32" or is_wsl()): + # The default allow_reuse_address is True. It works fine on non-Windows. + # On Windows, it undesirably allows multiple servers listening on same port, + # yet the second server would not receive any incoming request. + # So, we need to turn it off. + self.allow_reuse_address = False + super(_AuthCodeHttpServer, self).__init__(server_address, *args, **kwargs) + def handle_timeout(self): # It will be triggered when no request comes in self.timeout seconds. # See https://docs.python.org/3/library/socketserver.html#socketserver.BaseServer.handle_timeout @@ -119,7 +130,7 @@ class _AuthCodeHttpServer6(_AuthCodeHttpServer): class AuthCodeReceiver(object): # This class has (rather than is) an _AuthCodeHttpServer, so it does not leak API - def __init__(self, port=None): + def __init__(self, port=None, scheduled_actions=None): """Create a Receiver waiting for incoming auth response. :param port: @@ -128,6 +139,12 @@ def __init__(self, port=None): If your Identity Provider supports dynamic port, you can use port=0 here. Port 0 means to use an arbitrary unused port, per this official example: https://docs.python.org/2.7/library/socketserver.html#asynchronous-mixins + + :param scheduled_actions: + For example, if the input is + ``[(10, lambda: print("Got stuck during sign in? Call 800-000-0000"))]`` + then the receiver would call that lambda function after + waiting the response for 10 seconds. """ address = "127.0.0.1" # Hardcode, for now, Not sure what to expose, yet. # Per RFC 8252 (https://tools.ietf.org/html/rfc8252#section-8.3): @@ -141,6 +158,7 @@ def __init__(self, port=None): # When this server physically listens to a specific IP (as it should), # you will still be able to specify your redirect_uri using either # IP (e.g. 127.0.0.1) or localhost, whichever matches your registration. + self._scheduled_actions = sorted(scheduled_actions or []) # Make a copy Server = _AuthCodeHttpServer6 if ":" in address else _AuthCodeHttpServer # TODO: But, it would treat "localhost" or "" as IPv4. # If pressed, we might just expose a family parameter to caller. @@ -215,6 +233,10 @@ def get_auth_response(self, timeout=None, **kwargs): time.sleep(1) # Short detection interval to make happy path responsive if not t.is_alive(): # Then the thread has finished its job and exited break + while (self._scheduled_actions + and time.time() - begin > self._scheduled_actions[0][0]): + _, callback = self._scheduled_actions.pop(0) + callback() return result or None def _get_auth_response(self, result, auth_uri=None, timeout=None, state=None, diff --git a/msal/oauth2cli/oauth2.py b/msal/oauth2cli/oauth2.py index 8d337bb9..e092b3dd 100644 --- a/msal/oauth2cli/oauth2.py +++ b/msal/oauth2cli/oauth2.py @@ -3,9 +3,9 @@ import json try: - from urllib.parse import urlencode, parse_qs, quote_plus, urlparse + from urllib.parse import urlencode, parse_qs, quote_plus, urlparse, urlunparse except ImportError: - from urlparse import parse_qs, urlparse + from urlparse import parse_qs, urlparse, urlunparse from urllib import urlencode, quote_plus import logging import warnings @@ -573,16 +573,8 @@ def authorize(): # A controller in a web app def obtain_token_by_browser( # Name influenced by RFC 8252: "native apps should (use) ... user's browser" self, - scope=None, - extra_scope_to_consent=None, redirect_uri=None, - timeout=None, - welcome_template=None, - success_template=None, - error_template=None, - auth_params=None, - auth_uri_callback=None, - browser_name=None, + auth_code_receiver=None, **kwargs): """A native app can use this method to obtain token via a local browser. @@ -625,38 +617,64 @@ def obtain_token_by_browser( :return: Same as :func:`~obtain_token_by_auth_code_flow()` """ + if auth_code_receiver: # Then caller already knows the listen port + return self._obtain_token_by_browser( # Use all input param as-is + auth_code_receiver, redirect_uri=redirect_uri, **kwargs) + # Otherwise we will listen on _redirect_uri.port _redirect_uri = urlparse(redirect_uri or "http://127.0.0.1:0") if not _redirect_uri.hostname: raise ValueError("redirect_uri should contain hostname") - if _redirect_uri.scheme == "https": - raise ValueError("Our local loopback server will not use https") - listen_port = _redirect_uri.port if _redirect_uri.port is not None else 80 - # This implementation allows port-less redirect_uri to mean port 80 + listen_port = ( # Conventionally, port-less uri would mean port 80 + 80 if _redirect_uri.port is None else _redirect_uri.port) try: with _AuthCodeReceiver(port=listen_port) as receiver: - flow = self.initiate_auth_code_flow( - redirect_uri="http://{host}:{port}".format( - host=_redirect_uri.hostname, port=receiver.get_port(), - ) if _redirect_uri.port is not None else "http://{host}".format( - host=_redirect_uri.hostname - ), # This implementation uses port-less redirect_uri as-is - scope=_scope_set(scope) | _scope_set(extra_scope_to_consent), - **(auth_params or {})) - auth_response = receiver.get_auth_response( - auth_uri=flow["auth_uri"], - state=flow["state"], # Optional but we choose to do it upfront - timeout=timeout, - welcome_template=welcome_template, - success_template=success_template, - error_template=error_template, - auth_uri_callback=auth_uri_callback, - browser_name=browser_name, - ) + uri = redirect_uri if _redirect_uri.port != 0 else urlunparse(( + _redirect_uri.scheme, + "{}:{}".format(_redirect_uri.hostname, receiver.get_port()), + _redirect_uri.path, + _redirect_uri.params, + _redirect_uri.query, + _redirect_uri.fragment, + )) # It could be slightly different than raw redirect_uri + self.logger.debug("Using {} as redirect_uri".format(uri)) + return self._obtain_token_by_browser( + receiver, redirect_uri=uri, **kwargs) except PermissionError: - if 0 < listen_port < 1024: - self.logger.error( - "Can't listen on port %s. You may try port 0." % listen_port) - raise + raise ValueError( + "Can't listen on port %s. You may try port 0." % listen_port) + + def _obtain_token_by_browser( + self, + auth_code_receiver, + scope=None, + extra_scope_to_consent=None, + redirect_uri=None, + timeout=None, + welcome_template=None, + success_template=None, + error_template=None, + auth_params=None, + auth_uri_callback=None, + browser_name=None, + **kwargs): + # Internally, it calls self.initiate_auth_code_flow() and + # self.obtain_token_by_auth_code_flow(). + # + # Parameters are documented in public method obtain_token_by_browser(). + flow = self.initiate_auth_code_flow( + redirect_uri=redirect_uri, + scope=_scope_set(scope) | _scope_set(extra_scope_to_consent), + **(auth_params or {})) + auth_response = auth_code_receiver.get_auth_response( + auth_uri=flow["auth_uri"], + state=flow["state"], # Optional but we choose to do it upfront + timeout=timeout, + welcome_template=welcome_template, + success_template=success_template, + error_template=error_template, + auth_uri_callback=auth_uri_callback, + browser_name=browser_name, + ) return self.obtain_token_by_auth_code_flow( flow, auth_response, scope=scope, **kwargs) diff --git a/msal/region.py b/msal/region.py index dacd49d7..c540dc71 100644 --- a/msal/region.py +++ b/msal/region.py @@ -5,6 +5,9 @@ def _detect_region(http_client=None): + region = os.environ.get("REGION_NAME", "").replace(" ", "").lower() # e.g. westus2 + if region: + return region if http_client: return _detect_region_of_azure_vm(http_client) # It could hang for minutes return None diff --git a/tests/authcode.py b/tests/authcode.py deleted file mode 100644 index 4973d4c2..00000000 --- a/tests/authcode.py +++ /dev/null @@ -1,77 +0,0 @@ -import argparse -import webbrowser -import logging - -try: # Python 3 - from http.server import HTTPServer, BaseHTTPRequestHandler - from urllib.parse import urlparse, parse_qs, urlencode -except ImportError: # Fall back to Python 2 - from BaseHTTPServer import HTTPServer, BaseHTTPRequestHandler - from urlparse import urlparse, parse_qs - from urllib import urlencode - - -def build_auth_url(authority, client_id): - # Lucky that redirect_uri can be omitted, so it works for any app - return "{a}/oauth2/authorize?response_type=code&client_id={c}".format( - a=authority, c=client_id) - -class AuthCodeReceiver(BaseHTTPRequestHandler): - """A one-stop solution to acquire an authorization code. - - This helper starts a web server as redirect_uri, waiting for auth code. - It also opens a browser window to guide a human tester to manually login. - After obtaining an auth code, the web server will be shut down. - """ # Note: This docstring is also used by this script's command line help. - @classmethod - def acquire(cls, auth_endpoint, redirect_port): - """Usage: ac = AuthCodeReceiver.acquire('http://.../authorize', 8088)""" - webbrowser.open( - "http://localhost:{p}?{q}".format(p=redirect_port, q=urlencode({ - "text": """Open this link to acquire auth code. - If you prefer, you may want to use incognito window.""", - "link": auth_endpoint,}))) - logging.warn( - """Listening on http://localhost:{}, and a browser window is opened - for you on THIS machine, and waiting for human interaction. - This function call will hang until an auth code is received. - """.format(redirect_port)) - server = HTTPServer(("", int(redirect_port)), cls) - server.authcode = None - while not server.authcode: # https://docs.python.org/2/library/basehttpserver.html#more-examples - server.handle_request() - return server.authcode - - def do_GET(self): - # For flexibility, we choose to not check self.path matching redirect_uri - #assert self.path.startswith('/THE_PATH_REGISTERED_BY_THE_APP') - qs = parse_qs(urlparse(self.path).query) - if qs.get('code'): # Then store it into the server instance - ac = self.server.authcode = qs['code'][0] - self.send_full_response('Authcode:\n{}'.format(ac)) - # NOTE: Don't do self.server.shutdown() here. It'll halt the server. - elif qs.get('text') and qs.get('link'): # Then display a landing page - self.send_full_response('{text}'.format( - link=qs['link'][0], text=qs['text'][0])) - else: - self.send_full_response("This web service serves your redirect_uri") - - def send_full_response(self, body, is_ok=True): - self.send_response(200 if is_ok else 400) - content_type = 'text/html' if body.startswith('<') else 'text/plain' - self.send_header('Content-type', content_type) - self.end_headers() - self.wfile.write(body) - -if __name__ == '__main__': - p = parser = argparse.ArgumentParser( - description=AuthCodeReceiver.__doc__ - + "The auth code received will be dumped into stdout.") - p.add_argument('client_id', help="The client_id of your web service app") - p.add_argument('redirect_port', type=int, help="The port in redirect_uri") - p.add_argument( - "--authority", default="https://login.microsoftonline.com/common") - args = parser.parse_args() - print(AuthCodeReceiver.acquire( - build_auth_url(args.authority, args.client_id), args.redirect_port)) - diff --git a/tests/test_authcode.py b/tests/test_authcode.py new file mode 100644 index 00000000..c7e7565f --- /dev/null +++ b/tests/test_authcode.py @@ -0,0 +1,26 @@ +import unittest +import socket +import sys + +from msal.oauth2cli.authcode import AuthCodeReceiver + + +class TestAuthCodeReceiver(unittest.TestCase): + def test_setup_at_a_given_port_and_teardown(self): + port = 12345 # Assuming this port is available + with AuthCodeReceiver(port=port) as receiver: + self.assertEqual(port, receiver.get_port()) + + def test_setup_at_a_ephemeral_port_and_teardown(self): + port = 0 + with AuthCodeReceiver(port=port) as receiver: + self.assertNotEqual(port, receiver.get_port()) + + def test_no_two_concurrent_receivers_can_listen_on_same_port(self): + port = 12345 # Assuming this port is available + with AuthCodeReceiver(port=port) as receiver: + expected_error = OSError if sys.version_info[0] > 2 else socket.error + with self.assertRaises(expected_error): + with AuthCodeReceiver(port=port) as receiver2: + pass + diff --git a/tests/test_e2e.py b/tests/test_e2e.py index 2defecd6..a23806ed 100644 --- a/tests/test_e2e.py +++ b/tests/test_e2e.py @@ -791,7 +791,7 @@ class WorldWideRegionalEndpointTestCase(LabBasedTestCase): region = "westus" timeout = 2 # Short timeout makes this test case responsive on non-VM - def test_acquire_token_for_client_should_hit_regional_endpoint(self): + def _test_acquire_token_for_client(self, configured_region, expected_region): """This is the only grant supported by regional endpoint, for now""" self.app = get_lab_app( # Regional endpoint only supports confidential client @@ -799,8 +799,7 @@ def test_acquire_token_for_client_should_hit_regional_endpoint(self): #authority="https://westus.login.microsoft.com/microsoft.onmicrosoft.com", #validate_authority=False, authority="https://login.microsoftonline.com/microsoft.onmicrosoft.com", - azure_region=self.region, # Explicitly use this region, regardless of detection - + azure_region=configured_region, timeout=2, # Short timeout makes this test case responsive on non-VM ) scopes = ["https://graph.microsoft.com/.default"] @@ -809,9 +808,11 @@ def test_acquire_token_for_client_should_hit_regional_endpoint(self): self.app.http_client, "post", return_value=MinimalResponse( status_code=400, text='{"error": "mock"}')) as mocked_method: self.app.acquire_token_for_client(scopes) + expected_host = '{}.r.login.microsoftonline.com'.format( + expected_region) if expected_region else 'login.microsoftonline.com' mocked_method.assert_called_with( - 'https://westus.r.login.microsoftonline.com/{}/oauth2/v2.0/token'.format( - self.app.authority.tenant), + 'https://{}/{}/oauth2/v2.0/token'.format( + expected_host, self.app.authority.tenant), params=ANY, data=ANY, headers=ANY) result = self.app.acquire_token_for_client( scopes, @@ -820,6 +821,29 @@ def test_acquire_token_for_client_should_hit_regional_endpoint(self): self.assertIn('access_token', result) self.assertCacheWorksForApp(result, scopes) + def test_acquire_token_for_client_should_hit_global_endpoint_by_default(self): + self._test_acquire_token_for_client(None, None) + + def test_acquire_token_for_client_should_ignore_env_var_by_default(self): + os.environ["REGION_NAME"] = "eastus" + self._test_acquire_token_for_client(None, None) + del os.environ["REGION_NAME"] + + def test_acquire_token_for_client_should_use_a_specified_region(self): + self._test_acquire_token_for_client("westus", "westus") + + def test_acquire_token_for_client_should_use_an_env_var_with_short_region_name(self): + os.environ["REGION_NAME"] = "eastus" + self._test_acquire_token_for_client( + msal.ConfidentialClientApplication.ATTEMPT_REGION_DISCOVERY, "eastus") + del os.environ["REGION_NAME"] + + def test_acquire_token_for_client_should_use_an_env_var_with_long_region_name(self): + os.environ["REGION_NAME"] = "East Us 2" + self._test_acquire_token_for_client( + msal.ConfidentialClientApplication.ATTEMPT_REGION_DISCOVERY, "eastus2") + del os.environ["REGION_NAME"] + @unittest.skipUnless( os.getenv("LAB_OBO_CLIENT_SECRET"), "Need LAB_OBO_CLIENT_SECRET from https://aka.ms/GetLabSecret?Secret=TodoListServiceV2-OBO")