diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index 199ba5ce6969b..45db1f110f2d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution +import com.sun.org.apache.xalan.internal.xsltc.compiler.util.VoidType + import scala.language.existentials import org.apache.spark.api.java.function.MapFunction @@ -32,7 +34,7 @@ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState import org.apache.spark.sql.execution.streaming.KeyedStateImpl -import org.apache.spark.sql.types.{DataType, ObjectType, StructType} +import org.apache.spark.sql.types._ /** @@ -217,9 +219,33 @@ case class MapElementsExec( } override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { + val inType = if (child.output.length == 1) child.output(0).dataType else NullType + val outType = outputObjAttr.dataType val (funcClass, methodName) = func match { case m: MapFunction[_, _] => classOf[MapFunction[_, _]] -> "call" - case _ => classOf[Any => Any] -> "apply" + case _ => + (inType, outType) match { + // if a pair of an argument and return types is one of specific types + // whose specialized method (apply$mc..$sp) is generated by scalac, + // Catalyst generated a direct method call to the specialized method. + case (IntegerType, IntegerType) => classOf[Int => Int] -> "apply$mcII$sp" + case (IntegerType, LongType) => classOf[Int => Long] -> "apply$mcJI$sp" + case (IntegerType, FloatType) => classOf[Int => Float] -> "apply$mcFI$sp" + case (IntegerType, DoubleType) => classOf[Int => Double] -> "apply$mcDI$sp" + case (LongType, IntegerType) => classOf[Long => Int] -> "apply$mcIJ$sp" + case (LongType, LongType) => classOf[Long => Long] -> "apply$mcJJ$sp" + case (LongType, FloatType) => classOf[Long => Float] -> "apply$mcFJ$sp" + case (LongType, DoubleType) => classOf[Long => Double] -> "apply$mcDJ$sp" + case (FloatType, IntegerType) => classOf[Float => Int] -> "apply$mcIF$sp" + case (FloatType, LongType) => classOf[Float => Long] -> "apply$mcJF$sp" + case (FloatType, FloatType) => classOf[Float => Float] -> "apply$mcFF$sp" + case (FloatType, DoubleType) => classOf[Float => Double] -> "apply$mcDF$sp" + case (DoubleType, IntegerType) => classOf[Double => Int] -> "apply$mcID$sp" + case (DoubleType, LongType) => classOf[Double => Long] -> "apply$mcJD$sp" + case (DoubleType, FloatType) => classOf[Double => Float] -> "apply$mcFD$sp" + case (DoubleType, DoubleType) => classOf[Double => Double] -> "apply$mcDD$sp" + case _ => classOf[Any => Any] -> "apply" + } } val funcObj = Literal.create(func, ObjectType(funcClass)) val callFunc = Invoke(funcObj, methodName, outputObjAttr.dataType, child.output) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index 6b50cb3e48c76..8e52f36a18428 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -62,6 +62,33 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { 2, 3, 4) } + test("mapPrimitive") { + val dsInt = Seq(1, 2, 3).toDS() + checkDataset(dsInt.map(e => e + 1), 2, 3, 4) + checkDataset(dsInt.map(e => e + 8589934592L), 8589934593L, 8589934594L, 8589934595L) + checkDataset(dsInt.map(e => e + 1.1F), 2.1F, 3.1F, 4.1F) + checkDataset(dsInt.map(e => e + 1.23D), 2.23D, 3.23D, 4.23D) + + val dsLong = Seq(1L, 2L, 3L).toDS() + checkDataset(dsLong.map(e => (e + 1).toInt), 2, 3, 4) + checkDataset(dsLong.map(e => e + 8589934592L), 8589934593L, 8589934594L, 8589934595L) + checkDataset(dsLong.map(e => e + 1.1F), 2.1F, 3.1F, 4.1F) + checkDataset(dsLong.map(e => e + 1.23D), 2.23D, 3.23D, 4.23D) + + val dsFloat = Seq(1F, 2F, 3F).toDS() + checkDataset(dsFloat.map(e => (e + 1).toInt), 2, 3, 4) + checkDataset(dsFloat.map(e => (e + 123456L).toLong), 123457L, 123458L, 123459L) + checkDataset(dsFloat.map(e => e + 1.1F), 2.1F, 3.1F, 4.1F) + checkDataset(dsFloat.map(e => e + 1.23D), 2.23D, 3.23D, 4.23D) + + val dsDouble = Seq(1D, 2D, 3D).toDS() + checkDataset(dsDouble.map(e => (e + 1).toInt), 2, 3, 4) + checkDataset(dsDouble.map(e => (e + 8589934592L).toLong), + 8589934593L, 8589934594L, 8589934595L) + checkDataset(dsDouble.map(e => (e + 1.1F).toFloat), 2.1F, 3.1F, 4.1F) + checkDataset(dsDouble.map(e => e + 1.23D), 2.23D, 3.23D, 4.23D) + } + test("filter") { val ds = Seq(1, 2, 3, 4).toDS() checkDataset(