Skip to content
This repository was archived by the owner on Jun 12, 2021. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
265 changes: 166 additions & 99 deletions src/oidcendpoint/oidc/token_coop.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from oidcmsg.oidc import AccessTokenResponse
from oidcmsg.oidc import RefreshAccessTokenRequest
from oidcmsg.oidc import TokenErrorResponse
from oidcmsg.storage import importer

from oidcendpoint import sanitize
from oidcendpoint.cookie import new_cookie
Expand All @@ -22,74 +23,95 @@
logger = logging.getLogger(__name__)


class TokenCoop(Endpoint):
request_cls = oidc.Message
response_cls = oidc.AccessTokenResponse
error_cls = TokenErrorResponse
request_format = "json"
request_placement = "body"
response_format = "json"
response_placement = "body"
endpoint_name = "token_endpoint"
name = "token"
default_capabilities = {"token_endpoint_auth_signing_alg_values_supported": None}
class TokenEndpointHelper:
def __init__(self, endpoint, config=None):
self.endpoint = endpoint
self.config = config

def __init__(self, endpoint_context, new_refresh_token=False, **kwargs):
Endpoint.__init__(self, endpoint_context, **kwargs)
self.post_parse_request.append(self._post_parse_request)
if "client_authn_method" in kwargs:
self.endpoint_info["token_endpoint_auth_methods_supported"] = kwargs[
"client_authn_method"
]
self.allow_refresh = False
self.new_refresh_token = new_refresh_token
def post_parse_request(self, request, client_id="", **kwargs):
"""Context specific parsing of the request.
This is done after general request parsing and before processing
the request.
"""
raise NotImplementedError

def _refresh_access_token(self, req, **kwargs):
_sdb = self.endpoint_context.sdb
def process_request(self, req, **kwargs):
"""Acts on a process request."""
raise NotImplementedError


class AccessToken(TokenEndpointHelper):
def post_parse_request(self, request, client_id="", **kwargs):
"""
This is where clients come to get their access tokens

:param request: The request
:param authn: Authentication info, comes from HTTP header
:returns:
"""

request = AccessTokenRequest(**request.to_dict())

error_cls = self._validate_state(request)
if error_cls is not None:
return error_cls

if "client_id" not in request: # Optional for access token request
request["client_id"] = client_id

logger.debug("%s: %s" % (request.__class__.__name__, sanitize(request)))

return request

def _validate_state(self, request):
if "state" not in request:
return

rtoken = req["refresh_token"]
try:
_info = _sdb.refresh_token(rtoken, new_refresh=self.new_refresh_token)
except ExpiredToken:
return self.error_cls(
error="invalid_request", error_description="Refresh token is expired"
)
sinfo = self.endpoint.endpoint_context.sdb[request["code"]]
except KeyError:
logger.error("Code not present in SessionDB")
return self.endpoint.error_cls(error="unauthorized_client")
else:
state = sinfo["authn_req"]["state"]

return by_schema(AccessTokenResponse, **_info)
if state != request["state"]:
logger.error("State value mismatch")
return self.endpoint.error_cls(error="unauthorized_client")

def _access_token(self, req, **kwargs):
_context = self.endpoint_context
def process_request(self, req, **kwargs):
_context = self.endpoint.endpoint_context
_sdb = _context.sdb
_log_debug = logger.debug

try:
_access_code = req["code"].replace(" ", "+")
except KeyError: # Missing code parameter - absolutely fatal
return self.error_cls(
return self.endpoint.error_cls(
error="invalid_request", error_description="Missing code"
)

# Session might not exist or _access_code malformed
try:
_info = _sdb[_access_code]
except KeyError:
return self.error_cls(
return self.endpoint.error_cls(
error="invalid_grant", error_description="Code is invalid"
)

_authn_req = _info["authn_req"]

# assert that the code is valid
if _context.sdb.is_session_revoked(_access_code):
return self.error_cls(
return self.endpoint.error_cls(
error="invalid_grant", error_description="Session is revoked"
)

# If redirect_uri was in the initial authorization request
# verify that the one given here is the correct one.
if "redirect_uri" in _authn_req:
if req["redirect_uri"] != _authn_req["redirect_uri"]:
return self.error_cls(
return self.endpoint.error_cls(
error="invalid_request", error_description="redirect_uri mismatch"
)

Expand All @@ -111,7 +133,7 @@ def _access_token(self, req, **kwargs):
logger.error("%s" % err)
# Should revoke the token issued to this access code
_sdb.revoke_all_tokens(_access_code)
return self.error_cls(
return self.endpoint.error_cls(
error="access_denied", error_description="Access Code already used"
)

Expand All @@ -131,82 +153,134 @@ def _access_token(self, req, **kwargs):

return by_schema(AccessTokenResponse, **_info)

def get_client_id_from_token(self, endpoint_context, token, request=None):
sinfo = endpoint_context.sdb[token]
return sinfo["authn_req"]["client_id"]

def _access_token_post_parse_request(self, request, client_id="", **kwargs):
class RefreshToken(TokenEndpointHelper):
def post_parse_request(self, request, client_id="", **kwargs):
"""
This is where clients come to get their access tokens
This is where clients come to refresh their access tokens

:param request: The request
:param authn: Authentication info, comes from HTTP header
:returns:
"""
request = RefreshAccessTokenRequest(**request.to_dict())

request = AccessTokenRequest(**request.to_dict())

if "state" in request:
try:
sinfo = self.endpoint_context.sdb[request["code"]]
except KeyError:
logger.error("Code not present in SessionDB")
return self.error_cls(error="unauthorized_client")
else:
state = sinfo["authn_req"]["state"]

if state != request["state"]:
logger.error("State value mismatch")
return self.error_cls(error="unauthorized_client")
# verify that the request message is correct
try:
request.verify(
keyjar=self.endpoint.endpoint_context.keyjar, opponent_id=client_id
)
except (MissingRequiredAttribute, ValueError, MissingRequiredValue) as err:
return self.endpoint.error_cls(
error="invalid_request", error_description="%s" % err
)

if "client_id" not in request: # Optional for access token request
if "client_id" not in request: # Optional for refresh access token request
request["client_id"] = client_id

logger.debug("%s: %s" % (request.__class__.__name__, sanitize(request)))

return request

def _refresh_token_post_parse_request(self, request, client_id="", **kwargs):
"""
This is where clients come to refresh their access tokens
def process_request(self, req, **kwargs):
_sdb = self.endpoint.endpoint_context.sdb

:param request: The request
:param authn: Authentication info, comes from HTTP header
:returns:
"""
rtoken = req["refresh_token"]
try:
_info = _sdb.refresh_token(
rtoken, new_refresh=self.endpoint.new_refresh_token
)
except ExpiredToken:
return self.error_cls(
error="invalid_request", error_description="Refresh token is expired"
)

request = RefreshAccessTokenRequest(**request.to_dict())
return by_schema(AccessTokenResponse, **_info)

# verify that the request message is correct
try:
request.verify(keyjar=self.endpoint_context.keyjar)
except (MissingRequiredAttribute, ValueError, MissingRequiredValue) as err:
return self.error_cls(error="invalid_request", error_description="%s" % err)

try:
keyjar = self.endpoint_context.keyjar
except AttributeError:
keyjar = ""
HELPER_BY_GRANT_TYPE = {
"authorization_code": AccessToken,
"refresh_token": RefreshToken,
}

request.verify(keyjar=keyjar, opponent_id=client_id)

if "client_id" not in request: # Optional for refresh access token request
request["client_id"] = client_id
class TokenCoop(Endpoint):
request_cls = oidc.Message
response_cls = oidc.AccessTokenResponse
error_cls = TokenErrorResponse
request_format = "json"
request_placement = "body"
response_format = "json"
response_placement = "body"
endpoint_name = "token_endpoint"
name = "token"
default_capabilities = {"token_endpoint_auth_signing_alg_values_supported": None}

logger.debug("%s: %s" % (request.__class__.__name__, sanitize(request)))
def __init__(self, endpoint_context, new_refresh_token=False, **kwargs):
Endpoint.__init__(self, endpoint_context, **kwargs)
self.post_parse_request.append(self._post_parse_request)
if "client_authn_method" in kwargs:
self.endpoint_info["token_endpoint_auth_methods_supported"] = kwargs[
"client_authn_method"
]

return request
# self.allow_refresh = False
self.new_refresh_token = new_refresh_token

def _post_parse_request(self, request, client_id="", **kwargs):
if request["grant_type"] == "authorization_code":
return self._access_token_post_parse_request(request, client_id, **kwargs)
else: # request["grant_type"] == "refresh_token":
if self.allow_refresh:
return self._refresh_token_post_parse_request(
request, client_id, **kwargs
self.configure_grant_types(kwargs.get("grant_types_supported"))

def configure_grant_types(self, grant_types_supported):
if grant_types_supported is None:
self.helper = {k: v(self) for k, v in HELPER_BY_GRANT_TYPE.items()}
return

self.helper = {}
# TODO: do we want to allow any grant_type?
for grant_type, grant_type_options in grant_types_supported.items():
if (
grant_type_options in (None, True)
and grant_type in HELPER_BY_GRANT_TYPE
):
self.helper[grant_type] = HELPER_BY_GRANT_TYPE[grant_type]
continue
elif grant_type_options is False:
continue

try:
grant_class = grant_type_options["class"]
except (KeyError, TypeError):
raise ProcessError(
"Token Endpoint's grant types must be True, None or a dict with a"
" 'class' key."
)
else:
raise ProcessError("Refresh Token not allowed")
_conf = grant_type_options.get("kwargs", {})

if isinstance(grant_class, str):
try:
grant_class = importer(grant_class)
except (ValueError, AttributeError):
raise ProcessError(
f"Token Endpoint's grant type class {grant_class} can't"
" be imported."
)
try:
self.helper[grant_type] = grant_class(self, _conf)
except Exception as e:
raise ProcessError(f"Failed to initialize class {grant_class}: {e}")

def get_client_id_from_token(self, endpoint_context, token, request=None):
sinfo = endpoint_context.sdb[token]
return sinfo["authn_req"]["client_id"]

def _post_parse_request(self, request, client_id="", **kwargs):
_helper = self.helper.get(request["grant_type"])
if _helper:
return _helper.post_parse_request(request, client_id, **kwargs)
else:
return self.error_cls(
error="invalid_request",
error_description=f"Unsupported grant_type: {request['grant_type']}"
)

def process_request(self, request=None, **kwargs):
"""
Expand All @@ -218,27 +292,20 @@ def process_request(self, request=None, **kwargs):
if isinstance(request, self.error_cls):
return request
try:
if request["grant_type"] == "authorization_code":
logger.debug("Access Token Request")
response_args = self._access_token(request, **kwargs)
elif request["grant_type"] == "refresh_token":
logger.debug("Refresh Access Token Request")
response_args = self._refresh_access_token(request, **kwargs)
_helper = self.helper.get(request["grant_type"])
if _helper:
response_args = _helper.process_request(request, **kwargs)
else:
return self.error_cls(
error="invalid_request", error_description="Wrong grant_type"
error="invalid_request",
error_description=f"Unsupported grant_type: {request['grant_type']}"
)
except JWEException as err:
return self.error_cls(error="invalid_request", error_description="%s" % err)

if isinstance(response_args, ResponseMessage):
return response_args

if request["grant_type"] == "authorization_code":
_token = request["code"].replace(" ", "+")
else:
_token = request["refresh_token"].replace(" ", "+")

_access_token = response_args["access_token"]
_cookie = new_cookie(
self.endpoint_context,
Expand Down
Loading