Skip to content

Commit

Permalink
pk: Track cipher content internally
Browse files Browse the repository at this point in the history
mbedtls-3.x
  • Loading branch information
Synss committed Feb 22, 2024
1 parent 3068d97 commit dae0cbb
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 36 deletions.
1 change: 1 addition & 0 deletions src/mbedtls/pk.pxd
Expand Up @@ -108,6 +108,7 @@ cdef extern from "mbedtls/pk.h" nogil:

cdef class CipherBase:
cdef mbedtls_pk_context _ctx
cdef object __state


cdef class RSA(CipherBase):
Expand Down
78 changes: 42 additions & 36 deletions src/mbedtls/pk.pyx
Expand Up @@ -155,6 +155,12 @@ def _get_md_alg(digestmod):
raise TypeError("a valid digestmod is required, got %r" % digestmod)


class CipherState(enum.Flag):
UNSET = enum.auto()
PUBLIC = enum.auto()
PRIVATE = enum.auto()


cdef class CipherBase:
"""Base class to RSA and ECC ciphers.
Expand All @@ -168,6 +174,7 @@ cdef class CipherBase:
name,
const unsigned char[:] key=None,
const unsigned char[:] password=None):
self.__state = CipherState.UNSET
_exc.check_error(_pk.mbedtls_pk_setup(
&self._ctx,
_pk.mbedtls_pk_info_from_type(
Expand All @@ -189,6 +196,11 @@ cdef class CipherBase:
except _exc.TLSError:
_exc.check_error(_pk.mbedtls_pk_parse_public_key(
&self._ctx, &key[0], key.size))
pub = self._public_to_PEM()
if "PUBLIC" in pub:
self.__state |= CipherState.PUBLIC
if _pk.mbedtls_pk_check_pair(&self._ctx, &self._ctx) == 0:
self.__state |= CipherState.PRIVATE

def __cinit__(self):
"""Initialize the context."""
Expand Down Expand Up @@ -256,8 +268,10 @@ cdef class CipherBase:
# PEM must be null-terminated.
bkey = bkey + b"\0"
if callable(password):
return cls(key=bkey, password=password())
return cls(key=bkey, password=password)
self = cls(key=bkey, password=password())
else:
self = cls(key=bkey, password=password)
return self

@classmethod
def from_file(cls, path, password=None):
Expand Down Expand Up @@ -292,13 +306,19 @@ cdef class CipherBase:
"""Return the size of the key, in bytes."""
return _pk.mbedtls_pk_get_len(&self._ctx)

def _set_private(self):
self.__state |= CipherState.PRIVATE

def _has_private(self):
"""Return `True` if the key contains a valid private half."""
raise NotImplementedError
return CipherState.PRIVATE in self.__state

def _set_public(self):
self.__state |= CipherState.PUBLIC

def _has_public(self):
"""Return `True` if the key contains a valid public half."""
raise NotImplementedError
return CipherState.PUBLIC in self.__state

def sign(self,
const unsigned char[:] message not None,
Expand Down Expand Up @@ -418,8 +438,6 @@ cdef class CipherBase:
raise NotImplementedError

def _private_to_DER(self):
if not self._has_private():
return b""
cdef int olen
cdef size_t osize = PRV_DER_MAX_BYTES
cdef unsigned char *output = <unsigned char *>malloc(
Expand All @@ -430,12 +448,15 @@ cdef class CipherBase:
olen = _exc.check_error(
_pk.mbedtls_pk_write_key_der(&self._ctx, output, osize))
return output[osize - olen:osize]
except _exc.TLSError as exc:
if exc.err == 0x4080:
# no private key
return b""
raise
finally:
free(output)

def _private_to_PEM(self):
if not self._has_private():
return ""
cdef size_t osize = PRV_DER_MAX_BYTES * 4 // 3 + 100
cdef unsigned char *output = <unsigned char *>malloc(
osize * sizeof(unsigned char))
Expand All @@ -446,6 +467,11 @@ cdef class CipherBase:
_exc.check_error(
_pk.mbedtls_pk_write_key_pem(&self._ctx, output, osize))
return output[0:osize].rstrip(b"\0").decode("ascii")
except _exc.TLSError as exc:
if exc.err == 0x4080:
# no private key
return ""
raise
finally:
free(output)

Expand All @@ -459,14 +485,12 @@ cdef class CipherBase:
"""
if format == "DER":
return self._private_to_DER()
return self._private_to_DER() if self._has_private() else b""
if format == "PEM":
return self._private_to_PEM()
return self._private_to_PEM() if self._has_private() else ""
raise ValueError(format)

def _public_to_DER(self):
if not self._has_public():
return b""
cdef int olen
cdef size_t osize = PRV_DER_MAX_BYTES
cdef unsigned char *output = <unsigned char *>malloc(
Expand All @@ -481,8 +505,6 @@ cdef class CipherBase:
free(output)

def _public_to_PEM(self):
if not self._has_public():
return ""
cdef size_t osize = PRV_DER_MAX_BYTES * 4 // 3 + 100
cdef unsigned char *output = <unsigned char *>malloc(
osize * sizeof(unsigned char))
Expand All @@ -506,9 +528,9 @@ cdef class CipherBase:
"""
if format == "DER":
return self._public_to_DER()
return self._public_to_DER() if self._has_public() else b""
if format == "PEM":
return self._public_to_PEM()
return self._public_to_PEM() if self._has_public() else ""
raise ValueError(format)


Expand All @@ -521,16 +543,6 @@ cdef class RSA(CipherBase):
const unsigned char[:] password=None):
super().__init__(b"RSA", key, password)

def _has_private(self):
"""Return `True` if the key contains a valid private half."""
return _rsa.mbedtls_rsa_check_privkey(
_pk.mbedtls_pk_rsa(self._ctx)
) == 0

def _has_public(self):
"""Return `True` if the key contains a valid public half."""
return _rsa.mbedtls_rsa_check_pubkey(_pk.mbedtls_pk_rsa(self._ctx)) == 0

def generate(self, unsigned int key_size=2048, int exponent=65537):
"""Generate an RSA keypair.
Expand All @@ -545,6 +557,8 @@ cdef class RSA(CipherBase):
_exc.check_error(_rsa.mbedtls_rsa_gen_key(
_pk.mbedtls_pk_rsa(self._ctx), &_rnd.mbedtls_ctr_drbg_random,
&__rng._ctx, key_size, exponent))
self._set_public()
self._set_private()
return self.export_key("DER")


Expand Down Expand Up @@ -643,16 +657,6 @@ cdef class ECC(CipherBase):
def curve(self):
return self._curve

def _has_private(self):
"""Return `True` if the key contains a valid private half."""
cdef const _ecp.mbedtls_ecp_keypair* ecp = _pk.mbedtls_pk_ec(self._ctx)
return _mpi.mbedtls_mpi_cmp_mpi(&ecp.d, &_mpi.MPI()._ctx) != 0

def _has_public(self):
"""Return `True` if the key contains a valid public half."""
cdef _ecp.mbedtls_ecp_keypair* ecp = _pk.mbedtls_pk_ec(self._ctx)
return not _ecp.mbedtls_ecp_is_zero(&ecp.Q)

def sign(self,
const unsigned char[:] message not None,
digestmod=None):
Expand All @@ -678,6 +682,8 @@ cdef class ECC(CipherBase):
if self.curve in (Curve.CURVE25519, Curve.CURVE448)
else "DER"
)
self._set_public()
self._set_private()
return self.export_key(format)

def _private_to_num(self):
Expand Down
2 changes: 2 additions & 0 deletions tests/test_pk.py
Expand Up @@ -193,6 +193,8 @@ def test_import_private_key(

other = copy(cipher)
other = type(cipher).from_buffer(key)
assert bytes(other) == key

assert other == cipher
assert other.export_key() == cipher.export_key() == key
assert other.export_public_key() == cipher.export_public_key()
Expand Down

0 comments on commit dae0cbb

Please sign in to comment.