Skip to content

Commit

Permalink
Merge from main
Browse files Browse the repository at this point in the history
  • Loading branch information
akoshelev committed Jul 17, 2024
2 parents ed55e69 + c160823 commit c020f73
Show file tree
Hide file tree
Showing 14 changed files with 371 additions and 114 deletions.
24 changes: 24 additions & 0 deletions .github/workflows/check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,30 @@ jobs:

- name: Run compact gate tests
run: cargo test --no-default-features --features "cli web-app real-world-infra test-fixture compact-gate"
slow:
name: Slow tests
env:
EXEC_SLOW_TESTS: 1
RUSTFLAGS: -C target-cpu=native

runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4
- uses: ./.github/actions/rm
- uses: dtolnay/rust-toolchain@stable
- uses: actions/cache@v4
with:
path: |
~/.cargo/bin/
~/.cargo/registry/index/
~/.cargo/registry/cache/
~/.cargo/git/db/
target/
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.toml') }}

- name: End-to-end tests
run: cargo test --release --test "*" --no-default-features --features "cli web-app real-world-infra test-fixture compact-gate"

# sanitizers currently require nightly https://github.com/rust-lang/rust/issues/39699
sanitize:
Expand Down
4 changes: 2 additions & 2 deletions ipa-core/src/bin/test_mpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ enum TestAction {
/// Execute end-to-end simple addition circuit that uses prime fields.
/// All helpers add their shares locally and set the resulting share to be the
/// sum. No communication is required to run the circuit.
Add,
AddInPrimeField,
}

#[tokio::main]
Expand All @@ -101,7 +101,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
let (clients, _) = make_clients(args.network.as_deref(), scheme, args.wait).await;
match args.action {
TestAction::Multiply => multiply(&args, &clients).await,
TestAction::Add => add(&args, &clients).await,
TestAction::AddInPrimeField => add(&args, &clients).await,
};

Ok(())
Expand Down
4 changes: 3 additions & 1 deletion ipa-core/src/helpers/buffers/unordered_receiver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,9 @@ where
.enumerate()
.filter_map(|(i, waker)| waker.as_ref().map(|_| i))
.map(move |i| {
if i < start {
// We don't save a waker at `self.next`, so `start` is actually the last waker, and
// `start + 1` is the first.
if i <= start {
self.next + (self.wakers.len() - start + i)
} else {
self.next + (i - start)
Expand Down
2 changes: 1 addition & 1 deletion ipa-core/src/helpers/gateway/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ impl Gateway {
impl Default for GatewayConfig {
fn default() -> Self {
Self {
active: 1024.try_into().unwrap(),
active: 32768.try_into().unwrap(),
read_size: 2048.try_into().unwrap(),
// In-memory tests are fast, so progress check intervals can be lower.
// Real world scenarios currently over-report stalls because of inefficiencies inside
Expand Down
25 changes: 21 additions & 4 deletions ipa-core/src/helpers/transport/query/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::{
cmp::{max, min},
fmt::{Debug, Display, Formatter},
num::NonZeroU32,
num::{NonZeroU32, NonZeroUsize},
};

use serde::{Deserialize, Deserializer, Serialize};
Expand Down Expand Up @@ -140,9 +141,25 @@ impl RouteParams<RouteId, NoQueryId, NoStep> for &QueryConfig {
}

impl From<&QueryConfig> for GatewayConfig {
fn from(_value: &QueryConfig) -> Self {
// TODO: pick the correct value for active and test it
Self::default()
fn from(value: &QueryConfig) -> Self {
let mut config = Self::default();
// Minimum size for active work is 2 because:
// * `UnorderedReceiver` wants capacity to be greater than 1
// * 1 is better represented by not using seq_join and/or indeterminate total records
let active = max(
2,
min(
config.active.get(),
// It makes sense to start with active work set to input size, but some protocols
// may want to change that, if their fanout factor per input row is greater than 1.
// we don't have capabilities (see #ipa/1171) to allow that currently.
usize::try_from(value.size.0).expect("u32 fits into usize"),
),
);
// we set active to be at least 2, so unwrap is fine.
config.active = NonZeroUsize::new(active).unwrap();

config
}
}

Expand Down
23 changes: 5 additions & 18 deletions ipa-core/src/helpers/transport/stream/axum_body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@ use axum::body::{Body, BodyDataStream};
use bytes::Bytes;
use futures::{Stream, StreamExt};
use pin_project::pin_project;
use tokio_stream::wrappers::ReceiverStream;

use crate::error::BoxError;
use crate::{error::BoxError, helpers::BytesStream};

/// This struct is a simple wrapper so that both in-memory-infra and real-world-infra have a
/// unified interface for streams consumed by transport layer.
Expand All @@ -25,6 +24,10 @@ impl WrappedAxumBodyStream {
pub fn empty() -> Self {
Self::new(Body::empty())
}

pub fn from_bytes_stream<S: BytesStream + 'static>(stream: S) -> Self {
Self::new(axum::body::Body::from_stream(stream))
}
}

impl Stream for WrappedAxumBodyStream {
Expand All @@ -36,22 +39,6 @@ impl Stream for WrappedAxumBodyStream {
}
}

// Note that it is possible (although unlikely) that `from_body` panics.
#[cfg(any(test, feature = "test-fixture"))]
impl<Buf: Into<bytes::Bytes>> From<Buf> for WrappedAxumBodyStream {
fn from(buf: Buf) -> Self {
Self::new(Body::from(buf.into()))
}
}

impl WrappedAxumBodyStream {
#[must_use]
pub fn from_receiver_stream(receiver: Box<ReceiverStream<Result<Bytes, BoxError>>>) -> Self {
Self::new(Body::from_stream(receiver))
}
}

#[cfg(feature = "real-world-infra")]
#[async_trait::async_trait]
impl<S> axum::extract::FromRequest<S> for WrappedAxumBodyStream
where
Expand Down
8 changes: 2 additions & 6 deletions ipa-core/src/helpers/transport/stream/box_body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ use crate::helpers::{transport::stream::BoxBytesStream, BytesStream};
pub struct WrappedBoxBodyStream(BoxBytesStream);

impl WrappedBoxBodyStream {
/// Wrap an axum body stream, returning an instance of `crate::helpers::BodyStream`.
#[cfg(all(feature = "in-memory-infra", feature = "web-app"))]
#[must_use]
pub fn new(bytes: bytes::Bytes) -> Self {
let stream = futures::stream::once(futures::future::ready(Ok(bytes)));
Expand All @@ -29,7 +27,7 @@ impl WrappedBoxBodyStream {

#[must_use]
pub fn empty() -> Self {
WrappedBoxBodyStream(Box::pin(futures::stream::empty()))
Self(Box::pin(futures::stream::empty()))
}
}

Expand All @@ -45,9 +43,7 @@ impl Stream for WrappedBoxBodyStream {
#[cfg(any(test, feature = "test-fixture"))]
impl<Buf: Into<bytes::Bytes>> From<Buf> for WrappedBoxBodyStream {
fn from(buf: Buf) -> Self {
Self(Box::pin(futures::stream::once(futures::future::ready(Ok(
buf.into(),
)))))
Self::new(buf.into())
}
}

Expand Down
156 changes: 151 additions & 5 deletions ipa-core/src/helpers/transport/stream/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,21 @@ mod box_body;
mod collection;
mod input;

use std::pin::Pin;
use std::{
pin::Pin,
task::{Context, Poll},
};

#[cfg(feature = "web-app")]
pub use axum_body::WrappedAxumBodyStream;
pub use box_body::WrappedBoxBodyStream;
use bytes::Bytes;
pub use collection::{StreamCollection, StreamKey};
use futures::Stream;
use futures::{stream::iter, Stream};
use futures_util::StreamExt;
pub use input::{LengthDelimitedStream, RecordsStream, SingleRecordStream};

use crate::error::BoxError;
use crate::{const_assert, error::BoxError};

pub trait BytesStream: Stream<Item = Result<Bytes, BoxError>> + Send {
/// Collects the entire stream into a vec; only intended for use in tests
Expand Down Expand Up @@ -42,6 +46,148 @@ pub type BoxBytesStream = Pin<Box<dyn BytesStream>>;
// * Avoiding an extra level of boxing in the production configuration using axum, since
// the axum body stream type is already a `Pin<Box<dyn HttpBody>>`.
#[cfg(feature = "in-memory-infra")]
pub type BodyStream = WrappedBoxBodyStream;
type BodyStreamInner = WrappedBoxBodyStream;
#[cfg(feature = "real-world-infra")]
pub type BodyStream = WrappedAxumBodyStream;
type BodyStreamInner = WrappedAxumBodyStream;

/// Wrapper around [`BodyStreamInner`] that enforces checks relevant to both in-memory and
/// real-world implementations.
pub struct BodyStream {
inner: BodyStreamInner,
}

impl BodyStream {
/// Wrap a [`Bytes`] object, returning an instance of `crate::helpers::BodyStream`.
/// If the given byte chunk exceeds [`super::MAX_HTTP_CHUNK_SIZE`],
/// it will be split into multiple parts, each not exceeding that size.
/// See #ipa/1141
pub fn new(bytes: Bytes) -> Self {
let stream = iter(bytes.split().into_iter().map(Ok::<_, BoxError>));
Self::from_bytes_stream(stream)
}

#[must_use]
pub fn empty() -> Self {
Self {
inner: BodyStreamInner::empty(),
}
}

pub fn from_bytes_stream(stream: impl BytesStream + 'static) -> Self {
Self {
inner: BodyStreamInner::from_bytes_stream(stream),
}
}
}

impl Stream for BodyStream {
type Item = Result<Bytes, BoxError>;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let next = self.inner.poll_next_unpin(cx);
if let Poll::Ready(Some(Ok(v))) = &next {
debug_assert!(
v.len() <= MAX_HTTP_CHUNK_SIZE_BYTES,
"Chunk size {} is greater than maximum allowed {MAX_HTTP_CHUNK_SIZE_BYTES} bytes",
v.len()
);
};

next
}
}

impl From<Vec<u8>> for BodyStream {
fn from(value: Vec<u8>) -> Self {
Self::new(Bytes::from(value))
}
}

#[cfg(feature = "web-app")]
#[async_trait::async_trait]
impl<S> axum::extract::FromRequest<S> for BodyStream
where
S: Send + Sync,
{
type Rejection = crate::net::Error;

async fn from_request(req: axum::extract::Request, state: &S) -> Result<Self, Self::Rejection> {
Ok(Self {
inner: BodyStreamInner::from_request(req, state).await?,
})
}
}

/// The size is chosen somewhat arbitrary - feel free to change it, but don't go above 2Gb as
/// that will cause Hyper's HTTP2 to fail.
const MAX_HTTP_CHUNK_SIZE_BYTES: usize = 1024 * 1024; // 1MB
const_assert!(MAX_HTTP_CHUNK_SIZE_BYTES > 0 && MAX_HTTP_CHUNK_SIZE_BYTES < (1 << 31) - 1);

/// Trait for objects that can be split into multiple parts.
///
/// This trait is used to split the body of an HTTP request into multiple parts
/// when the request body is too large to fit in memory. This can happen if the
/// request body is being streamed from a file or other large source.
trait Split {
type Dest;

fn split(self) -> Self::Dest;
}

impl Split for Bytes {
type Dest = Vec<Self>;

fn split(self) -> Self::Dest {
tracing::trace!(
"Will split '{sz}' bytes buffer into {chunks} chunks of size {MAX_HTTP_CHUNK_SIZE_BYTES}",
sz = self.len(),
chunks = self.len() / MAX_HTTP_CHUNK_SIZE_BYTES,
);

let mut segments = Vec::with_capacity(self.len() / MAX_HTTP_CHUNK_SIZE_BYTES);
let mut segment = self;
while segment.len() > MAX_HTTP_CHUNK_SIZE_BYTES {
segments.push(segment.split_to(MAX_HTTP_CHUNK_SIZE_BYTES));
}
segments.push(segment);

segments
}
}

#[cfg(all(test, unit_test))]
mod tests {
use bytes::Bytes;
use futures::{future, stream, stream::TryStreamExt};

use crate::{
helpers::{transport::stream::MAX_HTTP_CHUNK_SIZE_BYTES, BodyStream},
test_executor::run,
};

#[test]
fn chunks_the_input() {
run(|| async {
let data = vec![0_u8; 2 * MAX_HTTP_CHUNK_SIZE_BYTES + 1];
let stream = BodyStream::new(data.into());
let chunks = stream.try_collect::<Vec<_>>().await.unwrap();

assert_eq!(3, chunks.len());
assert_eq!(MAX_HTTP_CHUNK_SIZE_BYTES, chunks[0].len());
assert_eq!(MAX_HTTP_CHUNK_SIZE_BYTES, chunks[1].len());
assert_eq!(1, chunks[2].len());
});
}

#[test]
#[should_panic(expected = "Chunk size 1048577 is greater than maximum allowed 1048576 bytes")]
fn rejects_large_chunks() {
run(|| async {
let data = vec![0_u8; MAX_HTTP_CHUNK_SIZE_BYTES + 1];
let stream =
BodyStream::from_bytes_stream(stream::once(future::ready(Ok(Bytes::from(data)))));

stream.try_collect::<Vec<_>>().await.unwrap()
});
}
}
2 changes: 1 addition & 1 deletion ipa-core/src/net/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ mod tests {

let TestServer { transport, .. } = TestServer::default().await;

let body = BodyStream::from_receiver_stream(Box::new(ReceiverStream::new(rx)));
let body = BodyStream::from_bytes_stream(ReceiverStream::new(rx));

// Register the stream with the transport (normally called by step data HTTP API handler)
Arc::clone(&transport).receive_stream(QueryId, STEP.clone(), HelperIdentity::TWO, body);
Expand Down
2 changes: 1 addition & 1 deletion ipa-core/src/protocol/prss/crypto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ impl Generator {
/// Generate the value at the given index.
/// This uses the MMO^{\pi} function described in <https://eprint.iacr.org/2019/074>.
#[must_use]
pub(crate) fn generate<I: Into<PrssIndex128>>(&self, index: I) -> u128 {
pub(super) fn generate<I: Into<PrssIndex128>>(&self, index: I) -> u128 {
let index = index.into();
#[cfg(debug_assertions)]
self.used.use_index(index).unwrap();
Expand Down
Loading

0 comments on commit c020f73

Please sign in to comment.