diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 2c714c228e6c9..f96ed7628fda1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -688,17 +688,13 @@ class CodegenContext { /** * Returns the specialized code to access a value from a column vector for a given `DataType`. */ - def getValue(vector: String, rowId: String, dataType: DataType): String = { - val jt = javaType(dataType) - dataType match { - case _ if isPrimitiveType(jt) => - s"$vector.get${primitiveTypeName(jt)}($rowId)" - case t: DecimalType => - s"$vector.getDecimal($rowId, ${t.precision}, ${t.scale})" - case StringType => - s"$vector.getUTF8String($rowId)" - case _ => - throw new IllegalArgumentException(s"cannot generate code for unsupported type: $dataType") + def getValueFromVector(vector: String, dataType: DataType, rowId: String): String = { + if (dataType.isInstanceOf[StructType]) { + // `ColumnVector.getStruct` is different from `InternalRow.getStruct`, it only takes an + // `ordinal` parameter. + s"$vector.getStruct($rowId)" + } else { + getValue(vector, dataType, rowId) } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java index b6e792274da11..aaf2a380034a9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java @@ -110,57 +110,21 @@ public boolean getBoolean(int rowId) { return longData.vector[getRowIndex(rowId)] == 1; } - @Override - public boolean[] getBooleans(int rowId, int count) { - boolean[] res = new boolean[count]; - for (int i = 0; i < count; i++) { - res[i] = getBoolean(rowId + i); - } - return res; - } - @Override public byte getByte(int rowId) { return (byte) longData.vector[getRowIndex(rowId)]; } - @Override - public byte[] getBytes(int rowId, int count) { - byte[] res = new byte[count]; - for (int i = 0; i < count; i++) { - res[i] = getByte(rowId + i); - } - return res; - } - @Override public short getShort(int rowId) { return (short) longData.vector[getRowIndex(rowId)]; } - @Override - public short[] getShorts(int rowId, int count) { - short[] res = new short[count]; - for (int i = 0; i < count; i++) { - res[i] = getShort(rowId + i); - } - return res; - } - @Override public int getInt(int rowId) { return (int) longData.vector[getRowIndex(rowId)]; } - @Override - public int[] getInts(int rowId, int count) { - int[] res = new int[count]; - for (int i = 0; i < count; i++) { - res[i] = getInt(rowId + i); - } - return res; - } - @Override public long getLong(int rowId) { int index = getRowIndex(rowId); @@ -171,43 +135,16 @@ public long getLong(int rowId) { } } - @Override - public long[] getLongs(int rowId, int count) { - long[] res = new long[count]; - for (int i = 0; i < count; i++) { - res[i] = getLong(rowId + i); - } - return res; - } - @Override public float getFloat(int rowId) { return (float) doubleData.vector[getRowIndex(rowId)]; } - @Override - public float[] getFloats(int rowId, int count) { - float[] res = new float[count]; - for (int i = 0; i < count; i++) { - res[i] = getFloat(rowId + i); - } - return res; - } - @Override public double getDouble(int rowId) { return doubleData.vector[getRowIndex(rowId)]; } - @Override - public double[] getDoubles(int rowId, int count) { - double[] res = new double[count]; - for (int i = 0; i < count; i++) { - res[i] = getDouble(rowId + i); - } - return res; - } - @Override public int getArrayLength(int rowId) { throw new UnsupportedOperationException(); @@ -245,7 +182,7 @@ public org.apache.spark.sql.vectorized.ColumnVector arrayData() { } @Override - public org.apache.spark.sql.vectorized.ColumnVector getChildColumn(int ordinal) { + public org.apache.spark.sql.vectorized.ColumnVector getChild(int ordinal) { throw new UnsupportedOperationException(); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java index 36fdf2bdf84d2..8612510f4f7bc 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java @@ -302,10 +302,9 @@ private void putRepeatingValues( toColumn.putDoubles(0, batchSize, ((DoubleColumnVector)fromColumn).vector[0]); } else if (type instanceof StringType || type instanceof BinaryType) { BytesColumnVector data = (BytesColumnVector)fromColumn; - WritableColumnVector arrayData = toColumn.getChildColumn(0); int size = data.vector[0].length; - arrayData.reserve(size); - arrayData.putBytes(0, size, data.vector[0], 0); + toColumn.arrayData().reserve(size); + toColumn.arrayData().putBytes(0, size, data.vector[0], 0); for (int index = 0; index < batchSize; index++) { toColumn.putArray(index, 0, size); } @@ -365,7 +364,7 @@ private void putNonNullValues( toColumn.putDoubles(0, batchSize, ((DoubleColumnVector)fromColumn).vector, 0); } else if (type instanceof StringType || type instanceof BinaryType) { BytesColumnVector data = ((BytesColumnVector)fromColumn); - WritableColumnVector arrayData = toColumn.getChildColumn(0); + WritableColumnVector arrayData = toColumn.arrayData(); int totalNumBytes = IntStream.of(data.length).sum(); arrayData.reserve(totalNumBytes); for (int index = 0, pos = 0; index < batchSize; pos += data.length[index], index++) { @@ -376,8 +375,7 @@ private void putNonNullValues( DecimalType decimalType = (DecimalType)type; DecimalColumnVector data = ((DecimalColumnVector)fromColumn); if (decimalType.precision() > Decimal.MAX_LONG_DIGITS()) { - WritableColumnVector arrayData = toColumn.getChildColumn(0); - arrayData.reserve(batchSize * 16); + toColumn.arrayData().reserve(batchSize * 16); } for (int index = 0; index < batchSize; index++) { putDecimalWritable( @@ -472,7 +470,7 @@ private void putValues( } } else if (type instanceof StringType || type instanceof BinaryType) { BytesColumnVector vector = (BytesColumnVector)fromColumn; - WritableColumnVector arrayData = toColumn.getChildColumn(0); + WritableColumnVector arrayData = toColumn.arrayData(); int totalNumBytes = IntStream.of(vector.length).sum(); arrayData.reserve(totalNumBytes); for (int index = 0, pos = 0; index < batchSize; pos += vector.length[index], index++) { @@ -487,8 +485,7 @@ private void putValues( DecimalType decimalType = (DecimalType)type; HiveDecimalWritable[] vector = ((DecimalColumnVector)fromColumn).vector; if (decimalType.precision() > Decimal.MAX_LONG_DIGITS()) { - WritableColumnVector arrayData = toColumn.getChildColumn(0); - arrayData.reserve(batchSize * 16); + toColumn.arrayData().reserve(batchSize * 16); } for (int index = 0; index < batchSize; index++) { if (fromColumn.isNull[index]) { @@ -534,8 +531,7 @@ private static void putDecimalWritable( toColumn.putLong(index, value.toUnscaledLong()); } else { byte[] bytes = value.toJavaBigDecimal().unscaledValue().toByteArray(); - WritableColumnVector arrayData = toColumn.getChildColumn(0); - arrayData.putBytes(index * 16, bytes.length, bytes, 0); + toColumn.arrayData().putBytes(index * 16, bytes.length, bytes, 0); toColumn.putArray(index, index * 16, bytes.length); } } @@ -560,9 +556,8 @@ private static void putDecimalWritables( toColumn.putLongs(0, size, value.toUnscaledLong()); } else { byte[] bytes = value.toJavaBigDecimal().unscaledValue().toByteArray(); - WritableColumnVector arrayData = toColumn.getChildColumn(0); - arrayData.reserve(bytes.length); - arrayData.putBytes(0, bytes.length, bytes, 0); + toColumn.arrayData().reserve(bytes.length); + toColumn.arrayData().putBytes(0, bytes.length, bytes, 0); for (int index = 0; index < size; index++) { toColumn.putArray(index, 0, bytes.length); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java index b5cbe8e2839ba..5108fc211a0d8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java @@ -85,8 +85,8 @@ public static void populate(WritableColumnVector col, InternalRow row, int field } } else if (t instanceof CalendarIntervalType) { CalendarInterval c = (CalendarInterval)row.get(fieldIdx, t); - col.getChildColumn(0).putInts(0, capacity, c.months); - col.getChildColumn(1).putLongs(0, capacity, c.microseconds); + col.getChild(0).putInts(0, capacity, c.months); + col.getChild(1).putLongs(0, capacity, c.microseconds); } else if (t instanceof DateType) { col.putInts(0, capacity, row.getInt(fieldIdx)); } else if (t instanceof TimestampType) { @@ -149,8 +149,8 @@ private static void appendValue(WritableColumnVector dst, DataType t, Object o) } else if (t instanceof CalendarIntervalType) { CalendarInterval c = (CalendarInterval)o; dst.appendStruct(false); - dst.getChildColumn(0).appendInt(c.months); - dst.getChildColumn(1).appendLong(c.microseconds); + dst.getChild(0).appendInt(c.months); + dst.getChild(1).appendLong(c.microseconds); } else if (t instanceof DateType) { dst.appendInt(DateTimeUtils.fromJavaDate((Date)o)); } else { @@ -179,7 +179,7 @@ private static void appendValue(WritableColumnVector dst, DataType t, Row src, i dst.appendStruct(false); Row c = src.getStruct(fieldIdx); for (int i = 0; i < st.fields().length; i++) { - appendValue(dst.getChildColumn(i), st.fields()[i].dataType(), c, i); + appendValue(dst.getChild(i), st.fields()[i].dataType(), c, i); } } } else { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java index 70057a9def6c0..2bab095d4d951 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java @@ -146,8 +146,8 @@ public byte[] getBinary(int ordinal) { @Override public CalendarInterval getInterval(int ordinal) { if (columns[ordinal].isNullAt(rowId)) return null; - final int months = columns[ordinal].getChildColumn(0).getInt(rowId); - final long microseconds = columns[ordinal].getChildColumn(1).getLong(rowId); + final int months = columns[ordinal].getChild(0).getInt(rowId); + final long microseconds = columns[ordinal].getChild(1).getLong(rowId); return new CalendarInterval(months, microseconds); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index d2ae32b06f83b..ca4f00985c2a3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -599,17 +599,13 @@ public final int appendStruct(boolean isNull) { return elementsAppended; } - /** - * Returns the data for the underlying array. - */ + // `WritableColumnVector` puts the data of array in the first child column vector, and puts the + // array offsets and lengths in the current column vector. @Override public WritableColumnVector arrayData() { return childColumns[0]; } - /** - * Returns the ordinal's child data column. - */ @Override - public WritableColumnVector getChildColumn(int ordinal) { return childColumns[ordinal]; } + public WritableColumnVector getChild(int ordinal) { return childColumns[ordinal]; } /** * Returns the elements appended. diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index eb69001fe677e..ff163c2220041 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -33,18 +33,6 @@ public final class ArrowColumnVector extends ColumnVector { private final ArrowVectorAccessor accessor; private ArrowColumnVector[] childColumns; - private void ensureAccessible(int index) { - ensureAccessible(index, 1); - } - - private void ensureAccessible(int index, int count) { - int valueCount = accessor.getValueCount(); - if (index < 0 || index + count > valueCount) { - throw new IndexOutOfBoundsException( - String.format("index range: [%d, %d), valueCount: %d", index, index + count, valueCount)); - } - } - @Override public int numNulls() { return accessor.getNullCount(); @@ -55,156 +43,75 @@ public void close() { if (childColumns != null) { for (int i = 0; i < childColumns.length; i++) { childColumns[i].close(); + childColumns[i] = null; } + childColumns = null; } accessor.close(); } @Override public boolean isNullAt(int rowId) { - ensureAccessible(rowId); return accessor.isNullAt(rowId); } @Override public boolean getBoolean(int rowId) { - ensureAccessible(rowId); return accessor.getBoolean(rowId); } - @Override - public boolean[] getBooleans(int rowId, int count) { - ensureAccessible(rowId, count); - boolean[] array = new boolean[count]; - for (int i = 0; i < count; ++i) { - array[i] = accessor.getBoolean(rowId + i); - } - return array; - } - @Override public byte getByte(int rowId) { - ensureAccessible(rowId); return accessor.getByte(rowId); } - @Override - public byte[] getBytes(int rowId, int count) { - ensureAccessible(rowId, count); - byte[] array = new byte[count]; - for (int i = 0; i < count; ++i) { - array[i] = accessor.getByte(rowId + i); - } - return array; - } - @Override public short getShort(int rowId) { - ensureAccessible(rowId); return accessor.getShort(rowId); } - @Override - public short[] getShorts(int rowId, int count) { - ensureAccessible(rowId, count); - short[] array = new short[count]; - for (int i = 0; i < count; ++i) { - array[i] = accessor.getShort(rowId + i); - } - return array; - } - @Override public int getInt(int rowId) { - ensureAccessible(rowId); return accessor.getInt(rowId); } - @Override - public int[] getInts(int rowId, int count) { - ensureAccessible(rowId, count); - int[] array = new int[count]; - for (int i = 0; i < count; ++i) { - array[i] = accessor.getInt(rowId + i); - } - return array; - } - @Override public long getLong(int rowId) { - ensureAccessible(rowId); return accessor.getLong(rowId); } - @Override - public long[] getLongs(int rowId, int count) { - ensureAccessible(rowId, count); - long[] array = new long[count]; - for (int i = 0; i < count; ++i) { - array[i] = accessor.getLong(rowId + i); - } - return array; - } - @Override public float getFloat(int rowId) { - ensureAccessible(rowId); return accessor.getFloat(rowId); } - @Override - public float[] getFloats(int rowId, int count) { - ensureAccessible(rowId, count); - float[] array = new float[count]; - for (int i = 0; i < count; ++i) { - array[i] = accessor.getFloat(rowId + i); - } - return array; - } - @Override public double getDouble(int rowId) { - ensureAccessible(rowId); return accessor.getDouble(rowId); } - @Override - public double[] getDoubles(int rowId, int count) { - ensureAccessible(rowId, count); - double[] array = new double[count]; - for (int i = 0; i < count; ++i) { - array[i] = accessor.getDouble(rowId + i); - } - return array; - } - @Override public int getArrayLength(int rowId) { - ensureAccessible(rowId); return accessor.getArrayLength(rowId); } @Override public int getArrayOffset(int rowId) { - ensureAccessible(rowId); return accessor.getArrayOffset(rowId); } @Override public Decimal getDecimal(int rowId, int precision, int scale) { - ensureAccessible(rowId); return accessor.getDecimal(rowId, precision, scale); } @Override public UTF8String getUTF8String(int rowId) { - ensureAccessible(rowId); return accessor.getUTF8String(rowId); } @Override public byte[] getBinary(int rowId) { - ensureAccessible(rowId); return accessor.getBinary(rowId); } @@ -212,7 +119,7 @@ public byte[] getBinary(int rowId) { public ArrowColumnVector arrayData() { return childColumns[0]; } @Override - public ArrowColumnVector getChildColumn(int ordinal) { return childColumns[ordinal]; } + public ArrowColumnVector getChild(int ordinal) { return childColumns[ordinal]; } public ArrowColumnVector(ValueVector vector) { super(ArrowUtils.fromArrowField(vector.getField())); diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java index d1196e1299fee..f9936214035b6 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java @@ -51,12 +51,16 @@ public abstract class ColumnVector implements AutoCloseable { public final DataType dataType() { return type; } /** - * Cleans up memory for this column. The column is not usable after this. + * Cleans up memory for this column vector. The column vector is not usable after this. + * + * This overwrites `AutoCloseable.close` to remove the `throws` clause, as column vector is + * in-memory and we don't expect any exception to happen during closing. */ + @Override public abstract void close(); /** - * Returns the number of nulls in this column. + * Returns the number of nulls in this column vector. */ public abstract int numNulls(); @@ -73,7 +77,13 @@ public abstract class ColumnVector implements AutoCloseable { /** * Gets values from [rowId, rowId + count) */ - public abstract boolean[] getBooleans(int rowId, int count); + public boolean[] getBooleans(int rowId, int count) { + boolean[] res = new boolean[count]; + for (int i = 0; i < count; i++) { + res[i] = getBoolean(rowId + i); + } + return res; + } /** * Returns the value for rowId. @@ -83,7 +93,13 @@ public abstract class ColumnVector implements AutoCloseable { /** * Gets values from [rowId, rowId + count) */ - public abstract byte[] getBytes(int rowId, int count); + public byte[] getBytes(int rowId, int count) { + byte[] res = new byte[count]; + for (int i = 0; i < count; i++) { + res[i] = getByte(rowId + i); + } + return res; + } /** * Returns the value for rowId. @@ -93,7 +109,13 @@ public abstract class ColumnVector implements AutoCloseable { /** * Gets values from [rowId, rowId + count) */ - public abstract short[] getShorts(int rowId, int count); + public short[] getShorts(int rowId, int count) { + short[] res = new short[count]; + for (int i = 0; i < count; i++) { + res[i] = getShort(rowId + i); + } + return res; + } /** * Returns the value for rowId. @@ -103,7 +125,13 @@ public abstract class ColumnVector implements AutoCloseable { /** * Gets values from [rowId, rowId + count) */ - public abstract int[] getInts(int rowId, int count); + public int[] getInts(int rowId, int count) { + int[] res = new int[count]; + for (int i = 0; i < count; i++) { + res[i] = getInt(rowId + i); + } + return res; + } /** * Returns the value for rowId. @@ -113,7 +141,13 @@ public abstract class ColumnVector implements AutoCloseable { /** * Gets values from [rowId, rowId + count) */ - public abstract long[] getLongs(int rowId, int count); + public long[] getLongs(int rowId, int count) { + long[] res = new long[count]; + for (int i = 0; i < count; i++) { + res[i] = getLong(rowId + i); + } + return res; + } /** * Returns the value for rowId. @@ -123,7 +157,13 @@ public abstract class ColumnVector implements AutoCloseable { /** * Gets values from [rowId, rowId + count) */ - public abstract float[] getFloats(int rowId, int count); + public float[] getFloats(int rowId, int count) { + float[] res = new float[count]; + for (int i = 0; i < count; i++) { + res[i] = getFloat(rowId + i); + } + return res; + } /** * Returns the value for rowId. @@ -133,7 +173,13 @@ public abstract class ColumnVector implements AutoCloseable { /** * Gets values from [rowId, rowId + count) */ - public abstract double[] getDoubles(int rowId, int count); + public double[] getDoubles(int rowId, int count) { + double[] res = new double[count]; + for (int i = 0; i < count; i++) { + res[i] = getDouble(rowId + i); + } + return res; + } /** * Returns the length of the array for rowId. @@ -152,14 +198,6 @@ public final ColumnarRow getStruct(int rowId) { return new ColumnarRow(this, rowId); } - /** - * A special version of {@link #getStruct(int)}, which is only used as an adapter for Spark - * codegen framework, the second parameter is totally ignored. - */ - public final ColumnarRow getStruct(int rowId, int size) { - return getStruct(rowId); - } - /** * Returns the array for rowId. */ @@ -196,9 +234,9 @@ public MapData getMap(int ordinal) { public abstract ColumnVector arrayData(); /** - * Returns the ordinal's child data column. + * Returns the ordinal's child column vector. */ - public abstract ColumnVector getChildColumn(int ordinal); + public abstract ColumnVector getChild(int ordinal); /** * Data type for this column. @@ -206,8 +244,7 @@ public MapData getMap(int ordinal) { protected DataType type; /** - * Sets up the common state and also handles creating the child columns if this is a nested - * type. + * Sets up the data type of this column vector. */ protected ColumnVector(DataType type) { this.type = type; diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java index 0d89a52e7a4fe..522c39580389f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java @@ -133,8 +133,8 @@ public byte[] getBinary(int ordinal) { @Override public CalendarInterval getInterval(int ordinal) { - int month = data.getChildColumn(0).getInt(offset + ordinal); - long microseconds = data.getChildColumn(1).getLong(offset + ordinal); + int month = data.getChild(0).getInt(offset + ordinal); + long microseconds = data.getChild(1).getLong(offset + ordinal); return new CalendarInterval(month, microseconds); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java index 3c6656dec77cd..2e59085a82768 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java @@ -28,7 +28,7 @@ */ public final class ColumnarRow extends InternalRow { // The data for this row. - // E.g. the value of 3rd int field is `data.getChildColumn(3).getInt(rowId)`. + // E.g. the value of 3rd int field is `data.getChild(3).getInt(rowId)`. private final ColumnVector data; private final int rowId; private final int numFields; @@ -53,7 +53,7 @@ public InternalRow copy() { if (isNullAt(i)) { row.setNullAt(i); } else { - DataType dt = data.getChildColumn(i).dataType(); + DataType dt = data.getChild(i).dataType(); if (dt instanceof BooleanType) { row.setBoolean(i, getBoolean(i)); } else if (dt instanceof ByteType) { @@ -93,65 +93,65 @@ public boolean anyNull() { } @Override - public boolean isNullAt(int ordinal) { return data.getChildColumn(ordinal).isNullAt(rowId); } + public boolean isNullAt(int ordinal) { return data.getChild(ordinal).isNullAt(rowId); } @Override - public boolean getBoolean(int ordinal) { return data.getChildColumn(ordinal).getBoolean(rowId); } + public boolean getBoolean(int ordinal) { return data.getChild(ordinal).getBoolean(rowId); } @Override - public byte getByte(int ordinal) { return data.getChildColumn(ordinal).getByte(rowId); } + public byte getByte(int ordinal) { return data.getChild(ordinal).getByte(rowId); } @Override - public short getShort(int ordinal) { return data.getChildColumn(ordinal).getShort(rowId); } + public short getShort(int ordinal) { return data.getChild(ordinal).getShort(rowId); } @Override - public int getInt(int ordinal) { return data.getChildColumn(ordinal).getInt(rowId); } + public int getInt(int ordinal) { return data.getChild(ordinal).getInt(rowId); } @Override - public long getLong(int ordinal) { return data.getChildColumn(ordinal).getLong(rowId); } + public long getLong(int ordinal) { return data.getChild(ordinal).getLong(rowId); } @Override - public float getFloat(int ordinal) { return data.getChildColumn(ordinal).getFloat(rowId); } + public float getFloat(int ordinal) { return data.getChild(ordinal).getFloat(rowId); } @Override - public double getDouble(int ordinal) { return data.getChildColumn(ordinal).getDouble(rowId); } + public double getDouble(int ordinal) { return data.getChild(ordinal).getDouble(rowId); } @Override public Decimal getDecimal(int ordinal, int precision, int scale) { - if (data.getChildColumn(ordinal).isNullAt(rowId)) return null; - return data.getChildColumn(ordinal).getDecimal(rowId, precision, scale); + if (data.getChild(ordinal).isNullAt(rowId)) return null; + return data.getChild(ordinal).getDecimal(rowId, precision, scale); } @Override public UTF8String getUTF8String(int ordinal) { - if (data.getChildColumn(ordinal).isNullAt(rowId)) return null; - return data.getChildColumn(ordinal).getUTF8String(rowId); + if (data.getChild(ordinal).isNullAt(rowId)) return null; + return data.getChild(ordinal).getUTF8String(rowId); } @Override public byte[] getBinary(int ordinal) { - if (data.getChildColumn(ordinal).isNullAt(rowId)) return null; - return data.getChildColumn(ordinal).getBinary(rowId); + if (data.getChild(ordinal).isNullAt(rowId)) return null; + return data.getChild(ordinal).getBinary(rowId); } @Override public CalendarInterval getInterval(int ordinal) { - if (data.getChildColumn(ordinal).isNullAt(rowId)) return null; - final int months = data.getChildColumn(ordinal).getChildColumn(0).getInt(rowId); - final long microseconds = data.getChildColumn(ordinal).getChildColumn(1).getLong(rowId); + if (data.getChild(ordinal).isNullAt(rowId)) return null; + final int months = data.getChild(ordinal).getChild(0).getInt(rowId); + final long microseconds = data.getChild(ordinal).getChild(1).getLong(rowId); return new CalendarInterval(months, microseconds); } @Override public ColumnarRow getStruct(int ordinal, int numFields) { - if (data.getChildColumn(ordinal).isNullAt(rowId)) return null; - return data.getChildColumn(ordinal).getStruct(rowId); + if (data.getChild(ordinal).isNullAt(rowId)) return null; + return data.getChild(ordinal).getStruct(rowId); } @Override public ColumnarArray getArray(int ordinal) { - if (data.getChildColumn(ordinal).isNullAt(rowId)) return null; - return data.getChildColumn(ordinal).getArray(rowId); + if (data.getChild(ordinal).isNullAt(rowId)) return null; + return data.getChild(ordinal).getArray(rowId); } @Override diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index dd68df9686691..04f2619ed7541 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -50,7 +50,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { dataType: DataType, nullable: Boolean): ExprCode = { val javaType = ctx.javaType(dataType) - val value = ctx.getValue(columnVar, dataType, ordinal) + val value = ctx.getValueFromVector(columnVar, dataType, ordinal) val isNullVar = if (nullable) { ctx.freshName("isNull") } else { "false" } val valueVar = ctx.freshName("value") val str = s"columnVector[$columnVar, $ordinal, ${dataType.simpleString}]" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala index 0cf9b53ce1d5d..40b1d7b016dd8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala @@ -127,8 +127,8 @@ class VectorizedHashMapGenerator( def genEqualsForKeys(groupingKeys: Seq[Buffer]): String = { groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) => - s"""(${ctx.genEqual(key.dataType, ctx.getValue(s"vectors[$ordinal]", "buckets[idx]", - key.dataType), key.name)})""" + val value = ctx.getValueFromVector(s"vectors[$ordinal]", key.dataType, "buckets[idx]") + s"(${ctx.genEqual(key.dataType, value, key.name)})" }.mkString(" && ") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala index c42bc60a59d67..92506032ab2e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala @@ -217,21 +217,21 @@ class ArrowWriterSuite extends SparkFunSuite { val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0)) - val struct0 = reader.getStruct(0, 2) + val struct0 = reader.getStruct(0) assert(struct0.getInt(0) === 1) assert(struct0.getUTF8String(1) === UTF8String.fromString("str1")) - val struct1 = reader.getStruct(1, 2) + val struct1 = reader.getStruct(1) assert(struct1.isNullAt(0)) assert(struct1.isNullAt(1)) assert(reader.isNullAt(2)) - val struct3 = reader.getStruct(3, 2) + val struct3 = reader.getStruct(3) assert(struct3.getInt(0) === 4) assert(struct3.isNullAt(1)) - val struct4 = reader.getStruct(4, 2) + val struct4 = reader.getStruct(4) assert(struct4.isNullAt(0)) assert(struct4.getUTF8String(1) === UTF8String.fromString("str5")) @@ -252,15 +252,15 @@ class ArrowWriterSuite extends SparkFunSuite { val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0)) - val struct00 = reader.getStruct(0, 1).getStruct(0, 2) + val struct00 = reader.getStruct(0).getStruct(0, 2) assert(struct00.getInt(0) === 1) assert(struct00.getUTF8String(1) === UTF8String.fromString("str1")) - val struct10 = reader.getStruct(1, 1).getStruct(0, 2) + val struct10 = reader.getStruct(1).getStruct(0, 2) assert(struct10.isNullAt(0)) assert(struct10.isNullAt(1)) - val struct2 = reader.getStruct(2, 1) + val struct2 = reader.getStruct(2) assert(struct2.isNullAt(0)) assert(reader.isNullAt(3)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala index 53432669e215d..e794f50781ff2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala @@ -346,11 +346,11 @@ class ArrowColumnVectorSuite extends SparkFunSuite { assert(columnVector.dataType === schema) assert(columnVector.numNulls === 0) - val row0 = columnVector.getStruct(0, 2) + val row0 = columnVector.getStruct(0) assert(row0.getInt(0) === 1) assert(row0.getLong(1) === 1L) - val row1 = columnVector.getStruct(1, 2) + val row1 = columnVector.getStruct(1) assert(row1.getInt(0) === 2) assert(row1.isNullAt(1)) @@ -398,21 +398,21 @@ class ArrowColumnVectorSuite extends SparkFunSuite { assert(columnVector.dataType === schema) assert(columnVector.numNulls === 1) - val row0 = columnVector.getStruct(0, 2) + val row0 = columnVector.getStruct(0) assert(row0.getInt(0) === 1) assert(row0.getLong(1) === 1L) - val row1 = columnVector.getStruct(1, 2) + val row1 = columnVector.getStruct(1) assert(row1.getInt(0) === 2) assert(row1.isNullAt(1)) - val row2 = columnVector.getStruct(2, 2) + val row2 = columnVector.getStruct(2) assert(row2.isNullAt(0)) assert(row2.getLong(1) === 3L) assert(columnVector.isNullAt(3)) - val row4 = columnVector.getStruct(4, 2) + val row4 = columnVector.getStruct(4) assert(row4.getInt(0) === 5) assert(row4.getLong(1) === 5L) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala index 944240f3bade5..2d1ad4b456783 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala @@ -199,17 +199,17 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { val structType: StructType = new StructType().add("int", IntegerType).add("double", DoubleType) testVectors("struct", 10, structType) { testVector => - val c1 = testVector.getChildColumn(0) - val c2 = testVector.getChildColumn(1) + val c1 = testVector.getChild(0) + val c2 = testVector.getChild(1) c1.putInt(0, 123) c2.putDouble(0, 3.45) c1.putInt(1, 456) c2.putDouble(1, 5.67) - assert(testVector.getStruct(0, structType.length).get(0, IntegerType) === 123) - assert(testVector.getStruct(0, structType.length).get(1, DoubleType) === 3.45) - assert(testVector.getStruct(1, structType.length).get(0, IntegerType) === 456) - assert(testVector.getStruct(1, structType.length).get(1, DoubleType) === 5.67) + assert(testVector.getStruct(0).get(0, IntegerType) === 123) + assert(testVector.getStruct(0).get(1, DoubleType) === 3.45) + assert(testVector.getStruct(1).get(0, IntegerType) === 456) + assert(testVector.getStruct(1).get(1, DoubleType) === 5.67) } test("[SPARK-22092] off-heap column vector reallocation corrupts array data") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala index 38ea2e47fdef8..ad74fb99b0c73 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala @@ -268,17 +268,17 @@ object ColumnarBatchBenchmark { Int Read/Write: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Java Array 177 / 181 1856.4 0.5 1.0X - ByteBuffer Unsafe 318 / 322 1032.0 1.0 0.6X - ByteBuffer API 1411 / 1418 232.2 4.3 0.1X - DirectByteBuffer 467 / 474 701.8 1.4 0.4X - Unsafe Buffer 178 / 185 1843.6 0.5 1.0X - Column(on heap) 178 / 184 1840.8 0.5 1.0X - Column(off heap) 341 / 344 961.8 1.0 0.5X - Column(off heap direct) 178 / 184 1845.4 0.5 1.0X - UnsafeRow (on heap) 378 / 389 866.3 1.2 0.5X - UnsafeRow (off heap) 393 / 402 834.0 1.2 0.4X - Column On Heap Append 309 / 318 1059.1 0.9 0.6X + Java Array 177 / 183 1851.1 0.5 1.0X + ByteBuffer Unsafe 314 / 330 1043.7 1.0 0.6X + ByteBuffer API 1298 / 1307 252.4 4.0 0.1X + DirectByteBuffer 465 / 483 704.2 1.4 0.4X + Unsafe Buffer 179 / 183 1835.5 0.5 1.0X + Column(on heap) 181 / 186 1815.2 0.6 1.0X + Column(off heap) 344 / 349 951.7 1.1 0.5X + Column(off heap direct) 178 / 186 1838.6 0.5 1.0X + UnsafeRow (on heap) 388 / 394 844.8 1.2 0.5X + UnsafeRow (off heap) 400 / 403 819.4 1.2 0.4X + Column On Heap Append 315 / 325 1041.8 1.0 0.6X */ val benchmark = new Benchmark("Int Read/Write", count * iters) benchmark.addCase("Java Array")(javaArray) @@ -337,8 +337,8 @@ object ColumnarBatchBenchmark { Boolean Read/Write: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Bitset 726 / 727 462.4 2.2 1.0X - Byte Array 530 / 542 632.7 1.6 1.4X + Bitset 741 / 747 452.6 2.2 1.0X + Byte Array 531 / 542 631.6 1.6 1.4X */ benchmark.run() } @@ -394,8 +394,8 @@ object ColumnarBatchBenchmark { String Read/Write: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - On Heap 332 / 338 49.3 20.3 1.0X - Off Heap 466 / 467 35.2 28.4 0.7X + On Heap 351 / 362 46.6 21.4 1.0X + Off Heap 456 / 466 35.9 27.8 0.8X */ val benchmark = new Benchmark("String Read/Write", count * iters) benchmark.addCase("On Heap")(column(MemoryMode.ON_HEAP)) @@ -479,10 +479,10 @@ object ColumnarBatchBenchmark { Array Vector Read: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - On Heap Read Size Only 415 / 422 394.7 2.5 1.0X - Off Heap Read Size Only 394 / 402 415.9 2.4 1.1X - On Heap Read Elements 2558 / 2593 64.0 15.6 0.2X - Off Heap Read Elements 3316 / 3317 49.4 20.2 0.1X + On Heap Read Size Only 416 / 423 393.5 2.5 1.0X + Off Heap Read Size Only 396 / 404 413.6 2.4 1.1X + On Heap Read Elements 2569 / 2590 63.8 15.7 0.2X + Off Heap Read Elements 3302 / 3333 49.6 20.2 0.1X */ benchmark.run } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 675f06b31b970..9ec19a8a5dbe1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -732,8 +732,8 @@ class ColumnarBatchSuite extends SparkFunSuite { "Struct Column", 10, new StructType().add("int", IntegerType).add("double", DoubleType)) { column => - val c1 = column.getChildColumn(0) - val c2 = column.getChildColumn(1) + val c1 = column.getChild(0) + val c2 = column.getChild(1) assert(c1.dataType() == IntegerType) assert(c2.dataType() == DoubleType) @@ -787,8 +787,8 @@ class ColumnarBatchSuite extends SparkFunSuite { 10, new ArrayType(structType, true)) { column => val data = column.arrayData() - val c0 = data.getChildColumn(0) - val c1 = data.getChildColumn(1) + val c0 = data.getChild(0) + val c1 = data.getChild(1) // Structs in child column: (0, 0), (1, 10), (2, 20), (3, 30), (4, 40), (5, 50) (0 until 6).foreach { i => c0.putInt(i, i) @@ -815,8 +815,8 @@ class ColumnarBatchSuite extends SparkFunSuite { new StructType() .add("int", IntegerType) .add("array", new ArrayType(IntegerType, true))) { column => - val c0 = column.getChildColumn(0) - val c1 = column.getChildColumn(1) + val c0 = column.getChild(0) + val c1 = column.getChild(1) c0.putInt(0, 0) c0.putInt(1, 1) c0.putInt(2, 2) @@ -844,13 +844,13 @@ class ColumnarBatchSuite extends SparkFunSuite { "Nest Struct in Struct", 10, new StructType().add("int", IntegerType).add("struct", subSchema)) { column => - val c0 = column.getChildColumn(0) - val c1 = column.getChildColumn(1) + val c0 = column.getChild(0) + val c1 = column.getChild(1) c0.putInt(0, 0) c0.putInt(1, 1) c0.putInt(2, 2) - val c1c0 = c1.getChildColumn(0) - val c1c1 = c1.getChildColumn(1) + val c1c0 = c1.getChild(0) + val c1c1 = c1.getChild(1) // Structs in c1: (7, 70), (8, 80), (9, 90) c1c0.putInt(0, 7) c1c0.putInt(1, 8)