Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: backwards compatibility test for the serialization feature #4548

Merged
merged 17 commits into from
May 20, 2024
7 changes: 6 additions & 1 deletion tests/integrationv2/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import subprocess
import threading
import itertools

import random
import string

from constants import TEST_CERT_DIRECTORY
from global_flags import get_flag, S2N_PROVIDER_VERSION
Expand All @@ -29,6 +30,10 @@ def data_bytes(n_bytes):
return bytes(byte_array)


def random_str(n):
return "".join(random.choice(string.ascii_uppercase + string.digits) for _ in range(n))


def pq_enabled():
"""
Returns true or false to indicate whether PQ crypto is enabled in s2n
Expand Down
12 changes: 3 additions & 9 deletions tests/integrationv2/test_key_update.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,14 @@
import copy
import random
import string
import pytest

from configuration import available_ports, TLS13_CIPHERS
from common import ProviderOptions, Protocols
from common import ProviderOptions, Protocols, random_str
from fixtures import managed_process # lgtm [py/unused-import]
from providers import Provider, S2N, OpenSSL
from utils import invalid_test_parameters, get_parameter_name

SERVER_DATA = f"Some random data from the server:" + "".join(
random.choice(string.ascii_uppercase + string.digits) for _ in range(10)
)
CLIENT_DATA = f"Some random data from the client:" + "".join(
random.choice(string.ascii_uppercase + string.digits) for _ in range(10)
)
SERVER_DATA = f"Some random data from the server:" + random_str(10)
CLIENT_DATA = f"Some random data from the client:" + random_str(10)


def test_nothing():
Expand Down
151 changes: 151 additions & 0 deletions tests/integrationv2/test_serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
from enum import Enum, auto
Fixed Show fixed Hide fixed
import pytest
import copy
import os

from configuration import available_ports
from common import ProviderOptions, Protocols, random_str
from fixtures import managed_process # lgtm [py/unused-import]

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'managed_process' is not used.
from providers import Provider, S2N
from utils import invalid_test_parameters, get_parameter_name, to_bytes

SERVER_STATE_FILE = 'server_state'
CLIENT_STATE_FILE = 'client_state'

SERVER_DATA = f"Some random data from the server:" + random_str(10)
CLIENT_DATA = f"Some random data from the client:" + random_str(10)


"""
This test file checks that a serialized connection can be deserialized by an older version of
s2n-tls and vice versa. This ensures that any future changes we make to the handshake are backwards-compatible
with an older version of s2n-tls.

This feature requires an uninterrupted TCP connection with the peer in-between serialization and
deserialization. Our integration test setup can't easily provide that while also using two different
s2n-tls versions. To get around that we do a hack and serialize/deserialize both peers in the TLS connection.
This prevents one peer from receiving a TCP FIN message and shutting the connection down early.
"""


@pytest.mark.uncollect_if(func=invalid_test_parameters)
@pytest.mark.parametrize("protocol", [Protocols.TLS13, Protocols.TLS12], ids=get_parameter_name)
@pytest.mark.parametrize("serialize_older_version", [True, False], ids=get_parameter_name)
def test_server_serialization_backwards_compat(managed_process, tmp_path, protocol, serialize_older_version):
server_state_file = str(tmp_path / SERVER_STATE_FILE)
client_state_file = str(tmp_path / CLIENT_STATE_FILE)
assert not os.path.exists(server_state_file)
assert not os.path.exists(client_state_file)

options = ProviderOptions(
port=next(available_ports),
protocol=protocol,
insecure=True,
)

client_options = copy.copy(options)
client_options.mode = Provider.ClientMode
client_options.extra_flags = ['--serialize-out', client_state_file]

server_options = copy.copy(options)
server_options.mode = Provider.ServerMode
server_options.extra_flags = ['--serialize-out', server_state_file]
server_options.use_mainline_version = serialize_older_version

server = managed_process(
S2N, server_options, send_marker=S2N.get_send_marker())
client = managed_process(S2N, client_options, send_marker=S2N.get_send_marker())

for results in client.get_results():
results.assert_success()
assert to_bytes("Actual protocol version: {}".format(protocol.value)) in results.stdout

for results in server.get_results():
results.assert_success()
assert to_bytes("Actual protocol version: {}".format(protocol.value)) in results.stdout

assert os.path.exists(server_state_file)
assert os.path.exists(client_state_file)

client_options.extra_flags = ['--deserialize-in', client_state_file]
server_options.extra_flags = ['--deserialize-in', server_state_file]
server_options.use_mainline_version = not serialize_older_version

server_options.data_to_send = SERVER_DATA.encode()
client_options.data_to_send = CLIENT_DATA.encode()

server = managed_process(S2N, server_options, send_marker=CLIENT_DATA)
client = managed_process(S2N, client_options, send_marker="Connected to localhost", close_marker=SERVER_DATA)

for results in server.get_results():
results.assert_success()
# No protocol version printout since deserialization means skipping the handshake
assert to_bytes("Actual protocol version:") not in results.stdout
assert CLIENT_DATA.encode() in results.stdout

for results in client.get_results():
results.assert_success()
assert to_bytes("Actual protocol version:") not in results.stdout
assert SERVER_DATA.encode() in results.stdout

maddeleine marked this conversation as resolved.
Show resolved Hide resolved

@pytest.mark.uncollect_if(func=invalid_test_parameters)
@pytest.mark.parametrize("protocol", [Protocols.TLS13, Protocols.TLS12], ids=get_parameter_name)
@pytest.mark.parametrize("serialize_older_version", [True, False], ids=get_parameter_name)
def test_client_serialization_backwards_compat(managed_process, tmp_path, protocol, serialize_older_version):
server_state_file = str(tmp_path / SERVER_STATE_FILE)
client_state_file = str(tmp_path / CLIENT_STATE_FILE)
assert not os.path.exists(server_state_file)
assert not os.path.exists(client_state_file)

options = ProviderOptions(
port=next(available_ports),
protocol=protocol,
insecure=True,
)

client_options = copy.copy(options)
client_options.mode = Provider.ClientMode
client_options.extra_flags = ['--serialize-out', client_state_file]
client_options.use_mainline_version = serialize_older_version

server_options = copy.copy(options)
server_options.mode = Provider.ServerMode
server_options.extra_flags = ['--serialize-out', server_state_file]

server = managed_process(
S2N, server_options, send_marker=S2N.get_send_marker())
client = managed_process(S2N, client_options, send_marker=S2N.get_send_marker())

for results in client.get_results():
results.assert_success()
assert to_bytes("Actual protocol version: {}".format(protocol.value)) in results.stdout

for results in server.get_results():
results.assert_success()
assert to_bytes("Actual protocol version: {}".format(protocol.value)) in results.stdout

assert os.path.exists(server_state_file)
assert os.path.exists(client_state_file)

client_options.extra_flags = ['--deserialize-in', client_state_file]
client_options.use_mainline_version = not serialize_older_version

server_options.extra_flags = ['--deserialize-in', server_state_file]

server_options.data_to_send = SERVER_DATA.encode()
client_options.data_to_send = CLIENT_DATA.encode()

server = managed_process(S2N, server_options, send_marker=CLIENT_DATA)
client = managed_process(S2N, client_options, send_marker="Connected to localhost", close_marker=SERVER_DATA)

for results in server.get_results():
results.assert_success()
# No protocol version printout since deserialization means skipping the handshake
assert to_bytes("Actual protocol version:") not in results.stdout
assert CLIENT_DATA.encode() in results.stdout

for results in client.get_results():
results.assert_success()
assert to_bytes("Actual protocol version:") not in results.stdout
assert SERVER_DATA.encode() in results.stdout
Loading