diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 20c1756ef4efa..d2a119556f7fb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -238,6 +238,13 @@ class Analyzer(override val catalogManager: CatalogManager)
errorOnExceed = true,
maxIterationsSetting = SQLConf.ANALYZER_MAX_ITERATIONS.key)
+ /**
+ * Override to provide rules to do pre-resolution. Note that these rules will be executed
+ * in an individual batch. This batch is to run right before the normal resolution batch and
+ * execute its rules in one pass.
+ */
+ val preResolutionRules: Seq[Rule[LogicalPlan]] = Nil
+
/**
* Override to provide additional rules for the "Resolution" batch.
*/
@@ -276,6 +283,8 @@ class Analyzer(override val catalogManager: CatalogManager)
LookupFunctions),
Batch("Keep Legacy Outputs", Once,
KeepLegacyOutputs),
+ Batch("PreResolution", Once,
+ preResolutionRules: _*),
Batch("Resolution", fixedPoint,
ResolveTableValuedFunctions(v1SessionCatalog) ::
ResolveNamespace(catalogManager) ::
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
index a8ccc39ac478f..bc68d0e1800ab 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
@@ -38,7 +38,9 @@ import org.apache.spark.sql.execution.{ColumnarRule, SparkPlan}
* This current provides the following extension points:
*
*
+ * - Pre Analyzer Rules.
* - Analyzer Rules.
+ * - Post-hoc Analyzer Rules.
* - Check Analysis Rules.
* - Optimizer Rules.
* - Pre CBO Rules.
@@ -165,6 +167,23 @@ class SparkSessionExtensions {
runtimeOptimizerRules += builder
}
+ private[this] val preResolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder]
+
+ /**
+ * Build the analyzer resolution `Rule`s using the given [[SparkSession]].
+ */
+ private[sql] def buildPreResolutionRules(session: SparkSession): Seq[Rule[LogicalPlan]] = {
+ preResolutionRuleBuilders.map(_.apply(session)).toSeq
+ }
+
+ /**
+ * Inject an analyzer pre-resolution `Rule` builder into the [[SparkSession]]. These analyzer
+ * rules will be executed as part of the resolution phase of analysis.
+ */
+ def injectPreResolutionRule(builder: RuleBuilder): Unit = {
+ preResolutionRuleBuilders += builder
+ }
+
private[this] val resolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder]
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
index f3cbb789a94b5..6a8a91b371324 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
@@ -181,6 +181,9 @@ abstract class BaseSessionStateBuilder(
* Note: this depends on the `conf` and `catalog` fields.
*/
protected def analyzer: Analyzer = new Analyzer(catalogManager) {
+
+ override val preResolutionRules: Seq[Rule[LogicalPlan]] = customPreResolutionRules
+
override val extendedResolutionRules: Seq[Rule[LogicalPlan]] =
new FindDataSourceTable(session) +:
new ResolveSQLOnFile(session) +:
@@ -207,6 +210,16 @@ abstract class BaseSessionStateBuilder(
customCheckRules
}
+ /**
+ * Custom post resolution rules to add to the Analyzer. Prefer overriding this instead of
+ * creating your own Analyzer.
+ *
+ * Note that this may NOT depend on the `analyzer` function.
+ */
+ protected def customPreResolutionRules: Seq[Rule[LogicalPlan]] = {
+ extensions.buildPreResolutionRules(session)
+ }
+
/**
* Custom resolution rules to add to the Analyzer. Prefer overriding this instead of creating
* your own Analyzer.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
index 1aef458a3529f..a92e8e2f87352 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
@@ -20,11 +20,12 @@ import java.util.{Locale, UUID}
import scala.concurrent.Future
-import org.apache.spark.{MapOutputStatistics, SparkFunSuite, TaskContext}
+import org.apache.spark.{MapOutputStatistics, SparkException, SparkFunSuite, TaskContext}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier}
+import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface}
import org.apache.spark.sql.catalyst.plans.SQLHelper
@@ -32,6 +33,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Limit, LocalRelation, Logica
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
+import org.apache.spark.sql.connector.catalog.{CatalogManager, LookupCatalog}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, BroadcastExchangeLike, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin}
@@ -428,6 +430,22 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper {
}
}
}
+
+ test("SPARK-39194: inject pre resolution for catalog loading") {
+ val spark = SparkSession.builder()
+ .master("local[1]")
+ .withExtensions(MyExtensionsWithCatalog)
+ .getOrCreate()
+ try {
+ val e1 = intercept[SparkException](spark.sql("select * from a.b.c"))
+ assert(e1.getMessage contains "org.apache.spark.YourCatalogClass",
+ "catalog shall be pre installed")
+ val e2 = intercept[AnalysisException](spark.sql("select * from b.c"))
+ assert(e2.getMessage contains "Table or view not found: b.c")
+ } finally {
+ stop(spark)
+ }
+ }
}
case class MyRule(spark: SparkSession) extends Rule[LogicalPlan] {
@@ -1053,3 +1071,20 @@ object AddLimit extends Rule[LogicalPlan] {
case _ => Limit(Literal(1), plan)
}
}
+
+object MyExtensionsWithCatalog extends SparkSessionExtensionsProvider {
+ override def apply(v1: SparkSessionExtensions): Unit = {
+ v1.injectPreResolutionRule(spark => new MyInjectCatalogs(spark))
+ }
+ class MyInjectCatalogs(spark: SparkSession) extends Rule[LogicalPlan] with LookupCatalog {
+ override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
+ case u @ UnresolvedRelation(CatalogAndIdentifier(catalog, ident), _, _)
+ if CatalogManager.SESSION_CATALOG_NAME == catalog.name() && ident.namespace().length > 1 =>
+ conf.setConfString(s"spark.sql.catalog.${ident.namespace().head}",
+ "org.apache.spark.YourCatalogClass")
+ u
+ }
+
+ override protected lazy val catalogManager: CatalogManager = spark.sessionState.catalogManager
+ }
+}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala
index 471f2c2303048..12ef2ea19221e 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala
@@ -83,6 +83,10 @@ class HiveSessionStateBuilder(
* A logical query plan `Analyzer` with rules specific to Hive.
*/
override protected def analyzer: Analyzer = new Analyzer(catalogManager) {
+
+ override val preResolutionRules: Seq[Rule[LogicalPlan]] =
+ customPreResolutionRules
+
override val extendedResolutionRules: Seq[Rule[LogicalPlan]] =
new ResolveHiveSerdeTable(session) +:
new FindDataSourceTable(session) +: