diff --git a/msal/application.py b/msal/application.py index 6efde65b..0881c613 100644 --- a/msal/application.py +++ b/msal/application.py @@ -9,7 +9,7 @@ import requests -from .oauth2cli import Client, JwtSigner +from .oauth2cli import Client, JwtAssertionCreator from .authority import Authority from .mex import send_request as mex_send_request from .wstrust_request import send_request as wst_send_request @@ -18,7 +18,7 @@ # The __init__.py will import this. Not the other way around. -__version__ = "0.4.1" +__version__ = "0.5.0" logger = logging.getLogger(__name__) @@ -50,16 +50,33 @@ def decorate_scope( return list(decorated) +def extract_certs(public_cert_content): + # Parses raw public certificate file contents and returns a list of strings + # Usage: headers = {"x5c": extract_certs(open("my_cert.pem").read())} + public_certificates = re.findall( + r'-----BEGIN CERTIFICATE-----(?P[^-]+)-----END CERTIFICATE-----', + public_cert_content, re.I) + if public_certificates: + return [cert.strip() for cert in public_certificates] + # The public cert tags are not found in the input, + # let's make best effort to exclude a private key pem file. + if "PRIVATE KEY" in public_cert_content: + raise ValueError( + "We expect your public key but detect a private key instead") + return [public_cert_content.strip()] + + class ClientApplication(object): def __init__( self, client_id, client_credential=None, authority=None, validate_authority=True, token_cache=None, - verify=True, proxies=None, timeout=None): + verify=True, proxies=None, timeout=None, + client_claims=None): """Create an instance of application. - :param client_id: Your app has a clinet_id after you register it on AAD. + :param client_id: Your app has a client_id after you register it on AAD. :param client_credential: For :class:`PublicClientApplication`, you simply use `None` here. For :class:`ConfidentialClientApplication`, @@ -69,6 +86,28 @@ def __init__( { "private_key": "...-----BEGIN PRIVATE KEY-----...", "thumbprint": "A1B2C3D4E5F6...", + "public_certificate": "...-----BEGIN CERTIFICATE-----..." (Optional. See below.) + } + + *Added in version 0.5.0*: + public_certificate (optional) is public key certificate + which will be sent through 'x5c' JWT header only for + subject name and issuer authentication to support cert auto rolls. + + :param dict client_claims: + *Added in version 0.5.0*: + It is a dictionary of extra claims that would be signed by + by this :class:`ConfidentialClientApplication` 's private key. + For example, you can use {"client_ip": "x.x.x.x"}. + You may also override any of the following default claims:: + + { + "aud": the_token_endpoint, + "iss": self.client_id, + "sub": same_as_issuer, + "exp": now + 10_min, + "iat": now, + "jti": a_random_uuid } :param str authority: @@ -95,6 +134,7 @@ def __init__( """ self.client_id = client_id self.client_credential = client_credential + self.client_claims = client_claims self.verify = verify self.proxies = proxies self.timeout = timeout @@ -113,11 +153,15 @@ def _build_client(self, client_credential, authority): if isinstance(client_credential, dict): assert ("private_key" in client_credential and "thumbprint" in client_credential) - signer = JwtSigner( + headers = {} + if 'public_certificate' in client_credential: + headers["x5c"] = extract_certs(client_credential['public_certificate']) + assertion = JwtAssertionCreator( client_credential["private_key"], algorithm="RS256", - sha1_thumbprint=client_credential.get("thumbprint")) - client_assertion = signer.sign_assertion( - audience=authority.token_endpoint, issuer=self.client_id) + sha1_thumbprint=client_credential.get("thumbprint"), headers=headers) + client_assertion = assertion.create_regenerative_assertion( + audience=authority.token_endpoint, issuer=self.client_id, + additional_claims=self.client_claims or {}) client_assertion_type = Client.CLIENT_ASSERTION_TYPE_JWT else: default_body['client_secret'] = client_credential diff --git a/msal/oauth2cli/__init__.py b/msal/oauth2cli/__init__.py index 2129912b..b8941361 100644 --- a/msal/oauth2cli/__init__.py +++ b/msal/oauth2cli/__init__.py @@ -1,5 +1,6 @@ -__version__ = "0.2.0" +__version__ = "0.3.0" from .oidc import Client -from .assertion import JwtSigner +from .assertion import JwtAssertionCreator +from .assertion import JwtSigner # Obsolete. For backward compatibility. diff --git a/msal/oauth2cli/assertion.py b/msal/oauth2cli/assertion.py index bd2373a7..e84400df 100644 --- a/msal/oauth2cli/assertion.py +++ b/msal/oauth2cli/assertion.py @@ -9,17 +9,57 @@ logger = logging.getLogger(__name__) -class Signer(object): - def sign_assertion( - self, audience, issuer, subject, expires_at, +class AssertionCreator(object): + def create_normal_assertion( + self, audience, issuer, subject, expires_at=None, expires_in=600, issued_at=None, assertion_id=None, **kwargs): - # Names are defined in https://tools.ietf.org/html/rfc7521#section-5 + """Create an assertion in bytes, based on the provided claims. + + All parameter names are defined in https://tools.ietf.org/html/rfc7521#section-5 + except the expires_in is defined here as lifetime-in-seconds, + which will be automatically translated into expires_at in UTC. + """ raise NotImplementedError("Will be implemented by sub-class") + def create_regenerative_assertion( + self, audience, issuer, subject=None, expires_in=600, **kwargs): + """Create an assertion as a callable, + which will then compute the assertion later when necessary. + + This is a useful optimization to reuse the client assertion. + """ + return AutoRefresher( # Returns a callable + lambda a=audience, i=issuer, s=subject, e=expires_in, kwargs=kwargs: + self.create_normal_assertion(a, i, s, expires_in=e, **kwargs), + expires_in=max(expires_in-60, 0)) + + +class AutoRefresher(object): + """Cache the output of a factory, and auto-refresh it when necessary. Usage:: -class JwtSigner(Signer): + r = AutoRefresher(time.time, expires_in=5) + for i in range(15): + print(r()) # the timestamp change only after every 5 seconds + time.sleep(1) + """ + def __init__(self, factory, expires_in=540): + self._factory = factory + self._expires_in = expires_in + self._buf = {} + def __call__(self): + EXPIRES_AT, VALUE = "expires_at", "value" + now = time.time() + if self._buf.get(EXPIRES_AT, 0) <= now: + logger.debug("Regenerating new assertion") + self._buf = {VALUE: self._factory(), EXPIRES_AT: now + self._expires_in} + else: + logger.debug("Reusing still valid assertion") + return self._buf.get(VALUE) + + +class JwtAssertionCreator(AssertionCreator): def __init__(self, key, algorithm, sha1_thumbprint=None, headers=None): - """Create a signer. + """Construct a Jwt assertion creator. Args: @@ -37,11 +77,11 @@ def __init__(self, key, algorithm, sha1_thumbprint=None, headers=None): self.headers["x5t"] = base64.urlsafe_b64encode( binascii.a2b_hex(sha1_thumbprint)).decode() - def sign_assertion( - self, audience, issuer, subject=None, expires_at=None, + def create_normal_assertion( + self, audience, issuer, subject=None, expires_at=None, expires_in=600, issued_at=None, assertion_id=None, not_before=None, additional_claims=None, **kwargs): - """Sign a JWT Assertion. + """Create a JWT Assertion. Parameters are defined in https://tools.ietf.org/html/rfc7523#section-3 Key-value pairs in additional_claims will be added into payload as-is. @@ -51,7 +91,7 @@ def sign_assertion( 'aud': audience, 'iss': issuer, 'sub': subject or issuer, - 'exp': expires_at or (now + 10*60), # 10 minutes + 'exp': expires_at or (now + expires_in), 'iat': issued_at or now, 'jti': assertion_id or str(uuid.uuid4()), } @@ -68,3 +108,9 @@ def sign_assertion( 'See https://pyjwt.readthedocs.io/en/latest/installation.html#cryptographic-dependencies-optional') raise + +# Obsolete. For backward compatibility. They will be removed in future versions. +Signer = AssertionCreator # For backward compatibility +JwtSigner = JwtAssertionCreator # For backward compatibility +JwtSigner.sign_assertion = JwtAssertionCreator.create_normal_assertion # For backward compatibility + diff --git a/msal/oauth2cli/oauth2.py b/msal/oauth2cli/oauth2.py index b9727cf5..918fb806 100644 --- a/msal/oauth2cli/oauth2.py +++ b/msal/oauth2cli/oauth2.py @@ -33,7 +33,7 @@ def __init__( server_configuration, # type: dict client_id, # type: str client_secret=None, # type: Optional[str] - client_assertion=None, # type: Optional[bytes] + client_assertion=None, # type: Union[bytes, callable, None] client_assertion_type=None, # type: Optional[str] default_headers=None, # type: Optional[dict] default_body=None, # type: Optional[dict] @@ -55,10 +55,12 @@ def __init__( https://example.com/.../.well-known/openid-configuration client_id (str): The client's id, issued by the authorization server client_secret (str): Triggers HTTP AUTH for Confidential Client - client_assertion (bytes): + client_assertion (bytes, callable): The client assertion to authenticate this client, per RFC 7521. It can be a raw SAML2 assertion (this method will encode it for you), or a raw JWT assertion. + It can also be a callable (recommended), + so that we will do lazy creation of an assertion. client_assertion_type (str): The type of your :attr:`client_assertion` parameter. It is typically the value of :attr:`CLIENT_ASSERTION_TYPE_SAML2` or @@ -75,11 +77,9 @@ def __init__( self.configuration = server_configuration self.client_id = client_id self.client_secret = client_secret + self.client_assertion = client_assertion self.default_body = default_body or {} - if client_assertion is not None and client_assertion_type is not None: - # See https://tools.ietf.org/html/rfc7521#section-4.2 - encoder = self.client_assertion_encoders.get(client_assertion_type, lambda a: a) - self.default_body["client_assertion"] = encoder(client_assertion) + if client_assertion_type is not None: self.default_body["client_assertion_type"] = client_assertion_type self.logger = logging.getLogger(__name__) self.session = s = requests.Session() @@ -114,6 +114,15 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749 **kwargs # Relay all extra parameters to underlying requests ): # Returns the json object came from the OAUTH2 response _data = {'client_id': self.client_id, 'grant_type': grant_type} + + if self.default_body.get("client_assertion_type") and self.client_assertion: + # See https://tools.ietf.org/html/rfc7521#section-4.2 + encoder = self.client_assertion_encoders.get( + self.default_body["client_assertion_type"], lambda a: a) + _data["client_assertion"] = encoder( + self.client_assertion() # Do lazy on-the-fly computation + if callable(self.client_assertion) else self.client_assertion) + _data.update(self.default_body) # It may contain authen parameters _data.update(data or {}) # So the content in data param prevails # We don't have to clean up None values here, because requests lib will. diff --git a/sample/authorization-code-flow-sample/authorization_code_flow_sample.py b/sample/authorization-code-flow-sample/authorization_code_flow_sample.py index 48d32e80..eea11dff 100644 --- a/sample/authorization-code-flow-sample/authorization_code_flow_sample.py +++ b/sample/authorization-code-flow-sample/authorization_code_flow_sample.py @@ -24,6 +24,7 @@ import json import logging import uuid +import os import flask @@ -31,7 +32,8 @@ app = flask.Flask(__name__) app.debug = True -app.secret_key = sys.argv[2] # In this demo, we expect a secret from 2nd CLI param +app.secret_key = os.environ.get("FLASK_SECRET") +assert app.secret_key, "This sample requires a FLASK_SECRET env var to enable session" # Optional logging diff --git a/sample/device_flow_sample.py b/sample/device_flow_sample.py index 8c46c6b0..21923ef7 100644 --- a/sample/device_flow_sample.py +++ b/sample/device_flow_sample.py @@ -56,6 +56,10 @@ # Ideally you should wait here, in order to save some unnecessary polling # input("Press Enter after you successfully login from another device...") result = app.acquire_token_by_device_flow(flow) # By default it will block + # You can follow this instruction to shorten the block time + # https://msal-python.readthedocs.io/en/latest/#msal.PublicClientApplication.acquire_token_by_device_flow + # or you may even turn off the blocking behavior, + # and then keep calling acquire_token_by_device_flow(flow) in your own customized loop. if "access_token" in result: print(result["access_token"]) diff --git a/setup.py b/setup.py index c1e0b7ac..06256333 100644 --- a/setup.py +++ b/setup.py @@ -66,7 +66,8 @@ 'License :: OSI Approved :: MIT License', 'Operating System :: OS Independent', ], - packages=find_packages(), + packages=find_packages(exclude=["tests"]), + data_files=[('', ['LICENSE'])], install_requires=[ 'requests>=2.0.0,<3', 'PyJWT[crypto]>=1.0.0,<2', diff --git a/tests/test_application.py b/tests/test_application.py index 3860c735..5e4c3b3a 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -20,7 +20,7 @@ with open(CONFIG_FILE) as conf: CONFIG = json.load(conf) -logger = logging.getLogger(__file__) +logger = logging.getLogger(__name__) logging.basicConfig(level=logging.DEBUG) @@ -99,6 +99,46 @@ def test_client_certificate(self): self.assertIn('access_token', result) self.assertCacheWorks(result, app.acquire_token_silent(scope, account=None)) + def test_extract_a_tag_less_public_cert(self): + pem = "my_cert" + self.assertEqual(["my_cert"], extract_certs(pem)) + + def test_extract_a_tag_enclosed_cert(self): + pem = """ + -----BEGIN CERTIFICATE----- + my_cert + -----END CERTIFICATE----- + """ + self.assertEqual(["my_cert"], extract_certs(pem)) + + def test_extract_multiple_tag_enclosed_certs(self): + pem = """ + -----BEGIN CERTIFICATE----- + my_cert1 + -----END CERTIFICATE----- + + -----BEGIN CERTIFICATE----- + my_cert2 + -----END CERTIFICATE----- + """ + self.assertEqual(["my_cert1", "my_cert2"], extract_certs(pem)) + + @unittest.skipUnless("public_certificate" in CONFIG, "Missing Public cert") + def test_subject_name_issuer_authentication(self): + assert ("private_key_file" in CONFIG + and "thumbprint" in CONFIG and "public_certificate" in CONFIG) + with open(os.path.join(THIS_FOLDER, CONFIG['private_key_file'])) as f: + pem = f.read() + with open(os.path.join(THIS_FOLDER, CONFIG['public_certificate'])) as f: + public_certificate = f.read() + app = ConfidentialClientApplication( + CONFIG['client_id'], authority=CONFIG["authority"], + client_credential={"private_key": pem, "thumbprint": CONFIG["thumbprint"], + "public_certificate": public_certificate}) + scope = CONFIG.get("scope", []) + result = app.acquire_token_for_client(scope) + self.assertIn('access_token', result) + self.assertCacheWorks(result, app.acquire_token_silent(scope, account=None)) @unittest.skipUnless("client_id" in CONFIG, "client_id missing") class TestPublicClientApplication(Oauth2TestCase): diff --git a/tests/test_assertion.py b/tests/test_assertion.py new file mode 100644 index 00000000..a4921138 --- /dev/null +++ b/tests/test_assertion.py @@ -0,0 +1,15 @@ +import json + +from msal.oauth2cli import JwtSigner +from msal.oauth2cli.oidc import base64decode + +from tests import unittest + + +class AssertionTestCase(unittest.TestCase): + def test_extra_claims(self): + assertion = JwtSigner(key=None, algorithm="none").sign_assertion( + "audience", "issuer", additional_claims={"client_ip": "1.2.3.4"}) + payload = json.loads(base64decode(assertion.split(b'.')[1].decode('utf-8'))) + self.assertEqual("1.2.3.4", payload.get("client_ip")) +