Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-32941][SQL] Optimize UpdateFields expression chain and put the rule early in Analysis phase #29812

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.optimizer.OptimizeUpdateFields
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
Expand Down Expand Up @@ -207,6 +208,11 @@ class Analyzer(

lazy val batches: Seq[Batch] = Seq(
Batch("Substitution", fixedPoint,
// This rule optimizes `UpdateFields` expression chains so looks more like optimization rule.
// However, when manipulating deeply nested schema, `UpdateFields` expression tree could be
// very complex and make analysis impossible. Thus we need to optimize `UpdateFields` early
// at the beginning of analysis.
OptimizeUpdateFields,
CTESubstitution,
WindowsSubstitution,
EliminateUnions,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import scala.collection.mutable

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.CombineUnions
import org.apache.spark.sql.catalyst.optimizer.{CombineUnions, OptimizeUpdateFields}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -88,13 +88,6 @@ object ResolveUnion extends Rule[LogicalPlan] {
}
}

def simplifyWithFields(expr: Expression): Expression = {
expr.transformUp {
case UpdateFields(UpdateFields(struct, fieldOps1), fieldOps2) =>
UpdateFields(struct, fieldOps1 ++ fieldOps2)
}
}

/**
* Adds missing fields recursively into given `col` expression, based on the target `StructType`.
* This is called by `compareAndAddFields` when we find two struct columns with same name but
Expand All @@ -119,7 +112,7 @@ object ResolveUnion extends Rule[LogicalPlan] {
missingFieldsOpt.map { s =>
val struct = addFieldsInto(col, s.fields)
// Combines `WithFields`s to reduce expression tree.
val reducedStruct = simplifyWithFields(struct)
val reducedStruct = struct.transformUp(OptimizeUpdateFields.optimizeUpdateFields)
val sorted = sortStructFieldsInWithFields(reducedStruct)
sorted
}.get
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,12 @@ object SimplifyExtractValueOps extends Rule[LogicalPlan] {
// if the struct itself is null, then any value extracted from it (expr) will be null
// so we don't need to wrap expr in If(IsNull(struct), Literal(null, expr.dataType), expr)
case expr: GetStructField if expr.child.semanticEquals(structExpr) => expr
case expr => If(IsNull(structExpr), Literal(null, expr.dataType), expr)
case expr =>
if (structExpr.nullable) {
If(IsNull(structExpr), Literal(null, expr.dataType), expr)
} else {
expr
}
}
// Remove redundant array indexing.
case GetArrayStructFields(CreateArray(elems, useStringTypeWhenEmpty), field, ordinal, _, _) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
RemoveRedundantAliases,
UnwrapCastInBinaryComparison,
RemoveNoopOperators,
CombineUpdateFields,
OptimizeUpdateFields,
SimplifyExtractValueOps,
OptimizeJsonExprs,
CombineConcats) ++
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,68 @@

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.expressions.UpdateFields
import java.util.Locale

import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions.{Expression, UpdateFields, WithField}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf


/**
* Combines all adjacent [[UpdateFields]] expression into a single [[UpdateFields]] expression.
* Optimizes [[UpdateFields]] expression chains.
*/
object CombineUpdateFields extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
object OptimizeUpdateFields extends Rule[LogicalPlan] {
private def canOptimize(names: Seq[String]): Boolean = {
if (SQLConf.get.caseSensitiveAnalysis) {
names.distinct.length != names.length
} else {
names.map(_.toLowerCase(Locale.ROOT)).distinct.length != names.length
}
}

val optimizeUpdateFields: PartialFunction[Expression, Expression] = {
case UpdateFields(structExpr, fieldOps)
if fieldOps.forall(_.isInstanceOf[WithField]) &&
canOptimize(fieldOps.map(_.asInstanceOf[WithField].name)) =>
val caseSensitive = SQLConf.get.caseSensitiveAnalysis

val withFields = fieldOps.map(_.asInstanceOf[WithField])
val names = withFields.map(_.name)
val values = withFields.map(_.valExpr)

val newNames = mutable.ArrayBuffer.empty[String]
val newValues = mutable.ArrayBuffer.empty[Expression]

if (caseSensitive) {
names.zip(values).reverse.foreach { case (name, value) =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we could just do like: collection.immutable.ListMap(names.zip(values): _*) which will keep the last win here and keep the order of fields to use later. But I guess it's no big deal. Just saying.

if (!newNames.contains(name)) {
newNames += name
newValues += value
}
}
} else {
val nameSet = mutable.HashSet.empty[String]
names.zip(values).reverse.foreach { case (name, value) =>
val lowercaseName = name.toLowerCase(Locale.ROOT)
if (!nameSet.contains(lowercaseName)) {
newNames += name
newValues += value
nameSet += lowercaseName
}
}
}

val newWithFields = newNames.reverse.zip(newValues.reverse).map(p => WithField(p._1, p._2))
UpdateFields(structExpr, newWithFields.toSeq)

case UpdateFields(UpdateFields(struct, fieldOps1), fieldOps2) =>
UpdateFields(struct, fieldOps1 ++ fieldOps2)
}

def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions(optimizeUpdateFields)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,21 @@ package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, UpdateFields, WithField}
import org.apache.spark.sql.catalyst.expressions.{Alias, GetStructField, Literal, UpdateFields, WithField}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.internal.SQLConf


class CombineUpdateFieldsSuite extends PlanTest {
class OptimizeWithFieldsSuite extends PlanTest {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches = Batch("CombineUpdateFields", FixedPoint(10), CombineUpdateFields) :: Nil
val batches = Batch("OptimizeUpdateFields", FixedPoint(10),
OptimizeUpdateFields, SimplifyExtractValueOps) :: Nil
}

private val testRelation = LocalRelation('a.struct('a1.int))
private val testRelation2 = LocalRelation('a.struct('a1.int).notNull)

test("combines two adjacent UpdateFields Expressions") {
val originalQuery = testRelation
Expand Down Expand Up @@ -70,4 +72,58 @@ class CombineUpdateFieldsSuite extends PlanTest {

comparePlans(optimized, correctAnswer)
}

test("SPARK-32941: optimize WithFields followed by GetStructField") {
val originalQuery = testRelation2
.select(Alias(
GetStructField(UpdateFields('a,
WithField("b1", Literal(4)) :: Nil), 1), "out")())

val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = testRelation2
.select(Alias(Literal(4), "out")())
.analyze

comparePlans(optimized, correctAnswer)
}

test("SPARK-32941: optimize WithFields chain - case insensitive") {
val originalQuery = testRelation
.select(
Alias(UpdateFields('a,
WithField("b1", Literal(4)) :: WithField("b1", Literal(5)) :: Nil), "out1")(),
Alias(UpdateFields('a,
WithField("b1", Literal(4)) :: WithField("B1", Literal(5)) :: Nil), "out2")())

val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = testRelation
.select(
Alias(UpdateFields('a, WithField("b1", Literal(5)) :: Nil), "out1")(),
Alias(UpdateFields('a, WithField("B1", Literal(5)) :: Nil), "out2")())
.analyze

comparePlans(optimized, correctAnswer)
}

test("SPARK-32941: optimize WithFields chain - case sensitive") {
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
val originalQuery = testRelation
.select(
Alias(UpdateFields('a,
WithField("b1", Literal(4)) :: WithField("b1", Literal(5)) :: Nil), "out1")(),
Alias(UpdateFields('a,
WithField("b1", Literal(4)) :: WithField("B1", Literal(5)) :: Nil), "out2")())

val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = testRelation
.select(
Alias(UpdateFields('a, WithField("b1", Literal(5)) :: Nil), "out1")(),
Alias(
UpdateFields('a,
WithField("b1", Literal(4)) :: WithField("B1", Literal(5)) :: Nil), "out2")())
.analyze

comparePlans(optimized, correctAnswer)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
BooleanSimplification,
SimplifyConditionals,
SimplifyBinaryComparison,
CombineUpdateFields,
OptimizeUpdateFields,
SimplifyExtractValueOps) :: Nil
}

Expand Down Expand Up @@ -698,7 +698,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
val expected = structLevel2.select(
UpdateFields('a1, Seq(
// scalastyle:off line.size.limit
WithField("a2", UpdateFields(GetStructField('a1, 0), WithField("b3", 2) :: Nil)),
WithField("a2", UpdateFields(GetStructField('a1, 0), WithField("b3", 2) :: WithField("c3", 3) :: Nil))
// scalastyle:on line.size.limit
)).as("a1"))
Expand Down Expand Up @@ -732,7 +731,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {

structLevel2.select(
UpdateFields('a1, Seq(
WithField("a2", repeatedExpr),
WithField("a2", UpdateFields(
If(IsNull('a1), Literal(null, repeatedExprDataType), repeatedExpr),
WithField("c3", Literal(3)) :: Nil))
Expand Down Expand Up @@ -763,7 +761,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {

val expected = structLevel2.select(
UpdateFields('a1, Seq(
WithField("a2", UpdateFields(GetStructField('a1, 0), Seq(DropField("b3")))),
WithField("a2", UpdateFields(GetStructField('a1, 0), Seq(DropField("b3"), DropField("c3"))))
)).as("a1"))

Expand Down Expand Up @@ -797,7 +794,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {

structLevel2.select(
UpdateFields('a1, Seq(
WithField("a2", repeatedExpr),
WithField("a2", UpdateFields(
If(IsNull('a1), Literal(null, repeatedExprDataType), repeatedExpr),
DropField("c3") :: Nil))
Expand Down