diff --git a/doc/source/index.rst b/doc/source/index.rst index 0992ecf..a060e6d 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -11,6 +11,7 @@ Welcome to OidcEndpoint's documentation! :caption: Contents: oidcendpoint + session_manager Indices and tables ================== diff --git a/doc/source/oidcendpoint.rst b/doc/source/oidcendpoint.rst index 203701d..6c359a2 100644 --- a/doc/source/oidcendpoint.rst +++ b/doc/source/oidcendpoint.rst @@ -62,6 +62,14 @@ oidcendpoint\.id_token module :undoc-members: :show-inheritance: +oidcendpoint\.grant module +-------------------------- + +.. automodule:: oidcendpoint.grant + :members: + :undoc-members: + :show-inheritance: + oidcendpoint\.in_memory_db module --------------------------------- @@ -86,18 +94,26 @@ oidcendpoint\.login_hint module :undoc-members: :show-inheritance: -oidcendpoint\.session module ----------------------------- +oidcendpoint\.scopes module +--------------------------- + +.. automodule:: oidcendpoint.scopes + :members: + :undoc-members: + :show-inheritance: + +oidcendpoint\.session_management module +--------------------------------------- -.. automodule:: oidcendpoint.session +.. automodule:: oidcendpoint.session_management :members: :undoc-members: :show-inheritance: -oidcendpoint\.sso_db module ---------------------------- +oidcendpoint\.session_storage module +------------------------------------ -.. automodule:: oidcendpoint.sso_db +.. automodule:: oidcendpoint.session_management :members: :undoc-members: :show-inheritance: diff --git a/doc/source/session_manager.rst b/doc/source/session_manager.rst new file mode 100644 index 0000000..403e041 --- /dev/null +++ b/doc/source/session_manager.rst @@ -0,0 +1,536 @@ +Session Management +================== + +- `About session management`_ + - `Design criteria`_ + - `Database layout`_ +- `The information structure`_ + - `Session key`_ + - `User session information`_ + - `Client session information`_ + - `Grant information`_ + - `Token`_ +- `Session Info API`_ +- `Grant API`_ +- `Token API`_ + +- `Session Manager API`_ + - `create_session`_ + - `add_grant`_ + - `find_token`_ + - `get_authentication_event`_ + - `get_client_session_info`_ + - `get_grant_by_response_type`_ + - `get_session_info`_ + - `get_session_info_by_token`_ + - `get_sids_by_user_id`_ + - `get_user_info`_ + - `grants`_ + - `revoke_client_session`_ + - `revoke_grant`_ + - `revoke_token`_ + + +About session management +------------------------ +.. _`About session management`: + +The OIDC Session Management draft defines session to be: + + Continuous period of time during which an End-User accesses a Relying + Party relying on the Authentication of the End-User performed by the + OpenID Provider. + +Note that we are dealing with a Single Sign On (SSO) context here. +If for some reason the OP does not want to support SSO then the +session management has to be done a bit differently. In that case each +session (user_id,client_id) would have its own authentication even. Not one +shared between the sessions. + +Design criteria ++++++++++++++++ +.. _`Design criteria`: + +So a session is defined by a user and a Relying Party. If one adds to that +that a user can have several sessions active at the same time each one against +a unique Relying Party we have the bases for session management. + +Furthermore the user may well decide on different rules for different +relying parties for releasing user +attributes, where and how issued access tokens could be used and whether +refresh tokens should be issued or not. + +We also need to keep track on which tokens where used to mint new tokens +such that we can easily revoked a suite of tokens all with a common ancestor. + +Database layout ++++++++++++++++ +.. _`Database layout`: + +The database is organized in 3 levels. The top one being the users. +Below that the Relying Parties and at the bottom what is called grants. + +Grants organize authorization codes, access tokens and refresh tokens (and +possibly other types of tokens) in a comprehensive way. More about that below. + +There may be many Relying Parties below a user and many grants below a +Relying Party. + +The information structure +------------------------- +.. _`The information structure`: + +As stated above there are 3 layers: user session information, client session +information and grants. But first the keys to the information. + +Session key ++++++++++++ +.. _`Session key`: + +A key to the session information is based on a list. The first item being the +user identifier, the second the client identifier and the third the grant +identifier. +If you only want the user session information then the key is a list with one +item the user id. If you want the client session information the key is a +list with 2 items (user_id, client_id). And lastly if you want a grant then +the key is a list with 3 elements (user_id, client_id, grant_id). + +A *session identifier* is constructed using the **session_key** function. +It takes as input the 3 elements list.:: + + session_id = session_key(user_id, client_id, grant_id) + + +Using the function **unpack_session_key** you can get the elements from a +session_id.:: + + user_id, client_id, grant_id = unpack_session_id(session_id) + + +User session information +++++++++++++++++++++++++ +.. _`User session information`: + +Houses the authentication event information which is the same for all session +connected to a user. +Here we also have a list of all the clients that this user has a session with. +Expressed as a dictionary this can look like this:: + + { + 'authentication_event': { + 'uid': 'diana', + 'authn_info': "urn:oasis:names:tc:SAML:2.0:ac:classes:InternetProtocolPassword", + 'authn_time': 1605515787, + 'valid_until': 1605519387 + }, + 'subordinate': ['client_1'] + } + + +Client session information +++++++++++++++++++++++++++ +.. _`Client session information`: + +The client specific information of the session information. +Presently only the authorization request and the subject identifier (sub). +The subordinates to this set of information are the grants:: + + { + 'authorization_request':{ + 'client_id': 'client_1', + 'redirect_uri': 'https://example.com/cb', + 'scope': ['openid', 'research_and_scholarship'], + 'state': 'STATE', + 'response_type': ['code'] + }, + 'sub': '117afe8d7bb0ace8e7fb2706034ab2d3fbf17f0fd4c949aa9c23aedd051cc9e3', + 'subordinate': ['e996c61227e711eba173acde48001122'], + 'revoked': False + } + +Grant information ++++++++++++++++++ +.. _`Grant information`: + +Grants are created by an authorization subsystem in an OP. If the grant is +created in connection with an user authentication the authorization system +might normally ask the user for usage consent and then base the construction +of the grant on that consent. + +If an authorization server can act as a Security Token Service (STS) as +defined by https://tools.ietf.org/html/draft-ietf-oauth-token-exchange-16 +then no user is involved. In the context of session management the STS is +equivalent to a user. + +Grant information contains information about user consent and issued tokens.:: + + { + "type": "grant", + "scope": ["openid", "research_and_scholarship"], + "authorization_details": null, + "claims": { + "userinfo": { + "sub": null, + "name": null, + "given_name": null, + "family_name": null, + "email": null, + "email_verified": null, + "eduperson_scoped_affiliation": null + } + }, + "resources": ["client_1"], + "issued_at": 1605452123, + "not_before": 0, + "expires_at": 0, + "revoked": false, + "issued_token": [ + { + "type": "authorization_code", + "issued_at": 1605452123, + "not_before": 0, + "expires_at": 1605452423, + "revoked": false, + "value": "Z0FBQUFBQmZzVUZieDFWZy1fbjE2ckxvZkFTVC1ZTHJIVlk0Z09tOVk1M0RsOVNDbkdfLTIxTUhILWs4T29kM1lmV015UEN1UGxrWkxLTkVXOEg1WVJLNjh3MGlhMVdSRWhYcUY4cGdBQkJEbzJUWUQ3UGxTUWlJVDNFUHFlb29PWUFKcjNXeHdRM1hDYzRIZnFrYjhVZnIyTFhvZ2Y0NUhjR1VBdzE0STVEWmJ3WkttTk1OYXQtTHNtdHJwYk1nWnl3MUJqSkdWZGFtdVNfY21VNXQxY3VzalpIczBWbGFueVk0TVZ2N2d2d0hVWTF4WG56TDJ6bz0=", + "usage_rules": { + "expires_in": 300, + "supports_minting": [ + "access_token", + "refresh_token", + "id_token" + ], + "max_usage": 1 + }, + "used": 0, + "based_on": null, + "id": "96d19bea275211eba43bacde48001122" + }, + { + "type": "access_token", + "issued_at": 1605452123, + "not_before": 0, + "expires_at": 1605452723, + "revoked": false, + "value": "Z0FBQUFBQmZzVUZiaWVRbi1IS2k0VW4wVDY1ZmJHeEVCR1hVODBaQXR6MWkzelNBRFpOS2tRM3p4WWY5Y1J6dk5IWWpnelRETGVpSG52b0d4RGhjOWphdWp4eW5xZEJwQzliaS16cXFCcmRFbVJqUldsR1Z3SHdTVVlWbkpHak54TmJaSTV2T3NEQ0Y1WFkxQkFyamZHbmd4V0RHQ3k1MVczYlYwakEyM010SGoyZk9tUVVxbWdYUzBvMmRRNVlZMUhRSnM4WFd2QzRkVmtWNVJ1aVdJSXQyWnpVTlRiZnMtcVhKTklGdzBzdDJ3RkRnc1A1UEw2Yz0=", + "usage_rules": { + "expires_in": 600, + }, + "used": 0, + "based_on": "Z0FBQUFBQmZzVUZieDFWZy1fbjE2ckxvZkFTVC1ZTHJIVlk0Z09tOVk1M0RsOVNDbkdfLTIxTUhILWs4T29kM1lmV015UEN1UGxrWkxLTkVXOEg1WVJLNjh3MGlhMVdSRWhYcUY4cGdBQkJEbzJUWUQ3UGxTUWlJVDNFUHFlb29PWUFKcjNXeHdRM1hDYzRIZnFrYjhVZnIyTFhvZ2Y0NUhjR1VBdzE0STVEWmJ3WkttTk1OYXQtTHNtdHJwYk1nWnl3MUJqSkdWZGFtdVNfY21VNXQxY3VzalpIczBWbGFueVk0TVZ2N2d2d0hVWTF4WG56TDJ6bz0=", + "id": "96d1c840275211eba43bacde48001122" + } + ], + "id": "96d16d3c275211eba43bacde48001122" + } + +The parameters are described below + +scope +::::: + +This is the scope that was chosen for this grant. Either by the user or by +some rules that the Authorization Server runs by. + +authorization_details +::::::::::::::::::::: + +Presently a place hold. But this is expected to be information on how the +authorization was performed. What input was used and so on. + +claims +:::::: + +The set of claims that should be returned in different circumstances. The +syntax that is defined in +https://openid.net/specs/openid-connect-core-1_0.html#ClaimsParameter +is used. With one addition, beside *userinfo* and *id_token* we have added +*introspection*. + +resources +::::::::: + +This are the resource servers and other entities that should be accepted +as users of issued access tokens. + +issued_at +::::::::: + +When the grant was created. Its value is a JSON number representing the number +of seconds from 1970-01-01T0:0:0Z as measured in UTC until the date/time. + +not_before +:::::::::: +If the usage of the grant should be delay, this is when it can start being used. +Its value is a JSON number representing the number +of seconds from 1970-01-01T0:0:0Z as measured in UTC until the date/time. + +expires_at +:::::::::: +When the grant expires. +Its value is a JSON number representing the number +of seconds from 1970-01-01T0:0:0Z as measured in UTC until the date/time. + +revoked +::::::: +If the grant has been revoked. + +issued_token +:::::::::::: +Tokens that has been issued based on this grant. There is no limitation +as to which tokens can be issued. Though presently we only have: + +- authorization_code, +- access_token and +- refresh_token + +id +:: +The grant identifier. + +Token ++++++ +.. _`Token`: + +As mention above there are presently only 3 token types that are defined: + +- authorization_code, +- access_token and +- refresh_token + +A token is described as follows:: + + { + "type": "authorization_code", + "issued_at": 1605452123, + "not_before": 0, + "expires_at": 1605452423, + "revoked": false, + "value": "Z0FBQUFBQmZzVUZieDFWZy1fbjE2ckxvZkFTVC1ZTHJIVlk0Z09tOVk1M0RsOVNDbkdfLTIxTUhILWs4T29kM1lmV015UEN1UGxrWkxLTkVXOEg1WVJLNjh3MGlhMVdSRWhYcUY4cGdBQkJEbzJUWUQ3UGxTUWlJVDNFUHFlb29PWUFKcjNXeHdRM1hDYzRIZnFrYjhVZnIyTFhvZ2Y0NUhjR1VBdzE0STVEWmJ3WkttTk1OYXQtTHNtdHJwYk1nWnl3MUJqSkdWZGFtdVNfY21VNXQxY3VzalpIczBWbGFueVk0TVZ2N2d2d0hVWTF4WG56TDJ6bz0=", + "usage_rules": { + "expires_in": 300, + "supports_minting": [ + "access_token", + "refresh_token", + "id_token" + ], + "max_usage": 1 + }, + "used": 0, + "based_on": null, + "id": "96d19bea275211eba43bacde48001122" + } + + +type +:::: +The type of token. + +issued_at +::::::::: +When the token was created. Its value is a JSON number representing the number +of seconds from 1970-01-01T0:0:0Z as measured in UTC until the date/time. + +not_before +:::::::::: +If the start of the usage of the token is to be delay, this is until when. +Its value is a JSON number representing the number +of seconds from 1970-01-01T0:0:0Z as measured in UTC until the date/time. + +expires_at +:::::::::: +When the token expires. +Its value is a JSON number representing the number +of seconds from 1970-01-01T0:0:0Z as measured in UTC until the date/time. + +revoked +::::::: +If the token has been revoked. + +value +::::: +This is the value that appears in OIDC protocol exchanges. + +usage_rules +::::::::::: +Rules as to how this token can be used: + +expires_in + Used to calculate expires_at + +supports_minting + The tokens types that can be minted based on this token. Typically a code + can be used to mint ID tokens and access and refresh tokens. + +max_usage + How many times this token can be used (being used is presently defined as + used to mint other tokens). An authorization_code token can according to + the OIDC standard only be used once but then to, in the same session, + mint more then one token. + +used +:::: +How many times the token has been used + +based_on +:::::::: +Reference to the token that was used to mint this token. Might be empty if the +token was minted based on the grant it belongs to. + +id +:: +Token identifier + +Session Info API +---------------- +.. _`Session Info API`: + +add_subordinate ++++++++++++++++ +.. _`add_subordinate`: + +remove_subordinate +++++++++++++++++++ +.. _`removed_subordinate`: + +revoke +++++++ +.. _`revoke`: + +is_revoked +++++++++++ +.. _`is_revoked`: + +to_json ++++++++ +.. _`to_json`: + +from_json ++++++++++ +.. _`from_json`: + +Grant API +--------- +.. _`Grant API`: + +Token API +--------- +.. _`Token API`: + +Session Manager API +------------------- +.. _`Session Manager API`: + +create_session +++++++++++++++ +.. _create_session: + +Creating a new session is done by running the create_session method of +the class SessionManager. The create_session methods takes the following +arguments. + +authn_event + An AuthnEvent class instance that describes the authentication event. + +auth_req + The Authentication request + +client_id + The client Identifier + +user_id + The user identifier + +sector_identifier + A possible sector identifier to be used when constructing a pairwise + subject identifier + +sub_type + The type of subject identifier that should be constructed. It can either be + *pairwise* or *public*. + +So a typical command would look like this:: + + + authn_event = create_authn_event(self.user_id) + session_manager.create_session(authn_event=authn_event, auth_req=auth_req, + user_id=self.user_id, client_id=client_id, + sub_type=sub_type, sector_identifier=sector_identifier) + +add_grant ++++++++++ +.. _add_grant: + +add_grant(self, user_id, client_id, **kwargs) + +find_token +++++++++++ +.. _find_token: + +find_token(self, session_id, token_value) + +get_authentication_event +++++++++++++++++++++++++ +.. _get_authentication_event: + +get_authentication_event(self, session_id) + + +get_client_session_info ++++++++++++++++++++++++ +.. _get_client_session_info: + +get_client_session_info(self, session_id) + +get_grant_by_response_type +++++++++++++++++++++++++++ +.. _get_grant_by_response_type: + +get_grant_by_response_type(self, user_id, client_id) + +get_session_info +++++++++++++++++ +.. _get_session_info: + +get_session_info(self, session_id) + +get_session_info_by_token ++++++++++++++++++++++++++ +.. _get_session_info_by_token: + +get_session_info_by_token(self, token_value) + +get_sids_by_user_id ++++++++++++++++++++ +.. _get_sids_by_user_id: + +get_sids_by_user_id(self, user_id) + +get_user_info ++++++++++++++ +.. _get_user_info: + +get_user_info(self, uid) + +grants +++++++ +.. _grants: + +grants(self, session_id) + +revoke_client_session ++++++++++++++++++++++ +.. _revoke_client_session: + +revoke_client_session(self, session_id) + +revoke_grant +++++++++++++ +.. _revoke_grant: + +revoke_grant(self, session_id) + +revoke_token +++++++++++++ +.. _revoke_token: + +revoke_token(self, session_id, token_value, recursive=False) \ No newline at end of file diff --git a/setup.py b/setup.py index 40eb21d..daa1506 100755 --- a/setup.py +++ b/setup.py @@ -52,7 +52,7 @@ def run_tests(self): packages=["oidcendpoint", 'oidcendpoint/oidc', 'oidcendpoint/authz', 'oidcendpoint/user_authn', 'oidcendpoint/user_info', 'oidcendpoint/oauth2', 'oidcendpoint/oidc/add_on', - 'oidcendpoint/common'], + 'oidcendpoint/session', 'oidcendpoint/token'], package_dir={"": "src"}, classifiers=[ "Development Status :: 4 - Beta", diff --git a/src/oidcendpoint/__init__.py b/src/oidcendpoint/__init__.py index 653aef2..2bb7941 100755 --- a/src/oidcendpoint/__init__.py +++ b/src/oidcendpoint/__init__.py @@ -1,7 +1,7 @@ import string from secrets import choice -__version__ = "1.1.1" +__version__ = '2.0.0' DEF_SIGN_ALG = { "id_token": "RS256", diff --git a/src/oidcendpoint/authn_event.py b/src/oidcendpoint/authn_event.py index 952078f..d309195 100644 --- a/src/oidcendpoint/authn_event.py +++ b/src/oidcendpoint/authn_event.py @@ -1,20 +1,22 @@ from oidcmsg.message import SINGLE_OPTIONAL_INT +from oidcmsg.message import SINGLE_OPTIONAL_STRING from oidcmsg.message import SINGLE_REQUIRED_STRING from oidcmsg.message import Message from oidcmsg.time_util import time_sans_frac DEFAULT_AUTHN_EXPIRES_IN = 3600 + class AuthnEvent(Message): c_param = { "uid": SINGLE_REQUIRED_STRING, - "salt": SINGLE_REQUIRED_STRING, "authn_info": SINGLE_REQUIRED_STRING, "authn_time": SINGLE_OPTIONAL_INT, "valid_until": SINGLE_OPTIONAL_INT, + "sub": SINGLE_OPTIONAL_STRING } - def valid(self, now=0): + def is_valid(self, now=0): if now: return self["valid_until"] > now else: @@ -24,30 +26,42 @@ def expires_in(self): return self["valid_until"] - time_sans_frac() -def create_authn_event(uid, salt, authn_info=None, **kwargs): +def create_authn_event(uid, authn_info=None, authn_time: int = 0, + valid_until: int = 0, expires_in: int = 0, + sub: str = "", **kwargs): """ - :param uid: - :param salt: - :param authn_info: + :param uid: User ID. This is the identifier used by the user DB + :param authn_time: When the authentication took place + :param authn_info: Information about the authentication + :param valid_until: Until when the authentication is valid + :param expires_in: How long before the authentication expires + :param sub: Subject identifier. The identifier for the user used between + the AS and the RP. :param kwargs: :return: """ - args = {"uid": uid, "salt": salt, "authn_info": authn_info} + args = {"uid": uid, "authn_info": authn_info} + + if sub: + args["sub"] = sub - try: - args["authn_time"] = int(kwargs["authn_time"]) - except KeyError: - try: - args["authn_time"] = int(kwargs["timestamp"]) - except KeyError: + if authn_time: + args["authn_time"] = authn_time + else: + _ts = kwargs.get("timestamp") + if _ts: + args["authn_time"] = _ts + else: args["authn_time"] = time_sans_frac() - try: - args["valid_until"] = kwargs["valid_until"] - except KeyError: - _expires_in = kwargs.get("expires_in", DEFAULT_AUTHN_EXPIRES_IN) - args["valid_until"] = args["authn_time"] + _expires_in + if valid_until: + args["valid_until"] = valid_until + else: + if expires_in: + args["valid_until"] = args["authn_time"] + expires_in + else: + args["valid_until"] = args["authn_time"] + DEFAULT_AUTHN_EXPIRES_IN return AuthnEvent(**args) diff --git a/src/oidcendpoint/authz/__init__.py b/src/oidcendpoint/authz/__init__.py index 3869cc2..2277df8 100755 --- a/src/oidcendpoint/authz/__init__.py +++ b/src/oidcendpoint/authz/__init__.py @@ -1,9 +1,13 @@ +import copy import inspect import logging import sys +from typing import Optional +from typing import Union -from oidcendpoint import sanitize -from oidcendpoint.cookie import cookie_value +from oidcmsg.message import Message + +from oidcendpoint.session.grant import Grant logger = logging.getLogger(__name__) @@ -11,53 +15,89 @@ class AuthzHandling(object): """ Class that allow an entity to manage authorization """ - def __init__(self, endpoint_context, **kwargs): + def __init__(self, endpoint_context, grant_config=None, **kwargs): self.endpoint_context = endpoint_context self.cookie_dealer = endpoint_context.cookie_dealer - self.permdb = {} + self.grant_config = grant_config or {} self.kwargs = kwargs - def __call__(self, *args, **kwargs): - return "" + def usage_rules(self, client_id): + if "usage_rules" in self.grant_config: + _usage_rules = copy.deepcopy(self.grant_config["usage_rules"]) + else: + _usage_rules = {} - def set(self, uid, client_id, permission): try: - self.permdb[uid][client_id] = permission + _per_client = self.endpoint_context.cdb[client_id]["token_usage_rules"] except KeyError: - self.permdb[uid] = {client_id: permission} - - def permissions(self, cookie=None, **kwargs): - if cookie is None: - return None + pass else: - logger.debug("kwargs: %s" % sanitize(kwargs)) - - val = self.cookie_dealer.get_cookie_value(cookie) - if val is None: - return None + if _usage_rules: + for _token_type, _rule in _usage_rules.items(): + _pc = _per_client.get(_token_type) + if _pc: + _rule.update(_pc) + for _token_type, _rule in _per_client.items(): + if _token_type not in _usage_rules: + _usage_rules[_token_type] = _rule else: - b64, _ts, typ = val + _usage_rules = _per_client - info = cookie_value(b64) - return self.get(info["sub"], info["client_id"]) + return _usage_rules - def get(self, uid, client_id): + def usage_rules_for(self, client_id, token_type): + _token_usage = self.usage_rules(client_id=client_id) try: - return self.permdb[uid][client_id] + return _token_usage[token_type] except KeyError: - return None + return {} + + def __call__(self, session_id: str, request: Union[dict, Message], + resources: Optional[list] = None) -> Grant: + args = self.grant_config.copy() + + scope = request.get("scope") + if scope: + args["scope"] = scope + + claims = request.get("claims") + if claims: + if isinstance(request, Message): + claims = claims.to_dict() + args["claims"] = claims + + session_info = self.endpoint_context.session_manager.get_session_info( + session_id=session_id, grant=True + ) + grant = session_info["grant"] + + for key, val in args.items(): + if key == "expires_in": + grant.set_expires_at(val) + else: + setattr(grant, key, val) + if resources is None: + grant.resources = [session_info["client_id"]] + else: + grant.resources = resources -class Implicit(AuthzHandling): - def __init__(self, endpoint_context, permission="implicit"): - AuthzHandling.__init__(self, endpoint_context) - self.permission = permission + # This is where user consent should be handled + for interface in ["userinfo", "introspection", "id_token", "access_token"]: + grant.claims[interface] = self.endpoint_context.claims_interface.get_claims( + session_id=session_id, scopes=request["scope"], usage=interface + ) + return grant - def permissions(self, cookie=None, **kwargs): - return self.permission - def get(self, uid, client_id): - return self.permission +class Implicit(AuthzHandling): + def __call__(self, session_id: str, request: Union[dict, Message], + resources: Optional[list] = None) -> Grant: + args = self.grant_config.copy() + grant = self.endpoint_context.session_manager.get_grant(session_id=session_id) + for arg, val in args: + setattr(grant, arg, val) + return grant def factory(msgtype, endpoint_context, **kwargs): diff --git a/src/oidcendpoint/common/__init__.py b/src/oidcendpoint/common/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/oidcendpoint/common/authorization.py b/src/oidcendpoint/common/authorization.py deleted file mode 100755 index 06f0158..0000000 --- a/src/oidcendpoint/common/authorization.py +++ /dev/null @@ -1,276 +0,0 @@ -import logging -from urllib.parse import unquote -from urllib.parse import urlencode -from urllib.parse import urlparse - -from oidcmsg.exception import ParameterError -from oidcmsg.exception import URIError -from oidcmsg.oauth2 import AuthorizationErrorResponse -from oidcmsg.oidc import AuthorizationResponse -from oidcmsg.oidc import verified_claim_name - -from oidcendpoint import sanitize -from oidcendpoint.exception import RedirectURIError -from oidcendpoint.exception import UnknownClient -from oidcendpoint.util import split_uri - -logger = logging.getLogger(__name__) - -FORM_POST = """ - - Submit This Form - - -
- {inputs} -
- -""" - - -def inputs(form_args): - """ - Creates list of input elements - """ - element = [] - html_field = '' - for name, value in form_args.items(): - element.append(html_field.format(name, value)) - return "\n".join(element) - - -def max_age(request): - verified_request = verified_claim_name("request") - return request.get(verified_request, {}).get("max_age") or request.get("max_age", 0) - - -def verify_uri(endpoint_context, request, uri_type, client_id=None): - """ - A redirect URI - MUST NOT contain a fragment - MAY contain query component - - :param endpoint_context: An EndpointContext instance - :param request: The authorization request - :param uri_type: redirect_uri or post_logout_redirect_uri - :return: An error response if the redirect URI is faulty otherwise - None - """ - _cid = request.get("client_id", client_id) - - if not _cid: - logger.error("No client id found") - raise UnknownClient("No client_id provided") - else: - logger.debug("Client ID: {}".format(_cid)) - - _redirect_uri = unquote(request[uri_type]) - - part = urlparse(_redirect_uri) - if part.fragment: - raise URIError("Contains fragment") - - (_base, _query) = split_uri(_redirect_uri) - # if _query: - # _query = parse_qs(_query) - - match = False - # Get the clients registered redirect uris - client_info = endpoint_context.cdb.get(_cid, {}) - if not client_info: - raise KeyError("No such client") - logger.debug("Client info: {}".format(client_info)) - redirect_uris = client_info.get("{}s".format(uri_type)) - if not redirect_uris: - if _cid not in endpoint_context.cdb: - logger.debug("CIDs: {}".format(list(endpoint_context.cdb.keys()))) - raise KeyError("No such client") - raise ValueError("No registered {}".format(uri_type)) - else: - for regbase, rquery in redirect_uris: - # The URI MUST exactly match one of the Redirection URI - if _base == regbase: - # every registered query component must exist in the uri - if rquery: - if not _query: - raise ValueError("Missing query part") - - for key, vals in rquery.items(): - if key not in _query: - raise ValueError('"{}" not in query part'.format(key)) - - for val in vals: - if val not in _query[key]: - raise ValueError( - "{}={} value not in query part".format(key, val) - ) - - # and vice versa, every query component in the uri - # must be registered - if _query: - if not rquery: - raise ValueError("No registered query part") - - for key, vals in _query.items(): - if key not in rquery: - raise ValueError('"{}" extra in query part'.format(key)) - for val in vals: - if val not in rquery[key]: - raise ValueError( - "Extra {}={} value in query part".format(key, val) - ) - match = True - break - if not match: - raise RedirectURIError("Doesn't match any registered uris") - - -def join_query(base, query): - """ - - :param base: URL base - :param query: query part as a dictionary - :return: - """ - if query: - return "{}?{}".format(base, urlencode(query, doseq=True)) - else: - return base - - -def get_uri(endpoint_context, request, uri_type): - """ verify that the redirect URI is reasonable. - - :param endpoint_context: An EndpointContext instance - :param request: The Authorization request - :param uri_type: 'redirect_uri' or 'post_logout_redirect_uri' - :return: redirect_uri - """ - uri = "" - - if uri_type in request: - verify_uri(endpoint_context, request, uri_type) - uri = request[uri_type] - else: - uris = "{}s".format(uri_type) - client_id = str(request["client_id"]) - if client_id in endpoint_context.cdb: - _specs = endpoint_context.cdb[client_id].get(uris) - if not _specs: - raise ParameterError("Missing {} and none registered".format(uri_type)) - - if len(_specs) > 1: - raise ParameterError( - "Missing {} and more than one registered".format(uri_type) - ) - - uri = join_query(*_specs[0]) - - return uri - - -def authn_args_gather(request, authn_class_ref, cinfo, **kwargs): - """ - Gather information to be used by the authentication method - """ - authn_args = { - "authn_class_ref": authn_class_ref, - "query": request.to_urlencoded(), - "return_uri": request["redirect_uri"], - } - - if "req_user" in kwargs: - authn_args["as_user"] = (kwargs["req_user"],) - - # Below are OIDC specific. Just ignore if OAuth2 - for attr in ["policy_uri", "logo_uri", "tos_uri"]: - if cinfo.get(attr): - authn_args[attr] = cinfo[attr] - - for attr in ["ui_locales", "acr_values", "login_hint"]: - if request.get(attr): - authn_args[attr] = request[attr] - - return authn_args - - -def create_authn_response(endpoint, request, sid): - """ - - :param endpoint: - :param request: - :param sid: - :return: - """ - # create the response - aresp = AuthorizationResponse() - if request.get("state"): - aresp["state"] = request["state"] - - if "response_type" in request and request["response_type"] == ["none"]: - fragment_enc = False - else: - _context = endpoint.endpoint_context - _sinfo = _context.sdb[sid] - - if request.get("scope"): - aresp["scope"] = request["scope"] - - rtype = set(request["response_type"][:]) - handled_response_type = [] - - fragment_enc = True - if len(rtype) == 1 and "code" in rtype: - fragment_enc = False - - if "code" in request["response_type"]: - _code = aresp["code"] = _context.sdb[sid]["code"] - handled_response_type.append("code") - else: - _context.sdb.update(sid, code=None) - _code = None - - if "token" in rtype: - _dic = _context.sdb.upgrade_to_token(issue_refresh=False, key=sid) - - logger.debug("_dic: %s" % sanitize(_dic)) - for key, val in _dic.items(): - if key in aresp.parameters() and val is not None: - aresp[key] = val - - handled_response_type.append("token") - - _access_token = aresp.get("access_token", None) - - not_handled = rtype.difference(handled_response_type) - if not_handled: - resp = AuthorizationErrorResponse( - error="invalid_request", error_description="unsupported_response_type" - ) - return {"response_args": resp, "fragment_enc": fragment_enc} - - return {"response_args": aresp, "fragment_enc": fragment_enc} - - -class AllowedAlgorithms: - def __init__(self, algorithm_parameters): - self.algorithm_parameters = algorithm_parameters - - def __call__(self, client_id, endpoint_context, alg, alg_type): - _cinfo = endpoint_context.cdb[client_id] - _pinfo = endpoint_context.provider_info - - _reg, _sup = self.algorithm_parameters[alg_type] - _allowed = _cinfo.get(_reg) - if _allowed is None: - _allowed = _pinfo.get(_sup) - - if alg not in _allowed: - logger.error( - "Signing alg user: {} not among allowed: {}".format(alg, _allowed) - ) - raise ValueError("Not allowed '%s' algorithm used", alg) - - -def re_authenticate(request, authn): - return False diff --git a/src/oidcendpoint/cookie.py b/src/oidcendpoint/cookie.py index bbbde9f..fa6b6cc 100755 --- a/src/oidcendpoint/cookie.py +++ b/src/oidcendpoint/cookie.py @@ -202,7 +202,7 @@ def make_cookie_content( :param max_age: The time in seconds for when a cookie will be deleted :type max_age: int :param secure: A secure cookie is only sent to the server with an encrypted request over the - HTTPS protocol. + HTTPS protocol. :type secure: boolean :param http_only: HttpOnly cookies are inaccessible to JavaScript's Document.cookie API :type http_only: boolean @@ -591,4 +591,7 @@ def new_cookie(endpoint_context, cookie_name=None, typ="sso", **kwargs): def cookie_value(b64): - return json.loads(as_unicode(b64d(as_bytes(b64)))) + try: + return json.loads(as_unicode(b64d(as_bytes(b64)))) + except Exception: + return b64 diff --git a/src/oidcendpoint/endpoint.py b/src/oidcendpoint/endpoint.py index b853c22..58a882f 100755 --- a/src/oidcendpoint/endpoint.py +++ b/src/oidcendpoint/endpoint.py @@ -1,20 +1,22 @@ import logging from functools import cmp_to_key +from typing import Optional +from typing import Union from urllib.parse import urlparse from cryptojwt import jwe from cryptojwt.jws.jws import SIGNER_ALGS -from oidcendpoint.token_handler import UnknownToken from oidcmsg.exception import MissingRequiredAttribute from oidcmsg.exception import MissingRequiredValue from oidcmsg.message import Message -from oidcmsg.oauth2 import ResponseMessage, AuthorizationErrorResponse +from oidcmsg.oauth2 import ResponseMessage from oidcendpoint import sanitize from oidcendpoint.client_authn import UnknownOrNoAuthnMethod from oidcendpoint.client_authn import client_auth_setup from oidcendpoint.client_authn import verify_client from oidcendpoint.exception import UnAuthorizedClient +from oidcendpoint.token.exception import UnknownToken from oidcendpoint.util import OAUTH2_NOCACHE_HEADERS __author__ = "Roland Hedberg" @@ -38,24 +40,25 @@ - post_construct (*) - update_http_args -do_response returns a dictionary that can look like this: -{ - 'response': - _response as a string or as a Message instance_ - 'http_headers': [ - ('Content-type', 'application/json'), - ('Pragma', 'no-cache'), - ('Cache-Control', 'no-store') - ], - 'cookie': _list of cookies_, - 'response_placement': 'body' -} +do_response returns a dictionary that can look like this:: + + { + 'response': + _response as a string or as a Message instance_ + 'http_headers': [ + ('Content-type', 'application/json'), + ('Pragma', 'no-cache'), + ('Cache-Control', 'no-store') + ], + 'cookie': _list of cookies_, + 'response_placement': 'body' + } "response" MUST be present "http_headers" MAY be present "cookie": MAY be present "response_placement": If absent defaults to the endpoints response_placement - parameter value or if that is also missing 'url' +parameter value or if that is also missing 'url' """ @@ -118,8 +121,8 @@ def construct_endpoint_info(default_capabilities, **kwargs): elif "signing_alg_values_supported" in attr: _info[attr] = assign_algorithms("signing_alg") if ( - attr - == "token_endpoint_auth_signing_alg_values_supported" + attr + == "token_endpoint_auth_signing_alg_values_supported" ): # none must not be in # token_endpoint_auth_signing_alg_values_supported @@ -191,7 +194,7 @@ def __init__(self, endpoint_context, **kwargs): if _methods: self.client_authn_method = client_auth_setup(_methods, endpoint_context) elif ( - _methods is not None + _methods is not None ): # [] or '' or something not None but regarded as nothing. self.client_authn_method = [None] # Ignore default value elif self.default_capabilities: @@ -205,8 +208,9 @@ def __init__(self, endpoint_context, **kwargs): # This is for matching against aud in JWTs # By default the endpoint's endpoint URL is an allowed target self.allowed_targets = [self.name] + self.client_verification_method = [] - def parse_request(self, request, auth=None, **kwargs): + def parse_request(self, request: str, auth=None, **kwargs): """ :param request: The request the server got @@ -307,9 +311,9 @@ def client_authentication(self, request, auth=None, **kwargs): LOGGER.debug("authn_info: %s", authn_info) if ( - authn_info == {} - and self.client_authn_method - and len(self.client_authn_method) + authn_info == {} + and self.client_authn_method + and len(self.client_authn_method) ): LOGGER.debug("client_authn_method: %s", self.client_authn_method) raise UnAuthorizedClient("Authorization failed") @@ -341,7 +345,7 @@ def do_post_construct(self, response_args, request, **kwargs): return response_args - def process_request(self, request=None, **kwargs): + def process_request(self, request: Optional[Union[Message, dict]] = None, **kwargs): """ :param request: The request, can be in a number of formats @@ -349,7 +353,10 @@ def process_request(self, request=None, **kwargs): """ return {} - def construct(self, response_args, request, **kwargs): + def construct(self, + response_args: Optional[dict] = None, + request: Optional[Union[Message, dict]] = None, + **kwargs): """ Construct the response @@ -365,12 +372,20 @@ def construct(self, response_args, request, **kwargs): return self.do_post_construct(response, request, **kwargs) - def response_info(self, response_args, request, **kwargs): + def response_info(self, + response_args: Optional[dict] = None, + request: Optional[Union[Message, dict]] = None, + **kwargs) -> dict: return self.construct(response_args, request, **kwargs) - def do_response(self, response_args=None, request=None, error="", **kwargs): + def do_response(self, + response_args: Optional[dict] = None, + request: Optional[Union[Message, dict]] = None, + error: Optional[str] = "", **kwargs) -> dict: """ - + :param response_args: Information to use when constructing the response + :param request: The original request + :param error: Possible error encountered while processing the request """ do_placement = True content_type = "text/html" diff --git a/src/oidcendpoint/endpoint_context.py b/src/oidcendpoint/endpoint_context.py index 10b6430..2abfa8c 100755 --- a/src/oidcendpoint/endpoint_context.py +++ b/src/oidcendpoint/endpoint_context.py @@ -5,17 +5,15 @@ from jinja2 import Environment from jinja2 import FileSystemLoader from oidcmsg.context import OidcContext -from oidcmsg.oidc import IdToken from oidcendpoint import authz from oidcendpoint import rndstr from oidcendpoint.id_token import IDToken from oidcendpoint.scopes import SCOPE2CLAIMS -from oidcendpoint.scopes import STANDARD_CLAIMS -from oidcendpoint.scopes import Claims from oidcendpoint.scopes import Scopes -from oidcendpoint.session import create_session_db -from oidcendpoint.sso_db import SSODb +from oidcendpoint.session.claims import STANDARD_CLAIMS +from oidcendpoint.session.claims import ClaimsInterface +from oidcendpoint.session.manager import create_session_manager from oidcendpoint.template_handler import Jinja2TemplateHandler from oidcendpoint.user_authn.authn_context import populate_authn_broker from oidcendpoint.util import allow_refresh_token @@ -87,34 +85,28 @@ def get_token_handlers(conf): class EndpointContext(OidcContext): def __init__( - self, - conf, - keyjar=None, - cwd="", - cookie_dealer=None, - httpc=None, - cookie_name=None, - jwks_uri_path=None, + self, + conf, + keyjar=None, + cwd="", + cookie_dealer=None, + httpc=None, ): OidcContext.__init__(self, conf, keyjar, entity_id=conf.get("issuer", "")) self.conf = conf # For my Dev environment - self.sso_db = None - self.session_db = None - self.state_db = None self.cdb = None self.jti_db = None self.registration_access_token = None + self.session_db = None self.add_boxes( { - "state": "state_db", "client": "cdb", "jti": "jti_db", "registration_access_token": "registration_access_token", - "sso": "sso_db", - "session": "session_db", + "session": "session_db" }, self.db_conf, ) @@ -134,7 +126,7 @@ def __init__( self.jwks_uri = None self.sso_ttl = 14400 # 4h self.symkey = rndstr(24) - self.id_token_schema = IdToken + # self.id_token_schema = IdToken self.idtoken = None self.authn_broker = None self.authz = None @@ -152,7 +144,7 @@ def __init__( "sso_ttl", "symkey", "client_authn", - "id_token_schema", + # "id_token_schema", ]: try: setattr(self, param, conf[param]) @@ -172,9 +164,7 @@ def __init__( # has to be after the above self.set_session_db() - if cookie_name: - self.cookie_name = cookie_name - elif "cookie_name" in conf: + if "cookie_name" in conf: self.cookie_name = conf["cookie_name"] else: self.cookie_name = { @@ -196,11 +186,7 @@ def __init__( self.template_handler = Jinja2TemplateHandler(loader) self.setup = {} - if not jwks_uri_path: - try: - jwks_uri_path = conf["keys"]["uri_path"] - except KeyError: - pass + jwks_uri_path = conf["keys"]["uri_path"] try: if self.issuer.endswith("/"): @@ -248,7 +234,7 @@ def __init__( self.httpc_params = {"verify": conf.get("verify_ssl")} self.set_scopes_handler() - self.set_claims_handler() + # self.set_claims_handler() # If pushed authorization is supported if "pushed_authorization_request_endpoint" in self.provider_info: @@ -260,6 +246,8 @@ def __init__( self.dev_auth_db = None self.add_boxes({"dev_auth": "dev_auth_db"}, self.db_conf) + self.claims_interface = ClaimsInterface(self) + def set_scopes_handler(self): _spec = self.conf.get("scopes_handler") if _spec: @@ -269,20 +257,19 @@ def set_scopes_handler(self): else: self.scopes_handler = Scopes() - def set_claims_handler(self): - _spec = self.conf.get("claims_handler") - if _spec: - _kwargs = _spec.get("kwargs", {}) - _cls = importer(_spec["class"])(**_kwargs) - self.claims_handler = _cls(_kwargs) - else: - self.claims_handler = Claims() + # def set_claims_handler(self): + # _spec = self.conf.get("claims_handler") + # if _spec: + # _kwargs = _spec.get("kwargs", {}) + # _cls = importer(_spec["class"])(**_kwargs) + # self.claims_handler = _cls(_kwargs) + # else: + # self.claims_handler = Claims() def set_session_db(self): - self.do_session_db(SSODb(db=self.sso_db), self.session_db) + self.do_session_manager() # append userinfo db to the session db self.do_userinfo() - logger.debug("Session DB: {}".format(self.sdb.__dict__)) def do_add_on(self): if self.conf.get("add_on"): @@ -311,9 +298,9 @@ def do_login_hint_lookup(self): def do_userinfo(self): _conf = self.conf.get("userinfo") if _conf: - if self.sdb: + if self.session_manager: self.userinfo = init_user_info(_conf, self.cwd) - self.sdb.userinfo = self.userinfo + self.session_manager.userinfo = self.userinfo else: logger.warning("Cannot init_user_info if no session_db was provided.") @@ -357,10 +344,15 @@ def do_sub_func(self): else: self._sub_func[key] = args["function"] - def do_session_db(self, sso_db, db=None): - self.sdb = create_session_db( - self, self.th_args, db=db, sso_db=sso_db, sub_func=self._sub_func - ) + def do_session_manager(self, db=None): + if self.session_db is None: + self.session_manager = create_session_manager( + self, self.th_args, db=db, sub_func=self._sub_func + ) + else: + self.session_manager = create_session_manager( + self, self.th_args, db=self.session_db, sub_func=self._sub_func + ) def do_endpoints(self): self.endpoint = build_endpoints( diff --git a/src/oidcendpoint/id_token.py b/src/oidcendpoint/id_token.py index 844d2d3..cbc4647 100755 --- a/src/oidcendpoint/id_token.py +++ b/src/oidcendpoint/id_token.py @@ -1,11 +1,13 @@ import logging +import uuid from cryptojwt.jws.utils import left_hash from cryptojwt.jwt import JWT from oidcendpoint.endpoint import construct_endpoint_info -from oidcendpoint.userinfo import collect_user_info -from oidcendpoint.userinfo import userinfo_in_id_token_claims +from oidcendpoint.session import unpack_session_key +from oidcendpoint.session.claims import claims_match +from oidcendpoint.session.info import SessionInfo logger = logging.getLogger(__name__) @@ -50,7 +52,7 @@ def include_session_id(endpoint_context, client_id, where): def get_sign_and_encrypt_algorithms( - endpoint_context, client_info, payload_type, sign=False, encrypt=False + endpoint_context, client_info, payload_type, sign=False, encrypt=False ): args = {"sign": sign, "encrypt": encrypt} if sign: @@ -112,49 +114,49 @@ class IDToken(object): def __init__(self, endpoint_context, **kwargs): self.endpoint_context = endpoint_context self.kwargs = kwargs - self.enable_claims_per_client = kwargs.get("enable_claims_per_client", False) - self.add_c_hash = kwargs.get("add_c_hash", True) - self.add_at_hash = kwargs.get("add_at_hash", True) self.scope_to_claims = None self.provider_info = construct_endpoint_info( self.default_capabilities, **kwargs ) def payload( - self, - session, - acr="", - alg="RS256", - code=None, - access_token=None, - user_info=None, - auth_time=0, - lifetime=None, - extra_claims=None, + self, + session_id, + alg="RS256", + code=None, + access_token=None, + extra_claims=None, ): """ - :param session: Session information - :param acr: Default Assurance/Authentication context class reference + :param session_id: Session identifier :param alg: Which signing algorithm to use for the IdToken :param code: Access grant :param access_token: Access Token - :param user_info: If user info are to be part of the IdToken - :param auth_time: - :param lifetime: Life time of the ID Token :param extra_claims: extra claims to be added to the ID Token :return: IDToken instance """ - _args = {"sub": session["sub"]} - - if lifetime is None: - lifetime = DEF_LIFETIME - - if auth_time: - _args["auth_time"] = auth_time - if acr: - _args["acr"] = acr + _mngr = self.endpoint_context.session_manager + session_information = _mngr.get_session_info(session_id, grant=True) + grant = session_information["grant"] + _args = {"sub": grant.sub} + if grant.authentication_event: + for claim, attr in {"authn_time": "auth_time", "authn_info": "acr"}.items(): + _val = grant.authentication_event.get(claim) + if _val: + _args[attr] = _val + + _claims_restriction = grant.claims.get("id_token") + if _claims_restriction == {}: + user_info = None + else: + user_info = self.endpoint_context.claims_interface.get_user_claims( + user_id=session_information["user_id"], + claims_restriction=_claims_restriction) + if _claims_restriction and "acr" in _claims_restriction and "acr" in _args: + if claims_match(_args["acr"], _claims_restriction["acr"]) is False: + raise ValueError("Could not match expected 'acr'") if user_info: try: @@ -178,46 +180,37 @@ def payload( halg = "HS%s" % alg[-3:] if code: _args["c_hash"] = left_hash(code.encode("utf-8"), halg) - elif self.add_c_hash and session.get("code"): - _args["c_hash"] = left_hash( - session.get("code").encode("utf-8"), halg - ) if access_token: _args["at_hash"] = left_hash(access_token.encode("utf-8"), halg) - elif self.add_at_hash and session.get("access_token"): - _args["at_hash"] = left_hash( - session.get("access_token").encode("utf-8"), halg - ) - authn_req = session["authn_req"] + authn_req = grant.authorization_request if authn_req: try: _args["nonce"] = authn_req["nonce"] except KeyError: pass - return {"payload": _args, "lifetime": lifetime} + return _args def sign_encrypt( - self, - session_info, - client_id, - code=None, - access_token=None, - user_info=None, - sign=True, - encrypt=False, - lifetime=None, - extra_claims=None, + self, + session_id, + client_id, + code=None, + access_token=None, + sign=True, + encrypt=False, + lifetime=None, + extra_claims=None, ): """ Signed and or encrypt a IDToken + :param lifetime: How long the ID Token should be valid :param session_info: Session information :param client_id: Client ID :param code: Access grant :param access_token: Access Token - :param user_info: User information :param sign: If the JWT should be signed :param encrypt: If the JWT should be encrypted :param extra_claims: Extra claims to be added to the ID Token @@ -231,69 +224,52 @@ def sign_encrypt( _cntx, client_info, "id_token", sign=sign, encrypt=encrypt ) - _authn_event = session_info["authn_event"] - - _idt_info = self.payload( - session_info, - acr=_authn_event["authn_info"], + _payload = self.payload( + session_id=session_id, alg=alg_dict["sign_alg"], code=code, access_token=access_token, - user_info=user_info, - auth_time=_authn_event["authn_time"], - lifetime=lifetime, - extra_claims=extra_claims, + extra_claims=extra_claims ) + if lifetime is None: + lifetime = DEF_LIFETIME + _jwt = JWT( - _cntx.keyjar, iss=_cntx.issuer, lifetime=_idt_info["lifetime"], **alg_dict + _cntx.keyjar, iss=_cntx.issuer, lifetime=lifetime, **alg_dict ) - return _jwt.pack(_idt_info["payload"], recv=client_id) + return _jwt.pack(_payload, recv=client_id) - def make(self, req, sess_info, authn_req=None, user_claims=False, **kwargs): + def make(self, session_id, **kwargs): _context = self.endpoint_context - if authn_req: - _client_id = authn_req["client_id"] + user_id, client_id, grant_id = unpack_session_key(session_id) + + # Should I add session ID. This is about Single Logout. + if include_session_id(_context, client_id, "back") or include_session_id( + _context, client_id, "front"): + + # Note that this session ID is not the session ID the session manager is using. + # It must be possible to map from one to the other. + logout_session_id = uuid.uuid4().get_hex() + _item = SessionInfo() + _item.set("user_id", user_id) + _item.set("client_id", client_id) + # Store the map + _mngr = self.endpoint_context.session_manager + _mngr.set([logout_session_id], _item) + # add identifier to extra arguments + xargs = {"sid": logout_session_id} else: - _client_id = req["client_id"] - - _cinfo = _context.cdb[_client_id] + xargs = {} - idtoken_claims = dict(self.kwargs.get("available_claims", {})) - if self.enable_claims_per_client: - idtoken_claims.update(_cinfo.get("id_token_claims", {})) lifetime = self.kwargs.get("lifetime") - userinfo = userinfo_in_id_token_claims(_context, sess_info, idtoken_claims) - - if user_claims: - info = collect_user_info(_context, sess_info) - if userinfo is None: - userinfo = info - else: - userinfo.update(info) - - # Should I add session ID - req_sid = include_session_id( - _context, _client_id, "back" - ) or include_session_id(_context, _client_id, "front") - - if req_sid: - xargs = { - "sid": _context.sdb.get_sid_by_sub_and_client_id( - sess_info["sub"], _client_id - ) - } - else: - xargs = {} - return self.sign_encrypt( - sess_info, - _client_id, + session_id, + client_id, sign=True, - user_info=userinfo, lifetime=lifetime, extra_claims=xargs, **kwargs diff --git a/src/oidcendpoint/jwt_token.py b/src/oidcendpoint/jwt_token.py deleted file mode 100644 index e3f4ec8..0000000 --- a/src/oidcendpoint/jwt_token.py +++ /dev/null @@ -1,173 +0,0 @@ -from typing import Any -from typing import Dict -from typing import Optional - -from cryptojwt import JWT -from cryptojwt.jws.exception import JWSException - -from oidcendpoint.exception import ToOld -from oidcendpoint.scopes import convert_scopes2claims -from oidcendpoint.token_handler import Token -from oidcendpoint.token_handler import UnknownToken -from oidcendpoint.token_handler import is_expired - - -class JWTToken(Token): - init_args = { - "add_claims_by_scope": False, - "enable_claims_per_client": False, - "add_scope": False, - "add_claims": {}, - } - - def __init__( - self, - typ, - keyjar=None, - issuer=None, - aud=None, - alg="ES256", - lifetime=300, - ec=None, - token_type="Bearer", - **kwargs - ): - Token.__init__(self, typ, **kwargs) - self.token_type = token_type - self.lifetime = lifetime - self.args = { - (k, v) for k, v in kwargs.items() if k not in self.init_args.keys() - } - - self.key_jar = keyjar or ec.keyjar - self.issuer = issuer or ec.issuer - self.cdb = ec.cdb - self.cntx = ec - - self.def_aud = aud or [] - self.alg = alg - self.scope_claims_map = kwargs.get("scope_claims_map", ec.scope2claims) - - self.add_claims = self.init_args["add_claims"] - self.add_claims_by_scope = self.init_args["add_claims_by_scope"] - self.add_scope = self.init_args["add_scope"] - self.enable_claims_per_client = self.init_args["enable_claims_per_client"] - - for param, default in self.init_args.items(): - setattr(self, param, kwargs.get(param, default)) - - def do_add_claims(self, payload, uinfo, claims): - for attr in claims: - if attr == "sub": - continue - try: - payload[attr] = uinfo[attr] - except KeyError: - pass - - def __call__( - self, - sid: str, - uinfo: Dict, - sinfo: Dict, - aud: Optional[Any], - client_id: Optional[str], - **kwargs - ): - """ - Return a token. - - :param sid: Session id - :param uinfo: User information - :param sinfo: Session information - :param aud: audience - :param client_id: client_id - :return: - """ - payload = {"sid": sid, "ttype": self.type, "sub": sinfo["sub"]} - scopes = sinfo["authn_req"]["scope"] - - if self.add_claims: - self.do_add_claims(payload, uinfo, self.add_claims) - if self.add_claims_by_scope: - _allowed_claims = self.cntx.claims_handler.allowed_claims( - client_id, self.cntx - ) - self.do_add_claims( - payload, - uinfo, - convert_scopes2claims( - scopes, - _allowed_claims, - map=self.scope_claims_map, - ).keys(), - ) - if self.add_scope: - payload["scope"] = self.cntx.scopes_handler.filter_scopes( - client_id, self.cntx, scopes - ) - - # Add claims if is access token - if self.type == "T" and self.enable_claims_per_client: - client = self.cdb.get(client_id, {}) - client_claims = client.get("access_token_claims") - if client_claims: - self.do_add_claims(payload, uinfo, client_claims) - - payload.update(kwargs) - signer = JWT( - key_jar=self.key_jar, - iss=self.issuer, - lifetime=self.lifetime, - sign_alg=self.alg, - ) - - if aud is None: - _aud = self.def_aud - else: - _aud = aud if isinstance(aud, list) else [aud] - _aud.extend(self.def_aud) - - return signer.pack(payload, aud=_aud) - - def info(self, token): - """ - Return type of Token (A=Access code, T=Token, R=Refresh token) and - the session id. - - :param token: A token - :return: tuple of token type and session id - """ - verifier = JWT(key_jar=self.key_jar, allowed_sign_algs=[self.alg]) - try: - _payload = verifier.unpack(token) - except JWSException: - raise UnknownToken() - - if is_expired(_payload["exp"]): - raise ToOld("Token has expired") - # All the token metadata - _res = { - "sid": _payload["sid"], - "type": _payload["ttype"], - "exp": _payload["exp"], - "handler": self, - } - return _res - - def is_expired(self, token, when=0): - """ - Evaluate whether the token has expired or not - - :param token: The token - :param when: The time against which to check the expiration - 0 means now. - :return: True/False - """ - verifier = JWT(key_jar=self.key_jar, allowed_sign_algs=[self.alg]) - _payload = verifier.unpack(token) - return is_expired(_payload["exp"], when) - - def gather_args(self, sid, sdb, udb): - _sinfo = sdb[sid] - return {} diff --git a/src/oidcendpoint/oauth2/authorization.py b/src/oidcendpoint/oauth2/authorization.py index e7896f8..d4b006e 100755 --- a/src/oidcendpoint/oauth2/authorization.py +++ b/src/oidcendpoint/oauth2/authorization.py @@ -1,27 +1,26 @@ import json import logging -import time +from typing import Union +from urllib.parse import unquote +from urllib.parse import urlencode +from urllib.parse import urlparse from cryptojwt import BadSyntax +from cryptojwt import as_unicode +from cryptojwt import b64d +from cryptojwt.jwe.exception import JWEException +from cryptojwt.jws.exception import NoSuitableSigningKeys from cryptojwt.utils import as_bytes -from cryptojwt.utils import as_unicode -from cryptojwt.utils import b64d from cryptojwt.utils import b64e -from oidcendpoint.token_handler import UnknownToken from oidcmsg import oauth2 from oidcmsg.exception import ParameterError -from oidcmsg.oidc import AuthorizationResponse +from oidcmsg.exception import URIError +from oidcmsg.message import Message from oidcmsg.oidc import verified_claim_name +from oidcmsg.time_util import utc_time_sans_frac from oidcendpoint import rndstr -from oidcendpoint import sanitize from oidcendpoint.authn_event import create_authn_event -from oidcendpoint.common.authorization import FORM_POST -from oidcendpoint.common.authorization import AllowedAlgorithms -from oidcendpoint.common.authorization import authn_args_gather -from oidcendpoint.common.authorization import get_uri -from oidcendpoint.common.authorization import inputs -from oidcendpoint.common.authorization import max_age from oidcendpoint.cookie import append_cookie from oidcendpoint.cookie import compute_session_state from oidcendpoint.cookie import new_cookie @@ -34,12 +33,14 @@ from oidcendpoint.exception import ToOld from oidcendpoint.exception import UnAuthorizedClientScope from oidcendpoint.exception import UnknownClient -from oidcendpoint.session import setup_session +from oidcendpoint.session import Revoked +from oidcendpoint.session import unpack_session_key +from oidcendpoint.token.exception import UnknownToken from oidcendpoint.user_authn.authn_context import pick_auth +from oidcendpoint.util import split_uri logger = logging.getLogger(__name__) - # For the time being. This is JAR specific and should probably be configurable. ALG_PARAMS = { "sign": [ @@ -56,9 +57,195 @@ ], } +FORM_POST = """ + + Submit This Form + + +
+ {inputs} +
+ +""" + + +def inputs(form_args): + """ + Creates list of input elements + """ + element = [] + html_field = '' + for name, value in form_args.items(): + element.append(html_field.format(name, value)) + return "\n".join(element) + + +def max_age(request): + verified_request = verified_claim_name("request") + return request.get(verified_request, {}).get("max_age") or request.get("max_age", 0) + + +def verify_uri(endpoint_context, request, uri_type, client_id=None): + """ + A redirect URI + MUST NOT contain a fragment + MAY contain query component + + :param endpoint_context: An EndpointContext instance + :param request: The authorization request + :param uri_type: redirect_uri or post_logout_redirect_uri + :return: An error response if the redirect URI is faulty otherwise + None + """ + _cid = request.get("client_id", client_id) + + if not _cid: + logger.error("No client id found") + raise UnknownClient("No client_id provided") + else: + logger.debug("Client ID: {}".format(_cid)) + + _redirect_uri = unquote(request[uri_type]) + + part = urlparse(_redirect_uri) + if part.fragment: + raise URIError("Contains fragment") + + (_base, _query) = split_uri(_redirect_uri) + # if _query: + # _query = parse_qs(_query) + + match = False + # Get the clients registered redirect uris + client_info = endpoint_context.cdb.get(_cid, {}) + if not client_info: + raise KeyError("No such client") + logger.debug("Client info: {}".format(client_info)) + redirect_uris = client_info.get("{}s".format(uri_type)) + if not redirect_uris: + if _cid not in endpoint_context.cdb: + logger.debug("CIDs: {}".format(list(endpoint_context.cdb.keys()))) + raise KeyError("No such client") + raise ValueError("No registered {}".format(uri_type)) + else: + for regbase, rquery in redirect_uris: + # The URI MUST exactly match one of the Redirection URI + if _base == regbase: + # every registered query component must exist in the uri + if rquery: + if not _query: + raise ValueError("Missing query part") + + for key, vals in rquery.items(): + if key not in _query: + raise ValueError('"{}" not in query part'.format(key)) + + for val in vals: + if val not in _query[key]: + raise ValueError( + "{}={} value not in query part".format(key, val) + ) + + # and vice versa, every query component in the uri + # must be registered + if _query: + if not rquery: + raise ValueError("No registered query part") + + for key, vals in _query.items(): + if key not in rquery: + raise ValueError('"{}" extra in query part'.format(key)) + for val in vals: + if val not in rquery[key]: + raise ValueError( + "Extra {}={} value in query part".format(key, val) + ) + match = True + break + if not match: + raise RedirectURIError("Doesn't match any registered uris") + + +def join_query(base, query): + """ + + :param base: URL base + :param query: query part as a dictionary + :return: + """ + if query: + return "{}?{}".format(base, urlencode(query, doseq=True)) + else: + return base + + +def get_uri(endpoint_context, request, uri_type): + """ verify that the redirect URI is reasonable. + + :param endpoint_context: An EndpointContext instance + :param request: The Authorization request + :param uri_type: 'redirect_uri' or 'post_logout_redirect_uri' + :return: redirect_uri + """ + uri = "" + + if uri_type in request: + verify_uri(endpoint_context, request, uri_type) + uri = request[uri_type] + else: + uris = "{}s".format(uri_type) + client_id = str(request["client_id"]) + if client_id in endpoint_context.cdb: + _specs = endpoint_context.cdb[client_id].get(uris) + if not _specs: + raise ParameterError("Missing {} and none registered".format(uri_type)) + + if len(_specs) > 1: + raise ParameterError( + "Missing {} and more than one registered".format(uri_type) + ) -def re_authenticate(request, authn): - return False + uri = join_query(*_specs[0]) + + return uri + + +def authn_args_gather(request, authn_class_ref, cinfo, **kwargs): + """ + Gather information to be used by the authentication method + + :param request: The request either as a dictionary or as a Message instance + :param authn_class_ref: Authentication class reference + :param cinfo: Client information + :param kwargs: Extra keyword arguments + :return: Authentication arguments + """ + authn_args = { + "authn_class_ref": authn_class_ref, + "return_uri": request["redirect_uri"], + } + + if isinstance(request, Message): + authn_args["query"] = request.to_urlencoded() + elif isinstance(request, dict): + authn_args["query"] = urlencode(request) + else: + ValueError("Wrong request format") + + if "req_user" in kwargs: + authn_args["as_user"] = (kwargs["req_user"],) + + # Below are OIDC specific. Just ignore if OAuth2 + if cinfo: + for attr in ["policy_uri", "logo_uri", "tos_uri"]: + if cinfo.get(attr): + authn_args[attr] = cinfo[attr] + + for attr in ["ui_locales", "acr_values", "login_hint"]: + if request.get(attr): + authn_args[attr] = request[attr] + + return authn_args def check_unknown_scopes_policy(request_info, cinfo, endpoint_context): @@ -69,7 +256,7 @@ def check_unknown_scopes_policy(request_info, cinfo, endpoint_context): # this prevents that authz would be released for unavailable scopes for scope in request_info['scope']: if op_capabilities.get('deny_unknown_scopes') and \ - scope not in client_allowed_scopes: + scope not in client_allowed_scopes: _msg = '{} requested an unauthorized scope ({})' logger.warning(_msg.format(cinfo['client_id'], scope)) @@ -100,27 +287,50 @@ class Authorization(Endpoint): def __init__(self, endpoint_context, **kwargs): Endpoint.__init__(self, endpoint_context, **kwargs) - # self.pre_construct.append(self._pre_construct) self.post_parse_request.append(self._do_request_uri) self.post_parse_request.append(self._post_parse_request) self.allowed_request_algorithms = AllowedAlgorithms(ALG_PARAMS) - # Has to be done elsewhere. To make sure things happen in order. - # self.scopes_supported = available_scopes(endpoint_context) def filter_request(self, endpoint_context, req): return req + def extra_response_args(self, aresp): + return aresp + def verify_response_type(self, request, cinfo): # Checking response types _registered = [set(rt.split(" ")) for rt in cinfo.get("response_types", [])] if not _registered: # If no response_type is registered by the client then we'll - # code which it the default according to the OIDC spec. + # use code. _registered = [{"code"}] # Is the asked for response_type among those that are permitted return set(request["response_type"]) in _registered + def mint_token(self, token_type, grant, session_id, based_on=None): + _mngr = self.endpoint_context.session_manager + usage_rules = grant.usage_rules.get(token_type, {}) + + token = grant.mint_token( + session_id=session_id, + endpoint_context=self.endpoint_context, + token_type=token_type, + based_on=based_on, + usage_rules=usage_rules + ) + + _exp_in = usage_rules.get("expires_in") + if isinstance(_exp_in, str): + _exp_in = int(_exp_in) + if _exp_in: + token.expires_at = utc_time_sans_frac() + _exp_in + + self.endpoint_context.session_manager.set(unpack_session_key(session_id), + grant) + + return token + def _do_request_uri(self, request, client_id, endpoint_context, **kwargs): _request_uri = request.get("request_uri") if _request_uri: @@ -144,35 +354,41 @@ def _do_request_uri(self, request, client_id, endpoint_context, **kwargs): if _registered: # Before matching remove a possible fragment _p = _request_uri.split("#") - if _p[0] not in _registered: + # ignore registered fragments for now. + if _p[0] not in [l[0] for l in _registered]: raise ValueError("A request_uri outside the registered") + # Fetch the request _resp = endpoint_context.httpc.get( _request_uri, **endpoint_context.httpc_params ) if _resp.status_code == 200: - args = {"keyjar": endpoint_context.keyjar} - request = self.request_cls().from_jwt(_resp.text, **args) + args = {"keyjar": endpoint_context.keyjar, "issuer": client_id} + _ver_request = self.request_cls().from_jwt(_resp.text, **args) self.allowed_request_algorithms( client_id, endpoint_context, - request.jws_header.get("alg", "RS256"), + _ver_request.jws_header.get("alg", "RS256"), "sign", ) - if request.jwe_header is not None: + if _ver_request.jwe_header is not None: self.allowed_request_algorithms( client_id, endpoint_context, - request.jws_header.get("alg"), + _ver_request.jws_header.get("alg"), "enc_alg", ) self.allowed_request_algorithms( client_id, endpoint_context, - request.jws_header.get("enc"), + _ver_request.jws_header.get("enc"), "enc_enc", ) - request[verified_claim_name("request")] = request + # The protected info overwrites the non-protected + for k, v in _ver_request.items(): + request[k] = v + + request[verified_claim_name("request")] = _ver_request else: raise ServiceError("Got a %s response", _resp.status) @@ -218,7 +434,7 @@ def _post_parse_request(self, request, client_id, endpoint_context, **kwargs): except (RedirectURIError, ParameterError, UnknownClient) as err: return self.error_cls( error="invalid_request", - error_description="{}: {}".format(err.__class__.__name__, err), + error_description="{}:{}".format(err.__class__.__name__, err), ) else: request["redirect_uri"] = redirect_uri @@ -292,15 +508,25 @@ def setup_auth(self, request, redirect_uri, cinfo, cookie, acr=None, **kwargs): else: identity = json.loads(as_unicode(_id)) - session = self.endpoint_context.sdb[identity.get("sid")] - if not session or "revoked" in session: + try: + _csi = self.endpoint_context.session_manager[identity.get("sid")] + except Revoked: identity = None + else: + if _csi.is_active() is False: + identity = None authn_args = authn_args_gather(request, authn_class_ref, cinfo, **kwargs) + _mngr = self.endpoint_context.session_manager + _session_id = "" # To authenticate or Not if identity is None: # No! logger.info("No active authentication") + logger.debug( + "Known clients: {}".format(list(self.endpoint_context.cdb.keys())) + ) + if "prompt" in request and "none" in request["prompt"]: # Need to authenticate but not allowed return { @@ -319,14 +545,7 @@ def setup_auth(self, request, redirect_uri, cinfo, cookie, acr=None, **kwargs): # I get back a dictionary user = identity["uid"] if "req_user" in kwargs: - sids = self.endpoint_context.sdb.get_sids_by_sub(kwargs["req_user"]) - if ( - sids - and user - != self.endpoint_context.sdb.get_authentication_event( - sids[-1] - ).uid - ): + if user != kwargs["req_user"]: logger.debug("Wanted to be someone else!") if "prompt" in request and "none" in request["prompt"]: # Need to authenticate but not allowed @@ -337,74 +556,49 @@ def setup_auth(self, request, redirect_uri, cinfo, cookie, acr=None, **kwargs): else: return {"function": authn, "args": authn_args} - authn_event = create_authn_event( - identity["uid"], - identity.get("salt", ""), - authn_info=authn_class_ref, - time_stamp=_ts, - ) - if "valid_until" in authn_event: - vu = time.time() + authn.kwargs.get("expires_in", 0.0) - authn_event["valid_until"] = vu + if "sid" in identity: + _session_id = identity["sid"] - return {"authn_event": authn_event, "identity": identity, "user": user} + # make sure the client is the same + _uid, _cid, _gid = unpack_session_key(_session_id) + if request["client_id"] != _cid: + return {"function": authn, "args": authn_args} - def create_authn_response(self, request, sid): - """ + grant = _mngr[_session_id] + if grant.is_active() is False: + return {"function": authn, "args": authn_args} + elif request != grant.authorization_request: + authn_event = _mngr.get_authentication_event(session_id=_session_id) + if authn_event.is_valid() is False: # if not valid, do new login + return {"function": authn, "args": authn_args} - :param self: - :param request: - :param sid: - :return: - """ - # create the response - aresp = AuthorizationResponse() - if request.get("state"): - aresp["state"] = request["state"] + # create new grant + _session_id = _mngr.create_grant(authn_event=authn_event, + auth_req=request, + user_id=user, + client_id=request["client_id"]) - if "response_type" in request and request["response_type"] == ["none"]: - fragment_enc = False + if _session_id: + authn_event = _mngr.get_authentication_event(session_id=_session_id) + if authn_event.is_valid() is False: # if not valid, do new login + return {"function": authn, "args": authn_args} else: - _context = self.endpoint_context - _sinfo = _context.sdb[sid] - - if request.get("scope"): - aresp["scope"] = request["scope"] - - rtype = set(request["response_type"][:]) - handled_response_type = [] - - fragment_enc = True - if len(rtype) == 1 and "code" in rtype: - fragment_enc = False - - if "code" in request["response_type"]: - _code = aresp["code"] = _context.sdb[sid]["code"] - handled_response_type.append("code") - else: - _context.sdb.update(sid, code=None) - _code = None - - if "token" in rtype: - _dic = _context.sdb.upgrade_to_token(issue_refresh=False, key=sid) - - logger.debug("_dic: %s" % sanitize(_dic)) - for key, val in _dic.items(): - if key in aresp.parameters() and val is not None: - aresp[key] = val - - handled_response_type.append("token") + authn_event = create_authn_event( + identity["uid"], + authn_info=authn_class_ref, + time_stamp=_ts, + ) + _exp_in = authn.kwargs.get("expires_in") + if _exp_in and "valid_until" in authn_event: + authn_event["valid_until"] = utc_time_sans_frac() + _exp_in - _access_token = aresp.get("access_token", None) + _token_usage_rules = self.endpoint_context.authz.usage_rules( + request["client_id"]) + _session_id = _mngr.create_session(authn_event=authn_event, auth_req=request, + user_id=user, client_id=request["client_id"], + token_usage_rules=_token_usage_rules) - not_handled = rtype.difference(handled_response_type) - if not_handled: - resp = self.error_cls( - error="invalid_request", error_description="unsupported_response_type" - ) - return {"response_args": resp, "fragment_enc": fragment_enc} - - return {"response_args": aresp, "fragment_enc": fragment_enc} + return {"session_id": _session_id, "identity": identity, "user": user} def aresp_check(self, aresp, request): return "" @@ -441,29 +635,122 @@ def response_mode(self, request, **kwargs): def error_response(self, response_info, error, error_description): resp = self.error_cls( - error=error, error_description=error_description + error=error, error_description=str(error_description) ) response_info["response_args"] = resp return response_info - def post_authentication(self, user, request, sid, **kwargs): + def create_authn_response(self, request: Union[dict, Message], sid: str) -> dict: """ - Things that are done after a successful authentication. - :param user: :param request: :param sid: + :return: + """ + # create the response + aresp = self.response_cls() + if request.get("state"): + aresp["state"] = request["state"] + + if "response_type" in request and request["response_type"] == ["none"]: + fragment_enc = False + else: + _context = self.endpoint_context + _mngr = self.endpoint_context.session_manager + _sinfo = _mngr.get_session_info(sid, grant=True) + + if request.get("scope"): + aresp["scope"] = request["scope"] + + rtype = set(request["response_type"][:]) + handled_response_type = [] + + fragment_enc = True + if len(rtype) == 1 and "code" in rtype: + fragment_enc = False + + grant = _sinfo["grant"] + + if "code" in request["response_type"]: + _code = self.mint_token( + token_type='authorization_code', + grant=grant, + session_id= _sinfo["session_id"]) + aresp["code"] = _code.value + handled_response_type.append("code") + else: + _code = None + + if "token" in rtype: + if _code: + based_on = _code + else: + based_on = None + + _access_token = self.mint_token(token_type="access_token", + grant=grant, + session_id=_sinfo["session_id"], + based_on=based_on) + aresp['access_token'] = _access_token.value + aresp['token_type'] = "Bearer" + if _access_token.expires_at: + aresp["expires_in"] = _access_token.expires_at - utc_time_sans_frac() + handled_response_type.append("token") + else: + _access_token = None + + if "id_token" in request["response_type"]: + kwargs = {} + if {"code", "id_token", "token"}.issubset(rtype): + kwargs = {"code": _code.value, "access_token": _access_token.value} + elif {"code", "id_token"}.issubset(rtype): + kwargs = {"code": _code.value} + elif {"id_token", "token"}.issubset(rtype): + kwargs = {"access_token": _access_token.value} + + try: + id_token = _context.idtoken.make(sid, **kwargs) + except (JWEException, NoSuitableSigningKeys) as err: + logger.warning(str(err)) + resp = self.error_cls( + error="invalid_request", + error_description="Could not sign/encrypt id_token", + ) + return {"response_args": resp, "fragment_enc": fragment_enc} + + aresp["id_token"] = id_token + _mngr.update([_sinfo["user_id"], _sinfo["client_id"]], + {"id_token": id_token}) + handled_response_type.append("id_token") + + not_handled = rtype.difference(handled_response_type) + if not_handled: + resp = self.error_cls( + error="invalid_request", error_description="unsupported_response_type" + ) + return {"response_args": resp, "fragment_enc": fragment_enc} + + aresp = self.extra_response_args(aresp) + + return {"response_args": aresp, "fragment_enc": fragment_enc} + + def post_authentication(self, request: Union[dict, Message], + session_id: str, **kwargs) -> dict: + """ + Things that are done after a successful authentication. + + :param request: The authorization request + :param session_id: Session identifier :param kwargs: :return: A dictionary with 'response_args' """ response_info = {} + _mngr = self.endpoint_context.session_manager # Do the authorization try: - permission = self.endpoint_context.authz( - user, client_id=request["client_id"] - ) + grant = self.endpoint_context.authz(session_id, request=request) except ToOld as err: return self.error_response( response_info, @@ -475,8 +762,9 @@ def post_authentication(self, user, request, sid, **kwargs): response_info, "access_denied", "{}".format(err.args) ) else: + user_id, client_id, grant_id = unpack_session_key(session_id) try: - self.endpoint_context.sdb.update(sid, permission=permission) + _mngr.set([user_id, client_id, grant_id], grant) except Exception as err: return self.error_response( response_info, "server_error", "{}".format(err.args) @@ -484,12 +772,10 @@ def post_authentication(self, user, request, sid, **kwargs): logger.debug("response type: %s" % request["response_type"]) - if self.endpoint_context.sdb.is_session_revoked(sid): - return self.error_response( - response_info, "access_denied", "Session is revoked" - ) + response_info = self.create_authn_response(request, session_id) + response_info["session_id"] = session_id - response_info = self.create_authn_response(request, sid) + logger.debug("Known clients: {}".format(list(self.endpoint_context.cdb.keys()))) try: redirect_uri = get_uri(self.endpoint_context, request, "redirect_uri") @@ -507,10 +793,8 @@ def post_authentication(self, user, request, sid, **kwargs): _cookie = new_cookie( self.endpoint_context, - sub=user, - sid=sid, - state=request["state"], - client_id=request["client_id"], + sid=session_id, + state=request.get("state"), cookie_name=self.endpoint_context.cookie_name["session"], ) @@ -530,77 +814,79 @@ def post_authentication(self, user, request, sid, **kwargs): return response_info - def authz_part2( - self, - user, - authn_event, - request, - subject_type=None, - acr=None, - salt=None, - sector_id=None, - **kwargs, - ): + # def setup_client_session(self, user_id: str, request: dict) -> str: + # _mngr = self.endpoint_context.session_manager + # client_id = request['client_id'] + # + # client_info = ClientSessionInfo( + # authorization_request=request, + # sub=_mngr.sub_func['public'](user_id, salt=_mngr.salt) + # ) + # + # _mngr.set([user_id, client_id], client_info) + # return session_key(user_id, client_id) + + def authz_part2(self, request, session_id, **kwargs): """ After the authentication this is where you should end up :param user: - :param authn_event: The Authorization Event :param request: The Authorization Request - :param subject_type: The subject_type - :param acr: The acr - :param salt: The salt used to produce the sub - :param sector_id: The sector_id used to produce the sub + :param session_id: Session identifier :param kwargs: possible other parameters :return: A redirect to the redirect_uri of the client """ - sid = setup_session( - self.endpoint_context, - request, - user, - acr=acr, - salt=salt, - authn_event=authn_event, - subject_type=subject_type, - sector_id=sector_id, - ) try: - resp_info = self.post_authentication(user, request, sid, **kwargs) + resp_info = self.post_authentication(request, session_id, **kwargs) except Exception as err: return self.error_response({}, "server_error", err) if "check_session_iframe" in self.endpoint_context.provider_info: ec = self.endpoint_context salt = rndstr() - if not ec.sdb.is_session_revoked(sid): - authn_event = ec.sdb.get_authentication_event( - sid - ) # use the last session - _state = b64e( - as_bytes(json.dumps({"authn_time": authn_event["authn_time"]})) - ) + try: + authn_event = ec.session_manager.get_authentication_event(session_id) + except KeyError: + return self.error_response({}, "server_error", "No such session") + else: + if authn_event.is_valid() is False: + return self.error_response({}, "server_error", "Authentication has timed out") + _state = b64e( + as_bytes(json.dumps({"authn_time": authn_event["authn_time"]})) + ) + + opbs_value = '' + if hasattr(ec.cookie_dealer, 'create_cookie'): session_cookie = ec.cookie_dealer.create_cookie( as_unicode(_state), typ="session", cookie_name=ec.cookie_name["session_management"], + same_site="None", + http_only=False, ) opbs = session_cookie[ec.cookie_name["session_management"]] - + opbs_value = opbs.value + else: + session_cookie = None logger.debug( - "compute_session_state: client_id=%s, origin=%s, opbs=%s, salt=%s", - request["client_id"], - resp_info["return_uri"], - opbs.value, - salt, - ) + "Failed to set Cookie, that's not configured in main configuration.") + + logger.debug( + "compute_session_state: client_id=%s, origin=%s, opbs=%s, salt=%s", + request["client_id"], + resp_info["return_uri"], + opbs_value, + salt, + ) - _session_state = compute_session_state( - opbs.value, salt, request["client_id"], resp_info["return_uri"] - ) + _session_state = compute_session_state( + opbs_value, salt, request["client_id"], resp_info["return_uri"] + ) + if opbs_value and session_cookie: if "cookie" in resp_info: if isinstance(resp_info["cookie"], list): resp_info["cookie"].append(session_cookie) @@ -609,7 +895,7 @@ def authz_part2( else: resp_info["cookie"] = session_cookie - resp_info["response_args"]["session_state"] = _session_state + resp_info["response_args"]["session_state"] = _session_state # Mix-Up mitigation resp_info["response_args"]["iss"] = self.endpoint_context.issuer @@ -617,30 +903,35 @@ def authz_part2( return resp_info - def process_request(self, request_info=None, **kwargs): + def do_request_user(self, request_info, **kwargs): + return kwargs + + def process_request(self, request: Union[Message, dict], **kwargs): """ The AuthorizationRequest endpoint - :param request_info: The authorization request as a dictionary + :param request: The authorization request as a Message instance :return: dictionary """ - if isinstance(request_info, self.error_cls): - return request_info + if isinstance(request, self.error_cls): + return request - _cid = request_info["client_id"] + _cid = request["client_id"] cinfo = self.endpoint_context.cdb[_cid] - logger.debug("client {}: {}".format(_cid, cinfo)) # this apply the default optionally deny_unknown_scopes policy - check_unknown_scopes_policy(request_info, cinfo, self.endpoint_context) + if cinfo: + check_unknown_scopes_policy(request, cinfo, self.endpoint_context) cookie = kwargs.get("cookie", "") if cookie: del kwargs["cookie"] + kwargs = self.do_request_user(request_info=request, **kwargs) + info = self.setup_auth( - request_info, request_info["redirect_uri"], cinfo, cookie, **kwargs + request, request["redirect_uri"], cinfo, cookie, **kwargs ) if "error" in info: @@ -649,18 +940,81 @@ def process_request(self, request_info=None, **kwargs): _function = info.get("function") if not _function: logger.debug("- authenticated -") - logger.debug("AREQ keys: %s" % request_info.keys()) - res = self.authz_part2( - info["user"], info["authn_event"], request_info, cookie=cookie - ) - return res + logger.debug("AREQ keys: %s" % request.keys()) + return self.authz_part2(request=request, cookie=cookie, **info) try: # Run the authentication function return { "http_response": _function(**info["args"]), - "return_uri": request_info["redirect_uri"], + "return_uri": request["redirect_uri"], } except Exception as err: logger.exception(err) return {"http_response": "Internal error: {}".format(err)} + + +class AllowedAlgorithms: + def __init__(self, algorithm_parameters): + self.algorithm_parameters = algorithm_parameters + + def __call__(self, client_id, endpoint_context, alg, alg_type): + _cinfo = endpoint_context.cdb[client_id] + _pinfo = endpoint_context.provider_info + + _reg, _sup = self.algorithm_parameters[alg_type] + _allowed = _cinfo.get(_reg) + if _allowed is None: + _allowed = _pinfo.get(_sup) + + if alg not in _allowed: + logger.error( + "Signing alg user: {} not among allowed: {}".format(alg, _allowed) + ) + raise ValueError("Not allowed '%s' algorithm used", alg) + + +def re_authenticate(request, authn): + return False + +# class Authorization(authorization.Authorization): +# request_cls = oauth2.AuthorizationRequest +# response_cls = oauth2.AuthorizationResponse +# error_cls = oauth2.AuthorizationErrorResponse +# request_format = "urlencoded" +# response_format = "urlencoded" +# response_placement = "url" +# endpoint_name = "authorization_endpoint" +# name = "authorization" +# default_capabilities = { +# "claims_parameter_supported": True, +# "request_parameter_supported": True, +# "request_uri_parameter_supported": True, +# "response_types_supported": ["code", "token", "code token"], +# "response_modes_supported": ["query", "fragment", "form_post"], +# "request_object_signing_alg_values_supported": None, +# "request_object_encryption_alg_values_supported": None, +# "request_object_encryption_enc_values_supported": None, +# "grant_types_supported": ["authorization_code", "implicit"], +# "scopes_supported": [], +# } +# +# def __init__(self, endpoint_context, **kwargs): +# authorization.Authorization.__init__(self, endpoint_context, **kwargs) +# # self.pre_construct.append(self._pre_construct) +# self.post_parse_request.append(self._do_request_uri) +# self.post_parse_request.append(self._post_parse_request) +# # Has to be done elsewhere. To make sure things happen in order. +# # self.scopes_supported = available_scopes(endpoint_context) +# +# def setup_client_session(self, user_id: str, request: dict) -> str: +# _mngr = self.endpoint_context.session_manager +# client_id = request['client_id'] +# +# client_info = ClientSessionInfo( +# authorization_request=request, +# sub=_mngr.sub_func['public'](user_id, salt=_mngr.salt) +# ) +# +# _mngr.set([user_id, client_id], client_info) +# return session_key(user_id, client_id) diff --git a/src/oidcendpoint/oauth2/introspection.py b/src/oidcendpoint/oauth2/introspection.py index 4211bd8..4126488 100644 --- a/src/oidcendpoint/oauth2/introspection.py +++ b/src/oidcendpoint/oauth2/introspection.py @@ -1,11 +1,10 @@ """Implements RFC7662""" import logging -from oidcendpoint.token_handler import UnknownToken from oidcmsg import oauth2 -from oidcmsg.time_util import utc_time_sans_frac from oidcendpoint.endpoint import Endpoint +from oidcendpoint.token.exception import UnknownToken LOGGER = logging.getLogger(__name__) @@ -23,59 +22,36 @@ class Introspection(Endpoint): def __init__(self, **kwargs): Endpoint.__init__(self, **kwargs) self.offset = kwargs.get("offset", 0) - self.enable_claims_per_client = kwargs.get("enable_claims_per_client", False) - - def get_client_id_from_token(self, endpoint_context, token, request=None): - """ - Will try to match tokens against information in the session DB. - - :param endpoint_context: - :param token: - :param request: - :return: client_id if there was a match - """ - sinfo = endpoint_context.sdb[token] - return sinfo["authn_req"]["client_id"] - - def _get_client_claims(self, token): - client_id = self.get_client_id_from_token(self.endpoint_context, token) - client = self.endpoint_context.cdb.get(client_id, {}) - return client.get("introspection_claims") - - def _get_user_info(self, token_info): - user_id = self.endpoint_context.sdb.sso_db.get_uid_by_sid(token_info["sid"]) - return self.endpoint_context.userinfo(user_id, client_id=None) - - def _add_claims(self, token_info, claims, payload): - user_info = self._get_user_info(token_info) - for attr in claims: - try: - payload[attr] = user_info[attr] - except KeyError: - pass - - def _introspect(self, token): - try: - info = self.endpoint_context.sdb[token] - except (KeyError, UnknownToken): - return None + def _introspect(self, token, client_id, grant): # Make sure that the token is an access_token or a refresh_token - if token != info.get("access_token") and token != info.get("refresh_token"): + if token.type not in ["access_token", "refresh_token"]: return None - eat = info.get("expires_at") - if eat and eat < utc_time_sans_frac(): + if not token.is_active(): return None - if info: # Now what can be returned ? - ret = info.to_dict() - ret["iss"] = self.endpoint_context.issuer - - if "scope" not in ret: - ret["scope"] = " ".join(info["authn_req"]["scope"]) - - return ret + scope = token.scope + if not scope: + scope = grant.scope + aud = token.resources + if not aud: + aud = grant.resources + + ret = { + "active": True, + "scope": " ".join(scope), + "client_id": client_id, + "token_type": token.type, + "exp": token.expires_at, + "iat": token.issued_at, + "sub": grant.sub, + "iss": self.endpoint_context.issuer + } + if aud: + ret["aud"] = aud + + return ret def process_request(self, request=None, **kwargs): """ @@ -88,29 +64,38 @@ def process_request(self, request=None, **kwargs): if "error" in _introspect_request: return _introspect_request - _token = _introspect_request["token"] + request_token = _introspect_request["token"] _resp = self.response_cls(active=False) - _info = self._introspect(_token) + try: + _session_info = self.endpoint_context.session_manager.get_session_info_by_token( + request_token, grant=True) + except UnknownToken: + return {"response_args": _resp} + + _grant = _session_info["grant"] + _token = _grant.get_token(request_token) + + _info = self._introspect(_token, _session_info["client_id"], _grant) if _info is None: return {"response_args": _resp} if "release" in self.kwargs: if "username" in self.kwargs["release"]: try: - _info["username"] = self.endpoint_context.userinfo.search( - sub=_info["sub"] - ) + _info["username"] = _session_info["user_id"] except KeyError: pass _resp.update(_info) _resp.weed() - if self.enable_claims_per_client: - client_claims = self._get_client_claims(_token) - if client_claims: - self._add_claims(_info, client_claims, _resp) + _claims_restriction = _grant.claims.get("introspection") + if _claims_restriction: + user_info = self.endpoint_context.claims_interface.get_user_claims( + _session_info["user_id"], _claims_restriction) + if user_info: + _resp.update(user_info) _resp["active"] = True diff --git a/src/oidcendpoint/oidc/add_on/pkce.py b/src/oidcendpoint/oidc/add_on/pkce.py index 3548938..e760dd0 100644 --- a/src/oidcendpoint/oidc/add_on/pkce.py +++ b/src/oidcendpoint/oidc/add_on/pkce.py @@ -5,9 +5,6 @@ from oidcmsg.oauth2 import AuthorizationErrorResponse from oidcmsg.oidc import TokenErrorResponse -from oidcendpoint.exception import MultipleCodeUsage -from oidcendpoint.token_handler import UnknownToken - LOGGER = logging.getLogger(__name__) @@ -96,11 +93,14 @@ def post_token_parse(request, client_id, endpoint_context, **kwargs): return request try: - _info = endpoint_context.sdb[request["code"]] - except (KeyError, UnknownToken, MultipleCodeUsage): - # This will be handled by process_request - return request - _authn_req = _info["authn_req"] + _session_info = endpoint_context.session_manager.get_session_info_by_token( + request["code"], grant=True) + except KeyError: + return TokenErrorResponse( + error="invalid_grant", error_description="Unknown access grant" + ) + + _authn_req = _session_info["grant"].authorization_request if "code_challenge" in _authn_req: if "code_verifier" not in request: @@ -109,11 +109,11 @@ def post_token_parse(request, client_id, endpoint_context, **kwargs): error_description="Missing code_verifier", ) - _method = _info["authn_req"]["code_challenge_method"] + _method = _authn_req["code_challenge_method"] if not verify_code_challenge( request["code_verifier"], - _info["authn_req"]["code_challenge"], + _authn_req["code_challenge"], _method, ): return TokenErrorResponse( @@ -139,6 +139,7 @@ def add_pkce_support(endpoint, **kwargs): return authn_endpoint.post_parse_request.append(post_authn_parse) + token_endpoint.post_parse_request.append(post_token_parse) if "essential" not in kwargs: kwargs["essential"] = False @@ -154,5 +155,3 @@ def add_pkce_support(endpoint, **kwargs): kwargs["code_challenge_methods"][method] = CC_METHOD[method] authn_endpoint.endpoint_context.args["pkce"] = kwargs - - token_endpoint.post_parse_request.append(post_token_parse) diff --git a/src/oidcendpoint/oidc/authorization.py b/src/oidcendpoint/oidc/authorization.py index dbb70fb..be49d57 100755 --- a/src/oidcendpoint/oidc/authorization.py +++ b/src/oidcendpoint/oidc/authorization.py @@ -1,43 +1,13 @@ -import json import logging +from urllib.parse import urlsplit -from cryptojwt import BadSyntax -from cryptojwt.jwe.exception import JWEException -from cryptojwt.jws.exception import NoSuitableSigningKeys -from cryptojwt.jwt import utc_time_sans_frac -from cryptojwt.utils import as_bytes -from cryptojwt.utils import as_unicode -from cryptojwt.utils import b64d -from cryptojwt.utils import b64e -from oidcendpoint.token_handler import UnknownToken from oidcmsg import oidc -from oidcmsg.exception import ParameterError from oidcmsg.oidc import Claims from oidcmsg.oidc import verified_claim_name -from oidcendpoint import rndstr -from oidcendpoint import sanitize -from oidcendpoint.authn_event import create_authn_event -from oidcendpoint.common.authorization import FORM_POST -from oidcendpoint.common.authorization import AllowedAlgorithms -from oidcendpoint.common.authorization import authn_args_gather -from oidcendpoint.common.authorization import get_uri -from oidcendpoint.common.authorization import inputs -from oidcendpoint.common.authorization import max_age -from oidcendpoint.cookie import append_cookie -from oidcendpoint.cookie import compute_session_state -from oidcendpoint.cookie import new_cookie -from oidcendpoint.endpoint import Endpoint -from oidcendpoint.exception import InvalidRequest -from oidcendpoint.exception import NoSuchAuthentication -from oidcendpoint.exception import RedirectURIError -from oidcendpoint.exception import ServiceError -from oidcendpoint.exception import TamperAllert -from oidcendpoint.exception import ToOld -from oidcendpoint.exception import UnknownClient -from oidcendpoint.oauth2.authorization import check_unknown_scopes_policy -from oidcendpoint.session import setup_session -from oidcendpoint.user_authn.authn_context import pick_auth +from oidcendpoint.oauth2 import authorization +from oidcendpoint.session import session_key +from oidcendpoint.session.info import ClientSessionInfo logger = logging.getLogger(__name__) @@ -68,6 +38,11 @@ def acr_claims(request): return acrdef["values"] +def host_component(url): + res = urlsplit(url) + return "{}://{}".format(res.scheme, res.netloc) + + ALG_PARAMS = { "sign": [ "request_object_signing_alg", @@ -92,7 +67,7 @@ def re_authenticate(request, authn): return False -class Authorization(Endpoint): +class Authorization(authorization.Authorization): request_cls = oidc.AuthorizationRequest response_cls = oidc.AuthorizationResponse error_cls = oidc.AuthorizationErrorResponse @@ -123,601 +98,41 @@ class Authorization(Endpoint): } def __init__(self, endpoint_context, **kwargs): - Endpoint.__init__(self, endpoint_context, **kwargs) + authorization.Authorization.__init__(self, endpoint_context, **kwargs) # self.pre_construct.append(self._pre_construct) self.post_parse_request.append(self._do_request_uri) self.post_parse_request.append(self._post_parse_request) - self.allowed_request_algorithms = AllowedAlgorithms(ALG_PARAMS) - - def filter_request(self, endpoint_context, req): - return req - - def verify_response_type(self, request, cinfo): - # Checking response types - _registered = [set(rt.split(" ")) for rt in cinfo.get("response_types", [])] - if not _registered: - # If no response_type is registered by the client then we'll - # code which it the default according to the OIDC spec. - _registered = [{"code"}] - - # Is the asked for response_type among those that are permitted - return set(request["response_type"]) in _registered - - def _do_request_uri(self, request, client_id, endpoint_context, **kwargs): - _request_uri = request.get("request_uri") - if _request_uri: - # Do I do pushed authorization requests ? - if "pushed_authorization" in endpoint_context.endpoint: - # Is it a UUID urn - if _request_uri.startswith("urn:uuid:"): - _req = endpoint_context.par_db.get(_request_uri) - if _req: - del endpoint_context.par_db[_request_uri] # One time - # usage - return _req - else: - raise ValueError("Got a request_uri I can not resolve") - - # Do I support request_uri ? - _supported = endpoint_context.provider_info.get( - "request_uri_parameter_supported", True - ) - _registered = endpoint_context.cdb[client_id].get("request_uris") - # Not registered should be handled else where - if _registered: - # Before matching remove a possible fragment - _p = _request_uri.split("#") - # ignore registered fragments for now. - if _p[0] not in [l[0] for l in _registered]: - raise ValueError("A request_uri outside the registered") - - # Fetch the request - _resp = endpoint_context.httpc.get( - _request_uri, **endpoint_context.httpc_params - ) - if _resp.status_code == 200: - args = {"keyjar": endpoint_context.keyjar, "issuer": client_id} - _ver_request = self.request_cls().from_jwt(_resp.text, **args) - self.allowed_request_algorithms( - client_id, - endpoint_context, - _ver_request.jws_header.get("alg", "RS256"), - "sign", - ) - if _ver_request.jwe_header is not None: - self.allowed_request_algorithms( - client_id, - endpoint_context, - _ver_request.jws_header.get("alg"), - "enc_alg", - ) - self.allowed_request_algorithms( - client_id, - endpoint_context, - _ver_request.jws_header.get("enc"), - "enc_enc", - ) - # The protected info overwrites the non-protected - for k, v in _ver_request.items(): - request[k] = v - - request[verified_claim_name("request")] = _ver_request - else: - raise ServiceError("Got a %s response", _resp.status) - - return request - - def _post_parse_request(self, request, client_id, endpoint_context, **kwargs): - """ - Verify the authorization request. - - :param endpoint_context: - :param request: - :param client_id: - :param kwargs: - :return: - """ - if not request: - logger.debug("No AuthzRequest") - return self.error_cls( - error="invalid_request", error_description="Can not parse AuthzRequest" - ) - - request = self.filter_request(endpoint_context, request) - - _cinfo = endpoint_context.cdb.get(client_id) - if not _cinfo: - logger.error( - "Client ID ({}) not in client database".format(request["client_id"]) - ) - return self.error_cls( - error="unauthorized_client", error_description="unknown client" - ) - - # Is the asked for response_type among those that are permitted - if not self.verify_response_type(request, _cinfo): - return self.error_cls( - error="invalid_request", - error_description="Trying to use unregistered response_type", - ) - - # Get a verified redirect URI - try: - redirect_uri = get_uri(endpoint_context, request, "redirect_uri") - except (RedirectURIError, ParameterError, UnknownClient) as err: - return self.error_cls( - error="invalid_request", - error_description="{}:{}".format(err.__class__.__name__, err), - ) - else: - request["redirect_uri"] = redirect_uri - - return request - - def pick_authn_method(self, request, redirect_uri, acr=None, **kwargs): - auth_id = kwargs.get("auth_method_id") - if auth_id: - return self.endpoint_context.authn_broker[auth_id] - - if acr: - res = self.endpoint_context.authn_broker.pick(acr) - else: - res = pick_auth(self.endpoint_context, request) - - if res: - return res - else: - return { - "error": "access_denied", - "error_description": "ACR I do not support", - "return_uri": redirect_uri, - "return_type": request["response_type"], - } - - def setup_auth(self, request, redirect_uri, cinfo, cookie, acr=None, **kwargs): - """ - - :param request: The authorization/authentication request - :param redirect_uri: - :param cinfo: client info - :param cookie: - :param acr: Default ACR, if nothing else is specified - :param kwargs: - :return: - """ - - res = self.pick_authn_method(request, redirect_uri, acr, **kwargs) - - authn = res["method"] - authn_class_ref = res["acr"] - session = None - - try: - _auth_info = kwargs.get("authn", "") - if "upm_answer" in request and request["upm_answer"] == "true": - _max_age = 0 - else: - _max_age = max_age(request) - - identity, _ts = authn.authenticated_as( - cookie, authorization=_auth_info, max_age=_max_age - ) - except (NoSuchAuthentication, TamperAllert): - identity = None - _ts = 0 - except ToOld: - logger.info("Too old authentication") - identity = None - _ts = 0 - except UnknownToken: - logger.info("Unknown Token") - identity = None - _ts = 0 - else: - if identity: - try: # If identity['uid'] is in fact a base64 encoded JSON string - _id = b64d(as_bytes(identity["uid"])) - except BadSyntax: - pass - else: - identity = json.loads(as_unicode(_id)) - - try: - session = self.endpoint_context.sdb[identity.get("sid")] - except UnknownToken: - identity= None - else: - if not session or "revoked" in session: - identity = None - - authn_args = authn_args_gather(request, authn_class_ref, cinfo, **kwargs) - - # To authenticate or Not - if identity is None: # No! - logger.info("No active authentication") - logger.debug( - "Known clients: {}".format(list(self.endpoint_context.cdb.keys())) - ) - - if "prompt" in request and "none" in request["prompt"]: - # Need to authenticate but not allowed - return { - "error": "login_required", - "return_uri": redirect_uri, - "return_type": request["response_type"], - } - else: - return {"function": authn, "args": authn_args} - else: - logger.info("Active authentication") - if re_authenticate(request, authn): - # demand re-authentication - return {"function": authn, "args": authn_args} - else: - # I get back a dictionary - user = identity["uid"] - if "req_user" in kwargs: - sids = self.endpoint_context.sdb.get_sids_by_sub(kwargs["req_user"]) - if ( - sids - and user - != self.endpoint_context.sdb.get_authentication_event( - sids[-1] - ).uid - ): - logger.debug("Wanted to be someone else!") - if "prompt" in request and "none" in request["prompt"]: - # Need to authenticate but not allowed - return { - "error": "login_required", - "return_uri": redirect_uri, - } - else: - return {"function": authn, "args": authn_args} - - authn_event = None - if session: - authn_event = session.get('authn_event') - - if authn_event is None: - authn_event = create_authn_event( - identity["uid"], - identity.get("salt", ""), - authn_info=authn_class_ref, - time_stamp=_ts, - ) - - _exp_in = authn.kwargs.get("expires_in") - if _exp_in and "valid_until" in authn_event: - authn_event["valid_until"] = utc_time_sans_frac() + _exp_in - - return {"authn_event": authn_event, "identity": identity, "user": user} - - def create_authn_response(self, request, sid): - """ - - :param self: - :param request: - :param sid: - :return: - """ - # create the response - aresp = self.response_cls() - if request.get("state"): - aresp["state"] = request["state"] - - if "response_type" in request and request["response_type"] == ["none"]: - fragment_enc = False - else: - _context = self.endpoint_context - _sinfo = _context.sdb[sid] - - if request.get("scope"): - aresp["scope"] = request["scope"] - - rtype = set(request["response_type"][:]) - handled_response_type = [] - - fragment_enc = True - if len(rtype) == 1 and "code" in rtype: - fragment_enc = False - if "code" in request["response_type"]: - _code = aresp["code"] = _context.sdb[sid]["code"] - handled_response_type.append("code") - else: - _context.sdb.update(sid, code=None) - _code = None + def setup_client_session(self, user_id: str, request: dict) -> str: + _mngr = self.endpoint_context.session_manager + client_id = request['client_id'] - if "token" in rtype: - _dic = _context.sdb.upgrade_to_token(issue_refresh=False, key=sid) + _client_info = self.endpoint_context.cdb[client_id] + sub_type = _client_info.get("subject_type") + if sub_type and sub_type == "pairwise": + sector_identifier_uri = _client_info.get("sector_identifier_uri") + if sector_identifier_uri is None: + sector_identifier_uri = host_component(_client_info["redirect_uris"][0]) - logger.debug("_dic: %s" % sanitize(_dic)) - for key, val in _dic.items(): - if key in aresp.parameters() and val is not None: - aresp[key] = val - - handled_response_type.append("token") - - _access_token = aresp.get("access_token", None) - - if "id_token" in request["response_type"]: - kwargs = {} - if {"code", "id_token", "token"}.issubset(rtype): - kwargs = {"code": _code, "access_token": _access_token} - elif {"code", "id_token"}.issubset(rtype): - kwargs = {"code": _code} - elif {"id_token", "token"}.issubset(rtype): - kwargs = {"access_token": _access_token} - - if request["response_type"] == ["id_token"]: - kwargs["user_claims"] = True - - try: - id_token = _context.idtoken.make(request, _sinfo, **kwargs) - except (JWEException, NoSuitableSigningKeys) as err: - logger.warning(str(err)) - resp = self.error_cls( - error="invalid_request", - error_description="Could not sign/encrypt id_token", - ) - return {"response_args": resp, "fragment_enc": fragment_enc} - - aresp["id_token"] = id_token - _sinfo["id_token"] = id_token - handled_response_type.append("id_token") - - not_handled = rtype.difference(handled_response_type) - if not_handled: - resp = self.error_cls( - error="invalid_request", error_description="unsupported_response_type" - ) - return {"response_args": resp, "fragment_enc": fragment_enc} - - return {"response_args": aresp, "fragment_enc": fragment_enc} - - def aresp_check(self, aresp, request): - return "" - - def response_mode(self, request, **kwargs): - resp_mode = request["response_mode"] - if resp_mode == "form_post": - msg = FORM_POST.format( - inputs=inputs(kwargs["response_args"].to_dict()), - action=kwargs["return_uri"], - ) - kwargs.update( - { - "response_msg": msg, - "content_type": "text/html", - "response_placement": "body", - } - ) - elif resp_mode == "fragment": - if "fragment_enc" in kwargs: - if not kwargs["fragment_enc"]: - # Can't be done - raise InvalidRequest("wrong response_mode") - else: - kwargs["fragment_enc"] = True - elif resp_mode == "query": - if "fragment_enc" in kwargs: - if kwargs["fragment_enc"]: - # Can't be done - raise InvalidRequest("wrong response_mode") - else: - raise InvalidRequest("Unknown response_mode") - return kwargs - - def error_response(self, response_info, error, error_description): - resp = self.error_cls( - error=error, error_description=str(error_description) - ) - response_info["response_args"] = resp - return response_info - - def post_authentication(self, user, request, sid, **kwargs): - """ - Things that are done after a successful authentication. - - :param user: - :param request: - :param sid: - :param kwargs: - :return: A dictionary with 'response_args' - """ - - response_info = {} - - # Do the authorization - try: - permission = self.endpoint_context.authz( - user, client_id=request["client_id"] - ) - except ToOld as err: - return self.error_response( - response_info, - "access_denied", - "Authentication to old {}".format(err.args), - ) - except Exception as err: - return self.error_response( - response_info, "access_denied", "{}".format(err.args) + client_info = ClientSessionInfo( + authorization_request=request, + sub=_mngr.sub_func[sub_type](user_id, salt=_mngr.salt, + sector_identifier=sector_identifier_uri) ) else: - try: - self.endpoint_context.sdb.update(sid, permission=permission) - except Exception as err: - return self.error_response( - response_info, "server_error", "{}".format(err.args) - ) + sub_type = self.kwargs.get("subject_type") + if not sub_type: + sub_type = "public" - logger.debug("response type: %s" % request["response_type"]) - - if self.endpoint_context.sdb.is_session_revoked(sid): - return self.error_response( - response_info, "access_denied", "Session is revoked" + client_info = ClientSessionInfo( + authorization_request=request, + sub=_mngr.sub_func[sub_type](user_id, salt=_mngr.salt) ) - response_info = self.create_authn_response(request, sid) - - logger.debug("Known clients: {}".format(list(self.endpoint_context.cdb.keys()))) - - try: - redirect_uri = get_uri(self.endpoint_context, request, "redirect_uri") - except (RedirectURIError, ParameterError) as err: - return self.error_response( - response_info, "invalid_request", "{}".format(err.args) - ) - else: - response_info["return_uri"] = redirect_uri - - # Must not use HTTP unless implicit grant type and native application - # info = self.aresp_check(response_info['response_args'], request) - # if isinstance(info, ResponseMessage): - # return info - - _cookie = new_cookie( - self.endpoint_context, - uid=user, - sid=sid, - state=request["state"], - client_id=request["client_id"], - cookie_name=self.endpoint_context.cookie_name["session"], - ) - - # Now about the response_mode. Should not be set if it's obvious - # from the response_type. Knows about 'query', 'fragment' and - # 'form_post'. - - if "response_mode" in request: - try: - response_info = self.response_mode(request, **response_info) - except InvalidRequest as err: - return self.error_response( - response_info, "invalid_request", "{}".format(err.args) - ) - - response_info["cookie"] = [_cookie] - - return response_info - - def authz_part2( - self, - user, - authn_event, - request, - subject_type=None, - acr=None, - salt=None, - sector_id=None, - **kwargs, - ): - """ - After the authentication this is where you should end up - - :param user: - :param authn_event: The Authorization Event - :param request: The Authorization Request - :param subject_type: The subject_type - :param acr: The acr - :param salt: The salt used to produce the sub - :param sector_id: The sector_id used to produce the sub - :param kwargs: possible other parameters - :return: A redirect to the redirect_uri of the client - """ - sid = setup_session( - self.endpoint_context, - request, - user, - acr=acr, - salt=salt, - authn_event=authn_event, - subject_type=subject_type, - sector_id=sector_id, - ) - - try: - resp_info = self.post_authentication(user, request, sid, **kwargs) - except Exception as err: - return self.error_response({}, "server_error", err) - - if "check_session_iframe" in self.endpoint_context.provider_info: - ec = self.endpoint_context - salt = rndstr() - if not ec.sdb.is_session_revoked(sid): - authn_event = ec.sdb.get_authentication_event( - sid - ) # use the last session - _state = b64e( - as_bytes(json.dumps({"authn_time": authn_event["authn_time"]})) - ) - - opbs_value = '' - if hasattr(ec.cookie_dealer, 'create_cookie'): - session_cookie = ec.cookie_dealer.create_cookie( - as_unicode(_state), - typ="session", - cookie_name=ec.cookie_name["session_management"], - same_site="None", - http_only=False, - ) - - opbs = session_cookie[ec.cookie_name["session_management"]] - opbs_value = opbs.value - else: - logger.debug("Failed to set Cookie, that's not configured in main configuration.") - - logger.debug( - "compute_session_state: client_id=%s, origin=%s, opbs=%s, salt=%s", - request["client_id"], - resp_info["return_uri"], - opbs_value, - salt, - ) - - _session_state = compute_session_state( - opbs_value, salt, request["client_id"], resp_info["return_uri"] - ) - - if opbs_value: - if "cookie" in resp_info: - if isinstance(resp_info["cookie"], list): - resp_info["cookie"].append(session_cookie) - else: - append_cookie(resp_info["cookie"], session_cookie) - else: - resp_info["cookie"] = session_cookie - - resp_info["response_args"]["session_state"] = _session_state - - # Mix-Up mitigation - resp_info["response_args"]["iss"] = self.endpoint_context.issuer - resp_info["response_args"]["client_id"] = request["client_id"] - - return resp_info - - def process_request(self, request_info=None, **kwargs): - """ The AuthorizationRequest endpoint - - :param request_info: The authorization request as a dictionary - :return: dictionary - """ - - if isinstance(request_info, self.error_cls): - return request_info - - _cid = request_info["client_id"] - cinfo = self.endpoint_context.cdb[_cid] - logger.debug("client {}: {}".format(_cid, cinfo)) - - # this apply the default optionally deny_unknown_scopes policy - check_unknown_scopes_policy(request_info, cinfo, self.endpoint_context) - - cookie = kwargs.get("cookie", "") - if cookie: - del kwargs["cookie"] + _mngr.set([user_id, client_id], client_info) + return session_key(user_id, client_id) + def do_request_user(self, request_info, **kwargs): if proposed_user(request_info): kwargs["req_user"] = proposed_user(request_info) else: @@ -727,29 +142,4 @@ def process_request(self, request_info=None, **kwargs): kwargs["req_user"] = self.endpoint_context.login_hint_lookup[ _login_hint ] - - info = self.setup_auth( - request_info, request_info["redirect_uri"], cinfo, cookie, **kwargs - ) - - if "error" in info: - return info - - _function = info.get("function") - if not _function: - logger.debug("- authenticated -") - logger.debug("AREQ keys: %s" % request_info.keys()) - res = self.authz_part2( - info["user"], info["authn_event"], request_info, cookie=cookie - ) - return res - - try: - # Run the authentication function - return { - "http_response": _function(**info["args"]), - "return_uri": request_info["redirect_uri"], - } - except Exception as err: - logger.exception(err) - return {"http_response": "Internal error: {}".format(err)} + return kwargs diff --git a/src/oidcendpoint/oidc/refresh_token.py b/src/oidcendpoint/oidc/refresh_token.py deleted file mode 100755 index a2ffea1..0000000 --- a/src/oidcendpoint/oidc/refresh_token.py +++ /dev/null @@ -1,127 +0,0 @@ -import logging - -from oidcmsg import oidc -from oidcmsg.oauth2 import ResponseMessage -from oidcmsg.oidc import AccessTokenResponse -from oidcmsg.oidc import RefreshAccessTokenRequest -from oidcmsg.oidc import TokenErrorResponse - -from oidcendpoint import sanitize -from oidcendpoint.client_authn import verify_client -from oidcendpoint.cookie import new_cookie -from oidcendpoint.endpoint import Endpoint -from oidcendpoint.token_handler import ExpiredToken -from oidcendpoint.userinfo import by_schema - -logger = logging.getLogger(__name__) - - -class RefreshAccessToken(Endpoint): - request_cls = oidc.RefreshAccessTokenRequest - response_cls = oidc.AccessTokenResponse - error_cls = TokenErrorResponse - request_format = "json" - request_placement = "body" - response_format = "json" - response_placement = "body" - endpoint_name = "token_endpoint" - name = "refresh_token" - - def __init__(self, endpoint_context, **kwargs): - Endpoint.__init__(self, endpoint_context, **kwargs) - self.post_parse_request.append(self._post_parse_request) - - def _refresh_access_token(self, req, **kwargs): - _sdb = self.endpoint_context.sdb - - # client_id = str(req["client_id"]) - - if req["grant_type"] != "refresh_token": - return self.error_cls( - error="invalid_request", error_description="Wrong grant_type" - ) - - rtoken = req["refresh_token"] - try: - _info = _sdb.refresh_token(rtoken) - except ExpiredToken: - return self.error_cls( - error="invalid_request", error_description="Refresh token is expired" - ) - - return by_schema(AccessTokenResponse, **_info) - - def client_authentication(self, request, auth=None, **kwargs): - """ - Deal with client authentication - - :param request: The refresh access token request - :param auth: Client authentication information - :param kwargs: Extra keyword arguments - :return: dictionary containing client id, client authentication method - and possibly access token. - """ - try: - auth_info = verify_client(self.endpoint_context, request, auth) - msg = "" - except Exception as err: - msg = "Failed to verify client due to: {}".format(err) - logger.error(msg) - return self.error_cls(error="unauthorized_client", error_description=msg) - else: - if "client_id" not in auth_info: - logger.error("No client_id, authentication failed") - return self.error_cls( - error="unauthorized_client", error_description="unknown client" - ) - - return auth_info - - def _post_parse_request(self, request, client_id="", **kwargs): - """ - This is where clients come to refresh their access tokens - - :param request: The request - :param authn: Authentication info, comes from HTTP header - :returns: - """ - - request = RefreshAccessTokenRequest(**request.to_dict()) - - try: - keyjar = self.endpoint_context.keyjar - except AttributeError: - keyjar = "" - - request.verify(keyjar=keyjar, opponent_id=client_id) - - if "client_id" not in request: # Optional for refresh access token request - request["client_id"] = client_id - - logger.debug("%s: %s" % (request.__class__.__name__, sanitize(request))) - - return request - - def process_request(self, request=None, **kwargs): - """ - - :param request: - :param kwargs: - :return: Dictionary with response information - """ - response_args = self._refresh_access_token(request, **kwargs) - - if isinstance(response_args, ResponseMessage): - return response_args - - _token = request["refresh_token"].replace(" ", "+") - _cookie = new_cookie( - self.endpoint_context, - sub=self.endpoint_context.sdb[_token]["sub"], - cookie_name=self.endpoint_context.cookie_name["session"], - ) - _headers = [("Content-type", "application/json")] - resp = {"response_args": response_args, "http_headers": _headers} - if _cookie: - resp["cookie"] = _cookie - return resp diff --git a/src/oidcendpoint/oidc/session.py b/src/oidcendpoint/oidc/session.py index fcd9351..72ea8b4 100644 --- a/src/oidcendpoint/oidc/session.py +++ b/src/oidcendpoint/oidc/session.py @@ -6,11 +6,14 @@ from cryptojwt import as_unicode from cryptojwt import b64d +from cryptojwt.jwe.aes import AES_GCMEncrypter +from cryptojwt.jwe.utils import split_ctx_and_tag from cryptojwt.jws.exception import JWSException from cryptojwt.jws.jws import factory from cryptojwt.jws.utils import alg2keytype from cryptojwt.jwt import JWT from cryptojwt.utils import as_bytes +from cryptojwt.utils import b64e from oidcmsg.exception import InvalidRequest from oidcmsg.message import Message from oidcmsg.oauth2 import ResponseMessage @@ -18,11 +21,13 @@ from oidcmsg.oidc.session import BACK_CHANNEL_LOGOUT_EVENT from oidcmsg.oidc.session import EndSessionRequest +from oidcendpoint import rndstr from oidcendpoint.client_authn import UnknownOrNoAuthnMethod -from oidcendpoint.common.authorization import verify_uri from oidcendpoint.cookie import append_cookie from oidcendpoint.endpoint import Endpoint from oidcendpoint.endpoint_context import add_path +from oidcendpoint.oauth2.authorization import verify_uri +from oidcendpoint.session import session_key logger = logging.getLogger(__name__) @@ -85,13 +90,24 @@ def __init__(self, endpoint_context, **kwargs): if _csi and not _csi.startswith("http"): kwargs["check_session_iframe"] = add_path(endpoint_context.issuer, _csi) Endpoint.__init__(self, endpoint_context, **kwargs) + self.iv = as_bytes(rndstr(24)) - def do_back_channel_logout(self, cinfo, sub, sid): + def _encrypt_sid(self, sid): + encrypter = AES_GCMEncrypter(key=as_bytes(self.endpoint_context.symkey)) + enc_msg = encrypter.encrypt(as_bytes(sid), iv=self.iv) + return as_unicode(b64e(enc_msg)) + + def _decrypt_sid(self, enc_msg): + _msg = b64d(as_bytes(enc_msg)) + encrypter = AES_GCMEncrypter(key=as_bytes(self.endpoint_context.symkey)) + ctx, tag = split_ctx_and_tag(_msg) + return as_unicode(encrypter.decrypt(as_bytes(ctx), iv=self.iv, tag=as_bytes(tag))) + + def do_back_channel_logout(self, cinfo, sid): """ :param cinfo: Client information - :param sub: Subject identifier - :param sid: The Issuer ID + :param sid: The session ID :return: Tuple with logout URI and signed logout token """ @@ -102,10 +118,16 @@ def do_back_channel_logout(self, cinfo, sub, sid): except KeyError: return None + # Create the logout token # always include sub and sid so I don't check for # backchannel_logout_session_required - payload = {"sub": sub, "sid": sid, "events": {BACK_CHANNEL_LOGOUT_EVENT: {}}} + enc_msg = self._encrypt_sid(sid) + + payload = { + "sid": enc_msg, + "events": {BACK_CHANNEL_LOGOUT_EVENT: {}} + } try: alg = cinfo["id_token_signed_response_alg"] @@ -114,59 +136,43 @@ def do_back_channel_logout(self, cinfo, sub, sid): _jws = JWT(_cntx.keyjar, iss=_cntx.issuer, lifetime=86400, sign_alg=alg) _jws.with_jti = True - sjwt = _jws.pack(payload=payload, recv=cinfo["client_id"]) + _logout_token = _jws.pack(payload=payload, recv=cinfo["client_id"]) - return back_channel_logout_uri, sjwt + return back_channel_logout_uri, _logout_token def clean_sessions(self, usids): - # Clean out all sessions - _sdb = self.endpoint_context.sdb - _sso_db = self.endpoint_context.sdb.sso_db + # Revoke all sessions for sid in usids: - # remove session information - del _sdb[sid] - _sso_db.remove_session_id(sid) - - def logout_all_clients(self, sid, client_id): - _sdb = self.endpoint_context.sdb - _sso_db = self.endpoint_context.sdb.sso_db - - # Find all RPs this user has logged it from - uid = _sso_db.get_uid_by_sid(sid) - if uid is None: - logger.debug("Can not translate sid:%s into a user id", sid) - return {} - - _client_sid = {} - usids = _sso_db.get_sids_by_uid(uid) - if usids is None: - logger.debug("No sessions found for uid: %s", uid) - return {} + self.endpoint_context.session_manager.revoke_client_session(sid) - for usid in usids: - _client_sid[_sdb[usid]["authn_req"]["client_id"]] = usid + def logout_all_clients(self, sid): + _mngr = self.endpoint_context.session_manager + _session_info = _mngr.get_session_info(sid, user_session_info=True, + client_session_info=True) # Front-/Backchannel logout ? _cdb = self.endpoint_context.cdb _iss = self.endpoint_context.issuer + _user_id = _session_info["user_id"] bc_logouts = {} fc_iframes = {} - for _cid, _csid in _client_sid.items(): - if "backchannel_logout_uri" in _cdb[_cid]: - _sub = _sso_db.get_sub_by_sid(_csid) - _spec = self.do_back_channel_logout(_cdb[_cid], _sub, _csid) + _rel_sid = [] + for _client_id in _session_info["user_session_info"]["subordinate"]: + if "backchannel_logout_uri" in _cdb[_client_id]: + _sid = session_key(_user_id, _client_id) + _rel_sid.append(_sid) + _spec = self.do_back_channel_logout(_cdb[_client_id], _sid) if _spec: - bc_logouts[_cid] = _spec - elif "frontchannel_logout_uri" in _cdb[_cid]: + bc_logouts[_client_id] = _spec + elif "frontchannel_logout_uri" in _cdb[_client_id]: # Construct an IFrame - _spec = do_front_channel_logout_iframe(_cdb[_cid], _iss, _csid) + _sid = session_key(_user_id, _client_id) + _rel_sid.append(_sid) + _spec = do_front_channel_logout_iframe(_cdb[_client_id], _iss, _sid) if _spec: - fc_iframes[_cid] = _spec + fc_iframes[_client_id] = _spec - try: - self.clean_sessions(usids) - except KeyError: - pass + self.clean_sessions(_rel_sid) res = {} if bc_logouts: @@ -189,33 +195,26 @@ def unpack_signed_jwt(self, sjwt, sig_alg=""): else: raise ValueError("Not a signed JWT") - def logout_from_client(self, sid, client_id): + def logout_from_client(self, sid): _cdb = self.endpoint_context.cdb - _sso_db = self.endpoint_context.sdb.sso_db - - # Kill the session - _sdb = self.endpoint_context.sdb - _sdb.revoke_session(sid=sid) + _session_information = self.endpoint_context.session_manager.get_session_info( + sid, grant=True) + _client_id = _session_information["client_id"] res = {} - if "backchannel_logout_uri" in _cdb[client_id]: - _sub = _sso_db.get_sub_by_sid(sid) - _spec = self.do_back_channel_logout(_cdb[client_id], _sub, sid) + if "backchannel_logout_uri" in _cdb[_client_id]: + _spec = self.do_back_channel_logout(_cdb[_client_id], sid) if _spec: - res["blu"] = {client_id: _spec} - elif "frontchannel_logout_uri" in _cdb[client_id]: + res["blu"] = {_client_id: _spec} + elif "frontchannel_logout_uri" in _cdb[_client_id]: # Construct an IFrame _spec = do_front_channel_logout_iframe( - _cdb[client_id], self.endpoint_context.issuer, sid + _cdb[_client_id], self.endpoint_context.issuer, sid ) if _spec: - res["flu"] = {client_id: _spec} - - try: - self.clean_sessions([sid]) - except KeyError: - pass + res["flu"] = {_client_id: _spec} + self.clean_sessions([sid]) return res def process_request(self, request=None, cookie=None, **kwargs): @@ -228,7 +227,7 @@ def process_request(self, request=None, cookie=None, **kwargs): :return: """ _cntx = self.endpoint_context - _sdb = _cntx.sdb + _mngr = _cntx.session_manager if "post_logout_redirect_uri" in request: if "id_token_hint" not in request: @@ -249,63 +248,31 @@ def process_request(self, request=None, cookie=None, **kwargs): # value is a base64 encoded JSON document _cookie_info = json.loads(as_unicode(b64d(as_bytes(part[0])))) logger.debug("Cookie info: {}".format(_cookie_info)) - _sid = _cookie_info["sid"] + try: + _session_info = _mngr.get_session_info(_cookie_info["sid"], + grant=True) + except KeyError: + raise ValueError("Can't find any corresponding session") else: logger.debug("No relevant cookie") - _sid = "" - _cookie_info = {} + raise ValueError("Missing cookie") - if "id_token_hint" in request: + if "id_token_hint" in request and _session_info: + _id_token = request[verified_claim_name("id_token_hint")] logger.debug( - "ID token hint: {}".format( - request[verified_claim_name("id_token_hint")] - ) + "ID token hint: {}".format(_id_token) ) - auds = request[verified_claim_name("id_token_hint")]["aud"] - _ith_sid = "" - _sids = _sdb.sso_db.get_sids_by_sub( - request[verified_claim_name("id_token_hint")]["sub"] - ) - - if _sids is None: - raise ValueError("Unknown subject identifier") - - for _isid in _sids: - if _sdb[_isid]["authn_req"]["client_id"] in auds: - _ith_sid = _isid - break + _aud = _id_token["aud"] + if _session_info["client_id"] not in _aud: + raise ValueError("Client ID doesn't match") - if not _ith_sid: - raise ValueError("Unknown subject") - - if _sid: - if _ith_sid != _sid: # someone's messing with me - raise ValueError("Wrong ID Token hint") - else: - _sid = _ith_sid + if _id_token["sub"] != _session_info["grant"].sub: + raise ValueError("Sub doesn't match") else: - auds = [] - - try: - session = _sdb[_sid] - except KeyError: - raise ValueError("Can't find any corresponding session") - - client_id = session["authn_req"]["client_id"] - # Does this match what's in the cookie ? - if _cookie_info: - if client_id != _cookie_info["client_id"]: - logger.warning( - "Client ID in authz request and in cookie does not match" - ) - raise ValueError("Wrong Client") - - if auds: - if client_id not in auds: - raise ValueError("Incorrect ID Token hint") + _aud = [] - _cinfo = _cntx.cdb[client_id] + _cinfo = _cntx.cdb[_session_info["client_id"]] # verify that the post_logout_redirect_uri if present are among the ones # registered @@ -320,12 +287,11 @@ def process_request(self, request=None, cookie=None, **kwargs): plur = False else: plur = True - verify_uri(_cntx, request, "post_logout_redirect_uri", client_id=client_id) + verify_uri(_cntx, request, "post_logout_redirect_uri", + client_id=_session_info["client_id"]) payload = { - "sid": _sid, - "client_id": client_id, - "user": session["authn_event"]["uid"], + "sid": _session_info["session_id"], } # redirect user to OP logout verification page @@ -387,20 +353,20 @@ def parse_request(self, request, auth=None, **kwargs): pass else: if ( - _ith.jws_header["alg"] - not in self.endpoint_context.provider_info[ - "id_token_signing_alg_values_supported" - ] + _ith.jws_header["alg"] + not in self.endpoint_context.provider_info[ + "id_token_signing_alg_values_supported" + ] ): raise JWSException("Unsupported signing algorithm") return request - def do_verified_logout(self, sid, client_id, alla=False, **kwargs): + def do_verified_logout(self, sid, alla=False, **kwargs): if alla: - _res = self.logout_all_clients(sid=sid, client_id=client_id) + _res = self.logout_all_clients(sid=sid) else: - _res = self.logout_from_client(sid=sid, client_id=client_id) + _res = self.logout_from_client(sid=sid) bcl = _res.get("blu") if bcl: diff --git a/src/oidcendpoint/oidc/token.py b/src/oidcendpoint/oidc/token.py index 74eb7f9..1dffa2c 100755 --- a/src/oidcendpoint/oidc/token.py +++ b/src/oidcendpoint/oidc/token.py @@ -1,45 +1,94 @@ import logging +from typing import Optional +from typing import Union from cryptojwt.jwe.exception import JWEException from cryptojwt.jws.exception import NoSuitableSigningKeys +from cryptojwt.jwt import utc_time_sans_frac from oidcmsg import oidc +from oidcmsg.message import Message from oidcmsg.oauth2 import ResponseMessage -from oidcmsg.oidc import AccessTokenResponse +from oidcmsg.oidc import RefreshAccessTokenRequest from oidcmsg.oidc import TokenErrorResponse +from oidcmsg.time_util import time_sans_frac from oidcendpoint import sanitize from oidcendpoint.cookie import new_cookie from oidcendpoint.endpoint import Endpoint -from oidcendpoint.exception import MultipleCodeUsage -from oidcendpoint.token_handler import AccessCodeUsed -from oidcendpoint.userinfo import by_schema +from oidcendpoint.exception import ProcessError +from oidcendpoint.session import unpack_session_key +from oidcendpoint.session.grant import AuthorizationCode +from oidcendpoint.session.grant import Grant +from oidcendpoint.session.grant import RefreshToken +from oidcendpoint.session.token import Token as sessionToken +from oidcendpoint.token.exception import UnknownToken +from oidcendpoint.util import importer logger = logging.getLogger(__name__) -class AccessToken(Endpoint): - request_cls = oidc.AccessTokenRequest - response_cls = oidc.AccessTokenResponse - error_cls = TokenErrorResponse - request_format = "json" - request_placement = "body" - response_format = "json" - response_placement = "body" - endpoint_name = "token_endpoint" - name = "token" - default_capabilities = {"token_endpoint_auth_signing_alg_values_supported": None} +class TokenEndpointHelper(object): + def __init__(self, endpoint, config=None): + self.endpoint = endpoint + self.config = config + self.endpoint_context = endpoint.endpoint_context + self.error_cls = self.endpoint.error_cls - def __init__(self, endpoint_context, **kwargs): - Endpoint.__init__(self, endpoint_context, **kwargs) - self.post_parse_request.append(self._post_parse_request) - if "client_authn_method" in kwargs: - self.endpoint_info["token_endpoint_auth_methods_supported"] = kwargs[ - "client_authn_method" - ] + def post_parse_request(self, request: Union[Message, dict], + client_id: Optional[str] = "", + **kwargs): + """Context specific parsing of the request. + This is done after general request parsing and before processing + the request. + """ + raise NotImplementedError + + def process_request(self, req: Union[Message, dict], **kwargs): + """Acts on a process request.""" + raise NotImplementedError + + def _mint_token(self, token_type: str, grant: Grant, session_id: str, + based_on: Optional[sessionToken] = None) -> sessionToken: + _mngr = self.endpoint_context.session_manager + usage_rules = grant.usage_rules.get(token_type) + if usage_rules: + _exp_in = usage_rules.get("expires_in") + else: + _exp_in = 0 + + token = grant.mint_token( + session_id, + endpoint_context=self.endpoint_context, + token_type=token_type, + token_handler=_mngr.token_handler["access_token"], + based_on=based_on, + usage_rules=usage_rules + ) + + if _exp_in: + if isinstance(_exp_in, str): + _exp_in = int(_exp_in) + + if _exp_in: + token.expires_at = time_sans_frac() + _exp_in + + self.endpoint_context.session_manager.set( + unpack_session_key(session_id), grant) + + return token + + +class AccessTokenHelper(TokenEndpointHelper): + def process_request(self, req: Union[Message, dict], **kwargs): + """ - def _access_token(self, req, **kwargs): - _context = self.endpoint_context - _sdb = _context.sdb + :param req: + :param kwargs: + :return: + """ + _context = self.endpoint.endpoint_context + + _mngr = _context.session_manager _log_debug = logger.debug if req["grant_type"] != "authorization_code": @@ -54,21 +103,11 @@ def _access_token(self, req, **kwargs): error="invalid_request", error_description="Missing code" ) - # Session might not exist or _access_code malformed - try: - _info = _sdb[_access_code] - except KeyError: - return self.error_cls( - error="invalid_grant", error_description="Code is invalid" - ) - - _authn_req = _info["authn_req"] + _session_info = _mngr.get_session_info_by_token(_access_code, grant=True) + grant = _session_info["grant"] - # assert that the code is valid - if _context.sdb.is_session_revoked(_access_code): - return self.error_cls( - error="invalid_grant", error_description="Session is revoked" - ) + code = grant.get_token(_access_code) + _authn_req = grant.authorization_request # If redirect_uri was in the initial authorization request # verify that the one given here is the correct one. @@ -83,26 +122,132 @@ def _access_token(self, req, **kwargs): issue_refresh = False if "issue_refresh" in kwargs: issue_refresh = kwargs["issue_refresh"] + else: + if "offline_access" in grant.scope: + issue_refresh = True + + _response = { + "token_type": "Bearer", + "scope": grant.scope, + } + + token = self._mint_token(token_type="access_token", + grant=grant, + session_id=_session_info["session_id"], + based_on=code) + _response["access_token"] = token.value + _response["expires_in"] = token.expires_at - utc_time_sans_frac() + + if issue_refresh: + refresh_token = self._mint_token(token_type="refresh_token", + grant=grant, + session_id=_session_info["session_id"], + based_on=code) + _response["refresh_token"] = refresh_token.value + + code.register_usage() + + # since the grant content has changed + _mngr[_session_info["session_id"]] = grant + + if "openid" in _authn_req["scope"]: + try: + _idtoken = _context.idtoken.make(_session_info["session_id"]) + except (JWEException, NoSuitableSigningKeys) as err: + logger.warning(str(err)) + resp = self.error_cls( + error="invalid_request", + error_description="Could not sign/encrypt id_token", + ) + return resp + + _response["id_token"] = _idtoken - # offline_access the default if nothing is specified - permissions = _info.get("permission", ["offline_access"]) + return _response - if "offline_access" in _authn_req["scope"] and "offline_access" in permissions: - issue_refresh = True + def post_parse_request(self, request: Union[Message, dict], + client_id: Optional[str] = "", + **kwargs): + """ + This is where clients come to get their access tokens + :param request: The request + :param client_id: Client identifier + :returns: + """ + + _mngr = self.endpoint_context.session_manager try: - _info = _sdb.upgrade_to_token(_access_code, issue_refresh=issue_refresh) - except AccessCodeUsed as err: - logger.error("%s" % err) - # Should revoke the token issued to this access code - _sdb.revoke_all_tokens(_access_code) + _session_info = _mngr.get_session_info_by_token(request["code"], + grant=True) + except (KeyError, UnknownToken): + logger.error("Access Code invalid") + return self.error_cls(error="invalid_grant", + error_description="Unknown code") + + _grant = _session_info["grant"] + code = _grant.get_token(request["code"]) + if not isinstance(code, AuthorizationCode): return self.error_cls( - error="access_denied", error_description="Access Code already used" + error="invalid_request", error_description="Wrong token type" ) - if "openid" in _authn_req["scope"]: + if code.is_active() is False: + return self.error_cls( + error="invalid_request", error_description="Code inactive" + ) + + _auth_req = _grant.authorization_request + + if "client_id" not in request: # Optional for access token request + request["client_id"] = _auth_req["client_id"] + + logger.debug("%s: %s" % (request.__class__.__name__, sanitize(request))) + + return request + + +class RefreshTokenHelper(TokenEndpointHelper): + def process_request(self, req: Union[Message, dict], **kwargs): + _mngr = self.endpoint_context.session_manager + + if req["grant_type"] != "refresh_token": + return self.error_cls( + error="invalid_request", error_description="Wrong grant_type" + ) + + token_value = req["refresh_token"] + _session_info = _mngr.get_session_info_by_token(token_value, grant=True) + + _grant = _session_info["grant"] + token = _grant.get_token(token_value) + + access_token = self._mint_token(token_type="access_token", + grant=_grant, + session_id=_session_info["session_id"], + based_on=token) + + _resp = { + "access_token": access_token.value, + "token_type": "Bearer", + "scope": _grant.scope + } + + if access_token.expires_at: + _resp["expires_in"] = access_token.expires_at - utc_time_sans_frac() + + _mints = token.usage_rules.get("supports_minting") + if "refresh_token" in _mints: + refresh_token = self._mint_token(token_type="refresh_token", + grant=_grant, + session_id=_session_info["session_id"], + based_on=token) + refresh_token.usage_rules = token.usage_rules.copy() + _resp["refresh_token"] = refresh_token.value + + if "id_token" in _mints: try: - _idtoken = _context.idtoken.make(req, _info, _authn_req) + _idtoken = self.endpoint_context.idtoken.make(_session_info["session_id"]) except (JWEException, NoSuitableSigningKeys) as err: logger.warning(str(err)) resp = self.error_cls( @@ -111,50 +256,135 @@ def _access_token(self, req, **kwargs): ) return resp - _sdb.update_by_token(_access_code, id_token=_idtoken) - _info = _sdb[_info["sid"]] - - return by_schema(AccessTokenResponse, **_info) + _resp["id_token"] = _idtoken - def get_client_id_from_token(self, endpoint_context, token, request=None): - sinfo = endpoint_context.sdb[token] - return sinfo["authn_req"]["client_id"] + return _resp - def _post_parse_request(self, request, client_id="", **kwargs): + def post_parse_request(self, request: Union[Message, dict], + client_id: Optional[str] = "", + **kwargs): """ - This is where clients come to get their access tokens + This is where clients come to refresh their access tokens :param request: The request :param authn: Authentication info, comes from HTTP header :returns: """ - if "state" in request: - try: - sinfo = self.endpoint_context.sdb[request["code"]] - except KeyError: - logger.error("Code not present in SessionDB") - return self.error_cls(error="access_denied") - except MultipleCodeUsage: - logger.error("Access Code reused") - # Remove any access tokens issued - self.endpoint_context.sdb.revoke_all_tokens(request["code"]) - return self.error_cls(error="invalid_grant") - else: - state = sinfo["authn_req"]["state"] + request = RefreshAccessTokenRequest(**request.to_dict()) - if state != request["state"]: - logger.error("State value mismatch") - return self.error_cls(error="invalid_request") + try: + keyjar = self.endpoint_context.keyjar + except AttributeError: + keyjar = "" - if "client_id" not in request: # Optional for access token request - request["client_id"] = client_id + request.verify(keyjar=keyjar, opponent_id=client_id) - logger.debug("%s: %s" % (request.__class__.__name__, sanitize(request))) + _mngr = self.endpoint_context.session_manager + try: + _session_info = _mngr.get_session_info_by_token( + request["refresh_token"], grant=True + ) + except KeyError: + logger.error("Access Code invalid") + return self.error_cls(error="invalid_grant") + + _grant = _session_info["grant"] + token = _grant.get_token(request["refresh_token"]) + + if not isinstance(token, RefreshToken): + return self.error_cls( + error="invalid_request", error_description="Wrong token type" + ) + + if token.is_active() is False: + return self.error_cls( + error="invalid_request", error_description="Refresh token inactive" + ) return request - def process_request(self, request=None, **kwargs): + +HELPER_BY_GRANT_TYPE = { + "authorization_code": AccessTokenHelper, + "refresh_token": RefreshTokenHelper, +} + + +class Token(Endpoint): + request_cls = oidc.Message + response_cls = oidc.AccessTokenResponse + error_cls = TokenErrorResponse + request_format = "json" + request_placement = "body" + response_format = "json" + response_placement = "body" + endpoint_name = "token_endpoint" + name = "token" + default_capabilities = {"token_endpoint_auth_signing_alg_values_supported": None} + + def __init__(self, endpoint_context, new_refresh_token=False, **kwargs): + Endpoint.__init__(self, endpoint_context, **kwargs) + self.post_parse_request.append(self._post_parse_request) + if "client_authn_method" in kwargs: + self.endpoint_info["token_endpoint_auth_methods_supported"] = kwargs[ + "client_authn_method" + ] + self.allow_refresh = False + self.new_refresh_token = new_refresh_token + self.configure_grant_types(kwargs.get("grant_types_supported")) + + def configure_grant_types(self, grant_types_supported): + if grant_types_supported is None: + self.helper = {k: v(self) for k, v in HELPER_BY_GRANT_TYPE.items()} + return + + self.helper = {} + # TODO: do we want to allow any grant_type? + for grant_type, grant_type_options in grant_types_supported.items(): + if ( + grant_type_options in (None, True) + and grant_type in HELPER_BY_GRANT_TYPE + ): + self.helper[grant_type] = HELPER_BY_GRANT_TYPE[grant_type] + continue + elif grant_type_options is False: + continue + + try: + grant_class = grant_type_options["class"] + except (KeyError, TypeError): + raise ProcessError( + "Token Endpoint's grant types must be True, None or a dict with a" + " 'class' key." + ) + _conf = grant_type_options.get("kwargs", {}) + + if isinstance(grant_class, str): + try: + grant_class = importer(grant_class) + except (ValueError, AttributeError): + raise ProcessError( + f"Token Endpoint's grant type class {grant_class} can't" + " be imported." + ) + try: + self.helper[grant_type] = grant_class(self, _conf) + except Exception as e: + raise ProcessError(f"Failed to initialize class {grant_class}: {e}") + + def _post_parse_request(self, request: Union[Message, dict], + client_id: Optional[str] = "", **kwargs): + _helper = self.helper.get(request["grant_type"]) + if _helper: + return _helper.post_parse_request(request, client_id, **kwargs) + else: + return self.error_cls( + error="invalid_request", + error_description=f"Unsupported grant_type: {request['grant_type']}" + ) + + def process_request(self, request: Optional[Union[Message, dict]] = None, **kwargs): """ :param request: @@ -163,8 +393,19 @@ def process_request(self, request=None, **kwargs): """ if isinstance(request, self.error_cls): return request + + if request is None: + return self.error_cls(error="invalid_request") + try: - response_args = self._access_token(request, **kwargs) + _helper = self.helper.get(request["grant_type"]) + if _helper: + response_args = _helper.process_request(request, **kwargs) + else: + return self.error_cls( + error="invalid_request", + error_description=f"Unsupported grant_type: {request['grant_type']}" + ) except JWEException as err: return self.error_cls(error="invalid_request", error_description="%s" % err) @@ -172,11 +413,15 @@ def process_request(self, request=None, **kwargs): return response_args _access_token = response_args["access_token"] + _session_info = self.endpoint_context.session_manager.get_session_info_by_token( + _access_token, grant=True) + _cookie = new_cookie( self.endpoint_context, - sub=self.endpoint_context.sdb[_access_token]["sub"], + sub=_session_info["grant"].sub, cookie_name=self.endpoint_context.cookie_name["session"], ) + _headers = [("Content-type", "application/json")] resp = {"response_args": response_args, "http_headers": _headers} if _cookie: diff --git a/src/oidcendpoint/oidc/token_coop.py b/src/oidcendpoint/oidc/token_coop.py deleted file mode 100755 index 56272b9..0000000 --- a/src/oidcendpoint/oidc/token_coop.py +++ /dev/null @@ -1,328 +0,0 @@ -import logging - -from cryptojwt.jwe.exception import JWEException -from cryptojwt.jws.exception import NoSuitableSigningKeys -from oidcmsg import oidc -from oidcmsg.exception import MissingRequiredAttribute -from oidcmsg.exception import MissingRequiredValue -from oidcmsg.oauth2 import ResponseMessage -from oidcmsg.oidc import AccessTokenRequest -from oidcmsg.oidc import AccessTokenResponse -from oidcmsg.oidc import RefreshAccessTokenRequest -from oidcmsg.oidc import TokenErrorResponse -from oidcmsg.storage import importer - -from oidcendpoint import sanitize -from oidcendpoint.cookie import new_cookie -from oidcendpoint.endpoint import Endpoint -from oidcendpoint.exception import MultipleCodeUsage -from oidcendpoint.exception import ProcessError -from oidcendpoint.token_handler import AccessCodeUsed -from oidcendpoint.token_handler import ExpiredToken -from oidcendpoint.token_handler import UnknownToken -from oidcendpoint.userinfo import by_schema - -logger = logging.getLogger(__name__) - - -class TokenEndpointHelper: - def __init__(self, endpoint, config=None): - self.endpoint = endpoint - self.config = config - - def post_parse_request(self, request, client_id="", **kwargs): - """Context specific parsing of the request. - This is done after general request parsing and before processing - the request. - """ - raise NotImplementedError - - def process_request(self, req, **kwargs): - """Acts on a process request.""" - raise NotImplementedError - - -class AccessToken(TokenEndpointHelper): - def post_parse_request(self, request, client_id="", **kwargs): - """ - This is where clients come to get their access tokens - - :param request: The request - :param authn: Authentication info, comes from HTTP header - :returns: - """ - - request = AccessTokenRequest(**request.to_dict()) - - error_cls = self._validate_state(request) - if error_cls is not None: - return error_cls - - if "client_id" not in request: # Optional for access token request - request["client_id"] = client_id - - logger.debug("%s: %s" % (request.__class__.__name__, sanitize(request))) - - return request - - def _validate_state(self, request): - if "state" not in request: - return - - if "code" not in request: - return - - try: - sinfo = self.endpoint.endpoint_context.sdb[request["code"]] - except (KeyError, UnknownToken, MultipleCodeUsage): - return - - state = sinfo["authn_req"]["state"] - - if state != request["state"]: - logger.error("State value mismatch") - return self.endpoint.error_cls(error="invalid_request") - - def process_request(self, req, **kwargs): - _context = self.endpoint.endpoint_context - _sdb = _context.sdb - _log_debug = logger.debug - - try: - _access_code = req["code"].replace(" ", "+") - except KeyError: # Missing code parameter - absolutely fatal - return self.endpoint.error_cls( - error="invalid_request", error_description="Missing code" - ) - - try: - _info = _sdb[_access_code] - except (KeyError, UnknownToken): - logger.error("Code not present in SessionDB") - return self.endpoint.error_cls( - error="invalid_grant", error_description="Invalid code" - ) - except MultipleCodeUsage: - return self.endpoint.error_cls( - error="invalid_grant", error_description="Code is already used" - ) - - _authn_req = _info["authn_req"] - - # assert that the code is valid - if _context.sdb.is_session_revoked(_access_code): - return self.endpoint.error_cls( - error="invalid_grant", error_description="Session is revoked" - ) - - # If redirect_uri was in the initial authorization request - # verify that the one given here is the correct one. - if "redirect_uri" in _authn_req: - if req["redirect_uri"] != _authn_req["redirect_uri"]: - return self.endpoint.error_cls( - error="invalid_request", error_description="redirect_uri mismatch" - ) - - _log_debug("All checks OK") - - issue_refresh = False - if "issue_refresh" in kwargs: - issue_refresh = kwargs["issue_refresh"] - - # offline_access the default if nothing is specified - permissions = _info.get("permission", ["offline_access"]) - - if "offline_access" in _authn_req["scope"] and "offline_access" in permissions: - issue_refresh = True - - try: - _info = _sdb.upgrade_to_token(_access_code, issue_refresh=issue_refresh) - except AccessCodeUsed as err: - logger.error("%s" % err) - # Should revoke the token issued to this access code - _sdb.revoke_all_tokens(_access_code) - return self.endpoint.error_cls( - error="access_denied", error_description="Access Code already used" - ) - - if "openid" in _authn_req["scope"]: - try: - _idtoken = _context.idtoken.make(req, _info, _authn_req) - except (JWEException, NoSuitableSigningKeys) as err: - logger.warning(str(err)) - resp = self.error_cls( - error="invalid_request", - error_description="Could not sign/encrypt id_token", - ) - return resp - - _sdb.update_by_token(_access_code, id_token=_idtoken) - _info = _sdb[_info["sid"]] - - return by_schema(AccessTokenResponse, **_info) - - -class RefreshToken(TokenEndpointHelper): - def post_parse_request(self, request, client_id="", **kwargs): - """ - This is where clients come to refresh their access tokens - - :param request: The request - :param authn: Authentication info, comes from HTTP header - :returns: - """ - request = RefreshAccessTokenRequest(**request.to_dict()) - - # verify that the request message is correct - try: - request.verify( - keyjar=self.endpoint.endpoint_context.keyjar, opponent_id=client_id - ) - except (MissingRequiredAttribute, ValueError, MissingRequiredValue) as err: - return self.endpoint.error_cls( - error="invalid_request", error_description="%s" % err - ) - - if "client_id" not in request: # Optional for refresh access token request - request["client_id"] = client_id - - logger.debug("%s: %s" % (request.__class__.__name__, sanitize(request))) - - return request - - def process_request(self, req, **kwargs): - _sdb = self.endpoint.endpoint_context.sdb - - rtoken = req["refresh_token"] - try: - _info = _sdb.refresh_token( - rtoken, new_refresh=self.endpoint.new_refresh_token - ) - except ExpiredToken: - return self.error_cls( - error="invalid_request", error_description="Refresh token is expired" - ) - - return by_schema(AccessTokenResponse, **_info) - - -HELPER_BY_GRANT_TYPE = { - "authorization_code": AccessToken, - "refresh_token": RefreshToken, -} - - -class TokenCoop(Endpoint): - request_cls = oidc.Message - response_cls = oidc.AccessTokenResponse - error_cls = TokenErrorResponse - request_format = "json" - request_placement = "body" - response_format = "json" - response_placement = "body" - endpoint_name = "token_endpoint" - name = "token" - default_capabilities = {"token_endpoint_auth_signing_alg_values_supported": None} - - def __init__(self, endpoint_context, new_refresh_token=False, **kwargs): - Endpoint.__init__(self, endpoint_context, **kwargs) - self.post_parse_request.append(self._post_parse_request) - if "client_authn_method" in kwargs: - self.endpoint_info["token_endpoint_auth_methods_supported"] = kwargs[ - "client_authn_method" - ] - - # self.allow_refresh = False - self.new_refresh_token = new_refresh_token - - self.configure_grant_types(kwargs.get("grant_types_supported")) - - def configure_grant_types(self, grant_types_supported): - if grant_types_supported is None: - self.helper = {k: v(self) for k, v in HELPER_BY_GRANT_TYPE.items()} - return - - self.helper = {} - # TODO: do we want to allow any grant_type? - for grant_type, grant_type_options in grant_types_supported.items(): - if ( - grant_type_options in (None, True) - and grant_type in HELPER_BY_GRANT_TYPE - ): - self.helper[grant_type] = HELPER_BY_GRANT_TYPE[grant_type] - continue - elif grant_type_options is False: - continue - - try: - grant_class = grant_type_options["class"] - except (KeyError, TypeError): - raise ProcessError( - "Token Endpoint's grant types must be True, None or a dict with a" - " 'class' key." - ) - _conf = grant_type_options.get("kwargs", {}) - - if isinstance(grant_class, str): - try: - grant_class = importer(grant_class) - except (ValueError, AttributeError): - raise ProcessError( - f"Token Endpoint's grant type class {grant_class} can't" - " be imported." - ) - try: - self.helper[grant_type] = grant_class(self, _conf) - except Exception as e: - raise ProcessError(f"Failed to initialize class {grant_class}: {e}") - - def get_client_id_from_token(self, endpoint_context, token, request=None): - sinfo = endpoint_context.sdb[token] - return sinfo["authn_req"]["client_id"] - - def _post_parse_request(self, request, client_id="", **kwargs): - _helper = self.helper.get(request["grant_type"]) - if _helper: - return _helper.post_parse_request(request, client_id, **kwargs) - else: - return self.error_cls( - error="invalid_request", - error_description=f"Unsupported grant_type: {request['grant_type']}" - ) - - def process_request(self, request=None, **kwargs): - """ - - :param request: - :param kwargs: - :return: Dictionary with response information - """ - if isinstance(request, self.error_cls): - return request - try: - _helper = self.helper.get(request["grant_type"]) - if _helper: - response_args = _helper.process_request(request, **kwargs) - else: - return self.error_cls( - error="invalid_request", - error_description=f"Unsupported grant_type: {request['grant_type']}" - ) - except JWEException as err: - return self.error_cls(error="invalid_request", error_description="%s" % err) - - if isinstance(response_args, ResponseMessage): - return response_args - - _access_token = response_args["access_token"] - _cookie = new_cookie( - self.endpoint_context, - sub=self.endpoint_context.sdb[_access_token]["sub"], - cookie_name=self.endpoint_context.cookie_name["session"], - ) - - _headers = [("Content-type", "application/json")] - resp = {"response_args": response_args, "http_headers": _headers} - if _cookie: - resp["cookie"] = _cookie - return resp diff --git a/src/oidcendpoint/oidc/userinfo.py b/src/oidcendpoint/oidc/userinfo.py index ffe57e7..dd2990f 100755 --- a/src/oidcendpoint/oidc/userinfo.py +++ b/src/oidcendpoint/oidc/userinfo.py @@ -1,16 +1,19 @@ import json import logging +from typing import Optional +from typing import Union from cryptojwt.exception import MissingValue from cryptojwt.jwt import JWT from cryptojwt.jwt import utc_time_sans_frac -from oidcendpoint.token_handler import UnknownToken from oidcmsg import oidc from oidcmsg.message import Message from oidcmsg.oauth2 import ResponseMessage from oidcendpoint.endpoint import Endpoint -from oidcendpoint.userinfo import collect_user_info +from oidcendpoint.endpoint_context import EndpointContext +from oidcendpoint.session.token import AccessToken +from oidcendpoint.token.exception import UnknownToken from oidcendpoint.util import OAUTH2_NOCACHE_HEADERS logger = logging.getLogger(__name__) @@ -32,17 +35,25 @@ class UserInfo(Endpoint): "client_authn_method": ["bearer_header"], } - def __init__(self, endpoint_context, **kwargs): - Endpoint.__init__(self, endpoint_context, **kwargs) - self.scope_to_claims = None + def __init__(self, endpoint_context: EndpointContext, + add_claims_by_scope: Optional[bool] = True, + **kwargs): + Endpoint.__init__( + self, + endpoint_context, + add_claims_by_scope=add_claims_by_scope, + **kwargs, + ) # Add the issuer ID as an allowed JWT target self.allowed_targets.append("") def get_client_id_from_token(self, endpoint_context, token, request=None): - sinfo = self.endpoint_context.sdb[token] - return sinfo["authn_req"]["client_id"] + _info = endpoint_context.session_manager.get_session_info_by_token(token) + return _info["client_id"] - def do_response(self, response_args=None, request=None, client_id="", **kwargs): + def do_response(self, response_args: Optional[Union[Message, dict]] = None, + request: Optional[Union[Message, dict]] = None, + client_id: Optional[str] = "", **kwargs) -> dict: if "error" in kwargs and kwargs["error"]: return Endpoint.do_response(self, response_args, request, **kwargs) @@ -96,23 +107,31 @@ def do_response(self, response_args=None, request=None, client_id="", **kwargs): return {"response": resp, "http_headers": http_headers} def process_request(self, request=None, **kwargs): - _sdb = self.endpoint_context.sdb - + _mngr = self.endpoint_context.session_manager + _session_info = _mngr.get_session_info_by_token(request["access_token"], + grant=True) + _grant = _session_info["grant"] + token = _grant.get_token(request["access_token"]) # should be an access token - if not _sdb.is_token_valid(request["access_token"]): + if not isinstance(token, AccessToken): return self.error_cls( - error="invalid_token", error_description="Invalid Token" + error="invalid_token", error_description="Wrong type of token" ) - session = _sdb.read(request["access_token"]) + # And it should be valid + if token.is_active() is False: + return self.error_cls( + error="invalid_token", error_description="Invalid Token" + ) allowed = True + _auth_event = _grant.authentication_event # if the authenticate is still active or offline_access is granted. - if session["authn_event"]["valid_until"] > utc_time_sans_frac(): + if _auth_event["valid_until"] > utc_time_sans_frac(): pass else: logger.debug("authentication not valid: {} > {}".format( - session["authn_event"]["valid_until"], utc_time_sans_frac() + _auth_event["valid_until"], utc_time_sans_frac() )) allowed = False @@ -121,15 +140,18 @@ def process_request(self, request=None, **kwargs): # pass if allowed: - # Scope can translate to userinfo_claims - info = collect_user_info(self.endpoint_context, session) + _claims = _grant.claims.get("userinfo") + info = self.endpoint_context.claims_interface.get_user_claims( + user_id=_session_info["user_id"], + claims_restriction=_claims) + info["sub"] = _grant.sub else: info = { "error": "invalid_request", "error_description": "Access not granted", } - return {"response_args": info, "client_id": session["authn_req"]["client_id"]} + return {"response_args": info, "client_id": _session_info["client_id"]} def parse_request(self, request, auth=None, **kwargs): """ diff --git a/src/oidcendpoint/scopes.py b/src/oidcendpoint/scopes.py index 8a04544..80166b9 100644 --- a/src/oidcendpoint/scopes.py +++ b/src/oidcendpoint/scopes.py @@ -1,5 +1,4 @@ # default set can be changed by configuration -from oidcmsg.oidc import OpenIDSchema SCOPE2CLAIMS = { "openid": ["sub"], @@ -25,9 +24,6 @@ "offline_access": [], } -IGNORE = ["error", "error_description", "error_uri", "_claim_names", "_claim_sources"] -STANDARD_CLAIMS = [c for c in OpenIDSchema.c_param.keys() if c not in IGNORE] - def available_scopes(endpoint_context): _supported = endpoint_context.provider_info.get("scopes_supported") @@ -37,27 +33,23 @@ def available_scopes(endpoint_context): return [s for s in endpoint_context.scope2claims.keys()] -def available_claims(endpoint_context): - _supported = endpoint_context.provider_info.get("claims_supported") - if _supported: - return _supported - else: - return STANDARD_CLAIMS - - -def convert_scopes2claims(scopes, allowed_claims, map=None): +def convert_scopes2claims(scopes, allowed_claims=None, map=None): if map is None: map = SCOPE2CLAIMS res = {} - for scope in scopes: - try: - claims = dict( - [(name, None) for name in map[scope] if name in allowed_claims] - ) + if allowed_claims is None: + for scope in scopes: + claims = {name: None for name in map[scope]} res.update(claims) - except KeyError: - continue + else: + for scope in scopes: + try: + claims = {name: None for name in map[scope] if name in allowed_claims} + res.update(claims) + except KeyError: + continue + return res @@ -85,25 +77,3 @@ def allowed_scopes(self, client_id, endpoint_context): def filter_scopes(self, client_id, endpoint_context, scopes): allowed_scopes = self.allowed_scopes(client_id, endpoint_context) return [s for s in scopes if s in allowed_scopes] - - -class Claims: - def __init__(self): - pass - - def allowed_claims(self, client_id, endpoint_context): - """ - Returns the set of claims that a specific client can use. - - :param client_id: The client identifier - :param endpoint_context: A EndpointContext instance - :returns: List of claim names. Can be empty. - """ - _cli = endpoint_context.cdb.get(client_id) - if _cli is not None: - _claims = _cli.get("allowed_claims") - if _claims: - return _claims - else: - return available_claims(endpoint_context) - return [] diff --git a/src/oidcendpoint/session.py b/src/oidcendpoint/session.py deleted file mode 100644 index 46f71d3..0000000 --- a/src/oidcendpoint/session.py +++ /dev/null @@ -1,621 +0,0 @@ -import hashlib -import json -import logging -import time - -from oidcmsg.exception import MissingParameter -from oidcmsg.message import OPTIONAL_LIST_OF_STRINGS -from oidcmsg.message import SINGLE_OPTIONAL_STRING -from oidcmsg.message import SINGLE_REQUIRED_STRING -from oidcmsg.message import Message -from oidcmsg.message import msg_ser -from oidcmsg.oidc import AuthorizationRequest -from oidcmsg.time_util import utc_time_sans_frac - -from oidcendpoint import token_handler -from oidcendpoint.authn_event import AuthnEvent -from oidcendpoint.exception import MultipleCodeUsage -from oidcendpoint.token_handler import ExpiredToken -from oidcendpoint.token_handler import UnknownToken -from oidcendpoint.token_handler import WrongTokenType -from oidcendpoint.token_handler import is_expired -from oidcendpoint.util import sector_id_from_redirect_uris - -logger = logging.getLogger(__name__) - - -def authorization_request_deser(val, sformat="urlencoded"): - if sformat in ["dict", "json"]: - if not isinstance(val, str): - val = json.dumps(val) - sformat = "json" - return AuthorizationRequest().deserialize(val, sformat) - - -SINGLE_REQUIRED_AUTHORIZATION_REQUEST = ( - Message, - True, - msg_ser, - authorization_request_deser, - False, -) - - -def authn_event_deser(val, sformat="urlencoded"): - if sformat in ["dict", "json"]: - if not isinstance(val, str): - val = json.dumps(val) - sformat = "json" - return AuthnEvent().deserialize(val, sformat) - - -def setup_session( - endpoint_context, - areq, - uid, - client_id="", - acr="", - salt="salt", - authn_event=None, - subject_type=None, - sector_id=None -): - """ - Setting up a user session - - :param endpoint_context: - :param areq: - :param uid: - :param acr: - :param client_id: - :param salt: - :param authn_event: A already made AuthnEvent - :return: - """ - if authn_event is None and acr: - authn_event = AuthnEvent( - uid=uid, salt=salt, authn_info=acr, authn_time=time.time() - ) - - if not client_id: - client_id = areq["client_id"] - - sid = endpoint_context.sdb.create_authz_session( - authn_event, areq, client_id=client_id, uid=uid - ) - - client = endpoint_context.cdb.get(client_id) - endpoint_context.sdb.do_sub( - sid, - uid, - salt=salt, - sector_id=sector_id, - subject_type=subject_type, - client=client, - ) - return sid - - -SINGLE_REQUIRED_AUTHN_EVENT = (Message, True, msg_ser, authn_event_deser, False) - - -class SessionInfo(Message): - c_param = { - "oauth_state": SINGLE_REQUIRED_STRING, - "code": SINGLE_OPTIONAL_STRING, - "authn_req": SINGLE_REQUIRED_AUTHORIZATION_REQUEST, - "client_id": SINGLE_REQUIRED_STRING, - "authn_event": SINGLE_REQUIRED_AUTHN_EVENT, - "si_redirects": OPTIONAL_LIST_OF_STRINGS, - } - - -def pairwise_id(uid, sector_identifier, salt, **kwargs): - return hashlib.sha256( - ("%s%s%s" % (uid, sector_identifier, salt)).encode("utf-8") - ).hexdigest() - - -def public_id(uid, salt="", **kwargs): - return hashlib.sha256("{}{}".format(uid, salt).encode("utf-8")).hexdigest() - - -def dict_match(a, b): - """ - Check if all attribute/value pairs in a also appears in b - - :param a: A dictionary - :param b: A dictionary - :return: True/False - """ - res = [] - for k, v in a.items(): - try: - res.append(b[k] == v) - except KeyError: - pass - return all(res) - - -class SessionDB(object): - def __init__(self, db, handler, sso_db, userinfo=None, sub_func=None): - self._db = db - self.handler = handler - self.sso_db = sso_db - self.userinfo = userinfo - - # this allows the subject identifier minters to be defined by someone - # else then me. - if sub_func is None: - self.sub_func = {"public": public_id, "pairwise": pairwise_id} - else: - self.sub_func = sub_func - if "public" not in sub_func: - self.sub_func["public"] = public_id - if "pairwise" not in sub_func: - self.sub_func["pairwise"] = pairwise_id - - def __getitem__(self, item): - _info = self._db.get(item) - - if _info is None: - sid = self.handler.sid(item) - _info = self._db.get(sid) - if _info: - _si = SessionInfo(**_info) - if any(item == val for val in _si.values()): - _si["sid"] = sid - return _si - else: - _handler, _res = self.handler.get_handler(item) - if _handler.type == 'A': # access grant - if _si['code_is_used']: - raise MultipleCodeUsage('Reused code') - else: - _si = SessionInfo(**_info) - _si["sid"] = item - return _si - raise KeyError - - def __setitem__(self, sid, instance): - if isinstance(instance, Message): - _info = instance.to_dict() - else: - _info = instance - self._db[sid] = _info - - def __delitem__(self, key): - del self._db[key] - - def keys(self): - return self._db.keys() - - def create_authz_session(self, authn_event, areq, client_id="", uid="", **kwargs): - """ - - :param authn_event: - :param areq: - :param client_id: - :param uid: - :param kwargs: - :return: - """ - try: - _uid = authn_event["uid"] - except (TypeError, KeyError): - _uid = uid - - if not _uid: - raise MissingParameter('Need a "uid"') - - sid = self.handler["code"].key(user=_uid, areq=areq) - - access_grant = self.handler["code"](sid=sid) - - _info = SessionInfo(code=access_grant, oauth_state="authz") - - if client_id: - _info["client_id"] = client_id - - if areq: - _info["authn_req"] = areq - if authn_event: - _info["authn_event"] = authn_event - - if kwargs: - _info.update(kwargs) - - self[sid] = _info - return sid - - def update(self, sid, **kwargs): - """ - Add attribute value assertion to a special session - - :param sid: Session ID - :param kwargs: - """ - item = self[sid] - for attribute, value in kwargs.items(): - item[attribute] = value - self[sid] = item - - def update_by_token(self, token, **kwargs): - """ - Updated the session info. Any type of known token can be used - - :param token: code/access token/refresh token/... - :param kwargs: Key word arguements - """ - _sid = self.handler.sid(token) - return self.update(_sid, **kwargs) - - def set(self, label, key, value): - logger.debug("Session set {} - {}: {}".format(label, key, value)) - # Try loading the key - _dic = self._db.get(key, {}) - if label in _dic: - _dic[label].append(value) - else: - _dic[label] = [value] - self._db[key] = _dic - - def get(self, key, seckey): - _dic = self._db.get(key, {}) - logger.debug("SSODb get {} - {}: {}".format(key, seckey, _dic)) - return _dic.get(seckey, None) - - def map_kv2sid(self, key, seckey, sid): - self.sso_db.set(key, seckey, sid) - - def delete_kv2sid(self, key, seckey): - self.sso_db.delete(key, seckey) - - def get_sid_by_kv(self, key, seckey): - return self.sso_db.get(key, seckey) - - def get_token(self, sid): - _sess_info = self[sid] - - if _sess_info["oauth_state"] == "authz": - return _sess_info["code"] - elif _sess_info["oauth_state"] == "token": - return _sess_info["access_token"] - - def do_sub( - self, - sid, - uid, - salt=None, - sector_id=None, - subject_type=None, - client=None, - ): - """ - Create and store a subject identifier - - :param sid: Session ID - :param uid: User ID - :param salt: - :param sector_id: For pairwise identifiers, an Identifier for the RP group - :param subject_type: 'pairwaise'/'public' - :param client_id: - :return: - """ - if not client: - client = {} - - salt = salt - if not salt: - salt = client.get("client_salt", salt) - if not subject_type: - # Default to public - subject_type = client.get("subject_type", "public") - if subject_type == "pairwise" and not sector_id: - sector_id = client.get("sector_identifier_uri") - # If no `sector_identifier_uri` is registered, then get the host - # component of the registered redirect_uri - if not sector_id: - uris = [uri[0] for uri in client.get("redirect_uris", [])] - sector_id = sector_id_from_redirect_uris(uris) - - sub = self.sub_func[subject_type]( - uid, salt=salt, sector_identifier=sector_id - ) - - self.sso_db.map_sid2uid(sid, uid) - self.update(sid, sub=sub) - self.sso_db.map_sid2sub(sid, sub) - - return sub - - def is_valid(self, typ, item): - try: - return typ in self[item] - except (KeyError, MultipleCodeUsage): - return False - - def get_sids_by_sub(self, sub): - return self.sso_db.get_sids_by_sub(sub) - - def get_sid_by_sub_and_client_id(self, sub, client_id): - for sid in self.sso_db.get_sids_by_sub(sub): - if self[sid]["authn_req"]["client_id"] == client_id: - return sid - return None - - def replace_refresh_token(self, sid, sinfo): - """ - Replace an old refresh_token with a new one - - :param sid: session ID - :param sinfo: session info - :return: Updated session info - """ - refresh_token = self.handler["refresh_token"](sid, sinfo=sinfo) - sinfo["refresh_token"] = refresh_token - return sinfo - - def _make_at(self, sid, session_info, aud=None, client_id_aud=True): - uid = self.sso_db.get_uid_by_sid(sid) - client_id = session_info["client_id"] - uinfo = self.userinfo(uid, client_id) or {} - at_aud = aud or [] - - if client_id_aud: - at_aud.append(client_id) - return self.handler["access_token"]( - sid=sid, sinfo=session_info, uinfo=uinfo, aud=at_aud, client_id=client_id - ) - - def upgrade_to_token( - self, - grant=None, - issue_refresh=False, - id_token="", - oidreq=None, - key=None, - scope=None, - ): - """ - - :param grant: The access grant - :param issue_refresh: If a refresh token should be issued - :param id_token: An IDToken instance - :param oidreq: An OpenIDRequest instance - :param key: The session key. One of grant or key must be given. - :return: The session information as a SessionInfo instance - """ - if grant: - # The caller is responsible for checking if the access code exists. - _tinfo = self.handler["code"].info(grant) - - key = _tinfo["sid"] - session_info = self[key] - - # mint a new access token - _at = self._make_at(_tinfo["sid"], session_info) - - # make sure the code can't be used again - self.revoke_token(key, "code", session_info) - else: - session_info = self[key] - _at = self._make_at(key, session_info) - - session_info["access_token"] = _at - session_info["oauth_state"] = "token" - session_info["token_type"] = self.handler["access_token"].token_type - - if scope: - session_info["access_token_scope"] = scope - if id_token: - session_info["id_token"] = id_token - if oidreq: - session_info["oidreq"] = oidreq - - if self.handler["access_token"].lifetime: - session_info["expires_in"] = self.handler["access_token"].lifetime - session_info["expires_at"] = ( - self.handler["access_token"].lifetime + utc_time_sans_frac() - ) - - if issue_refresh and "refresh_token" in self.handler: - session_info = self.replace_refresh_token(key, session_info) - - self[key] = session_info - return session_info - - def refresh_token(self, token, new_refresh=False): - """ - Issue a new access token using a valid refresh token - - :param token: Refresh token - :param new_refresh: Whether a new refresh token should be minted or not - :return: Dictionary with session info - :raises: ExpiredToken for invalid refresh token - WrongTokenType for wrong token type - """ - try: - _tinfo = self.handler["refresh_token"].info(token) - except KeyError: - return False - - _sid = _tinfo["sid"] - session_info = self[_sid] - if token != session_info.get("refresh_token"): - raise UnknownToken() - if is_expired(int(_tinfo["exp"])): - raise ExpiredToken() - - session_info["access_token"] = self._make_at(_sid, session_info) - session_info["token_type"] = self.handler["access_token"].token_type - - if new_refresh: - session_info = self.replace_refresh_token(_sid, session_info) - - self[_sid] = session_info - return session_info - - def is_token_valid(self, token): - """ - Checks validity of a given token - - :param token: Access or refresh token - """ - - try: - _tinfo = self.handler.info(token) - except KeyError: - return False - - # Dependent on what state the session is in. - session_info = self[_tinfo["sid"]] - if is_expired(int(_tinfo["exp"])): - return False - - if session_info["oauth_state"] == "authz": - if _tinfo["handler"] != self.handler["code"]: - return False - elif session_info["oauth_state"] == "token": - if _tinfo["handler"] != self.handler["access_token"]: - return False - - return True - - def revoke_token(self, sid, token_type, session_info=None): - """ - Revokes token - - :param sid: session id - :param token_type: token type, one of "code", "access_token" or - "refresh_token" - """ - if not session_info: - session_info = self[sid] - session_info.pop(token_type, None) - if token_type == 'code': - session_info['code_is_used'] = True - self[sid] = session_info - - def revoke_all_tokens(self, token): - sid = self.handler.sid(token) - _sinfo = self[sid] - for token_type in self.handler.keys(): - _sinfo.pop(token_type, None) - self[sid] = _sinfo - - def revoke_session(self, sid="", token=""): - """ - Mark session as revoked but also explicitly revoke all issued tokens - - :param token: any token connected to the session - :param sid: Session identifier - """ - if not sid: - if token: - sid = self.handler.sid(token) - else: - raise ValueError('Need one of "sid" or "token"') - - _sinfo = self[sid] - for token_type in self.handler.keys(): - _sinfo.pop(token_type, None) - _sinfo["revoked"] = True - self[sid] = _sinfo - - def get_client_id_for_session(self, sid): - return self[sid]["client_id"] - - def get_active_client_ids_for_uid(self, uid): - res = [] - for sid in self.sso_db.get_sids_by_uid(uid): - if "revoked" not in self[sid]: - res.append(self[sid]["client_id"]) - return res - - def get_verified_logout(self, uid): - res = {} - for sid in self.sso_db.get_sids_by_uid(uid): - session_info = self[sid] - try: - res[session_info["client_id"]] = session_info["verified_logout"] - except KeyError: - res[session_info["client_id"]] = False - return res - - def match_session(self, uid, **kwargs): - for sid in self.sso_db.get_sids_by_uid(uid): - session_info = self[sid] - if dict_match(kwargs, session_info): - return sid - return None - - def set_verify_logout(self, uid, client_id): - sid = self.match_session(uid, client_id=client_id) - self.update(sid, verified_logout=True) - - def get_id_token(self, uid, client_id): - sid = self.match_session(uid, client_id=client_id) - return self[sid]["id_token"] - - def is_session_revoked(self, key): - try: - session_info = self[key] - except Exception: - raise UnknownToken(key) - - try: - return session_info["revoked"] - except KeyError: - return False - - def revoke_uid(self, uid): - # Revoke all sessions - for sid in self.sso_db.get_sids_by_uid(uid): - self.update(sid, revoked=True) - - # Remove the uid from the SSO db - self.sso_db.remove_uid(uid) - - def read(self, token): - try: - _tinfo = self.handler["access_token"].info(token) - except WrongTokenType: - return {} - else: - return self[_tinfo["sid"]] - - def find_sid(self, req): - """ - Given a request with some info find the correct session. - The useful claim is 'code'. - - :param req: An AccessTokenRequest instance - :return: session ID or None - """ - - return self.get_sid_by_kv(req["code"], "code") - - def get_authentication_event(self, sid): - try: - session_info = self[sid] - except Exception: - raise UnknownToken(sid) - else: - sesinf = session_info.get("authn_event") - return sesinf or ValueError("No Authn event info") - - -def create_session_db(ec, token_handler_args, db=None, sso_db=None, sub_func=None): - _token_handler = token_handler.factory(ec, **token_handler_args) - # if db is None: - # db = InMemoryDataBase() - # else: - # db = db - # - # if sso_db is None: - # sso_db = SSODb() - # else: - # sso_db = sso_db - - return SessionDB(db, _token_handler, sso_db, sub_func=sub_func) diff --git a/src/oidcendpoint/session/__init__.py b/src/oidcendpoint/session/__init__.py new file mode 100644 index 0000000..0757c4e --- /dev/null +++ b/src/oidcendpoint/session/__init__.py @@ -0,0 +1,19 @@ +from typing import List + +DIVIDER = ";;" + + +def session_key(*args) -> str: + return DIVIDER.join(args) + + +def unpack_session_key(key: str) -> List[str]: + return key.split(DIVIDER) + + +class Revoked(Exception): + pass + + +class MintingNotAllowed(Exception): + pass diff --git a/src/oidcendpoint/session/claims.py b/src/oidcendpoint/session/claims.py new file mode 100755 index 0000000..b389cb4 --- /dev/null +++ b/src/oidcendpoint/session/claims.py @@ -0,0 +1,185 @@ +import logging +from typing import Optional +from typing import Union + +from oidcmsg.oidc import OpenIDSchema + +from oidcendpoint.scopes import convert_scopes2claims +from oidcendpoint.session import unpack_session_key + +logger = logging.getLogger(__name__) + +# USAGE = Literal["userinfo", "id_token", "introspection"] + +IGNORE = ["error", "error_description", "error_uri", "_claim_names", "_claim_sources"] +STANDARD_CLAIMS = [c for c in OpenIDSchema.c_param.keys() if c not in IGNORE] + + +def available_claims(endpoint_context): + _supported = endpoint_context.provider_info.get("claims_supported") + if _supported: + return _supported + else: + return STANDARD_CLAIMS + + +class ClaimsInterface: + init_args = { + "add_claims_by_scope": False, + "enable_claims_per_client": False + } + + def __init__(self, endpoint_context): + self.endpoint_context = endpoint_context + + def authorization_request_claims(self, session_id: str, usage: Optional[str] = "") -> dict: + if usage in ["id_token", "userinfo"]: + _grant = self.endpoint_context.session_manager.get_grant(session_id) + if "claims" in _grant.authorization_request: + return _grant.authorization_request["claims"].get(usage, {}) + + return {} + + def _get_client_claims(self, client_id, usage): + client_info = self.endpoint_context.cdb.get(client_id, {}) + client_claims = client_info.get("{}_claims".format(usage), {}) + if isinstance(client_claims, list): + client_claims = {k: None for k in client_claims} + return client_claims + + def get_claims(self, session_id: str, scopes: str, usage: str) -> dict: + """ + + :param session_id: Session identifier + :param scopes: Scopes + :param usage: Where to use the claims. One of "userinfo"/"id_token"/"introspection" + :return: Claims specification as a dictionary. + """ + + # which endpoint module configuration to get the base claims from + module = None + if usage == "userinfo": + if "userinfo" in self.endpoint_context.endpoint: + module = self.endpoint_context.endpoint["userinfo"] + elif usage == "id_token": + if self.endpoint_context.idtoken: + module = self.endpoint_context.idtoken + elif usage == "introspection": + if "introspection" in self.endpoint_context.endpoint: + module = self.endpoint_context.endpoint["introspection"] + elif usage == "access_token": + try: + module = self.endpoint_context.session_manager.token_handler["access_token"] + except KeyError: + pass + + if module: + base_claims = module.kwargs.get("base_claims", {}) + else: + base_claims = {} + + user_id, client_id, grant_id = unpack_session_key(session_id) + + # Can there be per client specification of which claims to use. + if module and module.kwargs.get("enable_claims_per_client"): + claims = self._get_client_claims(client_id, usage) + else: + claims = {} + + claims.update(base_claims) + + # Scopes can in some cases equate to set of claims, is that used here ? + if module and module.kwargs.get("add_claims_by_scope"): + if scopes: + _scopes = self.endpoint_context.scopes_handler.filter_scopes( + client_id, self.endpoint_context, scopes + ) + + _claims = convert_scopes2claims( + _scopes, map=self.endpoint_context.scope2claims + ) + claims.update(_claims) + + # Bring in claims specification from the authorization request + request_claims = self.authorization_request_claims(session_id=session_id, + usage=usage) + + # This will add claims that has not be added before and + # set filters on those claims that also appears in one of the sources above + if request_claims: + claims.update(request_claims) + + return claims + + def get_claims_all_usage(self, session_id: str, scopes: str) -> dict: + _claims = {} + for usage in ["userinfo", "introspection", "id_token", "token"]: + _claims.update(self.get_claims(session_id, scopes, usage)) + return _claims + + def get_user_claims(self, user_id: str, claims_restriction: dict) -> dict: + """ + + :param user_id: User identifier + :param claims_restriction: Specifies the upper limit of which claims can be returned + :return: + """ + if claims_restriction: + # Get all possible claims + user_info = self.endpoint_context.userinfo(user_id, client_id=None) + # Filter out the once that can be returned + return {k: user_info.get(k) for k, v in claims_restriction.items() if + claims_match(user_info.get(k), v)} + else: + return {} + + +def claims_match(value: Union[str, int], claimspec: Optional[dict]) -> bool: + """ + Implements matching according to section 5.5.1 of + http://openid.net/specs/openid-connect-core-1_0.html + The lack of value is not checked here. + Also the text doesn't prohibit having both 'value' and 'values'. + + :param value: single value + :param claimspec: None or dictionary with 'essential', 'value' or 'values' + as key + :return: Boolean + """ + if value is None: + return False + + if claimspec is None: # match anything + return True + + matched = False + for key, val in claimspec.items(): + if key == "value": + if value == val: + matched = True + elif key == "values": + if value in val: + matched = True + elif key == "essential": + # Whether it's essential or not doesn't change anything here + continue + + if matched: + break + + if matched is False: + if list(claimspec.keys()) == ["essential"]: + return True + + return matched + + +def by_schema(cls, **kwa): + """ + Will return only those claims that are listed in the Class definition. + + :param cls: A subclass of :py:class:´oidcmsg.message.Message` + :param kwa: Keyword arguments + :return: A dictionary with claims (keys) that meets the filter criteria + """ + return dict([(key, val) for key, val in kwa.items() if key in cls.c_param]) diff --git a/src/oidcendpoint/session/database.py b/src/oidcendpoint/session/database.py new file mode 100644 index 0000000..301a7fb --- /dev/null +++ b/src/oidcendpoint/session/database.py @@ -0,0 +1,147 @@ +import logging +from typing import List +from typing import Union + +from . import session_key +from .grant import Grant +from .info import ClientSessionInfo +from .info import SessionInfo +from .info import UserSessionInfo + +logger = logging.getLogger(__name__) + + +class NoSuchClientSession(KeyError): + pass + + +class NoSuchGrant(KeyError): + pass + + +class Database(object): + def __init__(self, storage=None): + if storage is None: + self._db = {} + else: + self._db = storage + + @staticmethod + def _eval_path(path: List[str]): + uid = path[0] + client_id = None + grant_id = None + if len(path) > 1: + client_id = path[1] + if len(path) > 2: + grant_id = path[2] + + return uid, client_id, grant_id + + def set(self, path: List[str], value: Union[SessionInfo, Grant]): + """ + + :param path: a list of identifiers + :param value: Class instance to be stored + """ + + uid, client_id, grant_id = self._eval_path(path) + + if grant_id: + gid_key = session_key(uid, client_id, grant_id) + self._db[gid_key] = value + + if client_id: + cid_key = session_key(uid, client_id) + cid_info = self._db.get(cid_key, ClientSessionInfo()) + if not grant_id: + self._db[cid_key] = value + elif grant_id not in cid_info["subordinate"]: + cid_info.add_subordinate(grant_id) + self._db[cid_key] = cid_info + + userinfo = self._db.get(uid, UserSessionInfo()) + if client_id is None: + self._db[uid] = value + if client_id and client_id not in userinfo["subordinate"]: + userinfo.add_subordinate(client_id) + self._db[uid] = userinfo + + def get(self, path: List[str]) -> Union[SessionInfo, Grant]: + uid, client_id, grant_id = self._eval_path(path) + try: + user_info = self._db[uid] + except KeyError: + raise KeyError('No such UserID') + else: + if user_info is None: + raise KeyError('No such UserID') + + if client_id is None: + return user_info + else: + if client_id not in user_info['subordinate']: + raise ValueError('No session from that client for that user') + else: + try: + client_session_info = self._db[session_key(uid, client_id)] + except KeyError: + raise NoSuchClientSession(client_id) + else: + if grant_id is None: + return client_session_info + + if grant_id not in client_session_info['subordinate']: + raise ValueError('No such grant for that user and client') + else: + try: + return self._db[session_key(uid, client_id, grant_id)] + except KeyError: + raise NoSuchGrant(grant_id) + + def delete(self, path: List[str]): + uid, client_id, grant_id = self._eval_path(path) + try: + _user_info = self._db[uid] + except KeyError: + pass + else: + if client_id: + if client_id in _user_info['subordinate']: + try: + _client_info = self._db[session_key(uid, client_id)] + except KeyError: + pass + else: + if grant_id: + if grant_id in _client_info['subordinate']: + try: + self._db.__delitem__(session_key(uid, client_id, grant_id)) + except KeyError: + pass + _client_info["subordinate"].remove(grant_id) + else: + for grant_id in _client_info['subordinate']: + self._db.__delitem__(session_key(uid, client_id, grant_id)) + _client_info['subordinate'] = [] + + if len(_client_info['subordinate']) == 0: + self._db.__delitem__(session_key(uid, client_id)) + _user_info["subordinate"].remove(client_id) + else: + self._db[client_id] = _client_info + + if len(_user_info["subordinate"]) == 0: + self._db.__delitem__(uid) + else: + self._db[uid] = _user_info + else: + pass + else: + self._db.__delitem__(uid) + + def update(self, path: List[str], new_info: dict): + _info = self.get(path) + for key, val in new_info.items(): + _info[key] = val + self.set(path, _info) diff --git a/src/oidcendpoint/session/grant.py b/src/oidcendpoint/session/grant.py new file mode 100644 index 0000000..ac846f1 --- /dev/null +++ b/src/oidcendpoint/session/grant.py @@ -0,0 +1,350 @@ +import json +from typing import Optional +from uuid import uuid1 + +from oidcmsg.message import Message +from oidcmsg.message import OPTIONAL_LIST_OF_SP_SEP_STRINGS +from oidcmsg.message import OPTIONAL_LIST_OF_STRINGS +from oidcmsg.message import SINGLE_OPTIONAL_JSON +from oidcmsg.oauth2 import AuthorizationRequest + +from oidcendpoint.authn_event import AuthnEvent +from oidcendpoint.session import MintingNotAllowed +from oidcendpoint.session import unpack_session_key +from oidcendpoint.session.token import AccessToken +from oidcendpoint.session.token import AuthorizationCode +from oidcendpoint.session.token import Item +from oidcendpoint.session.token import RefreshToken +from oidcendpoint.session.token import Token +from oidcendpoint.token import Token as TokenHandler +from oidcendpoint.util import importer + + +class GrantMessage(Message): + c_param = { + "scope": OPTIONAL_LIST_OF_SP_SEP_STRINGS, # As defined in RFC6749 + "authorization_details": SINGLE_OPTIONAL_JSON, # As defined in draft-lodderstedt-oauth-rar + "claims": SINGLE_OPTIONAL_JSON, # As defined in OIDC core + "resources": OPTIONAL_LIST_OF_STRINGS, # As defined in RFC8707 + } + + +GRANT_TYPE_MAP = { + "authorization_code": "code", + "access_token": "access_token", + "refresh_token": "refresh_token" +} + + +def find_token(issued, token_id): + for iss in issued: + if iss.id == token_id: + return iss + return None + + +TOKEN_MAP = { + "authorization_code": AuthorizationCode, + "access_token": AccessToken, + "refresh_token": RefreshToken +} + + +class Grant(Item): + attributes = ["scope", "claim", "resources", "authorization_details", + "issued_token", "usage_rules", "revoked", "issued_at", + "expires_at", "sub", "authorization_request", + "authentication_event"] + type = "grant" + + def __init__(self, + scope: Optional[list] = None, + claims: Optional[dict] = None, + resources: Optional[list] = None, + authorization_details: Optional[dict] = None, + authorization_request: Optional[Message] = None, + authentication_event: Optional[AuthnEvent] = None, + issued_token: Optional[list] = None, + usage_rules: Optional[dict] = None, + issued_at: int = 0, + expires_in: int = 0, + expires_at: int = 0, + revoked: bool = False, + token_map: Optional[dict] = None, + sub: Optional[str] = ""): + Item.__init__(self, usage_rules=usage_rules, issued_at=issued_at, + expires_in=expires_in, expires_at=expires_at, revoked=revoked) + self.scope = scope or [] + self.authorization_details = authorization_details or None + self.authorization_request = authorization_request or None + self.authentication_event = authentication_event or None + self.claims = claims or {} # default is to not release any user information + self.resources = resources or [] + self.issued_token = issued_token or [] + self.id = uuid1().hex + self.sub = sub + + if token_map is None: + self.token_map = TOKEN_MAP + else: + self.token_map = token_map + + def get(self) -> object: + return GrantMessage(scope=self.scope, claims=self.claims, + authorization_details=self.authorization_details, + resources=self.resources) + + def to_json(self) -> str: + d = { + "type": self.type, + "scope": self.scope, + "sub": self.sub, + "authorization_details": self.authorization_details, + "claims": self.claims, + "resources": self.resources, + "issued_at": self.issued_at, + "not_before": self.not_before, + "expires_at": self.expires_at, + "revoked": self.revoked, + "issued_token": [t.to_json() for t in self.issued_token], + "id": self.id, + "usage_rules": self.usage_rules, + "token_map": {k: ".".join([v.__module__, v.__name__]) for k, v in + self.token_map.items()} + } + + if self.authorization_request: + if isinstance(self.authorization_request, Message): + d["authorization_request"] = self.authorization_request.to_dict() + elif isinstance(self.authorization_request, dict): + d["authorization_request"] = self.authorization_request + + if self.authentication_event: + if isinstance(self.authentication_event, Message): + d["authentication_event"] = self.authentication_event.to_dict() + elif isinstance(self.authentication_event, dict): + d["authentication_event"] = self.authentication_event + + return json.dumps(d) + + def from_json(self, json_str: str) -> 'Grant': + d = json.loads(json_str) + for attr in ["scope", "authorization_details", "claims", "resources", "sub", + "issued_at", "not_before", "expires_at", "revoked", "id", + "usage_rules"]: + if attr in d: + setattr(self, attr, d[attr]) + + if "authentication_event" in d: + self.authentication_event = AuthnEvent(**d["authentication_event"]) + if "authorization_request" in d: + self.authorization_request = AuthorizationRequest(**d["authorization_request"]) + + if "token_map" in d: + self.token_map = {k: importer(v) for k, v in d["token_map"].items()} + else: + self.token_map = TOKEN_MAP + + if "issued_token" in d: + _it = [] + for js in d["issued_token"]: + args = json.loads(js) + _it.append(self.token_map[args["type"]](**args)) + setattr(self, "issued_token", _it) + + return self + + def payload_arguments(self, session_id: str, endpoint_context, + token_type: str, scope: Optional[dict] = None) -> dict: + """ + + :return: dictionary containing information to place in a token value + """ + if not scope: + scope = self.scope + + payload = { + "scope": scope, + "aud": self.resources + } + + _claims_restriction = endpoint_context.claims_interface.get_claims(session_id, + scopes=scope, + usage=token_type) + user_id, _, _ = unpack_session_key(session_id) + user_info = endpoint_context.claims_interface.get_user_claims(user_id, + _claims_restriction) + payload.update(user_info) + + return payload + + def mint_token(self, + session_id: str, + endpoint_context: object, + token_type: str, + token_handler: TokenHandler = None, + based_on: Optional[Token] = None, + usage_rules: Optional[dict] = None, + scope: Optional[list] = None, + **kwargs) -> Optional[Token]: + """ + + :param session_id: + :param endpoint_context: + :param token_type: + :param token_handler: + :param based_on: + :param usage_rules: + :param scope: + :param kwargs: + :return: + """ + if self.is_active() is False: + return None + + if based_on: + if based_on.supports_minting(token_type) and based_on.is_active(): + _base_on_ref = based_on.value + else: + raise MintingNotAllowed() + else: + _base_on_ref = None + + if usage_rules is None and token_type in self.usage_rules: + usage_rules = self.usage_rules[token_type] + + token_class = self.token_map.get(token_type) + if token_class: + item = token_class(type=token_type, + based_on=_base_on_ref, + usage_rules=usage_rules, + scope=scope, + **kwargs) + if token_handler is None: + token_handler = endpoint_context.session_manager.token_handler.handler[ + GRANT_TYPE_MAP[token_type]] + + item.value = token_handler(session_id=session_id, + **self.payload_arguments(session_id, + endpoint_context, + token_type=token_type, + scope=scope)) + else: + raise ValueError("Can not mint that kind of token") + + self.issued_token.append(item) + self.used += 1 + return item + + def get_token(self, value: str) -> Optional[Token]: + for t in self.issued_token: + if t.value == value: + return t + return None + + def revoke_token(self, + value: Optional[str] = "", + based_on: Optional[str] = "", + recursive: bool = True): + for t in self.issued_token: + if not value and not based_on: + t.revoked = True + elif value and based_on: + if value == t.value and based_on == t.based_on: + t.revoked = True + elif value and t.value == value: + t.revoked = True + if recursive: + self.revoke_token(based_on=t.value) + elif based_on and t.based_on == based_on: + t.revoked = True + if recursive: + self.revoke_token(based_on=t.value) + + def get_spec(self, token: Token) -> Optional[dict]: + if self.is_active() is False or token.is_active is False: + return None + + res = {} + for attr in ["scope", "claims", "resources"]: + _val = getattr(token, attr) + if _val: + res[attr] = _val + else: + _val = getattr(self, attr) + if _val: + res[attr] = _val + return res + + +DEFAULT_USAGE = { + "authorization_code": { + "max_usage": 1, + "supports_minting": ["access_token", "refresh_token", "id_token"], + "expires_in": 300 + }, + "access_token": { + "supports_minting": [], + "expires_in": 3600 + }, + "refresh_token": { + "supports_minting": ["access_token", "refresh_token", "id_token"] + } +} + + +def get_usage_rules(token_type, endpoint_context, grant, client_id): + """ + The order of importance: + Grant, Client, EndPointContext, Default + + :param token_type: The type of token + :param endpoint_context: An EndpointContext instance + :param grant: A Grant instance + :param client_id: The client identifier + :return: Usage specification + """ + + _usage = endpoint_context.authz.usage_rules_for(client_id, token_type) + if not _usage: + _usage = DEFAULT_USAGE[token_type] + + _grant_usage = grant.usage_rules.get(token_type) + if _grant_usage: + _usage.update(_grant_usage) + + return _usage + + +class ExchangeGrant(Grant): + attributes = Grant.attributes + attributes.append("users") + type = "exchange_grant" + + def __init__(self, + scope: Optional[list] = None, + claims: Optional[dict] = None, + resources: Optional[list] = None, + authorization_details: Optional[dict] = None, + issued_token: Optional[list] = None, + usage_rules: Optional[dict] = None, + issued_at: int = 0, + expires_in: int = 0, + expires_at: int = 0, + revoked: bool = False, + token_map: Optional[dict] = None, + users: list = None): + Grant.__init__(self, scope=scope, claims=claims, resources=resources, + authorization_details=authorization_details, + issued_token=issued_token, usage_rules=usage_rules, + issued_at=issued_at, expires_in=expires_in, + expires_at=expires_at, revoked=revoked, + token_map=token_map) + + self.users = users or [] + self.usage_rules = { + "access_token": { + "supports_minting": ["access_token"], + "expires_in": 60 + } + } diff --git a/src/oidcendpoint/session/info.py b/src/oidcendpoint/session/info.py new file mode 100644 index 0000000..4126db0 --- /dev/null +++ b/src/oidcendpoint/session/info.py @@ -0,0 +1,70 @@ +from typing import Tuple + +from oidcmsg.message import OPTIONAL_LIST_OF_MESSAGES +from oidcmsg.message import SINGLE_OPTIONAL_STRING +from oidcmsg.message import SINGLE_REQUIRED_BOOLEAN +from oidcmsg.message import SINGLE_REQUIRED_STRING +from oidcmsg.message import Message +from oidcmsg.message import list_deserializer +from oidcmsg.message import list_serializer + +from oidcendpoint.session.grant import Grant +from oidcendpoint.session.token import Token + + +class SessionInfo(Message): + c_param = { + "subordinate": ([str], False, list_serializer, list_deserializer, True), + "revoked": SINGLE_REQUIRED_BOOLEAN, + "type": SINGLE_REQUIRED_STRING + } + + def __init__(self, **kwargs): + Message.__init__(self, **kwargs) + if "subordinate" not in self: + self["subordinate"] = [] + self["revoked"] = False + + def add_subordinate(self, value: str) -> "SessionInfo": + if value not in self["subordinate"]: + self["subordinate"].append(value) + return self + + def remove_subordinate(self, value: str) -> 'SessionInfo': + self["subordinate"].remove(value) + return self + + def revoke(self) -> 'SessionInfo': + self["revoked"] = True + return self + + def is_revoked(self) -> bool: + return self["revoked"] + + +class UserSessionInfo(SessionInfo): + c_param = SessionInfo.c_param.copy() + c_param.update({ + "user_id": SINGLE_REQUIRED_STRING, + }) + + def __init__(self, **kwargs): + SessionInfo.__init__(self, **kwargs) + self["type"] = "UserSessionInfo" + + +class ClientSessionInfo(SessionInfo): + c_param = SessionInfo.c_param.copy() + c_param.update({ + "client_id": SINGLE_REQUIRED_STRING + }) + + def __init__(self, **kwargs): + SessionInfo.__init__(self, **kwargs) + self["type"] = "ClientSessionInfo" + + def find_grant_and_token(self, val: str) -> Tuple[Grant, Token]: + for grant in self["subordinate"]: + token = grant.get_token(val) + if token: + return grant, token diff --git a/src/oidcendpoint/session/manager.py b/src/oidcendpoint/session/manager.py new file mode 100644 index 0000000..b297ca3 --- /dev/null +++ b/src/oidcendpoint/session/manager.py @@ -0,0 +1,437 @@ +import hashlib +import logging +import uuid +from typing import List +from typing import Optional + +from oidcmsg.oauth2 import AuthorizationRequest + +from oidcendpoint import rndstr +from oidcendpoint.authn_event import AuthnEvent +from oidcendpoint.token import handler + +from ..token import UnknownToken +from ..token.handler import TokenHandler +from . import session_key +from . import unpack_session_key +from .database import Database +from .grant import Grant +from .grant import Token +from .info import ClientSessionInfo +from .info import UserSessionInfo + +logger = logging.getLogger(__name__) + + +def pairwise_id(uid, sector_identifier, salt="", **kwargs): + return hashlib.sha256(("%s%s%s" % (uid, sector_identifier, salt)).encode("utf-8")).hexdigest() + + +def public_id(uid, salt="", **kwargs): + return hashlib.sha256("{}{}".format(uid, salt).encode("utf-8")).hexdigest() + + +def ephemeral_id(*args, **kwargs): + return uuid.uuid4().hex + + +class SessionManager(Database): + def __init__(self, handler: TokenHandler, db: Optional[object] = None, + conf: Optional[dict] = None, sub_func: Optional[dict] = None): + Database.__init__(self, db) + self.token_handler = handler + self.salt = rndstr(32) + self.conf = conf or {} + + # this allows the subject identifier minters to be defined by someone + # else then me. + if sub_func is None: + self.sub_func = { + "public": public_id, + "pairwise": pairwise_id, + "ephemeral": ephemeral_id + } + else: + self.sub_func = sub_func + if "public" not in sub_func: + self.sub_func["public"] = public_id + if "pairwise" not in sub_func: + self.sub_func["pairwise"] = pairwise_id + if "ephemeral" not in sub_func: + self.sub_func["ephemeral"] = ephemeral_id + + def get_user_info(self, uid: str) -> UserSessionInfo: + usi = self.get([uid]) + if isinstance(usi, UserSessionInfo): + return usi + else: + raise ValueError("Not UserSessionInfo") + + def find_token(self, session_id: str, token_value: str) -> Optional[Token]: + """ + + :param session_id: Based on 3-tuple, user_id, client_id and grant_id + :param token_value: + :return: + """ + user_id, client_id, grant_id = unpack_session_key(session_id) + grant = self.get([user_id, client_id, grant_id]) + for token in grant.issued_token: + if token.value == token_value: + return token + + return None + + def create_grant(self, + authn_event: AuthnEvent, + auth_req: AuthorizationRequest, + user_id: str, + client_id: Optional[str] = "", + sub_type: Optional[str] = "public", + token_usage_rules: Optional[dict] = None, + scopes: Optional[list] = None + ) -> str: + """ + + :param scopes: Scopes + :param authn_event: AuthnEvent instance + :param auth_req: + :param user_id: + :param client_id: + :param sub_type: + :param token_usage_rules: + :return: + """ + try: + sector_identifier = auth_req.get("sector_identifier_uri") + except AttributeError: + sector_identifier = "" + + grant = Grant(authorization_request=auth_req, + authentication_event=authn_event, + sub=self.sub_func[sub_type]( + user_id, salt=self.salt, + sector_identifier=sector_identifier), + usage_rules=token_usage_rules, + scope=scopes + ) + + self.set([user_id, client_id, grant.id], grant) + + return session_key(user_id, client_id, grant.id) + + def create_session(self, + authn_event: AuthnEvent, + auth_req: AuthorizationRequest, + user_id: str, + client_id: Optional[str] = "", + sub_type: Optional[str] = "public", + token_usage_rules: Optional[dict] = None, + scopes: Optional[list] = None + ) -> str: + """ + Create part of a user session. The parts added are user- and client + information and a grant. + + :param scopes: + :param authn_event: Authentication Event information + :param auth_req: Authorization Request + :param client_id: Client ID + :param user_id: User ID + :param sector_identifier: Identifier for a group of websites under common administrative + control. + :param sub_type: What kind of subject will be assigned + :param token_usage_rules: Rules for how tokens can be used + :return: Session key + """ + + try: + _usi = self.get([user_id]) + except KeyError: + _usi = UserSessionInfo(user_id=user_id) + self.set([user_id], _usi) + + if not client_id: + client_id = auth_req['client_id'] + + client_info = ClientSessionInfo( + client_id=client_id, + ) + + self.set([user_id, client_id], client_info) + + return self.create_grant( + auth_req=auth_req, + authn_event=authn_event, + user_id=user_id, + client_id=client_id, + sub_type=sub_type, + token_usage_rules=token_usage_rules, + scopes=scopes + ) + + def __getitem__(self, session_id: str): + return self.get(unpack_session_key(session_id)) + + def __setitem__(self, session_id: str, value): + return self.set(unpack_session_key(session_id), value) + + def get_client_session_info(self, session_id: str) -> ClientSessionInfo: + """ + Return client connected information for a user session. + + :param session_id: Session identifier + :return: ClientSessionInfo instance + """ + _user_id, _client_id, _grant_id = unpack_session_key(session_id) + csi = self.get([_user_id, _client_id]) + if isinstance(csi, ClientSessionInfo): + return csi + else: + raise ValueError("Wrong type of session info") + + def get_user_session_info(self, session_id: str) -> UserSessionInfo: + """ + Return client connected information for a user session. + + :param session_id: Session identifier + :return: ClientSessionInfo instance + """ + _user_id, _client_id, _grant_id = unpack_session_key(session_id) + usi = self.get([_user_id]) + if isinstance(usi, UserSessionInfo): + return usi + else: + raise ValueError("Wrong type of session info") + + def get_grant(self, session_id: str) -> Grant: + """ + Return client connected information for a user session. + + :param session_id: Session identifier + :return: ClientSessionInfo instance + """ + _user_id, _client_id, _grant_id = unpack_session_key(session_id) + grant = self.get([_user_id, _client_id, _grant_id]) + if isinstance(grant, Grant): + return grant + else: + raise ValueError("Wrong type of item") + + def _revoke_dependent(self, grant: Grant, token: Token): + for t in grant.issued_token: + if t.based_on == token.value: + t.revoked = True + self._revoke_dependent(grant, t) + + def revoke_token(self, session_id: str, token_value: str, recursive: bool = False): + """ + Revoke a specific token that belongs to a specific user session. + + :param session_id: Session identifier + :param token_value: Token value + :param recursive: Revoke all tokens that was minted using this token or + tokens minted by this token. Recursively. + """ + token = self.find_token(session_id, token_value) + if token is None: + raise UnknownToken() + + token.revoked = True + if recursive: + grant = self[session_id] + self._revoke_dependent(grant, token) + + def get_authentication_events(self, session_id: Optional[str] = "", + user_id: Optional[str] = "", + client_id: Optional[str] = "") -> List[AuthnEvent]: + """ + Return the authentication events that exists for a user/client combination. + + :param session_id: A session identifier + :return: None if no authentication event could be found or an AuthnEvent instance. + """ + if session_id: + user_id, client_id, _ = unpack_session_key(session_id) + elif user_id and client_id: + pass + else: + raise AttributeError("Must have session_id or user_id and client_id") + + c_info = self.get([user_id, client_id]) + + _grants = [self.get([user_id, client_id, gid]) for gid in c_info['subordinate']] + return [g.authentication_event for g in _grants] + + def get_authorization_request(self, session_id): + res = self.get_session_info(session_id=session_id, authorization_request=True) + return res["authorization_request"] + + def get_authentication_event(self, session_id): + res = self.get_session_info(session_id=session_id, authentication_event=True) + return res["authentication_event"] + + def revoke_client_session(self, session_id: str): + """ + Revokes a client session + + :param session_id: Session identifier + """ + parts = unpack_session_key(session_id) + if len(parts) == 2: + _user_id, _client_id = parts + elif len(parts) == 3: + _user_id, _client_id, _ = parts + else: + raise ValueError("Invalid session ID") + + _info = self.get([_user_id, _client_id]) + self.set([_user_id, _client_id], _info.revoke()) + + def revoke_grant(self, session_id: str): + """ + Revokes the grant pointed to by a session identifier. + + :param session_id: A session identifier + """ + _path = unpack_session_key(session_id) + _info = self.get(_path) + _info.revoke() + self.set(_path, _info) + + def grants(self, + session_id: Optional[str] = "", + user_id: Optional[str] = "", + client_id: Optional[str] = "") -> List[Grant]: + """ + Find all grant connected to a user session + + :param session_id: A session identifier + :return: A list of grants + """ + if session_id: + user_id, client_id, _ = unpack_session_key(session_id) + elif user_id and client_id: + pass + else: + raise AttributeError("Must have session_id or user_id and client_id") + + _csi = self.get([user_id, client_id]) + return [self.get([user_id, client_id, gid]) for gid in _csi['subordinate']] + + def get_session_info(self, + session_id: str, + user_session_info: bool = False, + client_session_info: bool = False, + grant: bool = False, + authentication_event: bool = False, + authorization_request: bool = False) -> dict: + """ + Returns information connected to a session. + + :param session_id: The identifier of the session + :param user_session_info: Whether user session info should part of the response + :param client_session_info: Whether client session info should part of the response + :param grant: Whether the grant should part of the response + :param authentication_event: Whether the authentication event information should part of + the response + :param authorization_request: Whether the authorization_request should part of the response + :return: A dictionary with session information + """ + _user_id, _client_id, _grant_id = unpack_session_key(session_id) + _grant = None + res = { + "session_id": session_id, + "user_id": _user_id, + "client_id": _client_id, + "grant_id": _grant_id + } + if user_session_info: + res["user_session_info"] = self.get([_user_id]) + if client_session_info: + res["client_session_info"] = self.get([_user_id, _client_id]) + if grant: + res["grant"] = self.get([_user_id, _client_id, _grant_id]) + + if authentication_event: + if grant: + res["authentication_event"] = res["grant"]["authentication_event"] + else: + _grant = self.get([_user_id, _client_id, _grant_id]) + res["authentication_event"] = _grant.authentication_event + + if authorization_request: + if grant: + res["authorization_request"] = res["grant"].authorization_request + elif _grant: + res["authorization_request"] = _grant.authorization_request + else: + _grant = self.get([_user_id, _client_id, _grant_id]) + res["authorization_request"] = _grant.authorization_request + + return res + + def get_session_info_by_token(self, + token_value: str, + user_session_info: bool = False, + client_session_info: bool = False, + grant: bool = False, + authentication_event: bool = False, + authorization_request: bool = False + ) -> dict: + _token_info = self.token_handler.info(token_value) + return self.get_session_info(_token_info["sid"], + user_session_info=user_session_info, + client_session_info=client_session_info, + grant=grant, + authentication_event=authentication_event, + authorization_request=authorization_request) + + def get_session_id_by_token(self, token_value: str) -> str: + _token_info = self.token_handler.info(token_value) + return _token_info["sid"] + + def add_grant(self, user_id: str, client_id: str, **kwargs) -> Grant: + """ + Creates and adds a grant to a user session. + + :param user_id: User identifier + :param client_id: Client identifier + :param kwargs: Keyword arguments to the Grant class initialization + :return: A Grant instance + """ + args = {k: v for k, v in kwargs.items() if k in Grant.attributes} + _grant = Grant(**args) + self.set([user_id, client_id, _grant.id], _grant) + _client_session_info = self.get([user_id, client_id]) + _client_session_info["subordinate"].append(_grant.id) + self.set([user_id, client_id], _client_session_info) + return _grant + + # def find_grant_by_type_and_target( + # self, + # token: str, + # token_type: Grant, + # resource_server: str) -> Optional[Grant]: + # """ + # + # :param token: + # :param token_type: + # :param resource_server: + # :return: + # """ + # session_id = self.get_session_id_by_token(token) + # for grant in self.grants(session_id=session_id): + # if isinstance(grant, token_type): + # if resource_server in grant.resources: + # return grant + # return None + + def remove_session(self, session_id: str): + _user_id, _client_id, _grant_id = unpack_session_key(session_id) + self.delete([_user_id, _client_id, _grant_id]) + + +def create_session_manager(endpoint_context, token_handler_args, db=None, sub_func=None): + _token_handler = handler.factory(endpoint_context, **token_handler_args) + return SessionManager(_token_handler, db=db, sub_func=sub_func) diff --git a/src/oidcendpoint/session/storage.py b/src/oidcendpoint/session/storage.py new file mode 100644 index 0000000..150467f --- /dev/null +++ b/src/oidcendpoint/session/storage.py @@ -0,0 +1,22 @@ +import json + +from .grant import ExchangeGrant +from .grant import Grant +from .info import ClientSessionInfo +from .info import UserSessionInfo + + +class JSON: + def serialize(self, instance): + return instance.to_json() + + def deserialize(self, js): + args = json.loads(js) + if args["type"] == "UserSessionInfo": + return UserSessionInfo().from_json(js) + elif args["type"] == "ClientSessionInfo": + return ClientSessionInfo().from_json(js) + elif args["type"] == "grant": + return Grant().from_json(js) + elif args["type"] == "exchange_grant": + return ExchangeGrant().from_json(js) diff --git a/src/oidcendpoint/session/token.py b/src/oidcendpoint/session/token.py new file mode 100644 index 0000000..fd8b32b --- /dev/null +++ b/src/oidcendpoint/session/token.py @@ -0,0 +1,163 @@ +import json +from typing import Optional +from uuid import uuid1 + +from oidcmsg.time_util import time_sans_frac + + +class MintingNotAllowed(Exception): + pass + + +class Item: + def __init__(self, + usage_rules: Optional[dict] = None, + issued_at: int = 0, + expires_in: int = 0, + expires_at: int = 0, + not_before: int = 0, + revoked: bool = False, + used: int = 0 + ): + self.issued_at = issued_at or time_sans_frac() + self.not_before = not_before + if expires_at == 0 and expires_in != 0: + self.set_expires_at(expires_in) + else: + self.expires_at = expires_at + + self.revoked = revoked + self.used = used + self.usage_rules = usage_rules or {} + + def set_expires_at(self, expires_in): + self.expires_at = time_sans_frac() + expires_in + + def max_usage_reached(self): + if "max_usage" in self.usage_rules: + return self.used >= self.usage_rules['max_usage'] + else: + return False + + def is_active(self, now=0): + if self.max_usage_reached(): + return False + + if self.revoked: + return False + + if now == 0: + now = time_sans_frac() + + if self.not_before: + if now < self.not_before: + return False + + if self.expires_at: + if now > self.expires_at: + return False + + return True + + def revoke(self): + self.revoked = True + + +class Token(Item): + attributes = ["type", "issued_at", "not_before", "expires_at", "revoked", "value", + "usage_rules", "used", "based_on", "id", "scope", "claims", + "resources"] + + def __init__(self, + type: str = '', + value: str = '', + based_on: Optional[str] = None, + usage_rules: Optional[dict] = None, + issued_at: int = 0, + expires_in: int = 0, + expires_at: int = 0, + not_before: int = 0, + revoked: bool = False, + used: int = 0, + id: str = "", + scope: Optional[list] = None, + claims: Optional[dict] = None, + resources: Optional[list] = None, + ): + Item.__init__(self, usage_rules=usage_rules, issued_at=issued_at, expires_in=expires_in, + expires_at=expires_at, not_before=not_before, revoked=revoked, used=used) + + self.type = type + self.value = value + self.based_on = based_on + self.id = id or uuid1().hex + self.set_defaults() + self.scope = scope or [] + self.claims = claims or {} # default is to not release any user information + self.resources = resources or [] + + def set_defaults(self): + pass + + def register_usage(self): + self.used += 1 + + def has_been_used(self): + return self.used != 0 + + def to_json(self): + d = { + "type": self.type, + "issued_at": self.issued_at, + "not_before": self.not_before, + "expires_at": self.expires_at, + "revoked": self.revoked, + "value": self.value, + "usage_rules": self.usage_rules, + "used": self.used, + "based_on": self.based_on, + "id": self.id, + "scope": self.scope, + "claims": self.claims, + "resources": self.resources + } + return json.dumps(d) + + def from_json(self, json_str): + d = json.loads(json_str) + for attr in self.attributes: + if attr in d: + setattr(self, attr, d[attr]) + return self + + def supports_minting(self, token_type): + _supports_minting = self.usage_rules.get("supports_minting") + if _supports_minting is None: + return False + else: + return token_type in _supports_minting + + +class AccessToken(Token): + pass + + +class AuthorizationCode(Token): + def set_defaults(self): + if "supports_minting" not in self.usage_rules: + self.usage_rules['supports_minting'] = ["access_token", "refresh_token"] + + self.usage_rules['max_usage'] = 1 + + +class RefreshToken(Token): + def set_defaults(self): + if "supports_minting" not in self.usage_rules: + self.usage_rules['supports_minting'] = ["access_token", "refresh_token"] + + +SHORT_TYPE_NAME = { + "authorization_code": "A", + "access_token": "T", + "refresh_token": "R" +} diff --git a/src/oidcendpoint/sso_db.py b/src/oidcendpoint/sso_db.py deleted file mode 100644 index ea30364..0000000 --- a/src/oidcendpoint/sso_db.py +++ /dev/null @@ -1,196 +0,0 @@ -import logging - -from oidcmsg.storage.init import storage_factory - -logger = logging.getLogger(__name__) - - -class DictDatabase: - def __init__(self, db_conf=None, db=None): - if db_conf: - self._db = storage_factory(db_conf) - else: - self._db = db - - def set(self, key, sec_key, value): - logger.debug("SSODb set {} - {}: {}".format(key, sec_key, value)) - # Try loading the key, that's a good place to put a debugger to - # import pdb; pdb.set_trace() - _dic = self._db.get(key, {}) - if sec_key in _dic: - _dic[sec_key].append(value) - else: - _dic[sec_key] = [value] - self._db[key] = _dic - - def get(self, key, sec_key): - _dic = self._db.get(key, {}) - logger.debug("SSODb get {} - {}: {}".format(key, sec_key, _dic)) - return _dic.get(sec_key, []) - - def delete(self, key, sec_key): - _dic = self._db.get(key, {}) - try: - del _dic[sec_key] - except KeyError as e: - logger.warn('SSODb.delete _dic[{}] not found'.format(sec_key)) - if _dic == {}: - del self._db[key] - else: - self._db[key] = _dic - - def remove(self, key, sec_key, value): - _dic = self._db.get(key, {}) - _values = _dic[sec_key] - # full clean up - while value in _values: - _values.remove(value) - # if changes have been made -> update them - - if _values: - _dic[sec_key] = _values - self._db[key] = _dic - else: - del _dic[sec_key] - self._db[key] = _dic - - -class SSODb(DictDatabase): - """ - Keeps the connection between an user, one or more sub claims and - possibly several session IDs. - Each user can be represented by more then one sub claim and every sub claim - can appear in more the one session. - So, we have chains like this:: - - session id->subject id->user id - - """ - - def map_sid2uid(self, sid, uid): - """ - Store the connection between a Session ID and a User ID - - :param sid: Session ID - :param uid: User ID - """ - self.set(sid, "uid", uid) - self.set(uid, "sid", sid) - - def map_sid2sub(self, sid, sub): - """ - Store the connection between a Session ID and a subject ID. - - :param sid: Session ID - :param sub: subject ID - """ - self.set(sid, "sub", sub) - self.set(sub, "sid", sid) - - def get_sids_by_uid(self, uid): - """ - Return a list of session IDs that this user is connected to. - - :param uid: The subject ID - :return: list of session IDs - """ - return self.get(uid, "sid") - - def get_sids_by_sub(self, sub): - return self.get(sub, "sid") - - def get_sub_by_sid(self, sid): - _subs = self.get(sid, "sub") - if _subs: - return _subs[0] - else: - return None - - def get_uid_by_sid(self, sid): - """ - Find the User ID that is connected to a Session ID. - - :param sid: A Session ID - :return: A User ID, always just one - """ - _uids = self.get(sid, "uid") - if _uids: - return _uids[0] - else: - return None - - def get_subs_by_uid(self, uid): - """ - Find all subject identifiers that is connected to a User ID. - - :param uid: A User ID - :return: A set of subject identifiers - """ - res = set() - for sid in self.get(uid, "sid"): - res |= set(self.get(sid, "sub")) - return res - - def remove_sid2sub(self, sid, sub): - """ - Remove the connection between a session ID and a Subject - - :param sid: Session ID - :param sub: Subject identifier -´ """ - self.remove(sub, "sid", sid) - self.remove(sid, "sub", sub) - - def remove_sid2uid(self, sid, uid): - """ - Remove the connection between a session ID and a Subject - - :param sid: Session ID - :param uid: User identifier -´ """ - self.remove(uid, "sid", sid) - self.remove(sid, "uid", uid) - - def remove_session_id(self, sid): - """ - Remove all references to a specific Session ID - - :param sid: A Session ID - """ - for uid in self.get(sid, "uid"): - self.remove(uid, "sid", sid) - if self._db.get(uid) == {}: - del self._db[uid] - self.delete(sid, "uid") - - for sub in self.get(sid, "sub"): - self.remove(sub, "sid", sid) - if self._db.get(sub) == {}: - del self._db[sub] - self.delete(sid, "sub") - - def remove_uid(self, uid): - """ - Remove all references to a specific User ID - - :param uid: A User ID - """ - for sid in self.get(uid, "sid"): - self.remove(sid, "uid", uid) - self.delete(uid, "sid") - - def remove_sub(self, sub): - """ - Remove all references to a specific Subject ID - - :param sub: A Subject ID - """ - for _sid in self.get(sub, "sid"): - self.remove(_sid, "sub", sub) - self.delete(sub, "sid") - - def close(self): - self._db.close() - - def clear(self): - self._db.clear() diff --git a/src/oidcendpoint/token_handler.py b/src/oidcendpoint/token/__init__.py similarity index 50% rename from src/oidcendpoint/token_handler.py rename to src/oidcendpoint/token/__init__.py index 692d7d5..f6ef69a 100755 --- a/src/oidcendpoint/token_handler.py +++ b/src/oidcendpoint/token/__init__.py @@ -1,18 +1,16 @@ import base64 import hashlib import logging -import warnings +from typing import Optional from cryptography.fernet import Fernet -from cryptography.fernet import InvalidToken -from cryptojwt.exception import Invalid -from cryptojwt.key_jar import init_key_jar from cryptojwt.utils import as_bytes from cryptojwt.utils import as_unicode from oidcmsg.time_util import time_sans_frac from oidcendpoint import rndstr -from oidcendpoint.util import importer +from oidcendpoint.token.exception import UnknownToken +from oidcendpoint.token.exception import WrongTokenType from oidcendpoint.util import lv_pack from oidcendpoint.util import lv_unpack @@ -21,26 +19,6 @@ logger = logging.getLogger(__name__) -class ExpiredToken(Exception): - pass - - -class WrongTokenType(Exception): - pass - - -class AccessCodeUsed(Exception): - pass - - -class UnknownToken(Exception): - pass - - -class NotAllowed(Exception): - pass - - def is_expired(exp, when=0): if exp < 0: return False @@ -74,13 +52,16 @@ class Token(object): def __init__(self, typ, lifetime=300, **kwargs): self.type = typ self.lifetime = lifetime - self.args = kwargs + self.kwargs = kwargs - def __call__(self, sid): + def __call__(self, + session_id: Optional[str] = '', + ttype: Optional[str] = '', + **payload) -> str: """ Return a token. - :param sid: Session id + :param payload: Information to place in the token if possible. :return: """ raise NotImplementedError() @@ -121,13 +102,14 @@ def __init__(self, password, typ="", token_type="Bearer", **kwargs): self.crypt = Crypt(password) self.token_type = token_type - def __call__(self, sid="", ttype="", **kwargs): + def __call__(self, + session_id: Optional[str] = '', + ttype: Optional[str] = '', + **payload) -> str: """ Return a token. - :param ttype: Type of token - :param prev: Previous token, if there is one to go from - :param sid: Session id + :param payload: Token information :return: """ if not ttype and self.type: @@ -146,7 +128,7 @@ def __call__(self, sid="", ttype="", **kwargs): rnd = rndstr(32) # Ultimate length multiple of 16 return base64.b64encode( - self.crypt.encrypt(lv_pack(rnd, ttype, sid, exp).encode()) + self.crypt.encrypt(lv_pack(rnd, ttype, session_id, exp).encode()) ).decode("utf-8") def key(self, user="", areq=None): @@ -169,7 +151,7 @@ def split_token(self, token): # order: rnd, type, sid return lv_unpack(plain) - def info(self, token): + def info(self, token: str) -> dict: """ Return token information. @@ -183,131 +165,10 @@ def info(self, token): _res["handler"] = self return _res - def is_expired(self, token, when=0): + def is_expired(self, token: str, when: int = 0): _exp = self.info(token)["exp"] if _exp == "-1": return False else: exp = int(_exp) return is_expired(exp, when) - - -class TokenHandler(object): - def __init__( - self, access_token_handler=None, code_handler=None, refresh_token_handler=None - ): - - self.handler = {"code": code_handler, "access_token": access_token_handler} - - self.handler_order = ["code", "access_token"] - - if refresh_token_handler: - self.handler["refresh_token"] = refresh_token_handler - self.handler_order.append("refresh_token") - - def __getitem__(self, typ): - return self.handler[typ] - - def __contains__(self, item): - return item in self.handler - - def info(self, item, order=None): - _handler, item_info = self.get_handler(item, order) - - if _handler is None: - logger.info("Unknown token format") - raise UnknownToken(item) - else: - return item_info - - def sid(self, token, order=None): - return self.info(token, order)["sid"] - - def type(self, token, order=None): - return self.info(token, order)["type"] - - def get_handler(self, token, order=None): - if order is None: - order = self.handler_order - - for typ in order: - try: - res = self.handler[typ].info(token) - except (KeyError, WrongTokenType, InvalidToken, UnknownToken, Invalid): - pass - else: - return self.handler[typ], res - - return None, None - - def keys(self): - return self.handler.keys() - - -def init_token_handler(ec, spec, typ): - try: - _cls = spec["class"] - except KeyError: - cls = DefaultToken - else: - cls = importer(_cls) - - _kwargs = spec.get("kwargs") - if _kwargs is None: - if cls != DefaultToken: - warnings.warn( - "Token initialisation arguments should be grouped under 'kwargs'.", - DeprecationWarning, - stacklevel=2, - ) - _kwargs = spec - - return cls(typ=typ, ec=ec, **_kwargs) - - -def _add_passwd(keyjar, conf, kid): - if keyjar: - _keys = keyjar.get_encrypt_key(key_type="oct", kid=kid) - if _keys: - pw = as_unicode(_keys[0].k) - if "kwargs" in conf: - conf["kwargs"]["password"] = pw - else: - conf["password"] = pw - - -def factory(ec, code=None, token=None, refresh=None, jwks_def=None, **kwargs): - """ - Create a token handler - - :param code: - :param token: - :param refresh: - :param jwks_def: - :return: TokenHandler instance - """ - - TTYPE = {"code": "A", "token": "T", "refresh": "R"} - - if jwks_def: - kj = init_key_jar(**jwks_def) - else: - kj = None - - args = {} - - if code: - _add_passwd(kj, code, "code") - args["code_handler"] = init_token_handler(ec, code, TTYPE["code"]) - - if token: - _add_passwd(kj, token, "token") - args["access_token_handler"] = init_token_handler(ec, token, TTYPE["token"]) - - if refresh: - _add_passwd(kj, refresh, "refresh") - args["refresh_token_handler"] = init_token_handler( - ec, refresh, TTYPE["refresh"] - ) - - return TokenHandler(**args) diff --git a/src/oidcendpoint/token/exception.py b/src/oidcendpoint/token/exception.py new file mode 100644 index 0000000..965b38c --- /dev/null +++ b/src/oidcendpoint/token/exception.py @@ -0,0 +1,18 @@ +class ExpiredToken(Exception): + pass + + +class WrongTokenType(Exception): + pass + + +class AccessCodeUsed(Exception): + pass + + +class UnknownToken(Exception): + pass + + +class NotAllowed(Exception): + pass diff --git a/src/oidcendpoint/token/handler.py b/src/oidcendpoint/token/handler.py new file mode 100755 index 0000000..ba895b1 --- /dev/null +++ b/src/oidcendpoint/token/handler.py @@ -0,0 +1,137 @@ +import logging +import warnings + +from cryptography.fernet import InvalidToken +from cryptojwt.exception import Invalid +from cryptojwt.key_jar import init_key_jar +from cryptojwt.utils import as_unicode + +from oidcendpoint.token import DefaultToken +from oidcendpoint.token import UnknownToken +from oidcendpoint.token import WrongTokenType +from oidcendpoint.util import importer + +__author__ = "Roland Hedberg" + +logger = logging.getLogger(__name__) + + +class TokenHandler(object): + def __init__( + self, access_token_handler=None, code_handler=None, refresh_token_handler=None + ): + + self.handler = {"code": code_handler, "access_token": access_token_handler} + + self.handler_order = ["code", "access_token"] + + if refresh_token_handler: + self.handler["refresh_token"] = refresh_token_handler + self.handler_order.append("refresh_token") + + def __getitem__(self, typ): + return self.handler[typ] + + def __contains__(self, item): + return item in self.handler + + def info(self, item, order=None): + _handler, item_info = self.get_handler(item, order) + + if _handler is None: + logger.info("Unknown token format") + raise UnknownToken(item) + else: + return item_info + + def sid(self, token, order=None): + return self.info(token, order)["sid"] + + def type(self, token, order=None): + return self.info(token, order)["type"] + + def get_handler(self, token, order=None): + if order is None: + order = self.handler_order + + for typ in order: + try: + res = self.handler[typ].info(token) + except (KeyError, WrongTokenType, InvalidToken, UnknownToken, Invalid): + pass + else: + return self.handler[typ], res + + return None, None + + def keys(self): + return self.handler.keys() + + +def init_token_handler(ec, spec, typ): + try: + _cls = spec["class"] + except KeyError: + cls = DefaultToken + else: + cls = importer(_cls) + + _kwargs = spec.get("kwargs") + if _kwargs is None: + if cls != DefaultToken: + warnings.warn( + "Token initialisation arguments should be grouped under 'kwargs'.", + DeprecationWarning, + stacklevel=2, + ) + _kwargs = spec + + return cls(typ=typ, endpoint_context=ec, **_kwargs) + + +def _add_passwd(keyjar, conf, kid): + if keyjar: + _keys = keyjar.get_encrypt_key(key_type="oct", kid=kid) + if _keys: + pw = as_unicode(_keys[0].k) + if "kwargs" in conf: + conf["kwargs"]["password"] = pw + else: + conf["password"] = pw + + +def factory(ec, code=None, token=None, refresh=None, jwks_def=None, **kwargs): + """ + Create a token handler + + :param code: + :param token: + :param refresh: + :param jwks_def: + :return: TokenHandler instance + """ + + TTYPE = {"code": "A", "token": "T", "refresh": "R"} + + if jwks_def: + kj = init_key_jar(**jwks_def) + else: + kj = None + + args = {} + + if code: + _add_passwd(kj, code, "code") + args["code_handler"] = init_token_handler(ec, code, TTYPE["code"]) + + if token: + _add_passwd(kj, token, "token") + args["access_token_handler"] = init_token_handler(ec, token, TTYPE["token"]) + + if refresh is not None: + _add_passwd(kj, refresh, "refresh") + args["refresh_token_handler"] = init_token_handler( + ec, refresh, TTYPE["refresh"] + ) + + return TokenHandler(**args) diff --git a/src/oidcendpoint/token/jwt_token.py b/src/oidcendpoint/token/jwt_token.py new file mode 100644 index 0000000..aa868ae --- /dev/null +++ b/src/oidcendpoint/token/jwt_token.py @@ -0,0 +1,115 @@ +from typing import Optional + +from cryptojwt import JWT +from cryptojwt.jws.exception import JWSException + +from oidcendpoint.exception import ToOld + +from . import Token +from . import is_expired +from .exception import UnknownToken + +TYPE_MAP = { + "A": "code", + "T": "access_token", + "R": "refresh_token" +} + + +class JWTToken(Token): + def __init__( + self, + typ, + keyjar=None, + issuer: str = None, + aud: Optional[list] = None, + alg: str = "ES256", + lifetime: int = 300, + endpoint_context=None, + token_type: str = "Bearer", + **kwargs + ): + Token.__init__(self, typ, **kwargs) + self.token_type = token_type + self.lifetime = lifetime + + self.kwargs = kwargs + self.key_jar = keyjar or endpoint_context.keyjar + self.issuer = issuer or endpoint_context.issuer + self.cdb = endpoint_context.cdb + self.endpoint_context = endpoint_context + + self.def_aud = aud or [] + self.alg = alg + + def __call__(self, + session_id: Optional[str] = '', + ttype: Optional[str] = '', + **payload) -> str: + """ + Return a token. + + :param session_id: Session id + :param subject: + :param grant: + :param kwargs: KeyWord arguments + :return: Signed JSON Web Token + """ + if not ttype and self.type: + ttype = self.type + else: + ttype = "A" + + payload.update({"sid": session_id, "ttype": ttype}) + + # payload.update(kwargs) + signer = JWT( + key_jar=self.key_jar, + iss=self.issuer, + lifetime=self.lifetime, + sign_alg=self.alg, + ) + + return signer.pack(payload) + + def info(self, token): + """ + Return type of Token (A=Access code, T=Token, R=Refresh token) and + the session id. + + :param token: A token + :return: tuple of token type and session id + """ + verifier = JWT(key_jar=self.key_jar, allowed_sign_algs=[self.alg]) + try: + _payload = verifier.unpack(token) + except JWSException: + raise UnknownToken() + + if is_expired(_payload["exp"]): + raise ToOld("Token has expired") + # All the token metadata + _res = { + "sid": _payload["sid"], + "type": _payload["ttype"], + "exp": _payload["exp"], + "handler": self, + } + return _res + + def is_expired(self, token, when=0): + """ + Evaluate whether the token has expired or not + + :param token: The token + :param when: The time against which to check the expiration + 0 means now. + :return: True/False + """ + verifier = JWT(key_jar=self.key_jar, allowed_sign_algs=[self.alg]) + _payload = verifier.unpack(token) + return is_expired(_payload["exp"], when) + + def gather_args(self, sid, sdb, udb): + _sinfo = sdb[sid] + return {} diff --git a/src/oidcendpoint/user_authn/user.py b/src/oidcendpoint/user_authn/user.py index e197ef6..2eeac47 100755 --- a/src/oidcendpoint/user_authn/user.py +++ b/src/oidcendpoint/user_authn/user.py @@ -11,6 +11,7 @@ from cryptojwt.jwt import JWT from oidcendpoint import sanitize +from oidcendpoint.cookie import cookie_value from oidcendpoint.exception import FailedAuthentication from oidcendpoint.exception import ImproperlyConfigured from oidcendpoint.exception import InvalidCookieSign @@ -293,7 +294,28 @@ def authenticated_as(self, cookie=None, authorization="", **kwargs): if self.fail: raise self.fail() - return {"uid": self.user}, time.time() + res = {"uid": self.user} + + if cookie: + try: + val = self.cookie_dealer.get_cookie_value( + cookie, cookie_name=self.endpoint_context.cookie_name["session"] + ) + except (InvalidCookieSign, AssertionError, AttributeError) as err: + logger.warning(err) + val = None + + if val is None: + return None, 0 + else: + b64val, _ts, typ = val + info = cookie_value(b64val) + if isinstance(info, dict): + res.update(info) + else: + res["value"] = b64val + + return res, time.time() def factory(cls, **kwargs): diff --git a/src/oidcendpoint/userinfo.py b/src/oidcendpoint/userinfo.py deleted file mode 100755 index 590c2ad..0000000 --- a/src/oidcendpoint/userinfo.py +++ /dev/null @@ -1,210 +0,0 @@ -import logging - -from oidcmsg.oidc import Claims - -from oidcendpoint import sanitize -from oidcendpoint.exception import ImproperlyConfigured -from oidcendpoint.scopes import convert_scopes2claims - -logger = logging.getLogger(__name__) - - -def id_token_claims(session, provider_info): - """ - Pick the IdToken claims from the request - - :param session: Session information - :return: The IdToken claims - """ - itc = update_claims(session, "id_token", provider_info=provider_info, old_claims={}) - return itc - - -def update_claims(session, about, provider_info, old_claims=None): - """ - - :param session: - :param about: userinfo or id_token - :param old_claims: - :return: claims or None - """ - - if old_claims is None: - old_claims = {} - - req = None - try: - req = session["authn_req"] - except KeyError: - pass - - if req: - try: - _claims = req["claims"][about] - except KeyError: - pass - else: - if _claims: - # Deal only with supported claims - _unsup = [ - c - for c in _claims.keys() - if c not in provider_info["claims_supported"] - ] - for _c in _unsup: - del _claims[_c] - - # update with old claims, do not overwrite - for key, val in old_claims.items(): - if key not in _claims: - _claims[key] = val - return _claims - - return old_claims - - -def claims_match(value, claimspec): - """ - Implements matching according to section 5.5.1 of - http://openid.net/specs/openid-connect-core-1_0.html - The lack of value is not checked here. - Also the text doesn't prohibit having both 'value' and 'values'. - - :param value: single value - :param claimspec: None or dictionary with 'essential', 'value' or 'values' - as key - :return: Boolean - """ - if claimspec is None: # match anything - return True - - matched = False - for key, val in claimspec.items(): - if key == "value": - if value == val: - matched = True - elif key == "values": - if value in val: - matched = True - elif key == "essential": - # Whether it's essential or not doesn't change anything here - continue - - if matched: - break - - if matched is False: - if list(claimspec.keys()) == ["essential"]: - return True - - return matched - - -def by_schema(cls, **kwa): - """ - Will return only those claims that are listed in the Class definition. - - :param cls: A subclass of :py:class:´oidcmsg.message.Message` - :param kwa: Keyword arguments - :return: A dictionary with claims (keys) that meets the filter criteria - """ - return dict([(key, val) for key, val in kwa.items() if key in cls.c_param]) - - -def collect_user_info( - endpoint_context, session, userinfo_claims=None, scope_to_claims=None -): - """ - Collect information about a user. - This can happen in two cases, either when constructing an IdToken or - when returning user info through the UserInfo endpoint - - :param session: Session information - :param userinfo_claims: user info claims - :return: User info - """ - authn_req = session["authn_req"] - if scope_to_claims is None: - scope_to_claims = endpoint_context.scope2claims - - supported_scopes = endpoint_context.scopes_handler.filter_scopes( - authn_req["client_id"], endpoint_context, authn_req["scope"] - ) - if userinfo_claims is None: - _allowed_claims = endpoint_context.claims_handler.allowed_claims( - authn_req["client_id"], endpoint_context - ) - uic = convert_scopes2claims( - supported_scopes, _allowed_claims, map=scope_to_claims - ) - - # Get only keys allowed by user and update the dict if such info - # is stored in session - perm_set = session.get("permission") - if perm_set: - uic = {key: uic[key] for key in uic if key in perm_set} - - uic = update_claims( - session, - "userinfo", - provider_info=endpoint_context.provider_info, - old_claims=uic, - ) - - if uic: - userinfo_claims = Claims(**uic) - logger.debug("userinfo_claim: %s" % sanitize(userinfo_claims.to_dict())) - else: - userinfo_claims = None - logger.warning(("Client {} doesn't have any claims " - "belonging to one or more scopes.").format(authn_req["client_id"])) - raise ImproperlyConfigured("Some additional scopes doesn't have any claims.") - - logger.debug("Session info: %s" % sanitize(session)) - - authn_event = session["authn_event"] - if authn_event: - uid = authn_event["uid"] - else: - uid = session["uid"] - - info = endpoint_context.userinfo(uid, authn_req["client_id"], userinfo_claims) - - if "sub" in userinfo_claims: - if not claims_match(session["sub"], userinfo_claims["sub"]): - raise FailedAuthentication("Unmatched sub claim") - - info["sub"] = session["sub"] - try: - logger.debug("user_info_response: {}".format(info)) - except UnicodeEncodeError: - logger.debug("user_info_response: {}".format(info.encode("utf-8"))) - - return info - - -def userinfo_in_id_token_claims(endpoint_context, session, def_itc=None): - """ - Collect user info claims that are to be placed in the id token. - - :param endpoint_context: Endpoint context - :param session: Session information - :param def_itc: Default ID Token claims - :return: User information or None - """ - if def_itc: - itc = def_itc - else: - itc = {} - - itc.update(id_token_claims(session, provider_info=endpoint_context.provider_info)) - - if not itc: - return None - - _claims = by_schema(endpoint_context.id_token_schema, **itc) - - if _claims: - return collect_user_info(endpoint_context, session, _claims) - else: - return None diff --git a/src/oidcendpoint/util.py b/src/oidcendpoint/util.py index c4dfa50..a6eafd2 100755 --- a/src/oidcendpoint/util.py +++ b/src/oidcendpoint/util.py @@ -2,9 +2,9 @@ import json import logging from urllib.parse import parse_qs +from urllib.parse import urlparse from urllib.parse import urlsplit from urllib.parse import urlunsplit -from urllib.parse import urlparse from oidcendpoint.exception import OidcEndpointError @@ -175,7 +175,8 @@ def split_uri(uri): def allow_refresh_token(endpoint_context): # Are there a refresh_token handler - refresh_token_handler = endpoint_context.sdb.handler.handler.get("refresh_token") + refresh_token_handler = endpoint_context.session_manager.token_handler.handler[ + "refresh_token"] # Is refresh_token grant type supported _token_supported = False diff --git a/tests/test_00_endpoint_context.py b/tests/test_00_endpoint_context.py index 66e5fb7..da5af76 100755 --- a/tests/test_00_endpoint_context.py +++ b/tests/test_00_endpoint_context.py @@ -12,7 +12,7 @@ from oidcendpoint.oidc.provider_config import ProviderConfiguration from oidcendpoint.oidc.registration import Registration from oidcendpoint.oidc.session import Session -from oidcendpoint.oidc.token import AccessToken +from oidcendpoint.oidc.token import Token from oidcendpoint.oidc.userinfo import UserInfo from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD @@ -49,7 +49,7 @@ "class": Authorization, "kwargs": {}, }, - "token_endpoint": {"path": "token", "class": AccessToken, "kwargs": {}}, + "token_endpoint": {"path": "token", "class": Token, "kwargs": {}}, "userinfo_endpoint": { "path": "userinfo", "class": UserInfo, diff --git a/tests/test_01_grant.py b/tests/test_01_grant.py new file mode 100644 index 0000000..3e621dc --- /dev/null +++ b/tests/test_01_grant.py @@ -0,0 +1,361 @@ +from cryptojwt.key_jar import build_keyjar +import pytest + +from oidcendpoint.endpoint_context import EndpointContext +from oidcendpoint.session import session_key +from oidcendpoint.session.grant import Grant +from oidcendpoint.session.grant import TOKEN_MAP +from oidcendpoint.session.grant import find_token +from oidcendpoint.session.grant import get_usage_rules +from oidcendpoint.session.token import AuthorizationCode +from oidcendpoint.session.token import Token +from oidcendpoint.token import DefaultToken +from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD + +KEYDEFS = [ + {"type": "RSA", "key": "", "use": ["sig"]}, + {"type": "EC", "crv": "P-256", "use": ["sig"]}, +] + +KEYJAR = build_keyjar(KEYDEFS) + +conf = { + "issuer": "https://example.com/", + "template_dir": "template", + "keys": {"uri_path": "static/jwks.json", "key_defs": KEYDEFS, "read_only": True}, + "endpoint": { + "authorization_endpoint": { + "path": "authorization", + "class": "oidcendpoint.oidc.authorization.Authorization", + "kwargs": {}, + }, + "token_endpoint": { + "path": "token", + "class": "oidcendpoint.oidc.token.Token", + "kwargs": {} + }, + }, + "authentication": { + "anon": { + "acr": INTERNETPROTOCOLPASSWORD, + "class": "oidcendpoint.user_authn.user.NoAuthn", + "kwargs": {"user": "diana"}, + } + }, +} + +ENDPOINT_CONTEXT = EndpointContext(conf) + + +def test_access_code(): + token = AuthorizationCode('authorization_code', value="ABCD") + assert token.issued_at + assert token.type == "authorization_code" + assert token.value == "ABCD" + + token.register_usage() + # max_usage == 1 + assert token.max_usage_reached() is True + + +def test_access_token(): + code = AuthorizationCode('authorization_code', value="ABCD") + token = Token('access_token', value="1234", based_on=code.id, usage_rules={"max_usage": 2}) + assert token.issued_at + assert token.type == "access_token" + assert token.value == "1234" + + token.register_usage() + # max_usage - undefined + assert token.max_usage_reached() is False + + token.register_usage() + assert token.max_usage_reached() is True + + t = find_token([code, token], token.based_on) + assert t.value == "ABCD" + + token.revoked = True + assert token.revoked is True + + +TOKEN_HANDLER = { + "authorization_code": DefaultToken("authorization_code", typ="A"), + "access_token": DefaultToken("access_token", typ="T"), + "refresh_token": DefaultToken("refresh_token", typ="R") +} + + +def test_mint_token(): + grant = Grant() + + code = grant.mint_token("user_id;;client_id;;grant_id", + endpoint_context=ENDPOINT_CONTEXT, + token_type="authorization_code", + token_handler=TOKEN_HANDLER["authorization_code"]) + + access_token = grant.mint_token( + "user_id;;client_id;;grant_id", + endpoint_context=ENDPOINT_CONTEXT, + token_type="access_token", + token_handler=TOKEN_HANDLER["access_token"], + based_on=code, + scope=["openid", "foo", "bar"] + ) + + assert access_token.scope == ["openid", "foo", "bar"] + + +SESSION_ID = session_key('user_id', 'client_id', 'grant.id') + + +def test_grant(): + grant = Grant() + code = grant.mint_token(SESSION_ID, endpoint_context=ENDPOINT_CONTEXT, + token_type="authorization_code", + token_handler=TOKEN_HANDLER["authorization_code"]) + + access_token = grant.mint_token( + SESSION_ID, + endpoint_context=ENDPOINT_CONTEXT, + token_type="access_token", + token_handler=TOKEN_HANDLER["access_token"], + based_on=code) + + refresh_token = grant.mint_token( + SESSION_ID, + endpoint_context=ENDPOINT_CONTEXT, + token_type="refresh_token", + token_handler=TOKEN_HANDLER["refresh_token"], + based_on=code) + + grant.revoke_token() + assert code.revoked is True + assert access_token.revoked is True + assert refresh_token.revoked is True + + +def test_get_token(): + grant = Grant() + code = grant.mint_token(SESSION_ID, endpoint_context=ENDPOINT_CONTEXT, + token_type="authorization_code", + token_handler=TOKEN_HANDLER["authorization_code"]) + + access_token = grant.mint_token( + SESSION_ID, + endpoint_context=ENDPOINT_CONTEXT, + token_type="access_token", + token_handler=TOKEN_HANDLER["access_token"], + based_on=code, + scope=["openid", "foo", "bar"] + ) + + _code = grant.get_token(code.value) + assert _code.id == code.id + + _token = grant.get_token(access_token.value) + assert _token.id == access_token.id + assert set(_token.scope) == {"openid", "foo", "bar"} + + +def test_grant_revoked_based_on(): + grant = Grant() + code = grant.mint_token(SESSION_ID, + endpoint_context=ENDPOINT_CONTEXT, + token_type="authorization_code", + token_handler=TOKEN_HANDLER["authorization_code"]) + + access_token = grant.mint_token( + SESSION_ID, + endpoint_context=ENDPOINT_CONTEXT, + token_type="access_token", + token_handler=TOKEN_HANDLER["access_token"], + based_on=code) + + refresh_token = grant.mint_token( + SESSION_ID, + endpoint_context=ENDPOINT_CONTEXT, + token_type="refresh_token", + token_handler=TOKEN_HANDLER["refresh_token"], + based_on=code) + + code.register_usage() + if code.max_usage_reached(): + grant.revoke_token(based_on=code.value) + + assert code.is_active() is False + assert access_token.is_active() is False + assert refresh_token.is_active() is False + + +def test_revoke(): + grant = Grant() + code = grant.mint_token(SESSION_ID, + endpoint_context=ENDPOINT_CONTEXT, + token_type="authorization_code", + token_handler=TOKEN_HANDLER["authorization_code"]) + + access_token = grant.mint_token( + SESSION_ID, + endpoint_context=ENDPOINT_CONTEXT, + token_type="access_token", + token_handler=TOKEN_HANDLER["access_token"], + based_on=code) + + grant.revoke_token(based_on=code.value) + + assert code.is_active() is True + assert access_token.is_active() is False + + access_token_2 = grant.mint_token( + SESSION_ID, + endpoint_context=ENDPOINT_CONTEXT, + token_type="access_token", + token_handler=TOKEN_HANDLER["access_token"], + based_on=code) + + grant.revoke_token(value=code.value, recursive=True) + + assert code.is_active() is False + assert access_token_2.is_active() is False + + +def test_json_conversion(): + grant = Grant() + code = grant.mint_token(SESSION_ID, + endpoint_context=ENDPOINT_CONTEXT, + token_type="authorization_code", + token_handler=TOKEN_HANDLER["authorization_code"]) + + grant.mint_token( + SESSION_ID, + endpoint_context=ENDPOINT_CONTEXT, + token_type="access_token", + token_handler=TOKEN_HANDLER["access_token"], + based_on=code) + + _jstr = grant.to_json() + + _grant_copy = Grant().from_json(_jstr) + + assert len(_grant_copy.issued_token) == 2 + + tt = {"code": 0, "access_token": 0} + for token in _grant_copy.issued_token: + if token.type == "authorization_code": + tt["code"] += 1 + if token.type == "access_token": + tt["access_token"] += 1 + + assert tt == {"code": 1, "access_token": 1} + + +def test_json_no_token_map(): + grant = Grant(token_map={}) + with pytest.raises(ValueError): + grant.mint_token(SESSION_ID, + endpoint_context=ENDPOINT_CONTEXT, + token_type="authorization_code", + token_handler=TOKEN_HANDLER["authorization_code"]) + + +class MyToken(Token): + pass + + +TOKEN_HANDLER["my_token"] = DefaultToken("my_token", typ="M") + + +def test_json_custom_token_map(): + token_map = TOKEN_MAP.copy() + token_map["my_token"] = MyToken + + grant = Grant(token_map=token_map) + code = grant.mint_token(SESSION_ID, + endpoint_context=ENDPOINT_CONTEXT, + token_type="authorization_code", + token_handler=TOKEN_HANDLER["authorization_code"]) + + grant.mint_token( + SESSION_ID, + endpoint_context=ENDPOINT_CONTEXT, + token_type="access_token", + token_handler=TOKEN_HANDLER["access_token"], + based_on=code) + + grant.mint_token( + SESSION_ID, + endpoint_context=ENDPOINT_CONTEXT, + token_type="my_token", + token_handler=TOKEN_HANDLER["my_token"]) + + _jstr = grant.to_json() + + _grant_copy = Grant().from_json(_jstr) + + assert len(_grant_copy.issued_token) == 3 + + tt = {k: 0 for k, v in grant.token_map.items()} + + for token in _grant_copy.issued_token: + for _type in tt.keys(): + if token.type == _type: + tt[_type] += 1 + + assert tt == { + 'access_token': 1, 'authorization_code': 1, + 'my_token': 1, 'refresh_token': 0 + } + + +def test_get_spec(): + grant = Grant(scope=["openid", "email", "address"], + claims={"userinfo": {"given_name": None, "email": None}}, + resources=["https://api.example.com"] + ) + + code = grant.mint_token(SESSION_ID, + endpoint_context=ENDPOINT_CONTEXT, + token_type="authorization_code", + token_handler=TOKEN_HANDLER["authorization_code"]) + + access_token = grant.mint_token( + SESSION_ID, + endpoint_context=ENDPOINT_CONTEXT, + token_type="access_token", + token_handler=TOKEN_HANDLER["access_token"], + based_on=code, + scope=["openid", "email", "eduperson"], + claims={ + "userinfo": { + "given_name": None, + "eduperson_affiliation": None + } + } + ) + + spec = grant.get_spec(access_token) + assert set(spec.keys()) == {"scope", "claims", "resources"} + assert spec["scope"] == ["openid", "email", "eduperson"] + assert spec["claims"] == { + "userinfo": { + "given_name": None, + "eduperson_affiliation": None + } + } + assert spec["resources"] == ["https://api.example.com"] + + +def test_get_usage_rules(): + grant = Grant(scope=["openid", "email", "address"], + claims={"userinfo": {"given_name": None, "email": None}}, + resources=["https://api.example.com"] + ) + + # Default usage rules + ENDPOINT_CONTEXT.cdb["client_id"] = {} + rules = get_usage_rules("access_token", ENDPOINT_CONTEXT, grant, "client_id") + assert rules == {'supports_minting': [], 'expires_in': 3600} + + # client specific usage rules + ENDPOINT_CONTEXT.cdb["client_id"] = {"access_token": {"expires_in": 600}} \ No newline at end of file diff --git a/tests/test_01_session_info.py b/tests/test_01_session_info.py new file mode 100644 index 0000000..d048e87 --- /dev/null +++ b/tests/test_01_session_info.py @@ -0,0 +1,64 @@ +import pytest +from oidcmsg.oauth2 import AuthorizationRequest + +from oidcendpoint.session.info import ClientSessionInfo +from oidcendpoint.session.info import SessionInfo +from oidcendpoint.session.info import UserSessionInfo + +AUTH_REQ = AuthorizationRequest( + client_id="client_1", + redirect_uri="https://example.com/cb", + scope=["openid"], + state="STATE", + response_type=["code"], +) + + +def test_session_info_subordinate(): + si = SessionInfo() + si.add_subordinate("subordinate_1") + si.add_subordinate("subordinate_2") + assert set(si["subordinate"]) == {"subordinate_1", "subordinate_2"} + assert set(si["subordinate"]) == {"subordinate_1", "subordinate_2"} + assert si.is_revoked() is False + + si.remove_subordinate("subordinate_1") + assert si["subordinate"] == ["subordinate_2"] + + si.revoke() + assert si.is_revoked() is True + + +def test_session_info_no_subordinate(): + si = SessionInfo() + assert si["subordinate"] == [] + + +def test_user_session_info_to_json(): + usi = UserSessionInfo(uid="uid") + + _jstr = usi.to_json() + + usi2 = UserSessionInfo().from_json(_jstr) + + assert usi2["uid"] == "uid" + + +def test_user_session_info_to_json_with_sub(): + usi = UserSessionInfo(uid="uid") + usi.add_subordinate("client_id") + + _jstr = usi.to_json() + + usi2 = UserSessionInfo().from_json(_jstr) + + assert usi2["subordinate"] == ["client_id"] + + +def test_client_session_info(): + csi = ClientSessionInfo(client_id="clientID") + + _jstr = csi.to_json() + + _csi2 = ClientSessionInfo().from_json(_jstr) + assert _csi2["client_id"] == "clientID" diff --git a/tests/test_01_session_token.py b/tests/test_01_session_token.py new file mode 100644 index 0000000..3675d1d --- /dev/null +++ b/tests/test_01_session_token.py @@ -0,0 +1,90 @@ +from oidcmsg.time_util import time_sans_frac + +from oidcendpoint.session.token import AccessToken +from oidcendpoint.session.token import AuthorizationCode + + +def test_authorization_code_default(): + code = AuthorizationCode(value="ABCD") + assert code.usage_rules["max_usage"] == 1 + assert code.usage_rules["supports_minting"] == ["access_token", + "refresh_token"] + + +def test_authorization_code_usage(): + code = AuthorizationCode(value="ABCD", + usage_rules={ + "supports_minting": ["access_token"], + "max_usage": 1 + }) + + assert code.usage_rules["max_usage"] == 1 + assert code.usage_rules["supports_minting"] == ["access_token"] + + +def test_authorization_code_extras(): + code = AuthorizationCode(value="ABCD", + scope=["openid", "foo", "bar"], + claims={"userinfo": {"given_name": None}}, + resources=["https://api.example.com"]) + + assert code.scope == ["openid", "foo", "bar"] + assert code.claims == {"userinfo": {"given_name": None}} + assert code.resources == ["https://api.example.com"] + + +def test_to_from_json(): + code = AuthorizationCode(value="ABCD", + scope=["openid", "foo", "bar"], + claims={"userinfo": {"given_name": None}}, + resources=["https://api.example.com"]) + + _json_str = code.to_json() + + _new_code = AuthorizationCode().from_json(_json_str) + + for attr in AuthorizationCode.attributes: + assert getattr(code, attr) == getattr(_new_code, attr) + + +def test_supports_minting(): + code = AuthorizationCode(value="ABCD") + assert code.supports_minting('access_token') + assert code.supports_minting('refresh_token') + assert code.supports_minting("authorization_code") is False + + +def test_usage(): + token = AccessToken(usage_rules={"max_usage": 2}) + + token.register_usage() + assert token.has_been_used() + assert token.used == 1 + assert token.max_usage_reached() is False + + token.register_usage() + assert token.max_usage_reached() + + token.register_usage() + assert token.used == 3 + assert token.max_usage_reached() + + +def test_is_active_usage(): + token = AccessToken(usage_rules={"max_usage": 2}) + + token.register_usage() + token.register_usage() + assert token.is_active() is False + + +def test_is_active_revoke(): + token = AccessToken(usage_rules={"max_usage": 2}) + token.revoke() + assert token.is_active() is False + + +def test_is_active_expired(): + token = AccessToken(usage_rules={"max_usage": 2}) + token.expires_at = time_sans_frac() - 60 + assert token.is_active() is False diff --git a/tests/test_01_util.py b/tests/test_01_util.py index 75ad2a7..286322f 100644 --- a/tests/test_01_util.py +++ b/tests/test_01_util.py @@ -1,7 +1,7 @@ from oidcendpoint.oidc.authorization import Authorization from oidcendpoint.oidc.provider_config import ProviderConfiguration from oidcendpoint.oidc.registration import Registration -from oidcendpoint.oidc.token import AccessToken +from oidcendpoint.oidc.token import Token from oidcendpoint.oidc.userinfo import UserInfo from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD @@ -40,7 +40,7 @@ "class": Authorization, "kwargs": {}, }, - "token_endpoint": {"path": "token", "class": AccessToken, "kwargs": {}}, + "token_endpoint": {"path": "token", "class": Token, "kwargs": {}}, "userinfo_endpoint": { "path": "userinfo", "class": UserInfo, diff --git a/tests/test_02_client_authn.py b/tests/test_02_client_authn.py index 4a59bd3..d1e05d7 100755 --- a/tests/test_02_client_authn.py +++ b/tests/test_02_client_authn.py @@ -24,7 +24,7 @@ from oidcendpoint.exception import NotForMe from oidcendpoint.oidc.authorization import Authorization from oidcendpoint.oidc.registration import Registration -from oidcendpoint.oidc.token import AccessToken +from oidcendpoint.oidc.token import Token from oidcendpoint.oidc.userinfo import UserInfo KEYDEFS = [ @@ -37,14 +37,12 @@ CONF = { "issuer": "https://example.com/", "password": "mycket hemligt", - "token_expires_in": 600, "grant_expires_in": 300, - "refresh_token_expires_in": 86400, "verify_ssl": False, "endpoint": { "token": { "path": "token", - "class": AccessToken, + "class": Token, "kwargs": { "client_authn_method": [ "private_key_jwt", diff --git a/tests/test_02_client_authn.py.old b/tests/test_02_client_authn.py.old deleted file mode 100755 index 35dd5b0..0000000 --- a/tests/test_02_client_authn.py.old +++ /dev/null @@ -1,396 +0,0 @@ -import base64 - -import pytest -from cryptojwt.jws.exception import NoSuitableSigningKeys -from cryptojwt.jwt import JWT -from cryptojwt.key_jar import KeyJar -from cryptojwt.key_jar import build_keyjar -from cryptojwt.utils import as_bytes -from cryptojwt.utils import as_unicode -from oidcendpoint import JWT_BEARER -from oidcendpoint.client_authn import AuthnFailure -from oidcendpoint.client_authn import BearerBody -from oidcendpoint.client_authn import BearerHeader -from oidcendpoint.client_authn import ClientSecretBasic -from oidcendpoint.client_authn import ClientSecretJWT -from oidcendpoint.client_authn import ClientSecretPost -from oidcendpoint.client_authn import JWSAuthnMethod -from oidcendpoint.client_authn import PrivateKeyJWT -from oidcendpoint.client_authn import basic_authn -from oidcendpoint.client_authn import verify_client -from oidcendpoint.endpoint_context import EndpointContext -from oidcendpoint.exception import MultipleUsage -from oidcendpoint.exception import NotForMe -from oidcendpoint.oidc.authorization import Authorization -from oidcendpoint.oidc.token import AccessToken -from oidcendpoint.oidc.userinfo import UserInfo - -KEYDEFS = [ - {"type": "RSA", "key": "", "use": ["sig"]}, - {"type": "EC", "crv": "P-256", "use": ["sig"]}, -] - -KEYJAR = build_keyjar(KEYDEFS) - -conf = { - "issuer": "https://example.com/", - "password": "mycket hemligt", - "token_expires_in": 600, - "grant_expires_in": 300, - "refresh_token_expires_in": 86400, - "verify_ssl": False, - "endpoint": { - "token": {"path": "token", "class": AccessToken, "kwargs": {}}, - "authorization": {"path": "auth", "class": Authorization, "kwargs": {}}, - "userinfo": {"path": "user", "class": UserInfo, "kwargs": {}} - }, - "template_dir": "template", - "jwks": { - "private_path": "own/jwks.json", - "key_defs": KEYDEFS, - "uri_path": "static/jwks.json", - }, -} -client_id = "client_id" -client_secret = "a_longer_client_secret" -# Need to add the client_secret as a symmetric key bound to the client_id -KEYJAR.add_symmetric(client_id, client_secret, ["sig"]) - -endpoint_context = EndpointContext(conf, keyjar=KEYJAR) -endpoint_context.cdb[client_id] = {"client_secret": client_secret} - - -def get_client_id_from_token(endpoint_context, token, request=None): - if "client_id" in request: - if request["client_id"] == endpoint_context.registration_access_token[token]: - return request["client_id"] - return "" - - -def test_client_secret_basic(): - _token = "{}:{}".format(client_id, client_secret) - token = as_unicode(base64.b64encode(as_bytes(_token))) - - authz_token = "Basic {}".format(token) - - authn_info = ClientSecretBasic(endpoint_context).verify({}, authz_token) - - assert authn_info["client_id"] == client_id - - -def test_client_secret_post(): - request = {"client_id": client_id, "client_secret": client_secret} - - authn_info = ClientSecretPost(endpoint_context).verify(request) - - assert authn_info["client_id"] == client_id - - -def test_client_secret_jwt(): - client_keyjar = KeyJar() - client_keyjar[conf["issuer"]] = KEYJAR.issuer_keys[""] - # The only own key the client has a this point - client_keyjar.add_symmetric("", client_secret, ["sig"]) - - _jwt = JWT(client_keyjar, iss=client_id, sign_alg="HS256") - _jwt.with_jti = True - _assertion = _jwt.pack({"aud": [conf["issuer"]]}) - - request = {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER} - - authn_info = ClientSecretJWT(endpoint_context).verify(request) - - assert authn_info["client_id"] == client_id - assert "jwt" in authn_info - - -def test_private_key_jwt(): - # Own dynamic keys - client_keyjar = build_keyjar(KEYDEFS) - # The servers keys - client_keyjar[conf["issuer"]] = KEYJAR.issuer_keys[""] - - _jwks = client_keyjar.export_jwks() - endpoint_context.keyjar.import_jwks(_jwks, client_id) - - _jwt = JWT(client_keyjar, iss=client_id, sign_alg="RS256") - _jwt.with_jti = True - _assertion = _jwt.pack({"aud": [conf["issuer"]]}) - - request = {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER} - - authn_info = PrivateKeyJWT(endpoint_context).verify(request) - - assert authn_info["client_id"] == client_id - assert "jwt" in authn_info - - -def test_private_key_jwt_reusage_other_endpoint(): - # Own dynamic keys - client_keyjar = build_keyjar(KEYDEFS) - # The servers keys - client_keyjar[conf["issuer"]] = KEYJAR.issuer_keys[""] - - _jwks = client_keyjar.export_jwks() - endpoint_context.keyjar.import_jwks(_jwks, client_id) - - _jwt = JWT(client_keyjar, iss=client_id, sign_alg="RS256") - _jwt.with_jti = True - _assertion = _jwt.pack({"aud": [endpoint_context.endpoint["token"].full_path]}) - - request = {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER} - - # This should be OK - PrivateKeyJWT(endpoint_context).verify(request, endpoint="token") - - # This should NOT be OK - with pytest.raises(NotForMe): - PrivateKeyJWT(endpoint_context).verify(request, endpoint="authorization") - - # This should NOT be OK - with pytest.raises(MultipleUsage): - PrivateKeyJWT(endpoint_context).verify(request, endpoint="token") - - -def test_private_key_jwt_auth_endpoint(): - # Own dynamic keys - client_keyjar = build_keyjar(KEYDEFS) - # The servers keys - client_keyjar[conf["issuer"]] = KEYJAR.issuer_keys[""] - - _jwks = client_keyjar.export_jwks() - endpoint_context.keyjar.import_jwks(_jwks, client_id) - - _jwt = JWT(client_keyjar, iss=client_id, sign_alg="RS256") - _jwt.with_jti = True - _assertion = _jwt.pack({"aud": [endpoint_context.endpoint["authorization"].full_path]}) - - request = {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER} - - authn_info = PrivateKeyJWT(endpoint_context).verify(request, endpoint="authorization") - - assert authn_info["client_id"] == client_id - assert "jwt" in authn_info - - -def test_wrong_type(): - with pytest.raises(AuthnFailure): - ClientSecretBasic(endpoint_context).verify({}, "Foppa toffel") - - -def test_csb_wrong_secret(): - _token = "{}:{}".format(client_id, "pillow") - token = as_unicode(base64.b64encode(as_bytes(_token))) - - authz_token = "Basic {}".format(token) - - with pytest.raises(AuthnFailure): - ClientSecretBasic(endpoint_context).verify({}, authz_token) - - -def test_client_secret_post_wrong_secret(): - request = {"client_id": client_id, "client_secret": "pillow"} - - with pytest.raises(AuthnFailure): - ClientSecretPost(endpoint_context).verify(request) - - -def test_bearerheader(): - request = {} - authorization_info = "Bearer 1234567890" - assert BearerHeader(endpoint_context).verify(request, authorization_info) == { - "token": "1234567890" - } - - -def test_bearerheader_wrong_type(): - request = {} - authorization_info = "Thrower 1234567890" - with pytest.raises(AuthnFailure): - BearerHeader(endpoint_context).verify(request, authorization_info) - - -def test_bearer_body(): - request = {"access_token": "1234567890"} - assert BearerBody(endpoint_context).verify(request) == {"token": "1234567890"} - - -def test_bearer_body_no_token(): - request = {} - with pytest.raises(AuthnFailure): - BearerBody(endpoint_context).verify(request) - - -def test_jws_authn_method_wrong_key(): - client_keyjar = KeyJar() - client_keyjar[conf["issuer"]] = KEYJAR.issuer_keys[""] - # Fake symmetric key - client_keyjar.add_symmetric("", "client_secret:client_secret", ["sig"]) - - _jwt = JWT(client_keyjar, iss=client_id, sign_alg="HS256") - _assertion = _jwt.pack({"aud": [conf["issuer"]]}) - - request = {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER} - - with pytest.raises(NoSuitableSigningKeys): - JWSAuthnMethod(endpoint_context).verify(request) - - -def test_jws_authn_method_aud_iss(): - client_keyjar = KeyJar() - client_keyjar[conf["issuer"]] = KEYJAR.issuer_keys[""] - # The only own key the client has a this point - client_keyjar.add_symmetric("", client_secret, ["sig"]) - - _jwt = JWT(client_keyjar, iss=client_id, sign_alg="HS256") - # Audience is OP issuer ID - aud = conf["issuer"] - _assertion = _jwt.pack({"aud": [aud]}) - - request = {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER} - - assert JWSAuthnMethod(endpoint_context).verify(request) - - -def test_jws_authn_method_aud_token_endpoint(): - client_keyjar = KeyJar() - client_keyjar[conf["issuer"]] = KEYJAR.issuer_keys[""] - # The only own key the client has a this point - client_keyjar.add_symmetric("", client_secret, ["sig"]) - - _jwt = JWT(client_keyjar, iss=client_id, sign_alg="HS256") - - # audience is OP token endpoint - that's OK - aud = "{}token".format(conf["issuer"]) - _assertion = _jwt.pack({"aud": [aud]}) - - request = {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER} - - assert JWSAuthnMethod(endpoint_context).verify(request, endpoint="token") - - -def test_jws_authn_method_aud_not_me(): - client_keyjar = KeyJar() - client_keyjar[conf["issuer"]] = KEYJAR.issuer_keys[""] - # The only own key the client has a this point - client_keyjar.add_symmetric("", client_secret, ["sig"]) - - _jwt = JWT(client_keyjar, iss=client_id, sign_alg="HS256") - - # Other audiences not OK - aud = "https://example.org" - - _assertion = _jwt.pack({"aud": [aud]}) - - request = {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER} - - with pytest.raises(NotForMe): - JWSAuthnMethod(endpoint_context).verify(request) - - -def test_basic_auth(): - _token = "{}:{}".format(client_id, client_secret) - token = as_unicode(base64.b64encode(as_bytes(_token))) - - res = basic_authn("Basic {}".format(token)) - assert res - - -def test_basic_auth_wrong_label(): - _token = "{}:{}".format(client_id, client_secret) - token = as_unicode(base64.b64encode(as_bytes(_token))) - - with pytest.raises(AuthnFailure): - basic_authn("Expanded {}".format(token)) - - -def test_basic_auth_wrong_token(): - _token = "{}:{}:foo".format(client_id, client_secret) - token = as_unicode(base64.b64encode(as_bytes(_token))) - with pytest.raises(ValueError): - basic_authn("Basic {}".format(token)) - - _token = "{}:{}".format(client_id, client_secret) - with pytest.raises(ValueError): - basic_authn("Basic {}".format(_token)) - - _token = "{}{}".format(client_id, client_secret) - token = as_unicode(base64.b64encode(as_bytes(_token))) - with pytest.raises(ValueError): - basic_authn("Basic {}".format(token)) - - -def test_verify_client_jws_authn_method(): - client_keyjar = KeyJar() - client_keyjar[conf["issuer"]] = KEYJAR.issuer_keys[""] - # The only own key the client has a this point - client_keyjar.add_symmetric("", client_secret, ["sig"]) - - _jwt = JWT(client_keyjar, iss=client_id, sign_alg="HS256") - # Audience is OP issuer ID - aud = conf["issuer"] - _assertion = _jwt.pack({"aud": [aud]}) - - request = {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER} - - res = verify_client(endpoint_context, request) - assert res["method"] == "private_key_jwt" - assert res["client_id"] == "client_id" - - -def test_verify_client_bearer_body(): - request = {"access_token": "1234567890", "client_id": client_id} - endpoint_context.registration_access_token["1234567890"] = client_id - res = verify_client( - endpoint_context, request, get_client_id_from_token=get_client_id_from_token - ) - assert set(res.keys()) == {"token", "method", "client_id"} - assert res["method"] == "bearer_body" - - -def test_verify_client_client_secret_post(): - request = {"client_id": client_id, "client_secret": client_secret} - res = verify_client(endpoint_context, request) - assert set(res.keys()) == {"method", "client_id"} - assert res["method"] == "client_secret_post" - - -def test_verify_client_client_secret_basic(): - _token = "{}:{}".format(client_id, client_secret) - token = as_unicode(base64.b64encode(as_bytes(_token))) - authz_token = "Basic {}".format(token) - res = verify_client(endpoint_context, {}, authz_token) - assert set(res.keys()) == {"method", "client_id"} - assert res["method"] == "client_secret_basic" - - -def test_verify_client_bearer_header(): - endpoint_context.registration_access_token["1234567890"] = client_id - token = "Bearer 1234567890" - request = {"client_id": client_id} - res = verify_client( - endpoint_context, - request, - authorization_info=token, - get_client_id_from_token=get_client_id_from_token, - ) - - res = verify_client(endpoint_context, request, token, get_client_id_from_token) - assert set(res.keys()) == {"token", "method", "client_id"} - assert res["method"] == "bearer_header" - - -def test_jws_authn_method_aud_userinfo_endpoint(): - client_keyjar = KeyJar() - client_keyjar[conf["issuer"]] = KEYJAR.issuer_keys[""] - # The only own key the client has a this point - client_keyjar.add_symmetric("", client_secret, ["sig"]) - - _jwt = JWT(client_keyjar, iss=client_id, sign_alg="HS256") - - # audience is the OP - not specifically the user info endpoint - _assertion = _jwt.pack({"aud": [conf["issuer"]]}) - - request = {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER} - - assert JWSAuthnMethod(endpoint_context).verify(request, endpoint="userinfo") diff --git a/tests/test_02_sess_mngm_db.py b/tests/test_02_sess_mngm_db.py new file mode 100644 index 0000000..e68607c --- /dev/null +++ b/tests/test_02_sess_mngm_db.py @@ -0,0 +1,264 @@ +# Database is organized in 3 layers. User-session-grant. +import pytest +from oidcmsg.oauth2 import AuthorizationRequest + +from oidcendpoint.authn_event import create_authn_event +from oidcendpoint.session import session_key +from oidcendpoint.session.database import Database +from oidcendpoint.session.database import NoSuchClientSession +from oidcendpoint.session.database import NoSuchGrant +from oidcendpoint.session.grant import Grant +from oidcendpoint.session.info import ClientSessionInfo +from oidcendpoint.session.info import UserSessionInfo +from oidcendpoint.session.manager import public_id +from oidcendpoint.session.token import Token + +AUTHZ_REQ = AuthorizationRequest( + client_id="client_1", + redirect_uri="https://example.com/cb", + scope=["openid"], + state="STATE", + response_type="code", +) + + +class TestDB: + @pytest.fixture(autouse=True) + def setup_environment(self): + self.db = Database() + + def test_user_info(self): + with pytest.raises(KeyError): + self.db.get(['diana']) + + user_info = UserSessionInfo(user_id="diana", foo="bar") + self.db.set(['diana'], user_info) + stored_user_info = self.db.get(['diana']) + assert stored_user_info["foo"] == "bar" + + def test_client_info(self): + user_info = UserSessionInfo(user_id="diana", foo="bar") + self.db.set(['diana'], user_info) + client_info = ClientSessionInfo(client_id="client_1") + self.db.set(['diana', "client_1"], client_info) + + stored_user_info = self.db.get(['diana']) + assert stored_user_info['subordinate'] == ['client_1'] + stored_client_info = self.db.get(['diana', "client_1"]) + assert stored_client_info['client_id'] == "client_1" + + def test_client_info_change(self): + user_info = UserSessionInfo(user_id="diana", foo="bar") + self.db.set(['diana'], user_info) + client_info = ClientSessionInfo(client_id="client_1", extra="snow") + self.db.set(['diana', "client_1"], client_info) + + user_info = self.db.get(['diana']) + assert user_info['subordinate'] == ['client_1'] + client_info = self.db.get(['diana', "client_1"]) + assert client_info['client_id'] == "client_1" + assert client_info['extra'] == "snow" + + client_info = ClientSessionInfo(client_id="client_1", extra="ice") + self.db.set(['diana', "client_1"], client_info) + + stored_client_info = self.db.get(['diana', "client_1"]) + assert stored_client_info['extra'] == "ice" + + def test_client_info_add1(self): + user_info = UserSessionInfo(user_id="diana") + self.db.set(['diana'], user_info) + client_info = ClientSessionInfo(client_id="client_1") + self.db.set(['diana', "client_1"], client_info) + + # The reference is there but not the value + del self.db._db[session_key('diana', "client_1")] + + client_info = ClientSessionInfo(client_id="client_1", extra="ice") + self.db.set(['diana', "client_1"], client_info) + + stored_client_info = self.db.get(['diana', "client_1"]) + assert stored_client_info['extra'] == "ice" + + def test_client_info_add2(self): + user_info = UserSessionInfo(foo="bar") + self.db.set(['diana'], user_info) + client_info = ClientSessionInfo(sid="abcdef") + self.db.set(['diana', "client_1"], client_info) + + # The reference is there but not the value + del self.db._db[session_key('diana', "client_1")] + + authn_event = create_authn_event(uid="diana", + expires_in=10, + authn_info="authn_class_ref") + + grant = Grant(authentication_event=authn_event, + authorization_request=AUTHZ_REQ) + + access_code = Token('access_code', value='1234567890') + grant.issued_token.append(access_code) + + self.db.set(['diana', "client_1", "G1"], grant) + stored_client_info = self.db.get(['diana', "client_1"]) + assert set(stored_client_info.keys()) == {"subordinate", "revoked", "type"} + + stored_grant_info = self.db.get(['diana', 'client_1', 'G1']) + assert stored_grant_info.issued_at + + def test_jump_ahead(self): + grant = Grant() + access_code = Token('access_code', value='1234567890') + grant.issued_token.append(access_code) + + self.db.set(['diana', "client_1", "G1"], grant) + + user_info = self.db.get(['diana']) + assert user_info['subordinate'] == ['client_1'] + client_info = self.db.get(['diana', "client_1"]) + assert client_info['subordinate'] == ["G1"] + grant_info = self.db.get(['diana', 'client_1', 'G1']) + assert grant_info.issued_at + assert len(grant_info.issued_token) == 1 + token = grant_info.issued_token[0] + assert token.value == '1234567890' + assert token.type == "access_code" + + def test_replace_grant_info_not_there(self): + grant = Grant() + access_code = Token('access_code', value='1234567890') + grant.issued_token.append(access_code) + + self.db.set(['diana', "client_1", "G1"], grant) + + # The reference is there but not the value + del self.db._db[session_key('diana', "client_1", "G1")] + + grant = Grant() + access_code = Token('access_code', value='aaaaaaaaa') + grant.issued_token.append(access_code) + + self.db.set(['diana', "client_1", "G1"], grant) + + stored_grant_info = self.db.get(['diana', 'client_1', 'G1']) + assert stored_grant_info.issued_at + assert len(stored_grant_info.issued_token) == 1 + token = stored_grant_info.issued_token[0] + assert token.value == 'aaaaaaaaa' + + def test_replace_user_info(self): + # store user info + self.db.set(['diana'], + UserSessionInfo(authentication_event=create_authn_event('diana'))) + + self.db.set(['diana'], + UserSessionInfo( + authentication_event=create_authn_event('diana', authn_time=111111111))) + + stored_user_info = self.db.get(['diana']) + assert stored_user_info["authentication_event"]["authn_time"] == 111111111 + + def test_add_client_info(self): + client_info = ClientSessionInfo(sid="abcdef") + self.db.set(['diana', "client_1"], client_info) + + stored_client_info = self.db.get(['diana', "client_1"]) + assert stored_client_info['sid'] == "abcdef" + + def test_half_way(self): + # store user info + self.db.set(['diana'], + UserSessionInfo(authentication_event=create_authn_event('diana'))) + + grant = Grant() + access_code = Token('access_code', value='1234567890') + grant.issued_token.append(access_code) + + self.db.set(['diana', "client_1", "G1"], grant) + + stored_grant_info = self.db.get(['diana', 'client_1', 'G1']) + assert stored_grant_info.issued_at + assert len(stored_grant_info.issued_token) == 1 + + def test_step_wise(self): + salt = "natriumklorid" + # store user info + self.db.set(['diana'], + UserSessionInfo(authentication_event=create_authn_event('diana'))) + # Client specific information + self.db.set(['diana', 'client_1'], ClientSessionInfo(sub=public_id( + 'diana', salt))) + # Grant + grant = Grant() + access_code = Token('access_code', value='1234567890') + grant.issued_token.append(access_code) + + self.db.set(['diana', 'client_1', 'G1'], grant) + + def test_removed(self): + grant = Grant() + access_code = Token('access_code', value='1234567890') + grant.issued_token.append(access_code) + + self.db.set(['diana', "client_1", "G1"], grant) + self.db.delete(['diana', 'client_1']) + with pytest.raises(KeyError): + self.db.get(['diana', "client_1", "G1"]) + + def test_client_info_removed(self): + user_info = UserSessionInfo(foo="bar") + self.db.set(['diana'], user_info) + client_info = ClientSessionInfo(sid="abcdef") + self.db.set(['diana', "client_1"], client_info) + + # The reference is there but not the value + del self.db._db[session_key('diana', "client_1")] + + with pytest.raises(NoSuchClientSession): + self.db.get(['diana', "client_1"]) + + def test_grant_info(self): + user_info = UserSessionInfo(foo="bar") + self.db.set(['diana'], user_info) + client_info = ClientSessionInfo(sid="abcdef") + self.db.set(['diana', "client_1"], client_info) + + with pytest.raises(ValueError): + self.db.get(['diana', "client_1", "G1"]) + + grant = Grant() + access_code = Token('access_code', value='1234567890') + grant.issued_token.append(access_code) + + self.db.set(['diana', "client_1", "G1"], grant) + + # removed value + del self.db._db[session_key('diana', "client_1", "G1")] + + with pytest.raises(NoSuchGrant): + self.db.get(['diana', "client_1", "G1"]) + + def test_delete_non_existent_info(self): + # Does nothing + self.db.delete(["diana"]) + + user_info = UserSessionInfo(foo="bar") + user_info.add_subordinate('client') + self.db.set(['diana'], user_info) + + # again silently does nothing + self.db.delete(["diana", "client"]) + + client_info = ClientSessionInfo(sid="abcdef") + client_info.add_subordinate('G1') + self.db.set(['diana', "client_1"], client_info) + + # and finally + self.db.delete(["diana", "client_1", "G1"]) + + def test_delete_user_info(self): + user_info = UserSessionInfo(foo="bar") + self.db.set(['diana'], user_info) + self.db.delete(["diana"]) + with pytest.raises(KeyError): + self.db.get(['diana']) diff --git a/tests/test_03_id_token.py b/tests/test_03_id_token.py index 50555b6..cd8ba95 100644 --- a/tests/test_03_id_token.py +++ b/tests/test_03_id_token.py @@ -1,21 +1,22 @@ import json import os -import time -import pytest from cryptojwt.jws import jws from cryptojwt.jwt import JWT from cryptojwt.key_jar import KeyJar from oidcmsg.oidc import AuthorizationRequest from oidcmsg.oidc import RegistrationResponse +from oidcmsg.time_util import time_sans_frac +import pytest +from oidcendpoint.authn_event import create_authn_event from oidcendpoint.client_authn import verify_client from oidcendpoint.endpoint_context import EndpointContext from oidcendpoint.id_token import IDToken from oidcendpoint.id_token import get_sign_and_encrypt_algorithms from oidcendpoint.oidc import userinfo from oidcendpoint.oidc.authorization import Authorization -from oidcendpoint.oidc.token import AccessToken +from oidcendpoint.oidc.token import Token from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD from oidcendpoint.user_info import UserInfo @@ -34,15 +35,34 @@ def full_path(local_file): USERS = json.loads(open(full_path("users.json")).read()) USERINFO = UserInfo(USERS) -AREQN = AuthorizationRequest( +AREQ = AuthorizationRequest( response_type="code", - client_id="client1", + client_id="client_1", redirect_uri="http://example.com/authz", scope=["openid"], state="state000", nonce="nonce", ) +AREQS = AuthorizationRequest( + response_type="code", + client_id="client_1", + redirect_uri="http://example.com/authz", + scope=["openid", "address", "email"], + state="state000", + nonce="nonce", +) + +AREQRC = AuthorizationRequest( + response_type="code", + client_id="client_1", + redirect_uri="http://example.com/authz", + scope=["openid", "address", "email"], + state="state000", + nonce="nonce", + claims={"id_token": {"nickname": None}} +) + conf = { "issuer": "https://example.com/", "password": "mycket hemligt", @@ -58,7 +78,7 @@ def full_path(local_file): "class": Authorization, "kwargs": {}, }, - "token_endpoint": {"path": "{}/token", "class": AccessToken, "kwargs": {}}, + "token_endpoint": {"path": "{}/token", "class": Token, "kwargs": {}}, "userinfo_endpoint": { "path": "{}/userinfo", "class": userinfo.UserInfo, @@ -72,12 +92,14 @@ def full_path(local_file): "kwargs": {"user": "diana"}, } }, - "userinfo": {"class": "oidcendpoint.user_info.UserInfo", "kwargs": {"db": USERS},}, + "userinfo": {"class": "oidcendpoint.user_info.UserInfo", "kwargs": {"db": USERS}, }, "client_authn": verify_client, "template_dir": "template", "id_token": {"class": IDToken, "kwargs": {"foo": "bar"}}, } +USER_ID = "diana" + class TestEndpoint(object): @pytest.fixture(autouse=True) @@ -93,186 +115,114 @@ def create_idtoken(self): self.endpoint_context.keyjar.add_symmetric( "client_1", "hemligtochintekort", ["sig", "enc"] ) + self.session_manager = self.endpoint_context.session_manager + self.user_id = USER_ID - def test_id_token_payload_0(self): - session_info = {"authn_req": AREQN, "sub": "1234567890"} - info = self.endpoint_context.idtoken.payload(session_info) - assert info["payload"] == {"sub": "1234567890", "nonce": "nonce"} - assert info["lifetime"] == 300 + def _create_session(self, auth_req, sub_type="public", sector_identifier=''): + if sector_identifier: + authz_req = auth_req.copy() + authz_req["sector_identifier_uri"] = sector_identifier + else: + authz_req = auth_req + + client_id = authz_req['client_id'] + ae = create_authn_event(self.user_id) + return self.session_manager.create_session(ae, authz_req, self.user_id, + client_id=client_id, + sub_type=sub_type) + + def _mint_code(self, grant, session_id): + # Constructing an authorization code is now done + return grant.mint_token( + session_id=session_id, + endpoint_context=self.endpoint_context, + token_type="authorization_code", + token_handler=self.session_manager.token_handler["code"], + expires_at=time_sans_frac() + 300 # 5 minutes from now + ) - def test_id_token_payload_1(self): - session_info = {"authn_req": AREQN, "sub": "1234567890"} + def _mint_access_token(self, grant, session_id, token_ref): + return grant.mint_token( + session_id=session_id, + endpoint_context=self.endpoint_context, + token_type="access_token", + token_handler=self.session_manager.token_handler["access_token"], + expires_at=time_sans_frac() + 900, # 15 minutes from now + based_on=token_ref # Means the token (tok) was used to mint this token + ) - info = self.endpoint_context.idtoken.payload(session_info) - assert info["payload"] == {"nonce": "nonce", "sub": "1234567890"} - assert info["lifetime"] == 300 + def test_id_token_payload_0(self): + session_id = self._create_session(AREQ) + payload = self.endpoint_context.idtoken.payload(session_id) + assert set(payload.keys()) == {"sub", "nonce", "auth_time"} def test_id_token_payload_with_code(self): - session_info = {"authn_req": AREQN, "sub": "1234567890"} + session_id = self._create_session(AREQ) + grant = self.session_manager[session_id] - info = self.endpoint_context.idtoken.payload( - session_info, code="ABCDEFGHIJKLMNOP" + code = self._mint_code(grant, session_id) + payload = self.endpoint_context.idtoken.payload( + session_id, AREQ["client_id"], code=code.value ) - assert info["payload"] == { - "nonce": "nonce", - "c_hash": "5-i4nCch0pDMX1VCVJHs1g", - "sub": "1234567890", - } - assert info["lifetime"] == 300 - - @pytest.mark.parametrize("add_c_hash", [True, False]) - def test_id_token_payload_with_code_in_session(self, add_c_hash): - self.endpoint_context.idtoken.add_c_hash = add_c_hash - code = "ABCDEFGHIJKLMNOP" - session_info = {"authn_req": AREQN, "sub": "1234567890", "code": code} - - info = self.endpoint_context.idtoken.payload(session_info) - if add_c_hash: - assert info["payload"] == { - "nonce": "nonce", - "sub": "1234567890", - "c_hash": "5-i4nCch0pDMX1VCVJHs1g", - } - else: - assert info["payload"] == { - "nonce": "nonce", - "sub": "1234567890", - } - assert info["lifetime"] == 300 + assert set(payload.keys()) == {"nonce", "c_hash", "sub", "auth_time"} def test_id_token_payload_with_access_token(self): - session_info = {"authn_req": AREQN, "sub": "1234567890"} + session_id = self._create_session(AREQ) + grant = self.session_manager[session_id] - info = self.endpoint_context.idtoken.payload( - session_info, access_token="012ABCDEFGHIJKLMNOP" - ) - assert info["payload"] == { - "nonce": "nonce", - "at_hash": "bKkyhbn1CC8IMdavzOV-Qg", - "sub": "1234567890", - } - assert info["lifetime"] == 300 - - @pytest.mark.parametrize("add_at_hash", [True, False]) - def test_id_token_payload_with_access_token_in_session(self, add_at_hash): - self.endpoint_context.idtoken.add_at_hash = add_at_hash - access_token = "012ABCDEFGHIJKLMNOP" - session_info = { - "authn_req": AREQN, - "sub": "1234567890", - "access_token": access_token - } + code = self._mint_code(grant, session_id) + access_token = self._mint_access_token(grant, session_id, code) - info = self.endpoint_context.idtoken.payload(session_info) - if add_at_hash: - assert info["payload"] == { - "nonce": "nonce", - "sub": "1234567890", - "at_hash": "bKkyhbn1CC8IMdavzOV-Qg", - } - else: - assert info["payload"] == { - "nonce": "nonce", - "sub": "1234567890", - } - assert info["lifetime"] == 300 + payload = self.endpoint_context.idtoken.payload( + session_id, AREQ["client_id"], access_token=access_token.value + ) + assert set(payload.keys()) == {"nonce", "at_hash", "sub", "auth_time"} def test_id_token_payload_with_code_and_access_token(self): - session_info = {"authn_req": AREQN, "sub": "1234567890"} + session_id = self._create_session(AREQ) + grant = self.session_manager[session_id] - info = self.endpoint_context.idtoken.payload( - session_info, access_token="012ABCDEFGHIJKLMNOP", code="ABCDEFGHIJKLMNOP" - ) - assert info["payload"] == { - "nonce": "nonce", - "at_hash": "bKkyhbn1CC8IMdavzOV-Qg", - "c_hash": "5-i4nCch0pDMX1VCVJHs1g", - "sub": "1234567890", - } - assert info["lifetime"] == 300 - - @pytest.mark.parametrize("add_hashes", [True, False]) - def test_id_token_payload_with_code_and_access_token_in_session( - self, add_hashes - ): - self.endpoint_context.idtoken.add_c_hash = add_hashes - self.endpoint_context.idtoken.add_at_hash = add_hashes - code = "ABCDEFGHIJKLMNOP" - access_token = "012ABCDEFGHIJKLMNOP" - session_info = { - "authn_req": AREQN, - "sub": "1234567890", - "access_token": access_token, - "code": code, - } + code = self._mint_code(grant, session_id) + access_token = self._mint_access_token(grant, session_id, code) - info = self.endpoint_context.idtoken.payload(session_info) - if add_hashes: - assert info["payload"] == { - "nonce": "nonce", - "sub": "1234567890", - "at_hash": "bKkyhbn1CC8IMdavzOV-Qg", - "c_hash": "5-i4nCch0pDMX1VCVJHs1g", - } - else: - assert info["payload"] == { - "nonce": "nonce", - "sub": "1234567890", - } - assert info["lifetime"] == 300 + payload = self.endpoint_context.idtoken.payload( + session_id, AREQ["client_id"], access_token=access_token.value, code=code.value + ) + assert set(payload.keys()) == {"nonce", "c_hash", "at_hash", "sub", "auth_time"} def test_id_token_payload_with_userinfo(self): - session_info = {"authn_req": AREQN, "sub": "1234567890"} + session_id = self._create_session(AREQ) + grant = self.session_manager[session_id] + grant.claims = {"id_token": {"given_name": None}} - info = self.endpoint_context.idtoken.payload( - session_info, user_info={"given_name": "Diana"} - ) - assert info["payload"] == { - "nonce": "nonce", - "given_name": "Diana", - "sub": "1234567890", - } - assert info["lifetime"] == 300 + payload = self.endpoint_context.idtoken.payload(session_id=session_id) + assert set(payload.keys()) == {"nonce", "given_name", "sub", "auth_time"} def test_id_token_payload_many_0(self): - session_info = {"authn_req": AREQN, "sub": "1234567890"} - - info = self.endpoint_context.idtoken.payload( - session_info, - user_info={"given_name": "Diana"}, - access_token="012ABCDEFGHIJKLMNOP", - code="ABCDEFGHIJKLMNOP", + session_id = self._create_session(AREQ) + grant = self.session_manager[session_id] + grant.claims = {"id_token": {"given_name": None}} + code = self._mint_code(grant, session_id) + access_token = self._mint_access_token(grant, session_id, code) + + payload = self.endpoint_context.idtoken.payload( + session_id, AREQ["client_id"], + access_token=access_token.value, + code=code.value ) - assert info["payload"] == { - "nonce": "nonce", - "given_name": "Diana", - "at_hash": "bKkyhbn1CC8IMdavzOV-Qg", - "c_hash": "5-i4nCch0pDMX1VCVJHs1g", - "sub": "1234567890", - } - assert info["lifetime"] == 300 + assert set(payload.keys()) == {"nonce", "c_hash", "at_hash", "sub", "auth_time", + "given_name"} def test_sign_encrypt_id_token(self): - client_info = RegistrationResponse( - id_token_signed_response_alg="RS512", client_id="client_1" - ) - session_info = { - "authn_req": AREQN, - "sub": "sub", - "authn_event": {"authn_info": "loa2", "authn_time": time.time()}, - } - - self.endpoint_context.jwx_def["signing_alg"] = {"id_token": "RS384"} - self.endpoint_context.cdb["client_1"] = client_info.to_dict() + session_id = self._create_session(AREQ) - _token = self.endpoint_context.idtoken.sign_encrypt( - session_info, "client_1", sign=True - ) + _token = self.endpoint_context.idtoken.sign_encrypt(session_id, AREQ['client_id'], + sign=True) assert _token _jws = jws.factory(_token) - assert _jws.jwt.headers["alg"] == "RS512" + assert _jws.jwt.headers["alg"] == "RS256" client_keyjar = KeyJar() _jwks = self.endpoint_context.keyjar.export_jwks() @@ -284,10 +234,9 @@ def test_sign_encrypt_id_token(self): assert res["aud"] == ["client_1"] def test_get_sign_algorithm(self): - client_info = RegistrationResponse() - endpoint_context = EndpointContext(conf) + client_info = self.endpoint_context.cdb[AREQ['client_id']] algs = get_sign_and_encrypt_algorithms( - endpoint_context, client_info, "id_token", sign=True + self.endpoint_context, client_info, "id_token", sign=True ) # default signing alg assert algs == {"sign": True, "encrypt": False, "sign_alg": "RS256"} @@ -334,20 +283,11 @@ def test_get_sign_algorithm_4(self): assert algs == {"sign": True, "encrypt": False, "sign_alg": "RS512"} def test_available_claims(self): - session_info = { - "authn_req": AREQN, - "sub": "sub", - "authn_event": { - "authn_info": "loa2", - "authn_time": time.time(), - "uid": "diana", - }, - } - self.endpoint_context.idtoken.kwargs["available_claims"] = { - "nickname": {"essential": True} - } - req = {"client_id": "client_1"} - _token = self.endpoint_context.idtoken.make(req, session_info) + session_id = self._create_session(AREQ) + grant = self.session_manager[session_id] + grant.claims = {"id_token": {"nickname": {"essential": True}}} + + _token = self.endpoint_context.idtoken.make(session_id=session_id) assert _token client_keyjar = KeyJar() _jwks = self.endpoint_context.keyjar.export_jwks() @@ -357,39 +297,32 @@ def test_available_claims(self): assert "nickname" in res def test_no_available_claims(self): - session_info = { - "authn_req": AREQN, - "sub": "sub", - "authn_event": { - "authn_info": "loa2", - "authn_time": time.time(), - "uid": "diana", - }, - } + session_id = self._create_session(AREQ) + grant = self.session_manager[session_id] + grant.claims = {"id_token": {"foobar": None}} + req = {"client_id": "client_1"} - _token = self.endpoint_context.idtoken.make(req, session_info) + _token = self.endpoint_context.idtoken.make(session_id=session_id) assert _token client_keyjar = KeyJar() _jwks = self.endpoint_context.keyjar.export_jwks() client_keyjar.import_jwks(_jwks, self.endpoint_context.issuer) _jwt = JWT(key_jar=client_keyjar, iss="client_1") res = _jwt.unpack(_token) - assert "nickname" not in res + assert "foobar" not in res def test_client_claims(self): - session_info = { - "authn_req": AREQN, - "sub": "sub", - "authn_event": { - "authn_info": "loa2", - "authn_time": time.time(), - "uid": "diana", - }, - } - self.endpoint_context.idtoken.enable_claims_per_client = True + session_id = self._create_session(AREQ) + grant = self.session_manager[session_id] + + self.endpoint_context.idtoken.kwargs["enable_claims_per_client"] = True self.endpoint_context.cdb["client_1"]["id_token_claims"] = {"address": None} - req = {"client_id": "client_1"} - _token = self.endpoint_context.idtoken.make(req, session_info) + + _claims = self.endpoint_context.claims_interface.get_claims( + session_id=session_id, scopes=AREQ["scope"], usage="id_token") + grant.claims = {'id_token': _claims} + + _token = self.endpoint_context.idtoken.make(session_id=session_id) assert _token client_keyjar = KeyJar() _jwks = self.endpoint_context.keyjar.export_jwks() @@ -400,22 +333,38 @@ def test_client_claims(self): assert "nickname" not in res def test_client_claims_with_default(self): - session_info = { - "authn_req": AREQN, - "sub": "sub", - "authn_event": { - "authn_info": "loa2", - "authn_time": time.time(), - "uid": "diana", - }, - } - self.endpoint_context.cdb["client_1"]["id_token_claims"] = {"address": None} - self.endpoint_context.idtoken.kwargs["available_claims"] = { - "nickname": {"essential": True} - } - self.endpoint_context.idtoken.enable_claims_per_client = True - req = {"client_id": "client_1"} - _token = self.endpoint_context.idtoken.make(req, session_info) + session_id = self._create_session(AREQ) + grant = self.session_manager[session_id] + + # self.endpoint_context.cdb["client_1"]["id_token_claims"] = {"address": None} + # self.endpoint_context.idtoken.enable_claims_per_client = True + + _claims = self.endpoint_context.claims_interface.get_claims( + session_id=session_id, scopes=AREQ["scope"], usage="id_token") + grant.claims = {"id_token": _claims} + + _token = self.endpoint_context.idtoken.make(session_id=session_id) + assert _token + client_keyjar = KeyJar() + _jwks = self.endpoint_context.keyjar.export_jwks() + client_keyjar.import_jwks(_jwks, self.endpoint_context.issuer) + _jwt = JWT(key_jar=client_keyjar, iss="client_1") + res = _jwt.unpack(_token) + + # No user info claims should be there + assert "address" not in res + assert "nickname" not in res + + def test_client_claims_scopes(self): + session_id = self._create_session(AREQS) + grant = self.session_manager[session_id] + + self.endpoint_context.idtoken.kwargs["add_claims_by_scope"] = True + _claims = self.endpoint_context.claims_interface.get_claims( + session_id=session_id, scopes=AREQS["scope"], usage="id_token") + grant.claims = {"id_token": _claims} + + _token = self.endpoint_context.idtoken.make(session_id=session_id) assert _token client_keyjar = KeyJar() _jwks = self.endpoint_context.keyjar.export_jwks() @@ -423,27 +372,51 @@ def test_client_claims_with_default(self): _jwt = JWT(key_jar=client_keyjar, iss="client_1") res = _jwt.unpack(_token) assert "address" in res + assert "email" in res + assert "nickname" not in res + + def test_client_claims_scopes_and_request_claims_no_match(self): + session_id = self._create_session(AREQRC) + grant = self.session_manager[session_id] + + self.endpoint_context.idtoken.kwargs["add_claims_by_scope"] = True + _claims = self.endpoint_context.claims_interface.get_claims( + session_id=session_id, scopes=AREQRC["scope"], usage="id_token") + grant.claims = {"id_token": _claims} + + _token = self.endpoint_context.idtoken.make(session_id=session_id) + assert _token + client_keyjar = KeyJar() + _jwks = self.endpoint_context.keyjar.export_jwks() + client_keyjar.import_jwks(_jwks, self.endpoint_context.issuer) + _jwt = JWT(key_jar=client_keyjar, iss="client_1") + res = _jwt.unpack(_token) + # User information, from scopes -> claims + assert "address" in res + assert "email" in res + # User info, requested by claims parameter assert "nickname" in res - def test_client_claims_disabled(self): - # enable_claims_per_client defaults to False - session_info = { - "authn_req": AREQN, - "sub": "sub", - "authn_event": { - "authn_info": "loa2", - "authn_time": time.time(), - "uid": "diana", - }, - } - self.endpoint_context.cdb["client_1"]["id_token_claims"] = {"address": None} - req = {"client_id": "client_1"} - _token = self.endpoint_context.idtoken.make(req, session_info) + def test_client_claims_scopes_and_request_claims_one_match(self): + _req = AREQS.copy() + _req["claims"] = {"id_token": {"email": {"value": "diana@example.com"}}} + + session_id = self._create_session(_req) + grant = self.session_manager[session_id] + + self.endpoint_context.idtoken.kwargs["add_claims_by_scope"] = True + _claims = self.endpoint_context.claims_interface.get_claims( + session_id=session_id, scopes=_req["scope"], usage="id_token") + grant.claims = {"id_token": _claims} + + _token = self.endpoint_context.idtoken.make(session_id=session_id) assert _token client_keyjar = KeyJar() _jwks = self.endpoint_context.keyjar.export_jwks() client_keyjar.import_jwks(_jwks, self.endpoint_context.issuer) _jwt = JWT(key_jar=client_keyjar, iss="client_1") res = _jwt.unpack(_token) - assert "address" not in res - assert "nickname" not in res + # Email didn't match + assert "email" not in res + # Scope -> claims + assert "address" in res diff --git a/tests/test_04_token_handler.py b/tests/test_04_token_handler.py index c0aab54..3a7de23 100644 --- a/tests/test_04_token_handler.py +++ b/tests/test_04_token_handler.py @@ -6,10 +6,12 @@ import pytest -from oidcendpoint.token_handler import Crypt -from oidcendpoint.token_handler import DefaultToken -from oidcendpoint.token_handler import TokenHandler -from oidcendpoint.token_handler import is_expired +from oidcendpoint.token import Crypt +from oidcendpoint.token import is_expired +from oidcendpoint.token.handler import DefaultToken +from oidcendpoint.token.handler import TokenHandler +from oidcendpoint.token.handler import factory +from oidcendpoint.token.jwt_token import JWTToken def test_is_expired(): @@ -92,7 +94,7 @@ def test_is_expired(self): assert self.th.is_expired(_token) is False when = time.time() + 900 - assert self.th.is_expired(_token, when) + assert self.th.is_expired(_token, int(when)) class TestTokenHandler(object): @@ -153,3 +155,60 @@ def test_get_handler(self): def test_keys(self): assert set(self.handler.keys()) == {"access_token", "code", "refresh_token"} + + +class DummyEndpointContext(): + def __init__(self): + self.keyjar = None + self.issuer = "issuer" + self.cdb = {} + + +def test_token_handler_from_config(): + conf = { + "token_handler_args": { + "jwks_def": { + "private_path": "private/token_jwks.json", + "read_only": False, + "key_defs": [ + {"type": "oct", "bytes": "24", "use": ["enc"], "kid": "code"} + ], + }, + "code": {"lifetime": 600}, + "token": { + "class": "oidcendpoint.token.jwt_token.JWTToken", + "kwargs": { + "lifetime": 3600, + "add_claims_by_scope": True, + "aud": ["https://example.org/appl"], + }, + }, + "refresh": { + "class": "oidcendpoint.token.jwt_token.JWTToken", + "kwargs": { + "lifetime": 3600, + "aud": ["https://example.org/appl"], + } + } + } + } + + token_handler = factory(DummyEndpointContext(), **conf["token_handler_args"]) + assert token_handler + assert len(token_handler.handler) == 3 + assert set(token_handler.handler.keys()) == {"code", "access_token", "refresh_token"} + assert isinstance(token_handler.handler["code"], DefaultToken) + assert isinstance(token_handler.handler["access_token"], JWTToken) + assert isinstance(token_handler.handler["refresh_token"], JWTToken) + + assert token_handler.handler["code"].lifetime == 600 + + assert token_handler.handler["access_token"].alg == "ES256" + assert token_handler.handler["access_token"].kwargs == {"add_claims_by_scope": True} + assert token_handler.handler["access_token"].lifetime == 3600 + assert token_handler.handler["access_token"].def_aud == ["https://example.org/appl"] + + assert token_handler.handler["refresh_token"].alg == "ES256" + assert token_handler.handler["refresh_token"].kwargs == {} + assert token_handler.handler["refresh_token"].lifetime == 3600 + assert token_handler.handler["refresh_token"].def_aud == ["https://example.org/appl"] diff --git a/tests/test_05_session_manager.py b/tests/test_05_session_manager.py new file mode 100644 index 0000000..a327edf --- /dev/null +++ b/tests/test_05_session_manager.py @@ -0,0 +1,504 @@ +from oidcmsg.oidc import AuthorizationRequest +from oidcmsg.time_util import time_sans_frac +import pytest + +from oidcendpoint.authn_event import AuthnEvent +from oidcendpoint.authz import AuthzHandling +from oidcendpoint.endpoint_context import EndpointContext +from oidcendpoint.oidc.authorization import Authorization +from oidcendpoint.oidc.token import Token +from oidcendpoint.session import MintingNotAllowed +from oidcendpoint.session import session_key +from oidcendpoint.session.grant import Grant +from oidcendpoint.session.info import ClientSessionInfo +from oidcendpoint.session.manager import SessionManager +from oidcendpoint.session.token import AccessToken +from oidcendpoint.session.token import AuthorizationCode +from oidcendpoint.session.token import RefreshToken +from oidcendpoint.token.handler import factory + +AUTH_REQ = AuthorizationRequest( + client_id="client_1", + redirect_uri="https://example.com/cb", + scope=["openid"], + state="STATE", + response_type="code", +) + +MAP = { + "authorization_code": "code", + "access_token": "access_token", + "refresh_token": "refresh_token" +} + +KEYDEFS = [ + {"type": "RSA", "key": "", "use": ["sig"]}, + {"type": "EC", "crv": "P-256", "use": ["sig"]}, +] + +DUMMY_SESSION_ID = session_key('user_id', 'client_id', 'grant.id') + + +class TestSessionManager: + @pytest.fixture(autouse=True) + def create_session_manager(self): + conf = { + "issuer": "https://example.com/", + "password": "mycket hemligt", + "token_expires_in": 600, + "grant_expires_in": 300, + "refresh_token_expires_in": 86400, + "verify_ssl": False, + "keys": {"key_defs": KEYDEFS, "uri_path": "static/jwks.json"}, + "jwks_uri": "https://example.com/jwks.json", + "token_handler_args": { + "jwks_def": { + "private_path": "private/token_jwks.json", + "read_only": False, + "key_defs": [ + {"type": "oct", "bytes": "24", "use": ["enc"], "kid": "code"} + ], + }, + "code": {"lifetime": 600}, + "token": { + "class": "oidcendpoint.token.jwt_token.JWTToken", + "kwargs": { + "lifetime": 3600, + "add_claims": True, + "add_claim_by_scope": True, + "aud": ["https://example.org/appl"], + }, + }, + "refresh": { + "class": "oidcendpoint.token.jwt_token.JWTToken", + "kwargs": { + "lifetime": 3600, + "aud": ["https://example.org/appl"], + } + } + }, + "endpoint": { + "authorization_endpoint": { + "path": "{}/authorization", + "class": Authorization, + "kwargs": {}, + }, + "token_endpoint": {"path": "{}/token", "class": Token, "kwargs": {}}, + }, + "template_dir": "template", + } + + self.endpoint_context = EndpointContext(conf) + token_handler = factory(self.endpoint_context, **conf["token_handler_args"]) + + self.session_manager = SessionManager(handler=token_handler) + self.authn_event = AuthnEvent(uid="uid", + valid_until=time_sans_frac() + 1, + authn_info="authn_class_ref") + + @pytest.mark.parametrize("sub_type, sector_identifier", [ + ("pairwise", "https://all.example.com"), + ("public", ""), ("ephemeral", "")]) + def test_create_session_sub_type(self, sub_type, sector_identifier): + # First session + authz_req = AUTH_REQ.copy() + if sub_type == "pairwise": + authz_req["sector_identifier_uri"] = sector_identifier + + session_key_1 = self.session_manager.create_session(authn_event=self.authn_event, + auth_req=authz_req, + user_id='diana', + client_id="client_1", + sub_type=sub_type) + + _user_info_1 = self.session_manager.get_user_session_info(session_key_1) + assert _user_info_1["subordinate"] == ["client_1"] + _client_info_1 = self.session_manager.get_client_session_info(session_key_1) + assert len(_client_info_1["subordinate"]) == 1 + # grant = self.session_manager.get_grant(session_key_1) + + # Second session + authn_req = AUTH_REQ.copy() + authn_req["client_id"] = "client_2" + if sub_type == "pairwise": + authn_req["sector_identifier_uri"] = sector_identifier + + session_key_2 = self.session_manager.create_session(authn_event=self.authn_event, + auth_req=authn_req, + user_id='diana', + client_id="client_2", + sub_type=sub_type) + + _user_info_2 = self.session_manager.get_user_session_info(session_key_2) + assert _user_info_2["subordinate"] == ["client_1", "client_2"] + + grant_1 = self.session_manager.get_grant(session_key_1) + grant_2 = self.session_manager.get_grant(session_key_2) + + if sub_type in ["pairwise", "public"]: + assert grant_1.sub == grant_2.sub + else: + assert grant_1.sub != grant_2.sub + + # Third session + authn_req = AUTH_REQ.copy() + authn_req["client_id"] = "client_3" + if sub_type == "pairwise": + authn_req["sector_identifier_uri"] = sector_identifier + + session_key_3 = self.session_manager.create_session(authn_event=self.authn_event, + auth_req=authn_req, + user_id='diana', + client_id="client_3", + sub_type=sub_type) + + grant_3 = self.session_manager.get_grant(session_key_3) + + if sub_type == "pairwise": + assert grant_1.sub == grant_2.sub + assert grant_1.sub == grant_3.sub + assert grant_3.sub == grant_2.sub + elif sub_type == "public": + assert grant_1.sub == grant_2.sub + assert grant_1.sub == grant_3.sub + assert grant_3.sub == grant_2.sub + else: + assert grant_1.sub != grant_2.sub + assert grant_1.sub != grant_3.sub + assert grant_3.sub != grant_2.sub + + # Sub types differ so do authentication request + + assert grant_1.authorization_request != grant_2.authorization_request + assert grant_1.authorization_request != grant_3.authorization_request + assert grant_3.authorization_request != grant_2.authorization_request + + def _mint_token(self, type, grant, session_id, based_on=None): + # Constructing an authorization code is now done + return grant.mint_token( + session_id=session_id, + endpoint_context=self.endpoint_context, + token_type=type, + token_handler=self.session_manager.token_handler.handler[MAP[type]], + expires_at=time_sans_frac() + 300, # 5 minutes from now + based_on=based_on + ) + + def test_grant(self): + grant = Grant() + assert grant.issued_token == [] + assert grant.is_active() is True + + code = self._mint_token('authorization_code', grant, DUMMY_SESSION_ID) + assert isinstance(code, AuthorizationCode) + assert code.is_active() + assert len(grant.issued_token) == 1 + assert code.max_usage_reached() is False + + def test_code_usage(self): + grant = Grant() + assert grant.issued_token == [] + assert grant.is_active() is True + + code = self._mint_token('authorization_code', grant, DUMMY_SESSION_ID) + assert isinstance(code, AuthorizationCode) + assert code.is_active() + assert len(grant.issued_token) == 1 + + assert code.usage_rules["supports_minting"] == ['access_token', 'refresh_token'] + access_token = self._mint_token('access_token', grant, DUMMY_SESSION_ID, code) + assert isinstance(access_token, AccessToken) + assert access_token.is_active() + assert len(grant.issued_token) == 2 + + refresh_token = self._mint_token('refresh_token', grant, DUMMY_SESSION_ID, code) + assert isinstance(refresh_token, RefreshToken) + assert refresh_token.is_active() + assert len(grant.issued_token) == 3 + + code.register_usage() + assert code.max_usage_reached() is True + + with pytest.raises(MintingNotAllowed): + self._mint_token('access_token', grant, DUMMY_SESSION_ID, code) + + grant.revoke_token(based_on=code.value) + + assert access_token.revoked == True + assert refresh_token.revoked == True + + def test_add_grant(self): + self.session_manager.create_session(authn_event=self.authn_event, + auth_req=AUTH_REQ, + user_id='diana', + client_id="client_1") + + grant = self.session_manager.add_grant( + user_id="diana", client_id="client_1", + scope=["openid", "phoe"], + claims={"userinfo": {"given_name": None}}) + + assert grant.scope == ["openid", "phoe"] + + _grant = self.session_manager.get(['diana', 'client_1', grant.id]) + + assert _grant.scope == ["openid", "phoe"] + + def test_find_token(self): + self.session_manager.create_session(authn_event=self.authn_event, + auth_req=AUTH_REQ, + user_id='diana', + client_id="client_1") + + grant = self.session_manager.add_grant(user_id="diana", + client_id="client_1") + + code = self._mint_token('authorization_code', grant, DUMMY_SESSION_ID) + access_token = self._mint_token('access_token', grant, DUMMY_SESSION_ID, code) + + _session_key = session_key('diana', 'client_1', grant.id) + _token = self.session_manager.find_token(_session_key, access_token.value) + + assert _token.type == "access_token" + assert _token.id == access_token.id + + def test_get_authentication_event(self): + session_id = self.session_manager.create_session(authn_event=self.authn_event, + auth_req=AUTH_REQ, + user_id='diana', + client_id="client_1") + + _info = self.session_manager.get_session_info(session_id, authentication_event=True) + authn_event = _info["authentication_event"] + + assert isinstance(authn_event, AuthnEvent) + assert authn_event["uid"] == "uid" + assert authn_event["authn_info"] == "authn_class_ref" + + def test_get_client_session_info(self): + self.session_manager.create_session(authn_event=self.authn_event, + auth_req=AUTH_REQ, + user_id='diana', + client_id="client_1") + + grant = self.session_manager.add_grant(user_id="diana", + client_id="client_1") + + _session_id = session_key('diana', 'client_1', grant.id) + csi = self.session_manager.get_client_session_info(_session_id) + + assert isinstance(csi, ClientSessionInfo) + + def test_get_general_session_info(self): + _session_id = self.session_manager.create_session(authn_event=self.authn_event, + auth_req=AUTH_REQ, + user_id='diana', + client_id="client_1") + + _session_info = self.session_manager.get_session_info(_session_id) + + assert set(_session_info.keys()) == {'client_id', + 'grant_id', + 'session_id', + 'user_id'} + assert _session_info["user_id"] == "diana" + assert _session_info["client_id"] == "client_1" + + def test_get_session_info_by_token(self): + _session_id = self.session_manager.create_session(authn_event=self.authn_event, + auth_req=AUTH_REQ, + user_id='diana', + client_id="client_1") + + grant = self.session_manager.get_grant(_session_id) + code = self._mint_token('authorization_code', grant, _session_id) + _session_info = self.session_manager.get_session_info_by_token(code.value) + + assert set(_session_info.keys()) == {'client_id', + 'session_id', + 'grant_id', + 'user_id'} + assert _session_info["user_id"] == "diana" + assert _session_info["client_id"] == "client_1" + + def test_token_usage_default(self): + _session_id = self.session_manager.create_session(authn_event=self.authn_event, + auth_req=AUTH_REQ, + user_id='diana', + client_id="client_1") + grant = self.session_manager[_session_id] + + code = self._mint_token('authorization_code', grant, _session_id) + + assert code.usage_rules == { + 'max_usage': 1, + 'supports_minting': ['access_token', 'refresh_token'] + } + + token = self._mint_token("access_token", grant, _session_id, code) + + assert token.usage_rules == {} + + refresh_token = self._mint_token("refresh_token", grant, _session_id, code) + + assert refresh_token.usage_rules == { + 'supports_minting': ['access_token', 'refresh_token'] + } + + def test_token_usage_grant(self): + _session_id = self.session_manager.create_session(authn_event=self.authn_event, + auth_req=AUTH_REQ, + user_id='diana', + client_id="client_1") + grant = self.session_manager[_session_id] + grant.usage_rules = { + "authorization_code": { + "max_usage": 1, + "supports_minting": ["access_token", "refresh_token", "id_token"], + "expires_in": 300 + }, + "access_token": { + "expires_in": 3600 + }, + "refresh_token": { + "supports_minting": ["access_token", "refresh_token", "id_token"] + } + } + + code = self._mint_token('authorization_code', grant, _session_id) + assert code.usage_rules == { + 'max_usage': 1, + 'supports_minting': ['access_token', 'refresh_token', 'id_token'], + "expires_in": 300 + } + + token = self._mint_token("access_token", grant, _session_id, code) + assert token.usage_rules == { + "expires_in": 3600 + } + + refresh_token = self._mint_token("refresh_token", grant, _session_id, code) + assert refresh_token.usage_rules == { + 'supports_minting': ['access_token', 'refresh_token', "id_token"] + } + + def test_token_usage_authz(self): + grant_config = { + "usage_rules": { + "authorization_code": { + 'supports_minting': ["access_token"], + "max_usage": 1, + "expires_in": 120 + }, + "access_token": { + "expires_in": 600 + }, + "refresh_token": {} + }, + "expires_in": 43200 + } + + self.endpoint_context.authz = AuthzHandling(self.endpoint_context, + grant_config=grant_config) + + self.endpoint_context.cdb["client_1"] = {} + + token_usage_rules = self.endpoint_context.authz.usage_rules('client_1') + + _session_id = self.session_manager.create_session(authn_event=self.authn_event, + auth_req=AUTH_REQ, + user_id='diana', + client_id="client_1", + token_usage_rules=token_usage_rules) + grant = self.session_manager[_session_id] + + code = self._mint_token('authorization_code', grant, _session_id) + assert code.usage_rules == { + 'max_usage': 1, + 'supports_minting': ['access_token'], + "expires_in": 120 + } + + token = self._mint_token("access_token", grant, _session_id, code) + assert token.usage_rules == { + "expires_in": 600 + } + + with pytest.raises(MintingNotAllowed): + self._mint_token("refresh_token", grant, _session_id, code) + + def test_token_usage_client_config(self): + grant_config = { + "usage_rules": { + "authorization_code": { + 'supports_minting': ["access_token"], + "max_usage": 1, + "expires_in": 120 + }, + "access_token": { + "expires_in": 600 + }, + "refresh_token": {} + }, + "expires_in": 43200 + } + + self.endpoint_context.authz = AuthzHandling(self.endpoint_context, + grant_config=grant_config) + + # Change expiration time for the code and allow refresh tokens for this + # specific client + self.endpoint_context.cdb["client_1"] = { + "token_usage_rules": { + "authorization_code": { + "expires_in": 600, + 'supports_minting': ['access_token', "refresh_token"] + }, + "refresh_token": { + 'supports_minting': ['access_token'] + } + } + } + + token_usage_rules = self.endpoint_context.authz.usage_rules('client_1') + + _session_id = self.session_manager.create_session(authn_event=self.authn_event, + auth_req=AUTH_REQ, + user_id='diana', + client_id="client_1", + token_usage_rules=token_usage_rules) + grant = self.session_manager[_session_id] + + code = self._mint_token('authorization_code', grant, _session_id) + assert code.usage_rules == { + 'max_usage': 1, + 'supports_minting': ['access_token', 'refresh_token'], + "expires_in": 600 + } + + token = self._mint_token("access_token", grant, _session_id, code) + assert token.usage_rules == { + "expires_in": 600 + } + + refresh_token = self._mint_token("refresh_token", grant, _session_id, code) + assert refresh_token.usage_rules == { + 'supports_minting': ['access_token'] + } + + # Test with another client + + self.endpoint_context.cdb["client_2"] = {} + + token_usage_rules = self.endpoint_context.authz.usage_rules('client_2') + + _session_id = self.session_manager.create_session(authn_event=self.authn_event, + auth_req=AUTH_REQ, + user_id='diana', + client_id="client_2", + token_usage_rules=token_usage_rules) + grant = self.session_manager[_session_id] + code = self._mint_token('authorization_code', grant, _session_id) + # Not allowed to mint refresh token for this client + with pytest.raises(MintingNotAllowed): + self._mint_token("refresh_token", grant, _session_id, code) diff --git a/tests/test_05_sso_db.py b/tests/test_05_sso_db.py deleted file mode 100644 index 5c8479f..0000000 --- a/tests/test_05_sso_db.py +++ /dev/null @@ -1,135 +0,0 @@ -import shutil - -import pytest - -from oidcendpoint.sso_db import SSODb - -DB_CONF = { - "handler": "oidcmsg.storage.abfile.AbstractFileSystem", - "fdir": "db/sso", - "key_conv": "oidcmsg.storage.converter.QPKey", - "value_conv": "oidcmsg.storage.converter.JSON", -} - - -def rmtree(item): - try: - shutil.rmtree(item) - except FileNotFoundError: - pass - - -class TestSSODB(object): - @pytest.fixture(autouse=True) - def create_sdb(self): - rmtree("db/sso") - self.sso_db = SSODb(DB_CONF) - - def test_map_sid2uid(self): - self.sso_db.map_sid2uid("session id 1", "Lizz") - assert self.sso_db.get_sids_by_uid("Lizz") == ["session id 1"] - - def test_missing_map(self): - assert self.sso_db.get_sids_by_uid("Lizz") == [] - - def test_multiple_map_sid2uid(self): - self.sso_db.map_sid2uid("session id 1", "Lizz") - self.sso_db.map_sid2uid("session id 2", "Lizz") - assert set(self.sso_db.get_sids_by_uid("Lizz")) == { - "session id 1", - "session id 2", - } - - def test_map_unmap_sid2uid(self): - self.sso_db.map_sid2uid("session id 1", "Lizz") - self.sso_db.map_sid2uid("session id 2", "Lizz") - assert set(self.sso_db.get_sids_by_uid("Lizz")) == { - "session id 1", - "session id 2", - } - - self.sso_db.remove_sid2uid("session id 1", "Lizz") - assert self.sso_db.get_sids_by_uid("Lizz") == ["session id 2"] - - def test_get_uid_by_sid(self): - self.sso_db.map_sid2uid("session id 1", "Lizz") - self.sso_db.map_sid2uid("session id 2", "Lizz") - - assert self.sso_db.get_uid_by_sid("session id 1") == "Lizz" - assert self.sso_db.get_uid_by_sid("session id 2") == "Lizz" - - def test_remove_uid(self): - self.sso_db.map_sid2uid("session id 1", "Lizz") - self.sso_db.map_sid2uid("session id 2", "Diana") - - self.sso_db.remove_uid("Lizz") - assert self.sso_db.get_uid_by_sid("session id 1") is None - assert self.sso_db.get_sids_by_uid("Lizz") == [] - - def test_map_sid2sub(self): - self.sso_db.map_sid2sub("session id 1", "abcdefgh") - assert self.sso_db.get_sids_by_sub("abcdefgh") == ["session id 1"] - - def test_missing_sid2sub_map(self): - assert self.sso_db.get_sids_by_sub("abcdefgh") == [] - - def test_multiple_map_sid2sub(self): - self.sso_db.map_sid2sub("session id 1", "abcdefgh") - self.sso_db.map_sid2sub("session id 2", "abcdefgh") - assert set(self.sso_db.get_sids_by_sub("abcdefgh")) == { - "session id 1", - "session id 2", - } - - def test_map_unmap_sid2sub(self): - self.sso_db.map_sid2sub("session id 1", "abcdefgh") - self.sso_db.map_sid2sub("session id 2", "abcdefgh") - assert set(self.sso_db.get_sids_by_sub("abcdefgh")) == { - "session id 1", - "session id 2", - } - - self.sso_db.remove_sid2sub("session id 1", "abcdefgh") - assert self.sso_db.get_sids_by_sub("abcdefgh") == ["session id 2"] - - def test_get_sub_by_sid(self): - self.sso_db.map_sid2sub("session id 1", "abcdefgh") - self.sso_db.map_sid2sub("session id 2", "abcdefgh") - - assert set(self.sso_db.get_sids_by_sub("abcdefgh")) == { - "session id 1", - "session id 2", - } - - def test_remove_sub(self): - self.sso_db.map_sid2sub("session id 1", "abcdefgh") - self.sso_db.map_sid2sub("session id 2", "012346789") - - self.sso_db.remove_sub("abcdefgh") - assert self.sso_db.get_sub_by_sid("session id 1") is None - assert self.sso_db.get_sids_by_sub("abcdefgh") == [] - # have not touched the others - assert self.sso_db.get_sub_by_sid("session id 2") == "012346789" - assert self.sso_db.get_sids_by_sub("012346789") == ["session id 2"] - - def test_get_sub_by_uid_same_sub(self): - self.sso_db.map_sid2sub("session id 1", "abcdefgh") - self.sso_db.map_sid2sub("session id 2", "abcdefgh") - - self.sso_db.map_sid2uid("session id 1", "Lizz") - self.sso_db.map_sid2uid("session id 2", "Lizz") - - res = self.sso_db.get_subs_by_uid("Lizz") - - assert set(res) == {"abcdefgh"} - - def test_get_sub_by_uid_different_sub(self): - self.sso_db.map_sid2sub("session id 1", "abcdefgh") - self.sso_db.map_sid2sub("session id 2", "012346789") - - self.sso_db.map_sid2uid("session id 1", "Lizz") - self.sso_db.map_sid2uid("session id 2", "Lizz") - - res = self.sso_db.get_subs_by_uid("Lizz") - - assert set(res) == {"abcdefgh", "012346789"} diff --git a/tests/test_06_authn_context.py b/tests/test_06_authn_context.py index 284bc23..d3f41d6 100644 --- a/tests/test_06_authn_context.py +++ b/tests/test_06_authn_context.py @@ -11,7 +11,7 @@ from oidcendpoint.id_token import IDToken from oidcendpoint.oidc.authorization import Authorization from oidcendpoint.oidc.provider_config import ProviderConfiguration -from oidcendpoint.oidc.token import AccessToken +from oidcendpoint.oidc.token import Token from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD from oidcendpoint.user_authn.authn_context import TIMESYNCTOKEN from oidcendpoint.user_authn.authn_context import init_method @@ -58,7 +58,7 @@ "private_key_jwt", ], "response_modes_supported": ["query", "fragment", "form_post"], - "subject_types_supported": ["public", "pairwise"], + "subject_types_supported": ["public", "pairwise", "ephemeral"], "grant_types_supported": [ "authorization_code", "implicit", @@ -152,7 +152,7 @@ def create_authn_broker(self): "class": Authorization, "kwargs": {}, }, - "token": {"path": "{}/token", "class": AccessToken, "kwargs": {}}, + "token": {"path": "{}/token", "class": Token, "kwargs": {}}, }, "authentication": METHOD, "userinfo": {"class": UserInfo, "kwargs": {"db": USERINFO_db}}, @@ -202,15 +202,14 @@ def test_pick_authn_all(self): def test_authn_event(): an = AuthnEvent( uid="uid", - salt="_salt_", valid_until=time_sans_frac() + 1, authn_info="authn_class_ref", ) - assert an.valid() + assert an.is_valid() n = time_sans_frac() + 3 - assert an.valid(n) is False + assert an.is_valid(n) is False n = an.expires_in() assert n == 1 # could possibly be 0 diff --git a/tests/test_07_userinfo.py b/tests/test_07_userinfo.py index d3a0573..b5bc391 100644 --- a/tests/test_07_userinfo.py +++ b/tests/test_07_userinfo.py @@ -2,24 +2,24 @@ import os import pytest -from oidcmsg.message import Message from oidcmsg.oidc import OpenIDRequest -from oidcmsg.oidc import OpenIDSchema from oidcendpoint.authn_event import create_authn_event from oidcendpoint.endpoint_context import EndpointContext +from oidcendpoint.id_token import IDToken +from oidcendpoint.oidc import userinfo from oidcendpoint.oidc.authorization import Authorization from oidcendpoint.oidc.provider_config import ProviderConfiguration from oidcendpoint.oidc.registration import Registration from oidcendpoint.scopes import SCOPE2CLAIMS -from oidcendpoint.scopes import STANDARD_CLAIMS from oidcendpoint.scopes import convert_scopes2claims +from oidcendpoint.session import session_key +from oidcendpoint.session import unpack_session_key +from oidcendpoint.session.claims import STANDARD_CLAIMS +from oidcendpoint.session.claims import ClaimsInterface +from oidcendpoint.session.grant import Grant from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD from oidcendpoint.user_info import UserInfo -from oidcendpoint.userinfo import by_schema -from oidcendpoint.userinfo import claims_match -from oidcendpoint.userinfo import collect_user_info -from oidcendpoint.userinfo import update_claims CLAIMS = { "userinfo": { @@ -136,27 +136,27 @@ def test_custom_scopes(): assert set( convert_scopes2claims(["email"], _available_claims, map=_scopes).keys() - ) == {"email", "email_verified",} + ) == {"email", "email_verified", } assert set( convert_scopes2claims(["address"], _available_claims, map=_scopes).keys() ) == {"address"} assert set( convert_scopes2claims(["phone"], _available_claims, map=_scopes).keys() - ) == {"phone_number", "phone_number_verified",} + ) == {"phone_number", "phone_number_verified", } assert set( convert_scopes2claims( ["research_and_scholarship"], _available_claims, map=_scopes ).keys() ) == { - "name", - "given_name", - "family_name", - "email", - "email_verified", - "sub", - "eduperson_scoped_affiliation", - } + "name", + "given_name", + "family_name", + "email", + "email_verified", + "sub", + "eduperson_scoped_affiliation", + } PROVIDER_INFO = { @@ -172,71 +172,6 @@ def test_custom_scopes(): ] } - -def test_update_claims_authn_req_id_token(): - _session_info = {"authn_req": OIDR} - claims = update_claims(_session_info, "id_token", PROVIDER_INFO) - assert set(claims.keys()) == {"auth_time", "acr"} - - -def test_update_claims_authn_req_userinfo(): - _session_info = {"authn_req": OIDR} - claims = update_claims(_session_info, "userinfo", PROVIDER_INFO) - assert set(claims.keys()) == { - "given_name", - "nickname", - "email", - "email_verified", - "picture", - "http://example.info/claims/groups", - } - - -def test_update_claims_authzreq_id_token(): - _session_info = {"authn_req": OIDR} - claims = update_claims(_session_info, "id_token", PROVIDER_INFO) - assert set(claims.keys()) == {"auth_time", "acr"} - - -def test_update_claims_authzreq_userinfo(): - _session_info = {"authn_req": OIDR} - claims = update_claims(_session_info, "userinfo", PROVIDER_INFO) - assert set(claims.keys()) == { - "given_name", - "nickname", - "email", - "email_verified", - "picture", - "http://example.info/claims/groups", - } - - -def test_clams_value(): - assert claims_match("red", CLAIMS["userinfo"]["http://example.info/claims/groups"]) - - -def test_clams_values(): - assert claims_match("urn:mace:incommon:iap:silver", CLAIMS["id_token"]["acr"]) - - -def test_clams_essential(): - assert claims_match(["foobar@example"], CLAIMS["userinfo"]["email"]) - - -def test_clams_none(): - assert claims_match(["angle"], CLAIMS["userinfo"]["nickname"]) - - -def test_by_schema(): - # There are no requested or optional claims defined for Message - assert by_schema(Message, sub="John") == {} - - assert by_schema(OpenIDSchema, sub="John", given_name="John", age=34) == { - "sub": "John", - "given_name": "John", - } - - KEYDEFS = [ {"type": "RSA", "key": "", "use": ["sig"]}, {"type": "EC", "crv": "P-256", "use": ["sig"]}, @@ -282,6 +217,24 @@ def create_endpoint_context(self): "request_uri_parameter_supported": True, }, }, + "userinfo": { + "path": "userinfo", + "class": userinfo.UserInfo, + "kwargs": { + "claim_types_supported": [ + "normal", + "aggregated", + "distributed", + ], + "client_authn_method": ["bearer_header"], + "base_claims": { + "eduperson_scoped_affiliation": None, + "email": None, + }, + "add_claims_by_scope": True, + "enable_claims_per_client": True + }, + }, }, "keys": { "public_path": "jwks.json", @@ -296,78 +249,141 @@ def create_endpoint_context(self): } }, "template_dir": "template", + "id_token": { + "class": IDToken, + "kwargs": { + "base_claims": { + "email": None, + "email_verified": None, + }, + "enable_claims_per_client": True + }, + }, } ) # Just has to be there self.endpoint_context.cdb["client1"] = {} + self.session_manager = self.endpoint_context.session_manager + self.claims_interface = ClaimsInterface(self.endpoint_context) + self.user_id = "diana" + + def _create_session(self, auth_req, sub_type="public", sector_identifier=''): + if sector_identifier: + authz_req = auth_req.copy() + authz_req["sector_identifier_uri"] = sector_identifier + else: + authz_req = auth_req + client_id = authz_req['client_id'] + ae = create_authn_event(self.user_id) + return self.session_manager.create_session(ae, authz_req, self.user_id, + client_id=client_id, + sub_type=sub_type) def test_collect_user_info(self): _req = OIDR.copy() _req["claims"] = CLAIMS_2 - _session_info = {"authn_req": _req} - session = _session_info.copy() - session["sub"] = "doe" - session["uid"] = "diana" - session["authn_event"] = create_authn_event("diana", "salt") + session_id = self._create_session(_req) + + _userinfo_restriction = self.claims_interface.get_claims(session_id=session_id, + scopes=OIDR["scope"], + usage="userinfo") - res = collect_user_info(self.endpoint_context, session) + res = self.claims_interface.get_user_claims("diana", _userinfo_restriction) assert res == { + 'eduperson_scoped_affiliation': ['staff@example.org'], + "email": "diana@example.org", "nickname": "Dina", - "sub": "doe", + "email_verified": False + } + + _id_token_restriction = self.claims_interface.get_claims(session_id=session_id, + scopes=OIDR["scope"], + usage="id_token") + + res = self.claims_interface.get_user_claims("diana", _id_token_restriction) + + assert res == { "email": "diana@example.org", "email_verified": False, } + _introspection_restriction = self.claims_interface.get_claims(session_id=session_id, + scopes=OIDR["scope"], + usage="introspection") + + res = self.claims_interface.get_user_claims("diana", _introspection_restriction) + + assert res == {} + def test_collect_user_info_2(self): _req = OIDR.copy() - _req["scope"] = "openid email" + _req["scope"] = "openid email address" del _req["claims"] - _session_info = {"authn_req": _req} - session = _session_info.copy() - session["sub"] = "doe" - session["uid"] = "diana" - session["authn_event"] = create_authn_event("diana", "salt") + session_id = self._create_session(_req) + _uid, _cid, _gid = unpack_session_key(session_id) - self.endpoint_context.provider_info["scopes_supported"] = [ - "openid", - "email", - "offline_access", - ] - res = collect_user_info(self.endpoint_context, session) + _userinfo_restriction = self.claims_interface.get_claims(session_id=session_id, + scopes=_req["scope"], + usage="userinfo") + + res = self.claims_interface.get_user_claims("diana", _userinfo_restriction) assert res == { - "sub": "doe", - "email": "diana@example.org", - "email_verified": False, + 'address': { + 'country': 'Sweden', + 'locality': 'Umeå', + 'postal_code': 'SE-90187', + 'street_address': 'Umeå Universitet' + }, + 'eduperson_scoped_affiliation': ['staff@example.org'], + 'email': 'diana@example.org', + 'email_verified': False } - def test_collect_user_info_scope_not_supported(self): + def test_collect_user_info_scope_not_supported_no_base_claims(self): _req = OIDR.copy() _req["scope"] = "openid email address" del _req["claims"] - _session_info = {"authn_req": _req} - session = _session_info.copy() - session["sub"] = "doe" - session["uid"] = "diana" - session["authn_event"] = create_authn_event("diana", "salt") + session_id = self._create_session(_req) + _uid, _cid, _gid = unpack_session_key(session_id) - # Scope address not supported - self.endpoint_context.provider_info["scopes_supported"] = [ - "openid", - "email", - "offline_access", - ] - res = collect_user_info(self.endpoint_context, session) + self.endpoint_context.endpoint["userinfo"].kwargs["add_claims_by_scope"] = False + self.endpoint_context.endpoint["userinfo"].kwargs["enable_claims_per_client"] = False + del self.endpoint_context.endpoint["userinfo"].kwargs["base_claims"] - assert res == { - "sub": "doe", - "email": "diana@example.org", - "email_verified": False, - } + _userinfo_restriction = self.claims_interface.get_claims(session_id=session_id, + scopes=_req["scope"], + usage="userinfo") + + res = self.claims_interface.get_user_claims("diana", _userinfo_restriction) + + assert res == {} + + def test_collect_user_info_enable_claims_per_client(self): + _req = OIDR.copy() + _req["scope"] = "openid email address" + del _req["claims"] + + session_id = self._create_session(_req) + _uid, _cid, _gid = unpack_session_key(session_id) + + self.endpoint_context.endpoint["userinfo"].kwargs["add_claims_by_scope"] = False + self.endpoint_context.endpoint["userinfo"].kwargs["enable_claims_per_client"] = True + del self.endpoint_context.endpoint["userinfo"].kwargs["base_claims"] + + self.endpoint_context.cdb[_req["client_id"]]["userinfo_claims"] = {"phone_number": None} + + _userinfo_restriction = self.claims_interface.get_claims(session_id=session_id, + scopes=_req["scope"], + usage="userinfo") + + res = self.claims_interface.get_user_claims("diana", _userinfo_restriction) + + assert res == {'phone_number': '+46907865000'} class TestCollectUserInfoCustomScopes: @@ -409,6 +425,24 @@ def create_endpoint_context(self): "request_uri_parameter_supported": True, }, }, + "userinfo": { + "path": "userinfo", + "class": userinfo.UserInfo, + "kwargs": { + "claim_types_supported": [ + "normal", + "aggregated", + "distributed", + ], + "client_authn_method": ["bearer_header"], + "base_claims": { + "eduperson_scoped_affiliation": None, + "email": None, + }, + "add_claims_by_scope": True, + "enable_claims_per_client": True + }, + }, }, "add_on": { "custom_scopes": { @@ -440,66 +474,62 @@ def create_endpoint_context(self): } }, "template_dir": "template", + "id_token": { + "class": IDToken, + "kwargs": { + "base_claims": { + "email": None, + "email_verified": None, + }, + "enable_claims_per_client": True + }, + }, } ) self.endpoint_context.cdb["client1"] = {} - - def test_collect_user_info(self): - _session_info = {"authn_req": OIDR} - session = _session_info.copy() - session["sub"] = "doe" - session["uid"] = "diana" - session["authn_event"] = create_authn_event("diana", "salt") - - res = collect_user_info(self.endpoint_context, session) - - assert res == { - "email": "diana@example.org", - "email_verified": False, - "nickname": "Dina", - "given_name": "Diana", - "sub": "doe", - } - - def test_collect_user_info_2(self): + self.session_manager = self.endpoint_context.session_manager + self.claims_interface = ClaimsInterface(self.endpoint_context) + self.user_id = "diana" + + def _create_session(self, auth_req, sub_type="public", sector_identifier=''): + if sector_identifier: + authz_req = auth_req.copy() + authz_req["sector_identifier_uri"] = sector_identifier + else: + authz_req = auth_req + client_id = authz_req['client_id'] + ae = create_authn_event(self.user_id) + return self.session_manager.create_session(ae, authz_req, self.user_id, + client_id=client_id, + sub_type=sub_type) + + # def _do_grant(self, auth_req): + # client_id = auth_req['client_id'] + # # The user consent module produces a Grant instance + # grant = Grant(scope=auth_req['scope'], resources=[client_id]) + # + # # the grant is assigned to a session (user_id, client_id) + # self.session_manager.set([self.user_id, client_id, grant.id], grant) + # return session_key(self.user_id, client_id, grant.id) + + def test_collect_user_info_custom_scope(self): _req = OIDR.copy() - _req["scope"] = "openid email" + _req["scope"] = "openid research_and_scholarship" del _req["claims"] - _session_info = {"authn_req": _req} - session = _session_info.copy() - session["sub"] = "doe" - session["uid"] = "diana" - session["authn_event"] = create_authn_event("diana", "salt") - - self.endpoint_context.provider_info["claims_supported"].remove("email") - self.endpoint_context.provider_info["claims_supported"].remove("email_verified") - - res = collect_user_info(self.endpoint_context, session) + session_id = self._create_session(_req) - assert res == {"sub": "doe"} + _restriction = self.claims_interface.get_claims(session_id=session_id, + scopes=_req["scope"], + usage="userinfo") - def test_collect_user_info_scope_not_supported(self): - _req = OIDR.copy() - _req["scope"] = "openid email address" - del _req["claims"] - - _session_info = {"authn_req": _req} - session = _session_info.copy() - session["sub"] = "doe" - session["uid"] = "diana" - session["authn_event"] = create_authn_event("diana", "salt") - - # Scope address not supported - self.endpoint_context.provider_info["scopes_supported"] = [ - "openid", - "email", - "offline_access", - ] - res = collect_user_info(self.endpoint_context, session) + res = self.claims_interface.get_user_claims("diana", _restriction) assert res == { - "sub": "doe", - "email": "diana@example.org", - "email_verified": False, + 'eduperson_scoped_affiliation': ['staff@example.org'], + 'email': 'diana@example.org', + 'email_verified': False, + 'family_name': 'Krall', + 'given_name': 'Diana', + 'name': 'Diana Krall' } diff --git a/tests/test_08_session.py b/tests/test_08_session.py deleted file mode 100644 index dac1dc3..0000000 --- a/tests/test_08_session.py +++ /dev/null @@ -1,520 +0,0 @@ -import os -import shutil -import time - -from oidcendpoint.token_handler import UnknownToken -from oidcmsg.oidc import AuthorizationRequest -from oidcmsg.oidc import OpenIDRequest -from oidcmsg.storage.init import storage_factory -import pytest - -from oidcendpoint import rndstr -from oidcendpoint import token_handler -from oidcendpoint.authn_event import create_authn_event -from oidcendpoint.endpoint_context import EndpointContext -from oidcendpoint.oidc.authorization import Authorization -from oidcendpoint.oidc.provider_config import ProviderConfiguration -from oidcendpoint.session import SessionDB -from oidcendpoint.session import setup_session -from oidcendpoint.sso_db import SSODb -from oidcendpoint.token_handler import WrongTokenType -from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD -from oidcendpoint.user_info import UserInfo - -__author__ = "rohe0002" - -AREQ = AuthorizationRequest( - response_type="code", - client_id="client1", - redirect_uri="http://example.com/authz", - scope=["openid"], - state="state000", -) - -AREQN = AuthorizationRequest( - response_type="code", - client_id="client1", - redirect_uri="http://example.com/authz", - scope=["openid"], - state="state000", - nonce="something", -) - -AREQO = AuthorizationRequest( - response_type="code", - client_id="client1", - redirect_uri="http://example.com/authz", - scope=["openid", "offline_access"], - prompt="consent", - state="state000", -) - -OIDR = OpenIDRequest( - response_type="code", - client_id="client1", - redirect_uri="http://example.com/authz", - scope=["openid"], - state="state000", -) - -BASEDIR = os.path.abspath(os.path.dirname(__file__)) - - -def full_path(local_file): - return os.path.join(BASEDIR, local_file) - - -SSO_DB_CONF = { - "handler": "oidcmsg.storage.abfile.AbstractFileSystem", - "fdir": "db/sso", - "key_conv": "oidcmsg.storage.converter.QPKey", - "value_conv": "oidcmsg.storage.converter.JSON", -} - -SESSION_DB_CONF = { - "handler": "oidcmsg.storage.abfile.AbstractFileSystem", - "fdir": "db/session", - "key_conv": "oidcmsg.storage.converter.QPKey", - "value_conv": "oidcmsg.storage.converter.JSON", -} - - -def rmtree(item): - try: - shutil.rmtree(item) - except FileNotFoundError: - pass - - -class TestSessionDB(object): - @pytest.fixture(autouse=True) - def create_sdb(self): - rmtree("db/sso") - rmtree("db/session") - - passwd = rndstr(24) - _th_args = { - "code": {"lifetime": 600, "password": passwd}, - "token": {"lifetime": 3600, "password": passwd}, - "refresh": {"lifetime": 86400, "password": passwd}, - } - - _token_handler = token_handler.factory(None, **_th_args) - userinfo = UserInfo(db_file=full_path("users.json")) - self.sdb = SessionDB( - storage_factory(SESSION_DB_CONF), - _token_handler, - SSODb(SSO_DB_CONF), - userinfo, - ) - - def test_create_authz_session(self): - ae = create_authn_event("uid", "salt") - sid = self.sdb.create_authz_session(ae, AREQ, client_id="client_id") - self.sdb.do_sub(sid, uid="user", salt="client_salt") - - info = self.sdb[sid] - assert info["client_id"] == "client_id" - assert set(info.keys()) == { - "sid", - "client_id", - "authn_req", - "authn_event", - "sub", - "oauth_state", - "code", - } - - def test_create_authz_session_without_nonce(self): - ae = create_authn_event("sub", "salt") - sid = self.sdb.create_authz_session(ae, AREQ, client_id="client_id") - info = self.sdb[sid] - assert info["oauth_state"] == "authz" - - def test_create_authz_session_with_nonce(self): - ae = create_authn_event("sub", "salt") - sid = self.sdb.create_authz_session(ae, AREQN, client_id="client_id") - info = self.sdb[sid] - authz_request = info["authn_req"] - assert authz_request["nonce"] == "something" - - def test_create_authz_session_with_id_token(self): - ae = create_authn_event("sub", "salt") - sid = self.sdb.create_authz_session( - ae, AREQ, client_id="client_id", id_token="id_token" - ) - - info = self.sdb[sid] - assert info["id_token"] == "id_token" - - def test_create_authz_session_with_oidreq(self): - ae = create_authn_event("sub", "salt") - sid = self.sdb.create_authz_session( - ae, AREQ, client_id="client_id", oidreq=OIDR - ) - info = self.sdb[sid] - assert "id_token" not in info - assert "oidreq" in info - - def test_create_authz_session_with_sector_id(self): - ae = create_authn_event("sub", "salt") - sid = self.sdb.create_authz_session( - ae, AREQ, client_id="client_id", oidreq=OIDR - ) - self.sdb.do_sub( - sid, "user1", "client_salt", "http://example.com/si.jwt", "pairwise" - ) - - info_1 = self.sdb[sid].copy() - assert "id_token" not in info_1 - assert "oidreq" in info_1 - assert info_1["sub"] != "sub" - - self.sdb.do_sub( - sid, "user2", "client_salt", "http://example.net/si.jwt", "pairwise" - ) - - info_2 = self.sdb[sid] - assert info_2["sub"] != "sub" - assert info_2["sub"] != info_1["sub"] - - def test_upgrade_to_token(self): - ae1 = create_authn_event("uid", "salt") - sid = self.sdb.create_authz_session(ae1, AREQ, client_id="client_id") - self.sdb[sid]["sub"] = "sub" - grant = self.sdb[sid]["code"] - _dict = self.sdb.upgrade_to_token(grant) - - print(_dict.keys()) - assert set(_dict.keys()) == { - "sid", - "authn_event", - "authn_req", - "access_token", - "token_type", - "client_id", - "oauth_state", - "expires_in", - "expires_at", - "code_is_used" - } - - # can't update again - # with pytest.raises(AccessCodeUsed): - print(self.sdb.upgrade_to_token(grant)) - - def test_upgrade_to_token_refresh(self): - ae1 = create_authn_event("sub", "salt") - sid = self.sdb.create_authz_session(ae1, AREQO, client_id="client_id") - self.sdb.do_sub(sid, "user", ae1["salt"]) - grant = self.sdb[sid]["code"] - # Issue an access token trading in the access grant code - _dict = self.sdb.upgrade_to_token(grant, issue_refresh=True) - - print(_dict.keys()) - assert set(_dict.keys()) == { - "sid", - "authn_event", - "authn_req", - "access_token", - "sub", - "token_type", - "client_id", - "oauth_state", - "refresh_token", - "expires_in", - "expires_at", - "code_is_used" - } - - # You can't refresh a token using the token itself - with pytest.raises(WrongTokenType): - self.sdb.refresh_token(_dict["access_token"]) - - def test_upgrade_to_token_with_id_token_and_oidreq(self): - ae2 = create_authn_event("another_user_id", "salt") - sid = self.sdb.create_authz_session(ae2, AREQ, client_id="client_id") - self.sdb[sid]["sub"] = "sub" - grant = self.sdb[sid]["code"] - - _dict = self.sdb.upgrade_to_token(grant, id_token="id_token", oidreq=OIDR) - print(_dict.keys()) - assert set(_dict.keys()) == { - "sid", - "authn_event", - "authn_req", - "oidreq", - "access_token", - "id_token", - "token_type", - "client_id", - "oauth_state", - "expires_in", - "expires_at", - "code_is_used" - } - - assert _dict["id_token"] == "id_token" - assert isinstance(_dict["oidreq"], OpenIDRequest) - - def test_refresh_token(self): - ae = create_authn_event("uid", "salt") - sid = self.sdb.create_authz_session(ae, AREQ, client_id="client_id") - self.sdb[sid]["sub"] = "sub" - grant = self.sdb[sid]["code"] - - dict1 = self.sdb.upgrade_to_token(grant, issue_refresh=True).copy() - rtoken = dict1["refresh_token"] - dict2 = self.sdb.refresh_token(rtoken, AREQ["client_id"]) - - assert dict1["access_token"] != dict2["access_token"] - - with pytest.raises(WrongTokenType): - self.sdb.refresh_token(dict2["access_token"], AREQ["client_id"]) - - def test_refresh_token_cleared_session(self): - ae = create_authn_event("uid", "salt") - sid = self.sdb.create_authz_session(ae, AREQ, client_id="client_id") - self.sdb[sid]["sub"] = "sub" - grant = self.sdb[sid]["code"] - dict1 = self.sdb.upgrade_to_token(grant, issue_refresh=True) - ac1 = dict1["access_token"] - - # Purge the SessionDB - self.sdb._db = {} - - rtoken = dict1["refresh_token"] - with pytest.raises(UnknownToken): - self.sdb.refresh_token(rtoken, AREQ["client_id"]) - - def test_is_valid(self): - ae1 = create_authn_event("uid", "salt") - sid = self.sdb.create_authz_session(ae1, AREQ, client_id="client_id") - self.sdb[sid]["sub"] = "sub" - grant = self.sdb[sid]["code"] - - assert self.sdb.is_valid("code", grant) - - sinfo = self.sdb.upgrade_to_token(grant, issue_refresh=True) - assert not self.sdb.is_valid("code", grant) - access_token = sinfo["access_token"] - assert self.sdb.is_valid("access_token", access_token) - - refresh_token = sinfo["refresh_token"] - sinfo = self.sdb.refresh_token(refresh_token, AREQ["client_id"]) - access_token2 = sinfo["access_token"] - assert self.sdb.is_valid("access_token", access_token2) - - # The old access code should be invalid - try: - self.sdb.is_valid("access_token", access_token) - except KeyError: - pass - - def test_valid_grant(self): - ae = create_authn_event("another:user", "salt") - sid = self.sdb.create_authz_session(ae, AREQ, client_id="client_id") - grant = self.sdb[sid]["code"] - - assert self.sdb.is_valid("code", grant) - - def test_revoke_token(self): - ae1 = create_authn_event("uid", "salt") - sid = self.sdb.create_authz_session(ae1, AREQ, client_id="client_id") - self.sdb[sid]["sub"] = "sub" - - grant = self.sdb[sid]["code"] - tokens = self.sdb.upgrade_to_token(grant, issue_refresh=True) - access_token = tokens["access_token"] - refresh_token = tokens["refresh_token"] - - assert self.sdb.is_valid("access_token", access_token) - - self.sdb.revoke_token(sid, "access_token") - assert not self.sdb.is_valid("access_token", access_token) - - sinfo = self.sdb.refresh_token(refresh_token, AREQ["client_id"]) - access_token = sinfo["access_token"] - assert self.sdb.is_valid("access_token", access_token) - - self.sdb.revoke_token(sid, "refresh_token") - assert not self.sdb.is_valid("refresh_token", refresh_token) - - ae2 = create_authn_event("sub", "salt") - sid = self.sdb.create_authz_session(ae2, AREQ, client_id="client_2") - - grant = self.sdb[sid]["code"] - self.sdb.revoke_token(sid, "code") - assert not self.sdb.is_valid("code", grant) - - def test_sub_to_authn_event(self): - ae = create_authn_event("sub", "salt", time_stamp=time.time()) - sid = self.sdb.create_authz_session(ae, AREQ, client_id="client_id") - sub = self.sdb.do_sub(sid, "user", "client_salt") - - # given the sub find out whether the authn event is still valid - sids = self.sdb.get_sids_by_sub(sub) - ae = self.sdb[sids[0]]["authn_event"] - assert ae.valid() - - def test_do_sub_deterministic(self): - ae = create_authn_event("tester", "random_value") - sid = self.sdb.create_authz_session(ae, AREQ, client_id="client_id") - self.sdb.do_sub(sid, "user", "other_random_value") - - info = self.sdb[sid] - assert ( - info["sub"] - == "d657bddf3d30970aa681663978ea84e26553ead03cb6fe8fcfa6523f2bcd0ad2" - ) - - self.sdb.do_sub( - sid, - "user", - "other_random_value", - sector_id="http://example.com", - subject_type="pairwise", - ) - info2 = self.sdb[sid] - assert ( - info2["sub"] - == "1442ceb13a822e802f85832ce93a8fda011e32a3363834dd1db3f9aa211065bd" - ) - - self.sdb.do_sub( - sid, - "user", - "another_random_value", - sector_id="http://other.example.com", - subject_type="pairwise", - ) - - info2 = self.sdb[sid] - assert ( - info2["sub"] - == "56e0a53d41086e7b22d78d52ee461655e9b090d50a0663d16136ea49a56c9bec" - ) - - def test_match_session(self): - ae1 = create_authn_event("uid", "salt") - sid = self.sdb.create_authz_session(ae1, AREQ, client_id="client_id") - self.sdb[sid]["sub"] = "sub" - self.sdb.sso_db.map_sid2uid(sid, "uid") - - res = self.sdb.match_session("uid", client_id="client_id") - assert res == sid - - def test_get_token(self): - ae1 = create_authn_event("uid", "salt") - sid = self.sdb.create_authz_session(ae1, AREQ, client_id="client_id") - self.sdb[sid]["sub"] = "sub" - self.sdb.sso_db.map_sid2uid(sid, "uid") - - grant = self.sdb.get_token(sid) - assert self.sdb.is_valid("code", grant) - assert self.sdb.handler.type(grant) == "A" - - -KEYDEFS = [ - {"type": "RSA", "key": "", "use": ["sig"]}, - {"type": "EC", "crv": "P-256", "use": ["sig"]}, -] - -conf = { - "issuer": "https://example.com/", - "password": "mycket hemligt", - "token_expires_in": 600, - "grant_expires_in": 300, - "refresh_token_expires_in": 86400, - "verify_ssl": False, - "capabilities": {}, - "keys": { - "uri_path": "static/jwks.json", - "key_defs": KEYDEFS, - "private_path": "own/jwks.json", - }, - "endpoint": { - "provider_config": { - "path": ".well-known/openid-configuration", - "class": ProviderConfiguration, - "kwargs": {}, - }, - "authorization_endpoint": { - "path": "authorization", - "class": Authorization, - "kwargs": {}, - }, - }, - "authentication": { - "anon": { - "acr": INTERNETPROTOCOLPASSWORD, - "class": "oidcendpoint.user_authn.user.NoAuthn", - "kwargs": {"user": "diana"}, - } - }, - "userinfo": {"class": UserInfo, "kwargs": {"db_file": full_path("users.json")}}, - "template_dir": "template", - "sso_db": SSO_DB_CONF, - "session_db": SESSION_DB_CONF, -} - - -def test_setup_session(): - endpoint_context = EndpointContext(conf) - uid = "_user_" - client_id = "EXTERNAL" - areq = None - acr = None - sid = setup_session(endpoint_context, areq, uid, client_id, acr, salt="salt") - assert sid - - -def test_setup_session_upgrade_to_token(): - endpoint_context = EndpointContext(conf) - uid = "_user_" - client_id = "EXTERNAL" - areq = None - acr = None - sid = setup_session(endpoint_context, areq, uid, client_id, acr, salt="salt") - assert sid - code = endpoint_context.sdb[sid]["code"] - assert code - - res = endpoint_context.sdb.upgrade_to_token(code) - assert "access_token" in res - - endpoint_context.sdb.revoke_uid("_user_") - assert endpoint_context.sdb.is_session_revoked(sid) - - -def make_sub_uid(uid, **kwargs): - return uid - - -def test_sub_minting_function(): - conf["sub_func"] = {"public": {"function": make_sub_uid}} - - endpoint_context = EndpointContext(conf) - uid = "_user_" - client_id = "EXTERNAL" - areq = None - acr = None - sid = setup_session(endpoint_context, areq, uid, client_id, acr, salt="salt") - assert endpoint_context.sdb[sid]["sub"] == uid - - -class SubMinter(object): - def __call__(self, *args, **kwargs): - return args[0] - - -def test_sub_minting_class(): - conf["sub_func"] = {"public": {"class": SubMinter}} - - endpoint_context = EndpointContext(conf) - uid = "_user_" - client_id = "EXTERNAL" - areq = None - acr = None - sid = setup_session(endpoint_context, areq, uid, client_id, acr, salt="salt") - assert endpoint_context.sdb[sid]["sub"] == uid diff --git a/tests/test_08_session_life.py b/tests/test_08_session_life.py new file mode 100644 index 0000000..ad63f27 --- /dev/null +++ b/tests/test_08_session_life.py @@ -0,0 +1,490 @@ +import os + +import pytest +from cryptojwt.key_jar import init_key_jar +from oidcmsg.oidc import AccessTokenRequest +from oidcmsg.oidc import AuthorizationRequest +from oidcmsg.oidc import RefreshAccessTokenRequest +from oidcmsg.time_util import time_sans_frac + +from oidcendpoint import user_info +from oidcendpoint.authn_event import create_authn_event +from oidcendpoint.client_authn import verify_client +from oidcendpoint.endpoint_context import EndpointContext +from oidcendpoint.id_token import IDToken +from oidcendpoint.oidc.authorization import Authorization +from oidcendpoint.oidc.provider_config import ProviderConfiguration +from oidcendpoint.oidc.registration import Registration +from oidcendpoint.oidc.session import Session +from oidcendpoint.oidc.token import Token +from oidcendpoint.session import session_key +from oidcendpoint.session import unpack_session_key +from oidcendpoint.session.grant import Grant +from oidcendpoint.session.info import ClientSessionInfo +from oidcendpoint.session.info import UserSessionInfo +from oidcendpoint.session.manager import SessionManager +from oidcendpoint.session.manager import public_id +from oidcendpoint.token import DefaultToken +from oidcendpoint.token.handler import TokenHandler +from oidcendpoint.token.handler import factory +from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD + + +class TestSession(): + @pytest.fixture(autouse=True) + def setup_token_handler(self): + password = "The longer the better. Is this close to enough ?" + conf = { + "issuer": "https://example.com/", + "password": "mycket hemligt", + "token_expires_in": 600, + "grant_expires_in": 300, + "refresh_token_expires_in": 86400, + "verify_ssl": False, + "keys": {"key_defs": KEYDEFS, "uri_path": "static/jwks.json"}, + "jwks_uri": "https://example.com/jwks.json", + "token_handler_args": { + "code": { + "kwargs": { + "lifetime": 600, + "password": password + }}, + "token": { + "kwargs": { + "lifetime": 900, + "password": password + } + }, + "refresh": { + "kwargs": { + "lifetime": 86400, + "password": password + } + } + }, + "endpoint": { + "authorization_endpoint": { + "path": "{}/authorization", + "class": Authorization, + "kwargs": {}, + }, + "token_endpoint": {"path": "{}/token", "class": Token, "kwargs": {}}, + }, + "template_dir": "template", + } + + self.endpoint_context = EndpointContext(conf) + token_handler = factory(self.endpoint_context, **conf["token_handler_args"]) + + self.session_manager = SessionManager(handler=token_handler) + + def auth(self): + # Start with an authentication request + # The client ID appears in the request + AUTH_REQ = AuthorizationRequest( + client_id="client_1", + redirect_uri="https://example.com/cb", + scope=["openid", "mail", "address", "offline_access"], + state="STATE", + response_type="code", + ) + + # The authentication returns a user ID + user_id = "diana" + + # User info is stored in the Session DB + authn_event = create_authn_event( + user_id, + authn_info=INTERNETPROTOCOLPASSWORD, + authn_time=time_sans_frac(), + ) + + user_info = UserSessionInfo(user_id=user_id) + self.session_manager.set([user_id], user_info) + + # Now for client session information + client_id = AUTH_REQ['client_id'] + client_info = ClientSessionInfo(client_id=client_id) + self.session_manager.set([user_id, client_id], client_info) + + # The user consent module produces a Grant instance + + grant = Grant(scope=AUTH_REQ['scope'], + resources=[client_id], + authorization_request=AUTH_REQ, + authentication_event=authn_event) + + # the grant is assigned to a session (user_id, client_id) + session_id = session_key(user_id, client_id, grant.id) + self.session_manager.set([user_id, client_id, grant.id], grant) + + # Constructing an authorization code is now done by + + code = grant.mint_token( + session_id=session_id, + endpoint_context=self.endpoint_context, + token_type='authorization_code', + token_handler= self.session_manager.token_handler["code"], + expires_at=time_sans_frac() + 300 # 5 minutes from now + ) + + return grant.id, code + + def test_code_flow(self): + # code is a Token instance + _grant_id, code = self.auth() + + # next step is access token request + + TOKEN_REQ = AccessTokenRequest( + client_id="client_1", + redirect_uri="https://example.com/cb", + state="STATE", + grant_type="authorization_code", + client_secret="hemligt", + code=code.value + ) + + # parse the token + session_id = self.session_manager.token_handler.sid(TOKEN_REQ['code']) + + # Now given I have the client_id from the request and the user_id from the + # token I can easily find the grant + + # client_info = self.session_manager.get([user_id, TOKEN_REQ['client_id']]) + tok = self.session_manager.find_token(session_id, TOKEN_REQ['code']) + + # Verify that it's of the correct type and can be used + assert tok.type == "authorization_code" + assert tok.is_active() + + # Mint an access token and a refresh token and mark the code as used + + assert tok.supports_minting("access_token") + + grant = self.session_manager[session_id] + + access_token = grant.mint_token( + session_id=session_id, + endpoint_context=self.endpoint_context, + token_type='access_token', + token_handler=self.session_manager.token_handler["access_token"], + expires_at=time_sans_frac() + 900, # 15 minutes from now + based_on=tok # Means the token (tok) was used to mint this token + ) + + assert tok.supports_minting("refresh_token") + + refresh_token = grant.mint_token( + session_id=session_id, + endpoint_context=self.endpoint_context, + token_type='refresh_token', + token_handler=self.session_manager.token_handler["refresh_token"], + based_on=tok + ) + + tok.register_usage() + + assert tok.max_usage_reached() is True + + # A bit later a refresh token is used to mint a new access token + + REFRESH_TOKEN_REQ = RefreshAccessTokenRequest( + grant_type="refresh_token", + client_id="client_1", + client_secret="hemligt", + refresh_token=refresh_token.value, + scope=["openid", "mail", "offline_access"] + ) + + reftok = self.session_manager.find_token(session_id, + REFRESH_TOKEN_REQ['refresh_token']) + + assert reftok.supports_minting("access_token") + + access_token_2 = grant.mint_token( + session_id=session_id, + endpoint_context=self.endpoint_context, + token_type='access_token', + token_handler=self.session_manager.token_handler["access_token"], + expires_at=time_sans_frac() + 900, # 15 minutes from now + based_on=reftok # Means the token (tok) was used to mint this token + ) + + assert access_token_2.is_active() + + +KEYDEFS = [ + {"type": "RSA", "key": "", "use": ["sig"]}, + {"type": "EC", "crv": "P-256", "use": ["sig"]}, +] + +ISSUER = "https://example.com/" + +KEYJAR = init_key_jar(key_defs=KEYDEFS, issuer_id=ISSUER) +KEYJAR.import_jwks(KEYJAR.export_jwks(True, ISSUER), "") +RESPONSE_TYPES_SUPPORTED = [ + ["code"], + ["token"], + ["id_token"], + ["code", "token"], + ["code", "id_token"], + ["id_token", "token"], + ["code", "token", "id_token"], + ["none"], +] +CAPABILITIES = { + "response_types_supported": [" ".join(x) for x in RESPONSE_TYPES_SUPPORTED], + "token_endpoint_auth_methods_supported": [ + "client_secret_post", + "client_secret_basic", + "client_secret_jwt", + "private_key_jwt", + ], + "response_modes_supported": ["query", "fragment", "form_post"], + "subject_types_supported": ["public", "pairwise", "ephemeral"], + "grant_types_supported": [ + "authorization_code", + "implicit", + "urn:ietf:params:oauth:grant-type:jwt-bearer", + ], + "claim_types_supported": ["normal", "aggregated", "distributed"], + "claims_parameter_supported": True, + "request_parameter_supported": True, + "request_uri_parameter_supported": True, +} +BASEDIR = os.path.abspath(os.path.dirname(__file__)) + + +def full_path(local_file): + return os.path.join(BASEDIR, local_file) + + +class TestSessionJWTToken(): + @pytest.fixture(autouse=True) + def setup_session_manager(self): + conf = { + "issuer": ISSUER, + "password": "mycket hemligt", + "token_expires_in": 600, + "grant_expires_in": 300, + "refresh_token_expires_in": 86400, + "verify_ssl": False, + "capabilities": CAPABILITIES, + "keys": {"uri_path": "jwks.json", "key_defs": KEYDEFS}, + "token_handler_args": { + "jwks_def": { + "private_path": "private/token_jwks.json", + "read_only": False, + "key_defs": [ + {"type": "oct", "bytes": "24", "use": ["enc"], "kid": "code"}, + {"type": "oct", "bytes": "24", "use": ["enc"], "kid": "refresh"} + ], + }, + "code": {"lifetime": 600}, + "token": { + "class": "oidcendpoint.token.jwt_token.JWTToken", + "kwargs": { + "lifetime": 3600, + "add_claims": [ + "email", + "email_verified", + "phone_number", + "phone_number_verified", + ], + "add_claim_by_scope": True, + "aud": ["https://example.org/appl"], + }, + }, + "refresh": {}, + }, + "endpoint": { + "provider_config": { + "path": "{}/.well-known/openid-configuration", + "class": ProviderConfiguration, + "kwargs": {}, + }, + "registration": { + "path": "{}/registration", + "class": Registration, + "kwargs": {}, + }, + "authorization": { + "path": "{}/authorization", + "class": Authorization, + "kwargs": {}, + }, + "token": {"path": "{}/token", "class": Token, "kwargs": {}}, + "session": {"path": "{}/end_session", "class": Session}, + }, + "client_authn": verify_client, + "authentication": { + "anon": { + "acr": INTERNETPROTOCOLPASSWORD, + "class": "oidcendpoint.user_authn.user.NoAuthn", + "kwargs": {"user": "diana"}, + } + }, + "template_dir": "template", + "userinfo": { + "class": user_info.UserInfo, + "kwargs": {"db_file": full_path("users.json")}, + }, + "id_token": {"class": IDToken}, + } + + self.endpoint_context = EndpointContext(conf, keyjar=KEYJAR) + self.session_manager = self.endpoint_context.session_manager + # self.session_manager = SessionManager(handler=self.endpoint_context.sdb.handler) + # self.endpoint_context.session_manager = self.session_manager + + def auth(self): + # Start with an authentication request + # The client ID appears in the request + AUTH_REQ = AuthorizationRequest( + client_id="client_1", + redirect_uri="https://example.com/cb", + scope=["openid", "mail", "address", "offline_access"], + state="STATE", + response_type="code", + ) + + # The authentication returns a user ID + user_id = "diana" + + # User info is stored in the Session DB + + user_info = UserSessionInfo(user_id=user_id) + self.session_manager.set([user_id], user_info) + + # Now for client session information + authn_event = create_authn_event( + user_id, + authn_info=INTERNETPROTOCOLPASSWORD, + authn_time=time_sans_frac(), + ) + + client_id = AUTH_REQ["client_id"] + client_info = ClientSessionInfo(client_id=client_id) + self.session_manager.set([user_id, client_id], client_info) + + # The user consent module produces a Grant instance + + grant = Grant( + scope=AUTH_REQ['scope'], + resources=[client_id], + authentication_event=authn_event, + authorization_request=AUTH_REQ + ) + + # the grant is assigned to a session (user_id, client_id) + session_id = session_key(user_id, client_id, grant.id) + self.session_manager.set([user_id, client_id, grant.id], grant) + + # Constructing an authorization code is now done by + + code = grant.mint_token( + session_id=session_id, + endpoint_context=self.endpoint_context, + token_type='authorization_code', + token_handler=self.session_manager.token_handler["code"], + expires_at=time_sans_frac() + 300 # 5 minutes from now + ) + return code + + def test_code_flow(self): + # code is a Token instance + code = self.auth() + + # next step is access token request + + TOKEN_REQ = AccessTokenRequest( + client_id="client_1", + redirect_uri="https://example.com/cb", + state="STATE", + grant_type="authorization_code", + client_secret="hemligt", + code=code.value + ) + + # parse the token + session_id = self.session_manager.token_handler.sid(TOKEN_REQ['code']) + user_id, client_id, grant_id = unpack_session_key(session_id) + + # Now given I have the client_id from the request and the user_id from the + # token I can easily find the grant + + # client_info = self.session_manager.get([user_id, TOKEN_REQ['client_id']]) + tok = self.session_manager.find_token(session_id, TOKEN_REQ['code']) + + # Verify that it's of the correct type and can be used + assert tok.type == "authorization_code" + assert tok.is_active() + + # Mint an access token and a refresh token and mark the code as used + + assert tok.supports_minting("access_token") + + client_info = self.session_manager.get([user_id, TOKEN_REQ["client_id"]]) + + assert tok.supports_minting("access_token") + + grant = self.session_manager[session_id] + + access_token = grant.mint_token( + session_id=session_id, + endpoint_context=self.endpoint_context, + token_type='access_token', + token_handler=self.session_manager.token_handler["access_token"], + expires_at=time_sans_frac() + 900, # 15 minutes from now + based_on=tok # Means the token (tok) was used to mint this token + ) + + # this test is include in the mint_token methods + # assert tok.supports_minting("refresh_token") + + refresh_token = grant.mint_token( + session_id=session_id, + endpoint_context=self.endpoint_context, + token_type='refresh_token', + token_handler=self.session_manager.token_handler["refresh_token"], + based_on=tok + ) + + tok.register_usage() + + assert tok.max_usage_reached() is True + + # A bit later a refresh token is used to mint a new access token + + REFRESH_TOKEN_REQ = RefreshAccessTokenRequest( + grant_type="refresh_token", + client_id="client_1", + client_secret="hemligt", + refresh_token=refresh_token.value, + scope=["openid", "mail", "offline_access"] + ) + + session_id = session_key(user_id, REFRESH_TOKEN_REQ['client_id'], grant_id) + reftok = self.session_manager.find_token(session_id, + REFRESH_TOKEN_REQ['refresh_token']) + + # Can I use this token to mint another token ? + assert grant.is_active() + + user_claims = self.endpoint_context.userinfo(user_id, client_id=TOKEN_REQ["client_id"], + user_info_claims=grant.claims) + + access_token_2 = grant.mint_token( + session_id=session_id, + endpoint_context=self.endpoint_context, + token_type='access_token', + token_handler=self.session_manager.token_handler["access_token"], + expires_at=time_sans_frac() + 900, # 15 minutes from now + based_on=reftok # Means the refresh token (reftok) was used to mint this token + ) + + assert access_token_2.is_active() + + token_info = self.session_manager.token_handler.info(access_token_2.value) + assert token_info diff --git a/tests/test_09_cookie_dealer.py b/tests/test_09_cookie_dealer.py index b1db1ae..9c5e41c 100644 --- a/tests/test_09_cookie_dealer.py +++ b/tests/test_09_cookie_dealer.py @@ -1,9 +1,12 @@ +import json +import time from http.cookies import SimpleCookie import pytest from cryptojwt.jwk.hmac import SYMKey from cryptojwt.key_jar import init_key_jar +from oidcendpoint import rndstr from oidcendpoint.cookie import CookieDealer from oidcendpoint.cookie import append_cookie from oidcendpoint.cookie import compute_session_state @@ -11,8 +14,10 @@ from oidcendpoint.cookie import create_session_cookie from oidcendpoint.cookie import make_cookie from oidcendpoint.cookie import new_cookie +from oidcendpoint.cookie import sign_enc_payload +from oidcendpoint.cookie import ver_dec_content from oidcendpoint.endpoint_context import EndpointContext -from oidcendpoint.oidc.token import AccessToken +from oidcendpoint.oidc.token import Token KEYDEFS = [ {"type": "RSA", "key": "", "use": ["sig"]}, @@ -228,7 +233,7 @@ def test_append_cookie(): "grant_expires_in": 300, "refresh_token_expires_in": 86400, "verify_ssl": False, - "endpoint": {"token": {"path": "token", "class": AccessToken, "kwargs": {}}}, + "endpoint": {"token": {"path": "token", "class": Token, "kwargs": {}}}, "template_dir": "template", "keys": { "private_path": "own/jwks.json", @@ -256,6 +261,8 @@ def test_append_cookie(): endpoint_context.cdb[client_id] = {"client_secret": client_secret} endpoint_context.cookie_dealer = CookieDealer(**cookie_conf) +enc_key = rndstr(32) + def test_new_cookie(): kaka = new_cookie( @@ -301,3 +308,10 @@ def test_cookie_same_site_none(): assert kaka["test"]["secure"] is True assert kaka["test"]["httponly"] is True assert kaka["test"]["samesite"] is "None" + + +def test_cookie_enc(): + _key = SYMKey(k=enc_key) + _enc_data = sign_enc_payload(json.dumps({"test": "data"}), timestamp=time.time(), enc_key=_key) + _data, _timestamp = ver_dec_content(_enc_data.split('|'), enc_key=_key) + assert json.loads(_data) == {"test": "data"} diff --git a/tests/test_10_oidc_authz.py b/tests/test_10_oidc_authz.py deleted file mode 100644 index 535bb81..0000000 --- a/tests/test_10_oidc_authz.py +++ /dev/null @@ -1,206 +0,0 @@ -import json -import os - -import pytest -from cryptojwt.key_jar import build_keyjar - -from oidcendpoint.authz import AuthzHandling -from oidcendpoint.authz import Implicit -from oidcendpoint.authz import factory -from oidcendpoint.cookie import CookieDealer -from oidcendpoint.cookie import new_cookie -from oidcendpoint.endpoint_context import EndpointContext -from oidcendpoint.oidc import userinfo -from oidcendpoint.oidc.authorization import Authorization -from oidcendpoint.oidc.provider_config import ProviderConfiguration -from oidcendpoint.oidc.registration import Registration -from oidcendpoint.oidc.session import Session -from oidcendpoint.oidc.token import AccessToken -from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD -from oidcendpoint.user_info import UserInfo - -ISS = "https://example.com/" - -KEYDEFS = [ - {"type": "RSA", "use": ["sig"]}, - {"type": "EC", "crv": "P-256", "use": ["sig"]}, -] - -KEYJAR = build_keyjar(KEYDEFS) -KEYJAR.import_jwks(KEYJAR.export_jwks(private=True), ISS) - -RESPONSE_TYPES_SUPPORTED = [ - ["code"], - ["token"], - ["id_token"], - ["code", "token"], - ["code", "id_token"], - ["id_token", "token"], - ["code", "id_token", "token"], - ["none"], -] - -CAPABILITIES = { - "response_types_supported": [" ".join(x) for x in RESPONSE_TYPES_SUPPORTED], - "token_endpoint_auth_methods_supported": [ - "client_secret_post", - "client_secret_basic", - "client_secret_jwt", - "private_key_jwt", - ], - "response_modes_supported": ["query", "fragment", "form_post"], - "subject_types_supported": ["public", "pairwise"], - "grant_types_supported": [ - "authorization_code", - "implicit", - "urn:ietf:params:oauth:grant-type:jwt-bearer", - "refresh_token", - ], - "claim_types_supported": ["normal", "aggregated", "distributed"], - "claims_parameter_supported": True, - "request_parameter_supported": True, - "request_uri_parameter_supported": True, -} - -BASEDIR = os.path.abspath(os.path.dirname(__file__)) - - -def full_path(local_file): - return os.path.join(BASEDIR, local_file) - - -USERINFO_db = json.loads(open(full_path("users.json")).read()) - - -class TestAuthz(object): - @pytest.fixture(autouse=True) - def create_ec(self): - conf = { - "issuer": ISS, - "password": "mycket hemlig zebra", - "token_expires_in": 600, - "grant_expires_in": 300, - "refresh_token_expires_in": 86400, - "verify_ssl": False, - "capabilities": CAPABILITIES, - "keys": {"uri_path": "jwks.json", "key_defs": KEYDEFS}, - "endpoint": { - "provider_config": { - "path": "{}/.well-known/openid-configuration", - "class": ProviderConfiguration, - "kwargs": {"client_authn_method": None}, - }, - "registration": { - "path": "{}/registration", - "class": Registration, - "kwargs": {"client_authn_method": None}, - }, - "authorization": { - "path": "{}/authorization", - "class": Authorization, - "kwargs": {"client_authn_method": None}, - }, - "token": {"path": "{}/token", "class": AccessToken, "kwargs": {}}, - "userinfo": { - "path": "{}/userinfo", - "class": userinfo.UserInfo, - "kwargs": {"db_file": "users.json"}, - }, - "session": { - "path": "{}/end_session", - "class": Session, - "kwargs": { - "signing_alg": "ES256", - "logout_verify_url": "{}/verify_logout".format(ISS), - }, - }, - }, - "authentication": { - "anon": { - "acr": INTERNETPROTOCOLPASSWORD, - "class": "oidcendpoint.user_authn.user.NoAuthn", - "kwargs": {"user": "diana"}, - } - }, - "userinfo": {"class": UserInfo, "kwargs": {"db": USERINFO_db}}, - "template_dir": "template", - "authz": {"class": AuthzHandling, "kwargs": {}}, - "cookie_dealer": { - "class": CookieDealer, - "kwargs": { - "sign_key": "ghsNKDDLshZTPn974nOsIGhedULrsqnsGoBFBLwUKuJhE2ch", - "default_values": { - "name": "oidcop", - "domain": "127.0.0.1", - "path": "/", - "max_age": 3600, - }, - }, - }, - } - - self.endpoint_context = EndpointContext(conf, keyjar=KEYJAR) - - def _create_cookie(self, user, sid, state, client_id, typ="sso", name=""): - ec = self.endpoint_context - if not name: - name = ec.cookie_name["session"] - return new_cookie( - ec, - sub=user, - sid=sid, - state=state, - client_id=client_id, - typ=typ, - cookie_name=name, - ) - - def test_init_authz(self): - authz = AuthzHandling(self.endpoint_context) - assert authz - - def test_authz_set_get(self): - authz = self.endpoint_context.authz - authz.set("diana", "client_1", ["email", "phone"]) - assert authz.get("diana", "client_1") == ["email", "phone"] - - def test_authz_cookie(self): - authz = self.endpoint_context.authz - authz.set("diana", "client_1", ["email", "phone"]) - cookie = self._create_cookie("diana", "_sid_", "1234567", "client_1") - perm = authz.permissions(cookie) - assert set(perm) == {"email", "phone"} - - def test_authz_cookie_wrong_client(self): - authz = self.endpoint_context.authz - authz.set("diana", "client_1", ["email", "phone"]) - cookie = self._create_cookie("diana", "_sid_", "1234567", "client_2") - perm = authz.permissions(cookie) - assert perm is None - - def tests_implicit(self): - authz = Implicit(self.endpoint_context, "any") - perm = authz.get("foo", "bar") - assert perm == "any" - - def test_factory_implicit(self): - authz = factory("Implicit", self.endpoint_context, permission="all") - assert authz.get("foo", "bar") == "all" - - def test_factory_authz_handling(self): - authz = factory("AuthzHandling", self.endpoint_context) - authz.set("diana", "client_1", ["email", "phone"]) - assert authz.get("diana", "client_1") == ["email", "phone"] - - def test_authz_cookie_none(self): - authz = self.endpoint_context.authz - authz.set("diana", "client_1", ["email", "phone"]) - assert authz.permissions(None) is None - - def test_authz_cookie_other(self): - authz = self.endpoint_context.authz - authz.set("diana", "client_1", ["email", "phone"]) - cookie = self._create_cookie( - "diana", "_sid_", "1234567", "client_1", name="foo" - ) - assert authz.permissions(cookie) is None diff --git a/tests/test_12_user_authn.py b/tests/test_12_user_authn.py index dafbc5d..03845c9 100644 --- a/tests/test_12_user_authn.py +++ b/tests/test_12_user_authn.py @@ -29,9 +29,7 @@ def create_endpoint_context(self): conf = { "issuer": "https://example.com/", "password": "mycket hemligt", - "token_expires_in": 600, "grant_expires_in": 300, - "refresh_token_expires_in": 86400, "verify_ssl": False, "endpoint": {}, "keys": {"uri_path": "static/jwks.json", "key_defs": KEYDEFS}, diff --git a/tests/test_22_oidc_provider_config_endpoint.py b/tests/test_22_oidc_provider_config_endpoint.py index 47995fb..5c46441 100755 --- a/tests/test_22_oidc_provider_config_endpoint.py +++ b/tests/test_22_oidc_provider_config_endpoint.py @@ -4,7 +4,7 @@ from oidcendpoint.endpoint_context import EndpointContext from oidcendpoint.oidc.provider_config import ProviderConfiguration -from oidcendpoint.oidc.token import AccessToken +from oidcendpoint.oidc.token import Token KEYDEFS = [ {"type": "RSA", "key": "", "use": ["sig"]}, @@ -31,7 +31,7 @@ "private_key_jwt", ], "response_modes_supported": ["query", "fragment", "form_post"], - "subject_types_supported": ["public", "pairwise"], + "subject_types_supported": ["public", "pairwise""ephemeral"], "grant_types_supported": [ "authorization_code", "implicit", @@ -63,7 +63,7 @@ def create_endpoint(self): "class": ProviderConfiguration, "kwargs": {}, }, - "token": {"path": "token", "class": AccessToken, "kwargs": {}}, + "token": {"path": "token", "class": Token, "kwargs": {}}, }, "template_dir": "template", } diff --git a/tests/test_23_oidc_registration_endpoint.py b/tests/test_23_oidc_registration_endpoint.py index 0d59e45..71d03ea 100755 --- a/tests/test_23_oidc_registration_endpoint.py +++ b/tests/test_23_oidc_registration_endpoint.py @@ -1,9 +1,9 @@ # -*- coding: latin-1 -*- import json -from cryptojwt.key_jar import init_key_jar import pytest import responses +from cryptojwt.key_jar import init_key_jar from oidcmsg.oidc import RegistrationRequest from oidcmsg.oidc import RegistrationResponse @@ -12,7 +12,7 @@ from oidcendpoint.oidc.authorization import Authorization from oidcendpoint.oidc.registration import Registration from oidcendpoint.oidc.registration import match_sp_sep -from oidcendpoint.oidc.token import AccessToken +from oidcendpoint.oidc.token import Token from oidcendpoint.oidc.userinfo import UserInfo KEYDEFS = [ @@ -72,7 +72,7 @@ def create_endpoint(self): "refresh_token_expires_in": 86400, "verify_ssl": False, "capabilities": { - "subject_types_supported": ["public", "pairwise"], + "subject_types_supported": ["public", "pairwise", "ephemeral"], "grant_types_supported": [ "authorization_code", "implicit", @@ -108,7 +108,7 @@ def create_endpoint(self): }, "token": { "path": "token", - "class": AccessToken, + "class": Token, "kwargs": { "client_authn_method": [ "client_secret_post", diff --git a/tests/test_24_oauth2_authorization_endpoint.py b/tests/test_24_oauth2_authorization_endpoint.py index 599269d..f2aeb81 100755 --- a/tests/test_24_oauth2_authorization_endpoint.py +++ b/tests/test_24_oauth2_authorization_endpoint.py @@ -18,23 +18,23 @@ from oidcmsg.oauth2 import AuthorizationResponse from oidcmsg.time_util import in_a_while -from oidcendpoint.common.authorization import FORM_POST -from oidcendpoint.common.authorization import get_uri -from oidcendpoint.common.authorization import inputs -from oidcendpoint.common.authorization import join_query -from oidcendpoint.common.authorization import verify_uri +from oidcendpoint.authn_event import create_authn_event +from oidcendpoint.authz import AuthzHandling from oidcendpoint.cookie import CookieDealer from oidcendpoint.endpoint_context import EndpointContext from oidcendpoint.exception import InvalidRequest from oidcendpoint.exception import NoSuchAuthentication from oidcendpoint.exception import RedirectURIError from oidcendpoint.exception import ToOld -from oidcendpoint.exception import UnknownClient from oidcendpoint.exception import UnAuthorizedClientScope -from oidcendpoint.exception import UnAuthorizedClient +from oidcendpoint.exception import UnknownClient from oidcendpoint.id_token import IDToken +from oidcendpoint.oauth2.authorization import FORM_POST from oidcendpoint.oauth2.authorization import Authorization -from oidcendpoint.session import SessionInfo +from oidcendpoint.oauth2.authorization import get_uri +from oidcendpoint.oauth2.authorization import inputs +from oidcendpoint.oauth2.authorization import join_query +from oidcendpoint.oauth2.authorization import verify_uri from oidcendpoint.user_info import UserInfo KEYDEFS = [ @@ -140,9 +140,6 @@ def create_endpoint(self): conf = { "issuer": "https://example.com/", "password": "mycket hemligt zebra", - "token_expires_in": 600, - "grant_expires_in": 300, - "refresh_token_expires_in": 86400, "verify_ssl": False, "capabilities": CAPABILITIES, "keys": {"uri_path": "static/jwks.json", "key_defs": KEYDEFS}, @@ -191,6 +188,24 @@ def create_endpoint(self): }, }, }, + "authz": { + "class": AuthzHandling, + "kwargs": { + "grant_config": { + "usage_rules": { + "authorization_code": { + 'supports_minting': ["access_token", "refresh_token", "id_token"], + "max_usage": 1 + }, + "access_token": {}, + "refresh_token": { + 'supports_minting': ["access_token", "refresh_token", "id_token"], + } + }, + "expires_in": 43200 + } + } + } } endpoint_context = EndpointContext(conf) _clients = yaml.safe_load(io.StringIO(client_yaml)) @@ -199,6 +214,8 @@ def create_endpoint(self): endpoint_context.keyjar.export_jwks(True, ""), conf["issuer"] ) self.endpoint = endpoint_context.endpoint["authorization"] + self.session_manager = endpoint_context.session_manager + self.user_id = "diana" self.rp_keyjar = KeyJar() self.rp_keyjar.add_symmetric("client_1", "hemligtkodord1234567890") @@ -206,12 +223,24 @@ def create_endpoint(self): "client_1", "hemligtkodord1234567890" ) + def _create_session(self, auth_req, sub_type="public", sector_identifier=''): + if sector_identifier: + areq = auth_req.copy() + areq["sector_identifier_uri"] = sector_identifier + else: + areq = auth_req + + client_id = areq['client_id'] + ae = create_authn_event(self.user_id) + return self.session_manager.create_session(ae, areq, self.user_id, + client_id=client_id, + sub_type=sub_type) + def test_init(self): assert self.endpoint def test_parse(self): _req = self.endpoint.parse_request(AUTH_REQ_DICT) - assert isinstance(_req, AuthorizationRequest) assert set(_req.keys()) == set(AUTH_REQ.keys()) @@ -223,6 +252,7 @@ def test_process_request(self): "fragment_enc", "return_uri", "cookie", + "session_id" } def test_do_response_code(self): @@ -393,24 +423,15 @@ def test_create_authn_response(self): scope="openid", ) - _ec = self.endpoint.endpoint_context - _ec.sdb["session_id"] = SessionInfo( - authn_req=request, - uid="diana", - sub="abcdefghijkl", - authn_event={ - "authn_info": "loa1", - "uid": "diana", - "authn_time": utc_time_sans_frac(), - }, - ) - _ec.cdb["client_id"] = { + self.endpoint.endpoint_context.cdb["client_id"] = { "client_id": "client_id", "redirect_uris": [("https://rp.example.com/cb", {})], "id_token_signed_response_alg": "ES256", } - resp = self.endpoint.create_authn_response(request, "session_id") + session_id = self._create_session(request) + + resp = self.endpoint.create_authn_response(request, session_id) assert isinstance(resp["response_args"], AuthorizationErrorResponse) def test_setup_auth(self): @@ -434,7 +455,7 @@ def test_setup_auth(self): ) res = self.endpoint.setup_auth(request, redirect_uri, cinfo, kaka) - assert set(res.keys()) == {"authn_event", "identity", "user"} + assert set(res.keys()) == {"session_id", "identity", "user"} def test_setup_auth_error(self): request = AuthorizationRequest( @@ -512,25 +533,16 @@ def test_setup_auth_user(self): "redirect_uris": [("https://rp.example.com/cb", {})], "id_token_signed_response_alg": "RS256", } - _ec = self.endpoint.endpoint_context - _ec.sdb["session_id"] = SessionInfo( - authn_req=request, - uid="diana", - sub="abcdefghijkl", - authn_event={ - "authn_info": "loa1", - "uid": "diana", - "authn_time": utc_time_sans_frac(), - }, - ) - item = _ec.authn_broker.db["anon"] + session_id = self._create_session(request) + + item = self.endpoint.endpoint_context.authn_broker.db["anon"] item["method"].user = b64e( - as_bytes(json.dumps({"uid": "krall", "sid": "session_id"})) + as_bytes(json.dumps({"uid": "krall", "sid": session_id})) ) res = self.endpoint.setup_auth(request, redirect_uri, cinfo, None) - assert set(res.keys()) == {"authn_event", "identity", "user"} + assert set(res.keys()) == {"session_id", "identity", "user"} assert res["identity"]["uid"] == "krall" def test_setup_auth_session_revoked(self): @@ -548,22 +560,17 @@ def test_setup_auth_session_revoked(self): "redirect_uris": [("https://rp.example.com/cb", {})], "id_token_signed_response_alg": "RS256", } - _ec = self.endpoint.endpoint_context - _ec.sdb["session_id"] = SessionInfo( - authn_req=request, - uid="diana", - sub="abcdefghijkl", - authn_event={ - "authn_info": "loa1", - "uid": "diana", - "authn_time": utc_time_sans_frac(), - }, - revoked=True, - ) + session_id = self._create_session(request) + + _mngr = self.endpoint.endpoint_context.session_manager + _csi = _mngr[session_id] + _csi.revoked = True + + _ec = self.endpoint.endpoint_context item = _ec.authn_broker.db["anon"] item["method"].user = b64e( - as_bytes(json.dumps({"uid": "krall", "sid": "session_id"})) + as_bytes(json.dumps({"uid": "krall", "sid": session_id})) ) res = self.endpoint.setup_auth(request, redirect_uri, cinfo, None) @@ -598,6 +605,30 @@ def test_response_mode_fragment(self): info = self.endpoint.response_mode(request) assert set(info.keys()) == {"fragment_enc"} + # def test_sso(self): + # _pr_resp = self.endpoint.parse_request(AUTH_REQ_DICT) + # _resp = self.endpoint.process_request(_pr_resp) + # msg = self.endpoint.do_response(**_resp) + # + # request = AuthorizationRequest( + # client_id="client_2", + # redirect_uri="https://rp.example.org/cb", + # response_type=["code"], + # state="state", + # scope="openid", + # ) + # + # cinfo = { + # "client_id": "client_2", + # "redirect_uris": [(request["redirect_uri"], {})] + # } + # + # _pr_resp = self.endpoint.parse_request(AUTH_REQ_DICT, cookie="kaka") + # _resp = self.endpoint.process_request(_pr_resp) + # msg = self.endpoint.do_response(**_resp) + # + # assert set(res.keys()) == {"authn_event", "identity", "user"} + def test_inputs(): elems = inputs(dict(foo="bar", home="stead")) diff --git a/tests/test_24_oauth2_authorization_endpoint_jar.py b/tests/test_24_oauth2_authorization_endpoint_jar.py index efc1854..24b7644 100755 --- a/tests/test_24_oauth2_authorization_endpoint_jar.py +++ b/tests/test_24_oauth2_authorization_endpoint_jar.py @@ -182,6 +182,8 @@ def create_endpoint(self): endpoint_context.keyjar.export_jwks(True, ""), conf["issuer"] ) self.endpoint = endpoint_context.endpoint["authorization"] + self.session_manager = endpoint_context.session_manager + self.user_id = "diana" self.rp_keyjar = KeyJar() self.rp_keyjar.add_symmetric("client_1", "hemligtkodord1234567890") diff --git a/tests/test_24_oidc_authorization_endpoint.py b/tests/test_24_oidc_authorization_endpoint.py index 9de55a9..03acfa6 100755 --- a/tests/test_24_oidc_authorization_endpoint.py +++ b/tests/test_24_oidc_authorization_endpoint.py @@ -9,7 +9,6 @@ import yaml from cryptojwt import JWT from cryptojwt import KeyJar -from cryptojwt.jwt import utc_time_sans_frac from cryptojwt.utils import as_bytes from cryptojwt.utils import b64e from oidcmsg.exception import ParameterError @@ -21,9 +20,8 @@ from oidcmsg.oidc import verified_claim_name from oidcmsg.oidc import verify_id_token -from oidcendpoint.common.authorization import FORM_POST -from oidcendpoint.common.authorization import join_query -from oidcendpoint.common.authorization import verify_uri +from oidcendpoint.authn_event import create_authn_event +from oidcendpoint.authz import AuthzHandling from oidcendpoint.cookie import CookieDealer from oidcendpoint.cookie import cookie_value from oidcendpoint.cookie import new_cookie @@ -33,19 +31,20 @@ from oidcendpoint.exception import RedirectURIError from oidcendpoint.exception import ToOld from oidcendpoint.exception import UnknownClient -from oidcendpoint.exception import UnAuthorizedClient from oidcendpoint.id_token import IDToken from oidcendpoint.login_hint import LoginHint2Acrs +from oidcendpoint.oauth2.authorization import FORM_POST +from oidcendpoint.oauth2.authorization import get_uri +from oidcendpoint.oauth2.authorization import inputs +from oidcendpoint.oauth2.authorization import join_query +from oidcendpoint.oauth2.authorization import verify_uri from oidcendpoint.oidc import userinfo from oidcendpoint.oidc.authorization import Authorization from oidcendpoint.oidc.authorization import acr_claims -from oidcendpoint.oidc.authorization import get_uri -from oidcendpoint.oidc.authorization import inputs from oidcendpoint.oidc.authorization import re_authenticate from oidcendpoint.oidc.provider_config import ProviderConfiguration from oidcendpoint.oidc.registration import Registration -from oidcendpoint.oidc.token import AccessToken -from oidcendpoint.session import SessionInfo +from oidcendpoint.oidc.token import Token from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD from oidcendpoint.user_authn.authn_context import UNSPECIFIED from oidcendpoint.user_authn.authn_context import init_method @@ -72,7 +71,7 @@ ] CAPABILITIES = { - "subject_types_supported": ["public", "pairwise"], + "subject_types_supported": ["public", "pairwise", "ephemeral"], "grant_types_supported": [ "authorization_code", "implicit", @@ -110,7 +109,6 @@ def full_path(local_file): USERINFO_db = json.loads(open(full_path("users.json")).read()) - client_yaml = """ oidc_clients: client_1: @@ -149,18 +147,17 @@ def create_endpoint(self): conf = { "issuer": "https://example.com/", "password": "mycket hemligt zebra", - "token_expires_in": 600, - "grant_expires_in": 300, - "refresh_token_expires_in": 86400, "verify_ssl": False, "capabilities": CAPABILITIES, "keys": {"uri_path": "static/jwks.json", "key_defs": KEYDEFS}, "id_token": { "class": IDToken, "kwargs": { - "available_claims": { + "base_claims": { "email": {"essential": True}, "email_verified": {"essential": True}, + "given_name": {"essential": True}, + "nickname": None } }, }, @@ -190,7 +187,7 @@ def create_endpoint(self): }, "token": { "path": "token", - "class": AccessToken, + "class": Token, "kwargs": { "client_authn_method": [ "client_secret_post", @@ -222,6 +219,24 @@ def create_endpoint(self): }, "userinfo": {"class": UserInfo, "kwargs": {"db": USERINFO_db}}, "template_dir": "template", + "authz": { + "class": AuthzHandling, + "kwargs": { + "grant_config": { + "usage_rules": { + "authorization_code": { + 'supports_minting': ["access_token", "refresh_token", "id_token"], + "max_usage": 1 + }, + "access_token": {}, + "refresh_token": { + 'supports_minting': ["access_token", "refresh_token"], + } + }, + "expires_in": 43200 + } + } + }, "cookie_dealer": { "class": CookieDealer, "kwargs": { @@ -240,12 +255,15 @@ def create_endpoint(self): }, } endpoint_context = EndpointContext(conf) + _clients = yaml.safe_load(io.StringIO(client_yaml)) endpoint_context.cdb = _clients["oidc_clients"] endpoint_context.keyjar.import_jwks( endpoint_context.keyjar.export_jwks(True, ""), conf["issuer"] ) self.endpoint = endpoint_context.endpoint["authorization"] + self.session_manager = endpoint_context.session_manager + self.user_id = "diana" self.rp_keyjar = KeyJar() self.rp_keyjar.add_symmetric("client_1", "hemligtkodord1234567890") @@ -256,6 +274,18 @@ def create_endpoint(self): def test_init(self): assert self.endpoint + def _create_session(self, auth_req, sub_type="public", sector_identifier=''): + if sector_identifier: + authz_req = auth_req.copy() + authz_req["sector_identifier_uri"] = sector_identifier + else: + authz_req = auth_req + client_id = authz_req['client_id'] + ae = create_authn_event(self.user_id) + return self.session_manager.create_session(ae, authz_req, self.user_id, + client_id=client_id, + sub_type=sub_type) + def test_parse(self): _req = self.endpoint.parse_request(AUTH_REQ_DICT) @@ -270,6 +300,7 @@ def test_process_request(self): "fragment_enc", "return_uri", "cookie", + "session_id" } def test_do_response_code(self): @@ -366,13 +397,12 @@ def test_id_token_claims(self): _req["nonce"] = "rnd_nonce" _pr_resp = self.endpoint.parse_request(_req) _resp = self.endpoint.process_request(_pr_resp) - idt = verify_id_token( - _resp["response_args"], keyjar=self.endpoint.endpoint_context.keyjar - ) + idt = verify_id_token(_resp["response_args"], keyjar=self.endpoint.endpoint_context.keyjar) assert idt - # from claims - assert "given_name" in _resp["response_args"]["__verified_id_token"] # from config + assert "given_name" in _resp["response_args"]["__verified_id_token"] + assert "nickname" in _resp["response_args"]["__verified_id_token"] + # Could have gotten email but didn't ask for it assert "email" in _resp["response_args"]["__verified_id_token"] def test_re_authenticate(self): @@ -541,23 +571,15 @@ def test_create_authn_response(self): ) _ec = self.endpoint.endpoint_context - _ec.sdb["session_id"] = SessionInfo( - authn_req=request, - uid="diana", - sub="abcdefghijkl", - authn_event={ - "authn_info": "loa1", - "uid": "diana", - "authn_time": utc_time_sans_frac(), - }, - ) _ec.cdb["client_id"] = { "client_id": "client_id", "redirect_uris": [("https://rp.example.com/cb", {})], "id_token_signed_response_alg": "ES256", } - resp = self.endpoint.create_authn_response(request, "session_id") + session_id = self._create_session(request) + + resp = self.endpoint.create_authn_response(request, session_id) assert isinstance(resp["response_args"], AuthorizationErrorResponse) def test_setup_auth(self): @@ -581,7 +603,7 @@ def test_setup_auth(self): ) res = self.endpoint.setup_auth(request, redirect_uri, cinfo, kaka) - assert set(res.keys()) == {"authn_event", "identity", "user"} + assert set(res.keys()) == {"session_id", "identity", "user"} def test_setup_auth_error(self): request = AuthorizationRequest( @@ -628,24 +650,16 @@ def test_setup_auth_user(self): "id_token_signed_response_alg": "RS256", } _ec = self.endpoint.endpoint_context - _ec.sdb["session_id"] = SessionInfo( - authn_req=request, - uid="diana", - sub="abcdefghijkl", - authn_event={ - "authn_info": "loa1", - "uid": "diana", - "authn_time": utc_time_sans_frac(), - }, - ) + + session_id = self._create_session(request) item = _ec.authn_broker.db["anon"] item["method"].user = b64e( - as_bytes(json.dumps({"uid": "krall", "sid": "session_id"})) + as_bytes(json.dumps({"uid": "krall", "sid": session_id})) ) res = self.endpoint.setup_auth(request, redirect_uri, cinfo, None) - assert set(res.keys()) == {"authn_event", "identity", "user"} + assert set(res.keys()) == {"session_id", "identity", "user"} assert res["identity"]["uid"] == "krall" def test_setup_auth_session_revoked(self): @@ -664,23 +678,17 @@ def test_setup_auth_session_revoked(self): "id_token_signed_response_alg": "RS256", } _ec = self.endpoint.endpoint_context - _ec.sdb["session_id"] = SessionInfo( - authn_req=request, - uid="diana", - sub="abcdefghijkl", - authn_event={ - "authn_info": "loa1", - "uid": "diana", - "authn_time": utc_time_sans_frac(), - }, - revoked=True, - ) + + session_id = self._create_session(request) item = _ec.authn_broker.db["anon"] item["method"].user = b64e( - as_bytes(json.dumps({"uid": "krall", "sid": "session_id"})) + as_bytes(json.dumps({"uid": "krall", "sid": session_id})) ) + grant = _ec.session_manager[session_id] + grant.revoked = True + res = self.endpoint.setup_auth(request, redirect_uri, cinfo, None) assert set(res.keys()) == {"args", "function"} diff --git a/tests/test_25_oidc_token_endpoint.py b/tests/test_25_oidc_token_endpoint.py deleted file mode 100755 index 43be33f..0000000 --- a/tests/test_25_oidc_token_endpoint.py +++ /dev/null @@ -1,266 +0,0 @@ -import json -import os - -import pytest -from cryptojwt import JWT -from cryptojwt.key_jar import build_keyjar -from oidcmsg.oidc import AccessTokenRequest -from oidcmsg.oidc import AuthorizationRequest -from oidcmsg.oidc import RefreshAccessTokenRequest - -from oidcendpoint import JWT_BEARER -from oidcendpoint.client_authn import verify_client -from oidcendpoint.endpoint_context import EndpointContext -from oidcendpoint.exception import MultipleUsage -from oidcendpoint.exception import UnAuthorizedClient -from oidcendpoint.oidc import userinfo -from oidcendpoint.oidc.authorization import Authorization -from oidcendpoint.oidc.provider_config import ProviderConfiguration -from oidcendpoint.oidc.refresh_token import RefreshAccessToken -from oidcendpoint.oidc.registration import Registration -from oidcendpoint.oidc.token import AccessToken -from oidcendpoint.session import setup_session -from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD -from oidcendpoint.user_info import UserInfo - -KEYDEFS = [ - {"type": "RSA", "key": "", "use": ["sig"]}, - {"type": "EC", "crv": "P-256", "use": ["sig"]}, -] - -CLIENT_KEYJAR = build_keyjar(KEYDEFS) - - -RESPONSE_TYPES_SUPPORTED = [ - ["code"], - ["token"], - ["id_token"], - ["code", "token"], - ["code", "id_token"], - ["id_token", "token"], - ["code", "token", "id_token"], - ["none"], -] - -CAPABILITIES = { - "subject_types_supported": ["public", "pairwise"], - "grant_types_supported": [ - "authorization_code", - "implicit", - "urn:ietf:params:oauth:grant-type:jwt-bearer", - "refresh_token", - ], -} - -AUTH_REQ = AuthorizationRequest( - client_id="client_1", - redirect_uri="https://example.com/cb", - scope=["openid"], - state="STATE", - response_type="code", -) - -TOKEN_REQ = AccessTokenRequest( - client_id="client_1", - redirect_uri="https://example.com/cb", - state="STATE", - grant_type="authorization_code", - client_secret="hemligt", -) - -REFRESH_TOKEN_REQ = RefreshAccessTokenRequest( - grant_type="refresh_token", client_id="client_1", client_secret="hemligt" -) - -TOKEN_REQ_DICT = TOKEN_REQ.to_dict() - -BASEDIR = os.path.abspath(os.path.dirname(__file__)) - - -def full_path(local_file): - return os.path.join(BASEDIR, local_file) - - -USERINFO = UserInfo(json.loads(open(full_path("users.json")).read())) - - -class TestEndpoint(object): - @pytest.fixture(autouse=True) - def create_endpoint(self): - conf = { - "issuer": "https://example.com/", - "password": "mycket hemligt", - "token_expires_in": 600, - "grant_expires_in": 300, - "refresh_token_expires_in": 86400, - "verify_ssl": False, - "capabilities": CAPABILITIES, - "keys": {"uri_path": "jwks.json", "key_defs": KEYDEFS}, - "endpoint": { - "provider_config": { - "path": ".well-known/openid-configuration", - "class": ProviderConfiguration, - "kwargs": {}, - }, - "registration": { - "path": "registration", - "class": Registration, - "kwargs": {}, - }, - "authorization": { - "path": "authorization", - "class": Authorization, - "kwargs": {}, - }, - "token": { - "path": "token", - "class": AccessToken, - "kwargs": { - "client_authn_method": [ - "client_secret_basic", - "client_secret_post", - "client_secret_jwt", - "private_key_jwt", - ] - }, - }, - "refresh_token": { - "path": "token", - "class": RefreshAccessToken, - "kwargs": {}, - }, - "userinfo": { - "path": "userinfo", - "class": userinfo.UserInfo, - "kwargs": {"db_file": "users.json"}, - }, - }, - "authentication": { - "anon": { - "acr": INTERNETPROTOCOLPASSWORD, - "class": "oidcendpoint.user_authn.user.NoAuthn", - "kwargs": {"user": "diana"}, - } - }, - "userinfo": {"class": UserInfo, "kwargs": {"db": {}}}, - "client_authn": verify_client, - "template_dir": "template", - } - endpoint_context = EndpointContext(conf) - endpoint_context.cdb["client_1"] = { - "client_secret": "hemligt", - "redirect_uris": [("https://example.com/cb", None)], - "client_salt": "salted", - "token_endpoint_auth_method": "client_secret_post", - "response_types": ["code", "token", "code id_token", "id_token"], - } - endpoint_context.keyjar.import_jwks(CLIENT_KEYJAR.export_jwks(), "client_1") - self.endpoint = endpoint_context.endpoint["token"] - - def test_init(self): - assert self.endpoint - - def test_signing_alg_values(self): - """ - According to - https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata - The value "none" MUST NOT be used for the - token_endpoint_auth_signing_alg_values_supported field. - """ - assert "none" not in self.endpoint.endpoint_info[ - "token_endpoint_auth_signing_alg_values_supported" - ] - - def test_parse(self): - session_id = setup_session(self.endpoint.endpoint_context, AUTH_REQ, uid="user") - _token_request = TOKEN_REQ_DICT.copy() - _token_request["code"] = self.endpoint.endpoint_context.sdb[session_id]["code"] - _req = self.endpoint.parse_request(_token_request) - - assert isinstance(_req, AccessTokenRequest) - assert set(_req.keys()) == set(_token_request.keys()) - - def test_process_request(self): - session_id = setup_session( - self.endpoint.endpoint_context, - AUTH_REQ, - uid="user", - acr=INTERNETPROTOCOLPASSWORD, - ) - _token_request = TOKEN_REQ_DICT.copy() - _context = self.endpoint.endpoint_context - _token_request["code"] = _context.sdb[session_id]["code"] - _context.sdb.update(session_id, user="diana") - _req = self.endpoint.parse_request(_token_request) - - _resp = self.endpoint.process_request(request=_req) - - assert _resp - assert set(_resp.keys()) == {"http_headers", "response_args"} - - def test_process_request_using_code_twice(self): - session_id = setup_session( - self.endpoint.endpoint_context, - AUTH_REQ, - uid="user", - acr=INTERNETPROTOCOLPASSWORD, - ) - _token_request = TOKEN_REQ_DICT.copy() - _context = self.endpoint.endpoint_context - _token_request["code"] = _context.sdb[session_id]["code"] - _context.sdb.update(session_id, user="diana") - _req = self.endpoint.parse_request(_token_request) - _resp = self.endpoint.process_request(request=_req) - - # 2nd time used - # TODO: There is a bug in _post_parse_request, the returned error - # should be invalid_grant, not invalid_client - _req = self.endpoint.parse_request(_token_request) - _resp = self.endpoint.process_request(request=_req) - - assert _resp - assert set(_resp.keys()) == {"error"} - - def test_do_response(self): - session_id = setup_session( - self.endpoint.endpoint_context, - AUTH_REQ, - uid="user", - acr=INTERNETPROTOCOLPASSWORD, - ) - self.endpoint.endpoint_context.sdb.update(session_id, user="diana") - _token_request = TOKEN_REQ_DICT.copy() - _token_request["code"] = self.endpoint.endpoint_context.sdb[session_id]["code"] - _req = self.endpoint.parse_request(_token_request) - - _resp = self.endpoint.process_request(request=_req) - msg = self.endpoint.do_response(request=_req, **_resp) - assert isinstance(msg, dict) - - def test_process_request_using_private_key_jwt(self): - session_id = setup_session( - self.endpoint.endpoint_context, - AUTH_REQ, - uid="user", - acr=INTERNETPROTOCOLPASSWORD, - ) - _token_request = TOKEN_REQ_DICT.copy() - del _token_request["client_id"] - del _token_request["client_secret"] - _context = self.endpoint.endpoint_context - - _jwt = JWT(CLIENT_KEYJAR, iss=AUTH_REQ["client_id"], sign_alg="RS256") - _jwt.with_jti = True - _assertion = _jwt.pack({"aud": [_context.endpoint["token"].full_path]}) - _token_request.update( - {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER} - ) - _token_request["code"] = self.endpoint.endpoint_context.sdb[session_id]["code"] - - _context.sdb.update(session_id, user="diana") - _req = self.endpoint.parse_request(_token_request) - _resp = self.endpoint.process_request(request=_req) - - # 2nd time used - with pytest.raises(UnAuthorizedClient): - self.endpoint.parse_request(_token_request) diff --git a/tests/test_26_oidc_userinfo_endpoint.py b/tests/test_26_oidc_userinfo_endpoint.py index 41f2bb9..08e0757 100755 --- a/tests/test_26_oidc_userinfo_endpoint.py +++ b/tests/test_26_oidc_userinfo_endpoint.py @@ -2,19 +2,20 @@ import os import pytest -from cryptojwt.jwt import utc_time_sans_frac +from oidcmsg.oauth2 import ResponseMessage from oidcmsg.oidc import AccessTokenRequest from oidcmsg.oidc import AuthorizationRequest +from oidcmsg.time_util import time_sans_frac from oidcendpoint import user_info +from oidcendpoint.authn_event import create_authn_event from oidcendpoint.endpoint_context import EndpointContext from oidcendpoint.id_token import IDToken from oidcendpoint.oidc import userinfo from oidcendpoint.oidc.authorization import Authorization from oidcendpoint.oidc.provider_config import ProviderConfiguration from oidcendpoint.oidc.registration import Registration -from oidcendpoint.oidc.token import AccessToken -from oidcendpoint.session import setup_session +from oidcendpoint.oidc.token import Token from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD from oidcendpoint.user_info import UserInfo @@ -35,7 +36,7 @@ ] CAPABILITIES = { - "subject_types_supported": ["public", "pairwise"], + "subject_types_supported": ["public", "pairwise", "ephemeral"], "grant_types_supported": [ "authorization_code", "implicit", @@ -103,7 +104,7 @@ def create_endpoint(self): }, "token": { "path": "token", - "class": AccessToken, + "class": Token, "kwargs": { "client_authn_methods": [ "client_secret_post", @@ -165,133 +166,123 @@ def create_endpoint(self): "response_types": ["code", "token", "code id_token", "id_token"], } self.endpoint = endpoint_context.endpoint["userinfo"] + self.session_manager = endpoint_context.session_manager + self.user_id = "diana" + + def _create_session(self, auth_req, sub_type="public", sector_identifier=''): + if sector_identifier: + authz_req = auth_req.copy() + authz_req["sector_identifier_uri"] = sector_identifier + else: + authz_req = auth_req + client_id = authz_req['client_id'] + ae = create_authn_event(self.user_id) + return self.session_manager.create_session(ae, authz_req, self.user_id, + client_id=client_id, + sub_type=sub_type) + + def _mint_code(self, grant, session_id): + # Constructing an authorization code is now done + return grant.mint_token( + session_id=session_id, + endpoint_context=self.endpoint.endpoint_context, + token_type='authorization_code', + token_handler=self.session_manager.token_handler["code"], + expires_at=time_sans_frac() + 300 # 5 minutes from now + ) + + def _mint_token(self, token_type, grant, session_id, token_ref=None): + _session_info = self.session_manager.get_session_info(session_id, grant=True) + return grant.mint_token( + session_id=session_id, + endpoint_context=self.endpoint.endpoint_context, + token_type=token_type, + token_handler=self.session_manager.token_handler[token_type], + expires_at=time_sans_frac() + 900, # 15 minutes from now + based_on=token_ref # Means the token (tok) was used to mint this token + ) def test_init(self): assert self.endpoint assert set( self.endpoint.endpoint_context.provider_info["claims_supported"] ) == { - "address", - "birthdate", - "email", - "email_verified", - "eduperson_scoped_affiliation", - "family_name", - "gender", - "given_name", - "locale", - "middle_name", - "name", - "nickname", - "phone_number", - "phone_number_verified", - "picture", - "preferred_username", - "profile", - "sub", - "updated_at", - "website", - "zoneinfo", - } + "address", + "birthdate", + "email", + "email_verified", + "eduperson_scoped_affiliation", + "family_name", + "gender", + "given_name", + "locale", + "middle_name", + "name", + "nickname", + "phone_number", + "phone_number_verified", + "picture", + "preferred_username", + "profile", + "sub", + "updated_at", + "website", + "zoneinfo", + } def test_parse(self): - session_id = setup_session( - self.endpoint.endpoint_context, - AUTH_REQ, - uid="userID", - authn_event={ - "authn_info": "loa1", - "uid": "diana", - "authn_time": utc_time_sans_frac(), - "valid_until": utc_time_sans_frac() + 3600, - }, - ) - _dic = self.endpoint.endpoint_context.sdb.upgrade_to_token(key=session_id) - _req = self.endpoint.parse_request( - {}, auth="Bearer {}".format(_dic["access_token"]) - ) + session_id = self._create_session(AUTH_REQ) + grant = self.session_manager[session_id] + # Free standing access token, not based on an authorization code + access_token = self._mint_token("access_token", grant, session_id) + _req = self.endpoint.parse_request({}, auth="Bearer {}".format(access_token.value)) assert set(_req.keys()) == {"client_id", "access_token"} + assert _req["client_id"] == AUTH_REQ['client_id'] + assert _req["access_token"] == access_token.value def test_parse_invalid_token(self): _req = self.endpoint.parse_request({}, auth="Bearer invalid") - assert _req['error'] == "invalid_token" def test_process_request(self): - session_id = setup_session( - self.endpoint.endpoint_context, - AUTH_REQ, - uid="userID", - authn_event={ - "authn_info": "loa1", - "uid": "diana", - "authn_time": utc_time_sans_frac(), - "valid_until": utc_time_sans_frac() + 3600, - }, - ) - _dic = self.endpoint.endpoint_context.sdb.upgrade_to_token(key=session_id) + session_id = self._create_session(AUTH_REQ) + grant = self.session_manager[session_id] + code = self._mint_code(grant, session_id) + access_token = self._mint_token("access_token", grant, session_id, code) + _req = self.endpoint.parse_request( - {}, auth="Bearer {}".format(_dic["access_token"]) + {}, auth="Bearer {}".format(access_token.value) ) args = self.endpoint.process_request(_req) assert args def test_process_request_not_allowed(self): - session_id = setup_session( - self.endpoint.endpoint_context, - AUTH_REQ, - uid="userID", - authn_event={ - "authn_info": "loa1", - "uid": "diana", - "authn_time": utc_time_sans_frac() - 7200, - "valid_until": utc_time_sans_frac() - 3600, - }, - ) - _dic = self.endpoint.endpoint_context.sdb.upgrade_to_token(key=session_id) - _req = self.endpoint.parse_request( - {}, auth="Bearer {}".format(_dic["access_token"]) - ) - args = self.endpoint.process_request(_req) - assert set(args["response_args"].keys()) == {"error", "error_description"} + session_id = self._create_session(AUTH_REQ) + grant = self.session_manager[session_id] + code = self._mint_code(grant, session_id) + access_token = self._mint_token("access_token", grant, session_id, code) + + # 2 things can make the request invalid. + # 1) The token is not valid anymore or 2) The event is not valid. + _event = grant.authentication_event + _event['authn_time'] -= 9000 + _event['valid_until'] -= 9000 - def test_process_request_offline_access(self): - auth_req = AUTH_REQ.copy() - auth_req["scope"] = ["openid", "offline_access"] - session_id = setup_session( - self.endpoint.endpoint_context, - auth_req, - uid="userID", - authn_event={ - "authn_info": "loa1", - "uid": "diana", - "authn_time": utc_time_sans_frac() , - "valid_until": utc_time_sans_frac() + 3600, - }, - ) - _dic = self.endpoint.endpoint_context.sdb.upgrade_to_token(key=session_id) _req = self.endpoint.parse_request( - {}, auth="Bearer {}".format(_dic["access_token"]) + {}, auth="Bearer {}".format(access_token.value) ) args = self.endpoint.process_request(_req) - assert set(args["response_args"].keys()) == {"sub"} + assert set(args["response_args"].keys()) == {"error", "error_description"} def test_do_response(self): - session_id = setup_session( - self.endpoint.endpoint_context, - AUTH_REQ, - uid="userID", - authn_event={ - "authn_info": "loa1", - "uid": "diana", - "authn_time": utc_time_sans_frac(), - "valid_until": utc_time_sans_frac() + 3600, - }, - ) - _dic = self.endpoint.endpoint_context.sdb.upgrade_to_token(key=session_id) + session_id = self._create_session(AUTH_REQ) + grant = self.session_manager[session_id] + code = self._mint_code(grant,session_id) + access_token = self._mint_token("access_token", grant, session_id, code) + _req = self.endpoint.parse_request( - {}, auth="Bearer {}".format(_dic["access_token"]) + {}, auth="Bearer {}".format(access_token.value) ) args = self.endpoint.process_request(_req) assert args @@ -299,24 +290,15 @@ def test_do_response(self): assert res def test_do_signed_response(self): - self.endpoint.endpoint_context.cdb["client_1"][ - "userinfo_signed_response_alg" - ] = "ES256" - - session_id = setup_session( - self.endpoint.endpoint_context, - AUTH_REQ, - uid="userID", - authn_event={ - "authn_info": "loa1", - "uid": "diana", - "authn_time": utc_time_sans_frac(), - "valid_until": utc_time_sans_frac() + 3600, - }, - ) - _dic = self.endpoint.endpoint_context.sdb.upgrade_to_token(key=session_id) + self.endpoint.endpoint_context.cdb["client_1"]["userinfo_signed_response_alg"] = "ES256" + + session_id = self._create_session(AUTH_REQ) + grant = self.session_manager[session_id] + code = self._mint_code(grant, session_id) + access_token = self._mint_token("access_token", grant, session_id, code) + _req = self.endpoint.parse_request( - {}, auth="Bearer {}".format(_dic["access_token"]) + {}, auth="Bearer {}".format(access_token.value) ) args = self.endpoint.process_request(_req) assert args @@ -326,28 +308,55 @@ def test_do_signed_response(self): def test_custom_scope(self): _auth_req = AUTH_REQ.copy() _auth_req["scope"] = ["openid", "research_and_scholarship"] - session_id = setup_session( - self.endpoint.endpoint_context, - _auth_req, - uid="userID", - authn_event={ - "authn_info": "loa1", - "uid": "diana", - "authn_time": utc_time_sans_frac(), - "valid_until": utc_time_sans_frac() + 3600, - }, + + session_id = self._create_session(_auth_req) + grant = self.session_manager[session_id] + access_token = self._mint_token("access_token", grant, session_id) + + self.endpoint.kwargs["add_claims_by_scope"] = True + self.endpoint.endpoint_context.claims_interface.add_claims_by_scope = True + grant.claims = { + "userinfo": self.endpoint.endpoint_context.claims_interface.get_claims( + session_id=session_id, scopes=_auth_req["scope"], usage="userinfo") + } + + _req = self.endpoint.parse_request( + {}, auth="Bearer {}".format(access_token.value) ) - _dic = self.endpoint.endpoint_context.sdb.upgrade_to_token(key=session_id) + args = self.endpoint.process_request(_req) + assert set(args["response_args"].keys()) == {'eduperson_scoped_affiliation', 'given_name', + 'email_verified', 'email', 'family_name', + 'name', "sub"} + + def test_wrong_type_of_token(self): + _auth_req = AUTH_REQ.copy() + _auth_req["scope"] = ["openid", "research_and_scholarship"] + + session_id = self._create_session(_auth_req) + grant = self.session_manager[session_id] + refresh_token = self._mint_token("refresh_token", grant, session_id) + _req = self.endpoint.parse_request( - {}, auth="Bearer {}".format(_dic["access_token"]) + {}, auth="Bearer {}".format(refresh_token.value) ) args = self.endpoint.process_request(_req) - assert set(args["response_args"].keys()) == { - "sub", - "name", - "given_name", - "family_name", - "email", - "email_verified", - "eduperson_scoped_affiliation", - } + + assert isinstance(args, ResponseMessage) + assert args['error_description'] == "Wrong type of token" + + def test_invalid_token(self): + _auth_req = AUTH_REQ.copy() + _auth_req["scope"] = ["openid", "research_and_scholarship"] + + session_id = self._create_session(_auth_req) + grant = self.session_manager[session_id] + access_token = self._mint_token("access_token", grant, session_id) + + _req = self.endpoint.parse_request( + {}, auth="Bearer {}".format(access_token.value) + ) + access_token.expires_at = time_sans_frac() - 10 + args = self.endpoint.process_request(_req) + + assert isinstance(args, ResponseMessage) + assert args['error_description'] == "Invalid Token" diff --git a/tests/test_27_jwt_token.py b/tests/test_27_jwt_token.py index da0336d..13ef7f0 100644 --- a/tests/test_27_jwt_token.py +++ b/tests/test_27_jwt_token.py @@ -2,21 +2,24 @@ import pytest from cryptojwt.jwt import JWT -from cryptojwt.jwt import utc_time_sans_frac from cryptojwt.key_jar import init_key_jar from oidcmsg.oidc import AccessTokenRequest from oidcmsg.oidc import AuthorizationRequest +from oidcmsg.time_util import time_sans_frac from oidcendpoint import user_info +from oidcendpoint.authn_event import create_authn_event +from oidcendpoint.authz import AuthzHandling from oidcendpoint.client_authn import verify_client from oidcendpoint.endpoint_context import EndpointContext from oidcendpoint.id_token import IDToken +from oidcendpoint.oauth2.introspection import Introspection from oidcendpoint.oidc.authorization import Authorization from oidcendpoint.oidc.provider_config import ProviderConfiguration from oidcendpoint.oidc.registration import Registration from oidcendpoint.oidc.session import Session -from oidcendpoint.oidc.token import AccessToken -from oidcendpoint.session import setup_session +from oidcendpoint.oidc.token import Token +from oidcendpoint.session import session_key from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD KEYDEFS = [ @@ -49,7 +52,7 @@ "private_key_jwt", ], "response_modes_supported": ["query", "fragment", "form_post"], - "subject_types_supported": ["public", "pairwise"], + "subject_types_supported": ["public", "pairwise", "ephemeral"], "grant_types_supported": [ "authorization_code", "implicit", @@ -81,6 +84,12 @@ BASEDIR = os.path.abspath(os.path.dirname(__file__)) +MAP = { + "authorization_code": "code", + "access_token": "access_token", + "refresh_token": "refresh_token" +} + def full_path(local_file): return os.path.join(BASEDIR, local_file) @@ -92,9 +101,6 @@ def create_endpoint(self): conf = { "issuer": ISSUER, "password": "mycket hemligt", - "token_expires_in": 600, - "grant_expires_in": 300, - "refresh_token_expires_in": 86400, "verify_ssl": False, "capabilities": CAPABILITIES, "keys": {"uri_path": "jwks.json", "key_defs": KEYDEFS}, @@ -108,16 +114,18 @@ def create_endpoint(self): }, "code": {"lifetime": 600}, "token": { - "class": "oidcendpoint.jwt_token.JWTToken", + "class": "oidcendpoint.token.jwt_token.JWTToken", + "kwargs": { + "lifetime": 3600, + "base_claims": {"eduperson_scoped_affiliation": None}, + "add_claims_by_scope": True, + "aud": ["https://example.org/appl"], + }, + }, + "refresh": { + "class": "oidcendpoint.token.jwt_token.JWTToken", "kwargs": { "lifetime": 3600, - "add_claims": [ - "email", - "email_verified", - "phone_number", - "phone_number_verified", - ], - "add_claim_by_scope": True, "aud": ["https://example.org/appl"], }, }, @@ -138,8 +146,12 @@ def create_endpoint(self): "class": Authorization, "kwargs": {}, }, - "token": {"path": "{}/token", "class": AccessToken, "kwargs": {}}, + "token": {"path": "{}/token", "class": Token, "kwargs": {}}, "session": {"path": "{}/end_session", "class": Session}, + "introspection": { + "path": "{}/introspection", + "class": Introspection + } }, "client_authn": verify_client, "authentication": { @@ -155,81 +167,114 @@ def create_endpoint(self): "kwargs": {"db_file": full_path("users.json")}, }, "id_token": {"class": IDToken}, + "authz": { + "class": AuthzHandling, + "kwargs": { + "grant_config": { + "usage_rules": { + "authorization_code": { + 'supports_minting': ["access_token", "refresh_token", "id_token"], + "max_usage": 1 + }, + "access_token": {}, + "refresh_token": { + 'supports_minting': ["access_token", "refresh_token"], + } + }, + "expires_in": 43200 + } + } + }, } - endpoint_context = EndpointContext(conf, keyjar=KEYJAR) - endpoint_context.cdb["client_1"] = { + self.endpoint_context = EndpointContext(conf, keyjar=KEYJAR) + self.endpoint_context.cdb["client_1"] = { "client_secret": "hemligt", "redirect_uris": [("https://example.com/cb", None)], "client_salt": "salted", "token_endpoint_auth_method": "client_secret_post", "response_types": ["code", "token", "code id_token", "id_token"], } - self.endpoint = Session(endpoint_context) + self.session_manager = self.endpoint_context.session_manager + self.user_id = "diana" + self.endpoint = self.endpoint_context.endpoint["session"] - def test_parse(self): - session_id = setup_session( - self.endpoint.endpoint_context, AUTH_REQ, uid="diana" + def _create_session(self, auth_req, sub_type="public", sector_identifier=''): + if sector_identifier: + authz_req = auth_req.copy() + authz_req["sector_identifier_uri"] = sector_identifier + else: + authz_req = auth_req + client_id = authz_req['client_id'] + ae = create_authn_event(self.user_id) + return self.session_manager.create_session(ae, authz_req, self.user_id, + client_id=client_id, + sub_type=sub_type) + + def _mint_token(self, type, grant, session_id, based_on=None, **kwargs): + # Constructing an authorization code is now done + return grant.mint_token( + session_id=session_id, + endpoint_context=self.endpoint_context, + token_type=type, + token_handler=self.session_manager.token_handler.handler[MAP[type]], + expires_at=time_sans_frac() + 300, # 5 minutes from now + based_on=based_on, + **kwargs ) - _dic = self.endpoint.endpoint_context.sdb.upgrade_to_token(key=session_id) + + def test_parse(self): + session_id = self._create_session(AUTH_REQ) + # apply consent + grant = self.endpoint_context.authz(session_id=session_id, request=AUTH_REQ) + # grant = self.session_manager[session_id] + code = self._mint_token("authorization_code", grant, session_id) + access_token = self._mint_token("access_token", grant, session_id, code, + resources=[AUTH_REQ["client_id"]]) _verifier = JWT(self.endpoint.endpoint_context.keyjar) - _info = _verifier.unpack(_dic["access_token"]) + _info = _verifier.unpack(access_token.value) assert _info["ttype"] == "T" - assert _info["phone_number"] == "+46907865000" - assert set(_info["aud"]) == {"client_1", "https://example.org/appl"} + # assert _info["eduperson_scoped_affiliation"] == ["staff@example.org"] + assert set(_info["aud"]) == {"client_1"} def test_info(self): - session_id = setup_session( - self.endpoint.endpoint_context, AUTH_REQ, uid="diana" - ) - _dic = self.endpoint.endpoint_context.sdb.upgrade_to_token(key=session_id) + session_id = self._create_session(AUTH_REQ) + # apply consent + grant = self.endpoint_context.authz(session_id=session_id, request=AUTH_REQ) + # + code = self._mint_token("authorization_code", grant, session_id) + access_token = self._mint_token("access_token", grant, session_id, code) - handler = self.endpoint.endpoint_context.sdb.handler.handler["access_token"] - _info = handler.info(_dic["access_token"]) + _info = self.session_manager.token_handler.info(access_token.value) assert _info["type"] == "T" assert _info["sid"] == session_id @pytest.mark.parametrize("enable_claims_per_client", [True, False]) - def test_client_claims(self, enable_claims_per_client): - ec = self.endpoint.endpoint_context - handler = ec.sdb.handler.handler["access_token"] - session_id = setup_session(ec, AUTH_REQ, uid="diana") - ec.cdb["client_1"]["access_token_claims"] = {"address": None} - handler.enable_claims_per_client = enable_claims_per_client - _dic = ec.sdb.upgrade_to_token(key=session_id) - - token = _dic["access_token"] - _jwt = JWT(key_jar=KEYJAR, iss="client_1") - res = _jwt.unpack(token) - assert enable_claims_per_client is ("address" in res) + def test_enable_claims_per_client(self, enable_claims_per_client): + # Set up configuration + self.endpoint.endpoint_context.cdb["client_1"]["access_token_claims"] = {"address": None} + self.endpoint_context.session_manager.token_handler.handler["access_token"].kwargs[ + "enable_claims_per_client"] = enable_claims_per_client + + session_id = self._create_session(AUTH_REQ) + # apply consent + grant = self.endpoint_context.authz(session_id=session_id, request=AUTH_REQ) + # + code = self._mint_token("authorization_code", grant, session_id) + access_token = self._mint_token("access_token", grant, session_id, code) - @pytest.mark.parametrize("add_scope", [True, False]) - def test_add_scopes(self, add_scope): - ec = self.endpoint.endpoint_context - handler = ec.sdb.handler.handler["access_token"] - auth_req = dict(AUTH_REQ) - auth_req["scope"] = ["openid", "profile", "aba"] - session_id = setup_session(ec, auth_req, uid="diana") - handler.add_scope = add_scope - _dic = ec.sdb.upgrade_to_token(key=session_id) - - token = _dic["access_token"] _jwt = JWT(key_jar=KEYJAR, iss="client_1") - res = _jwt.unpack(token) - assert add_scope is (res.get("scope") == ["openid", "profile"]) + res = _jwt.unpack(access_token.value) + assert enable_claims_per_client is ("address" in res) def test_is_expired(self): - session_id = setup_session( - self.endpoint.endpoint_context, AUTH_REQ, uid="diana" - ) - _dic = self.endpoint.endpoint_context.sdb.upgrade_to_token(key=session_id) + session_id = self._create_session(AUTH_REQ) + grant = self.session_manager[session_id] + code = self._mint_token("authorization_code", grant, session_id) + access_token = self._mint_token("access_token", grant, session_id, code) - handler = self.endpoint.endpoint_context.sdb.handler.handler["access_token"] - assert handler.is_expired(_dic["access_token"]) is False - - assert ( - handler.is_expired(_dic["access_token"], utc_time_sans_frac() + 4000) - is True - ) + assert access_token.is_active() + # 4000 seconds in the future. Passed the lifetime. + assert access_token.is_active(now=time_sans_frac() + 4000) is False diff --git a/tests/test_28_oidc_refresh_token_endpoint.py b/tests/test_28_oidc_refresh_token_endpoint.py deleted file mode 100755 index 04a5f7e..0000000 --- a/tests/test_28_oidc_refresh_token_endpoint.py +++ /dev/null @@ -1,238 +0,0 @@ -import json -import os - -import pytest -from oidcmsg.oidc import AccessTokenRequest -from oidcmsg.oidc import AuthorizationRequest -from oidcmsg.oidc import RefreshAccessTokenRequest - -from oidcendpoint.client_authn import ClientSecretBasic -from oidcendpoint.client_authn import ClientSecretJWT -from oidcendpoint.client_authn import ClientSecretPost -from oidcendpoint.client_authn import PrivateKeyJWT -from oidcendpoint.client_authn import verify_client -from oidcendpoint.endpoint_context import EndpointContext -from oidcendpoint.oidc import userinfo -from oidcendpoint.oidc.authorization import Authorization -from oidcendpoint.oidc.provider_config import ProviderConfiguration -from oidcendpoint.oidc.refresh_token import RefreshAccessToken -from oidcendpoint.oidc.registration import Registration -from oidcendpoint.oidc.token import AccessToken -from oidcendpoint.session import setup_session -from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD -from oidcendpoint.user_info import UserInfo - -KEYDEFS = [ - {"type": "RSA", "key": "", "use": ["sig"]}, - {"type": "EC", "crv": "P-256", "use": ["sig"]}, -] - -RESPONSE_TYPES_SUPPORTED = [ - ["code"], - ["token"], - ["id_token"], - ["code", "token"], - ["code", "id_token"], - ["id_token", "token"], - ["code", "token", "id_token"], - ["none"], -] - -CAPABILITIES = { - "response_types_supported": [" ".join(x) for x in RESPONSE_TYPES_SUPPORTED], - "token_endpoint_auth_methods_supported": [ - "client_secret_post", - "client_secret_basic", - "client_secret_jwt", - "private_key_jwt", - ], - "response_modes_supported": ["query", "fragment", "form_post"], - "subject_types_supported": ["public", "pairwise"], - "grant_types_supported": [ - "authorization_code", - "implicit", - "urn:ietf:params:oauth:grant-type:jwt-bearer", - "refresh_token", - ], - "claim_types_supported": ["normal", "aggregated", "distributed"], - "claims_parameter_supported": True, - "request_parameter_supported": True, - "request_uri_parameter_supported": True, -} - -AUTH_REQ = AuthorizationRequest( - client_id="client_1", - redirect_uri="https://example.com/cb", - scope=["openid"], - state="STATE", - response_type="code", -) - -TOKEN_REQ = AccessTokenRequest( - client_id="client_1", - redirect_uri="https://example.com/cb", - state="STATE", - grant_type="authorization_code", - client_secret="hemligt", -) - -REFRESH_TOKEN_REQ = RefreshAccessTokenRequest( - grant_type="refresh_token", client_id="client_1", client_secret="hemligt" -) - -TOKEN_REQ_DICT = TOKEN_REQ.to_dict() - -BASEDIR = os.path.abspath(os.path.dirname(__file__)) - - -def full_path(local_file): - return os.path.join(BASEDIR, local_file) - - -USERINFO = UserInfo(json.loads(open(full_path("users.json")).read())) - - -class TestEndpoint(object): - @pytest.fixture(autouse=True) - def create_endpoint(self): - conf = { - "issuer": "https://example.com/", - "password": "mycket hemligt", - "token_expires_in": 600, - "grant_expires_in": 300, - "refresh_token_expires_in": 86400, - "verify_ssl": False, - "capabilities": CAPABILITIES, - "keys": {"uri_path": "jwks.json", "key_defs": KEYDEFS}, - "endpoint": { - "provider_config": { - "path": "{}/.well-known/openid-configuration", - "class": ProviderConfiguration, - "kwargs": {}, - }, - "registration": { - "path": "{}/registration", - "class": Registration, - "kwargs": {}, - }, - "authorization": { - "path": "{}/authorization", - "class": Authorization, - "kwargs": {}, - }, - "token": { - "path": "{}/token", - "class": AccessToken, - "kwargs": { - "client_authn_method": { - "client_secret_basic": ClientSecretBasic, - "client_secret_post": ClientSecretPost, - "client_secret_jwt": ClientSecretJWT, - "private_key_jwt": PrivateKeyJWT, - } - }, - }, - "refresh_token": { - "path": "{}/token", - "class": RefreshAccessToken, - "kwargs": { - "client_authn_method": { - "client_secret_basic": ClientSecretBasic, - "client_secret_post": ClientSecretPost, - "client_secret_jwt": ClientSecretJWT, - "private_key_jwt": PrivateKeyJWT, - } - }, - }, - "userinfo": { - "path": "{}/userinfo", - "class": userinfo.UserInfo, - "kwargs": {"db_file": "users.json"}, - }, - }, - "authentication": { - "anon": { - "acr": INTERNETPROTOCOLPASSWORD, - "class": "oidcendpoint.user_authn.user.NoAuthn", - "kwargs": {"user": "diana"}, - } - }, - "userinfo": {"class": UserInfo, "kwargs": {"db": {}}}, - "client_authn": verify_client, - "template_dir": "template", - } - endpoint_context = EndpointContext(conf) - endpoint_context.cdb["client_1"] = { - "client_secret": "hemligt", - "redirect_uris": [("https://example.com/cb", None)], - "client_salt": "salted", - "token_endpoint_auth_method": "client_secret_post", - "response_types": ["code", "token", "code id_token", "id_token"], - } - self.token_endpoint = endpoint_context.endpoint["token"] - self.refresh_token_endpoint = endpoint_context.endpoint["refresh_token"] - - def test_init(self): - assert self.refresh_token_endpoint - - def test_do_refresh_access_token(self): - areq = AUTH_REQ.copy() - areq["scope"] = ["openid", "offline_access"] - _cntx = self.token_endpoint.endpoint_context - session_id = setup_session( - _cntx, areq, uid="user", acr=INTERNETPROTOCOLPASSWORD - ) - _cntx.sdb.update(session_id, user="diana") - _token_request = TOKEN_REQ_DICT.copy() - _token_request["code"] = _cntx.sdb[session_id]["code"] - _req = self.token_endpoint.parse_request(_token_request) - _resp = self.token_endpoint.process_request(request=_req) - - _request = REFRESH_TOKEN_REQ.copy() - _request["refresh_token"] = _resp["response_args"]["refresh_token"] - _req = self.refresh_token_endpoint.parse_request(_request.to_json()) - _resp = self.refresh_token_endpoint.process_request(request=_req) - assert set(_resp.keys()) == {"response_args", "http_headers"} - assert set(_resp["response_args"].keys()) == { - "access_token", - "token_type", - "expires_in", - "refresh_token", - "id_token", - } - msg = self.refresh_token_endpoint.do_response(request=_req, **_resp) - assert isinstance(msg, dict) - - def test_do_2nd_refresh_access_token(self): - areq = AUTH_REQ.copy() - areq["scope"] = ["openid", "offline_access"] - _cntx = self.token_endpoint.endpoint_context - session_id = setup_session( - _cntx, areq, uid="user", acr=INTERNETPROTOCOLPASSWORD - ) - _cntx.sdb.update(session_id, user="diana") - _token_request = TOKEN_REQ_DICT.copy() - _token_request["code"] = _cntx.sdb[session_id]["code"] - _req = self.token_endpoint.parse_request(_token_request) - _resp = self.token_endpoint.process_request(request=_req) - - _request = REFRESH_TOKEN_REQ.copy() - _request["refresh_token"] = _resp["response_args"]["refresh_token"] - _req = self.refresh_token_endpoint.parse_request(_request.to_json()) - _resp = self.refresh_token_endpoint.process_request(request=_req) - - _request = REFRESH_TOKEN_REQ.copy() - _request["refresh_token"] = _resp["response_args"]["refresh_token"] - _req = self.refresh_token_endpoint.parse_request(_request.to_json()) - _resp = self.refresh_token_endpoint.process_request(request=_req) - - assert set(_resp.keys()) == {"response_args", "http_headers"} - assert set(_resp["response_args"].keys()) == { - "access_token", - "token_type", - "expires_in", - "refresh_token", - "id_token", - } - msg = self.refresh_token_endpoint.do_response(request=_req, **_resp) - assert isinstance(msg, dict) diff --git a/tests/test_30_oidc_end_session.py b/tests/test_30_oidc_end_session.py index 57125de..04c29dd 100644 --- a/tests/test_30_oidc_end_session.py +++ b/tests/test_30_oidc_end_session.py @@ -4,28 +4,35 @@ from urllib.parse import parse_qs from urllib.parse import urlparse -from oidcendpoint.token_handler import UnknownToken import pytest import responses +from cryptojwt import as_unicode +from cryptojwt import b64d from cryptojwt.key_jar import build_keyjar +from cryptojwt.utils import as_bytes from oidcmsg.exception import InvalidRequest from oidcmsg.message import Message from oidcmsg.oidc import AuthorizationRequest from oidcmsg.oidc import verified_claim_name from oidcmsg.oidc import verify_id_token +from oidcmsg.time_util import time_sans_frac -from oidcendpoint.common.authorization import join_query +from oidcendpoint.authn_event import create_authn_event from oidcendpoint.cookie import CookieDealer from oidcendpoint.cookie import new_cookie from oidcendpoint.endpoint_context import EndpointContext from oidcendpoint.exception import RedirectURIError +from oidcendpoint.oauth2.authorization import join_query from oidcendpoint.oidc import userinfo from oidcendpoint.oidc.authorization import Authorization from oidcendpoint.oidc.provider_config import ProviderConfiguration from oidcendpoint.oidc.registration import Registration from oidcendpoint.oidc.session import Session from oidcendpoint.oidc.session import do_front_channel_logout_iframe -from oidcendpoint.oidc.token import AccessToken +from oidcendpoint.oidc.token import Token +from oidcendpoint.session import session_key +from oidcendpoint.session import unpack_session_key +from oidcendpoint.session.grant import Grant from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD from oidcendpoint.user_info import UserInfo @@ -62,7 +69,7 @@ "private_key_jwt", ], "response_modes_supported": ["query", "fragment", "form_post"], - "subject_types_supported": ["public", "pairwise"], + "subject_types_supported": ["public", "pairwise", "ephemeral"], "grant_types_supported": [ "authorization_code", "implicit", @@ -124,7 +131,7 @@ def create_endpoint(self): "class": Authorization, "kwargs": {"client_authn_method": None}, }, - "token": {"path": "{}/token", "class": AccessToken, "kwargs": {}}, + "token": {"path": "{}/token", "class": Token, "kwargs": {}}, "userinfo": { "path": "{}/userinfo", "class": userinfo.UserInfo, @@ -185,24 +192,23 @@ def create_endpoint(self): "post_logout_redirect_uris": [("{}logout_cb".format(CLI2), "")], }, } + self.session_manager = endpoint_context.session_manager self.authn_endpoint = endpoint_context.endpoint["authorization"] self.session_endpoint = endpoint_context.endpoint["session"] self.token_endpoint = endpoint_context.endpoint["token"] + self.user_id = "diana" def test_end_session_endpoint(self): # End session not allowed if no cookie and no id_token_hint is sent # (can't determine session) - with pytest.raises(UnknownToken): + with pytest.raises(ValueError): _ = self.session_endpoint.process_request("", cookie="FAIL") - def _create_cookie(self, user, sid, state, client_id): + def _create_cookie(self, session_id): ec = self.session_endpoint.endpoint_context return new_cookie( ec, - sub=user, - sid=sid, - state=state, - client_id=client_id, + sid=session_id, cookie_name=ec.cookie_name["session"], ) @@ -215,7 +221,7 @@ def _code_auth(self, state): client_id="client_1", ) _pr_resp = self.authn_endpoint.parse_request(req.to_dict()) - _resp = self.authn_endpoint.process_request(_pr_resp) + return self.authn_endpoint.process_request(_pr_resp) def _code_auth2(self, state): req = AuthorizationRequest( @@ -226,16 +232,7 @@ def _code_auth2(self, state): client_id="client_2", ) _pr_resp = self.authn_endpoint.parse_request(req.to_dict()) - _resp = self.authn_endpoint.process_request(_pr_resp) - - def _get_sid(self): - _sdb = self.session_endpoint.endpoint_context.sdb - - for _sid in _sdb.keys(): - if _sid.startswith("__state__"): - continue - else: - return _sid + return self.authn_endpoint.process_request(_pr_resp) def _auth_with_id_token(self, state): req = AuthorizationRequest( @@ -249,12 +246,19 @@ def _auth_with_id_token(self, state): _pr_resp = self.authn_endpoint.parse_request(req.to_dict()) _resp = self.authn_endpoint.process_request(_pr_resp) - return _resp["response_args"]["id_token"] + part = self.session_endpoint.endpoint_context.cookie_dealer.get_cookie_value( + _resp["cookie"][0], cookie_name="oidcop" + ) + # value is a base64 encoded JSON document + _cookie_info = json.loads(as_unicode(b64d(as_bytes(part[0])))) + + return _resp["response_args"], _cookie_info["sid"] def test_end_session_endpoint_with_cookie(self): - self._code_auth("1234567") - _sid = self._get_sid() - cookie = self._create_cookie("diana", _sid, "1234567", "client_1") + _resp = self._code_auth("1234567") + _code = _resp["response_args"]["code"] + _session_info = self.session_manager.get_session_info_by_token(_code) + cookie = self._create_cookie(_session_info["session_id"]) _req_args = self.session_endpoint.parse_request({"state": "1234567"}) resp = self.session_endpoint.process_request(_req_args, cookie=cookie) @@ -266,39 +270,27 @@ def test_end_session_endpoint_with_cookie(self): qs = parse_qs(p.query) jwt_info = self.session_endpoint.unpack_signed_jwt(qs["sjwt"][0]) - assert jwt_info["user"] == "diana" - assert jwt_info["client_id"] == "client_1" + assert jwt_info["sid"] == _session_info["session_id"] assert jwt_info["redirect_uri"] == "https://example.com/post_logout" - def test_end_session_endpoint_with_wrong_cookie(self): - self._code_auth("1234567") - cookie = self._create_cookie("diana", "client_2", "abcdefg", "client_1") - - with pytest.raises(UnknownToken): - self.session_endpoint.process_request({"state": "abcde"}, cookie=cookie) - - def test_end_session_endpoint_with_cookie_wrong_user(self): + def test_end_session_endpoint_with_cookie_and_unknown_sid(self): # Need cookie and ID Token to figure this out - id_token = self._auth_with_id_token("1234567") + resp_args, _session_id = self._auth_with_id_token("1234567") + id_token = resp_args["id_token"] - cookie = self._create_cookie("diggins", "_sid_", "1234567", "client_1") + _uid, _cid, _gid = unpack_session_key(_session_id) + cookie = self._create_cookie(session_key(_uid, "client_66", _gid)) - msg = Message(id_token=id_token) - verify_id_token(msg, keyjar=self.session_endpoint.endpoint_context.keyjar) - - msg2 = Message(id_token_hint=id_token) - msg2[verified_claim_name("id_token_hint")] = msg[ - verified_claim_name("id_token") - ] with pytest.raises(ValueError): - self.session_endpoint.process_request(msg2, cookie=cookie) + self.session_endpoint.process_request({"state": "foo"}, cookie=cookie) - def test_end_session_endpoint_with_cookie_unknown_sid(self): + def test_end_session_endpoint_with_cookie_id_token_and_unknown_sid(self): # Need cookie and ID Token to figure this out - id_token = self._auth_with_id_token("1234567") + resp_args, _session_id = self._auth_with_id_token("1234567") + id_token = resp_args["id_token"] - # Wrong client_id - cookie = self._create_cookie("diana", "_sid_", "state", "client_1") + _uid, _cid, _gid = unpack_session_key(_session_id) + cookie = self._create_cookie(session_key(_uid, "client_66", _gid)) msg = Message(id_token=id_token) verify_id_token(msg, keyjar=self.session_endpoint.endpoint_context.keyjar) @@ -311,11 +303,11 @@ def test_end_session_endpoint_with_cookie_unknown_sid(self): self.session_endpoint.process_request(msg2, cookie=cookie) def test_end_session_endpoint_with_cookie_dual_login(self): - self._code_auth("1234567") + _resp = self._code_auth("1234567") self._code_auth2("abcdefg") - _sdb = self.session_endpoint.endpoint_context.sdb - _sid = self._get_sid() - cookie = self._create_cookie("diana", _sid, "1234567", "client_1") + _code = _resp["response_args"]["code"] + _session_info = self.session_manager.get_session_info_by_token(_code) + cookie = self._create_cookie(_session_info["session_id"]) resp = self.session_endpoint.process_request({"state": "abcde"}, cookie=cookie) @@ -326,16 +318,15 @@ def test_end_session_endpoint_with_cookie_dual_login(self): qs = parse_qs(p.query) jwt_info = self.session_endpoint.unpack_signed_jwt(qs["sjwt"][0]) - assert jwt_info["user"] == "diana" - assert jwt_info["client_id"] == "client_1" + assert jwt_info["sid"] == _session_info["session_id"] assert jwt_info["redirect_uri"] == "https://example.com/post_logout" def test_end_session_endpoint_with_post_logout_redirect_uri(self): - self._code_auth("1234567") + _resp = self._code_auth("1234567") self._code_auth2("abcdefg") - _sdb = self.session_endpoint.endpoint_context.sdb - _sid = self._get_sid() - cookie = self._create_cookie("diana", _sid, "1234567", "client_1") + _code = _resp["response_args"]["code"] + _session_info = self.session_manager.get_session_info_by_token(_code) + cookie = self._create_cookie(_session_info["session_id"]) post_logout_redirect_uri = join_query( *self.session_endpoint.endpoint_context.cdb["client_1"][ @@ -353,14 +344,13 @@ def test_end_session_endpoint_with_post_logout_redirect_uri(self): ) def test_end_session_endpoint_with_wrong_post_logout_redirect_uri(self): - self._code_auth("1234567") + _resp = self._code_auth("1234567") self._code_auth2("abcdefg") - id_token = self._auth_with_id_token("1234567") + resp_args, _session_id = self._auth_with_id_token("1234567") + id_token = resp_args["id_token"] - _sdb = self.session_endpoint.endpoint_context.sdb - _sid = self._get_sid() - cookie = self._create_cookie("diana", _sid, "1234567", "client_1") + cookie = self._create_cookie(_session_id) post_logout_redirect_uri = "https://demo.example.com/log_out" @@ -384,7 +374,7 @@ def test_back_channel_logout_no_uri(self): self._code_auth("1234567") res = self.session_endpoint.do_back_channel_logout( - self.session_endpoint.endpoint_context.cdb["client_1"], "username", 0 + self.session_endpoint.endpoint_context.cdb["client_1"], 0 ) assert res is None @@ -394,15 +384,14 @@ def test_back_channel_logout(self): _cdb = copy.copy(self.session_endpoint.endpoint_context.cdb["client_1"]) _cdb["backchannel_logout_uri"] = "https://example.com/bc_logout" _cdb["client_id"] = "client_1" - res = self.session_endpoint.do_back_channel_logout(_cdb, "username", "_sid_") + res = self.session_endpoint.do_back_channel_logout(_cdb, "_sid_") assert isinstance(res, tuple) assert res[0] == "https://example.com/bc_logout" _jwt = self.session_endpoint.unpack_signed_jwt(res[1], "RS256") assert _jwt assert _jwt["iss"] == ISS assert _jwt["aud"] == ["client_1"] - assert _jwt["sub"] == "username" - assert _jwt["sid"] == "_sid_" + assert "sid" in _jwt def test_front_channel_logout(self): self._code_auth("1234567") @@ -448,13 +437,17 @@ def test_front_channel_logout_with_query(self): assert i in res def test_logout_from_client_bc(self): - self._code_auth("1234567") + _resp = self._code_auth("1234567") + _code = _resp["response_args"]["code"] + _session_info = self.session_manager.get_session_info_by_token( + _code, client_session_info=True) + self.session_endpoint.endpoint_context.cdb["client_1"][ "backchannel_logout_uri" ] = "https://example.com/bc_logout" self.session_endpoint.endpoint_context.cdb["client_1"]["client_id"] = "client_1" - _sid = self._get_sid() - res = self.session_endpoint.logout_from_client(_sid, "client_1") + + res = self.session_endpoint.logout_from_client(_session_info["session_id"]) assert set(res.keys()) == {"blu"} assert set(res["blu"].keys()) == {"client_1"} _spec = res["blu"]["client_1"] @@ -463,30 +456,37 @@ def test_logout_from_client_bc(self): assert _jwt assert _jwt["iss"] == ISS assert _jwt["aud"] == ["client_1"] - assert _jwt["sid"] == _sid + assert "sid" in _jwt # This session ID is not the same as the session_id mentioned above - with pytest.raises(UnknownToken): - _ = self.session_endpoint.endpoint_context.sdb[_sid] + _sid = self.session_endpoint._decrypt_sid(_jwt["sid"]) + assert _sid == _session_info["session_id"] + assert _session_info["client_session_info"].is_revoked() def test_logout_from_client_fc(self): - self._code_auth("1234567") + _resp = self._code_auth("1234567") + _code = _resp["response_args"]["code"] + _session_info = self.session_manager.get_session_info_by_token( + _code, client_session_info=True) + # del self.session_endpoint.endpoint_context.cdb['client_1']['backchannel_logout_uri'] self.session_endpoint.endpoint_context.cdb["client_1"][ "frontchannel_logout_uri" ] = "https://example.com/fc_logout" self.session_endpoint.endpoint_context.cdb["client_1"]["client_id"] = "client_1" - _sid = self._get_sid() - res = self.session_endpoint.logout_from_client(_sid, "client_1") + + res = self.session_endpoint.logout_from_client(_session_info["session_id"]) assert set(res.keys()) == {"flu"} assert set(res["flu"].keys()) == {"client_1"} _spec = res["flu"]["client_1"] assert _spec == '