From dae0cbba328bc429cfcd06011ca90171ab6a305c Mon Sep 17 00:00:00 2001 From: Mathias Laurin Date: Sun, 18 Feb 2024 10:51:16 +0100 Subject: [PATCH] pk: Track cipher content internally mbedtls-3.x --- src/mbedtls/pk.pxd | 1 + src/mbedtls/pk.pyx | 78 +++++++++++++++++++++++++--------------------- tests/test_pk.py | 2 ++ 3 files changed, 45 insertions(+), 36 deletions(-) diff --git a/src/mbedtls/pk.pxd b/src/mbedtls/pk.pxd index 214957e5..bae5f7d9 100644 --- a/src/mbedtls/pk.pxd +++ b/src/mbedtls/pk.pxd @@ -108,6 +108,7 @@ cdef extern from "mbedtls/pk.h" nogil: cdef class CipherBase: cdef mbedtls_pk_context _ctx + cdef object __state cdef class RSA(CipherBase): diff --git a/src/mbedtls/pk.pyx b/src/mbedtls/pk.pyx index b85e538f..8f201330 100644 --- a/src/mbedtls/pk.pyx +++ b/src/mbedtls/pk.pyx @@ -155,6 +155,12 @@ def _get_md_alg(digestmod): raise TypeError("a valid digestmod is required, got %r" % digestmod) +class CipherState(enum.Flag): + UNSET = enum.auto() + PUBLIC = enum.auto() + PRIVATE = enum.auto() + + cdef class CipherBase: """Base class to RSA and ECC ciphers. @@ -168,6 +174,7 @@ cdef class CipherBase: name, const unsigned char[:] key=None, const unsigned char[:] password=None): + self.__state = CipherState.UNSET _exc.check_error(_pk.mbedtls_pk_setup( &self._ctx, _pk.mbedtls_pk_info_from_type( @@ -189,6 +196,11 @@ cdef class CipherBase: except _exc.TLSError: _exc.check_error(_pk.mbedtls_pk_parse_public_key( &self._ctx, &key[0], key.size)) + pub = self._public_to_PEM() + if "PUBLIC" in pub: + self.__state |= CipherState.PUBLIC + if _pk.mbedtls_pk_check_pair(&self._ctx, &self._ctx) == 0: + self.__state |= CipherState.PRIVATE def __cinit__(self): """Initialize the context.""" @@ -256,8 +268,10 @@ cdef class CipherBase: # PEM must be null-terminated. bkey = bkey + b"\0" if callable(password): - return cls(key=bkey, password=password()) - return cls(key=bkey, password=password) + self = cls(key=bkey, password=password()) + else: + self = cls(key=bkey, password=password) + return self @classmethod def from_file(cls, path, password=None): @@ -292,13 +306,19 @@ cdef class CipherBase: """Return the size of the key, in bytes.""" return _pk.mbedtls_pk_get_len(&self._ctx) + def _set_private(self): + self.__state |= CipherState.PRIVATE + def _has_private(self): """Return `True` if the key contains a valid private half.""" - raise NotImplementedError + return CipherState.PRIVATE in self.__state + + def _set_public(self): + self.__state |= CipherState.PUBLIC def _has_public(self): """Return `True` if the key contains a valid public half.""" - raise NotImplementedError + return CipherState.PUBLIC in self.__state def sign(self, const unsigned char[:] message not None, @@ -418,8 +438,6 @@ cdef class CipherBase: raise NotImplementedError def _private_to_DER(self): - if not self._has_private(): - return b"" cdef int olen cdef size_t osize = PRV_DER_MAX_BYTES cdef unsigned char *output = malloc( @@ -430,12 +448,15 @@ cdef class CipherBase: olen = _exc.check_error( _pk.mbedtls_pk_write_key_der(&self._ctx, output, osize)) return output[osize - olen:osize] + except _exc.TLSError as exc: + if exc.err == 0x4080: + # no private key + return b"" + raise finally: free(output) def _private_to_PEM(self): - if not self._has_private(): - return "" cdef size_t osize = PRV_DER_MAX_BYTES * 4 // 3 + 100 cdef unsigned char *output = malloc( osize * sizeof(unsigned char)) @@ -446,6 +467,11 @@ cdef class CipherBase: _exc.check_error( _pk.mbedtls_pk_write_key_pem(&self._ctx, output, osize)) return output[0:osize].rstrip(b"\0").decode("ascii") + except _exc.TLSError as exc: + if exc.err == 0x4080: + # no private key + return "" + raise finally: free(output) @@ -459,14 +485,12 @@ cdef class CipherBase: """ if format == "DER": - return self._private_to_DER() + return self._private_to_DER() if self._has_private() else b"" if format == "PEM": - return self._private_to_PEM() + return self._private_to_PEM() if self._has_private() else "" raise ValueError(format) def _public_to_DER(self): - if not self._has_public(): - return b"" cdef int olen cdef size_t osize = PRV_DER_MAX_BYTES cdef unsigned char *output = malloc( @@ -481,8 +505,6 @@ cdef class CipherBase: free(output) def _public_to_PEM(self): - if not self._has_public(): - return "" cdef size_t osize = PRV_DER_MAX_BYTES * 4 // 3 + 100 cdef unsigned char *output = malloc( osize * sizeof(unsigned char)) @@ -506,9 +528,9 @@ cdef class CipherBase: """ if format == "DER": - return self._public_to_DER() + return self._public_to_DER() if self._has_public() else b"" if format == "PEM": - return self._public_to_PEM() + return self._public_to_PEM() if self._has_public() else "" raise ValueError(format) @@ -521,16 +543,6 @@ cdef class RSA(CipherBase): const unsigned char[:] password=None): super().__init__(b"RSA", key, password) - def _has_private(self): - """Return `True` if the key contains a valid private half.""" - return _rsa.mbedtls_rsa_check_privkey( - _pk.mbedtls_pk_rsa(self._ctx) - ) == 0 - - def _has_public(self): - """Return `True` if the key contains a valid public half.""" - return _rsa.mbedtls_rsa_check_pubkey(_pk.mbedtls_pk_rsa(self._ctx)) == 0 - def generate(self, unsigned int key_size=2048, int exponent=65537): """Generate an RSA keypair. @@ -545,6 +557,8 @@ cdef class RSA(CipherBase): _exc.check_error(_rsa.mbedtls_rsa_gen_key( _pk.mbedtls_pk_rsa(self._ctx), &_rnd.mbedtls_ctr_drbg_random, &__rng._ctx, key_size, exponent)) + self._set_public() + self._set_private() return self.export_key("DER") @@ -643,16 +657,6 @@ cdef class ECC(CipherBase): def curve(self): return self._curve - def _has_private(self): - """Return `True` if the key contains a valid private half.""" - cdef const _ecp.mbedtls_ecp_keypair* ecp = _pk.mbedtls_pk_ec(self._ctx) - return _mpi.mbedtls_mpi_cmp_mpi(&ecp.d, &_mpi.MPI()._ctx) != 0 - - def _has_public(self): - """Return `True` if the key contains a valid public half.""" - cdef _ecp.mbedtls_ecp_keypair* ecp = _pk.mbedtls_pk_ec(self._ctx) - return not _ecp.mbedtls_ecp_is_zero(&ecp.Q) - def sign(self, const unsigned char[:] message not None, digestmod=None): @@ -678,6 +682,8 @@ cdef class ECC(CipherBase): if self.curve in (Curve.CURVE25519, Curve.CURVE448) else "DER" ) + self._set_public() + self._set_private() return self.export_key(format) def _private_to_num(self): diff --git a/tests/test_pk.py b/tests/test_pk.py index 7e373da7..6a522b4c 100644 --- a/tests/test_pk.py +++ b/tests/test_pk.py @@ -193,6 +193,8 @@ def test_import_private_key( other = copy(cipher) other = type(cipher).from_buffer(key) + assert bytes(other) == key + assert other == cipher assert other.export_key() == cipher.export_key() == key assert other.export_public_key() == cipher.export_public_key()