Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.catalog

import java.net.URI
import java.util.Locale
import java.util.concurrent.Callable
import java.util.concurrent.{Callable, ExecutionException, TimeUnit}
import javax.annotation.concurrent.GuardedBy

import scala.collection.mutable
Expand All @@ -31,7 +31,7 @@ import org.apache.hadoop.fs.Path

import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.{QualifiedTableName, _}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo, ImplicitCastInputTypes}
Expand Down Expand Up @@ -104,6 +104,12 @@ class SessionCatalog(

private val validNameFormat = "([\\w_]+)".r

private val catalogTableCache = {
val expireSeconds = conf.tableCatalogCacheExpireSeconds
CacheBuilder.newBuilder().expireAfterWrite(expireSeconds, TimeUnit.SECONDS)
.build[QualifiedTableName, CatalogTable]()
}

/**
* Checks if the given name conforms the Hive standard ("[a-zA-Z_0-9]+"),
* i.e. if this name only contains characters, numbers, and _.
Expand Down Expand Up @@ -222,6 +228,7 @@ class SessionCatalog(
if (cascade && databaseExists(dbName)) {
listTables(dbName).foreach { t =>
invalidateCachedTable(QualifiedTableName(dbName, t.table))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two methods invalidateCachedTable and invalidateCachedCatalogTable are really confusing in their names. I suggest to introduce some rewording to let the names more intuitive.

invalidateCachedCatalogTable(QualifiedTableName(dbName, t.table))
}
}
externalCatalog.dropDatabase(dbName, ignoreIfNotExists, cascade)
Expand Down Expand Up @@ -366,6 +373,7 @@ class SessionCatalog(
tableDefinition.copy(identifier = tableIdentifier)
}

invalidateCachedCatalogTable(QualifiedTableName(db, table))
externalCatalog.alterTable(newTableDefinition)
}

Expand All @@ -386,7 +394,7 @@ class SessionCatalog(
requireDbExists(db)
requireTableExists(tableIdentifier)

val catalogTable = externalCatalog.getTable(db, table)
val catalogTable = getTableMetadata(tableIdentifier)
val oldDataSchema = catalogTable.dataSchema
// not supporting dropping columns yet
val nonExistentColumnNames =
Expand All @@ -399,6 +407,7 @@ class SessionCatalog(
""".stripMargin)
}

invalidateCachedCatalogTable(QualifiedTableName(db, table))
externalCatalog.alterTableDataSchema(db, table, newDataSchema)
}

Expand All @@ -416,6 +425,7 @@ class SessionCatalog(
val tableIdentifier = TableIdentifier(table, Some(db))
requireDbExists(db)
requireTableExists(tableIdentifier)
invalidateCachedCatalogTable(QualifiedTableName(db, table))
externalCatalog.alterTableStats(db, table, newStats)
// Invalidate the table relation cache
refreshTable(identifier)
Expand All @@ -428,7 +438,12 @@ class SessionCatalog(
def tableExists(name: TableIdentifier): Boolean = synchronized {
val db = formatDatabaseName(name.database.getOrElse(currentDb))
val table = formatTableName(name.table)
externalCatalog.tableExists(db, table)
val exists = externalCatalog.tableExists(db, table)
if (!exists) {
// try best to keep cached table right
invalidateCachedCatalogTable(QualifiedTableName(db, table))
}
exists
}

/**
Expand All @@ -440,9 +455,12 @@ class SessionCatalog(
def getTableMetadata(name: TableIdentifier): CatalogTable = {
val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase))
val table = formatTableName(name.table)
requireDbExists(db)
requireTableExists(TableIdentifier(table, Some(db)))
externalCatalog.getTable(db, table)
val qtn = QualifiedTableName(db, table)
getOrCacheCatalogTable(qtn, () => {
requireDbExists(db)
requireTableExists(TableIdentifier(table, Some(db)))
externalCatalog.getTable(db, table)
})
}

/**
Expand Down Expand Up @@ -669,6 +687,7 @@ class SessionCatalog(
requireTableNotExists(TableIdentifier(newTableName, Some(db)))
validateName(newTableName)
validateNewLocationOfRename(oldName, newName)
invalidateCachedCatalogTable(QualifiedTableName(db, oldTableName))
externalCatalog.renameTable(db, oldTableName, newTableName)
} else {
if (newName.database.isDefined) {
Expand Down Expand Up @@ -711,6 +730,7 @@ class SessionCatalog(
// When ignoreIfNotExists is false, no exception is issued when the table does not exist.
// Instead, log it as an error message.
if (tableExists(TableIdentifier(table, Option(db)))) {
invalidateCachedCatalogTable(QualifiedTableName(db, table))
externalCatalog.dropTable(db, table, ignoreIfNotExists = true, purge = purge)
} else if (!ignoreIfNotExists) {
throw new NoSuchTableException(db = db, table = table)
Expand Down Expand Up @@ -872,6 +892,8 @@ class SessionCatalog(
// Also invalidate the table relation cache.
val qualifiedTableName = QualifiedTableName(dbName, tableName)
tableRelationCache.invalidate(qualifiedTableName)

invalidateCachedCatalogTable(qualifiedTableName)
}

/**
Expand Down Expand Up @@ -908,6 +930,7 @@ class SessionCatalog(
requireTableExists(TableIdentifier(table, Option(db)))
requireExactMatchedPartitionSpec(parts.map(_.spec), getTableMetadata(tableName))
requireNonEmptyValueInPartitionSpec(parts.map(_.spec))
invalidateCachedCatalogTable(QualifiedTableName(db, table))
externalCatalog.createPartitions(
db, table, partitionWithQualifiedPath(tableName, parts), ignoreIfExists)
}
Expand All @@ -928,6 +951,7 @@ class SessionCatalog(
requireTableExists(TableIdentifier(table, Option(db)))
requirePartialMatchedPartitionSpec(specs, getTableMetadata(tableName))
requireNonEmptyValueInPartitionSpec(specs)
invalidateCachedCatalogTable(QualifiedTableName(db, table))
externalCatalog.dropPartitions(db, table, specs, ignoreIfNotExists, purge, retainData)
}

Expand All @@ -950,6 +974,7 @@ class SessionCatalog(
requireExactMatchedPartitionSpec(newSpecs, tableMetadata)
requireNonEmptyValueInPartitionSpec(specs)
requireNonEmptyValueInPartitionSpec(newSpecs)
invalidateCachedCatalogTable(QualifiedTableName(db, table))
externalCatalog.renamePartitions(db, table, specs, newSpecs)
}

Expand All @@ -969,6 +994,7 @@ class SessionCatalog(
requireTableExists(TableIdentifier(table, Option(db)))
requireExactMatchedPartitionSpec(parts.map(_.spec), getTableMetadata(tableName))
requireNonEmptyValueInPartitionSpec(parts.map(_.spec))
invalidateCachedCatalogTable(QualifiedTableName(db, table))
externalCatalog.alterPartitions(db, table, partitionWithQualifiedPath(tableName, parts))
}

Expand Down Expand Up @@ -1484,6 +1510,41 @@ class SessionCatalog(
require(functionBuilder.isDefined, s"built-in function '$f' is missing function builder")
functionRegistry.registerFunction(f, expressionInfo.get, functionBuilder.get)
}
invalidateAllCachedCatalogTables()
}

private[sql] def getCachedCatalogTable(qtn: QualifiedTableName): Option[CatalogTable] = {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comments. Other new methods as well.

catalogTableCache.getIfPresent(qtn) match {
case null => None
case catalogTable => Some(catalogTable)
}
}

private[sql] def getOrCacheCatalogTable(
qtn: QualifiedTableName,
init: Callable[CatalogTable]): CatalogTable = {
try {
catalogTableCache.get(qtn, init)
} catch {
case e: ExecutionException =>
// unpack ExecutionException to raw Exception
throw e.getCause
case other: Throwable =>
// unexpected exception, should never happen
throw other
}
}

private[sql] def cacheCatalogTable(qtn: QualifiedTableName, catalogTable: CatalogTable): Unit = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caches should pass a Callable so that populating the cache can be combined with a get operation (get or initialize).

Instead of cacheCatalogTable, this should be getOrCacheCatalogTable(qtn: QualifiedTableName, init: Callable[CatalogTable]) that calls catalogTableCache.get(qtn, init).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, fix these.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here getOrCacheCatalogTable is just a simple wrapper doing nothing.
I suggest to move the lambda parameter in the caller into this method. It is in fact a part of "Get or Cache".

catalogTableCache.put(qtn, catalogTable)
}

private[sql] def invalidateAllCachedCatalogTables(): Unit = {
catalogTableCache.cleanUp()
}

private[sql] def invalidateCachedCatalogTable(qtn: QualifiedTableName): Unit = {
catalogTableCache.invalidate(qtn)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2253,6 +2253,9 @@ class SQLConf extends Serializable with Logging {
def tableRelationCacheSize: Int =
getConf(StaticSQLConf.FILESOURCE_TABLE_RELATION_CACHE_SIZE)

def tableCatalogCacheExpireSeconds: Long =
getConf(StaticSQLConf.FILESOURCE_TABLE_CATALOG_CACHE_EXPIRE_SECONDS)

def codegenCacheMaxEntries: Int = getConf(StaticSQLConf.CODEGEN_CACHE_MAX_ENTRIES)

def exchangeReuseEnabled: Boolean = getConf(EXCHANGE_REUSE_ENABLED)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.internal

import java.util.Locale
import java.util.concurrent.TimeUnit

import org.apache.spark.util.Utils

Expand Down Expand Up @@ -69,6 +70,15 @@ object StaticSQLConf {
.checkValue(cacheSize => cacheSize >= 0, "The maximum size of the cache must not be negative")
.createWithDefault(1000)

val FILESOURCE_TABLE_CATALOG_CACHE_EXPIRE_SECONDS =
buildStaticConf("spark.sql.filesourceTableCatalogCacheExpireSeconds")
.internal()
.doc("The maximum expire seconds time of cache table catalog.")
.timeConf(TimeUnit.SECONDS)
.checkValue(time => time >= 0 && time < 8,
"The cache expire seconds threshold must be in [0,8].")
.createWithDefault(2)

val CODEGEN_CACHE_MAX_ENTRIES = buildStaticConf("spark.sql.codegen.cache.maxEntries")
.internal()
.doc("When nonzero, enable caching of generated classes for operators and expressions. " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.catalog

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{AliasIdentifier, FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.{AliasIdentifier, FunctionIdentifier, QualifiedTableName, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
Expand Down Expand Up @@ -1617,4 +1617,38 @@ abstract class SessionCatalogSuite extends AnalysisTest {
assert(cause.cause.get.getMessage.contains("Actual error"))
}
}

test("test catalog table cache") {
withBasicCatalog { catalog =>
val tableId = TableIdentifier("tbl1", Some("db2"))
val qtn = QualifiedTableName("db2", "tbl1")

val ct1 = catalog.getTableMetadata(tableId)
val cct1 = catalog.getCachedCatalogTable(qtn)
assert(ct1 == cct1.get)

Thread.sleep(conf.tableCatalogCacheExpireSeconds * 1000 + 1)
val expireCatalogTable = catalog.getCachedCatalogTable(qtn)
assert(expireCatalogTable.isEmpty)

catalog.getTableMetadata(tableId)
catalog.refreshTable(tableId)
val cct3 = catalog.getCachedCatalogTable(qtn)
assert(cct3.isEmpty)

catalog.getTableMetadata(tableId)
catalog.dropTable(tableId, false, false)
val cct2 = catalog.getCachedCatalogTable(qtn)
assert(cct2.isEmpty)
}

withBasicCatalog { catalog =>
val tableId = TableIdentifier("tbl1", Some("db2"))
val qtn = QualifiedTableName("db2", "tbl1")
catalog.getTableMetadata(tableId)
catalog.dropDatabase(qtn.database, false, true)
val cct4 = catalog.getCachedCatalogTable(qtn)
assert(cct4.isEmpty)
}
}
}