Skip to content

Commit c59988a

Browse files
committed
[SPARK-34638][SQL] Single field nested column prune on generator output
### What changes were proposed in this pull request? This patch proposes an improvement on nested column pruning if the pruning target is generator's output. Previously we disallow such case. This patch allows to prune on it if there is only one single nested column is accessed after `Generate`. E.g., `df.select(explode($"items").as('item)).select($"item.itemId")`. As we only need `itemId` from `item`, we can prune other fields out and only keep `itemId`. In this patch, we only address explode-like generators. We will address other generators in followups. ### Why are the changes needed? This helps to extend the availability of nested column pruning. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Unit test Closes #31966 from viirya/SPARK-34638. Authored-by: Liang-Chi Hsieh <viirya@gmail.com> Signed-off-by: Liang-Chi Hsieh <viirya@gmail.com>
1 parent 1db031f commit c59988a

File tree

4 files changed

+177
-8
lines changed

4 files changed

+177
-8
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ case class ProjectionOverSchema(schema: StructType) {
4242
getProjection(a.child).map(p => (p, p.dataType)).map {
4343
case (projection, ArrayType(projSchema @ StructType(_), _)) =>
4444
// For case-sensitivity aware field resolution, we should take `ordinal` which
45-
// points to correct struct field.
45+
// points to correct struct field, because `ExtractValue` actually does column
46+
// name resolving correctly.
4647
val selectedField = a.child.dataType.asInstanceOf[ArrayType]
4748
.elementType.asInstanceOf[StructType](a.ordinal)
4849
val prunedField = projSchema(selectedField.name)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala

Lines changed: 94 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,27 @@ object NestedColumnAliasing {
231231
* of it.
232232
*/
233233
object GeneratorNestedColumnAliasing {
234+
// Partitions `attrToAliases` based on whether the attribute is in Generator's output.
235+
private def aliasesOnGeneratorOutput(
236+
attrToAliases: Map[ExprId, Seq[Alias]],
237+
generatorOutput: Seq[Attribute]) = {
238+
val generatorOutputExprId = generatorOutput.map(_.exprId)
239+
attrToAliases.partition { k =>
240+
generatorOutputExprId.contains(k._1)
241+
}
242+
}
243+
244+
// Partitions `nestedFieldToAlias` based on whether the attribute of nested field extractor
245+
// is in Generator's output.
246+
private def nestedFieldOnGeneratorOutput(
247+
nestedFieldToAlias: Map[ExtractValue, Alias],
248+
generatorOutput: Seq[Attribute]) = {
249+
val generatorOutputSet = AttributeSet(generatorOutput)
250+
nestedFieldToAlias.partition { pair =>
251+
pair._1.references.subsetOf(generatorOutputSet)
252+
}
253+
}
254+
234255
def unapply(plan: LogicalPlan): Option[LogicalPlan] = plan match {
235256
// Either `nestedPruningOnExpressions` or `nestedSchemaPruningEnabled` is enabled, we
236257
// need to prune nested columns through Project and under Generate. The difference is
@@ -241,12 +262,81 @@ object GeneratorNestedColumnAliasing {
241262
// On top on `Generate`, a `Project` that might have nested column accessors.
242263
// We try to get alias maps for both project list and generator's children expressions.
243264
val exprsToPrune = projectList ++ g.generator.children
244-
NestedColumnAliasing.getAliasSubMap(exprsToPrune, g.qualifiedGeneratorOutput).map {
265+
NestedColumnAliasing.getAliasSubMap(exprsToPrune).map {
245266
case (nestedFieldToAlias, attrToAliases) =>
267+
val (nestedFieldsOnGenerator, nestedFieldsNotOnGenerator) =
268+
nestedFieldOnGeneratorOutput(nestedFieldToAlias, g.qualifiedGeneratorOutput)
269+
val (attrToAliasesOnGenerator, attrToAliasesNotOnGenerator) =
270+
aliasesOnGeneratorOutput(attrToAliases, g.qualifiedGeneratorOutput)
271+
272+
// Push nested column accessors through `Generator`.
246273
// Defer updating `Generate.unrequiredChildIndex` to next round of `ColumnPruning`.
247-
val newChild =
248-
NestedColumnAliasing.replaceWithAliases(g, nestedFieldToAlias, attrToAliases)
249-
Project(NestedColumnAliasing.getNewProjectList(projectList, nestedFieldToAlias), newChild)
274+
val newChild = NestedColumnAliasing.replaceWithAliases(g,
275+
nestedFieldsNotOnGenerator, attrToAliasesNotOnGenerator)
276+
val pushedThrough = Project(NestedColumnAliasing
277+
.getNewProjectList(projectList, nestedFieldsNotOnGenerator), newChild)
278+
279+
// If the generator output is `ArrayType`, we cannot push through the extractor.
280+
// It is because we don't allow field extractor on two-level array,
281+
// i.e., attr.field when attr is a ArrayType(ArrayType(...)).
282+
// Similarily, we also cannot push through if the child of generator is `MapType`.
283+
g.generator.children.head.dataType match {
284+
case _: MapType => return Some(pushedThrough)
285+
case ArrayType(_: ArrayType, _) => return Some(pushedThrough)
286+
case _ =>
287+
}
288+
289+
// Pruning on `Generator`'s output. We only process single field case.
290+
// For multiple field case, we cannot directly move field extractor into
291+
// the generator expression. A workaround is to re-construct array of struct
292+
// from multiple fields. But it will be more complicated and may not worth.
293+
// TODO(SPARK-34956): support multiple fields.
294+
if (nestedFieldsOnGenerator.size > 1 || nestedFieldsOnGenerator.isEmpty) {
295+
pushedThrough
296+
} else {
297+
// Only one nested column accessor.
298+
// E.g., df.select(explode($"items").as("item")).select($"item.a")
299+
pushedThrough match {
300+
case p @ Project(_, newG: Generate) =>
301+
// Replace the child expression of `ExplodeBase` generator with
302+
// nested column accessor.
303+
// E.g., df.select(explode($"items").as("item")).select($"item.a") =>
304+
// df.select(explode($"items.a").as("item.a"))
305+
val rewrittenG = newG.transformExpressions {
306+
case e: ExplodeBase =>
307+
val extractor = nestedFieldsOnGenerator.head._1.transformUp {
308+
case _: Attribute =>
309+
e.child
310+
case g: GetStructField =>
311+
ExtractValue(g.child, Literal(g.extractFieldName), SQLConf.get.resolver)
312+
}
313+
e.withNewChildren(Seq(extractor))
314+
}
315+
316+
// As we change the child of the generator, its output data type must be updated.
317+
val updatedGeneratorOutput = rewrittenG.generatorOutput
318+
.zip(rewrittenG.generator.elementSchema.toAttributes)
319+
.map { case (oldAttr, newAttr) =>
320+
newAttr.withExprId(oldAttr.exprId).withName(oldAttr.name)
321+
}
322+
assert(updatedGeneratorOutput.length == rewrittenG.generatorOutput.length,
323+
"Updated generator output must have the same length " +
324+
"with original generator output.")
325+
val updatedGenerate = rewrittenG.copy(generatorOutput = updatedGeneratorOutput)
326+
327+
// Replace nested column accessor with generator output.
328+
p.withNewChildren(Seq(updatedGenerate)).transformExpressions {
329+
case f: ExtractValue if nestedFieldsOnGenerator.contains(f) =>
330+
updatedGenerate.output
331+
.find(a => attrToAliasesOnGenerator.contains(a.exprId))
332+
.getOrElse(f)
333+
}
334+
335+
case other =>
336+
// We should not reach here.
337+
throw new IllegalStateException(s"Unreasonable plan after optimization: $other")
338+
}
339+
}
250340
}
251341

252342
case g: Generate if SQLConf.get.nestedSchemaPruningEnabled &&

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -329,14 +329,14 @@ class NestedColumnAliasingSuite extends SchemaPruningTest {
329329
comparePlans(optimized, expected)
330330
}
331331

332-
test("Nested field pruning for Project and Generate: not prune on generator output") {
332+
test("Nested field pruning for Project and Generate: multiple-field case is not supported") {
333333
val companies = LocalRelation(
334334
'id.int,
335335
'employers.array(employer))
336336

337337
val query = companies
338338
.generate(Explode('employers.getField("company")), outputNames = Seq("company"))
339-
.select('company.getField("name"))
339+
.select('company.getField("name"), 'company.getField("address"))
340340
.analyze
341341
val optimized = Optimize.execute(query)
342342

@@ -347,7 +347,8 @@ class NestedColumnAliasingSuite extends SchemaPruningTest {
347347
.generate(Explode($"${aliases(0)}"),
348348
unrequiredChildIndex = Seq(0),
349349
outputNames = Seq("company"))
350-
.select('company.getField("name").as("company.name"))
350+
.select('company.getField("name").as("company.name"),
351+
'company.getField("address").as("company.address"))
351352
.analyze
352353
comparePlans(optimized, expected)
353354
}
@@ -684,6 +685,29 @@ class NestedColumnAliasingSuite extends SchemaPruningTest {
684685
).analyze
685686
comparePlans(optimized2, expected2)
686687
}
688+
689+
test("SPARK-34638: nested column prune on generator output for one field") {
690+
val companies = LocalRelation(
691+
'id.int,
692+
'employers.array(employer))
693+
694+
val query = companies
695+
.generate(Explode('employers.getField("company")), outputNames = Seq("company"))
696+
.select('company.getField("name"))
697+
.analyze
698+
val optimized = Optimize.execute(query)
699+
700+
val aliases = collectGeneratedAliases(optimized)
701+
702+
val expected = companies
703+
.select('employers.getField("company").getField("name").as(aliases(0)))
704+
.generate(Explode($"${aliases(0)}"),
705+
unrequiredChildIndex = Seq(0),
706+
outputNames = Seq("company"))
707+
.select('company.as("company.name"))
708+
.analyze
709+
comparePlans(optimized, expected)
710+
}
687711
}
688712

689713
object NestedColumnAliasingSuite {

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,43 @@ abstract class SchemaPruningSuite
351351
}
352352
}
353353

354+
testSchemaPruning("SPARK-34638: nested column prune on generator output") {
355+
val query1 = spark.table("contacts")
356+
.select(explode(col("friends")).as("friend"))
357+
.select("friend.first")
358+
checkScan(query1, "struct<friends:array<struct<first:string>>>")
359+
checkAnswer(query1, Row("Susan") :: Nil)
360+
361+
// Currently we don't prune multiple field case.
362+
val query2 = spark.table("contacts")
363+
.select(explode(col("friends")).as("friend"))
364+
.select("friend.first", "friend.middle")
365+
checkScan(query2, "struct<friends:array<struct<first:string,middle:string,last:string>>>")
366+
checkAnswer(query2, Row("Susan", "Z.") :: Nil)
367+
368+
val query3 = spark.table("contacts")
369+
.select(explode(col("friends")).as("friend"))
370+
.select("friend.first", "friend.middle", "friend")
371+
checkScan(query3, "struct<friends:array<struct<first:string,middle:string,last:string>>>")
372+
checkAnswer(query3, Row("Susan", "Z.", Row("Susan", "Z.", "Smith")) :: Nil)
373+
}
374+
375+
testSchemaPruning("SPARK-34638: nested column prune on generator output - case-sensitivity") {
376+
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
377+
val query1 = spark.table("contacts")
378+
.select(explode(col("friends")).as("friend"))
379+
.select("friend.First")
380+
checkScan(query1, "struct<friends:array<struct<first:string>>>")
381+
checkAnswer(query1, Row("Susan") :: Nil)
382+
383+
val query2 = spark.table("contacts")
384+
.select(explode(col("friends")).as("friend"))
385+
.select("friend.MIDDLE")
386+
checkScan(query2, "struct<friends:array<struct<middle:string>>>")
387+
checkAnswer(query2, Row("Z.") :: Nil)
388+
}
389+
}
390+
354391
testSchemaPruning("select one deep nested complex field after repartition") {
355392
val query = sql("select * from contacts")
356393
.repartition(100)
@@ -816,4 +853,21 @@ abstract class SchemaPruningSuite
816853
Row("John", "Y.") :: Nil)
817854
}
818855
}
856+
857+
test("SPARK-34638: queries should not fail on unsupported cases") {
858+
withTable("nested_array") {
859+
sql("select * from values array(array(named_struct('a', 1, 'b', 3), " +
860+
"named_struct('a', 2, 'b', 4))) T(items)").write.saveAsTable("nested_array")
861+
val query = sql("select d.a from (select explode(c) d from " +
862+
"(select explode(items) c from nested_array))")
863+
checkAnswer(query, Row(1) :: Row(2) :: Nil)
864+
}
865+
866+
withTable("map") {
867+
sql("select * from values map(1, named_struct('a', 1, 'b', 3), " +
868+
"2, named_struct('a', 2, 'b', 4)) T(items)").write.saveAsTable("map")
869+
val query = sql("select d.a from (select explode(items) (c, d) from map)")
870+
checkAnswer(query, Row(1) :: Row(2) :: Nil)
871+
}
872+
}
819873
}

0 commit comments

Comments
 (0)