1
+ from __future__ import annotations
1
2
import os
2
3
import typing as t
3
4
from abc import ABCMeta , abstractmethod
4
5
from ..registry import Header , HeaderRegistryDict
5
6
from ..errors import InvalidKeyTypeError , InvalidKeyLengthError
6
- from .._keys import Key , ECKey
7
+ from .._keys import Key , ECKey , OctKey
7
8
8
9
KeyType = t .TypeVar ("KeyType" )
9
10
@@ -12,8 +13,8 @@ class Recipient(t.Generic[KeyType]):
12
13
def __init__ (
13
14
self ,
14
15
parent : t .Union ["CompactEncryption" , "GeneralJSONEncryption" , "FlattenedJSONEncryption" ],
15
- header : t . Optional [ Header ] = None ,
16
- recipient_key : t . Optional [ KeyType ] = None ):
16
+ header : Header | None = None ,
17
+ recipient_key : KeyType | None = None ):
17
18
self .__parent = parent
18
19
self .header = header
19
20
self .recipient_key = recipient_key
@@ -30,35 +31,35 @@ def headers(self) -> Header:
30
31
rv .update (self .header )
31
32
return rv
32
33
33
- def add_header (self , k : str , v : t .Any ):
34
+ def add_header (self , k : str , v : t .Any ) -> None :
34
35
if isinstance (self .__parent , CompactEncryption ):
35
36
self .__parent .protected .update ({k : v })
36
37
elif self .header :
37
38
self .header .update ({k : v })
38
39
else :
39
40
self .header = {k : v }
40
41
41
- def set_kid (self , kid : str ):
42
+ def set_kid (self , kid : str ) -> None :
42
43
self .add_header ("kid" , kid )
43
44
44
45
45
46
class CompactEncryption :
46
47
"""An object to represent the JWE Compact Serialization. It is usually returned by
47
48
``decrypt_compact`` method.
48
49
"""
49
- def __init__ (self , protected : Header , plaintext : t . Optional [ bytes ] = None ):
50
+ def __init__ (self , protected : Header , plaintext : bytes | None = None ):
50
51
#: protected header in dict
51
52
self .protected = protected
52
53
#: the plaintext in bytes
53
54
self .plaintext = plaintext
54
- self .recipient : t . Optional [ Recipient ] = None
55
+ self .recipient : Recipient [ t . Any ] | None = None
55
56
self .bytes_segments : t .Dict [str , bytes ] = {} # store the decoded segments
56
57
self .base64_segments : t .Dict [str , bytes ] = {} # store the encoded segments
57
58
58
59
def headers (self ) -> Header :
59
60
return self .protected
60
61
61
- def attach_recipient (self , key : Key , header : t . Optional [ Header ] = None ):
62
+ def attach_recipient (self , key : Key , header : Header | None = None ) -> None :
62
63
"""Add a recipient to the JWE Compact Serialization. Please add a key that
63
64
comply with the given "alg" value.
64
65
@@ -71,7 +72,7 @@ def attach_recipient(self, key: Key, header: t.Optional[Header] = None):
71
72
self .recipient = recipient
72
73
73
74
@property
74
- def recipients (self ) -> t . List [Recipient ]:
75
+ def recipients (self ) -> list [Recipient [ t . Any ] ]:
75
76
if self .recipient is not None :
76
77
return [self .recipient ]
77
78
return []
@@ -89,14 +90,14 @@ class BaseJSONEncryption(metaclass=ABCMeta):
89
90
#: an optional additional authenticated data
90
91
aad : t .Optional [bytes ]
91
92
#: a list of recipients
92
- recipients : t .List [Recipient ]
93
+ recipients : t .List [Recipient [ t . Any ] ]
93
94
94
95
def __init__ (
95
96
self ,
96
97
protected : Header ,
97
- plaintext : t . Optional [ bytes ] = None ,
98
- unprotected : t . Optional [ Header ] = None ,
99
- aad : t . Optional [ bytes ] = None ):
98
+ plaintext : bytes | None = None ,
99
+ unprotected : Header | None = None ,
100
+ aad : bytes | None = None ):
100
101
self .protected = protected
101
102
self .plaintext = plaintext
102
103
self .unprotected = unprotected
@@ -106,7 +107,7 @@ def __init__(
106
107
self .base64_segments : t .Dict [str , bytes ] = {} # store the encoded segments
107
108
108
109
@abstractmethod
109
- def add_recipient (self , header : t . Optional [ Header ] = None , key : t . Optional [ Key ] = None ):
110
+ def add_recipient (self , header : Header | None = None , key : Key | None = None ) -> None :
110
111
"""Add a recipient to the JWE JSON Serialization. Please add a key that
111
112
comply with the "alg" to this recipient.
112
113
@@ -131,7 +132,7 @@ class GeneralJSONEncryption(BaseJSONEncryption):
131
132
"""
132
133
flattened = False
133
134
134
- def add_recipient (self , header : t . Optional [ Header ] = None , key : t . Optional [ Key ] = None ):
135
+ def add_recipient (self , header : Header | None = None , key : Key | None = None ) -> None :
135
136
recipient = Recipient (self , header , key )
136
137
self .recipients .append (recipient )
137
138
@@ -152,7 +153,7 @@ class FlattenedJSONEncryption(BaseJSONEncryption):
152
153
"""
153
154
flattened = True
154
155
155
- def add_recipient (self , header : t . Optional [ Header ] = None , key : t . Optional [ Key ] = None ):
156
+ def add_recipient (self , header : Header | None = None , key : Key | None = None ) -> None :
156
157
self .recipients = [Recipient (self , header , key )]
157
158
158
159
@@ -178,7 +179,7 @@ def check_iv(self, iv: bytes) -> bytes:
178
179
return iv
179
180
180
181
@abstractmethod
181
- def encrypt (self , plaintext : bytes , cek : bytes , iv : bytes , aad : bytes ) -> t . Tuple [bytes , bytes ]:
182
+ def encrypt (self , plaintext : bytes , cek : bytes , iv : bytes , aad : bytes ) -> tuple [bytes , bytes ]:
182
183
pass
183
184
184
185
@abstractmethod
@@ -216,19 +217,19 @@ class KeyManagement:
216
217
def direct_mode (self ) -> bool :
217
218
return self .key_size is None
218
219
219
- def check_key_type (self , key : Key ):
220
+ def check_key_type (self , key : Key ) -> None :
220
221
if key .key_type not in self .key_types :
221
222
raise InvalidKeyTypeError ()
222
223
223
- def prepare_recipient_header (self , recipient : Recipient ) :
224
+ def prepare_recipient_header (self , recipient : Recipient [ t . Any ]) -> None :
224
225
raise NotImplementedError ()
225
226
226
227
227
228
class JWEDirectEncryption (KeyManagement , metaclass = ABCMeta ):
228
229
key_types = ["oct" ]
229
230
230
231
@abstractmethod
231
- def compute_cek (self , size : int , recipient : Recipient ) -> bytes :
232
+ def compute_cek (self , size : int , recipient : Recipient [ OctKey ] ) -> bytes :
232
233
pass
233
234
234
235
@@ -238,11 +239,11 @@ def direct_mode(self) -> bool:
238
239
return False
239
240
240
241
@abstractmethod
241
- def encrypt_cek (self , cek : bytes , recipient : Recipient ) -> bytes :
242
+ def encrypt_cek (self , cek : bytes , recipient : Recipient [ t . Any ] ) -> bytes :
242
243
pass
243
244
244
245
@abstractmethod
245
- def decrypt_cek (self , recipient : Recipient ) -> bytes :
246
+ def decrypt_cek (self , recipient : Recipient [ t . Any ] ) -> bytes :
246
247
pass
247
248
248
249
@@ -254,7 +255,7 @@ class JWEKeyWrapping(KeyManagement, metaclass=ABCMeta):
254
255
def direct_mode (self ) -> bool :
255
256
return False
256
257
257
- def check_op_key (self , op_key : bytes ):
258
+ def check_op_key (self , op_key : bytes ) -> None :
258
259
if len (op_key ) * 8 != self .key_size :
259
260
raise InvalidKeyLengthError (f"A key of size { self .key_size } bits MUST be used" )
260
261
@@ -267,11 +268,11 @@ def unwrap_cek(self, ek: bytes, key: bytes) -> bytes:
267
268
pass
268
269
269
270
@abstractmethod
270
- def encrypt_cek (self , cek : bytes , recipient : Recipient ) -> bytes :
271
+ def encrypt_cek (self , cek : bytes , recipient : Recipient [ OctKey ] ) -> bytes :
271
272
pass
272
273
273
274
@abstractmethod
274
- def decrypt_cek (self , recipient : Recipient ) -> bytes :
275
+ def decrypt_cek (self , recipient : Recipient [ OctKey ] ) -> bytes :
275
276
pass
276
277
277
278
@@ -280,7 +281,7 @@ class JWEKeyAgreement(KeyManagement, metaclass=ABCMeta):
280
281
tag_aware : bool = False
281
282
key_wrapping : t .Optional [JWEKeyWrapping ]
282
283
283
- def prepare_ephemeral_key (self , recipient : Recipient [ECKey ]):
284
+ def prepare_ephemeral_key (self , recipient : Recipient [ECKey ]) -> None :
284
285
recipient_key = recipient .recipient_key
285
286
assert recipient_key is not None
286
287
self .check_key_type (recipient_key )
0 commit comments