Skip to content

Commit

Permalink
ARROW-14211: [C++][Compute] Fixing thread sanitizer problems in hash …
Browse files Browse the repository at this point in the history
…join node

Fixing 3 issues:
- one in SchemaProjectionMaps - I simplified all of the code to get rid of thread synchronization at all
- one in TaskScheduler - added (unnecessary) mutex
- one in HashJoinImpl - switching from shared byte vector to local bit vectors and merge (for recording if a match for a hash table row has been seen)

Closes #11350 from michalursa/ARROW-14211-hash-join-tsan

Authored-by: michalursa <michal@ursacomputing.com>
Signed-off-by: Sutou Kouhei <kou@clear-code.com>
  • Loading branch information
michalursa authored and kou committed Oct 19, 2021
1 parent 6e1293b commit d9ef519
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 89 deletions.
104 changes: 70 additions & 34 deletions cpp/src/arrow/compute/exec/hash_join.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ namespace compute {
using internal::RowEncoder;

class HashJoinBasicImpl : public HashJoinImpl {
private:
struct ThreadLocalState;

public:
Status InputReceived(size_t thread_index, int side, ExecBatch batch) override {
if (cancelled_) {
Expand Down Expand Up @@ -91,6 +94,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
local_states_.resize(num_threads);
for (size_t i = 0; i < local_states_.size(); ++i) {
local_states_[i].is_initialized = false;
local_states_[i].is_has_match_initialized = false;
}

has_hash_table_ = false;
Expand Down Expand Up @@ -150,23 +154,26 @@ class HashJoinBasicImpl : public HashJoinImpl {
int num_cols = schema_mgr_->proj_maps[side].num_cols(projection_handle);
projected.values.resize(num_cols);

const int* to_input =
auto to_input =
schema_mgr_->proj_maps[side].map(projection_handle, HashJoinProjection::INPUT);
for (int icol = 0; icol < num_cols; ++icol) {
projected.values[icol] = batch.values[to_input[icol]];
projected.values[icol] = batch.values[to_input.get(icol)];
}

return encoder->EncodeAndAppend(projected);
}

void ProbeBatch_Lookup(const RowEncoder& exec_batch_keys,
void ProbeBatch_Lookup(ThreadLocalState* local_state, const RowEncoder& exec_batch_keys,
const std::vector<const uint8_t*>& non_null_bit_vectors,
const std::vector<int64_t>& non_null_bit_vector_offsets,
std::vector<int32_t>* output_match,
std::vector<int32_t>* output_no_match,
std::vector<int32_t>* output_match_left,
std::vector<int32_t>* output_match_right) {
ARROW_DCHECK(has_hash_table_);

InitHasMatchIfNeeded(local_state);

int num_cols = static_cast<int>(non_null_bit_vectors.size());
for (int32_t irow = 0; irow < exec_batch_keys.num_rows(); ++irow) {
// Apply null key filtering
Expand All @@ -191,7 +198,8 @@ class HashJoinBasicImpl : public HashJoinImpl {
for (auto it = range.first; it != range.second; ++it) {
output_match_left->push_back(irow);
output_match_right->push_back(it->second);
has_match_[it->second] = 0xFF;
// Mark row in hash table as having a match
BitUtil::SetBit(local_state->has_match.data(), it->second);
has_match = true;
}
if (!has_match) {
Expand All @@ -215,46 +223,47 @@ class HashJoinBasicImpl : public HashJoinImpl {
ARROW_DCHECK((opt_right_payload == nullptr) ==
(schema_mgr_->proj_maps[1].num_cols(HashJoinProjection::PAYLOAD) == 0));
result.values.resize(num_out_cols_left + num_out_cols_right);
const int* from_key = schema_mgr_->proj_maps[0].map(HashJoinProjection::OUTPUT,
HashJoinProjection::KEY);
const int* from_payload = schema_mgr_->proj_maps[0].map(HashJoinProjection::OUTPUT,
HashJoinProjection::PAYLOAD);
auto from_key = schema_mgr_->proj_maps[0].map(HashJoinProjection::OUTPUT,
HashJoinProjection::KEY);
auto from_payload = schema_mgr_->proj_maps[0].map(HashJoinProjection::OUTPUT,
HashJoinProjection::PAYLOAD);
for (int icol = 0; icol < num_out_cols_left; ++icol) {
bool is_from_key = (from_key[icol] != HashJoinSchema::kMissingField());
bool is_from_payload = (from_payload[icol] != HashJoinSchema::kMissingField());
bool is_from_key = (from_key.get(icol) != HashJoinSchema::kMissingField());
bool is_from_payload = (from_payload.get(icol) != HashJoinSchema::kMissingField());
ARROW_DCHECK(is_from_key != is_from_payload);
ARROW_DCHECK(!is_from_key ||
(opt_left_key &&
from_key[icol] < static_cast<int>(opt_left_key->values.size()) &&
from_key.get(icol) < static_cast<int>(opt_left_key->values.size()) &&
opt_left_key->length == batch_size_next));
ARROW_DCHECK(
!is_from_payload ||
(opt_left_payload &&
from_payload[icol] < static_cast<int>(opt_left_payload->values.size()) &&
from_payload.get(icol) < static_cast<int>(opt_left_payload->values.size()) &&
opt_left_payload->length == batch_size_next));
result.values[icol] = is_from_key ? opt_left_key->values[from_key[icol]]
: opt_left_payload->values[from_payload[icol]];
result.values[icol] = is_from_key
? opt_left_key->values[from_key.get(icol)]
: opt_left_payload->values[from_payload.get(icol)];
}
from_key = schema_mgr_->proj_maps[1].map(HashJoinProjection::OUTPUT,
HashJoinProjection::KEY);
from_payload = schema_mgr_->proj_maps[1].map(HashJoinProjection::OUTPUT,
HashJoinProjection::PAYLOAD);
for (int icol = 0; icol < num_out_cols_right; ++icol) {
bool is_from_key = (from_key[icol] != HashJoinSchema::kMissingField());
bool is_from_payload = (from_payload[icol] != HashJoinSchema::kMissingField());
bool is_from_key = (from_key.get(icol) != HashJoinSchema::kMissingField());
bool is_from_payload = (from_payload.get(icol) != HashJoinSchema::kMissingField());
ARROW_DCHECK(is_from_key != is_from_payload);
ARROW_DCHECK(!is_from_key ||
(opt_right_key &&
from_key[icol] < static_cast<int>(opt_right_key->values.size()) &&
from_key.get(icol) < static_cast<int>(opt_right_key->values.size()) &&
opt_right_key->length == batch_size_next));
ARROW_DCHECK(
!is_from_payload ||
(opt_right_payload &&
from_payload[icol] < static_cast<int>(opt_right_payload->values.size()) &&
from_payload.get(icol) < static_cast<int>(opt_right_payload->values.size()) &&
opt_right_payload->length == batch_size_next));
result.values[num_out_cols_left + icol] =
is_from_key ? opt_right_key->values[from_key[icol]]
: opt_right_payload->values[from_payload[icol]];
is_from_key ? opt_right_key->values[from_key.get(icol)]
: opt_right_payload->values[from_payload.get(icol)];
}

output_batch_callback_(std::move(result));
Expand Down Expand Up @@ -384,10 +393,10 @@ class HashJoinBasicImpl : public HashJoinImpl {
int num_key_cols = schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::KEY);
non_null_bit_vectors.resize(num_key_cols);
non_null_bit_vector_offsets.resize(num_key_cols);
const int* from_batch =
auto from_batch =
schema_mgr_->proj_maps[0].map(HashJoinProjection::KEY, HashJoinProjection::INPUT);
for (int i = 0; i < num_key_cols; ++i) {
int input_col_id = from_batch[i];
int input_col_id = from_batch.get(i);
const uint8_t* non_nulls = nullptr;
int64_t offset = 0;
if (batch[input_col_id].array()->buffers[0] != NULLPTR) {
Expand All @@ -398,7 +407,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
non_null_bit_vector_offsets[i] = offset;
}

ProbeBatch_Lookup(local_state.exec_batch_keys, non_null_bit_vectors,
ProbeBatch_Lookup(&local_state, local_state.exec_batch_keys, non_null_bit_vectors,
non_null_bit_vector_offsets, &local_state.match,
&local_state.no_match, &local_state.match_left,
&local_state.match_right);
Expand Down Expand Up @@ -446,11 +455,6 @@ class HashJoinBasicImpl : public HashJoinImpl {
hash_table_.insert(std::make_pair(hash_table_keys_.encoded_row(irow), irow));
}
}
if (!hash_table_empty_) {
int32_t num_rows = hash_table_keys_.num_rows();
has_match_.resize(num_rows);
memset(has_match_.data(), 0, num_rows);
}
}
return Status::OK();
}
Expand Down Expand Up @@ -563,9 +567,9 @@ class HashJoinBasicImpl : public HashJoinImpl {
id_right.clear();
bool use_left = false;

uint8_t match_search_value = (join_type_ == JoinType::RIGHT_SEMI) ? 0xFF : 0x00;
bool match_search_value = (join_type_ == JoinType::RIGHT_SEMI);
for (int32_t row_id = start_row_id; row_id < end_row_id; ++row_id) {
if (has_match_[row_id] == match_search_value) {
if (BitUtil::GetBit(has_match_.data(), row_id) == match_search_value) {
id_right.push_back(row_id);
}
}
Expand Down Expand Up @@ -607,16 +611,13 @@ class HashJoinBasicImpl : public HashJoinImpl {
}

Status ScanHashTable(size_t thread_index) {
MergeHasMatch();
return scheduler_->StartTaskGroup(thread_index, task_group_scan_,
ScanHashTable_num_tasks());
}

bool QueueBatchIfNeeded(int side, ExecBatch batch) {
if (side == 0) {
if (has_hash_table_) {
return false;
}

std::lock_guard<std::mutex> lock(left_batches_mutex_);
if (has_hash_table_) {
return false;
Expand All @@ -636,6 +637,39 @@ class HashJoinBasicImpl : public HashJoinImpl {
return ScanHashTable(thread_index);
}

void InitHasMatchIfNeeded(ThreadLocalState* local_state) {
if (local_state->is_has_match_initialized) {
return;
}
if (!hash_table_empty_) {
int32_t num_rows = hash_table_keys_.num_rows();
local_state->has_match.resize(BitUtil::BytesForBits(num_rows));
memset(local_state->has_match.data(), 0, BitUtil::BytesForBits(num_rows));
}
local_state->is_has_match_initialized = true;
}

void MergeHasMatch() {
if (hash_table_empty_) {
return;
}

int32_t num_rows = hash_table_keys_.num_rows();
has_match_.resize(BitUtil::BytesForBits(num_rows));
memset(has_match_.data(), 0, BitUtil::BytesForBits(num_rows));

for (size_t tid = 0; tid < local_states_.size(); ++tid) {
if (!local_states_[tid].is_initialized) {
continue;
}
if (!local_states_[tid].is_has_match_initialized) {
continue;
}
arrow::internal::BitmapOr(has_match_.data(), 0, local_states_[tid].has_match.data(),
0, num_rows, 0, has_match_.data());
}
}

static constexpr int64_t hash_table_scan_unit_ = 32 * 1024;
static constexpr int64_t output_batch_size_ = 32 * 1024;

Expand Down Expand Up @@ -666,6 +700,8 @@ class HashJoinBasicImpl : public HashJoinImpl {
std::vector<int32_t> no_match;
std::vector<int32_t> match_left;
std::vector<int32_t> match_right;
bool is_has_match_initialized;
std::vector<uint8_t> has_match;
};
std::vector<ThreadLocalState> local_states_;

Expand Down
6 changes: 3 additions & 3 deletions cpp/src/arrow/compute/exec/hash_join_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -253,9 +253,9 @@ std::shared_ptr<Schema> HashJoinSchema::MakeOutputSchema(
for (int i = 0; i < left_size + right_size; ++i) {
bool is_left = (i < left_size);
int side = (is_left ? 0 : 1);
int input_field_id =
proj_maps[side].map(HashJoinProjection::OUTPUT,
HashJoinProjection::INPUT)[is_left ? i : i - left_size];
int input_field_id = proj_maps[side]
.map(HashJoinProjection::OUTPUT, HashJoinProjection::INPUT)
.get(is_left ? i : i - left_size);
const std::string& input_field_name =
proj_maps[side].field_name(HashJoinProjection::INPUT, input_field_id);
const std::shared_ptr<DataType>& input_data_type =
Expand Down
33 changes: 22 additions & 11 deletions cpp/src/arrow/compute/exec/hash_join_node_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -921,24 +921,31 @@ void HashJoinWithExecPlan(Random64Bit& rng, bool parallel,

ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get()));

Declaration join{"hashjoin", join_options};

// add left source
BatchesWithSchema l_batches = TableToBatches(rng, num_batches_l, l, "l_");
join.inputs.emplace_back(Declaration{
"source", SourceNodeOptions{l_batches.schema, l_batches.gen(parallel,
/*slow=*/false)}});
ASSERT_OK_AND_ASSIGN(
ExecNode * l_source,
MakeExecNode("source", plan.get(), {},
SourceNodeOptions{l_batches.schema, l_batches.gen(parallel,
/*slow=*/false)}));

// add right source
BatchesWithSchema r_batches = TableToBatches(rng, num_batches_r, r, "r_");
join.inputs.emplace_back(Declaration{
"source", SourceNodeOptions{r_batches.schema, r_batches.gen(parallel,
/*slow=*/false)}});
AsyncGenerator<util::optional<ExecBatch>> sink_gen;
ASSERT_OK_AND_ASSIGN(
ExecNode * r_source,
MakeExecNode("source", plan.get(), {},
SourceNodeOptions{r_batches.schema, r_batches.gen(parallel,
/*slow=*/false)}));

ASSERT_OK(Declaration::Sequence({join, {"sink", SinkNodeOptions{&sink_gen}}})
.AddToPlan(plan.get()));
ASSERT_OK_AND_ASSIGN(ExecNode * join, MakeExecNode("hashjoin", plan.get(),
{l_source, r_source}, join_options));

AsyncGenerator<util::optional<ExecBatch>> sink_gen;
ASSERT_OK_AND_ASSIGN(
std::ignore, MakeExecNode("sink", plan.get(), {join}, SinkNodeOptions{&sink_gen}));

ASSERT_FINISHES_OK_AND_ASSIGN(auto res, StartAndCollect(plan.get(), sink_gen));

ASSERT_OK_AND_ASSIGN(*output, TableFromExecBatches(output_schema, res));
}

Expand Down Expand Up @@ -1056,6 +1063,10 @@ TEST(HashJoin, Random) {
// print num_rows, batch_size, join_type, join_cmp
std::cout << join_type_name << " " << key_cmp_str << " ";
key_types.Print();
std::cout << " payload_l: ";
payload_types[0].Print();
std::cout << " payload_r: ";
payload_types[1].Print();
std::cout << " num_rows_l = " << num_rows_l << " num_rows_r = " << num_rows_r
<< " batch size = " << batch_size
<< " parallel = " << (parallel ? "true" : "false");
Expand Down
Loading

0 comments on commit d9ef519

Please sign in to comment.