Skip to content

Commit

Permalink
[SPARK-32511][SQL] Add dropFields method to Column class
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Added a new `dropFields` method to the `Column` class.
This method should allow users to drop a `StructField` in a `StructType` column (with similar semantics to the `drop` method on `Dataset`).

### Why are the changes needed?

Often Spark users have to work with deeply nested data e.g. to fix a data quality issue with an existing `StructField`. To do this with the existing Spark APIs, users have to rebuild the entire struct column.

For example, let's say you have the following deeply nested data structure which has a data quality issue (`5` is missing):
```
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._

val data = spark.createDataFrame(sc.parallelize(
      Seq(Row(Row(Row(1, 2, 3), Row(Row(4, null, 6), Row(7, 8, 9), Row(10, 11, 12)), Row(13, 14, 15))))),
      StructType(Seq(
        StructField("a", StructType(Seq(
          StructField("a", StructType(Seq(
            StructField("a", IntegerType),
            StructField("b", IntegerType),
            StructField("c", IntegerType)))),
          StructField("b", StructType(Seq(
            StructField("a", StructType(Seq(
              StructField("a", IntegerType),
              StructField("b", IntegerType),
              StructField("c", IntegerType)))),
            StructField("b", StructType(Seq(
              StructField("a", IntegerType),
              StructField("b", IntegerType),
              StructField("c", IntegerType)))),
            StructField("c", StructType(Seq(
              StructField("a", IntegerType),
              StructField("b", IntegerType),
              StructField("c", IntegerType))))
          ))),
          StructField("c", StructType(Seq(
            StructField("a", IntegerType),
            StructField("b", IntegerType),
            StructField("c", IntegerType))))
        )))))).cache

data.show(false)
+---------------------------------+
|a                                |
+---------------------------------+
|[[1, 2, 3], [[4,, 6], [7, 8, 9]]]|
+---------------------------------+
```
Currently, to drop the missing value users would have to do something like this:
```
val result = data.withColumn("a",
  struct(
    $"a.a",
    struct(
      struct(
        $"a.b.a.a",
        $"a.b.a.c"
      ).as("a"),
      $"a.b.b",
      $"a.b.c"
    ).as("b"),
    $"a.c"
  ))

result.show(false)
+---------------------------------------------------------------+
|a                                                              |
+---------------------------------------------------------------+
|[[1, 2, 3], [[4, 6], [7, 8, 9], [10, 11, 12]], [13, 14, 15]]|
+---------------------------------------------------------------+
```
As you can see above, with the existing methods users must call the `struct` function and list all fields, including fields they don't want to change. This is not ideal as:
>this leads to complex, fragile code that cannot survive schema evolution.
[SPARK-16483](https://issues.apache.org/jira/browse/SPARK-16483)

In contrast, with the method added in this PR, a user could simply do something like this to get the same result:
```
val result = data.withColumn("a", 'a.dropFields("b.a.b"))
result.show(false)
+---------------------------------------------------------------+
|a                                                              |
+---------------------------------------------------------------+
|[[1, 2, 3], [[4, 6], [7, 8, 9], [10, 11, 12]], [13, 14, 15]]|
+---------------------------------------------------------------+

```

This is the second of maybe 3 methods that could be added to the `Column` class to make it easier to manipulate nested data.
Other methods under discussion in [SPARK-22231](https://issues.apache.org/jira/browse/SPARK-22231) include `withFieldRenamed`.
However, this should be added in a separate PR.

### Does this PR introduce _any_ user-facing change?

Only one minor change. If the user submits the following query:
```
df.withColumn("a", $"a".withField(null, null))
```
instead of throwing:
```
java.lang.IllegalArgumentException: requirement failed: fieldName cannot be null
```
it will now throw:
```
java.lang.IllegalArgumentException: requirement failed: col cannot be null
```
I don't believe its should be an issue to change this because:
- neither message is incorrect
- Spark 3.1.0 has yet to be released

but please feel free to correct me if I am wrong.

### How was this patch tested?

New unit tests were added. Jenkins must pass them.

### Related JIRAs:
More discussion on this topic can be found here:
- https://issues.apache.org/jira/browse/SPARK-22231
- https://issues.apache.org/jira/browse/SPARK-16483

Closes #29322 from fqaiser94/SPARK-32511.

Lead-authored-by: fqaiser94@gmail.com <fqaiser94@gmail.com>
Co-authored-by: fqaiser94 <fqaiser94@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
fqaiser94 authored and cloud-fan committed Aug 13, 2020
1 parent 08d86eb commit 0c850c7
Show file tree
Hide file tree
Showing 8 changed files with 579 additions and 123 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCheckResult, TypeCoercion, UnresolvedException}
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.{FUNC_ALIAS, FunctionBuilder}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
Expand Down Expand Up @@ -541,57 +541,97 @@ case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: E
}

/**
* Adds/replaces field in struct by name.
* Represents an operation to be applied to the fields of a struct.
*/
case class WithFields(
structExpr: Expression,
names: Seq[String],
valExprs: Seq[Expression]) extends Unevaluable {
trait StructFieldsOperation {

assert(names.length == valExprs.length)
val resolver: Resolver = SQLConf.get.resolver

/**
* Returns an updated list of expressions which will ultimately be used as the children argument
* for [[CreateNamedStruct]].
*/
def apply(exprs: Seq[(String, Expression)]): Seq[(String, Expression)]
}

/**
* Add or replace a field by name.
*
* We extend [[Unevaluable]] here to ensure that [[UpdateFields]] can include it as part of its
* children, and thereby enable the analyzer to resolve and transform valExpr as necessary.
*/
case class WithField(name: String, valExpr: Expression)
extends Unevaluable with StructFieldsOperation {

override def apply(exprs: Seq[(String, Expression)]): Seq[(String, Expression)] =
if (exprs.exists(x => resolver(x._1, name))) {
exprs.map {
case (existingName, _) if resolver(existingName, name) => (name, valExpr)
case x => x
}
} else {
exprs :+ (name, valExpr)
}

override def children: Seq[Expression] = valExpr :: Nil

override def dataType: DataType = throw new UnresolvedException(this, "dataType")

override def nullable: Boolean = throw new UnresolvedException(this, "nullable")

override def prettyName: String = "WithField"
}

/**
* Drop a field by name.
*/
case class DropField(name: String) extends StructFieldsOperation {
override def apply(exprs: Seq[(String, Expression)]): Seq[(String, Expression)] =
exprs.filterNot(expr => resolver(expr._1, name))
}

/**
* Updates fields in struct by name.
*/
case class UpdateFields(structExpr: Expression, fieldOps: Seq[StructFieldsOperation])
extends Unevaluable {

override def checkInputDataTypes(): TypeCheckResult = {
if (!structExpr.dataType.isInstanceOf[StructType]) {
TypeCheckResult.TypeCheckFailure(
"struct argument should be struct type, got: " + structExpr.dataType.catalogString)
val dataType = structExpr.dataType
if (!dataType.isInstanceOf[StructType]) {
TypeCheckResult.TypeCheckFailure("struct argument should be struct type, got: " +
dataType.catalogString)
} else if (newExprs.isEmpty) {
TypeCheckResult.TypeCheckFailure("cannot drop all fields in struct")
} else {
TypeCheckResult.TypeCheckSuccess
}
}

override def children: Seq[Expression] = structExpr +: valExprs
override def children: Seq[Expression] = structExpr +: fieldOps.collect {
case e: Expression => e
}

override def dataType: StructType = evalExpr.dataType.asInstanceOf[StructType]

override def nullable: Boolean = structExpr.nullable

override def prettyName: String = "with_fields"
override def prettyName: String = "update_fields"

lazy val evalExpr: Expression = {
val existingExprs = structExpr.dataType.asInstanceOf[StructType].fieldNames.zipWithIndex.map {
case (name, i) => (name, GetStructField(KnownNotNull(structExpr), i).asInstanceOf[Expression])
private lazy val existingExprs: Seq[(String, Expression)] =
structExpr.dataType.asInstanceOf[StructType].fieldNames.zipWithIndex.map {
case (name, i) => (name, GetStructField(KnownNotNull(structExpr), i))
}

val addOrReplaceExprs = names.zip(valExprs)

val resolver = SQLConf.get.resolver
val newExprs = addOrReplaceExprs.foldLeft(existingExprs) {
case (resultExprs, newExpr @ (newExprName, _)) =>
if (resultExprs.exists(x => resolver(x._1, newExprName))) {
resultExprs.map {
case (name, _) if resolver(name, newExprName) => newExpr
case x => x
}
} else {
resultExprs :+ newExpr
}
}.flatMap { case (name, expr) => Seq(Literal(name), expr) }
private lazy val newExprs = fieldOps.foldLeft(existingExprs)((exprs, op) => op(exprs))

val expr = CreateNamedStruct(newExprs)
if (structExpr.nullable) {
If(IsNull(structExpr), Literal(null, expr.dataType), expr)
} else {
expr
}
private lazy val createNamedStructExpr = CreateNamedStruct(newExprs.flatMap {
case (name, expr) => Seq(Literal(name), expr)
})

lazy val evalExpr: Expression = if (structExpr.nullable) {
If(IsNull(structExpr), Literal(null, createNamedStructExpr.dataType), createNamedStructExpr)
} else {
createNamedStructExpr
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,17 @@ object SimplifyExtractValueOps extends Rule[LogicalPlan] {
// Remove redundant field extraction.
case GetStructField(createNamedStruct: CreateNamedStruct, ordinal, _) =>
createNamedStruct.valExprs(ordinal)
case GetStructField(w @ WithFields(struct, names, valExprs), ordinal, maybeName) =>
val name = w.dataType(ordinal).name
val matches = names.zip(valExprs).filter(_._1 == name)
case GetStructField(u: UpdateFields, ordinal, maybeName) =>
val name = u.dataType(ordinal).name
val matches = u.fieldOps.collect { case w: WithField if w.name == name => w }
if (matches.nonEmpty) {
// return last matching element as that is the final value for the field being extracted.
// For example, if a user submits a query like this:
// `$"struct_col".withField("b", lit(1)).withField("b", lit(2)).getField("b")`
// we want to return `lit(2)` (and not `lit(1)`).
matches.last._2
matches.last.valExpr
} else {
GetStructField(struct, ordinal, maybeName)
GetStructField(u.structExpr, ordinal, maybeName)
}
// Remove redundant array indexing.
case GetArrayStructFields(CreateArray(elems, useStringTypeWhenEmpty), field, ordinal, _, _) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
EliminateSerialization,
RemoveRedundantAliases,
RemoveNoopOperators,
CombineWithFields,
CombineUpdateFields,
SimplifyExtractValueOps,
CombineConcats) ++
extendedOperatorOptimizationRules
Expand Down Expand Up @@ -217,8 +217,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
RemoveNoopOperators) :+
// This batch must be executed after the `RewriteSubquery` batch, which creates joins.
Batch("NormalizeFloatingNumbers", Once, NormalizeFloatingNumbers) :+
Batch("ReplaceWithFieldsExpression", Once, ReplaceWithFieldsExpression)

Batch("ReplaceUpdateFieldsExpression", Once, ReplaceUpdateFieldsExpression)
// remove any batches with no rules. this may happen when subclasses do not add optional rules.
batches.filter(_.rules.nonEmpty)
}
Expand Down Expand Up @@ -251,7 +250,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
RewriteCorrelatedScalarSubquery.ruleName ::
RewritePredicateSubquery.ruleName ::
NormalizeFloatingNumbers.ruleName ::
ReplaceWithFieldsExpression.ruleName :: Nil
ReplaceUpdateFieldsExpression.ruleName :: Nil

/**
* Optimize all the subqueries inside expression.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,26 @@

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.expressions.WithFields
import org.apache.spark.sql.catalyst.expressions.UpdateFields
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule


/**
* Combines all adjacent [[WithFields]] expression into a single [[WithFields]] expression.
* Combines all adjacent [[UpdateFields]] expression into a single [[UpdateFields]] expression.
*/
object CombineWithFields extends Rule[LogicalPlan] {
object CombineUpdateFields extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case WithFields(WithFields(struct, names1, valExprs1), names2, valExprs2) =>
WithFields(struct, names1 ++ names2, valExprs1 ++ valExprs2)
case UpdateFields(UpdateFields(struct, fieldOps1), fieldOps2) =>
UpdateFields(struct, fieldOps1 ++ fieldOps2)
}
}

/**
* Replaces [[WithFields]] expression with an evaluable expression.
* Replaces [[UpdateFields]] expression with an evaluable expression.
*/
object ReplaceWithFieldsExpression extends Rule[LogicalPlan] {
object ReplaceUpdateFieldsExpression extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case w: WithFields => w.evalExpr
case u: UpdateFields => u.evalExpr
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,56 +19,53 @@ package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, WithFields}
import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, UpdateFields, WithField}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._


class CombineWithFieldsSuite extends PlanTest {
class CombineUpdateFieldsSuite extends PlanTest {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches = Batch("CombineWithFields", FixedPoint(10), CombineWithFields) :: Nil
val batches = Batch("CombineUpdateFields", FixedPoint(10), CombineUpdateFields) :: Nil
}

private val testRelation = LocalRelation('a.struct('a1.int))

test("combines two WithFields") {
test("combines two adjacent UpdateFields Expressions") {
val originalQuery = testRelation
.select(Alias(
WithFields(
WithFields(
UpdateFields(
UpdateFields(
'a,
Seq("b1"),
Seq(Literal(4))),
Seq("c1"),
Seq(Literal(5))), "out")())
WithField("b1", Literal(4)) :: Nil),
WithField("c1", Literal(5)) :: Nil), "out")())

val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = testRelation
.select(Alias(WithFields('a, Seq("b1", "c1"), Seq(Literal(4), Literal(5))), "out")())
.select(Alias(UpdateFields('a, WithField("b1", Literal(4)) :: WithField("c1", Literal(5)) ::
Nil), "out")())
.analyze

comparePlans(optimized, correctAnswer)
}

test("combines three WithFields") {
test("combines three adjacent UpdateFields Expressions") {
val originalQuery = testRelation
.select(Alias(
WithFields(
WithFields(
WithFields(
UpdateFields(
UpdateFields(
UpdateFields(
'a,
Seq("b1"),
Seq(Literal(4))),
Seq("c1"),
Seq(Literal(5))),
Seq("d1"),
Seq(Literal(6))), "out")())
WithField("b1", Literal(4)) :: Nil),
WithField("c1", Literal(5)) :: Nil),
WithField("d1", Literal(6)) :: Nil), "out")())

val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = testRelation
.select(Alias(WithFields('a, Seq("b1", "c1", "d1"), Seq(4, 5, 6).map(Literal(_))), "out")())
.select(Alias(UpdateFields('a, WithField("b1", Literal(4)) :: WithField("c1", Literal(5)) ::
WithField("d1", Literal(6)) :: Nil), "out")())
.analyze

comparePlans(optimized, correctAnswer)
Expand Down
Loading

0 comments on commit 0c850c7

Please sign in to comment.