Skip to content

Commit

Permalink
fix: only auto add kid of key set for "encode" action
Browse files Browse the repository at this point in the history
  • Loading branch information
lepture committed Aug 6, 2023
1 parent ab789cb commit 752fae7
Show file tree
Hide file tree
Showing 7 changed files with 27 additions and 18 deletions.
3 changes: 3 additions & 0 deletions src/joserfc/_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
"RSAKey",
"ECKey",
"OKPKey",
"Key",
"KeySet",
"JWKRegistry",
]

Key = t.Union[OctKey, RSAKey, ECKey, OKPKey]
Expand Down
20 changes: 11 additions & 9 deletions src/joserfc/jwe.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,11 @@ def encrypt_compact(

obj = CompactEncryption(protected, to_bytes(plaintext))
recipient: Recipient[Key] = Recipient(obj)
key = guess_key(public_key, recipient)
key = guess_key(public_key, recipient, True)
key.check_use("enc")
recipient.recipient_key = key
if sender_key:
recipient.sender_key = _guess_sender_key(recipient, sender_key)
recipient.sender_key = _guess_sender_key(recipient, sender_key, True)
obj.recipient = recipient
perform_encrypt(obj, registry)
out = represent_compact(obj)
Expand Down Expand Up @@ -214,10 +214,10 @@ def encrypt_json(

for recipient in obj.recipients:
if sender_key and not recipient.sender_key:
recipient.sender_key = _guess_sender_key(recipient, sender_key)
recipient.sender_key = _guess_sender_key(recipient, sender_key, True)
if not recipient.recipient_key:
assert public_key is not None
key = guess_key(public_key, recipient)
key = guess_key(public_key, recipient, True)
key.check_use("enc")
recipient.recipient_key = key

Expand Down Expand Up @@ -274,15 +274,17 @@ def _attach_recipient_keys(

def _guess_sender_key(
recipient: Recipient[Key],
key: t.Union[ECKey, OKPKey, KeySet]) -> t.Union[ECKey, OKPKey]:
key: t.Union[ECKey, OKPKey, KeySet],
use_random: bool = False) -> t.Union[ECKey, OKPKey]:
if isinstance(key, KeySet):
headers = recipient.headers()
skid = headers.get('skid')
if skid:
return key.get_by_kid(skid) # type: ignore[return-value]
skey = key.pick_random_key(headers["alg"])
if skey is not None:
recipient.add_header("skid", skey.kid)
return skey # type: ignore[return-value]
if use_random:
skey = key.pick_random_key(headers["alg"])
if skey is not None:
recipient.add_header("skid", skey.kid)
return skey # type: ignore[return-value]
raise ValueError("Invalid key")
return key
5 changes: 3 additions & 2 deletions src/joserfc/jwk.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,12 @@ def set_kid(self, kid: str):
KeyFlexible = t.Union[str, bytes, Key, KeySet, KeyCallable]


def guess_key(key: KeyFlexible, obj: GuestProtocol) -> Key:
def guess_key(key: KeyFlexible, obj: GuestProtocol, use_random: bool = False) -> Key:
"""Guess key from a various sources.
:param key: a very flexible key
:param obj: a protocol that has ``headers`` and ``set_kid`` methods
:param use_random: pick a random key from key set
"""
headers = obj.headers()

Expand All @@ -57,7 +58,7 @@ def guess_key(key: KeyFlexible, obj: GuestProtocol) -> Key:

elif isinstance(key, KeySet):
kid = headers.get("kid")
if not kid:
if not kid and use_random:
# choose one key by random
rv_key: Key = key.pick_random_key(headers["alg"]) # type: ignore
if rv_key is None:
Expand Down
4 changes: 2 additions & 2 deletions src/joserfc/jws.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def serialize_compact(
registry.check_header(protected)
obj = CompactSignature(protected, to_bytes(payload))
alg: JWSAlgModel = registry.get_alg(protected["alg"])
key: Key = guess_key(private_key, obj)
key: Key = guess_key(private_key, obj, True)
key.check_use("sig")
alg.check_key_type(key)
key.check_alg(protected["alg"])
Expand Down Expand Up @@ -227,7 +227,7 @@ def serialize_json(
if registry is None:
registry = construct_registry(algorithms)

find_key = lambda d: guess_key(private_key, d)
find_key = lambda d: guess_key(private_key, d, True)
_payload = to_bytes(payload)
if isinstance(members, list):
return sign_general_json(members, _payload, registry, find_key)
Expand Down
2 changes: 1 addition & 1 deletion src/joserfc/rfc7797/compact.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def serialize_compact(
registry.check_header(protected)
obj = CompactSignature(protected, to_bytes(payload))
alg = registry.get_alg(protected["alg"])
key = guess_key(private_key, obj)
key = guess_key(private_key, obj, True)
key.check_use("sig")

header_segment = json_b64encode(protected)
Expand Down
2 changes: 1 addition & 1 deletion src/joserfc/rfc7797/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def serialize_json(

registry.check_header(headers)

key = guess_key(private_key, _member)
key = guess_key(private_key, _member, True)
key.check_use("sig")
alg = registry.get_alg(headers["alg"])

Expand Down
9 changes: 6 additions & 3 deletions tests/jwk/test_key_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,20 @@ def test_guess_key_set(self):
key_set = KeySet([OctKey.generate_key(), RSAKey.generate_key()])
guest = Guest()
guest._headers["alg"] = "HS256"
key = guess_key(key_set, guest)
self.assertRaises(ValueError, guess_key, key_set, guest)
key = guess_key(key_set, guest, True)
self.assertIsInstance(key, OctKey)
key = guess_key(key_set, guest)

guest = Guest()
guest._headers["alg"] = "RS256"
key = guess_key(key_set, guest)
self.assertRaises(ValueError, guess_key, key_set, guest)
key = guess_key(key_set, guest, True)
self.assertIsInstance(key, RSAKey)

guest = Guest()
guest._headers["alg"] = "ES256"
self.assertRaises(ValueError, guess_key, key_set, guest)
self.assertRaises(ValueError, guess_key, key_set, guest, True)

def test_invalid_key(self):
self.assertRaises(ValueError, guess_key, {}, Guest())
Expand Down

0 comments on commit 752fae7

Please sign in to comment.