Skip to content

Commit

Permalink
More generalize accessor. Make sure testing on all types.
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Apr 10, 2018
1 parent 912c2c2 commit 54dd939
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -121,24 +121,26 @@ object InternalRow {
}

/**
* Returns an accessor for an InternalRow with given data type and ordinal.
* Returns an accessor for an `InternalRow` with given data type. The returned accessor
* actually takes a `SpecializedGetters` input because it can be generalized to other classes
* that implements `SpecializedGetters` (e.g., `ArrayData`) too.
*/
def getAccessor(dataType: DataType, ordinal: Int): (InternalRow) => Any = dataType match {
case BooleanType => (input) => input.getBoolean(ordinal)
case ByteType => (input) => input.getByte(ordinal)
case ShortType => (input) => input.getShort(ordinal)
case IntegerType | DateType => (input) => input.getInt(ordinal)
case LongType | TimestampType => (input) => input.getLong(ordinal)
case FloatType => (input) => input.getFloat(ordinal)
case DoubleType => (input) => input.getDouble(ordinal)
case StringType => (input) => input.getUTF8String(ordinal)
case BinaryType => (input) => input.getBinary(ordinal)
case CalendarIntervalType => (input) => input.getInterval(ordinal)
case t: DecimalType => (input) => input.getDecimal(ordinal, t.precision, t.scale)
case t: StructType => (input) => input.getStruct(ordinal, t.size)
case _: ArrayType => (input) => input.getArray(ordinal)
case _: MapType => (input) => input.getMap(ordinal)
case u: UserDefinedType[_] => getAccessor(u.sqlType, ordinal)
case _ => (input) => input.get(ordinal, dataType)
def getAccessor(dataType: DataType): (SpecializedGetters, Int) => Any = dataType match {
case BooleanType => (input, ordinal) => input.getBoolean(ordinal)
case ByteType => (input, ordinal) => input.getByte(ordinal)
case ShortType => (input, ordinal) => input.getShort(ordinal)
case IntegerType | DateType => (input, ordinal) => input.getInt(ordinal)
case LongType | TimestampType => (input, ordinal) => input.getLong(ordinal)
case FloatType => (input, ordinal) => input.getFloat(ordinal)
case DoubleType => (input, ordinal) => input.getDouble(ordinal)
case StringType => (input, ordinal) => input.getUTF8String(ordinal)
case BinaryType => (input, ordinal) => input.getBinary(ordinal)
case CalendarIntervalType => (input, ordinal) => input.getInterval(ordinal)
case t: DecimalType => (input, ordinal) => input.getDecimal(ordinal, t.precision, t.scale)
case t: StructType => (input, ordinal) => input.getStruct(ordinal, t.size)
case _: ArrayType => (input, ordinal) => input.getArray(ordinal)
case _: MapType => (input, ordinal) => input.getMap(ordinal)
case u: UserDefinedType[_] => getAccessor(u.sqlType)
case _ => (input, ordinal) => input.get(ordinal, dataType)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)

override def toString: String = s"input[$ordinal, ${dataType.simpleString}, $nullable]"

private val accessor: InternalRow => Any = InternalRow.getAccessor(dataType, ordinal)
private val accessor: (InternalRow, Int) => Any = InternalRow.getAccessor(dataType)

// Use special getter for primitive types (for UnsafeRow)
override def eval(input: InternalRow): Any = {
if (nullable && input.isNullAt(ordinal)) {
null
} else {
accessor(input)
accessor(input, ordinal)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,7 @@ case class LambdaVariable(
dataType: DataType,
nullable: Boolean = true) extends LeafExpression with NonSQLExpression {

private val accessor: InternalRow => Any = InternalRow.getAccessor(dataType, 0)
private val accessor: (InternalRow, Int) => Any = InternalRow.getAccessor(dataType)

// Interpreted execution of `LambdaVariable` always get the 0-index element from input row.
override def eval(input: InternalRow): Any = {
Expand All @@ -559,7 +559,7 @@ case class LambdaVariable(
if (nullable && input.isNullAt(0)) {
null
} else {
accessor(input)
accessor(input, 0)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,11 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}

test("LambdaVariable should support interpreted execution") {
def genSchema(dt: DataType): Seq[StructType] = {
Seq(StructType(StructField("col_1", dt, nullable = false) :: Nil),
StructType(StructField("col_1", dt, nullable = true) :: Nil))
}

val elementTypes = Seq(BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType,
DoubleType, DecimalType.USER_DEFAULT, StringType, BinaryType, DateType, TimestampType,
CalendarIntervalType, new ExamplePointUDT())
Expand All @@ -294,15 +299,16 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
StructType(StructField("col1", elementType, true) :: Nil))
}

val acceptedTypes = elementTypes ++ arrayTypes ++ mapTypes ++ structTypes
val testTypes = elementTypes ++ arrayTypes ++ mapTypes ++ structTypes
val random = new Random(100)
(0 until 100).foreach { _ =>
val schema = RandomDataGenerator.randomSchema(random, 1, acceptedTypes)
val row = RandomDataGenerator.randomRow(random, schema)
val rowConverter = RowEncoder(schema)
val internalRow = rowConverter.toRow(row)
val lambda = LambdaVariable("dummy", "dummuIsNull", schema(0).dataType, schema(0).nullable)
checkEvaluationWithoutCodegen(lambda, internalRow.get(0, schema(0).dataType), internalRow)
testTypes.foreach { dt =>
genSchema(dt).map { schema =>
val row = RandomDataGenerator.randomRow(random, schema)
val rowConverter = RowEncoder(schema)
val internalRow = rowConverter.toRow(row)
val lambda = LambdaVariable("dummy", "dummuIsNull", schema(0).dataType, schema(0).nullable)
checkEvaluationWithoutCodegen(lambda, internalRow.get(0, schema(0).dataType), internalRow)
}
}
}
}

0 comments on commit 54dd939

Please sign in to comment.