Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-23090][SQL] polish ColumnVector #20277

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -688,17 +688,13 @@ class CodegenContext {
/**
* Returns the specialized code to access a value from a column vector for a given `DataType`.
*/
def getValue(vector: String, rowId: String, dataType: DataType): String = {
val jt = javaType(dataType)
dataType match {
case _ if isPrimitiveType(jt) =>
s"$vector.get${primitiveTypeName(jt)}($rowId)"
case t: DecimalType =>
s"$vector.getDecimal($rowId, ${t.precision}, ${t.scale})"
case StringType =>
s"$vector.getUTF8String($rowId)"
case _ =>
throw new IllegalArgumentException(s"cannot generate code for unsupported type: $dataType")
def getValueFromVector(vector: String, dataType: DataType, rowId: String): String = {
if (dataType.isInstanceOf[StructType]) {
// `ColumnVector.getStruct` is different from `InternalRow.getStruct`, it only takes an
// `ordinal` parameter.
s"$vector.getStruct($rowId)"
} else {
getValue(vector, dataType, rowId)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,57 +110,21 @@ public boolean getBoolean(int rowId) {
return longData.vector[getRowIndex(rowId)] == 1;
}

@Override
public boolean[] getBooleans(int rowId, int count) {
boolean[] res = new boolean[count];
for (int i = 0; i < count; i++) {
res[i] = getBoolean(rowId + i);
}
return res;
}

@Override
public byte getByte(int rowId) {
return (byte) longData.vector[getRowIndex(rowId)];
}

@Override
public byte[] getBytes(int rowId, int count) {
byte[] res = new byte[count];
for (int i = 0; i < count; i++) {
res[i] = getByte(rowId + i);
}
return res;
}

@Override
public short getShort(int rowId) {
return (short) longData.vector[getRowIndex(rowId)];
}

@Override
public short[] getShorts(int rowId, int count) {
short[] res = new short[count];
for (int i = 0; i < count; i++) {
res[i] = getShort(rowId + i);
}
return res;
}

@Override
public int getInt(int rowId) {
return (int) longData.vector[getRowIndex(rowId)];
}

@Override
public int[] getInts(int rowId, int count) {
int[] res = new int[count];
for (int i = 0; i < count; i++) {
res[i] = getInt(rowId + i);
}
return res;
}

@Override
public long getLong(int rowId) {
int index = getRowIndex(rowId);
Expand All @@ -171,43 +135,16 @@ public long getLong(int rowId) {
}
}

@Override
public long[] getLongs(int rowId, int count) {
long[] res = new long[count];
for (int i = 0; i < count; i++) {
res[i] = getLong(rowId + i);
}
return res;
}

@Override
public float getFloat(int rowId) {
return (float) doubleData.vector[getRowIndex(rowId)];
}

@Override
public float[] getFloats(int rowId, int count) {
float[] res = new float[count];
for (int i = 0; i < count; i++) {
res[i] = getFloat(rowId + i);
}
return res;
}

@Override
public double getDouble(int rowId) {
return doubleData.vector[getRowIndex(rowId)];
}

@Override
public double[] getDoubles(int rowId, int count) {
double[] res = new double[count];
for (int i = 0; i < count; i++) {
res[i] = getDouble(rowId + i);
}
return res;
}

@Override
public int getArrayLength(int rowId) {
throw new UnsupportedOperationException();
Expand Down Expand Up @@ -245,7 +182,7 @@ public org.apache.spark.sql.vectorized.ColumnVector arrayData() {
}

@Override
public org.apache.spark.sql.vectorized.ColumnVector getChildColumn(int ordinal) {
public org.apache.spark.sql.vectorized.ColumnVector getChild(int ordinal) {
throw new UnsupportedOperationException();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -302,10 +302,9 @@ private void putRepeatingValues(
toColumn.putDoubles(0, batchSize, ((DoubleColumnVector)fromColumn).vector[0]);
} else if (type instanceof StringType || type instanceof BinaryType) {
BytesColumnVector data = (BytesColumnVector)fromColumn;
WritableColumnVector arrayData = toColumn.getChildColumn(0);
int size = data.vector[0].length;
arrayData.reserve(size);
arrayData.putBytes(0, size, data.vector[0], 0);
toColumn.arrayData().reserve(size);
toColumn.arrayData().putBytes(0, size, data.vector[0], 0);
for (int index = 0; index < batchSize; index++) {
toColumn.putArray(index, 0, size);
}
Expand Down Expand Up @@ -365,7 +364,7 @@ private void putNonNullValues(
toColumn.putDoubles(0, batchSize, ((DoubleColumnVector)fromColumn).vector, 0);
} else if (type instanceof StringType || type instanceof BinaryType) {
BytesColumnVector data = ((BytesColumnVector)fromColumn);
WritableColumnVector arrayData = toColumn.getChildColumn(0);
WritableColumnVector arrayData = toColumn.arrayData();
int totalNumBytes = IntStream.of(data.length).sum();
arrayData.reserve(totalNumBytes);
for (int index = 0, pos = 0; index < batchSize; pos += data.length[index], index++) {
Expand All @@ -376,8 +375,7 @@ private void putNonNullValues(
DecimalType decimalType = (DecimalType)type;
DecimalColumnVector data = ((DecimalColumnVector)fromColumn);
if (decimalType.precision() > Decimal.MAX_LONG_DIGITS()) {
WritableColumnVector arrayData = toColumn.getChildColumn(0);
arrayData.reserve(batchSize * 16);
toColumn.arrayData().reserve(batchSize * 16);
}
for (int index = 0; index < batchSize; index++) {
putDecimalWritable(
Expand Down Expand Up @@ -472,7 +470,7 @@ private void putValues(
}
} else if (type instanceof StringType || type instanceof BinaryType) {
BytesColumnVector vector = (BytesColumnVector)fromColumn;
WritableColumnVector arrayData = toColumn.getChildColumn(0);
WritableColumnVector arrayData = toColumn.arrayData();
int totalNumBytes = IntStream.of(vector.length).sum();
arrayData.reserve(totalNumBytes);
for (int index = 0, pos = 0; index < batchSize; pos += vector.length[index], index++) {
Expand All @@ -487,8 +485,7 @@ private void putValues(
DecimalType decimalType = (DecimalType)type;
HiveDecimalWritable[] vector = ((DecimalColumnVector)fromColumn).vector;
if (decimalType.precision() > Decimal.MAX_LONG_DIGITS()) {
WritableColumnVector arrayData = toColumn.getChildColumn(0);
arrayData.reserve(batchSize * 16);
toColumn.arrayData().reserve(batchSize * 16);
}
for (int index = 0; index < batchSize; index++) {
if (fromColumn.isNull[index]) {
Expand Down Expand Up @@ -534,8 +531,7 @@ private static void putDecimalWritable(
toColumn.putLong(index, value.toUnscaledLong());
} else {
byte[] bytes = value.toJavaBigDecimal().unscaledValue().toByteArray();
WritableColumnVector arrayData = toColumn.getChildColumn(0);
arrayData.putBytes(index * 16, bytes.length, bytes, 0);
toColumn.arrayData().putBytes(index * 16, bytes.length, bytes, 0);
toColumn.putArray(index, index * 16, bytes.length);
}
}
Expand All @@ -560,9 +556,8 @@ private static void putDecimalWritables(
toColumn.putLongs(0, size, value.toUnscaledLong());
} else {
byte[] bytes = value.toJavaBigDecimal().unscaledValue().toByteArray();
WritableColumnVector arrayData = toColumn.getChildColumn(0);
arrayData.reserve(bytes.length);
arrayData.putBytes(0, bytes.length, bytes, 0);
toColumn.arrayData().reserve(bytes.length);
toColumn.arrayData().putBytes(0, bytes.length, bytes, 0);
for (int index = 0; index < size; index++) {
toColumn.putArray(index, 0, bytes.length);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ public static void populate(WritableColumnVector col, InternalRow row, int field
}
} else if (t instanceof CalendarIntervalType) {
CalendarInterval c = (CalendarInterval)row.get(fieldIdx, t);
col.getChildColumn(0).putInts(0, capacity, c.months);
col.getChildColumn(1).putLongs(0, capacity, c.microseconds);
col.getChild(0).putInts(0, capacity, c.months);
col.getChild(1).putLongs(0, capacity, c.microseconds);
} else if (t instanceof DateType) {
col.putInts(0, capacity, row.getInt(fieldIdx));
} else if (t instanceof TimestampType) {
Expand Down Expand Up @@ -149,8 +149,8 @@ private static void appendValue(WritableColumnVector dst, DataType t, Object o)
} else if (t instanceof CalendarIntervalType) {
CalendarInterval c = (CalendarInterval)o;
dst.appendStruct(false);
dst.getChildColumn(0).appendInt(c.months);
dst.getChildColumn(1).appendLong(c.microseconds);
dst.getChild(0).appendInt(c.months);
dst.getChild(1).appendLong(c.microseconds);
} else if (t instanceof DateType) {
dst.appendInt(DateTimeUtils.fromJavaDate((Date)o));
} else {
Expand Down Expand Up @@ -179,7 +179,7 @@ private static void appendValue(WritableColumnVector dst, DataType t, Row src, i
dst.appendStruct(false);
Row c = src.getStruct(fieldIdx);
for (int i = 0; i < st.fields().length; i++) {
appendValue(dst.getChildColumn(i), st.fields()[i].dataType(), c, i);
appendValue(dst.getChild(i), st.fields()[i].dataType(), c, i);
}
}
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ public byte[] getBinary(int ordinal) {
@Override
public CalendarInterval getInterval(int ordinal) {
if (columns[ordinal].isNullAt(rowId)) return null;
final int months = columns[ordinal].getChildColumn(0).getInt(rowId);
final long microseconds = columns[ordinal].getChildColumn(1).getLong(rowId);
final int months = columns[ordinal].getChild(0).getInt(rowId);
final long microseconds = columns[ordinal].getChild(1).getLong(rowId);
return new CalendarInterval(months, microseconds);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -599,17 +599,13 @@ public final int appendStruct(boolean isNull) {
return elementsAppended;
}

/**
* Returns the data for the underlying array.
*/
// `WritableColumnVector` puts the data of array in the first child column vector, and puts the
// array offsets and lengths in the current column vector.
@Override
public WritableColumnVector arrayData() { return childColumns[0]; }

/**
* Returns the ordinal's child data column.
*/
@Override
public WritableColumnVector getChildColumn(int ordinal) { return childColumns[ordinal]; }
public WritableColumnVector getChild(int ordinal) { return childColumns[ordinal]; }

/**
* Returns the elements appended.
Expand Down
Loading