Skip to content

Commit

Permalink
fix: improve type hints for mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
lepture committed Jul 30, 2023
1 parent 3d3b595 commit 5d77591
Show file tree
Hide file tree
Showing 16 changed files with 81 additions and 87 deletions.
15 changes: 15 additions & 0 deletions src/joserfc/_keys.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import typing as t
from .rfc7518.oct_key import OctKey
from .rfc7518.rsa_key import RSAKey
from .rfc7518.ec_key import ECKey
from .rfc8037.okp_key import OKPKey

Key = t.Union[OctKey, RSAKey, ECKey, OKPKey]


__all__ = [
"OctKey",
"RSAKey",
"ECKey",
"OKPKey",
]
16 changes: 8 additions & 8 deletions src/joserfc/drafts/jwe_ecdh_1pu.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
JWEKeyWrapping,
JWEEncModel
)
from ..rfc7517.models import CurveKey
from ..rfc7518.jwe_algs import (
A128KW,
A192KW,
A256KW,
)
from ..rfc7518.ec_key import ECKey
from ..rfc7518.derive_key import (
derive_key_for_concat_kdf,
)
Expand Down Expand Up @@ -55,32 +55,32 @@ def _check_enc(self, enc: JWEEncModel) -> None:
)
raise InvalidEncryptionAlgorithmError(description)

def encrypt_agreed_upon_key(self, enc: JWEEncModel, recipient: Recipient[CurveKey]) -> bytes:
def encrypt_agreed_upon_key(self, enc: JWEEncModel, recipient: Recipient[ECKey]) -> bytes:
self._check_enc(enc)
return self.__encrypt_agreed_upon_key(enc, recipient, None)

def encrypt_agreed_upon_key_with_tag(
self,
enc: JWEEncModel,
recipient: Recipient[CurveKey],
recipient: Recipient[ECKey],
tag: bytes) -> bytes:
self._check_enc(enc)
return self.__encrypt_agreed_upon_key(enc, recipient, tag)

def decrypt_agreed_upon_key(self, enc: JWEEncModel, recipient: Recipient[CurveKey]) -> bytes:
def decrypt_agreed_upon_key(self, enc: JWEEncModel, recipient: Recipient[ECKey]) -> bytes:
return self.__decrypt_agreed_upon_key(enc, recipient, None)

def decrypt_agreed_upon_key_with_tag(
self,
enc: JWEEncModel,
recipient: Recipient[CurveKey],
recipient: Recipient[ECKey],
tag: bytes) -> bytes:
return self.__decrypt_agreed_upon_key(enc, recipient, tag)

def __encrypt_agreed_upon_key(
self,
enc: JWEEncModel,
recipient: Recipient[CurveKey],
recipient: Recipient[ECKey],
tag: t.Optional[bytes]) -> bytes:
sender_key = recipient.sender_key
recipient_key = recipient.recipient_key
Expand All @@ -98,7 +98,7 @@ def __encrypt_agreed_upon_key(
def __decrypt_agreed_upon_key(
self,
enc: JWEEncModel,
recipient: Recipient[CurveKey],
recipient: Recipient[ECKey],
tag: t.Optional[bytes]) -> bytes:

self._check_enc(enc)
Expand All @@ -110,7 +110,7 @@ def __decrypt_agreed_upon_key(
assert sender_key is not None
assert recipient_key is not None

ephemeral_key: CurveKey = recipient_key.import_key(headers["epk"]) # type: ignore[assignment]
ephemeral_key = recipient_key.import_key(headers["epk"])
sender_shared_key = recipient_key.exchange_derive_key(sender_key)
ephemeral_shared_key = recipient_key.exchange_derive_key(ephemeral_key)
shared_key = ephemeral_shared_key + sender_shared_key
Expand Down
4 changes: 2 additions & 2 deletions src/joserfc/jwe.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def register_algorithms() -> None:

def encrypt_compact(
protected: Header,
plaintext: t.AnyStr,
plaintext: t.Union[bytes, str],
public_key: KeyFlexible,
algorithms: t.Optional[t.List[str]] = None,
registry: t.Optional[JWERegistry] = None,
Expand Down Expand Up @@ -111,7 +111,7 @@ def encrypt_compact(


def decrypt_compact(
value: t.AnyStr,
value: t.Union[bytes, str],
private_key: KeyFlexible,
algorithms: t.Optional[t.List[str]] = None,
registry: t.Optional[JWERegistry] = None,
Expand Down
23 changes: 6 additions & 17 deletions src/joserfc/jwk.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,6 @@
import typing as t
from .rfc7517 import (
BaseKey,
SymmetricKey,
AsymmetricKey,
CurveKey,
JWKRegistry as _JWKRegistry,
KeySet as _KeySet,
)
from .rfc7517 import types
from .rfc7517.registry import JWKRegistry as _JWKRegistry
from .rfc7517.keyset import KeySet as _KeySet
from .rfc7518.oct_key import OctKey
from .rfc7518.rsa_key import RSAKey
from .rfc7518.ec_key import ECKey
Expand All @@ -17,12 +10,7 @@


__all__ = [
"types",
"JWKRegistry",
"BaseKey",
"SymmetricKey",
"AsymmetricKey",
"CurveKey",
"Key",
"KeyCallable",
"KeyFlexible",
Expand Down Expand Up @@ -87,11 +75,12 @@ def guess_key(key: KeyFlexible, obj: GuestProtocol) -> Key:
"""
headers = obj.headers()

rv_key: Key
if isinstance(key, (str, bytes)):
rv_key = OctKey.import_key(key) # type: ignore
rv_key = OctKey.import_key(key)

elif isinstance(key, BaseKey):
rv_key: Key = key # type: ignore
elif isinstance(key, (OctKey, RSAKey, ECKey, OKPKey)):
rv_key = key

elif isinstance(key, KeySet):
kid = headers.get("kid")
Expand Down
10 changes: 5 additions & 5 deletions src/joserfc/jws.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def register_algorithms() -> None:

def serialize_compact(
protected: Header,
payload: t.AnyStr,
payload: t.Union[bytes, str],
private_key: KeyFlexible,
algorithms: t.Optional[t.List[str]] = None,
registry: t.Optional[JWSRegistry] = None) -> str:
Expand Down Expand Up @@ -144,7 +144,7 @@ def validate_compact(


def deserialize_compact(
value: t.AnyStr,
value: t.Union[bytes, str],
public_key: KeyFlexible,
algorithms: t.Optional[t.List[str]] = None,
registry: t.Optional[JWSRegistry] = None) -> CompactSignature:
Expand Down Expand Up @@ -176,7 +176,7 @@ def deserialize_compact(
@overload
def serialize_json(
members: t.List[HeaderDict],
payload: t.AnyStr,
payload: t.Union[bytes, str],
private_key: KeyFlexible,
algorithms: t.Optional[t.List[str]] = None,
registry: t.Optional[JWSRegistry] = None) -> GeneralJSONSerialization: ...
Expand All @@ -185,15 +185,15 @@ def serialize_json(
@overload
def serialize_json(
members: HeaderDict,
payload: t.AnyStr,
payload: t.Union[bytes, str],
private_key: KeyFlexible,
algorithms: t.Optional[t.List[str]] = None,
registry: t.Optional[JWSRegistry] = None) -> FlattenedJSONSerialization: ...


def serialize_json(
members: t.Union[HeaderDict, t.List[HeaderDict]],
payload: t.AnyStr,
payload: t.Union[bytes, str],
private_key: KeyFlexible,
algorithms: t.Optional[t.List[str]] = None,
registry: t.Optional[JWSRegistry] = None,
Expand Down
2 changes: 1 addition & 1 deletion src/joserfc/jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def encode(


def decode(
value: t.AnyStr,
value: t.Union[bytes, str],
key: KeyFlexible,
algorithms: t.Optional[t.List[str]] = None,
registry: t.Optional[JWTRegistry] = None) -> Token:
Expand Down
3 changes: 2 additions & 1 deletion src/joserfc/rfc7515/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ def sign_flattened_json(
find_key: FindKey) -> FlattenedJSONSerialization:
payload_segment = urlsafe_b64encode(payload)
signature = __sign_member(payload_segment, HeaderMember(**member), registry, find_key)
return {"payload": payload_segment.decode("utf-8"), **signature} # type: ignore
data = {"payload": payload_segment.decode("utf-8"), **signature}
return data # type: ignore[return-value]


def __sign_member(
Expand Down
4 changes: 2 additions & 2 deletions src/joserfc/rfc7516/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def represent_general_json(obj: GeneralJSONEncryption) -> GeneralJSONSerializati
item["encrypted_key"] = to_str(urlsafe_b64encode(recipient.encrypted_key))
recipients.append(item)
data["recipients"] = recipients
return data # type: ignore[assignment]
return data


def represent_flattened_json(obj: FlattenedJSONEncryption) -> FlattenedJSONSerialization:
Expand All @@ -44,7 +44,7 @@ def represent_flattened_json(obj: FlattenedJSONEncryption) -> FlattenedJSONSeria
data["header"] = recipient.header
if recipient.encrypted_key:
data["encrypted_key"] = to_str(urlsafe_b64encode(recipient.encrypted_key))
return data # type: ignore
return data


def __represent_json_serialization(obj: BaseJSONEncryption):
Expand Down
29 changes: 14 additions & 15 deletions src/joserfc/rfc7516/models.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import os
import typing as t
from abc import ABCMeta, abstractmethod
from ..rfc7517.models import BaseKey, CurveKey
from ..registry import Header, HeaderRegistryDict
from ..errors import InvalidKeyTypeError, InvalidKeyLengthError
from .._keys import Key, ECKey

KeyType = t.TypeVar("KeyType", bound=BaseKey)
KeyType = t.TypeVar("KeyType")


class Recipient(t.Generic[KeyType]):
Expand All @@ -19,7 +19,7 @@ def __init__(
self.recipient_key = recipient_key
self.sender_key: t.Optional[KeyType] = None
self.encrypted_key: t.Optional[bytes] = None
self.ephemeral_key: t.Optional[CurveKey] = None
self.ephemeral_key: t.Optional[KeyType] = None

def headers(self) -> Header:
rv: Header = {}
Expand Down Expand Up @@ -58,7 +58,7 @@ def __init__(self, protected: Header, plaintext: t.Optional[bytes] = None):
def headers(self) -> Header:
return self.protected

def attach_recipient(self, key: BaseKey, header: t.Optional[Header] = None):
def attach_recipient(self, key: Key, header: t.Optional[Header] = None):
"""Add a recipient to the JWE Compact Serialization. Please add a key that
comply with the given "alg" value.
Expand Down Expand Up @@ -106,7 +106,7 @@ def __init__(
self.base64_segments: t.Dict[str, bytes] = {} # store the encoded segments

@abstractmethod
def add_recipient(self, header: t.Optional[Header] = None, key: t.Optional[BaseKey] = None):
def add_recipient(self, header: t.Optional[Header] = None, key: t.Optional[Key] = None):
"""Add a recipient to the JWE JSON Serialization. Please add a key that
comply with the "alg" to this recipient.
Expand All @@ -131,7 +131,7 @@ class GeneralJSONEncryption(BaseJSONEncryption):
"""
flattened = False

def add_recipient(self, header: t.Optional[Header] = None, key: t.Optional[BaseKey] = None):
def add_recipient(self, header: t.Optional[Header] = None, key: t.Optional[Key] = None):
recipient = Recipient(self, header, key)
self.recipients.append(recipient)

Expand All @@ -152,7 +152,7 @@ class FlattenedJSONEncryption(BaseJSONEncryption):
"""
flattened = True

def add_recipient(self, header: t.Optional[Header] = None, key: t.Optional[BaseKey] = None):
def add_recipient(self, header: t.Optional[Header] = None, key: t.Optional[Key] = None):
self.recipients = [Recipient(self, header, key)]


Expand Down Expand Up @@ -216,7 +216,7 @@ class KeyManagement:
def direct_mode(self) -> bool:
return self.key_size is None

def check_key_type(self, key: BaseKey):
def check_key_type(self, key: Key):
if key.key_type not in self.key_types:
raise InvalidKeyTypeError()

Expand Down Expand Up @@ -280,22 +280,21 @@ class JWEKeyAgreement(KeyManagement, metaclass=ABCMeta):
tag_aware: bool = False
key_wrapping: t.Optional[JWEKeyWrapping]

def prepare_ephemeral_key(self, recipient: Recipient[CurveKey]):
def prepare_ephemeral_key(self, recipient: Recipient[ECKey]):
recipient_key = recipient.recipient_key
assert recipient_key is not None
self.check_key_type(recipient_key)
if recipient.ephemeral_key is None:
ephemeral_key: CurveKey = recipient_key.generate_key(
recipient_key.curve_name, private=True) # type: ignore
ephemeral_key = recipient_key.generate_key(recipient_key.curve_name, private=True)
recipient.ephemeral_key = ephemeral_key
recipient.add_header("epk", recipient.ephemeral_key.as_dict(private=False))

@abstractmethod
def encrypt_agreed_upon_key(self, enc: JWEEncModel, recipient: Recipient[CurveKey]) -> bytes:
def encrypt_agreed_upon_key(self, enc: JWEEncModel, recipient: Recipient[ECKey]) -> bytes:
pass

@abstractmethod
def decrypt_agreed_upon_key(self, enc: JWEEncModel, recipient: Recipient[CurveKey]) -> bytes:
def decrypt_agreed_upon_key(self, enc: JWEEncModel, recipient: Recipient[ECKey]) -> bytes:
pass

def wrap_cek_with_auk(self, cek: bytes, key: bytes) -> bytes:
Expand All @@ -306,10 +305,10 @@ def unwrap_cek_with_auk(self, ek: bytes, key: bytes) -> bytes:
assert self.key_wrapping is not None
return self.key_wrapping.unwrap_cek(ek, key)

def encrypt_agreed_upon_key_with_tag(self, enc: JWEEncModel, recipient: Recipient[CurveKey], tag: bytes) -> bytes:
def encrypt_agreed_upon_key_with_tag(self, enc: JWEEncModel, recipient: Recipient[ECKey], tag: bytes) -> bytes:
raise NotImplementedError()

def decrypt_agreed_upon_key_with_tag(self, enc: JWEEncModel, recipient: Recipient[CurveKey], tag: bytes) -> bytes:
def decrypt_agreed_upon_key_with_tag(self, enc: JWEEncModel, recipient: Recipient[ECKey], tag: bytes) -> bytes:
raise NotImplementedError()


Expand Down
12 changes: 0 additions & 12 deletions src/joserfc/rfc7517/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +0,0 @@
from .models import BaseKey, SymmetricKey, AsymmetricKey, CurveKey
from .registry import JWKRegistry
from .keyset import KeySet

__all__ = [
"BaseKey",
"SymmetricKey",
"AsymmetricKey",
"CurveKey",
"JWKRegistry",
"KeySet",
]
11 changes: 6 additions & 5 deletions src/joserfc/rfc7517/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
)


GenericKey = t.TypeVar("GenericKey", bound="BaseKey")
NativePrivateKey = t.TypeVar("NativePrivateKey")
NativePublicKey = t.TypeVar("NativePublicKey")

Expand All @@ -43,7 +44,7 @@ def import_from_bytes(cls, value: bytes, password: t.Optional[t.Any] = None):
pass

@staticmethod
def as_bytes(key: "BaseKey", encoding=None, private=None, password=None) -> bytes: # pragma: no cover
def as_bytes(key: GenericKey, encoding=None, private=None, password=None) -> bytes: # pragma: no cover
return key.raw_value

@classmethod
Expand Down Expand Up @@ -228,10 +229,10 @@ def validate_dict_key(cls, data: KeyDict):

@classmethod
def import_key(
cls,
cls: t.Type[GenericKey],
value: KeyAny,
parameters: t.Optional[KeyParameters] = None,
password: t.Optional[t.Any] = None) -> "BaseKey":
password: t.Optional[t.Any] = None) -> GenericKey:
if isinstance(value, dict):
cls.validate_dict_key(value)
raw_key = cls.binding.import_from_dict(value)
Expand All @@ -242,10 +243,10 @@ def import_key(

@classmethod
def generate_key(
cls,
cls: t.Type[GenericKey],
size_or_crv,
parameters: t.Optional[KeyParameters] = None,
private: bool = True) -> "BaseKey":
private: bool = True) -> GenericKey:
raise NotImplementedError()


Expand Down
Loading

0 comments on commit 5d77591

Please sign in to comment.