diff --git a/common/src/main/scala/org/apache/comet/vector/ExportedBatch.scala b/common/src/main/scala/org/apache/comet/vector/ExportedBatch.scala new file mode 100644 index 0000000000..2e97a0dcc6 --- /dev/null +++ b/common/src/main/scala/org/apache/comet/vector/ExportedBatch.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.vector + +import org.apache.arrow.c.ArrowArray +import org.apache.arrow.c.ArrowSchema + +/** + * A wrapper class to hold the exported Arrow arrays and schemas. + * + * @param batch + * a list containing number of rows + pairs of memory addresses in the format of (address of + * Arrow array, address of Arrow schema) + * @param arrowSchemas + * the exported Arrow schemas, needs to be deallocated after being moved by the native executor + * @param arrowArrays + * the exported Arrow arrays, needs to be deallocated after being moved by the native executor + */ +case class ExportedBatch( + batch: Array[Long], + arrowSchemas: Array[ArrowSchema], + arrowArrays: Array[ArrowArray]) { + def close(): Unit = { + arrowSchemas.foreach(_.close()) + arrowArrays.foreach(_.close()) + } +} diff --git a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala index 89f79c9cdf..eed8fd05b1 100644 --- a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala +++ b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala @@ -44,43 +44,53 @@ class NativeUtil { * @param batch * the input Comet columnar batch * @return - * a list containing number of rows + pairs of memory addresses in the format of (address of - * Arrow array, address of Arrow schema) + * an exported batches object containing an array containing number of rows + pairs of memory + * addresses in the format of (address of Arrow array, address of Arrow schema) */ - def exportBatch(batch: ColumnarBatch): Array[Long] = { + def exportBatch(batch: ColumnarBatch): ExportedBatch = { val exportedVectors = mutable.ArrayBuffer.empty[Long] exportedVectors += batch.numRows() + // Run checks prior to exporting the batch + (0 until batch.numCols()).foreach { index => + val c = batch.column(index) + if (!c.isInstanceOf[CometVector]) { + batch.close() + throw new SparkException( + "Comet execution only takes Arrow Arrays, but got " + + s"${c.getClass}") + } + } + + val arrowSchemas = mutable.ArrayBuffer.empty[ArrowSchema] + val arrowArrays = mutable.ArrayBuffer.empty[ArrowArray] + (0 until batch.numCols()).foreach { index => - batch.column(index) match { - case a: CometVector => - val valueVector = a.getValueVector - - val provider = if (valueVector.getField.getDictionary != null) { - a.getDictionaryProvider - } else { - null - } - - val arrowSchema = ArrowSchema.allocateNew(allocator) - val arrowArray = ArrowArray.allocateNew(allocator) - Data.exportVector( - allocator, - getFieldVector(valueVector, "export"), - provider, - arrowArray, - arrowSchema) - - exportedVectors += arrowArray.memoryAddress() - exportedVectors += arrowSchema.memoryAddress() - case c => - throw new SparkException( - "Comet execution only takes Arrow Arrays, but got " + - s"${c.getClass}") + val cometVector = batch.column(index).asInstanceOf[CometVector] + val valueVector = cometVector.getValueVector + + val provider = if (valueVector.getField.getDictionary != null) { + cometVector.getDictionaryProvider + } else { + null } + + val arrowSchema = ArrowSchema.allocateNew(allocator) + val arrowArray = ArrowArray.allocateNew(allocator) + arrowSchemas += arrowSchema + arrowArrays += arrowArray + Data.exportVector( + allocator, + getFieldVector(valueVector, "export"), + provider, + arrowArray, + arrowSchema) + + exportedVectors += arrowArray.memoryAddress() + exportedVectors += arrowSchema.memoryAddress() } - exportedVectors.toArray + ExportedBatch(exportedVectors.toArray, arrowSchemas.toArray, arrowArrays.toArray) } /** diff --git a/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala b/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala index 8d6a633431..207474286e 100644 --- a/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala +++ b/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala @@ -215,12 +215,8 @@ object Utils { val writer = new ArrowStreamWriter(root, provider, Channels.newChannel(out)) writer.start() writer.writeBatch() - root.clear() - writer.end() - - out.flush() - out.close() + writer.close() if (out.size() > 0) { (batch.numRows(), cbbos.toChunkedByteBuffer) diff --git a/spark/src/main/java/org/apache/comet/CometBatchIterator.java b/spark/src/main/java/org/apache/comet/CometBatchIterator.java index 33603290ce..eb7506b889 100644 --- a/spark/src/main/java/org/apache/comet/CometBatchIterator.java +++ b/spark/src/main/java/org/apache/comet/CometBatchIterator.java @@ -23,6 +23,7 @@ import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.apache.comet.vector.ExportedBatch; import org.apache.comet.vector.NativeUtil; /** @@ -34,9 +35,12 @@ public class CometBatchIterator { final Iterator input; final NativeUtil nativeUtil; + private ExportedBatch lastBatch; + CometBatchIterator(Iterator input, NativeUtil nativeUtil) { this.input = input; this.nativeUtil = nativeUtil; + this.lastBatch = null; } /** @@ -45,12 +49,27 @@ public class CometBatchIterator { * indicating the end of the iterator. */ public long[] next() { + // Native side already copied the content of ArrowSchema and ArrowArray. We should deallocate + // the ArrowSchema and ArrowArray base structures allocated in JVM. + if (lastBatch != null) { + lastBatch.close(); + lastBatch = null; + } + boolean hasBatch = input.hasNext(); if (!hasBatch) { return new long[] {-1}; } - return nativeUtil.exportBatch(input.next()); + lastBatch = nativeUtil.exportBatch(input.next()); + return lastBatch.batch(); + } + + public void close() { + if (lastBatch != null) { + lastBatch.close(); + lastBatch = null; + } } } diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala index 29eb2f0ca9..f1e77fb5d1 100644 --- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala +++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala @@ -159,6 +159,8 @@ class CometExecIterator( } nativeLib.releasePlan(plan) + cometBatchIterators.foreach(_.close()) + // The allocator thoughts the exported ArrowArray and ArrowSchema structs are not released, // so it will report: // Caused by: java.lang.IllegalStateException: Memory was leaked by query.