Skip to content

Commit 9854013

Browse files
committed
feat(types): use strict type hint
1 parent 3aed632 commit 9854013

File tree

9 files changed

+82
-70
lines changed

9 files changed

+82
-70
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ exclude_lines = [
7272
]
7373

7474
[tool.mypy]
75+
strict = true
7576
python_version = "3.8"
7677
files = ["src/joserfc"]
7778
show_error_codes = true

src/joserfc/jws.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def serialize_json(
227227
if registry is None:
228228
registry = construct_registry(algorithms)
229229

230-
def find_key(obj: Any):
230+
def find_key(obj: Any) -> Key:
231231
return guess_key(private_key, obj, True)
232232

233233
_payload = to_bytes(payload)
@@ -271,7 +271,7 @@ def deserialize_json(
271271
if registry is None:
272272
registry = construct_registry(algorithms)
273273

274-
def find_key(obj: Any):
274+
def find_key(obj: Any) -> Key:
275275
return guess_key(public_key, obj)
276276

277277
if "signatures" in value:

src/joserfc/rfc7516/json.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ def extract_flattened_json(data: FlattenedJSONSerialization) -> FlattenedJSONEnc
9393

9494

9595
def __extract_segments(
96-
data: t.Union[GeneralJSONSerialization, FlattenedJSONSerialization]): # type: ignore[no-untyped-def]
96+
data: t.Union[GeneralJSONSerialization, FlattenedJSONSerialization]
97+
) -> t.Tuple[t.Dict[str, bytes], t.Dict[str, bytes], t.Optional[bytes]]:
9798
base64_segments: t.Dict[str, bytes] = {
9899
"iv": to_bytes(data["iv"]),
99100
"ciphertext": to_bytes(data["ciphertext"]),

src/joserfc/rfc7516/models.py

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
from __future__ import annotations
12
import os
23
import typing as t
34
from abc import ABCMeta, abstractmethod
45
from ..registry import Header, HeaderRegistryDict
56
from ..errors import InvalidKeyTypeError, InvalidKeyLengthError
6-
from .._keys import Key, ECKey
7+
from .._keys import Key, ECKey, OctKey
78

89
KeyType = t.TypeVar("KeyType")
910

@@ -12,8 +13,8 @@ class Recipient(t.Generic[KeyType]):
1213
def __init__(
1314
self,
1415
parent: t.Union["CompactEncryption", "GeneralJSONEncryption", "FlattenedJSONEncryption"],
15-
header: t.Optional[Header] = None,
16-
recipient_key: t.Optional[KeyType] = None):
16+
header: Header | None = None,
17+
recipient_key: KeyType | None = None):
1718
self.__parent = parent
1819
self.header = header
1920
self.recipient_key = recipient_key
@@ -30,35 +31,35 @@ def headers(self) -> Header:
3031
rv.update(self.header)
3132
return rv
3233

33-
def add_header(self, k: str, v: t.Any):
34+
def add_header(self, k: str, v: t.Any) -> None:
3435
if isinstance(self.__parent, CompactEncryption):
3536
self.__parent.protected.update({k: v})
3637
elif self.header:
3738
self.header.update({k: v})
3839
else:
3940
self.header = {k: v}
4041

41-
def set_kid(self, kid: str):
42+
def set_kid(self, kid: str) -> None:
4243
self.add_header("kid", kid)
4344

4445

4546
class CompactEncryption:
4647
"""An object to represent the JWE Compact Serialization. It is usually returned by
4748
``decrypt_compact`` method.
4849
"""
49-
def __init__(self, protected: Header, plaintext: t.Optional[bytes] = None):
50+
def __init__(self, protected: Header, plaintext: bytes | None = None):
5051
#: protected header in dict
5152
self.protected = protected
5253
#: the plaintext in bytes
5354
self.plaintext = plaintext
54-
self.recipient: t.Optional[Recipient] = None
55+
self.recipient: Recipient[t.Any] | None = None
5556
self.bytes_segments: t.Dict[str, bytes] = {} # store the decoded segments
5657
self.base64_segments: t.Dict[str, bytes] = {} # store the encoded segments
5758

5859
def headers(self) -> Header:
5960
return self.protected
6061

61-
def attach_recipient(self, key: Key, header: t.Optional[Header] = None):
62+
def attach_recipient(self, key: Key, header: Header | None = None) -> None:
6263
"""Add a recipient to the JWE Compact Serialization. Please add a key that
6364
comply with the given "alg" value.
6465
@@ -71,7 +72,7 @@ def attach_recipient(self, key: Key, header: t.Optional[Header] = None):
7172
self.recipient = recipient
7273

7374
@property
74-
def recipients(self) -> t.List[Recipient]:
75+
def recipients(self) -> list[Recipient[t.Any]]:
7576
if self.recipient is not None:
7677
return [self.recipient]
7778
return []
@@ -89,14 +90,14 @@ class BaseJSONEncryption(metaclass=ABCMeta):
8990
#: an optional additional authenticated data
9091
aad: t.Optional[bytes]
9192
#: a list of recipients
92-
recipients: t.List[Recipient]
93+
recipients: t.List[Recipient[t.Any]]
9394

9495
def __init__(
9596
self,
9697
protected: Header,
97-
plaintext: t.Optional[bytes] = None,
98-
unprotected: t.Optional[Header] = None,
99-
aad: t.Optional[bytes] = None):
98+
plaintext: bytes | None = None,
99+
unprotected: Header | None = None,
100+
aad: bytes | None = None):
100101
self.protected = protected
101102
self.plaintext = plaintext
102103
self.unprotected = unprotected
@@ -106,7 +107,7 @@ def __init__(
106107
self.base64_segments: t.Dict[str, bytes] = {} # store the encoded segments
107108

108109
@abstractmethod
109-
def add_recipient(self, header: t.Optional[Header] = None, key: t.Optional[Key] = None):
110+
def add_recipient(self, header: Header | None = None, key: Key | None = None) -> None:
110111
"""Add a recipient to the JWE JSON Serialization. Please add a key that
111112
comply with the "alg" to this recipient.
112113
@@ -131,7 +132,7 @@ class GeneralJSONEncryption(BaseJSONEncryption):
131132
"""
132133
flattened = False
133134

134-
def add_recipient(self, header: t.Optional[Header] = None, key: t.Optional[Key] = None):
135+
def add_recipient(self, header: Header | None = None, key: Key | None = None) -> None:
135136
recipient = Recipient(self, header, key)
136137
self.recipients.append(recipient)
137138

@@ -152,7 +153,7 @@ class FlattenedJSONEncryption(BaseJSONEncryption):
152153
"""
153154
flattened = True
154155

155-
def add_recipient(self, header: t.Optional[Header] = None, key: t.Optional[Key] = None):
156+
def add_recipient(self, header: Header | None = None, key: Key | None = None) -> None:
156157
self.recipients = [Recipient(self, header, key)]
157158

158159

@@ -178,7 +179,7 @@ def check_iv(self, iv: bytes) -> bytes:
178179
return iv
179180

180181
@abstractmethod
181-
def encrypt(self, plaintext: bytes, cek: bytes, iv: bytes, aad: bytes) -> t.Tuple[bytes, bytes]:
182+
def encrypt(self, plaintext: bytes, cek: bytes, iv: bytes, aad: bytes) -> tuple[bytes, bytes]:
182183
pass
183184

184185
@abstractmethod
@@ -216,19 +217,19 @@ class KeyManagement:
216217
def direct_mode(self) -> bool:
217218
return self.key_size is None
218219

219-
def check_key_type(self, key: Key):
220+
def check_key_type(self, key: Key) -> None:
220221
if key.key_type not in self.key_types:
221222
raise InvalidKeyTypeError()
222223

223-
def prepare_recipient_header(self, recipient: Recipient):
224+
def prepare_recipient_header(self, recipient: Recipient[t.Any]) -> None:
224225
raise NotImplementedError()
225226

226227

227228
class JWEDirectEncryption(KeyManagement, metaclass=ABCMeta):
228229
key_types = ["oct"]
229230

230231
@abstractmethod
231-
def compute_cek(self, size: int, recipient: Recipient) -> bytes:
232+
def compute_cek(self, size: int, recipient: Recipient[OctKey]) -> bytes:
232233
pass
233234

234235

@@ -238,11 +239,11 @@ def direct_mode(self) -> bool:
238239
return False
239240

240241
@abstractmethod
241-
def encrypt_cek(self, cek: bytes, recipient: Recipient) -> bytes:
242+
def encrypt_cek(self, cek: bytes, recipient: Recipient[t.Any]) -> bytes:
242243
pass
243244

244245
@abstractmethod
245-
def decrypt_cek(self, recipient: Recipient) -> bytes:
246+
def decrypt_cek(self, recipient: Recipient[t.Any]) -> bytes:
246247
pass
247248

248249

@@ -254,7 +255,7 @@ class JWEKeyWrapping(KeyManagement, metaclass=ABCMeta):
254255
def direct_mode(self) -> bool:
255256
return False
256257

257-
def check_op_key(self, op_key: bytes):
258+
def check_op_key(self, op_key: bytes) -> None:
258259
if len(op_key) * 8 != self.key_size:
259260
raise InvalidKeyLengthError(f"A key of size {self.key_size} bits MUST be used")
260261

@@ -267,11 +268,11 @@ def unwrap_cek(self, ek: bytes, key: bytes) -> bytes:
267268
pass
268269

269270
@abstractmethod
270-
def encrypt_cek(self, cek: bytes, recipient: Recipient) -> bytes:
271+
def encrypt_cek(self, cek: bytes, recipient: Recipient[OctKey]) -> bytes:
271272
pass
272273

273274
@abstractmethod
274-
def decrypt_cek(self, recipient: Recipient) -> bytes:
275+
def decrypt_cek(self, recipient: Recipient[OctKey]) -> bytes:
275276
pass
276277

277278

@@ -280,7 +281,7 @@ class JWEKeyAgreement(KeyManagement, metaclass=ABCMeta):
280281
tag_aware: bool = False
281282
key_wrapping: t.Optional[JWEKeyWrapping]
282283

283-
def prepare_ephemeral_key(self, recipient: Recipient[ECKey]):
284+
def prepare_ephemeral_key(self, recipient: Recipient[ECKey]) -> None:
284285
recipient_key = recipient.recipient_key
285286
assert recipient_key is not None
286287
self.check_key_type(recipient_key)

src/joserfc/rfc7516/registry.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,12 @@ def __init__(
5151
self.strict_check_header = strict_check_header
5252

5353
@classmethod
54-
def register(cls, model: JWEAlgorithm):
54+
def register(cls, model: JWEAlgorithm) -> None:
5555
cls.algorithms[model.algorithm_location][model.name] = model # type: ignore
5656
if model.recommended:
5757
cls.recommended.append(model.name)
5858

59-
def check_header(self, header: Header, check_more=False):
59+
def check_header(self, header: Header, check_more: bool = False) -> None:
6060
"""Check and validate the fields in header part of a JWS object."""
6161
check_crit_header(header)
6262
validate_registry_header(self.header_registry, header)
@@ -77,24 +77,29 @@ def get_alg(self, name: str) -> JWEAlgModel:
7777
7878
:param name: value of the ``alg``, e.g. ``ECDH-ES``, ``A128KW``
7979
"""
80-
return self._get_algorithm("alg", name)
80+
registry = self.algorithms["alg"]
81+
self._check_algorithm(name, registry)
82+
return registry[name]
8183

8284
def get_enc(self, name: str) -> JWEEncModel:
8385
"""Get the allowed ("enc") algorithm instance of the given name.
8486
8587
:param name: value of the ``enc``, e.g. ``A128CBC-HS256``, ``A128GCM``
8688
"""
87-
return self._get_algorithm("enc", name)
89+
registry = self.algorithms["enc"]
90+
self._check_algorithm(name, registry)
91+
return registry[name]
8892

8993
def get_zip(self, name: str) -> JWEZipModel:
9094
"""Get the allowed ("zip") algorithm instance of the given name.
9195
9296
:param name: value of the ``zip``, e.g. ``DEF``
9397
"""
94-
return self._get_algorithm("zip", name)
98+
registry = self.algorithms["zip"]
99+
self._check_algorithm(name, registry)
100+
return registry[name]
95101

96-
def _get_algorithm(self, location: t.Literal["alg", "enc", "zip"], name: str):
97-
registry: t.Dict[str, JWEAlgorithm] = self.algorithms[location] # type: ignore
102+
def _check_algorithm(self, name: str, registry: dict[str, t.Any]) -> None:
98103
if name not in registry:
99104
raise ValueError(f'Algorithm of "{name}" is not supported')
100105

@@ -105,7 +110,6 @@ def _get_algorithm(self, location: t.Literal["alg", "enc", "zip"], name: str):
105110

106111
if name not in allowed:
107112
raise ValueError(f'Algorithm of "{name}" is not allowed')
108-
return registry[name]
109113

110114

111115
default_registry = JWERegistry()

src/joserfc/rfc7517/models.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
from __future__ import annotations
12
import typing as t
2-
from typing import overload
3+
from collections.abc import KeysView
34
from abc import ABCMeta, abstractmethod
45
from .types import DictKey, AnyKey, KeyParameters
56
from ..registry import (
@@ -82,7 +83,7 @@ class BaseKey(t.Generic[NativePrivateKey, NativePublicKey], metaclass=ABCMeta):
8283

8384
def __init__(
8485
self,
85-
raw_value: t.Union[NativePrivateKey, NativePublicKey],
86+
raw_value: NativePrivateKey | NativePublicKey,
8687
original_value: t.Any,
8788
parameters: t.Optional[KeyParameters] = None):
8889
self._raw_value = raw_value
@@ -97,13 +98,13 @@ def __init__(
9798
self.validate_dict_key(data)
9899
self._dict_value = data
99100

100-
def keys(self):
101+
def keys(self) -> KeysView[str]:
101102
return self.dict_value.keys()
102103

103-
def __getitem__(self, k: str):
104+
def __getitem__(self, k: str) -> str | list[str]:
104105
return self.dict_value[k]
105106

106-
def get(self, k: str, default=None):
107+
def get(self, k: str, default: str | None = None) -> str | list[str] | None:
107108
return self.dict_value.get(k, default)
108109

109110
def ensure_kid(self) -> None:
@@ -114,17 +115,17 @@ def ensure_kid(self) -> None:
114115
self._dict_value["kid"] = self.thumbprint()
115116

116117
@property
117-
def kid(self) -> t.Optional[str]:
118+
def kid(self) -> str | None:
118119
"""The "kid" value of the JSON Web Key."""
119-
return self.get("kid")
120+
return t.cast(t.Optional[str], self.get("kid"))
120121

121122
@property
122-
def alg(self) -> t.Optional[str]:
123+
def alg(self) -> str | None:
123124
"""The "alg" value of the JSON Web Key."""
124-
return self.get("alg")
125+
return t.cast(t.Optional[str], self.get("alg"))
125126

126127
@property
127-
def raw_value(self):
128+
def raw_value(self) -> t.Any:
128129
raise NotImplementedError()
129130

130131
@property
@@ -220,13 +221,13 @@ def check_key_op(self, operation: str) -> None:
220221
if reg.private and not self.is_private:
221222
raise UnsupportedKeyOperationError(f'Invalid key_op "{operation}" for public key')
222223

223-
@overload
224+
@t.overload
224225
def get_op_key(self, operation: t.Literal["verify", "encrypt", "wrapKey", "deriveKey"]) -> NativePublicKey: ...
225226

226-
@overload
227+
@t.overload
227228
def get_op_key(self, operation: t.Literal["sign", "decrypt", "unwrapKey"]) -> NativePrivateKey: ...
228229

229-
def get_op_key(self, operation: str) -> t.Union[NativePublicKey, NativePrivateKey]:
230+
def get_op_key(self, operation: str) -> NativePublicKey | NativePrivateKey:
230231
self.check_key_op(operation)
231232
reg = self.operation_registry[operation]
232233
if reg.private:
@@ -235,7 +236,7 @@ def get_op_key(self, operation: str) -> t.Union[NativePublicKey, NativePrivateKe
235236
return self.public_key
236237

237238
@classmethod
238-
def validate_dict_key(cls, data: DictKey):
239+
def validate_dict_key(cls, data: DictKey) -> None:
239240
cls.binding.validate_dict_key_registry(data, cls.param_registry)
240241
cls.binding.validate_dict_key_registry(data, cls.value_registry)
241242
cls.binding.validate_dict_key_use_operations(data)
@@ -257,7 +258,7 @@ def import_key(
257258
@classmethod
258259
def generate_key(
259260
cls: t.Type[GenericKey],
260-
size_or_crv,
261+
size_or_crv: t.Any,
261262
parameters: t.Optional[KeyParameters] = None,
262263
private: bool = True,
263264
auto_kid: bool = False) -> GenericKey:
@@ -312,5 +313,5 @@ def curve_name(self) -> str:
312313
pass
313314

314315
@abstractmethod
315-
def exchange_derive_key(self, key) -> bytes:
316+
def exchange_derive_key(self, key: t.Any) -> bytes:
316317
pass

0 commit comments

Comments
 (0)