Skip to content

Commit

Permalink
[SPARK-10316] [SQL] respect nondeterministic expressions in PhysicalO…
Browse files Browse the repository at this point in the history
…peration

We did a lot of special handling for non-deterministic expressions in `Optimizer`. However, `PhysicalOperation` just collects all Projects and Filters and mess it up. We should respect the operators order caused by non-deterministic expressions in `PhysicalOperation`.

Author: Wenchen Fan <cloud0fan@outlook.com>

Closes apache#8486 from cloud-fan/fix.
  • Loading branch information
cloud-fan authored and marmbrus committed Sep 8, 2015
1 parent 5b2192e commit 5fd5795
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,35 +17,12 @@

package org.apache.spark.sql.catalyst.planning

import scala.annotation.tailrec

import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.trees.TreeNodeRef
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._

/**
* A pattern that matches any number of filter operations on top of another relational operator.
* Adjacent filter operators are collected and their conditions are broken up and returned as a
* sequence of conjunctive predicates.
*
* @return A tuple containing a sequence of conjunctive predicates that should be used to filter the
* output and a relational operator.
*/
object FilteredOperation extends PredicateHelper {
type ReturnType = (Seq[Expression], LogicalPlan)

def unapply(plan: LogicalPlan): Option[ReturnType] = Some(collectFilters(Nil, plan))

@tailrec
private def collectFilters(filters: Seq[Expression], plan: LogicalPlan): ReturnType = plan match {
case Filter(condition, child) =>
collectFilters(filters ++ splitConjunctivePredicates(condition), child)
case other => (filters, other)
}
}

/**
* A pattern that matches any number of project or filter operations on top of another relational
* operator. All filter operators are collected and their conditions are broken up and returned
Expand All @@ -62,8 +39,9 @@ object PhysicalOperation extends PredicateHelper {
}

/**
* Collects projects and filters, in-lining/substituting aliases if necessary. Here are two
* examples for alias in-lining/substitution. Before:
* Collects all deterministic projects and filters, in-lining/substituting aliases if necessary.
* Here are two examples for alias in-lining/substitution.
* Before:
* {{{
* SELECT c1 FROM (SELECT key AS c1 FROM t1) t2 WHERE c1 > 10
* SELECT c1 AS c2 FROM (SELECT key AS c1 FROM t1) t2 WHERE c1 > 10
Expand All @@ -74,15 +52,15 @@ object PhysicalOperation extends PredicateHelper {
* SELECT key AS c2 FROM t1 WHERE key > 10
* }}}
*/
def collectProjectsAndFilters(plan: LogicalPlan):
private def collectProjectsAndFilters(plan: LogicalPlan):
(Option[Seq[NamedExpression]], Seq[Expression], LogicalPlan, Map[Attribute, Expression]) =
plan match {
case Project(fields, child) =>
case Project(fields, child) if fields.forall(_.deterministic) =>
val (_, filters, other, aliases) = collectProjectsAndFilters(child)
val substitutedFields = fields.map(substitute(aliases)).asInstanceOf[Seq[NamedExpression]]
(Some(substitutedFields), filters, other, collectAliases(substitutedFields))

case Filter(condition, child) =>
case Filter(condition, child) if condition.deterministic =>
val (fields, filters, other, aliases) = collectProjectsAndFilters(child)
val substitutedCondition = substitute(aliases)(condition)
(fields, filters ++ splitConjunctivePredicates(substitutedCondition), other, aliases)
Expand All @@ -91,11 +69,11 @@ object PhysicalOperation extends PredicateHelper {
(None, Nil, other, Map.empty)
}

def collectAliases(fields: Seq[Expression]): Map[Attribute, Expression] = fields.collect {
private def collectAliases(fields: Seq[Expression]): Map[Attribute, Expression] = fields.collect {
case a @ Alias(child, _) => a.toAttribute -> child
}.toMap

def substitute(aliases: Map[Attribute, Expression])(expr: Expression): Expression = {
private def substitute(aliases: Map[Attribute, Expression])(expr: Expression): Expression = {
expr.transform {
case a @ Alias(ref: AttributeReference, name) =>
aliases.get(ref).map(Alias(_, name)(a.exprId, a.qualifiers)).getOrElse(a)
Expand Down
12 changes: 12 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import java.io.File
import scala.language.postfixOps
import scala.util.Random

import org.scalatest.Matchers._

import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -895,4 +897,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
.orderBy(sum('j))
checkAnswer(query, Row(1, 2))
}

test("SPARK-10316: respect non-deterministic expressions in PhysicalOperation") {
val input = sqlContext.read.json(sqlContext.sparkContext.makeRDD(
(1 to 10).map(i => s"""{"id": $i}""")))

val df = input.select($"id", rand(0).as('r))
df.as("a").join(df.filter($"r" < 0.5).as("b"), $"a.id" === $"b.id").collect().foreach { row =>
assert(row.getDouble(1) - row.getDouble(3) === 0.0 +- 0.001)
}
}
}

0 comments on commit 5fd5795

Please sign in to comment.