Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.internal.LogKeys._
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.catalog.HiveTableRelation
import org.apache.spark.sql.catalyst.expressions.{Attribute, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, SubqueryExpression}
import org.apache.spark.sql.catalyst.optimizer.EliminateResolvedHint
import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan, ResolvedHint, View}
import org.apache.spark.sql.catalyst.trees.TreePattern.PLAN_EXPRESSION
Expand All @@ -36,11 +36,12 @@ import org.apache.spark.sql.connector.catalog.CatalogPlugin
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.{IdentifierHelper, MultipartIdentifierHelper}
import org.apache.spark.sql.connector.catalog.Identifier
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.columnar.InMemoryRelation
import org.apache.spark.sql.execution.columnar.{InMemoryCacheTable, InMemoryRelation}
import org.apache.spark.sql.execution.command.CommandUtils
import org.apache.spark.sql.execution.datasources.{FileIndex, HadoopFsRelation, LogicalRelation, LogicalRelationWithTable}
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, ExtractV2CatalogAndIdentifier, ExtractV2Table, FileTable, V2TableRefreshUtil}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.storage.StorageLevel
import org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK

Expand Down Expand Up @@ -332,8 +333,17 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper {
cachedData: CachedData,
column: Seq[Attribute]): Unit = {
val relation = cachedData.cachedRepresentation
// Wrap in DataSourceV2Relation so the DSv2 planning path is used consistently
// (DataSourceV2Strategy handles InMemoryTableScanExec via InMemoryCacheScan).
val dsv2Relation = DataSourceV2Relation(
table = new InMemoryCacheTable(relation),
output = relation.output.map { case ar: AttributeReference => ar },
catalog = None,
identifier = None,
options = CaseInsensitiveStringMap.empty()
)
val (rowCount, newColStats) =
CommandUtils.computeColumnStats(sparkSession, relation, column)
CommandUtils.computeColumnStats(sparkSession, dsv2Relation, column)
relation.updateStats(rowCount, newColStats)
}

Expand Down Expand Up @@ -502,9 +512,19 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper {
// After cache lookup, we should still keep the hints from the input plan.
val hints = EliminateResolvedHint.extractHintsFromPlan(currentFragment)._2
val cachedPlan = cached.cachedRepresentation.withOutput(currentFragment.output)
// Wrap the InMemoryRelation in a DataSourceV2Relation so that V2ScanRelationPushDown
// optimizer rules can apply column pruning, filter pushdown, and ordering/statistics
// reporting. Physical execution is still routed to InMemoryTableScanExec.
val substitutedPlan: LogicalPlan = DataSourceV2Relation(
table = new InMemoryCacheTable(cachedPlan),
output = cachedPlan.output.map { case ar: AttributeReference => ar },
catalog = None,
identifier = None,
options = CaseInsensitiveStringMap.empty()
)
// The returned hint list is in top-down order, we should create the hint nodes from
// right to left.
hints.foldRight[LogicalPlan](cachedPlan) { case (hint, p) =>
hints.foldRight[LogicalPlan](substitutedPlan) { case (hint, p) =>
ResolvedHint(p, hint)
}
}.getOrElse(currentFragment)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,12 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
}

/**
* Fallback strategy for cached in-memory tables when the DSv2 cache path is disabled
* (spark.sql.inMemoryColumnarStorage.useDataSourceV2 = false).
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This config spark.sql.inMemoryColumnarStorage.useDataSourceV2 doesn't exist anywhere in the codebase. Also, after this PR InMemoryScans becomes unreachable dead code since CacheManager always wraps in DataSourceV2Relation — should we add a config toggle, or remove this strategy?

* Under the default (DSv2) path InMemoryRelation is never exposed to the planner because
* CacheManager wraps it in DataSourceV2Relation before planning.
*/
object InMemoryScans extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case PhysicalOperation(projectList, filters, mem: InMemoryRelation) =>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,295 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.columnar

import java.util
import java.util.OptionalLong

import org.apache.spark.sql.catalyst.expressions.{
Ascending, Attribute, AttributeReference, Descending, NullsFirst, NullsLast,
SortOrder => CatalystSortOrder
}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability}
import org.apache.spark.sql.connector.expressions.{
Expression => V2Expression, FieldReference, NamedReference,
NullOrdering => V2NullOrdering, SortDirection => V2SortDirection,
SortOrder => V2SortOrder, SortValue
}
import org.apache.spark.sql.connector.expressions.filter.{Predicate => V2Predicate}
import org.apache.spark.sql.connector.read.{
Scan, ScanBuilder, Statistics => V2Statistics, SupportsPushDownLimit,
SupportsPushDownRequiredColumns, SupportsPushDownV2Filters, SupportsReportOrdering,
SupportsReportPartitioning, SupportsReportStatistics, SupportsRuntimeV2Filtering
}
import org.apache.spark.sql.connector.read.colstats.ColumnStatistics
import org.apache.spark.sql.connector.read.partitioning.{
KeyGroupedPartitioning, Partitioning, UnknownPartitioning
}
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanRelation}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap

/**
* A DSv2 [[Table]] wrapper around [[InMemoryRelation]], enabling [[V2ScanRelationPushDown]]
* optimizer rules to apply column pruning, filter pushdown, and ordering/statistics reporting
* to cached DataFrames.
*/
private[sql] class InMemoryCacheTable(val relation: InMemoryRelation)
extends Table with SupportsRead {

// Two InMemoryCacheTable instances wrapping the same CachedRDDBuilder are equal.
// All InMemoryRelation copies from the same CachedData share the same cacheBuilder by reference.
override def equals(other: Any): Boolean = other match {
case t: InMemoryCacheTable => relation.cacheBuilder eq t.relation.cacheBuilder
case _ => false
}
override def hashCode(): Int = System.identityHashCode(relation.cacheBuilder)

override def name(): String = relation.cacheBuilder.cachedName

override def schema(): StructType = DataTypeUtils.fromAttributes(relation.output)

override def capabilities(): util.Set[TableCapability] =
util.EnumSet.of(TableCapability.BATCH_READ)

override def newScanBuilder(options: CaseInsensitiveStringMap): InMemoryScanBuilder =
new InMemoryScanBuilder(relation)
}

/**
* DSv2 [[ScanBuilder]] for [[InMemoryRelation]].
*
* - Column pruning via [[SupportsPushDownRequiredColumns]]: only requested columns are
* passed to [[InMemoryTableScanExec]], reducing deserialization work.
* - Filter pushdown via [[SupportsPushDownV2Filters]]: predicates are recorded for
* batch-level pruning using per-batch min/max statistics, but all predicates are
* returned (category-2: still need post-scan row-level re-evaluation).
*/
private[sql] class InMemoryScanBuilder(relation: InMemoryRelation)
extends ScanBuilder
with SupportsPushDownRequiredColumns
with SupportsPushDownV2Filters
with SupportsPushDownLimit {

private var requiredSchema: StructType = DataTypeUtils.fromAttributes(relation.output)
private var _pushedPredicates: Array[V2Predicate] = Array.empty
private var _pushedLimit: Option[Int] = None

override def pruneColumns(required: StructType): Unit = {
requiredSchema = required
}

/**
* Records predicates so Spark adds a post-scan [[FilterExec]] for row-level evaluation.
* Batch-level min/max pruning is handled at physical planning: [[DataSourceV2Strategy]]
* passes the Catalyst [[FilterExec]] expressions extracted by [[PhysicalOperation]] directly
* to [[InMemoryTableScanExec]], which forwards them to [[CachedBatchSerializer.buildFilter]].
* The V2 [[Predicate]]s stored here are not used for batch pruning.
*/
override def pushPredicates(predicates: Array[V2Predicate]): Array[V2Predicate] = {
_pushedPredicates = predicates
predicates
}

override def pushedPredicates(): Array[V2Predicate] = _pushedPredicates

/**
* Pushes a LIMIT down into the scan. Returns true to indicate the limit was accepted.
* Because caching may interleave data across partitions, this is always a partial push:
* Spark will still apply a LocalLimit on top to enforce the exact count.
*/
override def pushLimit(limit: Int): Boolean = {
_pushedLimit = Some(limit)
true
}

/** Always partially pushed: Spark applies a LocalLimit on top. */
override def isPartiallyPushed(): Boolean = true

override def build(): InMemoryCacheScan = {
val requiredFieldNames = requiredSchema.fieldNames.toSet
val prunedAttrs =
if (requiredFieldNames == relation.output.map(_.name).toSet) relation.output
else relation.output.filter(a => requiredFieldNames.contains(a.name))
new InMemoryCacheScan(relation, prunedAttrs, _pushedPredicates, _pushedLimit)
}
}

/**
* DSv2 [[Scan]] for [[InMemoryRelation]].
*
* Physical execution is handled by [[InMemoryTableScanExec]] via [[DataSourceV2Strategy]]
* rather than [[Batch]]/[[InputPartition]] to preserve the existing efficient columnar path.
*
* Reports:
* - Ordering ([[SupportsReportOrdering]]): propagates the ordering of the original cached plan
* so the optimizer can eliminate redundant sorts on top of the cache.
* - Statistics ([[SupportsReportStatistics]]): exposes accurate row count and size from
* accumulated scan metrics once the cache is materialized, feeding AQE decisions.
* - Partitioning ([[SupportsReportPartitioning]]): reports [[KeyGroupedPartitioning]] when
* the cached plan was hash-partitioned on explicit columns, allowing the optimizer to
* skip shuffles for downstream joins/aggregates on the same key.
* - Runtime filtering ([[SupportsRuntimeV2Filtering]]): enables Dynamic Partition Pruning
* on cached scans; [[DynamicPruning]] expressions are passed via [[InMemoryTableScanExec]]
* for batch-level min/max pruning.
*/
private[sql] class InMemoryCacheScan(
val relation: InMemoryRelation,
val prunedAttrs: Seq[Attribute],
val pushedPredicates: Array[V2Predicate],
val pushedLimit: Option[Int] = None)
extends Scan
with SupportsReportOrdering
with SupportsReportStatistics
with SupportsReportPartitioning
with SupportsRuntimeV2Filtering {

override def readSchema(): StructType = DataTypeUtils.fromAttributes(prunedAttrs)

/**
* Converts the Catalyst sort ordering of the cached plan to V2 [[SortOrder]]s.
* Only attribute-reference based orderings whose column is present in [[prunedAttrs]] are
* emitted; sort keys that were pruned away are dropped so that [[V2ScanPartitioningAndOrdering]]
* does not attempt to resolve a column that is no longer in the scan output.
*/
override def outputOrdering(): Array[V2SortOrder] = {
val prunedNames = prunedAttrs.map(_.name).toSet
relation.outputOrdering.flatMap {
case CatalystSortOrder(attr: AttributeReference, direction, nullOrdering, _)
if prunedNames.contains(attr.name) =>
val v2Dir = direction match {
case Ascending => V2SortDirection.ASCENDING
case Descending => V2SortDirection.DESCENDING
}
val v2Nulls = nullOrdering match {
case NullsFirst => V2NullOrdering.NULLS_FIRST
case NullsLast => V2NullOrdering.NULLS_LAST
}
Some(SortValue(FieldReference.column(attr.name), v2Dir, v2Nulls))
case _ => None
}.toArray
}

/**
* Reports the output partitioning of the cached plan so the optimizer can skip
* shuffles for downstream operations on the same partitioning key.
*/
override def outputPartitioning(): Partitioning = {
relation.cachedPlan.outputPartitioning match {
case HashPartitioning(expressions, numPartitions) =>
val keys = expressions.collect { case a: AttributeReference =>
FieldReference.column(a.name).asInstanceOf[V2Expression]
}
if (keys.size == expressions.size) {
new KeyGroupedPartitioning(keys.toArray, numPartitions)
} else {
new UnknownPartitioning(numPartitions)
}
case other => new UnknownPartitioning(other.numPartitions)
}
}

/**
* Exposes hash-partitioning key columns for Dynamic Partition Pruning.
* Spark will inject runtime IN-list filters on these attributes when it can
* derive them from a broadcast side of a join.
*/
override def filterAttributes(): Array[NamedReference] = {
relation.cachedPlan.outputPartitioning match {
case HashPartitioning(exprs, _) =>
exprs.collect { case a: AttributeReference =>
FieldReference.column(a.name).asInstanceOf[NamedReference]
}.toArray
case _ => Array.empty
}
}

/**
* No-op: runtime predicates for cached scans are handled entirely through
* [[InMemoryTableScanExec.runtimeFilters]], not through this interface method.
* The DPP pipeline injects [[DynamicPruning]] expressions into the plan, which
* [[DataSourceV2Strategy]] separates and passes as runtimeFilters to the exec node.
*/
override def filter(predicates: Array[V2Predicate]): Unit = {}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

filter() is a no-op — runtime filtering is actually routed through InMemoryTableScanExec.runtimeFilters instead. This deviates from the V2 contract where filter() is expected to prune InputPartitions.


override def estimateStatistics(): V2Statistics = {
val stats = relation.computeStats()
// Scale sizeInBytes proportionally to the number of columns actually read.
// This gives the optimizer an accurate size estimate after column pruning.
val scaledSize: Long =
if (relation.output.nonEmpty && prunedAttrs.size < relation.output.size) {
(stats.sizeInBytes * prunedAttrs.size / relation.output.size).toLong.max(1)
} else {
stats.sizeInBytes.toLong
}
// Only report column stats for pruned (selected) attributes.
val prunedNames = prunedAttrs.map(_.name).toSet
val v2ColStats = new util.HashMap[NamedReference, ColumnStatistics]()
stats.attributeStats
.filter { case (attr, _) => prunedNames.contains(attr.name) }
.foreach { case (attr, colStat) =>
val cs = new ColumnStatistics {
override def distinctCount(): OptionalLong =
colStat.distinctCount
.map(v => OptionalLong.of(v.toLong)).getOrElse(OptionalLong.empty())
override def min(): util.Optional[Object] =
colStat.min.map(v => util.Optional.of(v.asInstanceOf[Object]))
.getOrElse(util.Optional.empty[Object]())
override def max(): util.Optional[Object] =
colStat.max.map(v => util.Optional.of(v.asInstanceOf[Object]))
.getOrElse(util.Optional.empty[Object]())
override def nullCount(): OptionalLong =
colStat.nullCount.map(v => OptionalLong.of(v.toLong)).getOrElse(OptionalLong.empty())
override def avgLen(): OptionalLong =
colStat.avgLen.map(OptionalLong.of).getOrElse(OptionalLong.empty())
override def maxLen(): OptionalLong =
colStat.maxLen.map(OptionalLong.of).getOrElse(OptionalLong.empty())
}
v2ColStats.put(FieldReference.column(attr.name), cs)
}
new V2Statistics {
override def sizeInBytes(): OptionalLong = OptionalLong.of(scaledSize)
override def numRows(): OptionalLong =
stats.rowCount.map(c => OptionalLong.of(c.toLong)).getOrElse(OptionalLong.empty())
override def columnStats(): util.Map[NamedReference, ColumnStatistics] = v2ColStats
}
}
}

/**
* Extractor that matches any in-plan representation of a cached DataFrame and returns its
* underlying [[InMemoryRelation]].
*
* Three forms appear depending on the query stage:
* - [[InMemoryRelation]] - the direct node (e.g. as stored in [[CachedData]]).
* - [[DataSourceV2Relation]] backed by [[InMemoryCacheTable]] - produced by [[CacheManager]]
* in `useCachedData`, visible in `QueryExecution.withCachedData`.
* - [[DataSourceV2ScanRelation]] backed by [[InMemoryCacheScan]] - after
* [[V2ScanRelationPushDown]] optimizes the above, visible in `QueryExecution.optimizedPlan`.
*/
object CachedRelation {
def unapply(plan: LogicalPlan): Option[InMemoryRelation] = plan match {
case mem: InMemoryRelation => Some(mem)
case DataSourceV2Relation(table: InMemoryCacheTable, _, _, _, _, _) => Some(table.relation)
case DataSourceV2ScanRelation(_, scan: InMemoryCacheScan, _, _, _) => Some(scan.relation)
case _ => None
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,15 @@ case class InMemoryRelation(
val newOutputOrdering = outputOrdering
.map(_.transform { case a: Attribute => map(a) })
.asInstanceOf[Seq[SortOrder]]
InMemoryRelation(newOutput, cacheBuilder, newOutputOrdering, statsOfPlanToCache)
// Remap attributeStats keys to new ExprIds by column name so column statistics survive
// attribute re-aliasing (withOutput is called on every cache lookup).
val nameToNew = newOutput.map(a => a.name -> a).toMap
val remappedColStats = statsOfPlanToCache.attributeStats.flatMap { case (attr, stat) =>
nameToNew.get(attr.name).map(_ -> stat)
}
val remappedStats = statsOfPlanToCache.copy(
attributeStats = AttributeMap(remappedColStats.toSeq))
InMemoryRelation(newOutput, cacheBuilder, newOutputOrdering, remappedStats)
}

override def newInstance(): this.type = {
Expand Down
Loading