Skip to content

Commit

Permalink
fix(jwk): mypy for jwk
Browse files Browse the repository at this point in the history
  • Loading branch information
lepture committed Jul 17, 2023
1 parent 710d566 commit b052bb0
Show file tree
Hide file tree
Showing 12 changed files with 124 additions and 89 deletions.
18 changes: 9 additions & 9 deletions src/joserfc/jwk.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import typing as t
from .rfc7517 import (
BaseKey,
SymmetricKey,
AsymmetricKey,
CurveKey,
Expand All @@ -22,6 +23,7 @@
__all__ = [
"types",
"JWKRegistry",
"BaseKey",
"SymmetricKey",
"AsymmetricKey",
"CurveKey",
Expand Down Expand Up @@ -66,30 +68,28 @@ def guess_key(key: KeyFlexible, obj: GuestProtocol) -> Key:
"""
headers = obj.headers()

rv_key: Key

if isinstance(key, (str, bytes)):
rv_key = OctKey.import_key(key)
rv_key = OctKey.import_key(key) # type: ignore

elif isinstance(key, (SymmetricKey, AsymmetricKey)):
rv_key = key
elif isinstance(key, BaseKey):
rv_key: Key = key # type: ignore

elif isinstance(key, KeySet):
kid = headers.get("kid")
if not kid:
# choose one key by random
rv_key = key.pick_random_key(headers["alg"])
rv_key: Key = key.pick_random_key(headers["alg"]) # type: ignore
if rv_key is None:
raise ValueError("Invalid key")
# use side effect to add kid information
obj.set_kid(rv_key.kid)
else:
rv_key = key.get_by_kid(kid)
rv_key: Key = key.get_by_kid(kid) # type: ignore

elif callable(key):
rv_key = key(obj)
rv_key = key(obj) # type: ignore

else:
raise ValueError("Invalid key")

return rv_key
return rv_key # type: ignore
8 changes: 7 additions & 1 deletion src/joserfc/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,15 @@ def __init__(
self.required = required


class KeyOperation:
def __init__(self, description: str, use: str, private: bool):
self.description = description
self.use = use
self.private = private


#: Define parameters for JWK
KeyParameterRegistryDict = t.Dict[str, KeyParameter]
KeyOperation = namedtuple("KeyOperation", ["description", "use", "private"])
KeyOperationRegistryDict = t.Dict[str, KeyOperation]

#: Basic JWS header registry
Expand Down
3 changes: 2 additions & 1 deletion src/joserfc/rfc7517/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from .models import SymmetricKey, AsymmetricKey, CurveKey
from .models import BaseKey, SymmetricKey, AsymmetricKey, CurveKey
from .registry import JWKRegistry
from .keyset import KeySet

__all__ = [
"BaseKey",
"SymmetricKey",
"AsymmetricKey",
"CurveKey",
Expand Down
9 changes: 6 additions & 3 deletions src/joserfc/rfc7517/keyset.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@ def pick_random_key(self, algorithm: str) -> t.Optional[BaseKey]:
return None

@classmethod
def import_key_set(cls, value: KeySetDict, parameters: KeyParameters = None) -> "KeySet":
def import_key_set(
cls,
value: KeySetDict,
parameters: t.Optional[KeyParameters] = None) -> "KeySet":
keys = []

for data in value["keys"]:
Expand All @@ -54,12 +57,12 @@ def generate_key_set(
cls,
key_type: str,
crv_or_size: t.Union[str, int],
parameters: KeyParameters = None,
parameters: t.Optional[KeyParameters] = None,
private: bool = True,
count: int = 4) -> "KeySet":

keys = []
for i in range(count):
for _ in range(count):
key = JWKRegistry.generate_key(key_type, crv_or_size, parameters, private)
keys.append(key)

Expand Down
27 changes: 20 additions & 7 deletions src/joserfc/rfc7517/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import typing as t
from typing import overload
from abc import ABCMeta, abstractmethod
from .types import KeyDict, KeyAny, KeyParameters
from ..registry import (
Expand Down Expand Up @@ -60,20 +61,25 @@ def validate_dict_key_registry(cls, dict_key: KeyDict, registry: KeyParameterReg
@classmethod
def validate_dict_key_use_operations(cls, dict_key: KeyDict):
if "use" in dict_key and "key_ops" in dict_key:
operations = cls.use_key_ops_registry[dict_key["use"]]
_use: str = dict_key["use"] # type: ignore
operations = cls.use_key_ops_registry[_use]
for op in dict_key["key_ops"]:
if op not in operations:
raise ValueError('"use" and "key_ops" does not match')


class BaseKey(t.Generic[NativePrivateKey, NativePublicKey]):
key_type: t.ClassVar[str]
binding: t.ClassVar[t.Type[NativeKeyBinding]]
value_registry: t.ClassVar[KeyParameterRegistryDict]
param_registry: t.ClassVar[KeyParameterRegistryDict] = JWK_PARAMETER_REGISTRY
operation_registry: t.ClassVar[KeyOperationRegistryDict] = JWK_OPERATION_REGISTRY
binding: t.ClassVar[t.Type[NativeKeyBinding]] = NativeKeyBinding

def __init__(self, raw_value: t.Any, original_value: t.Any, parameters: t.Optional[KeyParameters] = None):
def __init__(
self,
raw_value: t.Union[NativePrivateKey, NativePublicKey],
original_value: t.Any,
parameters: t.Optional[KeyParameters] = None):
self._raw_value = raw_value
self.original_value = original_value
self.extra_parameters = parameters
Expand Down Expand Up @@ -119,8 +125,8 @@ def dict_value(self) -> KeyDict:
return self._dict_value

data = self.binding.convert_raw_key_to_dict(self.raw_value, self.is_private)
if self.extra_parameters:
data.update(dict(self.extra_parameters))
if self.extra_parameters is not None:
data.update(self.extra_parameters) # type: ignore
data["kty"] = self.key_type
self.validate_dict_key(data)
self._dict_value = data
Expand Down Expand Up @@ -201,10 +207,17 @@ def check_key_op(self, operation: str):
if reg.private and not self.is_private:
raise UnsupportedKeyOperationError(f'Invalid key_op "{operation}" for public key')

@overload
def get_op_key(self, operation: t.Literal["verify", "encrypt", "wrapKey", "deriveKey"]) -> NativePublicKey: ...

@overload
def get_op_key(self, operation: t.Literal["sign", "decrypt", "unwrapKey"]) -> NativePrivateKey: ...

def get_op_key(self, operation: str) -> t.Union[NativePublicKey, NativePrivateKey]:
self.check_key_op(operation)
reg = self.operation_registry[operation]
if reg.private:
assert self.private_key is not None
return self.private_key
return self.public_key

Expand Down Expand Up @@ -237,7 +250,7 @@ def generate_key(
raise NotImplementedError()


class SymmetricKey(BaseKey[NativePrivateKey, NativePublicKey], metaclass=ABCMeta):
class SymmetricKey(BaseKey[bytes, bytes], metaclass=ABCMeta):
@property
def raw_value(self) -> bytes:
"""The raw key in bytes."""
Expand Down Expand Up @@ -285,5 +298,5 @@ def curve_name(self) -> str:
pass

@abstractmethod
def exchange_derive_key(self, key: "CurveKey") -> bytes:
def exchange_derive_key(self, key) -> bytes:
pass
14 changes: 8 additions & 6 deletions src/joserfc/rfc7517/pem.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Encoding,
PrivateFormat,
PublicFormat,
KeySerializationEncryption,
BestAvailableEncryption,
NoEncryption,
)
Expand All @@ -21,22 +22,22 @@

def load_pem_key(raw: bytes, ssh_type: t.Optional[bytes] = None, password: t.Optional[bytes] = None):
if ssh_type and raw.startswith(ssh_type):
key = load_ssh_public_key(raw, backend=default_backend())
key = load_ssh_public_key(raw, backend=default_backend()) # type: ignore

elif b"OPENSSH PRIVATE" in raw:
key = load_ssh_private_key(raw, password=password, backend=default_backend())
key = load_ssh_private_key(raw, password=password, backend=default_backend()) # type: ignore

elif b"PUBLIC" in raw:
key = load_pem_public_key(raw, backend=default_backend())
key = load_pem_public_key(raw, backend=default_backend()) # type: ignore

elif b"PRIVATE" in raw:
key = load_pem_private_key(raw, password=password, backend=default_backend())
key = load_pem_private_key(raw, password=password, backend=default_backend()) # type: ignore

else:
try:
key = load_der_private_key(raw, password=password, backend=default_backend())
key = load_der_private_key(raw, password=password, backend=default_backend()) # type: ignore
except ValueError:
key = load_der_public_key(raw, backend=default_backend())
key = load_der_public_key(raw, backend=default_backend()) # type: ignore
return key


Expand All @@ -58,6 +59,7 @@ def dump_pem_key(key, encoding=None, private=False, password=None) -> bytes:
raise ValueError("Invalid encoding: {!r}".format(encoding))

if private:
encryption_algorithm: KeySerializationEncryption
if password is None:
encryption_algorithm = NoEncryption()
else:
Expand Down
2 changes: 1 addition & 1 deletion src/joserfc/rfc7517/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def import_key(
"""
if isinstance(data, dict) and key_type is None:
if "kty" in data:
key_type = data["kty"]
key_type = data["kty"] # type: ignore
else:
raise ValueError("Missing key type")

Expand Down
5 changes: 2 additions & 3 deletions src/joserfc/rfc7517/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,5 @@
}, total=False)

#: JWKs in dict
KeySetDict = t.TypedDict("KeySetDict", {
"keys": t.List[KeyDict],
})
class KeySetDict(t.TypedDict):
keys: t.List[KeyDict]
48 changes: 25 additions & 23 deletions src/joserfc/rfc7518/ec_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,31 +36,32 @@ class ECBinding(CryptographyBinding):

@staticmethod
def import_private_key(obj: KeyDict) -> EllipticCurvePrivateKey:
curve = DSS_CURVES[obj["crv"]]()
curve = DSS_CURVES[obj["crv"]]() # type: ignore
public_numbers = EllipticCurvePublicNumbers(
base64_to_int(obj["x"]),
base64_to_int(obj["y"]),
base64_to_int(obj["x"]), # type: ignore
base64_to_int(obj["y"]), # type: ignore
curve,
)
private_numbers = EllipticCurvePrivateNumbers(base64_to_int(obj["d"]), public_numbers)
d = base64_to_int(obj["d"]) # type: ignore
private_numbers = EllipticCurvePrivateNumbers(d, public_numbers)
return private_numbers.private_key(default_backend())

@staticmethod
def export_private_key(key: EllipticCurvePrivateKey) -> Dict[str, str]:
numbers = key.private_numbers()
return {
"crv": CURVES_DSS[key.curve.name],
"crv": CURVES_DSS[key.curve.name], # type: ignore
"x": int_to_base64(numbers.public_numbers.x),
"y": int_to_base64(numbers.public_numbers.y),
"d": int_to_base64(numbers.private_value),
}

@staticmethod
def import_public_key(obj: KeyDict) -> EllipticCurvePublicKey:
curve = DSS_CURVES[obj["crv"]]()
curve = DSS_CURVES[obj["crv"]]() # type: ignore
public_numbers = EllipticCurvePublicNumbers(
base64_to_int(obj["x"]),
base64_to_int(obj["y"]),
base64_to_int(obj["x"]), # type: ignore
base64_to_int(obj["y"]), # type: ignore
curve,
)
return public_numbers.public_key(default_backend())
Expand All @@ -69,14 +70,14 @@ def import_public_key(obj: KeyDict) -> EllipticCurvePublicKey:
def export_public_key(key: EllipticCurvePublicKey) -> Dict[str, str]:
numbers = key.public_numbers()
return {
"crv": CURVES_DSS[numbers.curve.name],
"crv": CURVES_DSS[numbers.curve.name], # type: ignore
"x": int_to_base64(numbers.x),
"y": int_to_base64(numbers.y),
}


class ECKey(CurveKey[EllipticCurvePrivateKey, EllipticCurvePublicKey]):
key_type: str = "EC"
key_type = "EC"
#: Registry definition for EC Key
#: https://www.rfc-editor.org/rfc/rfc7518#section-6.2
value_registry = {
Expand All @@ -87,12 +88,6 @@ class ECKey(CurveKey[EllipticCurvePrivateKey, EllipticCurvePublicKey]):
}
binding = ECBinding

def exchange_derive_key(self, key: "ECKey") -> bytes:
pubkey = key.get_op_key("deriveKey")
if self.private_key and self.curve_name == key.curve_name:
return self.private_key.exchange(ECDH(), pubkey)
raise ValueError("Invalid key for exchanging shared key")

@property
def is_private(self) -> bool:
return isinstance(self.raw_value, EllipticCurvePrivateKey)
Expand All @@ -105,13 +100,19 @@ def public_key(self) -> EllipticCurvePublicKey:

@property
def private_key(self) -> Optional[EllipticCurvePrivateKey]:
if self.is_private:
if isinstance(self.raw_value, EllipticCurvePrivateKey):
return self.raw_value
return None

def exchange_derive_key(self, key: "ECKey") -> bytes:
pubkey = key.get_op_key("deriveKey")
if self.private_key and self.curve_name == key.curve_name:
return self.private_key.exchange(ECDH(), pubkey)
raise ValueError("Invalid key for exchanging shared key")

@property
def curve_name(self) -> str:
return CURVES_DSS[self.raw_value.curve.name]
return CURVES_DSS[self.raw_value.curve.name] # type: ignore

@property
def curve_key_size(self) -> int:
Expand All @@ -121,14 +122,15 @@ def curve_key_size(self) -> int:
def generate_key(
cls,
crv: str = "P-256",
parameters: KeyParameters = None,
parameters: Optional[KeyParameters] = None,
private: bool = True) -> "ECKey":
if crv not in DSS_CURVES:
raise ValueError('Invalid crv value: "{}"'.format(crv))
raw_key = generate_private_key(
curve=DSS_CURVES[crv](),
curve=DSS_CURVES[crv](), # type: ignore
backend=default_backend(),
)
if not private:
raw_key = raw_key.public_key()
return cls(raw_key, raw_key, parameters)
if private:
return cls(raw_key, raw_key, parameters)
pub_key = raw_key.public_key()
return cls(pub_key, pub_key, parameters)
11 changes: 8 additions & 3 deletions src/joserfc/rfc7518/oct_key.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import random
import string
from typing import Optional
from ..util import (
to_bytes,
urlsafe_b64decode,
Expand Down Expand Up @@ -36,17 +37,21 @@ def import_from_bytes(cls, value: bytes, password=None):
return value


class OctKey(SymmetricKey[bytes, bytes]):
class OctKey(SymmetricKey):
"""OctKey is a symmetric key, defined by RFC7518 Section 6.4.
"""
key_type: str = "oct"
key_type = "oct"
binding = OctBinding

#: https://www.rfc-editor.org/rfc/rfc7518#section-6.4
value_registry = {"k": KeyParameter("Key Value", "str", True, True)}

@classmethod
def generate_key(cls, key_size=256, parameters: KeyParameters = None, private: bool = True) -> "OctKey":
def generate_key(
cls,
key_size=256,
parameters: Optional[KeyParameters] = None,
private: bool = True) -> "OctKey":
"""Generate a ``OctKey`` with the given bit size (not bytes).
:param key_size: size in bit
Expand Down
Loading

0 comments on commit b052bb0

Please sign in to comment.