From 939f8775c3db072e083117fdbce9ae8a5d3242d3 Mon Sep 17 00:00:00 2001 From: Anton Okolnychyi Date: Tue, 28 Oct 2025 11:13:41 -0700 Subject: [PATCH] [SPARK-54022][SQL] Make DSv2 table resolution aware of cached tables --- .../sql/catalyst/analysis/Analyzer.scala | 10 ++- .../sql/catalyst/analysis/RelationCache.scala | 28 +++++++ .../analysis/RelationResolution.scala | 74 ++++++++++++------ .../analysis/resolver/HybridAnalyzer.scala | 1 + .../catalyst/analysis/resolver/Resolver.scala | 13 +++- .../catalyst/plans/logical/v2Commands.scala | 8 +- .../sql/connector/catalog/CatalogV2Util.scala | 24 +++++- .../spark/sql/execution/CacheManager.scala | 75 ++++++++++++++++--- .../spark/sql/execution/QueryExecution.scala | 2 +- .../datasources/v2/V2TableRefreshUtil.scala | 30 ++++++-- .../v2/WriteToDataSourceV2Exec.scala | 9 ++- .../internal/BaseSessionStateBuilder.scala | 4 +- .../spark/sql/internal/SharedState.scala | 8 ++ .../apache/spark/sql/CachedTableSuite.scala | 35 +++++++++ .../sql/analysis/resolver/ResolverSuite.scala | 5 +- .../DataSourceV2DataFrameSuite.scala | 67 +++++++++++++++++ .../command/PlanResolutionSuite.scala | 4 +- .../sql/hive/HiveSessionStateBuilder.scala | 2 +- 18 files changed, 337 insertions(+), 62 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationCache.scala rename sql/{catalyst => core}/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2TableRefreshUtil.scala (79%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 6b0665c1b7f3..8b0bc59dc80b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -80,7 +80,8 @@ object SimpleAnalyzer extends Analyzer( FunctionRegistry.builtin, TableFunctionRegistry.builtin) { override def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = {} - })) { + }), + RelationCache.empty) { override def resolver: Resolver = caseSensitiveResolution } @@ -285,11 +286,14 @@ object Analyzer { * Provides a logical query plan analyzer, which translates [[UnresolvedAttribute]]s and * [[UnresolvedRelation]]s into fully typed objects using information in a [[SessionCatalog]]. */ -class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor[LogicalPlan] +class Analyzer( + override val catalogManager: CatalogManager, + private[sql] val sharedRelationCache: RelationCache = RelationCache.empty) + extends RuleExecutor[LogicalPlan] with CheckAnalysis with AliasHelper with SQLConfHelper with ColumnResolutionHelper { private val v1SessionCatalog: SessionCatalog = catalogManager.v1SessionCatalog - private val relationResolution = new RelationResolution(catalogManager) + private val relationResolution = new RelationResolution(catalogManager, sharedRelationCache) private val functionResolution = new FunctionResolution(catalogManager, relationResolution) override protected def validatePlanChanges( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationCache.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationCache.scala new file mode 100644 index 000000000000..770a5e780b24 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationCache.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan + +private[sql] trait RelationCache { + def lookup(nameParts: Seq[String], resolver: Resolver): Option[LogicalPlan] +} + +private[sql] object RelationCache { + val empty: RelationCache = (_, _) => None +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala index c7b92bc2a9fe..15d5e4874dbb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala @@ -46,7 +46,9 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.ArrayImplicits._ -class RelationResolution(override val catalogManager: CatalogManager) +class RelationResolution( + override val catalogManager: CatalogManager, + sharedRelationCache: RelationCache) extends DataTypeErrorsBase with Logging with LookupCatalog @@ -118,36 +120,62 @@ class RelationResolution(override val catalogManager: CatalogManager) val planId = u.getTagValue(LogicalPlan.PLAN_ID_TAG) relationCache .get(key) - .map { cache => - val cachedRelation = cache.transform { - case multi: MultiInstanceRelation => - val newRelation = multi.newInstance() - newRelation.copyTagsFrom(multi) - newRelation - } - cloneWithPlanId(cachedRelation, planId) - } + .map(adaptCachedRelation(_, planId)) .orElse { - val writePrivilegesString = - Option(u.options.get(UnresolvedRelation.REQUIRED_WRITE_PRIVILEGES)) - val table = - CatalogV2Util.loadTable(catalog, ident, finalTimeTravelSpec, writePrivilegesString) - val loaded = createRelation( + val writePrivileges = u.options.get(UnresolvedRelation.REQUIRED_WRITE_PRIVILEGES) + val finalOptions = u.clearWritePrivileges.options + val table = CatalogV2Util.loadTable( catalog, ident, - table, - u.clearWritePrivileges.options, - u.isStreaming, - finalTimeTravelSpec - ) - loaded.foreach(relationCache.update(key, _)) - loaded.map(cloneWithPlanId(_, planId)) - } + finalTimeTravelSpec, + Option(writePrivileges)) + + val sharedRelationCacheMatch = for { + t <- table + if finalTimeTravelSpec.isEmpty && writePrivileges == null && !u.isStreaming + cached <- lookupSharedRelationCache(catalog, ident, t) + } yield { + val updatedRelation = cached.copy(options = finalOptions) + val nameParts = ident.toQualifiedNameParts(catalog) + val aliasedRelation = SubqueryAlias(nameParts, updatedRelation) + relationCache.update(key, aliasedRelation) + adaptCachedRelation(aliasedRelation, planId) + } + + sharedRelationCacheMatch.orElse { + val loaded = createRelation( + catalog, + ident, + table, + finalOptions, + u.isStreaming, + finalTimeTravelSpec) + loaded.foreach(relationCache.update(key, _)) + loaded.map(cloneWithPlanId(_, planId)) + } + } case _ => None } } } + private def lookupSharedRelationCache( + catalog: CatalogPlugin, + ident: Identifier, + table: Table): Option[DataSourceV2Relation] = { + CatalogV2Util.lookupCachedRelation(sharedRelationCache, catalog, ident, table, conf) + } + + private def adaptCachedRelation(cached: LogicalPlan, planId: Option[Long]): LogicalPlan = { + val plan = cached transform { + case multi: MultiInstanceRelation => + val newRelation = multi.newInstance() + newRelation.copyTagsFrom(multi) + newRelation + } + cloneWithPlanId(plan, planId) + } + private def createRelation( catalog: CatalogPlugin, ident: Identifier, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/HybridAnalyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/HybridAnalyzer.scala index 0117b3fc2fb5..d346969be8ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/HybridAnalyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/HybridAnalyzer.scala @@ -302,6 +302,7 @@ object HybridAnalyzer { resolverGuard = new ResolverGuard(legacyAnalyzer.catalogManager), resolver = new Resolver( catalogManager = legacyAnalyzer.catalogManager, + sharedRelationCache = legacyAnalyzer.sharedRelationCache, extensions = legacyAnalyzer.singlePassResolverExtensions, metadataResolverExtensions = legacyAnalyzer.singlePassMetadataResolverExtensions, externalRelationResolution = Some(legacyAnalyzer.getRelationResolution) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/Resolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/Resolver.scala index 75d23f29ecfc..78029d593df1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/Resolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/Resolver.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.{ AnalysisErrorAt, FunctionResolution, MultiInstanceRelation, + RelationCache, RelationResolution, ResolvedInlineTable, UnresolvedHaving, @@ -71,6 +72,7 @@ import org.apache.spark.sql.errors.QueryCompilationErrors */ class Resolver( catalogManager: CatalogManager, + sharedRelationCache: RelationCache = RelationCache.empty, override val extensions: Seq[ResolverExtension] = Seq.empty, metadataResolverExtensions: Seq[ResolverExtension] = Seq.empty, externalRelationResolution: Option[RelationResolution] = None) @@ -81,8 +83,9 @@ class Resolver( private val cteRegistry = new CteRegistry private val subqueryRegistry = new SubqueryRegistry private val identifierAndCteSubstitutor = new IdentifierAndCteSubstitutor - private val relationResolution = - externalRelationResolution.getOrElse(Resolver.createRelationResolution(catalogManager)) + private val relationResolution = externalRelationResolution.getOrElse { + Resolver.createRelationResolution(catalogManager, sharedRelationCache) + } private val functionResolution = new FunctionResolution(catalogManager, relationResolution) private val expressionResolver = new ExpressionResolver(this, functionResolution, planLogger) private val aggregateResolver = new AggregateResolver(this, expressionResolver) @@ -788,7 +791,9 @@ object Resolver { /** * Create a new instance of the [[RelationResolution]]. */ - def createRelationResolution(catalogManager: CatalogManager): RelationResolution = { - new RelationResolution(catalogManager) + def createRelationResolution( + catalogManager: CatalogManager, + sharedRelationCache: RelationCache = RelationCache.empty): RelationResolution = { + new RelationResolution(catalogManager, sharedRelationCache) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index 614b73b1547f..62a336f8f308 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -40,7 +40,6 @@ import org.apache.spark.sql.connector.write.RowLevelOperation.Command.{DELETE, M import org.apache.spark.sql.errors.DataTypeErrors.toSQLType import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, ExtractV2Table} -import org.apache.spark.sql.execution.datasources.v2.V2TableRefreshUtil import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{ArrayType, AtomicType, BooleanType, DataType, IntegerType, MapType, MetadataBuilder, StringType, StructField, StructType} import org.apache.spark.util.ArrayImplicits._ @@ -689,12 +688,7 @@ case class ReplaceTableAsSelect( extends V2CreateTableAsSelectPlan { override def markAsAnalyzed(ac: AnalysisContext): LogicalPlan = { - // RTAS may drop and recreate table before query execution, breaking self-references - // refresh and pin versions here to read from original table versions instead of - // newly created empty table that is meant to serve as target for append/overwrite - val refreshedQuery = V2TableRefreshUtil.refresh(query, versionedOnly = true) - val pinnedQuery = V2TableRefreshUtil.pinVersions(refreshedQuery) - copy(query = pinnedQuery, isAnalyzed = true) + copy(isAnalyzed = true) } override def withPartitioning(rewritten: Seq[Transform]): V2CreateTablePlan = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala index 28bca400f5b8..07cb370d18dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala @@ -25,7 +25,7 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.{SparkException, SparkIllegalArgumentException} import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.CurrentUserContext -import org.apache.spark.sql.catalyst.analysis.{AsOfTimestamp, AsOfVersion, NamedRelation, NoSuchDatabaseException, NoSuchFunctionException, NoSuchTableException, TimeTravelSpec} +import org.apache.spark.sql.catalyst.analysis.{AsOfTimestamp, AsOfVersion, NamedRelation, NoSuchDatabaseException, NoSuchFunctionException, NoSuchTableException, RelationCache, TimeTravelSpec} import org.apache.spark.sql.catalyst.catalog.ClusterBySpec import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, V2ExpressionUtils} import org.apache.spark.sql.catalyst.plans.logical.{SerdeInfo, TableSpec} @@ -36,6 +36,7 @@ import org.apache.spark.sql.connector.catalog.constraints.Constraint import org.apache.spark.sql.connector.catalog.functions.UnboundFunction import org.apache.spark.sql.connector.expressions.{ClusterByTransform, LiteralValue, Transform} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{ArrayType, MapType, Metadata, MetadataBuilder, StructField, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.ArrayImplicits._ @@ -497,6 +498,27 @@ private[sql] object CatalogV2Util { loadTable(catalog, ident).map(DataSourceV2Relation.create(_, Some(catalog), Some(ident))) } + def isSameTable( + rel: DataSourceV2Relation, + catalog: CatalogPlugin, + ident: Identifier, + table: Table): Boolean = { + rel.catalog.contains(catalog) && rel.identifier.contains(ident) && rel.table.id == table.id + } + + def lookupCachedRelation( + cache: RelationCache, + catalog: CatalogPlugin, + ident: Identifier, + table: Table, + conf: SQLConf): Option[DataSourceV2Relation] = { + val nameParts = ident.toQualifiedNameParts(catalog) + val cached = cache.lookup(nameParts, conf.resolver) + cached.collect { + case r: DataSourceV2Relation if isSameTable(r, catalog, ident, table) => r + } + } + def isSessionCatalog(catalog: CatalogPlugin): Boolean = { catalog.name().equalsIgnoreCase(CatalogManager.SESSION_CATALOG_NAME) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index a35efd96060f..5a38751b61e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -22,6 +22,7 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.internal.{Logging, MessageWithContext} import org.apache.spark.internal.LogKeys._ import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases +import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.analysis.V2TableReference import org.apache.spark.sql.catalyst.catalog.HiveTableRelation import org.apache.spark.sql.catalyst.expressions.{Attribute, SubqueryExpression} @@ -30,13 +31,14 @@ import org.apache.spark.sql.catalyst.plans.logical.{IgnoreCachedData, LogicalPla import org.apache.spark.sql.catalyst.trees.TreePattern.PLAN_EXPRESSION import org.apache.spark.sql.catalyst.util.sideBySide import org.apache.spark.sql.classic.{Dataset, SparkSession} -import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.IdentifierHelper +import org.apache.spark.sql.connector.catalog.CatalogPlugin +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.{IdentifierHelper, MultipartIdentifierHelper} +import org.apache.spark.sql.connector.catalog.Identifier import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.command.CommandUtils import org.apache.spark.sql.execution.datasources.{FileIndex, HadoopFsRelation, LogicalRelation, LogicalRelationWithTable} -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, ExtractV2Table, FileTable} -import org.apache.spark.sql.execution.datasources.v2.V2TableRefreshUtil +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, ExtractV2CatalogAndIdentifier, ExtractV2Table, FileTable, V2TableRefreshUtil} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.storage.StorageLevel import org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK @@ -240,31 +242,51 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper { name: Seq[String], conf: SQLConf, includeTimeTravel: Boolean): Boolean = { - def isSameName(nameInCache: Seq[String]): Boolean = { - nameInCache.length == name.length && nameInCache.zip(name).forall(conf.resolver.tupled) - } + isMatchedTableOrView(plan, name, conf.resolver, includeTimeTravel) + } + + private def isMatchedTableOrView( + plan: LogicalPlan, + name: Seq[String], + resolver: Resolver, + includeTimeTravel: Boolean): Boolean = { EliminateSubqueryAliases(plan) match { case LogicalRelationWithTable(_, Some(catalogTable)) => - isSameName(catalogTable.identifier.nameParts) + isSameName(name, catalogTable.identifier.nameParts, resolver) case DataSourceV2Relation(_, _, Some(catalog), Some(v2Ident), _, timeTravelSpec) => val nameInCache = v2Ident.toQualifiedNameParts(catalog) - isSameName(nameInCache) && (includeTimeTravel || timeTravelSpec.isEmpty) + isSameName(name, nameInCache, resolver) && (includeTimeTravel || timeTravelSpec.isEmpty) case r: V2TableReference => - isSameName(r.identifier.toQualifiedNameParts(r.catalog)) + isSameName(name, r.identifier.toQualifiedNameParts(r.catalog), resolver) case v: View => - isSameName(v.desc.identifier.nameParts) + isSameName(name, v.desc.identifier.nameParts, resolver) case HiveTableRelation(catalogTable, _, _, _, _) => - isSameName(catalogTable.identifier.nameParts) + isSameName(name, catalogTable.identifier.nameParts, resolver) case _ => false } } + private def isSameName( + name: Seq[String], + catalog: CatalogPlugin, + ident: Identifier, + resolver: Resolver): Boolean = { + isSameName(name, ident.toQualifiedNameParts(catalog), resolver) + } + + private def isSameName( + name: Seq[String], + nameInCache: Seq[String], + resolver: Resolver): Boolean = { + nameInCache.length == name.length && nameInCache.zip(name).forall(resolver.tupled) + } + private def uncacheByCondition( spark: SparkSession, isMatchedPlan: LogicalPlan => Boolean, @@ -354,7 +376,7 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper { cd.cachedRepresentation.cacheBuilder.clearCache() val sessionWithConfigsOff = getOrCloneSessionWithConfigsOff(spark) val (newKey, newCache) = sessionWithConfigsOff.withActive { - val refreshedPlan = V2TableRefreshUtil.refresh(cd.plan) + val refreshedPlan = V2TableRefreshUtil.refresh(sessionWithConfigsOff, cd.plan) val qe = sessionWithConfigsOff.sessionState.executePlan(refreshedPlan) qe.normalized -> InMemoryRelation(cd.cachedRepresentation.cacheBuilder, qe) } @@ -371,6 +393,35 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper { } } + private[sql] def lookupCachedTable( + name: Seq[String], + resolver: Resolver): Option[LogicalPlan] = { + val cachedRelations = findCachedRelations(name, resolver) + cachedRelations match { + case cachedRelation +: _ => + CacheManager.logCacheOperation( + log"Relation cache hit for table ${MDC(TABLE_NAME, name.quoted)}") + Some(cachedRelation) + case _ => + None + } + } + + private def findCachedRelations( + name: Seq[String], + resolver: Resolver): Seq[LogicalPlan] = { + cachedData.flatMap { cd => + val plan = EliminateSubqueryAliases(cd.plan) + plan match { + case r @ ExtractV2CatalogAndIdentifier(catalog, ident) + if isSameName(name, catalog, ident, resolver) && r.timeTravelSpec.isEmpty => + Some(r) + case _ => + None + } + } + } + /** * Optionally returns cached data for the given [[Dataset]] */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 12fce2f91dac..26d2078791aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -207,7 +207,7 @@ class QueryExecution( // there may be delay between analysis and subsequent phases // therefore, refresh captured table versions to reflect latest data private val lazyTableVersionsRefreshed = LazyTry { - V2TableRefreshUtil.refresh(commandExecuted, versionedOnly = true) + V2TableRefreshUtil.refresh(sparkSession, commandExecuted, versionedOnly = true) } private[sql] def tableVersionsRefreshed: LogicalPlan = lazyTableVersionsRefreshed.get diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2TableRefreshUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2TableRefreshUtil.scala similarity index 79% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2TableRefreshUtil.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2TableRefreshUtil.scala index e98b80b6a5a0..db75f83d658a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2TableRefreshUtil.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2TableRefreshUtil.scala @@ -23,7 +23,9 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.analysis.AsOfVersion import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.connector.catalog.{Identifier, Table, TableCatalog, V2TableUtil} +import org.apache.spark.sql.connector.catalog.CatalogV2Util import org.apache.spark.sql.errors.QueryCompilationErrors private[sql] object V2TableRefreshUtil extends SQLConfHelper with Logging { @@ -59,19 +61,29 @@ private[sql] object V2TableRefreshUtil extends SQLConfHelper with Logging { * Tables with time travel specifications are skipped as they reference a specific point * in time and don't have to be refreshed. * + * @param spark the currently active Spark session * @param plan the logical plan to refresh * @param versionedOnly indicates whether to refresh only versioned tables * @return plan with refreshed table metadata */ - def refresh(plan: LogicalPlan, versionedOnly: Boolean = false): LogicalPlan = { - val cache = mutable.HashMap.empty[(TableCatalog, Identifier), Table] + def refresh( + spark: SparkSession, + plan: LogicalPlan, + versionedOnly: Boolean = false): LogicalPlan = { + val currentTables = mutable.HashMap.empty[(TableCatalog, Identifier), Table] plan transform { case r @ ExtractV2CatalogAndIdentifier(catalog, ident) if (r.isVersioned || !versionedOnly) && r.timeTravelSpec.isEmpty => - val currentTable = cache.getOrElseUpdate((catalog, ident), { + val currentTable = currentTables.getOrElseUpdate((catalog, ident), { val tableName = V2TableUtil.toQualifiedName(catalog, ident) - logDebug(s"Refreshing table metadata for $tableName") - catalog.loadTable(ident) + lookupCachedRelation(spark, catalog, ident, r.table) match { + case Some(cached) => + logDebug(s"Refreshing table metadata for $tableName using shared relation cache") + cached.table + case None => + logDebug(s"Refreshing table metadata for $tableName using catalog") + catalog.loadTable(ident) + } }) validateTableIdentity(currentTable, r) validateDataColumns(currentTable, r) @@ -80,6 +92,14 @@ private[sql] object V2TableRefreshUtil extends SQLConfHelper with Logging { } } + private def lookupCachedRelation( + spark: SparkSession, + catalog: TableCatalog, + ident: Identifier, + table: Table): Option[DataSourceV2Relation] = { + CatalogV2Util.lookupCachedRelation(spark.sharedState.relationCache, catalog, ident, table, conf) + } + private def validateTableIdentity(currentTable: Table, relation: DataSourceV2Relation): Unit = { if (relation.table.id != null && relation.table.id != currentTable.id) { throw QueryCompilationErrors.tableIdChangedAfterAnalysis( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 1d7566ce7f3e..9e5e45e984eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -168,6 +168,11 @@ case class ReplaceTableAsSelectExec( // 1. Creating the new table fails, // 2. Writing to the new table fails, // 3. The table returned by catalog.createTable doesn't support writing. + // + // RTAS must refresh and pin versions in query to read from original table versions instead of + // newly created empty table that is meant to serve as target for append/overwrite + val refreshedQuery = V2TableRefreshUtil.refresh(session, query, versionedOnly = true) + val pinnedQuery = V2TableRefreshUtil.pinVersions(refreshedQuery) if (catalog.tableExists(ident)) { invalidateCache(catalog, ident) catalog.dropTable(ident) @@ -175,13 +180,13 @@ case class ReplaceTableAsSelectExec( throw QueryCompilationErrors.cannotReplaceMissingTableError(ident) } val tableInfo = new TableInfo.Builder() - .withColumns(getV2Columns(query.schema, catalog.useNullableQuerySchema)) + .withColumns(getV2Columns(pinnedQuery.schema, catalog.useNullableQuerySchema)) .withPartitions(partitioning.toArray) .withProperties(properties.asJava) .build() val table = Option(catalog.createTable(ident, tableInfo)) .getOrElse(catalog.loadTable(ident, Set(TableWritePrivilege.INSERT).asJava)) - writeToTable(catalog, table, writeOptions, ident, query, overwrite = true) + writeToTable(catalog, table, writeOptions, ident, pinnedQuery, overwrite = true) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index c967497b660c..ef829eaae68c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -176,6 +176,8 @@ abstract class BaseSessionStateBuilder( protected lazy val catalogManager = new CatalogManager(v2SessionCatalog, catalog) + protected lazy val sharedRelationCache = session.sharedState.relationCache + /** * Interface exposed to the user for registering user-defined functions. * @@ -197,7 +199,7 @@ abstract class BaseSessionStateBuilder( * * Note: this depends on the `conf` and `catalog` fields. */ - protected def analyzer: Analyzer = new Analyzer(catalogManager) { + protected def analyzer: Analyzer = new Analyzer(catalogManager, sharedRelationCache) { override val hintResolutionRules: Seq[Rule[LogicalPlan]] = customHintResolutionRules diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala index af1f38caab68..8e641294bf8c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala @@ -31,6 +31,7 @@ import org.apache.hadoop.fs.{FsUrlStreamHandlerFactory, Path} import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.internal.Logging import org.apache.spark.internal.LogKeys.{CONFIG, CONFIG2, PATH, VALUE} +import org.apache.spark.sql.catalyst.analysis.RelationCache import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.CacheManager @@ -96,6 +97,13 @@ private[sql] class SharedState( */ val cacheManager: CacheManager = new CacheManager + /** + * A relation cache backed by the cache manager. + */ + private[sql] val relationCache: RelationCache = { + (nameParts, resolver) => cacheManager.lookupCachedTable(nameParts, resolver) + } + /** A global lock for all streaming query lifecycle tracking and management. */ private[sql] val activeQueriesLock = new Object diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 7faf580b6f7f..12d26c4e195f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import java.io.{File, FilenameFilter} import java.nio.file.{Files, Paths} import java.time.{Duration, LocalDateTime, LocalTime, Period} +import java.util import java.util.concurrent.atomic.AtomicBoolean import scala.collection.mutable.HashSet @@ -41,6 +42,8 @@ import org.apache.spark.sql.connector.catalog.CatalogPlugin import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.CatalogHelper import org.apache.spark.sql.connector.catalog.Identifier import org.apache.spark.sql.connector.catalog.InMemoryCatalog +import org.apache.spark.sql.connector.catalog.TableWritePrivilege +import org.apache.spark.sql.connector.catalog.TruncatableTable import org.apache.spark.sql.execution.{ColumnarToRowExec, ExecSubqueryExpression, RDDScanExec, SparkPlan, SparkPlanInfo} import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, AQEPropagateEmptyRelation} import org.apache.spark.sql.execution.columnar._ @@ -2563,6 +2566,38 @@ class CachedTableSuite extends QueryTest with SQLTestUtils context = ExpectedContext("non_existent", 14, 25)) } + test("SPARK-54022: caching table via CACHE TABLE should pin table state") { + val t = "testcat.ns1.ns2.tbl" + val ident = Identifier.of(Array("ns1", "ns2"), "tbl") + withTable(t) { + sql(s"CREATE TABLE $t (id INT, value INT, category STRING) USING foo") + sql(s"INSERT INTO $t VALUES (1, 10, 'A'), (2, 20, 'B'), (3, 30, 'A')") + + // cache table + sql(s"CACHE TABLE $t") + + // verify caching works as expected + assertCached(spark.table(t)) + checkAnswer(spark.table(t), Seq(Row(1, 10, "A"), Row(2, 20, "B"), Row(3, 30, "A"))) + + // modify table directly to mimic external changes + val tableCatalog = catalog("testcat").asTableCatalog + val table = tableCatalog.loadTable(ident, util.Set.of(TableWritePrivilege.DELETE)) + table.asInstanceOf[TruncatableTable].truncateTable() + + // verify this has no impact on cached state + assertCached(spark.table(t)) + checkAnswer(spark.table(t), Seq(Row(1, 10, "A"), Row(2, 20, "B"), Row(3, 30, "A"))) + + // add more data within session that should invalidate cache + sql(s"INSERT INTO $t VALUES (10, 100, 'x')") + + // table should be re-cached correctly + assertCached(spark.table(t)) + checkAnswer(spark.table(t), Seq(Row(10, 100, "x"))) + } + } + private def cacheManager = spark.sharedState.cacheManager private def pinTable( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/ResolverSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/ResolverSuite.scala index 8a54f6520974..0e23f984d922 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/ResolverSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/ResolverSuite.scala @@ -108,7 +108,10 @@ class ResolverSuite extends QueryTest with SharedSparkSession { } private def createResolver(extensions: Seq[ResolverExtension] = Seq.empty): Resolver = { - new Resolver(spark.sessionState.catalogManager, extensions) + new Resolver( + spark.sessionState.catalogManager, + spark.sharedState.relationCache, + extensions) } private class TestRelationResolver extends ResolverExtension { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala index 205fa561a5b0..0d56d95fdb5b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.connector +import java.util import java.util.Collections import scala.jdk.CollectionConverters._ @@ -29,6 +30,8 @@ import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSel import org.apache.spark.sql.connector.catalog.{Column, ColumnDefaultValue, DefaultValue, Identifier, InMemoryTableCatalog, TableInfo} import org.apache.spark.sql.connector.catalog.BasicInMemoryTableCatalog import org.apache.spark.sql.connector.catalog.TableChange.{AddColumn, UpdateColumnDefaultValue} +import org.apache.spark.sql.connector.catalog.TableWritePrivilege +import org.apache.spark.sql.connector.catalog.TruncatableTable import org.apache.spark.sql.connector.expressions.{ApplyTransform, GeneralScalarExpression, LiteralValue, Transform} import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue} import org.apache.spark.sql.execution.{QueryExecution, SparkPlan} @@ -1644,6 +1647,70 @@ class DataSourceV2DataFrameSuite } } + test("SPARK-54022: caching table via Dataset API should pin table state") { + val t = "testcat.ns1.ns2.tbl" + val ident = Identifier.of(Array("ns1", "ns2"), "tbl") + withTable(t) { + sql(s"CREATE TABLE $t (id INT, value INT, category STRING) USING foo") + sql(s"INSERT INTO $t VALUES (1, 10, 'A'), (2, 20, 'B'), (3, 30, 'A')") + + // cache table + spark.table(t).cache() + + // verify caching works as expected + assertCached(spark.table(t)) + checkAnswer(spark.table(t), Seq(Row(1, 10, "A"), Row(2, 20, "B"), Row(3, 30, "A"))) + + // modify table directly to mimic external changes + val table = catalog("testcat").loadTable(ident, util.Set.of(TableWritePrivilege.DELETE)) + table.asInstanceOf[TruncatableTable].truncateTable() + + // verify external changes have no impact on cached state + assertCached(spark.table(t)) + checkAnswer(spark.table(t), Seq(Row(1, 10, "A"), Row(2, 20, "B"), Row(3, 30, "A"))) + + // add more data within session that should invalidate cache + sql(s"INSERT INTO $t VALUES (10, 100, 'x')") + + // table should be re-cached correctly + assertCached(spark.table(t)) + checkAnswer(spark.table(t), Seq(Row(10, 100, "x"))) + } + } + + test("SPARK-54022: caching a query via Dataset API should not pin table state") { + val t = "testcat.ns1.ns2.tbl" + val ident = Identifier.of(Array("ns1", "ns2"), "tbl") + withTable(t) { + sql(s"CREATE TABLE $t (id INT, value INT, category STRING) USING foo") + sql(s"INSERT INTO $t VALUES (1, 10, 'A'), (2, 20, 'B'), (3, 30, 'A')") + + // cache query on top of table + val df = spark.table(t).select("id") + df.cache() + + // verify query caching works as expected + assertCached(spark.table(t).select("id")) + checkAnswer(spark.table(t).select("id"), Seq(Row(1), Row(2), Row(3))) + + // verify table itself is not cached + assertNotCached(spark.table(t)) + checkAnswer(spark.table(t), Seq(Row(1, 10, "A"), Row(2, 20, "B"), Row(3, 30, "A"))) + + // modify table directly to mimic external changes + val table = catalog("testcat").loadTable(ident, util.Set.of(TableWritePrivilege.DELETE)) + table.asInstanceOf[TruncatableTable].truncateTable() + + // verify cached DataFrame is unaffected by external changes + assertCached(df) + checkAnswer(df, Seq(Row(1), Row(2), Row(3))) + + // verify external changes are reflected correctly when table is queried + assertNotCached(spark.table(t)) + checkAnswer(spark.table(t), Seq.empty) + } + } + private def pinTable(catalogName: String, ident: Identifier, version: String): Unit = { catalog(catalogName) match { case inMemory: BasicInMemoryTableCatalog => inMemory.pinTable(ident, version) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index dfd24a1ebe97..b677ea78fdca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -3284,7 +3284,9 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { unresolvedRelation: UnresolvedRelation, timeTravelSpec: Option[TimeTravelSpec] = None, planId: Option[Long] = None): DataSourceV2Relation = { - val rule = new RelationResolution(catalogManagerWithDefault) + val rule = new RelationResolution( + catalogManagerWithDefault, + spark.sharedState.relationCache) rule.resolveRelation(unresolvedRelation, timeTravelSpec) match { case Some(p @ AsDataSourceV2Relation(relation)) => assert(unresolvedRelation.getTagValue(LogicalPlan.PLAN_ID_TAG) == planId) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index aa801b6e2f68..dec947651dd6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -85,7 +85,7 @@ class HiveSessionStateBuilder( /** * A logical query plan `Analyzer` with rules specific to Hive. */ - override protected def analyzer: Analyzer = new Analyzer(catalogManager) { + override protected def analyzer: Analyzer = new Analyzer(catalogManager, sharedRelationCache) { override val singlePassResolverExtensions: Seq[ResolverExtension] = Seq( new LogicalRelationResolver, new HiveTableRelationNoopResolver