diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala index bb67c173b9460..6b531a3fac7cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala @@ -62,6 +62,24 @@ case class ProjectionOverSchema(schema: StructType, output: AttributeSet) { s"unmatched child schema for GetArrayStructFields: ${projSchema.toString}" ) } + case a: GetNestedArrayStructFields => + getProjection(a.child).map(p => (p, p.dataType)).map { + case (projection, projArrayType: ArrayType) => + // Find the innermost struct in both original and projected types + val originalStruct = findInnermostStruct(a.child.dataType) + val projStruct = findInnermostStruct(projArrayType) + val selectedField = originalStruct(a.ordinal) + val prunedField = projStruct(selectedField.name) + GetNestedArrayStructFields(projection, + prunedField.copy(name = a.field.name), + projStruct.fieldIndex(selectedField.name), + projStruct.size, + a.containsNull) + case (_, projSchema) => + throw new IllegalStateException( + s"unmatched child schema for GetNestedArrayStructFields: ${projSchema.toString}" + ) + } case MapKeys(child) => getProjection(child).map { projection => MapKeys(projection) } case MapValues(child) => @@ -79,7 +97,40 @@ case class ProjectionOverSchema(schema: StructType, output: AttributeSet) { } case ElementAt(left, right, defaultValueOutOfBound, failOnError) if right.foldable => getProjection(left).map(p => ElementAt(p, right, defaultValueOutOfBound, failOnError)) + case az: ArraysZip => + // Project each child expression and rebuild ArraysZip with projected children + val projectedChildren = az.children.map(getProjection) + if (projectedChildren.forall(_.isDefined)) { + Some(az.copy(children = projectedChildren.map(_.get))) + } else { + None + } + case naz: NestedArraysZip => + // Project each child expression and rebuild NestedArraysZip with projected children + val projectedChildren = naz.children.map(getProjection) + if (projectedChildren.forall(_.isDefined)) { + Some(naz.copy(children = projectedChildren.map(_.get))) + } else { + None + } + case a: Alias => + // Project the child and wrap it back in an Alias with the same metadata + getProjection(a.child).map { projectedChild => + a.copy(child = projectedChild)( + a.exprId, a.qualifier, a.explicitMetadata, a.nonInheritableMetadataKeys) + } case _ => None } + + /** + * Finds the innermost StructType within a nested array type. + * For example, `array>>` returns `struct`. + */ + @scala.annotation.tailrec + private def findInnermostStruct(dt: DataType): StructType = dt match { + case ArrayType(elementType: ArrayType, _) => findInnermostStruct(elementType) + case ArrayType(st: StructType, _) => st + case _ => throw new IllegalStateException(s"Expected nested array of struct, got: $dt") + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala index dd2d6c2cb610c..ffb7dc9f01985 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala @@ -144,6 +144,10 @@ object SchemaPruning extends SQLConfHelper { RootField(StructField(att.name, att.dataType, att.nullable, att.metadata), derivedFromAtt = true) :: Nil case SelectedField(field) => RootField(field, derivedFromAtt = false) :: Nil + // Handle multi-field expressions like ArraysZip and NestedArraysZip that combine + // multiple field accesses into a single expression. unapplySeq returns all fields. + case expr if SelectedField.unapplySeq(expr).exists(_.size > 1) => + SelectedField.unapplySeq(expr).get.map(f => RootField(f, derivedFromAtt = false)) // Root field accesses by `IsNotNull` and `IsNull` are special cases as the expressions // don't actually use any nested fields. These root field accesses might be excluded later // if there are any nested fields accesses in the query plan. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SelectedField.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SelectedField.scala index 820dc452d7e84..6e99f2ac1e3b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SelectedField.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SelectedField.scala @@ -62,6 +62,40 @@ object SelectedField { selectField(unaliased, None) } + /** + * Like unapply, but returns all fields from expressions that combine multiple field accesses. + * This is needed for expressions like ArraysZip and NestedArraysZip that merge multiple + * array field extractions into a single output. + * + * For example, `ArraysZip([arr.f1, arr.f2], names)` accesses both f1 and f2 from arr. + * Regular unapply would return None, but unapplySeq returns both fields. + * + * @return None if no fields are accessed, Some(Seq[StructField]) otherwise + */ + def unapplySeq(expr: Expression): Option[Seq[StructField]] = { + val unaliased = expr match { + case Alias(child, _) => child + case e => e + } + unaliased match { + // ArraysZip combines multiple array field extractions + // Use unapplySeq recursively to handle nested ArraysZip/NestedArraysZip + case ArraysZip(children, _) => + val fields = children.flatMap(c => unapplySeq(c).getOrElse(Seq.empty)) + if (fields.nonEmpty) Some(fields) else None + + // NestedArraysZip combines nested array field extractions + // Use unapplySeq recursively to handle further nested expressions + case NestedArraysZip(children, _, _) => + val fields = children.flatMap(c => unapplySeq(c).getOrElse(Seq.empty)) + if (fields.nonEmpty) Some(fields) else None + + // For other expressions, delegate to regular unapply + case _ => + unapply(expr).map(Seq(_)) + } + } + /** * Convert an expression into the parts of the schema (the field) it accesses. */ @@ -96,6 +130,26 @@ object SelectedField { } val newField = StructField(field.name, newFieldDataType, field.nullable) selectField(child, Option(ArrayType(struct(newField), containsNull))) + case GetNestedArrayStructFields(child, field, ordinal, _, containsNull) => + // GetNestedArrayStructFields extracts a field from the innermost struct of a + // nested array like array>. We need to find the innermost struct + // field and rebuild the full nested array schema. + val innermostField = findInnermostStructField(child.dataType, ordinal) + val newFieldDataType = dataTypeOpt match { + case None => + // Top level extractor - use the field's type + innermostField.dataType + case Some(dt) => + // Part of a chain - peel off only the parent's array layers, not the field's own + // For example, if child is array>>>>, + // the parent chain contributes 2 array levels, and field contributes 1 more. + // We should peel only the parent's 2 levels, keeping the field's array type. + val parentArrayDepth = arrayDepth(child.dataType) + peelNArrayLayers(dt, parentArrayDepth) + } + val newField = StructField(innermostField.name, newFieldDataType, innermostField.nullable) + val wrappedType = wrapInArrays(child.dataType, struct(newField), containsNull) + selectField(child, Option(wrappedType)) case GetMapValue(child, key) if key.foldable => // GetMapValue does not select a field from a struct (i.e. prune the struct) so it can't be // the top-level extractor. However it can be part of an extractor chain. @@ -154,4 +208,66 @@ object SelectedField { } private def struct(field: StructField): StructType = StructType(Array(field)) + + /** + * Finds the struct field at the given ordinal in the innermost struct of a nested array type. + * For example, for `array>>` with ordinal 1, returns field `b`. + */ + @scala.annotation.tailrec + private def findInnermostStructField(dt: DataType, ordinal: Int): StructField = dt match { + case ArrayType(elementType: ArrayType, _) => findInnermostStructField(elementType, ordinal) + case ArrayType(st: StructType, _) => st(ordinal) + case _ => throw new IllegalArgumentException(s"Expected nested array of struct, got: $dt") + } + + /** + * Removes all ArrayType wrappers from a data type, returning the innermost element type. + * For example, `array>` becomes `int`. + */ + @scala.annotation.tailrec + private def peelArrayLayers(dt: DataType): DataType = dt match { + case ArrayType(elementType, _) => peelArrayLayers(elementType) + case other => other + } + + /** + * Counts the number of array layers in a data type. + * For example, `array>` returns 2. + */ + private def arrayDepth(dt: DataType): Int = { + @scala.annotation.tailrec + def loop(dt: DataType, depth: Int): Int = dt match { + case ArrayType(elementType, _) => loop(elementType, depth + 1) + case _ => depth + } + loop(dt, 0) + } + + /** + * Removes exactly N ArrayType wrappers from a data type. + * For example, `peelNArrayLayers(array>>, 2)` returns `array`. + */ + @scala.annotation.tailrec + private def peelNArrayLayers(dt: DataType, n: Int): DataType = { + if (n <= 0) dt + else dt match { + case ArrayType(elementType, _) => peelNArrayLayers(elementType, n - 1) + case other => other + } + } + + /** + * Wraps an innermost struct type in the same array nesting as the source type. + * For example, if sourceType is `array>>` and innerStruct is `struct`, + * returns `array>>`. + */ + private def wrapInArrays(sourceType: DataType, innerStruct: StructType, + containsNull: Boolean): DataType = sourceType match { + case ArrayType(elementType: ArrayType, outerContainsNull) => + ArrayType(wrapInArrays(elementType, innerStruct, containsNull), outerContainsNull) + case ArrayType(_: StructType, _) => + ArrayType(innerStruct, containsNull) + case _ => + throw new IllegalArgumentException(s"Expected nested array of struct, got: $sourceType") + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 6dddd9e6646c3..2a5aa7b1ea8ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -471,6 +471,250 @@ object ArraysZip { } } +/** + * Zips multiple nested arrays at their innermost level, producing a nested array of structs. + * + * Unlike [[ArraysZip]] which zips at the top level, this expression handles nested arrays + * (e.g., `array>`) and zips at the innermost array level while preserving the + * outer array structure. + * + * This enables rebuilding nested struct arrays without higher-order functions, which is + * useful for nested column pruning through generators. + * + * Example (depth 2): + * {{{ + * NestedArraysZip([[[1,2],[3]], [[a,b],[c]]], ["x", "y"]) + * produces [[(1,a),(2,b)], [(3,c)]] // type: array>> + * }}} + * + * Null handling: + * - If any top-level input is null, output is null + * - If an inner array at some position is null in any input, that output position is null + * - Uses max-length semantics at each level (shorter arrays padded with nulls) + * + * @param children The arrays to zip (all must have the same nesting depth) + * @param names The field names for the output struct + * @param depth The nesting depth at which to zip (1 = same as ArraysZip, 2 = one level nested) + */ +/** + * Zips multiple arrays at a specified nesting depth, creating struct elements at the innermost + * level. Created by PruneNestedFieldsThroughGenerateForScan during optimization, then executed + * at runtime in the scan Project to reconstruct pruned nested arrays from individual field arrays. + * + * For depth=1: Standard arrays_zip semantics (array, array) -> array> + * For depth>1: Recursively zips inner arrays at each position of outer arrays. + * + * Note: Uses CodegenFallback intentionally. Custom doGenCode implementations cause type + * resolution errors ("cannot be converted to numeric type"). This disables whole-stage + * codegen for the scan Project that contains this expression, which may affect runtime + * throughput on large explode pipelines. However, the scan IO savings from nested field + * pruning (reading fewer columns/sub-fields from Parquet) typically outweigh the codegen + * overhead. This expression only appears in depth>=2 nested array pruning + * (array>); depth-1 cases use the standard ArraysZip which supports codegen. + * + * @param children The array expressions to zip + * @param names The field names for the resulting struct + * @param depth The nesting depth at which to perform the zip (must be >= 1) + */ +case class NestedArraysZip(children: Seq[Expression], names: Seq[Expression], depth: Int) + extends Expression with ExpectsInputTypes with CodegenFallback { + + require(depth >= 1, s"NestedArraysZip depth must be >= 1, got $depth") + + if (children.size != names.size) { + throw new IllegalArgumentException( + "The numbers of zipped arrays and field names should be the same") + } + + override lazy val resolved: Boolean = + childrenResolved && checkInputDataTypes().isSuccess && names.forall(_.resolved) + + override def inputTypes: Seq[AbstractDataType] = { + // Each child must be an array nested to the specified depth + Seq.fill(children.length)(ArrayType) + } + + override def checkInputDataTypes(): TypeCheckResult = { + if (children.isEmpty) { + return TypeCheckResult.TypeCheckSuccess + } + + // All children must have at least the specified nesting depth. + // We use >= (not ==) because recursive buildPrunedArrayFromSchema may produce children + // with extra array nesting (e.g., array>> for depth=2 when + // the innermost struct itself contains array fields). NestedArraysZip.eval correctly + // handles this by zipping at the specified depth, leaving deeper nesting intact. + val depths = children.map(c => computeArrayDepth(c.dataType)) + val depthErrors = children.zipWithIndex.flatMap { case (child, idx) => + val actualDepth = depths(idx) + if (actualDepth < depth) { + Some(s"Argument ${idx + 1} has array depth $actualDepth but required at least $depth") + } else { + None + } + } + + if (depthErrors.nonEmpty) { + TypeCheckResult.TypeCheckFailure(depthErrors.mkString("; ")) + } else if (depths.distinct.size > 1) { + // All children must have the same depth to produce consistent struct element types + TypeCheckResult.TypeCheckFailure( + s"All arguments must have the same array depth, but got: ${depths.mkString(", ")}") + } else { + TypeCheckResult.TypeCheckSuccess + } + } + + private def computeArrayDepth(dt: DataType): Int = dt match { + case ArrayType(elementType, _) => 1 + computeArrayDepth(elementType) + case _ => 0 + } + + /** Extract the innermost element type at the specified depth */ + private def elementTypeAtDepth(dt: DataType, d: Int): DataType = { + if (d <= 0) dt + else dt match { + case ArrayType(elementType, _) => elementTypeAtDepth(elementType, d - 1) + case other => other + } + } + + /** Get the containsNull flag at a given depth */ + private def containsNullAtDepth(dt: DataType, d: Int): Boolean = { + if (d <= 1) dt match { + case ArrayType(_, containsNull) => containsNull + case _ => true + } + else dt match { + case ArrayType(elementType, _) => containsNullAtDepth(elementType, d - 1) + case _ => true + } + } + + @transient private lazy val innermostElementTypes: Seq[DataType] = + children.map(c => elementTypeAtDepth(c.dataType, depth)) + + @transient override lazy val dataType: DataType = { + val structFields = innermostElementTypes.zip(names).map { + case (elementType, Literal(name, StringType)) => + StructField(name.toString, elementType, nullable = true) + case (elementType, _) => + StructField("_", elementType, nullable = true) + } + val innerStruct = StructType(structFields) + + // Wrap in array types for each level of depth + (1 until depth).foldLeft(ArrayType(innerStruct, containsNull = false): DataType) { + (acc, _) => ArrayType(acc, containsNull = true) + } + } + + override def nullable: Boolean = children.exists(_.nullable) + + override def eval(input: InternalRow): Any = { + val inputArrays = children.map(_.eval(input).asInstanceOf[ArrayData]) + if (inputArrays.contains(null)) { + null + } else { + zipAtDepth(inputArrays, depth) + } + } + + /** + * Recursively zips arrays at the specified depth. + * At depth 1, performs standard arrays_zip semantics. + * At depth > 1, iterates outer arrays and recursively zips inner arrays. + */ + private def zipAtDepth(arrays: Seq[ArrayData], d: Int): ArrayData = { + if (d == 1) { + // Base case: standard arrays_zip semantics + zipArrays(arrays) + } else { + // Recursive case: zip at each position of the outer arrays + if (arrays.isEmpty) { + new GenericArrayData(Array.empty[Any]) + } else { + val maxLen = arrays.map(_.numElements()).max + val result = new Array[Any](maxLen) + + for (i <- 0 until maxLen) { + // Collect inner arrays at position i from each input + val innerArrays = arrays.zipWithIndex.map { case (arr, idx) => + if (i < arr.numElements() && !arr.isNullAt(i)) { + // Get the inner array, handling nested array types + arr.getArray(i) + } else { + null + } + } + + // If any inner array is null, the result at this position is null + if (innerArrays.contains(null)) { + result(i) = null + } else { + result(i) = zipAtDepth(innerArrays.map(_.asInstanceOf[ArrayData]), d - 1) + } + } + + new GenericArrayData(result) + } + } + } + + /** Standard arrays_zip at the innermost level */ + private def zipArrays(arrays: Seq[ArrayData]): ArrayData = { + if (arrays.isEmpty) { + new GenericArrayData(Array.empty[Any]) + } else { + val maxLen = arrays.map(_.numElements()).max + val result = new Array[InternalRow](maxLen) + + for (i <- 0 until maxLen) { + val row = arrays.zipWithIndex.map { case (arr, idx) => + if (i < arr.numElements() && !arr.isNullAt(i)) { + arr.get(i, innermostElementTypes(idx)) + } else { + null + } + } + result(i) = InternalRow.apply(row: _*) + } + + new GenericArrayData(result) + } + } + + override def prettyName: String = "nested_arrays_zip" + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): NestedArraysZip = + copy(children = newChildren) +} + +object NestedArraysZip { + /** Create with auto-detected depth from the first child's type. */ + def apply(children: Seq[Expression], names: Seq[Expression]): NestedArraysZip = { + // Compute depth from first child's type + val depth = children.headOption.map { child => + def computeDepth(dt: DataType): Int = dt match { + case ArrayType(elementType, _) => 1 + computeDepth(elementType) + case _ => 0 + } + computeDepth(child.dataType) + }.getOrElse(1) + + new NestedArraysZip(children, names, depth) + } + + /** Create with explicitly specified depth. */ + def withDepth( + children: Seq[Expression], + names: Seq[Expression], + depth: Int): NestedArraysZip = { + new NestedArraysZip(children, names, depth) + } +} + /** * Returns an unordered array containing the values of the map. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index dba061eeb870d..64fc64eafaa5c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -314,6 +314,192 @@ case class GetArrayStructFields( copy(child = newChild) } +/** + * For a child whose data type is a nested array ending in structs (e.g., `array>`), + * extracts the `ordinal`-th fields of all innermost struct elements, preserving the outer array + * nesting structure. + * + * For example, given `array>>` and ordinal 0 (field "a"), + * this returns `array>`. + * + * This expression supports arbitrary nesting depth (determined from the input type at runtime), + * without requiring higher-order functions. + * + * @param child The nested array expression (must be `array<...>>`) + * @param field The struct field to extract + * @param ordinal The ordinal of the field within the innermost struct + * @param numFields The number of fields in the innermost struct + * @param containsNull Whether the innermost array can contain nulls + */ +case class GetNestedArrayStructFields( + child: Expression, + field: StructField, + ordinal: Int, + numFields: Int, + containsNull: Boolean) extends UnaryExpression with ExtractValue { + + /** + * Computes the depth of array nesting (number of ArrayType levels). + * For `array`, depth is 1. + * For `array>`, depth is 2. + */ + @transient + private lazy val arrayDepth: Int = computeArrayDepth(child.dataType) + + private def computeArrayDepth(dt: DataType): Int = dt match { + case ArrayType(elementType: ArrayType, _) => 1 + computeArrayDepth(elementType) + case ArrayType(_: StructType, _) => 1 + case _ => throw new IllegalArgumentException( + s"GetNestedArrayStructFields requires array<...>, got: $dt") + } + + /** + * Returns the output type: same array nesting with innermost struct replaced by field type. + * E.g., `array>>` with field "a" -> `array>` + */ + override def dataType: DataType = { + def buildOutputType(dt: DataType, depth: Int): DataType = { + if (depth == 1) { + // Innermost array level - replace struct with field type + val ArrayType(_, elemContainsNull) = dt: @unchecked + ArrayType(field.dataType, elemContainsNull || containsNull) + } else { + val ArrayType(elemType, elemContainsNull) = dt: @unchecked + ArrayType(buildOutputType(elemType, depth - 1), elemContainsNull) + } + } + buildOutputType(child.dataType, arrayDepth) + } + + override def toString: String = s"$child.${field.name}" + override def sql: String = s"${child.sql}.${quoteIdentifier(field.name)}" + + protected override def nullSafeEval(input: Any): Any = { + extractNested(input.asInstanceOf[ArrayData], arrayDepth) + } + + /** + * Recursively traverses nested arrays, extracting the field at the innermost struct level. + */ + private def extractNested(array: ArrayData, depth: Int): ArrayData = { + val length = array.numElements() + val result = new Array[Any](length) + var i = 0 + while (i < length) { + if (array.isNullAt(i)) { + result(i) = null + } else if (depth == 1) { + // At innermost array level, extract struct field + val row = array.getStruct(i, numFields) + if (row.isNullAt(ordinal)) { + result(i) = null + } else { + result(i) = row.get(ordinal, field.dataType) + } + } else { + // Recurse into nested array + val nestedArray = array.getArray(i) + result(i) = extractNested(nestedArray, depth - 1) + } + i += 1 + } + new GenericArrayData(result) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val arrayClass = classOf[GenericArrayData].getName + nullSafeCodeGen(ctx, ev, eval => { + generateNestedLoop(ctx, eval, ev, arrayDepth) + }) + } + + /** + * Generates nested loop code for extracting fields from nested arrays. + * This is the entry point that assigns to ev.value at the end. + */ + private def generateNestedLoop( + ctx: CodegenContext, + inputArray: String, + ev: ExprCode, + depth: Int): String = { + val resultVar = ctx.freshName("finalResult") + val arrayClass = classOf[GenericArrayData].getName + s""" + ArrayData $resultVar = null; + ${generateInnerExtraction(ctx, inputArray, resultVar, depth)} + ${ev.value} = $resultVar; + """ + } + + /** + * Generates extraction code for a given nesting depth. + * Recursively builds nested loops, assigning to resultVar at the end. + */ + private def generateInnerExtraction( + ctx: CodegenContext, + inputArray: String, + resultVar: String, + depth: Int): String = { + val n = ctx.freshName("n") + val values = ctx.freshName("values") + val j = ctx.freshName("j") + val arrayClass = classOf[GenericArrayData].getName + + if (depth == 1) { + // Base case: innermost array with struct elements + val row = ctx.freshName("row") + val nullSafeFieldEval = if (field.nullable) { + s""" + if ($row.isNullAt($ordinal)) { + $values[$j] = null; + } else + """ + } else { + "" + } + s""" + final int $n = $inputArray.numElements(); + final Object[] $values = new Object[$n]; + for (int $j = 0; $j < $n; $j++) { + if ($inputArray.isNullAt($j)) { + $values[$j] = null; + } else { + final InternalRow $row = $inputArray.getStruct($j, $numFields); + $nullSafeFieldEval { + $values[$j] = ${CodeGenerator.getValue(row, field.dataType, ordinal.toString)}; + } + } + } + $resultVar = new $arrayClass($values); + """ + } else { + // Recursive case: nested array + val innerArray = ctx.freshName("innerArray") + val innerResult = ctx.freshName("innerResult") + + // Generate code that extracts from each nested array element + s""" + final int $n = $inputArray.numElements(); + final Object[] $values = new Object[$n]; + for (int $j = 0; $j < $n; $j++) { + if ($inputArray.isNullAt($j)) { + $values[$j] = null; + } else { + ArrayData $innerArray = $inputArray.getArray($j); + ArrayData $innerResult = null; + ${generateInnerExtraction(ctx, innerArray, innerResult, depth - 1)} + $values[$j] = $innerResult; + } + } + $resultVar = new $arrayClass($values); + """ + } + } + + override protected def withNewChildInternal(newChild: Expression): GetNestedArrayStructFields = + copy(child = newChild) +} + /** * Returns the field at `ordinal` in the Array `child`. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PruneNestedFieldsThroughGenerateForScan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PruneNestedFieldsThroughGenerateForScan.scala new file mode 100644 index 0000000000000..10eafc263bd5f --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PruneNestedFieldsThroughGenerateForScan.scala @@ -0,0 +1,2523 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes +import org.apache.spark.sql.types._ + +/** + * Prunes nested struct fields within arrays used by explode/posexplode generators, + * materializing pruned arrays as a Project below the Generate node so that downstream + * [[org.apache.spark.sql.execution.datasources.SchemaPruning]] and + * V2ScanRelationPushDown can reduce scan IO. + * + * This rule handles the multi-field case that [[GeneratorNestedColumnAliasing]] does + * not support (see SPARK-34956), and additionally provides: + * - Posexplode pos-only optimisation (minimal-weight field selection) + * - Ordinal-safe field resolution using field names instead of ordinal reuse + * - Chained Generate support (multiple consecutive lateral views) + * + * The rule runs in the earlyScanPushDownRules batch, before SchemaPruning and + * V2ScanRelationPushDown, so that the materialised fields are visible through + * [[org.apache.spark.sql.catalyst.planning.ScanOperation]]. + * + * === Example (single Generate) === + * Before (multi-field on generator output, no scan pruning): + * {{{ + * Project [item.f1, item.f2] + * Generate [explode(col)] // col: array> + * Scan [col] + * }}} + * + * After this rule (pruned array materialised for SchemaPruning): + * {{{ + * Project [item.f1, item.f2] // ordinals fixed + * Generate [explode(_pruned)] // _pruned: array> + * Project [_pruned = arrays_zip(col.f1, col.f2)] + * Scan [col] + * }}} + * + * === Example (chained Generates) === + * Before: + * {{{ + * Project [complex.col1, val] + * Generate [explode(complex.col2)] // inner: explodes col2 array + * Generate [explode(arr)] // outer: arr is array> + * Scan [arr] + * }}} + * + * After (col3 pruned from outer struct): + * {{{ + * Project [complex.col1, val] // ordinals fixed + * Generate [explode(complex.col2)] // ordinals fixed in generator child + * Generate [explode(_pruned)] // _pruned: array> + * Project [_pruned = arrays_zip(arr.col1, arr.col2)] + * Scan [arr] + * }}} + */ +object PruneNestedFieldsThroughGenerateForScan + extends Rule[LogicalPlan] with SQLConfHelper { + + override def apply(plan: LogicalPlan): LogicalPlan = { + if (!conf.nestedSchemaPruningEnabled) return plan + rewriteGenerateChains(plan) + } + + /** + * Represents a Generate node in a chain with its context. + * + * @param generate The Generate node + * @param generator The ExplodeBase generator + * @param colAttr The exploded element attribute + * @param posAttrOpt The position attribute (for posexplode) + * @param elementStruct The struct type of array elements + * @param containsNull Whether the array can contain nulls + * @param intermediateNodes Nodes between this Generate and the next one in the chain + * (or the leaf). These are typically Projects from + * GeneratorNestedColumnAliasing that need to be preserved + * and have their expressions fixed when rebuilding. + */ + private case class GenerateInfo( + generate: Generate, + generator: ExplodeBase, + colAttr: Attribute, + posAttrOpt: Option[Attribute], + elementStruct: StructType, + containsNull: Boolean, + intermediateNodes: Seq[LogicalPlan] = Nil) + + /** + * Main traversal that detects and rewrites Generate chains. + * + * We look for patterns: + * - Project -> Generate -> ... (chain of Generates) -> leaf + * - Project -> Filter -> Generate -> ... -> leaf + * + * For each chain, we collect required fields for each Generate and rewrite + * bottom-up, inserting a single pruned Project at the leaf. + * + * Uses nested schema approach for requirement propagation, enabling inner + * generate pruning by embedding inner requirements into outer schemas. + */ + private def rewriteGenerateChains(plan: LogicalPlan): LogicalPlan = { + plan.transformDown { + // Pattern 1: Project directly above a Generate chain + case p @ Project(projectList, child) if startsGenerateChain(child) => + tryRewriteChainWithNestedSchema(p, projectList, Nil, child, None) + + // Pattern 2: Project -> Filter -> Generate chain + case p @ Project(projectList, f @ Filter(condition, child)) + if startsGenerateChain(child) => + tryRewriteChainWithNestedSchema( + p, projectList, Seq(condition), child, Some(f)) + + // Pattern 3: Project above a Generate that has an intermediate Project below + case p @ Project(projectList, g: Generate) + if !startsGenerateChain(g) && startsGenerateChainThroughProjects(g.child) => + tryRewriteChainWithNestedSchema(p, projectList, Nil, g, None) + } + } + + /** + * Looks through Project and Filter nodes to check if there's a Generate chain. + */ + private def startsGenerateChainThroughProjects(plan: LogicalPlan): Boolean = plan match { + case g: Generate if startsGenerateChain(g) => true + case Project(_, child) => startsGenerateChainThroughProjects(child) + case Filter(_, child) => startsGenerateChainThroughProjects(child) + case _ => false + } + + /** + * Checks if a plan node starts a Generate chain we can potentially prune. + */ + private def startsGenerateChain(plan: LogicalPlan): Boolean = plan match { + case Generate(gen: ExplodeBase, _, _, _, _, _) => + val result = isArrayOfStruct(gen.child.dataType) + result + case _ => false + } + + /** + * Extracts a chain of Generate nodes from a plan. + * Returns the chain (top-to-bottom: closest to Project first) and the leaf node below the chain. + * + * Looks through intermediate Projects and Filters to find consecutive Generate nodes. + * For each Generate, captures the intermediate nodes (Projects, Filters) between it + * and the next Generate (or leaf). This is needed to preserve and fix those nodes + * during chain rewriting. + */ + private def extractGenerateChain(plan: LogicalPlan): (Seq[GenerateInfo], LogicalPlan) = { + + // Helper to collect intermediate nodes (Projects, Filters) until we hit + // a Generate or a leaf node. Returns (intermediates, next plan to process). + def collectIntermediates(p: LogicalPlan): (Seq[LogicalPlan], LogicalPlan) = { + p match { + case g @ Generate(gen: ExplodeBase, _, _, _, _, _) + if isArrayOfStruct(gen.child.dataType) => + // Found next Generate + (Nil, g) + case proj @ Project(_, child) => + val (rest, next) = collectIntermediates(child) + (proj +: rest, next) + case flt @ Filter(_, child) => + val (rest, next) = collectIntermediates(child) + (flt +: rest, next) + case other => + // Leaf node + (Nil, other) + } + } + + plan match { + case g @ Generate(gen: ExplodeBase, _, _, _, genOutput, child) + if isDirectArrayOfStruct(gen.child.dataType) => + // Only include generates where the element type (after explode) is a struct. + // For array>, the element is array, not struct - skip those. + val (elementStruct, containsNull) = extractDirectStruct(gen.child.dataType) + val colAttr = if (gen.position) genOutput(1) else genOutput.head + val posAttrOpt = if (gen.position) Some(genOutput.head) else None + + // Collect intermediate nodes between this Generate and the next (or leaf) + val (intermediates, nextPlan) = collectIntermediates(child) + + + val info = GenerateInfo(g, gen, colAttr, posAttrOpt, elementStruct, containsNull, + intermediateNodes = intermediates) + + // Continue extracting from the next plan (next Generate or leaf) + val (childChain, leaf) = extractGenerateChain(nextPlan) + (info +: childChain, leaf) + + // Handle nested array case (array>) for inner generate pruning. + // After GNA transforms the outer generate to explode nested arrays, we get + // a generate with array> source producing array elements. + // We MUST include this in the chain to preserve it during rewriting, even though + // we can't prune it directly (its elements are arrays, not structs). + case g @ Generate(gen: ExplodeBase, _, _, _, genOutput, child) + if isArrayOfStruct(gen.child.dataType) && !isDirectArrayOfStruct(gen.child.dataType) => + // Extract the innermost struct for tracking, but this generate can't be pruned directly + val (innermostStruct, containsNull) = extractInnermostStruct(gen.child.dataType) + val colAttr = if (gen.position) genOutput(1) else genOutput.head + val posAttrOpt = if (gen.position) Some(genOutput.head) else None + + // Collect intermediate nodes between this Generate and the next (or leaf) + val (intermediates, nextPlan) = collectIntermediates(child) + + + // Mark this as a "nested array" generate by using a special empty StructType + // This signals to the pruning logic that we can't prune this generate directly + // but must preserve it in the chain. + // We use the innermost struct so ordinal fixing can still work if needed. + val info = GenerateInfo(g, gen, colAttr, posAttrOpt, innermostStruct, containsNull, + intermediateNodes = intermediates) + + // Continue extracting from the next plan + val (childChain, leaf) = extractGenerateChain(nextPlan) + (info +: childChain, leaf) + + // If we start with a Generate that isn't array-of-struct, look through its child + // This handles Pattern 3 where the top Generate is exploding a scalar array + // (e.g., explode(l2.l3_f1) producing array) + case g: Generate => + extractGenerateChain(g.child) + + // If we start with non-Generate, look through to find the first Generate + case Project(_, child) => + extractGenerateChain(child) + + case Filter(_, child) => + extractGenerateChain(child) + + case other => + (Nil, other) + } + } + + /** + * Collects the nodes between chainStart and the first chain Generate (chain[0]). + * These nodes need to be preserved and rebuilt after the chain rewrite. + * + * For example, if chainStart is a non-struct Generate with intermediate Projects/Filters + * before chain[0], we need to collect [Generate, Project, Filter] and rebuild them later. + * + * @return Sequence of nodes from chainStart down to (but not including) targetGenerate + */ + private def collectAboveChainNodes( + chainStart: LogicalPlan, + targetGenerate: Generate): Seq[LogicalPlan] = { + val nodes = mutable.ArrayBuffer[LogicalPlan]() + + @scala.annotation.tailrec + def collect(plan: LogicalPlan): Unit = { + if (plan eq targetGenerate) return + + plan match { + case g: Generate => + nodes += g + collect(g.child) + case p: Project => + nodes += p + collect(p.child) + case f: Filter => + nodes += f + collect(f.child) + case _ => + // Shouldn't reach here if targetGenerate is in the tree + } + } + + collect(chainStart) + nodes.toSeq + } + + /** + * Collects expressions from intermediate nodes between the start plan and the target generate. + * This captures expressions like `outer_elem.inner_array.inner_f1` that may exist in + * intermediate Projects created by GeneratorNestedColumnAliasing. + * + * @param start The starting plan node + * @param targetGenerate The Generate node to stop at + * @return Expressions from intermediate Projects that may reference Generate outputs + */ + private def collectIntermediateExpressions( + start: LogicalPlan, + targetGenerate: Generate): Seq[Expression] = { + val exprs = mutable.ArrayBuffer[Expression]() + + def collect(plan: LogicalPlan): Unit = { + if (plan eq targetGenerate) return + + plan match { + case g: Generate => + // Look at expressions in the generator child (for ExplodeBase generators) + g.generator match { + case e: ExplodeBase => exprs += e.child + case _ => // Other generators may not have a simple child expression + } + collect(g.child) + + case Project(projectList, child) => + // Collect expressions from this Project + exprs ++= projectList + collect(child) + + case Filter(condition, child) => + exprs += condition + collect(child) + + case _ => + plan.children.foreach(collect) + } + } + + collect(start) + exprs.toSeq + } + + /** + * Attempts to rewrite a Generate chain using nested schema propagation. + * This enables inner generate pruning by embedding inner requirements into outer schemas. + */ + private def tryRewriteChainWithNestedSchema( + originalProject: Project, + projectList: Seq[NamedExpression], + filterConditions: Seq[Expression], + chainStart: LogicalPlan, + filterOpt: Option[Filter]): LogicalPlan = { + + val (chain, leaf) = extractGenerateChain(chainStart) + if (chain.isEmpty) return originalProject + + // Collect nodes between chainStart and chain[0] that need to be preserved. + // These are rebuilt after the chain rewrite using rebuildAboveChainNodes. + val aboveChainNodes = collectAboveChainNodes(chainStart, chain.head.generate) + + // Collect expressions from intermediate nodes between the top Project and the chain + // This captures nested array field accesses that may exist in intermediate Projects + // (e.g., from GeneratorNestedColumnAliasing) + val intermediateExprs = collectIntermediateExpressions(chainStart, chain.head.generate) + + // Compute nested schema requirements with backward propagation + val allTopExprs = projectList ++ filterConditions ++ intermediateExprs + val requirements = computeNestedChainRequirements(chain, allTopExprs, Nil, leaf, chainStart) + + // Check if any pruning is possible (top-level count reduction OR nested type changes) + val anyPruning = requirements.zip(chain).exists { case (req, info) => + req.isDefined && hasNestedTypePruning(req.get, info.elementStruct) + } + + if (!anyPruning) return originalProject + + // Rewrite using nested schema builder + rewriteChainWithNestedSchema( + originalProject, projectList, filterConditions, chain, requirements, leaf, filterOpt, + aboveChainNodes, chainStart) + } + + /** + * Rewrites the Generate chain using nested schema requirements. + * This is the schema-driven replacement for rewriteChainBottomUp. + * + * @param aboveChainNodes Nodes between originalProject and chain[0] that need to be + * preserved and rebuilt after the chain rewrite + */ + /** + * Traces an expression through generate outputs and aliases to find a scan-rooted source. + * Returns the resolved expression if it can be traced to a scan attribute. + */ + private def traceToScanRootedExpr( + expr: Expression, + generateSourceMap: Map[ExprId, Expression], + aliasMap: Map[ExprId, Expression], + scanAttrIds: Set[ExprId], + visited: Set[ExprId] = Set.empty): Option[Expression] = { + + + // Check if this expression is scan-rooted + extractRootAttribute(expr) match { + case Some(rootAttr) if scanAttrIds.contains(rootAttr.exprId) => + Some(expr) + case Some(rootAttr) if visited.contains(rootAttr.exprId) => + None // Cycle detected + case Some(rootAttr) => + // Try tracing through generate sources + generateSourceMap.get(rootAttr.exprId) match { + case Some(sourceExpr) => + traceToScanRootedExpr(sourceExpr, generateSourceMap, aliasMap, scanAttrIds, + visited + rootAttr.exprId).map { resolvedSource => + // Rebuild the path with the resolved source + rebuildExprWithNewBase(expr, rootAttr, resolvedSource) + } + case None => + // Try tracing through aliases + aliasMap.get(rootAttr.exprId) match { + case Some(aliasedExpr) => + traceToScanRootedExpr(aliasedExpr, generateSourceMap, aliasMap, scanAttrIds, + visited + rootAttr.exprId).map { resolvedAlias => + rebuildExprWithNewBase(expr, rootAttr, resolvedAlias) + } + case None => + None + } + } + case None => + None + } + } + + /** + * Rebuilds an expression by replacing its base attribute with a new expression. + * For example, rebuildExprWithNewBase(outer_elem.inner_f1, outer_elem, outer_array.inner_array) + * returns outer_array.inner_array.inner_f1 + */ + private def rebuildExprWithNewBase( + expr: Expression, + oldBase: Attribute, + newBase: Expression): Expression = { + expr match { + case a: Attribute if a.exprId == oldBase.exprId => + newBase + case gsf @ GetStructField(child, ordinal, name) => + GetStructField(rebuildExprWithNewBase(child, oldBase, newBase), ordinal, name) + case gasf @ GetArrayStructFields(child, field, ordinal, numFields, containsNull) => + GetArrayStructFields(rebuildExprWithNewBase(child, oldBase, newBase), + field, ordinal, numFields, containsNull) + case gnasf @ GetNestedArrayStructFields( + child, field, ordinal, numFields, containsNull) => + GetNestedArrayStructFields(rebuildExprWithNewBase(child, oldBase, newBase), + field, ordinal, numFields, containsNull) + case _ => + expr + } + } + + private def rewriteChainWithNestedSchema( + originalProject: Project, + projectList: Seq[NamedExpression], + filterConditions: Seq[Expression], + chain: Seq[GenerateInfo], + requirements: Seq[Option[StructType]], + leaf: LogicalPlan, + filterOpt: Option[Filter], + aboveChainNodes: Seq[LogicalPlan] = Seq.empty, + chainStart: LogicalPlan = null): LogicalPlan = { + + // Decompose the leaf to find where to insert our pruning Project + val (leafFilters, actualLeaf) = decomposeChild(leaf) + + // Find the TRUE scan relation to determine scan attribute IDs. + // We can't use actualLeaf.output because actualLeaf might be an intermediate + // Project (e.g., from GeneratorNestedColumnAliasing) that defines alias attributes. + // Those alias attributes would be incorrectly treated as "scan-rooted". + val scanLeaf = findScanLeaf(leaf) + val scanAttrIds = scanLeaf.output.map(_.exprId).toSet + + // Collect generate sources and aliases from the FULL plan (not just chain). + // This allows tracing through generates that aren't in the chain (e.g., outer + // generates with array> sources that were skipped). + val chainStartPlan = if (chainStart != null) chainStart else originalProject + val generateSourceMap = collectGenerateSources(chainStartPlan) + val fullAliasMap = collectAliasDefinitions(chainStartPlan) + + + // Find the outermost Generate that can be pruned AND has a scan-rooted source. + // We trace through BOTH aliases AND generate outputs to find the ultimate scan source. + // For inner generates (i < chain.length - 1), we only support GNA-transformed cases + // where the outer generate produces array elements. + val outermostPrunableIdx = chain.indices.find { i => + val info = chain(i) + val rootAttrOpt = extractRootAttribute(info.generator.child) + val hasScanRootedSource = rootAttrOpt match { + case Some(rootAttr) => + // First check direct match + scanAttrIds.contains(rootAttr.exprId) || + // Then try tracing through chain aliases + traceToScanAttribute(rootAttr, chain, scanAttrIds).isDefined || + // Finally try tracing through all generates and aliases + traceToScanRootedExpr(info.generator.child, generateSourceMap, fullAliasMap, + scanAttrIds).isDefined + case None => + // Even if root is not a simple attribute, try tracing the full expression + traceToScanRootedExpr(info.generator.child, generateSourceMap, fullAliasMap, + scanAttrIds).isDefined + } + val canPrune = requirements(i).isDefined && + hasNestedTypePruning(requirements(i).get, info.elementStruct) + + // For inner generates, check if we can support this case + // (GNA-transformed = outer gen produces array elements) + val hasGeneratesBelow = i < chain.length - 1 + val canPruneInnerGen = if (hasGeneratesBelow) { + val outerGenInfo = chain(i + 1) + outerGenInfo.colAttr.dataType match { + case _: ArrayType => true // GNA-transformed case + case _: StructType => false // Non-GNA - skip, try outer generate + case _ => false + } + } else { + true // Bottommost generate - always supported + } + + hasScanRootedSource && canPrune && canPruneInnerGen + } + + outermostPrunableIdx match { + case None => originalProject + case Some(prunableIdx) => + val prunableInfo = chain(prunableIdx) + val requiredSchema = requirements(prunableIdx).get + + if (!hasNestedTypePruning(requiredSchema, prunableInfo.elementStruct)) { + return originalProject + } + + // Note: GNA vs non-GNA check is now in outermostPrunableIdx.find predicate + val hasGeneratesBelow = prunableIdx < chain.length - 1 + + // Resolve the source expression to a scan-rooted path for SchemaPruning + val rawSourceExpr: Expression = prunableInfo.generator.child + val resolvedSourceExpr: Expression = traceToScanRootedExpr( + rawSourceExpr, generateSourceMap, fullAliasMap, scanAttrIds + ).getOrElse(resolveExpressionThroughChain(rawSourceExpr, chain)) + val sourceRootAttr: Option[Attribute] = extractRootAttribute(resolvedSourceExpr) + + + val prunedArrayExpr = buildPrunedArrayFromSchema( + resolvedSourceExpr, + prunableInfo.elementStruct, + requiredSchema, + prunableInfo.containsNull) + + + val prunedAlias = Alias(prunedArrayExpr, "_pruned_explode")() + val prunedAttr = prunedAlias.toAttribute + + // When the source expression is resolved (e.g., outer_array.inner_array), + // we need to insert the pruned Project at the scan level where those attributes + // are available. If the resolved expression references scan attributes, use scanLeaf; + // otherwise fall back to actualLeaf for compatibility. + val insertionPoint = sourceRootAttr match { + case Some(rootAttr) if scanAttrIds.contains(rootAttr.exprId) => scanLeaf + case _ => actualLeaf + } + + // Compute pass-through attributes from the insertion point. + // Include references from above-chain nodes (non-struct Generates, Projects, Filters) + // that sit between the top Project and the chain. These nodes may reference scan + // attributes (e.g., explode(tags) needs tags from scan) that must be passed through. + val aboveChainNodeRefs: Seq[Attribute] = aboveChainNodes.flatMap { + case g: Generate => g.generator.references + case p: Project => p.projectList.flatMap(_.references) + case f: Filter => f.condition.references + case _ => Nil + } + val aboveReferences: Set[Attribute] = + (projectList.flatMap(_.references) ++ + filterConditions.flatMap(_.references) ++ + aboveChainNodeRefs).toSet + val allGenOutputs = AttributeSet(chain.flatMap(_.generate.generatorOutput)) + val childAttrsNeededAbove: Seq[Attribute] = insertionPoint.output + .filter(a => aboveReferences.exists(r => + r.exprId == a.exprId && !allGenOutputs.contains(r))) + + // Inner Project directly above the insertion point + val innerProject = Project(childAttrsNeededAbove :+ prunedAlias, insertionPoint) + + + // Rewrite leaf filters + val rewrittenLeaf = sourceRootAttr match { + case Some(_) => + leafFilters.foldLeft(innerProject: LogicalPlan) { (plan, cond) => + val rewritten = fixArrayStructOrdinalsInExprChain( + cond, rawSourceExpr, prunedAttr, requiredSchema, prunableInfo.containsNull) + Filter(rewritten, plan) + } + case None => + leafFilters.foldLeft(innerProject: LogicalPlan) { (plan, cond) => + Filter(cond, plan) + } + } + + // Build new generator for the prunable Generate. + // GNA-transformed: outer gen explodes array>, produces array + // Non-GNA: outer gen explodes array, produces struct (early return above) + val prunableGenSource = if (hasGeneratesBelow) { + val outerGenInfo = chain(prunableIdx + 1) + outerGenInfo.colAttr.dataType match { + case _: ArrayType => + // GNA-transformed case: outer gen produces array elements, use colAttr directly + outerGenInfo.colAttr + case _: StructType => + // Non-GNA case: outer gen produces struct elements + // We use prunedAttr - the pruned version of inner_array at scan level + prunedAttr + } + } else { + // Bottommost prunable generate - use prunedAttr directly + prunedAttr + } + val newGenerator: ExplodeBase = prunableInfo.generator match { + case _: Explode => Explode(prunableGenSource) + case _: PosExplode => PosExplode(prunableGenSource) + } + + // Updated generator output with new pruned schema + val newGenOutput = prunableInfo.generate.generatorOutput + .zip(toAttributes(newGenerator.elementSchema)).map { + case (oldAttr, newAttr) => + newAttr.withExprId(oldAttr.exprId).withName(oldAttr.name) + } + + // Build the rewritten chain from the prunable Generate upward. + // Intermediate nodes (Projects, Filters) are stored with each Generate and + // represent nodes BELOW that Generate (between it and the next one or the leaf). + // We rebuild bottom-up: for each Generate, first add its intermediate nodes, + // then add the Generate itself. + var currentChild: LogicalPlan = rewrittenLeaf + + // Track colAttr -> new schema mappings for all generates that got updated. + // Initialize early so intermediate node rebuilds can use it. + val colAttrSchemaMap = mutable.Map[ExprId, StructType]() + colAttrSchemaMap(prunableInfo.colAttr.exprId) = requiredSchema + + // Build a map of attribute types to update (colAttr exprId -> new type) + // Initialize with the prunable generate's output type + val attrTypeMap = mutable.Map[ExprId, DataType]() + val prunableNewColType = prunableInfo.generator.child.dataType match { + case ArrayType(ArrayType(_, innerContainsNull), _) => + ArrayType(requiredSchema, innerContainsNull) + case ArrayType(_, _) => + requiredSchema + case _ => + requiredSchema + } + attrTypeMap(prunableInfo.colAttr.exprId) = prunableNewColType + + // Rebuild Generates below the prunable one (if any). + // These are typically GNA-transformed generates that we traced THROUGH to reach the scan. + // They need to use the pruned output (_pruned_explode) instead of their original source. + for (i <- (chain.length - 1) until prunableIdx by -1) { + val info = chain(i) + + // For the BOTTOMMOST generate (directly above the scan-level Project), we: + // 1. SKIP its intermediate nodes (they define the unpruned extraction) + // 2. Use the pruned attribute as its source instead + // This is analogous to how we skip prunableInfo's intermediate nodes + val isBottommostGenerate = (i == chain.length - 1) + + if (!isBottommostGenerate) { + // Not the bottommost - rebuild its intermediate nodes normally + currentChild = rebuildIntermediateNodes( + info.intermediateNodes, currentChild, colAttrSchemaMap.toMap, chain, attrTypeMap) + } + // If isBottommostGenerate, SKIP intermediate nodes - our pruned Project replaces them + + // For the bottommost generate, use the pruned attribute as source. + // For others, fix ordinals in their generator child. + val newGenChild = if (isBottommostGenerate) { + // Use the pruned attribute directly - it has the correct pruned schema + prunedAttr + } else { + // Fix ordinals based on any updated schemas + fixOrdinalsInExprWithSchema( + info.generator.child, prunableInfo.colAttr, requiredSchema) + } + + val fixedGenerator: ExplodeBase = info.generator match { + case _: Explode => Explode(newGenChild) + case _: PosExplode => PosExplode(newGenChild) + } + + // Update the generate's output schema. + // For the bottommost generate with nested array source (array>), + // the output is one level unwrapped: array with the pruned inner schema. + val newGenOutput = { + val newElementSchema = fixedGenerator.elementSchema + info.generate.generatorOutput.zip(toAttributes(newElementSchema)).map { + case (oldAttr, newAttr) => + newAttr.withExprId(oldAttr.exprId).withName(oldAttr.name) + } + } + + // Track the new colAttr schema for downstream rebuilding + val newColAttr = if (info.generator.position) newGenOutput(1) else newGenOutput.head + newColAttr.dataType match { + case st: StructType => + colAttrSchemaMap(info.colAttr.exprId) = st + case ArrayType(st: StructType, _) => + // For nested array generates, the colAttr is array + colAttrSchemaMap(info.colAttr.exprId) = st + case _ => // Non-struct output, skip + } + + + currentChild = info.generate.copy( + generator = fixedGenerator, + generatorOutput = newGenOutput, + child = currentChild) + } + + // SKIP the prunable Generate's intermediate nodes - they were inserted by GNA to + // extract the source array (e.g., outer_array.inner_array), but our pruned Project + // already provides the pruned array directly as _pruned_explode. The intermediate + // extraction is now redundant. + // + // If we were to include them, they would: + // 1. Reference the old source (outer_array.inner_array) which still has all 4 fields + // 2. Block ScanOperation from matching our pruned Project + // 3. Leave orphaned nodes with unresolved references + // + // By skipping them, the Generate uses _pruned_explode directly above our pruned Project. + + // Now add the prunable Generate with updated output + val prunedIdx = childAttrsNeededAbove.length + val prunableGenerate = prunableInfo.generate.copy( + generator = newGenerator, + unrequiredChildIndex = Seq(prunedIdx), + generatorOutput = newGenOutput, + child = currentChild) + currentChild = prunableGenerate + + // Rebuild Generates above the prunable one (closer to Project). + // These may reference the prunable generate's output, so we need to: + // 1. Fix ordinals in their generator child + // 2. Update attribute types in their generator child + // 3. Update their generatorOutput to match the new element schema + // 4. Rebuild intermediate nodes between this Generate and the one below it + // + // Note: attrTypeMap was already initialized earlier with the prunable generate's type + + for (i <- (prunableIdx - 1) to 0 by -1) { + val info = chain(i) + + // IMPORTANT: Rebuild intermediate nodes FIRST so that alias types are tracked + // in attrTypeMap BEFORE we try to update the generate's expression. + // This is critical because the inner generate's source (e.g., _extract_inner_array#19) + // is defined by an alias in the intermediate Project. We need that alias's new type + // in attrTypeMap before we can update the inner generate's generator child. + currentChild = rebuildIntermediateNodes( + info.intermediateNodes, currentChild, colAttrSchemaMap.toMap, chain, attrTypeMap) + + // Now fix ordinals in generator child (with updated attrTypeMap from intermediates) + var fixedGenChild = fixOrdinalsInExprWithSchema( + info.generator.child, prunableInfo.colAttr, requiredSchema) + // Also update attribute types in generator child + fixedGenChild = updateAttributeTypes(fixedGenChild, attrTypeMap.toMap) + + val fixedGenerator: ExplodeBase = info.generator match { + case _: Explode => Explode(fixedGenChild) + case _: PosExplode => PosExplode(fixedGenChild) + } + + // The inner generate's output schema is derived from the fixed generator's + // element schema. If the source array was pruned, the element schema changes. + val newGenOutput = info.generate.generatorOutput + .zip(toAttributes(fixedGenerator.elementSchema)).map { + case (oldAttr, newAttr) => + newAttr.withExprId(oldAttr.exprId).withName(oldAttr.name) + } + + // Track the new schema for this generate's colAttr + val newColAttr = if (info.generator.position) newGenOutput(1) else newGenOutput.head + newColAttr.dataType match { + case st: StructType => + colAttrSchemaMap(info.colAttr.exprId) = st + // Also update the type map for intermediate nodes above this generate + val newColType = fixedGenerator.child.dataType match { + case ArrayType(ArrayType(_, innerContainsNull), _) => + ArrayType(st, innerContainsNull) + case ArrayType(_, _) => + st + case _ => + st + } + attrTypeMap(info.colAttr.exprId) = newColType + case _ => // Non-struct output, skip + } + + // Add this Generate (intermediate nodes already added above) + currentChild = info.generate.copy( + generator = fixedGenerator, + generatorOutput = newGenOutput, + child = currentChild) + } + + // Rebuild above-chain nodes (nodes between originalProject and chain[0]) + // These need their expressions fixed for the new schema + + currentChild = rebuildAboveChainNodes( + aboveChainNodes, currentChild, colAttrSchemaMap.toMap, chain, attrTypeMap) + + + // Fix ordinals in the top-level expressions for ALL updated generates + val fixedProjectList = projectList.map { expr => + colAttrSchemaMap.foldLeft(expr: Expression) { case (e, (exprId, schema)) => + val colAttr = chain.find(_.colAttr.exprId == exprId).map(_.colAttr).get + fixOrdinalsInExprWithSchema(e, colAttr, schema) + }.asInstanceOf[NamedExpression] + } + + filterOpt match { + case Some(filter) => + val fixedCond = colAttrSchemaMap.foldLeft(filter.condition: Expression) { + case (e, (exprId, schema)) => + val colAttr = chain.find(_.colAttr.exprId == exprId).map(_.colAttr).get + fixOrdinalsInExprWithSchema(e, colAttr, schema) + } + Project(fixedProjectList, Filter(fixedCond, currentChild)) + case None => + Project(fixedProjectList, currentChild) + } + } + } + + /** + * Rebuilds the nodes between the original top Project and chain[0]. + * These nodes include non-struct Generates, Projects, and Filters that reference + * chain outputs and need their expressions fixed. + * + * Rebuilds in reverse order (bottom-to-top: closest to chain first). + */ + private def rebuildAboveChainNodes( + aboveChainNodes: Seq[LogicalPlan], + currentChild: LogicalPlan, + colAttrSchemaMap: Map[ExprId, StructType], + chain: Seq[GenerateInfo], + attrTypeMap: mutable.Map[ExprId, DataType]): LogicalPlan = { + + if (aboveChainNodes.isEmpty) return currentChild + + // Rebuild in reverse order (bottom-to-top) + aboveChainNodes.reverse.foldLeft(currentChild) { (child, node) => + node match { + case g: Generate => + // Fix ordinals and types in generator child + var fixedGenChild: Expression = g.generator match { + case e: ExplodeBase => + var expr: Expression = e.child + // Fix ordinals for struct field access + expr = colAttrSchemaMap.foldLeft(expr) { case (ex, (exprId, schema)) => + chain.find(_.colAttr.exprId == exprId).map(_.colAttr) match { + case Some(colAttr) => fixOrdinalsInExprWithSchema(ex, colAttr, schema) + case None => ex + } + } + // Update attribute types + expr = updateAttributeTypes(expr, attrTypeMap.toMap) + expr + case other => other.children.headOption.getOrElse(Literal(null)) + } + val fixedGenerator = g.generator match { + case _: Explode => Explode(fixedGenChild) + case _: PosExplode => PosExplode(fixedGenChild) + case other => other // Keep as-is for other generator types + } + + // Preserve the original unrequiredChildIndex. It was correctly computed + // by ColumnPruning (which runs before this rule) based on which child + // columns are only needed by the generator and not by the plan above. + // Recomputing it from only generator references would incorrectly drop + // columns that are needed as pass-throughs (e.g., SELECT arr, explode(arr)). + g.copy( + generator = fixedGenerator, + child = child) + + case p: Project => + // Fix ordinals and types in project expressions + val fixedProjectList = p.projectList.map { expr => + var e: Expression = expr + e = colAttrSchemaMap.foldLeft(e) { case (ex, (exprId, schema)) => + chain.find(_.colAttr.exprId == exprId).map(_.colAttr) match { + case Some(colAttr) => fixOrdinalsInExprWithSchema(ex, colAttr, schema) + case None => ex + } + } + e = updateAttributeTypes(e, attrTypeMap.toMap) + // Track alias types for downstream + e match { + case alias: Alias => attrTypeMap(alias.exprId) = alias.dataType + case _ => + } + e.asInstanceOf[NamedExpression] + } + Project(fixedProjectList, child) + + case f: Filter => + // Fix ordinals and types in filter condition + var fixedCond: Expression = f.condition + fixedCond = colAttrSchemaMap.foldLeft(fixedCond) { case (ex, (exprId, schema)) => + chain.find(_.colAttr.exprId == exprId).map(_.colAttr) match { + case Some(colAttr) => + fixOrdinalsInExprWithSchema(ex, colAttr, schema) + case None => ex + } + } + fixedCond = updateAttributeTypes(fixedCond, attrTypeMap.toMap) + Filter(fixedCond, child) + + case other => + // For other node types, just update the child + other.withNewChildren(Seq(child)) + } + } + } + + /** + * Fixes ordinals in expressions using a nested StructType schema. + * This handles both direct field access and nested struct/array access. + */ + private def fixOrdinalsInExprWithSchema( + expr: Expression, + colAttr: Attribute, + newSchema: StructType): Expression = { + + // For GetStructField: colAttr should have StructType + val structColAttr = colAttr.withDataType(newSchema) + + // For GetArrayStructFields: colAttr should have ArrayType(StructType) + // Preserve the containsNull from the original array type if available + val arrayContainsNull = colAttr.dataType match { + case ArrayType(_, cn) => cn + case _ => true + } + val arrayColAttr = colAttr.withDataType(ArrayType(newSchema, arrayContainsNull)) + + // Transform bottom-up to properly propagate type changes + expr.transformUp { + // GetStructField directly on colAttr + case gsf @ GetStructField(child, _, _) if isAttrRef(child, colAttr) => + val fieldName = gsf.extractFieldName + if (newSchema.fieldNames.contains(fieldName)) { + val newOrdinal = newSchema.fieldIndex(fieldName) + gsf.copy(child = structColAttr, ordinal = newOrdinal) + } else { + gsf + } + + // GetArrayStructFields directly on colAttr - use newSchema parameter + // because child.dataType is still the OLD type (not yet transformed) + case gasf @ GetArrayStructFields(child, field, oldOrdinal, _, containsNull) + if isAttrRef(child, colAttr) => + if (newSchema.fieldNames.contains(field.name)) { + val newOrdinal = newSchema.fieldIndex(field.name) + val newField = newSchema(newOrdinal) + gasf.copy( + child = arrayColAttr, + field = newField, + ordinal = newOrdinal, + numFields = newSchema.length) + } else { + gasf + } + + // GetArrayStructFields on a chain rooted at colAttr (not directly on colAttr) + // For example: l1.level2.level3 where l1.level2 was already fixed above + // Here we can use child.dataType since the child has already been transformed + case gasf @ GetArrayStructFields(child, field, oldOrdinal, _, containsNull) + if !isAttrRef(child, colAttr) && isColAttrChainChild(child, colAttr) => + child.dataType match { + case ArrayType(innerStruct: StructType, _) => + if (innerStruct.fieldNames.contains(field.name)) { + val newOrdinal = innerStruct.fieldIndex(field.name) + val newField = innerStruct(newOrdinal) + gasf.copy( + field = newField, + ordinal = newOrdinal, + numFields = innerStruct.length) + } else { + gasf + } + case _ => gasf + } + + // GetStructField on a result that comes from our colAttr chain + // For example: l1.level2.some_struct_field where level2 contains a struct + case gsf @ GetStructField(child, oldOrdinal, _) + if !isAttrRef(child, colAttr) && isColAttrChainChild(child, colAttr) => + child.dataType match { + case innerStruct: StructType => + val fieldName = gsf.extractFieldName + if (innerStruct.fieldNames.contains(fieldName)) { + val newOrdinal = innerStruct.fieldIndex(fieldName) + gsf.copy(ordinal = newOrdinal) + } else { + gsf + } + case _ => gsf + } + + // GetNestedArrayStructFields directly on colAttr + // (for nested arrays like array>) + case gnasf @ GetNestedArrayStructFields(child, field, oldOrdinal, _, containsNull) + if isAttrRef(child, colAttr) => + // For nested arrays, we need to wrap newSchema in the appropriate array nesting + val nestedArrayColAttr = colAttr.dataType match { + case at: ArrayType => + // Preserve original array nesting, just update innermost struct + def updateInnermostStruct(dt: DataType): DataType = dt match { + case ArrayType(inner: ArrayType, cn) => + ArrayType(updateInnermostStruct(inner), cn) + case ArrayType(_: StructType, cn) => + ArrayType(newSchema, cn) + case other => other + } + colAttr.withDataType(updateInnermostStruct(at)) + case _ => colAttr + } + if (newSchema.fieldNames.contains(field.name)) { + val newOrdinal = newSchema.fieldIndex(field.name) + val newField = newSchema(newOrdinal) + gnasf.copy( + child = nestedArrayColAttr, + field = newField, + ordinal = newOrdinal, + numFields = newSchema.length) + } else { + gnasf + } + + // GetNestedArrayStructFields on a chain rooted at colAttr + case gnasf @ GetNestedArrayStructFields(child, field, oldOrdinal, _, containsNull) + if !isAttrRef(child, colAttr) && isColAttrChainChild(child, colAttr) => + // Extract innermost struct from child's nested array type + def getInnermostStruct(dt: DataType): Option[StructType] = dt match { + case ArrayType(inner: ArrayType, _) => getInnermostStruct(inner) + case ArrayType(st: StructType, _) => Some(st) + case _ => None + } + getInnermostStruct(child.dataType) match { + case Some(innerStruct) if innerStruct.fieldNames.contains(field.name) => + val newOrdinal = innerStruct.fieldIndex(field.name) + val newField = innerStruct(newOrdinal) + gnasf.copy( + field = newField, + ordinal = newOrdinal, + numFields = innerStruct.length) + case _ => gnasf + } + } + } + + /** + * Checks if an expression is a chain of field accesses rooted at the given attribute. + * For example, GetStructField(GetStructField(colAttr, ...), ...) would return true. + */ + private def isColAttrChainChild(expr: Expression, colAttr: Attribute): Boolean = { + expr match { + case a: Attribute => a.exprId == colAttr.exprId + case gsf: GetStructField => isColAttrChainChild(gsf.child, colAttr) + case gasf: GetArrayStructFields => isColAttrChainChild(gasf.child, colAttr) + case gnasf: GetNestedArrayStructFields => isColAttrChainChild(gnasf.child, colAttr) + case _ => false + } + } + + /** + * Rebuilds intermediate nodes (Projects, Filters) between Generates. + * These nodes need their expressions fixed when the schema changes. + * + * The nodes are stored top-to-bottom (closer to the Generate first), + * so we rebuild them in reverse order (bottom-to-top). + * + * Important: We must update BOTH: + * 1. GetStructField ordinals (handled by fixOrdinalsInExprWithSchema) + * 2. AttributeReference data types (for aliases like `Alias(outer_elem, "_extract")`) + * + * When we update a Generate's output schema, any aliases that reference that + * output need their AttributeReference types updated so downstream consumers + * (like inner generates) see the correct pruned schema. + * + * This function also updates the attrTypeMap with alias output types, so that + * attributes defined by these aliases can be updated in subsequent generates. + * + * @param intermediateNodes The intermediate nodes to rebuild + * @param currentChild The plan to use as the child of the bottommost intermediate node + * @param colAttrSchemaMap Map of colAttr exprId -> new schema for ordinal fixes + * @param chain The full chain (for looking up colAttr by exprId) + * @param attrTypeMap Mutable map of attribute types to update (also receives new alias types) + * @return The rebuilt plan with intermediate nodes above currentChild + */ + private def rebuildIntermediateNodes( + intermediateNodes: Seq[LogicalPlan], + currentChild: LogicalPlan, + colAttrSchemaMap: Map[ExprId, StructType], + chain: Seq[GenerateInfo], + attrTypeMap: mutable.Map[ExprId, DataType]): LogicalPlan = { + + // Rebuild in reverse order (bottom-to-top) + intermediateNodes.reverse.foldLeft(currentChild) { (child, node) => + node match { + case Project(projectList, _) => + // Fix ordinals and attribute types in all expressions + val fixedProjectList = projectList.map { expr => + var e: Expression = expr + // First, fix ordinals for GetStructField + e = colAttrSchemaMap.foldLeft(e) { case (ex, (exprId, schema)) => + chain.find(_.colAttr.exprId == exprId).map(_.colAttr) match { + case Some(colAttr) => fixOrdinalsInExprWithSchema(ex, colAttr, schema) + case None => ex + } + } + // Then, update AttributeReference types + e = updateAttributeTypes(e, attrTypeMap.toMap) + + // Track alias output types for downstream use + e match { + case alias: Alias => + // The alias's output type is derived from its child's type + // Add it to the type map so attributes referencing this alias can be updated + attrTypeMap(alias.exprId) = alias.dataType + case _ => + } + + e.asInstanceOf[NamedExpression] + } + Project(fixedProjectList, child) + + case Filter(condition, _) => + // Fix ordinals and attribute types in the condition + var fixedCondition = colAttrSchemaMap.foldLeft(condition) { + case (e, (exprId, schema)) => + chain.find(_.colAttr.exprId == exprId).map(_.colAttr) match { + case Some(colAttr) => fixOrdinalsInExprWithSchema(e, colAttr, schema) + case None => e + } + } + fixedCondition = updateAttributeTypes(fixedCondition, attrTypeMap.toMap) + Filter(fixedCondition, child) + + case other => + // Unexpected node type - preserve as-is + other.withNewChildren(Seq(child)) + } + } + } + + /** + * Updates AttributeReference data types in an expression tree. + * This is needed when a Generate's output schema changes - any aliases + * or expressions that reference the generator's output attributes need + * their types updated. + * + * @param expr The expression to update + * @param attrTypeMap Map of exprId -> new dataType + * @return Expression with updated attribute types + */ + private def updateAttributeTypes( + expr: Expression, + attrTypeMap: Map[ExprId, DataType]): Expression = { + expr.transformDown { + case attr: Attribute if attrTypeMap.contains(attr.exprId) => + attr.withDataType(attrTypeMap(attr.exprId)) + } + } + + // --------------------------------------------------------------------------- + // Helpers + // --------------------------------------------------------------------------- + + /** + * Checks if a data type is an array containing structs at any nesting level. + * Returns true for `array`, `array>`, etc. + */ + private def isArrayOfStruct(dt: DataType): Boolean = dt match { + case ArrayType(_: StructType, _) => true + case ArrayType(elementType: ArrayType, _) => isArrayOfStruct(elementType) + case _ => false + } + + /** + * Checks if a data type is a DIRECT array of struct (not nested arrays). + * Returns true only for `array`, false for `array>`. + * + * This is important for chain extraction: when exploding array>, + * the output element is array, not struct, so it's not a prunable + * struct access. + */ + private def isDirectArrayOfStruct(dt: DataType): Boolean = dt match { + case ArrayType(_: StructType, _) => true + case _ => false + } + + /** + * Extracts the direct struct type from an array type. + * Unlike extractInnermostStruct, this does NOT dig into nested arrays. + */ + private def extractDirectStruct(dt: DataType): (StructType, Boolean) = dt match { + case ArrayType(st: StructType, containsNull) => (st, containsNull) + case _ => throw new IllegalArgumentException(s"Expected array, got: $dt") + } + + /** + * Extracts the innermost struct type and containsNull from a nested array type. + * For `array`, returns the struct directly. + * For `array>`, returns the innermost struct. + */ + private def extractInnermostStruct(dt: DataType): (StructType, Boolean) = dt match { + case ArrayType(st: StructType, containsNull) => (st, containsNull) + case ArrayType(elementType: ArrayType, _) => extractInnermostStruct(elementType) + case _ => throw new IllegalArgumentException(s"Expected nested array of struct, got: $dt") + } + + private def isAttrRef(e: Expression, attr: Attribute): Boolean = e match { + case a: Attribute => a.exprId == attr.exprId + case _ => false + } + + /** + * Checks if the required schema has any pruning compared to the original schema. + * Currently only checks top-level field count reduction. + * + * Note: We intentionally only check top-level field count, not nested types, + * because the rewrite logic for nested type changes is complex and not fully + * implemented for all cases (e.g., inner generate pruning through chained generates). + */ + private def hasNestedTypePruning(req: StructType, orig: StructType): Boolean = { + // Check top-level field count reduction + if (req.length < orig.length) return true + + // Check if any field has a pruned nested type + req.fields.exists { reqField => + orig.fields.find(_.name == reqField.name) match { + case Some(origField) => + hasNestedDataTypePruning(reqField.dataType, origField.dataType) + case None => + false // Field not in original (shouldn't happen) + } + } + } + + /** + * Recursively checks if a data type has been pruned compared to the original. + * Returns true if the required type is "smaller" than the original. + */ + private def hasNestedDataTypePruning(req: DataType, orig: DataType): Boolean = { + (req, orig) match { + case (reqSt: StructType, origSt: StructType) => + hasNestedTypePruning(reqSt, origSt) + case (ArrayType(reqElem, _), ArrayType(origElem, _)) => + hasNestedDataTypePruning(reqElem, origElem) + case _ => + false + } + } + + /** + * Extracts the root attribute from an expression that is either: + * 1. A direct Attribute + * 2. A chain of GetStructField expressions rooted at an Attribute + * + * This allows pruning for patterns like `explode(col.structArr)` where + * the array is nested inside a struct column from the scan. + * + * Returns None for complex expressions (e.g., UDFs, computations). + */ + private def extractRootAttribute(expr: Expression): Option[Attribute] = { + expr match { + case a: Attribute => Some(a) + case GetStructField(child, _, _) => extractRootAttribute(child) + case GetArrayStructFields(child, _, _, _, _) => extractRootAttribute(child) + case GetNestedArrayStructFields(child, _, _, _, _) => extractRootAttribute(child) + case _ => None + } + } + + /** + * Traces an attribute through intermediate Projects to find its ultimate scan source. + * This is needed because GeneratorNestedColumnAliasing may have extracted array fields + * into aliases. For example: + * - Generate source: _extract_inner_array#21 + * - Intermediate Project defines: _extract_inner_array = outer_array.inner_array + * - This function returns: outer_array (the scan attribute) + * + * @param attr The attribute to trace + * @param chain The generate chain (contains intermediate nodes) + * @param scanAttrIds Set of scan attribute IDs + * @return The ultimate scan-rooted attribute, or None if can't be traced + */ + private def traceToScanAttribute( + attr: Attribute, + chain: Seq[GenerateInfo], + scanAttrIds: Set[ExprId]): Option[Attribute] = { + + // If already a scan attribute, return it + if (scanAttrIds.contains(attr.exprId)) { + return Some(attr) + } + + + // Look through all intermediate Projects in the chain for alias definitions + val allIntermediates = chain.flatMap(_.intermediateNodes) + allIntermediates.foreach { + case proj @ Project(projectList, _) => + projectList.foreach { + case a @ Alias(child, _) if a.exprId == attr.exprId => + // Found the alias definition. Trace the child expression. + extractRootAttribute(child) match { + case Some(childAttr) => + // Recursively trace + val traced = traceToScanAttribute(childAttr, chain, scanAttrIds) + if (traced.isDefined) return traced + case None => + // Child is a complex expression. Try GetArrayStructFields + // or GetNestedArrayStructFields. + child match { + case gasf: GetArrayStructFields => + extractRootAttribute(gasf.child) match { + case Some(gasfRootAttr) => + val traced = traceToScanAttribute(gasfRootAttr, chain, scanAttrIds) + if (traced.isDefined) return traced + case None => + } + case gnasf: GetNestedArrayStructFields => + extractRootAttribute(gnasf.child) match { + case Some(gnasfRootAttr) => + val traced = traceToScanAttribute(gnasfRootAttr, chain, scanAttrIds) + if (traced.isDefined) return traced + case None => + } + case _ => + } + } + case _ => + } + case other => + } + + // Couldn't trace to scan + None + } + + /** + * Resolves an expression through intermediate Projects to find its scan-rooted form. + * Similar to traceToScanAttribute but returns the full expression path. + * + * For example: + * - Input: _extract_inner_array#21 + * - Intermediate: _extract_inner_array = outer_array.inner_array + * - Returns: outer_array.inner_array + * + * This resolved expression can be used to build GetArrayStructFields that + * SchemaPruning can trace to scan attributes. + */ + private def resolveExpressionThroughChain( + expr: Expression, + chain: Seq[GenerateInfo]): Expression = { + + expr match { + case attr: Attribute => + // Look through all intermediate Projects for alias definitions + val allIntermediates = chain.flatMap(_.intermediateNodes) + allIntermediates.foreach { + case Project(projectList, _) => + projectList.foreach { + case a @ Alias(child, _) if a.exprId == attr.exprId => + // Found the alias definition. Recursively resolve the child. + return resolveExpressionThroughChain(child, chain) + case _ => + } + case _ => + } + // No alias found, return as-is + attr + + case GetStructField(child, ordinal, name) => + // Resolve the child and rebuild + val resolvedChild = resolveExpressionThroughChain(child, chain) + GetStructField(resolvedChild, ordinal, name) + + case other => + // For other expressions, return as-is + other + } + } + + // --------------------------------------------------------------------------- + // Nested schema requirement analysis (Step 11) + // --------------------------------------------------------------------------- + + /** + * Analyzes expressions to build a nested StructType representing required fields + * from the given attribute. Unlike the flat Set[String] approach, this captures + * the full nested structure needed for inner generate pruning. + * + * For example, if expressions access: + * - item.requestId + * - item.items (used by inner explode selecting item.itemId) + * + * This returns: struct>> + * + * @param exprs Expressions to analyze + * @param colAttr The exploded element attribute + * @param elementStruct The original struct type of array elements + * @param chain Optional chain for resolving aliases (if provided) + * @return None if the attribute is referenced directly (all fields needed), + * Some(StructType) with required fields otherwise + */ + /** + * Recursive nested requirements type. Maps field name to its nested requirements. + * An empty map means "need this field with original type". + * A non-empty map means "need this field but prune its contents to only these nested fields". + */ + private type NestedReqs = mutable.Map[String, mutable.Map[String, Any]] + + /** + * Extracts the full field path from a chained GetStructField/GetArrayStructFields expression. + * For example, for `l1.level2.level3` returns Some(Seq("level2", "level3")). + * + * Also traces through generate outputs and aliases to find the ultimate path. + * For example, if `l2` is the output of exploding `l1.level2.level3`, then + * `l2.l3_f1` returns Some(Seq("level2", "level3", "l3_f1")). + */ + private def extractFieldPath( + expr: Expression, + colAttr: Attribute, + generateSourceMap: Map[ExprId, Expression] = Map.empty, + aliasMap: Map[ExprId, Expression] = Map.empty): Option[Seq[String]] = { + expr match { + case gsf: GetStructField if isAttrRef(gsf.child, colAttr) => + Some(Seq(gsf.extractFieldName)) + case gsf: GetStructField => + extractFieldPath(gsf.child, colAttr, generateSourceMap, aliasMap) + .map(_ :+ gsf.extractFieldName) + case gasf: GetArrayStructFields if isAttrRef(gasf.child, colAttr) => + Some(Seq(gasf.field.name)) + case gasf: GetArrayStructFields => + extractFieldPath(gasf.child, colAttr, generateSourceMap, aliasMap) + .map(_ :+ gasf.field.name) + case gnasf: GetNestedArrayStructFields if isAttrRef(gnasf.child, colAttr) => + Some(Seq(gnasf.field.name)) + case gnasf: GetNestedArrayStructFields => + extractFieldPath(gnasf.child, colAttr, generateSourceMap, aliasMap) + .map(_ :+ gnasf.field.name) + + // Handle generate output attributes: l2.l3_f1 where l2 is from exploding an expression + case attr: Attribute if generateSourceMap.contains(attr.exprId) => + val sourceExpr = generateSourceMap(attr.exprId) + // Resolve through alias if needed + val resolvedSource = resolveAliasExpr(sourceExpr, aliasMap) + extractFieldPath(resolvedSource, colAttr, generateSourceMap, aliasMap) + + // Handle alias attributes: resolve to their definition + case attr: Attribute if aliasMap.contains(attr.exprId) => + val aliasedExpr = aliasMap(attr.exprId) + extractFieldPath(aliasedExpr, colAttr, generateSourceMap, aliasMap) + + case _ => + None + } + } + + /** + * Resolves an expression through alias definitions. + */ + private def resolveAliasExpr( + expr: Expression, + aliasMap: Map[ExprId, Expression]): Expression = { + expr match { + case attr: Attribute if aliasMap.contains(attr.exprId) => + resolveAliasExpr(aliasMap(attr.exprId), aliasMap) + case gsf @ GetStructField(child, _, _) => + gsf.copy(child = resolveAliasExpr(child, aliasMap)) + case gasf @ GetArrayStructFields(child, field, ordinal, numFields, containsNull) => + GetArrayStructFields(resolveAliasExpr(child, aliasMap), field, ordinal, + numFields, containsNull) + case gnasf @ GetNestedArrayStructFields( + child, field, ordinal, numFields, containsNull) => + GetNestedArrayStructFields(resolveAliasExpr(child, aliasMap), field, ordinal, + numFields, containsNull) + case _ => expr + } + } + + /** + * Collects all generate outputs and their source expressions from the plan. + * This allows tracing through generate chains to understand field paths. + */ + private def collectGenerateSources( + plan: LogicalPlan): Map[ExprId, Expression] = { + val sources = mutable.Map.empty[ExprId, Expression] + + def collect(p: LogicalPlan): Unit = { + p match { + case g: Generate => + g.generator match { + case e: ExplodeBase => + // For explode/posexplode, the element output attribute comes from the source array + val colAttr = if (e.position) g.generatorOutput(1) else g.generatorOutput.head + sources(colAttr.exprId) = e.child + case _ => + } + collect(g.child) + case _ => + p.children.foreach(collect) + } + } + + collect(plan) + sources.toMap + } + + /** + * Adds a field path to nested requirements. Creates intermediate nodes as needed. + */ + private def addPathToNestedReqs( + reqs: mutable.Map[String, Any], + path: Seq[String]): Unit = { + if (path.isEmpty) return + val field = path.head + if (path.length == 1) { + // Leaf of path - mark as needed (empty map means use original type) + if (!reqs.contains(field)) { + reqs(field) = mutable.Map.empty[String, Any] + } + } else { + // Intermediate field - recurse into nested requirements + val innerReqs = reqs.getOrElseUpdate(field, mutable.Map.empty[String, Any]) + .asInstanceOf[mutable.Map[String, Any]] + addPathToNestedReqs(innerReqs, path.tail) + } + } + + /** + * Unwraps nested ArrayTypes and returns the innermost struct along with + * the containsNull flags at each array level (outermost first). + * + * @return (innerStruct, arrayNullFlags) or None if no struct found + */ + private def unwrapArraysToStruct(dt: DataType): Option[(StructType, Seq[Boolean])] = { + dt match { + case st: StructType => Some((st, Seq.empty)) + case ArrayType(inner, containsNull) => + unwrapArraysToStruct(inner).map { case (st, flags) => + (st, containsNull +: flags) + } + case _ => None + } + } + + /** + * Rewraps a struct type with the given array nesting levels. + * + * @param st The struct type to wrap + * @param arrayNullFlags containsNull flags for each array level (outermost first) + * @return The wrapped type (array<...array>) + */ + private def rewrapWithArrays(st: StructType, arrayNullFlags: Seq[Boolean]): DataType = { + arrayNullFlags.foldRight[DataType](st) { (containsNull, inner) => + ArrayType(inner, containsNull) + } + } + + /** + * Builds a pruned StructField based on nested requirements. + * Recursively prunes nested struct/array types, supporting arbitrary array nesting depth. + */ + private def buildPrunedField( + origField: StructField, + nestedReqs: mutable.Map[String, Any]): StructField = { + if (nestedReqs.isEmpty) { + // Empty nested reqs means use original field as-is + return origField + } + + // Helper to prune a struct type based on requirements + def pruneStructType( + st: StructType, + reqs: mutable.Map[String, Any]): (Array[StructField], Boolean) = { + val prunedFields = st.fieldNames.flatMap { fname => + reqs.get(fname).map { innerReqsAny => + val innerReqs = innerReqsAny.asInstanceOf[mutable.Map[String, Any]] + val innerOrigField = st(st.fieldIndex(fname)) + buildPrunedField(innerOrigField, innerReqs) + } + } + val anyNestedPruning = prunedFields.zip(st.fields).exists { + case (pruned, orig) => pruned.dataType != orig.dataType + } + val hasPruning = prunedFields.length < st.length || anyNestedPruning + (prunedFields, hasPruning) + } + + origField.dataType match { + case st: StructType => + val (prunedFields, hasPruning) = pruneStructType(st, nestedReqs) + if (hasPruning) { + origField.copy(dataType = StructType(prunedFields)) + } else { + origField + } + + case _ => + // Try to unwrap arrays to find inner struct + unwrapArraysToStruct(origField.dataType) match { + case Some((innerStruct, arrayNullFlags)) if arrayNullFlags.nonEmpty => + val (prunedFields, hasPruning) = pruneStructType(innerStruct, nestedReqs) + if (hasPruning) { + val prunedType = rewrapWithArrays(StructType(prunedFields), arrayNullFlags) + origField.copy(dataType = prunedType) + } else { + origField + } + + case _ => + // Not a struct or array-of-struct - return as-is + origField + } + } + } + + private def analyzeRequiredSchema( + exprs: Seq[Expression], + colAttr: Attribute, + elementStruct: StructType, + chain: Seq[GenerateInfo] = Seq.empty, + generateSourceMap: Map[ExprId, Expression] = Map.empty, + fullAliasMap: Map[ExprId, Expression] = Map.empty): Option[StructType] = { + + // Track fields with their nested requirements as a tree structure. + // For deeply nested access like l1.level2.level3.l3_f1, we track: + // level2 -> { level3 -> { l3_f1 -> {} } } + // This allows pruning at all levels, not just one. + val nestedReqs = mutable.Map.empty[String, Any] + var canPrune = true + + // Build a map of alias definitions from intermediate nodes + val aliasDefinitions: Map[ExprId, Expression] = { + val allIntermediates = chain.flatMap(_.intermediateNodes) + val intermediateAliases = allIntermediates.flatMap { + case Project(projectList, _) => + projectList.collect { + case a @ Alias(child, _) => a.exprId -> child + } + case _ => Seq.empty + }.toMap + // Merge with full alias map (from entire plan) + fullAliasMap ++ intermediateAliases + } + + def analyze(e: Expression): Unit = { + if (!canPrune) return + e match { + // GetStructField on our element attribute (single level access) + case gsf: GetStructField if isAttrRef(gsf.child, colAttr) => + val fieldName = gsf.extractFieldName + if (elementStruct.fieldNames.contains(fieldName)) { + // Single-level access: need the full field + if (!nestedReqs.contains(fieldName)) { + nestedReqs(fieldName) = mutable.Map.empty[String, Any] + } + } + + // GetArrayStructFields on our element attribute (single level access) + case gasf: GetArrayStructFields if isAttrRef(gasf.child, colAttr) => + val fieldName = gasf.field.name + if (elementStruct.fieldNames.contains(fieldName)) { + if (!nestedReqs.contains(fieldName)) { + nestedReqs(fieldName) = mutable.Map.empty[String, Any] + } + } + + // GetNestedArrayStructFields on our element attribute (single level access) + case gnasf: GetNestedArrayStructFields if isAttrRef(gnasf.child, colAttr) => + val fieldName = gnasf.field.name + if (elementStruct.fieldNames.contains(fieldName)) { + if (!nestedReqs.contains(fieldName)) { + nestedReqs(fieldName) = mutable.Map.empty[String, Any] + } + } + + // Chained GetStructField or GetArrayStructFields - extract full path + // Now also traces through generate outputs and aliases + case gsf: GetStructField => + extractFieldPath(gsf, colAttr, generateSourceMap, aliasDefinitions) match { + case Some(path) if path.nonEmpty => + if (elementStruct.fieldNames.contains(path.head)) { + addPathToNestedReqs(nestedReqs, path) + } + case _ => + // Not rooted at our colAttr, recurse + gsf.children.foreach(analyze) + } + + case gasf: GetArrayStructFields => + extractFieldPath(gasf, colAttr, generateSourceMap, aliasDefinitions) match { + case Some(path) if path.nonEmpty => + if (elementStruct.fieldNames.contains(path.head)) { + addPathToNestedReqs(nestedReqs, path) + } + case _ => + // Not rooted at our colAttr, recurse + gasf.children.foreach(analyze) + } + + case gnasf: GetNestedArrayStructFields => + extractFieldPath(gnasf, colAttr, generateSourceMap, aliasDefinitions) match { + case Some(path) if path.nonEmpty => + if (elementStruct.fieldNames.contains(path.head)) { + addPathToNestedReqs(nestedReqs, path) + } + case _ => + gnasf.children.foreach(analyze) + } + + // Direct reference to the element attribute - all fields needed + case a: Attribute if a.exprId == colAttr.exprId => + canPrune = false + + // Alias reference - resolve and analyze the aliased expression + case a: Attribute if aliasDefinitions.contains(a.exprId) => + val aliasedExpr = aliasDefinitions(a.exprId) + analyze(aliasedExpr) + + case other => + other.children.foreach(analyze) + } + } + + exprs.foreach(analyze) + + if (!canPrune) { + None + } else if (nestedReqs.isEmpty) { + // No fields needed - shouldn't happen but return empty struct + Some(StructType(Seq.empty)) + } else { + // Build the result schema with pruned nested arrays + // CRITICAL: Order fields according to the original elementStruct order, not discovery order! + // This ensures that ordinals in expressions (e.g., filter on b_int at ordinal 1) match + // the actual ArraysZip output order (which also uses original struct order). + val resultFields = elementStruct.fieldNames.flatMap { fieldName => + nestedReqs.get(fieldName).map { innerReqsAny => + val origField = elementStruct(elementStruct.fieldIndex(fieldName)) + val innerReqs = innerReqsAny.asInstanceOf[mutable.Map[String, Any]] + buildPrunedField(origField, innerReqs) + } + }.toSeq + Some(StructType(resultFields)) + } + } + + /** + * Merges an inner generate's required element schema into the outer generate's + * required schema for the array field that the inner generate explodes. + * + * For example: + * - Outer generates elements with struct>> + * - Inner explodes items, requiring struct + * - Merged outer requirement: struct>> + * + * Note: This REPLACES the existing field with the pruned version. The inner + * requirements are MORE specific (fewer fields needed) than the original. + * + * Merges inner generate's requirements into outer schema following a field path. + * + * For example, with: + * - outerElementStruct = struct>>> + * - fieldPath = Seq("level2", "level3") + * - innerElementSchema = struct + * + * Returns a StructType representing: + * struct>>>> + * + * This properly handles nested paths where inner generates access deeply nested arrays. + */ + private def mergeInnerRequirementsWithPath( + outerSchema: StructType, + fieldPath: Seq[String], + innerElementSchema: StructType, + outerElementStruct: StructType): StructType = { + + + if (fieldPath.isEmpty) { + return outerSchema + } + + val firstField = fieldPath.head + val remainingPath = fieldPath.tail + + // Get the original field definition from the struct + if (!outerElementStruct.fieldNames.contains(firstField)) { + return outerSchema // Field not found, return unchanged + } + + val origFieldIdx = outerElementStruct.fieldIndex(firstField) + val origField = outerElementStruct(origFieldIdx) + + // Build the pruned data type + val prunedDataType = if (remainingPath.isEmpty) { + // At the leaf of the path - this is the array that the inner generate explodes + // Handle arbitrary array nesting: array<...array> -> array<...array> + unwrapArraysToStruct(origField.dataType) match { + case Some((_, arrayNullFlags)) if arrayNullFlags.nonEmpty => + // Replace innermost struct with the inner element schema + rewrapWithArrays(innerElementSchema, arrayNullFlags) + case _ => + origField.dataType + } + } else { + // Not at leaf - recurse into the nested array/struct + // Handle arbitrary array nesting depth + unwrapArraysToStruct(origField.dataType) match { + case Some((innerStruct, arrayNullFlags)) if arrayNullFlags.nonEmpty => + // Recurse to build the nested pruned structure + val nestedPruned = mergeInnerRequirementsWithPath( + StructType(Seq.empty), remainingPath, innerElementSchema, innerStruct) + rewrapWithArrays(nestedPruned, arrayNullFlags) + case Some((innerStruct, _)) => + // Direct struct, no array wrapping + mergeInnerRequirementsWithPath( + StructType(Seq.empty), remainingPath, innerElementSchema, innerStruct) + case None => + // Not an array of struct or struct - can't navigate deeper + origField.dataType + } + } + + // Create the field with pruned type + val prunedField = origField.copy(dataType = prunedDataType) + + // Check if the field already exists in outer schema + val existingFieldOpt = outerSchema.fields.find(_.name == firstField) + + val result = existingFieldOpt match { + case Some(_) => + // Replace the existing field with the pruned version + StructType(outerSchema.fields.map { f => + if (f.name == firstField) prunedField else f + }) + case None => + // Add the new field to outer schema + StructType(outerSchema.fields :+ prunedField) + } + result + } + + /** + * Computes nested schema requirements for each Generate in the chain. + * This is the key method for inner generate pruning (Step 12). + * + * Unlike the flat `computeChainRequirements` which returns `Set[String]`, + * this returns a nested `StructType` per Generate that includes: + * 1. Direct field accesses (e.g., item.requestId) + * 2. Nested array fields with their own pruned element schema (for inner generates) + * + * The propagation works backward through the chain: + * - Start with the innermost generate's requirements + * - For each outer generate, merge inner requirements into its schema + * + * @param chain The sequence of GenerateInfo (top-to-bottom: closest to Project first) + * @param inputExprs Expressions to analyze (from top-level project, filters, intermediate nodes) + * @param additionalFilters Additional filter conditions + * @param leaf The leaf node below the chain + * @return Sequence of optional StructTypes (None = can't prune, Some = required schema) + */ + /** + * Collects all alias definitions from a plan tree. + */ + private def collectAliasDefinitions(plan: LogicalPlan): Map[ExprId, Expression] = { + val aliases = mutable.Map.empty[ExprId, Expression] + + def collect(p: LogicalPlan): Unit = { + p match { + case Project(projectList, child) => + projectList.foreach { + case a @ Alias(childExpr, _) => aliases(a.exprId) = childExpr + case _ => + } + collect(child) + case _ => + p.children.foreach(collect) + } + } + + collect(plan) + aliases.toMap + } + + private def computeNestedChainRequirements( + chain: Seq[GenerateInfo], + inputExprs: Seq[Expression], + additionalFilters: Seq[Expression], + leaf: LogicalPlan, + chainStart: LogicalPlan): Seq[Option[StructType]] = { + + // Collect generate sources from the entire plan (to trace through generate outputs) + val generateSourceMap = collectGenerateSources(chainStart) + + // Collect all alias definitions from the plan + val fullAliasMap = collectAliasDefinitions(chainStart) + + + // Extract filters from the leaf path + val (leafFilters, _) = decomposeChild(leaf) + + // Collect expressions from ALL intermediate nodes in the chain + // These include filters and projects between generates that may reference the source array + // + // IMPORTANT: Filter out direct attribute references from intermediate Projects. + // When ColumnPruning inserts a pass-through Project like `Project(a_array_item, ...)`, + // the `a_array_item` attribute reference is just a pass-through, NOT an indication + // that all fields of the struct are needed. We only care about GetStructField accesses + // from intermediate nodes, not direct attribute references. + val intermediateNodeExprs = chain.flatMap { info => + info.intermediateNodes.flatMap { + case Project(projectList, _) => + // Filter out direct attribute references (pass-throughs) + projectList.filter { + case _: Attribute => false + case Alias(_: Attribute, _) => false + case _ => true + } + case Filter(condition, _) => Seq(condition) + case _ => Nil + } + } + + val topExprs: Seq[Expression] = + inputExprs ++ additionalFilters ++ leafFilters ++ intermediateNodeExprs + + // First pass: compute direct requirements for each generate (without inner propagation) + val directRequirements: Seq[Option[StructType]] = chain.indices.map { i => + val info = chain(i) + + // Collect expressions that may reference this generate's output + val parentGenExprs = chain.take(i).flatMap { parentInfo => + Seq(parentInfo.generator.child) + } + val exprs = topExprs ++ parentGenExprs + + // Analyze direct requirements from expressions + // Pass generate source map and alias map for tracing through generate chains + val directOpt = analyzeRequiredSchema( + exprs, info.colAttr, info.elementStruct, chain, generateSourceMap, fullAliasMap) + + // Also analyze source array field accesses + val sourceRootAttr = extractRootAttribute(info.generator.child) + val sourceFieldsOpt = analyzeSourceArrayFields(exprs, sourceRootAttr, info.generator.child) + + (directOpt, sourceFieldsOpt) match { + case (None, _) => None // Direct element reference - can't prune + case (_, None) => None // Complex source expression - can't prune + case (Some(directSchema), Some(sourceFields)) => + // Merge source array field requirements into schema + val withSourceFields = sourceFields.foldLeft(directSchema) { (schema, fieldName) => + if (info.elementStruct.fieldNames.contains(fieldName)) { + val origField = info.elementStruct(info.elementStruct.fieldIndex(fieldName)) + if (schema.fieldNames.contains(fieldName)) { + schema // Already present + } else { + StructType(schema.fields :+ origField) + } + } else { + schema + } + } + Some(withSourceFields) + } + } + + // Second pass: propagate requirements from generates closer to Project to those closer to Scan. + // Chain is ordered top-to-bottom: [generate closest to Project, ..., generate closest to Scan] + // If chain[i]'s source is from chain[j] (where j > i), merge chain[i]'s element requirements + // into chain[j]'s schema for that array field. + val propagatedRequirements = mutable.ArrayBuffer[Option[StructType]]() + propagatedRequirements ++= directRequirements + + // Iterate from generates closest to Project (index 0) toward Scan + for (i <- 0 until chain.length - 1) { + val currentInfo = chain(i) + val currentReqOpt = propagatedRequirements(i) + + + currentReqOpt match { + case Some(currentSchema) if currentSchema.nonEmpty => + // Find which source generate (at larger index) provides this generate's input + val sourceIdx = findSourceGenerateForCurrent(chain, i, currentInfo) + + sourceIdx match { + case Some(si) => + // Extract the full field path that current accesses from source's element + val sourceFieldPath = extractSourceFieldPath(chain(si), currentInfo) + val sourceStruct = chain(si).elementStruct + + if (sourceFieldPath.nonEmpty) { + if (propagatedRequirements(si).isDefined) { + // Source already has requirements - merge inner's into it + val sourceSchema = propagatedRequirements(si).get + val merged = mergeInnerRequirementsWithPath( + sourceSchema, sourceFieldPath, currentSchema, sourceStruct) + propagatedRequirements(si) = Some(merged) + } else { + // Source has no direct requirements (None). + // Initialize with inner's requirements for the accessed path. + val emptySchema = StructType(Seq.empty) + val merged = mergeInnerRequirementsWithPath( + emptySchema, sourceFieldPath, currentSchema, sourceStruct) + propagatedRequirements(si) = Some(merged) + } + } else { + // Inner generate explodes outer's element directly (not a field of it). + // This happens when GeneratorNestedColumnAliasing has already extracted + // the array field. In this case, inner's element type IS outer's innermost + // element type, so inner's requirements apply directly to outer. + if (sourceStruct.fieldNames.toSet == + currentInfo.elementStruct.fieldNames.toSet) { + // Inner's requirements become outer's requirements + propagatedRequirements(si) = Some(currentSchema) + } + } + + case _ => // No source generate found + } + + case _ => // Current can't be pruned or has no requirements + } + } + + // Handle pos-only posexplode optimization and empty requirements + propagatedRequirements.zipWithIndex.map { case (reqOpt, i) => + val info = chain(i) + reqOpt match { + case Some(schema) if schema.isEmpty && info.posAttrOpt.isDefined => + // Pos-only: select minimal-weight field + if (info.elementStruct.fields.length > 1) { + selectMinimalWeightField(info.elementStruct).map { f => + StructType(Array(f)) + } + } else { + // Already at minimal - no pruning possible (idempotency) + None + } + case Some(schema) if schema.isEmpty => + // Empty requirements with non-posexplode means no explicit field accesses were found. + // This typically happens when GeneratorNestedColumnAliasing has already extracted + // the needed fields into an alias. The generate's source is already optimized, + // so we should not try to "further prune" to an empty struct. + // Return None to indicate no pruning should happen. + None + case other => other + } + }.toSeq + } + + /** + * Finds the index of the source generate whose output is used by the current generate. + * Returns None if the current generate's source is from the scan (not from another generate). + * + * Note: Chain is ordered top-to-bottom, so source generates have LARGER indices. + * + * When GeneratorNestedColumnAliasing creates intermediate Projects with aliases, + * we need to trace through those aliases to find the connection. + */ + private def findSourceGenerateForCurrent( + chain: Seq[GenerateInfo], + currentIdx: Int, + currentInfo: GenerateInfo): Option[Int] = { + + // The current generate's source is an expression like `outer.items` or an alias + // We need to find which source generate's colAttr matches the root + extractRootAttribute(currentInfo.generator.child) match { + case Some(rootAttr) => + // First, try direct match + val directMatch = chain.drop(currentIdx + 1).zipWithIndex.find { case (sourceInfo, _) => + sourceInfo.colAttr.exprId == rootAttr.exprId + }.map { case (_, relIdx) => currentIdx + 1 + relIdx } + + directMatch.orElse { + // If no direct match, check if the rootAttr is an alias for another generate's output. + // This happens when GeneratorNestedColumnAliasing creates intermediate aliases. + // Look at intermediate Projects between current and source generates. + findSourceThroughAliases(chain, currentIdx, currentInfo, rootAttr) + } + case None => + None + } + } + + /** + * Traces through intermediate alias Projects to find the source generate. + * When GeneratorNestedColumnAliasing runs, it may create aliases like: + * _extract = outer_elem + * or more complex expressions like: + * _extract_level3 = l1.level2.level3 + * We need to resolve these aliases to find the true source. + */ + private def findSourceThroughAliases( + chain: Seq[GenerateInfo], + currentIdx: Int, + currentInfo: GenerateInfo, + aliasAttr: Attribute): Option[Int] = { + + + // Extracts the root attribute from expressions like: + // - Attribute: returns it directly + // - GetStructField(child, ...) / GetArrayStructFields(child, ...) chains: + // recursively extracts root from child + @scala.annotation.tailrec + def extractRootAttrFromExpr(expr: Expression): Option[Attribute] = { + expr match { + case a: Attribute => Some(a) + case GetStructField(child, _, _) => extractRootAttrFromExpr(child) + case GetArrayStructFields(child, _, _, _, _) => extractRootAttrFromExpr(child) + case GetNestedArrayStructFields(child, _, _, _, _) => extractRootAttrFromExpr(child) + case _ => None + } + } + + // Look at the child plan of the current generate to find intermediate Projects + // that might define the alias + def findAliasDefinition(plan: LogicalPlan): Option[ExprId] = { + plan match { + case Project(projectList, child) => + // Find the alias that defines our aliasAttr + val result = projectList.collectFirst { + case a @ Alias(childExpr, _) if a.exprId == aliasAttr.exprId => + // Extract the root attribute from the child expression + val root = extractRootAttrFromExpr(childExpr) + root.map(_.exprId) + }.flatten + result.orElse(findAliasDefinition(child)) + case Filter(_, child) => + findAliasDefinition(child) + case _ => + None + } + } + + // Find what the alias refers to + val resolvedResult = findAliasDefinition(currentInfo.generate.child) + resolvedResult.flatMap { resolvedExprId => + // Now look for a source generate whose colAttr matches the resolved exprId + chain.drop(currentIdx + 1).zipWithIndex.find { case (sourceInfo, _) => + sourceInfo.colAttr.exprId == resolvedExprId + }.map { case (_, relIdx) => currentIdx + 1 + relIdx } + } + } + + /** + * Extracts the full field path that the current generate accesses from the source + * generate's element. For example, if current explodes `source.level2.level3`, + * returns Seq("level2", "level3"). + * + * Handles aliased expressions by tracing through intermediate Projects. + * + * @param sourceInfo The generate whose output provides the array + * @param currentInfo The generate that explodes a field from sourceInfo's output + * @return The field path from source element to the exploded array + */ + private def extractSourceFieldPath( + sourceInfo: GenerateInfo, + currentInfo: GenerateInfo): Seq[String] = { + + /** + * Extracts the full path of field accesses from an expression rooted at colAttr. + * For GetArrayStructFields(GetStructField(colAttr, level2), level3), returns + * Seq("level2", "level3"). + */ + def extractPathFromExpr(expr: Expression, colAttr: Attribute): Seq[String] = { + expr match { + case gsf: GetStructField if isAttrRef(gsf.child, colAttr) => + Seq(gsf.extractFieldName) + case gsf: GetStructField => + extractPathFromExpr(gsf.child, colAttr) :+ gsf.extractFieldName + case gasf: GetArrayStructFields if isAttrRef(gasf.child, colAttr) => + Seq(gasf.field.name) + case gasf: GetArrayStructFields => + extractPathFromExpr(gasf.child, colAttr) :+ gasf.field.name + case gnasf: GetNestedArrayStructFields if isAttrRef(gnasf.child, colAttr) => + Seq(gnasf.field.name) + case gnasf: GetNestedArrayStructFields => + extractPathFromExpr(gnasf.child, colAttr) :+ gnasf.field.name + case _ => + Seq.empty + } + } + + // First try direct match + val directPath = extractPathFromExpr(currentInfo.generator.child, sourceInfo.colAttr) + if (directPath.nonEmpty) { + return directPath + } + + // If generator child is an attribute (alias), resolve it through intermediate Projects + currentInfo.generator.child match { + case attr: Attribute => + currentInfo.intermediateNodes.foreach { + case Project(projectList, _) => + projectList.foreach { + case a @ Alias(childExpr, _) if a.exprId == attr.exprId => + val path = extractPathFromExpr(childExpr, sourceInfo.colAttr) + if (path.nonEmpty) { + return path + } + case _ => + } + case _ => + } + Seq.empty + case _ => + Seq.empty + } + } + + /** + * Checks if two expressions are semantically equivalent for the purpose of + * matching source array expressions (supports Attribute and GetStructField chains). + */ + private def exprMatches(e1: Expression, e2: Expression): Boolean = { + (e1, e2) match { + case (a1: Attribute, a2: Attribute) => a1.exprId == a2.exprId + case (gsf1: GetStructField, gsf2: GetStructField) => + gsf1.extractFieldName == gsf2.extractFieldName && exprMatches(gsf1.child, gsf2.child) + case _ => false + } + } + + /** + * Returns field names accessed on the source array via [[GetArrayStructFields]]. + * These are fields like `someArray.field` where `someArray` is an array of structs. + * + * Note: Direct references to the source array (e.g., in `isnotnull(array)` or + * `size(array)`) are allowed - they don't access specific struct fields. + * We only collect fields that are explicitly extracted via [[GetArrayStructFields]]. + * + * @param exprs Expressions to analyze + * @param sourceRootAttr The root attribute of the source array expression (for eligibility) + * @param sourceArrayExpr The full source array expression (e.g., `col.structArr`) + */ + private def analyzeSourceArrayFields( + exprs: Seq[Expression], + sourceRootAttr: Option[Attribute], + sourceArrayExpr: Expression): Option[Set[String]] = { + sourceRootAttr match { + case None => Some(Set.empty) // Complex expression - allow pruning but no source fields + case Some(_) => + val fieldNames = mutable.LinkedHashSet[String]() + + def analyze(e: Expression): Unit = { + e match { + case gasf: GetArrayStructFields if exprMatches(gasf.child, sourceArrayExpr) => + // Source array field access like someArray.field or col.structArr.field + fieldNames += gasf.field.name + case gnasf: GetNestedArrayStructFields + if exprMatches(gnasf.child, sourceArrayExpr) => + fieldNames += gnasf.field.name + case _ => + // Direct attribute references like isnotnull(someArray) are OK - + // they don't access specific struct fields. + // Continue recursing to find any nested GetArrayStructFields. + e.children.foreach(analyze) + } + } + exprs.foreach(analyze) + Some(fieldNames.toSet) + } + } + + // --------------------------------------------------------------------------- + // Minimal-weight field for pos-only posexplode + // --------------------------------------------------------------------------- + + /** + * Selects the field with the smallest `defaultSize`. Tie-breaker: + * lexicographic field name. + */ + private def selectMinimalWeightField( + st: StructType): Option[StructField] = { + if (st.isEmpty) None + else Some(st.fields.minBy(f => (f.dataType.defaultSize, f.name))) + } + + /** + * Checks if a Project list contains only simple attribute references + * (possibly with aliases). These are typically inserted by ColumnPruning + * and can be safely looked through. + * + * Returns false for Projects with GetStructField or other expressions, + * as these create aliases that the generator may reference. + */ + private def isAttributeOnlyProject(projectList: Seq[NamedExpression]): Boolean = { + projectList.forall { + case _: Attribute => true + case Alias(_: Attribute, _) => true + case _ => false + } + } + + /** + * Walks through a plan extracting Filter conditions and looking through + * attribute-only Projects to find the leaf (non-Project, non-Filter) child. + * This allows our inner Project to be inserted directly above the leaf, + * enabling ScanOperation to match it without interference from filters that + * reference the full source array attribute. + * + * Only looks through Projects that contain simple attribute references + * (from ColumnPruning). Projects with GetStructField or other expressions + * create aliases that the generator may reference, so we don't look through + * them to avoid breaking the reference chain. + * + * Returns (all filter conditions in bottom-first order, the leaf child). + */ + private def decomposeChild( + plan: LogicalPlan): (Seq[Expression], LogicalPlan) = { + plan match { + case Filter(condition, child) => + val (deeper, base) = decomposeChild(child) + (deeper :+ condition, base) + case Project(projectList, child) if isAttributeOnlyProject(projectList) => + // Look through column-pruning Projects inserted by ColumnPruning. + // Our inner Project will handle column selection instead. + decomposeChild(child) + case other => + (Nil, other) + } + } + + /** + * Finds the TRUE scan relation by looking through ALL Projects, Filters, and Generates. + * This is needed to correctly determine scan attribute IDs - we can't use the first + * non-attribute-only Project because that might be an intermediate alias Project + * (e.g., from GeneratorNestedColumnAliasing) whose output includes alias attributes + * that would be incorrectly treated as "scan attributes". + * + * @param plan The plan to traverse + * @return The actual scan relation (or the deepest non-Project/Filter/Generate node) + */ + private def findScanLeaf(plan: LogicalPlan): LogicalPlan = plan match { + case Project(_, child) => findScanLeaf(child) + case Filter(_, child) => findScanLeaf(child) + case Generate(_, _, _, _, _, child) => findScanLeaf(child) + case other => other + } + + // --------------------------------------------------------------------------- + // Pruned-array builder (ArraysZip + GetArrayStructFields) + // --------------------------------------------------------------------------- + + /** + * Builds a pruned array expression from a nested StructType schema. It handles: + * - Primitive fields: extracted with GetArrayStructFields + * - Nested array fields: recursively builds pruned nested arrays + * - Nested struct fields: preserved with nested pruning applied + * + * For example, given: + * - Source: array>>> + * - Required: struct>> + * + * Produces: + * arrays_zip( + * source.requestId, + * arrays_zip(source.items.itemId, source.items.clicked) AS items + * ) + * + * @param arrayExpr The source array expression + * @param originalStruct The original struct type of array elements + * @param requiredSchema The required nested schema (subset of originalStruct) + * @param containsNull Whether the array can contain nulls + * @return Expression that produces the pruned array + */ + private def buildPrunedArrayFromSchema( + arrayExpr: Expression, + originalStruct: StructType, + requiredSchema: StructType, + containsNull: Boolean): Expression = { + + val arrayDepth = computeArrayDepth(arrayExpr.dataType) + + // Order fields according to original struct order + val orderedRequiredFields = originalStruct.fieldNames.toSeq + .filter(name => requiredSchema.fieldNames.contains(name)) + .map(name => requiredSchema(requiredSchema.fieldIndex(name))) + + val fieldArrays = orderedRequiredFields.map { requiredField => + val fieldName = requiredField.name + val origOrdinal = originalStruct.fieldIndex(fieldName) + val origField = originalStruct(origOrdinal) + + (origField.dataType, requiredField.dataType) match { + // Nested array of struct - recursively prune if schemas differ + case (ArrayType(origInnerStruct: StructType, innerContainsNull), + ArrayType(requiredInnerStruct: StructType, _)) + if hasNestedTypePruning(requiredInnerStruct, origInnerStruct) => + // Extract the field arrays first + val fieldExpr = if (arrayDepth > 1) { + GetNestedArrayStructFields( + arrayExpr, origField, origOrdinal, originalStruct.length, + containsNull || origField.nullable) + } else { + GetArrayStructFields( + arrayExpr, origField, origOrdinal, originalStruct.length, + containsNull || origField.nullable) + } + // Recursively prune the nested array + buildPrunedArrayFromSchema( + fieldExpr, origInnerStruct, requiredInnerStruct, innerContainsNull) + + // Nested struct field - recursively prune sub-fields + // e.g., meta:struct to meta:struct when only meta.a is accessed + case (origSt: StructType, reqSt: StructType) + if hasNestedTypePruning(reqSt, origSt) => + val fieldExpr = if (arrayDepth > 1) { + GetNestedArrayStructFields( + arrayExpr, origField, origOrdinal, originalStruct.length, + containsNull || origField.nullable) + } else { + GetArrayStructFields( + arrayExpr, origField, origOrdinal, originalStruct.length, + containsNull || origField.nullable) + } + // fieldExpr is array, recursively prune to array + buildPrunedArrayFromSchema( + fieldExpr, origSt, reqSt, containsNull || origField.nullable) + + // Non-nested field or no pruning needed for nested struct + case _ => + if (arrayDepth > 1) { + GetNestedArrayStructFields( + arrayExpr, origField, origOrdinal, originalStruct.length, + containsNull || origField.nullable) + } else { + GetArrayStructFields( + arrayExpr, origField, origOrdinal, originalStruct.length, + containsNull || origField.nullable) + } + } + } + + val names = orderedRequiredFields.map(f => Literal(f.name)) + + // For nested arrays (depth > 1), use NestedArraysZip to zip at the innermost level + // For simple arrays (depth == 1), use the standard ArraysZip + if (arrayDepth > 1) { + NestedArraysZip.withDepth(fieldArrays, names, arrayDepth) + } else { + ArraysZip(fieldArrays, names) + } + } + + /** + * Computes the depth of array nesting in a data type. + * `array` has depth 1, `array>` has depth 2, etc. + */ + private def computeArrayDepth(dt: DataType): Int = dt match { + case ArrayType(elementType: ArrayType, _) => 1 + computeArrayDepth(elementType) + case ArrayType(_, _) => 1 + case _ => 0 + } + + // --------------------------------------------------------------------------- + // Ordinal fixer (name-based) + // --------------------------------------------------------------------------- + + /** + * Rewrites [[GetArrayStructFields]] and [[GetNestedArrayStructFields]] expressions + * that match the source array expression (supports both direct Attributes and + * GetStructField chains like `col.structArr`). + * + * This handles patterns like: + * - `arr.field` where generator is `explode(arr)` + * - `col.structArr.field` where generator is `explode(col.structArr)` + * + * @param expr The expression to fix + * @param srcExpr The original source array expression (Attribute or GetStructField chain) + * @param prunedAttr The _pruned attribute to replace srcExpr with + * @param newStruct The pruned struct type (element type of the pruned array) + * @param containsNull Whether the array can contain nulls + */ + private def fixArrayStructOrdinalsInExprChain( + expr: Expression, + srcExpr: Expression, + prunedAttr: Attribute, + newStruct: StructType, + containsNull: Boolean): Expression = { + + expr.transformDown { + case gasf @ GetArrayStructFields(child, field, _, _, _) + if exprMatches(child, srcExpr) => + // Find the new ordinal in the pruned struct + val fieldName = field.name + val newOrdinal = newStruct.fieldIndex(fieldName) + val newField = newStruct(newOrdinal) + // Rebuild with updated child (pruned attr), field, ordinal, and numFields + GetArrayStructFields( + prunedAttr, newField, newOrdinal, newStruct.length, + containsNull || newField.nullable) + + case gnasf @ GetNestedArrayStructFields(child, field, _, _, _) + if exprMatches(child, srcExpr) => + val fieldName = field.name + val newOrdinal = newStruct.fieldIndex(fieldName) + val newField = newStruct(newOrdinal) + GetNestedArrayStructFields( + prunedAttr, newField, newOrdinal, newStruct.length, + containsNull || newField.nullable) + + case e if exprMatches(e, srcExpr) => + // Direct reference to source array expression - replace with pruned + prunedAttr + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 7baad5ea92a00..2ba63941b1d95 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -225,6 +225,474 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { assert(!get3.containsNull) } + test("GetNestedArrayStructFields - basic extraction") { + // Test extraction from array>> + val innerStruct = new StructType() + .add("a", IntegerType, nullable = false) + .add("b", StringType, nullable = true) + val nestedArrayType = ArrayType( + ArrayType(innerStruct, containsNull = false), containsNull = false) + + // Create test data: [[{1, "x"}, {2, "y"}], [{3, "z"}]] + val testData = Seq( + Seq(create_row(1, "x"), create_row(2, "y")), + Seq(create_row(3, "z")) + ) + val input = Literal.create(testData, nestedArrayType) + + // Extract field "a" - should get [[1, 2], [3]] + val fieldA = innerStruct("a") + val extractA = GetNestedArrayStructFields(input, fieldA, 0, 2, containsNull = false) + checkEvaluation(extractA, Seq(Seq(1, 2), Seq(3))) + + // Extract field "b" - should get [["x", "y"], ["z"]] + val fieldB = innerStruct("b") + val extractB = GetNestedArrayStructFields(input, fieldB, 1, 2, containsNull = true) + checkEvaluation(extractB, Seq(Seq("x", "y"), Seq("z"))) + } + + test("GetNestedArrayStructFields - null handling") { + val innerStruct = new StructType() + .add("a", IntegerType, nullable = true) + val nestedArrayType = ArrayType( + ArrayType(innerStruct, containsNull = true), containsNull = true) + + // Test with nulls at various levels + val testData = Seq( + Seq(create_row(1), null, create_row(3)), + null, + Seq(create_row(null)) + ) + val input = Literal.create(testData, nestedArrayType) + + val fieldA = innerStruct("a") + val extract = GetNestedArrayStructFields(input, fieldA, 0, 1, containsNull = true) + checkEvaluation(extract, Seq(Seq(1, null, 3), null, Seq(null))) + + // Test with null input + val nullInput = Literal.create(null, nestedArrayType) + checkEvaluation(GetNestedArrayStructFields(nullInput, fieldA, 0, 1, containsNull = true), null) + } + + test("GetNestedArrayStructFields - triple nesting") { + // Test array>>> + val innerStruct = new StructType().add("x", IntegerType, nullable = false) + val tripleArrayType = ArrayType( + ArrayType( + ArrayType(innerStruct, containsNull = false), + containsNull = false), + containsNull = false) + + // [[[{1}, {2}]], [[{3}], [{4}, {5}]]] + val testData = Seq( + Seq(Seq(create_row(1), create_row(2))), + Seq(Seq(create_row(3)), Seq(create_row(4), create_row(5))) + ) + val input = Literal.create(testData, tripleArrayType) + + val fieldX = innerStruct("x") + val extract = GetNestedArrayStructFields(input, fieldX, 0, 1, containsNull = false) + checkEvaluation(extract, Seq(Seq(Seq(1, 2)), Seq(Seq(3), Seq(4, 5)))) + } + + test("GetNestedArrayStructFields - dataType preserves nesting") { + val innerStruct = new StructType() + .add("a", IntegerType) + .add("b", StringType) + val nestedArrayType = ArrayType( + ArrayType(innerStruct, containsNull = true), + containsNull = false) + + val attr = AttributeReference("arr", nestedArrayType)() + val fieldA = innerStruct("a") + val extract = GetNestedArrayStructFields(attr, fieldA, 0, 2, containsNull = true) + + // Output type should be array> + val expectedType = ArrayType(ArrayType(IntegerType, containsNull = true), containsNull = false) + assert(extract.dataType === expectedType) + } + + test("GetNestedArrayStructFields - SelectedField integration") { + val innerStruct = new StructType() + .add("a", IntegerType) + .add("b", StringType) + val nestedArrayType = ArrayType( + ArrayType(innerStruct, containsNull = true), + containsNull = false) + + val attr = AttributeReference("arr", nestedArrayType)() + val fieldA = innerStruct("a") + val extract = GetNestedArrayStructFields(attr, fieldA, 0, 2, containsNull = true) + + // SelectedField should recognize this expression and return the appropriate schema + SelectedField.unapply(extract) match { + case Some(selectedField) => + assert(selectedField.name == "arr") + // The data type should preserve the nested array structure with pruned inner struct + selectedField.dataType match { + case ArrayType(ArrayType(StructType(fields), _), _) => + assert(fields.length == 1) + assert(fields.head.name == "a") + assert(fields.head.dataType == IntegerType) + case other => fail(s"Unexpected data type: $other") + } + case None => fail("SelectedField should match GetNestedArrayStructFields") + } + } + + test("NestedArraysZip - depth 1 (same as ArraysZip)") { + // At depth 1, NestedArraysZip should behave like ArraysZip + val arr1 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) + val arr2 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType)) + val names = Seq(Literal("x"), Literal("y")) + + val zip = NestedArraysZip(Seq(arr1, arr2), names, 1) + val expected = Seq( + create_row(1, "a"), + create_row(2, "b"), + create_row(3, "c") + ) + checkEvaluation(zip, expected) + + // Test max-length semantics (shorter arrays padded with nulls) + val arr3 = Literal.create(Seq(1, 2), ArrayType(IntegerType)) + val arr4 = Literal.create(Seq("a", "b", "c", "d"), ArrayType(StringType)) + val zip2 = NestedArraysZip(Seq(arr3, arr4), names, 1) + val expected2 = Seq( + create_row(1, "a"), + create_row(2, "b"), + create_row(null, "c"), + create_row(null, "d") + ) + checkEvaluation(zip2, expected2) + } + + test("NestedArraysZip - depth 2") { + // array> zip array> => array>> + val arr1 = Literal.create( + Seq(Seq(1, 2), Seq(3)), + ArrayType(ArrayType(IntegerType)) + ) + val arr2 = Literal.create( + Seq(Seq("a", "b"), Seq("c")), + ArrayType(ArrayType(StringType)) + ) + val names = Seq(Literal("x"), Literal("y")) + + val zip = NestedArraysZip(Seq(arr1, arr2), names, 2) + val expected = Seq( + Seq(create_row(1, "a"), create_row(2, "b")), + Seq(create_row(3, "c")) + ) + checkEvaluation(zip, expected) + } + + test("NestedArraysZip - depth 2 with different inner lengths") { + // Inner arrays have different lengths - should use max-length semantics + val arr1 = Literal.create( + Seq(Seq(1, 2, 3), Seq(4)), + ArrayType(ArrayType(IntegerType)) + ) + val arr2 = Literal.create( + Seq(Seq("a"), Seq("b", "c")), + ArrayType(ArrayType(StringType)) + ) + val names = Seq(Literal("x"), Literal("y")) + + val zip = NestedArraysZip(Seq(arr1, arr2), names, 2) + val expected = Seq( + Seq(create_row(1, "a"), create_row(2, null), create_row(3, null)), + Seq(create_row(4, "b"), create_row(null, "c")) + ) + checkEvaluation(zip, expected) + } + + test("NestedArraysZip - null handling at outer level") { + val arr1 = Literal.create( + Seq(Seq(1, 2), Seq(3)), + ArrayType(ArrayType(IntegerType)) + ) + val nullArr = Literal.create( + null, + ArrayType(ArrayType(StringType)) + ) + val names = Seq(Literal("x"), Literal("y")) + + // If any top-level input is null, output is null + val zip = NestedArraysZip(Seq(arr1, nullArr), names, 2) + checkEvaluation(zip, null) + } + + test("NestedArraysZip - null handling at inner level") { + // If an inner array is null at some position, that output position is null + val arr1 = Literal.create( + Seq(Seq(1, 2), null, Seq(3)), + ArrayType(ArrayType(IntegerType), containsNull = true) + ) + val arr2 = Literal.create( + Seq(Seq("a", "b"), Seq("c"), null), + ArrayType(ArrayType(StringType), containsNull = true) + ) + val names = Seq(Literal("x"), Literal("y")) + + val zip = NestedArraysZip(Seq(arr1, arr2), names, 2) + val expected = Seq( + Seq(create_row(1, "a"), create_row(2, "b")), + null, // arr1[1] is null + null // arr2[2] is null + ) + checkEvaluation(zip, expected) + } + + test("NestedArraysZip - dataType") { + val arr1Type = ArrayType(ArrayType(IntegerType)) + val arr2Type = ArrayType(ArrayType(StringType)) + val arr1 = AttributeReference("a", arr1Type)() + val arr2 = AttributeReference("b", arr2Type)() + val names = Seq(Literal("x"), Literal("y")) + + val zip = NestedArraysZip(Seq(arr1, arr2), names, 2) + + // Output type should be array>> + val expectedElementStruct = StructType(Seq( + StructField("x", IntegerType, nullable = true), + StructField("y", StringType, nullable = true) + )) + val expectedType = ArrayType(ArrayType(expectedElementStruct, containsNull = false)) + assert(zip.dataType === expectedType) + } + + test("NestedArraysZip - depth 3 (triple nesting)") { + // array>> zip array>> + val arr1 = Literal.create( + Seq(Seq(Seq(1, 2), Seq(3)), Seq(Seq(4))), + ArrayType(ArrayType(ArrayType(IntegerType))) + ) + val arr2 = Literal.create( + Seq(Seq(Seq("a", "b"), Seq("c")), Seq(Seq("d"))), + ArrayType(ArrayType(ArrayType(StringType))) + ) + val names = Seq(Literal("x"), Literal("y")) + + val zip = NestedArraysZip(Seq(arr1, arr2), names, 3) + val expected = Seq( + Seq(Seq(create_row(1, "a"), create_row(2, "b")), Seq(create_row(3, "c"))), + Seq(Seq(create_row(4, "d"))) + ) + checkEvaluation(zip, expected) + } + + test("NestedArraysZip - empty arrays") { + // Empty outer array + val emptyOuter1 = Literal.create( + Seq.empty[Seq[Int]], + ArrayType(ArrayType(IntegerType)) + ) + val emptyOuter2 = Literal.create( + Seq.empty[Seq[String]], + ArrayType(ArrayType(StringType)) + ) + val names = Seq(Literal("x"), Literal("y")) + + val zip1 = NestedArraysZip(Seq(emptyOuter1, emptyOuter2), names, 2) + checkEvaluation(zip1, Seq.empty) + + // Empty inner arrays + val emptyInner1 = Literal.create( + Seq(Seq.empty[Int], Seq(1)), + ArrayType(ArrayType(IntegerType)) + ) + val emptyInner2 = Literal.create( + Seq(Seq.empty[String], Seq("a")), + ArrayType(ArrayType(StringType)) + ) + + val zip2 = NestedArraysZip(Seq(emptyInner1, emptyInner2), names, 2) + checkEvaluation(zip2, Seq(Seq.empty, Seq(create_row(1, "a")))) + } + + test("NestedArraysZip - single element arrays") { + val arr1 = Literal.create( + Seq(Seq(42)), + ArrayType(ArrayType(IntegerType)) + ) + val arr2 = Literal.create( + Seq(Seq("answer")), + ArrayType(ArrayType(StringType)) + ) + val names = Seq(Literal("x"), Literal("y")) + + val zip = NestedArraysZip(Seq(arr1, arr2), names, 2) + checkEvaluation(zip, Seq(Seq(create_row(42, "answer")))) + } + + test("NestedArraysZip - three input arrays") { + val arr1 = Literal.create( + Seq(Seq(1, 2), Seq(3)), + ArrayType(ArrayType(IntegerType)) + ) + val arr2 = Literal.create( + Seq(Seq("a", "b"), Seq("c")), + ArrayType(ArrayType(StringType)) + ) + val arr3 = Literal.create( + Seq(Seq(true, false), Seq(true)), + ArrayType(ArrayType(BooleanType)) + ) + val names = Seq(Literal("x"), Literal("y"), Literal("z")) + + val zip = NestedArraysZip(Seq(arr1, arr2, arr3), names, 2) + val expected = Seq( + Seq(create_row(1, "a", true), create_row(2, "b", false)), + Seq(create_row(3, "c", true)) + ) + checkEvaluation(zip, expected) + } + + test("NestedArraysZip - withDepth factory method") { + val arr1 = Literal.create(Seq(Seq(1)), ArrayType(ArrayType(IntegerType))) + val arr2 = Literal.create(Seq(Seq("a")), ArrayType(ArrayType(StringType))) + val names = Seq(Literal("x"), Literal("y")) + + val zip = NestedArraysZip.withDepth(Seq(arr1, arr2), names, 2) + checkEvaluation(zip, Seq(Seq(create_row(1, "a")))) + + // Verify depth is set correctly + assert(zip.depth === 2) + } + + test("NestedArraysZip - auto-detect depth factory") { + val arr1 = Literal.create(Seq(Seq(1)), ArrayType(ArrayType(IntegerType))) + val arr2 = Literal.create(Seq(Seq("a")), ArrayType(ArrayType(StringType))) + val names = Seq(Literal("x"), Literal("y")) + + val zip = NestedArraysZip(Seq(arr1, arr2), names) + // Should auto-detect depth as 2 + assert(zip.depth === 2) + checkEvaluation(zip, Seq(Seq(create_row(1, "a")))) + } + + test("NestedArraysZip - checkInputDataTypes rejects insufficient depth") { + // Child with depth 1 when depth 2 is required + val arr1 = AttributeReference("a", ArrayType(ArrayType(IntegerType)))() + val arr2 = AttributeReference("b", ArrayType(StringType))() // depth 1, not 2 + val names = Seq(Literal("x"), Literal("y")) + + val zip = NestedArraysZip(Seq(arr1, arr2), names, 2) + val result = zip.checkInputDataTypes() + assert(result.isFailure) + } + + test("NestedArraysZip - checkInputDataTypes accepts excess depth") { + // Children with depth 3 when depth 2 is specified - allowed because recursive + // buildPrunedArrayFromSchema may produce deeper arrays when inner struct fields + // are themselves arrays. NestedArraysZip zips at the specified depth, leaving + // deeper nesting intact. + val arr1 = AttributeReference("a", ArrayType(ArrayType(ArrayType(IntegerType))))() + val arr2 = AttributeReference("b", ArrayType(ArrayType(ArrayType(StringType))))() + val names = Seq(Literal("x"), Literal("y")) + + val zip = NestedArraysZip(Seq(arr1, arr2), names, 2) + val result = zip.checkInputDataTypes() + assert(result.isSuccess, + "Should accept children with depth > specified depth (zips at specified level)") + } + + test("NestedArraysZip - checkInputDataTypes rejects mixed depths") { + // Children with different depths should be rejected + val arr1 = AttributeReference("a", ArrayType(ArrayType(IntegerType)))() // depth 2 + val arr2 = AttributeReference("b", ArrayType(ArrayType(ArrayType(StringType))))() // depth 3 + val names = Seq(Literal("x"), Literal("y")) + + val zip = NestedArraysZip(Seq(arr1, arr2), names, 2) + val result = zip.checkInputDataTypes() + assert(result.isFailure, + "Should reject children with different array depths") + } + + test("NestedArraysZip - checkInputDataTypes accepts exact depth match") { + val arr1 = AttributeReference("a", ArrayType(ArrayType(IntegerType)))() + val arr2 = AttributeReference("b", ArrayType(ArrayType(StringType)))() + val names = Seq(Literal("x"), Literal("y")) + + val zip = NestedArraysZip(Seq(arr1, arr2), names, 2) + val result = zip.checkInputDataTypes() + assert(result.isSuccess) + } + + test("NestedArraysZip - withNewChildren preserves properties") { + val arr1 = Literal.create(Seq(Seq(1)), ArrayType(ArrayType(IntegerType))) + val arr2 = Literal.create(Seq(Seq("a")), ArrayType(ArrayType(StringType))) + val names = Seq(Literal("x"), Literal("y")) + + val zip = NestedArraysZip(Seq(arr1, arr2), names, 2) + + // Create new children + val newArr1 = Literal.create(Seq(Seq(99)), ArrayType(ArrayType(IntegerType))) + val newArr2 = Literal.create(Seq(Seq("z")), ArrayType(ArrayType(StringType))) + + // Use withNewChildren (public method from Expression trait) + val newZip = zip.withNewChildren(Seq(newArr1, newArr2)).asInstanceOf[NestedArraysZip] + checkEvaluation(newZip, Seq(Seq(create_row(99, "z")))) + + // Depth should be preserved + assert(newZip.depth == 2) + // Names should be preserved + assert(newZip.names == names) + } + + test("NestedArraysZip - complex nested types") { + // Zipping arrays of arrays of structs + val structType = StructType(Seq( + StructField("id", IntegerType), + StructField("name", StringType) + )) + val arr1 = Literal.create( + Seq(Seq(create_row(1, "a"))), + ArrayType(ArrayType(structType)) + ) + val arr2 = Literal.create( + Seq(Seq(100L)), + ArrayType(ArrayType(LongType)) + ) + val names = Seq(Literal("struct_field"), Literal("long_field")) + + val zip = NestedArraysZip(Seq(arr1, arr2), names, 2) + val expected = Seq(Seq(create_row(create_row(1, "a"), 100L))) + checkEvaluation(zip, expected) + } + + test("NestedArraysZip - prettyName") { + val arr1 = AttributeReference("a", ArrayType(ArrayType(IntegerType)))() + val arr2 = AttributeReference("b", ArrayType(ArrayType(StringType)))() + val names = Seq(Literal("x"), Literal("y")) + + val zip = NestedArraysZip(Seq(arr1, arr2), names, 2) + assert(zip.prettyName == "nested_arrays_zip") + } + + test("NestedArraysZip - mismatched outer lengths at depth 2") { + // Outer arrays have different lengths + val arr1 = Literal.create( + Seq(Seq(1, 2), Seq(3), Seq(4, 5, 6)), + ArrayType(ArrayType(IntegerType)) + ) + val arr2 = Literal.create( + Seq(Seq("a", "b")), // Only 1 outer element + ArrayType(ArrayType(StringType)) + ) + val names = Seq(Literal("x"), Literal("y")) + + val zip = NestedArraysZip(Seq(arr1, arr2), names, 2) + // Should extend with nulls at outer level + val expected = Seq( + Seq(create_row(1, "a"), create_row(2, "b")), + null, // arr2[1] doesn't exist + null // arr2[2] doesn't exist + ) + checkEvaluation(zip, expected) + } + test("CreateArray") { val intSeq = Seq(5, 10, 15, 20, 25) val longSeq = intSeq.map(_.toLong) @@ -359,36 +827,40 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { ) } - // map key can't be variant - val map6 = CreateMap(Seq( - Literal.create(new VariantVal(Array[Byte](), Array[Byte]())), - Literal.create(1) - )) - map6.checkInputDataTypes() match { - case TypeCheckResult.TypeCheckSuccess => fail("should not allow variant as a part of map key") - case TypeCheckResult.DataTypeMismatch(errorSubClass, messageParameters) => - assert(errorSubClass == "INVALID_MAP_KEY_TYPE") - assert(messageParameters === Map("keyType" -> "\"VARIANT\"")) - } - - // map key can't contain variant - val map7 = CreateMap( - Seq( - CreateStruct( - Seq(Literal.create(1), Literal.create(new VariantVal(Array[Byte](), Array[Byte]()))) - ), + test("CreateMap: variant key validation") { + // map key can't be variant + val map6 = CreateMap(Seq( + Literal.create(new VariantVal(Array[Byte](), Array[Byte]())), Literal.create(1) + )) + map6.checkInputDataTypes() match { + case TypeCheckResult.TypeCheckSuccess => + fail("should not allow variant as a part of map key") + case TypeCheckResult.DataTypeMismatch(errorSubClass, messageParameters) => + assert(errorSubClass == "INVALID_MAP_KEY_TYPE") + assert(messageParameters === Map("keyType" -> "\"VARIANT\"")) + } + + // map key can't contain variant + val map7 = CreateMap( + Seq( + CreateStruct( + Seq(Literal.create(1), Literal.create(new VariantVal(Array[Byte](), Array[Byte]()))) + ), + Literal.create(1) + ) ) - ) - map7.checkInputDataTypes() match { - case TypeCheckResult.TypeCheckSuccess => fail("should not allow variant as a part of map key") - case TypeCheckResult.DataTypeMismatch(errorSubClass, messageParameters) => - assert(errorSubClass == "INVALID_MAP_KEY_TYPE") - assert( - messageParameters === Map( - "keyType" -> "\"STRUCT\"" + map7.checkInputDataTypes() match { + case TypeCheckResult.TypeCheckSuccess => + fail("should not allow variant as a part of map key") + case TypeCheckResult.DataTypeMismatch(errorSubClass, messageParameters) => + assert(errorSubClass == "INVALID_MAP_KEY_TYPE") + assert( + messageParameters === Map( + "keyType" -> "\"STRUCT\"" + ) ) - ) + } } test("MapFromArrays") { @@ -417,8 +889,6 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { MapFromArrays(strArray, intWithNullArray), create_map(strSeq, intWithNullSeq)) checkEvaluation( MapFromArrays(strArray, longWithNullArray), create_map(strSeq, longWithNullSeq)) - checkEvaluation( - MapFromArrays(strArray, longWithNullArray), create_map(strSeq, longWithNullSeq)) checkEvaluation(MapFromArrays(nullArray, nullArray), null) // Map key can't be null diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneNestedFieldsThroughGenerateForScanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneNestedFieldsThroughGenerateForScanSuite.scala new file mode 100644 index 0000000000000..fd97dcee11b3d --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneNestedFieldsThroughGenerateForScanSuite.scala @@ -0,0 +1,436 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.SchemaPruningTest +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ + +class PruneNestedFieldsThroughGenerateForScanSuite extends SchemaPruningTest { + + private val itemStruct = StructType(Seq( + StructField("a", IntegerType), + StructField("b", StringType), + StructField("c", DoubleType))) + + private val rel = LocalRelation( + $"id".int, + $"items".array(itemStruct)) + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Column Pruning", FixedPoint(10), + ColumnPruning, + CollapseProject, + RemoveNoopOperators) :: + Batch("PruneNestedFieldsThroughGenerate", FixedPoint(1), + PruneNestedFieldsThroughGenerateForScan) :: Nil + } + + private def explodeItems(outer: Boolean = false): Generate = { + val explode = Explode($"items") + Generate( + explode, + unrequiredChildIndex = Nil, + outer = outer, + qualifier = None, + generatorOutput = Seq(AttributeReference("item", itemStruct)()), + child = rel) + } + + private def posexplodeItems(outer: Boolean = false): Generate = { + val posexplode = PosExplode($"items") + Generate( + posexplode, + unrequiredChildIndex = Nil, + outer = outer, + qualifier = None, + generatorOutput = Seq( + AttributeReference("pos", IntegerType)(), + AttributeReference("item", itemStruct)()), + child = rel) + } + + test("multi-field: prunes to required fields only") { + withSQLConf(SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.key -> "true") { + val gen = explodeItems() + val item = gen.generatorOutput.head + val first = GetStructField(item, 0, Some("a")) + val second = GetStructField(item, 1, Some("b")) + val query = gen.select(first, second).analyze + + val optimized = Optimize.execute(query) + + // The optimized plan should have a Generate with a pruned array type + val generates = optimized.collect { case g: Generate => g } + assert(generates.nonEmpty, "Expected a Generate node in the optimized plan") + + val newGen = generates.head + val newChildType = newGen.generator.children.head.dataType + newChildType match { + case ArrayType(st: StructType, _) => + assert(st.fieldNames.toSet === Set("a", "b"), + s"Expected pruned struct with fields {a, b} but got ${st.fieldNames.mkString(", ")}") + case other => + fail(s"Expected ArrayType(StructType) but got $other") + } + } + } + + test("multi-field: ordinals are correct after pruning non-contiguous fields") { + withSQLConf(SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.key -> "true") { + val gen = explodeItems() + val item = gen.generatorOutput.head + // Select field 'a' (ordinal 0) and 'c' (ordinal 2), skipping 'b' (ordinal 1) + val first = GetStructField(item, 0, Some("a")) + val third = GetStructField(item, 2, Some("c")) + val query = gen.select(first, third).analyze + + val optimized = Optimize.execute(query) + + // Check that the element struct has only a and c + val generates = optimized.collect { case g: Generate => g } + assert(generates.nonEmpty) + val newGen = generates.head + newGen.generator.children.head.dataType match { + case ArrayType(st: StructType, _) => + assert(st.fieldNames.toSeq === Seq("a", "c"), + "Fields should be in original schema order") + assert(st.fields(0).dataType === IntegerType) + assert(st.fields(1).dataType === DoubleType) + case other => + fail(s"Expected ArrayType(StructType) but got $other") + } + + // Check that GetStructField ordinals in the project are correct + val projects = optimized.collect { case p: Project => p.projectList } + assert(projects.nonEmpty) + val topProject = projects.head + val structFields = topProject.flatMap(_.collect { + case gsf: GetStructField => gsf + }) + val ordinals = structFields.map(_.ordinal) + assert(ordinals === Seq(0, 1), + s"Expected ordinals [0, 1] for pruned struct {a, c} but got $ordinals") + } + } + + test("no pruning when whole struct is referenced") { + withSQLConf(SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.key -> "true") { + val gen = explodeItems() + val item = gen.generatorOutput.head + // Reference the whole struct directly + val query = gen.select(item).analyze + + val optimized = Optimize.execute(query) + + // The generator child should still be the original items column + val generates = optimized.collect { case g: Generate => g } + assert(generates.nonEmpty) + generates.head.generator.children.head.dataType match { + case ArrayType(st: StructType, _) => + assert(st.fields.length === 3, + "No pruning should occur when whole struct is referenced") + case _ => + } + } + } + + test("no pruning when all fields are selected") { + withSQLConf(SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.key -> "true") { + val gen = explodeItems() + val item = gen.generatorOutput.head + val fieldA = GetStructField(item, 0, Some("a")) + val fieldB = GetStructField(item, 1, Some("b")) + val fieldC = GetStructField(item, 2, Some("c")) + val query = gen.select(fieldA, fieldB, fieldC).analyze + + val optimized = Optimize.execute(query) + + val generates = optimized.collect { case g: Generate => g } + assert(generates.nonEmpty) + generates.head.generator.children.head.dataType match { + case ArrayType(st: StructType, _) => + assert(st.fields.length === 3, + "No pruning should occur when all fields are selected") + case _ => + } + } + } + + test("disabled when nestedSchemaPruningEnabled is false") { + withSQLConf(SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.key -> "false") { + val gen = explodeItems() + val item = gen.generatorOutput.head + val first = GetStructField(item, 0, Some("a")) + val second = GetStructField(item, 1, Some("b")) + val query = gen.select(first, second).analyze + + val optimized = Optimize.execute(query) + + val generates = optimized.collect { case g: Generate => g } + assert(generates.nonEmpty) + generates.head.generator.children.head.dataType match { + case ArrayType(st: StructType, _) => + assert(st.fields.length === 3, + "No pruning when nestedSchemaPruningEnabled is false") + case _ => + } + } + } + + test("posexplode: multi-field prune on element fields") { + withSQLConf(SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.key -> "true") { + val gen = posexplodeItems() + val pos = gen.generatorOutput(0) + val item = gen.generatorOutput(1) + val fieldA = GetStructField(item, 0, Some("a")) + val fieldB = GetStructField(item, 1, Some("b")) + val query = gen.select(pos, fieldA, fieldB).analyze + + val optimized = Optimize.execute(query) + + val generates = optimized.collect { case g: Generate => g } + assert(generates.nonEmpty) + val newGen = generates.head + newGen.generator.children.head.dataType match { + case ArrayType(st: StructType, _) => + assert(st.fieldNames.toSet === Set("a", "b"), + s"Expected pruned struct {a, b} but got ${st.fieldNames.mkString(", ")}") + case other => + fail(s"Expected ArrayType(StructType) but got $other") + } + } + } + + test("posexplode: pos-only selects minimal-weight field") { + withSQLConf(SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.key -> "true") { + val gen = posexplodeItems() + val pos = gen.generatorOutput(0) + // Only reference pos, not any element fields + val query = gen.select(pos).analyze + + val optimized = Optimize.execute(query) + + val generates = optimized.collect { case g: Generate => g } + assert(generates.nonEmpty) + val newGen = generates.head + newGen.generator.children.head.dataType match { + case ArrayType(st: StructType, _) => + // Should pick the minimal-weight field: 'a' (IntegerType, defaultSize=4) + // over 'b' (StringType, defaultSize=20) and 'c' (DoubleType, defaultSize=8) + assert(st.fields.length === 1, + s"Expected 1 field but got ${st.fields.length}") + assert(st.fieldNames.toSet === Set("a"), + s"Expected minimal-weight field 'a' but got '${st.fieldNames.head}'") + case other => + fail(s"Expected ArrayType(StructType) but got $other") + } + } + } + + // Test case for bug: when posexplode element fields are selected WITHOUT the position column, + // pruning should still work correctly. This test verifies the fix for the "consecutive + // OUTER POSEXPLODE without pos columns" issue. + test("posexplode: element fields only (no pos column) still prunes correctly") { + withSQLConf(SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.key -> "true") { + val gen = posexplodeItems() + val item = gen.generatorOutput(1) + val fieldA = GetStructField(item, 0, Some("a")) + val fieldB = GetStructField(item, 1, Some("b")) + // Only element fields, NOT the position column + val query = gen.select(fieldA, fieldB).analyze + + val optimized = Optimize.execute(query) + + val generates = optimized.collect { case g: Generate => g } + assert(generates.nonEmpty) + val newGen = generates.head + newGen.generator.children.head.dataType match { + case ArrayType(st: StructType, _) => + // Should prune to just a, b - c should be removed + assert(st.fieldNames.toSet === Set("a", "b"), + s"Expected pruned struct {a, b} but got ${st.fieldNames.mkString(", ")}") + case other => + fail(s"Expected ArrayType(StructType) but got $other") + } + } + } + + // Nested data for consecutive posexplode tests + // Structure: outer_array -> array> + // b_array -> array> + private val innerStruct = StructType(Seq( + StructField("c", StringType), + StructField("c_int", LongType), + StructField("c_2", StringType))) + + private val outerStruct = StructType(Seq( + StructField("b", StringType), + StructField("b_int", LongType), + StructField("b_string", StringType), + StructField("b_array", ArrayType(innerStruct)))) + + private val nestedRel = LocalRelation( + $"a".string, + $"a_int".long, + $"a_array".array(outerStruct)) + + // Test consecutive posexplodes with position columns selected (should work) + test("consecutive posexplode: with pos columns selected") { + withSQLConf(SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.key -> "true") { + // Create outer posexplode + val outerPosExplode = PosExplode($"a_array") + val outerGen = Generate( + outerPosExplode, + unrequiredChildIndex = Nil, + outer = true, + qualifier = None, + generatorOutput = Seq( + AttributeReference("a_idx", IntegerType)(), + AttributeReference("a_array_item", outerStruct)()), + child = nestedRel) + + val outerPos = outerGen.generatorOutput(0) + val outerItem = outerGen.generatorOutput(1) + val outerFieldB = GetStructField(outerItem, 0, Some("b")) + val outerFieldBArray = GetStructField(outerItem, 3, Some("b_array")) + + // Create inner posexplode on a_array_item.b_array + val innerPosExplode = PosExplode(outerFieldBArray) + val innerGen = Generate( + innerPosExplode, + unrequiredChildIndex = Nil, + outer = true, + qualifier = None, + generatorOutput = Seq( + AttributeReference("b_idx", IntegerType)(), + AttributeReference("b_array_item", innerStruct)()), + child = outerGen) + + val innerPos = innerGen.generatorOutput(0) + val innerItem = innerGen.generatorOutput(1) + val innerFieldC = GetStructField(innerItem, 0, Some("c")) + + // Select WITH position columns: a_idx, a_array_item.b, b_idx, b_array_item.c + val query = innerGen.select(outerPos, outerFieldB, innerPos, innerFieldC).analyze + + val optimized = Optimize.execute(query) + + // Check outer struct is pruned correctly (only b and b_array) + val generates = optimized.collect { case g: Generate => g } + // Find the outer generate (the one with a_array as source) + val outerGenerates = generates.filter { g => + g.generator.children.head.dataType match { + case ArrayType(ArrayType(_, _), _) => false // This would be inner with nested array + case ArrayType(st: StructType, _) => + // Check if this is outer struct (has b_array field) + st.fieldNames.exists(_.contains("b")) + case _ => false + } + } + + if (outerGenerates.nonEmpty) { + val outerGenResult = outerGenerates.head + outerGenResult.generator.children.head.dataType match { + case ArrayType(st: StructType, _) => + // Outer struct should have b and b_array (b_int and b_string pruned) + assert(st.fieldNames.toSet === Set("b", "b_array"), + s"Expected pruned outer struct {b, b_array} but got " + + s"${st.fieldNames.mkString(", ")}") + case other => + fail(s"Expected ArrayType(StructType) for outer but got $other") + } + } + } + } + + // Test consecutive posexplodes WITHOUT position columns (this is the bug case) + test("consecutive posexplode: without pos columns should still prune outer fields") { + withSQLConf(SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.key -> "true") { + // Create outer posexplode + val outerPosExplode = PosExplode($"a_array") + val outerGen = Generate( + outerPosExplode, + unrequiredChildIndex = Nil, + outer = true, + qualifier = None, + generatorOutput = Seq( + AttributeReference("a_idx", IntegerType)(), + AttributeReference("a_array_item", outerStruct)()), + child = nestedRel) + + val outerItem = outerGen.generatorOutput(1) + val outerFieldB = GetStructField(outerItem, 0, Some("b")) + val outerFieldBArray = GetStructField(outerItem, 3, Some("b_array")) + + // Create inner posexplode on a_array_item.b_array + val innerPosExplode = PosExplode(outerFieldBArray) + val innerGen = Generate( + innerPosExplode, + unrequiredChildIndex = Nil, + outer = true, + qualifier = None, + generatorOutput = Seq( + AttributeReference("b_idx", IntegerType)(), + AttributeReference("b_array_item", innerStruct)()), + child = outerGen) + + val innerItem = innerGen.generatorOutput(1) + val innerFieldC = GetStructField(innerItem, 0, Some("c")) + + // Select WITHOUT position columns: a_array_item.b, b_array_item.c + // This is the bug case - outer field "b" may be missing from scan + val query = innerGen.select(outerFieldB, innerFieldC).analyze + + val optimized = Optimize.execute(query) + + // Check outer struct is pruned correctly (should have b and b_array) + val generates = optimized.collect { case g: Generate => g } + val outerGenerates = generates.filter { g => + g.generator.children.head.dataType match { + case ArrayType(ArrayType(_, _), _) => false + case ArrayType(st: StructType, _) => + st.fieldNames.exists(_.contains("b")) + case _ => false + } + } + + if (outerGenerates.nonEmpty) { + val outerGenResult = outerGenerates.head + outerGenResult.generator.children.head.dataType match { + case ArrayType(st: StructType, _) => + // BUG: Without pos columns, outer struct may only have b_array, missing "b" + // EXPECTED: {b, b_array} + assert(st.fieldNames.contains("b"), + s"BUG: Outer struct missing 'b' field! Got: ${st.fieldNames.mkString(", ")}. " + + "This happens when consecutive posexplode doesn't select position columns.") + assert(st.fieldNames.contains("b_array"), + s"Expected b_array in outer struct, got: ${st.fieldNames.mkString(", ")}") + case other => + fail(s"Expected ArrayType(StructType) for outer but got $other") + } + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 7f3b8383f0f8f..ea87d49839e3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -37,6 +37,7 @@ class SparkOptimizer( override def earlyScanPushDownRules: Seq[Rule[LogicalPlan]] = // TODO: move SchemaPruning into catalyst Seq( + PruneNestedFieldsThroughGenerateForScan, SchemaPruning, GroupBasedRowLevelOperationScanPlanning, V1Writes, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala index 1b23fd1a5e829..f60d84cc8863c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.util.SchemaUtils._ object SchemaPruning extends Rule[LogicalPlan] { import org.apache.spark.sql.catalyst.expressions.SchemaPruning._ - override def apply(plan: LogicalPlan): LogicalPlan = + override def apply(plan: LogicalPlan): LogicalPlan = { plan transformDown { case op @ ScanOperation(projects, filtersStayUp, filtersPushDown, l @ LogicalRelationWithTable(hadoopFsRelation: HadoopFsRelation, _)) => @@ -50,6 +50,7 @@ object SchemaPruning extends Rule[LogicalPlan] { buildPrunedRelation(l, prunedHadoopRelation, prunedMetadataSchema) }).getOrElse(op) } + } /** * This method returns optional logical plan. `None` is returned if no nested field is required or diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ExplodeNestedSchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ExplodeNestedSchemaPruningSuite.scala new file mode 100644 index 0000000000000..156064ed206b8 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ExplodeNestedSchemaPruningSuite.scala @@ -0,0 +1,3096 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.io.File + +import org.scalactic.Equality + +import org.apache.spark.sql.{DataFrame, QueryTest, Row} +import org.apache.spark.sql.catalyst.SchemaPruningTest +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.execution.FileSourceScanExec +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types._ + +/** + * Tests for nested schema pruning through explode/posexplode generators. + * + * Converted from Taboola's SchemaOnReadGeneratorTest to validate that + * Spark's native [[PruneNestedFieldsThroughGenerateForScan]] rule + * correctly prunes nested struct fields within exploded arrays. + * + * The test data provides rich array-of-struct schemas: + * - `sample`: flat arrays, complex arrays (struct with sub-arrays) + * - `double_nested`: two-level nested arrays (a_array -> b_array) + */ +abstract class ExplodeNestedSchemaPruningSuite + extends QueryTest + with FileBasedDataSourceTest + with SchemaPruningTest + with SharedSparkSession + with AdaptiveSparkPlanHelper { + + // ---- Case classes for test data ---- + + case class ComplexElement(col1: Long, col2: Long) + case class ArrayOfComplexElement(col1: Long, col2: Array[Long], col3: Long) + case class InnerStruct(aSubArray: Array[Long], col1: Long, col2: Long, col3: Long) + + case class SampleRecord( + someStr: String, + someLong: Long, + someStrArray: Array[String], + someComplexArray: Array[ComplexElement], + struct: InnerStruct, + someArrayOfComplexArrays: Array[ArrayOfComplexElement]) + + case class BArrayElement(c: String, c_int: Long, c_2: String) + case class AArrayElement(b: String, b_int: Long, b_string: String, + b_array: Array[BArrayElement]) + + case class DoubleNestedRecord( + a: String, + a_bool: Boolean, + a_int: Long, + a_string: String, + a_array: Array[AArrayElement]) + + case class ItemDetail(color: String, size: Int) + case class MixedItem(name: String, detail: ItemDetail, qty: Int) + case class MixedArrayRecord(id: Int, items: Array[MixedItem], tags: Array[String]) + + // ---- Test data ---- + + private val sampleData = Seq( + SampleRecord( + someStr = "bla", + someLong = 12345678987654321L, + someStrArray = Array("a", "b", "c"), + someComplexArray = Array(ComplexElement(1, 2)), + struct = InnerStruct(Array(1, 2, 3), 1, 2, 3), + someArrayOfComplexArrays = Array(ArrayOfComplexElement(1, Array(1, 2, 3), 4)))) + + private val doubleNestedData = Seq( + DoubleNestedRecord("a1", a_bool = true, 1, "dummy", + Array( + AArrayElement("a1_b1", 1, "da", + Array(BArrayElement("a1_b1_c1", 1, "a1_b1_c1_d"), + BArrayElement("a1_b1_c2", 1, "a1_b1_c2_d"))), + AArrayElement("a1_b2", 2, "da", + Array(BArrayElement("a1_b2_c1", 1, "a1_b2_c2_d"), + BArrayElement("a1_b2_c2", 1, "a1_b2_c2_d"))))), + DoubleNestedRecord("a2", a_bool = false, 2, "dummy", + Array( + AArrayElement("a2_b1", 3, "da", + Array(BArrayElement("a2_b1_c1", 1, "a2_b1_c1_d"), + BArrayElement("a2_b1_c2", 1, "a2_b1_c2_d"))), + AArrayElement("a2_b2", 4, "da", + Array(BArrayElement("a2_b2_c1", 1, "da"), + BArrayElement("a2_b2_c2", 1, "da")))))) + + // Mixed array data: struct array with nested struct + scalar array. + // Used to test Pattern 3 (non-struct Generate above struct chain). + private val mixedArrayData = Seq( + MixedArrayRecord(1, + Array( + MixedItem("apple", ItemDetail("red", 3), 5), + MixedItem("banana", ItemDetail("yellow", 2), 3)), + Array("fruit", "food")), + MixedArrayRecord(2, + Array(MixedItem("cherry", ItemDetail("dark red", 1), 1)), + Array("berry"))) + + // ---- Infrastructure ---- + + protected val vectorizedReaderEnabledKey: String + protected val vectorizedReaderNestedEnabledKey: String + + protected val schemaEquality: Equality[StructType] = new Equality[StructType] { + override def areEqual(a: StructType, b: Any): Boolean = + b match { + case otherType: StructType => DataTypeUtils.sameType(a, otherType) + case _ => false + } + } + + protected def checkScan(df: DataFrame, expectedSchemaCatalogStrings: String*): Unit = { + checkScanSchemata(df, expectedSchemaCatalogStrings: _*) + df.collect() + } + + protected def checkScanSchemata( + df: DataFrame, expectedSchemaCatalogStrings: String*): Unit = { + val fileSourceScanSchemata = + collect(df.queryExecution.executedPlan) { + case scan: FileSourceScanExec => scan.requiredSchema + } + assert(fileSourceScanSchemata.size === expectedSchemaCatalogStrings.size, + s"Found ${fileSourceScanSchemata.size} file sources in dataframe, " + + s"but expected $expectedSchemaCatalogStrings") + fileSourceScanSchemata.zip(expectedSchemaCatalogStrings).foreach { + case (scanSchema, expectedScanSchemaCatalogString) => + val expectedScanSchema = + CatalystSqlParser.parseDataType(expectedScanSchemaCatalogString) + implicit val equality = schemaEquality + assert(scanSchema === expectedScanSchema) + } + } + + private def withSampleData(testThunk: => Unit): Unit = { + withTempPath { dir => + val path = dir.getCanonicalPath + makeDataSourceFile(sampleData, new File(path + "/sample")) + + val schema = + """`someStr` STRING, `someLong` BIGINT, + |`someStrArray` ARRAY, + |`someComplexArray` ARRAY>, + |`struct` STRUCT<`aSubArray`: ARRAY, + | `col1`: BIGINT, `col2`: BIGINT, `col3`: BIGINT>, + |`someArrayOfComplexArrays` ARRAY, `col3`: BIGINT>>""".stripMargin + spark.read.format(dataSourceName).schema(schema).load(path + "/sample") + .createOrReplaceTempView("sample") + + testThunk + } + } + + private def withDoubleNestedData(testThunk: => Unit): Unit = { + withTempPath { dir => + val path = dir.getCanonicalPath + makeDataSourceFile(doubleNestedData, new File(path + "/double_nested")) + + val schema = + """`a` STRING, `a_bool` BOOLEAN, `a_int` BIGINT, `a_string` STRING, + |`a_array` ARRAY>>>""".stripMargin + spark.read.format(dataSourceName).schema(schema).load(path + "/double_nested") + .createOrReplaceTempView("double_nested") + + testThunk + } + } + + private def withMixedArrayData(testThunk: => Unit): Unit = { + withTempPath { dir => + val path = dir.getCanonicalPath + makeDataSourceFile(mixedArrayData, new File(path + "/mixed_array")) + + val schema = + """`id` INT, + |`items` ARRAY, `qty`: INT>>, + |`tags` ARRAY""".stripMargin + spark.read.format(dataSourceName).schema(schema).load(path + "/mixed_array") + .createOrReplaceTempView("mixed_array") + + testThunk + } + } + + protected def testExplodePruning(testName: String)(testThunk: => Unit): Unit = { + test(s"Vectorized - $testName") { + withSQLConf(vectorizedReaderEnabledKey -> "true") { + testThunk + } + } + test(s"Non-vectorized - $testName") { + withSQLConf(vectorizedReaderEnabledKey -> "false") { + testThunk + } + } + } + + // ========================================================================= + // Tests on sample data + // ========================================================================= + + testExplodePruning("explode complex array - select single field") { + withSampleData { + val query = sql( + "SELECT someStr, arrayVal.col1 " + + "FROM sample LATERAL VIEW EXPLODE(someComplexArray) as arrayVal") + checkScan(query, + "struct>>") + checkAnswer(query, Row("bla", 1) :: Nil) + } + } + + testExplodePruning("posexplode complex array - select single field") { + withSampleData { + val query = sql( + "SELECT someStr, arrayVal.col1 " + + "FROM sample LATERAL VIEW POSEXPLODE(someComplexArray) " + + "as arrayIdx, arrayVal") + checkScan(query, + "struct>>") + checkAnswer(query, Row("bla", 1) :: Nil) + } + } + + // ========================================================================= + // Multi-field selection tests (SPARK-34956) + // ========================================================================= + + // This is the key test for PruneNestedFieldsThroughGenerateForScan - + // selecting multiple fields from an exploded struct element. + testExplodePruning("explode multi-field selection from struct") { + withSampleData { + val query = sql( + "SELECT arrayVal.col1, arrayVal.col2 " + + "FROM sample LATERAL VIEW EXPLODE(someArrayOfComplexArrays) as arrayVal") + // Both col1 and col2 selected, col3 pruned + checkScan(query, + "struct>>>") + checkAnswer(query, Row(1, Array(1, 2, 3)) :: Nil) + } + } + + testExplodePruning("posexplode multi-field selection from struct") { + withSampleData { + val query = sql( + "SELECT arrayIdx, arrayVal.col1, arrayVal.col3 " + + "FROM sample LATERAL VIEW POSEXPLODE(someArrayOfComplexArrays) " + + "as arrayIdx, arrayVal") + // col1 and col3 selected, col2 pruned (non-contiguous fields) + checkScan(query, + "struct>>") + checkAnswer(query, Row(0, 1, 4) :: Nil) + } + } + + // Multi-level selection: fields from both outer and inner exploded elements. + // The outer explode needs col1 (for project) and col2 (for inner explode), + // so col3 can be pruned even in chained generates. + testExplodePruning("consecutive explode - multi-level field selection") { + withSampleData { + val query = sql( + "SELECT complex.col1, val " + + "FROM sample " + + "LATERAL VIEW EXPLODE(someArrayOfComplexArrays) as complex " + + "LATERAL VIEW EXPLODE(complex.col2) as val") + // col1 + col2 needed, col3 pruned + checkScan(query, + "struct>>>") + checkAnswer(query, + Row(1, 1) :: Row(1, 2) :: Row(1, 3) :: Nil) + } + } + + testExplodePruning("consecutive explode - prune outer struct") { + withSampleData { + val query = sql( + "SELECT someStr, val " + + "FROM sample " + + "LATERAL VIEW EXPLODE(someArrayOfComplexArrays) as complex " + + "LATERAL VIEW EXPLODE(complex.col2) as val") + checkScan(query, + "struct>>>") + checkAnswer(query, + Row("bla", 1) :: Row("bla", 2) :: Row("bla", 3) :: Nil) + } + } + + // Consecutive posexplode now supports pruning through chained generators. + // Only col2 is needed from the outer struct. + testExplodePruning("consecutive posexplode - prune outer struct") { + withSampleData { + val query = sql( + "SELECT someStr, val " + + "FROM sample " + + "LATERAL VIEW POSEXPLODE(someArrayOfComplexArrays) " + + "as complex_idx, complex " + + "LATERAL VIEW POSEXPLODE(complex.col2) as val_idx, val") + // Only col2 needed from outer struct, col1 and col3 pruned + checkScan(query, + "struct>>>") + checkAnswer(query, + Row("bla", 1) :: Row("bla", 2) :: Row("bla", 3) :: Nil) + } + } + + testExplodePruning("explode with filter on nested array field in subquery") { + withSampleData { + val query = sql( + "WITH base AS (SELECT someArrayOfComplexArrays FROM sample " + + "WHERE someArrayOfComplexArrays.col2 IS NOT NULL) " + + "SELECT item.col1 AS str " + + "FROM base LATERAL VIEW EXPLODE(someArrayOfComplexArrays) as item") + checkScan(query, + "struct>>>") + checkAnswer(query, Row(1) :: Nil) + } + } + + // Posexplode with filter: the projection needs col1, and the filter needs col2. + // Both fields are included in the pruned scan. + testExplodePruning("posexplode with filter on nested array field in subquery") { + withSampleData { + val query = sql( + "WITH base AS (SELECT someArrayOfComplexArrays FROM sample " + + "WHERE someArrayOfComplexArrays.col2 IS NOT NULL) " + + "SELECT item.col1 AS str " + + "FROM base " + + "LATERAL VIEW POSEXPLODE(someArrayOfComplexArrays) " + + "as item_idx, item") + // Both col1 (from projection) and col2 (from filter) are needed + checkScan(query, + "struct>>>") + checkAnswer(query, Row(1) :: Nil) + } + } + + testExplodePruning("explode with filter over nested array - direct WHERE") { + withSampleData { + val query = sql( + "SELECT item.col1 AS rst " + + "FROM sample " + + "LATERAL VIEW EXPLODE(someArrayOfComplexArrays) as item " + + "WHERE someArrayOfComplexArrays.col2 IS NOT NULL") + checkScan(query, + "struct>>>") + checkAnswer(query, Row(1) :: Nil) + } + } + + // Same as subquery variant - both col1 (projection) and col2 (filter) needed + testExplodePruning("posexplode with filter over nested array - direct WHERE") { + withSampleData { + val query = sql( + "SELECT item.col1 AS rst " + + "FROM sample " + + "LATERAL VIEW POSEXPLODE(someArrayOfComplexArrays) " + + "as item_idx, item " + + "WHERE someArrayOfComplexArrays.col2 IS NOT NULL") + checkScan(query, + "struct>>>") + checkAnswer(query, Row(1) :: Nil) + } + } + + testExplodePruning("explode sub-array from struct") { + withSampleData { + val query = sql( + "SELECT arrayVal FROM sample " + + "LATERAL VIEW EXPLODE(struct.aSubArray) as arrayVal") + checkScan(query, + "struct>>") + checkAnswer(query, + Row(1) :: Row(2) :: Row(3) :: Nil) + } + } + + testExplodePruning("posexplode sub-array from struct") { + withSampleData { + val query = sql( + "SELECT arrayVal FROM sample " + + "LATERAL VIEW POSEXPLODE(struct.aSubArray) " + + "as arrayIdx, arrayVal") + checkScan(query, + "struct>>") + checkAnswer(query, + Row(1) :: Row(2) :: Row(3) :: Nil) + } + } + + // ========================================================================= + // Tests on double-nested data + // ========================================================================= + + // Double-nested explode: pruning through chained generators now works. + // Outer struct is pruned to only needed fields (b + b_array). + // Inner struct (b_array elements) is also pruned to only c. + testExplodePruning("double-nested explode - select leaf fields") { + withDoubleNestedData { + val query = sql( + "WITH base AS (" + + "SELECT * FROM double_nested " + + "LATERAL VIEW OUTER EXPLODE(a_array) as a_array_item " + + "LATERAL VIEW OUTER EXPLODE(a_array_item.b_array) as b_array_item) " + + "SELECT a_array_item.b, b_array_item.c FROM base") + // b + b_array needed from outer struct, inner struct pruned to just c + checkScan(query, + "struct>>>>") + checkAnswer(query, + Row("a1_b1", "a1_b1_c1") :: Row("a1_b1", "a1_b1_c2") :: + Row("a1_b2", "a1_b2_c1") :: Row("a1_b2", "a1_b2_c2") :: + Row("a2_b1", "a2_b1_c1") :: Row("a2_b1", "a2_b1_c2") :: + Row("a2_b2", "a2_b2_c1") :: Row("a2_b2", "a2_b2_c2") :: Nil) + } + } + + // Same as explode variant - both outer and inner structs pruned + testExplodePruning("double-nested posexplode - select leaf fields with pos") { + withDoubleNestedData { + val query = sql( + "WITH base AS (" + + "SELECT * FROM double_nested " + + "LATERAL VIEW OUTER POSEXPLODE(a_array) " + + "as a_array_index, a_array_item " + + "LATERAL VIEW OUTER POSEXPLODE(a_array_item.b_array) " + + "as b_array_index, b_array_item) " + + "SELECT a_array_index, a_array_item.b, " + + "b_array_index, b_array_item.c FROM base") + checkScan(query, + "struct>>>>") + checkAnswer(query, + Row(0, "a1_b1", 0, "a1_b1_c1") :: Row(0, "a1_b1", 1, "a1_b1_c2") :: + Row(1, "a1_b2", 0, "a1_b2_c1") :: Row(1, "a1_b2", 1, "a1_b2_c2") :: + Row(0, "a2_b1", 0, "a2_b1_c1") :: Row(0, "a2_b1", 1, "a2_b1_c2") :: + Row(1, "a2_b2", 0, "a2_b2_c1") :: Row(1, "a2_b2", 1, "a2_b2_c2") :: Nil) + } + } + + // Multi-field from single posexplode - tests SPARK-34956 multi-field pruning + testExplodePruning("single posexplode - multi-field selection") { + withDoubleNestedData { + val query = sql( + "WITH base AS (" + + "SELECT * FROM double_nested " + + "LATERAL VIEW OUTER POSEXPLODE(a_array) " + + "as a_array_index, a_array_item) " + + "SELECT a_array_index, a_array_item.b, a_array_item.b_int FROM base") + // Both b and b_int selected, b_string and b_array pruned + checkScan(query, + "struct>>") + checkAnswer(query, + Row(0, "a1_b1", 1) :: Row(1, "a1_b2", 2) :: + Row(0, "a2_b1", 3) :: Row(1, "a2_b2", 4) :: Nil) + } + } + + testExplodePruning("single posexplode - select struct field") { + withDoubleNestedData { + val query = sql( + "WITH base AS (" + + "SELECT * FROM double_nested " + + "LATERAL VIEW OUTER POSEXPLODE(a_array) " + + "as a_array_index, a_array_item) " + + "SELECT a_array_index, a_array_item.b FROM base") + checkScan(query, + "struct>>") + checkAnswer(query, + Row(0, "a1_b1") :: Row(1, "a1_b2") :: + Row(0, "a2_b1") :: Row(1, "a2_b2") :: Nil) + } + } + + testExplodePruning( + "single posexplode with struct selection in subquery") { + withDoubleNestedData { + val query = sql( + "WITH base AS (" + + "SELECT a_array_index, a_array_item FROM double_nested " + + "LATERAL VIEW OUTER POSEXPLODE(a_array) " + + "as a_array_index, a_array_item) " + + "SELECT a_array_index, a_array_item.b FROM base") + checkScan(query, + "struct>>") + checkAnswer(query, + Row(0, "a1_b1") :: Row(1, "a1_b2") :: + Row(0, "a2_b1") :: Row(1, "a2_b2") :: Nil) + } + } + + // DF API double-nested - pruning now works through chained generators + // Both outer and inner structs are pruned to only needed fields + testExplodePruning("double-nested explode with struct selection via DF API") { + withDoubleNestedData { + var df = spark.table("double_nested") + df = df.select(col("a_int"), + explode_outer(col("a_array")).as("a_array_item")) + df = df.select(col("a_int"), col("a_array_item.b"), + explode_outer(col("a_array_item.b_array")).as("b_array_item")) + val query = df.select("a_int", "b", "b_array_item.c") + // b + b_array from outer struct, inner struct pruned to just c + checkScan(query, + "struct>>>>") + checkAnswer(query, + Row(1, "a1_b1", "a1_b1_c1") :: Row(1, "a1_b1", "a1_b1_c2") :: + Row(1, "a1_b2", "a1_b2_c1") :: Row(1, "a1_b2", "a1_b2_c2") :: + Row(2, "a2_b1", "a2_b1_c1") :: Row(2, "a2_b1", "a2_b1_c2") :: + Row(2, "a2_b2", "a2_b2_c1") :: Row(2, "a2_b2", "a2_b2_c2") :: Nil) + } + } + + // Chained posexplode - both outer and inner structs pruned + testExplodePruning( + "double-nested posexplode selecting struct fields") { + withDoubleNestedData { + val query = sql( + "WITH base AS (" + + "SELECT a_array_index, a_array_item, " + + "b_array_index, b_array_item " + + "FROM double_nested " + + "LATERAL VIEW OUTER POSEXPLODE(a_array) " + + "as a_array_index, a_array_item " + + "LATERAL VIEW OUTER POSEXPLODE(a_array_item.b_array) " + + "as b_array_index, b_array_item) " + + "SELECT a_array_index, a_array_item.b, " + + "b_array_index, b_array_item.c FROM base") + checkScan(query, + "struct>>>>") + checkAnswer(query, + Row(0, "a1_b1", 0, "a1_b1_c1") :: Row(0, "a1_b1", 1, "a1_b1_c2") :: + Row(1, "a1_b2", 0, "a1_b2_c1") :: Row(1, "a1_b2", 1, "a1_b2_c2") :: + Row(0, "a2_b1", 0, "a2_b1_c1") :: Row(0, "a2_b1", 1, "a2_b1_c2") :: + Row(1, "a2_b2", 0, "a2_b2_c1") :: Row(1, "a2_b2", 1, "a2_b2_c2") :: Nil) + } + } + + // Aggregation creates a barrier - the exploded array goes through FIRST(), + // which Spark doesn't optimize through, so full array is scanned. + testExplodePruning("explode with pass-through and aggregation") { + withSampleData { + val query = sql( + "WITH base AS (" + + "SELECT someStr, FIRST(someComplexArray) as complexArray " + + "FROM sample GROUP BY someStr) " + + "SELECT complex.col1 " + + "FROM base LATERAL VIEW EXPLODE(complexArray) as complex") + // Full array scanned due to aggregation barrier + checkScan(query, + "struct>>") + checkAnswer(query, Row(1) :: Nil) + } + } + + // Same as explode - aggregation barrier prevents pruning + testExplodePruning("posexplode with pass-through and aggregation") { + withSampleData { + val query = sql( + "WITH base AS (" + + "SELECT someStr, FIRST(someComplexArray) as complexArray " + + "FROM sample GROUP BY someStr) " + + "SELECT complex.col1 " + + "FROM base LATERAL VIEW POSEXPLODE(complexArray) " + + "as complexIdx, complex") + checkScan(query, + "struct>>") + checkAnswer(query, Row(1) :: Nil) + } + } + + // ========================================================================= + // Corner case tests for chained Generate support + // ========================================================================= + + // Mixed explode/posexplode chain - outer is explode, inner is posexplode + testExplodePruning("mixed chain - explode then posexplode") { + withSampleData { + val query = sql( + "SELECT complex.col1, valIdx, val " + + "FROM sample " + + "LATERAL VIEW EXPLODE(someArrayOfComplexArrays) as complex " + + "LATERAL VIEW POSEXPLODE(complex.col2) as valIdx, val") + // col1 + col2 needed, col3 pruned + checkScan(query, + "struct>>>") + checkAnswer(query, + Row(1, 0, 1) :: Row(1, 1, 2) :: Row(1, 2, 3) :: Nil) + } + } + + // Mixed chain - outer is posexplode, inner is explode + testExplodePruning("mixed chain - posexplode then explode") { + withSampleData { + val query = sql( + "SELECT complexIdx, complex.col1, val " + + "FROM sample " + + "LATERAL VIEW POSEXPLODE(someArrayOfComplexArrays) " + + "as complexIdx, complex " + + "LATERAL VIEW EXPLODE(complex.col2) as val") + // col1 + col2 needed, col3 pruned + checkScan(query, + "struct>>>") + checkAnswer(query, + Row(0, 1, 1) :: Row(0, 1, 2) :: Row(0, 1, 3) :: Nil) + } + } + + // Triple-nested chain: three levels of lateral views + testExplodePruning("triple-nested explode chain") { + withDoubleNestedData { + val query = sql( + "SELECT a_array_item.b, b_array_item.c, " + + "c_char FROM double_nested " + + "LATERAL VIEW EXPLODE(a_array) as a_array_item " + + "LATERAL VIEW EXPLODE(a_array_item.b_array) as b_array_item " + + "LATERAL VIEW EXPLODE(ARRAY(b_array_item.c)) as c_char") + // b + b_array needed from outer, inner struct pruned to just c + checkScan(query, + "struct>>>>") + checkAnswer(query, + Row("a1_b1", "a1_b1_c1", "a1_b1_c1") :: + Row("a1_b1", "a1_b1_c2", "a1_b1_c2") :: + Row("a1_b2", "a1_b2_c1", "a1_b2_c1") :: + Row("a1_b2", "a1_b2_c2", "a1_b2_c2") :: + Row("a2_b1", "a2_b1_c1", "a2_b1_c1") :: + Row("a2_b1", "a2_b1_c2", "a2_b1_c2") :: + Row("a2_b2", "a2_b2_c1", "a2_b2_c1") :: + Row("a2_b2", "a2_b2_c2", "a2_b2_c2") :: Nil) + } + } + + // Direct element reference blocks pruning - whole struct is used + testExplodePruning("direct element reference blocks pruning") { + withSampleData { + val query = sql( + "SELECT complex, val " + + "FROM sample " + + "LATERAL VIEW EXPLODE(someArrayOfComplexArrays) as complex " + + "LATERAL VIEW EXPLODE(complex.col2) as val") + // complex is referenced directly, so all fields needed + checkScan(query, + "struct,col3:bigint>>>") + checkAnswer(query, + Row(Row(1L, Array(1, 2, 3), 4L), 1) :: + Row(Row(1L, Array(1, 2, 3), 4L), 2) :: + Row(Row(1L, Array(1, 2, 3), 4L), 3) :: Nil) + } + } + + // Pos-only optimization in chain - only position from outer, field from inner + testExplodePruning("pos-only outer with field from inner") { + withSampleData { + val query = sql( + "SELECT complexIdx, val " + + "FROM sample " + + "LATERAL VIEW POSEXPLODE(someArrayOfComplexArrays) " + + "as complexIdx, complex " + + "LATERAL VIEW EXPLODE(complex.col2) as val") + // Only col2 needed from outer (for inner explode), col1 and col3 pruned + checkScan(query, + "struct>>>") + checkAnswer(query, + Row(0, 1) :: Row(0, 2) :: Row(0, 3) :: Nil) + } + } + + // Non-contiguous fields in chain - tests ordinal stability + // Selecting b (idx 0), b_int (idx 1) from outer, c (idx 0), c_2 (idx 2) from inner + // This verifies ordinal fixing works for non-adjacent fields + testExplodePruning("non-contiguous fields ordinal stability in chain") { + withDoubleNestedData { + val query = sql( + "SELECT a_array_item.b, a_array_item.b_int, " + + "b_array_item.c, b_array_item.c_2 FROM double_nested " + + "LATERAL VIEW EXPLODE(a_array) as a_array_item " + + "LATERAL VIEW EXPLODE(a_array_item.b_array) as b_array_item") + // Outer: b (idx 0) + b_int (idx 1) + b_array (idx 3) - b_string pruned + // Inner: c + c_2 needed, c_int pruned + checkScan(query, + "struct>>>>") + checkAnswer(query, + Row("a1_b1", 1, "a1_b1_c1", "a1_b1_c1_d") :: + Row("a1_b1", 1, "a1_b1_c2", "a1_b1_c2_d") :: + Row("a1_b2", 2, "a1_b2_c1", "a1_b2_c2_d") :: + Row("a1_b2", 2, "a1_b2_c2", "a1_b2_c2_d") :: + Row("a2_b1", 3, "a2_b1_c1", "a2_b1_c1_d") :: + Row("a2_b1", 3, "a2_b1_c2", "a2_b1_c2_d") :: + Row("a2_b2", 4, "a2_b2_c1", "da") :: + Row("a2_b2", 4, "a2_b2_c2", "da") :: Nil) + } + } + + // Chain with filter on generator output between generates + testExplodePruning("chain with filter on outer generator output") { + withDoubleNestedData { + val query = sql( + "SELECT a_array_item.b, b_array_item.c FROM double_nested " + + "LATERAL VIEW EXPLODE(a_array) as a_array_item " + + "LATERAL VIEW EXPLODE(a_array_item.b_array) as b_array_item " + + "WHERE a_array_item.b_int > 2") + // b + b_int (for filter) + b_array needed from outer, inner struct pruned to c + checkScan(query, + "struct>>>>") + checkAnswer(query, + Row("a2_b1", "a2_b1_c1") :: Row("a2_b1", "a2_b1_c2") :: + Row("a2_b2", "a2_b2_c1") :: Row("a2_b2", "a2_b2_c2") :: Nil) + } + } + + // Chain with filter on inner generator output + testExplodePruning("chain with filter on inner generator output") { + withDoubleNestedData { + val query = sql( + "SELECT a_array_item.b, b_array_item.c FROM double_nested " + + "LATERAL VIEW EXPLODE(a_array) as a_array_item " + + "LATERAL VIEW EXPLODE(a_array_item.b_array) as b_array_item " + + "WHERE b_array_item.c_int > 0") + // b + b_array needed from outer, c + c_int (for filter) from inner + checkScan(query, + "struct>>>>") + checkAnswer(query, + Row("a1_b1", "a1_b1_c1") :: Row("a1_b1", "a1_b1_c2") :: + Row("a1_b2", "a1_b2_c1") :: Row("a1_b2", "a1_b2_c2") :: + Row("a2_b1", "a2_b1_c1") :: Row("a2_b1", "a2_b1_c2") :: + Row("a2_b2", "a2_b2_c1") :: Row("a2_b2", "a2_b2_c2") :: Nil) + } + } + + // Only inner Generate can prune - outer uses all fields + testExplodePruning("only inner generate can prune") { + withDoubleNestedData { + val query = sql( + "SELECT a_array_item.b, a_array_item.b_int, " + + "a_array_item.b_string, b_array_item.c FROM double_nested " + + "LATERAL VIEW EXPLODE(a_array) as a_array_item " + + "LATERAL VIEW EXPLODE(a_array_item.b_array) as b_array_item") + // All fields from outer (b, b_int, b_string, b_array), inner pruned to just c + checkScan(query, + "struct>>>>") + checkAnswer(query, + Row("a1_b1", 1, "da", "a1_b1_c1") :: Row("a1_b1", 1, "da", "a1_b1_c2") :: + Row("a1_b2", 2, "da", "a1_b2_c1") :: Row("a1_b2", 2, "da", "a1_b2_c2") :: + Row("a2_b1", 3, "da", "a2_b1_c1") :: Row("a2_b1", 3, "da", "a2_b1_c2") :: + Row("a2_b2", 4, "da", "a2_b2_c1") :: Row("a2_b2", 4, "da", "a2_b2_c2") :: Nil) + } + } + + // All fields selected from chain - no pruning expected + testExplodePruning("all fields from chain - no pruning") { + withDoubleNestedData { + val query = sql( + "SELECT a_array_item.*, b_array_item.* FROM double_nested " + + "LATERAL VIEW EXPLODE(a_array) as a_array_item " + + "LATERAL VIEW EXPLODE(a_array_item.b_array) as b_array_item") + // All fields needed from both levels + checkScan(query, + "struct>>>>") + // Just check it runs without error + assert(query.collect().length == 8) + } + } + + // Posexplode chain with pos from both levels + testExplodePruning("posexplode chain using both positions") { + withSampleData { + val query = sql( + "SELECT complexIdx, valIdx " + + "FROM sample " + + "LATERAL VIEW POSEXPLODE(someArrayOfComplexArrays) " + + "as complexIdx, complex " + + "LATERAL VIEW POSEXPLODE(complex.col2) as valIdx, val") + // Only positions used, but col2 needed for inner explode + // Outer prunes to just col2 (minimal for inner explode) + checkScan(query, + "struct>>>") + checkAnswer(query, + Row(0, 0) :: Row(0, 1) :: Row(0, 2) :: Nil) + } + } + + // Single explode with OUTER - null handling + testExplodePruning("explode outer with null array") { + withSampleData { + // Create a query that would have null arrays (using a filter that produces no matches) + val query = sql( + "SELECT complex.col1 " + + "FROM sample " + + "LATERAL VIEW OUTER EXPLODE(someArrayOfComplexArrays) as complex " + + "WHERE someStr = 'bla'") + checkScan(query, + "struct>>") + checkAnswer(query, Row(1) :: Nil) + } + } + + // Verify pruning doesn't break when array has multiple elements + testExplodePruning("multi-element array pruning correctness") { + withDoubleNestedData { + // Simple test without filter to avoid V1/V2 differences + val query = sql( + "SELECT a_array_item.b, a_array_item.b_int FROM double_nested " + + "LATERAL VIEW EXPLODE(a_array) as a_array_item") + // b + b_int selected from exploded struct, b_string and b_array pruned + checkScan(query, + "struct>>") + checkAnswer(query, + Row("a1_b1", 1) :: Row("a1_b2", 2) :: + Row("a2_b1", 3) :: Row("a2_b2", 4) :: Nil) + } + } + + // ========================================================================= + // Step 7 tests: Additional correctness fixes + // ========================================================================= + + // Step 7.1: Fix ordinals inside rewritten leaf filters + // Test filter on non-contiguous field (keep col1 & col3, filter on col3) + testExplodePruning("filter on non-contiguous field with ordinal fix") { + withSampleData { + // Filter on source array field (col3) while selecting col1 and col3 + // This tests that GetArrayStructFields ordinals are fixed in leaf filters + val query = sql( + """SELECT arrayVal.col1, arrayVal.col3 + |FROM sample + |LATERAL VIEW EXPLODE(someArrayOfComplexArrays) as arrayVal + |WHERE size(someArrayOfComplexArrays) > 0""".stripMargin) + // col1 and col3 selected, col2 pruned + checkScan(query, + "struct>>") + checkAnswer(query, Row(1, 4) :: Nil) + } + } + + // Step 7.4: Safer handling of Projects in decomposeChild + // Test that explode works correctly with nested pruning through the DF API + testExplodePruning("explode via DF API - multi field selection") { + withSampleData { + val query = spark.table("sample") + .selectExpr("someStr", "explode(someArrayOfComplexArrays) as elem") + .selectExpr("someStr", "elem.col1", "elem.col3") + // col1 and col3 selected, col2 pruned + checkScan(query, + "struct>>") + checkAnswer(query, Row("bla", 1, 4) :: Nil) + } + } + + // ========================================================================= + // Nested array inside struct tests (Step 7.2) + // ========================================================================= + + // Case class for nested array inside struct test data + case class NestedArrayElement(x: Long, y: Long, z: Long) + case class WrapperStruct( + name: String, + nestedArray: Array[NestedArrayElement]) + case class NestedArrayRecord( + id: Long, + wrapper: WrapperStruct) + + private val nestedArrayData = Seq( + NestedArrayRecord(1, WrapperStruct("w1", + Array(NestedArrayElement(10, 20, 30), NestedArrayElement(11, 21, 31)))), + NestedArrayRecord(2, WrapperStruct("w2", + Array(NestedArrayElement(12, 22, 32))))) + + private def withNestedArrayData(testThunk: => Unit): Unit = { + withTempPath { dir => + val path = dir.getCanonicalPath + makeDataSourceFile(nestedArrayData, new File(path + "/nested_array")) + + val schema = + """`id` BIGINT, + |`wrapper` STRUCT<`name`: STRING, + | `nestedArray`: ARRAY>> + """.stripMargin + spark.read.format(dataSourceName).schema(schema).load(path + "/nested_array") + .createOrReplaceTempView("nested_array") + + testThunk + } + } + + // Step 7.2: Allow pruning when generator child is derived from scan attribute + // Test explode(col.structArr) where array is nested inside a struct + testExplodePruning("explode array nested inside struct - single field") { + withNestedArrayData { + val query = sql( + """SELECT id, elem.x + |FROM nested_array + |LATERAL VIEW EXPLODE(wrapper.nestedArray) as elem""".stripMargin) + // Should prune y and z, keeping only x + checkScan(query, + "struct>>>") + checkAnswer(query, + Row(1, 10) :: Row(1, 11) :: Row(2, 12) :: Nil) + } + } + + // Note: Multi-field selection from nested array inside struct is not yet fully supported. + // When ColumnPruning creates intermediate aliases for nested fields, our rule may not + // be able to trace the aliases back to scan attributes. The single-field case works + // because ColumnPruning doesn't always create aliases for simple nested projections. + + testExplodePruning("posexplode array nested inside struct") { + withNestedArrayData { + val query = sql( + """SELECT id, pos, elem.y + |FROM nested_array + |LATERAL VIEW POSEXPLODE(wrapper.nestedArray) as pos, elem""".stripMargin) + // Should prune x and z, keeping only y + checkScan(query, + "struct>>>") + checkAnswer(query, + Row(1, 0, 20) :: Row(1, 1, 21) :: Row(2, 0, 22) :: Nil) + } + } + + // Step 8: Test for nested arrays (array>) + // This tests the GetNestedArrayStructFields expression integration + protected def withDoublyNestedArrayData(testThunk: => Unit): Unit = { + withTempPath { dir => + val path = dir.getAbsolutePath + val schema = StructType(Seq( + StructField("id", LongType), + StructField("doublyNested", ArrayType(ArrayType( + StructType(Seq( + StructField("a", LongType), + StructField("b", StringType), + StructField("c", LongType) + )) + ))) + )) + + // Create test data: array>> + // Row 1: [[[1, "x", 10], [2, "y", 20]], [[3, "z", 30]]] + // Row 2: [[[4, "w", 40]]] + val data = Seq( + Row(1L, Seq( + Seq(Row(1L, "x", 10L), Row(2L, "y", 20L)), + Seq(Row(3L, "z", 30L)) + )), + Row(2L, Seq( + Seq(Row(4L, "w", 40L)) + )) + ) + + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + .write.format(dataSourceName).save(path + "/doubly_nested") + + spark.read.format(dataSourceName).schema(schema).load(path + "/doubly_nested") + .createOrReplaceTempView("doubly_nested") + + testThunk + } + } + + testExplodePruning("explode doubly nested array - single field") { + withDoublyNestedArrayData { + // Explode array> produces array elements + // Then explode again to get struct elements + val query = sql( + """SELECT id, inner_elem.a + |FROM doubly_nested + |LATERAL VIEW EXPLODE(doublyNested) as outer_elem + |LATERAL VIEW EXPLODE(outer_elem) as inner_elem""".stripMargin) + + // After two explodes, we access inner_elem.a + // The outer array can be pruned to only include field 'a' in the innermost struct + checkAnswer(query, + Row(1, 1) :: Row(1, 2) :: Row(1, 3) :: Row(2, 4) :: Nil) + } + } + + testExplodePruning("explode doubly nested array - multi field") { + withDoublyNestedArrayData { + val query = sql( + """SELECT id, inner_elem.a, inner_elem.c + |FROM doubly_nested + |LATERAL VIEW EXPLODE(doublyNested) as outer_elem + |LATERAL VIEW EXPLODE(outer_elem) as inner_elem""".stripMargin) + + checkAnswer(query, + Row(1, 1, 10) :: Row(1, 2, 20) :: Row(1, 3, 30) :: Row(2, 4, 40) :: Nil) + } + } + + // ============================================================================ + // PageView-style comprehensive tests: 3-level nesting with fields at each level + // Models real-world data: PageView -> Requests -> Items + // ============================================================================ + + /** + * PageView data model: + * - Root: pageId, country, platform, userId (select some, prune others) + * - requests: array> (4 fields) + * - items: array> (5 fields) + * + * This mirrors real analytics data where pruning at each level is important. + */ + protected def withPageViewData(testThunk: => Unit): Unit = { + withTempPath { dir => + val path = dir.getAbsolutePath + + val itemStruct = StructType(Seq( + StructField("itemId", LongType), + StructField("clicked", BooleanType), + StructField("visible", BooleanType), + StructField("charged", BooleanType), + StructField("campaignId", LongType) + )) + + val requestStruct = StructType(Seq( + StructField("requestId", StringType), + StructField("ts", LongType), + StructField("source", StringType), + StructField("items", ArrayType(itemStruct)) + )) + + val schema = StructType(Seq( + StructField("pageId", LongType), + StructField("country", StringType), + StructField("platform", StringType), + StructField("userId", LongType), + StructField("requests", ArrayType(requestStruct)) + )) + + // PageView 1: US, desktop, 2 requests with 2 items each + // PageView 2: UK, mobile, 1 request with 3 items + val data = Seq( + Row(100L, "US", "desktop", 1001L, Seq( + Row("req1", 1000L, "organic", Seq( + Row(1L, true, true, false, 500L), + Row(2L, false, true, true, 501L) + )), + Row("req2", 1001L, "paid", Seq( + Row(3L, true, false, false, 502L), + Row(4L, true, true, true, 503L) + )) + )), + Row(200L, "UK", "mobile", 1002L, Seq( + Row("req3", 2000L, "organic", Seq( + Row(5L, false, true, false, 600L), + Row(6L, true, true, true, 601L), + Row(7L, false, false, false, 602L) + )) + )) + ) + + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + .write.format(dataSourceName).save(path + "/pageview") + + spark.read.format(dataSourceName).schema(schema).load(path + "/pageview") + .createOrReplaceTempView("pageview") + + testThunk + } + } + + // ============================================================================ + // PageView-style tests: 3-level nesting + // Inner generate pruning IS supported - we can prune the innermost struct + // fields even when the inner generate's source comes from the outer's output. + // ============================================================================ + + // Test: Select fields from ALL 3 levels + // Inner items struct IS pruned to only itemId and clicked + testExplodePruning("pageview - fields from all 3 levels") { + withPageViewData { + val query = sql( + """SELECT + | pv.pageId, pv.country, + | req.requestId, req.ts, + | item.itemId, item.clicked + |FROM pageview pv + |LATERAL VIEW EXPLODE(pv.requests) as req + |LATERAL VIEW EXPLODE(req.items) as item""".stripMargin) + + // Can prune: platform, userId (root), source (request) + // CAN prune inner items struct - only itemId and clicked needed + checkScan(query, + "struct>>>>") + + checkAnswer(query, Seq( + Row(100L, "US", "req1", 1000L, 1L, true), + Row(100L, "US", "req1", 1000L, 2L, false), + Row(100L, "US", "req2", 1001L, 3L, true), + Row(100L, "US", "req2", 1001L, 4L, true), + Row(200L, "UK", "req3", 2000L, 5L, false), + Row(200L, "UK", "req3", 2000L, 6L, true), + Row(200L, "UK", "req3", 2000L, 7L, false) + )) + } + } + + // Test: Select only from root + innermost level (skip middle) + testExplodePruning("pageview - root and innermost levels only") { + withPageViewData { + val query = sql( + """SELECT + | pv.pageId, pv.platform, + | item.itemId, item.campaignId + |FROM pageview pv + |LATERAL VIEW EXPLODE(pv.requests) as req + |LATERAL VIEW EXPLODE(req.items) as item""".stripMargin) + + // Request fields pruned except items, inner items struct IS pruned + checkScan(query, + "struct>>>>") + + checkAnswer(query, Seq( + Row(100L, "desktop", 1L, 500L), + Row(100L, "desktop", 2L, 501L), + Row(100L, "desktop", 3L, 502L), + Row(100L, "desktop", 4L, 503L), + Row(200L, "mobile", 5L, 600L), + Row(200L, "mobile", 6L, 601L), + Row(200L, "mobile", 7L, 602L) + )) + } + } + + // Test: Select only from middle + innermost levels (no root fields) + testExplodePruning("pageview - middle and innermost levels only") { + withPageViewData { + val query = sql( + """SELECT + | req.requestId, req.source, + | item.clicked, item.visible, item.charged + |FROM pageview pv + |LATERAL VIEW EXPLODE(pv.requests) as req + |LATERAL VIEW EXPLODE(req.items) as item""".stripMargin) + + // Root fields fully pruned, request fields pruned, inner items IS pruned + checkScan(query, + "struct>>>>") + + checkAnswer(query, Seq( + Row("req1", "organic", true, true, false), + Row("req1", "organic", false, true, true), + Row("req2", "paid", true, false, false), + Row("req2", "paid", true, true, true), + Row("req3", "organic", false, true, false), + Row("req3", "organic", true, true, true), + Row("req3", "organic", false, false, false) + )) + } + } + + // Test: Posexplode at both levels with field selection + testExplodePruning("pageview - posexplode at both levels") { + withPageViewData { + val query = sql( + """SELECT + | pv.country, + | reqIdx, req.requestId, + | itemIdx, item.itemId + |FROM pageview pv + |LATERAL VIEW POSEXPLODE(pv.requests) as reqIdx, req + |LATERAL VIEW POSEXPLODE(req.items) as itemIdx, item""".stripMargin) + + // Inner items struct IS pruned + checkScan(query, + "struct>>>>") + + checkAnswer(query, Seq( + Row("US", 0, "req1", 0, 1L), + Row("US", 0, "req1", 1, 2L), + Row("US", 1, "req2", 0, 3L), + Row("US", 1, "req2", 1, 4L), + Row("UK", 0, "req3", 0, 5L), + Row("UK", 0, "req3", 1, 6L), + Row("UK", 0, "req3", 2, 7L) + )) + } + } + + // Test: Filter at middle level + testExplodePruning("pageview - filter on middle level") { + withPageViewData { + val query = sql( + """SELECT + | pv.pageId, + | req.requestId, req.ts, + | item.itemId, item.clicked + |FROM pageview pv + |LATERAL VIEW EXPLODE(pv.requests) as req + |LATERAL VIEW EXPLODE(req.items) as item + |WHERE req.source = 'organic'""".stripMargin) + + // source needed for filter, inner items struct IS pruned + checkScan(query, + "struct>>>>") + + checkAnswer(query, Seq( + Row(100L, "req1", 1000L, 1L, true), + Row(100L, "req1", 1000L, 2L, false), + Row(200L, "req3", 2000L, 5L, false), + Row(200L, "req3", 2000L, 6L, true), + Row(200L, "req3", 2000L, 7L, false) + )) + } + } + + // Test: Filter at innermost level + // Inner items struct IS pruned - clicked needed for filter + testExplodePruning("pageview - filter on innermost level") { + withPageViewData { + val query = sql( + """SELECT + | pv.country, + | req.requestId, + | item.itemId, item.campaignId + |FROM pageview pv + |LATERAL VIEW EXPLODE(pv.requests) as req + |LATERAL VIEW EXPLODE(req.items) as item + |WHERE item.clicked = true""".stripMargin) + + // clicked needed for filter; inner items struct IS pruned + checkScan(query, + "struct>>>>") + + checkAnswer(query, Seq( + Row("US", "req1", 1L, 500L), + Row("US", "req2", 3L, 502L), + Row("US", "req2", 4L, 503L), + Row("UK", "req3", 6L, 601L) + )) + } + } + + // Test: Many fields per level + testExplodePruning("pageview - many fields per level") { + withPageViewData { + val query = sql( + """SELECT + | pv.pageId, pv.country, pv.platform, + | req.requestId, req.ts, req.source, + | item.itemId, item.clicked, item.visible, item.charged + |FROM pageview pv + |LATERAL VIEW EXPLODE(pv.requests) as req + |LATERAL VIEW EXPLODE(req.items) as item""".stripMargin) + + // Only userId (root) pruned; inner items struct IS pruned (campaignId not needed) + checkScan(query, + "struct>>>>") + + val result = query.collect() + assert(result.length == 7) + assert(result(0) == Row(100L, "US", "desktop", "req1", 1000L, "organic", + 1L, true, true, false)) + } + } + + // ============================================================================ + // Inner Generate Pruning Tests (NestedArraysZip feature) + // These tests verify that fields from inner generates CAN be pruned when + // the source array comes from an outer generate's output. + // ============================================================================ + + /** + * Inner-pruning test data model: + * - Root: id (for identification) + * - outer_array: array> + * - inner_array: array> + * + * This allows testing that inner_f3, inner_f4 can be pruned when only + * inner_f1, inner_f2 are selected. + */ + protected def withInnerPruningData(testThunk: => Unit): Unit = { + withTempPath { dir => + val path = dir.getAbsolutePath + + val innerStruct = StructType(Seq( + StructField("inner_f1", LongType), + StructField("inner_f2", StringType), + StructField("inner_f3", BooleanType), + StructField("inner_f4", DoubleType) + )) + + val outerStruct = StructType(Seq( + StructField("outer_field1", StringType), + StructField("outer_field2", LongType), + StructField("inner_array", ArrayType(innerStruct)) + )) + + val schema = StructType(Seq( + StructField("id", LongType), + StructField("outer_array", ArrayType(outerStruct)) + )) + + // Row 1: 2 outer elements, each with 2 inner elements + // Row 2: 1 outer element with 3 inner elements + val data = Seq( + Row(1L, Seq( + Row("a1", 100L, Seq( + Row(1L, "x1", true, 1.0), + Row(2L, "x2", false, 2.0) + )), + Row("a2", 200L, Seq( + Row(3L, "x3", true, 3.0), + Row(4L, "x4", true, 4.0) + )) + )), + Row(2L, Seq( + Row("b1", 300L, Seq( + Row(5L, "y1", false, 5.0), + Row(6L, "y2", true, 6.0), + Row(7L, "y3", false, 7.0) + )) + )) + ) + + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + .write.format(dataSourceName).save(path + "/inner_pruning") + + spark.read.format(dataSourceName).schema(schema).load(path + "/inner_pruning") + .createOrReplaceTempView("inner_pruning") + + testThunk + } + } + + // Test: Only inner_f1 selected - other inner fields should be pruned + testExplodePruning("inner generate pruning - single field from inner") { + withInnerPruningData { + val query = sql( + """SELECT id, outer_elem.outer_field1, inner_elem.inner_f1 + |FROM inner_pruning + |LATERAL VIEW EXPLODE(outer_array) as outer_elem + |LATERAL VIEW EXPLODE(outer_elem.inner_array) as inner_elem""".stripMargin) + + // Inner fields pruned: only inner_f1 needed (inner_f2, inner_f3, inner_f4 pruned) + // Outer fields: outer_field1 + inner_array (outer_field2 pruned) + checkScan(query, + "struct>>>>") + + checkAnswer(query, Seq( + Row(1L, "a1", 1L), Row(1L, "a1", 2L), + Row(1L, "a2", 3L), Row(1L, "a2", 4L), + Row(2L, "b1", 5L), Row(2L, "b1", 6L), Row(2L, "b1", 7L) + )) + } + } + + // Test: Two inner fields selected - prune the other two + testExplodePruning("inner generate pruning - two fields from inner") { + withInnerPruningData { + val query = sql( + """SELECT inner_elem.inner_f1, inner_elem.inner_f2 + |FROM inner_pruning + |LATERAL VIEW EXPLODE(outer_array) as outer_elem + |LATERAL VIEW EXPLODE(outer_elem.inner_array) as inner_elem""".stripMargin) + + // Only inner_f1 and inner_f2 needed + checkScan(query, + "struct>>>>") + + checkAnswer(query, Seq( + Row(1L, "x1"), Row(2L, "x2"), + Row(3L, "x3"), Row(4L, "x4"), + Row(5L, "y1"), Row(6L, "y2"), Row(7L, "y3") + )) + } + } + + // Test: Non-contiguous inner fields (first and last) + testExplodePruning("inner generate pruning - non-contiguous inner fields") { + withInnerPruningData { + val query = sql( + """SELECT inner_elem.inner_f1, inner_elem.inner_f4 + |FROM inner_pruning + |LATERAL VIEW EXPLODE(outer_array) as outer_elem + |LATERAL VIEW EXPLODE(outer_elem.inner_array) as inner_elem""".stripMargin) + + // inner_f1 and inner_f4 needed, middle fields pruned + checkScan(query, + "struct>>>>") + + checkAnswer(query, Seq( + Row(1L, 1.0), Row(2L, 2.0), + Row(3L, 3.0), Row(4L, 4.0), + Row(5L, 5.0), Row(6L, 6.0), Row(7L, 7.0) + )) + } + } + + // Test: Fields from both outer and inner, with filter on inner + testExplodePruning("inner generate pruning - with filter on inner field") { + withInnerPruningData { + val query = sql( + """SELECT id, outer_elem.outer_field1, inner_elem.inner_f1, inner_elem.inner_f2 + |FROM inner_pruning + |LATERAL VIEW EXPLODE(outer_array) as outer_elem + |LATERAL VIEW EXPLODE(outer_elem.inner_array) as inner_elem + |WHERE inner_elem.inner_f3 = true""".stripMargin) + + // Filter requires inner_f3, but it's not in SELECT, so... + // Note: Filter pushdown may or may not include inner_f3 in scan + // For now, just verify query correctness + checkAnswer(query, Seq( + Row(1L, "a1", 1L, "x1"), + Row(1L, "a2", 3L, "x3"), Row(1L, "a2", 4L, "x4"), + Row(2L, "b1", 6L, "y2") + )) + } + } + + // Test: Posexplode at inner level with field pruning + testExplodePruning("inner generate pruning - posexplode inner") { + withInnerPruningData { + val query = sql( + """SELECT outer_elem.outer_field1, inner_pos, inner_elem.inner_f1 + |FROM inner_pruning + |LATERAL VIEW EXPLODE(outer_array) as outer_elem + |LATERAL VIEW POSEXPLODE(outer_elem.inner_array) as inner_pos, inner_elem""".stripMargin) + + // Inner pruned to just inner_f1 + checkScan(query, + "struct>>>>") + + checkAnswer(query, Seq( + Row("a1", 0, 1L), Row("a1", 1, 2L), + Row("a2", 0, 3L), Row("a2", 1, 4L), + Row("b1", 0, 5L), Row("b1", 1, 6L), Row("b1", 2, 7L) + )) + } + } + + // Test: Posexplode at outer level, explode at inner with pruning + testExplodePruning("inner generate pruning - posexplode outer explode inner") { + withInnerPruningData { + val query = sql( + """SELECT outer_pos, inner_elem.inner_f2 + |FROM inner_pruning + |LATERAL VIEW POSEXPLODE(outer_array) as outer_pos, outer_elem + |LATERAL VIEW EXPLODE(outer_elem.inner_array) as inner_elem""".stripMargin) + + // Only position from outer, only inner_f2 from inner + checkScan(query, + "struct>>>>") + + checkAnswer(query, Seq( + Row(0, "x1"), Row(0, "x2"), + Row(1, "x3"), Row(1, "x4"), + Row(0, "y1"), Row(0, "y2"), Row(0, "y3") + )) + } + } + + // Test: All inner fields selected - no inner pruning + testExplodePruning("inner generate pruning - all inner fields no pruning") { + withInnerPruningData { + val query = sql( + """SELECT inner_elem.* + |FROM inner_pruning + |LATERAL VIEW EXPLODE(outer_array) as outer_elem + |LATERAL VIEW EXPLODE(outer_elem.inner_array) as inner_elem""".stripMargin) + + // All inner fields needed, but outer_field1, outer_field2 can be pruned + checkScan(query, + "struct>>>>") + + val result = query.collect() + assert(result.length == 7) + } + } + + // Test: Mixed selection from all levels + testExplodePruning("inner generate pruning - mixed levels with pruning") { + withInnerPruningData { + val query = sql( + """SELECT id, outer_elem.outer_field2, inner_elem.inner_f3 + |FROM inner_pruning + |LATERAL VIEW EXPLODE(outer_array) as outer_elem + |LATERAL VIEW EXPLODE(outer_elem.inner_array) as inner_elem""".stripMargin) + + // Root: id, Outer: outer_field2 + inner_array, Inner: inner_f3 only + checkScan(query, + "struct>>>>") + + checkAnswer(query, Seq( + Row(1L, 100L, true), Row(1L, 100L, false), + Row(1L, 200L, true), Row(1L, 200L, true), + Row(2L, 300L, false), Row(2L, 300L, true), Row(2L, 300L, false) + )) + } + } + + // Test: Triple nesting with innermost pruning + /** + * Triple-nested data for testing deepest level pruning: + * - Root: id + * - level1: array> + * - level2: array> + * - level3: array> + */ + protected def withTripleNestedData(testThunk: => Unit): Unit = { + withTempPath { dir => + val path = dir.getAbsolutePath + + val l3Struct = StructType(Seq( + StructField("l3_f1", LongType), + StructField("l3_f2", StringType), + StructField("l3_f3", BooleanType) + )) + + val l2Struct = StructType(Seq( + StructField("l2_field", StringType), + StructField("level3", ArrayType(l3Struct)) + )) + + val l1Struct = StructType(Seq( + StructField("l1_field", StringType), + StructField("level2", ArrayType(l2Struct)) + )) + + val schema = StructType(Seq( + StructField("id", LongType), + StructField("level1", ArrayType(l1Struct)) + )) + + val data = Seq( + Row(1L, Seq( + Row("L1A", Seq( + Row("L2A", Seq(Row(1L, "a", true), Row(2L, "b", false))) + )) + )) + ) + + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + .write.format(dataSourceName).save(path + "/triple_nested") + + spark.read.format(dataSourceName).schema(schema).load(path + "/triple_nested") + .createOrReplaceTempView("triple_nested") + + testThunk + } + } + + testExplodePruning("inner generate pruning - triple nesting deepest field") { + withTripleNestedData { + val query = sql( + """SELECT l1.l1_field, l3.l3_f1 + |FROM triple_nested + |LATERAL VIEW EXPLODE(level1) as l1 + |LATERAL VIEW EXPLODE(l1.level2) as l2 + |LATERAL VIEW EXPLODE(l2.level3) as l3""".stripMargin) + + // Only l3_f1 from deepest level (l3_f2, l3_f3 pruned) + checkScan(query, + "struct>>>>>>") + + checkAnswer(query, Seq( + Row("L1A", 1L), Row("L1A", 2L) + )) + } + } + + // ============================================================================ + // Code Review Integration Tests (Feb 2026) + // Additional tests for edge cases identified during code review + // ============================================================================ + + // Test 1: Inner-generate pruning through nested arrays (core regression test) + // Validates that inner generate's array from outer generate output is pruned + protected def withBlockedItemsData(testThunk: => Unit): Unit = { + withTempPath { dir => + val path = dir.getAbsolutePath + + val itemStruct = StructType(Seq( + StructField("itemId", LongType), + StructField("blockReason", StringType), + StructField("score", DoubleType), + StructField("timestamp", LongType) + )) + + val requestStruct = StructType(Seq( + StructField("requestId", StringType), + StructField("blockedItemsList", ArrayType(itemStruct)) + )) + + val schema = StructType(Seq( + StructField("pageId", LongType), + StructField("requests", ArrayType(requestStruct)) + )) + + val data = Seq( + Row(1L, Seq( + Row("req1", Seq( + Row(100L, "policy_violation", 0.9, 1000L), + Row(101L, "spam", 0.8, 1001L) + )), + Row("req2", Seq( + Row(102L, "duplicate", 0.7, 1002L) + )) + )), + Row(2L, Seq( + Row("req3", Seq( + Row(200L, "low_quality", 0.5, 2000L) + )) + )) + ) + + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + .write.format(dataSourceName).save(path + "/blocked_items") + + spark.read.format(dataSourceName).schema(schema).load(path + "/blocked_items") + .createOrReplaceTempView("blocked_items") + + testThunk + } + } + + testExplodePruning("inner-generate pruning - blockedItemsList only blockReason") { + withBlockedItemsData { + // Explode requests, then explode blockedItemsList from each request + // Only select blockReason from the inner items + val query = sql( + """SELECT req.requestId, item.blockReason + |FROM blocked_items + |LATERAL VIEW EXPLODE(requests) as req + |LATERAL VIEW EXPLODE(req.blockedItemsList) as item""".stripMargin) + + // blockedItemsList should be pruned to only include blockReason + // itemId, score, timestamp should NOT be in scan schema + checkScan(query, + "struct>>>>") + + checkAnswer(query, Seq( + Row("req1", "policy_violation"), + Row("req1", "spam"), + Row("req2", "duplicate"), + Row("req3", "low_quality") + )) + } + } + + // Test 2: Nested array inside struct + multi-field selection + protected def withWrapperStructData(testThunk: => Unit): Unit = { + withTempPath { dir => + val path = dir.getAbsolutePath + + val elemStruct = StructType(Seq( + StructField("x", LongType), + StructField("y", LongType), + StructField("z", LongType) + )) + + val wrapperStruct = StructType(Seq( + StructField("name", StringType), + StructField("arr", ArrayType(elemStruct)) + )) + + val schema = StructType(Seq( + StructField("id", LongType), + StructField("wrapper", wrapperStruct) + )) + + val data = Seq( + Row(1L, Row("first", Seq(Row(10L, 20L, 30L), Row(11L, 21L, 31L)))), + Row(2L, Row("second", Seq(Row(12L, 22L, 32L)))) + ) + + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + .write.format(dataSourceName).save(path + "/wrapper_struct") + + spark.read.format(dataSourceName).schema(schema).load(path + "/wrapper_struct") + .createOrReplaceTempView("wrapper_struct") + + testThunk + } + } + + // NOTE: Multi-field selection from nested array inside struct is a known limitation. + // When ColumnPruning creates intermediate aliases for nested fields, our rule cannot + // trace the aliases back to scan attributes. We verify correctness but not pruning. + testExplodePruning("nested array inside struct - multi-field selection (correctness)") { + withWrapperStructData { + // Select wrapper.name and non-contiguous fields x, z from exploded array + val query = sql( + """SELECT wrapper.name, e.x, e.z + |FROM wrapper_struct + |LATERAL VIEW EXPLODE(wrapper.arr) as e""".stripMargin) + + // Pruning doesn't currently work for nested array inside struct with multi-field. + // This test verifies the query executes correctly. + // When this limitation is fixed, update expected schema to: + // "struct>>>" + checkScan(query, + "struct>>>") + + checkAnswer(query, Seq( + Row("first", 10L, 30L), + Row("first", 11L, 31L), + Row("second", 12L, 32L) + )) + } + } + + // Test 3: pos-only posexplode chooses minimal-weight field + protected def withMinimalWeightData(testThunk: => Unit): Unit = { + withTempPath { dir => + val path = dir.getAbsolutePath + + // Struct with fields of different sizes: boolean < int < string + val elemStruct = StructType(Seq( + StructField("b", BooleanType), + StructField("i", IntegerType), + StructField("s", StringType) + )) + + val schema = StructType(Seq( + StructField("id", LongType), + StructField("items", ArrayType(elemStruct)) + )) + + val data = Seq( + Row(1L, Seq(Row(true, 100, "long_string_value_1"), + Row(false, 200, "long_string_value_2"))), + Row(2L, Seq(Row(true, 300, "long_string_value_3"))) + ) + + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + .write.format(dataSourceName).save(path + "/minimal_weight") + + spark.read.format(dataSourceName).schema(schema).load(path + "/minimal_weight") + .createOrReplaceTempView("minimal_weight") + + testThunk + } + } + + testExplodePruning("pos-only posexplode chooses minimal-weight field") { + withMinimalWeightData { + // Only select position, not any fields from the struct + val query = sql( + """SELECT id, pos + |FROM minimal_weight + |LATERAL VIEW POSEXPLODE(items) as pos, item""".stripMargin) + + // Should scan with minimal-weight field (boolean 'b') + // Not the full struct with int and string + checkScan(query, + "struct>>") + + checkAnswer(query, Seq( + Row(1L, 0), + Row(1L, 1), + Row(2L, 0) + )) + } + } + + // Test 4: Filters on source array fields are preserved + testExplodePruning("filter on source array field preserved in scan") { + withWrapperStructData { + // Select only x, but filter on y - both should be in scan + val query = sql( + """SELECT e.x + |FROM wrapper_struct + |LATERAL VIEW EXPLODE(wrapper.arr) as e + |WHERE wrapper.arr.y IS NOT NULL""".stripMargin) + + // Scan must include both x (projection) and y (filter) + checkScan(query, + "struct>>>") + + checkAnswer(query, Seq( + Row(10L), Row(11L), Row(12L) + )) + } + } + + // Test 5: Non-contiguous fields + ordinal stability + protected def withFiveFieldData(testThunk: => Unit): Unit = { + withTempPath { dir => + val path = dir.getAbsolutePath + + val elemStruct = StructType(Seq( + StructField("a", LongType), + StructField("b", LongType), + StructField("c", LongType), + StructField("d", LongType), + StructField("e", LongType) + )) + + val schema = StructType(Seq( + StructField("id", LongType), + StructField("arr", ArrayType(elemStruct)) + )) + + val data = Seq( + Row(1L, Seq(Row(1L, 2L, 3L, 4L, 5L), Row(10L, 20L, 30L, 40L, 50L))), + Row(2L, Seq(Row(100L, 200L, 300L, 400L, 500L))) + ) + + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + .write.format(dataSourceName).save(path + "/five_field") + + spark.read.format(dataSourceName).schema(schema).load(path + "/five_field") + .createOrReplaceTempView("five_field") + + testThunk + } + } + + testExplodePruning("non-contiguous fields ordinal stability") { + withFiveFieldData { + // Select fields a, c, e (skip b, d) - tests ordinal fix for gaps + val query = sql( + """SELECT e.a, e.c, e.e + |FROM five_field + |LATERAL VIEW EXPLODE(arr) as e""".stripMargin) + + // Should only include a, c, e in scan schema + checkScan(query, + "struct>>") + + // Verify correct values despite ordinal gaps + checkAnswer(query, Seq( + Row(1L, 3L, 5L), + Row(10L, 30L, 50L), + Row(100L, 300L, 500L) + )) + } + } + + // Test 6: Alias chain (GNA + ColumnPruning interaction) + testExplodePruning("alias chain with GNA interaction") { + withFiveFieldData { + // Subquery with alias creates intermediate Project + // Tests that pruning traces through alias definitions + val query = sql( + """WITH aliased AS ( + | SELECT id, arr as my_arr + | FROM five_field + |) + |SELECT e.a, e.d + |FROM aliased + |LATERAL VIEW EXPLODE(my_arr) as e""".stripMargin) + + // Pruning should still work through the alias + checkScan(query, + "struct>>") + + checkAnswer(query, Seq( + Row(1L, 4L), + Row(10L, 40L), + Row(100L, 400L) + )) + } + } + + // Test 7: Depth-2 nested arrays (array>) + protected def withDepth2NestedData(testThunk: => Unit): Unit = { + withTempPath { dir => + val path = dir.getAbsolutePath + + val innerStruct = StructType(Seq( + StructField("x", IntegerType), + StructField("y", IntegerType), + StructField("z", IntegerType) + )) + + val schema = StructType(Seq( + StructField("id", LongType), + StructField("deep", ArrayType(ArrayType(innerStruct))) + )) + + // array>> + val data = Seq( + Row(1L, Seq( + Seq(Row(1, 10, 100), Row(2, 20, 200)), + Seq(Row(3, 30, 300)) + )), + Row(2L, Seq( + Seq(Row(4, 40, 400)) + )) + ) + + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + .write.format(dataSourceName).save(path + "/depth2_nested") + + spark.read.format(dataSourceName).schema(schema).load(path + "/depth2_nested") + .createOrReplaceTempView("depth2_nested") + + testThunk + } + } + + testExplodePruning("depth-2 nested arrays - single inner field") { + withDepth2NestedData { + // Explode array> to get array, then explode again + val query = sql( + """SELECT id, inner_elem.x + |FROM depth2_nested + |LATERAL VIEW EXPLODE(deep) as outer_arr + |LATERAL VIEW EXPLODE(outer_arr) as inner_elem""".stripMargin) + + // Inner struct should be pruned to only x (y, z not needed) + checkScan(query, + "struct>>>") + + checkAnswer(query, Seq( + Row(1L, 1), Row(1L, 2), Row(1L, 3), + Row(2L, 4) + )) + } + } + + testExplodePruning("depth-2 nested arrays - multi inner field") { + withDepth2NestedData { + // Select x and z, pruning y + val query = sql( + """SELECT id, inner_elem.x, inner_elem.z + |FROM depth2_nested + |LATERAL VIEW EXPLODE(deep) as outer_arr + |LATERAL VIEW EXPLODE(outer_arr) as inner_elem""".stripMargin) + + // Inner struct pruned to x and z + checkScan(query, + "struct>>>") + + checkAnswer(query, Seq( + Row(1L, 1, 100), Row(1L, 2, 200), Row(1L, 3, 300), + Row(2L, 4, 400) + )) + } + } + + testExplodePruning("depth-2 nested arrays - filter on pruned fields") { + withDepth2NestedData { + // Select x and z (prune y), with filter on x + // Exercises filter ordinal fixup path with GetNestedArrayStructFields + val query = sql( + """SELECT id, inner_elem.x, inner_elem.z + |FROM depth2_nested + |LATERAL VIEW EXPLODE(deep) as outer_arr + |LATERAL VIEW EXPLODE(outer_arr) as inner_elem + |WHERE inner_elem.x > 2""".stripMargin) + + checkScan(query, + "struct>>>") + + checkAnswer(query, Seq( + Row(1L, 3, 300), + Row(2L, 4, 400) + )) + } + } + + testExplodePruning("depth-2 nested arrays - posexplode with pruning") { + withDepth2NestedData { + // POSEXPLODE on depth-2 array, select x only (prune y and z) + val query = sql( + """SELECT id, pos, inner_elem.x + |FROM depth2_nested + |LATERAL VIEW EXPLODE(deep) as outer_arr + |LATERAL VIEW POSEXPLODE(outer_arr) as pos, inner_elem""".stripMargin) + + checkScan(query, + "struct>>>") + + checkAnswer(query, Seq( + Row(1L, 0, 1), Row(1L, 1, 2), Row(1L, 0, 3), + Row(2L, 0, 4) + )) + } + } + + // ========================================================================= + // Leaf filter ordinal fixup tests + // + // These tests verify that filters on the source array column (not on + // exploded elements) have correct ordinals after pruning. When such filters + // reference struct fields via GetArrayStructFields, the ordinals must be + // updated to match the pruned schema. + // ========================================================================= + + testExplodePruning("leaf filter on source array field - ordinal fixup after pruning") { + withMixedArrayData { + // SELECT item.name with filter on items.qty (source array struct field access). + // items.qty generates GetArrayStructFields(items, qty, 2, 3, false) to extract + // the qty field from each struct element. After pruning detail (not needed), + // the ordinal must be rewritten: GetArrayStructFields(_pruned, qty, 1, 2, false). + val query = sql( + """SELECT item.name + |FROM mixed_array + |LATERAL VIEW EXPLODE(items) AS item + |WHERE size(items.qty) > 1""".stripMargin) + + // Scan needs name (projected) and qty (leaf filter via items.qty). + // detail is not referenced anywhere, so it's pruned. + checkScan(query, + "struct>>") + + // id=1: items.qty = [5,3], size=2 > 1 -> apple, banana included + // id=2: items.qty = [1], size=1 > 1 -> cherry excluded + checkAnswer(query, + Row("apple") :: Row("banana") :: Nil) + } + } + + testExplodePruning("depth-2 nested arrays - leaf filter on source column") { + withDepth2NestedData { + // Filter on source column (deep IS NOT NULL) is a leaf filter that + // references the source array directly. After pruning to struct, + // the direct reference is rewritten to use the pruned attribute. + val query = sql( + """SELECT id, inner_elem.x + |FROM depth2_nested + |LATERAL VIEW EXPLODE(deep) as outer_arr + |LATERAL VIEW EXPLODE(outer_arr) as inner_elem + |WHERE deep IS NOT NULL""".stripMargin) + + checkScan(query, + "struct>>>") + + checkAnswer(query, Seq( + Row(1L, 1), Row(1L, 2), Row(1L, 3), + Row(2L, 4) + )) + } + } + + // ========================================================================= + // OUTER POSEXPLODE with aggregation tests + // + // These tests verify that schema pruning works correctly when aggregation + // is combined with OUTER POSEXPLODE. The PruneNestedFieldsThroughGenerateForScan + // rule properly tracks field requirements through Aggregate nodes. + // ========================================================================= + + // Test with single OUTER POSEXPLODE + aggregation - works correctly + testExplodePruning("OUTER POSEXPLODE with aggregation - large struct") { + withSampleData { + val query = sql( + """SELECT complex.col1, COUNT(*) as cnt, SUM(complex.col1) as total + |FROM sample + |LATERAL VIEW OUTER POSEXPLODE(someComplexArray) AS idx, complex + |WHERE complex.col1 IS NOT NULL + |GROUP BY complex.col1""".stripMargin) + + // Only col1 needed, col2 should be pruned + checkScan(query, + "struct>>") + + checkAnswer(query, Row(1L, 1, 1L) :: Nil) + } + } + + // ========================================================================= + // Advanced SQL Constructs - Window Functions + // ========================================================================= + + testExplodePruning("window function - ROW_NUMBER over exploded data") { + withDoubleNestedData { + val query = sql( + """SELECT a, a_item.b, a_item.b_int, + | ROW_NUMBER() OVER (PARTITION BY a ORDER BY a_item.b_int) as rn + |FROM double_nested + |LATERAL VIEW EXPLODE(a_array) AS a_item""".stripMargin) + + // Only b and b_int needed from a_array element, b_string and b_array pruned + checkScan(query, + "struct>>") + + // Check results include row numbers + val result = query.collect() + assert(result.length == 4) + // ROW_NUMBER returns Int, not Long + assert(result.map(_.getInt(3)).toSet == Set(1, 2)) + } + } + + testExplodePruning("window function - LAG/LEAD over exploded data") { + withDoubleNestedData { + val query = sql( + """SELECT a, a_item.b, + | LAG(a_item.b_int, 1) OVER (PARTITION BY a ORDER BY a_item.b_int) as prev_val, + | LEAD(a_item.b_int, 1) OVER (PARTITION BY a ORDER BY a_item.b_int) as next_val + |FROM double_nested + |LATERAL VIEW EXPLODE(a_array) AS a_item""".stripMargin) + + checkScan(query, + "struct>>") + + query.collect() + } + } + + testExplodePruning("window function - SUM OVER with exploded nested data") { + withDoubleNestedData { + val query = sql( + """SELECT a, a_item.b, b_item.c_int, + | SUM(b_item.c_int) OVER (PARTITION BY a ORDER BY a_item.b) as running_sum + |FROM double_nested + |LATERAL VIEW EXPLODE(a_array) AS a_item + |LATERAL VIEW EXPLODE(a_item.b_array) AS b_item""".stripMargin) + + checkScan(query, + "struct>>>>") + + query.collect() + } + } + + testExplodePruning("window function - RANK and DENSE_RANK") { + withSampleData { + val query = sql( + """SELECT someStr, complex.col1, + | RANK() OVER (ORDER BY complex.col1) as rnk, + | DENSE_RANK() OVER (ORDER BY complex.col1) as dense_rnk + |FROM sample + |LATERAL VIEW EXPLODE(someComplexArray) AS complex""".stripMargin) + + checkScan(query, + "struct>>") + + checkAnswer(query, Row("bla", 1L, 1, 1) :: Nil) + } + } + + // ========================================================================= + // Advanced SQL Constructs - Higher-Order Functions + // ========================================================================= + + testExplodePruning("higher-order function - TRANSFORM on exploded array") { + withSampleData { + val query = sql( + """SELECT someStr, + | TRANSFORM(someComplexArray, x -> x.col1 * 2) as doubled + |FROM sample""".stripMargin) + + // TRANSFORM uses higher-order functions, not generators - full struct is read + // This is a known limitation: higher-order function pruning is separate from + // generator-based pruning handled by PruneNestedFieldsThroughGenerateForScan + checkScan(query, + "struct>>") + + checkAnswer(query, Row("bla", Array(2L)) :: Nil) + } + } + + testExplodePruning("higher-order function - FILTER on array") { + withSampleData { + val query = sql( + """SELECT someStr, + | FILTER(someComplexArray, x -> x.col1 > 0) as filtered + |FROM sample""".stripMargin) + + checkScan(query, + "struct>>") + + checkAnswer(query, Row("bla", Array(Row(1L, 2L))) :: Nil) + } + } + + testExplodePruning("higher-order function - AGGREGATE on array") { + withSampleData { + val query = sql( + """SELECT someStr, + | AGGREGATE(someComplexArray, 0L, (acc, x) -> acc + x.col1) as total + |FROM sample""".stripMargin) + + // AGGREGATE uses higher-order functions, not generators - full struct is read + checkScan(query, + "struct>>") + + checkAnswer(query, Row("bla", 1L) :: Nil) + } + } + + testExplodePruning("higher-order function - EXISTS on array") { + withSampleData { + val query = sql( + """SELECT someStr, + | EXISTS(someComplexArray, x -> x.col1 > 0) as has_positive + |FROM sample""".stripMargin) + + // EXISTS uses higher-order functions, not generators - full struct is read + checkScan(query, + "struct>>") + + checkAnswer(query, Row("bla", true) :: Nil) + } + } + + // ========================================================================= + // Advanced SQL Constructs - UNION Operations + // ========================================================================= + + testExplodePruning("UNION ALL with exploded data") { + withDoubleNestedData { + val query = sql( + """SELECT a, a_item.b FROM double_nested + |LATERAL VIEW EXPLODE(a_array) AS a_item + |WHERE a = 'a1' + |UNION ALL + |SELECT a, a_item.b FROM double_nested + |LATERAL VIEW EXPLODE(a_array) AS a_item + |WHERE a = 'a2'""".stripMargin) + + // Both branches should prune to same schema + checkScan(query, + "struct>>", + "struct>>") + + assert(query.collect().length == 4) + } + } + + testExplodePruning("UNION with different field selections") { + withDoubleNestedData { + val query = sql( + """SELECT a, a_item.b as field FROM double_nested + |LATERAL VIEW EXPLODE(a_array) AS a_item + |UNION + |SELECT a, a_item.b_string as field FROM double_nested + |LATERAL VIEW EXPLODE(a_array) AS a_item""".stripMargin) + + // First branch needs b, second needs b_string + checkScan(query, + "struct>>", + "struct>>") + + query.collect() + } + } + + // ========================================================================= + // Advanced SQL Constructs - JOIN Operations + // ========================================================================= + + testExplodePruning("INNER JOIN on exploded data") { + withDoubleNestedData { + // Create exploded views for join - LATERAL VIEW must be part of the FROM clause + spark.sql( + """SELECT a, a_item.b, a_item.b_int + |FROM double_nested + |LATERAL VIEW EXPLODE(a_array) AS a_item""".stripMargin + ).createOrReplaceTempView("exploded1") + + spark.sql( + """SELECT a as a2, a_item.b as b2, a_item.b_int as b_int2 + |FROM double_nested + |LATERAL VIEW EXPLODE(a_array) AS a_item""".stripMargin + ).createOrReplaceTempView("exploded2") + + val query = sql( + """SELECT e1.a, e1.b, e2.b_int2 + |FROM exploded1 e1 + |INNER JOIN exploded2 e2 + |ON e1.a = e2.a2 AND e1.b = e2.b2""".stripMargin) + + // Each branch should be pruned independently + checkScan(query, + "struct>>", + "struct>>") + + query.collect() + } + } + + testExplodePruning("LEFT JOIN with exploded data and NULL handling") { + withDoubleNestedData { + // Create exploded views - LATERAL VIEW must be part of the FROM clause + spark.sql( + """SELECT a, a_item.b, a_item.b_int + |FROM double_nested + |LATERAL VIEW EXPLODE(a_array) AS a_item + |WHERE a = 'a1'""".stripMargin + ).createOrReplaceTempView("left_exploded") + + spark.sql( + """SELECT a_item.b_int as r_b_int + |FROM double_nested + |LATERAL VIEW EXPLODE(a_array) AS a_item + |WHERE a = 'a2'""".stripMargin + ).createOrReplaceTempView("right_exploded") + + val query = sql( + """SELECT l.a, l.b, r.r_b_int + |FROM left_exploded l + |LEFT JOIN right_exploded r + |ON l.b_int = r.r_b_int""".stripMargin) + + // Both branches need 'a' because of WHERE clause filter + checkScan(query, + "struct>>", + "struct>>") + + query.collect() + } + } + + // ========================================================================= + // Advanced SQL Constructs - Subqueries + // ========================================================================= + + testExplodePruning("scalar subquery with exploded data") { + withDoubleNestedData { + val query = sql( + """SELECT a, + | (SELECT MAX(sub_item.b_int) + | FROM double_nested sub + | LATERAL VIEW EXPLODE(sub.a_array) AS sub_item + | WHERE sub.a = double_nested.a) as max_b_int + |FROM double_nested""".stripMargin) + + // Main query needs a, subquery needs a and b_int + query.collect() + } + } + + testExplodePruning("IN subquery with exploded data") { + withDoubleNestedData { + val query = sql( + """SELECT a, a_item.b + |FROM double_nested + |LATERAL VIEW EXPLODE(a_array) AS a_item + |WHERE a_item.b_int IN ( + | SELECT b_item.c_int + | FROM double_nested sub + | LATERAL VIEW EXPLODE(sub.a_array) AS sub_a_item + | LATERAL VIEW EXPLODE(sub_a_item.b_array) AS b_item + |)""".stripMargin) + + query.collect() + } + } + + testExplodePruning("EXISTS subquery with exploded data") { + withDoubleNestedData { + val query = sql( + """SELECT a, a_item.b + |FROM double_nested + |LATERAL VIEW EXPLODE(a_array) AS a_item + |WHERE EXISTS ( + | SELECT 1 + | FROM double_nested sub + | LATERAL VIEW EXPLODE(sub.a_array) AS sub_item + | WHERE sub_item.b_int > a_item.b_int + |)""".stripMargin) + + query.collect() + } + } + + // ========================================================================= + // Advanced SQL Constructs - DISTINCT and HAVING + // ========================================================================= + + testExplodePruning("SELECT DISTINCT with exploded fields") { + withDoubleNestedData { + val query = sql( + """SELECT DISTINCT a_item.b_int + |FROM double_nested + |LATERAL VIEW EXPLODE(a_array) AS a_item""".stripMargin) + + checkScan(query, + "struct>>") + + checkAnswer(query, Row(1L) :: Row(2L) :: Row(3L) :: Row(4L) :: Nil) + } + } + + testExplodePruning("GROUP BY with HAVING on exploded data") { + withDoubleNestedData { + val query = sql( + """SELECT a, a_item.b, COUNT(*) as cnt + |FROM double_nested + |LATERAL VIEW EXPLODE(a_array) AS a_item + |LATERAL VIEW EXPLODE(a_item.b_array) AS b_item + |GROUP BY a, a_item.b + |HAVING COUNT(*) >= 2""".stripMargin) + + // Inner array can't be pruned to empty struct - full inner struct is read + checkScan(query, + "struct>>>>") + + val results = query.collect() + assert(results.forall(_.getLong(2) >= 2)) + } + } + + // ========================================================================= + // Advanced SQL Constructs - CUBE/ROLLUP/GROUPING SETS + // ========================================================================= + + testExplodePruning("ROLLUP with exploded data") { + withDoubleNestedData { + val query = sql( + """SELECT a, a_item.b, SUM(a_item.b_int) as total + |FROM double_nested + |LATERAL VIEW EXPLODE(a_array) AS a_item + |GROUP BY ROLLUP(a, a_item.b)""".stripMargin) + + // ROLLUP changes plan structure - full struct is read + checkScan(query, + "struct>>>>") + + query.collect() + } + } + + testExplodePruning("CUBE with exploded data") { + withDoubleNestedData { + val query = sql( + """SELECT a, a_item.b, COUNT(*) as cnt + |FROM double_nested + |LATERAL VIEW EXPLODE(a_array) AS a_item + |GROUP BY CUBE(a, a_item.b)""".stripMargin) + + checkScan(query, + "struct>>") + + query.collect() + } + } + + testExplodePruning("GROUPING SETS with exploded data") { + withDoubleNestedData { + val query = sql( + """SELECT a, a_item.b, a_item.b_int, COUNT(*) as cnt + |FROM double_nested + |LATERAL VIEW EXPLODE(a_array) AS a_item + |GROUP BY GROUPING SETS ((a), (a_item.b), (a, a_item.b, a_item.b_int))""".stripMargin) + + checkScan(query, + "struct>>") + + query.collect() + } + } + + // ========================================================================= + // Advanced SQL Constructs - CASE WHEN Expressions + // ========================================================================= + + testExplodePruning("CASE WHEN with exploded fields") { + withDoubleNestedData { + val query = sql( + """SELECT a, + | CASE + | WHEN a_item.b_int > 2 THEN 'high' + | WHEN a_item.b_int > 1 THEN 'medium' + | ELSE 'low' + | END as category, + | a_item.b + |FROM double_nested + |LATERAL VIEW EXPLODE(a_array) AS a_item""".stripMargin) + + checkScan(query, + "struct>>") + + query.collect() + } + } + + testExplodePruning("nested CASE WHEN with multiple explode levels") { + withDoubleNestedData { + val query = sql( + """SELECT a, + | CASE + | WHEN b_item.c_int > 0 THEN a_item.b + | ELSE 'unknown' + | END as result + |FROM double_nested + |LATERAL VIEW EXPLODE(a_array) AS a_item + |LATERAL VIEW EXPLODE(a_item.b_array) AS b_item""".stripMargin) + + checkScan(query, + "struct>>>>") + + query.collect() + } + } + + // ========================================================================= + // Advanced SQL Constructs - Complex CTEs + // ========================================================================= + + testExplodePruning("multiple CTEs with exploded data") { + withDoubleNestedData { + val query = sql( + """WITH cte1 AS ( + | SELECT a, a_item.b, a_item.b_int + | FROM double_nested + | LATERAL VIEW EXPLODE(a_array) AS a_item + |), + |cte2 AS ( + | SELECT a, b, SUM(b_int) as total + | FROM cte1 + | GROUP BY a, b + |) + |SELECT * FROM cte2 WHERE total > 0""".stripMargin) + + checkScan(query, + "struct>>") + + query.collect() + } + } + + testExplodePruning("CTE with window function over exploded data") { + withDoubleNestedData { + val query = sql( + """WITH ranked AS ( + | SELECT a, a_item.b, a_item.b_int, + | ROW_NUMBER() OVER (PARTITION BY a ORDER BY a_item.b_int DESC) as rn + | FROM double_nested + | LATERAL VIEW EXPLODE(a_array) AS a_item + |) + |SELECT a, b, b_int FROM ranked WHERE rn = 1""".stripMargin) + + checkScan(query, + "struct>>") + + checkAnswer(query, Row("a1", "a1_b2", 2L) :: Row("a2", "a2_b2", 4L) :: Nil) + } + } + + // ========================================================================= + // Advanced SQL Constructs - ORDER BY and LIMIT + // ========================================================================= + + testExplodePruning("ORDER BY on exploded nested field") { + withDoubleNestedData { + val query = sql( + """SELECT a, a_item.b, b_item.c + |FROM double_nested + |LATERAL VIEW EXPLODE(a_array) AS a_item + |LATERAL VIEW EXPLODE(a_item.b_array) AS b_item + |ORDER BY b_item.c_int DESC + |LIMIT 5""".stripMargin) + + checkScan(query, + "struct>>>>") + + val result = query.collect() + assert(result.length == 5) + } + } + + testExplodePruning("complex ORDER BY with multiple explode levels") { + withDoubleNestedData { + val query = sql( + """SELECT a, a_item.b, b_item.c_int + |FROM double_nested + |LATERAL VIEW EXPLODE(a_array) AS a_item + |LATERAL VIEW EXPLODE(a_item.b_array) AS b_item + |ORDER BY a DESC, a_item.b_int ASC, b_item.c_int DESC""".stripMargin) + + checkScan(query, + "struct>>>>") + + query.collect() + } + } + + // ========================================================================= + // Advanced SQL Constructs - COALESCE and NULL handling + // ========================================================================= + + testExplodePruning("COALESCE with exploded fields") { + withDoubleNestedData { + val query = sql( + """SELECT a, + | COALESCE(a_item.b, 'default') as b_value, + | COALESCE(a_item.b_int, 0) as b_int_value + |FROM double_nested + |LATERAL VIEW OUTER EXPLODE(a_array) AS a_item""".stripMargin) + + checkScan(query, + "struct>>") + + query.collect() + } + } + + testExplodePruning("IFNULL and NULLIF with exploded data") { + withDoubleNestedData { + val query = sql( + """SELECT a, + | IFNULL(a_item.b_int, -1) as safe_int, + | NULLIF(a_item.b, '') as non_empty_b + |FROM double_nested + |LATERAL VIEW EXPLODE(a_array) AS a_item""".stripMargin) + + checkScan(query, + "struct>>") + + query.collect() + } + } + + // ========================================================================= + // Advanced SQL Constructs - String Functions + // ========================================================================= + + testExplodePruning("string functions on exploded fields") { + withDoubleNestedData { + val query = sql( + """SELECT a, + | UPPER(a_item.b) as upper_b, + | LENGTH(a_item.b_string) as len, + | CONCAT(a_item.b, '-', a_item.b_string) as combined + |FROM double_nested + |LATERAL VIEW EXPLODE(a_array) AS a_item""".stripMargin) + + checkScan(query, + "struct>>") + + query.collect() + } + } + + // ========================================================================= + // Advanced SQL Constructs - Date/Time Functions + // ========================================================================= + + testExplodePruning("arithmetic on exploded numeric fields") { + withDoubleNestedData { + val query = sql( + """SELECT a, + | a_item.b_int + 100 as adjusted, + | a_item.b_int * 2 as doubled, + | CAST(a_item.b_int AS DOUBLE) / 3.0 as third + |FROM double_nested + |LATERAL VIEW EXPLODE(a_array) AS a_item""".stripMargin) + + checkScan(query, + "struct>>") + + query.collect() + } + } + + // ========================================================================= + // Advanced SQL Constructs - Array Aggregate Functions + // ========================================================================= + + testExplodePruning("COLLECT_LIST on exploded data") { + withDoubleNestedData { + val query = sql( + """SELECT a, COLLECT_LIST(a_item.b) as all_bs + |FROM double_nested + |LATERAL VIEW EXPLODE(a_array) AS a_item + |GROUP BY a""".stripMargin) + + checkScan(query, + "struct>>") + + val result = query.collect() + assert(result.length == 2) + } + } + + testExplodePruning("COLLECT_SET on nested exploded data") { + withDoubleNestedData { + val query = sql( + """SELECT a, COLLECT_SET(b_item.c_int) as unique_c_ints + |FROM double_nested + |LATERAL VIEW EXPLODE(a_array) AS a_item + |LATERAL VIEW EXPLODE(a_item.b_array) AS b_item + |GROUP BY a""".stripMargin) + + checkScan(query, + "struct>>>>") + + query.collect() + } + } + + // ========================================================================= + // Infrastructure Layer Pattern - CTE with SELECT * + // ========================================================================= + + // This test validates a common real-world pattern where an "infra" layer + // defines a base CTE using SELECT * for multiple explode levels, but the + // final query only selects specific columns. Schema pruning should still + // apply based on what the outer query actually uses. + testExplodePruning("infrastructure layer CTE with SELECT * and specific outer select") { + withDoubleNestedData { + val query = sql( + """WITH exploded AS ( + | SELECT * + | FROM double_nested + | LATERAL VIEW OUTER EXPLODE(a_array) AS a_item + | LATERAL VIEW OUTER EXPLODE(a_item.b_array) AS b_item + |) + |SELECT + | a, + | a_item.b, + | b_item.c_int, + | COUNT(*) as cnt, + | SUM(CASE WHEN a_item.b_int > 1 THEN 1 ELSE 0 END) as high_count + |FROM exploded + |WHERE a_bool = true + |GROUP BY a, a_item.b, b_item.c_int + |ORDER BY cnt DESC + |LIMIT 10""".stripMargin) + + // Despite SELECT * in CTE, only used fields from arrays should be read. + // Note: The filter on a_bool may be evaluated after scan depending on + // optimization order, so a_bool might not be in scan schema. + // Key verification: a_array struct fields ARE pruned (no b_string, c, c_2) + checkScan(query, + "struct>>>>") + + query.collect() + } + } + + // Simpler variant to verify the basic pattern works + testExplodePruning("CTE with SELECT * - single explode level") { + withDoubleNestedData { + val query = sql( + """WITH base AS ( + | SELECT * + | FROM double_nested + | LATERAL VIEW EXPLODE(a_array) AS item + |) + |SELECT a, item.b, item.b_int + |FROM base + |WHERE a_bool = true""".stripMargin) + + // Key verification: a_array struct fields ARE pruned (no b_string, b_array) + // Note: a_bool filter may be evaluated after scan + checkScan(query, + "struct>>") + + query.collect() + } + } + + // ========================================================================= + // Regression tests: non-struct Generate above struct chain (Pattern 3) + // + // These verify that when a non-struct Generate (e.g., EXPLODE(tags) on + // array) sits above a struct Generate chain (EXPLODE(items) on + // array), the scan includes the scalar array column even when it + // is not explicitly selected, and Generate.unrequiredChildIndex is + // preserved correctly. + // ========================================================================= + + testExplodePruning("mixed arrays - non-struct generate above struct chain") { + withMixedArrayData { + // Regression: tags is NOT in SELECT but is needed by EXPLODE(tags). + // Before fix, this failed with INTERNAL_ERROR_ATTRIBUTE_NOT_FOUND + // because tags was dropped from scan and from Generate child output. + val query = sql( + """SELECT item.name, item.detail.color, tag + |FROM mixed_array + |LATERAL VIEW EXPLODE(items) AS item + |LATERAL VIEW EXPLODE(tags) AS tag""".stripMargin) + + // tags must be in scan schema; qty is pruned (top-level field). + // detail is pruned to struct since only detail.color is accessed. + checkScan(query, + "struct>>,tags:array>") + + checkAnswer(query, + Row("apple", "red", "fruit") :: + Row("apple", "red", "food") :: + Row("banana", "yellow", "fruit") :: + Row("banana", "yellow", "food") :: + Row("cherry", "dark red", "berry") :: Nil) + } + } + + testExplodePruning("mixed arrays - pass-through column from non-struct generate") { + withMixedArrayData { + // tags appears both as a pass-through column AND as the generator source + val query = sql( + """SELECT tags, item.name, item.detail.color, tag + |FROM mixed_array + |LATERAL VIEW EXPLODE(items) AS item + |LATERAL VIEW EXPLODE(tags) AS tag""".stripMargin) + + checkScan(query, + "struct>>,tags:array>") + + checkAnswer(query, + Row(Seq("fruit", "food"), "apple", "red", "fruit") :: + Row(Seq("fruit", "food"), "apple", "red", "food") :: + Row(Seq("fruit", "food"), "banana", "yellow", "fruit") :: + Row(Seq("fruit", "food"), "banana", "yellow", "food") :: + Row(Seq("berry"), "cherry", "dark red", "berry") :: Nil) + } + } + + testExplodePruning("mixed arrays - scalar column with non-struct generate") { + withMixedArrayData { + // id (scalar) + tags (array for non-struct generate) + pruned struct fields + val query = sql( + """SELECT id, item.name, item.detail.color, tag + |FROM mixed_array + |LATERAL VIEW EXPLODE(items) AS item + |LATERAL VIEW EXPLODE(tags) AS tag""".stripMargin) + + checkScan(query, + "struct>>,tags:array>") + + checkAnswer(query, + Row(1, "apple", "red", "fruit") :: + Row(1, "apple", "red", "food") :: + Row(1, "banana", "yellow", "fruit") :: + Row(1, "banana", "yellow", "food") :: + Row(2, "cherry", "dark red", "berry") :: Nil) + } + } + + testExplodePruning("mixed arrays - nested sub-field pruning with all top-level fields") { + withMixedArrayData { + // All top-level struct fields (name, detail, qty) are selected so no + // top-level pruning. detail is pruned to struct (size not accessed). + val query = sql( + """SELECT item.name, item.detail.color, item.qty, tag + |FROM mixed_array + |LATERAL VIEW EXPLODE(items) AS item + |LATERAL VIEW EXPLODE(tags) AS tag""".stripMargin) + + checkScan(query, + "struct,qty:int>>,tags:array>") + + checkAnswer(query, + Row("apple", "red", 5, "fruit") :: + Row("apple", "red", 5, "food") :: + Row("banana", "yellow", 3, "fruit") :: + Row("banana", "yellow", 3, "food") :: + Row("cherry", "dark red", 1, "berry") :: Nil) + } + } + + // Nested struct pruning within array element - prunes both top-level and nested fields + testExplodePruning("nested struct pruning - prune sub-fields of struct inside array") { + withMixedArrayData { + // Select name and detail.color only: prunes qty (top-level) AND detail.size (nested) + val query = sql( + """SELECT item.name, item.detail.color + |FROM mixed_array + |LATERAL VIEW EXPLODE(items) AS item""".stripMargin) + + // detail should be pruned from struct to struct + checkScan(query, + "struct>>>") + + checkAnswer(query, + Row("apple", "red") :: + Row("banana", "yellow") :: + Row("cherry", "dark red") :: Nil) + } + } + + testExplodePruning("nested struct pruning - only nested sub-field selected") { + withMixedArrayData { + // Select only detail.color: prunes name, qty (top-level) AND detail.size (nested) + val query = sql( + """SELECT item.detail.color + |FROM mixed_array + |LATERAL VIEW EXPLODE(items) AS item""".stripMargin) + + checkScan(query, + "struct>>>") + + checkAnswer(query, + Row("red") :: Row("yellow") :: Row("dark red") :: Nil) + } + } + + testExplodePruning("nested struct pruning - with filter on nested sub-field") { + withMixedArrayData { + // Select name, filter on detail.size: prunes qty but keeps detail.color + detail.size + val query = sql( + """SELECT item.name, item.detail.color + |FROM mixed_array + |LATERAL VIEW EXPLODE(items) AS item + |WHERE item.detail.size > 1""".stripMargin) + + // detail needs both color (projected) and size (filtered), so no nested pruning on detail + checkScan(query, + "struct>>>") + + checkAnswer(query, + Row("apple", "red") :: Row("banana", "yellow") :: Nil) + } + } + + // ============================================================================ + // Non-consecutive generates: generates separated by Aggregate nodes + // Verifies the rule handles broken chains gracefully and pruning still works + // ============================================================================ + + testExplodePruning("generate above aggregate - posexplode after group by") { + withMixedArrayData { + // Aggregate passes the array through, then posexplode selects only some fields. + // The Aggregate (first(items)) prevents scan-level pruning of array element fields + // because the aggregate function needs the full array value. + val query = sql( + """SELECT pos, elem.name + |FROM ( + | SELECT first(items) as items + | FROM mixed_array + | GROUP BY id + |) t + |LATERAL VIEW POSEXPLODE(items) AS pos, elem""".stripMargin) + + // Scan reads full items array: the Aggregate blocks schema pruning push-down. + // The Generate chain is broken by the Aggregate, but the system handles this + // correctly - no crashes, correct results. + checkScan(query, + "struct,qty:int>>>") + + checkAnswer(query, + Row(0, "apple") :: Row(1, "banana") :: Row(0, "cherry") :: Nil) + } + } + + testExplodePruning("non-consecutive generates separated by aggregate") { + withMixedArrayData { + // Pattern: Scan -> EXPLODE -> GROUP BY -> POSEXPLODE -> Project + // The Aggregate breaks the generate chain into two independent generates. + val query = sql( + """SELECT category, pos, agg_item.name + |FROM ( + | SELECT item.detail.color as category, + | collect_list(named_struct('name', item.name, 'qty', item.qty)) as agg_items + | FROM mixed_array + | LATERAL VIEW EXPLODE(items) AS item + | GROUP BY item.detail.color + |) t + |LATERAL VIEW POSEXPLODE(agg_items) AS pos, agg_item""".stripMargin) + + // The inner EXPLODE needs name, qty, detail.color from items. + // The outer POSEXPLODE is on collect_list output (not a scan array). + checkScan(query, + "struct,qty:int>>>") + + checkAnswer(query, + Row("dark red", 0, "cherry") :: + Row("red", 0, "apple") :: + Row("yellow", 0, "banana") :: Nil) + } + } + + testExplodePruning( + "full pipeline: scan -> filter -> agg -> explode -> agg -> filter -> posexplode -> agg") { + withMixedArrayData { + // Full pipeline: Scan -> Filter -> Aggregate -> EXPLODE -> + // GROUP BY + collect_list -> Filter -> POSEXPLODE -> Aggregate + // Tests that the system handles a complex pipeline with multiple generates + // separated by aggregates and filters without crashing, producing correct results. + val query = sql( + """WITH filtered_data AS ( + | SELECT flatten(collect_list(items)) as all_items + | FROM mixed_array + | WHERE id > 0 + |), + |exploded AS ( + | SELECT item.name, item.detail.color as color, item.qty + | FROM filtered_data + | LATERAL VIEW EXPLODE(all_items) AS item + |), + |grouped AS ( + | SELECT color, + | collect_list(named_struct('name', name, 'color', color, 'qty', qty)) as agg_items + | FROM exploded + | GROUP BY color + |), + |filtered_groups AS ( + | SELECT * FROM grouped WHERE size(agg_items) >= 1 + |), + |posexploded AS ( + | SELECT color, pos, agg_item.name, agg_item.qty + | FROM filtered_groups + | LATERAL VIEW POSEXPLODE(agg_items) AS pos, agg_item + |) + |SELECT color, count(*) as cnt, sum(qty) as total_qty + |FROM posexploded + |GROUP BY color""".stripMargin) + + // Aggregates block scan-level nested pruning: the full items array is read. + // detail.size is not used but cannot be pruned through flatten(collect_list(...)). + checkScan(query, + "struct,qty:int>>>") + + checkAnswer(query, + Row("dark red", 1, 1) :: + Row("red", 1, 5) :: + Row("yellow", 1, 3) :: Nil) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala index 269990d7d14e8..b30406aee3865 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala @@ -369,13 +369,14 @@ abstract class SchemaPruningSuite checkScan(query1, "struct>>") checkAnswer(query1, Row("Susan") :: Nil) - // Currently we don't prune multiple field case. + // Multi-field case: should now prune to only first and middle. val query2 = spark.table("contacts") .select(explode(col("friends")).as("friend")) .select("friend.first", "friend.middle") - checkScan(query2, "struct>>") + checkScan(query2, "struct>>") checkAnswer(query2, Row("Susan", "Z.") :: Nil) + // When the whole struct is also selected, no pruning is possible. val query3 = spark.table("contacts") .select(explode(col("friends")).as("friend")) .select("friend.first", "friend.middle", "friend") @@ -404,10 +405,10 @@ abstract class SchemaPruningSuite checkScan(query1, "struct>>") checkAnswer(query1, Row("Susan") :: Nil) - // Currently we don't prune multiple field case. + // Multi-field case: should now prune to only first and middle. val query2 = sql( "select friend.first, friend.middle from contacts, lateral explode(friends) t(friend)") - checkScan(query2, "struct>>") + checkScan(query2, "struct>>") checkAnswer(query2, Row("Susan", "Z.") :: Nil) val query3 = sql( @@ -1178,4 +1179,101 @@ abstract class SchemaPruningSuite checkAnswer(mapQuery, Row(0, null) :: Row(1, null) :: Row(null, null) :: Row(null, null) :: Nil) } + + // ---- Tests for PruneNestedFieldsThroughGenerateForScan ---- + + testSchemaPruning( + "SPARK-47230: multi-field nested column prune on explode generator output") { + // Two fields from the generator output → should prune to only those fields + val query = spark.table("contacts") + .select(explode(col("friends")).as("friend")) + .select("friend.first", "friend.last") + checkScan(query, "struct>>") + checkAnswer(query, Row("Susan", "Smith") :: Nil) + } + + testSchemaPruning( + "SPARK-47230: multi-field prune with pass-through columns") { + // explode + pass-through of a non-array column + val query = spark.table("contacts") + .select(col("id"), explode(col("friends")).as("friend")) + .select("id", "friend.first", "friend.middle") + checkScan(query, + "struct>>") + checkAnswer(query, Row(0, "Susan", "Z.") :: Nil) + } + + testSchemaPruning( + "SPARK-47230: no pruning when whole struct is referenced directly") { + // friend is referenced directly → can't prune + val query = spark.table("contacts") + .select(explode(col("friends")).as("friend")) + .select("friend") + checkScan(query, + "struct>>") + checkAnswer(query, Row(Row("Susan", "Z.", "Smith")) :: Nil) + } + + testSchemaPruning( + "SPARK-47230: multi-field prune with filter on generator output") { + // Filter above Generate that references generator output + val query = spark.table("contacts") + .select(explode(col("friends")).as("friend")) + .where("friend.first = 'Susan'") + .select("friend.first", "friend.last") + checkScan(query, "struct>>") + checkAnswer(query, Row("Susan", "Smith") :: Nil) + } + + testSchemaPruning( + "SPARK-47230: all fields selected means no pruning") { + // Selecting all three fields → no pruning needed + val query = spark.table("contacts") + .select(explode(col("friends")).as("friend")) + .select("friend.first", "friend.middle", "friend.last") + checkScan(query, + "struct>>") + checkAnswer(query, Row("Susan", "Z.", "Smith") :: Nil) + } + + testSchemaPruning( + "SPARK-47230: ordinal stability after pruning non-contiguous fields") { + // Select first (ordinal 0) and last (ordinal 2) - skipping middle (ordinal 1) + // After pruning the struct becomes , ordinals should be 0 and 1 + val query = spark.table("contacts") + .select(explode(col("friends")).as("friend")) + .select("friend.first", "friend.last") + checkScan(query, "struct>>") + checkAnswer(query, Row("Susan", "Smith") :: Nil) + } + + testSchemaPruning( + "SPARK-47230: multi-field prune via lateral explode") { + val query = sql( + "select friend.first, friend.last from contacts, lateral explode(friends) t(friend)") + checkScan(query, "struct>>") + checkAnswer(query, Row("Susan", "Smith") :: Nil) + } + + testSchemaPruning( + "SPARK-47230: posexplode multi-field prune") { + val query = sql( + "select pos, friend.first, friend.middle " + + "from contacts, lateral posexplode(friends) t(pos, friend)") + checkScan(query, + "struct>>") + checkAnswer(query, Row(0, "Susan", "Z.") :: Nil) + } + + testSchemaPruning( + "SPARK-47230: posexplode pos-only selects minimal-weight field") { + // Only pos is referenced; the rule picks the lightest struct field + // All three fields are StringType (defaultSize=20), so tie-break by + // name: "first" < "last" < "middle" + val query = sql( + "select pos from contacts, lateral posexplode(friends) t(pos, friend)") + checkScan(query, + "struct>>") + checkAnswer(query, Row(0) :: Nil) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetExplodeNestedSchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetExplodeNestedSchemaPruningSuite.scala new file mode 100644 index 0000000000000..b41c0ca734d0e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetExplodeNestedSchemaPruningSuite.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.spark.SparkConf +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.execution.datasources.ExplodeNestedSchemaPruningSuite +import org.apache.spark.sql.execution.datasources.v2.BatchScanExec +import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.tags.ExtendedSQLTest + +abstract class ParquetExplodeNestedSchemaPruningSuite + extends ExplodeNestedSchemaPruningSuite with AdaptiveSparkPlanHelper { + override protected val dataSourceName: String = "parquet" + override protected val vectorizedReaderEnabledKey: String = + SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key + override protected val vectorizedReaderNestedEnabledKey: String = + SQLConf.PARQUET_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key +} + +@ExtendedSQLTest +class ParquetV1ExplodeNestedSchemaPruningSuite + extends ParquetExplodeNestedSchemaPruningSuite { + override protected def sparkConf: SparkConf = + super + .sparkConf + .set(SQLConf.USE_V1_SOURCE_LIST, "parquet") +} + +@ExtendedSQLTest +class ParquetV2ExplodeNestedSchemaPruningSuite + extends ParquetExplodeNestedSchemaPruningSuite { + override protected def sparkConf: SparkConf = + super + .sparkConf + .set(SQLConf.USE_V1_SOURCE_LIST, "") + + override def checkScanSchemata( + df: DataFrame, expectedSchemaCatalogStrings: String*): Unit = { + val fileSourceScanSchemata = + collect(df.queryExecution.executedPlan) { + case scan: BatchScanExec => + scan.scan.asInstanceOf[ParquetScan].readDataSchema + } + assert(fileSourceScanSchemata.size === expectedSchemaCatalogStrings.size, + s"Found ${fileSourceScanSchemata.size} file sources in dataframe, " + + s"but expected $expectedSchemaCatalogStrings") + fileSourceScanSchemata.zip(expectedSchemaCatalogStrings).foreach { + case (scanSchema, expectedScanSchemaCatalogString) => + val expectedScanSchema = + CatalystSqlParser.parseDataType(expectedScanSchemaCatalogString) + implicit val equality = schemaEquality + assert(scanSchema === expectedScanSchema) + } + } +}