Skip to content

Commit

Permalink
fix: improve for mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
lepture committed Jul 28, 2023
1 parent 5af2dc6 commit a0ba737
Show file tree
Hide file tree
Showing 13 changed files with 152 additions and 86 deletions.
48 changes: 36 additions & 12 deletions src/joserfc/drafts/jwe_ecdh_1pu.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,38 +55,62 @@ def _check_enc(self, enc: JWEEncModel) -> None:
)
raise InvalidEncryptionAlgorithmError(description)

def encrypt_agreed_upon_key(self, enc: JWEEncModel, recipient: Recipient) -> bytes:
def encrypt_agreed_upon_key(self, enc: JWEEncModel, recipient: Recipient[CurveKey]) -> bytes:
self._check_enc(enc)
return self.__encrypt_agreed_upon_key(enc, recipient, None)

def encrypt_agreed_upon_key_with_tag(self, enc: JWEEncModel, recipient: Recipient, tag: bytes) -> bytes:
def encrypt_agreed_upon_key_with_tag(
self,
enc: JWEEncModel,
recipient: Recipient[CurveKey],
tag: bytes) -> bytes:
self._check_enc(enc)
return self.__encrypt_agreed_upon_key(enc, recipient, tag)

def decrypt_agreed_upon_key(self, enc: JWEEncModel, recipient: Recipient) -> bytes:
def decrypt_agreed_upon_key(self, enc: JWEEncModel, recipient: Recipient[CurveKey]) -> bytes:
return self.__decrypt_agreed_upon_key(enc, recipient, None)

def decrypt_agreed_upon_key_with_tag(self, enc: JWEEncModel, recipient: Recipient, tag: bytes) -> bytes:
def decrypt_agreed_upon_key_with_tag(
self,
enc: JWEEncModel,
recipient: Recipient[CurveKey],
tag: bytes) -> bytes:
return self.__decrypt_agreed_upon_key(enc, recipient, tag)

def __encrypt_agreed_upon_key(self, enc: JWEEncModel, recipient: Recipient, tag: t.Optional[bytes]) -> bytes:
sender_key: CurveKey = recipient.sender_key # type: ignore
recipient_key: CurveKey = recipient.recipient_key # type: ignore
ephemeral_key: CurveKey = recipient.ephemeral_key # type: ignore
def __encrypt_agreed_upon_key(
self,
enc: JWEEncModel,
recipient: Recipient[CurveKey],
tag: t.Optional[bytes]) -> bytes:
sender_key = recipient.sender_key
recipient_key = recipient.recipient_key
ephemeral_key = recipient.ephemeral_key
assert sender_key is not None
assert recipient_key is not None
assert ephemeral_key is not None

sender_shared_key = sender_key.exchange_derive_key(recipient_key)
ephemeral_shared_key = ephemeral_key.exchange_derive_key(recipient_key)
shared_key = ephemeral_shared_key + sender_shared_key
headers = recipient.headers()
return derive_key_for_concat_kdf(shared_key, headers, enc.cek_size, self.key_size, tag)

def __decrypt_agreed_upon_key(self, enc: JWEEncModel, recipient: Recipient, tag: t.Optional[bytes]) -> bytes:
def __decrypt_agreed_upon_key(
self,
enc: JWEEncModel,
recipient: Recipient[CurveKey],
tag: t.Optional[bytes]) -> bytes:

self._check_enc(enc)
headers = recipient.headers()
assert "epk" in headers

recipient_key: CurveKey = recipient.recipient_key # type: ignore
ephemeral_key: CurveKey = recipient_key.import_key(headers["epk"]) # type: ignore
sender_key: CurveKey = recipient.sender_key # type: ignore
sender_key = recipient.sender_key
recipient_key = recipient.recipient_key
assert sender_key is not None
assert recipient_key is not None

ephemeral_key: CurveKey = recipient_key.import_key(headers["epk"]) # type: ignore[assignment]
sender_shared_key = recipient_key.exchange_derive_key(sender_key)
ephemeral_shared_key = recipient_key.exchange_derive_key(ephemeral_key)
shared_key = ephemeral_shared_key + sender_shared_key
Expand Down
4 changes: 2 additions & 2 deletions src/joserfc/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class InvalidCEKLengthError(JoseError):
class InvalidClaimError(JoseError):
error = "invalid_claim"

def __init__(self, claim):
def __init__(self, claim: str):
description = f'Invalid claim: "{claim}"'
super(InvalidClaimError, self).__init__(description=description)

Expand All @@ -98,7 +98,7 @@ def __init__(self, claim):
class InsecureClaimError(JoseError):
error = "insecure_claim"

def __init__(self, claim):
def __init__(self, claim: str):
description = f'Insecure claim "{claim}"'
super(InsecureClaimError, self).__init__(description=description)

Expand Down
4 changes: 2 additions & 2 deletions src/joserfc/jwe.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from .rfc7518.jwe_algs import JWE_ALG_MODELS
from .rfc7518.jwe_encs import JWE_ENC_MODELS
from .rfc7518.jwe_zips import JWE_ZIP_MODELS
from .jwk import KeySet, CurveKey, KeyFlexible, guess_key
from .jwk import Key, KeySet, CurveKey, KeyFlexible, guess_key
from .util import to_bytes
from .registry import Header

Expand Down Expand Up @@ -93,7 +93,7 @@ def encrypt_compact(
registry = default_registry

obj = CompactEncryption(protected, to_bytes(plaintext))
recipient = Recipient(obj)
recipient: Recipient[Key] = Recipient(obj)
key = guess_key(public_key, recipient)
key.check_use("enc")
recipient.recipient_key = key
Expand Down
33 changes: 25 additions & 8 deletions src/joserfc/jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,29 +85,46 @@ def decode(
:raise: BadSignatureError
"""
_value = to_bytes(value)
_header: Header
header: Header
payload: bytes
if _value.count(b".") == 4:
if registry is not None:
assert isinstance(registry, JWERegistry)
jwe_obj = decrypt_compact(_value, key, algorithms, registry)
_header = jwe_obj.headers()
payload: bytes = jwe_obj.plaintext # type: ignore
header, payload = _decode_jwe(_value, key, algorithms, registry)
else:
if registry is not None:
assert isinstance(registry, JWSRegistry)
jws_obj = deserialize_compact(_value, key, algorithms, registry)
_header = jws_obj.headers()
payload: bytes = jws_obj.payload # type: ignore
header, payload = _decode_jws(_value, key, algorithms, registry)

try:
claims: Claims = json.loads(payload)
except (TypeError, ValueError):
raise InvalidPayloadError()

token = Token(_header, claims)
token = Token(header, claims)
typ = token.header.get("typ")
# https://www.rfc-editor.org/rfc/rfc7519#section-5.1
# If present, it is RECOMMENDED that its value be "JWT".
if typ and typ != "JWT":
raise InvalidTypeError()
return token


def _decode_jwe(
value: bytes,
key: KeyFlexible,
algorithms: t.Optional[t.List[str]] = None,
registry: t.Optional[JWERegistry] = None) -> t.Tuple[Header, bytes]:
jwe_obj = decrypt_compact(value, key, algorithms, registry)
assert jwe_obj.plaintext is not None
return jwe_obj.headers(), jwe_obj.plaintext


def _decode_jws(
value: bytes,
key: KeyFlexible,
algorithms: t.Optional[t.List[str]] = None,
registry: t.Optional[JWSRegistry] = None) -> t.Tuple[Header, bytes]:
jws_obj = deserialize_compact(value, key, algorithms, registry)
assert jws_obj.payload is not None
return jws_obj.headers(), jws_obj.payload
22 changes: 11 additions & 11 deletions src/joserfc/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,42 +3,42 @@
Header = t.Dict[str, t.Any]


def is_str(value: str):
def is_str(value: str) -> None:
if not isinstance(value, str):
raise ValueError("must be a str")


def is_url(value: str):
def is_url(value: str) -> None:
is_str(value)
if not value.startswith(("http://", "https://")):
raise ValueError("must be a URL")


def is_int(value: int):
def is_int(value: int) -> None:
if not isinstance(value, int):
raise ValueError("must be an int")


def is_bool(value: bool):
def is_bool(value: bool) -> None:
if not isinstance(value, bool):
raise ValueError("must be an bool")


def is_list_str(values):
def is_list_str(values: t.List[str]) -> None:
if not isinstance(values, list):
raise ValueError("must be a list[str]")

if not all(isinstance(value, str) for value in values):
raise ValueError("must be a list[str]")


def is_jwk(value):
def is_jwk(value: t.Dict[str, t.Any]) -> None:
if not isinstance(value, dict):
raise ValueError("must be a JWK")


def in_choices(choices: t.List[str]):
def _is_one_of(value):
def _is_one_of(value: t.Union[str, t.List[str]]) -> None:
if isinstance(value, list):
if not all(v in choices for v in value):
raise ValueError(f"must be one of {choices}")
Expand All @@ -49,7 +49,7 @@ def _is_one_of(value):
return _is_one_of


def not_support(_):
def not_support(_) -> None:
raise ValueError("is not supported")


Expand Down Expand Up @@ -170,7 +170,7 @@ def __init__(self, description: str, use: str, private: t.Optional[bool]):
}


def check_supported_header(registry: HeaderRegistryDict, header: Header):
def check_supported_header(registry: HeaderRegistryDict, header: Header) -> None:
allowed_keys = set(registry.keys())
unsupported_keys = set(header.keys()) - allowed_keys
if unsupported_keys:
Expand All @@ -180,7 +180,7 @@ def check_supported_header(registry: HeaderRegistryDict, header: Header):
def validate_registry_header(
registry: HeaderRegistryDict,
header: Header,
check_required: bool = True):
check_required: bool = True) -> None:
for key, reg in registry.items():
if check_required and reg.required and key not in header:
raise ValueError(f'Required "{key}" is missing in header')
Expand All @@ -191,7 +191,7 @@ def validate_registry_header(
raise ValueError(f'"{key}" in header {error}')


def check_crit_header(header: Header):
def check_crit_header(header: Header) -> None:
# check crit header
if "crit" in header:
for k in header["crit"]:
Expand Down
3 changes: 2 additions & 1 deletion src/joserfc/rfc7516/compact.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .models import CompactEncryption, Recipient
from ..rfc7517.models import BaseKey
from ..errors import (
MissingAlgorithmError,
MissingEncryptionError,
Expand Down Expand Up @@ -51,7 +52,7 @@ def extract_compact(value: bytes) -> CompactEncryption:
"ciphertext": urlsafe_b64decode(ciphertext_segment),
"tag": urlsafe_b64decode(tag_segment),
})
recipient = Recipient(obj)
recipient: Recipient[BaseKey] = Recipient(obj)
recipient.encrypted_key = urlsafe_b64decode(ek_segment)
obj.recipient = recipient
return obj
9 changes: 5 additions & 4 deletions src/joserfc/rfc7516/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
GeneralJSONSerialization,
FlattenedJSONSerialization,
)
from ..rfc7517.models import BaseKey
from ..registry import Header
from ..util import (
to_bytes,
Expand All @@ -32,7 +33,7 @@ def represent_general_json(obj: GeneralJSONEncryption) -> GeneralJSONSerializati
item["encrypted_key"] = to_str(urlsafe_b64encode(recipient.encrypted_key))
recipients.append(item)
data["recipients"] = recipients
return data # type: ignore
return data # type: ignore[assignment]


def represent_flattened_json(obj: FlattenedJSONEncryption) -> FlattenedJSONSerialization:
Expand All @@ -47,7 +48,7 @@ def represent_flattened_json(obj: FlattenedJSONEncryption) -> FlattenedJSONSeria


def __represent_json_serialization(obj: BaseJSONEncryption):
data: t.Dict[str, t.Union[str, Header]] = {
data: t.Dict[str, t.Union[str, Header, t.List[Header]]] = {
"protected": to_str(json_b64encode(obj.protected)),
"iv": to_str(obj.base64_segments["iv"]),
"ciphertext": to_str(obj.base64_segments["ciphertext"]),
Expand All @@ -69,7 +70,7 @@ def extract_general_json(data: GeneralJSONSerialization) -> GeneralJSONEncryptio
obj.base64_segments = base64_segments
obj.bytes_segments = bytes_segments
for item in data["recipients"]:
recipient = Recipient(obj, item.get("header"))
recipient: Recipient[BaseKey] = Recipient(obj, item.get("header"))
if "encrypted_key" in item:
recipient.encrypted_key = urlsafe_b64decode(to_bytes(item["encrypted_key"]))
obj.recipients.append(recipient)
Expand All @@ -84,7 +85,7 @@ def extract_flattened_json(data: FlattenedJSONSerialization) -> FlattenedJSONEnc
obj.base64_segments = base64_segments
obj.bytes_segments = bytes_segments

recipient = Recipient(obj, data.get("header"))
recipient: Recipient[BaseKey] = Recipient(obj, data.get("header"))
if "encrypted_key" in data:
recipient.encrypted_key = urlsafe_b64decode(to_bytes(data["encrypted_key"]))
obj.recipients.append(recipient)
Expand Down
21 changes: 12 additions & 9 deletions src/joserfc/rfc7516/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,19 @@
from ..registry import Header, HeaderRegistryDict
from ..errors import InvalidKeyTypeError, InvalidKeyLengthError

KeyType = t.TypeVar("KeyType", bound=BaseKey)

class Recipient:

class Recipient(t.Generic[KeyType]):
def __init__(
self,
parent: t.Union["CompactEncryption", "GeneralJSONEncryption", "FlattenedJSONEncryption"],
header: t.Optional[Header] = None,
recipient_key: t.Optional[BaseKey] = None):
recipient_key: t.Optional[KeyType] = None):
self.__parent = parent
self.header = header
self.recipient_key = recipient_key
self.sender_key: t.Optional[BaseKey] = None
self.sender_key: t.Optional[KeyType] = None
self.encrypted_key: t.Optional[bytes] = None
self.ephemeral_key: t.Optional[CurveKey] = None

Expand Down Expand Up @@ -278,8 +280,9 @@ class JWEKeyAgreement(KeyManagement, metaclass=ABCMeta):
tag_aware: bool = False
key_wrapping: t.Optional[JWEKeyWrapping]

def prepare_ephemeral_key(self, recipient: Recipient):
recipient_key: CurveKey = recipient.recipient_key # type: ignore
def prepare_ephemeral_key(self, recipient: Recipient[CurveKey]):
recipient_key = recipient.recipient_key
assert recipient_key is not None
self.check_key_type(recipient_key)
if recipient.ephemeral_key is None:
ephemeral_key: CurveKey = recipient_key.generate_key(
Expand All @@ -288,11 +291,11 @@ def prepare_ephemeral_key(self, recipient: Recipient):
recipient.add_header("epk", recipient.ephemeral_key.as_dict(private=False))

@abstractmethod
def encrypt_agreed_upon_key(self, enc: JWEEncModel, recipient: Recipient) -> bytes:
def encrypt_agreed_upon_key(self, enc: JWEEncModel, recipient: Recipient[CurveKey]) -> bytes:
pass

@abstractmethod
def decrypt_agreed_upon_key(self, enc: JWEEncModel, recipient: Recipient) -> bytes:
def decrypt_agreed_upon_key(self, enc: JWEEncModel, recipient: Recipient[CurveKey]) -> bytes:
pass

def wrap_cek_with_auk(self, cek: bytes, key: bytes) -> bytes:
Expand All @@ -303,10 +306,10 @@ def unwrap_cek_with_auk(self, ek: bytes, key: bytes) -> bytes:
assert self.key_wrapping is not None
return self.key_wrapping.unwrap_cek(ek, key)

def encrypt_agreed_upon_key_with_tag(self, enc: JWEEncModel, recipient: Recipient, tag: bytes) -> bytes:
def encrypt_agreed_upon_key_with_tag(self, enc: JWEEncModel, recipient: Recipient[CurveKey], tag: bytes) -> bytes:
raise NotImplementedError()

def decrypt_agreed_upon_key_with_tag(self, enc: JWEEncModel, recipient: Recipient, tag: bytes) -> bytes:
def decrypt_agreed_upon_key_with_tag(self, enc: JWEEncModel, recipient: Recipient[CurveKey], tag: bytes) -> bytes:
raise NotImplementedError()


Expand Down
1 change: 1 addition & 0 deletions src/joserfc/rfc7516/types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import typing as t

__all__ = [
"JSONRecipientDict",
"FlattenedJSONSerialization",
"GeneralJSONSerialization",
]
Expand Down
9 changes: 5 additions & 4 deletions src/joserfc/rfc7518/derive_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,12 @@ def derive_key_for_concat_kdf(
return ckdf.derive(shared_key)


def u32be_len_input(s, use_base64=False) -> bytes:
def u32be_len_input(s: t.Optional[t.AnyStr], use_base64=False) -> bytes:
if not s:
return b"\x00\x00\x00\x00"
sb: bytes
if use_base64:
s = urlsafe_b64decode(to_bytes(s))
sb = urlsafe_b64decode(to_bytes(s))
else:
s = to_bytes(s)
return struct.pack(">I", len(s)) + s
sb = to_bytes(s)
return struct.pack(">I", len(sb)) + sb
Loading

0 comments on commit a0ba737

Please sign in to comment.