Skip to content

Commit

Permalink
Refactored the sdb module.
Browse files Browse the repository at this point in the history
  • Loading branch information
Roland Hedberg committed Nov 24, 2014
1 parent c6b8b8e commit 2f22b2f
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 77 deletions.
4 changes: 2 additions & 2 deletions oidc_example/op2/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,8 +526,8 @@ def application(environ, start_response):
else:
kwargs["verify_ssl"] = True

OAS = Provider(config.issuer, SessionDB(), cdb, ac, None, authz,
verify_client, config.SYM_KEY, **kwargs)
OAS = Provider(config.issuer, SessionDB(config.baseurl), cdb, ac, None,
authz, verify_client, config.SYM_KEY, **kwargs)

for authn in ac:
authn.srv = OAS
Expand Down
11 changes: 5 additions & 6 deletions src/oic/oauth2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,13 @@
import requests
import random
import string
import copy
import cookielib
import logging
from Cookie import SimpleCookie

from oic.utils.keyio import KeyJar
from oic.utils.time_util import utc_time_sans_frac
from oic.utils.time_util import utc_now
from oic.exception import UnSupported
import logging

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -668,8 +666,9 @@ def construct_ResourceRequest(self, request=ResourceRequest,
request_args["access_token"] = token.access_token
return self.construct_request(request, request_args, extra_args)

def get_or_post(self, uri, method, req,
content_type=DEFAULT_POST_CONTENT_TYPE, accept=None, **kwargs):
@staticmethod
def get_or_post(uri, method, req, content_type=DEFAULT_POST_CONTENT_TYPE,
accept=None, **kwargs):
if method == "GET":
_qp = req.to_urlencoded()
if _qp:
Expand All @@ -688,7 +687,7 @@ def get_or_post(self, uri, method, req,
"Unsupported content type: '%s'" % content_type)

header_ext = {"Content-type": content_type}
if (accept):
if accept:
header_ext = {"Accept": accept}

if "headers" in kwargs.keys():
Expand Down
5 changes: 3 additions & 2 deletions src/oic/oauth2/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import sys
import urllib
import urlparse
from oic.utils.sdb import AccessCodeUsed
from oic.utils.sdb import AccessCodeUsed, AuthnEvent

__author__ = 'rohe0002'

Expand Down Expand Up @@ -483,6 +483,7 @@ def authorization_endpoint(self, request="", cookie="", authn="", **kwargs):
return _authn(**authn_args)
else:
user = identity["uid"]
aevent = AuthnEvent(user, authn_info=acr)

# If I get this far the person is already authenticated
logger.debug("- authenticated -")
Expand All @@ -493,7 +494,7 @@ def authorization_endpoint(self, request="", cookie="", authn="", **kwargs):
except KeyError:
oidc_req = None

skey = self.sdb.create_authz_session(user, areq, oidreq=oidc_req)
skey = self.sdb.create_authz_session(aevent, areq, oidreq=oidc_req)

# Now about the authorization step.
try:
Expand Down
86 changes: 60 additions & 26 deletions src/oic/oic/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import urllib
import sys
from jwkest.jwe import JWE
from oic.utils import time_util
from oic.utils.authn.user import NoSuchAuthentication
from oic.utils.authn.user import ToOld
from oic.utils.authn.user import TamperAllert
from oic.utils.sdb import AuthnEvent
from oic.utils.time_util import utc_time_sans_frac
from oic.utils.keyio import KeyBundle, dump_jwks
from oic.utils.keyio import key_export
Expand Down Expand Up @@ -370,6 +370,11 @@ def required_user(areq):
req_user = areq["id_token"]["sub"]
except KeyError:
pass
else:
try:
assert areq["client_id"] in areq["id_token"]["aud"]
except AssertionError:
req_user = "" # Not allow to use

return req_user

Expand Down Expand Up @@ -539,6 +544,23 @@ def verify_endpoint(self, request="", cookie=None, **kwargs):
kwargs["cookie"] = cookie
return authn.verify(_req, **kwargs)

def setup_session(self, areq, authn_event, cinfo):
try:
oidc_req = areq["request"]
except KeyError:
oidc_req = None

sid = self.sdb.create_authz_session(authn_event, areq, oidreq=oidc_req)
kwargs = {}
for param in ["sector_id", "preferred_id_type"]:
try:
kwargs[param] = cinfo[param]
except KeyError:
pass

self.sdb.do_sub(sid, **kwargs)
return sid

def authorization_endpoint(self, request="", cookie=None, **kwargs):
""" The AuthorizationRequest endpoint
Expand Down Expand Up @@ -587,15 +609,27 @@ def authorization_endpoint(self, request="", cookie=None, **kwargs):
return areq
logger.debug("AuthzRequest+oidc_request: %s" % (areq.to_dict(),))

cinfo = self.cdb[areq["client_id"]]
req_user = self.required_user(areq)
if req_user:
try:
sids = self.sdb.sub2sid[req_user]
except KeyError:
pass
else:
# anyone will do
authn_event = self.sdb[sids[0]]["authn_event"]
# Is the authentication event to be regarded as valid ?
if authn_event.valid():
sid = self.setup_session(areq, authn_event, cinfo)
return self.authz_part2(authn_event.uid, areq, sid)

authn, authn_class_ref = self.pick_auth(areq)
if not authn:
authn, authn_class_ref = self.pick_auth(areq, "better")
if not authn:
authn, authn_class_ref = self.pick_auth(areq, "any")

logger.debug("Cookie: %s" % cookie)
try:
try:
_auth_info = kwargs["authn"]
Expand All @@ -604,7 +638,6 @@ def authorization_endpoint(self, request="", cookie=None, **kwargs):
identity = authn.authenticated_as(cookie,
authorization=_auth_info,
max_age=self.max_age(areq))
auth_time = time_util.utc_now()
except (NoSuchAuthentication, ToOld, TamperAllert):
identity = None

Expand All @@ -613,7 +646,6 @@ def authorization_endpoint(self, request="", cookie=None, **kwargs):
"as_user": req_user,
"authn_class_ref": authn_class_ref}

cinfo = self.cdb[areq["client_id"]]
for attr in ["policy_uri", "logo_uri"]:
try:
authn_args[attr] = cinfo[attr]
Expand Down Expand Up @@ -644,16 +676,12 @@ def authorization_endpoint(self, request="", cookie=None, **kwargs):
else:
return authn(**authn_args)

authn_event = AuthnEvent(identity["uid"], authn_info=authn_class_ref)

logger.debug("- authenticated -")
logger.debug("AREQ keys: %s" % areq.keys())

try:
oidc_req = areq["request"]
except KeyError:
oidc_req = None

sid = self.sdb.create_authz_session(user, areq, oidreq=oidc_req,
auth_time=auth_time)
sid = self.setup_session(areq, authn_event, cinfo)
return self.authz_part2(user, areq, sid)

def userinfo_in_id_token_claims(self, session):
Expand Down Expand Up @@ -773,7 +801,7 @@ def _access_token_endpoint(self, req, **kwargs):
try:
_idtoken = self.sign_encrypt_id_token(
_info, client_info, req, user_info=userinfo,
auth_time=_info["auth_time"])
auth_time=_info["authn_event"].authn_time)
except (JWEException, NoSuitableSigningKeys) as err:
logger.warning(str(err))
return self._error(error="access_denied",
Expand Down Expand Up @@ -801,9 +829,9 @@ def _refresh_access_token_endpoint(self, req, **kwargs):

if "openid" in _info["scope"]:
userinfo = self.userinfo_in_id_token_claims(_info)
_idtoken = self.sign_encrypt_id_token(_info, client_info, req,
user_info=userinfo,
auth_time=_info["auth_time"])
_idtoken = self.sign_encrypt_id_token(
_info, client_info, req, user_info=userinfo,
auth_time=_info["authn_event"].authn_time)
sid = _sdb.token.get_key(rtoken)
_sdb.update(sid, "id_token", _idtoken)

Expand Down Expand Up @@ -890,7 +918,7 @@ def _collect_user_info(self, session, userinfo_claims=None):
logger.debug("userinfo_claim: %s" % userinfo_claims.to_dict())

logger.debug("Session info: %s" % session)
info = self.userinfo(session["local_sub"], userinfo_claims)
info = self.userinfo(session["authn_event"].uid, userinfo_claims)

info["sub"] = session["sub"]
logger.debug("user_info_response: %s" % (info,))
Expand All @@ -911,11 +939,15 @@ def signed_userinfo(self, client_info, userinfo, session):
except KeyError: # Fall back to default
algo = self.jwx_def["sign_alg"]["userinfo"]

# Use my key for signing
key = self.keyjar.get_signing_key(alg2keytype(algo), "", alg=algo)
if not key:
return self._error(error="access_denied",
descr="Missing signing key")
if algo == "none":
key = []
else:
# Use my key for signing
key = self.keyjar.get_signing_key(alg2keytype(algo), "", alg=algo)
if not key:
return self._error(error="access_denied",
descr="Missing signing key")

jinfo = userinfo.to_jwt(key, algo)
if "userinfo_encrypted_response_alg" in client_info:
# encrypt with clients public key
Expand Down Expand Up @@ -1172,16 +1204,18 @@ def do_client_registration(self, request, client_id, ignore=None):
"invalid_configuration_parameter",
descr="%s pointed to illegal URL" % item)

# necessary keys ?
# Do I have the necessary keys
for item in ["id_token_signed_response_alg",
"userinfo_signed_response_alg"]:
if item in request:
if request[item] in self.capabilities[PREFERENCE2PROVIDER[item]]:
ktyp = jws.alg2keytype(request[item])
# do I have this ktyp and for EC type keys the curve
_k = self.keyjar.get_signing_key(ktyp, alg=request[item])
if not _k:
del _cinfo[item]
if ktyp not in ["none", "OCT"]:
_k = self.keyjar.get_signing_key(ktyp,
alg=request[item])
if not _k:
del _cinfo[item]

try:
self.keyjar.load_keys(request, client_id)
Expand Down Expand Up @@ -1616,7 +1650,7 @@ def authz_part2(self, user, areq, sid, **kwargs):
# or 'code id_token'
id_token = self.sign_encrypt_id_token(
_sinfo, client_info, areq, user_info=user_info,
auth_time=_sinfo["auth_time"], **hargs)
auth_time=_sinfo["authn_event"].authn_time, **hargs)

aresp["id_token"] = id_token
_sinfo["id_token"] = id_token
Expand Down
Loading

0 comments on commit 2f22b2f

Please sign in to comment.