Skip to content

Commit 8553afb

Browse files
committed
feat: add class method JWSRegistry.guess_alg
ref #49
1 parent c84aa98 commit 8553afb

File tree

6 files changed

+101
-10
lines changed

6 files changed

+101
-10
lines changed

src/joserfc/_rfc7515/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ class JWSAlgModel(object, metaclass=ABCMeta):
112112
key_type = "oct"
113113
algorithm_type: Literal["JWS"] = "JWS"
114114
algorithm_location = "sig"
115+
algorithm_security = 0
115116

116117
def check_key(self, key: Any) -> None:
117118
key.check_use("sig")

src/joserfc/_rfc7515/registry.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from __future__ import annotations
22
import warnings
3+
from typing import Any
4+
from enum import Enum
35
from .model import JWSAlgModel
4-
from ..errors import UnsupportedAlgorithmError, SecurityWarning
6+
from ..errors import UnsupportedAlgorithmError, SecurityWarning, JoseError
57
from ..registry import (
68
JWS_HEADER_REGISTRY,
79
Header,
@@ -28,6 +30,10 @@ class JWSRegistry:
2830
:param strict_check_header: only allow header key in the registry to be used
2931
"""
3032

33+
class Strategy(Enum):
34+
RECOMMENDED = 1
35+
SECURITY = 2
36+
3137
default_header_registry: HeaderRegistryDict = JWS_HEADER_REGISTRY
3238
algorithms: dict[str, JWSAlgModel] = {}
3339
recommended: list[str] = []
@@ -79,6 +85,45 @@ def check_header(self, header: Header) -> None:
7985
if self.strict_check_header:
8086
check_supported_header(self.header_registry, header)
8187

88+
@classmethod
89+
def guess_alg(cls, key: Any, strategy: Strategy) -> str | None:
90+
"""Guess the JWS algorithm for a given key.
91+
92+
:param key: key instance
93+
:param strategy: the strategy for guessing the JWS algorithm
94+
"""
95+
if strategy == cls.Strategy.RECOMMENDED:
96+
algorithms = cls.filter_algorithms(key, cls.recommended)
97+
elif strategy == cls.Strategy.SECURITY:
98+
names = list(cls.algorithms.keys())
99+
algorithms = cls.filter_algorithms(key, names)
100+
# sort by security level
101+
algorithms.sort(key=lambda alg: alg.algorithm_security, reverse=True)
102+
else:
103+
raise NotImplementedError(f"Unknown algorithm strategy '{strategy}'")
104+
105+
if algorithms:
106+
return algorithms[0].name
107+
else:
108+
return None
109+
110+
@classmethod
111+
def filter_algorithms(cls, key: Any, names: list[str]) -> list[JWSAlgModel]:
112+
"""Filter JWS algorithms based on the given algorithm names.
113+
114+
:param key: key instance
115+
:param names: list of algorithm names
116+
"""
117+
rv: list[JWSAlgModel] = []
118+
for name in names:
119+
alg = cls.algorithms[name]
120+
try:
121+
alg.check_key(key)
122+
rv.append(alg)
123+
except JoseError:
124+
pass
125+
return rv
126+
82127

83128
#: default JWS registry
84129
default_registry = JWSRegistry()

src/joserfc/_rfc7518/jws_algs.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def __init__(self, sha_type: t.Literal[256, 384, 512], recommended: bool = False
5757
self.description = f"HMAC using SHA-{sha_type}"
5858
self.recommended = recommended
5959
self.hash_alg = getattr(self, f"SHA{sha_type}")
60+
self.algorithm_security = sha_type
6061

6162
def sign(self, msg: bytes, key: OctKey) -> bytes:
6263
# it is faster than the one in cryptography
@@ -89,6 +90,7 @@ def __init__(self, sha_type: t.Literal[256, 384, 512], recommended: bool = False
8990
self.description = f"RSASSA-PKCS1-v1_5 using SHA-{sha_type}"
9091
self.recommended = recommended
9192
self.hash_alg = getattr(self, f"SHA{sha_type}")
93+
self.algorithm_security = sha_type
9294

9395
def sign(self, msg: bytes, key: RSAKey) -> bytes:
9496
op_key = key.get_op_key("sign")
@@ -103,7 +105,7 @@ def verify(self, msg: bytes, sig: bytes, key: RSAKey) -> bool:
103105
return False
104106

105107

106-
class ECAlgorithm(JWSAlgModel):
108+
class ESAlgorithm(JWSAlgModel):
107109
"""ECDSA using SHA algorithms for JWS. Available algorithms:
108110
109111
- ES256: ECDSA using P-256 and SHA-256
@@ -123,6 +125,7 @@ def __init__(self, name: str, curve: str, sha_type: t.Literal[256, 384, 512], re
123125
self.description = f"ECDSA using {self.curve} and SHA-{sha_type}"
124126
self.recommended = recommended
125127
self.hash_alg = getattr(self, f"SHA{sha_type}")
128+
self.algorithm_security = sha_type
126129

127130
def check_key(self, key: ECKey) -> None:
128131
super().check_key(key)
@@ -174,6 +177,7 @@ def __init__(self, sha_type: t.Literal[256, 384, 512]):
174177
self.description = f"RSASSA-PSS using SHA-{sha_type} and MGF1 with SHA-{sha_type}"
175178
self.hash_alg = getattr(self, f"SHA{sha_type}")
176179
self.padding = padding.PSS(mgf=padding.MGF1(self.hash_alg()), salt_length=self.hash_alg.digest_size)
180+
self.algorithm_security = sha_type
177181

178182
def sign(self, msg: bytes, key: RSAKey) -> bytes:
179183
op_key = key.get_op_key("sign")
@@ -196,9 +200,9 @@ def verify(self, msg: bytes, sig: bytes, key: RSAKey) -> bool:
196200
RSAAlgorithm(256, True), # RS256
197201
RSAAlgorithm(384), # RS384
198202
RSAAlgorithm(512), # RS512
199-
ECAlgorithm("ES256", "P-256", 256, True),
200-
ECAlgorithm("ES384", "P-384", 384),
201-
ECAlgorithm("ES512", "P-521", 512),
203+
ESAlgorithm("ES256", "P-256", 256, True),
204+
ESAlgorithm("ES384", "P-384", 384),
205+
ESAlgorithm("ES512", "P-521", 512),
202206
RSAPSSAlgorithm(256), # PS256
203207
RSAPSSAlgorithm(384), # PS384
204208
RSAPSSAlgorithm(512), # PS512
@@ -208,5 +212,5 @@ def verify(self, msg: bytes, sig: bytes, key: RSAKey) -> bool:
208212
NoneAlgModel = NoneAlgorithm
209213
HMACAlgModel = HMACAlgorithm
210214
RSAAlgModel = RSAAlgorithm
211-
ECAlgModel = ECAlgorithm
215+
ECAlgModel = ESAlgorithm
212216
RSAPSSAlgModel = RSAPSSAlgorithm

src/joserfc/_rfc8812/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from cryptography.hazmat.primitives.asymmetric.ec import SECP256K1
22
from .._rfc7518.ec_key import ECKey
3-
from .._rfc7518.jws_algs import ECAlgorithm
3+
from .._rfc7518.jws_algs import ESAlgorithm
44

5-
ES256K = ECAlgorithm("ES256K", "secp256k1", 256)
5+
ES256K = ESAlgorithm("ES256K", "secp256k1", 256)
66

77

88
def register_secp256k1() -> None:

src/joserfc/jwa.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
NoneAlgorithm,
1515
HMACAlgorithm,
1616
RSAAlgorithm,
17-
ECAlgorithm,
17+
ESAlgorithm,
1818
RSAPSSAlgorithm,
1919
JWS_ALGORITHMS as _JWS_ALGORITHMS,
2020
)
@@ -46,7 +46,7 @@
4646
"NoneAlgorithm",
4747
"HMACAlgorithm",
4848
"RSAAlgorithm",
49-
"ECAlgorithm",
49+
"ESAlgorithm",
5050
"RSAPSSAlgorithm",
5151
"EdDSAAlgorithm",
5252
# JWE algorithms

tests/jws/test_registry.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import unittest
2+
3+
from joserfc.jws import JWSRegistry
4+
from joserfc.jwk import OctKey, RSAKey, ECKey, OKPKey
5+
6+
7+
class JWSRegistryTest(unittest.TestCase):
8+
oct_key = OctKey.generate_key()
9+
rsa_key = RSAKey.generate_key()
10+
ec_key = ECKey.generate_key()
11+
okp_key = OKPKey.generate_key()
12+
13+
def test_guess_recommended_algorithm(self):
14+
name = JWSRegistry.guess_alg(self.oct_key, JWSRegistry.Strategy.RECOMMENDED)
15+
self.assertEqual(name, "HS256")
16+
17+
name = JWSRegistry.guess_alg(self.rsa_key, JWSRegistry.Strategy.RECOMMENDED)
18+
self.assertEqual(name, "RS256")
19+
20+
name = JWSRegistry.guess_alg(self.ec_key, JWSRegistry.Strategy.RECOMMENDED)
21+
self.assertEqual(name, "ES256")
22+
23+
name = JWSRegistry.guess_alg(self.okp_key, JWSRegistry.Strategy.RECOMMENDED)
24+
self.assertEqual(name, None)
25+
26+
def test_guess_security_algorithm(self):
27+
name = JWSRegistry.guess_alg(self.oct_key, JWSRegistry.Strategy.SECURITY)
28+
self.assertEqual(name, "HS512")
29+
30+
name = JWSRegistry.guess_alg(self.rsa_key, JWSRegistry.Strategy.SECURITY)
31+
self.assertEqual(name, "RS512")
32+
33+
name = JWSRegistry.guess_alg(self.ec_key, JWSRegistry.Strategy.SECURITY)
34+
self.assertEqual(name, "ES256")
35+
36+
ec521 = ECKey.generate_key("P-521")
37+
name = JWSRegistry.guess_alg(ec521, JWSRegistry.Strategy.SECURITY)
38+
self.assertEqual(name, "ES512")
39+
40+
name = JWSRegistry.guess_alg(self.okp_key, JWSRegistry.Strategy.SECURITY)
41+
self.assertEqual(name, "EdDSA")

0 commit comments

Comments
 (0)