Skip to content

Commit

Permalink
remove ColumnVector.getStruct(int, int)
Browse files Browse the repository at this point in the history
  • Loading branch information
cloud-fan committed Jan 19, 2018
1 parent f3f9d5e commit eccdca1
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -198,14 +198,6 @@ public final ColumnarRow getStruct(int rowId) {
return new ColumnarRow(this, rowId);
}

/**
* A special version of {@link #getStruct(int)}, which is only used as an adapter for Spark
* codegen framework, the second parameter is totally ignored.
*/
public final ColumnarRow getStruct(int rowId, int size) {
return getStruct(rowId);
}

/**
* Returns the array for rowId.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.sql.catalyst.expressions.{BoundReference, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}


Expand Down Expand Up @@ -50,7 +50,14 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport {
dataType: DataType,
nullable: Boolean): ExprCode = {
val javaType = ctx.javaType(dataType)
val value = ctx.getValue(columnVar, dataType, ordinal)
val value = if (dataType.isInstanceOf[StructType]) {
// `ColumnVector.getStruct` is different from `InternalRow.getStruct`, it only takes an
// `ordinal` parameter.
s"$columnVar.getStruct($ordinal)"
} else {
ctx.getValue(columnVar, dataType, ordinal)
}

val isNullVar = if (nullable) { ctx.freshName("isNull") } else { "false" }
val valueVar = ctx.freshName("value")
val str = s"columnVector[$columnVar, $ordinal, ${dataType.simpleString}]"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,14 @@ class VectorizedHashMapGenerator(

def genEqualsForKeys(groupingKeys: Seq[Buffer]): String = {
groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) =>
s"""(${ctx.genEqual(key.dataType, ctx.getValue(s"vectors[$ordinal]", "buckets[idx]",
key.dataType), key.name)})"""
// `ColumnVector.getStruct` is different from `InternalRow.getStruct`, it only takes an
// `ordinal` parameter.
val value = if (key.dataType.isInstanceOf[StructType]) {
s"vectors[$ordinal].getStruct(buckets[idx])"
} else {
ctx.getValue(s"vectors[$ordinal]", "buckets[idx]", key.dataType)
}
s"(${ctx.genEqual(key.dataType, value, key.name)})"
}.mkString(" && ")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,21 +217,21 @@ class ArrowWriterSuite extends SparkFunSuite {

val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0))

val struct0 = reader.getStruct(0, 2)
val struct0 = reader.getStruct(0)
assert(struct0.getInt(0) === 1)
assert(struct0.getUTF8String(1) === UTF8String.fromString("str1"))

val struct1 = reader.getStruct(1, 2)
val struct1 = reader.getStruct(1)
assert(struct1.isNullAt(0))
assert(struct1.isNullAt(1))

assert(reader.isNullAt(2))

val struct3 = reader.getStruct(3, 2)
val struct3 = reader.getStruct(3)
assert(struct3.getInt(0) === 4)
assert(struct3.isNullAt(1))

val struct4 = reader.getStruct(4, 2)
val struct4 = reader.getStruct(4)
assert(struct4.isNullAt(0))
assert(struct4.getUTF8String(1) === UTF8String.fromString("str5"))

Expand All @@ -252,15 +252,15 @@ class ArrowWriterSuite extends SparkFunSuite {

val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0))

val struct00 = reader.getStruct(0, 1).getStruct(0, 2)
val struct00 = reader.getStruct(0).getStruct(0, 2)
assert(struct00.getInt(0) === 1)
assert(struct00.getUTF8String(1) === UTF8String.fromString("str1"))

val struct10 = reader.getStruct(1, 1).getStruct(0, 2)
val struct10 = reader.getStruct(1).getStruct(0, 2)
assert(struct10.isNullAt(0))
assert(struct10.isNullAt(1))

val struct2 = reader.getStruct(2, 1)
val struct2 = reader.getStruct(2)
assert(struct2.isNullAt(0))

assert(reader.isNullAt(3))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -362,21 +362,21 @@ class ArrowColumnVectorSuite extends SparkFunSuite {
assert(columnVector.dataType === schema)
assert(columnVector.numNulls === 1)

val row0 = columnVector.getStruct(0, 2)
val row0 = columnVector.getStruct(0)
assert(row0.getInt(0) === 1)
assert(row0.getLong(1) === 1L)

val row1 = columnVector.getStruct(1, 2)
val row1 = columnVector.getStruct(1)
assert(row1.getInt(0) === 2)
assert(row1.isNullAt(1))

val row2 = columnVector.getStruct(2, 2)
val row2 = columnVector.getStruct(2)
assert(row2.isNullAt(0))
assert(row2.getLong(1) === 3L)

assert(columnVector.isNullAt(3))

val row4 = columnVector.getStruct(4, 2)
val row4 = columnVector.getStruct(4)
assert(row4.getInt(0) === 5)
assert(row4.getLong(1) === 5L)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,10 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach {
c1.putInt(1, 456)
c2.putDouble(1, 5.67)

assert(testVector.getStruct(0, structType.length).get(0, IntegerType) === 123)
assert(testVector.getStruct(0, structType.length).get(1, DoubleType) === 3.45)
assert(testVector.getStruct(1, structType.length).get(0, IntegerType) === 456)
assert(testVector.getStruct(1, structType.length).get(1, DoubleType) === 5.67)
assert(testVector.getStruct(0).get(0, IntegerType) === 123)
assert(testVector.getStruct(0).get(1, DoubleType) === 3.45)
assert(testVector.getStruct(1).get(0, IntegerType) === 456)
assert(testVector.getStruct(1).get(1, DoubleType) === 5.67)
}

test("[SPARK-22092] off-heap column vector reallocation corrupts array data") {
Expand Down

0 comments on commit eccdca1

Please sign in to comment.