Skip to content

[SPARK-13477][SQL] Expose new user-facing Catalog interface #12713

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 11 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,15 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils


/**
* A helper trait to create [[org.apache.spark.sql.catalyst.encoders.ExpressionEncoder]]s
* for classes whose fields are entirely defined by constructor params but should not be
* case classes.
*/
private[sql] trait DefinedByConstructorParams


/**
* A default version of ScalaReflection that uses the runtime universe.
Expand Down Expand Up @@ -333,7 +341,7 @@ object ScalaReflection extends ScalaReflection {
"toScalaMap",
keyData :: valueData :: Nil)

case t if t <:< localTypeOf[Product] =>
case t if definedByConstructorParams(t) =>
val params = getConstructorParameters(t)

val cls = getClassFromType(tpe)
Expand Down Expand Up @@ -401,7 +409,7 @@ object ScalaReflection extends ScalaReflection {
val clsName = getClassNameFromType(tpe)
val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil
serializerFor(inputObject, tpe, walkedTypePath) match {
case expressions.If(_, _, s: CreateNamedStruct) if tpe <:< localTypeOf[Product] => s
case expressions.If(_, _, s: CreateNamedStruct) if definedByConstructorParams(tpe) => s
case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil)
}
}
Expand Down Expand Up @@ -491,7 +499,7 @@ object ScalaReflection extends ScalaReflection {
serializerFor(unwrapped, optType, newPath))
}

case t if t <:< localTypeOf[Product] =>
case t if definedByConstructorParams(t) =>
val params = getConstructorParameters(t)
val nonNullOutput = CreateNamedStruct(params.flatMap { case (fieldName, fieldType) =>
val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType))
Expand Down Expand Up @@ -680,7 +688,7 @@ object ScalaReflection extends ScalaReflection {
val Schema(valueDataType, valueNullable) = schemaFor(valueType)
Schema(MapType(schemaFor(keyType).dataType,
valueDataType, valueContainsNull = valueNullable), nullable = true)
case t if t <:< localTypeOf[Product] =>
case t if definedByConstructorParams(t) =>
val params = getConstructorParameters(t)
Schema(StructType(
params.map { case (fieldName, fieldType) =>
Expand Down Expand Up @@ -712,6 +720,14 @@ object ScalaReflection extends ScalaReflection {
throw new UnsupportedOperationException(s"Schema for type $other is not supported")
}
}

/**
* Whether the fields of the given type is defined entirely by its constructor parameters.
*/
private[sql] def definedByConstructorParams(tpe: Type): Boolean = {
tpe <:< localTypeOf[Product] || tpe <:< localTypeOf[DefinedByConstructorParams]
}

}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,25 @@ class SessionCatalog(
s"a permanent function registered in the database $currentDb.")
}

/**
* Look up the [[ExpressionInfo]] associated with the specified function, assuming it exists.
*/
private[spark] def lookupFunctionInfo(name: FunctionIdentifier): ExpressionInfo = {
// TODO: just make function registry take in FunctionIdentifier instead of duplicating this
val qualifiedName = name.copy(database = name.database.orElse(Some(currentDb)))
functionRegistry.lookupFunction(name.funcName)
.orElse(functionRegistry.lookupFunction(qualifiedName.unquotedString))
.getOrElse {
val db = qualifiedName.database.get
if (externalCatalog.functionExists(db, name.funcName)) {
val metadata = externalCatalog.getFunction(db, name.funcName)
new ExpressionInfo(metadata.className, qualifiedName.unquotedString)
} else {
failFunctionLookup(name.funcName)
}
}
}

/**
* Return an [[Expression]] that represents the specified function, assuming it exists.
*
Expand Down Expand Up @@ -646,6 +665,7 @@ class SessionCatalog(
// The function has not been loaded to the function registry, which means
// that the function is a permanent function (if it actually has been registered
// in the metastore). We need to first put the function in the FunctionRegistry.
// TODO: why not just check whether the function exists first?
val catalogFunction = try {
externalCatalog.getFunction(currentDb, name.funcName)
} catch {
Expand All @@ -662,7 +682,7 @@ class SessionCatalog(
val builder = makeFunctionBuilder(qualifiedName.unquotedString, catalogFunction.className)
createTempFunction(qualifiedName.unquotedString, info, builder, ignoreIfExists = false)
// Now, we need to create the Expression.
return functionRegistry.lookupFunction(qualifiedName.unquotedString, children)
functionRegistry.lookupFunction(qualifiedName.unquotedString, children)
}

/**
Expand All @@ -687,8 +707,8 @@ class SessionCatalog(
// -----------------

/**
* Drop all existing databases (except "default") along with all associated tables,
* partitions and functions, and set the current database to "default".
* Drop all existing databases (except "default"), tables, partitions and functions,
* and set the current database to "default".
*
* This is mainly used for tests.
*/
Expand All @@ -697,6 +717,16 @@ class SessionCatalog(
listDatabases().filter(_ != default).foreach { db =>
dropDatabase(db, ignoreIfNotExists = false, cascade = true)
}
listTables(default).foreach { table =>
dropTable(table, ignoreIfNotExists = false)
}
listFunctions(default).foreach { func =>
if (func.database.isDefined) {
dropFunction(func, ignoreIfNotExists = false)
} else {
dropTempFunction(func.funcName, ignoreIfNotExists = false)
}
}
tempTables.clear()
functionRegistry.clear()
// restore built-in functions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,10 +299,10 @@ case class CatalogTable(

case class CatalogTableType private(name: String)
object CatalogTableType {
val EXTERNAL_TABLE = new CatalogTableType("EXTERNAL_TABLE")
val MANAGED_TABLE = new CatalogTableType("MANAGED_TABLE")
val INDEX_TABLE = new CatalogTableType("INDEX_TABLE")
val VIRTUAL_VIEW = new CatalogTableType("VIRTUAL_VIEW")
val EXTERNAL = new CatalogTableType("EXTERNAL")
val MANAGED = new CatalogTableType("MANAGED")
val INDEX = new CatalogTableType("INDEX")
val VIEW = new CatalogTableType("VIEW")
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,9 @@ object ExpressionEncoder {
def apply[T : TypeTag](): ExpressionEncoder[T] = {
// We convert the not-serializable TypeTag into StructType and ClassTag.
val mirror = typeTag[T].mirror
val cls = mirror.runtimeClass(typeTag[T].tpe)
val flat = !classOf[Product].isAssignableFrom(cls)
val tpe = typeTag[T].tpe
val cls = mirror.runtimeClass(tpe)
val flat = !ScalaReflection.definedByConstructorParams(tpe)

val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = false)
val serializer = ScalaReflection.serializerFor[T](inputObject)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.util.Utils
*
* Implementations of the [[ExternalCatalog]] interface can create test suites by extending this.
*/
abstract class CatalogTestCases extends SparkFunSuite with BeforeAndAfterEach {
abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEach {
protected val utils: CatalogTestUtils
import utils._

Expand Down Expand Up @@ -152,10 +152,10 @@ abstract class CatalogTestCases extends SparkFunSuite with BeforeAndAfterEach {
test("the table type of an external table should be EXTERNAL_TABLE") {
val catalog = newBasicCatalog()
val table =
newTable("external_table1", "db2").copy(tableType = CatalogTableType.EXTERNAL_TABLE)
newTable("external_table1", "db2").copy(tableType = CatalogTableType.EXTERNAL)
catalog.createTable("db2", table, ignoreIfExists = false)
val actual = catalog.getTable("db2", "external_table1")
assert(actual.tableType === CatalogTableType.EXTERNAL_TABLE)
assert(actual.tableType === CatalogTableType.EXTERNAL)
}

test("drop table") {
Expand Down Expand Up @@ -551,14 +551,15 @@ abstract class CatalogTestUtils {
def newTable(name: String, database: Option[String] = None): CatalogTable = {
CatalogTable(
identifier = TableIdentifier(name, database),
tableType = CatalogTableType.EXTERNAL_TABLE,
tableType = CatalogTableType.EXTERNAL,
storage = storageFormat,
schema = Seq(
CatalogColumn("col1", "int"),
CatalogColumn("col2", "string"),
CatalogColumn("a", "int"),
CatalogColumn("b", "string")),
partitionColumnNames = Seq("a", "b"))
partitionColumnNames = Seq("a", "b"),
bucketColumnNames = Seq("col1"))
}

def newFunc(name: String, database: Option[String] = None): CatalogFunction = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.catalog


/** Test suite for the [[InMemoryCatalog]]. */
class InMemoryCatalogSuite extends CatalogTestCases {
class InMemoryCatalogSuite extends ExternalCatalogSuite {

protected override val utils: CatalogTestUtils = new CatalogTestUtils {
override val tableInputFormat: String = "org.apache.park.SequenceFileInputFormat"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Range, SubqueryAlias}
/**
* Tests for [[SessionCatalog]] that assume that [[InMemoryCatalog]] is correctly implemented.
*
* Note: many of the methods here are very similar to the ones in [[CatalogTestCases]].
* Note: many of the methods here are very similar to the ones in [[ExternalCatalogSuite]].
* This is because [[SessionCatalog]] and [[ExternalCatalog]] share many similar method
* signatures but do not extend a common parent. This is largely by design but
* unfortunately leads to very similar test code in two places.
Expand Down
33 changes: 17 additions & 16 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.command.ShowTablesCommand
import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab}
import org.apache.spark.sql.internal.{SessionState, SharedState, SQLConf}
import org.apache.spark.sql.sources.BaseRelation
Expand Down Expand Up @@ -258,7 +259,7 @@ class SQLContext private[sql](
* @since 1.3.0
*/
def isCached(tableName: String): Boolean = {
sparkSession.isCached(tableName)
sparkSession.catalog.isCached(tableName)
}

/**
Expand All @@ -267,7 +268,7 @@ class SQLContext private[sql](
* @since 1.3.0
*/
private[sql] def isCached(qName: Dataset[_]): Boolean = {
sparkSession.isCached(qName)
sparkSession.cacheManager.lookupCachedData(qName).nonEmpty
}

/**
Expand All @@ -276,7 +277,7 @@ class SQLContext private[sql](
* @since 1.3.0
*/
def cacheTable(tableName: String): Unit = {
sparkSession.cacheTable(tableName)
sparkSession.catalog.cacheTable(tableName)
}

/**
Expand All @@ -285,15 +286,15 @@ class SQLContext private[sql](
* @since 1.3.0
*/
def uncacheTable(tableName: String): Unit = {
sparkSession.uncacheTable(tableName)
sparkSession.catalog.uncacheTable(tableName)
}

/**
* Removes all cached tables from the in-memory cache.
* @since 1.3.0
*/
def clearCache(): Unit = {
sparkSession.clearCache()
sparkSession.catalog.clearCache()
}

// scalastyle:off
Expand Down Expand Up @@ -507,7 +508,7 @@ class SQLContext private[sql](
*/
@Experimental
def createExternalTable(tableName: String, path: String): DataFrame = {
sparkSession.createExternalTable(tableName, path)
sparkSession.catalog.createExternalTable(tableName, path)
}

/**
Expand All @@ -523,7 +524,7 @@ class SQLContext private[sql](
tableName: String,
path: String,
source: String): DataFrame = {
sparkSession.createExternalTable(tableName, path, source)
sparkSession.catalog.createExternalTable(tableName, path, source)
}

/**
Expand All @@ -539,7 +540,7 @@ class SQLContext private[sql](
tableName: String,
source: String,
options: java.util.Map[String, String]): DataFrame = {
sparkSession.createExternalTable(tableName, source, options)
sparkSession.catalog.createExternalTable(tableName, source, options)
}

/**
Expand All @@ -556,7 +557,7 @@ class SQLContext private[sql](
tableName: String,
source: String,
options: Map[String, String]): DataFrame = {
sparkSession.createExternalTable(tableName, source, options)
sparkSession.catalog.createExternalTable(tableName, source, options)
}

/**
Expand All @@ -573,7 +574,7 @@ class SQLContext private[sql](
source: String,
schema: StructType,
options: java.util.Map[String, String]): DataFrame = {
sparkSession.createExternalTable(tableName, source, schema, options)
sparkSession.catalog.createExternalTable(tableName, source, schema, options)
}

/**
Expand All @@ -591,7 +592,7 @@ class SQLContext private[sql](
source: String,
schema: StructType,
options: Map[String, String]): DataFrame = {
sparkSession.createExternalTable(tableName, source, schema, options)
sparkSession.catalog.createExternalTable(tableName, source, schema, options)
}

/**
Expand All @@ -611,7 +612,7 @@ class SQLContext private[sql](
* @since 1.3.0
*/
def dropTempTable(tableName: String): Unit = {
sparkSession.dropTempTable(tableName)
sparkSession.catalog.dropTempTable(tableName)
}

/**
Expand Down Expand Up @@ -700,7 +701,7 @@ class SQLContext private[sql](
* @since 1.3.0
*/
def tables(): DataFrame = {
sparkSession.tables()
Dataset.ofRows(sparkSession, ShowTablesCommand(None, None))
}

/**
Expand All @@ -712,7 +713,7 @@ class SQLContext private[sql](
* @since 1.3.0
*/
def tables(databaseName: String): DataFrame = {
sparkSession.tables(databaseName)
Dataset.ofRows(sparkSession, ShowTablesCommand(Some(databaseName), None))
}

/**
Expand All @@ -730,7 +731,7 @@ class SQLContext private[sql](
* @since 1.3.0
*/
def tableNames(): Array[String] = {
sparkSession.tableNames()
sparkSession.catalog.listTables().collect().map(_.name)
}

/**
Expand All @@ -740,7 +741,7 @@ class SQLContext private[sql](
* @since 1.3.0
*/
def tableNames(databaseName: String): Array[String] = {
sparkSession.tableNames(databaseName)
sparkSession.catalog.listTables(databaseName).collect().map(_.name)
}

/**
Expand Down
Loading