Skip to content

Commit

Permalink
pk: Fix null-byte handling with PEM format
Browse files Browse the repository at this point in the history
Strip terminal null bytes on write and add it on read.

The stripping is consistent with other libraries such as OpenSSL.
mbedTLS nevertheless requires that the PEM strings be null-terminated
when importing PEM certificates.

Closes #75.
  • Loading branch information
Synss committed Jan 14, 2023
1 parent 3315886 commit bb766ca
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 27 deletions.
1 change: 1 addition & 0 deletions ChangeLog
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
* ci: Update wheels to mbedtls 2.28.2.
* md: Fixup cases where "algorithms_available" would return actually
unavailable algorithms.
* pk: Remove trailing null bytes from PEM format. (Issue #75)
* cipher: `get_supported_ciphers()` now returns a sequence (consistent
with the md "available_*" functions).

Expand Down
16 changes: 12 additions & 4 deletions src/mbedtls/pk.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,13 @@ cdef class CipherBase:
password-protected private keys.
"""
bkey = bytes(key)
if bkey.startswith(b"-----") and bkey.endswith(b"-----\n"):
# PEM must be null-terminated.
bkey = bkey + b"\0"
if callable(password):
return cls(key=key, password=password())
return cls(key=key, password=password)
return cls(key=bkey, password=password())
return cls(key=bkey, password=password)

@classmethod
def from_file(cls, path, password=None):
Expand Down Expand Up @@ -426,7 +430,7 @@ cdef class CipherBase:
try:
_exc.check_error(
_pk.mbedtls_pk_write_key_pem(&self._ctx, output, osize))
return output[0:osize].decode("ascii")
return output[0:osize].rstrip(b"\0").decode("ascii")
finally:
free(output)

Expand Down Expand Up @@ -473,7 +477,7 @@ cdef class CipherBase:
try:
_exc.check_error(
_pk.mbedtls_pk_write_pubkey_pem(&self._ctx, output, osize))
return output[0:osize].decode("ascii")
return output[0:osize].rstrip(b"\0").decode("ascii")
finally:
free(output)

Expand Down Expand Up @@ -746,6 +750,10 @@ cdef class ECC(CipherBase):
"""
if format == "POINT":
return self._public_to_point()
elif self.curve in (Curve.CURVE25519, Curve.CURVE448):
raise ValueError(
"Curve25519 and Curve448 only support export as NUM"
)
return super().export_public_key(format)


Expand Down
125 changes: 103 additions & 22 deletions tests/test_pk.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import pickle
import sys
from pathlib import Path
from typing import Any, Callable, List, Tuple, Type, Union, cast

import pytest
Expand Down Expand Up @@ -157,50 +158,97 @@ def test_pickle(self, cipher: _CipherType) -> None:
def test_hash(self, cipher: _CipherType) -> None:
assert isinstance(hash(cipher), int)

def test_export_private_key(self, cipher: _CipherType) -> None:
def test_export_private_key(
self, cipher: _CipherType, tmp_path: Path
) -> None:
cipher_tag = {RSA: "RSA", ECC: "EC"}[type(cipher)]

assert cipher.export_key("DER") == b""
assert cipher.export_key("PEM") == ""
assert cipher.key_size == 0

der = do_generate(cipher)

assert der
assert cipher.export_key() == der
assert cipher.export_key("DER") == der
assert cipher.export_key("PEM") != ""
assert cipher == cipher.export_key("DER")
assert cipher == cipher.export_key("PEM")
assert type(cipher).from_DER(cipher.export_key("DER")) == cipher
assert type(cipher).from_PEM(cipher.export_key("PEM")) == cipher
assert cipher.key_size > 0

def test_import_private_key(self, cipher: _CipherType) -> None:
assert der == cipher.export_key()
assert der == cipher.export_key("DER")
assert der == bytes(cipher)
assert der == cipher
assert cipher == type(cipher).from_DER(der)

pem = cipher.export_key("PEM")
assert pem.startswith(f"-----BEGIN {cipher_tag} PRIVATE KEY-----\n")
assert pem.endswith(f"-----END {cipher_tag} PRIVATE KEY-----\n")
assert pem == str(cipher)
assert pem == cipher.to_PEM().private
assert cipher == type(cipher).from_PEM(pem)

@pytest.mark.parametrize(
"copy",
[
lambda cipher: type(cipher).from_DER(cipher.export_key("DER")),
lambda cipher: type(cipher).from_PEM(cipher.export_key("PEM")),
],
)
def test_import_private_key(
self, cipher: _CipherType, copy: Callable[[_CipherType], _CipherType]
) -> None:
assert not cipher.export_key()
assert not cipher.export_public_key()

key = do_generate(cipher)
assert key

other = copy(cipher)
other = type(cipher).from_buffer(key)
assert other.export_key()
assert other.export_public_key()
assert other == cipher
assert other.export_key() == cipher.export_key() == key
assert other.export_public_key() == cipher.export_public_key()
assert check_pair(cipher, other) is True # Test private half.
assert check_pair(other, cipher) is True # Test public half.
assert check_pair(other, other) is True
assert cipher == other

def test_import_from_file(
self, cipher: _CipherType, tmp_path: Path
) -> None:
do_generate(cipher)

prv_der_file = tmp_path / "crt.prv.der"
prv_der_file.write_bytes(cipher.export_key("DER"))
assert cipher == type(cipher).from_file(prv_der_file)
prv_der_file.unlink()

pub_der_file = tmp_path / "crt.pub.der"
pub_der = cipher.export_public_key("DER")
pub_der_file.write_bytes(pub_der)
assert pub_der == type(cipher).from_file(pub_der_file)
pub_der_file.unlink()

prv_pem_file = tmp_path / "crt.prv.pem"
prv_pem_file.write_text(cipher.export_key("PEM"))
assert cipher == type(cipher).from_file(prv_pem_file)
prv_pem_file.unlink()

pub_pem_file = tmp_path / "crt.pub.pem"
pub_pem = cipher.export_public_key("PEM")
pub_pem_file.write_text(pub_pem)
assert pub_pem == type(cipher).from_file(pub_pem_file)
pub_pem_file.unlink()

def test_export_public_key(self, cipher: _CipherType) -> None:
assert cipher.export_public_key("DER") == b""
assert cipher.export_public_key("PEM") == ""

der = do_generate(cipher)

do_generate(cipher)
der = cipher.export_public_key("DER")
assert der
assert type(cipher).from_DER(
cipher.export_public_key("DER")
) == cipher.export_public_key("DER")
assert type(cipher).from_PEM(
cipher.export_public_key("PEM")
) == cipher.export_public_key("PEM")
assert der == type(cipher).from_DER(der).export_public_key("DER")

pem = cipher.export_public_key("PEM")
assert pem.startswith("-----BEGIN PUBLIC KEY-----\n")
assert pem.endswith("-----END PUBLIC KEY-----\n")
assert pem == type(cipher).from_PEM(pem).export_public_key("PEM")

def test_import_public_key(self, cipher: _CipherType) -> None:
assert not cipher.export_key()
Expand Down Expand Up @@ -252,7 +300,24 @@ def test_export_private_key(self, curve: Curve) -> None:
ecc.generate()
assert ecc.export_key("NUM") != 0

def test_export_public_key_to_point(self, curve: Curve) -> None:
if curve in (Curve.CURVE25519, Curve.CURVE448):
with pytest.raises(ValueError):
ecc.export_key("DER")
else:
der = ecc.export_key("DER")
assert der
assert ECC(curve).from_DER(der) == der

if curve in (Curve.CURVE25519, Curve.CURVE448):
with pytest.raises(ValueError):
ecc.export_key("PEM")
else:
pem = ecc.export_key("PEM")
assert pem
assert pem.startswith("-----BEGIN EC PRIVATE KEY-----\n")
assert pem.endswith("-----END EC PRIVATE KEY-----\n")

def test_export_public(self, curve: Curve) -> None:
ecc = ECC(curve)
assert ecc.export_public_key("POINT") == 0
assert ecc.export_public_key("POINT") == ECPoint(0, 0, 0)
Expand All @@ -268,6 +333,22 @@ def test_export_public_key_to_point(self, curve: Curve) -> None:
assert pub.y not in (0, pub.x, pub.z)
assert pub.z in (0, 1)

if curve in (Curve.CURVE25519, Curve.CURVE448):
with pytest.raises(ValueError):
ecc.export_public_key("DER")
else:
der = ecc.export_public_key("DER")
assert der

if curve in (Curve.CURVE25519, Curve.CURVE448):
with pytest.raises(ValueError):
ecc.export_public_key("PEM")
else:
pem = ecc.export_public_key("PEM")
assert pem
assert pem.startswith("-----BEGIN PUBLIC KEY-----\n")
assert pem.endswith("-----END PUBLIC KEY-----\n")


class TestECCtoECDH:
# pylint: disable=protected-access
Expand Down
2 changes: 1 addition & 1 deletion tests/test_x509.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def test_public_key(self) -> None:
crt, key = make_root_ca()
pem = crt.subject_public_key.export_public_key(format="PEM")
assert pem.startswith("-----BEGIN PUBLIC KEY-----\n")
assert pem.rstrip("\0").endswith("-----END PUBLIC KEY-----\n")
assert pem.endswith("-----END PUBLIC KEY-----\n")
assert pem == key.export_public_key(format="PEM")

def test_revocation_bad_cast(self) -> None:
Expand Down

0 comments on commit bb766ca

Please sign in to comment.