Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import org.apache.spark.annotation.Experimental;
import org.apache.spark.sql.connector.expressions.NamedReference;
import org.apache.spark.sql.connector.expressions.filter.PartitionPredicate;
import org.apache.spark.sql.connector.expressions.filter.Predicate;
import org.apache.spark.sql.sources.Filter;

Expand All @@ -29,7 +30,14 @@
* data source V2 {@link Predicate} instead of data source V1 {@link Filter}.
* {@link SupportsRuntimeV2Filtering} is preferred over {@link SupportsRuntimeFiltering}
* and only one of them should be implemented by the data sources.
*
* <p>
* <b>Iterative filtering:</b> When {@link #supportsIterativePushdown()} returns true,
* {@link #filter(Predicate[])} may be called <i>multiple times</i> on the same
* {@link Scan} instance. The first call pushes translated V2 predicates; the second call
* pushes {@link PartitionPredicate} instances derived from runtime filters whose translated
* form was not already accepted (via {@link #pushedPredicates()}) in the first call.
* The implementation must accumulate state across all calls, and
* {@link #pushedPredicates()} must return predicates from all of them.
* <p>
* Note that Spark will push runtime filters only if they are beneficial.
*
Expand Down Expand Up @@ -59,9 +67,47 @@ public interface SupportsRuntimeV2Filtering extends Scan {
* partition values (omitting those with no data) via {@link Batch#planInputPartitions()}. The
* scan must not report new partition values that were not present in the original partitioning.
* <p>
* This method may be called multiple times with additional predicates (e.g.
* {@link PartitionPredicate}) when {@link #supportsIterativePushdown()} returns true.
* The implementation must accumulate state across all calls so that
* {@link #pushedPredicates()} can return predicates from all of them.
* <p>
* Note that Spark will call {@link Scan#toBatch()} again after filtering the scan at runtime.
*
* @param predicates data source V2 predicates used to filter the scan at runtime
*/
void filter(Predicate[] predicates);

/**
* Returns the predicates that are pushed to the data source via
* {@link #filter(Predicate[])}.
* <p>
* When iterative filtering is supported and {@link #filter(Predicate[])} was called
* multiple times, this method must return predicates from <i>all</i> calls.
* <p>
* It's possible that there are no runtime predicates and
* {@link #filter(Predicate[])} is never called;
* an empty array should be returned for this case.
*
* @since 4.2.0
*/
default Predicate[] pushedPredicates() {
return new Predicate[0];
}

/**
* Returns true if this scan supports iterative runtime filtering. When true,
* {@link #filter(Predicate[])} may be called multiple times with additional
* predicates. The implementation must accumulate state across all calls,
* and {@link #pushedPredicates()} must return predicates from all of them.
* See the class-level Javadoc for the full contract.
* <p>
* When enabled, Spark will derive {@link PartitionPredicate} instances from the runtime
* filters and push them via a subsequent {@link #filter(Predicate[])} call.
*
* @since 4.2.0
*/
default boolean supportsIterativePushdown() {
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.{InternalRow, SQLConfHelper}
import org.apache.spark.sql.catalyst.analysis.{NoSuchFunctionException, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.encoders.EncoderUtils
import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.connector.catalog.{FunctionCatalog, Identifier}
import org.apache.spark.sql.connector.catalog.functions._
import org.apache.spark.sql.connector.catalog.functions.ScalarFunction.MAGIC_METHOD_NAME
Expand Down Expand Up @@ -59,6 +59,16 @@ object V2ExpressionUtils extends SQLConfHelper with Logging {
refs.map(ref => resolveRef[T](ref, plan))
}

/**
* Resolves [[NamedReference]]s against the given output and returns them as an [[AttributeSet]].
*/
def resolveAttributeRefs(
refs: Array[NamedReference],
output: Seq[Attribute]): AttributeSet = {
val plan = LocalRelation(output)
AttributeSet(resolveRefs[Attribute](refs.toImmutableArraySeq, plan))
}

/**
* Converts the array of input V2 [[V2SortOrder]] into their counterparts in catalyst.
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connector.catalog

import java.util

import scala.collection.mutable.ArrayBuffer

import InMemoryEnhancedRuntimePartitionFilterTable._

import org.apache.spark.sql.connector.expressions.{NamedReference, Transform}
import org.apache.spark.sql.connector.expressions.filter.{PartitionPredicate, Predicate}
import org.apache.spark.sql.connector.read.{InputPartition, Scan, ScanBuilder, SupportsRuntimeV2Filtering}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.ArrayImplicits._

/**
* In-memory table whose batch scan implements [[SupportsRuntimeV2Filtering]] with
* iterative filtering support, so that [[PartitionPredicate]] instances derived from
* runtime filters are pushed via a second [[SupportsRuntimeV2Filtering#filter]] call.
*
* Table properties:
* - `accept-v2-predicates` (default `false`): when true, non-PartitionPredicate
* V2 predicates are reported via `pushedPredicates()` (i.e. accepted).
* - `filter-attributes` (default: all partition cols): comma-separated list of
* column names to expose from `filterAttributes()`.
*/
class InMemoryEnhancedRuntimePartitionFilterTable(
name: String,
columns: Array[Column],
partitioning: Array[Transform],
properties: util.Map[String, String])
extends InMemoryTableWithV2Filter(name, columns, partitioning, properties) {

override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
new InMemoryEnhancedRuntimePartitionFilterScanBuilder(schema, options)
}

class InMemoryEnhancedRuntimePartitionFilterScanBuilder(
tableSchema: StructType,
options: CaseInsensitiveStringMap)
extends InMemoryScanBuilder(tableSchema, options) {
override def build: Scan = InMemoryEnhancedRuntimePartitionFilterBatchScan(
data.map(_.asInstanceOf[InputPartition]).toImmutableArraySeq,
schema, tableSchema, options)
}

case class InMemoryEnhancedRuntimePartitionFilterBatchScan(
var _data: Seq[InputPartition],
readSchema: StructType,
tableSchema: StructType,
options: CaseInsensitiveStringMap)
extends BatchScanBaseClass(_data, readSchema, tableSchema)
with SupportsRuntimeV2Filtering {

private val _allPushedPredicates = ArrayBuffer.empty[Predicate]

private val props = InMemoryEnhancedRuntimePartitionFilterTable.this.properties

private val acceptV2Predicates =
props.getOrDefault(AcceptV2PredicatesKey, "false").toBoolean

private val restrictedFilterAttrs: Option[Set[String]] =
Option(props.get(FilterAttributesKey)).map(_.split(",").map(_.trim).toSet)

def pushedPartitionPredicates: Seq[PartitionPredicate] =
_allPushedPredicates.collect { case pp: PartitionPredicate => pp }.toSeq

override def pushedPredicates(): Array[Predicate] = _allPushedPredicates.toArray

override def supportsIterativePushdown(): Boolean =
props.getOrDefault(SupportsIterativePushdownKey, "true").toBoolean

override def filterAttributes(): Array[NamedReference] = {
val scanFields = readSchema.fields.map(_.name).toSet
partitioning.flatMap(_.references()).filter { ref =>
val name = ref.fieldNames.mkString(".")
scanFields.contains(name) &&
restrictedFilterAttrs.forall(_.contains(name))
}
}

override def filter(filters: Array[Predicate]): Unit = filters.foreach {
case pp: PartitionPredicate =>
_allPushedPredicates += pp
data = data.filter(p => pp.eval(p.asInstanceOf[BufferedRows].partitionKey()))
case other =>
if (acceptV2Predicates) _allPushedPredicates += other
}
}
}

object InMemoryEnhancedRuntimePartitionFilterTable {
/**
* Table property: when "true", non-PartitionPredicate V2 predicates
* pushed via filter() are reported in pushedPredicates() (accepted).
*/
private[catalog] val AcceptV2PredicatesKey = "accept-v2-predicates"

/**
* Table property: comma-separated column names to expose from
* filterAttributes(). Default: all partition columns.
*/
private[catalog] val FilterAttributesKey = "filter-attributes"

/**
* Table property: when "false", supportsIterativePushdown() returns false.
* Default: "true".
*/
private[catalog] val SupportsIterativePushdownKey = "supports-iterative-pushdown"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connector.catalog

import java.util

import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
import org.apache.spark.sql.connector.expressions.Transform

class InMemoryTableEnhancedRuntimePartitionFilterCatalog extends InMemoryTableCatalog {
import CatalogV2Implicits._

override def createTable(
ident: Identifier,
columns: Array[Column],
partitions: Array[Transform],
properties: util.Map[String, String]): Table = {
if (tables.containsKey(ident)) {
throw new TableAlreadyExistsException(ident.asMultipartIdentifier)
}

InMemoryTableCatalog.maybeSimulateFailedTableCreation(properties)

val tableName = s"$name.${ident.quoted}"
val table = new InMemoryEnhancedRuntimePartitionFilterTable(
tableName, columns, partitions, properties)
tables.put(ident, table)
namespaces.putIfAbsent(ident.namespace.toList, Map())
table
}

override def createTable(ident: Identifier, tableInfo: TableInfo): Table = {
createTable(ident, tableInfo.columns(), tableInfo.partitions(), tableInfo.properties)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,10 @@ case class BatchScanExec(

// Visible for testing
@transient private[sql] lazy val filteredPartitions: Seq[Option[InputPartition]] = {
val dataSourceFilters = runtimeFilters.flatMap {
case DynamicPruningExpression(e) => DataSourceV2Strategy.translateRuntimeFilterV2(e)
case f => DataSourceV2Strategy.translateScalarSubqueryFilterV2(f)
}

val originalPartitioning = outputPartitioning
if (dataSourceFilters.nonEmpty) {
// the cast is safe as runtime filters are only assigned if the scan can be filtered
val filterableScan = scan.asInstanceOf[SupportsRuntimeV2Filtering]
filterableScan.filter(dataSourceFilters.toArray)

val filtered = PushDownUtils.pushRuntimeFilters(scan, runtimeFilters, table, output)
if (filtered) {
// call toBatch again to get filtered partitions
val newPartitions = scan.toBatch.planInputPartitions()

Expand Down
Loading