diff --git a/doc/source/index.rst b/doc/source/index.rst
index 0992ecf..a060e6d 100644
--- a/doc/source/index.rst
+++ b/doc/source/index.rst
@@ -11,6 +11,7 @@ Welcome to OidcEndpoint's documentation!
:caption: Contents:
oidcendpoint
+ session_manager
Indices and tables
==================
diff --git a/doc/source/oidcendpoint.rst b/doc/source/oidcendpoint.rst
index 203701d..6c359a2 100644
--- a/doc/source/oidcendpoint.rst
+++ b/doc/source/oidcendpoint.rst
@@ -62,6 +62,14 @@ oidcendpoint\.id_token module
:undoc-members:
:show-inheritance:
+oidcendpoint\.grant module
+--------------------------
+
+.. automodule:: oidcendpoint.grant
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
oidcendpoint\.in_memory_db module
---------------------------------
@@ -86,18 +94,26 @@ oidcendpoint\.login_hint module
:undoc-members:
:show-inheritance:
-oidcendpoint\.session module
-----------------------------
+oidcendpoint\.scopes module
+---------------------------
+
+.. automodule:: oidcendpoint.scopes
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+oidcendpoint\.session_management module
+---------------------------------------
-.. automodule:: oidcendpoint.session
+.. automodule:: oidcendpoint.session_management
:members:
:undoc-members:
:show-inheritance:
-oidcendpoint\.sso_db module
----------------------------
+oidcendpoint\.session_storage module
+------------------------------------
-.. automodule:: oidcendpoint.sso_db
+.. automodule:: oidcendpoint.session_management
:members:
:undoc-members:
:show-inheritance:
diff --git a/doc/source/session_manager.rst b/doc/source/session_manager.rst
new file mode 100644
index 0000000..403e041
--- /dev/null
+++ b/doc/source/session_manager.rst
@@ -0,0 +1,536 @@
+Session Management
+==================
+
+- `About session management`_
+ - `Design criteria`_
+ - `Database layout`_
+- `The information structure`_
+ - `Session key`_
+ - `User session information`_
+ - `Client session information`_
+ - `Grant information`_
+ - `Token`_
+- `Session Info API`_
+- `Grant API`_
+- `Token API`_
+
+- `Session Manager API`_
+ - `create_session`_
+ - `add_grant`_
+ - `find_token`_
+ - `get_authentication_event`_
+ - `get_client_session_info`_
+ - `get_grant_by_response_type`_
+ - `get_session_info`_
+ - `get_session_info_by_token`_
+ - `get_sids_by_user_id`_
+ - `get_user_info`_
+ - `grants`_
+ - `revoke_client_session`_
+ - `revoke_grant`_
+ - `revoke_token`_
+
+
+About session management
+------------------------
+.. _`About session management`:
+
+The OIDC Session Management draft defines session to be:
+
+ Continuous period of time during which an End-User accesses a Relying
+ Party relying on the Authentication of the End-User performed by the
+ OpenID Provider.
+
+Note that we are dealing with a Single Sign On (SSO) context here.
+If for some reason the OP does not want to support SSO then the
+session management has to be done a bit differently. In that case each
+session (user_id,client_id) would have its own authentication even. Not one
+shared between the sessions.
+
+Design criteria
++++++++++++++++
+.. _`Design criteria`:
+
+So a session is defined by a user and a Relying Party. If one adds to that
+that a user can have several sessions active at the same time each one against
+a unique Relying Party we have the bases for session management.
+
+Furthermore the user may well decide on different rules for different
+relying parties for releasing user
+attributes, where and how issued access tokens could be used and whether
+refresh tokens should be issued or not.
+
+We also need to keep track on which tokens where used to mint new tokens
+such that we can easily revoked a suite of tokens all with a common ancestor.
+
+Database layout
++++++++++++++++
+.. _`Database layout`:
+
+The database is organized in 3 levels. The top one being the users.
+Below that the Relying Parties and at the bottom what is called grants.
+
+Grants organize authorization codes, access tokens and refresh tokens (and
+possibly other types of tokens) in a comprehensive way. More about that below.
+
+There may be many Relying Parties below a user and many grants below a
+Relying Party.
+
+The information structure
+-------------------------
+.. _`The information structure`:
+
+As stated above there are 3 layers: user session information, client session
+information and grants. But first the keys to the information.
+
+Session key
++++++++++++
+.. _`Session key`:
+
+A key to the session information is based on a list. The first item being the
+user identifier, the second the client identifier and the third the grant
+identifier.
+If you only want the user session information then the key is a list with one
+item the user id. If you want the client session information the key is a
+list with 2 items (user_id, client_id). And lastly if you want a grant then
+the key is a list with 3 elements (user_id, client_id, grant_id).
+
+A *session identifier* is constructed using the **session_key** function.
+It takes as input the 3 elements list.::
+
+ session_id = session_key(user_id, client_id, grant_id)
+
+
+Using the function **unpack_session_key** you can get the elements from a
+session_id.::
+
+ user_id, client_id, grant_id = unpack_session_id(session_id)
+
+
+User session information
+++++++++++++++++++++++++
+.. _`User session information`:
+
+Houses the authentication event information which is the same for all session
+connected to a user.
+Here we also have a list of all the clients that this user has a session with.
+Expressed as a dictionary this can look like this::
+
+ {
+ 'authentication_event': {
+ 'uid': 'diana',
+ 'authn_info': "urn:oasis:names:tc:SAML:2.0:ac:classes:InternetProtocolPassword",
+ 'authn_time': 1605515787,
+ 'valid_until': 1605519387
+ },
+ 'subordinate': ['client_1']
+ }
+
+
+Client session information
+++++++++++++++++++++++++++
+.. _`Client session information`:
+
+The client specific information of the session information.
+Presently only the authorization request and the subject identifier (sub).
+The subordinates to this set of information are the grants::
+
+ {
+ 'authorization_request':{
+ 'client_id': 'client_1',
+ 'redirect_uri': 'https://example.com/cb',
+ 'scope': ['openid', 'research_and_scholarship'],
+ 'state': 'STATE',
+ 'response_type': ['code']
+ },
+ 'sub': '117afe8d7bb0ace8e7fb2706034ab2d3fbf17f0fd4c949aa9c23aedd051cc9e3',
+ 'subordinate': ['e996c61227e711eba173acde48001122'],
+ 'revoked': False
+ }
+
+Grant information
++++++++++++++++++
+.. _`Grant information`:
+
+Grants are created by an authorization subsystem in an OP. If the grant is
+created in connection with an user authentication the authorization system
+might normally ask the user for usage consent and then base the construction
+of the grant on that consent.
+
+If an authorization server can act as a Security Token Service (STS) as
+defined by https://tools.ietf.org/html/draft-ietf-oauth-token-exchange-16
+then no user is involved. In the context of session management the STS is
+equivalent to a user.
+
+Grant information contains information about user consent and issued tokens.::
+
+ {
+ "type": "grant",
+ "scope": ["openid", "research_and_scholarship"],
+ "authorization_details": null,
+ "claims": {
+ "userinfo": {
+ "sub": null,
+ "name": null,
+ "given_name": null,
+ "family_name": null,
+ "email": null,
+ "email_verified": null,
+ "eduperson_scoped_affiliation": null
+ }
+ },
+ "resources": ["client_1"],
+ "issued_at": 1605452123,
+ "not_before": 0,
+ "expires_at": 0,
+ "revoked": false,
+ "issued_token": [
+ {
+ "type": "authorization_code",
+ "issued_at": 1605452123,
+ "not_before": 0,
+ "expires_at": 1605452423,
+ "revoked": false,
+ "value": "Z0FBQUFBQmZzVUZieDFWZy1fbjE2ckxvZkFTVC1ZTHJIVlk0Z09tOVk1M0RsOVNDbkdfLTIxTUhILWs4T29kM1lmV015UEN1UGxrWkxLTkVXOEg1WVJLNjh3MGlhMVdSRWhYcUY4cGdBQkJEbzJUWUQ3UGxTUWlJVDNFUHFlb29PWUFKcjNXeHdRM1hDYzRIZnFrYjhVZnIyTFhvZ2Y0NUhjR1VBdzE0STVEWmJ3WkttTk1OYXQtTHNtdHJwYk1nWnl3MUJqSkdWZGFtdVNfY21VNXQxY3VzalpIczBWbGFueVk0TVZ2N2d2d0hVWTF4WG56TDJ6bz0=",
+ "usage_rules": {
+ "expires_in": 300,
+ "supports_minting": [
+ "access_token",
+ "refresh_token",
+ "id_token"
+ ],
+ "max_usage": 1
+ },
+ "used": 0,
+ "based_on": null,
+ "id": "96d19bea275211eba43bacde48001122"
+ },
+ {
+ "type": "access_token",
+ "issued_at": 1605452123,
+ "not_before": 0,
+ "expires_at": 1605452723,
+ "revoked": false,
+ "value": "Z0FBQUFBQmZzVUZiaWVRbi1IS2k0VW4wVDY1ZmJHeEVCR1hVODBaQXR6MWkzelNBRFpOS2tRM3p4WWY5Y1J6dk5IWWpnelRETGVpSG52b0d4RGhjOWphdWp4eW5xZEJwQzliaS16cXFCcmRFbVJqUldsR1Z3SHdTVVlWbkpHak54TmJaSTV2T3NEQ0Y1WFkxQkFyamZHbmd4V0RHQ3k1MVczYlYwakEyM010SGoyZk9tUVVxbWdYUzBvMmRRNVlZMUhRSnM4WFd2QzRkVmtWNVJ1aVdJSXQyWnpVTlRiZnMtcVhKTklGdzBzdDJ3RkRnc1A1UEw2Yz0=",
+ "usage_rules": {
+ "expires_in": 600,
+ },
+ "used": 0,
+ "based_on": "Z0FBQUFBQmZzVUZieDFWZy1fbjE2ckxvZkFTVC1ZTHJIVlk0Z09tOVk1M0RsOVNDbkdfLTIxTUhILWs4T29kM1lmV015UEN1UGxrWkxLTkVXOEg1WVJLNjh3MGlhMVdSRWhYcUY4cGdBQkJEbzJUWUQ3UGxTUWlJVDNFUHFlb29PWUFKcjNXeHdRM1hDYzRIZnFrYjhVZnIyTFhvZ2Y0NUhjR1VBdzE0STVEWmJ3WkttTk1OYXQtTHNtdHJwYk1nWnl3MUJqSkdWZGFtdVNfY21VNXQxY3VzalpIczBWbGFueVk0TVZ2N2d2d0hVWTF4WG56TDJ6bz0=",
+ "id": "96d1c840275211eba43bacde48001122"
+ }
+ ],
+ "id": "96d16d3c275211eba43bacde48001122"
+ }
+
+The parameters are described below
+
+scope
+:::::
+
+This is the scope that was chosen for this grant. Either by the user or by
+some rules that the Authorization Server runs by.
+
+authorization_details
+:::::::::::::::::::::
+
+Presently a place hold. But this is expected to be information on how the
+authorization was performed. What input was used and so on.
+
+claims
+::::::
+
+The set of claims that should be returned in different circumstances. The
+syntax that is defined in
+https://openid.net/specs/openid-connect-core-1_0.html#ClaimsParameter
+is used. With one addition, beside *userinfo* and *id_token* we have added
+*introspection*.
+
+resources
+:::::::::
+
+This are the resource servers and other entities that should be accepted
+as users of issued access tokens.
+
+issued_at
+:::::::::
+
+When the grant was created. Its value is a JSON number representing the number
+of seconds from 1970-01-01T0:0:0Z as measured in UTC until the date/time.
+
+not_before
+::::::::::
+If the usage of the grant should be delay, this is when it can start being used.
+Its value is a JSON number representing the number
+of seconds from 1970-01-01T0:0:0Z as measured in UTC until the date/time.
+
+expires_at
+::::::::::
+When the grant expires.
+Its value is a JSON number representing the number
+of seconds from 1970-01-01T0:0:0Z as measured in UTC until the date/time.
+
+revoked
+:::::::
+If the grant has been revoked.
+
+issued_token
+::::::::::::
+Tokens that has been issued based on this grant. There is no limitation
+as to which tokens can be issued. Though presently we only have:
+
+- authorization_code,
+- access_token and
+- refresh_token
+
+id
+::
+The grant identifier.
+
+Token
++++++
+.. _`Token`:
+
+As mention above there are presently only 3 token types that are defined:
+
+- authorization_code,
+- access_token and
+- refresh_token
+
+A token is described as follows::
+
+ {
+ "type": "authorization_code",
+ "issued_at": 1605452123,
+ "not_before": 0,
+ "expires_at": 1605452423,
+ "revoked": false,
+ "value": "Z0FBQUFBQmZzVUZieDFWZy1fbjE2ckxvZkFTVC1ZTHJIVlk0Z09tOVk1M0RsOVNDbkdfLTIxTUhILWs4T29kM1lmV015UEN1UGxrWkxLTkVXOEg1WVJLNjh3MGlhMVdSRWhYcUY4cGdBQkJEbzJUWUQ3UGxTUWlJVDNFUHFlb29PWUFKcjNXeHdRM1hDYzRIZnFrYjhVZnIyTFhvZ2Y0NUhjR1VBdzE0STVEWmJ3WkttTk1OYXQtTHNtdHJwYk1nWnl3MUJqSkdWZGFtdVNfY21VNXQxY3VzalpIczBWbGFueVk0TVZ2N2d2d0hVWTF4WG56TDJ6bz0=",
+ "usage_rules": {
+ "expires_in": 300,
+ "supports_minting": [
+ "access_token",
+ "refresh_token",
+ "id_token"
+ ],
+ "max_usage": 1
+ },
+ "used": 0,
+ "based_on": null,
+ "id": "96d19bea275211eba43bacde48001122"
+ }
+
+
+type
+::::
+The type of token.
+
+issued_at
+:::::::::
+When the token was created. Its value is a JSON number representing the number
+of seconds from 1970-01-01T0:0:0Z as measured in UTC until the date/time.
+
+not_before
+::::::::::
+If the start of the usage of the token is to be delay, this is until when.
+Its value is a JSON number representing the number
+of seconds from 1970-01-01T0:0:0Z as measured in UTC until the date/time.
+
+expires_at
+::::::::::
+When the token expires.
+Its value is a JSON number representing the number
+of seconds from 1970-01-01T0:0:0Z as measured in UTC until the date/time.
+
+revoked
+:::::::
+If the token has been revoked.
+
+value
+:::::
+This is the value that appears in OIDC protocol exchanges.
+
+usage_rules
+:::::::::::
+Rules as to how this token can be used:
+
+expires_in
+ Used to calculate expires_at
+
+supports_minting
+ The tokens types that can be minted based on this token. Typically a code
+ can be used to mint ID tokens and access and refresh tokens.
+
+max_usage
+ How many times this token can be used (being used is presently defined as
+ used to mint other tokens). An authorization_code token can according to
+ the OIDC standard only be used once but then to, in the same session,
+ mint more then one token.
+
+used
+::::
+How many times the token has been used
+
+based_on
+::::::::
+Reference to the token that was used to mint this token. Might be empty if the
+token was minted based on the grant it belongs to.
+
+id
+::
+Token identifier
+
+Session Info API
+----------------
+.. _`Session Info API`:
+
+add_subordinate
++++++++++++++++
+.. _`add_subordinate`:
+
+remove_subordinate
+++++++++++++++++++
+.. _`removed_subordinate`:
+
+revoke
+++++++
+.. _`revoke`:
+
+is_revoked
+++++++++++
+.. _`is_revoked`:
+
+to_json
++++++++
+.. _`to_json`:
+
+from_json
++++++++++
+.. _`from_json`:
+
+Grant API
+---------
+.. _`Grant API`:
+
+Token API
+---------
+.. _`Token API`:
+
+Session Manager API
+-------------------
+.. _`Session Manager API`:
+
+create_session
+++++++++++++++
+.. _create_session:
+
+Creating a new session is done by running the create_session method of
+the class SessionManager. The create_session methods takes the following
+arguments.
+
+authn_event
+ An AuthnEvent class instance that describes the authentication event.
+
+auth_req
+ The Authentication request
+
+client_id
+ The client Identifier
+
+user_id
+ The user identifier
+
+sector_identifier
+ A possible sector identifier to be used when constructing a pairwise
+ subject identifier
+
+sub_type
+ The type of subject identifier that should be constructed. It can either be
+ *pairwise* or *public*.
+
+So a typical command would look like this::
+
+
+ authn_event = create_authn_event(self.user_id)
+ session_manager.create_session(authn_event=authn_event, auth_req=auth_req,
+ user_id=self.user_id, client_id=client_id,
+ sub_type=sub_type, sector_identifier=sector_identifier)
+
+add_grant
++++++++++
+.. _add_grant:
+
+add_grant(self, user_id, client_id, **kwargs)
+
+find_token
+++++++++++
+.. _find_token:
+
+find_token(self, session_id, token_value)
+
+get_authentication_event
+++++++++++++++++++++++++
+.. _get_authentication_event:
+
+get_authentication_event(self, session_id)
+
+
+get_client_session_info
++++++++++++++++++++++++
+.. _get_client_session_info:
+
+get_client_session_info(self, session_id)
+
+get_grant_by_response_type
+++++++++++++++++++++++++++
+.. _get_grant_by_response_type:
+
+get_grant_by_response_type(self, user_id, client_id)
+
+get_session_info
+++++++++++++++++
+.. _get_session_info:
+
+get_session_info(self, session_id)
+
+get_session_info_by_token
++++++++++++++++++++++++++
+.. _get_session_info_by_token:
+
+get_session_info_by_token(self, token_value)
+
+get_sids_by_user_id
++++++++++++++++++++
+.. _get_sids_by_user_id:
+
+get_sids_by_user_id(self, user_id)
+
+get_user_info
++++++++++++++
+.. _get_user_info:
+
+get_user_info(self, uid)
+
+grants
+++++++
+.. _grants:
+
+grants(self, session_id)
+
+revoke_client_session
++++++++++++++++++++++
+.. _revoke_client_session:
+
+revoke_client_session(self, session_id)
+
+revoke_grant
+++++++++++++
+.. _revoke_grant:
+
+revoke_grant(self, session_id)
+
+revoke_token
+++++++++++++
+.. _revoke_token:
+
+revoke_token(self, session_id, token_value, recursive=False)
\ No newline at end of file
diff --git a/setup.py b/setup.py
index 40eb21d..daa1506 100755
--- a/setup.py
+++ b/setup.py
@@ -52,7 +52,7 @@ def run_tests(self):
packages=["oidcendpoint", 'oidcendpoint/oidc', 'oidcendpoint/authz',
'oidcendpoint/user_authn', 'oidcendpoint/user_info',
'oidcendpoint/oauth2', 'oidcendpoint/oidc/add_on',
- 'oidcendpoint/common'],
+ 'oidcendpoint/session', 'oidcendpoint/token'],
package_dir={"": "src"},
classifiers=[
"Development Status :: 4 - Beta",
diff --git a/src/oidcendpoint/__init__.py b/src/oidcendpoint/__init__.py
index 653aef2..2bb7941 100755
--- a/src/oidcendpoint/__init__.py
+++ b/src/oidcendpoint/__init__.py
@@ -1,7 +1,7 @@
import string
from secrets import choice
-__version__ = "1.1.1"
+__version__ = '2.0.0'
DEF_SIGN_ALG = {
"id_token": "RS256",
diff --git a/src/oidcendpoint/authn_event.py b/src/oidcendpoint/authn_event.py
index 952078f..d309195 100644
--- a/src/oidcendpoint/authn_event.py
+++ b/src/oidcendpoint/authn_event.py
@@ -1,20 +1,22 @@
from oidcmsg.message import SINGLE_OPTIONAL_INT
+from oidcmsg.message import SINGLE_OPTIONAL_STRING
from oidcmsg.message import SINGLE_REQUIRED_STRING
from oidcmsg.message import Message
from oidcmsg.time_util import time_sans_frac
DEFAULT_AUTHN_EXPIRES_IN = 3600
+
class AuthnEvent(Message):
c_param = {
"uid": SINGLE_REQUIRED_STRING,
- "salt": SINGLE_REQUIRED_STRING,
"authn_info": SINGLE_REQUIRED_STRING,
"authn_time": SINGLE_OPTIONAL_INT,
"valid_until": SINGLE_OPTIONAL_INT,
+ "sub": SINGLE_OPTIONAL_STRING
}
- def valid(self, now=0):
+ def is_valid(self, now=0):
if now:
return self["valid_until"] > now
else:
@@ -24,30 +26,42 @@ def expires_in(self):
return self["valid_until"] - time_sans_frac()
-def create_authn_event(uid, salt, authn_info=None, **kwargs):
+def create_authn_event(uid, authn_info=None, authn_time: int = 0,
+ valid_until: int = 0, expires_in: int = 0,
+ sub: str = "", **kwargs):
"""
- :param uid:
- :param salt:
- :param authn_info:
+ :param uid: User ID. This is the identifier used by the user DB
+ :param authn_time: When the authentication took place
+ :param authn_info: Information about the authentication
+ :param valid_until: Until when the authentication is valid
+ :param expires_in: How long before the authentication expires
+ :param sub: Subject identifier. The identifier for the user used between
+ the AS and the RP.
:param kwargs:
:return:
"""
- args = {"uid": uid, "salt": salt, "authn_info": authn_info}
+ args = {"uid": uid, "authn_info": authn_info}
+
+ if sub:
+ args["sub"] = sub
- try:
- args["authn_time"] = int(kwargs["authn_time"])
- except KeyError:
- try:
- args["authn_time"] = int(kwargs["timestamp"])
- except KeyError:
+ if authn_time:
+ args["authn_time"] = authn_time
+ else:
+ _ts = kwargs.get("timestamp")
+ if _ts:
+ args["authn_time"] = _ts
+ else:
args["authn_time"] = time_sans_frac()
- try:
- args["valid_until"] = kwargs["valid_until"]
- except KeyError:
- _expires_in = kwargs.get("expires_in", DEFAULT_AUTHN_EXPIRES_IN)
- args["valid_until"] = args["authn_time"] + _expires_in
+ if valid_until:
+ args["valid_until"] = valid_until
+ else:
+ if expires_in:
+ args["valid_until"] = args["authn_time"] + expires_in
+ else:
+ args["valid_until"] = args["authn_time"] + DEFAULT_AUTHN_EXPIRES_IN
return AuthnEvent(**args)
diff --git a/src/oidcendpoint/authz/__init__.py b/src/oidcendpoint/authz/__init__.py
index 3869cc2..2277df8 100755
--- a/src/oidcendpoint/authz/__init__.py
+++ b/src/oidcendpoint/authz/__init__.py
@@ -1,9 +1,13 @@
+import copy
import inspect
import logging
import sys
+from typing import Optional
+from typing import Union
-from oidcendpoint import sanitize
-from oidcendpoint.cookie import cookie_value
+from oidcmsg.message import Message
+
+from oidcendpoint.session.grant import Grant
logger = logging.getLogger(__name__)
@@ -11,53 +15,89 @@
class AuthzHandling(object):
""" Class that allow an entity to manage authorization """
- def __init__(self, endpoint_context, **kwargs):
+ def __init__(self, endpoint_context, grant_config=None, **kwargs):
self.endpoint_context = endpoint_context
self.cookie_dealer = endpoint_context.cookie_dealer
- self.permdb = {}
+ self.grant_config = grant_config or {}
self.kwargs = kwargs
- def __call__(self, *args, **kwargs):
- return ""
+ def usage_rules(self, client_id):
+ if "usage_rules" in self.grant_config:
+ _usage_rules = copy.deepcopy(self.grant_config["usage_rules"])
+ else:
+ _usage_rules = {}
- def set(self, uid, client_id, permission):
try:
- self.permdb[uid][client_id] = permission
+ _per_client = self.endpoint_context.cdb[client_id]["token_usage_rules"]
except KeyError:
- self.permdb[uid] = {client_id: permission}
-
- def permissions(self, cookie=None, **kwargs):
- if cookie is None:
- return None
+ pass
else:
- logger.debug("kwargs: %s" % sanitize(kwargs))
-
- val = self.cookie_dealer.get_cookie_value(cookie)
- if val is None:
- return None
+ if _usage_rules:
+ for _token_type, _rule in _usage_rules.items():
+ _pc = _per_client.get(_token_type)
+ if _pc:
+ _rule.update(_pc)
+ for _token_type, _rule in _per_client.items():
+ if _token_type not in _usage_rules:
+ _usage_rules[_token_type] = _rule
else:
- b64, _ts, typ = val
+ _usage_rules = _per_client
- info = cookie_value(b64)
- return self.get(info["sub"], info["client_id"])
+ return _usage_rules
- def get(self, uid, client_id):
+ def usage_rules_for(self, client_id, token_type):
+ _token_usage = self.usage_rules(client_id=client_id)
try:
- return self.permdb[uid][client_id]
+ return _token_usage[token_type]
except KeyError:
- return None
+ return {}
+
+ def __call__(self, session_id: str, request: Union[dict, Message],
+ resources: Optional[list] = None) -> Grant:
+ args = self.grant_config.copy()
+
+ scope = request.get("scope")
+ if scope:
+ args["scope"] = scope
+
+ claims = request.get("claims")
+ if claims:
+ if isinstance(request, Message):
+ claims = claims.to_dict()
+ args["claims"] = claims
+
+ session_info = self.endpoint_context.session_manager.get_session_info(
+ session_id=session_id, grant=True
+ )
+ grant = session_info["grant"]
+
+ for key, val in args.items():
+ if key == "expires_in":
+ grant.set_expires_at(val)
+ else:
+ setattr(grant, key, val)
+ if resources is None:
+ grant.resources = [session_info["client_id"]]
+ else:
+ grant.resources = resources
-class Implicit(AuthzHandling):
- def __init__(self, endpoint_context, permission="implicit"):
- AuthzHandling.__init__(self, endpoint_context)
- self.permission = permission
+ # This is where user consent should be handled
+ for interface in ["userinfo", "introspection", "id_token", "access_token"]:
+ grant.claims[interface] = self.endpoint_context.claims_interface.get_claims(
+ session_id=session_id, scopes=request["scope"], usage=interface
+ )
+ return grant
- def permissions(self, cookie=None, **kwargs):
- return self.permission
- def get(self, uid, client_id):
- return self.permission
+class Implicit(AuthzHandling):
+ def __call__(self, session_id: str, request: Union[dict, Message],
+ resources: Optional[list] = None) -> Grant:
+ args = self.grant_config.copy()
+ grant = self.endpoint_context.session_manager.get_grant(session_id=session_id)
+ for arg, val in args:
+ setattr(grant, arg, val)
+ return grant
def factory(msgtype, endpoint_context, **kwargs):
diff --git a/src/oidcendpoint/common/__init__.py b/src/oidcendpoint/common/__init__.py
deleted file mode 100644
index e69de29..0000000
diff --git a/src/oidcendpoint/common/authorization.py b/src/oidcendpoint/common/authorization.py
deleted file mode 100755
index 06f0158..0000000
--- a/src/oidcendpoint/common/authorization.py
+++ /dev/null
@@ -1,276 +0,0 @@
-import logging
-from urllib.parse import unquote
-from urllib.parse import urlencode
-from urllib.parse import urlparse
-
-from oidcmsg.exception import ParameterError
-from oidcmsg.exception import URIError
-from oidcmsg.oauth2 import AuthorizationErrorResponse
-from oidcmsg.oidc import AuthorizationResponse
-from oidcmsg.oidc import verified_claim_name
-
-from oidcendpoint import sanitize
-from oidcendpoint.exception import RedirectURIError
-from oidcendpoint.exception import UnknownClient
-from oidcendpoint.util import split_uri
-
-logger = logging.getLogger(__name__)
-
-FORM_POST = """
-
- Submit This Form
-
-
-
-
-"""
-
-
-def inputs(form_args):
- """
- Creates list of input elements
- """
- element = []
- html_field = ''
- for name, value in form_args.items():
- element.append(html_field.format(name, value))
- return "\n".join(element)
-
-
-def max_age(request):
- verified_request = verified_claim_name("request")
- return request.get(verified_request, {}).get("max_age") or request.get("max_age", 0)
-
-
-def verify_uri(endpoint_context, request, uri_type, client_id=None):
- """
- A redirect URI
- MUST NOT contain a fragment
- MAY contain query component
-
- :param endpoint_context: An EndpointContext instance
- :param request: The authorization request
- :param uri_type: redirect_uri or post_logout_redirect_uri
- :return: An error response if the redirect URI is faulty otherwise
- None
- """
- _cid = request.get("client_id", client_id)
-
- if not _cid:
- logger.error("No client id found")
- raise UnknownClient("No client_id provided")
- else:
- logger.debug("Client ID: {}".format(_cid))
-
- _redirect_uri = unquote(request[uri_type])
-
- part = urlparse(_redirect_uri)
- if part.fragment:
- raise URIError("Contains fragment")
-
- (_base, _query) = split_uri(_redirect_uri)
- # if _query:
- # _query = parse_qs(_query)
-
- match = False
- # Get the clients registered redirect uris
- client_info = endpoint_context.cdb.get(_cid, {})
- if not client_info:
- raise KeyError("No such client")
- logger.debug("Client info: {}".format(client_info))
- redirect_uris = client_info.get("{}s".format(uri_type))
- if not redirect_uris:
- if _cid not in endpoint_context.cdb:
- logger.debug("CIDs: {}".format(list(endpoint_context.cdb.keys())))
- raise KeyError("No such client")
- raise ValueError("No registered {}".format(uri_type))
- else:
- for regbase, rquery in redirect_uris:
- # The URI MUST exactly match one of the Redirection URI
- if _base == regbase:
- # every registered query component must exist in the uri
- if rquery:
- if not _query:
- raise ValueError("Missing query part")
-
- for key, vals in rquery.items():
- if key not in _query:
- raise ValueError('"{}" not in query part'.format(key))
-
- for val in vals:
- if val not in _query[key]:
- raise ValueError(
- "{}={} value not in query part".format(key, val)
- )
-
- # and vice versa, every query component in the uri
- # must be registered
- if _query:
- if not rquery:
- raise ValueError("No registered query part")
-
- for key, vals in _query.items():
- if key not in rquery:
- raise ValueError('"{}" extra in query part'.format(key))
- for val in vals:
- if val not in rquery[key]:
- raise ValueError(
- "Extra {}={} value in query part".format(key, val)
- )
- match = True
- break
- if not match:
- raise RedirectURIError("Doesn't match any registered uris")
-
-
-def join_query(base, query):
- """
-
- :param base: URL base
- :param query: query part as a dictionary
- :return:
- """
- if query:
- return "{}?{}".format(base, urlencode(query, doseq=True))
- else:
- return base
-
-
-def get_uri(endpoint_context, request, uri_type):
- """ verify that the redirect URI is reasonable.
-
- :param endpoint_context: An EndpointContext instance
- :param request: The Authorization request
- :param uri_type: 'redirect_uri' or 'post_logout_redirect_uri'
- :return: redirect_uri
- """
- uri = ""
-
- if uri_type in request:
- verify_uri(endpoint_context, request, uri_type)
- uri = request[uri_type]
- else:
- uris = "{}s".format(uri_type)
- client_id = str(request["client_id"])
- if client_id in endpoint_context.cdb:
- _specs = endpoint_context.cdb[client_id].get(uris)
- if not _specs:
- raise ParameterError("Missing {} and none registered".format(uri_type))
-
- if len(_specs) > 1:
- raise ParameterError(
- "Missing {} and more than one registered".format(uri_type)
- )
-
- uri = join_query(*_specs[0])
-
- return uri
-
-
-def authn_args_gather(request, authn_class_ref, cinfo, **kwargs):
- """
- Gather information to be used by the authentication method
- """
- authn_args = {
- "authn_class_ref": authn_class_ref,
- "query": request.to_urlencoded(),
- "return_uri": request["redirect_uri"],
- }
-
- if "req_user" in kwargs:
- authn_args["as_user"] = (kwargs["req_user"],)
-
- # Below are OIDC specific. Just ignore if OAuth2
- for attr in ["policy_uri", "logo_uri", "tos_uri"]:
- if cinfo.get(attr):
- authn_args[attr] = cinfo[attr]
-
- for attr in ["ui_locales", "acr_values", "login_hint"]:
- if request.get(attr):
- authn_args[attr] = request[attr]
-
- return authn_args
-
-
-def create_authn_response(endpoint, request, sid):
- """
-
- :param endpoint:
- :param request:
- :param sid:
- :return:
- """
- # create the response
- aresp = AuthorizationResponse()
- if request.get("state"):
- aresp["state"] = request["state"]
-
- if "response_type" in request and request["response_type"] == ["none"]:
- fragment_enc = False
- else:
- _context = endpoint.endpoint_context
- _sinfo = _context.sdb[sid]
-
- if request.get("scope"):
- aresp["scope"] = request["scope"]
-
- rtype = set(request["response_type"][:])
- handled_response_type = []
-
- fragment_enc = True
- if len(rtype) == 1 and "code" in rtype:
- fragment_enc = False
-
- if "code" in request["response_type"]:
- _code = aresp["code"] = _context.sdb[sid]["code"]
- handled_response_type.append("code")
- else:
- _context.sdb.update(sid, code=None)
- _code = None
-
- if "token" in rtype:
- _dic = _context.sdb.upgrade_to_token(issue_refresh=False, key=sid)
-
- logger.debug("_dic: %s" % sanitize(_dic))
- for key, val in _dic.items():
- if key in aresp.parameters() and val is not None:
- aresp[key] = val
-
- handled_response_type.append("token")
-
- _access_token = aresp.get("access_token", None)
-
- not_handled = rtype.difference(handled_response_type)
- if not_handled:
- resp = AuthorizationErrorResponse(
- error="invalid_request", error_description="unsupported_response_type"
- )
- return {"response_args": resp, "fragment_enc": fragment_enc}
-
- return {"response_args": aresp, "fragment_enc": fragment_enc}
-
-
-class AllowedAlgorithms:
- def __init__(self, algorithm_parameters):
- self.algorithm_parameters = algorithm_parameters
-
- def __call__(self, client_id, endpoint_context, alg, alg_type):
- _cinfo = endpoint_context.cdb[client_id]
- _pinfo = endpoint_context.provider_info
-
- _reg, _sup = self.algorithm_parameters[alg_type]
- _allowed = _cinfo.get(_reg)
- if _allowed is None:
- _allowed = _pinfo.get(_sup)
-
- if alg not in _allowed:
- logger.error(
- "Signing alg user: {} not among allowed: {}".format(alg, _allowed)
- )
- raise ValueError("Not allowed '%s' algorithm used", alg)
-
-
-def re_authenticate(request, authn):
- return False
diff --git a/src/oidcendpoint/cookie.py b/src/oidcendpoint/cookie.py
index bbbde9f..fa6b6cc 100755
--- a/src/oidcendpoint/cookie.py
+++ b/src/oidcendpoint/cookie.py
@@ -202,7 +202,7 @@ def make_cookie_content(
:param max_age: The time in seconds for when a cookie will be deleted
:type max_age: int
:param secure: A secure cookie is only sent to the server with an encrypted request over the
- HTTPS protocol.
+ HTTPS protocol.
:type secure: boolean
:param http_only: HttpOnly cookies are inaccessible to JavaScript's Document.cookie API
:type http_only: boolean
@@ -591,4 +591,7 @@ def new_cookie(endpoint_context, cookie_name=None, typ="sso", **kwargs):
def cookie_value(b64):
- return json.loads(as_unicode(b64d(as_bytes(b64))))
+ try:
+ return json.loads(as_unicode(b64d(as_bytes(b64))))
+ except Exception:
+ return b64
diff --git a/src/oidcendpoint/endpoint.py b/src/oidcendpoint/endpoint.py
index b853c22..58a882f 100755
--- a/src/oidcendpoint/endpoint.py
+++ b/src/oidcendpoint/endpoint.py
@@ -1,20 +1,22 @@
import logging
from functools import cmp_to_key
+from typing import Optional
+from typing import Union
from urllib.parse import urlparse
from cryptojwt import jwe
from cryptojwt.jws.jws import SIGNER_ALGS
-from oidcendpoint.token_handler import UnknownToken
from oidcmsg.exception import MissingRequiredAttribute
from oidcmsg.exception import MissingRequiredValue
from oidcmsg.message import Message
-from oidcmsg.oauth2 import ResponseMessage, AuthorizationErrorResponse
+from oidcmsg.oauth2 import ResponseMessage
from oidcendpoint import sanitize
from oidcendpoint.client_authn import UnknownOrNoAuthnMethod
from oidcendpoint.client_authn import client_auth_setup
from oidcendpoint.client_authn import verify_client
from oidcendpoint.exception import UnAuthorizedClient
+from oidcendpoint.token.exception import UnknownToken
from oidcendpoint.util import OAUTH2_NOCACHE_HEADERS
__author__ = "Roland Hedberg"
@@ -38,24 +40,25 @@
- post_construct (*)
- update_http_args
-do_response returns a dictionary that can look like this:
-{
- 'response':
- _response as a string or as a Message instance_
- 'http_headers': [
- ('Content-type', 'application/json'),
- ('Pragma', 'no-cache'),
- ('Cache-Control', 'no-store')
- ],
- 'cookie': _list of cookies_,
- 'response_placement': 'body'
-}
+do_response returns a dictionary that can look like this::
+
+ {
+ 'response':
+ _response as a string or as a Message instance_
+ 'http_headers': [
+ ('Content-type', 'application/json'),
+ ('Pragma', 'no-cache'),
+ ('Cache-Control', 'no-store')
+ ],
+ 'cookie': _list of cookies_,
+ 'response_placement': 'body'
+ }
"response" MUST be present
"http_headers" MAY be present
"cookie": MAY be present
"response_placement": If absent defaults to the endpoints response_placement
- parameter value or if that is also missing 'url'
+parameter value or if that is also missing 'url'
"""
@@ -118,8 +121,8 @@ def construct_endpoint_info(default_capabilities, **kwargs):
elif "signing_alg_values_supported" in attr:
_info[attr] = assign_algorithms("signing_alg")
if (
- attr
- == "token_endpoint_auth_signing_alg_values_supported"
+ attr
+ == "token_endpoint_auth_signing_alg_values_supported"
):
# none must not be in
# token_endpoint_auth_signing_alg_values_supported
@@ -191,7 +194,7 @@ def __init__(self, endpoint_context, **kwargs):
if _methods:
self.client_authn_method = client_auth_setup(_methods, endpoint_context)
elif (
- _methods is not None
+ _methods is not None
): # [] or '' or something not None but regarded as nothing.
self.client_authn_method = [None] # Ignore default value
elif self.default_capabilities:
@@ -205,8 +208,9 @@ def __init__(self, endpoint_context, **kwargs):
# This is for matching against aud in JWTs
# By default the endpoint's endpoint URL is an allowed target
self.allowed_targets = [self.name]
+ self.client_verification_method = []
- def parse_request(self, request, auth=None, **kwargs):
+ def parse_request(self, request: str, auth=None, **kwargs):
"""
:param request: The request the server got
@@ -307,9 +311,9 @@ def client_authentication(self, request, auth=None, **kwargs):
LOGGER.debug("authn_info: %s", authn_info)
if (
- authn_info == {}
- and self.client_authn_method
- and len(self.client_authn_method)
+ authn_info == {}
+ and self.client_authn_method
+ and len(self.client_authn_method)
):
LOGGER.debug("client_authn_method: %s", self.client_authn_method)
raise UnAuthorizedClient("Authorization failed")
@@ -341,7 +345,7 @@ def do_post_construct(self, response_args, request, **kwargs):
return response_args
- def process_request(self, request=None, **kwargs):
+ def process_request(self, request: Optional[Union[Message, dict]] = None, **kwargs):
"""
:param request: The request, can be in a number of formats
@@ -349,7 +353,10 @@ def process_request(self, request=None, **kwargs):
"""
return {}
- def construct(self, response_args, request, **kwargs):
+ def construct(self,
+ response_args: Optional[dict] = None,
+ request: Optional[Union[Message, dict]] = None,
+ **kwargs):
"""
Construct the response
@@ -365,12 +372,20 @@ def construct(self, response_args, request, **kwargs):
return self.do_post_construct(response, request, **kwargs)
- def response_info(self, response_args, request, **kwargs):
+ def response_info(self,
+ response_args: Optional[dict] = None,
+ request: Optional[Union[Message, dict]] = None,
+ **kwargs) -> dict:
return self.construct(response_args, request, **kwargs)
- def do_response(self, response_args=None, request=None, error="", **kwargs):
+ def do_response(self,
+ response_args: Optional[dict] = None,
+ request: Optional[Union[Message, dict]] = None,
+ error: Optional[str] = "", **kwargs) -> dict:
"""
-
+ :param response_args: Information to use when constructing the response
+ :param request: The original request
+ :param error: Possible error encountered while processing the request
"""
do_placement = True
content_type = "text/html"
diff --git a/src/oidcendpoint/endpoint_context.py b/src/oidcendpoint/endpoint_context.py
index 10b6430..2abfa8c 100755
--- a/src/oidcendpoint/endpoint_context.py
+++ b/src/oidcendpoint/endpoint_context.py
@@ -5,17 +5,15 @@
from jinja2 import Environment
from jinja2 import FileSystemLoader
from oidcmsg.context import OidcContext
-from oidcmsg.oidc import IdToken
from oidcendpoint import authz
from oidcendpoint import rndstr
from oidcendpoint.id_token import IDToken
from oidcendpoint.scopes import SCOPE2CLAIMS
-from oidcendpoint.scopes import STANDARD_CLAIMS
-from oidcendpoint.scopes import Claims
from oidcendpoint.scopes import Scopes
-from oidcendpoint.session import create_session_db
-from oidcendpoint.sso_db import SSODb
+from oidcendpoint.session.claims import STANDARD_CLAIMS
+from oidcendpoint.session.claims import ClaimsInterface
+from oidcendpoint.session.manager import create_session_manager
from oidcendpoint.template_handler import Jinja2TemplateHandler
from oidcendpoint.user_authn.authn_context import populate_authn_broker
from oidcendpoint.util import allow_refresh_token
@@ -87,34 +85,28 @@ def get_token_handlers(conf):
class EndpointContext(OidcContext):
def __init__(
- self,
- conf,
- keyjar=None,
- cwd="",
- cookie_dealer=None,
- httpc=None,
- cookie_name=None,
- jwks_uri_path=None,
+ self,
+ conf,
+ keyjar=None,
+ cwd="",
+ cookie_dealer=None,
+ httpc=None,
):
OidcContext.__init__(self, conf, keyjar, entity_id=conf.get("issuer", ""))
self.conf = conf
# For my Dev environment
- self.sso_db = None
- self.session_db = None
- self.state_db = None
self.cdb = None
self.jti_db = None
self.registration_access_token = None
+ self.session_db = None
self.add_boxes(
{
- "state": "state_db",
"client": "cdb",
"jti": "jti_db",
"registration_access_token": "registration_access_token",
- "sso": "sso_db",
- "session": "session_db",
+ "session": "session_db"
},
self.db_conf,
)
@@ -134,7 +126,7 @@ def __init__(
self.jwks_uri = None
self.sso_ttl = 14400 # 4h
self.symkey = rndstr(24)
- self.id_token_schema = IdToken
+ # self.id_token_schema = IdToken
self.idtoken = None
self.authn_broker = None
self.authz = None
@@ -152,7 +144,7 @@ def __init__(
"sso_ttl",
"symkey",
"client_authn",
- "id_token_schema",
+ # "id_token_schema",
]:
try:
setattr(self, param, conf[param])
@@ -172,9 +164,7 @@ def __init__(
# has to be after the above
self.set_session_db()
- if cookie_name:
- self.cookie_name = cookie_name
- elif "cookie_name" in conf:
+ if "cookie_name" in conf:
self.cookie_name = conf["cookie_name"]
else:
self.cookie_name = {
@@ -196,11 +186,7 @@ def __init__(
self.template_handler = Jinja2TemplateHandler(loader)
self.setup = {}
- if not jwks_uri_path:
- try:
- jwks_uri_path = conf["keys"]["uri_path"]
- except KeyError:
- pass
+ jwks_uri_path = conf["keys"]["uri_path"]
try:
if self.issuer.endswith("/"):
@@ -248,7 +234,7 @@ def __init__(
self.httpc_params = {"verify": conf.get("verify_ssl")}
self.set_scopes_handler()
- self.set_claims_handler()
+ # self.set_claims_handler()
# If pushed authorization is supported
if "pushed_authorization_request_endpoint" in self.provider_info:
@@ -260,6 +246,8 @@ def __init__(
self.dev_auth_db = None
self.add_boxes({"dev_auth": "dev_auth_db"}, self.db_conf)
+ self.claims_interface = ClaimsInterface(self)
+
def set_scopes_handler(self):
_spec = self.conf.get("scopes_handler")
if _spec:
@@ -269,20 +257,19 @@ def set_scopes_handler(self):
else:
self.scopes_handler = Scopes()
- def set_claims_handler(self):
- _spec = self.conf.get("claims_handler")
- if _spec:
- _kwargs = _spec.get("kwargs", {})
- _cls = importer(_spec["class"])(**_kwargs)
- self.claims_handler = _cls(_kwargs)
- else:
- self.claims_handler = Claims()
+ # def set_claims_handler(self):
+ # _spec = self.conf.get("claims_handler")
+ # if _spec:
+ # _kwargs = _spec.get("kwargs", {})
+ # _cls = importer(_spec["class"])(**_kwargs)
+ # self.claims_handler = _cls(_kwargs)
+ # else:
+ # self.claims_handler = Claims()
def set_session_db(self):
- self.do_session_db(SSODb(db=self.sso_db), self.session_db)
+ self.do_session_manager()
# append userinfo db to the session db
self.do_userinfo()
- logger.debug("Session DB: {}".format(self.sdb.__dict__))
def do_add_on(self):
if self.conf.get("add_on"):
@@ -311,9 +298,9 @@ def do_login_hint_lookup(self):
def do_userinfo(self):
_conf = self.conf.get("userinfo")
if _conf:
- if self.sdb:
+ if self.session_manager:
self.userinfo = init_user_info(_conf, self.cwd)
- self.sdb.userinfo = self.userinfo
+ self.session_manager.userinfo = self.userinfo
else:
logger.warning("Cannot init_user_info if no session_db was provided.")
@@ -357,10 +344,15 @@ def do_sub_func(self):
else:
self._sub_func[key] = args["function"]
- def do_session_db(self, sso_db, db=None):
- self.sdb = create_session_db(
- self, self.th_args, db=db, sso_db=sso_db, sub_func=self._sub_func
- )
+ def do_session_manager(self, db=None):
+ if self.session_db is None:
+ self.session_manager = create_session_manager(
+ self, self.th_args, db=db, sub_func=self._sub_func
+ )
+ else:
+ self.session_manager = create_session_manager(
+ self, self.th_args, db=self.session_db, sub_func=self._sub_func
+ )
def do_endpoints(self):
self.endpoint = build_endpoints(
diff --git a/src/oidcendpoint/id_token.py b/src/oidcendpoint/id_token.py
index 844d2d3..cbc4647 100755
--- a/src/oidcendpoint/id_token.py
+++ b/src/oidcendpoint/id_token.py
@@ -1,11 +1,13 @@
import logging
+import uuid
from cryptojwt.jws.utils import left_hash
from cryptojwt.jwt import JWT
from oidcendpoint.endpoint import construct_endpoint_info
-from oidcendpoint.userinfo import collect_user_info
-from oidcendpoint.userinfo import userinfo_in_id_token_claims
+from oidcendpoint.session import unpack_session_key
+from oidcendpoint.session.claims import claims_match
+from oidcendpoint.session.info import SessionInfo
logger = logging.getLogger(__name__)
@@ -50,7 +52,7 @@ def include_session_id(endpoint_context, client_id, where):
def get_sign_and_encrypt_algorithms(
- endpoint_context, client_info, payload_type, sign=False, encrypt=False
+ endpoint_context, client_info, payload_type, sign=False, encrypt=False
):
args = {"sign": sign, "encrypt": encrypt}
if sign:
@@ -112,49 +114,49 @@ class IDToken(object):
def __init__(self, endpoint_context, **kwargs):
self.endpoint_context = endpoint_context
self.kwargs = kwargs
- self.enable_claims_per_client = kwargs.get("enable_claims_per_client", False)
- self.add_c_hash = kwargs.get("add_c_hash", True)
- self.add_at_hash = kwargs.get("add_at_hash", True)
self.scope_to_claims = None
self.provider_info = construct_endpoint_info(
self.default_capabilities, **kwargs
)
def payload(
- self,
- session,
- acr="",
- alg="RS256",
- code=None,
- access_token=None,
- user_info=None,
- auth_time=0,
- lifetime=None,
- extra_claims=None,
+ self,
+ session_id,
+ alg="RS256",
+ code=None,
+ access_token=None,
+ extra_claims=None,
):
"""
- :param session: Session information
- :param acr: Default Assurance/Authentication context class reference
+ :param session_id: Session identifier
:param alg: Which signing algorithm to use for the IdToken
:param code: Access grant
:param access_token: Access Token
- :param user_info: If user info are to be part of the IdToken
- :param auth_time:
- :param lifetime: Life time of the ID Token
:param extra_claims: extra claims to be added to the ID Token
:return: IDToken instance
"""
- _args = {"sub": session["sub"]}
-
- if lifetime is None:
- lifetime = DEF_LIFETIME
-
- if auth_time:
- _args["auth_time"] = auth_time
- if acr:
- _args["acr"] = acr
+ _mngr = self.endpoint_context.session_manager
+ session_information = _mngr.get_session_info(session_id, grant=True)
+ grant = session_information["grant"]
+ _args = {"sub": grant.sub}
+ if grant.authentication_event:
+ for claim, attr in {"authn_time": "auth_time", "authn_info": "acr"}.items():
+ _val = grant.authentication_event.get(claim)
+ if _val:
+ _args[attr] = _val
+
+ _claims_restriction = grant.claims.get("id_token")
+ if _claims_restriction == {}:
+ user_info = None
+ else:
+ user_info = self.endpoint_context.claims_interface.get_user_claims(
+ user_id=session_information["user_id"],
+ claims_restriction=_claims_restriction)
+ if _claims_restriction and "acr" in _claims_restriction and "acr" in _args:
+ if claims_match(_args["acr"], _claims_restriction["acr"]) is False:
+ raise ValueError("Could not match expected 'acr'")
if user_info:
try:
@@ -178,46 +180,37 @@ def payload(
halg = "HS%s" % alg[-3:]
if code:
_args["c_hash"] = left_hash(code.encode("utf-8"), halg)
- elif self.add_c_hash and session.get("code"):
- _args["c_hash"] = left_hash(
- session.get("code").encode("utf-8"), halg
- )
if access_token:
_args["at_hash"] = left_hash(access_token.encode("utf-8"), halg)
- elif self.add_at_hash and session.get("access_token"):
- _args["at_hash"] = left_hash(
- session.get("access_token").encode("utf-8"), halg
- )
- authn_req = session["authn_req"]
+ authn_req = grant.authorization_request
if authn_req:
try:
_args["nonce"] = authn_req["nonce"]
except KeyError:
pass
- return {"payload": _args, "lifetime": lifetime}
+ return _args
def sign_encrypt(
- self,
- session_info,
- client_id,
- code=None,
- access_token=None,
- user_info=None,
- sign=True,
- encrypt=False,
- lifetime=None,
- extra_claims=None,
+ self,
+ session_id,
+ client_id,
+ code=None,
+ access_token=None,
+ sign=True,
+ encrypt=False,
+ lifetime=None,
+ extra_claims=None,
):
"""
Signed and or encrypt a IDToken
+ :param lifetime: How long the ID Token should be valid
:param session_info: Session information
:param client_id: Client ID
:param code: Access grant
:param access_token: Access Token
- :param user_info: User information
:param sign: If the JWT should be signed
:param encrypt: If the JWT should be encrypted
:param extra_claims: Extra claims to be added to the ID Token
@@ -231,69 +224,52 @@ def sign_encrypt(
_cntx, client_info, "id_token", sign=sign, encrypt=encrypt
)
- _authn_event = session_info["authn_event"]
-
- _idt_info = self.payload(
- session_info,
- acr=_authn_event["authn_info"],
+ _payload = self.payload(
+ session_id=session_id,
alg=alg_dict["sign_alg"],
code=code,
access_token=access_token,
- user_info=user_info,
- auth_time=_authn_event["authn_time"],
- lifetime=lifetime,
- extra_claims=extra_claims,
+ extra_claims=extra_claims
)
+ if lifetime is None:
+ lifetime = DEF_LIFETIME
+
_jwt = JWT(
- _cntx.keyjar, iss=_cntx.issuer, lifetime=_idt_info["lifetime"], **alg_dict
+ _cntx.keyjar, iss=_cntx.issuer, lifetime=lifetime, **alg_dict
)
- return _jwt.pack(_idt_info["payload"], recv=client_id)
+ return _jwt.pack(_payload, recv=client_id)
- def make(self, req, sess_info, authn_req=None, user_claims=False, **kwargs):
+ def make(self, session_id, **kwargs):
_context = self.endpoint_context
- if authn_req:
- _client_id = authn_req["client_id"]
+ user_id, client_id, grant_id = unpack_session_key(session_id)
+
+ # Should I add session ID. This is about Single Logout.
+ if include_session_id(_context, client_id, "back") or include_session_id(
+ _context, client_id, "front"):
+
+ # Note that this session ID is not the session ID the session manager is using.
+ # It must be possible to map from one to the other.
+ logout_session_id = uuid.uuid4().get_hex()
+ _item = SessionInfo()
+ _item.set("user_id", user_id)
+ _item.set("client_id", client_id)
+ # Store the map
+ _mngr = self.endpoint_context.session_manager
+ _mngr.set([logout_session_id], _item)
+ # add identifier to extra arguments
+ xargs = {"sid": logout_session_id}
else:
- _client_id = req["client_id"]
-
- _cinfo = _context.cdb[_client_id]
+ xargs = {}
- idtoken_claims = dict(self.kwargs.get("available_claims", {}))
- if self.enable_claims_per_client:
- idtoken_claims.update(_cinfo.get("id_token_claims", {}))
lifetime = self.kwargs.get("lifetime")
- userinfo = userinfo_in_id_token_claims(_context, sess_info, idtoken_claims)
-
- if user_claims:
- info = collect_user_info(_context, sess_info)
- if userinfo is None:
- userinfo = info
- else:
- userinfo.update(info)
-
- # Should I add session ID
- req_sid = include_session_id(
- _context, _client_id, "back"
- ) or include_session_id(_context, _client_id, "front")
-
- if req_sid:
- xargs = {
- "sid": _context.sdb.get_sid_by_sub_and_client_id(
- sess_info["sub"], _client_id
- )
- }
- else:
- xargs = {}
-
return self.sign_encrypt(
- sess_info,
- _client_id,
+ session_id,
+ client_id,
sign=True,
- user_info=userinfo,
lifetime=lifetime,
extra_claims=xargs,
**kwargs
diff --git a/src/oidcendpoint/jwt_token.py b/src/oidcendpoint/jwt_token.py
deleted file mode 100644
index e3f4ec8..0000000
--- a/src/oidcendpoint/jwt_token.py
+++ /dev/null
@@ -1,173 +0,0 @@
-from typing import Any
-from typing import Dict
-from typing import Optional
-
-from cryptojwt import JWT
-from cryptojwt.jws.exception import JWSException
-
-from oidcendpoint.exception import ToOld
-from oidcendpoint.scopes import convert_scopes2claims
-from oidcendpoint.token_handler import Token
-from oidcendpoint.token_handler import UnknownToken
-from oidcendpoint.token_handler import is_expired
-
-
-class JWTToken(Token):
- init_args = {
- "add_claims_by_scope": False,
- "enable_claims_per_client": False,
- "add_scope": False,
- "add_claims": {},
- }
-
- def __init__(
- self,
- typ,
- keyjar=None,
- issuer=None,
- aud=None,
- alg="ES256",
- lifetime=300,
- ec=None,
- token_type="Bearer",
- **kwargs
- ):
- Token.__init__(self, typ, **kwargs)
- self.token_type = token_type
- self.lifetime = lifetime
- self.args = {
- (k, v) for k, v in kwargs.items() if k not in self.init_args.keys()
- }
-
- self.key_jar = keyjar or ec.keyjar
- self.issuer = issuer or ec.issuer
- self.cdb = ec.cdb
- self.cntx = ec
-
- self.def_aud = aud or []
- self.alg = alg
- self.scope_claims_map = kwargs.get("scope_claims_map", ec.scope2claims)
-
- self.add_claims = self.init_args["add_claims"]
- self.add_claims_by_scope = self.init_args["add_claims_by_scope"]
- self.add_scope = self.init_args["add_scope"]
- self.enable_claims_per_client = self.init_args["enable_claims_per_client"]
-
- for param, default in self.init_args.items():
- setattr(self, param, kwargs.get(param, default))
-
- def do_add_claims(self, payload, uinfo, claims):
- for attr in claims:
- if attr == "sub":
- continue
- try:
- payload[attr] = uinfo[attr]
- except KeyError:
- pass
-
- def __call__(
- self,
- sid: str,
- uinfo: Dict,
- sinfo: Dict,
- aud: Optional[Any],
- client_id: Optional[str],
- **kwargs
- ):
- """
- Return a token.
-
- :param sid: Session id
- :param uinfo: User information
- :param sinfo: Session information
- :param aud: audience
- :param client_id: client_id
- :return:
- """
- payload = {"sid": sid, "ttype": self.type, "sub": sinfo["sub"]}
- scopes = sinfo["authn_req"]["scope"]
-
- if self.add_claims:
- self.do_add_claims(payload, uinfo, self.add_claims)
- if self.add_claims_by_scope:
- _allowed_claims = self.cntx.claims_handler.allowed_claims(
- client_id, self.cntx
- )
- self.do_add_claims(
- payload,
- uinfo,
- convert_scopes2claims(
- scopes,
- _allowed_claims,
- map=self.scope_claims_map,
- ).keys(),
- )
- if self.add_scope:
- payload["scope"] = self.cntx.scopes_handler.filter_scopes(
- client_id, self.cntx, scopes
- )
-
- # Add claims if is access token
- if self.type == "T" and self.enable_claims_per_client:
- client = self.cdb.get(client_id, {})
- client_claims = client.get("access_token_claims")
- if client_claims:
- self.do_add_claims(payload, uinfo, client_claims)
-
- payload.update(kwargs)
- signer = JWT(
- key_jar=self.key_jar,
- iss=self.issuer,
- lifetime=self.lifetime,
- sign_alg=self.alg,
- )
-
- if aud is None:
- _aud = self.def_aud
- else:
- _aud = aud if isinstance(aud, list) else [aud]
- _aud.extend(self.def_aud)
-
- return signer.pack(payload, aud=_aud)
-
- def info(self, token):
- """
- Return type of Token (A=Access code, T=Token, R=Refresh token) and
- the session id.
-
- :param token: A token
- :return: tuple of token type and session id
- """
- verifier = JWT(key_jar=self.key_jar, allowed_sign_algs=[self.alg])
- try:
- _payload = verifier.unpack(token)
- except JWSException:
- raise UnknownToken()
-
- if is_expired(_payload["exp"]):
- raise ToOld("Token has expired")
- # All the token metadata
- _res = {
- "sid": _payload["sid"],
- "type": _payload["ttype"],
- "exp": _payload["exp"],
- "handler": self,
- }
- return _res
-
- def is_expired(self, token, when=0):
- """
- Evaluate whether the token has expired or not
-
- :param token: The token
- :param when: The time against which to check the expiration
- 0 means now.
- :return: True/False
- """
- verifier = JWT(key_jar=self.key_jar, allowed_sign_algs=[self.alg])
- _payload = verifier.unpack(token)
- return is_expired(_payload["exp"], when)
-
- def gather_args(self, sid, sdb, udb):
- _sinfo = sdb[sid]
- return {}
diff --git a/src/oidcendpoint/oauth2/authorization.py b/src/oidcendpoint/oauth2/authorization.py
index e7896f8..d4b006e 100755
--- a/src/oidcendpoint/oauth2/authorization.py
+++ b/src/oidcendpoint/oauth2/authorization.py
@@ -1,27 +1,26 @@
import json
import logging
-import time
+from typing import Union
+from urllib.parse import unquote
+from urllib.parse import urlencode
+from urllib.parse import urlparse
from cryptojwt import BadSyntax
+from cryptojwt import as_unicode
+from cryptojwt import b64d
+from cryptojwt.jwe.exception import JWEException
+from cryptojwt.jws.exception import NoSuitableSigningKeys
from cryptojwt.utils import as_bytes
-from cryptojwt.utils import as_unicode
-from cryptojwt.utils import b64d
from cryptojwt.utils import b64e
-from oidcendpoint.token_handler import UnknownToken
from oidcmsg import oauth2
from oidcmsg.exception import ParameterError
-from oidcmsg.oidc import AuthorizationResponse
+from oidcmsg.exception import URIError
+from oidcmsg.message import Message
from oidcmsg.oidc import verified_claim_name
+from oidcmsg.time_util import utc_time_sans_frac
from oidcendpoint import rndstr
-from oidcendpoint import sanitize
from oidcendpoint.authn_event import create_authn_event
-from oidcendpoint.common.authorization import FORM_POST
-from oidcendpoint.common.authorization import AllowedAlgorithms
-from oidcendpoint.common.authorization import authn_args_gather
-from oidcendpoint.common.authorization import get_uri
-from oidcendpoint.common.authorization import inputs
-from oidcendpoint.common.authorization import max_age
from oidcendpoint.cookie import append_cookie
from oidcendpoint.cookie import compute_session_state
from oidcendpoint.cookie import new_cookie
@@ -34,12 +33,14 @@
from oidcendpoint.exception import ToOld
from oidcendpoint.exception import UnAuthorizedClientScope
from oidcendpoint.exception import UnknownClient
-from oidcendpoint.session import setup_session
+from oidcendpoint.session import Revoked
+from oidcendpoint.session import unpack_session_key
+from oidcendpoint.token.exception import UnknownToken
from oidcendpoint.user_authn.authn_context import pick_auth
+from oidcendpoint.util import split_uri
logger = logging.getLogger(__name__)
-
# For the time being. This is JAR specific and should probably be configurable.
ALG_PARAMS = {
"sign": [
@@ -56,9 +57,195 @@
],
}
+FORM_POST = """
+
+ Submit This Form
+
+
+
+
+"""
+
+
+def inputs(form_args):
+ """
+ Creates list of input elements
+ """
+ element = []
+ html_field = ''
+ for name, value in form_args.items():
+ element.append(html_field.format(name, value))
+ return "\n".join(element)
+
+
+def max_age(request):
+ verified_request = verified_claim_name("request")
+ return request.get(verified_request, {}).get("max_age") or request.get("max_age", 0)
+
+
+def verify_uri(endpoint_context, request, uri_type, client_id=None):
+ """
+ A redirect URI
+ MUST NOT contain a fragment
+ MAY contain query component
+
+ :param endpoint_context: An EndpointContext instance
+ :param request: The authorization request
+ :param uri_type: redirect_uri or post_logout_redirect_uri
+ :return: An error response if the redirect URI is faulty otherwise
+ None
+ """
+ _cid = request.get("client_id", client_id)
+
+ if not _cid:
+ logger.error("No client id found")
+ raise UnknownClient("No client_id provided")
+ else:
+ logger.debug("Client ID: {}".format(_cid))
+
+ _redirect_uri = unquote(request[uri_type])
+
+ part = urlparse(_redirect_uri)
+ if part.fragment:
+ raise URIError("Contains fragment")
+
+ (_base, _query) = split_uri(_redirect_uri)
+ # if _query:
+ # _query = parse_qs(_query)
+
+ match = False
+ # Get the clients registered redirect uris
+ client_info = endpoint_context.cdb.get(_cid, {})
+ if not client_info:
+ raise KeyError("No such client")
+ logger.debug("Client info: {}".format(client_info))
+ redirect_uris = client_info.get("{}s".format(uri_type))
+ if not redirect_uris:
+ if _cid not in endpoint_context.cdb:
+ logger.debug("CIDs: {}".format(list(endpoint_context.cdb.keys())))
+ raise KeyError("No such client")
+ raise ValueError("No registered {}".format(uri_type))
+ else:
+ for regbase, rquery in redirect_uris:
+ # The URI MUST exactly match one of the Redirection URI
+ if _base == regbase:
+ # every registered query component must exist in the uri
+ if rquery:
+ if not _query:
+ raise ValueError("Missing query part")
+
+ for key, vals in rquery.items():
+ if key not in _query:
+ raise ValueError('"{}" not in query part'.format(key))
+
+ for val in vals:
+ if val not in _query[key]:
+ raise ValueError(
+ "{}={} value not in query part".format(key, val)
+ )
+
+ # and vice versa, every query component in the uri
+ # must be registered
+ if _query:
+ if not rquery:
+ raise ValueError("No registered query part")
+
+ for key, vals in _query.items():
+ if key not in rquery:
+ raise ValueError('"{}" extra in query part'.format(key))
+ for val in vals:
+ if val not in rquery[key]:
+ raise ValueError(
+ "Extra {}={} value in query part".format(key, val)
+ )
+ match = True
+ break
+ if not match:
+ raise RedirectURIError("Doesn't match any registered uris")
+
+
+def join_query(base, query):
+ """
+
+ :param base: URL base
+ :param query: query part as a dictionary
+ :return:
+ """
+ if query:
+ return "{}?{}".format(base, urlencode(query, doseq=True))
+ else:
+ return base
+
+
+def get_uri(endpoint_context, request, uri_type):
+ """ verify that the redirect URI is reasonable.
+
+ :param endpoint_context: An EndpointContext instance
+ :param request: The Authorization request
+ :param uri_type: 'redirect_uri' or 'post_logout_redirect_uri'
+ :return: redirect_uri
+ """
+ uri = ""
+
+ if uri_type in request:
+ verify_uri(endpoint_context, request, uri_type)
+ uri = request[uri_type]
+ else:
+ uris = "{}s".format(uri_type)
+ client_id = str(request["client_id"])
+ if client_id in endpoint_context.cdb:
+ _specs = endpoint_context.cdb[client_id].get(uris)
+ if not _specs:
+ raise ParameterError("Missing {} and none registered".format(uri_type))
+
+ if len(_specs) > 1:
+ raise ParameterError(
+ "Missing {} and more than one registered".format(uri_type)
+ )
-def re_authenticate(request, authn):
- return False
+ uri = join_query(*_specs[0])
+
+ return uri
+
+
+def authn_args_gather(request, authn_class_ref, cinfo, **kwargs):
+ """
+ Gather information to be used by the authentication method
+
+ :param request: The request either as a dictionary or as a Message instance
+ :param authn_class_ref: Authentication class reference
+ :param cinfo: Client information
+ :param kwargs: Extra keyword arguments
+ :return: Authentication arguments
+ """
+ authn_args = {
+ "authn_class_ref": authn_class_ref,
+ "return_uri": request["redirect_uri"],
+ }
+
+ if isinstance(request, Message):
+ authn_args["query"] = request.to_urlencoded()
+ elif isinstance(request, dict):
+ authn_args["query"] = urlencode(request)
+ else:
+ ValueError("Wrong request format")
+
+ if "req_user" in kwargs:
+ authn_args["as_user"] = (kwargs["req_user"],)
+
+ # Below are OIDC specific. Just ignore if OAuth2
+ if cinfo:
+ for attr in ["policy_uri", "logo_uri", "tos_uri"]:
+ if cinfo.get(attr):
+ authn_args[attr] = cinfo[attr]
+
+ for attr in ["ui_locales", "acr_values", "login_hint"]:
+ if request.get(attr):
+ authn_args[attr] = request[attr]
+
+ return authn_args
def check_unknown_scopes_policy(request_info, cinfo, endpoint_context):
@@ -69,7 +256,7 @@ def check_unknown_scopes_policy(request_info, cinfo, endpoint_context):
# this prevents that authz would be released for unavailable scopes
for scope in request_info['scope']:
if op_capabilities.get('deny_unknown_scopes') and \
- scope not in client_allowed_scopes:
+ scope not in client_allowed_scopes:
_msg = '{} requested an unauthorized scope ({})'
logger.warning(_msg.format(cinfo['client_id'],
scope))
@@ -100,27 +287,50 @@ class Authorization(Endpoint):
def __init__(self, endpoint_context, **kwargs):
Endpoint.__init__(self, endpoint_context, **kwargs)
- # self.pre_construct.append(self._pre_construct)
self.post_parse_request.append(self._do_request_uri)
self.post_parse_request.append(self._post_parse_request)
self.allowed_request_algorithms = AllowedAlgorithms(ALG_PARAMS)
- # Has to be done elsewhere. To make sure things happen in order.
- # self.scopes_supported = available_scopes(endpoint_context)
def filter_request(self, endpoint_context, req):
return req
+ def extra_response_args(self, aresp):
+ return aresp
+
def verify_response_type(self, request, cinfo):
# Checking response types
_registered = [set(rt.split(" ")) for rt in cinfo.get("response_types", [])]
if not _registered:
# If no response_type is registered by the client then we'll
- # code which it the default according to the OIDC spec.
+ # use code.
_registered = [{"code"}]
# Is the asked for response_type among those that are permitted
return set(request["response_type"]) in _registered
+ def mint_token(self, token_type, grant, session_id, based_on=None):
+ _mngr = self.endpoint_context.session_manager
+ usage_rules = grant.usage_rules.get(token_type, {})
+
+ token = grant.mint_token(
+ session_id=session_id,
+ endpoint_context=self.endpoint_context,
+ token_type=token_type,
+ based_on=based_on,
+ usage_rules=usage_rules
+ )
+
+ _exp_in = usage_rules.get("expires_in")
+ if isinstance(_exp_in, str):
+ _exp_in = int(_exp_in)
+ if _exp_in:
+ token.expires_at = utc_time_sans_frac() + _exp_in
+
+ self.endpoint_context.session_manager.set(unpack_session_key(session_id),
+ grant)
+
+ return token
+
def _do_request_uri(self, request, client_id, endpoint_context, **kwargs):
_request_uri = request.get("request_uri")
if _request_uri:
@@ -144,35 +354,41 @@ def _do_request_uri(self, request, client_id, endpoint_context, **kwargs):
if _registered:
# Before matching remove a possible fragment
_p = _request_uri.split("#")
- if _p[0] not in _registered:
+ # ignore registered fragments for now.
+ if _p[0] not in [l[0] for l in _registered]:
raise ValueError("A request_uri outside the registered")
+
# Fetch the request
_resp = endpoint_context.httpc.get(
_request_uri, **endpoint_context.httpc_params
)
if _resp.status_code == 200:
- args = {"keyjar": endpoint_context.keyjar}
- request = self.request_cls().from_jwt(_resp.text, **args)
+ args = {"keyjar": endpoint_context.keyjar, "issuer": client_id}
+ _ver_request = self.request_cls().from_jwt(_resp.text, **args)
self.allowed_request_algorithms(
client_id,
endpoint_context,
- request.jws_header.get("alg", "RS256"),
+ _ver_request.jws_header.get("alg", "RS256"),
"sign",
)
- if request.jwe_header is not None:
+ if _ver_request.jwe_header is not None:
self.allowed_request_algorithms(
client_id,
endpoint_context,
- request.jws_header.get("alg"),
+ _ver_request.jws_header.get("alg"),
"enc_alg",
)
self.allowed_request_algorithms(
client_id,
endpoint_context,
- request.jws_header.get("enc"),
+ _ver_request.jws_header.get("enc"),
"enc_enc",
)
- request[verified_claim_name("request")] = request
+ # The protected info overwrites the non-protected
+ for k, v in _ver_request.items():
+ request[k] = v
+
+ request[verified_claim_name("request")] = _ver_request
else:
raise ServiceError("Got a %s response", _resp.status)
@@ -218,7 +434,7 @@ def _post_parse_request(self, request, client_id, endpoint_context, **kwargs):
except (RedirectURIError, ParameterError, UnknownClient) as err:
return self.error_cls(
error="invalid_request",
- error_description="{}: {}".format(err.__class__.__name__, err),
+ error_description="{}:{}".format(err.__class__.__name__, err),
)
else:
request["redirect_uri"] = redirect_uri
@@ -292,15 +508,25 @@ def setup_auth(self, request, redirect_uri, cinfo, cookie, acr=None, **kwargs):
else:
identity = json.loads(as_unicode(_id))
- session = self.endpoint_context.sdb[identity.get("sid")]
- if not session or "revoked" in session:
+ try:
+ _csi = self.endpoint_context.session_manager[identity.get("sid")]
+ except Revoked:
identity = None
+ else:
+ if _csi.is_active() is False:
+ identity = None
authn_args = authn_args_gather(request, authn_class_ref, cinfo, **kwargs)
+ _mngr = self.endpoint_context.session_manager
+ _session_id = ""
# To authenticate or Not
if identity is None: # No!
logger.info("No active authentication")
+ logger.debug(
+ "Known clients: {}".format(list(self.endpoint_context.cdb.keys()))
+ )
+
if "prompt" in request and "none" in request["prompt"]:
# Need to authenticate but not allowed
return {
@@ -319,14 +545,7 @@ def setup_auth(self, request, redirect_uri, cinfo, cookie, acr=None, **kwargs):
# I get back a dictionary
user = identity["uid"]
if "req_user" in kwargs:
- sids = self.endpoint_context.sdb.get_sids_by_sub(kwargs["req_user"])
- if (
- sids
- and user
- != self.endpoint_context.sdb.get_authentication_event(
- sids[-1]
- ).uid
- ):
+ if user != kwargs["req_user"]:
logger.debug("Wanted to be someone else!")
if "prompt" in request and "none" in request["prompt"]:
# Need to authenticate but not allowed
@@ -337,74 +556,49 @@ def setup_auth(self, request, redirect_uri, cinfo, cookie, acr=None, **kwargs):
else:
return {"function": authn, "args": authn_args}
- authn_event = create_authn_event(
- identity["uid"],
- identity.get("salt", ""),
- authn_info=authn_class_ref,
- time_stamp=_ts,
- )
- if "valid_until" in authn_event:
- vu = time.time() + authn.kwargs.get("expires_in", 0.0)
- authn_event["valid_until"] = vu
+ if "sid" in identity:
+ _session_id = identity["sid"]
- return {"authn_event": authn_event, "identity": identity, "user": user}
+ # make sure the client is the same
+ _uid, _cid, _gid = unpack_session_key(_session_id)
+ if request["client_id"] != _cid:
+ return {"function": authn, "args": authn_args}
- def create_authn_response(self, request, sid):
- """
+ grant = _mngr[_session_id]
+ if grant.is_active() is False:
+ return {"function": authn, "args": authn_args}
+ elif request != grant.authorization_request:
+ authn_event = _mngr.get_authentication_event(session_id=_session_id)
+ if authn_event.is_valid() is False: # if not valid, do new login
+ return {"function": authn, "args": authn_args}
- :param self:
- :param request:
- :param sid:
- :return:
- """
- # create the response
- aresp = AuthorizationResponse()
- if request.get("state"):
- aresp["state"] = request["state"]
+ # create new grant
+ _session_id = _mngr.create_grant(authn_event=authn_event,
+ auth_req=request,
+ user_id=user,
+ client_id=request["client_id"])
- if "response_type" in request and request["response_type"] == ["none"]:
- fragment_enc = False
+ if _session_id:
+ authn_event = _mngr.get_authentication_event(session_id=_session_id)
+ if authn_event.is_valid() is False: # if not valid, do new login
+ return {"function": authn, "args": authn_args}
else:
- _context = self.endpoint_context
- _sinfo = _context.sdb[sid]
-
- if request.get("scope"):
- aresp["scope"] = request["scope"]
-
- rtype = set(request["response_type"][:])
- handled_response_type = []
-
- fragment_enc = True
- if len(rtype) == 1 and "code" in rtype:
- fragment_enc = False
-
- if "code" in request["response_type"]:
- _code = aresp["code"] = _context.sdb[sid]["code"]
- handled_response_type.append("code")
- else:
- _context.sdb.update(sid, code=None)
- _code = None
-
- if "token" in rtype:
- _dic = _context.sdb.upgrade_to_token(issue_refresh=False, key=sid)
-
- logger.debug("_dic: %s" % sanitize(_dic))
- for key, val in _dic.items():
- if key in aresp.parameters() and val is not None:
- aresp[key] = val
-
- handled_response_type.append("token")
+ authn_event = create_authn_event(
+ identity["uid"],
+ authn_info=authn_class_ref,
+ time_stamp=_ts,
+ )
+ _exp_in = authn.kwargs.get("expires_in")
+ if _exp_in and "valid_until" in authn_event:
+ authn_event["valid_until"] = utc_time_sans_frac() + _exp_in
- _access_token = aresp.get("access_token", None)
+ _token_usage_rules = self.endpoint_context.authz.usage_rules(
+ request["client_id"])
+ _session_id = _mngr.create_session(authn_event=authn_event, auth_req=request,
+ user_id=user, client_id=request["client_id"],
+ token_usage_rules=_token_usage_rules)
- not_handled = rtype.difference(handled_response_type)
- if not_handled:
- resp = self.error_cls(
- error="invalid_request", error_description="unsupported_response_type"
- )
- return {"response_args": resp, "fragment_enc": fragment_enc}
-
- return {"response_args": aresp, "fragment_enc": fragment_enc}
+ return {"session_id": _session_id, "identity": identity, "user": user}
def aresp_check(self, aresp, request):
return ""
@@ -441,29 +635,122 @@ def response_mode(self, request, **kwargs):
def error_response(self, response_info, error, error_description):
resp = self.error_cls(
- error=error, error_description=error_description
+ error=error, error_description=str(error_description)
)
response_info["response_args"] = resp
return response_info
- def post_authentication(self, user, request, sid, **kwargs):
+ def create_authn_response(self, request: Union[dict, Message], sid: str) -> dict:
"""
- Things that are done after a successful authentication.
- :param user:
:param request:
:param sid:
+ :return:
+ """
+ # create the response
+ aresp = self.response_cls()
+ if request.get("state"):
+ aresp["state"] = request["state"]
+
+ if "response_type" in request and request["response_type"] == ["none"]:
+ fragment_enc = False
+ else:
+ _context = self.endpoint_context
+ _mngr = self.endpoint_context.session_manager
+ _sinfo = _mngr.get_session_info(sid, grant=True)
+
+ if request.get("scope"):
+ aresp["scope"] = request["scope"]
+
+ rtype = set(request["response_type"][:])
+ handled_response_type = []
+
+ fragment_enc = True
+ if len(rtype) == 1 and "code" in rtype:
+ fragment_enc = False
+
+ grant = _sinfo["grant"]
+
+ if "code" in request["response_type"]:
+ _code = self.mint_token(
+ token_type='authorization_code',
+ grant=grant,
+ session_id= _sinfo["session_id"])
+ aresp["code"] = _code.value
+ handled_response_type.append("code")
+ else:
+ _code = None
+
+ if "token" in rtype:
+ if _code:
+ based_on = _code
+ else:
+ based_on = None
+
+ _access_token = self.mint_token(token_type="access_token",
+ grant=grant,
+ session_id=_sinfo["session_id"],
+ based_on=based_on)
+ aresp['access_token'] = _access_token.value
+ aresp['token_type'] = "Bearer"
+ if _access_token.expires_at:
+ aresp["expires_in"] = _access_token.expires_at - utc_time_sans_frac()
+ handled_response_type.append("token")
+ else:
+ _access_token = None
+
+ if "id_token" in request["response_type"]:
+ kwargs = {}
+ if {"code", "id_token", "token"}.issubset(rtype):
+ kwargs = {"code": _code.value, "access_token": _access_token.value}
+ elif {"code", "id_token"}.issubset(rtype):
+ kwargs = {"code": _code.value}
+ elif {"id_token", "token"}.issubset(rtype):
+ kwargs = {"access_token": _access_token.value}
+
+ try:
+ id_token = _context.idtoken.make(sid, **kwargs)
+ except (JWEException, NoSuitableSigningKeys) as err:
+ logger.warning(str(err))
+ resp = self.error_cls(
+ error="invalid_request",
+ error_description="Could not sign/encrypt id_token",
+ )
+ return {"response_args": resp, "fragment_enc": fragment_enc}
+
+ aresp["id_token"] = id_token
+ _mngr.update([_sinfo["user_id"], _sinfo["client_id"]],
+ {"id_token": id_token})
+ handled_response_type.append("id_token")
+
+ not_handled = rtype.difference(handled_response_type)
+ if not_handled:
+ resp = self.error_cls(
+ error="invalid_request", error_description="unsupported_response_type"
+ )
+ return {"response_args": resp, "fragment_enc": fragment_enc}
+
+ aresp = self.extra_response_args(aresp)
+
+ return {"response_args": aresp, "fragment_enc": fragment_enc}
+
+ def post_authentication(self, request: Union[dict, Message],
+ session_id: str, **kwargs) -> dict:
+ """
+ Things that are done after a successful authentication.
+
+ :param request: The authorization request
+ :param session_id: Session identifier
:param kwargs:
:return: A dictionary with 'response_args'
"""
response_info = {}
+ _mngr = self.endpoint_context.session_manager
# Do the authorization
try:
- permission = self.endpoint_context.authz(
- user, client_id=request["client_id"]
- )
+ grant = self.endpoint_context.authz(session_id, request=request)
except ToOld as err:
return self.error_response(
response_info,
@@ -475,8 +762,9 @@ def post_authentication(self, user, request, sid, **kwargs):
response_info, "access_denied", "{}".format(err.args)
)
else:
+ user_id, client_id, grant_id = unpack_session_key(session_id)
try:
- self.endpoint_context.sdb.update(sid, permission=permission)
+ _mngr.set([user_id, client_id, grant_id], grant)
except Exception as err:
return self.error_response(
response_info, "server_error", "{}".format(err.args)
@@ -484,12 +772,10 @@ def post_authentication(self, user, request, sid, **kwargs):
logger.debug("response type: %s" % request["response_type"])
- if self.endpoint_context.sdb.is_session_revoked(sid):
- return self.error_response(
- response_info, "access_denied", "Session is revoked"
- )
+ response_info = self.create_authn_response(request, session_id)
+ response_info["session_id"] = session_id
- response_info = self.create_authn_response(request, sid)
+ logger.debug("Known clients: {}".format(list(self.endpoint_context.cdb.keys())))
try:
redirect_uri = get_uri(self.endpoint_context, request, "redirect_uri")
@@ -507,10 +793,8 @@ def post_authentication(self, user, request, sid, **kwargs):
_cookie = new_cookie(
self.endpoint_context,
- sub=user,
- sid=sid,
- state=request["state"],
- client_id=request["client_id"],
+ sid=session_id,
+ state=request.get("state"),
cookie_name=self.endpoint_context.cookie_name["session"],
)
@@ -530,77 +814,79 @@ def post_authentication(self, user, request, sid, **kwargs):
return response_info
- def authz_part2(
- self,
- user,
- authn_event,
- request,
- subject_type=None,
- acr=None,
- salt=None,
- sector_id=None,
- **kwargs,
- ):
+ # def setup_client_session(self, user_id: str, request: dict) -> str:
+ # _mngr = self.endpoint_context.session_manager
+ # client_id = request['client_id']
+ #
+ # client_info = ClientSessionInfo(
+ # authorization_request=request,
+ # sub=_mngr.sub_func['public'](user_id, salt=_mngr.salt)
+ # )
+ #
+ # _mngr.set([user_id, client_id], client_info)
+ # return session_key(user_id, client_id)
+
+ def authz_part2(self, request, session_id, **kwargs):
"""
After the authentication this is where you should end up
:param user:
- :param authn_event: The Authorization Event
:param request: The Authorization Request
- :param subject_type: The subject_type
- :param acr: The acr
- :param salt: The salt used to produce the sub
- :param sector_id: The sector_id used to produce the sub
+ :param session_id: Session identifier
:param kwargs: possible other parameters
:return: A redirect to the redirect_uri of the client
"""
- sid = setup_session(
- self.endpoint_context,
- request,
- user,
- acr=acr,
- salt=salt,
- authn_event=authn_event,
- subject_type=subject_type,
- sector_id=sector_id,
- )
try:
- resp_info = self.post_authentication(user, request, sid, **kwargs)
+ resp_info = self.post_authentication(request, session_id, **kwargs)
except Exception as err:
return self.error_response({}, "server_error", err)
if "check_session_iframe" in self.endpoint_context.provider_info:
ec = self.endpoint_context
salt = rndstr()
- if not ec.sdb.is_session_revoked(sid):
- authn_event = ec.sdb.get_authentication_event(
- sid
- ) # use the last session
- _state = b64e(
- as_bytes(json.dumps({"authn_time": authn_event["authn_time"]}))
- )
+ try:
+ authn_event = ec.session_manager.get_authentication_event(session_id)
+ except KeyError:
+ return self.error_response({}, "server_error", "No such session")
+ else:
+ if authn_event.is_valid() is False:
+ return self.error_response({}, "server_error", "Authentication has timed out")
+ _state = b64e(
+ as_bytes(json.dumps({"authn_time": authn_event["authn_time"]}))
+ )
+
+ opbs_value = ''
+ if hasattr(ec.cookie_dealer, 'create_cookie'):
session_cookie = ec.cookie_dealer.create_cookie(
as_unicode(_state),
typ="session",
cookie_name=ec.cookie_name["session_management"],
+ same_site="None",
+ http_only=False,
)
opbs = session_cookie[ec.cookie_name["session_management"]]
-
+ opbs_value = opbs.value
+ else:
+ session_cookie = None
logger.debug(
- "compute_session_state: client_id=%s, origin=%s, opbs=%s, salt=%s",
- request["client_id"],
- resp_info["return_uri"],
- opbs.value,
- salt,
- )
+ "Failed to set Cookie, that's not configured in main configuration.")
+
+ logger.debug(
+ "compute_session_state: client_id=%s, origin=%s, opbs=%s, salt=%s",
+ request["client_id"],
+ resp_info["return_uri"],
+ opbs_value,
+ salt,
+ )
- _session_state = compute_session_state(
- opbs.value, salt, request["client_id"], resp_info["return_uri"]
- )
+ _session_state = compute_session_state(
+ opbs_value, salt, request["client_id"], resp_info["return_uri"]
+ )
+ if opbs_value and session_cookie:
if "cookie" in resp_info:
if isinstance(resp_info["cookie"], list):
resp_info["cookie"].append(session_cookie)
@@ -609,7 +895,7 @@ def authz_part2(
else:
resp_info["cookie"] = session_cookie
- resp_info["response_args"]["session_state"] = _session_state
+ resp_info["response_args"]["session_state"] = _session_state
# Mix-Up mitigation
resp_info["response_args"]["iss"] = self.endpoint_context.issuer
@@ -617,30 +903,35 @@ def authz_part2(
return resp_info
- def process_request(self, request_info=None, **kwargs):
+ def do_request_user(self, request_info, **kwargs):
+ return kwargs
+
+ def process_request(self, request: Union[Message, dict], **kwargs):
""" The AuthorizationRequest endpoint
- :param request_info: The authorization request as a dictionary
+ :param request: The authorization request as a Message instance
:return: dictionary
"""
- if isinstance(request_info, self.error_cls):
- return request_info
+ if isinstance(request, self.error_cls):
+ return request
- _cid = request_info["client_id"]
+ _cid = request["client_id"]
cinfo = self.endpoint_context.cdb[_cid]
-
logger.debug("client {}: {}".format(_cid, cinfo))
# this apply the default optionally deny_unknown_scopes policy
- check_unknown_scopes_policy(request_info, cinfo, self.endpoint_context)
+ if cinfo:
+ check_unknown_scopes_policy(request, cinfo, self.endpoint_context)
cookie = kwargs.get("cookie", "")
if cookie:
del kwargs["cookie"]
+ kwargs = self.do_request_user(request_info=request, **kwargs)
+
info = self.setup_auth(
- request_info, request_info["redirect_uri"], cinfo, cookie, **kwargs
+ request, request["redirect_uri"], cinfo, cookie, **kwargs
)
if "error" in info:
@@ -649,18 +940,81 @@ def process_request(self, request_info=None, **kwargs):
_function = info.get("function")
if not _function:
logger.debug("- authenticated -")
- logger.debug("AREQ keys: %s" % request_info.keys())
- res = self.authz_part2(
- info["user"], info["authn_event"], request_info, cookie=cookie
- )
- return res
+ logger.debug("AREQ keys: %s" % request.keys())
+ return self.authz_part2(request=request, cookie=cookie, **info)
try:
# Run the authentication function
return {
"http_response": _function(**info["args"]),
- "return_uri": request_info["redirect_uri"],
+ "return_uri": request["redirect_uri"],
}
except Exception as err:
logger.exception(err)
return {"http_response": "Internal error: {}".format(err)}
+
+
+class AllowedAlgorithms:
+ def __init__(self, algorithm_parameters):
+ self.algorithm_parameters = algorithm_parameters
+
+ def __call__(self, client_id, endpoint_context, alg, alg_type):
+ _cinfo = endpoint_context.cdb[client_id]
+ _pinfo = endpoint_context.provider_info
+
+ _reg, _sup = self.algorithm_parameters[alg_type]
+ _allowed = _cinfo.get(_reg)
+ if _allowed is None:
+ _allowed = _pinfo.get(_sup)
+
+ if alg not in _allowed:
+ logger.error(
+ "Signing alg user: {} not among allowed: {}".format(alg, _allowed)
+ )
+ raise ValueError("Not allowed '%s' algorithm used", alg)
+
+
+def re_authenticate(request, authn):
+ return False
+
+# class Authorization(authorization.Authorization):
+# request_cls = oauth2.AuthorizationRequest
+# response_cls = oauth2.AuthorizationResponse
+# error_cls = oauth2.AuthorizationErrorResponse
+# request_format = "urlencoded"
+# response_format = "urlencoded"
+# response_placement = "url"
+# endpoint_name = "authorization_endpoint"
+# name = "authorization"
+# default_capabilities = {
+# "claims_parameter_supported": True,
+# "request_parameter_supported": True,
+# "request_uri_parameter_supported": True,
+# "response_types_supported": ["code", "token", "code token"],
+# "response_modes_supported": ["query", "fragment", "form_post"],
+# "request_object_signing_alg_values_supported": None,
+# "request_object_encryption_alg_values_supported": None,
+# "request_object_encryption_enc_values_supported": None,
+# "grant_types_supported": ["authorization_code", "implicit"],
+# "scopes_supported": [],
+# }
+#
+# def __init__(self, endpoint_context, **kwargs):
+# authorization.Authorization.__init__(self, endpoint_context, **kwargs)
+# # self.pre_construct.append(self._pre_construct)
+# self.post_parse_request.append(self._do_request_uri)
+# self.post_parse_request.append(self._post_parse_request)
+# # Has to be done elsewhere. To make sure things happen in order.
+# # self.scopes_supported = available_scopes(endpoint_context)
+#
+# def setup_client_session(self, user_id: str, request: dict) -> str:
+# _mngr = self.endpoint_context.session_manager
+# client_id = request['client_id']
+#
+# client_info = ClientSessionInfo(
+# authorization_request=request,
+# sub=_mngr.sub_func['public'](user_id, salt=_mngr.salt)
+# )
+#
+# _mngr.set([user_id, client_id], client_info)
+# return session_key(user_id, client_id)
diff --git a/src/oidcendpoint/oauth2/introspection.py b/src/oidcendpoint/oauth2/introspection.py
index 4211bd8..4126488 100644
--- a/src/oidcendpoint/oauth2/introspection.py
+++ b/src/oidcendpoint/oauth2/introspection.py
@@ -1,11 +1,10 @@
"""Implements RFC7662"""
import logging
-from oidcendpoint.token_handler import UnknownToken
from oidcmsg import oauth2
-from oidcmsg.time_util import utc_time_sans_frac
from oidcendpoint.endpoint import Endpoint
+from oidcendpoint.token.exception import UnknownToken
LOGGER = logging.getLogger(__name__)
@@ -23,59 +22,36 @@ class Introspection(Endpoint):
def __init__(self, **kwargs):
Endpoint.__init__(self, **kwargs)
self.offset = kwargs.get("offset", 0)
- self.enable_claims_per_client = kwargs.get("enable_claims_per_client", False)
-
- def get_client_id_from_token(self, endpoint_context, token, request=None):
- """
- Will try to match tokens against information in the session DB.
-
- :param endpoint_context:
- :param token:
- :param request:
- :return: client_id if there was a match
- """
- sinfo = endpoint_context.sdb[token]
- return sinfo["authn_req"]["client_id"]
-
- def _get_client_claims(self, token):
- client_id = self.get_client_id_from_token(self.endpoint_context, token)
- client = self.endpoint_context.cdb.get(client_id, {})
- return client.get("introspection_claims")
-
- def _get_user_info(self, token_info):
- user_id = self.endpoint_context.sdb.sso_db.get_uid_by_sid(token_info["sid"])
- return self.endpoint_context.userinfo(user_id, client_id=None)
-
- def _add_claims(self, token_info, claims, payload):
- user_info = self._get_user_info(token_info)
- for attr in claims:
- try:
- payload[attr] = user_info[attr]
- except KeyError:
- pass
-
- def _introspect(self, token):
- try:
- info = self.endpoint_context.sdb[token]
- except (KeyError, UnknownToken):
- return None
+ def _introspect(self, token, client_id, grant):
# Make sure that the token is an access_token or a refresh_token
- if token != info.get("access_token") and token != info.get("refresh_token"):
+ if token.type not in ["access_token", "refresh_token"]:
return None
- eat = info.get("expires_at")
- if eat and eat < utc_time_sans_frac():
+ if not token.is_active():
return None
- if info: # Now what can be returned ?
- ret = info.to_dict()
- ret["iss"] = self.endpoint_context.issuer
-
- if "scope" not in ret:
- ret["scope"] = " ".join(info["authn_req"]["scope"])
-
- return ret
+ scope = token.scope
+ if not scope:
+ scope = grant.scope
+ aud = token.resources
+ if not aud:
+ aud = grant.resources
+
+ ret = {
+ "active": True,
+ "scope": " ".join(scope),
+ "client_id": client_id,
+ "token_type": token.type,
+ "exp": token.expires_at,
+ "iat": token.issued_at,
+ "sub": grant.sub,
+ "iss": self.endpoint_context.issuer
+ }
+ if aud:
+ ret["aud"] = aud
+
+ return ret
def process_request(self, request=None, **kwargs):
"""
@@ -88,29 +64,38 @@ def process_request(self, request=None, **kwargs):
if "error" in _introspect_request:
return _introspect_request
- _token = _introspect_request["token"]
+ request_token = _introspect_request["token"]
_resp = self.response_cls(active=False)
- _info = self._introspect(_token)
+ try:
+ _session_info = self.endpoint_context.session_manager.get_session_info_by_token(
+ request_token, grant=True)
+ except UnknownToken:
+ return {"response_args": _resp}
+
+ _grant = _session_info["grant"]
+ _token = _grant.get_token(request_token)
+
+ _info = self._introspect(_token, _session_info["client_id"], _grant)
if _info is None:
return {"response_args": _resp}
if "release" in self.kwargs:
if "username" in self.kwargs["release"]:
try:
- _info["username"] = self.endpoint_context.userinfo.search(
- sub=_info["sub"]
- )
+ _info["username"] = _session_info["user_id"]
except KeyError:
pass
_resp.update(_info)
_resp.weed()
- if self.enable_claims_per_client:
- client_claims = self._get_client_claims(_token)
- if client_claims:
- self._add_claims(_info, client_claims, _resp)
+ _claims_restriction = _grant.claims.get("introspection")
+ if _claims_restriction:
+ user_info = self.endpoint_context.claims_interface.get_user_claims(
+ _session_info["user_id"], _claims_restriction)
+ if user_info:
+ _resp.update(user_info)
_resp["active"] = True
diff --git a/src/oidcendpoint/oidc/add_on/pkce.py b/src/oidcendpoint/oidc/add_on/pkce.py
index 3548938..e760dd0 100644
--- a/src/oidcendpoint/oidc/add_on/pkce.py
+++ b/src/oidcendpoint/oidc/add_on/pkce.py
@@ -5,9 +5,6 @@
from oidcmsg.oauth2 import AuthorizationErrorResponse
from oidcmsg.oidc import TokenErrorResponse
-from oidcendpoint.exception import MultipleCodeUsage
-from oidcendpoint.token_handler import UnknownToken
-
LOGGER = logging.getLogger(__name__)
@@ -96,11 +93,14 @@ def post_token_parse(request, client_id, endpoint_context, **kwargs):
return request
try:
- _info = endpoint_context.sdb[request["code"]]
- except (KeyError, UnknownToken, MultipleCodeUsage):
- # This will be handled by process_request
- return request
- _authn_req = _info["authn_req"]
+ _session_info = endpoint_context.session_manager.get_session_info_by_token(
+ request["code"], grant=True)
+ except KeyError:
+ return TokenErrorResponse(
+ error="invalid_grant", error_description="Unknown access grant"
+ )
+
+ _authn_req = _session_info["grant"].authorization_request
if "code_challenge" in _authn_req:
if "code_verifier" not in request:
@@ -109,11 +109,11 @@ def post_token_parse(request, client_id, endpoint_context, **kwargs):
error_description="Missing code_verifier",
)
- _method = _info["authn_req"]["code_challenge_method"]
+ _method = _authn_req["code_challenge_method"]
if not verify_code_challenge(
request["code_verifier"],
- _info["authn_req"]["code_challenge"],
+ _authn_req["code_challenge"],
_method,
):
return TokenErrorResponse(
@@ -139,6 +139,7 @@ def add_pkce_support(endpoint, **kwargs):
return
authn_endpoint.post_parse_request.append(post_authn_parse)
+ token_endpoint.post_parse_request.append(post_token_parse)
if "essential" not in kwargs:
kwargs["essential"] = False
@@ -154,5 +155,3 @@ def add_pkce_support(endpoint, **kwargs):
kwargs["code_challenge_methods"][method] = CC_METHOD[method]
authn_endpoint.endpoint_context.args["pkce"] = kwargs
-
- token_endpoint.post_parse_request.append(post_token_parse)
diff --git a/src/oidcendpoint/oidc/authorization.py b/src/oidcendpoint/oidc/authorization.py
index dbb70fb..be49d57 100755
--- a/src/oidcendpoint/oidc/authorization.py
+++ b/src/oidcendpoint/oidc/authorization.py
@@ -1,43 +1,13 @@
-import json
import logging
+from urllib.parse import urlsplit
-from cryptojwt import BadSyntax
-from cryptojwt.jwe.exception import JWEException
-from cryptojwt.jws.exception import NoSuitableSigningKeys
-from cryptojwt.jwt import utc_time_sans_frac
-from cryptojwt.utils import as_bytes
-from cryptojwt.utils import as_unicode
-from cryptojwt.utils import b64d
-from cryptojwt.utils import b64e
-from oidcendpoint.token_handler import UnknownToken
from oidcmsg import oidc
-from oidcmsg.exception import ParameterError
from oidcmsg.oidc import Claims
from oidcmsg.oidc import verified_claim_name
-from oidcendpoint import rndstr
-from oidcendpoint import sanitize
-from oidcendpoint.authn_event import create_authn_event
-from oidcendpoint.common.authorization import FORM_POST
-from oidcendpoint.common.authorization import AllowedAlgorithms
-from oidcendpoint.common.authorization import authn_args_gather
-from oidcendpoint.common.authorization import get_uri
-from oidcendpoint.common.authorization import inputs
-from oidcendpoint.common.authorization import max_age
-from oidcendpoint.cookie import append_cookie
-from oidcendpoint.cookie import compute_session_state
-from oidcendpoint.cookie import new_cookie
-from oidcendpoint.endpoint import Endpoint
-from oidcendpoint.exception import InvalidRequest
-from oidcendpoint.exception import NoSuchAuthentication
-from oidcendpoint.exception import RedirectURIError
-from oidcendpoint.exception import ServiceError
-from oidcendpoint.exception import TamperAllert
-from oidcendpoint.exception import ToOld
-from oidcendpoint.exception import UnknownClient
-from oidcendpoint.oauth2.authorization import check_unknown_scopes_policy
-from oidcendpoint.session import setup_session
-from oidcendpoint.user_authn.authn_context import pick_auth
+from oidcendpoint.oauth2 import authorization
+from oidcendpoint.session import session_key
+from oidcendpoint.session.info import ClientSessionInfo
logger = logging.getLogger(__name__)
@@ -68,6 +38,11 @@ def acr_claims(request):
return acrdef["values"]
+def host_component(url):
+ res = urlsplit(url)
+ return "{}://{}".format(res.scheme, res.netloc)
+
+
ALG_PARAMS = {
"sign": [
"request_object_signing_alg",
@@ -92,7 +67,7 @@ def re_authenticate(request, authn):
return False
-class Authorization(Endpoint):
+class Authorization(authorization.Authorization):
request_cls = oidc.AuthorizationRequest
response_cls = oidc.AuthorizationResponse
error_cls = oidc.AuthorizationErrorResponse
@@ -123,601 +98,41 @@ class Authorization(Endpoint):
}
def __init__(self, endpoint_context, **kwargs):
- Endpoint.__init__(self, endpoint_context, **kwargs)
+ authorization.Authorization.__init__(self, endpoint_context, **kwargs)
# self.pre_construct.append(self._pre_construct)
self.post_parse_request.append(self._do_request_uri)
self.post_parse_request.append(self._post_parse_request)
- self.allowed_request_algorithms = AllowedAlgorithms(ALG_PARAMS)
-
- def filter_request(self, endpoint_context, req):
- return req
-
- def verify_response_type(self, request, cinfo):
- # Checking response types
- _registered = [set(rt.split(" ")) for rt in cinfo.get("response_types", [])]
- if not _registered:
- # If no response_type is registered by the client then we'll
- # code which it the default according to the OIDC spec.
- _registered = [{"code"}]
-
- # Is the asked for response_type among those that are permitted
- return set(request["response_type"]) in _registered
-
- def _do_request_uri(self, request, client_id, endpoint_context, **kwargs):
- _request_uri = request.get("request_uri")
- if _request_uri:
- # Do I do pushed authorization requests ?
- if "pushed_authorization" in endpoint_context.endpoint:
- # Is it a UUID urn
- if _request_uri.startswith("urn:uuid:"):
- _req = endpoint_context.par_db.get(_request_uri)
- if _req:
- del endpoint_context.par_db[_request_uri] # One time
- # usage
- return _req
- else:
- raise ValueError("Got a request_uri I can not resolve")
-
- # Do I support request_uri ?
- _supported = endpoint_context.provider_info.get(
- "request_uri_parameter_supported", True
- )
- _registered = endpoint_context.cdb[client_id].get("request_uris")
- # Not registered should be handled else where
- if _registered:
- # Before matching remove a possible fragment
- _p = _request_uri.split("#")
- # ignore registered fragments for now.
- if _p[0] not in [l[0] for l in _registered]:
- raise ValueError("A request_uri outside the registered")
-
- # Fetch the request
- _resp = endpoint_context.httpc.get(
- _request_uri, **endpoint_context.httpc_params
- )
- if _resp.status_code == 200:
- args = {"keyjar": endpoint_context.keyjar, "issuer": client_id}
- _ver_request = self.request_cls().from_jwt(_resp.text, **args)
- self.allowed_request_algorithms(
- client_id,
- endpoint_context,
- _ver_request.jws_header.get("alg", "RS256"),
- "sign",
- )
- if _ver_request.jwe_header is not None:
- self.allowed_request_algorithms(
- client_id,
- endpoint_context,
- _ver_request.jws_header.get("alg"),
- "enc_alg",
- )
- self.allowed_request_algorithms(
- client_id,
- endpoint_context,
- _ver_request.jws_header.get("enc"),
- "enc_enc",
- )
- # The protected info overwrites the non-protected
- for k, v in _ver_request.items():
- request[k] = v
-
- request[verified_claim_name("request")] = _ver_request
- else:
- raise ServiceError("Got a %s response", _resp.status)
-
- return request
-
- def _post_parse_request(self, request, client_id, endpoint_context, **kwargs):
- """
- Verify the authorization request.
-
- :param endpoint_context:
- :param request:
- :param client_id:
- :param kwargs:
- :return:
- """
- if not request:
- logger.debug("No AuthzRequest")
- return self.error_cls(
- error="invalid_request", error_description="Can not parse AuthzRequest"
- )
-
- request = self.filter_request(endpoint_context, request)
-
- _cinfo = endpoint_context.cdb.get(client_id)
- if not _cinfo:
- logger.error(
- "Client ID ({}) not in client database".format(request["client_id"])
- )
- return self.error_cls(
- error="unauthorized_client", error_description="unknown client"
- )
-
- # Is the asked for response_type among those that are permitted
- if not self.verify_response_type(request, _cinfo):
- return self.error_cls(
- error="invalid_request",
- error_description="Trying to use unregistered response_type",
- )
-
- # Get a verified redirect URI
- try:
- redirect_uri = get_uri(endpoint_context, request, "redirect_uri")
- except (RedirectURIError, ParameterError, UnknownClient) as err:
- return self.error_cls(
- error="invalid_request",
- error_description="{}:{}".format(err.__class__.__name__, err),
- )
- else:
- request["redirect_uri"] = redirect_uri
-
- return request
-
- def pick_authn_method(self, request, redirect_uri, acr=None, **kwargs):
- auth_id = kwargs.get("auth_method_id")
- if auth_id:
- return self.endpoint_context.authn_broker[auth_id]
-
- if acr:
- res = self.endpoint_context.authn_broker.pick(acr)
- else:
- res = pick_auth(self.endpoint_context, request)
-
- if res:
- return res
- else:
- return {
- "error": "access_denied",
- "error_description": "ACR I do not support",
- "return_uri": redirect_uri,
- "return_type": request["response_type"],
- }
-
- def setup_auth(self, request, redirect_uri, cinfo, cookie, acr=None, **kwargs):
- """
-
- :param request: The authorization/authentication request
- :param redirect_uri:
- :param cinfo: client info
- :param cookie:
- :param acr: Default ACR, if nothing else is specified
- :param kwargs:
- :return:
- """
-
- res = self.pick_authn_method(request, redirect_uri, acr, **kwargs)
-
- authn = res["method"]
- authn_class_ref = res["acr"]
- session = None
-
- try:
- _auth_info = kwargs.get("authn", "")
- if "upm_answer" in request and request["upm_answer"] == "true":
- _max_age = 0
- else:
- _max_age = max_age(request)
-
- identity, _ts = authn.authenticated_as(
- cookie, authorization=_auth_info, max_age=_max_age
- )
- except (NoSuchAuthentication, TamperAllert):
- identity = None
- _ts = 0
- except ToOld:
- logger.info("Too old authentication")
- identity = None
- _ts = 0
- except UnknownToken:
- logger.info("Unknown Token")
- identity = None
- _ts = 0
- else:
- if identity:
- try: # If identity['uid'] is in fact a base64 encoded JSON string
- _id = b64d(as_bytes(identity["uid"]))
- except BadSyntax:
- pass
- else:
- identity = json.loads(as_unicode(_id))
-
- try:
- session = self.endpoint_context.sdb[identity.get("sid")]
- except UnknownToken:
- identity= None
- else:
- if not session or "revoked" in session:
- identity = None
-
- authn_args = authn_args_gather(request, authn_class_ref, cinfo, **kwargs)
-
- # To authenticate or Not
- if identity is None: # No!
- logger.info("No active authentication")
- logger.debug(
- "Known clients: {}".format(list(self.endpoint_context.cdb.keys()))
- )
-
- if "prompt" in request and "none" in request["prompt"]:
- # Need to authenticate but not allowed
- return {
- "error": "login_required",
- "return_uri": redirect_uri,
- "return_type": request["response_type"],
- }
- else:
- return {"function": authn, "args": authn_args}
- else:
- logger.info("Active authentication")
- if re_authenticate(request, authn):
- # demand re-authentication
- return {"function": authn, "args": authn_args}
- else:
- # I get back a dictionary
- user = identity["uid"]
- if "req_user" in kwargs:
- sids = self.endpoint_context.sdb.get_sids_by_sub(kwargs["req_user"])
- if (
- sids
- and user
- != self.endpoint_context.sdb.get_authentication_event(
- sids[-1]
- ).uid
- ):
- logger.debug("Wanted to be someone else!")
- if "prompt" in request and "none" in request["prompt"]:
- # Need to authenticate but not allowed
- return {
- "error": "login_required",
- "return_uri": redirect_uri,
- }
- else:
- return {"function": authn, "args": authn_args}
-
- authn_event = None
- if session:
- authn_event = session.get('authn_event')
-
- if authn_event is None:
- authn_event = create_authn_event(
- identity["uid"],
- identity.get("salt", ""),
- authn_info=authn_class_ref,
- time_stamp=_ts,
- )
-
- _exp_in = authn.kwargs.get("expires_in")
- if _exp_in and "valid_until" in authn_event:
- authn_event["valid_until"] = utc_time_sans_frac() + _exp_in
-
- return {"authn_event": authn_event, "identity": identity, "user": user}
-
- def create_authn_response(self, request, sid):
- """
-
- :param self:
- :param request:
- :param sid:
- :return:
- """
- # create the response
- aresp = self.response_cls()
- if request.get("state"):
- aresp["state"] = request["state"]
-
- if "response_type" in request and request["response_type"] == ["none"]:
- fragment_enc = False
- else:
- _context = self.endpoint_context
- _sinfo = _context.sdb[sid]
-
- if request.get("scope"):
- aresp["scope"] = request["scope"]
-
- rtype = set(request["response_type"][:])
- handled_response_type = []
-
- fragment_enc = True
- if len(rtype) == 1 and "code" in rtype:
- fragment_enc = False
- if "code" in request["response_type"]:
- _code = aresp["code"] = _context.sdb[sid]["code"]
- handled_response_type.append("code")
- else:
- _context.sdb.update(sid, code=None)
- _code = None
+ def setup_client_session(self, user_id: str, request: dict) -> str:
+ _mngr = self.endpoint_context.session_manager
+ client_id = request['client_id']
- if "token" in rtype:
- _dic = _context.sdb.upgrade_to_token(issue_refresh=False, key=sid)
+ _client_info = self.endpoint_context.cdb[client_id]
+ sub_type = _client_info.get("subject_type")
+ if sub_type and sub_type == "pairwise":
+ sector_identifier_uri = _client_info.get("sector_identifier_uri")
+ if sector_identifier_uri is None:
+ sector_identifier_uri = host_component(_client_info["redirect_uris"][0])
- logger.debug("_dic: %s" % sanitize(_dic))
- for key, val in _dic.items():
- if key in aresp.parameters() and val is not None:
- aresp[key] = val
-
- handled_response_type.append("token")
-
- _access_token = aresp.get("access_token", None)
-
- if "id_token" in request["response_type"]:
- kwargs = {}
- if {"code", "id_token", "token"}.issubset(rtype):
- kwargs = {"code": _code, "access_token": _access_token}
- elif {"code", "id_token"}.issubset(rtype):
- kwargs = {"code": _code}
- elif {"id_token", "token"}.issubset(rtype):
- kwargs = {"access_token": _access_token}
-
- if request["response_type"] == ["id_token"]:
- kwargs["user_claims"] = True
-
- try:
- id_token = _context.idtoken.make(request, _sinfo, **kwargs)
- except (JWEException, NoSuitableSigningKeys) as err:
- logger.warning(str(err))
- resp = self.error_cls(
- error="invalid_request",
- error_description="Could not sign/encrypt id_token",
- )
- return {"response_args": resp, "fragment_enc": fragment_enc}
-
- aresp["id_token"] = id_token
- _sinfo["id_token"] = id_token
- handled_response_type.append("id_token")
-
- not_handled = rtype.difference(handled_response_type)
- if not_handled:
- resp = self.error_cls(
- error="invalid_request", error_description="unsupported_response_type"
- )
- return {"response_args": resp, "fragment_enc": fragment_enc}
-
- return {"response_args": aresp, "fragment_enc": fragment_enc}
-
- def aresp_check(self, aresp, request):
- return ""
-
- def response_mode(self, request, **kwargs):
- resp_mode = request["response_mode"]
- if resp_mode == "form_post":
- msg = FORM_POST.format(
- inputs=inputs(kwargs["response_args"].to_dict()),
- action=kwargs["return_uri"],
- )
- kwargs.update(
- {
- "response_msg": msg,
- "content_type": "text/html",
- "response_placement": "body",
- }
- )
- elif resp_mode == "fragment":
- if "fragment_enc" in kwargs:
- if not kwargs["fragment_enc"]:
- # Can't be done
- raise InvalidRequest("wrong response_mode")
- else:
- kwargs["fragment_enc"] = True
- elif resp_mode == "query":
- if "fragment_enc" in kwargs:
- if kwargs["fragment_enc"]:
- # Can't be done
- raise InvalidRequest("wrong response_mode")
- else:
- raise InvalidRequest("Unknown response_mode")
- return kwargs
-
- def error_response(self, response_info, error, error_description):
- resp = self.error_cls(
- error=error, error_description=str(error_description)
- )
- response_info["response_args"] = resp
- return response_info
-
- def post_authentication(self, user, request, sid, **kwargs):
- """
- Things that are done after a successful authentication.
-
- :param user:
- :param request:
- :param sid:
- :param kwargs:
- :return: A dictionary with 'response_args'
- """
-
- response_info = {}
-
- # Do the authorization
- try:
- permission = self.endpoint_context.authz(
- user, client_id=request["client_id"]
- )
- except ToOld as err:
- return self.error_response(
- response_info,
- "access_denied",
- "Authentication to old {}".format(err.args),
- )
- except Exception as err:
- return self.error_response(
- response_info, "access_denied", "{}".format(err.args)
+ client_info = ClientSessionInfo(
+ authorization_request=request,
+ sub=_mngr.sub_func[sub_type](user_id, salt=_mngr.salt,
+ sector_identifier=sector_identifier_uri)
)
else:
- try:
- self.endpoint_context.sdb.update(sid, permission=permission)
- except Exception as err:
- return self.error_response(
- response_info, "server_error", "{}".format(err.args)
- )
+ sub_type = self.kwargs.get("subject_type")
+ if not sub_type:
+ sub_type = "public"
- logger.debug("response type: %s" % request["response_type"])
-
- if self.endpoint_context.sdb.is_session_revoked(sid):
- return self.error_response(
- response_info, "access_denied", "Session is revoked"
+ client_info = ClientSessionInfo(
+ authorization_request=request,
+ sub=_mngr.sub_func[sub_type](user_id, salt=_mngr.salt)
)
- response_info = self.create_authn_response(request, sid)
-
- logger.debug("Known clients: {}".format(list(self.endpoint_context.cdb.keys())))
-
- try:
- redirect_uri = get_uri(self.endpoint_context, request, "redirect_uri")
- except (RedirectURIError, ParameterError) as err:
- return self.error_response(
- response_info, "invalid_request", "{}".format(err.args)
- )
- else:
- response_info["return_uri"] = redirect_uri
-
- # Must not use HTTP unless implicit grant type and native application
- # info = self.aresp_check(response_info['response_args'], request)
- # if isinstance(info, ResponseMessage):
- # return info
-
- _cookie = new_cookie(
- self.endpoint_context,
- uid=user,
- sid=sid,
- state=request["state"],
- client_id=request["client_id"],
- cookie_name=self.endpoint_context.cookie_name["session"],
- )
-
- # Now about the response_mode. Should not be set if it's obvious
- # from the response_type. Knows about 'query', 'fragment' and
- # 'form_post'.
-
- if "response_mode" in request:
- try:
- response_info = self.response_mode(request, **response_info)
- except InvalidRequest as err:
- return self.error_response(
- response_info, "invalid_request", "{}".format(err.args)
- )
-
- response_info["cookie"] = [_cookie]
-
- return response_info
-
- def authz_part2(
- self,
- user,
- authn_event,
- request,
- subject_type=None,
- acr=None,
- salt=None,
- sector_id=None,
- **kwargs,
- ):
- """
- After the authentication this is where you should end up
-
- :param user:
- :param authn_event: The Authorization Event
- :param request: The Authorization Request
- :param subject_type: The subject_type
- :param acr: The acr
- :param salt: The salt used to produce the sub
- :param sector_id: The sector_id used to produce the sub
- :param kwargs: possible other parameters
- :return: A redirect to the redirect_uri of the client
- """
- sid = setup_session(
- self.endpoint_context,
- request,
- user,
- acr=acr,
- salt=salt,
- authn_event=authn_event,
- subject_type=subject_type,
- sector_id=sector_id,
- )
-
- try:
- resp_info = self.post_authentication(user, request, sid, **kwargs)
- except Exception as err:
- return self.error_response({}, "server_error", err)
-
- if "check_session_iframe" in self.endpoint_context.provider_info:
- ec = self.endpoint_context
- salt = rndstr()
- if not ec.sdb.is_session_revoked(sid):
- authn_event = ec.sdb.get_authentication_event(
- sid
- ) # use the last session
- _state = b64e(
- as_bytes(json.dumps({"authn_time": authn_event["authn_time"]}))
- )
-
- opbs_value = ''
- if hasattr(ec.cookie_dealer, 'create_cookie'):
- session_cookie = ec.cookie_dealer.create_cookie(
- as_unicode(_state),
- typ="session",
- cookie_name=ec.cookie_name["session_management"],
- same_site="None",
- http_only=False,
- )
-
- opbs = session_cookie[ec.cookie_name["session_management"]]
- opbs_value = opbs.value
- else:
- logger.debug("Failed to set Cookie, that's not configured in main configuration.")
-
- logger.debug(
- "compute_session_state: client_id=%s, origin=%s, opbs=%s, salt=%s",
- request["client_id"],
- resp_info["return_uri"],
- opbs_value,
- salt,
- )
-
- _session_state = compute_session_state(
- opbs_value, salt, request["client_id"], resp_info["return_uri"]
- )
-
- if opbs_value:
- if "cookie" in resp_info:
- if isinstance(resp_info["cookie"], list):
- resp_info["cookie"].append(session_cookie)
- else:
- append_cookie(resp_info["cookie"], session_cookie)
- else:
- resp_info["cookie"] = session_cookie
-
- resp_info["response_args"]["session_state"] = _session_state
-
- # Mix-Up mitigation
- resp_info["response_args"]["iss"] = self.endpoint_context.issuer
- resp_info["response_args"]["client_id"] = request["client_id"]
-
- return resp_info
-
- def process_request(self, request_info=None, **kwargs):
- """ The AuthorizationRequest endpoint
-
- :param request_info: The authorization request as a dictionary
- :return: dictionary
- """
-
- if isinstance(request_info, self.error_cls):
- return request_info
-
- _cid = request_info["client_id"]
- cinfo = self.endpoint_context.cdb[_cid]
- logger.debug("client {}: {}".format(_cid, cinfo))
-
- # this apply the default optionally deny_unknown_scopes policy
- check_unknown_scopes_policy(request_info, cinfo, self.endpoint_context)
-
- cookie = kwargs.get("cookie", "")
- if cookie:
- del kwargs["cookie"]
+ _mngr.set([user_id, client_id], client_info)
+ return session_key(user_id, client_id)
+ def do_request_user(self, request_info, **kwargs):
if proposed_user(request_info):
kwargs["req_user"] = proposed_user(request_info)
else:
@@ -727,29 +142,4 @@ def process_request(self, request_info=None, **kwargs):
kwargs["req_user"] = self.endpoint_context.login_hint_lookup[
_login_hint
]
-
- info = self.setup_auth(
- request_info, request_info["redirect_uri"], cinfo, cookie, **kwargs
- )
-
- if "error" in info:
- return info
-
- _function = info.get("function")
- if not _function:
- logger.debug("- authenticated -")
- logger.debug("AREQ keys: %s" % request_info.keys())
- res = self.authz_part2(
- info["user"], info["authn_event"], request_info, cookie=cookie
- )
- return res
-
- try:
- # Run the authentication function
- return {
- "http_response": _function(**info["args"]),
- "return_uri": request_info["redirect_uri"],
- }
- except Exception as err:
- logger.exception(err)
- return {"http_response": "Internal error: {}".format(err)}
+ return kwargs
diff --git a/src/oidcendpoint/oidc/refresh_token.py b/src/oidcendpoint/oidc/refresh_token.py
deleted file mode 100755
index a2ffea1..0000000
--- a/src/oidcendpoint/oidc/refresh_token.py
+++ /dev/null
@@ -1,127 +0,0 @@
-import logging
-
-from oidcmsg import oidc
-from oidcmsg.oauth2 import ResponseMessage
-from oidcmsg.oidc import AccessTokenResponse
-from oidcmsg.oidc import RefreshAccessTokenRequest
-from oidcmsg.oidc import TokenErrorResponse
-
-from oidcendpoint import sanitize
-from oidcendpoint.client_authn import verify_client
-from oidcendpoint.cookie import new_cookie
-from oidcendpoint.endpoint import Endpoint
-from oidcendpoint.token_handler import ExpiredToken
-from oidcendpoint.userinfo import by_schema
-
-logger = logging.getLogger(__name__)
-
-
-class RefreshAccessToken(Endpoint):
- request_cls = oidc.RefreshAccessTokenRequest
- response_cls = oidc.AccessTokenResponse
- error_cls = TokenErrorResponse
- request_format = "json"
- request_placement = "body"
- response_format = "json"
- response_placement = "body"
- endpoint_name = "token_endpoint"
- name = "refresh_token"
-
- def __init__(self, endpoint_context, **kwargs):
- Endpoint.__init__(self, endpoint_context, **kwargs)
- self.post_parse_request.append(self._post_parse_request)
-
- def _refresh_access_token(self, req, **kwargs):
- _sdb = self.endpoint_context.sdb
-
- # client_id = str(req["client_id"])
-
- if req["grant_type"] != "refresh_token":
- return self.error_cls(
- error="invalid_request", error_description="Wrong grant_type"
- )
-
- rtoken = req["refresh_token"]
- try:
- _info = _sdb.refresh_token(rtoken)
- except ExpiredToken:
- return self.error_cls(
- error="invalid_request", error_description="Refresh token is expired"
- )
-
- return by_schema(AccessTokenResponse, **_info)
-
- def client_authentication(self, request, auth=None, **kwargs):
- """
- Deal with client authentication
-
- :param request: The refresh access token request
- :param auth: Client authentication information
- :param kwargs: Extra keyword arguments
- :return: dictionary containing client id, client authentication method
- and possibly access token.
- """
- try:
- auth_info = verify_client(self.endpoint_context, request, auth)
- msg = ""
- except Exception as err:
- msg = "Failed to verify client due to: {}".format(err)
- logger.error(msg)
- return self.error_cls(error="unauthorized_client", error_description=msg)
- else:
- if "client_id" not in auth_info:
- logger.error("No client_id, authentication failed")
- return self.error_cls(
- error="unauthorized_client", error_description="unknown client"
- )
-
- return auth_info
-
- def _post_parse_request(self, request, client_id="", **kwargs):
- """
- This is where clients come to refresh their access tokens
-
- :param request: The request
- :param authn: Authentication info, comes from HTTP header
- :returns:
- """
-
- request = RefreshAccessTokenRequest(**request.to_dict())
-
- try:
- keyjar = self.endpoint_context.keyjar
- except AttributeError:
- keyjar = ""
-
- request.verify(keyjar=keyjar, opponent_id=client_id)
-
- if "client_id" not in request: # Optional for refresh access token request
- request["client_id"] = client_id
-
- logger.debug("%s: %s" % (request.__class__.__name__, sanitize(request)))
-
- return request
-
- def process_request(self, request=None, **kwargs):
- """
-
- :param request:
- :param kwargs:
- :return: Dictionary with response information
- """
- response_args = self._refresh_access_token(request, **kwargs)
-
- if isinstance(response_args, ResponseMessage):
- return response_args
-
- _token = request["refresh_token"].replace(" ", "+")
- _cookie = new_cookie(
- self.endpoint_context,
- sub=self.endpoint_context.sdb[_token]["sub"],
- cookie_name=self.endpoint_context.cookie_name["session"],
- )
- _headers = [("Content-type", "application/json")]
- resp = {"response_args": response_args, "http_headers": _headers}
- if _cookie:
- resp["cookie"] = _cookie
- return resp
diff --git a/src/oidcendpoint/oidc/session.py b/src/oidcendpoint/oidc/session.py
index fcd9351..72ea8b4 100644
--- a/src/oidcendpoint/oidc/session.py
+++ b/src/oidcendpoint/oidc/session.py
@@ -6,11 +6,14 @@
from cryptojwt import as_unicode
from cryptojwt import b64d
+from cryptojwt.jwe.aes import AES_GCMEncrypter
+from cryptojwt.jwe.utils import split_ctx_and_tag
from cryptojwt.jws.exception import JWSException
from cryptojwt.jws.jws import factory
from cryptojwt.jws.utils import alg2keytype
from cryptojwt.jwt import JWT
from cryptojwt.utils import as_bytes
+from cryptojwt.utils import b64e
from oidcmsg.exception import InvalidRequest
from oidcmsg.message import Message
from oidcmsg.oauth2 import ResponseMessage
@@ -18,11 +21,13 @@
from oidcmsg.oidc.session import BACK_CHANNEL_LOGOUT_EVENT
from oidcmsg.oidc.session import EndSessionRequest
+from oidcendpoint import rndstr
from oidcendpoint.client_authn import UnknownOrNoAuthnMethod
-from oidcendpoint.common.authorization import verify_uri
from oidcendpoint.cookie import append_cookie
from oidcendpoint.endpoint import Endpoint
from oidcendpoint.endpoint_context import add_path
+from oidcendpoint.oauth2.authorization import verify_uri
+from oidcendpoint.session import session_key
logger = logging.getLogger(__name__)
@@ -85,13 +90,24 @@ def __init__(self, endpoint_context, **kwargs):
if _csi and not _csi.startswith("http"):
kwargs["check_session_iframe"] = add_path(endpoint_context.issuer, _csi)
Endpoint.__init__(self, endpoint_context, **kwargs)
+ self.iv = as_bytes(rndstr(24))
- def do_back_channel_logout(self, cinfo, sub, sid):
+ def _encrypt_sid(self, sid):
+ encrypter = AES_GCMEncrypter(key=as_bytes(self.endpoint_context.symkey))
+ enc_msg = encrypter.encrypt(as_bytes(sid), iv=self.iv)
+ return as_unicode(b64e(enc_msg))
+
+ def _decrypt_sid(self, enc_msg):
+ _msg = b64d(as_bytes(enc_msg))
+ encrypter = AES_GCMEncrypter(key=as_bytes(self.endpoint_context.symkey))
+ ctx, tag = split_ctx_and_tag(_msg)
+ return as_unicode(encrypter.decrypt(as_bytes(ctx), iv=self.iv, tag=as_bytes(tag)))
+
+ def do_back_channel_logout(self, cinfo, sid):
"""
:param cinfo: Client information
- :param sub: Subject identifier
- :param sid: The Issuer ID
+ :param sid: The session ID
:return: Tuple with logout URI and signed logout token
"""
@@ -102,10 +118,16 @@ def do_back_channel_logout(self, cinfo, sub, sid):
except KeyError:
return None
+ # Create the logout token
# always include sub and sid so I don't check for
# backchannel_logout_session_required
- payload = {"sub": sub, "sid": sid, "events": {BACK_CHANNEL_LOGOUT_EVENT: {}}}
+ enc_msg = self._encrypt_sid(sid)
+
+ payload = {
+ "sid": enc_msg,
+ "events": {BACK_CHANNEL_LOGOUT_EVENT: {}}
+ }
try:
alg = cinfo["id_token_signed_response_alg"]
@@ -114,59 +136,43 @@ def do_back_channel_logout(self, cinfo, sub, sid):
_jws = JWT(_cntx.keyjar, iss=_cntx.issuer, lifetime=86400, sign_alg=alg)
_jws.with_jti = True
- sjwt = _jws.pack(payload=payload, recv=cinfo["client_id"])
+ _logout_token = _jws.pack(payload=payload, recv=cinfo["client_id"])
- return back_channel_logout_uri, sjwt
+ return back_channel_logout_uri, _logout_token
def clean_sessions(self, usids):
- # Clean out all sessions
- _sdb = self.endpoint_context.sdb
- _sso_db = self.endpoint_context.sdb.sso_db
+ # Revoke all sessions
for sid in usids:
- # remove session information
- del _sdb[sid]
- _sso_db.remove_session_id(sid)
-
- def logout_all_clients(self, sid, client_id):
- _sdb = self.endpoint_context.sdb
- _sso_db = self.endpoint_context.sdb.sso_db
-
- # Find all RPs this user has logged it from
- uid = _sso_db.get_uid_by_sid(sid)
- if uid is None:
- logger.debug("Can not translate sid:%s into a user id", sid)
- return {}
-
- _client_sid = {}
- usids = _sso_db.get_sids_by_uid(uid)
- if usids is None:
- logger.debug("No sessions found for uid: %s", uid)
- return {}
+ self.endpoint_context.session_manager.revoke_client_session(sid)
- for usid in usids:
- _client_sid[_sdb[usid]["authn_req"]["client_id"]] = usid
+ def logout_all_clients(self, sid):
+ _mngr = self.endpoint_context.session_manager
+ _session_info = _mngr.get_session_info(sid, user_session_info=True,
+ client_session_info=True)
# Front-/Backchannel logout ?
_cdb = self.endpoint_context.cdb
_iss = self.endpoint_context.issuer
+ _user_id = _session_info["user_id"]
bc_logouts = {}
fc_iframes = {}
- for _cid, _csid in _client_sid.items():
- if "backchannel_logout_uri" in _cdb[_cid]:
- _sub = _sso_db.get_sub_by_sid(_csid)
- _spec = self.do_back_channel_logout(_cdb[_cid], _sub, _csid)
+ _rel_sid = []
+ for _client_id in _session_info["user_session_info"]["subordinate"]:
+ if "backchannel_logout_uri" in _cdb[_client_id]:
+ _sid = session_key(_user_id, _client_id)
+ _rel_sid.append(_sid)
+ _spec = self.do_back_channel_logout(_cdb[_client_id], _sid)
if _spec:
- bc_logouts[_cid] = _spec
- elif "frontchannel_logout_uri" in _cdb[_cid]:
+ bc_logouts[_client_id] = _spec
+ elif "frontchannel_logout_uri" in _cdb[_client_id]:
# Construct an IFrame
- _spec = do_front_channel_logout_iframe(_cdb[_cid], _iss, _csid)
+ _sid = session_key(_user_id, _client_id)
+ _rel_sid.append(_sid)
+ _spec = do_front_channel_logout_iframe(_cdb[_client_id], _iss, _sid)
if _spec:
- fc_iframes[_cid] = _spec
+ fc_iframes[_client_id] = _spec
- try:
- self.clean_sessions(usids)
- except KeyError:
- pass
+ self.clean_sessions(_rel_sid)
res = {}
if bc_logouts:
@@ -189,33 +195,26 @@ def unpack_signed_jwt(self, sjwt, sig_alg=""):
else:
raise ValueError("Not a signed JWT")
- def logout_from_client(self, sid, client_id):
+ def logout_from_client(self, sid):
_cdb = self.endpoint_context.cdb
- _sso_db = self.endpoint_context.sdb.sso_db
-
- # Kill the session
- _sdb = self.endpoint_context.sdb
- _sdb.revoke_session(sid=sid)
+ _session_information = self.endpoint_context.session_manager.get_session_info(
+ sid, grant=True)
+ _client_id = _session_information["client_id"]
res = {}
- if "backchannel_logout_uri" in _cdb[client_id]:
- _sub = _sso_db.get_sub_by_sid(sid)
- _spec = self.do_back_channel_logout(_cdb[client_id], _sub, sid)
+ if "backchannel_logout_uri" in _cdb[_client_id]:
+ _spec = self.do_back_channel_logout(_cdb[_client_id], sid)
if _spec:
- res["blu"] = {client_id: _spec}
- elif "frontchannel_logout_uri" in _cdb[client_id]:
+ res["blu"] = {_client_id: _spec}
+ elif "frontchannel_logout_uri" in _cdb[_client_id]:
# Construct an IFrame
_spec = do_front_channel_logout_iframe(
- _cdb[client_id], self.endpoint_context.issuer, sid
+ _cdb[_client_id], self.endpoint_context.issuer, sid
)
if _spec:
- res["flu"] = {client_id: _spec}
-
- try:
- self.clean_sessions([sid])
- except KeyError:
- pass
+ res["flu"] = {_client_id: _spec}
+ self.clean_sessions([sid])
return res
def process_request(self, request=None, cookie=None, **kwargs):
@@ -228,7 +227,7 @@ def process_request(self, request=None, cookie=None, **kwargs):
:return:
"""
_cntx = self.endpoint_context
- _sdb = _cntx.sdb
+ _mngr = _cntx.session_manager
if "post_logout_redirect_uri" in request:
if "id_token_hint" not in request:
@@ -249,63 +248,31 @@ def process_request(self, request=None, cookie=None, **kwargs):
# value is a base64 encoded JSON document
_cookie_info = json.loads(as_unicode(b64d(as_bytes(part[0]))))
logger.debug("Cookie info: {}".format(_cookie_info))
- _sid = _cookie_info["sid"]
+ try:
+ _session_info = _mngr.get_session_info(_cookie_info["sid"],
+ grant=True)
+ except KeyError:
+ raise ValueError("Can't find any corresponding session")
else:
logger.debug("No relevant cookie")
- _sid = ""
- _cookie_info = {}
+ raise ValueError("Missing cookie")
- if "id_token_hint" in request:
+ if "id_token_hint" in request and _session_info:
+ _id_token = request[verified_claim_name("id_token_hint")]
logger.debug(
- "ID token hint: {}".format(
- request[verified_claim_name("id_token_hint")]
- )
+ "ID token hint: {}".format(_id_token)
)
- auds = request[verified_claim_name("id_token_hint")]["aud"]
- _ith_sid = ""
- _sids = _sdb.sso_db.get_sids_by_sub(
- request[verified_claim_name("id_token_hint")]["sub"]
- )
-
- if _sids is None:
- raise ValueError("Unknown subject identifier")
-
- for _isid in _sids:
- if _sdb[_isid]["authn_req"]["client_id"] in auds:
- _ith_sid = _isid
- break
+ _aud = _id_token["aud"]
+ if _session_info["client_id"] not in _aud:
+ raise ValueError("Client ID doesn't match")
- if not _ith_sid:
- raise ValueError("Unknown subject")
-
- if _sid:
- if _ith_sid != _sid: # someone's messing with me
- raise ValueError("Wrong ID Token hint")
- else:
- _sid = _ith_sid
+ if _id_token["sub"] != _session_info["grant"].sub:
+ raise ValueError("Sub doesn't match")
else:
- auds = []
-
- try:
- session = _sdb[_sid]
- except KeyError:
- raise ValueError("Can't find any corresponding session")
-
- client_id = session["authn_req"]["client_id"]
- # Does this match what's in the cookie ?
- if _cookie_info:
- if client_id != _cookie_info["client_id"]:
- logger.warning(
- "Client ID in authz request and in cookie does not match"
- )
- raise ValueError("Wrong Client")
-
- if auds:
- if client_id not in auds:
- raise ValueError("Incorrect ID Token hint")
+ _aud = []
- _cinfo = _cntx.cdb[client_id]
+ _cinfo = _cntx.cdb[_session_info["client_id"]]
# verify that the post_logout_redirect_uri if present are among the ones
# registered
@@ -320,12 +287,11 @@ def process_request(self, request=None, cookie=None, **kwargs):
plur = False
else:
plur = True
- verify_uri(_cntx, request, "post_logout_redirect_uri", client_id=client_id)
+ verify_uri(_cntx, request, "post_logout_redirect_uri",
+ client_id=_session_info["client_id"])
payload = {
- "sid": _sid,
- "client_id": client_id,
- "user": session["authn_event"]["uid"],
+ "sid": _session_info["session_id"],
}
# redirect user to OP logout verification page
@@ -387,20 +353,20 @@ def parse_request(self, request, auth=None, **kwargs):
pass
else:
if (
- _ith.jws_header["alg"]
- not in self.endpoint_context.provider_info[
- "id_token_signing_alg_values_supported"
- ]
+ _ith.jws_header["alg"]
+ not in self.endpoint_context.provider_info[
+ "id_token_signing_alg_values_supported"
+ ]
):
raise JWSException("Unsupported signing algorithm")
return request
- def do_verified_logout(self, sid, client_id, alla=False, **kwargs):
+ def do_verified_logout(self, sid, alla=False, **kwargs):
if alla:
- _res = self.logout_all_clients(sid=sid, client_id=client_id)
+ _res = self.logout_all_clients(sid=sid)
else:
- _res = self.logout_from_client(sid=sid, client_id=client_id)
+ _res = self.logout_from_client(sid=sid)
bcl = _res.get("blu")
if bcl:
diff --git a/src/oidcendpoint/oidc/token.py b/src/oidcendpoint/oidc/token.py
index 74eb7f9..1dffa2c 100755
--- a/src/oidcendpoint/oidc/token.py
+++ b/src/oidcendpoint/oidc/token.py
@@ -1,45 +1,94 @@
import logging
+from typing import Optional
+from typing import Union
from cryptojwt.jwe.exception import JWEException
from cryptojwt.jws.exception import NoSuitableSigningKeys
+from cryptojwt.jwt import utc_time_sans_frac
from oidcmsg import oidc
+from oidcmsg.message import Message
from oidcmsg.oauth2 import ResponseMessage
-from oidcmsg.oidc import AccessTokenResponse
+from oidcmsg.oidc import RefreshAccessTokenRequest
from oidcmsg.oidc import TokenErrorResponse
+from oidcmsg.time_util import time_sans_frac
from oidcendpoint import sanitize
from oidcendpoint.cookie import new_cookie
from oidcendpoint.endpoint import Endpoint
-from oidcendpoint.exception import MultipleCodeUsage
-from oidcendpoint.token_handler import AccessCodeUsed
-from oidcendpoint.userinfo import by_schema
+from oidcendpoint.exception import ProcessError
+from oidcendpoint.session import unpack_session_key
+from oidcendpoint.session.grant import AuthorizationCode
+from oidcendpoint.session.grant import Grant
+from oidcendpoint.session.grant import RefreshToken
+from oidcendpoint.session.token import Token as sessionToken
+from oidcendpoint.token.exception import UnknownToken
+from oidcendpoint.util import importer
logger = logging.getLogger(__name__)
-class AccessToken(Endpoint):
- request_cls = oidc.AccessTokenRequest
- response_cls = oidc.AccessTokenResponse
- error_cls = TokenErrorResponse
- request_format = "json"
- request_placement = "body"
- response_format = "json"
- response_placement = "body"
- endpoint_name = "token_endpoint"
- name = "token"
- default_capabilities = {"token_endpoint_auth_signing_alg_values_supported": None}
+class TokenEndpointHelper(object):
+ def __init__(self, endpoint, config=None):
+ self.endpoint = endpoint
+ self.config = config
+ self.endpoint_context = endpoint.endpoint_context
+ self.error_cls = self.endpoint.error_cls
- def __init__(self, endpoint_context, **kwargs):
- Endpoint.__init__(self, endpoint_context, **kwargs)
- self.post_parse_request.append(self._post_parse_request)
- if "client_authn_method" in kwargs:
- self.endpoint_info["token_endpoint_auth_methods_supported"] = kwargs[
- "client_authn_method"
- ]
+ def post_parse_request(self, request: Union[Message, dict],
+ client_id: Optional[str] = "",
+ **kwargs):
+ """Context specific parsing of the request.
+ This is done after general request parsing and before processing
+ the request.
+ """
+ raise NotImplementedError
+
+ def process_request(self, req: Union[Message, dict], **kwargs):
+ """Acts on a process request."""
+ raise NotImplementedError
+
+ def _mint_token(self, token_type: str, grant: Grant, session_id: str,
+ based_on: Optional[sessionToken] = None) -> sessionToken:
+ _mngr = self.endpoint_context.session_manager
+ usage_rules = grant.usage_rules.get(token_type)
+ if usage_rules:
+ _exp_in = usage_rules.get("expires_in")
+ else:
+ _exp_in = 0
+
+ token = grant.mint_token(
+ session_id,
+ endpoint_context=self.endpoint_context,
+ token_type=token_type,
+ token_handler=_mngr.token_handler["access_token"],
+ based_on=based_on,
+ usage_rules=usage_rules
+ )
+
+ if _exp_in:
+ if isinstance(_exp_in, str):
+ _exp_in = int(_exp_in)
+
+ if _exp_in:
+ token.expires_at = time_sans_frac() + _exp_in
+
+ self.endpoint_context.session_manager.set(
+ unpack_session_key(session_id), grant)
+
+ return token
+
+
+class AccessTokenHelper(TokenEndpointHelper):
+ def process_request(self, req: Union[Message, dict], **kwargs):
+ """
- def _access_token(self, req, **kwargs):
- _context = self.endpoint_context
- _sdb = _context.sdb
+ :param req:
+ :param kwargs:
+ :return:
+ """
+ _context = self.endpoint.endpoint_context
+
+ _mngr = _context.session_manager
_log_debug = logger.debug
if req["grant_type"] != "authorization_code":
@@ -54,21 +103,11 @@ def _access_token(self, req, **kwargs):
error="invalid_request", error_description="Missing code"
)
- # Session might not exist or _access_code malformed
- try:
- _info = _sdb[_access_code]
- except KeyError:
- return self.error_cls(
- error="invalid_grant", error_description="Code is invalid"
- )
-
- _authn_req = _info["authn_req"]
+ _session_info = _mngr.get_session_info_by_token(_access_code, grant=True)
+ grant = _session_info["grant"]
- # assert that the code is valid
- if _context.sdb.is_session_revoked(_access_code):
- return self.error_cls(
- error="invalid_grant", error_description="Session is revoked"
- )
+ code = grant.get_token(_access_code)
+ _authn_req = grant.authorization_request
# If redirect_uri was in the initial authorization request
# verify that the one given here is the correct one.
@@ -83,26 +122,132 @@ def _access_token(self, req, **kwargs):
issue_refresh = False
if "issue_refresh" in kwargs:
issue_refresh = kwargs["issue_refresh"]
+ else:
+ if "offline_access" in grant.scope:
+ issue_refresh = True
+
+ _response = {
+ "token_type": "Bearer",
+ "scope": grant.scope,
+ }
+
+ token = self._mint_token(token_type="access_token",
+ grant=grant,
+ session_id=_session_info["session_id"],
+ based_on=code)
+ _response["access_token"] = token.value
+ _response["expires_in"] = token.expires_at - utc_time_sans_frac()
+
+ if issue_refresh:
+ refresh_token = self._mint_token(token_type="refresh_token",
+ grant=grant,
+ session_id=_session_info["session_id"],
+ based_on=code)
+ _response["refresh_token"] = refresh_token.value
+
+ code.register_usage()
+
+ # since the grant content has changed
+ _mngr[_session_info["session_id"]] = grant
+
+ if "openid" in _authn_req["scope"]:
+ try:
+ _idtoken = _context.idtoken.make(_session_info["session_id"])
+ except (JWEException, NoSuitableSigningKeys) as err:
+ logger.warning(str(err))
+ resp = self.error_cls(
+ error="invalid_request",
+ error_description="Could not sign/encrypt id_token",
+ )
+ return resp
+
+ _response["id_token"] = _idtoken
- # offline_access the default if nothing is specified
- permissions = _info.get("permission", ["offline_access"])
+ return _response
- if "offline_access" in _authn_req["scope"] and "offline_access" in permissions:
- issue_refresh = True
+ def post_parse_request(self, request: Union[Message, dict],
+ client_id: Optional[str] = "",
+ **kwargs):
+ """
+ This is where clients come to get their access tokens
+ :param request: The request
+ :param client_id: Client identifier
+ :returns:
+ """
+
+ _mngr = self.endpoint_context.session_manager
try:
- _info = _sdb.upgrade_to_token(_access_code, issue_refresh=issue_refresh)
- except AccessCodeUsed as err:
- logger.error("%s" % err)
- # Should revoke the token issued to this access code
- _sdb.revoke_all_tokens(_access_code)
+ _session_info = _mngr.get_session_info_by_token(request["code"],
+ grant=True)
+ except (KeyError, UnknownToken):
+ logger.error("Access Code invalid")
+ return self.error_cls(error="invalid_grant",
+ error_description="Unknown code")
+
+ _grant = _session_info["grant"]
+ code = _grant.get_token(request["code"])
+ if not isinstance(code, AuthorizationCode):
return self.error_cls(
- error="access_denied", error_description="Access Code already used"
+ error="invalid_request", error_description="Wrong token type"
)
- if "openid" in _authn_req["scope"]:
+ if code.is_active() is False:
+ return self.error_cls(
+ error="invalid_request", error_description="Code inactive"
+ )
+
+ _auth_req = _grant.authorization_request
+
+ if "client_id" not in request: # Optional for access token request
+ request["client_id"] = _auth_req["client_id"]
+
+ logger.debug("%s: %s" % (request.__class__.__name__, sanitize(request)))
+
+ return request
+
+
+class RefreshTokenHelper(TokenEndpointHelper):
+ def process_request(self, req: Union[Message, dict], **kwargs):
+ _mngr = self.endpoint_context.session_manager
+
+ if req["grant_type"] != "refresh_token":
+ return self.error_cls(
+ error="invalid_request", error_description="Wrong grant_type"
+ )
+
+ token_value = req["refresh_token"]
+ _session_info = _mngr.get_session_info_by_token(token_value, grant=True)
+
+ _grant = _session_info["grant"]
+ token = _grant.get_token(token_value)
+
+ access_token = self._mint_token(token_type="access_token",
+ grant=_grant,
+ session_id=_session_info["session_id"],
+ based_on=token)
+
+ _resp = {
+ "access_token": access_token.value,
+ "token_type": "Bearer",
+ "scope": _grant.scope
+ }
+
+ if access_token.expires_at:
+ _resp["expires_in"] = access_token.expires_at - utc_time_sans_frac()
+
+ _mints = token.usage_rules.get("supports_minting")
+ if "refresh_token" in _mints:
+ refresh_token = self._mint_token(token_type="refresh_token",
+ grant=_grant,
+ session_id=_session_info["session_id"],
+ based_on=token)
+ refresh_token.usage_rules = token.usage_rules.copy()
+ _resp["refresh_token"] = refresh_token.value
+
+ if "id_token" in _mints:
try:
- _idtoken = _context.idtoken.make(req, _info, _authn_req)
+ _idtoken = self.endpoint_context.idtoken.make(_session_info["session_id"])
except (JWEException, NoSuitableSigningKeys) as err:
logger.warning(str(err))
resp = self.error_cls(
@@ -111,50 +256,135 @@ def _access_token(self, req, **kwargs):
)
return resp
- _sdb.update_by_token(_access_code, id_token=_idtoken)
- _info = _sdb[_info["sid"]]
-
- return by_schema(AccessTokenResponse, **_info)
+ _resp["id_token"] = _idtoken
- def get_client_id_from_token(self, endpoint_context, token, request=None):
- sinfo = endpoint_context.sdb[token]
- return sinfo["authn_req"]["client_id"]
+ return _resp
- def _post_parse_request(self, request, client_id="", **kwargs):
+ def post_parse_request(self, request: Union[Message, dict],
+ client_id: Optional[str] = "",
+ **kwargs):
"""
- This is where clients come to get their access tokens
+ This is where clients come to refresh their access tokens
:param request: The request
:param authn: Authentication info, comes from HTTP header
:returns:
"""
- if "state" in request:
- try:
- sinfo = self.endpoint_context.sdb[request["code"]]
- except KeyError:
- logger.error("Code not present in SessionDB")
- return self.error_cls(error="access_denied")
- except MultipleCodeUsage:
- logger.error("Access Code reused")
- # Remove any access tokens issued
- self.endpoint_context.sdb.revoke_all_tokens(request["code"])
- return self.error_cls(error="invalid_grant")
- else:
- state = sinfo["authn_req"]["state"]
+ request = RefreshAccessTokenRequest(**request.to_dict())
- if state != request["state"]:
- logger.error("State value mismatch")
- return self.error_cls(error="invalid_request")
+ try:
+ keyjar = self.endpoint_context.keyjar
+ except AttributeError:
+ keyjar = ""
- if "client_id" not in request: # Optional for access token request
- request["client_id"] = client_id
+ request.verify(keyjar=keyjar, opponent_id=client_id)
- logger.debug("%s: %s" % (request.__class__.__name__, sanitize(request)))
+ _mngr = self.endpoint_context.session_manager
+ try:
+ _session_info = _mngr.get_session_info_by_token(
+ request["refresh_token"], grant=True
+ )
+ except KeyError:
+ logger.error("Access Code invalid")
+ return self.error_cls(error="invalid_grant")
+
+ _grant = _session_info["grant"]
+ token = _grant.get_token(request["refresh_token"])
+
+ if not isinstance(token, RefreshToken):
+ return self.error_cls(
+ error="invalid_request", error_description="Wrong token type"
+ )
+
+ if token.is_active() is False:
+ return self.error_cls(
+ error="invalid_request", error_description="Refresh token inactive"
+ )
return request
- def process_request(self, request=None, **kwargs):
+
+HELPER_BY_GRANT_TYPE = {
+ "authorization_code": AccessTokenHelper,
+ "refresh_token": RefreshTokenHelper,
+}
+
+
+class Token(Endpoint):
+ request_cls = oidc.Message
+ response_cls = oidc.AccessTokenResponse
+ error_cls = TokenErrorResponse
+ request_format = "json"
+ request_placement = "body"
+ response_format = "json"
+ response_placement = "body"
+ endpoint_name = "token_endpoint"
+ name = "token"
+ default_capabilities = {"token_endpoint_auth_signing_alg_values_supported": None}
+
+ def __init__(self, endpoint_context, new_refresh_token=False, **kwargs):
+ Endpoint.__init__(self, endpoint_context, **kwargs)
+ self.post_parse_request.append(self._post_parse_request)
+ if "client_authn_method" in kwargs:
+ self.endpoint_info["token_endpoint_auth_methods_supported"] = kwargs[
+ "client_authn_method"
+ ]
+ self.allow_refresh = False
+ self.new_refresh_token = new_refresh_token
+ self.configure_grant_types(kwargs.get("grant_types_supported"))
+
+ def configure_grant_types(self, grant_types_supported):
+ if grant_types_supported is None:
+ self.helper = {k: v(self) for k, v in HELPER_BY_GRANT_TYPE.items()}
+ return
+
+ self.helper = {}
+ # TODO: do we want to allow any grant_type?
+ for grant_type, grant_type_options in grant_types_supported.items():
+ if (
+ grant_type_options in (None, True)
+ and grant_type in HELPER_BY_GRANT_TYPE
+ ):
+ self.helper[grant_type] = HELPER_BY_GRANT_TYPE[grant_type]
+ continue
+ elif grant_type_options is False:
+ continue
+
+ try:
+ grant_class = grant_type_options["class"]
+ except (KeyError, TypeError):
+ raise ProcessError(
+ "Token Endpoint's grant types must be True, None or a dict with a"
+ " 'class' key."
+ )
+ _conf = grant_type_options.get("kwargs", {})
+
+ if isinstance(grant_class, str):
+ try:
+ grant_class = importer(grant_class)
+ except (ValueError, AttributeError):
+ raise ProcessError(
+ f"Token Endpoint's grant type class {grant_class} can't"
+ " be imported."
+ )
+ try:
+ self.helper[grant_type] = grant_class(self, _conf)
+ except Exception as e:
+ raise ProcessError(f"Failed to initialize class {grant_class}: {e}")
+
+ def _post_parse_request(self, request: Union[Message, dict],
+ client_id: Optional[str] = "", **kwargs):
+ _helper = self.helper.get(request["grant_type"])
+ if _helper:
+ return _helper.post_parse_request(request, client_id, **kwargs)
+ else:
+ return self.error_cls(
+ error="invalid_request",
+ error_description=f"Unsupported grant_type: {request['grant_type']}"
+ )
+
+ def process_request(self, request: Optional[Union[Message, dict]] = None, **kwargs):
"""
:param request:
@@ -163,8 +393,19 @@ def process_request(self, request=None, **kwargs):
"""
if isinstance(request, self.error_cls):
return request
+
+ if request is None:
+ return self.error_cls(error="invalid_request")
+
try:
- response_args = self._access_token(request, **kwargs)
+ _helper = self.helper.get(request["grant_type"])
+ if _helper:
+ response_args = _helper.process_request(request, **kwargs)
+ else:
+ return self.error_cls(
+ error="invalid_request",
+ error_description=f"Unsupported grant_type: {request['grant_type']}"
+ )
except JWEException as err:
return self.error_cls(error="invalid_request", error_description="%s" % err)
@@ -172,11 +413,15 @@ def process_request(self, request=None, **kwargs):
return response_args
_access_token = response_args["access_token"]
+ _session_info = self.endpoint_context.session_manager.get_session_info_by_token(
+ _access_token, grant=True)
+
_cookie = new_cookie(
self.endpoint_context,
- sub=self.endpoint_context.sdb[_access_token]["sub"],
+ sub=_session_info["grant"].sub,
cookie_name=self.endpoint_context.cookie_name["session"],
)
+
_headers = [("Content-type", "application/json")]
resp = {"response_args": response_args, "http_headers": _headers}
if _cookie:
diff --git a/src/oidcendpoint/oidc/token_coop.py b/src/oidcendpoint/oidc/token_coop.py
deleted file mode 100755
index 56272b9..0000000
--- a/src/oidcendpoint/oidc/token_coop.py
+++ /dev/null
@@ -1,328 +0,0 @@
-import logging
-
-from cryptojwt.jwe.exception import JWEException
-from cryptojwt.jws.exception import NoSuitableSigningKeys
-from oidcmsg import oidc
-from oidcmsg.exception import MissingRequiredAttribute
-from oidcmsg.exception import MissingRequiredValue
-from oidcmsg.oauth2 import ResponseMessage
-from oidcmsg.oidc import AccessTokenRequest
-from oidcmsg.oidc import AccessTokenResponse
-from oidcmsg.oidc import RefreshAccessTokenRequest
-from oidcmsg.oidc import TokenErrorResponse
-from oidcmsg.storage import importer
-
-from oidcendpoint import sanitize
-from oidcendpoint.cookie import new_cookie
-from oidcendpoint.endpoint import Endpoint
-from oidcendpoint.exception import MultipleCodeUsage
-from oidcendpoint.exception import ProcessError
-from oidcendpoint.token_handler import AccessCodeUsed
-from oidcendpoint.token_handler import ExpiredToken
-from oidcendpoint.token_handler import UnknownToken
-from oidcendpoint.userinfo import by_schema
-
-logger = logging.getLogger(__name__)
-
-
-class TokenEndpointHelper:
- def __init__(self, endpoint, config=None):
- self.endpoint = endpoint
- self.config = config
-
- def post_parse_request(self, request, client_id="", **kwargs):
- """Context specific parsing of the request.
- This is done after general request parsing and before processing
- the request.
- """
- raise NotImplementedError
-
- def process_request(self, req, **kwargs):
- """Acts on a process request."""
- raise NotImplementedError
-
-
-class AccessToken(TokenEndpointHelper):
- def post_parse_request(self, request, client_id="", **kwargs):
- """
- This is where clients come to get their access tokens
-
- :param request: The request
- :param authn: Authentication info, comes from HTTP header
- :returns:
- """
-
- request = AccessTokenRequest(**request.to_dict())
-
- error_cls = self._validate_state(request)
- if error_cls is not None:
- return error_cls
-
- if "client_id" not in request: # Optional for access token request
- request["client_id"] = client_id
-
- logger.debug("%s: %s" % (request.__class__.__name__, sanitize(request)))
-
- return request
-
- def _validate_state(self, request):
- if "state" not in request:
- return
-
- if "code" not in request:
- return
-
- try:
- sinfo = self.endpoint.endpoint_context.sdb[request["code"]]
- except (KeyError, UnknownToken, MultipleCodeUsage):
- return
-
- state = sinfo["authn_req"]["state"]
-
- if state != request["state"]:
- logger.error("State value mismatch")
- return self.endpoint.error_cls(error="invalid_request")
-
- def process_request(self, req, **kwargs):
- _context = self.endpoint.endpoint_context
- _sdb = _context.sdb
- _log_debug = logger.debug
-
- try:
- _access_code = req["code"].replace(" ", "+")
- except KeyError: # Missing code parameter - absolutely fatal
- return self.endpoint.error_cls(
- error="invalid_request", error_description="Missing code"
- )
-
- try:
- _info = _sdb[_access_code]
- except (KeyError, UnknownToken):
- logger.error("Code not present in SessionDB")
- return self.endpoint.error_cls(
- error="invalid_grant", error_description="Invalid code"
- )
- except MultipleCodeUsage:
- return self.endpoint.error_cls(
- error="invalid_grant", error_description="Code is already used"
- )
-
- _authn_req = _info["authn_req"]
-
- # assert that the code is valid
- if _context.sdb.is_session_revoked(_access_code):
- return self.endpoint.error_cls(
- error="invalid_grant", error_description="Session is revoked"
- )
-
- # If redirect_uri was in the initial authorization request
- # verify that the one given here is the correct one.
- if "redirect_uri" in _authn_req:
- if req["redirect_uri"] != _authn_req["redirect_uri"]:
- return self.endpoint.error_cls(
- error="invalid_request", error_description="redirect_uri mismatch"
- )
-
- _log_debug("All checks OK")
-
- issue_refresh = False
- if "issue_refresh" in kwargs:
- issue_refresh = kwargs["issue_refresh"]
-
- # offline_access the default if nothing is specified
- permissions = _info.get("permission", ["offline_access"])
-
- if "offline_access" in _authn_req["scope"] and "offline_access" in permissions:
- issue_refresh = True
-
- try:
- _info = _sdb.upgrade_to_token(_access_code, issue_refresh=issue_refresh)
- except AccessCodeUsed as err:
- logger.error("%s" % err)
- # Should revoke the token issued to this access code
- _sdb.revoke_all_tokens(_access_code)
- return self.endpoint.error_cls(
- error="access_denied", error_description="Access Code already used"
- )
-
- if "openid" in _authn_req["scope"]:
- try:
- _idtoken = _context.idtoken.make(req, _info, _authn_req)
- except (JWEException, NoSuitableSigningKeys) as err:
- logger.warning(str(err))
- resp = self.error_cls(
- error="invalid_request",
- error_description="Could not sign/encrypt id_token",
- )
- return resp
-
- _sdb.update_by_token(_access_code, id_token=_idtoken)
- _info = _sdb[_info["sid"]]
-
- return by_schema(AccessTokenResponse, **_info)
-
-
-class RefreshToken(TokenEndpointHelper):
- def post_parse_request(self, request, client_id="", **kwargs):
- """
- This is where clients come to refresh their access tokens
-
- :param request: The request
- :param authn: Authentication info, comes from HTTP header
- :returns:
- """
- request = RefreshAccessTokenRequest(**request.to_dict())
-
- # verify that the request message is correct
- try:
- request.verify(
- keyjar=self.endpoint.endpoint_context.keyjar, opponent_id=client_id
- )
- except (MissingRequiredAttribute, ValueError, MissingRequiredValue) as err:
- return self.endpoint.error_cls(
- error="invalid_request", error_description="%s" % err
- )
-
- if "client_id" not in request: # Optional for refresh access token request
- request["client_id"] = client_id
-
- logger.debug("%s: %s" % (request.__class__.__name__, sanitize(request)))
-
- return request
-
- def process_request(self, req, **kwargs):
- _sdb = self.endpoint.endpoint_context.sdb
-
- rtoken = req["refresh_token"]
- try:
- _info = _sdb.refresh_token(
- rtoken, new_refresh=self.endpoint.new_refresh_token
- )
- except ExpiredToken:
- return self.error_cls(
- error="invalid_request", error_description="Refresh token is expired"
- )
-
- return by_schema(AccessTokenResponse, **_info)
-
-
-HELPER_BY_GRANT_TYPE = {
- "authorization_code": AccessToken,
- "refresh_token": RefreshToken,
-}
-
-
-class TokenCoop(Endpoint):
- request_cls = oidc.Message
- response_cls = oidc.AccessTokenResponse
- error_cls = TokenErrorResponse
- request_format = "json"
- request_placement = "body"
- response_format = "json"
- response_placement = "body"
- endpoint_name = "token_endpoint"
- name = "token"
- default_capabilities = {"token_endpoint_auth_signing_alg_values_supported": None}
-
- def __init__(self, endpoint_context, new_refresh_token=False, **kwargs):
- Endpoint.__init__(self, endpoint_context, **kwargs)
- self.post_parse_request.append(self._post_parse_request)
- if "client_authn_method" in kwargs:
- self.endpoint_info["token_endpoint_auth_methods_supported"] = kwargs[
- "client_authn_method"
- ]
-
- # self.allow_refresh = False
- self.new_refresh_token = new_refresh_token
-
- self.configure_grant_types(kwargs.get("grant_types_supported"))
-
- def configure_grant_types(self, grant_types_supported):
- if grant_types_supported is None:
- self.helper = {k: v(self) for k, v in HELPER_BY_GRANT_TYPE.items()}
- return
-
- self.helper = {}
- # TODO: do we want to allow any grant_type?
- for grant_type, grant_type_options in grant_types_supported.items():
- if (
- grant_type_options in (None, True)
- and grant_type in HELPER_BY_GRANT_TYPE
- ):
- self.helper[grant_type] = HELPER_BY_GRANT_TYPE[grant_type]
- continue
- elif grant_type_options is False:
- continue
-
- try:
- grant_class = grant_type_options["class"]
- except (KeyError, TypeError):
- raise ProcessError(
- "Token Endpoint's grant types must be True, None or a dict with a"
- " 'class' key."
- )
- _conf = grant_type_options.get("kwargs", {})
-
- if isinstance(grant_class, str):
- try:
- grant_class = importer(grant_class)
- except (ValueError, AttributeError):
- raise ProcessError(
- f"Token Endpoint's grant type class {grant_class} can't"
- " be imported."
- )
- try:
- self.helper[grant_type] = grant_class(self, _conf)
- except Exception as e:
- raise ProcessError(f"Failed to initialize class {grant_class}: {e}")
-
- def get_client_id_from_token(self, endpoint_context, token, request=None):
- sinfo = endpoint_context.sdb[token]
- return sinfo["authn_req"]["client_id"]
-
- def _post_parse_request(self, request, client_id="", **kwargs):
- _helper = self.helper.get(request["grant_type"])
- if _helper:
- return _helper.post_parse_request(request, client_id, **kwargs)
- else:
- return self.error_cls(
- error="invalid_request",
- error_description=f"Unsupported grant_type: {request['grant_type']}"
- )
-
- def process_request(self, request=None, **kwargs):
- """
-
- :param request:
- :param kwargs:
- :return: Dictionary with response information
- """
- if isinstance(request, self.error_cls):
- return request
- try:
- _helper = self.helper.get(request["grant_type"])
- if _helper:
- response_args = _helper.process_request(request, **kwargs)
- else:
- return self.error_cls(
- error="invalid_request",
- error_description=f"Unsupported grant_type: {request['grant_type']}"
- )
- except JWEException as err:
- return self.error_cls(error="invalid_request", error_description="%s" % err)
-
- if isinstance(response_args, ResponseMessage):
- return response_args
-
- _access_token = response_args["access_token"]
- _cookie = new_cookie(
- self.endpoint_context,
- sub=self.endpoint_context.sdb[_access_token]["sub"],
- cookie_name=self.endpoint_context.cookie_name["session"],
- )
-
- _headers = [("Content-type", "application/json")]
- resp = {"response_args": response_args, "http_headers": _headers}
- if _cookie:
- resp["cookie"] = _cookie
- return resp
diff --git a/src/oidcendpoint/oidc/userinfo.py b/src/oidcendpoint/oidc/userinfo.py
index ffe57e7..dd2990f 100755
--- a/src/oidcendpoint/oidc/userinfo.py
+++ b/src/oidcendpoint/oidc/userinfo.py
@@ -1,16 +1,19 @@
import json
import logging
+from typing import Optional
+from typing import Union
from cryptojwt.exception import MissingValue
from cryptojwt.jwt import JWT
from cryptojwt.jwt import utc_time_sans_frac
-from oidcendpoint.token_handler import UnknownToken
from oidcmsg import oidc
from oidcmsg.message import Message
from oidcmsg.oauth2 import ResponseMessage
from oidcendpoint.endpoint import Endpoint
-from oidcendpoint.userinfo import collect_user_info
+from oidcendpoint.endpoint_context import EndpointContext
+from oidcendpoint.session.token import AccessToken
+from oidcendpoint.token.exception import UnknownToken
from oidcendpoint.util import OAUTH2_NOCACHE_HEADERS
logger = logging.getLogger(__name__)
@@ -32,17 +35,25 @@ class UserInfo(Endpoint):
"client_authn_method": ["bearer_header"],
}
- def __init__(self, endpoint_context, **kwargs):
- Endpoint.__init__(self, endpoint_context, **kwargs)
- self.scope_to_claims = None
+ def __init__(self, endpoint_context: EndpointContext,
+ add_claims_by_scope: Optional[bool] = True,
+ **kwargs):
+ Endpoint.__init__(
+ self,
+ endpoint_context,
+ add_claims_by_scope=add_claims_by_scope,
+ **kwargs,
+ )
# Add the issuer ID as an allowed JWT target
self.allowed_targets.append("")
def get_client_id_from_token(self, endpoint_context, token, request=None):
- sinfo = self.endpoint_context.sdb[token]
- return sinfo["authn_req"]["client_id"]
+ _info = endpoint_context.session_manager.get_session_info_by_token(token)
+ return _info["client_id"]
- def do_response(self, response_args=None, request=None, client_id="", **kwargs):
+ def do_response(self, response_args: Optional[Union[Message, dict]] = None,
+ request: Optional[Union[Message, dict]] = None,
+ client_id: Optional[str] = "", **kwargs) -> dict:
if "error" in kwargs and kwargs["error"]:
return Endpoint.do_response(self, response_args, request, **kwargs)
@@ -96,23 +107,31 @@ def do_response(self, response_args=None, request=None, client_id="", **kwargs):
return {"response": resp, "http_headers": http_headers}
def process_request(self, request=None, **kwargs):
- _sdb = self.endpoint_context.sdb
-
+ _mngr = self.endpoint_context.session_manager
+ _session_info = _mngr.get_session_info_by_token(request["access_token"],
+ grant=True)
+ _grant = _session_info["grant"]
+ token = _grant.get_token(request["access_token"])
# should be an access token
- if not _sdb.is_token_valid(request["access_token"]):
+ if not isinstance(token, AccessToken):
return self.error_cls(
- error="invalid_token", error_description="Invalid Token"
+ error="invalid_token", error_description="Wrong type of token"
)
- session = _sdb.read(request["access_token"])
+ # And it should be valid
+ if token.is_active() is False:
+ return self.error_cls(
+ error="invalid_token", error_description="Invalid Token"
+ )
allowed = True
+ _auth_event = _grant.authentication_event
# if the authenticate is still active or offline_access is granted.
- if session["authn_event"]["valid_until"] > utc_time_sans_frac():
+ if _auth_event["valid_until"] > utc_time_sans_frac():
pass
else:
logger.debug("authentication not valid: {} > {}".format(
- session["authn_event"]["valid_until"], utc_time_sans_frac()
+ _auth_event["valid_until"], utc_time_sans_frac()
))
allowed = False
@@ -121,15 +140,18 @@ def process_request(self, request=None, **kwargs):
# pass
if allowed:
- # Scope can translate to userinfo_claims
- info = collect_user_info(self.endpoint_context, session)
+ _claims = _grant.claims.get("userinfo")
+ info = self.endpoint_context.claims_interface.get_user_claims(
+ user_id=_session_info["user_id"],
+ claims_restriction=_claims)
+ info["sub"] = _grant.sub
else:
info = {
"error": "invalid_request",
"error_description": "Access not granted",
}
- return {"response_args": info, "client_id": session["authn_req"]["client_id"]}
+ return {"response_args": info, "client_id": _session_info["client_id"]}
def parse_request(self, request, auth=None, **kwargs):
"""
diff --git a/src/oidcendpoint/scopes.py b/src/oidcendpoint/scopes.py
index 8a04544..80166b9 100644
--- a/src/oidcendpoint/scopes.py
+++ b/src/oidcendpoint/scopes.py
@@ -1,5 +1,4 @@
# default set can be changed by configuration
-from oidcmsg.oidc import OpenIDSchema
SCOPE2CLAIMS = {
"openid": ["sub"],
@@ -25,9 +24,6 @@
"offline_access": [],
}
-IGNORE = ["error", "error_description", "error_uri", "_claim_names", "_claim_sources"]
-STANDARD_CLAIMS = [c for c in OpenIDSchema.c_param.keys() if c not in IGNORE]
-
def available_scopes(endpoint_context):
_supported = endpoint_context.provider_info.get("scopes_supported")
@@ -37,27 +33,23 @@ def available_scopes(endpoint_context):
return [s for s in endpoint_context.scope2claims.keys()]
-def available_claims(endpoint_context):
- _supported = endpoint_context.provider_info.get("claims_supported")
- if _supported:
- return _supported
- else:
- return STANDARD_CLAIMS
-
-
-def convert_scopes2claims(scopes, allowed_claims, map=None):
+def convert_scopes2claims(scopes, allowed_claims=None, map=None):
if map is None:
map = SCOPE2CLAIMS
res = {}
- for scope in scopes:
- try:
- claims = dict(
- [(name, None) for name in map[scope] if name in allowed_claims]
- )
+ if allowed_claims is None:
+ for scope in scopes:
+ claims = {name: None for name in map[scope]}
res.update(claims)
- except KeyError:
- continue
+ else:
+ for scope in scopes:
+ try:
+ claims = {name: None for name in map[scope] if name in allowed_claims}
+ res.update(claims)
+ except KeyError:
+ continue
+
return res
@@ -85,25 +77,3 @@ def allowed_scopes(self, client_id, endpoint_context):
def filter_scopes(self, client_id, endpoint_context, scopes):
allowed_scopes = self.allowed_scopes(client_id, endpoint_context)
return [s for s in scopes if s in allowed_scopes]
-
-
-class Claims:
- def __init__(self):
- pass
-
- def allowed_claims(self, client_id, endpoint_context):
- """
- Returns the set of claims that a specific client can use.
-
- :param client_id: The client identifier
- :param endpoint_context: A EndpointContext instance
- :returns: List of claim names. Can be empty.
- """
- _cli = endpoint_context.cdb.get(client_id)
- if _cli is not None:
- _claims = _cli.get("allowed_claims")
- if _claims:
- return _claims
- else:
- return available_claims(endpoint_context)
- return []
diff --git a/src/oidcendpoint/session.py b/src/oidcendpoint/session.py
deleted file mode 100644
index 46f71d3..0000000
--- a/src/oidcendpoint/session.py
+++ /dev/null
@@ -1,621 +0,0 @@
-import hashlib
-import json
-import logging
-import time
-
-from oidcmsg.exception import MissingParameter
-from oidcmsg.message import OPTIONAL_LIST_OF_STRINGS
-from oidcmsg.message import SINGLE_OPTIONAL_STRING
-from oidcmsg.message import SINGLE_REQUIRED_STRING
-from oidcmsg.message import Message
-from oidcmsg.message import msg_ser
-from oidcmsg.oidc import AuthorizationRequest
-from oidcmsg.time_util import utc_time_sans_frac
-
-from oidcendpoint import token_handler
-from oidcendpoint.authn_event import AuthnEvent
-from oidcendpoint.exception import MultipleCodeUsage
-from oidcendpoint.token_handler import ExpiredToken
-from oidcendpoint.token_handler import UnknownToken
-from oidcendpoint.token_handler import WrongTokenType
-from oidcendpoint.token_handler import is_expired
-from oidcendpoint.util import sector_id_from_redirect_uris
-
-logger = logging.getLogger(__name__)
-
-
-def authorization_request_deser(val, sformat="urlencoded"):
- if sformat in ["dict", "json"]:
- if not isinstance(val, str):
- val = json.dumps(val)
- sformat = "json"
- return AuthorizationRequest().deserialize(val, sformat)
-
-
-SINGLE_REQUIRED_AUTHORIZATION_REQUEST = (
- Message,
- True,
- msg_ser,
- authorization_request_deser,
- False,
-)
-
-
-def authn_event_deser(val, sformat="urlencoded"):
- if sformat in ["dict", "json"]:
- if not isinstance(val, str):
- val = json.dumps(val)
- sformat = "json"
- return AuthnEvent().deserialize(val, sformat)
-
-
-def setup_session(
- endpoint_context,
- areq,
- uid,
- client_id="",
- acr="",
- salt="salt",
- authn_event=None,
- subject_type=None,
- sector_id=None
-):
- """
- Setting up a user session
-
- :param endpoint_context:
- :param areq:
- :param uid:
- :param acr:
- :param client_id:
- :param salt:
- :param authn_event: A already made AuthnEvent
- :return:
- """
- if authn_event is None and acr:
- authn_event = AuthnEvent(
- uid=uid, salt=salt, authn_info=acr, authn_time=time.time()
- )
-
- if not client_id:
- client_id = areq["client_id"]
-
- sid = endpoint_context.sdb.create_authz_session(
- authn_event, areq, client_id=client_id, uid=uid
- )
-
- client = endpoint_context.cdb.get(client_id)
- endpoint_context.sdb.do_sub(
- sid,
- uid,
- salt=salt,
- sector_id=sector_id,
- subject_type=subject_type,
- client=client,
- )
- return sid
-
-
-SINGLE_REQUIRED_AUTHN_EVENT = (Message, True, msg_ser, authn_event_deser, False)
-
-
-class SessionInfo(Message):
- c_param = {
- "oauth_state": SINGLE_REQUIRED_STRING,
- "code": SINGLE_OPTIONAL_STRING,
- "authn_req": SINGLE_REQUIRED_AUTHORIZATION_REQUEST,
- "client_id": SINGLE_REQUIRED_STRING,
- "authn_event": SINGLE_REQUIRED_AUTHN_EVENT,
- "si_redirects": OPTIONAL_LIST_OF_STRINGS,
- }
-
-
-def pairwise_id(uid, sector_identifier, salt, **kwargs):
- return hashlib.sha256(
- ("%s%s%s" % (uid, sector_identifier, salt)).encode("utf-8")
- ).hexdigest()
-
-
-def public_id(uid, salt="", **kwargs):
- return hashlib.sha256("{}{}".format(uid, salt).encode("utf-8")).hexdigest()
-
-
-def dict_match(a, b):
- """
- Check if all attribute/value pairs in a also appears in b
-
- :param a: A dictionary
- :param b: A dictionary
- :return: True/False
- """
- res = []
- for k, v in a.items():
- try:
- res.append(b[k] == v)
- except KeyError:
- pass
- return all(res)
-
-
-class SessionDB(object):
- def __init__(self, db, handler, sso_db, userinfo=None, sub_func=None):
- self._db = db
- self.handler = handler
- self.sso_db = sso_db
- self.userinfo = userinfo
-
- # this allows the subject identifier minters to be defined by someone
- # else then me.
- if sub_func is None:
- self.sub_func = {"public": public_id, "pairwise": pairwise_id}
- else:
- self.sub_func = sub_func
- if "public" not in sub_func:
- self.sub_func["public"] = public_id
- if "pairwise" not in sub_func:
- self.sub_func["pairwise"] = pairwise_id
-
- def __getitem__(self, item):
- _info = self._db.get(item)
-
- if _info is None:
- sid = self.handler.sid(item)
- _info = self._db.get(sid)
- if _info:
- _si = SessionInfo(**_info)
- if any(item == val for val in _si.values()):
- _si["sid"] = sid
- return _si
- else:
- _handler, _res = self.handler.get_handler(item)
- if _handler.type == 'A': # access grant
- if _si['code_is_used']:
- raise MultipleCodeUsage('Reused code')
- else:
- _si = SessionInfo(**_info)
- _si["sid"] = item
- return _si
- raise KeyError
-
- def __setitem__(self, sid, instance):
- if isinstance(instance, Message):
- _info = instance.to_dict()
- else:
- _info = instance
- self._db[sid] = _info
-
- def __delitem__(self, key):
- del self._db[key]
-
- def keys(self):
- return self._db.keys()
-
- def create_authz_session(self, authn_event, areq, client_id="", uid="", **kwargs):
- """
-
- :param authn_event:
- :param areq:
- :param client_id:
- :param uid:
- :param kwargs:
- :return:
- """
- try:
- _uid = authn_event["uid"]
- except (TypeError, KeyError):
- _uid = uid
-
- if not _uid:
- raise MissingParameter('Need a "uid"')
-
- sid = self.handler["code"].key(user=_uid, areq=areq)
-
- access_grant = self.handler["code"](sid=sid)
-
- _info = SessionInfo(code=access_grant, oauth_state="authz")
-
- if client_id:
- _info["client_id"] = client_id
-
- if areq:
- _info["authn_req"] = areq
- if authn_event:
- _info["authn_event"] = authn_event
-
- if kwargs:
- _info.update(kwargs)
-
- self[sid] = _info
- return sid
-
- def update(self, sid, **kwargs):
- """
- Add attribute value assertion to a special session
-
- :param sid: Session ID
- :param kwargs:
- """
- item = self[sid]
- for attribute, value in kwargs.items():
- item[attribute] = value
- self[sid] = item
-
- def update_by_token(self, token, **kwargs):
- """
- Updated the session info. Any type of known token can be used
-
- :param token: code/access token/refresh token/...
- :param kwargs: Key word arguements
- """
- _sid = self.handler.sid(token)
- return self.update(_sid, **kwargs)
-
- def set(self, label, key, value):
- logger.debug("Session set {} - {}: {}".format(label, key, value))
- # Try loading the key
- _dic = self._db.get(key, {})
- if label in _dic:
- _dic[label].append(value)
- else:
- _dic[label] = [value]
- self._db[key] = _dic
-
- def get(self, key, seckey):
- _dic = self._db.get(key, {})
- logger.debug("SSODb get {} - {}: {}".format(key, seckey, _dic))
- return _dic.get(seckey, None)
-
- def map_kv2sid(self, key, seckey, sid):
- self.sso_db.set(key, seckey, sid)
-
- def delete_kv2sid(self, key, seckey):
- self.sso_db.delete(key, seckey)
-
- def get_sid_by_kv(self, key, seckey):
- return self.sso_db.get(key, seckey)
-
- def get_token(self, sid):
- _sess_info = self[sid]
-
- if _sess_info["oauth_state"] == "authz":
- return _sess_info["code"]
- elif _sess_info["oauth_state"] == "token":
- return _sess_info["access_token"]
-
- def do_sub(
- self,
- sid,
- uid,
- salt=None,
- sector_id=None,
- subject_type=None,
- client=None,
- ):
- """
- Create and store a subject identifier
-
- :param sid: Session ID
- :param uid: User ID
- :param salt:
- :param sector_id: For pairwise identifiers, an Identifier for the RP group
- :param subject_type: 'pairwaise'/'public'
- :param client_id:
- :return:
- """
- if not client:
- client = {}
-
- salt = salt
- if not salt:
- salt = client.get("client_salt", salt)
- if not subject_type:
- # Default to public
- subject_type = client.get("subject_type", "public")
- if subject_type == "pairwise" and not sector_id:
- sector_id = client.get("sector_identifier_uri")
- # If no `sector_identifier_uri` is registered, then get the host
- # component of the registered redirect_uri
- if not sector_id:
- uris = [uri[0] for uri in client.get("redirect_uris", [])]
- sector_id = sector_id_from_redirect_uris(uris)
-
- sub = self.sub_func[subject_type](
- uid, salt=salt, sector_identifier=sector_id
- )
-
- self.sso_db.map_sid2uid(sid, uid)
- self.update(sid, sub=sub)
- self.sso_db.map_sid2sub(sid, sub)
-
- return sub
-
- def is_valid(self, typ, item):
- try:
- return typ in self[item]
- except (KeyError, MultipleCodeUsage):
- return False
-
- def get_sids_by_sub(self, sub):
- return self.sso_db.get_sids_by_sub(sub)
-
- def get_sid_by_sub_and_client_id(self, sub, client_id):
- for sid in self.sso_db.get_sids_by_sub(sub):
- if self[sid]["authn_req"]["client_id"] == client_id:
- return sid
- return None
-
- def replace_refresh_token(self, sid, sinfo):
- """
- Replace an old refresh_token with a new one
-
- :param sid: session ID
- :param sinfo: session info
- :return: Updated session info
- """
- refresh_token = self.handler["refresh_token"](sid, sinfo=sinfo)
- sinfo["refresh_token"] = refresh_token
- return sinfo
-
- def _make_at(self, sid, session_info, aud=None, client_id_aud=True):
- uid = self.sso_db.get_uid_by_sid(sid)
- client_id = session_info["client_id"]
- uinfo = self.userinfo(uid, client_id) or {}
- at_aud = aud or []
-
- if client_id_aud:
- at_aud.append(client_id)
- return self.handler["access_token"](
- sid=sid, sinfo=session_info, uinfo=uinfo, aud=at_aud, client_id=client_id
- )
-
- def upgrade_to_token(
- self,
- grant=None,
- issue_refresh=False,
- id_token="",
- oidreq=None,
- key=None,
- scope=None,
- ):
- """
-
- :param grant: The access grant
- :param issue_refresh: If a refresh token should be issued
- :param id_token: An IDToken instance
- :param oidreq: An OpenIDRequest instance
- :param key: The session key. One of grant or key must be given.
- :return: The session information as a SessionInfo instance
- """
- if grant:
- # The caller is responsible for checking if the access code exists.
- _tinfo = self.handler["code"].info(grant)
-
- key = _tinfo["sid"]
- session_info = self[key]
-
- # mint a new access token
- _at = self._make_at(_tinfo["sid"], session_info)
-
- # make sure the code can't be used again
- self.revoke_token(key, "code", session_info)
- else:
- session_info = self[key]
- _at = self._make_at(key, session_info)
-
- session_info["access_token"] = _at
- session_info["oauth_state"] = "token"
- session_info["token_type"] = self.handler["access_token"].token_type
-
- if scope:
- session_info["access_token_scope"] = scope
- if id_token:
- session_info["id_token"] = id_token
- if oidreq:
- session_info["oidreq"] = oidreq
-
- if self.handler["access_token"].lifetime:
- session_info["expires_in"] = self.handler["access_token"].lifetime
- session_info["expires_at"] = (
- self.handler["access_token"].lifetime + utc_time_sans_frac()
- )
-
- if issue_refresh and "refresh_token" in self.handler:
- session_info = self.replace_refresh_token(key, session_info)
-
- self[key] = session_info
- return session_info
-
- def refresh_token(self, token, new_refresh=False):
- """
- Issue a new access token using a valid refresh token
-
- :param token: Refresh token
- :param new_refresh: Whether a new refresh token should be minted or not
- :return: Dictionary with session info
- :raises: ExpiredToken for invalid refresh token
- WrongTokenType for wrong token type
- """
- try:
- _tinfo = self.handler["refresh_token"].info(token)
- except KeyError:
- return False
-
- _sid = _tinfo["sid"]
- session_info = self[_sid]
- if token != session_info.get("refresh_token"):
- raise UnknownToken()
- if is_expired(int(_tinfo["exp"])):
- raise ExpiredToken()
-
- session_info["access_token"] = self._make_at(_sid, session_info)
- session_info["token_type"] = self.handler["access_token"].token_type
-
- if new_refresh:
- session_info = self.replace_refresh_token(_sid, session_info)
-
- self[_sid] = session_info
- return session_info
-
- def is_token_valid(self, token):
- """
- Checks validity of a given token
-
- :param token: Access or refresh token
- """
-
- try:
- _tinfo = self.handler.info(token)
- except KeyError:
- return False
-
- # Dependent on what state the session is in.
- session_info = self[_tinfo["sid"]]
- if is_expired(int(_tinfo["exp"])):
- return False
-
- if session_info["oauth_state"] == "authz":
- if _tinfo["handler"] != self.handler["code"]:
- return False
- elif session_info["oauth_state"] == "token":
- if _tinfo["handler"] != self.handler["access_token"]:
- return False
-
- return True
-
- def revoke_token(self, sid, token_type, session_info=None):
- """
- Revokes token
-
- :param sid: session id
- :param token_type: token type, one of "code", "access_token" or
- "refresh_token"
- """
- if not session_info:
- session_info = self[sid]
- session_info.pop(token_type, None)
- if token_type == 'code':
- session_info['code_is_used'] = True
- self[sid] = session_info
-
- def revoke_all_tokens(self, token):
- sid = self.handler.sid(token)
- _sinfo = self[sid]
- for token_type in self.handler.keys():
- _sinfo.pop(token_type, None)
- self[sid] = _sinfo
-
- def revoke_session(self, sid="", token=""):
- """
- Mark session as revoked but also explicitly revoke all issued tokens
-
- :param token: any token connected to the session
- :param sid: Session identifier
- """
- if not sid:
- if token:
- sid = self.handler.sid(token)
- else:
- raise ValueError('Need one of "sid" or "token"')
-
- _sinfo = self[sid]
- for token_type in self.handler.keys():
- _sinfo.pop(token_type, None)
- _sinfo["revoked"] = True
- self[sid] = _sinfo
-
- def get_client_id_for_session(self, sid):
- return self[sid]["client_id"]
-
- def get_active_client_ids_for_uid(self, uid):
- res = []
- for sid in self.sso_db.get_sids_by_uid(uid):
- if "revoked" not in self[sid]:
- res.append(self[sid]["client_id"])
- return res
-
- def get_verified_logout(self, uid):
- res = {}
- for sid in self.sso_db.get_sids_by_uid(uid):
- session_info = self[sid]
- try:
- res[session_info["client_id"]] = session_info["verified_logout"]
- except KeyError:
- res[session_info["client_id"]] = False
- return res
-
- def match_session(self, uid, **kwargs):
- for sid in self.sso_db.get_sids_by_uid(uid):
- session_info = self[sid]
- if dict_match(kwargs, session_info):
- return sid
- return None
-
- def set_verify_logout(self, uid, client_id):
- sid = self.match_session(uid, client_id=client_id)
- self.update(sid, verified_logout=True)
-
- def get_id_token(self, uid, client_id):
- sid = self.match_session(uid, client_id=client_id)
- return self[sid]["id_token"]
-
- def is_session_revoked(self, key):
- try:
- session_info = self[key]
- except Exception:
- raise UnknownToken(key)
-
- try:
- return session_info["revoked"]
- except KeyError:
- return False
-
- def revoke_uid(self, uid):
- # Revoke all sessions
- for sid in self.sso_db.get_sids_by_uid(uid):
- self.update(sid, revoked=True)
-
- # Remove the uid from the SSO db
- self.sso_db.remove_uid(uid)
-
- def read(self, token):
- try:
- _tinfo = self.handler["access_token"].info(token)
- except WrongTokenType:
- return {}
- else:
- return self[_tinfo["sid"]]
-
- def find_sid(self, req):
- """
- Given a request with some info find the correct session.
- The useful claim is 'code'.
-
- :param req: An AccessTokenRequest instance
- :return: session ID or None
- """
-
- return self.get_sid_by_kv(req["code"], "code")
-
- def get_authentication_event(self, sid):
- try:
- session_info = self[sid]
- except Exception:
- raise UnknownToken(sid)
- else:
- sesinf = session_info.get("authn_event")
- return sesinf or ValueError("No Authn event info")
-
-
-def create_session_db(ec, token_handler_args, db=None, sso_db=None, sub_func=None):
- _token_handler = token_handler.factory(ec, **token_handler_args)
- # if db is None:
- # db = InMemoryDataBase()
- # else:
- # db = db
- #
- # if sso_db is None:
- # sso_db = SSODb()
- # else:
- # sso_db = sso_db
-
- return SessionDB(db, _token_handler, sso_db, sub_func=sub_func)
diff --git a/src/oidcendpoint/session/__init__.py b/src/oidcendpoint/session/__init__.py
new file mode 100644
index 0000000..0757c4e
--- /dev/null
+++ b/src/oidcendpoint/session/__init__.py
@@ -0,0 +1,19 @@
+from typing import List
+
+DIVIDER = ";;"
+
+
+def session_key(*args) -> str:
+ return DIVIDER.join(args)
+
+
+def unpack_session_key(key: str) -> List[str]:
+ return key.split(DIVIDER)
+
+
+class Revoked(Exception):
+ pass
+
+
+class MintingNotAllowed(Exception):
+ pass
diff --git a/src/oidcendpoint/session/claims.py b/src/oidcendpoint/session/claims.py
new file mode 100755
index 0000000..b389cb4
--- /dev/null
+++ b/src/oidcendpoint/session/claims.py
@@ -0,0 +1,185 @@
+import logging
+from typing import Optional
+from typing import Union
+
+from oidcmsg.oidc import OpenIDSchema
+
+from oidcendpoint.scopes import convert_scopes2claims
+from oidcendpoint.session import unpack_session_key
+
+logger = logging.getLogger(__name__)
+
+# USAGE = Literal["userinfo", "id_token", "introspection"]
+
+IGNORE = ["error", "error_description", "error_uri", "_claim_names", "_claim_sources"]
+STANDARD_CLAIMS = [c for c in OpenIDSchema.c_param.keys() if c not in IGNORE]
+
+
+def available_claims(endpoint_context):
+ _supported = endpoint_context.provider_info.get("claims_supported")
+ if _supported:
+ return _supported
+ else:
+ return STANDARD_CLAIMS
+
+
+class ClaimsInterface:
+ init_args = {
+ "add_claims_by_scope": False,
+ "enable_claims_per_client": False
+ }
+
+ def __init__(self, endpoint_context):
+ self.endpoint_context = endpoint_context
+
+ def authorization_request_claims(self, session_id: str, usage: Optional[str] = "") -> dict:
+ if usage in ["id_token", "userinfo"]:
+ _grant = self.endpoint_context.session_manager.get_grant(session_id)
+ if "claims" in _grant.authorization_request:
+ return _grant.authorization_request["claims"].get(usage, {})
+
+ return {}
+
+ def _get_client_claims(self, client_id, usage):
+ client_info = self.endpoint_context.cdb.get(client_id, {})
+ client_claims = client_info.get("{}_claims".format(usage), {})
+ if isinstance(client_claims, list):
+ client_claims = {k: None for k in client_claims}
+ return client_claims
+
+ def get_claims(self, session_id: str, scopes: str, usage: str) -> dict:
+ """
+
+ :param session_id: Session identifier
+ :param scopes: Scopes
+ :param usage: Where to use the claims. One of "userinfo"/"id_token"/"introspection"
+ :return: Claims specification as a dictionary.
+ """
+
+ # which endpoint module configuration to get the base claims from
+ module = None
+ if usage == "userinfo":
+ if "userinfo" in self.endpoint_context.endpoint:
+ module = self.endpoint_context.endpoint["userinfo"]
+ elif usage == "id_token":
+ if self.endpoint_context.idtoken:
+ module = self.endpoint_context.idtoken
+ elif usage == "introspection":
+ if "introspection" in self.endpoint_context.endpoint:
+ module = self.endpoint_context.endpoint["introspection"]
+ elif usage == "access_token":
+ try:
+ module = self.endpoint_context.session_manager.token_handler["access_token"]
+ except KeyError:
+ pass
+
+ if module:
+ base_claims = module.kwargs.get("base_claims", {})
+ else:
+ base_claims = {}
+
+ user_id, client_id, grant_id = unpack_session_key(session_id)
+
+ # Can there be per client specification of which claims to use.
+ if module and module.kwargs.get("enable_claims_per_client"):
+ claims = self._get_client_claims(client_id, usage)
+ else:
+ claims = {}
+
+ claims.update(base_claims)
+
+ # Scopes can in some cases equate to set of claims, is that used here ?
+ if module and module.kwargs.get("add_claims_by_scope"):
+ if scopes:
+ _scopes = self.endpoint_context.scopes_handler.filter_scopes(
+ client_id, self.endpoint_context, scopes
+ )
+
+ _claims = convert_scopes2claims(
+ _scopes, map=self.endpoint_context.scope2claims
+ )
+ claims.update(_claims)
+
+ # Bring in claims specification from the authorization request
+ request_claims = self.authorization_request_claims(session_id=session_id,
+ usage=usage)
+
+ # This will add claims that has not be added before and
+ # set filters on those claims that also appears in one of the sources above
+ if request_claims:
+ claims.update(request_claims)
+
+ return claims
+
+ def get_claims_all_usage(self, session_id: str, scopes: str) -> dict:
+ _claims = {}
+ for usage in ["userinfo", "introspection", "id_token", "token"]:
+ _claims.update(self.get_claims(session_id, scopes, usage))
+ return _claims
+
+ def get_user_claims(self, user_id: str, claims_restriction: dict) -> dict:
+ """
+
+ :param user_id: User identifier
+ :param claims_restriction: Specifies the upper limit of which claims can be returned
+ :return:
+ """
+ if claims_restriction:
+ # Get all possible claims
+ user_info = self.endpoint_context.userinfo(user_id, client_id=None)
+ # Filter out the once that can be returned
+ return {k: user_info.get(k) for k, v in claims_restriction.items() if
+ claims_match(user_info.get(k), v)}
+ else:
+ return {}
+
+
+def claims_match(value: Union[str, int], claimspec: Optional[dict]) -> bool:
+ """
+ Implements matching according to section 5.5.1 of
+ http://openid.net/specs/openid-connect-core-1_0.html
+ The lack of value is not checked here.
+ Also the text doesn't prohibit having both 'value' and 'values'.
+
+ :param value: single value
+ :param claimspec: None or dictionary with 'essential', 'value' or 'values'
+ as key
+ :return: Boolean
+ """
+ if value is None:
+ return False
+
+ if claimspec is None: # match anything
+ return True
+
+ matched = False
+ for key, val in claimspec.items():
+ if key == "value":
+ if value == val:
+ matched = True
+ elif key == "values":
+ if value in val:
+ matched = True
+ elif key == "essential":
+ # Whether it's essential or not doesn't change anything here
+ continue
+
+ if matched:
+ break
+
+ if matched is False:
+ if list(claimspec.keys()) == ["essential"]:
+ return True
+
+ return matched
+
+
+def by_schema(cls, **kwa):
+ """
+ Will return only those claims that are listed in the Class definition.
+
+ :param cls: A subclass of :py:class:´oidcmsg.message.Message`
+ :param kwa: Keyword arguments
+ :return: A dictionary with claims (keys) that meets the filter criteria
+ """
+ return dict([(key, val) for key, val in kwa.items() if key in cls.c_param])
diff --git a/src/oidcendpoint/session/database.py b/src/oidcendpoint/session/database.py
new file mode 100644
index 0000000..301a7fb
--- /dev/null
+++ b/src/oidcendpoint/session/database.py
@@ -0,0 +1,147 @@
+import logging
+from typing import List
+from typing import Union
+
+from . import session_key
+from .grant import Grant
+from .info import ClientSessionInfo
+from .info import SessionInfo
+from .info import UserSessionInfo
+
+logger = logging.getLogger(__name__)
+
+
+class NoSuchClientSession(KeyError):
+ pass
+
+
+class NoSuchGrant(KeyError):
+ pass
+
+
+class Database(object):
+ def __init__(self, storage=None):
+ if storage is None:
+ self._db = {}
+ else:
+ self._db = storage
+
+ @staticmethod
+ def _eval_path(path: List[str]):
+ uid = path[0]
+ client_id = None
+ grant_id = None
+ if len(path) > 1:
+ client_id = path[1]
+ if len(path) > 2:
+ grant_id = path[2]
+
+ return uid, client_id, grant_id
+
+ def set(self, path: List[str], value: Union[SessionInfo, Grant]):
+ """
+
+ :param path: a list of identifiers
+ :param value: Class instance to be stored
+ """
+
+ uid, client_id, grant_id = self._eval_path(path)
+
+ if grant_id:
+ gid_key = session_key(uid, client_id, grant_id)
+ self._db[gid_key] = value
+
+ if client_id:
+ cid_key = session_key(uid, client_id)
+ cid_info = self._db.get(cid_key, ClientSessionInfo())
+ if not grant_id:
+ self._db[cid_key] = value
+ elif grant_id not in cid_info["subordinate"]:
+ cid_info.add_subordinate(grant_id)
+ self._db[cid_key] = cid_info
+
+ userinfo = self._db.get(uid, UserSessionInfo())
+ if client_id is None:
+ self._db[uid] = value
+ if client_id and client_id not in userinfo["subordinate"]:
+ userinfo.add_subordinate(client_id)
+ self._db[uid] = userinfo
+
+ def get(self, path: List[str]) -> Union[SessionInfo, Grant]:
+ uid, client_id, grant_id = self._eval_path(path)
+ try:
+ user_info = self._db[uid]
+ except KeyError:
+ raise KeyError('No such UserID')
+ else:
+ if user_info is None:
+ raise KeyError('No such UserID')
+
+ if client_id is None:
+ return user_info
+ else:
+ if client_id not in user_info['subordinate']:
+ raise ValueError('No session from that client for that user')
+ else:
+ try:
+ client_session_info = self._db[session_key(uid, client_id)]
+ except KeyError:
+ raise NoSuchClientSession(client_id)
+ else:
+ if grant_id is None:
+ return client_session_info
+
+ if grant_id not in client_session_info['subordinate']:
+ raise ValueError('No such grant for that user and client')
+ else:
+ try:
+ return self._db[session_key(uid, client_id, grant_id)]
+ except KeyError:
+ raise NoSuchGrant(grant_id)
+
+ def delete(self, path: List[str]):
+ uid, client_id, grant_id = self._eval_path(path)
+ try:
+ _user_info = self._db[uid]
+ except KeyError:
+ pass
+ else:
+ if client_id:
+ if client_id in _user_info['subordinate']:
+ try:
+ _client_info = self._db[session_key(uid, client_id)]
+ except KeyError:
+ pass
+ else:
+ if grant_id:
+ if grant_id in _client_info['subordinate']:
+ try:
+ self._db.__delitem__(session_key(uid, client_id, grant_id))
+ except KeyError:
+ pass
+ _client_info["subordinate"].remove(grant_id)
+ else:
+ for grant_id in _client_info['subordinate']:
+ self._db.__delitem__(session_key(uid, client_id, grant_id))
+ _client_info['subordinate'] = []
+
+ if len(_client_info['subordinate']) == 0:
+ self._db.__delitem__(session_key(uid, client_id))
+ _user_info["subordinate"].remove(client_id)
+ else:
+ self._db[client_id] = _client_info
+
+ if len(_user_info["subordinate"]) == 0:
+ self._db.__delitem__(uid)
+ else:
+ self._db[uid] = _user_info
+ else:
+ pass
+ else:
+ self._db.__delitem__(uid)
+
+ def update(self, path: List[str], new_info: dict):
+ _info = self.get(path)
+ for key, val in new_info.items():
+ _info[key] = val
+ self.set(path, _info)
diff --git a/src/oidcendpoint/session/grant.py b/src/oidcendpoint/session/grant.py
new file mode 100644
index 0000000..ac846f1
--- /dev/null
+++ b/src/oidcendpoint/session/grant.py
@@ -0,0 +1,350 @@
+import json
+from typing import Optional
+from uuid import uuid1
+
+from oidcmsg.message import Message
+from oidcmsg.message import OPTIONAL_LIST_OF_SP_SEP_STRINGS
+from oidcmsg.message import OPTIONAL_LIST_OF_STRINGS
+from oidcmsg.message import SINGLE_OPTIONAL_JSON
+from oidcmsg.oauth2 import AuthorizationRequest
+
+from oidcendpoint.authn_event import AuthnEvent
+from oidcendpoint.session import MintingNotAllowed
+from oidcendpoint.session import unpack_session_key
+from oidcendpoint.session.token import AccessToken
+from oidcendpoint.session.token import AuthorizationCode
+from oidcendpoint.session.token import Item
+from oidcendpoint.session.token import RefreshToken
+from oidcendpoint.session.token import Token
+from oidcendpoint.token import Token as TokenHandler
+from oidcendpoint.util import importer
+
+
+class GrantMessage(Message):
+ c_param = {
+ "scope": OPTIONAL_LIST_OF_SP_SEP_STRINGS, # As defined in RFC6749
+ "authorization_details": SINGLE_OPTIONAL_JSON, # As defined in draft-lodderstedt-oauth-rar
+ "claims": SINGLE_OPTIONAL_JSON, # As defined in OIDC core
+ "resources": OPTIONAL_LIST_OF_STRINGS, # As defined in RFC8707
+ }
+
+
+GRANT_TYPE_MAP = {
+ "authorization_code": "code",
+ "access_token": "access_token",
+ "refresh_token": "refresh_token"
+}
+
+
+def find_token(issued, token_id):
+ for iss in issued:
+ if iss.id == token_id:
+ return iss
+ return None
+
+
+TOKEN_MAP = {
+ "authorization_code": AuthorizationCode,
+ "access_token": AccessToken,
+ "refresh_token": RefreshToken
+}
+
+
+class Grant(Item):
+ attributes = ["scope", "claim", "resources", "authorization_details",
+ "issued_token", "usage_rules", "revoked", "issued_at",
+ "expires_at", "sub", "authorization_request",
+ "authentication_event"]
+ type = "grant"
+
+ def __init__(self,
+ scope: Optional[list] = None,
+ claims: Optional[dict] = None,
+ resources: Optional[list] = None,
+ authorization_details: Optional[dict] = None,
+ authorization_request: Optional[Message] = None,
+ authentication_event: Optional[AuthnEvent] = None,
+ issued_token: Optional[list] = None,
+ usage_rules: Optional[dict] = None,
+ issued_at: int = 0,
+ expires_in: int = 0,
+ expires_at: int = 0,
+ revoked: bool = False,
+ token_map: Optional[dict] = None,
+ sub: Optional[str] = ""):
+ Item.__init__(self, usage_rules=usage_rules, issued_at=issued_at,
+ expires_in=expires_in, expires_at=expires_at, revoked=revoked)
+ self.scope = scope or []
+ self.authorization_details = authorization_details or None
+ self.authorization_request = authorization_request or None
+ self.authentication_event = authentication_event or None
+ self.claims = claims or {} # default is to not release any user information
+ self.resources = resources or []
+ self.issued_token = issued_token or []
+ self.id = uuid1().hex
+ self.sub = sub
+
+ if token_map is None:
+ self.token_map = TOKEN_MAP
+ else:
+ self.token_map = token_map
+
+ def get(self) -> object:
+ return GrantMessage(scope=self.scope, claims=self.claims,
+ authorization_details=self.authorization_details,
+ resources=self.resources)
+
+ def to_json(self) -> str:
+ d = {
+ "type": self.type,
+ "scope": self.scope,
+ "sub": self.sub,
+ "authorization_details": self.authorization_details,
+ "claims": self.claims,
+ "resources": self.resources,
+ "issued_at": self.issued_at,
+ "not_before": self.not_before,
+ "expires_at": self.expires_at,
+ "revoked": self.revoked,
+ "issued_token": [t.to_json() for t in self.issued_token],
+ "id": self.id,
+ "usage_rules": self.usage_rules,
+ "token_map": {k: ".".join([v.__module__, v.__name__]) for k, v in
+ self.token_map.items()}
+ }
+
+ if self.authorization_request:
+ if isinstance(self.authorization_request, Message):
+ d["authorization_request"] = self.authorization_request.to_dict()
+ elif isinstance(self.authorization_request, dict):
+ d["authorization_request"] = self.authorization_request
+
+ if self.authentication_event:
+ if isinstance(self.authentication_event, Message):
+ d["authentication_event"] = self.authentication_event.to_dict()
+ elif isinstance(self.authentication_event, dict):
+ d["authentication_event"] = self.authentication_event
+
+ return json.dumps(d)
+
+ def from_json(self, json_str: str) -> 'Grant':
+ d = json.loads(json_str)
+ for attr in ["scope", "authorization_details", "claims", "resources", "sub",
+ "issued_at", "not_before", "expires_at", "revoked", "id",
+ "usage_rules"]:
+ if attr in d:
+ setattr(self, attr, d[attr])
+
+ if "authentication_event" in d:
+ self.authentication_event = AuthnEvent(**d["authentication_event"])
+ if "authorization_request" in d:
+ self.authorization_request = AuthorizationRequest(**d["authorization_request"])
+
+ if "token_map" in d:
+ self.token_map = {k: importer(v) for k, v in d["token_map"].items()}
+ else:
+ self.token_map = TOKEN_MAP
+
+ if "issued_token" in d:
+ _it = []
+ for js in d["issued_token"]:
+ args = json.loads(js)
+ _it.append(self.token_map[args["type"]](**args))
+ setattr(self, "issued_token", _it)
+
+ return self
+
+ def payload_arguments(self, session_id: str, endpoint_context,
+ token_type: str, scope: Optional[dict] = None) -> dict:
+ """
+
+ :return: dictionary containing information to place in a token value
+ """
+ if not scope:
+ scope = self.scope
+
+ payload = {
+ "scope": scope,
+ "aud": self.resources
+ }
+
+ _claims_restriction = endpoint_context.claims_interface.get_claims(session_id,
+ scopes=scope,
+ usage=token_type)
+ user_id, _, _ = unpack_session_key(session_id)
+ user_info = endpoint_context.claims_interface.get_user_claims(user_id,
+ _claims_restriction)
+ payload.update(user_info)
+
+ return payload
+
+ def mint_token(self,
+ session_id: str,
+ endpoint_context: object,
+ token_type: str,
+ token_handler: TokenHandler = None,
+ based_on: Optional[Token] = None,
+ usage_rules: Optional[dict] = None,
+ scope: Optional[list] = None,
+ **kwargs) -> Optional[Token]:
+ """
+
+ :param session_id:
+ :param endpoint_context:
+ :param token_type:
+ :param token_handler:
+ :param based_on:
+ :param usage_rules:
+ :param scope:
+ :param kwargs:
+ :return:
+ """
+ if self.is_active() is False:
+ return None
+
+ if based_on:
+ if based_on.supports_minting(token_type) and based_on.is_active():
+ _base_on_ref = based_on.value
+ else:
+ raise MintingNotAllowed()
+ else:
+ _base_on_ref = None
+
+ if usage_rules is None and token_type in self.usage_rules:
+ usage_rules = self.usage_rules[token_type]
+
+ token_class = self.token_map.get(token_type)
+ if token_class:
+ item = token_class(type=token_type,
+ based_on=_base_on_ref,
+ usage_rules=usage_rules,
+ scope=scope,
+ **kwargs)
+ if token_handler is None:
+ token_handler = endpoint_context.session_manager.token_handler.handler[
+ GRANT_TYPE_MAP[token_type]]
+
+ item.value = token_handler(session_id=session_id,
+ **self.payload_arguments(session_id,
+ endpoint_context,
+ token_type=token_type,
+ scope=scope))
+ else:
+ raise ValueError("Can not mint that kind of token")
+
+ self.issued_token.append(item)
+ self.used += 1
+ return item
+
+ def get_token(self, value: str) -> Optional[Token]:
+ for t in self.issued_token:
+ if t.value == value:
+ return t
+ return None
+
+ def revoke_token(self,
+ value: Optional[str] = "",
+ based_on: Optional[str] = "",
+ recursive: bool = True):
+ for t in self.issued_token:
+ if not value and not based_on:
+ t.revoked = True
+ elif value and based_on:
+ if value == t.value and based_on == t.based_on:
+ t.revoked = True
+ elif value and t.value == value:
+ t.revoked = True
+ if recursive:
+ self.revoke_token(based_on=t.value)
+ elif based_on and t.based_on == based_on:
+ t.revoked = True
+ if recursive:
+ self.revoke_token(based_on=t.value)
+
+ def get_spec(self, token: Token) -> Optional[dict]:
+ if self.is_active() is False or token.is_active is False:
+ return None
+
+ res = {}
+ for attr in ["scope", "claims", "resources"]:
+ _val = getattr(token, attr)
+ if _val:
+ res[attr] = _val
+ else:
+ _val = getattr(self, attr)
+ if _val:
+ res[attr] = _val
+ return res
+
+
+DEFAULT_USAGE = {
+ "authorization_code": {
+ "max_usage": 1,
+ "supports_minting": ["access_token", "refresh_token", "id_token"],
+ "expires_in": 300
+ },
+ "access_token": {
+ "supports_minting": [],
+ "expires_in": 3600
+ },
+ "refresh_token": {
+ "supports_minting": ["access_token", "refresh_token", "id_token"]
+ }
+}
+
+
+def get_usage_rules(token_type, endpoint_context, grant, client_id):
+ """
+ The order of importance:
+ Grant, Client, EndPointContext, Default
+
+ :param token_type: The type of token
+ :param endpoint_context: An EndpointContext instance
+ :param grant: A Grant instance
+ :param client_id: The client identifier
+ :return: Usage specification
+ """
+
+ _usage = endpoint_context.authz.usage_rules_for(client_id, token_type)
+ if not _usage:
+ _usage = DEFAULT_USAGE[token_type]
+
+ _grant_usage = grant.usage_rules.get(token_type)
+ if _grant_usage:
+ _usage.update(_grant_usage)
+
+ return _usage
+
+
+class ExchangeGrant(Grant):
+ attributes = Grant.attributes
+ attributes.append("users")
+ type = "exchange_grant"
+
+ def __init__(self,
+ scope: Optional[list] = None,
+ claims: Optional[dict] = None,
+ resources: Optional[list] = None,
+ authorization_details: Optional[dict] = None,
+ issued_token: Optional[list] = None,
+ usage_rules: Optional[dict] = None,
+ issued_at: int = 0,
+ expires_in: int = 0,
+ expires_at: int = 0,
+ revoked: bool = False,
+ token_map: Optional[dict] = None,
+ users: list = None):
+ Grant.__init__(self, scope=scope, claims=claims, resources=resources,
+ authorization_details=authorization_details,
+ issued_token=issued_token, usage_rules=usage_rules,
+ issued_at=issued_at, expires_in=expires_in,
+ expires_at=expires_at, revoked=revoked,
+ token_map=token_map)
+
+ self.users = users or []
+ self.usage_rules = {
+ "access_token": {
+ "supports_minting": ["access_token"],
+ "expires_in": 60
+ }
+ }
diff --git a/src/oidcendpoint/session/info.py b/src/oidcendpoint/session/info.py
new file mode 100644
index 0000000..4126db0
--- /dev/null
+++ b/src/oidcendpoint/session/info.py
@@ -0,0 +1,70 @@
+from typing import Tuple
+
+from oidcmsg.message import OPTIONAL_LIST_OF_MESSAGES
+from oidcmsg.message import SINGLE_OPTIONAL_STRING
+from oidcmsg.message import SINGLE_REQUIRED_BOOLEAN
+from oidcmsg.message import SINGLE_REQUIRED_STRING
+from oidcmsg.message import Message
+from oidcmsg.message import list_deserializer
+from oidcmsg.message import list_serializer
+
+from oidcendpoint.session.grant import Grant
+from oidcendpoint.session.token import Token
+
+
+class SessionInfo(Message):
+ c_param = {
+ "subordinate": ([str], False, list_serializer, list_deserializer, True),
+ "revoked": SINGLE_REQUIRED_BOOLEAN,
+ "type": SINGLE_REQUIRED_STRING
+ }
+
+ def __init__(self, **kwargs):
+ Message.__init__(self, **kwargs)
+ if "subordinate" not in self:
+ self["subordinate"] = []
+ self["revoked"] = False
+
+ def add_subordinate(self, value: str) -> "SessionInfo":
+ if value not in self["subordinate"]:
+ self["subordinate"].append(value)
+ return self
+
+ def remove_subordinate(self, value: str) -> 'SessionInfo':
+ self["subordinate"].remove(value)
+ return self
+
+ def revoke(self) -> 'SessionInfo':
+ self["revoked"] = True
+ return self
+
+ def is_revoked(self) -> bool:
+ return self["revoked"]
+
+
+class UserSessionInfo(SessionInfo):
+ c_param = SessionInfo.c_param.copy()
+ c_param.update({
+ "user_id": SINGLE_REQUIRED_STRING,
+ })
+
+ def __init__(self, **kwargs):
+ SessionInfo.__init__(self, **kwargs)
+ self["type"] = "UserSessionInfo"
+
+
+class ClientSessionInfo(SessionInfo):
+ c_param = SessionInfo.c_param.copy()
+ c_param.update({
+ "client_id": SINGLE_REQUIRED_STRING
+ })
+
+ def __init__(self, **kwargs):
+ SessionInfo.__init__(self, **kwargs)
+ self["type"] = "ClientSessionInfo"
+
+ def find_grant_and_token(self, val: str) -> Tuple[Grant, Token]:
+ for grant in self["subordinate"]:
+ token = grant.get_token(val)
+ if token:
+ return grant, token
diff --git a/src/oidcendpoint/session/manager.py b/src/oidcendpoint/session/manager.py
new file mode 100644
index 0000000..b297ca3
--- /dev/null
+++ b/src/oidcendpoint/session/manager.py
@@ -0,0 +1,437 @@
+import hashlib
+import logging
+import uuid
+from typing import List
+from typing import Optional
+
+from oidcmsg.oauth2 import AuthorizationRequest
+
+from oidcendpoint import rndstr
+from oidcendpoint.authn_event import AuthnEvent
+from oidcendpoint.token import handler
+
+from ..token import UnknownToken
+from ..token.handler import TokenHandler
+from . import session_key
+from . import unpack_session_key
+from .database import Database
+from .grant import Grant
+from .grant import Token
+from .info import ClientSessionInfo
+from .info import UserSessionInfo
+
+logger = logging.getLogger(__name__)
+
+
+def pairwise_id(uid, sector_identifier, salt="", **kwargs):
+ return hashlib.sha256(("%s%s%s" % (uid, sector_identifier, salt)).encode("utf-8")).hexdigest()
+
+
+def public_id(uid, salt="", **kwargs):
+ return hashlib.sha256("{}{}".format(uid, salt).encode("utf-8")).hexdigest()
+
+
+def ephemeral_id(*args, **kwargs):
+ return uuid.uuid4().hex
+
+
+class SessionManager(Database):
+ def __init__(self, handler: TokenHandler, db: Optional[object] = None,
+ conf: Optional[dict] = None, sub_func: Optional[dict] = None):
+ Database.__init__(self, db)
+ self.token_handler = handler
+ self.salt = rndstr(32)
+ self.conf = conf or {}
+
+ # this allows the subject identifier minters to be defined by someone
+ # else then me.
+ if sub_func is None:
+ self.sub_func = {
+ "public": public_id,
+ "pairwise": pairwise_id,
+ "ephemeral": ephemeral_id
+ }
+ else:
+ self.sub_func = sub_func
+ if "public" not in sub_func:
+ self.sub_func["public"] = public_id
+ if "pairwise" not in sub_func:
+ self.sub_func["pairwise"] = pairwise_id
+ if "ephemeral" not in sub_func:
+ self.sub_func["ephemeral"] = ephemeral_id
+
+ def get_user_info(self, uid: str) -> UserSessionInfo:
+ usi = self.get([uid])
+ if isinstance(usi, UserSessionInfo):
+ return usi
+ else:
+ raise ValueError("Not UserSessionInfo")
+
+ def find_token(self, session_id: str, token_value: str) -> Optional[Token]:
+ """
+
+ :param session_id: Based on 3-tuple, user_id, client_id and grant_id
+ :param token_value:
+ :return:
+ """
+ user_id, client_id, grant_id = unpack_session_key(session_id)
+ grant = self.get([user_id, client_id, grant_id])
+ for token in grant.issued_token:
+ if token.value == token_value:
+ return token
+
+ return None
+
+ def create_grant(self,
+ authn_event: AuthnEvent,
+ auth_req: AuthorizationRequest,
+ user_id: str,
+ client_id: Optional[str] = "",
+ sub_type: Optional[str] = "public",
+ token_usage_rules: Optional[dict] = None,
+ scopes: Optional[list] = None
+ ) -> str:
+ """
+
+ :param scopes: Scopes
+ :param authn_event: AuthnEvent instance
+ :param auth_req:
+ :param user_id:
+ :param client_id:
+ :param sub_type:
+ :param token_usage_rules:
+ :return:
+ """
+ try:
+ sector_identifier = auth_req.get("sector_identifier_uri")
+ except AttributeError:
+ sector_identifier = ""
+
+ grant = Grant(authorization_request=auth_req,
+ authentication_event=authn_event,
+ sub=self.sub_func[sub_type](
+ user_id, salt=self.salt,
+ sector_identifier=sector_identifier),
+ usage_rules=token_usage_rules,
+ scope=scopes
+ )
+
+ self.set([user_id, client_id, grant.id], grant)
+
+ return session_key(user_id, client_id, grant.id)
+
+ def create_session(self,
+ authn_event: AuthnEvent,
+ auth_req: AuthorizationRequest,
+ user_id: str,
+ client_id: Optional[str] = "",
+ sub_type: Optional[str] = "public",
+ token_usage_rules: Optional[dict] = None,
+ scopes: Optional[list] = None
+ ) -> str:
+ """
+ Create part of a user session. The parts added are user- and client
+ information and a grant.
+
+ :param scopes:
+ :param authn_event: Authentication Event information
+ :param auth_req: Authorization Request
+ :param client_id: Client ID
+ :param user_id: User ID
+ :param sector_identifier: Identifier for a group of websites under common administrative
+ control.
+ :param sub_type: What kind of subject will be assigned
+ :param token_usage_rules: Rules for how tokens can be used
+ :return: Session key
+ """
+
+ try:
+ _usi = self.get([user_id])
+ except KeyError:
+ _usi = UserSessionInfo(user_id=user_id)
+ self.set([user_id], _usi)
+
+ if not client_id:
+ client_id = auth_req['client_id']
+
+ client_info = ClientSessionInfo(
+ client_id=client_id,
+ )
+
+ self.set([user_id, client_id], client_info)
+
+ return self.create_grant(
+ auth_req=auth_req,
+ authn_event=authn_event,
+ user_id=user_id,
+ client_id=client_id,
+ sub_type=sub_type,
+ token_usage_rules=token_usage_rules,
+ scopes=scopes
+ )
+
+ def __getitem__(self, session_id: str):
+ return self.get(unpack_session_key(session_id))
+
+ def __setitem__(self, session_id: str, value):
+ return self.set(unpack_session_key(session_id), value)
+
+ def get_client_session_info(self, session_id: str) -> ClientSessionInfo:
+ """
+ Return client connected information for a user session.
+
+ :param session_id: Session identifier
+ :return: ClientSessionInfo instance
+ """
+ _user_id, _client_id, _grant_id = unpack_session_key(session_id)
+ csi = self.get([_user_id, _client_id])
+ if isinstance(csi, ClientSessionInfo):
+ return csi
+ else:
+ raise ValueError("Wrong type of session info")
+
+ def get_user_session_info(self, session_id: str) -> UserSessionInfo:
+ """
+ Return client connected information for a user session.
+
+ :param session_id: Session identifier
+ :return: ClientSessionInfo instance
+ """
+ _user_id, _client_id, _grant_id = unpack_session_key(session_id)
+ usi = self.get([_user_id])
+ if isinstance(usi, UserSessionInfo):
+ return usi
+ else:
+ raise ValueError("Wrong type of session info")
+
+ def get_grant(self, session_id: str) -> Grant:
+ """
+ Return client connected information for a user session.
+
+ :param session_id: Session identifier
+ :return: ClientSessionInfo instance
+ """
+ _user_id, _client_id, _grant_id = unpack_session_key(session_id)
+ grant = self.get([_user_id, _client_id, _grant_id])
+ if isinstance(grant, Grant):
+ return grant
+ else:
+ raise ValueError("Wrong type of item")
+
+ def _revoke_dependent(self, grant: Grant, token: Token):
+ for t in grant.issued_token:
+ if t.based_on == token.value:
+ t.revoked = True
+ self._revoke_dependent(grant, t)
+
+ def revoke_token(self, session_id: str, token_value: str, recursive: bool = False):
+ """
+ Revoke a specific token that belongs to a specific user session.
+
+ :param session_id: Session identifier
+ :param token_value: Token value
+ :param recursive: Revoke all tokens that was minted using this token or
+ tokens minted by this token. Recursively.
+ """
+ token = self.find_token(session_id, token_value)
+ if token is None:
+ raise UnknownToken()
+
+ token.revoked = True
+ if recursive:
+ grant = self[session_id]
+ self._revoke_dependent(grant, token)
+
+ def get_authentication_events(self, session_id: Optional[str] = "",
+ user_id: Optional[str] = "",
+ client_id: Optional[str] = "") -> List[AuthnEvent]:
+ """
+ Return the authentication events that exists for a user/client combination.
+
+ :param session_id: A session identifier
+ :return: None if no authentication event could be found or an AuthnEvent instance.
+ """
+ if session_id:
+ user_id, client_id, _ = unpack_session_key(session_id)
+ elif user_id and client_id:
+ pass
+ else:
+ raise AttributeError("Must have session_id or user_id and client_id")
+
+ c_info = self.get([user_id, client_id])
+
+ _grants = [self.get([user_id, client_id, gid]) for gid in c_info['subordinate']]
+ return [g.authentication_event for g in _grants]
+
+ def get_authorization_request(self, session_id):
+ res = self.get_session_info(session_id=session_id, authorization_request=True)
+ return res["authorization_request"]
+
+ def get_authentication_event(self, session_id):
+ res = self.get_session_info(session_id=session_id, authentication_event=True)
+ return res["authentication_event"]
+
+ def revoke_client_session(self, session_id: str):
+ """
+ Revokes a client session
+
+ :param session_id: Session identifier
+ """
+ parts = unpack_session_key(session_id)
+ if len(parts) == 2:
+ _user_id, _client_id = parts
+ elif len(parts) == 3:
+ _user_id, _client_id, _ = parts
+ else:
+ raise ValueError("Invalid session ID")
+
+ _info = self.get([_user_id, _client_id])
+ self.set([_user_id, _client_id], _info.revoke())
+
+ def revoke_grant(self, session_id: str):
+ """
+ Revokes the grant pointed to by a session identifier.
+
+ :param session_id: A session identifier
+ """
+ _path = unpack_session_key(session_id)
+ _info = self.get(_path)
+ _info.revoke()
+ self.set(_path, _info)
+
+ def grants(self,
+ session_id: Optional[str] = "",
+ user_id: Optional[str] = "",
+ client_id: Optional[str] = "") -> List[Grant]:
+ """
+ Find all grant connected to a user session
+
+ :param session_id: A session identifier
+ :return: A list of grants
+ """
+ if session_id:
+ user_id, client_id, _ = unpack_session_key(session_id)
+ elif user_id and client_id:
+ pass
+ else:
+ raise AttributeError("Must have session_id or user_id and client_id")
+
+ _csi = self.get([user_id, client_id])
+ return [self.get([user_id, client_id, gid]) for gid in _csi['subordinate']]
+
+ def get_session_info(self,
+ session_id: str,
+ user_session_info: bool = False,
+ client_session_info: bool = False,
+ grant: bool = False,
+ authentication_event: bool = False,
+ authorization_request: bool = False) -> dict:
+ """
+ Returns information connected to a session.
+
+ :param session_id: The identifier of the session
+ :param user_session_info: Whether user session info should part of the response
+ :param client_session_info: Whether client session info should part of the response
+ :param grant: Whether the grant should part of the response
+ :param authentication_event: Whether the authentication event information should part of
+ the response
+ :param authorization_request: Whether the authorization_request should part of the response
+ :return: A dictionary with session information
+ """
+ _user_id, _client_id, _grant_id = unpack_session_key(session_id)
+ _grant = None
+ res = {
+ "session_id": session_id,
+ "user_id": _user_id,
+ "client_id": _client_id,
+ "grant_id": _grant_id
+ }
+ if user_session_info:
+ res["user_session_info"] = self.get([_user_id])
+ if client_session_info:
+ res["client_session_info"] = self.get([_user_id, _client_id])
+ if grant:
+ res["grant"] = self.get([_user_id, _client_id, _grant_id])
+
+ if authentication_event:
+ if grant:
+ res["authentication_event"] = res["grant"]["authentication_event"]
+ else:
+ _grant = self.get([_user_id, _client_id, _grant_id])
+ res["authentication_event"] = _grant.authentication_event
+
+ if authorization_request:
+ if grant:
+ res["authorization_request"] = res["grant"].authorization_request
+ elif _grant:
+ res["authorization_request"] = _grant.authorization_request
+ else:
+ _grant = self.get([_user_id, _client_id, _grant_id])
+ res["authorization_request"] = _grant.authorization_request
+
+ return res
+
+ def get_session_info_by_token(self,
+ token_value: str,
+ user_session_info: bool = False,
+ client_session_info: bool = False,
+ grant: bool = False,
+ authentication_event: bool = False,
+ authorization_request: bool = False
+ ) -> dict:
+ _token_info = self.token_handler.info(token_value)
+ return self.get_session_info(_token_info["sid"],
+ user_session_info=user_session_info,
+ client_session_info=client_session_info,
+ grant=grant,
+ authentication_event=authentication_event,
+ authorization_request=authorization_request)
+
+ def get_session_id_by_token(self, token_value: str) -> str:
+ _token_info = self.token_handler.info(token_value)
+ return _token_info["sid"]
+
+ def add_grant(self, user_id: str, client_id: str, **kwargs) -> Grant:
+ """
+ Creates and adds a grant to a user session.
+
+ :param user_id: User identifier
+ :param client_id: Client identifier
+ :param kwargs: Keyword arguments to the Grant class initialization
+ :return: A Grant instance
+ """
+ args = {k: v for k, v in kwargs.items() if k in Grant.attributes}
+ _grant = Grant(**args)
+ self.set([user_id, client_id, _grant.id], _grant)
+ _client_session_info = self.get([user_id, client_id])
+ _client_session_info["subordinate"].append(_grant.id)
+ self.set([user_id, client_id], _client_session_info)
+ return _grant
+
+ # def find_grant_by_type_and_target(
+ # self,
+ # token: str,
+ # token_type: Grant,
+ # resource_server: str) -> Optional[Grant]:
+ # """
+ #
+ # :param token:
+ # :param token_type:
+ # :param resource_server:
+ # :return:
+ # """
+ # session_id = self.get_session_id_by_token(token)
+ # for grant in self.grants(session_id=session_id):
+ # if isinstance(grant, token_type):
+ # if resource_server in grant.resources:
+ # return grant
+ # return None
+
+ def remove_session(self, session_id: str):
+ _user_id, _client_id, _grant_id = unpack_session_key(session_id)
+ self.delete([_user_id, _client_id, _grant_id])
+
+
+def create_session_manager(endpoint_context, token_handler_args, db=None, sub_func=None):
+ _token_handler = handler.factory(endpoint_context, **token_handler_args)
+ return SessionManager(_token_handler, db=db, sub_func=sub_func)
diff --git a/src/oidcendpoint/session/storage.py b/src/oidcendpoint/session/storage.py
new file mode 100644
index 0000000..150467f
--- /dev/null
+++ b/src/oidcendpoint/session/storage.py
@@ -0,0 +1,22 @@
+import json
+
+from .grant import ExchangeGrant
+from .grant import Grant
+from .info import ClientSessionInfo
+from .info import UserSessionInfo
+
+
+class JSON:
+ def serialize(self, instance):
+ return instance.to_json()
+
+ def deserialize(self, js):
+ args = json.loads(js)
+ if args["type"] == "UserSessionInfo":
+ return UserSessionInfo().from_json(js)
+ elif args["type"] == "ClientSessionInfo":
+ return ClientSessionInfo().from_json(js)
+ elif args["type"] == "grant":
+ return Grant().from_json(js)
+ elif args["type"] == "exchange_grant":
+ return ExchangeGrant().from_json(js)
diff --git a/src/oidcendpoint/session/token.py b/src/oidcendpoint/session/token.py
new file mode 100644
index 0000000..fd8b32b
--- /dev/null
+++ b/src/oidcendpoint/session/token.py
@@ -0,0 +1,163 @@
+import json
+from typing import Optional
+from uuid import uuid1
+
+from oidcmsg.time_util import time_sans_frac
+
+
+class MintingNotAllowed(Exception):
+ pass
+
+
+class Item:
+ def __init__(self,
+ usage_rules: Optional[dict] = None,
+ issued_at: int = 0,
+ expires_in: int = 0,
+ expires_at: int = 0,
+ not_before: int = 0,
+ revoked: bool = False,
+ used: int = 0
+ ):
+ self.issued_at = issued_at or time_sans_frac()
+ self.not_before = not_before
+ if expires_at == 0 and expires_in != 0:
+ self.set_expires_at(expires_in)
+ else:
+ self.expires_at = expires_at
+
+ self.revoked = revoked
+ self.used = used
+ self.usage_rules = usage_rules or {}
+
+ def set_expires_at(self, expires_in):
+ self.expires_at = time_sans_frac() + expires_in
+
+ def max_usage_reached(self):
+ if "max_usage" in self.usage_rules:
+ return self.used >= self.usage_rules['max_usage']
+ else:
+ return False
+
+ def is_active(self, now=0):
+ if self.max_usage_reached():
+ return False
+
+ if self.revoked:
+ return False
+
+ if now == 0:
+ now = time_sans_frac()
+
+ if self.not_before:
+ if now < self.not_before:
+ return False
+
+ if self.expires_at:
+ if now > self.expires_at:
+ return False
+
+ return True
+
+ def revoke(self):
+ self.revoked = True
+
+
+class Token(Item):
+ attributes = ["type", "issued_at", "not_before", "expires_at", "revoked", "value",
+ "usage_rules", "used", "based_on", "id", "scope", "claims",
+ "resources"]
+
+ def __init__(self,
+ type: str = '',
+ value: str = '',
+ based_on: Optional[str] = None,
+ usage_rules: Optional[dict] = None,
+ issued_at: int = 0,
+ expires_in: int = 0,
+ expires_at: int = 0,
+ not_before: int = 0,
+ revoked: bool = False,
+ used: int = 0,
+ id: str = "",
+ scope: Optional[list] = None,
+ claims: Optional[dict] = None,
+ resources: Optional[list] = None,
+ ):
+ Item.__init__(self, usage_rules=usage_rules, issued_at=issued_at, expires_in=expires_in,
+ expires_at=expires_at, not_before=not_before, revoked=revoked, used=used)
+
+ self.type = type
+ self.value = value
+ self.based_on = based_on
+ self.id = id or uuid1().hex
+ self.set_defaults()
+ self.scope = scope or []
+ self.claims = claims or {} # default is to not release any user information
+ self.resources = resources or []
+
+ def set_defaults(self):
+ pass
+
+ def register_usage(self):
+ self.used += 1
+
+ def has_been_used(self):
+ return self.used != 0
+
+ def to_json(self):
+ d = {
+ "type": self.type,
+ "issued_at": self.issued_at,
+ "not_before": self.not_before,
+ "expires_at": self.expires_at,
+ "revoked": self.revoked,
+ "value": self.value,
+ "usage_rules": self.usage_rules,
+ "used": self.used,
+ "based_on": self.based_on,
+ "id": self.id,
+ "scope": self.scope,
+ "claims": self.claims,
+ "resources": self.resources
+ }
+ return json.dumps(d)
+
+ def from_json(self, json_str):
+ d = json.loads(json_str)
+ for attr in self.attributes:
+ if attr in d:
+ setattr(self, attr, d[attr])
+ return self
+
+ def supports_minting(self, token_type):
+ _supports_minting = self.usage_rules.get("supports_minting")
+ if _supports_minting is None:
+ return False
+ else:
+ return token_type in _supports_minting
+
+
+class AccessToken(Token):
+ pass
+
+
+class AuthorizationCode(Token):
+ def set_defaults(self):
+ if "supports_minting" not in self.usage_rules:
+ self.usage_rules['supports_minting'] = ["access_token", "refresh_token"]
+
+ self.usage_rules['max_usage'] = 1
+
+
+class RefreshToken(Token):
+ def set_defaults(self):
+ if "supports_minting" not in self.usage_rules:
+ self.usage_rules['supports_minting'] = ["access_token", "refresh_token"]
+
+
+SHORT_TYPE_NAME = {
+ "authorization_code": "A",
+ "access_token": "T",
+ "refresh_token": "R"
+}
diff --git a/src/oidcendpoint/sso_db.py b/src/oidcendpoint/sso_db.py
deleted file mode 100644
index ea30364..0000000
--- a/src/oidcendpoint/sso_db.py
+++ /dev/null
@@ -1,196 +0,0 @@
-import logging
-
-from oidcmsg.storage.init import storage_factory
-
-logger = logging.getLogger(__name__)
-
-
-class DictDatabase:
- def __init__(self, db_conf=None, db=None):
- if db_conf:
- self._db = storage_factory(db_conf)
- else:
- self._db = db
-
- def set(self, key, sec_key, value):
- logger.debug("SSODb set {} - {}: {}".format(key, sec_key, value))
- # Try loading the key, that's a good place to put a debugger to
- # import pdb; pdb.set_trace()
- _dic = self._db.get(key, {})
- if sec_key in _dic:
- _dic[sec_key].append(value)
- else:
- _dic[sec_key] = [value]
- self._db[key] = _dic
-
- def get(self, key, sec_key):
- _dic = self._db.get(key, {})
- logger.debug("SSODb get {} - {}: {}".format(key, sec_key, _dic))
- return _dic.get(sec_key, [])
-
- def delete(self, key, sec_key):
- _dic = self._db.get(key, {})
- try:
- del _dic[sec_key]
- except KeyError as e:
- logger.warn('SSODb.delete _dic[{}] not found'.format(sec_key))
- if _dic == {}:
- del self._db[key]
- else:
- self._db[key] = _dic
-
- def remove(self, key, sec_key, value):
- _dic = self._db.get(key, {})
- _values = _dic[sec_key]
- # full clean up
- while value in _values:
- _values.remove(value)
- # if changes have been made -> update them
-
- if _values:
- _dic[sec_key] = _values
- self._db[key] = _dic
- else:
- del _dic[sec_key]
- self._db[key] = _dic
-
-
-class SSODb(DictDatabase):
- """
- Keeps the connection between an user, one or more sub claims and
- possibly several session IDs.
- Each user can be represented by more then one sub claim and every sub claim
- can appear in more the one session.
- So, we have chains like this::
-
- session id->subject id->user id
-
- """
-
- def map_sid2uid(self, sid, uid):
- """
- Store the connection between a Session ID and a User ID
-
- :param sid: Session ID
- :param uid: User ID
- """
- self.set(sid, "uid", uid)
- self.set(uid, "sid", sid)
-
- def map_sid2sub(self, sid, sub):
- """
- Store the connection between a Session ID and a subject ID.
-
- :param sid: Session ID
- :param sub: subject ID
- """
- self.set(sid, "sub", sub)
- self.set(sub, "sid", sid)
-
- def get_sids_by_uid(self, uid):
- """
- Return a list of session IDs that this user is connected to.
-
- :param uid: The subject ID
- :return: list of session IDs
- """
- return self.get(uid, "sid")
-
- def get_sids_by_sub(self, sub):
- return self.get(sub, "sid")
-
- def get_sub_by_sid(self, sid):
- _subs = self.get(sid, "sub")
- if _subs:
- return _subs[0]
- else:
- return None
-
- def get_uid_by_sid(self, sid):
- """
- Find the User ID that is connected to a Session ID.
-
- :param sid: A Session ID
- :return: A User ID, always just one
- """
- _uids = self.get(sid, "uid")
- if _uids:
- return _uids[0]
- else:
- return None
-
- def get_subs_by_uid(self, uid):
- """
- Find all subject identifiers that is connected to a User ID.
-
- :param uid: A User ID
- :return: A set of subject identifiers
- """
- res = set()
- for sid in self.get(uid, "sid"):
- res |= set(self.get(sid, "sub"))
- return res
-
- def remove_sid2sub(self, sid, sub):
- """
- Remove the connection between a session ID and a Subject
-
- :param sid: Session ID
- :param sub: Subject identifier
-´ """
- self.remove(sub, "sid", sid)
- self.remove(sid, "sub", sub)
-
- def remove_sid2uid(self, sid, uid):
- """
- Remove the connection between a session ID and a Subject
-
- :param sid: Session ID
- :param uid: User identifier
-´ """
- self.remove(uid, "sid", sid)
- self.remove(sid, "uid", uid)
-
- def remove_session_id(self, sid):
- """
- Remove all references to a specific Session ID
-
- :param sid: A Session ID
- """
- for uid in self.get(sid, "uid"):
- self.remove(uid, "sid", sid)
- if self._db.get(uid) == {}:
- del self._db[uid]
- self.delete(sid, "uid")
-
- for sub in self.get(sid, "sub"):
- self.remove(sub, "sid", sid)
- if self._db.get(sub) == {}:
- del self._db[sub]
- self.delete(sid, "sub")
-
- def remove_uid(self, uid):
- """
- Remove all references to a specific User ID
-
- :param uid: A User ID
- """
- for sid in self.get(uid, "sid"):
- self.remove(sid, "uid", uid)
- self.delete(uid, "sid")
-
- def remove_sub(self, sub):
- """
- Remove all references to a specific Subject ID
-
- :param sub: A Subject ID
- """
- for _sid in self.get(sub, "sid"):
- self.remove(_sid, "sub", sub)
- self.delete(sub, "sid")
-
- def close(self):
- self._db.close()
-
- def clear(self):
- self._db.clear()
diff --git a/src/oidcendpoint/token_handler.py b/src/oidcendpoint/token/__init__.py
similarity index 50%
rename from src/oidcendpoint/token_handler.py
rename to src/oidcendpoint/token/__init__.py
index 692d7d5..f6ef69a 100755
--- a/src/oidcendpoint/token_handler.py
+++ b/src/oidcendpoint/token/__init__.py
@@ -1,18 +1,16 @@
import base64
import hashlib
import logging
-import warnings
+from typing import Optional
from cryptography.fernet import Fernet
-from cryptography.fernet import InvalidToken
-from cryptojwt.exception import Invalid
-from cryptojwt.key_jar import init_key_jar
from cryptojwt.utils import as_bytes
from cryptojwt.utils import as_unicode
from oidcmsg.time_util import time_sans_frac
from oidcendpoint import rndstr
-from oidcendpoint.util import importer
+from oidcendpoint.token.exception import UnknownToken
+from oidcendpoint.token.exception import WrongTokenType
from oidcendpoint.util import lv_pack
from oidcendpoint.util import lv_unpack
@@ -21,26 +19,6 @@
logger = logging.getLogger(__name__)
-class ExpiredToken(Exception):
- pass
-
-
-class WrongTokenType(Exception):
- pass
-
-
-class AccessCodeUsed(Exception):
- pass
-
-
-class UnknownToken(Exception):
- pass
-
-
-class NotAllowed(Exception):
- pass
-
-
def is_expired(exp, when=0):
if exp < 0:
return False
@@ -74,13 +52,16 @@ class Token(object):
def __init__(self, typ, lifetime=300, **kwargs):
self.type = typ
self.lifetime = lifetime
- self.args = kwargs
+ self.kwargs = kwargs
- def __call__(self, sid):
+ def __call__(self,
+ session_id: Optional[str] = '',
+ ttype: Optional[str] = '',
+ **payload) -> str:
"""
Return a token.
- :param sid: Session id
+ :param payload: Information to place in the token if possible.
:return:
"""
raise NotImplementedError()
@@ -121,13 +102,14 @@ def __init__(self, password, typ="", token_type="Bearer", **kwargs):
self.crypt = Crypt(password)
self.token_type = token_type
- def __call__(self, sid="", ttype="", **kwargs):
+ def __call__(self,
+ session_id: Optional[str] = '',
+ ttype: Optional[str] = '',
+ **payload) -> str:
"""
Return a token.
- :param ttype: Type of token
- :param prev: Previous token, if there is one to go from
- :param sid: Session id
+ :param payload: Token information
:return:
"""
if not ttype and self.type:
@@ -146,7 +128,7 @@ def __call__(self, sid="", ttype="", **kwargs):
rnd = rndstr(32) # Ultimate length multiple of 16
return base64.b64encode(
- self.crypt.encrypt(lv_pack(rnd, ttype, sid, exp).encode())
+ self.crypt.encrypt(lv_pack(rnd, ttype, session_id, exp).encode())
).decode("utf-8")
def key(self, user="", areq=None):
@@ -169,7 +151,7 @@ def split_token(self, token):
# order: rnd, type, sid
return lv_unpack(plain)
- def info(self, token):
+ def info(self, token: str) -> dict:
"""
Return token information.
@@ -183,131 +165,10 @@ def info(self, token):
_res["handler"] = self
return _res
- def is_expired(self, token, when=0):
+ def is_expired(self, token: str, when: int = 0):
_exp = self.info(token)["exp"]
if _exp == "-1":
return False
else:
exp = int(_exp)
return is_expired(exp, when)
-
-
-class TokenHandler(object):
- def __init__(
- self, access_token_handler=None, code_handler=None, refresh_token_handler=None
- ):
-
- self.handler = {"code": code_handler, "access_token": access_token_handler}
-
- self.handler_order = ["code", "access_token"]
-
- if refresh_token_handler:
- self.handler["refresh_token"] = refresh_token_handler
- self.handler_order.append("refresh_token")
-
- def __getitem__(self, typ):
- return self.handler[typ]
-
- def __contains__(self, item):
- return item in self.handler
-
- def info(self, item, order=None):
- _handler, item_info = self.get_handler(item, order)
-
- if _handler is None:
- logger.info("Unknown token format")
- raise UnknownToken(item)
- else:
- return item_info
-
- def sid(self, token, order=None):
- return self.info(token, order)["sid"]
-
- def type(self, token, order=None):
- return self.info(token, order)["type"]
-
- def get_handler(self, token, order=None):
- if order is None:
- order = self.handler_order
-
- for typ in order:
- try:
- res = self.handler[typ].info(token)
- except (KeyError, WrongTokenType, InvalidToken, UnknownToken, Invalid):
- pass
- else:
- return self.handler[typ], res
-
- return None, None
-
- def keys(self):
- return self.handler.keys()
-
-
-def init_token_handler(ec, spec, typ):
- try:
- _cls = spec["class"]
- except KeyError:
- cls = DefaultToken
- else:
- cls = importer(_cls)
-
- _kwargs = spec.get("kwargs")
- if _kwargs is None:
- if cls != DefaultToken:
- warnings.warn(
- "Token initialisation arguments should be grouped under 'kwargs'.",
- DeprecationWarning,
- stacklevel=2,
- )
- _kwargs = spec
-
- return cls(typ=typ, ec=ec, **_kwargs)
-
-
-def _add_passwd(keyjar, conf, kid):
- if keyjar:
- _keys = keyjar.get_encrypt_key(key_type="oct", kid=kid)
- if _keys:
- pw = as_unicode(_keys[0].k)
- if "kwargs" in conf:
- conf["kwargs"]["password"] = pw
- else:
- conf["password"] = pw
-
-
-def factory(ec, code=None, token=None, refresh=None, jwks_def=None, **kwargs):
- """
- Create a token handler
-
- :param code:
- :param token:
- :param refresh:
- :param jwks_def:
- :return: TokenHandler instance
- """
-
- TTYPE = {"code": "A", "token": "T", "refresh": "R"}
-
- if jwks_def:
- kj = init_key_jar(**jwks_def)
- else:
- kj = None
-
- args = {}
-
- if code:
- _add_passwd(kj, code, "code")
- args["code_handler"] = init_token_handler(ec, code, TTYPE["code"])
-
- if token:
- _add_passwd(kj, token, "token")
- args["access_token_handler"] = init_token_handler(ec, token, TTYPE["token"])
-
- if refresh:
- _add_passwd(kj, refresh, "refresh")
- args["refresh_token_handler"] = init_token_handler(
- ec, refresh, TTYPE["refresh"]
- )
-
- return TokenHandler(**args)
diff --git a/src/oidcendpoint/token/exception.py b/src/oidcendpoint/token/exception.py
new file mode 100644
index 0000000..965b38c
--- /dev/null
+++ b/src/oidcendpoint/token/exception.py
@@ -0,0 +1,18 @@
+class ExpiredToken(Exception):
+ pass
+
+
+class WrongTokenType(Exception):
+ pass
+
+
+class AccessCodeUsed(Exception):
+ pass
+
+
+class UnknownToken(Exception):
+ pass
+
+
+class NotAllowed(Exception):
+ pass
diff --git a/src/oidcendpoint/token/handler.py b/src/oidcendpoint/token/handler.py
new file mode 100755
index 0000000..ba895b1
--- /dev/null
+++ b/src/oidcendpoint/token/handler.py
@@ -0,0 +1,137 @@
+import logging
+import warnings
+
+from cryptography.fernet import InvalidToken
+from cryptojwt.exception import Invalid
+from cryptojwt.key_jar import init_key_jar
+from cryptojwt.utils import as_unicode
+
+from oidcendpoint.token import DefaultToken
+from oidcendpoint.token import UnknownToken
+from oidcendpoint.token import WrongTokenType
+from oidcendpoint.util import importer
+
+__author__ = "Roland Hedberg"
+
+logger = logging.getLogger(__name__)
+
+
+class TokenHandler(object):
+ def __init__(
+ self, access_token_handler=None, code_handler=None, refresh_token_handler=None
+ ):
+
+ self.handler = {"code": code_handler, "access_token": access_token_handler}
+
+ self.handler_order = ["code", "access_token"]
+
+ if refresh_token_handler:
+ self.handler["refresh_token"] = refresh_token_handler
+ self.handler_order.append("refresh_token")
+
+ def __getitem__(self, typ):
+ return self.handler[typ]
+
+ def __contains__(self, item):
+ return item in self.handler
+
+ def info(self, item, order=None):
+ _handler, item_info = self.get_handler(item, order)
+
+ if _handler is None:
+ logger.info("Unknown token format")
+ raise UnknownToken(item)
+ else:
+ return item_info
+
+ def sid(self, token, order=None):
+ return self.info(token, order)["sid"]
+
+ def type(self, token, order=None):
+ return self.info(token, order)["type"]
+
+ def get_handler(self, token, order=None):
+ if order is None:
+ order = self.handler_order
+
+ for typ in order:
+ try:
+ res = self.handler[typ].info(token)
+ except (KeyError, WrongTokenType, InvalidToken, UnknownToken, Invalid):
+ pass
+ else:
+ return self.handler[typ], res
+
+ return None, None
+
+ def keys(self):
+ return self.handler.keys()
+
+
+def init_token_handler(ec, spec, typ):
+ try:
+ _cls = spec["class"]
+ except KeyError:
+ cls = DefaultToken
+ else:
+ cls = importer(_cls)
+
+ _kwargs = spec.get("kwargs")
+ if _kwargs is None:
+ if cls != DefaultToken:
+ warnings.warn(
+ "Token initialisation arguments should be grouped under 'kwargs'.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ _kwargs = spec
+
+ return cls(typ=typ, endpoint_context=ec, **_kwargs)
+
+
+def _add_passwd(keyjar, conf, kid):
+ if keyjar:
+ _keys = keyjar.get_encrypt_key(key_type="oct", kid=kid)
+ if _keys:
+ pw = as_unicode(_keys[0].k)
+ if "kwargs" in conf:
+ conf["kwargs"]["password"] = pw
+ else:
+ conf["password"] = pw
+
+
+def factory(ec, code=None, token=None, refresh=None, jwks_def=None, **kwargs):
+ """
+ Create a token handler
+
+ :param code:
+ :param token:
+ :param refresh:
+ :param jwks_def:
+ :return: TokenHandler instance
+ """
+
+ TTYPE = {"code": "A", "token": "T", "refresh": "R"}
+
+ if jwks_def:
+ kj = init_key_jar(**jwks_def)
+ else:
+ kj = None
+
+ args = {}
+
+ if code:
+ _add_passwd(kj, code, "code")
+ args["code_handler"] = init_token_handler(ec, code, TTYPE["code"])
+
+ if token:
+ _add_passwd(kj, token, "token")
+ args["access_token_handler"] = init_token_handler(ec, token, TTYPE["token"])
+
+ if refresh is not None:
+ _add_passwd(kj, refresh, "refresh")
+ args["refresh_token_handler"] = init_token_handler(
+ ec, refresh, TTYPE["refresh"]
+ )
+
+ return TokenHandler(**args)
diff --git a/src/oidcendpoint/token/jwt_token.py b/src/oidcendpoint/token/jwt_token.py
new file mode 100644
index 0000000..aa868ae
--- /dev/null
+++ b/src/oidcendpoint/token/jwt_token.py
@@ -0,0 +1,115 @@
+from typing import Optional
+
+from cryptojwt import JWT
+from cryptojwt.jws.exception import JWSException
+
+from oidcendpoint.exception import ToOld
+
+from . import Token
+from . import is_expired
+from .exception import UnknownToken
+
+TYPE_MAP = {
+ "A": "code",
+ "T": "access_token",
+ "R": "refresh_token"
+}
+
+
+class JWTToken(Token):
+ def __init__(
+ self,
+ typ,
+ keyjar=None,
+ issuer: str = None,
+ aud: Optional[list] = None,
+ alg: str = "ES256",
+ lifetime: int = 300,
+ endpoint_context=None,
+ token_type: str = "Bearer",
+ **kwargs
+ ):
+ Token.__init__(self, typ, **kwargs)
+ self.token_type = token_type
+ self.lifetime = lifetime
+
+ self.kwargs = kwargs
+ self.key_jar = keyjar or endpoint_context.keyjar
+ self.issuer = issuer or endpoint_context.issuer
+ self.cdb = endpoint_context.cdb
+ self.endpoint_context = endpoint_context
+
+ self.def_aud = aud or []
+ self.alg = alg
+
+ def __call__(self,
+ session_id: Optional[str] = '',
+ ttype: Optional[str] = '',
+ **payload) -> str:
+ """
+ Return a token.
+
+ :param session_id: Session id
+ :param subject:
+ :param grant:
+ :param kwargs: KeyWord arguments
+ :return: Signed JSON Web Token
+ """
+ if not ttype and self.type:
+ ttype = self.type
+ else:
+ ttype = "A"
+
+ payload.update({"sid": session_id, "ttype": ttype})
+
+ # payload.update(kwargs)
+ signer = JWT(
+ key_jar=self.key_jar,
+ iss=self.issuer,
+ lifetime=self.lifetime,
+ sign_alg=self.alg,
+ )
+
+ return signer.pack(payload)
+
+ def info(self, token):
+ """
+ Return type of Token (A=Access code, T=Token, R=Refresh token) and
+ the session id.
+
+ :param token: A token
+ :return: tuple of token type and session id
+ """
+ verifier = JWT(key_jar=self.key_jar, allowed_sign_algs=[self.alg])
+ try:
+ _payload = verifier.unpack(token)
+ except JWSException:
+ raise UnknownToken()
+
+ if is_expired(_payload["exp"]):
+ raise ToOld("Token has expired")
+ # All the token metadata
+ _res = {
+ "sid": _payload["sid"],
+ "type": _payload["ttype"],
+ "exp": _payload["exp"],
+ "handler": self,
+ }
+ return _res
+
+ def is_expired(self, token, when=0):
+ """
+ Evaluate whether the token has expired or not
+
+ :param token: The token
+ :param when: The time against which to check the expiration
+ 0 means now.
+ :return: True/False
+ """
+ verifier = JWT(key_jar=self.key_jar, allowed_sign_algs=[self.alg])
+ _payload = verifier.unpack(token)
+ return is_expired(_payload["exp"], when)
+
+ def gather_args(self, sid, sdb, udb):
+ _sinfo = sdb[sid]
+ return {}
diff --git a/src/oidcendpoint/user_authn/user.py b/src/oidcendpoint/user_authn/user.py
index e197ef6..2eeac47 100755
--- a/src/oidcendpoint/user_authn/user.py
+++ b/src/oidcendpoint/user_authn/user.py
@@ -11,6 +11,7 @@
from cryptojwt.jwt import JWT
from oidcendpoint import sanitize
+from oidcendpoint.cookie import cookie_value
from oidcendpoint.exception import FailedAuthentication
from oidcendpoint.exception import ImproperlyConfigured
from oidcendpoint.exception import InvalidCookieSign
@@ -293,7 +294,28 @@ def authenticated_as(self, cookie=None, authorization="", **kwargs):
if self.fail:
raise self.fail()
- return {"uid": self.user}, time.time()
+ res = {"uid": self.user}
+
+ if cookie:
+ try:
+ val = self.cookie_dealer.get_cookie_value(
+ cookie, cookie_name=self.endpoint_context.cookie_name["session"]
+ )
+ except (InvalidCookieSign, AssertionError, AttributeError) as err:
+ logger.warning(err)
+ val = None
+
+ if val is None:
+ return None, 0
+ else:
+ b64val, _ts, typ = val
+ info = cookie_value(b64val)
+ if isinstance(info, dict):
+ res.update(info)
+ else:
+ res["value"] = b64val
+
+ return res, time.time()
def factory(cls, **kwargs):
diff --git a/src/oidcendpoint/userinfo.py b/src/oidcendpoint/userinfo.py
deleted file mode 100755
index 590c2ad..0000000
--- a/src/oidcendpoint/userinfo.py
+++ /dev/null
@@ -1,210 +0,0 @@
-import logging
-
-from oidcmsg.oidc import Claims
-
-from oidcendpoint import sanitize
-from oidcendpoint.exception import ImproperlyConfigured
-from oidcendpoint.scopes import convert_scopes2claims
-
-logger = logging.getLogger(__name__)
-
-
-def id_token_claims(session, provider_info):
- """
- Pick the IdToken claims from the request
-
- :param session: Session information
- :return: The IdToken claims
- """
- itc = update_claims(session, "id_token", provider_info=provider_info, old_claims={})
- return itc
-
-
-def update_claims(session, about, provider_info, old_claims=None):
- """
-
- :param session:
- :param about: userinfo or id_token
- :param old_claims:
- :return: claims or None
- """
-
- if old_claims is None:
- old_claims = {}
-
- req = None
- try:
- req = session["authn_req"]
- except KeyError:
- pass
-
- if req:
- try:
- _claims = req["claims"][about]
- except KeyError:
- pass
- else:
- if _claims:
- # Deal only with supported claims
- _unsup = [
- c
- for c in _claims.keys()
- if c not in provider_info["claims_supported"]
- ]
- for _c in _unsup:
- del _claims[_c]
-
- # update with old claims, do not overwrite
- for key, val in old_claims.items():
- if key not in _claims:
- _claims[key] = val
- return _claims
-
- return old_claims
-
-
-def claims_match(value, claimspec):
- """
- Implements matching according to section 5.5.1 of
- http://openid.net/specs/openid-connect-core-1_0.html
- The lack of value is not checked here.
- Also the text doesn't prohibit having both 'value' and 'values'.
-
- :param value: single value
- :param claimspec: None or dictionary with 'essential', 'value' or 'values'
- as key
- :return: Boolean
- """
- if claimspec is None: # match anything
- return True
-
- matched = False
- for key, val in claimspec.items():
- if key == "value":
- if value == val:
- matched = True
- elif key == "values":
- if value in val:
- matched = True
- elif key == "essential":
- # Whether it's essential or not doesn't change anything here
- continue
-
- if matched:
- break
-
- if matched is False:
- if list(claimspec.keys()) == ["essential"]:
- return True
-
- return matched
-
-
-def by_schema(cls, **kwa):
- """
- Will return only those claims that are listed in the Class definition.
-
- :param cls: A subclass of :py:class:´oidcmsg.message.Message`
- :param kwa: Keyword arguments
- :return: A dictionary with claims (keys) that meets the filter criteria
- """
- return dict([(key, val) for key, val in kwa.items() if key in cls.c_param])
-
-
-def collect_user_info(
- endpoint_context, session, userinfo_claims=None, scope_to_claims=None
-):
- """
- Collect information about a user.
- This can happen in two cases, either when constructing an IdToken or
- when returning user info through the UserInfo endpoint
-
- :param session: Session information
- :param userinfo_claims: user info claims
- :return: User info
- """
- authn_req = session["authn_req"]
- if scope_to_claims is None:
- scope_to_claims = endpoint_context.scope2claims
-
- supported_scopes = endpoint_context.scopes_handler.filter_scopes(
- authn_req["client_id"], endpoint_context, authn_req["scope"]
- )
- if userinfo_claims is None:
- _allowed_claims = endpoint_context.claims_handler.allowed_claims(
- authn_req["client_id"], endpoint_context
- )
- uic = convert_scopes2claims(
- supported_scopes, _allowed_claims, map=scope_to_claims
- )
-
- # Get only keys allowed by user and update the dict if such info
- # is stored in session
- perm_set = session.get("permission")
- if perm_set:
- uic = {key: uic[key] for key in uic if key in perm_set}
-
- uic = update_claims(
- session,
- "userinfo",
- provider_info=endpoint_context.provider_info,
- old_claims=uic,
- )
-
- if uic:
- userinfo_claims = Claims(**uic)
- logger.debug("userinfo_claim: %s" % sanitize(userinfo_claims.to_dict()))
- else:
- userinfo_claims = None
- logger.warning(("Client {} doesn't have any claims "
- "belonging to one or more scopes.").format(authn_req["client_id"]))
- raise ImproperlyConfigured("Some additional scopes doesn't have any claims.")
-
- logger.debug("Session info: %s" % sanitize(session))
-
- authn_event = session["authn_event"]
- if authn_event:
- uid = authn_event["uid"]
- else:
- uid = session["uid"]
-
- info = endpoint_context.userinfo(uid, authn_req["client_id"], userinfo_claims)
-
- if "sub" in userinfo_claims:
- if not claims_match(session["sub"], userinfo_claims["sub"]):
- raise FailedAuthentication("Unmatched sub claim")
-
- info["sub"] = session["sub"]
- try:
- logger.debug("user_info_response: {}".format(info))
- except UnicodeEncodeError:
- logger.debug("user_info_response: {}".format(info.encode("utf-8")))
-
- return info
-
-
-def userinfo_in_id_token_claims(endpoint_context, session, def_itc=None):
- """
- Collect user info claims that are to be placed in the id token.
-
- :param endpoint_context: Endpoint context
- :param session: Session information
- :param def_itc: Default ID Token claims
- :return: User information or None
- """
- if def_itc:
- itc = def_itc
- else:
- itc = {}
-
- itc.update(id_token_claims(session, provider_info=endpoint_context.provider_info))
-
- if not itc:
- return None
-
- _claims = by_schema(endpoint_context.id_token_schema, **itc)
-
- if _claims:
- return collect_user_info(endpoint_context, session, _claims)
- else:
- return None
diff --git a/src/oidcendpoint/util.py b/src/oidcendpoint/util.py
index c4dfa50..a6eafd2 100755
--- a/src/oidcendpoint/util.py
+++ b/src/oidcendpoint/util.py
@@ -2,9 +2,9 @@
import json
import logging
from urllib.parse import parse_qs
+from urllib.parse import urlparse
from urllib.parse import urlsplit
from urllib.parse import urlunsplit
-from urllib.parse import urlparse
from oidcendpoint.exception import OidcEndpointError
@@ -175,7 +175,8 @@ def split_uri(uri):
def allow_refresh_token(endpoint_context):
# Are there a refresh_token handler
- refresh_token_handler = endpoint_context.sdb.handler.handler.get("refresh_token")
+ refresh_token_handler = endpoint_context.session_manager.token_handler.handler[
+ "refresh_token"]
# Is refresh_token grant type supported
_token_supported = False
diff --git a/tests/test_00_endpoint_context.py b/tests/test_00_endpoint_context.py
index 66e5fb7..da5af76 100755
--- a/tests/test_00_endpoint_context.py
+++ b/tests/test_00_endpoint_context.py
@@ -12,7 +12,7 @@
from oidcendpoint.oidc.provider_config import ProviderConfiguration
from oidcendpoint.oidc.registration import Registration
from oidcendpoint.oidc.session import Session
-from oidcendpoint.oidc.token import AccessToken
+from oidcendpoint.oidc.token import Token
from oidcendpoint.oidc.userinfo import UserInfo
from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD
@@ -49,7 +49,7 @@
"class": Authorization,
"kwargs": {},
},
- "token_endpoint": {"path": "token", "class": AccessToken, "kwargs": {}},
+ "token_endpoint": {"path": "token", "class": Token, "kwargs": {}},
"userinfo_endpoint": {
"path": "userinfo",
"class": UserInfo,
diff --git a/tests/test_01_grant.py b/tests/test_01_grant.py
new file mode 100644
index 0000000..3e621dc
--- /dev/null
+++ b/tests/test_01_grant.py
@@ -0,0 +1,361 @@
+from cryptojwt.key_jar import build_keyjar
+import pytest
+
+from oidcendpoint.endpoint_context import EndpointContext
+from oidcendpoint.session import session_key
+from oidcendpoint.session.grant import Grant
+from oidcendpoint.session.grant import TOKEN_MAP
+from oidcendpoint.session.grant import find_token
+from oidcendpoint.session.grant import get_usage_rules
+from oidcendpoint.session.token import AuthorizationCode
+from oidcendpoint.session.token import Token
+from oidcendpoint.token import DefaultToken
+from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD
+
+KEYDEFS = [
+ {"type": "RSA", "key": "", "use": ["sig"]},
+ {"type": "EC", "crv": "P-256", "use": ["sig"]},
+]
+
+KEYJAR = build_keyjar(KEYDEFS)
+
+conf = {
+ "issuer": "https://example.com/",
+ "template_dir": "template",
+ "keys": {"uri_path": "static/jwks.json", "key_defs": KEYDEFS, "read_only": True},
+ "endpoint": {
+ "authorization_endpoint": {
+ "path": "authorization",
+ "class": "oidcendpoint.oidc.authorization.Authorization",
+ "kwargs": {},
+ },
+ "token_endpoint": {
+ "path": "token",
+ "class": "oidcendpoint.oidc.token.Token",
+ "kwargs": {}
+ },
+ },
+ "authentication": {
+ "anon": {
+ "acr": INTERNETPROTOCOLPASSWORD,
+ "class": "oidcendpoint.user_authn.user.NoAuthn",
+ "kwargs": {"user": "diana"},
+ }
+ },
+}
+
+ENDPOINT_CONTEXT = EndpointContext(conf)
+
+
+def test_access_code():
+ token = AuthorizationCode('authorization_code', value="ABCD")
+ assert token.issued_at
+ assert token.type == "authorization_code"
+ assert token.value == "ABCD"
+
+ token.register_usage()
+ # max_usage == 1
+ assert token.max_usage_reached() is True
+
+
+def test_access_token():
+ code = AuthorizationCode('authorization_code', value="ABCD")
+ token = Token('access_token', value="1234", based_on=code.id, usage_rules={"max_usage": 2})
+ assert token.issued_at
+ assert token.type == "access_token"
+ assert token.value == "1234"
+
+ token.register_usage()
+ # max_usage - undefined
+ assert token.max_usage_reached() is False
+
+ token.register_usage()
+ assert token.max_usage_reached() is True
+
+ t = find_token([code, token], token.based_on)
+ assert t.value == "ABCD"
+
+ token.revoked = True
+ assert token.revoked is True
+
+
+TOKEN_HANDLER = {
+ "authorization_code": DefaultToken("authorization_code", typ="A"),
+ "access_token": DefaultToken("access_token", typ="T"),
+ "refresh_token": DefaultToken("refresh_token", typ="R")
+}
+
+
+def test_mint_token():
+ grant = Grant()
+
+ code = grant.mint_token("user_id;;client_id;;grant_id",
+ endpoint_context=ENDPOINT_CONTEXT,
+ token_type="authorization_code",
+ token_handler=TOKEN_HANDLER["authorization_code"])
+
+ access_token = grant.mint_token(
+ "user_id;;client_id;;grant_id",
+ endpoint_context=ENDPOINT_CONTEXT,
+ token_type="access_token",
+ token_handler=TOKEN_HANDLER["access_token"],
+ based_on=code,
+ scope=["openid", "foo", "bar"]
+ )
+
+ assert access_token.scope == ["openid", "foo", "bar"]
+
+
+SESSION_ID = session_key('user_id', 'client_id', 'grant.id')
+
+
+def test_grant():
+ grant = Grant()
+ code = grant.mint_token(SESSION_ID, endpoint_context=ENDPOINT_CONTEXT,
+ token_type="authorization_code",
+ token_handler=TOKEN_HANDLER["authorization_code"])
+
+ access_token = grant.mint_token(
+ SESSION_ID,
+ endpoint_context=ENDPOINT_CONTEXT,
+ token_type="access_token",
+ token_handler=TOKEN_HANDLER["access_token"],
+ based_on=code)
+
+ refresh_token = grant.mint_token(
+ SESSION_ID,
+ endpoint_context=ENDPOINT_CONTEXT,
+ token_type="refresh_token",
+ token_handler=TOKEN_HANDLER["refresh_token"],
+ based_on=code)
+
+ grant.revoke_token()
+ assert code.revoked is True
+ assert access_token.revoked is True
+ assert refresh_token.revoked is True
+
+
+def test_get_token():
+ grant = Grant()
+ code = grant.mint_token(SESSION_ID, endpoint_context=ENDPOINT_CONTEXT,
+ token_type="authorization_code",
+ token_handler=TOKEN_HANDLER["authorization_code"])
+
+ access_token = grant.mint_token(
+ SESSION_ID,
+ endpoint_context=ENDPOINT_CONTEXT,
+ token_type="access_token",
+ token_handler=TOKEN_HANDLER["access_token"],
+ based_on=code,
+ scope=["openid", "foo", "bar"]
+ )
+
+ _code = grant.get_token(code.value)
+ assert _code.id == code.id
+
+ _token = grant.get_token(access_token.value)
+ assert _token.id == access_token.id
+ assert set(_token.scope) == {"openid", "foo", "bar"}
+
+
+def test_grant_revoked_based_on():
+ grant = Grant()
+ code = grant.mint_token(SESSION_ID,
+ endpoint_context=ENDPOINT_CONTEXT,
+ token_type="authorization_code",
+ token_handler=TOKEN_HANDLER["authorization_code"])
+
+ access_token = grant.mint_token(
+ SESSION_ID,
+ endpoint_context=ENDPOINT_CONTEXT,
+ token_type="access_token",
+ token_handler=TOKEN_HANDLER["access_token"],
+ based_on=code)
+
+ refresh_token = grant.mint_token(
+ SESSION_ID,
+ endpoint_context=ENDPOINT_CONTEXT,
+ token_type="refresh_token",
+ token_handler=TOKEN_HANDLER["refresh_token"],
+ based_on=code)
+
+ code.register_usage()
+ if code.max_usage_reached():
+ grant.revoke_token(based_on=code.value)
+
+ assert code.is_active() is False
+ assert access_token.is_active() is False
+ assert refresh_token.is_active() is False
+
+
+def test_revoke():
+ grant = Grant()
+ code = grant.mint_token(SESSION_ID,
+ endpoint_context=ENDPOINT_CONTEXT,
+ token_type="authorization_code",
+ token_handler=TOKEN_HANDLER["authorization_code"])
+
+ access_token = grant.mint_token(
+ SESSION_ID,
+ endpoint_context=ENDPOINT_CONTEXT,
+ token_type="access_token",
+ token_handler=TOKEN_HANDLER["access_token"],
+ based_on=code)
+
+ grant.revoke_token(based_on=code.value)
+
+ assert code.is_active() is True
+ assert access_token.is_active() is False
+
+ access_token_2 = grant.mint_token(
+ SESSION_ID,
+ endpoint_context=ENDPOINT_CONTEXT,
+ token_type="access_token",
+ token_handler=TOKEN_HANDLER["access_token"],
+ based_on=code)
+
+ grant.revoke_token(value=code.value, recursive=True)
+
+ assert code.is_active() is False
+ assert access_token_2.is_active() is False
+
+
+def test_json_conversion():
+ grant = Grant()
+ code = grant.mint_token(SESSION_ID,
+ endpoint_context=ENDPOINT_CONTEXT,
+ token_type="authorization_code",
+ token_handler=TOKEN_HANDLER["authorization_code"])
+
+ grant.mint_token(
+ SESSION_ID,
+ endpoint_context=ENDPOINT_CONTEXT,
+ token_type="access_token",
+ token_handler=TOKEN_HANDLER["access_token"],
+ based_on=code)
+
+ _jstr = grant.to_json()
+
+ _grant_copy = Grant().from_json(_jstr)
+
+ assert len(_grant_copy.issued_token) == 2
+
+ tt = {"code": 0, "access_token": 0}
+ for token in _grant_copy.issued_token:
+ if token.type == "authorization_code":
+ tt["code"] += 1
+ if token.type == "access_token":
+ tt["access_token"] += 1
+
+ assert tt == {"code": 1, "access_token": 1}
+
+
+def test_json_no_token_map():
+ grant = Grant(token_map={})
+ with pytest.raises(ValueError):
+ grant.mint_token(SESSION_ID,
+ endpoint_context=ENDPOINT_CONTEXT,
+ token_type="authorization_code",
+ token_handler=TOKEN_HANDLER["authorization_code"])
+
+
+class MyToken(Token):
+ pass
+
+
+TOKEN_HANDLER["my_token"] = DefaultToken("my_token", typ="M")
+
+
+def test_json_custom_token_map():
+ token_map = TOKEN_MAP.copy()
+ token_map["my_token"] = MyToken
+
+ grant = Grant(token_map=token_map)
+ code = grant.mint_token(SESSION_ID,
+ endpoint_context=ENDPOINT_CONTEXT,
+ token_type="authorization_code",
+ token_handler=TOKEN_HANDLER["authorization_code"])
+
+ grant.mint_token(
+ SESSION_ID,
+ endpoint_context=ENDPOINT_CONTEXT,
+ token_type="access_token",
+ token_handler=TOKEN_HANDLER["access_token"],
+ based_on=code)
+
+ grant.mint_token(
+ SESSION_ID,
+ endpoint_context=ENDPOINT_CONTEXT,
+ token_type="my_token",
+ token_handler=TOKEN_HANDLER["my_token"])
+
+ _jstr = grant.to_json()
+
+ _grant_copy = Grant().from_json(_jstr)
+
+ assert len(_grant_copy.issued_token) == 3
+
+ tt = {k: 0 for k, v in grant.token_map.items()}
+
+ for token in _grant_copy.issued_token:
+ for _type in tt.keys():
+ if token.type == _type:
+ tt[_type] += 1
+
+ assert tt == {
+ 'access_token': 1, 'authorization_code': 1,
+ 'my_token': 1, 'refresh_token': 0
+ }
+
+
+def test_get_spec():
+ grant = Grant(scope=["openid", "email", "address"],
+ claims={"userinfo": {"given_name": None, "email": None}},
+ resources=["https://api.example.com"]
+ )
+
+ code = grant.mint_token(SESSION_ID,
+ endpoint_context=ENDPOINT_CONTEXT,
+ token_type="authorization_code",
+ token_handler=TOKEN_HANDLER["authorization_code"])
+
+ access_token = grant.mint_token(
+ SESSION_ID,
+ endpoint_context=ENDPOINT_CONTEXT,
+ token_type="access_token",
+ token_handler=TOKEN_HANDLER["access_token"],
+ based_on=code,
+ scope=["openid", "email", "eduperson"],
+ claims={
+ "userinfo": {
+ "given_name": None,
+ "eduperson_affiliation": None
+ }
+ }
+ )
+
+ spec = grant.get_spec(access_token)
+ assert set(spec.keys()) == {"scope", "claims", "resources"}
+ assert spec["scope"] == ["openid", "email", "eduperson"]
+ assert spec["claims"] == {
+ "userinfo": {
+ "given_name": None,
+ "eduperson_affiliation": None
+ }
+ }
+ assert spec["resources"] == ["https://api.example.com"]
+
+
+def test_get_usage_rules():
+ grant = Grant(scope=["openid", "email", "address"],
+ claims={"userinfo": {"given_name": None, "email": None}},
+ resources=["https://api.example.com"]
+ )
+
+ # Default usage rules
+ ENDPOINT_CONTEXT.cdb["client_id"] = {}
+ rules = get_usage_rules("access_token", ENDPOINT_CONTEXT, grant, "client_id")
+ assert rules == {'supports_minting': [], 'expires_in': 3600}
+
+ # client specific usage rules
+ ENDPOINT_CONTEXT.cdb["client_id"] = {"access_token": {"expires_in": 600}}
\ No newline at end of file
diff --git a/tests/test_01_session_info.py b/tests/test_01_session_info.py
new file mode 100644
index 0000000..d048e87
--- /dev/null
+++ b/tests/test_01_session_info.py
@@ -0,0 +1,64 @@
+import pytest
+from oidcmsg.oauth2 import AuthorizationRequest
+
+from oidcendpoint.session.info import ClientSessionInfo
+from oidcendpoint.session.info import SessionInfo
+from oidcendpoint.session.info import UserSessionInfo
+
+AUTH_REQ = AuthorizationRequest(
+ client_id="client_1",
+ redirect_uri="https://example.com/cb",
+ scope=["openid"],
+ state="STATE",
+ response_type=["code"],
+)
+
+
+def test_session_info_subordinate():
+ si = SessionInfo()
+ si.add_subordinate("subordinate_1")
+ si.add_subordinate("subordinate_2")
+ assert set(si["subordinate"]) == {"subordinate_1", "subordinate_2"}
+ assert set(si["subordinate"]) == {"subordinate_1", "subordinate_2"}
+ assert si.is_revoked() is False
+
+ si.remove_subordinate("subordinate_1")
+ assert si["subordinate"] == ["subordinate_2"]
+
+ si.revoke()
+ assert si.is_revoked() is True
+
+
+def test_session_info_no_subordinate():
+ si = SessionInfo()
+ assert si["subordinate"] == []
+
+
+def test_user_session_info_to_json():
+ usi = UserSessionInfo(uid="uid")
+
+ _jstr = usi.to_json()
+
+ usi2 = UserSessionInfo().from_json(_jstr)
+
+ assert usi2["uid"] == "uid"
+
+
+def test_user_session_info_to_json_with_sub():
+ usi = UserSessionInfo(uid="uid")
+ usi.add_subordinate("client_id")
+
+ _jstr = usi.to_json()
+
+ usi2 = UserSessionInfo().from_json(_jstr)
+
+ assert usi2["subordinate"] == ["client_id"]
+
+
+def test_client_session_info():
+ csi = ClientSessionInfo(client_id="clientID")
+
+ _jstr = csi.to_json()
+
+ _csi2 = ClientSessionInfo().from_json(_jstr)
+ assert _csi2["client_id"] == "clientID"
diff --git a/tests/test_01_session_token.py b/tests/test_01_session_token.py
new file mode 100644
index 0000000..3675d1d
--- /dev/null
+++ b/tests/test_01_session_token.py
@@ -0,0 +1,90 @@
+from oidcmsg.time_util import time_sans_frac
+
+from oidcendpoint.session.token import AccessToken
+from oidcendpoint.session.token import AuthorizationCode
+
+
+def test_authorization_code_default():
+ code = AuthorizationCode(value="ABCD")
+ assert code.usage_rules["max_usage"] == 1
+ assert code.usage_rules["supports_minting"] == ["access_token",
+ "refresh_token"]
+
+
+def test_authorization_code_usage():
+ code = AuthorizationCode(value="ABCD",
+ usage_rules={
+ "supports_minting": ["access_token"],
+ "max_usage": 1
+ })
+
+ assert code.usage_rules["max_usage"] == 1
+ assert code.usage_rules["supports_minting"] == ["access_token"]
+
+
+def test_authorization_code_extras():
+ code = AuthorizationCode(value="ABCD",
+ scope=["openid", "foo", "bar"],
+ claims={"userinfo": {"given_name": None}},
+ resources=["https://api.example.com"])
+
+ assert code.scope == ["openid", "foo", "bar"]
+ assert code.claims == {"userinfo": {"given_name": None}}
+ assert code.resources == ["https://api.example.com"]
+
+
+def test_to_from_json():
+ code = AuthorizationCode(value="ABCD",
+ scope=["openid", "foo", "bar"],
+ claims={"userinfo": {"given_name": None}},
+ resources=["https://api.example.com"])
+
+ _json_str = code.to_json()
+
+ _new_code = AuthorizationCode().from_json(_json_str)
+
+ for attr in AuthorizationCode.attributes:
+ assert getattr(code, attr) == getattr(_new_code, attr)
+
+
+def test_supports_minting():
+ code = AuthorizationCode(value="ABCD")
+ assert code.supports_minting('access_token')
+ assert code.supports_minting('refresh_token')
+ assert code.supports_minting("authorization_code") is False
+
+
+def test_usage():
+ token = AccessToken(usage_rules={"max_usage": 2})
+
+ token.register_usage()
+ assert token.has_been_used()
+ assert token.used == 1
+ assert token.max_usage_reached() is False
+
+ token.register_usage()
+ assert token.max_usage_reached()
+
+ token.register_usage()
+ assert token.used == 3
+ assert token.max_usage_reached()
+
+
+def test_is_active_usage():
+ token = AccessToken(usage_rules={"max_usage": 2})
+
+ token.register_usage()
+ token.register_usage()
+ assert token.is_active() is False
+
+
+def test_is_active_revoke():
+ token = AccessToken(usage_rules={"max_usage": 2})
+ token.revoke()
+ assert token.is_active() is False
+
+
+def test_is_active_expired():
+ token = AccessToken(usage_rules={"max_usage": 2})
+ token.expires_at = time_sans_frac() - 60
+ assert token.is_active() is False
diff --git a/tests/test_01_util.py b/tests/test_01_util.py
index 75ad2a7..286322f 100644
--- a/tests/test_01_util.py
+++ b/tests/test_01_util.py
@@ -1,7 +1,7 @@
from oidcendpoint.oidc.authorization import Authorization
from oidcendpoint.oidc.provider_config import ProviderConfiguration
from oidcendpoint.oidc.registration import Registration
-from oidcendpoint.oidc.token import AccessToken
+from oidcendpoint.oidc.token import Token
from oidcendpoint.oidc.userinfo import UserInfo
from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD
@@ -40,7 +40,7 @@
"class": Authorization,
"kwargs": {},
},
- "token_endpoint": {"path": "token", "class": AccessToken, "kwargs": {}},
+ "token_endpoint": {"path": "token", "class": Token, "kwargs": {}},
"userinfo_endpoint": {
"path": "userinfo",
"class": UserInfo,
diff --git a/tests/test_02_client_authn.py b/tests/test_02_client_authn.py
index 4a59bd3..d1e05d7 100755
--- a/tests/test_02_client_authn.py
+++ b/tests/test_02_client_authn.py
@@ -24,7 +24,7 @@
from oidcendpoint.exception import NotForMe
from oidcendpoint.oidc.authorization import Authorization
from oidcendpoint.oidc.registration import Registration
-from oidcendpoint.oidc.token import AccessToken
+from oidcendpoint.oidc.token import Token
from oidcendpoint.oidc.userinfo import UserInfo
KEYDEFS = [
@@ -37,14 +37,12 @@
CONF = {
"issuer": "https://example.com/",
"password": "mycket hemligt",
- "token_expires_in": 600,
"grant_expires_in": 300,
- "refresh_token_expires_in": 86400,
"verify_ssl": False,
"endpoint": {
"token": {
"path": "token",
- "class": AccessToken,
+ "class": Token,
"kwargs": {
"client_authn_method": [
"private_key_jwt",
diff --git a/tests/test_02_client_authn.py.old b/tests/test_02_client_authn.py.old
deleted file mode 100755
index 35dd5b0..0000000
--- a/tests/test_02_client_authn.py.old
+++ /dev/null
@@ -1,396 +0,0 @@
-import base64
-
-import pytest
-from cryptojwt.jws.exception import NoSuitableSigningKeys
-from cryptojwt.jwt import JWT
-from cryptojwt.key_jar import KeyJar
-from cryptojwt.key_jar import build_keyjar
-from cryptojwt.utils import as_bytes
-from cryptojwt.utils import as_unicode
-from oidcendpoint import JWT_BEARER
-from oidcendpoint.client_authn import AuthnFailure
-from oidcendpoint.client_authn import BearerBody
-from oidcendpoint.client_authn import BearerHeader
-from oidcendpoint.client_authn import ClientSecretBasic
-from oidcendpoint.client_authn import ClientSecretJWT
-from oidcendpoint.client_authn import ClientSecretPost
-from oidcendpoint.client_authn import JWSAuthnMethod
-from oidcendpoint.client_authn import PrivateKeyJWT
-from oidcendpoint.client_authn import basic_authn
-from oidcendpoint.client_authn import verify_client
-from oidcendpoint.endpoint_context import EndpointContext
-from oidcendpoint.exception import MultipleUsage
-from oidcendpoint.exception import NotForMe
-from oidcendpoint.oidc.authorization import Authorization
-from oidcendpoint.oidc.token import AccessToken
-from oidcendpoint.oidc.userinfo import UserInfo
-
-KEYDEFS = [
- {"type": "RSA", "key": "", "use": ["sig"]},
- {"type": "EC", "crv": "P-256", "use": ["sig"]},
-]
-
-KEYJAR = build_keyjar(KEYDEFS)
-
-conf = {
- "issuer": "https://example.com/",
- "password": "mycket hemligt",
- "token_expires_in": 600,
- "grant_expires_in": 300,
- "refresh_token_expires_in": 86400,
- "verify_ssl": False,
- "endpoint": {
- "token": {"path": "token", "class": AccessToken, "kwargs": {}},
- "authorization": {"path": "auth", "class": Authorization, "kwargs": {}},
- "userinfo": {"path": "user", "class": UserInfo, "kwargs": {}}
- },
- "template_dir": "template",
- "jwks": {
- "private_path": "own/jwks.json",
- "key_defs": KEYDEFS,
- "uri_path": "static/jwks.json",
- },
-}
-client_id = "client_id"
-client_secret = "a_longer_client_secret"
-# Need to add the client_secret as a symmetric key bound to the client_id
-KEYJAR.add_symmetric(client_id, client_secret, ["sig"])
-
-endpoint_context = EndpointContext(conf, keyjar=KEYJAR)
-endpoint_context.cdb[client_id] = {"client_secret": client_secret}
-
-
-def get_client_id_from_token(endpoint_context, token, request=None):
- if "client_id" in request:
- if request["client_id"] == endpoint_context.registration_access_token[token]:
- return request["client_id"]
- return ""
-
-
-def test_client_secret_basic():
- _token = "{}:{}".format(client_id, client_secret)
- token = as_unicode(base64.b64encode(as_bytes(_token)))
-
- authz_token = "Basic {}".format(token)
-
- authn_info = ClientSecretBasic(endpoint_context).verify({}, authz_token)
-
- assert authn_info["client_id"] == client_id
-
-
-def test_client_secret_post():
- request = {"client_id": client_id, "client_secret": client_secret}
-
- authn_info = ClientSecretPost(endpoint_context).verify(request)
-
- assert authn_info["client_id"] == client_id
-
-
-def test_client_secret_jwt():
- client_keyjar = KeyJar()
- client_keyjar[conf["issuer"]] = KEYJAR.issuer_keys[""]
- # The only own key the client has a this point
- client_keyjar.add_symmetric("", client_secret, ["sig"])
-
- _jwt = JWT(client_keyjar, iss=client_id, sign_alg="HS256")
- _jwt.with_jti = True
- _assertion = _jwt.pack({"aud": [conf["issuer"]]})
-
- request = {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER}
-
- authn_info = ClientSecretJWT(endpoint_context).verify(request)
-
- assert authn_info["client_id"] == client_id
- assert "jwt" in authn_info
-
-
-def test_private_key_jwt():
- # Own dynamic keys
- client_keyjar = build_keyjar(KEYDEFS)
- # The servers keys
- client_keyjar[conf["issuer"]] = KEYJAR.issuer_keys[""]
-
- _jwks = client_keyjar.export_jwks()
- endpoint_context.keyjar.import_jwks(_jwks, client_id)
-
- _jwt = JWT(client_keyjar, iss=client_id, sign_alg="RS256")
- _jwt.with_jti = True
- _assertion = _jwt.pack({"aud": [conf["issuer"]]})
-
- request = {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER}
-
- authn_info = PrivateKeyJWT(endpoint_context).verify(request)
-
- assert authn_info["client_id"] == client_id
- assert "jwt" in authn_info
-
-
-def test_private_key_jwt_reusage_other_endpoint():
- # Own dynamic keys
- client_keyjar = build_keyjar(KEYDEFS)
- # The servers keys
- client_keyjar[conf["issuer"]] = KEYJAR.issuer_keys[""]
-
- _jwks = client_keyjar.export_jwks()
- endpoint_context.keyjar.import_jwks(_jwks, client_id)
-
- _jwt = JWT(client_keyjar, iss=client_id, sign_alg="RS256")
- _jwt.with_jti = True
- _assertion = _jwt.pack({"aud": [endpoint_context.endpoint["token"].full_path]})
-
- request = {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER}
-
- # This should be OK
- PrivateKeyJWT(endpoint_context).verify(request, endpoint="token")
-
- # This should NOT be OK
- with pytest.raises(NotForMe):
- PrivateKeyJWT(endpoint_context).verify(request, endpoint="authorization")
-
- # This should NOT be OK
- with pytest.raises(MultipleUsage):
- PrivateKeyJWT(endpoint_context).verify(request, endpoint="token")
-
-
-def test_private_key_jwt_auth_endpoint():
- # Own dynamic keys
- client_keyjar = build_keyjar(KEYDEFS)
- # The servers keys
- client_keyjar[conf["issuer"]] = KEYJAR.issuer_keys[""]
-
- _jwks = client_keyjar.export_jwks()
- endpoint_context.keyjar.import_jwks(_jwks, client_id)
-
- _jwt = JWT(client_keyjar, iss=client_id, sign_alg="RS256")
- _jwt.with_jti = True
- _assertion = _jwt.pack({"aud": [endpoint_context.endpoint["authorization"].full_path]})
-
- request = {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER}
-
- authn_info = PrivateKeyJWT(endpoint_context).verify(request, endpoint="authorization")
-
- assert authn_info["client_id"] == client_id
- assert "jwt" in authn_info
-
-
-def test_wrong_type():
- with pytest.raises(AuthnFailure):
- ClientSecretBasic(endpoint_context).verify({}, "Foppa toffel")
-
-
-def test_csb_wrong_secret():
- _token = "{}:{}".format(client_id, "pillow")
- token = as_unicode(base64.b64encode(as_bytes(_token)))
-
- authz_token = "Basic {}".format(token)
-
- with pytest.raises(AuthnFailure):
- ClientSecretBasic(endpoint_context).verify({}, authz_token)
-
-
-def test_client_secret_post_wrong_secret():
- request = {"client_id": client_id, "client_secret": "pillow"}
-
- with pytest.raises(AuthnFailure):
- ClientSecretPost(endpoint_context).verify(request)
-
-
-def test_bearerheader():
- request = {}
- authorization_info = "Bearer 1234567890"
- assert BearerHeader(endpoint_context).verify(request, authorization_info) == {
- "token": "1234567890"
- }
-
-
-def test_bearerheader_wrong_type():
- request = {}
- authorization_info = "Thrower 1234567890"
- with pytest.raises(AuthnFailure):
- BearerHeader(endpoint_context).verify(request, authorization_info)
-
-
-def test_bearer_body():
- request = {"access_token": "1234567890"}
- assert BearerBody(endpoint_context).verify(request) == {"token": "1234567890"}
-
-
-def test_bearer_body_no_token():
- request = {}
- with pytest.raises(AuthnFailure):
- BearerBody(endpoint_context).verify(request)
-
-
-def test_jws_authn_method_wrong_key():
- client_keyjar = KeyJar()
- client_keyjar[conf["issuer"]] = KEYJAR.issuer_keys[""]
- # Fake symmetric key
- client_keyjar.add_symmetric("", "client_secret:client_secret", ["sig"])
-
- _jwt = JWT(client_keyjar, iss=client_id, sign_alg="HS256")
- _assertion = _jwt.pack({"aud": [conf["issuer"]]})
-
- request = {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER}
-
- with pytest.raises(NoSuitableSigningKeys):
- JWSAuthnMethod(endpoint_context).verify(request)
-
-
-def test_jws_authn_method_aud_iss():
- client_keyjar = KeyJar()
- client_keyjar[conf["issuer"]] = KEYJAR.issuer_keys[""]
- # The only own key the client has a this point
- client_keyjar.add_symmetric("", client_secret, ["sig"])
-
- _jwt = JWT(client_keyjar, iss=client_id, sign_alg="HS256")
- # Audience is OP issuer ID
- aud = conf["issuer"]
- _assertion = _jwt.pack({"aud": [aud]})
-
- request = {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER}
-
- assert JWSAuthnMethod(endpoint_context).verify(request)
-
-
-def test_jws_authn_method_aud_token_endpoint():
- client_keyjar = KeyJar()
- client_keyjar[conf["issuer"]] = KEYJAR.issuer_keys[""]
- # The only own key the client has a this point
- client_keyjar.add_symmetric("", client_secret, ["sig"])
-
- _jwt = JWT(client_keyjar, iss=client_id, sign_alg="HS256")
-
- # audience is OP token endpoint - that's OK
- aud = "{}token".format(conf["issuer"])
- _assertion = _jwt.pack({"aud": [aud]})
-
- request = {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER}
-
- assert JWSAuthnMethod(endpoint_context).verify(request, endpoint="token")
-
-
-def test_jws_authn_method_aud_not_me():
- client_keyjar = KeyJar()
- client_keyjar[conf["issuer"]] = KEYJAR.issuer_keys[""]
- # The only own key the client has a this point
- client_keyjar.add_symmetric("", client_secret, ["sig"])
-
- _jwt = JWT(client_keyjar, iss=client_id, sign_alg="HS256")
-
- # Other audiences not OK
- aud = "https://example.org"
-
- _assertion = _jwt.pack({"aud": [aud]})
-
- request = {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER}
-
- with pytest.raises(NotForMe):
- JWSAuthnMethod(endpoint_context).verify(request)
-
-
-def test_basic_auth():
- _token = "{}:{}".format(client_id, client_secret)
- token = as_unicode(base64.b64encode(as_bytes(_token)))
-
- res = basic_authn("Basic {}".format(token))
- assert res
-
-
-def test_basic_auth_wrong_label():
- _token = "{}:{}".format(client_id, client_secret)
- token = as_unicode(base64.b64encode(as_bytes(_token)))
-
- with pytest.raises(AuthnFailure):
- basic_authn("Expanded {}".format(token))
-
-
-def test_basic_auth_wrong_token():
- _token = "{}:{}:foo".format(client_id, client_secret)
- token = as_unicode(base64.b64encode(as_bytes(_token)))
- with pytest.raises(ValueError):
- basic_authn("Basic {}".format(token))
-
- _token = "{}:{}".format(client_id, client_secret)
- with pytest.raises(ValueError):
- basic_authn("Basic {}".format(_token))
-
- _token = "{}{}".format(client_id, client_secret)
- token = as_unicode(base64.b64encode(as_bytes(_token)))
- with pytest.raises(ValueError):
- basic_authn("Basic {}".format(token))
-
-
-def test_verify_client_jws_authn_method():
- client_keyjar = KeyJar()
- client_keyjar[conf["issuer"]] = KEYJAR.issuer_keys[""]
- # The only own key the client has a this point
- client_keyjar.add_symmetric("", client_secret, ["sig"])
-
- _jwt = JWT(client_keyjar, iss=client_id, sign_alg="HS256")
- # Audience is OP issuer ID
- aud = conf["issuer"]
- _assertion = _jwt.pack({"aud": [aud]})
-
- request = {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER}
-
- res = verify_client(endpoint_context, request)
- assert res["method"] == "private_key_jwt"
- assert res["client_id"] == "client_id"
-
-
-def test_verify_client_bearer_body():
- request = {"access_token": "1234567890", "client_id": client_id}
- endpoint_context.registration_access_token["1234567890"] = client_id
- res = verify_client(
- endpoint_context, request, get_client_id_from_token=get_client_id_from_token
- )
- assert set(res.keys()) == {"token", "method", "client_id"}
- assert res["method"] == "bearer_body"
-
-
-def test_verify_client_client_secret_post():
- request = {"client_id": client_id, "client_secret": client_secret}
- res = verify_client(endpoint_context, request)
- assert set(res.keys()) == {"method", "client_id"}
- assert res["method"] == "client_secret_post"
-
-
-def test_verify_client_client_secret_basic():
- _token = "{}:{}".format(client_id, client_secret)
- token = as_unicode(base64.b64encode(as_bytes(_token)))
- authz_token = "Basic {}".format(token)
- res = verify_client(endpoint_context, {}, authz_token)
- assert set(res.keys()) == {"method", "client_id"}
- assert res["method"] == "client_secret_basic"
-
-
-def test_verify_client_bearer_header():
- endpoint_context.registration_access_token["1234567890"] = client_id
- token = "Bearer 1234567890"
- request = {"client_id": client_id}
- res = verify_client(
- endpoint_context,
- request,
- authorization_info=token,
- get_client_id_from_token=get_client_id_from_token,
- )
-
- res = verify_client(endpoint_context, request, token, get_client_id_from_token)
- assert set(res.keys()) == {"token", "method", "client_id"}
- assert res["method"] == "bearer_header"
-
-
-def test_jws_authn_method_aud_userinfo_endpoint():
- client_keyjar = KeyJar()
- client_keyjar[conf["issuer"]] = KEYJAR.issuer_keys[""]
- # The only own key the client has a this point
- client_keyjar.add_symmetric("", client_secret, ["sig"])
-
- _jwt = JWT(client_keyjar, iss=client_id, sign_alg="HS256")
-
- # audience is the OP - not specifically the user info endpoint
- _assertion = _jwt.pack({"aud": [conf["issuer"]]})
-
- request = {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER}
-
- assert JWSAuthnMethod(endpoint_context).verify(request, endpoint="userinfo")
diff --git a/tests/test_02_sess_mngm_db.py b/tests/test_02_sess_mngm_db.py
new file mode 100644
index 0000000..e68607c
--- /dev/null
+++ b/tests/test_02_sess_mngm_db.py
@@ -0,0 +1,264 @@
+# Database is organized in 3 layers. User-session-grant.
+import pytest
+from oidcmsg.oauth2 import AuthorizationRequest
+
+from oidcendpoint.authn_event import create_authn_event
+from oidcendpoint.session import session_key
+from oidcendpoint.session.database import Database
+from oidcendpoint.session.database import NoSuchClientSession
+from oidcendpoint.session.database import NoSuchGrant
+from oidcendpoint.session.grant import Grant
+from oidcendpoint.session.info import ClientSessionInfo
+from oidcendpoint.session.info import UserSessionInfo
+from oidcendpoint.session.manager import public_id
+from oidcendpoint.session.token import Token
+
+AUTHZ_REQ = AuthorizationRequest(
+ client_id="client_1",
+ redirect_uri="https://example.com/cb",
+ scope=["openid"],
+ state="STATE",
+ response_type="code",
+)
+
+
+class TestDB:
+ @pytest.fixture(autouse=True)
+ def setup_environment(self):
+ self.db = Database()
+
+ def test_user_info(self):
+ with pytest.raises(KeyError):
+ self.db.get(['diana'])
+
+ user_info = UserSessionInfo(user_id="diana", foo="bar")
+ self.db.set(['diana'], user_info)
+ stored_user_info = self.db.get(['diana'])
+ assert stored_user_info["foo"] == "bar"
+
+ def test_client_info(self):
+ user_info = UserSessionInfo(user_id="diana", foo="bar")
+ self.db.set(['diana'], user_info)
+ client_info = ClientSessionInfo(client_id="client_1")
+ self.db.set(['diana', "client_1"], client_info)
+
+ stored_user_info = self.db.get(['diana'])
+ assert stored_user_info['subordinate'] == ['client_1']
+ stored_client_info = self.db.get(['diana', "client_1"])
+ assert stored_client_info['client_id'] == "client_1"
+
+ def test_client_info_change(self):
+ user_info = UserSessionInfo(user_id="diana", foo="bar")
+ self.db.set(['diana'], user_info)
+ client_info = ClientSessionInfo(client_id="client_1", extra="snow")
+ self.db.set(['diana', "client_1"], client_info)
+
+ user_info = self.db.get(['diana'])
+ assert user_info['subordinate'] == ['client_1']
+ client_info = self.db.get(['diana', "client_1"])
+ assert client_info['client_id'] == "client_1"
+ assert client_info['extra'] == "snow"
+
+ client_info = ClientSessionInfo(client_id="client_1", extra="ice")
+ self.db.set(['diana', "client_1"], client_info)
+
+ stored_client_info = self.db.get(['diana', "client_1"])
+ assert stored_client_info['extra'] == "ice"
+
+ def test_client_info_add1(self):
+ user_info = UserSessionInfo(user_id="diana")
+ self.db.set(['diana'], user_info)
+ client_info = ClientSessionInfo(client_id="client_1")
+ self.db.set(['diana', "client_1"], client_info)
+
+ # The reference is there but not the value
+ del self.db._db[session_key('diana', "client_1")]
+
+ client_info = ClientSessionInfo(client_id="client_1", extra="ice")
+ self.db.set(['diana', "client_1"], client_info)
+
+ stored_client_info = self.db.get(['diana', "client_1"])
+ assert stored_client_info['extra'] == "ice"
+
+ def test_client_info_add2(self):
+ user_info = UserSessionInfo(foo="bar")
+ self.db.set(['diana'], user_info)
+ client_info = ClientSessionInfo(sid="abcdef")
+ self.db.set(['diana', "client_1"], client_info)
+
+ # The reference is there but not the value
+ del self.db._db[session_key('diana', "client_1")]
+
+ authn_event = create_authn_event(uid="diana",
+ expires_in=10,
+ authn_info="authn_class_ref")
+
+ grant = Grant(authentication_event=authn_event,
+ authorization_request=AUTHZ_REQ)
+
+ access_code = Token('access_code', value='1234567890')
+ grant.issued_token.append(access_code)
+
+ self.db.set(['diana', "client_1", "G1"], grant)
+ stored_client_info = self.db.get(['diana', "client_1"])
+ assert set(stored_client_info.keys()) == {"subordinate", "revoked", "type"}
+
+ stored_grant_info = self.db.get(['diana', 'client_1', 'G1'])
+ assert stored_grant_info.issued_at
+
+ def test_jump_ahead(self):
+ grant = Grant()
+ access_code = Token('access_code', value='1234567890')
+ grant.issued_token.append(access_code)
+
+ self.db.set(['diana', "client_1", "G1"], grant)
+
+ user_info = self.db.get(['diana'])
+ assert user_info['subordinate'] == ['client_1']
+ client_info = self.db.get(['diana', "client_1"])
+ assert client_info['subordinate'] == ["G1"]
+ grant_info = self.db.get(['diana', 'client_1', 'G1'])
+ assert grant_info.issued_at
+ assert len(grant_info.issued_token) == 1
+ token = grant_info.issued_token[0]
+ assert token.value == '1234567890'
+ assert token.type == "access_code"
+
+ def test_replace_grant_info_not_there(self):
+ grant = Grant()
+ access_code = Token('access_code', value='1234567890')
+ grant.issued_token.append(access_code)
+
+ self.db.set(['diana', "client_1", "G1"], grant)
+
+ # The reference is there but not the value
+ del self.db._db[session_key('diana', "client_1", "G1")]
+
+ grant = Grant()
+ access_code = Token('access_code', value='aaaaaaaaa')
+ grant.issued_token.append(access_code)
+
+ self.db.set(['diana', "client_1", "G1"], grant)
+
+ stored_grant_info = self.db.get(['diana', 'client_1', 'G1'])
+ assert stored_grant_info.issued_at
+ assert len(stored_grant_info.issued_token) == 1
+ token = stored_grant_info.issued_token[0]
+ assert token.value == 'aaaaaaaaa'
+
+ def test_replace_user_info(self):
+ # store user info
+ self.db.set(['diana'],
+ UserSessionInfo(authentication_event=create_authn_event('diana')))
+
+ self.db.set(['diana'],
+ UserSessionInfo(
+ authentication_event=create_authn_event('diana', authn_time=111111111)))
+
+ stored_user_info = self.db.get(['diana'])
+ assert stored_user_info["authentication_event"]["authn_time"] == 111111111
+
+ def test_add_client_info(self):
+ client_info = ClientSessionInfo(sid="abcdef")
+ self.db.set(['diana', "client_1"], client_info)
+
+ stored_client_info = self.db.get(['diana', "client_1"])
+ assert stored_client_info['sid'] == "abcdef"
+
+ def test_half_way(self):
+ # store user info
+ self.db.set(['diana'],
+ UserSessionInfo(authentication_event=create_authn_event('diana')))
+
+ grant = Grant()
+ access_code = Token('access_code', value='1234567890')
+ grant.issued_token.append(access_code)
+
+ self.db.set(['diana', "client_1", "G1"], grant)
+
+ stored_grant_info = self.db.get(['diana', 'client_1', 'G1'])
+ assert stored_grant_info.issued_at
+ assert len(stored_grant_info.issued_token) == 1
+
+ def test_step_wise(self):
+ salt = "natriumklorid"
+ # store user info
+ self.db.set(['diana'],
+ UserSessionInfo(authentication_event=create_authn_event('diana')))
+ # Client specific information
+ self.db.set(['diana', 'client_1'], ClientSessionInfo(sub=public_id(
+ 'diana', salt)))
+ # Grant
+ grant = Grant()
+ access_code = Token('access_code', value='1234567890')
+ grant.issued_token.append(access_code)
+
+ self.db.set(['diana', 'client_1', 'G1'], grant)
+
+ def test_removed(self):
+ grant = Grant()
+ access_code = Token('access_code', value='1234567890')
+ grant.issued_token.append(access_code)
+
+ self.db.set(['diana', "client_1", "G1"], grant)
+ self.db.delete(['diana', 'client_1'])
+ with pytest.raises(KeyError):
+ self.db.get(['diana', "client_1", "G1"])
+
+ def test_client_info_removed(self):
+ user_info = UserSessionInfo(foo="bar")
+ self.db.set(['diana'], user_info)
+ client_info = ClientSessionInfo(sid="abcdef")
+ self.db.set(['diana', "client_1"], client_info)
+
+ # The reference is there but not the value
+ del self.db._db[session_key('diana', "client_1")]
+
+ with pytest.raises(NoSuchClientSession):
+ self.db.get(['diana', "client_1"])
+
+ def test_grant_info(self):
+ user_info = UserSessionInfo(foo="bar")
+ self.db.set(['diana'], user_info)
+ client_info = ClientSessionInfo(sid="abcdef")
+ self.db.set(['diana', "client_1"], client_info)
+
+ with pytest.raises(ValueError):
+ self.db.get(['diana', "client_1", "G1"])
+
+ grant = Grant()
+ access_code = Token('access_code', value='1234567890')
+ grant.issued_token.append(access_code)
+
+ self.db.set(['diana', "client_1", "G1"], grant)
+
+ # removed value
+ del self.db._db[session_key('diana', "client_1", "G1")]
+
+ with pytest.raises(NoSuchGrant):
+ self.db.get(['diana', "client_1", "G1"])
+
+ def test_delete_non_existent_info(self):
+ # Does nothing
+ self.db.delete(["diana"])
+
+ user_info = UserSessionInfo(foo="bar")
+ user_info.add_subordinate('client')
+ self.db.set(['diana'], user_info)
+
+ # again silently does nothing
+ self.db.delete(["diana", "client"])
+
+ client_info = ClientSessionInfo(sid="abcdef")
+ client_info.add_subordinate('G1')
+ self.db.set(['diana', "client_1"], client_info)
+
+ # and finally
+ self.db.delete(["diana", "client_1", "G1"])
+
+ def test_delete_user_info(self):
+ user_info = UserSessionInfo(foo="bar")
+ self.db.set(['diana'], user_info)
+ self.db.delete(["diana"])
+ with pytest.raises(KeyError):
+ self.db.get(['diana'])
diff --git a/tests/test_03_id_token.py b/tests/test_03_id_token.py
index 50555b6..cd8ba95 100644
--- a/tests/test_03_id_token.py
+++ b/tests/test_03_id_token.py
@@ -1,21 +1,22 @@
import json
import os
-import time
-import pytest
from cryptojwt.jws import jws
from cryptojwt.jwt import JWT
from cryptojwt.key_jar import KeyJar
from oidcmsg.oidc import AuthorizationRequest
from oidcmsg.oidc import RegistrationResponse
+from oidcmsg.time_util import time_sans_frac
+import pytest
+from oidcendpoint.authn_event import create_authn_event
from oidcendpoint.client_authn import verify_client
from oidcendpoint.endpoint_context import EndpointContext
from oidcendpoint.id_token import IDToken
from oidcendpoint.id_token import get_sign_and_encrypt_algorithms
from oidcendpoint.oidc import userinfo
from oidcendpoint.oidc.authorization import Authorization
-from oidcendpoint.oidc.token import AccessToken
+from oidcendpoint.oidc.token import Token
from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD
from oidcendpoint.user_info import UserInfo
@@ -34,15 +35,34 @@ def full_path(local_file):
USERS = json.loads(open(full_path("users.json")).read())
USERINFO = UserInfo(USERS)
-AREQN = AuthorizationRequest(
+AREQ = AuthorizationRequest(
response_type="code",
- client_id="client1",
+ client_id="client_1",
redirect_uri="http://example.com/authz",
scope=["openid"],
state="state000",
nonce="nonce",
)
+AREQS = AuthorizationRequest(
+ response_type="code",
+ client_id="client_1",
+ redirect_uri="http://example.com/authz",
+ scope=["openid", "address", "email"],
+ state="state000",
+ nonce="nonce",
+)
+
+AREQRC = AuthorizationRequest(
+ response_type="code",
+ client_id="client_1",
+ redirect_uri="http://example.com/authz",
+ scope=["openid", "address", "email"],
+ state="state000",
+ nonce="nonce",
+ claims={"id_token": {"nickname": None}}
+)
+
conf = {
"issuer": "https://example.com/",
"password": "mycket hemligt",
@@ -58,7 +78,7 @@ def full_path(local_file):
"class": Authorization,
"kwargs": {},
},
- "token_endpoint": {"path": "{}/token", "class": AccessToken, "kwargs": {}},
+ "token_endpoint": {"path": "{}/token", "class": Token, "kwargs": {}},
"userinfo_endpoint": {
"path": "{}/userinfo",
"class": userinfo.UserInfo,
@@ -72,12 +92,14 @@ def full_path(local_file):
"kwargs": {"user": "diana"},
}
},
- "userinfo": {"class": "oidcendpoint.user_info.UserInfo", "kwargs": {"db": USERS},},
+ "userinfo": {"class": "oidcendpoint.user_info.UserInfo", "kwargs": {"db": USERS}, },
"client_authn": verify_client,
"template_dir": "template",
"id_token": {"class": IDToken, "kwargs": {"foo": "bar"}},
}
+USER_ID = "diana"
+
class TestEndpoint(object):
@pytest.fixture(autouse=True)
@@ -93,186 +115,114 @@ def create_idtoken(self):
self.endpoint_context.keyjar.add_symmetric(
"client_1", "hemligtochintekort", ["sig", "enc"]
)
+ self.session_manager = self.endpoint_context.session_manager
+ self.user_id = USER_ID
- def test_id_token_payload_0(self):
- session_info = {"authn_req": AREQN, "sub": "1234567890"}
- info = self.endpoint_context.idtoken.payload(session_info)
- assert info["payload"] == {"sub": "1234567890", "nonce": "nonce"}
- assert info["lifetime"] == 300
+ def _create_session(self, auth_req, sub_type="public", sector_identifier=''):
+ if sector_identifier:
+ authz_req = auth_req.copy()
+ authz_req["sector_identifier_uri"] = sector_identifier
+ else:
+ authz_req = auth_req
+
+ client_id = authz_req['client_id']
+ ae = create_authn_event(self.user_id)
+ return self.session_manager.create_session(ae, authz_req, self.user_id,
+ client_id=client_id,
+ sub_type=sub_type)
+
+ def _mint_code(self, grant, session_id):
+ # Constructing an authorization code is now done
+ return grant.mint_token(
+ session_id=session_id,
+ endpoint_context=self.endpoint_context,
+ token_type="authorization_code",
+ token_handler=self.session_manager.token_handler["code"],
+ expires_at=time_sans_frac() + 300 # 5 minutes from now
+ )
- def test_id_token_payload_1(self):
- session_info = {"authn_req": AREQN, "sub": "1234567890"}
+ def _mint_access_token(self, grant, session_id, token_ref):
+ return grant.mint_token(
+ session_id=session_id,
+ endpoint_context=self.endpoint_context,
+ token_type="access_token",
+ token_handler=self.session_manager.token_handler["access_token"],
+ expires_at=time_sans_frac() + 900, # 15 minutes from now
+ based_on=token_ref # Means the token (tok) was used to mint this token
+ )
- info = self.endpoint_context.idtoken.payload(session_info)
- assert info["payload"] == {"nonce": "nonce", "sub": "1234567890"}
- assert info["lifetime"] == 300
+ def test_id_token_payload_0(self):
+ session_id = self._create_session(AREQ)
+ payload = self.endpoint_context.idtoken.payload(session_id)
+ assert set(payload.keys()) == {"sub", "nonce", "auth_time"}
def test_id_token_payload_with_code(self):
- session_info = {"authn_req": AREQN, "sub": "1234567890"}
+ session_id = self._create_session(AREQ)
+ grant = self.session_manager[session_id]
- info = self.endpoint_context.idtoken.payload(
- session_info, code="ABCDEFGHIJKLMNOP"
+ code = self._mint_code(grant, session_id)
+ payload = self.endpoint_context.idtoken.payload(
+ session_id, AREQ["client_id"], code=code.value
)
- assert info["payload"] == {
- "nonce": "nonce",
- "c_hash": "5-i4nCch0pDMX1VCVJHs1g",
- "sub": "1234567890",
- }
- assert info["lifetime"] == 300
-
- @pytest.mark.parametrize("add_c_hash", [True, False])
- def test_id_token_payload_with_code_in_session(self, add_c_hash):
- self.endpoint_context.idtoken.add_c_hash = add_c_hash
- code = "ABCDEFGHIJKLMNOP"
- session_info = {"authn_req": AREQN, "sub": "1234567890", "code": code}
-
- info = self.endpoint_context.idtoken.payload(session_info)
- if add_c_hash:
- assert info["payload"] == {
- "nonce": "nonce",
- "sub": "1234567890",
- "c_hash": "5-i4nCch0pDMX1VCVJHs1g",
- }
- else:
- assert info["payload"] == {
- "nonce": "nonce",
- "sub": "1234567890",
- }
- assert info["lifetime"] == 300
+ assert set(payload.keys()) == {"nonce", "c_hash", "sub", "auth_time"}
def test_id_token_payload_with_access_token(self):
- session_info = {"authn_req": AREQN, "sub": "1234567890"}
+ session_id = self._create_session(AREQ)
+ grant = self.session_manager[session_id]
- info = self.endpoint_context.idtoken.payload(
- session_info, access_token="012ABCDEFGHIJKLMNOP"
- )
- assert info["payload"] == {
- "nonce": "nonce",
- "at_hash": "bKkyhbn1CC8IMdavzOV-Qg",
- "sub": "1234567890",
- }
- assert info["lifetime"] == 300
-
- @pytest.mark.parametrize("add_at_hash", [True, False])
- def test_id_token_payload_with_access_token_in_session(self, add_at_hash):
- self.endpoint_context.idtoken.add_at_hash = add_at_hash
- access_token = "012ABCDEFGHIJKLMNOP"
- session_info = {
- "authn_req": AREQN,
- "sub": "1234567890",
- "access_token": access_token
- }
+ code = self._mint_code(grant, session_id)
+ access_token = self._mint_access_token(grant, session_id, code)
- info = self.endpoint_context.idtoken.payload(session_info)
- if add_at_hash:
- assert info["payload"] == {
- "nonce": "nonce",
- "sub": "1234567890",
- "at_hash": "bKkyhbn1CC8IMdavzOV-Qg",
- }
- else:
- assert info["payload"] == {
- "nonce": "nonce",
- "sub": "1234567890",
- }
- assert info["lifetime"] == 300
+ payload = self.endpoint_context.idtoken.payload(
+ session_id, AREQ["client_id"], access_token=access_token.value
+ )
+ assert set(payload.keys()) == {"nonce", "at_hash", "sub", "auth_time"}
def test_id_token_payload_with_code_and_access_token(self):
- session_info = {"authn_req": AREQN, "sub": "1234567890"}
+ session_id = self._create_session(AREQ)
+ grant = self.session_manager[session_id]
- info = self.endpoint_context.idtoken.payload(
- session_info, access_token="012ABCDEFGHIJKLMNOP", code="ABCDEFGHIJKLMNOP"
- )
- assert info["payload"] == {
- "nonce": "nonce",
- "at_hash": "bKkyhbn1CC8IMdavzOV-Qg",
- "c_hash": "5-i4nCch0pDMX1VCVJHs1g",
- "sub": "1234567890",
- }
- assert info["lifetime"] == 300
-
- @pytest.mark.parametrize("add_hashes", [True, False])
- def test_id_token_payload_with_code_and_access_token_in_session(
- self, add_hashes
- ):
- self.endpoint_context.idtoken.add_c_hash = add_hashes
- self.endpoint_context.idtoken.add_at_hash = add_hashes
- code = "ABCDEFGHIJKLMNOP"
- access_token = "012ABCDEFGHIJKLMNOP"
- session_info = {
- "authn_req": AREQN,
- "sub": "1234567890",
- "access_token": access_token,
- "code": code,
- }
+ code = self._mint_code(grant, session_id)
+ access_token = self._mint_access_token(grant, session_id, code)
- info = self.endpoint_context.idtoken.payload(session_info)
- if add_hashes:
- assert info["payload"] == {
- "nonce": "nonce",
- "sub": "1234567890",
- "at_hash": "bKkyhbn1CC8IMdavzOV-Qg",
- "c_hash": "5-i4nCch0pDMX1VCVJHs1g",
- }
- else:
- assert info["payload"] == {
- "nonce": "nonce",
- "sub": "1234567890",
- }
- assert info["lifetime"] == 300
+ payload = self.endpoint_context.idtoken.payload(
+ session_id, AREQ["client_id"], access_token=access_token.value, code=code.value
+ )
+ assert set(payload.keys()) == {"nonce", "c_hash", "at_hash", "sub", "auth_time"}
def test_id_token_payload_with_userinfo(self):
- session_info = {"authn_req": AREQN, "sub": "1234567890"}
+ session_id = self._create_session(AREQ)
+ grant = self.session_manager[session_id]
+ grant.claims = {"id_token": {"given_name": None}}
- info = self.endpoint_context.idtoken.payload(
- session_info, user_info={"given_name": "Diana"}
- )
- assert info["payload"] == {
- "nonce": "nonce",
- "given_name": "Diana",
- "sub": "1234567890",
- }
- assert info["lifetime"] == 300
+ payload = self.endpoint_context.idtoken.payload(session_id=session_id)
+ assert set(payload.keys()) == {"nonce", "given_name", "sub", "auth_time"}
def test_id_token_payload_many_0(self):
- session_info = {"authn_req": AREQN, "sub": "1234567890"}
-
- info = self.endpoint_context.idtoken.payload(
- session_info,
- user_info={"given_name": "Diana"},
- access_token="012ABCDEFGHIJKLMNOP",
- code="ABCDEFGHIJKLMNOP",
+ session_id = self._create_session(AREQ)
+ grant = self.session_manager[session_id]
+ grant.claims = {"id_token": {"given_name": None}}
+ code = self._mint_code(grant, session_id)
+ access_token = self._mint_access_token(grant, session_id, code)
+
+ payload = self.endpoint_context.idtoken.payload(
+ session_id, AREQ["client_id"],
+ access_token=access_token.value,
+ code=code.value
)
- assert info["payload"] == {
- "nonce": "nonce",
- "given_name": "Diana",
- "at_hash": "bKkyhbn1CC8IMdavzOV-Qg",
- "c_hash": "5-i4nCch0pDMX1VCVJHs1g",
- "sub": "1234567890",
- }
- assert info["lifetime"] == 300
+ assert set(payload.keys()) == {"nonce", "c_hash", "at_hash", "sub", "auth_time",
+ "given_name"}
def test_sign_encrypt_id_token(self):
- client_info = RegistrationResponse(
- id_token_signed_response_alg="RS512", client_id="client_1"
- )
- session_info = {
- "authn_req": AREQN,
- "sub": "sub",
- "authn_event": {"authn_info": "loa2", "authn_time": time.time()},
- }
-
- self.endpoint_context.jwx_def["signing_alg"] = {"id_token": "RS384"}
- self.endpoint_context.cdb["client_1"] = client_info.to_dict()
+ session_id = self._create_session(AREQ)
- _token = self.endpoint_context.idtoken.sign_encrypt(
- session_info, "client_1", sign=True
- )
+ _token = self.endpoint_context.idtoken.sign_encrypt(session_id, AREQ['client_id'],
+ sign=True)
assert _token
_jws = jws.factory(_token)
- assert _jws.jwt.headers["alg"] == "RS512"
+ assert _jws.jwt.headers["alg"] == "RS256"
client_keyjar = KeyJar()
_jwks = self.endpoint_context.keyjar.export_jwks()
@@ -284,10 +234,9 @@ def test_sign_encrypt_id_token(self):
assert res["aud"] == ["client_1"]
def test_get_sign_algorithm(self):
- client_info = RegistrationResponse()
- endpoint_context = EndpointContext(conf)
+ client_info = self.endpoint_context.cdb[AREQ['client_id']]
algs = get_sign_and_encrypt_algorithms(
- endpoint_context, client_info, "id_token", sign=True
+ self.endpoint_context, client_info, "id_token", sign=True
)
# default signing alg
assert algs == {"sign": True, "encrypt": False, "sign_alg": "RS256"}
@@ -334,20 +283,11 @@ def test_get_sign_algorithm_4(self):
assert algs == {"sign": True, "encrypt": False, "sign_alg": "RS512"}
def test_available_claims(self):
- session_info = {
- "authn_req": AREQN,
- "sub": "sub",
- "authn_event": {
- "authn_info": "loa2",
- "authn_time": time.time(),
- "uid": "diana",
- },
- }
- self.endpoint_context.idtoken.kwargs["available_claims"] = {
- "nickname": {"essential": True}
- }
- req = {"client_id": "client_1"}
- _token = self.endpoint_context.idtoken.make(req, session_info)
+ session_id = self._create_session(AREQ)
+ grant = self.session_manager[session_id]
+ grant.claims = {"id_token": {"nickname": {"essential": True}}}
+
+ _token = self.endpoint_context.idtoken.make(session_id=session_id)
assert _token
client_keyjar = KeyJar()
_jwks = self.endpoint_context.keyjar.export_jwks()
@@ -357,39 +297,32 @@ def test_available_claims(self):
assert "nickname" in res
def test_no_available_claims(self):
- session_info = {
- "authn_req": AREQN,
- "sub": "sub",
- "authn_event": {
- "authn_info": "loa2",
- "authn_time": time.time(),
- "uid": "diana",
- },
- }
+ session_id = self._create_session(AREQ)
+ grant = self.session_manager[session_id]
+ grant.claims = {"id_token": {"foobar": None}}
+
req = {"client_id": "client_1"}
- _token = self.endpoint_context.idtoken.make(req, session_info)
+ _token = self.endpoint_context.idtoken.make(session_id=session_id)
assert _token
client_keyjar = KeyJar()
_jwks = self.endpoint_context.keyjar.export_jwks()
client_keyjar.import_jwks(_jwks, self.endpoint_context.issuer)
_jwt = JWT(key_jar=client_keyjar, iss="client_1")
res = _jwt.unpack(_token)
- assert "nickname" not in res
+ assert "foobar" not in res
def test_client_claims(self):
- session_info = {
- "authn_req": AREQN,
- "sub": "sub",
- "authn_event": {
- "authn_info": "loa2",
- "authn_time": time.time(),
- "uid": "diana",
- },
- }
- self.endpoint_context.idtoken.enable_claims_per_client = True
+ session_id = self._create_session(AREQ)
+ grant = self.session_manager[session_id]
+
+ self.endpoint_context.idtoken.kwargs["enable_claims_per_client"] = True
self.endpoint_context.cdb["client_1"]["id_token_claims"] = {"address": None}
- req = {"client_id": "client_1"}
- _token = self.endpoint_context.idtoken.make(req, session_info)
+
+ _claims = self.endpoint_context.claims_interface.get_claims(
+ session_id=session_id, scopes=AREQ["scope"], usage="id_token")
+ grant.claims = {'id_token': _claims}
+
+ _token = self.endpoint_context.idtoken.make(session_id=session_id)
assert _token
client_keyjar = KeyJar()
_jwks = self.endpoint_context.keyjar.export_jwks()
@@ -400,22 +333,38 @@ def test_client_claims(self):
assert "nickname" not in res
def test_client_claims_with_default(self):
- session_info = {
- "authn_req": AREQN,
- "sub": "sub",
- "authn_event": {
- "authn_info": "loa2",
- "authn_time": time.time(),
- "uid": "diana",
- },
- }
- self.endpoint_context.cdb["client_1"]["id_token_claims"] = {"address": None}
- self.endpoint_context.idtoken.kwargs["available_claims"] = {
- "nickname": {"essential": True}
- }
- self.endpoint_context.idtoken.enable_claims_per_client = True
- req = {"client_id": "client_1"}
- _token = self.endpoint_context.idtoken.make(req, session_info)
+ session_id = self._create_session(AREQ)
+ grant = self.session_manager[session_id]
+
+ # self.endpoint_context.cdb["client_1"]["id_token_claims"] = {"address": None}
+ # self.endpoint_context.idtoken.enable_claims_per_client = True
+
+ _claims = self.endpoint_context.claims_interface.get_claims(
+ session_id=session_id, scopes=AREQ["scope"], usage="id_token")
+ grant.claims = {"id_token": _claims}
+
+ _token = self.endpoint_context.idtoken.make(session_id=session_id)
+ assert _token
+ client_keyjar = KeyJar()
+ _jwks = self.endpoint_context.keyjar.export_jwks()
+ client_keyjar.import_jwks(_jwks, self.endpoint_context.issuer)
+ _jwt = JWT(key_jar=client_keyjar, iss="client_1")
+ res = _jwt.unpack(_token)
+
+ # No user info claims should be there
+ assert "address" not in res
+ assert "nickname" not in res
+
+ def test_client_claims_scopes(self):
+ session_id = self._create_session(AREQS)
+ grant = self.session_manager[session_id]
+
+ self.endpoint_context.idtoken.kwargs["add_claims_by_scope"] = True
+ _claims = self.endpoint_context.claims_interface.get_claims(
+ session_id=session_id, scopes=AREQS["scope"], usage="id_token")
+ grant.claims = {"id_token": _claims}
+
+ _token = self.endpoint_context.idtoken.make(session_id=session_id)
assert _token
client_keyjar = KeyJar()
_jwks = self.endpoint_context.keyjar.export_jwks()
@@ -423,27 +372,51 @@ def test_client_claims_with_default(self):
_jwt = JWT(key_jar=client_keyjar, iss="client_1")
res = _jwt.unpack(_token)
assert "address" in res
+ assert "email" in res
+ assert "nickname" not in res
+
+ def test_client_claims_scopes_and_request_claims_no_match(self):
+ session_id = self._create_session(AREQRC)
+ grant = self.session_manager[session_id]
+
+ self.endpoint_context.idtoken.kwargs["add_claims_by_scope"] = True
+ _claims = self.endpoint_context.claims_interface.get_claims(
+ session_id=session_id, scopes=AREQRC["scope"], usage="id_token")
+ grant.claims = {"id_token": _claims}
+
+ _token = self.endpoint_context.idtoken.make(session_id=session_id)
+ assert _token
+ client_keyjar = KeyJar()
+ _jwks = self.endpoint_context.keyjar.export_jwks()
+ client_keyjar.import_jwks(_jwks, self.endpoint_context.issuer)
+ _jwt = JWT(key_jar=client_keyjar, iss="client_1")
+ res = _jwt.unpack(_token)
+ # User information, from scopes -> claims
+ assert "address" in res
+ assert "email" in res
+ # User info, requested by claims parameter
assert "nickname" in res
- def test_client_claims_disabled(self):
- # enable_claims_per_client defaults to False
- session_info = {
- "authn_req": AREQN,
- "sub": "sub",
- "authn_event": {
- "authn_info": "loa2",
- "authn_time": time.time(),
- "uid": "diana",
- },
- }
- self.endpoint_context.cdb["client_1"]["id_token_claims"] = {"address": None}
- req = {"client_id": "client_1"}
- _token = self.endpoint_context.idtoken.make(req, session_info)
+ def test_client_claims_scopes_and_request_claims_one_match(self):
+ _req = AREQS.copy()
+ _req["claims"] = {"id_token": {"email": {"value": "diana@example.com"}}}
+
+ session_id = self._create_session(_req)
+ grant = self.session_manager[session_id]
+
+ self.endpoint_context.idtoken.kwargs["add_claims_by_scope"] = True
+ _claims = self.endpoint_context.claims_interface.get_claims(
+ session_id=session_id, scopes=_req["scope"], usage="id_token")
+ grant.claims = {"id_token": _claims}
+
+ _token = self.endpoint_context.idtoken.make(session_id=session_id)
assert _token
client_keyjar = KeyJar()
_jwks = self.endpoint_context.keyjar.export_jwks()
client_keyjar.import_jwks(_jwks, self.endpoint_context.issuer)
_jwt = JWT(key_jar=client_keyjar, iss="client_1")
res = _jwt.unpack(_token)
- assert "address" not in res
- assert "nickname" not in res
+ # Email didn't match
+ assert "email" not in res
+ # Scope -> claims
+ assert "address" in res
diff --git a/tests/test_04_token_handler.py b/tests/test_04_token_handler.py
index c0aab54..3a7de23 100644
--- a/tests/test_04_token_handler.py
+++ b/tests/test_04_token_handler.py
@@ -6,10 +6,12 @@
import pytest
-from oidcendpoint.token_handler import Crypt
-from oidcendpoint.token_handler import DefaultToken
-from oidcendpoint.token_handler import TokenHandler
-from oidcendpoint.token_handler import is_expired
+from oidcendpoint.token import Crypt
+from oidcendpoint.token import is_expired
+from oidcendpoint.token.handler import DefaultToken
+from oidcendpoint.token.handler import TokenHandler
+from oidcendpoint.token.handler import factory
+from oidcendpoint.token.jwt_token import JWTToken
def test_is_expired():
@@ -92,7 +94,7 @@ def test_is_expired(self):
assert self.th.is_expired(_token) is False
when = time.time() + 900
- assert self.th.is_expired(_token, when)
+ assert self.th.is_expired(_token, int(when))
class TestTokenHandler(object):
@@ -153,3 +155,60 @@ def test_get_handler(self):
def test_keys(self):
assert set(self.handler.keys()) == {"access_token", "code", "refresh_token"}
+
+
+class DummyEndpointContext():
+ def __init__(self):
+ self.keyjar = None
+ self.issuer = "issuer"
+ self.cdb = {}
+
+
+def test_token_handler_from_config():
+ conf = {
+ "token_handler_args": {
+ "jwks_def": {
+ "private_path": "private/token_jwks.json",
+ "read_only": False,
+ "key_defs": [
+ {"type": "oct", "bytes": "24", "use": ["enc"], "kid": "code"}
+ ],
+ },
+ "code": {"lifetime": 600},
+ "token": {
+ "class": "oidcendpoint.token.jwt_token.JWTToken",
+ "kwargs": {
+ "lifetime": 3600,
+ "add_claims_by_scope": True,
+ "aud": ["https://example.org/appl"],
+ },
+ },
+ "refresh": {
+ "class": "oidcendpoint.token.jwt_token.JWTToken",
+ "kwargs": {
+ "lifetime": 3600,
+ "aud": ["https://example.org/appl"],
+ }
+ }
+ }
+ }
+
+ token_handler = factory(DummyEndpointContext(), **conf["token_handler_args"])
+ assert token_handler
+ assert len(token_handler.handler) == 3
+ assert set(token_handler.handler.keys()) == {"code", "access_token", "refresh_token"}
+ assert isinstance(token_handler.handler["code"], DefaultToken)
+ assert isinstance(token_handler.handler["access_token"], JWTToken)
+ assert isinstance(token_handler.handler["refresh_token"], JWTToken)
+
+ assert token_handler.handler["code"].lifetime == 600
+
+ assert token_handler.handler["access_token"].alg == "ES256"
+ assert token_handler.handler["access_token"].kwargs == {"add_claims_by_scope": True}
+ assert token_handler.handler["access_token"].lifetime == 3600
+ assert token_handler.handler["access_token"].def_aud == ["https://example.org/appl"]
+
+ assert token_handler.handler["refresh_token"].alg == "ES256"
+ assert token_handler.handler["refresh_token"].kwargs == {}
+ assert token_handler.handler["refresh_token"].lifetime == 3600
+ assert token_handler.handler["refresh_token"].def_aud == ["https://example.org/appl"]
diff --git a/tests/test_05_session_manager.py b/tests/test_05_session_manager.py
new file mode 100644
index 0000000..a327edf
--- /dev/null
+++ b/tests/test_05_session_manager.py
@@ -0,0 +1,504 @@
+from oidcmsg.oidc import AuthorizationRequest
+from oidcmsg.time_util import time_sans_frac
+import pytest
+
+from oidcendpoint.authn_event import AuthnEvent
+from oidcendpoint.authz import AuthzHandling
+from oidcendpoint.endpoint_context import EndpointContext
+from oidcendpoint.oidc.authorization import Authorization
+from oidcendpoint.oidc.token import Token
+from oidcendpoint.session import MintingNotAllowed
+from oidcendpoint.session import session_key
+from oidcendpoint.session.grant import Grant
+from oidcendpoint.session.info import ClientSessionInfo
+from oidcendpoint.session.manager import SessionManager
+from oidcendpoint.session.token import AccessToken
+from oidcendpoint.session.token import AuthorizationCode
+from oidcendpoint.session.token import RefreshToken
+from oidcendpoint.token.handler import factory
+
+AUTH_REQ = AuthorizationRequest(
+ client_id="client_1",
+ redirect_uri="https://example.com/cb",
+ scope=["openid"],
+ state="STATE",
+ response_type="code",
+)
+
+MAP = {
+ "authorization_code": "code",
+ "access_token": "access_token",
+ "refresh_token": "refresh_token"
+}
+
+KEYDEFS = [
+ {"type": "RSA", "key": "", "use": ["sig"]},
+ {"type": "EC", "crv": "P-256", "use": ["sig"]},
+]
+
+DUMMY_SESSION_ID = session_key('user_id', 'client_id', 'grant.id')
+
+
+class TestSessionManager:
+ @pytest.fixture(autouse=True)
+ def create_session_manager(self):
+ conf = {
+ "issuer": "https://example.com/",
+ "password": "mycket hemligt",
+ "token_expires_in": 600,
+ "grant_expires_in": 300,
+ "refresh_token_expires_in": 86400,
+ "verify_ssl": False,
+ "keys": {"key_defs": KEYDEFS, "uri_path": "static/jwks.json"},
+ "jwks_uri": "https://example.com/jwks.json",
+ "token_handler_args": {
+ "jwks_def": {
+ "private_path": "private/token_jwks.json",
+ "read_only": False,
+ "key_defs": [
+ {"type": "oct", "bytes": "24", "use": ["enc"], "kid": "code"}
+ ],
+ },
+ "code": {"lifetime": 600},
+ "token": {
+ "class": "oidcendpoint.token.jwt_token.JWTToken",
+ "kwargs": {
+ "lifetime": 3600,
+ "add_claims": True,
+ "add_claim_by_scope": True,
+ "aud": ["https://example.org/appl"],
+ },
+ },
+ "refresh": {
+ "class": "oidcendpoint.token.jwt_token.JWTToken",
+ "kwargs": {
+ "lifetime": 3600,
+ "aud": ["https://example.org/appl"],
+ }
+ }
+ },
+ "endpoint": {
+ "authorization_endpoint": {
+ "path": "{}/authorization",
+ "class": Authorization,
+ "kwargs": {},
+ },
+ "token_endpoint": {"path": "{}/token", "class": Token, "kwargs": {}},
+ },
+ "template_dir": "template",
+ }
+
+ self.endpoint_context = EndpointContext(conf)
+ token_handler = factory(self.endpoint_context, **conf["token_handler_args"])
+
+ self.session_manager = SessionManager(handler=token_handler)
+ self.authn_event = AuthnEvent(uid="uid",
+ valid_until=time_sans_frac() + 1,
+ authn_info="authn_class_ref")
+
+ @pytest.mark.parametrize("sub_type, sector_identifier", [
+ ("pairwise", "https://all.example.com"),
+ ("public", ""), ("ephemeral", "")])
+ def test_create_session_sub_type(self, sub_type, sector_identifier):
+ # First session
+ authz_req = AUTH_REQ.copy()
+ if sub_type == "pairwise":
+ authz_req["sector_identifier_uri"] = sector_identifier
+
+ session_key_1 = self.session_manager.create_session(authn_event=self.authn_event,
+ auth_req=authz_req,
+ user_id='diana',
+ client_id="client_1",
+ sub_type=sub_type)
+
+ _user_info_1 = self.session_manager.get_user_session_info(session_key_1)
+ assert _user_info_1["subordinate"] == ["client_1"]
+ _client_info_1 = self.session_manager.get_client_session_info(session_key_1)
+ assert len(_client_info_1["subordinate"]) == 1
+ # grant = self.session_manager.get_grant(session_key_1)
+
+ # Second session
+ authn_req = AUTH_REQ.copy()
+ authn_req["client_id"] = "client_2"
+ if sub_type == "pairwise":
+ authn_req["sector_identifier_uri"] = sector_identifier
+
+ session_key_2 = self.session_manager.create_session(authn_event=self.authn_event,
+ auth_req=authn_req,
+ user_id='diana',
+ client_id="client_2",
+ sub_type=sub_type)
+
+ _user_info_2 = self.session_manager.get_user_session_info(session_key_2)
+ assert _user_info_2["subordinate"] == ["client_1", "client_2"]
+
+ grant_1 = self.session_manager.get_grant(session_key_1)
+ grant_2 = self.session_manager.get_grant(session_key_2)
+
+ if sub_type in ["pairwise", "public"]:
+ assert grant_1.sub == grant_2.sub
+ else:
+ assert grant_1.sub != grant_2.sub
+
+ # Third session
+ authn_req = AUTH_REQ.copy()
+ authn_req["client_id"] = "client_3"
+ if sub_type == "pairwise":
+ authn_req["sector_identifier_uri"] = sector_identifier
+
+ session_key_3 = self.session_manager.create_session(authn_event=self.authn_event,
+ auth_req=authn_req,
+ user_id='diana',
+ client_id="client_3",
+ sub_type=sub_type)
+
+ grant_3 = self.session_manager.get_grant(session_key_3)
+
+ if sub_type == "pairwise":
+ assert grant_1.sub == grant_2.sub
+ assert grant_1.sub == grant_3.sub
+ assert grant_3.sub == grant_2.sub
+ elif sub_type == "public":
+ assert grant_1.sub == grant_2.sub
+ assert grant_1.sub == grant_3.sub
+ assert grant_3.sub == grant_2.sub
+ else:
+ assert grant_1.sub != grant_2.sub
+ assert grant_1.sub != grant_3.sub
+ assert grant_3.sub != grant_2.sub
+
+ # Sub types differ so do authentication request
+
+ assert grant_1.authorization_request != grant_2.authorization_request
+ assert grant_1.authorization_request != grant_3.authorization_request
+ assert grant_3.authorization_request != grant_2.authorization_request
+
+ def _mint_token(self, type, grant, session_id, based_on=None):
+ # Constructing an authorization code is now done
+ return grant.mint_token(
+ session_id=session_id,
+ endpoint_context=self.endpoint_context,
+ token_type=type,
+ token_handler=self.session_manager.token_handler.handler[MAP[type]],
+ expires_at=time_sans_frac() + 300, # 5 minutes from now
+ based_on=based_on
+ )
+
+ def test_grant(self):
+ grant = Grant()
+ assert grant.issued_token == []
+ assert grant.is_active() is True
+
+ code = self._mint_token('authorization_code', grant, DUMMY_SESSION_ID)
+ assert isinstance(code, AuthorizationCode)
+ assert code.is_active()
+ assert len(grant.issued_token) == 1
+ assert code.max_usage_reached() is False
+
+ def test_code_usage(self):
+ grant = Grant()
+ assert grant.issued_token == []
+ assert grant.is_active() is True
+
+ code = self._mint_token('authorization_code', grant, DUMMY_SESSION_ID)
+ assert isinstance(code, AuthorizationCode)
+ assert code.is_active()
+ assert len(grant.issued_token) == 1
+
+ assert code.usage_rules["supports_minting"] == ['access_token', 'refresh_token']
+ access_token = self._mint_token('access_token', grant, DUMMY_SESSION_ID, code)
+ assert isinstance(access_token, AccessToken)
+ assert access_token.is_active()
+ assert len(grant.issued_token) == 2
+
+ refresh_token = self._mint_token('refresh_token', grant, DUMMY_SESSION_ID, code)
+ assert isinstance(refresh_token, RefreshToken)
+ assert refresh_token.is_active()
+ assert len(grant.issued_token) == 3
+
+ code.register_usage()
+ assert code.max_usage_reached() is True
+
+ with pytest.raises(MintingNotAllowed):
+ self._mint_token('access_token', grant, DUMMY_SESSION_ID, code)
+
+ grant.revoke_token(based_on=code.value)
+
+ assert access_token.revoked == True
+ assert refresh_token.revoked == True
+
+ def test_add_grant(self):
+ self.session_manager.create_session(authn_event=self.authn_event,
+ auth_req=AUTH_REQ,
+ user_id='diana',
+ client_id="client_1")
+
+ grant = self.session_manager.add_grant(
+ user_id="diana", client_id="client_1",
+ scope=["openid", "phoe"],
+ claims={"userinfo": {"given_name": None}})
+
+ assert grant.scope == ["openid", "phoe"]
+
+ _grant = self.session_manager.get(['diana', 'client_1', grant.id])
+
+ assert _grant.scope == ["openid", "phoe"]
+
+ def test_find_token(self):
+ self.session_manager.create_session(authn_event=self.authn_event,
+ auth_req=AUTH_REQ,
+ user_id='diana',
+ client_id="client_1")
+
+ grant = self.session_manager.add_grant(user_id="diana",
+ client_id="client_1")
+
+ code = self._mint_token('authorization_code', grant, DUMMY_SESSION_ID)
+ access_token = self._mint_token('access_token', grant, DUMMY_SESSION_ID, code)
+
+ _session_key = session_key('diana', 'client_1', grant.id)
+ _token = self.session_manager.find_token(_session_key, access_token.value)
+
+ assert _token.type == "access_token"
+ assert _token.id == access_token.id
+
+ def test_get_authentication_event(self):
+ session_id = self.session_manager.create_session(authn_event=self.authn_event,
+ auth_req=AUTH_REQ,
+ user_id='diana',
+ client_id="client_1")
+
+ _info = self.session_manager.get_session_info(session_id, authentication_event=True)
+ authn_event = _info["authentication_event"]
+
+ assert isinstance(authn_event, AuthnEvent)
+ assert authn_event["uid"] == "uid"
+ assert authn_event["authn_info"] == "authn_class_ref"
+
+ def test_get_client_session_info(self):
+ self.session_manager.create_session(authn_event=self.authn_event,
+ auth_req=AUTH_REQ,
+ user_id='diana',
+ client_id="client_1")
+
+ grant = self.session_manager.add_grant(user_id="diana",
+ client_id="client_1")
+
+ _session_id = session_key('diana', 'client_1', grant.id)
+ csi = self.session_manager.get_client_session_info(_session_id)
+
+ assert isinstance(csi, ClientSessionInfo)
+
+ def test_get_general_session_info(self):
+ _session_id = self.session_manager.create_session(authn_event=self.authn_event,
+ auth_req=AUTH_REQ,
+ user_id='diana',
+ client_id="client_1")
+
+ _session_info = self.session_manager.get_session_info(_session_id)
+
+ assert set(_session_info.keys()) == {'client_id',
+ 'grant_id',
+ 'session_id',
+ 'user_id'}
+ assert _session_info["user_id"] == "diana"
+ assert _session_info["client_id"] == "client_1"
+
+ def test_get_session_info_by_token(self):
+ _session_id = self.session_manager.create_session(authn_event=self.authn_event,
+ auth_req=AUTH_REQ,
+ user_id='diana',
+ client_id="client_1")
+
+ grant = self.session_manager.get_grant(_session_id)
+ code = self._mint_token('authorization_code', grant, _session_id)
+ _session_info = self.session_manager.get_session_info_by_token(code.value)
+
+ assert set(_session_info.keys()) == {'client_id',
+ 'session_id',
+ 'grant_id',
+ 'user_id'}
+ assert _session_info["user_id"] == "diana"
+ assert _session_info["client_id"] == "client_1"
+
+ def test_token_usage_default(self):
+ _session_id = self.session_manager.create_session(authn_event=self.authn_event,
+ auth_req=AUTH_REQ,
+ user_id='diana',
+ client_id="client_1")
+ grant = self.session_manager[_session_id]
+
+ code = self._mint_token('authorization_code', grant, _session_id)
+
+ assert code.usage_rules == {
+ 'max_usage': 1,
+ 'supports_minting': ['access_token', 'refresh_token']
+ }
+
+ token = self._mint_token("access_token", grant, _session_id, code)
+
+ assert token.usage_rules == {}
+
+ refresh_token = self._mint_token("refresh_token", grant, _session_id, code)
+
+ assert refresh_token.usage_rules == {
+ 'supports_minting': ['access_token', 'refresh_token']
+ }
+
+ def test_token_usage_grant(self):
+ _session_id = self.session_manager.create_session(authn_event=self.authn_event,
+ auth_req=AUTH_REQ,
+ user_id='diana',
+ client_id="client_1")
+ grant = self.session_manager[_session_id]
+ grant.usage_rules = {
+ "authorization_code": {
+ "max_usage": 1,
+ "supports_minting": ["access_token", "refresh_token", "id_token"],
+ "expires_in": 300
+ },
+ "access_token": {
+ "expires_in": 3600
+ },
+ "refresh_token": {
+ "supports_minting": ["access_token", "refresh_token", "id_token"]
+ }
+ }
+
+ code = self._mint_token('authorization_code', grant, _session_id)
+ assert code.usage_rules == {
+ 'max_usage': 1,
+ 'supports_minting': ['access_token', 'refresh_token', 'id_token'],
+ "expires_in": 300
+ }
+
+ token = self._mint_token("access_token", grant, _session_id, code)
+ assert token.usage_rules == {
+ "expires_in": 3600
+ }
+
+ refresh_token = self._mint_token("refresh_token", grant, _session_id, code)
+ assert refresh_token.usage_rules == {
+ 'supports_minting': ['access_token', 'refresh_token', "id_token"]
+ }
+
+ def test_token_usage_authz(self):
+ grant_config = {
+ "usage_rules": {
+ "authorization_code": {
+ 'supports_minting': ["access_token"],
+ "max_usage": 1,
+ "expires_in": 120
+ },
+ "access_token": {
+ "expires_in": 600
+ },
+ "refresh_token": {}
+ },
+ "expires_in": 43200
+ }
+
+ self.endpoint_context.authz = AuthzHandling(self.endpoint_context,
+ grant_config=grant_config)
+
+ self.endpoint_context.cdb["client_1"] = {}
+
+ token_usage_rules = self.endpoint_context.authz.usage_rules('client_1')
+
+ _session_id = self.session_manager.create_session(authn_event=self.authn_event,
+ auth_req=AUTH_REQ,
+ user_id='diana',
+ client_id="client_1",
+ token_usage_rules=token_usage_rules)
+ grant = self.session_manager[_session_id]
+
+ code = self._mint_token('authorization_code', grant, _session_id)
+ assert code.usage_rules == {
+ 'max_usage': 1,
+ 'supports_minting': ['access_token'],
+ "expires_in": 120
+ }
+
+ token = self._mint_token("access_token", grant, _session_id, code)
+ assert token.usage_rules == {
+ "expires_in": 600
+ }
+
+ with pytest.raises(MintingNotAllowed):
+ self._mint_token("refresh_token", grant, _session_id, code)
+
+ def test_token_usage_client_config(self):
+ grant_config = {
+ "usage_rules": {
+ "authorization_code": {
+ 'supports_minting': ["access_token"],
+ "max_usage": 1,
+ "expires_in": 120
+ },
+ "access_token": {
+ "expires_in": 600
+ },
+ "refresh_token": {}
+ },
+ "expires_in": 43200
+ }
+
+ self.endpoint_context.authz = AuthzHandling(self.endpoint_context,
+ grant_config=grant_config)
+
+ # Change expiration time for the code and allow refresh tokens for this
+ # specific client
+ self.endpoint_context.cdb["client_1"] = {
+ "token_usage_rules": {
+ "authorization_code": {
+ "expires_in": 600,
+ 'supports_minting': ['access_token', "refresh_token"]
+ },
+ "refresh_token": {
+ 'supports_minting': ['access_token']
+ }
+ }
+ }
+
+ token_usage_rules = self.endpoint_context.authz.usage_rules('client_1')
+
+ _session_id = self.session_manager.create_session(authn_event=self.authn_event,
+ auth_req=AUTH_REQ,
+ user_id='diana',
+ client_id="client_1",
+ token_usage_rules=token_usage_rules)
+ grant = self.session_manager[_session_id]
+
+ code = self._mint_token('authorization_code', grant, _session_id)
+ assert code.usage_rules == {
+ 'max_usage': 1,
+ 'supports_minting': ['access_token', 'refresh_token'],
+ "expires_in": 600
+ }
+
+ token = self._mint_token("access_token", grant, _session_id, code)
+ assert token.usage_rules == {
+ "expires_in": 600
+ }
+
+ refresh_token = self._mint_token("refresh_token", grant, _session_id, code)
+ assert refresh_token.usage_rules == {
+ 'supports_minting': ['access_token']
+ }
+
+ # Test with another client
+
+ self.endpoint_context.cdb["client_2"] = {}
+
+ token_usage_rules = self.endpoint_context.authz.usage_rules('client_2')
+
+ _session_id = self.session_manager.create_session(authn_event=self.authn_event,
+ auth_req=AUTH_REQ,
+ user_id='diana',
+ client_id="client_2",
+ token_usage_rules=token_usage_rules)
+ grant = self.session_manager[_session_id]
+ code = self._mint_token('authorization_code', grant, _session_id)
+ # Not allowed to mint refresh token for this client
+ with pytest.raises(MintingNotAllowed):
+ self._mint_token("refresh_token", grant, _session_id, code)
diff --git a/tests/test_05_sso_db.py b/tests/test_05_sso_db.py
deleted file mode 100644
index 5c8479f..0000000
--- a/tests/test_05_sso_db.py
+++ /dev/null
@@ -1,135 +0,0 @@
-import shutil
-
-import pytest
-
-from oidcendpoint.sso_db import SSODb
-
-DB_CONF = {
- "handler": "oidcmsg.storage.abfile.AbstractFileSystem",
- "fdir": "db/sso",
- "key_conv": "oidcmsg.storage.converter.QPKey",
- "value_conv": "oidcmsg.storage.converter.JSON",
-}
-
-
-def rmtree(item):
- try:
- shutil.rmtree(item)
- except FileNotFoundError:
- pass
-
-
-class TestSSODB(object):
- @pytest.fixture(autouse=True)
- def create_sdb(self):
- rmtree("db/sso")
- self.sso_db = SSODb(DB_CONF)
-
- def test_map_sid2uid(self):
- self.sso_db.map_sid2uid("session id 1", "Lizz")
- assert self.sso_db.get_sids_by_uid("Lizz") == ["session id 1"]
-
- def test_missing_map(self):
- assert self.sso_db.get_sids_by_uid("Lizz") == []
-
- def test_multiple_map_sid2uid(self):
- self.sso_db.map_sid2uid("session id 1", "Lizz")
- self.sso_db.map_sid2uid("session id 2", "Lizz")
- assert set(self.sso_db.get_sids_by_uid("Lizz")) == {
- "session id 1",
- "session id 2",
- }
-
- def test_map_unmap_sid2uid(self):
- self.sso_db.map_sid2uid("session id 1", "Lizz")
- self.sso_db.map_sid2uid("session id 2", "Lizz")
- assert set(self.sso_db.get_sids_by_uid("Lizz")) == {
- "session id 1",
- "session id 2",
- }
-
- self.sso_db.remove_sid2uid("session id 1", "Lizz")
- assert self.sso_db.get_sids_by_uid("Lizz") == ["session id 2"]
-
- def test_get_uid_by_sid(self):
- self.sso_db.map_sid2uid("session id 1", "Lizz")
- self.sso_db.map_sid2uid("session id 2", "Lizz")
-
- assert self.sso_db.get_uid_by_sid("session id 1") == "Lizz"
- assert self.sso_db.get_uid_by_sid("session id 2") == "Lizz"
-
- def test_remove_uid(self):
- self.sso_db.map_sid2uid("session id 1", "Lizz")
- self.sso_db.map_sid2uid("session id 2", "Diana")
-
- self.sso_db.remove_uid("Lizz")
- assert self.sso_db.get_uid_by_sid("session id 1") is None
- assert self.sso_db.get_sids_by_uid("Lizz") == []
-
- def test_map_sid2sub(self):
- self.sso_db.map_sid2sub("session id 1", "abcdefgh")
- assert self.sso_db.get_sids_by_sub("abcdefgh") == ["session id 1"]
-
- def test_missing_sid2sub_map(self):
- assert self.sso_db.get_sids_by_sub("abcdefgh") == []
-
- def test_multiple_map_sid2sub(self):
- self.sso_db.map_sid2sub("session id 1", "abcdefgh")
- self.sso_db.map_sid2sub("session id 2", "abcdefgh")
- assert set(self.sso_db.get_sids_by_sub("abcdefgh")) == {
- "session id 1",
- "session id 2",
- }
-
- def test_map_unmap_sid2sub(self):
- self.sso_db.map_sid2sub("session id 1", "abcdefgh")
- self.sso_db.map_sid2sub("session id 2", "abcdefgh")
- assert set(self.sso_db.get_sids_by_sub("abcdefgh")) == {
- "session id 1",
- "session id 2",
- }
-
- self.sso_db.remove_sid2sub("session id 1", "abcdefgh")
- assert self.sso_db.get_sids_by_sub("abcdefgh") == ["session id 2"]
-
- def test_get_sub_by_sid(self):
- self.sso_db.map_sid2sub("session id 1", "abcdefgh")
- self.sso_db.map_sid2sub("session id 2", "abcdefgh")
-
- assert set(self.sso_db.get_sids_by_sub("abcdefgh")) == {
- "session id 1",
- "session id 2",
- }
-
- def test_remove_sub(self):
- self.sso_db.map_sid2sub("session id 1", "abcdefgh")
- self.sso_db.map_sid2sub("session id 2", "012346789")
-
- self.sso_db.remove_sub("abcdefgh")
- assert self.sso_db.get_sub_by_sid("session id 1") is None
- assert self.sso_db.get_sids_by_sub("abcdefgh") == []
- # have not touched the others
- assert self.sso_db.get_sub_by_sid("session id 2") == "012346789"
- assert self.sso_db.get_sids_by_sub("012346789") == ["session id 2"]
-
- def test_get_sub_by_uid_same_sub(self):
- self.sso_db.map_sid2sub("session id 1", "abcdefgh")
- self.sso_db.map_sid2sub("session id 2", "abcdefgh")
-
- self.sso_db.map_sid2uid("session id 1", "Lizz")
- self.sso_db.map_sid2uid("session id 2", "Lizz")
-
- res = self.sso_db.get_subs_by_uid("Lizz")
-
- assert set(res) == {"abcdefgh"}
-
- def test_get_sub_by_uid_different_sub(self):
- self.sso_db.map_sid2sub("session id 1", "abcdefgh")
- self.sso_db.map_sid2sub("session id 2", "012346789")
-
- self.sso_db.map_sid2uid("session id 1", "Lizz")
- self.sso_db.map_sid2uid("session id 2", "Lizz")
-
- res = self.sso_db.get_subs_by_uid("Lizz")
-
- assert set(res) == {"abcdefgh", "012346789"}
diff --git a/tests/test_06_authn_context.py b/tests/test_06_authn_context.py
index 284bc23..d3f41d6 100644
--- a/tests/test_06_authn_context.py
+++ b/tests/test_06_authn_context.py
@@ -11,7 +11,7 @@
from oidcendpoint.id_token import IDToken
from oidcendpoint.oidc.authorization import Authorization
from oidcendpoint.oidc.provider_config import ProviderConfiguration
-from oidcendpoint.oidc.token import AccessToken
+from oidcendpoint.oidc.token import Token
from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD
from oidcendpoint.user_authn.authn_context import TIMESYNCTOKEN
from oidcendpoint.user_authn.authn_context import init_method
@@ -58,7 +58,7 @@
"private_key_jwt",
],
"response_modes_supported": ["query", "fragment", "form_post"],
- "subject_types_supported": ["public", "pairwise"],
+ "subject_types_supported": ["public", "pairwise", "ephemeral"],
"grant_types_supported": [
"authorization_code",
"implicit",
@@ -152,7 +152,7 @@ def create_authn_broker(self):
"class": Authorization,
"kwargs": {},
},
- "token": {"path": "{}/token", "class": AccessToken, "kwargs": {}},
+ "token": {"path": "{}/token", "class": Token, "kwargs": {}},
},
"authentication": METHOD,
"userinfo": {"class": UserInfo, "kwargs": {"db": USERINFO_db}},
@@ -202,15 +202,14 @@ def test_pick_authn_all(self):
def test_authn_event():
an = AuthnEvent(
uid="uid",
- salt="_salt_",
valid_until=time_sans_frac() + 1,
authn_info="authn_class_ref",
)
- assert an.valid()
+ assert an.is_valid()
n = time_sans_frac() + 3
- assert an.valid(n) is False
+ assert an.is_valid(n) is False
n = an.expires_in()
assert n == 1 # could possibly be 0
diff --git a/tests/test_07_userinfo.py b/tests/test_07_userinfo.py
index d3a0573..b5bc391 100644
--- a/tests/test_07_userinfo.py
+++ b/tests/test_07_userinfo.py
@@ -2,24 +2,24 @@
import os
import pytest
-from oidcmsg.message import Message
from oidcmsg.oidc import OpenIDRequest
-from oidcmsg.oidc import OpenIDSchema
from oidcendpoint.authn_event import create_authn_event
from oidcendpoint.endpoint_context import EndpointContext
+from oidcendpoint.id_token import IDToken
+from oidcendpoint.oidc import userinfo
from oidcendpoint.oidc.authorization import Authorization
from oidcendpoint.oidc.provider_config import ProviderConfiguration
from oidcendpoint.oidc.registration import Registration
from oidcendpoint.scopes import SCOPE2CLAIMS
-from oidcendpoint.scopes import STANDARD_CLAIMS
from oidcendpoint.scopes import convert_scopes2claims
+from oidcendpoint.session import session_key
+from oidcendpoint.session import unpack_session_key
+from oidcendpoint.session.claims import STANDARD_CLAIMS
+from oidcendpoint.session.claims import ClaimsInterface
+from oidcendpoint.session.grant import Grant
from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD
from oidcendpoint.user_info import UserInfo
-from oidcendpoint.userinfo import by_schema
-from oidcendpoint.userinfo import claims_match
-from oidcendpoint.userinfo import collect_user_info
-from oidcendpoint.userinfo import update_claims
CLAIMS = {
"userinfo": {
@@ -136,27 +136,27 @@ def test_custom_scopes():
assert set(
convert_scopes2claims(["email"], _available_claims, map=_scopes).keys()
- ) == {"email", "email_verified",}
+ ) == {"email", "email_verified", }
assert set(
convert_scopes2claims(["address"], _available_claims, map=_scopes).keys()
) == {"address"}
assert set(
convert_scopes2claims(["phone"], _available_claims, map=_scopes).keys()
- ) == {"phone_number", "phone_number_verified",}
+ ) == {"phone_number", "phone_number_verified", }
assert set(
convert_scopes2claims(
["research_and_scholarship"], _available_claims, map=_scopes
).keys()
) == {
- "name",
- "given_name",
- "family_name",
- "email",
- "email_verified",
- "sub",
- "eduperson_scoped_affiliation",
- }
+ "name",
+ "given_name",
+ "family_name",
+ "email",
+ "email_verified",
+ "sub",
+ "eduperson_scoped_affiliation",
+ }
PROVIDER_INFO = {
@@ -172,71 +172,6 @@ def test_custom_scopes():
]
}
-
-def test_update_claims_authn_req_id_token():
- _session_info = {"authn_req": OIDR}
- claims = update_claims(_session_info, "id_token", PROVIDER_INFO)
- assert set(claims.keys()) == {"auth_time", "acr"}
-
-
-def test_update_claims_authn_req_userinfo():
- _session_info = {"authn_req": OIDR}
- claims = update_claims(_session_info, "userinfo", PROVIDER_INFO)
- assert set(claims.keys()) == {
- "given_name",
- "nickname",
- "email",
- "email_verified",
- "picture",
- "http://example.info/claims/groups",
- }
-
-
-def test_update_claims_authzreq_id_token():
- _session_info = {"authn_req": OIDR}
- claims = update_claims(_session_info, "id_token", PROVIDER_INFO)
- assert set(claims.keys()) == {"auth_time", "acr"}
-
-
-def test_update_claims_authzreq_userinfo():
- _session_info = {"authn_req": OIDR}
- claims = update_claims(_session_info, "userinfo", PROVIDER_INFO)
- assert set(claims.keys()) == {
- "given_name",
- "nickname",
- "email",
- "email_verified",
- "picture",
- "http://example.info/claims/groups",
- }
-
-
-def test_clams_value():
- assert claims_match("red", CLAIMS["userinfo"]["http://example.info/claims/groups"])
-
-
-def test_clams_values():
- assert claims_match("urn:mace:incommon:iap:silver", CLAIMS["id_token"]["acr"])
-
-
-def test_clams_essential():
- assert claims_match(["foobar@example"], CLAIMS["userinfo"]["email"])
-
-
-def test_clams_none():
- assert claims_match(["angle"], CLAIMS["userinfo"]["nickname"])
-
-
-def test_by_schema():
- # There are no requested or optional claims defined for Message
- assert by_schema(Message, sub="John") == {}
-
- assert by_schema(OpenIDSchema, sub="John", given_name="John", age=34) == {
- "sub": "John",
- "given_name": "John",
- }
-
-
KEYDEFS = [
{"type": "RSA", "key": "", "use": ["sig"]},
{"type": "EC", "crv": "P-256", "use": ["sig"]},
@@ -282,6 +217,24 @@ def create_endpoint_context(self):
"request_uri_parameter_supported": True,
},
},
+ "userinfo": {
+ "path": "userinfo",
+ "class": userinfo.UserInfo,
+ "kwargs": {
+ "claim_types_supported": [
+ "normal",
+ "aggregated",
+ "distributed",
+ ],
+ "client_authn_method": ["bearer_header"],
+ "base_claims": {
+ "eduperson_scoped_affiliation": None,
+ "email": None,
+ },
+ "add_claims_by_scope": True,
+ "enable_claims_per_client": True
+ },
+ },
},
"keys": {
"public_path": "jwks.json",
@@ -296,78 +249,141 @@ def create_endpoint_context(self):
}
},
"template_dir": "template",
+ "id_token": {
+ "class": IDToken,
+ "kwargs": {
+ "base_claims": {
+ "email": None,
+ "email_verified": None,
+ },
+ "enable_claims_per_client": True
+ },
+ },
}
)
# Just has to be there
self.endpoint_context.cdb["client1"] = {}
+ self.session_manager = self.endpoint_context.session_manager
+ self.claims_interface = ClaimsInterface(self.endpoint_context)
+ self.user_id = "diana"
+
+ def _create_session(self, auth_req, sub_type="public", sector_identifier=''):
+ if sector_identifier:
+ authz_req = auth_req.copy()
+ authz_req["sector_identifier_uri"] = sector_identifier
+ else:
+ authz_req = auth_req
+ client_id = authz_req['client_id']
+ ae = create_authn_event(self.user_id)
+ return self.session_manager.create_session(ae, authz_req, self.user_id,
+ client_id=client_id,
+ sub_type=sub_type)
def test_collect_user_info(self):
_req = OIDR.copy()
_req["claims"] = CLAIMS_2
- _session_info = {"authn_req": _req}
- session = _session_info.copy()
- session["sub"] = "doe"
- session["uid"] = "diana"
- session["authn_event"] = create_authn_event("diana", "salt")
+ session_id = self._create_session(_req)
+
+ _userinfo_restriction = self.claims_interface.get_claims(session_id=session_id,
+ scopes=OIDR["scope"],
+ usage="userinfo")
- res = collect_user_info(self.endpoint_context, session)
+ res = self.claims_interface.get_user_claims("diana", _userinfo_restriction)
assert res == {
+ 'eduperson_scoped_affiliation': ['staff@example.org'],
+ "email": "diana@example.org",
"nickname": "Dina",
- "sub": "doe",
+ "email_verified": False
+ }
+
+ _id_token_restriction = self.claims_interface.get_claims(session_id=session_id,
+ scopes=OIDR["scope"],
+ usage="id_token")
+
+ res = self.claims_interface.get_user_claims("diana", _id_token_restriction)
+
+ assert res == {
"email": "diana@example.org",
"email_verified": False,
}
+ _introspection_restriction = self.claims_interface.get_claims(session_id=session_id,
+ scopes=OIDR["scope"],
+ usage="introspection")
+
+ res = self.claims_interface.get_user_claims("diana", _introspection_restriction)
+
+ assert res == {}
+
def test_collect_user_info_2(self):
_req = OIDR.copy()
- _req["scope"] = "openid email"
+ _req["scope"] = "openid email address"
del _req["claims"]
- _session_info = {"authn_req": _req}
- session = _session_info.copy()
- session["sub"] = "doe"
- session["uid"] = "diana"
- session["authn_event"] = create_authn_event("diana", "salt")
+ session_id = self._create_session(_req)
+ _uid, _cid, _gid = unpack_session_key(session_id)
- self.endpoint_context.provider_info["scopes_supported"] = [
- "openid",
- "email",
- "offline_access",
- ]
- res = collect_user_info(self.endpoint_context, session)
+ _userinfo_restriction = self.claims_interface.get_claims(session_id=session_id,
+ scopes=_req["scope"],
+ usage="userinfo")
+
+ res = self.claims_interface.get_user_claims("diana", _userinfo_restriction)
assert res == {
- "sub": "doe",
- "email": "diana@example.org",
- "email_verified": False,
+ 'address': {
+ 'country': 'Sweden',
+ 'locality': 'Umeå',
+ 'postal_code': 'SE-90187',
+ 'street_address': 'Umeå Universitet'
+ },
+ 'eduperson_scoped_affiliation': ['staff@example.org'],
+ 'email': 'diana@example.org',
+ 'email_verified': False
}
- def test_collect_user_info_scope_not_supported(self):
+ def test_collect_user_info_scope_not_supported_no_base_claims(self):
_req = OIDR.copy()
_req["scope"] = "openid email address"
del _req["claims"]
- _session_info = {"authn_req": _req}
- session = _session_info.copy()
- session["sub"] = "doe"
- session["uid"] = "diana"
- session["authn_event"] = create_authn_event("diana", "salt")
+ session_id = self._create_session(_req)
+ _uid, _cid, _gid = unpack_session_key(session_id)
- # Scope address not supported
- self.endpoint_context.provider_info["scopes_supported"] = [
- "openid",
- "email",
- "offline_access",
- ]
- res = collect_user_info(self.endpoint_context, session)
+ self.endpoint_context.endpoint["userinfo"].kwargs["add_claims_by_scope"] = False
+ self.endpoint_context.endpoint["userinfo"].kwargs["enable_claims_per_client"] = False
+ del self.endpoint_context.endpoint["userinfo"].kwargs["base_claims"]
- assert res == {
- "sub": "doe",
- "email": "diana@example.org",
- "email_verified": False,
- }
+ _userinfo_restriction = self.claims_interface.get_claims(session_id=session_id,
+ scopes=_req["scope"],
+ usage="userinfo")
+
+ res = self.claims_interface.get_user_claims("diana", _userinfo_restriction)
+
+ assert res == {}
+
+ def test_collect_user_info_enable_claims_per_client(self):
+ _req = OIDR.copy()
+ _req["scope"] = "openid email address"
+ del _req["claims"]
+
+ session_id = self._create_session(_req)
+ _uid, _cid, _gid = unpack_session_key(session_id)
+
+ self.endpoint_context.endpoint["userinfo"].kwargs["add_claims_by_scope"] = False
+ self.endpoint_context.endpoint["userinfo"].kwargs["enable_claims_per_client"] = True
+ del self.endpoint_context.endpoint["userinfo"].kwargs["base_claims"]
+
+ self.endpoint_context.cdb[_req["client_id"]]["userinfo_claims"] = {"phone_number": None}
+
+ _userinfo_restriction = self.claims_interface.get_claims(session_id=session_id,
+ scopes=_req["scope"],
+ usage="userinfo")
+
+ res = self.claims_interface.get_user_claims("diana", _userinfo_restriction)
+
+ assert res == {'phone_number': '+46907865000'}
class TestCollectUserInfoCustomScopes:
@@ -409,6 +425,24 @@ def create_endpoint_context(self):
"request_uri_parameter_supported": True,
},
},
+ "userinfo": {
+ "path": "userinfo",
+ "class": userinfo.UserInfo,
+ "kwargs": {
+ "claim_types_supported": [
+ "normal",
+ "aggregated",
+ "distributed",
+ ],
+ "client_authn_method": ["bearer_header"],
+ "base_claims": {
+ "eduperson_scoped_affiliation": None,
+ "email": None,
+ },
+ "add_claims_by_scope": True,
+ "enable_claims_per_client": True
+ },
+ },
},
"add_on": {
"custom_scopes": {
@@ -440,66 +474,62 @@ def create_endpoint_context(self):
}
},
"template_dir": "template",
+ "id_token": {
+ "class": IDToken,
+ "kwargs": {
+ "base_claims": {
+ "email": None,
+ "email_verified": None,
+ },
+ "enable_claims_per_client": True
+ },
+ },
}
)
self.endpoint_context.cdb["client1"] = {}
-
- def test_collect_user_info(self):
- _session_info = {"authn_req": OIDR}
- session = _session_info.copy()
- session["sub"] = "doe"
- session["uid"] = "diana"
- session["authn_event"] = create_authn_event("diana", "salt")
-
- res = collect_user_info(self.endpoint_context, session)
-
- assert res == {
- "email": "diana@example.org",
- "email_verified": False,
- "nickname": "Dina",
- "given_name": "Diana",
- "sub": "doe",
- }
-
- def test_collect_user_info_2(self):
+ self.session_manager = self.endpoint_context.session_manager
+ self.claims_interface = ClaimsInterface(self.endpoint_context)
+ self.user_id = "diana"
+
+ def _create_session(self, auth_req, sub_type="public", sector_identifier=''):
+ if sector_identifier:
+ authz_req = auth_req.copy()
+ authz_req["sector_identifier_uri"] = sector_identifier
+ else:
+ authz_req = auth_req
+ client_id = authz_req['client_id']
+ ae = create_authn_event(self.user_id)
+ return self.session_manager.create_session(ae, authz_req, self.user_id,
+ client_id=client_id,
+ sub_type=sub_type)
+
+ # def _do_grant(self, auth_req):
+ # client_id = auth_req['client_id']
+ # # The user consent module produces a Grant instance
+ # grant = Grant(scope=auth_req['scope'], resources=[client_id])
+ #
+ # # the grant is assigned to a session (user_id, client_id)
+ # self.session_manager.set([self.user_id, client_id, grant.id], grant)
+ # return session_key(self.user_id, client_id, grant.id)
+
+ def test_collect_user_info_custom_scope(self):
_req = OIDR.copy()
- _req["scope"] = "openid email"
+ _req["scope"] = "openid research_and_scholarship"
del _req["claims"]
- _session_info = {"authn_req": _req}
- session = _session_info.copy()
- session["sub"] = "doe"
- session["uid"] = "diana"
- session["authn_event"] = create_authn_event("diana", "salt")
-
- self.endpoint_context.provider_info["claims_supported"].remove("email")
- self.endpoint_context.provider_info["claims_supported"].remove("email_verified")
-
- res = collect_user_info(self.endpoint_context, session)
+ session_id = self._create_session(_req)
- assert res == {"sub": "doe"}
+ _restriction = self.claims_interface.get_claims(session_id=session_id,
+ scopes=_req["scope"],
+ usage="userinfo")
- def test_collect_user_info_scope_not_supported(self):
- _req = OIDR.copy()
- _req["scope"] = "openid email address"
- del _req["claims"]
-
- _session_info = {"authn_req": _req}
- session = _session_info.copy()
- session["sub"] = "doe"
- session["uid"] = "diana"
- session["authn_event"] = create_authn_event("diana", "salt")
-
- # Scope address not supported
- self.endpoint_context.provider_info["scopes_supported"] = [
- "openid",
- "email",
- "offline_access",
- ]
- res = collect_user_info(self.endpoint_context, session)
+ res = self.claims_interface.get_user_claims("diana", _restriction)
assert res == {
- "sub": "doe",
- "email": "diana@example.org",
- "email_verified": False,
+ 'eduperson_scoped_affiliation': ['staff@example.org'],
+ 'email': 'diana@example.org',
+ 'email_verified': False,
+ 'family_name': 'Krall',
+ 'given_name': 'Diana',
+ 'name': 'Diana Krall'
}
diff --git a/tests/test_08_session.py b/tests/test_08_session.py
deleted file mode 100644
index dac1dc3..0000000
--- a/tests/test_08_session.py
+++ /dev/null
@@ -1,520 +0,0 @@
-import os
-import shutil
-import time
-
-from oidcendpoint.token_handler import UnknownToken
-from oidcmsg.oidc import AuthorizationRequest
-from oidcmsg.oidc import OpenIDRequest
-from oidcmsg.storage.init import storage_factory
-import pytest
-
-from oidcendpoint import rndstr
-from oidcendpoint import token_handler
-from oidcendpoint.authn_event import create_authn_event
-from oidcendpoint.endpoint_context import EndpointContext
-from oidcendpoint.oidc.authorization import Authorization
-from oidcendpoint.oidc.provider_config import ProviderConfiguration
-from oidcendpoint.session import SessionDB
-from oidcendpoint.session import setup_session
-from oidcendpoint.sso_db import SSODb
-from oidcendpoint.token_handler import WrongTokenType
-from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD
-from oidcendpoint.user_info import UserInfo
-
-__author__ = "rohe0002"
-
-AREQ = AuthorizationRequest(
- response_type="code",
- client_id="client1",
- redirect_uri="http://example.com/authz",
- scope=["openid"],
- state="state000",
-)
-
-AREQN = AuthorizationRequest(
- response_type="code",
- client_id="client1",
- redirect_uri="http://example.com/authz",
- scope=["openid"],
- state="state000",
- nonce="something",
-)
-
-AREQO = AuthorizationRequest(
- response_type="code",
- client_id="client1",
- redirect_uri="http://example.com/authz",
- scope=["openid", "offline_access"],
- prompt="consent",
- state="state000",
-)
-
-OIDR = OpenIDRequest(
- response_type="code",
- client_id="client1",
- redirect_uri="http://example.com/authz",
- scope=["openid"],
- state="state000",
-)
-
-BASEDIR = os.path.abspath(os.path.dirname(__file__))
-
-
-def full_path(local_file):
- return os.path.join(BASEDIR, local_file)
-
-
-SSO_DB_CONF = {
- "handler": "oidcmsg.storage.abfile.AbstractFileSystem",
- "fdir": "db/sso",
- "key_conv": "oidcmsg.storage.converter.QPKey",
- "value_conv": "oidcmsg.storage.converter.JSON",
-}
-
-SESSION_DB_CONF = {
- "handler": "oidcmsg.storage.abfile.AbstractFileSystem",
- "fdir": "db/session",
- "key_conv": "oidcmsg.storage.converter.QPKey",
- "value_conv": "oidcmsg.storage.converter.JSON",
-}
-
-
-def rmtree(item):
- try:
- shutil.rmtree(item)
- except FileNotFoundError:
- pass
-
-
-class TestSessionDB(object):
- @pytest.fixture(autouse=True)
- def create_sdb(self):
- rmtree("db/sso")
- rmtree("db/session")
-
- passwd = rndstr(24)
- _th_args = {
- "code": {"lifetime": 600, "password": passwd},
- "token": {"lifetime": 3600, "password": passwd},
- "refresh": {"lifetime": 86400, "password": passwd},
- }
-
- _token_handler = token_handler.factory(None, **_th_args)
- userinfo = UserInfo(db_file=full_path("users.json"))
- self.sdb = SessionDB(
- storage_factory(SESSION_DB_CONF),
- _token_handler,
- SSODb(SSO_DB_CONF),
- userinfo,
- )
-
- def test_create_authz_session(self):
- ae = create_authn_event("uid", "salt")
- sid = self.sdb.create_authz_session(ae, AREQ, client_id="client_id")
- self.sdb.do_sub(sid, uid="user", salt="client_salt")
-
- info = self.sdb[sid]
- assert info["client_id"] == "client_id"
- assert set(info.keys()) == {
- "sid",
- "client_id",
- "authn_req",
- "authn_event",
- "sub",
- "oauth_state",
- "code",
- }
-
- def test_create_authz_session_without_nonce(self):
- ae = create_authn_event("sub", "salt")
- sid = self.sdb.create_authz_session(ae, AREQ, client_id="client_id")
- info = self.sdb[sid]
- assert info["oauth_state"] == "authz"
-
- def test_create_authz_session_with_nonce(self):
- ae = create_authn_event("sub", "salt")
- sid = self.sdb.create_authz_session(ae, AREQN, client_id="client_id")
- info = self.sdb[sid]
- authz_request = info["authn_req"]
- assert authz_request["nonce"] == "something"
-
- def test_create_authz_session_with_id_token(self):
- ae = create_authn_event("sub", "salt")
- sid = self.sdb.create_authz_session(
- ae, AREQ, client_id="client_id", id_token="id_token"
- )
-
- info = self.sdb[sid]
- assert info["id_token"] == "id_token"
-
- def test_create_authz_session_with_oidreq(self):
- ae = create_authn_event("sub", "salt")
- sid = self.sdb.create_authz_session(
- ae, AREQ, client_id="client_id", oidreq=OIDR
- )
- info = self.sdb[sid]
- assert "id_token" not in info
- assert "oidreq" in info
-
- def test_create_authz_session_with_sector_id(self):
- ae = create_authn_event("sub", "salt")
- sid = self.sdb.create_authz_session(
- ae, AREQ, client_id="client_id", oidreq=OIDR
- )
- self.sdb.do_sub(
- sid, "user1", "client_salt", "http://example.com/si.jwt", "pairwise"
- )
-
- info_1 = self.sdb[sid].copy()
- assert "id_token" not in info_1
- assert "oidreq" in info_1
- assert info_1["sub"] != "sub"
-
- self.sdb.do_sub(
- sid, "user2", "client_salt", "http://example.net/si.jwt", "pairwise"
- )
-
- info_2 = self.sdb[sid]
- assert info_2["sub"] != "sub"
- assert info_2["sub"] != info_1["sub"]
-
- def test_upgrade_to_token(self):
- ae1 = create_authn_event("uid", "salt")
- sid = self.sdb.create_authz_session(ae1, AREQ, client_id="client_id")
- self.sdb[sid]["sub"] = "sub"
- grant = self.sdb[sid]["code"]
- _dict = self.sdb.upgrade_to_token(grant)
-
- print(_dict.keys())
- assert set(_dict.keys()) == {
- "sid",
- "authn_event",
- "authn_req",
- "access_token",
- "token_type",
- "client_id",
- "oauth_state",
- "expires_in",
- "expires_at",
- "code_is_used"
- }
-
- # can't update again
- # with pytest.raises(AccessCodeUsed):
- print(self.sdb.upgrade_to_token(grant))
-
- def test_upgrade_to_token_refresh(self):
- ae1 = create_authn_event("sub", "salt")
- sid = self.sdb.create_authz_session(ae1, AREQO, client_id="client_id")
- self.sdb.do_sub(sid, "user", ae1["salt"])
- grant = self.sdb[sid]["code"]
- # Issue an access token trading in the access grant code
- _dict = self.sdb.upgrade_to_token(grant, issue_refresh=True)
-
- print(_dict.keys())
- assert set(_dict.keys()) == {
- "sid",
- "authn_event",
- "authn_req",
- "access_token",
- "sub",
- "token_type",
- "client_id",
- "oauth_state",
- "refresh_token",
- "expires_in",
- "expires_at",
- "code_is_used"
- }
-
- # You can't refresh a token using the token itself
- with pytest.raises(WrongTokenType):
- self.sdb.refresh_token(_dict["access_token"])
-
- def test_upgrade_to_token_with_id_token_and_oidreq(self):
- ae2 = create_authn_event("another_user_id", "salt")
- sid = self.sdb.create_authz_session(ae2, AREQ, client_id="client_id")
- self.sdb[sid]["sub"] = "sub"
- grant = self.sdb[sid]["code"]
-
- _dict = self.sdb.upgrade_to_token(grant, id_token="id_token", oidreq=OIDR)
- print(_dict.keys())
- assert set(_dict.keys()) == {
- "sid",
- "authn_event",
- "authn_req",
- "oidreq",
- "access_token",
- "id_token",
- "token_type",
- "client_id",
- "oauth_state",
- "expires_in",
- "expires_at",
- "code_is_used"
- }
-
- assert _dict["id_token"] == "id_token"
- assert isinstance(_dict["oidreq"], OpenIDRequest)
-
- def test_refresh_token(self):
- ae = create_authn_event("uid", "salt")
- sid = self.sdb.create_authz_session(ae, AREQ, client_id="client_id")
- self.sdb[sid]["sub"] = "sub"
- grant = self.sdb[sid]["code"]
-
- dict1 = self.sdb.upgrade_to_token(grant, issue_refresh=True).copy()
- rtoken = dict1["refresh_token"]
- dict2 = self.sdb.refresh_token(rtoken, AREQ["client_id"])
-
- assert dict1["access_token"] != dict2["access_token"]
-
- with pytest.raises(WrongTokenType):
- self.sdb.refresh_token(dict2["access_token"], AREQ["client_id"])
-
- def test_refresh_token_cleared_session(self):
- ae = create_authn_event("uid", "salt")
- sid = self.sdb.create_authz_session(ae, AREQ, client_id="client_id")
- self.sdb[sid]["sub"] = "sub"
- grant = self.sdb[sid]["code"]
- dict1 = self.sdb.upgrade_to_token(grant, issue_refresh=True)
- ac1 = dict1["access_token"]
-
- # Purge the SessionDB
- self.sdb._db = {}
-
- rtoken = dict1["refresh_token"]
- with pytest.raises(UnknownToken):
- self.sdb.refresh_token(rtoken, AREQ["client_id"])
-
- def test_is_valid(self):
- ae1 = create_authn_event("uid", "salt")
- sid = self.sdb.create_authz_session(ae1, AREQ, client_id="client_id")
- self.sdb[sid]["sub"] = "sub"
- grant = self.sdb[sid]["code"]
-
- assert self.sdb.is_valid("code", grant)
-
- sinfo = self.sdb.upgrade_to_token(grant, issue_refresh=True)
- assert not self.sdb.is_valid("code", grant)
- access_token = sinfo["access_token"]
- assert self.sdb.is_valid("access_token", access_token)
-
- refresh_token = sinfo["refresh_token"]
- sinfo = self.sdb.refresh_token(refresh_token, AREQ["client_id"])
- access_token2 = sinfo["access_token"]
- assert self.sdb.is_valid("access_token", access_token2)
-
- # The old access code should be invalid
- try:
- self.sdb.is_valid("access_token", access_token)
- except KeyError:
- pass
-
- def test_valid_grant(self):
- ae = create_authn_event("another:user", "salt")
- sid = self.sdb.create_authz_session(ae, AREQ, client_id="client_id")
- grant = self.sdb[sid]["code"]
-
- assert self.sdb.is_valid("code", grant)
-
- def test_revoke_token(self):
- ae1 = create_authn_event("uid", "salt")
- sid = self.sdb.create_authz_session(ae1, AREQ, client_id="client_id")
- self.sdb[sid]["sub"] = "sub"
-
- grant = self.sdb[sid]["code"]
- tokens = self.sdb.upgrade_to_token(grant, issue_refresh=True)
- access_token = tokens["access_token"]
- refresh_token = tokens["refresh_token"]
-
- assert self.sdb.is_valid("access_token", access_token)
-
- self.sdb.revoke_token(sid, "access_token")
- assert not self.sdb.is_valid("access_token", access_token)
-
- sinfo = self.sdb.refresh_token(refresh_token, AREQ["client_id"])
- access_token = sinfo["access_token"]
- assert self.sdb.is_valid("access_token", access_token)
-
- self.sdb.revoke_token(sid, "refresh_token")
- assert not self.sdb.is_valid("refresh_token", refresh_token)
-
- ae2 = create_authn_event("sub", "salt")
- sid = self.sdb.create_authz_session(ae2, AREQ, client_id="client_2")
-
- grant = self.sdb[sid]["code"]
- self.sdb.revoke_token(sid, "code")
- assert not self.sdb.is_valid("code", grant)
-
- def test_sub_to_authn_event(self):
- ae = create_authn_event("sub", "salt", time_stamp=time.time())
- sid = self.sdb.create_authz_session(ae, AREQ, client_id="client_id")
- sub = self.sdb.do_sub(sid, "user", "client_salt")
-
- # given the sub find out whether the authn event is still valid
- sids = self.sdb.get_sids_by_sub(sub)
- ae = self.sdb[sids[0]]["authn_event"]
- assert ae.valid()
-
- def test_do_sub_deterministic(self):
- ae = create_authn_event("tester", "random_value")
- sid = self.sdb.create_authz_session(ae, AREQ, client_id="client_id")
- self.sdb.do_sub(sid, "user", "other_random_value")
-
- info = self.sdb[sid]
- assert (
- info["sub"]
- == "d657bddf3d30970aa681663978ea84e26553ead03cb6fe8fcfa6523f2bcd0ad2"
- )
-
- self.sdb.do_sub(
- sid,
- "user",
- "other_random_value",
- sector_id="http://example.com",
- subject_type="pairwise",
- )
- info2 = self.sdb[sid]
- assert (
- info2["sub"]
- == "1442ceb13a822e802f85832ce93a8fda011e32a3363834dd1db3f9aa211065bd"
- )
-
- self.sdb.do_sub(
- sid,
- "user",
- "another_random_value",
- sector_id="http://other.example.com",
- subject_type="pairwise",
- )
-
- info2 = self.sdb[sid]
- assert (
- info2["sub"]
- == "56e0a53d41086e7b22d78d52ee461655e9b090d50a0663d16136ea49a56c9bec"
- )
-
- def test_match_session(self):
- ae1 = create_authn_event("uid", "salt")
- sid = self.sdb.create_authz_session(ae1, AREQ, client_id="client_id")
- self.sdb[sid]["sub"] = "sub"
- self.sdb.sso_db.map_sid2uid(sid, "uid")
-
- res = self.sdb.match_session("uid", client_id="client_id")
- assert res == sid
-
- def test_get_token(self):
- ae1 = create_authn_event("uid", "salt")
- sid = self.sdb.create_authz_session(ae1, AREQ, client_id="client_id")
- self.sdb[sid]["sub"] = "sub"
- self.sdb.sso_db.map_sid2uid(sid, "uid")
-
- grant = self.sdb.get_token(sid)
- assert self.sdb.is_valid("code", grant)
- assert self.sdb.handler.type(grant) == "A"
-
-
-KEYDEFS = [
- {"type": "RSA", "key": "", "use": ["sig"]},
- {"type": "EC", "crv": "P-256", "use": ["sig"]},
-]
-
-conf = {
- "issuer": "https://example.com/",
- "password": "mycket hemligt",
- "token_expires_in": 600,
- "grant_expires_in": 300,
- "refresh_token_expires_in": 86400,
- "verify_ssl": False,
- "capabilities": {},
- "keys": {
- "uri_path": "static/jwks.json",
- "key_defs": KEYDEFS,
- "private_path": "own/jwks.json",
- },
- "endpoint": {
- "provider_config": {
- "path": ".well-known/openid-configuration",
- "class": ProviderConfiguration,
- "kwargs": {},
- },
- "authorization_endpoint": {
- "path": "authorization",
- "class": Authorization,
- "kwargs": {},
- },
- },
- "authentication": {
- "anon": {
- "acr": INTERNETPROTOCOLPASSWORD,
- "class": "oidcendpoint.user_authn.user.NoAuthn",
- "kwargs": {"user": "diana"},
- }
- },
- "userinfo": {"class": UserInfo, "kwargs": {"db_file": full_path("users.json")}},
- "template_dir": "template",
- "sso_db": SSO_DB_CONF,
- "session_db": SESSION_DB_CONF,
-}
-
-
-def test_setup_session():
- endpoint_context = EndpointContext(conf)
- uid = "_user_"
- client_id = "EXTERNAL"
- areq = None
- acr = None
- sid = setup_session(endpoint_context, areq, uid, client_id, acr, salt="salt")
- assert sid
-
-
-def test_setup_session_upgrade_to_token():
- endpoint_context = EndpointContext(conf)
- uid = "_user_"
- client_id = "EXTERNAL"
- areq = None
- acr = None
- sid = setup_session(endpoint_context, areq, uid, client_id, acr, salt="salt")
- assert sid
- code = endpoint_context.sdb[sid]["code"]
- assert code
-
- res = endpoint_context.sdb.upgrade_to_token(code)
- assert "access_token" in res
-
- endpoint_context.sdb.revoke_uid("_user_")
- assert endpoint_context.sdb.is_session_revoked(sid)
-
-
-def make_sub_uid(uid, **kwargs):
- return uid
-
-
-def test_sub_minting_function():
- conf["sub_func"] = {"public": {"function": make_sub_uid}}
-
- endpoint_context = EndpointContext(conf)
- uid = "_user_"
- client_id = "EXTERNAL"
- areq = None
- acr = None
- sid = setup_session(endpoint_context, areq, uid, client_id, acr, salt="salt")
- assert endpoint_context.sdb[sid]["sub"] == uid
-
-
-class SubMinter(object):
- def __call__(self, *args, **kwargs):
- return args[0]
-
-
-def test_sub_minting_class():
- conf["sub_func"] = {"public": {"class": SubMinter}}
-
- endpoint_context = EndpointContext(conf)
- uid = "_user_"
- client_id = "EXTERNAL"
- areq = None
- acr = None
- sid = setup_session(endpoint_context, areq, uid, client_id, acr, salt="salt")
- assert endpoint_context.sdb[sid]["sub"] == uid
diff --git a/tests/test_08_session_life.py b/tests/test_08_session_life.py
new file mode 100644
index 0000000..ad63f27
--- /dev/null
+++ b/tests/test_08_session_life.py
@@ -0,0 +1,490 @@
+import os
+
+import pytest
+from cryptojwt.key_jar import init_key_jar
+from oidcmsg.oidc import AccessTokenRequest
+from oidcmsg.oidc import AuthorizationRequest
+from oidcmsg.oidc import RefreshAccessTokenRequest
+from oidcmsg.time_util import time_sans_frac
+
+from oidcendpoint import user_info
+from oidcendpoint.authn_event import create_authn_event
+from oidcendpoint.client_authn import verify_client
+from oidcendpoint.endpoint_context import EndpointContext
+from oidcendpoint.id_token import IDToken
+from oidcendpoint.oidc.authorization import Authorization
+from oidcendpoint.oidc.provider_config import ProviderConfiguration
+from oidcendpoint.oidc.registration import Registration
+from oidcendpoint.oidc.session import Session
+from oidcendpoint.oidc.token import Token
+from oidcendpoint.session import session_key
+from oidcendpoint.session import unpack_session_key
+from oidcendpoint.session.grant import Grant
+from oidcendpoint.session.info import ClientSessionInfo
+from oidcendpoint.session.info import UserSessionInfo
+from oidcendpoint.session.manager import SessionManager
+from oidcendpoint.session.manager import public_id
+from oidcendpoint.token import DefaultToken
+from oidcendpoint.token.handler import TokenHandler
+from oidcendpoint.token.handler import factory
+from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD
+
+
+class TestSession():
+ @pytest.fixture(autouse=True)
+ def setup_token_handler(self):
+ password = "The longer the better. Is this close to enough ?"
+ conf = {
+ "issuer": "https://example.com/",
+ "password": "mycket hemligt",
+ "token_expires_in": 600,
+ "grant_expires_in": 300,
+ "refresh_token_expires_in": 86400,
+ "verify_ssl": False,
+ "keys": {"key_defs": KEYDEFS, "uri_path": "static/jwks.json"},
+ "jwks_uri": "https://example.com/jwks.json",
+ "token_handler_args": {
+ "code": {
+ "kwargs": {
+ "lifetime": 600,
+ "password": password
+ }},
+ "token": {
+ "kwargs": {
+ "lifetime": 900,
+ "password": password
+ }
+ },
+ "refresh": {
+ "kwargs": {
+ "lifetime": 86400,
+ "password": password
+ }
+ }
+ },
+ "endpoint": {
+ "authorization_endpoint": {
+ "path": "{}/authorization",
+ "class": Authorization,
+ "kwargs": {},
+ },
+ "token_endpoint": {"path": "{}/token", "class": Token, "kwargs": {}},
+ },
+ "template_dir": "template",
+ }
+
+ self.endpoint_context = EndpointContext(conf)
+ token_handler = factory(self.endpoint_context, **conf["token_handler_args"])
+
+ self.session_manager = SessionManager(handler=token_handler)
+
+ def auth(self):
+ # Start with an authentication request
+ # The client ID appears in the request
+ AUTH_REQ = AuthorizationRequest(
+ client_id="client_1",
+ redirect_uri="https://example.com/cb",
+ scope=["openid", "mail", "address", "offline_access"],
+ state="STATE",
+ response_type="code",
+ )
+
+ # The authentication returns a user ID
+ user_id = "diana"
+
+ # User info is stored in the Session DB
+ authn_event = create_authn_event(
+ user_id,
+ authn_info=INTERNETPROTOCOLPASSWORD,
+ authn_time=time_sans_frac(),
+ )
+
+ user_info = UserSessionInfo(user_id=user_id)
+ self.session_manager.set([user_id], user_info)
+
+ # Now for client session information
+ client_id = AUTH_REQ['client_id']
+ client_info = ClientSessionInfo(client_id=client_id)
+ self.session_manager.set([user_id, client_id], client_info)
+
+ # The user consent module produces a Grant instance
+
+ grant = Grant(scope=AUTH_REQ['scope'],
+ resources=[client_id],
+ authorization_request=AUTH_REQ,
+ authentication_event=authn_event)
+
+ # the grant is assigned to a session (user_id, client_id)
+ session_id = session_key(user_id, client_id, grant.id)
+ self.session_manager.set([user_id, client_id, grant.id], grant)
+
+ # Constructing an authorization code is now done by
+
+ code = grant.mint_token(
+ session_id=session_id,
+ endpoint_context=self.endpoint_context,
+ token_type='authorization_code',
+ token_handler= self.session_manager.token_handler["code"],
+ expires_at=time_sans_frac() + 300 # 5 minutes from now
+ )
+
+ return grant.id, code
+
+ def test_code_flow(self):
+ # code is a Token instance
+ _grant_id, code = self.auth()
+
+ # next step is access token request
+
+ TOKEN_REQ = AccessTokenRequest(
+ client_id="client_1",
+ redirect_uri="https://example.com/cb",
+ state="STATE",
+ grant_type="authorization_code",
+ client_secret="hemligt",
+ code=code.value
+ )
+
+ # parse the token
+ session_id = self.session_manager.token_handler.sid(TOKEN_REQ['code'])
+
+ # Now given I have the client_id from the request and the user_id from the
+ # token I can easily find the grant
+
+ # client_info = self.session_manager.get([user_id, TOKEN_REQ['client_id']])
+ tok = self.session_manager.find_token(session_id, TOKEN_REQ['code'])
+
+ # Verify that it's of the correct type and can be used
+ assert tok.type == "authorization_code"
+ assert tok.is_active()
+
+ # Mint an access token and a refresh token and mark the code as used
+
+ assert tok.supports_minting("access_token")
+
+ grant = self.session_manager[session_id]
+
+ access_token = grant.mint_token(
+ session_id=session_id,
+ endpoint_context=self.endpoint_context,
+ token_type='access_token',
+ token_handler=self.session_manager.token_handler["access_token"],
+ expires_at=time_sans_frac() + 900, # 15 minutes from now
+ based_on=tok # Means the token (tok) was used to mint this token
+ )
+
+ assert tok.supports_minting("refresh_token")
+
+ refresh_token = grant.mint_token(
+ session_id=session_id,
+ endpoint_context=self.endpoint_context,
+ token_type='refresh_token',
+ token_handler=self.session_manager.token_handler["refresh_token"],
+ based_on=tok
+ )
+
+ tok.register_usage()
+
+ assert tok.max_usage_reached() is True
+
+ # A bit later a refresh token is used to mint a new access token
+
+ REFRESH_TOKEN_REQ = RefreshAccessTokenRequest(
+ grant_type="refresh_token",
+ client_id="client_1",
+ client_secret="hemligt",
+ refresh_token=refresh_token.value,
+ scope=["openid", "mail", "offline_access"]
+ )
+
+ reftok = self.session_manager.find_token(session_id,
+ REFRESH_TOKEN_REQ['refresh_token'])
+
+ assert reftok.supports_minting("access_token")
+
+ access_token_2 = grant.mint_token(
+ session_id=session_id,
+ endpoint_context=self.endpoint_context,
+ token_type='access_token',
+ token_handler=self.session_manager.token_handler["access_token"],
+ expires_at=time_sans_frac() + 900, # 15 minutes from now
+ based_on=reftok # Means the token (tok) was used to mint this token
+ )
+
+ assert access_token_2.is_active()
+
+
+KEYDEFS = [
+ {"type": "RSA", "key": "", "use": ["sig"]},
+ {"type": "EC", "crv": "P-256", "use": ["sig"]},
+]
+
+ISSUER = "https://example.com/"
+
+KEYJAR = init_key_jar(key_defs=KEYDEFS, issuer_id=ISSUER)
+KEYJAR.import_jwks(KEYJAR.export_jwks(True, ISSUER), "")
+RESPONSE_TYPES_SUPPORTED = [
+ ["code"],
+ ["token"],
+ ["id_token"],
+ ["code", "token"],
+ ["code", "id_token"],
+ ["id_token", "token"],
+ ["code", "token", "id_token"],
+ ["none"],
+]
+CAPABILITIES = {
+ "response_types_supported": [" ".join(x) for x in RESPONSE_TYPES_SUPPORTED],
+ "token_endpoint_auth_methods_supported": [
+ "client_secret_post",
+ "client_secret_basic",
+ "client_secret_jwt",
+ "private_key_jwt",
+ ],
+ "response_modes_supported": ["query", "fragment", "form_post"],
+ "subject_types_supported": ["public", "pairwise", "ephemeral"],
+ "grant_types_supported": [
+ "authorization_code",
+ "implicit",
+ "urn:ietf:params:oauth:grant-type:jwt-bearer",
+ ],
+ "claim_types_supported": ["normal", "aggregated", "distributed"],
+ "claims_parameter_supported": True,
+ "request_parameter_supported": True,
+ "request_uri_parameter_supported": True,
+}
+BASEDIR = os.path.abspath(os.path.dirname(__file__))
+
+
+def full_path(local_file):
+ return os.path.join(BASEDIR, local_file)
+
+
+class TestSessionJWTToken():
+ @pytest.fixture(autouse=True)
+ def setup_session_manager(self):
+ conf = {
+ "issuer": ISSUER,
+ "password": "mycket hemligt",
+ "token_expires_in": 600,
+ "grant_expires_in": 300,
+ "refresh_token_expires_in": 86400,
+ "verify_ssl": False,
+ "capabilities": CAPABILITIES,
+ "keys": {"uri_path": "jwks.json", "key_defs": KEYDEFS},
+ "token_handler_args": {
+ "jwks_def": {
+ "private_path": "private/token_jwks.json",
+ "read_only": False,
+ "key_defs": [
+ {"type": "oct", "bytes": "24", "use": ["enc"], "kid": "code"},
+ {"type": "oct", "bytes": "24", "use": ["enc"], "kid": "refresh"}
+ ],
+ },
+ "code": {"lifetime": 600},
+ "token": {
+ "class": "oidcendpoint.token.jwt_token.JWTToken",
+ "kwargs": {
+ "lifetime": 3600,
+ "add_claims": [
+ "email",
+ "email_verified",
+ "phone_number",
+ "phone_number_verified",
+ ],
+ "add_claim_by_scope": True,
+ "aud": ["https://example.org/appl"],
+ },
+ },
+ "refresh": {},
+ },
+ "endpoint": {
+ "provider_config": {
+ "path": "{}/.well-known/openid-configuration",
+ "class": ProviderConfiguration,
+ "kwargs": {},
+ },
+ "registration": {
+ "path": "{}/registration",
+ "class": Registration,
+ "kwargs": {},
+ },
+ "authorization": {
+ "path": "{}/authorization",
+ "class": Authorization,
+ "kwargs": {},
+ },
+ "token": {"path": "{}/token", "class": Token, "kwargs": {}},
+ "session": {"path": "{}/end_session", "class": Session},
+ },
+ "client_authn": verify_client,
+ "authentication": {
+ "anon": {
+ "acr": INTERNETPROTOCOLPASSWORD,
+ "class": "oidcendpoint.user_authn.user.NoAuthn",
+ "kwargs": {"user": "diana"},
+ }
+ },
+ "template_dir": "template",
+ "userinfo": {
+ "class": user_info.UserInfo,
+ "kwargs": {"db_file": full_path("users.json")},
+ },
+ "id_token": {"class": IDToken},
+ }
+
+ self.endpoint_context = EndpointContext(conf, keyjar=KEYJAR)
+ self.session_manager = self.endpoint_context.session_manager
+ # self.session_manager = SessionManager(handler=self.endpoint_context.sdb.handler)
+ # self.endpoint_context.session_manager = self.session_manager
+
+ def auth(self):
+ # Start with an authentication request
+ # The client ID appears in the request
+ AUTH_REQ = AuthorizationRequest(
+ client_id="client_1",
+ redirect_uri="https://example.com/cb",
+ scope=["openid", "mail", "address", "offline_access"],
+ state="STATE",
+ response_type="code",
+ )
+
+ # The authentication returns a user ID
+ user_id = "diana"
+
+ # User info is stored in the Session DB
+
+ user_info = UserSessionInfo(user_id=user_id)
+ self.session_manager.set([user_id], user_info)
+
+ # Now for client session information
+ authn_event = create_authn_event(
+ user_id,
+ authn_info=INTERNETPROTOCOLPASSWORD,
+ authn_time=time_sans_frac(),
+ )
+
+ client_id = AUTH_REQ["client_id"]
+ client_info = ClientSessionInfo(client_id=client_id)
+ self.session_manager.set([user_id, client_id], client_info)
+
+ # The user consent module produces a Grant instance
+
+ grant = Grant(
+ scope=AUTH_REQ['scope'],
+ resources=[client_id],
+ authentication_event=authn_event,
+ authorization_request=AUTH_REQ
+ )
+
+ # the grant is assigned to a session (user_id, client_id)
+ session_id = session_key(user_id, client_id, grant.id)
+ self.session_manager.set([user_id, client_id, grant.id], grant)
+
+ # Constructing an authorization code is now done by
+
+ code = grant.mint_token(
+ session_id=session_id,
+ endpoint_context=self.endpoint_context,
+ token_type='authorization_code',
+ token_handler=self.session_manager.token_handler["code"],
+ expires_at=time_sans_frac() + 300 # 5 minutes from now
+ )
+ return code
+
+ def test_code_flow(self):
+ # code is a Token instance
+ code = self.auth()
+
+ # next step is access token request
+
+ TOKEN_REQ = AccessTokenRequest(
+ client_id="client_1",
+ redirect_uri="https://example.com/cb",
+ state="STATE",
+ grant_type="authorization_code",
+ client_secret="hemligt",
+ code=code.value
+ )
+
+ # parse the token
+ session_id = self.session_manager.token_handler.sid(TOKEN_REQ['code'])
+ user_id, client_id, grant_id = unpack_session_key(session_id)
+
+ # Now given I have the client_id from the request and the user_id from the
+ # token I can easily find the grant
+
+ # client_info = self.session_manager.get([user_id, TOKEN_REQ['client_id']])
+ tok = self.session_manager.find_token(session_id, TOKEN_REQ['code'])
+
+ # Verify that it's of the correct type and can be used
+ assert tok.type == "authorization_code"
+ assert tok.is_active()
+
+ # Mint an access token and a refresh token and mark the code as used
+
+ assert tok.supports_minting("access_token")
+
+ client_info = self.session_manager.get([user_id, TOKEN_REQ["client_id"]])
+
+ assert tok.supports_minting("access_token")
+
+ grant = self.session_manager[session_id]
+
+ access_token = grant.mint_token(
+ session_id=session_id,
+ endpoint_context=self.endpoint_context,
+ token_type='access_token',
+ token_handler=self.session_manager.token_handler["access_token"],
+ expires_at=time_sans_frac() + 900, # 15 minutes from now
+ based_on=tok # Means the token (tok) was used to mint this token
+ )
+
+ # this test is include in the mint_token methods
+ # assert tok.supports_minting("refresh_token")
+
+ refresh_token = grant.mint_token(
+ session_id=session_id,
+ endpoint_context=self.endpoint_context,
+ token_type='refresh_token',
+ token_handler=self.session_manager.token_handler["refresh_token"],
+ based_on=tok
+ )
+
+ tok.register_usage()
+
+ assert tok.max_usage_reached() is True
+
+ # A bit later a refresh token is used to mint a new access token
+
+ REFRESH_TOKEN_REQ = RefreshAccessTokenRequest(
+ grant_type="refresh_token",
+ client_id="client_1",
+ client_secret="hemligt",
+ refresh_token=refresh_token.value,
+ scope=["openid", "mail", "offline_access"]
+ )
+
+ session_id = session_key(user_id, REFRESH_TOKEN_REQ['client_id'], grant_id)
+ reftok = self.session_manager.find_token(session_id,
+ REFRESH_TOKEN_REQ['refresh_token'])
+
+ # Can I use this token to mint another token ?
+ assert grant.is_active()
+
+ user_claims = self.endpoint_context.userinfo(user_id, client_id=TOKEN_REQ["client_id"],
+ user_info_claims=grant.claims)
+
+ access_token_2 = grant.mint_token(
+ session_id=session_id,
+ endpoint_context=self.endpoint_context,
+ token_type='access_token',
+ token_handler=self.session_manager.token_handler["access_token"],
+ expires_at=time_sans_frac() + 900, # 15 minutes from now
+ based_on=reftok # Means the refresh token (reftok) was used to mint this token
+ )
+
+ assert access_token_2.is_active()
+
+ token_info = self.session_manager.token_handler.info(access_token_2.value)
+ assert token_info
diff --git a/tests/test_09_cookie_dealer.py b/tests/test_09_cookie_dealer.py
index b1db1ae..9c5e41c 100644
--- a/tests/test_09_cookie_dealer.py
+++ b/tests/test_09_cookie_dealer.py
@@ -1,9 +1,12 @@
+import json
+import time
from http.cookies import SimpleCookie
import pytest
from cryptojwt.jwk.hmac import SYMKey
from cryptojwt.key_jar import init_key_jar
+from oidcendpoint import rndstr
from oidcendpoint.cookie import CookieDealer
from oidcendpoint.cookie import append_cookie
from oidcendpoint.cookie import compute_session_state
@@ -11,8 +14,10 @@
from oidcendpoint.cookie import create_session_cookie
from oidcendpoint.cookie import make_cookie
from oidcendpoint.cookie import new_cookie
+from oidcendpoint.cookie import sign_enc_payload
+from oidcendpoint.cookie import ver_dec_content
from oidcendpoint.endpoint_context import EndpointContext
-from oidcendpoint.oidc.token import AccessToken
+from oidcendpoint.oidc.token import Token
KEYDEFS = [
{"type": "RSA", "key": "", "use": ["sig"]},
@@ -228,7 +233,7 @@ def test_append_cookie():
"grant_expires_in": 300,
"refresh_token_expires_in": 86400,
"verify_ssl": False,
- "endpoint": {"token": {"path": "token", "class": AccessToken, "kwargs": {}}},
+ "endpoint": {"token": {"path": "token", "class": Token, "kwargs": {}}},
"template_dir": "template",
"keys": {
"private_path": "own/jwks.json",
@@ -256,6 +261,8 @@ def test_append_cookie():
endpoint_context.cdb[client_id] = {"client_secret": client_secret}
endpoint_context.cookie_dealer = CookieDealer(**cookie_conf)
+enc_key = rndstr(32)
+
def test_new_cookie():
kaka = new_cookie(
@@ -301,3 +308,10 @@ def test_cookie_same_site_none():
assert kaka["test"]["secure"] is True
assert kaka["test"]["httponly"] is True
assert kaka["test"]["samesite"] is "None"
+
+
+def test_cookie_enc():
+ _key = SYMKey(k=enc_key)
+ _enc_data = sign_enc_payload(json.dumps({"test": "data"}), timestamp=time.time(), enc_key=_key)
+ _data, _timestamp = ver_dec_content(_enc_data.split('|'), enc_key=_key)
+ assert json.loads(_data) == {"test": "data"}
diff --git a/tests/test_10_oidc_authz.py b/tests/test_10_oidc_authz.py
deleted file mode 100644
index 535bb81..0000000
--- a/tests/test_10_oidc_authz.py
+++ /dev/null
@@ -1,206 +0,0 @@
-import json
-import os
-
-import pytest
-from cryptojwt.key_jar import build_keyjar
-
-from oidcendpoint.authz import AuthzHandling
-from oidcendpoint.authz import Implicit
-from oidcendpoint.authz import factory
-from oidcendpoint.cookie import CookieDealer
-from oidcendpoint.cookie import new_cookie
-from oidcendpoint.endpoint_context import EndpointContext
-from oidcendpoint.oidc import userinfo
-from oidcendpoint.oidc.authorization import Authorization
-from oidcendpoint.oidc.provider_config import ProviderConfiguration
-from oidcendpoint.oidc.registration import Registration
-from oidcendpoint.oidc.session import Session
-from oidcendpoint.oidc.token import AccessToken
-from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD
-from oidcendpoint.user_info import UserInfo
-
-ISS = "https://example.com/"
-
-KEYDEFS = [
- {"type": "RSA", "use": ["sig"]},
- {"type": "EC", "crv": "P-256", "use": ["sig"]},
-]
-
-KEYJAR = build_keyjar(KEYDEFS)
-KEYJAR.import_jwks(KEYJAR.export_jwks(private=True), ISS)
-
-RESPONSE_TYPES_SUPPORTED = [
- ["code"],
- ["token"],
- ["id_token"],
- ["code", "token"],
- ["code", "id_token"],
- ["id_token", "token"],
- ["code", "id_token", "token"],
- ["none"],
-]
-
-CAPABILITIES = {
- "response_types_supported": [" ".join(x) for x in RESPONSE_TYPES_SUPPORTED],
- "token_endpoint_auth_methods_supported": [
- "client_secret_post",
- "client_secret_basic",
- "client_secret_jwt",
- "private_key_jwt",
- ],
- "response_modes_supported": ["query", "fragment", "form_post"],
- "subject_types_supported": ["public", "pairwise"],
- "grant_types_supported": [
- "authorization_code",
- "implicit",
- "urn:ietf:params:oauth:grant-type:jwt-bearer",
- "refresh_token",
- ],
- "claim_types_supported": ["normal", "aggregated", "distributed"],
- "claims_parameter_supported": True,
- "request_parameter_supported": True,
- "request_uri_parameter_supported": True,
-}
-
-BASEDIR = os.path.abspath(os.path.dirname(__file__))
-
-
-def full_path(local_file):
- return os.path.join(BASEDIR, local_file)
-
-
-USERINFO_db = json.loads(open(full_path("users.json")).read())
-
-
-class TestAuthz(object):
- @pytest.fixture(autouse=True)
- def create_ec(self):
- conf = {
- "issuer": ISS,
- "password": "mycket hemlig zebra",
- "token_expires_in": 600,
- "grant_expires_in": 300,
- "refresh_token_expires_in": 86400,
- "verify_ssl": False,
- "capabilities": CAPABILITIES,
- "keys": {"uri_path": "jwks.json", "key_defs": KEYDEFS},
- "endpoint": {
- "provider_config": {
- "path": "{}/.well-known/openid-configuration",
- "class": ProviderConfiguration,
- "kwargs": {"client_authn_method": None},
- },
- "registration": {
- "path": "{}/registration",
- "class": Registration,
- "kwargs": {"client_authn_method": None},
- },
- "authorization": {
- "path": "{}/authorization",
- "class": Authorization,
- "kwargs": {"client_authn_method": None},
- },
- "token": {"path": "{}/token", "class": AccessToken, "kwargs": {}},
- "userinfo": {
- "path": "{}/userinfo",
- "class": userinfo.UserInfo,
- "kwargs": {"db_file": "users.json"},
- },
- "session": {
- "path": "{}/end_session",
- "class": Session,
- "kwargs": {
- "signing_alg": "ES256",
- "logout_verify_url": "{}/verify_logout".format(ISS),
- },
- },
- },
- "authentication": {
- "anon": {
- "acr": INTERNETPROTOCOLPASSWORD,
- "class": "oidcendpoint.user_authn.user.NoAuthn",
- "kwargs": {"user": "diana"},
- }
- },
- "userinfo": {"class": UserInfo, "kwargs": {"db": USERINFO_db}},
- "template_dir": "template",
- "authz": {"class": AuthzHandling, "kwargs": {}},
- "cookie_dealer": {
- "class": CookieDealer,
- "kwargs": {
- "sign_key": "ghsNKDDLshZTPn974nOsIGhedULrsqnsGoBFBLwUKuJhE2ch",
- "default_values": {
- "name": "oidcop",
- "domain": "127.0.0.1",
- "path": "/",
- "max_age": 3600,
- },
- },
- },
- }
-
- self.endpoint_context = EndpointContext(conf, keyjar=KEYJAR)
-
- def _create_cookie(self, user, sid, state, client_id, typ="sso", name=""):
- ec = self.endpoint_context
- if not name:
- name = ec.cookie_name["session"]
- return new_cookie(
- ec,
- sub=user,
- sid=sid,
- state=state,
- client_id=client_id,
- typ=typ,
- cookie_name=name,
- )
-
- def test_init_authz(self):
- authz = AuthzHandling(self.endpoint_context)
- assert authz
-
- def test_authz_set_get(self):
- authz = self.endpoint_context.authz
- authz.set("diana", "client_1", ["email", "phone"])
- assert authz.get("diana", "client_1") == ["email", "phone"]
-
- def test_authz_cookie(self):
- authz = self.endpoint_context.authz
- authz.set("diana", "client_1", ["email", "phone"])
- cookie = self._create_cookie("diana", "_sid_", "1234567", "client_1")
- perm = authz.permissions(cookie)
- assert set(perm) == {"email", "phone"}
-
- def test_authz_cookie_wrong_client(self):
- authz = self.endpoint_context.authz
- authz.set("diana", "client_1", ["email", "phone"])
- cookie = self._create_cookie("diana", "_sid_", "1234567", "client_2")
- perm = authz.permissions(cookie)
- assert perm is None
-
- def tests_implicit(self):
- authz = Implicit(self.endpoint_context, "any")
- perm = authz.get("foo", "bar")
- assert perm == "any"
-
- def test_factory_implicit(self):
- authz = factory("Implicit", self.endpoint_context, permission="all")
- assert authz.get("foo", "bar") == "all"
-
- def test_factory_authz_handling(self):
- authz = factory("AuthzHandling", self.endpoint_context)
- authz.set("diana", "client_1", ["email", "phone"])
- assert authz.get("diana", "client_1") == ["email", "phone"]
-
- def test_authz_cookie_none(self):
- authz = self.endpoint_context.authz
- authz.set("diana", "client_1", ["email", "phone"])
- assert authz.permissions(None) is None
-
- def test_authz_cookie_other(self):
- authz = self.endpoint_context.authz
- authz.set("diana", "client_1", ["email", "phone"])
- cookie = self._create_cookie(
- "diana", "_sid_", "1234567", "client_1", name="foo"
- )
- assert authz.permissions(cookie) is None
diff --git a/tests/test_12_user_authn.py b/tests/test_12_user_authn.py
index dafbc5d..03845c9 100644
--- a/tests/test_12_user_authn.py
+++ b/tests/test_12_user_authn.py
@@ -29,9 +29,7 @@ def create_endpoint_context(self):
conf = {
"issuer": "https://example.com/",
"password": "mycket hemligt",
- "token_expires_in": 600,
"grant_expires_in": 300,
- "refresh_token_expires_in": 86400,
"verify_ssl": False,
"endpoint": {},
"keys": {"uri_path": "static/jwks.json", "key_defs": KEYDEFS},
diff --git a/tests/test_22_oidc_provider_config_endpoint.py b/tests/test_22_oidc_provider_config_endpoint.py
index 47995fb..5c46441 100755
--- a/tests/test_22_oidc_provider_config_endpoint.py
+++ b/tests/test_22_oidc_provider_config_endpoint.py
@@ -4,7 +4,7 @@
from oidcendpoint.endpoint_context import EndpointContext
from oidcendpoint.oidc.provider_config import ProviderConfiguration
-from oidcendpoint.oidc.token import AccessToken
+from oidcendpoint.oidc.token import Token
KEYDEFS = [
{"type": "RSA", "key": "", "use": ["sig"]},
@@ -31,7 +31,7 @@
"private_key_jwt",
],
"response_modes_supported": ["query", "fragment", "form_post"],
- "subject_types_supported": ["public", "pairwise"],
+ "subject_types_supported": ["public", "pairwise""ephemeral"],
"grant_types_supported": [
"authorization_code",
"implicit",
@@ -63,7 +63,7 @@ def create_endpoint(self):
"class": ProviderConfiguration,
"kwargs": {},
},
- "token": {"path": "token", "class": AccessToken, "kwargs": {}},
+ "token": {"path": "token", "class": Token, "kwargs": {}},
},
"template_dir": "template",
}
diff --git a/tests/test_23_oidc_registration_endpoint.py b/tests/test_23_oidc_registration_endpoint.py
index 0d59e45..71d03ea 100755
--- a/tests/test_23_oidc_registration_endpoint.py
+++ b/tests/test_23_oidc_registration_endpoint.py
@@ -1,9 +1,9 @@
# -*- coding: latin-1 -*-
import json
-from cryptojwt.key_jar import init_key_jar
import pytest
import responses
+from cryptojwt.key_jar import init_key_jar
from oidcmsg.oidc import RegistrationRequest
from oidcmsg.oidc import RegistrationResponse
@@ -12,7 +12,7 @@
from oidcendpoint.oidc.authorization import Authorization
from oidcendpoint.oidc.registration import Registration
from oidcendpoint.oidc.registration import match_sp_sep
-from oidcendpoint.oidc.token import AccessToken
+from oidcendpoint.oidc.token import Token
from oidcendpoint.oidc.userinfo import UserInfo
KEYDEFS = [
@@ -72,7 +72,7 @@ def create_endpoint(self):
"refresh_token_expires_in": 86400,
"verify_ssl": False,
"capabilities": {
- "subject_types_supported": ["public", "pairwise"],
+ "subject_types_supported": ["public", "pairwise", "ephemeral"],
"grant_types_supported": [
"authorization_code",
"implicit",
@@ -108,7 +108,7 @@ def create_endpoint(self):
},
"token": {
"path": "token",
- "class": AccessToken,
+ "class": Token,
"kwargs": {
"client_authn_method": [
"client_secret_post",
diff --git a/tests/test_24_oauth2_authorization_endpoint.py b/tests/test_24_oauth2_authorization_endpoint.py
index 599269d..f2aeb81 100755
--- a/tests/test_24_oauth2_authorization_endpoint.py
+++ b/tests/test_24_oauth2_authorization_endpoint.py
@@ -18,23 +18,23 @@
from oidcmsg.oauth2 import AuthorizationResponse
from oidcmsg.time_util import in_a_while
-from oidcendpoint.common.authorization import FORM_POST
-from oidcendpoint.common.authorization import get_uri
-from oidcendpoint.common.authorization import inputs
-from oidcendpoint.common.authorization import join_query
-from oidcendpoint.common.authorization import verify_uri
+from oidcendpoint.authn_event import create_authn_event
+from oidcendpoint.authz import AuthzHandling
from oidcendpoint.cookie import CookieDealer
from oidcendpoint.endpoint_context import EndpointContext
from oidcendpoint.exception import InvalidRequest
from oidcendpoint.exception import NoSuchAuthentication
from oidcendpoint.exception import RedirectURIError
from oidcendpoint.exception import ToOld
-from oidcendpoint.exception import UnknownClient
from oidcendpoint.exception import UnAuthorizedClientScope
-from oidcendpoint.exception import UnAuthorizedClient
+from oidcendpoint.exception import UnknownClient
from oidcendpoint.id_token import IDToken
+from oidcendpoint.oauth2.authorization import FORM_POST
from oidcendpoint.oauth2.authorization import Authorization
-from oidcendpoint.session import SessionInfo
+from oidcendpoint.oauth2.authorization import get_uri
+from oidcendpoint.oauth2.authorization import inputs
+from oidcendpoint.oauth2.authorization import join_query
+from oidcendpoint.oauth2.authorization import verify_uri
from oidcendpoint.user_info import UserInfo
KEYDEFS = [
@@ -140,9 +140,6 @@ def create_endpoint(self):
conf = {
"issuer": "https://example.com/",
"password": "mycket hemligt zebra",
- "token_expires_in": 600,
- "grant_expires_in": 300,
- "refresh_token_expires_in": 86400,
"verify_ssl": False,
"capabilities": CAPABILITIES,
"keys": {"uri_path": "static/jwks.json", "key_defs": KEYDEFS},
@@ -191,6 +188,24 @@ def create_endpoint(self):
},
},
},
+ "authz": {
+ "class": AuthzHandling,
+ "kwargs": {
+ "grant_config": {
+ "usage_rules": {
+ "authorization_code": {
+ 'supports_minting': ["access_token", "refresh_token", "id_token"],
+ "max_usage": 1
+ },
+ "access_token": {},
+ "refresh_token": {
+ 'supports_minting': ["access_token", "refresh_token", "id_token"],
+ }
+ },
+ "expires_in": 43200
+ }
+ }
+ }
}
endpoint_context = EndpointContext(conf)
_clients = yaml.safe_load(io.StringIO(client_yaml))
@@ -199,6 +214,8 @@ def create_endpoint(self):
endpoint_context.keyjar.export_jwks(True, ""), conf["issuer"]
)
self.endpoint = endpoint_context.endpoint["authorization"]
+ self.session_manager = endpoint_context.session_manager
+ self.user_id = "diana"
self.rp_keyjar = KeyJar()
self.rp_keyjar.add_symmetric("client_1", "hemligtkodord1234567890")
@@ -206,12 +223,24 @@ def create_endpoint(self):
"client_1", "hemligtkodord1234567890"
)
+ def _create_session(self, auth_req, sub_type="public", sector_identifier=''):
+ if sector_identifier:
+ areq = auth_req.copy()
+ areq["sector_identifier_uri"] = sector_identifier
+ else:
+ areq = auth_req
+
+ client_id = areq['client_id']
+ ae = create_authn_event(self.user_id)
+ return self.session_manager.create_session(ae, areq, self.user_id,
+ client_id=client_id,
+ sub_type=sub_type)
+
def test_init(self):
assert self.endpoint
def test_parse(self):
_req = self.endpoint.parse_request(AUTH_REQ_DICT)
-
assert isinstance(_req, AuthorizationRequest)
assert set(_req.keys()) == set(AUTH_REQ.keys())
@@ -223,6 +252,7 @@ def test_process_request(self):
"fragment_enc",
"return_uri",
"cookie",
+ "session_id"
}
def test_do_response_code(self):
@@ -393,24 +423,15 @@ def test_create_authn_response(self):
scope="openid",
)
- _ec = self.endpoint.endpoint_context
- _ec.sdb["session_id"] = SessionInfo(
- authn_req=request,
- uid="diana",
- sub="abcdefghijkl",
- authn_event={
- "authn_info": "loa1",
- "uid": "diana",
- "authn_time": utc_time_sans_frac(),
- },
- )
- _ec.cdb["client_id"] = {
+ self.endpoint.endpoint_context.cdb["client_id"] = {
"client_id": "client_id",
"redirect_uris": [("https://rp.example.com/cb", {})],
"id_token_signed_response_alg": "ES256",
}
- resp = self.endpoint.create_authn_response(request, "session_id")
+ session_id = self._create_session(request)
+
+ resp = self.endpoint.create_authn_response(request, session_id)
assert isinstance(resp["response_args"], AuthorizationErrorResponse)
def test_setup_auth(self):
@@ -434,7 +455,7 @@ def test_setup_auth(self):
)
res = self.endpoint.setup_auth(request, redirect_uri, cinfo, kaka)
- assert set(res.keys()) == {"authn_event", "identity", "user"}
+ assert set(res.keys()) == {"session_id", "identity", "user"}
def test_setup_auth_error(self):
request = AuthorizationRequest(
@@ -512,25 +533,16 @@ def test_setup_auth_user(self):
"redirect_uris": [("https://rp.example.com/cb", {})],
"id_token_signed_response_alg": "RS256",
}
- _ec = self.endpoint.endpoint_context
- _ec.sdb["session_id"] = SessionInfo(
- authn_req=request,
- uid="diana",
- sub="abcdefghijkl",
- authn_event={
- "authn_info": "loa1",
- "uid": "diana",
- "authn_time": utc_time_sans_frac(),
- },
- )
- item = _ec.authn_broker.db["anon"]
+ session_id = self._create_session(request)
+
+ item = self.endpoint.endpoint_context.authn_broker.db["anon"]
item["method"].user = b64e(
- as_bytes(json.dumps({"uid": "krall", "sid": "session_id"}))
+ as_bytes(json.dumps({"uid": "krall", "sid": session_id}))
)
res = self.endpoint.setup_auth(request, redirect_uri, cinfo, None)
- assert set(res.keys()) == {"authn_event", "identity", "user"}
+ assert set(res.keys()) == {"session_id", "identity", "user"}
assert res["identity"]["uid"] == "krall"
def test_setup_auth_session_revoked(self):
@@ -548,22 +560,17 @@ def test_setup_auth_session_revoked(self):
"redirect_uris": [("https://rp.example.com/cb", {})],
"id_token_signed_response_alg": "RS256",
}
- _ec = self.endpoint.endpoint_context
- _ec.sdb["session_id"] = SessionInfo(
- authn_req=request,
- uid="diana",
- sub="abcdefghijkl",
- authn_event={
- "authn_info": "loa1",
- "uid": "diana",
- "authn_time": utc_time_sans_frac(),
- },
- revoked=True,
- )
+ session_id = self._create_session(request)
+
+ _mngr = self.endpoint.endpoint_context.session_manager
+ _csi = _mngr[session_id]
+ _csi.revoked = True
+
+ _ec = self.endpoint.endpoint_context
item = _ec.authn_broker.db["anon"]
item["method"].user = b64e(
- as_bytes(json.dumps({"uid": "krall", "sid": "session_id"}))
+ as_bytes(json.dumps({"uid": "krall", "sid": session_id}))
)
res = self.endpoint.setup_auth(request, redirect_uri, cinfo, None)
@@ -598,6 +605,30 @@ def test_response_mode_fragment(self):
info = self.endpoint.response_mode(request)
assert set(info.keys()) == {"fragment_enc"}
+ # def test_sso(self):
+ # _pr_resp = self.endpoint.parse_request(AUTH_REQ_DICT)
+ # _resp = self.endpoint.process_request(_pr_resp)
+ # msg = self.endpoint.do_response(**_resp)
+ #
+ # request = AuthorizationRequest(
+ # client_id="client_2",
+ # redirect_uri="https://rp.example.org/cb",
+ # response_type=["code"],
+ # state="state",
+ # scope="openid",
+ # )
+ #
+ # cinfo = {
+ # "client_id": "client_2",
+ # "redirect_uris": [(request["redirect_uri"], {})]
+ # }
+ #
+ # _pr_resp = self.endpoint.parse_request(AUTH_REQ_DICT, cookie="kaka")
+ # _resp = self.endpoint.process_request(_pr_resp)
+ # msg = self.endpoint.do_response(**_resp)
+ #
+ # assert set(res.keys()) == {"authn_event", "identity", "user"}
+
def test_inputs():
elems = inputs(dict(foo="bar", home="stead"))
diff --git a/tests/test_24_oauth2_authorization_endpoint_jar.py b/tests/test_24_oauth2_authorization_endpoint_jar.py
index efc1854..24b7644 100755
--- a/tests/test_24_oauth2_authorization_endpoint_jar.py
+++ b/tests/test_24_oauth2_authorization_endpoint_jar.py
@@ -182,6 +182,8 @@ def create_endpoint(self):
endpoint_context.keyjar.export_jwks(True, ""), conf["issuer"]
)
self.endpoint = endpoint_context.endpoint["authorization"]
+ self.session_manager = endpoint_context.session_manager
+ self.user_id = "diana"
self.rp_keyjar = KeyJar()
self.rp_keyjar.add_symmetric("client_1", "hemligtkodord1234567890")
diff --git a/tests/test_24_oidc_authorization_endpoint.py b/tests/test_24_oidc_authorization_endpoint.py
index 9de55a9..03acfa6 100755
--- a/tests/test_24_oidc_authorization_endpoint.py
+++ b/tests/test_24_oidc_authorization_endpoint.py
@@ -9,7 +9,6 @@
import yaml
from cryptojwt import JWT
from cryptojwt import KeyJar
-from cryptojwt.jwt import utc_time_sans_frac
from cryptojwt.utils import as_bytes
from cryptojwt.utils import b64e
from oidcmsg.exception import ParameterError
@@ -21,9 +20,8 @@
from oidcmsg.oidc import verified_claim_name
from oidcmsg.oidc import verify_id_token
-from oidcendpoint.common.authorization import FORM_POST
-from oidcendpoint.common.authorization import join_query
-from oidcendpoint.common.authorization import verify_uri
+from oidcendpoint.authn_event import create_authn_event
+from oidcendpoint.authz import AuthzHandling
from oidcendpoint.cookie import CookieDealer
from oidcendpoint.cookie import cookie_value
from oidcendpoint.cookie import new_cookie
@@ -33,19 +31,20 @@
from oidcendpoint.exception import RedirectURIError
from oidcendpoint.exception import ToOld
from oidcendpoint.exception import UnknownClient
-from oidcendpoint.exception import UnAuthorizedClient
from oidcendpoint.id_token import IDToken
from oidcendpoint.login_hint import LoginHint2Acrs
+from oidcendpoint.oauth2.authorization import FORM_POST
+from oidcendpoint.oauth2.authorization import get_uri
+from oidcendpoint.oauth2.authorization import inputs
+from oidcendpoint.oauth2.authorization import join_query
+from oidcendpoint.oauth2.authorization import verify_uri
from oidcendpoint.oidc import userinfo
from oidcendpoint.oidc.authorization import Authorization
from oidcendpoint.oidc.authorization import acr_claims
-from oidcendpoint.oidc.authorization import get_uri
-from oidcendpoint.oidc.authorization import inputs
from oidcendpoint.oidc.authorization import re_authenticate
from oidcendpoint.oidc.provider_config import ProviderConfiguration
from oidcendpoint.oidc.registration import Registration
-from oidcendpoint.oidc.token import AccessToken
-from oidcendpoint.session import SessionInfo
+from oidcendpoint.oidc.token import Token
from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD
from oidcendpoint.user_authn.authn_context import UNSPECIFIED
from oidcendpoint.user_authn.authn_context import init_method
@@ -72,7 +71,7 @@
]
CAPABILITIES = {
- "subject_types_supported": ["public", "pairwise"],
+ "subject_types_supported": ["public", "pairwise", "ephemeral"],
"grant_types_supported": [
"authorization_code",
"implicit",
@@ -110,7 +109,6 @@ def full_path(local_file):
USERINFO_db = json.loads(open(full_path("users.json")).read())
-
client_yaml = """
oidc_clients:
client_1:
@@ -149,18 +147,17 @@ def create_endpoint(self):
conf = {
"issuer": "https://example.com/",
"password": "mycket hemligt zebra",
- "token_expires_in": 600,
- "grant_expires_in": 300,
- "refresh_token_expires_in": 86400,
"verify_ssl": False,
"capabilities": CAPABILITIES,
"keys": {"uri_path": "static/jwks.json", "key_defs": KEYDEFS},
"id_token": {
"class": IDToken,
"kwargs": {
- "available_claims": {
+ "base_claims": {
"email": {"essential": True},
"email_verified": {"essential": True},
+ "given_name": {"essential": True},
+ "nickname": None
}
},
},
@@ -190,7 +187,7 @@ def create_endpoint(self):
},
"token": {
"path": "token",
- "class": AccessToken,
+ "class": Token,
"kwargs": {
"client_authn_method": [
"client_secret_post",
@@ -222,6 +219,24 @@ def create_endpoint(self):
},
"userinfo": {"class": UserInfo, "kwargs": {"db": USERINFO_db}},
"template_dir": "template",
+ "authz": {
+ "class": AuthzHandling,
+ "kwargs": {
+ "grant_config": {
+ "usage_rules": {
+ "authorization_code": {
+ 'supports_minting': ["access_token", "refresh_token", "id_token"],
+ "max_usage": 1
+ },
+ "access_token": {},
+ "refresh_token": {
+ 'supports_minting': ["access_token", "refresh_token"],
+ }
+ },
+ "expires_in": 43200
+ }
+ }
+ },
"cookie_dealer": {
"class": CookieDealer,
"kwargs": {
@@ -240,12 +255,15 @@ def create_endpoint(self):
},
}
endpoint_context = EndpointContext(conf)
+
_clients = yaml.safe_load(io.StringIO(client_yaml))
endpoint_context.cdb = _clients["oidc_clients"]
endpoint_context.keyjar.import_jwks(
endpoint_context.keyjar.export_jwks(True, ""), conf["issuer"]
)
self.endpoint = endpoint_context.endpoint["authorization"]
+ self.session_manager = endpoint_context.session_manager
+ self.user_id = "diana"
self.rp_keyjar = KeyJar()
self.rp_keyjar.add_symmetric("client_1", "hemligtkodord1234567890")
@@ -256,6 +274,18 @@ def create_endpoint(self):
def test_init(self):
assert self.endpoint
+ def _create_session(self, auth_req, sub_type="public", sector_identifier=''):
+ if sector_identifier:
+ authz_req = auth_req.copy()
+ authz_req["sector_identifier_uri"] = sector_identifier
+ else:
+ authz_req = auth_req
+ client_id = authz_req['client_id']
+ ae = create_authn_event(self.user_id)
+ return self.session_manager.create_session(ae, authz_req, self.user_id,
+ client_id=client_id,
+ sub_type=sub_type)
+
def test_parse(self):
_req = self.endpoint.parse_request(AUTH_REQ_DICT)
@@ -270,6 +300,7 @@ def test_process_request(self):
"fragment_enc",
"return_uri",
"cookie",
+ "session_id"
}
def test_do_response_code(self):
@@ -366,13 +397,12 @@ def test_id_token_claims(self):
_req["nonce"] = "rnd_nonce"
_pr_resp = self.endpoint.parse_request(_req)
_resp = self.endpoint.process_request(_pr_resp)
- idt = verify_id_token(
- _resp["response_args"], keyjar=self.endpoint.endpoint_context.keyjar
- )
+ idt = verify_id_token(_resp["response_args"], keyjar=self.endpoint.endpoint_context.keyjar)
assert idt
- # from claims
- assert "given_name" in _resp["response_args"]["__verified_id_token"]
# from config
+ assert "given_name" in _resp["response_args"]["__verified_id_token"]
+ assert "nickname" in _resp["response_args"]["__verified_id_token"]
+ # Could have gotten email but didn't ask for it
assert "email" in _resp["response_args"]["__verified_id_token"]
def test_re_authenticate(self):
@@ -541,23 +571,15 @@ def test_create_authn_response(self):
)
_ec = self.endpoint.endpoint_context
- _ec.sdb["session_id"] = SessionInfo(
- authn_req=request,
- uid="diana",
- sub="abcdefghijkl",
- authn_event={
- "authn_info": "loa1",
- "uid": "diana",
- "authn_time": utc_time_sans_frac(),
- },
- )
_ec.cdb["client_id"] = {
"client_id": "client_id",
"redirect_uris": [("https://rp.example.com/cb", {})],
"id_token_signed_response_alg": "ES256",
}
- resp = self.endpoint.create_authn_response(request, "session_id")
+ session_id = self._create_session(request)
+
+ resp = self.endpoint.create_authn_response(request, session_id)
assert isinstance(resp["response_args"], AuthorizationErrorResponse)
def test_setup_auth(self):
@@ -581,7 +603,7 @@ def test_setup_auth(self):
)
res = self.endpoint.setup_auth(request, redirect_uri, cinfo, kaka)
- assert set(res.keys()) == {"authn_event", "identity", "user"}
+ assert set(res.keys()) == {"session_id", "identity", "user"}
def test_setup_auth_error(self):
request = AuthorizationRequest(
@@ -628,24 +650,16 @@ def test_setup_auth_user(self):
"id_token_signed_response_alg": "RS256",
}
_ec = self.endpoint.endpoint_context
- _ec.sdb["session_id"] = SessionInfo(
- authn_req=request,
- uid="diana",
- sub="abcdefghijkl",
- authn_event={
- "authn_info": "loa1",
- "uid": "diana",
- "authn_time": utc_time_sans_frac(),
- },
- )
+
+ session_id = self._create_session(request)
item = _ec.authn_broker.db["anon"]
item["method"].user = b64e(
- as_bytes(json.dumps({"uid": "krall", "sid": "session_id"}))
+ as_bytes(json.dumps({"uid": "krall", "sid": session_id}))
)
res = self.endpoint.setup_auth(request, redirect_uri, cinfo, None)
- assert set(res.keys()) == {"authn_event", "identity", "user"}
+ assert set(res.keys()) == {"session_id", "identity", "user"}
assert res["identity"]["uid"] == "krall"
def test_setup_auth_session_revoked(self):
@@ -664,23 +678,17 @@ def test_setup_auth_session_revoked(self):
"id_token_signed_response_alg": "RS256",
}
_ec = self.endpoint.endpoint_context
- _ec.sdb["session_id"] = SessionInfo(
- authn_req=request,
- uid="diana",
- sub="abcdefghijkl",
- authn_event={
- "authn_info": "loa1",
- "uid": "diana",
- "authn_time": utc_time_sans_frac(),
- },
- revoked=True,
- )
+
+ session_id = self._create_session(request)
item = _ec.authn_broker.db["anon"]
item["method"].user = b64e(
- as_bytes(json.dumps({"uid": "krall", "sid": "session_id"}))
+ as_bytes(json.dumps({"uid": "krall", "sid": session_id}))
)
+ grant = _ec.session_manager[session_id]
+ grant.revoked = True
+
res = self.endpoint.setup_auth(request, redirect_uri, cinfo, None)
assert set(res.keys()) == {"args", "function"}
diff --git a/tests/test_25_oidc_token_endpoint.py b/tests/test_25_oidc_token_endpoint.py
deleted file mode 100755
index 43be33f..0000000
--- a/tests/test_25_oidc_token_endpoint.py
+++ /dev/null
@@ -1,266 +0,0 @@
-import json
-import os
-
-import pytest
-from cryptojwt import JWT
-from cryptojwt.key_jar import build_keyjar
-from oidcmsg.oidc import AccessTokenRequest
-from oidcmsg.oidc import AuthorizationRequest
-from oidcmsg.oidc import RefreshAccessTokenRequest
-
-from oidcendpoint import JWT_BEARER
-from oidcendpoint.client_authn import verify_client
-from oidcendpoint.endpoint_context import EndpointContext
-from oidcendpoint.exception import MultipleUsage
-from oidcendpoint.exception import UnAuthorizedClient
-from oidcendpoint.oidc import userinfo
-from oidcendpoint.oidc.authorization import Authorization
-from oidcendpoint.oidc.provider_config import ProviderConfiguration
-from oidcendpoint.oidc.refresh_token import RefreshAccessToken
-from oidcendpoint.oidc.registration import Registration
-from oidcendpoint.oidc.token import AccessToken
-from oidcendpoint.session import setup_session
-from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD
-from oidcendpoint.user_info import UserInfo
-
-KEYDEFS = [
- {"type": "RSA", "key": "", "use": ["sig"]},
- {"type": "EC", "crv": "P-256", "use": ["sig"]},
-]
-
-CLIENT_KEYJAR = build_keyjar(KEYDEFS)
-
-
-RESPONSE_TYPES_SUPPORTED = [
- ["code"],
- ["token"],
- ["id_token"],
- ["code", "token"],
- ["code", "id_token"],
- ["id_token", "token"],
- ["code", "token", "id_token"],
- ["none"],
-]
-
-CAPABILITIES = {
- "subject_types_supported": ["public", "pairwise"],
- "grant_types_supported": [
- "authorization_code",
- "implicit",
- "urn:ietf:params:oauth:grant-type:jwt-bearer",
- "refresh_token",
- ],
-}
-
-AUTH_REQ = AuthorizationRequest(
- client_id="client_1",
- redirect_uri="https://example.com/cb",
- scope=["openid"],
- state="STATE",
- response_type="code",
-)
-
-TOKEN_REQ = AccessTokenRequest(
- client_id="client_1",
- redirect_uri="https://example.com/cb",
- state="STATE",
- grant_type="authorization_code",
- client_secret="hemligt",
-)
-
-REFRESH_TOKEN_REQ = RefreshAccessTokenRequest(
- grant_type="refresh_token", client_id="client_1", client_secret="hemligt"
-)
-
-TOKEN_REQ_DICT = TOKEN_REQ.to_dict()
-
-BASEDIR = os.path.abspath(os.path.dirname(__file__))
-
-
-def full_path(local_file):
- return os.path.join(BASEDIR, local_file)
-
-
-USERINFO = UserInfo(json.loads(open(full_path("users.json")).read()))
-
-
-class TestEndpoint(object):
- @pytest.fixture(autouse=True)
- def create_endpoint(self):
- conf = {
- "issuer": "https://example.com/",
- "password": "mycket hemligt",
- "token_expires_in": 600,
- "grant_expires_in": 300,
- "refresh_token_expires_in": 86400,
- "verify_ssl": False,
- "capabilities": CAPABILITIES,
- "keys": {"uri_path": "jwks.json", "key_defs": KEYDEFS},
- "endpoint": {
- "provider_config": {
- "path": ".well-known/openid-configuration",
- "class": ProviderConfiguration,
- "kwargs": {},
- },
- "registration": {
- "path": "registration",
- "class": Registration,
- "kwargs": {},
- },
- "authorization": {
- "path": "authorization",
- "class": Authorization,
- "kwargs": {},
- },
- "token": {
- "path": "token",
- "class": AccessToken,
- "kwargs": {
- "client_authn_method": [
- "client_secret_basic",
- "client_secret_post",
- "client_secret_jwt",
- "private_key_jwt",
- ]
- },
- },
- "refresh_token": {
- "path": "token",
- "class": RefreshAccessToken,
- "kwargs": {},
- },
- "userinfo": {
- "path": "userinfo",
- "class": userinfo.UserInfo,
- "kwargs": {"db_file": "users.json"},
- },
- },
- "authentication": {
- "anon": {
- "acr": INTERNETPROTOCOLPASSWORD,
- "class": "oidcendpoint.user_authn.user.NoAuthn",
- "kwargs": {"user": "diana"},
- }
- },
- "userinfo": {"class": UserInfo, "kwargs": {"db": {}}},
- "client_authn": verify_client,
- "template_dir": "template",
- }
- endpoint_context = EndpointContext(conf)
- endpoint_context.cdb["client_1"] = {
- "client_secret": "hemligt",
- "redirect_uris": [("https://example.com/cb", None)],
- "client_salt": "salted",
- "token_endpoint_auth_method": "client_secret_post",
- "response_types": ["code", "token", "code id_token", "id_token"],
- }
- endpoint_context.keyjar.import_jwks(CLIENT_KEYJAR.export_jwks(), "client_1")
- self.endpoint = endpoint_context.endpoint["token"]
-
- def test_init(self):
- assert self.endpoint
-
- def test_signing_alg_values(self):
- """
- According to
- https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata
- The value "none" MUST NOT be used for the
- token_endpoint_auth_signing_alg_values_supported field.
- """
- assert "none" not in self.endpoint.endpoint_info[
- "token_endpoint_auth_signing_alg_values_supported"
- ]
-
- def test_parse(self):
- session_id = setup_session(self.endpoint.endpoint_context, AUTH_REQ, uid="user")
- _token_request = TOKEN_REQ_DICT.copy()
- _token_request["code"] = self.endpoint.endpoint_context.sdb[session_id]["code"]
- _req = self.endpoint.parse_request(_token_request)
-
- assert isinstance(_req, AccessTokenRequest)
- assert set(_req.keys()) == set(_token_request.keys())
-
- def test_process_request(self):
- session_id = setup_session(
- self.endpoint.endpoint_context,
- AUTH_REQ,
- uid="user",
- acr=INTERNETPROTOCOLPASSWORD,
- )
- _token_request = TOKEN_REQ_DICT.copy()
- _context = self.endpoint.endpoint_context
- _token_request["code"] = _context.sdb[session_id]["code"]
- _context.sdb.update(session_id, user="diana")
- _req = self.endpoint.parse_request(_token_request)
-
- _resp = self.endpoint.process_request(request=_req)
-
- assert _resp
- assert set(_resp.keys()) == {"http_headers", "response_args"}
-
- def test_process_request_using_code_twice(self):
- session_id = setup_session(
- self.endpoint.endpoint_context,
- AUTH_REQ,
- uid="user",
- acr=INTERNETPROTOCOLPASSWORD,
- )
- _token_request = TOKEN_REQ_DICT.copy()
- _context = self.endpoint.endpoint_context
- _token_request["code"] = _context.sdb[session_id]["code"]
- _context.sdb.update(session_id, user="diana")
- _req = self.endpoint.parse_request(_token_request)
- _resp = self.endpoint.process_request(request=_req)
-
- # 2nd time used
- # TODO: There is a bug in _post_parse_request, the returned error
- # should be invalid_grant, not invalid_client
- _req = self.endpoint.parse_request(_token_request)
- _resp = self.endpoint.process_request(request=_req)
-
- assert _resp
- assert set(_resp.keys()) == {"error"}
-
- def test_do_response(self):
- session_id = setup_session(
- self.endpoint.endpoint_context,
- AUTH_REQ,
- uid="user",
- acr=INTERNETPROTOCOLPASSWORD,
- )
- self.endpoint.endpoint_context.sdb.update(session_id, user="diana")
- _token_request = TOKEN_REQ_DICT.copy()
- _token_request["code"] = self.endpoint.endpoint_context.sdb[session_id]["code"]
- _req = self.endpoint.parse_request(_token_request)
-
- _resp = self.endpoint.process_request(request=_req)
- msg = self.endpoint.do_response(request=_req, **_resp)
- assert isinstance(msg, dict)
-
- def test_process_request_using_private_key_jwt(self):
- session_id = setup_session(
- self.endpoint.endpoint_context,
- AUTH_REQ,
- uid="user",
- acr=INTERNETPROTOCOLPASSWORD,
- )
- _token_request = TOKEN_REQ_DICT.copy()
- del _token_request["client_id"]
- del _token_request["client_secret"]
- _context = self.endpoint.endpoint_context
-
- _jwt = JWT(CLIENT_KEYJAR, iss=AUTH_REQ["client_id"], sign_alg="RS256")
- _jwt.with_jti = True
- _assertion = _jwt.pack({"aud": [_context.endpoint["token"].full_path]})
- _token_request.update(
- {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER}
- )
- _token_request["code"] = self.endpoint.endpoint_context.sdb[session_id]["code"]
-
- _context.sdb.update(session_id, user="diana")
- _req = self.endpoint.parse_request(_token_request)
- _resp = self.endpoint.process_request(request=_req)
-
- # 2nd time used
- with pytest.raises(UnAuthorizedClient):
- self.endpoint.parse_request(_token_request)
diff --git a/tests/test_26_oidc_userinfo_endpoint.py b/tests/test_26_oidc_userinfo_endpoint.py
index 41f2bb9..08e0757 100755
--- a/tests/test_26_oidc_userinfo_endpoint.py
+++ b/tests/test_26_oidc_userinfo_endpoint.py
@@ -2,19 +2,20 @@
import os
import pytest
-from cryptojwt.jwt import utc_time_sans_frac
+from oidcmsg.oauth2 import ResponseMessage
from oidcmsg.oidc import AccessTokenRequest
from oidcmsg.oidc import AuthorizationRequest
+from oidcmsg.time_util import time_sans_frac
from oidcendpoint import user_info
+from oidcendpoint.authn_event import create_authn_event
from oidcendpoint.endpoint_context import EndpointContext
from oidcendpoint.id_token import IDToken
from oidcendpoint.oidc import userinfo
from oidcendpoint.oidc.authorization import Authorization
from oidcendpoint.oidc.provider_config import ProviderConfiguration
from oidcendpoint.oidc.registration import Registration
-from oidcendpoint.oidc.token import AccessToken
-from oidcendpoint.session import setup_session
+from oidcendpoint.oidc.token import Token
from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD
from oidcendpoint.user_info import UserInfo
@@ -35,7 +36,7 @@
]
CAPABILITIES = {
- "subject_types_supported": ["public", "pairwise"],
+ "subject_types_supported": ["public", "pairwise", "ephemeral"],
"grant_types_supported": [
"authorization_code",
"implicit",
@@ -103,7 +104,7 @@ def create_endpoint(self):
},
"token": {
"path": "token",
- "class": AccessToken,
+ "class": Token,
"kwargs": {
"client_authn_methods": [
"client_secret_post",
@@ -165,133 +166,123 @@ def create_endpoint(self):
"response_types": ["code", "token", "code id_token", "id_token"],
}
self.endpoint = endpoint_context.endpoint["userinfo"]
+ self.session_manager = endpoint_context.session_manager
+ self.user_id = "diana"
+
+ def _create_session(self, auth_req, sub_type="public", sector_identifier=''):
+ if sector_identifier:
+ authz_req = auth_req.copy()
+ authz_req["sector_identifier_uri"] = sector_identifier
+ else:
+ authz_req = auth_req
+ client_id = authz_req['client_id']
+ ae = create_authn_event(self.user_id)
+ return self.session_manager.create_session(ae, authz_req, self.user_id,
+ client_id=client_id,
+ sub_type=sub_type)
+
+ def _mint_code(self, grant, session_id):
+ # Constructing an authorization code is now done
+ return grant.mint_token(
+ session_id=session_id,
+ endpoint_context=self.endpoint.endpoint_context,
+ token_type='authorization_code',
+ token_handler=self.session_manager.token_handler["code"],
+ expires_at=time_sans_frac() + 300 # 5 minutes from now
+ )
+
+ def _mint_token(self, token_type, grant, session_id, token_ref=None):
+ _session_info = self.session_manager.get_session_info(session_id, grant=True)
+ return grant.mint_token(
+ session_id=session_id,
+ endpoint_context=self.endpoint.endpoint_context,
+ token_type=token_type,
+ token_handler=self.session_manager.token_handler[token_type],
+ expires_at=time_sans_frac() + 900, # 15 minutes from now
+ based_on=token_ref # Means the token (tok) was used to mint this token
+ )
def test_init(self):
assert self.endpoint
assert set(
self.endpoint.endpoint_context.provider_info["claims_supported"]
) == {
- "address",
- "birthdate",
- "email",
- "email_verified",
- "eduperson_scoped_affiliation",
- "family_name",
- "gender",
- "given_name",
- "locale",
- "middle_name",
- "name",
- "nickname",
- "phone_number",
- "phone_number_verified",
- "picture",
- "preferred_username",
- "profile",
- "sub",
- "updated_at",
- "website",
- "zoneinfo",
- }
+ "address",
+ "birthdate",
+ "email",
+ "email_verified",
+ "eduperson_scoped_affiliation",
+ "family_name",
+ "gender",
+ "given_name",
+ "locale",
+ "middle_name",
+ "name",
+ "nickname",
+ "phone_number",
+ "phone_number_verified",
+ "picture",
+ "preferred_username",
+ "profile",
+ "sub",
+ "updated_at",
+ "website",
+ "zoneinfo",
+ }
def test_parse(self):
- session_id = setup_session(
- self.endpoint.endpoint_context,
- AUTH_REQ,
- uid="userID",
- authn_event={
- "authn_info": "loa1",
- "uid": "diana",
- "authn_time": utc_time_sans_frac(),
- "valid_until": utc_time_sans_frac() + 3600,
- },
- )
- _dic = self.endpoint.endpoint_context.sdb.upgrade_to_token(key=session_id)
- _req = self.endpoint.parse_request(
- {}, auth="Bearer {}".format(_dic["access_token"])
- )
+ session_id = self._create_session(AUTH_REQ)
+ grant = self.session_manager[session_id]
+ # Free standing access token, not based on an authorization code
+ access_token = self._mint_token("access_token", grant, session_id)
+ _req = self.endpoint.parse_request({}, auth="Bearer {}".format(access_token.value))
assert set(_req.keys()) == {"client_id", "access_token"}
+ assert _req["client_id"] == AUTH_REQ['client_id']
+ assert _req["access_token"] == access_token.value
def test_parse_invalid_token(self):
_req = self.endpoint.parse_request({}, auth="Bearer invalid")
-
assert _req['error'] == "invalid_token"
def test_process_request(self):
- session_id = setup_session(
- self.endpoint.endpoint_context,
- AUTH_REQ,
- uid="userID",
- authn_event={
- "authn_info": "loa1",
- "uid": "diana",
- "authn_time": utc_time_sans_frac(),
- "valid_until": utc_time_sans_frac() + 3600,
- },
- )
- _dic = self.endpoint.endpoint_context.sdb.upgrade_to_token(key=session_id)
+ session_id = self._create_session(AUTH_REQ)
+ grant = self.session_manager[session_id]
+ code = self._mint_code(grant, session_id)
+ access_token = self._mint_token("access_token", grant, session_id, code)
+
_req = self.endpoint.parse_request(
- {}, auth="Bearer {}".format(_dic["access_token"])
+ {}, auth="Bearer {}".format(access_token.value)
)
args = self.endpoint.process_request(_req)
assert args
def test_process_request_not_allowed(self):
- session_id = setup_session(
- self.endpoint.endpoint_context,
- AUTH_REQ,
- uid="userID",
- authn_event={
- "authn_info": "loa1",
- "uid": "diana",
- "authn_time": utc_time_sans_frac() - 7200,
- "valid_until": utc_time_sans_frac() - 3600,
- },
- )
- _dic = self.endpoint.endpoint_context.sdb.upgrade_to_token(key=session_id)
- _req = self.endpoint.parse_request(
- {}, auth="Bearer {}".format(_dic["access_token"])
- )
- args = self.endpoint.process_request(_req)
- assert set(args["response_args"].keys()) == {"error", "error_description"}
+ session_id = self._create_session(AUTH_REQ)
+ grant = self.session_manager[session_id]
+ code = self._mint_code(grant, session_id)
+ access_token = self._mint_token("access_token", grant, session_id, code)
+
+ # 2 things can make the request invalid.
+ # 1) The token is not valid anymore or 2) The event is not valid.
+ _event = grant.authentication_event
+ _event['authn_time'] -= 9000
+ _event['valid_until'] -= 9000
- def test_process_request_offline_access(self):
- auth_req = AUTH_REQ.copy()
- auth_req["scope"] = ["openid", "offline_access"]
- session_id = setup_session(
- self.endpoint.endpoint_context,
- auth_req,
- uid="userID",
- authn_event={
- "authn_info": "loa1",
- "uid": "diana",
- "authn_time": utc_time_sans_frac() ,
- "valid_until": utc_time_sans_frac() + 3600,
- },
- )
- _dic = self.endpoint.endpoint_context.sdb.upgrade_to_token(key=session_id)
_req = self.endpoint.parse_request(
- {}, auth="Bearer {}".format(_dic["access_token"])
+ {}, auth="Bearer {}".format(access_token.value)
)
args = self.endpoint.process_request(_req)
- assert set(args["response_args"].keys()) == {"sub"}
+ assert set(args["response_args"].keys()) == {"error", "error_description"}
def test_do_response(self):
- session_id = setup_session(
- self.endpoint.endpoint_context,
- AUTH_REQ,
- uid="userID",
- authn_event={
- "authn_info": "loa1",
- "uid": "diana",
- "authn_time": utc_time_sans_frac(),
- "valid_until": utc_time_sans_frac() + 3600,
- },
- )
- _dic = self.endpoint.endpoint_context.sdb.upgrade_to_token(key=session_id)
+ session_id = self._create_session(AUTH_REQ)
+ grant = self.session_manager[session_id]
+ code = self._mint_code(grant,session_id)
+ access_token = self._mint_token("access_token", grant, session_id, code)
+
_req = self.endpoint.parse_request(
- {}, auth="Bearer {}".format(_dic["access_token"])
+ {}, auth="Bearer {}".format(access_token.value)
)
args = self.endpoint.process_request(_req)
assert args
@@ -299,24 +290,15 @@ def test_do_response(self):
assert res
def test_do_signed_response(self):
- self.endpoint.endpoint_context.cdb["client_1"][
- "userinfo_signed_response_alg"
- ] = "ES256"
-
- session_id = setup_session(
- self.endpoint.endpoint_context,
- AUTH_REQ,
- uid="userID",
- authn_event={
- "authn_info": "loa1",
- "uid": "diana",
- "authn_time": utc_time_sans_frac(),
- "valid_until": utc_time_sans_frac() + 3600,
- },
- )
- _dic = self.endpoint.endpoint_context.sdb.upgrade_to_token(key=session_id)
+ self.endpoint.endpoint_context.cdb["client_1"]["userinfo_signed_response_alg"] = "ES256"
+
+ session_id = self._create_session(AUTH_REQ)
+ grant = self.session_manager[session_id]
+ code = self._mint_code(grant, session_id)
+ access_token = self._mint_token("access_token", grant, session_id, code)
+
_req = self.endpoint.parse_request(
- {}, auth="Bearer {}".format(_dic["access_token"])
+ {}, auth="Bearer {}".format(access_token.value)
)
args = self.endpoint.process_request(_req)
assert args
@@ -326,28 +308,55 @@ def test_do_signed_response(self):
def test_custom_scope(self):
_auth_req = AUTH_REQ.copy()
_auth_req["scope"] = ["openid", "research_and_scholarship"]
- session_id = setup_session(
- self.endpoint.endpoint_context,
- _auth_req,
- uid="userID",
- authn_event={
- "authn_info": "loa1",
- "uid": "diana",
- "authn_time": utc_time_sans_frac(),
- "valid_until": utc_time_sans_frac() + 3600,
- },
+
+ session_id = self._create_session(_auth_req)
+ grant = self.session_manager[session_id]
+ access_token = self._mint_token("access_token", grant, session_id)
+
+ self.endpoint.kwargs["add_claims_by_scope"] = True
+ self.endpoint.endpoint_context.claims_interface.add_claims_by_scope = True
+ grant.claims = {
+ "userinfo": self.endpoint.endpoint_context.claims_interface.get_claims(
+ session_id=session_id, scopes=_auth_req["scope"], usage="userinfo")
+ }
+
+ _req = self.endpoint.parse_request(
+ {}, auth="Bearer {}".format(access_token.value)
)
- _dic = self.endpoint.endpoint_context.sdb.upgrade_to_token(key=session_id)
+ args = self.endpoint.process_request(_req)
+ assert set(args["response_args"].keys()) == {'eduperson_scoped_affiliation', 'given_name',
+ 'email_verified', 'email', 'family_name',
+ 'name', "sub"}
+
+ def test_wrong_type_of_token(self):
+ _auth_req = AUTH_REQ.copy()
+ _auth_req["scope"] = ["openid", "research_and_scholarship"]
+
+ session_id = self._create_session(_auth_req)
+ grant = self.session_manager[session_id]
+ refresh_token = self._mint_token("refresh_token", grant, session_id)
+
_req = self.endpoint.parse_request(
- {}, auth="Bearer {}".format(_dic["access_token"])
+ {}, auth="Bearer {}".format(refresh_token.value)
)
args = self.endpoint.process_request(_req)
- assert set(args["response_args"].keys()) == {
- "sub",
- "name",
- "given_name",
- "family_name",
- "email",
- "email_verified",
- "eduperson_scoped_affiliation",
- }
+
+ assert isinstance(args, ResponseMessage)
+ assert args['error_description'] == "Wrong type of token"
+
+ def test_invalid_token(self):
+ _auth_req = AUTH_REQ.copy()
+ _auth_req["scope"] = ["openid", "research_and_scholarship"]
+
+ session_id = self._create_session(_auth_req)
+ grant = self.session_manager[session_id]
+ access_token = self._mint_token("access_token", grant, session_id)
+
+ _req = self.endpoint.parse_request(
+ {}, auth="Bearer {}".format(access_token.value)
+ )
+ access_token.expires_at = time_sans_frac() - 10
+ args = self.endpoint.process_request(_req)
+
+ assert isinstance(args, ResponseMessage)
+ assert args['error_description'] == "Invalid Token"
diff --git a/tests/test_27_jwt_token.py b/tests/test_27_jwt_token.py
index da0336d..13ef7f0 100644
--- a/tests/test_27_jwt_token.py
+++ b/tests/test_27_jwt_token.py
@@ -2,21 +2,24 @@
import pytest
from cryptojwt.jwt import JWT
-from cryptojwt.jwt import utc_time_sans_frac
from cryptojwt.key_jar import init_key_jar
from oidcmsg.oidc import AccessTokenRequest
from oidcmsg.oidc import AuthorizationRequest
+from oidcmsg.time_util import time_sans_frac
from oidcendpoint import user_info
+from oidcendpoint.authn_event import create_authn_event
+from oidcendpoint.authz import AuthzHandling
from oidcendpoint.client_authn import verify_client
from oidcendpoint.endpoint_context import EndpointContext
from oidcendpoint.id_token import IDToken
+from oidcendpoint.oauth2.introspection import Introspection
from oidcendpoint.oidc.authorization import Authorization
from oidcendpoint.oidc.provider_config import ProviderConfiguration
from oidcendpoint.oidc.registration import Registration
from oidcendpoint.oidc.session import Session
-from oidcendpoint.oidc.token import AccessToken
-from oidcendpoint.session import setup_session
+from oidcendpoint.oidc.token import Token
+from oidcendpoint.session import session_key
from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD
KEYDEFS = [
@@ -49,7 +52,7 @@
"private_key_jwt",
],
"response_modes_supported": ["query", "fragment", "form_post"],
- "subject_types_supported": ["public", "pairwise"],
+ "subject_types_supported": ["public", "pairwise", "ephemeral"],
"grant_types_supported": [
"authorization_code",
"implicit",
@@ -81,6 +84,12 @@
BASEDIR = os.path.abspath(os.path.dirname(__file__))
+MAP = {
+ "authorization_code": "code",
+ "access_token": "access_token",
+ "refresh_token": "refresh_token"
+}
+
def full_path(local_file):
return os.path.join(BASEDIR, local_file)
@@ -92,9 +101,6 @@ def create_endpoint(self):
conf = {
"issuer": ISSUER,
"password": "mycket hemligt",
- "token_expires_in": 600,
- "grant_expires_in": 300,
- "refresh_token_expires_in": 86400,
"verify_ssl": False,
"capabilities": CAPABILITIES,
"keys": {"uri_path": "jwks.json", "key_defs": KEYDEFS},
@@ -108,16 +114,18 @@ def create_endpoint(self):
},
"code": {"lifetime": 600},
"token": {
- "class": "oidcendpoint.jwt_token.JWTToken",
+ "class": "oidcendpoint.token.jwt_token.JWTToken",
+ "kwargs": {
+ "lifetime": 3600,
+ "base_claims": {"eduperson_scoped_affiliation": None},
+ "add_claims_by_scope": True,
+ "aud": ["https://example.org/appl"],
+ },
+ },
+ "refresh": {
+ "class": "oidcendpoint.token.jwt_token.JWTToken",
"kwargs": {
"lifetime": 3600,
- "add_claims": [
- "email",
- "email_verified",
- "phone_number",
- "phone_number_verified",
- ],
- "add_claim_by_scope": True,
"aud": ["https://example.org/appl"],
},
},
@@ -138,8 +146,12 @@ def create_endpoint(self):
"class": Authorization,
"kwargs": {},
},
- "token": {"path": "{}/token", "class": AccessToken, "kwargs": {}},
+ "token": {"path": "{}/token", "class": Token, "kwargs": {}},
"session": {"path": "{}/end_session", "class": Session},
+ "introspection": {
+ "path": "{}/introspection",
+ "class": Introspection
+ }
},
"client_authn": verify_client,
"authentication": {
@@ -155,81 +167,114 @@ def create_endpoint(self):
"kwargs": {"db_file": full_path("users.json")},
},
"id_token": {"class": IDToken},
+ "authz": {
+ "class": AuthzHandling,
+ "kwargs": {
+ "grant_config": {
+ "usage_rules": {
+ "authorization_code": {
+ 'supports_minting': ["access_token", "refresh_token", "id_token"],
+ "max_usage": 1
+ },
+ "access_token": {},
+ "refresh_token": {
+ 'supports_minting': ["access_token", "refresh_token"],
+ }
+ },
+ "expires_in": 43200
+ }
+ }
+ },
}
- endpoint_context = EndpointContext(conf, keyjar=KEYJAR)
- endpoint_context.cdb["client_1"] = {
+ self.endpoint_context = EndpointContext(conf, keyjar=KEYJAR)
+ self.endpoint_context.cdb["client_1"] = {
"client_secret": "hemligt",
"redirect_uris": [("https://example.com/cb", None)],
"client_salt": "salted",
"token_endpoint_auth_method": "client_secret_post",
"response_types": ["code", "token", "code id_token", "id_token"],
}
- self.endpoint = Session(endpoint_context)
+ self.session_manager = self.endpoint_context.session_manager
+ self.user_id = "diana"
+ self.endpoint = self.endpoint_context.endpoint["session"]
- def test_parse(self):
- session_id = setup_session(
- self.endpoint.endpoint_context, AUTH_REQ, uid="diana"
+ def _create_session(self, auth_req, sub_type="public", sector_identifier=''):
+ if sector_identifier:
+ authz_req = auth_req.copy()
+ authz_req["sector_identifier_uri"] = sector_identifier
+ else:
+ authz_req = auth_req
+ client_id = authz_req['client_id']
+ ae = create_authn_event(self.user_id)
+ return self.session_manager.create_session(ae, authz_req, self.user_id,
+ client_id=client_id,
+ sub_type=sub_type)
+
+ def _mint_token(self, type, grant, session_id, based_on=None, **kwargs):
+ # Constructing an authorization code is now done
+ return grant.mint_token(
+ session_id=session_id,
+ endpoint_context=self.endpoint_context,
+ token_type=type,
+ token_handler=self.session_manager.token_handler.handler[MAP[type]],
+ expires_at=time_sans_frac() + 300, # 5 minutes from now
+ based_on=based_on,
+ **kwargs
)
- _dic = self.endpoint.endpoint_context.sdb.upgrade_to_token(key=session_id)
+
+ def test_parse(self):
+ session_id = self._create_session(AUTH_REQ)
+ # apply consent
+ grant = self.endpoint_context.authz(session_id=session_id, request=AUTH_REQ)
+ # grant = self.session_manager[session_id]
+ code = self._mint_token("authorization_code", grant, session_id)
+ access_token = self._mint_token("access_token", grant, session_id, code,
+ resources=[AUTH_REQ["client_id"]])
_verifier = JWT(self.endpoint.endpoint_context.keyjar)
- _info = _verifier.unpack(_dic["access_token"])
+ _info = _verifier.unpack(access_token.value)
assert _info["ttype"] == "T"
- assert _info["phone_number"] == "+46907865000"
- assert set(_info["aud"]) == {"client_1", "https://example.org/appl"}
+ # assert _info["eduperson_scoped_affiliation"] == ["staff@example.org"]
+ assert set(_info["aud"]) == {"client_1"}
def test_info(self):
- session_id = setup_session(
- self.endpoint.endpoint_context, AUTH_REQ, uid="diana"
- )
- _dic = self.endpoint.endpoint_context.sdb.upgrade_to_token(key=session_id)
+ session_id = self._create_session(AUTH_REQ)
+ # apply consent
+ grant = self.endpoint_context.authz(session_id=session_id, request=AUTH_REQ)
+ #
+ code = self._mint_token("authorization_code", grant, session_id)
+ access_token = self._mint_token("access_token", grant, session_id, code)
- handler = self.endpoint.endpoint_context.sdb.handler.handler["access_token"]
- _info = handler.info(_dic["access_token"])
+ _info = self.session_manager.token_handler.info(access_token.value)
assert _info["type"] == "T"
assert _info["sid"] == session_id
@pytest.mark.parametrize("enable_claims_per_client", [True, False])
- def test_client_claims(self, enable_claims_per_client):
- ec = self.endpoint.endpoint_context
- handler = ec.sdb.handler.handler["access_token"]
- session_id = setup_session(ec, AUTH_REQ, uid="diana")
- ec.cdb["client_1"]["access_token_claims"] = {"address": None}
- handler.enable_claims_per_client = enable_claims_per_client
- _dic = ec.sdb.upgrade_to_token(key=session_id)
-
- token = _dic["access_token"]
- _jwt = JWT(key_jar=KEYJAR, iss="client_1")
- res = _jwt.unpack(token)
- assert enable_claims_per_client is ("address" in res)
+ def test_enable_claims_per_client(self, enable_claims_per_client):
+ # Set up configuration
+ self.endpoint.endpoint_context.cdb["client_1"]["access_token_claims"] = {"address": None}
+ self.endpoint_context.session_manager.token_handler.handler["access_token"].kwargs[
+ "enable_claims_per_client"] = enable_claims_per_client
+
+ session_id = self._create_session(AUTH_REQ)
+ # apply consent
+ grant = self.endpoint_context.authz(session_id=session_id, request=AUTH_REQ)
+ #
+ code = self._mint_token("authorization_code", grant, session_id)
+ access_token = self._mint_token("access_token", grant, session_id, code)
- @pytest.mark.parametrize("add_scope", [True, False])
- def test_add_scopes(self, add_scope):
- ec = self.endpoint.endpoint_context
- handler = ec.sdb.handler.handler["access_token"]
- auth_req = dict(AUTH_REQ)
- auth_req["scope"] = ["openid", "profile", "aba"]
- session_id = setup_session(ec, auth_req, uid="diana")
- handler.add_scope = add_scope
- _dic = ec.sdb.upgrade_to_token(key=session_id)
-
- token = _dic["access_token"]
_jwt = JWT(key_jar=KEYJAR, iss="client_1")
- res = _jwt.unpack(token)
- assert add_scope is (res.get("scope") == ["openid", "profile"])
+ res = _jwt.unpack(access_token.value)
+ assert enable_claims_per_client is ("address" in res)
def test_is_expired(self):
- session_id = setup_session(
- self.endpoint.endpoint_context, AUTH_REQ, uid="diana"
- )
- _dic = self.endpoint.endpoint_context.sdb.upgrade_to_token(key=session_id)
+ session_id = self._create_session(AUTH_REQ)
+ grant = self.session_manager[session_id]
+ code = self._mint_token("authorization_code", grant, session_id)
+ access_token = self._mint_token("access_token", grant, session_id, code)
- handler = self.endpoint.endpoint_context.sdb.handler.handler["access_token"]
- assert handler.is_expired(_dic["access_token"]) is False
-
- assert (
- handler.is_expired(_dic["access_token"], utc_time_sans_frac() + 4000)
- is True
- )
+ assert access_token.is_active()
+ # 4000 seconds in the future. Passed the lifetime.
+ assert access_token.is_active(now=time_sans_frac() + 4000) is False
diff --git a/tests/test_28_oidc_refresh_token_endpoint.py b/tests/test_28_oidc_refresh_token_endpoint.py
deleted file mode 100755
index 04a5f7e..0000000
--- a/tests/test_28_oidc_refresh_token_endpoint.py
+++ /dev/null
@@ -1,238 +0,0 @@
-import json
-import os
-
-import pytest
-from oidcmsg.oidc import AccessTokenRequest
-from oidcmsg.oidc import AuthorizationRequest
-from oidcmsg.oidc import RefreshAccessTokenRequest
-
-from oidcendpoint.client_authn import ClientSecretBasic
-from oidcendpoint.client_authn import ClientSecretJWT
-from oidcendpoint.client_authn import ClientSecretPost
-from oidcendpoint.client_authn import PrivateKeyJWT
-from oidcendpoint.client_authn import verify_client
-from oidcendpoint.endpoint_context import EndpointContext
-from oidcendpoint.oidc import userinfo
-from oidcendpoint.oidc.authorization import Authorization
-from oidcendpoint.oidc.provider_config import ProviderConfiguration
-from oidcendpoint.oidc.refresh_token import RefreshAccessToken
-from oidcendpoint.oidc.registration import Registration
-from oidcendpoint.oidc.token import AccessToken
-from oidcendpoint.session import setup_session
-from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD
-from oidcendpoint.user_info import UserInfo
-
-KEYDEFS = [
- {"type": "RSA", "key": "", "use": ["sig"]},
- {"type": "EC", "crv": "P-256", "use": ["sig"]},
-]
-
-RESPONSE_TYPES_SUPPORTED = [
- ["code"],
- ["token"],
- ["id_token"],
- ["code", "token"],
- ["code", "id_token"],
- ["id_token", "token"],
- ["code", "token", "id_token"],
- ["none"],
-]
-
-CAPABILITIES = {
- "response_types_supported": [" ".join(x) for x in RESPONSE_TYPES_SUPPORTED],
- "token_endpoint_auth_methods_supported": [
- "client_secret_post",
- "client_secret_basic",
- "client_secret_jwt",
- "private_key_jwt",
- ],
- "response_modes_supported": ["query", "fragment", "form_post"],
- "subject_types_supported": ["public", "pairwise"],
- "grant_types_supported": [
- "authorization_code",
- "implicit",
- "urn:ietf:params:oauth:grant-type:jwt-bearer",
- "refresh_token",
- ],
- "claim_types_supported": ["normal", "aggregated", "distributed"],
- "claims_parameter_supported": True,
- "request_parameter_supported": True,
- "request_uri_parameter_supported": True,
-}
-
-AUTH_REQ = AuthorizationRequest(
- client_id="client_1",
- redirect_uri="https://example.com/cb",
- scope=["openid"],
- state="STATE",
- response_type="code",
-)
-
-TOKEN_REQ = AccessTokenRequest(
- client_id="client_1",
- redirect_uri="https://example.com/cb",
- state="STATE",
- grant_type="authorization_code",
- client_secret="hemligt",
-)
-
-REFRESH_TOKEN_REQ = RefreshAccessTokenRequest(
- grant_type="refresh_token", client_id="client_1", client_secret="hemligt"
-)
-
-TOKEN_REQ_DICT = TOKEN_REQ.to_dict()
-
-BASEDIR = os.path.abspath(os.path.dirname(__file__))
-
-
-def full_path(local_file):
- return os.path.join(BASEDIR, local_file)
-
-
-USERINFO = UserInfo(json.loads(open(full_path("users.json")).read()))
-
-
-class TestEndpoint(object):
- @pytest.fixture(autouse=True)
- def create_endpoint(self):
- conf = {
- "issuer": "https://example.com/",
- "password": "mycket hemligt",
- "token_expires_in": 600,
- "grant_expires_in": 300,
- "refresh_token_expires_in": 86400,
- "verify_ssl": False,
- "capabilities": CAPABILITIES,
- "keys": {"uri_path": "jwks.json", "key_defs": KEYDEFS},
- "endpoint": {
- "provider_config": {
- "path": "{}/.well-known/openid-configuration",
- "class": ProviderConfiguration,
- "kwargs": {},
- },
- "registration": {
- "path": "{}/registration",
- "class": Registration,
- "kwargs": {},
- },
- "authorization": {
- "path": "{}/authorization",
- "class": Authorization,
- "kwargs": {},
- },
- "token": {
- "path": "{}/token",
- "class": AccessToken,
- "kwargs": {
- "client_authn_method": {
- "client_secret_basic": ClientSecretBasic,
- "client_secret_post": ClientSecretPost,
- "client_secret_jwt": ClientSecretJWT,
- "private_key_jwt": PrivateKeyJWT,
- }
- },
- },
- "refresh_token": {
- "path": "{}/token",
- "class": RefreshAccessToken,
- "kwargs": {
- "client_authn_method": {
- "client_secret_basic": ClientSecretBasic,
- "client_secret_post": ClientSecretPost,
- "client_secret_jwt": ClientSecretJWT,
- "private_key_jwt": PrivateKeyJWT,
- }
- },
- },
- "userinfo": {
- "path": "{}/userinfo",
- "class": userinfo.UserInfo,
- "kwargs": {"db_file": "users.json"},
- },
- },
- "authentication": {
- "anon": {
- "acr": INTERNETPROTOCOLPASSWORD,
- "class": "oidcendpoint.user_authn.user.NoAuthn",
- "kwargs": {"user": "diana"},
- }
- },
- "userinfo": {"class": UserInfo, "kwargs": {"db": {}}},
- "client_authn": verify_client,
- "template_dir": "template",
- }
- endpoint_context = EndpointContext(conf)
- endpoint_context.cdb["client_1"] = {
- "client_secret": "hemligt",
- "redirect_uris": [("https://example.com/cb", None)],
- "client_salt": "salted",
- "token_endpoint_auth_method": "client_secret_post",
- "response_types": ["code", "token", "code id_token", "id_token"],
- }
- self.token_endpoint = endpoint_context.endpoint["token"]
- self.refresh_token_endpoint = endpoint_context.endpoint["refresh_token"]
-
- def test_init(self):
- assert self.refresh_token_endpoint
-
- def test_do_refresh_access_token(self):
- areq = AUTH_REQ.copy()
- areq["scope"] = ["openid", "offline_access"]
- _cntx = self.token_endpoint.endpoint_context
- session_id = setup_session(
- _cntx, areq, uid="user", acr=INTERNETPROTOCOLPASSWORD
- )
- _cntx.sdb.update(session_id, user="diana")
- _token_request = TOKEN_REQ_DICT.copy()
- _token_request["code"] = _cntx.sdb[session_id]["code"]
- _req = self.token_endpoint.parse_request(_token_request)
- _resp = self.token_endpoint.process_request(request=_req)
-
- _request = REFRESH_TOKEN_REQ.copy()
- _request["refresh_token"] = _resp["response_args"]["refresh_token"]
- _req = self.refresh_token_endpoint.parse_request(_request.to_json())
- _resp = self.refresh_token_endpoint.process_request(request=_req)
- assert set(_resp.keys()) == {"response_args", "http_headers"}
- assert set(_resp["response_args"].keys()) == {
- "access_token",
- "token_type",
- "expires_in",
- "refresh_token",
- "id_token",
- }
- msg = self.refresh_token_endpoint.do_response(request=_req, **_resp)
- assert isinstance(msg, dict)
-
- def test_do_2nd_refresh_access_token(self):
- areq = AUTH_REQ.copy()
- areq["scope"] = ["openid", "offline_access"]
- _cntx = self.token_endpoint.endpoint_context
- session_id = setup_session(
- _cntx, areq, uid="user", acr=INTERNETPROTOCOLPASSWORD
- )
- _cntx.sdb.update(session_id, user="diana")
- _token_request = TOKEN_REQ_DICT.copy()
- _token_request["code"] = _cntx.sdb[session_id]["code"]
- _req = self.token_endpoint.parse_request(_token_request)
- _resp = self.token_endpoint.process_request(request=_req)
-
- _request = REFRESH_TOKEN_REQ.copy()
- _request["refresh_token"] = _resp["response_args"]["refresh_token"]
- _req = self.refresh_token_endpoint.parse_request(_request.to_json())
- _resp = self.refresh_token_endpoint.process_request(request=_req)
-
- _request = REFRESH_TOKEN_REQ.copy()
- _request["refresh_token"] = _resp["response_args"]["refresh_token"]
- _req = self.refresh_token_endpoint.parse_request(_request.to_json())
- _resp = self.refresh_token_endpoint.process_request(request=_req)
-
- assert set(_resp.keys()) == {"response_args", "http_headers"}
- assert set(_resp["response_args"].keys()) == {
- "access_token",
- "token_type",
- "expires_in",
- "refresh_token",
- "id_token",
- }
- msg = self.refresh_token_endpoint.do_response(request=_req, **_resp)
- assert isinstance(msg, dict)
diff --git a/tests/test_30_oidc_end_session.py b/tests/test_30_oidc_end_session.py
index 57125de..04c29dd 100644
--- a/tests/test_30_oidc_end_session.py
+++ b/tests/test_30_oidc_end_session.py
@@ -4,28 +4,35 @@
from urllib.parse import parse_qs
from urllib.parse import urlparse
-from oidcendpoint.token_handler import UnknownToken
import pytest
import responses
+from cryptojwt import as_unicode
+from cryptojwt import b64d
from cryptojwt.key_jar import build_keyjar
+from cryptojwt.utils import as_bytes
from oidcmsg.exception import InvalidRequest
from oidcmsg.message import Message
from oidcmsg.oidc import AuthorizationRequest
from oidcmsg.oidc import verified_claim_name
from oidcmsg.oidc import verify_id_token
+from oidcmsg.time_util import time_sans_frac
-from oidcendpoint.common.authorization import join_query
+from oidcendpoint.authn_event import create_authn_event
from oidcendpoint.cookie import CookieDealer
from oidcendpoint.cookie import new_cookie
from oidcendpoint.endpoint_context import EndpointContext
from oidcendpoint.exception import RedirectURIError
+from oidcendpoint.oauth2.authorization import join_query
from oidcendpoint.oidc import userinfo
from oidcendpoint.oidc.authorization import Authorization
from oidcendpoint.oidc.provider_config import ProviderConfiguration
from oidcendpoint.oidc.registration import Registration
from oidcendpoint.oidc.session import Session
from oidcendpoint.oidc.session import do_front_channel_logout_iframe
-from oidcendpoint.oidc.token import AccessToken
+from oidcendpoint.oidc.token import Token
+from oidcendpoint.session import session_key
+from oidcendpoint.session import unpack_session_key
+from oidcendpoint.session.grant import Grant
from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD
from oidcendpoint.user_info import UserInfo
@@ -62,7 +69,7 @@
"private_key_jwt",
],
"response_modes_supported": ["query", "fragment", "form_post"],
- "subject_types_supported": ["public", "pairwise"],
+ "subject_types_supported": ["public", "pairwise", "ephemeral"],
"grant_types_supported": [
"authorization_code",
"implicit",
@@ -124,7 +131,7 @@ def create_endpoint(self):
"class": Authorization,
"kwargs": {"client_authn_method": None},
},
- "token": {"path": "{}/token", "class": AccessToken, "kwargs": {}},
+ "token": {"path": "{}/token", "class": Token, "kwargs": {}},
"userinfo": {
"path": "{}/userinfo",
"class": userinfo.UserInfo,
@@ -185,24 +192,23 @@ def create_endpoint(self):
"post_logout_redirect_uris": [("{}logout_cb".format(CLI2), "")],
},
}
+ self.session_manager = endpoint_context.session_manager
self.authn_endpoint = endpoint_context.endpoint["authorization"]
self.session_endpoint = endpoint_context.endpoint["session"]
self.token_endpoint = endpoint_context.endpoint["token"]
+ self.user_id = "diana"
def test_end_session_endpoint(self):
# End session not allowed if no cookie and no id_token_hint is sent
# (can't determine session)
- with pytest.raises(UnknownToken):
+ with pytest.raises(ValueError):
_ = self.session_endpoint.process_request("", cookie="FAIL")
- def _create_cookie(self, user, sid, state, client_id):
+ def _create_cookie(self, session_id):
ec = self.session_endpoint.endpoint_context
return new_cookie(
ec,
- sub=user,
- sid=sid,
- state=state,
- client_id=client_id,
+ sid=session_id,
cookie_name=ec.cookie_name["session"],
)
@@ -215,7 +221,7 @@ def _code_auth(self, state):
client_id="client_1",
)
_pr_resp = self.authn_endpoint.parse_request(req.to_dict())
- _resp = self.authn_endpoint.process_request(_pr_resp)
+ return self.authn_endpoint.process_request(_pr_resp)
def _code_auth2(self, state):
req = AuthorizationRequest(
@@ -226,16 +232,7 @@ def _code_auth2(self, state):
client_id="client_2",
)
_pr_resp = self.authn_endpoint.parse_request(req.to_dict())
- _resp = self.authn_endpoint.process_request(_pr_resp)
-
- def _get_sid(self):
- _sdb = self.session_endpoint.endpoint_context.sdb
-
- for _sid in _sdb.keys():
- if _sid.startswith("__state__"):
- continue
- else:
- return _sid
+ return self.authn_endpoint.process_request(_pr_resp)
def _auth_with_id_token(self, state):
req = AuthorizationRequest(
@@ -249,12 +246,19 @@ def _auth_with_id_token(self, state):
_pr_resp = self.authn_endpoint.parse_request(req.to_dict())
_resp = self.authn_endpoint.process_request(_pr_resp)
- return _resp["response_args"]["id_token"]
+ part = self.session_endpoint.endpoint_context.cookie_dealer.get_cookie_value(
+ _resp["cookie"][0], cookie_name="oidcop"
+ )
+ # value is a base64 encoded JSON document
+ _cookie_info = json.loads(as_unicode(b64d(as_bytes(part[0]))))
+
+ return _resp["response_args"], _cookie_info["sid"]
def test_end_session_endpoint_with_cookie(self):
- self._code_auth("1234567")
- _sid = self._get_sid()
- cookie = self._create_cookie("diana", _sid, "1234567", "client_1")
+ _resp = self._code_auth("1234567")
+ _code = _resp["response_args"]["code"]
+ _session_info = self.session_manager.get_session_info_by_token(_code)
+ cookie = self._create_cookie(_session_info["session_id"])
_req_args = self.session_endpoint.parse_request({"state": "1234567"})
resp = self.session_endpoint.process_request(_req_args, cookie=cookie)
@@ -266,39 +270,27 @@ def test_end_session_endpoint_with_cookie(self):
qs = parse_qs(p.query)
jwt_info = self.session_endpoint.unpack_signed_jwt(qs["sjwt"][0])
- assert jwt_info["user"] == "diana"
- assert jwt_info["client_id"] == "client_1"
+ assert jwt_info["sid"] == _session_info["session_id"]
assert jwt_info["redirect_uri"] == "https://example.com/post_logout"
- def test_end_session_endpoint_with_wrong_cookie(self):
- self._code_auth("1234567")
- cookie = self._create_cookie("diana", "client_2", "abcdefg", "client_1")
-
- with pytest.raises(UnknownToken):
- self.session_endpoint.process_request({"state": "abcde"}, cookie=cookie)
-
- def test_end_session_endpoint_with_cookie_wrong_user(self):
+ def test_end_session_endpoint_with_cookie_and_unknown_sid(self):
# Need cookie and ID Token to figure this out
- id_token = self._auth_with_id_token("1234567")
+ resp_args, _session_id = self._auth_with_id_token("1234567")
+ id_token = resp_args["id_token"]
- cookie = self._create_cookie("diggins", "_sid_", "1234567", "client_1")
+ _uid, _cid, _gid = unpack_session_key(_session_id)
+ cookie = self._create_cookie(session_key(_uid, "client_66", _gid))
- msg = Message(id_token=id_token)
- verify_id_token(msg, keyjar=self.session_endpoint.endpoint_context.keyjar)
-
- msg2 = Message(id_token_hint=id_token)
- msg2[verified_claim_name("id_token_hint")] = msg[
- verified_claim_name("id_token")
- ]
with pytest.raises(ValueError):
- self.session_endpoint.process_request(msg2, cookie=cookie)
+ self.session_endpoint.process_request({"state": "foo"}, cookie=cookie)
- def test_end_session_endpoint_with_cookie_unknown_sid(self):
+ def test_end_session_endpoint_with_cookie_id_token_and_unknown_sid(self):
# Need cookie and ID Token to figure this out
- id_token = self._auth_with_id_token("1234567")
+ resp_args, _session_id = self._auth_with_id_token("1234567")
+ id_token = resp_args["id_token"]
- # Wrong client_id
- cookie = self._create_cookie("diana", "_sid_", "state", "client_1")
+ _uid, _cid, _gid = unpack_session_key(_session_id)
+ cookie = self._create_cookie(session_key(_uid, "client_66", _gid))
msg = Message(id_token=id_token)
verify_id_token(msg, keyjar=self.session_endpoint.endpoint_context.keyjar)
@@ -311,11 +303,11 @@ def test_end_session_endpoint_with_cookie_unknown_sid(self):
self.session_endpoint.process_request(msg2, cookie=cookie)
def test_end_session_endpoint_with_cookie_dual_login(self):
- self._code_auth("1234567")
+ _resp = self._code_auth("1234567")
self._code_auth2("abcdefg")
- _sdb = self.session_endpoint.endpoint_context.sdb
- _sid = self._get_sid()
- cookie = self._create_cookie("diana", _sid, "1234567", "client_1")
+ _code = _resp["response_args"]["code"]
+ _session_info = self.session_manager.get_session_info_by_token(_code)
+ cookie = self._create_cookie(_session_info["session_id"])
resp = self.session_endpoint.process_request({"state": "abcde"}, cookie=cookie)
@@ -326,16 +318,15 @@ def test_end_session_endpoint_with_cookie_dual_login(self):
qs = parse_qs(p.query)
jwt_info = self.session_endpoint.unpack_signed_jwt(qs["sjwt"][0])
- assert jwt_info["user"] == "diana"
- assert jwt_info["client_id"] == "client_1"
+ assert jwt_info["sid"] == _session_info["session_id"]
assert jwt_info["redirect_uri"] == "https://example.com/post_logout"
def test_end_session_endpoint_with_post_logout_redirect_uri(self):
- self._code_auth("1234567")
+ _resp = self._code_auth("1234567")
self._code_auth2("abcdefg")
- _sdb = self.session_endpoint.endpoint_context.sdb
- _sid = self._get_sid()
- cookie = self._create_cookie("diana", _sid, "1234567", "client_1")
+ _code = _resp["response_args"]["code"]
+ _session_info = self.session_manager.get_session_info_by_token(_code)
+ cookie = self._create_cookie(_session_info["session_id"])
post_logout_redirect_uri = join_query(
*self.session_endpoint.endpoint_context.cdb["client_1"][
@@ -353,14 +344,13 @@ def test_end_session_endpoint_with_post_logout_redirect_uri(self):
)
def test_end_session_endpoint_with_wrong_post_logout_redirect_uri(self):
- self._code_auth("1234567")
+ _resp = self._code_auth("1234567")
self._code_auth2("abcdefg")
- id_token = self._auth_with_id_token("1234567")
+ resp_args, _session_id = self._auth_with_id_token("1234567")
+ id_token = resp_args["id_token"]
- _sdb = self.session_endpoint.endpoint_context.sdb
- _sid = self._get_sid()
- cookie = self._create_cookie("diana", _sid, "1234567", "client_1")
+ cookie = self._create_cookie(_session_id)
post_logout_redirect_uri = "https://demo.example.com/log_out"
@@ -384,7 +374,7 @@ def test_back_channel_logout_no_uri(self):
self._code_auth("1234567")
res = self.session_endpoint.do_back_channel_logout(
- self.session_endpoint.endpoint_context.cdb["client_1"], "username", 0
+ self.session_endpoint.endpoint_context.cdb["client_1"], 0
)
assert res is None
@@ -394,15 +384,14 @@ def test_back_channel_logout(self):
_cdb = copy.copy(self.session_endpoint.endpoint_context.cdb["client_1"])
_cdb["backchannel_logout_uri"] = "https://example.com/bc_logout"
_cdb["client_id"] = "client_1"
- res = self.session_endpoint.do_back_channel_logout(_cdb, "username", "_sid_")
+ res = self.session_endpoint.do_back_channel_logout(_cdb, "_sid_")
assert isinstance(res, tuple)
assert res[0] == "https://example.com/bc_logout"
_jwt = self.session_endpoint.unpack_signed_jwt(res[1], "RS256")
assert _jwt
assert _jwt["iss"] == ISS
assert _jwt["aud"] == ["client_1"]
- assert _jwt["sub"] == "username"
- assert _jwt["sid"] == "_sid_"
+ assert "sid" in _jwt
def test_front_channel_logout(self):
self._code_auth("1234567")
@@ -448,13 +437,17 @@ def test_front_channel_logout_with_query(self):
assert i in res
def test_logout_from_client_bc(self):
- self._code_auth("1234567")
+ _resp = self._code_auth("1234567")
+ _code = _resp["response_args"]["code"]
+ _session_info = self.session_manager.get_session_info_by_token(
+ _code, client_session_info=True)
+
self.session_endpoint.endpoint_context.cdb["client_1"][
"backchannel_logout_uri"
] = "https://example.com/bc_logout"
self.session_endpoint.endpoint_context.cdb["client_1"]["client_id"] = "client_1"
- _sid = self._get_sid()
- res = self.session_endpoint.logout_from_client(_sid, "client_1")
+
+ res = self.session_endpoint.logout_from_client(_session_info["session_id"])
assert set(res.keys()) == {"blu"}
assert set(res["blu"].keys()) == {"client_1"}
_spec = res["blu"]["client_1"]
@@ -463,30 +456,37 @@ def test_logout_from_client_bc(self):
assert _jwt
assert _jwt["iss"] == ISS
assert _jwt["aud"] == ["client_1"]
- assert _jwt["sid"] == _sid
+ assert "sid" in _jwt # This session ID is not the same as the session_id mentioned above
- with pytest.raises(UnknownToken):
- _ = self.session_endpoint.endpoint_context.sdb[_sid]
+ _sid = self.session_endpoint._decrypt_sid(_jwt["sid"])
+ assert _sid == _session_info["session_id"]
+ assert _session_info["client_session_info"].is_revoked()
def test_logout_from_client_fc(self):
- self._code_auth("1234567")
+ _resp = self._code_auth("1234567")
+ _code = _resp["response_args"]["code"]
+ _session_info = self.session_manager.get_session_info_by_token(
+ _code, client_session_info=True)
+
# del self.session_endpoint.endpoint_context.cdb['client_1']['backchannel_logout_uri']
self.session_endpoint.endpoint_context.cdb["client_1"][
"frontchannel_logout_uri"
] = "https://example.com/fc_logout"
self.session_endpoint.endpoint_context.cdb["client_1"]["client_id"] = "client_1"
- _sid = self._get_sid()
- res = self.session_endpoint.logout_from_client(_sid, "client_1")
+
+ res = self.session_endpoint.logout_from_client(_session_info["session_id"])
assert set(res.keys()) == {"flu"}
assert set(res["flu"].keys()) == {"client_1"}
_spec = res["flu"]["client_1"]
assert _spec == '