Skip to content

Commit

Permalink
apacheGH-35270: [C++] Use Buffer instead of raw buffer in hash join i…
Browse files Browse the repository at this point in the history
…nternals (apache#35347)

### Rationale for this change

The current code has two storage buffers in the key map which are allocated with MemoryPool::Allocate which does not use smart pointers.  This could have led to a potential memory leak in an OOM scenario where the first allocate fails and it also led to some convoluted code keeping track of the previously allocated size in order to properly call Free.

Furthermore, it seems that this key map could have been getting potentially copied in the swiss join code.  While that was probably not happening (since the copy happened before the key map was initialized) it is still an easy recipe for an accidental double-free later on as we maintain the class.

### What changes are included in this PR?

Those raw buffers are changed to std::shared_ptr<Buffer> to avoid these issues.

### Are these changes tested?

Somewhat, the existing unit tests should ensure we didn't cause a regression.  I didn't introduce a regression test to introduce this potential bug because it would be very difficult to do so.

### Are there any user-facing changes?

No

* Closes: apache#35270

Authored-by: Weston Pace <weston.pace@gmail.com>
Signed-off-by: Antoine Pitrou <antoine@python.org>
  • Loading branch information
westonpace authored and liujiacheng777 committed May 11, 2023
1 parent f5a71dd commit 69e5224
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 53 deletions.
75 changes: 31 additions & 44 deletions cpp/src/arrow/compute/key_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,12 @@ void SwissTable::extract_group_ids_imp(const int num_keys, const uint16_t* selec
const uint32_t* hashes, const uint8_t* local_slots,
uint32_t* out_group_ids, int element_offset,
int element_multiplier) const {
const T* elements = reinterpret_cast<const T*>(blocks_) + element_offset;
const T* elements = reinterpret_cast<const T*>(blocks_->data()) + element_offset;
if (log_blocks_ == 0) {
ARROW_DCHECK(sizeof(T) == sizeof(uint8_t));
for (int i = 0; i < num_keys; ++i) {
uint32_t id = use_selection ? selection[i] : i;
uint32_t group_id = blocks_[8 + local_slots[id]];
uint32_t group_id = blocks()[8 + local_slots[id]];
out_group_ids[id] = group_id;
}
} else {
Expand Down Expand Up @@ -206,7 +206,7 @@ void SwissTable::init_slot_ids_for_new_keys(uint32_t num_ids, const uint16_t* id
int num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_);
uint32_t num_block_bytes = num_groupid_bits + 8;
if (log_blocks_ == 0) {
uint64_t block = *reinterpret_cast<const uint64_t*>(blocks_);
uint64_t block = *reinterpret_cast<const uint64_t*>(blocks_->mutable_data());
uint32_t empty_slot =
static_cast<uint32_t>(8 - ARROW_POPCOUNT64(block & kHighBitOfEachByte));
for (uint32_t i = 0; i < num_ids; ++i) {
Expand All @@ -220,7 +220,8 @@ void SwissTable::init_slot_ids_for_new_keys(uint32_t num_ids, const uint16_t* id
uint32_t iblock = hash >> (bits_hash_ - log_blocks_);
uint64_t block;
for (;;) {
block = *reinterpret_cast<const uint64_t*>(blocks_ + num_block_bytes * iblock);
block = *reinterpret_cast<const uint64_t*>(blocks_->mutable_data() +
num_block_bytes * iblock);
block &= kHighBitOfEachByte;
if (block) {
break;
Expand Down Expand Up @@ -255,8 +256,8 @@ void SwissTable::early_filter_imp(const int num_keys, const uint32_t* hashes,
iblock >>= bits_stamp_;

uint32_t num_block_bytes = num_groupid_bits + 8;
const uint8_t* blockbase = reinterpret_cast<const uint8_t*>(blocks_) +
static_cast<uint64_t>(iblock) * num_block_bytes;
const uint8_t* blockbase =
blocks_->data() + static_cast<uint64_t>(iblock) * num_block_bytes;
ARROW_DCHECK(num_block_bytes % sizeof(uint64_t) == 0);
uint64_t block = *reinterpret_cast<const uint64_t*>(blockbase);

Expand Down Expand Up @@ -397,7 +398,7 @@ bool SwissTable::find_next_stamp_match(const uint32_t hash, const uint32_t in_sl
uint8_t* blockbase;
for (;;) {
const uint64_t num_block_bytes = (8 + num_groupid_bits);
blockbase = blocks_ + num_block_bytes * (start_slot_id >> 3);
blockbase = blocks_->mutable_data() + num_block_bytes * (start_slot_id >> 3);
uint64_t block = *reinterpret_cast<uint64_t*>(blockbase);

search_block<true>(block, stamp, (start_slot_id & 7), &local_slot, &match_found);
Expand Down Expand Up @@ -544,7 +545,7 @@ Status SwissTable::map_new_keys_helper(
//
out_group_ids[id] = num_inserted_ + num_inserted_new;
insert_into_empty_slot(inout_next_slot_ids[id], hashes[id], out_group_ids[id]);
hashes_[inout_next_slot_ids[id]] = hashes[id];
this->hashes()[inout_next_slot_ids[id]] = hashes[id];
::arrow::bit_util::ClearBit(match_bitvector, num_processed);
++num_inserted_new;

Expand Down Expand Up @@ -649,34 +650,30 @@ Status SwissTable::grow_double() {
int num_group_id_bits_before = num_groupid_bits_from_log_blocks(log_blocks_);
int num_group_id_bits_after = num_groupid_bits_from_log_blocks(log_blocks_ + 1);
uint64_t group_id_mask_before = ~0ULL >> (64 - num_group_id_bits_before);
int log_blocks_before = log_blocks_;
int log_blocks_after = log_blocks_ + 1;
uint64_t block_size_before = (8 + num_group_id_bits_before);
uint64_t block_size_after = (8 + num_group_id_bits_after);
uint64_t block_size_total_before = (block_size_before << log_blocks_before) + padding_;
uint64_t block_size_total_after = (block_size_after << log_blocks_after) + padding_;
uint64_t hashes_size_total_before =
(bits_hash_ / 8 * (1 << (log_blocks_before + 3))) + padding_;
uint64_t hashes_size_total_after =
(bits_hash_ / 8 * (1 << (log_blocks_after + 3))) + padding_;
constexpr uint32_t stamp_mask = (1 << bits_stamp_) - 1;

// Allocate new buffers
uint8_t* blocks_new;
RETURN_NOT_OK(pool_->Allocate(block_size_total_after, &blocks_new));
memset(blocks_new, 0, block_size_total_after);
uint8_t* hashes_new_8B;
uint32_t* hashes_new;
RETURN_NOT_OK(pool_->Allocate(hashes_size_total_after, &hashes_new_8B));
hashes_new = reinterpret_cast<uint32_t*>(hashes_new_8B);
ARROW_ASSIGN_OR_RAISE(std::unique_ptr<Buffer> blocks_new,
AllocateBuffer(block_size_total_after, pool_));
memset(blocks_new->mutable_data(), 0, block_size_total_after);
ARROW_ASSIGN_OR_RAISE(std::unique_ptr<Buffer> hashes_new_buffer,
AllocateBuffer(hashes_size_total_after, pool_));
auto hashes_new = reinterpret_cast<uint32_t*>(hashes_new_buffer->mutable_data());

// First pass over all old blocks.
// Reinsert entries that were not in the overflow block
// (block other than selected by hash bits corresponding to the entry).
for (int i = 0; i < (1 << log_blocks_); ++i) {
// How many full slots in this block
uint8_t* block_base = blocks_ + i * block_size_before;
uint8_t* double_block_base_new = blocks_new + 2 * i * block_size_after;
uint8_t* block_base = blocks_->mutable_data() + i * block_size_before;
uint8_t* double_block_base_new =
blocks_new->mutable_data() + 2 * i * block_size_after;
uint64_t block = *reinterpret_cast<const uint64_t*>(block_base);

auto full_slots =
Expand All @@ -688,7 +685,7 @@ Status SwissTable::grow_double() {

for (int j = 0; j < full_slots; ++j) {
uint64_t slot_id = i * 8 + j;
uint32_t hash = hashes_[slot_id];
uint32_t hash = hashes()[slot_id];
uint64_t block_id_new = hash >> (bits_hash_ - log_blocks_after);
bool is_overflow_entry = ((block_id_new >> 1) != static_cast<uint64_t>(i));
if (is_overflow_entry) {
Expand Down Expand Up @@ -721,13 +718,13 @@ Status SwissTable::grow_double() {
// Reinsert entries that were in an overflow block.
for (int i = 0; i < (1 << log_blocks_); ++i) {
// How many full slots in this block
uint8_t* block_base = blocks_ + i * block_size_before;
uint8_t* block_base = blocks_->mutable_data() + i * block_size_before;
uint64_t block = util::SafeLoadAs<uint64_t>(block_base);
int full_slots = static_cast<int>(CountLeadingZeros(block & kHighBitOfEachByte) >> 3);

for (int j = 0; j < full_slots; ++j) {
uint64_t slot_id = i * 8 + j;
uint32_t hash = hashes_[slot_id];
uint32_t hash = hashes()[slot_id];
uint64_t block_id_new = hash >> (bits_hash_ - log_blocks_after);
bool is_overflow_entry = ((block_id_new >> 1) != static_cast<uint64_t>(i));
if (!is_overflow_entry) {
Expand All @@ -742,13 +739,14 @@ Status SwissTable::grow_double() {
uint8_t stamp_new =
hash >> ((bits_hash_ - log_blocks_after - bits_stamp_)) & stamp_mask;

uint8_t* block_base_new = blocks_new + block_id_new * block_size_after;
uint8_t* block_base_new =
blocks_new->mutable_data() + block_id_new * block_size_after;
uint64_t block_new = util::SafeLoadAs<uint64_t>(block_base_new);
int full_slots_new =
static_cast<int>(CountLeadingZeros(block_new & kHighBitOfEachByte) >> 3);
while (full_slots_new == 8) {
block_id_new = (block_id_new + 1) & ((1 << log_blocks_after) - 1);
block_base_new = blocks_new + block_id_new * block_size_after;
block_base_new = blocks_new->mutable_data() + block_id_new * block_size_after;
block_new = util::SafeLoadAs<uint64_t>(block_base_new);
full_slots_new =
static_cast<int>(CountLeadingZeros(block_new & kHighBitOfEachByte) >> 3);
Expand All @@ -764,11 +762,9 @@ Status SwissTable::grow_double() {
}
}

pool_->Free(blocks_, block_size_total_before);
pool_->Free(reinterpret_cast<uint8_t*>(hashes_), hashes_size_total_before);
blocks_ = std::move(blocks_new);
hashes_ = std::move(hashes_new_buffer);
log_blocks_ = log_blocks_after;
blocks_ = blocks_new;
hashes_ = hashes_new;

return Status::OK();
}
Expand All @@ -785,14 +781,15 @@ Status SwissTable::init(int64_t hardware_flags, MemoryPool* pool, int log_blocks

const uint64_t block_bytes = 8 + num_groupid_bits;
const uint64_t slot_bytes = (block_bytes << log_blocks_) + padding_;
RETURN_NOT_OK(pool_->Allocate(slot_bytes, &blocks_));
ARROW_ASSIGN_OR_RAISE(blocks_, AllocateBuffer(slot_bytes, pool_));

// Make sure group ids are initially set to zero for all slots.
memset(blocks_, 0, slot_bytes);
memset(blocks_->mutable_data(), 0, slot_bytes);

// Initialize all status bytes to represent an empty slot.
uint8_t* blocks_ptr = blocks_->mutable_data();
for (uint64_t i = 0; i < (static_cast<uint64_t>(1) << log_blocks_); ++i) {
util::SafeStore(blocks_ + i * block_bytes, kHighBitOfEachByte);
util::SafeStore(blocks_ptr + i * block_bytes, kHighBitOfEachByte);
}

if (no_hash_array) {
Expand All @@ -801,27 +798,17 @@ Status SwissTable::init(int64_t hardware_flags, MemoryPool* pool, int log_blocks
uint64_t num_slots = 1ULL << (log_blocks_ + 3);
const uint64_t hash_size = sizeof(uint32_t);
const uint64_t hash_bytes = hash_size * num_slots + padding_;
uint8_t* hashes8;
RETURN_NOT_OK(pool_->Allocate(hash_bytes, &hashes8));
hashes_ = reinterpret_cast<uint32_t*>(hashes8);
ARROW_ASSIGN_OR_RAISE(hashes_, AllocateBuffer(hash_bytes, pool_));
}

return Status::OK();
}

void SwissTable::cleanup() {
if (blocks_) {
int num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_);
const uint64_t block_bytes = 8 + num_groupid_bits;
const uint64_t slot_bytes = (block_bytes << log_blocks_) + padding_;
pool_->Free(blocks_, slot_bytes);
blocks_ = nullptr;
}
if (hashes_) {
uint64_t num_slots = 1ULL << (log_blocks_ + 3);
const uint64_t hash_size = sizeof(uint32_t);
const uint64_t hash_bytes = hash_size * num_slots + padding_;
pool_->Free(reinterpret_cast<uint8_t*>(hashes_), hash_bytes);
hashes_ = nullptr;
}
log_blocks_ = 0;
Expand Down
12 changes: 7 additions & 5 deletions cpp/src/arrow/compute/key_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,11 @@ class ARROW_EXPORT SwissTable {

void num_inserted(uint32_t i) { num_inserted_ = i; }

uint8_t* blocks() const { return blocks_; }
uint8_t* blocks() const { return blocks_->mutable_data(); }

uint32_t* hashes() const { return hashes_; }
uint32_t* hashes() const {
return reinterpret_cast<uint32_t*>(hashes_->mutable_data());
}

/// \brief Extract group id for a given slot in a given block.
///
Expand Down Expand Up @@ -226,12 +228,12 @@ class ARROW_EXPORT SwissTable {
// ---------------------------------------------------
// * Empty bucket has value 0x80. Non-empty bucket has highest bit set to 0.
//
uint8_t* blocks_;
std::shared_ptr<Buffer> blocks_;

// Array of hashes of values inserted into slots.
// Undefined if the corresponding slot is empty.
// There is 64B padding at the end.
uint32_t* hashes_;
std::shared_ptr<Buffer> hashes_;

int64_t hardware_flags_;
MemoryPool* pool_;
Expand Down Expand Up @@ -270,7 +272,7 @@ void SwissTable::insert_into_empty_slot(uint32_t slot_id, uint32_t hash,
int stamp =
static_cast<int>((hash >> (bits_hash_ - log_blocks_ - bits_stamp_)) & stamp_mask);
uint64_t block_id = slot_id >> 3;
uint8_t* blockbase = blocks_ + num_block_bytes * block_id;
uint8_t* blockbase = blocks_->mutable_data() + num_block_bytes * block_id;

blockbase[7 - start_slot] = static_cast<uint8_t>(stamp);
int groupid_bit_offset = static_cast<int>(start_slot * num_groupid_bits);
Expand Down
9 changes: 5 additions & 4 deletions cpp/src/arrow/compute/key_map_avx2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ int SwissTable::early_filter_imp_avx2_x8(const int num_hashes, const uint32_t* h
__m256i voffset_B = _mm256_srli_epi64(vblock_offset, 32);
__m256i vstamp_B = _mm256_srli_epi64(vstamp, 32);

auto blocks_i64 = reinterpret_cast<arrow::util::int64_for_gather_t*>(blocks_);
auto blocks_i64 =
reinterpret_cast<arrow::util::int64_for_gather_t*>(blocks_->mutable_data());
auto vblock_A = _mm256_i64gather_epi64(blocks_i64, voffset_A, 1);
auto vblock_B = _mm256_i64gather_epi64(blocks_i64, voffset_B, 1);
__m256i vblock_highbits_A =
Expand Down Expand Up @@ -234,7 +235,7 @@ int SwissTable::early_filter_imp_avx2_x32(const int num_hashes, const uint32_t*
const int num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_);
for (int i = 0; i < (1 << log_blocks_); ++i) {
uint64_t in_blockbytes =
*reinterpret_cast<const uint64_t*>(blocks_ + (8 + num_groupid_bits) * i);
*reinterpret_cast<const uint64_t*>(blocks_->data() + (8 + num_groupid_bits) * i);
block_bytes[i] = in_blockbytes;
}

Expand Down Expand Up @@ -375,12 +376,12 @@ int SwissTable::extract_group_ids_avx2(const int num_keys, const uint32_t* hashe
int byte_multiplier, int byte_size) const {
ARROW_DCHECK(byte_size == 1 || byte_size == 2 || byte_size == 4);
uint32_t mask = byte_size == 1 ? 0xFF : byte_size == 2 ? 0xFFFF : 0xFFFFFFFF;
auto elements = reinterpret_cast<const int*>(blocks_ + byte_offset);
auto elements = reinterpret_cast<const int*>(blocks_->data() + byte_offset);
constexpr int unroll = 8;
if (log_blocks_ == 0) {
ARROW_DCHECK(byte_size == 1 && byte_offset == 8 && byte_multiplier == 16);
__m256i block_group_ids =
_mm256_set1_epi64x(reinterpret_cast<const uint64_t*>(blocks_)[1]);
_mm256_set1_epi64x(reinterpret_cast<const uint64_t*>(blocks_->data())[1]);
for (int i = 0; i < num_keys / unroll; ++i) {
__m256i local_slot =
_mm256_set1_epi64x(reinterpret_cast<const uint64_t*>(local_slots)[i]);
Expand Down

0 comments on commit 69e5224

Please sign in to comment.