From 689f86f761009d2220d9679102770a8763e55573 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 19 Jul 2017 15:22:24 +0900 Subject: [PATCH 1/8] Import ArrowUtils and use it. --- .../sql/execution/arrow/ArrowConverters.scala | 32 +---- .../sql/execution/arrow/ArrowUtils.scala | 109 ++++++++++++++++++ .../arrow/ArrowConvertersSuite.scala | 2 +- .../sql/execution/arrow/ArrowUtilsSuite.scala | 65 +++++++++++ 4 files changed, 177 insertions(+), 31 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowUtilsSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 6af5c73422377..c913efe52a41c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -70,34 +70,6 @@ private[sql] object ArrowPayload { private[sql] object ArrowConverters { - /** - * Map a Spark DataType to ArrowType. - */ - private[arrow] def sparkTypeToArrowType(dataType: DataType): ArrowType = { - dataType match { - case BooleanType => ArrowType.Bool.INSTANCE - case ShortType => new ArrowType.Int(8 * ShortType.defaultSize, true) - case IntegerType => new ArrowType.Int(8 * IntegerType.defaultSize, true) - case LongType => new ArrowType.Int(8 * LongType.defaultSize, true) - case FloatType => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE) - case DoubleType => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE) - case ByteType => new ArrowType.Int(8, true) - case StringType => ArrowType.Utf8.INSTANCE - case BinaryType => ArrowType.Binary.INSTANCE - case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dataType") - } - } - - /** - * Convert a Spark Dataset schema to Arrow schema. - */ - private[arrow] def schemaToArrowSchema(schema: StructType): Schema = { - val arrowFields = schema.fields.map { f => - new Field(f.name, f.nullable, sparkTypeToArrowType(f.dataType), List.empty[Field].asJava) - } - new Schema(arrowFields.toList.asJava) - } - /** * Maps Iterator from InternalRow to ArrowPayload. Limit ArrowRecordBatch size in ArrowPayload * by setting maxRecordsPerBatch or use 0 to fully consume rowIter. @@ -178,7 +150,7 @@ private[sql] object ArrowConverters { batch: ArrowRecordBatch, schema: StructType, allocator: BufferAllocator): Array[Byte] = { - val arrowSchema = ArrowConverters.schemaToArrowSchema(schema) + val arrowSchema = ArrowUtils.toArrowSchema(schema) val root = VectorSchemaRoot.create(arrowSchema, allocator) val out = new ByteArrayOutputStream() val writer = new ArrowFileWriter(root, null, Channels.newChannel(out)) @@ -410,7 +382,7 @@ private[arrow] object ColumnWriter { * Create an Arrow ColumnWriter given the type and ordinal of row. */ def apply(dataType: DataType, ordinal: Int, allocator: BufferAllocator): ColumnWriter = { - val dtype = ArrowConverters.sparkTypeToArrowType(dataType) + val dtype = ArrowUtils.toArrowType(dataType) dataType match { case BooleanType => new BooleanColumnWriter(dtype, ordinal, allocator) case ShortType => new ShortColumnWriter(dtype, ordinal, allocator) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala new file mode 100644 index 0000000000000..2caf1ef02909a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.arrow + +import scala.collection.JavaConverters._ + +import org.apache.arrow.memory.RootAllocator +import org.apache.arrow.vector.types.FloatingPointPrecision +import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} + +import org.apache.spark.sql.types._ + +object ArrowUtils { + + val rootAllocator = new RootAllocator(Long.MaxValue) + + // todo: support more types. + + def toArrowType(dt: DataType): ArrowType = dt match { + case BooleanType => ArrowType.Bool.INSTANCE + case ByteType => new ArrowType.Int(8, true) + case ShortType => new ArrowType.Int(8 * 2, true) + case IntegerType => new ArrowType.Int(8 * 4, true) + case LongType => new ArrowType.Int(8 * 8, true) + case FloatType => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE) + case DoubleType => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE) + case StringType => ArrowType.Utf8.INSTANCE + case BinaryType => ArrowType.Binary.INSTANCE + case DecimalType.Fixed(precision, scale) => new ArrowType.Decimal(precision, scale) + case _ => throw new UnsupportedOperationException(s"Unsupported data type: ${dt.simpleString}") + } + + def fromArrowType(dt: ArrowType): DataType = dt match { + case ArrowType.Bool.INSTANCE => BooleanType + case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 => ByteType + case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 2 => ShortType + case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 4 => IntegerType + case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 8 => LongType + case float: ArrowType.FloatingPoint + if float.getPrecision() == FloatingPointPrecision.SINGLE => FloatType + case float: ArrowType.FloatingPoint + if float.getPrecision() == FloatingPointPrecision.DOUBLE => DoubleType + case ArrowType.Utf8.INSTANCE => StringType + case ArrowType.Binary.INSTANCE => BinaryType + case d: ArrowType.Decimal => DecimalType(d.getPrecision, d.getScale) + case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dt") + } + + def toArrowField(name: String, dt: DataType, nullable: Boolean): Field = { + dt match { + case ArrayType(elementType, containsNull) => + val fieldType = new FieldType(nullable, ArrowType.List.INSTANCE, null) + new Field(name, fieldType, Seq(toArrowField("element", elementType, containsNull)).asJava) + case StructType(fields) => + val fieldType = new FieldType(nullable, ArrowType.Struct.INSTANCE, null) + new Field(name, fieldType, + fields.map { field => + toArrowField(field.name, field.dataType, field.nullable) + }.toSeq.asJava) + case dataType => + val fieldType = new FieldType(nullable, toArrowType(dataType), null) + new Field(name, fieldType, Seq.empty[Field].asJava) + } + } + + def fromArrowField(field: Field): DataType = { + field.getType match { + case ArrowType.List.INSTANCE => + val elementField = field.getChildren().get(0) + val elementType = fromArrowField(elementField) + ArrayType(elementType, containsNull = elementField.isNullable) + case ArrowType.Struct.INSTANCE => + val fields = field.getChildren().asScala.map { child => + val dt = fromArrowField(child) + StructField(child.getName, dt, child.isNullable) + } + StructType(fields) + case arrowType => fromArrowType(arrowType) + } + } + + def toArrowSchema(schema: StructType): Schema = { + new Schema(schema.map { field => + toArrowField(field.name, field.dataType, field.nullable) + }.asJava) + } + + def fromArrowSchema(schema: Schema): StructType = { + StructType(schema.getFields.asScala.map { field => + val dt = fromArrowField(field) + StructField(field.getName, dt, field.isNullable) + }) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index 159328cc0d958..55b465578a42d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -1202,7 +1202,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { val allocator = new RootAllocator(Long.MaxValue) val jsonReader = new JsonFileReader(jsonFile, allocator) - val arrowSchema = ArrowConverters.schemaToArrowSchema(sparkSchema) + val arrowSchema = ArrowUtils.toArrowSchema(sparkSchema) val jsonSchema = jsonReader.start() Validator.compareSchemas(arrowSchema, jsonSchema) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowUtilsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowUtilsSuite.scala new file mode 100644 index 0000000000000..638619fd39d06 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowUtilsSuite.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.arrow + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types._ + +class ArrowUtilsSuite extends SparkFunSuite { + + def roundtrip(dt: DataType): Unit = { + dt match { + case schema: StructType => + assert(ArrowUtils.fromArrowSchema(ArrowUtils.toArrowSchema(schema)) === schema) + case _ => + roundtrip(new StructType().add("value", dt)) + } + } + + test("simple") { + roundtrip(BooleanType) + roundtrip(ByteType) + roundtrip(ShortType) + roundtrip(IntegerType) + roundtrip(LongType) + roundtrip(FloatType) + roundtrip(DoubleType) + roundtrip(StringType) + roundtrip(BinaryType) + roundtrip(DecimalType.SYSTEM_DEFAULT) + } + + test("array") { + roundtrip(ArrayType(IntegerType, containsNull = true)) + roundtrip(ArrayType(IntegerType, containsNull = false)) + roundtrip(ArrayType(ArrayType(IntegerType, containsNull = true), containsNull = true)) + roundtrip(ArrayType(ArrayType(IntegerType, containsNull = false), containsNull = true)) + roundtrip(ArrayType(ArrayType(IntegerType, containsNull = true), containsNull = false)) + roundtrip(ArrayType(ArrayType(IntegerType, containsNull = false), containsNull = false)) + } + + test("struct") { + roundtrip(new StructType()) + roundtrip(new StructType().add("i", IntegerType)) + roundtrip(new StructType().add("arr", ArrayType(IntegerType))) + roundtrip(new StructType().add("i", IntegerType).add("arr", ArrayType(IntegerType))) + roundtrip(new StructType().add( + "struct", + new StructType().add("i", IntegerType).add("arr", ArrayType(IntegerType)))) + } +} From 73899b26d10ed3763569f4aa2e836643a5ce941a Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 19 Jul 2017 14:50:22 +0900 Subject: [PATCH 2/8] Introduce ArrowColumnVector as a reader for Arrow vectors. --- .../vectorized/ArrowColumnVector.java | 510 ++++++++++++++++++ .../execution/vectorized/ColumnVector.java | 16 +- .../vectorized/ArrowColumnVectorSuite.scala | 396 ++++++++++++++ 3 files changed, 914 insertions(+), 8 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java new file mode 100644 index 0000000000000..35a8412085ca1 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java @@ -0,0 +1,510 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.vectorized; + +import org.apache.arrow.vector.*; +import org.apache.arrow.vector.complex.*; +import org.apache.arrow.vector.holders.NullableVarCharHolder; + +import org.apache.spark.memory.MemoryMode; +import org.apache.spark.sql.execution.arrow.ArrowUtils; +import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A column backed by Apache Arrow. + */ +public final class ArrowColumnVector extends ColumnVector { + + private ValueVector vector; + private ValueVector.Accessor nulls; + + private NullableBitVector boolData; + private NullableTinyIntVector byteData; + private NullableSmallIntVector shortData; + private NullableIntVector intData; + private NullableBigIntVector longData; + + private NullableFloat4Vector floatData; + private NullableFloat8Vector doubleData; + private NullableDecimalVector decimalData; + + private NullableVarCharVector stringData; + + private NullableVarBinaryVector binaryData; + + private UInt4Vector listOffsetData; + + public ArrowColumnVector(ValueVector vector) { + super(vector.getValueCapacity(), DataTypes.NullType, MemoryMode.OFF_HEAP); + initialize(vector); + } + + @Override + public long nullsNativeAddress() { + throw new RuntimeException("Cannot get native address for arrow column"); + } + + @Override + public long valuesNativeAddress() { + throw new RuntimeException("Cannot get native address for arrow column"); + } + + @Override + public void close() { + if (childColumns != null) { + for (int i = 0; i < childColumns.length; i++) { + childColumns[i].close(); + } + } + vector.close(); + } + + // + // APIs dealing with nulls + // + + @Override + public void putNotNull(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public void putNull(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public void putNulls(int rowId, int count) { + throw new UnsupportedOperationException(); + } + + @Override + public void putNotNulls(int rowId, int count) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isNullAt(int rowId) { + return nulls.isNull(rowId); + } + + // + // APIs dealing with Booleans + // + + @Override + public void putBoolean(int rowId, boolean value) { + throw new UnsupportedOperationException(); + } + + @Override + public void putBooleans(int rowId, int count, boolean value) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean getBoolean(int rowId) { + return boolData.getAccessor().get(rowId) == 1; + } + + @Override + public boolean[] getBooleans(int rowId, int count) { + assert(dictionary == null); + NullableBitVector.Accessor accessor = boolData.getAccessor(); + boolean[] array = new boolean[count]; + for (int i = 0; i < count; ++i) { + array[i] = (accessor.get(rowId + i) == 1); + } + return array; + } + + // + // APIs dealing with Bytes + // + + @Override + public void putByte(int rowId, byte value) { + throw new UnsupportedOperationException(); + } + + @Override + public void putBytes(int rowId, int count, byte value) { + throw new UnsupportedOperationException(); + } + + @Override + public void putBytes(int rowId, int count, byte[] src, int srcIndex) { + throw new UnsupportedOperationException(); + } + + @Override + public byte getByte(int rowId) { + return byteData.getAccessor().get(rowId); + } + + @Override + public byte[] getBytes(int rowId, int count) { + assert(dictionary == null); + NullableTinyIntVector.Accessor accessor = byteData.getAccessor(); + byte[] array = new byte[count]; + for (int i = 0; i < count; ++i) { + array[i] = accessor.get(rowId + i); + } + return array; + } + + // + // APIs dealing with Shorts + // + + @Override + public void putShort(int rowId, short value) { + throw new UnsupportedOperationException(); + } + + @Override + public void putShorts(int rowId, int count, short value) { + throw new UnsupportedOperationException(); + } + + @Override + public void putShorts(int rowId, int count, short[] src, int srcIndex) { + throw new UnsupportedOperationException(); + } + + @Override + public short getShort(int rowId) { + return shortData.getAccessor().get(rowId); + } + + @Override + public short[] getShorts(int rowId, int count) { + assert(dictionary == null); + NullableSmallIntVector.Accessor accessor = shortData.getAccessor(); + short[] array = new short[count]; + for (int i = 0; i < count; ++i) { + array[i] = accessor.get(rowId + i); + } + return array; + } + + // + // APIs dealing with Ints + // + + @Override + public void putInt(int rowId, int value) { + throw new UnsupportedOperationException(); + } + + @Override + public void putInts(int rowId, int count, int value) { + throw new UnsupportedOperationException(); + } + + @Override + public void putInts(int rowId, int count, int[] src, int srcIndex) { + throw new UnsupportedOperationException(); + } + + @Override + public void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { + throw new UnsupportedOperationException(); + } + + @Override + public int getInt(int rowId) { + return intData.getAccessor().get(rowId); + } + + @Override + public int[] getInts(int rowId, int count) { + assert(dictionary == null); + NullableIntVector.Accessor accessor = intData.getAccessor(); + int[] array = new int[count]; + for (int i = 0; i < count; ++i) { + array[i] = accessor.get(rowId + i); + } + return array; + } + + @Override + public int getDictId(int rowId) { + throw new UnsupportedOperationException(); + } + + // + // APIs dealing with Longs + // + + @Override + public void putLong(int rowId, long value) { + throw new UnsupportedOperationException(); + } + + @Override + public void putLongs(int rowId, int count, long value) { + throw new UnsupportedOperationException(); + } + + @Override + public void putLongs(int rowId, int count, long[] src, int srcIndex) { + throw new UnsupportedOperationException(); + } + + @Override + public void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { + throw new UnsupportedOperationException(); + } + + @Override + public long getLong(int rowId) { + return longData.getAccessor().get(rowId); + } + + @Override + public long[] getLongs(int rowId, int count) { + assert(dictionary == null); + NullableBigIntVector.Accessor accessor = longData.getAccessor(); + long[] array = new long[count]; + for (int i = 0; i < count; ++i) { + array[i] = accessor.get(rowId + i); + } + return array; + } + + // + // APIs dealing with floats + // + + @Override + public void putFloat(int rowId, float value) { + throw new UnsupportedOperationException(); + } + + @Override + public void putFloats(int rowId, int count, float value) { + throw new UnsupportedOperationException(); + } + + @Override + public void putFloats(int rowId, int count, float[] src, int srcIndex) { + throw new UnsupportedOperationException(); + } + + @Override + public void putFloats(int rowId, int count, byte[] src, int srcIndex) { + throw new UnsupportedOperationException(); + } + + @Override + public float getFloat(int rowId) { + return floatData.getAccessor().get(rowId); + } + + @Override + public float[] getFloats(int rowId, int count) { + assert(dictionary == null); + NullableFloat4Vector.Accessor accessor = floatData.getAccessor(); + float[] array = new float[count]; + for (int i = 0; i < count; ++i) { + array[i] = accessor.get(rowId + i); + } + return array; + } + + // + // APIs dealing with doubles + // + + @Override + public void putDouble(int rowId, double value) { + throw new UnsupportedOperationException(); + } + + @Override + public void putDoubles(int rowId, int count, double value) { + throw new UnsupportedOperationException(); + } + + @Override + public void putDoubles(int rowId, int count, double[] src, int srcIndex) { + throw new UnsupportedOperationException(); + } + + @Override + public void putDoubles(int rowId, int count, byte[] src, int srcIndex) { + throw new UnsupportedOperationException(); + } + + @Override + public double getDouble(int rowId) { + return doubleData.getAccessor().get(rowId); + } + + @Override + public double[] getDoubles(int rowId, int count) { + assert(dictionary == null); + NullableFloat8Vector.Accessor accessor = doubleData.getAccessor(); + double[] array = new double[count]; + for (int i = 0; i < count; ++i) { + array[i] = accessor.get(rowId + i); + } + return array; + } + + // + // APIs dealing with Arrays + // + + @Override + public int getArrayLength(int rowId) { + return listOffsetData.getAccessor().get(rowId + 1) - listOffsetData.getAccessor().get(rowId); + } + + @Override + public int getArrayOffset(int rowId) { + return listOffsetData.getAccessor().get(rowId); + } + + @Override + public void putArray(int rowId, int offset, int length) { + throw new UnsupportedOperationException(); + } + + @Override + public void loadBytes(Array array) { + throw new UnsupportedOperationException(); + } + + // + // APIs dealing with Byte Arrays + // + + @Override + public int putByteArray(int rowId, byte[] value, int offset, int count) { + throw new UnsupportedOperationException(); + } + + // + // APIs dealing with Decimals + // + + @Override + public Decimal getDecimal(int rowId, int precision, int scale) { + if (isNullAt(rowId)) return null; + return Decimal.apply(decimalData.getAccessor().getObject(rowId), precision, scale); + } + + @Override + public final void putDecimal(int rowId, Decimal value, int precision) { + throw new UnsupportedOperationException(); + } + + // + // APIs dealing with UTF8Strings + // + + private NullableVarCharHolder stringResult = new NullableVarCharHolder(); + + @Override + public UTF8String getUTF8String(int rowId) { + stringData.getAccessor().get(rowId, stringResult); + if (stringResult.isSet == 0) { + return null; + } else { + return UTF8String.fromAddress(null, + stringResult.buffer.memoryAddress() + stringResult.start, + stringResult.end - stringResult.start); + } + } + + // + // APIs dealing with Binaries + // + + @Override + public byte[] getBinary(int rowId) { + return binaryData.getAccessor().getObject(rowId); + } + + @Override + protected void reserveInternal(int newCapacity) { + while (vector.getValueCapacity() <= newCapacity) { + vector.reAlloc(); + } + capacity = vector.getValueCapacity(); + } + + private void initialize(ValueVector vector) { + this.vector = vector; + this.type = ArrowUtils.fromArrowField(vector.getField()); + if (vector instanceof NullableBitVector) { + boolData = (NullableBitVector) vector; + nulls = boolData.getAccessor(); + } else if (vector instanceof NullableTinyIntVector) { + byteData = (NullableTinyIntVector) vector; + nulls = byteData.getAccessor(); + } else if (vector instanceof NullableSmallIntVector) { + shortData = (NullableSmallIntVector) vector; + nulls = shortData.getAccessor(); + } else if (vector instanceof NullableIntVector) { + intData = (NullableIntVector) vector; + nulls = intData.getAccessor(); + } else if (vector instanceof NullableBigIntVector) { + longData = (NullableBigIntVector) vector; + nulls = longData.getAccessor(); + } else if (vector instanceof NullableFloat4Vector) { + floatData = (NullableFloat4Vector) vector; + nulls = floatData.getAccessor(); + } else if (vector instanceof NullableFloat8Vector) { + doubleData = (NullableFloat8Vector) vector; + nulls = doubleData.getAccessor(); + } else if (vector instanceof NullableDecimalVector) { + decimalData = (NullableDecimalVector) vector; + nulls = decimalData.getAccessor(); + } else if (vector instanceof NullableVarCharVector) { + stringData = (NullableVarCharVector) vector; + nulls = stringData.getAccessor(); + } else if (vector instanceof NullableVarBinaryVector) { + binaryData = (NullableVarBinaryVector) vector; + nulls = binaryData.getAccessor(); + } else if (vector instanceof ListVector) { + ListVector listVector = (ListVector) vector; + listOffsetData = listVector.getOffsetVector(); + nulls = listVector.getAccessor(); + + childColumns = new ColumnVector[1]; + childColumns[0] = new ArrowColumnVector(listVector.getDataVector()); + resultArray = new Array(childColumns[0]); + } else if (vector instanceof MapVector) { + MapVector mapVector = (MapVector) vector; + nulls = mapVector.getAccessor(); + + childColumns = new ArrowColumnVector[mapVector.size()]; + for (int i = 0; i < childColumns.length; ++i) { + childColumns[i] = new ArrowColumnVector(mapVector.getVectorById(i)); + } + resultStruct = new ColumnarBatch.Row(childColumns); + } + numNulls = nulls.getNullCount(); + anyNullsSet = numNulls > 0; + isConstant = true; + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index 0c027f80d48cc..77966382881b8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -646,7 +646,7 @@ public MapData getMap(int ordinal) { /** * Returns the decimal for rowId. */ - public final Decimal getDecimal(int rowId, int precision, int scale) { + public Decimal getDecimal(int rowId, int precision, int scale) { if (precision <= Decimal.MAX_INT_DIGITS()) { return Decimal.createUnsafe(getInt(rowId), precision, scale); } else if (precision <= Decimal.MAX_LONG_DIGITS()) { @@ -661,7 +661,7 @@ public final Decimal getDecimal(int rowId, int precision, int scale) { } - public final void putDecimal(int rowId, Decimal value, int precision) { + public void putDecimal(int rowId, Decimal value, int precision) { if (precision <= Decimal.MAX_INT_DIGITS()) { putInt(rowId, (int) value.toUnscaledLong()); } else if (precision <= Decimal.MAX_LONG_DIGITS()) { @@ -675,7 +675,7 @@ public final void putDecimal(int rowId, Decimal value, int precision) { /** * Returns the UTF8String for rowId. */ - public final UTF8String getUTF8String(int rowId) { + public UTF8String getUTF8String(int rowId) { if (dictionary == null) { ColumnVector.Array a = getByteArray(rowId); return UTF8String.fromBytes(a.byteArray, a.byteArrayOffset, a.length); @@ -688,7 +688,7 @@ public final UTF8String getUTF8String(int rowId) { /** * Returns the byte array for rowId. */ - public final byte[] getBinary(int rowId) { + public byte[] getBinary(int rowId) { if (dictionary == null) { ColumnVector.Array array = getByteArray(rowId); byte[] bytes = new byte[array.length]; @@ -956,7 +956,7 @@ public final int appendStruct(boolean isNull) { /** * Data type for this column. */ - protected final DataType type; + protected DataType type; /** * Number of nulls in this column. This is an optimization for the reader, to skip NULL checks. @@ -988,17 +988,17 @@ public final int appendStruct(boolean isNull) { /** * If this is a nested type (array or struct), the column for the child data. */ - protected final ColumnVector[] childColumns; + protected ColumnVector[] childColumns; /** * Reusable Array holder for getArray(). */ - protected final Array resultArray; + protected Array resultArray; /** * Reusable Struct holder for getStruct(). */ - protected final ColumnarBatch.Row resultStruct; + protected ColumnarBatch.Row resultStruct; /** * The Dictionary for this column. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala new file mode 100644 index 0000000000000..433a65a84117a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala @@ -0,0 +1,396 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.vectorized + +import org.apache.arrow.vector._ +import org.apache.arrow.vector.complex._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.execution.arrow.ArrowUtils +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +class ArrowColumnVectorSuite extends SparkFunSuite { + + test("boolean") { + val allocator = ArrowUtils.rootAllocator.newChildAllocator("boolean", 0, Long.MaxValue) + val vector = ArrowUtils.toArrowField("boolean", BooleanType, nullable = true) + .createVector(allocator).asInstanceOf[NullableBitVector] + vector.allocateNew() + val mutator = vector.getMutator() + + (0 until 10).foreach { i => + mutator.setSafe(i, if (i % 2 == 0) 1 else 0) + } + mutator.setNull(10) + mutator.setValueCount(11) + + val columnVector = new ArrowColumnVector(vector) + assert(columnVector.dataType === BooleanType) + assert(columnVector.anyNullsSet) + assert(columnVector.numNulls === 1) + + (0 until 10).foreach { i => + assert(columnVector.getBoolean(i) === (i % 2 == 0)) + } + assert(columnVector.isNullAt(10)) + + columnVector.close() + allocator.close() + } + + test("byte") { + val allocator = ArrowUtils.rootAllocator.newChildAllocator("byte", 0, Long.MaxValue) + val vector = ArrowUtils.toArrowField("byte", ByteType, nullable = true) + .createVector(allocator).asInstanceOf[NullableTinyIntVector] + vector.allocateNew() + val mutator = vector.getMutator() + + (0 until 10).foreach { i => + mutator.setSafe(i, i.toByte) + } + mutator.setNull(10) + mutator.setValueCount(11) + + val columnVector = new ArrowColumnVector(vector) + assert(columnVector.dataType === ByteType) + assert(columnVector.anyNullsSet) + assert(columnVector.numNulls === 1) + + (0 until 10).foreach { i => + assert(columnVector.getByte(i) === i.toByte) + } + assert(columnVector.isNullAt(10)) + + columnVector.close() + allocator.close() + } + + test("short") { + val allocator = ArrowUtils.rootAllocator.newChildAllocator("short", 0, Long.MaxValue) + val vector = ArrowUtils.toArrowField("short", ShortType, nullable = true) + .createVector(allocator).asInstanceOf[NullableSmallIntVector] + vector.allocateNew() + val mutator = vector.getMutator() + + (0 until 10).foreach { i => + mutator.setSafe(i, i.toShort) + } + mutator.setNull(10) + mutator.setValueCount(11) + + val columnVector = new ArrowColumnVector(vector) + assert(columnVector.dataType === ShortType) + assert(columnVector.anyNullsSet) + assert(columnVector.numNulls === 1) + + (0 until 10).foreach { i => + assert(columnVector.getShort(i) === i.toShort) + } + assert(columnVector.isNullAt(10)) + + columnVector.close() + allocator.close() + } + + test("int") { + val allocator = ArrowUtils.rootAllocator.newChildAllocator("int", 0, Long.MaxValue) + val vector = ArrowUtils.toArrowField("int", IntegerType, nullable = true) + .createVector(allocator).asInstanceOf[NullableIntVector] + vector.allocateNew() + val mutator = vector.getMutator() + + (0 until 10).foreach { i => + mutator.setSafe(i, i) + } + mutator.setNull(10) + mutator.setValueCount(11) + + val columnVector = new ArrowColumnVector(vector) + assert(columnVector.dataType === IntegerType) + assert(columnVector.anyNullsSet) + assert(columnVector.numNulls === 1) + + (0 until 10).foreach { i => + assert(columnVector.getInt(i) === i) + } + assert(columnVector.isNullAt(10)) + + columnVector.close() + allocator.close() + } + + test("long") { + val allocator = ArrowUtils.rootAllocator.newChildAllocator("long", 0, Long.MaxValue) + val vector = ArrowUtils.toArrowField("long", LongType, nullable = true) + .createVector(allocator).asInstanceOf[NullableBigIntVector] + vector.allocateNew() + val mutator = vector.getMutator() + + (0 until 10).foreach { i => + mutator.setSafe(i, i.toLong) + } + mutator.setNull(10) + mutator.setValueCount(11) + + val columnVector = new ArrowColumnVector(vector) + assert(columnVector.dataType === LongType) + assert(columnVector.anyNullsSet) + assert(columnVector.numNulls === 1) + + (0 until 10).foreach { i => + assert(columnVector.getLong(i) === i.toLong) + } + assert(columnVector.isNullAt(10)) + + columnVector.close() + allocator.close() + } + + test("float") { + val allocator = ArrowUtils.rootAllocator.newChildAllocator("float", 0, Long.MaxValue) + val vector = ArrowUtils.toArrowField("float", FloatType, nullable = true) + .createVector(allocator).asInstanceOf[NullableFloat4Vector] + vector.allocateNew() + val mutator = vector.getMutator() + + (0 until 10).foreach { i => + mutator.setSafe(i, i.toFloat) + } + mutator.setNull(10) + mutator.setValueCount(11) + + val columnVector = new ArrowColumnVector(vector) + assert(columnVector.dataType === FloatType) + assert(columnVector.anyNullsSet) + assert(columnVector.numNulls === 1) + + (0 until 10).foreach { i => + assert(columnVector.getFloat(i) === i.toFloat) + } + assert(columnVector.isNullAt(10)) + + columnVector.close() + allocator.close() + } + + test("double") { + val allocator = ArrowUtils.rootAllocator.newChildAllocator("double", 0, Long.MaxValue) + val vector = ArrowUtils.toArrowField("double", DoubleType, nullable = true) + .createVector(allocator).asInstanceOf[NullableFloat8Vector] + vector.allocateNew() + val mutator = vector.getMutator() + + (0 until 10).foreach { i => + mutator.setSafe(i, i.toDouble) + } + mutator.setNull(10) + mutator.setValueCount(11) + + val columnVector = new ArrowColumnVector(vector) + assert(columnVector.dataType === DoubleType) + assert(columnVector.anyNullsSet) + assert(columnVector.numNulls === 1) + + (0 until 10).foreach { i => + assert(columnVector.getDouble(i) === i.toDouble) + } + assert(columnVector.isNullAt(10)) + + columnVector.close() + allocator.close() + } + + test("string") { + val allocator = ArrowUtils.rootAllocator.newChildAllocator("string", 0, Long.MaxValue) + val vector = ArrowUtils.toArrowField("string", StringType, nullable = true) + .createVector(allocator).asInstanceOf[NullableVarCharVector] + vector.allocateNew() + val mutator = vector.getMutator() + + (0 until 10).foreach { i => + val utf8 = s"str$i".getBytes("utf8") + mutator.setSafe(i, utf8, 0, utf8.length) + } + mutator.setNull(10) + mutator.setValueCount(11) + + val columnVector = new ArrowColumnVector(vector) + assert(columnVector.dataType === StringType) + assert(columnVector.anyNullsSet) + assert(columnVector.numNulls === 1) + + (0 until 10).foreach { i => + assert(columnVector.getUTF8String(i) === UTF8String.fromString(s"str$i")) + } + assert(columnVector.isNullAt(10)) + + columnVector.close() + allocator.close() + } + + test("binary") { + val allocator = ArrowUtils.rootAllocator.newChildAllocator("binary", 0, Long.MaxValue) + val vector = ArrowUtils.toArrowField("binary", BinaryType, nullable = true) + .createVector(allocator).asInstanceOf[NullableVarBinaryVector] + vector.allocateNew() + val mutator = vector.getMutator() + + (0 until 10).foreach { i => + val utf8 = s"str$i".getBytes("utf8") + mutator.setSafe(i, utf8, 0, utf8.length) + } + mutator.setNull(10) + mutator.setValueCount(11) + + val columnVector = new ArrowColumnVector(vector) + assert(columnVector.dataType === BinaryType) + assert(columnVector.anyNullsSet) + assert(columnVector.numNulls === 1) + + (0 until 10).foreach { i => + assert(columnVector.getBinary(i) === s"str$i".getBytes("utf8")) + } + assert(columnVector.isNullAt(10)) + + columnVector.close() + allocator.close() + } + + test("array") { + val allocator = ArrowUtils.rootAllocator.newChildAllocator("array", 0, Long.MaxValue) + val vector = ArrowUtils.toArrowField("array", ArrayType(IntegerType), nullable = true) + .createVector(allocator).asInstanceOf[ListVector] + vector.allocateNew() + val mutator = vector.getMutator() + val elementVector = vector.getDataVector().asInstanceOf[NullableIntVector] + val elementMutator = elementVector.getMutator() + + // [1, 2] + mutator.startNewValue(0) + elementMutator.setSafe(0, 1) + elementMutator.setSafe(1, 2) + mutator.endValue(0, 2) + + // [3, null, 5] + mutator.startNewValue(1) + elementMutator.setSafe(2, 3) + elementMutator.setNull(3) + elementMutator.setSafe(4, 5) + mutator.endValue(1, 3) + + // null + + // [] + mutator.startNewValue(3) + mutator.endValue(3, 0) + + elementMutator.setValueCount(5) + mutator.setValueCount(4) + + val columnVector = new ArrowColumnVector(vector) + assert(columnVector.dataType === ArrayType(IntegerType)) + assert(columnVector.anyNullsSet) + assert(columnVector.numNulls === 1) + + val array0 = columnVector.getArray(0) + assert(array0.numElements() === 2) + assert(array0.getInt(0) === 1) + assert(array0.getInt(1) === 2) + + val array1 = columnVector.getArray(1) + assert(array1.numElements() === 3) + assert(array1.getInt(0) === 3) + assert(array1.isNullAt(1)) + assert(array1.getInt(2) === 5) + + assert(columnVector.isNullAt(2)) + + val array3 = columnVector.getArray(3) + assert(array3.numElements() === 0) + + columnVector.close() + allocator.close() + } + + test("struct") { + val allocator = ArrowUtils.rootAllocator.newChildAllocator("struct", 0, Long.MaxValue) + val schema = new StructType().add("int", IntegerType).add("long", LongType) + val vector = ArrowUtils.toArrowField("struct", schema, nullable = true) + .createVector(allocator).asInstanceOf[NullableMapVector] + vector.allocateNew() + val mutator = vector.getMutator() + val intVector = vector.getChildByOrdinal(0).asInstanceOf[NullableIntVector] + val intMutator = intVector.getMutator() + val longVector = vector.getChildByOrdinal(1).asInstanceOf[NullableBigIntVector] + val longMutator = longVector.getMutator() + + // (1, 1L) + mutator.setIndexDefined(0) + intMutator.setSafe(0, 1) + longMutator.setSafe(0, 1L) + + // (2, null) + mutator.setIndexDefined(1) + intMutator.setSafe(1, 2) + longMutator.setNull(1) + + // (null, 3L) + mutator.setIndexDefined(2) + intMutator.setNull(2) + longMutator.setSafe(2, 3L) + + // null + mutator.setNull(3) + + // (5, 5L) + mutator.setIndexDefined(4) + intMutator.setSafe(4, 5) + longMutator.setSafe(4, 5L) + + intMutator.setValueCount(5) + longMutator.setValueCount(5) + mutator.setValueCount(5) + + val columnVector = new ArrowColumnVector(vector) + assert(columnVector.dataType === schema) + assert(columnVector.anyNullsSet) + assert(columnVector.numNulls === 1) + + val row0 = columnVector.getStruct(0, 2) + assert(row0.getInt(0) === 1) + assert(row0.getLong(1) === 1L) + + val row1 = columnVector.getStruct(1, 2) + assert(row1.getInt(0) === 2) + assert(row1.isNullAt(1)) + + val row2 = columnVector.getStruct(2, 2) + assert(row2.isNullAt(0)) + assert(row2.getLong(1) === 3L) + + assert(columnVector.isNullAt(3)) + + val row4 = columnVector.getStruct(4, 2) + assert(row4.getInt(0) === 5) + assert(row4.getLong(1) === 5L) + + columnVector.close() + allocator.close() + } +} From c912e78657ad2e4971862828d88d0e9deb73c0d0 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 19 Jul 2017 23:01:19 +0900 Subject: [PATCH 3/8] Extract ReadOnlyColumnVector. --- .../vectorized/ArrowColumnVector.java | 172 +----------- .../vectorized/ReadOnlyColumnVector.java | 250 ++++++++++++++++++ 2 files changed, 252 insertions(+), 170 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ReadOnlyColumnVector.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java index 35a8412085ca1..7dca38553a39d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java @@ -29,7 +29,7 @@ /** * A column backed by Apache Arrow. */ -public final class ArrowColumnVector extends ColumnVector { +public final class ArrowColumnVector extends ReadOnlyColumnVector { private ValueVector vector; private ValueVector.Accessor nulls; @@ -51,7 +51,7 @@ public final class ArrowColumnVector extends ColumnVector { private UInt4Vector listOffsetData; public ArrowColumnVector(ValueVector vector) { - super(vector.getValueCapacity(), DataTypes.NullType, MemoryMode.OFF_HEAP); + super(vector.getValueCapacity(), MemoryMode.OFF_HEAP); initialize(vector); } @@ -79,26 +79,6 @@ public void close() { // APIs dealing with nulls // - @Override - public void putNotNull(int rowId) { - throw new UnsupportedOperationException(); - } - - @Override - public void putNull(int rowId) { - throw new UnsupportedOperationException(); - } - - @Override - public void putNulls(int rowId, int count) { - throw new UnsupportedOperationException(); - } - - @Override - public void putNotNulls(int rowId, int count) { - throw new UnsupportedOperationException(); - } - @Override public boolean isNullAt(int rowId) { return nulls.isNull(rowId); @@ -108,16 +88,6 @@ public boolean isNullAt(int rowId) { // APIs dealing with Booleans // - @Override - public void putBoolean(int rowId, boolean value) { - throw new UnsupportedOperationException(); - } - - @Override - public void putBooleans(int rowId, int count, boolean value) { - throw new UnsupportedOperationException(); - } - @Override public boolean getBoolean(int rowId) { return boolData.getAccessor().get(rowId) == 1; @@ -138,21 +108,6 @@ public boolean[] getBooleans(int rowId, int count) { // APIs dealing with Bytes // - @Override - public void putByte(int rowId, byte value) { - throw new UnsupportedOperationException(); - } - - @Override - public void putBytes(int rowId, int count, byte value) { - throw new UnsupportedOperationException(); - } - - @Override - public void putBytes(int rowId, int count, byte[] src, int srcIndex) { - throw new UnsupportedOperationException(); - } - @Override public byte getByte(int rowId) { return byteData.getAccessor().get(rowId); @@ -173,21 +128,6 @@ public byte[] getBytes(int rowId, int count) { // APIs dealing with Shorts // - @Override - public void putShort(int rowId, short value) { - throw new UnsupportedOperationException(); - } - - @Override - public void putShorts(int rowId, int count, short value) { - throw new UnsupportedOperationException(); - } - - @Override - public void putShorts(int rowId, int count, short[] src, int srcIndex) { - throw new UnsupportedOperationException(); - } - @Override public short getShort(int rowId) { return shortData.getAccessor().get(rowId); @@ -208,26 +148,6 @@ public short[] getShorts(int rowId, int count) { // APIs dealing with Ints // - @Override - public void putInt(int rowId, int value) { - throw new UnsupportedOperationException(); - } - - @Override - public void putInts(int rowId, int count, int value) { - throw new UnsupportedOperationException(); - } - - @Override - public void putInts(int rowId, int count, int[] src, int srcIndex) { - throw new UnsupportedOperationException(); - } - - @Override - public void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { - throw new UnsupportedOperationException(); - } - @Override public int getInt(int rowId) { return intData.getAccessor().get(rowId); @@ -253,26 +173,6 @@ public int getDictId(int rowId) { // APIs dealing with Longs // - @Override - public void putLong(int rowId, long value) { - throw new UnsupportedOperationException(); - } - - @Override - public void putLongs(int rowId, int count, long value) { - throw new UnsupportedOperationException(); - } - - @Override - public void putLongs(int rowId, int count, long[] src, int srcIndex) { - throw new UnsupportedOperationException(); - } - - @Override - public void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { - throw new UnsupportedOperationException(); - } - @Override public long getLong(int rowId) { return longData.getAccessor().get(rowId); @@ -293,26 +193,6 @@ public long[] getLongs(int rowId, int count) { // APIs dealing with floats // - @Override - public void putFloat(int rowId, float value) { - throw new UnsupportedOperationException(); - } - - @Override - public void putFloats(int rowId, int count, float value) { - throw new UnsupportedOperationException(); - } - - @Override - public void putFloats(int rowId, int count, float[] src, int srcIndex) { - throw new UnsupportedOperationException(); - } - - @Override - public void putFloats(int rowId, int count, byte[] src, int srcIndex) { - throw new UnsupportedOperationException(); - } - @Override public float getFloat(int rowId) { return floatData.getAccessor().get(rowId); @@ -333,26 +213,6 @@ public float[] getFloats(int rowId, int count) { // APIs dealing with doubles // - @Override - public void putDouble(int rowId, double value) { - throw new UnsupportedOperationException(); - } - - @Override - public void putDoubles(int rowId, int count, double value) { - throw new UnsupportedOperationException(); - } - - @Override - public void putDoubles(int rowId, int count, double[] src, int srcIndex) { - throw new UnsupportedOperationException(); - } - - @Override - public void putDoubles(int rowId, int count, byte[] src, int srcIndex) { - throw new UnsupportedOperationException(); - } - @Override public double getDouble(int rowId) { return doubleData.getAccessor().get(rowId); @@ -383,25 +243,11 @@ public int getArrayOffset(int rowId) { return listOffsetData.getAccessor().get(rowId); } - @Override - public void putArray(int rowId, int offset, int length) { - throw new UnsupportedOperationException(); - } - @Override public void loadBytes(Array array) { throw new UnsupportedOperationException(); } - // - // APIs dealing with Byte Arrays - // - - @Override - public int putByteArray(int rowId, byte[] value, int offset, int count) { - throw new UnsupportedOperationException(); - } - // // APIs dealing with Decimals // @@ -412,11 +258,6 @@ public Decimal getDecimal(int rowId, int precision, int scale) { return Decimal.apply(decimalData.getAccessor().getObject(rowId), precision, scale); } - @Override - public final void putDecimal(int rowId, Decimal value, int precision) { - throw new UnsupportedOperationException(); - } - // // APIs dealing with UTF8Strings // @@ -444,14 +285,6 @@ public byte[] getBinary(int rowId) { return binaryData.getAccessor().getObject(rowId); } - @Override - protected void reserveInternal(int newCapacity) { - while (vector.getValueCapacity() <= newCapacity) { - vector.reAlloc(); - } - capacity = vector.getValueCapacity(); - } - private void initialize(ValueVector vector) { this.vector = vector; this.type = ArrowUtils.fromArrowField(vector.getField()); @@ -505,6 +338,5 @@ private void initialize(ValueVector vector) { } numNulls = nulls.getNullCount(); anyNullsSet = numNulls > 0; - isConstant = true; } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ReadOnlyColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ReadOnlyColumnVector.java new file mode 100644 index 0000000000000..8396f99d1329b --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ReadOnlyColumnVector.java @@ -0,0 +1,250 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.vectorized; + +import org.apache.spark.memory.MemoryMode; +import org.apache.spark.sql.types.*; + +/** + * An abstract class for read-only column vector. + */ +public abstract class ReadOnlyColumnVector extends ColumnVector { + + protected ReadOnlyColumnVector(int capacity, MemoryMode memMode) { + super(capacity, DataTypes.NullType, memMode); + isConstant = true; + } + + // + // APIs dealing with nulls + // + + @Override + public final void putNotNull(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public final void putNull(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public final void putNulls(int rowId, int count) { + throw new UnsupportedOperationException(); + } + + @Override + public final void putNotNulls(int rowId, int count) { + throw new UnsupportedOperationException(); + } + + // + // APIs dealing with Booleans + // + + @Override + public final void putBoolean(int rowId, boolean value) { + throw new UnsupportedOperationException(); + } + + @Override + public final void putBooleans(int rowId, int count, boolean value) { + throw new UnsupportedOperationException(); + } + + // + // APIs dealing with Bytes + // + + @Override + public final void putByte(int rowId, byte value) { + throw new UnsupportedOperationException(); + } + + @Override + public final void putBytes(int rowId, int count, byte value) { + throw new UnsupportedOperationException(); + } + + @Override + public final void putBytes(int rowId, int count, byte[] src, int srcIndex) { + throw new UnsupportedOperationException(); + } + + // + // APIs dealing with Shorts + // + + @Override + public final void putShort(int rowId, short value) { + throw new UnsupportedOperationException(); + } + + @Override + public final void putShorts(int rowId, int count, short value) { + throw new UnsupportedOperationException(); + } + + @Override + public final void putShorts(int rowId, int count, short[] src, int srcIndex) { + throw new UnsupportedOperationException(); + } + + // + // APIs dealing with Ints + // + + @Override + public final void putInt(int rowId, int value) { + throw new UnsupportedOperationException(); + } + + @Override + public final void putInts(int rowId, int count, int value) { + throw new UnsupportedOperationException(); + } + + @Override + public final void putInts(int rowId, int count, int[] src, int srcIndex) { + throw new UnsupportedOperationException(); + } + + @Override + public final void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { + throw new UnsupportedOperationException(); + } + + // + // APIs dealing with Longs + // + + @Override + public final void putLong(int rowId, long value) { + throw new UnsupportedOperationException(); + } + + @Override + public final void putLongs(int rowId, int count, long value) { + throw new UnsupportedOperationException(); + } + + @Override + public final void putLongs(int rowId, int count, long[] src, int srcIndex) { + throw new UnsupportedOperationException(); + } + + @Override + public final void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { + throw new UnsupportedOperationException(); + } + + // + // APIs dealing with floats + // + + @Override + public final void putFloat(int rowId, float value) { + throw new UnsupportedOperationException(); + } + + @Override + public final void putFloats(int rowId, int count, float value) { + throw new UnsupportedOperationException(); + } + + @Override + public final void putFloats(int rowId, int count, float[] src, int srcIndex) { + throw new UnsupportedOperationException(); + } + + @Override + public final void putFloats(int rowId, int count, byte[] src, int srcIndex) { + throw new UnsupportedOperationException(); + } + + // + // APIs dealing with doubles + // + + @Override + public final void putDouble(int rowId, double value) { + throw new UnsupportedOperationException(); + } + + @Override + public final void putDoubles(int rowId, int count, double value) { + throw new UnsupportedOperationException(); + } + + @Override + public final void putDoubles(int rowId, int count, double[] src, int srcIndex) { + throw new UnsupportedOperationException(); + } + + @Override + public final void putDoubles(int rowId, int count, byte[] src, int srcIndex) { + throw new UnsupportedOperationException(); + } + + // + // APIs dealing with Arrays + // + + @Override + public final void putArray(int rowId, int offset, int length) { + throw new UnsupportedOperationException(); + } + + // + // APIs dealing with Byte Arrays + // + + @Override + public final int putByteArray(int rowId, byte[] value, int offset, int count) { + throw new UnsupportedOperationException(); + } + + // + // APIs dealing with Decimals + // + + @Override + public final void putDecimal(int rowId, Decimal value, int precision) { + throw new UnsupportedOperationException(); + } + + // + // Other APIs + // + + @Override + public final void setDictionary(Dictionary dictionary) { + throw new UnsupportedOperationException(); + } + + @Override + public final ColumnVector reserveDictionaryIds(int capacity) { + throw new UnsupportedOperationException(); + } + + @Override + protected final void reserveInternal(int newCapacity) { + throw new UnsupportedOperationException(); + } +} From 2215922185babe20c2a58412efa3aae86b6aadd0 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 20 Jul 2017 00:07:21 +0900 Subject: [PATCH 4/8] Refactor ArrowColumnVector. --- .../vectorized/ArrowColumnVector.java | 393 +++++++++++++----- 1 file changed, 298 insertions(+), 95 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java index 7dca38553a39d..332a644b4656c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java @@ -31,29 +31,7 @@ */ public final class ArrowColumnVector extends ReadOnlyColumnVector { - private ValueVector vector; - private ValueVector.Accessor nulls; - - private NullableBitVector boolData; - private NullableTinyIntVector byteData; - private NullableSmallIntVector shortData; - private NullableIntVector intData; - private NullableBigIntVector longData; - - private NullableFloat4Vector floatData; - private NullableFloat8Vector doubleData; - private NullableDecimalVector decimalData; - - private NullableVarCharVector stringData; - - private NullableVarBinaryVector binaryData; - - private UInt4Vector listOffsetData; - - public ArrowColumnVector(ValueVector vector) { - super(vector.getValueCapacity(), MemoryMode.OFF_HEAP); - initialize(vector); - } + private final ArrowVectorAccessor accessor; @Override public long nullsNativeAddress() { @@ -72,7 +50,7 @@ public void close() { childColumns[i].close(); } } - vector.close(); + accessor.close(); } // @@ -81,7 +59,7 @@ public void close() { @Override public boolean isNullAt(int rowId) { - return nulls.isNull(rowId); + return accessor.isNullAt(rowId); } // @@ -90,16 +68,14 @@ public boolean isNullAt(int rowId) { @Override public boolean getBoolean(int rowId) { - return boolData.getAccessor().get(rowId) == 1; + return accessor.getBoolean(rowId); } @Override public boolean[] getBooleans(int rowId, int count) { - assert(dictionary == null); - NullableBitVector.Accessor accessor = boolData.getAccessor(); boolean[] array = new boolean[count]; for (int i = 0; i < count; ++i) { - array[i] = (accessor.get(rowId + i) == 1); + array[i] = accessor.getBoolean(rowId + i); } return array; } @@ -110,16 +86,14 @@ public boolean[] getBooleans(int rowId, int count) { @Override public byte getByte(int rowId) { - return byteData.getAccessor().get(rowId); + return accessor.getByte(rowId); } @Override public byte[] getBytes(int rowId, int count) { - assert(dictionary == null); - NullableTinyIntVector.Accessor accessor = byteData.getAccessor(); byte[] array = new byte[count]; for (int i = 0; i < count; ++i) { - array[i] = accessor.get(rowId + i); + array[i] = accessor.getByte(rowId + i); } return array; } @@ -130,16 +104,14 @@ public byte[] getBytes(int rowId, int count) { @Override public short getShort(int rowId) { - return shortData.getAccessor().get(rowId); + return accessor.getShort(rowId); } @Override public short[] getShorts(int rowId, int count) { - assert(dictionary == null); - NullableSmallIntVector.Accessor accessor = shortData.getAccessor(); short[] array = new short[count]; for (int i = 0; i < count; ++i) { - array[i] = accessor.get(rowId + i); + array[i] = accessor.getShort(rowId + i); } return array; } @@ -150,16 +122,14 @@ public short[] getShorts(int rowId, int count) { @Override public int getInt(int rowId) { - return intData.getAccessor().get(rowId); + return accessor.getInt(rowId); } @Override public int[] getInts(int rowId, int count) { - assert(dictionary == null); - NullableIntVector.Accessor accessor = intData.getAccessor(); int[] array = new int[count]; for (int i = 0; i < count; ++i) { - array[i] = accessor.get(rowId + i); + array[i] = accessor.getInt(rowId + i); } return array; } @@ -175,16 +145,14 @@ public int getDictId(int rowId) { @Override public long getLong(int rowId) { - return longData.getAccessor().get(rowId); + return accessor.getLong(rowId); } @Override public long[] getLongs(int rowId, int count) { - assert(dictionary == null); - NullableBigIntVector.Accessor accessor = longData.getAccessor(); long[] array = new long[count]; for (int i = 0; i < count; ++i) { - array[i] = accessor.get(rowId + i); + array[i] = accessor.getLong(rowId + i); } return array; } @@ -195,16 +163,14 @@ public long[] getLongs(int rowId, int count) { @Override public float getFloat(int rowId) { - return floatData.getAccessor().get(rowId); + return accessor.getFloat(rowId); } @Override public float[] getFloats(int rowId, int count) { - assert(dictionary == null); - NullableFloat4Vector.Accessor accessor = floatData.getAccessor(); float[] array = new float[count]; for (int i = 0; i < count; ++i) { - array[i] = accessor.get(rowId + i); + array[i] = accessor.getFloat(rowId + i); } return array; } @@ -215,16 +181,14 @@ public float[] getFloats(int rowId, int count) { @Override public double getDouble(int rowId) { - return doubleData.getAccessor().get(rowId); + return accessor.getDouble(rowId); } @Override public double[] getDoubles(int rowId, int count) { - assert(dictionary == null); - NullableFloat8Vector.Accessor accessor = doubleData.getAccessor(); double[] array = new double[count]; for (int i = 0; i < count; ++i) { - array[i] = accessor.get(rowId + i); + array[i] = accessor.getDouble(rowId + i); } return array; } @@ -235,12 +199,12 @@ public double[] getDoubles(int rowId, int count) { @Override public int getArrayLength(int rowId) { - return listOffsetData.getAccessor().get(rowId + 1) - listOffsetData.getAccessor().get(rowId); + return accessor.getArrayLength(rowId); } @Override public int getArrayOffset(int rowId) { - return listOffsetData.getAccessor().get(rowId); + return accessor.getArrayOffset(rowId); } @Override @@ -254,26 +218,16 @@ public void loadBytes(Array array) { @Override public Decimal getDecimal(int rowId, int precision, int scale) { - if (isNullAt(rowId)) return null; - return Decimal.apply(decimalData.getAccessor().getObject(rowId), precision, scale); + return accessor.getDecimal(rowId, precision, scale); } // // APIs dealing with UTF8Strings // - private NullableVarCharHolder stringResult = new NullableVarCharHolder(); - @Override public UTF8String getUTF8String(int rowId) { - stringData.getAccessor().get(rowId, stringResult); - if (stringResult.isSet == 0) { - return null; - } else { - return UTF8String.fromAddress(null, - stringResult.buffer.memoryAddress() + stringResult.start, - stringResult.end - stringResult.start); - } + return accessor.getUTF8String(rowId); } // @@ -282,61 +236,310 @@ public UTF8String getUTF8String(int rowId) { @Override public byte[] getBinary(int rowId) { - return binaryData.getAccessor().getObject(rowId); + return accessor.getBinary(rowId); } - private void initialize(ValueVector vector) { - this.vector = vector; - this.type = ArrowUtils.fromArrowField(vector.getField()); + public ArrowColumnVector(ValueVector vector) { + super(vector.getValueCapacity(), MemoryMode.OFF_HEAP); + + type = ArrowUtils.fromArrowField(vector.getField()); if (vector instanceof NullableBitVector) { - boolData = (NullableBitVector) vector; - nulls = boolData.getAccessor(); + accessor = new BooleanAccessor((NullableBitVector) vector); } else if (vector instanceof NullableTinyIntVector) { - byteData = (NullableTinyIntVector) vector; - nulls = byteData.getAccessor(); + accessor = new ByteAccessor((NullableTinyIntVector) vector); } else if (vector instanceof NullableSmallIntVector) { - shortData = (NullableSmallIntVector) vector; - nulls = shortData.getAccessor(); + accessor = new ShortAccessor((NullableSmallIntVector) vector); } else if (vector instanceof NullableIntVector) { - intData = (NullableIntVector) vector; - nulls = intData.getAccessor(); + accessor = new IntAccessor((NullableIntVector) vector); } else if (vector instanceof NullableBigIntVector) { - longData = (NullableBigIntVector) vector; - nulls = longData.getAccessor(); + accessor = new LongAccessor((NullableBigIntVector) vector); } else if (vector instanceof NullableFloat4Vector) { - floatData = (NullableFloat4Vector) vector; - nulls = floatData.getAccessor(); + accessor = new FloatAccessor((NullableFloat4Vector) vector); } else if (vector instanceof NullableFloat8Vector) { - doubleData = (NullableFloat8Vector) vector; - nulls = doubleData.getAccessor(); + accessor = new DoubleAccessor((NullableFloat8Vector) vector); } else if (vector instanceof NullableDecimalVector) { - decimalData = (NullableDecimalVector) vector; - nulls = decimalData.getAccessor(); + accessor = new DecimalAccessor((NullableDecimalVector) vector); } else if (vector instanceof NullableVarCharVector) { - stringData = (NullableVarCharVector) vector; - nulls = stringData.getAccessor(); + accessor = new StringAccessor((NullableVarCharVector) vector); } else if (vector instanceof NullableVarBinaryVector) { - binaryData = (NullableVarBinaryVector) vector; - nulls = binaryData.getAccessor(); + accessor = new BinaryAccessor((NullableVarBinaryVector) vector); } else if (vector instanceof ListVector) { ListVector listVector = (ListVector) vector; - listOffsetData = listVector.getOffsetVector(); - nulls = listVector.getAccessor(); + accessor = new ArrayAccessor(listVector); childColumns = new ColumnVector[1]; childColumns[0] = new ArrowColumnVector(listVector.getDataVector()); resultArray = new Array(childColumns[0]); } else if (vector instanceof MapVector) { MapVector mapVector = (MapVector) vector; - nulls = mapVector.getAccessor(); + accessor = new StructAccessor(mapVector); childColumns = new ArrowColumnVector[mapVector.size()]; for (int i = 0; i < childColumns.length; ++i) { childColumns[i] = new ArrowColumnVector(mapVector.getVectorById(i)); } resultStruct = new ColumnarBatch.Row(childColumns); + } else { + throw new UnsupportedOperationException(); } - numNulls = nulls.getNullCount(); + numNulls = accessor.getNullCount(); anyNullsSet = numNulls > 0; } + + private static abstract class ArrowVectorAccessor { + + private final ValueVector vector; + private final ValueVector.Accessor nulls; + + ArrowVectorAccessor(ValueVector vector) { + this.vector = vector; + this.nulls = vector.getAccessor(); + } + + final boolean isNullAt(int rowId) { + return nulls.isNull(rowId); + } + + final int getNullCount() { + return nulls.getNullCount(); + } + + final void close() { + vector.close(); + } + + boolean getBoolean(int rowId) { + throw new UnsupportedOperationException(); + } + + byte getByte(int rowId) { + throw new UnsupportedOperationException(); + } + + short getShort(int rowId) { + throw new UnsupportedOperationException(); + } + + int getInt(int rowId) { + throw new UnsupportedOperationException(); + } + + long getLong(int rowId) { + throw new UnsupportedOperationException(); + } + + float getFloat(int rowId) { + throw new UnsupportedOperationException(); + } + + double getDouble(int rowId) { + throw new UnsupportedOperationException(); + } + + Decimal getDecimal(int rowId, int precision, int scale) { + throw new UnsupportedOperationException(); + } + + UTF8String getUTF8String(int rowId) { + throw new UnsupportedOperationException(); + } + + byte[] getBinary(int rowId) { + throw new UnsupportedOperationException(); + } + + int getArrayLength(int rowId) { + throw new UnsupportedOperationException(); + } + + int getArrayOffset(int rowId) { + throw new UnsupportedOperationException(); + } + } + + private static class BooleanAccessor extends ArrowVectorAccessor { + + private final NullableBitVector.Accessor accessor; + + BooleanAccessor(NullableBitVector vector) { + super(vector); + this.accessor = vector.getAccessor(); + } + + @Override + final boolean getBoolean(int rowId) { + return accessor.get(rowId) == 1; + } + } + + private static class ByteAccessor extends ArrowVectorAccessor { + + private final NullableTinyIntVector.Accessor accessor; + + ByteAccessor(NullableTinyIntVector vector) { + super(vector); + this.accessor = vector.getAccessor(); + } + + @Override + final byte getByte(int rowId) { + return accessor.get(rowId); + } + } + + private static class ShortAccessor extends ArrowVectorAccessor { + + private final NullableSmallIntVector.Accessor accessor; + + ShortAccessor(NullableSmallIntVector vector) { + super(vector); + this.accessor = vector.getAccessor(); + } + + @Override + final short getShort(int rowId) { + return accessor.get(rowId); + } + } + + private static class IntAccessor extends ArrowVectorAccessor { + + private final NullableIntVector.Accessor accessor; + + IntAccessor(NullableIntVector vector) { + super(vector); + this.accessor = vector.getAccessor(); + } + + @Override + final int getInt(int rowId) { + return accessor.get(rowId); + } + } + + private static class LongAccessor extends ArrowVectorAccessor { + + private final NullableBigIntVector.Accessor accessor; + + LongAccessor(NullableBigIntVector vector) { + super(vector); + this.accessor = vector.getAccessor(); + } + + @Override + final long getLong(int rowId) { + return accessor.get(rowId); + } + } + + private static class FloatAccessor extends ArrowVectorAccessor { + + private final NullableFloat4Vector.Accessor accessor; + + FloatAccessor(NullableFloat4Vector vector) { + super(vector); + this.accessor = vector.getAccessor(); + } + + @Override + final float getFloat(int rowId) { + return accessor.get(rowId); + } + } + + private static class DoubleAccessor extends ArrowVectorAccessor { + + private final NullableFloat8Vector.Accessor accessor; + + DoubleAccessor(NullableFloat8Vector vector) { + super(vector); + this.accessor = vector.getAccessor(); + } + + @Override + final double getDouble(int rowId) { + return accessor.get(rowId); + } + } + + private static class DecimalAccessor extends ArrowVectorAccessor { + + private final NullableDecimalVector.Accessor accessor; + + DecimalAccessor(NullableDecimalVector vector) { + super(vector); + this.accessor = vector.getAccessor(); + } + + @Override + final Decimal getDecimal(int rowId, int precision, int scale) { + if (isNullAt(rowId)) return null; + return Decimal.apply(accessor.getObject(rowId), precision, scale); + } + } + + private static class StringAccessor extends ArrowVectorAccessor { + + private final NullableVarCharVector.Accessor accessor; + private final NullableVarCharHolder stringResult = new NullableVarCharHolder(); + + StringAccessor(NullableVarCharVector vector) { + super(vector); + this.accessor = vector.getAccessor(); + } + + @Override + final UTF8String getUTF8String(int rowId) { + accessor.get(rowId, stringResult); + if (stringResult.isSet == 0) { + return null; + } else { + return UTF8String.fromAddress(null, + stringResult.buffer.memoryAddress() + stringResult.start, + stringResult.end - stringResult.start); + } + } + } + + private static class BinaryAccessor extends ArrowVectorAccessor { + + private final NullableVarBinaryVector.Accessor accessor; + + BinaryAccessor(NullableVarBinaryVector vector) { + super(vector); + this.accessor = vector.getAccessor(); + } + + @Override + final byte[] getBinary(int rowId) { + return accessor.getObject(rowId); + } + } + + private static class ArrayAccessor extends ArrowVectorAccessor { + + private final UInt4Vector.Accessor accessor; + + ArrayAccessor(ListVector vector) { + super(vector); + this.accessor = vector.getOffsetVector().getAccessor(); + } + + @Override + final int getArrayLength(int rowId) { + return accessor.get(rowId + 1) - accessor.get(rowId); + } + + @Override + final int getArrayOffset(int rowId) { + return accessor.get(rowId); + } + } + + private static class StructAccessor extends ArrowVectorAccessor { + + StructAccessor(MapVector vector) { + super(vector); + } + } } From ddfcf3670c86c7d0498f2193df1525fc60662e40 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 20 Jul 2017 00:17:10 +0900 Subject: [PATCH 5/8] Add tests to check getting multiple values. --- .../vectorized/ArrowColumnVectorSuite.scala | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala index 433a65a84117a..d24a9e1f4bd16 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala @@ -50,6 +50,8 @@ class ArrowColumnVectorSuite extends SparkFunSuite { } assert(columnVector.isNullAt(10)) + assert(columnVector.getBooleans(0, 10) === (0 until 10).map(i => (i % 2 == 0))) + columnVector.close() allocator.close() } @@ -77,6 +79,8 @@ class ArrowColumnVectorSuite extends SparkFunSuite { } assert(columnVector.isNullAt(10)) + assert(columnVector.getBytes(0, 10) === (0 until 10).map(i => i.toByte)) + columnVector.close() allocator.close() } @@ -104,6 +108,8 @@ class ArrowColumnVectorSuite extends SparkFunSuite { } assert(columnVector.isNullAt(10)) + assert(columnVector.getShorts(0, 10) === (0 until 10).map(i => i.toShort)) + columnVector.close() allocator.close() } @@ -131,6 +137,8 @@ class ArrowColumnVectorSuite extends SparkFunSuite { } assert(columnVector.isNullAt(10)) + assert(columnVector.getInts(0, 10) === (0 until 10)) + columnVector.close() allocator.close() } @@ -158,6 +166,8 @@ class ArrowColumnVectorSuite extends SparkFunSuite { } assert(columnVector.isNullAt(10)) + assert(columnVector.getLongs(0, 10) === (0 until 10).map(i => i.toLong)) + columnVector.close() allocator.close() } @@ -185,6 +195,8 @@ class ArrowColumnVectorSuite extends SparkFunSuite { } assert(columnVector.isNullAt(10)) + assert(columnVector.getFloats(0, 10) === (0 until 10).map(i => i.toFloat)) + columnVector.close() allocator.close() } @@ -212,6 +224,8 @@ class ArrowColumnVectorSuite extends SparkFunSuite { } assert(columnVector.isNullAt(10)) + assert(columnVector.getDoubles(0, 10) === (0 until 10).map(i => i.toDouble)) + columnVector.close() allocator.close() } From 91b94ef6d08771fe8e5eb5d41f43153af9a75f06 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 20 Jul 2017 13:18:48 +0900 Subject: [PATCH 6/8] Modify ReadOnlyColumnVector to accept dataType for the constructor argument. --- .../spark/sql/execution/vectorized/ArrowColumnVector.java | 4 ++-- .../spark/sql/execution/vectorized/ReadOnlyColumnVector.java | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java index 332a644b4656c..202b19e8df9d5 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java @@ -240,9 +240,9 @@ public byte[] getBinary(int rowId) { } public ArrowColumnVector(ValueVector vector) { - super(vector.getValueCapacity(), MemoryMode.OFF_HEAP); + super(vector.getValueCapacity(), ArrowUtils.fromArrowField(vector.getField()), + MemoryMode.OFF_HEAP); - type = ArrowUtils.fromArrowField(vector.getField()); if (vector instanceof NullableBitVector) { accessor = new BooleanAccessor((NullableBitVector) vector); } else if (vector instanceof NullableTinyIntVector) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ReadOnlyColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ReadOnlyColumnVector.java index 8396f99d1329b..e9f6e7c631fd4 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ReadOnlyColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ReadOnlyColumnVector.java @@ -25,8 +25,9 @@ */ public abstract class ReadOnlyColumnVector extends ColumnVector { - protected ReadOnlyColumnVector(int capacity, MemoryMode memMode) { + protected ReadOnlyColumnVector(int capacity, DataType type, MemoryMode memMode) { super(capacity, DataTypes.NullType, memMode); + this.type = type; isConstant = true; } From afdaf5a3a952e1a06ccf7f13fcc2e1d1318aee99 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 20 Jul 2017 17:35:49 +0900 Subject: [PATCH 7/8] Fix a comment. --- .../spark/sql/execution/vectorized/ArrowColumnVector.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java index 202b19e8df9d5..b874f7399b808 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java @@ -27,7 +27,7 @@ import org.apache.spark.unsafe.types.UTF8String; /** - * A column backed by Apache Arrow. + * A column vector backed by Apache Arrow. */ public final class ArrowColumnVector extends ReadOnlyColumnVector { From 2d1dad9ac6bc2cfa4a4dcad32ef99464bc7f6541 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 20 Jul 2017 19:16:33 +0900 Subject: [PATCH 8/8] Add boundary check. --- .../vectorized/ArrowColumnVector.java | 47 ++++++++++++++++++- 1 file changed, 46 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java index b874f7399b808..68e0abc11c39d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java @@ -32,6 +32,21 @@ public final class ArrowColumnVector extends ReadOnlyColumnVector { private final ArrowVectorAccessor accessor; + private final int valueCount; + + private void ensureAccessible(int index) { + if (index < 0 || index >= valueCount) { + throw new IndexOutOfBoundsException( + String.format("index: %d, valueCount: %d", index, valueCount)); + } + } + + private void ensureAccessible(int index, int count) { + if (index < 0 || index + count > valueCount) { + throw new IndexOutOfBoundsException( + String.format("index range: [%d, %d), valueCount: %d", index, index + count, valueCount)); + } + } @Override public long nullsNativeAddress() { @@ -59,6 +74,7 @@ public void close() { @Override public boolean isNullAt(int rowId) { + ensureAccessible(rowId); return accessor.isNullAt(rowId); } @@ -68,11 +84,13 @@ public boolean isNullAt(int rowId) { @Override public boolean getBoolean(int rowId) { + ensureAccessible(rowId); return accessor.getBoolean(rowId); } @Override public boolean[] getBooleans(int rowId, int count) { + ensureAccessible(rowId, count); boolean[] array = new boolean[count]; for (int i = 0; i < count; ++i) { array[i] = accessor.getBoolean(rowId + i); @@ -86,11 +104,13 @@ public boolean[] getBooleans(int rowId, int count) { @Override public byte getByte(int rowId) { + ensureAccessible(rowId); return accessor.getByte(rowId); } @Override public byte[] getBytes(int rowId, int count) { + ensureAccessible(rowId, count); byte[] array = new byte[count]; for (int i = 0; i < count; ++i) { array[i] = accessor.getByte(rowId + i); @@ -104,11 +124,13 @@ public byte[] getBytes(int rowId, int count) { @Override public short getShort(int rowId) { + ensureAccessible(rowId); return accessor.getShort(rowId); } @Override public short[] getShorts(int rowId, int count) { + ensureAccessible(rowId, count); short[] array = new short[count]; for (int i = 0; i < count; ++i) { array[i] = accessor.getShort(rowId + i); @@ -122,11 +144,13 @@ public short[] getShorts(int rowId, int count) { @Override public int getInt(int rowId) { + ensureAccessible(rowId); return accessor.getInt(rowId); } @Override public int[] getInts(int rowId, int count) { + ensureAccessible(rowId, count); int[] array = new int[count]; for (int i = 0; i < count; ++i) { array[i] = accessor.getInt(rowId + i); @@ -145,11 +169,13 @@ public int getDictId(int rowId) { @Override public long getLong(int rowId) { + ensureAccessible(rowId); return accessor.getLong(rowId); } @Override public long[] getLongs(int rowId, int count) { + ensureAccessible(rowId, count); long[] array = new long[count]; for (int i = 0; i < count; ++i) { array[i] = accessor.getLong(rowId + i); @@ -163,11 +189,13 @@ public long[] getLongs(int rowId, int count) { @Override public float getFloat(int rowId) { + ensureAccessible(rowId); return accessor.getFloat(rowId); } @Override public float[] getFloats(int rowId, int count) { + ensureAccessible(rowId, count); float[] array = new float[count]; for (int i = 0; i < count; ++i) { array[i] = accessor.getFloat(rowId + i); @@ -181,11 +209,13 @@ public float[] getFloats(int rowId, int count) { @Override public double getDouble(int rowId) { + ensureAccessible(rowId); return accessor.getDouble(rowId); } @Override public double[] getDoubles(int rowId, int count) { + ensureAccessible(rowId, count); double[] array = new double[count]; for (int i = 0; i < count; ++i) { array[i] = accessor.getDouble(rowId + i); @@ -199,11 +229,13 @@ public double[] getDoubles(int rowId, int count) { @Override public int getArrayLength(int rowId) { + ensureAccessible(rowId); return accessor.getArrayLength(rowId); } @Override public int getArrayOffset(int rowId) { + ensureAccessible(rowId); return accessor.getArrayOffset(rowId); } @@ -218,6 +250,7 @@ public void loadBytes(Array array) { @Override public Decimal getDecimal(int rowId, int precision, int scale) { + ensureAccessible(rowId); return accessor.getDecimal(rowId, precision, scale); } @@ -227,6 +260,7 @@ public Decimal getDecimal(int rowId, int precision, int scale) { @Override public UTF8String getUTF8String(int rowId) { + ensureAccessible(rowId); return accessor.getUTF8String(rowId); } @@ -236,6 +270,7 @@ public UTF8String getUTF8String(int rowId) { @Override public byte[] getBinary(int rowId) { + ensureAccessible(rowId); return accessor.getBinary(rowId); } @@ -282,6 +317,7 @@ public ArrowColumnVector(ValueVector vector) { } else { throw new UnsupportedOperationException(); } + valueCount = accessor.getValueCount(); numNulls = accessor.getNullCount(); anyNullsSet = numNulls > 0; } @@ -291,17 +327,26 @@ private static abstract class ArrowVectorAccessor { private final ValueVector vector; private final ValueVector.Accessor nulls; + private final int valueCount; + private final int nullCount; + ArrowVectorAccessor(ValueVector vector) { this.vector = vector; this.nulls = vector.getAccessor(); + this.valueCount = nulls.getValueCount(); + this.nullCount = nulls.getNullCount(); } final boolean isNullAt(int rowId) { return nulls.isNull(rowId); } + final int getValueCount() { + return valueCount; + } + final int getNullCount() { - return nulls.getNullCount(); + return nullCount; } final void close() {