Skip to content
This repository was archived by the owner on Jun 23, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/oidcop/oidc/userinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from oidcmsg import oidc
from oidcmsg.message import Message
from oidcmsg.oauth2 import ResponseMessage
from oidcop.session.claims import claims_match

from oidcop.endpoint import Endpoint
from oidcop.token.exception import UnknownToken
Expand Down Expand Up @@ -140,6 +141,8 @@ def process_request(self, request=None, **kwargs):
user_id=_session_info["user_id"], claims_restriction=_claims
)
info["sub"] = _grant.sub
if _grant.add_acr_value("userinfo"):
info["acr"] = _grant.authentication_event["authn_info"]
else:
info = {
"error": "invalid_request",
Expand Down
13 changes: 13 additions & 0 deletions src/oidcop/session/grant.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from oidcop.authn_event import AuthnEvent
from oidcop.session import MintingNotAllowed
from oidcop.session.claims import claims_match
from oidcop.session.token import AccessToken
from oidcop.session.token import AuthorizationCode
from oidcop.session.token import IDToken
Expand Down Expand Up @@ -180,6 +181,14 @@ def find_scope(self, based_on):

return self.scope

def add_acr_value(self, claims_release_point):
_release = self.claims.get(claims_release_point)
if _release:
_acr_request = _release.get("acr")
_used_acr = self.authentication_event.get("authn_info")
return claims_match(_used_acr, _acr_request)
return False

def payload_arguments(
self,
session_id: str,
Expand Down Expand Up @@ -221,6 +230,10 @@ def payload_arguments(
user_info = endpoint_context.claims_interface.get_user_claims(user_id, _claims_restriction)
payload.update(user_info)

# Should I add the acr value
if self.add_acr_value(claims_release_point):
payload["acr"] = self.authentication_event["authn_info"]

return payload

def mint_token(
Expand Down
95 changes: 49 additions & 46 deletions src/oidcop/session/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, salt: Optional[str] = "", filename: Optional[str] = ""):
if os.path.isfile(filename):
self.salt = open(filename).read()
elif not os.path.isfile(filename) and os.path.exists(
filename
filename
): # Not a file, Something else
raise ConfigurationError("Salt filename points to something that is not a file")
else:
Expand Down Expand Up @@ -73,7 +73,8 @@ class SessionManager(Database):
init_args = ["handler"]

def __init__(
self, handler: TokenHandler, conf: Optional[dict] = None, sub_func: Optional[dict] = None,
self, handler: TokenHandler, conf: Optional[dict] = None,
sub_func: Optional[dict] = None,
):
self.conf = conf or {}

Expand Down Expand Up @@ -125,9 +126,9 @@ def __setattr__(self, key, value):

def _init_db(self):
Database.__init__(
self,
key=self.load_key(),
salt=self.load_salt()
self,
key=self.load_key(),
salt=self.load_salt()
)

def get_user_info(self, uid: str) -> UserSessionInfo:
Expand All @@ -153,14 +154,14 @@ def find_token(self, session_id: str, token_value: str) -> Optional[SessionToken
return None # pragma: no cover

def create_grant(
self,
authn_event: AuthnEvent,
auth_req: AuthorizationRequest,
user_id: str,
client_id: Optional[str] = "",
sub_type: Optional[str] = "public",
token_usage_rules: Optional[dict] = None,
scopes: Optional[list] = None,
self,
authn_event: AuthnEvent,
auth_req: AuthorizationRequest,
user_id: str,
client_id: Optional[str] = "",
sub_type: Optional[str] = "public",
token_usage_rules: Optional[dict] = None,
scopes: Optional[list] = None,
) -> str:
"""

Expand All @@ -175,29 +176,31 @@ def create_grant(
"""
sector_identifier = auth_req.get("sector_identifier_uri", "")

_claims = auth_req.get("claims", {})

grant = Grant(
authorization_request=auth_req,
authentication_event=authn_event,
sub=self.sub_func[sub_type](
user_id, salt=self.salt, sector_identifier=sector_identifier
),
sub=self.sub_func[sub_type](user_id, salt=self.salt,
sector_identifier=sector_identifier),
usage_rules=token_usage_rules,
scope=scopes,
claims=_claims
)

self.set([user_id, client_id, grant.id], grant)

return self.encrypted_session_id(user_id, client_id, grant.id)

def create_session(
self,
authn_event: AuthnEvent,
auth_req: AuthorizationRequest,
user_id: str,
client_id: Optional[str] = "",
sub_type: Optional[str] = "public",
token_usage_rules: Optional[dict] = None,
scopes: Optional[list] = None,
self,
authn_event: AuthnEvent,
auth_req: AuthorizationRequest,
user_id: str,
client_id: Optional[str] = "",
sub_type: Optional[str] = "public",
token_usage_rules: Optional[dict] = None,
scopes: Optional[list] = None,
) -> str:
"""
Create part of a user session. The parts added are user- and client
Expand Down Expand Up @@ -309,10 +312,10 @@ def revoke_token(self, session_id: str, token_value: str, recursive: bool = Fals
self._revoke_dependent(grant, token)

def get_authentication_events(
self,
session_id: Optional[str] = "",
user_id: Optional[str] = "",
client_id: Optional[str] = "",
self,
session_id: Optional[str] = "",
user_id: Optional[str] = "",
client_id: Optional[str] = "",
) -> List[AuthnEvent]:
"""
Return the authentication events that exists for a user/client combination.
Expand Down Expand Up @@ -371,10 +374,10 @@ def revoke_grant(self, session_id: str):
self.set(_path, _info)

def grants(
self,
session_id: Optional[str] = "",
user_id: Optional[str] = "",
client_id: Optional[str] = "",
self,
session_id: Optional[str] = "",
user_id: Optional[str] = "",
client_id: Optional[str] = "",
) -> List[Grant]:
"""
Find all grant connected to a user session
Expand All @@ -395,13 +398,13 @@ def grants(
return [self.get([user_id, client_id, gid]) for gid in _csi.subordinate]

def get_session_info(
self,
session_id: str,
user_session_info: bool = False,
client_session_info: bool = False,
grant: bool = False,
authentication_event: bool = False,
authorization_request: bool = False,
self,
session_id: str,
user_session_info: bool = False,
client_session_info: bool = False,
grant: bool = False,
authentication_event: bool = False,
authorization_request: bool = False,
) -> dict:
"""
Returns information connected to a session.
Expand Down Expand Up @@ -449,13 +452,13 @@ def get_session_info(
return res

def get_session_info_by_token(
self,
token_value: str,
user_session_info: bool = False,
client_session_info: bool = False,
grant: bool = False,
authentication_event: bool = False,
authorization_request: bool = False,
self,
token_value: str,
user_session_info: bool = False,
client_session_info: bool = False,
grant: bool = False,
authentication_event: bool = False,
authorization_request: bool = False,
) -> dict:
_token_info = self.token_handler.info(token_value)
sid = _token_info.get("sid")
Expand Down
26 changes: 24 additions & 2 deletions tests/test_05_id_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,11 @@ def full_path(local_file):
"acr": INTERNETPROTOCOLPASSWORD,
"class": "oidcop.user_authn.user.NoAuthn",
"kwargs": {"user": "diana"},
},
"mfa": {
"acr": 'https://refeds.org/profile/mfa',
"class": "oidcop.user_authn.user.NoAuthn",
"kwargs": {"user": "diana"},
}
},
"session_manager": {
Expand Down Expand Up @@ -170,15 +175,15 @@ def create_session_manager(self):
self.session_manager = self.endpoint_context.session_manager
self.user_id = USER_ID

def _create_session(self, auth_req, sub_type="public", sector_identifier=""):
def _create_session(self, auth_req, sub_type="public", sector_identifier="", authn_info=''):
if sector_identifier:
authz_req = auth_req.copy()
authz_req["sector_identifier_uri"] = sector_identifier
else:
authz_req = auth_req

client_id = authz_req["client_id"]
ae = create_authn_event(self.user_id)
ae = create_authn_event(self.user_id, authn_info=authn_info)
return self.session_manager.create_session(
ae, authz_req, self.user_id, client_id=client_id, sub_type=sub_type
)
Expand Down Expand Up @@ -587,3 +592,20 @@ def test_id_token_info(self):
get_sign_and_encrypt_algorithms(
endpoint_context, client_info, payload_type="id_token", sign=True, encrypt=True
)

def test_id_token_acr_claim(self):
_req = AREQS.copy()
_req["claims"] = {"id_token": {"acr": {"value": "https://refeds.org/profile/mfa"}}}

session_id = self._create_session(_req,authn_info="https://refeds.org/profile/mfa")
grant = self.session_manager[session_id]
code = self._mint_code(grant, session_id)
access_token = self._mint_access_token(grant, session_id, code)

id_token = self._mint_id_token(
grant, session_id, token_ref=code, access_token=access_token.value
)

_jwt = factory(id_token.value)
_id_token_content = _jwt.jwt.payload()
assert _id_token_content["acr"] == "https://refeds.org/profile/mfa"
Loading