Skip to content

Commit

Permalink
*: fix potential UBs in memory table and tests. DO NOT USE UNSAFE RUS…
Browse files Browse the repository at this point in the history
…T UNLESS YOU ARE CONFIDENT TO THIS

Signed-off-by: Fullstop000 <fullstop1005@gmail.com>
  • Loading branch information
Fullstop000 committed Sep 20, 2019
1 parent c01183a commit bf8f4b3
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 48 deletions.
2 changes: 1 addition & 1 deletion src/db/format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ impl Comparator for InternalKeyComparator {
let ua = extract_user_key(a);
let ub = extract_user_key(b);
// compare user key first
match ua.compare(&ub) {
match self.user_comparator.compare(ua.as_slice(), ub.as_slice()) {
Ordering::Greater => Ordering::Greater,
Ordering::Less => Ordering::Less,
Ordering::Equal => {
Expand Down
21 changes: 10 additions & 11 deletions src/mem/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,33 +72,34 @@ pub trait MemoryTable {
// KeyComparator is a wrapper for InternalKeyComparator. It will convert the input mem key
// to the internal key before comparing.
struct KeyComparator {
cmp: Arc<InternalKeyComparator>,
icmp: Arc<InternalKeyComparator>,
}

impl Comparator for KeyComparator {
fn compare(&self, a: &[u8], b: &[u8]) -> Ordering {
let ia = extract_varint32_encoded_slice(&mut Slice::from(a));
let ib = extract_varint32_encoded_slice(&mut Slice::from(b));
if ia.is_empty() || ib.is_empty() {
// Use memcmp directly
ia.compare(&ib)
} else {
self.cmp.compare(ia.as_slice(), ib.as_slice())
self.icmp.compare(ia.as_slice(), ib.as_slice())
}
}

fn name(&self) -> &str {
self.cmp.name()
self.icmp.name()
}

fn separator(&self, a: &[u8], b: &[u8]) -> Vec<u8> {
let ia = extract_varint32_encoded_slice(&mut Slice::from(a));
let ib = extract_varint32_encoded_slice(&mut Slice::from(b));
self.cmp.separator(ia.as_slice(), ib.as_slice())
self.icmp.separator(ia.as_slice(), ib.as_slice())
}

fn successor(&self, key: &[u8]) -> Vec<u8> {
let ia = extract_varint32_encoded_slice(&mut Slice::from(key));
self.cmp.successor(ia.as_slice())
self.icmp.successor(ia.as_slice())
}
}

Expand All @@ -109,9 +110,9 @@ pub struct MemTable {
}

impl MemTable {
pub fn new(cmp: Arc<InternalKeyComparator>) -> Self {
pub fn new(icmp: Arc<InternalKeyComparator>) -> Self {
let arena = BlockArena::new();
let kcmp = Arc::new(KeyComparator { cmp });
let kcmp = Arc::new(KeyComparator { icmp });
let table = Arc::new(Skiplist::new(kcmp.clone(), Box::new(arena)));
Self { cmp: kcmp, table }
}
Expand All @@ -129,14 +130,12 @@ impl MemoryTable for MemTable {
fn add(&self, seq_number: u64, val_type: ValueType, key: &[u8], value: &[u8]) {
let key_size = key.len();
let internal_key_size = key_size + 8;
// TODO: use pre-allocated buf
let mut buf = vec![];
VarintU32::put_varint(&mut buf, internal_key_size as u32);
buf.extend_from_slice(key);
put_fixed_64(&mut buf, (seq_number << 8) | val_type as u64);
VarintU32::put_varint_prefixed_slice(&mut buf, value);
// TODO: remove redundant copying
self.table.insert(Slice::from(buf.as_slice()))
self.table.insert(buf);
}

fn get(&self, key: &LookupKey) -> Option<Result<Slice>> {
Expand All @@ -147,7 +146,7 @@ impl MemoryTable for MemTable {
if iter.valid() {
let internal_key = iter.key();
// only check the user key here
match self.cmp.cmp.user_comparator.compare(
match self.cmp.icmp.user_comparator.compare(
Slice::new(internal_key.as_ptr(), internal_key.size() - 8).as_slice(),
key.user_key().as_slice(),
) {
Expand Down
27 changes: 14 additions & 13 deletions src/mem/skiplist.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,24 +111,25 @@ impl Skiplist {
}
}

/// Insert a node into the skiplist by given key.
/// The key must be unique otherwise this method panic.
/// Insert the given key as a node into the skiplist.
/// The key must be unique otherwise this method panics.
///
/// # NOTICE:
///
/// Concurrent insertion is not thread safe but concurrent reading with a
/// single writer is safe.
///
pub fn insert(&self, key: Slice) {
pub fn insert(&self, key: Vec<u8>) {
let mut prev = [ptr::null_mut(); MAX_HEIGHT];
let node = self.find_greater_or_equal(&key, Some(&mut prev));
let slc = Slice::from(&key);
let node = self.find_greater_or_equal(&slc, Some(&mut prev));
if !node.is_null() {
unsafe {
assert_ne!(
(&(*node)).key().compare(&key),
(&(*node)).key().compare(&slc),
CmpOrdering::Equal,
"[skiplist] duplicate insertion [key={:?}] is not allowed",
key
&key
);
}
}
Expand All @@ -142,13 +143,13 @@ impl Skiplist {
self.max_height.store(height, Ordering::Release);
}
// allocate the key
let k = self.arena.allocate(key.size());
let k = self.arena.allocate(key.len());
unsafe {
copy_nonoverlapping(key.as_ptr(), k, key.size());
copy_nonoverlapping(key.as_ptr(), k, key.len());
}
// allocate the node
let new_node = Node::new(
Slice::new(k as *const u8, key.size()),
Slice::new(k as *const u8, key.len()),
height,
self.arena.as_ref(),
);
Expand Down Expand Up @@ -531,7 +532,7 @@ mod tests {
let inputs = vec!["key1", "key3", "key5", "key7", "key9"];
let skl = new_test_skl();
for key in inputs.clone().drain(..) {
skl.insert(Slice::from(key));
skl.insert(Vec::from(key));
}

let mut node = skl.head;
Expand All @@ -555,7 +556,7 @@ mod tests {
let mut inputs = vec!["key1", "key1"];
let skl = new_test_skl();
for key in inputs.drain(..) {
skl.insert(Slice::from(key));
skl.insert(Vec::from(key));
}
}

Expand All @@ -572,7 +573,7 @@ mod tests {
let skl = new_test_skl();
let inputs = vec!["key1", "key11", "key13", "key3", "key5", "key7", "key9"];
for key in inputs.clone().drain(..) {
skl.insert(Slice::from(key))
skl.insert(Vec::from(key))
}
let mut iter = SkiplistIterator::new(Arc::new(skl));
assert_eq!(ptr::null_mut(), iter.node,);
Expand Down Expand Up @@ -736,7 +737,7 @@ mod tests {
let key = make_key(k as u64, g as u64);
let mut bytes = vec![];
put_fixed_64(&mut bytes, key);
self.list.insert(Slice::from(&bytes));
self.list.insert(bytes);
self.current.set(k, g);
}

Expand Down
4 changes: 2 additions & 2 deletions src/sstable/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,11 @@ pub struct BlockIterator {
not_shared: u32, // not shared length
value_len: u32, // value length
key_offset: u32, // the offset of the key in the block
// TODO: remmove this buffer
// TODO: remmove this buffer
// Removing this buffer might be difficult becasue the key
// could be formed by multiple segments which means we should
// maintain predictable amount of offsets for each key.
key: Vec<u8>, // buffer for a completed key
key: Vec<u8>, // buffer for a completed key
}

impl BlockIterator {
Expand Down
58 changes: 37 additions & 21 deletions src/sstable/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ mod test_footer {
#[cfg(test)]
mod tests {
use crate::db::format::{
InternalKeyComparator, ParsedInternalKey, ValueType, MAX_KEY_SEQUENCE,
InternalKeyComparator, LookupKey, ParsedInternalKey, ValueType, MAX_KEY_SEQUENCE,
};
use crate::db::{WickDB, DB};
use crate::iterator::{EmptyIterator, Iterator};
Expand Down Expand Up @@ -524,7 +524,7 @@ mod tests {
}
}

// A helper struct to convert user key into internal key for inner iterator
// A helper struct to convert user key into lookup key for inner iterator
struct KeyConvertingIterator {
inner: Box<dyn Iterator>,
err: Cell<Option<WickErr>>,
Expand Down Expand Up @@ -553,8 +553,8 @@ mod tests {
}

fn seek(&mut self, target: &Slice) {
let ikey = ParsedInternalKey::new(target.clone(), MAX_KEY_SEQUENCE, ValueType::Value);
self.inner.seek(&Slice::from(ikey.encode().data()))
let lkey = LookupKey::new(target.as_slice(), MAX_KEY_SEQUENCE);
self.inner.seek(&lkey.mem_key());
}

fn next(&mut self) {
Expand Down Expand Up @@ -603,7 +603,7 @@ mod tests {
impl EntryIterator {
fn new(cmp: Arc<dyn Comparator>, data: Vec<(Vec<u8>, Vec<u8>)>) -> Self {
Self {
current: 0,
current: data.len(),
data,
cmp,
}
Expand Down Expand Up @@ -673,25 +673,22 @@ mod tests {
}

struct MemTableConstructor {
icmp: Arc<InternalKeyComparator>,
memtable: MemTable,
inner: MemTable,
}

impl MemTableConstructor {
fn new(cmp: Arc<dyn Comparator>) -> Self {
let icmp = Arc::new(InternalKeyComparator::new(cmp));
Self {
icmp: icmp.clone(),
memtable: MemTable::new(icmp),
inner: MemTable::new(icmp),
}
}
}

impl Constructor for MemTableConstructor {
fn finish(&mut self, _options: Arc<Options>, data: &[(Vec<u8>, Vec<u8>)]) -> Result<()> {
let memtable = MemTable::new(self.icmp.clone());
for (seq, (key, value)) in data.iter().enumerate() {
memtable.add(
self.inner.add(
seq as u64 + 1,
ValueType::Value,
key.as_slice(),
Expand All @@ -702,7 +699,7 @@ mod tests {
}

fn iter(&self) -> Box<dyn Iterator> {
Box::new(KeyConvertingIterator::new(self.memtable.iter()))
Box::new(KeyConvertingIterator::new(self.inner.iter()))
}
}

Expand Down Expand Up @@ -779,6 +776,7 @@ mod tests {

struct TestHarness {
options: Arc<Options>,
reverse_cmp: bool,
inner: CommonConstructor,
rand: ThreadRng,
}
Expand All @@ -805,6 +803,7 @@ mod tests {
};
TestHarness {
inner: CommonConstructor::new(constructor),
reverse_cmp,
rand: rand::thread_rng(),
options: Arc::new(options),
}
Expand Down Expand Up @@ -861,7 +860,7 @@ mod tests {
"iterator should be invalid after being initialized"
);
let mut expected_iter = EntryIterator::new(self.options.comparator.clone(), expected);
for _ in 0..100 {
for _ in 0..1000 {
match self.rand.gen_range(0, 5) {
// case for `next`
0 => {
Expand All @@ -873,6 +872,8 @@ mod tests {
format_entry(iter.as_ref()),
format_entry(&expected_iter)
);
} else {
assert_eq!(iter.valid(), expected_iter.valid());
}
}
}
Expand All @@ -882,15 +883,20 @@ mod tests {
expected_iter.seek_to_first();
if iter.valid() {
assert_eq!(format_entry(iter.as_ref()), format_entry(&expected_iter));
} else {
assert_eq!(iter.valid(), expected_iter.valid());
}
}
// case for `seek`
2 => {
let key = Slice::from(random_key(keys).as_slice());
let rkey = random_key(keys, self.reverse_cmp);
let key = Slice::from(rkey.as_slice());
iter.seek(&key);
expected_iter.seek(&key);
if iter.valid() {
assert_eq!(format_entry(iter.as_ref()), format_entry(&expected_iter));
} else {
assert_eq!(iter.valid(), expected_iter.valid());
}
}
// case for `prev`
Expand All @@ -903,6 +909,8 @@ mod tests {
format_entry(iter.as_ref()),
format_entry(&expected_iter)
);
} else {
assert_eq!(iter.valid(), expected_iter.valid());
}
}
}
Expand All @@ -912,6 +920,8 @@ mod tests {
expected_iter.seek_to_last();
if iter.valid() {
assert_eq!(format_entry(iter.as_ref()), format_entry(&expected_iter));
} else {
assert_eq!(iter.valid(), expected_iter.valid());
}
}
_ => { /* ignore */ }
Expand Down Expand Up @@ -945,15 +955,15 @@ mod tests {
format!("'{:?}->{:?}'", iter.key(), iter.value())
}

fn random_key(keys: &[Vec<u8>]) -> Vec<u8> {
fn random_key(keys: &[Vec<u8>], reverse_cmp: bool) -> Vec<u8> {
if keys.is_empty() {
b"foo".to_vec()
} else {
let mut rnd = rand::thread_rng();
let result = keys.get(rnd.gen_range(0, keys.len())).unwrap();
match rnd.gen_range(0, 3) {
0 => result.clone(),
1 => {
// Attempt to return something smaller than an existing key
let mut cloned = result.clone();
if !cloned.is_empty() && *cloned.last().unwrap() > 0u8 {
let last = cloned.last_mut().unwrap();
Expand All @@ -962,11 +972,16 @@ mod tests {
cloned
}
2 => {
// Return something larger than an existing key
let mut cloned = result.clone();
cloned.push(0);
if reverse_cmp {
cloned.insert(0, 0)
} else {
cloned.push(0);
}
cloned
}
_ => result.clone(),
_ => result.clone(), // Return an existing key
}
}
}
Expand All @@ -975,7 +990,8 @@ mod tests {
Table,
Block,
Memtable,
DB,
#[allow(dead_code)]
DB, // Enable DB test util fundamental components are stable
}

fn new_test_suits() -> Vec<TestHarness> {
Expand All @@ -996,8 +1012,8 @@ mod tests {
(TestType::Memtable, false, 16),
(TestType::Memtable, true, 16),
// Do not bother with restart interval variations for DB
(TestType::DB, false, 16),
(TestType::DB, true, 16),
// (TestType::DB, false, 16),
// (TestType::DB, true, 16),
];
let mut results = vec![];
for (t, reverse_cmp, restart_interval) in tests.drain(..) {
Expand Down

0 comments on commit bf8f4b3

Please sign in to comment.