Skip to content

Commit

Permalink
Merge b7453fa into 2f7bf36
Browse files Browse the repository at this point in the history
  • Loading branch information
nlhepler committed Oct 31, 2019
2 parents 2f7bf36 + b7453fa commit 312e141
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 41 deletions.
9 changes: 6 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,18 @@ failure = "0.1.5"
crossbeam-channel = "0.3.8"
failure_derive = "0.1"
min-max-heap = "1.2.2"
log = "^0.4"
rustc_version_runtime = "^0.1"
semver = "^0.9"

[dev-dependencies]
tempfile = "3.0"
tempfile = "^3.1"
criterion = "0.2.3"
fxhash = "0.2.1"
lazy_static = "1.2"
pretty_assertions = "0.5.1"
quickcheck = "0.8.5"
rand = "0.6.5"
quickcheck = "^0.9"
rand = "^0.7"
is_sorted = "0.1"

[[bench]]
Expand Down
1 change: 0 additions & 1 deletion src/helper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ pub struct ThreadProxyWriter<T: Send + Write> {
}

impl<T: 'static + Send + Write> ThreadProxyWriter<T> {

/// Create a new `ThreadProxyWriter` that will write to `writer` on a newly created thread
pub fn new(mut writer: T, buffer_size: usize) -> ThreadProxyWriter<T> {
let (tx, rx) = bounded::<Option<Vec<u8>>>(10);
Expand Down
122 changes: 85 additions & 37 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
//! let mut all_items_sorted = all_items.clone();
//! all_items.sort();
//! assert_eq!(all_items, all_items_sorted);
//! std::fs::remove_file(filename)?;
//! Ok(())
//! }
//! ```
Expand All @@ -77,10 +78,11 @@
use crossbeam_channel;
use lz4;

#[macro_use]
#[macro_use]
extern crate serde_derive;
use serde::{Serialize, de::DeserializeOwned};
use serde::{de::DeserializeOwned, Serialize};

use std::any::type_name;
use std::borrow::Cow;
use std::collections::BTreeSet;
use std::fs::File;
Expand All @@ -95,14 +97,16 @@ use std::marker::PhantomData;
use std::thread;
use std::thread::JoinHandle;


use bincode::{deserialize_from, serialize_into};
use bincode::serialize_into;
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};

use libc::{c_void, off_t, pread, pwrite, size_t, ssize_t};

use failure::{format_err, Error};

use log::warn;
use rustc_version_runtime::version;
use semver::Version;

/// Represent a range of key space
pub mod range;
Expand Down Expand Up @@ -274,7 +278,6 @@ impl<K: Ord + Serialize> FileManager<K> {
/// }
/// ```
pub trait SortKey<T> {

/// The type of the key that will be sorted.
type Key: Ord + Clone;

Expand Down Expand Up @@ -624,7 +627,15 @@ where
fn write_index_block(&mut self) -> Result<(), Error> {
let mut buf = Vec::new();

serialize_into(&mut buf, &self.writer.regions)?;
serialize_into(
&mut buf,
&(
version(),
type_name::<T>(),
type_name::<S>(),
&self.writer.regions,
),
)?;

let index_block_position = self.writer.cursor;
let index_block_size = buf.len();
Expand Down Expand Up @@ -728,6 +739,7 @@ where
{
file: File,
index: Vec<ShardRecord<<S as SortKey<T>>::Key>>,
binconfig: bincode::Config,
p1: PhantomData<T>,
}

Expand All @@ -741,27 +753,55 @@ where
fn open<P: AsRef<Path>>(path: P) -> Result<ShardReaderSingle<T, S>, Error> {
let mut f = File::open(path).unwrap();

let mut index = Self::read_index_block(&mut f)?;
let mut binconfig = bincode::config();
// limit ourselves to decoding no more than 268MB at a time
binconfig.limit(1 << 28);

let mut index = Self::read_index_block(&mut f, &binconfig)?;
index.sort();

Ok(ShardReaderSingle {
file: f,
index,
binconfig,
p1: PhantomData,
})
}

/// Read shard index
fn read_index_block(
file: &mut File,
binconfig: &bincode::Config,
) -> Result<Vec<ShardRecord<<S as SortKey<T>>::Key>>, Error> {
let _ = file.seek(SeekFrom::End(-24))?;
let _num_shards = file.read_u64::<BigEndian>()? as usize;
let index_block_position = file.read_u64::<BigEndian>()?;
let _ = file.read_u64::<BigEndian>()?;
file.seek(SeekFrom::Start(index_block_position as u64))?;

Ok(deserialize_from(file)?)
let (ver, t_typ, s_typ, recs): (Version, String, String, _) =
binconfig.deserialize_from(file)?;
if ver != version() {
warn!(
"expected compiler version {}, got {}; types may be incompatible",
version(),
ver
);
}
if ver == version() && t_typ != type_name::<T>() {
return Err(format_err!(
"expected shardio type {}, got {}",
type_name::<T>(),
t_typ
));
}
if ver == version() && s_typ != type_name::<S>() {
return Err(format_err!(
"expected shardio sort {}, got {}",
type_name::<S>(),
s_typ
));
}
Ok(recs)
}

fn get_decoder(buffer: &mut Vec<u8>) -> lz4::Decoder<&[u8]> {
Expand Down Expand Up @@ -790,7 +830,7 @@ where
assert_eq!(read_len, rec.len_bytes);

let mut decoder = Self::get_decoder(buf);
let r: Vec<T> = deserialize_from(&mut decoder)?;
let r: Vec<T> = self.binconfig.deserialize_from(&mut decoder)?;
data.extend(r.into_iter().filter(|x| range.contains(&S::sort_key(x))));
}

Expand Down Expand Up @@ -858,6 +898,7 @@ where
next_item: Option<T>,
decoder: lz4::Decoder<BufReader<ReadAdapter<'a>>>,
items_remaining: usize,
binconfig: &'a bincode::Config,
phantom_s: PhantomData<S>,
}

Expand All @@ -875,13 +916,15 @@ where
let buf_reader = BufReader::new(adp_reader);
let mut lz4_reader = lz4::Decoder::new(buf_reader)?;

let first_item: T = deserialize_from(&mut lz4_reader)?;
let binconfig = &reader.binconfig;
let first_item: T = binconfig.deserialize_from(&mut lz4_reader)?;
let items_remaining = rec.len_items - 1;

Ok(ShardIter {
next_item: Some(first_item),
decoder: lz4_reader,
items_remaining,
binconfig,
phantom_s: PhantomData,
})
}
Expand All @@ -895,7 +938,7 @@ where
if self.items_remaining == 0 {
Ok((item, None))
} else {
self.next_item = Some(deserialize_from(&mut self.decoder)?);
self.next_item = Some(self.binconfig.deserialize_from(&mut self.decoder)?);
self.items_remaining -= 1;
Ok((item, Some(self)))
}
Expand Down Expand Up @@ -1316,16 +1359,16 @@ where
#[cfg(test)]
mod shard_tests {
use super::*;
use is_sorted::IsSorted;
use pretty_assertions::assert_eq;
use quickcheck::{Arbitrary, Gen, QuickCheck, StdThreadGen};
use rand::Rng;
use std::collections::HashSet;
use std::fmt::Debug;
use std::hash::Hash;
use std::iter::{repeat, FromIterator};
use std::u8;
use tempfile;
use pretty_assertions::assert_eq;
use quickcheck::{QuickCheck, Arbitrary, Gen, StdThreadGen};
use rand::Rng;
use is_sorted::IsSorted;

#[derive(Copy, Clone, Eq, PartialEq, Serialize, Deserialize, Debug, PartialOrd, Ord, Hash)]
struct T1 {
Expand All @@ -1337,8 +1380,8 @@ mod shard_tests {

impl Arbitrary for T1 {
fn arbitrary<G: Gen>(g: &mut G) -> T1 {
T1 {
a: g.gen(),
T1 {
a: g.gen(),
b: g.gen(),
c: g.gen(),
d: g.gen(),
Expand Down Expand Up @@ -1508,7 +1551,7 @@ mod shard_tests {

let mut data = Vec::new();

for _ in 0 .. slices {
for _ in 0..slices {
let slice = Vec::<T>::arbitrary(g);
data.push(slice);
}
Expand All @@ -1517,26 +1560,29 @@ mod shard_tests {
}
}


fn test_multi_slice<T, S>(items: MultiSlice<T>, disk_chunk_size: usize, producer_chunk_size: usize, buffer_size: usize) -> Result<Vec<T>, Error>
fn test_multi_slice<T, S>(
items: MultiSlice<T>,
disk_chunk_size: usize,
producer_chunk_size: usize,
buffer_size: usize,
) -> Result<Vec<T>, Error>
where
T: 'static + Serialize + DeserializeOwned + Clone + Send,
S: SortKey<T>,
<S as SortKey<T>>::Key: 'static + Send + Serialize + DeserializeOwned,
{

let mut files = Vec::new();

for item_chunk in &items.0 {
let tmp = tempfile::NamedTempFile::new()?;

let writer: ShardWriter<T, S> = ShardWriter::new(
tmp.path(),
producer_chunk_size,
disk_chunk_size,
buffer_size)?;

tmp.path(),
producer_chunk_size,
disk_chunk_size,
buffer_size,
)?;

let mut sender = writer.get_sender();
for item in item_chunk {
sender.send(item.clone())?;
Expand All @@ -1545,7 +1591,7 @@ mod shard_tests {
files.push(tmp);
}

let reader = ShardReader::<T,S>::open_set(&files)?;
let reader = ShardReader::<T, S>::open_set(&files)?;
let mut out_items = Vec::new();

for r in reader.iter()? {
Expand All @@ -1555,27 +1601,29 @@ mod shard_tests {
Ok(out_items)
}


#[test]
fn multi_slice_correctness_quickcheck() {

fn check_t1(v: MultiSlice<T1>) -> bool {
let sorted = test_multi_slice::<T1, FieldDSort>(v.clone(), 1024, 1<<17, 16).unwrap();
let sorted = test_multi_slice::<T1, FieldDSort>(v.clone(), 1024, 1 << 17, 16).unwrap();

let mut vall = Vec::new();
for chunk in v.0 {
vall.extend(chunk);
}

if sorted.len() != vall.len() { return false; }
if !set_compare(&sorted, &vall) { return false; }
if sorted.len() != vall.len() {
return false;
}
if !set_compare(&sorted, &vall) {
return false;
}
IsSorted::is_sorted_by_key(&mut sorted.iter(), |x| FieldDSort::sort_key(x).into_owned())
}


QuickCheck::with_gen(StdThreadGen::new(500000)).tests(4).quickcheck(check_t1 as fn(MultiSlice<T1>) -> bool);
QuickCheck::with_gen(StdThreadGen::new(500000))
.tests(4)
.quickcheck(check_t1 as fn(MultiSlice<T1>) -> bool);
}


fn check_round_trip(
disk_chunk_size: usize,
Expand Down

0 comments on commit 312e141

Please sign in to comment.