Skip to content

Commit

Permalink
tests: Reorganize and bugfix TLS tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Synss committed Feb 19, 2020
1 parent 927fef3 commit c209e46
Showing 1 changed file with 105 additions and 82 deletions.
187 changes: 105 additions & 82 deletions tests/test_tls.py
Expand Up @@ -143,11 +143,15 @@ def test_timeout(self, cookie):
assert cookie.timeout == 1000


class _TestBaseConfiguration(Chain):
class _BaseConfiguration(Chain):
@pytest.fixture
def conf(self):
raise NotImplementedError

@pytest.fixture
def version(self):
raise NotImplementedError

@pytest.mark.parametrize("validate", [True, False])
def test_set_validate_certificates(self, conf, validate):
conf_ = conf.update(validate_certificates=validate)
Expand Down Expand Up @@ -177,6 +181,14 @@ def test_set_inner_protocols(self, conf, inner_protocols):
NextProtocol(_) for _ in inner_protocols
)

def test_lowest_supported_version(self, conf, version):
conf_ = conf.update(lowest_supported_version=version)
assert conf_.lowest_supported_version is version

def test_highest_supported_version(self, conf, version):
conf_ = conf.update(highest_supported_version=version)
assert conf_.highest_supported_version is version

@pytest.mark.parametrize("store", [TrustStore.system()])
def test_trust_store(self, conf, store):
conf_ = conf.update(trust_store=store)
Expand All @@ -188,36 +200,24 @@ def test_set_sni_callback(self, conf, callback):
assert conf.sni_callback is None


class TestTLSConfiguration(_TestBaseConfiguration):
class TestTLSConfiguration(_BaseConfiguration):
@pytest.fixture
def conf(self):
return TLSConfiguration()

@pytest.mark.parametrize("version", TLSVersion)
def test_lowest_supported_version(self, conf, version):
conf_ = conf.update(lowest_supported_version=version)
assert conf_.lowest_supported_version is version

@pytest.mark.parametrize("version", TLSVersion)
def test_highest_supported_version(self, conf, version):
conf_ = conf.update(highest_supported_version=version)
assert conf_.highest_supported_version is version
@pytest.fixture(params=TLSVersion)
def version(self, request):
return request.param


class TestDTLSConfiguration(_TestBaseConfiguration):
class TestDTLSConfiguration(_BaseConfiguration):
@pytest.fixture
def conf(self):
return DTLSConfiguration()

@pytest.mark.parametrize("version", DTLSVersion)
def test_lowest_supported_version(self, conf, version):
conf_ = conf.update(lowest_supported_version=version)
assert conf_.lowest_supported_version is version

@pytest.mark.parametrize("version", DTLSVersion)
def test_highest_supported_version(self, conf, version):
conf_ = conf.update(highest_supported_version=version)
assert conf_.highest_supported_version is version
@pytest.fixture(params=DTLSVersion)
def version(self, request):
return request.param

@pytest.mark.parametrize("anti_replay", [True, False])
def test_set_anti_replay(self, conf, anti_replay):
Expand Down Expand Up @@ -289,13 +289,19 @@ def test_wrap_buffers(self, context):
assert isinstance(context.wrap_buffers(), TLSWrappedBuffer)


class _TestCommunicationBase(Chain):
CLOSE_MESSAGE = b"bye"
class _CommunicationBase(Chain):
@pytest.fixture(scope="class")
def proto(self):
raise NotImplementedError

@pytest.fixture(scope="class")
def version(self):
raise NotImplementedError

@pytest.fixture(scope="class")
def srv_hostname(self):
return "End Entity"

@pytest.fixture(scope="class")
def srv_conf(self):
raise NotImplementedError
Expand All @@ -304,10 +310,6 @@ def srv_conf(self):
def cli_conf(self):
raise NotImplementedError

@pytest.fixture
def step(self):
raise NotImplementedError

def echo(self, sock):
raise NotImplementedError

Expand All @@ -328,34 +330,31 @@ def trust_store(self, ca0_crt):
return store

@pytest.fixture
def server(self, srv_conf, version):
def server(self, srv_conf, version, proto):
ctx = ServerContext(srv_conf)
sock = ctx.wrap_socket(socket.socket(socket.AF_INET, self.proto))
sock = ctx.wrap_socket(socket.socket(socket.AF_INET, proto))
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(("", 0))
if self.proto == socket.SOCK_STREAM:
if proto == socket.SOCK_STREAM:
sock.listen(1)

runner = mp.Process(target=self.echo, args=(sock,))
runner.start()
yield sock
runner.join(0.1)
runner.terminate()
runner.join()
with suppress(OSError):
sock.close()
runner.terminate()

@pytest.fixture
def client(self, server, cli_conf):
def client(self, server, srv_hostname, cli_conf, proto):
ctx = ClientContext(cli_conf)
sock = ctx.wrap_socket(
socket.socket(socket.AF_INET, self.proto),
server_hostname="End Entity",
socket.socket(socket.AF_INET, proto), server_hostname=srv_hostname,
)
sock.connect(server.getsockname())
block(sock.do_handshake)
yield sock
with suppress(OSError):
block(sock.send, self.CLOSE_MESSAGE)
with suppress(TLSError, OSError):
sock.close()

def test_srv_conf(self, srv_conf, ca1_crt, ee0_crt, ee0_key, trust_store):
Expand All @@ -369,8 +368,10 @@ def test_cli_conf(self, cli_conf, trust_store):
assert cli_conf.validate_certificates == True


class TestTLSCommunication(_TestCommunicationBase):
proto = socket.SOCK_STREAM
class _TLSCommunicationBase(_CommunicationBase):
@pytest.fixture(scope="class")
def proto(self):
return socket.SOCK_STREAM

@pytest.fixture(
scope="class",
Expand Down Expand Up @@ -400,42 +401,23 @@ def cli_conf(self, version, trust_store):
validate_certificates=True,
)

@pytest.fixture(params=[1, 1000, 5000])
def step(self, request):
return request.param

def echo(self, sock):
conn, addr = sock.accept()
block(conn.do_handshake)
try:
block(conn.do_handshake)
except TLSError:
conn.close()
return
while True:
data = block(conn.recv, 2 << 13)
if data == self.CLOSE_MESSAGE:
break

amt = block(conn.send, data)
assert amt == len(data)

def test_server_hostname_fails_verification(self, server, cli_conf):
ctx = ClientContext(cli_conf)
sock = ctx.wrap_socket(
socket.socket(socket.AF_INET, self.proto),
server_hostname="Wrong End Entity",
)
sock.connect(server.getsockname())
with pytest.raises(TLSError):
block(sock.do_handshake)

def test_client_server(self, client, buffer, step):
received = bytearray()
for idx in range(0, len(buffer), step):
view = memoryview(buffer[idx : idx + step])
amt = block(client.send, view)
assert amt == len(view)
assert block(client.recv, 2 << 13) == view
conn.close()


class TestDTLSCommunication(_TestCommunicationBase):
proto = socket.SOCK_DGRAM
class _DTLSCommunicationBase(_CommunicationBase):
@pytest.fixture(scope="class")
def proto(self):
return socket.SOCK_DGRAM

@pytest.fixture(scope="class", params=DTLSVersion)
def version(self, request):
Expand All @@ -448,7 +430,7 @@ def srv_conf(
return DTLSConfiguration(
trust_store=trust_store,
certificate_chain=([ee0_crt, ca1_crt], ee0_key),
lowest_supported_version=TLSVersion.MINIMUM_SUPPORTED,
lowest_supported_version=DTLSVersion.MINIMUM_SUPPORTED,
highest_supported_version=version,
validate_certificates=False,
)
Expand All @@ -457,30 +439,71 @@ def srv_conf(
def cli_conf(self, version, trust_store):
return DTLSConfiguration(
trust_store=trust_store,
lowest_supported_version=TLSVersion.MINIMUM_SUPPORTED,
lowest_supported_version=DTLSVersion.MINIMUM_SUPPORTED,
highest_supported_version=version,
validate_certificates=True,
)

@pytest.fixture(params=[10, 1000, 5000])
def step(self, request):
return request.param

def echo(self, sock):
cli, addr = sock.accept()
cli.setcookieparam(addr[0].encode("ascii"))
with pytest.raises(_tls.HelloVerifyRequest):
with pytest.raises(HelloVerifyRequest):
block(cli.do_handshake)

cli, addr = cli.accept()
_, (cli, addr) = cli, cli.accept()
_.close()
cli.setcookieparam(addr[0].encode("ascii"))
block(cli.do_handshake)
try:
block(cli.do_handshake)
except TLSError:
cli.close()
return
while True:
data = block(cli.recv, 4096)
if data == self.CLOSE_MESSAGE:
break

# We must use `send()` instead of `sendto()` because the
# DTLS socket is connected.
amt = block(cli.send, data)
assert amt == len(data)
cli.close()


class TestTLSHostNameVerificationFailure(_TLSCommunicationBase):
@pytest.fixture(scope="class")
def srv_hostname(self):
return "Wrong End Entity"

@pytest.mark.usefixtures("server")
def test_handshake_raises_tlserror(self, client):
with pytest.raises(TLSError):
block(client.do_handshake)


class TestTLSCommunication(_TLSCommunicationBase):
@pytest.fixture(params=[1000, 5000])
def step(self, request):
return request.param

@pytest.mark.usefixtures("server")
def test_client_server(self, client, buffer, step):
block(client.do_handshake)
received = bytearray()
for idx in range(0, len(buffer), step):
view = memoryview(buffer[idx : idx + step])
amt = block(client.send, view)
assert amt == len(view)
assert block(client.recv, 2 << 13) == view


class TestDTLSCommunication(_DTLSCommunicationBase):
@pytest.fixture(params=[10, 1000])
def step(self, request):
return request.param

@pytest.mark.usefixtures("server")
def test_client_server(self, client, buffer, step):
block(client.do_handshake)
received = bytearray()
for idx in range(0, len(buffer), step):
view = memoryview(buffer[idx : idx + step])
amt = block(client.send, view)
assert amt == len(view)
assert block(client.recv, 2 << 13) == view

0 comments on commit c209e46

Please sign in to comment.