Skip to content

Commit

Permalink
std: Fix overflow of HashMap's capacity
Browse files Browse the repository at this point in the history
  • Loading branch information
pczarn committed Sep 4, 2014
1 parent ae7342a commit 27f87c6
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 32 deletions.
81 changes: 49 additions & 32 deletions src/libstd/collections/hashmap/table.rs
Expand Up @@ -526,32 +526,45 @@ fn test_rounding() {
assert_eq!(round_up_to_next(5, 4), 8);
}

// Returns a tuple of (minimum required malloc alignment, hash_offset,
// key_offset, val_offset, array_size), from the start of a mallocated array.
fn calculate_offsets(
hash_size: uint, hash_align: uint,
keys_size: uint, keys_align: uint,
vals_size: uint, vals_align: uint) -> (uint, uint, uint, uint, uint) {
// Returns a tuple of (key_offset, val_offset),
// from the start of a mallocated array.
fn calculate_offsets(hashes_size: uint,
keys_size: uint, keys_align: uint,
vals_align: uint)
-> (uint, uint) {
let keys_offset = round_up_to_next(hashes_size, keys_align);
let end_of_keys = keys_offset + keys_size;

let hash_offset = 0;
let end_of_hashes = hash_offset + hash_size;
let vals_offset = round_up_to_next(end_of_keys, vals_align);

let keys_offset = round_up_to_next(end_of_hashes, keys_align);
let end_of_keys = keys_offset + keys_size;
(keys_offset, vals_offset)
}

let vals_offset = round_up_to_next(end_of_keys, vals_align);
let end_of_vals = vals_offset + vals_size;
// Returns a tuple of (minimum required malloc alignment, hash_offset,
// array_size), from the start of a mallocated array.
fn calculate_allocation(hash_size: uint, hash_align: uint,
keys_size: uint, keys_align: uint,
vals_size: uint, vals_align: uint)
-> (uint, uint, uint) {
let hash_offset = 0;
let (_, vals_offset) = calculate_offsets(hash_size,
keys_size, keys_align,
vals_align);
let end_of_vals = vals_offset + vals_size;

let min_align = cmp::max(hash_align, cmp::max(keys_align, vals_align));

(min_align, hash_offset, keys_offset, vals_offset, end_of_vals)
(min_align, hash_offset, end_of_vals)
}

#[test]
fn test_offset_calculation() {
assert_eq!(calculate_offsets(128, 8, 15, 1, 4, 4 ), (8, 0, 128, 144, 148));
assert_eq!(calculate_offsets(3, 1, 2, 1, 1, 1 ), (1, 0, 3, 5, 6));
assert_eq!(calculate_offsets(6, 2, 12, 4, 24, 8), (8, 0, 8, 24, 48));
assert_eq!(calculate_allocation(128, 8, 15, 1, 4, 4), (8, 0, 148));
assert_eq!(calculate_allocation(3, 1, 2, 1, 1, 1), (1, 0, 6));
assert_eq!(calculate_allocation(6, 2, 12, 4, 24, 8), (8, 0, 48));
assert_eq!(calculate_offsets(128, 15, 1, 4), (128, 144));
assert_eq!(calculate_offsets(3, 2, 1, 1), (3, 5));
assert_eq!(calculate_offsets(6, 12, 4, 8), (8, 24));
}

impl<K, V> RawTable<K, V> {
Expand All @@ -566,12 +579,11 @@ impl<K, V> RawTable<K, V> {
marker: marker::CovariantType,
};
}
let hashes_size = capacity.checked_mul(&size_of::<u64>())
.expect("capacity overflow");
let keys_size = capacity.checked_mul(&size_of::< K >())
.expect("capacity overflow");
let vals_size = capacity.checked_mul(&size_of::< V >())
.expect("capacity overflow");
// No need for `checked_mul` before a more restrictive check performed
// later in this method.
let hashes_size = capacity * size_of::<u64>();
let keys_size = capacity * size_of::< K >();
let vals_size = capacity * size_of::< V >();

// Allocating hashmaps is a little tricky. We need to allocate three
// arrays, but since we know their sizes and alignments up front,
Expand All @@ -581,12 +593,19 @@ impl<K, V> RawTable<K, V> {
// This is great in theory, but in practice getting the alignment
// right is a little subtle. Therefore, calculating offsets has been
// factored out into a different function.
let (malloc_alignment, hash_offset, _, _, size) =
calculate_offsets(
let (malloc_alignment, hash_offset, size) =
calculate_allocation(
hashes_size, min_align_of::<u64>(),
keys_size, min_align_of::< K >(),
vals_size, min_align_of::< V >());

// One check for overflow that covers calculation and rounding of size.
let size_of_bucket = size_of::<u64>().checked_add(&size_of::<K>()).unwrap()
.checked_add(&size_of::<V>()).unwrap();
assert!(size >= capacity.checked_mul(&size_of_bucket)
.expect("capacity overflow"),
"capacity overflow");

let buffer = allocate(size, malloc_alignment);

let hashes = buffer.offset(hash_offset as int) as *mut u64;
Expand All @@ -603,12 +622,10 @@ impl<K, V> RawTable<K, V> {
let hashes_size = self.capacity * size_of::<u64>();
let keys_size = self.capacity * size_of::<K>();

let keys_offset = (hashes_size + min_align_of::<K>() - 1) & !(min_align_of::<K>() - 1);
let end_of_keys = keys_offset + keys_size;

let vals_offset = (end_of_keys + min_align_of::<V>() - 1) & !(min_align_of::<V>() - 1);

let buffer = self.hashes as *mut u8;
let (keys_offset, vals_offset) = calculate_offsets(hashes_size,
keys_size, min_align_of::<K>(),
min_align_of::<V>());

unsafe {
RawBucket {
Expand Down Expand Up @@ -866,9 +883,9 @@ impl<K, V> Drop for RawTable<K, V> {
let hashes_size = self.capacity * size_of::<u64>();
let keys_size = self.capacity * size_of::<K>();
let vals_size = self.capacity * size_of::<V>();
let (align, _, _, _, size) = calculate_offsets(hashes_size, min_align_of::<u64>(),
keys_size, min_align_of::<K>(),
vals_size, min_align_of::<V>());
let (align, _, size) = calculate_allocation(hashes_size, min_align_of::<u64>(),
keys_size, min_align_of::<K>(),
vals_size, min_align_of::<V>());

unsafe {
deallocate(self.hashes as *mut u8, size, align);
Expand Down
21 changes: 21 additions & 0 deletions src/test/run-fail/hashmap-capacity-overflow.rs
@@ -0,0 +1,21 @@
// Copyright 2014 The Rust Project Developers. See the COPYRIGHT
// file at the top-level directory of this distribution and at
// http://rust-lang.org/COPYRIGHT.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

// error-pattern:capacity overflow

use std::collections::hashmap::HashMap;
use std::uint;
use std::mem::size_of;

fn main() {
let threshold = uint::MAX / size_of::<(u64, u64, u64)>();
let mut h = HashMap::<u64, u64>::with_capacity(threshold + 100);
h.insert(0, 0);
}

0 comments on commit 27f87c6

Please sign in to comment.