From d5bc366cdb66bf8e06ca27adcb4f74c59b321f68 Mon Sep 17 00:00:00 2001 From: Nikos Sklikas Date: Thu, 29 Oct 2020 14:02:07 +0200 Subject: [PATCH] Add add_scope option for JWT access token Add add_scope option JWT access tokens, enabling this will add a list with the allowed scopes that were requested in the returned JWT. --- src/oidcendpoint/jwt_token.py | 10 +++++++++- src/oidcendpoint/scopes.py | 4 ++++ src/oidcendpoint/userinfo.py | 5 ++--- tests/test_27_jwt_token.py | 15 +++++++++++++++ 4 files changed, 30 insertions(+), 4 deletions(-) diff --git a/src/oidcendpoint/jwt_token.py b/src/oidcendpoint/jwt_token.py index d211eea..e3f4ec8 100644 --- a/src/oidcendpoint/jwt_token.py +++ b/src/oidcendpoint/jwt_token.py @@ -16,6 +16,7 @@ class JWTToken(Token): init_args = { "add_claims_by_scope": False, "enable_claims_per_client": False, + "add_scope": False, "add_claims": {}, } @@ -49,6 +50,7 @@ def __init__( self.add_claims = self.init_args["add_claims"] self.add_claims_by_scope = self.init_args["add_claims_by_scope"] + self.add_scope = self.init_args["add_scope"] self.enable_claims_per_client = self.init_args["enable_claims_per_client"] for param, default in self.init_args.items(): @@ -83,6 +85,7 @@ def __call__( :return: """ payload = {"sid": sid, "ttype": self.type, "sub": sinfo["sub"]} + scopes = sinfo["authn_req"]["scope"] if self.add_claims: self.do_add_claims(payload, uinfo, self.add_claims) @@ -94,11 +97,16 @@ def __call__( payload, uinfo, convert_scopes2claims( - sinfo["authn_req"]["scope"], + scopes, _allowed_claims, map=self.scope_claims_map, ).keys(), ) + if self.add_scope: + payload["scope"] = self.cntx.scopes_handler.filter_scopes( + client_id, self.cntx, scopes + ) + # Add claims if is access token if self.type == "T" and self.enable_claims_per_client: client = self.cdb.get(client_id, {}) diff --git a/src/oidcendpoint/scopes.py b/src/oidcendpoint/scopes.py index b3413ae..8a04544 100644 --- a/src/oidcendpoint/scopes.py +++ b/src/oidcendpoint/scopes.py @@ -82,6 +82,10 @@ def allowed_scopes(self, client_id, endpoint_context): return available_scopes(endpoint_context) return [] + def filter_scopes(self, client_id, endpoint_context, scopes): + allowed_scopes = self.allowed_scopes(client_id, endpoint_context) + return [s for s in scopes if s in allowed_scopes] + class Claims: def __init__(self): diff --git a/src/oidcendpoint/userinfo.py b/src/oidcendpoint/userinfo.py index 56684fd..590c2ad 100755 --- a/src/oidcendpoint/userinfo.py +++ b/src/oidcendpoint/userinfo.py @@ -127,10 +127,9 @@ def collect_user_info( if scope_to_claims is None: scope_to_claims = endpoint_context.scope2claims - _allowed = endpoint_context.scopes_handler.allowed_scopes( - authn_req["client_id"], endpoint_context + supported_scopes = endpoint_context.scopes_handler.filter_scopes( + authn_req["client_id"], endpoint_context, authn_req["scope"] ) - supported_scopes = [s for s in authn_req["scope"] if s in _allowed] if userinfo_claims is None: _allowed_claims = endpoint_context.claims_handler.allowed_claims( authn_req["client_id"], endpoint_context diff --git a/tests/test_27_jwt_token.py b/tests/test_27_jwt_token.py index cd37c27..da0336d 100644 --- a/tests/test_27_jwt_token.py +++ b/tests/test_27_jwt_token.py @@ -205,6 +205,21 @@ def test_client_claims(self, enable_claims_per_client): res = _jwt.unpack(token) assert enable_claims_per_client is ("address" in res) + @pytest.mark.parametrize("add_scope", [True, False]) + def test_add_scopes(self, add_scope): + ec = self.endpoint.endpoint_context + handler = ec.sdb.handler.handler["access_token"] + auth_req = dict(AUTH_REQ) + auth_req["scope"] = ["openid", "profile", "aba"] + session_id = setup_session(ec, auth_req, uid="diana") + handler.add_scope = add_scope + _dic = ec.sdb.upgrade_to_token(key=session_id) + + token = _dic["access_token"] + _jwt = JWT(key_jar=KEYJAR, iss="client_1") + res = _jwt.unpack(token) + assert add_scope is (res.get("scope") == ["openid", "profile"]) + def test_is_expired(self): session_id = setup_session( self.endpoint.endpoint_context, AUTH_REQ, uid="diana"