Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-32511][SQL] Add dropFields method to Column class #29322

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCheckResult, TypeCoercion, UnresolvedException}
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.{FUNC_ALIAS, FunctionBuilder}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
Expand Down Expand Up @@ -541,57 +541,97 @@ case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: E
}

/**
* Adds/replaces field in struct by name.
* Represents an operation to be applied to the fields of a struct.
*/
case class WithFields(
structExpr: Expression,
names: Seq[String],
valExprs: Seq[Expression]) extends Unevaluable {
trait StructFieldsOperation {

assert(names.length == valExprs.length)
val resolver: Resolver = SQLConf.get.resolver

/**
* Returns an updated list of expressions which will ultimately be used as the children argument
* for [[CreateNamedStruct]].
*/
def apply(exprs: Seq[(String, Expression)]): Seq[(String, Expression)]
}

/**
* Add or replace a field by name.
*
* We extend [[Unevaluable]] here to ensure that [[UpdateFields]] can include it as part of its
* children, and thereby enable the analyzer to resolve and transform valExpr as necessary.
*/
case class WithField(name: String, valExpr: Expression)
extends Unevaluable with StructFieldsOperation {
fqaiser94 marked this conversation as resolved.
Show resolved Hide resolved

override def apply(exprs: Seq[(String, Expression)]): Seq[(String, Expression)] =
if (exprs.exists(x => resolver(x._1, name))) {
exprs.map {
case (existingName, _) if resolver(existingName, name) => (name, valExpr)
case x => x
}
} else {
exprs :+ (name, valExpr)
}

override def children: Seq[Expression] = valExpr :: Nil

override def dataType: DataType = throw new UnresolvedException(this, "dataType")

override def nullable: Boolean = throw new UnresolvedException(this, "nullable")

override def prettyName: String = "WithField"
cloud-fan marked this conversation as resolved.
Show resolved Hide resolved
}

/**
* Drop a field by name.
*/
case class DropField(name: String) extends StructFieldsOperation {
override def apply(exprs: Seq[(String, Expression)]): Seq[(String, Expression)] =
exprs.filterNot(expr => resolver(expr._1, name))
}

/**
* Updates fields in struct by name.
*/
case class UpdateFields(structExpr: Expression, fieldOps: Seq[StructFieldsOperation])
extends Unevaluable {

override def checkInputDataTypes(): TypeCheckResult = {
if (!structExpr.dataType.isInstanceOf[StructType]) {
TypeCheckResult.TypeCheckFailure(
"struct argument should be struct type, got: " + structExpr.dataType.catalogString)
val dataType = structExpr.dataType
if (!dataType.isInstanceOf[StructType]) {
TypeCheckResult.TypeCheckFailure("struct argument should be struct type, got: " +
dataType.catalogString)
} else if (newExprs.isEmpty) {
TypeCheckResult.TypeCheckFailure("cannot drop all fields in struct")
} else {
TypeCheckResult.TypeCheckSuccess
}
}

override def children: Seq[Expression] = structExpr +: valExprs
override def children: Seq[Expression] = structExpr +: fieldOps.collect {
case e: Expression => e
cloud-fan marked this conversation as resolved.
Show resolved Hide resolved
}

override def dataType: StructType = evalExpr.dataType.asInstanceOf[StructType]

override def nullable: Boolean = structExpr.nullable

override def prettyName: String = "with_fields"
override def prettyName: String = "update_fields"

lazy val evalExpr: Expression = {
val existingExprs = structExpr.dataType.asInstanceOf[StructType].fieldNames.zipWithIndex.map {
case (name, i) => (name, GetStructField(KnownNotNull(structExpr), i).asInstanceOf[Expression])
private lazy val existingExprs: Seq[(String, Expression)] =
structExpr.dataType.asInstanceOf[StructType].fieldNames.zipWithIndex.map {
case (name, i) => (name, GetStructField(KnownNotNull(structExpr), i))
}

val addOrReplaceExprs = names.zip(valExprs)

val resolver = SQLConf.get.resolver
val newExprs = addOrReplaceExprs.foldLeft(existingExprs) {
case (resultExprs, newExpr @ (newExprName, _)) =>
if (resultExprs.exists(x => resolver(x._1, newExprName))) {
resultExprs.map {
case (name, _) if resolver(name, newExprName) => newExpr
case x => x
}
} else {
resultExprs :+ newExpr
}
}.flatMap { case (name, expr) => Seq(Literal(name), expr) }
private lazy val newExprs = fieldOps.foldLeft(existingExprs)((exprs, op) => op(exprs))

val expr = CreateNamedStruct(newExprs)
if (structExpr.nullable) {
If(IsNull(structExpr), Literal(null, expr.dataType), expr)
} else {
expr
}
private lazy val createNamedStructExpr = CreateNamedStruct(newExprs.flatMap {
case (name, expr) => Seq(Literal(name), expr)
})

lazy val evalExpr: Expression = if (structExpr.nullable) {
If(IsNull(structExpr), Literal(null, createNamedStructExpr.dataType), createNamedStructExpr)
} else {
createNamedStructExpr
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,17 @@ object SimplifyExtractValueOps extends Rule[LogicalPlan] {
// Remove redundant field extraction.
case GetStructField(createNamedStruct: CreateNamedStruct, ordinal, _) =>
createNamedStruct.valExprs(ordinal)
case GetStructField(w @ WithFields(struct, names, valExprs), ordinal, maybeName) =>
val name = w.dataType(ordinal).name
val matches = names.zip(valExprs).filter(_._1 == name)
case GetStructField(u: UpdateFields, ordinal, maybeName) =>
val name = u.dataType(ordinal).name
val matches = u.fieldOps.collect { case w: WithField if w.name == name => w }
if (matches.nonEmpty) {
// return last matching element as that is the final value for the field being extracted.
// For example, if a user submits a query like this:
// `$"struct_col".withField("b", lit(1)).withField("b", lit(2)).getField("b")`
// we want to return `lit(2)` (and not `lit(1)`).
matches.last._2
matches.last.valExpr
} else {
GetStructField(struct, ordinal, maybeName)
GetStructField(u.structExpr, ordinal, maybeName)
}
// Remove redundant array indexing.
case GetArrayStructFields(CreateArray(elems, useStringTypeWhenEmpty), field, ordinal, _, _) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
EliminateSerialization,
RemoveRedundantAliases,
RemoveNoopOperators,
CombineWithFields,
CombineUpdateFields,
SimplifyExtractValueOps,
CombineConcats) ++
extendedOperatorOptimizationRules
Expand Down Expand Up @@ -217,8 +217,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
RemoveNoopOperators) :+
// This batch must be executed after the `RewriteSubquery` batch, which creates joins.
Batch("NormalizeFloatingNumbers", Once, NormalizeFloatingNumbers) :+
Batch("ReplaceWithFieldsExpression", Once, ReplaceWithFieldsExpression)

Batch("ReplaceUpdateFieldsExpression", Once, ReplaceUpdateFieldsExpression)
// remove any batches with no rules. this may happen when subclasses do not add optional rules.
batches.filter(_.rules.nonEmpty)
}
Expand Down Expand Up @@ -251,7 +250,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
RewriteCorrelatedScalarSubquery.ruleName ::
RewritePredicateSubquery.ruleName ::
NormalizeFloatingNumbers.ruleName ::
ReplaceWithFieldsExpression.ruleName :: Nil
ReplaceUpdateFieldsExpression.ruleName :: Nil

/**
* Optimize all the subqueries inside expression.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,26 @@

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.expressions.WithFields
import org.apache.spark.sql.catalyst.expressions.UpdateFields
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule


/**
* Combines all adjacent [[WithFields]] expression into a single [[WithFields]] expression.
* Combines all adjacent [[UpdateFields]] expression into a single [[UpdateFields]] expression.
*/
object CombineWithFields extends Rule[LogicalPlan] {
object CombineUpdateFields extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case WithFields(WithFields(struct, names1, valExprs1), names2, valExprs2) =>
WithFields(struct, names1 ++ names2, valExprs1 ++ valExprs2)
case UpdateFields(UpdateFields(struct, fieldOps1), fieldOps2) =>
UpdateFields(struct, fieldOps1 ++ fieldOps2)
}
}

/**
* Replaces [[WithFields]] expression with an evaluable expression.
* Replaces [[UpdateFields]] expression with an evaluable expression.
*/
object ReplaceWithFieldsExpression extends Rule[LogicalPlan] {
object ReplaceUpdateFieldsExpression extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case w: WithFields => w.evalExpr
case u: UpdateFields => u.evalExpr
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,56 +19,53 @@ package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, WithFields}
import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, UpdateFields, WithField}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._


class CombineWithFieldsSuite extends PlanTest {
class CombineUpdateFieldsSuite extends PlanTest {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches = Batch("CombineWithFields", FixedPoint(10), CombineWithFields) :: Nil
val batches = Batch("CombineUpdateFields", FixedPoint(10), CombineUpdateFields) :: Nil
}

private val testRelation = LocalRelation('a.struct('a1.int))

test("combines two WithFields") {
test("combines two adjacent UpdateFields Expressions") {
val originalQuery = testRelation
.select(Alias(
WithFields(
WithFields(
UpdateFields(
UpdateFields(
'a,
Seq("b1"),
Seq(Literal(4))),
Seq("c1"),
Seq(Literal(5))), "out")())
WithField("b1", Literal(4)) :: Nil),
WithField("c1", Literal(5)) :: Nil), "out")())

val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = testRelation
.select(Alias(WithFields('a, Seq("b1", "c1"), Seq(Literal(4), Literal(5))), "out")())
.select(Alias(UpdateFields('a, WithField("b1", Literal(4)) :: WithField("c1", Literal(5)) ::
Nil), "out")())
.analyze

comparePlans(optimized, correctAnswer)
}

test("combines three WithFields") {
test("combines three adjacent UpdateFields Expressions") {
val originalQuery = testRelation
.select(Alias(
WithFields(
WithFields(
WithFields(
UpdateFields(
UpdateFields(
UpdateFields(
'a,
Seq("b1"),
Seq(Literal(4))),
Seq("c1"),
Seq(Literal(5))),
Seq("d1"),
Seq(Literal(6))), "out")())
WithField("b1", Literal(4)) :: Nil),
WithField("c1", Literal(5)) :: Nil),
WithField("d1", Literal(6)) :: Nil), "out")())

val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = testRelation
.select(Alias(WithFields('a, Seq("b1", "c1", "d1"), Seq(4, 5, 6).map(Literal(_))), "out")())
.select(Alias(UpdateFields('a, WithField("b1", Literal(4)) :: WithField("c1", Literal(5)) ::
WithField("d1", Literal(6)) :: Nil), "out")())
.analyze

comparePlans(optimized, correctAnswer)
Expand Down
Loading