11from __future__ import annotations
2- from typing import Any , Literal , cast
2+ from typing import Any , Literal , Tuple , cast
33from abc import ABCMeta , abstractmethod
44from cryptography .x509 import load_pem_x509_certificate
55from cryptography .hazmat .primitives .serialization import (
1919from cryptography .hazmat .backends import default_backend
2020from .models import NativeKeyBinding , GenericKey
2121from .types import DictKey
22+ from ..errors import InvalidKeyTypeError
2223from ..util import to_bytes
2324
2425
25- def load_pem_key (raw : bytes , ssh_type : bytes | None = None , password : bytes | None = None ) -> Any :
26+ def load_pem_key (raw : bytes , password : bytes | None = None ) -> Any :
2627 key : Any
27- if ssh_type and raw .startswith (ssh_type ):
28- key = load_ssh_public_key (raw , backend = default_backend ())
2928
30- elif b"OPENSSH PRIVATE" in raw :
29+ if b"OPENSSH PRIVATE" in raw :
3130 key = load_ssh_private_key (raw , password = password , backend = default_backend ())
3231
3332 elif b"PUBLIC" in raw :
@@ -49,7 +48,10 @@ def load_pem_key(raw: bytes, ssh_type: bytes | None = None, password: bytes | No
4948
5049
5150def dump_pem_key (
52- key : Any , encoding : Literal ["PEM" , "DER" ] | None = None , private : bool | None = False , password : Any | None = None
51+ key : Any ,
52+ encoding : Literal ["PEM" , "DER" ] | None = None ,
53+ private : bool | None = False ,
54+ password : Any | None = None ,
5355) -> bytes :
5456 """Export key into PEM/DER format bytes.
5557
@@ -87,7 +89,17 @@ def dump_pem_key(
8789
8890
8991class CryptographyBinding (NativeKeyBinding , metaclass = ABCMeta ):
92+ key_type : str
9093 ssh_type : bytes
94+ cryptography_native_keys : Tuple [Any ]
95+
96+ @classmethod
97+ def check_ssh_type (cls , value : bytes ):
98+ return cls .ssh_type and value .startswith (cls .ssh_type )
99+
100+ @classmethod
101+ def check_cryptography_native_key (cls , native_key : Any ):
102+ return isinstance (native_key , cls .cryptography_native_keys )
91103
92104 @classmethod
93105 def convert_raw_key_to_dict (cls , raw_key : Any , private : bool ) -> DictKey :
@@ -105,9 +117,16 @@ def import_from_dict(cls, value: DictKey) -> Any:
105117
106118 @classmethod
107119 def import_from_bytes (cls , value : bytes , password : Any | None = None ) -> Any :
120+ if cls .check_ssh_type (value ):
121+ return load_ssh_public_key (value , backend = default_backend ())
122+
108123 if password is not None :
109124 password = to_bytes (password )
110- return load_pem_key (value , cls .ssh_type , password )
125+
126+ key = load_pem_key (value , password )
127+ if not cls .check_cryptography_native_key (key ):
128+ raise InvalidKeyTypeError (f"Not a key of: '{ cls .key_type } '" )
129+ return key
111130
112131 @staticmethod
113132 def as_bytes (
@@ -116,7 +135,7 @@ def as_bytes(
116135 private : bool | None = False ,
117136 password : Any | None = None ,
118137 ) -> bytes :
119- if private is True :
138+ if private :
120139 return dump_pem_key (key .private_key , encoding , private , password )
121140 elif private is False :
122141 return dump_pem_key (key .public_key , encoding , private , password )
0 commit comments