Skip to content

Commit

Permalink
Address review comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
marmbrus committed Jul 26, 2014
1 parent 0672e8a commit 1a61293
Show file tree
Hide file tree
Showing 8 changed files with 92 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,14 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
/** Binds an input expression to a given input schema */
protected def bind(in: InType, inputSchema: Seq[Attribute]): InType

/**
* A cache of generated classes.
*
* From the Guava Docs: A Cache is similar to ConcurrentMap, but not quite the same. The most
* fundamental difference is that a ConcurrentMap persists all elements that are added to it until
* they are explicitly removed. A Cache on the other hand is generally configured to evict entries
* automatically, in order to constrain its memory footprint
*/
protected val cache = CacheBuilder.newBuilder()
.maximumSize(1000)
.build(
Expand All @@ -74,9 +82,11 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
}
})

def apply(expressions: InType, inputSchema: Seq[Attribute]): OutType=
/** Generates the requested evaluator binding the given expression(s) to the inputSchema. */
def apply(expressions: InType, inputSchema: Seq[Attribute]): OutType =
apply(bind(expressions, inputSchema))

/** Generates the requested evaluator given already bound expression(s). */
def apply(expressions: InType): OutType = cache.get(canonicalize(expressions))

/**
Expand Down Expand Up @@ -233,7 +243,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
case Cast(child @ NumericType(), FloatType) =>
child.castOrNull(c => q"$c.toFloat", IntegerType)

// Special handling required for timestamps in hive test cases.
// Special handling required for timestamps in hive test cases since the toString function
// does not match the expected output.
case Cast(e, StringType) if e.dataType != TimestampType =>
val eval = expressionEvaluator(e)
eval.code ++
Expand Down Expand Up @@ -355,18 +366,18 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
var $nullTerm = true
var $primitiveTerm: ${termForType(c.dataType)} = ${defaultPrimitive(c.dataType)}
""".children ++
children.map { c =>
val eval = expressionEvaluator(c)
q"""
children.map { c =>
val eval = expressionEvaluator(c)
q"""
if($nullTerm) {
..${eval.code}
if(!${eval.nullTerm}) {
$nullTerm = false
$primitiveTerm = ${eval.primitiveTerm}
}
}
"""
}
"""
}

case i @ expressions.If(condition, trueValue, falseValue) =>
val condEval = expressionEvaluator(condition)
Expand All @@ -392,8 +403,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
// If there was no match in the partial function above, we fall back on calling the interpreted
// expression evaluator.
val code: Seq[Tree] =
primitiveEvaluation.lift.apply(e)
.getOrElse {
primitiveEvaluation.lift.apply(e).getOrElse {
log.debug(s"No rules to generate $e")
val tree = reify { e }
q"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ package object expressions {
abstract class MutableProjection extends Projection {
def currentValue: Row

/** Updates the target of this projection to a new MutableRow */
/** Uses the given row to store the output of the projection. */
def target(row: MutableRow): MutableProjection
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,42 +67,3 @@ class GeneratedEvaluationSuite extends ExpressionEvaluationSuite {
futures.foreach(Await.result(_, 10.seconds))
}
}

/**
* Overrides our expression evaluation tests to use generated code on mutable rows.
*/
class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite {
override def checkEvaluation(
expression: Expression,
expected: Any,
inputRow: Row = EmptyRow): Unit = {
lazy val evaluated = GenerateProjection.expressionEvaluator(expression)

val plan = try {
GenerateProjection(Alias(expression, s"Optimized($expression)")() :: Nil)
} catch {
case e: Throwable =>
fail(
s"""
|Code generation of $expression failed:
|${evaluated.code.mkString("\n")}
|$e
""".stripMargin)
}

val actual = plan(inputRow)
val expectedRow = new GenericRow(Array[Any](expected))
if (actual.hashCode() != expectedRow.hashCode()) {
fail(
s"""
|Mismatched hashCodes for values: $actual, $expectedRow
|Hash Codes: ${actual.hashCode()} != ${expectedRow.hashCode()}
|${evaluated.code.mkString("\n")}
""".stripMargin)
}
if (actual != expectedRow) {
val input = if(inputRow == EmptyRow) "" else s", input: $inputRow"
fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input")
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._

/**
* Overrides our expression evaluation tests to use generated code on mutable rows.
*/
class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite {
override def checkEvaluation(
expression: Expression,
expected: Any,
inputRow: Row = EmptyRow): Unit = {
lazy val evaluated = GenerateProjection.expressionEvaluator(expression)

val plan = try {
GenerateProjection(Alias(expression, s"Optimized($expression)")() :: Nil)
} catch {
case e: Throwable =>
fail(
s"""
|Code generation of $expression failed:
|${evaluated.code.mkString("\n")}
|$e
""".stripMargin)
}

val actual = plan(inputRow)
val expectedRow = new GenericRow(Array[Any](expected))
if (actual.hashCode() != expectedRow.hashCode()) {
fail(
s"""
|Mismatched hashCodes for values: $actual, $expectedRow
|Hash Codes: ${actual.hashCode()} != ${expectedRow.hashCode()}
|${evaluated.code.mkString("\n")}
""".stripMargin)
}
if (actual != expectedRow) {
val input = if(inputRow == EmptyRow) "" else s", input: $inputRow"
fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ case class GeneratedAggregate(

val computationSchema = computeFunctions.flatMap(_.schema)

val resultMap = aggregatesToCompute.zip(computeFunctions).map {
val resultMap: Map[Long, Expression] = aggregatesToCompute.zip(computeFunctions).map {
case (agg, func) => agg.id -> func.result
}.toMap

Expand All @@ -116,34 +116,33 @@ case class GeneratedAggregate(
case (e, i) => (e, Alias(e, s"GroupingExpr$i")())
}

val groupMap = namedGroups.map { case (k, v) => k -> v.toAttribute}.toMap
val groupMap: Map[Expression, Attribute] =
namedGroups.map { case (k, v) => k -> v.toAttribute}.toMap

// The set of expressions that produce the final output given the aggregation buffer and the
// grouping expressions.
val resultExpressions = aggregateExpressions.map(_.transform {
case e: Expression if resultMap.contains(e.id) => resultMap(e.id)
case e: Expression if groupMap.contains(e) => groupMap(e)
})

child.execute().mapPartitions { iter =>
// Builds a new custom class for holding the results of aggregation for a group.
@transient
val newAggregationBuffer =
newProjection(computeFunctions.flatMap(_.initialValues), child.output)

// A projection that is used to update the aggregate values for a group given a new tuple.
// This projection should be targeted at the current values for the group and then applied
// to a joined row of the current values with the new input row.
@transient
val updateProjection =
newMutableProjection(
computeFunctions.flatMap(_.update),
computeFunctions.flatMap(_.schema) ++ child.output)()

// A projection that computes the group given an input tuple.
@transient
val groupProjection = newProjection(groupingExpressions, child.output)

// A projection that produces the final result, given a computation.
@transient
val resultProjectionBuilder =
newMutableProjection(
resultExpressions,
Expand All @@ -155,10 +154,11 @@ case class GeneratedAggregate(
// TODO: Codegening anything other than the updateProjection is probably over kill.
val buffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow]
var currentRow: Row = null
updateProjection.target(buffer)

while (iter.hasNext) {
currentRow = iter.next()
updateProjection.target(buffer)(joinedRow(buffer, currentRow))
updateProjection(joinedRow(buffer, currentRow))
}

val resultProjection = resultProjectionBuilder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
val relation =
ParquetRelation.create(path, child, sparkContext.hadoopConfiguration)
// Note: overwrite=false because otherwise the metadata we just created will be deleted
InsertIntoParquetTable(relation, planLater(child), overwrite=false) :: Nil
InsertIntoParquetTable(relation, planLater(child), overwrite = false) :: Nil
case logical.InsertIntoTable(table: ParquetRelation, partition, child, overwrite) =>
InsertIntoParquetTable(table, planLater(child), overwrite) :: Nil
case PhysicalOperation(projectList, filters: Seq[Expression], relation: ParquetRelation) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,9 +304,7 @@ case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNod

leftResults.cartesian(rightResults).mapPartitions { iter =>
val joinedRow = new JoinedRow
iter.map {
case (l: Row, r: Row) => joinedRow(l, r)
}
iter.map(r => joinedRow(r._1, r._2))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.sql.parquet

import org.apache.spark.sql.execution.SparkPlan
import org.scalatest.{BeforeAndAfterAll, FunSuiteLike}

import parquet.hadoop.ParquetFileWriter
Expand All @@ -26,13 +25,15 @@ import parquet.schema.MessageTypeParser

import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.mapreduce.Job

import org.apache.spark.SparkContext
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.{SqlLexical, SqlParser}
import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.types.{BooleanType, IntegerType}
import org.apache.spark.sql.catalyst.util.getTempFilePath
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext._
import org.apache.spark.util.Utils
Expand Down

0 comments on commit 1a61293

Please sign in to comment.