Skip to content

Commit

Permalink
GH-39778: [C++] Fix tail-byte access cross buffer boundary in key has…
Browse files Browse the repository at this point in the history
…h avx2 (#39800)

### Rationale for this change

Issue #39778 seems caused by a careless (but hard to spot) bug in key hash avx2.

### What changes are included in this PR?

Fix the careless bug.

### Are these changes tested?

UT included.

### Are there any user-facing changes?

No.

* Closes: #39778

Authored-by: Ruoxi Sun <zanmato1984@gmail.com>
Signed-off-by: Antoine Pitrou <antoine@python.org>
  • Loading branch information
zanmato1984 committed Jan 26, 2024
1 parent 667e917 commit 13b2234
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 80 deletions.
142 changes: 74 additions & 68 deletions cpp/src/arrow/compute/key_hash.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,23 +105,23 @@ inline void Hashing32::StripeMask(int i, uint32_t* mask1, uint32_t* mask2,
}

template <bool T_COMBINE_HASHES>
void Hashing32::HashFixedLenImp(uint32_t num_rows, uint64_t length, const uint8_t* keys,
uint32_t* hashes) {
void Hashing32::HashFixedLenImp(uint32_t num_rows, uint64_t key_length,
const uint8_t* keys, uint32_t* hashes) {
// Calculate the number of rows that skip the last 16 bytes
//
uint32_t num_rows_safe = num_rows;
while (num_rows_safe > 0 && (num_rows - num_rows_safe) * length < kStripeSize) {
while (num_rows_safe > 0 && (num_rows - num_rows_safe) * key_length < kStripeSize) {
--num_rows_safe;
}

// Compute masks for the last 16 byte stripe
//
uint64_t num_stripes = bit_util::CeilDiv(length, kStripeSize);
uint64_t num_stripes = bit_util::CeilDiv(key_length, kStripeSize);
uint32_t mask1, mask2, mask3, mask4;
StripeMask(((length - 1) & (kStripeSize - 1)) + 1, &mask1, &mask2, &mask3, &mask4);
StripeMask(((key_length - 1) & (kStripeSize - 1)) + 1, &mask1, &mask2, &mask3, &mask4);

for (uint32_t i = 0; i < num_rows_safe; ++i) {
const uint8_t* key = keys + static_cast<uint64_t>(i) * length;
const uint8_t* key = keys + static_cast<uint64_t>(i) * key_length;
uint32_t acc1, acc2, acc3, acc4;
ProcessFullStripes(num_stripes, key, &acc1, &acc2, &acc3, &acc4);
ProcessLastStripe(mask1, mask2, mask3, mask4, key + (num_stripes - 1) * kStripeSize,
Expand All @@ -138,11 +138,11 @@ void Hashing32::HashFixedLenImp(uint32_t num_rows, uint64_t length, const uint8_

uint32_t last_stripe_copy[4];
for (uint32_t i = num_rows_safe; i < num_rows; ++i) {
const uint8_t* key = keys + static_cast<uint64_t>(i) * length;
const uint8_t* key = keys + static_cast<uint64_t>(i) * key_length;
uint32_t acc1, acc2, acc3, acc4;
ProcessFullStripes(num_stripes, key, &acc1, &acc2, &acc3, &acc4);
memcpy(last_stripe_copy, key + (num_stripes - 1) * kStripeSize,
length - (num_stripes - 1) * kStripeSize);
key_length - (num_stripes - 1) * kStripeSize);
ProcessLastStripe(mask1, mask2, mask3, mask4,
reinterpret_cast<const uint8_t*>(last_stripe_copy), &acc1, &acc2,
&acc3, &acc4);
Expand All @@ -168,15 +168,16 @@ void Hashing32::HashVarLenImp(uint32_t num_rows, const T* offsets,
}

for (uint32_t i = 0; i < num_rows_safe; ++i) {
uint64_t length = offsets[i + 1] - offsets[i];
uint64_t key_length = offsets[i + 1] - offsets[i];

// Compute masks for the last 16 byte stripe.
// For an empty string set number of stripes to 1 but mask to all zeroes.
//
int is_non_empty = length == 0 ? 0 : 1;
uint64_t num_stripes = bit_util::CeilDiv(length, kStripeSize) + (1 - is_non_empty);
int is_non_empty = key_length == 0 ? 0 : 1;
uint64_t num_stripes =
bit_util::CeilDiv(key_length, kStripeSize) + (1 - is_non_empty);
uint32_t mask1, mask2, mask3, mask4;
StripeMask(((length - is_non_empty) & (kStripeSize - 1)) + is_non_empty, &mask1,
StripeMask(((key_length - is_non_empty) & (kStripeSize - 1)) + is_non_empty, &mask1,
&mask2, &mask3, &mask4);

const uint8_t* key = concatenated_keys + offsets[i];
Expand All @@ -198,23 +199,24 @@ void Hashing32::HashVarLenImp(uint32_t num_rows, const T* offsets,

uint32_t last_stripe_copy[4];
for (uint32_t i = num_rows_safe; i < num_rows; ++i) {
uint64_t length = offsets[i + 1] - offsets[i];
uint64_t key_length = offsets[i + 1] - offsets[i];

// Compute masks for the last 16 byte stripe.
// For an empty string set number of stripes to 1 but mask to all zeroes.
//
int is_non_empty = length == 0 ? 0 : 1;
uint64_t num_stripes = bit_util::CeilDiv(length, kStripeSize) + (1 - is_non_empty);
int is_non_empty = key_length == 0 ? 0 : 1;
uint64_t num_stripes =
bit_util::CeilDiv(key_length, kStripeSize) + (1 - is_non_empty);
uint32_t mask1, mask2, mask3, mask4;
StripeMask(((length - is_non_empty) & (kStripeSize - 1)) + is_non_empty, &mask1,
StripeMask(((key_length - is_non_empty) & (kStripeSize - 1)) + is_non_empty, &mask1,
&mask2, &mask3, &mask4);

const uint8_t* key = concatenated_keys + offsets[i];
uint32_t acc1, acc2, acc3, acc4;
ProcessFullStripes(num_stripes, key, &acc1, &acc2, &acc3, &acc4);
if (length > 0) {
if (key_length > 0) {
memcpy(last_stripe_copy, key + (num_stripes - 1) * kStripeSize,
length - (num_stripes - 1) * kStripeSize);
key_length - (num_stripes - 1) * kStripeSize);
}
if (num_stripes > 0) {
ProcessLastStripe(mask1, mask2, mask3, mask4,
Expand Down Expand Up @@ -309,9 +311,9 @@ void Hashing32::HashIntImp(uint32_t num_keys, const T* keys, uint32_t* hashes) {
}
}

void Hashing32::HashInt(bool combine_hashes, uint32_t num_keys, uint64_t length_key,
void Hashing32::HashInt(bool combine_hashes, uint32_t num_keys, uint64_t key_length,
const uint8_t* keys, uint32_t* hashes) {
switch (length_key) {
switch (key_length) {
case sizeof(uint8_t):
if (combine_hashes) {
HashIntImp<true, uint8_t>(num_keys, keys, hashes);
Expand Down Expand Up @@ -352,27 +354,27 @@ void Hashing32::HashInt(bool combine_hashes, uint32_t num_keys, uint64_t length_
}
}

void Hashing32::HashFixed(int64_t hardware_flags, bool combine_hashes, uint32_t num_rows,
uint64_t length, const uint8_t* keys, uint32_t* hashes,
uint32_t* hashes_temp_for_combine) {
if (ARROW_POPCOUNT64(length) == 1 && length <= sizeof(uint64_t)) {
HashInt(combine_hashes, num_rows, length, keys, hashes);
void Hashing32::HashFixed(int64_t hardware_flags, bool combine_hashes, uint32_t num_keys,
uint64_t key_length, const uint8_t* keys, uint32_t* hashes,
uint32_t* temp_hashes_for_combine) {
if (ARROW_POPCOUNT64(key_length) == 1 && key_length <= sizeof(uint64_t)) {
HashInt(combine_hashes, num_keys, key_length, keys, hashes);
return;
}

uint32_t num_processed = 0;
#if defined(ARROW_HAVE_RUNTIME_AVX2)
if (hardware_flags & arrow::internal::CpuInfo::AVX2) {
num_processed = HashFixedLen_avx2(combine_hashes, num_rows, length, keys, hashes,
hashes_temp_for_combine);
num_processed = HashFixedLen_avx2(combine_hashes, num_keys, key_length, keys, hashes,
temp_hashes_for_combine);
}
#endif
if (combine_hashes) {
HashFixedLenImp<true>(num_rows - num_processed, length, keys + length * num_processed,
hashes + num_processed);
HashFixedLenImp<true>(num_keys - num_processed, key_length,
keys + key_length * num_processed, hashes + num_processed);
} else {
HashFixedLenImp<false>(num_rows - num_processed, length,
keys + length * num_processed, hashes + num_processed);
HashFixedLenImp<false>(num_keys - num_processed, key_length,
keys + key_length * num_processed, hashes + num_processed);
}
}

Expand Down Expand Up @@ -423,13 +425,13 @@ void Hashing32::HashMultiColumn(const std::vector<KeyColumnArray>& cols,
}

if (cols[icol].metadata().is_fixed_length) {
uint32_t col_width = cols[icol].metadata().fixed_length;
if (col_width == 0) {
uint32_t key_length = cols[icol].metadata().fixed_length;
if (key_length == 0) {
HashBit(icol > 0, cols[icol].bit_offset(1), batch_size_next,
cols[icol].data(1) + first_row / 8, hashes + first_row);
} else {
HashFixed(ctx->hardware_flags, icol > 0, batch_size_next, col_width,
cols[icol].data(1) + first_row * col_width, hashes + first_row,
HashFixed(ctx->hardware_flags, icol > 0, batch_size_next, key_length,
cols[icol].data(1) + first_row * key_length, hashes + first_row,
hash_temp);
}
} else if (cols[icol].metadata().fixed_length == sizeof(uint32_t)) {
Expand Down Expand Up @@ -463,8 +465,9 @@ void Hashing32::HashMultiColumn(const std::vector<KeyColumnArray>& cols,
Status Hashing32::HashBatch(const ExecBatch& key_batch, uint32_t* hashes,
std::vector<KeyColumnArray>& column_arrays,
int64_t hardware_flags, util::TempVectorStack* temp_stack,
int64_t offset, int64_t length) {
RETURN_NOT_OK(ColumnArraysFromExecBatch(key_batch, offset, length, &column_arrays));
int64_t start_rows, int64_t num_rows) {
RETURN_NOT_OK(
ColumnArraysFromExecBatch(key_batch, start_rows, num_rows, &column_arrays));

LightContext ctx;
ctx.hardware_flags = hardware_flags;
Expand Down Expand Up @@ -574,23 +577,23 @@ inline void Hashing64::StripeMask(int i, uint64_t* mask1, uint64_t* mask2,
}

template <bool T_COMBINE_HASHES>
void Hashing64::HashFixedLenImp(uint32_t num_rows, uint64_t length, const uint8_t* keys,
uint64_t* hashes) {
void Hashing64::HashFixedLenImp(uint32_t num_rows, uint64_t key_length,
const uint8_t* keys, uint64_t* hashes) {
// Calculate the number of rows that skip the last 32 bytes
//
uint32_t num_rows_safe = num_rows;
while (num_rows_safe > 0 && (num_rows - num_rows_safe) * length < kStripeSize) {
while (num_rows_safe > 0 && (num_rows - num_rows_safe) * key_length < kStripeSize) {
--num_rows_safe;
}

// Compute masks for the last 32 byte stripe
//
uint64_t num_stripes = bit_util::CeilDiv(length, kStripeSize);
uint64_t num_stripes = bit_util::CeilDiv(key_length, kStripeSize);
uint64_t mask1, mask2, mask3, mask4;
StripeMask(((length - 1) & (kStripeSize - 1)) + 1, &mask1, &mask2, &mask3, &mask4);
StripeMask(((key_length - 1) & (kStripeSize - 1)) + 1, &mask1, &mask2, &mask3, &mask4);

for (uint32_t i = 0; i < num_rows_safe; ++i) {
const uint8_t* key = keys + static_cast<uint64_t>(i) * length;
const uint8_t* key = keys + static_cast<uint64_t>(i) * key_length;
uint64_t acc1, acc2, acc3, acc4;
ProcessFullStripes(num_stripes, key, &acc1, &acc2, &acc3, &acc4);
ProcessLastStripe(mask1, mask2, mask3, mask4, key + (num_stripes - 1) * kStripeSize,
Expand All @@ -607,11 +610,11 @@ void Hashing64::HashFixedLenImp(uint32_t num_rows, uint64_t length, const uint8_

uint64_t last_stripe_copy[4];
for (uint32_t i = num_rows_safe; i < num_rows; ++i) {
const uint8_t* key = keys + static_cast<uint64_t>(i) * length;
const uint8_t* key = keys + static_cast<uint64_t>(i) * key_length;
uint64_t acc1, acc2, acc3, acc4;
ProcessFullStripes(num_stripes, key, &acc1, &acc2, &acc3, &acc4);
memcpy(last_stripe_copy, key + (num_stripes - 1) * kStripeSize,
length - (num_stripes - 1) * kStripeSize);
key_length - (num_stripes - 1) * kStripeSize);
ProcessLastStripe(mask1, mask2, mask3, mask4,
reinterpret_cast<const uint8_t*>(last_stripe_copy), &acc1, &acc2,
&acc3, &acc4);
Expand All @@ -637,15 +640,16 @@ void Hashing64::HashVarLenImp(uint32_t num_rows, const T* offsets,
}

for (uint32_t i = 0; i < num_rows_safe; ++i) {
uint64_t length = offsets[i + 1] - offsets[i];
uint64_t key_length = offsets[i + 1] - offsets[i];

// Compute masks for the last 32 byte stripe.
// For an empty string set number of stripes to 1 but mask to all zeroes.
//
int is_non_empty = length == 0 ? 0 : 1;
uint64_t num_stripes = bit_util::CeilDiv(length, kStripeSize) + (1 - is_non_empty);
int is_non_empty = key_length == 0 ? 0 : 1;
uint64_t num_stripes =
bit_util::CeilDiv(key_length, kStripeSize) + (1 - is_non_empty);
uint64_t mask1, mask2, mask3, mask4;
StripeMask(((length - is_non_empty) & (kStripeSize - 1)) + is_non_empty, &mask1,
StripeMask(((key_length - is_non_empty) & (kStripeSize - 1)) + is_non_empty, &mask1,
&mask2, &mask3, &mask4);

const uint8_t* key = concatenated_keys + offsets[i];
Expand All @@ -667,22 +671,23 @@ void Hashing64::HashVarLenImp(uint32_t num_rows, const T* offsets,

uint64_t last_stripe_copy[4];
for (uint32_t i = num_rows_safe; i < num_rows; ++i) {
uint64_t length = offsets[i + 1] - offsets[i];
uint64_t key_length = offsets[i + 1] - offsets[i];

// Compute masks for the last 32 byte stripe
//
int is_non_empty = length == 0 ? 0 : 1;
uint64_t num_stripes = bit_util::CeilDiv(length, kStripeSize) + (1 - is_non_empty);
int is_non_empty = key_length == 0 ? 0 : 1;
uint64_t num_stripes =
bit_util::CeilDiv(key_length, kStripeSize) + (1 - is_non_empty);
uint64_t mask1, mask2, mask3, mask4;
StripeMask(((length - is_non_empty) & (kStripeSize - 1)) + is_non_empty, &mask1,
StripeMask(((key_length - is_non_empty) & (kStripeSize - 1)) + is_non_empty, &mask1,
&mask2, &mask3, &mask4);

const uint8_t* key = concatenated_keys + offsets[i];
uint64_t acc1, acc2, acc3, acc4;
ProcessFullStripes(num_stripes, key, &acc1, &acc2, &acc3, &acc4);
if (length > 0) {
if (key_length > 0) {
memcpy(last_stripe_copy, key + (num_stripes - 1) * kStripeSize,
length - (num_stripes - 1) * kStripeSize);
key_length - (num_stripes - 1) * kStripeSize);
}
if (num_stripes > 0) {
ProcessLastStripe(mask1, mask2, mask3, mask4,
Expand Down Expand Up @@ -759,9 +764,9 @@ void Hashing64::HashIntImp(uint32_t num_keys, const T* keys, uint64_t* hashes) {
}
}

void Hashing64::HashInt(bool combine_hashes, uint32_t num_keys, uint64_t length_key,
void Hashing64::HashInt(bool combine_hashes, uint32_t num_keys, uint64_t key_length,
const uint8_t* keys, uint64_t* hashes) {
switch (length_key) {
switch (key_length) {
case sizeof(uint8_t):
if (combine_hashes) {
HashIntImp<true, uint8_t>(num_keys, keys, hashes);
Expand Down Expand Up @@ -802,17 +807,17 @@ void Hashing64::HashInt(bool combine_hashes, uint32_t num_keys, uint64_t length_
}
}

void Hashing64::HashFixed(bool combine_hashes, uint32_t num_rows, uint64_t length,
void Hashing64::HashFixed(bool combine_hashes, uint32_t num_keys, uint64_t key_length,
const uint8_t* keys, uint64_t* hashes) {
if (ARROW_POPCOUNT64(length) == 1 && length <= sizeof(uint64_t)) {
HashInt(combine_hashes, num_rows, length, keys, hashes);
if (ARROW_POPCOUNT64(key_length) == 1 && key_length <= sizeof(uint64_t)) {
HashInt(combine_hashes, num_keys, key_length, keys, hashes);
return;
}

if (combine_hashes) {
HashFixedLenImp<true>(num_rows, length, keys, hashes);
HashFixedLenImp<true>(num_keys, key_length, keys, hashes);
} else {
HashFixedLenImp<false>(num_rows, length, keys, hashes);
HashFixedLenImp<false>(num_keys, key_length, keys, hashes);
}
}

Expand Down Expand Up @@ -860,13 +865,13 @@ void Hashing64::HashMultiColumn(const std::vector<KeyColumnArray>& cols,
}

if (cols[icol].metadata().is_fixed_length) {
uint64_t col_width = cols[icol].metadata().fixed_length;
if (col_width == 0) {
uint64_t key_length = cols[icol].metadata().fixed_length;
if (key_length == 0) {
HashBit(icol > 0, cols[icol].bit_offset(1), batch_size_next,
cols[icol].data(1) + first_row / 8, hashes + first_row);
} else {
HashFixed(icol > 0, batch_size_next, col_width,
cols[icol].data(1) + first_row * col_width, hashes + first_row);
HashFixed(icol > 0, batch_size_next, key_length,
cols[icol].data(1) + first_row * key_length, hashes + first_row);
}
} else if (cols[icol].metadata().fixed_length == sizeof(uint32_t)) {
HashVarLen(icol > 0, batch_size_next, cols[icol].offsets() + first_row,
Expand Down Expand Up @@ -897,8 +902,9 @@ void Hashing64::HashMultiColumn(const std::vector<KeyColumnArray>& cols,
Status Hashing64::HashBatch(const ExecBatch& key_batch, uint64_t* hashes,
std::vector<KeyColumnArray>& column_arrays,
int64_t hardware_flags, util::TempVectorStack* temp_stack,
int64_t offset, int64_t length) {
RETURN_NOT_OK(ColumnArraysFromExecBatch(key_batch, offset, length, &column_arrays));
int64_t start_row, int64_t num_rows) {
RETURN_NOT_OK(
ColumnArraysFromExecBatch(key_batch, start_row, num_rows, &column_arrays));

LightContext ctx;
ctx.hardware_flags = hardware_flags;
Expand Down
Loading

0 comments on commit 13b2234

Please sign in to comment.