From 19587e830a7889616583f48b44da61ca296c5215 Mon Sep 17 00:00:00 2001 From: "fqaiser94@gmail.com" Date: Sat, 11 Jul 2020 19:11:27 -0400 Subject: [PATCH 1/2] add dropFields method to Column class --- .../expressions/complexTypeCreator.scala | 109 +++--- .../sql/catalyst/optimizer/ComplexTypes.scala | 10 +- .../sql/catalyst/optimizer/Optimizer.scala | 7 +- .../{WithFields.scala => UpdateFields.scala} | 16 +- ...e.scala => CombineUpdateFieldsSuite.scala} | 41 ++- .../optimizer/complexTypesSuite.scala | 81 +++-- .../scala/org/apache/spark/sql/Column.scala | 80 ++++- .../spark/sql/ColumnExpressionSuite.scala | 309 +++++++++++++++++- 8 files changed, 531 insertions(+), 122 deletions(-) rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/{WithFields.scala => UpdateFields.scala} (68%) rename sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/{CombineWithFieldsSuite.scala => CombineUpdateFieldsSuite.scala} (65%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index cf7cc3a5e16ff..432db8e2f6f78 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} +import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCheckResult, TypeCoercion, UnresolvedException} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.{FUNC_ALIAS, FunctionBuilder} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ @@ -541,59 +541,94 @@ case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: E } /** - * Adds/replaces field in struct by name. + * Represents an operation to be applied to the fields of a struct. */ -case class WithFields( - structExpr: Expression, - names: Seq[String], - valExprs: Seq[Expression]) extends Unevaluable { +trait StructFieldsOperation { - assert(names.length == valExprs.length) + val resolver: Resolver = SQLConf.get.resolver + + /** + * Returns an updated list of expressions which will ultimately be used as the children argument + * for [[CreateNamedStruct]]. + */ + def apply(exprs: Seq[(String, Expression)]): Seq[(String, Expression)] +} + +/** + * Add or replace a field by name. + */ +case class WithField(name: String, valExpr: Expression) + extends Unevaluable with StructFieldsOperation { + + override def apply(exprs: Seq[(String, Expression)]): Seq[(String, Expression)] = + if (exprs.exists(x => resolver(x._1, name))) { + exprs.map { + case (existingName, _) if resolver(existingName, name) => (name, valExpr) + case x => x + } + } else { + exprs :+ (name, valExpr) + } + + override def children: Seq[Expression] = valExpr :: Nil + + override def dataType: DataType = throw new UnresolvedException(this, "dataType") + + override def nullable: Boolean = throw new UnresolvedException(this, "nullable") + + override def prettyName: String = "WithField" +} + +/** + * Drop a field by name. + */ +case class DropField(name: String) extends StructFieldsOperation { + override def apply(exprs: Seq[(String, Expression)]): Seq[(String, Expression)] = + exprs.filterNot(expr => resolver(expr._1, name)) +} + +/** + * Updates fields in struct by name. + */ +case class UpdateFields(structExpr: Expression, fieldOps: Seq[StructFieldsOperation]) + extends Unevaluable { override def checkInputDataTypes(): TypeCheckResult = { - if (!structExpr.dataType.isInstanceOf[StructType]) { - TypeCheckResult.TypeCheckFailure( - "struct argument should be struct type, got: " + structExpr.dataType.catalogString) + val dataType = structExpr.dataType + if (!dataType.isInstanceOf[StructType]) { + TypeCheckResult.TypeCheckFailure("struct argument should be struct type, got: " + + dataType.catalogString) + } else if (newExprs.isEmpty) { + TypeCheckResult.TypeCheckFailure("cannot drop all fields in struct") } else { TypeCheckResult.TypeCheckSuccess } } - override def children: Seq[Expression] = structExpr +: valExprs + override def children: Seq[Expression] = structExpr +: fieldOps.collect { + case e: Expression => e + } override def dataType: StructType = evalExpr.dataType.asInstanceOf[StructType] - override def foldable: Boolean = structExpr.foldable && valExprs.forall(_.foldable) - override def nullable: Boolean = structExpr.nullable - override def prettyName: String = "with_fields" + override def prettyName: String = "update_fields" - lazy val evalExpr: Expression = { - val existingExprs = structExpr.dataType.asInstanceOf[StructType].fieldNames.zipWithIndex.map { - case (name, i) => (name, GetStructField(KnownNotNull(structExpr), i).asInstanceOf[Expression]) + private lazy val existingExprs: Seq[(String, Expression)] = + structExpr.dataType.asInstanceOf[StructType].fieldNames.zipWithIndex.map { + case (name, i) => (name, GetStructField(KnownNotNull(structExpr), i)) } - val addOrReplaceExprs = names.zip(valExprs) - - val resolver = SQLConf.get.resolver - val newExprs = addOrReplaceExprs.foldLeft(existingExprs) { - case (resultExprs, newExpr @ (newExprName, _)) => - if (resultExprs.exists(x => resolver(x._1, newExprName))) { - resultExprs.map { - case (name, _) if resolver(name, newExprName) => newExpr - case x => x - } - } else { - resultExprs :+ newExpr - } - }.flatMap { case (name, expr) => Seq(Literal(name), expr) } + private lazy val newExprs = fieldOps.foldLeft(existingExprs)((exprs, op) => op(exprs)) - val expr = CreateNamedStruct(newExprs) - if (structExpr.nullable) { - If(IsNull(structExpr), Literal(null, expr.dataType), expr) - } else { - expr - } + private lazy val createNamedStructExpr = CreateNamedStruct(newExprs.flatMap { + case (name, expr) => Seq(Literal(name), expr) + }) + + lazy val evalExpr: Expression = if (structExpr.nullable) { + If(IsNull(structExpr), Literal(null, createNamedStructExpr.dataType), createNamedStructExpr) + } else { + createNamedStructExpr } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala index 1c33a2c7c3136..5d3e9302ccccc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala @@ -39,17 +39,17 @@ object SimplifyExtractValueOps extends Rule[LogicalPlan] { // Remove redundant field extraction. case GetStructField(createNamedStruct: CreateNamedStruct, ordinal, _) => createNamedStruct.valExprs(ordinal) - case GetStructField(w @ WithFields(struct, names, valExprs), ordinal, maybeName) => - val name = w.dataType(ordinal).name - val matches = names.zip(valExprs).filter(_._1 == name) + case GetStructField(u: UpdateFields, ordinal, maybeName) => + val name = u.dataType(ordinal).name + val matches = u.fieldOps.collect { case w: WithField if w.name == name => w } if (matches.nonEmpty) { // return last matching element as that is the final value for the field being extracted. // For example, if a user submits a query like this: // `$"struct_col".withField("b", lit(1)).withField("b", lit(2)).getField("b")` // we want to return `lit(2)` (and not `lit(1)`). - matches.last._2 + matches.last.valExpr } else { - GetStructField(struct, ordinal, maybeName) + GetStructField(u.structExpr, ordinal, maybeName) } // Remove redundant array indexing. case GetArrayStructFields(CreateArray(elems, useStringTypeWhenEmpty), field, ordinal, _, _) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index b8da954d938c4..ebf7d2dd5d1c0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -106,7 +106,7 @@ abstract class Optimizer(catalogManager: CatalogManager) EliminateSerialization, RemoveRedundantAliases, RemoveNoopOperators, - CombineWithFields, + CombineUpdateFields, SimplifyExtractValueOps, CombineConcats) ++ extendedOperatorOptimizationRules @@ -215,8 +215,7 @@ abstract class Optimizer(catalogManager: CatalogManager) RemoveNoopOperators) :+ // This batch must be executed after the `RewriteSubquery` batch, which creates joins. Batch("NormalizeFloatingNumbers", Once, NormalizeFloatingNumbers) :+ - Batch("ReplaceWithFieldsExpression", Once, ReplaceWithFieldsExpression) - + Batch("ReplaceUpdateFieldsExpression", Once, ReplaceUpdateFieldsExpression) // remove any batches with no rules. this may happen when subclasses do not add optional rules. batches.filter(_.rules.nonEmpty) } @@ -249,7 +248,7 @@ abstract class Optimizer(catalogManager: CatalogManager) RewriteCorrelatedScalarSubquery.ruleName :: RewritePredicateSubquery.ruleName :: NormalizeFloatingNumbers.ruleName :: - ReplaceWithFieldsExpression.ruleName :: Nil + ReplaceUpdateFieldsExpression.ruleName :: Nil /** * Optimize all the subqueries inside expression. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/WithFields.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UpdateFields.scala similarity index 68% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/WithFields.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UpdateFields.scala index 05c90864e4bb0..c7154210e0c62 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/WithFields.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UpdateFields.scala @@ -17,26 +17,26 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.expressions.WithFields +import org.apache.spark.sql.catalyst.expressions.UpdateFields import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule /** - * Combines all adjacent [[WithFields]] expression into a single [[WithFields]] expression. + * Combines all adjacent [[UpdateFields]] expression into a single [[UpdateFields]] expression. */ -object CombineWithFields extends Rule[LogicalPlan] { +object CombineUpdateFields extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case WithFields(WithFields(struct, names1, valExprs1), names2, valExprs2) => - WithFields(struct, names1 ++ names2, valExprs1 ++ valExprs2) + case UpdateFields(UpdateFields(struct, fieldOps1), fieldOps2) => + UpdateFields(struct, fieldOps1 ++ fieldOps2) } } /** - * Replaces [[WithFields]] expression with an evaluable expression. + * Replaces [[UpdateFields]] expression with an evaluable expression. */ -object ReplaceWithFieldsExpression extends Rule[LogicalPlan] { +object ReplaceUpdateFieldsExpression extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case w: WithFields => w.evalExpr + case u: UpdateFields => u.evalExpr } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineWithFieldsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineUpdateFieldsSuite.scala similarity index 65% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineWithFieldsSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineUpdateFieldsSuite.scala index a3e0bbc57e639..ff9c60a2fa5bd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineWithFieldsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineUpdateFieldsSuite.scala @@ -19,56 +19,53 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, WithFields} +import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, UpdateFields, WithField} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -class CombineWithFieldsSuite extends PlanTest { +class CombineUpdateFieldsSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { - val batches = Batch("CombineWithFields", FixedPoint(10), CombineWithFields) :: Nil + val batches = Batch("CombineUpdateFields", FixedPoint(10), CombineUpdateFields) :: Nil } private val testRelation = LocalRelation('a.struct('a1.int)) - test("combines two WithFields") { + test("combines two adjacent UpdateFields Expressions") { val originalQuery = testRelation .select(Alias( - WithFields( - WithFields( + UpdateFields( + UpdateFields( 'a, - Seq("b1"), - Seq(Literal(4))), - Seq("c1"), - Seq(Literal(5))), "out")()) + WithField("b1", Literal(4)) :: Nil), + WithField("c1", Literal(5)) :: Nil), "out")()) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .select(Alias(WithFields('a, Seq("b1", "c1"), Seq(Literal(4), Literal(5))), "out")()) + .select(Alias(UpdateFields('a, WithField("b1", Literal(4)) :: WithField("c1", Literal(5)) :: + Nil), "out")()) .analyze comparePlans(optimized, correctAnswer) } - test("combines three WithFields") { + test("combines three adjacent UpdateFields Expressions") { val originalQuery = testRelation .select(Alias( - WithFields( - WithFields( - WithFields( + UpdateFields( + UpdateFields( + UpdateFields( 'a, - Seq("b1"), - Seq(Literal(4))), - Seq("c1"), - Seq(Literal(5))), - Seq("d1"), - Seq(Literal(6))), "out")()) + WithField("b1", Literal(4)) :: Nil), + WithField("c1", Literal(5)) :: Nil), + WithField("d1", Literal(6)) :: Nil), "out")()) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .select(Alias(WithFields('a, Seq("b1", "c1", "d1"), Seq(4, 5, 6).map(Literal(_))), "out")()) + .select(Alias(UpdateFields('a, WithField("b1", Literal(4)) :: WithField("c1", Literal(5)) :: + WithField("d1", Literal(6)) :: Nil), "out")()) .analyze comparePlans(optimized, correctAnswer) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala index c71e7dbe7d6f9..b2821044ce93f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala @@ -453,49 +453,72 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](3, 4))), null) } - private val structAttr = 'struct1.struct('a.int) + private val structAttr = 'struct1.struct('a.int, 'b.int) private val testStructRelation = LocalRelation(structAttr) - test("simplify GetStructField on WithFields that is not changing the attribute being extracted") { - val query = testStructRelation.select( - GetStructField(WithFields('struct1, Seq("b"), Seq(Literal(1))), 0, Some("a")) as "outerAtt") - val expected = testStructRelation.select(GetStructField('struct1, 0, Some("a")) as "outerAtt") - checkRule(query, expected) + test("simplify GetStructField on UpdateFields that is not modifying the attribute being " + + "extracted") { + // add attribute, extract an attribute from the original struct + val query1 = testStructRelation.select(GetStructField(UpdateFields('struct1, + WithField("b", Literal(1)) :: Nil), 0, None) as "outerAtt") + // drop attribute, extract an attribute from the original struct + val query2 = testStructRelation.select(GetStructField(UpdateFields('struct1, DropField("b") :: + Nil), 0, None) as "outerAtt") + // drop attribute, add attribute, extract an attribute from the original struct + val query3 = testStructRelation.select(GetStructField(UpdateFields('struct1, DropField("b") :: + WithField("c", Literal(2)) :: Nil), 0, None) as "outerAtt") + // drop attribute, add attribute, extract an attribute from the original struct + val query4 = testStructRelation.select(GetStructField(UpdateFields('struct1, DropField("a") :: + WithField("a", Literal(1)) :: Nil), 0, None) as "outerAtt") + val expected = testStructRelation.select(GetStructField('struct1, 0, None) as "outerAtt") + + Seq(query1, query2, query3, query4).foreach { + query => checkRule(query, expected) + } } - test("simplify GetStructField on WithFields that is changing the attribute being extracted") { - val query = testStructRelation.select( - GetStructField(WithFields('struct1, Seq("b"), Seq(Literal(1))), 1, Some("b")) as "outerAtt") + test("simplify GetStructField on UpdateFields that is modifying the attribute being extracted") { + // add attribute, and then extract it + val query1 = testStructRelation.select(GetStructField(UpdateFields('struct1, + WithField("c", Literal(1)) :: Nil), 2, None) as "outerAtt") + // replace attribute, and then extract it + val query2 = testStructRelation.select(GetStructField(UpdateFields('struct1, + WithField("b", Literal(1)) :: Nil), 1, None) as "outerAtt") + // add attribute, replace the same attribute, and then extract it + val query3 = testStructRelation.select(GetStructField(UpdateFields('struct1, + WithField("c", Literal(2)) :: WithField("c", Literal(1)) :: Nil), 2, None) as "outerAtt") + // replace the same attribute twice, and then extract it + val query4 = testStructRelation.select(GetStructField(UpdateFields('struct1, + WithField("b", Literal(2)) :: WithField("b", Literal(1)) :: Nil), 1, None) as "outerAtt") + // replace attribute, drop another attribute, extract the replaced attribute + val query5 = testStructRelation.select(GetStructField(UpdateFields('struct1, + WithField("a", Literal(1)) :: DropField("b") :: Nil), 0, None) as "outerAtt") + // drop attribute, add attribute with same name, and then extract the added attribute + val query6 = testStructRelation.select(GetStructField(UpdateFields('struct1, DropField("a") :: + WithField("a", Literal(1)) :: Nil), 1, None) as "outerAtt") val expected = testStructRelation.select(Literal(1) as "outerAtt") - checkRule(query, expected) - } - test( - "simplify GetStructField on WithFields that is changing the attribute being extracted twice") { - val query = testStructRelation - .select(GetStructField(WithFields('struct1, Seq("b", "b"), Seq(Literal(1), Literal(2))), 1, - Some("b")) as "outerAtt") - val expected = testStructRelation.select(Literal(2) as "outerAtt") - checkRule(query, expected) + Seq(query1, query2, query3, query4, query5, query6).foreach { + query => checkRule(query, expected) + } } - test("collapse multiple GetStructField on the same WithFields") { + test("simplify multiple GetStructField on the same UpdateFields expression") { val query = testStructRelation - .select(WithFields('struct1, Seq("b"), Seq(Literal(2))) as "struct2") + .select(UpdateFields('struct1, WithField("b", Literal(2)) :: Nil) as "struct2") .select( GetStructField('struct2, 0, Some("a")) as "struct1A", GetStructField('struct2, 1, Some("b")) as "struct1B") - val expected = testStructRelation.select( - GetStructField('struct1, 0, Some("a")) as "struct1A", - Literal(2) as "struct1B") + val expected = testStructRelation + .select(GetStructField('struct1, 0, Some("a")) as "struct1A", Literal(2) as "struct1B") checkRule(query, expected) } - test("collapse multiple GetStructField on different WithFields") { + test("simplify multiple GetStructField on different UpdateFields expressions") { val query = testStructRelation .select( - WithFields('struct1, Seq("b"), Seq(Literal(2))) as "struct2", - WithFields('struct1, Seq("b"), Seq(Literal(3))) as "struct3") + UpdateFields('struct1, WithField("b", Literal(2)) :: Nil) as "struct2", + UpdateFields('struct1, WithField("b", Literal(3)) :: Nil) as "struct3") .select( GetStructField('struct2, 0, Some("a")) as "struct2A", GetStructField('struct2, 1, Some("b")) as "struct2B", @@ -503,10 +526,8 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { GetStructField('struct3, 1, Some("b")) as "struct3B") val expected = testStructRelation .select( - GetStructField('struct1, 0, Some("a")) as "struct2A", - Literal(2) as "struct2B", - GetStructField('struct1, 0, Some("a")) as "struct3A", - Literal(3) as "struct3B") + GetStructField('struct1, 0, Some("a")) as "struct2A", Literal(2) as "struct2B", + GetStructField('struct1, 0, Some("a")) as "struct3A", Literal(3) as "struct3B") checkRule(query, expected) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index da542c67d9c51..0ee49e79ade7a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -906,34 +906,84 @@ class Column(val expr: Expression) extends Logging { */ // scalastyle:on line.size.limit def withField(fieldName: String, col: Column): Column = withExpr { - require(fieldName != null, "fieldName cannot be null") require(col != null, "col cannot be null") + updateFieldsHelper(expr, nameParts(fieldName), name => WithField(name, col.expr)) + } + + // scalastyle:off line.size.limit + /** + * An expression that drops fields in `StructType` by name. + * + * {{{ + * val df = sql("SELECT named_struct('a', 1, 'b', 2) struct_col") + * df.select($"struct_col".dropFields("b")) + * // result: {"a":1} + * + * val df = sql("SELECT named_struct('a', 1, 'b', 2) struct_col") + * df.select($"struct_col".dropFields("c")) + * // result: {"a":1,"b":2} + * + * val df = sql("SELECT named_struct('a', 1, 'b', 2, 'c', 3) struct_col") + * df.select($"struct_col".dropFields("b", "c")) + * // result: {"a":1} + * + * val df = sql("SELECT named_struct('a', 1, 'b', 2) struct_col") + * df.select($"struct_col".dropFields("a", "b")) + * // result: org.apache.spark.sql.AnalysisException: cannot resolve 'update_fields(update_fields(`struct_col`))' due to data type mismatch: cannot drop all fields in struct + * + * val df = sql("SELECT CAST(NULL AS struct) struct_col") + * df.select($"struct_col".dropFields("b")) + * // result: null of type struct + * + * val df = sql("SELECT named_struct('a', 1, 'b', 2, 'b', 3) struct_col") + * df.select($"struct_col".dropFields("b")) + * // result: {"a":1} + * + * val df = sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2)) struct_col") + * df.select($"struct_col".dropFields("a.b")) + * // result: {"a":{"a":1}} + * + * val df = sql("SELECT named_struct('a', named_struct('b', 1), 'a', named_struct('c', 2)) struct_col") + * df.select($"struct_col".dropFields("a.c")) + * // result: org.apache.spark.sql.AnalysisException: Ambiguous reference to fields + * }}} + * + * @group expr_ops + * @since 3.1.0 + */ + // scalastyle:on line.size.limit + def dropFields(fieldNames: String*): Column = withExpr { + def dropField(expr: Expression, fieldName: String): UpdateFields = + updateFieldsHelper(expr, nameParts(fieldName), name => DropField(name)) + + fieldNames.tail.foldLeft(dropField(expr, fieldNames.head)) { + (resExpr, fieldName) => dropField(resExpr, fieldName) + } + } + + private def nameParts(fieldName: String): Seq[String] = { + require(fieldName != null, "fieldName cannot be null") - val nameParts = if (fieldName.isEmpty) { + if (fieldName.isEmpty) { fieldName :: Nil } else { CatalystSqlParser.parseMultipartIdentifier(fieldName) } - withFieldHelper(expr, nameParts, Nil, col.expr) } - private def withFieldHelper( + private def updateFieldsHelper( struct: Expression, namePartsRemaining: Seq[String], - namePartsDone: Seq[String], - value: Expression) : WithFields = { - val name = namePartsRemaining.head + value: String => StructFieldsOperation): UpdateFields = { + val fieldName = namePartsRemaining.head if (namePartsRemaining.length == 1) { - WithFields(struct, name :: Nil, value :: Nil) + UpdateFields(struct, value(fieldName) :: Nil) } else { - val newNamesRemaining = namePartsRemaining.tail - val newNamesDone = namePartsDone :+ name - val newValue = withFieldHelper( - struct = UnresolvedExtractValue(struct, Literal(name)), - namePartsRemaining = newNamesRemaining, - namePartsDone = newNamesDone, + val newValue = updateFieldsHelper( + struct = UnresolvedExtractValue(struct, Literal(fieldName)), + namePartsRemaining = namePartsRemaining.tail, value = value) - WithFields(struct, name :: Nil, newValue :: Nil) + UpdateFields(struct, WithField(fieldName, newValue) :: Nil) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 761632e76b165..9d9625c89c383 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -984,7 +984,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { intercept[IllegalArgumentException] { structLevel1.withColumn("a", $"a".withField(null, null)) - }.getMessage should include("fieldName cannot be null") + }.getMessage should include("col cannot be null") } test("withField should throw an exception if any intermediate structs don't exist") { @@ -1420,4 +1420,311 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { }.getMessage should include("No such struct field b in a, B") } } + + test("dropFields should throw an exception if called on a non-StructType column") { + intercept[AnalysisException] { + testData.withColumn("key", $"key".dropFields("a")) + }.getMessage should include("struct argument should be struct type, got: int") + } + + test("dropFields should throw an exception if fieldName argument is null") { + intercept[IllegalArgumentException] { + structLevel1.withColumn("a", $"a".dropFields(null)) + }.getMessage should include("fieldName cannot be null") + } + + test("dropFields should throw an exception if any intermediate structs don't exist") { + intercept[AnalysisException] { + structLevel2.withColumn("a", 'a.dropFields("x.b")) + }.getMessage should include("No such struct field x in a") + + intercept[AnalysisException] { + structLevel3.withColumn("a", 'a.dropFields("a.x.b")) + }.getMessage should include("No such struct field x in a") + } + + test("dropFields should throw an exception if intermediate field is not a struct") { + intercept[AnalysisException] { + structLevel1.withColumn("a", 'a.dropFields("b.a")) + }.getMessage should include("struct argument should be struct type, got: int") + } + + test("dropFields should throw an exception if intermediate field reference is ambiguous") { + intercept[AnalysisException] { + val structLevel2: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(Row(Row(1, null, 3), 4)) :: Nil), + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", structType, nullable = false), + StructField("a", structType, nullable = false))), + nullable = false)))) + + structLevel2.withColumn("a", 'a.dropFields("a.b")) + }.getMessage should include("Ambiguous reference to fields") + } + + test("dropFields should drop field in struct") { + checkAnswerAndSchema( + structLevel1.withColumn("a", 'a.dropFields("b")), + Row(Row(1, 3)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("c", IntegerType, nullable = false))), + nullable = false)))) + } + + test("dropFields should drop field in null struct") { + checkAnswerAndSchema( + nullStructLevel1.withColumn("a", $"a".dropFields("b")), + Row(null) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("c", IntegerType, nullable = false))), + nullable = true)))) + } + + test("dropFields should drop multiple fields in struct") { + Seq( + structLevel1.withColumn("a", $"a".dropFields("b", "c")), + structLevel1.withColumn("a", 'a.dropFields("b").dropFields("c")) + ).foreach { df => + checkAnswerAndSchema( + df, + Row(Row(1)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false))), + nullable = false)))) + } + } + + test("dropFields should throw an exception if no fields will be left in struct") { + intercept[AnalysisException] { + structLevel1.withColumn("a", 'a.dropFields("a", "b", "c")) + }.getMessage should include("cannot drop all fields in struct") + } + + test("dropFields should drop field in nested struct") { + checkAnswerAndSchema( + structLevel2.withColumn("a", 'a.dropFields("a.b")), + Row(Row(Row(1, 3))) :: Nil, + StructType( + Seq(StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("c", IntegerType, nullable = false))), + nullable = false))), + nullable = false)))) + } + + test("dropFields should drop multiple fields in nested struct") { + checkAnswerAndSchema( + structLevel2.withColumn("a", 'a.dropFields("a.b", "a.c")), + Row(Row(Row(1))) :: Nil, + StructType( + Seq(StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false))), + nullable = false))), + nullable = false)))) + } + + test("dropFields should drop field in nested null struct") { + checkAnswerAndSchema( + nullStructLevel2.withColumn("a", $"a".dropFields("a.b")), + Row(Row(null)) :: Nil, + StructType( + Seq(StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("c", IntegerType, nullable = false))), + nullable = true))), + nullable = false)))) + } + + test("dropFields should drop multiple fields in nested null struct") { + checkAnswerAndSchema( + nullStructLevel2.withColumn("a", $"a".dropFields("a.b", "a.c")), + Row(Row(null)) :: Nil, + StructType( + Seq(StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false))), + nullable = true))), + nullable = false)))) + } + + test("dropFields should drop field in deeply nested struct") { + checkAnswerAndSchema( + structLevel3.withColumn("a", 'a.dropFields("a.a.b")), + Row(Row(Row(Row(1, 3)))) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("c", IntegerType, nullable = false))), + nullable = false))), + nullable = false))), + nullable = false)))) + } + + test("dropFields should drop all fields with given name in struct") { + val structLevel1 = spark.createDataFrame( + sparkContext.parallelize(Row(Row(1, 2, 3)) :: Nil), + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false))), + nullable = false)))) + + checkAnswerAndSchema( + structLevel1.withColumn("a", 'a.dropFields("b")), + Row(Row(1)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false))), + nullable = false)))) + } + + test("dropFields should drop field in struct even if casing is different") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + checkAnswerAndSchema( + mixedCaseStructLevel1.withColumn("a", 'a.dropFields("A")), + Row(Row(1)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("B", IntegerType, nullable = false))), + nullable = false)))) + + checkAnswerAndSchema( + mixedCaseStructLevel1.withColumn("a", 'a.dropFields("b")), + Row(Row(1)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false))), + nullable = false)))) + } + } + + test("dropFields should not drop field in struct because casing is different") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + checkAnswerAndSchema( + mixedCaseStructLevel1.withColumn("a", 'a.dropFields("A")), + Row(Row(1, 1)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("B", IntegerType, nullable = false))), + nullable = false)))) + + checkAnswerAndSchema( + mixedCaseStructLevel1.withColumn("a", 'a.dropFields("b")), + Row(Row(1, 1)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("B", IntegerType, nullable = false))), + nullable = false)))) + } + } + + test("dropFields should drop nested field in struct even if casing is different") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + checkAnswerAndSchema( + mixedCaseStructLevel2.withColumn("a", 'a.dropFields("A.a")), + Row(Row(Row(1), Row(1, 1))) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("A", StructType(Seq( + StructField("b", IntegerType, nullable = false))), + nullable = false), + StructField("B", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false))), + nullable = false))), + nullable = false)))) + + checkAnswerAndSchema( + mixedCaseStructLevel2.withColumn("a", 'a.dropFields("b.a")), + Row(Row(Row(1, 1), Row(1))) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false))), + nullable = false), + StructField("b", StructType(Seq( + StructField("b", IntegerType, nullable = false))), + nullable = false))), + nullable = false)))) + } + } + + test("dropFields should throw an exception because casing is different") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + intercept[AnalysisException] { + mixedCaseStructLevel2.withColumn("a", 'a.dropFields("A.a")) + }.getMessage should include("No such struct field A in a, B") + + intercept[AnalysisException] { + mixedCaseStructLevel2.withColumn("a", 'a.dropFields("b.a")) + }.getMessage should include("No such struct field b in a, B") + } + } + + test("dropFields should drop only fields that exist") { + checkAnswerAndSchema( + structLevel1.withColumn("a", 'a.dropFields("d")), + Row(Row(1, null, 3)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false))), + nullable = false)))) + + checkAnswerAndSchema( + structLevel1.withColumn("a", 'a.dropFields("b", "d")), + Row(Row(1, 3)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("c", IntegerType, nullable = false))), + nullable = false)))) + + checkAnswerAndSchema( + structLevel2.withColumn("a", $"a".dropFields("a.b", "a.d")), + Row(Row(Row(1, 3))) :: Nil, + StructType( + Seq(StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("c", IntegerType, nullable = false))), + nullable = false))), + nullable = false)))) + } + + test("dropFields should drop multiple fields at arbitrary levels of nesting in a single call") { + val df: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(Row(Row(1, null, 3), 4)) :: Nil), + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", structType, nullable = false), + StructField("b", IntegerType, nullable = false))), + nullable = false)))) + + checkAnswerAndSchema( + df.withColumn("a", $"a".dropFields("a.b", "b")), + Row(Row(Row(1, 3))) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("c", IntegerType, nullable = false))), nullable = false))), + nullable = false)))) + } } From 948fc9cb7cd64a3cc69e297717bafea690ea0056 Mon Sep 17 00:00:00 2001 From: "fqaiser94@gmail.com" Date: Tue, 4 Aug 2020 15:35:20 -0400 Subject: [PATCH 2/2] add dropFields user-facing examples --- .../spark/sql/ColumnExpressionSuite.scala | 44 ++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index c0f847ba56648..8c9e36e289021 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -1452,7 +1452,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { .select($"struct_col".withField("a.c", lit(3))) }.getMessage should include("Ambiguous reference to fields") } - + test("dropFields should throw an exception if called on a non-StructType column") { intercept[AnalysisException] { testData.withColumn("key", $"key".dropFields("a")) @@ -1759,4 +1759,46 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { StructField("c", IntegerType, nullable = false))), nullable = false))), nullable = false)))) } + + test("dropFields user-facing examples") { + checkAnswer( + sql("SELECT named_struct('a', 1, 'b', 2) struct_col") + .select($"struct_col".dropFields("b")), + Row(Row(1))) + + checkAnswer( + sql("SELECT named_struct('a', 1, 'b', 2) struct_col") + .select($"struct_col".dropFields("c")), + Row(Row(1, 2))) + + checkAnswer( + sql("SELECT named_struct('a', 1, 'b', 2, 'c', 3) struct_col") + .select($"struct_col".dropFields("b", "c")), + Row(Row(1))) + + intercept[AnalysisException] { + sql("SELECT named_struct('a', 1, 'b', 2) struct_col") + .select($"struct_col".dropFields("a", "b")) + }.getMessage should include("cannot drop all fields in struct") + + checkAnswer( + sql("SELECT CAST(NULL AS struct) struct_col") + .select($"struct_col".dropFields("b")), + Row(null)) + + checkAnswer( + sql("SELECT named_struct('a', 1, 'b', 2, 'b', 3) struct_col") + .select($"struct_col".dropFields("b")), + Row(Row(1))) + + checkAnswer( + sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2)) struct_col") + .select($"struct_col".dropFields("a.b")), + Row(Row(Row(1)))) + + intercept[AnalysisException] { + sql("SELECT named_struct('a', named_struct('b', 1), 'a', named_struct('c', 2)) struct_col") + .select($"struct_col".dropFields("a.c")) + }.getMessage should include("Ambiguous reference to fields") + } }