Skip to content

Commit

Permalink
support Boolean for returnType
Browse files Browse the repository at this point in the history
Use compositional approach instead of enumeration approach
  • Loading branch information
kiszk committed Mar 9, 2017
1 parent dfbce2a commit 8ee91af
Showing 1 changed file with 22 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.spark.sql.execution

import scala.language.existentials

import org.apache.spark.api.java.function.MapFunction
import org.apache.spark.api.r._
import org.apache.spark.broadcast.Broadcast
Expand All @@ -33,6 +32,7 @@ import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState
import org.apache.spark.sql.execution.streaming.KeyedStateImpl
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils


/**
Expand Down Expand Up @@ -216,35 +216,38 @@ case class MapElementsExec(
child.asInstanceOf[CodegenSupport].produce(ctx, this)
}

private def getMethodType(dt: DataType, isOutput: Boolean): String = {
dt match {
case BooleanType if isOutput => "Z"
case IntegerType => "I"
case LongType => "J"
case FloatType => "F"
case DoubleType => "D"
case _ => null
}
}

override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
val (funcClass, methodName) = func match {
case m: MapFunction[_, _] => classOf[MapFunction[_, _]] -> "call"
case _ => (child.output(0).dataType, outputObjAttr.dataType) match {
// load "scala.Function1" using Java API to avoid requirements of type parameters
case _ => Utils.classForName("scala.Function1") -> {
// if a pair of an argument and return types is one of specific types
// whose specialized method (apply$mc..$sp) is generated by scalac,
// Catalyst generated a direct method call to the specialized method.
// The followings are references for this specialization:
// http://www.scala-lang.org/api/2.12.0/scala/Function1.html
// https://github.com/scala/scala/blob/2.11.x/src/compiler/scala/tools/nsc/transform/
// SpecializeTypes.scala
// http://www.cakesolutions.net/teamblogs/scala-dissection-functions
// http://axel22.github.io/2013/11/03/specialization-quirks.html
case (IntegerType, IntegerType) => classOf[Int => Int] -> "apply$mcII$sp"
case (IntegerType, LongType) => classOf[Int => Long] -> "apply$mcJI$sp"
case (IntegerType, FloatType) => classOf[Int => Float] -> "apply$mcFI$sp"
case (IntegerType, DoubleType) => classOf[Int => Double] -> "apply$mcDI$sp"
case (LongType, IntegerType) => classOf[Long => Int] -> "apply$mcIJ$sp"
case (LongType, LongType) => classOf[Long => Long] -> "apply$mcJJ$sp"
case (LongType, FloatType) => classOf[Long => Float] -> "apply$mcFJ$sp"
case (LongType, DoubleType) => classOf[Long => Double] -> "apply$mcDJ$sp"
case (FloatType, IntegerType) => classOf[Float => Int] -> "apply$mcIF$sp"
case (FloatType, LongType) => classOf[Float => Long] -> "apply$mcJF$sp"
case (FloatType, FloatType) => classOf[Float => Float] -> "apply$mcFF$sp"
case (FloatType, DoubleType) => classOf[Float => Double] -> "apply$mcDF$sp"
case (DoubleType, IntegerType) => classOf[Double => Int] -> "apply$mcID$sp"
case (DoubleType, LongType) => classOf[Double => Long] -> "apply$mcJD$sp"
case (DoubleType, FloatType) => classOf[Double => Float] -> "apply$mcFD$sp"
case (DoubleType, DoubleType) => classOf[Double => Double] -> "apply$mcDD$sp"
case _ => classOf[Any => Any] -> "apply"
val inputType = getMethodType(child.output(0).dataType, false)
val outputType = getMethodType(outputObjAttr.dataType, true)
if (inputType != null && outputType != null) {
s"apply$$mc$outputType$inputType$$sp"
} else {
"apply"
}
}
}
val funcObj = Literal.create(func, ObjectType(funcClass))
Expand Down

0 comments on commit 8ee91af

Please sign in to comment.