Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-19008][SQL] Improve performance of Dataset.map by eliminating boxing/unboxing #17172

Closed
wants to merge 12 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

object CatalystSerde {
def deserialize[T : Encoder](child: LogicalPlan): DeserializeToObject = {
Expand Down Expand Up @@ -210,13 +211,48 @@ case class TypedFilter(
def typedCondition(input: Expression): Expression = {
val (funcClass, methodName) = func match {
case m: FilterFunction[_] => classOf[FilterFunction[_]] -> "call"
case _ => classOf[Any => Boolean] -> "apply"
case _ => FunctionUtils.getFunctionOneName(BooleanType, input.dataType)
}
val funcObj = Literal.create(func, ObjectType(funcClass))
Invoke(funcObj, methodName, BooleanType, input :: Nil)
}
}

object FunctionUtils {
private def getMethodType(dt: DataType, isOutput: Boolean): Option[String] = {
dt match {
case BooleanType if isOutput => Some("Z")
case IntegerType => Some("I")
case LongType => Some("J")
case FloatType => Some("F")
case DoubleType => Some("D")
case _ => None
}
}

def getFunctionOneName(outputDT: DataType, inputDT: DataType): (Class[_], String) = {
// load "scala.Function1" using Java API to avoid requirements of type parameters
Utils.classForName("scala.Function1") -> {
// if a pair of an argument and return types is one of specific types
// whose specialized method (apply$mc..$sp) is generated by scalac,
// Catalyst generated a direct method call to the specialized method.
// The followings are references for this specialization:
// http://www.scala-lang.org/api/2.12.0/scala/Function1.html
// https://github.com/scala/scala/blob/2.11.x/src/compiler/scala/tools/nsc/transform/
// SpecializeTypes.scala
// http://www.cakesolutions.net/teamblogs/scala-dissection-functions
// http://axel22.github.io/2013/11/03/specialization-quirks.html
val inputType = getMethodType(inputDT, false)
val outputType = getMethodType(outputDT, true)
if (inputType.isDefined && outputType.isDefined) {
s"apply$$mc${outputType.get}${inputType.get}$$sp"
} else {
"apply"
}
}
}
}

/** Factory for constructing new `AppendColumn` nodes. */
object AppendColumns {
def apply[T : Encoder, U : Encoder](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,13 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
import org.apache.spark.sql.catalyst.plans.logical.FunctionUtils
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState
import org.apache.spark.sql.execution.streaming.KeyedStateImpl
import org.apache.spark.sql.types.{DataType, ObjectType, StructType}
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils


/**
Expand Down Expand Up @@ -219,7 +221,7 @@ case class MapElementsExec(
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
val (funcClass, methodName) = func match {
case m: MapFunction[_, _] => classOf[MapFunction[_, _]] -> "call"
case _ => classOf[Any => Any] -> "apply"
case _ => FunctionUtils.getFunctionOneName(outputObjAttr.dataType, child.output(0).dataType)
}
val funcObj = Literal.create(func, ObjectType(funcClass))
val callFunc = Invoke(funcObj, methodName, outputObjAttr.dataType, child.output)
Expand Down
122 changes: 116 additions & 6 deletions sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,49 @@ object DatasetBenchmark {

case class Data(l: Long, s: String)

def backToBackMapLong(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = {
import spark.implicits._

val rdd = spark.sparkContext.range(0, numRows)
val ds = spark.range(0, numRows)
val df = ds.toDF("l")
val func = (l: Long) => l + 1

val benchmark = new Benchmark("back-to-back map long", numRows)

benchmark.addCase("RDD") { iter =>
var res = rdd
var i = 0
while (i < numChains) {
res = res.map(func)
i += 1
}
res.foreach(_ => Unit)
}

benchmark.addCase("DataFrame") { iter =>
var res = df
var i = 0
while (i < numChains) {
res = res.select($"l" + 1 as "l")
i += 1
}
res.queryExecution.toRdd.foreach(_ => Unit)
}

benchmark.addCase("Dataset") { iter =>
var res = ds.as[Long]
var i = 0
while (i < numChains) {
res = res.map(func)
i += 1
}
res.queryExecution.toRdd.foreach(_ => Unit)
}

benchmark
}

def backToBackMap(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = {
import spark.implicits._

Expand Down Expand Up @@ -72,6 +115,49 @@ object DatasetBenchmark {
benchmark
}

def backToBackFilterLong(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = {
import spark.implicits._

val rdd = spark.sparkContext.range(1, numRows)
val ds = spark.range(1, numRows)
val df = ds.toDF("l")
val func = (l: Long) => l % 2L == 0L

val benchmark = new Benchmark("back-to-back filter Long", numRows)

benchmark.addCase("RDD") { iter =>
var res = rdd
var i = 0
while (i < numChains) {
res = res.filter(func)
i += 1
}
res.foreach(_ => Unit)
}

benchmark.addCase("DataFrame") { iter =>
var res = df
var i = 0
while (i < numChains) {
res = res.filter($"l" % 2L === 0L)
i += 1
}
res.queryExecution.toRdd.foreach(_ => Unit)
}

benchmark.addCase("Dataset") { iter =>
var res = ds.as[Long]
var i = 0
while (i < numChains) {
res = res.filter(func)
i += 1
}
res.queryExecution.toRdd.foreach(_ => Unit)
}

benchmark
}

def backToBackFilter(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = {
import spark.implicits._

Expand Down Expand Up @@ -165,9 +251,22 @@ object DatasetBenchmark {
val numRows = 100000000
val numChains = 10

val benchmark = backToBackMap(spark, numRows, numChains)
val benchmark2 = backToBackFilter(spark, numRows, numChains)
val benchmark3 = aggregate(spark, numRows)
val benchmark0 = backToBackMapLong(spark, numRows, numChains)
val benchmark1 = backToBackMap(spark, numRows, numChains)
val benchmark2 = backToBackFilterLong(spark, numRows, numChains)
val benchmark3 = backToBackFilter(spark, numRows, numChains)
val benchmark4 = aggregate(spark, numRows)

/*
OpenJDK 64-Bit Server VM 1.8.0_111-8u111-b14-2ubuntu0.16.04.2-b14 on Linux 4.4.0-47-generic
Intel(R) Xeon(R) CPU E5-2667 v3 @ 3.20GHz
back-to-back map long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
RDD 1883 / 1892 53.1 18.8 1.0X
DataFrame 502 / 642 199.1 5.0 3.7X
Dataset 657 / 784 152.2 6.6 2.9X
*/
benchmark0.run()

/*
OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 3.10.0-327.18.2.el7.x86_64
Expand All @@ -178,7 +277,18 @@ object DatasetBenchmark {
DataFrame 2647 / 3116 37.8 26.5 1.3X
Dataset 4781 / 5155 20.9 47.8 0.7X
*/
benchmark.run()
benchmark1.run()

/*
OpenJDK 64-Bit Server VM 1.8.0_121-8u121-b13-0ubuntu1.16.04.2-b13 on Linux 4.4.0-47-generic
Intel(R) Xeon(R) CPU E5-2667 v3 @ 3.20GHz
back-to-back filter Long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
RDD 846 / 1120 118.1 8.5 1.0X
DataFrame 270 / 329 370.9 2.7 3.1X
Dataset 545 / 789 183.5 5.4 1.6X
*/
benchmark2.run()

/*
OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 3.10.0-327.18.2.el7.x86_64
Expand All @@ -189,7 +299,7 @@ object DatasetBenchmark {
DataFrame 59 / 72 1695.4 0.6 22.8X
Dataset 2777 / 2805 36.0 27.8 0.5X
*/
benchmark2.run()
benchmark3.run()

/*
Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.12.1
Expand All @@ -201,6 +311,6 @@ object DatasetBenchmark {
Dataset sum using Aggregator 4656 / 4758 21.5 46.6 0.4X
Dataset complex Aggregator 6636 / 7039 15.1 66.4 0.3X
*/
benchmark3.run()
benchmark4.run()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,64 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
2, 3, 4)
}

test("mapPrimitive") {
val dsInt = Seq(1, 2, 3).toDS()
checkDataset(dsInt.map(_ > 1), false, true, true)
checkDataset(dsInt.map(_ + 1), 2, 3, 4)
checkDataset(dsInt.map(_ + 8589934592L), 8589934593L, 8589934594L, 8589934595L)
checkDataset(dsInt.map(_ + 1.1F), 2.1F, 3.1F, 4.1F)
checkDataset(dsInt.map(_ + 1.23D), 2.23D, 3.23D, 4.23D)

val dsLong = Seq(1L, 2L, 3L).toDS()
checkDataset(dsLong.map(_ > 1), false, true, true)
checkDataset(dsLong.map(e => (e + 1).toInt), 2, 3, 4)
checkDataset(dsLong.map(_ + 8589934592L), 8589934593L, 8589934594L, 8589934595L)
checkDataset(dsLong.map(_ + 1.1F), 2.1F, 3.1F, 4.1F)
checkDataset(dsLong.map(_ + 1.23D), 2.23D, 3.23D, 4.23D)

val dsFloat = Seq(1F, 2F, 3F).toDS()
checkDataset(dsFloat.map(_ > 1), false, true, true)
checkDataset(dsFloat.map(e => (e + 1).toInt), 2, 3, 4)
checkDataset(dsFloat.map(e => (e + 123456L).toLong), 123457L, 123458L, 123459L)
checkDataset(dsFloat.map(_ + 1.1F), 2.1F, 3.1F, 4.1F)
checkDataset(dsFloat.map(_ + 1.23D), 2.23D, 3.23D, 4.23D)

val dsDouble = Seq(1D, 2D, 3D).toDS()
checkDataset(dsDouble.map(_ > 1), false, true, true)
checkDataset(dsDouble.map(e => (e + 1).toInt), 2, 3, 4)
checkDataset(dsDouble.map(e => (e + 8589934592L).toLong),
8589934593L, 8589934594L, 8589934595L)
checkDataset(dsDouble.map(e => (e + 1.1F).toFloat), 2.1F, 3.1F, 4.1F)
checkDataset(dsDouble.map(_ + 1.23D), 2.23D, 3.23D, 4.23D)

val dsBoolean = Seq(true, false).toDS()
checkDataset(dsBoolean.map(e => !e), false, true)
}

test("filter") {
val ds = Seq(1, 2, 3, 4).toDS()
checkDataset(
ds.filter(_ % 2 == 0),
2, 4)
}

test("filterPrimitive") {
val dsInt = Seq(1, 2, 3).toDS()
checkDataset(dsInt.filter(_ > 1), 2, 3)

val dsLong = Seq(1L, 2L, 3L).toDS()
checkDataset(dsLong.filter(_ > 1), 2L, 3L)

val dsFloat = Seq(1F, 2F, 3F).toDS()
checkDataset(dsFloat.filter(_ > 1), 2F, 3F)

val dsDouble = Seq(1D, 2D, 3D).toDS()
checkDataset(dsDouble.filter(_ > 1), 2D, 3D)

val dsBoolean = Seq(true, false).toDS()
checkDataset(dsBoolean.filter(e => !e), false)
}

test("foreach") {
val ds = Seq(1, 2, 3).toDS()
val acc = sparkContext.longAccumulator
Expand Down