Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 52 additions & 8 deletions msal/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import requests

from .oauth2cli import Client, JwtSigner
from .oauth2cli import Client, JwtAssertionCreator
from .authority import Authority
from .mex import send_request as mex_send_request
from .wstrust_request import send_request as wst_send_request
Expand All @@ -18,7 +18,7 @@


# The __init__.py will import this. Not the other way around.
__version__ = "0.4.1"
__version__ = "0.5.0"

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -50,16 +50,33 @@ def decorate_scope(
return list(decorated)


def extract_certs(public_cert_content):
# Parses raw public certificate file contents and returns a list of strings
# Usage: headers = {"x5c": extract_certs(open("my_cert.pem").read())}
public_certificates = re.findall(
r'-----BEGIN CERTIFICATE-----(?P<cert_value>[^-]+)-----END CERTIFICATE-----',
public_cert_content, re.I)
if public_certificates:
return [cert.strip() for cert in public_certificates]
# The public cert tags are not found in the input,
# let's make best effort to exclude a private key pem file.
if "PRIVATE KEY" in public_cert_content:
raise ValueError(
"We expect your public key but detect a private key instead")
return [public_cert_content.strip()]


class ClientApplication(object):

def __init__(
self, client_id,
client_credential=None, authority=None, validate_authority=True,
token_cache=None,
verify=True, proxies=None, timeout=None):
verify=True, proxies=None, timeout=None,
client_claims=None):
"""Create an instance of application.

:param client_id: Your app has a clinet_id after you register it on AAD.
:param client_id: Your app has a client_id after you register it on AAD.
:param client_credential:
For :class:`PublicClientApplication`, you simply use `None` here.
For :class:`ConfidentialClientApplication`,
Expand All @@ -69,6 +86,28 @@ def __init__(
{
"private_key": "...-----BEGIN PRIVATE KEY-----...",
"thumbprint": "A1B2C3D4E5F6...",
"public_certificate": "...-----BEGIN CERTIFICATE-----..." (Optional. See below.)
}

*Added in version 0.5.0*:
public_certificate (optional) is public key certificate
which will be sent through 'x5c' JWT header only for
subject name and issuer authentication to support cert auto rolls.

:param dict client_claims:
*Added in version 0.5.0*:
It is a dictionary of extra claims that would be signed by
by this :class:`ConfidentialClientApplication` 's private key.
For example, you can use {"client_ip": "x.x.x.x"}.
You may also override any of the following default claims::

{
"aud": the_token_endpoint,
"iss": self.client_id,
"sub": same_as_issuer,
"exp": now + 10_min,
"iat": now,
"jti": a_random_uuid
}

:param str authority:
Expand All @@ -95,6 +134,7 @@ def __init__(
"""
self.client_id = client_id
self.client_credential = client_credential
self.client_claims = client_claims
self.verify = verify
self.proxies = proxies
self.timeout = timeout
Expand All @@ -113,11 +153,15 @@ def _build_client(self, client_credential, authority):
if isinstance(client_credential, dict):
assert ("private_key" in client_credential
and "thumbprint" in client_credential)
signer = JwtSigner(
headers = {}
if 'public_certificate' in client_credential:
headers["x5c"] = extract_certs(client_credential['public_certificate'])
assertion = JwtAssertionCreator(
client_credential["private_key"], algorithm="RS256",
sha1_thumbprint=client_credential.get("thumbprint"))
client_assertion = signer.sign_assertion(
audience=authority.token_endpoint, issuer=self.client_id)
sha1_thumbprint=client_credential.get("thumbprint"), headers=headers)
client_assertion = assertion.create_regenerative_assertion(
audience=authority.token_endpoint, issuer=self.client_id,
additional_claims=self.client_claims or {})
client_assertion_type = Client.CLIENT_ASSERTION_TYPE_JWT
else:
default_body['client_secret'] = client_credential
Expand Down
5 changes: 3 additions & 2 deletions msal/oauth2cli/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
__version__ = "0.2.0"
__version__ = "0.3.0"

from .oidc import Client
from .assertion import JwtSigner
from .assertion import JwtAssertionCreator
from .assertion import JwtSigner # Obsolete. For backward compatibility.

66 changes: 56 additions & 10 deletions msal/oauth2cli/assertion.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,57 @@

logger = logging.getLogger(__name__)

class Signer(object):
def sign_assertion(
self, audience, issuer, subject, expires_at,
class AssertionCreator(object):
def create_normal_assertion(
self, audience, issuer, subject, expires_at=None, expires_in=600,
issued_at=None, assertion_id=None, **kwargs):
# Names are defined in https://tools.ietf.org/html/rfc7521#section-5
"""Create an assertion in bytes, based on the provided claims.

All parameter names are defined in https://tools.ietf.org/html/rfc7521#section-5
except the expires_in is defined here as lifetime-in-seconds,
which will be automatically translated into expires_at in UTC.
"""
raise NotImplementedError("Will be implemented by sub-class")

def create_regenerative_assertion(
self, audience, issuer, subject=None, expires_in=600, **kwargs):
"""Create an assertion as a callable,
which will then compute the assertion later when necessary.

This is a useful optimization to reuse the client assertion.
"""
return AutoRefresher( # Returns a callable
lambda a=audience, i=issuer, s=subject, e=expires_in, kwargs=kwargs:
self.create_normal_assertion(a, i, s, expires_in=e, **kwargs),
expires_in=max(expires_in-60, 0))


class AutoRefresher(object):
"""Cache the output of a factory, and auto-refresh it when necessary. Usage::

class JwtSigner(Signer):
r = AutoRefresher(time.time, expires_in=5)
for i in range(15):
print(r()) # the timestamp change only after every 5 seconds
time.sleep(1)
"""
def __init__(self, factory, expires_in=540):
self._factory = factory
self._expires_in = expires_in
self._buf = {}
def __call__(self):
EXPIRES_AT, VALUE = "expires_at", "value"
now = time.time()
if self._buf.get(EXPIRES_AT, 0) <= now:
logger.debug("Regenerating new assertion")
self._buf = {VALUE: self._factory(), EXPIRES_AT: now + self._expires_in}
else:
logger.debug("Reusing still valid assertion")
return self._buf.get(VALUE)


class JwtAssertionCreator(AssertionCreator):
def __init__(self, key, algorithm, sha1_thumbprint=None, headers=None):
"""Create a signer.
"""Construct a Jwt assertion creator.

Args:

Expand All @@ -37,11 +77,11 @@ def __init__(self, key, algorithm, sha1_thumbprint=None, headers=None):
self.headers["x5t"] = base64.urlsafe_b64encode(
binascii.a2b_hex(sha1_thumbprint)).decode()

def sign_assertion(
self, audience, issuer, subject=None, expires_at=None,
def create_normal_assertion(
self, audience, issuer, subject=None, expires_at=None, expires_in=600,
issued_at=None, assertion_id=None, not_before=None,
additional_claims=None, **kwargs):
"""Sign a JWT Assertion.
"""Create a JWT Assertion.

Parameters are defined in https://tools.ietf.org/html/rfc7523#section-3
Key-value pairs in additional_claims will be added into payload as-is.
Expand All @@ -51,7 +91,7 @@ def sign_assertion(
'aud': audience,
'iss': issuer,
'sub': subject or issuer,
'exp': expires_at or (now + 10*60), # 10 minutes
'exp': expires_at or (now + expires_in),
'iat': issued_at or now,
'jti': assertion_id or str(uuid.uuid4()),
}
Expand All @@ -68,3 +108,9 @@ def sign_assertion(
'See https://pyjwt.readthedocs.io/en/latest/installation.html#cryptographic-dependencies-optional')
raise


# Obsolete. For backward compatibility. They will be removed in future versions.
Signer = AssertionCreator # For backward compatibility
JwtSigner = JwtAssertionCreator # For backward compatibility
JwtSigner.sign_assertion = JwtAssertionCreator.create_normal_assertion # For backward compatibility

21 changes: 15 additions & 6 deletions msal/oauth2cli/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(
server_configuration, # type: dict
client_id, # type: str
client_secret=None, # type: Optional[str]
client_assertion=None, # type: Optional[bytes]
client_assertion=None, # type: Union[bytes, callable, None]
client_assertion_type=None, # type: Optional[str]
default_headers=None, # type: Optional[dict]
default_body=None, # type: Optional[dict]
Expand All @@ -55,10 +55,12 @@ def __init__(
https://example.com/.../.well-known/openid-configuration
client_id (str): The client's id, issued by the authorization server
client_secret (str): Triggers HTTP AUTH for Confidential Client
client_assertion (bytes):
client_assertion (bytes, callable):
The client assertion to authenticate this client, per RFC 7521.
It can be a raw SAML2 assertion (this method will encode it for you),
or a raw JWT assertion.
It can also be a callable (recommended),
so that we will do lazy creation of an assertion.
client_assertion_type (str):
The type of your :attr:`client_assertion` parameter.
It is typically the value of :attr:`CLIENT_ASSERTION_TYPE_SAML2` or
Expand All @@ -75,11 +77,9 @@ def __init__(
self.configuration = server_configuration
self.client_id = client_id
self.client_secret = client_secret
self.client_assertion = client_assertion
self.default_body = default_body or {}
if client_assertion is not None and client_assertion_type is not None:
# See https://tools.ietf.org/html/rfc7521#section-4.2
encoder = self.client_assertion_encoders.get(client_assertion_type, lambda a: a)
self.default_body["client_assertion"] = encoder(client_assertion)
if client_assertion_type is not None:
self.default_body["client_assertion_type"] = client_assertion_type
self.logger = logging.getLogger(__name__)
self.session = s = requests.Session()
Expand Down Expand Up @@ -114,6 +114,15 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749
**kwargs # Relay all extra parameters to underlying requests
): # Returns the json object came from the OAUTH2 response
_data = {'client_id': self.client_id, 'grant_type': grant_type}

if self.default_body.get("client_assertion_type") and self.client_assertion:
# See https://tools.ietf.org/html/rfc7521#section-4.2
encoder = self.client_assertion_encoders.get(
self.default_body["client_assertion_type"], lambda a: a)
_data["client_assertion"] = encoder(
self.client_assertion() # Do lazy on-the-fly computation
if callable(self.client_assertion) else self.client_assertion)

_data.update(self.default_body) # It may contain authen parameters
_data.update(data or {}) # So the content in data param prevails
# We don't have to clean up None values here, because requests lib will.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,16 @@
import json
import logging
import uuid
import os

import flask

import msal

app = flask.Flask(__name__)
app.debug = True
app.secret_key = sys.argv[2] # In this demo, we expect a secret from 2nd CLI param
app.secret_key = os.environ.get("FLASK_SECRET")
assert app.secret_key, "This sample requires a FLASK_SECRET env var to enable session"


# Optional logging
Expand Down
4 changes: 4 additions & 0 deletions sample/device_flow_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@
# Ideally you should wait here, in order to save some unnecessary polling
# input("Press Enter after you successfully login from another device...")
result = app.acquire_token_by_device_flow(flow) # By default it will block
# You can follow this instruction to shorten the block time
# https://msal-python.readthedocs.io/en/latest/#msal.PublicClientApplication.acquire_token_by_device_flow
# or you may even turn off the blocking behavior,
# and then keep calling acquire_token_by_device_flow(flow) in your own customized loop.

if "access_token" in result:
print(result["access_token"])
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@
'License :: OSI Approved :: MIT License',
'Operating System :: OS Independent',
],
packages=find_packages(),
packages=find_packages(exclude=["tests"]),
data_files=[('', ['LICENSE'])],
install_requires=[
'requests>=2.0.0,<3',
'PyJWT[crypto]>=1.0.0,<2',
Expand Down
42 changes: 41 additions & 1 deletion tests/test_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
with open(CONFIG_FILE) as conf:
CONFIG = json.load(conf)

logger = logging.getLogger(__file__)
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.DEBUG)


Expand Down Expand Up @@ -99,6 +99,46 @@ def test_client_certificate(self):
self.assertIn('access_token', result)
self.assertCacheWorks(result, app.acquire_token_silent(scope, account=None))

def test_extract_a_tag_less_public_cert(self):
pem = "my_cert"
self.assertEqual(["my_cert"], extract_certs(pem))

def test_extract_a_tag_enclosed_cert(self):
pem = """
-----BEGIN CERTIFICATE-----
my_cert
-----END CERTIFICATE-----
"""
self.assertEqual(["my_cert"], extract_certs(pem))

def test_extract_multiple_tag_enclosed_certs(self):
pem = """
-----BEGIN CERTIFICATE-----
my_cert1
-----END CERTIFICATE-----

-----BEGIN CERTIFICATE-----
my_cert2
-----END CERTIFICATE-----
"""
self.assertEqual(["my_cert1", "my_cert2"], extract_certs(pem))

@unittest.skipUnless("public_certificate" in CONFIG, "Missing Public cert")
def test_subject_name_issuer_authentication(self):
assert ("private_key_file" in CONFIG
and "thumbprint" in CONFIG and "public_certificate" in CONFIG)
with open(os.path.join(THIS_FOLDER, CONFIG['private_key_file'])) as f:
pem = f.read()
with open(os.path.join(THIS_FOLDER, CONFIG['public_certificate'])) as f:
public_certificate = f.read()
app = ConfidentialClientApplication(
CONFIG['client_id'], authority=CONFIG["authority"],
client_credential={"private_key": pem, "thumbprint": CONFIG["thumbprint"],
"public_certificate": public_certificate})
scope = CONFIG.get("scope", [])
result = app.acquire_token_for_client(scope)
self.assertIn('access_token', result)
self.assertCacheWorks(result, app.acquire_token_silent(scope, account=None))

@unittest.skipUnless("client_id" in CONFIG, "client_id missing")
class TestPublicClientApplication(Oauth2TestCase):
Expand Down
15 changes: 15 additions & 0 deletions tests/test_assertion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import json

from msal.oauth2cli import JwtSigner
from msal.oauth2cli.oidc import base64decode

from tests import unittest


class AssertionTestCase(unittest.TestCase):
def test_extra_claims(self):
assertion = JwtSigner(key=None, algorithm="none").sign_assertion(
"audience", "issuer", additional_claims={"client_ip": "1.2.3.4"})
payload = json.loads(base64decode(assertion.split(b'.')[1].decode('utf-8')))
self.assertEqual("1.2.3.4", payload.get("client_ip"))