Skip to content

Commit

Permalink
[WIP][VL] Support celeborn sort based shuffle
Browse files Browse the repository at this point in the history
  • Loading branch information
kerwin-zk committed May 10, 2024
1 parent 24d949e commit 34e4b51
Show file tree
Hide file tree
Showing 36 changed files with 739 additions and 53 deletions.
18 changes: 16 additions & 2 deletions cpp/core/jni/JniCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,20 @@ static inline arrow::Compression::type getCompressionType(JNIEnv* env, jstring c
return compressionType;
}

static inline const std::string getCompressionTypeStr(JNIEnv* env, jstring codecJstr) {
if (codecJstr == NULL) {
return "none";
}
auto codec = env->GetStringUTFChars(codecJstr, JNI_FALSE);

// Convert codec string into lowercase.
std::string codecLower;
std::transform(codec, codec + std::strlen(codec), std::back_inserter(codecLower), ::tolower);

env->ReleaseStringUTFChars(codecJstr, codec);
return codecLower;
}

static inline gluten::CodecBackend getCodecBackend(JNIEnv* env, jstring codecJstr) {
if (codecJstr == nullptr) {
return gluten::CodecBackend::NONE;
Expand Down Expand Up @@ -444,7 +458,7 @@ class JavaRssClient : public RssClient {
env->DeleteGlobalRef(array_);
}

int32_t pushPartitionData(int32_t partitionId, char* bytes, int64_t size) override {
int32_t pushPartitionData(int32_t partitionId, const char* bytes, int64_t size) override {
JNIEnv* env;
if (vm_->GetEnv(reinterpret_cast<void**>(&env), jniVersion) != JNI_OK) {
throw gluten::GlutenException("JNIEnv was not attached to current thread");
Expand All @@ -457,7 +471,7 @@ class JavaRssClient : public RssClient {
array_ = env->NewByteArray(size);
array_ = static_cast<jbyteArray>(env->NewGlobalRef(array_));
}
env->SetByteArrayRegion(array_, 0, size, reinterpret_cast<jbyte*>(bytes));
env->SetByteArrayRegion(array_, 0, size, (jbyte*)bytes);
jint javaBytesSize = env->CallIntMethod(javaRssShuffleWriter_, javaPushPartitionData_, partitionId, array_, size);
checkException(env);
return static_cast<int32_t>(javaBytesSize);
Expand Down
110 changes: 104 additions & 6 deletions cpp/core/jni/JniWrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <optional>
#include "memory/AllocationListener.h"
#include "operators/serializer/ColumnarBatchSerializer.h"
#include "shuffle/JavaInputStreamWrapper.h"
#include "shuffle/LocalPartitionWriter.h"
#include "shuffle/Partitioning.h"
#include "shuffle/ShuffleReader.h"
Expand Down Expand Up @@ -148,6 +149,69 @@ class JavaInputStreamAdaptor final : public arrow::io::InputStream {
bool closed_ = false;
};

class JavaInputStreamVeloxWrapper final : public JavaInputStreamWrapper {
public:
JavaInputStreamVeloxWrapper(JNIEnv* env, jobject jniIn) {
// IMPORTANT: DO NOT USE LOCAL REF IN DIFFERENT THREAD
if (env->GetJavaVM(&vm_) != JNI_OK) {
std::string errorMessage = "Unable to get JavaVM instance";
throw gluten::GlutenException(errorMessage);
}
jniIn_ = env->NewGlobalRef(jniIn);
}

~JavaInputStreamVeloxWrapper() {
try {
auto status = JavaInputStreamVeloxWrapper::close();
if (!status.ok()) {
LOG(WARNING) << __func__ << " call JavaInputStreamVeloxWrapper::close() failed, status:" << status.ToString();
}
} catch (std::exception& e) {
LOG(WARNING) << __func__ << " call JavaInputStreamVeloxWrapper::close() got exception:" << e.what();
}
}

// not thread safe
arrow::Status close() override {
if (closed_) {
return arrow::Status::OK();
}
JNIEnv* env;
attachCurrentThreadAsDaemonOrThrow(vm_, &env);
env->CallVoidMethod(jniIn_, jniByteInputStreamClose);
checkException(env);
env->DeleteGlobalRef(jniIn_);
vm_->DetachCurrentThread();
closed_ = true;
return arrow::Status::OK();
}

int64_t tell() override {
JNIEnv* env;
attachCurrentThreadAsDaemonOrThrow(vm_, &env);
jlong told = env->CallLongMethod(jniIn_, jniByteInputStreamTell);
checkException(env);
return told;
}

bool closed() {
return closed_;
}

int64_t read(int64_t nbytes, void* out) override {
JNIEnv* env;
attachCurrentThreadAsDaemonOrThrow(vm_, &env);
jlong read = env->CallLongMethod(jniIn_, jniByteInputStreamRead, reinterpret_cast<jlong>(out), nbytes);
checkException(env);
return read;
}

private:
JavaVM* vm_;
jobject jniIn_;
bool closed_ = false;
};

class JniColumnarBatchIterator : public ColumnarBatchIterator {
public:
explicit JniColumnarBatchIterator(
Expand Down Expand Up @@ -831,8 +895,10 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_vectorized_ShuffleWriterJniWrappe
jlong taskAttemptId,
jint startPartitionId,
jint pushBufferMaxSize,
jlong sortBufferMaxSize,
jobject partitionPusher,
jstring partitionWriterTypeJstr) {
jstring partitionWriterTypeJstr,
jstring shuffleWriterTypeJstr) {
JNI_METHOD_START
auto ctx = gluten::getRuntime(env, wrapper);
auto memoryManager = jniCastOrThrow<MemoryManager>(memoryManagerHandle);
Expand Down Expand Up @@ -866,10 +932,12 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_vectorized_ShuffleWriterJniWrappe
.mergeThreshold = mergeThreshold,
.compressionThreshold = compressionThreshold,
.compressionType = getCompressionType(env, codecJstr),
.compressionTypeStr = getCompressionTypeStr(env, codecJstr),
.compressionLevel = compressionLevel,
.bufferedWrite = true,
.numSubDirs = numSubDirs,
.pushBufferMaxSize = pushBufferMaxSize > 0 ? pushBufferMaxSize : kDefaultShuffleWriterBufferSize};
.pushBufferMaxSize = pushBufferMaxSize > 0 ? pushBufferMaxSize : kDefaultShuffleWriterBufferSize,
.sortBufferMaxSize = sortBufferMaxSize > 0 ? sortBufferMaxSize : kDefaultShuffleWriterBufferSize};
if (codecJstr != NULL) {
partitionWriterOptions.codecBackend = getCodecBackend(env, codecBackendJstr);
partitionWriterOptions.compressionMode = getCompressionMode(env, compressionModeJstr);
Expand All @@ -879,6 +947,15 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_vectorized_ShuffleWriterJniWrappe
auto partitionWriterTypeC = env->GetStringUTFChars(partitionWriterTypeJstr, JNI_FALSE);
auto partitionWriterType = std::string(partitionWriterTypeC);
env->ReleaseStringUTFChars(partitionWriterTypeJstr, partitionWriterTypeC);

auto shuffleWriterTypeC = env->GetStringUTFChars(shuffleWriterTypeJstr, JNI_FALSE);
auto shuffleWriterType = std::string(shuffleWriterTypeC);
env->ReleaseStringUTFChars(shuffleWriterTypeJstr, shuffleWriterTypeC);

if (shuffleWriterType == "sort") {
shuffleWriterOptions.shuffleWriterType = kSortShuffle;
}

if (partitionWriterType == "local") {
if (dataFileJstr == NULL) {
throw gluten::GlutenException(std::string("Shuffle DataFile can't be null"));
Expand Down Expand Up @@ -981,7 +1058,11 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_vectorized_ShuffleWriterJniWrappe
// The column batch maybe VeloxColumnBatch or ArrowCStructColumnarBatch(FallbackRangeShuffleWriter)
auto batch = ctx->objectStore()->retrieve<ColumnarBatch>(batchHandle);
auto numBytes = batch->numBytes();
gluten::arrowAssertOkOrThrow(shuffleWriter->split(batch, memLimit), "Native split: shuffle writer split failed");
if (shuffleWriter->options().shuffleWriterType == kHashShuffle) {
gluten::arrowAssertOkOrThrow(shuffleWriter->split(batch, memLimit), "Native split: shuffle writer split failed");
} else if (shuffleWriter->options().shuffleWriterType == kSortShuffle) {
gluten::arrowAssertOkOrThrow(shuffleWriter->sort(batch, memLimit), "Native split: shuffle writer split failed");
}
return numBytes;
JNI_METHOD_END(kInvalidResourceHandle)
}
Expand Down Expand Up @@ -1058,19 +1139,25 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_vectorized_ShuffleReaderJniWrappe
jlong memoryManagerHandle,
jstring compressionType,
jstring compressionBackend,
jint batchSize) {
jint batchSize,
jstring shuffleWriterType) {
JNI_METHOD_START
auto ctx = gluten::getRuntime(env, wrapper);
auto memoryManager = jniCastOrThrow<MemoryManager>(memoryManagerHandle);

auto pool = memoryManager->getArrowMemoryPool();
ShuffleReaderOptions options = ShuffleReaderOptions{};
options.compressionType = getCompressionType(env, compressionType);
options.compressionTypeStr = getCompressionTypeStr(env, compressionType);
if (compressionType != nullptr) {
options.codecBackend = getCodecBackend(env, compressionBackend);
}
options.batchSize = batchSize;
// TODO: Add coalesce option and maximum coalesced size.

if (jStringToCString(env, shuffleWriterType) == "sort") {
options.shuffleWriterType = kSortShuffle;
}
std::shared_ptr<arrow::Schema> schema =
gluten::arrowGetOrThrow(arrow::ImportSchema(reinterpret_cast<struct ArrowSchema*>(cSchema)));

Expand All @@ -1087,8 +1174,19 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_vectorized_ShuffleReaderJniWrappe
auto ctx = gluten::getRuntime(env, wrapper);

auto reader = ctx->objectStore()->retrieve<ShuffleReader>(shuffleReaderHandle);
std::shared_ptr<arrow::io::InputStream> in = std::make_shared<JavaInputStreamAdaptor>(env, reader->getPool(), jniIn);
auto outItr = reader->readStream(in);
std::shared_ptr<ResultIterator> outItr;
const auto shuffleWriterType = reader->getShuffleWriterType();
if (shuffleWriterType == kHashShuffle) {
std::shared_ptr<arrow::io::InputStream> in =
std::make_shared<JavaInputStreamAdaptor>(env, reader->getPool(), jniIn);
outItr = reader->readStream(in);
} else if (shuffleWriterType == kSortShuffle) {
std::shared_ptr<JavaInputStreamWrapper> in = std::make_shared<JavaInputStreamVeloxWrapper>(env, jniIn);
outItr = reader->readStream(in);
} else {
std::string errorMessage = "Invalid shuffle writer type " + std::to_string(shuffleWriterType);
throw gluten::GlutenException(errorMessage);
}
return ctx->objectStore()->save(outItr);
JNI_METHOD_END(kInvalidResourceHandle)
}
Expand Down
18 changes: 18 additions & 0 deletions cpp/core/shuffle/FallbackRangePartitioner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,22 @@ arrow::Status gluten::FallbackRangePartitioner::compute(
return arrow::Status::OK();
}

arrow::Status gluten::FallbackRangePartitioner::compute(
const int32_t* pidArr,
const int64_t numRows,
const int32_t vectorIndex,
std::unordered_map<int32_t, std::vector<int64_t>>& rowVectorIndexMap) {
auto index = static_cast<int64_t>(vectorIndex) << 32;
for (auto i = 0; i < numRows; ++i) {
auto pid = pidArr[i];
int64_t combined = index | (i & 0xFFFFFFFFLL);
auto& vec = rowVectorIndexMap[pid];
vec.push_back(combined);
if (pid >= numPartitions_) {
return arrow::Status::Invalid(
"Partition id ", std::to_string(pid), " is equal or greater than ", std::to_string(numPartitions_));
}
}
return arrow::Status::OK();
}
} // namespace gluten
6 changes: 6 additions & 0 deletions cpp/core/shuffle/FallbackRangePartitioner.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ class FallbackRangePartitioner final : public Partitioner {
const int64_t numRows,
std::vector<uint32_t>& row2partition,
std::vector<uint32_t>& partition2RowCount) override;

arrow::Status compute(
const int32_t* pidArr,
const int64_t numRows,
const int32_t vectorIndex,
std::unordered_map<int32_t, std::vector<int64_t>>& rowVectorIndexMap) override;
};

} // namespace gluten
29 changes: 29 additions & 0 deletions cpp/core/shuffle/HashPartitioner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,33 @@ arrow::Status gluten::HashPartitioner::compute(
return arrow::Status::OK();
}

arrow::Status gluten::HashPartitioner::compute(
const int32_t* pidArr,
const int64_t numRows,
const int32_t vectorIndex,
std::unordered_map<int32_t, std::vector<int64_t>>& rowVectorIndexMap) {
auto index = static_cast<int64_t>(vectorIndex) << 32;
for (auto i = 0; i < numRows; ++i) {
auto pid = pidArr[i] % numPartitions_;
#if defined(__x86_64__)
// force to generate ASM
__asm__(
"lea (%[num_partitions],%[pid],1),%[tmp]\n"
"test %[pid],%[pid]\n"
"cmovs %[tmp],%[pid]\n"
: [pid] "+r"(pid)
: [num_partitions] "r"(numPartitions_), [tmp] "r"(0));
#else
if (pid < 0) {
pid += numPartitions_;
}
#endif
int64_t combined = index | (i & 0xFFFFFFFFLL);
auto& vec = rowVectorIndexMap[pid];
vec.push_back(combined);
}

return arrow::Status::OK();
}

} // namespace gluten
6 changes: 6 additions & 0 deletions cpp/core/shuffle/HashPartitioner.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ class HashPartitioner final : public Partitioner {
const int64_t numRows,
std::vector<uint32_t>& row2partition,
std::vector<uint32_t>& partition2RowCount) override;

arrow::Status compute(
const int32_t* pidArr,
const int64_t numRows,
const int32_t vectorIndex,
std::unordered_map<int32_t, std::vector<int64_t>>& rowVectorIndexMap) override;
};

} // namespace gluten
29 changes: 29 additions & 0 deletions cpp/core/shuffle/JavaInputStreamWrapper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

class JavaInputStreamWrapper {
public:
virtual ~JavaInputStreamWrapper() = default;

virtual arrow::Status close() = 0;

virtual int64_t tell() = 0;

virtual int64_t read(int64_t nbytes, void* out) = 0;
};
4 changes: 4 additions & 0 deletions cpp/core/shuffle/LocalPartitionWriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,10 @@ arrow::Status LocalPartitionWriter::evict(
return arrow::Status::OK();
}

arrow::Status LocalPartitionWriter::evict(uint32_t partitionId, int64_t rawSize, const char* data, int64_t length) {
return arrow::Status::OK();
}

arrow::Status LocalPartitionWriter::reclaimFixedSize(int64_t size, int64_t* actual) {
// Finish last spiller.
RETURN_NOT_OK(finishSpill());
Expand Down
2 changes: 2 additions & 0 deletions cpp/core/shuffle/LocalPartitionWriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class LocalPartitionWriter : public PartitionWriter {
bool reuseBuffers,
bool hasComplexType) override;

arrow::Status evict(uint32_t partitionId, int64_t rawSize, const char* data, int64_t length) override;

/// The stop function performs several tasks:
/// 1. Opens the final data file.
/// 2. Iterates over each partition ID (pid) to:
Expand Down
Loading

0 comments on commit 34e4b51

Please sign in to comment.