Skip to content

Commit

Permalink
add dropFields method to Column class
Browse files Browse the repository at this point in the history
  • Loading branch information
fqaiser94 committed Jul 31, 2020
1 parent 1c6dff7 commit 19587e8
Show file tree
Hide file tree
Showing 8 changed files with 531 additions and 122 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCheckResult, TypeCoercion, UnresolvedException}
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.{FUNC_ALIAS, FunctionBuilder}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
Expand Down Expand Up @@ -541,59 +541,94 @@ case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: E
}

/**
* Adds/replaces field in struct by name.
* Represents an operation to be applied to the fields of a struct.
*/
case class WithFields(
structExpr: Expression,
names: Seq[String],
valExprs: Seq[Expression]) extends Unevaluable {
trait StructFieldsOperation {

assert(names.length == valExprs.length)
val resolver: Resolver = SQLConf.get.resolver

/**
* Returns an updated list of expressions which will ultimately be used as the children argument
* for [[CreateNamedStruct]].
*/
def apply(exprs: Seq[(String, Expression)]): Seq[(String, Expression)]
}

/**
* Add or replace a field by name.
*/
case class WithField(name: String, valExpr: Expression)
extends Unevaluable with StructFieldsOperation {

override def apply(exprs: Seq[(String, Expression)]): Seq[(String, Expression)] =
if (exprs.exists(x => resolver(x._1, name))) {
exprs.map {
case (existingName, _) if resolver(existingName, name) => (name, valExpr)
case x => x
}
} else {
exprs :+ (name, valExpr)
}

override def children: Seq[Expression] = valExpr :: Nil

override def dataType: DataType = throw new UnresolvedException(this, "dataType")

override def nullable: Boolean = throw new UnresolvedException(this, "nullable")

override def prettyName: String = "WithField"
}

/**
* Drop a field by name.
*/
case class DropField(name: String) extends StructFieldsOperation {
override def apply(exprs: Seq[(String, Expression)]): Seq[(String, Expression)] =
exprs.filterNot(expr => resolver(expr._1, name))
}

/**
* Updates fields in struct by name.
*/
case class UpdateFields(structExpr: Expression, fieldOps: Seq[StructFieldsOperation])
extends Unevaluable {

override def checkInputDataTypes(): TypeCheckResult = {
if (!structExpr.dataType.isInstanceOf[StructType]) {
TypeCheckResult.TypeCheckFailure(
"struct argument should be struct type, got: " + structExpr.dataType.catalogString)
val dataType = structExpr.dataType
if (!dataType.isInstanceOf[StructType]) {
TypeCheckResult.TypeCheckFailure("struct argument should be struct type, got: " +
dataType.catalogString)
} else if (newExprs.isEmpty) {
TypeCheckResult.TypeCheckFailure("cannot drop all fields in struct")
} else {
TypeCheckResult.TypeCheckSuccess
}
}

override def children: Seq[Expression] = structExpr +: valExprs
override def children: Seq[Expression] = structExpr +: fieldOps.collect {
case e: Expression => e
}

override def dataType: StructType = evalExpr.dataType.asInstanceOf[StructType]

override def foldable: Boolean = structExpr.foldable && valExprs.forall(_.foldable)

override def nullable: Boolean = structExpr.nullable

override def prettyName: String = "with_fields"
override def prettyName: String = "update_fields"

lazy val evalExpr: Expression = {
val existingExprs = structExpr.dataType.asInstanceOf[StructType].fieldNames.zipWithIndex.map {
case (name, i) => (name, GetStructField(KnownNotNull(structExpr), i).asInstanceOf[Expression])
private lazy val existingExprs: Seq[(String, Expression)] =
structExpr.dataType.asInstanceOf[StructType].fieldNames.zipWithIndex.map {
case (name, i) => (name, GetStructField(KnownNotNull(structExpr), i))
}

val addOrReplaceExprs = names.zip(valExprs)

val resolver = SQLConf.get.resolver
val newExprs = addOrReplaceExprs.foldLeft(existingExprs) {
case (resultExprs, newExpr @ (newExprName, _)) =>
if (resultExprs.exists(x => resolver(x._1, newExprName))) {
resultExprs.map {
case (name, _) if resolver(name, newExprName) => newExpr
case x => x
}
} else {
resultExprs :+ newExpr
}
}.flatMap { case (name, expr) => Seq(Literal(name), expr) }
private lazy val newExprs = fieldOps.foldLeft(existingExprs)((exprs, op) => op(exprs))

val expr = CreateNamedStruct(newExprs)
if (structExpr.nullable) {
If(IsNull(structExpr), Literal(null, expr.dataType), expr)
} else {
expr
}
private lazy val createNamedStructExpr = CreateNamedStruct(newExprs.flatMap {
case (name, expr) => Seq(Literal(name), expr)
})

lazy val evalExpr: Expression = if (structExpr.nullable) {
If(IsNull(structExpr), Literal(null, createNamedStructExpr.dataType), createNamedStructExpr)
} else {
createNamedStructExpr
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,17 @@ object SimplifyExtractValueOps extends Rule[LogicalPlan] {
// Remove redundant field extraction.
case GetStructField(createNamedStruct: CreateNamedStruct, ordinal, _) =>
createNamedStruct.valExprs(ordinal)
case GetStructField(w @ WithFields(struct, names, valExprs), ordinal, maybeName) =>
val name = w.dataType(ordinal).name
val matches = names.zip(valExprs).filter(_._1 == name)
case GetStructField(u: UpdateFields, ordinal, maybeName) =>
val name = u.dataType(ordinal).name
val matches = u.fieldOps.collect { case w: WithField if w.name == name => w }
if (matches.nonEmpty) {
// return last matching element as that is the final value for the field being extracted.
// For example, if a user submits a query like this:
// `$"struct_col".withField("b", lit(1)).withField("b", lit(2)).getField("b")`
// we want to return `lit(2)` (and not `lit(1)`).
matches.last._2
matches.last.valExpr
} else {
GetStructField(struct, ordinal, maybeName)
GetStructField(u.structExpr, ordinal, maybeName)
}
// Remove redundant array indexing.
case GetArrayStructFields(CreateArray(elems, useStringTypeWhenEmpty), field, ordinal, _, _) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
EliminateSerialization,
RemoveRedundantAliases,
RemoveNoopOperators,
CombineWithFields,
CombineUpdateFields,
SimplifyExtractValueOps,
CombineConcats) ++
extendedOperatorOptimizationRules
Expand Down Expand Up @@ -215,8 +215,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
RemoveNoopOperators) :+
// This batch must be executed after the `RewriteSubquery` batch, which creates joins.
Batch("NormalizeFloatingNumbers", Once, NormalizeFloatingNumbers) :+
Batch("ReplaceWithFieldsExpression", Once, ReplaceWithFieldsExpression)

Batch("ReplaceUpdateFieldsExpression", Once, ReplaceUpdateFieldsExpression)
// remove any batches with no rules. this may happen when subclasses do not add optional rules.
batches.filter(_.rules.nonEmpty)
}
Expand Down Expand Up @@ -249,7 +248,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
RewriteCorrelatedScalarSubquery.ruleName ::
RewritePredicateSubquery.ruleName ::
NormalizeFloatingNumbers.ruleName ::
ReplaceWithFieldsExpression.ruleName :: Nil
ReplaceUpdateFieldsExpression.ruleName :: Nil

/**
* Optimize all the subqueries inside expression.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,26 @@

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.expressions.WithFields
import org.apache.spark.sql.catalyst.expressions.UpdateFields
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule


/**
* Combines all adjacent [[WithFields]] expression into a single [[WithFields]] expression.
* Combines all adjacent [[UpdateFields]] expression into a single [[UpdateFields]] expression.
*/
object CombineWithFields extends Rule[LogicalPlan] {
object CombineUpdateFields extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case WithFields(WithFields(struct, names1, valExprs1), names2, valExprs2) =>
WithFields(struct, names1 ++ names2, valExprs1 ++ valExprs2)
case UpdateFields(UpdateFields(struct, fieldOps1), fieldOps2) =>
UpdateFields(struct, fieldOps1 ++ fieldOps2)
}
}

/**
* Replaces [[WithFields]] expression with an evaluable expression.
* Replaces [[UpdateFields]] expression with an evaluable expression.
*/
object ReplaceWithFieldsExpression extends Rule[LogicalPlan] {
object ReplaceUpdateFieldsExpression extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case w: WithFields => w.evalExpr
case u: UpdateFields => u.evalExpr
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,56 +19,53 @@ package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, WithFields}
import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, UpdateFields, WithField}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._


class CombineWithFieldsSuite extends PlanTest {
class CombineUpdateFieldsSuite extends PlanTest {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches = Batch("CombineWithFields", FixedPoint(10), CombineWithFields) :: Nil
val batches = Batch("CombineUpdateFields", FixedPoint(10), CombineUpdateFields) :: Nil
}

private val testRelation = LocalRelation('a.struct('a1.int))

test("combines two WithFields") {
test("combines two adjacent UpdateFields Expressions") {
val originalQuery = testRelation
.select(Alias(
WithFields(
WithFields(
UpdateFields(
UpdateFields(
'a,
Seq("b1"),
Seq(Literal(4))),
Seq("c1"),
Seq(Literal(5))), "out")())
WithField("b1", Literal(4)) :: Nil),
WithField("c1", Literal(5)) :: Nil), "out")())

val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = testRelation
.select(Alias(WithFields('a, Seq("b1", "c1"), Seq(Literal(4), Literal(5))), "out")())
.select(Alias(UpdateFields('a, WithField("b1", Literal(4)) :: WithField("c1", Literal(5)) ::
Nil), "out")())
.analyze

comparePlans(optimized, correctAnswer)
}

test("combines three WithFields") {
test("combines three adjacent UpdateFields Expressions") {
val originalQuery = testRelation
.select(Alias(
WithFields(
WithFields(
WithFields(
UpdateFields(
UpdateFields(
UpdateFields(
'a,
Seq("b1"),
Seq(Literal(4))),
Seq("c1"),
Seq(Literal(5))),
Seq("d1"),
Seq(Literal(6))), "out")())
WithField("b1", Literal(4)) :: Nil),
WithField("c1", Literal(5)) :: Nil),
WithField("d1", Literal(6)) :: Nil), "out")())

val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = testRelation
.select(Alias(WithFields('a, Seq("b1", "c1", "d1"), Seq(4, 5, 6).map(Literal(_))), "out")())
.select(Alias(UpdateFields('a, WithField("b1", Literal(4)) :: WithField("c1", Literal(5)) ::
WithField("d1", Literal(6)) :: Nil), "out")())
.analyze

comparePlans(optimized, correctAnswer)
Expand Down
Loading

0 comments on commit 19587e8

Please sign in to comment.