diff --git a/src/oidcop/token/__init__.py b/src/oidcop/token/__init__.py index b3309afa..a9bcd791 100755 --- a/src/oidcop/token/__init__.py +++ b/src/oidcop/token/__init__.py @@ -15,6 +15,13 @@ logger = logging.getLogger(__name__) +ALT_TOKEN_NAME = { + "authorization_code": "A", + "access_token": "T", + "refresh_token": "R", + "id_token": "I" +} + def is_expired(exp, when=0): if exp < 0: @@ -28,6 +35,11 @@ def is_expired(exp, when=0): class Token(object): def __init__(self, token_class, lifetime=300, **kwargs): self.token_class = token_class + try: + self.alt_token_name = ALT_TOKEN_NAME[token_class] + except KeyError: + self.alt_token_name = "" + self.lifetime = lifetime self.kwargs = kwargs @@ -70,7 +82,8 @@ def __init__(self, password, token_class="", token_type="Bearer", **kwargs): self.crypt = Crypt(password) self.token_type = token_type - def __call__(self, session_id: Optional[str] = "", token_class: Optional[str] = "", **payload) -> str: + def __call__(self, session_id: Optional[str] = "", token_class: Optional[str] = "", + **payload) -> str: """ Return a token. @@ -112,9 +125,10 @@ def info(self, token: str) -> dict: :return: dictionary with info about the token """ _res = dict(zip(["_id", "token_class", "sid", "exp"], self.split_token(token))) - if _res["token_class"] != self.token_class: + if _res["token_class"] not in [self.token_class, self.alt_token_name]: raise WrongTokenClass(_res["token_class"]) else: + _res["token_class"] = self.token_class _res["handler"] = self return _res diff --git a/src/oidcop/token/jwt_token.py b/src/oidcop/token/jwt_token.py index d002e416..1329f64d 100644 --- a/src/oidcop/token/jwt_token.py +++ b/src/oidcop/token/jwt_token.py @@ -6,11 +6,12 @@ from oidcop.exception import ToOld from oidcop.token import Crypt +from oidcop.token.exception import WrongTokenClass + from . import Token from . import is_expired from .exception import UnknownToken - # TYPE_MAP = {"A": "code", "T": "access_token", "R": "refresh_token"} @@ -58,10 +59,11 @@ def __call__(self, session_id: Optional[str] = "", token_class: Optional[str] = :param payload: A dictionary with information that is part of the payload of the JWT. :return: Signed JSON Web Token """ - if not token_class and self.token_class: - token_class = self.token_class - else: - token_class = "authorization_code" + if not token_class: + if self.token_class: + token_class = self.token_class + else: + token_class = "authorization_code" payload.update({"sid": session_id, "token_class": token_class}) payload = self.load_custom_claims(payload) @@ -86,14 +88,22 @@ def get_payload(self, token): def info(self, token): """ - Return type of Token (A=Access code, T=Token, R=Refresh token) and - the session id. + Return token information :param token: A token - :return: tuple of token type and session id + :return: dictionary with token information """ _payload = self.get_payload(token) + _class = _payload.get("ttype") + if _class is None: + _class = _payload.get("token_class") + + if _class not in [self.token_class, self.alt_token_name]: + raise WrongTokenClass(_payload["token_class"]) + else: + _payload["token_class"] = self.token_class + if is_expired(_payload["exp"]): raise ToOld("Token has expired") # All the token metadata diff --git a/tests/test_35_oidc_token_endpoint.py b/tests/test_35_oidc_token_endpoint.py index ba4c1dc4..054602cd 100755 --- a/tests/test_35_oidc_token_endpoint.py +++ b/tests/test_35_oidc_token_endpoint.py @@ -1,7 +1,7 @@ +import base64 import json import os -from oidcop.configure import OPConfiguration import pytest from cryptojwt import JWT from cryptojwt.key_jar import build_keyjar @@ -15,6 +15,7 @@ from oidcop.authn_event import create_authn_event from oidcop.authz import AuthzHandling from oidcop.client_authn import verify_client +from oidcop.configure import OPConfiguration from oidcop.cookie_handler import CookieHandler from oidcop.exception import UnAuthorizedClient from oidcop.oidc import userinfo @@ -26,6 +27,7 @@ from oidcop.session import MintingNotAllowed from oidcop.user_authn.authn_context import INTERNETPROTOCOLPASSWORD from oidcop.user_info import UserInfo +from oidcop.util import lv_pack KEYDEFS = [ {"type": "RSA", "key": "", "use": ["sig"]}, @@ -113,7 +115,7 @@ def conf(): }, "refresh": { "class": "oidcop.token.jwt_token.JWTToken", - "kwargs": {"lifetime": 3600, "aud": ["https://example.org/appl"],}, + "kwargs": {"lifetime": 3600, "aud": ["https://example.org/appl"], }, }, "id_token": {"class": "oidcop.token.id_token.IDToken", "kwargs": {}}, }, @@ -127,8 +129,8 @@ def conf(): "class": ProviderConfiguration, "kwargs": {}, }, - "registration": {"path": "registration", "class": Registration, "kwargs": {},}, - "authorization": {"path": "authorization", "class": Authorization, "kwargs": {},}, + "registration": {"path": "registration", "class": Registration, "kwargs": {}, }, + "authorization": {"path": "authorization", "class": Authorization, "kwargs": {}, }, "token": { "path": "token", "class": Token, @@ -164,7 +166,7 @@ def conf(): "usage_rules": { "authorization_code": { "expires_in": 300, - "supports_minting": ["access_token", "refresh_token", "id_token",], + "supports_minting": ["access_token", "refresh_token", "id_token", ], "max_usage": 1, }, "access_token": {"expires_in": 600}, @@ -741,3 +743,97 @@ def test_configure_grant_types(self): assert len(self.token_endpoint.helper) == 1 assert "access_token" in self.token_endpoint.helper assert "refresh_token" not in self.token_endpoint.helper + + +class TestOldTokens(object): + @pytest.fixture(autouse=True) + def create_endpoint(self, conf): + server = Server(OPConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) + + endpoint_context = server.endpoint_context + endpoint_context.cdb["client_1"] = { + "client_secret": "hemligt", + "redirect_uris": [("https://example.com/cb", None)], + "client_salt": "salted", + "endpoint_auth_method": "client_secret_post", + "response_types": ["code", "token", "code id_token", "id_token"], + } + endpoint_context.keyjar.import_jwks(CLIENT_KEYJAR.export_jwks(), "client_1") + self.session_manager = endpoint_context.session_manager + self.token_endpoint = server.server_get("endpoint", "token") + self.user_id = "diana" + self.endpoint_context = endpoint_context + + def _create_session(self, auth_req, sub_type="public", sector_identifier=""): + if sector_identifier: + authz_req = auth_req.copy() + authz_req["sector_identifier_uri"] = sector_identifier + else: + authz_req = auth_req + client_id = authz_req["client_id"] + ae = create_authn_event(self.user_id) + return self.session_manager.create_session( + ae, authz_req, self.user_id, client_id=client_id, sub_type=sub_type + ) + + def _mint_code(self, grant, client_id): + session_id = self.session_manager.encrypted_session_id(self.user_id, client_id, grant.id) + usage_rules = grant.usage_rules.get("authorization_code", {}) + _exp_in = usage_rules.get("expires_in") + + # Constructing an authorization code is now done + _code = grant.mint_token( + session_id=session_id, + endpoint_context=self.endpoint_context, + token_class="authorization_code", + token_handler=self.session_manager.token_handler["authorization_code"], + usage_rules=usage_rules, + ) + + if _exp_in: + if isinstance(_exp_in, str): + _exp_in = int(_exp_in) + if _exp_in: + _code.expires_at = utc_time_sans_frac() + _exp_in + return _code + + def test_old_default_token(self): + session_id = self._create_session(AUTH_REQ) + grant = self.session_manager[session_id] + code = self._mint_code(grant, AUTH_REQ["client_id"]) + + # pack and unpack + _handler = self.session_manager.token_handler.handler["authorization_code"] + _res = dict(zip(["_id", "token_class", "sid", "exp"], _handler.split_token(code.value))) + + _old_type_value = base64.b64encode( + _handler.crypt.encrypt(lv_pack(_res["_id"], "A", _res["sid"], _res["exp"]).encode()) + ).decode("utf-8") + + _info = self.session_manager.token_handler.info(_old_type_value) + assert _info["token_class"] == "authorization_code" + + def test_old_jwt_token(self): + session_id = self._create_session(AUTH_REQ) + grant = self.session_manager[session_id] + code = self._mint_code(grant, AUTH_REQ["client_id"]) + + _handler = self.session_manager.token_handler.handler["access_token"] + _old_type_token = _handler(session_id=session_id, token_class="T") + + _info = self.session_manager.token_handler.info(_old_type_token) + assert _info["token_class"] == "access_token" + + payload = {"sid": session_id, "ttype": "T"} + payload = _handler.load_custom_claims(payload) + + # payload.update(kwargs) + _context = _handler.server_get("endpoint_context") + signer = JWT( + key_jar=_context.keyjar, iss=_handler.issuer, lifetime=300, sign_alg=_handler.alg, + ) + + _old_type_token = signer.pack(payload) + + _info = self.session_manager.token_handler.info(_old_type_token) + assert _info["token_class"] == "access_token"