1+ from __future__ import annotations
12import typing as t
23from .models import (
34 BaseJSONEncryption ,
45 GeneralJSONEncryption ,
56 FlattenedJSONEncryption ,
67 Recipient ,
78)
9+ from .registry import JWERegistry
810from .types import (
911 JSONRecipientDict ,
1012 GeneralJSONSerialization ,
@@ -70,44 +72,49 @@ def __represent_json_serialization(obj: BaseJSONEncryption): # type: ignore[no-
7072 return data
7173
7274
73- def extract_general_json (data : GeneralJSONSerialization ) -> GeneralJSONEncryption :
74- protected = json_b64decode (data ["protected" ])
75+ def extract_general_json (data : GeneralJSONSerialization , registry : JWERegistry ) -> GeneralJSONEncryption :
76+ protected_segment = to_bytes (data ["protected" ])
77+ registry .validate_protected_header_size (protected_segment )
78+ protected = json_b64decode (protected_segment )
79+
7580 unprotected = data .get ("unprotected" )
76- base64_segments , bytes_segments , aad = __extract_segments (data )
81+ base64_segments , bytes_segments , aad = __extract_segments (data , registry )
82+
7783 obj = GeneralJSONEncryption (protected , None , unprotected , aad )
7884 obj .base64_segments = base64_segments
7985 obj .bytes_segments = bytes_segments
8086 for item in data ["recipients" ]:
81- recipient : Recipient [Key ] = Recipient (obj , item .get ("header" ))
82- if "encrypted_key" in item :
83- recipient .encrypted_key = urlsafe_b64decode (to_bytes (item ["encrypted_key" ]))
87+ recipient = __extract_recipient (obj , item , registry )
8488 obj .recipients .append (recipient )
8589 return obj
8690
8791
88- def extract_flattened_json (data : FlattenedJSONSerialization ) -> FlattenedJSONEncryption :
89- protected = json_b64decode (data ["protected" ])
92+ def extract_flattened_json (data : FlattenedJSONSerialization , registry : JWERegistry ) -> FlattenedJSONEncryption :
93+ protected_segment = to_bytes (data ["protected" ])
94+ registry .validate_protected_header_size (protected_segment )
95+ protected = json_b64decode (protected_segment )
9096 unprotected = data .get ("unprotected" )
91- base64_segments , bytes_segments , aad = __extract_segments (data )
97+ base64_segments , bytes_segments , aad = __extract_segments (data , registry )
9298 obj = FlattenedJSONEncryption (protected , None , unprotected , aad )
9399 obj .base64_segments = base64_segments
94100 obj .bytes_segments = bytes_segments
95-
96- recipient : Recipient [Key ] = Recipient (obj , data .get ("header" ))
97- if "encrypted_key" in data :
98- recipient .encrypted_key = urlsafe_b64decode (to_bytes (data ["encrypted_key" ]))
101+ recipient = __extract_recipient (obj , data , registry )
99102 obj .recipients .append (recipient )
100103 return obj
101104
102105
103106def __extract_segments (
104107 data : t .Union [GeneralJSONSerialization , FlattenedJSONSerialization ],
108+ registry : JWERegistry ,
105109) -> tuple [dict [str , bytes ], dict [str , bytes ], t .Optional [bytes ]]:
106110 base64_segments : dict [str , bytes ] = {
107111 "iv" : to_bytes (data ["iv" ]),
108112 "ciphertext" : to_bytes (data ["ciphertext" ]),
109113 "tag" : to_bytes (data ["tag" ]),
110114 }
115+ registry .validate_initialization_vector_size (base64_segments ["iv" ])
116+ registry .validate_ciphertext_size (base64_segments ["ciphertext" ])
117+ registry .validate_auth_tag_size (base64_segments ["tag" ])
111118 bytes_segments : dict [str , bytes ] = {
112119 "iv" : urlsafe_b64decode (base64_segments ["iv" ]),
113120 "ciphertext" : urlsafe_b64decode (base64_segments ["ciphertext" ]),
@@ -118,3 +125,16 @@ def __extract_segments(
118125 else :
119126 aad = None
120127 return base64_segments , bytes_segments , aad
128+
129+
130+ def __extract_recipient (
131+ obj : FlattenedJSONEncryption | GeneralJSONEncryption ,
132+ data : FlattenedJSONSerialization | JSONRecipientDict ,
133+ registry : JWERegistry ,
134+ ) -> Recipient [Key ]:
135+ recipient : Recipient [Key ] = Recipient (obj , data .get ("header" ))
136+ if "encrypted_key" in data :
137+ ek_segment = to_bytes (data ["encrypted_key" ])
138+ registry .validate_encrypted_key_size (ek_segment )
139+ recipient .encrypted_key = urlsafe_b64decode (ek_segment )
140+ return recipient
0 commit comments