Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-19008][SQL] Improve performance of Dataset.map by eliminating boxing/unboxing #17172

Closed
wants to merge 12 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState
import org.apache.spark.sql.execution.streaming.KeyedStateImpl
import org.apache.spark.sql.types.{DataType, ObjectType, StructType}
import org.apache.spark.sql.types._


/**
Expand Down Expand Up @@ -219,7 +219,30 @@ case class MapElementsExec(
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
val (funcClass, methodName) = func match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's put this thing in a util so that FilterExec can also use it

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Now, can generate a call to a specialized method for Dataset.filter().

case m: MapFunction[_, _] => classOf[MapFunction[_, _]] -> "call"
case _ => classOf[Any => Any] -> "apply"
case _ =>
(if (child.output.length == 1) child.output(0).dataType else NullType,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the if is not needed, see the assert in ObjectConsumerExec

outputObjAttr.dataType) match {
// if a pair of an argument and return types is one of specific types
// whose specialized method (apply$mc..$sp) is generated by scalac,
// Catalyst generated a direct method call to the specialized method.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you link to some official document or blogpost?

case (IntegerType, IntegerType) => classOf[Int => Int] -> "apply$mcII$sp"
case (IntegerType, LongType) => classOf[Int => Long] -> "apply$mcJI$sp"
case (IntegerType, FloatType) => classOf[Int => Float] -> "apply$mcFI$sp"
case (IntegerType, DoubleType) => classOf[Int => Double] -> "apply$mcDI$sp"
case (LongType, IntegerType) => classOf[Long => Int] -> "apply$mcIJ$sp"
case (LongType, LongType) => classOf[Long => Long] -> "apply$mcJJ$sp"
case (LongType, FloatType) => classOf[Long => Float] -> "apply$mcFJ$sp"
case (LongType, DoubleType) => classOf[Long => Double] -> "apply$mcDJ$sp"
case (FloatType, IntegerType) => classOf[Float => Int] -> "apply$mcIF$sp"
case (FloatType, LongType) => classOf[Float => Long] -> "apply$mcJF$sp"
case (FloatType, FloatType) => classOf[Float => Float] -> "apply$mcFF$sp"
case (FloatType, DoubleType) => classOf[Float => Double] -> "apply$mcDF$sp"
case (DoubleType, IntegerType) => classOf[Double => Int] -> "apply$mcID$sp"
case (DoubleType, LongType) => classOf[Double => Long] -> "apply$mcJD$sp"
case (DoubleType, FloatType) => classOf[Double => Float] -> "apply$mcFD$sp"
case (DoubleType, DoubleType) => classOf[Double => Double] -> "apply$mcDD$sp"
case _ => classOf[Any => Any] -> "apply"
}
}
val funcObj = Literal.create(func, ObjectType(funcClass))
val callFunc = Invoke(funcObj, methodName, outputObjAttr.dataType, child.output)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,33 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
2, 3, 4)
}

test("mapPrimitive") {
val dsInt = Seq(1, 2, 3).toDS()
checkDataset(dsInt.map(e => e + 1), 2, 3, 4)
checkDataset(dsInt.map(e => e + 8589934592L), 8589934593L, 8589934594L, 8589934595L)
checkDataset(dsInt.map(e => e + 1.1F), 2.1F, 3.1F, 4.1F)
checkDataset(dsInt.map(e => e + 1.23D), 2.23D, 3.23D, 4.23D)

val dsLong = Seq(1L, 2L, 3L).toDS()
checkDataset(dsLong.map(e => (e + 1).toInt), 2, 3, 4)
checkDataset(dsLong.map(e => e + 8589934592L), 8589934593L, 8589934594L, 8589934595L)
checkDataset(dsLong.map(e => e + 1.1F), 2.1F, 3.1F, 4.1F)
checkDataset(dsLong.map(e => e + 1.23D), 2.23D, 3.23D, 4.23D)

val dsFloat = Seq(1F, 2F, 3F).toDS()
checkDataset(dsFloat.map(e => (e + 1).toInt), 2, 3, 4)
checkDataset(dsFloat.map(e => (e + 123456L).toLong), 123457L, 123458L, 123459L)
checkDataset(dsFloat.map(e => e + 1.1F), 2.1F, 3.1F, 4.1F)
checkDataset(dsFloat.map(e => e + 1.23D), 2.23D, 3.23D, 4.23D)

val dsDouble = Seq(1D, 2D, 3D).toDS()
checkDataset(dsDouble.map(e => (e + 1).toInt), 2, 3, 4)
checkDataset(dsDouble.map(e => (e + 8589934592L).toLong),
8589934593L, 8589934594L, 8589934595L)
checkDataset(dsDouble.map(e => (e + 1.1F).toFloat), 2.1F, 3.1F, 4.1F)
checkDataset(dsDouble.map(e => e + 1.23D), 2.23D, 3.23D, 4.23D)
}

test("filter") {
val ds = Seq(1, 2, 3, 4).toDS()
checkDataset(
Expand Down