Skip to content

Commit

Permalink
GH-34223: [Java] Java Substrait Consumer JNI call to ACERO C++ (#34227)
Browse files Browse the repository at this point in the history
* Closes: #34223

The purpose of this PR is to implement:

1. JNI Wrappers to consume Acero capabilities that execute Substrait Plans. 
2. Java base code to offer API that consume Substrait Plans. 
3. Initial Substrait documentation

Lead-authored-by: david dali susanibar arce <davi.sarces@gmail.com>
Co-authored-by: David Li <li.davidm96@gmail.com>
Signed-off-by: David Li <li.davidm96@gmail.com>
  • Loading branch information
davisusanibar and lidavidm committed May 24, 2023
1 parent c4ea194 commit 95c33d8
Show file tree
Hide file tree
Showing 15 changed files with 809 additions and 5 deletions.
1 change: 1 addition & 0 deletions ci/scripts/java_jni_macos_build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ cmake \
-DARROW_BUILD_TESTS=${ARROW_BUILD_TESTS} \
-DARROW_CSV=${ARROW_DATASET} \
-DARROW_DATASET=${ARROW_DATASET} \
-DARROW_SUBSTRAIT=${ARROW_DATASET} \
-DARROW_DEPENDENCY_USE_SHARED=OFF \
-DARROW_GANDIVA=${ARROW_GANDIVA} \
-DARROW_GANDIVA_STATIC_LIBSTDCPP=ON \
Expand Down
1 change: 1 addition & 0 deletions ci/scripts/java_jni_manylinux_build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ cmake \
-DARROW_BUILD_TESTS=ON \
-DARROW_CSV=${ARROW_DATASET} \
-DARROW_DATASET=${ARROW_DATASET} \
-DARROW_SUBSTRAIT=${ARROW_DATASET} \
-DARROW_DEPENDENCY_SOURCE="VCPKG" \
-DARROW_DEPENDENCY_USE_SHARED=OFF \
-DARROW_GANDIVA_PC_CXX_FLAGS=${GANDIVA_CXX_FLAGS} \
Expand Down
1 change: 1 addition & 0 deletions ci/scripts/java_jni_windows_build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ cmake \
-DARROW_BUILD_TESTS=ON \
-DARROW_CSV=${ARROW_DATASET} \
-DARROW_DATASET=${ARROW_DATASET} \
-DARROW_SUBSTRAIT=${ARROW_DATASET} \
-DARROW_DEPENDENCY_USE_SHARED=OFF \
-DARROW_ORC=${ARROW_ORC} \
-DARROW_PARQUET=${ARROW_PARQUET} \
Expand Down
1 change: 1 addition & 0 deletions docs/source/java/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ on the Arrow format and other language bindings see the :doc:`parent documentati
flight_sql
flight_sql_jdbc_driver
dataset
substrait
cdata
jdbc
Reference (javadoc) <reference/index>
107 changes: 107 additions & 0 deletions docs/source/java/substrait.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
.. Licensed to the Apache Software Foundation (ASF) under one
.. or more contributor license agreements. See the NOTICE file
.. distributed with this work for additional information
.. regarding copyright ownership. The ASF licenses this file
.. to you under the Apache License, Version 2.0 (the
.. "License"); you may not use this file except in compliance
.. with the License. You may obtain a copy of the License at
.. http://www.apache.org/licenses/LICENSE-2.0
.. Unless required by applicable law or agreed to in writing,
.. software distributed under the License is distributed on an
.. "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
.. KIND, either express or implied. See the License for the
.. specific language governing permissions and limitations
.. under the License.
=========
Substrait
=========

The ``arrow-dataset`` module can execute Substrait_ plans via the :doc:`Acero <../cpp/streaming_execution>`
query engine.

Executing Substrait Plans
=========================

Plans can reference data in files via URIs, or "named tables" that must be provided along with the plan.

Here is an example of a Java program that queries a Parquet file using Java Substrait
(this example use `Substrait Java`_ project to compile a SQL query to a Substrait plan):

.. code-block:: Java
import com.google.common.collect.ImmutableList;
import io.substrait.isthmus.SqlToSubstrait;
import io.substrait.proto.Plan;
import org.apache.arrow.dataset.file.FileFormat;
import org.apache.arrow.dataset.file.FileSystemDatasetFactory;
import org.apache.arrow.dataset.jni.NativeMemoryPool;
import org.apache.arrow.dataset.scanner.ScanOptions;
import org.apache.arrow.dataset.scanner.Scanner;
import org.apache.arrow.dataset.source.Dataset;
import org.apache.arrow.dataset.source.DatasetFactory;
import org.apache.arrow.dataset.substrait.AceroSubstraitConsumer;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.ipc.ArrowReader;
import org.apache.calcite.sql.parser.SqlParseException;
import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.Map;
public class ClientSubstrait {
public static void main(String[] args) {
String uri = "file:///data/tpch_parquet/nation.parquet";
ScanOptions options = new ScanOptions(/*batchSize*/ 32768);
try (
BufferAllocator allocator = new RootAllocator();
DatasetFactory datasetFactory = new FileSystemDatasetFactory(allocator, NativeMemoryPool.getDefault(),
FileFormat.PARQUET, uri);
Dataset dataset = datasetFactory.finish();
Scanner scanner = dataset.newScan(options);
ArrowReader reader = scanner.scanBatches()
) {
// map table to reader
Map<String, ArrowReader> mapTableToArrowReader = new HashMap<>();
mapTableToArrowReader.put("NATION", reader);
// get binary plan
Plan plan = getPlan();
ByteBuffer substraitPlan = ByteBuffer.allocateDirect(plan.toByteArray().length);
substraitPlan.put(plan.toByteArray());
// run query
try (ArrowReader arrowReader = new AceroSubstraitConsumer(allocator).runQuery(
substraitPlan,
mapTableToArrowReader
)) {
while (arrowReader.loadNextBatch()) {
System.out.println(arrowReader.getVectorSchemaRoot().contentToTSVString());
}
}
} catch (Exception e) {
e.printStackTrace();
}
}
static Plan getPlan() throws SqlParseException {
String sql = "SELECT * from nation";
String nation = "CREATE TABLE NATION (N_NATIONKEY BIGINT NOT NULL, N_NAME CHAR(25), " +
"N_REGIONKEY BIGINT NOT NULL, N_COMMENT VARCHAR(152))";
SqlToSubstrait sqlToSubstrait = new SqlToSubstrait();
Plan plan = sqlToSubstrait.execute(sql, ImmutableList.of(nation));
return plan;
}
}
.. code-block:: text
// Results example:
FieldPath(0) FieldPath(1) FieldPath(2) FieldPath(3)
0 ALGERIA 0 haggle. carefully final deposits detect slyly agai
1 ARGENTINA 1 al foxes promise slyly according to the regular accounts. bold requests alon
.. _`Substrait`: https://substrait.io/
.. _`Substrait Java`: https://github.com/substrait-io/substrait-java
.. _`Acero`: https://arrow.apache.org/docs/cpp/streaming_execution.html
9 changes: 7 additions & 2 deletions java/dataset/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.

find_package(ArrowDataset REQUIRED)
find_package(ArrowSubstrait REQUIRED)

include_directories(${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_SOURCE_DIR}
${JNI_INCLUDE_DIRS} ${JNI_HEADERS_DIR})
Expand All @@ -26,14 +27,18 @@ add_jar(arrow_java_jni_dataset_jar
src/main/java/org/apache/arrow/dataset/file/JniWrapper.java
src/main/java/org/apache/arrow/dataset/jni/NativeMemoryPool.java
src/main/java/org/apache/arrow/dataset/jni/ReservationListener.java
src/main/java/org/apache/arrow/dataset/substrait/JniWrapper.java
GENERATE_NATIVE_HEADERS
arrow_java_jni_dataset_headers)

add_library(arrow_java_jni_dataset SHARED src/main/cpp/jni_wrapper.cc
src/main/cpp/jni_util.cc)
set_property(TARGET arrow_java_jni_dataset PROPERTY OUTPUT_NAME "arrow_dataset_jni")
target_link_libraries(arrow_java_jni_dataset arrow_java_jni_dataset_headers jni
ArrowDataset::arrow_dataset_static)
target_link_libraries(arrow_java_jni_dataset
arrow_java_jni_dataset_headers
jni
ArrowDataset::arrow_dataset_static
ArrowSubstrait::arrow_substrait_static)

if(BUILD_TESTING)
add_executable(arrow-java-jni-dataset-test src/main/cpp/jni_util_test.cc
Expand Down
118 changes: 118 additions & 0 deletions java/dataset/src/main/cpp/jni_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// under the License.

#include <mutex>
#include <unordered_map>

#include "arrow/array.h"
#include "arrow/array/concatenate.h"
Expand All @@ -24,12 +25,14 @@
#include "arrow/dataset/api.h"
#include "arrow/dataset/file_base.h"
#include "arrow/filesystem/localfs.h"
#include "arrow/engine/substrait/util.h"
#include "arrow/ipc/api.h"
#include "arrow/util/iterator.h"
#include "jni_util.h"
#include "org_apache_arrow_dataset_file_JniWrapper.h"
#include "org_apache_arrow_dataset_jni_JniWrapper.h"
#include "org_apache_arrow_dataset_jni_NativeMemoryPool.h"
#include "org_apache_arrow_dataset_substrait_JniWrapper.h"

namespace {

Expand Down Expand Up @@ -261,6 +264,52 @@ void JNI_OnUnload(JavaVM* vm, void* reserved) {
default_memory_pool_id = -1L;
}

/// Unpack the named tables passed through JNI.
///
/// Named tables are encoded as a string array, where every two elements
/// encode (1) the table name and (2) the address of an ArrowArrayStream
/// containing the table data. This function will eagerly read all
/// tables into Tables.
std::unordered_map<std::string, std::shared_ptr<arrow::Table>> LoadNamedTables(JNIEnv* env, const jobjectArray& str_array) {
std::unordered_map<std::string, std::shared_ptr<arrow::Table>> map_table_to_record_batch_reader;
int length = env->GetArrayLength(str_array);
if (length % 2 != 0) {
JniThrow("Can not map odd number of array elements to key/value pairs");
}
std::shared_ptr<arrow::Table> output_table;
for (int pos = 0; pos < length; pos++) {
auto j_string_key = reinterpret_cast<jstring>(env->GetObjectArrayElement(str_array, pos));
pos++;
auto j_string_value = reinterpret_cast<jstring>(env->GetObjectArrayElement(str_array, pos));
uintptr_t memory_address = 0;
try {
memory_address = std::stol(JStringToCString(env, j_string_value));
} catch(const std::exception& ex) {
JniThrow("Failed to parse memory address from string value. Error: " + std::string(ex.what()));
} catch (...) {
JniThrow("Failed to parse memory address from string value.");
}
auto* arrow_stream_in = reinterpret_cast<ArrowArrayStream*>(memory_address);
std::shared_ptr<arrow::RecordBatchReader> readerIn = JniGetOrThrow(arrow::ImportRecordBatchReader(arrow_stream_in));
output_table = JniGetOrThrow(readerIn->ToTable());
map_table_to_record_batch_reader[JStringToCString(env, j_string_key)] = output_table;
}
return map_table_to_record_batch_reader;
}

/// Find the arrow Table associated with a given table name
std::shared_ptr<arrow::Table> GetTableByName(const std::vector<std::string>& names,
const std::unordered_map<std::string, std::shared_ptr<arrow::Table>>& tables) {
if (names.size() != 1) {
JniThrow("Tables with hierarchical names are not supported");
}
const auto& it = tables.find(names[0]);
if (it == tables.end()) {
JniThrow("Table is referenced, but not provided: " + names[0]);
}
return it->second;
}

/*
* Class: org_apache_arrow_dataset_jni_NativeMemoryPool
* Method: getDefaultMemoryPool
Expand Down Expand Up @@ -578,3 +627,72 @@ Java_org_apache_arrow_dataset_file_JniWrapper_writeFromScannerToFile(
JniAssertOkOrThrow(arrow::dataset::FileSystemDataset::Write(options, scanner));
JNI_METHOD_END()
}

/*
* Class: org_apache_arrow_dataset_substrait_JniWrapper
* Method: executeSerializedPlan
* Signature: (Ljava/lang/String;[Ljava/lang/String;J)V
*/
JNIEXPORT void JNICALL
Java_org_apache_arrow_dataset_substrait_JniWrapper_executeSerializedPlan__Ljava_lang_String_2_3Ljava_lang_String_2J (
JNIEnv* env, jobject, jstring plan, jobjectArray table_to_memory_address_input,
jlong memory_address_output) {
JNI_METHOD_START
// get mapping of table name to memory address
std::unordered_map<std::string, std::shared_ptr<arrow::Table>> map_table_to_reader =
LoadNamedTables(env, table_to_memory_address_input);
// create table provider
arrow::engine::NamedTableProvider table_provider =
[&map_table_to_reader](const std::vector<std::string>& names, const arrow::Schema&) {
std::shared_ptr<arrow::Table> output_table = GetTableByName(names, map_table_to_reader);
std::shared_ptr<arrow::acero::ExecNodeOptions> options =
std::make_shared<arrow::acero::TableSourceNodeOptions>(std::move(output_table));
return arrow::acero::Declaration("table_source", {}, options, "java_source");
};
arrow::engine::ConversionOptions conversion_options;
conversion_options.named_table_provider = std::move(table_provider);
// execute plan
std::shared_ptr<arrow::Buffer> buffer = JniGetOrThrow(arrow::engine::SerializeJsonPlan(
JStringToCString(env, plan)));
std::shared_ptr<arrow::RecordBatchReader> reader_out =
JniGetOrThrow(arrow::engine::ExecuteSerializedPlan(*buffer, nullptr, nullptr, conversion_options));
auto* arrow_stream_out = reinterpret_cast<ArrowArrayStream*>(memory_address_output);
JniAssertOkOrThrow(arrow::ExportRecordBatchReader(reader_out, arrow_stream_out));
JNI_METHOD_END()
}

/*
* Class: org_apache_arrow_dataset_substrait_JniWrapper
* Method: executeSerializedPlan
* Signature: (Ljava/nio/ByteBuffer;[Ljava/lang/String;J)V
*/
JNIEXPORT void JNICALL
Java_org_apache_arrow_dataset_substrait_JniWrapper_executeSerializedPlan__Ljava_nio_ByteBuffer_2_3Ljava_lang_String_2J (
JNIEnv* env, jobject, jobject plan, jobjectArray table_to_memory_address_input,
jlong memory_address_output) {
JNI_METHOD_START
// get mapping of table name to memory address
std::unordered_map<std::string, std::shared_ptr<arrow::Table>> map_table_to_reader =
LoadNamedTables(env, table_to_memory_address_input);
// create table provider
arrow::engine::NamedTableProvider table_provider =
[&map_table_to_reader](const std::vector<std::string>& names, const arrow::Schema&) {
std::shared_ptr<arrow::Table> output_table = GetTableByName(names, map_table_to_reader);
std::shared_ptr<arrow::acero::ExecNodeOptions> options =
std::make_shared<arrow::acero::TableSourceNodeOptions>(std::move(output_table));
return arrow::acero::Declaration("table_source", {}, options, "java_source");
};
arrow::engine::ConversionOptions conversion_options;
conversion_options.named_table_provider = std::move(table_provider);
// mapping arrow::Buffer
auto *buff = reinterpret_cast<jbyte*>(env->GetDirectBufferAddress(plan));
int length = env->GetDirectBufferCapacity(plan);
std::shared_ptr<arrow::Buffer> buffer = JniGetOrThrow(arrow::AllocateBuffer(length));
std::memcpy(buffer->mutable_data(), buff, length);
// execute plan
std::shared_ptr<arrow::RecordBatchReader> reader_out =
JniGetOrThrow(arrow::engine::ExecuteSerializedPlan(*buffer, nullptr, nullptr, conversion_options));
auto* arrow_stream_out = reinterpret_cast<ArrowArrayStream*>(memory_address_output);
JniAssertOkOrThrow(arrow::ExportRecordBatchReader(reader_out, arrow_stream_out));
JNI_METHOD_END()
}

0 comments on commit 95c33d8

Please sign in to comment.