Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class Analyzer(
ResolveWindowOrder ::
ResolveWindowFrame ::
ResolveNaturalAndUsingJoin ::
InsertRelationScanner ::
ExtractWindowExpressions ::
GlobalAggregates ::
ResolveAggregateFunctions ::
Expand Down Expand Up @@ -466,6 +467,36 @@ class Analyzer(
}
}

/**
* Insert a [[Scanner]] operator over [[MultiInstanceRelation]], the operator can hold both
* projectLists and filter predicates.
*/
object InsertRelationScanner extends Rule[LogicalPlan] {
// If relation is direct child of [[Scanner]] or [[InsertIntoTable]] operator then we don't
// insert [[Scanner]] operator over it.
def collectProcessedRelations(plan: LogicalPlan): Set[LogicalPlan] = {
plan.collect {
case s @ Scanner(_, _, r: MultiInstanceRelation) =>
r
case i @ InsertIntoTable(r: MultiInstanceRelation, _, _, _, _) =>
r
case s @ Scanner(_, _, SubqueryAlias(_, r: MultiInstanceRelation, _)) =>
r
}.toSet
}

def apply(plan: LogicalPlan): LogicalPlan = {
val relationsProcessed = collectProcessedRelations(plan)

plan transformUp {
case relation: MultiInstanceRelation if !relationsProcessed.contains(relation) =>
Scanner(relation.output, Nil, relation)
case i @ InsertIntoTable(Scanner(_, _, r: MultiInstanceRelation), _, _, _, _) =>
i.copy(table = r)
}
}
}

/**
* Replaces [[UnresolvedAttribute]]s with concrete [[AttributeReference]]s from
* a logical plan node's children.
Expand Down Expand Up @@ -2100,6 +2131,11 @@ object CleanupAliases extends Rule[LogicalPlan] {
projectList.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression])
Project(cleanedProjectList, child)

case Scanner(projectList, filters, child) =>
val cleanedProjectList =
projectList.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression])
Scanner(cleanedProjectList, filters, child)

case Aggregate(grouping, aggs, child) =>
val cleanedAggs = aggs.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression])
Aggregate(grouping.map(trimAliases), cleanedAggs, child)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Scanner, SubqueryAlias}
import org.apache.spark.sql.catalyst.util.StringUtils

object SessionCatalog {
Expand Down Expand Up @@ -417,9 +417,9 @@ class SessionCatalog(
val view = Option(metadata.tableType).collect {
case CatalogTableType.VIEW => name
}
SubqueryAlias(relationAlias, SimpleCatalogRelation(db, metadata), view)
SubqueryAlias(relationAlias, Scanner(SimpleCatalogRelation(db, metadata)), view)
} else {
SubqueryAlias(relationAlias, tempTables(table), Option(name))
SubqueryAlias(relationAlias, Scanner(tempTables(table)), Option(name))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,10 @@ package object dsl {

def filter[T : Encoder](func: T => Boolean): LogicalPlan = TypedFilter(func, logicalPlan)

def scanner(): LogicalPlan = Scanner(logicalPlan)

def scanner(condition: Expression): LogicalPlan = Scanner(condition, logicalPlan)

def serialize[T : Encoder]: LogicalPlan = CatalystSerde.serialize[T](logicalPlan)

def deserialize[T : Encoder]: LogicalPlan = CatalystSerde.deserialize[T](logicalPlan)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
CombineFilters,
CombineLimits,
CombineUnions,
CombineScanners,
// Constant folding and strength reduction
NullPropagation,
FoldablePropagation,
Expand All @@ -102,6 +103,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
RemoveDispensableExpressions,
SimplifyBinaryComparison,
PruneFilters,
PruneScanners,
EliminateSorts,
SimplifyCasts,
SimplifyCaseConversionExpressions,
Expand Down Expand Up @@ -423,7 +425,12 @@ object ColumnPruning extends Rule[LogicalPlan] {
case p @ Project(_, child) =>
val required = child.references ++ p.references
if ((child.inputSet -- required).nonEmpty) {
val newChildren = child.children.map(c => prunedChild(c, required))
val newChildren = child.children.map { c =>
c match {
case r: MultiInstanceRelation => r
case _ => prunedChild(c, required)
}
}
p.copy(child = child.withNewChildren(newChildren))
} else {
p
Expand Down Expand Up @@ -556,6 +563,23 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelpe
filter
}

case scanner @ Scanner(projectList, filters, child) =>
val aliasMap = AttributeMap(projectList.collect {
case a: Alias => (a.toAttribute, a.child)
})

val newFilters = scanner.constraints --
(child.constraints ++ filters)

if (newFilters.nonEmpty) {
val replaced = replaceAlias(newFilters.reduce(And), aliasMap)
Scanner(projectList,
filters ++ (splitConjunctivePredicates(replaced).toSet -- filters),
child)
} else {
scanner
}

case join @ Join(left, right, joinType, conditionOpt) =>
// Only consider constraints that can be pushed down completely to either the left or the
// right child
Expand Down Expand Up @@ -601,6 +625,56 @@ object CombineFilters extends Rule[LogicalPlan] with PredicateHelper {
}
}

/**
* Combines adjacent [[Scanner]] operators with [[Project]]s and [[Filter]]s as well as other
* [[Scanner]]s, merging the non-redundant conditions into one conjunctive predicate.
*/
object CombineScanners extends Rule[LogicalPlan] with PredicateHelper {

def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case Project(fields, child: Scanner) =>
val aliasMap = AttributeMap(child.projectList.collect {
case a: Alias => (a.toAttribute, a)
})

child.copy(projectList = buildCleanedProjectList(fields, child.projectList))
case Filter(condition, child: Scanner) if child.projectList.forall(_.deterministic) =>
val aliasMap = AttributeMap(child.projectList.collect {
case a: Alias => (a.toAttribute, a.child)
})

val newFilters =
splitConjunctivePredicates(replaceAlias(condition, aliasMap)) ++ child.filters
child.copy(filters = newFilters)
case Scanner(fields, filters, child: Scanner) =>
val newFilters = filters ++ child.filters
child.copy(projectList = buildCleanedProjectList(fields, child.projectList),
filters = newFilters)
}

private def buildCleanedProjectList(
upper: Seq[NamedExpression],
lower: Seq[NamedExpression]): Seq[NamedExpression] = {
// Create a map of Aliases to their values from the lower projection.
// e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)).
val aliases = AttributeMap(lower.collect {
case a: Alias => (a.toAttribute, a)
})

// Substitute any attributes that are produced by the lower projection, so that we safely
// eliminate it.
// e.g., 'SELECT c + 1 FROM (SELECT a + b AS C ...' produces 'SELECT a + b + 1 ...'
// Use transformUp to prevent infinite recursion.
val rewrittenUpper = upper.map(_.transformUp {
case a: Attribute => aliases.getOrElse(a, a)
})
// collapse upper and lower Projects may introduce unnecessary Aliases, trim them here.
rewrittenUpper.map { p =>
CleanupAliases.trimNonTopLevelAliases(p).asInstanceOf[NamedExpression]
}
}
}

/**
* Removes no-op SortOrder from Sort
*/
Expand Down Expand Up @@ -644,6 +718,31 @@ object PruneFilters extends Rule[LogicalPlan] with PredicateHelper {
}
}

/**
* Removes filters in [[Scanner]] operator that can be evaluated trivially.
*/
object PruneScanners extends Rule[LogicalPlan] with PredicateHelper {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case s @ Scanner(projectList, filters, _) =>
// Removes filters always evaluate to true
val newFilters = filters.collect {
case filter: Expression if !filter.fastEquals(Literal(true, BooleanType)) => filter
}

if (newFilters.exists { filter =>
filter.fastEquals(Literal(false, BooleanType)) || filter.fastEquals(Literal(null))}) {
// If there exists at lease one filter that always evaluate to null or false,
// replace the input with an empty relation.
Scanner(LocalRelation(projectList.map(_.toAttribute), data = Seq.empty))
} else if (filters.forall(newFilters.contains(_))) {
// No filter always evaluate to true, respect the original filters.
s
} else {
s.copy(filters = newFilters)
}
}
}

/**
* Pushes [[Filter]] operators through many operators iff:
* 1) the operator is deterministic
Expand Down Expand Up @@ -751,6 +850,29 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper {
filter
}

case filter @ Filter(condition, child: Scanner) if child.projectList.forall(_.deterministic) =>
// Deterministic parts placed before any non-deterministic predicates in [[Filter]] could
// be pushed down to [[Scanner]], combine with `Scanner.filters`.
val aliasMap = AttributeMap(child.projectList.collect {
case a: Alias => (a.toAttribute, a.child)
})

val (pushDown, stayUp) = splitConjunctivePredicates(condition).span(_.deterministic)

if (pushDown.nonEmpty) {
val replaced = replaceAlias(pushDown.reduce(And), aliasMap)
val newScanner =
child.copy(filters =
child.filters ++ (splitConjunctivePredicates(replaced).toSet -- child.filters))
if (stayUp.nonEmpty) {
Filter(stayUp.reduceLeft(And), newScanner)
} else {
newScanner
}
} else {
filter
}

case filter @ Filter(condition, u: UnaryNode)
if canPushThrough(u) && u.expressions.forall(_.deterministic) =>
pushDownPredicate(filter, u.child) { predicate =>
Expand Down Expand Up @@ -1022,15 +1144,16 @@ object DecimalAggregates extends Rule[LogicalPlan] {
/**
* Converts local operations (i.e. ones that don't require data exchange) on LocalRelation to
* another LocalRelation.
*
* This is relatively simple as it currently handles only a single case: Project.
*/
object ConvertToLocalRelation extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case Project(projectList, LocalRelation(output, data))
if !projectList.exists(hasUnevaluableExpr) =>
case scanner @ Scanner(projectList, filters, LocalRelation(output, data))
if !projectList.exists(hasUnevaluableExpr) && filters.isEmpty =>
val projection = new InterpretedProjection(projectList, output)
LocalRelation(projectList.map(_.toAttribute), data.map(projection))
val newProjectList = projectList.map(_.toAttribute)
val newRelation = LocalRelation(newProjectList, data.map(projection))

Scanner(newProjectList, filters, newRelation)
}

private def hasUnevaluableExpr(expr: Expression): Boolean = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.rules._
*/
object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper {
private def isEmptyLocalRelation(plan: LogicalPlan): Boolean = plan match {
case Scanner(_, _, r: LocalRelation) => r.data.isEmpty
case p: LocalRelation => p.data.isEmpty
case _ => false
}
Expand All @@ -43,7 +44,7 @@ object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper {
e.collectFirst { case _: AggregateFunction => () }.isDefined
}

private def empty(plan: LogicalPlan) = LocalRelation(plan.output, data = Seq.empty)
private def empty(plan: LogicalPlan) = Scanner(LocalRelation(plan.output, data = Seq.empty))

def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case p: Union if p.children.forall(isEmptyLocalRelation) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,27 +53,55 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
}

// Filter the plan by applying left semi and left anti joins.
withSubquery.foldLeft(newFilter) {
case (p, PredicateSubquery(sub, conditions, _, _)) =>
val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
Join(outerPlan, sub, LeftSemi, joinCond)
case (p, Not(PredicateSubquery(sub, conditions, false, _))) =>
val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
Join(outerPlan, sub, LeftAnti, joinCond)
case (p, Not(PredicateSubquery(sub, conditions, true, _))) =>
// This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr
// Construct the condition. A NULL in one of the conditions is regarded as a positive
// result; such a row will be filtered out by the Anti-Join operator.

// Note that will almost certainly be planned as a Broadcast Nested Loop join.
// Use EXISTS if performance matters to you.
val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
val anyNull = splitConjunctivePredicates(joinCond.get).map(IsNull).reduceLeft(Or)
Join(outerPlan, sub, LeftAnti, Option(Or(anyNull, joinCond.get)))
case (p, predicate) =>
val (newCond, inputPlan) = rewriteExistentialExpr(Seq(predicate), p)
Project(p.output, Filter(newCond.get, inputPlan))
rewritePredicateSubquery(withSubquery, newFilter)

case Scanner(projectList, filters, child) if filters.exists(
PredicateSubquery.hasPredicateSubquery(_)) =>
val (withSubquery, withoutSubquery) =
filters.partition(PredicateSubquery.hasPredicateSubquery)

val newFilter: LogicalPlan = withoutSubquery match {
case Nil => Scanner(child)
case conditions => Scanner(conditions.reduce(And), child)
}

val newPlan = rewritePredicateSubquery(withSubquery, newFilter)

if (newPlan.output.forall(projectList.contains(_))) {
newPlan
} else {
Project(projectList, newPlan)
}
}

/**
* Re-construct the plan by applying left semi and left anti joins instead of predicate
* subquerys.
*/
private def rewritePredicateSubquery(
predicateSubquerys: Seq[Expression],
filter: LogicalPlan): LogicalPlan = {
predicateSubquerys.foldLeft(filter) {
case (p, PredicateSubquery(sub, conditions, _, _)) =>
val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
Join(outerPlan, sub, LeftSemi, joinCond)
case (p, Not(PredicateSubquery(sub, conditions, false, _))) =>
val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
Join(outerPlan, sub, LeftAnti, joinCond)
case (p, Not(PredicateSubquery(sub, conditions, true, _))) =>
// This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr
// Construct the condition. A NULL in one of the conditions is regarded as a positive
// result; such a row will be filtered out by the Anti-Join operator.

// Note that will almost certainly be planned as a Broadcast Nested Loop join.
// Use EXISTS if performance matters to you.
val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
val anyNull = splitConjunctivePredicates(joinCond.get).map(IsNull).reduceLeft(Or)
Join(outerPlan, sub, LeftAnti, Option(Or(anyNull, joinCond.get)))
case (p, predicate) =>
val (newCond, inputPlan) = rewriteExistentialExpr(Seq(predicate), p)
Project(p.output, Filter(newCond.get, inputPlan))
}
}

/**
Expand Down
Loading