Skip to content

Commit

Permalink
[SPARK-13893][SQL] Remove SQLContext.catalog/analyzer (internal method)
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?
Our internal code can go through SessionState.catalog and SessionState.analyzer. This brings two small benefits:
1. Reduces internal dependency on SQLContext.
2. Removes 2 public methods in Java (Java does not obey package private visibility).

More importantly, according to the design in SPARK-13485, we'd need to claim this catalog function for the user-facing public functions, rather than having an internal field.

## How was this patch tested?
Existing unit/integration test code.

Author: Reynold Xin <rxin@databricks.com>

Closes #11716 from rxin/SPARK-13893.
  • Loading branch information
rxin committed Mar 15, 2016
1 parent 48978ab commit 5e6f2f4
Show file tree
Hide file tree
Showing 27 changed files with 105 additions and 99 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
* @since 1.3.1
*/
def fill(value: Double, cols: Seq[String]): DataFrame = {
val columnEquals = df.sqlContext.analyzer.resolver
val columnEquals = df.sqlContext.sessionState.analyzer.resolver
val projections = df.schema.fields.map { f =>
// Only fill if the column is part of the cols list.
if (f.dataType.isInstanceOf[NumericType] && cols.exists(col => columnEquals(f.name, col))) {
Expand All @@ -182,7 +182,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
* @since 1.3.1
*/
def fill(value: String, cols: Seq[String]): DataFrame = {
val columnEquals = df.sqlContext.analyzer.resolver
val columnEquals = df.sqlContext.sessionState.analyzer.resolver
val projections = df.schema.fields.map { f =>
// Only fill if the column is part of the cols list.
if (f.dataType.isInstanceOf[StringType] && cols.exists(col => columnEquals(f.name, col))) {
Expand Down Expand Up @@ -353,7 +353,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
case _: String => StringType
}

val columnEquals = df.sqlContext.analyzer.resolver
val columnEquals = df.sqlContext.sessionState.analyzer.resolver
val projections = df.schema.fields.map { f =>
val shouldReplace = cols.exists(colName => columnEquals(colName, f.name))
if (f.dataType.isInstanceOf[NumericType] && targetColumnType == DoubleType && shouldReplace) {
Expand Down Expand Up @@ -382,7 +382,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
}
}

val columnEquals = df.sqlContext.analyzer.resolver
val columnEquals = df.sqlContext.sessionState.analyzer.resolver
val projections = df.schema.fields.map { f =>
values.find { case (k, _) => columnEquals(k, f.name) }.map { case (_, v) =>
v match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
*/
def table(tableName: String): DataFrame = {
Dataset.newDataFrame(sqlContext,
sqlContext.catalog.lookupRelation(
sqlContext.sessionState.catalog.lookupRelation(
sqlContext.sessionState.sqlParser.parseTableIdentifier(tableName)))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
*/
private def normalize(columnName: String, columnType: String): String = {
val validColumnNames = df.logicalPlan.output.map(_.name)
validColumnNames.find(df.sqlContext.analyzer.resolver(_, columnName))
validColumnNames.find(df.sqlContext.sessionState.analyzer.resolver(_, columnName))
.getOrElse(throw new AnalysisException(s"$columnType column $columnName not found in " +
s"existing columns (${validColumnNames.mkString(", ")})"))
}
Expand Down Expand Up @@ -358,7 +358,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
}

private def saveAsTable(tableIdent: TableIdentifier): Unit = {
val tableExists = df.sqlContext.catalog.tableExists(tableIdent)
val tableExists = df.sqlContext.sessionState.catalog.tableExists(tableIdent)

(tableExists, mode) match {
case (true, SaveMode.Ignore) =>
Expand Down
22 changes: 12 additions & 10 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -166,15 +166,16 @@ class Dataset[T] private[sql](
private implicit def classTag = unresolvedTEncoder.clsTag

protected[sql] def resolve(colName: String): NamedExpression = {
queryExecution.analyzed.resolveQuoted(colName, sqlContext.analyzer.resolver).getOrElse {
throw new AnalysisException(
s"""Cannot resolve column name "$colName" among (${schema.fieldNames.mkString(", ")})""")
}
queryExecution.analyzed.resolveQuoted(colName, sqlContext.sessionState.analyzer.resolver)
.getOrElse {
throw new AnalysisException(
s"""Cannot resolve column name "$colName" among (${schema.fieldNames.mkString(", ")})""")
}
}

protected[sql] def numericColumns: Seq[Expression] = {
schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map { n =>
queryExecution.analyzed.resolveQuoted(n.name, sqlContext.analyzer.resolver).get
queryExecution.analyzed.resolveQuoted(n.name, sqlContext.sessionState.analyzer.resolver).get
}
}

Expand Down Expand Up @@ -1400,7 +1401,7 @@ class Dataset[T] private[sql](
* @since 1.3.0
*/
def withColumn(colName: String, col: Column): DataFrame = {
val resolver = sqlContext.analyzer.resolver
val resolver = sqlContext.sessionState.analyzer.resolver
val output = queryExecution.analyzed.output
val shouldReplace = output.exists(f => resolver(f.name, colName))
if (shouldReplace) {
Expand All @@ -1421,7 +1422,7 @@ class Dataset[T] private[sql](
* Returns a new [[DataFrame]] by adding a column with metadata.
*/
private[spark] def withColumn(colName: String, col: Column, metadata: Metadata): DataFrame = {
val resolver = sqlContext.analyzer.resolver
val resolver = sqlContext.sessionState.analyzer.resolver
val output = queryExecution.analyzed.output
val shouldReplace = output.exists(f => resolver(f.name, colName))
if (shouldReplace) {
Expand All @@ -1445,7 +1446,7 @@ class Dataset[T] private[sql](
* @since 1.3.0
*/
def withColumnRenamed(existingName: String, newName: String): DataFrame = {
val resolver = sqlContext.analyzer.resolver
val resolver = sqlContext.sessionState.analyzer.resolver
val output = queryExecution.analyzed.output
val shouldRename = output.exists(f => resolver(f.name, existingName))
if (shouldRename) {
Expand Down Expand Up @@ -1480,7 +1481,7 @@ class Dataset[T] private[sql](
*/
@scala.annotation.varargs
def drop(colNames: String*): DataFrame = {
val resolver = sqlContext.analyzer.resolver
val resolver = sqlContext.sessionState.analyzer.resolver
val remainingCols =
schema.filter(f => colNames.forall(n => !resolver(f.name, n))).map(f => Column(f.name))
if (remainingCols.size == this.schema.size) {
Expand All @@ -1501,7 +1502,8 @@ class Dataset[T] private[sql](
def drop(col: Column): DataFrame = {
val expression = col match {
case Column(u: UnresolvedAttribute) =>
queryExecution.analyzed.resolveQuoted(u.name, sqlContext.analyzer.resolver).getOrElse(u)
queryExecution.analyzed.resolveQuoted(
u.name, sqlContext.sessionState.analyzer.resolver).getOrElse(u)
case Column(expr: Expression) => expr
}
val attrs = this.logicalPlan.output
Expand Down
13 changes: 6 additions & 7 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,6 @@ class SQLContext private[sql](
@transient
protected[sql] lazy val sessionState: SessionState = new SessionState(self)
protected[sql] def conf: SQLConf = sessionState.conf
protected[sql] def catalog: Catalog = sessionState.catalog
protected[sql] def analyzer: Analyzer = sessionState.analyzer

/**
* An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s
Expand Down Expand Up @@ -699,7 +697,8 @@ class SQLContext private[sql](
* only during the lifetime of this instance of SQLContext.
*/
private[sql] def registerDataFrameAsTable(df: DataFrame, tableName: String): Unit = {
catalog.registerTable(sessionState.sqlParser.parseTableIdentifier(tableName), df.logicalPlan)
sessionState.catalog.registerTable(
sessionState.sqlParser.parseTableIdentifier(tableName), df.logicalPlan)
}

/**
Expand All @@ -712,7 +711,7 @@ class SQLContext private[sql](
*/
def dropTempTable(tableName: String): Unit = {
cacheManager.tryUncacheQuery(table(tableName))
catalog.unregisterTable(TableIdentifier(tableName))
sessionState.catalog.unregisterTable(TableIdentifier(tableName))
}

/**
Expand Down Expand Up @@ -797,7 +796,7 @@ class SQLContext private[sql](
}

private def table(tableIdent: TableIdentifier): DataFrame = {
Dataset.newDataFrame(this, catalog.lookupRelation(tableIdent))
Dataset.newDataFrame(this, sessionState.catalog.lookupRelation(tableIdent))
}

/**
Expand Down Expand Up @@ -839,7 +838,7 @@ class SQLContext private[sql](
* @since 1.3.0
*/
def tableNames(): Array[String] = {
catalog.getTables(None).map {
sessionState.catalog.getTables(None).map {
case (tableName, _) => tableName
}.toArray
}
Expand All @@ -851,7 +850,7 @@ class SQLContext private[sql](
* @since 1.3.0
*/
def tableNames(databaseName: String): Array[String] = {
catalog.getTables(Some(databaseName)).map {
sessionState.catalog.getTables(Some(databaseName)).map {
case (tableName, _) => tableName
}.toArray
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer}
*/
class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) {

def assertAnalyzed(): Unit = try sqlContext.analyzer.checkAnalysis(analyzed) catch {
def assertAnalyzed(): Unit = try sqlContext.sessionState.analyzer.checkAnalysis(analyzed) catch {
case e: AnalysisException =>
val ae = new AnalysisException(e.message, e.line, e.startPosition, Some(analyzed))
ae.setStackTrace(e.getStackTrace)
throw ae
}

lazy val analyzed: LogicalPlan = sqlContext.analyzer.execute(logical)
lazy val analyzed: LogicalPlan = sqlContext.sessionState.analyzer.execute(logical)

lazy val withCachedData: LogicalPlan = {
assertAnalyzed()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ case class ShowTablesCommand(databaseName: Option[String]) extends RunnableComma
override def run(sqlContext: SQLContext): Seq[Row] = {
// Since we need to return a Seq of rows, we will call getTables directly
// instead of calling tables in sqlContext.
val rows = sqlContext.catalog.getTables(databaseName).map {
val rows = sqlContext.sessionState.catalog.getTables(databaseName).map {
case (tableName, isTemporary) => Row(tableName, isTemporary)
}

Expand Down Expand Up @@ -417,7 +417,7 @@ case class DescribeFunction(
case class SetDatabaseCommand(databaseName: String) extends RunnableCommand {

override def run(sqlContext: SQLContext): Seq[Row] = {
sqlContext.catalog.setCurrentDatabase(databaseName)
sqlContext.sessionState.catalog.setCurrentDatabase(databaseName)
Seq.empty[Row]
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ private[sql] object FileSourceStrategy extends Strategy with Logging {
val filterSet = ExpressionSet(filters)

val partitionColumns =
AttributeSet(l.resolve(files.partitionSchema, files.sqlContext.analyzer.resolver))
AttributeSet(
l.resolve(files.partitionSchema, files.sqlContext.sessionState.analyzer.resolver))
val partitionKeyFilters =
ExpressionSet(filters.filter(_.references.subsetOf(partitionColumns)))
logInfo(s"Pruning directories with: ${partitionKeyFilters.mkString(",")}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ case class CreateTempTableUsing(
userSpecifiedSchema = userSpecifiedSchema,
className = provider,
options = options)
sqlContext.catalog.registerTable(
sqlContext.sessionState.catalog.registerTable(
tableIdent,
Dataset.newDataFrame(sqlContext, LogicalRelation(dataSource.resolveRelation())).logicalPlan)

Expand All @@ -124,7 +124,7 @@ case class CreateTempTableUsingAsSelect(
bucketSpec = None,
options = options)
val result = dataSource.write(mode, df)
sqlContext.catalog.registerTable(
sqlContext.sessionState.catalog.registerTable(
tableIdent,
Dataset.newDataFrame(sqlContext, LogicalRelation(result)).logicalPlan)

Expand All @@ -137,11 +137,11 @@ case class RefreshTable(tableIdent: TableIdentifier)

override def run(sqlContext: SQLContext): Seq[Row] = {
// Refresh the given table's metadata first.
sqlContext.catalog.refreshTable(tableIdent)
sqlContext.sessionState.catalog.refreshTable(tableIdent)

// If this table is cached as a InMemoryColumnarRelation, drop the original
// cached version and make the new version cached lazily.
val logicalPlan = sqlContext.catalog.lookupRelation(tableIdent)
val logicalPlan = sqlContext.sessionState.catalog.lookupRelation(tableIdent)
// Use lookupCachedData directly since RefreshTable also takes databaseName.
val isCached = sqlContext.cacheManager.lookupCachedData(logicalPlan).nonEmpty
if (isCached) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex
}

after {
sqlContext.catalog.unregisterTable(TableIdentifier("ListTablesSuiteTable"))
sqlContext.sessionState.catalog.unregisterTable(TableIdentifier("ListTablesSuiteTable"))
}

test("get all tables") {
Expand All @@ -45,7 +45,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex
sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"),
Row("ListTablesSuiteTable", true))

sqlContext.catalog.unregisterTable(TableIdentifier("ListTablesSuiteTable"))
sqlContext.sessionState.catalog.unregisterTable(TableIdentifier("ListTablesSuiteTable"))
assert(sqlContext.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0)
}

Expand All @@ -58,7 +58,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex
sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"),
Row("ListTablesSuiteTable", true))

sqlContext.catalog.unregisterTable(TableIdentifier("ListTablesSuiteTable"))
sqlContext.sessionState.catalog.unregisterTable(TableIdentifier("ListTablesSuiteTable"))
assert(sqlContext.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext
sql("INSERT INTO TABLE t SELECT * FROM tmp")
checkAnswer(sqlContext.table("t"), (data ++ data).map(Row.fromTuple))
}
sqlContext.catalog.unregisterTable(TableIdentifier("tmp"))
sqlContext.sessionState.catalog.unregisterTable(TableIdentifier("tmp"))
}

test("overwriting") {
Expand All @@ -61,7 +61,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext
sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp")
checkAnswer(sqlContext.table("t"), data.map(Row.fromTuple))
}
sqlContext.catalog.unregisterTable(TableIdentifier("tmp"))
sqlContext.sessionState.catalog.unregisterTable(TableIdentifier("tmp"))
}

test("self-join") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,6 @@ class HiveContext private[hive](
@transient
protected[sql] override lazy val sessionState = new HiveSessionState(self)

protected[sql] override def catalog = sessionState.catalog

// The Hive UDF current_database() is foldable, will be evaluated by optimizer,
// but the optimizer can't access the SessionState of metadataHive.
sessionState.functionRegistry.registerFunction(
Expand Down Expand Up @@ -349,12 +347,12 @@ class HiveContext private[hive](
*/
def refreshTable(tableName: String): Unit = {
val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName)
catalog.refreshTable(tableIdent)
sessionState.catalog.refreshTable(tableIdent)
}

protected[hive] def invalidateTable(tableName: String): Unit = {
val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName)
catalog.invalidateTable(tableIdent)
sessionState.catalog.invalidateTable(tableIdent)
}

/**
Expand All @@ -368,7 +366,7 @@ class HiveContext private[hive](
*/
def analyze(tableName: String) {
val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName)
val relation = EliminateSubqueryAliases(catalog.lookupRelation(tableIdent))
val relation = EliminateSubqueryAliases(sessionState.catalog.lookupRelation(tableIdent))

relation match {
case relation: MetastoreRelation =>
Expand Down Expand Up @@ -429,7 +427,7 @@ class HiveContext private[hive](
// recorded in the Hive metastore.
// This logic is based on org.apache.hadoop.hive.ql.exec.StatsTask.aggregateStats().
if (newTotalSize > 0 && newTotalSize != oldTotalSize) {
catalog.client.alterTable(
sessionState.catalog.client.alterTable(
relation.table.copy(
properties = relation.table.properties +
(StatsSetupConst.TOTAL_SIZE -> newTotalSize.toString)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,17 +69,17 @@ case class CreateTableAsSelect(
withFormat
}

hiveContext.catalog.client.createTable(withSchema, ignoreIfExists = false)
hiveContext.sessionState.catalog.client.createTable(withSchema, ignoreIfExists = false)

// Get the Metastore Relation
hiveContext.catalog.lookupRelation(tableIdentifier, None) match {
hiveContext.sessionState.catalog.lookupRelation(tableIdentifier, None) match {
case r: MetastoreRelation => r
}
}
// TODO ideally, we should get the output data ready first and then
// add the relation into catalog, just in case of failure occurs while data
// processing.
if (hiveContext.catalog.tableExists(tableIdentifier)) {
if (hiveContext.sessionState.catalog.tableExists(tableIdentifier)) {
if (allowExisting) {
// table already exists, will do nothing, to keep consistent with Hive
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,14 @@ private[hive] case class CreateViewAsSelect(
override def run(sqlContext: SQLContext): Seq[Row] = {
val hiveContext = sqlContext.asInstanceOf[HiveContext]

hiveContext.catalog.tableExists(tableIdentifier) match {
hiveContext.sessionState.catalog.tableExists(tableIdentifier) match {
case true if allowExisting =>
// Handles `CREATE VIEW IF NOT EXISTS v0 AS SELECT ...`. Does nothing when the target view
// already exists.

case true if orReplace =>
// Handles `CREATE OR REPLACE VIEW v0 AS SELECT ...`
hiveContext.catalog.client.alertView(prepareTable(sqlContext))
hiveContext.sessionState.catalog.client.alertView(prepareTable(sqlContext))

case true =>
// Handles `CREATE VIEW v0 AS SELECT ...`. Throws exception when the target view already
Expand All @@ -66,7 +66,7 @@ private[hive] case class CreateViewAsSelect(
"CREATE OR REPLACE VIEW AS")

case false =>
hiveContext.catalog.client.createView(prepareTable(sqlContext))
hiveContext.sessionState.catalog.client.createView(prepareTable(sqlContext))
}

Seq.empty[Row]
Expand Down
Loading

0 comments on commit 5e6f2f4

Please sign in to comment.