Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,13 @@ class Analyzer(override val catalogManager: CatalogManager)
errorOnExceed = true,
maxIterationsSetting = SQLConf.ANALYZER_MAX_ITERATIONS.key)

/**
* Override to provide rules to do pre-resolution. Note that these rules will be executed
* in an individual batch. This batch is to run right before the normal resolution batch and
* execute its rules in one pass.
*/
val preResolutionRules: Seq[Rule[LogicalPlan]] = Nil

/**
* Override to provide additional rules for the "Resolution" batch.
*/
Expand Down Expand Up @@ -276,6 +283,8 @@ class Analyzer(override val catalogManager: CatalogManager)
LookupFunctions),
Batch("Keep Legacy Outputs", Once,
KeepLegacyOutputs),
Batch("PreResolution", Once,
preResolutionRules: _*),
Batch("Resolution", fixedPoint,
ResolveTableValuedFunctions(v1SessionCatalog) ::
ResolveNamespace(catalogManager) ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ import org.apache.spark.sql.execution.{ColumnarRule, SparkPlan}
* This current provides the following extension points:
*
* <ul>
* <li>Pre Analyzer Rules.</li>
* <li>Analyzer Rules.</li>
* <li>Post-hoc Analyzer Rules.</li>
* <li>Check Analysis Rules.</li>
* <li>Optimizer Rules.</li>
* <li>Pre CBO Rules.</li>
Expand Down Expand Up @@ -165,6 +167,23 @@ class SparkSessionExtensions {
runtimeOptimizerRules += builder
}

private[this] val preResolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder]

/**
* Build the analyzer resolution `Rule`s using the given [[SparkSession]].
*/
private[sql] def buildPreResolutionRules(session: SparkSession): Seq[Rule[LogicalPlan]] = {
preResolutionRuleBuilders.map(_.apply(session)).toSeq
}

/**
* Inject an analyzer pre-resolution `Rule` builder into the [[SparkSession]]. These analyzer
* rules will be executed as part of the resolution phase of analysis.
*/
def injectPreResolutionRule(builder: RuleBuilder): Unit = {
preResolutionRuleBuilders += builder
}

private[this] val resolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder]

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,9 @@ abstract class BaseSessionStateBuilder(
* Note: this depends on the `conf` and `catalog` fields.
*/
protected def analyzer: Analyzer = new Analyzer(catalogManager) {

override val preResolutionRules: Seq[Rule[LogicalPlan]] = customPreResolutionRules

override val extendedResolutionRules: Seq[Rule[LogicalPlan]] =
new FindDataSourceTable(session) +:
new ResolveSQLOnFile(session) +:
Expand All @@ -207,6 +210,16 @@ abstract class BaseSessionStateBuilder(
customCheckRules
}

/**
* Custom post resolution rules to add to the Analyzer. Prefer overriding this instead of
* creating your own Analyzer.
*
* Note that this may NOT depend on the `analyzer` function.
*/
protected def customPreResolutionRules: Seq[Rule[LogicalPlan]] = {
extensions.buildPreResolutionRules(session)
}

/**
* Custom resolution rules to add to the Analyzer. Prefer overriding this instead of creating
* your own Analyzer.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,20 @@ import java.util.{Locale, UUID}

import scala.concurrent.Future

import org.apache.spark.{MapOutputStatistics, SparkFunSuite, TaskContext}
import org.apache.spark.{MapOutputStatistics, SparkException, SparkFunSuite, TaskContext}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface}
import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.catalyst.plans.logical.{Limit, LocalRelation, LogicalPlan, Statistics, UnresolvedHint}
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.connector.catalog.{CatalogManager, LookupCatalog}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, BroadcastExchangeLike, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin}
Expand Down Expand Up @@ -428,6 +430,22 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper {
}
}
}

test("SPARK-39194: inject pre resolution for catalog loading") {
val spark = SparkSession.builder()
.master("local[1]")
.withExtensions(MyExtensionsWithCatalog)
.getOrCreate()
try {
val e1 = intercept[SparkException](spark.sql("select * from a.b.c"))
assert(e1.getMessage contains "org.apache.spark.YourCatalogClass",
"catalog shall be pre installed")
val e2 = intercept[AnalysisException](spark.sql("select * from b.c"))
assert(e2.getMessage contains "Table or view not found: b.c")
} finally {
stop(spark)
}
}
}

case class MyRule(spark: SparkSession) extends Rule[LogicalPlan] {
Expand Down Expand Up @@ -1053,3 +1071,20 @@ object AddLimit extends Rule[LogicalPlan] {
case _ => Limit(Literal(1), plan)
}
}

object MyExtensionsWithCatalog extends SparkSessionExtensionsProvider {
override def apply(v1: SparkSessionExtensions): Unit = {
v1.injectPreResolutionRule(spark => new MyInjectCatalogs(spark))
}
class MyInjectCatalogs(spark: SparkSession) extends Rule[LogicalPlan] with LookupCatalog {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case u @ UnresolvedRelation(CatalogAndIdentifier(catalog, ident), _, _)
if CatalogManager.SESSION_CATALOG_NAME == catalog.name() && ident.namespace().length > 1 =>
conf.setConfString(s"spark.sql.catalog.${ident.namespace().head}",
"org.apache.spark.YourCatalogClass")
u
}

override protected lazy val catalogManager: CatalogManager = spark.sessionState.catalogManager
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ class HiveSessionStateBuilder(
* A logical query plan `Analyzer` with rules specific to Hive.
*/
override protected def analyzer: Analyzer = new Analyzer(catalogManager) {

override val preResolutionRules: Seq[Rule[LogicalPlan]] =
customPreResolutionRules

override val extendedResolutionRules: Seq[Rule[LogicalPlan]] =
new ResolveHiveSerdeTable(session) +:
new FindDataSourceTable(session) +:
Expand Down