Skip to content

Commit 102a7a7

Browse files
committed
fix(typing): accept any Collection for algorithms, not just list
The `algorithms` argument is only ever tested with truthiness and the `in` operator (it is stored verbatim as `self.allowed` on the JWS/JWE registries), so a tuple, set, or frozenset works at runtime. Widen the type hint from `list[str] | None` to `Collection[str] | None` across the jwt/jws/jwe public APIs and the registry constructors, and add a regression test covering list/tuple/set/frozenset.
1 parent 8b869e8 commit 102a7a7

6 files changed

Lines changed: 57 additions & 32 deletions

File tree

src/joserfc/_rfc7515/registry.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22
import warnings
33
from typing import Any
4+
from collections.abc import Collection
45
from enum import Enum
56
from .model import JWSAlgModel
67
from ..errors import (
@@ -56,7 +57,7 @@ class Strategy(Enum):
5657
def __init__(
5758
self,
5859
header_registry: HeaderRegistryDict | None = None,
59-
algorithms: list[str] | None = None,
60+
algorithms: Collection[str] | None = None,
6061
strict_check_header: bool = True,
6162
):
6263
self.header_registry: HeaderRegistryDict = {}
@@ -173,7 +174,7 @@ def filter_algorithms(cls, key: Any, names: list[str] | None = None) -> list[JWS
173174
default_registry = JWSRegistry()
174175

175176

176-
def construct_registry(algorithms: list[str] | None = None) -> JWSRegistry:
177+
def construct_registry(algorithms: Collection[str] | None = None) -> JWSRegistry:
177178
if algorithms:
178179
registry = JWSRegistry(algorithms=algorithms)
179180
else:

src/joserfc/_rfc7516/registry.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22
import warnings
33
import typing as t
4+
from collections.abc import Collection
45
from .models import JWEAlgModel, JWEEncModel, JWEZipModel
56
from ..errors import (
67
UnsupportedAlgorithmError,
@@ -66,7 +67,7 @@ class JWERegistry:
6667
def __init__(
6768
self,
6869
header_registry: HeaderRegistryDict | None = None,
69-
algorithms: list[str] | None = None,
70+
algorithms: Collection[str] | None = None,
7071
verify_all_recipients: bool = True,
7172
strict_check_header: bool = True,
7273
):

src/joserfc/jwe.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22
from typing import overload
3+
from collections.abc import Collection
34
from ._rfc7516.types import (
45
GeneralJSONSerialization,
56
FlattenedJSONSerialization,
@@ -52,7 +53,7 @@ def encrypt_compact(
5253
protected: Header,
5354
plaintext: bytes | str,
5455
public_key: KeyFlexible,
55-
algorithms: list[str] | None = None,
56+
algorithms: Collection[str] | None = None,
5657
registry: JWERegistry | None = None,
5758
sender_key: ECKey | OKPKey | KeySet | None = None,
5859
) -> str:
@@ -68,7 +69,7 @@ def encrypt_compact(
6869
:param protected: protected header part of the JWE, in dict
6970
:param plaintext: the content (message) to be encrypted
7071
:param public_key: a public key used to encrypt the CEK
71-
:param algorithms: a list of allowed algorithms
72+
:param algorithms: a collection (list, tuple, or set) of allowed algorithms
7273
:param registry: a JWERegistry to use
7374
:param sender_key: only required when using ECDH-1PU
7475
:return: JWE Compact Serialization in bytes
@@ -95,7 +96,7 @@ def encrypt_compact(
9596
def decrypt_compact(
9697
value: bytes | str,
9798
private_key: KeyFlexible,
98-
algorithms: list[str] | None = None,
99+
algorithms: Collection[str] | None = None,
99100
registry: JWERegistry | None = None,
100101
sender_key: ECKey | OKPKey | KeySet | None = None,
101102
) -> CompactEncryption:
@@ -114,7 +115,7 @@ def decrypt_compact(
114115
115116
:param value: a string (or bytes) of the JWE Compact Serialization
116117
:param private_key: a flexible private key to decrypt the serialization
117-
:param algorithms: a list of allowed algorithms
118+
:param algorithms: a collection (list, tuple, or set) of allowed algorithms
118119
:param registry: a JWERegistry to use
119120
:param sender_key: only required when using ECDH-1PU
120121
:return: object of the ``CompactEncryption``
@@ -140,7 +141,7 @@ def decrypt_compact(
140141
def encrypt_json(
141142
obj: GeneralJSONEncryption,
142143
public_key: KeyFlexible | None,
143-
algorithms: list[str] | None = None,
144+
algorithms: Collection[str] | None = None,
144145
registry: JWERegistry | None = None,
145146
sender_key: ECKey | OKPKey | KeySet | None = None,
146147
) -> GeneralJSONSerialization: ...
@@ -150,7 +151,7 @@ def encrypt_json(
150151
def encrypt_json(
151152
obj: FlattenedJSONEncryption,
152153
public_key: KeyFlexible | None,
153-
algorithms: list[str] | None = None,
154+
algorithms: Collection[str] | None = None,
154155
registry: JWERegistry | None = None,
155156
sender_key: ECKey | OKPKey | KeySet | None = None,
156157
) -> FlattenedJSONSerialization: ...
@@ -159,7 +160,7 @@ def encrypt_json(
159160
def encrypt_json(
160161
obj: GeneralJSONEncryption | FlattenedJSONEncryption,
161162
public_key: KeyFlexible | None,
162-
algorithms: list[str] | None = None,
163+
algorithms: Collection[str] | None = None,
163164
registry: JWERegistry | None = None,
164165
sender_key: ECKey | OKPKey | KeySet | None = None,
165166
) -> GeneralJSONSerialization | FlattenedJSONSerialization:
@@ -184,7 +185,7 @@ def encrypt_json(
184185
185186
:param obj: an instance of ``GeneralJSONEncryption`` or ``FlattenedJSONEncryption``
186187
:param public_key: a public key used to encrypt the CEK
187-
:param algorithms: a list of allowed algorithms
188+
:param algorithms: a collection (list, tuple, or set) of allowed algorithms
188189
:param registry: a JWERegistry to use
189190
:param sender_key: only required when using ECDH-1PU
190191
:return: JWE JSON Serialization in dict
@@ -214,7 +215,7 @@ def encrypt_json(
214215
def decrypt_json(
215216
data: GeneralJSONSerialization | FlattenedJSONSerialization,
216217
private_key: KeyFlexible,
217-
algorithms: list[str] | None = None,
218+
algorithms: Collection[str] | None = None,
218219
registry: JWERegistry | None = None,
219220
sender_key: ECKey | OKPKey | KeySet | None = None,
220221
) -> GeneralJSONEncryption | FlattenedJSONEncryption:
@@ -223,7 +224,7 @@ def decrypt_json(
223224
224225
:param data: JWE JSON Serialization in dict
225226
:param private_key: a flexible private key to decrypt the CEK
226-
:param algorithms: a list of allowed algorithms
227+
:param algorithms: a collection (list, tuple, or set) of allowed algorithms
227228
:param registry: a JWERegistry to use
228229
:param sender_key: only required when using ECDH-1PU
229230
:return: an instance of ``GeneralJSONEncryption`` or ``FlattenedJSONEncryption``

src/joserfc/jws.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22
from typing import overload, Any
3+
from collections.abc import Collection
34
from ._rfc7515.model import (
45
JWSAlgModel,
56
HeaderMember,
@@ -75,7 +76,7 @@ def serialize_compact(
7576
protected: Header,
7677
payload: bytes | str,
7778
private_key: KeyFlexible | None,
78-
algorithms: list[str] | None = None,
79+
algorithms: Collection[str] | None = None,
7980
registry: JWSRegistry | None = None,
8081
) -> str:
8182
"""Generate a JWS Compact Serialization. The JWS Compact Serialization
@@ -91,7 +92,7 @@ def serialize_compact(
9192
:param protected: protected header part of the JWS, in dict
9293
:param payload: payload data of the JWS, in bytes
9394
:param private_key: a flexible private key to sign the signature
94-
:param algorithms: a list of allowed algorithms
95+
:param algorithms: a collection (list, tuple, or set) of allowed algorithms
9596
:param registry: a JWSRegistry to use
9697
:return: JWS in str
9798
"""
@@ -123,15 +124,15 @@ def serialize_compact(
123124
def validate_compact(
124125
obj: CompactSignature,
125126
public_key: KeyFlexible | None,
126-
algorithms: list[str] | None = None,
127+
algorithms: Collection[str] | None = None,
127128
registry: JWSRegistry | None = None,
128129
) -> bool:
129130
"""Validate the JWS Compact Serialization with the given key.
130131
This method is usually used together with ``extract_compact``.
131132
132133
:param obj: object of the JWS Compact Serialization
133134
:param public_key: a flexible public key to verify the signature
134-
:param algorithms: a list of allowed algorithms
135+
:param algorithms: a collection (list, tuple, or set) of allowed algorithms
135136
:param registry: a JWSRegistry to use
136137
"""
137138
if registry is None:
@@ -156,7 +157,7 @@ def validate_compact(
156157
def deserialize_compact(
157158
value: bytes | str,
158159
public_key: KeyFlexible | None,
159-
algorithms: list[str] | None = None,
160+
algorithms: Collection[str] | None = None,
160161
registry: JWSRegistry | None = None,
161162
payload: bytes | str | None = None,
162163
) -> CompactSignature:
@@ -175,7 +176,7 @@ def deserialize_compact(
175176
176177
:param value: a string (or bytes) of the JWS Compact Serialization
177178
:param public_key: a flexible public key to verify the signature
178-
:param algorithms: a list of allowed algorithms
179+
:param algorithms: a collection (list, tuple, or set) of allowed algorithms
179180
:param registry: a JWSRegistry to use
180181
:param payload: optional payload, required with detached content
181182
:raises BadSignatureError: when signature verification fails
@@ -192,7 +193,7 @@ def serialize_json(
192193
members: list[HeaderDict],
193194
payload: bytes | str,
194195
private_key: KeyFlexible,
195-
algorithms: list[str] | None = None,
196+
algorithms: Collection[str] | None = None,
196197
registry: JWSRegistry | None = None,
197198
) -> GeneralJSONSerialization: ...
198199

@@ -202,7 +203,7 @@ def serialize_json(
202203
members: HeaderDict,
203204
payload: bytes | str,
204205
private_key: KeyFlexible,
205-
algorithms: list[str] | None = None,
206+
algorithms: Collection[str] | None = None,
206207
registry: JWSRegistry | None = None,
207208
) -> FlattenedJSONSerialization: ...
208209

@@ -211,7 +212,7 @@ def serialize_json(
211212
members: HeaderDict | list[HeaderDict],
212213
payload: bytes | str,
213214
private_key: KeyFlexible,
214-
algorithms: list[str] | None = None,
215+
algorithms: Collection[str] | None = None,
215216
registry: JWSRegistry | None = None,
216217
) -> GeneralJSONSerialization | FlattenedJSONSerialization:
217218
"""Generate a JWS JSON Serialization (in dict). The JWS JSON Serialization
@@ -261,7 +262,7 @@ def find_key(obj: HeaderMember) -> Key:
261262
def deserialize_json(
262263
value: GeneralJSONSerialization,
263264
public_key: KeyFlexible,
264-
algorithms: list[str] | None = None,
265+
algorithms: Collection[str] | None = None,
265266
registry: JWSRegistry | None = None,
266267
) -> GeneralJSONSignature: ...
267268

@@ -270,22 +271,22 @@ def deserialize_json(
270271
def deserialize_json(
271272
value: FlattenedJSONSerialization,
272273
public_key: KeyFlexible,
273-
algorithms: list[str] | None = None,
274+
algorithms: Collection[str] | None = None,
274275
registry: JWSRegistry | None = None,
275276
) -> FlattenedJSONSignature: ...
276277

277278

278279
def deserialize_json(
279280
value: GeneralJSONSerialization | FlattenedJSONSerialization,
280281
public_key: KeyFlexible,
281-
algorithms: list[str] | None = None,
282+
algorithms: Collection[str] | None = None,
282283
registry: JWSRegistry | None = None,
283284
) -> GeneralJSONSignature | FlattenedJSONSignature:
284285
"""Extract and validate the JWS (in string) with the given key.
285286
286287
:param value: a dict of the JSON signature
287288
:param public_key: a flexible public key to verify the signature
288-
:param algorithms: a list of allowed algorithms
289+
:param algorithms: a collection (list, tuple, or set) of allowed algorithms
289290
:param registry: a JWSRegistry to use
290291
:return: object of GeneralJSONSignature or FlattenedJSONSignature
291292
:raises BadSignatureError: when signature verification fails

src/joserfc/jwt.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import json
33
from json import JSONEncoder, JSONDecoder
44
from typing import Type
5+
from collections.abc import Collection
56
from ._rfc7519.claims import (
67
convert_claims,
78
Claims,
@@ -58,7 +59,7 @@ def encode(
5859
header: Header,
5960
claims: Claims,
6061
key: KeyFlexible,
61-
algorithms: list[str] | None = None,
62+
algorithms: Collection[str] | None = None,
6263
registry: JWSRegistry | JWERegistry | None = None,
6364
encoder_cls: Type[JSONEncoder] | None = None,
6465
default_type: str | None = "JWT",
@@ -68,7 +69,7 @@ def encode(
6869
:param header: A dict of the JWT header
6970
:param claims: A dict of the JWT claims to be encoded
7071
:param key: key used to sign the signature
71-
:param algorithms: a list of allowed algorithms
72+
:param algorithms: a collection (list, tuple, or set) of allowed algorithms
7273
:param registry: a ``JWSRegistry`` or ``JWERegistry`` to use
7374
:param encoder_cls: A JSONEncoder subclass to use
7475
:param default_type: default value of the ``typ`` header parameter
@@ -87,7 +88,7 @@ def encode(
8788
def decode(
8889
value: bytes | str,
8990
key: KeyFlexible,
90-
algorithms: list[str] | None = None,
91+
algorithms: Collection[str] | None = None,
9192
registry: JWSRegistry | JWERegistry | None = None,
9293
decoder_cls: Type[JSONDecoder] | None = None,
9394
) -> Token:
@@ -96,7 +97,7 @@ def decode(
9697
9798
:param value: text of the JWT
9899
:param key: key used to verify the signature
99-
:param algorithms: a list of allowed algorithms
100+
:param algorithms: a collection (list, tuple, or set) of allowed algorithms
100101
:param registry: a ``JWSRegistry`` or ``JWERegistry`` to use
101102
:param decoder_cls: A JSONDecoder subclass to use
102103
:raise BadSignatureError: when signature verification fails
@@ -119,15 +120,15 @@ def decode(
119120

120121

121122
def _decode_jwe(
122-
value: bytes, key: KeyFlexible, algorithms: list[str] | None = None, registry: JWERegistry | None = None
123+
value: bytes, key: KeyFlexible, algorithms: Collection[str] | None = None, registry: JWERegistry | None = None
123124
) -> tuple[Header, bytes]:
124125
jwe_obj = decrypt_compact(value, key, algorithms, registry)
125126
assert jwe_obj.plaintext is not None
126127
return jwe_obj.headers(), jwe_obj.plaintext
127128

128129

129130
def _decode_jws(
130-
value: bytes, key: KeyFlexible, algorithms: list[str] | None = None, registry: JWSRegistry | None = None
131+
value: bytes, key: KeyFlexible, algorithms: Collection[str] | None = None, registry: JWSRegistry | None = None
131132
) -> tuple[Header, bytes]:
132133
jws_obj = deserialize_compact(value, key, algorithms, registry)
133134
assert jws_obj.payload is not None

tests/jwt/test_jwt.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
InvalidPayloadError,
66
MissingClaimError,
77
UnsupportedHeaderError,
8+
UnsupportedAlgorithmError,
89
DecodeError,
910
)
1011

@@ -94,6 +95,25 @@ def test_using_registry(self):
9495
registry=jws.JWSRegistry(),
9596
)
9697

98+
def test_algorithms_accepts_any_collection(self):
99+
# ``algorithms`` only needs to support ``in`` and truthiness, so any
100+
# collection works, not just a list.
101+
data = jwt.encode({"alg": "HS256"}, {"sub": "a"}, self.oct_key)
102+
for algorithms in (["HS256"], ("HS256",), {"HS256"}, frozenset({"HS256"})):
103+
token = jwt.decode(data, self.oct_key, algorithms=algorithms)
104+
self.assertEqual(token.claims["sub"], "a")
105+
106+
def test_algorithms_collection_rejects_disallowed(self):
107+
data = jwt.encode({"alg": "HS256"}, {"sub": "a"}, self.oct_key)
108+
# a disallowed algorithm is still rejected when given as a non-list
109+
self.assertRaises(
110+
UnsupportedAlgorithmError,
111+
jwt.decode,
112+
data,
113+
self.oct_key,
114+
("HS384",),
115+
)
116+
97117
def test_with_embedded_jwk(self):
98118
value = (
99119
"eyJqd2siOnsiY3J2IjoiUC0yNTYiLCJ4IjoiVU05ZzVuS25aWFlvdldBbE"

0 commit comments

Comments
 (0)