Skip to content
This repository was archived by the owner on Jun 12, 2021. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/oidcendpoint/id_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ class IDToken(object):
def __init__(self, endpoint_context, **kwargs):
self.endpoint_context = endpoint_context
self.kwargs = kwargs
self.enable_claims_per_client = kwargs.get(
'enable_claims_per_client', False
)
self.scope_to_claims = None
self.provider_info = construct_endpoint_info(
self.default_capabilities, **kwargs
Expand Down Expand Up @@ -242,7 +245,6 @@ def sign_encrypt(

def make(self, req, sess_info, authn_req=None, user_claims=False, **kwargs):
_context = self.endpoint_context
_sdb = _context.sdb

if authn_req:
_client_id = authn_req["client_id"]
Expand All @@ -251,11 +253,13 @@ def make(self, req, sess_info, authn_req=None, user_claims=False, **kwargs):

_cinfo = _context.cdb[_client_id]

default_idtoken_claims = dict(self.kwargs.get("default_claims", {}))
idtoken_claims = dict(self.kwargs.get("default_claims", {}))
if self.enable_claims_per_client:
idtoken_claims.update(_cinfo.get("id_token_claims", {}))
lifetime = self.kwargs.get("lifetime")

userinfo = userinfo_in_id_token_claims(
_context, sess_info, default_idtoken_claims
_context, sess_info, idtoken_claims
)

if user_claims:
Expand Down
128 changes: 127 additions & 1 deletion tests/test_03_id_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def full_path(local_file):
return os.path.join(BASEDIR, local_file)


USERINFO = UserInfo(json.loads(open(full_path("users.json")).read()))
USERS = json.loads(open(full_path("users.json")).read())
USERINFO = UserInfo(USERS)

AREQN = AuthorizationRequest(
response_type="code",
Expand Down Expand Up @@ -70,6 +71,10 @@ def full_path(local_file):
"kwargs": {"user": "diana"},
}
},
"userinfo": {
"class": "oidcendpoint.user_info.UserInfo",
"kwargs": {"db": USERS},
},
"client_authn": verify_client,
"template_dir": "template",
"id_token": {"class": IDToken, "kwargs": {"foo": "bar"}},
Expand Down Expand Up @@ -252,3 +257,124 @@ def test_get_sign_algorithm_4(self):
)
# default signing alg
assert algs == {"sign": True, "encrypt": False, "sign_alg": "RS512"}

def test_default_claims(self):
session_info = {
"authn_req": AREQN,
"sub": "sub",
"authn_event": {
"authn_info": "loa2",
"authn_time": time.time(),
"uid": "diana"
},
}
self.endpoint_context.idtoken.kwargs['default_claims'] = {
"nickname": {"essential": True}
}
req = {"client_id": "client_1"}
_token = self.endpoint_context.idtoken.make(req, session_info)
assert _token
client_keyjar = KeyJar()
_jwks = self.endpoint_context.keyjar.export_jwks()
client_keyjar.import_jwks(_jwks, self.endpoint_context.issuer)
_jwt = JWT(key_jar=client_keyjar, iss="client_1")
res = _jwt.unpack(_token)
assert "nickname" in res

def test_no_default_claims(self):
session_info = {
"authn_req": AREQN,
"sub": "sub",
"authn_event": {
"authn_info": "loa2",
"authn_time": time.time(),
"uid": "diana"
},
}
req = {"client_id": "client_1"}
_token = self.endpoint_context.idtoken.make(req, session_info)
assert _token
client_keyjar = KeyJar()
_jwks = self.endpoint_context.keyjar.export_jwks()
client_keyjar.import_jwks(_jwks, self.endpoint_context.issuer)
_jwt = JWT(key_jar=client_keyjar, iss="client_1")
res = _jwt.unpack(_token)
assert "nickname" not in res

def test_client_claims(self):
session_info = {
"authn_req": AREQN,
"sub": "sub",
"authn_event": {
"authn_info": "loa2",
"authn_time": time.time(),
"uid": "diana"
},
}
self.endpoint_context.idtoken.enable_claims_per_client = True
self.endpoint_context.cdb["client_1"]['id_token_claims'] = {
"address": None
}
req = {"client_id": "client_1"}
_token = self.endpoint_context.idtoken.make(req, session_info)
assert _token
client_keyjar = KeyJar()
_jwks = self.endpoint_context.keyjar.export_jwks()
client_keyjar.import_jwks(_jwks, self.endpoint_context.issuer)
_jwt = JWT(key_jar=client_keyjar, iss="client_1")
res = _jwt.unpack(_token)
assert "address" in res
assert "nickname" not in res

def test_client_claims_with_default(self):
session_info = {
"authn_req": AREQN,
"sub": "sub",
"authn_event": {
"authn_info": "loa2",
"authn_time": time.time(),
"uid": "diana"
},
}
self.endpoint_context.cdb["client_1"]['id_token_claims'] = {
"address": None
}
self.endpoint_context.idtoken.kwargs['default_claims'] = {
"nickname": {"essential": True}
}
self.endpoint_context.idtoken.enable_claims_per_client = True
req = {"client_id": "client_1"}
_token = self.endpoint_context.idtoken.make(req, session_info)
assert _token
client_keyjar = KeyJar()
_jwks = self.endpoint_context.keyjar.export_jwks()
client_keyjar.import_jwks(_jwks, self.endpoint_context.issuer)
_jwt = JWT(key_jar=client_keyjar, iss="client_1")
res = _jwt.unpack(_token)
assert "address" in res
assert "nickname" in res

def test_client_claims_disabled(self):
# enable_claims_per_client defaults to False
session_info = {
"authn_req": AREQN,
"sub": "sub",
"authn_event": {
"authn_info": "loa2",
"authn_time": time.time(),
"uid": "diana"
},
}
self.endpoint_context.cdb["client_1"]['id_token_claims'] = {
"address": None
}
req = {"client_id": "client_1"}
_token = self.endpoint_context.idtoken.make(req, session_info)
assert _token
client_keyjar = KeyJar()
_jwks = self.endpoint_context.keyjar.export_jwks()
client_keyjar.import_jwks(_jwks, self.endpoint_context.issuer)
_jwt = JWT(key_jar=client_keyjar, iss="client_1")
res = _jwt.unpack(_token)
assert "address" not in res
assert "nickname" not in res