Skip to content

Commit

Permalink
[SPARK-24336][SQL] Support 'pass through' transformation in BasicOper…
Browse files Browse the repository at this point in the history
…ators
  • Loading branch information
HeartSaVioR committed May 21, 2018
1 parent 5be8aab commit 139aefa
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,33 @@ object ScalaReflection extends ScalaReflection {
*/
def getClassFromType(tpe: Type): Class[_] = mirror.runtimeClass(tpe.dealias.typeSymbol.asClass)

def getClassFromTypeHandleArray(tpe: Type): Class[_] = cleanUpReflectionObjects {
tpe.dealias match {
case ty if ty <:< localTypeOf[Array[_]] =>
def arrayClassFromType(tpe: `Type`): Class[_] =
ScalaReflection.cleanUpReflectionObjects {
tpe.dealias match {
case t if t <:< definitions.IntTpe => classOf[Array[Int]]
case t if t <:< definitions.LongTpe => classOf[Array[Long]]
case t if t <:< definitions.DoubleTpe => classOf[Array[Double]]
case t if t <:< definitions.FloatTpe => classOf[Array[Float]]
case t if t <:< definitions.ShortTpe => classOf[Array[Short]]
case t if t <:< definitions.ByteTpe => classOf[Array[Byte]]
case t if t <:< definitions.BooleanTpe => classOf[Array[Boolean]]
case _ =>
// There is probably a better way to do this, but I couldn't find it...
val elementType = getClassFromTypeHandleArray(tpe)
java.lang.reflect.Array.newInstance(elementType, 1).getClass
}
}

val TypeRef(_, _, Seq(elementType)) = ty
arrayClassFromType(elementType)

case ty => getClassFromType(ty)
}
}

case class Schema(dataType: DataType, nullable: Boolean)

/** Returns a Sequence of attributes for the given case class type. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@

package org.apache.spark.sql.execution

import java.lang.reflect.Constructor

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{execution, AnalysisException, Strategy}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection}
import org.apache.spark.sql.catalyst.ScalaReflection._
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
Expand Down Expand Up @@ -474,8 +477,80 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
}

// Can we automate these 'pass through' operations?
object BasicOperators extends Strategy {

import universe._

// Enumerate the pair of logical plan and physical plan which can be transformed via
// 'pass-through', which can be achieved when the difference between parameters on
// primary constructor in both plans is just LogicalPlan vs SparkPlan.
// The map should exclude the pair which 'pass-through' needs to leverage default value of
// constructor parameter.
val passThroughOperators: Map[Class[_ <: LogicalPlan], Class[_ <: SparkPlan]] = Map(
(classOf[logical.DeserializeToObject], classOf[execution.DeserializeToObjectExec]),
(classOf[logical.SerializeFromObject], classOf[execution.SerializeFromObjectExec]),
(classOf[logical.MapPartitions], classOf[execution.MapPartitionsExec]),
(classOf[logical.FlatMapGroupsInR], classOf[execution.FlatMapGroupsInRExec]),
(classOf[logical.FlatMapGroupsInPandas], classOf[execution.python.FlatMapGroupsInPandasExec]),
(classOf[logical.AppendColumnsWithObject], classOf[execution.AppendColumnsWithObjectExec]),
(classOf[logical.MapGroups], classOf[execution.MapGroupsExec]),
(classOf[logical.CoGroup], classOf[execution.CoGroupExec]),
(classOf[logical.Project], classOf[execution.ProjectExec]),
(classOf[logical.Filter], classOf[execution.FilterExec]),
(classOf[logical.Window], classOf[execution.window.WindowExec]),
(classOf[logical.Sample], classOf[execution.SampleExec])
)

lazy val operatorToConstructorParameters: Map[Class[_ <: LogicalPlan], Seq[(String, Type)]] =
passThroughOperators.map {
case (srcOpCls, _) =>
(srcOpCls, ScalaReflection.getConstructorParameters(srcOpCls))
}.toMap

lazy val operatorToTargetConstructor: Map[Class[_ <: LogicalPlan], Constructor[_]] =
passThroughOperators.map {
case (srcOpCls, tgtOpCls) =>
val logicalPlanCls = classOf[LogicalPlan]
val m = runtimeMirror(logicalPlanCls.getClassLoader)
val classSymbol = m.staticClass(logicalPlanCls.getName)
val logicalPlanType = classSymbol.selfType

val paramTypes = operatorToConstructorParameters(srcOpCls).map(_._2)
val convertedParamTypes = ScalaReflection.cleanUpReflectionObjects {
paramTypes.map {
case ty if ty <:< logicalPlanType =>
m.staticClass(classOf[SparkPlan].getName).selfType

case ty => ty
}
}

val convertedParamClasses = convertedParamTypes.map(
ScalaReflection.getClassFromTypeHandleArray)
val constructorOption = ScalaReflection.findConstructor(tgtOpCls, convertedParamClasses)

constructorOption match {
case Some(const: Constructor[_]) => (srcOpCls, const)
case _ => throw new IllegalStateException(
s"Matching constructor ${srcOpCls.getName} must be presented in ${tgtOpCls.getName}!")
}
}.toMap

def createPassThroughOutputPlan(src: LogicalPlan): SparkPlan = {
val srcClass = src.getClass
require(passThroughOperators.contains(srcClass))
val paramValues = operatorToConstructorParameters(srcClass).map(_._1).map { name =>
srcClass.getMethod(name).invoke(src)
}
val convertedParamValues = paramValues.map {
case p if p.isInstanceOf[LogicalPlan] => planLater(p.asInstanceOf[LogicalPlan])
case p => p
}

val const = operatorToTargetConstructor(srcClass)
const.newInstance(convertedParamValues: _*).asInstanceOf[SparkPlan]
}

def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case d: DataWritingCommand => DataWritingCommandExec(d, planLater(d.query)) :: Nil
case r: RunnableCommand => ExecutedCommandExec(r) :: Nil
Expand All @@ -497,36 +572,23 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
throw new IllegalStateException(
"logical except operator should have been replaced by anti-join in the optimizer")

case logical.DeserializeToObject(deserializer, objAttr, child) =>
execution.DeserializeToObjectExec(deserializer, objAttr, planLater(child)) :: Nil
case logical.SerializeFromObject(serializer, child) =>
execution.SerializeFromObjectExec(serializer, planLater(child)) :: Nil
case logical.MapPartitions(f, objAttr, child) =>
execution.MapPartitionsExec(f, objAttr, planLater(child)) :: Nil
case src if passThroughOperators.contains(src.getClass) =>
createPassThroughOutputPlan(src) :: Nil

case logical.MapPartitionsInR(f, p, b, is, os, objAttr, child) =>
execution.MapPartitionsExec(
execution.r.MapPartitionsRWrapper(f, p, b, is, os), objAttr, planLater(child)) :: Nil
case logical.FlatMapGroupsInR(f, p, b, is, os, key, value, grouping, data, objAttr, child) =>
execution.FlatMapGroupsInRExec(f, p, b, is, os, key, value, grouping,
data, objAttr, planLater(child)) :: Nil
case logical.FlatMapGroupsInPandas(grouping, func, output, child) =>
execution.python.FlatMapGroupsInPandasExec(grouping, func, output, planLater(child)) :: Nil
case logical.MapElements(f, _, _, objAttr, child) =>
execution.MapElementsExec(f, objAttr, planLater(child)) :: Nil
case logical.AppendColumns(f, _, _, in, out, child) =>
execution.AppendColumnsExec(f, in, out, planLater(child)) :: Nil
case logical.AppendColumnsWithObject(f, childSer, newSer, child) =>
execution.AppendColumnsWithObjectExec(f, childSer, newSer, planLater(child)) :: Nil
case logical.MapGroups(f, key, value, grouping, data, objAttr, child) =>
execution.MapGroupsExec(f, key, value, grouping, data, objAttr, planLater(child)) :: Nil
case logical.FlatMapGroupsWithState(
f, key, value, grouping, data, output, _, _, _, timeout, child) =>
execution.MapGroupsExec(
f, key, value, grouping, data, output, timeout, planLater(child)) :: Nil
case logical.CoGroup(f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, left, right) =>
execution.CoGroupExec(
f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr,
planLater(left), planLater(right)) :: Nil

case logical.Repartition(numPartitions, shuffle, child) =>
if (shuffle) {
Expand All @@ -536,18 +598,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
case logical.Sort(sortExprs, global, child) =>
execution.SortExec(sortExprs, global, planLater(child)) :: Nil
case logical.Project(projectList, child) =>
execution.ProjectExec(projectList, planLater(child)) :: Nil
case logical.Filter(condition, child) =>
execution.FilterExec(condition, planLater(child)) :: Nil
case f: logical.TypedFilter =>
execution.FilterExec(f.typedCondition(f.deserializer), planLater(f.child)) :: Nil
case e @ logical.Expand(_, _, child) =>
execution.ExpandExec(e.projections, e.output, planLater(child)) :: Nil
case logical.Window(windowExprs, partitionSpec, orderSpec, child) =>
execution.window.WindowExec(windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil
case logical.Sample(lb, ub, withReplacement, seed, child) =>
execution.SampleExec(lb, ub, withReplacement, seed, planLater(child)) :: Nil
case logical.LocalRelation(output, data, _) =>
LocalTableScanExec(output, data) :: Nil
case logical.LocalLimit(IntegerLiteral(limit), child) =>
Expand Down

0 comments on commit 139aefa

Please sign in to comment.