@@ -231,6 +231,27 @@ object NestedColumnAliasing {
231
231
* of it.
232
232
*/
233
233
object GeneratorNestedColumnAliasing {
234
+ // Partitions `attrToAliases` based on whether the attribute is in Generator's output.
235
+ private def aliasesOnGeneratorOutput (
236
+ attrToAliases : Map [ExprId , Seq [Alias ]],
237
+ generatorOutput : Seq [Attribute ]) = {
238
+ val generatorOutputExprId = generatorOutput.map(_.exprId)
239
+ attrToAliases.partition { k =>
240
+ generatorOutputExprId.contains(k._1)
241
+ }
242
+ }
243
+
244
+ // Partitions `nestedFieldToAlias` based on whether the attribute of nested field extractor
245
+ // is in Generator's output.
246
+ private def nestedFieldOnGeneratorOutput (
247
+ nestedFieldToAlias : Map [ExtractValue , Alias ],
248
+ generatorOutput : Seq [Attribute ]) = {
249
+ val generatorOutputSet = AttributeSet (generatorOutput)
250
+ nestedFieldToAlias.partition { pair =>
251
+ pair._1.references.subsetOf(generatorOutputSet)
252
+ }
253
+ }
254
+
234
255
def unapply (plan : LogicalPlan ): Option [LogicalPlan ] = plan match {
235
256
// Either `nestedPruningOnExpressions` or `nestedSchemaPruningEnabled` is enabled, we
236
257
// need to prune nested columns through Project and under Generate. The difference is
@@ -241,12 +262,81 @@ object GeneratorNestedColumnAliasing {
241
262
// On top on `Generate`, a `Project` that might have nested column accessors.
242
263
// We try to get alias maps for both project list and generator's children expressions.
243
264
val exprsToPrune = projectList ++ g.generator.children
244
- NestedColumnAliasing .getAliasSubMap(exprsToPrune, g.qualifiedGeneratorOutput ).map {
265
+ NestedColumnAliasing .getAliasSubMap(exprsToPrune).map {
245
266
case (nestedFieldToAlias, attrToAliases) =>
267
+ val (nestedFieldsOnGenerator, nestedFieldsNotOnGenerator) =
268
+ nestedFieldOnGeneratorOutput(nestedFieldToAlias, g.qualifiedGeneratorOutput)
269
+ val (attrToAliasesOnGenerator, attrToAliasesNotOnGenerator) =
270
+ aliasesOnGeneratorOutput(attrToAliases, g.qualifiedGeneratorOutput)
271
+
272
+ // Push nested column accessors through `Generator`.
246
273
// Defer updating `Generate.unrequiredChildIndex` to next round of `ColumnPruning`.
247
- val newChild =
248
- NestedColumnAliasing .replaceWithAliases(g, nestedFieldToAlias, attrToAliases)
249
- Project (NestedColumnAliasing .getNewProjectList(projectList, nestedFieldToAlias), newChild)
274
+ val newChild = NestedColumnAliasing .replaceWithAliases(g,
275
+ nestedFieldsNotOnGenerator, attrToAliasesNotOnGenerator)
276
+ val pushedThrough = Project (NestedColumnAliasing
277
+ .getNewProjectList(projectList, nestedFieldsNotOnGenerator), newChild)
278
+
279
+ // If the generator output is `ArrayType`, we cannot push through the extractor.
280
+ // It is because we don't allow field extractor on two-level array,
281
+ // i.e., attr.field when attr is a ArrayType(ArrayType(...)).
282
+ // Similarily, we also cannot push through if the child of generator is `MapType`.
283
+ g.generator.children.head.dataType match {
284
+ case _ : MapType => return Some (pushedThrough)
285
+ case ArrayType (_ : ArrayType , _) => return Some (pushedThrough)
286
+ case _ =>
287
+ }
288
+
289
+ // Pruning on `Generator`'s output. We only process single field case.
290
+ // For multiple field case, we cannot directly move field extractor into
291
+ // the generator expression. A workaround is to re-construct array of struct
292
+ // from multiple fields. But it will be more complicated and may not worth.
293
+ // TODO(SPARK-34956): support multiple fields.
294
+ if (nestedFieldsOnGenerator.size > 1 || nestedFieldsOnGenerator.isEmpty) {
295
+ pushedThrough
296
+ } else {
297
+ // Only one nested column accessor.
298
+ // E.g., df.select(explode($"items").as("item")).select($"item.a")
299
+ pushedThrough match {
300
+ case p @ Project (_, newG : Generate ) =>
301
+ // Replace the child expression of `ExplodeBase` generator with
302
+ // nested column accessor.
303
+ // E.g., df.select(explode($"items").as("item")).select($"item.a") =>
304
+ // df.select(explode($"items.a").as("item.a"))
305
+ val rewrittenG = newG.transformExpressions {
306
+ case e : ExplodeBase =>
307
+ val extractor = nestedFieldsOnGenerator.head._1.transformUp {
308
+ case _ : Attribute =>
309
+ e.child
310
+ case g : GetStructField =>
311
+ ExtractValue (g.child, Literal (g.extractFieldName), SQLConf .get.resolver)
312
+ }
313
+ e.withNewChildren(Seq (extractor))
314
+ }
315
+
316
+ // As we change the child of the generator, its output data type must be updated.
317
+ val updatedGeneratorOutput = rewrittenG.generatorOutput
318
+ .zip(rewrittenG.generator.elementSchema.toAttributes)
319
+ .map { case (oldAttr, newAttr) =>
320
+ newAttr.withExprId(oldAttr.exprId).withName(oldAttr.name)
321
+ }
322
+ assert(updatedGeneratorOutput.length == rewrittenG.generatorOutput.length,
323
+ " Updated generator output must have the same length " +
324
+ " with original generator output." )
325
+ val updatedGenerate = rewrittenG.copy(generatorOutput = updatedGeneratorOutput)
326
+
327
+ // Replace nested column accessor with generator output.
328
+ p.withNewChildren(Seq (updatedGenerate)).transformExpressions {
329
+ case f : ExtractValue if nestedFieldsOnGenerator.contains(f) =>
330
+ updatedGenerate.output
331
+ .find(a => attrToAliasesOnGenerator.contains(a.exprId))
332
+ .getOrElse(f)
333
+ }
334
+
335
+ case other =>
336
+ // We should not reach here.
337
+ throw new IllegalStateException (s " Unreasonable plan after optimization: $other" )
338
+ }
339
+ }
250
340
}
251
341
252
342
case g : Generate if SQLConf .get.nestedSchemaPruningEnabled &&
0 commit comments