diff --git a/bindings/rust/bench/.cargo/config.toml b/bindings/rust/bench/.cargo/config.toml deleted file mode 100644 index 1165fb46d80..00000000000 --- a/bindings/rust/bench/.cargo/config.toml +++ /dev/null @@ -1,3 +0,0 @@ -[env] -S2N_TLS_LIB_DIR = "/home/ubuntu/s2n-tls/bindings/rust/bench/target/s2n-tls-build/lib" -LD_LIBRARY_PATH = "/home/ubuntu/s2n-tls/bindings/rust/bench/target/s2n-tls-build/lib" diff --git a/bindings/rust/bench/Cargo.toml b/bindings/rust/bench/Cargo.toml index c00b34335a1..7fb601b56e7 100644 --- a/bindings/rust/bench/Cargo.toml +++ b/bindings/rust/bench/Cargo.toml @@ -10,7 +10,7 @@ historical-perf = [] s2n-tls = { path = "../s2n-tls" } rustls = "0.21" rustls-pemfile = "1.0" -openssl = "0.10" +openssl = { version = "0.10", features = ["vendored"] } errno = "0.3" libc = "0.2" crabgrind = "0.1" @@ -19,9 +19,10 @@ rand_distr = "0.4" plotters = "0.3" serde_json = "1.0" semver = "1.0" +strum = { version = "0.25", features = ["derive"] } [dev-dependencies] -criterion = "0.3" +criterion = "0.5" [[bench]] name = "handshake" diff --git a/bindings/rust/bench/benches/handshake.rs b/bindings/rust/bench/benches/handshake.rs index 39c09a8df74..99eff7617c0 100644 --- a/bindings/rust/bench/benches/handshake.rs +++ b/bindings/rust/bench/benches/handshake.rs @@ -2,77 +2,72 @@ // SPDX-License-Identifier: Apache-2.0 use bench::{ - CipherSuite, CryptoConfig, - ECGroup::{self, *}, - HandshakeType::{self, *}, - OpenSslHarness, RustlsHarness, S2NHarness, - SigType::{self, *}, - TlsBenchHarness, + CipherSuite, CryptoConfig, HandshakeType, KXGroup, OpenSslConnection, RustlsConnection, + S2NConnection, SigType, TlsConnPair, TlsConnection, }; use criterion::{ criterion_group, criterion_main, measurement::WallTime, BatchSize, BenchmarkGroup, Criterion, }; +use strum::IntoEnumIterator; + +fn bench_handshake_for_library( + bench_group: &mut BenchmarkGroup, + handshake_type: HandshakeType, + kx_group: KXGroup, + sig_type: SigType, +) { + // generate all harnesses (TlsConnPair structs) beforehand so that benchmarks + // only include negotiation and not config/connection initialization + bench_group.bench_function(T::name(), |b| { + b.iter_batched_ref( + || { + TlsConnPair::::new( + CryptoConfig::new(CipherSuite::default(), kx_group, sig_type), + handshake_type, + Default::default(), + ) + }, + |conn_pair_res| { + // harnesses with certain parameters fail to initialize for + // some past versions of s2n-tls, but missing data can be + // visually interpolated in the historical performance graph + if let Ok(conn_pair) = conn_pair_res { + let _ = conn_pair.handshake(); + } + }, + BatchSize::SmallInput, + ) + }); +} pub fn bench_handshake_params(c: &mut Criterion) { - fn bench_handshake_for_library( - bench_group: &mut BenchmarkGroup, - name: &str, - handshake_type: HandshakeType, - ec_group: ECGroup, - sig_type: SigType, - ) { - // generate all harnesses (TlsBenchHarness structs) beforehand so that benchmarks - // only include negotiation and not config/connection initialization - bench_group.bench_function(name, |b| { - b.iter_batched_ref( - || { - T::new( - CryptoConfig::new(CipherSuite::default(), ec_group, sig_type), - handshake_type, - Default::default(), - ) - }, - |harness| { - // harnesses with certain parameters fail to initialize for - // some past versions of s2n-tls, but missing data can be - // visually interpolated in the historical performance graph - if let Ok(harness) = harness { - let _ = harness.handshake(); + for handshake_type in HandshakeType::iter() { + for kx_group in KXGroup::iter() { + for sig_type in SigType::iter() { + let mut bench_group = c.benchmark_group(match handshake_type { + HandshakeType::ServerAuth => format!("handshake-{:?}-{:?}", kx_group, sig_type), + HandshakeType::MutualAuth => { + format!("handshake-mTLS-{:?}-{:?}", kx_group, sig_type) } - }, - BatchSize::SmallInput, - ) - }); - } - - for handshake_type in [ServerAuth, MutualAuth] { - for ec_group in [SECP256R1, X25519] { - for sig_type in [Rsa2048, Rsa3072, Rsa4096, Ec384] { - let mut bench_group = c.benchmark_group(format!( - "handshake-{:?}-{:?}-{:?}", - handshake_type, ec_group, sig_type - )); - bench_handshake_for_library::( + }); + bench_handshake_for_library::( &mut bench_group, - "s2n-tls", handshake_type, - ec_group, + kx_group, sig_type, ); #[cfg(not(feature = "historical-perf"))] { - bench_handshake_for_library::( + bench_handshake_for_library::( &mut bench_group, - "rustls", handshake_type, - ec_group, + kx_group, sig_type, ); - bench_handshake_for_library::( + bench_handshake_for_library::( &mut bench_group, - "openssl", handshake_type, - ec_group, + kx_group, sig_type, ); } diff --git a/bindings/rust/bench/benches/throughput.rs b/bindings/rust/bench/benches/throughput.rs index 7ff9576b525..bbb2121f586 100644 --- a/bindings/rust/bench/benches/throughput.rs +++ b/bindings/rust/bench/benches/throughput.rs @@ -2,68 +2,64 @@ // SPDX-License-Identifier: Apache-2.0 use bench::{ - CipherSuite::{self, *}, - CryptoConfig, ECGroup, HandshakeType, OpenSslHarness, RustlsHarness, S2NHarness, SigType, - TlsBenchHarness, harness::ConnectedBuffer, + harness::ConnectedBuffer, CipherSuite, CryptoConfig, HandshakeType, KXGroup, OpenSslConnection, + RustlsConnection, S2NConnection, SigType, TlsConnPair, TlsConnection, }; use criterion::{ criterion_group, criterion_main, measurement::WallTime, BatchSize, BenchmarkGroup, Criterion, Throughput, }; +use strum::IntoEnumIterator; + +fn bench_throughput_for_library( + bench_group: &mut BenchmarkGroup, + shared_buf: &mut [u8], + cipher_suite: CipherSuite, +) { + bench_group.bench_function(T::name(), |b| { + b.iter_batched_ref( + || { + TlsConnPair::::new( + CryptoConfig::new(cipher_suite, KXGroup::default(), SigType::default()), + HandshakeType::default(), + ConnectedBuffer::default(), + ) + .map(|mut h| { + let _ = h.handshake(); + h + }) + }, + |conn_pair_res| { + if let Ok(conn_pair) = conn_pair_res { + let _ = conn_pair.round_trip_transfer(shared_buf); + } + }, + BatchSize::SmallInput, + ) + }); +} pub fn bench_throughput_cipher_suite(c: &mut Criterion) { // arbitrarily large to cut across TLS record boundaries let mut shared_buf = [0u8; 100000]; - fn bench_throughput_for_library( - bench_group: &mut BenchmarkGroup, - name: &str, - shared_buf: &mut [u8], - cipher_suite: CipherSuite, - ) { - bench_group.bench_function(name, |b| { - b.iter_batched_ref( - || { - T::new( - CryptoConfig::new(cipher_suite, ECGroup::default(), SigType::default()), - HandshakeType::default(), - ConnectedBuffer::default(), - ) - .map(|mut h| { - let _ = h.handshake(); - h - }) - }, - |harness| { - if let Ok(harness) = harness { - let _ = harness.round_trip_transfer(shared_buf); - } - }, - BatchSize::SmallInput, - ) - }); - } - - for cipher_suite in [AES_128_GCM_SHA256, AES_256_GCM_SHA384] { + for cipher_suite in CipherSuite::iter() { let mut bench_group = c.benchmark_group(format!("throughput-{:?}", cipher_suite)); bench_group.throughput(Throughput::Bytes(shared_buf.len() as u64)); - bench_throughput_for_library::( + bench_throughput_for_library::( &mut bench_group, - "s2n-tls", &mut shared_buf, cipher_suite, ); #[cfg(not(feature = "historical-perf"))] { - bench_throughput_for_library::( + bench_throughput_for_library::( &mut bench_group, - "rustls", &mut shared_buf, cipher_suite, ); - bench_throughput_for_library::( + bench_throughput_for_library::( &mut bench_group, - "openssl", &mut shared_buf, cipher_suite, ); diff --git a/bindings/rust/bench/certs/generate_certs.sh b/bindings/rust/bench/certs/generate_certs.sh index 11b68284e7b..36f6807c0c0 100755 --- a/bindings/rust/bench/certs/generate_certs.sh +++ b/bindings/rust/bench/certs/generate_certs.sh @@ -16,11 +16,13 @@ pushd "$(dirname "$0")" > /dev/null # Generates certs with given algorithms and bits in $1$2/, ex. ec384/ # $1: rsa or ec # $2: number of bits +# $3: directory under the `certs/` directory to put certs in cert-gen () { echo -e "\n----- generating certs for $1$2 -----\n" key_family=$1 key_size=$2 + dir_name=$3 # set openssl argument name if [[ $key_family == rsa ]]; then @@ -30,8 +32,8 @@ cert-gen () { fi # make directory for certs - mkdir -p $key_family$key_size - cd $key_family$key_size + mkdir -p $dir_name + cd $dir_name echo "generating CA private key and certificate" openssl req -new -nodes -x509 -newkey $key_family -pkeyopt $argname$key_size -keyout ca-key.pem -out ca-cert.pem -days 65536 -config ../config/ca.cnf @@ -62,13 +64,13 @@ cert-gen () { if [[ $1 != "clean" ]] then - cert-gen ec 384 - cert-gen rsa 2048 - cert-gen rsa 3072 - cert-gen rsa 4096 + cert-gen ec 384 ecdsa384 + cert-gen rsa 2048 rsa2048 + cert-gen rsa 3072 rsa3072 + cert-gen rsa 4096 rsa4096 else echo "cleaning certs" - rm -rf ec*/ rsa*/ + rm -rf ecdsa*/ rsa*/ fi popd > /dev/null diff --git a/bindings/rust/bench/src/bin/graph_memory.rs b/bindings/rust/bench/src/bin/graph_memory.rs index a4cb56bc8b1..20c0fd0789d 100644 --- a/bindings/rust/bench/src/bin/graph_memory.rs +++ b/bindings/rust/bench/src/bin/graph_memory.rs @@ -28,7 +28,7 @@ fn get_bytes_from_snapshot(name: &str, i: i32) -> i32 { } /// Get the difference in bytes between two snapshots, which is memory of the -/// `i`th TlsBenchHarness (client and server) +/// `i`th TlsConnPair (client and server) fn get_bytes_diff(name: &str, i: i32) -> i32 { get_bytes_from_snapshot(name, i + 1) - get_bytes_from_snapshot(name, i) } diff --git a/bindings/rust/bench/src/bin/memory.rs b/bindings/rust/bench/src/bin/memory.rs index e10aa35ff78..6a09b548b9e 100644 --- a/bindings/rust/bench/src/bin/memory.rs +++ b/bindings/rust/bench/src/bin/memory.rs @@ -1,18 +1,21 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -use bench::{harness::ConnectedBuffer, OpenSslHarness, RustlsHarness, S2NHarness, TlsBenchHarness}; +use bench::{ + harness::ConnectedBuffer, CryptoConfig, HandshakeType, OpenSslConnection, RustlsConnection, + S2NConnection, TlsConnPair, TlsConnection, +}; use std::{fs::create_dir_all, path::Path}; -fn memory_bench(dir_name: &str) { +fn memory_bench(dir_name: &str) { println!("testing {dir_name}"); if !Path::new(&format!("target/memory/{dir_name}")).is_dir() { create_dir_all(format!("target/memory/{dir_name}")).unwrap(); } - let mut harnesses = Vec::new(); - harnesses.reserve(100); + let mut conn_pairs = Vec::new(); + conn_pairs.reserve(100); // reserve space for buffers before benching let mut buffers = Vec::new(); @@ -22,27 +25,27 @@ fn memory_bench(dir_name: &str) { } // handshake one harness to initalize libraries - let mut harness = T::default().unwrap(); - harness.handshake().unwrap(); + let mut conn_pair = TlsConnPair::::default(); + conn_pair.handshake().unwrap(); // tell massif to take initial memory snapshot crabgrind::monitor_command(format!("snapshot target/memory/{dir_name}/0.snapshot")).unwrap(); - // make and handshake 100 harness + // make and handshake 100 connection pairs // memory usage stabilizes after first few handshakes for i in 1..101 { // put new harness directly into harness vec - harnesses.push( - T::new( - Default::default(), - Default::default(), + conn_pairs.push( + TlsConnPair::::new( + CryptoConfig::default(), + HandshakeType::default(), buffers.pop().unwrap(), // take ownership of buffer ) .unwrap(), ); // handshake last harness added - harnesses + conn_pairs .as_mut_slice() .last_mut() .unwrap() @@ -58,7 +61,7 @@ fn memory_bench(dir_name: &str) { fn main() { assert!(!cfg!(debug_assertions), "need to run in release mode"); - memory_bench::("s2n-tls"); - memory_bench::("rustls"); - memory_bench::("openssl"); + memory_bench::("s2n-tls"); + memory_bench::("rustls"); + memory_bench::("openssl"); } diff --git a/bindings/rust/bench/src/harness.rs b/bindings/rust/bench/src/harness.rs index 02756378596..3e0e9d6e1df 100644 --- a/bindings/rust/bench/src/harness.rs +++ b/bindings/rust/bench/src/harness.rs @@ -1,15 +1,71 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -use crate::{get_cert_path, PemType}; use std::{ cell::RefCell, collections::VecDeque, error::Error, + fmt::Debug, fs::read_to_string, io::{ErrorKind, Read, Write}, rc::Rc, }; +use strum::EnumIter; + +#[derive(Clone, Copy, EnumIter)] +pub enum PemType { + ServerKey, + ServerCertChain, + ClientKey, + ClientCertChain, + CACert, +} + +impl PemType { + fn get_filename(&self) -> &str { + match self { + PemType::ServerKey => "server-key.pem", + PemType::ServerCertChain => "server-cert.pem", + PemType::ClientKey => "client-key.pem", + PemType::ClientCertChain => "client-cert.pem", + PemType::CACert => "ca-cert.pem", + } + } +} + +#[derive(Clone, Copy, Default, EnumIter)] +pub enum SigType { + Rsa2048, + Rsa3072, + Rsa4096, + #[default] + Ecdsa384, +} + +impl SigType { + pub fn get_dir_name(&self) -> &str { + match self { + SigType::Rsa2048 => "rsa2048", + SigType::Rsa3072 => "rsa3072", + SigType::Rsa4096 => "rsa4096", + SigType::Ecdsa384 => "ecdsa384", + } + } +} + +impl Debug for SigType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.get_dir_name()) + } +} + +pub fn get_cert_path(pem_type: PemType, sig_type: SigType) -> String { + format!( + "certs/{}/{}", + sig_type.get_dir_name(), + pem_type.get_filename() + ) +} pub fn read_to_bytes(pem_type: PemType, sig_type: SigType) -> Vec { read_to_string(get_cert_path(pem_type, sig_type)) @@ -17,13 +73,13 @@ pub fn read_to_bytes(pem_type: PemType, sig_type: SigType) -> Vec { .into_bytes() } -#[derive(Clone, Copy, Debug, Eq, PartialEq)] +#[derive(Clone, Copy)] pub enum Mode { Client, Server, } -#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)] +#[derive(Clone, Copy, Default, EnumIter, Eq, PartialEq)] pub enum HandshakeType { #[default] ServerAuth, @@ -33,131 +89,266 @@ pub enum HandshakeType { // these parameters were the only ones readily usable for all three libaries: // s2n-tls, rustls, and openssl #[allow(non_camel_case_types)] -#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)] +#[derive(Clone, Copy, Debug, Default, EnumIter, Eq, PartialEq)] pub enum CipherSuite { #[default] AES_128_GCM_SHA256, AES_256_GCM_SHA384, } -#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)] -pub enum ECGroup { - SECP256R1, +#[derive(Clone, Copy, Default, EnumIter)] +pub enum KXGroup { + Secp256R1, #[default] X25519, } -#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] -pub enum SigType { - Rsa2048, - Rsa3072, - Rsa4096, - #[default] - Ec384, -} - -impl SigType { - pub fn get_dir_name(&self) -> &str { - match self { - SigType::Rsa2048 => "rsa2048", - SigType::Rsa3072 => "rsa3072", - SigType::Rsa4096 => "rsa4096", - SigType::Ec384 => "ec384", - } +impl Debug for KXGroup { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + Self::Secp256R1 => "secp256r1", + Self::X25519 => "x25519", + } + ) } } -#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)] +#[derive(Clone, Copy, Debug, Default)] pub struct CryptoConfig { pub cipher_suite: CipherSuite, - pub ec_group: ECGroup, + pub kx_group: KXGroup, pub sig_type: SigType, } impl CryptoConfig { - pub fn new(cipher_suite: CipherSuite, ec_group: ECGroup, sig_type: SigType) -> Self { + pub fn new(cipher_suite: CipherSuite, kx_group: KXGroup, sig_type: SigType) -> Self { Self { cipher_suite, - ec_group, + kx_group, sig_type, } } } -pub trait TlsBenchHarness: Sized { - /// Default harness - fn default() -> Result> { - Self::new(CryptoConfig::default(), HandshakeType::default(), ConnectedBuffer::default()) +pub trait TlsConnection: Sized { + /// Library-specific config struct + type Config; + + /// Name of the connection type + fn name() -> String; + + /// Default connection (client or server) + fn default(mode: Mode) -> Result> { + Self::new( + mode, + CryptoConfig::default(), + HandshakeType::default(), + ConnectedBuffer::default(), + ) } + /// Make a config with given parameters + fn make_config( + mode: Mode, + crypto_config: CryptoConfig, + handshake_type: HandshakeType, + ) -> Result>; + + /// Make connection from existing config and buffer + fn new_from_config( + config: &Self::Config, + connected_buffer: ConnectedBuffer, + ) -> Result>; + /// Initialize buffers, configs, and connections (pre-handshake) fn new( + mode: Mode, crypto_config: CryptoConfig, handshake_type: HandshakeType, buffer: ConnectedBuffer, - ) -> Result>; + ) -> Result> { + Self::new_from_config( + &Self::make_config(mode, crypto_config, handshake_type)?, + buffer, + ) + } - /// Run handshake on initialized connections - /// Returns error if handshake has already completed + /// Run one handshake step: receive msgs from other connection, process, and send new msgs fn handshake(&mut self) -> Result<(), Box>; - /// Checks if handshake is finished for both client and server fn handshake_completed(&self) -> bool; - /// Get negotiated cipher suite fn get_negotiated_cipher_suite(&self) -> CipherSuite; - /// Get whether or negotiated version is TLS1.3 fn negotiated_tls13(&self) -> bool; - /// Send application data from connection in harness pair - fn send(&mut self, sender: Mode, data: &[u8]) -> Result<(), Box>; + /// Send application data to ConnectedBuffer + fn send(&mut self, data: &[u8]) -> Result<(), Box>; + + /// Read application data from ConnectedBuffer + fn recv(&mut self, data: &mut [u8]) -> Result<(), Box>; + + /// Shrink buffers owned by the connection + fn shrink_connection_buffers(&mut self); + + /// Clear and shrink buffers used for IO with another connection + fn shrink_connected_buffer(&mut self); - /// Receive application data sent to connection in harness pair - fn recv(&mut self, receiver: Mode, data: &mut [u8]) -> Result<(), Box>; + /// Get reference to internal connected buffer + fn connected_buffer(&self) -> &ConnectedBuffer; +} + +pub struct TlsConnPair { + client: C, + server: S, +} + +impl Default for TlsConnPair { + fn default() -> Self { + Self::new(Default::default(), Default::default(), Default::default()).unwrap() + } +} + +impl TlsConnPair { + /// Wrap two TlsConnections into a TlsConnPair + pub fn wrap(client: C, server: S) -> Self { + assert!( + client.connected_buffer() == &server.connected_buffer().clone_inverse(), + "connected buffers don't match" + ); + Self { client, server } + } + + /// Take back ownership of individual connections in the TlsConnPair + pub fn split(self) -> (C, S) { + (self.client, self.server) + } + + /// Initialize buffers, configs, and connections (pre-handshake) + pub fn new( + crypto_config: CryptoConfig, + handshake_type: HandshakeType, + connected_buffer: ConnectedBuffer, + ) -> Result> { + Ok(Self { + client: C::new( + Mode::Client, + crypto_config, + handshake_type, + connected_buffer.clone_inverse(), + )?, + server: S::new( + Mode::Server, + crypto_config, + handshake_type, + connected_buffer, + )?, + }) + } + + /// Run handshake on connections + /// Two round trips are needed for the server to receive the Finished message + /// from the client and be ready to send data + pub fn handshake(&mut self) -> Result<(), Box> { + for _ in 0..2 { + self.client.handshake()?; + self.server.handshake()?; + } + Ok(()) + } + + /// Checks if handshake is finished for both client and server + pub fn handshake_completed(&self) -> bool { + self.client.handshake_completed() && self.server.handshake_completed() + } + + pub fn get_negotiated_cipher_suite(&self) -> CipherSuite { + assert!(self.handshake_completed()); + assert!( + self.client.get_negotiated_cipher_suite() == self.server.get_negotiated_cipher_suite() + ); + self.client.get_negotiated_cipher_suite() + } + + pub fn negotiated_tls13(&self) -> bool { + self.client.negotiated_tls13() && self.server.negotiated_tls13() + } - /// Send data from client to server and then from server to client - fn round_trip_transfer(&mut self, data: &mut [u8]) -> Result<(), Box> { + /// Send data from client to server, and then from server to client + pub fn round_trip_transfer(&mut self, data: &mut [u8]) -> Result<(), Box> { // send data from client to server - self.send(Mode::Client, data)?; - self.recv(Mode::Server, data)?; + self.client.send(data)?; + self.server.recv(data)?; // send data from server to client - self.send(Mode::Server, data)?; - self.recv(Mode::Client, data)?; + self.server.send(data)?; + self.client.recv(data)?; Ok(()) } + + /// Shrink buffers owned by the connections + pub fn shrink_connection_buffers(&mut self) { + self.client.shrink_connection_buffers(); + self.server.shrink_connection_buffers(); + } + + /// Clear and shrink buffers used for IO between the connections + pub fn shrink_connected_buffers(&mut self) { + self.client.shrink_connected_buffer(); + self.server.shrink_connected_buffer(); + } } /// Wrapper of two shared buffers to pass as stream /// This wrapper `read()`s into one buffer and `write()`s to another -#[derive(Clone)] +/// `Rc>>` allows sharing of references to the buffers for two connections +#[derive(Clone, Eq)] pub struct ConnectedBuffer { recv: Rc>>, send: Rc>>, } +impl PartialEq for ConnectedBuffer { + /// ConnectedBuffers are equal if and only if they point to the same VecDeques + fn eq(&self, other: &ConnectedBuffer) -> bool { + Rc::ptr_eq(&self.recv, &other.recv) && Rc::ptr_eq(&self.send, &other.send) + } +} + impl ConnectedBuffer { /// Make a new struct with new internal buffers pub fn new() -> Self { let recv = Rc::new(RefCell::new(VecDeque::new())); let send = Rc::new(RefCell::new(VecDeque::new())); - // prevent resizing of buffers, useful for memory bench + // prevent (potentially slow) resizing of buffers for small data transfers, + // like with handshake recv.borrow_mut().reserve(10000); send.borrow_mut().reserve(10000); Self { recv, send } } - /// Make a new struct that shares internal buffers but swapped, ex. - /// `write()` writes to the buffer that the inverse `read()`s from + + /// Makes a new ConnectedBuffer that shares internal buffers but swapped, + /// ex. `write()` writes to the buffer that the inverse `read()`s from pub fn clone_inverse(&self) -> Self { Self { - recv: Rc::clone(&self.send), - send: Rc::clone(&self.recv), + recv: self.send.clone(), + send: self.recv.clone(), } } + + /// Clears and shrinks buffers + pub fn shrink(&mut self) { + self.recv.borrow_mut().clear(); + self.recv.borrow_mut().shrink_to_fit(); + self.send.borrow_mut().clear(); + self.send.borrow_mut().shrink_to_fit(); + } } impl Read for ConnectedBuffer { @@ -188,61 +379,79 @@ impl Default for ConnectedBuffer { } #[cfg(test)] -macro_rules! test_tls_bench_harnesses { - ($($lib_name:ident: $harness_type:ty,)*) => { - $( - mod $lib_name { - use super::*; - use CipherSuite::*; - use ECGroup::*; - use HandshakeType::*; - use SigType::*; - - #[test] - fn test_handshake_config() { - for handshake_type in [ServerAuth, MutualAuth] { - for cipher_suite in [AES_128_GCM_SHA256, AES_256_GCM_SHA384] { - for ec_group in [SECP256R1, X25519] { - for sig_type in [Ec384, Rsa2048, Rsa3072, Rsa4096] { - let crypto_config = CryptoConfig::new(cipher_suite, ec_group, sig_type); - let mut harness = <$harness_type>::new(crypto_config, handshake_type, ConnectedBuffer::default()).unwrap(); - - assert!(!harness.handshake_completed()); - harness.handshake().unwrap(); - assert!(harness.handshake_completed()); - - assert!(harness.negotiated_tls13()); - assert_eq!(cipher_suite, harness.get_negotiated_cipher_suite()); - } - } - } - } +mod tests { + use super::*; + use crate::{OpenSslConnection, RustlsConnection, S2NConnection, TlsConnPair}; + use std::path::Path; + use strum::IntoEnumIterator; + + #[test] + fn test_cert_paths_valid() { + for pem_type in PemType::iter() { + for sig_type in SigType::iter() { + assert!( + Path::new(&get_cert_path(pem_type, sig_type)).exists(), + "cert not found" + ); } + } + } + + #[test] + fn test_all() { + test_type::(); + test_type::(); + test_type::(); + } - #[test] - fn test_transfer() { - // use a large buffer to test across TLS record boundaries - let mut buf = [0x56u8; 1000000]; - for cipher_suite in [AES_128_GCM_SHA256, AES_256_GCM_SHA384] { - let crypto_config = CryptoConfig::new(cipher_suite, ECGroup::default(), SigType::default()); - let mut harness = <$harness_type>::new(crypto_config, HandshakeType::default(), ConnectedBuffer::default()).unwrap(); - harness.handshake().unwrap(); - harness.round_trip_transfer(&mut buf).unwrap(); + fn test_type() { + eprintln!("{} client --- {} server", C::name(), S::name()); + eprintln!("testing handshake..."); + test_handshake_configs::(); + eprintln!("testing transfer..."); + test_transfer::(); + eprintln!(); + } + + fn test_handshake_configs() { + for handshake_type in HandshakeType::iter() { + for cipher_suite in CipherSuite::iter() { + for kx_group in KXGroup::iter() { + for sig_type in SigType::iter() { + let crypto_config = CryptoConfig::new(cipher_suite, kx_group, sig_type); + let mut conn_pair = TlsConnPair::::new( + crypto_config, + handshake_type, + ConnectedBuffer::default(), + ) + .unwrap(); + + assert!(!conn_pair.handshake_completed()); + conn_pair.handshake().unwrap(); + assert!(conn_pair.handshake_completed()); + + assert!(conn_pair.negotiated_tls13()); + assert_eq!(cipher_suite, conn_pair.get_negotiated_cipher_suite()); + } } } } - )* } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{OpenSslHarness, RustlsHarness, S2NHarness, TlsBenchHarness}; - test_tls_bench_harnesses! { - s2n_tls: S2NHarness, - rustls: RustlsHarness, - openssl: OpenSslHarness, + fn test_transfer() { + // use a large buffer to test across TLS record boundaries + let mut buf = [0x56u8; 1000000]; + for cipher_suite in CipherSuite::iter() { + let crypto_config = + CryptoConfig::new(cipher_suite, KXGroup::default(), SigType::default()); + let mut conn_pair = TlsConnPair::::new( + crypto_config, + HandshakeType::default(), + ConnectedBuffer::default(), + ) + .unwrap(); + conn_pair.handshake().unwrap(); + conn_pair.round_trip_transfer(&mut buf).unwrap(); + } } } diff --git a/bindings/rust/bench/src/lib.rs b/bindings/rust/bench/src/lib.rs index e262c0fe6f3..4c9dad96754 100644 --- a/bindings/rust/bench/src/lib.rs +++ b/bindings/rust/bench/src/lib.rs @@ -7,63 +7,11 @@ pub mod rustls; pub mod s2n_tls; pub use crate::{ - harness::{CipherSuite, CryptoConfig, ECGroup, HandshakeType, Mode, SigType, TlsBenchHarness}, - openssl::OpenSslHarness, - rustls::RustlsHarness, - s2n_tls::S2NHarness, + harness::{ + get_cert_path, CipherSuite, CryptoConfig, HandshakeType, KXGroup, Mode, PemType, SigType, + TlsConnPair, TlsConnection, + }, + openssl::OpenSslConnection, + rustls::RustlsConnection, + s2n_tls::S2NConnection, }; - -#[derive(Clone, Copy)] -pub enum PemType { - ServerKey, - ServerCertChain, - ClientKey, - ClientCertChain, - CACert, -} - -impl PemType { - fn get_filename(&self) -> &str { - match self { - PemType::ServerKey => "server-key.pem", - PemType::ServerCertChain => "server-cert.pem", - PemType::ClientKey => "client-key.pem", - PemType::ClientCertChain => "client-cert.pem", - PemType::CACert => "ca-cert.pem", - } - } -} - -fn get_cert_path(pem_type: PemType, sig_type: SigType) -> String { - format!( - "certs/{}/{}", - sig_type.get_dir_name(), - pem_type.get_filename() - ) -} - -#[cfg(test)] -mod tests { - use super::*; - use std::path::Path; - use PemType::*; - use SigType::*; - - #[test] - fn cert_paths_valid() { - for pem_type in [ - ServerKey, - ServerCertChain, - ClientKey, - ClientCertChain, - CACert, - ] { - for sig_type in [Rsa2048, Rsa3072, Rsa4096, Ec384] { - assert!( - Path::new(&get_cert_path(pem_type, sig_type)).exists(), - "cert not found" - ); - } - } - } -} diff --git a/bindings/rust/bench/src/openssl.rs b/bindings/rust/bench/src/openssl.rs index baee03a1517..0c30664e380 100644 --- a/bindings/rust/bench/src/openssl.rs +++ b/bindings/rust/bench/src/openssl.rs @@ -4,126 +4,140 @@ use crate::{ get_cert_path, harness::{ - CipherSuite, ConnectedBuffer, CryptoConfig, ECGroup, HandshakeType, Mode, TlsBenchHarness, + CipherSuite, ConnectedBuffer, CryptoConfig, HandshakeType, KXGroup, Mode, TlsConnection, }, PemType::*, }; use openssl::ssl::{ - ErrorCode, Ssl, SslContext, SslContextBuilder, SslFiletype, SslMethod, SslStream, - SslVerifyMode, SslVersion, + ErrorCode, Ssl, SslContext, SslFiletype, SslMethod, SslStream, SslVerifyMode, SslVersion, }; use std::{ error::Error, io::{Read, Write}, }; -pub struct OpenSslHarness { - client_conn: SslStream, - server_conn: SslStream, +pub struct OpenSslConnection { + connected_buffer: ConnectedBuffer, + connection: SslStream, } -impl OpenSslHarness { - fn common_config( - builder: &mut SslContextBuilder, - cipher_suite: &str, - ec_key: &str, - ) -> Result<(), Box> { - builder.set_min_proto_version(Some(SslVersion::TLS1_3))?; - builder.set_ciphersuites(cipher_suite)?; - builder.set_groups_list(ec_key)?; - Ok(()) - } - /// Process handshake for one connection, treating blocking errors as `Ok` - fn handshake_conn(&mut self, mode: Mode) -> Result<(), Box> { - let result = match mode { - Mode::Client => self.client_conn.connect(), - Mode::Server => self.server_conn.accept(), - }; - match result { - Ok(_) => Ok(()), - Err(err) => { - if err.code() != ErrorCode::WANT_READ { - Err(err.into()) - } else { - Ok(()) - } - } - } +impl TlsConnection for OpenSslConnection { + type Config = SslContext; + + fn name() -> String { + let version_num = openssl::version::number() as u64; + let patch: u8 = (version_num >> 4) as u8; + let fix = (version_num >> 12) as u8; + let minor = (version_num >> 20) as u8; + let major = (version_num >> 28) as u8; + format!( + "openssl{}.{}.{}{}", + major, + minor, + fix, + (b'a' + patch - 1) as char + ) } -} -impl TlsBenchHarness for OpenSslHarness { - fn new( + fn make_config( + mode: Mode, crypto_config: CryptoConfig, handshake_type: HandshakeType, - buffer: ConnectedBuffer, - ) -> Result> { - let client_buf = buffer; - let server_buf = client_buf.clone_inverse(); - + ) -> Result> { let cipher_suite = match crypto_config.cipher_suite { CipherSuite::AES_128_GCM_SHA256 => "TLS_AES_128_GCM_SHA256", CipherSuite::AES_256_GCM_SHA384 => "TLS_AES_256_GCM_SHA384", }; - let ec_key = match crypto_config.ec_group { - ECGroup::SECP256R1 => "P-256", - ECGroup::X25519 => "X25519", + let ec_key = match crypto_config.kx_group { + KXGroup::Secp256R1 => "P-256", + KXGroup::X25519 => "X25519", }; - let mut client_builder = SslContext::builder(SslMethod::tls_client())?; - Self::common_config(&mut client_builder, cipher_suite, ec_key)?; - client_builder.set_ca_file(get_cert_path(CACert, crypto_config.sig_type))?; - - let mut server_builder = SslContext::builder(SslMethod::tls_server())?; - Self::common_config(&mut server_builder, cipher_suite, ec_key)?; - server_builder - .set_certificate_chain_file(get_cert_path(ServerCertChain, crypto_config.sig_type))?; - server_builder.set_private_key_file( - get_cert_path(ServerKey, crypto_config.sig_type), - SslFiletype::PEM, - )?; - - if handshake_type == HandshakeType::MutualAuth { - client_builder.set_certificate_chain_file(get_cert_path( - ClientCertChain, - crypto_config.sig_type, - ))?; - client_builder.set_private_key_file( - get_cert_path(ClientKey, crypto_config.sig_type), - SslFiletype::PEM, - )?; - server_builder.set_ca_file(get_cert_path(CACert, crypto_config.sig_type))?; - server_builder.set_verify(SslVerifyMode::FAIL_IF_NO_PEER_CERT | SslVerifyMode::PEER); - } + let ssl_method = match mode { + Mode::Client => SslMethod::tls_client(), + Mode::Server => SslMethod::tls_server(), + }; - let client_config = client_builder.build(); - let server_config = server_builder.build(); + let mut builder = SslContext::builder(ssl_method)?; + builder.set_min_proto_version(Some(SslVersion::TLS1_3))?; + builder.set_ciphersuites(cipher_suite)?; + builder.set_groups_list(ec_key)?; - let client_conn = SslStream::new(Ssl::new(&client_config)?, client_buf)?; - let server_conn = SslStream::new(Ssl::new(&server_config)?, server_buf)?; + match mode { + Mode::Client => { + builder.set_ca_file(get_cert_path(CACert, crypto_config.sig_type))?; + builder.set_verify(SslVerifyMode::FAIL_IF_NO_PEER_CERT | SslVerifyMode::PEER); + + if handshake_type == HandshakeType::MutualAuth { + builder.set_certificate_chain_file(get_cert_path( + ClientCertChain, + crypto_config.sig_type, + ))?; + builder.set_private_key_file( + get_cert_path(ClientKey, crypto_config.sig_type), + SslFiletype::PEM, + )?; + } + } + Mode::Server => { + builder.set_certificate_chain_file(get_cert_path( + ServerCertChain, + crypto_config.sig_type, + ))?; + builder.set_private_key_file( + get_cert_path(ServerKey, crypto_config.sig_type), + SslFiletype::PEM, + )?; + + if handshake_type == HandshakeType::MutualAuth { + builder.set_ca_file(get_cert_path(CACert, crypto_config.sig_type))?; + builder.set_verify(SslVerifyMode::FAIL_IF_NO_PEER_CERT | SslVerifyMode::PEER); + } + } + } + + Ok(builder.build()) + } + fn new_from_config( + config: &Self::Config, + connected_buffer: ConnectedBuffer, + ) -> Result> { + let connection = SslStream::new(Ssl::new(config)?, connected_buffer.clone())?; Ok(Self { - client_conn, - server_conn, + connected_buffer, + connection, }) } fn handshake(&mut self) -> Result<(), Box> { - for _ in 0..2 { - self.handshake_conn(Mode::Client)?; - self.handshake_conn(Mode::Server)?; + let result = if self.connection.ssl().is_server() { + self.connection.accept() + } else { + self.connection.connect() + }; + + // treat blocking (`ErrorCode::WANT_READ`) as `Ok`, expected during handshake + match result { + Ok(_) => Ok(()), + Err(err) => { + if err.code() != ErrorCode::WANT_READ { + Err(err.into()) + } else { + Ok(()) + } + } } - Ok(()) } fn handshake_completed(&self) -> bool { - self.client_conn.ssl().is_init_finished() && self.server_conn.ssl().is_init_finished() + self.connection.ssl().is_init_finished() } fn get_negotiated_cipher_suite(&self) -> CipherSuite { let cipher_suite = self - .client_conn + .connection .ssl() .current_cipher() .expect("Handshake not completed") @@ -136,39 +150,42 @@ impl TlsBenchHarness for OpenSslHarness { } fn negotiated_tls13(&self) -> bool { - self.client_conn + self.connection .ssl() .version2() // version() -> &str is deprecated, version2() returns an enum instead .expect("Handshake not completed") == SslVersion::TLS1_3 } - fn send(&mut self, sender: Mode, data: &[u8]) -> Result<(), Box> { - let send_conn = match sender { - Mode::Client => &mut self.client_conn, - Mode::Server => &mut self.server_conn, - }; - + fn send(&mut self, data: &[u8]) -> Result<(), Box> { let mut write_offset = 0; while write_offset < data.len() { - write_offset += send_conn.write(&data[write_offset..data.len()])?; - send_conn.flush()?; // make sure internal buffers don't fill up + write_offset += self.connection.write(&data[write_offset..data.len()])?; + self.connection.flush()?; // make sure internal buffers don't fill up } - Ok(()) } - fn recv(&mut self, receiver: Mode, data: &mut [u8]) -> Result<(), Box> { - let recv_conn = match receiver { - Mode::Client => &mut self.client_conn, - Mode::Server => &mut self.server_conn, - }; - + fn recv(&mut self, data: &mut [u8]) -> Result<(), Box> { + let data_len = data.len(); let mut read_offset = 0; while read_offset < data.len() { - read_offset += recv_conn.read(data)? + read_offset += self.connection.read(&mut data[read_offset..data_len])? } - Ok(()) } + + /// With OpenSSL's API, not possible after connection initialization: + /// In order to shrink buffers owned by the connection, config has to built + /// with `builder.set_mode(SslMode::RELEASE_BUFFERS);`, which tells the + /// connection to release buffers only when it's idle + fn shrink_connection_buffers(&mut self) {} + + fn shrink_connected_buffer(&mut self) { + self.connected_buffer.shrink(); + } + + fn connected_buffer(&self) -> &ConnectedBuffer { + &self.connected_buffer + } } diff --git a/bindings/rust/bench/src/rustls.rs b/bindings/rust/bench/src/rustls.rs index 5b567bc71e5..e16f2279dfe 100644 --- a/bindings/rust/bench/src/rustls.rs +++ b/bindings/rust/bench/src/rustls.rs @@ -3,8 +3,8 @@ use crate::{ harness::{ - read_to_bytes, CipherSuite, ConnectedBuffer, CryptoConfig, ECGroup, HandshakeType, Mode, - TlsBenchHarness, + read_to_bytes, CipherSuite, ConnectedBuffer, CryptoConfig, HandshakeType, KXGroup, Mode, + TlsConnection, }, PemType::{self, *}, SigType, @@ -14,9 +14,7 @@ use rustls::{ kx_group::{SECP256R1, X25519}, server::AllowAnyAuthenticatedClient, version::TLS13, - Certificate, ClientConfig, ClientConnection, Connection, - Connection::{Client, Server}, - PrivateKey, + Certificate, ClientConfig, ClientConnection, Connection, PrivateKey, ProtocolVersion::TLSv1_3, RootCertStore, ServerConfig, ServerConnection, ServerName, }; @@ -27,14 +25,12 @@ use std::{ sync::Arc, }; -pub struct RustlsHarness { - client_buf: ConnectedBuffer, - server_buf: ConnectedBuffer, - client_conn: Connection, - server_conn: Connection, +pub struct RustlsConnection { + connected_buffer: ConnectedBuffer, + connection: Connection, } -impl RustlsHarness { +impl RustlsConnection { fn get_root_cert_store(sig_type: SigType) -> Result> { let root_cert = Certificate(certs(&mut BufReader::new(&*read_to_bytes(CACert, sig_type)))?.remove(0)); @@ -60,7 +56,7 @@ impl RustlsHarness { )) } - /// Treat `WouldBlock` as `Ok` for when blocking is expected + /// Treat `WouldBlock` as an `Ok` value for when blocking is expected fn ignore_block(res: Result) -> Result { match res { Ok(t) => Ok(t), @@ -72,88 +68,108 @@ impl RustlsHarness { } } -impl TlsBenchHarness for RustlsHarness { - fn new( +/// Clients and servers have different config types in Rustls, so wrap them in an enum +pub enum RustlsConfig { + Client(Arc), + Server(Arc), +} + +impl TlsConnection for RustlsConnection { + type Config = RustlsConfig; + + fn name() -> String { + "rustls".to_string() + } + + fn make_config( + mode: Mode, crypto_config: CryptoConfig, handshake_type: HandshakeType, - buffer: ConnectedBuffer, - ) -> Result> { - let client_buf = buffer; - let server_buf = client_buf.clone_inverse(); - + ) -> Result> { let cipher_suite = match crypto_config.cipher_suite { CipherSuite::AES_128_GCM_SHA256 => TLS13_AES_128_GCM_SHA256, CipherSuite::AES_256_GCM_SHA384 => TLS13_AES_256_GCM_SHA384, }; - let kx_group = match crypto_config.ec_group { - ECGroup::SECP256R1 => &SECP256R1, - ECGroup::X25519 => &X25519, + let kx_group = match crypto_config.kx_group { + KXGroup::Secp256R1 => &SECP256R1, + KXGroup::X25519 => &X25519, }; - let client_builder = ClientConfig::builder() - .with_cipher_suites(&[cipher_suite]) - .with_kx_groups(&[kx_group]) - .with_protocol_versions(&[&TLS13])? - .with_root_certificates(Self::get_root_cert_store(crypto_config.sig_type)?); - - let server_builder = ServerConfig::builder() - .with_cipher_suites(&[cipher_suite]) - .with_kx_groups(&[kx_group]) - .with_protocol_versions(&[&TLS13])?; - - let (client_builder, server_builder) = match handshake_type { - HandshakeType::MutualAuth => ( - client_builder.with_client_auth_cert( - Self::get_cert_chain(ClientCertChain, crypto_config.sig_type)?, - Self::get_key(ClientKey, crypto_config.sig_type)?, - )?, - server_builder.with_client_cert_verifier(Arc::new( - AllowAnyAuthenticatedClient::new(Self::get_root_cert_store( - crypto_config.sig_type, - )?), - )), - ), - HandshakeType::ServerAuth => ( - client_builder.with_no_client_auth(), - server_builder.with_no_client_auth(), - ), - }; - - let client_config = Arc::new(client_builder); - let server_config = Arc::new(server_builder.with_single_cert( - Self::get_cert_chain(ServerCertChain, crypto_config.sig_type)?, - Self::get_key(ServerKey, crypto_config.sig_type)?, - )?); + match mode { + Mode::Client => { + let builder = ClientConfig::builder() + .with_cipher_suites(&[cipher_suite]) + .with_kx_groups(&[kx_group]) + .with_protocol_versions(&[&TLS13])? + .with_root_certificates(Self::get_root_cert_store(crypto_config.sig_type)?); + + let config = match handshake_type { + HandshakeType::ServerAuth => builder.with_no_client_auth(), + HandshakeType::MutualAuth => builder.with_client_auth_cert( + Self::get_cert_chain(ClientCertChain, crypto_config.sig_type)?, + Self::get_key(ClientKey, crypto_config.sig_type)?, + )?, + }; + + Ok(RustlsConfig::Client(Arc::new(config))) + } + Mode::Server => { + let builder = ServerConfig::builder() + .with_cipher_suites(&[cipher_suite]) + .with_kx_groups(&[kx_group]) + .with_protocol_versions(&[&TLS13])?; + + let builder = match handshake_type { + HandshakeType::ServerAuth => builder.with_no_client_auth(), + HandshakeType::MutualAuth => builder.with_client_cert_verifier(Arc::new( + AllowAnyAuthenticatedClient::new(Self::get_root_cert_store( + crypto_config.sig_type, + )?), + )), + }; + + let config = builder.with_single_cert( + Self::get_cert_chain(ServerCertChain, crypto_config.sig_type)?, + Self::get_key(ServerKey, crypto_config.sig_type)?, + )?; + + Ok(RustlsConfig::Server(Arc::new(config))) + } + } + } - let client_conn = Client(ClientConnection::new( - client_config, - ServerName::try_from("localhost")?, - )?); - let server_conn = Server(ServerConnection::new(server_config)?); + fn new_from_config( + config: &Self::Config, + connected_buffer: ConnectedBuffer, + ) -> Result> { + let connection = match config { + RustlsConfig::Client(config) => Connection::Client(ClientConnection::new( + config.clone(), + ServerName::try_from("localhost")?, + )?), + RustlsConfig::Server(config) => { + Connection::Server(ServerConnection::new(config.clone())?) + } + }; Ok(Self { - client_buf, - server_buf, - client_conn, - server_conn, + connected_buffer, + connection, }) } fn handshake(&mut self) -> Result<(), Box> { - for _ in 0..2 { - Self::ignore_block(self.client_conn.complete_io(&mut self.client_buf))?; - Self::ignore_block(self.server_conn.complete_io(&mut self.server_buf))?; - } + Self::ignore_block(self.connection.complete_io(&mut self.connected_buffer))?; Ok(()) } fn handshake_completed(&self) -> bool { - !self.client_conn.is_handshaking() && !self.server_conn.is_handshaking() + !self.connection.is_handshaking() } fn get_negotiated_cipher_suite(&self) -> CipherSuite { - match self.client_conn.negotiated_cipher_suite().unwrap().suite() { + match self.connection.negotiated_cipher_suite().unwrap().suite() { rustls::CipherSuite::TLS13_AES_128_GCM_SHA256 => CipherSuite::AES_128_GCM_SHA256, rustls::CipherSuite::TLS13_AES_256_GCM_SHA384 => CipherSuite::AES_256_GCM_SHA384, _ => panic!("Unknown cipher suite"), @@ -161,42 +177,48 @@ impl TlsBenchHarness for RustlsHarness { } fn negotiated_tls13(&self) -> bool { - self.client_conn + self.connection .protocol_version() .expect("Handshake not completed") == TLSv1_3 } - fn send(&mut self, sender: Mode, data: &[u8]) -> Result<(), Box> { - let (send_conn, send_buf) = match sender { - Mode::Client => (&mut self.client_conn, &mut self.client_buf), - Mode::Server => (&mut self.server_conn, &mut self.server_buf), - }; - + fn send(&mut self, data: &[u8]) -> Result<(), Box> { let mut write_offset = 0; while write_offset < data.len() { - write_offset += send_conn.writer().write(&data[write_offset..data.len()])?; - send_conn.writer().flush()?; - send_conn.complete_io(send_buf)?; + write_offset += self + .connection + .writer() + .write(&data[write_offset..data.len()])?; + self.connection.writer().flush()?; + self.connection.complete_io(&mut self.connected_buffer)?; } - Ok(()) } - fn recv(&mut self, receiver: Mode, data: &mut [u8]) -> Result<(), Box> { - let (recv_conn, recv_buf) = match receiver { - Mode::Client => (&mut self.client_conn, &mut self.client_buf), - Mode::Server => (&mut self.server_conn, &mut self.server_buf), - }; - + fn recv(&mut self, data: &mut [u8]) -> Result<(), Box> { let data_len = data.len(); let mut read_offset = 0; while read_offset < data.len() { - recv_conn.complete_io(recv_buf)?; - read_offset += - Self::ignore_block(recv_conn.reader().read(&mut data[read_offset..data_len]))?; + self.connection.complete_io(&mut self.connected_buffer)?; + read_offset += Self::ignore_block( + self.connection + .reader() + .read(&mut data[read_offset..data_len]), + )?; } - Ok(()) } + + fn shrink_connection_buffers(&mut self) { + self.connection.set_buffer_limit(Some(1)); + } + + fn shrink_connected_buffer(&mut self) { + self.connected_buffer.shrink(); + } + + fn connected_buffer(&self) -> &ConnectedBuffer { + &self.connected_buffer + } } diff --git a/bindings/rust/bench/src/s2n_tls.rs b/bindings/rust/bench/src/s2n_tls.rs index 8a2ba7e6b6a..caa30d92d8c 100644 --- a/bindings/rust/bench/src/s2n_tls.rs +++ b/bindings/rust/bench/src/s2n_tls.rs @@ -3,14 +3,14 @@ use crate::{ harness::{ - read_to_bytes, CipherSuite, ConnectedBuffer, CryptoConfig, ECGroup, HandshakeType, Mode, - TlsBenchHarness, + read_to_bytes, CipherSuite, ConnectedBuffer, CryptoConfig, HandshakeType, KXGroup, Mode, + TlsConnection, }, PemType::*, }; use s2n_tls::{ callbacks::VerifyHostNameCallback, - config::{Builder, Config}, + config::Builder, connection::Connection, enums::{Blinding, ClientAuthType, Version}, security::Policy, @@ -21,21 +21,8 @@ use std::{ io::{ErrorKind, Read, Write}, os::raw::c_int, pin::Pin, - task::Poll::Ready, }; -#[allow(dead_code)] -pub struct S2NHarness { - // UnsafeCell is needed b/c client and server share *mut to IO buffers - // Pin> is to ensure long-term *mut to IO buffers remain valid - client_buf: Pin>, - server_buf: Pin>, - client_conn: Connection, - server_conn: Connection, - client_handshake_completed: bool, - server_handshake_completed: bool, -} - /// Custom callback for verifying hostnames. Rustls requires checking hostnames, /// so this is to make a fair comparison struct HostNameHandler<'a> { @@ -47,7 +34,20 @@ impl VerifyHostNameCallback for HostNameHandler<'_> { } } -impl S2NHarness { +/// s2n-tls has mode-independent configs, so this struct wraps the config with the mode +pub struct S2NConfig { + mode: Mode, + config: s2n_tls::config::Config, +} + +pub struct S2NConnection { + // Pin> is to ensure long-term *mut to IO buffers remains valid + connected_buffer: Pin>, + connection: Connection, + handshake_completed: bool, +} + +impl S2NConnection { /// Unsafe callback for custom IO C API /// /// s2n-tls IO is usually used with file descriptors to a TCP socket, but we @@ -65,6 +65,7 @@ impl S2NHarness { context.flush().unwrap(); match context.read(data) { Err(err) => { + // s2n-tls requires the callback to set errno if blocking happens if let ErrorKind::WouldBlock = err.kind() { errno::set_errno(errno::Errno(libc::EWOULDBLOCK)); -1 @@ -75,16 +76,27 @@ impl S2NHarness { Ok(len) => len as _, } } +} + +impl TlsConnection for S2NConnection { + type Config = S2NConfig; + + fn name() -> String { + "s2n-tls".to_string() + } - fn create_common_config_builder( + fn make_config( + mode: Mode, crypto_config: CryptoConfig, handshake_type: HandshakeType, - ) -> Result> { - let security_policy = match (crypto_config.cipher_suite, crypto_config.ec_group) { - (CipherSuite::AES_128_GCM_SHA256, ECGroup::SECP256R1) => "20230317", - (CipherSuite::AES_256_GCM_SHA384, ECGroup::SECP256R1) => "20190802", - (CipherSuite::AES_128_GCM_SHA256, ECGroup::X25519) => "default_tls13", - (CipherSuite::AES_256_GCM_SHA384, ECGroup::X25519) => "20190801", + ) -> Result> { + // these security policies negotiate the given cipher suite and key + // exchange group as their top choice + let security_policy = match (crypto_config.cipher_suite, crypto_config.kx_group) { + (CipherSuite::AES_128_GCM_SHA256, KXGroup::Secp256R1) => "20230317", + (CipherSuite::AES_256_GCM_SHA384, KXGroup::Secp256R1) => "20190802", + (CipherSuite::AES_128_GCM_SHA256, KXGroup::X25519) => "default_tls13", + (CipherSuite::AES_256_GCM_SHA384, KXGroup::X25519) => "20190801", }; let mut builder = Builder::new(); @@ -96,130 +108,90 @@ impl S2NHarness { HandshakeType::MutualAuth => ClientAuthType::Required, })?; - Ok(builder) - } - - fn create_client_config( - crypto_config: CryptoConfig, - handshake_type: HandshakeType, - ) -> Result> { - let mut builder = Self::create_common_config_builder(crypto_config, handshake_type)?; - builder - .trust_pem(read_to_bytes(CACert, crypto_config.sig_type).as_slice())? - .set_verify_host_callback(HostNameHandler { - expected_server_name: "localhost", - })?; - - if handshake_type == HandshakeType::MutualAuth { - builder.load_pem( - read_to_bytes(ClientCertChain, crypto_config.sig_type).as_slice(), - read_to_bytes(ClientKey, crypto_config.sig_type).as_slice(), - )?; + match mode { + Mode::Client => { + builder + .trust_pem(read_to_bytes(CACert, crypto_config.sig_type).as_slice())? + .set_verify_host_callback(HostNameHandler { + expected_server_name: "localhost", + })?; + + if handshake_type == HandshakeType::MutualAuth { + builder.load_pem( + read_to_bytes(ClientCertChain, crypto_config.sig_type).as_slice(), + read_to_bytes(ClientKey, crypto_config.sig_type).as_slice(), + )?; + } + } + Mode::Server => { + builder.load_pem( + read_to_bytes(ServerCertChain, crypto_config.sig_type).as_slice(), + read_to_bytes(ServerKey, crypto_config.sig_type).as_slice(), + )?; + + if handshake_type == HandshakeType::MutualAuth { + builder + .trust_pem(read_to_bytes(CACert, crypto_config.sig_type).as_slice())? + .set_verify_host_callback(HostNameHandler { + expected_server_name: "localhost", + })?; + } + } } - Ok(builder.build()?) + Ok(S2NConfig { + mode, + config: builder.build()?, + }) } - fn create_server_config( - crypto_config: CryptoConfig, - handshake_type: HandshakeType, - ) -> Result> { - let mut builder = Self::create_common_config_builder(crypto_config, handshake_type)?; - builder.load_pem( - read_to_bytes(ServerCertChain, crypto_config.sig_type).as_slice(), - read_to_bytes(ServerKey, crypto_config.sig_type).as_slice(), - )?; - - if handshake_type == HandshakeType::MutualAuth { - builder - .trust_pem(read_to_bytes(CACert, crypto_config.sig_type).as_slice())? - .set_verify_host_callback(HostNameHandler { - expected_server_name: "localhost", - })?; - } + fn new_from_config( + config: &Self::Config, + connected_buffer: ConnectedBuffer, + ) -> Result> { + let mode = match config.mode { + Mode::Client => s2n_tls::enums::Mode::Client, + Mode::Server => s2n_tls::enums::Mode::Server, + }; - Ok(builder.build()?) - } + let mut connected_buffer = Box::pin(connected_buffer); - /// Set up connections with config and custom IO - fn init_conn( - conn: &mut Connection, - buffer: &mut Pin>, - config: Config, - ) -> Result<(), Box> { - conn.set_blinding(Blinding::SelfService)? - .set_config(config)? + let mut connection = Connection::new(mode); + connection + .set_blinding(Blinding::SelfService)? + .set_config(config.config.clone())? .set_send_callback(Some(Self::send_cb))? .set_receive_callback(Some(Self::recv_cb))?; unsafe { - conn.set_send_context(&mut **buffer as *mut ConnectedBuffer as *mut c_void)? - .set_receive_context(&mut **buffer as *mut ConnectedBuffer as *mut c_void)?; - } - - Ok(()) - } - - /// Handshake step for one connection - fn handshake_conn(&mut self, mode: Mode) -> Result<(), Box> { - let (conn, handshake_completed) = match mode { - Mode::Client => (&mut self.client_conn, &mut self.client_handshake_completed), - Mode::Server => (&mut self.server_conn, &mut self.server_handshake_completed), - }; - - if let Ready(res) = conn.poll_negotiate() { - res?; - *handshake_completed = true; - } else { - *handshake_completed = false; + connection + .set_send_context(&mut *connected_buffer as *mut ConnectedBuffer as *mut c_void)? + .set_receive_context( + &mut *connected_buffer as *mut ConnectedBuffer as *mut c_void, + )?; } - Ok(()) - } -} - -impl TlsBenchHarness for S2NHarness { - fn new( - crypto_config: CryptoConfig, - handshake_type: HandshakeType, - buffer: ConnectedBuffer, - ) -> Result> { - let mut client_buf = Box::pin(buffer); - let mut server_buf = Box::pin(client_buf.clone_inverse()); - - let client_config = Self::create_client_config(crypto_config, handshake_type)?; - let server_config = Self::create_server_config(crypto_config, handshake_type)?; - let mut client_conn = Connection::new_client(); - let mut server_conn = Connection::new_server(); - - Self::init_conn(&mut client_conn, &mut client_buf, client_config)?; - Self::init_conn(&mut server_conn, &mut server_buf, server_config)?; - - let harness = Self { - client_buf, - server_buf, - client_conn, - server_conn, - client_handshake_completed: false, - server_handshake_completed: false, - }; - - Ok(harness) + Ok(Self { + connected_buffer, + connection, + handshake_completed: false, + }) } fn handshake(&mut self) -> Result<(), Box> { - for _ in 0..2 { - self.handshake_conn(Mode::Client)?; - self.handshake_conn(Mode::Server)?; - } + self.handshake_completed = self + .connection + .poll_negotiate() + .map(|res| res.unwrap()) // unwrap `Err` if present + .is_ready(); Ok(()) } fn handshake_completed(&self) -> bool { - self.client_handshake_completed && self.server_handshake_completed + self.handshake_completed } fn get_negotiated_cipher_suite(&self) -> CipherSuite { - match self.client_conn.cipher_suite().unwrap() { + match self.connection.cipher_suite().unwrap() { "TLS_AES_128_GCM_SHA256" => CipherSuite::AES_128_GCM_SHA256, "TLS_AES_256_GCM_SHA384" => CipherSuite::AES_256_GCM_SHA384, _ => panic!("Unknown cipher suite"), @@ -227,28 +199,29 @@ impl TlsBenchHarness for S2NHarness { } fn negotiated_tls13(&self) -> bool { - self.client_conn.actual_protocol_version().unwrap() == Version::TLS13 + self.connection.actual_protocol_version().unwrap() == Version::TLS13 } - fn send(&mut self, sender: Mode, data: &[u8]) -> Result<(), Box> { - let send_conn = match sender { - Mode::Client => &mut self.client_conn, - Mode::Server => &mut self.server_conn, - }; - - assert!(send_conn.poll_send(data).is_ready()); - assert!(send_conn.poll_flush().is_ready()); + fn send(&mut self, data: &[u8]) -> Result<(), Box> { + assert!(self.connection.poll_send(data).is_ready()); + assert!(self.connection.poll_flush().is_ready()); + Ok(()) + } + fn recv(&mut self, data: &mut [u8]) -> Result<(), Box> { + assert!(self.connection.poll_recv(data).is_ready()); Ok(()) } - fn recv(&mut self, receiver: Mode, data: &mut [u8]) -> Result<(), Box> { - let recv_conn = match receiver { - Mode::Client => &mut self.client_conn, - Mode::Server => &mut self.server_conn, - }; + fn shrink_connection_buffers(&mut self) { + self.connection.release_buffers().unwrap(); + } - assert!(recv_conn.poll_recv(data).is_ready()); - Ok(()) + fn shrink_connected_buffer(&mut self) { + self.connected_buffer.shrink(); + } + + fn connected_buffer(&self) -> &ConnectedBuffer { + &self.connected_buffer } }