Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,24 @@ case class ProjectionOverSchema(schema: StructType, output: AttributeSet) {
s"unmatched child schema for GetArrayStructFields: ${projSchema.toString}"
)
}
case a: GetNestedArrayStructFields =>
getProjection(a.child).map(p => (p, p.dataType)).map {
case (projection, projArrayType: ArrayType) =>
// Find the innermost struct in both original and projected types
val originalStruct = findInnermostStruct(a.child.dataType)
val projStruct = findInnermostStruct(projArrayType)
val selectedField = originalStruct(a.ordinal)
val prunedField = projStruct(selectedField.name)
GetNestedArrayStructFields(projection,
prunedField.copy(name = a.field.name),
projStruct.fieldIndex(selectedField.name),
projStruct.size,
a.containsNull)
case (_, projSchema) =>
throw new IllegalStateException(
s"unmatched child schema for GetNestedArrayStructFields: ${projSchema.toString}"
)
}
case MapKeys(child) =>
getProjection(child).map { projection => MapKeys(projection) }
case MapValues(child) =>
Expand All @@ -79,7 +97,40 @@ case class ProjectionOverSchema(schema: StructType, output: AttributeSet) {
}
case ElementAt(left, right, defaultValueOutOfBound, failOnError) if right.foldable =>
getProjection(left).map(p => ElementAt(p, right, defaultValueOutOfBound, failOnError))
case az: ArraysZip =>
// Project each child expression and rebuild ArraysZip with projected children
val projectedChildren = az.children.map(getProjection)
if (projectedChildren.forall(_.isDefined)) {
Some(az.copy(children = projectedChildren.map(_.get)))
} else {
None
}
case naz: NestedArraysZip =>
// Project each child expression and rebuild NestedArraysZip with projected children
val projectedChildren = naz.children.map(getProjection)
if (projectedChildren.forall(_.isDefined)) {
Some(naz.copy(children = projectedChildren.map(_.get)))
} else {
None
}
case a: Alias =>
// Project the child and wrap it back in an Alias with the same metadata
getProjection(a.child).map { projectedChild =>
a.copy(child = projectedChild)(
a.exprId, a.qualifier, a.explicitMetadata, a.nonInheritableMetadataKeys)
}
case _ =>
None
}

/**
* Finds the innermost StructType within a nested array type.
* For example, `array<array<struct<a, b>>>` returns `struct<a, b>`.
*/
@scala.annotation.tailrec
private def findInnermostStruct(dt: DataType): StructType = dt match {
case ArrayType(elementType: ArrayType, _) => findInnermostStruct(elementType)
case ArrayType(st: StructType, _) => st
case _ => throw new IllegalStateException(s"Expected nested array of struct, got: $dt")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ object SchemaPruning extends SQLConfHelper {
RootField(StructField(att.name, att.dataType, att.nullable, att.metadata),
derivedFromAtt = true) :: Nil
case SelectedField(field) => RootField(field, derivedFromAtt = false) :: Nil
// Handle multi-field expressions like ArraysZip and NestedArraysZip that combine
// multiple field accesses into a single expression. unapplySeq returns all fields.
case expr if SelectedField.unapplySeq(expr).exists(_.size > 1) =>
SelectedField.unapplySeq(expr).get.map(f => RootField(f, derivedFromAtt = false))
// Root field accesses by `IsNotNull` and `IsNull` are special cases as the expressions
// don't actually use any nested fields. These root field accesses might be excluded later
// if there are any nested fields accesses in the query plan.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,40 @@ object SelectedField {
selectField(unaliased, None)
}

/**
* Like unapply, but returns all fields from expressions that combine multiple field accesses.
* This is needed for expressions like ArraysZip and NestedArraysZip that merge multiple
* array field extractions into a single output.
*
* For example, `ArraysZip([arr.f1, arr.f2], names)` accesses both f1 and f2 from arr.
* Regular unapply would return None, but unapplySeq returns both fields.
*
* @return None if no fields are accessed, Some(Seq[StructField]) otherwise
*/
def unapplySeq(expr: Expression): Option[Seq[StructField]] = {
val unaliased = expr match {
case Alias(child, _) => child
case e => e
}
unaliased match {
// ArraysZip combines multiple array field extractions
// Use unapplySeq recursively to handle nested ArraysZip/NestedArraysZip
case ArraysZip(children, _) =>
val fields = children.flatMap(c => unapplySeq(c).getOrElse(Seq.empty))
if (fields.nonEmpty) Some(fields) else None

// NestedArraysZip combines nested array field extractions
// Use unapplySeq recursively to handle further nested expressions
case NestedArraysZip(children, _, _) =>
val fields = children.flatMap(c => unapplySeq(c).getOrElse(Seq.empty))
if (fields.nonEmpty) Some(fields) else None

// For other expressions, delegate to regular unapply
case _ =>
unapply(expr).map(Seq(_))
}
}

/**
* Convert an expression into the parts of the schema (the field) it accesses.
*/
Expand Down Expand Up @@ -96,6 +130,26 @@ object SelectedField {
}
val newField = StructField(field.name, newFieldDataType, field.nullable)
selectField(child, Option(ArrayType(struct(newField), containsNull)))
case GetNestedArrayStructFields(child, field, ordinal, _, containsNull) =>
// GetNestedArrayStructFields extracts a field from the innermost struct of a
// nested array like array<array<struct>>. We need to find the innermost struct
// field and rebuild the full nested array schema.
val innermostField = findInnermostStructField(child.dataType, ordinal)
val newFieldDataType = dataTypeOpt match {
case None =>
// Top level extractor - use the field's type
innermostField.dataType
case Some(dt) =>
// Part of a chain - peel off only the parent's array layers, not the field's own
// For example, if child is array<array<struct<..., field: array<struct<...>>>>>,
// the parent chain contributes 2 array levels, and field contributes 1 more.
// We should peel only the parent's 2 levels, keeping the field's array type.
val parentArrayDepth = arrayDepth(child.dataType)
peelNArrayLayers(dt, parentArrayDepth)
}
val newField = StructField(innermostField.name, newFieldDataType, innermostField.nullable)
val wrappedType = wrapInArrays(child.dataType, struct(newField), containsNull)
selectField(child, Option(wrappedType))
case GetMapValue(child, key) if key.foldable =>
// GetMapValue does not select a field from a struct (i.e. prune the struct) so it can't be
// the top-level extractor. However it can be part of an extractor chain.
Expand Down Expand Up @@ -154,4 +208,66 @@ object SelectedField {
}

private def struct(field: StructField): StructType = StructType(Array(field))

/**
* Finds the struct field at the given ordinal in the innermost struct of a nested array type.
* For example, for `array<array<struct<a, b, c>>>` with ordinal 1, returns field `b`.
*/
@scala.annotation.tailrec
private def findInnermostStructField(dt: DataType, ordinal: Int): StructField = dt match {
case ArrayType(elementType: ArrayType, _) => findInnermostStructField(elementType, ordinal)
case ArrayType(st: StructType, _) => st(ordinal)
case _ => throw new IllegalArgumentException(s"Expected nested array of struct, got: $dt")
}

/**
* Removes all ArrayType wrappers from a data type, returning the innermost element type.
* For example, `array<array<int>>` becomes `int`.
*/
@scala.annotation.tailrec
private def peelArrayLayers(dt: DataType): DataType = dt match {
case ArrayType(elementType, _) => peelArrayLayers(elementType)
case other => other
}

/**
* Counts the number of array layers in a data type.
* For example, `array<array<struct>>` returns 2.
*/
private def arrayDepth(dt: DataType): Int = {
@scala.annotation.tailrec
def loop(dt: DataType, depth: Int): Int = dt match {
case ArrayType(elementType, _) => loop(elementType, depth + 1)
case _ => depth
}
loop(dt, 0)
}

/**
* Removes exactly N ArrayType wrappers from a data type.
* For example, `peelNArrayLayers(array<array<array<int>>>, 2)` returns `array<int>`.
*/
@scala.annotation.tailrec
private def peelNArrayLayers(dt: DataType, n: Int): DataType = {
if (n <= 0) dt
else dt match {
case ArrayType(elementType, _) => peelNArrayLayers(elementType, n - 1)
case other => other
}
}

/**
* Wraps an innermost struct type in the same array nesting as the source type.
* For example, if sourceType is `array<array<struct<a, b>>>` and innerStruct is `struct<a>`,
* returns `array<array<struct<a>>>`.
*/
private def wrapInArrays(sourceType: DataType, innerStruct: StructType,
containsNull: Boolean): DataType = sourceType match {
case ArrayType(elementType: ArrayType, outerContainsNull) =>
ArrayType(wrapInArrays(elementType, innerStruct, containsNull), outerContainsNull)
case ArrayType(_: StructType, _) =>
ArrayType(innerStruct, containsNull)
case _ =>
throw new IllegalArgumentException(s"Expected nested array of struct, got: $sourceType")
}
}
Loading