Skip to content
Permalink
Browse files

[SPARK-23090][SQL] polish ColumnVector

## What changes were proposed in this pull request?

Several improvements:
* provide a default implementation for the batch get methods
* rename `getChildColumn` to `getChild`, which is more concise
* remove `getStruct(int, int)`, it's only used to simplify the codegen, which is an internal thing, we should not add a public API for this purpose.

## How was this patch tested?

existing tests

Author: Wenchen Fan <wenchen@databricks.com>

Closes #20277 from cloud-fan/column-vector.

(cherry picked from commit 5d680ca)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information...
cloud-fan committed Jan 22, 2018
1 parent 1069fad commit d963ba031748711ec7847ad0b702911eb7319c63
Showing with 164 additions and 296 deletions.
  1. +7 −11 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
  2. +1 −64 sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java
  3. +9 −14 sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java
  4. +5 −5 sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java
  5. +2 −2 sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java
  6. +3 −7 sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java
  7. +3 −96 sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java
  8. +58 −21 sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java
  9. +2 −2 sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java
  10. +23 −23 sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java
  11. +1 −1 sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
  12. +2 −2 sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala
  13. +7 −7 sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala
  14. +6 −6 sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala
  15. +6 −6 sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala
  16. +19 −19 sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala
  17. +10 −10 sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
@@ -688,17 +688,13 @@ class CodegenContext {
/**
* Returns the specialized code to access a value from a column vector for a given `DataType`.
*/
def getValue(vector: String, rowId: String, dataType: DataType): String = {
val jt = javaType(dataType)
dataType match {
case _ if isPrimitiveType(jt) =>
s"$vector.get${primitiveTypeName(jt)}($rowId)"
case t: DecimalType =>
s"$vector.getDecimal($rowId, ${t.precision}, ${t.scale})"
case StringType =>
s"$vector.getUTF8String($rowId)"
case _ =>
throw new IllegalArgumentException(s"cannot generate code for unsupported type: $dataType")
def getValueFromVector(vector: String, dataType: DataType, rowId: String): String = {
if (dataType.isInstanceOf[StructType]) {
// `ColumnVector.getStruct` is different from `InternalRow.getStruct`, it only takes an
// `ordinal` parameter.
s"$vector.getStruct($rowId)"
} else {
getValue(vector, dataType, rowId)
}
}

@@ -110,57 +110,21 @@ public boolean getBoolean(int rowId) {
return longData.vector[getRowIndex(rowId)] == 1;
}

@Override
public boolean[] getBooleans(int rowId, int count) {
boolean[] res = new boolean[count];
for (int i = 0; i < count; i++) {
res[i] = getBoolean(rowId + i);
}
return res;
}

@Override
public byte getByte(int rowId) {
return (byte) longData.vector[getRowIndex(rowId)];
}

@Override
public byte[] getBytes(int rowId, int count) {
byte[] res = new byte[count];
for (int i = 0; i < count; i++) {
res[i] = getByte(rowId + i);
}
return res;
}

@Override
public short getShort(int rowId) {
return (short) longData.vector[getRowIndex(rowId)];
}

@Override
public short[] getShorts(int rowId, int count) {
short[] res = new short[count];
for (int i = 0; i < count; i++) {
res[i] = getShort(rowId + i);
}
return res;
}

@Override
public int getInt(int rowId) {
return (int) longData.vector[getRowIndex(rowId)];
}

@Override
public int[] getInts(int rowId, int count) {
int[] res = new int[count];
for (int i = 0; i < count; i++) {
res[i] = getInt(rowId + i);
}
return res;
}

@Override
public long getLong(int rowId) {
int index = getRowIndex(rowId);
@@ -171,43 +135,16 @@ public long getLong(int rowId) {
}
}

@Override
public long[] getLongs(int rowId, int count) {
long[] res = new long[count];
for (int i = 0; i < count; i++) {
res[i] = getLong(rowId + i);
}
return res;
}

@Override
public float getFloat(int rowId) {
return (float) doubleData.vector[getRowIndex(rowId)];
}

@Override
public float[] getFloats(int rowId, int count) {
float[] res = new float[count];
for (int i = 0; i < count; i++) {
res[i] = getFloat(rowId + i);
}
return res;
}

@Override
public double getDouble(int rowId) {
return doubleData.vector[getRowIndex(rowId)];
}

@Override
public double[] getDoubles(int rowId, int count) {
double[] res = new double[count];
for (int i = 0; i < count; i++) {
res[i] = getDouble(rowId + i);
}
return res;
}

@Override
public int getArrayLength(int rowId) {
throw new UnsupportedOperationException();
@@ -245,7 +182,7 @@ public UTF8String getUTF8String(int rowId) {
}

@Override
public org.apache.spark.sql.vectorized.ColumnVector getChildColumn(int ordinal) {
public org.apache.spark.sql.vectorized.ColumnVector getChild(int ordinal) {
throw new UnsupportedOperationException();
}
}
@@ -289,10 +289,9 @@ private void putRepeatingValues(
toColumn.putDoubles(0, batchSize, ((DoubleColumnVector)fromColumn).vector[0]);
} else if (type instanceof StringType || type instanceof BinaryType) {
BytesColumnVector data = (BytesColumnVector)fromColumn;
WritableColumnVector arrayData = toColumn.getChildColumn(0);
int size = data.vector[0].length;
arrayData.reserve(size);
arrayData.putBytes(0, size, data.vector[0], 0);
toColumn.arrayData().reserve(size);
toColumn.arrayData().putBytes(0, size, data.vector[0], 0);
for (int index = 0; index < batchSize; index++) {
toColumn.putArray(index, 0, size);
}
@@ -352,7 +351,7 @@ private void putNonNullValues(
toColumn.putDoubles(0, batchSize, ((DoubleColumnVector)fromColumn).vector, 0);
} else if (type instanceof StringType || type instanceof BinaryType) {
BytesColumnVector data = ((BytesColumnVector)fromColumn);
WritableColumnVector arrayData = toColumn.getChildColumn(0);
WritableColumnVector arrayData = toColumn.arrayData();
int totalNumBytes = IntStream.of(data.length).sum();
arrayData.reserve(totalNumBytes);
for (int index = 0, pos = 0; index < batchSize; pos += data.length[index], index++) {
@@ -363,8 +362,7 @@ private void putNonNullValues(
DecimalType decimalType = (DecimalType)type;
DecimalColumnVector data = ((DecimalColumnVector)fromColumn);
if (decimalType.precision() > Decimal.MAX_LONG_DIGITS()) {
WritableColumnVector arrayData = toColumn.getChildColumn(0);
arrayData.reserve(batchSize * 16);
toColumn.arrayData().reserve(batchSize * 16);
}
for (int index = 0; index < batchSize; index++) {
putDecimalWritable(
@@ -459,7 +457,7 @@ private void putValues(
}
} else if (type instanceof StringType || type instanceof BinaryType) {
BytesColumnVector vector = (BytesColumnVector)fromColumn;
WritableColumnVector arrayData = toColumn.getChildColumn(0);
WritableColumnVector arrayData = toColumn.arrayData();
int totalNumBytes = IntStream.of(vector.length).sum();
arrayData.reserve(totalNumBytes);
for (int index = 0, pos = 0; index < batchSize; pos += vector.length[index], index++) {
@@ -474,8 +472,7 @@ private void putValues(
DecimalType decimalType = (DecimalType)type;
HiveDecimalWritable[] vector = ((DecimalColumnVector)fromColumn).vector;
if (decimalType.precision() > Decimal.MAX_LONG_DIGITS()) {
WritableColumnVector arrayData = toColumn.getChildColumn(0);
arrayData.reserve(batchSize * 16);
toColumn.arrayData().reserve(batchSize * 16);
}
for (int index = 0; index < batchSize; index++) {
if (fromColumn.isNull[index]) {
@@ -521,8 +518,7 @@ private static void putDecimalWritable(
toColumn.putLong(index, value.toUnscaledLong());
} else {
byte[] bytes = value.toJavaBigDecimal().unscaledValue().toByteArray();
WritableColumnVector arrayData = toColumn.getChildColumn(0);
arrayData.putBytes(index * 16, bytes.length, bytes, 0);
toColumn.arrayData().putBytes(index * 16, bytes.length, bytes, 0);
toColumn.putArray(index, index * 16, bytes.length);
}
}
@@ -547,9 +543,8 @@ private static void putDecimalWritables(
toColumn.putLongs(0, size, value.toUnscaledLong());
} else {
byte[] bytes = value.toJavaBigDecimal().unscaledValue().toByteArray();
WritableColumnVector arrayData = toColumn.getChildColumn(0);
arrayData.reserve(bytes.length);
arrayData.putBytes(0, bytes.length, bytes, 0);
toColumn.arrayData().reserve(bytes.length);
toColumn.arrayData().putBytes(0, bytes.length, bytes, 0);
for (int index = 0; index < size; index++) {
toColumn.putArray(index, 0, bytes.length);
}
@@ -85,8 +85,8 @@ public static void populate(WritableColumnVector col, InternalRow row, int field
}
} else if (t instanceof CalendarIntervalType) {
CalendarInterval c = (CalendarInterval)row.get(fieldIdx, t);
col.getChildColumn(0).putInts(0, capacity, c.months);
col.getChildColumn(1).putLongs(0, capacity, c.microseconds);
col.getChild(0).putInts(0, capacity, c.months);
col.getChild(1).putLongs(0, capacity, c.microseconds);
} else if (t instanceof DateType) {
col.putInts(0, capacity, row.getInt(fieldIdx));
} else if (t instanceof TimestampType) {
@@ -149,8 +149,8 @@ private static void appendValue(WritableColumnVector dst, DataType t, Object o)
} else if (t instanceof CalendarIntervalType) {
CalendarInterval c = (CalendarInterval)o;
dst.appendStruct(false);
dst.getChildColumn(0).appendInt(c.months);
dst.getChildColumn(1).appendLong(c.microseconds);
dst.getChild(0).appendInt(c.months);
dst.getChild(1).appendLong(c.microseconds);
} else if (t instanceof DateType) {
dst.appendInt(DateTimeUtils.fromJavaDate((Date)o));
} else {
@@ -179,7 +179,7 @@ private static void appendValue(WritableColumnVector dst, DataType t, Row src, i
dst.appendStruct(false);
Row c = src.getStruct(fieldIdx);
for (int i = 0; i < st.fields().length; i++) {
appendValue(dst.getChildColumn(i), st.fields()[i].dataType(), c, i);
appendValue(dst.getChild(i), st.fields()[i].dataType(), c, i);
}
}
} else {
@@ -146,8 +146,8 @@ public UTF8String getUTF8String(int ordinal) {
@Override
public CalendarInterval getInterval(int ordinal) {
if (columns[ordinal].isNullAt(rowId)) return null;
final int months = columns[ordinal].getChildColumn(0).getInt(rowId);
final long microseconds = columns[ordinal].getChildColumn(1).getLong(rowId);
final int months = columns[ordinal].getChild(0).getInt(rowId);
final long microseconds = columns[ordinal].getChild(1).getLong(rowId);
return new CalendarInterval(months, microseconds);
}

@@ -599,17 +599,13 @@ public final int appendStruct(boolean isNull) {
return elementsAppended;
}

/**
* Returns the data for the underlying array.
*/
// `WritableColumnVector` puts the data of array in the first child column vector, and puts the
// array offsets and lengths in the current column vector.
@Override
public WritableColumnVector arrayData() { return childColumns[0]; }

/**
* Returns the ordinal's child data column.
*/
@Override
public WritableColumnVector getChildColumn(int ordinal) { return childColumns[ordinal]; }
public WritableColumnVector getChild(int ordinal) { return childColumns[ordinal]; }

/**
* Returns the elements appended.

0 comments on commit d963ba0

Please sign in to comment.
You can’t perform that action at this time.