Skip to content

Commit

Permalink
[SPARK-41713][SQL] Make CTAS hold a nested execution for data writing
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This pr aims to make ctas use a nested execution instead of running data writing cmmand.

So, we can clean up ctas itself to remove the unnecessary v1write information. Now, the v1writes only have two implementation: `InsertIntoHadoopFsRelationCommand` and `InsertIntoHiveTable`

### Why are the changes needed?

Make v1writes code clear.

```sql
EXPLAIN FORMATTED CREATE TABLE t2 USING PARQUET AS SELECT * FROM t;

== Physical Plan ==
Execute CreateDataSourceTableAsSelectCommand (1)
   +- CreateDataSourceTableAsSelectCommand (2)
         +- Project (5)
            +- SubqueryAlias (4)
               +- LogicalRelation (3)

(1) Execute CreateDataSourceTableAsSelectCommand
Output: []

(2) CreateDataSourceTableAsSelectCommand
Arguments: `spark_catalog`.`default`.`t2`, ErrorIfExists, [c1, c2]

(3) LogicalRelation
Arguments: parquet, [c1#11, c2#12], `spark_catalog`.`default`.`t`, org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe, false

(4) SubqueryAlias
Arguments: spark_catalog.default.t

(5) Project
Arguments: [c1#11, c2#12]
```

### Does this PR introduce _any_ user-facing change?

no

### How was this patch tested?

improve existed test

Closes #39220 from ulysses-you/SPARK-41713.

Authored-by: ulysses-you <ulyssesyou18@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
ulysses-you authored and cloud-fan committed Dec 28, 2022
1 parent 87a235c commit 4b40920
Show file tree
Hide file tree
Showing 11 changed files with 159 additions and 183 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,11 @@ package org.apache.spark.sql.execution.command
import java.net.URI

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.{CommandExecutionMode, SparkPlan}
import org.apache.spark.sql.execution.CommandExecutionMode
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -143,29 +141,11 @@ case class CreateDataSourceTableAsSelectCommand(
mode: SaveMode,
query: LogicalPlan,
outputColumnNames: Seq[String])
extends V1WriteCommand {

override def fileFormatProvider: Boolean = {
table.provider.forall { provider =>
classOf[FileFormat].isAssignableFrom(DataSource.providingClass(provider, conf))
}
}

override lazy val partitionColumns: Seq[Attribute] = {
val unresolvedPartitionColumns = table.partitionColumnNames.map(UnresolvedAttribute.quoted)
DataSource.resolvePartitionColumns(
unresolvedPartitionColumns,
outputColumns,
query,
SparkSession.active.sessionState.conf.resolver)
}

override def requiredOrdering: Seq[SortOrder] = {
val options = table.storage.properties
V1WritesUtils.getSortOrder(outputColumns, partitionColumns, table.bucketSpec, options)
}
extends LeafRunnableCommand {
assert(query.resolved)
override def innerChildren: Seq[LogicalPlan] = query :: Nil

override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = {
override def run(sparkSession: SparkSession): Seq[Row] = {
assert(table.tableType != CatalogTableType.VIEW)
assert(table.provider.isDefined)

Expand All @@ -187,7 +167,7 @@ case class CreateDataSourceTableAsSelectCommand(
}

saveDataIntoTable(
sparkSession, table, table.storage.locationUri, child, SaveMode.Append, tableExists = true)
sparkSession, table, table.storage.locationUri, SaveMode.Append, tableExists = true)
} else {
table.storage.locationUri.foreach { p =>
DataWritingCommand.assertEmptyRootPath(p, mode, sparkSession.sessionState.newHadoopConf)
Expand All @@ -200,7 +180,7 @@ case class CreateDataSourceTableAsSelectCommand(
table.storage.locationUri
}
val result = saveDataIntoTable(
sparkSession, table, tableLocation, child, SaveMode.Overwrite, tableExists = false)
sparkSession, table, tableLocation, SaveMode.Overwrite, tableExists = false)
val tableSchema = CharVarcharUtils.getRawSchema(result.schema, sessionState.conf)
val newTable = table.copy(
storage = table.storage.copy(locationUri = tableLocation),
Expand Down Expand Up @@ -232,7 +212,6 @@ case class CreateDataSourceTableAsSelectCommand(
session: SparkSession,
table: CatalogTable,
tableLocation: Option[URI],
physicalPlan: SparkPlan,
mode: SaveMode,
tableExists: Boolean): BaseRelation = {
// Create the relation based on the input logical plan: `query`.
Expand All @@ -246,14 +225,11 @@ case class CreateDataSourceTableAsSelectCommand(
catalogTable = if (tableExists) Some(table) else None)

try {
dataSource.writeAndRead(mode, query, outputColumnNames, physicalPlan, metrics)
dataSource.writeAndRead(mode, query, outputColumnNames)
} catch {
case ex: AnalysisException =>
logError(s"Failed to write to table ${table.identifier.unquotedString}", ex)
throw ex
}
}

override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan =
copy(query = newChild)
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, TypeUtils}
import org.apache.spark.sql.connector.catalog.TableProvider
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.command.DataWritingCommand
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
import org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider
Expand All @@ -45,7 +44,6 @@ import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2
import org.apache.spark.sql.execution.datasources.v2.orc.OrcDataSourceV2
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.sources.{RateStreamProvider, TextSocketSourceProvider}
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -97,8 +95,19 @@ case class DataSource(

case class SourceInfo(name: String, schema: StructType, partitionColumns: Seq[String])

lazy val providingClass: Class[_] =
DataSource.providingClass(className, sparkSession.sessionState.conf)
lazy val providingClass: Class[_] = {
val cls = DataSource.lookupDataSource(className, sparkSession.sessionState.conf)
// `providingClass` is used for resolving data source relation for catalog tables.
// As now catalog for data source V2 is under development, here we fall back all the
// [[FileDataSourceV2]] to [[FileFormat]] to guarantee the current catalog works.
// [[FileDataSourceV2]] will still be used if we call the load()/save() method in
// [[DataFrameReader]]/[[DataFrameWriter]], since they use method `lookupDataSource`
// instead of `providingClass`.
cls.newInstance() match {
case f: FileDataSourceV2 => f.fallbackFileFormat
case _ => cls
}
}

private[sql] def providingInstance(): Any = providingClass.getConstructor().newInstance()

Expand Down Expand Up @@ -483,17 +492,11 @@ case class DataSource(
* @param outputColumnNames The original output column names of the input query plan. The
* optimizer may not preserve the output column's names' case, so we need
* this parameter instead of `data.output`.
* @param physicalPlan The physical plan of the input query plan. We should run the writing
* command with this physical plan instead of creating a new physical plan,
* so that the metrics can be correctly linked to the given physical plan and
* shown in the web UI.
*/
def writeAndRead(
mode: SaveMode,
data: LogicalPlan,
outputColumnNames: Seq[String],
physicalPlan: SparkPlan,
metrics: Map[String, SQLMetric]): BaseRelation = {
outputColumnNames: Seq[String]): BaseRelation = {
val outputColumns = DataWritingCommand.logicalPlanOutputWithNames(data, outputColumnNames)
providingInstance() match {
case dataSource: CreatableRelationProvider =>
Expand All @@ -503,13 +506,8 @@ case class DataSource(
case format: FileFormat =>
disallowWritingIntervals(outputColumns.map(_.dataType), forbidAnsiIntervals = false)
val cmd = planForWritingFileFormat(format, mode, data)
val resolvedPartCols =
DataSource.resolvePartitionColumns(cmd.partitionColumns, outputColumns, data, equality)
val resolved = cmd.copy(
partitionColumns = resolvedPartCols,
outputColumnNames = outputColumnNames)
resolved.run(sparkSession, physicalPlan)
DataWritingCommand.propogateMetrics(sparkSession.sparkContext, resolved, metrics)
val qe = sparkSession.sessionState.executePlan(cmd)
qe.assertCommandExecuted()
// Replace the schema with that of the DataFrame we just wrote out to avoid re-inferring
copy(userSpecifiedSchema = Some(outputColumns.toStructType.asNullable)).resolveRelation()
case _ => throw new IllegalStateException(
Expand Down Expand Up @@ -832,18 +830,4 @@ object DataSource extends Logging {
}
}
}

def providingClass(className: String, conf: SQLConf): Class[_] = {
val cls = DataSource.lookupDataSource(className, conf)
// `providingClass` is used for resolving data source relation for catalog tables.
// As now catalog for data source V2 is under development, here we fall back all the
// [[FileDataSourceV2]] to [[FileFormat]] to guarantee the current catalog works.
// [[FileDataSourceV2]] will still be used if we call the load()/save() method in
// [[DataFrameReader]]/[[DataFrameWriter]], since they use method `lookupDataSource`
// instead of `providingClass`.
cls.newInstance() match {
case f: FileDataSourceV2 => f.fallbackFileFormat
case _ => cls
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,6 @@ import org.apache.spark.sql.types.StringType
import org.apache.spark.unsafe.types.UTF8String

trait V1WriteCommand extends DataWritingCommand {
/**
* Return if the provider is [[FileFormat]]
*/
def fileFormatProvider: Boolean = true

/**
* Specify the partition columns of the V1 write command.
*/
Expand All @@ -58,8 +53,7 @@ object V1Writes extends Rule[LogicalPlan] with SQLConfHelper {
override def apply(plan: LogicalPlan): LogicalPlan = {
if (conf.plannedWriteEnabled) {
plan.transformUp {
case write: V1WriteCommand if write.fileFormatProvider &&
!write.child.isInstanceOf[WriteFiles] =>
case write: V1WriteCommand if !write.child.isInstanceOf[WriteFiles] =>
val newQuery = prepareQuery(write, write.query)
val attrMap = AttributeMap(write.query.output.zip(newQuery.output))
val newChild = WriteFiles(newQuery)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,12 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite
withTable("temptable") {
val df = sql("create table temptable using parquet as select * from range(2)")
withNormalizedExplain(df, SimpleMode) { normalizedOutput =>
assert("Create\\w*?TableAsSelectCommand".r.findAllMatchIn(normalizedOutput).length == 1)
// scalastyle:off
// == Physical Plan ==
// Execute CreateDataSourceTableAsSelectCommand
// +- CreateDataSourceTableAsSelectCommand `spark_catalog`.`default`.`temptable`, ErrorIfExists, Project [id#5L], [id]
// scalastyle:on
assert("Create\\w*?TableAsSelectCommand".r.findAllMatchIn(normalizedOutput).length == 2)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListe
import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession, Strategy}
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
import org.apache.spark.sql.execution.{CollectLimitExec, CommandResultExec, LocalTableScanExec, PartialReducerPartitionSpec, QueryExecution, ReusedSubqueryExec, ShuffledRowRDD, SortExec, SparkPlan, UnionExec}
import org.apache.spark.sql.execution.{CollectLimitExec, LocalTableScanExec, PartialReducerPartitionSpec, QueryExecution, ReusedSubqueryExec, ShuffledRowRDD, SortExec, SparkPlan, SparkPlanInfo, UnionExec}
import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
import org.apache.spark.sql.execution.command.DataWritingCommandExec
import org.apache.spark.sql.execution.datasources.noop.NoopDataSource
import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ENSURE_REQUIREMENTS, Exchange, REPARTITION_BY_COL, REPARTITION_BY_NUM, ReusedExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin}
import org.apache.spark.sql.execution.joins.{BaseJoinExec, BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, ShuffledHashJoinExec, ShuffledJoin, SortMergeJoinExec}
import org.apache.spark.sql.execution.metric.SQLShuffleReadMetricsReporter
import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate
import org.apache.spark.sql.execution.ui.{SparkListenerSQLAdaptiveExecutionUpdate, SparkListenerSQLExecutionStart}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode
Expand Down Expand Up @@ -1150,18 +1150,31 @@ class AdaptiveQueryExecSuite
SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true",
SQLConf.PLANNED_WRITE_ENABLED.key -> enabled.toString) {
withTable("t1") {
val df = sql("CREATE TABLE t1 USING parquet AS SELECT 1 col")
val plan = df.queryExecution.executedPlan
assert(plan.isInstanceOf[CommandResultExec])
val commandPhysicalPlan = plan.asInstanceOf[CommandResultExec].commandPhysicalPlan
if (enabled) {
assert(commandPhysicalPlan.isInstanceOf[AdaptiveSparkPlanExec])
assert(commandPhysicalPlan.asInstanceOf[AdaptiveSparkPlanExec]
.executedPlan.isInstanceOf[DataWritingCommandExec])
} else {
assert(commandPhysicalPlan.isInstanceOf[DataWritingCommandExec])
assert(commandPhysicalPlan.asInstanceOf[DataWritingCommandExec]
.child.isInstanceOf[AdaptiveSparkPlanExec])
var checkDone = false
val listener = new SparkListener {
override def onOtherEvent(event: SparkListenerEvent): Unit = {
event match {
case SparkListenerSQLAdaptiveExecutionUpdate(_, _, planInfo) =>
if (enabled) {
assert(planInfo.nodeName == "AdaptiveSparkPlan")
assert(planInfo.children.size == 1)
assert(planInfo.children.head.nodeName ==
"Execute InsertIntoHadoopFsRelationCommand")
} else {
assert(planInfo.nodeName == "Execute InsertIntoHadoopFsRelationCommand")
}
checkDone = true
case _ => // ignore other events
}
}
}
spark.sparkContext.addSparkListener(listener)
try {
sql("CREATE TABLE t1 USING parquet AS SELECT 1 col").collect()
spark.sparkContext.listenerBus.waitUntilEmpty()
assert(checkDone)
} finally {
spark.sparkContext.removeSparkListener(listener)
}
}
}
Expand Down Expand Up @@ -1209,16 +1222,12 @@ class AdaptiveQueryExecSuite
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") {
withTable("t1") {
var checkDone = false
var commands: Seq[SparkPlanInfo] = Seq.empty
val listener = new SparkListener {
override def onOtherEvent(event: SparkListenerEvent): Unit = {
event match {
case SparkListenerSQLAdaptiveExecutionUpdate(_, _, planInfo) =>
assert(planInfo.nodeName == "AdaptiveSparkPlan")
assert(planInfo.children.size == 1)
assert(planInfo.children.head.nodeName ==
"Execute CreateDataSourceTableAsSelectCommand")
checkDone = true
case start: SparkListenerSQLExecutionStart =>
commands = commands ++ Seq(start.sparkPlanInfo)
case _ => // ignore other events
}
}
Expand All @@ -1227,7 +1236,12 @@ class AdaptiveQueryExecSuite
try {
sql("CREATE TABLE t1 USING parquet AS SELECT 1 col").collect()
spark.sparkContext.listenerBus.waitUntilEmpty()
assert(checkDone)
assert(commands.size == 3)
assert(commands.head.nodeName == "Execute CreateDataSourceTableAsSelectCommand")
assert(commands(1).nodeName == "AdaptiveSparkPlan")
assert(commands(1).children.size == 1)
assert(commands(1).children.head.nodeName == "Execute InsertIntoHadoopFsRelationCommand")
assert(commands(2).nodeName == "CommandResult")
} finally {
spark.sparkContext.removeSparkListener(listener)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ trait V1WriteCommandSuiteBase extends SQLTestUtils {
override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
qe.optimizedPlan match {
case w: V1WriteCommand =>
if (hasLogicalSort) {
if (hasLogicalSort && conf.getConf(SQLConf.PLANNED_WRITE_ENABLED)) {
assert(w.query.isInstanceOf[WriteFiles])
optimizedPlan = w.query.asInstanceOf[WriteFiles].child
} else {
Expand All @@ -86,16 +86,15 @@ trait V1WriteCommandSuiteBase extends SQLTestUtils {

sparkContext.listenerBus.waitUntilEmpty()

assert(optimizedPlan != null)
// Check whether a logical sort node is at the top of the logical plan of the write query.
if (optimizedPlan != null) {
assert(optimizedPlan.isInstanceOf[Sort] == hasLogicalSort,
s"Expect hasLogicalSort: $hasLogicalSort, Actual: ${optimizedPlan.isInstanceOf[Sort]}")
assert(optimizedPlan.isInstanceOf[Sort] == hasLogicalSort,
s"Expect hasLogicalSort: $hasLogicalSort, Actual: ${optimizedPlan.isInstanceOf[Sort]}")

// Check empty2null conversion.
val empty2nullExpr = optimizedPlan.exists(p => V1WritesUtils.hasEmptyToNull(p.expressions))
assert(empty2nullExpr == hasEmpty2Null,
s"Expect hasEmpty2Null: $hasEmpty2Null, Actual: $empty2nullExpr. Plan:\n$optimizedPlan")
}
// Check empty2null conversion.
val empty2nullExpr = optimizedPlan.exists(p => V1WritesUtils.hasEmptyToNull(p.expressions))
assert(empty2nullExpr == hasEmpty2Null,
s"Expect hasEmpty2Null: $hasEmpty2Null, Actual: $empty2nullExpr. Plan:\n$optimizedPlan")

spark.listenerManager.unregister(listener)
}
Expand Down
Loading

0 comments on commit 4b40920

Please sign in to comment.