Skip to content

Commit 05ccff5

Browse files
committed
fix(jws): validating content size to avoid DoS
1 parent 7d7733b commit 05ccff5

File tree

9 files changed

+163
-47
lines changed

9 files changed

+163
-47
lines changed

.pre-commit-config.yaml

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,6 @@ repos:
77
args: [--fix, --exit-non-zero-on-fix]
88
- id: ruff-format
99

10-
- repo: https://github.com/pre-commit/mirrors-mypy
11-
rev: 'v1.18.1'
12-
hooks:
13-
- id: mypy
14-
1510
- repo: https://github.com/codespell-project/codespell
1611
rev: 'v2.4.1'
1712
hooks:

src/joserfc/_rfc7515/json.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from .registry import JWSRegistry
1515
from ..registry import reject_unprotected_crit_header
1616
from ..util import (
17+
to_bytes,
1718
json_b64encode,
1819
json_b64decode,
1920
urlsafe_b64encode,
@@ -86,25 +87,27 @@ def sign_json_member(
8687
return rv
8788

8889

89-
def extract_general_json(value: GeneralJSONSerialization) -> GeneralJSONSignature:
90+
def extract_general_json(value: GeneralJSONSerialization, registry: JWSRegistry) -> GeneralJSONSignature:
9091
payload_segment: bytes = value["payload"].encode("utf-8")
92+
registry.validate_payload_size(payload_segment)
9193
try:
9294
payload = urlsafe_b64decode(payload_segment)
9395
except (TypeError, ValueError):
9496
raise DecodeError("Invalid payload")
9597

9698
signatures: list[JSONSignatureDict] = value["signatures"]
97-
members = [__signature_to_member(sig) for sig in signatures]
99+
members = [__signature_to_member(sig, registry) for sig in signatures]
98100
obj = GeneralJSONSignature(members, payload)
99101
obj.signatures = signatures
100102
obj.segments = {"payload": payload_segment}
101103
return obj
102104

103105

104-
def __signature_to_member(sig: JSONSignatureDict) -> HeaderMember:
106+
def __signature_to_member(sig: JSONSignatureDict, registry: JWSRegistry) -> HeaderMember:
105107
member = HeaderMember()
106108
if "protected" in sig:
107-
protected_segment = sig["protected"]
109+
protected_segment = to_bytes(sig["protected"])
110+
registry.validate_header_size(protected_segment)
108111
member.protected = json_b64decode(protected_segment)
109112
if "header" in sig:
110113
member.header = sig["header"]
@@ -139,11 +142,16 @@ def verify_signature(
139142
alg = registry.get_alg(headers["alg"])
140143
key = find_key(member)
141144
alg.check_key(key)
145+
142146
if "protected" in signature:
143-
protected_segment = signature["protected"].encode("utf-8")
147+
protected_segment = to_bytes(signature["protected"])
144148
else:
145149
protected_segment = b""
146-
sig = urlsafe_b64decode(signature["signature"].encode("utf-8"))
150+
151+
signature_segment = to_bytes(signature["signature"])
152+
registry.validate_signature_size(signature_segment)
153+
154+
sig = urlsafe_b64decode(signature_segment)
147155
signing_input = b".".join([protected_segment, payload_segment])
148156
return alg.verify(signing_input, sig, key)
149157

src/joserfc/_rfc7515/registry.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,12 @@
33
from typing import Any
44
from enum import Enum
55
from .model import JWSAlgModel
6-
from ..errors import UnsupportedAlgorithmError, SecurityWarning, JoseError
6+
from ..errors import (
7+
UnsupportedAlgorithmError,
8+
SecurityWarning,
9+
JoseError,
10+
ExceededSizeError,
11+
)
712
from ..registry import (
813
JWS_HEADER_REGISTRY,
914
Header,
@@ -40,6 +45,13 @@ class Strategy(Enum):
4045
algorithms: dict[str, JWSAlgModel] = {}
4146
recommended: list[str] = []
4247

48+
#: max header content's size in bytes
49+
max_header_length: int = 512
50+
#: max payload content's size in bytes
51+
max_payload_length: int = 8000
52+
#: max signature's size in bytes
53+
max_signature_length: int = 1024
54+
4355
def __init__(
4456
self,
4557
header_registry: HeaderRegistryDict | None = None,
@@ -87,6 +99,18 @@ def check_header(self, header: Header) -> None:
8799
if self.strict_check_header:
88100
check_supported_header(self.header_registry, header)
89101

102+
def validate_header_size(self, header: bytes) -> None:
103+
if header and len(header) > self.max_header_length:
104+
raise ExceededSizeError(f"Header size of '{header!r}' exceeds {self.max_header_length} bytes.")
105+
106+
def validate_payload_size(self, payload: bytes) -> None:
107+
if payload and len(payload) > self.max_payload_length:
108+
raise ExceededSizeError(f"Payload size of '{payload!r}' exceeds {self.max_payload_length} bytes.")
109+
110+
def validate_signature_size(self, signature: bytes) -> None:
111+
if len(signature) > self.max_signature_length:
112+
raise ExceededSizeError(f"Signature of '{signature!r}' exceeds {self.max_signature_length} bytes.")
113+
90114
@classmethod
91115
def guess_alg(cls, key: Any, strategy: Strategy) -> str | None:
92116
"""Guess the JWS algorithm for a given key.

src/joserfc/_rfc7797/compact.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
)
1111
from .._rfc7515.model import JWSAlgModel, CompactSignature
1212
from .._rfc7515.compact import decode_header
13+
from .._rfc7515.registry import JWSRegistry, default_registry
1314
from ..errors import DecodeError
1415
from .util import is_rfc7797_enabled
1516

@@ -27,18 +28,28 @@ def sign_rfc7515_compact(obj: CompactSignature, alg: JWSAlgModel, key: Any) -> b
2728
return out
2829

2930

30-
def extract_rfc7515_compact(value: bytes, payload: bytes | str | None = None) -> CompactSignature:
31+
def extract_rfc7515_compact(
32+
value: bytes, payload: bytes | str | None = None, registry: JWSRegistry | None = None
33+
) -> CompactSignature:
3134
"""Extract the JWS Compact Serialization from bytes to object.
3235
3336
:param value: JWS in bytes
3437
:param payload: optional payload, required with detached content
38+
:param registry: optional JWSRegistry instance
3539
:raise: DecodeError
3640
"""
3741
parts = value.split(b".")
3842
if len(parts) != 3:
3943
raise DecodeError("Invalid JSON Web Signature")
4044

45+
if registry is None:
46+
registry = default_registry
47+
4148
header_segment, payload_segment, signature_segment = parts
49+
50+
registry.validate_header_size(header_segment)
51+
registry.validate_signature_size(signature_segment)
52+
4253
protected = decode_header(header_segment)
4354

4455
if is_rfc7797_enabled(protected):
@@ -50,6 +61,7 @@ def extract_rfc7515_compact(value: bytes, payload: bytes | str | None = None) ->
5061
payload = to_bytes(payload)
5162
payload_segment = urlsafe_b64encode(payload)
5263
else:
64+
registry.validate_payload_size(payload_segment)
5365
try:
5466
payload = urlsafe_b64decode(payload_segment)
5567
except (TypeError, ValueError):

src/joserfc/_rfc7797/json.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from __future__ import annotations
12
from .._rfc7515.types import FlattenedJSONSerialization, JSONSignatureDict
23
from .._rfc7515.model import HeaderMember, FlattenedJSONSignature
34
from .._rfc7515.registry import JWSRegistry
@@ -18,9 +19,10 @@ def sign_rfc7797_json(
1819
return data
1920

2021

21-
def extract_rfc7797_json(value: FlattenedJSONSerialization) -> FlattenedJSONSignature:
22+
def extract_rfc7797_json(value: FlattenedJSONSerialization, registry: JWSRegistry) -> FlattenedJSONSignature:
2223
if "protected" in value:
2324
protected_segment = to_bytes(value["protected"])
25+
registry.validate_header_size(protected_segment)
2426
protected = json_b64decode(protected_segment)
2527
else:
2628
protected = None
@@ -32,6 +34,7 @@ def extract_rfc7797_json(value: FlattenedJSONSerialization) -> FlattenedJSONSign
3234
if is_rfc7797_enabled(member.headers()):
3335
payload = payload_segment
3436
else:
37+
registry.validate_payload_size(payload_segment)
3538
try:
3639
payload = urlsafe_b64decode(payload_segment)
3740
except (TypeError, ValueError):

src/joserfc/errors.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,8 @@ class BadSignatureError(JoseError):
135135

136136

137137
class ExceededSizeError(JoseError):
138-
"""This error is designed for DEF zip algorithm. It raised when the
139-
compressed data exceeds the maximum allowed length."""
138+
"""This error is designed for validating the token's content size.
139+
It raised when the data exceeds the maximum allowed length."""
140140

141141
error = "exceeded_size"
142142

src/joserfc/jws.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def deserialize_compact(
180180
:param payload: optional payload, required with detached content
181181
:return: object of the ``CompactSignature``
182182
"""
183-
obj = extract_compact(to_bytes(value), payload)
183+
obj = extract_compact(to_bytes(value), payload, registry)
184184
if not validate_compact(obj, public_key, algorithms, registry):
185185
raise BadSignatureError()
186186
return obj
@@ -296,12 +296,12 @@ def find_key(obj: HeaderMember) -> Key:
296296
return guess_key(public_key, obj, use="sig")
297297

298298
if "signatures" in value:
299-
general_obj = extract_general_json(value)
299+
general_obj = extract_general_json(value, registry)
300300
if not verify_general_json(general_obj, registry, find_key):
301301
raise BadSignatureError()
302302
return general_obj
303303
else:
304-
flattened_obj = extract_flattened_json(value)
304+
flattened_obj = extract_flattened_json(value, registry)
305305
if not verify_flattened_json(flattened_obj, registry, find_key):
306306
raise BadSignatureError()
307307
return flattened_obj

tests/jws/test_compact.py

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,20 @@
1212
MissingAlgorithmError,
1313
UnsupportedAlgorithmError,
1414
UnsupportedHeaderError,
15+
ExceededSizeError,
1516
)
17+
from joserfc.util import urlsafe_b64encode, json_b64encode
1618

1719

1820
class TestCompact(TestCase):
21+
key = OctKey.import_key("secret")
22+
1923
def test_registry_is_none(self):
20-
key = OctKey.import_key("secret")
21-
value = serialize_compact({"alg": "HS256"}, b"foo", key)
24+
value = serialize_compact({"alg": "HS256"}, b"foo", self.key)
2225
expected = "eyJhbGciOiJIUzI1NiJ9.Zm9v.0pehoi-RMZM1jl-4TP_C4Y6BJ-bcmsuzfDyQpkpJkh0"
2326
self.assertEqual(value, expected)
2427

25-
obj = deserialize_compact(value, key)
28+
obj = deserialize_compact(value, self.key)
2629
self.assertEqual(obj.payload, b"foo")
2730

2831
def test_bad_signature_error(self):
@@ -31,28 +34,42 @@ def test_bad_signature_error(self):
3134
self.assertRaises(BadSignatureError, deserialize_compact, value, key)
3235

3336
def test_raise_unsupported_algorithm_error(self):
34-
key = OctKey.import_key("secret")
35-
self.assertRaises(UnsupportedAlgorithmError, serialize_compact, {"alg": "HS512"}, b"foo", key)
36-
self.assertRaises(UnsupportedAlgorithmError, serialize_compact, {"alg": "NOT"}, b"foo", key)
37+
self.assertRaises(UnsupportedAlgorithmError, serialize_compact, {"alg": "HS512"}, b"foo", self.key)
38+
self.assertRaises(UnsupportedAlgorithmError, serialize_compact, {"alg": "NOT"}, b"foo", self.key)
3739

3840
def test_invalid_length(self):
39-
key = OctKey.import_key("secret")
40-
self.assertRaises(DecodeError, deserialize_compact, b"a.b.c.d", key)
41+
self.assertRaises(DecodeError, deserialize_compact, b"a.b.c.d", self.key)
4142

4243
def test_no_invalid_header(self):
4344
# invalid base64
4445
value = b"abc.Zm9v.0pehoi"
45-
key = OctKey.import_key("secret")
46-
self.assertRaises(DecodeError, deserialize_compact, value, key)
46+
self.assertRaises(DecodeError, deserialize_compact, value, self.key)
4747

4848
# no alg value
4949
value = b"eyJhIjoiYiJ9.Zm9v.0pehoi"
50-
self.assertRaises(MissingAlgorithmError, deserialize_compact, value, key)
50+
self.assertRaises(MissingAlgorithmError, deserialize_compact, value, self.key)
5151

5252
def test_invalid_payload(self):
5353
value = b"eyJhbGciOiJIUzI1NiJ9.a$b.0pehoi"
54-
key = OctKey.import_key("secret")
55-
self.assertRaises(DecodeError, deserialize_compact, value, key)
54+
self.assertRaises(DecodeError, deserialize_compact, value, self.key)
55+
56+
def test_header_exceeded_size_error(self):
57+
exceeded_header = json_b64encode({f"a{i}": f"a{i}" for i in range(1000)})
58+
other = urlsafe_b64encode(b"o")
59+
fake_jws = exceeded_header + b"." + other + b"." + other
60+
self.assertRaises(ExceededSizeError, deserialize_compact, fake_jws, self.key)
61+
62+
def test_payload_exceeded_size_error(self):
63+
header = json_b64encode({"alg": "HS256"})
64+
exceeded_payload = urlsafe_b64encode(("o" * 10000).encode("utf8"))
65+
fake_jws = header + b"." + exceeded_payload + b"." + urlsafe_b64encode(b"o")
66+
self.assertRaises(ExceededSizeError, deserialize_compact, fake_jws, self.key)
67+
68+
def test_signature_exceeded_size_error(self):
69+
header = json_b64encode({"alg": "HS256"})
70+
exceeded_signature = urlsafe_b64encode(("o" * 1000).encode("utf8"))
71+
fake_jws = header + b"." + urlsafe_b64encode(b"o") + b"." + exceeded_signature
72+
self.assertRaises(ExceededSizeError, deserialize_compact, fake_jws, self.key)
5673

5774
def test_with_key_set(self):
5875
keys = KeySet(
@@ -73,20 +90,18 @@ def test_with_key_set(self):
7390

7491
def test_strict_check_header(self):
7592
header = {"alg": "HS256", "custom": "hi"}
76-
key = OctKey.import_key("secret")
77-
self.assertRaises(UnsupportedHeaderError, serialize_compact, header, b"hi", key)
93+
self.assertRaises(UnsupportedHeaderError, serialize_compact, header, b"hi", self.key)
7894

7995
registry = JWSRegistry(strict_check_header=False)
80-
serialize_compact(header, b"hi", key, registry=registry)
96+
serialize_compact(header, b"hi", self.key, registry=registry)
8197

8298
def test_non_canonical_signature_encoding(self):
8399
text = "eyJhbGciOiJIUzI1NiJ9.eyJ1c2VyIjoiYWRtaW4ifQ.VI29GgHzuh2xfF0bkRYvZIsSuQnbTXSIvuRyt7RDrwo"[:-1] + "p"
84100
self.assertRaises(BadSignatureError, deserialize_compact, text, OctKey.import_key("secret"))
85101

86102
def test_detached_content(self):
87-
key = OctKey.import_key("secret")
88-
value = detach_content(serialize_compact({"alg": "HS256"}, b"foo", key))
103+
value = detach_content(serialize_compact({"alg": "HS256"}, b"foo", self.key))
89104
expected = "eyJhbGciOiJIUzI1NiJ9..0pehoi-RMZM1jl-4TP_C4Y6BJ-bcmsuzfDyQpkpJkh0"
90105
self.assertEqual(value, expected)
91-
obj = deserialize_compact(value, key, payload=b"foo")
106+
obj = deserialize_compact(value, self.key, payload=b"foo")
92107
self.assertEqual(obj.payload, b"foo")

0 commit comments

Comments
 (0)