diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index a0bc360b9aa3c..c0b72123065f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -21,7 +21,7 @@ import java.util.Locale import scala.collection.mutable -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkIllegalArgumentException} import org.apache.spark.internal.LogKeys.{AGGREGATE_FUNCTIONS, COLUMN_NAMES, GROUP_BY_EXPRS, JOIN_CONDITION, JOIN_TYPE, POST_SCAN_FILTERS, PUSHED_FILTERS, RELATION_NAME, RELATION_OUTPUT} import org.apache.spark.sql.catalyst.expressions.{aggregate, Alias, And, Attribute, AttributeMap, AttributeReference, AttributeSet, Cast, Expression, ExpressionSet, ExprId, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder, SubqueryExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression @@ -808,9 +808,16 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { } // Remap pushed filter attributes to the pruned output schema and drop filters - // whose references are no longer in the pruned output. - val remappedPushedFilters = sHolder.pushedFilterExpressions.map(projectionFunc) - .filter(_.references.subsetOf(AttributeSet(output))) + // whose references are no longer in the pruned output. Catch FIELD_NOT_FOUND + // because ProjectionOverSchema throws when a pushed filter references a nested + // struct field that was pruned from the schema. + val remappedPushedFilters = sHolder.pushedFilterExpressions.flatMap { filter => + try Some(projectionFunc(filter)) + catch { + case e: SparkIllegalArgumentException if e.getCondition == "FIELD_NOT_FOUND" => + None + } + }.filter(_.references.subsetOf(AttributeSet(output))) val scanRelation = DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output, pushedFilters = remappedPushedFilters) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala index 6ea1ea3faa0ec..6d3ad69994a62 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala @@ -1288,6 +1288,26 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS s"struct column in pushed filter should be pruned to struct but was $prunedStructType") } + test("pushedFilters drops filters referencing pruned nested struct fields") { + // Disable constraint propagation so IsNotNull(s.a) is not added as a post-scan + // filter (it would keep field a alive in the struct). + withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "false") { + val df = spark.read.format(classOf[NestedSchemaDataSourceV2].getName).load() + // Filter on s.a but select only s.b. Column pruning narrows s to struct, + // so the pushed filter on s.a can't be remapped and should be dropped. + val q = df.filter($"s.a" > 3).select($"s.b") + checkAnswer(q, (4 until 10).map(i => Row(-i))) + + val scanRelation = getScanRelation(q) + val referencedStructFields = scanRelation.pushedFilters.flatMap { filter => + filter.collect { case a: AttributeReference if a.name == "s" => a } + .flatMap(_.dataType.asInstanceOf[StructType].fieldNames) + } + assert(!referencedStructFields.contains("a"), + "pushedFilters should not reference pruned nested field a") + } + } + test("scan canonicalization with pushedFilters") { // Use SimpleDataSourceV2 whose scan implements equals, so canonicalization comparison works val table = new SimpleDataSourceV2().getTable(CaseInsensitiveStringMap.empty())