diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index c9d6fb61ba69..614b73b1547f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -692,7 +692,7 @@ case class ReplaceTableAsSelect( // RTAS may drop and recreate table before query execution, breaking self-references // refresh and pin versions here to read from original table versions instead of // newly created empty table that is meant to serve as target for append/overwrite - val refreshedQuery = V2TableRefreshUtil.refreshVersions(query) + val refreshedQuery = V2TableRefreshUtil.refresh(query, versionedOnly = true) val pinnedQuery = V2TableRefreshUtil.pinVersions(refreshedQuery) copy(query = pinnedQuery, isAnalyzed = true) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 3462ae0e4206..b0fb414fce97 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -133,6 +133,8 @@ case class DataSourceV2Relation( def autoSchemaEvolution(): Boolean = table.capabilities().contains(TableCapability.AUTOMATIC_SCHEMA_EVOLUTION) + + def isVersioned: Boolean = table.currentVersion != null } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2TableRefreshUtil.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2TableRefreshUtil.scala index 4a2141629240..e98b80b6a5a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2TableRefreshUtil.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2TableRefreshUtil.scala @@ -40,7 +40,7 @@ private[sql] object V2TableRefreshUtil extends SQLConfHelper with Logging { def pinVersions(plan: LogicalPlan): LogicalPlan = { plan transform { case r @ ExtractV2CatalogAndIdentifier(catalog, ident) - if r.table.currentVersion != null && r.timeTravelSpec.isEmpty => + if r.isVersioned && r.timeTravelSpec.isEmpty => val tableName = V2TableUtil.toQualifiedName(catalog, ident) val version = r.table.currentVersion logDebug(s"Pinning table version for $tableName to $version") @@ -49,21 +49,25 @@ private[sql] object V2TableRefreshUtil extends SQLConfHelper with Logging { } /** - * Refreshes table metadata for all versioned tables in the plan. + * Refreshes table metadata for tables in the plan. * * This method reloads table metadata from the catalog and validates: * - Table identity: Ensures table ID has not changed * - Data columns: Verifies captured columns match the current schema * - Metadata columns: Checks metadata column consistency * + * Tables with time travel specifications are skipped as they reference a specific point + * in time and don't have to be refreshed. + * * @param plan the logical plan to refresh + * @param versionedOnly indicates whether to refresh only versioned tables * @return plan with refreshed table metadata */ - def refreshVersions(plan: LogicalPlan): LogicalPlan = { + def refresh(plan: LogicalPlan, versionedOnly: Boolean = false): LogicalPlan = { val cache = mutable.HashMap.empty[(TableCatalog, Identifier), Table] plan transform { case r @ ExtractV2CatalogAndIdentifier(catalog, ident) - if r.table.currentVersion != null && r.timeTravelSpec.isEmpty => + if (r.isVersioned || !versionedOnly) && r.timeTravelSpec.isEmpty => val currentTable = cache.getOrElseUpdate((catalog, ident), { val tableName = V2TableUtil.toQualifiedName(catalog, ident) logDebug(s"Refreshing table metadata for $tableName") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 3944cf818895..a35efd96060f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.command.CommandUtils import org.apache.spark.sql.execution.datasources.{FileIndex, HadoopFsRelation, LogicalRelation, LogicalRelationWithTable} import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, ExtractV2Table, FileTable} +import org.apache.spark.sql.execution.datasources.v2.V2TableRefreshUtil import org.apache.spark.sql.internal.SQLConf import org.apache.spark.storage.StorageLevel import org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK @@ -352,11 +353,12 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper { needToRecache.foreach { cd => cd.cachedRepresentation.cacheBuilder.clearCache() val sessionWithConfigsOff = getOrCloneSessionWithConfigsOff(spark) - val newCache = sessionWithConfigsOff.withActive { - val qe = sessionWithConfigsOff.sessionState.executePlan(cd.plan) - InMemoryRelation(cd.cachedRepresentation.cacheBuilder, qe) + val (newKey, newCache) = sessionWithConfigsOff.withActive { + val refreshedPlan = V2TableRefreshUtil.refresh(cd.plan) + val qe = sessionWithConfigsOff.sessionState.executePlan(refreshedPlan) + qe.normalized -> InMemoryRelation(cd.cachedRepresentation.cacheBuilder, qe) } - val recomputedPlan = cd.copy(cachedRepresentation = newCache) + val recomputedPlan = cd.copy(plan = newKey, cachedRepresentation = newCache) this.synchronized { if (lookupCachedDataInternal(recomputedPlan.plan).nonEmpty) { logWarning("While recaching, data was already added to cache.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 5c7fbcc8edd0..12fce2f91dac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -207,7 +207,7 @@ class QueryExecution( // there may be delay between analysis and subsequent phases // therefore, refresh captured table versions to reflect latest data private val lazyTableVersionsRefreshed = LazyTry { - V2TableRefreshUtil.refreshVersions(commandExecuted) + V2TableRefreshUtil.refresh(commandExecuted, versionedOnly = true) } private[sql] def tableVersionsRefreshed: LogicalPlan = lazyTableVersionsRefreshed.get diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index bf7491625fa0..56faf2032065 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -406,11 +406,12 @@ object InMemoryRelation { def apply(cacheBuilder: CachedRDDBuilder, qe: QueryExecution): InMemoryRelation = { val optimizedPlan = qe.optimizedPlan val serializer = cacheBuilder.serializer - val newBuilder = if (serializer.supportsColumnarInput(optimizedPlan.output)) { - cacheBuilder.copy(cachedPlan = serializer.convertToColumnarPlanIfPossible(qe.executedPlan)) + val newCachedPlan = if (serializer.supportsColumnarInput(optimizedPlan.output)) { + serializer.convertToColumnarPlanIfPossible(qe.executedPlan) } else { - cacheBuilder.copy(cachedPlan = qe.executedPlan) + qe.executedPlan } + val newBuilder = cacheBuilder.copy(cachedPlan = newCachedPlan, logicalPlan = qe.logical) val relation = new InMemoryRelation( newBuilder.cachedPlan.output, newBuilder, optimizedPlan.outputOrdering) relation.statsOfPlanToCache = optimizedPlan.stats diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 91addd72ab2b..7faf580b6f7f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -68,6 +68,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils override def sparkConf: SparkConf = super.sparkConf .set("spark.sql.catalog.testcat", classOf[InMemoryCatalog].getName) + .set("spark.sql.catalog.testcat.copyOnLoad", "true") setupTestData() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala index c59e624cb178..a0a65b8524b7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala @@ -1592,6 +1592,30 @@ class DataSourceV2DataFrameSuite } } + test("cached DSv2 table DataFrame is refreshed and reused after insert") { + val t = "testcat.ns1.ns2.tbl" + withTable(t) { + sql(s"CREATE TABLE $t (id bigint, data string) USING foo") + val df1 = Seq((1L, "a"), (2L, "b")).toDF("id", "data") + df1.write.insertInto(t) + + // cache DataFrame pointing to table + val readDF1 = spark.table(t) + readDF1.cache() + assertCached(readDF1) + checkAnswer(readDF1, Seq(Row(1L, "a"), Row(2L, "b"))) + + // insert more data, invalidating and refreshing cache entry + val df2 = Seq((3L, "c"), (4L, "d")).toDF("id", "data") + df2.write.insertInto(t) + + // verify underlying plan is recached and picks up new data + val readDF2 = spark.table(t) + assertCached(readDF2) + checkAnswer(readDF2, Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c"), Row(4L, "d"))) + } + } + private def pinTable(catalogName: String, ident: Identifier, version: String): Unit = { catalog(catalogName) match { case inMemory: BasicInMemoryTableCatalog => inMemory.pinTable(ident, version)