From 166cf5c156365fb67923d8e700ff2d415ef8d85a Mon Sep 17 00:00:00 2001 From: roland Date: Wed, 24 Jun 2020 11:47:37 +0200 Subject: [PATCH 1/7] Refactored module. --- src/oidcendpoint/oidc/token_coop.py | 237 +++++++++++++--------- tests/test_35_oidc_token_coop_endpoint.py | 3 +- 2 files changed, 139 insertions(+), 101 deletions(-) diff --git a/src/oidcendpoint/oidc/token_coop.py b/src/oidcendpoint/oidc/token_coop.py index 2e812df..2e54e4e 100755 --- a/src/oidcendpoint/oidc/token_coop.py +++ b/src/oidcendpoint/oidc/token_coop.py @@ -1,4 +1,5 @@ import logging +from urllib.parse import urlparse from cryptojwt.jwe.exception import JWEException from cryptojwt.jws.exception import NoSuitableSigningKeys @@ -10,6 +11,7 @@ from oidcmsg.oidc import AccessTokenResponse from oidcmsg.oidc import RefreshAccessTokenRequest from oidcmsg.oidc import TokenErrorResponse +from oidcmsg.storage import importer from oidcendpoint import sanitize from oidcendpoint.cookie import new_cookie @@ -22,50 +24,69 @@ logger = logging.getLogger(__name__) -class TokenCoop(Endpoint): - request_cls = oidc.Message - response_cls = oidc.AccessTokenResponse - error_cls = TokenErrorResponse - request_format = "json" - request_placement = "body" - response_format = "json" - response_placement = "body" - endpoint_name = "token_endpoint" - name = "token" - default_capabilities = {"token_endpoint_auth_signing_alg_values_supported": None} +def aud_and_scope(url): + p = urlparse(url) + return "{}://{}".format(p.scheme, p.netloc), p.path - def __init__(self, endpoint_context, new_refresh_token=False, **kwargs): - Endpoint.__init__(self, endpoint_context, **kwargs) - self.post_parse_request.append(self._post_parse_request) - if "client_authn_method" in kwargs: - self.endpoint_info["token_endpoint_auth_methods_supported"] = kwargs[ - "client_authn_method" - ] - self.allow_refresh = False - self.new_refresh_token = new_refresh_token - def _refresh_access_token(self, req, **kwargs): - _sdb = self.endpoint_context.sdb +class EndpointHelper: + def __init__(self, endpoint, config=None): + self.endpoint = endpoint + self.config = config - rtoken = req["refresh_token"] - try: - _info = _sdb.refresh_token(rtoken, new_refresh=self.new_refresh_token) - except ExpiredToken: - return self.error_cls( - error="invalid_request", error_description="Refresh token is expired" - ) + def post_parse_request(self, request, client_id="", **kwargs): + """Context specific parsing of the request. + This is done after general request parsing and before processing + the request. + """ + raise NotImplemented() - return by_schema(AccessTokenResponse, **_info) + def process_request(self, req, **kwargs): + """Acts on a process request.""" + raise NotImplemented() + + +class AccessToken(EndpointHelper): + def post_parse_request(self, request, client_id="", **kwargs): + """ + This is where clients come to get their access tokens + + :param request: The request + :param authn: Authentication info, comes from HTTP header + :returns: + """ + + request = AccessTokenRequest(**request.to_dict()) + + if "state" in request: + try: + sinfo = self.endpoint.endpoint_context.sdb[request["code"]] + except KeyError: + logger.error("Code not present in SessionDB") + return self.endpoint.error_cls(error="unauthorized_client") + else: + state = sinfo["authn_req"]["state"] + + if state != request["state"]: + logger.error("State value mismatch") + return self.endpoint.error_cls(error="unauthorized_client") + + if "client_id" not in request: # Optional for access token request + request["client_id"] = client_id + + logger.debug("%s: %s" % (request.__class__.__name__, sanitize(request))) - def _access_token(self, req, **kwargs): - _context = self.endpoint_context + return request + + def process_request(self, req, **kwargs): + _context = self.endpoint.endpoint_context _sdb = _context.sdb _log_debug = logger.debug try: _access_code = req["code"].replace(" ", "+") except KeyError: # Missing code parameter - absolutely fatal - return self.error_cls( + return self.endpoint.error_cls( error="invalid_request", error_description="Missing code" ) @@ -73,7 +94,7 @@ def _access_token(self, req, **kwargs): try: _info = _sdb[_access_code] except KeyError: - return self.error_cls( + return self.endpoint.error_cls( error="invalid_grant", error_description="Code is invalid" ) @@ -81,7 +102,7 @@ def _access_token(self, req, **kwargs): # assert that the code is valid if _context.sdb.is_session_revoked(_access_code): - return self.error_cls( + return self.endpoint.error_cls( error="invalid_grant", error_description="Session is revoked" ) @@ -89,7 +110,7 @@ def _access_token(self, req, **kwargs): # verify that the one given here is the correct one. if "redirect_uri" in _authn_req: if req["redirect_uri"] != _authn_req["redirect_uri"]: - return self.error_cls( + return self.endpoint.error_cls( error="invalid_request", error_description="redirect_uri mismatch" ) @@ -111,7 +132,7 @@ def _access_token(self, req, **kwargs): logger.error("%s" % err) # Should revoke the token issued to this access code _sdb.revoke_all_tokens(_access_code) - return self.error_cls( + return self.endpoint.error_cls( error="access_denied", error_description="Access Code already used" ) @@ -131,82 +152,103 @@ def _access_token(self, req, **kwargs): return by_schema(AccessTokenResponse, **_info) - def get_client_id_from_token(self, endpoint_context, token, request=None): - sinfo = endpoint_context.sdb[token] - return sinfo["authn_req"]["client_id"] - def _access_token_post_parse_request(self, request, client_id="", **kwargs): +class RefreshToken(EndpointHelper): + def post_parse_request(self, request, client_id="", **kwargs): """ - This is where clients come to get their access tokens + This is where clients come to refresh their access tokens :param request: The request :param authn: Authentication info, comes from HTTP header :returns: """ + request = RefreshAccessTokenRequest(**request.to_dict()) - request = AccessTokenRequest(**request.to_dict()) - - if "state" in request: - try: - sinfo = self.endpoint_context.sdb[request["code"]] - except KeyError: - logger.error("Code not present in SessionDB") - return self.error_cls(error="unauthorized_client") - else: - state = sinfo["authn_req"]["state"] - - if state != request["state"]: - logger.error("State value mismatch") - return self.error_cls(error="unauthorized_client") + # verify that the request message is correct + try: + request.verify( + keyjar=self.endpoint.endpoint_context.keyjar, opponent_id=client_id + ) + except (MissingRequiredAttribute, ValueError, MissingRequiredValue) as err: + return self.endpoint.error_cls( + error="invalid_request", error_description="%s" % err + ) - if "client_id" not in request: # Optional for access token request + if "client_id" not in request: # Optional for refresh access token request request["client_id"] = client_id logger.debug("%s: %s" % (request.__class__.__name__, sanitize(request))) return request - def _refresh_token_post_parse_request(self, request, client_id="", **kwargs): - """ - This is where clients come to refresh their access tokens + def process_request(self, req, **kwargs): + _sdb = self.endpoint.endpoint_context.sdb - :param request: The request - :param authn: Authentication info, comes from HTTP header - :returns: - """ + rtoken = req["refresh_token"] + try: + _info = _sdb.refresh_token( + rtoken, new_refresh=self.endpoint.new_refresh_token + ) + except ExpiredToken: + return self.error_cls( + error="invalid_request", error_description="Refresh token is expired" + ) - request = RefreshAccessTokenRequest(**request.to_dict()) + return by_schema(AccessTokenResponse, **_info) - # verify that the request message is correct - try: - request.verify(keyjar=self.endpoint_context.keyjar) - except (MissingRequiredAttribute, ValueError, MissingRequiredValue) as err: - return self.error_cls(error="invalid_request", error_description="%s" % err) - try: - keyjar = self.endpoint_context.keyjar - except AttributeError: - keyjar = "" +HELPER_BY_GRANT_TYPE = { + "authorization_code": AccessToken, + "refresh_token": RefreshToken, +} - request.verify(keyjar=keyjar, opponent_id=client_id) - if "client_id" not in request: # Optional for refresh access token request - request["client_id"] = client_id +class TokenCoop(Endpoint): + request_cls = oidc.Message + response_cls = oidc.AccessTokenResponse + error_cls = TokenErrorResponse + request_format = "json" + request_placement = "body" + response_format = "json" + response_placement = "body" + endpoint_name = "token_endpoint" + name = "token" + default_capabilities = {"token_endpoint_auth_signing_alg_values_supported": None} - logger.debug("%s: %s" % (request.__class__.__name__, sanitize(request))) + def __init__(self, endpoint_context, new_refresh_token=False, **kwargs): + Endpoint.__init__(self, endpoint_context, **kwargs) + self.post_parse_request.append(self._post_parse_request) + if "client_authn_method" in kwargs: + self.endpoint_info["token_endpoint_auth_methods_supported"] = kwargs[ + "client_authn_method" + ] - return request + # self.allow_refresh = False + self.new_refresh_token = new_refresh_token + + if "grant_types_support" in kwargs: + _supported = kwargs["grant_types_support"] + self.helper = {} + for key, spec in _supported.items(): + _conf = spec.get("kwargs", {}) + if isinstance(spec["class"], str): + _instance = importer(spec["class"])(self, _conf) + else: + _instance = spec["class"](self, _conf) + self.helper[key] = _instance + else: + self.helper = {k: v(self) for k, v in HELPER_BY_GRANT_TYPE.items()} + + def get_client_id_from_token(self, endpoint_context, token, request=None): + sinfo = endpoint_context.sdb[token] + return sinfo["authn_req"]["client_id"] def _post_parse_request(self, request, client_id="", **kwargs): - if request["grant_type"] == "authorization_code": - return self._access_token_post_parse_request(request, client_id, **kwargs) - else: # request["grant_type"] == "refresh_token": - if self.allow_refresh: - return self._refresh_token_post_parse_request( - request, client_id, **kwargs - ) - else: - raise ProcessError("Refresh Token not allowed") + _helper = self.helper.get(request["grant_type"]) + if _helper: + return _helper.post_parse_request(request, client_id, **kwargs) + else: + raise ProcessError("No support for grant_type: {}", request["grant_type"]) def process_request(self, request=None, **kwargs): """ @@ -218,15 +260,15 @@ def process_request(self, request=None, **kwargs): if isinstance(request, self.error_cls): return request try: - if request["grant_type"] == "authorization_code": - logger.debug("Access Token Request") - response_args = self._access_token(request, **kwargs) - elif request["grant_type"] == "refresh_token": - logger.debug("Refresh Access Token Request") - response_args = self._refresh_access_token(request, **kwargs) + _helper = self.helper.get(request["grant_type"]) + if _helper: + response_args = _helper.process_request(request, **kwargs) else: return self.error_cls( - error="invalid_request", error_description="Wrong grant_type" + error="invalid_request", + error_description="Unsupported grant_type {}".format( + request["grant_type"] + ), ) except JWEException as err: return self.error_cls(error="invalid_request", error_description="%s" % err) @@ -234,11 +276,6 @@ def process_request(self, request=None, **kwargs): if isinstance(response_args, ResponseMessage): return response_args - if request["grant_type"] == "authorization_code": - _token = request["code"].replace(" ", "+") - else: - _token = request["refresh_token"].replace(" ", "+") - _access_token = response_args["access_token"] _cookie = new_cookie( self.endpoint_context, diff --git a/tests/test_35_oidc_token_coop_endpoint.py b/tests/test_35_oidc_token_coop_endpoint.py index 5bf50a5..5ca135e 100755 --- a/tests/test_35_oidc_token_coop_endpoint.py +++ b/tests/test_35_oidc_token_coop_endpoint.py @@ -18,6 +18,7 @@ from oidcendpoint.oidc.authorization import Authorization from oidcendpoint.oidc.provider_config import ProviderConfiguration from oidcendpoint.oidc.registration import Registration +from oidcendpoint.oidc.token_coop import AccessToken from oidcendpoint.oidc.token_coop import TokenCoop from oidcendpoint.session import setup_session from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD @@ -363,7 +364,7 @@ def test_do_refresh_access_token_not_allowed(self): _req = self.endpoint.parse_request(_token_request) _resp = self.endpoint.process_request(request=_req) - self.endpoint.allow_refresh = False + self.endpoint.helper = {"authorization_code": AccessToken} _request = REFRESH_TOKEN_REQ.copy() _request["refresh_token"] = _resp["response_args"]["refresh_token"] From 68b3b63894f28203e35c9d383383c04a76939a64 Mon Sep 17 00:00:00 2001 From: Antonis Angelakis Date: Mon, 31 Aug 2020 19:14:15 +0300 Subject: [PATCH 2/7] Change how supported grant types are configured The user can now set the usual grant types to have a default configuration and class, without the need to set the oidcendpoint class path. This can be done by setting a value of "default" or None (no value in yaml dict). --- src/oidcendpoint/oidc/token_coop.py | 37 ++++++++++++++++++++--------- 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/src/oidcendpoint/oidc/token_coop.py b/src/oidcendpoint/oidc/token_coop.py index 2e54e4e..fbcb81b 100755 --- a/src/oidcendpoint/oidc/token_coop.py +++ b/src/oidcendpoint/oidc/token_coop.py @@ -226,18 +226,33 @@ def __init__(self, endpoint_context, new_refresh_token=False, **kwargs): # self.allow_refresh = False self.new_refresh_token = new_refresh_token - if "grant_types_support" in kwargs: - _supported = kwargs["grant_types_support"] - self.helper = {} - for key, spec in _supported.items(): - _conf = spec.get("kwargs", {}) - if isinstance(spec["class"], str): - _instance = importer(spec["class"])(self, _conf) - else: - _instance = spec["class"](self, _conf) - self.helper[key] = _instance - else: + self.configure_grant_types(kwargs.get("grant_types_supported")) + + def configure_grant_types(self, grant_types_supported): + if grant_types_supported is None: self.helper = {k: v(self) for k, v in HELPER_BY_GRANT_TYPE.items()} + return + + self.helper = {} + # TODO: do we want to allow any grant_type? + for grant_type, grant_type_options in grant_types_supported.items(): + if ( + grant_type_options in ("default", None) + and grant_type in HELPER_BY_GRANT_TYPE + ): + self.helper[grant_type] = HELPER_BY_GRANT_TYPE[grant_type] + continue + _conf = grant_type_options.get('kwargs', {}) + try: + if isinstance(grant_type_options["class"], str): + grant_class = importer(grant_type_options["class"]) + else: + grant_class = grant_type_options["class"] + self.helper[grant_type] = grant_class(self, _conf) + except KeyError: + raise ProcessError( + "Grant type is invalid or missing a valid class to import." + ) def get_client_id_from_token(self, endpoint_context, token, request=None): sinfo = endpoint_context.sdb[token] From a8ab2d8524c05ae1e7508ca3d7239aa2b671a545 Mon Sep 17 00:00:00 2001 From: Antonis Angelakis Date: Mon, 14 Sep 2020 19:05:28 +0300 Subject: [PATCH 3/7] Make multiple fixes --- src/oidcendpoint/oidc/token_coop.py | 9 +++++---- tests/test_35_oidc_token_coop_endpoint.py | 8 +++++--- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/oidcendpoint/oidc/token_coop.py b/src/oidcendpoint/oidc/token_coop.py index fbcb81b..299c0e5 100755 --- a/src/oidcendpoint/oidc/token_coop.py +++ b/src/oidcendpoint/oidc/token_coop.py @@ -263,7 +263,10 @@ def _post_parse_request(self, request, client_id="", **kwargs): if _helper: return _helper.post_parse_request(request, client_id, **kwargs) else: - raise ProcessError("No support for grant_type: {}", request["grant_type"]) + return self.error_cls( + error="invalid_request", + error_description=f"Unsupported grant_type: {request['grant_type']}" + ) def process_request(self, request=None, **kwargs): """ @@ -281,9 +284,7 @@ def process_request(self, request=None, **kwargs): else: return self.error_cls( error="invalid_request", - error_description="Unsupported grant_type {}".format( - request["grant_type"] - ), + error_description=f"Unsupported grant_type: {request['grant_type']}" ) except JWEException as err: return self.error_cls(error="invalid_request", error_description="%s" % err) diff --git a/tests/test_35_oidc_token_coop_endpoint.py b/tests/test_35_oidc_token_coop_endpoint.py index 5ca135e..ad9f201 100755 --- a/tests/test_35_oidc_token_coop_endpoint.py +++ b/tests/test_35_oidc_token_coop_endpoint.py @@ -12,7 +12,6 @@ from oidcendpoint.client_authn import verify_client from oidcendpoint.endpoint_context import EndpointContext from oidcendpoint.exception import MultipleCodeUsage -from oidcendpoint.exception import ProcessError from oidcendpoint.exception import UnAuthorizedClient from oidcendpoint.oidc import userinfo from oidcendpoint.oidc.authorization import Authorization @@ -368,5 +367,8 @@ def test_do_refresh_access_token_not_allowed(self): _request = REFRESH_TOKEN_REQ.copy() _request["refresh_token"] = _resp["response_args"]["refresh_token"] - with pytest.raises(ProcessError): - self.endpoint.parse_request(_request.to_json()) + _resp = self.endpoint.parse_request(_request.to_json()) + assert "error" in _resp + assert "error_description" in _resp + assert _resp["error"] == "invalid_request" + assert _resp["error_description"] == "Unsupported grant_type: refresh_token" From 9f8c94ebbf90ce0a7e2696d9c9de0c6eba9c194d Mon Sep 17 00:00:00 2001 From: Antonis Angelakis Date: Tue, 15 Sep 2020 15:44:08 +0300 Subject: [PATCH 4/7] Add tests for configurable grant types --- tests/test_35_oidc_token_coop_endpoint.py | 68 ++++++++++++++++++++--- 1 file changed, 59 insertions(+), 9 deletions(-) diff --git a/tests/test_35_oidc_token_coop_endpoint.py b/tests/test_35_oidc_token_coop_endpoint.py index ad9f201..85efadd 100755 --- a/tests/test_35_oidc_token_coop_endpoint.py +++ b/tests/test_35_oidc_token_coop_endpoint.py @@ -17,7 +17,7 @@ from oidcendpoint.oidc.authorization import Authorization from oidcendpoint.oidc.provider_config import ProviderConfiguration from oidcendpoint.oidc.registration import Registration -from oidcendpoint.oidc.token_coop import AccessToken +from oidcendpoint.oidc.token_coop import AccessToken, RefreshToken from oidcendpoint.oidc.token_coop import TokenCoop from oidcendpoint.session import setup_session from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD @@ -43,12 +43,6 @@ CAPABILITIES = { "subject_types_supported": ["public", "pairwise"], - "grant_types_supported": [ - "authorization_code", - "implicit", - "urn:ietf:params:oauth:grant-type:jwt-bearer", - "refresh_token", - ], } AUTH_REQ = AuthorizationRequest( @@ -119,7 +113,7 @@ def conf(): "client_secret_post", "client_secret_jwt", "private_key_jwt", - ] + ], }, }, "userinfo": { @@ -158,6 +152,62 @@ def create_endpoint(self, conf): def test_init(self): assert self.endpoint + @pytest.mark.parametrize("grant_types_supported", [ + { + "authorization_code": { + "class": AccessToken + }, + "refresh_token": { + "class": RefreshToken, + }, + }, + {}, # Empty dict should work with the default grant types + { + "authorization_code": { + "class": "oidcendpoint.oidc.token_coop.AccessToken" + }, + }, + { + "authorization_code": { + "class": "oidcendpoint.oidc.token_coop.AccessToken", + "kwargs": {}, + }, + }, + { + "authorization_code": "default", + "refresh_token": None, # This represents a key w/o value in the YAML conf + }, + ]) + def test_init_with_grant_types_supported(self, conf, grant_types_supported): + token_conf = conf["endpoint"]["token"] + token_conf["kwargs"]["grant_types_supported"] = grant_types_supported + endpoint_context = EndpointContext(conf) + assert endpoint_context + + @pytest.mark.parametrize("grant_types_supported", [ + { + "authorization_code": AccessToken, + "refresh_token": { + "class": RefreshToken, + }, + }, + { + "authorization_code": { + "class": "oidcendpoint.UnknownModule" + }, + }, + { + "authorization_code": { + "kwargs": {}, + }, + }, + ]) + def test_errors_in_grant_types_supported(self, conf, grant_types_supported): + token_conf = conf["endpoint"]["token"] + token_conf["kwargs"]["grant_types_supported"] = grant_types_supported + with pytest.raises(Exception): + EndpointContext(conf) + def test_parse(self): session_id = setup_session(self.endpoint.endpoint_context, AUTH_REQ, uid="user") _token_request = TOKEN_REQ_DICT.copy() @@ -241,7 +291,7 @@ def test_process_request_using_private_key_jwt(self): _context.sdb.update(session_id, user="diana") _req = self.endpoint.parse_request(_token_request) - _resp = self.endpoint.process_request(request=_req) + self.endpoint.process_request(request=_req) # 2nd time used with pytest.raises(UnAuthorizedClient): From da4e5d83326a5f4594eac8795e8d185e01dc2be3 Mon Sep 17 00:00:00 2001 From: Antonis Angelakis Date: Wed, 7 Oct 2020 17:20:46 +0300 Subject: [PATCH 5/7] Clean up things in token_coop and minor changes Removed "default" value for grant types and added True instead. So now None or True enables the grant_type with the default class, while False disables it and of course a dict with a custom class uses that instead. --- src/oidcendpoint/oidc/token_coop.py | 78 +++++++++++++---------- tests/test_35_oidc_token_coop_endpoint.py | 2 +- 2 files changed, 47 insertions(+), 33 deletions(-) diff --git a/src/oidcendpoint/oidc/token_coop.py b/src/oidcendpoint/oidc/token_coop.py index 299c0e5..c9dd778 100755 --- a/src/oidcendpoint/oidc/token_coop.py +++ b/src/oidcendpoint/oidc/token_coop.py @@ -1,5 +1,4 @@ import logging -from urllib.parse import urlparse from cryptojwt.jwe.exception import JWEException from cryptojwt.jws.exception import NoSuitableSigningKeys @@ -24,12 +23,7 @@ logger = logging.getLogger(__name__) -def aud_and_scope(url): - p = urlparse(url) - return "{}://{}".format(p.scheme, p.netloc), p.path - - -class EndpointHelper: +class TokenEndpointHelper: def __init__(self, endpoint, config=None): self.endpoint = endpoint self.config = config @@ -39,14 +33,14 @@ def post_parse_request(self, request, client_id="", **kwargs): This is done after general request parsing and before processing the request. """ - raise NotImplemented() + raise NotImplementedError def process_request(self, req, **kwargs): """Acts on a process request.""" - raise NotImplemented() + raise NotImplementedError -class AccessToken(EndpointHelper): +class AccessToken(TokenEndpointHelper): def post_parse_request(self, request, client_id="", **kwargs): """ This is where clients come to get their access tokens @@ -58,18 +52,9 @@ def post_parse_request(self, request, client_id="", **kwargs): request = AccessTokenRequest(**request.to_dict()) - if "state" in request: - try: - sinfo = self.endpoint.endpoint_context.sdb[request["code"]] - except KeyError: - logger.error("Code not present in SessionDB") - return self.endpoint.error_cls(error="unauthorized_client") - else: - state = sinfo["authn_req"]["state"] - - if state != request["state"]: - logger.error("State value mismatch") - return self.endpoint.error_cls(error="unauthorized_client") + error_cls = self._validate_state(request) + if error_cls is not None: + return error_cls if "client_id" not in request: # Optional for access token request request["client_id"] = client_id @@ -78,6 +63,22 @@ def post_parse_request(self, request, client_id="", **kwargs): return request + def _validate_state(self, request): + if "state" not in request: + return + + try: + sinfo = self.endpoint.endpoint_context.sdb[request["code"]] + except KeyError: + logger.error("Code not present in SessionDB") + return self.endpoint.error_cls(error="unauthorized_client") + else: + state = sinfo["authn_req"]["state"] + + if state != request["state"]: + logger.error("State value mismatch") + return self.endpoint.error_cls(error="unauthorized_client") + def process_request(self, req, **kwargs): _context = self.endpoint.endpoint_context _sdb = _context.sdb @@ -153,7 +154,7 @@ def process_request(self, req, **kwargs): return by_schema(AccessTokenResponse, **_info) -class RefreshToken(EndpointHelper): +class RefreshToken(TokenEndpointHelper): def post_parse_request(self, request, client_id="", **kwargs): """ This is where clients come to refresh their access tokens @@ -237,22 +238,35 @@ def configure_grant_types(self, grant_types_supported): # TODO: do we want to allow any grant_type? for grant_type, grant_type_options in grant_types_supported.items(): if ( - grant_type_options in ("default", None) - and grant_type in HELPER_BY_GRANT_TYPE + grant_type_options in (None, True) + and grant_type in HELPER_BY_GRANT_TYPE ): self.helper[grant_type] = HELPER_BY_GRANT_TYPE[grant_type] continue - _conf = grant_type_options.get('kwargs', {}) + elif grant_type_options is False: + continue + try: - if isinstance(grant_type_options["class"], str): - grant_class = importer(grant_type_options["class"]) - else: - grant_class = grant_type_options["class"] - self.helper[grant_type] = grant_class(self, _conf) + grant_class = grant_type_options["class"] except KeyError: raise ProcessError( - "Grant type is invalid or missing a valid class to import." + "Token Endpoint's grant types must be True, None or a dict with a" + " 'class' key." ) + _conf = grant_type_options.get("kwargs", {}) + + if isinstance(grant_class, str): + try: + grant_class = importer(grant_class) + except ValueError: + raise ProcessError( + f"Token Endpoint's grant type class {grant_class} can't" + " be imported." + ) + try: + self.helper[grant_type] = grant_class(self, _conf) + except Exception as e: + raise ProcessError(f"Failed to initialize class {grant_class}: {e}") def get_client_id_from_token(self, endpoint_context, token, request=None): sinfo = endpoint_context.sdb[token] diff --git a/tests/test_35_oidc_token_coop_endpoint.py b/tests/test_35_oidc_token_coop_endpoint.py index 85efadd..db5c675 100755 --- a/tests/test_35_oidc_token_coop_endpoint.py +++ b/tests/test_35_oidc_token_coop_endpoint.py @@ -174,7 +174,7 @@ def test_init(self): }, }, { - "authorization_code": "default", + "authorization_code": True, # Both True and None end up using the defaults "refresh_token": None, # This represents a key w/o value in the YAML conf }, ]) From 02a186862d14e3e751682b20bd19ed86b3303d28 Mon Sep 17 00:00:00 2001 From: Antonis Angelakis Date: Wed, 11 Nov 2020 14:49:19 +0200 Subject: [PATCH 6/7] Improve exception handling in TokenCoop --- src/oidcendpoint/oidc/token_coop.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/oidcendpoint/oidc/token_coop.py b/src/oidcendpoint/oidc/token_coop.py index c9dd778..8e5d73b 100755 --- a/src/oidcendpoint/oidc/token_coop.py +++ b/src/oidcendpoint/oidc/token_coop.py @@ -248,7 +248,7 @@ def configure_grant_types(self, grant_types_supported): try: grant_class = grant_type_options["class"] - except KeyError: + except (KeyError, TypeError): raise ProcessError( "Token Endpoint's grant types must be True, None or a dict with a" " 'class' key." @@ -258,7 +258,7 @@ def configure_grant_types(self, grant_types_supported): if isinstance(grant_class, str): try: grant_class = importer(grant_class) - except ValueError: + except (ValueError, AttributeError): raise ProcessError( f"Token Endpoint's grant type class {grant_class} can't" " be imported." From ec006ad370ebca7f28fa48d11c0b3e34e2cbcd29 Mon Sep 17 00:00:00 2001 From: Antonis Angelakis Date: Wed, 11 Nov 2020 14:49:48 +0200 Subject: [PATCH 7/7] Improve TokenCoop tests and add more --- tests/test_35_oidc_token_coop_endpoint.py | 51 ++++++++++++++++++++++- 1 file changed, 50 insertions(+), 1 deletion(-) diff --git a/tests/test_35_oidc_token_coop_endpoint.py b/tests/test_35_oidc_token_coop_endpoint.py index db5c675..3bda198 100755 --- a/tests/test_35_oidc_token_coop_endpoint.py +++ b/tests/test_35_oidc_token_coop_endpoint.py @@ -7,6 +7,7 @@ from oidcmsg.oidc import AccessTokenRequest from oidcmsg.oidc import AuthorizationRequest from oidcmsg.oidc import RefreshAccessTokenRequest +from oidcmsg.oidc import ResponseMessage from oidcendpoint import JWT_BEARER from oidcendpoint.client_authn import verify_client @@ -205,8 +206,10 @@ def test_init_with_grant_types_supported(self, conf, grant_types_supported): def test_errors_in_grant_types_supported(self, conf, grant_types_supported): token_conf = conf["endpoint"]["token"] token_conf["kwargs"]["grant_types_supported"] = grant_types_supported - with pytest.raises(Exception): + with pytest.raises(Exception) as exception_info: EndpointContext(conf) + assert exception_info.typename == "ProcessError" + assert "Token Endpoint" in str(exception_info.value) def test_parse(self): session_id = setup_session(self.endpoint.endpoint_context, AUTH_REQ, uid="user") @@ -422,3 +425,49 @@ def test_do_refresh_access_token_not_allowed(self): assert "error_description" in _resp assert _resp["error"] == "invalid_request" assert _resp["error_description"] == "Unsupported grant_type: refresh_token" + + def test_custom_grant_class(self, conf): + """ + Register a custom grant type supported and see if it works as it should. + """ + class CustomGrant: + def __init__(self, endpoint, config=None): + self.endpoint = endpoint + + def post_parse_request(self, request, client_id="", **kwargs): + request.testvalue = "test" + return request + + def process_request(self, request, **kwargs): + """ + All grant types should return a ResponseMessage class or inherit it. + """ + return ResponseMessage(test="successful") + + token_conf = conf["endpoint"]["token"] + token_conf["kwargs"]["grant_types_supported"] = { + "authorization_code": True, + "test_grant": { + "class": CustomGrant + } + } + endpoint_context = EndpointContext(conf) + token_endpoint = endpoint_context.endpoint["token"] + token_endpoint.client_authn_method = [None] + endpoint_context.cdb["client_1"] = { + "client_secret": "hemligt", + "redirect_uris": [("https://example.com/cb", None)], + "client_salt": "salted", + "endpoint_auth_method": "client_secret_post", + "response_types": ["code", "token", "code id_token", "id_token"], + } + endpoint_context.keyjar.import_jwks(CLIENT_KEYJAR.export_jwks(), "client_1") + + request = dict(grant_type="test_grant", client_id="client_1") + + parsed_request = token_endpoint.parse_request(request) + assert parsed_request.testvalue == "test" + + response = token_endpoint.process_request(parsed_request) + assert "test" in response + assert response["test"] == "successful"