From e085be4af39db91e9a3bb6c8efc38054351f54af Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Thu, 16 Apr 2026 18:22:21 -0700 Subject: [PATCH 1/8] [SPARK-56521][SQL] Support PartitionPredicate in runtime filters --- .../read/SupportsRuntimeV2Filtering.java | 15 + ...yEnhancedRuntimePartitionFilterTable.scala | 88 ++++ ...nhancedRuntimePartitionFilterCatalog.scala | 50 ++ .../datasources/v2/BatchScanExec.scala | 22 +- .../datasources/v2/PushDownUtils.scala | 34 +- ...2EnhancedRuntimePartitionFilterSuite.scala | 433 ++++++++++++++++++ 6 files changed, 637 insertions(+), 5 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryEnhancedRuntimePartitionFilterTable.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableEnhancedRuntimePartitionFilterCatalog.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2EnhancedRuntimePartitionFilterSuite.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeV2Filtering.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeV2Filtering.java index f5acdf885bf5c..d158ab413649a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeV2Filtering.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeV2Filtering.java @@ -19,6 +19,7 @@ import org.apache.spark.annotation.Experimental; import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.filter.PartitionPredicate; import org.apache.spark.sql.connector.expressions.filter.Predicate; import org.apache.spark.sql.sources.Filter; @@ -64,4 +65,18 @@ public interface SupportsRuntimeV2Filtering extends Scan { * @param predicates data source V2 predicates used to filter the scan at runtime */ void filter(Predicate[] predicates); + + /** + * Returns true if this scan supports iterative runtime filtering. When true, + * {@link #filter(Predicate[])} may be called multiple times with additional predicates. + *

+ * When enabled, Spark will derive {@link PartitionPredicate} instances from the runtime + * filters and push them via a subsequent {@link #filter(Predicate[])} call. + *

+ * + * @since 4.2.0 + */ + default boolean supportsIterativeFiltering() { + return false; + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryEnhancedRuntimePartitionFilterTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryEnhancedRuntimePartitionFilterTable.scala new file mode 100644 index 0000000000000..4825a7f7a3d6e --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryEnhancedRuntimePartitionFilterTable.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog + +import java.util + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.sql.connector.expressions.{NamedReference, Transform} +import org.apache.spark.sql.connector.expressions.filter.{PartitionPredicate, Predicate} +import org.apache.spark.sql.connector.read.{InputPartition, Scan, ScanBuilder, SupportsRuntimeV2Filtering} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.ArrayImplicits._ + +/** + * In-memory table whose batch scan implements [[SupportsRuntimeV2Filtering]] with + * iterative filtering support, so that [[PartitionPredicate]] instances derived from + * runtime filters are pushed via a second [[SupportsRuntimeV2Filtering#filter]] call. + */ +class InMemoryEnhancedRuntimePartitionFilterTable( + name: String, + columns: Array[Column], + partitioning: Array[Transform], + properties: util.Map[String, String]) + extends InMemoryTableWithV2Filter(name, columns, partitioning, properties) { + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + new InMemoryEnhancedRuntimePartitionFilterScanBuilder(schema, options) + } + + class InMemoryEnhancedRuntimePartitionFilterScanBuilder( + tableSchema: StructType, + options: CaseInsensitiveStringMap) + extends InMemoryScanBuilder(tableSchema, options) { + override def build: Scan = InMemoryEnhancedRuntimePartitionFilterBatchScan( + data.map(_.asInstanceOf[InputPartition]).toImmutableArraySeq, + schema, tableSchema, options) + } + + case class InMemoryEnhancedRuntimePartitionFilterBatchScan( + var _data: Seq[InputPartition], + readSchema: StructType, + tableSchema: StructType, + options: CaseInsensitiveStringMap) + extends BatchScanBaseClass(_data, readSchema, tableSchema) + with SupportsRuntimeV2Filtering { + + private val _pushedPartitionPredicates = ArrayBuffer.empty[PartitionPredicate] + + def pushedPartitionPredicates: Seq[PartitionPredicate] = + _pushedPartitionPredicates.toSeq + + override def supportsIterativeFiltering(): Boolean = true + + override def filterAttributes(): Array[NamedReference] = { + val scanFields = readSchema.fields.map(_.name).toSet + partitioning.flatMap(_.references()) + .filter(ref => scanFields.contains(ref.fieldNames.mkString("."))) + } + + override def filter(filters: Array[Predicate]): Unit = { + filters.foreach { + case pp: PartitionPredicate => + _pushedPartitionPredicates += pp + data = data.filter { partition => + pp.eval(partition.asInstanceOf[BufferedRows].partitionKey()) + } + case _ => + } + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableEnhancedRuntimePartitionFilterCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableEnhancedRuntimePartitionFilterCatalog.scala new file mode 100644 index 0000000000000..8d479c848920c --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableEnhancedRuntimePartitionFilterCatalog.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog + +import java.util + +import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException +import org.apache.spark.sql.connector.expressions.Transform + +class InMemoryTableEnhancedRuntimePartitionFilterCatalog extends InMemoryTableCatalog { + import CatalogV2Implicits._ + + override def createTable( + ident: Identifier, + columns: Array[Column], + partitions: Array[Transform], + properties: util.Map[String, String]): Table = { + if (tables.containsKey(ident)) { + throw new TableAlreadyExistsException(ident.asMultipartIdentifier) + } + + InMemoryTableCatalog.maybeSimulateFailedTableCreation(properties) + + val tableName = s"$name.${ident.quoted}" + val table = new InMemoryEnhancedRuntimePartitionFilterTable( + tableName, columns, partitions, properties) + tables.put(ident, table) + namespaces.putIfAbsent(ident.namespace.toList, Map()) + table + } + + override def createTable(ident: Identifier, tableInfo: TableInfo): Table = { + createTable(ident, tableInfo.columns(), tableInfo.partitions(), tableInfo.properties) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index dde294c6019b0..e27d8d5f2db7d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -67,11 +67,29 @@ case class BatchScanExec( } val originalPartitioning = outputPartitioning + // the cast is safe as runtime filters are only assigned if the scan can be filtered + val filterableScan = scan.asInstanceOf[SupportsRuntimeV2Filtering] + var filtered = false + if (dataSourceFilters.nonEmpty) { - // the cast is safe as runtime filters are only assigned if the scan can be filtered - val filterableScan = scan.asInstanceOf[SupportsRuntimeV2Filtering] filterableScan.filter(dataSourceFilters.toArray) + filtered = true + } + + // If the scan supports iterative filtering, derive PartitionPredicates from the + // runtime filters and push them in a second pass. (See SPARK-55596) + if (filterableScan.supportsIterativeFiltering()) { + PushDownUtils.getPartitionPredicateSchema(table, output).foreach { partitionFields => + val partPredicates = + PushDownUtils.createRuntimePartitionPredicates(runtimeFilters, partitionFields) + if (partPredicates.nonEmpty) { + filterableScan.filter(partPredicates.toArray) + filtered = true + } + } + } + if (filtered) { // call toBatch again to get filtered partitions val newPartitions = scan.toBatch.planInputPartitions() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index 4b50159132737..8cab3855a7bb1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -21,13 +21,15 @@ import scala.collection.mutable import org.apache.spark.internal.{Logging, LogKeys} import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, ExpressionSet, GetStructField, NamedExpression, PythonUDF, SchemaPruning, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, DynamicPruning, DynamicPruningExpression, Expression, ExpressionSet, GetStructField, NamedExpression, PythonUDF, SchemaPruning, SubqueryExpression} import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.CharVarcharUtils +import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.connector.expressions.{IdentityTransform, SortOrder} import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownOffset, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownTopN, SupportsPushDownV2Filters} +import org.apache.spark.sql.execution.{ScalarSubquery => ExecScalarSubquery} import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, DataSourceUtils} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.connector.{PartitionPredicateField, PartitionPredicateImpl, SupportsPushDownCatalystFilters} @@ -135,11 +137,20 @@ object PushDownUtils extends Logging { */ def getPartitionPredicateSchema(relation: DataSourceV2Relation) : Option[Seq[PartitionPredicateField]] = { - val transforms = relation.table.partitioning + getPartitionPredicateSchema(relation.table, relation.output) + } + + /** + * Returns a Seq of [[PartitionPredicateField]] representing partition transform expression types, + * if schema is supported for [[PartitionPredicate]] push down. None if not supported. + */ + def getPartitionPredicateSchema(table: Table, output: Seq[AttributeReference]) + : Option[Seq[PartitionPredicateField]] = { + val transforms = table.partitioning if (transforms.isEmpty) { None } else { - val rootStruct = StructType(relation.output.map { a => + val rootStruct = StructType(output.map { a => StructField(a.name, a.dataType, a.nullable)}) val fields = transforms.flatMap { case t: IdentityTransform => @@ -223,6 +234,23 @@ object PushDownUtils extends Logging { .map(flattenedToOriginal) } + /** + * Creates [[PartitionPredicateImpl]] instances from runtime filter expressions. + * Extracts Catalyst expressions from the runtime filters (unwrapping DPP and literalizing + * scalar subqueries), then converts partition-column filters to [[PartitionPredicateImpl]]. + */ + private[v2] def createRuntimePartitionPredicates( + runtimeFilters: Seq[Expression], + partitionFields: Seq[PartitionPredicateField]): Seq[PartitionPredicateImpl] = { + val catalystExprs = runtimeFilters.flatMap { + case DynamicPruningExpression(e) => Some(e) + case _: DynamicPruning => None + case f => Some(f.transform { case s: ExecScalarSubquery => s.toLiteral }) + } + val flattened = flattenNestedPartitionFilters(catalystExprs, partitionFields).keys + createPartitionPredicates(flattened.toSeq, partitionFields)._1 + } + private def isPushablePartitionFilter(f: Expression) = f.deterministic && !SubqueryExpression.hasSubquery(f) && diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2EnhancedRuntimePartitionFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2EnhancedRuntimePartitionFilterSuite.scala new file mode 100644 index 0000000000000..69392a3439d65 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2EnhancedRuntimePartitionFilterSuite.scala @@ -0,0 +1,433 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.sql.{DataFrame, QueryTest, Row} +import org.apache.spark.sql.catalyst.expressions.{DynamicPruning, DynamicPruningExpression} +import org.apache.spark.sql.connector.catalog.{BufferedRows, InMemoryEnhancedRuntimePartitionFilterTable, InMemoryTableEnhancedRuntimePartitionFilterCatalog} +import org.apache.spark.sql.connector.expressions.PartitionFieldReference +import org.apache.spark.sql.connector.expressions.filter.PartitionPredicate +import org.apache.spark.sql.execution.ExplainUtils.stripAQEPlan +import org.apache.spark.sql.execution.datasources.v2.BatchScanExec +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession + +/** + * Tests that [[PartitionPredicate]] instances are pushed via iterative runtime filtering + * (second [[SupportsRuntimeV2Filtering#filter]] call) for both DPP and scalar subquery + * runtime filters. + */ +class DataSourceV2EnhancedRuntimePartitionFilterSuite + extends QueryTest with SharedSparkSession with BeforeAndAfter { + + protected val v2Source = classOf[FakeV2ProviderWithCustomSchema].getName + protected val catalogName = "testruntimepartfilter" + + before { + spark.conf.set(s"spark.sql.catalog.$catalogName", + classOf[InMemoryTableEnhancedRuntimePartitionFilterCatalog].getName) + } + + after { + spark.sessionState.catalogManager.reset() + } + + test("DPP: PartitionPredicate pushed via iterative runtime filtering") { + val fact = s"$catalogName.fact" + val dim = s"$catalogName.dim" + withTable(fact, dim) { + sql(s"CREATE TABLE $fact (id INT, part INT) USING $v2Source PARTITIONED BY (part)") + for (i <- 0 until 5) { + sql(s"INSERT INTO $fact VALUES ($i, $i)") + } + sql(s"CREATE TABLE $dim (dim_id INT, dim_val STRING) USING $v2Source") + sql(s"INSERT INTO $dim VALUES (2, 'two')") + + withSQLConf( + SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true", + SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "false", + SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "10") { + val df = sql( + s"""SELECT f.id, f.part FROM $fact f JOIN $dim d + |ON f.part = d.dim_id WHERE d.dim_val = 'two'""".stripMargin) + checkAnswer(df, Row(2, 2)) + + assertDPPRuntimeFilters(df) + + assertPushedPartitionPredicates(df, 1) + assertScanReturnsPartitionKeys(df, Set("2")) + assertReferencedPartitionFieldOrdinals(df, Array(0), Array("part")) + } + } + } + + test("DPP: PartitionPredicate on non-first partition column") { + val fact = s"$catalogName.fact2" + val dim = s"$catalogName.dim2" + withTable(fact, dim) { + sql(s"CREATE TABLE $fact (id INT, p1 INT, p2 INT) " + + s"USING $v2Source PARTITIONED BY (p1, p2)") + for (i <- 0 until 5; j <- 0 until 2) { + sql(s"INSERT INTO $fact VALUES (${i * 2 + j}, $i, $j)") + } + sql(s"CREATE TABLE $dim (dim_id INT, dim_val STRING) " + + s"USING $v2Source") + sql(s"INSERT INTO $dim VALUES (1, 'one')") + + withSQLConf( + SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true", + SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> + "false", + SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> + "10") { + val df = sql( + s"""SELECT f.id, f.p1, f.p2 FROM $fact f JOIN $dim d + |ON f.p2 = d.dim_id + |WHERE d.dim_val = 'one'""".stripMargin) + checkAnswer(df, Seq( + Row(1, 0, 1), Row(3, 1, 1), Row(5, 2, 1), + Row(7, 3, 1), Row(9, 4, 1))) + + assertDPPRuntimeFilters(df) + + assertPushedPartitionPredicates(df, 1) + assertReferencedPartitionFieldOrdinals( + df, Array(1), Array("p1", "p2")) + } + } + } + + test("scalar subquery: PartitionPredicate pushed via iterative runtime filtering") { + val tbl = s"$catalogName.tbl" + val dim = s"$catalogName.dim" + withTable(tbl, dim) { + sql(s"CREATE TABLE $tbl (id INT, part INT) USING $v2Source PARTITIONED BY (part)") + for (i <- 0 until 5) { + sql(s"INSERT INTO $tbl VALUES ($i, $i)") + } + sql(s"CREATE TABLE $dim (val INT) USING $v2Source") + sql(s"INSERT INTO $dim VALUES (3)") + + val df = sql(s"SELECT * FROM $tbl WHERE part = (SELECT max(val) FROM $dim)") + checkAnswer(df, Row(3, 3)) + + assertScalarSubqueryRuntimeFilters(df) + + assertPushedPartitionPredicates(df, 1) + assertScanReturnsPartitionKeys(df, Set("3")) + assertReferencedPartitionFieldOrdinals(df, Array(0), Array("part")) + } + } + + test("scalar subquery: complex expression with arithmetic on subquery result") { + val tbl = s"$catalogName.tbl_complex" + val dim = s"$catalogName.dim_complex" + withTable(tbl, dim) { + sql(s"CREATE TABLE $tbl (id INT, part INT) USING $v2Source PARTITIONED BY (part)") + for (i <- 0 until 5) { + sql(s"INSERT INTO $tbl VALUES ($i, $i)") + } + sql(s"CREATE TABLE $dim (val INT) USING $v2Source") + sql(s"INSERT INTO $dim VALUES (2)") + + val df = sql(s"SELECT * FROM $tbl WHERE part > (SELECT max(val) FROM $dim) + 1") + checkAnswer(df, Row(4, 4)) + + assertScalarSubqueryRuntimeFilters(df) + + assertPushedPartitionPredicates(df, 1) + assertScanReturnsPartitionKeys(df, Set("4")) + assertReferencedPartitionFieldOrdinals(df, Array(0), Array("part")) + } + } + + test("scalar subquery: RLIKE (untranslatable) with subquery pattern") { + val tbl = s"$catalogName.tbl_rlike" + val dim = s"$catalogName.dim_rlike" + withTable(tbl, dim) { + sql(s"CREATE TABLE $tbl (id INT, part STRING) USING $v2Source PARTITIONED BY (part)") + sql(s"INSERT INTO $tbl VALUES (1, 'abc')") + sql(s"INSERT INTO $tbl VALUES (2, 'def')") + sql(s"INSERT INTO $tbl VALUES (3, 'abx')") + sql(s"INSERT INTO $tbl VALUES (4, 'xyz')") + + sql(s"CREATE TABLE $dim (pattern STRING) USING $v2Source") + sql(s"INSERT INTO $dim VALUES ('^ab')") + + val df = sql(s"SELECT * FROM $tbl WHERE part RLIKE (SELECT max(pattern) FROM $dim)") + checkAnswer(df, Seq(Row(1, "abc"), Row(3, "abx"))) + + assertScalarSubqueryRuntimeFilters(df) + + assertPushedPartitionPredicates(df, 1) + assertScanReturnsPartitionKeys(df, Set("abc", "abx")) + assertReferencedPartitionFieldOrdinals(df, Array(0), Array("part")) + } + } + + test("scalar subquery: UDF on partition column with subquery value") { + val tbl = s"$catalogName.tbl_udf" + val dim = s"$catalogName.dim_udf" + withTable(tbl, dim) { + sql(s"CREATE TABLE $tbl (id INT, part STRING) USING $v2Source PARTITIONED BY (part)") + sql(s"INSERT INTO $tbl VALUES (1, 'a')") + sql(s"INSERT INTO $tbl VALUES (2, 'A')") + sql(s"INSERT INTO $tbl VALUES (3, 'b')") + sql(s"INSERT INTO $tbl VALUES (4, 'B')") + + sql(s"CREATE TABLE $dim (val STRING) USING $v2Source") + sql(s"INSERT INTO $dim VALUES ('A')") + + spark.udf.register("my_upper_runtime", + (s: String) => if (s == null) null else s.toUpperCase(java.util.Locale.ROOT)) + + val df = sql( + s"SELECT * FROM $tbl WHERE my_upper_runtime(part) = (SELECT max(val) FROM $dim)") + checkAnswer(df, Seq(Row(1, "a"), Row(2, "A"))) + + assertScalarSubqueryRuntimeFilters(df) + + assertPushedPartitionPredicates(df, 1) + assertScanReturnsPartitionKeys(df, Set("a", "A")) + assertReferencedPartitionFieldOrdinals(df, Array(0), Array("part")) + } + } + + test("scalar subquery: two PartitionPredicates for two subqueries on different partition cols") { + val fact = s"$catalogName.fact_2sub" + val dim1 = s"$catalogName.dim_2sub1" + val dim2 = s"$catalogName.dim_2sub2" + withTable(fact, dim1, dim2) { + sql(s"CREATE TABLE $fact (id INT, p1 INT, p2 STRING) " + + s"USING $v2Source PARTITIONED BY (p1, p2)") + sql(s"INSERT INTO $fact VALUES (1, 1, 'a')") + sql(s"INSERT INTO $fact VALUES (2, 1, 'b')") + sql(s"INSERT INTO $fact VALUES (3, 2, 'a')") + sql(s"INSERT INTO $fact VALUES (4, 2, 'b')") + + sql(s"CREATE TABLE $dim1 (val INT) USING $v2Source") + sql(s"INSERT INTO $dim1 VALUES (1)") + + sql(s"CREATE TABLE $dim2 (val STRING) USING $v2Source") + sql(s"INSERT INTO $dim2 VALUES ('a')") + + val df = sql( + s"""SELECT * FROM $fact + |WHERE p1 = (SELECT max(val) FROM $dim1) + | AND p2 = (SELECT max(val) FROM $dim2)""".stripMargin) + checkAnswer(df, Row(1, 1, "a")) + + assertScalarSubqueryRuntimeFilters(df, expectedCount = 2) + + assertPushedPartitionPredicates(df, 2) + assertScanReturnsPartitionKeys(df, Set("1/a")) + + val predicates = getPushedPartitionPredicates(df) + val partFieldNames = Array("p1", "p2") + val p1Pred = predicates.find(_.references().exists( + _.asInstanceOf[PartitionFieldReference].ordinal() == 0)) + val p2Pred = predicates.find(_.references().exists( + _.asInstanceOf[PartitionFieldReference].ordinal() == 1)) + assert(p1Pred.isDefined, "Expected a PartitionPredicate referencing p1 (ordinal 0)") + assert(p2Pred.isDefined, "Expected a PartitionPredicate referencing p2 (ordinal 1)") + assertPartitionPredicateOrdinals(p1Pred.get, Array(0), partFieldNames) + assertPartitionPredicateOrdinals(p2Pred.get, Array(1), partFieldNames) + } + } + + test("scalar subquery: PartitionPredicate on non-first column of three-partition-column table") { + val tbl = s"$catalogName.tbl_3part" + val dim = s"$catalogName.dim_3part" + withTable(tbl, dim) { + sql(s"CREATE TABLE $tbl (id INT, p0 INT, p1 STRING, p2 INT) " + + s"USING $v2Source PARTITIONED BY (p0, p1, p2)") + sql(s"INSERT INTO $tbl VALUES (1, 1, 'a', 10)") + sql(s"INSERT INTO $tbl VALUES (2, 1, 'b', 10)") + sql(s"INSERT INTO $tbl VALUES (3, 2, 'a', 20)") + sql(s"INSERT INTO $tbl VALUES (4, 2, 'b', 20)") + + sql(s"CREATE TABLE $dim (val STRING) USING $v2Source") + sql(s"INSERT INTO $dim VALUES ('a')") + + val df = sql( + s"SELECT * FROM $tbl WHERE p1 = (SELECT max(val) FROM $dim)") + checkAnswer(df, Seq(Row(1, 1, "a", 10), Row(3, 2, "a", 20))) + + assertScalarSubqueryRuntimeFilters(df) + + assertPushedPartitionPredicates(df, 1) + assertScanReturnsPartitionKeys(df, Set("1/a/10", "2/a/20")) + assertReferencedPartitionFieldOrdinals(df, Array(1), Array("p0", "p1", "p2")) + } + } + + test("no PartitionPredicate for scalar subquery on data column") { + val tbl = s"$catalogName.tbl_data" + val dim = s"$catalogName.dim_data" + withTable(tbl, dim) { + sql(s"CREATE TABLE $tbl (id INT, data INT, part INT) " + + s"USING $v2Source PARTITIONED BY (part)") + for (i <- 0 until 5) { + sql(s"INSERT INTO $tbl VALUES ($i, ${i * 10}, $i)") + } + sql(s"CREATE TABLE $dim (val INT) USING $v2Source") + sql(s"INSERT INTO $dim VALUES (30)") + + val df = sql( + s"SELECT * FROM $tbl WHERE data = (SELECT max(val) FROM $dim)") + checkAnswer(df, Row(3, 30, 3)) + + assertPushedPartitionPredicates(df, 0) + } + } + + test("no PartitionPredicate when supportsIterativeFiltering is false") { + val baseCatalog = "testv2filterNoIterative" + spark.conf.set(s"spark.sql.catalog.$baseCatalog", + classOf[catalog.InMemoryTableWithV2FilterCatalog].getName) + + val tbl = s"$baseCatalog.tbl" + val dim = s"$baseCatalog.dim" + withTable(tbl, dim) { + sql(s"CREATE TABLE $tbl (id INT, part INT) USING $v2Source PARTITIONED BY (part)") + for (i <- 0 until 5) { + sql(s"INSERT INTO $tbl VALUES ($i, $i)") + } + sql(s"CREATE TABLE $dim (val INT) USING $v2Source") + sql(s"INSERT INTO $dim VALUES (3)") + + val df = sql(s"SELECT * FROM $tbl WHERE part = (SELECT max(val) FROM $dim)") + checkAnswer(df, Row(3, 3)) + + val batchScan = collectBatchScan(df) + assert(batchScan.runtimeFilters.nonEmpty) + + val scan = batchScan.scan + assert( + !scan.asInstanceOf[ + catalog.InMemoryTableWithV2Filter#InMemoryV2FilterBatchScan + ].supportsIterativeFiltering(), + "Base V2 filter table should not support iterative filtering") + } + } + + private def assertDPPRuntimeFilters( + df: DataFrame, expectedCount: Int = 1): Unit = { + val batchScan = collectBatchScan(df) + val dppFilters = batchScan.runtimeFilters.collect { + case d: DynamicPruningExpression => d + } + assert(dppFilters.size === expectedCount, + s"Expected $expectedCount DynamicPruningExpression(s) " + + s"in runtimeFilters, got ${dppFilters.size}") + } + + private def assertScalarSubqueryRuntimeFilters( + df: DataFrame, expectedCount: Int = 1): Unit = { + val batchScan = collectBatchScan(df) + val scalarFilters = batchScan.runtimeFilters.collect { + case f if !f.isInstanceOf[DynamicPruning] => f + } + val dppFilters = batchScan.runtimeFilters.collect { + case d: DynamicPruning => d + } + assert(scalarFilters.size === expectedCount, + s"Expected $expectedCount scalar subquery runtime filter(s), " + + s"got ${scalarFilters.size}") + assert(dppFilters.isEmpty, + "Expected non-DPP runtime filters (scalar subquery)") + } + + private def collectBatchScan(df: DataFrame): BatchScanExec = { + stripAQEPlan(df.queryExecution.executedPlan).collectFirst { + case b: BatchScanExec => b + }.getOrElse(fail("Expected BatchScanExec in plan")) + } + + private def getPushedPartitionPredicates( + df: DataFrame): Seq[PartitionPredicate] = { + val batchScan = collectBatchScan(df) + batchScan.scan match { + case s: InMemoryEnhancedRuntimePartitionFilterTable# + InMemoryEnhancedRuntimePartitionFilterBatchScan => + s.pushedPartitionPredicates + case _ => Seq.empty + } + } + + private def assertPushedPartitionPredicates( + df: DataFrame, + expectedCount: Int): Unit = { + val predicates = getPushedPartitionPredicates(df) + assert(predicates.size === expectedCount, + s"Expected $expectedCount pushed partition predicate(s), " + + s"got ${predicates.size}: $predicates") + } + + private def assertPartitionPredicateOrdinals( + predicate: PartitionPredicate, + expectedOrdinals: Array[Int], + expectedPartitionFieldNames: Array[String]): Unit = { + val refs = predicate.references() + val ordinals = refs.map(_.asInstanceOf[PartitionFieldReference].ordinal()).sorted + assert(ordinals.sameElements(expectedOrdinals.sorted), + s"Expected references().map(_.ordinal()) " + + s"${expectedOrdinals.sorted.mkString("[", ", ", "]")}, " + + s"got ${ordinals.mkString("[", ", ", "]")}") + + val names = expectedPartitionFieldNames + refs.foreach { ref => + assert(ref.isInstanceOf[PartitionFieldReference], + s"Expected PartitionFieldReference, got ${ref.getClass.getName}") + val partRef = ref.asInstanceOf[PartitionFieldReference] + assert(partRef.fieldNames().nonEmpty, + s"PartitionFieldReference.ordinal=${partRef.ordinal()} has empty fieldNames") + assert(partRef.ordinal() < names.length, + s"PartitionFieldReference.ordinal=${partRef.ordinal()} " + + s"out of range for names length ${names.length}") + val expectedName = names(partRef.ordinal()) + val actualName = partRef.fieldNames().mkString(".") + assert(actualName === expectedName, + s"PartitionFieldReference.ordinal=${partRef.ordinal()}: " + + s"expected fieldNames '${expectedName}', got '${actualName}'") + } + } + + private def assertReferencedPartitionFieldOrdinals( + df: DataFrame, + expectedOrdinals: Array[Int], + expectedPartitionFieldNames: Array[String]): Unit = { + getPushedPartitionPredicates(df).foreach { p => + assertPartitionPredicateOrdinals(p, expectedOrdinals, expectedPartitionFieldNames) + } + } + + private def assertScanReturnsPartitionKeys( + df: DataFrame, + expectedPartitionKeys: Set[String]): Unit = { + val batchScan = collectBatchScan(df) + val partitions = batchScan.batch.planInputPartitions() + assert(partitions.length === expectedPartitionKeys.size, + s"Expected ${expectedPartitionKeys.size} partition(s), got ${partitions.length}") + val partKeys = partitions.map(_.asInstanceOf[BufferedRows].keyString()).toSet + assert(partKeys === expectedPartitionKeys, + s"Partition keys should be $expectedPartitionKeys, got $partKeys") + } +} From 0ba04455a32fb42c061755368fe82ec3081580f0 Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Thu, 16 Apr 2026 18:50:03 -0700 Subject: [PATCH 2/8] [SPARK-56521][SQL] Refactor BatchScanExec: guard cast with runtimeFilters.nonEmpty, simplify partPredicates --- .../datasources/v2/BatchScanExec.scala | 50 ++++++++++--------- 1 file changed, 27 insertions(+), 23 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index e27d8d5f2db7d..66985331dfab3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -61,35 +61,39 @@ case class BatchScanExec( // Visible for testing @transient private[sql] lazy val filteredPartitions: Seq[Option[InputPartition]] = { - val dataSourceFilters = runtimeFilters.flatMap { - case DynamicPruningExpression(e) => DataSourceV2Strategy.translateRuntimeFilterV2(e) - case f => DataSourceV2Strategy.translateScalarSubqueryFilterV2(f) - } - val originalPartitioning = outputPartitioning - // the cast is safe as runtime filters are only assigned if the scan can be filtered - val filterableScan = scan.asInstanceOf[SupportsRuntimeV2Filtering] - var filtered = false - - if (dataSourceFilters.nonEmpty) { - filterableScan.filter(dataSourceFilters.toArray) - filtered = true - } + val pushedFilters = if (runtimeFilters.nonEmpty) { + // the cast is safe as runtime filters are only assigned if the scan can be filtered + val filterableScan = scan.asInstanceOf[SupportsRuntimeV2Filtering] + + // push down translatable runtime filters + val dataSourceFilters = runtimeFilters.flatMap { + case DynamicPruningExpression(e) => DataSourceV2Strategy.translateRuntimeFilterV2(e) + case f => DataSourceV2Strategy.translateScalarSubqueryFilterV2(f) + } + if (dataSourceFilters.nonEmpty) { + filterableScan.filter(dataSourceFilters.toArray) + } - // If the scan supports iterative filtering, derive PartitionPredicates from the - // runtime filters and push them in a second pass. (See SPARK-55596) - if (filterableScan.supportsIterativeFiltering()) { - PushDownUtils.getPartitionPredicateSchema(table, output).foreach { partitionFields => - val partPredicates = + // If the scan supports iterative filtering, derive PartitionPredicates from the + // runtime filters and push them in a second pass. (See SPARK-55596) + val partPredicates = if (filterableScan.supportsIterativeFiltering()) { + PushDownUtils.getPartitionPredicateSchema(table, output).map { partitionFields => PushDownUtils.createRuntimePartitionPredicates(runtimeFilters, partitionFields) - if (partPredicates.nonEmpty) { - filterableScan.filter(partPredicates.toArray) - filtered = true - } + }.getOrElse(Seq.empty) + } else { + Seq.empty + } + if (partPredicates.nonEmpty) { + filterableScan.filter(partPredicates.toArray) } + + dataSourceFilters ++ partPredicates + } else { + Seq.empty } - if (filtered) { + if (pushedFilters.nonEmpty) { // call toBatch again to get filtered partitions val newPartitions = scan.toBatch.planInputPartitions() From 2ba33ff2a1c192682e544c7a8758d0a0ebfb5089 Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Fri, 17 Apr 2026 10:59:20 -0700 Subject: [PATCH 3/8] [SPARK-56521][SQL] Refactor pushRuntimeFilters into BatchScanExec companion object Extract runtime filter pushing logic from filteredPartitions into a companion object method with a pattern match guard, removing the asInstanceOf cast. --- .../datasources/v2/BatchScanExec.scala | 79 +++++++++++-------- 1 file changed, 48 insertions(+), 31 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index 66985331dfab3..b5144a329f0b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -62,38 +62,9 @@ case class BatchScanExec( // Visible for testing @transient private[sql] lazy val filteredPartitions: Seq[Option[InputPartition]] = { val originalPartitioning = outputPartitioning - val pushedFilters = if (runtimeFilters.nonEmpty) { - // the cast is safe as runtime filters are only assigned if the scan can be filtered - val filterableScan = scan.asInstanceOf[SupportsRuntimeV2Filtering] - - // push down translatable runtime filters - val dataSourceFilters = runtimeFilters.flatMap { - case DynamicPruningExpression(e) => DataSourceV2Strategy.translateRuntimeFilterV2(e) - case f => DataSourceV2Strategy.translateScalarSubqueryFilterV2(f) - } - if (dataSourceFilters.nonEmpty) { - filterableScan.filter(dataSourceFilters.toArray) - } - - // If the scan supports iterative filtering, derive PartitionPredicates from the - // runtime filters and push them in a second pass. (See SPARK-55596) - val partPredicates = if (filterableScan.supportsIterativeFiltering()) { - PushDownUtils.getPartitionPredicateSchema(table, output).map { partitionFields => - PushDownUtils.createRuntimePartitionPredicates(runtimeFilters, partitionFields) - }.getOrElse(Seq.empty) - } else { - Seq.empty - } - if (partPredicates.nonEmpty) { - filterableScan.filter(partPredicates.toArray) - } - - dataSourceFilters ++ partPredicates - } else { - Seq.empty - } + val filtered = BatchScanExec.pushRuntimeFilters(scan, runtimeFilters, table, output) - if (pushedFilters.nonEmpty) { + if (filtered) { // call toBatch again to get filtered partitions val newPartitions = scan.toBatch.planInputPartitions() @@ -182,3 +153,49 @@ case class BatchScanExec( s"BatchScan ${table.name()}".trim } } + +object BatchScanExec { + + /** + * Pushes runtime filters to a [[SupportsRuntimeV2Filtering]] scan. Translatable filters are + * pushed first, followed by + * [[org.apache.spark.sql.connector.expressions.filter.PartitionPredicate]] + * instances if the scan supports iterative filtering. + * + * @return true if any filters were actually pushed to the data source + */ + private[sql] def pushRuntimeFilters( + scan: Scan, + runtimeFilters: Seq[Expression], + table: Table, + output: Seq[AttributeReference]): Boolean = { + scan match { + case filterableScan: SupportsRuntimeV2Filtering if runtimeFilters.nonEmpty => + // push down translatable runtime filters + val dataSourceFilters = runtimeFilters.flatMap { + case DynamicPruningExpression(e) => DataSourceV2Strategy.translateRuntimeFilterV2(e) + case f => DataSourceV2Strategy.translateScalarSubqueryFilterV2(f) + } + if (dataSourceFilters.nonEmpty) { + filterableScan.filter(dataSourceFilters.toArray) + } + + // If the scan supports iterative filtering, derive PartitionPredicates from the + // runtime filters and push them in a second pass. (See SPARK-55596) + val partPredicates = if (filterableScan.supportsIterativeFiltering()) { + PushDownUtils.getPartitionPredicateSchema(table, output).map { partitionFields => + PushDownUtils.createRuntimePartitionPredicates(runtimeFilters, partitionFields) + }.getOrElse(Seq.empty) + } else { + Seq.empty + } + if (partPredicates.nonEmpty) { + filterableScan.filter(partPredicates.toArray) + } + + dataSourceFilters.nonEmpty || partPredicates.nonEmpty + case _ => + false + } + } +} From f5aa92aba3fe638a3ee5ea0d78050b4c15e9a4f3 Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Fri, 17 Apr 2026 13:58:44 -0700 Subject: [PATCH 4/8] [SPARK-56521][SQL] Move pushRuntimeFilters to PushDownUtils Move the runtime filter pushing logic from the BatchScanExec companion object to PushDownUtils, co-locating it with the related partition predicate helpers. --- .../datasources/v2/BatchScanExec.scala | 48 +------------------ .../datasources/v2/PushDownUtils.scala | 46 +++++++++++++++++- 2 files changed, 46 insertions(+), 48 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index b5144a329f0b6..28f4d12d366b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -62,8 +62,8 @@ case class BatchScanExec( // Visible for testing @transient private[sql] lazy val filteredPartitions: Seq[Option[InputPartition]] = { val originalPartitioning = outputPartitioning - val filtered = BatchScanExec.pushRuntimeFilters(scan, runtimeFilters, table, output) + val filtered = PushDownUtils.pushRuntimeFilters(scan, runtimeFilters, table, output) if (filtered) { // call toBatch again to get filtered partitions val newPartitions = scan.toBatch.planInputPartitions() @@ -153,49 +153,3 @@ case class BatchScanExec( s"BatchScan ${table.name()}".trim } } - -object BatchScanExec { - - /** - * Pushes runtime filters to a [[SupportsRuntimeV2Filtering]] scan. Translatable filters are - * pushed first, followed by - * [[org.apache.spark.sql.connector.expressions.filter.PartitionPredicate]] - * instances if the scan supports iterative filtering. - * - * @return true if any filters were actually pushed to the data source - */ - private[sql] def pushRuntimeFilters( - scan: Scan, - runtimeFilters: Seq[Expression], - table: Table, - output: Seq[AttributeReference]): Boolean = { - scan match { - case filterableScan: SupportsRuntimeV2Filtering if runtimeFilters.nonEmpty => - // push down translatable runtime filters - val dataSourceFilters = runtimeFilters.flatMap { - case DynamicPruningExpression(e) => DataSourceV2Strategy.translateRuntimeFilterV2(e) - case f => DataSourceV2Strategy.translateScalarSubqueryFilterV2(f) - } - if (dataSourceFilters.nonEmpty) { - filterableScan.filter(dataSourceFilters.toArray) - } - - // If the scan supports iterative filtering, derive PartitionPredicates from the - // runtime filters and push them in a second pass. (See SPARK-55596) - val partPredicates = if (filterableScan.supportsIterativeFiltering()) { - PushDownUtils.getPartitionPredicateSchema(table, output).map { partitionFields => - PushDownUtils.createRuntimePartitionPredicates(runtimeFilters, partitionFields) - }.getOrElse(Seq.empty) - } else { - Seq.empty - } - if (partPredicates.nonEmpty) { - filterableScan.filter(partPredicates.toArray) - } - - dataSourceFilters.nonEmpty || partPredicates.nonEmpty - case _ => - false - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index 8cab3855a7bb1..85b4b522f22e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.connector.expressions.{IdentityTransform, SortOrder} import org.apache.spark.sql.connector.expressions.filter.Predicate -import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownOffset, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownTopN, SupportsPushDownV2Filters} +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownOffset, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownTopN, SupportsPushDownV2Filters, SupportsRuntimeV2Filtering} import org.apache.spark.sql.execution.{ScalarSubquery => ExecScalarSubquery} import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, DataSourceUtils} import org.apache.spark.sql.internal.SQLConf @@ -131,6 +131,50 @@ object PushDownUtils extends Logging { } } + /** + * Pushes runtime filters to a [[SupportsRuntimeV2Filtering]] scan. + * Translatable filters are pushed first, followed by [[PartitionPredicate]] if the scan supports + * iterative filtering. + * + * @return true if any filters were pushed to the data source + */ + def pushRuntimeFilters( + scan: Scan, + runtimeFilters: Seq[Expression], + table: Table, + output: Seq[AttributeReference]): Boolean = { + scan match { + case filterableScan: SupportsRuntimeV2Filtering if runtimeFilters.nonEmpty => + + // First push down translatable runtime filters. + val translatedFilters = runtimeFilters.flatMap { + case DynamicPruningExpression(e) => DataSourceV2Strategy.translateRuntimeFilterV2(e) + case f => DataSourceV2Strategy.translateScalarSubqueryFilterV2(f) + } + if (translatedFilters.nonEmpty) { + filterableScan.filter(translatedFilters.toArray) + } + + // If the scan supports iterative filtering, derive PartitionPredicates + // and push them in a second pass. (See SPARK-55596) + val partPredicates = + if (filterableScan.supportsIterativeFiltering()) { + getPartitionPredicateSchema(table, output).map { + fields => createRuntimePartitionPredicates(runtimeFilters, fields) + }.getOrElse(Seq.empty) + } else { + Seq.empty + } + if (partPredicates.nonEmpty) { + filterableScan.filter(partPredicates.toArray) + } + + translatedFilters.nonEmpty || partPredicates.nonEmpty + case _ => + false + } + } + /** * Returns a Seq of [[PartitionPredicateField]] representing partition transform expression types, * if schema is supported for [[PartitionPredicate]] push down. None if not supported. From 3f5dd946882e03e44be234b791b0dd4b6643c244 Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Fri, 17 Apr 2026 14:52:29 -0700 Subject: [PATCH 5/8] [SPARK-56521][SQL] Add pushedPredicates() to SupportsRuntimeV2Filtering Add a pushedPredicates() API to SupportsRuntimeV2Filtering, mirroring SupportsPushDownV2Filters. Use it in pushRuntimeFilters to exclude already-pushed predicates from the second pass and to determine whether replanning is needed. --- .../read/SupportsRuntimeV2Filtering.java | 35 +++++++++++++-- ...yEnhancedRuntimePartitionFilterTable.scala | 20 ++++++--- .../datasources/v2/PushDownUtils.scala | 44 ++++++++++++------- 3 files changed, 74 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeV2Filtering.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeV2Filtering.java index d158ab413649a..bee1d3b7e9d19 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeV2Filtering.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeV2Filtering.java @@ -30,7 +30,12 @@ * data source V2 {@link Predicate} instead of data source V1 {@link Filter}. * {@link SupportsRuntimeV2Filtering} is preferred over {@link SupportsRuntimeFiltering} * and only one of them should be implemented by the data sources. - * + *

+ * Iterative filtering: When {@link #supportsIterativeFiltering()} returns true, + * {@link #filter(Predicate[])} may be called multiple times on the same + * {@link Scan} instance with additional predicates (e.g. {@link PartitionPredicate}). + * The implementation must accumulate state across all calls, and + * {@link #pushedPredicates()} must return predicates from all of them. *

* Note that Spark will push runtime filters only if they are beneficial. * @@ -60,19 +65,43 @@ public interface SupportsRuntimeV2Filtering extends Scan { * partition values (omitting those with no data) via {@link Batch#planInputPartitions()}. The * scan must not report new partition values that were not present in the original partitioning. *

+ * This method may be called multiple times with additional predicates (e.g. + * {@link PartitionPredicate}) when {@link #supportsIterativeFiltering()} returns true. + * The implementation must accumulate state across all calls so that + * {@link #pushedPredicates()} can return predicates from all of them. + *

* Note that Spark will call {@link Scan#toBatch()} again after filtering the scan at runtime. * * @param predicates data source V2 predicates used to filter the scan at runtime */ void filter(Predicate[] predicates); + /** + * Returns the predicates that are pushed to the data source via + * {@link #filter(Predicate[])}. + *

+ * When iterative filtering is supported and {@link #filter(Predicate[])} was called + * multiple times, this method must return predicates from all calls. + *

+ * It's possible that there are no runtime predicates and + * {@link #filter(Predicate[])} is never called; + * an empty array should be returned for this case. + * + * @since 4.2.0 + */ + default Predicate[] pushedPredicates() { + return new Predicate[0]; + } + /** * Returns true if this scan supports iterative runtime filtering. When true, - * {@link #filter(Predicate[])} may be called multiple times with additional predicates. + * {@link #filter(Predicate[])} may be called multiple times with additional + * predicates. The implementation must accumulate state across all calls, + * and {@link #pushedPredicates()} must return predicates from all of them. + * See the class-level Javadoc for the full contract. *

* When enabled, Spark will derive {@link PartitionPredicate} instances from the runtime * filters and push them via a subsequent {@link #filter(Predicate[])} call. - *

* * @since 4.2.0 */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryEnhancedRuntimePartitionFilterTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryEnhancedRuntimePartitionFilterTable.scala index 4825a7f7a3d6e..4f1a1eb8ce023 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryEnhancedRuntimePartitionFilterTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryEnhancedRuntimePartitionFilterTable.scala @@ -61,27 +61,35 @@ class InMemoryEnhancedRuntimePartitionFilterTable( extends BatchScanBaseClass(_data, readSchema, tableSchema) with SupportsRuntimeV2Filtering { - private val _pushedPartitionPredicates = ArrayBuffer.empty[PartitionPredicate] + private val _allPushedPredicates = ArrayBuffer.empty[Predicate] def pushedPartitionPredicates: Seq[PartitionPredicate] = - _pushedPartitionPredicates.toSeq + _allPushedPredicates.collect { + case pp: PartitionPredicate => pp + }.toSeq + + override def pushedPredicates(): Array[Predicate] = + _allPushedPredicates.toArray override def supportsIterativeFiltering(): Boolean = true override def filterAttributes(): Array[NamedReference] = { val scanFields = readSchema.fields.map(_.name).toSet partitioning.flatMap(_.references()) - .filter(ref => scanFields.contains(ref.fieldNames.mkString("."))) + .filter(ref => scanFields.contains( + ref.fieldNames.mkString("."))) } override def filter(filters: Array[Predicate]): Unit = { filters.foreach { case pp: PartitionPredicate => - _pushedPartitionPredicates += pp + _allPushedPredicates += pp data = data.filter { partition => - pp.eval(partition.asInstanceOf[BufferedRows].partitionKey()) + pp.eval(partition + .asInstanceOf[BufferedRows].partitionKey()) } - case _ => + case other => + _allPushedPredicates += other } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index 85b4b522f22e9..1474523199ff9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -132,9 +132,12 @@ object PushDownUtils extends Logging { } /** - * Pushes runtime filters to a [[SupportsRuntimeV2Filtering]] scan. - * Translatable filters are pushed first, followed by [[PartitionPredicate]] if the scan supports - * iterative filtering. + * Pushes runtime filters to a [[SupportsRuntimeV2Filtering]] + * scan. Translatable filters are pushed first, followed by + * [[PartitionPredicate]] if the scan supports iterative + * filtering. Already-pushed predicates (reported by + * [[SupportsRuntimeV2Filtering#pushedPredicates]]) are excluded + * from subsequent passes. * * @return true if any filters were pushed to the data source */ @@ -144,32 +147,41 @@ object PushDownUtils extends Logging { table: Table, output: Seq[AttributeReference]): Boolean = { scan match { - case filterableScan: SupportsRuntimeV2Filtering if runtimeFilters.nonEmpty => + case filterableScan: SupportsRuntimeV2Filtering + if runtimeFilters.nonEmpty => // First push down translatable runtime filters. val translatedFilters = runtimeFilters.flatMap { - case DynamicPruningExpression(e) => DataSourceV2Strategy.translateRuntimeFilterV2(e) - case f => DataSourceV2Strategy.translateScalarSubqueryFilterV2(f) + case DynamicPruningExpression(e) => + DataSourceV2Strategy.translateRuntimeFilterV2(e) + case f => + DataSourceV2Strategy + .translateScalarSubqueryFilterV2(f) } if (translatedFilters.nonEmpty) { filterableScan.filter(translatedFilters.toArray) } - // If the scan supports iterative filtering, derive PartitionPredicates - // and push them in a second pass. (See SPARK-55596) - val partPredicates = - if (filterableScan.supportsIterativeFiltering()) { + // If the scan supports iterative filtering, derive + // PartitionPredicates and push them in a second pass, + // excluding predicates already accepted by the scan. + // (See SPARK-55596) + if (filterableScan.supportsIterativeFiltering()) { + val alreadyPushed = + filterableScan.pushedPredicates().toSet + val partPredicates = getPartitionPredicateSchema(table, output).map { - fields => createRuntimePartitionPredicates(runtimeFilters, fields) + fields => + createRuntimePartitionPredicates( + runtimeFilters, fields) }.getOrElse(Seq.empty) - } else { - Seq.empty + .filterNot(alreadyPushed.contains) + if (partPredicates.nonEmpty) { + filterableScan.filter(partPredicates.toArray) } - if (partPredicates.nonEmpty) { - filterableScan.filter(partPredicates.toArray) } - translatedFilters.nonEmpty || partPredicates.nonEmpty + filterableScan.pushedPredicates().nonEmpty case _ => false } From 9a2cdce0eb7c271fd9d95602d321197aeda456e4 Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Fri, 17 Apr 2026 16:53:07 -0700 Subject: [PATCH 6/8] [SPARK-56521][SQL] Refine pushRuntimeFilters: filterAttributes gate, pushedPredicates dedup, and comprehensive tests - Use pushedPredicates() to avoid deriving PartitionPredicates from runtime filters whose V2 translation was already accepted in the first filter() pass, preventing duplicate pushdown. - Gate PartitionPredicate candidates on filterAttributes(), consistent with PartitionPruning's planning-time check, using ExprId-based AttributeSet.subsetOf comparison. - Reorganize test suite into 12 numbered cases (with subcases) covering all combinations of DPP/scalar, translated/untranslatable, accepted/rejected, partition/data column, and filterAttributes. - Add configurable test table properties (accept-v2-predicates, filter-attributes) for targeted scenario testing. --- ...yEnhancedRuntimePartitionFilterTable.scala | 47 ++- .../datasources/v2/PushDownUtils.scala | 60 ++-- ...2EnhancedRuntimePartitionFilterSuite.scala | 281 ++++++++++++++---- 3 files changed, 285 insertions(+), 103 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryEnhancedRuntimePartitionFilterTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryEnhancedRuntimePartitionFilterTable.scala index 4f1a1eb8ce023..ed70d10057d39 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryEnhancedRuntimePartitionFilterTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryEnhancedRuntimePartitionFilterTable.scala @@ -32,6 +32,12 @@ import org.apache.spark.util.ArrayImplicits._ * In-memory table whose batch scan implements [[SupportsRuntimeV2Filtering]] with * iterative filtering support, so that [[PartitionPredicate]] instances derived from * runtime filters are pushed via a second [[SupportsRuntimeV2Filtering#filter]] call. + * + * Table properties: + * - `accept-v2-predicates` (default `false`): when true, non-PartitionPredicate + * V2 predicates are reported via `pushedPredicates()` (i.e. accepted). + * - `filter-attributes` (default: all partition cols): comma-separated list of + * column names to expose from `filterAttributes()`. */ class InMemoryEnhancedRuntimePartitionFilterTable( name: String, @@ -63,6 +69,19 @@ class InMemoryEnhancedRuntimePartitionFilterTable( private val _allPushedPredicates = ArrayBuffer.empty[Predicate] + private val acceptV2Predicates = + InMemoryEnhancedRuntimePartitionFilterTable.this.properties + .getOrDefault( + InMemoryEnhancedRuntimePartitionFilterTable + .AcceptV2PredicatesKey, "false").toBoolean + + private val restrictedFilterAttrs: Option[Set[String]] = + Option(InMemoryEnhancedRuntimePartitionFilterTable.this + .properties.get( + InMemoryEnhancedRuntimePartitionFilterTable + .FilterAttributesKey)) + .map(_.split(",").map(_.trim).toSet) + def pushedPartitionPredicates: Seq[PartitionPredicate] = _allPushedPredicates.collect { case pp: PartitionPredicate => pp @@ -75,9 +94,11 @@ class InMemoryEnhancedRuntimePartitionFilterTable( override def filterAttributes(): Array[NamedReference] = { val scanFields = readSchema.fields.map(_.name).toSet - partitioning.flatMap(_.references()) - .filter(ref => scanFields.contains( - ref.fieldNames.mkString("."))) + partitioning.flatMap(_.references()).filter { ref => + val name = ref.fieldNames.mkString(".") + scanFields.contains(name) && + restrictedFilterAttrs.forall(_.contains(name)) + } } override def filter(filters: Array[Predicate]): Unit = { @@ -85,12 +106,26 @@ class InMemoryEnhancedRuntimePartitionFilterTable( case pp: PartitionPredicate => _allPushedPredicates += pp data = data.filter { partition => - pp.eval(partition - .asInstanceOf[BufferedRows].partitionKey()) + pp.eval( + partition.asInstanceOf[BufferedRows].partitionKey()) } case other => - _allPushedPredicates += other + if (acceptV2Predicates) _allPushedPredicates += other } } } } + +object InMemoryEnhancedRuntimePartitionFilterTable { + /** + * Table property: when "true", non-PartitionPredicate V2 predicates + * pushed via filter() are reported in pushedPredicates() (accepted). + */ + private[catalog] val AcceptV2PredicatesKey = "accept-v2-predicates" + + /** + * Table property: comma-separated column names to expose from + * filterAttributes(). Default: all partition columns. + */ + private[catalog] val FilterAttributesKey = "filter-attributes" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index 1474523199ff9..e2f1f7c660126 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -132,12 +132,10 @@ object PushDownUtils extends Logging { } /** - * Pushes runtime filters to a [[SupportsRuntimeV2Filtering]] - * scan. Translatable filters are pushed first, followed by - * [[PartitionPredicate]] if the scan supports iterative - * filtering. Already-pushed predicates (reported by - * [[SupportsRuntimeV2Filtering#pushedPredicates]]) are excluded - * from subsequent passes. + * Pushes runtime filters to a [[SupportsRuntimeV2Filtering]] scan. Translatable filters are + * pushed first, followed by [[PartitionPredicate]] if the scan supports iterative filtering. + * Only runtime filters that were not already translated are used to derive PartitionPredicates + * in the second pass, avoiding duplicate pushdown. * * @return true if any filters were pushed to the data source */ @@ -147,35 +145,33 @@ object PushDownUtils extends Logging { table: Table, output: Seq[AttributeReference]): Boolean = { scan match { - case filterableScan: SupportsRuntimeV2Filtering - if runtimeFilters.nonEmpty => - - // First push down translatable runtime filters. - val translatedFilters = runtimeFilters.flatMap { - case DynamicPruningExpression(e) => - DataSourceV2Strategy.translateRuntimeFilterV2(e) - case f => - DataSourceV2Strategy - .translateScalarSubqueryFilterV2(f) - } - if (translatedFilters.nonEmpty) { - filterableScan.filter(translatedFilters.toArray) + case filterableScan: SupportsRuntimeV2Filtering if runtimeFilters.nonEmpty => + // Push down translatable runtime filters. + val filtersToTranslated = runtimeFilters.flatMap { f => + (f match { + case DynamicPruningExpression(e) => DataSourceV2Strategy.translateRuntimeFilterV2(e) + case o => DataSourceV2Strategy.translateScalarSubqueryFilterV2(o) + }).map(f -> _) + }.toMap + + if (filtersToTranslated.nonEmpty) { + filterableScan.filter(filtersToTranslated.values.toArray) } - // If the scan supports iterative filtering, derive - // PartitionPredicates and push them in a second pass, - // excluding predicates already accepted by the scan. - // (See SPARK-55596) + // If the scan supports iterative filtering, derive PartitionPredicates from runtime + // filters whose translation was not already accepted in the first pass. (See SPARK-55596) + // Only candidates whose referenced columns are declared in filterAttributes() are eligible. if (filterableScan.supportsIterativeFiltering()) { - val alreadyPushed = - filterableScan.pushedPredicates().toSet - val partPredicates = - getPartitionPredicateSchema(table, output).map { - fields => - createRuntimePartitionPredicates( - runtimeFilters, fields) - }.getOrElse(Seq.empty) - .filterNot(alreadyPushed.contains) + val filterAttrs = AttributeSet(filterableScan.filterAttributes() + .flatMap(r => output.find(a => SQLConf.get.resolver(a.name, r.fieldNames.head)))) + val pushed = filterableScan.pushedPredicates().toSet + val candidates = runtimeFilters.filter { f => + !filtersToTranslated.get(f).exists(pushed.contains) && + f.references.subsetOf(filterAttrs) + } + val partPredicates = getPartitionPredicateSchema(table, output) + .map(createRuntimePartitionPredicates(candidates, _)) + .getOrElse(Seq.empty) if (partPredicates.nonEmpty) { filterableScan.filter(partPredicates.toArray) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2EnhancedRuntimePartitionFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2EnhancedRuntimePartitionFilterSuite.scala index 69392a3439d65..8e3a645be0630 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2EnhancedRuntimePartitionFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2EnhancedRuntimePartitionFilterSuite.scala @@ -33,6 +33,25 @@ import org.apache.spark.sql.test.SharedSparkSession * Tests that [[PartitionPredicate]] instances are pushed via iterative runtime filtering * (second [[SupportsRuntimeV2Filtering#filter]] call) for both DPP and scalar subquery * runtime filters. + * + * Pushdown cases (DPP / Scalar Subquery, Translated / Untranslatable, + * Accepted / Rejected, Partition / Data Column, In filterAttributes): + * + * PartitionPredicate IS created: + * 1. DPP, translated, rejected in 1st pass, partition col -> PartitionPredicate + * 2. DPP, translated, rejected in 1st pass, non-first partition col -> PartitionPredicate + * 3. Scalar, translatable, rejected in 1st pass, partition col -> PartitionPredicate + * 4. Scalar, untranslatable, partition col -> PartitionPredicate + * 5. Scalar, two subqueries on two partition cols -> 2 PartitionPredicates + * 6. Scalar, non-first of 3 partition cols -> PartitionPredicate + * 7. Mixed: 1st pass accepted + untranslatable -> only untranslatable gets PartitionPredicate + * + * PartitionPredicate is NOT created: + * 8. DPP, translated, accepted in 1st pass -> no PartitionPredicate + * 9. Scalar, translatable, accepted in 1st pass -> no PartitionPredicate + * 10. Scalar on data column -> no PartitionPredicate + * 11. supportsIterativeFiltering is false -> no PartitionPredicate + * 12. Partition col not in filterAttributes -> no PartitionPredicate */ class DataSourceV2EnhancedRuntimePartitionFilterSuite extends QueryTest with SharedSparkSession with BeforeAndAfter { @@ -49,7 +68,18 @@ class DataSourceV2EnhancedRuntimePartitionFilterSuite spark.sessionState.catalogManager.reset() } - test("DPP: PartitionPredicate pushed via iterative runtime filtering") { + private def withDPPConf(f: => Unit): Unit = { + withSQLConf( + SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true", + SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "false", + SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "10")(f) + } + + // --------------------------------------------------------------------------- + // PartitionPredicate IS created + // --------------------------------------------------------------------------- + + test("case 1: DPP translated, rejected in 1st pass -> PartitionPredicate") { val fact = s"$catalogName.fact" val dim = s"$catalogName.dim" withTable(fact, dim) { @@ -60,17 +90,13 @@ class DataSourceV2EnhancedRuntimePartitionFilterSuite sql(s"CREATE TABLE $dim (dim_id INT, dim_val STRING) USING $v2Source") sql(s"INSERT INTO $dim VALUES (2, 'two')") - withSQLConf( - SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true", - SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "false", - SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "10") { + withDPPConf { val df = sql( s"""SELECT f.id, f.part FROM $fact f JOIN $dim d |ON f.part = d.dim_id WHERE d.dim_val = 'two'""".stripMargin) checkAnswer(df, Row(2, 2)) assertDPPRuntimeFilters(df) - assertPushedPartitionPredicates(df, 1) assertScanReturnsPartitionKeys(df, Set("2")) assertReferencedPartitionFieldOrdinals(df, Array(0), Array("part")) @@ -78,7 +104,7 @@ class DataSourceV2EnhancedRuntimePartitionFilterSuite } } - test("DPP: PartitionPredicate on non-first partition column") { + test("case 2: DPP translated, rejected, non-first partition col -> PartitionPredicate") { val fact = s"$catalogName.fact2" val dim = s"$catalogName.dim2" withTable(fact, dim) { @@ -87,16 +113,10 @@ class DataSourceV2EnhancedRuntimePartitionFilterSuite for (i <- 0 until 5; j <- 0 until 2) { sql(s"INSERT INTO $fact VALUES (${i * 2 + j}, $i, $j)") } - sql(s"CREATE TABLE $dim (dim_id INT, dim_val STRING) " + - s"USING $v2Source") + sql(s"CREATE TABLE $dim (dim_id INT, dim_val STRING) USING $v2Source") sql(s"INSERT INTO $dim VALUES (1, 'one')") - withSQLConf( - SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true", - SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> - "false", - SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> - "10") { + withDPPConf { val df = sql( s"""SELECT f.id, f.p1, f.p2 FROM $fact f JOIN $dim d |ON f.p2 = d.dim_id @@ -106,15 +126,13 @@ class DataSourceV2EnhancedRuntimePartitionFilterSuite Row(7, 3, 1), Row(9, 4, 1))) assertDPPRuntimeFilters(df) - assertPushedPartitionPredicates(df, 1) - assertReferencedPartitionFieldOrdinals( - df, Array(1), Array("p1", "p2")) + assertReferencedPartitionFieldOrdinals(df, Array(1), Array("p1", "p2")) } } } - test("scalar subquery: PartitionPredicate pushed via iterative runtime filtering") { + test("case 3: scalar subquery translatable, rejected in 1st pass -> PartitionPredicate") { val tbl = s"$catalogName.tbl" val dim = s"$catalogName.dim" withTable(tbl, dim) { @@ -129,14 +147,13 @@ class DataSourceV2EnhancedRuntimePartitionFilterSuite checkAnswer(df, Row(3, 3)) assertScalarSubqueryRuntimeFilters(df) - assertPushedPartitionPredicates(df, 1) assertScanReturnsPartitionKeys(df, Set("3")) assertReferencedPartitionFieldOrdinals(df, Array(0), Array("part")) } } - test("scalar subquery: complex expression with arithmetic on subquery result") { + test("case 4a: scalar subquery untranslatable (complex expr) -> PartitionPredicate") { val tbl = s"$catalogName.tbl_complex" val dim = s"$catalogName.dim_complex" withTable(tbl, dim) { @@ -151,14 +168,13 @@ class DataSourceV2EnhancedRuntimePartitionFilterSuite checkAnswer(df, Row(4, 4)) assertScalarSubqueryRuntimeFilters(df) - assertPushedPartitionPredicates(df, 1) assertScanReturnsPartitionKeys(df, Set("4")) assertReferencedPartitionFieldOrdinals(df, Array(0), Array("part")) } } - test("scalar subquery: RLIKE (untranslatable) with subquery pattern") { + test("case 4b: scalar subquery untranslatable (RLIKE) -> PartitionPredicate") { val tbl = s"$catalogName.tbl_rlike" val dim = s"$catalogName.dim_rlike" withTable(tbl, dim) { @@ -175,14 +191,13 @@ class DataSourceV2EnhancedRuntimePartitionFilterSuite checkAnswer(df, Seq(Row(1, "abc"), Row(3, "abx"))) assertScalarSubqueryRuntimeFilters(df) - assertPushedPartitionPredicates(df, 1) assertScanReturnsPartitionKeys(df, Set("abc", "abx")) assertReferencedPartitionFieldOrdinals(df, Array(0), Array("part")) } } - test("scalar subquery: UDF on partition column with subquery value") { + test("case 4c: scalar subquery untranslatable (UDF) -> PartitionPredicate") { val tbl = s"$catalogName.tbl_udf" val dim = s"$catalogName.dim_udf" withTable(tbl, dim) { @@ -196,21 +211,21 @@ class DataSourceV2EnhancedRuntimePartitionFilterSuite sql(s"INSERT INTO $dim VALUES ('A')") spark.udf.register("my_upper_runtime", - (s: String) => if (s == null) null else s.toUpperCase(java.util.Locale.ROOT)) + (s: String) => if (s == null) null + else s.toUpperCase(java.util.Locale.ROOT)) val df = sql( s"SELECT * FROM $tbl WHERE my_upper_runtime(part) = (SELECT max(val) FROM $dim)") checkAnswer(df, Seq(Row(1, "a"), Row(2, "A"))) assertScalarSubqueryRuntimeFilters(df) - assertPushedPartitionPredicates(df, 1) assertScanReturnsPartitionKeys(df, Set("a", "A")) assertReferencedPartitionFieldOrdinals(df, Array(0), Array("part")) } } - test("scalar subquery: two PartitionPredicates for two subqueries on different partition cols") { + test("case 5: scalar subquery two partition cols -> 2 PartitionPredicates") { val fact = s"$catalogName.fact_2sub" val dim1 = s"$catalogName.dim_2sub1" val dim2 = s"$catalogName.dim_2sub2" @@ -235,24 +250,16 @@ class DataSourceV2EnhancedRuntimePartitionFilterSuite checkAnswer(df, Row(1, 1, "a")) assertScalarSubqueryRuntimeFilters(df, expectedCount = 2) - assertPushedPartitionPredicates(df, 2) assertScanReturnsPartitionKeys(df, Set("1/a")) - val predicates = getPushedPartitionPredicates(df) val partFieldNames = Array("p1", "p2") - val p1Pred = predicates.find(_.references().exists( - _.asInstanceOf[PartitionFieldReference].ordinal() == 0)) - val p2Pred = predicates.find(_.references().exists( - _.asInstanceOf[PartitionFieldReference].ordinal() == 1)) - assert(p1Pred.isDefined, "Expected a PartitionPredicate referencing p1 (ordinal 0)") - assert(p2Pred.isDefined, "Expected a PartitionPredicate referencing p2 (ordinal 1)") - assertPartitionPredicateOrdinals(p1Pred.get, Array(0), partFieldNames) - assertPartitionPredicateOrdinals(p2Pred.get, Array(1), partFieldNames) + assertPredicateForOrdinal(df, 0, partFieldNames) + assertPredicateForOrdinal(df, 1, partFieldNames) } } - test("scalar subquery: PartitionPredicate on non-first column of three-partition-column table") { + test("case 6: scalar subquery non-first of 3 partition cols -> PartitionPredicate") { val tbl = s"$catalogName.tbl_3part" val dim = s"$catalogName.dim_3part" withTable(tbl, dim) { @@ -271,14 +278,109 @@ class DataSourceV2EnhancedRuntimePartitionFilterSuite checkAnswer(df, Seq(Row(1, 1, "a", 10), Row(3, 2, "a", 20))) assertScalarSubqueryRuntimeFilters(df) - assertPushedPartitionPredicates(df, 1) assertScanReturnsPartitionKeys(df, Set("1/a/10", "2/a/20")) - assertReferencedPartitionFieldOrdinals(df, Array(1), Array("p0", "p1", "p2")) + assertReferencedPartitionFieldOrdinals( + df, Array(1), Array("p0", "p1", "p2")) } } - test("no PartitionPredicate for scalar subquery on data column") { + test("case 7: mixed - accepted in 1st pass + untranslatable -> " + + "only untranslatable gets PartitionPredicate") { + val tbl = s"$catalogName.tbl_mixed" + val dim1 = s"$catalogName.dim_mixed1" + val dim2 = s"$catalogName.dim_mixed2" + withTable(tbl, dim1, dim2) { + sql(s"CREATE TABLE $tbl (id INT, p1 INT, p2 STRING) " + + s"USING $v2Source PARTITIONED BY (p1, p2) " + + "TBLPROPERTIES('accept-v2-predicates' = 'true')") + sql(s"INSERT INTO $tbl VALUES (1, 1, 'a')") + sql(s"INSERT INTO $tbl VALUES (2, 1, 'b')") + sql(s"INSERT INTO $tbl VALUES (3, 2, 'a')") + sql(s"INSERT INTO $tbl VALUES (4, 2, 'b')") + + sql(s"CREATE TABLE $dim1 (val INT) USING $v2Source") + sql(s"INSERT INTO $dim1 VALUES (1)") + sql(s"CREATE TABLE $dim2 (val STRING) USING $v2Source") + sql(s"INSERT INTO $dim2 VALUES ('A')") + + spark.udf.register("my_upper_mixed", + (s: String) => if (s == null) null + else s.toUpperCase(java.util.Locale.ROOT)) + + // p1 = (subquery) is translatable and accepted in 1st pass. + // my_upper_mixed(p2) = (subquery) is untranslatable -> PartitionPredicate. + // Only the untranslatable filter should produce a PartitionPredicate. + val df = sql( + s"""SELECT * FROM $tbl + |WHERE p1 = (SELECT max(val) FROM $dim1) + | AND my_upper_mixed(p2) = (SELECT max(val) FROM $dim2) + |""".stripMargin) + checkAnswer(df, Row(1, 1, "a")) + + assertScalarSubqueryRuntimeFilters(df, expectedCount = 2) + assertPushedPartitionPredicates(df, 1) + // The V2 predicate for p1=1 was accepted but not evaluated by the + // test table, so both p2='a' partitions remain after the + // PartitionPredicate. Spark applies p1=1 as a post-scan filter. + assertScanReturnsPartitionKeys(df, Set("1/a", "2/a")) + assertReferencedPartitionFieldOrdinals(df, Array(1), Array("p1", "p2")) + } + } + + // --------------------------------------------------------------------------- + // PartitionPredicate is NOT created + // --------------------------------------------------------------------------- + + test("case 8: DPP translated, accepted in 1st pass -> no PartitionPredicate") { + val fact = s"$catalogName.fact_acc" + val dim = s"$catalogName.dim_acc" + withTable(fact, dim) { + sql(s"CREATE TABLE $fact (id INT, part INT) USING $v2Source " + + "PARTITIONED BY (part) " + + "TBLPROPERTIES('accept-v2-predicates' = 'true')") + for (i <- 0 until 5) { + sql(s"INSERT INTO $fact VALUES ($i, $i)") + } + sql(s"CREATE TABLE $dim (dim_id INT, dim_val STRING) USING $v2Source") + sql(s"INSERT INTO $dim VALUES (2, 'two')") + + withDPPConf { + val df = sql( + s"""SELECT f.id, f.part FROM $fact f JOIN $dim d + |ON f.part = d.dim_id WHERE d.dim_val = 'two'""".stripMargin) + checkAnswer(df, Row(2, 2)) + + assertDPPRuntimeFilters(df) + assertPushedPartitionPredicates(df, 0) + } + } + } + + test("case 9: scalar subquery translatable, accepted in 1st pass -> " + + "no PartitionPredicate") { + val tbl = s"$catalogName.tbl_acc" + val dim = s"$catalogName.dim_acc2" + withTable(tbl, dim) { + sql(s"CREATE TABLE $tbl (id INT, part INT) USING $v2Source " + + "PARTITIONED BY (part) " + + "TBLPROPERTIES('accept-v2-predicates' = 'true')") + for (i <- 0 until 5) { + sql(s"INSERT INTO $tbl VALUES ($i, $i)") + } + sql(s"CREATE TABLE $dim (val INT) USING $v2Source") + sql(s"INSERT INTO $dim VALUES (3)") + + val df = sql( + s"SELECT * FROM $tbl WHERE part = (SELECT max(val) FROM $dim)") + checkAnswer(df, Row(3, 3)) + + assertScalarSubqueryRuntimeFilters(df) + assertPushedPartitionPredicates(df, 0) + } + } + + test("case 10: scalar subquery on data column -> no PartitionPredicate") { val tbl = s"$catalogName.tbl_data" val dim = s"$catalogName.dim_data" withTable(tbl, dim) { @@ -298,7 +400,7 @@ class DataSourceV2EnhancedRuntimePartitionFilterSuite } } - test("no PartitionPredicate when supportsIterativeFiltering is false") { + test("case 11: supportsIterativeFiltering is false -> no PartitionPredicate") { val baseCatalog = "testv2filterNoIterative" spark.conf.set(s"spark.sql.catalog.$baseCatalog", classOf[catalog.InMemoryTableWithV2FilterCatalog].getName) @@ -306,28 +408,53 @@ class DataSourceV2EnhancedRuntimePartitionFilterSuite val tbl = s"$baseCatalog.tbl" val dim = s"$baseCatalog.dim" withTable(tbl, dim) { - sql(s"CREATE TABLE $tbl (id INT, part INT) USING $v2Source PARTITIONED BY (part)") + sql(s"CREATE TABLE $tbl (id INT, part INT) " + + s"USING $v2Source PARTITIONED BY (part)") for (i <- 0 until 5) { sql(s"INSERT INTO $tbl VALUES ($i, $i)") } sql(s"CREATE TABLE $dim (val INT) USING $v2Source") sql(s"INSERT INTO $dim VALUES (3)") - val df = sql(s"SELECT * FROM $tbl WHERE part = (SELECT max(val) FROM $dim)") + val df = sql( + s"SELECT * FROM $tbl WHERE part = (SELECT max(val) FROM $dim)") checkAnswer(df, Row(3, 3)) - val batchScan = collectBatchScan(df) - assert(batchScan.runtimeFilters.nonEmpty) + assertHasRuntimeFilters(df) + assertPushedPartitionPredicates(df, 0) + } + } - val scan = batchScan.scan - assert( - !scan.asInstanceOf[ - catalog.InMemoryTableWithV2Filter#InMemoryV2FilterBatchScan - ].supportsIterativeFiltering(), - "Base V2 filter table should not support iterative filtering") + test("case 12: partition col not in filterAttributes -> no PartitionPredicate") { + val tbl = s"$catalogName.tbl_noattr" + val dim = s"$catalogName.dim_noattr" + withTable(tbl, dim) { + sql(s"CREATE TABLE $tbl (id INT, p1 INT, p2 INT) " + + s"USING $v2Source PARTITIONED BY (p1, p2) " + + "TBLPROPERTIES('filter-attributes' = 'p1')") + sql(s"INSERT INTO $tbl VALUES (1, 1, 10)") + sql(s"INSERT INTO $tbl VALUES (2, 1, 20)") + sql(s"INSERT INTO $tbl VALUES (3, 2, 10)") + sql(s"INSERT INTO $tbl VALUES (4, 2, 20)") + + sql(s"CREATE TABLE $dim (val INT) USING $v2Source") + sql(s"INSERT INTO $dim VALUES (10)") + + // Scalar subquery on p2, which is NOT in filterAttributes. + // p2 is a partition column, but since it's not declared as a + // filterable attribute the PartitionPredicate should not be pushed. + val df = sql( + s"SELECT * FROM $tbl WHERE p2 = (SELECT max(val) FROM $dim)") + checkAnswer(df, Seq(Row(1, 1, 10), Row(3, 2, 10))) + + assertPushedPartitionPredicates(df, 0) } } + // --------------------------------------------------------------------------- + // Helper methods + // --------------------------------------------------------------------------- + private def assertDPPRuntimeFilters( df: DataFrame, expectedCount: Int = 1): Unit = { val batchScan = collectBatchScan(df) @@ -339,6 +466,11 @@ class DataSourceV2EnhancedRuntimePartitionFilterSuite s"in runtimeFilters, got ${dppFilters.size}") } + private def assertHasRuntimeFilters(df: DataFrame): Unit = { + assert(collectBatchScan(df).runtimeFilters.nonEmpty, + "Expected non-empty runtimeFilters on BatchScanExec") + } + private def assertScalarSubqueryRuntimeFilters( df: DataFrame, expectedCount: Int = 1): Unit = { val batchScan = collectBatchScan(df) @@ -361,7 +493,7 @@ class DataSourceV2EnhancedRuntimePartitionFilterSuite }.getOrElse(fail("Expected BatchScanExec in plan")) } - private def getPushedPartitionPredicates( + private[connector] def getPushedPartitionPredicates( df: DataFrame): Seq[PartitionPredicate] = { val batchScan = collectBatchScan(df) batchScan.scan match { @@ -386,7 +518,8 @@ class DataSourceV2EnhancedRuntimePartitionFilterSuite expectedOrdinals: Array[Int], expectedPartitionFieldNames: Array[String]): Unit = { val refs = predicate.references() - val ordinals = refs.map(_.asInstanceOf[PartitionFieldReference].ordinal()).sorted + val ordinals = + refs.map(_.asInstanceOf[PartitionFieldReference].ordinal()).sorted assert(ordinals.sameElements(expectedOrdinals.sorted), s"Expected references().map(_.ordinal()) " + s"${expectedOrdinals.sorted.mkString("[", ", ", "]")}, " + @@ -395,27 +528,42 @@ class DataSourceV2EnhancedRuntimePartitionFilterSuite val names = expectedPartitionFieldNames refs.foreach { ref => assert(ref.isInstanceOf[PartitionFieldReference], - s"Expected PartitionFieldReference, got ${ref.getClass.getName}") + s"Expected PartitionFieldReference, " + + s"got ${ref.getClass.getName}") val partRef = ref.asInstanceOf[PartitionFieldReference] assert(partRef.fieldNames().nonEmpty, - s"PartitionFieldReference.ordinal=${partRef.ordinal()} has empty fieldNames") + s"ordinal=${partRef.ordinal()} has empty fieldNames") assert(partRef.ordinal() < names.length, - s"PartitionFieldReference.ordinal=${partRef.ordinal()} " + - s"out of range for names length ${names.length}") + s"ordinal=${partRef.ordinal()} out of range " + + s"for names length ${names.length}") val expectedName = names(partRef.ordinal()) val actualName = partRef.fieldNames().mkString(".") assert(actualName === expectedName, - s"PartitionFieldReference.ordinal=${partRef.ordinal()}: " + - s"expected fieldNames '${expectedName}', got '${actualName}'") + s"ordinal=${partRef.ordinal()}: expected " + + s"fieldNames '$expectedName', got '$actualName'") } } + private def assertPredicateForOrdinal( + df: DataFrame, + ordinal: Int, + expectedPartitionFieldNames: Array[String]): Unit = { + val predicates = getPushedPartitionPredicates(df) + val pred = predicates.find(_.references().exists( + _.asInstanceOf[PartitionFieldReference].ordinal() == ordinal)) + assert(pred.isDefined, + s"Expected a PartitionPredicate referencing ordinal $ordinal") + assertPartitionPredicateOrdinals( + pred.get, Array(ordinal), expectedPartitionFieldNames) + } + private def assertReferencedPartitionFieldOrdinals( df: DataFrame, expectedOrdinals: Array[Int], expectedPartitionFieldNames: Array[String]): Unit = { getPushedPartitionPredicates(df).foreach { p => - assertPartitionPredicateOrdinals(p, expectedOrdinals, expectedPartitionFieldNames) + assertPartitionPredicateOrdinals( + p, expectedOrdinals, expectedPartitionFieldNames) } } @@ -425,9 +573,12 @@ class DataSourceV2EnhancedRuntimePartitionFilterSuite val batchScan = collectBatchScan(df) val partitions = batchScan.batch.planInputPartitions() assert(partitions.length === expectedPartitionKeys.size, - s"Expected ${expectedPartitionKeys.size} partition(s), got ${partitions.length}") - val partKeys = partitions.map(_.asInstanceOf[BufferedRows].keyString()).toSet + s"Expected ${expectedPartitionKeys.size} partition(s), " + + s"got ${partitions.length}") + val partKeys = + partitions.map(_.asInstanceOf[BufferedRows].keyString()).toSet assert(partKeys === expectedPartitionKeys, - s"Partition keys should be $expectedPartitionKeys, got $partKeys") + s"Partition keys should be $expectedPartitionKeys, " + + s"got $partKeys") } } From 6b07f10b4ef64f1b2cd18a2cf9649b7c559d7672 Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Wed, 22 Apr 2026 18:45:11 +0200 Subject: [PATCH 7/8] [SPARK-56521][SQL] Address review feedback: robustness, javadoc, and test improvements - Return translatedFiltersPushed || partPredicatesPushed instead of pushedPredicates().nonEmpty, so filter() side effects are visible even if the connector does not override pushedPredicates(). - Extract V2ExpressionUtils.resolveAttributeRefs to share resolution logic between PartitionPruning and PushDownUtils. - Clarify SupportsRuntimeV2Filtering javadoc: document two-pass call order and that the second pass excludes already-accepted filters. - Refactor case 11 to use the enhanced catalog with supports-iterative-filtering=false property and withSQLConf. - Add regression test for buggy connector that omits first-pass filters from pushedPredicates(). - Code tidying in InMemoryEnhancedRuntimePartitionFilterTable. --- .../read/SupportsRuntimeV2Filtering.java | 4 +- .../expressions/V2ExpressionUtils.scala | 12 +++- ...yEnhancedRuntimePartitionFilterTable.scala | 48 ++++++------- .../datasources/v2/PushDownUtils.scala | 19 ++--- .../dynamicpruning/PartitionPruning.scala | 8 +-- ...2EnhancedRuntimePartitionFilterSuite.scala | 70 ++++++++++++++----- 6 files changed, 101 insertions(+), 60 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeV2Filtering.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeV2Filtering.java index bee1d3b7e9d19..e0e76b246caac 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeV2Filtering.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeV2Filtering.java @@ -33,7 +33,9 @@ *

* Iterative filtering: When {@link #supportsIterativeFiltering()} returns true, * {@link #filter(Predicate[])} may be called multiple times on the same - * {@link Scan} instance with additional predicates (e.g. {@link PartitionPredicate}). + * {@link Scan} instance. The first call pushes translated V2 predicates; the second call + * pushes {@link PartitionPredicate} instances derived from runtime filters whose translated + * form was not already accepted (via {@link #pushedPredicates()}) in the first call. * The implementation must accumulate state across all calls, and * {@link #pushedPredicates()} must return predicates from all of them. *

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala index c4c6a60ce9314..d747bebd5cfe6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.{InternalRow, SQLConfHelper} import org.apache.spark.sql.catalyst.analysis.{NoSuchFunctionException, UnresolvedAttribute} import org.apache.spark.sql.catalyst.encoders.EncoderUtils import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.connector.catalog.{FunctionCatalog, Identifier} import org.apache.spark.sql.connector.catalog.functions._ import org.apache.spark.sql.connector.catalog.functions.ScalarFunction.MAGIC_METHOD_NAME @@ -59,6 +59,16 @@ object V2ExpressionUtils extends SQLConfHelper with Logging { refs.map(ref => resolveRef[T](ref, plan)) } + /** + * Resolves [[NamedReference]]s against the given output and returns them as an [[AttributeSet]]. + */ + def resolveAttributeRefs( + refs: Array[NamedReference], + output: Seq[Attribute]): AttributeSet = { + val plan = LocalRelation(output) + AttributeSet(resolveRefs[Attribute](refs.toImmutableArraySeq, plan)) + } + /** * Converts the array of input V2 [[V2SortOrder]] into their counterparts in catalyst. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryEnhancedRuntimePartitionFilterTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryEnhancedRuntimePartitionFilterTable.scala index ed70d10057d39..3c6fd4439bf91 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryEnhancedRuntimePartitionFilterTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryEnhancedRuntimePartitionFilterTable.scala @@ -21,6 +21,8 @@ import java.util import scala.collection.mutable.ArrayBuffer +import InMemoryEnhancedRuntimePartitionFilterTable._ + import org.apache.spark.sql.connector.expressions.{NamedReference, Transform} import org.apache.spark.sql.connector.expressions.filter.{PartitionPredicate, Predicate} import org.apache.spark.sql.connector.read.{InputPartition, Scan, ScanBuilder, SupportsRuntimeV2Filtering} @@ -69,28 +71,21 @@ class InMemoryEnhancedRuntimePartitionFilterTable( private val _allPushedPredicates = ArrayBuffer.empty[Predicate] + private val props = InMemoryEnhancedRuntimePartitionFilterTable.this.properties + private val acceptV2Predicates = - InMemoryEnhancedRuntimePartitionFilterTable.this.properties - .getOrDefault( - InMemoryEnhancedRuntimePartitionFilterTable - .AcceptV2PredicatesKey, "false").toBoolean + props.getOrDefault(AcceptV2PredicatesKey, "false").toBoolean private val restrictedFilterAttrs: Option[Set[String]] = - Option(InMemoryEnhancedRuntimePartitionFilterTable.this - .properties.get( - InMemoryEnhancedRuntimePartitionFilterTable - .FilterAttributesKey)) - .map(_.split(",").map(_.trim).toSet) + Option(props.get(FilterAttributesKey)).map(_.split(",").map(_.trim).toSet) def pushedPartitionPredicates: Seq[PartitionPredicate] = - _allPushedPredicates.collect { - case pp: PartitionPredicate => pp - }.toSeq + _allPushedPredicates.collect { case pp: PartitionPredicate => pp }.toSeq - override def pushedPredicates(): Array[Predicate] = - _allPushedPredicates.toArray + override def pushedPredicates(): Array[Predicate] = _allPushedPredicates.toArray - override def supportsIterativeFiltering(): Boolean = true + override def supportsIterativeFiltering(): Boolean = + props.getOrDefault(SupportsIterativeFilteringKey, "true").toBoolean override def filterAttributes(): Array[NamedReference] = { val scanFields = readSchema.fields.map(_.name).toSet @@ -101,17 +96,12 @@ class InMemoryEnhancedRuntimePartitionFilterTable( } } - override def filter(filters: Array[Predicate]): Unit = { - filters.foreach { - case pp: PartitionPredicate => - _allPushedPredicates += pp - data = data.filter { partition => - pp.eval( - partition.asInstanceOf[BufferedRows].partitionKey()) - } - case other => - if (acceptV2Predicates) _allPushedPredicates += other - } + override def filter(filters: Array[Predicate]): Unit = filters.foreach { + case pp: PartitionPredicate => + _allPushedPredicates += pp + data = data.filter(p => pp.eval(p.asInstanceOf[BufferedRows].partitionKey())) + case other => + if (acceptV2Predicates) _allPushedPredicates += other } } } @@ -128,4 +118,10 @@ object InMemoryEnhancedRuntimePartitionFilterTable { * filterAttributes(). Default: all partition columns. */ private[catalog] val FilterAttributesKey = "filter-attributes" + + /** + * Table property: when "false", supportsIterativeFiltering() returns false. + * Default: "true". + */ + private[catalog] val SupportsIterativeFilteringKey = "supports-iterative-filtering" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index e2f1f7c660126..e479041114749 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -21,7 +21,7 @@ import scala.collection.mutable import org.apache.spark.internal.{Logging, LogKeys} import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, DynamicPruning, DynamicPruningExpression, Expression, ExpressionSet, GetStructField, NamedExpression, PythonUDF, SchemaPruning, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, DynamicPruning, DynamicPruningExpression, Expression, ExpressionSet, GetStructField, NamedExpression, PythonUDF, SchemaPruning, SubqueryExpression, V2ExpressionUtils} import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.CharVarcharUtils @@ -134,8 +134,9 @@ object PushDownUtils extends Logging { /** * Pushes runtime filters to a [[SupportsRuntimeV2Filtering]] scan. Translatable filters are * pushed first, followed by [[PartitionPredicate]] if the scan supports iterative filtering. - * Only runtime filters that were not already translated are used to derive PartitionPredicates - * in the second pass, avoiding duplicate pushdown. + * Only runtime filters whose translated form was not already accepted by the data source in + * the first pass are used to derive PartitionPredicates in the second pass, avoiding duplicate + * pushdown. * * @return true if any filters were pushed to the data source */ @@ -154,16 +155,17 @@ object PushDownUtils extends Logging { }).map(f -> _) }.toMap - if (filtersToTranslated.nonEmpty) { + val translatedFiltersPushed = filtersToTranslated.nonEmpty + if (translatedFiltersPushed) { filterableScan.filter(filtersToTranslated.values.toArray) } // If the scan supports iterative filtering, derive PartitionPredicates from runtime // filters whose translation was not already accepted in the first pass. (See SPARK-55596) // Only candidates whose referenced columns are declared in filterAttributes() are eligible. - if (filterableScan.supportsIterativeFiltering()) { - val filterAttrs = AttributeSet(filterableScan.filterAttributes() - .flatMap(r => output.find(a => SQLConf.get.resolver(a.name, r.fieldNames.head)))) + val partPredicatesPushed = filterableScan.supportsIterativeFiltering() && { + val filterAttrs = V2ExpressionUtils.resolveAttributeRefs( + filterableScan.filterAttributes(), output) val pushed = filterableScan.pushedPredicates().toSet val candidates = runtimeFilters.filter { f => !filtersToTranslated.get(f).exists(pushed.contains) && @@ -175,9 +177,10 @@ object PushDownUtils extends Logging { if (partPredicates.nonEmpty) { filterableScan.filter(partPredicates.toArray) } + partPredicates.nonEmpty } - filterableScan.pushedPredicates().nonEmpty + translatedFiltersPushed || partPredicatesPushed case _ => false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala index 05dddfd3e7b4a..fd34162044483 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala @@ -27,8 +27,6 @@ import org.apache.spark.sql.connector.read.SupportsRuntimeV2Filtering import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.execution.datasources.v2.ExtractV2Scan -import org.apache.spark.util.ArrayImplicits._ - /** * Dynamic partition pruning optimization is performed based on the type and * selectivity of the join operation. During query optimization, we insert a @@ -80,9 +78,9 @@ object PartitionPruning extends Rule[LogicalPlan] with PredicateHelper with Join None } case (resExp, r @ ExtractV2Scan(scan: SupportsRuntimeV2Filtering)) => - val filterAttrs = V2ExpressionUtils.resolveRefs[Attribute]( - scan.filterAttributes.toImmutableArraySeq, r) - if (resExp.references.subsetOf(AttributeSet(filterAttrs))) { + val filterAttrs = V2ExpressionUtils.resolveAttributeRefs( + scan.filterAttributes, r.output) + if (resExp.references.subsetOf(filterAttrs)) { Some(r) } else { None diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2EnhancedRuntimePartitionFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2EnhancedRuntimePartitionFilterSuite.scala index 8e3a645be0630..efcecc3a29efc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2EnhancedRuntimePartitionFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2EnhancedRuntimePartitionFilterSuite.scala @@ -401,27 +401,28 @@ class DataSourceV2EnhancedRuntimePartitionFilterSuite } test("case 11: supportsIterativeFiltering is false -> no PartitionPredicate") { - val baseCatalog = "testv2filterNoIterative" - spark.conf.set(s"spark.sql.catalog.$baseCatalog", - classOf[catalog.InMemoryTableWithV2FilterCatalog].getName) + val noIterCatalog = "testNoIterativeFiltering" + withSQLConf(s"spark.sql.catalog.$noIterCatalog" -> + classOf[InMemoryTableEnhancedRuntimePartitionFilterCatalog].getName) { + val tbl = s"$noIterCatalog.tbl" + val dim = s"$noIterCatalog.dim" + withTable(tbl, dim) { + sql(s"CREATE TABLE $tbl (id INT, part INT) " + + s"USING $v2Source PARTITIONED BY (part) " + + s"TBLPROPERTIES('supports-iterative-filtering' = 'false')") + for (i <- 0 until 5) { + sql(s"INSERT INTO $tbl VALUES ($i, $i)") + } + sql(s"CREATE TABLE $dim (val INT) USING $v2Source") + sql(s"INSERT INTO $dim VALUES (3)") - val tbl = s"$baseCatalog.tbl" - val dim = s"$baseCatalog.dim" - withTable(tbl, dim) { - sql(s"CREATE TABLE $tbl (id INT, part INT) " + - s"USING $v2Source PARTITIONED BY (part)") - for (i <- 0 until 5) { - sql(s"INSERT INTO $tbl VALUES ($i, $i)") - } - sql(s"CREATE TABLE $dim (val INT) USING $v2Source") - sql(s"INSERT INTO $dim VALUES (3)") - - val df = sql( - s"SELECT * FROM $tbl WHERE part = (SELECT max(val) FROM $dim)") - checkAnswer(df, Row(3, 3)) + val df = sql( + s"SELECT * FROM $tbl WHERE part = (SELECT max(val) FROM $dim)") + checkAnswer(df, Row(3, 3)) - assertHasRuntimeFilters(df) - assertPushedPartitionPredicates(df, 0) + assertHasRuntimeFilters(df) + assertPushedPartitionPredicates(df, 0) + } } } @@ -451,6 +452,37 @@ class DataSourceV2EnhancedRuntimePartitionFilterSuite } } + // Regression test for a buggy connector that supports iterative filtering but + // does not correctly report first-pass filters in pushedPredicates(). + // The second round should still push a PartitionPredicate and prune partitions. + test("pushedPredicates() omits first-pass filters -> second round still prunes") { + val tbl = s"$catalogName.tbl_nopushed" + val dim = s"$catalogName.dim_nopushed" + withTable(tbl, dim) { + sql(s"CREATE TABLE $tbl (id INT, part INT) USING $v2Source PARTITIONED BY (part)") + for (i <- 0 until 5) { + sql(s"INSERT INTO $tbl VALUES ($i, $i)") + } + sql(s"CREATE TABLE $dim (dim_id INT, dim_val STRING) USING $v2Source") + sql(s"INSERT INTO $dim VALUES (2, 'two')") + + withDPPConf { + val df = sql( + s"""SELECT f.id, f.part FROM $tbl f JOIN $dim d + |ON f.part = d.dim_id WHERE d.dim_val = 'two'""".stripMargin) + checkAnswer(df, Row(2, 2)) + + assertDPPRuntimeFilters(df) + assertPushedPartitionPredicates(df, 1) + assertScanReturnsPartitionKeys(df, Set("2")) + + val batchScan = collectBatchScan(df) + assert(batchScan.filteredPartitions.flatten.length < 5, + "Expected PartitionPredicate from second round to actually prune partitions") + } + } + } + // --------------------------------------------------------------------------- // Helper methods // --------------------------------------------------------------------------- From f426fbcd2fb4e7af50d9eb343e811443e69a2d24 Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Thu, 23 Apr 2026 00:14:54 +0200 Subject: [PATCH 8/8] [SPARK-56521][SQL] Rename supportsIterativeFiltering to supportsIterativePushdown Align with SupportsPushDownV2Filters.supportsIterativePushdown() naming. --- .../sql/connector/read/SupportsRuntimeV2Filtering.java | 6 +++--- .../InMemoryEnhancedRuntimePartitionFilterTable.scala | 8 ++++---- .../sql/execution/datasources/v2/PushDownUtils.scala | 2 +- ...taSourceV2EnhancedRuntimePartitionFilterSuite.scala | 10 +++++----- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeV2Filtering.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeV2Filtering.java index e0e76b246caac..94dbc3865958a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeV2Filtering.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeV2Filtering.java @@ -31,7 +31,7 @@ * {@link SupportsRuntimeV2Filtering} is preferred over {@link SupportsRuntimeFiltering} * and only one of them should be implemented by the data sources. *

- * Iterative filtering: When {@link #supportsIterativeFiltering()} returns true, + * Iterative filtering: When {@link #supportsIterativePushdown()} returns true, * {@link #filter(Predicate[])} may be called multiple times on the same * {@link Scan} instance. The first call pushes translated V2 predicates; the second call * pushes {@link PartitionPredicate} instances derived from runtime filters whose translated @@ -68,7 +68,7 @@ public interface SupportsRuntimeV2Filtering extends Scan { * scan must not report new partition values that were not present in the original partitioning. *

* This method may be called multiple times with additional predicates (e.g. - * {@link PartitionPredicate}) when {@link #supportsIterativeFiltering()} returns true. + * {@link PartitionPredicate}) when {@link #supportsIterativePushdown()} returns true. * The implementation must accumulate state across all calls so that * {@link #pushedPredicates()} can return predicates from all of them. *

@@ -107,7 +107,7 @@ default Predicate[] pushedPredicates() { * * @since 4.2.0 */ - default boolean supportsIterativeFiltering() { + default boolean supportsIterativePushdown() { return false; } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryEnhancedRuntimePartitionFilterTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryEnhancedRuntimePartitionFilterTable.scala index 3c6fd4439bf91..1d575d52d6078 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryEnhancedRuntimePartitionFilterTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryEnhancedRuntimePartitionFilterTable.scala @@ -84,8 +84,8 @@ class InMemoryEnhancedRuntimePartitionFilterTable( override def pushedPredicates(): Array[Predicate] = _allPushedPredicates.toArray - override def supportsIterativeFiltering(): Boolean = - props.getOrDefault(SupportsIterativeFilteringKey, "true").toBoolean + override def supportsIterativePushdown(): Boolean = + props.getOrDefault(SupportsIterativePushdownKey, "true").toBoolean override def filterAttributes(): Array[NamedReference] = { val scanFields = readSchema.fields.map(_.name).toSet @@ -120,8 +120,8 @@ object InMemoryEnhancedRuntimePartitionFilterTable { private[catalog] val FilterAttributesKey = "filter-attributes" /** - * Table property: when "false", supportsIterativeFiltering() returns false. + * Table property: when "false", supportsIterativePushdown() returns false. * Default: "true". */ - private[catalog] val SupportsIterativeFilteringKey = "supports-iterative-filtering" + private[catalog] val SupportsIterativePushdownKey = "supports-iterative-pushdown" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index e479041114749..0d34dfc91c39f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -163,7 +163,7 @@ object PushDownUtils extends Logging { // If the scan supports iterative filtering, derive PartitionPredicates from runtime // filters whose translation was not already accepted in the first pass. (See SPARK-55596) // Only candidates whose referenced columns are declared in filterAttributes() are eligible. - val partPredicatesPushed = filterableScan.supportsIterativeFiltering() && { + val partPredicatesPushed = filterableScan.supportsIterativePushdown() && { val filterAttrs = V2ExpressionUtils.resolveAttributeRefs( filterableScan.filterAttributes(), output) val pushed = filterableScan.pushedPredicates().toSet diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2EnhancedRuntimePartitionFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2EnhancedRuntimePartitionFilterSuite.scala index efcecc3a29efc..0e845734792f7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2EnhancedRuntimePartitionFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2EnhancedRuntimePartitionFilterSuite.scala @@ -50,7 +50,7 @@ import org.apache.spark.sql.test.SharedSparkSession * 8. DPP, translated, accepted in 1st pass -> no PartitionPredicate * 9. Scalar, translatable, accepted in 1st pass -> no PartitionPredicate * 10. Scalar on data column -> no PartitionPredicate - * 11. supportsIterativeFiltering is false -> no PartitionPredicate + * 11. supportsIterativePushdown is false -> no PartitionPredicate * 12. Partition col not in filterAttributes -> no PartitionPredicate */ class DataSourceV2EnhancedRuntimePartitionFilterSuite @@ -400,7 +400,7 @@ class DataSourceV2EnhancedRuntimePartitionFilterSuite } } - test("case 11: supportsIterativeFiltering is false -> no PartitionPredicate") { + test("case 11: supportsIterativePushdown is false -> no PartitionPredicate") { val noIterCatalog = "testNoIterativeFiltering" withSQLConf(s"spark.sql.catalog.$noIterCatalog" -> classOf[InMemoryTableEnhancedRuntimePartitionFilterCatalog].getName) { @@ -409,7 +409,7 @@ class DataSourceV2EnhancedRuntimePartitionFilterSuite withTable(tbl, dim) { sql(s"CREATE TABLE $tbl (id INT, part INT) " + s"USING $v2Source PARTITIONED BY (part) " + - s"TBLPROPERTIES('supports-iterative-filtering' = 'false')") + s"TBLPROPERTIES('supports-iterative-pushdown' = 'false')") for (i <- 0 until 5) { sql(s"INSERT INTO $tbl VALUES ($i, $i)") } @@ -452,9 +452,9 @@ class DataSourceV2EnhancedRuntimePartitionFilterSuite } } - // Regression test for a buggy connector that supports iterative filtering but + // Test for a buggy connector that supports iterative filtering but // does not correctly report first-pass filters in pushedPredicates(). - // The second round should still push a PartitionPredicate and prune partitions. + // The second round will push duplicate PartitionPredicate. test("pushedPredicates() omits first-pass filters -> second round still prunes") { val tbl = s"$catalogName.tbl_nopushed" val dim = s"$catalogName.dim_nopushed"