From c55a1f95491b10208ccd2cdf5910e6ec813c3522 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 19 Jan 2017 22:04:44 +0800 Subject: [PATCH] add post-hoc resolution --- .../sql/catalyst/analysis/Analyzer.scala | 8 +++++ .../datasources/DataSourceStrategy.scala | 25 ++-------------- .../sql/execution/datasources/rules.scala | 4 +-- .../spark/sql/internal/SessionState.scala | 8 +++-- .../spark/sql/hive/HiveSessionState.scala | 10 ++++--- .../spark/sql/hive/HiveStrategies.scala | 30 +++++-------------- 6 files changed, 30 insertions(+), 55 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 98851cb8557a3..cb56e94c0a77a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -106,6 +106,13 @@ class Analyzer( */ val extendedResolutionRules: Seq[Rule[LogicalPlan]] = Nil + /** + * Override to provide rules to do post-hoc resolution. Note that these rules will be executed + * in an individual batch. This batch is to run right after the normal resolution batch and + * execute its rules in one pass. + */ + val postHocResolutionRules: Seq[Rule[LogicalPlan]] = Nil + lazy val batches: Seq[Batch] = Seq( Batch("Substitution", fixedPoint, CTESubstitution, @@ -139,6 +146,7 @@ class Analyzer( ResolveInlineTables :: TypeCoercion.typeCoercionRules ++ extendedResolutionRules : _*), + Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*), Batch("View", Once, AliasViewChild(conf)), Batch("Nondeterministic", Once, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 21b07ee85adc8..19db293132f54 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -44,6 +44,8 @@ import org.apache.spark.unsafe.types.UTF8String /** * Replaces generic operations with specific variants that are designed to work with Spark * SQL Data Sources. + * + * Note that, this rule must be run after [[PreprocessTableInsertion]]. */ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] { @@ -127,30 +129,9 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] { projectList } - /** - * Returns true if the [[InsertIntoTable]] plan has already been preprocessed by analyzer rule - * [[PreprocessTableInsertion]]. It is important that this rule([[DataSourceAnalysis]]) has to - * be run after [[PreprocessTableInsertion]], to normalize the column names in partition spec and - * fix the schema mismatch by adding Cast. - */ - private def hasBeenPreprocessed( - tableOutput: Seq[Attribute], - partSchema: StructType, - partSpec: Map[String, Option[String]], - query: LogicalPlan): Boolean = { - val partColNames = partSchema.map(_.name).toSet - query.resolved && partSpec.keys.forall(partColNames.contains) && { - val staticPartCols = partSpec.filter(_._2.isDefined).keySet - val expectedColumns = tableOutput.filterNot(a => staticPartCols.contains(a.name)) - expectedColumns.toStructType.sameType(query.schema) - } - } - override def apply(plan: LogicalPlan): LogicalPlan = plan transform { case InsertIntoTable( - l @ LogicalRelation(t: HadoopFsRelation, _, table), parts, query, overwrite, false) - if hasBeenPreprocessed(l.output, t.partitionSchema, parts, query) => - + l @ LogicalRelation(t: HadoopFsRelation, _, table), parts, query, overwrite, false) => // If the InsertIntoTable command is for a partitioned HadoopFsRelation and // the user has specified static partitions, we add a Project operator on top of the query // to include those constant column values in the query result. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index beacb08994430..87e7017aee3a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution.datasources -import scala.util.control.NonFatal - import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog._ @@ -385,7 +383,7 @@ case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] { } def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case i @ InsertIntoTable(table, partition, child, _, _) if table.resolved && child.resolved => + case i @ InsertIntoTable(table, _, child, _, _) if table.resolved && child.resolved => table match { case relation: CatalogRelation => val metadata = relation.catalogTable diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index 64ec62f41d1f8..68b774b52fd7f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -114,12 +114,14 @@ private[sql] class SessionState(sparkSession: SparkSession) { lazy val analyzer: Analyzer = { new Analyzer(catalog, conf) { override val extendedResolutionRules = - AnalyzeCreateTable(sparkSession) :: - PreprocessTableInsertion(conf) :: new FindDataSourceTable(sparkSession) :: - DataSourceAnalysis(conf) :: new ResolveDataSource(sparkSession) :: Nil + override val postHocResolutionRules = + AnalyzeCreateTable(sparkSession) :: + PreprocessTableInsertion(conf) :: + DataSourceAnalysis(conf) :: Nil + override val extendedCheckRules = Seq(PreWriteCheck(conf, catalog), HiveOnlyCheck) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala index d3cef6e0cb0cf..9fd03ef8ba037 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala @@ -62,15 +62,17 @@ private[hive] class HiveSessionState(sparkSession: SparkSession) override val extendedResolutionRules = catalog.ParquetConversions :: catalog.OrcConversions :: - AnalyzeCreateTable(sparkSession) :: - PreprocessTableInsertion(conf) :: - DataSourceAnalysis(conf) :: new DetermineHiveSerde(conf) :: - new HiveAnalysis(sparkSession) :: new FindDataSourceTable(sparkSession) :: new FindHiveSerdeTable(sparkSession) :: new ResolveDataSource(sparkSession) :: Nil + override val postHocResolutionRules = + AnalyzeCreateTable(sparkSession) :: + PreprocessTableInsertion(conf) :: + DataSourceAnalysis(conf) :: + new HiveAnalysis(sparkSession) :: Nil + override val extendedCheckRules = Seq(PreWriteCheck(conf, catalog)) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 838e6f4008108..6cde783c5ae32 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -25,10 +25,9 @@ import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.command.DDLUtils -import org.apache.spark.sql.execution.datasources.CreateTable +import org.apache.spark.sql.execution.datasources.{CreateTable, PreprocessTableInsertion} import org.apache.spark.sql.hive.execution._ import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} -import org.apache.spark.sql.types.StructType /** @@ -78,10 +77,14 @@ class DetermineHiveSerde(conf: SQLConf) extends Rule[LogicalPlan] { } } +/** + * Replaces generic operations with specific variants that are designed to work with Hive. + * + * Note that, this rule must be run after [[PreprocessTableInsertion]]. + */ class HiveAnalysis(session: SparkSession) extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case InsertIntoTable(table: MetastoreRelation, partSpec, query, overwrite, ifNotExists) - if hasBeenPreprocessed(table.output, table.partitionKeys.toStructType, partSpec, query) => + case InsertIntoTable(table: MetastoreRelation, partSpec, query, overwrite, ifNotExists) => InsertIntoHiveTable(table, partSpec, query, overwrite, ifNotExists) case CreateTable(tableDesc, mode, Some(query)) if DDLUtils.isHiveTable(tableDesc) => @@ -98,25 +101,6 @@ class HiveAnalysis(session: SparkSession) extends Rule[LogicalPlan] { query, mode == SaveMode.Ignore) } - - /** - * Returns true if the [[InsertIntoTable]] plan has already been preprocessed by analyzer rule - * [[PreprocessTableInsertion]]. It is important that this rule([[HiveAnalysis]]) has to - * be run after [[PreprocessTableInsertion]], to normalize the column names in partition spec and - * fix the schema mismatch by adding Cast. - */ - private def hasBeenPreprocessed( - tableOutput: Seq[Attribute], - partSchema: StructType, - partSpec: Map[String, Option[String]], - query: LogicalPlan): Boolean = { - val partColNames = partSchema.map(_.name).toSet - query.resolved && partSpec.keys.forall(partColNames.contains) && { - val staticPartCols = partSpec.filter(_._2.isDefined).keySet - val expectedColumns = tableOutput.filterNot(a => staticPartCols.contains(a.name)) - expectedColumns.toStructType.sameType(query.schema) - } - } } /**