From 22cc85f2f14cc40438d9d333a876c1c37171c74e Mon Sep 17 00:00:00 2001 From: David Li Date: Wed, 29 Jun 2022 10:09:35 -0400 Subject: [PATCH 1/5] ARROW-16913: [Java] Implement ArrowArrayStream --- java/c/CMakeLists.txt | 1 + java/c/pom.xml | 5 + java/c/src/main/cpp/jni_wrapper.cc | 261 ++++++++++++-- .../apache/arrow/c/ArrayStreamExporter.java | 116 ++++++ .../org/apache/arrow/c/ArrowArrayStream.java | 194 ++++++++++ .../arrow/c/ArrowArrayStreamReader.java | 95 +++++ .../arrow/c/CDataDictionaryProvider.java | 1 + .../main/java/org/apache/arrow/c/Data.java | 21 ++ .../java/org/apache/arrow/c/NativeUtil.java | 4 +- .../apache/arrow/c/jni/CDataJniException.java | 45 +++ .../org/apache/arrow/c/jni/JniWrapper.java | 8 + .../org/apache/arrow/c/RoundtripTest.java | 3 - .../java/org/apache/arrow/c/StreamTest.java | 332 ++++++++++++++++++ java/pom.xml | 6 + .../vector/dictionary/DictionaryProvider.java | 4 + .../apache/arrow/vector/ipc/ArrowReader.java | 6 + .../arrow/vector/ipc/JsonFileReader.java | 6 + .../vector/ipc/message/MessageSerializer.java | 3 +- 18 files changed, 1077 insertions(+), 34 deletions(-) create mode 100644 java/c/src/main/java/org/apache/arrow/c/ArrayStreamExporter.java create mode 100644 java/c/src/main/java/org/apache/arrow/c/ArrowArrayStream.java create mode 100644 java/c/src/main/java/org/apache/arrow/c/ArrowArrayStreamReader.java create mode 100644 java/c/src/main/java/org/apache/arrow/c/jni/CDataJniException.java create mode 100644 java/c/src/test/java/org/apache/arrow/c/StreamTest.java diff --git a/java/c/CMakeLists.txt b/java/c/CMakeLists.txt index 1025f87afbc68..05938508de29d 100644 --- a/java/c/CMakeLists.txt +++ b/java/c/CMakeLists.txt @@ -35,6 +35,7 @@ include_directories(${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_SOURCE_DIR} ${JNI_INCLUDE_DIRS} ${JNI_HEADERS_DIR}) add_jar(${PROJECT_NAME} + src/main/java/org/apache/arrow/c/jni/CDataJniException.java src/main/java/org/apache/arrow/c/jni/JniLoader.java src/main/java/org/apache/arrow/c/jni/JniWrapper.java src/main/java/org/apache/arrow/c/jni/PrivateData.java diff --git a/java/c/pom.xml b/java/c/pom.xml index 930c5b22d6df6..6d0632ea16584 100644 --- a/java/c/pom.xml +++ b/java/c/pom.xml @@ -62,6 +62,11 @@ ${dep.guava.version} test + + org.assertj + assertj-core + test + diff --git a/java/c/src/main/cpp/jni_wrapper.cc b/java/c/src/main/cpp/jni_wrapper.cc index cfb0af9bcbb0b..604c428224631 100644 --- a/java/c/src/main/cpp/jni_wrapper.cc +++ b/java/c/src/main/cpp/jni_wrapper.cc @@ -18,6 +18,7 @@ #include #include +#include #include #include #include @@ -27,19 +28,17 @@ namespace { -jclass CreateGlobalClassReference(JNIEnv* env, const char* class_name) { - jclass local_class = env->FindClass(class_name); - jclass global_class = (jclass)env->NewGlobalRef(local_class); - env->DeleteLocalRef(local_class); - return global_class; -} +jclass kRuntimeExceptionClass; +jclass kPrivateDataClass; +jclass kCDataExceptionClass; +jclass kStreamPrivateDataClass; -jclass illegal_access_exception_class; -jclass illegal_argument_exception_class; -jclass runtime_exception_class; -jclass private_data_class; +jfieldID kPrivateDataLastErrorField; -jmethodID private_data_close_method; +jmethodID kPrivateDataCloseMethod; +jmethodID kPrivateDataGetNextMethod; +jmethodID kPrivateDataGetSchemaMethod; +jmethodID kCDataExceptionConstructor; jint JNI_VERSION = JNI_VERSION_1_6; @@ -54,16 +53,43 @@ void ThrowPendingException(const std::string& message) { void JniThrow(std::string message) { ThrowPendingException(message); } +jclass CreateGlobalClassReference(JNIEnv* env, const char* class_name) { + jclass local_class = env->FindClass(class_name); + if (!local_class) { + std::string message = "Could not find class "; + message += class_name; + ThrowPendingException(message); + } + jclass global_class = (jclass)env->NewGlobalRef(local_class); + if (!local_class) { + std::string message = "Could not create global reference to class "; + message += class_name; + ThrowPendingException(message); + } + env->DeleteLocalRef(local_class); + return global_class; +} + jmethodID GetMethodID(JNIEnv* env, jclass this_class, const char* name, const char* sig) { jmethodID ret = env->GetMethodID(this_class, name, sig); if (ret == nullptr) { std::string error_message = "Unable to find method " + std::string(name) + - " within signature " + std::string(sig); + " with signature " + std::string(sig); ThrowPendingException(error_message); } return ret; } +jfieldID GetFieldID(JNIEnv* env, jclass this_class, const char* name, const char* sig) { + jfieldID fieldId = env->GetFieldID(this_class, name, sig); + if (fieldId == nullptr) { + std::string error_message = "Unable to find field " + std::string(name) + + " with signature " + std::string(sig); + ThrowPendingException(error_message); + } + return fieldId; +} + class InnerPrivateData { public: InnerPrivateData(JavaVM* vm, jobject private_data) @@ -71,6 +97,8 @@ class InnerPrivateData { JavaVM* vm_; jobject j_private_data_; + // Only for ArrowArrayStream + std::string last_error_; }; class JNIEnvGuard { @@ -135,7 +163,7 @@ void release_exported(T* base) { JNIEnvGuard guard(private_data->vm_); JNIEnv* env = guard.env(); - env->CallObjectMethod(private_data->j_private_data_, private_data_close_method); + env->CallObjectMethod(private_data->j_private_data_, kPrivateDataCloseMethod); if (env->ExceptionCheck()) { env->ExceptionDescribe(); env->ExceptionClear(); @@ -148,16 +176,99 @@ void release_exported(T* base) { // Mark released base->release = nullptr; } + +int ArrowArrayStreamGetSchema(ArrowArrayStream* stream, ArrowSchema* out) { + assert(stream->private_data != nullptr); + InnerPrivateData* private_data = + reinterpret_cast(stream->private_data); + JNIEnvGuard guard(private_data->vm_); + JNIEnv* env = guard.env(); + + const long out_addr = static_cast(reinterpret_cast(out)); + const int err_code = env->CallIntMethod(private_data->j_private_data_, + kPrivateDataGetSchemaMethod, out_addr); + if (env->ExceptionCheck()) { + env->ExceptionDescribe(); + env->ExceptionClear(); + return EIO; + } + return err_code; +} + +int ArrowArrayStreamGetNext(ArrowArrayStream* stream, ArrowArray* out) { + assert(stream->private_data != nullptr); + InnerPrivateData* private_data = + reinterpret_cast(stream->private_data); + JNIEnvGuard guard(private_data->vm_); + JNIEnv* env = guard.env(); + + const long out_addr = static_cast(reinterpret_cast(out)); + const int err_code = env->CallIntMethod(private_data->j_private_data_, + kPrivateDataGetNextMethod, out_addr); + if (env->ExceptionCheck()) { + env->ExceptionDescribe(); + env->ExceptionClear(); + return EIO; + } + return err_code; +} + +const char* ArrowArrayStreamGetLastError(ArrowArrayStream* stream) { + assert(stream->private_data != nullptr); + InnerPrivateData* private_data = + reinterpret_cast(stream->private_data); + JNIEnvGuard guard(private_data->vm_); + JNIEnv* env = guard.env(); + + jobject error_data = + env->GetObjectField(private_data->j_private_data_, kPrivateDataLastErrorField); + if (!error_data) return nullptr; + + auto arr = reinterpret_cast(error_data); + jbyte* error_bytes = env->GetByteArrayElements(arr, nullptr); + if (!error_bytes) return nullptr; + + char* error_str = reinterpret_cast(error_bytes); + private_data->last_error_ = std::string(error_str, std::strlen(error_str)); + + env->ReleaseByteArrayElements(arr, error_bytes, JNI_ABORT); + return private_data->last_error_.c_str(); +} + +void ArrowArrayStreamRelease(ArrowArrayStream* stream) { + // This should not be called on already released structure + assert(stream->release != nullptr); + // Release all data directly owned by the struct + InnerPrivateData* private_data = + reinterpret_cast(stream->private_data); + + JNIEnvGuard guard(private_data->vm_); + JNIEnv* env = guard.env(); + + env->CallObjectMethod(private_data->j_private_data_, kPrivateDataCloseMethod); + if (env->ExceptionCheck()) { + env->ExceptionDescribe(); + env->ExceptionClear(); + ThrowPendingException("Error calling close of private data"); + } + env->DeleteGlobalRef(private_data->j_private_data_); + delete private_data; + stream->private_data = nullptr; + + // Mark released + stream->release = nullptr; +} + } // namespace #define JNI_METHOD_START try { // macro ended -#define JNI_METHOD_END(fallback_expr) \ - } \ - catch (JniPendingException & e) { \ - env->ThrowNew(runtime_exception_class, e.what()); \ - return fallback_expr; \ +#define JNI_METHOD_END(fallback_expr) \ + } \ + catch (JniPendingException & e) { \ + env->ThrowNew(kRuntimeExceptionClass, e.what()); \ + return fallback_expr; \ } // macro ended @@ -167,16 +278,25 @@ jint JNI_OnLoad(JavaVM* vm, void* reserved) { return JNI_ERR; } JNI_METHOD_START - illegal_access_exception_class = - CreateGlobalClassReference(env, "Ljava/lang/IllegalAccessException;"); - illegal_argument_exception_class = - CreateGlobalClassReference(env, "Ljava/lang/IllegalArgumentException;"); - runtime_exception_class = + kRuntimeExceptionClass = CreateGlobalClassReference(env, "Ljava/lang/RuntimeException;"); - private_data_class = + kPrivateDataClass = CreateGlobalClassReference(env, "Lorg/apache/arrow/c/jni/PrivateData;"); - - private_data_close_method = GetMethodID(env, private_data_class, "close", "()V"); + kCDataExceptionClass = + CreateGlobalClassReference(env, "Lorg/apache/arrow/c/jni/CDataJniException;"); + kStreamPrivateDataClass = CreateGlobalClassReference( + env, "Lorg/apache/arrow/c/ArrayStreamExporter$ExportedArrayStreamPrivateData;"); + + kPrivateDataLastErrorField = + GetFieldID(env, kStreamPrivateDataClass, "lastError", "[B"); + + kPrivateDataCloseMethod = GetMethodID(env, kPrivateDataClass, "close", "()V"); + kPrivateDataGetNextMethod = + GetMethodID(env, kStreamPrivateDataClass, "getNext", "(J)I"); + kPrivateDataGetSchemaMethod = + GetMethodID(env, kStreamPrivateDataClass, "getSchema", "(J)I"); + kCDataExceptionConstructor = + GetMethodID(env, kCDataExceptionClass, "", "(ILjava/lang/String;)V"); return JNI_VERSION; JNI_METHOD_END(JNI_ERR) @@ -185,9 +305,9 @@ jint JNI_OnLoad(JavaVM* vm, void* reserved) { void JNI_OnUnload(JavaVM* vm, void* reserved) { JNIEnv* env; vm->GetEnv(reinterpret_cast(&env), JNI_VERSION); - env->DeleteGlobalRef(illegal_access_exception_class); - env->DeleteGlobalRef(illegal_argument_exception_class); - env->DeleteGlobalRef(runtime_exception_class); + env->DeleteGlobalRef(kRuntimeExceptionClass); + env->DeleteGlobalRef(kPrivateDataClass); + env->DeleteGlobalRef(kCDataExceptionClass); } /* @@ -220,6 +340,65 @@ Java_org_apache_arrow_c_jni_JniWrapper_releaseArray(JNIEnv* env, jobject, jlong JNI_METHOD_END() } +/* + * Class: org_apache_arrow_c_jni_JniWrapper + * Method: getNextArrayStream + * Signature: (JJ)V + */ +JNIEXPORT void JNICALL Java_org_apache_arrow_c_jni_JniWrapper_getNextArrayStream( + JNIEnv* env, jobject, jlong address, jlong out_address) { + JNI_METHOD_START + auto* stream = reinterpret_cast(address); + auto* out = reinterpret_cast(out_address); + const int err_code = stream->get_next(stream, out); + if (err_code != 0) { + const char* message = stream->get_last_error(stream); + if (!message) message = std::strerror(err_code); + jstring java_message = env->NewStringUTF(message); + jthrowable exception = static_cast(env->NewObject( + kCDataExceptionClass, kCDataExceptionConstructor, err_code, java_message)); + env->Throw(exception); + } + JNI_METHOD_END() +} + +/* + * Class: org_apache_arrow_c_jni_JniWrapper + * Method: getSchemaArrayStream + * Signature: (JJ)V + */ +JNIEXPORT void JNICALL Java_org_apache_arrow_c_jni_JniWrapper_getSchemaArrayStream( + JNIEnv* env, jobject, jlong address, jlong out_address) { + JNI_METHOD_START + auto* stream = reinterpret_cast(address); + auto* out = reinterpret_cast(out_address); + const int err_code = stream->get_schema(stream, out); + if (err_code != 0) { + const char* message = stream->get_last_error(stream); + if (!message) message = std::strerror(err_code); + jstring java_message = env->NewStringUTF(message); + jthrowable exception = static_cast(env->NewObject( + kCDataExceptionClass, kCDataExceptionConstructor, err_code, java_message)); + env->Throw(exception); + } + JNI_METHOD_END() +} + +/* + * Class: org_apache_arrow_c_jni_JniWrapper + * Method: releaseArrayStream + * Signature: (J)V + */ +JNIEXPORT void JNICALL Java_org_apache_arrow_c_jni_JniWrapper_releaseArrayStream( + JNIEnv* env, jobject, jlong address) { + JNI_METHOD_START + auto* stream = reinterpret_cast(address); + if (stream->release != nullptr) { + stream->release(stream); + } + JNI_METHOD_END() +} + /* * Class: org_apache_arrow_c_jni_JniWrapper * Method: exportSchema @@ -261,3 +440,27 @@ JNIEXPORT void JNICALL Java_org_apache_arrow_c_jni_JniWrapper_exportArray( array->release = &release_exported; JNI_METHOD_END() } + +/* + * Class: org_apache_arrow_c_jni_JniWrapper + * Method: exportArrayStream + * Signature: (JLorg/apache/arrow/c/jni/PrivateData;)V + */ +JNIEXPORT void JNICALL Java_org_apache_arrow_c_jni_JniWrapper_exportArrayStream( + JNIEnv* env, jobject, jlong address, jobject private_data) { + JNI_METHOD_START + auto* stream = reinterpret_cast(address); + + JavaVM* vm; + if (env->GetJavaVM(&vm) != JNI_OK) { + JniThrow("Unable to get JavaVM instance"); + } + jobject private_data_ref = env->NewGlobalRef(private_data); + + stream->get_schema = &ArrowArrayStreamGetSchema; + stream->get_next = &ArrowArrayStreamGetNext; + stream->get_last_error = &ArrowArrayStreamGetLastError; + stream->release = &ArrowArrayStreamRelease; + stream->private_data = new InnerPrivateData(vm, private_data_ref); + JNI_METHOD_END() +} diff --git a/java/c/src/main/java/org/apache/arrow/c/ArrayStreamExporter.java b/java/c/src/main/java/org/apache/arrow/c/ArrayStreamExporter.java new file mode 100644 index 0000000000000..81f495683590c --- /dev/null +++ b/java/c/src/main/java/org/apache/arrow/c/ArrayStreamExporter.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.c; + +import java.io.IOException; +import java.io.PrintWriter; +import java.io.StringWriter; +import java.nio.charset.StandardCharsets; + +import org.apache.arrow.c.jni.JniWrapper; +import org.apache.arrow.c.jni.PrivateData; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.ipc.ArrowReader; +import org.apache.arrow.vector.types.pojo.Schema; + +/** + * Utility to export an {@link ArrowReader} as an ArrowArrayStream. + */ +final class ArrayStreamExporter { + private final BufferAllocator allocator; + + ArrayStreamExporter(BufferAllocator allocator) { + this.allocator = allocator; + } + + /** + * Java-side state for the exported stream. + */ + static class ExportedArrayStreamPrivateData implements PrivateData { + final BufferAllocator allocator; + final ArrowReader reader; + int nextDictionary; + byte[] lastError; + + ExportedArrayStreamPrivateData(BufferAllocator allocator, ArrowReader reader) { + this.allocator = allocator; + this.reader = reader; + this.nextDictionary = 0; + } + + private int setLastError(Throwable err) { + // Do not let exceptions propagate up to JNI + try { + StringWriter buf = new StringWriter(); + PrintWriter writer = new PrintWriter(buf); + err.printStackTrace(writer); + lastError = buf.toString().getBytes(StandardCharsets.UTF_8); + } catch (Throwable e) { + // Bail out of setting the error message - we'll still return an error code + lastError = null; + } + return 5; // = EIO + } + + @SuppressWarnings("unused") // Used by JNI + int getNext(long arrayAddress) { + try (ArrowArray out = ArrowArray.wrap(arrayAddress)) { + if (reader.loadNextBatch()) { + Data.exportVectorSchemaRoot(allocator, reader.getVectorSchemaRoot(), reader, out); + } else { + out.markReleased(); + } + return 0; + } catch (Throwable e) { + return setLastError(e); + } + } + + @SuppressWarnings("unused") // Used by JNI + int getSchema(long schemaAddress) { + try (ArrowSchema out = ArrowSchema.wrap(schemaAddress)) { + final Schema schema = reader.getVectorSchemaRoot().getSchema(); + Data.exportSchema(allocator, schema, reader, out); + return 0; + } catch (Throwable e) { + return setLastError(e); + } + } + + @Override + public void close() { + try { + reader.close(); + } catch (IOException e) { + // XXX: C Data Interface gives us no way to signal this to the caller, + // but the JNI side will catch this and log an error. + throw new RuntimeException(e); + } + } + } + + void export(ArrowArrayStream stream, ArrowReader reader) { + ExportedArrayStreamPrivateData data = new ExportedArrayStreamPrivateData(allocator, reader); + try { + JniWrapper.get().exportArrayStream(stream.memoryAddress(), data); + } catch (Exception e) { + data.close(); + throw e; + } + } +} diff --git a/java/c/src/main/java/org/apache/arrow/c/ArrowArrayStream.java b/java/c/src/main/java/org/apache/arrow/c/ArrowArrayStream.java new file mode 100644 index 0000000000000..caf1f2fe965d1 --- /dev/null +++ b/java/c/src/main/java/org/apache/arrow/c/ArrowArrayStream.java @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.c; + +import static org.apache.arrow.c.NativeUtil.NULL; +import static org.apache.arrow.util.Preconditions.checkNotNull; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +import org.apache.arrow.c.jni.CDataJniException; +import org.apache.arrow.c.jni.JniWrapper; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.ReferenceManager; +import org.apache.arrow.memory.util.MemoryUtil; + +/** + * C Stream Interface ArrowArrayStream. + *

+ * Represents a wrapper for the following C structure: + * + *

+ * struct ArrowArrayStream {
+ *   int (*get_schema)(struct ArrowArrayStream*, struct ArrowSchema* out);
+ *   int (*get_next)(struct ArrowArrayStream*, struct ArrowArray* out);
+ *   const char* (*get_last_error)(struct ArrowArrayStream*);
+ *   void (*release)(struct ArrowArrayStream*);
+ *   void* private_data;
+ * };
+ * 
+ */ +public class ArrowArrayStream implements BaseStruct { + private static final int SIZE_OF = 40; + private static final int INDEX_RELEASE_CALLBACK = 24; + + private ArrowBuf data; + + /** + * Snapshot of the ArrowArrayStream raw data. + */ + public static class Snapshot { + public long get_schema; + public long get_next; + public long get_last_error; + public long release; + public long private_data; + + /** + * Initialize empty ArrowArray snapshot. + */ + public Snapshot() { + get_schema = NULL; + get_next = NULL; + get_last_error = NULL; + release = NULL; + private_data = NULL; + } + } + + /** + * Create ArrowArrayStream from an existing memory address. + *

+ * The resulting ArrowArrayStream does not own the memory. + * + * @param memoryAddress Memory address to wrap + * @return A new ArrowArrayStream instance + */ + public static ArrowArrayStream wrap(long memoryAddress) { + return new ArrowArrayStream(new ArrowBuf(ReferenceManager.NO_OP, null, ArrowArrayStream.SIZE_OF, memoryAddress)); + } + + /** + * Create ArrowArrayStream by allocating memory. + *

+ * The resulting ArrowArrayStream owns the memory. + * + * @param allocator Allocator for memory allocations + * @return A new ArrowArrayStream instance + */ + public static ArrowArrayStream allocateNew(BufferAllocator allocator) { + ArrowArrayStream array = new ArrowArrayStream(allocator.buffer(ArrowArrayStream.SIZE_OF)); + array.markReleased(); + return array; + } + + ArrowArrayStream(ArrowBuf data) { + checkNotNull(data, "ArrowArrayStream initialized with a null buffer"); + this.data = data; + } + + /** + * Mark the array as released. + */ + public void markReleased() { + directBuffer().putLong(INDEX_RELEASE_CALLBACK, NULL); + } + + @Override + public long memoryAddress() { + checkNotNull(data, "ArrowArrayStream is already closed"); + return data.memoryAddress(); + } + + @Override + public void release() { + long address = memoryAddress(); + JniWrapper.get().releaseArrayStream(address); + } + + /** + * Get the schema of the stream. + * @param schema The ArrowSchema struct to output to + * @throws IOException if the stream returns an error + */ + public void getSchema(ArrowSchema schema) throws IOException { + long address = memoryAddress(); + try { + JniWrapper.get().getSchemaArrayStream(address, schema.memoryAddress()); + } catch (CDataJniException e) { + throw new IOException("[errno " + e.getErrno() + "] " + e.getMessage()); + } + } + + /** + * Get the next batch in the stream. + * @param array The ArrowArray struct to output to + * @throws IOException if the stream returns an error + */ + public void getNext(ArrowArray array) throws IOException { + long address = memoryAddress(); + try { + JniWrapper.get().getNextArrayStream(address, array.memoryAddress()); + } catch (CDataJniException e) { + throw new IOException("[errno " + e.getErrno() + "] " + e.getMessage()); + } + } + + @Override + public void close() { + if (data != null) { + data.close(); + data = null; + } + } + + private ByteBuffer directBuffer() { + return MemoryUtil.directBuffer(memoryAddress(), ArrowArrayStream.SIZE_OF).order(ByteOrder.nativeOrder()); + } + + /** + * Take a snapshot of the ArrowArrayStream raw values. + * + * @return snapshot + */ + public ArrowArrayStream.Snapshot snapshot() { + ByteBuffer data = directBuffer(); + ArrowArrayStream.Snapshot snapshot = new ArrowArrayStream.Snapshot(); + snapshot.get_schema = data.getLong(); + snapshot.get_next = data.getLong(); + snapshot.get_last_error = data.getLong(); + snapshot.release = data.getLong(); + snapshot.private_data = data.getLong(); + return snapshot; + } + + /** + * Write values from Snapshot to the underlying ArrowArrayStream memory buffer. + */ + public void save(ArrowArrayStream.Snapshot snapshot) { + directBuffer() + .putLong(snapshot.get_schema) + .putLong(snapshot.get_next) + .putLong(snapshot.get_last_error) + .putLong(snapshot.release) + .putLong(snapshot.private_data); + } +} diff --git a/java/c/src/main/java/org/apache/arrow/c/ArrowArrayStreamReader.java b/java/c/src/main/java/org/apache/arrow/c/ArrowArrayStreamReader.java new file mode 100644 index 0000000000000..b39a3be9b842f --- /dev/null +++ b/java/c/src/main/java/org/apache/arrow/c/ArrowArrayStreamReader.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.c; + +import static org.apache.arrow.c.NativeUtil.NULL; +import static org.apache.arrow.util.Preconditions.checkState; + +import java.io.IOException; +import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.ipc.ArrowReader; +import org.apache.arrow.vector.types.pojo.Schema; + +/** + * An implementation of an {@link ArrowReader} backed by an ArrowArrayStream. + */ +final class ArrowArrayStreamReader extends ArrowReader { + private final ArrowArrayStream ownedStream; + private final CDataDictionaryProvider provider; + + ArrowArrayStreamReader(BufferAllocator allocator, ArrowArrayStream stream) { + super(allocator); + this.provider = new CDataDictionaryProvider(); + + ArrowArrayStream.Snapshot snapshot = stream.snapshot(); + checkState(snapshot.release != NULL, "Cannot import released ArrowArrayStream"); + + // Move imported stream + this.ownedStream = ArrowArrayStream.allocateNew(allocator); + this.ownedStream.save(snapshot); + stream.markReleased(); + stream.close(); + } + + @Override + public Map getDictionaryVectors() { + return provider.getDictionaryIds().stream().collect(Collectors.toMap(Function.identity(), provider::lookup)); + } + + @Override + public Dictionary lookup(long id) { + return provider.lookup(id); + } + + @Override + public boolean loadNextBatch() throws IOException { + try (ArrowArray array = ArrowArray.allocateNew(allocator)) { + ownedStream.getNext(array); + if (array.snapshot().release == NULL) { + return false; + } + Data.importIntoVectorSchemaRoot(allocator, array, getVectorSchemaRoot(), provider); + return true; + } + } + + @Override + public long bytesRead() { + return 0; + } + + @Override + protected void closeReadSource() { + ownedStream.release(); + ownedStream.close(); + provider.close(); + } + + @Override + protected Schema readSchema() throws IOException { + try (ArrowSchema schema = ArrowSchema.allocateNew(allocator)) { + ownedStream.getSchema(schema); + return Data.importSchema(allocator, schema, provider); + } + } +} diff --git a/java/c/src/main/java/org/apache/arrow/c/CDataDictionaryProvider.java b/java/c/src/main/java/org/apache/arrow/c/CDataDictionaryProvider.java index 43bcda276ef99..4a84f11704c9a 100644 --- a/java/c/src/main/java/org/apache/arrow/c/CDataDictionaryProvider.java +++ b/java/c/src/main/java/org/apache/arrow/c/CDataDictionaryProvider.java @@ -52,6 +52,7 @@ void put(Dictionary dictionary) { } } + @Override public final Set getDictionaryIds() { return map.keySet(); } diff --git a/java/c/src/main/java/org/apache/arrow/c/Data.java b/java/c/src/main/java/org/apache/arrow/c/Data.java index 7151bff94be94..523a4c555a120 100644 --- a/java/c/src/main/java/org/apache/arrow/c/Data.java +++ b/java/c/src/main/java/org/apache/arrow/c/Data.java @@ -26,6 +26,7 @@ import org.apache.arrow.vector.VectorUnloader; import org.apache.arrow.vector.complex.StructVector; import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.apache.arrow.vector.ipc.ArrowReader; import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.ArrowType.ArrowTypeID; @@ -162,6 +163,16 @@ public static void exportVectorSchemaRoot(BufferAllocator allocator, VectorSchem } } + /** + * Export a reader as an ArrowArrayStream using the C Stream Interface. + * @param allocator Buffer allocator for allocating C data inteface fields + * @param reader Reader to export + * @param out C struct to export the stream + */ + public static void exportArrayStream(BufferAllocator allocator, ArrowReader reader, ArrowArrayStream out) { + new ArrayStreamExporter(allocator).export(out, reader); + } + /** * Import Java Field from the C data interface. *

@@ -314,4 +325,14 @@ public static VectorSchemaRoot importVectorSchemaRoot(BufferAllocator allocator, } return vsr; } + + /** + * Import an ArrowArrayStream as an {@link ArrowReader}. + * @param allocator Buffer allocator for allocating the output data. + * @param stream C stream interface struct to import. + * @return Imported reader + */ + public static ArrowReader importStream(BufferAllocator allocator, ArrowArrayStream stream) { + return new ArrowArrayStreamReader(allocator, stream); + } } diff --git a/java/c/src/main/java/org/apache/arrow/c/NativeUtil.java b/java/c/src/main/java/org/apache/arrow/c/NativeUtil.java index e2feda1e5dcc6..b152ea4e7c9fd 100644 --- a/java/c/src/main/java/org/apache/arrow/c/NativeUtil.java +++ b/java/c/src/main/java/org/apache/arrow/c/NativeUtil.java @@ -17,6 +17,7 @@ package org.apache.arrow.c; +import java.nio.Buffer; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.charset.StandardCharsets; @@ -52,7 +53,8 @@ public static String toJavaString(long cstringPtr) { length++; } byte[] bytes = new byte[length]; - ((ByteBuffer) reader.rewind()).get(bytes); + // Force use of base class rewind() to avoid breaking change of ByteBuffer.rewind in JDK9+ + ((ByteBuffer) ((Buffer) reader).rewind()).get(bytes); return new String(bytes, 0, length, StandardCharsets.UTF_8); } diff --git a/java/c/src/main/java/org/apache/arrow/c/jni/CDataJniException.java b/java/c/src/main/java/org/apache/arrow/c/jni/CDataJniException.java new file mode 100644 index 0000000000000..bebd434f3db3e --- /dev/null +++ b/java/c/src/main/java/org/apache/arrow/c/jni/CDataJniException.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.c.jni; + +/** + * An exception raised by the JNI side of the C Data bridge. + */ +public final class CDataJniException extends Exception { + private final int errno; + + public CDataJniException(int errno, String message) { + super(message); + this.errno = errno; + } + + /** + * The original error code returned from C. + */ + public int getErrno() { + return errno; + } + + @Override + public String toString() { + return "CDataJniException{" + + "errno=" + errno + + ", message=" + getMessage() + + '}'; + } +} diff --git a/java/c/src/main/java/org/apache/arrow/c/jni/JniWrapper.java b/java/c/src/main/java/org/apache/arrow/c/jni/JniWrapper.java index 9e1c19b100e98..eb299b65f003b 100644 --- a/java/c/src/main/java/org/apache/arrow/c/jni/JniWrapper.java +++ b/java/c/src/main/java/org/apache/arrow/c/jni/JniWrapper.java @@ -41,7 +41,15 @@ private JniWrapper() { public native void releaseArray(long memoryAddress); + public native void getNextArrayStream(long streamAddress, long arrayAddress) throws CDataJniException; + + public native void getSchemaArrayStream(long streamAddress, long arrayAddress) throws CDataJniException; + + public native void releaseArrayStream(long memoryAddress); + public native void exportSchema(long memoryAddress, PrivateData privateData); public native void exportArray(long memoryAddress, PrivateData data); + + public native void exportArrayStream(long memoryAddress, PrivateData data); } diff --git a/java/c/src/test/java/org/apache/arrow/c/RoundtripTest.java b/java/c/src/test/java/org/apache/arrow/c/RoundtripTest.java index 6aa6e889ba347..6a2b476b0c395 100644 --- a/java/c/src/test/java/org/apache/arrow/c/RoundtripTest.java +++ b/java/c/src/test/java/org/apache/arrow/c/RoundtripTest.java @@ -34,9 +34,6 @@ import java.util.stream.Collectors; import java.util.stream.Stream; -import org.apache.arrow.c.ArrowArray; -import org.apache.arrow.c.ArrowSchema; -import org.apache.arrow.c.Data; import org.apache.arrow.memory.ArrowBuf; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; diff --git a/java/c/src/test/java/org/apache/arrow/c/StreamTest.java b/java/c/src/test/java/org/apache/arrow/c/StreamTest.java new file mode 100644 index 0000000000000..57de88de5431c --- /dev/null +++ b/java/c/src/test/java/org/apache/arrow/c/StreamTest.java @@ -0,0 +1,332 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.c; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Function; +import java.util.stream.Collectors; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorLoader; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.compare.Range; +import org.apache.arrow.vector.compare.RangeEqualsVisitor; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.apache.arrow.vector.ipc.ArrowReader; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +final class StreamTest { + private RootAllocator allocator = null; + + @BeforeEach + public void setUp() { + allocator = new RootAllocator(Long.MAX_VALUE); + } + + @AfterEach + public void tearDown() { + allocator.close(); + } + + @Test + public void testRoundtripInts() throws Exception { + final Schema schema = new Schema(Collections.singletonList(Field.nullable("ints", new ArrowType.Int(32, true)))); + final List batches = new ArrayList<>(); + try (final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + final IntVector ints = (IntVector) root.getVector(0); + VectorUnloader unloader = new VectorUnloader(root); + + root.allocateNew(); + ints.setSafe(0, 1); + ints.setSafe(1, 2); + ints.setSafe(2, 4); + ints.setSafe(3, 8); + root.setRowCount(4); + batches.add(unloader.getRecordBatch()); + + root.allocateNew(); + ints.setSafe(0, 1); + ints.setNull(1); + ints.setSafe(2, 4); + ints.setNull(3); + root.setRowCount(4); + batches.add(unloader.getRecordBatch()); + roundtrip(schema, batches); + } + } + + @Test + public void roundtripStrings() throws Exception { + final Schema schema = new Schema(Arrays.asList(Field.nullable("ints", new ArrowType.Int(32, true)), + Field.nullable("strs", new ArrowType.Utf8()))); + final List batches = new ArrayList<>(); + try (final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + final IntVector ints = (IntVector) root.getVector(0); + final VarCharVector strs = (VarCharVector) root.getVector(1); + VectorUnloader unloader = new VectorUnloader(root); + + root.allocateNew(); + ints.setSafe(0, 1); + ints.setSafe(1, 2); + ints.setSafe(2, 4); + ints.setSafe(3, 8); + strs.setSafe(0, "".getBytes(StandardCharsets.UTF_8)); + strs.setSafe(1, "a".getBytes(StandardCharsets.UTF_8)); + strs.setSafe(2, "bc".getBytes(StandardCharsets.UTF_8)); + strs.setSafe(3, "defg".getBytes(StandardCharsets.UTF_8)); + root.setRowCount(4); + batches.add(unloader.getRecordBatch()); + + root.allocateNew(); + ints.setSafe(0, 1); + ints.setNull(1); + ints.setSafe(2, 4); + ints.setNull(3); + strs.setSafe(0, "".getBytes(StandardCharsets.UTF_8)); + strs.setNull(1); + strs.setSafe(2, "bc".getBytes(StandardCharsets.UTF_8)); + strs.setNull(3); + root.setRowCount(4); + batches.add(unloader.getRecordBatch()); + roundtrip(schema, batches); + } + } + + @Test + public void roundtripDictionary() throws Exception { + final ArrowType.Int indexType = new ArrowType.Int(32, true); + final DictionaryEncoding encoding = new DictionaryEncoding(1L, false, indexType); + final Schema schema = new Schema(Collections.singletonList( + new Field("dict", new FieldType(/*nullable=*/true, indexType, encoding), Collections.emptyList()))); + final List batches = new ArrayList<>(); + try (final CDataDictionaryProvider provider = new CDataDictionaryProvider(); + final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + final VarCharVector dictionary = new VarCharVector("values", allocator); + dictionary.allocateNew(); + dictionary.setSafe(0, "foo".getBytes(StandardCharsets.UTF_8)); + dictionary.setSafe(1, "bar".getBytes(StandardCharsets.UTF_8)); + dictionary.setNull(2); + dictionary.setValueCount(3); + provider.put(new Dictionary(dictionary, encoding)); + final IntVector encoded = (IntVector) root.getVector(0); + VectorUnloader unloader = new VectorUnloader(root); + + root.allocateNew(); + encoded.setSafe(0, 0); + encoded.setSafe(1, 1); + encoded.setSafe(2, 0); + encoded.setSafe(3, 2); + root.setRowCount(4); + batches.add(unloader.getRecordBatch()); + + root.allocateNew(); + encoded.setSafe(0, 0); + encoded.setNull(1); + encoded.setSafe(2, 1); + encoded.setNull(3); + root.setRowCount(4); + batches.add(unloader.getRecordBatch()); + roundtrip(schema, batches, provider); + } + } + + @Test + public void importReleasedStream() { + try (final ArrowArrayStream stream = ArrowArrayStream.allocateNew(allocator)) { + Exception e = assertThrows(IllegalStateException.class, () -> Data.importStream(allocator, stream)); + assertThat(e).hasMessageContaining("Cannot import released ArrowArrayStream"); + } + } + + @Test + public void getNextError() throws Exception { + final Schema schema = new Schema(Collections.singletonList(Field.nullable("ints", new ArrowType.Int(32, true)))); + final List batches = new ArrayList<>(); + try (final ArrowReader source = new InMemoryArrowReader(allocator, schema, batches, + new DictionaryProvider.MapDictionaryProvider()) { + @Override + public boolean loadNextBatch() throws IOException { + throw new IOException("Failed to load batch!"); + } + }; final ArrowArrayStream stream = ArrowArrayStream.allocateNew(allocator)) { + Data.exportArrayStream(allocator, source, stream); + try (final ArrowReader reader = Data.importStream(allocator, stream)) { + assertThat(reader.getVectorSchemaRoot().getSchema()).isEqualTo(schema); + final IOException e = assertThrows(IOException.class, reader::loadNextBatch); + assertThat(e).hasMessageContaining("Failed to load batch!"); + assertThat(e).hasMessageContaining("[errno "); + } + } + } + + @Test + public void getSchemaError() throws Exception { + final Schema schema = new Schema(Collections.singletonList(Field.nullable("ints", new ArrowType.Int(32, true)))); + final List batches = new ArrayList<>(); + try (final ArrowReader source = new InMemoryArrowReader(allocator, schema, batches, + new DictionaryProvider.MapDictionaryProvider()) { + @Override + protected Schema readSchema() { + throw new IllegalArgumentException("Failed to read schema!"); + } + }; final ArrowArrayStream stream = ArrowArrayStream.allocateNew(allocator)) { + Data.exportArrayStream(allocator, source, stream); + try (final ArrowReader reader = Data.importStream(allocator, stream)) { + final IOException e = assertThrows(IOException.class, reader::getVectorSchemaRoot); + assertThat(e).hasMessageContaining("Failed to read schema!"); + assertThat(e).hasMessageContaining("[errno "); + } + } + } + + void roundtrip(Schema schema, List batches, DictionaryProvider provider) throws Exception { + ArrowReader source = new InMemoryArrowReader(allocator, schema, batches, provider); + + try (final ArrowArrayStream stream = ArrowArrayStream.allocateNew(allocator); + final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + final VectorLoader loader = new VectorLoader(root); + Data.exportArrayStream(allocator, source, stream); + + try (final ArrowReader reader = Data.importStream(allocator, stream)) { + assertThat(reader.getVectorSchemaRoot().getSchema()).isEqualTo(schema); + + for (ArrowRecordBatch batch : batches) { + assertThat(reader.loadNextBatch()).isTrue(); + loader.load(batch); + + assertThat(reader.getVectorSchemaRoot().getRowCount()).isEqualTo(root.getRowCount()); + + for (int i = 0; i < root.getFieldVectors().size(); i++) { + final FieldVector expected = root.getVector(i); + final FieldVector actual = reader.getVectorSchemaRoot().getVector(i); + assertVectorsEqual(expected, actual); + } + } + assertThat(reader.loadNextBatch()).isFalse(); + assertThat(reader.getDictionaryIds()).isEqualTo(provider.getDictionaryIds()); + for (Map.Entry entry : reader.getDictionaryVectors().entrySet()) { + final FieldVector expected = provider.lookup(entry.getKey()).getVector(); + final FieldVector actual = entry.getValue().getVector(); + assertVectorsEqual(expected, actual); + } + } + } + } + + void roundtrip(Schema schema, List batches) throws Exception { + roundtrip(schema, batches, new CDataDictionaryProvider()); + } + + private static void assertVectorsEqual(FieldVector expected, FieldVector actual) { + assertThat(actual.getField().getType()).isEqualTo(expected.getField().getType()); + assertThat(actual.getValueCount()).isEqualTo(expected.getValueCount()); + final Range range = new Range(/*leftStart=*/0, /*rightStart=*/0, expected.getValueCount()); + assertThat(new RangeEqualsVisitor(expected, actual) + .rangeEquals(range)) + .as("Vectors were not equal.\nExpected: %s\nGot: %s", expected, actual) + .isTrue(); + } + + /** + * An ArrowReader backed by a fixed list of batches. + */ + static class InMemoryArrowReader extends ArrowReader { + private final Schema schema; + private final List batches; + private final DictionaryProvider provider; + private int nextBatch; + + InMemoryArrowReader(BufferAllocator allocator, Schema schema, List batches, + DictionaryProvider provider) { + super(allocator); + this.schema = schema; + this.batches = batches; + this.provider = provider; + this.nextBatch = 0; + } + + @Override + public Dictionary lookup(long id) { + return provider.lookup(id); + } + + @Override + public Set getDictionaryIds() { + return provider.getDictionaryIds(); + } + + @Override + public Map getDictionaryVectors() { + return getDictionaryIds().stream().collect(Collectors.toMap(Function.identity(), this::lookup)); + } + + @Override + public boolean loadNextBatch() throws IOException { + if (nextBatch < batches.size()) { + VectorLoader loader = new VectorLoader(getVectorSchemaRoot()); + loader.load(batches.get(nextBatch++)); + return true; + } + return false; + } + + @Override + public long bytesRead() { + return 0; + } + + @Override + protected void closeReadSource() throws IOException { + try { + AutoCloseables.close(batches); + } catch (Exception e) { + throw new IOException(e); + } + } + + @Override + protected Schema readSchema() { + return schema; + } + } +} diff --git a/java/pom.xml b/java/pom.xml index 28afabc344dd4..6f2ed823cfe2d 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -554,6 +554,12 @@ javax.annotation-api 1.3.2 + + org.assertj + assertj-core + 3.23.1 + test + org.immutables value diff --git a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryProvider.java b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryProvider.java index 21165c07d9b1e..76e1eb9f66d25 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryProvider.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryProvider.java @@ -29,6 +29,9 @@ public interface DictionaryProvider { /** Return the dictionary for the given ID. */ Dictionary lookup(long id); + /** Get all dictionary IDs. */ + Set getDictionaryIds(); + /** * Implementation of {@link DictionaryProvider} that is backed by a hash-map. */ @@ -50,6 +53,7 @@ public void put(Dictionary dictionary) { map.put(dictionary.getEncoding().getId(), dictionary); } + @Override public final Set getDictionaryIds() { return map.keySet(); } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowReader.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowReader.java index 9d940deecfe20..04c57d7e82fef 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowReader.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowReader.java @@ -23,6 +23,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Set; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.FieldVector; @@ -99,6 +100,11 @@ public Dictionary lookup(long id) { return dictionaries.get(id); } + @Override + public Set getDictionaryIds() { + return dictionaries.keySet(); + } + /** * Load the next ArrowRecordBatch to the vector schema root if available. * diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileReader.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileReader.java index d093e840ab3a5..6455857c275c3 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileReader.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileReader.java @@ -38,6 +38,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Set; import org.apache.arrow.memory.ArrowBuf; import org.apache.arrow.memory.BufferAllocator; @@ -115,6 +116,11 @@ public Dictionary lookup(long id) { return dictionaries.get(id); } + @Override + public Set getDictionaryIds() { + return dictionaries.keySet(); + } + /** Reads the beginning (schema section) of the json file and returns it. */ public Schema start() throws JsonParseException, IOException { readToken(START_OBJECT); diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/MessageSerializer.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/MessageSerializer.java index 6597e0302c72c..9deb42c498cbb 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/MessageSerializer.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/MessageSerializer.java @@ -684,7 +684,8 @@ public static MessageMetadataResult readMessage(ReadChannel in) throws IOExcepti int messageLength = MessageSerializer.bytesToInt(buffer.array()); if (messageLength == IPC_CONTINUATION_TOKEN) { - buffer.clear(); + // Avoid breaking change in signature of ByteBuffer.clear() in JDK9+ + ((java.nio.Buffer) buffer).clear(); // ARROW-6313, if the first 4 bytes are continuation message, read the next 4 for the length if (in.readFully(buffer) == 4) { messageLength = MessageSerializer.bytesToInt(buffer.array()); From 8090dbf85acd4d89de2d846d781c469c6f0ed2eb Mon Sep 17 00:00:00 2001 From: David Li Date: Wed, 29 Jun 2022 14:12:32 -0400 Subject: [PATCH 2/5] Address feedback --- java/c/src/main/cpp/jni_wrapper.cc | 86 ++++++++++++++----- .../apache/arrow/c/ArrayStreamExporter.java | 5 +- .../main/java/org/apache/arrow/c/Data.java | 2 +- .../java/org/apache/arrow/c/StreamTest.java | 8 +- 4 files changed, 72 insertions(+), 29 deletions(-) diff --git a/java/c/src/main/cpp/jni_wrapper.cc b/java/c/src/main/cpp/jni_wrapper.cc index 604c428224631..ffe4d2ba715ca 100644 --- a/java/c/src/main/cpp/jni_wrapper.cc +++ b/java/c/src/main/cpp/jni_wrapper.cc @@ -28,6 +28,7 @@ namespace { +jclass kObjectClass; jclass kRuntimeExceptionClass; jclass kPrivateDataClass; jclass kCDataExceptionClass; @@ -35,6 +36,7 @@ jclass kStreamPrivateDataClass; jfieldID kPrivateDataLastErrorField; +jmethodID kObjectToStringMethod; jmethodID kPrivateDataCloseMethod; jmethodID kPrivateDataGetNextMethod; jmethodID kPrivateDataGetSchemaMethod; @@ -61,7 +63,7 @@ jclass CreateGlobalClassReference(JNIEnv* env, const char* class_name) { ThrowPendingException(message); } jclass global_class = (jclass)env->NewGlobalRef(local_class); - if (!local_class) { + if (!global_class) { std::string message = "Could not create global reference to class "; message += class_name; ThrowPendingException(message); @@ -165,9 +167,10 @@ void release_exported(T* base) { env->CallObjectMethod(private_data->j_private_data_, kPrivateDataCloseMethod); if (env->ExceptionCheck()) { + // Can't signal this to caller, so log and then try to free things + // as best we can env->ExceptionDescribe(); env->ExceptionClear(); - ThrowPendingException("Error calling close of private data"); } env->DeleteGlobalRef(private_data->j_private_data_); delete private_data; @@ -177,6 +180,48 @@ void release_exported(T* base) { base->release = nullptr; } +// Attempt to copy the JVM-side lastError to the C++ side +void TryCopyLastError(JNIEnv* env, InnerPrivateData* private_data) { + jobject error_data = + env->GetObjectField(private_data->j_private_data_, kPrivateDataLastErrorField); + if (!error_data) { + private_data->last_error_.clear(); + return; + } + + auto arr = reinterpret_cast(error_data); + jbyte* error_bytes = env->GetByteArrayElements(arr, nullptr); + if (!error_bytes) { + private_data->last_error_.clear(); + return; + } + + char* error_str = reinterpret_cast(error_bytes); + private_data->last_error_ = std::string(error_str, std::strlen(error_str)); + + env->ReleaseByteArrayElements(arr, error_bytes, JNI_ABORT); +} + +// Normally the Java side catches all exceptions and populates +// lastError. If that fails we check for an exception and try to +// populate last_error_ ourselves. +void TryHandleUncaughtException(JNIEnv* env, InnerPrivateData* private_data, + jthrowable exc) { + jstring message = + reinterpret_cast(env->CallObjectMethod(exc, kObjectToStringMethod)); + if (!message) { + private_data->last_error_.clear(); + return; + } + const char* str = env->GetStringUTFChars(message, 0); + if (!str) { + private_data->last_error_.clear(); + return; + } + private_data->last_error_ = str; + env->ReleaseStringUTFChars(message, 0); +} + int ArrowArrayStreamGetSchema(ArrowArrayStream* stream, ArrowSchema* out) { assert(stream->private_data != nullptr); InnerPrivateData* private_data = @@ -184,13 +229,15 @@ int ArrowArrayStreamGetSchema(ArrowArrayStream* stream, ArrowSchema* out) { JNIEnvGuard guard(private_data->vm_); JNIEnv* env = guard.env(); - const long out_addr = static_cast(reinterpret_cast(out)); + const jlong out_addr = static_cast(reinterpret_cast(out)); const int err_code = env->CallIntMethod(private_data->j_private_data_, kPrivateDataGetSchemaMethod, out_addr); - if (env->ExceptionCheck()) { - env->ExceptionDescribe(); + if (jthrowable exc = env->ExceptionOccurred()) { + TryHandleUncaughtException(env, private_data, exc); env->ExceptionClear(); return EIO; + } else if (err_code != 0) { + TryCopyLastError(env, private_data); } return err_code; } @@ -202,13 +249,15 @@ int ArrowArrayStreamGetNext(ArrowArrayStream* stream, ArrowArray* out) { JNIEnvGuard guard(private_data->vm_); JNIEnv* env = guard.env(); - const long out_addr = static_cast(reinterpret_cast(out)); + const jlong out_addr = static_cast(reinterpret_cast(out)); const int err_code = env->CallIntMethod(private_data->j_private_data_, kPrivateDataGetNextMethod, out_addr); - if (env->ExceptionCheck()) { - env->ExceptionDescribe(); + if (jthrowable exc = env->ExceptionOccurred()) { + TryHandleUncaughtException(env, private_data, exc); env->ExceptionClear(); return EIO; + } else if (err_code != 0) { + TryCopyLastError(env, private_data); } return err_code; } @@ -220,18 +269,7 @@ const char* ArrowArrayStreamGetLastError(ArrowArrayStream* stream) { JNIEnvGuard guard(private_data->vm_); JNIEnv* env = guard.env(); - jobject error_data = - env->GetObjectField(private_data->j_private_data_, kPrivateDataLastErrorField); - if (!error_data) return nullptr; - - auto arr = reinterpret_cast(error_data); - jbyte* error_bytes = env->GetByteArrayElements(arr, nullptr); - if (!error_bytes) return nullptr; - - char* error_str = reinterpret_cast(error_bytes); - private_data->last_error_ = std::string(error_str, std::strlen(error_str)); - - env->ReleaseByteArrayElements(arr, error_bytes, JNI_ABORT); + if (private_data->last_error_.empty()) return nullptr; return private_data->last_error_.c_str(); } @@ -247,9 +285,10 @@ void ArrowArrayStreamRelease(ArrowArrayStream* stream) { env->CallObjectMethod(private_data->j_private_data_, kPrivateDataCloseMethod); if (env->ExceptionCheck()) { + // Can't signal this to caller, so log and then try to free things + // as best we can env->ExceptionDescribe(); env->ExceptionClear(); - ThrowPendingException("Error calling close of private data"); } env->DeleteGlobalRef(private_data->j_private_data_); delete private_data; @@ -278,6 +317,7 @@ jint JNI_OnLoad(JavaVM* vm, void* reserved) { return JNI_ERR; } JNI_METHOD_START + kObjectClass = CreateGlobalClassReference(env, "Ljava/lang/Object;"); kRuntimeExceptionClass = CreateGlobalClassReference(env, "Ljava/lang/RuntimeException;"); kPrivateDataClass = @@ -290,6 +330,8 @@ jint JNI_OnLoad(JavaVM* vm, void* reserved) { kPrivateDataLastErrorField = GetFieldID(env, kStreamPrivateDataClass, "lastError", "[B"); + kObjectToStringMethod = + GetMethodID(env, kObjectClass, "toString", "()Ljava/lang/String;"); kPrivateDataCloseMethod = GetMethodID(env, kPrivateDataClass, "close", "()V"); kPrivateDataGetNextMethod = GetMethodID(env, kStreamPrivateDataClass, "getNext", "(J)I"); @@ -305,9 +347,11 @@ jint JNI_OnLoad(JavaVM* vm, void* reserved) { void JNI_OnUnload(JavaVM* vm, void* reserved) { JNIEnv* env; vm->GetEnv(reinterpret_cast(&env), JNI_VERSION); + env->DeleteGlobalRef(kObjectClass); env->DeleteGlobalRef(kRuntimeExceptionClass); env->DeleteGlobalRef(kPrivateDataClass); env->DeleteGlobalRef(kCDataExceptionClass); + env->DeleteGlobalRef(kStreamPrivateDataClass); } /* diff --git a/java/c/src/main/java/org/apache/arrow/c/ArrayStreamExporter.java b/java/c/src/main/java/org/apache/arrow/c/ArrayStreamExporter.java index 81f495683590c..2c5ca08e717fd 100644 --- a/java/c/src/main/java/org/apache/arrow/c/ArrayStreamExporter.java +++ b/java/c/src/main/java/org/apache/arrow/c/ArrayStreamExporter.java @@ -44,13 +44,12 @@ final class ArrayStreamExporter { static class ExportedArrayStreamPrivateData implements PrivateData { final BufferAllocator allocator; final ArrowReader reader; - int nextDictionary; + // Read by the JNI side for get_last_error byte[] lastError; ExportedArrayStreamPrivateData(BufferAllocator allocator, ArrowReader reader) { this.allocator = allocator; this.reader = reader; - this.nextDictionary = 0; } private int setLastError(Throwable err) { @@ -97,7 +96,7 @@ public void close() { try { reader.close(); } catch (IOException e) { - // XXX: C Data Interface gives us no way to signal this to the caller, + // XXX: C Data Interface gives us no way to signal errors to the caller, // but the JNI side will catch this and log an error. throw new RuntimeException(e); } diff --git a/java/c/src/main/java/org/apache/arrow/c/Data.java b/java/c/src/main/java/org/apache/arrow/c/Data.java index 523a4c555a120..9ee5a6c757cab 100644 --- a/java/c/src/main/java/org/apache/arrow/c/Data.java +++ b/java/c/src/main/java/org/apache/arrow/c/Data.java @@ -332,7 +332,7 @@ public static VectorSchemaRoot importVectorSchemaRoot(BufferAllocator allocator, * @param stream C stream interface struct to import. * @return Imported reader */ - public static ArrowReader importStream(BufferAllocator allocator, ArrowArrayStream stream) { + public static ArrowReader importArrayStream(BufferAllocator allocator, ArrowArrayStream stream) { return new ArrowArrayStreamReader(allocator, stream); } } diff --git a/java/c/src/test/java/org/apache/arrow/c/StreamTest.java b/java/c/src/test/java/org/apache/arrow/c/StreamTest.java index 57de88de5431c..06401687a5a66 100644 --- a/java/c/src/test/java/org/apache/arrow/c/StreamTest.java +++ b/java/c/src/test/java/org/apache/arrow/c/StreamTest.java @@ -173,7 +173,7 @@ public void roundtripDictionary() throws Exception { @Test public void importReleasedStream() { try (final ArrowArrayStream stream = ArrowArrayStream.allocateNew(allocator)) { - Exception e = assertThrows(IllegalStateException.class, () -> Data.importStream(allocator, stream)); + Exception e = assertThrows(IllegalStateException.class, () -> Data.importArrayStream(allocator, stream)); assertThat(e).hasMessageContaining("Cannot import released ArrowArrayStream"); } } @@ -190,7 +190,7 @@ public boolean loadNextBatch() throws IOException { } }; final ArrowArrayStream stream = ArrowArrayStream.allocateNew(allocator)) { Data.exportArrayStream(allocator, source, stream); - try (final ArrowReader reader = Data.importStream(allocator, stream)) { + try (final ArrowReader reader = Data.importArrayStream(allocator, stream)) { assertThat(reader.getVectorSchemaRoot().getSchema()).isEqualTo(schema); final IOException e = assertThrows(IOException.class, reader::loadNextBatch); assertThat(e).hasMessageContaining("Failed to load batch!"); @@ -211,7 +211,7 @@ protected Schema readSchema() { } }; final ArrowArrayStream stream = ArrowArrayStream.allocateNew(allocator)) { Data.exportArrayStream(allocator, source, stream); - try (final ArrowReader reader = Data.importStream(allocator, stream)) { + try (final ArrowReader reader = Data.importArrayStream(allocator, stream)) { final IOException e = assertThrows(IOException.class, reader::getVectorSchemaRoot); assertThat(e).hasMessageContaining("Failed to read schema!"); assertThat(e).hasMessageContaining("[errno "); @@ -227,7 +227,7 @@ void roundtrip(Schema schema, List batches, DictionaryProvider final VectorLoader loader = new VectorLoader(root); Data.exportArrayStream(allocator, source, stream); - try (final ArrowReader reader = Data.importStream(allocator, stream)) { + try (final ArrowReader reader = Data.importArrayStream(allocator, stream)) { assertThat(reader.getVectorSchemaRoot().getSchema()).isEqualTo(schema); for (ArrowRecordBatch batch : batches) { From 3e81c769ae8f61ad0ab76b09569b886f9a984e9f Mon Sep 17 00:00:00 2001 From: David Li Date: Thu, 30 Jun 2022 10:18:47 -0400 Subject: [PATCH 3/5] Extend Python/Java C Data integration tests --- java/c/src/test/python/integration_tests.py | 51 +++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/java/c/src/test/python/integration_tests.py b/java/c/src/test/python/integration_tests.py index c1f130f21d47a..33ff1cf4a9af5 100644 --- a/java/c/src/test/python/integration_tests.py +++ b/java/c/src/test/python/integration_tests.py @@ -84,6 +84,13 @@ def java_to_python_record_batch(self, root): ptr_array), self.java_c.ArrowSchema.wrap(ptr_schema)) return pa.RecordBatch._import_from_c(ptr_array, ptr_schema) + def java_to_python_reader(self, reader): + c_stream = ffi.new("struct ArrowArrayStream*") + ptr_stream = int(ffi.cast("uintptr_t", c_stream)) + self.java_c.Data.exportArrayStream(self.java_allocator, reader, + self.java_c.ArrowArrayStream.wrap(ptr_stream)) + return pa.RecordBatchReader._import_from_c(ptr_stream) + def python_to_java_field(self, field): c_schema = self.java_c.ArrowSchema.allocateNew(self.java_allocator) field._export_to_c(c_schema.memoryAddress()) @@ -102,6 +109,11 @@ def python_to_java_record_batch(self, record_batch): c_array.memoryAddress(), c_schema.memoryAddress()) return self.java_c.Data.importVectorSchemaRoot(self.java_allocator, c_array, c_schema, None) + def python_to_java_reader(self, reader): + c_stream = self.java_c.ArrowArrayStream.allocateNew(self.java_allocator) + reader._export_to_c(c_stream.memoryAddress()) + return self.java_c.Data.importArrayStream(self.java_allocator, c_stream) + def close(self): self.java_allocator.close() @@ -151,6 +163,17 @@ def round_trip_record_batch(self, rb_generator): expected = rb_generator() self.assertEqual(expected, new_rb) + def round_trip_reader(self, schema, batches): + reader = pa.RecordBatchReader.from_batches(schema, batches) + + java_reader = self.bridge.python_to_java_reader(reader) + del reader + py_reader = self.bridge.java_to_python_reader(java_reader) + del java_reader + + actual = list(py_reader) + self.assertEqual(batches, actual) + def test_string_array(self): self.round_trip_array(lambda: pa.array([None, "a", "bb", "ccc"])) @@ -217,6 +240,34 @@ def test_record_batch_with_list(self): self.round_trip_record_batch( lambda: pa.RecordBatch.from_arrays(data, ['f0', 'f1', 'f2', 'f3'])) + def test_reader_roundtrip(self): + schema = pa.schema([("ints", pa.int64()), ("strs", pa.string())]) + data = [ + pa.record_batch([[1, 2, 3, None], + ["a", "bc", None, ""]], + schema=schema), + pa.record_batch([[None, 4, 5, 6], + [None, "", "def", "g"]], + schema=schema), + ] + self.round_trip_reader(schema, data) + + def test_reader_complex_roundtrip(self): + schema = pa.schema([ + ("str_dict", pa.dictionary(pa.int8(), pa.string())), + ("int_list", pa.list_(pa.int64())), + ]) + dictionary = pa.array(["a", "bc", None]) + data = [ + pa.record_batch([pa.DictionaryArray.from_arrays([0, 2], dictionary), + [[1, 2, 3], None]], + schema=schema), + pa.record_batch([pa.DictionaryArray.from_arrays([None, 1], dictionary), + [[], [4]]], + schema=schema), + ] + self.round_trip_reader(schema, data) + if __name__ == '__main__': setup_jvm() From dad5d3c1f97dc0f20a0d3fb0cbd6a0e8401672d0 Mon Sep 17 00:00:00 2001 From: David Li Date: Thu, 30 Jun 2022 11:31:59 -0400 Subject: [PATCH 4/5] Safeguard against JVM shutdown --- cpp/src/arrow/c/bridge.cc | 11 +++++- java/c/src/main/cpp/jni_wrapper.cc | 51 +++++++++++++++++----------- python/pyarrow/includes/libarrow.pxd | 1 + python/pyarrow/ipc.pxi | 13 +++++-- 4 files changed, 53 insertions(+), 23 deletions(-) diff --git a/cpp/src/arrow/c/bridge.cc b/cpp/src/arrow/c/bridge.cc index f2671b5016122..de531dbc6078d 100644 --- a/cpp/src/arrow/c/bridge.cc +++ b/cpp/src/arrow/c/bridge.cc @@ -1748,7 +1748,9 @@ class ArrayStreamBatchReader : public RecordBatchReader { } ~ArrayStreamBatchReader() { - ArrowArrayStreamRelease(&stream_); + if (!ArrowArrayStreamIsReleased(&stream_)) { + ArrowArrayStreamRelease(&stream_); + } DCHECK(ArrowArrayStreamIsReleased(&stream_)); } @@ -1766,6 +1768,13 @@ class ArrayStreamBatchReader : public RecordBatchReader { } } + Status Close() override { + if (!ArrowArrayStreamIsReleased(&stream_)) { + ArrowArrayStreamRelease(&stream_); + } + return Status::OK(); + } + private: std::shared_ptr CacheSchema() const { if (!schema_) { diff --git a/java/c/src/main/cpp/jni_wrapper.cc b/java/c/src/main/cpp/jni_wrapper.cc index ffe4d2ba715ca..fea53aff49f40 100644 --- a/java/c/src/main/cpp/jni_wrapper.cc +++ b/java/c/src/main/cpp/jni_wrapper.cc @@ -19,6 +19,7 @@ #include #include +#include #include #include #include @@ -162,17 +163,25 @@ void release_exported(T* base) { InnerPrivateData* private_data = reinterpret_cast(base->private_data); - JNIEnvGuard guard(private_data->vm_); - JNIEnv* env = guard.env(); - - env->CallObjectMethod(private_data->j_private_data_, kPrivateDataCloseMethod); - if (env->ExceptionCheck()) { - // Can't signal this to caller, so log and then try to free things - // as best we can - env->ExceptionDescribe(); - env->ExceptionClear(); + // It is possible for the JVM to be shut down when this is called; + // guard against that. Example: Python code using JPype may shut + // down the JVM before releasing the stream. + try { + JNIEnvGuard guard(private_data->vm_); + JNIEnv* env = guard.env(); + + env->CallObjectMethod(private_data->j_private_data_, kPrivateDataCloseMethod); + if (env->ExceptionCheck()) { + // Can't signal this to caller, so log and then try to free things + // as best we can + env->ExceptionDescribe(); + env->ExceptionClear(); + } + env->DeleteGlobalRef(private_data->j_private_data_); + } catch (const JniPendingException& e) { + std::cerr << "WARNING: Failed to release Java C Data resource: " << e.what() + << std::endl; } - env->DeleteGlobalRef(private_data->j_private_data_); delete private_data; base->private_data = nullptr; @@ -280,17 +289,21 @@ void ArrowArrayStreamRelease(ArrowArrayStream* stream) { InnerPrivateData* private_data = reinterpret_cast(stream->private_data); - JNIEnvGuard guard(private_data->vm_); - JNIEnv* env = guard.env(); + // It is possible for the JVM to be shut down (see above) + try { + JNIEnvGuard guard(private_data->vm_); + JNIEnv* env = guard.env(); - env->CallObjectMethod(private_data->j_private_data_, kPrivateDataCloseMethod); - if (env->ExceptionCheck()) { - // Can't signal this to caller, so log and then try to free things - // as best we can - env->ExceptionDescribe(); - env->ExceptionClear(); + env->CallObjectMethod(private_data->j_private_data_, kPrivateDataCloseMethod); + if (env->ExceptionCheck()) { + env->ExceptionDescribe(); + env->ExceptionClear(); + } + env->DeleteGlobalRef(private_data->j_private_data_); + } catch (const JniPendingException& e) { + std::cerr << "WARNING: Failed to release Java ArrowArrayStream: " << e.what() + << std::endl; } - env->DeleteGlobalRef(private_data->j_private_data_); delete private_data; stream->private_data = nullptr; diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 9e43eb4eb9c76..ee5446fd57042 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -869,6 +869,7 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: cdef cppclass CRecordBatchReader" arrow::RecordBatchReader": shared_ptr[CSchema] schema() + CStatus Close() CStatus ReadNext(shared_ptr[CRecordBatch]* batch) CResult[shared_ptr[CTable]] ToTable() diff --git a/python/pyarrow/ipc.pxi b/python/pyarrow/ipc.pxi index f0297ff004d03..b5cbbfb62cf83 100644 --- a/python/pyarrow/ipc.pxi +++ b/python/pyarrow/ipc.pxi @@ -598,7 +598,7 @@ class _ReadPandasMixin: cdef class RecordBatchReader(_Weakrefable): """Base class for reading stream of record batches. - Record batch readers function as iterators of record batches that also + Record batch readers function as iterators of record batches that also provide the schema (without the need to get any batches). Warnings @@ -608,7 +608,7 @@ cdef class RecordBatchReader(_Weakrefable): Notes ----- - To import and export using the Arrow C stream interface, use the + To import and export using the Arrow C stream interface, use the ``_import_from_c`` and ``_export_from_c`` methods. However, keep in mind this interface is intended for expert users. @@ -702,11 +702,18 @@ cdef class RecordBatchReader(_Weakrefable): read_pandas = _ReadPandasMixin.read_pandas + def close(self): + """ + Release any resources associated with the reader. + """ + with nogil: + check_status(self.reader.get().Close()) + def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): - pass + self.close() def _export_to_c(self, out_ptr): """ From fec64cba5b8436b7bb81692b9a5a7a93c746f2e9 Mon Sep 17 00:00:00 2001 From: David Li Date: Thu, 30 Jun 2022 11:38:23 -0400 Subject: [PATCH 5/5] Update docs with example --- .../source/python/integration/python_java.rst | 163 ++++++++++++++++-- 1 file changed, 145 insertions(+), 18 deletions(-) diff --git a/docs/source/python/integration/python_java.rst b/docs/source/python/integration/python_java.rst index c191682721caa..a524fe9b48bb7 100644 --- a/docs/source/python/integration/python_java.rst +++ b/docs/source/python/integration/python_java.rst @@ -29,7 +29,7 @@ marshaling and unmarshaling data. The article takes for granted that you have a ``Python`` environment with ``pyarrow`` correctly installed and a ``Java`` environment with - ``arrow`` library correctly installed. + ``arrow`` library correctly installed. The ``Arrow Java`` version must have been compiled with ``mvn -Parrow-c-data`` to ensure CData exchange support is enabled. See `Python Install Instructions `_ @@ -53,7 +53,7 @@ We would save such class in the ``Simple.java`` file and proceed with compiling it to ``Simple.class`` using ``javac Simple.java``. Once the ``Simple.class`` file is created we can use the class -from Python using the +from Python using the `JPype `_ library which enables a Java runtime within the Python interpreter. @@ -64,11 +64,11 @@ enables a Java runtime within the Python interpreter. $ pip install jpype1 The most basic thing we can do with our ``Simple`` class is to -use the ``Simple.getNumber`` method from Python and see +use the ``Simple.getNumber`` method from Python and see if it will return the result. To do so, we can create a ``simple.py`` file which uses ``jpype`` to -import the ``Simple`` class from ``Simple.class`` file and invoke +import the ``Simple`` class from ``Simple.class`` file and invoke the ``Simple.getNumber`` method: .. code-block:: python @@ -87,7 +87,7 @@ to access the ``Java`` method and print the expected result: .. code-block:: console - $ python simple.py + $ python simple.py 4 Java to Python using pyarrow.jvm @@ -132,7 +132,7 @@ class, named ``FillTen.java`` } This class provides a public ``createArray`` method that anyone can invoke -to get back an array containing numbers from 1 to 10. +to get back an array containing numbers from 1 to 10. Given that this class now has a dependency on a bunch of packages, compiling it with ``javac`` is not enough anymore. We need to create @@ -142,7 +142,7 @@ a dedicated ``pom.xml`` file where we can collect the dependencies: 4.0.0 - + org.apache.arrow.py2java FillTen 1 @@ -150,7 +150,7 @@ a dedicated ``pom.xml`` file where we can collect the dependencies: 8 8 - + @@ -170,7 +170,7 @@ a dedicated ``pom.xml`` file where we can collect the dependencies: arrow-vector 8.0.0 pom - + org.apache.arrow arrow-c-data @@ -182,22 +182,22 @@ a dedicated ``pom.xml`` file where we can collect the dependencies: Once the ``FillTen.java`` file with the class is created as ``src/main/java/FillTen.java`` we can use ``maven`` to -compile the project with ``mvn package`` and get it +compile the project with ``mvn package`` and get it available in the ``target`` directory. .. code-block:: console $ mvn package [INFO] Scanning for projects... - [INFO] + [INFO] [INFO] ------------------< org.apache.arrow.py2java:FillTen >------------------ [INFO] Building FillTen 1 [INFO] --------------------------------[ jar ]--------------------------------- - [INFO] + [INFO] [INFO] --- maven-compiler-plugin:3.1:compile (default-compile) @ FillTen --- [INFO] Changes detected - recompiling the module! [INFO] Compiling 1 source file to /experiments/java2py/target/classes - [INFO] + [INFO] [INFO] --- maven-jar-plugin:2.4:jar (default-jar) @ FillTen --- [INFO] Building jar: /experiments/java2py/target/FillTen-1.jar [INFO] ------------------------------------------------------------------------ @@ -215,11 +215,11 @@ We can use ``maven`` to collect all dependencies and make them available in a si $ mvn org.apache.maven.plugins:maven-dependency-plugin:2.7:copy-dependencies -DoutputDirectory=dependencies [INFO] Scanning for projects... - [INFO] + [INFO] [INFO] ------------------< org.apache.arrow.py2java:FillTen >------------------ [INFO] Building FillTen 1 [INFO] --------------------------------[ jar ]--------------------------------- - [INFO] + [INFO] [INFO] --- maven-dependency-plugin:2.7:copy-dependencies (default-cli) @ FillTen --- [INFO] Copying jsr305-3.0.2.jar to /experiments/java2py/dependencies/jsr305-3.0.2.jar [INFO] Copying netty-common-4.1.72.Final.jar to /experiments/java2py/dependencies/netty-common-4.1.72.Final.jar @@ -246,9 +246,9 @@ We can use ``maven`` to collect all dependencies and make them available in a si Instead of manually collecting dependencies, you could also rely on the ``maven-assembly-plugin`` to build a single ``jar`` with all dependencies. -Once our package and all its depdendencies are available, +Once our package and all its depdendencies are available, we can invoke it from ``fillten_pyarrowjvm.py`` script that will -import the ``FillTen`` class and print out the result of invoking ``FillTen.createArray`` +import the ``FillTen`` class and print out the result of invoking ``FillTen.createArray`` .. code-block:: python @@ -291,7 +291,7 @@ Running the python script will lead to two lines getting printed: The first line is the raw result of invoking the ``FillTen.createArray`` method. The resulting object is a proxy to the actual Java object, so it's not really a pyarrow -Array, it will lack most of its capabilities and methods. +Array, it will lack most of its capabilities and methods. That's why we subsequently use ``pyarrow.jvm.array`` to convert it to an actual ``pyarrow`` array. That allows us to treat it like any other ``pyarrow`` array. The result is the second line in the output where the array is correctly reported @@ -441,3 +441,130 @@ values printed by the Python script have been properly changed by the Java code: 9, 10 ] + +We can also use the C Stream Interface to exchange +:py:class:`pyarrow.RecordBatchReader`s between Java and Python. We'll +use this Java class as a demo, which lets you read an Arrow IPC file +via Java's implementation, or write data to a JSON file: + +.. code-block:: java + + import java.io.File; + import java.nio.file.Files; + import java.nio.file.Paths; + + import org.apache.arrow.c.ArrowArrayStream; + import org.apache.arrow.c.Data; + import org.apache.arrow.memory.BufferAllocator; + import org.apache.arrow.memory.RootAllocator; + import org.apache.arrow.vector.ipc.ArrowFileReader; + import org.apache.arrow.vector.ipc.ArrowReader; + import org.apache.arrow.vector.ipc.JsonFileWriter; + + public class PythonInteropDemo implements AutoCloseable { + private final BufferAllocator allocator; + + public PythonInteropDemo() { + this.allocator = new RootAllocator(); + } + + public void exportStream(String path, long cStreamPointer) throws Exception { + try (final ArrowArrayStream stream = ArrowArrayStream.wrap(cStreamPointer)) { + ArrowFileReader reader = new ArrowFileReader(Files.newByteChannel(Paths.get(path)), allocator); + Data.exportArrayStream(allocator, reader, stream); + } + } + + public void importStream(String path, long cStreamPointer) throws Exception { + try (final ArrowArrayStream stream = ArrowArrayStream.wrap(cStreamPointer); + final ArrowReader input = Data.importArrayStream(allocator, stream); + JsonFileWriter writer = new JsonFileWriter(new File(path))) { + writer.start(input.getVectorSchemaRoot().getSchema(), input); + while (input.loadNextBatch()) { + writer.write(input.getVectorSchemaRoot()); + } + } + } + + @Override + public void close() throws Exception { + allocator.close(); + } + } + +On the Python side, we'll use JPype as before, except this time we'll +send RecordBatchReaders back and forth: + +.. code-block:: python + + import tempfile + + import jpype + import jpype.imports + from jpype.types import * + + # Init the JVM and make demo class available to Python. + jpype.startJVM(classpath=["./dependencies/*", "./target/*"]) + PythonInteropDemo = JClass("PythonInteropDemo") + demo = PythonInteropDemo() + + # Create a Python record batch reader + import pyarrow as pa + schema = pa.schema([ + ("ints", pa.int64()), + ("strs", pa.string()) + ]) + batches = [ + pa.record_batch([ + [0, 2, 4, 8], + ["a", "b", "c", None], + ], schema=schema), + pa.record_batch([ + [None, 32, 64, None], + ["e", None, None, "h"], + ], schema=schema), + ] + reader = pa.RecordBatchReader.from_batches(schema, batches) + + from pyarrow.cffi import ffi as arrow_c + + # Export the Python reader through C Data + c_stream = arrow_c.new("struct ArrowArrayStream*") + c_stream_ptr = int(arrow_c.cast("uintptr_t", c_stream)) + reader._export_to_c(c_stream_ptr) + + # Send reader to the Java function that writes a JSON file + with tempfile.NamedTemporaryFile() as temp: + demo.importStream(temp.name, c_stream_ptr) + + # Read the JSON file back + with open(temp.name) as source: + print("JSON file written by Java:") + print(source.read()) + + + # Write an Arrow IPC file for Java to read + with tempfile.NamedTemporaryFile() as temp: + with pa.ipc.new_file(temp.name, schema) as sink: + for batch in batches: + sink.write_batch(batch) + + demo.exportStream(temp.name, c_stream_ptr) + with pa.RecordBatchReader._import_from_c(c_stream_ptr) as source: + print("IPC file read by Java:") + print(source.read_all()) + +.. code-block:: console + + $ mvn package + $ mvn org.apache.maven.plugins:maven-dependency-plugin:2.7:copy-dependencies -DoutputDirectory=dependencies + $ python demo.py + JSON file written by Java: + {"schema":{"fields":[{"name":"ints","nullable":true,"type":{"name":"int","bitWidth":64,"isSigned":true},"children":[]},{"name":"strs","nullable":true,"type":{"name":"utf8"},"children":[]}]},"batches":[{"count":4,"columns":[{"name":"ints","count":4,"VALIDITY":[1,1,1,1],"DATA":["0","2","4","8"]},{"name":"strs","count":4,"VALIDITY":[1,1,1,0],"OFFSET":[0,1,2,3,3],"DATA":["a","b","c",""]}]},{"count":4,"columns":[{"name":"ints","count":4,"VALIDITY":[0,1,1,0],"DATA":["0","32","64","0"]},{"name":"strs","count":4,"VALIDITY":[1,0,0,1],"OFFSET":[0,1,1,1,2],"DATA":["e","","","h"]}]}]} + IPC file read by Java: + pyarrow.Table + ints: int64 + strs: string + ---- + ints: [[0,2,4,8],[null,32,64,null]] + strs: [["a","b","c",null],["e",null,null,"h"]]