diff --git a/Makefile b/Makefile index 5a8610a4..a6ed0774 100644 --- a/Makefile +++ b/Makefile @@ -4,16 +4,20 @@ PYX = $(wildcard mbedtls/*.pyx) PYX += $(wildcard mbedtls/cipher/*.pyx) PYX += $(wildcard mbedtls/pk/*.pyx) +LIBMBEDTLS = $(HOME)/lib/mbedtls-2.5.2 + release: cython $(PYX) python setup.py build_ext debug: cython -a -X linetrace=True $(PYX) - CFLAGS='-DCYTHON_TRACE=1' python setup.py build_ext --inplace + CFLAGS='-DCYTHON_TRACE=1' python setup.py build_ext --inplace \ + -L$(LIBMBEDTLS)/lib \ + -I$(LIBMBEDTLS)/include test: - nosetests -v --with-coverage --cover-package=mbedtls + pytest --cov mbedtls tests html: cd docs && make html diff --git a/mbedtls/__init__.py b/mbedtls/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mbedtls/cipher/__init__.py b/mbedtls/cipher/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mbedtls/pk/__init__.py b/mbedtls/pk/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/__init__.py b/tests/__init__.py index 85ddfb08..16cc1450 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,3 +1,5 @@ +# pylint: disable=missing-docstring + import random @@ -10,28 +12,13 @@ def assert_canonical_repr(obj): # pylint: disable=eval-used newobj = eval(repr(obj), frame.f_globals, frame.f_locals) except TypeError: - raise AssertionError("Cannot eval '%r'" % obj) from None + raise AssertionError("Cannot eval '%r'" % obj) finally: # explicitely delete the frame to avoid memory leaks, see also # https://docs.python.org/3/library/inspect.html#the-interpreter-stack del frame assert isinstance(newobj, type(obj)) -assert_canonical_repr.__test__ = False def _rnd(length): return bytes(random.randrange(0, 256) for _ in range(length)) -_rnd.__test__ = False - - -class TestRnd: - - @staticmethod - def test_key_length(): - for length in range(1024 + 1, 8): - assert len(_rnd(length)) == length - - @staticmethod - def test_values_fit_in_latin1(): - k = _rnd(2048) - assert k.decode("latin1") diff --git a/tests/test_cipher.py b/tests/test_cipher.py index a7f3d7d8..f1ae5b90 100644 --- a/tests/test_cipher.py +++ b/tests/test_cipher.py @@ -3,13 +3,12 @@ # Disable checks for violations that are acceptable in tests. # pylint: disable=missing-docstring # pylint: disable=attribute-defined-outside-init -# pylint: disable=invalid-name +# pylint: disable=invalid-name, redefined-outer-name +from collections import namedtuple from functools import partial -from nose.plugins.skip import SkipTest -from nose.tools import assert_equal, assert_raises -from nose.tools import raises +import pytest # pylint: disable=import-error from mbedtls.cipher._cipher import * @@ -32,129 +31,85 @@ def test_get_supported_ciphers(): assert cl and set(cl).issubset(set(CIPHER_NAME)) -@raises(CipherError) def test_wrong_size_raises_cipher_error(): - Cipher(b"AES-512-ECB", b"", 0, b"") + with pytest.raises(CipherError): + Cipher(b"AES-512-ECB", b"", 0, b"") -@raises(CipherError) def test_random_name_raises_cipher_error(): - Cipher(b"RANDOM TEXT IS NOT A CIPHER", b"", 0, b"") + with pytest.raises(CipherError): + Cipher(b"RANDOM TEXT IS NOT A CIPHER", b"", 0, b"") -@raises(CipherError) def test_zero_length_raises_cipher_error(): - Cipher(b"", b"", 0, b"") + with pytest.raises(CipherError): + Cipher(b"", b"", 0, b"") -@raises(ValueError) def test_cbc_raises_value_error_without_iv(): - Cipher(b"AES-512-CBC", b"", MODE_CBC, b"") + with pytest.raises(ValueError): + Cipher(b"AES-512-CBC", b"", MODE_CBC, b"") -@raises(ValueError) def test_cfb_raises_value_error_without_iv(): - Cipher(b"AES-512-CFB", b"", MODE_CFB, b"") - - -def setup_cipher(name): + with pytest.raises(ValueError): + Cipher(b"AES-512-CFB", b"", MODE_CFB, b"") + + +def module_from_name(name): + for cipher, mod in ( + (b"AES", mb.AES), + (b"ARC4", mb.ARC4), + (b"BLOWFISH", mb.Blowfish), + (b"CAMELLIA", mb.Camellia), + (b"DES-EDE3", mb.DES3), + (b"DES-EDE", mb.DES3dbl), + (b"DES", mb.DES)): + if name.startswith(cipher): + return mod + raise NotImplementedError + + +@pytest.fixture(params=(name for name in sorted(get_supported_ciphers()) + if not name.endswith(b"CCM"))) # Not compiled by default. +def cipher(request): + name = request.param cipher = Cipher(name, key=None, mode=None, iv=b"\x00") key = _rnd(cipher.key_size) iv = _rnd(cipher.iv_size) - block = _rnd(cipher.block_size) - if name.startswith(b"AES"): - mod = mb.AES - elif name.startswith(b"ARC4"): - mod = mb.ARC4 - elif name.startswith(b"BLOWFISH"): - mod = mb.Blowfish - elif name.startswith(b"CAMELLIA"): - mod = mb.Camellia - elif name.startswith(b"DES-EDE3"): - mod = mb.DES3 - elif name.startswith(b"DES-EDE"): - mod = mb.DES3dbl - elif name.startswith(b"DES"): - mod = mb.DES - else: - raise NotImplementedError - return mod, key, cipher.mode, iv, block - - -def get_ciphers(): - return (name for name in sorted(get_supported_ciphers()) - if not name.endswith(b"CCM")) # Not compiled by default. + return module_from_name(name).new(key, cipher.mode, iv) def is_streaming(cipher): return cipher.name.startswith(b"ARC") or cipher.mode is not MODE_ECB -def skip_test(message): - raise SkipTest(message) -skip_test.__test__ = False - - -def fail_test(message): - assert False, message -fail_test.__test__ = False - - -def check_encrypt_decrypt(cipher, block): - assert_equal(cipher.decrypt(cipher.encrypt(block)), block) - - -def test_encrypt_decrypt(): - for name in get_ciphers(): - description = "check_encrypt_decrypt(%s)" % name.decode() - mod, key, mode, iv, block = setup_cipher(name) - cipher = mod.new(key, mode, iv) - test = partial(check_encrypt_decrypt, cipher, block) - test.description = description - yield test +def test_encrypt_decrypt(cipher): + block = _rnd(cipher.block_size) + assert cipher.decrypt(cipher.encrypt(block)) == block -def test_module_level_block_size_variable(): - for name in get_ciphers(): - description = ("test_module_level_block_size_variable(%s)" % - name.decode()) - mod, key, mode, iv, block = setup_cipher(name) - cipher = mod.new(key, mode, iv) - test = partial(assert_equal, cipher.block_size, mod.block_size) - test.description = description - yield test +def test_module_level_block_size_variable(cipher): + mod = module_from_name(cipher.name) + assert cipher.block_size == mod.block_size -def test_module_level_key_size_variable(): - for name in get_ciphers(): - description = ("test_module_level_block_size_variable(%s)" % - name.decode()) - mod, key, mode, iv, block = setup_cipher(name) - if mod.key_size is None: - skip_test.description = description - yield skip_test, "module defines variable-length key" - continue - cipher = mod.new(key, mode, iv) - test = partial(assert_equal, cipher.key_size, mod.key_size) - test.description = description - yield test +def test_module_level_key_size_variable(cipher): + mod = module_from_name(cipher.name) + if mod.key_size is None: + pytest.skip("module defines variable-length key") + assert cipher.key_size == mod.key_size -def test_wrong_key_size_raises_invalid_key_size_error(): - for name in get_ciphers(): - description = "wrong_key_size_raises(%s)" % name.decode() - mod, key, mode, iv, block = setup_cipher(name) - if mod.key_size is None: - skip_test.description = description - yield skip_test, "module defines variable-length key" - continue - test = partial(assert_raises, InvalidKeyLengthError, - mod.new, key + b"\x00", mode, iv) - test.description = description - yield test +def test_wrong_key_size_raises_invalid_key_size_error(cipher): + mod = module_from_name(cipher.name) + if mod.key_size is None: + pytest.skip("module defines variable-length key") + with pytest.raises(InvalidKeyLengthError): + mod.new(_rnd(cipher.key_size) + b"\x00", cipher.mode, _rnd(cipher.iv_size)) -def test_check_against_pycrypto(): +def test_check_against_pycrypto(cipher): try: import Crypto.Cipher as pc # We must import the following to have them in scope. @@ -166,7 +121,7 @@ def test_check_against_pycrypto(): from Crypto.Cipher import DES3 # pylint: enable=unused-import except ImportError as exc: - raise SkipTest(str(exc)) + pytest.skip(str(exc)) pc_supported_modes = { MODE_ECB, @@ -175,51 +130,61 @@ def test_check_against_pycrypto(): MODE_CTR, } - def check_against_pycrypto(cipher, ref, block): - assert_equal(cipher.encrypt(block), ref.encrypt(block)) + mod = module_from_name(cipher.name) + if cipher.mode not in pc_supported_modes.difference( + {MODE_CTR, MODE_CFB}): + # Counter actually requires the counter. + pytest.skip("encryption mode unsupported") - for name in get_ciphers(): - description = "check_against_pycrypto(%s)" % name.decode() - mod, key, mode, iv, block = setup_cipher(name) - if mode not in pc_supported_modes.difference( - {MODE_CTR, MODE_CFB}): - # Counter actually requires the counter. - skip_test.description = description - yield skip_test, "encryption mode unsupported" - continue + key = _rnd(cipher.key_size) + iv = _rnd(cipher.iv_size) + block = _rnd(cipher.block_size) - ref_mod = { + cipher = mod.new(key, cipher.mode, iv) # A new cipher... + try: + ref = { mb.AES: pc.AES, mb.ARC4: pc.ARC4, mb.Blowfish: pc.Blowfish, mb.DES: pc.DES, mb.DES3: pc.DES3, - }.get(mod, None) - if ref_mod is None: - skip_test.description = description - yield skip_test, "%s not available in pyCrypto" % cipher - continue - cipher = mod.new(key, mode, iv) - ref = ref_mod.new(key, mode, iv) - - # Use partial to avoid late binding in report. - if cipher.mode is MODE_CBC: - # mbed TLS adds a block to CBC (probably due to padding) so - # that pyCrypto returns one block less. - test = partial(assert_equal, cipher.encrypt(block)[:len(block)], - ref.encrypt(block)) - else: - test = partial(assert_equal, cipher.encrypt(block), - ref.encrypt(block)) - test.description = description - yield test - - -def test_check_against_openssl(): + }[mod].new(key, cipher.mode, iv) + except KeyError: + pytest.skip("%s not available in pyCrypto" % cipher) + + # Use partial to avoid late binding in report. + if cipher.mode is MODE_CBC: + # mbed TLS adds a block to CBC (probably due to padding) so + # that pyCrypto returns one block less. + assert cipher.encrypt(block)[:len(block)] == ref.encrypt(block) + else: + assert cipher.encrypt(block) == ref.encrypt(block) + + +def test_check_against_openssl(cipher): from binascii import hexlify from subprocess import PIPE, Popen - CIPHER_LOOKUP = { + if cipher.mode is MODE_GCM: + pytest.skip("encryption mode unsupported") + + if cipher.name in { + b"ARC4-128", b"DES-EDE3-ECB", b"DES-EDE-ECB", + b"CAMELLIA-256-ECB", + b"CAMELLIA-128-CTR", b"CAMELLIA-192-CTR", + b"CAMELLIA-256-CTR", + b"BLOWFISH-CTR", + }: + pytest.skip("not available in openssl") + + key = _rnd(cipher.key_size) + iv = _rnd(cipher.iv_size) + block = _rnd(cipher.block_size) + + # A new cipher... + cipher = module_from_name(cipher.name).new(key, cipher.mode, iv) + + openssl_cipher = { b"AES-128-CFB128": "aes-128-cfb", b"AES-192-CFB128": "aes-192-cfb", b"AES-256-CFB128": "aes-256-cfb", @@ -230,74 +195,41 @@ def test_check_against_openssl(): b"BLOWFISH-CBC": "bf-cbc", b"BLOWFISH-CFB64": "bf-cfb", b"ARC4-128": "rc4", - } + }.get(cipher.name, cipher.name.decode("ascii").lower()) + + cmd = ["openssl", "enc", "-%s" % openssl_cipher, "-nosalt", + "-K", hexlify(key).decode("ascii")] + if cipher.mode is MODE_ECB: + cmd.append("-nopad") + else: + cmd.extend(["-iv", hexlify(iv).decode("ascii")]) + openssl = Popen(cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE) + out, err = openssl.communicate(input=block) + if err: + pytest.fail(":".join((str(cipher), openssl_cipher, + err.decode().splitlines()[0]))) + else: + assert cipher.encrypt(block) == out + + +def test_streaming_ciphers(cipher): + if not is_streaming(cipher): + pytest.skip("not a streaming cipher") + block = _rnd(20000) + assert cipher.decrypt(cipher.encrypt(block)) == block + + +def test_fixed_block_size_ciphers_long_block_raise_ciphererror(cipher): + if is_streaming(cipher): + pytest.skip("streaming cipher") + with pytest.raises(CipherError): + block = _rnd(cipher.block_size) + _rnd(1) + cipher.encrypt(block) + - for name in get_ciphers(): - description = "check_against_openssl(%s)" % name.decode() - mod, key, mode, iv, block = setup_cipher(name) - if mode == MODE_GCM: - skip_test.description = description - yield skip_test, "encryption mode unsupported" - continue - if name in {b"ARC4-128", b"DES-EDE3-ECB", b"DES-EDE-ECB", - b"CAMELLIA-256-ECB", - b"CAMELLIA-128-CTR", b"CAMELLIA-192-CTR", - b"CAMELLIA-256-CTR", - b"BLOWFISH-CTR", - }: - yield skip_test, "%s not available in openssl" % name - continue - - cipher = mod.new(key, mode, iv) - openssl_cipher = CIPHER_LOOKUP.get( - cipher.name, cipher.name.decode("ascii").lower()) - - cmd = ["openssl", "enc", "-%s" % openssl_cipher, "-nosalt", - "-K", hexlify(key).decode("ascii")] - if cipher.mode is MODE_ECB: - cmd.append("-nopad") - else: - cmd.extend(["-iv", hexlify(iv).decode("ascii")]) - openssl = Popen(cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE) - out, err = openssl.communicate(input=block) - if err: - test = partial(fail_test, ":".join((str(cipher), openssl_cipher, - err.decode().splitlines()[0]))) - else: - test = partial(assert_equal, cipher.encrypt(block), out) - test.description = description - yield test - - -def test_streaming_ciphers(): - for name in get_ciphers(): - description = "check_stream_cipher(%s)" % name.decode() - mod, key, mode, iv, block = setup_cipher(name) - cipher = mod.new(key, mode, iv) - if is_streaming(cipher): - block = _rnd(20000) - check_encrypt_decrypt.description = description - yield check_encrypt_decrypt, cipher, block - - -def test_fixed_block_size_ciphers(): - - def check_encrypt_raises(cipher, block, exc): - with assert_raises(exc): - cipher.encrypt(block) - - for name in get_ciphers(): - mod, key, mode, iv, block = setup_cipher(name) - cipher = mod.new(key, mode, iv) - if not is_streaming(cipher): - description = "long_block_raises(%s)" % name.decode() - test = partial(check_encrypt_raises, cipher, block + _rnd(1), - CipherError) - test.description = description - yield test - - description = "short_block_raises(%s)" % name.decode() - test = partial(check_encrypt_raises, cipher, block[1:], - CipherError) - test.description = description - yield test +def test_fixed_block_size_ciphers_short_block_raise_ciphererror(cipher): + if is_streaming(cipher): + pytest.skip("streaming cipher") + with pytest.raises(CipherError): + block = _rnd(cipher.block_size)[1:] + cipher.encrypt(block) diff --git a/tests/test_md.py b/tests/test_md.py index b07e3733..3c5cb92c 100644 --- a/tests/test_md.py +++ b/tests/test_md.py @@ -7,9 +7,9 @@ from functools import partial import hashlib import hmac +import inspect -from nose.plugins.skip import SkipTest -from nose.tools import assert_equal, assert_greater_equal, assert_less +import pytest # pylint: disable=import-error from mbedtls._md import MD_NAME @@ -20,6 +20,12 @@ from . import _rnd +@pytest.fixture(params=md_hash.algorithms_available) +def algorithm(request): + name = request.param + return md_hash.new(name) + + def make_chunks(buffer, size): for i in range(0, len(buffer), size): yield buffer[i:i+size] @@ -27,8 +33,7 @@ def make_chunks(buffer, size): def test_make_chunks(): buffer = _rnd(1024) - assert_equal(b"".join(buf for buf in make_chunks(buffer, 100)), - buffer) + assert b"".join(buf for buf in make_chunks(buffer, 100)) == buffer def test_md_list(): @@ -40,142 +45,99 @@ def test_algorithms(): md_hash.algorithms_available) -def test_type_accessor(): - def assert_in_bounds(value, lower, higher): - assert_greater_equal(value, lower) - assert_less(value, higher) - - for name in md_hash.algorithms_available: - alg = md_hash.new(name) - # pylint: disable=protected-access - test = partial(assert_in_bounds, alg._type, 0, len(MD_NAME)) - test.description = "test_type_accessor(%s)" % name - yield test - - -def test_copy_hash(): - for name in md_hash.algorithms_available: - buf0 = _rnd(512) - buf1 = _rnd(512) - alg = md_hash.new(name, buf0) - copy = alg.copy() - alg.update(buf1) - copy.update(buf1) - # Use partial to have the correct name in failed reports (by - # avoiding late bindings). - test = partial(assert_equal, alg.digest(), copy.digest()) - test.description = "test_copy_hash(%s)" % name - yield test - - -def test_check_hexdigest_against_hashlib(): - for name in md_hash.algorithms_available: - buf = _rnd(1024) - try: - alg = md_hash.new(name, buf) - ref = hashlib.new(name, buf) - except ValueError as exc: - # Unsupported hash type. - raise SkipTest(str(exc)) from exc - test = partial(assert_equal, alg.hexdigest(), ref.hexdigest()) - test.description = "check_hexdigest_against_hashlib(%s)" % name - yield test - - -def test_check_against_hashlib_nobuf(): - for name in md_hash.algorithms_available: - buf = _rnd(1024) - try: - alg = md_hash.new(name, buf) - ref = hashlib.new(name, buf) - except ValueError as exc: - # Unsupported hash type. - raise SkipTest(str(exc)) from exc - test = partial(assert_equal, alg.digest(), ref.digest()) - test.description = "check_against_hashlib_nobuf(%s)" % name - yield test - - -def test_check_against_hashlib_buf(): - for name in md_hash.algorithms_available: - buf = _rnd(4096) - try: - alg = md_hash.new(name) - ref = hashlib.new(name) - except ValueError as exc: - # Unsupported hash type. - raise SkipTest(str(exc)) from exc - for chunk in make_chunks(buf, 500): - alg.update(chunk) - ref.update(chunk) - test = partial(assert_equal, alg.digest(), ref.digest()) - test.description = "check_against_hashlib_buf(%s)" % name - yield test - - -def test_check_against_hmac_nobuf(): - for name in md_hmac.algorithms_available: - buf = _rnd(1024) - key = _rnd(16) - try: - alg = md_hmac.new(key, buf, digestmod=name) - ref = hmac.new(key, buf, digestmod=name) - except ValueError as exc: - # Unsupported hash type. - raise SkipTest(str(exc)) from exc - # Use partial to have the correct name in failed reports (by - # avoiding late bindings). - test = partial(assert_equal, alg.digest(), ref.digest()) - test.description = "check_against_hmac_nobuf(%s)" % name - yield test - - -def test_check_against_hmac_buf(): - for name in md_hmac.algorithms_available: - buf = _rnd(4096) - key = _rnd(16) - try: - alg = md_hmac.new(key, digestmod=name) - ref = hmac.new(key, digestmod=name) - except ValueError as exc: - # Unsupported hash type. - raise SkipTest(str(exc)) from exc - for chunk in make_chunks(buf, 500): - alg.update(chunk) - ref.update(chunk) - test = partial(assert_equal, alg.digest(), ref.digest()) - test.description = "check_against_hmac_buf(%s)" % name - yield test - - -def test_hash_instantiation(): - import inspect - - def check_instantiation(fun, name): - alg1 = fun() - alg2 = md_hash.new(name) - assert_equal(type(alg1), type(alg2)) - assert_equal(alg1.name, alg2.name) - - for name, member in inspect.getmembers(md_hash): - if name in md_hash.algorithms_available: - test = partial(check_instantiation, member, name) - test.description = "check_hash_instantiation(%s)" % name - yield test - - -def test_hmac_instantiation(): - import inspect - - def check_instantiation(fun, name): - key = _rnd(16) - alg1 = fun(key) - alg2 = md_hmac.new(key, digestmod=name) - assert_equal(type(alg1), type(alg2)) - assert_equal(alg1.name, alg2.name) - - for name, member in inspect.getmembers(md_hmac): - if name in md_hmac.algorithms_available: - test = partial(check_instantiation, member, name) - test.description = "check_hmac_instantiation(%s)" % name - yield test +def test_type_accessor(algorithm): + # pylint: disable=protected-access + assert 0 <= algorithm._type < len(MD_NAME) + + +def test_copy_hash(algorithm): + buf0 = _rnd(512) + buf1 = _rnd(512) + copy = algorithm.copy() + algorithm.update(buf1) + copy.update(buf1) + assert algorithm.digest() == copy.digest() + + +def test_check_hexdigest_against_hashlib(algorithm): + buf = _rnd(1024) + try: + alg = md_hash.new(algorithm.name, buf) + ref = hashlib.new(algorithm.name, buf) + except ValueError as exc: + # Unsupported hash type. + pytest.skip(str(exc)) + assert alg.hexdigest() == ref.hexdigest() + + +def test_check_against_hashlib_nobuf(algorithm): + buf = _rnd(1024) + try: + alg = md_hash.new(algorithm.name, buf) + ref = hashlib.new(algorithm.name, buf) + except ValueError as exc: + # Unsupported hash type. + pytest.skip(str(exc)) + assert alg.digest() == ref.digest() + + +def test_check_against_hashlib_buf(algorithm): + buf = _rnd(4096) + try: + alg = md_hash.new(algorithm.name) + ref = hashlib.new(algorithm.name) + except ValueError as exc: + # Unsupported hash type. + pytest.skip(str(exc)) + for chunk in make_chunks(buf, 500): + alg.update(chunk) + ref.update(chunk) + assert alg.digest() == ref.digest() + + +def test_check_against_hmac_nobuf(algorithm): + buf = _rnd(1024) + key = _rnd(16) + try: + alg = md_hmac.new(key, buf, digestmod=algorithm.name) + ref = hmac.new(key, buf, digestmod=algorithm.name) + except ValueError as exc: + # Unsupported hash type. + pytest.skip(str(exc)) + assert alg.digest() == ref.digest() + + +def test_check_against_hmac_buf(algorithm): + buf = _rnd(4096) + key = _rnd(16) + try: + alg = md_hmac.new(key, digestmod=algorithm.name) + ref = hmac.new(key, digestmod=algorithm.name) + except ValueError as exc: + # Unsupported hash type. + pytest.skip(str(exc)) + for chunk in make_chunks(buf, 500): + alg.update(chunk) + ref.update(chunk) + assert alg.digest() == ref.digest() + + +@pytest.mark.parametrize("name, algcls", inspect.getmembers(md_hash)) +def test_hash_instantiation(name, algcls): + if name not in md_hash.algorithms_available: + pytest.skip("not a hash algorithm") + alg1 = algcls() + alg2 = md_hash.new(name) + assert type(alg1) is type(alg2) + assert alg1.name == alg2.name + + +@pytest.mark.parametrize("name, algcls", inspect.getmembers(md_hmac)) +def test_hmac_instantiation(name, algcls): + if name not in md_hash.algorithms_available: + pytest.skip("not an hmac algorithm") + key = _rnd(16) + alg1 = algcls(key) + alg2 = md_hmac.new(key, digestmod=name) + assert type(alg1) is type(alg2) + assert alg1.name == alg2.name diff --git a/tests/test_pk.py b/tests/test_pk.py index 28d5564b..49f4266b 100644 --- a/tests/test_pk.py +++ b/tests/test_pk.py @@ -4,14 +4,9 @@ from functools import partial from tempfile import TemporaryFile -from nose.plugins.skip import SkipTest -from nose.tools import (assert_equal, assert_is_instance, - assert_true, assert_false, - assert_is_none, assert_is_not_none, - raises, - ) - -import mbedtls.hash as hash +import pytest + +import mbedtls.hash as _hash from mbedtls.exceptions import * from mbedtls.exceptions import _ErrorBase from mbedtls.pk._pk import _type_from_name, _get_md_alg @@ -20,9 +15,19 @@ from . import _rnd -def fail_test(message): - assert False, message -fail_test.__test__ = False +@pytest.fixture(params=(name for name in sorted(get_supported_ciphers()) + if name != b"NONE")) +def cipher(request): + name = request.param + return CipherBase(name) + + +@pytest.fixture(params=(1024, 2048, 4096)) +def rsa(request): + key_size = request.param + cipher = RSA() + cipher.generate(key_size) + return cipher def test_cipher_list(): @@ -35,231 +40,163 @@ def test_get_supported_ciphers(): def test_type_from_name(): - assert_equal( - tuple(_type_from_name(name) for name in CIPHER_NAME), - tuple(range(len(CIPHER_NAME)))) + assert tuple(_type_from_name(name) + for name in CIPHER_NAME) == tuple(range(len(CIPHER_NAME))) -def get_ciphers(): - return (name for name in sorted(get_supported_ciphers()) - if name not in {b"NONE"}) +def test_type_accessor(cipher): + assert cipher._type == _type_from_name(cipher.name) -def test_type_accessor(): - for name in get_ciphers(): - description = "test_type_accessor(%s)" % name - cipher = CipherBase(name) - test = partial(assert_equal, cipher._type, _type_from_name(name)) - test.description = description - yield test +def test_key_size_accessor(cipher): + assert cipher.key_size == 0 -def test_name_accessor(): - for name in get_ciphers(): - description = "test_name_accessor(%s)" % name - cipher = CipherBase(name) - test = partial(assert_equal, cipher.name, name) - test.description = description - yield test +@pytest.mark.parametrize( + "algorithm", (_get_md_alg(name) for name in _hash.algorithms_available)) +def test_digestmod(algorithm): + assert isinstance(algorithm(), _hash.Hash) -def test_key_size_accessor(): - for name in get_ciphers(): - description = "test_key_size_accessor(%s)" % name - cipher = CipherBase(name) - test = partial(assert_equal, cipher.key_size, 0) - test.description = description - yield test +@pytest.mark.parametrize( + "md_algorithm", (vars(_hash)[name] for name in _hash.algorithms_available)) +def test_digestmod_from_ctor(md_algorithm): + assert callable(md_algorithm) + algorithm = _get_md_alg(md_algorithm) + assert isinstance(algorithm(), _hash.Hash) -def test_digestmod(): - for name in hash.algorithms_available: - alg = _get_md_alg(name) - test = partial(assert_is_instance, alg(), hash.Hash) - test.description = "test_digestmod_from_string(%s)" % name - yield test +def test_rsa_encrypt_decrypt(rsa): + msg = _rnd(rsa.key_size - 11) + assert rsa.decrypt(rsa.encrypt(msg)) == msg -def test_digestmod_from_ctor(): - for name in hash.algorithms_available: - md_alg = vars(hash)[name] - assert callable(md_alg) - alg = _get_md_alg(md_alg) - test = partial(assert_is_instance, alg(), hash.Hash) - test.description = "test_digestmod_from_ctor(%s)" % name - yield test +def test_rsa_sign_without_key_returns_none(): + rsa = RSA() + message = _rnd(4096) + assert rsa.sign(message, _hash.md5) is None -def test_rsa_encrypt_decrypt(): - for key_size in (1024, 2048, 4096): - cipher = RSA() - cipher.generate(key_size) - msg = _rnd(cipher.key_size - 11) - enc = cipher.encrypt(msg) - dec = cipher.decrypt(enc) - test = partial(assert_equal, dec, msg) - test.description = "test_encrypt_decrypt(%s:%s)" % ("RSA", key_size) - yield test +def test_rsa_check_pair(rsa): + assert check_pair(rsa, rsa) is True -def test_rsa_sign_without_key_returns_none(): +def test_rsa_has_private_and_has_public_with_private_key(rsa): cipher = RSA() - message = _rnd(4096) - assert_is_none(cipher.sign(message, hash.md5)) - - -class _TestRsaBase: - - def setup(self): - key_size = 2048 - self.cipher = RSA() - self.cipher.generate(key_size) - - -class TestRsa(_TestRsaBase): - - def test_keypair(self): - assert_true(check_pair(self.cipher, self.cipher)) - - def test_has_private_and_has_public_with_private_key(self): - cipher = RSA() - assert_false(cipher.has_private()) - assert_false(cipher.has_public()) - - cipher.import_(self.cipher._write_private_key_der()) - assert_true(cipher.has_private()) - assert_true(cipher.has_public()) - - def test_has_private_and_has_public_with_public_key(self): - cipher = RSA() - assert_false(cipher.has_private()) - assert_false(cipher.has_public()) - - cipher.import_(self.cipher._write_public_key_der()) - assert_false(cipher.has_private()) - assert_true(cipher.has_public()) - - -class TestRsaWriteParse(_TestRsaBase): - - def test_write_and_parse_private_key_der(self): - prv = self.cipher._write_private_key_der() - cipher = RSA() - cipher._parse_private_key(prv) - assert_true(check_pair(self.cipher, cipher)) # Test private half. - assert_true(check_pair(cipher, self.cipher)) # Test public half. - assert_true(check_pair(cipher, cipher)) - - def test_write_and_parse_private_key_pem(self): - prv = self.cipher._write_private_key_pem() - cipher = RSA() - cipher._parse_private_key(prv) - assert_true(check_pair(self.cipher, cipher)) # Test private half. - assert_true(check_pair(cipher, self.cipher)) # Test public half. - assert_true(check_pair(cipher, cipher)) - - def test_write_and_parse_public_key_der(self): - pub = self.cipher._write_public_key_der() - cipher = RSA() - cipher._parse_public_key(pub) - assert_false(check_pair(self.cipher, cipher)) # Test private half. - assert_true(check_pair(cipher, self.cipher)) # Test public half. - assert_false(check_pair(cipher, cipher)) - - def test_write_and_parse_public_key_pem(self): - pub = self.cipher._write_public_key_pem() - cipher = RSA() - cipher._parse_public_key(pub) - assert_false(check_pair(self.cipher, cipher)) # Test private half. - assert_true(check_pair(cipher, self.cipher)) # Test public half. - assert_false(check_pair(cipher, cipher)) - - @raises(PkError) - def test_write_public_der_in_private_raises(self): - pub = self.cipher._write_public_key_der() - cipher = RSA() + assert cipher.has_private() is False + assert cipher.has_public() is False + + cipher.import_(rsa._write_private_key_der()) + assert cipher.has_private() is True + assert cipher.has_public() is True + + +def test_rsa_has_private_and_has_public_with_public_key(rsa): + cipher = RSA() + assert cipher.has_private() is False + assert cipher.has_public() is False + + cipher.import_(rsa._write_public_key_der()) + assert cipher.has_private() is False + assert cipher.has_public() is True + + +def test_rsa_write_and_parse_private_key_der(rsa): + prv = rsa._write_private_key_der() + cipher = RSA() + cipher._parse_private_key(prv) + assert check_pair(rsa, cipher) is True # Test private half. + assert check_pair(cipher, rsa) is True # Test public half. + assert check_pair(cipher, cipher) is True + + +def test_rsa_write_and_parse_private_key_pem(rsa): + prv = rsa._write_private_key_pem() + cipher = RSA() + cipher._parse_private_key(prv) + assert check_pair(rsa, cipher) is True # Test private half. + assert check_pair(cipher, rsa) is True # Test public half. + assert check_pair(cipher, cipher) is True + + +def test_rsa_write_and_parse_public_key_der(rsa): + pub = rsa._write_public_key_der() + cipher = RSA() + cipher._parse_public_key(pub) + assert check_pair(rsa, cipher) is False # Test private half. + assert check_pair(cipher, rsa) is True # Test public half. + assert check_pair(cipher, cipher) is False + + +def test_rsa_write_and_parse_public_key_pem(rsa): + pub = rsa._write_public_key_pem() + cipher = RSA() + cipher._parse_public_key(pub) + assert check_pair(rsa, cipher) is False # Test private half. + assert check_pair(cipher, rsa) is True # Test public half. + assert check_pair(cipher, cipher) is False + + +def test_rsa_write_public_der_in_private_raises(rsa): + pub = rsa._write_public_key_der() + cipher = RSA() + with pytest.raises(PkError): cipher._parse_private_key(pub) - @raises(_ErrorBase) - def test_write_private_der_in_public_raises(self): - prv = self.cipher._write_private_key_der() - cipher = RSA() + +def test_rsa_write_private_der_in_public_raises(rsa): + prv = rsa._write_private_key_der() + cipher = RSA() + with pytest.raises(_ErrorBase): cipher._parse_public_key(prv) -class TestRsaImportExport(_TestRsaBase): - - def test_import_public_key(self): - cipher = RSA() - cipher.import_(self.cipher._write_public_key_der()) - assert_false(check_pair(self.cipher, cipher)) # Test private half. - assert_true(check_pair(cipher, self.cipher)) # Test public half. - assert_false(check_pair(cipher, cipher)) - - def test_import_private_key(self): - cipher = RSA() - cipher.import_(self.cipher._write_private_key_der()) - assert_true(check_pair(self.cipher, cipher)) # Test private half. - assert_true(check_pair(cipher, self.cipher)) # Test public half. - assert_true(check_pair(cipher, cipher)) - - def test_export_private_key_pem(self): - cipher = RSA() - prv, pub = self.cipher.export(format="PEM") - cipher.import_(prv) - assert_true(cipher.has_private()) - assert_true(cipher.has_public()) - assert_true(check_pair(self.cipher, cipher)) # Test private half. - assert_true(check_pair(cipher, self.cipher)) # Test public half. - assert_true(check_pair(cipher, cipher)) - - def test_export_private_key_der(self): - cipher = RSA() - prv, pub = self.cipher.export(format="DER") - cipher.import_(prv) - assert_true(cipher.has_private()) - assert_true(cipher.has_public()) - assert_true(check_pair(self.cipher, cipher)) # Test private half. - assert_true(check_pair(cipher, self.cipher)) # Test public half. - assert_true(check_pair(cipher, cipher)) - - def test_export_private_key_to_file_pem(self): - cipher = RSA() - with TemporaryFile() as prv: - prv.write(self.cipher.export(format="PEM")[0]) - prv.seek(0) - cipher.import_(prv.read()) - assert_true(cipher.has_private()) - assert_true(cipher.has_public()) - assert_true(check_pair(self.cipher, cipher)) # Test private half. - assert_true(check_pair(cipher, self.cipher)) # Test public half. - assert_true(check_pair(cipher, cipher)) - - def test_export_private_key_to_file_der(self): - cipher = RSA() - with TemporaryFile() as prv: - prv.write(self.cipher.export(format="DER")[0]) - prv.seek(0) - cipher.import_(prv.read()) - assert_true(cipher.has_private()) - assert_true(cipher.has_public()) - assert_true(check_pair(self.cipher, cipher)) # Test private half. - assert_true(check_pair(cipher, self.cipher)) # Test public half. - assert_true(check_pair(cipher, cipher)) - - -class TestRsaSignature(_TestRsaBase): - - def test_sign_verify(self): - message = _rnd(4096) - sig = self.cipher.sign(message, hash.md5) - assert_is_not_none(sig) - assert_true(self.cipher.verify(message, sig, hash.md5)) - assert_false(self.cipher.verify(message + b"\0", sig, hash.md5)) - - def test_sign_verify_default_digestmod(self): - message = _rnd(4096) - sig = self.cipher.sign(message) - assert_is_not_none(sig) - assert_true(self.cipher.verify(message, sig)) - assert_false(self.cipher.verify(message + b"\0", sig)) +def test_rsa_import_public_key(rsa): + cipher = RSA() + cipher.import_(rsa._write_public_key_der()) + assert check_pair(rsa, cipher) is False # Test private half. + assert check_pair(cipher, rsa) is True # Test public half. + assert check_pair(cipher, cipher) is False + + +def test_rsa_import_private_key(rsa): + cipher = RSA() + cipher.import_(rsa._write_private_key_der()) + assert check_pair(rsa, cipher) is True # Test private half. + assert check_pair(cipher, rsa) is True # Test public half. + assert check_pair(cipher, cipher) is True + + +@pytest.mark.parametrize("format", ("PEM", "DER")) +def test_rsa_export_private_key(rsa, format): + cipher = RSA() + prv, pub = rsa.export(format=format) + cipher.import_(prv) + assert cipher.has_private() is True + assert cipher.has_public() is True + assert check_pair(rsa, cipher) is True # Test private half. + assert check_pair(cipher, rsa) is True # Test public half. + assert check_pair(cipher, cipher) is True + + +@pytest.mark.parametrize("format", ("PEM", "DER")) +def test_rsa_export_private_key_to_file(tmpdir, rsa, format): + prv = tmpdir.join("key.prv") + prv.write_binary(rsa.export(format=format)[0]) + + cipher = RSA() + cipher.import_(prv.read_binary()) + assert cipher.has_private() is True + assert cipher.has_public() is True + assert check_pair(rsa, cipher) is True # Test private half. + assert check_pair(cipher, rsa) is True # Test public half. + assert check_pair(cipher, cipher) is True + + +@pytest.mark.parametrize("digestmod", (_hash.md5, None)) +def test_rsa_sign_verify(rsa, digestmod): + message = _rnd(4096) + sig = rsa.sign(message, digestmod) + assert sig is not None + assert rsa.verify(message, sig, digestmod) is True + assert rsa.verify(message + b"\0", sig, digestmod) is False diff --git a/tests/test_random.py b/tests/test_random.py index fb090507..96455ccd 100644 --- a/tests/test_random.py +++ b/tests/test_random.py @@ -5,74 +5,77 @@ # pylint: disable=import-error import mbedtls.random as _drbg # pylint: enable=import-error -from nose.tools import assert_equal, assert_not_equal, raises + +import pytest + from mbedtls.exceptions import EntropyError from . import _rnd -def assert_length(collection, length): - assert_equal(len(collection), length) -assert_length.__test__ = False +@pytest.fixture +def entropy(): + return _drbg.Entropy() + + +@pytest.fixture +def random(): + return _drbg.Random() + + +def test_entropy_gather(entropy): + # Only test that this does not raise. + entropy.gather() + + +@pytest.mark.parametrize("length", range(64)) +def test_entropy_retrieve(entropy, length): + assert len(entropy.retrieve(length)) == length + + +@pytest.mark.parametrize("length", (100, )) +def test_entropy_retrieve_long_block_raises_entropyerror(entropy, length): + with pytest.raises(EntropyError): + entropy.retrieve(length) -class TestEntropy: +def test_entropy_update(entropy): + # Only test that this does not raise. + buf = _rnd(64) + entropy.update(buf) - def setup(self): - # pylint: disable=attribute-defined-outside-init - # pylint: disable=invalid-name - self.s = _drbg.Entropy() - def test_gather(self): - # Only test that this does not raise. - self.s.gather() +def test_entropy_not_reproducible(entropy): + assert entropy.retrieve(8) != entropy.retrieve(8) - def test_retrieve(self): - for length in range(64): - assert_length(self.s.retrieve(length), length) - @raises(EntropyError) - def test_retrieve_long_block_raises(self): - self.s.retrieve(100) +def test_entropy_random_initial_values(entropy): + # pylint: disable=invalid-name + other = _drbg.Entropy() + assert entropy.retrieve(8) != other.retrieve(8) - def test_update(self): - # Only test that this does not raise. - buf = _rnd(64) - self.s.update(buf) - def test_not_reproducible(self): - assert_not_equal(self.s.retrieve(8), self.s.retrieve(8)) +def test_reseed(random): + random.reseed() - def test_random_initial_values(self): - # pylint: disable=invalid-name - s = _drbg.Entropy() - assert_not_equal(self.s.retrieve(8), s.retrieve(8)) +def test_not_reproducible(random): + assert random.token_bytes(8) != random.token_bytes(8) -class TestRandom: - def setup(self): - # pylint: disable=attribute-defined-outside-init - self.rnd = _drbg.Random() +def test_update(random): + random.update(b"additional data") - def test_reseed(self): - self.rnd.reseed() - def test_not_reproducible(self): - assert_not_equal(self.rnd.token_bytes(8), - self.rnd.token_bytes(8)) +def test_initial_values(random): + other = _drbg.Random() + assert random.token_bytes(8) != other.token_bytes(8) - def test_update(self): - self.rnd.update(b"additional data") - def test_initial_values(self): - rnd = _drbg.Random() - assert_not_equal(self.rnd.token_bytes(8), - rnd.token_bytes(8)) +@pytest.mark.parametrize("length", range(1024)) +def test_token_bytes(random, length): + assert len(random.token_bytes(length)) == length - def test_token_bytes(self): - for length in range(1024): - assert_length(self.rnd.token_bytes(length), length) - def test_token_hex(self): - for length in range(1024): - assert_length(self.rnd.token_hex(length), 2 * length) +@pytest.mark.parametrize("length", range(1024)) +def test_token_hex(random, length): + assert len(random.token_hex(length)) == 2 * length