Skip to content

Commit

Permalink
SPARK-24923: Add v2 CTAS and RTAS support.
Browse files Browse the repository at this point in the history
This uses the catalog API introduced in SPARK-24252 to implement CTAS
and RTAS plans.
  • Loading branch information
rdblue committed Aug 15, 2018
1 parent 622180a commit e50d94b
Show file tree
Hide file tree
Showing 12 changed files with 420 additions and 95 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@

package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan

trait NamedRelation extends LogicalPlan {
def name: String

def output: Seq[AttributeReference]
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.sql.catalyst.{AliasIdentifier}
import org.apache.spark.sql.catalog.v2.{PartitionTransform, TableCatalog}
import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelation}
import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable}
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -384,6 +385,37 @@ object AppendData {
}
}

/**
* Create a new table from a select query.
*/
case class CreateTableAsSelect(
catalog: TableCatalog,
table: TableIdentifier,
partitioning: Seq[PartitionTransform],
query: LogicalPlan,
writeOptions: Map[String, String],
ignoreIfExists: Boolean) extends LogicalPlan {

override def children: Seq[LogicalPlan] = Seq(query)
override def output: Seq[Attribute] = Seq.empty
override lazy val resolved = true
}

/**
* Replace a table with the results of a select query.
*/
case class ReplaceTableAsSelect(
catalog: TableCatalog,
table: TableIdentifier,
partitioning: Seq[PartitionTransform],
query: LogicalPlan,
writeOptions: Map[String, String]) extends LogicalPlan {

override def children: Seq[LogicalPlan] = Seq(query)
override def output: Seq[Attribute] = Seq.empty
override lazy val resolved = true
}

/**
* Insert some data into a table. Note that this plan is unresolved and has to be replaced by the
* concrete implementations during analysis.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
* provide data reading ability and scan the data from the data source.
*/
@InterfaceStability.Evolving
public interface ReadSupport extends DataSourceV2 {
public interface ReadSupport {

/**
* Creates a {@link DataSourceReader} to scan the data from this data source.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
* provide data writing ability and save the data to the data source.
*/
@InterfaceStability.Evolving
public interface WriteSupport extends DataSourceV2 {
public interface WriteSupport {

/**
* Creates an optional {@link DataSourceWriter} to save the data to this data source. Data
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import org.apache.spark.sql.execution.datasources.jdbc._
import org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils
import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport}
import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, DataSourceV2Implicits, ReadSupport}
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -191,6 +191,21 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
"read files of Hive data source directly.")
}

import DataSourceV2Implicits._

extraOptions.get("catalog") match {
case Some(catalogName) if extraOptions.get(DataSourceOptions.TABLE_KEY).isDefined =>
val catalog = sparkSession.catalog(catalogName).asTableCatalog
val options = extraOptions.toMap
val identifier = options.table.get

return Dataset.ofRows(sparkSession,
DataSourceV2Relation.create(
catalogName, identifier, catalog.loadTable(identifier), options))

case _ =>
}

val cls = DataSource.lookupDataSource(source, sparkSession.sessionState.conf)
if (classOf[DataSourceV2].isAssignableFrom(cls)) {
val ds = cls.newInstance().asInstanceOf[DataSourceV2]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.annotation.InterfaceStability
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedRelation}
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, InsertIntoTable, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelect, InsertIntoTable, LogicalPlan, ReplaceTableAsSelect}
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, LogicalRelation}
Expand Down Expand Up @@ -236,6 +236,51 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {

assertNotBucketed("save")

import DataSourceV2Implicits._

extraOptions.get("catalog") match {
case Some(catalogName) if extraOptions.get(DataSourceOptions.TABLE_KEY).isDefined =>
val catalog = df.sparkSession.catalog(catalogName).asTableCatalog
val options = extraOptions.toMap
val identifier = options.table.get
val exists = catalog.tableExists(identifier)

(exists, mode) match {
case (true, SaveMode.ErrorIfExists) =>
throw new AnalysisException(s"Table already exists: ${identifier.quotedString}")

case (true, SaveMode.Overwrite) =>
runCommand(df.sparkSession, "save") {
ReplaceTableAsSelect(catalog, identifier, Seq.empty, df.logicalPlan, options)
}

case (true, SaveMode.Append) =>
val relation = DataSourceV2Relation.create(
catalogName, identifier, catalog.loadTable(identifier), options)

runCommand(df.sparkSession, "save") {
AppendData.byName(relation, df.logicalPlan)
}

case (false, SaveMode.Append) =>
throw new AnalysisException(s"Table does not exist: ${identifier.quotedString}")

case (false, SaveMode.ErrorIfExists) |
(false, SaveMode.Ignore) |
(false, SaveMode.Overwrite) =>

runCommand(df.sparkSession, "save") {
CreateTableAsSelect(catalog, identifier, Seq.empty, df.logicalPlan, options,
ignoreIfExists = mode == SaveMode.Ignore)
}

case _ =>
return // table exists and mode is ignore
}

case _ =>
}

val cls = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf)
if (classOf[DataSourceV2].isAssignableFrom(cls)) {
val source = cls.newInstance().asInstanceOf[DataSourceV2]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,18 @@

package org.apache.spark.sql.execution.datasources.v2

import java.util.UUID

import scala.collection.JavaConverters._

import org.apache.spark.sql.{AnalysisException, SaveMode}
import org.apache.spark.sql.catalog.v2.Table
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelation}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression}
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
import org.apache.spark.sql.sources.DataSourceRegister
import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport, WriteSupport}
import org.apache.spark.sql.sources.v2.DataSourceV2Implicits._
import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsReportStatistics}
import org.apache.spark.sql.sources.v2.writer.DataSourceWriter
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils

/**
* A logical plan representing a data source v2 scan.
Expand All @@ -48,10 +46,10 @@ case class DataSourceV2Relation(
userSpecifiedSchema: Option[StructType] = None)
extends LeafNode with MultiInstanceRelation with NamedRelation with DataSourceV2StringFormat {

import DataSourceV2Relation._
override def sourceName: String = source.name

override def name: String = {
tableIdent.map(_.unquotedString).getOrElse(s"${source.name}:unknown")
tableIdent.map(_.unquotedString).getOrElse(s"$sourceName:unknown")
}

override def pushedFilters: Seq[Expression] = Seq.empty
Expand All @@ -62,7 +60,7 @@ case class DataSourceV2Relation(

def newWriter(): DataSourceWriter = source.createWriter(options, schema)

override def computeStats(): Statistics = newReader match {
override def computeStats(): Statistics = newReader() match {
case r: SupportsReportStatistics =>
Statistics(sizeInBytes = r.getStatistics.sizeInBytes().orElse(conf.defaultSizeInBytes))
case _ =>
Expand All @@ -74,6 +72,43 @@ case class DataSourceV2Relation(
}
}

/**
* A logical plan representing a data source v2 table.
*
* @param ident The table's TableIdentifier.
* @param table The table.
* @param output The output attributes of the table.
* @param options The options for this scan or write.
*/
case class TableV2Relation(
catalogName: String,
ident: TableIdentifier,
table: Table,
output: Seq[AttributeReference],
options: Map[String, String])
extends LeafNode with MultiInstanceRelation with NamedRelation {

import org.apache.spark.sql.sources.v2.DataSourceV2Implicits._

override def name: String = ident.unquotedString

override def simpleString: String =
s"RelationV2 $name ${Utils.truncatedString(output, "[", ", ", "]")}"

def newReader(): DataSourceReader = table.createReader(options)

override def computeStats(): Statistics = newReader() match {
case r: SupportsReportStatistics =>
Statistics(sizeInBytes = r.getStatistics.sizeInBytes().orElse(conf.defaultSizeInBytes))
case _ =>
Statistics(sizeInBytes = conf.defaultSizeInBytes)
}

override def newInstance(): TableV2Relation = {
copy(output = output.map(_.newInstance()))
}
}

/**
* A specialization of [[DataSourceV2Relation]] with the streaming bit set to true.
*
Expand All @@ -88,6 +123,8 @@ case class StreamingDataSourceV2Relation(
reader: DataSourceReader)
extends LeafNode with MultiInstanceRelation with DataSourceV2StringFormat {

override def sourceName: String = source.name

override def isStreaming: Boolean = true

override def simpleString: String = "Streaming RelationV2 " + metadataString
Expand Down Expand Up @@ -116,68 +153,22 @@ case class StreamingDataSourceV2Relation(
}

object DataSourceV2Relation {
private implicit class SourceHelpers(source: DataSourceV2) {
def asReadSupport: ReadSupport = {
source match {
case support: ReadSupport =>
support
case _ =>
throw new AnalysisException(s"Data source is not readable: $name")
}
}

def asWriteSupport: WriteSupport = {
source match {
case support: WriteSupport =>
support
case _ =>
throw new AnalysisException(s"Data source is not writable: $name")
}
}

def name: String = {
source match {
case registered: DataSourceRegister =>
registered.shortName()
case _ =>
source.getClass.getSimpleName
}
}

def createReader(
options: Map[String, String],
userSpecifiedSchema: Option[StructType]): DataSourceReader = {
val v2Options = new DataSourceOptions(options.asJava)
userSpecifiedSchema match {
case Some(s) =>
asReadSupport.createReader(s, v2Options)
case _ =>
asReadSupport.createReader(v2Options)
}
}

def createWriter(
options: Map[String, String],
schema: StructType): DataSourceWriter = {
val v2Options = new DataSourceOptions(options.asJava)
asWriteSupport.createWriter(UUID.randomUUID.toString, schema, SaveMode.Append, v2Options).get
}
}

def create(
source: DataSourceV2,
options: Map[String, String],
tableIdent: Option[TableIdentifier] = None,
userSpecifiedSchema: Option[StructType] = None): DataSourceV2Relation = {
userSpecifiedSchema: Option[StructType] = None): NamedRelation = {
val reader = source.createReader(options, userSpecifiedSchema)
val ident = tableIdent.orElse(tableFromOptions(options))
val ident = tableIdent.orElse(options.table)
DataSourceV2Relation(
source, reader.readSchema().toAttributes, options, ident, userSpecifiedSchema)
}

private def tableFromOptions(options: Map[String, String]): Option[TableIdentifier] = {
options
.get(DataSourceOptions.TABLE_KEY)
.map(TableIdentifier(_, options.get(DataSourceOptions.DATABASE_KEY)))
def create(
catalogName: String,
ident: TableIdentifier,
table: Table,
options: Map[String, String]): NamedRelation = {
TableV2Relation(catalogName, ident, table, table.schema.toAttributes, options)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.plans.physical
import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec}
import org.apache.spark.sql.execution.streaming.continuous._
import org.apache.spark.sql.sources.v2.DataSourceV2
import org.apache.spark.sql.sources.v2.reader._
import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader
import org.apache.spark.sql.vectorized.ColumnarBatch
Expand All @@ -36,7 +35,7 @@ import org.apache.spark.sql.vectorized.ColumnarBatch
*/
case class DataSourceV2ScanExec(
output: Seq[AttributeReference],
@transient source: DataSourceV2,
@transient sourceName: String,
@transient options: Map[String, String],
@transient pushedFilters: Seq[Expression],
@transient reader: DataSourceReader)
Expand All @@ -52,7 +51,7 @@ case class DataSourceV2ScanExec(
}

override def hashCode(): Int = {
Seq(output, source, options).hashCode()
Seq(output, sourceName, options).hashCode()
}

override def outputPartitioning: physical.Partitioning = reader match {
Expand Down

0 comments on commit e50d94b

Please sign in to comment.