Skip to content

Commit

Permalink
[SPARK-13738][SQL] Cleanup Data Source resolution
Browse files Browse the repository at this point in the history
Follow-up to #11509, that simply refactors the interface that we use when resolving a pluggable `DataSource`.
 - Multiple functions share the same set of arguments so we make this a case class, called `DataSource`.  Actual resolution is now done by calling a function on this class.
 - Instead of having multiple methods named `apply` (some of which do writing some of which do reading) we now explicitly have `resolveRelation()` and `write(mode, df)`.
 - Get rid of `Array[String]` since this is an internal API and was forcing us to awkwardly call `toArray` in a bunch of places.

Author: Michael Armbrust <michael@databricks.com>

Closes #11572 from marmbrus/dataSourceResolution.
  • Loading branch information
marmbrus authored and rxin committed Mar 8, 2016
1 parent 076009b commit 1e28840
Show file tree
Hide file tree
Showing 12 changed files with 197 additions and 197 deletions.
34 changes: 18 additions & 16 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.execution.LogicalRDD
import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource}
import org.apache.spark.sql.execution.datasources.{DataSource, LogicalRelation}
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation}
import org.apache.spark.sql.execution.datasources.json.{InferSchema, JacksonParser, JSONOptions}
import org.apache.spark.sql.execution.streaming.StreamingRelation
Expand Down Expand Up @@ -122,12 +122,13 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
* @since 1.4.0
*/
def load(): DataFrame = {
val resolved = ResolvedDataSource(
sqlContext,
userSpecifiedSchema = userSpecifiedSchema,
provider = source,
options = extraOptions.toMap)
DataFrame(sqlContext, LogicalRelation(resolved.relation))
val dataSource =
DataSource(
sqlContext,
userSpecifiedSchema = userSpecifiedSchema,
className = source,
options = extraOptions.toMap)
DataFrame(sqlContext, LogicalRelation(dataSource.resolveRelation()))
}

/**
Expand All @@ -152,12 +153,12 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
sqlContext.emptyDataFrame
} else {
sqlContext.baseRelationToDataFrame(
ResolvedDataSource.apply(
DataSource.apply(
sqlContext,
paths = paths,
userSpecifiedSchema = userSpecifiedSchema,
provider = source,
options = extraOptions.toMap).relation)
className = source,
options = extraOptions.toMap).resolveRelation())
}
}

Expand All @@ -168,12 +169,13 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
* @since 2.0.0
*/
def stream(): DataFrame = {
val resolved = ResolvedDataSource.createSource(
sqlContext,
userSpecifiedSchema = userSpecifiedSchema,
providerName = source,
options = extraOptions.toMap)
DataFrame(sqlContext, StreamingRelation(resolved))
val dataSource =
DataSource(
sqlContext,
userSpecifiedSchema = userSpecifiedSchema,
className = source,
options = extraOptions.toMap)
DataFrame(sqlContext, StreamingRelation(dataSource.createSource()))
}

/**
Expand Down
29 changes: 15 additions & 14 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, Project}
import org.apache.spark.sql.execution.datasources.{BucketSpec, CreateTableUsingAsSelect, ResolvedDataSource}
import org.apache.spark.sql.execution.datasources.{BucketSpec, CreateTableUsingAsSelect, DataSource}
import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
import org.apache.spark.sql.execution.streaming.StreamExecution
import org.apache.spark.sql.sources.HadoopFsRelation
Expand Down Expand Up @@ -195,14 +195,14 @@ final class DataFrameWriter private[sql](df: DataFrame) {
*/
def save(): Unit = {
assertNotBucketed()
ResolvedDataSource(
val dataSource = DataSource(
df.sqlContext,
source,
partitioningColumns.map(_.toArray).getOrElse(Array.empty[String]),
getBucketSpec,
mode,
extraOptions.toMap,
df)
className = source,
partitionColumns = partitioningColumns.getOrElse(Nil),
bucketSpec = getBucketSpec,
options = extraOptions.toMap)

dataSource.write(mode, df)
}

/**
Expand Down Expand Up @@ -235,14 +235,15 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* @since 2.0.0
*/
def stream(): ContinuousQuery = {
val sink = ResolvedDataSource.createSink(
df.sqlContext,
source,
extraOptions.toMap,
normalizedParCols.getOrElse(Nil))
val dataSource =
DataSource(
df.sqlContext,
className = source,
options = extraOptions.toMap,
partitionColumns = normalizedParCols.getOrElse(Nil))

df.sqlContext.continuousQueryManager.startQuery(
extraOptions.getOrElse("queryName", StreamExecution.nextName), df, sink)
extraOptions.getOrElse("queryName", StreamExecution.nextName), df, dataSource.createSink())
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,39 @@ import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{CalendarIntervalType, StructType}
import org.apache.spark.util.Utils

case class ResolvedDataSource(provider: Class[_], relation: BaseRelation)

/**
* Responsible for taking a description of a datasource (either from
* [[org.apache.spark.sql.DataFrameReader]], or a metastore) and converting it into a logical
* relation that can be used in a query plan.
* The main class responsible for representing a pluggable Data Source in Spark SQL. In addition to
* acting as the canonical set of parameters that can describe a Data Source, this class is used to
* resolve a description to a concrete implementation that can be used in a query plan
* (either batch or streaming) or to write out data using an external library.
*
* From an end user's perspective a DataSource description can be created explicitly using
* [[org.apache.spark.sql.DataFrameReader]] or CREATE TABLE USING DDL. Additionally, this class is
* used when resolving a description from a metastore to a concrete implementation.
*
* Many of the arguments to this class are optional, though depending on the specific API being used
* these optional arguments might be filled in during resolution using either inference or external
* metadata. For example, when reading a partitioned table from a file system, partition columns
* will be inferred from the directory layout even if they are not specified.
*
* @param paths A list of file system paths that hold data. These will be globbed before and
* qualified. This option only works when reading from a [[FileFormat]].
* @param userSpecifiedSchema An optional specification of the schema of the data. When present
* we skip attempting to infer the schema.
* @param partitionColumns A list of column names that the relation is partitioned by. When this
* list is empty, the relation is unpartitioned.
* @param bucketSpec An optional specification for bucketing (hash-partitioning) of the data.
*/
object ResolvedDataSource extends Logging {
case class DataSource(
sqlContext: SQLContext,
className: String,
paths: Seq[String] = Nil,
userSpecifiedSchema: Option[StructType] = None,
partitionColumns: Seq[String] = Seq.empty,
bucketSpec: Option[BucketSpec] = None,
options: Map[String, String] = Map.empty) extends Logging {

lazy val providingClass: Class[_] = lookupDataSource(className)

/** A map to maintain backward compatibility in case we move data sources around. */
private val backwardCompatibilityMap = Map(
Expand All @@ -54,7 +79,7 @@ object ResolvedDataSource extends Logging {
)

/** Given a provider name, look up the data source class definition. */
def lookupDataSource(provider0: String): Class[_] = {
private def lookupDataSource(provider0: String): Class[_] = {
val provider = backwardCompatibilityMap.getOrElse(provider0, provider0)
val provider2 = s"$provider.DefaultSource"
val loader = Utils.getContextOrSparkClassLoader
Expand Down Expand Up @@ -96,15 +121,11 @@ object ResolvedDataSource extends Logging {
}
}

// TODO: Combine with apply?
def createSource(
sqlContext: SQLContext,
userSpecifiedSchema: Option[StructType],
providerName: String,
options: Map[String, String]): Source = {
val provider = lookupDataSource(providerName).newInstance() match {
/** Returns a source that can be used to continually read data. */
def createSource(): Source = {
providingClass.newInstance() match {
case s: StreamSourceProvider =>
s.createSource(sqlContext, userSpecifiedSchema, providerName, options)
s.createSource(sqlContext, userSpecifiedSchema, className, options)

case format: FileFormat =>
val caseInsensitiveOptions = new CaseInsensitiveMap(options)
Expand Down Expand Up @@ -135,53 +156,38 @@ object ResolvedDataSource extends Logging {
new DataFrame(
sqlContext,
LogicalRelation(
apply(
DataSource(
sqlContext,
paths = files,
userSpecifiedSchema = Some(dataSchema),
provider = providerName,
options = options.filterKeys(_ != "path")).relation))
className = className,
options = options.filterKeys(_ != "path")).resolveRelation()))
}

new FileStreamSource(
sqlContext, metadataPath, path, Some(dataSchema), providerName, dataFrameBuilder)
sqlContext, metadataPath, path, Some(dataSchema), className, dataFrameBuilder)
case _ =>
throw new UnsupportedOperationException(
s"Data source $providerName does not support streamed reading")
s"Data source $className does not support streamed reading")
}

provider
}

def createSink(
sqlContext: SQLContext,
providerName: String,
options: Map[String, String],
partitionColumns: Seq[String]): Sink = {
val provider = lookupDataSource(providerName).newInstance() match {
/** Returns a sink that can be used to continually write data. */
def createSink(): Sink = {
val datasourceClass = providingClass.newInstance() match {
case s: StreamSinkProvider => s
case _ =>
throw new UnsupportedOperationException(
s"Data source $providerName does not support streamed writing")
s"Data source $className does not support streamed writing")
}

provider.createSink(sqlContext, options, partitionColumns)
datasourceClass.createSink(sqlContext, options, partitionColumns)
}

/** Create a [[ResolvedDataSource]] for reading data in. */
def apply(
sqlContext: SQLContext,
paths: Seq[String] = Nil,
userSpecifiedSchema: Option[StructType] = None,
partitionColumns: Array[String] = Array.empty,
bucketSpec: Option[BucketSpec] = None,
provider: String,
options: Map[String, String]): ResolvedDataSource = {
val clazz: Class[_] = lookupDataSource(provider)
def className: String = clazz.getCanonicalName

/** Create a resolved [[BaseRelation]] that can be used to read data from this [[DataSource]] */
def resolveRelation(): BaseRelation = {
val caseInsensitiveOptions = new CaseInsensitiveMap(options)
val relation = (clazz.newInstance(), userSpecifiedSchema) match {
val relation = (providingClass.newInstance(), userSpecifiedSchema) match {
// TODO: Throw when too much is given.
case (dataSource: SchemaRelationProvider, Some(schema)) =>
dataSource.createRelation(sqlContext, caseInsensitiveOptions, schema)
Expand Down Expand Up @@ -238,43 +244,19 @@ object ResolvedDataSource extends Logging {
throw new AnalysisException(
s"$className is not a valid Spark SQL Data Source.")
}
new ResolvedDataSource(clazz, relation)
}

def partitionColumnsSchema(
schema: StructType,
partitionColumns: Array[String],
caseSensitive: Boolean): StructType = {
val equality = columnNameEquality(caseSensitive)
StructType(partitionColumns.map { col =>
schema.find(f => equality(f.name, col)).getOrElse {
throw new RuntimeException(s"Partition column $col not found in schema $schema")
}
}).asNullable
relation
}

private def columnNameEquality(caseSensitive: Boolean): (String, String) => Boolean = {
if (caseSensitive) {
org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution
} else {
org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution
}
}

/** Create a [[ResolvedDataSource]] for saving the content of the given DataFrame. */
def apply(
sqlContext: SQLContext,
provider: String,
partitionColumns: Array[String],
bucketSpec: Option[BucketSpec],
/** Writes the give [[DataFrame]] out to this [[DataSource]]. */
def write(
mode: SaveMode,
options: Map[String, String],
data: DataFrame): ResolvedDataSource = {
data: DataFrame): BaseRelation = {
if (data.schema.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) {
throw new AnalysisException("Cannot save interval data type into external storage.")
}
val clazz: Class[_] = lookupDataSource(provider)
clazz.newInstance() match {

providingClass.newInstance() match {
case dataSource: CreatableRelationProvider =>
dataSource.createRelation(sqlContext, mode, options, data)
case format: FileFormat =>
Expand All @@ -295,27 +277,28 @@ object ResolvedDataSource extends Logging {
PartitioningUtils.validatePartitionColumnDataTypes(
data.schema, partitionColumns, caseSensitive)

val equality = columnNameEquality(caseSensitive)
val equality =
if (sqlContext.conf.caseSensitiveAnalysis) {
org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution
} else {
org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution
}

val dataSchema = StructType(
data.schema.filterNot(f => partitionColumns.exists(equality(_, f.name))))

// If we are appending to a table that already exists, make sure the partitioning matches
// up. If we fail to load the table for whatever reason, ignore the check.
if (mode == SaveMode.Append) {
val existingPartitionColumnSet = try {
val resolved = apply(
sqlContext,
userSpecifiedSchema = Some(data.schema.asNullable),
provider = provider,
options = options)

Some(resolved.relation
.asInstanceOf[HadoopFsRelation]
.location
.partitionSpec(None)
.partitionColumns
.fieldNames
.toSet)
Some(
resolveRelation()
.asInstanceOf[HadoopFsRelation]
.location
.partitionSpec(None)
.partitionColumns
.fieldNames
.toSet)
} catch {
case e: Exception =>
None
Expand Down Expand Up @@ -346,15 +329,10 @@ object ResolvedDataSource extends Logging {
sqlContext.executePlan(plan).toRdd

case _ =>
sys.error(s"${clazz.getCanonicalName} does not allow create table as select.")
sys.error(s"${providingClass.getCanonicalName} does not allow create table as select.")
}

apply(
sqlContext,
userSpecifiedSchema = Some(data.schema.asNullable),
partitionColumns = partitionColumns,
bucketSpec = bucketSpec,
provider = provider,
options = options)
// We replace the schema with that of the DataFrame we just wrote out to avoid re-inferring it.
copy(userSpecifiedSchema = Some(data.schema.asNullable)).resolveRelation()
}
}
Loading

0 comments on commit 1e28840

Please sign in to comment.