Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-24336][SQL] Support 'pass through' transformation in BasicOperators #21388

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,33 @@ object ScalaReflection extends ScalaReflection {
*/
def getClassFromType(tpe: Type): Class[_] = mirror.runtimeClass(tpe.dealias.typeSymbol.asClass)

def getClassFromTypeHandleArray(tpe: Type): Class[_] = cleanUpReflectionObjects {
tpe.dealias match {
case ty if ty <:< localTypeOf[Array[_]] =>
val TypeRef(_, _, Seq(elementType)) = ty
arrayClassFromType(elementType)

case ty => getClassFromType(ty)
}
}

private def arrayClassFromType(tpe: Type): Class[_] =
ScalaReflection.cleanUpReflectionObjects {
tpe.dealias match {
case t if t <:< definitions.IntTpe => classOf[Array[Int]]
case t if t <:< definitions.LongTpe => classOf[Array[Long]]
case t if t <:< definitions.DoubleTpe => classOf[Array[Double]]
case t if t <:< definitions.FloatTpe => classOf[Array[Float]]
case t if t <:< definitions.ShortTpe => classOf[Array[Short]]
case t if t <:< definitions.ByteTpe => classOf[Array[Byte]]
case t if t <:< definitions.BooleanTpe => classOf[Array[Boolean]]
case _ =>
// There is probably a better way to do this, but I couldn't find it...
val elementType = getClassFromTypeHandleArray(tpe)
java.lang.reflect.Array.newInstance(elementType, 1).getClass
}
}

case class Schema(dataType: DataType, nullable: Boolean)

/** Returns a Sequence of attributes for the given case class type. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@

package org.apache.spark.sql.execution

import java.lang.reflect.Constructor

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{execution, AnalysisException, Strategy}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection}
import org.apache.spark.sql.catalyst.ScalaReflection._
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
Expand Down Expand Up @@ -474,8 +477,78 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
}

// Can we automate these 'pass through' operations?
object BasicOperators extends Strategy {

import universe._

// Enumerate the pair of logical plan and physical plan which can be transformed via
// 'pass-through', which can be achieved when the difference between parameters on
// primary constructor in both plans is just LogicalPlan vs SparkPlan.
// The map should exclude the pair which 'pass-through' needs to leverage default value of
// constructor parameter.
val passThroughOperators: Map[Class[_ <: LogicalPlan], Class[_ <: SparkPlan]] = Map(
(classOf[logical.DeserializeToObject], classOf[execution.DeserializeToObjectExec]),
(classOf[logical.SerializeFromObject], classOf[execution.SerializeFromObjectExec]),
(classOf[logical.MapPartitions], classOf[execution.MapPartitionsExec]),
(classOf[logical.FlatMapGroupsInR], classOf[execution.FlatMapGroupsInRExec]),
(classOf[logical.FlatMapGroupsInPandas], classOf[execution.python.FlatMapGroupsInPandasExec]),
(classOf[logical.AppendColumnsWithObject], classOf[execution.AppendColumnsWithObjectExec]),
(classOf[logical.MapGroups], classOf[execution.MapGroupsExec]),
(classOf[logical.CoGroup], classOf[execution.CoGroupExec]),
(classOf[logical.Project], classOf[execution.ProjectExec]),
(classOf[logical.Filter], classOf[execution.FilterExec]),
(classOf[logical.Window], classOf[execution.window.WindowExec]),
(classOf[logical.Sample], classOf[execution.SampleExec])
)

lazy val operatorToConstructorParameters: Map[Class[_ <: LogicalPlan], Seq[(String, Type)]] =
passThroughOperators.map { case (srcOpCls, _) =>
(srcOpCls, ScalaReflection.getConstructorParameters(srcOpCls))
}.toMap

lazy val operatorToTargetConstructor: Map[Class[_ <: LogicalPlan], Constructor[_]] =
passThroughOperators.map { case (srcOpCls, tgtOpCls) =>
val logicalPlanCls = classOf[LogicalPlan]
val m = runtimeMirror(logicalPlanCls.getClassLoader)
val classSymbol = m.staticClass(logicalPlanCls.getName)
val logicalPlanType = classSymbol.selfType

val paramTypes = operatorToConstructorParameters(srcOpCls).map(_._2)
val convertedParamTypes = ScalaReflection.cleanUpReflectionObjects {
paramTypes.map {
case ty if ty <:< logicalPlanType =>
m.staticClass(classOf[SparkPlan].getName).selfType

case ty => ty
}
}

val convertedParamClasses = convertedParamTypes.map(
ScalaReflection.getClassFromTypeHandleArray)
val constructorOption = ScalaReflection.findConstructor(tgtOpCls, convertedParamClasses)

constructorOption match {
case Some(const: Constructor[_]) => (srcOpCls, const)
case _ => throw new IllegalStateException(
s"Matching constructor ${srcOpCls.getName} must be presented in ${tgtOpCls.getName}!")
}
}.toMap

def createPassThroughOutputPlan(src: LogicalPlan): SparkPlan = {
val srcClass = src.getClass
require(passThroughOperators.contains(srcClass))
val paramValues = operatorToConstructorParameters(srcClass).map(_._1).map { name =>
srcClass.getMethod(name).invoke(src)
}
val convertedParamValues = paramValues.map {
case p if p.isInstanceOf[LogicalPlan] => planLater(p.asInstanceOf[LogicalPlan])
case p => p
}

val const = operatorToTargetConstructor(srcClass)
const.newInstance(convertedParamValues: _*).asInstanceOf[SparkPlan]
}

def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case d: DataWritingCommand => DataWritingCommandExec(d, planLater(d.query)) :: Nil
case r: RunnableCommand => ExecutedCommandExec(r) :: Nil
Expand All @@ -497,36 +570,20 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
throw new IllegalStateException(
"logical except operator should have been replaced by anti-join in the optimizer")

case logical.DeserializeToObject(deserializer, objAttr, child) =>
execution.DeserializeToObjectExec(deserializer, objAttr, planLater(child)) :: Nil
case logical.SerializeFromObject(serializer, child) =>
execution.SerializeFromObjectExec(serializer, planLater(child)) :: Nil
case logical.MapPartitions(f, objAttr, child) =>
execution.MapPartitionsExec(f, objAttr, planLater(child)) :: Nil
case src if passThroughOperators.contains(src.getClass) =>
createPassThroughOutputPlan(src) :: Nil

case logical.MapPartitionsInR(f, p, b, is, os, objAttr, child) =>
execution.MapPartitionsExec(
execution.r.MapPartitionsRWrapper(f, p, b, is, os), objAttr, planLater(child)) :: Nil
case logical.FlatMapGroupsInR(f, p, b, is, os, key, value, grouping, data, objAttr, child) =>
execution.FlatMapGroupsInRExec(f, p, b, is, os, key, value, grouping,
data, objAttr, planLater(child)) :: Nil
case logical.FlatMapGroupsInPandas(grouping, func, output, child) =>
execution.python.FlatMapGroupsInPandasExec(grouping, func, output, planLater(child)) :: Nil
case logical.MapElements(f, _, _, objAttr, child) =>
execution.MapElementsExec(f, objAttr, planLater(child)) :: Nil
case logical.AppendColumns(f, _, _, in, out, child) =>
execution.AppendColumnsExec(f, in, out, planLater(child)) :: Nil
case logical.AppendColumnsWithObject(f, childSer, newSer, child) =>
execution.AppendColumnsWithObjectExec(f, childSer, newSer, planLater(child)) :: Nil
case logical.MapGroups(f, key, value, grouping, data, objAttr, child) =>
execution.MapGroupsExec(f, key, value, grouping, data, objAttr, planLater(child)) :: Nil
case logical.FlatMapGroupsWithState(
f, key, value, grouping, data, output, _, _, _, timeout, child) =>
execution.MapGroupsExec(
f, key, value, grouping, data, output, timeout, planLater(child)) :: Nil
case logical.CoGroup(f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, left, right) =>
execution.CoGroupExec(
f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr,
planLater(left), planLater(right)) :: Nil

case logical.Repartition(numPartitions, shuffle, child) =>
if (shuffle) {
Expand All @@ -536,18 +593,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
case logical.Sort(sortExprs, global, child) =>
execution.SortExec(sortExprs, global, planLater(child)) :: Nil
case logical.Project(projectList, child) =>
execution.ProjectExec(projectList, planLater(child)) :: Nil
case logical.Filter(condition, child) =>
execution.FilterExec(condition, planLater(child)) :: Nil
case f: logical.TypedFilter =>
execution.FilterExec(f.typedCondition(f.deserializer), planLater(f.child)) :: Nil
case e @ logical.Expand(_, _, child) =>
execution.ExpandExec(e.projections, e.output, planLater(child)) :: Nil
case logical.Window(windowExprs, partitionSpec, orderSpec, child) =>
execution.window.WindowExec(windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil
case logical.Sample(lb, ub, withReplacement, seed, child) =>
execution.SampleExec(lb, ub, withReplacement, seed, planLater(child)) :: Nil
case logical.LocalRelation(output, data, _) =>
LocalTableScanExec(output, data) :: Nil
case logical.LocalLimit(IntegerLiteral(limit), child) =>
Expand Down