Skip to content
This repository was archived by the owner on Jun 23, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions src/oidcop/token/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@

logger = logging.getLogger(__name__)

ALT_TOKEN_NAME = {
"authorization_code": "A",
"access_token": "T",
"refresh_token": "R",
"id_token": "I"
}


def is_expired(exp, when=0):
if exp < 0:
Expand All @@ -28,6 +35,11 @@ def is_expired(exp, when=0):
class Token(object):
def __init__(self, token_class, lifetime=300, **kwargs):
self.token_class = token_class
try:
self.alt_token_name = ALT_TOKEN_NAME[token_class]
except KeyError:
self.alt_token_name = ""

self.lifetime = lifetime
self.kwargs = kwargs

Expand Down Expand Up @@ -70,7 +82,8 @@ def __init__(self, password, token_class="", token_type="Bearer", **kwargs):
self.crypt = Crypt(password)
self.token_type = token_type

def __call__(self, session_id: Optional[str] = "", token_class: Optional[str] = "", **payload) -> str:
def __call__(self, session_id: Optional[str] = "", token_class: Optional[str] = "",
**payload) -> str:
"""
Return a token.

Expand Down Expand Up @@ -112,9 +125,10 @@ def info(self, token: str) -> dict:
:return: dictionary with info about the token
"""
_res = dict(zip(["_id", "token_class", "sid", "exp"], self.split_token(token)))
if _res["token_class"] != self.token_class:
if _res["token_class"] not in [self.token_class, self.alt_token_name]:
raise WrongTokenClass(_res["token_class"])
else:
_res["token_class"] = self.token_class
_res["handler"] = self
return _res

Expand Down
26 changes: 18 additions & 8 deletions src/oidcop/token/jwt_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@

from oidcop.exception import ToOld
from oidcop.token import Crypt
from oidcop.token.exception import WrongTokenClass

from . import Token
from . import is_expired
from .exception import UnknownToken


# TYPE_MAP = {"A": "code", "T": "access_token", "R": "refresh_token"}


Expand Down Expand Up @@ -58,10 +59,11 @@ def __call__(self, session_id: Optional[str] = "", token_class: Optional[str] =
:param payload: A dictionary with information that is part of the payload of the JWT.
:return: Signed JSON Web Token
"""
if not token_class and self.token_class:
token_class = self.token_class
else:
token_class = "authorization_code"
if not token_class:
if self.token_class:
token_class = self.token_class
else:
token_class = "authorization_code"

payload.update({"sid": session_id, "token_class": token_class})
payload = self.load_custom_claims(payload)
Expand All @@ -86,14 +88,22 @@ def get_payload(self, token):

def info(self, token):
"""
Return type of Token (A=Access code, T=Token, R=Refresh token) and
the session id.
Return token information

:param token: A token
:return: tuple of token type and session id
:return: dictionary with token information
"""
_payload = self.get_payload(token)

_class = _payload.get("ttype")
if _class is None:
_class = _payload.get("token_class")

if _class not in [self.token_class, self.alt_token_name]:
raise WrongTokenClass(_payload["token_class"])
else:
_payload["token_class"] = self.token_class

if is_expired(_payload["exp"]):
raise ToOld("Token has expired")
# All the token metadata
Expand Down
106 changes: 101 additions & 5 deletions tests/test_35_oidc_token_endpoint.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import base64
import json
import os

from oidcop.configure import OPConfiguration
import pytest
from cryptojwt import JWT
from cryptojwt.key_jar import build_keyjar
Expand All @@ -15,6 +15,7 @@
from oidcop.authn_event import create_authn_event
from oidcop.authz import AuthzHandling
from oidcop.client_authn import verify_client
from oidcop.configure import OPConfiguration
from oidcop.cookie_handler import CookieHandler
from oidcop.exception import UnAuthorizedClient
from oidcop.oidc import userinfo
Expand All @@ -26,6 +27,7 @@
from oidcop.session import MintingNotAllowed
from oidcop.user_authn.authn_context import INTERNETPROTOCOLPASSWORD
from oidcop.user_info import UserInfo
from oidcop.util import lv_pack

KEYDEFS = [
{"type": "RSA", "key": "", "use": ["sig"]},
Expand Down Expand Up @@ -113,7 +115,7 @@ def conf():
},
"refresh": {
"class": "oidcop.token.jwt_token.JWTToken",
"kwargs": {"lifetime": 3600, "aud": ["https://example.org/appl"],},
"kwargs": {"lifetime": 3600, "aud": ["https://example.org/appl"], },
},
"id_token": {"class": "oidcop.token.id_token.IDToken", "kwargs": {}},
},
Expand All @@ -127,8 +129,8 @@ def conf():
"class": ProviderConfiguration,
"kwargs": {},
},
"registration": {"path": "registration", "class": Registration, "kwargs": {},},
"authorization": {"path": "authorization", "class": Authorization, "kwargs": {},},
"registration": {"path": "registration", "class": Registration, "kwargs": {}, },
"authorization": {"path": "authorization", "class": Authorization, "kwargs": {}, },
"token": {
"path": "token",
"class": Token,
Expand Down Expand Up @@ -164,7 +166,7 @@ def conf():
"usage_rules": {
"authorization_code": {
"expires_in": 300,
"supports_minting": ["access_token", "refresh_token", "id_token",],
"supports_minting": ["access_token", "refresh_token", "id_token", ],
"max_usage": 1,
},
"access_token": {"expires_in": 600},
Expand Down Expand Up @@ -741,3 +743,97 @@ def test_configure_grant_types(self):
assert len(self.token_endpoint.helper) == 1
assert "access_token" in self.token_endpoint.helper
assert "refresh_token" not in self.token_endpoint.helper


class TestOldTokens(object):
@pytest.fixture(autouse=True)
def create_endpoint(self, conf):
server = Server(OPConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR)

endpoint_context = server.endpoint_context
endpoint_context.cdb["client_1"] = {
"client_secret": "hemligt",
"redirect_uris": [("https://example.com/cb", None)],
"client_salt": "salted",
"endpoint_auth_method": "client_secret_post",
"response_types": ["code", "token", "code id_token", "id_token"],
}
endpoint_context.keyjar.import_jwks(CLIENT_KEYJAR.export_jwks(), "client_1")
self.session_manager = endpoint_context.session_manager
self.token_endpoint = server.server_get("endpoint", "token")
self.user_id = "diana"
self.endpoint_context = endpoint_context

def _create_session(self, auth_req, sub_type="public", sector_identifier=""):
if sector_identifier:
authz_req = auth_req.copy()
authz_req["sector_identifier_uri"] = sector_identifier
else:
authz_req = auth_req
client_id = authz_req["client_id"]
ae = create_authn_event(self.user_id)
return self.session_manager.create_session(
ae, authz_req, self.user_id, client_id=client_id, sub_type=sub_type
)

def _mint_code(self, grant, client_id):
session_id = self.session_manager.encrypted_session_id(self.user_id, client_id, grant.id)
usage_rules = grant.usage_rules.get("authorization_code", {})
_exp_in = usage_rules.get("expires_in")

# Constructing an authorization code is now done
_code = grant.mint_token(
session_id=session_id,
endpoint_context=self.endpoint_context,
token_class="authorization_code",
token_handler=self.session_manager.token_handler["authorization_code"],
usage_rules=usage_rules,
)

if _exp_in:
if isinstance(_exp_in, str):
_exp_in = int(_exp_in)
if _exp_in:
_code.expires_at = utc_time_sans_frac() + _exp_in
return _code

def test_old_default_token(self):
session_id = self._create_session(AUTH_REQ)
grant = self.session_manager[session_id]
code = self._mint_code(grant, AUTH_REQ["client_id"])

# pack and unpack
_handler = self.session_manager.token_handler.handler["authorization_code"]
_res = dict(zip(["_id", "token_class", "sid", "exp"], _handler.split_token(code.value)))

_old_type_value = base64.b64encode(
_handler.crypt.encrypt(lv_pack(_res["_id"], "A", _res["sid"], _res["exp"]).encode())
).decode("utf-8")

_info = self.session_manager.token_handler.info(_old_type_value)
assert _info["token_class"] == "authorization_code"

def test_old_jwt_token(self):
session_id = self._create_session(AUTH_REQ)
grant = self.session_manager[session_id]
code = self._mint_code(grant, AUTH_REQ["client_id"])

_handler = self.session_manager.token_handler.handler["access_token"]
_old_type_token = _handler(session_id=session_id, token_class="T")

_info = self.session_manager.token_handler.info(_old_type_token)
assert _info["token_class"] == "access_token"

payload = {"sid": session_id, "ttype": "T"}
payload = _handler.load_custom_claims(payload)

# payload.update(kwargs)
_context = _handler.server_get("endpoint_context")
signer = JWT(
key_jar=_context.keyjar, iss=_handler.issuer, lifetime=300, sign_alg=_handler.alg,
)

_old_type_token = signer.pack(payload)

_info = self.session_manager.token_handler.info(_old_type_token)
assert _info["token_class"] == "access_token"