Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connector.read;

import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.connector.expressions.Expression;

/**
* A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to
* push down all the join or aggregate keys to data sources. A return value true indicates
* that data source will return input partitions (via planInputPartitions} following the
* clustering keys. Otherwise, a false return value indicates the data source doesn't make
* such a guarantee, even though it may still report a partitioning that may or may not
* be compatible with the given clustering keys, and it's Spark's responsibility to group
* the input partitions whether it can be applied.
*
* @since 3.4.0
*/
@Evolving
public interface SupportsPushDownClusterKeys extends ScanBuilder {

/**
* Pushes down cluster keys to the data source.
*/
boolean pushClusterKeys(Expression[] expressions);
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@ object RewriteDeleteFromTable extends RewriteRowLevelCommand {
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case d @ DeleteFromTable(aliasedTable, cond) if d.resolved =>
EliminateSubqueryAliases(aliasedTable) match {
case DataSourceV2Relation(_: TruncatableTable, _, _, _, _) if cond == TrueLiteral =>
case DataSourceV2Relation(_: TruncatableTable, _, _, _, _, _) if cond == TrueLiteral =>
// don't rewrite as the table supports truncation
d

case r @ DataSourceV2Relation(t: SupportsRowLevelOperations, _, _, _, _) =>
case r @ DataSourceV2Relation(t: SupportsRowLevelOperations, _, _, _, _, _) =>
val table = buildOperationTable(t, DELETE, CaseInsensitiveStringMap.empty())
buildReplaceDataPlan(r, table, cond)

case DataSourceV2Relation(_: SupportsDeleteV2, _, _, _, _) =>
case DataSourceV2Relation(_: SupportsDeleteV2, _, _, _, _, _) =>
// don't rewrite as the table supports deletes only with filters
d

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ object GroupBasedRowLevelOperation {
type ReturnType = (ReplaceData, Expression, LogicalPlan)

def unapply(plan: LogicalPlan): Option[ReturnType] = plan match {
case rd @ ReplaceData(DataSourceV2Relation(table, _, _, _, _), cond, query, _, _) =>
case rd @ ReplaceData(DataSourceV2Relation(table, _, _, _, _, _), cond, query, _, _) =>
val readRelation = findReadRelation(table, query)
readRelation.map((rd, cond, _))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ case class ReplaceData(

lazy val operation: RowLevelOperation = {
EliminateSubqueryAliases(table) match {
case DataSourceV2Relation(RowLevelOperationTable(_, operation), _, _, _, _) =>
case DataSourceV2Relation(RowLevelOperationTable(_, operation), _, _, _, _, _) =>
operation
case _ =>
throw new AnalysisException(s"Cannot retrieve row-level operation from $table")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,16 @@ import org.apache.spark.util.Utils
* @param options The options for this table operation. It's used to create fresh
* [[org.apache.spark.sql.connector.read.ScanBuilder]] and
* [[org.apache.spark.sql.connector.write.WriteBuilder]].
* @param partitionGroupedByClusterKeys if the partitions in this relation is grouped
* by clustered keys
*/
case class DataSourceV2Relation(
table: Table,
output: Seq[AttributeReference],
catalog: Option[CatalogPlugin],
identifier: Option[Identifier],
options: CaseInsensitiveStringMap)
options: CaseInsensitiveStringMap,
var partitionGroupedByClusterKeys: Boolean = false)
extends LeafNode with MultiInstanceRelation with NamedRelation with ExposesMetadataColumns {

import DataSourceV2Implicits._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,20 +257,47 @@ abstract class InMemoryBaseTable(
TableCapability.TRUNCATE)

override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
new InMemoryScanBuilder(schema)
new InMemoryScanBuilder(schema, Some(partitioning))
}

class InMemoryScanBuilder(tableSchema: StructType) extends ScanBuilder
with SupportsPushDownRequiredColumns {
class InMemoryScanBuilder(
tableSchema: StructType,
partitioning: Option[Array[Transform]] = None) extends ScanBuilder
with SupportsPushDownRequiredColumns
with SupportsPushDownClusterKeys {
private var schema: StructType = tableSchema
private var pushedJoinKeys: Array[Expression] = Array.empty[Expression]
private val partition: Option[Array[Transform]] = partitioning

override def build: Scan =
InMemoryBatchScan(data.map(_.asInstanceOf[InputPartition]), schema, tableSchema)
override def build: Scan = {
val scan = InMemoryBatchScan(data.map(_.asInstanceOf[InputPartition]), schema, tableSchema)
scan.pushedJoinKeys = pushedJoinKeys
scan
}

override def pruneColumns(requiredSchema: StructType): Unit = {
val schemaNames = metadataColumnNames ++ tableSchema.map(_.name)
schema = StructType(requiredSchema.filter(f => schemaNames.contains(f.name)))
}

override def pushClusterKeys(expressions: Array[Expression]): Boolean = {
pushedJoinKeys = (pushedJoinKeys ++ expressions).distinct
pushedJoinKeys.map(_.asInstanceOf[NamedReference]).foreach(key =>
if (partition.nonEmpty && getPartitionColumns.contains(key)) {
return true
}
)
false
}

private def getPartitionColumns(): Array[NamedReference] = {
partition.get.map {
case BucketTransform(_, cols, _) =>
return cols.toArray
case _ =>
return Array.empty[NamedReference]
}
}
}

case class InMemoryStats(sizeInBytes: OptionalLong, numRows: OptionalLong) extends Statistics
Expand Down Expand Up @@ -322,6 +349,8 @@ abstract class InMemoryBaseTable(
tableSchema: StructType)
extends BatchScanBaseClass(_data, readSchema, tableSchema) with SupportsRuntimeFiltering {

var pushedJoinKeys = Array.empty[Expression]

override def filterAttributes(): Array[NamedReference] = {
val scanFields = readSchema.fields.map(_.name).toSet
partitioning.flatMap(_.references)
Expand Down
2 changes: 1 addition & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3923,7 +3923,7 @@ class Dataset[T] private[sql](
fr.inputFiles
case r: HiveTableRelation =>
r.tableMeta.storage.locationUri.map(_.toString).toArray
case DataSourceV2ScanRelation(DataSourceV2Relation(table: FileTable, _, _, _, _),
case DataSourceV2ScanRelation(DataSourceV2Relation(table: FileTable, _, _, _, _, _),
_, _, _, _) =>
table.fileIndex.inputFiles
}.flatten
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper {
isSameName(ident.qualifier :+ ident.name) &&
isSameName(v1Ident.catalog.toSeq ++ v1Ident.database :+ v1Ident.table)

case SubqueryAlias(ident, DataSourceV2Relation(_, _, Some(catalog), Some(v2Ident), _)) =>
case SubqueryAlias(ident, DataSourceV2Relation(_, _, Some(catalog), Some(v2Ident), _, _)) =>
isSameName(ident.qualifier :+ ident.name) &&
isSameName(catalog.name() +: v2Ident.namespace() :+ v2Ident.name())

Expand Down Expand Up @@ -345,7 +345,7 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper {
case _ => false
}

case DataSourceV2Relation(fileTable: FileTable, _, _, _, _) =>
case DataSourceV2Relation(fileTable: FileTable, _, _, _, _, _) =>
refreshFileIndexIfNecessary(fileTable.fileIndex, fs, qualifiedPath)

case _ => false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, File
class FallBackFileSourceV2(sparkSession: SparkSession) extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case i @ InsertIntoStatement(
d @ DataSourceV2Relation(table: FileTable, _, _, _, _), _, _, _, _, _) =>
d @ DataSourceV2Relation(table: FileTable, _, _, _, _, _), _, _, _, _, _) =>
val v1FileFormat = table.fallbackFileFormat.newInstance()
val relation = HadoopFsRelation(
table.fileIndex,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,10 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
case _: DynamicPruning => true
case _ => false
}
// TODO:
// if (!relation.relation.partitionGroupedByClusterKeys)
// spark may need to regroup partitions since partitions returned from datasource
// are not guaranteed to be grouped by cluster keys
val batchExec = BatchScanExec(relation.output, relation.scan, runtimeFilters,
relation.keyGroupedPartitioning, relation.ordering, relation.relation.table)
withProjectAndFilter(project, postScanFilters, batchExec, !batchExec.supportsColumnar) :: Nil
Expand Down Expand Up @@ -236,7 +240,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
invalidateCache) :: Nil
}

case AppendData(r @ DataSourceV2Relation(v1: SupportsWrite, _, _, _, _), query, _,
case AppendData(r @ DataSourceV2Relation(v1: SupportsWrite, _, _, _, _, _), query, _,
_, Some(write)) if v1.supports(TableCapability.V1_BATCH_WRITE) =>
write match {
case v1Write: V1Write =>
Expand All @@ -249,7 +253,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
case AppendData(r: DataSourceV2Relation, query, _, _, Some(write)) =>
AppendDataExec(planLater(query), refreshCache(r), write) :: Nil

case OverwriteByExpression(r @ DataSourceV2Relation(v1: SupportsWrite, _, _, _, _), _, query,
case OverwriteByExpression(r @ DataSourceV2Relation(v1: SupportsWrite, _, _, _, _, _), _, query,
_, _, Some(write)) if v1.supports(TableCapability.V1_BATCH_WRITE) =>
write match {
case v1Write: V1Write =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, NamedExpression, SchemaPruning}
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.connector.expressions.SortOrder
import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, SortOrder}
import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownOffset, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownTopN, SupportsPushDownV2Filters}
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownClusterKeys, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownOffset, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownTopN, SupportsPushDownV2Filters}
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources
Expand Down Expand Up @@ -116,6 +116,15 @@ object PushDownUtils {
}
}

def pushClusterKeys(scanBuilder: ScanBuilder, keys: Seq[V2Expression]): Boolean = {
scanBuilder match {
case s: SupportsPushDownClusterKeys =>
s.pushClusterKeys(keys.toArray)
case _ =>
false
}
}

/**
* Pushes down LIMIT to the data source Scan.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ package org.apache.spark.sql.execution.datasources.v2

import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions.{aggregate, Alias, And, Attribute, AttributeMap, AttributeReference, AttributeSet, Cast, Expression, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.{aggregate, Alias, And, Attribute, AttributeMap, AttributeReference, AttributeSet, Cast, EqualTo, Expression, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.optimizer.CollapseProject
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LimitAndOffset, LocalLimit, LogicalPlan, Offset, OffsetAndLimit, Project, Sample, Sort}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Join, LeafNode, Limit, LimitAndOffset, LocalLimit, LogicalPlan, Offset, OffsetAndLimit, Project, Sample, Sort}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.expressions.{SortOrder => V2SortOrder}
import org.apache.spark.sql.connector.expressions.{FieldReference, NamedReference, SortOrder => V2SortOrder}
import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Avg, Count, CountStar, Max, Min, Sum}
import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan}
Expand All @@ -42,6 +42,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
createScanBuilder,
pushDownSample,
pushDownFilters,
pushDownJoinKey,
pushDownAggregates,
pushDownLimitAndOffset,
buildScanWithPushedAggregate,
Expand Down Expand Up @@ -523,6 +524,83 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
}
}

def pushDownJoinKey(plan: LogicalPlan): LogicalPlan = plan.transform {
case join: Join if join.condition.nonEmpty =>
val joinKeys = splitConjunctivePredicates(join.condition.get)
if (joinKeys.forall(_.isInstanceOf[EqualTo])) {
val conditions = joinKeys.map(_.asInstanceOf[EqualTo])
// only push down join keys if join keys are columns,
// that is, the join condition is col1 = col2
if (conditions.forall(condition => isColumn(condition.left)
&& isColumn(condition.right))) {
var leftKeys = Seq[NamedReference]()
conditions.map(condition =>
if (join.left.outputSet.contains(condition.left.asInstanceOf[AttributeReference])) {
leftKeys = leftKeys :+ getFieldReference(condition.left)
} else {
leftKeys = leftKeys :+ getFieldReference(condition.right)
}
)
if (leftKeys.nonEmpty) {
pushDownJoinKeys(join.left, leftKeys)
}

var rightKeys = Seq[NamedReference]()
conditions.map(condition =>
if (join.right.outputSet.contains(condition.right.asInstanceOf[AttributeReference])) {
rightKeys = rightKeys :+ getFieldReference(condition.right)
} else {
rightKeys = rightKeys :+ getFieldReference(condition.left)
}
)
if (rightKeys.nonEmpty) {
pushDownJoinKeys(join.right, rightKeys)
}
}
}
join
}

private def isColumn(expr: Expression): Boolean = {
if (!expr.isInstanceOf[AttributeReference]) {
return false
}
try {
FieldReference.apply(expr.asInstanceOf[AttributeReference].name)
true
} catch {
case _: Throwable => false
}
}

private def getFieldReference(expr: Expression): NamedReference = {
FieldReference.apply(expr.asInstanceOf[AttributeReference].name)
}

private def pushDownJoinKeys(plan: LogicalPlan, keys: Seq[NamedReference]): Unit = {
var pushed = false
def pushJoinKeys(plan: LogicalPlan): Unit = {
plan match {
case PhysicalOperation(_, _, sHolder: ScanBuilderHolder) =>
val tableContainsKeys = keys.map(_.describe()).forall(sHolder.output.map(_.name).contains)
if (tableContainsKeys) {
val groupedByClusterKeys = PushDownUtils.pushClusterKeys(sHolder.builder, keys)
sHolder.relation.partitionGroupedByClusterKeys = groupedByClusterKeys
pushed = true
}
case join: Join =>
if (!pushed) {
pushJoinKeys(join.left)
}
if (!pushed) {
pushJoinKeys(join.right)
}
case _ =>
}
}
pushJoinKeys(plan)
}

private def getWrappedScan(scan: Scan, sHolder: ScanBuilderHolder): Scan = {
scan match {
case v1: V1Scan =>
Expand Down
Loading