Skip to content

Commit

Permalink
Extends Analyze commands for cached tables
Browse files Browse the repository at this point in the history
  • Loading branch information
maropu committed Jan 16, 2019
1 parent 5018f27 commit 1f78144
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@ import org.apache.hadoop.fs.{FileSystem, Path}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.{Dataset, SparkSession}
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, SubqueryExpression}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ResolvedHint}
import org.apache.spark.sql.execution.columnar.InMemoryRelation
import org.apache.spark.sql.execution.command.CommandUtils
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK
Expand Down Expand Up @@ -157,6 +158,38 @@ class CacheManager extends Logging {
}
}

private[sql] def analyzeColumnCacheQuery(
query: Dataset[_],
columnNames: Seq[Attribute]): Unit = writeLock {
val cachedData = lookupCachedData(query)
if (cachedData.isEmpty) {
logWarning("The cached data not found, so you need to cache the query first.")
} else {
cachedData.foreach { cachedData =>
val relation = cachedData.cachedRepresentation
val (rowCount, newColStats) =
CommandUtils.computeColumnStats(query.sparkSession, relation, columnNames)
val oldStats = cachedData.cachedRepresentation.statsOfPlanToCache
val newStats = oldStats.copy(
rowCount = Some(rowCount),
attributeStats = AttributeMap((oldStats.attributeStats ++ newColStats).toSeq)
)
cachedData.cachedRepresentation.statsOfPlanToCache = newStats
}
}
}

/**
* Analyzes column statistics in an already-cached table.
*
* @param spark The Spark session.
* @param tableName The identifier of a cached table.
* @param columnNames The names of columns to be analyzed for computing statistics.
*/
def analyzeColumn(spark: SparkSession, tableName: String, columnNames: Seq[Attribute]): Unit = {
analyzeColumnCacheQuery(spark.table(tableName), columnNames)
}

/**
* Tries to re-cache all the cache entries that refer to the given plan.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ case class InMemoryRelation(
output: Seq[Attribute],
@transient cacheBuilder: CachedRDDBuilder,
override val outputOrdering: Seq[SortOrder])(
statsOfPlanToCache: Statistics)
var statsOfPlanToCache: Statistics)
extends logical.LeafNode with MultiInstanceRelation {

override protected def innerChildren: Seq[SparkPlan] = Seq(cachedPlan)
Expand All @@ -186,7 +186,8 @@ case class InMemoryRelation(
// Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache.
statsOfPlanToCache
} else {
Statistics(sizeInBytes = cacheBuilder.sizeInBytesStats.value.longValue)
statsOfPlanToCache.copy(
sizeInBytes = cacheBuilder.sizeInBytesStats.value.longValue)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,34 +40,29 @@ case class AnalyzeColumnCommand(
require(columnNames.isDefined ^ allColumns, "Parameter `columnNames` or `allColumns` are " +
"mutually exclusive. Only one of them should be specified.")
val sessionState = sparkSession.sessionState
val db = tableIdent.database.getOrElse(sessionState.catalog.getCurrentDatabase)
val tableIdentWithDB = TableIdentifier(tableIdent.table, Some(db))
val tableMeta = sessionState.catalog.getTableMetadata(tableIdentWithDB)
if (tableMeta.tableType == CatalogTableType.VIEW) {
throw new AnalysisException("ANALYZE TABLE is not supported on views.")
}
val sizeInBytes = CommandUtils.calculateTotalSize(sparkSession, tableMeta)
val relation = sparkSession.table(tableIdent).logicalPlan
val columnsToAnalyze = getColumnsToAnalyze(tableIdent, relation, columnNames, allColumns)

// Compute stats for the computed list of columns.
val (rowCount, newColStats) =
CommandUtils.computeColumnStats(sparkSession, relation, columnsToAnalyze)

val newColCatalogStats = newColStats.map {
case (attr, columnStat) =>
attr.name -> columnStat.toCatalogColumnStat(attr.name, attr.dataType)
tableIdent.database match {
case None =>
sessionState.catalog.getTempView(tableIdent.identifier) match {
case Some(tempView) =>
val cacheManager = sparkSession.sharedState.cacheManager
cacheManager.lookupCachedData(tempView) match {
case Some(cachedData) =>
val columnsToAnalyze = getColumnsToAnalyze(
tableIdent, cachedData.plan, columnNames, allColumns)
cacheManager.analyzeColumn(sparkSession, tableIdent.identifier, columnsToAnalyze)
case None =>
throw new NoSuchTableException(
db = sessionState.catalog.getCurrentDatabase, table = tableIdent.identifier)
}
case _ =>
analyzeColumnInCatalog(sparkSession)
}

case _ =>
analyzeColumnInCatalog(sparkSession)
}

// We also update table-level stats in order to keep them consistent with column-level stats.
val statistics = CatalogStatistics(
sizeInBytes = sizeInBytes,
rowCount = Some(rowCount),
// Newly computed column stats should override the existing ones.
colStats = tableMeta.stats.map(_.colStats).getOrElse(Map.empty) ++ newColCatalogStats)

sessionState.catalog.alterTableStats(tableIdentWithDB, Some(statistics))

Seq.empty[Row]
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -946,4 +946,33 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
// Clean-up
df.unpersist()
}

test("SPARK-25196 analyzes column statistics in cached query") {
def query(): DataFrame = {
spark.range(100)
.selectExpr("id % 3 AS c0", "id % 5 AS c1", "2 AS c2")
.groupBy("c0")
.agg(avg("c1").as("v1"), sum("c2").as("v2"))
}
// First, checks if there is no column statistic in cached query
val queryStats1 = query().cache.queryExecution.optimizedPlan.stats.attributeStats
assert(queryStats1.map(_._1.name).isEmpty)

val cacheManager = spark.sharedState.cacheManager
val cachedData = cacheManager.lookupCachedData(query().logicalPlan)
assert(cachedData.isDefined)
val queryAttrs = cachedData.get.plan.output
assert(queryAttrs.size === 3)
val (c0, v1, v2) = (queryAttrs(0), queryAttrs(1), queryAttrs(2))

// Analyzes one column in the query output
cacheManager.analyzeColumnCacheQuery(query(), v1 :: Nil)
val queryStats2 = query().queryExecution.optimizedPlan.stats.attributeStats
assert(queryStats2.map(_._1.name).toSet === Set("v1"))

// Analyzes two more columns
cacheManager.analyzeColumnCacheQuery(query(), c0 :: v2 :: Nil)
val queryStats3 = query().queryExecution.optimizedPlan.stats.attributeStats
assert(queryStats3.map(_._1.name).toSet === Set("c0", "v1", "v2"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.io.File
import scala.collection.mutable

import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
import org.apache.spark.sql.catalyst.catalog.CatalogColumnStat
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -427,4 +428,34 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared
}
}
}

test("SPARK-25196 analyzes column statistics in cached query") {
withTempView("cachedTempView", "tempView") {
spark.sql(
"""CACHE TABLE cachedTempView AS
| SELECT c0, avg(c1) AS v1, avg(c2) AS v2
| FROM (SELECT id % 3 AS c0, id % 5 AS c1, 2 AS c2 FROM range(1, 30))
| GROUP BY c0
""".stripMargin)

// Analyzes one column in the cached logical plan
spark.sql("ANALYZE TABLE cachedTempView COMPUTE STATISTICS FOR COLUMNS v1".stripMargin)
val queryStats1 = spark.table("cachedTempView").queryExecution
.optimizedPlan.stats.attributeStats
assert(queryStats1.map(_._1.name).toSet === Set("v1"))

// Analyzes two more columns
spark.sql("ANALYZE TABLE cachedTempView COMPUTE STATISTICS FOR COLUMNS c0, v2")
val queryStats2 = spark.table("cachedTempView").queryExecution
.optimizedPlan.stats.attributeStats
assert(queryStats2.map(_._1.name).toSet === Set("c0", "v1", "v2"))

// Analyzes in a temporary table
spark.sql("CREATE TEMPORARY VIEW tempView AS SELECT * FROM range(1, 30)")
val errMsg = intercept[NoSuchTableException] {
spark.sql("ANALYZE TABLE tempView COMPUTE STATISTICS FOR COLUMNS id")
}.getMessage
assert(errMsg.contains("Table or view 'tempView' not found in database 'default'"))
}
}
}

0 comments on commit 1f78144

Please sign in to comment.