Skip to content

Commit

Permalink
Refactor putting SQLContext into SparkPlan. Fix ordering, other test …
Browse files Browse the repository at this point in the history
…cases.
  • Loading branch information
marmbrus committed Jul 22, 2014
1 parent be2cd6b commit d2ad5c5
Show file tree
Hide file tree
Showing 14 changed files with 104 additions and 95 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@

package org.apache.spark.sql.catalyst.expressions.codegen

import com.typesafe.scalalogging.slf4j.Logging
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.types.{StringType, NumericType}

/**
* Generates bytecode for an [[Ordering]] of [[Row Rows]] for a given set of
* [[Expression Expressions]].
*/
object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] {
object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] with Logging {
import scala.reflect.runtime.{universe => ru}
import scala.reflect.runtime.universe._

Expand All @@ -40,6 +42,22 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] {
val evalA = expressionEvaluator(order.child)
val evalB = expressionEvaluator(order.child)

val compare = order.child.dataType match {
case _: NumericType =>
q"""
val comp = ${evalA.primitiveTerm} - ${evalB.primitiveTerm}
if(comp != 0) {
return ${if (order.direction == Ascending) q"comp.toInt" else q"-comp.toInt"}
}
"""
case StringType =>
if (order.direction == Ascending) {
q"""return ${evalA.primitiveTerm}.compare(${evalB.primitiveTerm})"""
} else {
q"""return ${evalB.primitiveTerm}.compare(${evalA.primitiveTerm})"""
}
}

q"""
i = $a
..${evalA.code}
Expand All @@ -52,9 +70,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] {
} else if (${evalB.nullTerm}) {
return ${if (order.direction == Ascending) q"1" else q"-1"}
} else {
i = a
val comp = ${evalA.primitiveTerm} - ${evalB.primitiveTerm}
if(comp != 0) return comp.toInt
$compare
}
"""
}
Expand All @@ -76,6 +92,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] {
}
new $orderingName()
"""
logger.debug(s"Generated Ordering: $code")
toolBox.eval(code).asInstanceOf[Ordering[Row]]
}
}
18 changes: 5 additions & 13 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -304,18 +304,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
@transient
protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] {
val batches =
Batch("Add exchange", Once, AddExchange(self)) ::
Batch("CodeGen", Once, TurnOnCodeGen) :: Nil
}

protected object TurnOnCodeGen extends Rule[SparkPlan] {
def apply(plan: SparkPlan): SparkPlan = {
if (self.codegenEnabled) {
plan.foreach(p => println(p.simpleString))
plan.foreach(_._codegenEnabled = true)
}
plan
}
Batch("Add exchange", Once, AddExchange(self)) :: Nil
}

/**
Expand All @@ -330,7 +319,10 @@ class SQLContext(@transient val sparkContext: SparkContext)
lazy val analyzed = analyzer(logical)
lazy val optimizedPlan = optimizer(analyzed)
// TODO: Don't just pick the first one...
lazy val sparkPlan = planner(optimizedPlan).next()
lazy val sparkPlan = {
SparkPlan.currentContext.set(self)
planner(optimizedPlan).next()
}
// executedPlan should not be used to initialize any SparkPlan. It should be
// only used for execution.
lazy val executedPlan: SparkPlan = prepareForExecution(sparkPlan)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ case class Aggregate(
partial: Boolean,
groupingExpressions: Seq[Expression],
aggregateExpressions: Seq[NamedExpression],
child: SparkPlan)(@transient sqlContext: SQLContext)
child: SparkPlan)
extends UnaryNode {

override def requiredChildDistribution =
Expand All @@ -56,8 +56,6 @@ case class Aggregate(
}
}

override def otherCopyArgs = sqlContext :: Nil

// HACK: Generators don't correctly preserve their output through serializations so we grab
// out child's output attributes statically here.
private[this] val childOutput = child.output
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,11 @@ case class Generate(
if (join) child.output ++ generatorOutput else generatorOutput

/** Codegenned rows are not serializable... */
override def codegenEnabled = false
override val codegenEnabled = false

override def execute() = {
val boundGenerator = BindReferences.bindReference(generator, child.output)

if (join) {
child.execute().mapPartitions { iter =>
val nullValues = Seq.fill(generator.output.size)(Literal(null))
Expand All @@ -66,7 +68,7 @@ case class Generate(
val joinedRow = new JoinedRow

iter.flatMap {row =>
val outputRows = generator.eval(row)
val outputRows = boundGenerator.eval(row)
if (outer && outputRows.isEmpty) {
outerProjection(row) :: Nil
} else {
Expand All @@ -75,7 +77,7 @@ case class Generate(
}
}
} else {
child.execute().mapPartitions(iter => iter.flatMap(row => generator.eval(row)))
child.execute().mapPartitions(iter => iter.flatMap(row => boundGenerator.eval(row)))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,9 @@ case class GeneratedAggregate(
partial: Boolean,
groupingExpressions: Seq[Expression],
aggregateExpressions: Seq[NamedExpression],
child: SparkPlan)(@transient sqlContext: SQLContext)
child: SparkPlan)
extends UnaryNode {

println(s"new $codegenEnabled")

override def requiredChildDistribution =
if (partial) {
UnspecifiedDistribution :: Nil
Expand All @@ -62,12 +60,9 @@ case class GeneratedAggregate(
}
}

override def otherCopyArgs = sqlContext :: Nil

override def output = aggregateExpressions.map(_.toAttribute)

override def execute() = {
println(s"codegen: $codegenEnabled")
val aggregatesToCompute = aggregateExpressions.flatMap { a =>
a.collect { case agg: AggregateExpression => agg}
}
Expand Down Expand Up @@ -160,7 +155,6 @@ case class GeneratedAggregate(
// TODO: Codegening anything other than the updateProjection is probably over kill.
val buffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow]
var currentRow: Row = null
println(codegenEnabled)

while (iter.hasNext) {
currentRow = iter.next()
Expand All @@ -172,7 +166,6 @@ case class GeneratedAggregate(
} else {
val buffers = new java.util.HashMap[Row, MutableRow]()

println(codegenEnabled)
var currentRow: Row = null
while (iter.hasNext) {
currentRow = iter.next()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
package org.apache.spark.sql.execution

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{SQLContext, Logging, Row}
import org.apache.spark.sql.{SQLContext, Row}
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions._
Expand All @@ -28,17 +29,35 @@ import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.BaseRelation
import org.apache.spark.sql.catalyst.plans.physical._


object SparkPlan {
protected[sql] val currentContext = new ThreadLocal[SQLContext]()
}

/**
* :: DeveloperApi ::
*/
@DeveloperApi
abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging {
abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializable {
self: Product =>

def codegenEnabled = _codegenEnabled
/**
* A handle to the SQL Context that was used to create this plan. Since many operators need
* access to the sqlContext for RDD operations or configuration this field is automatically
* populated by the query planning infrastructure.
*/
@transient
protected val sqlContext = SparkPlan.currentContext.get()

/** Will be set to true during planning if code generation should be used for this operator. */
private[sql] var _codegenEnabled = false
protected def sparkContext = sqlContext.sparkContext

def logger = log

val codegenEnabled: Boolean = if(sqlContext != null) {
sqlContext.codegenEnabled
} else {
false
}

// TODO: Move to `DistributedPlan`
/** Specifies how data is partitioned across different nodes in the cluster. */
Expand All @@ -57,16 +76,22 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging {
*/
def executeCollect(): Array[Row] = execute().map(_.copy()).collect()

def newProjection(expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection =
protected def newProjection(
expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = {
log.debug(
s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled")
if (codegenEnabled) {
GenerateProjection(expressions, inputSchema)
} else {
new InterpretedProjection(expressions, inputSchema)
}
}

def newMutableProjection(
protected def newMutableProjection(
expressions: Seq[Expression],
inputSchema: Seq[Attribute]): () => MutableProjection = {
log.debug(
s"Creating MutableProj: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled")
if(codegenEnabled) {
GenerateMutableProjection(expressions, inputSchema)
} else {
Expand All @@ -75,15 +100,16 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging {
}


def newPredicate(expression: Expression, inputSchema: Seq[Attribute]): (Row) => Boolean = {
protected def newPredicate(
expression: Expression, inputSchema: Seq[Attribute]): (Row) => Boolean = {
if (codegenEnabled) {
GeneratePredicate(expression, inputSchema)
} else {
InterpretedPredicate(expression, inputSchema)
}
}

def newOrdering(order: Seq[SortOrder], inputSchema: Seq[Attribute]): Ordering[Row] = {
protected def newOrdering(order: Seq[SortOrder], inputSchema: Seq[Attribute]): Ordering[Row] = {
if (codegenEnabled) {
GenerateOrdering(order, inputSchema)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
// no predicate can be evaluated by matching hash keys
case logical.Join(left, right, LeftSemi, condition) =>
execution.LeftSemiJoinBNL(
planLater(left), planLater(right), condition)(sqlContext) :: Nil
planLater(left), planLater(right), condition) :: Nil
case _ => Nil
}
}
Expand All @@ -58,7 +58,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
condition: Option[Expression],
side: BuildSide) = {
val broadcastHashJoin = execution.BroadcastHashJoin(
leftKeys, rightKeys, side, planLater(left), planLater(right))(sqlContext)
leftKeys, rightKeys, side, planLater(left), planLater(right))
condition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) :: Nil
}

Expand Down Expand Up @@ -118,7 +118,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
partial = true,
groupingExpressions,
partialComputation,
planLater(child))(sqlContext))(sqlContext) :: Nil
planLater(child))) :: Nil

// Cases where some aggregate can not be codegened
case PartialAggregation(
Expand All @@ -135,7 +135,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
partial = true,
groupingExpressions,
partialComputation,
planLater(child))(sqlContext))(sqlContext) :: Nil
planLater(child))) :: Nil

case _ => Nil
}
Expand All @@ -153,7 +153,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case logical.Join(left, right, joinType, condition) =>
execution.BroadcastNestedLoopJoin(
planLater(left), planLater(right), joinType, condition)(sqlContext) :: Nil
planLater(left), planLater(right), joinType, condition) :: Nil
case _ => Nil
}
}
Expand All @@ -175,7 +175,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
object TakeOrdered extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case logical.Limit(IntegerLiteral(limit), logical.Sort(order, child)) =>
execution.TakeOrdered(limit, order, planLater(child))(sqlContext) :: Nil
execution.TakeOrdered(limit, order, planLater(child)) :: Nil
case _ => Nil
}
}
Expand All @@ -187,9 +187,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
val relation =
ParquetRelation.create(path, child, sparkContext.hadoopConfiguration)
// Note: overwrite=false because otherwise the metadata we just created will be deleted
InsertIntoParquetTable(relation, planLater(child), overwrite=false)(sqlContext) :: Nil
InsertIntoParquetTable(relation, planLater(child), overwrite=false) :: Nil
case logical.InsertIntoTable(table: ParquetRelation, partition, child, overwrite) =>
InsertIntoParquetTable(table, planLater(child), overwrite)(sqlContext) :: Nil
InsertIntoParquetTable(table, planLater(child), overwrite) :: Nil
case PhysicalOperation(projectList, filters: Seq[Expression], relation: ParquetRelation) =>
val prunePushedDownFilters =
if (sparkContext.conf.getBoolean(ParquetFilters.PARQUET_FILTER_PUSHDOWN_ENABLED, true)) {
Expand Down Expand Up @@ -218,7 +218,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
projectList,
filters,
prunePushedDownFilters,
ParquetTableScan(_, relation, filters)(sqlContext)) :: Nil
ParquetTableScan(_, relation, filters)) :: Nil

case _ => Nil
}
Expand All @@ -243,7 +243,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case logical.Distinct(child) =>
execution.Aggregate(
partial = false, child.output, child.output, planLater(child))(sqlContext) :: Nil
partial = false, child.output, child.output, planLater(child)) :: Nil
case logical.Sort(sortExprs, child) =>
// This sort is a global sort. Its requiredDistribution will be an OrderedDistribution.
execution.Sort(sortExprs, global = true, planLater(child)):: Nil
Expand All @@ -256,17 +256,17 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.Filter(condition, child) =>
execution.Filter(condition, planLater(child)) :: Nil
case logical.Aggregate(group, agg, child) =>
execution.Aggregate(partial = false, group, agg, planLater(child))(sqlContext) :: Nil
execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil
case logical.Sample(fraction, withReplacement, seed, child) =>
execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil
case logical.LocalRelation(output, data) =>
ExistingRdd(
output,
ExistingRdd.productToRowRdd(sparkContext.parallelize(data, numPartitions))) :: Nil
case logical.Limit(IntegerLiteral(limit), child) =>
execution.Limit(limit, planLater(child))(sqlContext) :: Nil
execution.Limit(limit, planLater(child)) :: Nil
case Unions(unionChildren) =>
execution.Union(unionChildren.map(planLater))(sqlContext) :: Nil
execution.Union(unionChildren.map(planLater)) :: Nil
case logical.Except(left,right) =>
execution.Except(planLater(left),planLater(right)) :: Nil
case logical.Intersect(left, right) =>
Expand Down
Loading

0 comments on commit d2ad5c5

Please sign in to comment.