diff --git a/src/oidcendpoint/jwt_token.py b/src/oidcendpoint/jwt_token.py index 361fe56..a032dff 100644 --- a/src/oidcendpoint/jwt_token.py +++ b/src/oidcendpoint/jwt_token.py @@ -3,10 +3,12 @@ from typing import Optional from cryptojwt import JWT +from cryptojwt.jws.exception import JWSException from oidcendpoint.exception import ToOld from oidcendpoint.token_handler import Token from oidcendpoint.token_handler import is_expired +from oidcendpoint.token_handler import UnknownToken from oidcendpoint.user_info import scope2claims @@ -94,7 +96,10 @@ def info(self, token): :return: tuple of token type and session id """ verifier = JWT(key_jar=self.key_jar, allowed_sign_algs=[self.alg]) - _payload = verifier.unpack(token) + try: + _payload = verifier.unpack(token) + except JWSException: + raise UnknownToken() if is_expired(_payload["exp"]): raise ToOld("Token has expired") diff --git a/src/oidcendpoint/oauth2/introspection.py b/src/oidcendpoint/oauth2/introspection.py index 1422cb5..c0d3a2b 100644 --- a/src/oidcendpoint/oauth2/introspection.py +++ b/src/oidcendpoint/oauth2/introspection.py @@ -1,8 +1,6 @@ """Implements RFC7662""" import logging -from cryptojwt import JWT -from cryptojwt.jws.jws import factory from oidcmsg import oauth2 from oidcmsg.time_util import utc_time_sans_frac @@ -11,18 +9,6 @@ LOGGER = logging.getLogger(__name__) -def before(t1, t2, range): - if t1 < (t2 - range): - return True - return False - - -def after(t1, t2, range): - if t1 > (t2 + range): - return True - return False - - class Introspection(Endpoint): """Implements RFC 7662""" @@ -49,49 +35,30 @@ def get_client_id_from_token(self, endpoint_context, token, request=None): sinfo = endpoint_context.sdb[token] return sinfo["authn_req"]["client_id"] - def do_jws(self, token): - _jwt = JWT(key_jar=self.endpoint_context.keyjar) - + def _introspect(self, token): try: - _jwt_info = _jwt.unpack(token) - except Exception as err: - return None - - return _jwt_info - - def do_access_token(self, token): - try: - _info = self.endpoint_context.sdb[token] + info = self.endpoint_context.sdb[token] except KeyError: return None - _revoked = _info.get("revoked", False) - if _revoked: + # Make sure that the token is an access_token or a refresh_token + if token not in info.get("access_token") and token != info.get( + "refresh_token" + ): return None - _eat = _info.get("expires_at") - if _eat < utc_time_sans_frac(): + eat = info.get("expires_at") + if eat and eat < utc_time_sans_frac(): return None - if _info: # Now what can be returned ? - _ret = { - "sub": _info["sub"], - "client_id": _info["client_id"], - "token_type": "bearer", - "iss": self.endpoint_context.issuer - } - _authn = _info.get("authn_event") - if _authn: - _ret["authn_info"] = _authn["authn_info"] - _ret["authn_time"] = _authn["authn_time"] - - _scope = _info.get("scope") - if not _scope: - _ret["scope"] = " ".join(_info["authn_req"]["scope"]) - - return _ret - else: - return _info + if info: # Now what can be returned ? + ret = info.to_dict() + ret["iss"] = self.endpoint_context.issuer + + if "scope" not in ret: + ret["scope"] = " ".join(info["authn_req"]["scope"]) + + return ret def process_request(self, request=None, **kwargs): """ @@ -107,29 +74,8 @@ def process_request(self, request=None, **kwargs): _token = _introspect_request["token"] _resp = self.response_cls(active=False) - if factory(_token): - _info = self.do_jws(_token) - if _info is None: - return {"response_args": _resp} - _now = utc_time_sans_frac() - - # Time checks - if "exp" in _info: - if after(_now, _info["exp"], self.offset): - return {"response_args": _resp} - if 'iat' in _info: - if after(_info["iat"], _now, self.offset): - return {"response_args": _resp} - if 'nbf' in _info: - if before(_now, _info["nbf"], self.offset): - return {"response_args": _resp} - else: - # A non-jws access token - _info = self.do_access_token(_token) - if _info is None: - return {"response_args": _resp} - - if not _info: + _info = self._introspect(_token) + if _info is None: return {"response_args": _resp} if "release" in self.kwargs: diff --git a/src/oidcendpoint/session.py b/src/oidcendpoint/session.py index 05182f1..1ccc994 100644 --- a/src/oidcendpoint/session.py +++ b/src/oidcendpoint/session.py @@ -461,8 +461,8 @@ def revoke_session(self, sid="", token=""): _sinfo = self[sid] for token_type in self.handler.keys(): _sinfo.pop(token_type, None) - - self.update(sid, revoked=True) + _sinfo["revoked"] = True + self[sid] = _sinfo def get_client_id_for_session(self, sid): return self[sid]["client_id"] diff --git a/src/oidcendpoint/token_handler.py b/src/oidcendpoint/token_handler.py index 95f1a9c..f6c1fe0 100755 --- a/src/oidcendpoint/token_handler.py +++ b/src/oidcendpoint/token_handler.py @@ -7,6 +7,7 @@ from cryptojwt.key_jar import init_key_jar from cryptojwt.utils import as_bytes from cryptojwt.utils import as_unicode +from cryptojwt.exception import BadSyntax from oidcmsg.time_util import time_sans_frac from oidcendpoint import rndstr @@ -218,7 +219,8 @@ def info(self, item, order=None): for typ in order: try: return self.handler[typ].info(item) - except (KeyError, WrongTokenType, InvalidToken, UnknownToken): + except (KeyError, WrongTokenType, InvalidToken, UnknownToken, + BadSyntax): pass logger.info("Unknown token format") diff --git a/tests/test_31_introspection.py b/tests/test_31_introspection.py index 6810a53..71a454f 100644 --- a/tests/test_31_introspection.py +++ b/tests/test_31_introspection.py @@ -9,6 +9,7 @@ from cryptojwt.utils import as_bytes from oidcmsg.oauth2 import TokenIntrospectionRequest from oidcmsg.oidc import AccessTokenRequest +from oidcmsg.oidc import AccessTokenResponse from oidcmsg.oidc import AuthorizationRequest from oidcmsg.time_util import utc_time_sans_frac @@ -88,9 +89,10 @@ def full_path(local_file): return os.path.join(BASEDIR, local_file) -class TestEndpoint(object): +@pytest.mark.parametrize("jwt_token", [True, False]) +class TestEndpoint: @pytest.fixture(autouse=True) - def create_endpoint(self): + def create_endpoint(self, jwt_token): conf = { "issuer": "https://example.com/", "password": "mycket hemligt", @@ -100,6 +102,26 @@ def create_endpoint(self): "verify_ssl": False, "capabilities": CAPABILITIES, "jwks": {"uri_path": "jwks.json", "key_defs": KEYDEFS}, + "token_handler_args": { + "jwks_def": { + # These keys are used for encrypting the access code, the + # refresh token and the access token(when it's not a JWT), + # I don't think these should be configurable. + # Should these keys be stored somewhere? + "read_only": False, + "key_defs": [ + {"type": "oct", "bytes": 24, "use": ["enc"], + "kid": "code"}, + {"type": "oct", "bytes": 24, "use": ["enc"], + "kid": "refresh"}, + {"type": "oct", "bytes": 24, "use": ["enc"], + "kid": "token"}, + ], + }, + "code": {"lifetime": 600}, + "token": {"lifetime": 3600}, + "refresh": {"lifetime": 86400}, + }, "endpoint": { "authorization": { "path": "{}/authorization", @@ -142,6 +164,9 @@ def create_endpoint(self): "client_authn": verify_client, "template_dir": "template", } + if jwt_token: + conf["token_handler_args"]["token"]["class"] = \ + "oidcendpoint.jwt_token.JWTToken" endpoint_context = EndpointContext(conf) endpoint_context.cdb["client_1"] = { "client_secret": "hemligt", @@ -157,32 +182,32 @@ def create_endpoint(self): self.introspection_endpoint = endpoint_context.endpoint["introspection"] self.token_endpoint = endpoint_context.endpoint["token"] - def _create_jwt(self, uid, lifetime=0, with_jti=False): - _jwt = JWT( - self.introspection_endpoint.endpoint_context.keyjar, - iss=self.introspection_endpoint.endpoint_context.issuer, - lifetime=lifetime, - ) + def _create_at(self, uid, lifetime=0, with_jti=False): + _context = self.introspection_endpoint.endpoint_context - if with_jti: - _jwt.with_jti = with_jti + session_id = setup_session( + _context, + AUTH_REQ, + uid=uid, + acr=INTERNETPROTOCOLPASSWORD, + ) + _token_request = TOKEN_REQ_DICT.copy() + _token_request["code"] = _context.sdb[session_id]["code"] + _context.sdb.update(session_id, user=uid) - _info = self.introspection_endpoint.endpoint_context.userinfo.db[uid] - _payload = {"sub": _info["sub"]} - return _jwt.pack(_payload, aud="client_1") + _req = self.token_endpoint.parse_request(_token_request) + _resp = self.token_endpoint.process_request(request=_req) + _resp = AccessTokenResponse(**_resp["response_args"]) + return _resp["access_token"] def test_parse_no_authn(self): - _ = setup_session( - self.introspection_endpoint.endpoint_context, AUTH_REQ, uid="diana" - ) - _token = self._create_jwt("diana") - with pytest.raises(UnAuthorizedClient): + _token = self._create_at("diana") + with pytest.raises(UnknownOrNoAuthnMethod): self.introspection_endpoint.parse_request({"token": _token}) def test_parse_with_client_auth_in_req(self): _context = self.introspection_endpoint.endpoint_context - _ = setup_session(_context, AUTH_REQ, uid="diana") - _token = self._create_jwt("diana") + _token = self._create_at("diana") _req = self.introspection_endpoint.parse_request( { "token": _token, @@ -196,8 +221,7 @@ def test_parse_with_client_auth_in_req(self): def test_parse_with_wrong_client_authn(self): _context = self.introspection_endpoint.endpoint_context - _ = setup_session(_context, AUTH_REQ, uid="diana") - _token = self._create_jwt("diana") + _token = self._create_at("diana") _basic_token = "{}:{}".format( "client_1", _context.cdb["client_1"]["client_secret"] ) @@ -209,8 +233,7 @@ def test_parse_with_wrong_client_authn(self): def test_process_request(self): _context = self.introspection_endpoint.endpoint_context - _ = setup_session(_context, AUTH_REQ, uid="diana") - _token = self._create_jwt("diana", lifetime=6000) + _token = self._create_at("diana", lifetime=6000) _req = self.introspection_endpoint.parse_request( { "token": _token, @@ -225,8 +248,7 @@ def test_process_request(self): def test_do_response(self): _context = self.introspection_endpoint.endpoint_context - _ = setup_session(_context, AUTH_REQ, uid="diana") - _token = self._create_jwt("diana", lifetime=6000, with_jti=True) + _token = self._create_at("diana", lifetime=6000, with_jti=True) _req = self.introspection_endpoint.parse_request( { "token": _token, @@ -245,20 +267,18 @@ def test_do_response(self): ] _payload = json.loads(msg_info["response"]) assert set(_payload.keys()) == { - "sub", - "username", - "exp", - "iat", - "aud", "active", "iss", - "jti", + "scope", + "token_type", + "sub", + "client_id", } assert _payload["active"] is True def test_do_response_no_token(self): _context = self.introspection_endpoint.endpoint_context - _ = setup_session(_context, AUTH_REQ, uid="diana") + self._create_at("diana", lifetime=6000, with_jti=True) _req = self.introspection_endpoint.parse_request( { "client_id": "client_1", @@ -270,23 +290,10 @@ def test_do_response_no_token(self): def test_access_token(self): _context = self.introspection_endpoint.endpoint_context - - session_id = setup_session( - _context, - AUTH_REQ, - uid="user", - acr=INTERNETPROTOCOLPASSWORD, - ) - _token_request = TOKEN_REQ_DICT.copy() - _token_request["code"] = _context.sdb[session_id]["code"] - _context.sdb.update(session_id, user="diana") - - _req = self.token_endpoint.parse_request(_token_request) - _resp = self.token_endpoint.process_request(request=_req) - + _token = self._create_at("diana", lifetime=6000, with_jti=True) _req = self.introspection_endpoint.parse_request( { - "token": _resp["response_args"]["access_token"], + "token": _token, "client_id": "client_1", "client_secret": _context.cdb["client_1"]["client_secret"], } @@ -326,27 +333,14 @@ def test_jwt_unknown_key(self): def test_expired_access_token(self): _context = self.introspection_endpoint.endpoint_context - - session_id = setup_session( - _context, - AUTH_REQ, - uid="user", - acr=INTERNETPROTOCOLPASSWORD, - ) - _token_request = TOKEN_REQ_DICT.copy() - _token_request["code"] = _context.sdb[session_id]["code"] - _context.sdb.update(session_id, user="diana") - - _req = self.token_endpoint.parse_request(_token_request) - _resp = self.token_endpoint.process_request(request=_req) - - _info = self.token_endpoint.endpoint_context.sdb[_resp["response_args"]["access_token"]] + _token = self._create_at("diana", lifetime=6000, with_jti=True) + _info = self.token_endpoint.endpoint_context.sdb[_token] _info['expires_at'] = utc_time_sans_frac() - 1000 - self.token_endpoint.endpoint_context.sdb[_resp["response_args"]["access_token"]] = _info + self.token_endpoint.endpoint_context.sdb[_token] = _info _req = self.introspection_endpoint.parse_request( { - "token": _resp["response_args"]["access_token"], + "token": _token, "client_id": "client_1", "client_secret": _context.cdb["client_1"]["client_secret"], } @@ -355,30 +349,18 @@ def test_expired_access_token(self): assert _resp["response_args"]["active"] is False def test_revoked_access_token(self): + _token = self._create_at("diana", lifetime=6000, with_jti=True) _context = self.introspection_endpoint.endpoint_context - session_id = setup_session( - _context, - AUTH_REQ, - uid="user", - acr=INTERNETPROTOCOLPASSWORD, - ) - _token_request = TOKEN_REQ_DICT.copy() - _token_request["code"] = _context.sdb[session_id]["code"] - _context.sdb.update(session_id, user="diana") - - _req = self.token_endpoint.parse_request(_token_request) - _resp = self.token_endpoint.process_request(request=_req) - self.token_endpoint.endpoint_context.sdb.revoke_session( - token=_resp["response_args"]["access_token"]) + token=_token) _req = self.introspection_endpoint.parse_request( { - "token": _resp["response_args"]["access_token"], + "token": _token, "client_id": "client_1", "client_secret": _context.cdb["client_1"]["client_secret"], } ) _resp = self.introspection_endpoint.process_request(_req) - assert _resp["response_args"]["active"] is False \ No newline at end of file + assert _resp["response_args"]["active"] is False