Skip to content

Commit 45cd616

Browse files
committed
feat: check ssh_type and cryptography key types in bindings
1 parent a06ab98 commit 45cd616

File tree

5 files changed

+42
-9
lines changed

5 files changed

+42
-9
lines changed

docs/migrations/authlib.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ When using methods such as ``.as_dict``, ``.as_bytes``, ``.as_pem``, and others,
8080

8181
.. code-block:: python
8282
:caption: Authlib
83-
:emphasize-lines: 1,2
8483
8584
key.as_dict(is_private=True)
8685

src/joserfc/_rfc7517/pem.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from __future__ import annotations
2-
from typing import Any, Literal, cast
2+
from typing import Any, Literal, Tuple, cast
33
from abc import ABCMeta, abstractmethod
44
from cryptography.x509 import load_pem_x509_certificate
55
from cryptography.hazmat.primitives.serialization import (
@@ -19,15 +19,14 @@
1919
from cryptography.hazmat.backends import default_backend
2020
from .models import NativeKeyBinding, GenericKey
2121
from .types import DictKey
22+
from ..errors import InvalidKeyTypeError
2223
from ..util import to_bytes
2324

2425

25-
def load_pem_key(raw: bytes, ssh_type: bytes | None = None, password: bytes | None = None) -> Any:
26+
def load_pem_key(raw: bytes, password: bytes | None = None) -> Any:
2627
key: Any
27-
if ssh_type and raw.startswith(ssh_type):
28-
key = load_ssh_public_key(raw, backend=default_backend())
2928

30-
elif b"OPENSSH PRIVATE" in raw:
29+
if b"OPENSSH PRIVATE" in raw:
3130
key = load_ssh_private_key(raw, password=password, backend=default_backend())
3231

3332
elif b"PUBLIC" in raw:
@@ -49,7 +48,10 @@ def load_pem_key(raw: bytes, ssh_type: bytes | None = None, password: bytes | No
4948

5049

5150
def dump_pem_key(
52-
key: Any, encoding: Literal["PEM", "DER"] | None = None, private: bool | None = False, password: Any | None = None
51+
key: Any,
52+
encoding: Literal["PEM", "DER"] | None = None,
53+
private: bool | None = False,
54+
password: Any | None = None,
5355
) -> bytes:
5456
"""Export key into PEM/DER format bytes.
5557
@@ -87,7 +89,17 @@ def dump_pem_key(
8789

8890

8991
class CryptographyBinding(NativeKeyBinding, metaclass=ABCMeta):
92+
key_type: str
9093
ssh_type: bytes
94+
cryptography_native_keys: Tuple[Any]
95+
96+
@classmethod
97+
def check_ssh_type(cls, value: bytes):
98+
return cls.ssh_type and value.startswith(cls.ssh_type)
99+
100+
@classmethod
101+
def check_cryptography_native_key(cls, native_key: Any):
102+
return isinstance(native_key, cls.cryptography_native_keys)
91103

92104
@classmethod
93105
def convert_raw_key_to_dict(cls, raw_key: Any, private: bool) -> DictKey:
@@ -105,9 +117,16 @@ def import_from_dict(cls, value: DictKey) -> Any:
105117

106118
@classmethod
107119
def import_from_bytes(cls, value: bytes, password: Any | None = None) -> Any:
120+
if cls.check_ssh_type(value):
121+
return load_ssh_public_key(value, backend=default_backend())
122+
108123
if password is not None:
109124
password = to_bytes(password)
110-
return load_pem_key(value, cls.ssh_type, password)
125+
126+
key = load_pem_key(value, password)
127+
if not cls.check_cryptography_native_key(key):
128+
raise InvalidKeyTypeError(f"Not a key of: '{cls.key_type}'")
129+
return key
111130

112131
@staticmethod
113132
def as_bytes(
@@ -116,7 +135,7 @@ def as_bytes(
116135
private: bool | None = False,
117136
password: Any | None = None,
118137
) -> bytes:
119-
if private is True:
138+
if private:
120139
return dump_pem_key(key.private_key, encoding, private, password)
121140
elif private is False:
122141
return dump_pem_key(key.public_key, encoding, private, password)

src/joserfc/_rfc7518/ec_key.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@
3636

3737

3838
class ECBinding(CryptographyBinding):
39+
key_type = "EC"
3940
ssh_type = b"ecdsa-sha2-"
41+
cryptography_native_keys = (EllipticCurvePrivateKey, EllipticCurvePublicKey)
4042

4143
_dss_curves: dict[str, t.Type[EllipticCurve]] = {}
4244
_curves_dss: dict[str, str] = {}

src/joserfc/_rfc7518/rsa_key.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@
3939

4040

4141
class RSABinding(CryptographyBinding):
42+
key_type = "RSA"
4243
ssh_type = b"ssh-rsa"
44+
cryptography_native_keys = (RSAPrivateKey, RSAPublicKey)
4345

4446
@staticmethod
4547
def import_private_key(obj: RSADictKey) -> RSAPrivateKey:

src/joserfc/_rfc8037/okp_key.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,18 @@
4646

4747

4848
class OKPBinding(CryptographyBinding):
49+
key_type = "OKP"
4950
ssh_type = b"ssh-ed25519"
51+
cryptography_native_keys = (
52+
Ed25519PublicKey,
53+
Ed25519PrivateKey,
54+
Ed448PublicKey,
55+
Ed448PrivateKey,
56+
X25519PublicKey,
57+
X25519PrivateKey,
58+
X448PublicKey,
59+
X448PrivateKey,
60+
)
5061

5162
@staticmethod
5263
def import_private_key(obj: OKPDictKey) -> PrivateOKPKey:

0 commit comments

Comments
 (0)