Skip to content

Commit

Permalink
update Blob trait to use SegmentedBytes, update all implementors
Browse files Browse the repository at this point in the history
more fixes for non-S3

fixes for examples

update clients
  • Loading branch information
ParkMyCar committed Apr 17, 2023
1 parent a4d6ba3 commit 6920a93
Show file tree
Hide file tree
Showing 15 changed files with 90 additions and 48 deletions.
13 changes: 5 additions & 8 deletions src/ore/src/bytes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ use internal::SegmentedReader;
/// memory fragmentation) if you try to allocate a single very large chunk. Depending on the
/// application, you probably don't need a contiguous chunk of memory, just a way to store and
/// iterate over a collection of bytes.
///
/// Note: [`SegmentedBytes`] is generic over a `const N: usize`. Internally we use a
///
/// Note: [`SegmentedBytes`] is generic over a `const N: usize`. Internally we use a
/// [`smallvec::SmallVec`] to store our [`Bytes`] segments, and `N` is how many `Bytes` we'll
/// store inline before spilling to the heap. We default `N = 1`, so in the case of a single
/// `Bytes` segment, we avoid one layer of indirection.
/// `Bytes` segment, we avoid one layer of indirection.
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct SegmentedBytes<const N: usize = 1> {
/// Collection of non-contiguous segments.
Expand Down Expand Up @@ -305,12 +305,9 @@ mod internal {

#[cfg(test)]
mod tests {
use bytes::{
Buf,
Bytes,
};
use std::io::{Read, Seek, SeekFrom};
use bytes::{Buf, Bytes};
use proptest::prelude::*;
use std::io::{Read, Seek, SeekFrom};

use super::SegmentedBytes;

Expand Down
2 changes: 1 addition & 1 deletion src/persist-client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ differential-dataflow = { git = "https://github.com/TimelyDataflow/differential-
futures = "0.3.25"
futures-util = "0.3"
mz-build-info = { path = "../build-info" }
mz-ore = { path = "../ore", features = ["tracing_"] }
mz-ore = { path = "../ore", features = ["bytes_", "tracing_"] }
mz-persist = { path = "../persist" }
mz-persist-types = { path = "../persist-types" }
mz-proto = { path = "../proto" }
Expand Down
11 changes: 6 additions & 5 deletions src/persist-client/examples/maelstrom/services.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use std::sync::Arc;
use anyhow::anyhow;
use async_trait::async_trait;
use bytes::Bytes;
use mz_ore::bytes::SegmentedBytes;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio::sync::Mutex;
Expand Down Expand Up @@ -177,7 +178,7 @@ impl MaelstromBlob {

#[async_trait]
impl Blob for MaelstromBlob {
async fn get(&self, key: &str) -> Result<Option<Vec<u8>>, ExternalError> {
async fn get(&self, key: &str) -> Result<Option<SegmentedBytes>, ExternalError> {
let value = match self
.handle
.lin_kv_read(Value::from(format!("blob/{}", key)))
Expand All @@ -190,9 +191,9 @@ impl Blob for MaelstromBlob {
let value = value
.as_str()
.ok_or_else(|| anyhow!("invalid blob at {}: {:?}", key, value))?;
let value = serde_json::from_str(value)
let value: Vec<u8> = serde_json::from_str(value)
.map_err(|err| anyhow!("invalid blob at {}: {}", key, err))?;
Ok(Some(value))
Ok(Some(SegmentedBytes::from(value)))
}

async fn list_keys_and_metadata(
Expand Down Expand Up @@ -233,7 +234,7 @@ impl Blob for MaelstromBlob {
#[derive(Debug)]
pub struct CachingBlob {
blob: Arc<dyn Blob + Send + Sync>,
cache: Mutex<BTreeMap<String, Vec<u8>>>,
cache: Mutex<BTreeMap<String, SegmentedBytes>>,
}

impl CachingBlob {
Expand All @@ -247,7 +248,7 @@ impl CachingBlob {

#[async_trait]
impl Blob for CachingBlob {
async fn get(&self, key: &str) -> Result<Option<Vec<u8>>, ExternalError> {
async fn get(&self, key: &str) -> Result<Option<SegmentedBytes>, ExternalError> {
// Fetch the cached value if there is one.
let cache = self.cache.lock().await;
if let Some(value) = cache.get(key) {
Expand Down
3 changes: 2 additions & 1 deletion src/persist-client/src/cli/admin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use std::time::Instant;
use anyhow::anyhow;
use async_trait::async_trait;
use bytes::Bytes;
use mz_ore::bytes::SegmentedBytes;
use mz_ore::metrics::MetricsRegistry;
use mz_ore::now::SYSTEM_TIME;
use mz_persist::cfg::{BlobConfig, ConsensusConfig};
Expand Down Expand Up @@ -287,7 +288,7 @@ struct ReadOnly<T>(T);

#[async_trait]
impl Blob for ReadOnly<Arc<dyn Blob + Sync + Send>> {
async fn get(&self, key: &str) -> Result<Option<Vec<u8>>, ExternalError> {
async fn get(&self, key: &str) -> Result<Option<SegmentedBytes>, ExternalError> {
self.0.get(key).await
}

Expand Down
7 changes: 3 additions & 4 deletions src/persist-client/src/cli/inspect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ pub async fn fetch_state_rollup(
.get(&rollup_key.complete(&shard_id))
.await?
.expect("fetching the specified state rollup");
let proto = ProtoStateRollup::decode(rollup_buf.as_slice()).expect("invalid encoded state");
let proto = ProtoStateRollup::decode(rollup_buf).expect("invalid encoded state");
Ok(proto)
}

Expand Down Expand Up @@ -283,8 +283,7 @@ pub async fn fetch_state_rollups(args: &StateArgs) -> Result<impl serde::Seriali
.await
.unwrap();
if let Some(rollup_buf) = rollup_buf {
let proto =
ProtoStateRollup::decode(rollup_buf.as_slice()).expect("invalid encoded state");
let proto = ProtoStateRollup::decode(rollup_buf).expect("invalid encoded state");
rollup_states.insert(key.to_string(), proto);
}
}
Expand Down Expand Up @@ -500,7 +499,7 @@ pub async fn shard_stats(blob_uri: &str) -> anyhow::Result<()> {
};

let state: State<u64> =
UntypedState::decode(&cfg.build_version, &rollup).check_ts_codec(&shard)?;
UntypedState::decode(&cfg.build_version, rollup).check_ts_codec(&shard)?;

let leased_readers = state.collections.leased_readers.len();
let critical_readers = state.collections.critical_readers.len();
Expand Down
13 changes: 8 additions & 5 deletions src/persist-client/src/internal/encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use std::collections::BTreeMap;
use std::marker::PhantomData;
use std::sync::Arc;

use bytes::Bytes;
use bytes::{Buf, Bytes};
use differential_dataflow::lattice::Lattice;
use differential_dataflow::trace::Description;
use mz_persist::location::{SeqNo, VersionedData};
Expand Down Expand Up @@ -622,7 +622,7 @@ impl<T: Timestamp + Lattice + Codec64> UntypedState<T> {
Ok(self.state)
}

pub fn decode(build_version: &Version, buf: &[u8]) -> Self {
pub fn decode(build_version: &Version, buf: impl Buf) -> Self {
let proto = ProtoStateRollup::decode(buf)
// We received a State that we couldn't decode. This could happen if
// persist messes up backward/forward compatibility, if the durable
Expand Down Expand Up @@ -1021,6 +1021,7 @@ impl<T: Timestamp + Codec64> From<SerdeWriterEnrichedHollowBatch> for WriterEnri
mod tests {
use std::sync::atomic::Ordering;

use bytes::Bytes;
use mz_build_info::DUMMY_BUILD_INFO;
use mz_persist::location::SeqNo;

Expand All @@ -1043,16 +1044,17 @@ mod tests {
let state = TypedState::<(), (), u64, i64>::new(v2.clone(), shard_id, "".to_owned(), 0);
let mut buf = Vec::new();
state.encode(&mut buf);
let bytes = Bytes::from(buf);

// We can read it back using persist code v2 and v3.
assert_eq!(
UntypedState::<u64>::decode(&v2, &buf)
UntypedState::<u64>::decode(&v2, bytes.clone())
.check_codecs(&shard_id)
.as_ref(),
Ok(&state)
);
assert_eq!(
UntypedState::<u64>::decode(&v3, &buf)
UntypedState::<u64>::decode(&v3, bytes.clone())
.check_codecs(&shard_id)
.as_ref(),
Ok(&state)
Expand All @@ -1062,7 +1064,8 @@ mod tests {
// losing or misinterpreting something written out by a future version
// of code.
mz_ore::process::PANIC_ON_HALT.store(true, Ordering::SeqCst);
let v1_res = mz_ore::panic::catch_unwind(|| UntypedState::<u64>::decode(&v1, &buf));
let v1_res =
mz_ore::panic::catch_unwind(|| UntypedState::<u64>::decode(&v1, bytes.clone()));
assert!(v1_res.is_err());
}

Expand Down
3 changes: 2 additions & 1 deletion src/persist-client/src/internal/metrics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use std::time::{Duration, Instant};

use async_trait::async_trait;
use bytes::Bytes;
use mz_ore::bytes::SegmentedBytes;
use mz_ore::cast::{CastFrom, CastLossy};
use mz_ore::metric;
use mz_ore::metrics::{
Expand Down Expand Up @@ -1542,7 +1543,7 @@ impl MetricsBlob {

#[async_trait]
impl Blob for MetricsBlob {
async fn get(&self, key: &str) -> Result<Option<Vec<u8>>, ExternalError> {
async fn get(&self, key: &str) -> Result<Option<SegmentedBytes>, ExternalError> {
let res = self
.metrics
.blob
Expand Down
2 changes: 1 addition & 1 deletion src/persist-client/src/internal/state_versions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -747,7 +747,7 @@ impl StateVersions {
self.metrics
.codecs
.state
.decode(|| UntypedState::decode(&self.cfg.build_version, &buf))
.decode(|| UntypedState::decode(&self.cfg.build_version, buf))
})
}

Expand Down
2 changes: 1 addition & 1 deletion src/persist/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ futures-util = "0.3.25"
once_cell = "1.16.0"
md-5 = "0.10.5"
mz-aws-s3-util = { path = "../aws-s3-util" }
mz-ore = { path = "../ore", default-features = false, features = ["metrics", "async"] }
mz-ore = { path = "../ore", default-features = false, features = ["metrics", "async", "bytes_"] }
mz-persist-types = { path = "../persist-types" }
mz-proto = { path = "../proto" }
openssl = { version = "0.10.48", features = ["vendored"] }
Expand Down
5 changes: 3 additions & 2 deletions src/persist/src/file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use anyhow::anyhow;
use async_trait::async_trait;
use bytes::Bytes;
use fail::fail_point;
use mz_ore::bytes::SegmentedBytes;
use mz_ore::cast::CastFrom;
use tokio::fs::{self, File, OpenOptions};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
Expand Down Expand Up @@ -75,7 +76,7 @@ impl FileBlob {

#[async_trait]
impl Blob for FileBlob {
async fn get(&self, key: &str) -> Result<Option<Vec<u8>>, ExternalError> {
async fn get(&self, key: &str) -> Result<Option<SegmentedBytes>, ExternalError> {
let file_path = self.blob_path(&FileBlob::replace_forward_slashes(key));
let mut file = match File::open(file_path).await {
Ok(file) => file,
Expand All @@ -84,7 +85,7 @@ impl Blob for FileBlob {
};
let mut buf = Vec::new();
file.read_to_end(&mut buf).await?;
Ok(Some(buf))
Ok(Some(SegmentedBytes::from(buf)))
}

async fn list_keys_and_metadata(
Expand Down
6 changes: 3 additions & 3 deletions src/persist/src/indexed/encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
// structs.

use std::fmt::{self, Debug};
use std::io::Cursor;
use std::marker::PhantomData;

use bytes::BufMut;
use differential_dataflow::trace::Description;
use mz_ore::bytes::SegmentedBytes;
use mz_ore::cast::CastFrom;
use mz_persist_types::Codec64;
use prost::Message;
Expand Down Expand Up @@ -223,8 +223,8 @@ impl<T: Timestamp + Codec64> BlobTraceBatchPart<T> {
}

/// Decodes a BlobTraceBatchPart from the Parquet format.
pub fn decode<'a>(buf: &'a [u8]) -> Result<Self, Error> {
decode_trace_parquet(&mut Cursor::new(&buf))
pub fn decode(buf: &SegmentedBytes) -> Result<Self, Error> {
decode_trace_parquet(&mut buf.clone().reader())
}
}

Expand Down
3 changes: 2 additions & 1 deletion src/persist/src/intercept.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use std::sync::{Arc, Mutex};

use async_trait::async_trait;
use bytes::Bytes;
use mz_ore::bytes::SegmentedBytes;

use crate::location::{Atomicity, Blob, BlobMetadata, ExternalError};

Expand Down Expand Up @@ -76,7 +77,7 @@ impl InterceptBlob {

#[async_trait]
impl Blob for InterceptBlob {
async fn get(&self, key: &str) -> Result<Option<Vec<u8>>, ExternalError> {
async fn get(&self, key: &str) -> Result<Option<SegmentedBytes>, ExternalError> {
self.blob.get(key).await
}

Expand Down
58 changes: 46 additions & 12 deletions src/persist/src/location.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use std::time::Instant;
use anyhow::anyhow;
use async_trait::async_trait;
use bytes::Bytes;
use mz_ore::bytes::SegmentedBytes;
use mz_ore::cast::u64_to_usize;
use mz_proto::RustType;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -394,7 +395,7 @@ pub const BLOB_GET_LIVENESS_KEY: &str = "LIVENESS";
#[async_trait]
pub trait Blob: std::fmt::Debug {
/// Returns a reference to the value corresponding to the key.
async fn get(&self, key: &str) -> Result<Option<Vec<u8>>, ExternalError>;
async fn get(&self, key: &str) -> Result<Option<SegmentedBytes>, ExternalError>;

/// List all of the keys in the map with metadata about the entry.
///
Expand Down Expand Up @@ -488,15 +489,27 @@ pub mod tests {
blob0
.set(k0, values[0].clone().into(), AllowNonAtomic)
.await?;
assert_eq!(blob0.get(k0).await?, Some(values[0].clone()));
assert_eq!(blob1.get(k0).await?, Some(values[0].clone()));
assert_eq!(
blob0.get(k0).await?.map(|s| s.into_contiguous()),
Some(values[0].clone())
);
assert_eq!(
blob1.get(k0).await?.map(|s| s.into_contiguous()),
Some(values[0].clone())
);

// Set a key with RequireAtomic and get it back.
blob0
.set("k0a", values[0].clone().into(), RequireAtomic)
.await?;
assert_eq!(blob0.get("k0a").await?, Some(values[0].clone()));
assert_eq!(blob1.get("k0a").await?, Some(values[0].clone()));
assert_eq!(
blob0.get("k0a").await?.map(|s| s.into_contiguous()),
Some(values[0].clone())
);
assert_eq!(
blob1.get("k0a").await?.map(|s| s.into_contiguous()),
Some(values[0].clone())
);

// Blob contains the key we just inserted.
let mut blob_keys = get_keys(&blob0).await?;
Expand All @@ -510,14 +523,26 @@ pub mod tests {
blob0
.set(k0, values[1].clone().into(), AllowNonAtomic)
.await?;
assert_eq!(blob0.get(k0).await?, Some(values[1].clone()));
assert_eq!(blob1.get(k0).await?, Some(values[1].clone()));
assert_eq!(
blob0.get(k0).await?.map(|s| s.into_contiguous()),
Some(values[1].clone())
);
assert_eq!(
blob1.get(k0).await?.map(|s| s.into_contiguous()),
Some(values[1].clone())
);
// Can overwrite a key with RequireAtomic.
blob0
.set("k0a", values[1].clone().into(), RequireAtomic)
.await?;
assert_eq!(blob0.get("k0a").await?, Some(values[1].clone()));
assert_eq!(blob1.get("k0a").await?, Some(values[1].clone()));
assert_eq!(
blob0.get("k0a").await?.map(|s| s.into_contiguous()),
Some(values[1].clone())
);
assert_eq!(
blob1.get("k0a").await?.map(|s| s.into_contiguous()),
Some(values[1].clone())
);

// Can delete a key.
assert_eq!(blob0.delete(k0).await, Ok(Some(2)));
Expand Down Expand Up @@ -545,8 +570,14 @@ pub mod tests {
blob0
.set(k0, values[1].clone().into(), AllowNonAtomic)
.await?;
assert_eq!(blob1.get(k0).await?, Some(values[1].clone()));
assert_eq!(blob0.get(k0).await?, Some(values[1].clone()));
assert_eq!(
blob1.get(k0).await?.map(|s| s.into_contiguous()),
Some(values[1].clone())
);
assert_eq!(
blob0.get(k0).await?.map(|s| s.into_contiguous()),
Some(values[1].clone())
);

// Insert multiple keys back to back and validate that we can list
// them all out.
Expand Down Expand Up @@ -588,7 +619,10 @@ pub mod tests {

// We can open a new blob to the same path and use it.
let blob3 = new_fn("path0").await?;
assert_eq!(blob3.get(k0).await?, Some(values[1].clone()));
assert_eq!(
blob3.get(k0).await?.map(|s| s.into_contiguous()),
Some(values[1].clone())
);

Ok(())
}
Expand Down
Loading

0 comments on commit 6920a93

Please sign in to comment.