diff --git a/src/oidcop/authz/__init__.py b/src/oidcop/authz/__init__.py index 39df89af..04203cff 100755 --- a/src/oidcop/authz/__init__.py +++ b/src/oidcop/authz/__init__.py @@ -30,9 +30,7 @@ def usage_rules(self, client_id: Optional[str] = ""): return _usage_rules try: - _per_client = self.server_get("endpoint_context").cdb[client_id][ - "token_usage_rules" - ] + _per_client = self.server_get("endpoint_context").cdb[client_id]["token_usage_rules"] except KeyError: pass else: @@ -59,14 +57,11 @@ def usage_rules_for(self, client_id, token_type): return {} def __call__( - self, - session_id: str, - request: Union[dict, Message], - resources: Optional[list] = None, + self, session_id: str, request: Union[dict, Message], resources: Optional[list] = None, ) -> Grant: - session_info = self.server_get( - "endpoint_context" - ).session_manager.get_session_info(session_id=session_id, grant=True) + session_info = self.server_get("endpoint_context").session_manager.get_session_info( + session_id=session_id, grant=True + ) grant = session_info["grant"] args = self.grant_config.copy() @@ -87,24 +82,19 @@ def __call__( # After this is where user consent should be handled scopes = request.get("scope", []) grant.scope = scopes - grant.claims = self.server_get( - "endpoint_context" - ).claims_interface.get_claims_all_usage(session_id=session_id, scopes=scopes) + grant.claims = self.server_get("endpoint_context").claims_interface.get_claims_all_usage( + session_id=session_id, scopes=scopes + ) return grant class Implicit(AuthzHandling): def __call__( - self, - session_id: str, - request: Union[dict, Message], - resources: Optional[list] = None, + self, session_id: str, request: Union[dict, Message], resources: Optional[list] = None, ) -> Grant: args = self.grant_config.copy() - grant = self.server_get("endpoint_context").session_manager.get_grant( - session_id=session_id - ) + grant = self.server_get("endpoint_context").session_manager.get_grant(session_id=session_id) for arg, val in args: setattr(grant, arg, val) return grant diff --git a/src/oidcop/client_authn.py b/src/oidcop/client_authn.py index 57086be9..be37ffa5 100755 --- a/src/oidcop/client_authn.py +++ b/src/oidcop/client_authn.py @@ -131,9 +131,7 @@ def is_usable(self, request=None, authorization_token=None): def verify(self, request, **kwargs): if ( - self.server_get("endpoint_context").cdb[request["client_id"]][ - "client_secret" - ] + self.server_get("endpoint_context").cdb[request["client_id"]]["client_secret"] == request["client_secret"] ): return {"client_id": request["client_id"]} @@ -148,9 +146,7 @@ class BearerHeader(ClientSecretBasic): tag = "bearer_header" def is_usable(self, request=None, authorization_token=None): - if authorization_token is not None and authorization_token.startswith( - "Bearer " - ): + if authorization_token is not None and authorization_token.startswith("Bearer "): return True return False @@ -203,9 +199,7 @@ def verify(self, request, key_type, **kwargs): if _sign_alg and _sign_alg.startswith("HS"): if key_type == "private_key": raise AttributeError("Wrong key type") - keys = _context.keyjar.get( - "sig", "oct", ca_jwt["iss"], ca_jwt.jws_header.get("kid") - ) + keys = _context.keyjar.get("sig", "oct", ca_jwt["iss"], ca_jwt.jws_header.get("kid")) _secret = _context.cdb[ca_jwt["iss"]].get("client_secret") if _secret and keys[0].key != as_bytes(_secret): raise AttributeError("Oct key used for signing not client_secret") @@ -366,14 +360,10 @@ def verify_client( if _method.is_usable(request, authorization_token): try: auth_info = _method.verify( - request=request, - authorization_token=authorization_token, - endpoint=endpoint, + request=request, authorization_token=authorization_token, endpoint=endpoint, ) except Exception as err: - logger.warning( - "Verifying auth using {} failed: {}".format(_method.tag, err) - ) + logger.warning("Verifying auth using {} failed: {}".format(_method.tag, err)) else: if "method" not in auth_info: auth_info["method"] = _method.tag @@ -403,9 +393,7 @@ def verify_client( raise UnknownClient("Unknown Client ID") if not valid_client_info(_cinfo): - logger.warning( - "Client registration has timed out or " "client secret is expired." - ) + logger.warning("Client registration has timed out or " "client secret is expired.") raise InvalidClient("Not valid client") # store what authn method was used @@ -413,9 +401,7 @@ def verify_client( _request_type = request.__class__.__name__ _used_authn_method = endpoint_context.cdb[client_id].get("auth_method") if _used_authn_method: - endpoint_context.cdb[client_id]["auth_method"][ - _request_type - ] = auth_info["method"] + endpoint_context.cdb[client_id]["auth_method"][_request_type] = auth_info["method"] else: endpoint_context.cdb[client_id]["auth_method"] = { _request_type: auth_info["method"] @@ -427,9 +413,7 @@ def verify_client( try: # get_client_id_from_token is a callback... Do not abuse for code readability. - auth_info["client_id"] = get_client_id_from_token( - endpoint_context, _token, request - ) + auth_info["client_id"] = get_client_id_from_token(endpoint_context, _token, request) except KeyError: raise ValueError("Unknown token") diff --git a/src/oidcop/configure.py b/src/oidcop/configure.py index 0e9f67ef..9532a89b 100755 --- a/src/oidcop/configure.py +++ b/src/oidcop/configure.py @@ -23,7 +23,16 @@ "jwks_file" ] -DEFAULT_CONFIG = { +OP_DEFAULT_CONFIG = { + "capabilities": { + "subject_types_supported": ["public", "pairwise"], + "grant_types_supported": [ + "authorization_code", + "implicit", + "urn:ietf:params:oauth:grant-type:jwt-bearer", + "refresh_token", + ], + }, "cookie_handler": { "class": "oidcop.cookie_handler.CookieHandler", "kwargs": { @@ -35,30 +44,22 @@ ], "read_only": False, }, - "name": { - "session": "oidc_op", - "register": "oidc_op_rp", - "session_management": "sman", - }, + "name": {"session": "oidc_op", "register": "oidc_op_rp", + "session_management": "sman", }, }, }, + "claims_interface": {"class": "oidcop.session.claims.ClaimsInterface", "kwargs": {}}, "authz": { "class": "oidcop.authz.AuthzHandling", "kwargs": { "grant_config": { "usage_rules": { "authorization_code": { - "supports_minting": [ - "access_token", - "refresh_token", - "id_token", - ], + "supports_minting": ["access_token", "refresh_token", "id_token", ], "max_usage": 1, }, "access_token": {}, - "refresh_token": { - "supports_minting": ["access_token", "refresh_token"] - }, + "refresh_token": {"supports_minting": ["access_token", "refresh_token"]}, }, "expires_in": 43200, } @@ -66,27 +67,14 @@ }, "httpc_params": {"verify": False}, "issuer": "https://{domain}:{port}", - "session_key": { - "filename": "private/session_jwk.json", - "type": "OCT", - "use": "sig", - }, + "session_key": {"filename": "private/session_jwk.json", "type": "OCT", "use": "sig", }, "template_dir": "templates", "token_handler_args": { "jwks_file": "private/token_jwks.json", "code": {"kwargs": {"lifetime": 600}}, - "token": { - "class": "oidcop.token.jwt_token.JWTToken", - "kwargs": {"lifetime": 3600}, - }, - "refresh": { - "class": "oidcop.token.jwt_token.JWTToken", - "kwargs": {"lifetime": 86400}, - }, - "id_token": { - "class": "oidcop.token.id_token.IDToken", - "kwargs": {} - }, + "token": {"class": "oidcop.token.jwt_token.JWTToken", "kwargs": {"lifetime": 3600}, }, + "refresh": {"class": "oidcop.token.jwt_token.JWTToken", "kwargs": {"lifetime": 86400}, }, + "id_token": {"class": "oidcop.token.id_token.IDToken", "kwargs": {}}, }, } @@ -169,10 +157,7 @@ class Base: parameter = {} def __init__( - self, - conf: Dict, - base_path: str = "", - file_attributes: Optional[List[str]] = None, + self, conf: Dict, base_path: str = "", file_attributes: Optional[List[str]] = None, ): if file_attributes is None: file_attributes = DEFAULT_FILE_ATTRIBUTE_NAMES @@ -221,6 +206,7 @@ def __init__( self.authentication = None self.base_url = "" self.capabilities = None + self.claims_interface = None self.cookie_handler = None self.endpoint = {} self.httpc_params = {} @@ -241,8 +227,73 @@ def __init__( for key in self.__dict__.keys(): _val = conf.get(key) if not _val: - if key in DEFAULT_CONFIG: - _dc = copy.deepcopy(DEFAULT_CONFIG[key]) + if key in OP_DEFAULT_CONFIG: + _dc = copy.deepcopy(OP_DEFAULT_CONFIG[key]) + add_base_path(_dc, base_path, file_attributes) + _val = _dc + else: + continue + setattr(self, key, _val) + + if self.template_dir is None: + self.template_dir = os.path.abspath("templates") + else: + self.template_dir = os.path.abspath(self.template_dir) + + if not domain: + domain = conf.get("domain", "127.0.0.1") + + if not port: + port = conf.get("port", 80) + + set_domain_and_port(conf, URIS, domain=domain, port=port) + + +AS_DEFAULT_CONFIG = copy.deepcopy(OP_DEFAULT_CONFIG) +AS_DEFAULT_CONFIG["claims_interface"] = { + "class": "oidcop.session.claims.OAuth2ClaimsInterface", "kwargs": {}} + + +class ASConfiguration(Base): + "Authorization server configuration" + + def __init__( + self, + conf: Dict, + base_path: Optional[str] = "", + entity_conf: Optional[List[dict]] = None, + domain: Optional[str] = "", + port: Optional[int] = 0, + file_attributes: Optional[List[str]] = None, + ): + + conf = copy.deepcopy(conf) + Base.__init__(self, conf, base_path, file_attributes) + + self.add_on = None + self.authz = None + self.authentication = None + self.base_url = "" + self.capabilities = None + self.claims_interface = None + self.cookie_handler = None + self.endpoint = {} + self.httpc_params = {} + self.issuer = "" + self.keys = None + self.session_key = None + self.template_dir = None + self.token_handler_args = {} + self.userinfo = None + + if file_attributes is None: + file_attributes = DEFAULT_FILE_ATTRIBUTE_NAMES + + for key in self.__dict__.keys(): + _val = conf.get(key) + if not _val: + if key in AS_DEFAULT_CONFIG: + _dc = copy.deepcopy(AS_DEFAULT_CONFIG[key]) add_base_path(_dc, base_path, file_attributes) _val = _dc else: @@ -343,17 +394,11 @@ def __init__( "grant_config": { "usage_rules": { "authorization_code": { - "supports_minting": [ - "access_token", - "refresh_token", - "id_token", - ], + "supports_minting": ["access_token", "refresh_token", "id_token", ], "max_usage": 1, }, "access_token": {}, - "refresh_token": { - "supports_minting": ["access_token", "refresh_token"] - }, + "refresh_token": {"supports_minting": ["access_token", "refresh_token"]}, }, "expires_in": 43200, } @@ -366,10 +411,7 @@ def __init__( "kwargs": { "verify_endpoint": "verify/user", "template": "user_pass.jinja2", - "db": { - "class": "oidcop.util.JSONDictDB", - "kwargs": {"filename": "passwd.json"}, - }, + "db": {"class": "oidcop.util.JSONDictDB", "kwargs": {"filename": "passwd.json"}, }, "page_header": "Testing log in", "submit_btn": "Get me in!", "user_label": "Nickname", @@ -397,11 +439,8 @@ def __init__( ], "read_only": False, }, - "name": { - "session": "oidc_op", - "register": "oidc_op_rp", - "session_management": "sman", - }, + "name": {"session": "oidc_op", "register": "oidc_op_rp", + "session_management": "sman", }, }, }, "endpoint": { @@ -418,10 +457,7 @@ def __init__( "registration": { "path": "registration", "class": "oidcop.oidc.registration.Registration", - "kwargs": { - "client_authn_method": None, - "client_secret_expiration_time": 432000, - }, + "kwargs": {"client_authn_method": None, "client_secret_expiration_time": 432000, }, }, "registration_api": { "path": "registration_api", @@ -431,10 +467,7 @@ def __init__( "introspection": { "path": "introspection", "class": "oidcop.oauth2.introspection.Introspection", - "kwargs": { - "client_authn_method": ["client_secret_post"], - "release": ["username"], - }, + "kwargs": {"client_authn_method": ["client_secret_post"], "release": ["username"], }, }, "authorization": { "path": "authorization", @@ -472,9 +505,7 @@ def __init__( "userinfo": { "path": "userinfo", "class": "oidcop.oidc.userinfo.UserInfo", - "kwargs": { - "claim_types_supported": ["normal", "aggregated", "distributed"] - }, + "kwargs": {"claim_types_supported": ["normal", "aggregated", "distributed"]}, }, "end_session": { "path": "session", @@ -506,16 +537,10 @@ def __init__( "login_hint2acrs": { "class": "oidcop.login_hint.LoginHint2Acrs", "kwargs": { - "scheme_map": { - "email": ["oidcop.user_authn.authn_context.INTERNETPROTOCOLPASSWORD"] - } + "scheme_map": {"email": ["oidcop.user_authn.authn_context.INTERNETPROTOCOLPASSWORD"]} }, }, - "session_key": { - "filename": "private/session_jwk.json", - "type": "OCT", - "use": "sig", - }, + "session_key": {"filename": "private/session_jwk.json", "type": "OCT", "use": "sig", }, "template_dir": "templates", "token_handler_args": { "jwks_def": { @@ -546,8 +571,5 @@ def __init__( }, }, }, - "userinfo": { - "class": "oidcop.user_info.UserInfo", - "kwargs": {"db_file": "users.json"}, - }, + "userinfo": {"class": "oidcop.user_info.UserInfo", "kwargs": {"db_file": "users.json"}, }, } diff --git a/src/oidcop/construct.py b/src/oidcop/construct.py index e0374c59..3ef728f7 100644 --- a/src/oidcop/construct.py +++ b/src/oidcop/construct.py @@ -7,7 +7,7 @@ from cryptojwt.jws.jws import SIGNER_ALGS ALG_SORT_ORDER = {"RS": 0, "ES": 1, "HS": 2, "PS": 3, "no": 4} -WEAK_ALGS = ['RSA1_5', 'none'] +WEAK_ALGS = ["RSA1_5", "none"] logger = logging.getLogger(__name__) @@ -76,7 +76,7 @@ def construct_endpoint_info(default_capabilities, **kwargs): elif "encryption_enc_values_supported" in attr: _info[attr] = assign_algorithms("encryption_enc") - if re.match(r'.*(alg|enc).*_values_supported', attr): + if re.match(r".*(alg|enc).*_values_supported", attr): for i in _info[attr]: if i in WEAK_ALGS: logger.warning( diff --git a/src/oidcop/cookie_handler.py b/src/oidcop/cookie_handler.py index 1e533fa4..4524c94f 100755 --- a/src/oidcop/cookie_handler.py +++ b/src/oidcop/cookie_handler.py @@ -143,9 +143,7 @@ def _ver_dec_content(self, parts): mac = base64.b64decode(b64_mac) verifier = HMACSigner(algorithm=self.sign_alg) if verifier.verify( - payload.encode("utf-8") + timestamp.encode("utf-8"), - mac, - self.sign_key.key, + payload.encode("utf-8") + timestamp.encode("utf-8"), mac, self.sign_key.key, ): return payload, timestamp else: diff --git a/src/oidcop/endpoint.py b/src/oidcop/endpoint.py index 52d13491..89fb8191 100755 --- a/src/oidcop/endpoint.py +++ b/src/oidcop/endpoint.py @@ -115,17 +115,13 @@ def __init__(self, server_get: Callable, **kwargs): self.client_authn_method = [] if _methods: self.client_authn_method = client_auth_setup(_methods, server_get) - elif ( - _methods is not None - ): # [] or '' or something not None but regarded as nothing. + elif _methods is not None: # [] or '' or something not None but regarded as nothing. self.client_authn_method = [None] # Ignore default value elif self.default_capabilities: _methods = self.default_capabilities.get("client_authn_method") if _methods: self.client_authn_method = client_auth_setup(_methods, server_get) - self.endpoint_info = construct_endpoint_info( - self.default_capabilities, **kwargs - ) + self.endpoint_info = construct_endpoint_info(self.default_capabilities, **kwargs) # This is for matching against aud in JWTs # By default the endpoint's endpoint URL is an allowed target @@ -137,10 +133,7 @@ def parse_cookies(self, cookies: List[dict], context: EndpointContext, name: str return res def parse_request( - self, - request: Union[Message, dict, str], - http_info: Optional[dict] = None, - **kwargs + self, request: Union[Message, dict, str], http_info: Optional[dict] = None, **kwargs ): """ @@ -211,9 +204,7 @@ def get_client_id_from_token( ): return "" - def client_authentication( - self, request: Message, http_info: Optional[dict] = None, **kwargs - ): + def client_authentication(self, request: Message, http_info: Optional[dict] = None, **kwargs): """ Do client authentication @@ -234,11 +225,7 @@ def client_authentication( ) LOGGER.debug("authn_info: %s", authn_info) - if ( - authn_info == {} - and self.client_authn_method - and len(self.client_authn_method) - ): + if authn_info == {} and self.client_authn_method and len(self.client_authn_method): LOGGER.debug("client_authn_method: %s", self.client_authn_method) raise UnAuthorizedClient("Authorization failed") @@ -255,16 +242,11 @@ def do_post_parse_request( return request def do_pre_construct( - self, - response_args: dict, - request: Optional[Union[Message, dict]] = None, - **kwargs + self, response_args: dict, request: Optional[Union[Message, dict]] = None, **kwargs ) -> dict: _context = self.server_get("endpoint_context") for meth in self.pre_construct: - response_args = meth( - response_args, request, endpoint_context=_context, **kwargs - ) + response_args = meth(response_args, request, endpoint_context=_context, **kwargs) return response_args @@ -276,9 +258,7 @@ def do_post_construct( ) -> dict: _context = self.server_get("endpoint_context") for meth in self.post_construct: - response_args = meth( - response_args, request, endpoint_context=_context, **kwargs - ) + response_args = meth(response_args, request, endpoint_context=_context, **kwargs) return response_args diff --git a/src/oidcop/endpoint_context.py b/src/oidcop/endpoint_context.py index a00668cb..dab66128 100755 --- a/src/oidcop/endpoint_context.py +++ b/src/oidcop/endpoint_context.py @@ -195,9 +195,7 @@ def __init__( if _loader is None: _template_dir = conf.get("template_dir") if _template_dir: - _loader = Environment( - loader=FileSystemLoader(_template_dir), autoescape=True - ) + _loader = Environment(loader=FileSystemLoader(_template_dir), autoescape=True) if _loader: self.template_handler = Jinja2TemplateHandler(_loader) @@ -280,9 +278,7 @@ def do_userinfo(self): self.userinfo = init_user_info(_conf, self.cwd) self.session_manager.userinfo = self.userinfo else: - logger.warning( - "Cannot init_user_info if no session manager was provided." - ) + logger.warning("Cannot init_user_info if no session manager was provided.") def do_cookie_handler(self): _conf = self.conf.get("cookie_handler") @@ -328,9 +324,7 @@ def create_providerinfo(self, capabilities): _provider_info["jwks_uri"] = self.jwks_uri if "scopes_supported" not in _provider_info: - _provider_info["scopes_supported"] = [ - s for s in self.scope2claims.keys() - ] + _provider_info["scopes_supported"] = [s for s in self.scope2claims.keys()] if "claims_supported" not in _provider_info: _provider_info["claims_supported"] = STANDARD_CLAIMS[:] diff --git a/src/oidcop/exception.py b/src/oidcop/exception.py index c9aa652d..48a25093 100755 --- a/src/oidcop/exception.py +++ b/src/oidcop/exception.py @@ -96,3 +96,7 @@ class CapabilitiesMisMatch(OidcEndpointError): class MultipleCodeUsage(OidcEndpointError): pass + + +class InvalidToken(Exception): + pass diff --git a/src/oidcop/logging.py b/src/oidcop/logging.py index bd07fbd8..d04c577e 100755 --- a/src/oidcop/logging.py +++ b/src/oidcop/logging.py @@ -10,18 +10,14 @@ LOGGING_DEFAULT = { "version": 1, - "formatters": { - "default": {"format": "%(asctime)s %(name)s %(levelname)s %(message)s"} - }, + "formatters": {"default": {"format": "%(asctime)s %(name)s %(levelname)s %(message)s"}}, "handlers": {"default": {"class": "logging.StreamHandler", "formatter": "default"}}, "root": {"handlers": ["default"], "level": "INFO"}, } def configure_logging( - debug: Optional[bool] = False, - config: Optional[dict] = None, - filename: Optional[str] = "", + debug: Optional[bool] = False, config: Optional[dict] = None, filename: Optional[str] = "", ) -> logging.Logger: """Configure logging""" diff --git a/src/oidcop/oauth2/add_on/dpop_token.py b/src/oidcop/oauth2/add_on/dpop_token.py index 359bc25a..21ace934 100644 --- a/src/oidcop/oauth2/add_on/dpop_token.py +++ b/src/oidcop/oauth2/add_on/dpop_token.py @@ -26,16 +26,12 @@ def process_request(self, req: Union[Message, dict], **kwargs): _log_debug = logger.debug if req["grant_type"] != "authorization_code": - return self.error_cls( - error="invalid_request", error_description="Unknown grant_type" - ) + return self.error_cls(error="invalid_request", error_description="Unknown grant_type") try: _access_code = req["code"].replace(" ", "+") except KeyError: # Missing code parameter - absolutely fatal - return self.error_cls( - error="invalid_request", error_description="Missing code" - ) + return self.error_cls(error="invalid_request", error_description="Missing code") _session_info = _mngr.get_session_info_by_token(_access_code, grant=True) grant = _session_info["grant"] @@ -114,8 +110,7 @@ def process_request(self, req: Union[Message, dict], **kwargs): except (JWEException, NoSuitableSigningKeys) as err: logger.warning(str(err)) resp = self.error_cls( - error="invalid_request", - error_description="Could not sign/encrypt id_token", + error="invalid_request", error_description="Could not sign/encrypt id_token", ) return resp @@ -130,9 +125,7 @@ def process_request(self, req: Union[Message, dict], **kwargs): _mngr = _context.session_manager if req["grant_type"] != "refresh_token": - return self.error_cls( - error="invalid_request", error_description="Wrong grant_type" - ) + return self.error_cls(error="invalid_request", error_description="Wrong grant_type") token_value = req["refresh_token"] _session_info = _mngr.get_session_info_by_token(token_value, grant=True) @@ -186,8 +179,7 @@ def process_request(self, req: Union[Message, dict], **kwargs): except (JWEException, NoSuitableSigningKeys) as err: logger.warning(str(err)) resp = self.error_cls( - error="invalid_request", - error_description="Could not sign/encrypt id_token", + error="invalid_request", error_description="Could not sign/encrypt id_token", ) return resp diff --git a/src/oidcop/oauth2/authorization.py b/src/oidcop/oauth2/authorization.py index 138a6128..72552c0f 100755 --- a/src/oidcop/oauth2/authorization.py +++ b/src/oidcop/oauth2/authorization.py @@ -45,18 +45,9 @@ # For the time being. This is JAR specific and should probably be configurable. ALG_PARAMS = { - "sign": [ - "request_object_signing_alg", - "request_object_signing_alg_values_supported", - ], - "enc_alg": [ - "request_object_encryption_alg", - "request_object_encryption_alg_values_supported", - ], - "enc_enc": [ - "request_object_encryption_enc", - "request_object_encryption_enc_values_supported", - ], + "sign": ["request_object_signing_alg", "request_object_signing_alg_values_supported",], + "enc_alg": ["request_object_encryption_alg", "request_object_encryption_alg_values_supported",], + "enc_enc": ["request_object_encryption_enc", "request_object_encryption_enc_values_supported",], } FORM_POST = """ @@ -151,9 +142,7 @@ def verify_uri( for val in vals: if val not in _query[key]: - raise ValueError( - "{}={} value not in query part".format(key, val) - ) + raise ValueError("{}={} value not in query part".format(key, val)) # and vice versa, every query component in the uri # must be registered @@ -166,9 +155,7 @@ def verify_uri( raise ValueError('"{}" extra in query part'.format(key)) for val in vals: if val not in rquery[key]: - raise ValueError( - "Extra {}={} value in query part".format(key, val) - ) + raise ValueError("Extra {}={} value in query part".format(key, val)) match = True break if not match: @@ -210,9 +197,7 @@ def get_uri(endpoint_context, request, uri_type): raise ParameterError(f"Missing '{uri_type}' and none registered") if len(_specs) > 1: - raise ParameterError( - f"Missing '{uri_type}' and more than one registered" - ) + raise ParameterError(f"Missing '{uri_type}' and more than one registered") uri = join_query(*_specs[0]) else: @@ -222,10 +207,7 @@ def get_uri(endpoint_context, request, uri_type): def authn_args_gather( - request: Union[AuthorizationRequest, dict], - authn_class_ref: str, - cinfo: dict, - **kwargs, + request: Union[AuthorizationRequest, dict], authn_class_ref: str, cinfo: dict, **kwargs, ): """ Gather information to be used by the authentication method @@ -245,9 +227,7 @@ def authn_args_gather( else: raise ValueError("Wrong request format") - authn_args.update( - {"authn_class_ref": authn_class_ref, "return_uri": request["redirect_uri"]} - ) + authn_args.update({"authn_class_ref": authn_class_ref, "return_uri": request["redirect_uri"]}) if "req_user" in kwargs: authn_args["as_user"] = (kwargs["req_user"],) @@ -267,16 +247,11 @@ def authn_args_gather( def check_unknown_scopes_policy(request_info, cinfo, endpoint_context): op_capabilities = endpoint_context.conf["capabilities"] - client_allowed_scopes = ( - cinfo.get("allowed_scopes") or op_capabilities["scopes_supported"] - ) + client_allowed_scopes = cinfo.get("allowed_scopes") or op_capabilities["scopes_supported"] # this prevents that authz would be released for unavailable scopes for scope in request_info["scope"]: - if ( - op_capabilities.get("deny_unknown_scopes") - and scope not in client_allowed_scopes - ): + if op_capabilities.get("deny_unknown_scopes") and scope not in client_allowed_scopes: _msg = "{} requested an unauthorized scope ({})" logger.warning(_msg.format(cinfo["client_id"], scope)) raise UnAuthorizedClientScope() @@ -365,15 +340,8 @@ def _do_request_uri(self, request, client_id, endpoint_context, **kwargs): raise ValueError("Got a request_uri I can not resolve") # Do I support request_uri ? - if ( - endpoint_context.provider_info.get( - "request_uri_parameter_supported", True - ) - is False - ): - raise ServiceError( - "Someone is using request_uri which I'm not supporting" - ) + if endpoint_context.provider_info.get("request_uri_parameter_supported", True) is False: + raise ServiceError("Someone is using request_uri which I'm not supporting") _registered = endpoint_context.cdb[client_id].get("request_uris") # Not registered should be handled else where @@ -385,9 +353,7 @@ def _do_request_uri(self, request, client_id, endpoint_context, **kwargs): raise ValueError("A request_uri outside the registered") # Fetch the request - _resp = endpoint_context.httpc.get( - _request_uri, **endpoint_context.httpc_params - ) + _resp = endpoint_context.httpc.get(_request_uri, **endpoint_context.httpc_params) if _resp.status_code == 200: args = {"keyjar": endpoint_context.keyjar, "issuer": client_id} _ver_request = self.request_cls().from_jwt(_resp.text, **args) @@ -399,16 +365,10 @@ def _do_request_uri(self, request, client_id, endpoint_context, **kwargs): ) if _ver_request.jwe_header is not None: self.allowed_request_algorithms( - client_id, - endpoint_context, - _ver_request.jws_header.get("alg"), - "enc_alg", + client_id, endpoint_context, _ver_request.jws_header.get("alg"), "enc_alg", ) self.allowed_request_algorithms( - client_id, - endpoint_context, - _ver_request.jws_header.get("enc"), - "enc_enc", + client_id, endpoint_context, _ver_request.jws_header.get("enc"), "enc_enc", ) # The protected info overwrites the non-protected for k, v in _ver_request.items(): @@ -440,12 +400,8 @@ def _post_parse_request(self, request, client_id, endpoint_context, **kwargs): _cinfo = endpoint_context.cdb.get(client_id) if not _cinfo: - logger.error( - "Client ID ({}) not in client database".format(request["client_id"]) - ) - return self.error_cls( - error="unauthorized_client", error_description="unknown client" - ) + logger.error("Client ID ({}) not in client database".format(request["client_id"])) + return self.error_cls(error="unauthorized_client", error_description="unknown client") # Is the asked for response_type among those that are permitted if not self.verify_response_type(request, _cinfo): @@ -491,9 +447,7 @@ def pick_authn_method(self, request, redirect_uri, acr=None, **kwargs): def create_session(self, request, user_id, acr, time_stamp, authn_method): _context = self.server_get("endpoint_context") _mngr = _context.session_manager - authn_event = create_authn_event( - user_id, authn_info=acr, time_stamp=time_stamp, - ) + authn_event = create_authn_event(user_id, authn_info=acr, time_stamp=time_stamp,) _exp_in = authn_method.kwargs.get("expires_in") if _exp_in and "valid_until" in authn_event: authn_event["valid_until"] = utc_time_sans_frac() + _exp_in @@ -621,12 +575,8 @@ def setup_auth( if grant.is_active() is False: return {"function": authn, "args": authn_args} elif request != grant.authorization_request: - authn_event = _mngr.get_authentication_event( - session_id=_session_id - ) - if ( - authn_event.is_valid() is False - ): # if not valid, do new login + authn_event = _mngr.get_authentication_event(session_id=_session_id) + if authn_event.is_valid() is False: # if not valid, do new login return {"function": authn, "args": authn_args} # create new grant @@ -642,9 +592,7 @@ def setup_auth( if authn_event.is_valid() is False: # if not valid, do new login return {"function": authn, "args": authn_args} else: - _session_id = self.create_session( - request, identity["uid"], authn_class_ref, _ts, authn - ) + _session_id = self.create_session(request, identity["uid"], authn_class_ref, _ts, authn) return {"session_id": _session_id, "identity": identity, "user": user} @@ -667,11 +615,7 @@ def response_mode( _args = response_args msg = FORM_POST.format(inputs=inputs(_args), action=return_uri,) kwargs.update( - { - "response_msg": msg, - "content_type": "text/html", - "response_placement": "body", - } + {"response_msg": msg, "content_type": "text/html", "response_placement": "body",} ) elif resp_mode == "fragment": if fragment_enc is False: @@ -728,9 +672,7 @@ def create_authn_response(self, request: Union[dict, Message], sid: str) -> dict if "code" in request["response_type"]: _code = self.mint_token( - token_type="authorization_code", - grant=grant, - session_id=_sinfo["session_id"], + token_type="authorization_code", grant=grant, session_id=_sinfo["session_id"], ) aresp["code"] = _code.value handled_response_type.append("code") @@ -739,16 +681,12 @@ def create_authn_response(self, request: Union[dict, Message], sid: str) -> dict if "token" in rtype: _access_token = self.mint_token( - token_type="access_token", - grant=grant, - session_id=_sinfo["session_id"], + token_type="access_token", grant=grant, session_id=_sinfo["session_id"], ) aresp["access_token"] = _access_token.value aresp["token_type"] = "Bearer" if _access_token.expires_at: - aresp["expires_in"] = ( - _access_token.expires_at - utc_time_sans_frac() - ) + aresp["expires_in"] = _access_token.expires_at - utc_time_sans_frac() handled_response_type.append("token") else: _access_token = None @@ -784,8 +722,7 @@ def create_authn_response(self, request: Union[dict, Message], sid: str) -> dict not_handled = rtype.difference(handled_response_type) if not_handled: resp = self.error_cls( - error="invalid_request", - error_description="unsupported_response_type", + error="invalid_request", error_description="unsupported_response_type", ) return {"response_args": resp, "fragment_enc": fragment_enc} @@ -793,9 +730,7 @@ def create_authn_response(self, request: Union[dict, Message], sid: str) -> dict return {"response_args": aresp, "fragment_enc": fragment_enc} - def post_authentication( - self, request: Union[dict, Message], session_id: str, **kwargs - ) -> dict: + def post_authentication(self, request: Union[dict, Message], session_id: str, **kwargs) -> dict: """ Things that are done after a successful authentication. @@ -813,17 +748,13 @@ def post_authentication( grant = _context.authz(session_id, request=request) if grant.is_active() is False: - return self.error_response( - response_info, "server_error", "Grant not usable" - ) + return self.error_response(response_info, "server_error", "Grant not usable") user_id, client_id, grant_id = _mngr.decrypt_session_id(session_id) try: _mngr.set([user_id, client_id, grant_id], grant) except Exception as err: - return self.error_response( - response_info, "server_error", "{}".format(err.args) - ) + return self.error_response(response_info, "server_error", "{}".format(err.args)) logger.debug("response type: %s" % request["response_type"]) @@ -835,9 +766,7 @@ def post_authentication( try: redirect_uri = get_uri(_context, request, "redirect_uri") except (RedirectURIError, ParameterError) as err: - return self.error_response( - response_info, "invalid_request", "{}".format(err.args) - ) + return self.error_response(response_info, "invalid_request", "{}".format(err.args)) else: response_info["return_uri"] = redirect_uri @@ -849,9 +778,7 @@ def post_authentication( try: response_info = self.response_mode(request, **response_info) except InvalidRequest as err: - return self.error_response( - response_info, "invalid_request", "{}".format(err.args) - ) + return self.error_response(response_info, "invalid_request", "{}".format(err.args)) _cookie_info = _context.new_cookie( name=_context.cookie_handler.name["session"], @@ -882,24 +809,17 @@ def authz_part2(self, request, session_id, **kwargs): if "check_session_iframe" in _context.provider_info: salt = rndstr() try: - authn_event = _context.session_manager.get_authentication_event( - session_id - ) + authn_event = _context.session_manager.get_authentication_event(session_id) except KeyError: return self.error_response({}, "server_error", "No such session") else: if authn_event.is_valid() is False: - return self.error_response( - {}, "server_error", "Authentication has timed out" - ) + return self.error_response({}, "server_error", "Authentication has timed out") - _state = b64e( - as_bytes(json.dumps({"authn_time": authn_event["authn_time"]})) - ) + _state = b64e(as_bytes(json.dumps({"authn_time": authn_event["authn_time"]}))) _session_cookie_content = _context.new_cookie( - name=_context.cookie_handler.name["session_management"], - state=as_unicode(_state), + name=_context.cookie_handler.name["session_management"], state=as_unicode(_state), ) opbs_value = _session_cookie_content["value"] @@ -967,9 +887,7 @@ def process_request( kwargs = self.do_request_user(request_info=request, **kwargs) - info = self.setup_auth( - request, request["redirect_uri"], cinfo, _cookies, **kwargs - ) + info = self.setup_auth(request, request["redirect_uri"], cinfo, _cookies, **kwargs) if "error" in info: return info @@ -1005,9 +923,7 @@ def __call__(self, client_id, endpoint_context, alg, alg_type): _allowed = _pinfo.get(_sup) if alg not in _allowed: - logger.error( - "Signing alg user: {} not among allowed: {}".format(alg, _allowed) - ) + logger.error("Signing alg user: {} not among allowed: {}".format(alg, _allowed)) raise ValueError("Not allowed '%s' algorithm used", alg) diff --git a/src/oidcop/oauth2/introspection.py b/src/oidcop/oauth2/introspection.py index b7474a9c..0ae077f7 100644 --- a/src/oidcop/oauth2/introspection.py +++ b/src/oidcop/oauth2/introspection.py @@ -89,9 +89,7 @@ def process_request(self, request=None, release: Optional[list] = None, **kwargs grant = _session_info["grant"] _token = grant.get_token(request_token) - _info = self._introspect( - _token, _session_info["client_id"], _session_info["grant"] - ) + _info = self._introspect(_token, _session_info["client_id"], _session_info["grant"]) if _info is None: return {"response_args": _resp} @@ -105,9 +103,7 @@ def process_request(self, request=None, release: Optional[list] = None, **kwargs _resp.update(_info) _resp.weed() - _claims_restriction = grant.claims.get( - "introspection" - ) + _claims_restriction = grant.claims.get("introspection") if _claims_restriction: user_info = _context.claims_interface.get_user_claims( _session_info["user_id"], _claims_restriction diff --git a/src/oidcop/oauth2/token.py b/src/oidcop/oauth2/token.py new file mode 100755 index 00000000..e7e89b3d --- /dev/null +++ b/src/oidcop/oauth2/token.py @@ -0,0 +1,420 @@ +import logging +from typing import Optional +from typing import Union + +from cryptojwt.jwe.exception import JWEException +from cryptojwt.jwt import utc_time_sans_frac + +from oidcmsg import oidc +from oidcmsg.message import Message +from oidcmsg.oauth2 import AccessTokenResponse +from oidcmsg.oauth2 import ResponseMessage +from oidcmsg.oidc import RefreshAccessTokenRequest +from oidcmsg.oidc import TokenErrorResponse +from oidcmsg.time_util import time_sans_frac + +from oidcop import sanitize +from oidcop.endpoint import Endpoint +from oidcop.exception import ProcessError +from oidcop.session.grant import AuthorizationCode +from oidcop.session.grant import Grant +from oidcop.session.grant import RefreshToken +from oidcop.session.token import MintingNotAllowed +from oidcop.session.token import SessionToken +from oidcop.token.exception import UnknownToken +from oidcop.util import importer + +logger = logging.getLogger(__name__) + + +class TokenEndpointHelper(object): + def __init__(self, endpoint, config=None): + self.endpoint = endpoint + self.config = config + self.error_cls = self.endpoint.error_cls + + def post_parse_request( + self, request: Union[Message, dict], client_id: Optional[str] = "", **kwargs + ): + """Context specific parsing of the request. + This is done after general request parsing and before processing + the request. + """ + raise NotImplementedError + + def process_request(self, req: Union[Message, dict], **kwargs): + """Acts on a process request.""" + raise NotImplementedError + + def _mint_token( + self, + type: str, + grant: Grant, + session_id: str, + client_id: str, + based_on: Optional[SessionToken] = None, + token_args: Optional[dict] = None, + ) -> SessionToken: + _context = self.endpoint.server_get("endpoint_context") + _mngr = _context.session_manager + usage_rules = grant.usage_rules.get(type) + if usage_rules: + _exp_in = usage_rules.get("expires_in") + else: + _exp_in = 0 + + token_args = token_args or {} + for meth in _context.token_args_methods: + token_args = meth(_context, client_id, token_args) + + if token_args: + _args = {"token_args": token_args} + else: + _args = {} + + token = grant.mint_token( + session_id, + endpoint_context=_context, + token_type=type, + token_handler=_mngr.token_handler[type], + based_on=based_on, + usage_rules=usage_rules, + **_args, + ) + + if _exp_in: + if isinstance(_exp_in, str): + _exp_in = int(_exp_in) + + if _exp_in: + token.expires_at = time_sans_frac() + _exp_in + + _context.session_manager.set(_context.session_manager.unpack_session_key(session_id), grant) + + return token + + +class AccessTokenHelper(TokenEndpointHelper): + def process_request(self, req: Union[Message, dict], **kwargs): + """ + + :param req: + :param kwargs: + :return: + """ + _context = self.endpoint.server_get("endpoint_context") + + _mngr = _context.session_manager + _log_debug = logger.debug + + if req["grant_type"] != "authorization_code": + return self.error_cls(error="invalid_request", error_description="Unknown grant_type") + + try: + _access_code = req["code"].replace(" ", "+") + except KeyError: # Missing code parameter - absolutely fatal + return self.error_cls(error="invalid_request", error_description="Missing code") + + _session_info = _mngr.get_session_info_by_token(_access_code, grant=True) + grant = _session_info["grant"] + + _based_on = grant.get_token(_access_code) + _supports_minting = _based_on.usage_rules.get("supports_minting", []) + + _authn_req = grant.authorization_request + + # If redirect_uri was in the initial authorization request + # verify that the one given here is the correct one. + if "redirect_uri" in _authn_req: + if req["redirect_uri"] != _authn_req["redirect_uri"]: + return self.error_cls( + error="invalid_request", error_description="redirect_uri mismatch" + ) + + _log_debug("All checks OK") + + issue_refresh = kwargs.get("issue_refresh", False) + + _response = { + "token_type": "Bearer", + "scope": grant.scope, + } + + if "access_token" in _supports_minting: + try: + token = self._mint_token( + type="access_token", + grant=grant, + session_id=_session_info["session_id"], + client_id=_session_info["client_id"], + based_on=_based_on, + ) + except MintingNotAllowed as err: + logger.warning(err) + else: + _response["access_token"] = token.value + if token.expires_at: + _response["expires_in"] = token.expires_at - utc_time_sans_frac() + + if issue_refresh and "refresh_token" in _supports_minting: + try: + refresh_token = self._mint_token( + type="refresh_token", + grant=grant, + session_id=_session_info["session_id"], + client_id=_session_info["client_id"], + based_on=_based_on, + ) + except MintingNotAllowed as err: + logger.warning(err) + else: + _response["refresh_token"] = refresh_token.value + + # since the grant content has changed. Make sure it's stored + _mngr[_session_info["session_id"]] = grant + + _based_on.register_usage() + + return _response + + def post_parse_request( + self, request: Union[Message, dict], client_id: Optional[str] = "", **kwargs + ): + """ + This is where clients come to get their access tokens + + :param request: The request + :param client_id: Client identifier + :returns: + """ + + _mngr = self.endpoint.server_get("endpoint_context").session_manager + try: + _session_info = _mngr.get_session_info_by_token(request["code"], grant=True) + except (KeyError, UnknownToken): + logger.error("Access Code invalid") + return self.error_cls(error="invalid_grant", error_description="Unknown code") + + grant = _session_info["grant"] + code = grant.get_token(request["code"]) + if not isinstance(code, AuthorizationCode): + return self.error_cls(error="invalid_request", error_description="Wrong token type") + + if code.is_active() is False: + return self.error_cls(error="invalid_request", error_description="Code inactive") + + _auth_req = grant.authorization_request + + if "client_id" not in request: # Optional for access token request + request["client_id"] = _auth_req["client_id"] + + logger.debug("%s: %s" % (request.__class__.__name__, sanitize(request))) + + return request + + +class RefreshTokenHelper(TokenEndpointHelper): + def process_request(self, req: Union[Message, dict], **kwargs): + _context = self.endpoint.server_get("endpoint_context") + _mngr = _context.session_manager + + if req["grant_type"] != "refresh_token": + return self.error_cls(error="invalid_request", error_description="Wrong grant_type") + + token_value = req["refresh_token"] + _session_info = _mngr.get_session_info_by_token(token_value, grant=True) + + _grant = _session_info["grant"] + token = _grant.get_token(token_value) + access_token = self._mint_token( + type="access_token", + grant=_grant, + session_id=_session_info["session_id"], + client_id=_session_info["client_id"], + based_on=token, + ) + + _resp = { + "access_token": access_token.value, + "token_type": access_token.type, + "scope": _grant.scope, + } + + if access_token.expires_at: + _resp["expires_in"] = access_token.expires_at - utc_time_sans_frac() + + _mints = token.usage_rules.get("supports_minting") + if "refresh_token" in _mints: + refresh_token = self._mint_token( + type="refresh_token", + grant=_grant, + session_id=_session_info["session_id"], + client_id=_session_info["client_id"], + based_on=token, + ) + refresh_token.usage_rules = token.usage_rules.copy() + _resp["refresh_token"] = refresh_token.value + + token.register_usage() + + return _resp + + def post_parse_request( + self, request: Union[Message, dict], client_id: Optional[str] = "", **kwargs + ): + """ + This is where clients come to refresh their access tokens + + :param request: The request + :param client_id: Client identifier + :returns: + """ + + request = RefreshAccessTokenRequest(**request.to_dict()) + _context = self.endpoint.server_get("endpoint_context") + try: + keyjar = _context.keyjar + except AttributeError: + keyjar = "" + + request.verify(keyjar=keyjar, opponent_id=client_id) + + _mngr = _context.session_manager + try: + _session_info = _mngr.get_session_info_by_token(request["refresh_token"], grant=True) + except KeyError: + logger.error("Access Code invalid") + return self.error_cls(error="invalid_grant") + + token = _session_info["grant"].get_token(request["refresh_token"]) + + if not isinstance(token, RefreshToken): + return self.error_cls(error="invalid_request", error_description="Wrong token type") + + if token.is_active() is False: + return self.error_cls( + error="invalid_request", error_description="Refresh token inactive" + ) + + return request + + +class Token(Endpoint): + request_cls = Message + response_cls = AccessTokenResponse + error_cls = TokenErrorResponse + request_format = "json" + request_placement = "body" + response_format = "json" + response_placement = "body" + endpoint_name = "token_endpoint" + name = "token" + default_capabilities = {"token_endpoint_auth_signing_alg_values_supported": None} + helper_by_grant_type = { + "authorization_code": AccessTokenHelper, + "refresh_token": RefreshTokenHelper, + } + + def __init__(self, server_get, new_refresh_token=False, **kwargs): + Endpoint.__init__(self, server_get, **kwargs) + self.post_parse_request.append(self._post_parse_request) + if "client_authn_method" in kwargs: + self.endpoint_info["token_endpoint_auth_methods_supported"] = kwargs[ + "client_authn_method" + ] + self.allow_refresh = False + self.new_refresh_token = new_refresh_token + self.configure_grant_types(kwargs.get("grant_types_supported")) + + def configure_grant_types(self, grant_types_supported): + if grant_types_supported is None: + self.helper = {k: v(self) for k, v in self.helper_by_grant_type.items()} + return + + self.helper = {} + # TODO: do we want to allow any grant_type? + for grant_type, grant_type_options in grant_types_supported.items(): + _conf = grant_type_options.get("kwargs", {}) + if _conf is False: + continue + + try: + grant_class = grant_type_options["class"] + except (KeyError, TypeError): + raise ProcessError( + "Token Endpoint's grant types must be True, None or a dict with a" + " 'class' key." + ) + + if isinstance(grant_class, str): + try: + grant_class = importer(grant_class) + except (ValueError, AttributeError): + raise ProcessError( + f"Token Endpoint's grant type class {grant_class} can't" " be imported." + ) + + try: + self.helper[grant_type] = grant_class(self, _conf) + except Exception as e: + raise ProcessError(f"Failed to initialize class {grant_class}: {e}") + + def _post_parse_request( + self, request: Union[Message, dict], client_id: Optional[str] = "", **kwargs + ): + _helper = self.helper.get(request["grant_type"]) + if _helper: + return _helper.post_parse_request(request, client_id, **kwargs) + else: + return self.error_cls( + error="invalid_request", + error_description=f"Unsupported grant_type: {request['grant_type']}", + ) + + def process_request(self, request: Optional[Union[Message, dict]] = None, **kwargs): + """ + + :param request: + :param kwargs: + :return: Dictionary with response information + """ + if isinstance(request, self.error_cls): + return request + + if request is None: + return self.error_cls(error="invalid_request") + + try: + _helper = self.helper.get(request["grant_type"]) + if _helper: + response_args = _helper.process_request(request, **kwargs) + else: + return self.error_cls( + error="invalid_request", + error_description=f"Unsupported grant_type: {request['grant_type']}", + ) + except JWEException as err: + return self.error_cls(error="invalid_request", error_description="%s" % err) + + if isinstance(response_args, ResponseMessage): + return response_args + + _access_token = response_args["access_token"] + _context = self.server_get("endpoint_context") + _session_info = _context.session_manager.get_session_info_by_token( + _access_token, grant=True + ) + + _cookie = _context.new_cookie( + name=_context.cookie_handler.name["session"], + sub=_session_info["grant"].sub, + sid=_context.session_manager.session_key( + _session_info["user_id"], _session_info["user_id"], _session_info["grant"].id, + ), + ) + + _headers = [("Content-type", "application/json")] + resp = {"response_args": response_args, "http_headers": _headers} + if _cookie: + resp["cookie"] = [_cookie] + return resp diff --git a/src/oidcop/oidc/add_on/pkce.py b/src/oidcop/oidc/add_on/pkce.py index 301f315d..75b541d6 100644 --- a/src/oidcop/oidc/add_on/pkce.py +++ b/src/oidcop/oidc/add_on/pkce.py @@ -4,7 +4,9 @@ from cryptojwt.utils import b64e from oidcmsg.oauth2 import ( - AuthorizationErrorResponse, RefreshAccessTokenRequest, TokenExchangeRequest + AuthorizationErrorResponse, + RefreshAccessTokenRequest, + TokenExchangeRequest, ) from oidcmsg.oidc import TokenErrorResponse @@ -41,8 +43,7 @@ def post_authn_parse(request, client_id, endpoint_context, **kwargs): """ if endpoint_context.args["pkce"]["essential"] and "code_challenge" not in request: return AuthorizationErrorResponse( - error="invalid_request", - error_description="Missing required code_challenge", + error="invalid_request", error_description="Missing required code_challenge", ) if "code_challenge_method" not in request: @@ -87,8 +88,7 @@ def post_token_parse(request, client_id, endpoint_context, **kwargs): :return: """ if isinstance( - request, - (AuthorizationErrorResponse, RefreshAccessTokenRequest, TokenExchangeRequest), + request, (AuthorizationErrorResponse, RefreshAccessTokenRequest, TokenExchangeRequest), ): return request @@ -97,9 +97,7 @@ def post_token_parse(request, client_id, endpoint_context, **kwargs): request["code"], grant=True ) except KeyError: - return TokenErrorResponse( - error="invalid_grant", error_description="Unknown access grant" - ) + return TokenErrorResponse(error="invalid_grant", error_description="Unknown access grant") _authn_req = _session_info["grant"].authorization_request @@ -114,9 +112,7 @@ def post_token_parse(request, client_id, endpoint_context, **kwargs): if not verify_code_challenge( request["code_verifier"], _authn_req["code_challenge"], _method, ): - return TokenErrorResponse( - error="invalid_grant", error_description="PKCE check failed" - ) + return TokenErrorResponse(error="invalid_grant", error_description="PKCE check failed") return request diff --git a/src/oidcop/oidc/authorization.py b/src/oidcop/oidc/authorization.py index fb4d73e8..670b42a7 100755 --- a/src/oidcop/oidc/authorization.py +++ b/src/oidcop/oidc/authorization.py @@ -43,18 +43,9 @@ def host_component(url): ALG_PARAMS = { - "sign": [ - "request_object_signing_alg", - "request_object_signing_alg_values_supported", - ], - "enc_alg": [ - "request_object_encryption_alg", - "request_object_encryption_alg_values_supported", - ], - "enc_enc": [ - "request_object_encryption_enc", - "request_object_encryption_enc_values_supported", - ], + "sign": ["request_object_signing_alg", "request_object_signing_alg_values_supported",], + "enc_alg": ["request_object_encryption_alg", "request_object_encryption_alg_values_supported",], + "enc_enc": ["request_object_encryption_enc", "request_object_encryption_enc_values_supported",], } diff --git a/src/oidcop/oidc/registration.py b/src/oidcop/oidc/registration.py index 11e4d49c..260043bf 100755 --- a/src/oidcop/oidc/registration.py +++ b/src/oidcop/oidc/registration.py @@ -99,9 +99,7 @@ def comb_uri(args): val = [] for base, query_dict in args[param]: if query_dict: - query_string = urlencode( - [(key, v) for key in query_dict for v in query_dict[key]] - ) + query_string = urlencode([(key, v) for key in query_dict for v in query_dict[key]]) val.append("%s?%s" % (base, query_string)) else: val.append(base) @@ -153,9 +151,7 @@ def match_client_request(self, request): if request[_pref] not in _context.provider_info[_prov]: raise CapabilitiesMisMatch(_pref) else: - if not set(request[_pref]).issubset( - set(_context.provider_info[_prov]) - ): + if not set(request[_pref]).issubset(set(_context.provider_info[_prov])): raise CapabilitiesMisMatch(_pref) def do_client_registration(self, request, client_id, ignore=None): @@ -186,9 +182,7 @@ def do_client_registration(self, request, client_id, ignore=None): ruri = self.verify_redirect_uris(request) _cinfo["redirect_uris"] = ruri except InvalidRedirectURIError as e: - return self.error_cls( - error="invalid_redirect_uri", error_description=str(e) - ) + return self.error_cls(error="invalid_redirect_uri", error_description=str(e)) if "request_uris" in request: _uris = [] @@ -209,10 +203,9 @@ def do_client_registration(self, request, client_id, ignore=None): if "sector_identifier_uri" in request: try: - ( - _cinfo["si_redirects"], - _cinfo["sector_id"], - ) = self._verify_sector_identifier(request) + (_cinfo["si_redirects"], _cinfo["sector_id"],) = self._verify_sector_identifier( + request + ) except InvalidSectorIdentifier as err: return ResponseMessage( error="invalid_configuration_parameter", error_description=str(err) @@ -238,14 +231,10 @@ def do_client_registration(self, request, client_id, ignore=None): _k = [] for iss in ["", _context.issuer]: _k.extend( - _context.keyjar.get_signing_key( - ktyp, alg=request[item], owner=iss - ) + _context.keyjar.get_signing_key(ktyp, alg=request[item], owner=iss) ) if not _k: - logger.warning( - 'Lacking support for "{}"'.format(request[item]) - ) + logger.warning('Lacking support for "{}"'.format(request[item])) del _cinfo[item] t = {"jwks_uri": "", "jwks": None} @@ -287,9 +276,7 @@ def verify_redirect_uris(registration_request): pass else: logger.error( - "InvalidRedirectURI: scheme:%s, hostname:%s", - p.scheme, - p.hostname, + "InvalidRedirectURI: scheme:%s, hostname:%s", p.scheme, p.hostname, ) raise InvalidRedirectURIError( "Redirect_uri must use custom " "scheme or http and localhost" @@ -299,9 +286,7 @@ def verify_redirect_uris(registration_request): raise InvalidRedirectURIError(msg) elif p.scheme not in ["http", "https"]: # Custom scheme - raise InvalidRedirectURIError( - "Custom redirect_uri not allowed for web client" - ) + raise InvalidRedirectURIError("Custom redirect_uri not allowed for web client") elif p.fragment: raise InvalidRedirectURIError("redirect_uri contains fragment") @@ -339,17 +324,13 @@ def _verify_sector_identifier(self, request): try: si_redirects = json.loads(res.text) except ValueError: - raise InvalidSectorIdentifier( - "Error deserializing sector_identifier_uri content" - ) + raise InvalidSectorIdentifier("Error deserializing sector_identifier_uri content") if "redirect_uris" in request: logger.debug("redirect_uris: %s", request["redirect_uris"]) for uri in request["redirect_uris"]: if uri not in si_redirects: - raise InvalidSectorIdentifier( - "redirect_uri missing from sector_identifiers" - ) + raise InvalidSectorIdentifier("redirect_uri missing from sector_identifiers") return si_redirects, si_url @@ -397,8 +378,7 @@ def client_registration_setup(self, request, new_id=True, set_secret=True): self.match_client_request(request) except CapabilitiesMisMatch as err: return ResponseMessage( - error="invalid_request", - error_description="Don't support proposed %s" % err, + error="invalid_request", error_description="Don't support proposed %s" % err, ) _context = self.server_get("endpoint_context") @@ -433,16 +413,12 @@ def client_registration_setup(self, request, new_id=True, set_secret=True): _context.cdb[client_id] = _cinfo _cinfo = self.do_client_registration( - request, - client_id, - ignore=["redirect_uris", "policy_uri", "logo_uri", "tos_uri"], + request, client_id, ignore=["redirect_uris", "policy_uri", "logo_uri", "tos_uri"], ) if isinstance(_cinfo, ResponseMessage): return _cinfo - args = dict( - [(k, v) for k, v in _cinfo.items() if k in self.response_cls.c_param] - ) + args = dict([(k, v) for k, v in _cinfo.items() if k in self.response_cls.c_param]) comb_uri(args) response = self.response_cls(**args) @@ -478,8 +454,7 @@ def process_request(self, request=None, new_id=True, set_secret=True, **kwargs): else: _context = self.server_get("endpoint_context") _cookie = _context.new_cookie( - name=_context.cookie_handler.name["register"], - client_id=reg_resp["client_id"], + name=_context.cookie_handler.name["register"], client_id=reg_resp["client_id"], ) return {"response_args": reg_resp, "cookie": _cookie} diff --git a/src/oidcop/oidc/session.py b/src/oidcop/oidc/session.py index ef2604a0..b3db0e42 100644 --- a/src/oidcop/oidc/session.py +++ b/src/oidcop/oidc/session.py @@ -89,28 +89,20 @@ class Session(Endpoint): def __init__(self, server_get, **kwargs): _csi = kwargs.get("check_session_iframe") if _csi and not _csi.startswith("http"): - kwargs["check_session_iframe"] = add_path( - server_get("endpoint_context").issuer, _csi - ) + kwargs["check_session_iframe"] = add_path(server_get("endpoint_context").issuer, _csi) Endpoint.__init__(self, server_get, **kwargs) self.iv = as_bytes(rndstr(24)) def _encrypt_sid(self, sid): - encrypter = AES_GCMEncrypter( - key=as_bytes(self.server_get("endpoint_context").symkey) - ) + encrypter = AES_GCMEncrypter(key=as_bytes(self.server_get("endpoint_context").symkey)) enc_msg = encrypter.encrypt(as_bytes(sid), iv=self.iv) return as_unicode(b64e(enc_msg)) def _decrypt_sid(self, enc_msg): _msg = b64d(as_bytes(enc_msg)) - encrypter = AES_GCMEncrypter( - key=as_bytes(self.server_get("endpoint_context").symkey) - ) + encrypter = AES_GCMEncrypter(key=as_bytes(self.server_get("endpoint_context").symkey)) ctx, tag = split_ctx_and_tag(_msg) - return as_unicode( - encrypter.decrypt(as_bytes(ctx), iv=self.iv, tag=as_bytes(tag)) - ) + return as_unicode(encrypter.decrypt(as_bytes(ctx), iv=self.iv, tag=as_bytes(tag))) def do_back_channel_logout(self, cinfo, sid): """ @@ -198,9 +190,7 @@ def unpack_signed_jwt(self, sjwt, sig_alg=""): else: alg = self.kwargs["signing_alg"] - sign_keys = self.server_get("endpoint_context").keyjar.get_signing_key( - alg2keytype(alg) - ) + sign_keys = self.server_get("endpoint_context").keyjar.get_signing_key(alg2keytype(alg)) _info = _jwt.verify_compact(keys=sign_keys, sigalg=alg) return _info else: @@ -209,9 +199,7 @@ def unpack_signed_jwt(self, sjwt, sig_alg=""): def logout_from_client(self, sid): _context = self.server_get("endpoint_context") _cdb = _context.cdb - _session_information = _context.session_manager.get_session_info( - sid, grant=True - ) + _session_information = _context.session_manager.get_session_info(sid, grant=True) _client_id = _session_information["client_id"] res = {} @@ -221,9 +209,7 @@ def logout_from_client(self, sid): res["blu"] = {_client_id: _spec} elif "frontchannel_logout_uri" in _cdb[_client_id]: # Construct an IFrame - _spec = do_front_channel_logout_iframe( - _cdb[_client_id], _context.issuer, sid - ) + _spec = do_front_channel_logout_iframe(_cdb[_client_id], _context.issuer, sid) if _spec: res["flu"] = {_client_id: _spec} @@ -249,9 +235,7 @@ def process_request( if "post_logout_redirect_uri" in request: if "id_token_hint" not in request: - raise InvalidRequest( - "If post_logout_redirect_uri then id_token_hint is a MUST" - ) + raise InvalidRequest("If post_logout_redirect_uri then id_token_hint is a MUST") _cookies = http_info.get("cookie") _session_info = None @@ -269,9 +253,7 @@ def process_request( _cookie_info = json.loads(_cookie_infos[0]["value"]) logger.debug("Cookie info: {}".format(_cookie_info)) try: - _session_info = _mngr.get_session_info( - _cookie_info["sid"], grant=True - ) + _session_info = _mngr.get_session_info(_cookie_info["sid"], grant=True) except KeyError: raise ValueError("Can't find any corresponding session") @@ -301,21 +283,14 @@ def process_request( _uri = request["post_logout_redirect_uri"] except KeyError: if _context.issuer.endswith("/"): - _uri = "{}{}".format( - _context.issuer, self.kwargs["post_logout_uri_path"] - ) + _uri = "{}{}".format(_context.issuer, self.kwargs["post_logout_uri_path"]) else: - _uri = "{}/{}".format( - _context.issuer, self.kwargs["post_logout_uri_path"] - ) + _uri = "{}/{}".format(_context.issuer, self.kwargs["post_logout_uri_path"]) plur = False else: plur = True verify_uri( - _context, - request, - "post_logout_redirect_uri", - client_id=_session_info["client_id"], + _context, request, "post_logout_redirect_uri", client_id=_session_info["client_id"], ) payload = { @@ -339,9 +314,7 @@ def process_request( ) sjwt = _jws.pack(payload=payload, recv=_context.issuer) - location = "{}?{}".format( - self.kwargs["logout_verify_url"], urlencode({"sjwt": sjwt}) - ) + location = "{}?{}".format(self.kwargs["logout_verify_url"], urlencode({"sjwt": sjwt})) return {"redirect_location": location} def parse_request(self, request, http_info=None, **kwargs): @@ -383,9 +356,7 @@ def parse_request(self, request, http_info=None, **kwargs): else: if ( _ith.jws_header["alg"] - not in _context.provider_info[ - "id_token_signing_alg_values_supported" - ] + not in _context.provider_info["id_token_signing_alg_values_supported"] ): raise JWSException("Unsupported signing algorithm") @@ -424,7 +395,5 @@ def kill_cookies(self): session_mngmnt = _handler.make_cookie_content( value="", name=_handler.name["session_management"], max_age=-1 ) - session = _handler.make_cookie_content( - value="", name=_handler.name["session"], max_age=-1 - ) + session = _handler.make_cookie_content(value="", name=_handler.name["session"], max_age=-1) return [session_mngmnt, session] diff --git a/src/oidcop/oidc/token.py b/src/oidcop/oidc/token.py index a528d7ad..58dd90f7 100755 --- a/src/oidcop/oidc/token.py +++ b/src/oidcop/oidc/token.py @@ -7,94 +7,20 @@ from cryptojwt.jwt import utc_time_sans_frac from oidcmsg import oidc from oidcmsg.message import Message -from oidcmsg.oauth2 import ResponseMessage from oidcmsg.oidc import RefreshAccessTokenRequest from oidcmsg.oidc import TokenErrorResponse -from oidcmsg.time_util import time_sans_frac +from oidcop import oauth2 from oidcop import sanitize -from oidcop.endpoint import Endpoint -from oidcop.exception import ProcessError +from oidcop.oauth2.token import TokenEndpointHelper from oidcop.session.grant import AuthorizationCode -from oidcop.session.grant import Grant from oidcop.session.grant import RefreshToken from oidcop.session.token import MintingNotAllowed -from oidcop.session.token import SessionToken from oidcop.token.exception import UnknownToken -from oidcop.util import importer logger = logging.getLogger(__name__) -class TokenEndpointHelper(object): - def __init__(self, endpoint, config=None): - self.endpoint = endpoint - self.config = config - self.error_cls = self.endpoint.error_cls - - def post_parse_request(self, request: Union[Message, dict], - client_id: Optional[str] = "", - **kwargs): - """Context specific parsing of the request. - This is done after general request parsing and before processing - the request. - """ - raise NotImplementedError - - def process_request(self, req: Union[Message, dict], **kwargs): - """Acts on a process request.""" - raise NotImplementedError - - def _mint_token( - self, - type: str, - grant: Grant, - session_id: str, - client_id: str, - based_on: Optional[SessionToken] = None, - token_args: Optional[dict] = None, - ) -> SessionToken: - _context = self.endpoint.server_get("endpoint_context") - _mngr = _context.session_manager - usage_rules = grant.usage_rules.get(type) - if usage_rules: - _exp_in = usage_rules.get("expires_in") - else: - _exp_in = 0 - - token_args = token_args or {} - for meth in _context.token_args_methods: - token_args = meth(_context, client_id, token_args) - - if token_args: - _args = {"token_args": token_args} - else: - _args = {} - - token = grant.mint_token( - session_id, - endpoint_context=_context, - token_type=type, - token_handler=_mngr.token_handler[type], - based_on=based_on, - usage_rules=usage_rules, - **_args, - ) - - if _exp_in: - if isinstance(_exp_in, str): - _exp_in = int(_exp_in) - - if _exp_in: - token.expires_at = time_sans_frac() + _exp_in - - _context.session_manager.set( - _context.session_manager.unpack_session_key(session_id), grant - ) - - return token - - class AccessTokenHelper(TokenEndpointHelper): def process_request(self, req: Union[Message, dict], **kwargs): """ @@ -109,16 +35,12 @@ def process_request(self, req: Union[Message, dict], **kwargs): _log_debug = logger.debug if req["grant_type"] != "authorization_code": - return self.error_cls( - error="invalid_request", error_description="Unknown grant_type" - ) + return self.error_cls(error="invalid_request", error_description="Unknown grant_type") try: _access_code = req["code"].replace(" ", "+") except KeyError: # Missing code parameter - absolutely fatal - return self.error_cls( - error="invalid_request", error_description="Missing code" - ) + return self.error_cls(error="invalid_request", error_description="Missing code") _session_info = _mngr.get_session_info_by_token(_access_code, grant=True) grant = _session_info["grant"] @@ -223,21 +145,15 @@ def post_parse_request( _session_info = _mngr.get_session_info_by_token(request["code"], grant=True) except (KeyError, UnknownToken): logger.error("Access Code invalid") - return self.error_cls( - error="invalid_grant", error_description="Unknown code" - ) + return self.error_cls(error="invalid_grant", error_description="Unknown code") grant = _session_info["grant"] code = grant.get_token(request["code"]) if not isinstance(code, AuthorizationCode): - return self.error_cls( - error="invalid_request", error_description="Wrong token type" - ) + return self.error_cls(error="invalid_request", error_description="Wrong token type") if code.is_active() is False: - return self.error_cls( - error="invalid_request", error_description="Code inactive" - ) + return self.error_cls(error="invalid_request", error_description="Code inactive") _auth_req = grant.authorization_request @@ -255,14 +171,10 @@ def process_request(self, req: Union[Message, dict], **kwargs): _mngr = _context.session_manager if req["grant_type"] != "refresh_token": - return self.error_cls( - error="invalid_request", error_description="Wrong grant_type" - ) + return self.error_cls(error="invalid_request", error_description="Wrong grant_type") token_value = req["refresh_token"] - _session_info = _mngr.get_session_info_by_token( - token_value, grant=True - ) + _session_info = _mngr.get_session_info_by_token(token_value, grant=True) _grant = _session_info["grant"] token = _grant.get_token(token_value) @@ -307,8 +219,7 @@ def process_request(self, req: Union[Message, dict], **kwargs): except (JWEException, NoSuitableSigningKeys) as err: logger.warning(str(err)) resp = self.error_cls( - error="invalid_request", - error_description="Could not sign/encrypt id_token", + error="invalid_request", error_description="Could not sign/encrypt id_token", ) return resp @@ -340,9 +251,7 @@ def post_parse_request( _mngr = _context.session_manager try: - _session_info = _mngr.get_session_info_by_token( - request["refresh_token"], grant=True - ) + _session_info = _mngr.get_session_info_by_token(request["refresh_token"], grant=True) except KeyError: logger.error("Access Code invalid") return self.error_cls(error="invalid_grant") @@ -350,9 +259,7 @@ def post_parse_request( token = _session_info["grant"].get_token(request["refresh_token"]) if not isinstance(token, RefreshToken): - return self.error_cls( - error="invalid_request", error_description="Wrong token type" - ) + return self.error_cls(error="invalid_request", error_description="Wrong token type") if token.is_active() is False: return self.error_cls( @@ -362,14 +269,8 @@ def post_parse_request( return request -HELPER_BY_GRANT_TYPE = { - "authorization_code": AccessTokenHelper, - "refresh_token": RefreshTokenHelper, -} - - -class Token(Endpoint): - request_cls = oidc.Message +class Token(oauth2.token.Token): + request_cls = Message response_cls = oidc.AccessTokenResponse error_cls = TokenErrorResponse request_format = "json" @@ -379,110 +280,7 @@ class Token(Endpoint): endpoint_name = "token_endpoint" name = "token" default_capabilities = {"token_endpoint_auth_signing_alg_values_supported": None} - - def __init__(self, server_get, new_refresh_token=False, **kwargs): - Endpoint.__init__(self, server_get, **kwargs) - self.post_parse_request.append(self._post_parse_request) - if "client_authn_method" in kwargs: - self.endpoint_info["token_endpoint_auth_methods_supported"] = kwargs[ - "client_authn_method" - ] - self.allow_refresh = False - self.new_refresh_token = new_refresh_token - self.configure_grant_types(kwargs.get("grant_types_supported")) - - def configure_grant_types(self, grant_types_supported): - if grant_types_supported is None: - self.helper = {k: v(self) for k, v in HELPER_BY_GRANT_TYPE.items()} - return - - self.helper = {} - # TODO: do we want to allow any grant_type? - for grant_type, grant_type_options in grant_types_supported.items(): - _conf = grant_type_options.get("kwargs", {}) - if _conf is False: - continue - - try: - grant_class = grant_type_options["class"] - except (KeyError, TypeError): - raise ProcessError( - "Token Endpoint's grant types must be True, None or a dict with a" - " 'class' key." - ) - - if isinstance(grant_class, str): - try: - grant_class = importer(grant_class) - except (ValueError, AttributeError): - raise ProcessError( - f"Token Endpoint's grant type class {grant_class} can't" - " be imported." - ) - - try: - self.helper[grant_type] = grant_class(self, _conf) - except Exception as e: - raise ProcessError(f"Failed to initialize class {grant_class}: {e}") - - def _post_parse_request( - self, request: Union[Message, dict], client_id: Optional[str] = "", **kwargs - ): - _helper = self.helper.get(request["grant_type"]) - if _helper: - return _helper.post_parse_request(request, client_id, **kwargs) - else: - return self.error_cls( - error="invalid_request", - error_description=f"Unsupported grant_type: {request['grant_type']}", - ) - - def process_request(self, request: Optional[Union[Message, dict]] = None, **kwargs): - """ - - :param request: - :param kwargs: - :return: Dictionary with response information - """ - if isinstance(request, self.error_cls): - return request - - if request is None: - return self.error_cls(error="invalid_request") - - try: - _helper = self.helper.get(request["grant_type"]) - if _helper: - response_args = _helper.process_request(request, **kwargs) - else: - return self.error_cls( - error="invalid_request", - error_description=f"Unsupported grant_type: {request['grant_type']}", - ) - except JWEException as err: - return self.error_cls(error="invalid_request", error_description="%s" % err) - - if isinstance(response_args, ResponseMessage): - return response_args - - _access_token = response_args["access_token"] - _context = self.server_get("endpoint_context") - _session_info = _context.session_manager.get_session_info_by_token( - _access_token, grant=True - ) - - _cookie = _context.new_cookie( - name=_context.cookie_handler.name["session"], - sub=_session_info["grant"].sub, - sid=_context.session_manager.session_key( - _session_info["user_id"], - _session_info["user_id"], - _session_info["grant"].id, - ), - ) - - _headers = [("Content-type", "application/json")] - resp = {"response_args": response_args, "http_headers": _headers} - if _cookie: - resp["cookie"] = [_cookie] - return resp + helper_by_grant_type = { + "authorization_code": AccessTokenHelper, + "refresh_token": RefreshTokenHelper, + } diff --git a/src/oidcop/oidc/userinfo.py b/src/oidcop/oidc/userinfo.py index 455860e8..3d935c47 100755 --- a/src/oidcop/oidc/userinfo.py +++ b/src/oidcop/oidc/userinfo.py @@ -34,9 +34,7 @@ class UserInfo(Endpoint): "client_authn_method": ["bearer_header"], } - def __init__( - self, server_get: Callable, add_claims_by_scope: Optional[bool] = True, **kwargs - ): + def __init__(self, server_get: Callable, add_claims_by_scope: Optional[bool] = True, **kwargs): Endpoint.__init__( self, server_get, add_claims_by_scope=add_claims_by_scope, **kwargs, ) @@ -108,22 +106,16 @@ def do_response( def process_request(self, request=None, **kwargs): _mngr = self.server_get("endpoint_context").session_manager - _session_info = _mngr.get_session_info_by_token( - request["access_token"], grant=True - ) + _session_info = _mngr.get_session_info_by_token(request["access_token"], grant=True) _grant = _session_info["grant"] token = _grant.get_token(request["access_token"]) # should be an access token if token.type != "access_token": - return self.error_cls( - error="invalid_token", error_description="Wrong type of token" - ) + return self.error_cls(error="invalid_token", error_description="Wrong type of token") # And it should be valid if token.is_active() is False: - return self.error_cls( - error="invalid_token", error_description="Invalid Token" - ) + return self.error_cls(error="invalid_token", error_description="Invalid Token") allowed = True _auth_event = _grant.authentication_event diff --git a/src/oidcop/scopes.py b/src/oidcop/scopes.py index b12e4db4..ec772040 100644 --- a/src/oidcop/scopes.py +++ b/src/oidcop/scopes.py @@ -44,11 +44,7 @@ def convert_scopes2claims(scopes, allowed_claims=None, scope2claim_map=None): else: for scope in scopes: try: - claims = { - name: None - for name in scope2claim_map[scope] - if name in allowed_claims - } + claims = {name: None for name in scope2claim_map[scope] if name in allowed_claims} res.update(claims) except KeyError: continue diff --git a/src/oidcop/server.py b/src/oidcop/server.py index eae929bd..6a06d4ba 100644 --- a/src/oidcop/server.py +++ b/src/oidcop/server.py @@ -7,12 +7,12 @@ from oidcop import authz from oidcop.client_authn import client_auth_setup +from oidcop.configure import ASConfiguration from oidcop.configure import OPConfiguration from oidcop.endpoint import Endpoint from oidcop.endpoint_context import EndpointContext from oidcop.endpoint_context import init_service from oidcop.endpoint_context import init_user_info -from oidcop.session.claims import ClaimsInterface from oidcop.session.manager import create_session_manager from oidcop.user_authn.authn_context import populate_authn_broker from oidcop.util import allow_refresh_token @@ -20,9 +20,7 @@ def do_endpoints(conf, server_get): - endpoints = build_endpoints( - conf["endpoint"], server_get=server_get, issuer=conf["issuer"] - ) + endpoints = build_endpoints(conf["endpoint"], server_get=server_get, issuer=conf["issuer"]) _cap = conf.get("capabilities", {}) @@ -60,7 +58,7 @@ class Server(ImpExp): def __init__( self, - conf: Union[dict, OPConfiguration], + conf: Union[dict, OPConfiguration, ASConfiguration], keyjar: Optional[KeyJar] = None, cwd: Optional[str] = "", cookie_handler: Optional[Any] = None, @@ -69,11 +67,7 @@ def __init__( ImpExp.__init__(self) self.conf = conf self.endpoint_context = EndpointContext( - conf=conf, - keyjar=keyjar, - cwd=cwd, - cookie_handler=cookie_handler, - httpc=httpc, + conf=conf, keyjar=keyjar, cwd=cwd, cookie_handler=cookie_handler, httpc=httpc, ) self.endpoint_context.authz = self.do_authz() @@ -82,9 +76,7 @@ def __init__( self.endpoint = do_endpoints(conf, self.server_get) _cap = get_capabilities(conf, self.endpoint) - self.endpoint_context.provider_info = self.endpoint_context.create_providerinfo( - _cap - ) + self.endpoint_context.provider_info = self.endpoint_context.create_providerinfo(_cap) self.endpoint_context.do_add_on(endpoints=self.endpoint) self.endpoint_context.session_manager = create_session_manager( @@ -102,12 +94,8 @@ def __init__( self.client_authn_method = [] if _methods: - _endpoint.client_authn_method = client_auth_setup( - _methods, self.server_get - ) - elif ( - _methods is not None - ): # [] or '' or something not None but regarded as nothing. + _endpoint.client_authn_method = client_auth_setup(_methods, self.server_get) + elif _methods is not None: # [] or '' or something not None but regarded as nothing. _endpoint.client_authn_method = [None] # Ignore default value elif _endpoint.default_capabilities: _methods = _endpoint.default_capabilities.get("client_authn_method") @@ -122,8 +110,8 @@ def __init__( if _token_endp: _token_endp.allow_refresh = allow_refresh_token(self.endpoint_context) - self.endpoint_context.claims_interface = ClaimsInterface( - server_get=self.server_get + self.endpoint_context.claims_interface = init_service( + conf["claims_interface"], self.server_get ) _id_token_handler = self.endpoint_context.session_manager.token_handler.handler.get( @@ -181,9 +169,7 @@ def do_login_hint_lookup(self): if _kwargs: _userinfo_conf = _kwargs.get("userinfo") if _userinfo_conf: - _userinfo = init_user_info( - _userinfo_conf, self.endpoint_context.cwd - ) + _userinfo = init_user_info(_userinfo_conf, self.endpoint_context.cwd) if _userinfo is None: _userinfo = self.endpoint_context.userinfo diff --git a/src/oidcop/session/claims.py b/src/oidcop/session/claims.py index ba7e8484..603d8de8 100755 --- a/src/oidcop/session/claims.py +++ b/src/oidcop/session/claims.py @@ -25,16 +25,13 @@ def available_claims(endpoint_context): class ClaimsInterface: init_args = {"add_claims_by_scope": False, "enable_claims_per_client": False} + claims_types = ["userinfo", "introspection", "id_token", "access_token"] def __init__(self, server_get): self.server_get = server_get - def authorization_request_claims( - self, session_id: str, usage: Optional[str] = "" - ) -> dict: - _grant = self.server_get("endpoint_context").session_manager.get_grant( - session_id - ) + def authorization_request_claims(self, session_id: str, usage: Optional[str] = "") -> dict: + _grant = self.server_get("endpoint_context").session_manager.get_grant(session_id) if _grant.authorization_request and "claims" in _grant.authorization_request: return _grant.authorization_request["claims"].get(usage, {}) @@ -47,42 +44,45 @@ def _get_client_claims(self, client_id, usage): client_claims = {k: None for k in client_claims} return client_claims - def get_claims(self, session_id: str, scopes: str, usage: str) -> dict: - """ - - :param session_id: Session identifier - :param scopes: Scopes - :param usage: Where to use the claims. One of - "userinfo"/"id_token"/"introspection"/"access_token" - :return: Claims specification as a dictionary. - """ - - _context = self.server_get("endpoint_context") - # which endpoint module configuration to get the base claims from + def _get_module(self, usage, endpoint_context): module = None if usage == "userinfo": module = self.server_get("endpoint", "userinfo") elif usage == "id_token": try: - module = _context.session_manager.token_handler["id_token"] + module = endpoint_context.session_manager.token_handler["id_token"] except KeyError: raise ServiceError("No support for ID Tokens") elif usage == "introspection": module = self.server_get("endpoint", "introspection") elif usage == "access_token": try: - module = _context.session_manager.token_handler["access_token"] + module = endpoint_context.session_manager.token_handler["access_token"] except KeyError: raise ServiceError("No support for Access Tokens") + return module + + def get_claims(self, session_id: str, scopes: str, usage: str) -> dict: + """ + + :param session_id: Session identifier + :param scopes: Scopes + :param usage: Where to use the claims. One of + "userinfo"/"id_token"/"introspection"/"access_token" + :return: Claims specification as a dictionary. + """ + + _context = self.server_get("endpoint_context") + # which endpoint module configuration to get the base claims from + module = self._get_module(usage, _context) + if module: base_claims = module.kwargs.get("base_claims", {}) else: return {} - user_id, client_id, grant_id = _context.session_manager.decrypt_session_id( - session_id - ) + user_id, client_id, grant_id = _context.session_manager.decrypt_session_id(session_id) # Can there be per client specification of which claims to use. if module.kwargs.get("enable_claims_per_client"): @@ -95,20 +95,14 @@ def get_claims(self, session_id: str, scopes: str, usage: str) -> dict: # Scopes can in some cases equate to set of claims, is that used here ? if module.kwargs.get("add_claims_by_scope"): if scopes: - _scopes = _context.scopes_handler.filter_scopes( - client_id, _context, scopes - ) + _scopes = _context.scopes_handler.filter_scopes(client_id, _context, scopes) - _claims = convert_scopes2claims( - _scopes, scope2claim_map=_context.scope2claims - ) + _claims = convert_scopes2claims(_scopes, scope2claim_map=_context.scope2claims) claims.update(_claims) # Bring in claims specification from the authorization request # This only goes for ID Token and user info - request_claims = self.authorization_request_claims( - session_id=session_id, usage=usage - ) + request_claims = self.authorization_request_claims(session_id=session_id, usage=usage) # This will add claims that has not be added before and # set filters on those claims that also appears in one of the sources above @@ -119,7 +113,7 @@ def get_claims(self, session_id: str, scopes: str, usage: str) -> dict: def get_claims_all_usage(self, session_id: str, scopes: str) -> dict: _claims = {} - for usage in ["userinfo", "introspection", "id_token", "access_token"]: + for usage in self.claims_types: _claims[usage] = self.get_claims(session_id, scopes, usage) return _claims @@ -132,9 +126,7 @@ def get_user_claims(self, user_id: str, claims_restriction: dict) -> dict: """ if claims_restriction: # Get all possible claims - user_info = self.server_get("endpoint_context").userinfo( - user_id, client_id=None - ) + user_info = self.server_get("endpoint_context").userinfo(user_id, client_id=None) # Filter out the claims that can be returned return { k: user_info.get(k) @@ -194,3 +186,25 @@ def by_schema(cls, **kwa): :return: A dictionary with claims (keys) that meets the filter criteria """ return dict([(key, val) for key, val in kwa.items() if key in cls.c_param]) + + +class OAuth2ClaimsInterface(ClaimsInterface): + claims_types = ["introspection", "access_token"] + + def _get_module(self, usage, endpoint_context): + module = None + if usage == "introspection": + module = self.server_get("endpoint", "introspection") + elif usage == "access_token": + try: + module = endpoint_context.session_manager.token_handler["access_token"] + except KeyError: + raise ServiceError("No support for Access Tokens") + + return module + + def get_claims_all_usage(self, session_id: str, scopes: str) -> dict: + _claims = {} + for usage in self.claims_types: + _claims[usage] = self.get_claims(session_id, scopes, usage) + return _claims diff --git a/src/oidcop/session/database.py b/src/oidcop/session/database.py index bf75991b..4f5aad7d 100644 --- a/src/oidcop/session/database.py +++ b/src/oidcop/session/database.py @@ -92,18 +92,18 @@ def get(self, path: List[str]) -> Union[SessionInfo, Grant]: try: user_info = self.db[uid] except KeyError: - raise KeyError('No such UserID') + raise KeyError("No such UserID") except TypeError: - raise InconsistentDatabase('Missing session db') + raise InconsistentDatabase("Missing session db") else: if user_info is None: - raise KeyError('No such UserID') + raise KeyError("No such UserID") if client_id is None: return user_info if client_id not in user_info.subordinate: - raise ValueError('No session from that client for that user') + raise ValueError("No session from that client for that user") try: skey = self.session_key(uid, client_id) @@ -115,8 +115,7 @@ def get(self, path: List[str]) -> Union[SessionInfo, Grant]: return client_session_info if grant_id not in client_session_info.subordinate: - raise ValueError( - 'No such grant for that user and client') + raise ValueError("No such grant for that user and client") else: try: skey = self.session_key(uid, client_id, grant_id) @@ -135,9 +134,7 @@ def delete(self, path: List[str]): _user_info = self.db[uid] skey_uid_client = self.session_key(uid, client_id) - skey_uid_client_grant = self.session_key( - uid, client_id, grant_id or '' - ) + skey_uid_client_grant = self.session_key(uid, client_id, grant_id or "") if client_id not in _user_info.subordinate: self.db.__delitem__(client_id) diff --git a/src/oidcop/session/grant.py b/src/oidcop/session/grant.py index 6591c722..1139214c 100644 --- a/src/oidcop/session/grant.py +++ b/src/oidcop/session/grant.py @@ -186,30 +186,21 @@ def payload_arguments( if not scope: scope = self.scope - payload = { - "scope": scope, - "aud": self.resources, - "jti": uuid1().hex - } + payload = {"scope": scope, "aud": self.resources, "jti": uuid1().hex} if extra_payload: payload.update(extra_payload) if self.authorization_request: - client_id = self.authorization_request.get('client_id') + client_id = self.authorization_request.get("client_id") if client_id: - payload.update({ - "client_id": client_id, - 'sub': client_id - }) + payload.update({"client_id": client_id, "sub": client_id}) _claims_restriction = endpoint_context.claims_interface.get_claims( session_id, scopes=scope, usage=token_type ) user_id, _, _ = endpoint_context.session_manager.decrypt_session_id(session_id) - user_info = endpoint_context.claims_interface.get_user_claims( - user_id, _claims_restriction - ) + user_info = endpoint_context.claims_interface.get_user_claims(user_id, _claims_restriction) payload.update(user_info) return payload @@ -254,12 +245,8 @@ def mint_token( token_class = self.token_map.get(token_type) if token_type == "id_token": - class_args = { - k: v for k, v in kwargs.items() if k not in ["code", "access_token"] - } - handler_args = { - k: v for k, v in kwargs.items() if k in ["code", "access_token"] - } + class_args = {k: v for k, v in kwargs.items() if k not in ["code", "access_token"]} + handler_args = {k: v for k, v in kwargs.items() if k in ["code", "access_token"]} else: class_args = kwargs handler_args = {} @@ -274,15 +261,17 @@ def mint_token( ) if token_handler is None: token_handler = endpoint_context.session_manager.token_handler.handler[ - GRANT_TYPE_MAP[token_type]] + GRANT_TYPE_MAP[token_type] + ] - token_payload = self.payload_arguments(session_id, - endpoint_context, - token_type=token_type, - scope=scope, - extra_payload=handler_args) - item.value = token_handler(session_id=session_id, - **token_payload) + token_payload = self.payload_arguments( + session_id, + endpoint_context, + token_type=token_type, + scope=scope, + extra_payload=handler_args, + ) + item.value = token_handler(session_id=session_id, **token_payload) else: raise ValueError("Can not mint that kind of token") @@ -298,10 +287,7 @@ def get_token(self, value: str) -> Optional[SessionToken]: return None def revoke_token( - self, - value: Optional[str] = "", - based_on: Optional[str] = "", - recursive: bool = True, + self, value: Optional[str] = "", based_on: Optional[str] = "", recursive: bool = True, ): for t in self.issued_token: if not value and not based_on: @@ -341,9 +327,7 @@ def get_spec(self, token: SessionToken) -> Optional[dict]: "expires_in": 300, }, "access_token": {"supports_minting": [], "expires_in": 3600}, - "refresh_token": { - "supports_minting": ["access_token", "refresh_token", "id_token"] - }, + "refresh_token": {"supports_minting": ["access_token", "refresh_token", "id_token"]}, } diff --git a/src/oidcop/session/info.py b/src/oidcop/session/info.py index 1b96acfa..3d93c02d 100644 --- a/src/oidcop/session/info.py +++ b/src/oidcop/session/info.py @@ -9,11 +9,11 @@ class SessionInfo(ImpExp): parameter = {"subordinate": [], "revoked": bool, "type": "", "extra_args": {}} def __init__( - self, - subordinate: Optional[List[str]] = None, - revoked: Optional[bool] = False, - type: Optional[str] = "", - **kwargs + self, + subordinate: Optional[List[str]] = None, + revoked: Optional[bool] = False, + type: Optional[str] = "", + **kwargs ): ImpExp.__init__(self) self.subordinate = subordinate or [] @@ -44,7 +44,7 @@ def keys(self): class UserSessionInfo(SessionInfo): parameter = SessionInfo.parameter.copy() parameter.update( - {"user_id": "", } + {"user_id": "",} ) def __init__(self, **kwargs): diff --git a/src/oidcop/session/manager.py b/src/oidcop/session/manager.py index 23c6a255..7bfb0271 100644 --- a/src/oidcop/session/manager.py +++ b/src/oidcop/session/manager.py @@ -37,10 +37,10 @@ def __init__(self, salt: Optional[str] = "", filename: Optional[str] = ""): elif filename: if os.path.isfile(filename): self.salt = open(filename).read() - elif not os.path.isfile(filename) and os.path.exists(filename): # Not a file, Something else - raise ConfigurationError( - "Salt filename points to something that is not a file" - ) + elif not os.path.isfile(filename) and os.path.exists( + filename + ): # Not a file, Something else + raise ConfigurationError("Salt filename points to something that is not a file") else: self.salt = rndstr(24) # May raise an exception @@ -73,10 +73,7 @@ class SessionManager(Database): init_args = ["handler"] def __init__( - self, - handler: TokenHandler, - conf: Optional[dict] = None, - sub_func: Optional[dict] = None, + self, handler: TokenHandler, conf: Optional[dict] = None, sub_func: Optional[dict] = None, ): if conf: _key = conf.get("password", rndstr(24)) @@ -109,7 +106,7 @@ def get_user_info(self, uid: str) -> UserSessionInfo: usi = self.get([uid]) if isinstance(usi, UserSessionInfo): return usi - else: # pragma: no cover + else: # pragma: no cover raise ValueError("Not UserSessionInfo") def find_token(self, session_id: str, token_value: str) -> Optional[SessionToken]: @@ -125,17 +122,17 @@ def find_token(self, session_id: str, token_value: str) -> Optional[SessionToken if token.value == token_value: return token - return None # pragma: no cover + return None # pragma: no cover def create_grant( - self, - authn_event: AuthnEvent, - auth_req: AuthorizationRequest, - user_id: str, - client_id: Optional[str] = "", - sub_type: Optional[str] = "public", - token_usage_rules: Optional[dict] = None, - scopes: Optional[list] = None, + self, + authn_event: AuthnEvent, + auth_req: AuthorizationRequest, + user_id: str, + client_id: Optional[str] = "", + sub_type: Optional[str] = "public", + token_usage_rules: Optional[dict] = None, + scopes: Optional[list] = None, ) -> str: """ @@ -165,14 +162,14 @@ def create_grant( return self.encrypted_session_id(user_id, client_id, grant.id) def create_session( - self, - authn_event: AuthnEvent, - auth_req: AuthorizationRequest, - user_id: str, - client_id: Optional[str] = "", - sub_type: Optional[str] = "public", - token_usage_rules: Optional[dict] = None, - scopes: Optional[list] = None, + self, + authn_event: AuthnEvent, + auth_req: AuthorizationRequest, + user_id: str, + client_id: Optional[str] = "", + sub_type: Optional[str] = "public", + token_usage_rules: Optional[dict] = None, + scopes: Optional[list] = None, ) -> str: """ Create part of a user session. The parts added are user- and client @@ -228,7 +225,7 @@ def get_client_session_info(self, session_id: str) -> ClientSessionInfo: csi = self.get([_user_id, _client_id]) if isinstance(csi, ClientSessionInfo): return csi - else: # pragma: no cover + else: # pragma: no cover raise ValueError("Wrong type of session info") def get_user_session_info(self, session_id: str) -> UserSessionInfo: @@ -242,7 +239,7 @@ def get_user_session_info(self, session_id: str) -> UserSessionInfo: usi = self.get([_user_id]) if isinstance(usi, UserSessionInfo): return usi - else: # pragma: no cover + else: # pragma: no cover raise ValueError("Wrong type of session info") def get_grant(self, session_id: str) -> Grant: @@ -256,13 +253,13 @@ def get_grant(self, session_id: str) -> Grant: grant = self.get([_user_id, _client_id, _grant_id]) if isinstance(grant, Grant): return grant - else: # pragma: no cover + else: # pragma: no cover raise ValueError("Wrong type of item") def _revoke_dependent(self, grant: Grant, token: SessionToken): for t in grant.issued_token: if t.based_on == token.value: - t.revoked = True # TODO: not covered yet! + t.revoked = True # TODO: not covered yet! self._revoke_dependent(grant, t) def revoke_token(self, session_id: str, token_value: str, recursive: bool = False): @@ -275,19 +272,19 @@ def revoke_token(self, session_id: str, token_value: str, recursive: bool = Fals tokens minted by this token. Recursively. """ token = self.find_token(session_id, token_value) - if token is None: # pragma: no cover + if token is None: # pragma: no cover raise UnknownToken() token.revoked = True - if recursive: # TODO: not covered yet! + if recursive: # TODO: not covered yet! grant = self[session_id] self._revoke_dependent(grant, token) def get_authentication_events( - self, - session_id: Optional[str] = "", - user_id: Optional[str] = "", - client_id: Optional[str] = "", + self, + session_id: Optional[str] = "", + user_id: Optional[str] = "", + client_id: Optional[str] = "", ) -> List[AuthnEvent]: """ Return the authentication events that exists for a user/client combination. @@ -346,10 +343,10 @@ def revoke_grant(self, session_id: str): self.set(_path, _info) def grants( - self, - session_id: Optional[str] = "", - user_id: Optional[str] = "", - client_id: Optional[str] = "", + self, + session_id: Optional[str] = "", + user_id: Optional[str] = "", + client_id: Optional[str] = "", ) -> List[Grant]: """ Find all grant connected to a user session @@ -370,13 +367,13 @@ def grants( return [self.get([user_id, client_id, gid]) for gid in _csi.subordinate] def get_session_info( - self, - session_id: str, - user_session_info: bool = False, - client_session_info: bool = False, - grant: bool = False, - authentication_event: bool = False, - authorization_request: bool = False, + self, + session_id: str, + user_session_info: bool = False, + client_session_info: bool = False, + grant: bool = False, + authentication_event: bool = False, + authorization_request: bool = False, ) -> dict: """ Returns information connected to a session. @@ -424,13 +421,13 @@ def get_session_info( return res def get_session_info_by_token( - self, - token_value: str, - user_session_info: bool = False, - client_session_info: bool = False, - grant: bool = False, - authentication_event: bool = False, - authorization_request: bool = False, + self, + token_value: str, + user_session_info: bool = False, + client_session_info: bool = False, + grant: bool = False, + authentication_event: bool = False, + authorization_request: bool = False, ) -> dict: _token_info = self.token_handler.info(token_value) return self.get_session_info( diff --git a/src/oidcop/token/__init__.py b/src/oidcop/token/__init__.py index 50d8dc01..ba3bb1ea 100755 --- a/src/oidcop/token/__init__.py +++ b/src/oidcop/token/__init__.py @@ -31,9 +31,7 @@ def __init__(self, typ, lifetime=300, **kwargs): self.lifetime = lifetime self.kwargs = kwargs - def __call__( - self, session_id: Optional[str] = "", ttype: Optional[str] = "", **payload - ) -> str: + def __call__(self, session_id: Optional[str] = "", ttype: Optional[str] = "", **payload) -> str: """ Return a token. @@ -72,9 +70,7 @@ def __init__(self, password, typ="", token_type="Bearer", **kwargs): self.crypt = Crypt(password) self.token_type = token_type - def __call__( - self, session_id: Optional[str] = "", ttype: Optional[str] = "", **payload - ) -> str: + def __call__(self, session_id: Optional[str] = "", ttype: Optional[str] = "", **payload) -> str: """ Return a token. diff --git a/src/oidcop/token/handler.py b/src/oidcop/token/handler.py index 782ae8dd..85800c8f 100755 --- a/src/oidcop/token/handler.py +++ b/src/oidcop/token/handler.py @@ -3,13 +3,13 @@ from typing import Optional import warnings -from cryptography.fernet import InvalidToken from cryptojwt.exception import Invalid from cryptojwt.key_jar import init_key_jar from cryptojwt.utils import as_unicode from oidcmsg.impexp import ImpExp from oidcmsg.item import DLDict +from oidcop.exception import InvalidToken from oidcop.token import DefaultToken from oidcop.token import Token from oidcop.token import UnknownToken @@ -25,11 +25,11 @@ class TokenHandler(ImpExp): parameter = {"handler": DLDict, "handler_order": [""]} def __init__( - self, - access_token_handler: Optional[Token] = None, - code_handler: Optional[Token] = None, - refresh_token_handler: Optional[Token] = None, - id_token_handler: Optional[Token] = None, + self, + access_token_handler: Optional[Token] = None, + code_handler: Optional[Token] = None, + refresh_token_handler: Optional[Token] = None, + id_token_handler: Optional[Token] = None, ): ImpExp.__init__(self) self.handler = {"code": code_handler, "access_token": access_token_handler} @@ -72,7 +72,7 @@ def get_handler(self, token, order=None): for typ in order: try: res = self.handler[typ].info(token) - except (KeyError, WrongTokenType, InvalidToken, UnknownToken, Invalid): + except (KeyError, WrongTokenType, InvalidToken, UnknownToken, Invalid, AttributeError): pass else: return self.handler[typ], res @@ -142,13 +142,13 @@ def default_token(spec): def factory( - server_get, - code: Optional[dict] = None, - token: Optional[dict] = None, - refresh: Optional[dict] = None, - id_token: Optional[dict] = None, - jwks_file: Optional[str] = "", - **kwargs + server_get, + code: Optional[dict] = None, + token: Optional[dict] = None, + refresh: Optional[dict] = None, + id_token: Optional[dict] = None, + jwks_file: Optional[str] = "", + **kwargs ) -> TokenHandler: """ Create a token handler @@ -165,12 +165,12 @@ def factory( key_defs = [] read_only = False cwd = server_get("endpoint_context").cwd - if kwargs.get('jwks_def'): - defs = kwargs['jwks_def'] + if kwargs.get("jwks_def"): + defs = kwargs["jwks_def"] if not jwks_file: - jwks_file = defs.get('private_path', os.path.join(cwd, JWKS_FILE)) - read_only = defs.get('read_only', read_only) - key_defs = defs.get('key_defs', []) + jwks_file = defs.get("private_path", os.path.join(cwd, JWKS_FILE)) + read_only = defs.get("read_only", read_only) + key_defs = defs.get("key_defs", []) if not jwks_file: jwks_file = os.path.join(cwd, JWKS_FILE) @@ -179,22 +179,20 @@ def factory( for kid, cnf in [("code", code), ("refresh", refresh), ("token", token)]: if cnf is not None: if default_token(cnf): - key_defs.append( - {"type": "oct", "bytes": 24, "use": ["enc"], "kid": kid} - ) + key_defs.append({"type": "oct", "bytes": 24, "use": ["enc"], "kid": kid}) kj = init_key_jar(key_defs=key_defs, private_path=jwks_file, read_only=read_only) args = {} - for typ, cnf, attr in [("code", code, "code_handler"), - ("token", token, "access_token_handler"), - ("refresh", refresh, "refresh_token_handler")]: + for typ, cnf, attr in [ + ("code", code, "code_handler"), + ("token", token, "access_token_handler"), + ("refresh", refresh, "refresh_token_handler"), + ]: if cnf is not None: if default_token(cnf): _add_passwd(kj, cnf, typ) - args[attr] = init_token_handler( - server_get, cnf, TTYPE[typ] - ) + args[attr] = init_token_handler(server_get, cnf, TTYPE[typ]) if id_token is not None: args["id_token_handler"] = init_token_handler(server_get, id_token, typ="") diff --git a/src/oidcop/token/id_token.py b/src/oidcop/token/id_token.py index 8e137331..a6d1d9a0 100755 --- a/src/oidcop/token/id_token.py +++ b/src/oidcop/token/id_token.py @@ -13,6 +13,7 @@ from oidcop.token import is_expired from . import Token from . import UnknownToken +from ..exception import InvalidToken from ..util import get_logout_id logger = logging.getLogger(__name__) @@ -58,14 +59,12 @@ def include_session_id(endpoint_context, client_id, where): def get_sign_and_encrypt_algorithms( - endpoint_context, client_info, payload_type, sign=False, encrypt=False + endpoint_context, client_info, payload_type, sign=False, encrypt=False ): args = {"sign": sign, "encrypt": encrypt} if sign: try: - args["sign_alg"] = client_info[ - "{}_signed_response_alg".format(payload_type) - ] + args["sign_alg"] = client_info["{}_signed_response_alg".format(payload_type)] except KeyError: # Fall back to default try: args["sign_alg"] = endpoint_context.jwx_def["signing_alg"][payload_type] @@ -88,9 +87,7 @@ def get_sign_and_encrypt_algorithms( args["enc_alg"] = client_info["%s_encrypted_response_alg" % payload_type] except KeyError: try: - args["enc_alg"] = endpoint_context.jwx_def["encryption_alg"][ - payload_type - ] + args["enc_alg"] = endpoint_context.jwx_def["encryption_alg"][payload_type] except KeyError: _supported = endpoint_context.provider_info.get( "{}_encryption_alg_values_supported".format(payload_type) @@ -102,9 +99,7 @@ def get_sign_and_encrypt_algorithms( args["enc_enc"] = client_info["%s_encrypted_response_enc" % payload_type] except KeyError: try: - args["enc_enc"] = endpoint_context.jwx_def["encryption_enc"][ - payload_type - ] + args["enc_enc"] = endpoint_context.jwx_def["encryption_enc"][payload_type] except KeyError: _supported = endpoint_context.provider_info.get( "{}_encryption_enc_values_supported".format(payload_type) @@ -123,23 +118,21 @@ class IDToken(Token): } def __init__( - self, - typ: Optional[str] = "I", - lifetime: Optional[int] = 300, - server_get: Callable = None, - **kwargs + self, + typ: Optional[str] = "I", + lifetime: Optional[int] = 300, + server_get: Callable = None, + **kwargs ): Token.__init__(self, typ, **kwargs) self.lifetime = lifetime self.server_get = server_get self.kwargs = kwargs self.scope_to_claims = None - self.provider_info = construct_endpoint_info( - self.default_capabilities, **kwargs - ) + self.provider_info = construct_endpoint_info(self.default_capabilities, **kwargs) def payload( - self, session_id, alg="RS256", code=None, access_token=None, extra_claims=None, + self, session_id, alg="RS256", code=None, access_token=None, extra_claims=None, ): """ @@ -167,8 +160,7 @@ def payload( user_info = None else: user_info = _context.claims_interface.get_user_claims( - user_id=session_information["user_id"], - claims_restriction=_claims_restriction, + user_id=session_information["user_id"], claims_restriction=_claims_restriction, ) if _claims_restriction and "acr" in _claims_restriction and "acr" in _args: if claims_match(_args["acr"], _claims_restriction["acr"]) is False: @@ -209,15 +201,15 @@ def payload( return _args def sign_encrypt( - self, - session_id, - client_id, - code=None, - access_token=None, - sign=True, - encrypt=False, - lifetime=None, - extra_claims=None, + self, + session_id, + client_id, + code=None, + access_token=None, + sign=True, + encrypt=False, + lifetime=None, + extra_claims=None, ) -> str: """ Signed and or encrypt a IDToken @@ -255,18 +247,15 @@ def sign_encrypt( return _jwt.pack(_payload, recv=client_id) - def __call__( - self, session_id: Optional[str] = "", ttype: Optional[str] = "", **kwargs - ) -> str: + def __call__(self, session_id: Optional[str] = "", ttype: Optional[str] = "", **kwargs) -> str: _context = self.server_get("endpoint_context") - user_id, client_id, grant_id = _context.session_manager.decrypt_session_id( - session_id - ) + user_id, client_id, grant_id = _context.session_manager.decrypt_session_id(session_id) # Should I add session ID. This is about Single Logout. - if include_session_id(_context, client_id, "back") or \ - include_session_id(_context, client_id, "front"): + if include_session_id(_context, client_id, "back") or include_session_id( + _context, client_id, "front" + ): xargs = {"sid": get_logout_id(_context, user_id=user_id, client_id=client_id)} else: @@ -275,17 +264,10 @@ def __call__( lifetime = self.kwargs.get("lifetime") # Weed out stuff that doesn't belong here - kwargs = { - k: v for k, v in kwargs.items() if k in ["encrypt", "code", "access_token"] - } + kwargs = {k: v for k, v in kwargs.items() if k in ["encrypt", "code", "access_token"]} id_token = self.sign_encrypt( - session_id, - client_id, - sign=True, - lifetime=lifetime, - extra_claims=xargs, - **kwargs + session_id, client_id, sign=True, lifetime=lifetime, extra_claims=xargs, **kwargs ) return id_token @@ -302,17 +284,15 @@ def info(self, token): _context = self.server_get("endpoint_context") _jwt = factory(token) + if not _jwt: + raise InvalidToken("Not valid token") + _payload = _jwt.jwt.payload() client_id = _payload["aud"][0] client_info = _context.cdb[client_id] - alg_dict = get_sign_and_encrypt_algorithms( - _context, client_info, "id_token", sign=True - ) + alg_dict = get_sign_and_encrypt_algorithms(_context, client_info, "id_token", sign=True) - verifier = JWT( - key_jar=_context.keyjar, - allowed_sign_algs=alg_dict["sign_alg"] - ) + verifier = JWT(key_jar=_context.keyjar, allowed_sign_algs=alg_dict["sign_alg"]) try: _payload = verifier.unpack(token) except JWSException: @@ -322,7 +302,7 @@ def info(self, token): raise ToOld("Token has expired") # All the token metadata return { - "sid": _payload.get("sid", ''), # TODO: would sid be there? + "sid": _payload.get("sid", ""), # TODO: would sid be there? # "type": _payload["ttype"], "exp": _payload["exp"], "aud": client_id, diff --git a/src/oidcop/token/jwt_token.py b/src/oidcop/token/jwt_token.py index 4e93b742..312c460d 100644 --- a/src/oidcop/token/jwt_token.py +++ b/src/oidcop/token/jwt_token.py @@ -15,17 +15,17 @@ class JWTToken(Token): def __init__( - self, - typ, - # keyjar: KeyJar = None, - issuer: str = None, - aud: Optional[list] = None, - alg: str = "ES256", - lifetime: int = 300, - server_get: Callable = None, - token_type: str = "Bearer", - password: str = "", - **kwargs + self, + typ, + # keyjar: KeyJar = None, + issuer: str = None, + aud: Optional[list] = None, + alg: str = "ES256", + lifetime: int = 300, + server_get: Callable = None, + token_type: str = "Bearer", + password: str = "", + **kwargs ): Token.__init__(self, typ, **kwargs) self.token_type = token_type @@ -46,10 +46,7 @@ def load_custom_claims(self, payload: dict = None): # inherit me and do your things here return payload - def __call__(self, - session_id: Optional[str] = '', - ttype: Optional[str] = '', - **payload) -> str: + def __call__(self, session_id: Optional[str] = "", ttype: Optional[str] = "", **payload) -> str: """ Return a token. @@ -65,20 +62,13 @@ def __call__(self, else: ttype = "A" - payload.update( - {"sid": session_id, - "ttype": ttype - } - ) + payload.update({"sid": session_id, "ttype": ttype}) payload = self.load_custom_claims(payload) # payload.update(kwargs) _context = self.server_get("endpoint_context") signer = JWT( - key_jar=_context.keyjar, - iss=self.issuer, - lifetime=self.lifetime, - sign_alg=self.alg, + key_jar=_context.keyjar, iss=self.issuer, lifetime=self.lifetime, sign_alg=self.alg, ) return signer.pack(payload) diff --git a/src/oidcop/user_authn/user.py b/src/oidcop/user_authn/user.py index 8ff48f33..b31238a4 100755 --- a/src/oidcop/user_authn/user.py +++ b/src/oidcop/user_authn/user.py @@ -3,6 +3,7 @@ import inspect import json import logging +import os import sys import time import warnings @@ -32,9 +33,9 @@ }, "se": { "title": "Logga in", - "login_title": u"Användarnamn", - "passwd_title": u"Lösenord", - "submit_text": u"Sänd", + "login_title": "Användarnamn", + "passwd_title": "Lösenord", + "submit_text": "Sänd", "client_policy_title": "Klientens sekretesspolicy", }, } @@ -75,9 +76,7 @@ def verify(self, *args, **kwargs): raise NotImplementedError def unpack_token(self, token): - return verify_signed_jwt( - token=token, keyjar=self.server_get("endpoint_context").keyjar - ) + return verify_signed_jwt(token=token, keyjar=self.server_get("endpoint_context").keyjar) def done(self, areq): """ @@ -138,7 +137,7 @@ def __init__( template="user_pass.jinja2", server_get=None, verify_endpoint="", - **kwargs + **kwargs, ): super(UserPassJinja2, self).__init__(server_get=server_get) @@ -170,9 +169,7 @@ def __call__(self, **kwargs): OnlyForTestingWarning, ) if not self.server_get: - raise Exception( - f"{self.__class__.__name__} doesn't have a working server_get" - ) + raise Exception(f"{self.__class__.__name__} doesn't have a working server_get") _context = self.server_get("endpoint_context") # Stores information need afterwards in a signed JWT that then # appears as a hidden input in the form @@ -188,9 +185,7 @@ def __call__(self, **kwargs): _label = "{}_label".format(attr) _kwargs[_label] = LABELS[_uri] - return self.template_handler.render( - self.template, action=self.action, token=jws, **_kwargs - ) + return self.template_handler.render(self.template, action=self.action, token=jws, **_kwargs) def verify(self, *args, **kwargs): username = kwargs["username"] @@ -254,7 +249,7 @@ def authenticated_as(self, client_id, cookie=None, authorization="", **kwargs): try: aesgcm = AESGCM(self.symkey) user = aesgcm.decrypt(iv, encmsg, None) - except (AssertionError, KeyError): # pragma: no-cover + except (AssertionError, KeyError): # pragma: no-cover raise FailedAuthentication("Decryption failed") res = {"uid": user} @@ -280,7 +275,7 @@ def authenticated_as(self, client_id="", cookie=None, authorization="", **kwargs :param kwargs: extra key word arguments :return: """ - if self.fail: # pragma: no-cover + if self.fail: # pragma: no-cover raise self.fail() res = {"uid": self.user} diff --git a/src/oidcop/util.py b/src/oidcop/util.py index ac47bd8e..613de27c 100755 --- a/src/oidcop/util.py +++ b/src/oidcop/util.py @@ -142,9 +142,7 @@ def lv_unpack(txt): class Crypt(object): def __init__(self, password, mode=None): - self.key = base64.urlsafe_b64encode( - hashlib.sha256(password.encode("utf-8")).digest() - ) + self.key = base64.urlsafe_b64encode(hashlib.sha256(password.encode("utf-8")).digest()) self.core = Fernet(self.key) def encrypt(self, text): @@ -200,9 +198,7 @@ def split_uri(uri): def allow_refresh_token(endpoint_context): # Are there a refresh_token handler - refresh_token_handler = endpoint_context.session_manager.token_handler.handler[ - "refresh_token" - ] + refresh_token_handler = endpoint_context.session_manager.token_handler.handler["refresh_token"] # Is refresh_token grant type supported _token_supported = False @@ -254,4 +250,4 @@ def get_logout_id(endpoint_context, user_id, client_id): _mngr = endpoint_context.session_manager _mngr.set([logout_session_id], _item) - return logout_session_id \ No newline at end of file + return logout_session_id diff --git a/src/oidcop/utils.py b/src/oidcop/utils.py index f292b497..1e30dd9a 100644 --- a/src/oidcop/utils.py +++ b/src/oidcop/utils.py @@ -7,7 +7,7 @@ import yaml -def load_json(file_name): # pragma: no cover +def load_json(file_name): # pragma: no cover with open(file_name) as fp: js = json.load(fp) return js @@ -19,7 +19,7 @@ def load_yaml_config(file_name): return c -def yaml_to_py_stream(file_name): # pragma: no cover +def yaml_to_py_stream(file_name): # pragma: no cover d = load_yaml_config(file_name) fstream = io.StringIO() for i in d: @@ -29,14 +29,14 @@ def yaml_to_py_stream(file_name): # pragma: no cover return fstream -def lower_or_upper(config, param, default=None): # pragma: no cover +def lower_or_upper(config, param, default=None): # pragma: no cover res = config.get(param.lower(), default) if not res: res = config.get(param.upper(), default) return res -def create_context(dir_path, config, **kwargs): # pragma: no cover +def create_context(dir_path, config, **kwargs): # pragma: no cover _fname = lower_or_upper(config, "server_cert") if _fname: if _fname.startswith("/"): diff --git a/tests/op_config_defaults.py b/tests/op_config_defaults.py index d148d3ae..429b42ec 100644 --- a/tests/op_config_defaults.py +++ b/tests/op_config_defaults.py @@ -6,10 +6,7 @@ "kwargs": { "verify_endpoint": "verify/user", "template": "user_pass.jinja2", - "db": { - "class": "oidcop.util.JSONDictDB", - "kwargs": {"filename": "passwd.json"}, - }, + "db": {"class": "oidcop.util.JSONDictDB", "kwargs": {"filename": "passwd.json"},}, "page_header": "Testing log in", "submit_btn": "Get me in!", "user_label": "Nickname", @@ -40,10 +37,7 @@ "registration": { "path": "registration", "class": "oidcop.oidc.registration.Registration", - "kwargs": { - "client_authn_method": None, - "client_secret_expiration_time": 432000, - }, + "kwargs": {"client_authn_method": None, "client_secret_expiration_time": 432000,}, }, "registration_api": { "path": "registration_api", @@ -53,10 +47,7 @@ "introspection": { "path": "introspection", "class": "oidcop.oauth2.introspection.Introspection", - "kwargs": { - "client_authn_method": ["client_secret_post"], - "release": ["username"], - }, + "kwargs": {"client_authn_method": ["client_secret_post"], "release": ["username"],}, }, "authorization": { "path": "authorization", @@ -94,9 +85,7 @@ "userinfo": { "path": "userinfo", "class": "oidcop.oidc.userinfo.UserInfo", - "kwargs": { - "claim_types_supported": ["normal", "aggregated", "distributed"] - }, + "kwargs": {"claim_types_supported": ["normal", "aggregated", "distributed"]}, }, "end_session": { "path": "session", @@ -126,9 +115,7 @@ "login_hint2acrs": { "class": "oidcop.login_hint.LoginHint2Acrs", "kwargs": { - "scheme_map": { - "email": ["oidcop.user_authn.authn_context.INTERNETPROTOCOLPASSWORD"] - } + "scheme_map": {"email": ["oidcop.user_authn.authn_context.INTERNETPROTOCOLPASSWORD"]} }, }, "token_handler_args": { @@ -145,20 +132,12 @@ "class": "oidcop.token.jwt_token.JWTToken", "kwargs": { "lifetime": 3600, - "add_claims": [ - "email", - "email_verified", - "phone_number", - "phone_number_verified", - ], + "add_claims": ["email", "email_verified", "phone_number", "phone_number_verified",], "add_claim_by_scope": True, "aud": ["https://example.org/appl"], }, }, "refresh": {"kwargs": {"lifetime": 86400}}, }, - "userinfo": { - "class": "oidcop.user_info.UserInfo", - "kwargs": {"filename": "users.json"}, - }, + "userinfo": {"class": "oidcop.user_info.UserInfo", "kwargs": {"filename": "users.json"},}, } diff --git a/tests/test_00_configure.py b/tests/test_00_configure.py index 1b130ee0..d74540af 100644 --- a/tests/test_00_configure.py +++ b/tests/test_00_configure.py @@ -19,9 +19,7 @@ def test_op_configure(): _str = open(full_path("op_config.json")).read() _conf = json.loads(_str) - configuration = OPConfiguration( - conf=_conf, base_path=BASEDIR, domain="127.0.0.1", port=443 - ) + configuration = OPConfiguration(conf=_conf, base_path=BASEDIR, domain="127.0.0.1", port=443) assert configuration assert "add_on" in configuration authz_conf = configuration["authz"] @@ -68,9 +66,7 @@ def test_op_configure_default(): _str = open(full_path("op_config.json")).read() _conf = json.loads(_str) - configuration = OPConfiguration( - conf=_conf, base_path=BASEDIR, domain="127.0.0.1", port=443 - ) + configuration = OPConfiguration(conf=_conf, base_path=BASEDIR, domain="127.0.0.1", port=443) assert configuration assert "add_on" in configuration authz = configuration["authz"] @@ -78,10 +74,7 @@ def test_op_configure_default(): id_token_conf = configuration.get("id_token", {}) assert set(id_token_conf.keys()) == {"kwargs", "class"} assert id_token_conf["kwargs"] == { - "base_claims": { - "email": {"essential": True}, - "email_verified": {"essential": True}, - } + "base_claims": {"email": {"essential": True}, "email_verified": {"essential": True},} } @@ -100,19 +93,14 @@ def test_op_configure_default_from_file(): id_token_conf = configuration.get("id_token", {}) assert set(id_token_conf.keys()) == {"kwargs", "class"} assert id_token_conf["kwargs"] == { - "base_claims": { - "email": {"essential": True}, - "email_verified": {"essential": True}, - } + "base_claims": {"email": {"essential": True}, "email_verified": {"essential": True},} } def test_server_configure(): configuration = create_from_config_file( Configuration, - entity_conf=[ - {"class": OPConfiguration, "attr": "op", "path": ["op", "server_info"]} - ], + entity_conf=[{"class": OPConfiguration, "attr": "op", "path": ["op", "server_info"]}], filename=full_path("srv_config.yaml"), base_path=BASEDIR, ) @@ -156,9 +144,7 @@ def test_loggin_conf_default(): "formatter": "default", }, }, - "formatters": { - "default": {"format": "%(asctime)s %(name)s %(levelname)s %(message)s"} - }, + "formatters": {"default": {"format": "%(asctime)s %(name)s %(levelname)s %(message)s"}}, } diff --git a/tests/test_00_server.py b/tests/test_00_server.py index cb3e855b..d6bd726c 100755 --- a/tests/test_00_server.py +++ b/tests/test_00_server.py @@ -44,16 +44,8 @@ def full_path(local_file): "class": ProviderConfiguration, "kwargs": {}, }, - "registration_endpoint": { - "path": "registration", - "class": Registration, - "kwargs": {}, - }, - "authorization_endpoint": { - "path": "authorization", - "class": Authorization, - "kwargs": {}, - }, + "registration_endpoint": {"path": "registration", "class": Registration, "kwargs": {},}, + "authorization_endpoint": {"path": "authorization", "class": Authorization, "kwargs": {},}, "token_endpoint": {"path": "token", "class": Token, "kwargs": {}}, "userinfo_endpoint": { "path": "userinfo", @@ -69,6 +61,7 @@ def full_path(local_file): "kwargs": {"user": "diana"}, } }, + "claims_interface": {"class": "oidcop.session.claims.ClaimsInterface", "kwargs": {}}, "add_on": {"pkce": {"function": add_pkce_support, "kwargs": {"essential": True}}}, "template_dir": "template", "login_hint_lookup": {"class": oidcop.login_hint.LoginHintLookup, "kwargs": {}}, @@ -105,9 +98,7 @@ def test_capabilities_default(): _str = open(full_path("op_config.json")).read() _conf = json.loads(_str) - configuration = OPConfiguration( - conf=_conf, base_path=BASEDIR, domain="127.0.0.1", port=443 - ) + configuration = OPConfiguration(conf=_conf, base_path=BASEDIR, domain="127.0.0.1", port=443) server = Server(configuration) assert set(server.endpoint_context.provider_info["response_types_supported"]) == { @@ -119,9 +110,7 @@ def test_capabilities_default(): "id_token token", "code id_token token", } - assert ( - server.endpoint_context.provider_info["request_uri_parameter_supported"] is True - ) + assert server.endpoint_context.provider_info["request_uri_parameter_supported"] is True def test_capabilities_subset1(): @@ -145,10 +134,7 @@ def test_capabilities_bool(): _cnf = copy(conf) _cnf["capabilities"] = {"request_uri_parameter_supported": False} server = Server(_cnf) - assert ( - server.endpoint_context.provider_info["request_uri_parameter_supported"] - is False - ) + assert server.endpoint_context.provider_info["request_uri_parameter_supported"] is False def test_cdb(): diff --git a/tests/test_01_claims.py b/tests/test_01_claims.py index 0a2ab77e..fa5f8dae 100644 --- a/tests/test_01_claims.py +++ b/tests/test_01_claims.py @@ -61,11 +61,7 @@ def full_path(local_file): "class": "oidcop.oidc.authorization.Authorization", "kwargs": {}, }, - "token_endpoint": { - "path": "token", - "class": "oidcop.oidc.token.Token", - "kwargs": {}, - }, + "token_endpoint": {"path": "token", "class": "oidcop.oidc.token.Token", "kwargs": {},}, "userinfo_endpoint": { "path": "userinfo", "class": "oidcop.oidc.userinfo.UserInfo", @@ -102,6 +98,7 @@ def full_path(local_file): "class": "oidcop.user_info.UserInfo", "kwargs": {"db_file": full_path("users.json")}, }, + "claims_interface": {"class": "oidcop.session.claims.ClaimsInterface", "kwargs": {}}, } USER_ID = "diana" @@ -143,46 +140,34 @@ def _create_session(self, auth_req, sub_type="public", sector_identifier=""): def test_authorization_request_id_token_claims(self): session_id = self._create_session(AREQ) - claims = self.claims_interface.authorization_request_claims( - session_id, "id_token" - ) + claims = self.claims_interface.authorization_request_claims(session_id, "id_token") assert claims == {} def test_authorization_request_id_token_claims_2(self): session_id = self._create_session(AREQ_2) - claims = self.claims_interface.authorization_request_claims( - session_id, "id_token" - ) + claims = self.claims_interface.authorization_request_claims(session_id, "id_token") assert claims assert set(claims.keys()) == {"nickname"} def test_authorization_request_userinfo_claims(self): session_id = self._create_session(AREQ) - claims = self.claims_interface.authorization_request_claims( - session_id, "userinfo" - ) + claims = self.claims_interface.authorization_request_claims(session_id, "userinfo") assert claims == {} def test_authorization_request_userinfo_claims_2(self): session_id = self._create_session(AREQ_2) - claims = self.claims_interface.authorization_request_claims( - session_id, "userinfo" - ) + claims = self.claims_interface.authorization_request_claims(session_id, "userinfo") assert claims == {} def test_authorization_request_userinfo_claims_3(self): session_id = self._create_session(AREQ_3) - claims = self.claims_interface.authorization_request_claims( - session_id, "userinfo" - ) + claims = self.claims_interface.authorization_request_claims(session_id, "userinfo") assert set(claims.keys()) == {"name", "email", "email_verified"} - @pytest.mark.parametrize( - "usage", ["id_token", "userinfo", "introspection", "token"] - ) + @pytest.mark.parametrize("usage", ["id_token", "userinfo", "introspection", "token"]) def test_get_client_claims_0(self, usage): claims = self.claims_interface._get_client_claims("client_1", usage) assert claims == {} @@ -202,9 +187,7 @@ def test_get_client_claims_introspection_1(self): claims = self.claims_interface._get_client_claims("client_1", "introspection") assert set(claims.keys()) == {"email"} - @pytest.mark.parametrize( - "usage", ["id_token", "userinfo", "introspection", "token"] - ) + @pytest.mark.parametrize("usage", ["id_token", "userinfo", "introspection", "token"]) def test_get_claims(self, usage): session_id = self._create_session(AREQ) claims = self.claims_interface.get_claims(session_id, [], usage) @@ -238,9 +221,7 @@ def test_get_claims_id_token_3(self): } self.endpoint_context.cdb["client_1"]["id_token_claims"] = ["name", "email"] - claims = self.claims_interface.get_claims( - session_id, ["openid", "address"], "id_token" - ) + claims = self.claims_interface.get_claims(session_id, ["openid", "address"], "id_token") assert set(claims.keys()) == { "name", "email", @@ -259,9 +240,7 @@ def test_get_claims_userinfo_3(self): } self.endpoint_context.cdb["client_1"]["userinfo_claims"] = ["name", "email"] - claims = self.claims_interface.get_claims( - session_id, ["openid", "address"], "userinfo" - ) + claims = self.claims_interface.get_claims(session_id, ["openid", "address"], "userinfo") assert set(claims.keys()) == { "name", "email", @@ -304,9 +283,7 @@ def test_get_claims_access_token_3(self): self.endpoint_context.cdb["client_1"]["access_token_claims"] = ["name", "email"] session_id = self._create_session(AREQ) - claims = self.claims_interface.get_claims( - session_id, ["openid", "address"], "access_token" - ) + claims = self.claims_interface.get_claims(session_id, ["openid", "address"], "access_token") assert set(claims.keys()) == { "name", "email", @@ -324,9 +301,7 @@ def test_get_claims_all_usage(self): self.server.server_get("endpoint", "introspection").kwargs = {} session_id = self._create_session(AREQ) - claims = self.claims_interface.get_claims_all_usage( - session_id, ["openid", "address"] - ) + claims = self.claims_interface.get_claims_all_usage(session_id, ["openid", "address"]) assert set(claims.keys()) == { "id_token", "userinfo", @@ -347,16 +322,12 @@ def test_get_claims_all_usage_2(self): } self.endpoint_context.cdb["client_1"]["userinfo_claims"] = ["name", "email"] - self.server.server_get("endpoint", "introspection").kwargs = { - "add_claims_by_scope": True - } + self.server.server_get("endpoint", "introspection").kwargs = {"add_claims_by_scope": True} self.endpoint_context.session_manager.token_handler["access_token"].kwargs = {} session_id = self._create_session(AREQ) - claims = self.claims_interface.get_claims_all_usage( - session_id, ["openid", "address"] - ) + claims = self.claims_interface.get_claims_all_usage(session_id, ["openid", "address"]) assert set(claims.keys()) == { "id_token", @@ -380,9 +351,7 @@ def test_get_user_claims(self): } self.endpoint_context.cdb["client_1"]["userinfo_claims"] = ["name", "email"] - self.server.server_get("endpoint", "introspection").kwargs = { - "add_claims_by_scope": True - } + self.server.server_get("endpoint", "introspection").kwargs = {"add_claims_by_scope": True} self.endpoint_context.session_manager.token_handler["access_token"].kwargs = {} @@ -391,14 +360,10 @@ def test_get_user_claims(self): session_id, ["openid", "address"] ) - _claims = self.claims_interface.get_user_claims( - USER_ID, claims_restriction["userinfo"] - ) + _claims = self.claims_interface.get_user_claims(USER_ID, claims_restriction["userinfo"]) assert _claims == {"name": "Diana Krall", "email": "diana@example.org"} - _claims = self.claims_interface.get_user_claims( - USER_ID, claims_restriction["id_token"] - ) + _claims = self.claims_interface.get_user_claims(USER_ID, claims_restriction["id_token"]) assert _claims == {"email_verified": False, "email": "diana@example.org"} _claims = self.claims_interface.get_user_claims( @@ -414,7 +379,5 @@ def test_get_user_claims(self): } } - _claims = self.claims_interface.get_user_claims( - USER_ID, claims_restriction["access_token"] - ) + _claims = self.claims_interface.get_user_claims(USER_ID, claims_restriction["access_token"]) assert _claims == {} diff --git a/tests/test_01_grant.py b/tests/test_01_grant.py index edb91e01..84302480 100644 --- a/tests/test_01_grant.py +++ b/tests/test_01_grant.py @@ -30,11 +30,7 @@ "class": "oidcop.oidc.authorization.Authorization", "kwargs": {}, }, - "token_endpoint": { - "path": "token", - "class": "oidcop.oidc.token.Token", - "kwargs": {}, - }, + "token_endpoint": {"path": "token", "class": "oidcop.oidc.token.Token", "kwargs": {},}, }, "authentication": { "anon": { @@ -43,6 +39,7 @@ "kwargs": {"user": "diana"}, } }, + "claims_interface": {"class": "oidcop.session.claims.ClaimsInterface", "kwargs": {}}, } USER_ID = "diana" @@ -420,9 +417,7 @@ def test_get_spec(self): spec = grant.get_spec(access_token) assert set(spec.keys()) == {"scope", "claims", "resources"} assert spec["scope"] == ["openid", "email", "eduperson"] - assert spec["claims"] == { - "userinfo": {"given_name": None, "eduperson_affiliation": None} - } + assert spec["claims"] == {"userinfo": {"given_name": None, "eduperson_affiliation": None}} assert spec["resources"] == ["https://api.example.com"] def test_get_usage_rules(self): @@ -438,9 +433,7 @@ def test_get_usage_rules(self): # Default usage rules self.endpoint_context.cdb["client_id"] = {} - rules = get_usage_rules( - "access_token", self.endpoint_context, grant, "client_id" - ) + rules = get_usage_rules("access_token", self.endpoint_context, grant, "client_id") assert rules == {"supports_minting": [], "expires_in": 3600} # client specific usage rules diff --git a/tests/test_01_session_token.py b/tests/test_01_session_token.py index 34b0510d..622e4b6d 100644 --- a/tests/test_01_session_token.py +++ b/tests/test_01_session_token.py @@ -37,14 +37,15 @@ def test_authorization_code_extras(): assert code.resources == ["https://api.example.com"] -def test_dump_load(cls=AuthorizationCode, - kwargs=dict( - value="ABCD", - scope=["openid", "foo", "bar"], - claims={"userinfo": {"given_name": None}}, - resources=["https://api.example.com"], - ) - ): +def test_dump_load( + cls=AuthorizationCode, + kwargs=dict( + value="ABCD", + scope=["openid", "foo", "bar"], + claims={"userinfo": {"given_name": None}}, + resources=["https://api.example.com"], + ), +): code = cls(**kwargs) _item = code.dump() @@ -54,17 +55,13 @@ def test_dump_load(cls=AuthorizationCode, if val: assert val == getattr(_new_code, attr) + def test_dump_load_access_token(): - test_dump_load( - cls=AccessToken, - kwargs={} - ) + test_dump_load(cls=AccessToken, kwargs={}) + def test_dump_load_idtoken(): - test_dump_load( - cls=IDToken, - kwargs={} - ) + test_dump_load(cls=IDToken, kwargs={}) def test_supports_minting(): diff --git a/tests/test_01_util.py b/tests/test_01_util.py index fa246104..ea8bab84 100644 --- a/tests/test_01_util.py +++ b/tests/test_01_util.py @@ -19,27 +19,15 @@ "verify_ssl": False, "capabilities": {}, "jwks_uri": "https://example.com/jwks.json", - "keys": { - "private_path": "own/jwks.json", - "key_defs": KEYDEFS, - "uri_path": "static/jwks.json", - }, + "keys": {"private_path": "own/jwks.json", "key_defs": KEYDEFS, "uri_path": "static/jwks.json",}, "endpoint": { "provider_config": { "path": ".well-known/openid-configuration", "class": ProviderConfiguration, "kwargs": {}, }, - "registration_endpoint": { - "path": "registration", - "class": Registration, - "kwargs": {}, - }, - "authorization_endpoint": { - "path": "authorization", - "class": Authorization, - "kwargs": {}, - }, + "registration_endpoint": {"path": "registration", "class": Registration, "kwargs": {},}, + "authorization_endpoint": {"path": "authorization", "class": Authorization, "kwargs": {},}, "token_endpoint": {"path": "token", "class": Token, "kwargs": {}}, "userinfo_endpoint": { "path": "userinfo", diff --git a/tests/test_02_authz_handling.py b/tests/test_02_authz_handling.py index 8b7bc619..c3183556 100644 --- a/tests/test_02_authz_handling.py +++ b/tests/test_02_authz_handling.py @@ -52,17 +52,11 @@ "grant_config": { "usage_rules": { "authorization_code": { - "supports_minting": [ - "access_token", - "refresh_token", - "id_token", - ], + "supports_minting": ["access_token", "refresh_token", "id_token",], "max_usage": 1, }, "access_token": {}, - "refresh_token": { - "supports_minting": ["access_token", "refresh_token"] - }, + "refresh_token": {"supports_minting": ["access_token", "refresh_token"]}, }, "expires_in": 43200, } @@ -74,11 +68,7 @@ "class": "oidcop.oidc.authorization.Authorization", "kwargs": {}, }, - "token_endpoint": { - "path": "token", - "class": "oidcop.oidc.token.Token", - "kwargs": {}, - }, + "token_endpoint": {"path": "token", "class": "oidcop.oidc.token.Token", "kwargs": {},}, "userinfo_endpoint": { "path": "userinfo", "class": "oidcop.oidc.userinfo.UserInfo", @@ -111,6 +101,7 @@ }, "id_token": {"class": "oidcop.token.id_token.IDToken", "kwargs": {}}, }, + "claims_interface": {"class": "oidcop.session.claims.ClaimsInterface", "kwargs": {}}, } USER_ID = "diana" diff --git a/tests/test_02_client_authn.py b/tests/test_02_client_authn.py index a3942a58..fae199d7 100755 --- a/tests/test_02_client_authn.py +++ b/tests/test_02_client_authn.py @@ -65,11 +65,8 @@ }, }, "template_dir": "template", - "keys": { - "private_path": "own/jwks.json", - "key_defs": KEYDEFS, - "uri_path": "static/jwks.json", - }, + "keys": {"private_path": "own/jwks.json", "key_defs": KEYDEFS, "uri_path": "static/jwks.json",}, + "claims_interface": {"class": "oidcop.session.claims.ClaimsInterface", "kwargs": {}}, } client_id = "client_id" @@ -205,17 +202,13 @@ def test_private_key_jwt_reusage_other_endpoint(self): _jwt = JWT(client_keyjar, iss=client_id, sign_alg="RS256") _jwt.with_jti = True - _assertion = _jwt.pack( - {"aud": [self.method.server_get("endpoint", "token").full_path]} - ) + _assertion = _jwt.pack({"aud": [self.method.server_get("endpoint", "token").full_path]}) request = {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER} # This should be OK assert self.method.is_usable(request=request) - self.method.verify( - request=request, endpoint=self.method.server_get("endpoint", "token") - ) + self.method.verify(request=request, endpoint=self.method.server_get("endpoint", "token")) # This should NOT be OK with pytest.raises(NotForMe): @@ -225,9 +218,7 @@ def test_private_key_jwt_reusage_other_endpoint(self): # This should NOT be OK because this is the second time the token appears with pytest.raises(MultipleUsage): - self.method.verify( - request, endpoint=self.method.server_get("endpoint", "token") - ) + self.method.verify(request, endpoint=self.method.server_get("endpoint", "token")) def test_private_key_jwt_auth_endpoint(self): # Own dynamic keys @@ -248,8 +239,7 @@ def test_private_key_jwt_auth_endpoint(self): assert self.method.is_usable(request=request) authn_info = self.method.verify( - request=request, - endpoint=self.method.server_get("endpoint", "authorization"), + request=request, endpoint=self.method.server_get("endpoint", "authorization"), ) assert authn_info["client_id"] == client_id @@ -265,9 +255,7 @@ def create_method(self): def test_bearerheader(self): authorization_info = "Bearer 1234567890" - assert self.method.verify(authorization_token=authorization_info) == { - "token": "1234567890" - } + assert self.method.verify(authorization_token=authorization_info) == {"token": "1234567890"} def test_bearerheader_wrong_type(self): authorization_info = "Thrower 1234567890" @@ -461,9 +449,7 @@ def test_verify_client_bearer_body(self): def test_verify_client_client_secret_post(self): request = {"client_id": client_id, "client_secret": client_secret} res = verify_client( - self.endpoint_context, - request, - endpoint=self.server.server_get("endpoint", "token"), + self.endpoint_context, request, endpoint=self.server.server_get("endpoint", "token"), ) assert set(res.keys()) == {"method", "client_id"} assert res["method"] == "client_secret_post" @@ -522,9 +508,7 @@ def test_verify_client_jws_authn_method(self): request = {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER} res = verify_client( - self.endpoint_context, - request, - endpoint=self.server.server_get("endpoint", "token"), + self.endpoint_context, request, endpoint=self.server.server_get("endpoint", "token"), ) assert res["method"] == "client_secret_jwt" assert res["client_id"] == "client_id" @@ -544,9 +528,7 @@ def test_verify_client_bearer_body(self): def test_verify_client_client_secret_post(self): request = {"client_id": client_id, "client_secret": client_secret} res = verify_client( - self.endpoint_context, - request, - endpoint=self.server.server_get("endpoint", "token"), + self.endpoint_context, request, endpoint=self.server.server_get("endpoint", "token"), ) assert set(res.keys()) == {"method", "client_id"} assert res["method"] == "client_secret_post" diff --git a/tests/test_02_sess_mngm_db.py b/tests/test_02_sess_mngm_db.py index ce492edc..d5c3f0c3 100644 --- a/tests/test_02_sess_mngm_db.py +++ b/tests/test_02_sess_mngm_db.py @@ -88,9 +88,7 @@ def test_client_info_add2(self): # The reference is there but not the value del self.db.db[self.db.session_key("diana", "client_1")] - authn_event = create_authn_event( - uid="diana", expires_in=10, authn_info="authn_class_ref" - ) + authn_event = create_authn_event(uid="diana", expires_in=10, authn_info="authn_class_ref") grant = Grant(authentication_event=authn_event, authorization_request=AUTHZ_REQ) @@ -183,9 +181,7 @@ def test_step_wise(self): # store user info self.db.set(["diana"], UserSessionInfo(user_id="diana")) # Client specific information - self.db.set( - ["diana", "client_1"], ClientSessionInfo(sub=public_id("diana", salt)) - ) + self.db.set(["diana", "client_1"], ClientSessionInfo(sub=public_id("diana", salt))) # Grant grant = Grant() access_code = SessionToken("access_code", value="1234567890") diff --git a/tests/test_04_token_handler.py b/tests/test_04_token_handler.py index a1e802a6..3e263a66 100644 --- a/tests/test_04_token_handler.py +++ b/tests/test_04_token_handler.py @@ -116,12 +116,8 @@ def setup_token_handler(self): refresh_token_expires_in = 86400 code_handler = DefaultToken(password, typ="A", lifetime=grant_expires_in) - access_token_handler = DefaultToken( - password, typ="T", lifetime=token_expires_in - ) - refresh_token_handler = DefaultToken( - password, typ="R", lifetime=refresh_token_expires_in - ) + access_token_handler = DefaultToken(password, typ="T", lifetime=token_expires_in) + refresh_token_handler = DefaultToken(password, typ="R", lifetime=refresh_token_expires_in) self.handler = TokenHandler( code_handler=code_handler, @@ -177,16 +173,12 @@ def test_token_handler_from_config(): conf = { "issuer": "https://example.com/op", "keys": {"uri_path": "static/jwks.json", "key_defs": KEYDEFS}, - "endpoint": { - "endpoint": {"path": "endpoint", "class": Endpoint, "kwargs": {}}, - }, + "endpoint": {"endpoint": {"path": "endpoint", "class": Endpoint, "kwargs": {}},}, "token_handler_args": { "jwks_def": { "private_path": "private/token_jwks.json", "read_only": False, - "key_defs": [ - {"type": "oct", "bytes": "24", "use": ["enc"], "kid": "code"} - ], + "key_defs": [{"type": "oct", "bytes": "24", "use": ["enc"], "kid": "code"}], }, "code": {"kwargs": {"lifetime": 600}}, "token": { @@ -199,7 +191,7 @@ def test_token_handler_from_config(): }, "refresh": { "class": "oidcop.token.jwt_token.JWTToken", - "kwargs": {"lifetime": 3600, "aud": ["https://example.org/appl"], }, + "kwargs": {"lifetime": 3600, "aud": ["https://example.org/appl"],}, }, "id_token": { "class": "oidcop.token.id_token.IDToken", @@ -238,42 +230,39 @@ def test_token_handler_from_config(): assert token_handler.handler["refresh_token"].alg == "ES256" assert token_handler.handler["refresh_token"].kwargs == {} assert token_handler.handler["refresh_token"].lifetime == 3600 - assert token_handler.handler["refresh_token"].def_aud == [ - "https://example.org/appl" - ] + assert token_handler.handler["refresh_token"].def_aud == ["https://example.org/appl"] assert token_handler.handler["id_token"].lifetime == 300 assert "base_claims" in token_handler.handler["id_token"].kwargs -@pytest.mark.parametrize("jwks", [ - {"jwks_file": "private/token_jwks_1.json"}, - {"jwks_def": { - "private_path": "private/token_jwks_2.json", - "read_only": False, - "key_defs": [ - {"type": "oct", "bytes": "24", "use": ["enc"], "kid": "code"} - ], - }}, - { - "jwks_file": "private/token_jwks_1.json", - "jwks_def": { - "private_path": "private/token_jwks_2.json", - "read_only": False, - "key_defs": [ - {"type": "oct", "bytes": "24", "use": ["enc"], "kid": "code"} - ], - } - }, - None -]) +@pytest.mark.parametrize( + "jwks", + [ + {"jwks_file": "private/token_jwks_1.json"}, + { + "jwks_def": { + "private_path": "private/token_jwks_2.json", + "read_only": False, + "key_defs": [{"type": "oct", "bytes": "24", "use": ["enc"], "kid": "code"}], + } + }, + { + "jwks_file": "private/token_jwks_1.json", + "jwks_def": { + "private_path": "private/token_jwks_2.json", + "read_only": False, + "key_defs": [{"type": "oct", "bytes": "24", "use": ["enc"], "kid": "code"}], + }, + }, + None, + ], +) def test_file(jwks): conf = { "issuer": "https://example.com/op", "keys": {"uri_path": "static/jwks.json", "key_defs": KEYDEFS}, - "endpoint": { - "endpoint": {"path": "endpoint", "class": Endpoint, "kwargs": {}}, - }, + "endpoint": {"endpoint": {"path": "endpoint", "class": Endpoint, "kwargs": {}},}, "token_handler_args": { "code": {"kwargs": {"lifetime": 600}}, "token": { @@ -286,7 +275,7 @@ def test_file(jwks): }, "refresh": { "class": "oidcop.token.jwt_token.JWTToken", - "kwargs": {"lifetime": 3600, "aud": ["https://example.org/appl"], }, + "kwargs": {"lifetime": 3600, "aud": ["https://example.org/appl"],}, }, "id_token": { "class": "oidcop.token.id_token.IDToken", @@ -305,8 +294,11 @@ def test_file(jwks): except KeyError: pass - for _file in ["private/token_jwks_1.json", "private/token_jwks_2.json", - "private/token_jwks.json"]: + for _file in [ + "private/token_jwks_1.json", + "private/token_jwks_2.json", + "private/token_jwks.json", + ]: if os.path.exists(full_path(_file)): os.unlink(full_path(_file)) diff --git a/tests/test_05_id_token.py b/tests/test_05_id_token.py index 76ba827d..8696f76c 100644 --- a/tests/test_05_id_token.py +++ b/tests/test_05_id_token.py @@ -123,9 +123,7 @@ def full_path(local_file): "max_usage": 1, }, "access_token": {}, - "refresh_token": { - "supports_minting": ["access_token", "refresh_token"] - }, + "refresh_token": {"supports_minting": ["access_token", "refresh_token"]}, }, "expires_in": 43200, } @@ -136,17 +134,11 @@ def full_path(local_file): "grant_config": { "usage_rules": { "authorization_code": { - "supports_minting": [ - "access_token", - "refresh_token", - "id_token", - ], + "supports_minting": ["access_token", "refresh_token", "id_token",], "max_usage": 1, }, "access_token": {}, - "refresh_token": { - "supports_minting": ["access_token", "refresh_token"] - }, + "refresh_token": {"supports_minting": ["access_token", "refresh_token"]}, }, "expires_in": 43200, } @@ -154,6 +146,7 @@ def full_path(local_file): }, "userinfo": {"class": "oidcop.user_info.UserInfo", "kwargs": {"db": USERS},}, "client_authn": verify_client, + "claims_interface": {"class": "oidcop.session.claims.ClaimsInterface", "kwargs": {}}, } USER_ID = "diana" @@ -171,9 +164,7 @@ def create_session_manager(self): "token_endpoint_auth_method": "client_secret_post", "response_types": ["code", "token", "code id_token", "id_token"], } - self.endpoint_context.keyjar.add_symmetric( - "client_1", "hemligtochintekort", ["sig", "enc"] - ) + self.endpoint_context.keyjar.add_symmetric("client_1", "hemligtochintekort", ["sig", "enc"]) self.session_manager = self.endpoint_context.session_manager self.user_id = USER_ID @@ -211,9 +202,7 @@ def _mint_access_token(self, grant, session_id, token_ref): ) return access_token - def _mint_id_token( - self, grant, session_id, token_ref=None, code=None, access_token=None - ): + def _mint_id_token(self, grant, session_id, token_ref=None, code=None, access_token=None): return grant.mint_token( session_id=session_id, endpoint_context=self.endpoint_context, @@ -247,9 +236,7 @@ def test_id_token_payload_with_code(self): grant = self.session_manager[session_id] code = self._mint_code(grant, session_id) - id_token = self._mint_id_token( - grant, session_id, token_ref=code, code=code.value - ) + id_token = self._mint_id_token(grant, session_id, token_ref=code, code=code.value) _jwt = factory(id_token.value) payload = _jwt.jwt.payload() @@ -296,11 +283,7 @@ def test_id_token_payload_with_code_and_access_token(self): access_token = self._mint_access_token(grant, session_id, code) id_token = self._mint_id_token( - grant, - session_id, - token_ref=code, - code=code.value, - access_token=access_token.value, + grant, session_id, token_ref=code, code=code.value, access_token=access_token.value, ) _jwt = factory(id_token.value) @@ -345,11 +328,7 @@ def test_id_token_payload_many_0(self): access_token = self._mint_access_token(grant, session_id, code) id_token = self._mint_id_token( - grant, - session_id, - token_ref=code, - code=code.value, - access_token=access_token.value, + grant, session_id, token_ref=code, code=code.value, access_token=access_token.value, ) _jwt = factory(id_token.value) @@ -397,8 +376,11 @@ def test_get_sign_algorithm(self): ) # default signing alg assert algs == { - 'sign': True, 'encrypt': True, 'sign_alg': 'RS256', - 'enc_alg': 'RSA-OAEP', 'enc_enc': 'A128CBC-HS256' + "sign": True, + "encrypt": True, + "sign_alg": "RS256", + "enc_alg": "RSA-OAEP", + "enc_enc": "A128CBC-HS256", } def test_available_claims(self): @@ -433,9 +415,7 @@ def test_client_claims(self): session_id = self._create_session(AREQ) grant = self.session_manager[session_id] - self.session_manager.token_handler["id_token"].kwargs[ - "enable_claims_per_client" - ] = True + self.session_manager.token_handler["id_token"].kwargs["enable_claims_per_client"] = True self.endpoint_context.cdb["client_1"]["id_token_claims"] = {"address": None} _claims = self.endpoint_context.claims_interface.get_claims( @@ -478,9 +458,7 @@ def test_client_claims_scopes(self): session_id = self._create_session(AREQS) grant = self.session_manager[session_id] - self.session_manager.token_handler["id_token"].kwargs[ - "add_claims_by_scope" - ] = True + self.session_manager.token_handler["id_token"].kwargs["add_claims_by_scope"] = True _claims = self.endpoint_context.claims_interface.get_claims( session_id=session_id, scopes=AREQS["scope"], usage="id_token" @@ -502,9 +480,7 @@ def test_client_claims_scopes_and_request_claims_no_match(self): session_id = self._create_session(AREQRC) grant = self.session_manager[session_id] - self.session_manager.token_handler["id_token"].kwargs[ - "add_claims_by_scope" - ] = True + self.session_manager.token_handler["id_token"].kwargs["add_claims_by_scope"] = True _claims = self.endpoint_context.claims_interface.get_claims( session_id=session_id, scopes=AREQRC["scope"], usage="id_token" @@ -531,9 +507,7 @@ def test_client_claims_scopes_and_request_claims_one_match(self): session_id = self._create_session(_req) grant = self.session_manager[session_id] - self.session_manager.token_handler["id_token"].kwargs[ - "add_claims_by_scope" - ] = True + self.session_manager.token_handler["id_token"].kwargs["add_claims_by_scope"] = True _claims = self.endpoint_context.claims_interface.get_claims( session_id=session_id, scopes=_req["scope"], usage="id_token" @@ -552,7 +526,6 @@ def test_client_claims_scopes_and_request_claims_one_match(self): # Scope -> claims assert "address" in res - def test_id_token_info(self): session_id = self._create_session(AREQ) grant = self.session_manager[session_id] @@ -565,14 +538,14 @@ def test_id_token_info(self): endpoint_context = self.endpoint_context sman = endpoint_context.session_manager - server_get = sman.token_handler.handler['id_token'].server_get + server_get = sman.token_handler.handler["id_token"].server_get _info = self.session_manager.token_handler.info(id_token.value) - assert 'sid' in _info - assert 'exp' in _info - assert 'aud' in _info + assert "sid" in _info + assert "exp" in _info + assert "aud" in _info - client_id = AREQ.get('client_id') - _id_token = sman.token_handler.handler['id_token'] + client_id = AREQ.get("client_id") + _id_token = sman.token_handler.handler["id_token"] _id_token.sign_encrypt(session_id, client_id) # TODO: we need an authentication event for this id_token for a better coverage @@ -580,6 +553,5 @@ def test_id_token_info(self): client_info = endpoint_context.cdb[client_id] get_sign_and_encrypt_algorithms( - endpoint_context, client_info, payload_type="id_token", - sign=True, encrypt=True + endpoint_context, client_info, payload_type="id_token", sign=True, encrypt=True ) diff --git a/tests/test_05_jwt_token.py b/tests/test_05_jwt_token.py index fab24e71..40f5cb7f 100644 --- a/tests/test_05_jwt_token.py +++ b/tests/test_05_jwt_token.py @@ -86,7 +86,7 @@ "authorization_code": "code", "access_token": "access_token", "refresh_token": "refresh_token", - "id_token": "id_token" + "id_token": "id_token", } @@ -135,11 +135,7 @@ def create_endpoint(self): "class": ProviderConfiguration, "kwargs": {}, }, - "registration": { - "path": "{}/registration", - "class": Registration, - "kwargs": {}, - }, + "registration": {"path": "{}/registration", "class": Registration, "kwargs": {},}, "authorization": { "path": "{}/authorization", "class": Authorization, @@ -168,11 +164,7 @@ def create_endpoint(self): "grant_config": { "usage_rules": { "authorization_code": { - "supports_minting": [ - "access_token", - "refresh_token", - "id_token", - ], + "supports_minting": ["access_token", "refresh_token", "id_token",], "max_usage": 1, }, "access_token": {}, @@ -184,6 +176,7 @@ def create_endpoint(self): } }, }, + "claims_interface": {"class": "oidcop.session.claims.ClaimsInterface", "kwargs": {}}, } server = Server(conf, keyjar=KEYJAR) self.endpoint_context = server.endpoint_context @@ -255,9 +248,9 @@ def test_info(self): def test_enable_claims_per_client(self, enable_claims_per_client): # Set up configuration self.endpoint_context.cdb["client_1"]["access_token_claims"] = {"address": None} - self.endpoint_context.session_manager.token_handler.handler[ - "access_token" - ].kwargs["enable_claims_per_client"] = enable_claims_per_client + self.endpoint_context.session_manager.token_handler.handler["access_token"].kwargs[ + "enable_claims_per_client" + ] = enable_claims_per_client session_id = self._create_session(AUTH_REQ) # apply consent diff --git a/tests/test_06_authn_context.py b/tests/test_06_authn_context.py index fa022f70..b59e5635 100644 --- a/tests/test_06_authn_context.py +++ b/tests/test_06_authn_context.py @@ -25,11 +25,7 @@ "kwargs": {"user": "diana"}, "class": "oidcop.user_authn.user.NoAuthn", }, - "krall": { - "acr": INTERNETPROTOCOLPASSWORD, - "kwargs": {"user": "krall"}, - "class": NoAuthn, - }, + "krall": {"acr": INTERNETPROTOCOLPASSWORD, "kwargs": {"user": "krall"}, "class": NoAuthn,}, } KEYDEFS = [ @@ -144,6 +140,7 @@ def create_authn_broker(self): "authentication": METHOD, "userinfo": {"class": UserInfo, "kwargs": {"db": USERINFO_db}}, "template_dir": "template", + "claims_interface": {"class": "oidcop.session.claims.ClaimsInterface", "kwargs": {}}, } cookie_conf = { "sign_key": SYMKey(k="ghsNKDDLshZTPn974nOsIGhedULrsqnsGoBFBLwUKuJhE2ch"), @@ -187,9 +184,7 @@ def test_pick_authn_all(self): def test_authn_event(): - an = AuthnEvent( - uid="uid", valid_until=time_sans_frac() + 1, authn_info="authn_class_ref", - ) + an = AuthnEvent(uid="uid", valid_until=time_sans_frac() + 1, authn_info="authn_class_ref",) assert an.is_valid() diff --git a/tests/test_06_session_manager.py b/tests/test_06_session_manager.py index d0d62363..6acb4458 100644 --- a/tests/test_06_session_manager.py +++ b/tests/test_06_session_manager.py @@ -54,9 +54,7 @@ def create_session_manager(self): "jwks_def": { "private_path": "private/token_jwks.json", "read_only": False, - "key_defs": [ - {"type": "oct", "bytes": "24", "use": ["enc"], "kid": "code"} - ], + "key_defs": [{"type": "oct", "bytes": "24", "use": ["enc"], "kid": "code"}], }, "code": {"lifetime": 600}, "token": { @@ -82,6 +80,7 @@ def create_session_manager(self): "token_endpoint": {"path": "{}/token", "class": Token, "kwargs": {}}, }, "template_dir": "template", + "claims_interface": {"class": "oidcop.session.claims.ClaimsInterface", "kwargs": {}}, } server = Server(conf) self.server = server @@ -263,9 +262,7 @@ def test_check_grant(self): scopes=["openid", "phoe"], ) - _session_info = self.session_manager.get_session_info( - session_id=session_id, grant=True - ) + _session_info = self.session_manager.get_session_info(session_id=session_id, grant=True) grant = _session_info["grant"] assert grant.scope == ["openid", "phoe"] @@ -275,10 +272,7 @@ def test_check_grant(self): def test_find_token(self): session_id = self.session_manager.create_session( - authn_event=self.authn_event, - auth_req=AUTH_REQ, - user_id="diana", - client_id="client_1", + authn_event=self.authn_event, auth_req=AUTH_REQ, user_id="diana", client_id="client_1", ) _info = self.session_manager.get_session_info(session_id=session_id, grant=True) @@ -287,9 +281,7 @@ def test_find_token(self): code = self._mint_token("authorization_code", grant, session_id) access_token = self._mint_token("access_token", grant, session_id, code) - _session_id = self.session_manager.encrypted_session_id( - "diana", "client_1", grant.id - ) + _session_id = self.session_manager.encrypted_session_id("diana", "client_1", grant.id) _token = self.session_manager.find_token(_session_id, access_token.value) assert _token.type == "access_token" @@ -304,9 +296,7 @@ def test_get_authentication_event(self): # client_id="client_1", ) - _info = self.session_manager.get_session_info( - session_id, authentication_event=True - ) + _info = self.session_manager.get_session_info(session_id, authentication_event=True) authn_event = _info["authentication_event"] assert isinstance(authn_event, AuthnEvent) @@ -314,16 +304,11 @@ def test_get_authentication_event(self): assert authn_event["authn_info"] == "authn_class_ref" # cover the remaining one ... - _info = self.session_manager.get_session_info( - session_id, authorization_request=True - ) + _info = self.session_manager.get_session_info(session_id, authorization_request=True) def test_get_client_session_info(self): _session_id = self.session_manager.create_session( - authn_event=self.authn_event, - auth_req=AUTH_REQ, - user_id="diana", - client_id="client_1", + authn_event=self.authn_event, auth_req=AUTH_REQ, user_id="diana", client_id="client_1", ) csi = self.session_manager.get_client_session_info(_session_id) @@ -332,10 +317,7 @@ def test_get_client_session_info(self): def test_get_general_session_info(self): _session_id = self.session_manager.create_session( - authn_event=self.authn_event, - auth_req=AUTH_REQ, - user_id="diana", - client_id="client_1", + authn_event=self.authn_event, auth_req=AUTH_REQ, user_id="diana", client_id="client_1", ) _session_info = self.session_manager.get_session_info(_session_id) @@ -351,10 +333,7 @@ def test_get_general_session_info(self): def test_get_session_info_by_token(self): _session_id = self.session_manager.create_session( - authn_event=self.authn_event, - auth_req=AUTH_REQ, - user_id="diana", - client_id="client_1", + authn_event=self.authn_event, auth_req=AUTH_REQ, user_id="diana", client_id="client_1", ) grant = self.session_manager.get_grant(_session_id) @@ -372,10 +351,7 @@ def test_get_session_info_by_token(self): def test_token_usage_default(self): _session_id = self.session_manager.create_session( - authn_event=self.authn_event, - auth_req=AUTH_REQ, - user_id="diana", - client_id="client_1", + authn_event=self.authn_event, auth_req=AUTH_REQ, user_id="diana", client_id="client_1", ) grant = self.session_manager[_session_id] @@ -392,16 +368,11 @@ def test_token_usage_default(self): refresh_token = self._mint_token("refresh_token", grant, _session_id, code) - assert refresh_token.usage_rules == { - "supports_minting": ["access_token", "refresh_token"] - } + assert refresh_token.usage_rules == {"supports_minting": ["access_token", "refresh_token"]} def test_token_usage_grant(self): _session_id = self.session_manager.create_session( - authn_event=self.authn_event, - auth_req=AUTH_REQ, - user_id="diana", - client_id="client_1", + authn_event=self.authn_event, auth_req=AUTH_REQ, user_id="diana", client_id="client_1", ) grant = self.session_manager[_session_id] grant.usage_rules = { @@ -411,9 +382,7 @@ def test_token_usage_grant(self): "expires_in": 300, }, "access_token": {"expires_in": 3600}, - "refresh_token": { - "supports_minting": ["access_token", "refresh_token", "id_token"] - }, + "refresh_token": {"supports_minting": ["access_token", "refresh_token", "id_token"]}, } code = self._mint_token("authorization_code", grant, _session_id) @@ -564,23 +533,16 @@ def test_authentication_events(self): assert isinstance(res[0], AuthnEvent) - res = self.session_manager.get_authentication_events( - user_id= "diana", - client_id="client_1" - ) + res = self.session_manager.get_authentication_events(user_id="diana", client_id="client_1") assert isinstance(res[0], AuthnEvent) try: - self.session_manager.get_authentication_events( - user_id="diana", - ) + self.session_manager.get_authentication_events(user_id="diana",) except AttributeError: pass else: - raise Exception( - "get_authentication_events MUST return a list of AuthnEvent" - ) + raise Exception("get_authentication_events MUST return a list of AuthnEvent") def test_user_info(self): token_usage_rules = self.endpoint_context.authz.usage_rules("client_1") @@ -604,7 +566,6 @@ def test_revoke_client_session(self): ) self.session_manager.revoke_client_session(_session_id) - def test_revoke_grant(self): token_usage_rules = self.endpoint_context.authz.usage_rules("client_1") _session_id = self.session_manager.create_session( @@ -645,27 +606,20 @@ def test_grants(self): assert isinstance(res, list) - res = self.session_manager.grants( - user_id= "diana", - client_id="client_1" - ) + res = self.session_manager.grants(user_id="diana", client_id="client_1") assert isinstance(res, list) try: - self.session_manager.grants( - user_id="diana", - ) + self.session_manager.grants(user_id="diana",) except AttributeError: pass else: - raise Exception( - "get_authentication_events MUST return a list of AuthnEvent" - ) + raise Exception("get_authentication_events MUST return a list of AuthnEvent") # and now cove add_grant grant = self.session_manager[_session_id] grant_kwargs = grant.parameter - for i in ('not_before', 'used'): + for i in ("not_before", "used"): grant_kwargs.pop(i) - self.session_manager.add_grant("diana", "client_1", **grant_kwargs ) + self.session_manager.add_grant("diana", "client_1", **grant_kwargs) diff --git a/tests/test_06_session_manager_pairwise.py b/tests/test_06_session_manager_pairwise.py index 333c230b..c2c377fd 100644 --- a/tests/test_06_session_manager_pairwise.py +++ b/tests/test_06_session_manager_pairwise.py @@ -1,6 +1,4 @@ import os -import pytest - from oidcop.exception import ConfigurationError from oidcop.session.manager import PairWiseID @@ -12,42 +10,39 @@ class TestSessionManagerPairWiseID: def test_paiwise_id(self): # as param - pw = PairWiseID(salt='salt') - pw('diana', 'that-sector') + pw = PairWiseID(salt="salt") + pw("diana", "that-sector") # as file - pw = PairWiseID(filename='salt.txt') - pw('diana', 'that-sector') + pw = PairWiseID(filename="salt.txt") + pw("diana", "that-sector") # prune - os.remove('salt.txt') + os.remove("salt.txt") # again to test if a preexistent file going ot be used ... - pw = PairWiseID(filename='salt.txt') + pw = PairWiseID(filename="salt.txt") try: - pw = PairWiseID(filename='/tmp') + pw = PairWiseID(filename="/tmp") except ConfigurationError: - pass # that's ok + pass # that's ok # as random pw = PairWiseID() - pw('diana', 'that-sector') + pw("diana", "that-sector") self.cleanup() def cleanup(self): - if os.path.isfile('salt.txt'): - os.remove('salt.txt') + if os.path.isfile("salt.txt"): + os.remove("salt.txt") class TestSessionManagerPublicID: - pw = PublicID() - pw('diana', 'that-sector') + pw = PublicID() + pw("diana", "that-sector") class TestSessionManagerConf: - sman = SessionManager( - handler = TokenHandler(), - conf={'password': 'hola!'} - ) + sman = SessionManager(handler=TokenHandler(), conf={"password": "hola!"}) diff --git a/tests/test_07_userinfo.py b/tests/test_07_userinfo.py index 4ab0394e..d163e945 100644 --- a/tests/test_07_userinfo.py +++ b/tests/test_07_userinfo.py @@ -1,6 +1,7 @@ import json import os +from oidcop.configure import OPConfiguration import pytest from oidcmsg.oidc import OpenIDRequest @@ -93,9 +94,7 @@ def test_default_scope2claims(): "email", "email_verified", } - assert set(convert_scopes2claims(["address"], STANDARD_CLAIMS).keys()) == { - "address" - } + assert set(convert_scopes2claims(["address"], STANDARD_CLAIMS).keys()) == {"address"} assert set(convert_scopes2claims(["phone"], STANDARD_CLAIMS).keys()) == { "phone_number", "phone_number_verified", @@ -185,9 +184,7 @@ def create_endpoint_context(self): "jwks_def": { "private_path": "private/token_jwks.json", "read_only": False, - "key_defs": [ - {"type": "oct", "bytes": "24", "use": ["enc"], "kid": "code"} - ], + "key_defs": [{"type": "oct", "bytes": "24", "use": ["enc"], "kid": "code"}], }, "code": {"kwargs": {"lifetime": 600}}, "token": { @@ -216,18 +213,12 @@ def create_endpoint_context(self): "class": ProviderConfiguration, "kwargs": {}, }, - "registration": { - "path": "{}/registration", - "class": Registration, - "kwargs": {}, - }, + "registration": {"path": "{}/registration", "class": Registration, "kwargs": {},}, "authorization": { "path": "{}/authorization", "class": Authorization, "kwargs": { - "response_types_supported": [ - " ".join(x) for x in RESPONSE_TYPES_SUPPORTED - ], + "response_types_supported": [" ".join(x) for x in RESPONSE_TYPES_SUPPORTED], "response_modes_supported": ["query", "fragment", "form_post",], "claims_parameter_supported": True, "request_parameter_supported": True, @@ -238,16 +229,9 @@ def create_endpoint_context(self): "path": "userinfo", "class": userinfo.UserInfo, "kwargs": { - "claim_types_supported": [ - "normal", - "aggregated", - "distributed", - ], + "claim_types_supported": ["normal", "aggregated", "distributed",], "client_authn_method": ["bearer_header"], - "base_claims": { - "eduperson_scoped_affiliation": None, - "email": None, - }, + "base_claims": {"eduperson_scoped_affiliation": None, "email": None,}, "add_claims_by_scope": True, "enable_claims_per_client": True, }, @@ -268,7 +252,7 @@ def create_endpoint_context(self): "template_dir": "template", } - server = Server(conf) + server = Server(OPConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) self.endpoint_context = server.endpoint_context # Just has to be there self.endpoint_context.cdb["client1"] = {} @@ -387,9 +371,7 @@ def test_collect_user_info_enable_claims_per_client(self): _userinfo_endpoint.kwargs["enable_claims_per_client"] = True del _userinfo_endpoint.kwargs["base_claims"] - self.endpoint_context.cdb[_req["client_id"]]["userinfo_claims"] = { - "phone_number": None - } + self.endpoint_context.cdb[_req["client_id"]]["userinfo_claims"] = {"phone_number": None} _userinfo_restriction = self.claims_interface.get_claims( session_id=session_id, scopes=_req["scope"], usage="userinfo" @@ -408,9 +390,8 @@ def create_endpoint_context(self): "userinfo": {"class": UserInfo, "kwargs": {"db": USERINFO_DB}}, "password": "we didn't start the fire", "issuer": "https://example.com/op", - "token_expires_in": 900, - "grant_expires_in": 600, - "refresh_token_expires_in": 86400, + "claims_interface": {"class": "oidcop.session.claims.OAuth2ClaimsInterface", + "kwargs": {}}, "endpoint": { "provider_config": { "path": "{}/.well-known/openid-configuration", @@ -429,11 +410,7 @@ def create_endpoint_context(self): "response_types_supported": [ " ".join(x) for x in RESPONSE_TYPES_SUPPORTED ], - "response_modes_supported": [ - "query", - "fragment", - "form_post", - ], + "response_modes_supported": ["query", "fragment", "form_post",], "claims_parameter_supported": True, "request_parameter_supported": True, "request_uri_parameter_supported": True, @@ -443,16 +420,9 @@ def create_endpoint_context(self): "path": "userinfo", "class": userinfo.UserInfo, "kwargs": { - "claim_types_supported": [ - "normal", - "aggregated", - "distributed", - ], + "claim_types_supported": ["normal", "aggregated", "distributed",], "client_authn_method": ["bearer_header"], - "base_claims": { - "eduperson_scoped_affiliation": None, - "email": None, - }, + "base_claims": {"eduperson_scoped_affiliation": None, "email": None,}, "add_claims_by_scope": True, "enable_claims_per_client": True, }, diff --git a/tests/test_08_session_life.py b/tests/test_08_session_life.py index 4e759098..9aca9394 100644 --- a/tests/test_08_session_life.py +++ b/tests/test_08_session_life.py @@ -1,5 +1,6 @@ import os +from oidcop.configure import OPConfiguration import pytest from cryptojwt.key_jar import init_key_jar from oidcmsg.oidc import AccessTokenRequest @@ -50,7 +51,8 @@ def setup_token_handler(self): }, "template_dir": "template", } - server = Server(conf) + server = Server(OPConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) + self.endpoint_context = server.endpoint_context self.session_manager = self.endpoint_context.session_manager @@ -92,9 +94,7 @@ def auth(self): # the grant is assigned to a session (user_id, client_id) self.session_manager.set([user_id, client_id, grant.id], grant) - session_id = self.session_manager.encrypted_session_id( - user_id, client_id, grant.id - ) + session_id = self.session_manager.encrypted_session_id(user_id, client_id, grant.id) # Constructing an authorization code is now done by @@ -107,9 +107,7 @@ def auth(self): ) # get user info - user_info = self.session_manager.get_user_info( - uid = user_id, - ) + user_info = self.session_manager.get_user_info(uid=user_id,) return grant.id, code def test_code_flow(self): @@ -179,9 +177,7 @@ def test_code_flow(self): scope=["openid", "mail", "offline_access"], ) - reftok = self.session_manager.find_token( - session_id, REFRESH_TOKEN_REQ["refresh_token"] - ) + reftok = self.session_manager.find_token(session_id, REFRESH_TOKEN_REQ["refresh_token"]) assert reftok.supports_minting("access_token") @@ -280,11 +276,7 @@ def setup_session_manager(self): "class": ProviderConfiguration, "kwargs": {}, }, - "registration": { - "path": "{}/registration", - "class": Registration, - "kwargs": {}, - }, + "registration": {"path": "{}/registration", "class": Registration, "kwargs": {},}, "authorization": { "path": "{}/authorization", "class": Authorization, @@ -307,7 +299,7 @@ def setup_session_manager(self): "kwargs": {"db_file": full_path("users.json")}, }, } - server = Server(conf, keyjar=KEYJAR) + server = Server(OPConfiguration(conf=conf, base_path=BASEDIR), keyjar=KEYJAR, cwd=BASEDIR) self.endpoint_context = server.endpoint_context self.session_manager = self.endpoint_context.session_manager # self.session_manager = SessionManager(handler=self.endpoint_context.sdb.handler) @@ -353,9 +345,7 @@ def auth(self): # the grant is assigned to a session (user_id, client_id) self.session_manager.set([user_id, client_id, grant.id], grant) - session_id = self.session_manager.encrypted_session_id( - user_id, client_id, grant.id - ) + session_id = self.session_manager.encrypted_session_id(user_id, client_id, grant.id) # Constructing an authorization code is now done by code = grant.mint_token( session_id=session_id, @@ -383,9 +373,7 @@ def test_code_flow(self): # parse the token session_id = self.session_manager.token_handler.sid(TOKEN_REQ["code"]) - user_id, client_id, grant_id = self.session_manager.decrypt_session_id( - session_id - ) + user_id, client_id, grant_id = self.session_manager.decrypt_session_id(session_id) # Now given I have the client_id from the request and the user_id from the # token I can easily find the grant @@ -444,9 +432,7 @@ def test_code_flow(self): session_id = self.session_manager.encrypted_session_id( user_id, REFRESH_TOKEN_REQ["client_id"], grant_id ) - reftok = self.session_manager.find_token( - session_id, REFRESH_TOKEN_REQ["refresh_token"] - ) + reftok = self.session_manager.find_token(session_id, REFRESH_TOKEN_REQ["refresh_token"]) # Can I use this token to mint another token ? assert grant.is_active() diff --git a/tests/test_09_cookie_handler.py b/tests/test_09_cookie_handler.py index 5b120c9c..b170faf2 100644 --- a/tests/test_09_cookie_handler.py +++ b/tests/test_09_cookie_handler.py @@ -37,9 +37,7 @@ def test_make_cookie_content_max_age(self): assert len(_cookie_info["value"].split("|")) == 3 def test_read_cookie_info(self): - _cookie_info = [ - self.cookie_handler.make_cookie_content("oidcop", "value", "sso") - ] + _cookie_info = [self.cookie_handler.make_cookie_content("oidcop", "value", "sso")] returned = [{"name": c["name"], "value": c["value"]} for c in _cookie_info] _info = self.cookie_handler.parse_cookie("oidcop", returned) assert len(_info) == 1 @@ -50,9 +48,7 @@ def test_read_cookie_info(self): def test_mult_cookie(self): _cookie = [ self.cookie_handler.make_cookie_content("oidcop", "value", "sso"), - self.cookie_handler.make_cookie_content( - "oidcop", "session_state", "session" - ), + self.cookie_handler.make_cookie_content("oidcop", "session_state", "session"), ] assert len(_cookie) == 2 _c_info = self.cookie_handler.parse_cookie("oidcop", _cookie) @@ -88,9 +84,7 @@ def test_make_cookie_content_max_age(self): assert len(_cookie_info["value"].split("|")) == 4 def test_read_cookie_info(self): - _cookie_info = [ - self.cookie_handler.make_cookie_content("oidcop", "value", "sso") - ] + _cookie_info = [self.cookie_handler.make_cookie_content("oidcop", "value", "sso")] returned = [{"name": c["name"], "value": c["value"]} for c in _cookie_info] _info = self.cookie_handler.parse_cookie("oidcop", returned) assert len(_info) == 1 @@ -101,9 +95,7 @@ def test_read_cookie_info(self): def test_mult_cookie(self): _cookie = [ self.cookie_handler.make_cookie_content("oidcop", "value", "sso"), - self.cookie_handler.make_cookie_content( - "oidcop", "session_state", "session" - ), + self.cookie_handler.make_cookie_content("oidcop", "session_state", "session"), ] assert len(_cookie) == 2 _c_info = self.cookie_handler.parse_cookie("oidcop", _cookie) @@ -138,9 +130,7 @@ def test_make_cookie_content_max_age(self): assert len(_cookie_info["value"].split("|")) == 4 def test_read_cookie_info(self): - _cookie_info = [ - self.cookie_handler.make_cookie_content("oidcop", "value", "sso") - ] + _cookie_info = [self.cookie_handler.make_cookie_content("oidcop", "value", "sso")] returned = [{"name": c["name"], "value": c["value"]} for c in _cookie_info] _info = self.cookie_handler.parse_cookie("oidcop", returned) assert len(_info) == 1 @@ -151,9 +141,7 @@ def test_read_cookie_info(self): def test_mult_cookie(self): _cookie = [ self.cookie_handler.make_cookie_content("oidcop", "value", "sso"), - self.cookie_handler.make_cookie_content( - "oidcop", "session_state", "session" - ), + self.cookie_handler.make_cookie_content("oidcop", "session_state", "session"), ] assert len(_cookie) == 2 _c_info = self.cookie_handler.parse_cookie("oidcop", _cookie) @@ -192,9 +180,7 @@ def test_make_cookie_content_max_age(self): assert len(_cookie_info["value"].split("|")) == 4 def test_read_cookie_info(self): - _cookie_info = [ - self.cookie_handler.make_cookie_content("oidcop", "value", "sso") - ] + _cookie_info = [self.cookie_handler.make_cookie_content("oidcop", "value", "sso")] returned = [{"name": c["name"], "value": c["value"]} for c in _cookie_info] _info = self.cookie_handler.parse_cookie("oidcop", returned) assert len(_info) == 1 @@ -205,9 +191,7 @@ def test_read_cookie_info(self): def test_mult_cookie(self): _cookie = [ self.cookie_handler.make_cookie_content("oidcop", "value", "sso"), - self.cookie_handler.make_cookie_content( - "oidcop", "session_state", "session" - ), + self.cookie_handler.make_cookie_content("oidcop", "session_state", "session"), ] assert len(_cookie) == 2 _c_info = self.cookie_handler.parse_cookie("oidcop", _cookie) @@ -219,9 +203,7 @@ def test_mult_cookie(self): def test_compute_session_state(): - hv = compute_session_state( - "state", "salt", "client_id", "https://example.com/redirect" - ) + hv = compute_session_state("state", "salt", "client_id", "https://example.com/redirect") assert hv == "d21113fbe4b54661ae45f3a3233b0f865ccc646af248274b6fa5664267540e29.salt" diff --git a/tests/test_12_user_authn.py b/tests/test_12_user_authn.py index e0fe2cf2..37e88374 100644 --- a/tests/test_12_user_authn.py +++ b/tests/test_12_user_authn.py @@ -3,17 +3,16 @@ import pytest +from oidcop.configure import OPConfiguration from oidcop.server import Server from oidcop.user_authn.authn_context import INTERNETPROTOCOLPASSWORD from oidcop.user_authn.authn_context import UNSPECIFIED -from oidcop.user_authn.user import NoAuthn -from oidcop.user_authn.user import UserPassJinja2 from oidcop.user_authn.user import BasicAuthn from oidcop.user_authn.user import NoAuthn from oidcop.user_authn.user import SymKeyAuthn +from oidcop.user_authn.user import UserPassJinja2 from oidcop.util import JSONDictDB - KEYDEFS = [ {"type": "RSA", "key": "", "use": ["sig"]}, {"type": "EC", "crv": "P-256", "use": ["sig"]}, @@ -54,12 +53,9 @@ def create_endpoint_context(self): "passwd_label": "Secret sauce", }, }, - "anon": { - "acr": UNSPECIFIED, - "class": NoAuthn, - "kwargs": {"user": "diana"}, - }, + "anon": {"acr": UNSPECIFIED, "class": NoAuthn, "kwargs": {"user": "diana"}, }, }, + "template_dir": "templates", "cookie_handler": { "class": "oidcop.cookie_handler.CookieHandler", "kwargs": { @@ -71,10 +67,9 @@ def create_endpoint_context(self): }, }, }, - "template_dir": "tests/templates", } - server = Server(conf) - self.endpoint_context = server.endpoint_context + self.server = Server(OPConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) + self.endpoint_context = self.server.endpoint_context def test_authenticated_as_without_cookie(self): authn_item = self.endpoint_context.authn_broker.pick(INTERNETPROTOCOLPASSWORD) @@ -104,29 +99,23 @@ def test_authenticated_as_with_cookie(self): def test_userpassjinja2(self): db = { - "class": JSONDictDB, - "kwargs": {"filename": full_path("passwd.json")}, + "class": JSONDictDB, + "kwargs": {"filename": full_path("passwd.json")}, } template_handler = self.endpoint_context.template_handler - sg = self.endpoint_context.session_manager.token_handler.handler['access_token'].kwargs['server_get'] - res = UserPassJinja2(db, template_handler, server_get=sg) + res = UserPassJinja2(db, template_handler, + server_get=self.server.server_get) res() - assert 'page_header' in res.kwargs - + assert "page_header" in res.kwargs def test_basic_auth(self): - sg = self.endpoint_context.session_manager.token_handler.handler['access_token'].kwargs['server_get'] - basic_auth = base64.b64encode(b'diana:krall').decode() - ba = BasicAuthn(pwd={'diana': 'krall'}, - server_get=sg) - ba.authenticated_as( - client_id='', authorization=f"Basic {basic_auth}" - ) + basic_auth = base64.b64encode(b"diana:krall").decode() + ba = BasicAuthn(pwd={"diana": "krall"}, server_get=self.server.server_get) + ba.authenticated_as(client_id="", authorization=f"Basic {basic_auth}") def test_no_auth(self): - sg = self.endpoint_context.session_manager.token_handler.handler['access_token'].kwargs['server_get'] basic_auth = base64.b64encode( - b'D\xfd\x8a\x85\xa6\xd1\x16\xe4\\6\x1e\x9ds~\xc3\t\x95\x99\x83\x91\x1f\xfb:iviviviv' + b"D\xfd\x8a\x85\xa6\xd1\x16\xe4\\6\x1e\x9ds~\xc3\t\x95\x99\x83\x91\x1f\xfb:iviviviv" ) - ba = SymKeyAuthn(symkey=b'0'*32, ttl=600, server_get=sg) - ba.authenticated_as(client_id='', authorization=basic_auth) + ba = SymKeyAuthn(symkey=b"0" * 32, ttl=600, server_get=self.server.server_get) + ba.authenticated_as(client_id="", authorization=basic_auth) diff --git a/tests/test_13_login_hint.py b/tests/test_13_login_hint.py index 025a572d..7213a27a 100644 --- a/tests/test_13_login_hint.py +++ b/tests/test_13_login_hint.py @@ -2,7 +2,6 @@ import os from oidcop.configure import OPConfiguration -from oidcop.configure import create_from_config_file from oidcop.endpoint_context import init_service from oidcop.endpoint_context import init_user_info from oidcop.login_hint import LoginHint2Acrs @@ -17,15 +16,9 @@ def full_path(local_file): def test_login_hint(): userinfo = init_user_info( - { - "class": "oidcop.user_info.UserInfo", - "kwargs": {"db_file": full_path("users.json")}, - }, - "", - ) - login_hint_lookup = init_service( - {"class": "oidcop.login_hint.LoginHintLookup"}, None + {"class": "oidcop.user_info.UserInfo", "kwargs": {"db_file": full_path("users.json")},}, "", ) + login_hint_lookup = init_service({"class": "oidcop.login_hint.LoginHintLookup"}, None) login_hint_lookup.userinfo = userinfo assert login_hint_lookup("tel:0907865000") == "diana" @@ -46,9 +39,7 @@ def test_login_hint2acrs_unmatched_schema(): def test_server_login_hint_lookup(): _str = open(full_path("op_config.json")).read() _conf = json.loads(_str) - configuration = OPConfiguration( - conf=_conf, base_path=BASEDIR, domain="127.0.0.1", port=443 - ) + configuration = OPConfiguration(conf=_conf, base_path=BASEDIR, domain="127.0.0.1", port=443) server = Server(configuration) assert server.endpoint_context.login_hint_lookup("tel:0907865000") == "diana" diff --git a/tests/test_20_endpoint.py b/tests/test_20_endpoint.py index 82373e7f..91e37a99 100755 --- a/tests/test_20_endpoint.py +++ b/tests/test_20_endpoint.py @@ -1,6 +1,8 @@ import json +import os from urllib.parse import urlparse +from oidcop.configure import OPConfiguration import pytest from oidcmsg.message import Message @@ -8,6 +10,8 @@ from oidcop.server import Server from oidcop.user_authn.authn_context import INTERNETPROTOCOLPASSWORD +BASEDIR = os.path.abspath(os.path.dirname(__file__)) + KEYDEFS = [ {"type": "RSA", "key": "", "use": ["sig"]}, {"type": "EC", "crv": "P-256", "use": ["sig"]}, @@ -44,9 +48,7 @@ def create_endpoint(self): "grant_expires_in": 300, "refresh_token_expires_in": 86400, "verify_ssl": False, - "endpoint": { - "endpoint": {"path": "endpoint", "class": Endpoint, "kwargs": {}}, - }, + "endpoint": {"endpoint": {"path": "endpoint", "class": Endpoint, "kwargs": {}},}, "keys": { "public_path": "jwks.json", "key_defs": KEYDEFS, @@ -62,7 +64,8 @@ def create_endpoint(self): }, "template_dir": "template", } - server = Server(conf) + server = Server(OPConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) + self.endpoint_context = server.endpoint_context self.endpoint = server.server_get("endpoint", "") @@ -166,9 +169,7 @@ def test_do_response_response_msg_1(self): assert ("Content-type", "application/json") in info["http_headers"] self.endpoint.response_format = "jws" - info = self.endpoint.do_response( - EXAMPLE_MSG, response_msg="header.payload.sign" - ) + info = self.endpoint.do_response(EXAMPLE_MSG, response_msg="header.payload.sign") assert info["response"] == "header.payload.sign" assert ("Content-type", "application/jose") in info["http_headers"] @@ -177,9 +178,7 @@ def test_do_response_response_msg_1(self): info = self.endpoint.do_response(EXAMPLE_MSG, response_msg="foo=bar") assert info["response"] == "foo=bar" - assert ("Content-type", "application/x-www-form-urlencoded") in info[ - "http_headers" - ] + assert ("Content-type", "application/x-www-form-urlencoded") in info["http_headers"] info = self.endpoint.do_response( EXAMPLE_MSG, response_msg="{foo=bar}", content_type="application/json" @@ -188,9 +187,7 @@ def test_do_response_response_msg_1(self): assert ("Content-type", "application/json") in info["http_headers"] info = self.endpoint.do_response( - EXAMPLE_MSG, - response_msg="header.payload.sign", - content_type="application/jose", + EXAMPLE_MSG, response_msg="header.payload.sign", content_type="application/jose", ) assert info["response"] == "header.payload.sign" assert ("Content-type", "application/jose") in info["http_headers"] @@ -198,22 +195,15 @@ def test_do_response_response_msg_1(self): def test_do_response_placement_body(self): self.endpoint.response_placement = "body" info = self.endpoint.do_response(EXAMPLE_MSG) - assert ("Content-type", "application/json; charset=utf-8") in info[ - "http_headers" - ] + assert ("Content-type", "application/json; charset=utf-8") in info["http_headers"] assert ( - info["response"] - == '{"name": "Doe, Jane", "given_name": "Jane", "family_name": "Doe"}' + info["response"] == '{"name": "Doe, Jane", "given_name": "Jane", "family_name": "Doe"}' ) def test_do_response_placement_url(self): self.endpoint.response_placement = "url" - info = self.endpoint.do_response( - EXAMPLE_MSG, return_uri="https://example.org/cb" - ) - assert ("Content-type", "application/x-www-form-urlencoded") in info[ - "http_headers" - ] + info = self.endpoint.do_response(EXAMPLE_MSG, return_uri="https://example.org/cb") + assert ("Content-type", "application/x-www-form-urlencoded") in info["http_headers"] assert ( info["response"] == "https://example.org/cb?name=Doe%2C+Jane&given_name=Jane&family_name=Doe" @@ -222,9 +212,7 @@ def test_do_response_placement_url(self): info = self.endpoint.do_response( EXAMPLE_MSG, return_uri="https://example.org/cb", fragment_enc=True ) - assert ("Content-type", "application/x-www-form-urlencoded") in info[ - "http_headers" - ] + assert ("Content-type", "application/x-www-form-urlencoded") in info["http_headers"] assert ( info["response"] == "https://example.org/cb#name=Doe%2C+Jane&given_name=Jane&family_name=Doe" diff --git a/tests/test_21_oidc_discovery_endpoint.py b/tests/test_21_oidc_discovery_endpoint.py index 24f7b38a..448d16e5 100755 --- a/tests/test_21_oidc_discovery_endpoint.py +++ b/tests/test_21_oidc_discovery_endpoint.py @@ -1,5 +1,7 @@ import json +import os +from oidcop.configure import OPConfiguration import pytest from oidcop.oidc.discovery import Discovery @@ -11,6 +13,8 @@ {"type": "EC", "crv": "P-256", "use": ["sig"]}, ] +BASEDIR = os.path.abspath(os.path.dirname(__file__)) + class TestEndpoint(object): @pytest.fixture(autouse=True) @@ -39,7 +43,7 @@ def create_endpoint(self): }, "template_dir": "template", } - server = Server(conf) + server = Server(OPConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) self.endpoint = server.server_get("endpoint", "discovery") def test_do_response(self): diff --git a/tests/test_22_oidc_provider_config_endpoint.py b/tests/test_22_oidc_provider_config_endpoint.py index 5a393a31..df950023 100755 --- a/tests/test_22_oidc_provider_config_endpoint.py +++ b/tests/test_22_oidc_provider_config_endpoint.py @@ -1,11 +1,15 @@ import json +import os +from oidcop.configure import OPConfiguration import pytest from oidcop.oidc.provider_config import ProviderConfiguration from oidcop.oidc.token import Token from oidcop.server import Server +BASEDIR = os.path.abspath(os.path.dirname(__file__)) + KEYDEFS = [ {"type": "RSA", "key": "", "use": ["sig"]}, {"type": "EC", "crv": "P-256", "use": ["sig"]}, @@ -64,7 +68,8 @@ def create_endpoint(self): }, "template_dir": "template", } - server = Server(conf) + server = Server(OPConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) + self.endpoint_context = server.endpoint_context self.endpoint = server.server_get("endpoint", "provider_config") @@ -98,6 +103,4 @@ def test_do_response(self): "updated_at", "birthdate", } - assert ("Content-type", "application/json; charset=utf-8") in msg[ - "http_headers" - ] + assert ("Content-type", "application/json; charset=utf-8") in msg["http_headers"] diff --git a/tests/test_23_oidc_registration_endpoint.py b/tests/test_23_oidc_registration_endpoint.py index ceb3e7db..eb2c23f2 100755 --- a/tests/test_23_oidc_registration_endpoint.py +++ b/tests/test_23_oidc_registration_endpoint.py @@ -1,6 +1,8 @@ # -*- coding: latin-1 -*- import json +import os +from oidcop.configure import OPConfiguration import pytest import responses from cryptojwt.key_jar import init_key_jar @@ -16,6 +18,8 @@ from oidcop.oidc.userinfo import UserInfo from oidcop.server import Server +BASEDIR = os.path.abspath(os.path.dirname(__file__)) + KEYDEFS = [ {"type": "RSA", "key": "", "use": ["sig"]}, {"type": "EC", "crv": "P-256", "use": ["sig"]}, @@ -87,9 +91,7 @@ def create_endpoint(self): "jwks_def": { "private_path": "private/token_jwks.json", "read_only": False, - "key_defs": [ - {"type": "oct", "bytes": "24", "use": ["enc"], "kid": "code"} - ], + "key_defs": [{"type": "oct", "bytes": "24", "use": ["enc"], "kid": "code"}], }, "code": {"kwargs": {"lifetime": 600}}, "token": { @@ -129,15 +131,9 @@ def create_endpoint(self): "path": "authorization", "class": Authorization, "kwargs": { - "response_types_supported": [ - " ".join(x) for x in RESPONSE_TYPES_SUPPORTED - ], + "response_types_supported": [" ".join(x) for x in RESPONSE_TYPES_SUPPORTED], "response_modes_supported": ["query", "fragment", "form_post"], - "claim_types_supported": [ - "normal", - "aggregated", - "distributed", - ], + "claim_types_supported": ["normal", "aggregated", "distributed",], "claims_parameter_supported": True, "request_parameter_supported": True, "request_uri_parameter_supported": True, @@ -159,7 +155,7 @@ def create_endpoint(self): }, "template_dir": "template", } - server = Server(conf) + server = Server(OPConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) self.endpoint = server.server_get("endpoint", "registration") def test_parse(self): diff --git a/tests/test_24_oauth2_authorization_endpoint.py b/tests/test_24_oauth2_authorization_endpoint.py index fbdb52b9..319a9a15 100755 --- a/tests/test_24_oauth2_authorization_endpoint.py +++ b/tests/test_24_oauth2_authorization_endpoint.py @@ -5,6 +5,7 @@ from urllib.parse import parse_qs from urllib.parse import urlparse +from oidcop.configure import ASConfiguration import pytest import yaml from cryptojwt import KeyJar @@ -151,9 +152,7 @@ def create_endpoint(self): "jwks_def": { "private_path": "private/token_jwks.json", "read_only": False, - "key_defs": [ - {"type": "oct", "bytes": "24", "use": ["enc"], "kid": "code"} - ], + "key_defs": [{"type": "oct", "bytes": "24", "use": ["enc"], "kid": "code"}], }, "code": {"kwargs": {"lifetime": 600}}, "token": { @@ -183,9 +182,7 @@ def create_endpoint(self): "path": "{}/authorization", "class": Authorization, "kwargs": { - "response_types_supported": [ - " ".join(x) for x in RESPONSE_TYPES_SUPPORTED - ], + "response_types_supported": [" ".join(x) for x in RESPONSE_TYPES_SUPPORTED], "response_modes_supported": ["query", "fragment", "form_post"], "claims_parameter_supported": True, "request_parameter_supported": True, @@ -219,20 +216,12 @@ def create_endpoint(self): "grant_config": { "usage_rules": { "authorization_code": { - "supports_minting": [ - "access_token", - "refresh_token", - "id_token", - ], + "supports_minting": ["access_token", "refresh_token", "id_token",], "max_usage": 1, }, "access_token": {}, "refresh_token": { - "supports_minting": [ - "access_token", - "refresh_token", - "id_token", - ], + "supports_minting": ["access_token", "refresh_token", "id_token",], }, }, "expires_in": 43200, @@ -240,7 +229,8 @@ def create_endpoint(self): }, }, } - server = Server(conf) + server = Server(ASConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) + endpoint_context = server.endpoint_context _clients = yaml.safe_load(io.StringIO(client_yaml)) endpoint_context.cdb = _clients["clients"] @@ -316,9 +306,7 @@ def test_do_response_code_token(self): def test_verify_uri_unknown_client(self): request = {"redirect_uri": "https://rp.example.com/cb"} with pytest.raises(UnknownClient): - verify_uri( - self.endpoint.server_get("endpoint_context"), request, "redirect_uri" - ) + verify_uri(self.endpoint.server_get("endpoint_context"), request, "redirect_uri") def test_verify_uri_fragment(self): _context = self.endpoint.server_get("endpoint_context") @@ -336,9 +324,7 @@ def test_verify_uri_noregistered(self): def test_verify_uri_unregistered(self): _context = self.endpoint.server_get("endpoint_context") - _context.cdb["client_id"] = { - "redirect_uris": [("https://rp.example.com/auth_cb", {})] - } + _context.cdb["client_id"] = {"redirect_uris": [("https://rp.example.com/auth_cb", {})]} request = {"redirect_uri": "https://rp.example.com/cb"} @@ -380,9 +366,7 @@ def test_verify_uri_qp_mismatch(self): def test_verify_uri_qp_missing(self): _context = self.endpoint.server_get("endpoint_context") _context.cdb["client_id"] = { - "redirect_uris": [ - ("https://rp.example.com/cb", {"foo": ["bar"], "state": ["low"]}) - ] + "redirect_uris": [("https://rp.example.com/cb", {"foo": ["bar"], "state": ["low"]})] } request = {"redirect_uri": "https://rp.example.com/cb?foo=bar"} @@ -401,9 +385,7 @@ def test_verify_uri_qp_missing_val(self): def test_verify_uri_no_registered_qp(self): _context = self.endpoint.server_get("endpoint_context") - _context.cdb["client_id"] = { - "redirect_uris": [("https://rp.example.com/cb", {})] - } + _context.cdb["client_id"] = {"redirect_uris": [("https://rp.example.com/cb", {})]} request = {"redirect_uri": "https://rp.example.com/cb?foo=bob"} with pytest.raises(ValueError): @@ -411,9 +393,7 @@ def test_verify_uri_no_registered_qp(self): def test_verify_uri_wrong_uri_type(self): _context = self.endpoint.server_get("endpoint_context") - _context.cdb["client_id"] = { - "redirect_uris": [("https://rp.example.com/cb", {})] - } + _context.cdb["client_id"] = {"redirect_uris": [("https://rp.example.com/cb", {})]} request = {"redirect_uri": "https://rp.example.com/cb?foo=bob"} with pytest.raises(ValueError): @@ -431,9 +411,7 @@ def test_verify_uri_none_registered(self): def test_get_uri(self): _context = self.endpoint.server_get("endpoint_context") - _context.cdb["client_id"] = { - "redirect_uris": [("https://rp.example.com/cb", {})] - } + _context.cdb["client_id"] = {"redirect_uris": [("https://rp.example.com/cb", {})]} request = { "redirect_uri": "https://rp.example.com/cb", @@ -444,9 +422,7 @@ def test_get_uri(self): def test_get_uri_no_redirect_uri(self): _context = self.endpoint.server_get("endpoint_context") - _context.cdb["client_id"] = { - "redirect_uris": [("https://rp.example.com/cb", {})] - } + _context.cdb["client_id"] = {"redirect_uris": [("https://rp.example.com/cb", {})]} request = {"client_id": "client_id"} @@ -454,9 +430,7 @@ def test_get_uri_no_redirect_uri(self): def test_get_uri_no_registered(self): _context = self.endpoint.server_get("endpoint_context") - _context.cdb["client_id"] = { - "redirect_uris": [("https://rp.example.com/cb", {})] - } + _context.cdb["client_id"] = {"redirect_uris": [("https://rp.example.com/cb", {})]} request = {"client_id": "client_id"} @@ -514,9 +488,9 @@ def test_setup_auth(self): "id_token_signed_response_alg": "RS256", } - kaka = self.endpoint.server_get( - "endpoint_context" - ).cookie_handler.make_cookie_content("value", "sso") + kaka = self.endpoint.server_get("endpoint_context").cookie_handler.make_cookie_content( + "value", "sso" + ) res = self.endpoint.setup_auth(request, redirect_uri, cinfo, [kaka]) assert set(res.keys()) == {"session_id", "identity", "user"} @@ -575,9 +549,7 @@ def test_setup_auth_invalid_scope(self): _context.conf["capabilities"]["deny_unknown_scopes"] = True excp = None try: - res = self.endpoint.process_request( - request, http_info={"headers": {"cookie": [kaka]}} - ) + res = self.endpoint.process_request(request, http_info={"headers": {"cookie": [kaka]}}) except UnAuthorizedClientScope as e: excp = e assert excp @@ -602,9 +574,7 @@ def test_setup_auth_user(self): session_id = self._create_session(request) item = self.endpoint.server_get("endpoint_context").authn_broker.db["anon"] - item["method"].user = b64e( - as_bytes(json.dumps({"uid": "krall", "sid": session_id})) - ) + item["method"].user = b64e(as_bytes(json.dumps({"uid": "krall", "sid": session_id}))) res = self.endpoint.setup_auth(request, redirect_uri, cinfo, None) assert set(res.keys()) == {"session_id", "identity", "user"} @@ -634,9 +604,7 @@ def test_setup_auth_session_revoked(self): _csi.revoked = True item = _context.authn_broker.db["anon"] - item["method"].user = b64e( - as_bytes(json.dumps({"uid": "krall", "sid": session_id})) - ) + item["method"].user = b64e(as_bytes(json.dumps({"uid": "krall", "sid": session_id}))) res = self.endpoint.setup_auth(request, redirect_uri, cinfo, None) assert set(res.keys()) == {"args", "function"} @@ -654,8 +622,7 @@ def test_response_mode_form_post(self): "response_placement", } assert info["response_msg"] == FORM_POST.format( - action="https://example.com/cb", - inputs='', + action="https://example.com/cb", inputs='', ) def test_response_mode_fragment(self): @@ -683,9 +650,7 @@ def test_req_user(self): "redirect_uris": [("https://rp.example.com/cb", {})], "id_token_signed_response_alg": "RS256", } - res = self.endpoint.setup_auth( - request, redirect_uri, cinfo, None, req_user="adam" - ) + res = self.endpoint.setup_auth(request, redirect_uri, cinfo, None, req_user="adam") assert "function" in res def test_req_user_no_prompt(self): @@ -704,9 +669,7 @@ def test_req_user_no_prompt(self): "redirect_uris": [("https://rp.example.com/cb", {})], "id_token_signed_response_alg": "RS256", } - res = self.endpoint.setup_auth( - request, redirect_uri, cinfo, None, req_user="adam" - ) + res = self.endpoint.setup_auth(request, redirect_uri, cinfo, None, req_user="adam") assert "error" in res # def test_sso(self): diff --git a/tests/test_24_oauth2_authorization_endpoint_jar.py b/tests/test_24_oauth2_authorization_endpoint_jar.py index 0d0c629f..ad4f4523 100755 --- a/tests/test_24_oauth2_authorization_endpoint_jar.py +++ b/tests/test_24_oauth2_authorization_endpoint_jar.py @@ -3,6 +3,7 @@ import os from http.cookies import SimpleCookie +from oidcop.configure import ASConfiguration import pytest import responses import yaml @@ -129,9 +130,7 @@ def create_endpoint(self): "path": "{}/authorization", "class": Authorization, "kwargs": { - "response_types_supported": [ - " ".join(x) for x in RESPONSE_TYPES_SUPPORTED - ], + "response_types_supported": [" ".join(x) for x in RESPONSE_TYPES_SUPPORTED], "response_modes_supported": ["query", "fragment", "form_post"], "claims_parameter_supported": True, "request_parameter_supported": True, @@ -161,7 +160,7 @@ def create_endpoint(self): }, }, } - server = Server(conf) + server = Server(ASConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) endpoint_context = server.endpoint_context _clients = yaml.safe_load(io.StringIO(client_yaml)) endpoint_context.cdb = _clients["clients"] @@ -179,8 +178,7 @@ def create_endpoint(self): def test_parse_request_parameter(self): _jwt = JWT(key_jar=self.rp_keyjar, iss="client_1", sign_alg="HS256") _jws = _jwt.pack( - AUTH_REQ_DICT, - aud=self.endpoint.server_get("endpoint_context").provider_info["issuer"], + AUTH_REQ_DICT, aud=self.endpoint.server_get("endpoint_context").provider_info["issuer"], ) # ----------------- _req = self.endpoint.parse_request( @@ -197,8 +195,7 @@ def test_parse_request_parameter(self): def test_parse_request_uri(self): _jwt = JWT(key_jar=self.rp_keyjar, iss="client_1", sign_alg="HS256") _jws = _jwt.pack( - AUTH_REQ_DICT, - aud=self.endpoint.server_get("endpoint_context").provider_info["issuer"], + AUTH_REQ_DICT, aud=self.endpoint.server_get("endpoint_context").provider_info["issuer"], ) request_uri = "https://client.example.com/req" diff --git a/tests/test_24_oauth2_token_endpoint.py b/tests/test_24_oauth2_token_endpoint.py new file mode 100644 index 00000000..2222447d --- /dev/null +++ b/tests/test_24_oauth2_token_endpoint.py @@ -0,0 +1,480 @@ +import json +import os + +from cryptojwt import JWT +from cryptojwt.key_jar import build_keyjar +from oidcmsg.oidc import AccessTokenRequest +from oidcmsg.oidc import AuthorizationRequest +from oidcmsg.oidc import RefreshAccessTokenRequest +from oidcmsg.oidc import TokenErrorResponse +from oidcmsg.time_util import utc_time_sans_frac +import pytest + +from oidcop import JWT_BEARER +from oidcop.authn_event import create_authn_event +from oidcop.authz import AuthzHandling +from oidcop.client_authn import verify_client +from oidcop.configure import ASConfiguration +from oidcop.exception import UnAuthorizedClient +from oidcop.oauth2.authorization import Authorization +from oidcop.oauth2.token import Token +from oidcop.server import Server +from oidcop.session import MintingNotAllowed +from oidcop.user_authn.authn_context import INTERNETPROTOCOLPASSWORD +from oidcop.user_info import UserInfo + +KEYDEFS = [ + {"type": "RSA", "key": "", "use": ["sig"]}, + {"type": "EC", "crv": "P-256", "use": ["sig"]}, +] + +CLIENT_KEYJAR = build_keyjar(KEYDEFS) + +RESPONSE_TYPES_SUPPORTED = [ + ["code"], + ["token"], +] + +CAPABILITIES = { + "grant_types_supported": [ + "authorization_code", + "implicit", + "urn:ietf:params:oauth:grant-type:jwt-bearer", + "refresh_token", + ], +} + +AUTH_REQ = AuthorizationRequest( + client_id="client_1", + redirect_uri="https://example.com/cb", + scope=["openid"], + state="STATE", + response_type="code", +) + +TOKEN_REQ = AccessTokenRequest( + client_id="client_1", + redirect_uri="https://example.com/cb", + state="STATE", + grant_type="authorization_code", + client_secret="hemligt", +) + +REFRESH_TOKEN_REQ = RefreshAccessTokenRequest( + grant_type="refresh_token", client_id="client_1", client_secret="hemligt" +) + +TOKEN_REQ_DICT = TOKEN_REQ.to_dict() + +BASEDIR = os.path.abspath(os.path.dirname(__file__)) + + +def full_path(local_file): + return os.path.join(BASEDIR, local_file) + + +USERINFO = UserInfo(json.loads(open(full_path("users.json")).read())) + + +@pytest.fixture +def conf(): + return { + "issuer": "https://example.com/", + "password": "mycket hemligt", + "verify_ssl": False, + "capabilities": CAPABILITIES, + "keys": {"uri_path": "jwks.json", "key_defs": KEYDEFS}, + "token_handler_args": { + "jwks_file": "private/token_jwks.json", + "code": {"kwargs": {"lifetime": 600}}, + "token": { + "class": "oidcop.token.jwt_token.JWTToken", + "kwargs": { + "lifetime": 3600, + "add_claims_by_scope": True, + "aud": ["https://example.org/appl"], + }, + }, + "refresh": { + "class": "oidcop.token.jwt_token.JWTToken", + "kwargs": {"lifetime": 3600, "aud": ["https://example.org/appl"],}, + }, + }, + "endpoint": { + "authorization": {"path": "authorization", "class": Authorization, "kwargs": {},}, + "token": { + "path": "token", + "class": Token, + "kwargs": { + "client_authn_method": [ + "client_secret_basic", + "client_secret_post", + "client_secret_jwt", + "private_key_jwt", + ] + }, + }, + }, + "authentication": { + "anon": { + "acr": INTERNETPROTOCOLPASSWORD, + "class": "oidcop.user_authn.user.NoAuthn", + "kwargs": {"user": "diana"}, + } + }, + "userinfo": {"class": UserInfo, "kwargs": {"db": {}}}, + "client_authn": verify_client, + "template_dir": "template", + "claims_interface": {"class": "oidcop.session.claims.OAuth2ClaimsInterface", "kwargs": {}}, + "authz": { + "class": AuthzHandling, + "kwargs": { + "grant_config": { + "usage_rules": { + "authorization_code": { + "expires_in": 300, + "supports_minting": ["access_token", "refresh_token"], + "max_usage": 1, + }, + "access_token": {"expires_in": 600}, + "refresh_token": { + "expires_in": 86400, + "supports_minting": ["access_token", "refresh_token"], + }, + }, + "expires_in": 43200, + } + }, + }, + } + + +class TestEndpoint(object): + @pytest.fixture(autouse=True) + def create_endpoint(self, conf): + server = Server(ASConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) + endpoint_context = server.endpoint_context + endpoint_context.cdb["client_1"] = { + "client_secret": "hemligt", + "redirect_uris": [("https://example.com/cb", None)], + "client_salt": "salted", + "endpoint_auth_method": "client_secret_post", + "response_types": ["code", "token", "code id_token", "id_token"], + } + endpoint_context.keyjar.import_jwks(CLIENT_KEYJAR.export_jwks(), "client_1") + self.session_manager = endpoint_context.session_manager + self.token_endpoint = server.server_get("endpoint", "token") + self.user_id = "diana" + self.endpoint_context = endpoint_context + + def test_init(self): + assert self.token_endpoint + + def _create_session(self, auth_req, sub_type="public", sector_identifier=""): + if sector_identifier: + authz_req = auth_req.copy() + authz_req["sector_identifier_uri"] = sector_identifier + else: + authz_req = auth_req + client_id = authz_req["client_id"] + ae = create_authn_event(self.user_id) + return self.session_manager.create_session( + ae, authz_req, self.user_id, client_id=client_id, sub_type=sub_type + ) + + def _mint_code(self, grant, client_id): + session_id = self.session_manager.encrypted_session_id(self.user_id, client_id, grant.id) + usage_rules = grant.usage_rules.get("authorization_code", {}) + _exp_in = usage_rules.get("expires_in") + + # Constructing an authorization code is now done + _code = grant.mint_token( + session_id=session_id, + endpoint_context=self.endpoint_context, + token_type="authorization_code", + token_handler=self.session_manager.token_handler["code"], + usage_rules=usage_rules, + ) + + if _exp_in: + if isinstance(_exp_in, str): + _exp_in = int(_exp_in) + if _exp_in: + _code.expires_at = utc_time_sans_frac() + _exp_in + return _code + + def _mint_access_token(self, grant, session_id, token_ref=None): + _session_info = self.session_manager.get_session_info(session_id) + usage_rules = grant.usage_rules.get("access_token", {}) + _exp_in = usage_rules.get("expires_in", 0) + + _token = grant.mint_token( + _session_info, + endpoint_context=self.endpoint_context, + token_type="access_token", + token_handler=self.session_manager.token_handler["access_token"], + based_on=token_ref, # Means the token (tok) was used to mint this token + usage_rules=usage_rules, + ) + if isinstance(_exp_in, str): + _exp_in = int(_exp_in) + if _exp_in: + _token.expires_at = utc_time_sans_frac() + _exp_in + + return _token + + def test_parse(self): + session_id = self._create_session(AUTH_REQ) + grant = self.session_manager[session_id] + code = self._mint_code(grant, AUTH_REQ["client_id"]) + + _token_request = TOKEN_REQ_DICT.copy() + _token_request["code"] = code.value + _req = self.token_endpoint.parse_request(_token_request) + + assert set(_req.keys()) == set(_token_request.keys()) + + def test_process_request(self): + session_id = self._create_session(AUTH_REQ) + grant = self.session_manager[session_id] + code = self._mint_code(grant, AUTH_REQ["client_id"]) + + _token_request = TOKEN_REQ_DICT.copy() + _context = self.endpoint_context + _token_request["code"] = code.value + _req = self.token_endpoint.parse_request(_token_request) + _resp = self.token_endpoint.process_request(request=_req) + + assert _resp + assert set(_resp.keys()) == {"cookie", "http_headers", "response_args"} + + def test_process_request_using_code_twice(self): + session_id = self._create_session(AUTH_REQ) + grant = self.session_manager[session_id] + code = self._mint_code(grant, AUTH_REQ["client_id"]) + + _token_request = TOKEN_REQ_DICT.copy() + _context = self.endpoint_context + _token_request["code"] = code.value + + _req = self.token_endpoint.parse_request(_token_request) + _resp = self.token_endpoint.process_request(request=_req) + + # 2nd time used + _2nd_response = self.token_endpoint.parse_request(_token_request) + assert "error" in _2nd_response + + def test_do_response(self): + session_id = self._create_session(AUTH_REQ) + grant = self.session_manager[session_id] + code = self._mint_code(grant, AUTH_REQ["client_id"]) + + _token_request = TOKEN_REQ_DICT.copy() + _token_request["code"] = code.value + _req = self.token_endpoint.parse_request(_token_request) + + _resp = self.token_endpoint.process_request(request=_req) + msg = self.token_endpoint.do_response(request=_req, **_resp) + assert isinstance(msg, dict) + + def test_process_request_using_private_key_jwt(self): + session_id = self._create_session(AUTH_REQ) + grant = self.session_manager[session_id] + code = self._mint_code(grant, AUTH_REQ["client_id"]) + + _token_request = TOKEN_REQ_DICT.copy() + del _token_request["client_id"] + del _token_request["client_secret"] + _context = self.endpoint_context + + _jwt = JWT(CLIENT_KEYJAR, iss=AUTH_REQ["client_id"], sign_alg="RS256") + _jwt.with_jti = True + _assertion = _jwt.pack({"aud": [self.token_endpoint.full_path]}) + _token_request.update({"client_assertion": _assertion, "client_assertion_type": JWT_BEARER}) + _token_request["code"] = code.value + + _req = self.token_endpoint.parse_request(_token_request) + _resp = self.token_endpoint.process_request(request=_req) + + # 2nd time used + with pytest.raises(UnAuthorizedClient): + self.token_endpoint.parse_request(_token_request) + + def test_do_refresh_access_token(self): + areq = AUTH_REQ.copy() + areq["scope"] = ["openid"] + + session_id = self._create_session(areq) + grant = self.endpoint_context.authz(session_id, areq) + code = self._mint_code(grant, areq["client_id"]) + + _cntx = self.endpoint_context + + _token_request = TOKEN_REQ_DICT.copy() + _token_request["code"] = code.value + _req = self.token_endpoint.parse_request(_token_request) + _resp = self.token_endpoint.process_request(request=_req, issue_refresh=True) + + _request = REFRESH_TOKEN_REQ.copy() + _request["refresh_token"] = _resp["response_args"]["refresh_token"] + + _token_value = _resp["response_args"]["refresh_token"] + _session_info = self.session_manager.get_session_info_by_token(_token_value) + _token = self.session_manager.find_token(_session_info["session_id"], _token_value) + _token.usage_rules["supports_minting"] = ["access_token", "refresh_token"] + + _req = self.token_endpoint.parse_request(_request.to_json()) + _resp = self.token_endpoint.process_request(request=_req) + assert set(_resp.keys()) == {"cookie", "response_args", "http_headers"} + assert set(_resp["response_args"].keys()) == { + "access_token", + "token_type", + "expires_in", + "refresh_token", + "scope", + } + msg = self.token_endpoint.do_response(request=_req, **_resp) + assert isinstance(msg, dict) + + def test_do_2nd_refresh_access_token(self): + areq = AUTH_REQ.copy() + areq["scope"] = ["openid", "offline_access"] + + session_id = self._create_session(areq) + grant = self.endpoint_context.authz(session_id, areq) + code = self._mint_code(grant, areq["client_id"]) + + _cntx = self.endpoint_context + + _token_request = TOKEN_REQ_DICT.copy() + _token_request["code"] = code.value + _req = self.token_endpoint.parse_request(_token_request) + _resp = self.token_endpoint.process_request(request=_req, issue_refresh=True) + + _request = REFRESH_TOKEN_REQ.copy() + _request["refresh_token"] = _resp["response_args"]["refresh_token"] + + # Make sure ID Tokens can also be used by this refesh token + _token_value = _resp["response_args"]["refresh_token"] + _session_info = self.session_manager.get_session_info_by_token(_token_value) + _token = self.session_manager.find_token(_session_info["session_id"], _token_value) + _token.usage_rules["supports_minting"] = [ + "access_token", + "refresh_token", + "id_token", + ] + + _req = self.token_endpoint.parse_request(_request.to_json()) + _resp = self.token_endpoint.process_request(request=_req) + + _2nd_request = REFRESH_TOKEN_REQ.copy() + _2nd_request["refresh_token"] = _resp["response_args"]["refresh_token"] + _2nd_req = self.token_endpoint.parse_request(_request.to_json()) + _2nd_resp = self.token_endpoint.process_request(request=_req) + + assert set(_2nd_resp.keys()) == {"cookie", "response_args", "http_headers"} + assert set(_2nd_resp["response_args"].keys()) == { + "access_token", + "token_type", + "expires_in", + "refresh_token", + "scope", + } + msg = self.token_endpoint.do_response(request=_req, **_resp) + assert isinstance(msg, dict) + + def test_new_refresh_token(self, conf): + self.endpoint_context.cdb["client_1"] = { + "client_secret": "hemligt", + "redirect_uris": [("https://example.com/cb", None)], + "client_salt": "salted", + "endpoint_auth_method": "client_secret_post", + "response_types": ["code", "token", "code id_token", "id_token"], + } + + areq = AUTH_REQ.copy() + areq["scope"] = ["openid", "offline_access"] + + session_id = self._create_session(areq) + grant = self.endpoint_context.authz(session_id, areq) + code = self._mint_code(grant, areq["client_id"]) + + _token_request = TOKEN_REQ_DICT.copy() + _token_request["code"] = code.value + _req = self.token_endpoint.parse_request(_token_request) + _resp = self.token_endpoint.process_request(request=_req, issue_refresh=True) + assert "refresh_token" in _resp["response_args"] + first_refresh_token = _resp["response_args"]["refresh_token"] + + _refresh_request = REFRESH_TOKEN_REQ.copy() + _refresh_request["refresh_token"] = first_refresh_token + _2nd_req = self.token_endpoint.parse_request(_refresh_request.to_json()) + _2nd_resp = self.token_endpoint.process_request(request=_2nd_req, issue_refresh=True) + assert "refresh_token" in _2nd_resp["response_args"] + second_refresh_token = _2nd_resp["response_args"]["refresh_token"] + + _2d_refresh_request = REFRESH_TOKEN_REQ.copy() + _2d_refresh_request["refresh_token"] = second_refresh_token + _3rd_req = self.token_endpoint.parse_request(_2d_refresh_request.to_json()) + _3rd_resp = self.token_endpoint.process_request(request=_3rd_req, issue_refresh=True) + assert "access_token" in _3rd_resp["response_args"] + assert "refresh_token" in _3rd_resp["response_args"] + + assert first_refresh_token != second_refresh_token + + def test_do_refresh_access_token_not_allowed(self): + areq = AUTH_REQ.copy() + areq["scope"] = ["openid", "offline_access"] + + session_id = self._create_session(areq) + grant = self.endpoint_context.authz(session_id, areq) + code = self._mint_code(grant, areq["client_id"]) + + _cntx = self.token_endpoint.server_get("endpoint_context") + + _token_request = TOKEN_REQ_DICT.copy() + _token_request["code"] = code.value + # This is weird, issuing a refresh token that can't be used to mint anything + # but it's testing so anything goes. + grant.usage_rules["refresh_token"] = {"supports_minting": []} + _req = self.token_endpoint.parse_request(_token_request) + _resp = self.token_endpoint.process_request(request=_req, issue_refresh=True) + + _request = REFRESH_TOKEN_REQ.copy() + _request["refresh_token"] = _resp["response_args"]["refresh_token"] + _req = self.token_endpoint.parse_request(_request.to_json()) + with pytest.raises(MintingNotAllowed): + self.token_endpoint.process_request(_req) + + def test_do_refresh_access_token_revoked(self): + areq = AUTH_REQ.copy() + areq["scope"] = ["openid"] + + session_id = self._create_session(areq) + grant = self.endpoint_context.authz(session_id, areq) + code = self._mint_code(grant, areq["client_id"]) + + _cntx = self.token_endpoint.server_get("endpoint_context") + + _token_request = TOKEN_REQ_DICT.copy() + _token_request["code"] = code.value + _req = self.token_endpoint.parse_request(_token_request) + _resp = self.token_endpoint.process_request(request=_req, issue_refresh=True) + + _refresh_token = _resp["response_args"]["refresh_token"] + _cntx.session_manager.revoke_token(session_id, _refresh_token) + + _request = REFRESH_TOKEN_REQ.copy() + _request["refresh_token"] = _refresh_token + _req = self.token_endpoint.parse_request(_request.to_json()) + # A revoked token is caught already when parsing the query. + assert isinstance(_req, TokenErrorResponse) + + def test_configure_grant_types(self): + conf = {"access_token": {"class": "oidcop.oidc.token.AccessTokenHelper"}} + + self.token_endpoint.configure_grant_types(conf) + + assert len(self.token_endpoint.helper) == 1 + assert "access_token" in self.token_endpoint.helper + assert "refresh_token" not in self.token_endpoint.helper diff --git a/tests/test_24_oidc_authorization_endpoint.py b/tests/test_24_oidc_authorization_endpoint.py index 4df743ba..e75e0447 100755 --- a/tests/test_24_oidc_authorization_endpoint.py +++ b/tests/test_24_oidc_authorization_endpoint.py @@ -4,6 +4,7 @@ from urllib.parse import parse_qs from urllib.parse import urlparse +from oidcop.configure import OPConfiguration import pytest import responses import yaml @@ -183,18 +184,12 @@ def create_endpoint(self): "class": ProviderConfiguration, "kwargs": {}, }, - "registration": { - "path": "{}/registration", - "class": Registration, - "kwargs": {}, - }, + "registration": {"path": "{}/registration", "class": Registration, "kwargs": {},}, "authorization": { "path": "{}/authorization", "class": Authorization, "kwargs": { - "response_types_supported": [ - " ".join(x) for x in RESPONSE_TYPES_SUPPORTED - ], + "response_types_supported": [" ".join(x) for x in RESPONSE_TYPES_SUPPORTED], "response_modes_supported": ["query", "fragment", "form_post"], "claims_parameter_supported": True, "request_parameter_supported": True, @@ -218,11 +213,7 @@ def create_endpoint(self): "class": userinfo.UserInfo, "kwargs": { "db_file": "users.json", - "claim_types_supported": [ - "normal", - "aggregated", - "distributed", - ], + "claim_types_supported": ["normal", "aggregated", "distributed",], }, }, }, @@ -241,11 +232,7 @@ def create_endpoint(self): "grant_config": { "usage_rules": { "authorization_code": { - "supports_minting": [ - "access_token", - "refresh_token", - "id_token", - ], + "supports_minting": ["access_token", "refresh_token", "id_token",], "max_usage": 1, }, "access_token": {}, @@ -273,7 +260,8 @@ def create_endpoint(self): "kwargs": {"scheme_map": {"email": [INTERNETPROTOCOLPASSWORD]}}, }, } - server = Server(conf) + server = Server(OPConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) + endpoint_context = server.endpoint_context _clients = yaml.safe_load(io.StringIO(client_yaml)) @@ -416,8 +404,7 @@ def test_id_token_claims(self): _pr_resp = self.endpoint.parse_request(_req) _resp = self.endpoint.process_request(_pr_resp) idt = verify_id_token( - _resp["response_args"], - keyjar=self.endpoint.server_get("endpoint_context").keyjar, + _resp["response_args"], keyjar=self.endpoint.server_get("endpoint_context").keyjar, ) assert idt # from config @@ -441,8 +428,7 @@ def test_id_token_acr(self): _pr_resp = self.endpoint.parse_request(_req) _resp = self.endpoint.process_request(_pr_resp) res = verify_id_token( - _resp["response_args"], - keyjar=self.endpoint.server_get("endpoint_context").keyjar, + _resp["response_args"], keyjar=self.endpoint.server_get("endpoint_context").keyjar, ) assert res res = _resp["response_args"][verified_claim_name("id_token")] @@ -451,9 +437,7 @@ def test_id_token_acr(self): def test_verify_uri_unknown_client(self): request = {"redirect_uri": "https://rp.example.com/cb"} with pytest.raises(UnknownClient): - verify_uri( - self.endpoint.server_get("endpoint_context"), request, "redirect_uri" - ) + verify_uri(self.endpoint.server_get("endpoint_context"), request, "redirect_uri") def test_verify_uri_fragment(self): _ec = self.endpoint.server_get("endpoint_context") @@ -471,9 +455,7 @@ def test_verify_uri_noregistered(self): def test_verify_uri_unregistered(self): _ec = self.endpoint.server_get("endpoint_context") - _ec.cdb["client_id"] = { - "redirect_uris": [("https://rp.example.com/auth_cb", {})] - } + _ec.cdb["client_id"] = {"redirect_uris": [("https://rp.example.com/auth_cb", {})]} request = {"redirect_uri": "https://rp.example.com/cb"} @@ -482,9 +464,7 @@ def test_verify_uri_unregistered(self): def test_verify_uri_qp_match(self): _ec = self.endpoint.server_get("endpoint_context") - _ec.cdb["client_id"] = { - "redirect_uris": [("https://rp.example.com/cb", {"foo": ["bar"]})] - } + _ec.cdb["client_id"] = {"redirect_uris": [("https://rp.example.com/cb", {"foo": ["bar"]})]} request = {"redirect_uri": "https://rp.example.com/cb?foo=bar"} @@ -492,9 +472,7 @@ def test_verify_uri_qp_match(self): def test_verify_uri_qp_mismatch(self): _ec = self.endpoint.server_get("endpoint_context") - _ec.cdb["client_id"] = { - "redirect_uris": [("https://rp.example.com/cb", {"foo": ["bar"]})] - } + _ec.cdb["client_id"] = {"redirect_uris": [("https://rp.example.com/cb", {"foo": ["bar"]})]} request = {"redirect_uri": "https://rp.example.com/cb?foo=bob"} with pytest.raises(ValueError): @@ -515,9 +493,7 @@ def test_verify_uri_qp_mismatch(self): def test_verify_uri_qp_missing(self): _ec = self.endpoint.server_get("endpoint_context") _ec.cdb["client_id"] = { - "redirect_uris": [ - ("https://rp.example.com/cb", {"foo": ["bar"], "state": ["low"]}) - ] + "redirect_uris": [("https://rp.example.com/cb", {"foo": ["bar"], "state": ["low"]})] } request = {"redirect_uri": "https://rp.example.com/cb?foo=bar"} @@ -622,9 +598,9 @@ def test_setup_auth(self): "id_token_signed_response_alg": "RS256", } - kaka = self.endpoint.server_get( - "endpoint_context" - ).cookie_handler.make_cookie_content("value", "sso") + kaka = self.endpoint.server_get("endpoint_context").cookie_handler.make_cookie_content( + "value", "sso" + ) res = self.endpoint.setup_auth(request, redirect_uri, cinfo, [kaka]) assert set(res.keys()) == {"session_id", "identity", "user"} @@ -678,9 +654,7 @@ def test_setup_auth_user(self): session_id = self._create_session(request) item = _ec.authn_broker.db["anon"] - item["method"].user = b64e( - as_bytes(json.dumps({"uid": "krall", "sid": session_id})) - ) + item["method"].user = b64e(as_bytes(json.dumps({"uid": "krall", "sid": session_id}))) res = self.endpoint.setup_auth(request, redirect_uri, cinfo, None) assert set(res.keys()) == {"session_id", "identity", "user"} @@ -706,9 +680,7 @@ def test_setup_auth_session_revoked(self): session_id = self._create_session(request) item = _ec.authn_broker.db["anon"] - item["method"].user = b64e( - as_bytes(json.dumps({"uid": "krall", "sid": session_id})) - ) + item["method"].user = b64e(as_bytes(json.dumps({"uid": "krall", "sid": session_id}))) grant = _ec.session_manager[session_id] grant.revoked = True @@ -786,8 +758,7 @@ def test_post_logout_uri(self): def test_parse_request(self): _jwt = JWT(key_jar=self.rp_keyjar, iss="client_1", sign_alg="HS256") _jws = _jwt.pack( - AUTH_REQ_DICT, - aud=self.endpoint.server_get("endpoint_context").provider_info["issuer"], + AUTH_REQ_DICT, aud=self.endpoint.server_get("endpoint_context").provider_info["issuer"], ) # ----------------- _req = self.endpoint.parse_request( @@ -804,8 +775,7 @@ def test_parse_request(self): def test_parse_request_uri(self): _jwt = JWT(key_jar=self.rp_keyjar, iss="client_1", sign_alg="HS256") _jws = _jwt.pack( - AUTH_REQ_DICT, - aud=self.endpoint.server_get("endpoint_context").provider_info["issuer"], + AUTH_REQ_DICT, aud=self.endpoint.server_get("endpoint_context").provider_info["issuer"], ) request_uri = "https://client.example.com/req" @@ -868,8 +838,7 @@ def test_mint_token_exp_at(self, exp_in): def test_do_request_uri(self): request = AuthorizationRequest( - redirect_uri="https://rp.example.com/cb", - request_uri="https://example.com/request", + redirect_uri="https://rp.example.com/cb", request_uri="https://example.com/request", ) orig_request = AuthorizationRequest( @@ -887,9 +856,7 @@ def test_do_request_uri(self): ) endpoint_context = self.endpoint.server_get("endpoint_context") - endpoint_context.cdb["client_1"]["request_uris"] = [ - ("https://example.com/request", {}) - ] + endpoint_context.cdb["client_1"]["request_uris"] = [("https://example.com/request", {})] with responses.RequestsMock() as rsps: rsps.add( @@ -968,9 +935,7 @@ def test_response_mode(self, response_mode): request, response_args, request["redirect_uri"], fragment_enc=True ) else: - info = self.endpoint.response_mode( - request, response_args, request["redirect_uri"] - ) + info = self.endpoint.response_mode(request, response_args, request["redirect_uri"]) if response_mode == "form_post": assert set(info.keys()) == { @@ -1013,10 +978,7 @@ def test_do_request_user(self): endpoint_context = self.endpoint.server_get("endpoint_context") # userinfo _userinfo = init_user_info( - { - "class": "oidcop.user_info.UserInfo", - "kwargs": {"db_file": full_path("users.json")}, - }, + {"class": "oidcop.user_info.UserInfo", "kwargs": {"db_file": full_path("users.json")},}, "", ) # login_hint @@ -1052,9 +1014,7 @@ def test_authn_args_gather_message(): assert set(args.keys()) == {"query", "authn_class_ref", "return_uri", "policy_uri"} with pytest.raises(ValueError): - authn_args_gather( - request.to_urlencoded(), INTERNETPROTOCOLPASSWORD, client_info - ) + authn_args_gather(request.to_urlencoded(), INTERNETPROTOCOLPASSWORD, client_info) def test_inputs(): @@ -1068,9 +1028,10 @@ def test_inputs(): def test_acr_claims(): assert acr_claims({"claims": {"id_token": {"acr": {"value": "foo"}}}}) == ["foo"] - assert acr_claims( - {"claims": {"id_token": {"acr": {"values": ["foo", "bar"]}}}} - ) == ["foo", "bar"] + assert acr_claims({"claims": {"id_token": {"acr": {"values": ["foo", "bar"]}}}}) == [ + "foo", + "bar", + ] assert acr_claims({"claims": {"id_token": {"acr": {"values": ["foo"]}}}}) == ["foo"] assert acr_claims({"claims": {"id_token": {"acr": {"essential": True}}}}) is None @@ -1113,21 +1074,15 @@ def create_endpoint_context(self): "passwd_label": "Secret sauce", }, }, - "anon": { - "acr": UNSPECIFIED, - "class": NoAuthn, - "kwargs": {"user": "diana"}, - }, + "anon": {"acr": UNSPECIFIED, "class": NoAuthn, "kwargs": {"user": "diana"},}, }, "cookie_handler": { "class": "oidcop.cookie_handler.CookieHandler", - "kwargs": { - "sign_key": "ghsNKDDLshZTPn974nOsIGhedULrsqnsGoBFBLwUKuJhE2ch" - }, + "kwargs": {"sign_key": "ghsNKDDLshZTPn974nOsIGhedULrsqnsGoBFBLwUKuJhE2ch"}, }, "template_dir": "template", } - server = Server(conf) + server = Server(OPConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) self.endpoint_context = server.endpoint_context def test_authenticated_as_without_cookie(self): @@ -1152,7 +1107,5 @@ def test_authenticated_as_with_cookie(self): client_id=authn_req["client_id"], ) - _info, _time_stamp = method.authenticated_as( - client_id="client 12345", cookie=[_cookie] - ) + _info, _time_stamp = method.authenticated_as(client_id="client 12345", cookie=[_cookie]) assert _info["sub"] == "diana" diff --git a/tests/test_26_oidc_userinfo_endpoint.py b/tests/test_26_oidc_userinfo_endpoint.py index 42245810..e607a230 100755 --- a/tests/test_26_oidc_userinfo_endpoint.py +++ b/tests/test_26_oidc_userinfo_endpoint.py @@ -1,6 +1,7 @@ import json import os +from oidcop.configure import OPConfiguration import pytest from oidcmsg.oauth2 import ResponseMessage from oidcmsg.oidc import AccessTokenRequest @@ -104,16 +105,8 @@ def create_endpoint(self): "class": ProviderConfiguration, "kwargs": {}, }, - "registration": { - "path": "registration", - "class": Registration, - "kwargs": {}, - }, - "authorization": { - "path": "authorization", - "class": Authorization, - "kwargs": {}, - }, + "registration": {"path": "registration", "class": Registration, "kwargs": {},}, + "authorization": {"path": "authorization", "class": Authorization, "kwargs": {},}, "token": { "path": "token", "class": Token, @@ -130,11 +123,7 @@ def create_endpoint(self): "path": "userinfo", "class": userinfo.UserInfo, "kwargs": { - "claim_types_supported": [ - "normal", - "aggregated", - "distributed", - ], + "claim_types_supported": ["normal", "aggregated", "distributed",], "client_authn_method": ["bearer_header"], }, }, @@ -169,7 +158,8 @@ def create_endpoint(self): } }, } - server = Server(conf) + server = Server(OPConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) + endpoint_context = server.endpoint_context endpoint_context.cdb["client_1"] = { "client_secret": "hemligt", @@ -218,9 +208,7 @@ def _mint_token(self, token_type, grant, session_id, token_ref=None): def test_init(self): assert self.endpoint assert set( - self.endpoint.server_get("endpoint_context").provider_info[ - "claims_supported" - ] + self.endpoint.server_get("endpoint_context").provider_info["claims_supported"] ) == { "address", "birthdate", @@ -251,9 +239,7 @@ def test_parse(self): # Free standing access token, not based on an authorization code access_token = self._mint_token("access_token", grant, session_id) - http_info = { - "headers": {"authorization": "Bearer {}".format(access_token.value)} - } + http_info = {"headers": {"authorization": "Bearer {}".format(access_token.value)}} _req = self.endpoint.parse_request({}, http_info=http_info) assert set(_req.keys()) == {"client_id", "access_token"} assert _req["client_id"] == AUTH_REQ["client_id"] @@ -270,9 +256,7 @@ def test_process_request(self): code = self._mint_code(grant, session_id) access_token = self._mint_token("access_token", grant, session_id, code) - http_info = { - "headers": {"authorization": "Bearer {}".format(access_token.value)} - } + http_info = {"headers": {"authorization": "Bearer {}".format(access_token.value)}} _req = self.endpoint.parse_request({}, http_info=http_info) args = self.endpoint.process_request(_req, http_info=http_info) @@ -290,9 +274,7 @@ def test_process_request_not_allowed(self): _event["authn_time"] -= 9000 _event["valid_until"] -= 9000 - http_info = { - "headers": {"authorization": "Bearer {}".format(access_token.value)} - } + http_info = {"headers": {"authorization": "Bearer {}".format(access_token.value)}} _req = self.endpoint.parse_request({}, http_info=http_info) args = self.endpoint.process_request(_req, http_info=http_info) @@ -304,9 +286,7 @@ def test_do_response(self): code = self._mint_code(grant, session_id) access_token = self._mint_token("access_token", grant, session_id, code) - http_info = { - "headers": {"authorization": "Bearer {}".format(access_token.value)} - } + http_info = {"headers": {"authorization": "Bearer {}".format(access_token.value)}} _req = self.endpoint.parse_request({}, http_info=http_info) args = self.endpoint.process_request(_req) @@ -324,9 +304,7 @@ def test_do_signed_response(self): code = self._mint_code(grant, session_id) access_token = self._mint_token("access_token", grant, session_id, code) - http_info = { - "headers": {"authorization": "Bearer {}".format(access_token.value)} - } + http_info = {"headers": {"authorization": "Bearer {}".format(access_token.value)}} _req = self.endpoint.parse_request({}, http_info=http_info) args = self.endpoint.process_request(_req) @@ -343,20 +321,14 @@ def test_custom_scope(self): access_token = self._mint_token("access_token", grant, session_id) self.endpoint.kwargs["add_claims_by_scope"] = True - self.endpoint.server_get( - "endpoint_context" - ).claims_interface.add_claims_by_scope = True + self.endpoint.server_get("endpoint_context").claims_interface.add_claims_by_scope = True grant.claims = { - "userinfo": self.endpoint.server_get( - "endpoint_context" - ).claims_interface.get_claims( + "userinfo": self.endpoint.server_get("endpoint_context").claims_interface.get_claims( session_id=session_id, scopes=_auth_req["scope"], usage="userinfo" ) } - http_info = { - "headers": {"authorization": "Bearer {}".format(access_token.value)} - } + http_info = {"headers": {"authorization": "Bearer {}".format(access_token.value)}} _req = self.endpoint.parse_request({}, http_info=http_info) args = self.endpoint.process_request(_req, http_info=http_info) @@ -378,9 +350,7 @@ def test_wrong_type_of_token(self): grant = self.session_manager[session_id] refresh_token = self._mint_token("refresh_token", grant, session_id) - http_info = { - "headers": {"authorization": "Bearer {}".format(refresh_token.value)} - } + http_info = {"headers": {"authorization": "Bearer {}".format(refresh_token.value)}} _req = self.endpoint.parse_request({}, http_info=http_info) args = self.endpoint.process_request(_req, http_info=http_info) @@ -395,9 +365,7 @@ def test_invalid_token(self): grant = self.session_manager[session_id] access_token = self._mint_token("access_token", grant, session_id) - http_info = { - "headers": {"authorization": "Bearer {}".format(access_token.value)} - } + http_info = {"headers": {"authorization": "Bearer {}".format(access_token.value)}} _req = self.endpoint.parse_request({}, http_info=http_info) access_token.expires_at = time_sans_frac() - 10 diff --git a/tests/test_30_oidc_end_session.py b/tests/test_30_oidc_end_session.py index 6634b53f..1d827a48 100644 --- a/tests/test_30_oidc_end_session.py +++ b/tests/test_30_oidc_end_session.py @@ -4,15 +4,16 @@ from urllib.parse import parse_qs from urllib.parse import urlparse -import pytest -import responses from cryptojwt.key_jar import build_keyjar from oidcmsg.exception import InvalidRequest from oidcmsg.message import Message from oidcmsg.oidc import AuthorizationRequest from oidcmsg.oidc import verified_claim_name from oidcmsg.oidc import verify_id_token +import pytest +import responses +from oidcop.configure import OPConfiguration from oidcop.cookie_handler import CookieHandler from oidcop.exception import RedirectURIError from oidcop.oauth2.authorization import join_query @@ -149,9 +150,7 @@ def create_endpoint(self): "jwks_def": { "private_path": "private/token_jwks.json", "read_only": False, - "key_defs": [ - {"type": "oct", "bytes": "24", "use": ["enc"], "kid": "code"} - ], + "key_defs": [{"type": "oct", "bytes": "24", "use": ["enc"], "kid": "code"}], }, "code": {"kwargs": {"lifetime": 600}}, "token": { @@ -187,7 +186,12 @@ def create_endpoint(self): }, } self.cd = CookieHandler(**cookie_conf) - server = Server(conf, cookie_handler=self.cd, keyjar=KEYJAR) + server = Server( + OPConfiguration(conf=conf, base_path=BASEDIR), + cwd=BASEDIR, + cookie_handler=self.cd, + keyjar=KEYJAR, + ) endpoint_context = server.endpoint_context endpoint_context.cdb = { "client_1": { @@ -258,9 +262,9 @@ def _auth_with_id_token(self, state): _pr_resp = self.authn_endpoint.parse_request(req.to_dict()) _resp = self.authn_endpoint.process_request(_pr_resp) - _info = self.session_endpoint.server_get( - "endpoint_context" - ).cookie_handler.parse_cookie("oidc_op", _resp["cookie"]) + _info = self.session_endpoint.server_get("endpoint_context").cookie_handler.parse_cookie( + "oidc_op", _resp["cookie"] + ) # value is a JSON document _cookie_info = json.loads(_info[0]["value"]) @@ -272,9 +276,7 @@ def test_end_session_endpoint_with_cookie(self): _session_info = self.session_manager.get_session_info_by_token(_code) cookie = self._create_cookie(_session_info["session_id"]) http_info = {"cookie": [cookie]} - _req_args = self.session_endpoint.parse_request( - {"state": "1234567"}, http_info=http_info - ) + _req_args = self.session_endpoint.parse_request({"state": "1234567"}, http_info=http_info) resp = self.session_endpoint.process_request(_req_args, http_info=http_info) # returns a signed JWT to be put in a verification web page shown to @@ -293,9 +295,7 @@ def test_end_session_endpoint_with_cookie_and_unknown_sid(self): id_token = resp_args["id_token"] _uid, _cid, _gid = self.session_manager.decrypt_session_id(_session_id) - cookie = self._create_cookie( - self.session_manager.session_key(_uid, "client_66", _gid) - ) + cookie = self._create_cookie(self.session_manager.session_key(_uid, "client_66", _gid)) http_info = {"cookie": [cookie]} with pytest.raises(ValueError): @@ -307,20 +307,14 @@ def test_end_session_endpoint_with_cookie_id_token_and_unknown_sid(self): id_token = resp_args["id_token"] _uid, _cid, _gid = self.session_manager.decrypt_session_id(_session_id) - cookie = self._create_cookie( - self.session_manager.session_key(_uid, "client_66", _gid) - ) + cookie = self._create_cookie(self.session_manager.session_key(_uid, "client_66", _gid)) http_info = {"cookie": [cookie]} msg = Message(id_token=id_token) - verify_id_token( - msg, keyjar=self.session_endpoint.server_get("endpoint_context").keyjar - ) + verify_id_token(msg, keyjar=self.session_endpoint.server_get("endpoint_context").keyjar) msg2 = Message(id_token_hint=id_token) - msg2[verified_claim_name("id_token_hint")] = msg[ - verified_claim_name("id_token") - ] + msg2[verified_claim_name("id_token_hint")] = msg[verified_claim_name("id_token")] with pytest.raises(ValueError): self.session_endpoint.process_request(msg2, http_info=http_info) @@ -332,9 +326,7 @@ def test_end_session_endpoint_with_cookie_dual_login(self): cookie = self._create_cookie(_session_info["session_id"]) http_info = {"cookie": [cookie]} - resp = self.session_endpoint.process_request( - {"state": "abcde"}, http_info=http_info - ) + resp = self.session_endpoint.process_request({"state": "abcde"}, http_info=http_info) # returns a signed JWT to be put in a verification web page shown to # the user @@ -362,10 +354,7 @@ def test_end_session_endpoint_with_post_logout_redirect_uri(self): with pytest.raises(InvalidRequest): self.session_endpoint.process_request( - { - "post_logout_redirect_uri": post_logout_redirect_uri, - "state": "abcde", - }, + {"post_logout_redirect_uri": post_logout_redirect_uri, "state": "abcde",}, http_info=http_info, ) @@ -382,9 +371,7 @@ def test_end_session_endpoint_with_wrong_post_logout_redirect_uri(self): post_logout_redirect_uri = "https://demo.example.com/log_out" msg = Message(id_token=id_token) - verify_id_token( - msg, keyjar=self.session_endpoint.server_get("endpoint_context").keyjar - ) + verify_id_token(msg, keyjar=self.session_endpoint.server_get("endpoint_context").keyjar) with pytest.raises(RedirectURIError): self.session_endpoint.process_request( @@ -392,9 +379,7 @@ def test_end_session_endpoint_with_wrong_post_logout_redirect_uri(self): "post_logout_redirect_uri": post_logout_redirect_uri, "state": "abcde", "id_token_hint": id_token, - verified_claim_name("id_token_hint"): msg[ - verified_claim_name("id_token") - ], + verified_claim_name("id_token_hint"): msg[verified_claim_name("id_token")], }, http_info=http_info, ) @@ -410,9 +395,7 @@ def test_back_channel_logout_no_uri(self): def test_back_channel_logout(self): self._code_auth("1234567") - _cdb = copy.copy( - self.session_endpoint.server_get("endpoint_context").cdb["client_1"] - ) + _cdb = copy.copy(self.session_endpoint.server_get("endpoint_context").cdb["client_1"]) _cdb["backchannel_logout_uri"] = "https://example.com/bc_logout" _cdb["client_id"] = "client_1" res = self.session_endpoint.do_back_channel_logout(_cdb, "_sid_") @@ -427,9 +410,7 @@ def test_back_channel_logout(self): def test_front_channel_logout(self): self._code_auth("1234567") - _cdb = copy.copy( - self.session_endpoint.server_get("endpoint_context").cdb["client_1"] - ) + _cdb = copy.copy(self.session_endpoint.server_get("endpoint_context").cdb["client_1"]) _cdb["frontchannel_logout_uri"] = "https://example.com/fc_logout" _cdb["client_id"] = "client_1" res = do_front_channel_logout_iframe(_cdb, ISS, "_sid_") @@ -438,9 +419,7 @@ def test_front_channel_logout(self): def test_front_channel_logout_session_required(self): self._code_auth("1234567") - _cdb = copy.copy( - self.session_endpoint.server_get("endpoint_context").cdb["client_1"] - ) + _cdb = copy.copy(self.session_endpoint.server_get("endpoint_context").cdb["client_1"]) _cdb["frontchannel_logout_uri"] = "https://example.com/fc_logout" _cdb["frontchannel_logout_session_required"] = True _cdb["client_id"] = "client_1" @@ -456,9 +435,7 @@ def test_front_channel_logout_session_required(self): def test_front_channel_logout_with_query(self): self._code_auth("1234567") - _cdb = copy.copy( - self.session_endpoint.server_get("endpoint_context").cdb["client_1"] - ) + _cdb = copy.copy(self.session_endpoint.server_get("endpoint_context").cdb["client_1"]) _cdb["frontchannel_logout_uri"] = "https://example.com/fc_logout?entity_id=foo" _cdb["frontchannel_logout_session_required"] = True _cdb["client_id"] = "client_1" @@ -496,9 +473,7 @@ def test_logout_from_client_bc(self): assert _jwt assert _jwt["iss"] == ISS assert _jwt["aud"] == ["client_1"] - assert ( - "sid" in _jwt - ) # This session ID is not the same as the session_id mentioned above + assert "sid" in _jwt # This session ID is not the same as the session_id mentioned above _sid = self.session_endpoint._decrypt_sid(_jwt["sid"]) assert _sid == _session_info["session_id"] @@ -592,9 +567,7 @@ def test_logout_from_client_unknow_sid(self): _session_info = self.session_manager.get_session_info_by_token(_code) self._code_auth2("abcdefg") - _uid, _cid, _gid = self.session_manager.decrypt_session_id( - _session_info["session_id"] - ) + _uid, _cid, _gid = self.session_manager.decrypt_session_id(_session_info["session_id"]) _sid = self.session_manager.encrypted_session_id("babs", _cid, _gid) with pytest.raises(KeyError): res = self.session_endpoint.logout_all_clients(_sid) @@ -619,12 +592,8 @@ def test_logout_from_client_no_session(self): "client_id" ] = "client_2" - _uid, _cid, _gid = self.session_manager.decrypt_session_id( - _session_info["session_id"] - ) - self.session_endpoint.server_get("endpoint_context").session_manager.delete( - [_uid, _cid] - ) + _uid, _cid, _gid = self.session_manager.decrypt_session_id(_session_info["session_id"]) + self.session_endpoint.server_get("endpoint_context").session_manager.delete([_uid, _cid]) with pytest.raises(ValueError): self.session_endpoint.logout_all_clients(_session_info["session_id"]) diff --git a/tests/test_31_introspection.py b/tests/test_31_oauth2_introspection.py similarity index 92% rename from tests/test_31_introspection.py rename to tests/test_31_oauth2_introspection.py index a948e8a9..ae497157 100644 --- a/tests/test_31_introspection.py +++ b/tests/test_31_oauth2_introspection.py @@ -2,6 +2,9 @@ import json import os +from oidcop.configure import ASConfiguration + +from oidcop.configure import OPConfiguration import pytest from cryptojwt import JWT from cryptojwt import as_unicode @@ -118,15 +121,6 @@ def create_endpoint(self, jwt_token): "class": "oidcop.token.jwt_token.JWTToken", "kwargs": {"lifetime": 3600, "aud": ["https://example.org/appl"],}, }, - "id_token": { - "class": "oidcop.token.id_token.IDToken", - "kwargs": { - "base_claims": { - "email": {"essential": True}, - "email_verified": {"essential": True}, - } - }, - }, }, "endpoint": { "authorization": { @@ -175,11 +169,7 @@ def create_endpoint(self, jwt_token): "grant_config": { "usage_rules": { "authorization_code": { - "supports_minting": [ - "access_token", - "refresh_token", - "id_token", - ], + "supports_minting": ["access_token", "refresh_token", "id_token",], "max_usage": 1, }, "access_token": {}, @@ -197,7 +187,7 @@ def create_endpoint(self, jwt_token): "class": "oidcop.token.jwt_token.JWTToken", "kwargs": {}, } - server = Server(conf) + server = Server(ASConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) endpoint_context = server.endpoint_context endpoint_context.cdb["client_1"] = { "client_secret": "hemligt", @@ -205,14 +195,10 @@ def create_endpoint(self, jwt_token): "client_salt": "salted", "token_endpoint_auth_method": "client_secret_post", "response_types": ["code", "token", "code id_token", "id_token"], - "introspection_claims": { - "nickname": None, - "eduperson_scoped_affiliation": None, - }, + "introspection_claims": {"nickname": None, "eduperson_scoped_affiliation": None,}, } endpoint_context.keyjar.import_jwks_as_json( - endpoint_context.keyjar.export_jwks_as_json(private=True), - endpoint_context.issuer, + endpoint_context.keyjar.export_jwks_as_json(private=True), endpoint_context.issuer, ) self.introspection_endpoint = server.server_get("endpoint", "introspection") self.token_endpoint = server.server_get("endpoint", "token") @@ -246,9 +232,7 @@ def _mint_token(self, type, grant, session_id, based_on=None, **kwargs): def _get_access_token(self, areq): session_id = self._create_session(areq) # Consent handling - grant = self.token_endpoint.server_get("endpoint_context").authz( - session_id, areq - ) + grant = self.token_endpoint.server_get("endpoint_context").authz(session_id, areq) self.session_manager[session_id] = grant # grant = self.session_manager[session_id] code = self._mint_token("authorization_code", grant, session_id) @@ -299,9 +283,9 @@ def test_process_request(self): { "token": access_token.value, "client_id": "client_1", - "client_secret": self.introspection_endpoint.server_get( - "endpoint_context" - ).cdb["client_1"]["client_secret"], + "client_secret": self.introspection_endpoint.server_get("endpoint_context").cdb[ + "client_1" + ]["client_secret"], } ) _resp = self.introspection_endpoint.process_request(_req) @@ -323,9 +307,9 @@ def test_do_response(self): { "token": access_token.value, "client_id": "client_1", - "client_secret": self.introspection_endpoint.server_get( - "endpoint_context" - ).cdb["client_1"]["client_secret"], + "client_secret": self.introspection_endpoint.server_get("endpoint_context").cdb[ + "client_1" + ]["client_secret"], } ) _resp = self.introspection_endpoint.process_request(_req) @@ -355,10 +339,7 @@ def test_do_response_no_token(self): # access_token = self._get_access_token(AUTH_REQ) _context = self.introspection_endpoint.server_get("endpoint_context") _req = self.introspection_endpoint.parse_request( - { - "client_id": "client_1", - "client_secret": _context.cdb["client_1"]["client_secret"], - } + {"client_id": "client_1", "client_secret": _context.cdb["client_1"]["client_secret"],} ) _resp = self.introspection_endpoint.process_request(_req) assert "error" in _resp @@ -383,9 +364,7 @@ def test_code(self): session_id = self._create_session(AUTH_REQ) # Apply consent - grant = self.token_endpoint.server_get("endpoint_context").authz( - session_id, AUTH_REQ - ) + grant = self.token_endpoint.server_get("endpoint_context").authz(session_id, AUTH_REQ) self.session_manager[session_id] = grant code = self._mint_token("authorization_code", grant, session_id) @@ -406,9 +385,7 @@ def test_code(self): def test_introspection_claims(self): session_id = self._create_session(AUTH_REQ) # Apply consent - grant = self.token_endpoint.server_get("endpoint_context").authz( - session_id, AUTH_REQ - ) + grant = self.token_endpoint.server_get("endpoint_context").authz(session_id, AUTH_REQ) self.session_manager[session_id] = grant code = self._mint_token("authorization_code", grant, session_id) @@ -416,9 +393,7 @@ def test_introspection_claims(self): self.introspection_endpoint.kwargs["enable_claims_per_client"] = True - _c_interface = self.introspection_endpoint.server_get( - "endpoint_context" - ).claims_interface + _c_interface = self.introspection_endpoint.server_get("endpoint_context").claims_interface grant.claims = { "introspection": _c_interface.get_claims( session_id, scopes=AUTH_REQ["scope"], usage="introspection" diff --git a/tests/test_32_read_registration.py b/tests/test_32_oidc_read_registration.py similarity index 91% rename from tests/test_32_read_registration.py rename to tests/test_32_oidc_read_registration.py index fcb37dcd..c687a0ec 100644 --- a/tests/test_32_read_registration.py +++ b/tests/test_32_oidc_read_registration.py @@ -1,6 +1,8 @@ # -*- coding: latin-1 -*- import json +import os +from oidcop.configure import OPConfiguration import pytest from oidcmsg.oidc import RegistrationRequest @@ -12,6 +14,8 @@ from oidcop.oidc.userinfo import UserInfo from oidcop.server import Server +BASEDIR = os.path.abspath(os.path.dirname(__file__)) + KEYDEFS = [ {"type": "RSA", "key": "", "use": ["sig"]}, {"type": "EC", "crv": "P-256", "use": ["sig"]}, @@ -97,11 +101,7 @@ def create_endpoint(self): "class": RegistrationRead, "kwargs": {"client_authn_method": ["bearer_header"]}, }, - "authorization": { - "path": "authorization", - "class": Authorization, - "kwargs": {}, - }, + "authorization": {"path": "authorization", "class": Authorization, "kwargs": {},}, "token": { "path": "token", "class": Token, @@ -118,11 +118,9 @@ def create_endpoint(self): }, "template_dir": "template", } - server = Server(conf) + server = Server(OPConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) self.registration_endpoint = server.server_get("endpoint", "registration") - self.registration_api_endpoint = server.server_get( - "endpoint", "registration_read" - ) + self.registration_api_endpoint = server.server_get("endpoint", "registration_read") def test_do_response(self): _req = self.registration_endpoint.parse_request(CLI_REQ.to_json()) @@ -141,8 +139,7 @@ def test_do_response(self): } _api_req = self.registration_api_endpoint.parse_request( - "client_id={}".format(_resp["response_args"]["client_id"]), - http_info=http_info, + "client_id={}".format(_resp["response_args"]["client_id"]), http_info=http_info, ) assert set(_api_req.keys()) == {"client_id"} @@ -152,6 +149,4 @@ def test_do_response(self): _endp_response = self.registration_api_endpoint.do_response(_info) assert set(_endp_response.keys()) == {"response", "http_headers"} - assert ("Content-type", "application/json; charset=utf-8") in _endp_response[ - "http_headers" - ] + assert ("Content-type", "application/json; charset=utf-8") in _endp_response["http_headers"] diff --git a/tests/test_33_pkce.py b/tests/test_33_oauth2_pkce.py similarity index 95% rename from tests/test_33_pkce.py rename to tests/test_33_oauth2_pkce.py index b920fa7e..08c24d7d 100644 --- a/tests/test_33_pkce.py +++ b/tests/test_33_oauth2_pkce.py @@ -4,6 +4,7 @@ import secrets import string +from oidcop.configure import ASConfiguration import pytest import yaml from oidcmsg.message import Message @@ -121,11 +122,7 @@ def conf(): "capabilities": CAPABILITIES, "keys": {"uri_path": "static/jwks.json", "key_defs": KEYDEFS}, "endpoint": { - "authorization": { - "path": "{}/authorization", - "class": Authorization, - "kwargs": {}, - }, + "authorization": {"path": "{}/authorization", "class": Authorization, "kwargs": {},}, "token": { "path": "{}/token", "class": Token, @@ -193,7 +190,8 @@ def _code_challenge(): def create_server(config): - server = Server(config) + server = Server(ASConfiguration(conf=config, base_path=BASEDIR), cwd=BASEDIR) + endpoint_context = server.endpoint_context _clients = yaml.safe_load(io.StringIO(client_yaml)) endpoint_context.cdb = _clients["oidc_clients"] @@ -296,9 +294,7 @@ def test_unknown_code_challenge_method(self): assert isinstance(_pr_resp, AuthorizationErrorResponse) assert _pr_resp["error"] == "invalid_request" - assert _pr_resp[ - "error_description" - ] == "Unsupported code_challenge_method={}".format( + assert _pr_resp["error_description"] == "Unsupported code_challenge_method={}".format( _authn_req["code_challenge_method"] ) @@ -316,9 +312,7 @@ def test_unsupported_code_challenge_method(self, conf): assert isinstance(_pr_resp, AuthorizationErrorResponse) assert _pr_resp["error"] == "invalid_request" - assert _pr_resp[ - "error_description" - ] == "Unsupported code_challenge_method={}".format( + assert _pr_resp["error_description"] == "Unsupported code_challenge_method={}".format( _authn_req["code_challenge_method"] ) @@ -371,9 +365,7 @@ def test_missing_authz_endpoint(): } }, } - configuration = OPConfiguration( - conf, base_path=BASEDIR, domain="127.0.0.1", port=443 - ) + configuration = OPConfiguration(conf, base_path=BASEDIR, domain="127.0.0.1", port=443) server = Server(configuration) add_pkce_support(server.server_get("endpoints")) @@ -398,9 +390,7 @@ def test_missing_token_endpoint(): }, }, } - configuration = OPConfiguration( - conf, base_path=BASEDIR, domain="127.0.0.1", port=443 - ) + configuration = OPConfiguration(conf, base_path=BASEDIR, domain="127.0.0.1", port=443) server = Server(configuration) add_pkce_support(server.server_get("endpoints")) diff --git a/tests/test_34_sso.py b/tests/test_34_oidc_sso.py similarity index 94% rename from tests/test_34_sso.py rename to tests/test_34_oidc_sso.py index 075bd58a..dca54792 100755 --- a/tests/test_34_sso.py +++ b/tests/test_34_oidc_sso.py @@ -2,6 +2,7 @@ import json import os +from oidcop.configure import OPConfiguration import pytest import yaml from cryptojwt import KeyJar @@ -133,9 +134,7 @@ def create_endpoint_context(self): "path": "{}/authorization", "class": Authorization, "kwargs": { - "response_types_supported": [ - " ".join(x) for x in RESPONSE_TYPES_SUPPORTED - ], + "response_types_supported": [" ".join(x) for x in RESPONSE_TYPES_SUPPORTED], "response_modes_supported": ["query", "fragment", "form_post"], "claims_parameter_supported": True, "request_parameter_supported": True, @@ -145,11 +144,7 @@ def create_endpoint_context(self): }, "keys": {"uri_path": "static/jwks.json", "key_defs": KEYDEFS}, "authentication": { - "anon": { - "acr": UNSPECIFIED, - "class": NoAuthn, - "kwargs": {"user": "diana"}, - }, + "anon": {"acr": UNSPECIFIED, "class": NoAuthn, "kwargs": {"user": "diana"},}, }, "cookie_handler": { "class": "oidcop.cookie_handler.CookieHandler", @@ -164,7 +159,8 @@ def create_endpoint_context(self): }, "template_dir": "template", } - server = Server(conf) + server = Server(OPConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) + endpoint_context = server.endpoint_context _clients = yaml.safe_load(io.StringIO(client_yaml)) endpoint_context.cdb = _clients["oidc_clients"] @@ -234,9 +230,9 @@ def test_sso(self): # No valid login cookie so new session assert info["session_id"] != sid2 - user_session_info = self.endpoint.server_get( - "endpoint_context" - ).session_manager.get(["diana"]) + user_session_info = self.endpoint.server_get("endpoint_context").session_manager.get( + ["diana"] + ) assert len(user_session_info.subordinate) == 3 assert set(user_session_info.subordinate) == { "client_1", diff --git a/tests/test_35_oidc_token_endpoint.py b/tests/test_35_oidc_token_endpoint.py index 4e3c1694..ec0bd6a4 100755 --- a/tests/test_35_oidc_token_endpoint.py +++ b/tests/test_35_oidc_token_endpoint.py @@ -1,6 +1,7 @@ import json import os +from oidcop.configure import OPConfiguration import pytest from cryptojwt import JWT from cryptojwt.key_jar import build_keyjar @@ -126,16 +127,8 @@ def conf(): "class": ProviderConfiguration, "kwargs": {}, }, - "registration": { - "path": "registration", - "class": Registration, - "kwargs": {}, - }, - "authorization": { - "path": "authorization", - "class": Authorization, - "kwargs": {}, - }, + "registration": {"path": "registration", "class": Registration, "kwargs": {},}, + "authorization": {"path": "authorization", "class": Authorization, "kwargs": {},}, "token": { "path": "token", "class": Token, @@ -171,11 +164,7 @@ def conf(): "usage_rules": { "authorization_code": { "expires_in": 300, - "supports_minting": [ - "access_token", - "refresh_token", - "id_token", - ], + "supports_minting": ["access_token", "refresh_token", "id_token",], "max_usage": 1, }, "access_token": {"expires_in": 600}, @@ -194,7 +183,8 @@ def conf(): class TestEndpoint(object): @pytest.fixture(autouse=True) def create_endpoint(self, conf): - server = Server(conf) + server = Server(OPConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) + endpoint_context = server.endpoint_context endpoint_context.cdb["client_1"] = { "client_secret": "hemligt", @@ -225,9 +215,7 @@ def _create_session(self, auth_req, sub_type="public", sector_identifier=""): ) def _mint_code(self, grant, client_id): - session_id = self.session_manager.encrypted_session_id( - self.user_id, client_id, grant.id - ) + session_id = self.session_manager.encrypted_session_id(self.user_id, client_id, grant.id) usage_rules = grant.usage_rules.get("authorization_code", {}) _exp_in = usage_rules.get("expires_in") @@ -334,9 +322,7 @@ def test_process_request_using_private_key_jwt(self): _jwt = JWT(CLIENT_KEYJAR, iss=AUTH_REQ["client_id"], sign_alg="RS256") _jwt.with_jti = True _assertion = _jwt.pack({"aud": [self.token_endpoint.full_path]}) - _token_request.update( - {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER} - ) + _token_request.update({"client_assertion": _assertion, "client_assertion_type": JWT_BEARER}) _token_request["code"] = code.value _req = self.token_endpoint.parse_request(_token_request) @@ -366,9 +352,7 @@ def test_do_refresh_access_token(self): _token_value = _resp["response_args"]["refresh_token"] _session_info = self.session_manager.get_session_info_by_token(_token_value) - _token = self.session_manager.find_token( - _session_info["session_id"], _token_value - ) + _token = self.session_manager.find_token(_session_info["session_id"], _token_value) _token.usage_rules["supports_minting"] = [ "access_token", "refresh_token", @@ -410,9 +394,7 @@ def test_do_2nd_refresh_access_token(self): # Make sure ID Tokens can also be used by this refesh token _token_value = _resp["response_args"]["refresh_token"] _session_info = self.session_manager.get_session_info_by_token(_token_value) - _token = self.session_manager.find_token( - _session_info["session_id"], _token_value - ) + _token = self.session_manager.find_token(_session_info["session_id"], _token_value) _token.usage_rules["supports_minting"] = [ "access_token", "refresh_token", diff --git a/tests/test_36_token_exchange.py b/tests/test_36_oauth2_token_exchange.py similarity index 91% rename from tests/test_36_token_exchange.py rename to tests/test_36_oauth2_token_exchange.py index 7e6904e9..5e7641c1 100644 --- a/tests/test_36_token_exchange.py +++ b/tests/test_36_oauth2_token_exchange.py @@ -1,18 +1,19 @@ import json import os -import pytest from cryptojwt.key_jar import build_keyjar from oidcmsg.oauth2 import TokenExchangeRequest from oidcmsg.oidc import AccessTokenRequest from oidcmsg.oidc import AuthorizationRequest +import pytest from oidcop.authn_event import create_authn_event from oidcop.authz import AuthzHandling from oidcop.client_authn import verify_client +from oidcop.configure import ASConfiguration from oidcop.cookie_handler import CookieHandler -from oidcop.oidc.authorization import Authorization -from oidcop.oidc.token import Token +from oidcop.oauth2.authorization import Authorization +from oidcop.oauth2.token import Token from oidcop.server import Server from oidcop.session.grant import ExchangeGrant from oidcop.user_authn.authn_context import INTERNETPROTOCOLPASSWORD @@ -95,12 +96,12 @@ def create_endpoint(self): "endpoint": { "authorization": { "path": "authorization", - "class": Authorization, + "class": 'oidcop.oauth2.authorization.Authorization', "kwargs": {}, }, "token": { "path": "token", - "class": Token, + "class": 'oidcop.oauth2.token.Token', "kwargs": { "client_authn_method": [ "client_secret_basic", @@ -132,11 +133,7 @@ def create_endpoint(self): "grant_config": { "usage_rules": { "authorization_code": { - "supports_minting": [ - "access_token", - "refresh_token", - "id_token", - ], + "supports_minting": ["access_token", "refresh_token", "id_token", ], "max_usage": 1, }, "access_token": {}, @@ -161,12 +158,11 @@ def create_endpoint(self): }, "refresh": { "class": "oidcop.token.jwt_token.JWTToken", - "kwargs": {"lifetime": 3600, "aud": ["https://example.org/appl"],}, + "kwargs": {"lifetime": 3600, "aud": ["https://example.org/appl"], }, }, - "id_token": {"class": "oidcop.token.id_token.IDToken", "kwargs": {}}, }, } - server = Server(conf) + server = Server(ASConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) endpoint_context = server.endpoint_context endpoint_context.cdb["client_1"] = { "client_secret": "hemligt", @@ -261,22 +257,15 @@ def test_do_response(self): assert exch_grants exch_grant = exch_grants[0] - session_info = self.session_manager.get_session_info_by_token( - ter["subject_token"] - ) - _token = self.session_manager.find_token( - session_info["session_id"], ter["subject_token"] - ) + session_info = self.session_manager.get_session_info_by_token(ter["subject_token"]) + _token = self.session_manager.find_token(session_info["session_id"], ter["subject_token"]) session_id = self.session_manager.encrypted_session_id( session_info["user_id"], session_info["client_id"], exch_grant.id ) _token = self._mint_access_token( - exch_grant, - session_id, - token_ref=_token, - resources=["https://backend.example.com"], + exch_grant, session_id, token_ref=_token, resources=["https://backend.example.com"], ) print(_token.value) @@ -284,9 +273,9 @@ def test_do_response(self): { "token": _token.value, "client_id": "client_1", - "client_secret": self.introspection_endpoint.server_get( - "endpoint_context" - ).cdb["client_1"]["client_secret"], + "client_secret": self.introspection_endpoint.server_get("endpoint_context").cdb[ + "client_1" + ]["client_secret"], } ) _resp = self.introspection_endpoint.process_request(_req) diff --git a/tests/test_40_oauth2_pushed_authorization.py b/tests/test_40_oauth2_pushed_authorization.py index 2e4a17bb..56ef1d2d 100644 --- a/tests/test_40_oauth2_pushed_authorization.py +++ b/tests/test_40_oauth2_pushed_authorization.py @@ -1,5 +1,7 @@ import io +import os +from oidcop.configure import ASConfiguration import pytest import yaml from cryptojwt import JWT @@ -15,6 +17,8 @@ from oidcop.oidc.registration import Registration from oidcop.server import Server +BASEDIR = os.path.abspath(os.path.dirname(__file__)) + CAPABILITIES = { "subject_types_supported": ["public", "pairwise", "ephemeral"], "grant_types_supported": [ @@ -106,18 +110,12 @@ def create_endpoint(self): "class": ProviderConfiguration, "kwargs": {}, }, - "registration": { - "path": "registration", - "class": Registration, - "kwargs": {}, - }, + "registration": {"path": "registration", "class": Registration, "kwargs": {},}, "authorization": { "path": "authorization", "class": Authorization, "kwargs": { - "response_types_supported": [ - " ".join(x) for x in RESPONSE_TYPES_SUPPORTED - ], + "response_types_supported": [" ".join(x) for x in RESPONSE_TYPES_SUPPORTED], "response_modes_supported": ["query", "fragment", "form_post"], "claims_parameter_supported": True, "request_parameter_supported": True, @@ -157,7 +155,7 @@ def create_endpoint(self): }, }, } - server = Server(conf) + server = Server(ASConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) endpoint_context = server.endpoint_context _clients = yaml.safe_load(io.StringIO(client_yaml)) endpoint_context.cdb = _clients["oidc_clients"] @@ -171,9 +169,7 @@ def create_endpoint(self): self.rp_keyjar.export_jwks(issuer_id="s6BhdRkqt3"), "s6BhdRkqt3" ) - self.pushed_authorization_endpoint = server.server_get( - "endpoint", "pushed_authorization" - ) + self.pushed_authorization_endpoint = server.server_get("endpoint", "pushed_authorization") self.authorization_endpoint = server.server_get("endpoint", "authorization") def test_init(self): @@ -181,14 +177,10 @@ def test_init(self): def test_pushed_auth_urlencoded(self): http_info = { - "headers": { - "authorization": "Basic czZCaGRSa3F0Mzo3RmpmcDBaQnIxS3REUmJuZlZkbUl3" - } + "headers": {"authorization": "Basic czZCaGRSa3F0Mzo3RmpmcDBaQnIxS3REUmJuZlZkbUl3"} } - _req = self.pushed_authorization_endpoint.parse_request( - AUTHN_REQUEST, http_info=http_info - ) + _req = self.pushed_authorization_endpoint.parse_request(AUTHN_REQUEST, http_info=http_info) assert isinstance(_req, AuthorizationRequest) assert set(_req.keys()) == { @@ -208,14 +200,10 @@ def test_pushed_auth_request(self): authn_request = "request={}".format(_jws) http_info = { - "headers": { - "authorization": "Basic czZCaGRSa3F0Mzo3RmpmcDBaQnIxS3REUmJuZlZkbUl3" - } + "headers": {"authorization": "Basic czZCaGRSa3F0Mzo3RmpmcDBaQnIxS3REUmJuZlZkbUl3"} } - _req = self.pushed_authorization_endpoint.parse_request( - authn_request, http_info=http_info - ) + _req = self.pushed_authorization_endpoint.parse_request(authn_request, http_info=http_info) assert isinstance(_req, AuthorizationRequest) _req = remove_jwt_parameters(_req) @@ -233,14 +221,10 @@ def test_pushed_auth_request(self): def test_pushed_auth_urlencoded_process(self): http_info = { - "headers": { - "authorization": "Basic czZCaGRSa3F0Mzo3RmpmcDBaQnIxS3REUmJuZlZkbUl3" - } + "headers": {"authorization": "Basic czZCaGRSa3F0Mzo3RmpmcDBaQnIxS3REUmJuZlZkbUl3"} } - _req = self.pushed_authorization_endpoint.parse_request( - AUTHN_REQUEST, http_info=http_info - ) + _req = self.pushed_authorization_endpoint.parse_request(AUTHN_REQUEST, http_info=http_info) assert isinstance(_req, AuthorizationRequest) assert set(_req.keys()) == { diff --git a/tests/test_50_persistence.py b/tests/test_50_persistence.py index 419258da..535a3444 100644 --- a/tests/test_50_persistence.py +++ b/tests/test_50_persistence.py @@ -2,14 +2,15 @@ import os import shutil -import pytest from cryptojwt.jwt import utc_time_sans_frac from oidcmsg.oidc import AccessTokenRequest from oidcmsg.oidc import AuthorizationRequest +import pytest from oidcop import user_info from oidcop.authn_event import create_authn_event from oidcop.authz import AuthzHandling +from oidcop.configure import OPConfiguration from oidcop.oidc import userinfo from oidcop.oidc.authorization import Authorization from oidcop.oidc.provider_config import ProviderConfiguration @@ -102,11 +103,7 @@ def full_path(local_file): "kwargs": {}, }, "registration": {"path": "registration", "class": Registration, "kwargs": {},}, - "authorization": { - "path": "authorization", - "class": Authorization, - "kwargs": {}, - }, + "authorization": {"path": "authorization", "class": Authorization, "kwargs": {},}, "token": { "path": "token", "class": Token, @@ -129,10 +126,7 @@ def full_path(local_file): }, }, }, - "userinfo": { - "class": user_info.UserInfo, - "kwargs": {"db_file": full_path("users.json")}, - }, + "userinfo": {"class": user_info.UserInfo, "kwargs": {"db_file": full_path("users.json")},}, # "client_authn": verify_client, "authentication": { "anon": { @@ -164,17 +158,11 @@ def full_path(local_file): "grant_config": { "usage_rules": { "authorization_code": { - "supports_minting": [ - "access_token", - "refresh_token", - "id_token", - ], + "supports_minting": ["access_token", "refresh_token", "id_token",], "max_usage": 1, }, "access_token": {}, - "refresh_token": { - "supports_minting": ["access_token", "refresh_token"], - }, + "refresh_token": {"supports_minting": ["access_token", "refresh_token"],}, }, "expires_in": 43200, } @@ -191,8 +179,12 @@ def create_endpoint(self): except FileNotFoundError: pass - server1 = Server(ENDPOINT_CONTEXT_CONFIG) - server2 = Server(ENDPOINT_CONTEXT_CONFIG) + server1 = Server( + OPConfiguration(conf=ENDPOINT_CONTEXT_CONFIG, base_path=BASEDIR), cwd=BASEDIR + ) + server2 = Server( + OPConfiguration(conf=ENDPOINT_CONTEXT_CONFIG, base_path=BASEDIR), cwd=BASEDIR + ) server1.endpoint_context.cdb["client_1"] = { "client_secret": "hemligt", @@ -222,9 +214,7 @@ def create_endpoint(self): } self.user_id = "diana" - def _create_session( - self, auth_req, sub_type="public", sector_identifier="", index=1 - ): + def _create_session(self, auth_req, sub_type="public", sector_identifier="", index=1): if sector_identifier: authz_req = auth_req.copy() authz_req["sector_identifier_uri"] = sector_identifier @@ -264,9 +254,7 @@ def _mint_access_token(self, grant, session_id, token_ref=None, index=1): based_on=token_ref, # Means the token (tok) was used to mint this token ) - self.session_manager[index].set( - [self.user_id, _session_info["client_id"], grant.id], grant - ) + self.session_manager[index].set([self.user_id, _session_info["client_id"], grant.id], grant) return _token @@ -279,9 +267,7 @@ def _dump_restore(self, fro, to): def test_init(self): assert self.endpoint[1] assert set( - self.endpoint[1] - .server_get("endpoint_context") - .provider_info["claims_supported"] + self.endpoint[1].server_get("endpoint_context").provider_info["claims_supported"] ) == { "address", "birthdate", @@ -306,20 +292,12 @@ def test_init(self): "zoneinfo", } assert set( - self.endpoint[1] - .server_get("endpoint_context") - .provider_info["claims_supported"] - ) == set( - self.endpoint[2] - .server_get("endpoint_context") - .provider_info["claims_supported"] - ) + self.endpoint[1].server_get("endpoint_context").provider_info["claims_supported"] + ) == set(self.endpoint[2].server_get("endpoint_context").provider_info["claims_supported"]) def test_parse(self): session_id = self._create_session(AUTH_REQ, index=1) - grant = ( - self.endpoint[1].server_get("endpoint_context").authz(session_id, AUTH_REQ) - ) + grant = self.endpoint[1].server_get("endpoint_context").authz(session_id, AUTH_REQ) # grant, session_id = self._do_grant(AUTH_REQ, index=1) code = self._mint_code(grant, session_id, index=1) access_token = self._mint_access_token(grant, session_id, code, 1) @@ -328,48 +306,36 @@ def test_parse(self): self._dump_restore(1, 2) - http_info = { - "headers": {"authorization": "Bearer {}".format(access_token.value)} - } + http_info = {"headers": {"authorization": "Bearer {}".format(access_token.value)}} _req = self.endpoint[2].parse_request({}, http_info=http_info) assert set(_req.keys()) == {"client_id", "access_token"} def test_process_request(self): session_id = self._create_session(AUTH_REQ, index=1) - grant = ( - self.endpoint[1].server_get("endpoint_context").authz(session_id, AUTH_REQ) - ) + grant = self.endpoint[1].server_get("endpoint_context").authz(session_id, AUTH_REQ) code = self._mint_code(grant, session_id, index=1) access_token = self._mint_access_token(grant, session_id, code, 1) self._dump_restore(1, 2) - http_info = { - "headers": {"authorization": "Bearer {}".format(access_token.value)} - } + http_info = {"headers": {"authorization": "Bearer {}".format(access_token.value)}} _req = self.endpoint[2].parse_request({}, http_info=http_info) args = self.endpoint[2].process_request(_req) assert args def test_process_request_not_allowed(self): session_id = self._create_session(AUTH_REQ, index=2) - grant = ( - self.endpoint[2].server_get("endpoint_context").authz(session_id, AUTH_REQ) - ) + grant = self.endpoint[2].server_get("endpoint_context").authz(session_id, AUTH_REQ) code = self._mint_code(grant, session_id, index=2) access_token = self._mint_access_token(grant, session_id, code, 2) access_token.expires_at = utc_time_sans_frac() - 60 - self.session_manager[2].set( - [self.user_id, AUTH_REQ["client_id"], grant.id], grant - ) + self.session_manager[2].set([self.user_id, AUTH_REQ["client_id"], grant.id], grant) self._dump_restore(2, 1) - http_info = { - "headers": {"authorization": "Bearer {}".format(access_token.value)} - } + http_info = {"headers": {"authorization": "Bearer {}".format(access_token.value)}} _req = self.endpoint[1].parse_request({}, http_info=http_info) @@ -394,17 +360,13 @@ def test_process_request_not_allowed(self): def test_do_response(self): session_id = self._create_session(AUTH_REQ, index=2) - grant = ( - self.endpoint[2].server_get("endpoint_context").authz(session_id, AUTH_REQ) - ) + grant = self.endpoint[2].server_get("endpoint_context").authz(session_id, AUTH_REQ) code = self._mint_code(grant, session_id, index=2) access_token = self._mint_access_token(grant, session_id, code, 2) self._dump_restore(2, 1) - http_info = { - "headers": {"authorization": "Bearer {}".format(access_token.value)} - } + http_info = {"headers": {"authorization": "Bearer {}".format(access_token.value)}} _req = self.endpoint[1].parse_request({}, http_info=http_info) args = self.endpoint[1].process_request(_req) @@ -421,17 +383,13 @@ def test_do_signed_response(self): ] = "ES256" session_id = self._create_session(AUTH_REQ, index=2) - grant = ( - self.endpoint[2].server_get("endpoint_context").authz(session_id, AUTH_REQ) - ) + grant = self.endpoint[2].server_get("endpoint_context").authz(session_id, AUTH_REQ) code = self._mint_code(grant, session_id, index=2) access_token = self._mint_access_token(grant, session_id, code, 2) self._dump_restore(2, 1) - http_info = { - "headers": {"authorization": "Bearer {}".format(access_token.value)} - } + http_info = {"headers": {"authorization": "Bearer {}".format(access_token.value)}} _req = self.endpoint[1].parse_request({}, http_info=http_info) args = self.endpoint[1].process_request(_req) @@ -444,34 +402,26 @@ def test_custom_scope(self): _auth_req["scope"] = ["openid", "research_and_scholarship"] session_id = self._create_session(AUTH_REQ, index=2) - grant = ( - self.endpoint[2].server_get("endpoint_context").authz(session_id, AUTH_REQ) - ) + grant = self.endpoint[2].server_get("endpoint_context").authz(session_id, AUTH_REQ) self._dump_restore(2, 1) grant.claims = { "userinfo": self.endpoint[1] .server_get("endpoint_context") - .claims_interface.get_claims( - session_id, scopes=_auth_req["scope"], usage="userinfo" - ) + .claims_interface.get_claims(session_id, scopes=_auth_req["scope"], usage="userinfo") } self._dump_restore(1, 2) - self.session_manager[2].set( - self.session_manager[2].decrypt_session_id(session_id), grant - ) + self.session_manager[2].set(self.session_manager[2].decrypt_session_id(session_id), grant) code = self._mint_code(grant, session_id, index=2) access_token = self._mint_access_token(grant, session_id, code, 2) self._dump_restore(2, 1) - http_info = { - "headers": {"authorization": "Bearer {}".format(access_token.value)} - } + http_info = {"headers": {"authorization": "Bearer {}".format(access_token.value)}} _req = self.endpoint[1].parse_request({}, http_info=http_info) args = self.endpoint[1].process_request(_req) @@ -484,4 +434,3 @@ def test_custom_scope(self): "email_verified", "eduperson_scoped_affiliation", } -