Skip to content

Commit 62d968b

Browse files
committed
feat(jwe): add content size validation to avoid DoS
1 parent 05ccff5 commit 62d968b

File tree

6 files changed

+115
-21
lines changed

6 files changed

+115
-21
lines changed

src/joserfc/_rfc7515/registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
from enum import Enum
55
from .model import JWSAlgModel
66
from ..errors import (
7+
JoseError,
78
UnsupportedAlgorithmError,
89
SecurityWarning,
9-
JoseError,
1010
ExceededSizeError,
1111
)
1212
from ..registry import (

src/joserfc/_rfc7516/compact.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .models import CompactEncryption, Recipient
2+
from .registry import JWERegistry
23
from .._keys import Key
34
from ..errors import (
45
MissingAlgorithmError,
@@ -32,12 +33,17 @@ def represent_compact(obj: CompactEncryption) -> bytes:
3233
)
3334

3435

35-
def extract_compact(value: bytes) -> CompactEncryption:
36+
def extract_compact(value: bytes, registry: JWERegistry) -> CompactEncryption:
3637
parts = value.split(b".")
3738
if len(parts) != 5:
3839
raise ValueError("Invalid JSON Web Encryption")
3940

4041
header_segment, ek_segment, iv_segment, ciphertext_segment, tag_segment = parts
42+
registry.validate_protected_header_size(header_segment)
43+
registry.validate_encrypted_key_size(ek_segment)
44+
registry.validate_initialization_vector_size(iv_segment)
45+
registry.validate_ciphertext_size(ciphertext_segment)
46+
registry.validate_auth_tag_size(tag_segment)
4147
try:
4248
protected = json_b64decode(header_segment)
4349
if "alg" not in protected:

src/joserfc/_rfc7516/json.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
from __future__ import annotations
12
import typing as t
23
from .models import (
34
BaseJSONEncryption,
45
GeneralJSONEncryption,
56
FlattenedJSONEncryption,
67
Recipient,
78
)
9+
from .registry import JWERegistry
810
from .types import (
911
JSONRecipientDict,
1012
GeneralJSONSerialization,
@@ -70,44 +72,49 @@ def __represent_json_serialization(obj: BaseJSONEncryption): # type: ignore[no-
7072
return data
7173

7274

73-
def extract_general_json(data: GeneralJSONSerialization) -> GeneralJSONEncryption:
74-
protected = json_b64decode(data["protected"])
75+
def extract_general_json(data: GeneralJSONSerialization, registry: JWERegistry) -> GeneralJSONEncryption:
76+
protected_segment = to_bytes(data["protected"])
77+
registry.validate_protected_header_size(protected_segment)
78+
protected = json_b64decode(protected_segment)
79+
7580
unprotected = data.get("unprotected")
76-
base64_segments, bytes_segments, aad = __extract_segments(data)
81+
base64_segments, bytes_segments, aad = __extract_segments(data, registry)
82+
7783
obj = GeneralJSONEncryption(protected, None, unprotected, aad)
7884
obj.base64_segments = base64_segments
7985
obj.bytes_segments = bytes_segments
8086
for item in data["recipients"]:
81-
recipient: Recipient[Key] = Recipient(obj, item.get("header"))
82-
if "encrypted_key" in item:
83-
recipient.encrypted_key = urlsafe_b64decode(to_bytes(item["encrypted_key"]))
87+
recipient = __extract_recipient(obj, item, registry)
8488
obj.recipients.append(recipient)
8589
return obj
8690

8791

88-
def extract_flattened_json(data: FlattenedJSONSerialization) -> FlattenedJSONEncryption:
89-
protected = json_b64decode(data["protected"])
92+
def extract_flattened_json(data: FlattenedJSONSerialization, registry: JWERegistry) -> FlattenedJSONEncryption:
93+
protected_segment = to_bytes(data["protected"])
94+
registry.validate_protected_header_size(protected_segment)
95+
protected = json_b64decode(protected_segment)
9096
unprotected = data.get("unprotected")
91-
base64_segments, bytes_segments, aad = __extract_segments(data)
97+
base64_segments, bytes_segments, aad = __extract_segments(data, registry)
9298
obj = FlattenedJSONEncryption(protected, None, unprotected, aad)
9399
obj.base64_segments = base64_segments
94100
obj.bytes_segments = bytes_segments
95-
96-
recipient: Recipient[Key] = Recipient(obj, data.get("header"))
97-
if "encrypted_key" in data:
98-
recipient.encrypted_key = urlsafe_b64decode(to_bytes(data["encrypted_key"]))
101+
recipient = __extract_recipient(obj, data, registry)
99102
obj.recipients.append(recipient)
100103
return obj
101104

102105

103106
def __extract_segments(
104107
data: t.Union[GeneralJSONSerialization, FlattenedJSONSerialization],
108+
registry: JWERegistry,
105109
) -> tuple[dict[str, bytes], dict[str, bytes], t.Optional[bytes]]:
106110
base64_segments: dict[str, bytes] = {
107111
"iv": to_bytes(data["iv"]),
108112
"ciphertext": to_bytes(data["ciphertext"]),
109113
"tag": to_bytes(data["tag"]),
110114
}
115+
registry.validate_initialization_vector_size(base64_segments["iv"])
116+
registry.validate_ciphertext_size(base64_segments["ciphertext"])
117+
registry.validate_auth_tag_size(base64_segments["tag"])
111118
bytes_segments: dict[str, bytes] = {
112119
"iv": urlsafe_b64decode(base64_segments["iv"]),
113120
"ciphertext": urlsafe_b64decode(base64_segments["ciphertext"]),
@@ -118,3 +125,16 @@ def __extract_segments(
118125
else:
119126
aad = None
120127
return base64_segments, bytes_segments, aad
128+
129+
130+
def __extract_recipient(
131+
obj: FlattenedJSONEncryption | GeneralJSONEncryption,
132+
data: FlattenedJSONSerialization | JSONRecipientDict,
133+
registry: JWERegistry,
134+
) -> Recipient[Key]:
135+
recipient: Recipient[Key] = Recipient(obj, data.get("header"))
136+
if "encrypted_key" in data:
137+
ek_segment = to_bytes(data["encrypted_key"])
138+
registry.validate_encrypted_key_size(ek_segment)
139+
recipient.encrypted_key = urlsafe_b64decode(ek_segment)
140+
return recipient

src/joserfc/_rfc7516/registry.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@
22
import warnings
33
import typing as t
44
from .models import JWEAlgModel, JWEEncModel, JWEZipModel
5-
from ..errors import UnsupportedAlgorithmError, SecurityWarning
5+
from ..errors import (
6+
UnsupportedAlgorithmError,
7+
SecurityWarning,
8+
ExceededSizeError,
9+
)
610
from ..registry import (
711
Header,
812
HeaderRegistryDict,
@@ -41,13 +45,24 @@ class JWERegistry:
4145
:param strict_check_header: only allow header key in the registry to be used
4246
"""
4347

44-
algorithms: t.ClassVar[AlgorithmsDict] = {
48+
algorithms: AlgorithmsDict = {
4549
"alg": {},
4650
"enc": {},
4751
"zip": {},
4852
}
4953
recommended: t.ClassVar[list[str]] = []
5054

55+
#: max protected header content's size in bytes
56+
max_protected_header_length: int = 1024
57+
#: max encrypted key's size in bytes
58+
max_encrypted_key_length: int = 1024
59+
#: max initialization vector's size in bytes
60+
max_initialization_vector_length: int = 64
61+
#: max ciphertext's size in bytes
62+
max_ciphertext_length: int = 65536 # 64KB
63+
#: max auth tag's size in bytes
64+
max_auth_tag_length: int = 64
65+
5166
def __init__(
5267
self,
5368
header_registry: t.Optional[HeaderRegistryDict] = None,
@@ -85,6 +100,28 @@ def check_header(self, header: Header, check_more: bool = False) -> None:
85100
elif self.strict_check_header:
86101
check_supported_header(self.header_registry, header)
87102

103+
def validate_protected_header_size(self, header: bytes) -> None:
104+
if header and len(header) > self.max_protected_header_length:
105+
raise ExceededSizeError(f"Header size of '{header!r}' exceeds {self.max_protected_header_length} bytes.")
106+
107+
def validate_encrypted_key_size(self, ek: bytes) -> None:
108+
if ek and len(ek) > self.max_encrypted_key_length:
109+
raise ExceededSizeError(f"Encrypted key size of '{ek!r}' exceeds {self.max_encrypted_key_length} bytes.")
110+
111+
def validate_initialization_vector_size(self, iv: bytes) -> None:
112+
if iv and len(iv) > self.max_initialization_vector_length:
113+
raise ExceededSizeError(
114+
f"Initialization vector size of '{iv!r}' exceeds {self.max_initialization_vector_length} bytes."
115+
)
116+
117+
def validate_ciphertext_size(self, ciphertext: bytes) -> None:
118+
if ciphertext and len(ciphertext) > self.max_ciphertext_length:
119+
raise ExceededSizeError(f"Ciphertext size of '{ciphertext!r}' exceeds {self.max_ciphertext_length} bytes.")
120+
121+
def validate_auth_tag_size(self, tag: bytes) -> None:
122+
if tag and len(tag) > self.max_auth_tag_length:
123+
raise ExceededSizeError(f"Auth tag size of '{tag!r}' exceeds {self.max_auth_tag_length} bytes.")
124+
88125
def get_alg(self, name: str) -> JWEAlgModel:
89126
"""Get the allowed ("alg") algorithm instance of the given name.
90127

src/joserfc/jwe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,12 +119,12 @@ def decrypt_compact(
119119
:param sender_key: only required when using ECDH-1PU
120120
:return: object of the ``CompactEncryption``
121121
"""
122-
obj = extract_compact(to_bytes(value))
123122
if algorithms:
124123
registry = JWERegistry(algorithms=algorithms)
125124
elif registry is None:
126125
registry = default_registry
127126

127+
obj = extract_compact(to_bytes(value), registry)
128128
recipient = obj.recipient
129129
assert recipient is not None
130130
key = guess_key(private_key, recipient, use="enc")
@@ -235,12 +235,12 @@ def decrypt_json(
235235

236236
reject_unprotected_crit_header(data.get("unprotected"))
237237
if "recipients" in data:
238-
general_obj = extract_general_json(data) # type: ignore[arg-type]
238+
general_obj = extract_general_json(data, registry) # type: ignore[arg-type]
239239
_attach_recipient_keys(general_obj.recipients, private_key, sender_key)
240240
perform_decrypt(general_obj, registry)
241241
return general_obj
242242
else:
243-
flattened_obj = extract_flattened_json(data) # type: ignore[arg-type]
243+
flattened_obj = extract_flattened_json(data, registry) # type: ignore[arg-type]
244244
_attach_recipient_keys(flattened_obj.recipients, private_key, sender_key)
245245
perform_decrypt(flattened_obj, registry)
246246
return flattened_obj

tests/jwe/test_compact.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
ExceededSizeError,
1616
InvalidHeaderValueError,
1717
)
18-
from joserfc.util import json_b64encode
18+
from joserfc.util import json_b64encode, urlsafe_b64encode
1919
from tests.base import load_key
2020

2121

@@ -204,6 +204,37 @@ def test_decompress_zip_exceeds_size(self):
204204
result = encrypt_compact({"alg": "dir", "enc": "A128GCM", "zip": "DEF"}, b"h" * 300000, key)
205205
self.assertRaises(ExceededSizeError, decrypt_compact, result, key)
206206

207+
def test_header_exceeds_size(self):
208+
header = json_b64encode({f"a{i}": "a" * i for i in range(1000)}).decode("utf-8")
209+
s = header + "..YbDfdYa6p-wAEFul.YK7j0MsH-Dko6ifsEg.wES6-QAOEbErZqXiS0JHRw"
210+
self.assertRaises(ExceededSizeError, decrypt_compact, s, OctKey.import_key("secret"))
211+
212+
def test_encrypted_key_exceeds_size(self):
213+
header = json_b64encode({"alg": "dir", "enc": "A128GCM"}).decode("utf-8")
214+
ek = urlsafe_b64encode(("a" * 1000).encode("utf-8")).decode("utf-8")
215+
s = header + "." + ek + ".YbDfdYa6p-wAEFul.YK7j0MsH-Dko6ifsEg.wES6-QAOEbErZqXiS0JHRw"
216+
key = OctKey.import_key({"k": "pyL42ncDFSYnenl-GiZjRw", "kty": "oct"})
217+
self.assertRaises(ExceededSizeError, decrypt_compact, s, key)
218+
219+
def test_initialization_vector_size(self):
220+
header = json_b64encode({"alg": "dir", "enc": "A128GCM"}).decode("utf-8")
221+
iv = urlsafe_b64encode(("a" * 1000).encode("utf-8")).decode("utf-8")
222+
s = header + ".." + iv + ".YK7j0MsH-Dko6ifsEg.wES6-QAOEbErZqXiS0JHRw"
223+
key = OctKey.import_key({"k": "pyL42ncDFSYnenl-GiZjRw", "kty": "oct"})
224+
self.assertRaises(ExceededSizeError, decrypt_compact, s, key)
225+
226+
def test_ciphertext_exceeds_size(self):
227+
header = json_b64encode({"alg": "dir", "enc": "A128GCM"}).decode("utf-8")
228+
ciphertext = urlsafe_b64encode(("a" * 70000).encode("utf-8")).decode("utf-8")
229+
s = header + "..YbDfdYa6p-wAEFul." + ciphertext + ".wES6-QAOEbErZqXiS0JHRw"
230+
self.assertRaises(ExceededSizeError, decrypt_compact, s, OctKey.import_key("secret"))
231+
232+
def test_auth_tag_exceeds_size(self):
233+
header = json_b64encode({"alg": "dir", "enc": "A128GCM"}).decode("utf-8")
234+
tag = urlsafe_b64encode(("a" * 80).encode("utf-8")).decode("utf-8")
235+
s = header + "..YbDfdYa6p-wAEFul.YK7j0MsH-Dko6ifsEg." + tag
236+
self.assertRaises(ExceededSizeError, decrypt_compact, s, OctKey.import_key("secret"))
237+
207238
def test_invalid_compact_data(self):
208239
private_key: RSAKey = load_key("rsa-openssl-private.pem")
209240
value = b"a.b.c.d.e.f.g"

0 commit comments

Comments
 (0)