From 82c87305f9447bc8885ae7aa9d11063e10924a39 Mon Sep 17 00:00:00 2001 From: roland Date: Thu, 22 Oct 2020 09:10:26 +0200 Subject: [PATCH 001/127] Handle different types of input. There might not be any client information at this time. This is an effect of automatic client registration as defined in the OIDC federation specification. --- src/oidcendpoint/common/authorization.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/src/oidcendpoint/common/authorization.py b/src/oidcendpoint/common/authorization.py index 06f0158..2529b72 100755 --- a/src/oidcendpoint/common/authorization.py +++ b/src/oidcendpoint/common/authorization.py @@ -5,6 +5,7 @@ from oidcmsg.exception import ParameterError from oidcmsg.exception import URIError +from oidcmsg.message import Message from oidcmsg.oauth2 import AuthorizationErrorResponse from oidcmsg.oidc import AuthorizationResponse from oidcmsg.oidc import verified_claim_name @@ -172,20 +173,33 @@ def get_uri(endpoint_context, request, uri_type): def authn_args_gather(request, authn_class_ref, cinfo, **kwargs): """ Gather information to be used by the authentication method + + :param request: The request either as a dictionary or as a Message instance + :param authn_class_ref: Authentication class reference + :param cinfo: Client information + :param kwargs: Extra keyword arguments + :return: Authentication arguments """ authn_args = { "authn_class_ref": authn_class_ref, - "query": request.to_urlencoded(), "return_uri": request["redirect_uri"], } + if isinstance(request, Message): + authn_args["query"] = request.to_urlencoded() + elif isinstance(request, dict): + authn_args["query"] = urlencode(request) + else: + ValueError("Wrong request format") + if "req_user" in kwargs: authn_args["as_user"] = (kwargs["req_user"],) # Below are OIDC specific. Just ignore if OAuth2 - for attr in ["policy_uri", "logo_uri", "tos_uri"]: - if cinfo.get(attr): - authn_args[attr] = cinfo[attr] + if cinfo: + for attr in ["policy_uri", "logo_uri", "tos_uri"]: + if cinfo.get(attr): + authn_args[attr] = cinfo[attr] for attr in ["ui_locales", "acr_values", "login_hint"]: if request.get(attr): From 5d52f0bf0187a2255017fc3f5b660135310e1b02 Mon Sep 17 00:00:00 2001 From: roland Date: Thu, 22 Oct 2020 09:12:22 +0200 Subject: [PATCH 002/127] Allow for adding extra response arguments. Not used here but used in classes that are built on this. --- src/oidcendpoint/oidc/authorization.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/oidcendpoint/oidc/authorization.py b/src/oidcendpoint/oidc/authorization.py index 8408084..0955534 100755 --- a/src/oidcendpoint/oidc/authorization.py +++ b/src/oidcendpoint/oidc/authorization.py @@ -9,6 +9,7 @@ from cryptojwt.utils import as_unicode from cryptojwt.utils import b64d from cryptojwt.utils import b64e + from oidcendpoint.token_handler import UnknownToken from oidcmsg import oidc from oidcmsg.exception import ParameterError @@ -393,6 +394,9 @@ def setup_auth(self, request, redirect_uri, cinfo, cookie, acr=None, **kwargs): return {"authn_event": authn_event, "identity": identity, "user": user} + def extra_response_args(self, aresp): + return aresp + def create_authn_response(self, request, sid): """ @@ -474,6 +478,8 @@ def create_authn_response(self, request, sid): ) return {"response_args": resp, "fragment_enc": fragment_enc} + aresp = self.extra_response_args(aresp) + return {"response_args": aresp, "fragment_enc": fragment_enc} def aresp_check(self, aresp, request): @@ -679,7 +685,7 @@ def authz_part2(self, user, authn_event, request, **kwargs): def process_request(self, request_info=None, **kwargs): """ The AuthorizationRequest endpoint - :param request_info: The authorization request as a dictionary + :param request_info: The authorization request as a Message instance :return: dictionary """ @@ -691,7 +697,8 @@ def process_request(self, request_info=None, **kwargs): logger.debug("client {}: {}".format(_cid, cinfo)) # this apply the default optionally deny_unknown_scopes policy - check_unknown_scopes_policy(request_info, cinfo, self.endpoint_context) + if cinfo: + check_unknown_scopes_policy(request_info, cinfo, self.endpoint_context) cookie = kwargs.get("cookie", "") if cookie: From dc6f366def220e15579654b5420c912353bd39b4 Mon Sep 17 00:00:00 2001 From: roland Date: Thu, 22 Oct 2020 09:12:47 +0200 Subject: [PATCH 003/127] Bumped version. --- src/oidcendpoint/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/oidcendpoint/__init__.py b/src/oidcendpoint/__init__.py index 653aef2..ef93b0a 100755 --- a/src/oidcendpoint/__init__.py +++ b/src/oidcendpoint/__init__.py @@ -1,7 +1,7 @@ import string from secrets import choice -__version__ = "1.1.1" +__version__ = "1.1.2" DEF_SIGN_ALG = { "id_token": "RS256", From 7f383d702e76b56ecf6d8df9dac1d79f84211745 Mon Sep 17 00:00:00 2001 From: roland Date: Tue, 3 Nov 2020 10:50:15 +0100 Subject: [PATCH 004/127] clean up --- src/oidcendpoint/endpoint.py | 5 +++-- src/oidcendpoint/oauth2/authorization.py | 2 +- src/oidcendpoint/oauth2/introspection.py | 2 +- src/oidcendpoint/oidc/authorization.py | 3 +-- src/oidcendpoint/oidc/userinfo.py | 2 +- src/oidcendpoint/token_handler.py | 2 +- src/oidcendpoint/userinfo.py | 1 + tests/test_08_session.py | 4 ++-- tests/test_23_oidc_registration_endpoint.py | 2 +- tests/test_24_oauth2_authorization_endpoint.py | 4 ++-- tests/test_24_oidc_authorization_endpoint.py | 2 +- tests/test_30_oidc_end_session.py | 2 +- tests/test_31_introspection.py | 3 --- 13 files changed, 16 insertions(+), 18 deletions(-) diff --git a/src/oidcendpoint/endpoint.py b/src/oidcendpoint/endpoint.py index f60ab1d..ba91b6f 100755 --- a/src/oidcendpoint/endpoint.py +++ b/src/oidcendpoint/endpoint.py @@ -4,17 +4,18 @@ from cryptojwt import jwe from cryptojwt.jws.jws import SIGNER_ALGS -from oidcendpoint.token_handler import UnknownToken from oidcmsg.exception import MissingRequiredAttribute from oidcmsg.exception import MissingRequiredValue from oidcmsg.message import Message -from oidcmsg.oauth2 import ResponseMessage, AuthorizationErrorResponse +from oidcmsg.oauth2 import AuthorizationErrorResponse +from oidcmsg.oauth2 import ResponseMessage from oidcendpoint import sanitize from oidcendpoint.client_authn import UnknownOrNoAuthnMethod from oidcendpoint.client_authn import client_auth_setup from oidcendpoint.client_authn import verify_client from oidcendpoint.exception import UnAuthorizedClient +from oidcendpoint.token_handler import UnknownToken from oidcendpoint.util import OAUTH2_NOCACHE_HEADERS __author__ = "Roland Hedberg" diff --git a/src/oidcendpoint/oauth2/authorization.py b/src/oidcendpoint/oauth2/authorization.py index 6e72b17..c4f6c23 100755 --- a/src/oidcendpoint/oauth2/authorization.py +++ b/src/oidcendpoint/oauth2/authorization.py @@ -7,7 +7,6 @@ from cryptojwt.utils import as_unicode from cryptojwt.utils import b64d from cryptojwt.utils import b64e -from oidcendpoint.token_handler import UnknownToken from oidcmsg import oauth2 from oidcmsg.exception import ParameterError from oidcmsg.oidc import AuthorizationResponse @@ -35,6 +34,7 @@ from oidcendpoint.exception import UnAuthorizedClientScope from oidcendpoint.exception import UnknownClient from oidcendpoint.session import setup_session +from oidcendpoint.token_handler import UnknownToken from oidcendpoint.user_authn.authn_context import pick_auth logger = logging.getLogger(__name__) diff --git a/src/oidcendpoint/oauth2/introspection.py b/src/oidcendpoint/oauth2/introspection.py index 4211bd8..8e83a4c 100644 --- a/src/oidcendpoint/oauth2/introspection.py +++ b/src/oidcendpoint/oauth2/introspection.py @@ -1,11 +1,11 @@ """Implements RFC7662""" import logging -from oidcendpoint.token_handler import UnknownToken from oidcmsg import oauth2 from oidcmsg.time_util import utc_time_sans_frac from oidcendpoint.endpoint import Endpoint +from oidcendpoint.token_handler import UnknownToken LOGGER = logging.getLogger(__name__) diff --git a/src/oidcendpoint/oidc/authorization.py b/src/oidcendpoint/oidc/authorization.py index 0955534..050c56e 100755 --- a/src/oidcendpoint/oidc/authorization.py +++ b/src/oidcendpoint/oidc/authorization.py @@ -9,8 +9,6 @@ from cryptojwt.utils import as_unicode from cryptojwt.utils import b64d from cryptojwt.utils import b64e - -from oidcendpoint.token_handler import UnknownToken from oidcmsg import oidc from oidcmsg.exception import ParameterError from oidcmsg.oidc import Claims @@ -38,6 +36,7 @@ from oidcendpoint.exception import UnknownClient from oidcendpoint.oauth2.authorization import check_unknown_scopes_policy from oidcendpoint.session import setup_session +from oidcendpoint.token_handler import UnknownToken from oidcendpoint.user_authn.authn_context import pick_auth logger = logging.getLogger(__name__) diff --git a/src/oidcendpoint/oidc/userinfo.py b/src/oidcendpoint/oidc/userinfo.py index ffe57e7..a63cdc5 100755 --- a/src/oidcendpoint/oidc/userinfo.py +++ b/src/oidcendpoint/oidc/userinfo.py @@ -4,12 +4,12 @@ from cryptojwt.exception import MissingValue from cryptojwt.jwt import JWT from cryptojwt.jwt import utc_time_sans_frac -from oidcendpoint.token_handler import UnknownToken from oidcmsg import oidc from oidcmsg.message import Message from oidcmsg.oauth2 import ResponseMessage from oidcendpoint.endpoint import Endpoint +from oidcendpoint.token_handler import UnknownToken from oidcendpoint.userinfo import collect_user_info from oidcendpoint.util import OAUTH2_NOCACHE_HEADERS diff --git a/src/oidcendpoint/token_handler.py b/src/oidcendpoint/token_handler.py index 692d7d5..3f04a1e 100755 --- a/src/oidcendpoint/token_handler.py +++ b/src/oidcendpoint/token_handler.py @@ -304,7 +304,7 @@ def factory(ec, code=None, token=None, refresh=None, jwks_def=None, **kwargs): _add_passwd(kj, token, "token") args["access_token_handler"] = init_token_handler(ec, token, TTYPE["token"]) - if refresh: + if refresh is not None: _add_passwd(kj, refresh, "refresh") args["refresh_token_handler"] = init_token_handler( ec, refresh, TTYPE["refresh"] diff --git a/src/oidcendpoint/userinfo.py b/src/oidcendpoint/userinfo.py index 56684fd..bd2fa01 100755 --- a/src/oidcendpoint/userinfo.py +++ b/src/oidcendpoint/userinfo.py @@ -3,6 +3,7 @@ from oidcmsg.oidc import Claims from oidcendpoint import sanitize +from oidcendpoint.exception import FailedAuthentication from oidcendpoint.exception import ImproperlyConfigured from oidcendpoint.scopes import convert_scopes2claims diff --git a/tests/test_08_session.py b/tests/test_08_session.py index 3952c55..15ca52c 100644 --- a/tests/test_08_session.py +++ b/tests/test_08_session.py @@ -2,11 +2,10 @@ import shutil import time -from oidcendpoint.token_handler import UnknownToken +import pytest from oidcmsg.oidc import AuthorizationRequest from oidcmsg.oidc import OpenIDRequest from oidcmsg.storage.init import storage_factory -import pytest from oidcendpoint import rndstr from oidcendpoint import token_handler @@ -17,6 +16,7 @@ from oidcendpoint.session import SessionDB from oidcendpoint.session import setup_session from oidcendpoint.sso_db import SSODb +from oidcendpoint.token_handler import UnknownToken from oidcendpoint.token_handler import WrongTokenType from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD from oidcendpoint.user_info import UserInfo diff --git a/tests/test_23_oidc_registration_endpoint.py b/tests/test_23_oidc_registration_endpoint.py index 0d59e45..90d7969 100755 --- a/tests/test_23_oidc_registration_endpoint.py +++ b/tests/test_23_oidc_registration_endpoint.py @@ -1,9 +1,9 @@ # -*- coding: latin-1 -*- import json -from cryptojwt.key_jar import init_key_jar import pytest import responses +from cryptojwt.key_jar import init_key_jar from oidcmsg.oidc import RegistrationRequest from oidcmsg.oidc import RegistrationResponse diff --git a/tests/test_24_oauth2_authorization_endpoint.py b/tests/test_24_oauth2_authorization_endpoint.py index 599269d..8d4cbcd 100755 --- a/tests/test_24_oauth2_authorization_endpoint.py +++ b/tests/test_24_oauth2_authorization_endpoint.py @@ -29,9 +29,9 @@ from oidcendpoint.exception import NoSuchAuthentication from oidcendpoint.exception import RedirectURIError from oidcendpoint.exception import ToOld -from oidcendpoint.exception import UnknownClient -from oidcendpoint.exception import UnAuthorizedClientScope from oidcendpoint.exception import UnAuthorizedClient +from oidcendpoint.exception import UnAuthorizedClientScope +from oidcendpoint.exception import UnknownClient from oidcendpoint.id_token import IDToken from oidcendpoint.oauth2.authorization import Authorization from oidcendpoint.session import SessionInfo diff --git a/tests/test_24_oidc_authorization_endpoint.py b/tests/test_24_oidc_authorization_endpoint.py index 9de55a9..2130b5d 100755 --- a/tests/test_24_oidc_authorization_endpoint.py +++ b/tests/test_24_oidc_authorization_endpoint.py @@ -32,8 +32,8 @@ from oidcendpoint.exception import NoSuchAuthentication from oidcendpoint.exception import RedirectURIError from oidcendpoint.exception import ToOld -from oidcendpoint.exception import UnknownClient from oidcendpoint.exception import UnAuthorizedClient +from oidcendpoint.exception import UnknownClient from oidcendpoint.id_token import IDToken from oidcendpoint.login_hint import LoginHint2Acrs from oidcendpoint.oidc import userinfo diff --git a/tests/test_30_oidc_end_session.py b/tests/test_30_oidc_end_session.py index 57125de..088e591 100644 --- a/tests/test_30_oidc_end_session.py +++ b/tests/test_30_oidc_end_session.py @@ -4,7 +4,6 @@ from urllib.parse import parse_qs from urllib.parse import urlparse -from oidcendpoint.token_handler import UnknownToken import pytest import responses from cryptojwt.key_jar import build_keyjar @@ -26,6 +25,7 @@ from oidcendpoint.oidc.session import Session from oidcendpoint.oidc.session import do_front_channel_logout_iframe from oidcendpoint.oidc.token import AccessToken +from oidcendpoint.token_handler import UnknownToken from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD from oidcendpoint.user_info import UserInfo diff --git a/tests/test_31_introspection.py b/tests/test_31_introspection.py index 4327eb8..2a31542 100644 --- a/tests/test_31_introspection.py +++ b/tests/test_31_introspection.py @@ -13,9 +13,6 @@ from oidcmsg.oidc import AuthorizationRequest from oidcmsg.time_util import utc_time_sans_frac -from oidcendpoint.client_authn import ClientSecretPost -from oidcendpoint.client_authn import UnknownOrNoAuthnMethod -from oidcendpoint.client_authn import WrongAuthnMethod from oidcendpoint.client_authn import verify_client from oidcendpoint.endpoint_context import EndpointContext from oidcendpoint.exception import UnAuthorizedClient From 2f2e388499eb196322de6e88b9f9d4f586c71e9d Mon Sep 17 00:00:00 2001 From: roland Date: Tue, 3 Nov 2020 10:51:24 +0100 Subject: [PATCH 005/127] Change so the method definition is the same. --- src/oidcendpoint/jwt_token.py | 70 +++++++++++++++-------------------- 1 file changed, 29 insertions(+), 41 deletions(-) diff --git a/src/oidcendpoint/jwt_token.py b/src/oidcendpoint/jwt_token.py index d211eea..28b39b1 100644 --- a/src/oidcendpoint/jwt_token.py +++ b/src/oidcendpoint/jwt_token.py @@ -1,7 +1,3 @@ -from typing import Any -from typing import Dict -from typing import Optional - from cryptojwt import JWT from cryptojwt.jws.exception import JWSException @@ -20,16 +16,16 @@ class JWTToken(Token): } def __init__( - self, - typ, - keyjar=None, - issuer=None, - aud=None, - alg="ES256", - lifetime=300, - ec=None, - token_type="Bearer", - **kwargs + self, + typ, + keyjar=None, + issuer=None, + aud=None, + alg="ES256", + lifetime=300, + ec=None, + token_type="Bearer", + **kwargs ): Token.__init__(self, typ, **kwargs) self.token_type = token_type @@ -64,47 +60,38 @@ def do_add_claims(self, payload, uinfo, claims): pass def __call__( - self, - sid: str, - uinfo: Dict, - sinfo: Dict, - aud: Optional[Any], - client_id: Optional[str], - **kwargs + self, + sid: str, + **kwargs ): """ Return a token. :param sid: Session id - :param uinfo: User information - :param sinfo: Session information - :param aud: audience - :param client_id: client_id - :return: + :return: Signed JSON Web Token """ - payload = {"sid": sid, "ttype": self.type, "sub": sinfo["sub"]} + + payload = {"sid": sid, "ttype": self.type, "sub": kwargs["sinfo"]["sub"]} + + _user_claims = kwargs.get('user_claims') + _client_id = kwargs.get('client_id') + _scopes = kwargs.get('scope') if self.add_claims: - self.do_add_claims(payload, uinfo, self.add_claims) + self.do_add_claims(payload, _user_claims, self.add_claims) if self.add_claims_by_scope: - _allowed_claims = self.cntx.claims_handler.allowed_claims( - client_id, self.cntx - ) + _allowed_claims = self.cntx.claims_handler.allowed_claims(_client_id, self.cntx) self.do_add_claims( payload, - uinfo, - convert_scopes2claims( - sinfo["authn_req"]["scope"], - _allowed_claims, - map=self.scope_claims_map, - ).keys(), + _user_claims, + convert_scopes2claims(_scopes, _allowed_claims, map=self.scope_claims_map).keys(), ) # Add claims if is access token if self.type == "T" and self.enable_claims_per_client: - client = self.cdb.get(client_id, {}) + client = self.cdb.get(_client_id, {}) client_claims = client.get("access_token_claims") if client_claims: - self.do_add_claims(payload, uinfo, client_claims) + self.do_add_claims(payload, _user_claims, client_claims) payload.update(kwargs) signer = JWT( @@ -114,10 +101,11 @@ def __call__( sign_alg=self.alg, ) - if aud is None: + _aud = kwargs.get('aud') + if _aud is None: _aud = self.def_aud else: - _aud = aud if isinstance(aud, list) else [aud] + _aud = _aud if isinstance(_aud, list) else [_aud] _aud.extend(self.def_aud) return signer.pack(payload, aud=_aud) From 5cca27bd1bd223ad004d993b24fd5cad8f2af11e Mon Sep 17 00:00:00 2001 From: roland Date: Tue, 3 Nov 2020 11:00:08 +0100 Subject: [PATCH 006/127] Basic functionality. --- src/oidcendpoint/grant.py | 231 ++++++++++++ src/oidcendpoint/session_management.py | 230 ++++++++++++ tests/test_70_grant.py | 61 ++++ tests/test_71_identity_db.py | 67 ++++ tests/test_72_session_life.py | 464 +++++++++++++++++++++++++ 5 files changed, 1053 insertions(+) create mode 100644 src/oidcendpoint/grant.py create mode 100644 src/oidcendpoint/session_management.py create mode 100644 tests/test_70_grant.py create mode 100644 tests/test_71_identity_db.py create mode 100644 tests/test_72_session_life.py diff --git a/src/oidcendpoint/grant.py b/src/oidcendpoint/grant.py new file mode 100644 index 0000000..83e2f0a --- /dev/null +++ b/src/oidcendpoint/grant.py @@ -0,0 +1,231 @@ +import json +import time +from typing import Optional +from uuid import uuid1 + +from oidcmsg.message import OPTIONAL_LIST_OF_SP_SEP_STRINGS +from oidcmsg.message import OPTIONAL_LIST_OF_STRINGS +from oidcmsg.message import SINGLE_OPTIONAL_JSON +from oidcmsg.message import Message +from oidcmsg.time_util import utc_time_sans_frac + + +class GrantMessage(Message): + c_param = { + "scope": OPTIONAL_LIST_OF_SP_SEP_STRINGS, # As defined in RFC6749 + "authorization_details": SINGLE_OPTIONAL_JSON, # As defined in draft-lodderstedt-oauth-rar + "claims": SINGLE_OPTIONAL_JSON, # As defined in OIDC core + "resources": OPTIONAL_LIST_OF_STRINGS, # As defined in RFC8707 + } + + +GRANT_TYPE_MAP = { + "authorization_code": "code", + "access_token": "access_token", + "refresh_token": "refresh_token" +} + + +def find_token(issued, id): + for iss in issued: + if iss.id == id: + return iss + return None + + +class Item: + def __init__(self, + usage_rules: Optional[dict] = None, + issued_at: int = 0, + expires_at: int = 0, + not_before: int = 0 + ): + self.issued_at = issued_at or utc_time_sans_frac() + self.not_before = not_before + self.expires_at = expires_at + self.revoked = False + self.used = 0 + self.usage_rules = usage_rules or {} + + def max_usage_reached(self): + if "max_usage" in self.usage_rules: + return self.used >= self.usage_rules['max_usage'] + else: + return False + + def is_active(self): + if self.max_usage_reached(): + return False + + if self.revoked: + return False + + if self.not_before: + if time.time() < self.not_before: + return False + + if self.expires_at: + if time.time() > self.expires_at: + return False + + return True + + +class Token(Item): + attributes = ["type", "issued_at", "not_before", "expires_at", "revoked", "value", + "usage_rules", "used", "based_on", "id"] + + def __init__(self, + typ: str = '', + based_on: Optional[str] = None, + usage_rules: Optional[dict] = None, + value: Optional[str] = '', + issued_at: int = 0, + expires_at: int = 0, + not_before: int = 0 + ): + Item.__init__(self, usage_rules=usage_rules, issued_at=issued_at, expires_at=expires_at, + not_before=not_before) + + self.type = typ + self.value = value + self.based_on = based_on + self.id = uuid1().hex + + self.set_defaults() + + def set_defaults(self): + pass + + def register_usage(self): + self.used += 1 + + def has_been_used(self): + return self.used != 0 + + def to_json(self): + d = { + "type": self.type, + "issued_at": self.issued_at, + "not_before": self.not_before, + "expires_at": self.expires_at, + "revoked": self.revoked, + "value": self.value, + "usage_rules": self.usage_rules, + "used": self.used, + "based_on": self.based_on, + "id": self.id + } + return json.dumps(d) + + def from_json(self, json_str): + d = json.loads(json_str) + for attr in self.attributes: + if attr in d: + setattr(self, attr, d[attr]) + return self + + def supports_minting(self, token_type): + return token_type in self.usage_rules['supports_minting'] + + +class AuthorizationCode(Token): + def set_defaults(self): + if "supports_minting" not in self.usage_rules: + self.usage_rules['supports_minting'] = ["access_token", "refresh_token"] + + self.usage_rules['max_usage'] = 1 + + +class RefreshToken(Token): + def set_defaults(self): + if "supports_minting" not in self.usage_rules: + self.usage_rules['supports_minting'] = ["access_token", "refresh_token"] + + +TOKEN_MAP = { + "authorization_code": AuthorizationCode, + "access_token": Token, + "refresh_token": RefreshToken +} + + +class Grant(Item): + def __init__(self, + scopes: Optional[list] = None, + claims: Optional[dict] = None, + resources: Optional[list] = None, + authorization_details: Optional[dict] = None, + token: Optional[list] = None, + usage_rules: Optional[dict] = None, + issued_at: int = 0, + expires_at: int = 0): + Item.__init__(self, usage_rules=usage_rules, issued_at=issued_at, expires_at=expires_at) + self.scope = scopes or [] + self.authorization_details = authorization_details or None + self.claims = claims or None + self.resources = resources or [] + self.issued_token = token or [] + self.id = uuid1().hex + + def update(self, item: dict): + for attr in ['scope', 'authorization_details', 'claims', 'resources']: + val = item.get(attr) + if val: + setattr(self, attr, val) + + def replace(self, item: dict): + for attr in ['scope', 'authorization_details', 'claims', 'resources']: + setattr(self, attr, item.get(attr)) + + def revoke(self): + self.revoked = True + for t in self.issued_token: + t.revoked = True + + def get(self): + return GrantMessage(scope=self.scope, claims=self.claims, + authorization_details=self.authorization_details, + resources=self.resources) + + def to_json(self): + d = { + "scope": self.scope, + "authorization_details": self.authorization_details, + "claims": self.claims, + "resources": self.resources, + "issued_at": self.issued_at, + "not_before": self.not_before, + "expires_at": self.expires_at, + "revoked": self.revoked, + "issued_token": [t.to_json for t in self.issued_token], + "id": self.id + } + return json.dumps(d) + + def from_json(self, json_str): + d = json.loads(json_str) + for attr in ["scope", "authorization_details", "claims", "resources", "issued_at", + "not_before", + "expires_at", "revoked", "id"]: + if attr in d: + setattr(self, attr, d[attr]) + if "issued_token" in d: + setattr(self, "issued_token", [Token(**t) for t in d['issued_token']]) + + def mint_token(self, token_type, **kwargs): + item = TOKEN_MAP[token_type](typ=token_type, **kwargs) + self.issued_token.append(item) + return item + + def revoke_all_based_on(self, id): + for t in self.issued_token: + if t.based_on == id: + t.revoked = True + self.revoke_all_based_on(t.id) + + def get_token(self, val): + for t in self.issued_token: + if t.value == val: + return t + return None diff --git a/src/oidcendpoint/session_management.py b/src/oidcendpoint/session_management.py new file mode 100644 index 0000000..44979d6 --- /dev/null +++ b/src/oidcendpoint/session_management.py @@ -0,0 +1,230 @@ +import hashlib +import logging + +logger = logging.getLogger(__name__) + + +def db_key(*args): + return ':'.join(args) + + +def unpack_db_key(key): + return key.split(':') + + +def pairwise_id(uid, sector_identifier, salt, **kwargs): + return hashlib.sha256(("%s%s%s" % (uid, sector_identifier, salt)).encode("utf-8")).hexdigest() + + +def public_id(uid, salt="", **kwargs): + return hashlib.sha256("{}{}".format(uid, salt).encode("utf-8")).hexdigest() + + +class Info(object): + def __init__(self, **kwargs): + self._db = kwargs or {} + if "subordinate" not in self._db: + self._db["subordinate"] = [] + + def set(self, key, value): + self._db[key] = value + + def get(self, key): + return self._db[key] + + def update(self, ava): + self._db.update(ava) + return self + + def add_subordinate(self, value): + self._db["subordinate"].append(value) + return self + + def remove_subordinate(self, value): + self._db["subordinate"].remove(value) + return self + + def __setitem__(self, key, value): + self._db[key] = value + + def __getitem__(self, key): + return self._db[key] + + def keys(self): + return self._db.keys() + + def values(self): + return self._db.values() + + def items(self): + return self._db.items() + + +class UserInfo(Info): + pass + + +class ClientInfo(Info): + def find_grant(self, val): + for grant in self._db["subordinate"]: + token = grant.get_token(val) + if token: + return grant, token + + +class Database(object): + def __init__(self, storage=None): + self._db = storage or {} + + def eval_path(self, path): + uid = path[0] + client_id = None + grant_id = None + if len(path) > 1: + client_id = path[1] + if len(path) > 2: + grant_id = path[2] + + return uid, client_id, grant_id + + def set(self, path: list, value: object): + """ + + :param path: a list of identifiers + :param value: Class instance to be stored + """ + # Try loading the key, that's a good place to put a debugger to + # import pdb; pdb.set_trace() + uid, client_id, grant_id = self.eval_path(path) + + _userinfo = self._db.get(uid) + if _userinfo: + if client_id: + if client_id in _userinfo['subordinate']: + _cid_key = db_key(uid, client_id) + _cid_info = self._db[db_key(uid, client_id)] + if _cid_info: + if grant_id: + _gid_key = db_key(uid, client_id, grant_id) + if grant_id in _cid_info['subordinate']: + _gid_info = self._db[_gid_key] + if not _gid_info: + self._db[_cid_key] = _cid_info.add_subordinate(grant_id) + self._db[_gid_key] = value + else: + self._db[_cid_key] = _cid_info.add_subordinate(grant_id) + self._db[_gid_key] = value + else: + self._db.set[_cid_key] = value + else: + _userinfo.add_subordinate(client_id) + if grant_id: + _cid_info = ClientInfo() + _cid_info.add_subordinate(grant_id) + self._db[_cid_key] = _cid_info + self._db[db_key(uid, client_id, grant_id)] = value + else: + _cid_info = ClientInfo() + self._db[_cid_key] = _cid_info + self._db[uid] = _userinfo + else: + _userinfo.add_subordinate(client_id) + self._db[uid] = _userinfo + if grant_id: + _cid_info = ClientInfo() + _cid_info.add_subordinate(grant_id) + self._db[db_key(uid, client_id, grant_id)] = value + else: + _cid_info = value + + _cid_key = db_key(uid, client_id) + self._db[_cid_key] = _cid_info + else: + self._db[uid] = value + else: + if client_id: + _user_info = UserInfo() + _user_info.add_subordinate(client_id) + if grant_id: + _cid_info = ClientInfo() + _cid_info.add_subordinate(grant_id) + self._db[db_key(uid, client_id, grant_id)] = value + else: + _cid_info = value + self._db[db_key(uid, client_id)] = _cid_info + else: + _user_info = value + + self._db[uid] = _user_info + + def get(self, path: list): + uid, client_id, grant_id = self.eval_path(path) + try: + user_info = self._db[uid] + except KeyError: + raise KeyError('No such UserID') + + if client_id is None: + return user_info + else: + if client_id not in user_info['subordinate']: + raise ValueError('No session from that client for that user') + else: + try: + client_session_info = self._db[db_key(uid, client_id)] + except KeyError: + return {} + else: + if grant_id is None: + return client_session_info + + if grant_id not in client_session_info['subordinate']: + raise ValueError('No such grant for that user and client') + else: + try: + return self._db[db_key(uid, client_id, grant_id)] + except KeyError: + return {} + + def delete(self, path): + uid, client_id, grant_id = self.eval_path(path) + try: + _dic = self._db[uid] + except KeyError: + pass + else: + if client_id: + if client_id in _dic['client_id']: + try: + _cinfo = self._db[db_key(uid, client_id)] + except KeyError: + pass + else: + if grant_id: + if grant_id in _cinfo['grant_id']: + self._db.__delitem__(db_key(uid, client_id, grant_id)) + else: + self._db.__delitem__(db_key(uid, client_id)) + else: + pass + else: + self._db.__delitem__(uid) + + +class SessionManager(Database): + def __init__(self, handler, storage=None): + Database.__init__(self, storage) + self.token_handler = handler + + def get_user(self, uid): + user = self.get(uid) + + def find_grant(self, user_id, client_id, token_value): + client_info = self.get([user_id, client_id]) + for grant_id in client_info["subordinate"]: + grant = self.get([user_id, client_id, grant_id]) + for token in grant.issued_token: + if token.value == token_value: + return grant, token + + return None diff --git a/tests/test_70_grant.py b/tests/test_70_grant.py new file mode 100644 index 0000000..4635541 --- /dev/null +++ b/tests/test_70_grant.py @@ -0,0 +1,61 @@ +from oidcendpoint.grant import Grant +from oidcendpoint.grant import Token +from oidcendpoint.grant import find_token + + +def test_access_code(): + token = Token('access_code', value="ABCD") + assert token.issued_at + assert token.type == "access_code" + assert token.value == "ABCD" + + token.register_usage() + # max_usage == 1 + assert token.max_usage_reached() is True + + +def test_access_token(): + code = Token('access_code', value="ABCD") + token = Token('access_token', value="1234", based_on=code.id) + assert token.issued_at + assert token.type == "access_token" + assert token.value == "1234" + + token.register_usage() + # max_usage - undefined + assert token.max_usage_reached() is False + + token.max_usage = 2 + token.register_usage() + assert token.max_usage_reached() is True + + t = find_token([code, token], token.based_on) + assert t.value == "ABCD" + + token.revoked = True + assert token.revoked is True + + +def test_grant(): + grant = Grant() + code = grant.mint_token("authorization_code", value="ABCD") + access_token = grant.mint_token("access_token", value="1234", based_on=code.id) + refresh_token = grant.mint_token("refresh_token", value="1234", based_on=code.id) + grant.revoke() + assert code.revoked is True + assert access_token.revoked is True + assert refresh_token.revoked is True + +def test_grant_revoked_based_on(): + grant = Grant() + code = grant.mint_token("authorization_code", value="ABCD") + access_token = grant.mint_token("access_token", value="1234", based_on=code.id) + refresh_token = grant.mint_token("refresh_token", value="1234", based_on=code.id) + + code.register_usage() + if code.max_usage_reached(): + grant.revoke_all_based_on(code.id) + + assert code.is_active() is False + assert access_token.is_active() is False + assert refresh_token.is_active() is False diff --git a/tests/test_71_identity_db.py b/tests/test_71_identity_db.py new file mode 100644 index 0000000..4f1a15f --- /dev/null +++ b/tests/test_71_identity_db.py @@ -0,0 +1,67 @@ +# Database is organized in 3 layers. User-session-grant. +import pytest + +from oidcendpoint.authn_event import create_authn_event +from oidcendpoint.grant import Grant +from oidcendpoint.grant import Token +from oidcendpoint.session_management import ClientInfo +from oidcendpoint.session_management import Database +from oidcendpoint.session_management import UserInfo +from oidcendpoint.session_management import public_id + + +class TestDB: + @pytest.fixture(autouse=True) + def setup_environment(self): + self.db = Database() + + def test_user_info(self): + with pytest.raises(KeyError): + self.db.get(['diana']) + + user_info = UserInfo(foo="bar") + self.db.set(['diana'], user_info) + user_info = self.db.get(['diana']) + assert user_info["foo"] == "bar" + + def test_client_info(self): + user_info = UserInfo(foo="bar") + self.db.set(['diana'], user_info) + client_info = ClientInfo(sid= "abcdef") + self.db.set(['diana', "client_1"], client_info) + + user_info = self.db.get(['diana']) + assert user_info['client_id'] == ['client_1'] + client_info = self.db.get(['diana', "client_1"]) + assert client_info['sid'] == "abcdef" + + def test_jump_ahead(self): + grant = Grant() + access_code = Token('access_code', value='1234567890') + grant.issued_token.append(access_code) + + self.db.set(['diana', "client_1", "G1"], grant) + + user_info = self.db.get(['diana']) + assert user_info['client_id'] == ['client_1'] + client_info = self.db.get(['diana', "client_1"]) + assert client_info['grant_id'] == ["G1"] + grant_info = self.db.get(['diana', 'client_1', 'G1']) + assert grant_info.issued_at + assert len(grant_info.issued_token) == 1 + token = grant_info.issued_token[0] + assert token.value == '1234567890' + assert token.type == "access_code" + + def test_step_wise(self): + salt = "natriumklorid" + # store user info + self.db.set(['diana'], UserInfo(authn_event = create_authn_event('diana', salt))) + # Client specific information + self.db.set(['diana', 'client_1'], ClientInfo(sub= public_id('diana', salt))) + # Grant + grant = Grant() + access_code = Token('access_code', value='1234567890') + grant.issued_token.append(access_code) + + self.db.set(['diana', 'client_1', 'G1'], grant) diff --git a/tests/test_72_session_life.py b/tests/test_72_session_life.py new file mode 100644 index 0000000..dcfa961 --- /dev/null +++ b/tests/test_72_session_life.py @@ -0,0 +1,464 @@ +import os + +import pytest +from cryptojwt.key_jar import init_key_jar +from oidcmsg.oidc import AccessTokenRequest +from oidcmsg.oidc import AuthorizationRequest +from oidcmsg.oidc import RefreshAccessTokenRequest +from oidcmsg.time_util import time_sans_frac + +from oidcendpoint import user_info +from oidcendpoint.authn_event import create_authn_event +from oidcendpoint.client_authn import verify_client +from oidcendpoint.endpoint_context import EndpointContext +from oidcendpoint.grant import Grant +from oidcendpoint.id_token import IDToken +from oidcendpoint.session_management import ClientInfo +from oidcendpoint.session_management import SessionManager +from oidcendpoint.session_management import UserInfo +from oidcendpoint.session_management import db_key +from oidcendpoint.session_management import public_id +from oidcendpoint.session_management import unpack_db_key +from oidcendpoint.oidc.authorization import Authorization +from oidcendpoint.oidc.provider_config import ProviderConfiguration +from oidcendpoint.oidc.registration import Registration +from oidcendpoint.oidc.session import Session +from oidcendpoint.oidc.token import AccessToken +from oidcendpoint.token_handler import DefaultToken +from oidcendpoint.token_handler import TokenHandler +from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD + + +class TestSession(): + @pytest.fixture(autouse=True) + def setup_token_handler(self): + password = "The longer the better. Is this close to enough ?" + grant_expires_in = 600 + token_expires_in = 900 + refresh_token_expires_in = 86400 + + code_handler = DefaultToken(password, typ="A", lifetime=grant_expires_in) + access_token_handler = DefaultToken( + password, typ="T", lifetime=token_expires_in + ) + refresh_token_handler = DefaultToken( + password, typ="R", lifetime=refresh_token_expires_in + ) + + handler = TokenHandler( + code_handler=code_handler, + access_token_handler=access_token_handler, + refresh_token_handler=refresh_token_handler, + ) + + self.session_manager = SessionManager(handler) + + def auth(self): + # Start with an authentication request + # The client ID appears in the request + AUTH_REQ = AuthorizationRequest( + client_id="client_1", + redirect_uri="https://example.com/cb", + scope=["openid", "mail", "address", "offline_access"], + state="STATE", + response_type="code", + ) + + # The authentication returns a user ID + user_id = "diana" + + # User info is stored in the Session DB + + user_info = UserInfo() + self.session_manager.set([user_id], user_info) + + # Now for client session information + salt = "natriumklorid" + authn_event = create_authn_event( + user_id, + salt, + authn_info=INTERNETPROTOCOLPASSWORD, + authn_time=time_sans_frac(), + ) + + client_info = ClientInfo( + authorization_request=AUTH_REQ, + authenticationEvent=authn_event, + sub=public_id(user_id, salt) + ) + self.session_manager.set([user_id, AUTH_REQ['client_id']], client_info) + + # The user consent module produces a Grant instance + + grant = Grant(scopes=AUTH_REQ['scope'], resources=[AUTH_REQ['client_id']]) + + # the grant is assigned to a session (user_id, client_id) + + self.session_manager.set([user_id, AUTH_REQ['client_id'], grant.id], grant) + + # Constructing an authorization code is now done by + + code = grant.mint_token( + 'authorization_code', + value=self.session_manager.token_handler["code"](user_id), + expires_at=time_sans_frac() + 300 # 5 minutes from now + ) + + return code + + def test_code_flow(self): + # code is a Token instance + code = self.auth() + + # next step is access token request + + TOKEN_REQ = AccessTokenRequest( + client_id="client_1", + redirect_uri="https://example.com/cb", + state="STATE", + grant_type="authorization_code", + client_secret="hemligt", + code=code.value + ) + + # parse the token + user_id = self.session_manager.token_handler.sid(TOKEN_REQ['code']) + + # Now given I have the client_id from the request and the user_id from the + # token I can easily find the grant + + # client_info = self.session_manager.get([user_id, TOKEN_REQ['client_id']]) + grant, tok = self.session_manager.find_grant(user_id, TOKEN_REQ['client_id'], + TOKEN_REQ['code']) + + # Verify that it's of the correct type and can be used + assert tok.type == "authorization_code" + assert tok.is_active() + + # Mint an access token and a refresh token and mark the code as used + + assert tok.supports_minting("access_token") + + access_token = grant.mint_token( + 'access_token', + value=self.session_manager.token_handler["access_token"](user_id), + expires_at=time_sans_frac() + 900, # 15 minutes from now + based_on=tok.id # Means the token (tok) was used to mint this token + ) + + assert tok.supports_minting("refresh_token") + + refresh_token = grant.mint_token( + 'refresh_token', + value=self.session_manager.token_handler["refresh_token"](user_id), + based_on=tok.id + ) + + tok.register_usage() + + assert tok.max_usage_reached() is True + + # A bit later a refresh token is used to mint a new access token + + REFRESH_TOKEN_REQ = RefreshAccessTokenRequest( + grant_type="refresh_token", + client_id="client_1", + client_secret="hemligt", + refresh_token=refresh_token.value, + scope=["openid", "mail", "offline_access"] + ) + + grant, reftok = self.session_manager.find_grant(user_id, + REFRESH_TOKEN_REQ['client_id'], + REFRESH_TOKEN_REQ['refresh_token']) + + assert reftok.supports_minting("access_token") + + access_token_2 = grant.mint_token( + 'access_token', + value=self.session_manager.token_handler["access_token"](user_id), + expires_at=time_sans_frac() + 900, # 15 minutes from now + based_on=reftok.id # Means the token (tok) was used to mint this token + ) + + assert access_token_2.is_active() + + +KEYDEFS = [ + {"type": "RSA", "key": "", "use": ["sig"]}, + {"type": "EC", "crv": "P-256", "use": ["sig"]}, +] + +ISSUER = "https://example.com/" + +KEYJAR = init_key_jar(key_defs=KEYDEFS, issuer_id=ISSUER) +KEYJAR.import_jwks(KEYJAR.export_jwks(True, ISSUER), "") +RESPONSE_TYPES_SUPPORTED = [ + ["code"], + ["token"], + ["id_token"], + ["code", "token"], + ["code", "id_token"], + ["id_token", "token"], + ["code", "token", "id_token"], + ["none"], +] +CAPABILITIES = { + "response_types_supported": [" ".join(x) for x in RESPONSE_TYPES_SUPPORTED], + "token_endpoint_auth_methods_supported": [ + "client_secret_post", + "client_secret_basic", + "client_secret_jwt", + "private_key_jwt", + ], + "response_modes_supported": ["query", "fragment", "form_post"], + "subject_types_supported": ["public", "pairwise"], + "grant_types_supported": [ + "authorization_code", + "implicit", + "urn:ietf:params:oauth:grant-type:jwt-bearer", + ], + "claim_types_supported": ["normal", "aggregated", "distributed"], + "claims_parameter_supported": True, + "request_parameter_supported": True, + "request_uri_parameter_supported": True, +} +BASEDIR = os.path.abspath(os.path.dirname(__file__)) + + +def full_path(local_file): + return os.path.join(BASEDIR, local_file) + + +class TestSessionJWTToken(): + @pytest.fixture(autouse=True) + def setup_session_manager(self): + conf = { + "issuer": ISSUER, + "password": "mycket hemligt", + "token_expires_in": 600, + "grant_expires_in": 300, + "refresh_token_expires_in": 86400, + "verify_ssl": False, + "capabilities": CAPABILITIES, + "keys": {"uri_path": "jwks.json", "key_defs": KEYDEFS}, + "token_handler_args": { + "jwks_def": { + "private_path": "private/token_jwks.json", + "read_only": False, + "key_defs": [ + {"type": "oct", "bytes": "24", "use": ["enc"], "kid": "code"}, + {"type": "oct", "bytes": "24", "use": ["enc"], "kid": "refresh"} + ], + }, + "code": {"lifetime": 600}, + "token": { + "class": "oidcendpoint.jwt_token.JWTToken", + "kwargs": { + "lifetime": 3600, + "add_claims": [ + "email", + "email_verified", + "phone_number", + "phone_number_verified", + ], + "add_claim_by_scope": True, + "aud": ["https://example.org/appl"], + }, + }, + "refresh": {}, + }, + "endpoint": { + "provider_config": { + "path": "{}/.well-known/openid-configuration", + "class": ProviderConfiguration, + "kwargs": {}, + }, + "registration": { + "path": "{}/registration", + "class": Registration, + "kwargs": {}, + }, + "authorization": { + "path": "{}/authorization", + "class": Authorization, + "kwargs": {}, + }, + "token": {"path": "{}/token", "class": AccessToken, "kwargs": {}}, + "session": {"path": "{}/end_session", "class": Session}, + }, + "client_authn": verify_client, + "authentication": { + "anon": { + "acr": INTERNETPROTOCOLPASSWORD, + "class": "oidcendpoint.user_authn.user.NoAuthn", + "kwargs": {"user": "diana"}, + } + }, + "template_dir": "template", + "userinfo": { + "class": user_info.UserInfo, + "kwargs": {"db_file": full_path("users.json")}, + }, + "id_token": {"class": IDToken}, + } + + self.endpoint_context = EndpointContext(conf, keyjar=KEYJAR) + self.session_manager = SessionManager(self.endpoint_context.sdb.handler) + + def auth(self): + # Start with an authentication request + # The client ID appears in the request + AUTH_REQ = AuthorizationRequest( + client_id="client_1", + redirect_uri="https://example.com/cb", + scope=["openid", "mail", "address", "offline_access"], + state="STATE", + response_type="code", + ) + + # The authentication returns a user ID + user_id = "diana" + + # User info is stored in the Session DB + + user_info = UserInfo() + self.session_manager.set([user_id], user_info) + + # Now for client session information + salt = "natriumklorid" + authn_event = create_authn_event( + user_id, + salt, + authn_info=INTERNETPROTOCOLPASSWORD, + authn_time=time_sans_frac(), + ) + + client_info = ClientInfo( + authorization_request=AUTH_REQ, + authenticationEvent=authn_event, + sub=public_id(user_id, salt) + ) + self.session_manager.set([user_id, AUTH_REQ['client_id']], client_info) + + # The user consent module produces a Grant instance + + grant = Grant(scopes=AUTH_REQ['scope'], resources=[AUTH_REQ['client_id']]) + + # the grant is assigned to a session (user_id, client_id) + + self.session_manager.set([user_id, AUTH_REQ['client_id'], grant.id], grant) + + # Constructing an authorization code is now done by + + code = grant.mint_token( + 'authorization_code', + value=self.session_manager.token_handler["code"]( + db_key(user_id, AUTH_REQ['client_id'])), + expires_at=time_sans_frac() + 300 # 5 minutes from now + ) + + return code + + def test_code_flow(self): + # code is a Token instance + code = self.auth() + + # next step is access token request + + TOKEN_REQ = AccessTokenRequest( + client_id="client_1", + redirect_uri="https://example.com/cb", + state="STATE", + grant_type="authorization_code", + client_secret="hemligt", + code=code.value + ) + + # parse the token + session_id = self.session_manager.token_handler.sid(TOKEN_REQ['code']) + user_id, client_id = unpack_db_key(session_id) + + # Now given I have the client_id from the request and the user_id from the + # token I can easily find the grant + + # client_info = self.session_manager.get([user_id, TOKEN_REQ['client_id']]) + grant, tok = self.session_manager.find_grant(user_id, TOKEN_REQ['client_id'], + TOKEN_REQ['code']) + + # Verify that it's of the correct type and can be used + assert tok.type == "authorization_code" + assert tok.is_active() + + # Mint an access token and a refresh token and mark the code as used + + assert tok.supports_minting("access_token") + + client_info = self.session_manager.get([user_id, TOKEN_REQ["client_id"]]) + + assert tok.supports_minting("access_token") + + user_claims = self.endpoint_context.userinfo(user_id, client_id=TOKEN_REQ["client_id"], + user_info_claims=grant.claims) + + access_token = grant.mint_token( + 'access_token', + value=self.session_manager.token_handler["access_token"]( + db_key(user_id, client_id), + sinfo=client_info, + client_id=TOKEN_REQ['client_id'], + aud=grant.resources, + uinfo=user_claims + ), + expires_at=time_sans_frac() + 900, # 15 minutes from now + based_on=tok.id # Means the token (tok) was used to mint this token + ) + + assert tok.supports_minting("refresh_token") + + refresh_token = grant.mint_token( + 'refresh_token', + value=self.session_manager.token_handler["refresh_token"](db_key(user_id, client_id)), + based_on=tok.id + ) + + tok.register_usage() + + assert tok.max_usage_reached() is True + + # A bit later a refresh token is used to mint a new access token + + REFRESH_TOKEN_REQ = RefreshAccessTokenRequest( + grant_type="refresh_token", + client_id="client_1", + client_secret="hemligt", + refresh_token=refresh_token.value, + scope=["openid", "mail", "offline_access"] + ) + + grant, reftok = self.session_manager.find_grant(user_id, + REFRESH_TOKEN_REQ['client_id'], + REFRESH_TOKEN_REQ['refresh_token']) + + # Can I use this token to mint another token ? + assert reftok.supports_minting("access_token") + assert reftok.is_active() + assert grant.is_active() + + user_claims = self.endpoint_context.userinfo(user_id, client_id=TOKEN_REQ["client_id"], + user_info_claims=grant.claims) + + access_token_2 = grant.mint_token( + 'access_token', + value=self.session_manager.token_handler["access_token"]( + db_key(user_id, client_id), + sinfo=client_info, + client_id=TOKEN_REQ['client_id'], + aud=grant.resources, + uinfo=user_claims + ), + expires_at=time_sans_frac() + 900, # 15 minutes from now + based_on=reftok.id # Means the refresh token (reftok) was used to mint this token + ) + + assert access_token_2.is_active() From acbe78004c3951ce9bf3284c2a8008d38ca3a8f5 Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Tue, 3 Nov 2020 11:30:00 +0100 Subject: [PATCH 007/127] Fixed tests. --- src/oidcendpoint/__init__.py | 2 +- src/oidcendpoint/jwt_token.py | 4 +- tests/test_70_grant.py | 13 ++- ...identity_db.py => test_71_sess_mngm_db.py} | 6 +- tests/test_72_session_life.py | 103 +++++++++--------- 5 files changed, 65 insertions(+), 63 deletions(-) rename tests/{test_71_identity_db.py => test_71_sess_mngm_db.py} (93%) diff --git a/src/oidcendpoint/__init__.py b/src/oidcendpoint/__init__.py index ef93b0a..2bb7941 100755 --- a/src/oidcendpoint/__init__.py +++ b/src/oidcendpoint/__init__.py @@ -1,7 +1,7 @@ import string from secrets import choice -__version__ = "1.1.2" +__version__ = '2.0.0' DEF_SIGN_ALG = { "id_token": "RS256", diff --git a/src/oidcendpoint/jwt_token.py b/src/oidcendpoint/jwt_token.py index 28b39b1..a33d765 100644 --- a/src/oidcendpoint/jwt_token.py +++ b/src/oidcendpoint/jwt_token.py @@ -71,7 +71,7 @@ def __call__( :return: Signed JSON Web Token """ - payload = {"sid": sid, "ttype": self.type, "sub": kwargs["sinfo"]["sub"]} + payload = {"sid": sid, "ttype": self.type, "sub": kwargs['sub']} _user_claims = kwargs.get('user_claims') _client_id = kwargs.get('client_id') @@ -93,7 +93,7 @@ def __call__( if client_claims: self.do_add_claims(payload, _user_claims, client_claims) - payload.update(kwargs) + # payload.update(kwargs) signer = JWT( key_jar=self.key_jar, iss=self.issuer, diff --git a/tests/test_70_grant.py b/tests/test_70_grant.py index 4635541..e3bce23 100644 --- a/tests/test_70_grant.py +++ b/tests/test_70_grant.py @@ -1,12 +1,13 @@ +from oidcendpoint.grant import AuthorizationCode +from oidcendpoint.grant import find_token from oidcendpoint.grant import Grant from oidcendpoint.grant import Token -from oidcendpoint.grant import find_token def test_access_code(): - token = Token('access_code', value="ABCD") + token = AuthorizationCode('authorization_code', value="ABCD") assert token.issued_at - assert token.type == "access_code" + assert token.type == "authorization_code" assert token.value == "ABCD" token.register_usage() @@ -15,8 +16,8 @@ def test_access_code(): def test_access_token(): - code = Token('access_code', value="ABCD") - token = Token('access_token', value="1234", based_on=code.id) + code = AuthorizationCode('authorization_code', value="ABCD") + token = Token('access_token', value="1234", based_on=code.id, usage_rules={"max_usage": 2}) assert token.issued_at assert token.type == "access_token" assert token.value == "1234" @@ -25,7 +26,6 @@ def test_access_token(): # max_usage - undefined assert token.max_usage_reached() is False - token.max_usage = 2 token.register_usage() assert token.max_usage_reached() is True @@ -46,6 +46,7 @@ def test_grant(): assert access_token.revoked is True assert refresh_token.revoked is True + def test_grant_revoked_based_on(): grant = Grant() code = grant.mint_token("authorization_code", value="ABCD") diff --git a/tests/test_71_identity_db.py b/tests/test_71_sess_mngm_db.py similarity index 93% rename from tests/test_71_identity_db.py rename to tests/test_71_sess_mngm_db.py index 4f1a15f..be93cb0 100644 --- a/tests/test_71_identity_db.py +++ b/tests/test_71_sess_mngm_db.py @@ -31,7 +31,7 @@ def test_client_info(self): self.db.set(['diana', "client_1"], client_info) user_info = self.db.get(['diana']) - assert user_info['client_id'] == ['client_1'] + assert user_info['subordinate'] == ['client_1'] client_info = self.db.get(['diana', "client_1"]) assert client_info['sid'] == "abcdef" @@ -43,9 +43,9 @@ def test_jump_ahead(self): self.db.set(['diana', "client_1", "G1"], grant) user_info = self.db.get(['diana']) - assert user_info['client_id'] == ['client_1'] + assert user_info['subordinate'] == ['client_1'] client_info = self.db.get(['diana', "client_1"]) - assert client_info['grant_id'] == ["G1"] + assert client_info['subordinate'] == ["G1"] grant_info = self.db.get(['diana', 'client_1', 'G1']) assert grant_info.issued_at assert len(grant_info.issued_token) == 1 diff --git a/tests/test_72_session_life.py b/tests/test_72_session_life.py index dcfa961..42f6ebf 100644 --- a/tests/test_72_session_life.py +++ b/tests/test_72_session_life.py @@ -13,17 +13,17 @@ from oidcendpoint.endpoint_context import EndpointContext from oidcendpoint.grant import Grant from oidcendpoint.id_token import IDToken -from oidcendpoint.session_management import ClientInfo -from oidcendpoint.session_management import SessionManager -from oidcendpoint.session_management import UserInfo -from oidcendpoint.session_management import db_key -from oidcendpoint.session_management import public_id -from oidcendpoint.session_management import unpack_db_key from oidcendpoint.oidc.authorization import Authorization from oidcendpoint.oidc.provider_config import ProviderConfiguration from oidcendpoint.oidc.registration import Registration from oidcendpoint.oidc.session import Session from oidcendpoint.oidc.token import AccessToken +from oidcendpoint.session_management import ClientInfo +from oidcendpoint.session_management import db_key +from oidcendpoint.session_management import public_id +from oidcendpoint.session_management import SessionManager +from oidcendpoint.session_management import unpack_db_key +from oidcendpoint.session_management import UserInfo from oidcendpoint.token_handler import DefaultToken from oidcendpoint.token_handler import TokenHandler from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD @@ -40,16 +40,16 @@ def setup_token_handler(self): code_handler = DefaultToken(password, typ="A", lifetime=grant_expires_in) access_token_handler = DefaultToken( password, typ="T", lifetime=token_expires_in - ) + ) refresh_token_handler = DefaultToken( password, typ="R", lifetime=refresh_token_expires_in - ) + ) handler = TokenHandler( code_handler=code_handler, access_token_handler=access_token_handler, refresh_token_handler=refresh_token_handler, - ) + ) self.session_manager = SessionManager(handler) @@ -62,7 +62,7 @@ def auth(self): scope=["openid", "mail", "address", "offline_access"], state="STATE", response_type="code", - ) + ) # The authentication returns a user ID user_id = "diana" @@ -79,13 +79,13 @@ def auth(self): salt, authn_info=INTERNETPROTOCOLPASSWORD, authn_time=time_sans_frac(), - ) + ) client_info = ClientInfo( authorization_request=AUTH_REQ, authenticationEvent=authn_event, sub=public_id(user_id, salt) - ) + ) self.session_manager.set([user_id, AUTH_REQ['client_id']], client_info) # The user consent module produces a Grant instance @@ -102,7 +102,7 @@ def auth(self): 'authorization_code', value=self.session_manager.token_handler["code"](user_id), expires_at=time_sans_frac() + 300 # 5 minutes from now - ) + ) return code @@ -119,7 +119,7 @@ def test_code_flow(self): grant_type="authorization_code", client_secret="hemligt", code=code.value - ) + ) # parse the token user_id = self.session_manager.token_handler.sid(TOKEN_REQ['code']) @@ -144,7 +144,7 @@ def test_code_flow(self): value=self.session_manager.token_handler["access_token"](user_id), expires_at=time_sans_frac() + 900, # 15 minutes from now based_on=tok.id # Means the token (tok) was used to mint this token - ) + ) assert tok.supports_minting("refresh_token") @@ -152,7 +152,7 @@ def test_code_flow(self): 'refresh_token', value=self.session_manager.token_handler["refresh_token"](user_id), based_on=tok.id - ) + ) tok.register_usage() @@ -166,7 +166,7 @@ def test_code_flow(self): client_secret="hemligt", refresh_token=refresh_token.value, scope=["openid", "mail", "offline_access"] - ) + ) grant, reftok = self.session_manager.find_grant(user_id, REFRESH_TOKEN_REQ['client_id'], @@ -179,7 +179,7 @@ def test_code_flow(self): value=self.session_manager.token_handler["access_token"](user_id), expires_at=time_sans_frac() + 900, # 15 minutes from now based_on=reftok.id # Means the token (tok) was used to mint this token - ) + ) assert access_token_2.is_active() @@ -187,7 +187,7 @@ def test_code_flow(self): KEYDEFS = [ {"type": "RSA", "key": "", "use": ["sig"]}, {"type": "EC", "crv": "P-256", "use": ["sig"]}, -] + ] ISSUER = "https://example.com/" @@ -202,7 +202,7 @@ def test_code_flow(self): ["id_token", "token"], ["code", "token", "id_token"], ["none"], -] + ] CAPABILITIES = { "response_types_supported": [" ".join(x) for x in RESPONSE_TYPES_SUPPORTED], "token_endpoint_auth_methods_supported": [ @@ -210,19 +210,19 @@ def test_code_flow(self): "client_secret_basic", "client_secret_jwt", "private_key_jwt", - ], + ], "response_modes_supported": ["query", "fragment", "form_post"], "subject_types_supported": ["public", "pairwise"], "grant_types_supported": [ "authorization_code", "implicit", "urn:ietf:params:oauth:grant-type:jwt-bearer", - ], + ], "claim_types_supported": ["normal", "aggregated", "distributed"], "claims_parameter_supported": True, "request_parameter_supported": True, "request_uri_parameter_supported": True, -} + } BASEDIR = os.path.abspath(os.path.dirname(__file__)) @@ -249,8 +249,8 @@ def setup_session_manager(self): "key_defs": [ {"type": "oct", "bytes": "24", "use": ["enc"], "kid": "code"}, {"type": "oct", "bytes": "24", "use": ["enc"], "kid": "refresh"} - ], - }, + ], + }, "code": {"lifetime": 600}, "token": { "class": "oidcendpoint.jwt_token.JWTToken", @@ -261,47 +261,47 @@ def setup_session_manager(self): "email_verified", "phone_number", "phone_number_verified", - ], + ], "add_claim_by_scope": True, "aud": ["https://example.org/appl"], + }, }, - }, "refresh": {}, - }, + }, "endpoint": { "provider_config": { "path": "{}/.well-known/openid-configuration", "class": ProviderConfiguration, "kwargs": {}, - }, + }, "registration": { "path": "{}/registration", "class": Registration, "kwargs": {}, - }, + }, "authorization": { "path": "{}/authorization", "class": Authorization, "kwargs": {}, - }, + }, "token": {"path": "{}/token", "class": AccessToken, "kwargs": {}}, "session": {"path": "{}/end_session", "class": Session}, - }, + }, "client_authn": verify_client, "authentication": { "anon": { "acr": INTERNETPROTOCOLPASSWORD, "class": "oidcendpoint.user_authn.user.NoAuthn", "kwargs": {"user": "diana"}, - } - }, + } + }, "template_dir": "template", "userinfo": { "class": user_info.UserInfo, "kwargs": {"db_file": full_path("users.json")}, - }, + }, "id_token": {"class": IDToken}, - } + } self.endpoint_context = EndpointContext(conf, keyjar=KEYJAR) self.session_manager = SessionManager(self.endpoint_context.sdb.handler) @@ -315,7 +315,7 @@ def auth(self): scope=["openid", "mail", "address", "offline_access"], state="STATE", response_type="code", - ) + ) # The authentication returns a user ID user_id = "diana" @@ -332,13 +332,13 @@ def auth(self): salt, authn_info=INTERNETPROTOCOLPASSWORD, authn_time=time_sans_frac(), - ) + ) client_info = ClientInfo( authorization_request=AUTH_REQ, authenticationEvent=authn_event, sub=public_id(user_id, salt) - ) + ) self.session_manager.set([user_id, AUTH_REQ['client_id']], client_info) # The user consent module produces a Grant instance @@ -356,7 +356,7 @@ def auth(self): value=self.session_manager.token_handler["code"]( db_key(user_id, AUTH_REQ['client_id'])), expires_at=time_sans_frac() + 300 # 5 minutes from now - ) + ) return code @@ -373,7 +373,7 @@ def test_code_flow(self): grant_type="authorization_code", client_secret="hemligt", code=code.value - ) + ) # parse the token session_id = self.session_manager.token_handler.sid(TOKEN_REQ['code']) @@ -405,14 +405,15 @@ def test_code_flow(self): 'access_token', value=self.session_manager.token_handler["access_token"]( db_key(user_id, client_id), - sinfo=client_info, client_id=TOKEN_REQ['client_id'], aud=grant.resources, - uinfo=user_claims - ), + user_claims=user_claims, + scope=grant.scope, + sub=client_info['sub'] + ), expires_at=time_sans_frac() + 900, # 15 minutes from now based_on=tok.id # Means the token (tok) was used to mint this token - ) + ) assert tok.supports_minting("refresh_token") @@ -420,7 +421,7 @@ def test_code_flow(self): 'refresh_token', value=self.session_manager.token_handler["refresh_token"](db_key(user_id, client_id)), based_on=tok.id - ) + ) tok.register_usage() @@ -434,7 +435,7 @@ def test_code_flow(self): client_secret="hemligt", refresh_token=refresh_token.value, scope=["openid", "mail", "offline_access"] - ) + ) grant, reftok = self.session_manager.find_grant(user_id, REFRESH_TOKEN_REQ['client_id'], @@ -452,13 +453,13 @@ def test_code_flow(self): 'access_token', value=self.session_manager.token_handler["access_token"]( db_key(user_id, client_id), - sinfo=client_info, + sub=client_info['sub'], client_id=TOKEN_REQ['client_id'], aud=grant.resources, - uinfo=user_claims - ), + user_claims=user_claims + ), expires_at=time_sans_frac() + 900, # 15 minutes from now based_on=reftok.id # Means the refresh token (reftok) was used to mint this token - ) + ) assert access_token_2.is_active() From 35e667e299502446563239f6e23eb940cc8aa85f Mon Sep 17 00:00:00 2001 From: roland Date: Fri, 6 Nov 2020 14:20:34 +0100 Subject: [PATCH 008/127] Refactoring and adding more usecases. --- src/oidcendpoint/grant.py | 43 +++++-- src/oidcendpoint/session_management.py | 158 ++++++++++++++++++++++--- src/oidcendpoint/user_authn/user.py | 3 +- src/oidcendpoint/userinfo.py | 2 +- tests/test_70_grant.py | 12 +- tests/test_71_sess_mngm_db.py | 26 ++-- tests/test_72_session_life.py | 65 +++++----- 7 files changed, 233 insertions(+), 76 deletions(-) diff --git a/src/oidcendpoint/grant.py b/src/oidcendpoint/grant.py index 83e2f0a..9a20c83 100644 --- a/src/oidcendpoint/grant.py +++ b/src/oidcendpoint/grant.py @@ -3,13 +3,17 @@ from typing import Optional from uuid import uuid1 +from oidcmsg.message import Message from oidcmsg.message import OPTIONAL_LIST_OF_SP_SEP_STRINGS from oidcmsg.message import OPTIONAL_LIST_OF_STRINGS from oidcmsg.message import SINGLE_OPTIONAL_JSON -from oidcmsg.message import Message from oidcmsg.time_util import utc_time_sans_frac +class MintingNotAllowed(Exception): + pass + + class GrantMessage(Message): c_param = { "scope": OPTIONAL_LIST_OF_SP_SEP_STRINGS, # As defined in RFC6749 @@ -126,7 +130,11 @@ def from_json(self, json_str): return self def supports_minting(self, token_type): - return token_type in self.usage_rules['supports_minting'] + _supports_minting = self.usage_rules.get("supports_minting") + if _supports_minting is None: + return False + else: + return token_type in _supports_minting class AuthorizationCode(Token): @@ -152,8 +160,8 @@ def set_defaults(self): class Grant(Item): def __init__(self, - scopes: Optional[list] = None, - claims: Optional[dict] = None, + scope: Optional[list] = None, + claim: Optional[dict] = None, resources: Optional[list] = None, authorization_details: Optional[dict] = None, token: Optional[list] = None, @@ -161,9 +169,9 @@ def __init__(self, issued_at: int = 0, expires_at: int = 0): Item.__init__(self, usage_rules=usage_rules, issued_at=issued_at, expires_at=expires_at) - self.scope = scopes or [] + self.scope = scope or [] self.authorization_details = authorization_details or None - self.claims = claims or None + self.claims = claim or None self.resources = resources or [] self.issued_token = token or [] self.id = uuid1().hex @@ -178,7 +186,7 @@ def replace(self, item: dict): for attr in ['scope', 'authorization_details', 'claims', 'resources']: setattr(self, attr, item.get(attr)) - def revoke(self): + def revoke_all(self): self.revoked = True for t in self.issued_token: t.revoked = True @@ -206,16 +214,24 @@ def to_json(self): def from_json(self, json_str): d = json.loads(json_str) for attr in ["scope", "authorization_details", "claims", "resources", "issued_at", - "not_before", - "expires_at", "revoked", "id"]: + "not_before", "expires_at", "revoked", "id"]: if attr in d: setattr(self, attr, d[attr]) if "issued_token" in d: setattr(self, "issued_token", [Token(**t) for t in d['issued_token']]) - def mint_token(self, token_type, **kwargs): - item = TOKEN_MAP[token_type](typ=token_type, **kwargs) + def mint_token(self, token_type: str, based_on: Optional[Token] = None, **kwargs) -> Token: + if based_on: + if based_on.supports_minting(token_type) and based_on.is_active(): + _base_on_ref = based_on.value + else: + raise MintingNotAllowed() + else: + _base_on_ref = None + + item = TOKEN_MAP[token_type](typ=token_type, based_on=_base_on_ref, **kwargs) self.issued_token.append(item) + return item def revoke_all_based_on(self, id): @@ -229,3 +245,8 @@ def get_token(self, val): if t.value == val: return t return None + + def revoke_token(self, val): + for t in self.issued_token: + if t.value == val: + t.revoked = True diff --git a/src/oidcendpoint/session_management.py b/src/oidcendpoint/session_management.py index 44979d6..cc65f2f 100644 --- a/src/oidcendpoint/session_management.py +++ b/src/oidcendpoint/session_management.py @@ -1,9 +1,15 @@ import hashlib import logging +from oidcendpoint import rndstr + logger = logging.getLogger(__name__) +class Revoked(Exception): + pass + + def db_key(*args): return ':'.join(args) @@ -12,7 +18,7 @@ def unpack_db_key(key): return key.split(':') -def pairwise_id(uid, sector_identifier, salt, **kwargs): +def pairwise_id(uid, sector_identifier, salt="", **kwargs): return hashlib.sha256(("%s%s%s" % (uid, sector_identifier, salt)).encode("utf-8")).hexdigest() @@ -20,9 +26,10 @@ def public_id(uid, salt="", **kwargs): return hashlib.sha256("{}{}".format(uid, salt).encode("utf-8")).hexdigest() -class Info(object): +class SessionInfo(object): def __init__(self, **kwargs): self._db = kwargs or {} + self._revoked = False if "subordinate" not in self._db: self._db["subordinate"] = [] @@ -59,12 +66,21 @@ def values(self): def items(self): return self._db.items() + def __contains__(self, item): + return item in self._db + + def revoke(self): + self._revoked = True -class UserInfo(Info): + def is_revoked(self): + return self._revoked + + +class UserSessionInfo(SessionInfo): pass -class ClientInfo(Info): +class ClientSessionInfo(SessionInfo): def find_grant(self, val): for grant in self._db["subordinate"]: token = grant.get_token(val) @@ -103,6 +119,8 @@ def set(self, path: list, value: object): if client_id in _userinfo['subordinate']: _cid_key = db_key(uid, client_id) _cid_info = self._db[db_key(uid, client_id)] + if _cid_info.is_revoked(): + raise Revoked("Session is revoked") if _cid_info: if grant_id: _gid_key = db_key(uid, client_id, grant_id) @@ -115,23 +133,23 @@ def set(self, path: list, value: object): self._db[_cid_key] = _cid_info.add_subordinate(grant_id) self._db[_gid_key] = value else: - self._db.set[_cid_key] = value + self._db[_cid_key] = value else: _userinfo.add_subordinate(client_id) if grant_id: - _cid_info = ClientInfo() + _cid_info = ClientSessionInfo() _cid_info.add_subordinate(grant_id) self._db[_cid_key] = _cid_info self._db[db_key(uid, client_id, grant_id)] = value else: - _cid_info = ClientInfo() + _cid_info = ClientSessionInfo() self._db[_cid_key] = _cid_info self._db[uid] = _userinfo else: _userinfo.add_subordinate(client_id) self._db[uid] = _userinfo if grant_id: - _cid_info = ClientInfo() + _cid_info = ClientSessionInfo() _cid_info.add_subordinate(grant_id) self._db[db_key(uid, client_id, grant_id)] = value else: @@ -143,10 +161,10 @@ def set(self, path: list, value: object): self._db[uid] = value else: if client_id: - _user_info = UserInfo() + _user_info = UserSessionInfo() _user_info.add_subordinate(client_id) if grant_id: - _cid_info = ClientInfo() + _cid_info = ClientSessionInfo() _cid_info.add_subordinate(grant_id) self._db[db_key(uid, client_id, grant_id)] = value else: @@ -175,6 +193,9 @@ def get(self, path: list): except KeyError: return {} else: + if client_session_info.is_revoked(): + raise Revoked("Session is revoked") + if grant_id is None: return client_session_info @@ -194,32 +215,56 @@ def delete(self, path): pass else: if client_id: - if client_id in _dic['client_id']: + if client_id in _dic['subordinate']: try: _cinfo = self._db[db_key(uid, client_id)] except KeyError: pass else: if grant_id: - if grant_id in _cinfo['grant_id']: + if grant_id in _cinfo['subordinate']: self._db.__delitem__(db_key(uid, client_id, grant_id)) else: + for grant_id in _cinfo['subordinate']: + self._db.__delitem__(db_key(uid, client_id, grant_id)) self._db.__delitem__(db_key(uid, client_id)) + + _dic["subordinate"].remove(client_id) + self._db[uid] = _dic else: pass else: self._db.__delitem__(uid) + def update(self, path, new_info): + _info = self.get(path) + _info.update(new_info) + self.set(path, _info) + class SessionManager(Database): - def __init__(self, handler, storage=None): - Database.__init__(self, storage) + def __init__(self, db, handler, userinfo=None, sub_func=None): + Database.__init__(self, db) self.token_handler = handler + self.userinfo = userinfo + self.salt = rndstr(32) + + # this allows the subject identifier minters to be defined by someone + # else then me. + if sub_func is None: + self.sub_func = {"public": public_id, "pairwise": pairwise_id} + else: + self.sub_func = sub_func + if "public" not in sub_func: + self.sub_func["public"] = public_id + if "pairwise" not in sub_func: + self.sub_func["pairwise"] = pairwise_id - def get_user(self, uid): - user = self.get(uid) + def get_user_info(self, uid): + return self.get(uid) - def find_grant(self, user_id, client_id, token_value): + def find_grant(self, session_id, token_value): + user_id, client_id = unpack_db_key(session_id) client_info = self.get([user_id, client_id]) for grant_id in client_info["subordinate"]: grant = self.get([user_id, client_id, grant_id]) @@ -228,3 +273,82 @@ def find_grant(self, user_id, client_id, token_value): return grant, token return None + + def create_session(self, authn_event, auth_req, user_id, client_id="", + sub_type="public", sector_identifier='', **kwargs): + """ + + :param authn_event: + :param auth_req: Authorization Request + :param client_id: Client ID + :param user_id: User ID + :param kwargs: extra keyword arguments + :return: + """ + + try: + _ = self.get([user_id]) + except KeyError: + user_info = UserSessionInfo(authentication_event=authn_event) + self.set([user_id], user_info) + + client_info = ClientSessionInfo( + authorization_request=auth_req, + sub=self.sub_func[sub_type](user_id, salt=self.salt, + sector_identifier=sector_identifier) + ) + + if not client_id: + client_id = auth_req['client_id'] + + self.set([user_id, client_id], client_info) + + def _update_client_info(self, session_id, new_information): + _path = unpack_db_key(session_id) + _client_info = self.get(_path) + _client_info.update(new_information) + self.set(_path, _client_info) + + def do_sub(self, session_id, sector_id="", subject_type="public"): + """ + Create and store a subject identifier + + :param session_id: Session ID + :param sector_id: For pairwise identifiers, an Identifier for the RP group + :param subject_type: 'pairwise'/'public' + :return: + """ + _path = unpack_db_key(session_id) + sub = self.sub_func[subject_type](_path[0], salt=self.salt, sector_identifier=sector_id) + self._update_client_info(session_id, {'sub': sub}) + return sub + + def __getitem__(self, item): + return self.get(unpack_db_key(item)) + + def revoke_token(self, session_id, token_value): + grant, token = self.find_grant(session_id, token_value) + token.revoked = True + + def get_sids_by_user_id(self, user_id): + user_info = self.get([user_id]) + return [db_key(user_id, c) for c in user_info['subordinate']] + + def get_authentication_event(self, user_id): + try: + user_info = self.get([user_id]) + except KeyError: + return None + + return user_info["authentication_event"] + + def revoke_session(self, session_id): + _path = unpack_db_key(session_id) + _info = self.get(_path) + _info.revoke() + self.set(_path, _info) + + def grants(self, session_id): + uid, cid = unpack_db_key(session_id) + _csi = self.get([uid, cid]) + return [self.get([uid, cid, gid]) for gid in _csi['subordinate']] diff --git a/src/oidcendpoint/user_authn/user.py b/src/oidcendpoint/user_authn/user.py index e197ef6..dce3116 100755 --- a/src/oidcendpoint/user_authn/user.py +++ b/src/oidcendpoint/user_authn/user.py @@ -4,13 +4,14 @@ import logging import sys import time -import warnings from urllib.parse import unquote +import warnings from cryptography.hazmat.primitives.ciphers.aead import AESGCM from cryptojwt.jwt import JWT from oidcendpoint import sanitize +from oidcendpoint.authn_event import create_authn_event from oidcendpoint.exception import FailedAuthentication from oidcendpoint.exception import ImproperlyConfigured from oidcendpoint.exception import InvalidCookieSign diff --git a/src/oidcendpoint/userinfo.py b/src/oidcendpoint/userinfo.py index bd2fa01..33b7160 100755 --- a/src/oidcendpoint/userinfo.py +++ b/src/oidcendpoint/userinfo.py @@ -124,7 +124,7 @@ def collect_user_info( :param userinfo_claims: user info claims :return: User info """ - authn_req = session["authn_req"] + authn_req = session["authorization_request"] if scope_to_claims is None: scope_to_claims = endpoint_context.scope2claims diff --git a/tests/test_70_grant.py b/tests/test_70_grant.py index e3bce23..1d75043 100644 --- a/tests/test_70_grant.py +++ b/tests/test_70_grant.py @@ -39,9 +39,9 @@ def test_access_token(): def test_grant(): grant = Grant() code = grant.mint_token("authorization_code", value="ABCD") - access_token = grant.mint_token("access_token", value="1234", based_on=code.id) - refresh_token = grant.mint_token("refresh_token", value="1234", based_on=code.id) - grant.revoke() + access_token = grant.mint_token("access_token", value="1234", based_on=code) + refresh_token = grant.mint_token("refresh_token", value="1234", based_on=code) + grant.revoke_all() assert code.revoked is True assert access_token.revoked is True assert refresh_token.revoked is True @@ -50,12 +50,12 @@ def test_grant(): def test_grant_revoked_based_on(): grant = Grant() code = grant.mint_token("authorization_code", value="ABCD") - access_token = grant.mint_token("access_token", value="1234", based_on=code.id) - refresh_token = grant.mint_token("refresh_token", value="1234", based_on=code.id) + access_token = grant.mint_token("access_token", value="1234", based_on=code) + refresh_token = grant.mint_token("refresh_token", value="1234", based_on=code) code.register_usage() if code.max_usage_reached(): - grant.revoke_all_based_on(code.id) + grant.revoke_all_based_on(code.value) assert code.is_active() is False assert access_token.is_active() is False diff --git a/tests/test_71_sess_mngm_db.py b/tests/test_71_sess_mngm_db.py index be93cb0..8f88d06 100644 --- a/tests/test_71_sess_mngm_db.py +++ b/tests/test_71_sess_mngm_db.py @@ -4,9 +4,9 @@ from oidcendpoint.authn_event import create_authn_event from oidcendpoint.grant import Grant from oidcendpoint.grant import Token -from oidcendpoint.session_management import ClientInfo +from oidcendpoint.session_management import ClientSessionInfo from oidcendpoint.session_management import Database -from oidcendpoint.session_management import UserInfo +from oidcendpoint.session_management import UserSessionInfo from oidcendpoint.session_management import public_id @@ -19,15 +19,15 @@ def test_user_info(self): with pytest.raises(KeyError): self.db.get(['diana']) - user_info = UserInfo(foo="bar") + user_info = UserSessionInfo(foo="bar") self.db.set(['diana'], user_info) user_info = self.db.get(['diana']) assert user_info["foo"] == "bar" def test_client_info(self): - user_info = UserInfo(foo="bar") + user_info = UserSessionInfo(foo="bar") self.db.set(['diana'], user_info) - client_info = ClientInfo(sid= "abcdef") + client_info = ClientSessionInfo(sid="abcdef") self.db.set(['diana', "client_1"], client_info) user_info = self.db.get(['diana']) @@ -56,12 +56,24 @@ def test_jump_ahead(self): def test_step_wise(self): salt = "natriumklorid" # store user info - self.db.set(['diana'], UserInfo(authn_event = create_authn_event('diana', salt))) + self.db.set(['diana'], + UserSessionInfo(authentication_event=create_authn_event('diana', salt))) # Client specific information - self.db.set(['diana', 'client_1'], ClientInfo(sub= public_id('diana', salt))) + self.db.set(['diana', 'client_1'], ClientSessionInfo(sub=public_id( + 'diana', salt))) # Grant grant = Grant() access_code = Token('access_code', value='1234567890') grant.issued_token.append(access_code) self.db.set(['diana', 'client_1', 'G1'], grant) + + def test_removed(self): + grant = Grant() + access_code = Token('access_code', value='1234567890') + grant.issued_token.append(access_code) + + self.db.set(['diana', "client_1", "G1"], grant) + self.db.delete(['diana', 'client_1']) + with pytest.raises(ValueError): + self.db.get(['diana', "client_1", "G1"]) diff --git a/tests/test_72_session_life.py b/tests/test_72_session_life.py index 42f6ebf..d93cf4e 100644 --- a/tests/test_72_session_life.py +++ b/tests/test_72_session_life.py @@ -18,12 +18,12 @@ from oidcendpoint.oidc.registration import Registration from oidcendpoint.oidc.session import Session from oidcendpoint.oidc.token import AccessToken -from oidcendpoint.session_management import ClientInfo +from oidcendpoint.session_management import ClientSessionInfo from oidcendpoint.session_management import db_key from oidcendpoint.session_management import public_id from oidcendpoint.session_management import SessionManager from oidcendpoint.session_management import unpack_db_key -from oidcendpoint.session_management import UserInfo +from oidcendpoint.session_management import UserSessionInfo from oidcendpoint.token_handler import DefaultToken from oidcendpoint.token_handler import TokenHandler from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD @@ -51,7 +51,7 @@ def setup_token_handler(self): refresh_token_handler=refresh_token_handler, ) - self.session_manager = SessionManager(handler) + self.session_manager = SessionManager({}, handler=handler) def auth(self): # Start with an authentication request @@ -68,29 +68,27 @@ def auth(self): user_id = "diana" # User info is stored in the Session DB - - user_info = UserInfo() - self.session_manager.set([user_id], user_info) - - # Now for client session information - salt = "natriumklorid" authn_event = create_authn_event( user_id, - salt, + self.session_manager.salt, authn_info=INTERNETPROTOCOLPASSWORD, authn_time=time_sans_frac(), ) - client_info = ClientInfo( + user_info = UserSessionInfo(authenticationEvent=authn_event) + self.session_manager.set([user_id], user_info) + + # Now for client session information + + client_info = ClientSessionInfo( authorization_request=AUTH_REQ, - authenticationEvent=authn_event, - sub=public_id(user_id, salt) + sub=public_id(user_id, self.session_manager.salt) ) self.session_manager.set([user_id, AUTH_REQ['client_id']], client_info) # The user consent module produces a Grant instance - grant = Grant(scopes=AUTH_REQ['scope'], resources=[AUTH_REQ['client_id']]) + grant = Grant(scope=AUTH_REQ['scope'], resources=[AUTH_REQ['client_id']]) # the grant is assigned to a session (user_id, client_id) @@ -128,7 +126,8 @@ def test_code_flow(self): # token I can easily find the grant # client_info = self.session_manager.get([user_id, TOKEN_REQ['client_id']]) - grant, tok = self.session_manager.find_grant(user_id, TOKEN_REQ['client_id'], + session_id = db_key(user_id, TOKEN_REQ['client_id']) + grant, tok = self.session_manager.find_grant(session_id, TOKEN_REQ['code']) # Verify that it's of the correct type and can be used @@ -143,7 +142,7 @@ def test_code_flow(self): 'access_token', value=self.session_manager.token_handler["access_token"](user_id), expires_at=time_sans_frac() + 900, # 15 minutes from now - based_on=tok.id # Means the token (tok) was used to mint this token + based_on=tok # Means the token (tok) was used to mint this token ) assert tok.supports_minting("refresh_token") @@ -151,7 +150,7 @@ def test_code_flow(self): refresh_token = grant.mint_token( 'refresh_token', value=self.session_manager.token_handler["refresh_token"](user_id), - based_on=tok.id + based_on=tok ) tok.register_usage() @@ -168,8 +167,8 @@ def test_code_flow(self): scope=["openid", "mail", "offline_access"] ) - grant, reftok = self.session_manager.find_grant(user_id, - REFRESH_TOKEN_REQ['client_id'], + session_id = db_key(user_id,REFRESH_TOKEN_REQ['client_id']) + grant, reftok = self.session_manager.find_grant(session_id, REFRESH_TOKEN_REQ['refresh_token']) assert reftok.supports_minting("access_token") @@ -178,7 +177,7 @@ def test_code_flow(self): 'access_token', value=self.session_manager.token_handler["access_token"](user_id), expires_at=time_sans_frac() + 900, # 15 minutes from now - based_on=reftok.id # Means the token (tok) was used to mint this token + based_on=reftok # Means the token (tok) was used to mint this token ) assert access_token_2.is_active() @@ -304,7 +303,7 @@ def setup_session_manager(self): } self.endpoint_context = EndpointContext(conf, keyjar=KEYJAR) - self.session_manager = SessionManager(self.endpoint_context.sdb.handler) + self.session_manager = SessionManager({}, handler=self.endpoint_context.sdb.handler) def auth(self): # Start with an authentication request @@ -322,7 +321,7 @@ def auth(self): # User info is stored in the Session DB - user_info = UserInfo() + user_info = UserSessionInfo() self.session_manager.set([user_id], user_info) # Now for client session information @@ -334,7 +333,7 @@ def auth(self): authn_time=time_sans_frac(), ) - client_info = ClientInfo( + client_info = ClientSessionInfo( authorization_request=AUTH_REQ, authenticationEvent=authn_event, sub=public_id(user_id, salt) @@ -343,7 +342,7 @@ def auth(self): # The user consent module produces a Grant instance - grant = Grant(scopes=AUTH_REQ['scope'], resources=[AUTH_REQ['client_id']]) + grant = Grant(scope=AUTH_REQ['scope'], resources=[AUTH_REQ['client_id']]) # the grant is assigned to a session (user_id, client_id) @@ -383,7 +382,8 @@ def test_code_flow(self): # token I can easily find the grant # client_info = self.session_manager.get([user_id, TOKEN_REQ['client_id']]) - grant, tok = self.session_manager.find_grant(user_id, TOKEN_REQ['client_id'], + session_id = db_key(user_id, TOKEN_REQ['client_id']) + grant, tok = self.session_manager.find_grant(session_id, TOKEN_REQ['code']) # Verify that it's of the correct type and can be used @@ -412,15 +412,16 @@ def test_code_flow(self): sub=client_info['sub'] ), expires_at=time_sans_frac() + 900, # 15 minutes from now - based_on=tok.id # Means the token (tok) was used to mint this token + based_on=tok # Means the token (tok) was used to mint this token ) - assert tok.supports_minting("refresh_token") + # this test is include in the mint_token methods + # assert tok.supports_minting("refresh_token") refresh_token = grant.mint_token( 'refresh_token', value=self.session_manager.token_handler["refresh_token"](db_key(user_id, client_id)), - based_on=tok.id + based_on=tok ) tok.register_usage() @@ -437,13 +438,11 @@ def test_code_flow(self): scope=["openid", "mail", "offline_access"] ) - grant, reftok = self.session_manager.find_grant(user_id, - REFRESH_TOKEN_REQ['client_id'], + session_id = db_key(user_id, REFRESH_TOKEN_REQ['client_id']) + grant, reftok = self.session_manager.find_grant(session_id, REFRESH_TOKEN_REQ['refresh_token']) # Can I use this token to mint another token ? - assert reftok.supports_minting("access_token") - assert reftok.is_active() assert grant.is_active() user_claims = self.endpoint_context.userinfo(user_id, client_id=TOKEN_REQ["client_id"], @@ -459,7 +458,7 @@ def test_code_flow(self): user_claims=user_claims ), expires_at=time_sans_frac() + 900, # 15 minutes from now - based_on=reftok.id # Means the refresh token (reftok) was used to mint this token + based_on=reftok # Means the refresh token (reftok) was used to mint this token ) assert access_token_2.is_active() From cbd4d32c6e83e933f9af3ef2fc638529a37685ec Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Fri, 6 Nov 2020 14:23:24 +0100 Subject: [PATCH 009/127] Deal with commonality. --- src/oidcendpoint/jwt_token.py | 124 ++++++++++++++++++++-------------- 1 file changed, 75 insertions(+), 49 deletions(-) diff --git a/src/oidcendpoint/jwt_token.py b/src/oidcendpoint/jwt_token.py index a33d765..4cc494c 100644 --- a/src/oidcendpoint/jwt_token.py +++ b/src/oidcendpoint/jwt_token.py @@ -1,38 +1,91 @@ +from typing import Optional + from cryptojwt import JWT from cryptojwt.jws.exception import JWSException from oidcendpoint.exception import ToOld from oidcendpoint.scopes import convert_scopes2claims +from oidcendpoint.token_handler import is_expired from oidcendpoint.token_handler import Token from oidcendpoint.token_handler import UnknownToken -from oidcendpoint.token_handler import is_expired -class JWTToken(Token): +class ClaimsInterface: init_args = { "add_claims_by_scope": False, - "enable_claims_per_client": False, - "add_claims": {}, - } + "enable_claims_per_client": False + } + def __init__(self, endpoint_context, **kwargs): + self.endpoint_context = endpoint_context + self.scope_claims_map = kwargs.get("scope_claims_map", endpoint_context.scope2claims) + self.add_claims_by_scope = kwargs.get("add_claims_by_scope", + self.init_args["add_claims_by_scope"]) + self.enable_claims_per_client = kwargs.get("enable_claims_per_client", + self.init_args["enable_claims_per_client"]) + + def _get_client_claims(self, client_id): + if self.enable_claims_per_client: + client_info = self.endpoint_context.cdb.get(client_id, {}) + return client_info.get("introspection_claims") + else: + return [] + + def _get_user_info(self, token_info): + user_id = self.endpoint_context.sdb.sso_db.get_uid_by_sid(token_info["sid"]) + return self.endpoint_context.userinfo(user_id, client_id=None) + + def add_claims(self, client_id, user_id, payload, scopes, claims_restriction): + if claims_restriction is None: + user_info = self.endpoint_context.userinfo(user_id, client_id=None) + payload.update(user_info) + elif claims_restriction == {}: # Nothing is allowed + pass + else: + possible_claims = self._get_client_claims(client_id) + if self.add_claims_by_scope: + _claims = convert_scopes2claims(scopes, map=self.scope_claims_map).keys() + possible_claims = list(set(possible_claims).union(_claims)) + + if possible_claims: + _claims = {c: None for c in + set(possible_claims).intersection(set(claims_restriction.key()))} + _claims.update(claims_restriction) + else: + _claims = claims_restriction + + if _claims: + user_info = self.endpoint_context.userinfo(user_id, client_id=None, + user_info_claims=_claims) + for attr in _claims: + try: + payload[attr] = user_info[attr] + except KeyError: + pass + + +class JWTToken(Token): def __init__( self, typ, keyjar=None, - issuer=None, - aud=None, - alg="ES256", - lifetime=300, + issuer: str = None, + aud: Optional[list] = None, + alg: str = "ES256", + lifetime: int = 300, ec=None, - token_type="Bearer", + token_type: str = "Bearer", + add_claims: bool = False, **kwargs - ): + ): Token.__init__(self, typ, **kwargs) self.token_type = token_type self.lifetime = lifetime + self.claims_interface = ClaimsInterface(ec, **kwargs) + self.args = { - (k, v) for k, v in kwargs.items() if k not in self.init_args.keys() - } + (k, v) for k, v in kwargs.items() if k not in self.claims_interface.init_args.keys() + } self.key_jar = keyjar or ec.keyjar self.issuer = issuer or ec.issuer @@ -41,29 +94,13 @@ def __init__( self.def_aud = aud or [] self.alg = alg - self.scope_claims_map = kwargs.get("scope_claims_map", ec.scope2claims) - - self.add_claims = self.init_args["add_claims"] - self.add_claims_by_scope = self.init_args["add_claims_by_scope"] - self.enable_claims_per_client = self.init_args["enable_claims_per_client"] - - for param, default in self.init_args.items(): - setattr(self, param, kwargs.get(param, default)) - - def do_add_claims(self, payload, uinfo, claims): - for attr in claims: - if attr == "sub": - continue - try: - payload[attr] = uinfo[attr] - except KeyError: - pass + self.add_claims = add_claims def __call__( self, sid: str, **kwargs - ): + ): """ Return a token. @@ -73,25 +110,14 @@ def __call__( payload = {"sid": sid, "ttype": self.type, "sub": kwargs['sub']} - _user_claims = kwargs.get('user_claims') - _client_id = kwargs.get('client_id') _scopes = kwargs.get('scope') + _client_id = kwargs.get('client_id') + _user_id = kwargs.get('user_id') + _claims = kwargs.get('claims') if self.add_claims: - self.do_add_claims(payload, _user_claims, self.add_claims) - if self.add_claims_by_scope: - _allowed_claims = self.cntx.claims_handler.allowed_claims(_client_id, self.cntx) - self.do_add_claims( - payload, - _user_claims, - convert_scopes2claims(_scopes, _allowed_claims, map=self.scope_claims_map).keys(), - ) - # Add claims if is access token - if self.type == "T" and self.enable_claims_per_client: - client = self.cdb.get(_client_id, {}) - client_claims = client.get("access_token_claims") - if client_claims: - self.do_add_claims(payload, _user_claims, client_claims) + self.claims_interface.add_claims(_client_id, _user_id, payload, claims=_claims, + scopes=_scopes) # payload.update(kwargs) signer = JWT( @@ -99,7 +125,7 @@ def __call__( iss=self.issuer, lifetime=self.lifetime, sign_alg=self.alg, - ) + ) _aud = kwargs.get('aud') if _aud is None: @@ -132,7 +158,7 @@ def info(self, token): "type": _payload["ttype"], "exp": _payload["exp"], "handler": self, - } + } return _res def is_expired(self, token, when=0): From 4c09c3a90182359a038e55136a59e382ef11e61b Mon Sep 17 00:00:00 2001 From: roland Date: Tue, 10 Nov 2020 15:25:38 +0100 Subject: [PATCH 010/127] Getting closer --- src/oidcendpoint/authz/__init__.py | 46 +- src/oidcendpoint/authz/old.init.py | 78 ++ src/oidcendpoint/common/authorization.py | 52 +- src/oidcendpoint/endpoint_context.py | 22 +- src/oidcendpoint/grant.py | 2 +- src/oidcendpoint/id_token.py | 149 ++- src/oidcendpoint/jwt_token.py | 77 +- src/oidcendpoint/oauth2/authorization.py | 115 ++- src/oidcendpoint/oauth2/introspection.py | 23 +- src/oidcendpoint/oauth2/old_introspection.py | 117 +++ src/oidcendpoint/oidc/add_on/pkce.py | 9 +- src/oidcendpoint/oidc/authorization.py | 227 +++-- src/oidcendpoint/oidc/old_authorization.py | 740 ++++++++++++++ src/oidcendpoint/oidc/refresh_token.py | 37 +- src/oidcendpoint/oidc/session.py | 75 +- src/oidcendpoint/oidc/token.py | 130 ++- src/oidcendpoint/oidc/userinfo.py | 34 +- src/oidcendpoint/old_id_token.py | 288 ++++++ .../{session.py => old_session.py} | 0 src/oidcendpoint/scopes.py | 20 +- src/oidcendpoint/session_management.py | 106 +- src/oidcendpoint/user_authn/user.py | 1 - src/oidcendpoint/userinfo.py | 219 ++-- src/oidcendpoint/util.py | 3 +- tests/{test_70_grant.py => test_01_grant.py} | 0 ...ess_mngm_db.py => test_01_sess_mngm_db.py} | 0 ...ession_life.py => test_01_session_life.py} | 137 +-- tests/test_03_id_token.py | 356 ++++--- tests/test_05_sso_db.py | 135 --- tests/test_07_userinfo.py | 251 ++--- tests/test_08_session.py | 520 ---------- ...oidc_authz.py => test_10_oidc_authz.py.no} | 0 .../test_24_oauth2_authorization_endpoint.py | 113 ++- ...st_24_oauth2_authorization_endpoint_jar.py | 2 + tests/test_24_oidc_authorization_endpoint.py | 37 +- .../test_24_oidc_authorization_endpoint.py.no | 933 ++++++++++++++++++ tests/test_25_oidc_token_endpoint.py | 118 ++- tests/test_25_oidc_token_endpoint.py.no | 255 +++++ tests/test_26_oidc_userinfo_endpoint.py | 285 +++--- tests/test_26_oidc_userinfo_endpoint.py.no | 353 +++++++ tests/test_30_oidc_end_session.py | 60 +- tests/test_30_oidc_end_session.py.no | 573 +++++++++++ tests/users.json | 3 - 43 files changed, 4736 insertions(+), 1965 deletions(-) create mode 100755 src/oidcendpoint/authz/old.init.py create mode 100644 src/oidcendpoint/oauth2/old_introspection.py create mode 100755 src/oidcendpoint/oidc/old_authorization.py create mode 100755 src/oidcendpoint/old_id_token.py rename src/oidcendpoint/{session.py => old_session.py} (100%) rename tests/{test_70_grant.py => test_01_grant.py} (100%) rename tests/{test_71_sess_mngm_db.py => test_01_sess_mngm_db.py} (100%) rename tests/{test_72_session_life.py => test_01_session_life.py} (89%) delete mode 100644 tests/test_05_sso_db.py delete mode 100644 tests/test_08_session.py rename tests/{test_10_oidc_authz.py => test_10_oidc_authz.py.no} (100%) create mode 100755 tests/test_24_oidc_authorization_endpoint.py.no create mode 100755 tests/test_25_oidc_token_endpoint.py.no create mode 100755 tests/test_26_oidc_userinfo_endpoint.py.no create mode 100644 tests/test_30_oidc_end_session.py.no diff --git a/src/oidcendpoint/authz/__init__.py b/src/oidcendpoint/authz/__init__.py index 3869cc2..e25eb25 100755 --- a/src/oidcendpoint/authz/__init__.py +++ b/src/oidcendpoint/authz/__init__.py @@ -2,8 +2,7 @@ import logging import sys -from oidcendpoint import sanitize -from oidcendpoint.cookie import cookie_value +from oidcendpoint.grant import Grant logger = logging.getLogger(__name__) @@ -14,50 +13,19 @@ class AuthzHandling(object): def __init__(self, endpoint_context, **kwargs): self.endpoint_context = endpoint_context self.cookie_dealer = endpoint_context.cookie_dealer - self.permdb = {} self.kwargs = kwargs - def __call__(self, *args, **kwargs): - return "" - - def set(self, uid, client_id, permission): - try: - self.permdb[uid][client_id] = permission - except KeyError: - self.permdb[uid] = {client_id: permission} - - def permissions(self, cookie=None, **kwargs): - if cookie is None: - return None - else: - logger.debug("kwargs: %s" % sanitize(kwargs)) - - val = self.cookie_dealer.get_cookie_value(cookie) - if val is None: - return None - else: - b64, _ts, typ = val - - info = cookie_value(b64) - return self.get(info["sub"], info["client_id"]) - - def get(self, uid, client_id): - try: - return self.permdb[uid][client_id] - except KeyError: - return None + def __call__(self, user_id, client_id, request): + permission = {k: v for k, v in request.items() if k in ["scope", "claims"]} + return Grant(**permission) class Implicit(AuthzHandling): - def __init__(self, endpoint_context, permission="implicit"): + def __init__(self, endpoint_context): AuthzHandling.__init__(self, endpoint_context) - self.permission = permission - - def permissions(self, cookie=None, **kwargs): - return self.permission - def get(self, uid, client_id): - return self.permission + def __call__(self, user_id, client_id, request): + return Grant() def factory(msgtype, endpoint_context, **kwargs): diff --git a/src/oidcendpoint/authz/old.init.py b/src/oidcendpoint/authz/old.init.py new file mode 100755 index 0000000..3869cc2 --- /dev/null +++ b/src/oidcendpoint/authz/old.init.py @@ -0,0 +1,78 @@ +import inspect +import logging +import sys + +from oidcendpoint import sanitize +from oidcendpoint.cookie import cookie_value + +logger = logging.getLogger(__name__) + + +class AuthzHandling(object): + """ Class that allow an entity to manage authorization """ + + def __init__(self, endpoint_context, **kwargs): + self.endpoint_context = endpoint_context + self.cookie_dealer = endpoint_context.cookie_dealer + self.permdb = {} + self.kwargs = kwargs + + def __call__(self, *args, **kwargs): + return "" + + def set(self, uid, client_id, permission): + try: + self.permdb[uid][client_id] = permission + except KeyError: + self.permdb[uid] = {client_id: permission} + + def permissions(self, cookie=None, **kwargs): + if cookie is None: + return None + else: + logger.debug("kwargs: %s" % sanitize(kwargs)) + + val = self.cookie_dealer.get_cookie_value(cookie) + if val is None: + return None + else: + b64, _ts, typ = val + + info = cookie_value(b64) + return self.get(info["sub"], info["client_id"]) + + def get(self, uid, client_id): + try: + return self.permdb[uid][client_id] + except KeyError: + return None + + +class Implicit(AuthzHandling): + def __init__(self, endpoint_context, permission="implicit"): + AuthzHandling.__init__(self, endpoint_context) + self.permission = permission + + def permissions(self, cookie=None, **kwargs): + return self.permission + + def get(self, uid, client_id): + return self.permission + + +def factory(msgtype, endpoint_context, **kwargs): + """ + Factory method that can be used to easily instantiate a class instance + + :param msgtype: The name of the class + :param kwargs: Keyword arguments + :return: An instance of the class or None if the name doesn't match any + known class. + """ + for name, obj in inspect.getmembers(sys.modules[__name__]): + if inspect.isclass(obj) and issubclass(obj, AuthzHandling): + try: + if obj.__name__ == msgtype: + return obj(endpoint_context, **kwargs) + except AttributeError: + pass diff --git a/src/oidcendpoint/common/authorization.py b/src/oidcendpoint/common/authorization.py index 2529b72..6564869 100755 --- a/src/oidcendpoint/common/authorization.py +++ b/src/oidcendpoint/common/authorization.py @@ -3,6 +3,9 @@ from urllib.parse import urlencode from urllib.parse import urlparse +from oidcmsg.time_util import time_sans_frac + +from oidcendpoint.session_management import unpack_db_key from oidcmsg.exception import ParameterError from oidcmsg.exception import URIError from oidcmsg.message import Message @@ -225,7 +228,8 @@ def create_authn_response(endpoint, request, sid): fragment_enc = False else: _context = endpoint.endpoint_context - _sinfo = _context.sdb[sid] + _mngr = endpoint.endpoint_context.session_manager + _session_info = _mngr[sid] if request.get("scope"): aresp["scope"] = request["scope"] @@ -234,27 +238,41 @@ def create_authn_response(endpoint, request, sid): handled_response_type = [] fragment_enc = True + if len(rtype) == 1 and "code" in rtype: fragment_enc = False - if "code" in request["response_type"]: - _code = aresp["code"] = _context.sdb[sid]["code"] - handled_response_type.append("code") - else: - _context.sdb.update(sid, code=None) - _code = None - - if "token" in rtype: - _dic = _context.sdb.upgrade_to_token(issue_refresh=False, key=sid) + grant = _mngr.grants(sid)[0] + user_id, client_id = unpack_db_key(sid) - logger.debug("_dic: %s" % sanitize(_dic)) - for key, val in _dic.items(): - if key in aresp.parameters() and val is not None: - aresp[key] = val - - handled_response_type.append("token") + if "code" in request["response_type"]: + _code = grant.mint_token( + 'authorization_code', + value=_mngr.token_handler["code"](user_id), + expires_at=time_sans_frac() + 300 # 5 minutes from now + ) + aresp["code"] = _code.value + handled_response_type.append("code") + else: + _code = None + + if "token" in rtype: + _access_token = grant.mint_token( + "access_token", + value=_mngr.token_handler["access_token"]( + sid, + client_id=client_id, + aud=grant.resources, + user_claims=None, + scope=grant.scope, + sub=_session_info['sub'], + based_on=_code + ), + expires_at=time_sans_frac() + 900 + ) - _access_token = aresp.get("access_token", None) + aresp['token'] = _access_token + handled_response_type.append("token") not_handled = rtype.difference(handled_response_type) if not_handled: diff --git a/src/oidcendpoint/endpoint_context.py b/src/oidcendpoint/endpoint_context.py index 10b6430..591bffd 100755 --- a/src/oidcendpoint/endpoint_context.py +++ b/src/oidcendpoint/endpoint_context.py @@ -14,8 +14,7 @@ from oidcendpoint.scopes import STANDARD_CLAIMS from oidcendpoint.scopes import Claims from oidcendpoint.scopes import Scopes -from oidcendpoint.session import create_session_db -from oidcendpoint.sso_db import SSODb +from oidcendpoint.session_management import create_session_manager from oidcendpoint.template_handler import Jinja2TemplateHandler from oidcendpoint.user_authn.authn_context import populate_authn_broker from oidcendpoint.util import allow_refresh_token @@ -100,21 +99,15 @@ def __init__( self.conf = conf # For my Dev environment - self.sso_db = None - self.session_db = None - self.state_db = None self.cdb = None self.jti_db = None self.registration_access_token = None self.add_boxes( { - "state": "state_db", "client": "cdb", "jti": "jti_db", "registration_access_token": "registration_access_token", - "sso": "sso_db", - "session": "session_db", }, self.db_conf, ) @@ -279,10 +272,9 @@ def set_claims_handler(self): self.claims_handler = Claims() def set_session_db(self): - self.do_session_db(SSODb(db=self.sso_db), self.session_db) + self.do_session_manager() # append userinfo db to the session db self.do_userinfo() - logger.debug("Session DB: {}".format(self.sdb.__dict__)) def do_add_on(self): if self.conf.get("add_on"): @@ -311,9 +303,9 @@ def do_login_hint_lookup(self): def do_userinfo(self): _conf = self.conf.get("userinfo") if _conf: - if self.sdb: + if self.session_manager: self.userinfo = init_user_info(_conf, self.cwd) - self.sdb.userinfo = self.userinfo + self.session_manager.userinfo = self.userinfo else: logger.warning("Cannot init_user_info if no session_db was provided.") @@ -357,9 +349,9 @@ def do_sub_func(self): else: self._sub_func[key] = args["function"] - def do_session_db(self, sso_db, db=None): - self.sdb = create_session_db( - self, self.th_args, db=db, sso_db=sso_db, sub_func=self._sub_func + def do_session_manager(self, db=None): + self.session_manager = create_session_manager( + self, self.th_args, db=db, sub_func=self._sub_func ) def do_endpoints(self): diff --git a/src/oidcendpoint/grant.py b/src/oidcendpoint/grant.py index 9a20c83..8f85ef5 100644 --- a/src/oidcendpoint/grant.py +++ b/src/oidcendpoint/grant.py @@ -171,7 +171,7 @@ def __init__(self, Item.__init__(self, usage_rules=usage_rules, issued_at=issued_at, expires_at=expires_at) self.scope = scope or [] self.authorization_details = authorization_details or None - self.claims = claim or None + self.claims = claim or {} # default is to not release any user information self.resources = resources or [] self.issued_token = token or [] self.id = uuid1().hex diff --git a/src/oidcendpoint/id_token.py b/src/oidcendpoint/id_token.py index a9747bb..984a19d 100755 --- a/src/oidcendpoint/id_token.py +++ b/src/oidcendpoint/id_token.py @@ -1,11 +1,17 @@ import logging +import uuid from cryptojwt.jws.utils import left_hash from cryptojwt.jwt import JWT +from oidcendpoint.session_management import unpack_db_key +from oidcendpoint.session_management import SessionInfo + +from oidcendpoint import rndstr from oidcendpoint.endpoint import construct_endpoint_info -from oidcendpoint.userinfo import collect_user_info -from oidcendpoint.userinfo import userinfo_in_id_token_claims +from oidcendpoint.grant import Item +from oidcendpoint.session_management import db_key +from oidcendpoint.userinfo import ClaimsInterface logger = logging.getLogger(__name__) @@ -50,7 +56,7 @@ def include_session_id(endpoint_context, client_id, where): def get_sign_and_encrypt_algorithms( - endpoint_context, client_info, payload_type, sign=False, encrypt=False + endpoint_context, client_info, payload_type, sign=False, encrypt=False ): args = {"sign": sign, "encrypt": encrypt} if sign: @@ -112,23 +118,19 @@ class IDToken(object): def __init__(self, endpoint_context, **kwargs): self.endpoint_context = endpoint_context self.kwargs = kwargs - self.enable_claims_per_client = kwargs.get("enable_claims_per_client", False) self.scope_to_claims = None self.provider_info = construct_endpoint_info( self.default_capabilities, **kwargs ) + self.claims_interface = ClaimsInterface(endpoint_context, "id_token", **kwargs) def payload( - self, - session, - acr="", - alg="RS256", - code=None, - access_token=None, - user_info=None, - auth_time=0, - lifetime=None, - extra_claims=None, + self, + session_id, + alg="RS256", + code=None, + access_token=None, + extra_claims=None, ): """ @@ -144,15 +146,18 @@ def payload( :return: IDToken instance """ - _args = {"sub": session["sub"]} - - if lifetime is None: - lifetime = DEF_LIFETIME + _mngr = self.endpoint_context.session_manager + session_information = _mngr.get_session_info(session_id) + _args = {"sub": session_information["client_session_info"]["sub"]} + for claim, attr in {"authn_time": "auth_time", "acr": "acr"}.items(): + _val = session_information["user_session_info"]["authentication_event"].get(claim) + if _val: + _args[attr] = _val - if auth_time: - _args["auth_time"] = auth_time - if acr: - _args["acr"] = acr + grant = _mngr.grants(session_id)[0] + _claims_restriction = grant.claims.get(self.claims_interface.usage) + user_info = self.claims_interface.get_user_claims(user_id=session_information["user_id"], + claims_restriction=_claims_restriction) if user_info: try: @@ -179,35 +184,34 @@ def payload( if access_token: _args["at_hash"] = left_hash(access_token.encode("utf-8"), halg) - authn_req = session["authn_req"] + authn_req = session_information["client_session_info"]["authorization_request"] if authn_req: try: _args["nonce"] = authn_req["nonce"] except KeyError: pass - return {"payload": _args, "lifetime": lifetime} + return _args def sign_encrypt( - self, - session_info, - client_id, - code=None, - access_token=None, - user_info=None, - sign=True, - encrypt=False, - lifetime=None, - extra_claims=None, + self, + session_id, + client_id, + code=None, + access_token=None, + sign=True, + encrypt=False, + lifetime=None, + extra_claims=None, ): """ Signed and or encrypt a IDToken + :param lifetime: How long the ID Token should be valid :param session_info: Session information :param client_id: Client ID :param code: Access grant :param access_token: Access Token - :param user_info: User information :param sign: If the JWT should be signed :param encrypt: If the JWT should be encrypted :param extra_claims: Extra claims to be added to the ID Token @@ -221,69 +225,52 @@ def sign_encrypt( _cntx, client_info, "id_token", sign=sign, encrypt=encrypt ) - _authn_event = session_info["authn_event"] - - _idt_info = self.payload( - session_info, - acr=_authn_event["authn_info"], + _payload = self.payload( + session_id=session_id, alg=alg_dict["sign_alg"], code=code, access_token=access_token, - user_info=user_info, - auth_time=_authn_event["authn_time"], - lifetime=lifetime, - extra_claims=extra_claims, + extra_claims=extra_claims ) + if lifetime is None: + lifetime = DEF_LIFETIME + _jwt = JWT( - _cntx.keyjar, iss=_cntx.issuer, lifetime=_idt_info["lifetime"], **alg_dict + _cntx.keyjar, iss=_cntx.issuer, lifetime=lifetime, **alg_dict ) - return _jwt.pack(_idt_info["payload"], recv=client_id) + return _jwt.pack(_payload, recv=client_id) - def make(self, req, sess_info, authn_req=None, user_claims=False, **kwargs): + def make(self, session_id, **kwargs): _context = self.endpoint_context - if authn_req: - _client_id = authn_req["client_id"] + user_id, client_id, grant_id = unpack_db_key(session_id) + + # Should I add session ID. This is about Single Logout. + if include_session_id(_context, client_id, "back") or include_session_id( + _context, client_id, "front"): + + # Note that this session ID is not the session ID the session manager is using. + # It must be possible to map from one to the other. + logout_session_id = uuid.uuid4().get_hex() + _item = SessionInfo() + _item.set("user_id", user_id) + _item.set("client_id", client_id) + # Store the map + _mngr = self.endpoint_context.session_manager + _mngr.set([logout_session_id], _item) + # add identifier to extra arguments + xargs = {"sid": logout_session_id} else: - _client_id = req["client_id"] - - _cinfo = _context.cdb[_client_id] + xargs = {} - idtoken_claims = dict(self.kwargs.get("available_claims", {})) - if self.enable_claims_per_client: - idtoken_claims.update(_cinfo.get("id_token_claims", {})) lifetime = self.kwargs.get("lifetime") - userinfo = userinfo_in_id_token_claims(_context, sess_info, idtoken_claims) - - if user_claims: - info = collect_user_info(_context, sess_info) - if userinfo is None: - userinfo = info - else: - userinfo.update(info) - - # Should I add session ID - req_sid = include_session_id( - _context, _client_id, "back" - ) or include_session_id(_context, _client_id, "front") - - if req_sid: - xargs = { - "sid": _context.sdb.get_sid_by_sub_and_client_id( - sess_info["sub"], _client_id - ) - } - else: - xargs = {} - return self.sign_encrypt( - sess_info, - _client_id, + session_id, + client_id, sign=True, - user_info=userinfo, lifetime=lifetime, extra_claims=xargs, **kwargs diff --git a/src/oidcendpoint/jwt_token.py b/src/oidcendpoint/jwt_token.py index 4cc494c..25464b3 100644 --- a/src/oidcendpoint/jwt_token.py +++ b/src/oidcendpoint/jwt_token.py @@ -4,64 +4,11 @@ from cryptojwt.jws.exception import JWSException from oidcendpoint.exception import ToOld -from oidcendpoint.scopes import convert_scopes2claims -from oidcendpoint.token_handler import is_expired +from oidcendpoint.session_management import db_key from oidcendpoint.token_handler import Token from oidcendpoint.token_handler import UnknownToken - - -class ClaimsInterface: - init_args = { - "add_claims_by_scope": False, - "enable_claims_per_client": False - } - - def __init__(self, endpoint_context, **kwargs): - self.endpoint_context = endpoint_context - self.scope_claims_map = kwargs.get("scope_claims_map", endpoint_context.scope2claims) - self.add_claims_by_scope = kwargs.get("add_claims_by_scope", - self.init_args["add_claims_by_scope"]) - self.enable_claims_per_client = kwargs.get("enable_claims_per_client", - self.init_args["enable_claims_per_client"]) - - def _get_client_claims(self, client_id): - if self.enable_claims_per_client: - client_info = self.endpoint_context.cdb.get(client_id, {}) - return client_info.get("introspection_claims") - else: - return [] - - def _get_user_info(self, token_info): - user_id = self.endpoint_context.sdb.sso_db.get_uid_by_sid(token_info["sid"]) - return self.endpoint_context.userinfo(user_id, client_id=None) - - def add_claims(self, client_id, user_id, payload, scopes, claims_restriction): - if claims_restriction is None: - user_info = self.endpoint_context.userinfo(user_id, client_id=None) - payload.update(user_info) - elif claims_restriction == {}: # Nothing is allowed - pass - else: - possible_claims = self._get_client_claims(client_id) - if self.add_claims_by_scope: - _claims = convert_scopes2claims(scopes, map=self.scope_claims_map).keys() - possible_claims = list(set(possible_claims).union(_claims)) - - if possible_claims: - _claims = {c: None for c in - set(possible_claims).intersection(set(claims_restriction.key()))} - _claims.update(claims_restriction) - else: - _claims = claims_restriction - - if _claims: - user_info = self.endpoint_context.userinfo(user_id, client_id=None, - user_info_claims=_claims) - for attr in _claims: - try: - payload[attr] = user_info[attr] - except KeyError: - pass +from oidcendpoint.token_handler import is_expired +from oidcendpoint.userinfo import ClaimsInterface class JWTToken(Token): @@ -77,15 +24,15 @@ def __init__( token_type: str = "Bearer", add_claims: bool = False, **kwargs - ): + ): Token.__init__(self, typ, **kwargs) self.token_type = token_type self.lifetime = lifetime - self.claims_interface = ClaimsInterface(ec, **kwargs) + self.claims_interface = ClaimsInterface(ec, "jwt_token", **kwargs) self.args = { (k, v) for k, v in kwargs.items() if k not in self.claims_interface.init_args.keys() - } + } self.key_jar = keyjar or ec.keyjar self.issuer = issuer or ec.issuer @@ -100,7 +47,7 @@ def __call__( self, sid: str, **kwargs - ): + ): """ Return a token. @@ -116,8 +63,10 @@ def __call__( _claims = kwargs.get('claims') if self.add_claims: - self.claims_interface.add_claims(_client_id, _user_id, payload, claims=_claims, - scopes=_scopes) + grant = self.claims_interface.endpoint_context.session_manager.grants(sid)[0] + user_info = self.claims_interface.get_user_claims( + _user_id, grant.claims.get("id_token", {})) + payload.update(user_info) # payload.update(kwargs) signer = JWT( @@ -125,7 +74,7 @@ def __call__( iss=self.issuer, lifetime=self.lifetime, sign_alg=self.alg, - ) + ) _aud = kwargs.get('aud') if _aud is None: @@ -158,7 +107,7 @@ def info(self, token): "type": _payload["ttype"], "exp": _payload["exp"], "handler": self, - } + } return _res def is_expired(self, token, when=0): diff --git a/src/oidcendpoint/oauth2/authorization.py b/src/oidcendpoint/oauth2/authorization.py index c4f6c23..162ad03 100755 --- a/src/oidcendpoint/oauth2/authorization.py +++ b/src/oidcendpoint/oauth2/authorization.py @@ -7,6 +7,14 @@ from cryptojwt.utils import as_unicode from cryptojwt.utils import b64d from cryptojwt.utils import b64e +from oidcendpoint.session_management import Revoked + +from oidcendpoint.session_management import db_key + +from oidcendpoint.session_management import ClientSessionInfo +from oidcmsg.time_util import time_sans_frac + +from oidcendpoint.session_management import unpack_db_key from oidcmsg import oauth2 from oidcmsg.exception import ParameterError from oidcmsg.oidc import AuthorizationResponse @@ -33,7 +41,6 @@ from oidcendpoint.exception import ToOld from oidcendpoint.exception import UnAuthorizedClientScope from oidcendpoint.exception import UnknownClient -from oidcendpoint.session import setup_session from oidcendpoint.token_handler import UnknownToken from oidcendpoint.user_authn.authn_context import pick_auth @@ -292,9 +299,13 @@ def setup_auth(self, request, redirect_uri, cinfo, cookie, acr=None, **kwargs): else: identity = json.loads(as_unicode(_id)) - session = self.endpoint_context.sdb[identity.get("sid")] - if not session or "revoked" in session: + try: + _csi = self.endpoint_context.session_manager[identity.get("sid")] + except Revoked: identity = None + else: + if _csi.is_active() is False: + identity = None authn_args = authn_args_gather(request, authn_class_ref, cinfo, **kwargs) @@ -316,16 +327,14 @@ def setup_auth(self, request, redirect_uri, cinfo, cookie, acr=None, **kwargs): # demand re-authentication return {"function": authn, "args": authn_args} else: + _mngr = self.endpoint_context.session_manager # I get back a dictionary user = identity["uid"] if "req_user" in kwargs: - sids = self.endpoint_context.sdb.get_sids_by_sub(kwargs["req_user"]) + sids = _mngr.get_sids_by_user_id(kwargs["req_user"]) if ( - sids - and user - != self.endpoint_context.sdb.get_authentication_event( - sids[-1] - ).uid + sids + and user != _mngr.get_authentication_event(sids[-1]).uid ): logger.debug("Wanted to be someone else!") if "prompt" in request and "none" in request["prompt"]: @@ -365,8 +374,9 @@ def create_authn_response(self, request, sid): if "response_type" in request and request["response_type"] == ["none"]: fragment_enc = False else: + _mngr = self.endpoint_context.session_manager _context = self.endpoint_context - _sinfo = _context.sdb[sid] + _session_info = _mngr.get_session_info(sid) if request.get("scope"): aresp["scope"] = request["scope"] @@ -378,25 +388,37 @@ def create_authn_response(self, request, sid): if len(rtype) == 1 and "code" in rtype: fragment_enc = False + grant = _mngr[sid] + if "code" in request["response_type"]: - _code = aresp["code"] = _context.sdb[sid]["code"] + _code = grant.mint_token( + 'authorization_code', + value=_mngr.token_handler["code"](_session_info["user_id"]), + expires_at=time_sans_frac() + 300 # 5 minutes from now + ) + aresp["code"] = _code.value handled_response_type.append("code") else: - _context.sdb.update(sid, code=None) _code = None if "token" in rtype: - _dic = _context.sdb.upgrade_to_token(issue_refresh=False, key=sid) - - logger.debug("_dic: %s" % sanitize(_dic)) - for key, val in _dic.items(): - if key in aresp.parameters() and val is not None: - aresp[key] = val + _access_token = grant.mint_token( + "access_token", + value=_mngr.token_handler["access_token"]( + sid, + client_id=_session_info["client_id"], + aud=grant.resources, + user_claims=None, + scope=grant.scope, + sub=_session_info["client_session_info"]['sub'], + based_on=_code + ), + expires_at=time_sans_frac() + 900 + ) + aresp['token'] = _access_token handled_response_type.append("token") - _access_token = aresp.get("access_token", None) - not_handled = rtype.difference(handled_response_type) if not_handled: resp = self.error_cls( @@ -446,13 +468,13 @@ def error_response(self, response_info, error, error_description): response_info["response_args"] = resp return response_info - def post_authentication(self, user, request, sid, **kwargs): + def post_authentication(self, user, request, pre_sid, **kwargs): """ Things that are done after a successful authentication. :param user: :param request: - :param sid: + :param pre_sid: :param kwargs: :return: A dictionary with 'response_args' """ @@ -461,8 +483,8 @@ def post_authentication(self, user, request, sid, **kwargs): # Do the authorization try: - permission = self.endpoint_context.authz( - user, client_id=request["client_id"] + grant = self.endpoint_context.authz( + user, client_id=request["client_id"], request=request ) except ToOld as err: return self.error_response( @@ -475,8 +497,10 @@ def post_authentication(self, user, request, sid, **kwargs): response_info, "access_denied", "{}".format(err.args) ) else: + session_id = db_key(user, request["client_id"], grant.id) try: - self.endpoint_context.sdb.update(sid, permission=permission) + self.endpoint_context.session_manager.set([user, request["client_id"], + grant.id], grant) except Exception as err: return self.error_response( response_info, "server_error", "{}".format(err.args) @@ -484,12 +508,7 @@ def post_authentication(self, user, request, sid, **kwargs): logger.debug("response type: %s" % request["response_type"]) - if self.endpoint_context.sdb.is_session_revoked(sid): - return self.error_response( - response_info, "access_denied", "Session is revoked" - ) - - response_info = self.create_authn_response(request, sid) + response_info = self.create_authn_response(request, session_id) try: redirect_uri = get_uri(self.endpoint_context, request, "redirect_uri") @@ -507,10 +526,8 @@ def post_authentication(self, user, request, sid, **kwargs): _cookie = new_cookie( self.endpoint_context, - sub=user, - sid=sid, + sid=session_id, state=request["state"], - client_id=request["client_id"], cookie_name=self.endpoint_context.cookie_name["session"], ) @@ -528,7 +545,22 @@ def post_authentication(self, user, request, sid, **kwargs): response_info["cookie"] = [_cookie] - return response_info + return response_info, session_id + + def setup_client_session(self, user_id: str, request: dict) -> str: + _mngr = self.endpoint_context.session_manager + client_id = request['client_id'] + + _client_info = self.endpoint_context.cdb[client_id] + sub_type = _client_info.get("subject_type") + + client_info = ClientSessionInfo( + authorization_request=request, + sub=_mngr.sub_func['public'](user_id, salt=_mngr.salt) + ) + + _mngr.set([user_id, client_id], client_info) + return db_key(user_id, client_id) def authz_part2(self, user, authn_event, request, **kwargs): """ @@ -540,22 +572,19 @@ def authz_part2(self, user, authn_event, request, **kwargs): :param kwargs: possible other parameters :return: A redirect to the redirect_uri of the client """ - sid = setup_session( - self.endpoint_context, request, user, authn_event=authn_event - ) + pre_sid = self.setup_client_session(user, request) try: - resp_info = self.post_authentication(user, request, sid, **kwargs) + resp_info, session_id = self.post_authentication(user, request, pre_sid, **kwargs) except Exception as err: return self.error_response({}, "server_error", err) if "check_session_iframe" in self.endpoint_context.provider_info: ec = self.endpoint_context salt = rndstr() - if not ec.sdb.is_session_revoked(sid): - authn_event = ec.sdb.get_authentication_event( - sid - ) # use the last session + grant = ec.session_manager[session_id] + if grant.is_active() is False: + authn_event = ec.session_manager.get_authentication_event(session_id) _state = b64e( as_bytes(json.dumps({"authn_time": authn_event["authn_time"]})) ) diff --git a/src/oidcendpoint/oauth2/introspection.py b/src/oidcendpoint/oauth2/introspection.py index 8e83a4c..124c9b0 100644 --- a/src/oidcendpoint/oauth2/introspection.py +++ b/src/oidcendpoint/oauth2/introspection.py @@ -2,10 +2,9 @@ import logging from oidcmsg import oauth2 -from oidcmsg.time_util import utc_time_sans_frac from oidcendpoint.endpoint import Endpoint -from oidcendpoint.token_handler import UnknownToken +from oidcendpoint.session_management import unpack_db_key LOGGER = logging.getLogger(__name__) @@ -54,18 +53,12 @@ def _add_claims(self, token_info, claims, payload): except KeyError: pass - def _introspect(self, token): - try: - info = self.endpoint_context.sdb[token] - except (KeyError, UnknownToken): - return None - + def _introspect(self, token, grant): # Make sure that the token is an access_token or a refresh_token - if token != info.get("access_token") and token != info.get("refresh_token"): + if token.type not in ["access_token", "refresh_token"]: return None - eat = info.get("expires_at") - if eat and eat < utc_time_sans_frac(): + if not token.is_active(): return None if info: # Now what can be returned ? @@ -88,10 +81,14 @@ def process_request(self, request=None, **kwargs): if "error" in _introspect_request: return _introspect_request - _token = _introspect_request["token"] + request_token = _introspect_request["token"] + session_id = self.endpoint_context.session_manager.token_handler.sid(request_token) + grant, token = self.endpoint_context.session_manager.find_grant(session_id, + request_token) + _resp = self.response_cls(active=False) - _info = self._introspect(_token) + _info = self._introspect(token) if _info is None: return {"response_args": _resp} diff --git a/src/oidcendpoint/oauth2/old_introspection.py b/src/oidcendpoint/oauth2/old_introspection.py new file mode 100644 index 0000000..8e83a4c --- /dev/null +++ b/src/oidcendpoint/oauth2/old_introspection.py @@ -0,0 +1,117 @@ +"""Implements RFC7662""" +import logging + +from oidcmsg import oauth2 +from oidcmsg.time_util import utc_time_sans_frac + +from oidcendpoint.endpoint import Endpoint +from oidcendpoint.token_handler import UnknownToken + +LOGGER = logging.getLogger(__name__) + + +class Introspection(Endpoint): + """Implements RFC 7662""" + + request_cls = oauth2.TokenIntrospectionRequest + response_cls = oauth2.TokenIntrospectionResponse + request_format = "urlencoded" + response_format = "json" + endpoint_name = "introspection_endpoint" + name = "introspection" + + def __init__(self, **kwargs): + Endpoint.__init__(self, **kwargs) + self.offset = kwargs.get("offset", 0) + self.enable_claims_per_client = kwargs.get("enable_claims_per_client", False) + + def get_client_id_from_token(self, endpoint_context, token, request=None): + """ + Will try to match tokens against information in the session DB. + + :param endpoint_context: + :param token: + :param request: + :return: client_id if there was a match + """ + sinfo = endpoint_context.sdb[token] + return sinfo["authn_req"]["client_id"] + + def _get_client_claims(self, token): + client_id = self.get_client_id_from_token(self.endpoint_context, token) + client = self.endpoint_context.cdb.get(client_id, {}) + return client.get("introspection_claims") + + def _get_user_info(self, token_info): + user_id = self.endpoint_context.sdb.sso_db.get_uid_by_sid(token_info["sid"]) + return self.endpoint_context.userinfo(user_id, client_id=None) + + def _add_claims(self, token_info, claims, payload): + user_info = self._get_user_info(token_info) + for attr in claims: + try: + payload[attr] = user_info[attr] + except KeyError: + pass + + def _introspect(self, token): + try: + info = self.endpoint_context.sdb[token] + except (KeyError, UnknownToken): + return None + + # Make sure that the token is an access_token or a refresh_token + if token != info.get("access_token") and token != info.get("refresh_token"): + return None + + eat = info.get("expires_at") + if eat and eat < utc_time_sans_frac(): + return None + + if info: # Now what can be returned ? + ret = info.to_dict() + ret["iss"] = self.endpoint_context.issuer + + if "scope" not in ret: + ret["scope"] = " ".join(info["authn_req"]["scope"]) + + return ret + + def process_request(self, request=None, **kwargs): + """ + + :param request: The authorization request as a dictionary + :param kwargs: + :return: + """ + _introspect_request = self.request_cls(**request) + if "error" in _introspect_request: + return _introspect_request + + _token = _introspect_request["token"] + _resp = self.response_cls(active=False) + + _info = self._introspect(_token) + if _info is None: + return {"response_args": _resp} + + if "release" in self.kwargs: + if "username" in self.kwargs["release"]: + try: + _info["username"] = self.endpoint_context.userinfo.search( + sub=_info["sub"] + ) + except KeyError: + pass + + _resp.update(_info) + _resp.weed() + + if self.enable_claims_per_client: + client_claims = self._get_client_claims(_token) + if client_claims: + self._add_claims(_info, client_claims, _resp) + + _resp["active"] = True + + return {"response_args": _resp} diff --git a/src/oidcendpoint/oidc/add_on/pkce.py b/src/oidcendpoint/oidc/add_on/pkce.py index 07bc9f7..f53b7e4 100644 --- a/src/oidcendpoint/oidc/add_on/pkce.py +++ b/src/oidcendpoint/oidc/add_on/pkce.py @@ -90,12 +90,13 @@ def post_token_parse(request, client_id, endpoint_context, **kwargs): return request try: - _info = endpoint_context.sdb[request["code"]] + _session_info = endpoint_context.session_manager.get_session_info_by_token(request["code"]) except KeyError: return TokenErrorResponse( error="invalid_grant", error_description="Unknown access grant" ) - _authn_req = _info["authn_req"] + + _authn_req = _session_info["client_session_info"]["authorization_request"] if "code_challenge" in _authn_req: if "code_verifier" not in request: @@ -104,11 +105,11 @@ def post_token_parse(request, client_id, endpoint_context, **kwargs): error_description="Missing code_verifier", ) - _method = _info["authn_req"]["code_challenge_method"] + _method = _authn_req["code_challenge_method"] if not verify_code_challenge( request["code_verifier"], - _info["authn_req"]["code_challenge"], + _authn_req["code_challenge"], _method, ): return TokenErrorResponse( diff --git a/src/oidcendpoint/oidc/authorization.py b/src/oidcendpoint/oidc/authorization.py index 050c56e..9f1a5c5 100755 --- a/src/oidcendpoint/oidc/authorization.py +++ b/src/oidcendpoint/oidc/authorization.py @@ -1,5 +1,6 @@ import json import logging +from urllib.parse import urlsplit from cryptojwt import BadSyntax from cryptojwt.jwe.exception import JWEException @@ -13,12 +14,12 @@ from oidcmsg.exception import ParameterError from oidcmsg.oidc import Claims from oidcmsg.oidc import verified_claim_name +from oidcmsg.time_util import time_sans_frac from oidcendpoint import rndstr -from oidcendpoint import sanitize from oidcendpoint.authn_event import create_authn_event -from oidcendpoint.common.authorization import FORM_POST from oidcendpoint.common.authorization import AllowedAlgorithms +from oidcendpoint.common.authorization import FORM_POST from oidcendpoint.common.authorization import authn_args_gather from oidcendpoint.common.authorization import get_uri from oidcendpoint.common.authorization import inputs @@ -35,7 +36,10 @@ from oidcendpoint.exception import ToOld from oidcendpoint.exception import UnknownClient from oidcendpoint.oauth2.authorization import check_unknown_scopes_policy -from oidcendpoint.session import setup_session +from oidcendpoint.session_management import ClientSessionInfo +from oidcendpoint.session_management import UserSessionInfo +from oidcendpoint.session_management import db_key +from oidcendpoint.session_management import unpack_db_key from oidcendpoint.token_handler import UnknownToken from oidcendpoint.user_authn.authn_context import pick_auth @@ -68,6 +72,11 @@ def acr_claims(request): return acrdef["values"] +def host_component(url): + res = urlsplit(url) + return "{}://{}".format(res.scheme, res.netloc) + + ALG_PARAMS = { "sign": [ "request_object_signing_alg", @@ -290,7 +299,6 @@ def setup_auth(self, request, redirect_uri, cinfo, cookie, acr=None, **kwargs): authn = res["method"] authn_class_ref = res["acr"] - session = None try: _auth_info = kwargs.get("authn", "") @@ -322,15 +330,8 @@ def setup_auth(self, request, redirect_uri, cinfo, cookie, acr=None, **kwargs): else: identity = json.loads(as_unicode(_id)) - try: - session = self.endpoint_context.sdb[identity.get("sid")] - except UnknownToken: - identity= None - else: - if not session or "revoked" in session: - identity = None - authn_args = authn_args_gather(request, authn_class_ref, cinfo, **kwargs) + _mngr = self.endpoint_context.session_manager # To authenticate or Not if identity is None: # No! @@ -357,13 +358,10 @@ def setup_auth(self, request, redirect_uri, cinfo, cookie, acr=None, **kwargs): # I get back a dictionary user = identity["uid"] if "req_user" in kwargs: - sids = self.endpoint_context.sdb.get_sids_by_sub(kwargs["req_user"]) + sids = _mngr.get_sids_by_user_id(kwargs["req_user"]) if ( sids - and user - != self.endpoint_context.sdb.get_authentication_event( - sids[-1] - ).uid + and user != _mngr.get_authentication_event(sids[-1]).uid ): logger.debug("Wanted to be someone else!") if "prompt" in request and "none" in request["prompt"]: @@ -375,17 +373,16 @@ def setup_auth(self, request, redirect_uri, cinfo, cookie, acr=None, **kwargs): else: return {"function": authn, "args": authn_args} - authn_event = None - if session: - authn_event = session.get('authn_event') + authn_event = _mngr.get_authentication_event(user) if authn_event is None: authn_event = create_authn_event( identity["uid"], - identity.get("salt", ""), + _mngr.salt, authn_info=authn_class_ref, time_stamp=_ts, ) + _mngr.set([identity["uid"]], UserSessionInfo(authentication_event=authn_event)) _exp_in = authn.kwargs.get("expires_in") if _exp_in and "valid_until" in authn_event: @@ -413,7 +410,8 @@ def create_authn_response(self, request, sid): fragment_enc = False else: _context = self.endpoint_context - _sinfo = _context.sdb[sid] + _mngr = self.endpoint_context.session_manager + _sinfo = _mngr[sid] if request.get("scope"): aresp["scope"] = request["scope"] @@ -425,24 +423,39 @@ def create_authn_response(self, request, sid): if len(rtype) == 1 and "code" in rtype: fragment_enc = False + grant = _mngr.grants(sid)[0] + user_id, client_id = unpack_db_key(sid) + if "code" in request["response_type"]: - _code = aresp["code"] = _context.sdb[sid]["code"] + _code = grant.mint_token( + 'authorization_code', + value=_mngr.token_handler["code"](user_id), + expires_at=time_sans_frac() + 300 # 5 minutes from now + ) + aresp["code"] = _code.value handled_response_type.append("code") else: - _context.sdb.update(sid, code=None) _code = None if "token" in rtype: - _dic = _context.sdb.upgrade_to_token(issue_refresh=False, key=sid) - - logger.debug("_dic: %s" % sanitize(_dic)) - for key, val in _dic.items(): - if key in aresp.parameters() and val is not None: - aresp[key] = val + _access_token = grant.mint_token( + "access_token", + value=_mngr.token_handler["access_token"]( + sid, + client_id=client_id, + aud=grant.resources, + user_claims=None, + scope=grant.scope, + sub=_sinfo['sub'], + based_on=_code + ), + expires_at=time_sans_frac() + 900 + ) + aresp['token'] = _access_token handled_response_type.append("token") - - _access_token = aresp.get("access_token", None) + else: + _access_token = None if "id_token" in request["response_type"]: kwargs = {} @@ -453,11 +466,11 @@ def create_authn_response(self, request, sid): elif {"id_token", "token"}.issubset(rtype): kwargs = {"access_token": _access_token} - if request["response_type"] == ["id_token"]: - kwargs["user_claims"] = True + # if request["response_type"] == ["id_token"]: + # kwargs["user_claims"] = True try: - id_token = _context.idtoken.make(request, _sinfo, **kwargs) + id_token = _context.idtoken.make(user_id=user_id, client_id=client_id, **kwargs) except (JWEException, NoSuitableSigningKeys) as err: logger.warning(str(err)) resp = self.error_cls( @@ -467,7 +480,7 @@ def create_authn_response(self, request, sid): return {"response_args": resp, "fragment_enc": fragment_enc} aresp["id_token"] = id_token - _sinfo["id_token"] = id_token + _mngr.update([user_id, client_id], {"id_token": id_token}) handled_response_type.append("id_token") not_handled = rtype.difference(handled_response_type) @@ -521,24 +534,24 @@ def error_response(self, response_info, error, error_description): response_info["response_args"] = resp return response_info - def post_authentication(self, user, request, sid, **kwargs): + def post_authentication(self, user, request, pre_sid, **kwargs): """ Things that are done after a successful authentication. :param user: - :param request: + :param request: The authorization request :param sid: :param kwargs: :return: A dictionary with 'response_args' """ response_info = {} + _mngr = self.endpoint_context.session_manager + user_id, client_id = unpack_db_key(pre_sid) # Do the authorization try: - permission = self.endpoint_context.authz( - user, client_id=request["client_id"] - ) + grant = self.endpoint_context.authz(user_id, client_id, request=request) except ToOld as err: return self.error_response( response_info, @@ -551,20 +564,17 @@ def post_authentication(self, user, request, sid, **kwargs): ) else: try: - self.endpoint_context.sdb.update(sid, permission=permission) + _mngr.set([user_id, client_id, grant.id], grant) except Exception as err: return self.error_response( response_info, "server_error", "{}".format(err.args) ) + else: + session_id = db_key(user_id, client_id, grant.id) logger.debug("response type: %s" % request["response_type"]) - if self.endpoint_context.sdb.is_session_revoked(sid): - return self.error_response( - response_info, "access_denied", "Session is revoked" - ) - - response_info = self.create_authn_response(request, sid) + response_info = self.create_authn_response(request, session_id) logger.debug("Known clients: {}".format(list(self.endpoint_context.cdb.keys()))) @@ -585,7 +595,6 @@ def post_authentication(self, user, request, sid, **kwargs): _cookie = new_cookie( self.endpoint_context, uid=user, - sid=sid, state=request["state"], client_id=request["client_id"], cookie_name=self.endpoint_context.cookie_name["session"], @@ -605,9 +614,34 @@ def post_authentication(self, user, request, sid, **kwargs): response_info["cookie"] = [_cookie] - return response_info + return response_info, session_id + + def setup_client_session(self, user_id: str, request: dict) -> str: + _mngr = self.endpoint_context.session_manager + client_id = request['client_id'] + + _client_info = self.endpoint_context.cdb[client_id] + sub_type = _client_info.get("subject_type") + if sub_type and sub_type == "pairwise": + sector_identifier_uri = _client_info.get("sector_identifier_uri") + if sector_identifier_uri is None: + sector_identifier_uri = host_component(_client_info["redirect_uris"][0]) + + client_info = ClientSessionInfo( + authorization_request=request, + sub=_mngr.sub_func[sub_type](user_id, salt=_mngr.salt, + sector_identifier=sector_identifier_uri) + ) + else: + client_info = ClientSessionInfo( + authorization_request=request, + sub=_mngr.sub_func['public'](user_id, salt=_mngr.salt) + ) + + _mngr.set([user_id, client_id], client_info) + return db_key(user_id, client_id) - def authz_part2(self, user, authn_event, request, **kwargs): + def authz_part2(self, user, request, **kwargs): """ After the authentication this is where you should end up @@ -617,63 +651,67 @@ def authz_part2(self, user, authn_event, request, **kwargs): :param kwargs: possible other parameters :return: A redirect to the redirect_uri of the client """ - sid = setup_session( - self.endpoint_context, request, user, authn_event=authn_event - ) + + pre_sid = self.setup_client_session(user, request) try: - resp_info = self.post_authentication(user, request, sid, **kwargs) + resp_info, session_id = self.post_authentication(user, request, pre_sid, **kwargs) except Exception as err: return self.error_response({}, "server_error", err) if "check_session_iframe" in self.endpoint_context.provider_info: ec = self.endpoint_context salt = rndstr() - if not ec.sdb.is_session_revoked(sid): - authn_event = ec.sdb.get_authentication_event( - sid - ) # use the last session - _state = b64e( - as_bytes(json.dumps({"authn_time": authn_event["authn_time"]})) - ) + try: + authn_event = ec.session_manager.get_authentication_event(session_id) + except KeyError: + return self.error_response({}, "server_error", "No such session") + else: + if authn_event.is_active() is False: + return self.error_response({}, "server_error", "Authentication has timed out") - opbs_value = '' - if hasattr(ec.cookie_dealer, 'create_cookie'): - session_cookie = ec.cookie_dealer.create_cookie( - as_unicode(_state), - typ="session", - cookie_name=ec.cookie_name["session_management"], - same_site="None", - http_only=False, - ) + _state = b64e( + as_bytes(json.dumps({"authn_time": authn_event["authn_time"]})) + ) - opbs = session_cookie[ec.cookie_name["session_management"]] - opbs_value = opbs.value - else: - logger.debug("Failed to set Cookie, that's not configured in main configuration.") + opbs_value = '' + if hasattr(ec.cookie_dealer, 'create_cookie'): + session_cookie = ec.cookie_dealer.create_cookie( + as_unicode(_state), + typ="session", + cookie_name=ec.cookie_name["session_management"], + same_site="None", + http_only=False, + ) + opbs = session_cookie[ec.cookie_name["session_management"]] + opbs_value = opbs.value + else: logger.debug( - "compute_session_state: client_id=%s, origin=%s, opbs=%s, salt=%s", - request["client_id"], - resp_info["return_uri"], - opbs_value, - salt, - ) + "Failed to set Cookie, that's not configured in main configuration.") - _session_state = compute_session_state( - opbs_value, salt, request["client_id"], resp_info["return_uri"] - ) + logger.debug( + "compute_session_state: client_id=%s, origin=%s, opbs=%s, salt=%s", + request["client_id"], + resp_info["return_uri"], + opbs_value, + salt, + ) - if opbs_value: - if "cookie" in resp_info: - if isinstance(resp_info["cookie"], list): - resp_info["cookie"].append(session_cookie) - else: - append_cookie(resp_info["cookie"], session_cookie) + _session_state = compute_session_state( + opbs_value, salt, request["client_id"], resp_info["return_uri"] + ) + + if opbs_value: + if "cookie" in resp_info: + if isinstance(resp_info["cookie"], list): + resp_info["cookie"].append(session_cookie) else: - resp_info["cookie"] = session_cookie + append_cookie(resp_info["cookie"], session_cookie) + else: + resp_info["cookie"] = session_cookie - resp_info["response_args"]["session_state"] = _session_state + resp_info["response_args"]["session_state"] = _session_state # Mix-Up mitigation resp_info["response_args"]["iss"] = self.endpoint_context.issuer @@ -724,10 +762,7 @@ def process_request(self, request_info=None, **kwargs): if not _function: logger.debug("- authenticated -") logger.debug("AREQ keys: %s" % request_info.keys()) - res = self.authz_part2( - info["user"], info["authn_event"], request_info, cookie=cookie - ) - return res + return self.authz_part2(user=info["user"], request=request_info, cookie=cookie) try: # Run the authentication function diff --git a/src/oidcendpoint/oidc/old_authorization.py b/src/oidcendpoint/oidc/old_authorization.py new file mode 100755 index 0000000..050c56e --- /dev/null +++ b/src/oidcendpoint/oidc/old_authorization.py @@ -0,0 +1,740 @@ +import json +import logging + +from cryptojwt import BadSyntax +from cryptojwt.jwe.exception import JWEException +from cryptojwt.jws.exception import NoSuitableSigningKeys +from cryptojwt.jwt import utc_time_sans_frac +from cryptojwt.utils import as_bytes +from cryptojwt.utils import as_unicode +from cryptojwt.utils import b64d +from cryptojwt.utils import b64e +from oidcmsg import oidc +from oidcmsg.exception import ParameterError +from oidcmsg.oidc import Claims +from oidcmsg.oidc import verified_claim_name + +from oidcendpoint import rndstr +from oidcendpoint import sanitize +from oidcendpoint.authn_event import create_authn_event +from oidcendpoint.common.authorization import FORM_POST +from oidcendpoint.common.authorization import AllowedAlgorithms +from oidcendpoint.common.authorization import authn_args_gather +from oidcendpoint.common.authorization import get_uri +from oidcendpoint.common.authorization import inputs +from oidcendpoint.common.authorization import max_age +from oidcendpoint.cookie import append_cookie +from oidcendpoint.cookie import compute_session_state +from oidcendpoint.cookie import new_cookie +from oidcendpoint.endpoint import Endpoint +from oidcendpoint.exception import InvalidRequest +from oidcendpoint.exception import NoSuchAuthentication +from oidcendpoint.exception import RedirectURIError +from oidcendpoint.exception import ServiceError +from oidcendpoint.exception import TamperAllert +from oidcendpoint.exception import ToOld +from oidcendpoint.exception import UnknownClient +from oidcendpoint.oauth2.authorization import check_unknown_scopes_policy +from oidcendpoint.session import setup_session +from oidcendpoint.token_handler import UnknownToken +from oidcendpoint.user_authn.authn_context import pick_auth + +logger = logging.getLogger(__name__) + + +def proposed_user(request): + cn = verified_claim_name("it_token_hint") + if request.get(cn): + return request[cn].get("sub", "") + return "" + + +def acr_claims(request): + acrdef = None + + _claims = request.get("claims") + if isinstance(_claims, str): + _claims = Claims().from_json(_claims) + + if _claims: + _id_token_claim = _claims.get("id_token") + if _id_token_claim: + acrdef = _id_token_claim.get("acr") + + if isinstance(acrdef, dict): + if acrdef.get("value"): + return [acrdef["value"]] + elif acrdef.get("values"): + return acrdef["values"] + + +ALG_PARAMS = { + "sign": [ + "request_object_signing_alg", + "request_object_signing_alg_values_supported", + ], + "enc_alg": [ + "request_object_encryption_alg", + "request_object_encryption_alg_values_supported", + ], + "enc_enc": [ + "request_object_encryption_enc", + "request_object_encryption_enc_values_supported", + ], +} + + +def re_authenticate(request, authn): + if "prompt" in request and "login" in request["prompt"]: + if authn.done(request): + return True + + return False + + +class Authorization(Endpoint): + request_cls = oidc.AuthorizationRequest + response_cls = oidc.AuthorizationResponse + error_cls = oidc.AuthorizationErrorResponse + request_format = "urlencoded" + response_format = "urlencoded" + response_placement = "url" + endpoint_name = "authorization_endpoint" + name = "authorization" + default_capabilities = { + "claims_parameter_supported": True, + "request_parameter_supported": True, + "request_uri_parameter_supported": True, + "response_types_supported": [ + "code", + "token", + "id_token", + "code token", + "code id_token", + "id_token token", + "code id_token token", + ], + "response_modes_supported": ["query", "fragment", "form_post"], + "request_object_signing_alg_values_supported": None, + "request_object_encryption_alg_values_supported": None, + "request_object_encryption_enc_values_supported": None, + "grant_types_supported": ["authorization_code", "implicit"], + "claim_types_supported": ["normal", "aggregated", "distributed"], + } + + def __init__(self, endpoint_context, **kwargs): + Endpoint.__init__(self, endpoint_context, **kwargs) + # self.pre_construct.append(self._pre_construct) + self.post_parse_request.append(self._do_request_uri) + self.post_parse_request.append(self._post_parse_request) + self.allowed_request_algorithms = AllowedAlgorithms(ALG_PARAMS) + + def filter_request(self, endpoint_context, req): + return req + + def verify_response_type(self, request, cinfo): + # Checking response types + _registered = [set(rt.split(" ")) for rt in cinfo.get("response_types", [])] + if not _registered: + # If no response_type is registered by the client then we'll + # code which it the default according to the OIDC spec. + _registered = [{"code"}] + + # Is the asked for response_type among those that are permitted + return set(request["response_type"]) in _registered + + def _do_request_uri(self, request, client_id, endpoint_context, **kwargs): + _request_uri = request.get("request_uri") + if _request_uri: + # Do I do pushed authorization requests ? + if "pushed_authorization" in endpoint_context.endpoint: + # Is it a UUID urn + if _request_uri.startswith("urn:uuid:"): + _req = endpoint_context.par_db.get(_request_uri) + if _req: + del endpoint_context.par_db[_request_uri] # One time + # usage + return _req + else: + raise ValueError("Got a request_uri I can not resolve") + + # Do I support request_uri ? + _supported = endpoint_context.provider_info.get( + "request_uri_parameter_supported", True + ) + _registered = endpoint_context.cdb[client_id].get("request_uris") + # Not registered should be handled else where + if _registered: + # Before matching remove a possible fragment + _p = _request_uri.split("#") + # ignore registered fragments for now. + if _p[0] not in [l[0] for l in _registered]: + raise ValueError("A request_uri outside the registered") + + # Fetch the request + _resp = endpoint_context.httpc.get( + _request_uri, **endpoint_context.httpc_params + ) + if _resp.status_code == 200: + args = {"keyjar": endpoint_context.keyjar, "issuer": client_id} + _ver_request = self.request_cls().from_jwt(_resp.text, **args) + self.allowed_request_algorithms( + client_id, + endpoint_context, + _ver_request.jws_header.get("alg", "RS256"), + "sign", + ) + if _ver_request.jwe_header is not None: + self.allowed_request_algorithms( + client_id, + endpoint_context, + _ver_request.jws_header.get("alg"), + "enc_alg", + ) + self.allowed_request_algorithms( + client_id, + endpoint_context, + _ver_request.jws_header.get("enc"), + "enc_enc", + ) + # The protected info overwrites the non-protected + for k, v in _ver_request.items(): + request[k] = v + + request[verified_claim_name("request")] = _ver_request + else: + raise ServiceError("Got a %s response", _resp.status) + + return request + + def _post_parse_request(self, request, client_id, endpoint_context, **kwargs): + """ + Verify the authorization request. + + :param endpoint_context: + :param request: + :param client_id: + :param kwargs: + :return: + """ + if not request: + logger.debug("No AuthzRequest") + return self.error_cls( + error="invalid_request", error_description="Can not parse AuthzRequest" + ) + + request = self.filter_request(endpoint_context, request) + + _cinfo = endpoint_context.cdb.get(client_id) + if not _cinfo: + logger.error( + "Client ID ({}) not in client database".format(request["client_id"]) + ) + return self.error_cls( + error="unauthorized_client", error_description="unknown client" + ) + + # Is the asked for response_type among those that are permitted + if not self.verify_response_type(request, _cinfo): + return self.error_cls( + error="invalid_request", + error_description="Trying to use unregistered response_type", + ) + + # Get a verified redirect URI + try: + redirect_uri = get_uri(endpoint_context, request, "redirect_uri") + except (RedirectURIError, ParameterError, UnknownClient) as err: + return self.error_cls( + error="invalid_request", + error_description="{}:{}".format(err.__class__.__name__, err), + ) + else: + request["redirect_uri"] = redirect_uri + + return request + + def pick_authn_method(self, request, redirect_uri, acr=None, **kwargs): + auth_id = kwargs.get("auth_method_id") + if auth_id: + return self.endpoint_context.authn_broker[auth_id] + + if acr: + res = self.endpoint_context.authn_broker.pick(acr) + else: + res = pick_auth(self.endpoint_context, request) + + if res: + return res + else: + return { + "error": "access_denied", + "error_description": "ACR I do not support", + "return_uri": redirect_uri, + "return_type": request["response_type"], + } + + def setup_auth(self, request, redirect_uri, cinfo, cookie, acr=None, **kwargs): + """ + + :param request: The authorization/authentication request + :param redirect_uri: + :param cinfo: client info + :param cookie: + :param acr: Default ACR, if nothing else is specified + :param kwargs: + :return: + """ + + res = self.pick_authn_method(request, redirect_uri, acr, **kwargs) + + authn = res["method"] + authn_class_ref = res["acr"] + session = None + + try: + _auth_info = kwargs.get("authn", "") + if "upm_answer" in request and request["upm_answer"] == "true": + _max_age = 0 + else: + _max_age = max_age(request) + + identity, _ts = authn.authenticated_as( + cookie, authorization=_auth_info, max_age=_max_age + ) + except (NoSuchAuthentication, TamperAllert): + identity = None + _ts = 0 + except ToOld: + logger.info("Too old authentication") + identity = None + _ts = 0 + except UnknownToken: + logger.info("Unknown Token") + identity = None + _ts = 0 + else: + if identity: + try: # If identity['uid'] is in fact a base64 encoded JSON string + _id = b64d(as_bytes(identity["uid"])) + except BadSyntax: + pass + else: + identity = json.loads(as_unicode(_id)) + + try: + session = self.endpoint_context.sdb[identity.get("sid")] + except UnknownToken: + identity= None + else: + if not session or "revoked" in session: + identity = None + + authn_args = authn_args_gather(request, authn_class_ref, cinfo, **kwargs) + + # To authenticate or Not + if identity is None: # No! + logger.info("No active authentication") + logger.debug( + "Known clients: {}".format(list(self.endpoint_context.cdb.keys())) + ) + + if "prompt" in request and "none" in request["prompt"]: + # Need to authenticate but not allowed + return { + "error": "login_required", + "return_uri": redirect_uri, + "return_type": request["response_type"], + } + else: + return {"function": authn, "args": authn_args} + else: + logger.info("Active authentication") + if re_authenticate(request, authn): + # demand re-authentication + return {"function": authn, "args": authn_args} + else: + # I get back a dictionary + user = identity["uid"] + if "req_user" in kwargs: + sids = self.endpoint_context.sdb.get_sids_by_sub(kwargs["req_user"]) + if ( + sids + and user + != self.endpoint_context.sdb.get_authentication_event( + sids[-1] + ).uid + ): + logger.debug("Wanted to be someone else!") + if "prompt" in request and "none" in request["prompt"]: + # Need to authenticate but not allowed + return { + "error": "login_required", + "return_uri": redirect_uri, + } + else: + return {"function": authn, "args": authn_args} + + authn_event = None + if session: + authn_event = session.get('authn_event') + + if authn_event is None: + authn_event = create_authn_event( + identity["uid"], + identity.get("salt", ""), + authn_info=authn_class_ref, + time_stamp=_ts, + ) + + _exp_in = authn.kwargs.get("expires_in") + if _exp_in and "valid_until" in authn_event: + authn_event["valid_until"] = utc_time_sans_frac() + _exp_in + + return {"authn_event": authn_event, "identity": identity, "user": user} + + def extra_response_args(self, aresp): + return aresp + + def create_authn_response(self, request, sid): + """ + + :param self: + :param request: + :param sid: + :return: + """ + # create the response + aresp = self.response_cls() + if request.get("state"): + aresp["state"] = request["state"] + + if "response_type" in request and request["response_type"] == ["none"]: + fragment_enc = False + else: + _context = self.endpoint_context + _sinfo = _context.sdb[sid] + + if request.get("scope"): + aresp["scope"] = request["scope"] + + rtype = set(request["response_type"][:]) + handled_response_type = [] + + fragment_enc = True + if len(rtype) == 1 and "code" in rtype: + fragment_enc = False + + if "code" in request["response_type"]: + _code = aresp["code"] = _context.sdb[sid]["code"] + handled_response_type.append("code") + else: + _context.sdb.update(sid, code=None) + _code = None + + if "token" in rtype: + _dic = _context.sdb.upgrade_to_token(issue_refresh=False, key=sid) + + logger.debug("_dic: %s" % sanitize(_dic)) + for key, val in _dic.items(): + if key in aresp.parameters() and val is not None: + aresp[key] = val + + handled_response_type.append("token") + + _access_token = aresp.get("access_token", None) + + if "id_token" in request["response_type"]: + kwargs = {} + if {"code", "id_token", "token"}.issubset(rtype): + kwargs = {"code": _code, "access_token": _access_token} + elif {"code", "id_token"}.issubset(rtype): + kwargs = {"code": _code} + elif {"id_token", "token"}.issubset(rtype): + kwargs = {"access_token": _access_token} + + if request["response_type"] == ["id_token"]: + kwargs["user_claims"] = True + + try: + id_token = _context.idtoken.make(request, _sinfo, **kwargs) + except (JWEException, NoSuitableSigningKeys) as err: + logger.warning(str(err)) + resp = self.error_cls( + error="invalid_request", + error_description="Could not sign/encrypt id_token", + ) + return {"response_args": resp, "fragment_enc": fragment_enc} + + aresp["id_token"] = id_token + _sinfo["id_token"] = id_token + handled_response_type.append("id_token") + + not_handled = rtype.difference(handled_response_type) + if not_handled: + resp = self.error_cls( + error="invalid_request", error_description="unsupported_response_type" + ) + return {"response_args": resp, "fragment_enc": fragment_enc} + + aresp = self.extra_response_args(aresp) + + return {"response_args": aresp, "fragment_enc": fragment_enc} + + def aresp_check(self, aresp, request): + return "" + + def response_mode(self, request, **kwargs): + resp_mode = request["response_mode"] + if resp_mode == "form_post": + msg = FORM_POST.format( + inputs=inputs(kwargs["response_args"].to_dict()), + action=kwargs["return_uri"], + ) + kwargs.update( + { + "response_msg": msg, + "content_type": "text/html", + "response_placement": "body", + } + ) + elif resp_mode == "fragment": + if "fragment_enc" in kwargs: + if not kwargs["fragment_enc"]: + # Can't be done + raise InvalidRequest("wrong response_mode") + else: + kwargs["fragment_enc"] = True + elif resp_mode == "query": + if "fragment_enc" in kwargs: + if kwargs["fragment_enc"]: + # Can't be done + raise InvalidRequest("wrong response_mode") + else: + raise InvalidRequest("Unknown response_mode") + return kwargs + + def error_response(self, response_info, error, error_description): + resp = self.error_cls( + error=error, error_description=str(error_description) + ) + response_info["response_args"] = resp + return response_info + + def post_authentication(self, user, request, sid, **kwargs): + """ + Things that are done after a successful authentication. + + :param user: + :param request: + :param sid: + :param kwargs: + :return: A dictionary with 'response_args' + """ + + response_info = {} + + # Do the authorization + try: + permission = self.endpoint_context.authz( + user, client_id=request["client_id"] + ) + except ToOld as err: + return self.error_response( + response_info, + "access_denied", + "Authentication to old {}".format(err.args), + ) + except Exception as err: + return self.error_response( + response_info, "access_denied", "{}".format(err.args) + ) + else: + try: + self.endpoint_context.sdb.update(sid, permission=permission) + except Exception as err: + return self.error_response( + response_info, "server_error", "{}".format(err.args) + ) + + logger.debug("response type: %s" % request["response_type"]) + + if self.endpoint_context.sdb.is_session_revoked(sid): + return self.error_response( + response_info, "access_denied", "Session is revoked" + ) + + response_info = self.create_authn_response(request, sid) + + logger.debug("Known clients: {}".format(list(self.endpoint_context.cdb.keys()))) + + try: + redirect_uri = get_uri(self.endpoint_context, request, "redirect_uri") + except (RedirectURIError, ParameterError) as err: + return self.error_response( + response_info, "invalid_request", "{}".format(err.args) + ) + else: + response_info["return_uri"] = redirect_uri + + # Must not use HTTP unless implicit grant type and native application + # info = self.aresp_check(response_info['response_args'], request) + # if isinstance(info, ResponseMessage): + # return info + + _cookie = new_cookie( + self.endpoint_context, + uid=user, + sid=sid, + state=request["state"], + client_id=request["client_id"], + cookie_name=self.endpoint_context.cookie_name["session"], + ) + + # Now about the response_mode. Should not be set if it's obvious + # from the response_type. Knows about 'query', 'fragment' and + # 'form_post'. + + if "response_mode" in request: + try: + response_info = self.response_mode(request, **response_info) + except InvalidRequest as err: + return self.error_response( + response_info, "invalid_request", "{}".format(err.args) + ) + + response_info["cookie"] = [_cookie] + + return response_info + + def authz_part2(self, user, authn_event, request, **kwargs): + """ + After the authentication this is where you should end up + + :param user: + :param request: The Authorization Request + :param sid: Session key + :param kwargs: possible other parameters + :return: A redirect to the redirect_uri of the client + """ + sid = setup_session( + self.endpoint_context, request, user, authn_event=authn_event + ) + + try: + resp_info = self.post_authentication(user, request, sid, **kwargs) + except Exception as err: + return self.error_response({}, "server_error", err) + + if "check_session_iframe" in self.endpoint_context.provider_info: + ec = self.endpoint_context + salt = rndstr() + if not ec.sdb.is_session_revoked(sid): + authn_event = ec.sdb.get_authentication_event( + sid + ) # use the last session + _state = b64e( + as_bytes(json.dumps({"authn_time": authn_event["authn_time"]})) + ) + + opbs_value = '' + if hasattr(ec.cookie_dealer, 'create_cookie'): + session_cookie = ec.cookie_dealer.create_cookie( + as_unicode(_state), + typ="session", + cookie_name=ec.cookie_name["session_management"], + same_site="None", + http_only=False, + ) + + opbs = session_cookie[ec.cookie_name["session_management"]] + opbs_value = opbs.value + else: + logger.debug("Failed to set Cookie, that's not configured in main configuration.") + + logger.debug( + "compute_session_state: client_id=%s, origin=%s, opbs=%s, salt=%s", + request["client_id"], + resp_info["return_uri"], + opbs_value, + salt, + ) + + _session_state = compute_session_state( + opbs_value, salt, request["client_id"], resp_info["return_uri"] + ) + + if opbs_value: + if "cookie" in resp_info: + if isinstance(resp_info["cookie"], list): + resp_info["cookie"].append(session_cookie) + else: + append_cookie(resp_info["cookie"], session_cookie) + else: + resp_info["cookie"] = session_cookie + + resp_info["response_args"]["session_state"] = _session_state + + # Mix-Up mitigation + resp_info["response_args"]["iss"] = self.endpoint_context.issuer + resp_info["response_args"]["client_id"] = request["client_id"] + + return resp_info + + def process_request(self, request_info=None, **kwargs): + """ The AuthorizationRequest endpoint + + :param request_info: The authorization request as a Message instance + :return: dictionary + """ + + if isinstance(request_info, self.error_cls): + return request_info + + _cid = request_info["client_id"] + cinfo = self.endpoint_context.cdb[_cid] + logger.debug("client {}: {}".format(_cid, cinfo)) + + # this apply the default optionally deny_unknown_scopes policy + if cinfo: + check_unknown_scopes_policy(request_info, cinfo, self.endpoint_context) + + cookie = kwargs.get("cookie", "") + if cookie: + del kwargs["cookie"] + + if proposed_user(request_info): + kwargs["req_user"] = proposed_user(request_info) + else: + if request_info.get("login_hint"): + _login_hint = request_info["login_hint"] + if self.endpoint_context.login_hint_lookup: + kwargs["req_user"] = self.endpoint_context.login_hint_lookup[ + _login_hint + ] + + info = self.setup_auth( + request_info, request_info["redirect_uri"], cinfo, cookie, **kwargs + ) + + if "error" in info: + return info + + _function = info.get("function") + if not _function: + logger.debug("- authenticated -") + logger.debug("AREQ keys: %s" % request_info.keys()) + res = self.authz_part2( + info["user"], info["authn_event"], request_info, cookie=cookie + ) + return res + + try: + # Run the authentication function + return { + "http_response": _function(**info["args"]), + "return_uri": request_info["redirect_uri"], + } + except Exception as err: + logger.exception(err) + return {"http_response": "Internal error: {}".format(err)} diff --git a/src/oidcendpoint/oidc/refresh_token.py b/src/oidcendpoint/oidc/refresh_token.py index a2ffea1..759a120 100755 --- a/src/oidcendpoint/oidc/refresh_token.py +++ b/src/oidcendpoint/oidc/refresh_token.py @@ -2,16 +2,14 @@ from oidcmsg import oidc from oidcmsg.oauth2 import ResponseMessage -from oidcmsg.oidc import AccessTokenResponse from oidcmsg.oidc import RefreshAccessTokenRequest from oidcmsg.oidc import TokenErrorResponse +from oidcmsg.time_util import time_sans_frac from oidcendpoint import sanitize from oidcendpoint.client_authn import verify_client from oidcendpoint.cookie import new_cookie from oidcendpoint.endpoint import Endpoint -from oidcendpoint.token_handler import ExpiredToken -from oidcendpoint.userinfo import by_schema logger = logging.getLogger(__name__) @@ -32,9 +30,7 @@ def __init__(self, endpoint_context, **kwargs): self.post_parse_request.append(self._post_parse_request) def _refresh_access_token(self, req, **kwargs): - _sdb = self.endpoint_context.sdb - - # client_id = str(req["client_id"]) + _mngr = self.endpoint_context.session_manager if req["grant_type"] != "refresh_token": return self.error_cls( @@ -42,14 +38,29 @@ def _refresh_access_token(self, req, **kwargs): ) rtoken = req["refresh_token"] - try: - _info = _sdb.refresh_token(rtoken) - except ExpiredToken: + _session_info = _mngr.get_session_info_by_token(rtoken) + grant, token = _mngr.find_grant(_session_info["session_id"], rtoken) + if token.is_active is False: return self.error_cls( - error="invalid_request", error_description="Refresh token is expired" + error="invalid_request", error_description="Refresh token inactive" ) - return by_schema(AccessTokenResponse, **_info) + access_token = grant.mint_token( + 'access_token', + value=_mngr.token_handler["access_token"]( + _session_info["session_id"], + client_id=_session_info["client_id"], + aud=grant.resources, + user_claims=None, + scope=grant.scope, + sub=_session_info["client_session_info"]['sub'] + ), + expires_at=time_sans_frac() + 900, # 15 minutes from now + based_on=rtoken # Means the token (tok) was used to mint this token + ) + + return {"access_token": access_token, "token_type": "Bearer", + "expires_in": 900, "scope": grant.scope} def client_authentication(self, request, auth=None, **kwargs): """ @@ -114,10 +125,12 @@ def process_request(self, request=None, **kwargs): if isinstance(response_args, ResponseMessage): return response_args + _token = request["refresh_token"].replace(" ", "+") + _session_info = self.endpoint_context.session_manager.get_session_info_by_token(_token) _cookie = new_cookie( self.endpoint_context, - sub=self.endpoint_context.sdb[_token]["sub"], + sub=_session_info["client_session_info"]["sub"], cookie_name=self.endpoint_context.cookie_name["session"], ) _headers = [("Content-type", "application/json")] diff --git a/src/oidcendpoint/oidc/session.py b/src/oidcendpoint/oidc/session.py index e49a648..16b7c8a 100644 --- a/src/oidcendpoint/oidc/session.py +++ b/src/oidcendpoint/oidc/session.py @@ -11,6 +11,9 @@ from cryptojwt.jws.utils import alg2keytype from cryptojwt.jwt import JWT from cryptojwt.utils import as_bytes +from oidcendpoint.session_management import db_key + +from oidcendpoint.session_management import unpack_db_key from oidcmsg.exception import InvalidRequest from oidcmsg.message import Message from oidcmsg.oauth2 import ResponseMessage @@ -120,53 +123,40 @@ def do_back_channel_logout(self, cinfo, sub, sid): def clean_sessions(self, usids): # Clean out all sessions - _sdb = self.endpoint_context.sdb - _sso_db = self.endpoint_context.sdb.sso_db + _mngr = self.endpoint_context.session_manager + for sid in usids: - _state = _sdb[sid]["authn_req"]["state"] - # remove session information - del _sdb[sid] - # remove all states connected to this session id - _sdb.delete_kv2sid(_state, "state") - _sso_db.remove_session_id(sid) + _mngr.revoke_session(sid) def logout_all_clients(self, sid, client_id): - _sdb = self.endpoint_context.sdb - _sso_db = self.endpoint_context.sdb.sso_db - + _mngr = self.endpoint_context.session_manager + _user_id, _client_id = unpack_db_key(sid) # Find all RPs this user has logged it from - uid = _sso_db.get_uid_by_sid(sid) - if uid is None: - logger.debug("Can not translate sid:%s into a user id", sid) - return {} - - _client_sid = {} - usids = _sso_db.get_sids_by_uid(uid) - if usids is None: - logger.debug("No sessions found for uid: %s", uid) - return {} - - for usid in usids: - _client_sid[_sdb[usid]["authn_req"]["client_id"]] = usid + _user_session_info = _mngr.get([_user_id]) # Front-/Backchannel logout ? _cdb = self.endpoint_context.cdb _iss = self.endpoint_context.issuer bc_logouts = {} fc_iframes = {} - for _cid, _csid in _client_sid.items(): - if "backchannel_logout_uri" in _cdb[_cid]: - _sub = _sso_db.get_sub_by_sid(_csid) - _spec = self.do_back_channel_logout(_cdb[_cid], _sub, _csid) + sids = [] + for _client_id in _user_session_info["subordinate"]: + if "backchannel_logout_uri" in _cdb[_client_id]: + _sid = db_key(_user_id, _client_id) + _sub = _mngr.get([_user_id, _client_id])["sub"] + sids.append(_sid) + _spec = self.do_back_channel_logout(_cdb[_client_id], _sub, _sid) if _spec: - bc_logouts[_cid] = _spec - elif "frontchannel_logout_uri" in _cdb[_cid]: + bc_logouts[_client_id] = _spec + elif "frontchannel_logout_uri" in _cdb[_client_id]: # Construct an IFrame - _spec = do_front_channel_logout_iframe(_cdb[_cid], _iss, _csid) + _sid = db_key(_user_id, _client_id) + sids.append(_sid) + _spec = do_front_channel_logout_iframe(_cdb[_client_id], _iss, _sid) if _spec: - fc_iframes[_cid] = _spec + fc_iframes[_client_id] = _spec - self.clean_sessions(usids) + self.clean_sessions(sids) res = {} if bc_logouts: @@ -191,15 +181,14 @@ def unpack_signed_jwt(self, sjwt, sig_alg=""): def logout_from_client(self, sid, client_id): _cdb = self.endpoint_context.cdb - _sso_db = self.endpoint_context.sdb.sso_db + _mngr = self.endpoint_context.session_manager # Kill the session - _sdb = self.endpoint_context.sdb - _sdb.revoke_session(sid=sid) + _mngr.revoke_session(sid) res = {} if "backchannel_logout_uri" in _cdb[client_id]: - _sub = _sso_db.get_sub_by_sid(sid) + _sub = _mngr[sid]["sub"] _spec = self.do_back_channel_logout(_cdb[client_id], _sub, sid) if _spec: res["blu"] = {client_id: _spec} @@ -224,7 +213,7 @@ def process_request(self, request=None, cookie=None, **kwargs): :return: """ _cntx = self.endpoint_context - _sdb = _cntx.sdb + _mngr = _cntx.session_manager if "post_logout_redirect_uri" in request: if "id_token_hint" not in request: @@ -246,6 +235,7 @@ def process_request(self, request=None, cookie=None, **kwargs): _cookie_info = json.loads(as_unicode(b64d(as_bytes(part[0])))) logger.debug("Cookie info: {}".format(_cookie_info)) _sid = _cookie_info["sid"] + _user_id, _client_id = unpack_db_key(_sid) else: logger.debug("No relevant cookie") _sid = "" @@ -283,12 +273,15 @@ def process_request(self, request=None, cookie=None, **kwargs): else: auds = [] + if not _sid: + raise KeyError("Unknown session") + try: - session = _sdb[_sid] + session = _mngr[_sid] except KeyError: raise ValueError("Can't find any corresponding session") - client_id = session["authn_req"]["client_id"] + client_id = session["authorization_request"]["client_id"] # Does this match what's in the cookie ? if _cookie_info: if client_id != _cookie_info["client_id"]: @@ -321,7 +314,7 @@ def process_request(self, request=None, cookie=None, **kwargs): payload = { "sid": _sid, "client_id": client_id, - "user": session["authn_event"]["uid"], + "user": _user_id } # redirect user to OP logout verification page diff --git a/src/oidcendpoint/oidc/token.py b/src/oidcendpoint/oidc/token.py index 74eb7f9..a53dad6 100755 --- a/src/oidcendpoint/oidc/token.py +++ b/src/oidcendpoint/oidc/token.py @@ -4,15 +4,13 @@ from cryptojwt.jws.exception import NoSuitableSigningKeys from oidcmsg import oidc from oidcmsg.oauth2 import ResponseMessage -from oidcmsg.oidc import AccessTokenResponse from oidcmsg.oidc import TokenErrorResponse +from oidcmsg.time_util import time_sans_frac from oidcendpoint import sanitize from oidcendpoint.cookie import new_cookie from oidcendpoint.endpoint import Endpoint -from oidcendpoint.exception import MultipleCodeUsage -from oidcendpoint.token_handler import AccessCodeUsed -from oidcendpoint.userinfo import by_schema +from oidcendpoint.session_management import unpack_db_key logger = logging.getLogger(__name__) @@ -39,7 +37,7 @@ def __init__(self, endpoint_context, **kwargs): def _access_token(self, req, **kwargs): _context = self.endpoint_context - _sdb = _context.sdb + _mngr = _context.session_manager _log_debug = logger.debug if req["grant_type"] != "authorization_code": @@ -54,22 +52,17 @@ def _access_token(self, req, **kwargs): error="invalid_request", error_description="Missing code" ) - # Session might not exist or _access_code malformed - try: - _info = _sdb[_access_code] - except KeyError: - return self.error_cls( - error="invalid_grant", error_description="Code is invalid" - ) - - _authn_req = _info["authn_req"] + _session_info = _mngr.get_session_info_by_token(_access_code) + grant, code = _mngr.find_grant(_session_info["session_id"], _access_code) # assert that the code is valid - if _context.sdb.is_session_revoked(_access_code): + if code.is_active() is False: return self.error_cls( - error="invalid_grant", error_description="Session is revoked" + error="invalid_grant", error_description="Code is invalid" ) + _authn_req = _session_info["client_session_info"]["authorization_request"] + # If redirect_uri was in the initial authorization request # verify that the one given here is the correct one. if "redirect_uri" in _authn_req: @@ -83,26 +76,53 @@ def _access_token(self, req, **kwargs): issue_refresh = False if "issue_refresh" in kwargs: issue_refresh = kwargs["issue_refresh"] + else: + if "offline_access" in grant.scope: + issue_refresh = True + + token = grant.mint_token( + "access_token", + value=_mngr.token_handler["access_token"]( + _session_info["session_id"], + client_id=_session_info["client_id"], + aud=grant.resources, + user_claims=None, + scope=grant.scope, + sub=_session_info["client_session_info"]['sub'] + ), + expires_at=time_sans_frac() + 900, # 15 minutes from now + based_on=code + ) - # offline_access the default if nothing is specified - permissions = _info.get("permission", ["offline_access"]) - - if "offline_access" in _authn_req["scope"] and "offline_access" in permissions: - issue_refresh = True - - try: - _info = _sdb.upgrade_to_token(_access_code, issue_refresh=issue_refresh) - except AccessCodeUsed as err: - logger.error("%s" % err) - # Should revoke the token issued to this access code - _sdb.revoke_all_tokens(_access_code) - return self.error_cls( - error="access_denied", error_description="Access Code already used" + _response = { + "access_token": token.value, + "token_type": "Bearer", + "expires_in": 900, + "scope": grant.scope, + "state": _authn_req["state"] + } + + if issue_refresh: + refresh_token = grant.mint_token( + "refresh_token", + value=_mngr.token_handler["refresh_token"]( + _session_info["session_id"], + client_id=_session_info["client_id"], + aud=grant.resources, + user_claims=None, + scope=grant.scope, + sub=_session_info["client_session_info"]['sub'] + ), + based_on=code ) + _response["refresh_token"] = refresh_token.value + + code.register_usage() if "openid" in _authn_req["scope"]: try: - _idtoken = _context.idtoken.make(req, _info, _authn_req) + _idtoken = _context.idtoken.make(_session_info["user_id"], + _session_info["client_id"]) except (JWEException, NoSuitableSigningKeys) as err: logger.warning(str(err)) resp = self.error_cls( @@ -111,10 +131,9 @@ def _access_token(self, req, **kwargs): ) return resp - _sdb.update_by_token(_access_code, id_token=_idtoken) - _info = _sdb[_info["sid"]] + _response["id_token"] = _idtoken - return by_schema(AccessTokenResponse, **_info) + return _response def get_client_id_from_token(self, endpoint_context, token, request=None): sinfo = endpoint_context.sdb[token] @@ -129,26 +148,34 @@ def _post_parse_request(self, request, client_id="", **kwargs): :returns: """ - if "state" in request: - try: - sinfo = self.endpoint_context.sdb[request["code"]] - except KeyError: - logger.error("Code not present in SessionDB") - return self.error_cls(error="access_denied") - except MultipleCodeUsage: - logger.error("Access Code reused") - # Remove any access tokens issued - self.endpoint_context.sdb.revoke_all_tokens(request["code"]) - return self.error_cls(error="invalid_grant") - else: - state = sinfo["authn_req"]["state"] + _mngr = self.endpoint_context.session_manager + try: + _session_info = _mngr.get_session_info_by_token(request["code"]) + except KeyError: + logger.error("Access Code invalid") + return self.error_cls(error="invalid_grant") + + grant, code = _mngr.find_grant(_session_info["session_id"], request["code"]) + _auth_req = _session_info["client_session_info"]["authorization_request"] + if code.is_active(): + state = _auth_req["state"] + else: + logger.error("Access Code inactive") + # Remove any access tokens issued + if code.max_usage_reached(): + _mngr.revoke_token(_session_info["session_id"], + code.value, recursive=True) + return self.error_cls(error="invalid_grant") + if "state" in request: + # verify that state in this request is the same as the one in the + # authorization request if state != request["state"]: logger.error("State value mismatch") return self.error_cls(error="invalid_request") if "client_id" not in request: # Optional for access token request - request["client_id"] = client_id + request["client_id"] = _auth_req["client_id"] logger.debug("%s: %s" % (request.__class__.__name__, sanitize(request))) @@ -171,10 +198,13 @@ def process_request(self, request=None, **kwargs): if isinstance(response_args, ResponseMessage): return response_args - _access_token = response_args["access_token"] + _mngr = self.endpoint_context.session_manager + _tinfo = _mngr.token_handler.info(request["code"]) + _cs_info = _mngr[_tinfo["sid"]] + _cookie = new_cookie( self.endpoint_context, - sub=self.endpoint_context.sdb[_access_token]["sub"], + sub=_cs_info["sub"], cookie_name=self.endpoint_context.cookie_name["session"], ) _headers = [("Content-type", "application/json")] diff --git a/src/oidcendpoint/oidc/userinfo.py b/src/oidcendpoint/oidc/userinfo.py index a63cdc5..6b23495 100755 --- a/src/oidcendpoint/oidc/userinfo.py +++ b/src/oidcendpoint/oidc/userinfo.py @@ -9,8 +9,10 @@ from oidcmsg.oauth2 import ResponseMessage from oidcendpoint.endpoint import Endpoint +from oidcendpoint.session_management import unpack_db_key from oidcendpoint.token_handler import UnknownToken -from oidcendpoint.userinfo import collect_user_info +# from oidcendpoint.userinfo import collect_user_info +from oidcendpoint.userinfo import ClaimsInterface from oidcendpoint.util import OAUTH2_NOCACHE_HEADERS logger = logging.getLogger(__name__) @@ -34,13 +36,14 @@ class UserInfo(Endpoint): def __init__(self, endpoint_context, **kwargs): Endpoint.__init__(self, endpoint_context, **kwargs) - self.scope_to_claims = None # Add the issuer ID as an allowed JWT target self.allowed_targets.append("") + self.claims_interface = ClaimsInterface(endpoint_context, "userinfo", **kwargs) def get_client_id_from_token(self, endpoint_context, token, request=None): - sinfo = self.endpoint_context.sdb[token] - return sinfo["authn_req"]["client_id"] + _info = endpoint_context.session_manager.token_handler.info(token) + sinfo = self.endpoint_context.session_manager[_info["sid"]] + return sinfo["authorization_request"]["client_id"] def do_response(self, response_args=None, request=None, client_id="", **kwargs): @@ -96,23 +99,24 @@ def do_response(self, response_args=None, request=None, client_id="", **kwargs): return {"response": resp, "http_headers": http_headers} def process_request(self, request=None, **kwargs): - _sdb = self.endpoint_context.sdb - + _mngr = self.endpoint_context.session_manager + _info = _mngr.token_handler.info(request["access_token"]) + grant, token = _mngr.find_grant(_info['sid'], request["access_token"]) # should be an access token - if not _sdb.is_token_valid(request["access_token"]): + if token.is_active() is False: return self.error_cls( error="invalid_token", error_description="Invalid Token" ) - - session = _sdb.read(request["access_token"]) - + _user_id, _client_id = unpack_db_key(_info['sid']) + _cs_info = _mngr.get([_user_id, _client_id]) + _us_info = _mngr.get([_user_id]) allowed = True # if the authenticate is still active or offline_access is granted. - if session["authn_event"]["valid_until"] > utc_time_sans_frac(): + if _us_info["authentication_event"]["valid_until"] > utc_time_sans_frac(): pass else: logger.debug("authentication not valid: {} > {}".format( - session["authn_event"]["valid_until"], utc_time_sans_frac() + _us_info["authentication_event"]["valid_until"], utc_time_sans_frac() )) allowed = False @@ -122,14 +126,16 @@ def process_request(self, request=None, **kwargs): if allowed: # Scope can translate to userinfo_claims - info = collect_user_info(self.endpoint_context, session) + _restrictions = grant.claims.get("userinfo") + info = self.claims_interface.get_user_claims( + user_id=_user_id, claims_restriction=_restrictions) else: info = { "error": "invalid_request", "error_description": "Access not granted", } - return {"response_args": info, "client_id": session["authn_req"]["client_id"]} + return {"response_args": info, "client_id": _client_id} def parse_request(self, request, auth=None, **kwargs): """ diff --git a/src/oidcendpoint/old_id_token.py b/src/oidcendpoint/old_id_token.py new file mode 100755 index 0000000..293a5fe --- /dev/null +++ b/src/oidcendpoint/old_id_token.py @@ -0,0 +1,288 @@ +import logging + +from cryptojwt.jws.utils import left_hash +from cryptojwt.jwt import JWT + +from oidcendpoint.endpoint import construct_endpoint_info + +logger = logging.getLogger(__name__) + +DEF_SIGN_ALG = { + "id_token": "RS256", + "userinfo": "RS256", + "request_object": "RS256", + "client_secret_jwt": "HS256", + "private_key_jwt": "RS256", +} +DEF_LIFETIME = 300 + + +def include_session_id(endpoint_context, client_id, where): + """ + + :param endpoint_context: + :param client_id: + :param dir: front or back + :return: + """ + _pinfo = endpoint_context.provider_info + + # Am the OP supposed to support {dir}-channel log out and if so can + # it pass sid in logout token and ID Token + for param in ["{}channel_logout_supported", "{}channel_logout_session_supported"]: + try: + _supported = _pinfo[param.format(where)] + except KeyError: + return False + else: + if not _supported: + return False + + # Does the client support back-channel logout ? + try: + _val = endpoint_context.cdb[client_id]["{}channel_logout_uri".format(where)] + except KeyError: + return False + + return True + + +def get_sign_and_encrypt_algorithms( + endpoint_context, client_info, payload_type, sign=False, encrypt=False +): + args = {"sign": sign, "encrypt": encrypt} + if sign: + try: + args["sign_alg"] = client_info[ + "{}_signed_response_alg".format(payload_type) + ] + except KeyError: # Fall back to default + try: + args["sign_alg"] = endpoint_context.jwx_def["signing_alg"][payload_type] + except KeyError: + _def_sign_alg = DEF_SIGN_ALG[payload_type] + _supported = endpoint_context.provider_info[ + "{}_signing_alg_values_supported".format(payload_type) + ] + + if _def_sign_alg in _supported: + args["sign_alg"] = _def_sign_alg + else: + args["sign_alg"] = _supported[0] + + if encrypt: + try: + args["enc_alg"] = client_info["%s_encrypted_response_alg" % payload_type] + except KeyError: + try: + args["enc_alg"] = endpoint_context.jwx_def["encryption_alg"][ + payload_type + ] + except KeyError: + _supported = endpoint_context.provider_info[ + "{}_encryption_alg_values_supported".format(payload_type) + ] + args["enc_alg"] = _supported[0] + + try: + args["enc_enc"] = client_info["%s_encrypted_response_enc" % payload_type] + except KeyError: + try: + args["enc_enc"] = endpoint_context.jwx_def["encryption_enc"][ + payload_type + ] + except KeyError: + _supported = endpoint_context.provider_info[ + "{}_encryption_enc_values_supported".format(payload_type) + ] + args["enc_enc"] = _supported[0] + + return args + + +class IDToken(object): + default_capabilities = { + "id_token_signing_alg_values_supported": None, + "id_token_encryption_alg_values_supported": None, + "id_token_encryption_enc_values_supported": None, + } + + def __init__(self, endpoint_context, **kwargs): + self.endpoint_context = endpoint_context + self.kwargs = kwargs + self.enable_claims_per_client = kwargs.get("enable_claims_per_client", False) + self.scope_to_claims = None + self.provider_info = construct_endpoint_info( + self.default_capabilities, **kwargs + ) + + def payload( + self, + session, + acr="", + alg="RS256", + code=None, + access_token=None, + user_info=None, + auth_time=0, + lifetime=None, + extra_claims=None, + ): + """ + + :param session: Session information + :param acr: Default Assurance/Authentication context class reference + :param alg: Which signing algorithm to use for the IdToken + :param code: Access grant + :param access_token: Access Token + :param user_info: If user info are to be part of the IdToken + :param auth_time: + :param lifetime: Life time of the ID Token + :param extra_claims: extra claims to be added to the ID Token + :return: IDToken instance + """ + + _args = {"sub": session["sub"]} + + if lifetime is None: + lifetime = DEF_LIFETIME + + if auth_time: + _args["auth_time"] = auth_time + if acr: + _args["acr"] = acr + + if user_info: + try: + user_info = user_info.to_dict() + except AttributeError: + pass + + # Make sure that there are no name clashes + for key in ["iss", "sub", "aud", "exp", "acr", "nonce", "auth_time"]: + try: + del user_info[key] + except KeyError: + pass + + _args.update(user_info) + + if extra_claims is not None: + _args.update(extra_claims) + + # Left hashes of code and/or access_token + halg = "HS%s" % alg[-3:] + if code: + _args["c_hash"] = left_hash(code.encode("utf-8"), halg) + if access_token: + _args["at_hash"] = left_hash(access_token.encode("utf-8"), halg) + + authn_req = session["authn_req"] + if authn_req: + try: + _args["nonce"] = authn_req["nonce"] + except KeyError: + pass + + return {"payload": _args, "lifetime": lifetime} + + def sign_encrypt( + self, + session_info, + client_id, + code=None, + access_token=None, + user_info=None, + sign=True, + encrypt=False, + lifetime=None, + extra_claims=None, + ): + """ + Signed and or encrypt a IDToken + + :param session_info: Session information + :param client_id: Client ID + :param code: Access grant + :param access_token: Access Token + :param user_info: User information + :param sign: If the JWT should be signed + :param encrypt: If the JWT should be encrypted + :param extra_claims: Extra claims to be added to the ID Token + :return: IDToken as a signed and/or encrypted JWT + """ + + _cntx = self.endpoint_context + + client_info = _cntx.cdb[client_id] + alg_dict = get_sign_and_encrypt_algorithms( + _cntx, client_info, "id_token", sign=sign, encrypt=encrypt + ) + + _authn_event = session_info["authn_event"] + + _idt_info = self.payload( + session_info, + acr=_authn_event["authn_info"], + alg=alg_dict["sign_alg"], + code=code, + access_token=access_token, + user_info=user_info, + auth_time=_authn_event["authn_time"], + lifetime=lifetime, + extra_claims=extra_claims, + ) + + _jwt = JWT( + _cntx.keyjar, iss=_cntx.issuer, lifetime=_idt_info["lifetime"], **alg_dict + ) + + return _jwt.pack(_idt_info["payload"], recv=client_id) + + def make(self, req, sess_info, authn_req=None, user_claims=False, **kwargs): + _context = self.endpoint_context + + if authn_req: + _client_id = authn_req["client_id"] + else: + _client_id = req["client_id"] + + _cinfo = _context.cdb[_client_id] + + idtoken_claims = dict(self.kwargs.get("available_claims", {})) + if self.enable_claims_per_client: + idtoken_claims.update(_cinfo.get("id_token_claims", {})) + lifetime = self.kwargs.get("lifetime") + + userinfo = userinfo_in_id_token_claims(_context, sess_info, idtoken_claims) + + if user_claims: + info = collect_user_info(_context, sess_info) + if userinfo is None: + userinfo = info + else: + userinfo.update(info) + + # Should I add session ID + req_sid = include_session_id( + _context, _client_id, "back" + ) or include_session_id(_context, _client_id, "front") + + if req_sid: + xargs = { + "sid": _context.sdb.get_sid_by_sub_and_client_id( + sess_info["sub"], _client_id + ) + } + else: + xargs = {} + + return self.sign_encrypt( + sess_info, + _client_id, + sign=True, + user_info=userinfo, + lifetime=lifetime, + extra_claims=xargs, + **kwargs + ) diff --git a/src/oidcendpoint/session.py b/src/oidcendpoint/old_session.py similarity index 100% rename from src/oidcendpoint/session.py rename to src/oidcendpoint/old_session.py diff --git a/src/oidcendpoint/scopes.py b/src/oidcendpoint/scopes.py index b3413ae..e23dfdf 100644 --- a/src/oidcendpoint/scopes.py +++ b/src/oidcendpoint/scopes.py @@ -45,19 +45,23 @@ def available_claims(endpoint_context): return STANDARD_CLAIMS -def convert_scopes2claims(scopes, allowed_claims, map=None): +def convert_scopes2claims(scopes, allowed_claims=None, map=None): if map is None: map = SCOPE2CLAIMS res = {} - for scope in scopes: - try: - claims = dict( - [(name, None) for name in map[scope] if name in allowed_claims] - ) + if allowed_claims is None: + for scope in scopes: + claims = {name: None for name in map[scope]} res.update(claims) - except KeyError: - continue + else: + for scope in scopes: + try: + claims = {name: None for name in map[scope] if name in allowed_claims} + res.update(claims) + except KeyError: + continue + return res diff --git a/src/oidcendpoint/session_management.py b/src/oidcendpoint/session_management.py index cc65f2f..67ef2c1 100644 --- a/src/oidcendpoint/session_management.py +++ b/src/oidcendpoint/session_management.py @@ -2,6 +2,8 @@ import logging from oidcendpoint import rndstr +from oidcendpoint import token_handler +from oidcendpoint.token_handler import UnknownToken logger = logging.getLogger(__name__) @@ -77,7 +79,10 @@ def is_revoked(self): class UserSessionInfo(SessionInfo): - pass + def __init__(self, **kwargs): + SessionInfo.__init__(self, **kwargs) + if "logout_sid" not in self._db: + self._db["logout_sid"] = {} class ClientSessionInfo(SessionInfo): @@ -181,6 +186,9 @@ def get(self, path: list): user_info = self._db[uid] except KeyError: raise KeyError('No such UserID') + else: + if user_info is None: + raise KeyError('No such UserID') if client_id is None: return user_info @@ -243,10 +251,9 @@ def update(self, path, new_info): class SessionManager(Database): - def __init__(self, db, handler, userinfo=None, sub_func=None): + def __init__(self, db, handler, sub_func=None): Database.__init__(self, db) self.token_handler = handler - self.userinfo = userinfo self.salt = rndstr(32) # this allows the subject identifier minters to be defined by someone @@ -263,14 +270,18 @@ def __init__(self, db, handler, userinfo=None, sub_func=None): def get_user_info(self, uid): return self.get(uid) - def find_grant(self, session_id, token_value): - user_id, client_id = unpack_db_key(session_id) - client_info = self.get([user_id, client_id]) - for grant_id in client_info["subordinate"]: - grant = self.get([user_id, client_id, grant_id]) - for token in grant.issued_token: - if token.value == token_value: - return grant, token + def find_token(self, session_id, token_value): + """ + + :param session_id: Based on 3-tuple, user_id, client_id and grant_id + :param token_value: + :return: + """ + user_id, client_id, grant_id = unpack_db_key(session_id) + grant = self.get([user_id, client_id, grant_id]) + for token in grant.issued_token: + if token.value == token_value: + return token return None @@ -282,6 +293,8 @@ def create_session(self, authn_event, auth_req, user_id, client_id="", :param auth_req: Authorization Request :param client_id: Client ID :param user_id: User ID + :param sector_identifier: + :param sub_type: :param kwargs: extra keyword arguments :return: """ @@ -304,10 +317,16 @@ def create_session(self, authn_event, auth_req, user_id, client_id="", self.set([user_id, client_id], client_info) def _update_client_info(self, session_id, new_information): - _path = unpack_db_key(session_id) - _client_info = self.get(_path) + """ + + :param session_id: + :param new_information: + :return: + """ + _user_id, _client_id, _grant_id = unpack_db_key(session_id) + _client_info = self.get([_user_id, _client_id]) _client_info.update(new_information) - self.set(_path, _client_info) + self.set([_user_id, _client_id], _client_info) def do_sub(self, session_id, sector_id="", subject_type="public"): """ @@ -318,37 +337,80 @@ def do_sub(self, session_id, sector_id="", subject_type="public"): :param subject_type: 'pairwise'/'public' :return: """ - _path = unpack_db_key(session_id) - sub = self.sub_func[subject_type](_path[0], salt=self.salt, sector_identifier=sector_id) + _user_id, _client_id, _grant_id = unpack_db_key(session_id) + sub = self.sub_func[subject_type](_user_id, salt=self.salt, sector_identifier=sector_id) self._update_client_info(session_id, {'sub': sub}) return sub def __getitem__(self, item): return self.get(unpack_db_key(item)) - def revoke_token(self, session_id, token_value): - grant, token = self.find_grant(session_id, token_value) + def get_client_session_info(self, session_id): + _user_id, _client_id, _grant_id = unpack_db_key(session_id) + self.get([_user_id, _client_id]) + + def _revoke_dependent(self, grant, token): + for t in grant.issued_token: + if t.based_on == token.value: + t.revoked = True + self._revoke_dependent(grant, t) + + def revoke_token(self, session_id, token_value, recursive=False): + token = self.find_token(session_id, token_value) + if token is None: + raise UnknownToken() + token.revoked = True + if recursive: + grant = self[session_id] + self._revoke_dependent(grant, token) def get_sids_by_user_id(self, user_id): user_info = self.get([user_id]) return [db_key(user_id, c) for c in user_info['subordinate']] - def get_authentication_event(self, user_id): + def get_authentication_event(self, session_id): + _user_id = unpack_db_key(session_id)[0] try: - user_info = self.get([user_id]) + user_info = self.get([_user_id]) except KeyError: return None return user_info["authentication_event"] - def revoke_session(self, session_id): + def revoke_client_session(self, session_id): + _user_id, _client_id, _ = unpack_db_key(session_id) + _info = self.get([_user_id, _client_id]) + _info.revoke() + self.set([_user_id, _client_id], _info) + + def revoke_grant(self, session_id): _path = unpack_db_key(session_id) _info = self.get(_path) _info.revoke() self.set(_path, _info) def grants(self, session_id): - uid, cid = unpack_db_key(session_id) + uid, cid, _gid = unpack_db_key(session_id) _csi = self.get([uid, cid]) return [self.get([uid, cid, gid]) for gid in _csi['subordinate']] + + def get_session_info(self, session_id): + _user_id, _client_id, _grant_id = unpack_db_key(session_id) + return { + "session_id": session_id, + "user_id": _user_id, + "client_id": _client_id, + "user_session_info": self.get([_user_id]), + "client_session_info": self.get([_user_id, _client_id]), + "grant": self.get([_user_id, _client_id, _grant_id]) + } + + def get_session_info_by_token(self, token_value): + _token_info = self.token_handler.info(token_value) + return self.get_session_info(_token_info["sid"]) + + +def create_session_manager(endpoint_context, token_handler_args, db=None, sub_func=None): + _token_handler = token_handler.factory(endpoint_context, **token_handler_args) + return SessionManager(db, _token_handler, sub_func=sub_func) diff --git a/src/oidcendpoint/user_authn/user.py b/src/oidcendpoint/user_authn/user.py index dce3116..0f1a628 100755 --- a/src/oidcendpoint/user_authn/user.py +++ b/src/oidcendpoint/user_authn/user.py @@ -11,7 +11,6 @@ from cryptojwt.jwt import JWT from oidcendpoint import sanitize -from oidcendpoint.authn_event import create_authn_event from oidcendpoint.exception import FailedAuthentication from oidcendpoint.exception import ImproperlyConfigured from oidcendpoint.exception import InvalidCookieSign diff --git a/src/oidcendpoint/userinfo.py b/src/oidcendpoint/userinfo.py index 33b7160..e896adf 100755 --- a/src/oidcendpoint/userinfo.py +++ b/src/oidcendpoint/userinfo.py @@ -1,67 +1,71 @@ import logging -from oidcmsg.oidc import Claims - -from oidcendpoint import sanitize -from oidcendpoint.exception import FailedAuthentication -from oidcendpoint.exception import ImproperlyConfigured from oidcendpoint.scopes import convert_scopes2claims logger = logging.getLogger(__name__) -def id_token_claims(session, provider_info): - """ - Pick the IdToken claims from the request - - :param session: Session information - :return: The IdToken claims - """ - itc = update_claims(session, "id_token", provider_info=provider_info, old_claims={}) - return itc - - -def update_claims(session, about, provider_info, old_claims=None): - """ - - :param session: - :param about: userinfo or id_token - :param old_claims: - :return: claims or None - """ - - if old_claims is None: - old_claims = {} - - req = None - try: - req = session["authn_req"] - except KeyError: - pass - - if req: - try: - _claims = req["claims"][about] - except KeyError: - pass +class ClaimsInterface: + init_args = { + "add_claims_by_scope": False, + "enable_claims_per_client": False + } + + def __init__(self, endpoint_context, usage, **kwargs): + self.usage = usage # for instance introspection, id_token, userinfo + self.endpoint_context = endpoint_context + self.add_claims_by_scope = kwargs.get("add_claims_by_scope", + self.init_args["add_claims_by_scope"]) + self.enable_claims_per_client = kwargs.get("enable_claims_per_client", + self.init_args["enable_claims_per_client"]) + + def request_claims(self, user_id, client_id): + if self.usage in ["id_token", "userinfo"]: + _csi = self.endpoint_context.session_manager.get([user_id, client_id]) + if "claims" in _csi["authorization_request"]: + return _csi["authorization_request"]["claims"].get(self.usage) + + return {} + + def _get_client_claims(self, client_id): + client_info = self.endpoint_context.cdb.get(client_id, {}) + return client_info.get("{}_claims".format(self.usage), {}) + + def get_claims(self, client_id, user_id, scopes): + """ + + :param client_id: + :param user_id: + :param scopes: + :return: + """ + claims = self._get_client_claims(client_id) + if self.add_claims_by_scope: + _supported = self.endpoint_context.provider_info.get("scopes_supported", []) + if _supported: + _scopes = set(_supported).intersection(set(scopes)) + else: + _scopes = scopes + + _claims = convert_scopes2claims(_scopes, map=self.endpoint_context.scope2claims) + claims.update(_claims) + request_claims = self.request_claims(user_id=user_id, client_id=client_id) + claims.update(request_claims) + return claims + + def get_user_claims(self, user_id, claims_restriction): + """ + + :param user_id: User identifier + :param claims_restriction: Specifies the upper limit of what claims can be returned + :return: + """ + if claims_restriction: + user_info = self.endpoint_context.userinfo(user_id, client_id=None) + return {k: user_info.get(k) for k, v in claims_restriction.items() if + claims_match(user_info.get(k), v)} else: - if _claims: - # Deal only with supported claims - _unsup = [ - c - for c in _claims.keys() - if c not in provider_info["claims_supported"] - ] - for _c in _unsup: - del _claims[_c] - - # update with old claims, do not overwrite - for key, val in old_claims.items(): - if key not in _claims: - _claims[key] = val - return _claims - - return old_claims + return {} def claims_match(value, claimspec): @@ -76,6 +80,9 @@ def claims_match(value, claimspec): as key :return: Boolean """ + if value is None: + return False + if claimspec is None: # match anything return True @@ -110,103 +117,3 @@ def by_schema(cls, **kwa): :return: A dictionary with claims (keys) that meets the filter criteria """ return dict([(key, val) for key, val in kwa.items() if key in cls.c_param]) - - -def collect_user_info( - endpoint_context, session, userinfo_claims=None, scope_to_claims=None -): - """ - Collect information about a user. - This can happen in two cases, either when constructing an IdToken or - when returning user info through the UserInfo endpoint - - :param session: Session information - :param userinfo_claims: user info claims - :return: User info - """ - authn_req = session["authorization_request"] - if scope_to_claims is None: - scope_to_claims = endpoint_context.scope2claims - - _allowed = endpoint_context.scopes_handler.allowed_scopes( - authn_req["client_id"], endpoint_context - ) - supported_scopes = [s for s in authn_req["scope"] if s in _allowed] - if userinfo_claims is None: - _allowed_claims = endpoint_context.claims_handler.allowed_claims( - authn_req["client_id"], endpoint_context - ) - uic = convert_scopes2claims( - supported_scopes, _allowed_claims, map=scope_to_claims - ) - - # Get only keys allowed by user and update the dict if such info - # is stored in session - perm_set = session.get("permission") - if perm_set: - uic = {key: uic[key] for key in uic if key in perm_set} - - uic = update_claims( - session, - "userinfo", - provider_info=endpoint_context.provider_info, - old_claims=uic, - ) - - if uic: - userinfo_claims = Claims(**uic) - logger.debug("userinfo_claim: %s" % sanitize(userinfo_claims.to_dict())) - else: - userinfo_claims = None - logger.warning(("Client {} doesn't have any claims " - "belonging to one or more scopes.").format(authn_req["client_id"])) - raise ImproperlyConfigured("Some additional scopes doesn't have any claims.") - - logger.debug("Session info: %s" % sanitize(session)) - - authn_event = session["authn_event"] - if authn_event: - uid = authn_event["uid"] - else: - uid = session["uid"] - - info = endpoint_context.userinfo(uid, authn_req["client_id"], userinfo_claims) - - if "sub" in userinfo_claims: - if not claims_match(session["sub"], userinfo_claims["sub"]): - raise FailedAuthentication("Unmatched sub claim") - - info["sub"] = session["sub"] - try: - logger.debug("user_info_response: {}".format(info)) - except UnicodeEncodeError: - logger.debug("user_info_response: {}".format(info.encode("utf-8"))) - - return info - - -def userinfo_in_id_token_claims(endpoint_context, session, def_itc=None): - """ - Collect user info claims that are to be placed in the id token. - - :param endpoint_context: Endpoint context - :param session: Session information - :param def_itc: Default ID Token claims - :return: User information or None - """ - if def_itc: - itc = def_itc - else: - itc = {} - - itc.update(id_token_claims(session, provider_info=endpoint_context.provider_info)) - - if not itc: - return None - - _claims = by_schema(endpoint_context.id_token_schema, **itc) - - if _claims: - return collect_user_info(endpoint_context, session, _claims) - else: - return None diff --git a/src/oidcendpoint/util.py b/src/oidcendpoint/util.py index 89ac46c..64372e6 100755 --- a/src/oidcendpoint/util.py +++ b/src/oidcendpoint/util.py @@ -174,7 +174,8 @@ def split_uri(uri): def allow_refresh_token(endpoint_context): # Are there a refresh_token handler - refresh_token_handler = endpoint_context.sdb.handler.handler.get("refresh_token") + refresh_token_handler = endpoint_context.session_manager.token_handler.handler[ + "refresh_token"] # Is refresh_token grant type supported _token_supported = False diff --git a/tests/test_70_grant.py b/tests/test_01_grant.py similarity index 100% rename from tests/test_70_grant.py rename to tests/test_01_grant.py diff --git a/tests/test_71_sess_mngm_db.py b/tests/test_01_sess_mngm_db.py similarity index 100% rename from tests/test_71_sess_mngm_db.py rename to tests/test_01_sess_mngm_db.py diff --git a/tests/test_72_session_life.py b/tests/test_01_session_life.py similarity index 89% rename from tests/test_72_session_life.py rename to tests/test_01_session_life.py index d93cf4e..6a6771b 100644 --- a/tests/test_72_session_life.py +++ b/tests/test_01_session_life.py @@ -1,11 +1,11 @@ import os -import pytest from cryptojwt.key_jar import init_key_jar from oidcmsg.oidc import AccessTokenRequest from oidcmsg.oidc import AuthorizationRequest from oidcmsg.oidc import RefreshAccessTokenRequest from oidcmsg.time_util import time_sans_frac +import pytest from oidcendpoint import user_info from oidcendpoint.authn_event import create_authn_event @@ -19,11 +19,11 @@ from oidcendpoint.oidc.session import Session from oidcendpoint.oidc.token import AccessToken from oidcendpoint.session_management import ClientSessionInfo +from oidcendpoint.session_management import SessionManager +from oidcendpoint.session_management import UserSessionInfo from oidcendpoint.session_management import db_key from oidcendpoint.session_management import public_id -from oidcendpoint.session_management import SessionManager from oidcendpoint.session_management import unpack_db_key -from oidcendpoint.session_management import UserSessionInfo from oidcendpoint.token_handler import DefaultToken from oidcendpoint.token_handler import TokenHandler from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD @@ -40,16 +40,16 @@ def setup_token_handler(self): code_handler = DefaultToken(password, typ="A", lifetime=grant_expires_in) access_token_handler = DefaultToken( password, typ="T", lifetime=token_expires_in - ) + ) refresh_token_handler = DefaultToken( password, typ="R", lifetime=refresh_token_expires_in - ) + ) handler = TokenHandler( code_handler=code_handler, access_token_handler=access_token_handler, refresh_token_handler=refresh_token_handler, - ) + ) self.session_manager = SessionManager({}, handler=handler) @@ -62,7 +62,7 @@ def auth(self): scope=["openid", "mail", "address", "offline_access"], state="STATE", response_type="code", - ) + ) # The authentication returns a user ID user_id = "diana" @@ -73,9 +73,9 @@ def auth(self): self.session_manager.salt, authn_info=INTERNETPROTOCOLPASSWORD, authn_time=time_sans_frac(), - ) + ) - user_info = UserSessionInfo(authenticationEvent=authn_event) + user_info = UserSessionInfo(authentication_event=authn_event) self.session_manager.set([user_id], user_info) # Now for client session information @@ -83,7 +83,7 @@ def auth(self): client_info = ClientSessionInfo( authorization_request=AUTH_REQ, sub=public_id(user_id, self.session_manager.salt) - ) + ) self.session_manager.set([user_id, AUTH_REQ['client_id']], client_info) # The user consent module produces a Grant instance @@ -100,13 +100,13 @@ def auth(self): 'authorization_code', value=self.session_manager.token_handler["code"](user_id), expires_at=time_sans_frac() + 300 # 5 minutes from now - ) + ) - return code + return grant.id, code def test_code_flow(self): # code is a Token instance - code = self.auth() + _grant_id, code = self.auth() # next step is access token request @@ -117,7 +117,7 @@ def test_code_flow(self): grant_type="authorization_code", client_secret="hemligt", code=code.value - ) + ) # parse the token user_id = self.session_manager.token_handler.sid(TOKEN_REQ['code']) @@ -126,9 +126,9 @@ def test_code_flow(self): # token I can easily find the grant # client_info = self.session_manager.get([user_id, TOKEN_REQ['client_id']]) - session_id = db_key(user_id, TOKEN_REQ['client_id']) - grant, tok = self.session_manager.find_grant(session_id, - TOKEN_REQ['code']) + session_id = db_key(user_id, TOKEN_REQ['client_id'], _grant_id) + tok = self.session_manager.find_token(session_id, + TOKEN_REQ['code']) # Verify that it's of the correct type and can be used assert tok.type == "authorization_code" @@ -138,12 +138,14 @@ def test_code_flow(self): assert tok.supports_minting("access_token") + grant = self.session_manager[session_id] + access_token = grant.mint_token( 'access_token', value=self.session_manager.token_handler["access_token"](user_id), expires_at=time_sans_frac() + 900, # 15 minutes from now based_on=tok # Means the token (tok) was used to mint this token - ) + ) assert tok.supports_minting("refresh_token") @@ -151,7 +153,7 @@ def test_code_flow(self): 'refresh_token', value=self.session_manager.token_handler["refresh_token"](user_id), based_on=tok - ) + ) tok.register_usage() @@ -165,11 +167,11 @@ def test_code_flow(self): client_secret="hemligt", refresh_token=refresh_token.value, scope=["openid", "mail", "offline_access"] - ) + ) - session_id = db_key(user_id,REFRESH_TOKEN_REQ['client_id']) - grant, reftok = self.session_manager.find_grant(session_id, - REFRESH_TOKEN_REQ['refresh_token']) + session_id = db_key(user_id, REFRESH_TOKEN_REQ['client_id'], _grant_id) + reftok = self.session_manager.find_token(session_id, + REFRESH_TOKEN_REQ['refresh_token']) assert reftok.supports_minting("access_token") @@ -178,7 +180,7 @@ def test_code_flow(self): value=self.session_manager.token_handler["access_token"](user_id), expires_at=time_sans_frac() + 900, # 15 minutes from now based_on=reftok # Means the token (tok) was used to mint this token - ) + ) assert access_token_2.is_active() @@ -186,7 +188,7 @@ def test_code_flow(self): KEYDEFS = [ {"type": "RSA", "key": "", "use": ["sig"]}, {"type": "EC", "crv": "P-256", "use": ["sig"]}, - ] +] ISSUER = "https://example.com/" @@ -201,7 +203,7 @@ def test_code_flow(self): ["id_token", "token"], ["code", "token", "id_token"], ["none"], - ] +] CAPABILITIES = { "response_types_supported": [" ".join(x) for x in RESPONSE_TYPES_SUPPORTED], "token_endpoint_auth_methods_supported": [ @@ -209,19 +211,19 @@ def test_code_flow(self): "client_secret_basic", "client_secret_jwt", "private_key_jwt", - ], + ], "response_modes_supported": ["query", "fragment", "form_post"], "subject_types_supported": ["public", "pairwise"], "grant_types_supported": [ "authorization_code", "implicit", "urn:ietf:params:oauth:grant-type:jwt-bearer", - ], + ], "claim_types_supported": ["normal", "aggregated", "distributed"], "claims_parameter_supported": True, "request_parameter_supported": True, "request_uri_parameter_supported": True, - } +} BASEDIR = os.path.abspath(os.path.dirname(__file__)) @@ -248,8 +250,8 @@ def setup_session_manager(self): "key_defs": [ {"type": "oct", "bytes": "24", "use": ["enc"], "kid": "code"}, {"type": "oct", "bytes": "24", "use": ["enc"], "kid": "refresh"} - ], - }, + ], + }, "code": {"lifetime": 600}, "token": { "class": "oidcendpoint.jwt_token.JWTToken", @@ -260,50 +262,52 @@ def setup_session_manager(self): "email_verified", "phone_number", "phone_number_verified", - ], + ], "add_claim_by_scope": True, "aud": ["https://example.org/appl"], - }, }, - "refresh": {}, }, + "refresh": {}, + }, "endpoint": { "provider_config": { "path": "{}/.well-known/openid-configuration", "class": ProviderConfiguration, "kwargs": {}, - }, + }, "registration": { "path": "{}/registration", "class": Registration, "kwargs": {}, - }, + }, "authorization": { "path": "{}/authorization", "class": Authorization, "kwargs": {}, - }, + }, "token": {"path": "{}/token", "class": AccessToken, "kwargs": {}}, "session": {"path": "{}/end_session", "class": Session}, - }, + }, "client_authn": verify_client, "authentication": { "anon": { "acr": INTERNETPROTOCOLPASSWORD, "class": "oidcendpoint.user_authn.user.NoAuthn", "kwargs": {"user": "diana"}, - } - }, + } + }, "template_dir": "template", "userinfo": { "class": user_info.UserInfo, "kwargs": {"db_file": full_path("users.json")}, - }, + }, "id_token": {"class": IDToken}, - } + } self.endpoint_context = EndpointContext(conf, keyjar=KEYJAR) - self.session_manager = SessionManager({}, handler=self.endpoint_context.sdb.handler) + self.session_manager = self.endpoint_context.session_manager + # self.session_manager = SessionManager({}, handler=self.endpoint_context.sdb.handler) + # self.endpoint_context.session_manager = self.session_manager def auth(self): # Start with an authentication request @@ -314,7 +318,7 @@ def auth(self): scope=["openid", "mail", "address", "offline_access"], state="STATE", response_type="code", - ) + ) # The authentication returns a user ID user_id = "diana" @@ -331,13 +335,13 @@ def auth(self): salt, authn_info=INTERNETPROTOCOLPASSWORD, authn_time=time_sans_frac(), - ) + ) client_info = ClientSessionInfo( authorization_request=AUTH_REQ, - authenticationEvent=authn_event, + authentication_event=authn_event, sub=public_id(user_id, salt) - ) + ) self.session_manager.set([user_id, AUTH_REQ['client_id']], client_info) # The user consent module produces a Grant instance @@ -353,9 +357,9 @@ def auth(self): code = grant.mint_token( 'authorization_code', value=self.session_manager.token_handler["code"]( - db_key(user_id, AUTH_REQ['client_id'])), + db_key(user_id, AUTH_REQ['client_id'], grant.id)), expires_at=time_sans_frac() + 300 # 5 minutes from now - ) + ) return code @@ -372,19 +376,17 @@ def test_code_flow(self): grant_type="authorization_code", client_secret="hemligt", code=code.value - ) + ) # parse the token session_id = self.session_manager.token_handler.sid(TOKEN_REQ['code']) - user_id, client_id = unpack_db_key(session_id) + user_id, client_id, grant_id = unpack_db_key(session_id) # Now given I have the client_id from the request and the user_id from the # token I can easily find the grant # client_info = self.session_manager.get([user_id, TOKEN_REQ['client_id']]) - session_id = db_key(user_id, TOKEN_REQ['client_id']) - grant, tok = self.session_manager.find_grant(session_id, - TOKEN_REQ['code']) + tok = self.session_manager.find_token(session_id, TOKEN_REQ['code']) # Verify that it's of the correct type and can be used assert tok.type == "authorization_code" @@ -398,31 +400,33 @@ def test_code_flow(self): assert tok.supports_minting("access_token") + grant = self.session_manager[session_id] + user_claims = self.endpoint_context.userinfo(user_id, client_id=TOKEN_REQ["client_id"], user_info_claims=grant.claims) access_token = grant.mint_token( 'access_token', value=self.session_manager.token_handler["access_token"]( - db_key(user_id, client_id), + session_id, client_id=TOKEN_REQ['client_id'], aud=grant.resources, user_claims=user_claims, scope=grant.scope, sub=client_info['sub'] - ), + ), expires_at=time_sans_frac() + 900, # 15 minutes from now based_on=tok # Means the token (tok) was used to mint this token - ) + ) # this test is include in the mint_token methods # assert tok.supports_minting("refresh_token") refresh_token = grant.mint_token( 'refresh_token', - value=self.session_manager.token_handler["refresh_token"](db_key(user_id, client_id)), + value=self.session_manager.token_handler["refresh_token"](session_id), based_on=tok - ) + ) tok.register_usage() @@ -436,10 +440,10 @@ def test_code_flow(self): client_secret="hemligt", refresh_token=refresh_token.value, scope=["openid", "mail", "offline_access"] - ) + ) - session_id = db_key(user_id, REFRESH_TOKEN_REQ['client_id']) - grant, reftok = self.session_manager.find_grant(session_id, + session_id = db_key(user_id, REFRESH_TOKEN_REQ['client_id'], grant_id) + reftok = self.session_manager.find_token(session_id, REFRESH_TOKEN_REQ['refresh_token']) # Can I use this token to mint another token ? @@ -451,14 +455,17 @@ def test_code_flow(self): access_token_2 = grant.mint_token( 'access_token', value=self.session_manager.token_handler["access_token"]( - db_key(user_id, client_id), + session_id, sub=client_info['sub'], client_id=TOKEN_REQ['client_id'], aud=grant.resources, user_claims=user_claims - ), + ), expires_at=time_sans_frac() + 900, # 15 minutes from now based_on=reftok # Means the refresh token (reftok) was used to mint this token - ) + ) assert access_token_2.is_active() + + token_info = self.session_manager.token_handler.info(access_token_2.value) + assert token_info diff --git a/tests/test_03_id_token.py b/tests/test_03_id_token.py index 7ab1da0..37bfcf0 100644 --- a/tests/test_03_id_token.py +++ b/tests/test_03_id_token.py @@ -2,20 +2,25 @@ import os import time -import pytest from cryptojwt.jws import jws from cryptojwt.jwt import JWT from cryptojwt.key_jar import KeyJar from oidcmsg.oidc import AuthorizationRequest from oidcmsg.oidc import RegistrationResponse +from oidcmsg.time_util import time_sans_frac +import pytest +from oidcendpoint.authn_event import create_authn_event from oidcendpoint.client_authn import verify_client from oidcendpoint.endpoint_context import EndpointContext +from oidcendpoint.grant import Grant from oidcendpoint.id_token import IDToken from oidcendpoint.id_token import get_sign_and_encrypt_algorithms from oidcendpoint.oidc import userinfo from oidcendpoint.oidc.authorization import Authorization from oidcendpoint.oidc.token import AccessToken +from oidcendpoint.session_management import SessionManager +from oidcendpoint.session_management import db_key from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD from oidcendpoint.user_info import UserInfo @@ -34,15 +39,34 @@ def full_path(local_file): USERS = json.loads(open(full_path("users.json")).read()) USERINFO = UserInfo(USERS) -AREQN = AuthorizationRequest( +AREQ = AuthorizationRequest( response_type="code", - client_id="client1", + client_id="client_1", redirect_uri="http://example.com/authz", scope=["openid"], state="state000", nonce="nonce", ) +AREQS = AuthorizationRequest( + response_type="code", + client_id="client_1", + redirect_uri="http://example.com/authz", + scope=["openid", "address", "email"], + state="state000", + nonce="nonce", +) + +AREQRC = AuthorizationRequest( + response_type="code", + client_id="client_1", + redirect_uri="http://example.com/authz", + scope=["openid", "address", "email"], + state="state000", + nonce="nonce", + claims={"id_token": {"nickname": None}} +) + conf = { "issuer": "https://example.com/", "password": "mycket hemligt", @@ -72,12 +96,14 @@ def full_path(local_file): "kwargs": {"user": "diana"}, } }, - "userinfo": {"class": "oidcendpoint.user_info.UserInfo", "kwargs": {"db": USERS},}, + "userinfo": {"class": "oidcendpoint.user_info.UserInfo", "kwargs": {"db": USERS}, }, "client_authn": verify_client, "template_dir": "template", "id_token": {"class": IDToken, "kwargs": {"foo": "bar"}}, } +USER_ID = "diana" + class TestEndpoint(object): @pytest.fixture(autouse=True) @@ -93,112 +119,128 @@ def create_idtoken(self): self.endpoint_context.keyjar.add_symmetric( "client_1", "hemligtochintekort", ["sig", "enc"] ) + self.session_manager = self.endpoint_context.session_manager + self.user_id = USER_ID + + def _create_session(self, auth_req, sub_type="public", sector_identifier=''): + client_id = auth_req['client_id'] + ae = create_authn_event(self.user_id, self.session_manager.salt) + self.session_manager.create_session(ae, auth_req, self.user_id, client_id=client_id, + sub_type=sub_type, sector_identifier=sector_identifier) + return db_key(self.user_id, client_id) + + def _do_grant(self, auth_req): + client_id = auth_req['client_id'] + # The user consent module produces a Grant instance + grant = Grant(scope=AREQ['scope'], resources=[client_id]) + + # the grant is assigned to a session (user_id, client_id) + self.session_manager.set([self.user_id, client_id, grant.id], grant) + return db_key(self.user_id, client_id, grant.id) + + def _mint_code(self, grant): + # Constructing an authorization code is now done + return grant.mint_token( + 'authorization_code', + value=self.session_manager.token_handler["code"](self.user_id), + expires_at=time_sans_frac() + 300 # 5 minutes from now + ) - def test_id_token_payload_0(self): - session_info = {"authn_req": AREQN, "sub": "1234567890"} - info = self.endpoint_context.idtoken.payload(session_info) - assert info["payload"] == {"sub": "1234567890", "nonce": "nonce"} - assert info["lifetime"] == 300 + def _mint_access_token(self, grant, client_id, token_ref): + _csi = self.session_manager.get([self.user_id, client_id]) + return grant.mint_token( + 'access_token', + value=self.session_manager.token_handler["access_token"]( + db_key(self.user_id, client_id, grant.id), + client_id=client_id, + aud=grant.resources, + user_claims=None, + scope=grant.scope, + sub=_csi['sub'] + ), + expires_at=time_sans_frac() + 900, # 15 minutes from now + based_on=token_ref # Means the token (tok) was used to mint this token + ) - def test_id_token_payload_1(self): - session_info = {"authn_req": AREQN, "sub": "1234567890"} + def test_id_token_payload_0(self): + self._create_session(AREQ) + session_id = self._do_grant(AREQ) - info = self.endpoint_context.idtoken.payload(session_info) - assert info["payload"] == {"nonce": "nonce", "sub": "1234567890"} - assert info["lifetime"] == 300 + payload = self.endpoint_context.idtoken.payload(session_id) + assert set(payload.keys()) == {"sub", "nonce", "auth_time"} def test_id_token_payload_with_code(self): - session_info = {"authn_req": AREQN, "sub": "1234567890"} + self._create_session(AREQ) + session_id = self._do_grant(AREQ) + grant = self.session_manager[session_id] - info = self.endpoint_context.idtoken.payload( - session_info, code="ABCDEFGHIJKLMNOP" + code = self._mint_code(grant) + payload = self.endpoint_context.idtoken.payload( + session_id, AREQ["client_id"], code=code.value ) - assert info["payload"] == { - "nonce": "nonce", - "c_hash": "5-i4nCch0pDMX1VCVJHs1g", - "sub": "1234567890", - } - assert info["lifetime"] == 300 + assert set(payload.keys()) == {"nonce", "c_hash", "sub", "auth_time"} def test_id_token_payload_with_access_token(self): - session_info = {"authn_req": AREQN, "sub": "1234567890"} + self._create_session(AREQ) + session_id = self._do_grant(AREQ) + grant = self.session_manager[session_id] - info = self.endpoint_context.idtoken.payload( - session_info, access_token="012ABCDEFGHIJKLMNOP" + code = self._mint_code(grant) + access_token = self._mint_access_token(grant, AREQ['client_id'], code) + + payload = self.endpoint_context.idtoken.payload( + session_id, AREQ["client_id"], access_token=access_token.value ) - assert info["payload"] == { - "nonce": "nonce", - "at_hash": "bKkyhbn1CC8IMdavzOV-Qg", - "sub": "1234567890", - } - assert info["lifetime"] == 300 + assert set(payload.keys()) == {"nonce", "at_hash", "sub", "auth_time"} def test_id_token_payload_with_code_and_access_token(self): - session_info = {"authn_req": AREQN, "sub": "1234567890"} + self._create_session(AREQ) + session_id = self._do_grant(AREQ) + grant = self.session_manager[session_id] + + code = self._mint_code(grant) + access_token = self._mint_access_token(grant, AREQ['client_id'], code) - info = self.endpoint_context.idtoken.payload( - session_info, access_token="012ABCDEFGHIJKLMNOP", code="ABCDEFGHIJKLMNOP" + payload = self.endpoint_context.idtoken.payload( + session_id, AREQ["client_id"], access_token=access_token.value, code=code.value ) - assert info["payload"] == { - "nonce": "nonce", - "at_hash": "bKkyhbn1CC8IMdavzOV-Qg", - "c_hash": "5-i4nCch0pDMX1VCVJHs1g", - "sub": "1234567890", - } - assert info["lifetime"] == 300 + assert set(payload.keys()) == {"nonce", "c_hash", "at_hash", "sub", "auth_time"} def test_id_token_payload_with_userinfo(self): - session_info = {"authn_req": AREQN, "sub": "1234567890"} + self._create_session(AREQ) + session_id = self._do_grant(AREQ) + grant = self.session_manager[session_id] + grant.claims = {"id_token": {"given_name": None}} - info = self.endpoint_context.idtoken.payload( - session_info, user_info={"given_name": "Diana"} - ) - assert info["payload"] == { - "nonce": "nonce", - "given_name": "Diana", - "sub": "1234567890", - } - assert info["lifetime"] == 300 + payload = self.endpoint_context.idtoken.payload(session_id=session_id) + assert set(payload.keys()) == {"nonce", "given_name", "sub", "auth_time"} def test_id_token_payload_many_0(self): - session_info = {"authn_req": AREQN, "sub": "1234567890"} - - info = self.endpoint_context.idtoken.payload( - session_info, - user_info={"given_name": "Diana"}, - access_token="012ABCDEFGHIJKLMNOP", - code="ABCDEFGHIJKLMNOP", + self._create_session(AREQ) + session_id = self._do_grant(AREQ) + grant = self.session_manager[session_id] + grant.claims = {"id_token": {"given_name": None}} + code = self._mint_code(grant) + access_token = self._mint_access_token(grant, AREQ['client_id'], code) + + payload = self.endpoint_context.idtoken.payload( + session_id, AREQ["client_id"], + access_token=access_token.value, + code=code.value ) - assert info["payload"] == { - "nonce": "nonce", - "given_name": "Diana", - "at_hash": "bKkyhbn1CC8IMdavzOV-Qg", - "c_hash": "5-i4nCch0pDMX1VCVJHs1g", - "sub": "1234567890", - } - assert info["lifetime"] == 300 + assert set(payload.keys()) == {"nonce", "c_hash", "at_hash", "sub", "auth_time", + "given_name"} def test_sign_encrypt_id_token(self): - client_info = RegistrationResponse( - id_token_signed_response_alg="RS512", client_id="client_1" - ) - session_info = { - "authn_req": AREQN, - "sub": "sub", - "authn_event": {"authn_info": "loa2", "authn_time": time.time()}, - } + self._create_session(AREQ) + session_id = self._do_grant(AREQ) - self.endpoint_context.jwx_def["signing_alg"] = {"id_token": "RS384"} - self.endpoint_context.cdb["client_1"] = client_info.to_dict() - - _token = self.endpoint_context.idtoken.sign_encrypt( - session_info, "client_1", sign=True - ) + _token = self.endpoint_context.idtoken.sign_encrypt(session_id, AREQ['client_id'], sign=True) assert _token _jws = jws.factory(_token) - assert _jws.jwt.headers["alg"] == "RS512" + assert _jws.jwt.headers["alg"] == "RS256" client_keyjar = KeyJar() _jwks = self.endpoint_context.keyjar.export_jwks() @@ -210,10 +252,9 @@ def test_sign_encrypt_id_token(self): assert res["aud"] == ["client_1"] def test_get_sign_algorithm(self): - client_info = RegistrationResponse() - endpoint_context = EndpointContext(conf) + client_info = self.endpoint_context.cdb[AREQ['client_id']] algs = get_sign_and_encrypt_algorithms( - endpoint_context, client_info, "id_token", sign=True + self.endpoint_context, client_info, "id_token", sign=True ) # default signing alg assert algs == {"sign": True, "encrypt": False, "sign_alg": "RS256"} @@ -260,20 +301,12 @@ def test_get_sign_algorithm_4(self): assert algs == {"sign": True, "encrypt": False, "sign_alg": "RS512"} def test_available_claims(self): - session_info = { - "authn_req": AREQN, - "sub": "sub", - "authn_event": { - "authn_info": "loa2", - "authn_time": time.time(), - "uid": "diana", - }, - } - self.endpoint_context.idtoken.kwargs["available_claims"] = { - "nickname": {"essential": True} - } - req = {"client_id": "client_1"} - _token = self.endpoint_context.idtoken.make(req, session_info) + self._create_session(AREQ) + session_id = self._do_grant(AREQ) + grant = self.session_manager[session_id] + grant.claims = {"id_token": {"nickname": {"essential": True}}} + + _token = self.endpoint_context.idtoken.make(session_id=session_id) assert _token client_keyjar = KeyJar() _jwks = self.endpoint_context.keyjar.export_jwks() @@ -283,39 +316,34 @@ def test_available_claims(self): assert "nickname" in res def test_no_available_claims(self): - session_info = { - "authn_req": AREQN, - "sub": "sub", - "authn_event": { - "authn_info": "loa2", - "authn_time": time.time(), - "uid": "diana", - }, - } + self._create_session(AREQ) + session_id = self._do_grant(AREQ) + grant = self.session_manager[session_id] + grant.claims = {"id_token":{"foobar": None}} + req = {"client_id": "client_1"} - _token = self.endpoint_context.idtoken.make(req, session_info) + _token = self.endpoint_context.idtoken.make(session_id=session_id) assert _token client_keyjar = KeyJar() _jwks = self.endpoint_context.keyjar.export_jwks() client_keyjar.import_jwks(_jwks, self.endpoint_context.issuer) _jwt = JWT(key_jar=client_keyjar, iss="client_1") res = _jwt.unpack(_token) - assert "nickname" not in res + assert "foobar" not in res def test_client_claims(self): - session_info = { - "authn_req": AREQN, - "sub": "sub", - "authn_event": { - "authn_info": "loa2", - "authn_time": time.time(), - "uid": "diana", - }, - } - self.endpoint_context.idtoken.enable_claims_per_client = True + self._create_session(AREQ) + session_id = self._do_grant(AREQ) + grant = self.session_manager[session_id] + + self.endpoint_context.idtoken.claims_interface.enable_claims_per_client = True self.endpoint_context.cdb["client_1"]["id_token_claims"] = {"address": None} - req = {"client_id": "client_1"} - _token = self.endpoint_context.idtoken.make(req, session_info) + + _claims = self.endpoint_context.idtoken.claims_interface.get_claims( + client_id=AREQ["client_id"], user_id=USER_ID, scopes=AREQ["scope"]) + grant.claims = {'id_token': _claims} + + _token = self.endpoint_context.idtoken.make(session_id=session_id) assert _token client_keyjar = KeyJar() _jwks = self.endpoint_context.keyjar.export_jwks() @@ -326,50 +354,68 @@ def test_client_claims(self): assert "nickname" not in res def test_client_claims_with_default(self): - session_info = { - "authn_req": AREQN, - "sub": "sub", - "authn_event": { - "authn_info": "loa2", - "authn_time": time.time(), - "uid": "diana", - }, - } - self.endpoint_context.cdb["client_1"]["id_token_claims"] = {"address": None} - self.endpoint_context.idtoken.kwargs["available_claims"] = { - "nickname": {"essential": True} - } - self.endpoint_context.idtoken.enable_claims_per_client = True - req = {"client_id": "client_1"} - _token = self.endpoint_context.idtoken.make(req, session_info) + self._create_session(AREQ) + session_id = self._do_grant(AREQ) + grant = self.session_manager[session_id] + + # self.endpoint_context.cdb["client_1"]["id_token_claims"] = {"address": None} + # self.endpoint_context.idtoken.enable_claims_per_client = True + + _claims = self.endpoint_context.idtoken.claims_interface.get_claims( + client_id=AREQ["client_id"], user_id=USER_ID, scopes=AREQ["scope"]) + grant.claims = {"id_token": _claims} + + _token = self.endpoint_context.idtoken.make(session_id=session_id) assert _token client_keyjar = KeyJar() _jwks = self.endpoint_context.keyjar.export_jwks() client_keyjar.import_jwks(_jwks, self.endpoint_context.issuer) _jwt = JWT(key_jar=client_keyjar, iss="client_1") res = _jwt.unpack(_token) - assert "address" in res - assert "nickname" in res - def test_client_claims_disabled(self): - # enable_claims_per_client defaults to False - session_info = { - "authn_req": AREQN, - "sub": "sub", - "authn_event": { - "authn_info": "loa2", - "authn_time": time.time(), - "uid": "diana", - }, - } - self.endpoint_context.cdb["client_1"]["id_token_claims"] = {"address": None} - req = {"client_id": "client_1"} - _token = self.endpoint_context.idtoken.make(req, session_info) + # No user info claims should be there + assert "address" not in res + assert "nickname" not in res + + def test_client_claims_scopes(self): + self._create_session(AREQS) + session_id = self._do_grant(AREQS) + grant = self.session_manager[session_id] + + self.endpoint_context.idtoken.claims_interface.add_claims_by_scope = True + _claims = self.endpoint_context.idtoken.claims_interface.get_claims( + client_id=AREQS["client_id"], user_id=USER_ID, scopes=AREQS["scope"]) + grant.claims = {"id_token": _claims} + + _token = self.endpoint_context.idtoken.make(session_id=session_id) assert _token client_keyjar = KeyJar() _jwks = self.endpoint_context.keyjar.export_jwks() client_keyjar.import_jwks(_jwks, self.endpoint_context.issuer) _jwt = JWT(key_jar=client_keyjar, iss="client_1") res = _jwt.unpack(_token) - assert "address" not in res + assert "address" in res + assert "email" in res assert "nickname" not in res + + def test_client_claims_scopes_and_request_claims(self): + self._create_session(AREQRC) + session_id = self._do_grant(AREQRC) + grant = self.session_manager[session_id] + + self.endpoint_context.idtoken.claims_interface.add_claims_by_scope = True + _claims = self.endpoint_context.idtoken.claims_interface.get_claims( + client_id=AREQRC["client_id"], user_id=USER_ID, scopes=AREQRC["scope"]) + grant.claims = {"id_token": _claims} + + _token = self.endpoint_context.idtoken.make(session_id=session_id) + assert _token + client_keyjar = KeyJar() + _jwks = self.endpoint_context.keyjar.export_jwks() + client_keyjar.import_jwks(_jwks, self.endpoint_context.issuer) + _jwt = JWT(key_jar=client_keyjar, iss="client_1") + res = _jwt.unpack(_token) + assert "address" in res + assert "email" in res + assert "nickname" in res + diff --git a/tests/test_05_sso_db.py b/tests/test_05_sso_db.py deleted file mode 100644 index 5c8479f..0000000 --- a/tests/test_05_sso_db.py +++ /dev/null @@ -1,135 +0,0 @@ -import shutil - -import pytest - -from oidcendpoint.sso_db import SSODb - -DB_CONF = { - "handler": "oidcmsg.storage.abfile.AbstractFileSystem", - "fdir": "db/sso", - "key_conv": "oidcmsg.storage.converter.QPKey", - "value_conv": "oidcmsg.storage.converter.JSON", -} - - -def rmtree(item): - try: - shutil.rmtree(item) - except FileNotFoundError: - pass - - -class TestSSODB(object): - @pytest.fixture(autouse=True) - def create_sdb(self): - rmtree("db/sso") - self.sso_db = SSODb(DB_CONF) - - def test_map_sid2uid(self): - self.sso_db.map_sid2uid("session id 1", "Lizz") - assert self.sso_db.get_sids_by_uid("Lizz") == ["session id 1"] - - def test_missing_map(self): - assert self.sso_db.get_sids_by_uid("Lizz") == [] - - def test_multiple_map_sid2uid(self): - self.sso_db.map_sid2uid("session id 1", "Lizz") - self.sso_db.map_sid2uid("session id 2", "Lizz") - assert set(self.sso_db.get_sids_by_uid("Lizz")) == { - "session id 1", - "session id 2", - } - - def test_map_unmap_sid2uid(self): - self.sso_db.map_sid2uid("session id 1", "Lizz") - self.sso_db.map_sid2uid("session id 2", "Lizz") - assert set(self.sso_db.get_sids_by_uid("Lizz")) == { - "session id 1", - "session id 2", - } - - self.sso_db.remove_sid2uid("session id 1", "Lizz") - assert self.sso_db.get_sids_by_uid("Lizz") == ["session id 2"] - - def test_get_uid_by_sid(self): - self.sso_db.map_sid2uid("session id 1", "Lizz") - self.sso_db.map_sid2uid("session id 2", "Lizz") - - assert self.sso_db.get_uid_by_sid("session id 1") == "Lizz" - assert self.sso_db.get_uid_by_sid("session id 2") == "Lizz" - - def test_remove_uid(self): - self.sso_db.map_sid2uid("session id 1", "Lizz") - self.sso_db.map_sid2uid("session id 2", "Diana") - - self.sso_db.remove_uid("Lizz") - assert self.sso_db.get_uid_by_sid("session id 1") is None - assert self.sso_db.get_sids_by_uid("Lizz") == [] - - def test_map_sid2sub(self): - self.sso_db.map_sid2sub("session id 1", "abcdefgh") - assert self.sso_db.get_sids_by_sub("abcdefgh") == ["session id 1"] - - def test_missing_sid2sub_map(self): - assert self.sso_db.get_sids_by_sub("abcdefgh") == [] - - def test_multiple_map_sid2sub(self): - self.sso_db.map_sid2sub("session id 1", "abcdefgh") - self.sso_db.map_sid2sub("session id 2", "abcdefgh") - assert set(self.sso_db.get_sids_by_sub("abcdefgh")) == { - "session id 1", - "session id 2", - } - - def test_map_unmap_sid2sub(self): - self.sso_db.map_sid2sub("session id 1", "abcdefgh") - self.sso_db.map_sid2sub("session id 2", "abcdefgh") - assert set(self.sso_db.get_sids_by_sub("abcdefgh")) == { - "session id 1", - "session id 2", - } - - self.sso_db.remove_sid2sub("session id 1", "abcdefgh") - assert self.sso_db.get_sids_by_sub("abcdefgh") == ["session id 2"] - - def test_get_sub_by_sid(self): - self.sso_db.map_sid2sub("session id 1", "abcdefgh") - self.sso_db.map_sid2sub("session id 2", "abcdefgh") - - assert set(self.sso_db.get_sids_by_sub("abcdefgh")) == { - "session id 1", - "session id 2", - } - - def test_remove_sub(self): - self.sso_db.map_sid2sub("session id 1", "abcdefgh") - self.sso_db.map_sid2sub("session id 2", "012346789") - - self.sso_db.remove_sub("abcdefgh") - assert self.sso_db.get_sub_by_sid("session id 1") is None - assert self.sso_db.get_sids_by_sub("abcdefgh") == [] - # have not touched the others - assert self.sso_db.get_sub_by_sid("session id 2") == "012346789" - assert self.sso_db.get_sids_by_sub("012346789") == ["session id 2"] - - def test_get_sub_by_uid_same_sub(self): - self.sso_db.map_sid2sub("session id 1", "abcdefgh") - self.sso_db.map_sid2sub("session id 2", "abcdefgh") - - self.sso_db.map_sid2uid("session id 1", "Lizz") - self.sso_db.map_sid2uid("session id 2", "Lizz") - - res = self.sso_db.get_subs_by_uid("Lizz") - - assert set(res) == {"abcdefgh"} - - def test_get_sub_by_uid_different_sub(self): - self.sso_db.map_sid2sub("session id 1", "abcdefgh") - self.sso_db.map_sid2sub("session id 2", "012346789") - - self.sso_db.map_sid2uid("session id 1", "Lizz") - self.sso_db.map_sid2uid("session id 2", "Lizz") - - res = self.sso_db.get_subs_by_uid("Lizz") - - assert set(res) == {"abcdefgh", "012346789"} diff --git a/tests/test_07_userinfo.py b/tests/test_07_userinfo.py index d3a0573..f94e859 100644 --- a/tests/test_07_userinfo.py +++ b/tests/test_07_userinfo.py @@ -1,25 +1,23 @@ import json import os -import pytest -from oidcmsg.message import Message from oidcmsg.oidc import OpenIDRequest -from oidcmsg.oidc import OpenIDSchema +import pytest from oidcendpoint.authn_event import create_authn_event from oidcendpoint.endpoint_context import EndpointContext +from oidcendpoint.grant import Grant from oidcendpoint.oidc.authorization import Authorization from oidcendpoint.oidc.provider_config import ProviderConfiguration from oidcendpoint.oidc.registration import Registration from oidcendpoint.scopes import SCOPE2CLAIMS from oidcendpoint.scopes import STANDARD_CLAIMS from oidcendpoint.scopes import convert_scopes2claims +from oidcendpoint.session_management import db_key +from oidcendpoint.session_management import unpack_db_key from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD from oidcendpoint.user_info import UserInfo -from oidcendpoint.userinfo import by_schema -from oidcendpoint.userinfo import claims_match -from oidcendpoint.userinfo import collect_user_info -from oidcendpoint.userinfo import update_claims +from oidcendpoint.userinfo import ClaimsInterface CLAIMS = { "userinfo": { @@ -136,27 +134,27 @@ def test_custom_scopes(): assert set( convert_scopes2claims(["email"], _available_claims, map=_scopes).keys() - ) == {"email", "email_verified",} + ) == {"email", "email_verified", } assert set( convert_scopes2claims(["address"], _available_claims, map=_scopes).keys() ) == {"address"} assert set( convert_scopes2claims(["phone"], _available_claims, map=_scopes).keys() - ) == {"phone_number", "phone_number_verified",} + ) == {"phone_number", "phone_number_verified", } assert set( convert_scopes2claims( ["research_and_scholarship"], _available_claims, map=_scopes ).keys() ) == { - "name", - "given_name", - "family_name", - "email", - "email_verified", - "sub", - "eduperson_scoped_affiliation", - } + "name", + "given_name", + "family_name", + "email", + "email_verified", + "sub", + "eduperson_scoped_affiliation", + } PROVIDER_INFO = { @@ -172,71 +170,6 @@ def test_custom_scopes(): ] } - -def test_update_claims_authn_req_id_token(): - _session_info = {"authn_req": OIDR} - claims = update_claims(_session_info, "id_token", PROVIDER_INFO) - assert set(claims.keys()) == {"auth_time", "acr"} - - -def test_update_claims_authn_req_userinfo(): - _session_info = {"authn_req": OIDR} - claims = update_claims(_session_info, "userinfo", PROVIDER_INFO) - assert set(claims.keys()) == { - "given_name", - "nickname", - "email", - "email_verified", - "picture", - "http://example.info/claims/groups", - } - - -def test_update_claims_authzreq_id_token(): - _session_info = {"authn_req": OIDR} - claims = update_claims(_session_info, "id_token", PROVIDER_INFO) - assert set(claims.keys()) == {"auth_time", "acr"} - - -def test_update_claims_authzreq_userinfo(): - _session_info = {"authn_req": OIDR} - claims = update_claims(_session_info, "userinfo", PROVIDER_INFO) - assert set(claims.keys()) == { - "given_name", - "nickname", - "email", - "email_verified", - "picture", - "http://example.info/claims/groups", - } - - -def test_clams_value(): - assert claims_match("red", CLAIMS["userinfo"]["http://example.info/claims/groups"]) - - -def test_clams_values(): - assert claims_match("urn:mace:incommon:iap:silver", CLAIMS["id_token"]["acr"]) - - -def test_clams_essential(): - assert claims_match(["foobar@example"], CLAIMS["userinfo"]["email"]) - - -def test_clams_none(): - assert claims_match(["angle"], CLAIMS["userinfo"]["nickname"]) - - -def test_by_schema(): - # There are no requested or optional claims defined for Message - assert by_schema(Message, sub="John") == {} - - assert by_schema(OpenIDSchema, sub="John", given_name="John", age=34) == { - "sub": "John", - "given_name": "John", - } - - KEYDEFS = [ {"type": "RSA", "key": "", "use": ["sig"]}, {"type": "EC", "crv": "P-256", "use": ["sig"]}, @@ -300,22 +233,42 @@ def create_endpoint_context(self): ) # Just has to be there self.endpoint_context.cdb["client1"] = {} + self.session_manager = self.endpoint_context.session_manager + self.claims_interface = ClaimsInterface(self.endpoint_context, "userinfo") + self.user_id = "diana" + + def _create_session(self, auth_req, sub_type="public", sector_identifier=''): + client_id = auth_req['client_id'] + ae = create_authn_event(self.user_id, self.session_manager.salt) + self.session_manager.create_session(ae, auth_req, self.user_id, client_id=client_id, + sub_type=sub_type, sector_identifier=sector_identifier) + return db_key(self.user_id, client_id) + + def _do_grant(self, auth_req): + client_id = auth_req['client_id'] + # The user consent module produces a Grant instance + grant = Grant(scope=auth_req['scope'], resources=[client_id]) + + # the grant is assigned to a session (user_id, client_id) + self.session_manager.set([self.user_id, client_id, grant.id], grant) + return db_key(self.user_id, client_id, grant.id) def test_collect_user_info(self): _req = OIDR.copy() _req["claims"] = CLAIMS_2 - _session_info = {"authn_req": _req} - session = _session_info.copy() - session["sub"] = "doe" - session["uid"] = "diana" - session["authn_event"] = create_authn_event("diana", "salt") + self._create_session(_req) + session_id = self._do_grant(_req) + _uid, _cid, _gid = unpack_db_key(session_id) + + _restriction = self.claims_interface.get_claims(client_id=_cid, user_id=_uid, + scopes=OIDR["scope"]) - res = collect_user_info(self.endpoint_context, session) + res = self.claims_interface.get_user_claims("diana", _restriction) assert res == { + 'eduperson_scoped_affiliation': ['staff@example.org'], "nickname": "Dina", - "sub": "doe", "email": "diana@example.org", "email_verified": False, } @@ -325,21 +278,17 @@ def test_collect_user_info_2(self): _req["scope"] = "openid email" del _req["claims"] - _session_info = {"authn_req": _req} - session = _session_info.copy() - session["sub"] = "doe" - session["uid"] = "diana" - session["authn_event"] = create_authn_event("diana", "salt") + self._create_session(_req) + session_id = self._do_grant(_req) + _uid, _cid, _gid = unpack_db_key(session_id) - self.endpoint_context.provider_info["scopes_supported"] = [ - "openid", - "email", - "offline_access", - ] - res = collect_user_info(self.endpoint_context, session) + self.claims_interface.add_claims_by_scope = True + _restriction = self.claims_interface.get_claims(client_id=_cid, user_id=_uid, + scopes=_req["scope"]) + + res = self.claims_interface.get_user_claims("diana", _restriction) assert res == { - "sub": "doe", "email": "diana@example.org", "email_verified": False, } @@ -349,25 +298,17 @@ def test_collect_user_info_scope_not_supported(self): _req["scope"] = "openid email address" del _req["claims"] - _session_info = {"authn_req": _req} - session = _session_info.copy() - session["sub"] = "doe" - session["uid"] = "diana" - session["authn_event"] = create_authn_event("diana", "salt") + self._create_session(_req) + session_id = self._do_grant(_req) + _uid, _cid, _gid = unpack_db_key(session_id) - # Scope address not supported - self.endpoint_context.provider_info["scopes_supported"] = [ - "openid", - "email", - "offline_access", - ] - res = collect_user_info(self.endpoint_context, session) + self.claims_interface.add_claims_by_scope = False + _restriction = self.claims_interface.get_claims(client_id=_cid, user_id=_uid, + scopes=_req["scope"]) - assert res == { - "sub": "doe", - "email": "diana@example.org", - "email_verified": False, - } + res = self.claims_interface.get_user_claims("diana", _restriction) + + assert res == {} class TestCollectUserInfoCustomScopes: @@ -443,22 +384,43 @@ def create_endpoint_context(self): } ) self.endpoint_context.cdb["client1"] = {} + self.endpoint_context.cdb["client1"] = {} + self.session_manager = self.endpoint_context.session_manager + self.claims_interface = ClaimsInterface(self.endpoint_context, "userinfo") + self.user_id = "diana" + + def _create_session(self, auth_req, sub_type="public", sector_identifier=''): + client_id = auth_req['client_id'] + ae = create_authn_event(self.user_id, self.session_manager.salt) + self.session_manager.create_session(ae, auth_req, self.user_id, client_id=client_id, + sub_type=sub_type, sector_identifier=sector_identifier) + return db_key(self.user_id, client_id) + + def _do_grant(self, auth_req): + client_id = auth_req['client_id'] + # The user consent module produces a Grant instance + grant = Grant(scope=auth_req['scope'], resources=[client_id]) + + # the grant is assigned to a session (user_id, client_id) + self.session_manager.set([self.user_id, client_id, grant.id], grant) + return db_key(self.user_id, client_id, grant.id) def test_collect_user_info(self): - _session_info = {"authn_req": OIDR} - session = _session_info.copy() - session["sub"] = "doe" - session["uid"] = "diana" - session["authn_event"] = create_authn_event("diana", "salt") + self._create_session(OIDR) + session_id = self._do_grant(OIDR) + _uid, _cid, _gid = unpack_db_key(session_id) + + self.claims_interface.add_claims_by_scope = False + _restriction = self.claims_interface.get_claims(client_id=_cid, user_id=_uid, + scopes=OIDR["scope"]) - res = collect_user_info(self.endpoint_context, session) + res = self.claims_interface.get_user_claims("diana", _restriction) assert res == { "email": "diana@example.org", "email_verified": False, "nickname": "Dina", "given_name": "Diana", - "sub": "doe", } def test_collect_user_info_2(self): @@ -466,40 +428,41 @@ def test_collect_user_info_2(self): _req["scope"] = "openid email" del _req["claims"] - _session_info = {"authn_req": _req} - session = _session_info.copy() - session["sub"] = "doe" - session["uid"] = "diana" - session["authn_event"] = create_authn_event("diana", "salt") + self._create_session(_req) + session_id = self._do_grant(_req) + _uid, _cid, _gid = unpack_db_key(session_id) - self.endpoint_context.provider_info["claims_supported"].remove("email") - self.endpoint_context.provider_info["claims_supported"].remove("email_verified") + self.claims_interface.add_claims_by_scope = True + _restriction = self.claims_interface.get_claims(client_id=_cid, user_id=_uid, + scopes=_req["scope"]) - res = collect_user_info(self.endpoint_context, session) + res = self.claims_interface.get_user_claims("diana", _restriction) - assert res == {"sub": "doe"} + assert res == {'email': 'diana@example.org', 'email_verified': False} def test_collect_user_info_scope_not_supported(self): _req = OIDR.copy() _req["scope"] = "openid email address" del _req["claims"] - _session_info = {"authn_req": _req} - session = _session_info.copy() - session["sub"] = "doe" - session["uid"] = "diana" - session["authn_event"] = create_authn_event("diana", "salt") + self._create_session(_req) + session_id = self._do_grant(_req) + _uid, _cid, _gid = unpack_db_key(session_id) - # Scope address not supported + # Address asked for but not supported self.endpoint_context.provider_info["scopes_supported"] = [ "openid", "email", "offline_access", ] - res = collect_user_info(self.endpoint_context, session) + + self.claims_interface.add_claims_by_scope = True + _restriction = self.claims_interface.get_claims(client_id=_cid, user_id=_uid, + scopes=_req["scope"]) + + res = self.claims_interface.get_user_claims("diana", _restriction) assert res == { - "sub": "doe", - "email": "diana@example.org", - "email_verified": False, + 'email': 'diana@example.org', + 'email_verified': False } diff --git a/tests/test_08_session.py b/tests/test_08_session.py deleted file mode 100644 index 15ca52c..0000000 --- a/tests/test_08_session.py +++ /dev/null @@ -1,520 +0,0 @@ -import os -import shutil -import time - -import pytest -from oidcmsg.oidc import AuthorizationRequest -from oidcmsg.oidc import OpenIDRequest -from oidcmsg.storage.init import storage_factory - -from oidcendpoint import rndstr -from oidcendpoint import token_handler -from oidcendpoint.authn_event import create_authn_event -from oidcendpoint.endpoint_context import EndpointContext -from oidcendpoint.oidc.authorization import Authorization -from oidcendpoint.oidc.provider_config import ProviderConfiguration -from oidcendpoint.session import SessionDB -from oidcendpoint.session import setup_session -from oidcendpoint.sso_db import SSODb -from oidcendpoint.token_handler import UnknownToken -from oidcendpoint.token_handler import WrongTokenType -from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD -from oidcendpoint.user_info import UserInfo - -__author__ = "rohe0002" - -AREQ = AuthorizationRequest( - response_type="code", - client_id="client1", - redirect_uri="http://example.com/authz", - scope=["openid"], - state="state000", -) - -AREQN = AuthorizationRequest( - response_type="code", - client_id="client1", - redirect_uri="http://example.com/authz", - scope=["openid"], - state="state000", - nonce="something", -) - -AREQO = AuthorizationRequest( - response_type="code", - client_id="client1", - redirect_uri="http://example.com/authz", - scope=["openid", "offline_access"], - prompt="consent", - state="state000", -) - -OIDR = OpenIDRequest( - response_type="code", - client_id="client1", - redirect_uri="http://example.com/authz", - scope=["openid"], - state="state000", -) - -BASEDIR = os.path.abspath(os.path.dirname(__file__)) - - -def full_path(local_file): - return os.path.join(BASEDIR, local_file) - - -SSO_DB_CONF = { - "handler": "oidcmsg.storage.abfile.AbstractFileSystem", - "fdir": "db/sso", - "key_conv": "oidcmsg.storage.converter.QPKey", - "value_conv": "oidcmsg.storage.converter.JSON", -} - -SESSION_DB_CONF = { - "handler": "oidcmsg.storage.abfile.AbstractFileSystem", - "fdir": "db/session", - "key_conv": "oidcmsg.storage.converter.QPKey", - "value_conv": "oidcmsg.storage.converter.JSON", -} - - -def rmtree(item): - try: - shutil.rmtree(item) - except FileNotFoundError: - pass - - -class TestSessionDB(object): - @pytest.fixture(autouse=True) - def create_sdb(self): - rmtree("db/sso") - rmtree("db/session") - - passwd = rndstr(24) - _th_args = { - "code": {"lifetime": 600, "password": passwd}, - "token": {"lifetime": 3600, "password": passwd}, - "refresh": {"lifetime": 86400, "password": passwd}, - } - - _token_handler = token_handler.factory(None, **_th_args) - userinfo = UserInfo(db_file=full_path("users.json")) - self.sdb = SessionDB( - storage_factory(SESSION_DB_CONF), - _token_handler, - SSODb(SSO_DB_CONF), - userinfo, - ) - - def test_create_authz_session(self): - ae = create_authn_event("uid", "salt") - sid = self.sdb.create_authz_session(ae, AREQ, client_id="client_id") - self.sdb.do_sub(sid, uid="user", client_salt="client_salt") - - info = self.sdb[sid] - assert info["client_id"] == "client_id" - assert set(info.keys()) == { - "sid", - "client_id", - "authn_req", - "authn_event", - "sub", - "oauth_state", - "code", - } - - def test_create_authz_session_without_nonce(self): - ae = create_authn_event("sub", "salt") - sid = self.sdb.create_authz_session(ae, AREQ, client_id="client_id") - info = self.sdb[sid] - assert info["oauth_state"] == "authz" - - def test_create_authz_session_with_nonce(self): - ae = create_authn_event("sub", "salt") - sid = self.sdb.create_authz_session(ae, AREQN, client_id="client_id") - info = self.sdb[sid] - authz_request = info["authn_req"] - assert authz_request["nonce"] == "something" - - def test_create_authz_session_with_id_token(self): - ae = create_authn_event("sub", "salt") - sid = self.sdb.create_authz_session( - ae, AREQ, client_id="client_id", id_token="id_token" - ) - - info = self.sdb[sid] - assert info["id_token"] == "id_token" - - def test_create_authz_session_with_oidreq(self): - ae = create_authn_event("sub", "salt") - sid = self.sdb.create_authz_session( - ae, AREQ, client_id="client_id", oidreq=OIDR - ) - info = self.sdb[sid] - assert "id_token" not in info - assert "oidreq" in info - - def test_create_authz_session_with_sector_id(self): - ae = create_authn_event("sub", "salt") - sid = self.sdb.create_authz_session( - ae, AREQ, client_id="client_id", oidreq=OIDR - ) - self.sdb.do_sub( - sid, "user1", "client_salt", "http://example.com/si.jwt", "pairwise" - ) - - info_1 = self.sdb[sid].copy() - assert "id_token" not in info_1 - assert "oidreq" in info_1 - assert info_1["sub"] != "sub" - - self.sdb.do_sub( - sid, "user2", "client_salt", "http://example.net/si.jwt", "pairwise" - ) - - info_2 = self.sdb[sid] - assert info_2["sub"] != "sub" - assert info_2["sub"] != info_1["sub"] - - def test_upgrade_to_token(self): - ae1 = create_authn_event("uid", "salt") - sid = self.sdb.create_authz_session(ae1, AREQ, client_id="client_id") - self.sdb[sid]["sub"] = "sub" - grant = self.sdb[sid]["code"] - _dict = self.sdb.upgrade_to_token(grant) - - print(_dict.keys()) - assert set(_dict.keys()) == { - "sid", - "authn_event", - "authn_req", - "access_token", - "token_type", - "client_id", - "oauth_state", - "expires_in", - "expires_at", - "code_is_used" - } - - # can't update again - # with pytest.raises(AccessCodeUsed): - print(self.sdb.upgrade_to_token(grant)) - - def test_upgrade_to_token_refresh(self): - ae1 = create_authn_event("sub", "salt") - sid = self.sdb.create_authz_session(ae1, AREQO, client_id="client_id") - self.sdb.do_sub(sid, "user", ae1["salt"]) - grant = self.sdb[sid]["code"] - # Issue an access token trading in the access grant code - _dict = self.sdb.upgrade_to_token(grant, issue_refresh=True) - - print(_dict.keys()) - assert set(_dict.keys()) == { - "sid", - "authn_event", - "authn_req", - "access_token", - "sub", - "token_type", - "client_id", - "oauth_state", - "refresh_token", - "expires_in", - "expires_at", - "code_is_used" - } - - # You can't refresh a token using the token itself - with pytest.raises(WrongTokenType): - self.sdb.refresh_token(_dict["access_token"]) - - def test_upgrade_to_token_with_id_token_and_oidreq(self): - ae2 = create_authn_event("another_user_id", "salt") - sid = self.sdb.create_authz_session(ae2, AREQ, client_id="client_id") - self.sdb[sid]["sub"] = "sub" - grant = self.sdb[sid]["code"] - - _dict = self.sdb.upgrade_to_token(grant, id_token="id_token", oidreq=OIDR) - print(_dict.keys()) - assert set(_dict.keys()) == { - "sid", - "authn_event", - "authn_req", - "oidreq", - "access_token", - "id_token", - "token_type", - "client_id", - "oauth_state", - "expires_in", - "expires_at", - "code_is_used" - } - - assert _dict["id_token"] == "id_token" - assert isinstance(_dict["oidreq"], OpenIDRequest) - - def test_refresh_token(self): - ae = create_authn_event("uid", "salt") - sid = self.sdb.create_authz_session(ae, AREQ, client_id="client_id") - self.sdb[sid]["sub"] = "sub" - grant = self.sdb[sid]["code"] - - dict1 = self.sdb.upgrade_to_token(grant, issue_refresh=True).copy() - rtoken = dict1["refresh_token"] - dict2 = self.sdb.refresh_token(rtoken, AREQ["client_id"]) - - assert dict1["access_token"] != dict2["access_token"] - - with pytest.raises(WrongTokenType): - self.sdb.refresh_token(dict2["access_token"], AREQ["client_id"]) - - def test_refresh_token_cleared_session(self): - ae = create_authn_event("uid", "salt") - sid = self.sdb.create_authz_session(ae, AREQ, client_id="client_id") - self.sdb[sid]["sub"] = "sub" - grant = self.sdb[sid]["code"] - dict1 = self.sdb.upgrade_to_token(grant, issue_refresh=True) - ac1 = dict1["access_token"] - - # Purge the SessionDB - self.sdb._db = {} - - rtoken = dict1["refresh_token"] - with pytest.raises(UnknownToken): - self.sdb.refresh_token(rtoken, AREQ["client_id"]) - - def test_is_valid(self): - ae1 = create_authn_event("uid", "salt") - sid = self.sdb.create_authz_session(ae1, AREQ, client_id="client_id") - self.sdb[sid]["sub"] = "sub" - grant = self.sdb[sid]["code"] - - assert self.sdb.is_valid("code", grant) - - sinfo = self.sdb.upgrade_to_token(grant, issue_refresh=True) - assert not self.sdb.is_valid("code", grant) - access_token = sinfo["access_token"] - assert self.sdb.is_valid("access_token", access_token) - - refresh_token = sinfo["refresh_token"] - sinfo = self.sdb.refresh_token(refresh_token, AREQ["client_id"]) - access_token2 = sinfo["access_token"] - assert self.sdb.is_valid("access_token", access_token2) - - # The old access code should be invalid - try: - self.sdb.is_valid("access_token", access_token) - except KeyError: - pass - - def test_valid_grant(self): - ae = create_authn_event("another:user", "salt") - sid = self.sdb.create_authz_session(ae, AREQ, client_id="client_id") - grant = self.sdb[sid]["code"] - - assert self.sdb.is_valid("code", grant) - - def test_revoke_token(self): - ae1 = create_authn_event("uid", "salt") - sid = self.sdb.create_authz_session(ae1, AREQ, client_id="client_id") - self.sdb[sid]["sub"] = "sub" - - grant = self.sdb[sid]["code"] - tokens = self.sdb.upgrade_to_token(grant, issue_refresh=True) - access_token = tokens["access_token"] - refresh_token = tokens["refresh_token"] - - assert self.sdb.is_valid("access_token", access_token) - - self.sdb.revoke_token(sid, "access_token") - assert not self.sdb.is_valid("access_token", access_token) - - sinfo = self.sdb.refresh_token(refresh_token, AREQ["client_id"]) - access_token = sinfo["access_token"] - assert self.sdb.is_valid("access_token", access_token) - - self.sdb.revoke_token(sid, "refresh_token") - assert not self.sdb.is_valid("refresh_token", refresh_token) - - ae2 = create_authn_event("sub", "salt") - sid = self.sdb.create_authz_session(ae2, AREQ, client_id="client_2") - - grant = self.sdb[sid]["code"] - self.sdb.revoke_token(sid, "code") - assert not self.sdb.is_valid("code", grant) - - def test_sub_to_authn_event(self): - ae = create_authn_event("sub", "salt", time_stamp=time.time()) - sid = self.sdb.create_authz_session(ae, AREQ, client_id="client_id") - sub = self.sdb.do_sub(sid, "user", "client_salt") - - # given the sub find out whether the authn event is still valid - sids = self.sdb.get_sids_by_sub(sub) - ae = self.sdb[sids[0]]["authn_event"] - assert ae.valid() - - def test_do_sub_deterministic(self): - ae = create_authn_event("tester", "random_value") - sid = self.sdb.create_authz_session(ae, AREQ, client_id="client_id") - self.sdb.do_sub(sid, "user", "other_random_value") - - info = self.sdb[sid] - assert ( - info["sub"] - == "d657bddf3d30970aa681663978ea84e26553ead03cb6fe8fcfa6523f2bcd0ad2" - ) - - self.sdb.do_sub( - sid, - "user", - "other_random_value", - sector_id="http://example.com", - subject_type="pairwise", - ) - info2 = self.sdb[sid] - assert ( - info2["sub"] - == "1442ceb13a822e802f85832ce93a8fda011e32a3363834dd1db3f9aa211065bd" - ) - - self.sdb.do_sub( - sid, - "user", - "another_random_value", - sector_id="http://other.example.com", - subject_type="pairwise", - ) - - info2 = self.sdb[sid] - assert ( - info2["sub"] - == "56e0a53d41086e7b22d78d52ee461655e9b090d50a0663d16136ea49a56c9bec" - ) - - def test_match_session(self): - ae1 = create_authn_event("uid", "salt") - sid = self.sdb.create_authz_session(ae1, AREQ, client_id="client_id") - self.sdb[sid]["sub"] = "sub" - self.sdb.sso_db.map_sid2uid(sid, "uid") - - res = self.sdb.match_session("uid", client_id="client_id") - assert res == sid - - def test_get_token(self): - ae1 = create_authn_event("uid", "salt") - sid = self.sdb.create_authz_session(ae1, AREQ, client_id="client_id") - self.sdb[sid]["sub"] = "sub" - self.sdb.sso_db.map_sid2uid(sid, "uid") - - grant = self.sdb.get_token(sid) - assert self.sdb.is_valid("code", grant) - assert self.sdb.handler.type(grant) == "A" - - -KEYDEFS = [ - {"type": "RSA", "key": "", "use": ["sig"]}, - {"type": "EC", "crv": "P-256", "use": ["sig"]}, -] - -conf = { - "issuer": "https://example.com/", - "password": "mycket hemligt", - "token_expires_in": 600, - "grant_expires_in": 300, - "refresh_token_expires_in": 86400, - "verify_ssl": False, - "capabilities": {}, - "keys": { - "uri_path": "static/jwks.json", - "key_defs": KEYDEFS, - "private_path": "own/jwks.json", - }, - "endpoint": { - "provider_config": { - "path": ".well-known/openid-configuration", - "class": ProviderConfiguration, - "kwargs": {}, - }, - "authorization_endpoint": { - "path": "authorization", - "class": Authorization, - "kwargs": {}, - }, - }, - "authentication": { - "anon": { - "acr": INTERNETPROTOCOLPASSWORD, - "class": "oidcendpoint.user_authn.user.NoAuthn", - "kwargs": {"user": "diana"}, - } - }, - "userinfo": {"class": UserInfo, "kwargs": {"db_file": full_path("users.json")}}, - "template_dir": "template", - "sso_db": SSO_DB_CONF, - "session_db": SESSION_DB_CONF, -} - - -def test_setup_session(): - endpoint_context = EndpointContext(conf) - uid = "_user_" - client_id = "EXTERNAL" - areq = None - acr = None - sid = setup_session(endpoint_context, areq, uid, client_id, acr, salt="salt") - assert sid - - -def test_setup_session_upgrade_to_token(): - endpoint_context = EndpointContext(conf) - uid = "_user_" - client_id = "EXTERNAL" - areq = None - acr = None - sid = setup_session(endpoint_context, areq, uid, client_id, acr, salt="salt") - assert sid - code = endpoint_context.sdb[sid]["code"] - assert code - - res = endpoint_context.sdb.upgrade_to_token(code) - assert "access_token" in res - - endpoint_context.sdb.revoke_uid("_user_") - assert endpoint_context.sdb.is_session_revoked(sid) - - -def make_sub_uid(uid, **kwargs): - return uid - - -def test_sub_minting_function(): - conf["sub_func"] = {"public": {"function": make_sub_uid}} - - endpoint_context = EndpointContext(conf) - uid = "_user_" - client_id = "EXTERNAL" - areq = None - acr = None - sid = setup_session(endpoint_context, areq, uid, client_id, acr, salt="salt") - assert endpoint_context.sdb[sid]["sub"] == uid - - -class SubMinter(object): - def __call__(self, *args, **kwargs): - return args[0] - - -def test_sub_minting_class(): - conf["sub_func"] = {"public": {"class": SubMinter}} - - endpoint_context = EndpointContext(conf) - uid = "_user_" - client_id = "EXTERNAL" - areq = None - acr = None - sid = setup_session(endpoint_context, areq, uid, client_id, acr, salt="salt") - assert endpoint_context.sdb[sid]["sub"] == uid diff --git a/tests/test_10_oidc_authz.py b/tests/test_10_oidc_authz.py.no similarity index 100% rename from tests/test_10_oidc_authz.py rename to tests/test_10_oidc_authz.py.no diff --git a/tests/test_24_oauth2_authorization_endpoint.py b/tests/test_24_oauth2_authorization_endpoint.py index 8d4cbcd..cbda2fc 100755 --- a/tests/test_24_oauth2_authorization_endpoint.py +++ b/tests/test_24_oauth2_authorization_endpoint.py @@ -1,12 +1,10 @@ +from http.cookies import SimpleCookie import io import json import os -from http.cookies import SimpleCookie from urllib.parse import parse_qs from urllib.parse import urlparse -import pytest -import yaml from cryptojwt import KeyJar from cryptojwt.jwt import utc_time_sans_frac from cryptojwt.utils import as_bytes @@ -17,7 +15,10 @@ from oidcmsg.oauth2 import AuthorizationRequest from oidcmsg.oauth2 import AuthorizationResponse from oidcmsg.time_util import in_a_while +import pytest +import yaml +from oidcendpoint.authn_event import create_authn_event from oidcendpoint.common.authorization import FORM_POST from oidcendpoint.common.authorization import get_uri from oidcendpoint.common.authorization import inputs @@ -29,12 +30,12 @@ from oidcendpoint.exception import NoSuchAuthentication from oidcendpoint.exception import RedirectURIError from oidcendpoint.exception import ToOld -from oidcendpoint.exception import UnAuthorizedClient from oidcendpoint.exception import UnAuthorizedClientScope from oidcendpoint.exception import UnknownClient +from oidcendpoint.grant import Grant from oidcendpoint.id_token import IDToken from oidcendpoint.oauth2.authorization import Authorization -from oidcendpoint.session import SessionInfo +from oidcendpoint.session_management import db_key from oidcendpoint.user_info import UserInfo KEYDEFS = [ @@ -199,6 +200,8 @@ def create_endpoint(self): endpoint_context.keyjar.export_jwks(True, ""), conf["issuer"] ) self.endpoint = endpoint_context.endpoint["authorization"] + self.session_manager = endpoint_context.session_manager + self.user_id = "diana" self.rp_keyjar = KeyJar() self.rp_keyjar.add_symmetric("client_1", "hemligtkodord1234567890") @@ -206,12 +209,52 @@ def create_endpoint(self): "client_1", "hemligtkodord1234567890" ) + def _create_session(self, auth_req, sub_type="public", sector_identifier=''): + client_id = auth_req['client_id'] + ae = create_authn_event(self.user_id, self.session_manager.salt) + self.session_manager.create_session(ae, auth_req, self.user_id, client_id=client_id, + sub_type=sub_type, sector_identifier=sector_identifier) + return db_key(self.user_id, client_id) + + def _do_grant(self, auth_req): + client_id = auth_req['client_id'] + # The user consent module produces a Grant instance + grant = Grant(scope=auth_req['scope'], resources=[client_id]) + + # the grant is assigned to a session (user_id, client_id) + self.session_manager.set([self.user_id, client_id, grant.id], grant) + return db_key(self.user_id, client_id, grant.id) + + # def _mint_code(self, grant, client_id): + # sid = db_key(self.user_id, client_id) + # # Constructing an authorization code is now done + # return grant.mint_token( + # 'authorization_code', + # value=self.session_manager.token_handler["code"](sid), + # expires_at=time_sans_frac() + 300 # 5 minutes from now + # ) + # + # def _mint_access_token(self, grant, client_id, token_ref=None): + # _csi = self.session_manager.get([self.user_id, client_id]) + # return grant.mint_token( + # 'access_token', + # value=self.session_manager.token_handler["access_token"]( + # db_key(self.user_id, client_id), + # client_id=client_id, + # aud=grant.resources, + # user_claims=None, + # scope=grant.scope, + # sub=_csi['sub'] + # ), + # expires_at=time_sans_frac() + 900, # 15 minutes from now + # based_on=token_ref # Means the token (tok) was used to mint this token + # ) + def test_init(self): assert self.endpoint def test_parse(self): _req = self.endpoint.parse_request(AUTH_REQ_DICT) - assert isinstance(_req, AuthorizationRequest) assert set(_req.keys()) == set(AUTH_REQ.keys()) @@ -393,24 +436,16 @@ def test_create_authn_response(self): scope="openid", ) - _ec = self.endpoint.endpoint_context - _ec.sdb["session_id"] = SessionInfo( - authn_req=request, - uid="diana", - sub="abcdefghijkl", - authn_event={ - "authn_info": "loa1", - "uid": "diana", - "authn_time": utc_time_sans_frac(), - }, - ) - _ec.cdb["client_id"] = { + self.endpoint.endpoint_context.cdb["client_id"] = { "client_id": "client_id", "redirect_uris": [("https://rp.example.com/cb", {})], "id_token_signed_response_alg": "ES256", } - resp = self.endpoint.create_authn_response(request, "session_id") + self._create_session(request) + session_id = self._do_grant(request) + + resp = self.endpoint.create_authn_response(request, session_id) assert isinstance(resp["response_args"], AuthorizationErrorResponse) def test_setup_auth(self): @@ -512,21 +547,13 @@ def test_setup_auth_user(self): "redirect_uris": [("https://rp.example.com/cb", {})], "id_token_signed_response_alg": "RS256", } - _ec = self.endpoint.endpoint_context - _ec.sdb["session_id"] = SessionInfo( - authn_req=request, - uid="diana", - sub="abcdefghijkl", - authn_event={ - "authn_info": "loa1", - "uid": "diana", - "authn_time": utc_time_sans_frac(), - }, - ) - item = _ec.authn_broker.db["anon"] + pre_sid = self._create_session(request) + session_id = self._do_grant(request) + + item = self.endpoint.endpoint_context.authn_broker.db["anon"] item["method"].user = b64e( - as_bytes(json.dumps({"uid": "krall", "sid": "session_id"})) + as_bytes(json.dumps({"uid": "krall", "sid": session_id})) ) res = self.endpoint.setup_auth(request, redirect_uri, cinfo, None) @@ -548,22 +575,18 @@ def test_setup_auth_session_revoked(self): "redirect_uris": [("https://rp.example.com/cb", {})], "id_token_signed_response_alg": "RS256", } - _ec = self.endpoint.endpoint_context - _ec.sdb["session_id"] = SessionInfo( - authn_req=request, - uid="diana", - sub="abcdefghijkl", - authn_event={ - "authn_info": "loa1", - "uid": "diana", - "authn_time": utc_time_sans_frac(), - }, - revoked=True, - ) + pre_sid = self._create_session(request) + session_id = self._do_grant(request) + + _mngr = self.endpoint.endpoint_context.session_manager + _csi = _mngr[session_id] + _csi.revoked = True + + _ec = self.endpoint.endpoint_context item = _ec.authn_broker.db["anon"] item["method"].user = b64e( - as_bytes(json.dumps({"uid": "krall", "sid": "session_id"})) + as_bytes(json.dumps({"uid": "krall", "sid": session_id})) ) res = self.endpoint.setup_auth(request, redirect_uri, cinfo, None) diff --git a/tests/test_24_oauth2_authorization_endpoint_jar.py b/tests/test_24_oauth2_authorization_endpoint_jar.py index efc1854..24b7644 100755 --- a/tests/test_24_oauth2_authorization_endpoint_jar.py +++ b/tests/test_24_oauth2_authorization_endpoint_jar.py @@ -182,6 +182,8 @@ def create_endpoint(self): endpoint_context.keyjar.export_jwks(True, ""), conf["issuer"] ) self.endpoint = endpoint_context.endpoint["authorization"] + self.session_manager = endpoint_context.session_manager + self.user_id = "diana" self.rp_keyjar = KeyJar() self.rp_keyjar.add_symmetric("client_1", "hemligtkodord1234567890") diff --git a/tests/test_24_oidc_authorization_endpoint.py b/tests/test_24_oidc_authorization_endpoint.py index 2130b5d..214b74a 100755 --- a/tests/test_24_oidc_authorization_endpoint.py +++ b/tests/test_24_oidc_authorization_endpoint.py @@ -4,14 +4,17 @@ from urllib.parse import parse_qs from urllib.parse import urlparse -import pytest -import responses -import yaml from cryptojwt import JWT from cryptojwt import KeyJar from cryptojwt.jwt import utc_time_sans_frac from cryptojwt.utils import as_bytes from cryptojwt.utils import b64e + +from oidcendpoint.grant import Grant +from oidcendpoint.session_management import db_key + +from oidcendpoint.authn_event import create_authn_event +from oidcendpoint.oidc.authorization import Authorization from oidcmsg.exception import ParameterError from oidcmsg.exception import URIError from oidcmsg.oauth2 import AuthorizationErrorResponse @@ -20,7 +23,11 @@ from oidcmsg.oidc import AuthorizationResponse from oidcmsg.oidc import verified_claim_name from oidcmsg.oidc import verify_id_token +import pytest +import responses +import yaml +from oidcendpoint.authz import AuthzHandling from oidcendpoint.common.authorization import FORM_POST from oidcendpoint.common.authorization import join_query from oidcendpoint.common.authorization import verify_uri @@ -32,12 +39,10 @@ from oidcendpoint.exception import NoSuchAuthentication from oidcendpoint.exception import RedirectURIError from oidcendpoint.exception import ToOld -from oidcendpoint.exception import UnAuthorizedClient from oidcendpoint.exception import UnknownClient from oidcendpoint.id_token import IDToken from oidcendpoint.login_hint import LoginHint2Acrs from oidcendpoint.oidc import userinfo -from oidcendpoint.oidc.authorization import Authorization from oidcendpoint.oidc.authorization import acr_claims from oidcendpoint.oidc.authorization import get_uri from oidcendpoint.oidc.authorization import inputs @@ -45,7 +50,6 @@ from oidcendpoint.oidc.provider_config import ProviderConfiguration from oidcendpoint.oidc.registration import Registration from oidcendpoint.oidc.token import AccessToken -from oidcendpoint.session import SessionInfo from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD from oidcendpoint.user_authn.authn_context import UNSPECIFIED from oidcendpoint.user_authn.authn_context import init_method @@ -110,7 +114,6 @@ def full_path(local_file): USERINFO_db = json.loads(open(full_path("users.json")).read()) - client_yaml = """ oidc_clients: client_1: @@ -222,6 +225,7 @@ def create_endpoint(self): }, "userinfo": {"class": UserInfo, "kwargs": {"db": USERINFO_db}}, "template_dir": "template", + "authz": {"class": AuthzHandling, "kwargs": {}}, "cookie_dealer": { "class": CookieDealer, "kwargs": { @@ -240,12 +244,15 @@ def create_endpoint(self): }, } endpoint_context = EndpointContext(conf) + _clients = yaml.safe_load(io.StringIO(client_yaml)) endpoint_context.cdb = _clients["oidc_clients"] endpoint_context.keyjar.import_jwks( endpoint_context.keyjar.export_jwks(True, ""), conf["issuer"] ) self.endpoint = endpoint_context.endpoint["authorization"] + self.session_manager = endpoint_context.session_manager + self.user_id = "diana" self.rp_keyjar = KeyJar() self.rp_keyjar.add_symmetric("client_1", "hemligtkodord1234567890") @@ -256,6 +263,22 @@ def create_endpoint(self): def test_init(self): assert self.endpoint + def _create_session(self, auth_req, sub_type="public", sector_identifier=''): + client_id = auth_req['client_id'] + ae = create_authn_event(self.user_id, self.session_manager.salt) + self.session_manager.create_session(ae, auth_req, self.user_id, client_id=client_id, + sub_type=sub_type, sector_identifier=sector_identifier) + return db_key(self.user_id, client_id) + + def _do_grant(self, auth_req): + client_id = auth_req['client_id'] + # The user consent module produces a Grant instance + grant = Grant(scope=auth_req['scope'], resources=[client_id]) + + # the grant is assigned to a session (user_id, client_id) + self.session_manager.set([self.user_id, client_id, grant.id], grant) + return db_key(self.user_id, client_id, grant.id) + def test_parse(self): _req = self.endpoint.parse_request(AUTH_REQ_DICT) diff --git a/tests/test_24_oidc_authorization_endpoint.py.no b/tests/test_24_oidc_authorization_endpoint.py.no new file mode 100755 index 0000000..2130b5d --- /dev/null +++ b/tests/test_24_oidc_authorization_endpoint.py.no @@ -0,0 +1,933 @@ +import io +import json +import os +from urllib.parse import parse_qs +from urllib.parse import urlparse + +import pytest +import responses +import yaml +from cryptojwt import JWT +from cryptojwt import KeyJar +from cryptojwt.jwt import utc_time_sans_frac +from cryptojwt.utils import as_bytes +from cryptojwt.utils import b64e +from oidcmsg.exception import ParameterError +from oidcmsg.exception import URIError +from oidcmsg.oauth2 import AuthorizationErrorResponse +from oidcmsg.oauth2 import ResponseMessage +from oidcmsg.oidc import AuthorizationRequest +from oidcmsg.oidc import AuthorizationResponse +from oidcmsg.oidc import verified_claim_name +from oidcmsg.oidc import verify_id_token + +from oidcendpoint.common.authorization import FORM_POST +from oidcendpoint.common.authorization import join_query +from oidcendpoint.common.authorization import verify_uri +from oidcendpoint.cookie import CookieDealer +from oidcendpoint.cookie import cookie_value +from oidcendpoint.cookie import new_cookie +from oidcendpoint.endpoint_context import EndpointContext +from oidcendpoint.exception import InvalidRequest +from oidcendpoint.exception import NoSuchAuthentication +from oidcendpoint.exception import RedirectURIError +from oidcendpoint.exception import ToOld +from oidcendpoint.exception import UnAuthorizedClient +from oidcendpoint.exception import UnknownClient +from oidcendpoint.id_token import IDToken +from oidcendpoint.login_hint import LoginHint2Acrs +from oidcendpoint.oidc import userinfo +from oidcendpoint.oidc.authorization import Authorization +from oidcendpoint.oidc.authorization import acr_claims +from oidcendpoint.oidc.authorization import get_uri +from oidcendpoint.oidc.authorization import inputs +from oidcendpoint.oidc.authorization import re_authenticate +from oidcendpoint.oidc.provider_config import ProviderConfiguration +from oidcendpoint.oidc.registration import Registration +from oidcendpoint.oidc.token import AccessToken +from oidcendpoint.session import SessionInfo +from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD +from oidcendpoint.user_authn.authn_context import UNSPECIFIED +from oidcendpoint.user_authn.authn_context import init_method +from oidcendpoint.user_authn.user import NoAuthn +from oidcendpoint.user_authn.user import UserAuthnMethod +from oidcendpoint.user_authn.user import UserPassJinja2 +from oidcendpoint.user_info import UserInfo +from oidcendpoint.util import JSONDictDB + +KEYDEFS = [ + {"type": "RSA", "key": "", "use": ["sig"]} + # {"type": "EC", "crv": "P-256", "use": ["sig"]} +] + +RESPONSE_TYPES_SUPPORTED = [ + ["code"], + ["token"], + ["id_token"], + ["code", "token"], + ["code", "id_token"], + ["id_token", "token"], + ["code", "token", "id_token"], + ["none"], +] + +CAPABILITIES = { + "subject_types_supported": ["public", "pairwise"], + "grant_types_supported": [ + "authorization_code", + "implicit", + "urn:ietf:params:oauth:grant-type:jwt-bearer", + "refresh_token", + ], +} + +CLAIMS = {"id_token": {"given_name": {"essential": True}, "nickname": None}} + +AUTH_REQ = AuthorizationRequest( + client_id="client_1", + redirect_uri="https://example.com/cb", + scope=["openid"], + state="STATE", + response_type="code", +) + +AUTH_REQ_DICT = AUTH_REQ.to_dict() + +AUTH_REQ_2 = AuthorizationRequest( + client_id="client3", + redirect_uri="https://127.0.0.1:8090/authz_cb/bobcat", + scope=["openid"], + state="STATE2", + response_type="code", +) + +BASEDIR = os.path.abspath(os.path.dirname(__file__)) + + +def full_path(local_file): + return os.path.join(BASEDIR, local_file) + + +USERINFO_db = json.loads(open(full_path("users.json")).read()) + + +client_yaml = """ +oidc_clients: + client_1: + "client_secret": 'hemligtkodord' + "redirect_uris": + - ['https://example.com/cb', ''] + "client_salt": "salted" + 'token_endpoint_auth_method': 'client_secret_post' + 'response_types': + - 'code' + - 'token' + - 'code id_token' + - 'id_token' + - 'code id_token token' + client2: + client_secret: "spraket_sr.se" + redirect_uris: + - ['https://app1.example.net/foo', ''] + - ['https://app2.example.net/bar', ''] + response_types: + - code + client3: + client_secret: '2222222222222222222222222222222222222222' + redirect_uris: + - ['https://127.0.0.1:8090/authz_cb/bobcat', ''] + post_logout_redirect_uris: + - ['https://openidconnect.net/', ''] + response_types: + - code +""" + + +class TestEndpoint(object): + @pytest.fixture(autouse=True) + def create_endpoint(self): + conf = { + "issuer": "https://example.com/", + "password": "mycket hemligt zebra", + "token_expires_in": 600, + "grant_expires_in": 300, + "refresh_token_expires_in": 86400, + "verify_ssl": False, + "capabilities": CAPABILITIES, + "keys": {"uri_path": "static/jwks.json", "key_defs": KEYDEFS}, + "id_token": { + "class": IDToken, + "kwargs": { + "available_claims": { + "email": {"essential": True}, + "email_verified": {"essential": True}, + } + }, + }, + "endpoint": { + "provider_config": { + "path": "{}/.well-known/openid-configuration", + "class": ProviderConfiguration, + "kwargs": {}, + }, + "registration": { + "path": "{}/registration", + "class": Registration, + "kwargs": {}, + }, + "authorization": { + "path": "{}/authorization", + "class": Authorization, + "kwargs": { + "response_types_supported": [ + " ".join(x) for x in RESPONSE_TYPES_SUPPORTED + ], + "response_modes_supported": ["query", "fragment", "form_post"], + "claims_parameter_supported": True, + "request_parameter_supported": True, + "request_uri_parameter_supported": True, + }, + }, + "token": { + "path": "token", + "class": AccessToken, + "kwargs": { + "client_authn_method": [ + "client_secret_post", + "client_secret_basic", + "client_secret_jwt", + "private_key_jwt", + ] + }, + }, + "userinfo": { + "path": "userinfo", + "class": userinfo.UserInfo, + "kwargs": { + "db_file": "users.json", + "claim_types_supported": [ + "normal", + "aggregated", + "distributed", + ], + }, + }, + }, + "authentication": { + "anon": { + "acr": "http://www.swamid.se/policy/assurance/al1", + "class": "oidcendpoint.user_authn.user.NoAuthn", + "kwargs": {"user": "diana"}, + } + }, + "userinfo": {"class": UserInfo, "kwargs": {"db": USERINFO_db}}, + "template_dir": "template", + "cookie_dealer": { + "class": CookieDealer, + "kwargs": { + "sign_key": "ghsNKDDLshZTPn974nOsIGhedULrsqnsGoBFBLwUKuJhE2ch", + "default_values": { + "name": "oidcop", + "domain": "127.0.0.1", + "path": "/", + "max_age": 3600, + }, + }, + }, + "login_hint2acrs": { + "class": LoginHint2Acrs, + "kwargs": {"scheme_map": {"email": [INTERNETPROTOCOLPASSWORD]}}, + }, + } + endpoint_context = EndpointContext(conf) + _clients = yaml.safe_load(io.StringIO(client_yaml)) + endpoint_context.cdb = _clients["oidc_clients"] + endpoint_context.keyjar.import_jwks( + endpoint_context.keyjar.export_jwks(True, ""), conf["issuer"] + ) + self.endpoint = endpoint_context.endpoint["authorization"] + + self.rp_keyjar = KeyJar() + self.rp_keyjar.add_symmetric("client_1", "hemligtkodord1234567890") + self.endpoint.endpoint_context.keyjar.add_symmetric( + "client_1", "hemligtkodord1234567890" + ) + + def test_init(self): + assert self.endpoint + + def test_parse(self): + _req = self.endpoint.parse_request(AUTH_REQ_DICT) + + assert isinstance(_req, AuthorizationRequest) + assert set(_req.keys()) == set(AUTH_REQ.keys()) + + def test_process_request(self): + _pr_resp = self.endpoint.parse_request(AUTH_REQ_DICT) + _resp = self.endpoint.process_request(_pr_resp) + assert set(_resp.keys()) == { + "response_args", + "fragment_enc", + "return_uri", + "cookie", + } + + def test_do_response_code(self): + _pr_resp = self.endpoint.parse_request(AUTH_REQ_DICT) + _resp = self.endpoint.process_request(_pr_resp) + msg = self.endpoint.do_response(**_resp) + assert isinstance(msg, dict) + _msg = parse_qs(msg["response"]) + assert _msg + part = urlparse(msg["response"]) + assert part.fragment == "" + assert part.query + _query = parse_qs(part.query) + assert _query + assert "code" in _query + + def test_do_response_id_token_no_nonce(self): + _orig_req = AUTH_REQ_DICT.copy() + _orig_req["response_type"] = "id_token" + _pr_resp = self.endpoint.parse_request(_orig_req) + # Missing nonce + assert isinstance(_pr_resp, ResponseMessage) + + def test_do_response_id_token(self): + _orig_req = AUTH_REQ_DICT.copy() + _orig_req["response_type"] = "id_token" + _orig_req["nonce"] = "rnd_nonce" + _pr_resp = self.endpoint.parse_request(_orig_req) + _resp = self.endpoint.process_request(_pr_resp) + msg = self.endpoint.do_response(**_resp) + assert isinstance(msg, dict) + part = urlparse(msg["response"]) + assert part.query == "" + assert part.fragment + _frag_msg = parse_qs(part.fragment) + assert _frag_msg + assert "id_token" in _frag_msg + assert "code" not in _frag_msg + assert "token" not in _frag_msg + + def test_do_response_id_token_token(self): + _orig_req = AUTH_REQ_DICT.copy() + _orig_req["response_type"] = "id_token token" + _orig_req["nonce"] = "rnd_nonce" + _pr_resp = self.endpoint.parse_request(_orig_req) + assert isinstance(_pr_resp, AuthorizationErrorResponse) + assert _pr_resp["error"] == "invalid_request" + + def test_do_response_code_token(self): + _orig_req = AUTH_REQ_DICT.copy() + _orig_req["response_type"] = "code token" + _pr_resp = self.endpoint.parse_request(_orig_req) + assert isinstance(_pr_resp, AuthorizationErrorResponse) + assert _pr_resp["error"] == "invalid_request" + + def test_do_response_code_id_token(self): + _orig_req = AUTH_REQ_DICT.copy() + _orig_req["response_type"] = "code id_token" + _orig_req["nonce"] = "rnd_nonce" + _pr_resp = self.endpoint.parse_request(_orig_req) + _resp = self.endpoint.process_request(_pr_resp) + msg = self.endpoint.do_response(**_resp) + assert isinstance(msg, dict) + part = urlparse(msg["response"]) + assert part.query == "" + assert part.fragment + _frag_msg = parse_qs(part.fragment) + assert _frag_msg + assert "id_token" in _frag_msg + assert "code" in _frag_msg + assert "access_token" not in _frag_msg + + def test_do_response_code_id_token_token(self): + _orig_req = AUTH_REQ_DICT.copy() + _orig_req["response_type"] = "code id_token token" + _orig_req["nonce"] = "rnd_nonce" + _pr_resp = self.endpoint.parse_request(_orig_req) + _resp = self.endpoint.process_request(_pr_resp) + msg = self.endpoint.do_response(**_resp) + assert isinstance(msg, dict) + part = urlparse(msg["response"]) + assert part.query == "" + assert part.fragment + _frag_msg = parse_qs(part.fragment) + assert _frag_msg + assert "id_token" in _frag_msg + assert "code" in _frag_msg + assert "access_token" in _frag_msg + + def test_id_token_claims(self): + _req = AUTH_REQ_DICT.copy() + _req["claims"] = CLAIMS + _req["response_type"] = "code id_token token" + _req["nonce"] = "rnd_nonce" + _pr_resp = self.endpoint.parse_request(_req) + _resp = self.endpoint.process_request(_pr_resp) + idt = verify_id_token( + _resp["response_args"], keyjar=self.endpoint.endpoint_context.keyjar + ) + assert idt + # from claims + assert "given_name" in _resp["response_args"]["__verified_id_token"] + # from config + assert "email" in _resp["response_args"]["__verified_id_token"] + + def test_re_authenticate(self): + request = {"prompt": "login"} + authn = UserAuthnMethod(self.endpoint.endpoint_context) + assert re_authenticate(request, authn) + + def test_id_token_acr(self): + _req = AUTH_REQ_DICT.copy() + _req["claims"] = { + "id_token": {"acr": {"value": "http://www.swamid.se/policy/assurance/al1"}} + } + _req["response_type"] = "code id_token token" + _req["nonce"] = "rnd_nonce" + _pr_resp = self.endpoint.parse_request(_req) + _resp = self.endpoint.process_request(_pr_resp) + res = verify_id_token( + _resp["response_args"], keyjar=self.endpoint.endpoint_context.keyjar + ) + assert res + res = _resp["response_args"][verified_claim_name("id_token")] + assert res["acr"] == "http://www.swamid.se/policy/assurance/al1" + + def test_verify_uri_unknown_client(self): + request = {"redirect_uri": "https://rp.example.com/cb"} + with pytest.raises(UnknownClient): + verify_uri(self.endpoint.endpoint_context, request, "redirect_uri") + + def test_verify_uri_fragment(self): + _ec = self.endpoint.endpoint_context + _ec.cdb["client_id"] = {"redirect_uri": ["https://rp.example.com/auth_cb"]} + request = {"redirect_uri": "https://rp.example.com/cb#foobar"} + with pytest.raises(URIError): + verify_uri(_ec, request, "redirect_uri", "client_id") + + def test_verify_uri_noregistered(self): + _ec = self.endpoint.endpoint_context + request = {"redirect_uri": "https://rp.example.com/cb"} + + with pytest.raises(KeyError): + verify_uri(_ec, request, "redirect_uri", "client_id") + + def test_verify_uri_unregistered(self): + _ec = self.endpoint.endpoint_context + _ec.cdb["client_id"] = { + "redirect_uris": [("https://rp.example.com/auth_cb", {})] + } + + request = {"redirect_uri": "https://rp.example.com/cb"} + + with pytest.raises(RedirectURIError): + verify_uri(_ec, request, "redirect_uri", "client_id") + + def test_verify_uri_qp_match(self): + _ec = self.endpoint.endpoint_context + _ec.cdb["client_id"] = { + "redirect_uris": [("https://rp.example.com/cb", {"foo": ["bar"]})] + } + + request = {"redirect_uri": "https://rp.example.com/cb?foo=bar"} + + verify_uri(_ec, request, "redirect_uri", "client_id") + + def test_verify_uri_qp_mismatch(self): + _ec = self.endpoint.endpoint_context + _ec.cdb["client_id"] = { + "redirect_uris": [("https://rp.example.com/cb", {"foo": ["bar"]})] + } + + request = {"redirect_uri": "https://rp.example.com/cb?foo=bob"} + with pytest.raises(ValueError): + verify_uri(_ec, request, "redirect_uri", "client_id") + + request = {"redirect_uri": "https://rp.example.com/cb?foo=bar&foo=kex"} + with pytest.raises(ValueError): + verify_uri(_ec, request, "redirect_uri", "client_id") + + request = {"redirect_uri": "https://rp.example.com/cb"} + with pytest.raises(ValueError): + verify_uri(_ec, request, "redirect_uri", "client_id") + + request = {"redirect_uri": "https://rp.example.com/cb?foo=bar&level=low"} + with pytest.raises(ValueError): + verify_uri(_ec, request, "redirect_uri", "client_id") + + def test_verify_uri_qp_missing(self): + _ec = self.endpoint.endpoint_context + _ec.cdb["client_id"] = { + "redirect_uris": [ + ("https://rp.example.com/cb", {"foo": ["bar"], "state": ["low"]}) + ] + } + + request = {"redirect_uri": "https://rp.example.com/cb?foo=bar"} + with pytest.raises(ValueError): + verify_uri(_ec, request, "redirect_uri", "client_id") + + def test_verify_uri_qp_missing_val(self): + _ec = self.endpoint.endpoint_context + _ec.cdb["client_id"] = { + "redirect_uris": [("https://rp.example.com/cb", {"foo": ["bar", "low"]})] + } + + request = {"redirect_uri": "https://rp.example.com/cb?foo=bar"} + with pytest.raises(ValueError): + verify_uri(_ec, request, "redirect_uri", "client_id") + + def test_verify_uri_no_registered_qp(self): + _ec = self.endpoint.endpoint_context + _ec.cdb["client_id"] = {"redirect_uris": [("https://rp.example.com/cb", {})]} + + request = {"redirect_uri": "https://rp.example.com/cb?foo=bob"} + with pytest.raises(ValueError): + verify_uri(_ec, request, "redirect_uri", "client_id") + + def test_get_uri(self): + _ec = self.endpoint.endpoint_context + _ec.cdb["client_id"] = {"redirect_uris": [("https://rp.example.com/cb", {})]} + + request = { + "redirect_uri": "https://rp.example.com/cb", + "client_id": "client_id", + } + + assert get_uri(_ec, request, "redirect_uri") == "https://rp.example.com/cb" + + def test_get_uri_no_redirect_uri(self): + _ec = self.endpoint.endpoint_context + _ec.cdb["client_id"] = {"redirect_uris": [("https://rp.example.com/cb", {})]} + + request = {"client_id": "client_id"} + + assert get_uri(_ec, request, "redirect_uri") == "https://rp.example.com/cb" + + def test_get_uri_no_registered(self): + _ec = self.endpoint.endpoint_context + _ec.cdb["client_id"] = {"redirect_uris": [("https://rp.example.com/cb", {})]} + + request = {"client_id": "client_id"} + + with pytest.raises(ParameterError): + get_uri(_ec, request, "post_logout_redirect_uri") + + def test_get_uri_more_then_one_registered(self): + _ec = self.endpoint.endpoint_context + _ec.cdb["client_id"] = { + "redirect_uris": [ + ("https://rp.example.com/cb", {}), + ("https://rp.example.org/authz_cb", {"foo": "bar"}), + ] + } + + request = {"client_id": "client_id"} + + with pytest.raises(ParameterError): + get_uri(_ec, request, "redirect_uri") + + def test_create_authn_response(self): + request = AuthorizationRequest( + client_id="client_id", + redirect_uri="https://rp.example.com/cb", + response_type=["id_token"], + state="state", + nonce="nonce", + scope="openid", + ) + + _ec = self.endpoint.endpoint_context + _ec.sdb["session_id"] = SessionInfo( + authn_req=request, + uid="diana", + sub="abcdefghijkl", + authn_event={ + "authn_info": "loa1", + "uid": "diana", + "authn_time": utc_time_sans_frac(), + }, + ) + _ec.cdb["client_id"] = { + "client_id": "client_id", + "redirect_uris": [("https://rp.example.com/cb", {})], + "id_token_signed_response_alg": "ES256", + } + + resp = self.endpoint.create_authn_response(request, "session_id") + assert isinstance(resp["response_args"], AuthorizationErrorResponse) + + def test_setup_auth(self): + request = AuthorizationRequest( + client_id="client_id", + redirect_uri="https://rp.example.com/cb", + response_type=["id_token"], + state="state", + nonce="nonce", + scope="openid", + ) + redirect_uri = request["redirect_uri"] + cinfo = { + "client_id": "client_id", + "redirect_uris": [("https://rp.example.com/cb", {})], + "id_token_signed_response_alg": "RS256", + } + + kaka = self.endpoint.endpoint_context.cookie_dealer.create_cookie( + "value", "sso" + ) + + res = self.endpoint.setup_auth(request, redirect_uri, cinfo, kaka) + assert set(res.keys()) == {"authn_event", "identity", "user"} + + def test_setup_auth_error(self): + request = AuthorizationRequest( + client_id="client_id", + redirect_uri="https://rp.example.com/cb", + response_type=["id_token"], + state="state", + nonce="nonce", + scope="openid", + ) + redirect_uri = request["redirect_uri"] + cinfo = { + "client_id": "client_id", + "redirect_uris": [("https://rp.example.com/cb", {})], + "id_token_signed_response_alg": "RS256", + } + + item = self.endpoint.endpoint_context.authn_broker.db["anon"] + item["method"].fail = NoSuchAuthentication + + res = self.endpoint.setup_auth(request, redirect_uri, cinfo, None) + assert set(res.keys()) == {"function", "args"} + + item["method"].fail = ToOld + + res = self.endpoint.setup_auth(request, redirect_uri, cinfo, None) + assert set(res.keys()) == {"function", "args"} + + item["method"].file = "" + + def test_setup_auth_user(self): + request = AuthorizationRequest( + client_id="client_id", + redirect_uri="https://rp.example.com/cb", + response_type=["id_token"], + state="state", + nonce="nonce", + scope="openid", + ) + redirect_uri = request["redirect_uri"] + cinfo = { + "client_id": "client_id", + "redirect_uris": [("https://rp.example.com/cb", {})], + "id_token_signed_response_alg": "RS256", + } + _ec = self.endpoint.endpoint_context + _ec.sdb["session_id"] = SessionInfo( + authn_req=request, + uid="diana", + sub="abcdefghijkl", + authn_event={ + "authn_info": "loa1", + "uid": "diana", + "authn_time": utc_time_sans_frac(), + }, + ) + + item = _ec.authn_broker.db["anon"] + item["method"].user = b64e( + as_bytes(json.dumps({"uid": "krall", "sid": "session_id"})) + ) + + res = self.endpoint.setup_auth(request, redirect_uri, cinfo, None) + assert set(res.keys()) == {"authn_event", "identity", "user"} + assert res["identity"]["uid"] == "krall" + + def test_setup_auth_session_revoked(self): + request = AuthorizationRequest( + client_id="client_id", + redirect_uri="https://rp.example.com/cb", + response_type=["id_token"], + state="state", + nonce="nonce", + scope="openid", + ) + redirect_uri = request["redirect_uri"] + cinfo = { + "client_id": "client_id", + "redirect_uris": [("https://rp.example.com/cb", {})], + "id_token_signed_response_alg": "RS256", + } + _ec = self.endpoint.endpoint_context + _ec.sdb["session_id"] = SessionInfo( + authn_req=request, + uid="diana", + sub="abcdefghijkl", + authn_event={ + "authn_info": "loa1", + "uid": "diana", + "authn_time": utc_time_sans_frac(), + }, + revoked=True, + ) + + item = _ec.authn_broker.db["anon"] + item["method"].user = b64e( + as_bytes(json.dumps({"uid": "krall", "sid": "session_id"})) + ) + + res = self.endpoint.setup_auth(request, redirect_uri, cinfo, None) + assert set(res.keys()) == {"args", "function"} + + def test_response_mode_form_post(self): + request = {"response_mode": "form_post"} + info = { + "response_args": AuthorizationResponse(foo="bar"), + "return_uri": "https://example.com/cb", + } + info = self.endpoint.response_mode(request, **info) + assert set(info.keys()) == { + "response_args", + "return_uri", + "response_msg", + "content_type", + "response_placement", + } + assert info["response_msg"] == FORM_POST.format( + action="https://example.com/cb", + inputs='', + ) + + def test_do_response_code_form_post(self): + _req = AUTH_REQ_DICT.copy() + _req["response_mode"] = "form_post" + _pr_resp = self.endpoint.parse_request(_req) + _resp = self.endpoint.process_request(_pr_resp) + msg = self.endpoint.do_response(**_resp) + assert ("Content-type", "text/html") in msg["http_headers"] + assert "response_placement" in msg + + def test_response_mode_fragment(self): + request = {"response_mode": "fragment"} + self.endpoint.response_mode(request, fragment_enc=True) + + with pytest.raises(InvalidRequest): + self.endpoint.response_mode(request, fragment_enc=False) + + info = self.endpoint.response_mode(request) + assert set(info.keys()) == {"fragment_enc"} + + def test_check_session_iframe(self): + self.endpoint.endpoint_context.provider_info[ + "check_session_iframe" + ] = "https://example.com/csi" + _pr_resp = self.endpoint.parse_request(AUTH_REQ_DICT) + _resp = self.endpoint.process_request(_pr_resp) + assert "session_state" in _resp["response_args"] + + def test_setup_auth_login_hint(self): + request = AuthorizationRequest( + client_id="client_id", + redirect_uri="https://rp.example.com/cb", + response_type=["id_token"], + state="state", + nonce="nonce", + scope="openid", + login_hint="tel:0907865204", + ) + redirect_uri = request["redirect_uri"] + cinfo = { + "client_id": "client_id", + "redirect_uris": [("https://rp.example.com/cb", {})], + "id_token_signed_response_alg": "RS256", + } + + item = self.endpoint.endpoint_context.authn_broker.db["anon"] + item["method"].fail = NoSuchAuthentication + + res = self.endpoint.setup_auth(request, redirect_uri, cinfo, None) + assert set(res.keys()) == {"function", "args"} + assert "login_hint" in res["args"] + + def test_setup_auth_login_hint2acrs(self): + request = AuthorizationRequest( + client_id="client_id", + redirect_uri="https://rp.example.com/cb", + response_type=["id_token"], + state="state", + nonce="nonce", + scope="openid", + login_hint="email:foo@bar", + ) + redirect_uri = request["redirect_uri"] + + method_spec = { + "acr": INTERNETPROTOCOLPASSWORD, + "kwargs": {"user": "knoll"}, + "class": NoAuthn, + } + self.endpoint.endpoint_context.authn_broker["foo"] = init_method( + method_spec, None + ) + + item = self.endpoint.endpoint_context.authn_broker.db["anon"] + item["method"].fail = NoSuchAuthentication + item = self.endpoint.endpoint_context.authn_broker.db["foo"] + item["method"].fail = NoSuchAuthentication + + res = self.endpoint.pick_authn_method(request, redirect_uri) + assert set(res.keys()) == {"method", "acr"} + assert res["acr"] == INTERNETPROTOCOLPASSWORD + assert isinstance(res["method"], NoAuthn) + assert res["method"].user == "knoll" + + def test_post_logout_uri(self): + pass + + def test_parse_request(self): + _jwt = JWT(key_jar=self.rp_keyjar, iss="client_1", sign_alg="HS256") + _jws = _jwt.pack( + AUTH_REQ_DICT, aud=self.endpoint.endpoint_context.provider_info["issuer"] + ) + # ----------------- + _req = self.endpoint.parse_request( + { + "request": _jws, + "redirect_uri": AUTH_REQ.get("redirect_uri"), + "response_type": AUTH_REQ.get("response_type"), + "client_id": AUTH_REQ.get("client_id"), + "scope": AUTH_REQ.get("scope"), + } + ) + assert "__verified_request" in _req + + def test_parse_request_uri(self): + _jwt = JWT(key_jar=self.rp_keyjar, iss="client_1", sign_alg="HS256") + _jws = _jwt.pack( + AUTH_REQ_DICT, aud=self.endpoint.endpoint_context.provider_info["issuer"] + ) + + request_uri = "https://client.example.com/req" + # ----------------- + with responses.RequestsMock() as rsps: + rsps.add("GET", request_uri, body=_jws, status=200) + _req = self.endpoint.parse_request( + { + "request_uri": request_uri, + "redirect_uri": AUTH_REQ.get("redirect_uri"), + "response_type": AUTH_REQ.get("response_type"), + "client_id": AUTH_REQ.get("client_id"), + "scope": AUTH_REQ.get("scope"), + } + ) + + assert "__verified_request" in _req + + +def test_inputs(): + elems = inputs(dict(foo="bar", home="stead")) + test_elems = ( + '', + '', + ) + assert test_elems[0] in elems and test_elems[1] in elems + + +def test_acr_claims(): + assert acr_claims({"claims": {"id_token": {"acr": {"value": "foo"}}}}) == ["foo"] + assert acr_claims( + {"claims": {"id_token": {"acr": {"values": ["foo", "bar"]}}}} + ) == ["foo", "bar"] + assert acr_claims({"claims": {"id_token": {"acr": {"values": ["foo"]}}}}) == ["foo"] + assert acr_claims({"claims": {"id_token": {"acr": {"essential": True}}}}) is None + + +def test_join_query(): + redirect_uris = [("https://rp.example.com/cb", {"foo": ["bar"], "state": ["low"]})] + uri = join_query(*redirect_uris[0]) + test_uri = ("https://rp.example.com/cb?", "foo=bar", "state=low") + for i in test_uri: + assert i in uri + + +class TestUserAuthn(object): + @pytest.fixture(autouse=True) + def create_endpoint_context(self): + conf = { + "issuer": "https://example.com/", + "password": "mycket hemligt", + "token_expires_in": 600, + "grant_expires_in": 300, + "refresh_token_expires_in": 86400, + "verify_ssl": False, + "endpoint": {}, + "keys": {"uri_path": "static/jwks.json", "key_defs": KEYDEFS}, + "authentication": { + "user": { + "acr": INTERNETPROTOCOLPASSWORD, + "class": UserPassJinja2, + "verify_endpoint": "verify/user", + "kwargs": { + "template": "user_pass.jinja2", + "sym_key": "24AA/LR6HighEnergy", + "db": { + "class": JSONDictDB, + "kwargs": {"json_path": full_path("passwd.json")}, + }, + "page_header": "Testing log in", + "submit_btn": "Get me in!", + "user_label": "Nickname", + "passwd_label": "Secret sauce", + }, + }, + "anon": { + "acr": UNSPECIFIED, + "class": NoAuthn, + "kwargs": {"user": "diana"}, + }, + }, + "cookie_dealer": { + "class": "oidcendpoint.cookie.CookieDealer", + "kwargs": { + "sign_key": "ghsNKDDLshZTPn974nOsIGhedULrsqnsGoBFBLwUKuJhE2ch", + "default_values": { + "name": "oidc_xx", + "domain": "example.com", + "path": "/", + "max_age": 3600, + }, + }, + }, + "template_dir": "template", + } + self.endpoint_context = EndpointContext(conf) + + def test_authenticated_as_without_cookie(self): + authn_item = self.endpoint_context.authn_broker.pick(INTERNETPROTOCOLPASSWORD) + method = authn_item[0]["method"] + + _info, _time_stamp = method.authenticated_as(None) + assert _info is None + + def test_authenticated_as_with_cookie(self): + authn_item = self.endpoint_context.authn_broker.pick(INTERNETPROTOCOLPASSWORD) + method = authn_item[0]["method"] + + authn_req = {"state": "state_identifier", "client_id": "client 12345"} + _cookie = new_cookie( + self.endpoint_context, + sub="diana", + sid="session_identifier", + state=authn_req["state"], + client_id=authn_req["client_id"], + cookie_name=self.endpoint_context.cookie_name["session"], + ) + + _info, _time_stamp = method.authenticated_as(_cookie) + _info = cookie_value(_info["uid"]) + assert _info["sub"] == "diana" diff --git a/tests/test_25_oidc_token_endpoint.py b/tests/test_25_oidc_token_endpoint.py index dd130c7..9d77be3 100755 --- a/tests/test_25_oidc_token_endpoint.py +++ b/tests/test_25_oidc_token_endpoint.py @@ -1,25 +1,28 @@ import json import os -import pytest from cryptojwt import JWT from cryptojwt.key_jar import build_keyjar from oidcmsg.oidc import AccessTokenRequest from oidcmsg.oidc import AuthorizationRequest from oidcmsg.oidc import RefreshAccessTokenRequest +from oidcmsg.time_util import time_sans_frac +import pytest from oidcendpoint import JWT_BEARER +from oidcendpoint.authn_event import create_authn_event from oidcendpoint.client_authn import verify_client from oidcendpoint.endpoint_context import EndpointContext -from oidcendpoint.exception import MultipleUsage from oidcendpoint.exception import UnAuthorizedClient +from oidcendpoint.grant import Grant +from oidcendpoint.id_token import IDToken from oidcendpoint.oidc import userinfo from oidcendpoint.oidc.authorization import Authorization from oidcendpoint.oidc.provider_config import ProviderConfiguration from oidcendpoint.oidc.refresh_token import RefreshAccessToken from oidcendpoint.oidc.registration import Registration from oidcendpoint.oidc.token import AccessToken -from oidcendpoint.session import setup_session +from oidcendpoint.session_management import db_key from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD from oidcendpoint.user_info import UserInfo @@ -30,7 +33,6 @@ CLIENT_KEYJAR = build_keyjar(KEYDEFS) - RESPONSE_TYPES_SUPPORTED = [ ["code"], ["token"], @@ -145,6 +147,7 @@ def create_endpoint(self): "userinfo": {"class": UserInfo, "kwargs": {"db": {}}}, "client_authn": verify_client, "template_dir": "template", + "id_token": {"class": IDToken, "kwargs": {}}, } endpoint_context = EndpointContext(conf) endpoint_context.cdb["client_1"] = { @@ -156,48 +159,87 @@ def create_endpoint(self): } endpoint_context.keyjar.import_jwks(CLIENT_KEYJAR.export_jwks(), "client_1") self.endpoint = endpoint_context.endpoint["token"] + self.session_manager = endpoint_context.session_manager + self.user_id = "diana" + + def _create_session(self, auth_req, sub_type="public", sector_identifier=''): + client_id = auth_req['client_id'] + ae = create_authn_event(self.user_id, self.session_manager.salt) + self.session_manager.create_session(ae, auth_req, self.user_id, client_id=client_id, + sub_type=sub_type, sector_identifier=sector_identifier) + return db_key(self.user_id, client_id) + + def _do_grant(self, auth_req): + client_id = auth_req['client_id'] + # The user consent module produces a Grant instance + grant = Grant(scope=auth_req['scope'], resources=[client_id]) + + # the grant is assigned to a session (user_id, client_id) + self.session_manager.set([self.user_id, client_id, grant.id], grant) + return grant + + def _mint_code(self, grant, client_id): + sid = db_key(self.user_id, client_id) + # Constructing an authorization code is now done + return grant.mint_token( + 'authorization_code', + value=self.session_manager.token_handler["code"](sid), + expires_at=time_sans_frac() + 300 # 5 minutes from now + ) + + def _mint_access_token(self, grant, client_id, token_ref=None): + _csi = self.session_manager.get([self.user_id, client_id]) + return grant.mint_token( + 'access_token', + value=self.session_manager.token_handler["access_token"]( + db_key(self.user_id, client_id), + client_id=client_id, + aud=grant.resources, + user_claims=None, + scope=grant.scope, + sub=_csi['sub'] + ), + expires_at=time_sans_frac() + 900, # 15 minutes from now + based_on=token_ref # Means the token (tok) was used to mint this token + ) def test_init(self): assert self.endpoint def test_parse(self): - session_id = setup_session(self.endpoint.endpoint_context, AUTH_REQ, uid="user") + self._create_session(AUTH_REQ) + grant = self._do_grant(AUTH_REQ) + code = self._mint_code(grant, AUTH_REQ['client_id']) + _token_request = TOKEN_REQ_DICT.copy() - _token_request["code"] = self.endpoint.endpoint_context.sdb[session_id]["code"] + _token_request["code"] = code.value _req = self.endpoint.parse_request(_token_request) assert isinstance(_req, AccessTokenRequest) assert set(_req.keys()) == set(_token_request.keys()) def test_process_request(self): - session_id = setup_session( - self.endpoint.endpoint_context, - AUTH_REQ, - uid="user", - acr=INTERNETPROTOCOLPASSWORD, - ) + self._create_session(AUTH_REQ) + grant = self._do_grant(AUTH_REQ) + code = self._mint_code(grant, AUTH_REQ['client_id']) + _token_request = TOKEN_REQ_DICT.copy() - _context = self.endpoint.endpoint_context - _token_request["code"] = _context.sdb[session_id]["code"] - _context.sdb.update(session_id, user="diana") - _req = self.endpoint.parse_request(_token_request) + _token_request["code"] = code.value + _req = self.endpoint.parse_request(_token_request) _resp = self.endpoint.process_request(request=_req) assert _resp assert set(_resp.keys()) == {"http_headers", "response_args"} def test_process_request_using_code_twice(self): - session_id = setup_session( - self.endpoint.endpoint_context, - AUTH_REQ, - uid="user", - acr=INTERNETPROTOCOLPASSWORD, - ) + self._create_session(AUTH_REQ) + grant = self._do_grant(AUTH_REQ) + code = self._mint_code(grant, AUTH_REQ['client_id']) + _token_request = TOKEN_REQ_DICT.copy() - _context = self.endpoint.endpoint_context - _token_request["code"] = _context.sdb[session_id]["code"] - _context.sdb.update(session_id, user="diana") + _token_request["code"] = code.value + _req = self.endpoint.parse_request(_token_request) _resp = self.endpoint.process_request(request=_req) @@ -211,15 +253,12 @@ def test_process_request_using_code_twice(self): assert set(_resp.keys()) == {"error"} def test_do_response(self): - session_id = setup_session( - self.endpoint.endpoint_context, - AUTH_REQ, - uid="user", - acr=INTERNETPROTOCOLPASSWORD, - ) - self.endpoint.endpoint_context.sdb.update(session_id, user="diana") + self._create_session(AUTH_REQ) + grant = self._do_grant(AUTH_REQ) + code = self._mint_code(grant, AUTH_REQ['client_id']) + _token_request = TOKEN_REQ_DICT.copy() - _token_request["code"] = self.endpoint.endpoint_context.sdb[session_id]["code"] + _token_request["code"] = code.value _req = self.endpoint.parse_request(_token_request) _resp = self.endpoint.process_request(request=_req) @@ -227,12 +266,10 @@ def test_do_response(self): assert isinstance(msg, dict) def test_process_request_using_private_key_jwt(self): - session_id = setup_session( - self.endpoint.endpoint_context, - AUTH_REQ, - uid="user", - acr=INTERNETPROTOCOLPASSWORD, - ) + self._create_session(AUTH_REQ) + grant = self._do_grant(AUTH_REQ) + code = self._mint_code(grant, AUTH_REQ['client_id']) + _token_request = TOKEN_REQ_DICT.copy() del _token_request["client_id"] del _token_request["client_secret"] @@ -244,9 +281,8 @@ def test_process_request_using_private_key_jwt(self): _token_request.update( {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER} ) - _token_request["code"] = self.endpoint.endpoint_context.sdb[session_id]["code"] + _token_request["code"] = code.value - _context.sdb.update(session_id, user="diana") _req = self.endpoint.parse_request(_token_request) _resp = self.endpoint.process_request(request=_req) diff --git a/tests/test_25_oidc_token_endpoint.py.no b/tests/test_25_oidc_token_endpoint.py.no new file mode 100755 index 0000000..dd130c7 --- /dev/null +++ b/tests/test_25_oidc_token_endpoint.py.no @@ -0,0 +1,255 @@ +import json +import os + +import pytest +from cryptojwt import JWT +from cryptojwt.key_jar import build_keyjar +from oidcmsg.oidc import AccessTokenRequest +from oidcmsg.oidc import AuthorizationRequest +from oidcmsg.oidc import RefreshAccessTokenRequest + +from oidcendpoint import JWT_BEARER +from oidcendpoint.client_authn import verify_client +from oidcendpoint.endpoint_context import EndpointContext +from oidcendpoint.exception import MultipleUsage +from oidcendpoint.exception import UnAuthorizedClient +from oidcendpoint.oidc import userinfo +from oidcendpoint.oidc.authorization import Authorization +from oidcendpoint.oidc.provider_config import ProviderConfiguration +from oidcendpoint.oidc.refresh_token import RefreshAccessToken +from oidcendpoint.oidc.registration import Registration +from oidcendpoint.oidc.token import AccessToken +from oidcendpoint.session import setup_session +from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD +from oidcendpoint.user_info import UserInfo + +KEYDEFS = [ + {"type": "RSA", "key": "", "use": ["sig"]}, + {"type": "EC", "crv": "P-256", "use": ["sig"]}, +] + +CLIENT_KEYJAR = build_keyjar(KEYDEFS) + + +RESPONSE_TYPES_SUPPORTED = [ + ["code"], + ["token"], + ["id_token"], + ["code", "token"], + ["code", "id_token"], + ["id_token", "token"], + ["code", "token", "id_token"], + ["none"], +] + +CAPABILITIES = { + "subject_types_supported": ["public", "pairwise"], + "grant_types_supported": [ + "authorization_code", + "implicit", + "urn:ietf:params:oauth:grant-type:jwt-bearer", + "refresh_token", + ], +} + +AUTH_REQ = AuthorizationRequest( + client_id="client_1", + redirect_uri="https://example.com/cb", + scope=["openid"], + state="STATE", + response_type="code", +) + +TOKEN_REQ = AccessTokenRequest( + client_id="client_1", + redirect_uri="https://example.com/cb", + state="STATE", + grant_type="authorization_code", + client_secret="hemligt", +) + +REFRESH_TOKEN_REQ = RefreshAccessTokenRequest( + grant_type="refresh_token", client_id="client_1", client_secret="hemligt" +) + +TOKEN_REQ_DICT = TOKEN_REQ.to_dict() + +BASEDIR = os.path.abspath(os.path.dirname(__file__)) + + +def full_path(local_file): + return os.path.join(BASEDIR, local_file) + + +USERINFO = UserInfo(json.loads(open(full_path("users.json")).read())) + + +class TestEndpoint(object): + @pytest.fixture(autouse=True) + def create_endpoint(self): + conf = { + "issuer": "https://example.com/", + "password": "mycket hemligt", + "token_expires_in": 600, + "grant_expires_in": 300, + "refresh_token_expires_in": 86400, + "verify_ssl": False, + "capabilities": CAPABILITIES, + "keys": {"uri_path": "jwks.json", "key_defs": KEYDEFS}, + "endpoint": { + "provider_config": { + "path": ".well-known/openid-configuration", + "class": ProviderConfiguration, + "kwargs": {}, + }, + "registration": { + "path": "registration", + "class": Registration, + "kwargs": {}, + }, + "authorization": { + "path": "authorization", + "class": Authorization, + "kwargs": {}, + }, + "token": { + "path": "token", + "class": AccessToken, + "kwargs": { + "client_authn_method": [ + "client_secret_basic", + "client_secret_post", + "client_secret_jwt", + "private_key_jwt", + ] + }, + }, + "refresh_token": { + "path": "token", + "class": RefreshAccessToken, + "kwargs": {}, + }, + "userinfo": { + "path": "userinfo", + "class": userinfo.UserInfo, + "kwargs": {"db_file": "users.json"}, + }, + }, + "authentication": { + "anon": { + "acr": INTERNETPROTOCOLPASSWORD, + "class": "oidcendpoint.user_authn.user.NoAuthn", + "kwargs": {"user": "diana"}, + } + }, + "userinfo": {"class": UserInfo, "kwargs": {"db": {}}}, + "client_authn": verify_client, + "template_dir": "template", + } + endpoint_context = EndpointContext(conf) + endpoint_context.cdb["client_1"] = { + "client_secret": "hemligt", + "redirect_uris": [("https://example.com/cb", None)], + "client_salt": "salted", + "token_endpoint_auth_method": "client_secret_post", + "response_types": ["code", "token", "code id_token", "id_token"], + } + endpoint_context.keyjar.import_jwks(CLIENT_KEYJAR.export_jwks(), "client_1") + self.endpoint = endpoint_context.endpoint["token"] + + def test_init(self): + assert self.endpoint + + def test_parse(self): + session_id = setup_session(self.endpoint.endpoint_context, AUTH_REQ, uid="user") + _token_request = TOKEN_REQ_DICT.copy() + _token_request["code"] = self.endpoint.endpoint_context.sdb[session_id]["code"] + _req = self.endpoint.parse_request(_token_request) + + assert isinstance(_req, AccessTokenRequest) + assert set(_req.keys()) == set(_token_request.keys()) + + def test_process_request(self): + session_id = setup_session( + self.endpoint.endpoint_context, + AUTH_REQ, + uid="user", + acr=INTERNETPROTOCOLPASSWORD, + ) + _token_request = TOKEN_REQ_DICT.copy() + _context = self.endpoint.endpoint_context + _token_request["code"] = _context.sdb[session_id]["code"] + _context.sdb.update(session_id, user="diana") + _req = self.endpoint.parse_request(_token_request) + + _resp = self.endpoint.process_request(request=_req) + + assert _resp + assert set(_resp.keys()) == {"http_headers", "response_args"} + + def test_process_request_using_code_twice(self): + session_id = setup_session( + self.endpoint.endpoint_context, + AUTH_REQ, + uid="user", + acr=INTERNETPROTOCOLPASSWORD, + ) + _token_request = TOKEN_REQ_DICT.copy() + _context = self.endpoint.endpoint_context + _token_request["code"] = _context.sdb[session_id]["code"] + _context.sdb.update(session_id, user="diana") + _req = self.endpoint.parse_request(_token_request) + _resp = self.endpoint.process_request(request=_req) + + # 2nd time used + # TODO: There is a bug in _post_parse_request, the returned error + # should be invalid_grant, not invalid_client + _req = self.endpoint.parse_request(_token_request) + _resp = self.endpoint.process_request(request=_req) + + assert _resp + assert set(_resp.keys()) == {"error"} + + def test_do_response(self): + session_id = setup_session( + self.endpoint.endpoint_context, + AUTH_REQ, + uid="user", + acr=INTERNETPROTOCOLPASSWORD, + ) + self.endpoint.endpoint_context.sdb.update(session_id, user="diana") + _token_request = TOKEN_REQ_DICT.copy() + _token_request["code"] = self.endpoint.endpoint_context.sdb[session_id]["code"] + _req = self.endpoint.parse_request(_token_request) + + _resp = self.endpoint.process_request(request=_req) + msg = self.endpoint.do_response(request=_req, **_resp) + assert isinstance(msg, dict) + + def test_process_request_using_private_key_jwt(self): + session_id = setup_session( + self.endpoint.endpoint_context, + AUTH_REQ, + uid="user", + acr=INTERNETPROTOCOLPASSWORD, + ) + _token_request = TOKEN_REQ_DICT.copy() + del _token_request["client_id"] + del _token_request["client_secret"] + _context = self.endpoint.endpoint_context + + _jwt = JWT(CLIENT_KEYJAR, iss=AUTH_REQ["client_id"], sign_alg="RS256") + _jwt.with_jti = True + _assertion = _jwt.pack({"aud": [_context.endpoint["token"].full_path]}) + _token_request.update( + {"client_assertion": _assertion, "client_assertion_type": JWT_BEARER} + ) + _token_request["code"] = self.endpoint.endpoint_context.sdb[session_id]["code"] + + _context.sdb.update(session_id, user="diana") + _req = self.endpoint.parse_request(_token_request) + _resp = self.endpoint.process_request(request=_req) + + # 2nd time used + with pytest.raises(UnAuthorizedClient): + self.endpoint.parse_request(_token_request) diff --git a/tests/test_26_oidc_userinfo_endpoint.py b/tests/test_26_oidc_userinfo_endpoint.py index 41f2bb9..88fc2d2 100755 --- a/tests/test_26_oidc_userinfo_endpoint.py +++ b/tests/test_26_oidc_userinfo_endpoint.py @@ -1,20 +1,24 @@ import json import os -import pytest -from cryptojwt.jwt import utc_time_sans_frac from oidcmsg.oidc import AccessTokenRequest from oidcmsg.oidc import AuthorizationRequest +from oidcmsg.time_util import time_sans_frac +import pytest from oidcendpoint import user_info +from oidcendpoint.authn_event import create_authn_event from oidcendpoint.endpoint_context import EndpointContext +from oidcendpoint.grant import Grant from oidcendpoint.id_token import IDToken from oidcendpoint.oidc import userinfo from oidcendpoint.oidc.authorization import Authorization from oidcendpoint.oidc.provider_config import ProviderConfiguration from oidcendpoint.oidc.registration import Registration from oidcendpoint.oidc.token import AccessToken -from oidcendpoint.session import setup_session +from oidcendpoint.session_management import SessionManager +from oidcendpoint.session_management import db_key +from oidcendpoint.session_management import unpack_db_key from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD from oidcendpoint.user_info import UserInfo @@ -165,133 +169,147 @@ def create_endpoint(self): "response_types": ["code", "token", "code id_token", "id_token"], } self.endpoint = endpoint_context.endpoint["userinfo"] + self.session_manager = SessionManager({}, endpoint_context.sdb.handler) + endpoint_context.session_manager = self.session_manager + self.user_id = "diana" + + def _create_session(self, auth_req, sub_type="public", sector_identifier=''): + client_id = auth_req['client_id'] + ae = create_authn_event(self.user_id, self.session_manager.salt) + self.session_manager.create_session(ae, auth_req, self.user_id, client_id=client_id, + sub_type=sub_type, sector_identifier=sector_identifier) + return db_key(self.user_id, client_id) + + def _do_grant(self, auth_req): + client_id = auth_req['client_id'] + # The user consent module produces a Grant instance + grant = Grant(scope=auth_req['scope'], resources=[client_id]) + + # the grant is assigned to a session (user_id, client_id) + self.session_manager.set([self.user_id, client_id, grant.id], grant) + return grant + + def _mint_code(self, grant): + # Constructing an authorization code is now done + return grant.mint_token( + 'authorization_code', + value=self.session_manager.token_handler["code"](self.user_id), + expires_at=time_sans_frac() + 300 # 5 minutes from now + ) + + def _mint_access_token(self, grant, client_id, token_ref=None): + _csi = self.session_manager.get([self.user_id, client_id]) + return grant.mint_token( + 'access_token', + value=self.session_manager.token_handler["access_token"]( + db_key(self.user_id, client_id), + client_id=client_id, + aud=grant.resources, + user_claims=None, + scope=grant.scope, + sub=_csi['sub'] + ), + expires_at=time_sans_frac() + 900, # 15 minutes from now + based_on=token_ref # Means the token (tok) was used to mint this token + ) def test_init(self): assert self.endpoint assert set( self.endpoint.endpoint_context.provider_info["claims_supported"] ) == { - "address", - "birthdate", - "email", - "email_verified", - "eduperson_scoped_affiliation", - "family_name", - "gender", - "given_name", - "locale", - "middle_name", - "name", - "nickname", - "phone_number", - "phone_number_verified", - "picture", - "preferred_username", - "profile", - "sub", - "updated_at", - "website", - "zoneinfo", - } + "address", + "birthdate", + "email", + "email_verified", + "eduperson_scoped_affiliation", + "family_name", + "gender", + "given_name", + "locale", + "middle_name", + "name", + "nickname", + "phone_number", + "phone_number_verified", + "picture", + "preferred_username", + "profile", + "sub", + "updated_at", + "website", + "zoneinfo", + } def test_parse(self): - session_id = setup_session( - self.endpoint.endpoint_context, - AUTH_REQ, - uid="userID", - authn_event={ - "authn_info": "loa1", - "uid": "diana", - "authn_time": utc_time_sans_frac(), - "valid_until": utc_time_sans_frac() + 3600, - }, - ) - _dic = self.endpoint.endpoint_context.sdb.upgrade_to_token(key=session_id) - _req = self.endpoint.parse_request( - {}, auth="Bearer {}".format(_dic["access_token"]) - ) - + self._create_session(AUTH_REQ) + grant = self._do_grant(AUTH_REQ) + # Free standing access token, not based on an authorization code + access_token = self._mint_access_token(grant, AUTH_REQ['client_id'], None) + _req = self.endpoint.parse_request({}, auth="Bearer {}".format(access_token.value)) assert set(_req.keys()) == {"client_id", "access_token"} + assert _req["client_id"] == AUTH_REQ['client_id'] + assert _req["access_token"] == access_token.value def test_parse_invalid_token(self): _req = self.endpoint.parse_request({}, auth="Bearer invalid") - assert _req['error'] == "invalid_token" def test_process_request(self): - session_id = setup_session( - self.endpoint.endpoint_context, - AUTH_REQ, - uid="userID", - authn_event={ - "authn_info": "loa1", - "uid": "diana", - "authn_time": utc_time_sans_frac(), - "valid_until": utc_time_sans_frac() + 3600, - }, - ) - _dic = self.endpoint.endpoint_context.sdb.upgrade_to_token(key=session_id) + self._create_session(AUTH_REQ) + grant = self._do_grant(AUTH_REQ) + code = self._mint_code(grant) + access_token = self._mint_access_token(grant, AUTH_REQ['client_id'], code) + _req = self.endpoint.parse_request( - {}, auth="Bearer {}".format(_dic["access_token"]) + {}, auth="Bearer {}".format(access_token.value) ) args = self.endpoint.process_request(_req) assert args def test_process_request_not_allowed(self): - session_id = setup_session( - self.endpoint.endpoint_context, - AUTH_REQ, - uid="userID", - authn_event={ - "authn_info": "loa1", - "uid": "diana", - "authn_time": utc_time_sans_frac() - 7200, - "valid_until": utc_time_sans_frac() - 3600, - }, - ) - _dic = self.endpoint.endpoint_context.sdb.upgrade_to_token(key=session_id) + self._create_session(AUTH_REQ) + grant = self._do_grant(AUTH_REQ) + code = self._mint_code(grant) + access_token = self._mint_access_token(grant, AUTH_REQ['client_id'], code) + + _us_info = self.session_manager.get([self.user_id]) + # 2 things can make the request invalid. + # 1) The token is not valid anymore or 2) The event is not valid. + _event = _us_info["authentication_event"] + _event['authn_time'] -= 9000 + _event['valid_until'] -= 9000 + _req = self.endpoint.parse_request( - {}, auth="Bearer {}".format(_dic["access_token"]) + {}, auth="Bearer {}".format(access_token.value) ) args = self.endpoint.process_request(_req) assert set(args["response_args"].keys()) == {"error", "error_description"} - def test_process_request_offline_access(self): - auth_req = AUTH_REQ.copy() - auth_req["scope"] = ["openid", "offline_access"] - session_id = setup_session( - self.endpoint.endpoint_context, - auth_req, - uid="userID", - authn_event={ - "authn_info": "loa1", - "uid": "diana", - "authn_time": utc_time_sans_frac() , - "valid_until": utc_time_sans_frac() + 3600, - }, - ) - _dic = self.endpoint.endpoint_context.sdb.upgrade_to_token(key=session_id) - _req = self.endpoint.parse_request( - {}, auth="Bearer {}".format(_dic["access_token"]) - ) - args = self.endpoint.process_request(_req) - assert set(args["response_args"].keys()) == {"sub"} + # Offline access is presently not checked. + # + # def test_process_request_offline_access(self): + # auth_req = AUTH_REQ.copy() + # auth_req["scope"] = ["openid", "offline_access"] + # self._create_session(auth_req) + # grant = self._do_grant(auth_req) + # code = self._mint_code(grant) + # access_token = self._mint_access_token(grant, auth_req['client_id'], code) + # + # _req = self.endpoint.parse_request( + # {}, auth="Bearer {}".format(access_token.value) + # ) + # args = self.endpoint.process_request(_req) + # assert set(args["response_args"].keys()) =={'response_args', 'client_id'} def test_do_response(self): - session_id = setup_session( - self.endpoint.endpoint_context, - AUTH_REQ, - uid="userID", - authn_event={ - "authn_info": "loa1", - "uid": "diana", - "authn_time": utc_time_sans_frac(), - "valid_until": utc_time_sans_frac() + 3600, - }, - ) - _dic = self.endpoint.endpoint_context.sdb.upgrade_to_token(key=session_id) + self._create_session(AUTH_REQ) + grant = self._do_grant(AUTH_REQ) + code = self._mint_code(grant) + access_token = self._mint_access_token(grant, AUTH_REQ['client_id'], code) + _req = self.endpoint.parse_request( - {}, auth="Bearer {}".format(_dic["access_token"]) + {}, auth="Bearer {}".format(access_token.value) ) args = self.endpoint.process_request(_req) assert args @@ -299,24 +317,15 @@ def test_do_response(self): assert res def test_do_signed_response(self): - self.endpoint.endpoint_context.cdb["client_1"][ - "userinfo_signed_response_alg" - ] = "ES256" - - session_id = setup_session( - self.endpoint.endpoint_context, - AUTH_REQ, - uid="userID", - authn_event={ - "authn_info": "loa1", - "uid": "diana", - "authn_time": utc_time_sans_frac(), - "valid_until": utc_time_sans_frac() + 3600, - }, - ) - _dic = self.endpoint.endpoint_context.sdb.upgrade_to_token(key=session_id) + self.endpoint.endpoint_context.cdb["client_1"]["userinfo_signed_response_alg"] = "ES256" + + self._create_session(AUTH_REQ) + grant = self._do_grant(AUTH_REQ) + code = self._mint_code(grant) + access_token = self._mint_access_token(grant, AUTH_REQ['client_id'], code) + _req = self.endpoint.parse_request( - {}, auth="Bearer {}".format(_dic["access_token"]) + {}, auth="Bearer {}".format(access_token.value) ) args = self.endpoint.process_request(_req) assert args @@ -326,28 +335,24 @@ def test_do_signed_response(self): def test_custom_scope(self): _auth_req = AUTH_REQ.copy() _auth_req["scope"] = ["openid", "research_and_scholarship"] - session_id = setup_session( - self.endpoint.endpoint_context, - _auth_req, - uid="userID", - authn_event={ - "authn_info": "loa1", - "uid": "diana", - "authn_time": utc_time_sans_frac(), - "valid_until": utc_time_sans_frac() + 3600, - }, - ) - _dic = self.endpoint.endpoint_context.sdb.upgrade_to_token(key=session_id) + + _sid = self._create_session(_auth_req) + grant = self._do_grant(_auth_req) + code = self._mint_code(grant) + access_token = self._mint_access_token(grant, AUTH_REQ['client_id'], code) + + user_id, client_id = unpack_db_key(_sid) + self.endpoint.claims_interface.add_claims_by_scope = True + grant.claims = { + "userinfo": self.endpoint.claims_interface.get_claims(client_id=client_id, + user_id=user_id, + scopes=_auth_req["scope"]) + } + _req = self.endpoint.parse_request( - {}, auth="Bearer {}".format(_dic["access_token"]) + {}, auth="Bearer {}".format(access_token.value) ) args = self.endpoint.process_request(_req) - assert set(args["response_args"].keys()) == { - "sub", - "name", - "given_name", - "family_name", - "email", - "email_verified", - "eduperson_scoped_affiliation", - } + assert set(args["response_args"].keys()) == {'eduperson_scoped_affiliation', 'given_name', + 'email_verified', 'email', 'family_name', + 'name', 'sub'} diff --git a/tests/test_26_oidc_userinfo_endpoint.py.no b/tests/test_26_oidc_userinfo_endpoint.py.no new file mode 100755 index 0000000..41f2bb9 --- /dev/null +++ b/tests/test_26_oidc_userinfo_endpoint.py.no @@ -0,0 +1,353 @@ +import json +import os + +import pytest +from cryptojwt.jwt import utc_time_sans_frac +from oidcmsg.oidc import AccessTokenRequest +from oidcmsg.oidc import AuthorizationRequest + +from oidcendpoint import user_info +from oidcendpoint.endpoint_context import EndpointContext +from oidcendpoint.id_token import IDToken +from oidcendpoint.oidc import userinfo +from oidcendpoint.oidc.authorization import Authorization +from oidcendpoint.oidc.provider_config import ProviderConfiguration +from oidcendpoint.oidc.registration import Registration +from oidcendpoint.oidc.token import AccessToken +from oidcendpoint.session import setup_session +from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD +from oidcendpoint.user_info import UserInfo + +KEYDEFS = [ + {"type": "RSA", "key": "", "use": ["sig"]}, + {"type": "EC", "crv": "P-256", "use": ["sig"]}, +] + +RESPONSE_TYPES_SUPPORTED = [ + ["code"], + ["token"], + ["id_token"], + ["code", "token"], + ["code", "id_token"], + ["id_token", "token"], + ["code", "token", "id_token"], + ["none"], +] + +CAPABILITIES = { + "subject_types_supported": ["public", "pairwise"], + "grant_types_supported": [ + "authorization_code", + "implicit", + "urn:ietf:params:oauth:grant-type:jwt-bearer", + "refresh_token", + ], +} + +AUTH_REQ = AuthorizationRequest( + client_id="client_1", + redirect_uri="https://example.com/cb", + scope=["openid"], + state="STATE", + response_type="code", +) + +TOKEN_REQ = AccessTokenRequest( + client_id="client_1", + redirect_uri="https://example.com/cb", + state="STATE", + grant_type="authorization_code", + client_secret="hemligt", +) + +TOKEN_REQ_DICT = TOKEN_REQ.to_dict() + +BASEDIR = os.path.abspath(os.path.dirname(__file__)) + + +def full_path(local_file): + return os.path.join(BASEDIR, local_file) + + +USERINFO = UserInfo(json.loads(open(full_path("users.json")).read())) + + +class TestEndpoint(object): + @pytest.fixture(autouse=True) + def create_endpoint(self): + conf = { + "issuer": "https://example.com/", + "password": "mycket hemligt", + "token_expires_in": 600, + "grant_expires_in": 300, + "refresh_token_expires_in": 86400, + "verify_ssl": False, + "capabilities": CAPABILITIES, + "keys": {"uri_path": "jwks.json", "key_defs": KEYDEFS}, + "id_token": {"class": IDToken, "kwargs": {}}, + "endpoint": { + "provider_config": { + "path": ".well-known/openid-configuration", + "class": ProviderConfiguration, + "kwargs": {}, + }, + "registration": { + "path": "registration", + "class": Registration, + "kwargs": {}, + }, + "authorization": { + "path": "authorization", + "class": Authorization, + "kwargs": {}, + }, + "token": { + "path": "token", + "class": AccessToken, + "kwargs": { + "client_authn_methods": [ + "client_secret_post", + "client_secret_basic", + "client_secret_jwt", + "private_key_jwt", + ] + }, + }, + "userinfo": { + "path": "userinfo", + "class": userinfo.UserInfo, + "kwargs": { + "claim_types_supported": [ + "normal", + "aggregated", + "distributed", + ], + "client_authn_method": ["bearer_header"], + }, + }, + }, + "userinfo": { + "class": user_info.UserInfo, + "kwargs": {"db_file": full_path("users.json")}, + }, + # "client_authn": verify_client, + "authentication": { + "anon": { + "acr": INTERNETPROTOCOLPASSWORD, + "class": "oidcendpoint.user_authn.user.NoAuthn", + "kwargs": {"user": "diana"}, + } + }, + "template_dir": "template", + "add_on": { + "custom_scopes": { + "function": "oidcendpoint.oidc.add_on.custom_scopes.add_custom_scopes", + "kwargs": { + "research_and_scholarship": [ + "name", + "given_name", + "family_name", + "email", + "email_verified", + "sub", + "eduperson_scoped_affiliation", + ] + }, + } + }, + } + endpoint_context = EndpointContext(conf) + endpoint_context.cdb["client_1"] = { + "client_secret": "hemligt", + "redirect_uris": [("https://example.com/cb", None)], + "client_salt": "salted", + "token_endpoint_auth_method": "client_secret_post", + "response_types": ["code", "token", "code id_token", "id_token"], + } + self.endpoint = endpoint_context.endpoint["userinfo"] + + def test_init(self): + assert self.endpoint + assert set( + self.endpoint.endpoint_context.provider_info["claims_supported"] + ) == { + "address", + "birthdate", + "email", + "email_verified", + "eduperson_scoped_affiliation", + "family_name", + "gender", + "given_name", + "locale", + "middle_name", + "name", + "nickname", + "phone_number", + "phone_number_verified", + "picture", + "preferred_username", + "profile", + "sub", + "updated_at", + "website", + "zoneinfo", + } + + def test_parse(self): + session_id = setup_session( + self.endpoint.endpoint_context, + AUTH_REQ, + uid="userID", + authn_event={ + "authn_info": "loa1", + "uid": "diana", + "authn_time": utc_time_sans_frac(), + "valid_until": utc_time_sans_frac() + 3600, + }, + ) + _dic = self.endpoint.endpoint_context.sdb.upgrade_to_token(key=session_id) + _req = self.endpoint.parse_request( + {}, auth="Bearer {}".format(_dic["access_token"]) + ) + + assert set(_req.keys()) == {"client_id", "access_token"} + + def test_parse_invalid_token(self): + _req = self.endpoint.parse_request({}, auth="Bearer invalid") + + assert _req['error'] == "invalid_token" + + def test_process_request(self): + session_id = setup_session( + self.endpoint.endpoint_context, + AUTH_REQ, + uid="userID", + authn_event={ + "authn_info": "loa1", + "uid": "diana", + "authn_time": utc_time_sans_frac(), + "valid_until": utc_time_sans_frac() + 3600, + }, + ) + _dic = self.endpoint.endpoint_context.sdb.upgrade_to_token(key=session_id) + _req = self.endpoint.parse_request( + {}, auth="Bearer {}".format(_dic["access_token"]) + ) + args = self.endpoint.process_request(_req) + assert args + + def test_process_request_not_allowed(self): + session_id = setup_session( + self.endpoint.endpoint_context, + AUTH_REQ, + uid="userID", + authn_event={ + "authn_info": "loa1", + "uid": "diana", + "authn_time": utc_time_sans_frac() - 7200, + "valid_until": utc_time_sans_frac() - 3600, + }, + ) + _dic = self.endpoint.endpoint_context.sdb.upgrade_to_token(key=session_id) + _req = self.endpoint.parse_request( + {}, auth="Bearer {}".format(_dic["access_token"]) + ) + args = self.endpoint.process_request(_req) + assert set(args["response_args"].keys()) == {"error", "error_description"} + + def test_process_request_offline_access(self): + auth_req = AUTH_REQ.copy() + auth_req["scope"] = ["openid", "offline_access"] + session_id = setup_session( + self.endpoint.endpoint_context, + auth_req, + uid="userID", + authn_event={ + "authn_info": "loa1", + "uid": "diana", + "authn_time": utc_time_sans_frac() , + "valid_until": utc_time_sans_frac() + 3600, + }, + ) + _dic = self.endpoint.endpoint_context.sdb.upgrade_to_token(key=session_id) + _req = self.endpoint.parse_request( + {}, auth="Bearer {}".format(_dic["access_token"]) + ) + args = self.endpoint.process_request(_req) + assert set(args["response_args"].keys()) == {"sub"} + + def test_do_response(self): + session_id = setup_session( + self.endpoint.endpoint_context, + AUTH_REQ, + uid="userID", + authn_event={ + "authn_info": "loa1", + "uid": "diana", + "authn_time": utc_time_sans_frac(), + "valid_until": utc_time_sans_frac() + 3600, + }, + ) + _dic = self.endpoint.endpoint_context.sdb.upgrade_to_token(key=session_id) + _req = self.endpoint.parse_request( + {}, auth="Bearer {}".format(_dic["access_token"]) + ) + args = self.endpoint.process_request(_req) + assert args + res = self.endpoint.do_response(request=_req, **args) + assert res + + def test_do_signed_response(self): + self.endpoint.endpoint_context.cdb["client_1"][ + "userinfo_signed_response_alg" + ] = "ES256" + + session_id = setup_session( + self.endpoint.endpoint_context, + AUTH_REQ, + uid="userID", + authn_event={ + "authn_info": "loa1", + "uid": "diana", + "authn_time": utc_time_sans_frac(), + "valid_until": utc_time_sans_frac() + 3600, + }, + ) + _dic = self.endpoint.endpoint_context.sdb.upgrade_to_token(key=session_id) + _req = self.endpoint.parse_request( + {}, auth="Bearer {}".format(_dic["access_token"]) + ) + args = self.endpoint.process_request(_req) + assert args + res = self.endpoint.do_response(request=_req, **args) + assert res + + def test_custom_scope(self): + _auth_req = AUTH_REQ.copy() + _auth_req["scope"] = ["openid", "research_and_scholarship"] + session_id = setup_session( + self.endpoint.endpoint_context, + _auth_req, + uid="userID", + authn_event={ + "authn_info": "loa1", + "uid": "diana", + "authn_time": utc_time_sans_frac(), + "valid_until": utc_time_sans_frac() + 3600, + }, + ) + _dic = self.endpoint.endpoint_context.sdb.upgrade_to_token(key=session_id) + _req = self.endpoint.parse_request( + {}, auth="Bearer {}".format(_dic["access_token"]) + ) + args = self.endpoint.process_request(_req) + assert set(args["response_args"].keys()) == { + "sub", + "name", + "given_name", + "family_name", + "email", + "email_verified", + "eduperson_scoped_affiliation", + } diff --git a/tests/test_30_oidc_end_session.py b/tests/test_30_oidc_end_session.py index 088e591..6e73bf7 100644 --- a/tests/test_30_oidc_end_session.py +++ b/tests/test_30_oidc_end_session.py @@ -4,14 +4,14 @@ from urllib.parse import parse_qs from urllib.parse import urlparse -import pytest -import responses from cryptojwt.key_jar import build_keyjar from oidcmsg.exception import InvalidRequest from oidcmsg.message import Message from oidcmsg.oidc import AuthorizationRequest from oidcmsg.oidc import verified_claim_name from oidcmsg.oidc import verify_id_token +import pytest +import responses from oidcendpoint.common.authorization import join_query from oidcendpoint.cookie import CookieDealer @@ -25,6 +25,7 @@ from oidcendpoint.oidc.session import Session from oidcendpoint.oidc.session import do_front_channel_logout_iframe from oidcendpoint.oidc.token import AccessToken +from oidcendpoint.session_management import db_key from oidcendpoint.token_handler import UnknownToken from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD from oidcendpoint.user_info import UserInfo @@ -192,15 +193,14 @@ def create_endpoint(self): def test_end_session_endpoint(self): # End session not allowed if no cookie and no id_token_hint is sent # (can't determine session) - with pytest.raises(UnknownToken): + with pytest.raises(ValueError): _ = self.session_endpoint.process_request("", cookie="FAIL") - def _create_cookie(self, user, sid, state, client_id): + def _create_cookie(self, session_id, state, client_id): ec = self.session_endpoint.endpoint_context return new_cookie( ec, - sub=user, - sid=sid, + sid=session_id, state=state, client_id=client_id, cookie_name=ec.cookie_name["session"], @@ -228,14 +228,14 @@ def _code_auth2(self, state): _pr_resp = self.authn_endpoint.parse_request(req.to_dict()) _resp = self.authn_endpoint.process_request(_pr_resp) - def _get_sid(self): - _sdb = self.session_endpoint.endpoint_context.sdb - - for _sid in _sdb.keys(): - if _sid.startswith("__state__"): - continue - else: - return _sid + # def _get_sid(self): + # _mngr = self.session_endpoint.endpoint_context.session_manager + # + # for _sid in _sdb.keys(): + # if _sid.startswith("__state__"): + # continue + # else: + # return _sid def _auth_with_id_token(self, state): req = AuthorizationRequest( @@ -253,8 +253,7 @@ def _auth_with_id_token(self, state): def test_end_session_endpoint_with_cookie(self): self._code_auth("1234567") - _sid = self._get_sid() - cookie = self._create_cookie("diana", _sid, "1234567", "client_1") + cookie = self._create_cookie(db_key("diana", "client_1"), "1234567", "client_1") _req_args = self.session_endpoint.parse_request({"state": "1234567"}) resp = self.session_endpoint.process_request(_req_args, cookie=cookie) @@ -272,7 +271,7 @@ def test_end_session_endpoint_with_cookie(self): def test_end_session_endpoint_with_wrong_cookie(self): self._code_auth("1234567") - cookie = self._create_cookie("diana", "client_2", "abcdefg", "client_1") + cookie = self._create_cookie(db_key("diana", "client_1"), "abcdefg", "client_2") with pytest.raises(UnknownToken): self.session_endpoint.process_request({"state": "abcde"}, cookie=cookie) @@ -281,7 +280,7 @@ def test_end_session_endpoint_with_cookie_wrong_user(self): # Need cookie and ID Token to figure this out id_token = self._auth_with_id_token("1234567") - cookie = self._create_cookie("diggins", "_sid_", "1234567", "client_1") + cookie = self._create_cookie(db_key("diggins", "client_1"), "1234567", "client_1") msg = Message(id_token=id_token) verify_id_token(msg, keyjar=self.session_endpoint.endpoint_context.keyjar) @@ -298,7 +297,7 @@ def test_end_session_endpoint_with_cookie_unknown_sid(self): id_token = self._auth_with_id_token("1234567") # Wrong client_id - cookie = self._create_cookie("diana", "_sid_", "state", "client_1") + cookie = self._create_cookie(db_key("diana", "client_1"), "state", "client_1") msg = Message(id_token=id_token) verify_id_token(msg, keyjar=self.session_endpoint.endpoint_context.keyjar) @@ -314,8 +313,7 @@ def test_end_session_endpoint_with_cookie_dual_login(self): self._code_auth("1234567") self._code_auth2("abcdefg") _sdb = self.session_endpoint.endpoint_context.sdb - _sid = self._get_sid() - cookie = self._create_cookie("diana", _sid, "1234567", "client_1") + cookie = self._create_cookie(db_key("diana", "client_1"), "1234567", "client_1") resp = self.session_endpoint.process_request({"state": "abcde"}, cookie=cookie) @@ -334,8 +332,7 @@ def test_end_session_endpoint_with_post_logout_redirect_uri(self): self._code_auth("1234567") self._code_auth2("abcdefg") _sdb = self.session_endpoint.endpoint_context.sdb - _sid = self._get_sid() - cookie = self._create_cookie("diana", _sid, "1234567", "client_1") + cookie = self._create_cookie(db_key("diana", "client_1"), "1234567", "client_1") post_logout_redirect_uri = join_query( *self.session_endpoint.endpoint_context.cdb["client_1"][ @@ -359,8 +356,7 @@ def test_end_session_endpoint_with_wrong_post_logout_redirect_uri(self): id_token = self._auth_with_id_token("1234567") _sdb = self.session_endpoint.endpoint_context.sdb - _sid = self._get_sid() - cookie = self._create_cookie("diana", _sid, "1234567", "client_1") + cookie = self._create_cookie(db_key("diana", "client_1"), "1234567", "client_1") post_logout_redirect_uri = "https://demo.example.com/log_out" @@ -453,8 +449,8 @@ def test_logout_from_client_bc(self): "backchannel_logout_uri" ] = "https://example.com/bc_logout" self.session_endpoint.endpoint_context.cdb["client_1"]["client_id"] = "client_1" - _sid = self._get_sid() - res = self.session_endpoint.logout_from_client(_sid, "client_1") + _sid = db_key() + res = self.session_endpoint.logout_from_client(db_key(), "client_1") assert set(res.keys()) == {"blu"} assert set(res["blu"].keys()) == {"client_1"} _spec = res["blu"]["client_1"] @@ -475,7 +471,7 @@ def test_logout_from_client_fc(self): "frontchannel_logout_uri" ] = "https://example.com/fc_logout" self.session_endpoint.endpoint_context.cdb["client_1"]["client_id"] = "client_1" - _sid = self._get_sid() + _sid = db_key("diana", "client_1") res = self.session_endpoint.logout_from_client(_sid, "client_1") assert set(res.keys()) == {"flu"} assert set(res["flu"].keys()) == {"client_1"} @@ -499,7 +495,7 @@ def test_logout_from_client(self): ] = "https://example.com/fc_logout" self.session_endpoint.endpoint_context.cdb["client_2"]["client_id"] = "client_2" - _sid = self._get_sid() + _sid = db_key("diana", "client_1") res = self.session_endpoint.logout_all_clients(_sid, "client_1") assert res @@ -527,7 +523,7 @@ def test_do_verified_logout(self): _cdb["client_1"]["backchannel_logout_uri"] = "https://example.com/bc_logout" _cdb["client_1"]["client_id"] = "client_1" - _sid = self._get_sid() + _sid = db_key("diana", "client_1") res = self.session_endpoint.do_verified_logout(_sid, "client_1") assert res == [] @@ -565,9 +561,9 @@ def test_logout_from_client_no_session(self): ] = "https://example.com/fc_logout" self.session_endpoint.endpoint_context.cdb["client_2"]["client_id"] = "client_2" - _sid = self._get_sid() + _sid = db_key("diana", "client_1") - self.session_endpoint.endpoint_context.sdb.sso_db.delete("diana", "sid") + self.session_endpoint.endpoint_context.session_manager.delete("diana", "client_1") res = self.session_endpoint.logout_all_clients(_sid, "client_1") assert res == {} diff --git a/tests/test_30_oidc_end_session.py.no b/tests/test_30_oidc_end_session.py.no new file mode 100644 index 0000000..088e591 --- /dev/null +++ b/tests/test_30_oidc_end_session.py.no @@ -0,0 +1,573 @@ +import copy +import json +import os +from urllib.parse import parse_qs +from urllib.parse import urlparse + +import pytest +import responses +from cryptojwt.key_jar import build_keyjar +from oidcmsg.exception import InvalidRequest +from oidcmsg.message import Message +from oidcmsg.oidc import AuthorizationRequest +from oidcmsg.oidc import verified_claim_name +from oidcmsg.oidc import verify_id_token + +from oidcendpoint.common.authorization import join_query +from oidcendpoint.cookie import CookieDealer +from oidcendpoint.cookie import new_cookie +from oidcendpoint.endpoint_context import EndpointContext +from oidcendpoint.exception import RedirectURIError +from oidcendpoint.oidc import userinfo +from oidcendpoint.oidc.authorization import Authorization +from oidcendpoint.oidc.provider_config import ProviderConfiguration +from oidcendpoint.oidc.registration import Registration +from oidcendpoint.oidc.session import Session +from oidcendpoint.oidc.session import do_front_channel_logout_iframe +from oidcendpoint.oidc.token import AccessToken +from oidcendpoint.token_handler import UnknownToken +from oidcendpoint.user_authn.authn_context import INTERNETPROTOCOLPASSWORD +from oidcendpoint.user_info import UserInfo + +ISS = "https://example.com/" + +CLI1 = "https://client1.example.com/" +CLI2 = "https://client2.example.com/" + +KEYDEFS = [ + {"type": "RSA", "use": ["sig"]}, + {"type": "EC", "crv": "P-256", "use": ["sig"]}, +] + +KEYJAR = build_keyjar(KEYDEFS) +KEYJAR.import_jwks(KEYJAR.export_jwks(private=True), ISS) + +RESPONSE_TYPES_SUPPORTED = [ + ["code"], + ["token"], + ["id_token"], + ["code", "token"], + ["code", "id_token"], + ["id_token", "token"], + ["code", "token", "id_token"], + ["none"], +] + +CAPABILITIES = { + "response_types_supported": [" ".join(x) for x in RESPONSE_TYPES_SUPPORTED], + "token_endpoint_auth_methods_supported": [ + "client_secret_post", + "client_secret_basic", + "client_secret_jwt", + "private_key_jwt", + ], + "response_modes_supported": ["query", "fragment", "form_post"], + "subject_types_supported": ["public", "pairwise"], + "grant_types_supported": [ + "authorization_code", + "implicit", + "urn:ietf:params:oauth:grant-type:jwt-bearer", + "refresh_token", + ], + "claim_types_supported": ["normal", "aggregated", "distributed"], + "claims_parameter_supported": True, + "request_parameter_supported": True, + "request_uri_parameter_supported": True, +} + +AUTH_REQ = AuthorizationRequest( + client_id="client_1", + redirect_uri="{}cb".format(ISS), + scope=["openid"], + state="STATE", + response_type="code", + client_secret="hemligt", +) + +AUTH_REQ_DICT = AUTH_REQ.to_dict() + +BASEDIR = os.path.abspath(os.path.dirname(__file__)) + + +def full_path(local_file): + return os.path.join(BASEDIR, local_file) + + +USERINFO_db = json.loads(open(full_path("users.json")).read()) + + +class TestEndpoint(object): + @pytest.fixture(autouse=True) + def create_endpoint(self): + conf = { + "issuer": ISS, + "password": "mycket hemlig zebra", + "token_expires_in": 600, + "grant_expires_in": 300, + "refresh_token_expires_in": 86400, + "verify_ssl": False, + "capabilities": CAPABILITIES, + "keys": {"uri_path": "jwks.json", "key_defs": KEYDEFS}, + "endpoint": { + "provider_config": { + "path": "{}/.well-known/openid-configuration", + "class": ProviderConfiguration, + "kwargs": {"client_authn_method": None}, + }, + "registration": { + "path": "{}/registration", + "class": Registration, + "kwargs": {"client_authn_method": None}, + }, + "authorization": { + "path": "{}/authorization", + "class": Authorization, + "kwargs": {"client_authn_method": None}, + }, + "token": {"path": "{}/token", "class": AccessToken, "kwargs": {}}, + "userinfo": { + "path": "{}/userinfo", + "class": userinfo.UserInfo, + "kwargs": {"db_file": "users.json"}, + }, + "session": { + "path": "{}/end_session", + "class": Session, + "kwargs": { + "post_logout_uri_path": "post_logout", + "signing_alg": "ES256", + "logout_verify_url": "{}/verify_logout".format(ISS), + "client_authn_method": None, + }, + }, + }, + "authentication": { + "anon": { + "acr": INTERNETPROTOCOLPASSWORD, + "class": "oidcendpoint.user_authn.user.NoAuthn", + "kwargs": {"user": "diana"}, + } + }, + "userinfo": {"class": UserInfo, "kwargs": {"db": USERINFO_db}}, + "template_dir": "template", + # 'cookie_name':{ + # 'session': 'oidcop', + # 'register': 'oidcreg' + # } + } + cookie_conf = { + "sign_key": "ghsNKDDLshZTPn974nOsIGhedULrsqnsGoBFBLwUKuJhE2ch", + "default_values": { + "name": "oidcop", + "domain": "127.0.0.1", + "path": "/", + "max_age": 3600, + }, + } + + self.cd = CookieDealer(**cookie_conf) + endpoint_context = EndpointContext(conf, cookie_dealer=self.cd, keyjar=KEYJAR) + endpoint_context.cdb = { + "client_1": { + "client_secret": "hemligt", + "redirect_uris": [("{}cb".format(CLI1), None)], + "client_salt": "salted", + "token_endpoint_auth_method": "client_secret_post", + "response_types": ["code", "token", "code id_token", "id_token"], + "post_logout_redirect_uris": [("{}logout_cb".format(CLI1), "")], + }, + "client_2": { + "client_secret": "hemligare", + "redirect_uris": [("{}cb".format(CLI2), None)], + "client_salt": "saltare", + "token_endpoint_auth_method": "client_secret_post", + "response_types": ["code", "token", "code id_token", "id_token"], + "post_logout_redirect_uris": [("{}logout_cb".format(CLI2), "")], + }, + } + self.authn_endpoint = endpoint_context.endpoint["authorization"] + self.session_endpoint = endpoint_context.endpoint["session"] + self.token_endpoint = endpoint_context.endpoint["token"] + + def test_end_session_endpoint(self): + # End session not allowed if no cookie and no id_token_hint is sent + # (can't determine session) + with pytest.raises(UnknownToken): + _ = self.session_endpoint.process_request("", cookie="FAIL") + + def _create_cookie(self, user, sid, state, client_id): + ec = self.session_endpoint.endpoint_context + return new_cookie( + ec, + sub=user, + sid=sid, + state=state, + client_id=client_id, + cookie_name=ec.cookie_name["session"], + ) + + def _code_auth(self, state): + req = AuthorizationRequest( + state=state, + response_type="code", + redirect_uri="{}cb".format(CLI1), + scope=["openid"], + client_id="client_1", + ) + _pr_resp = self.authn_endpoint.parse_request(req.to_dict()) + _resp = self.authn_endpoint.process_request(_pr_resp) + + def _code_auth2(self, state): + req = AuthorizationRequest( + state=state, + response_type="code", + redirect_uri="{}cb".format(CLI2), + scope=["openid"], + client_id="client_2", + ) + _pr_resp = self.authn_endpoint.parse_request(req.to_dict()) + _resp = self.authn_endpoint.process_request(_pr_resp) + + def _get_sid(self): + _sdb = self.session_endpoint.endpoint_context.sdb + + for _sid in _sdb.keys(): + if _sid.startswith("__state__"): + continue + else: + return _sid + + def _auth_with_id_token(self, state): + req = AuthorizationRequest( + state=state, + response_type="id_token", + redirect_uri="{}cb".format(CLI1), + scope=["openid"], + client_id="client_1", + nonce="_nonce_", + ) + _pr_resp = self.authn_endpoint.parse_request(req.to_dict()) + _resp = self.authn_endpoint.process_request(_pr_resp) + + return _resp["response_args"]["id_token"] + + def test_end_session_endpoint_with_cookie(self): + self._code_auth("1234567") + _sid = self._get_sid() + cookie = self._create_cookie("diana", _sid, "1234567", "client_1") + + _req_args = self.session_endpoint.parse_request({"state": "1234567"}) + resp = self.session_endpoint.process_request(_req_args, cookie=cookie) + + # returns a signed JWT to be put in a verification web page shown to + # the user + + p = urlparse(resp["redirect_location"]) + qs = parse_qs(p.query) + jwt_info = self.session_endpoint.unpack_signed_jwt(qs["sjwt"][0]) + + assert jwt_info["user"] == "diana" + assert jwt_info["client_id"] == "client_1" + assert jwt_info["redirect_uri"] == "https://example.com/post_logout" + + def test_end_session_endpoint_with_wrong_cookie(self): + self._code_auth("1234567") + cookie = self._create_cookie("diana", "client_2", "abcdefg", "client_1") + + with pytest.raises(UnknownToken): + self.session_endpoint.process_request({"state": "abcde"}, cookie=cookie) + + def test_end_session_endpoint_with_cookie_wrong_user(self): + # Need cookie and ID Token to figure this out + id_token = self._auth_with_id_token("1234567") + + cookie = self._create_cookie("diggins", "_sid_", "1234567", "client_1") + + msg = Message(id_token=id_token) + verify_id_token(msg, keyjar=self.session_endpoint.endpoint_context.keyjar) + + msg2 = Message(id_token_hint=id_token) + msg2[verified_claim_name("id_token_hint")] = msg[ + verified_claim_name("id_token") + ] + with pytest.raises(ValueError): + self.session_endpoint.process_request(msg2, cookie=cookie) + + def test_end_session_endpoint_with_cookie_unknown_sid(self): + # Need cookie and ID Token to figure this out + id_token = self._auth_with_id_token("1234567") + + # Wrong client_id + cookie = self._create_cookie("diana", "_sid_", "state", "client_1") + + msg = Message(id_token=id_token) + verify_id_token(msg, keyjar=self.session_endpoint.endpoint_context.keyjar) + + msg2 = Message(id_token_hint=id_token) + msg2[verified_claim_name("id_token_hint")] = msg[ + verified_claim_name("id_token") + ] + with pytest.raises(ValueError): + self.session_endpoint.process_request(msg2, cookie=cookie) + + def test_end_session_endpoint_with_cookie_dual_login(self): + self._code_auth("1234567") + self._code_auth2("abcdefg") + _sdb = self.session_endpoint.endpoint_context.sdb + _sid = self._get_sid() + cookie = self._create_cookie("diana", _sid, "1234567", "client_1") + + resp = self.session_endpoint.process_request({"state": "abcde"}, cookie=cookie) + + # returns a signed JWT to be put in a verification web page shown to + # the user + + p = urlparse(resp["redirect_location"]) + qs = parse_qs(p.query) + jwt_info = self.session_endpoint.unpack_signed_jwt(qs["sjwt"][0]) + + assert jwt_info["user"] == "diana" + assert jwt_info["client_id"] == "client_1" + assert jwt_info["redirect_uri"] == "https://example.com/post_logout" + + def test_end_session_endpoint_with_post_logout_redirect_uri(self): + self._code_auth("1234567") + self._code_auth2("abcdefg") + _sdb = self.session_endpoint.endpoint_context.sdb + _sid = self._get_sid() + cookie = self._create_cookie("diana", _sid, "1234567", "client_1") + + post_logout_redirect_uri = join_query( + *self.session_endpoint.endpoint_context.cdb["client_1"][ + "post_logout_redirect_uris" + ][0] + ) + + with pytest.raises(InvalidRequest): + self.session_endpoint.process_request( + { + "post_logout_redirect_uri": post_logout_redirect_uri, + "state": "abcde", + }, + cookie=cookie, + ) + + def test_end_session_endpoint_with_wrong_post_logout_redirect_uri(self): + self._code_auth("1234567") + self._code_auth2("abcdefg") + + id_token = self._auth_with_id_token("1234567") + + _sdb = self.session_endpoint.endpoint_context.sdb + _sid = self._get_sid() + cookie = self._create_cookie("diana", _sid, "1234567", "client_1") + + post_logout_redirect_uri = "https://demo.example.com/log_out" + + msg = Message(id_token=id_token) + verify_id_token(msg, keyjar=self.session_endpoint.endpoint_context.keyjar) + + with pytest.raises(RedirectURIError): + self.session_endpoint.process_request( + { + "post_logout_redirect_uri": post_logout_redirect_uri, + "state": "abcde", + "id_token_hint": id_token, + verified_claim_name("id_token_hint"): msg[ + verified_claim_name("id_token") + ], + }, + cookie=cookie, + ) + + def test_back_channel_logout_no_uri(self): + self._code_auth("1234567") + + res = self.session_endpoint.do_back_channel_logout( + self.session_endpoint.endpoint_context.cdb["client_1"], "username", 0 + ) + assert res is None + + def test_back_channel_logout(self): + self._code_auth("1234567") + + _cdb = copy.copy(self.session_endpoint.endpoint_context.cdb["client_1"]) + _cdb["backchannel_logout_uri"] = "https://example.com/bc_logout" + _cdb["client_id"] = "client_1" + res = self.session_endpoint.do_back_channel_logout(_cdb, "username", "_sid_") + assert isinstance(res, tuple) + assert res[0] == "https://example.com/bc_logout" + _jwt = self.session_endpoint.unpack_signed_jwt(res[1], "RS256") + assert _jwt + assert _jwt["iss"] == ISS + assert _jwt["aud"] == ["client_1"] + assert _jwt["sub"] == "username" + assert _jwt["sid"] == "_sid_" + + def test_front_channel_logout(self): + self._code_auth("1234567") + + _cdb = copy.copy(self.session_endpoint.endpoint_context.cdb["client_1"]) + _cdb["frontchannel_logout_uri"] = "https://example.com/fc_logout" + _cdb["client_id"] = "client_1" + res = do_front_channel_logout_iframe(_cdb, ISS, "_sid_") + assert res == '