diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 52b567ea250b1..76b8d71ac9359 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -394,7 +394,8 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { */ def table(tableName: String): DataFrame = { Dataset.newDataFrame(sqlContext, - sqlContext.catalog.lookupRelation(sqlContext.sqlParser.parseTableIdentifier(tableName))) + sqlContext.catalog.lookupRelation( + sqlContext.sessionState.sqlParser.parseTableIdentifier(tableName))) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 3349b8421b3e8..de87f4d7c24ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -242,7 +242,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { options = extraOptions.toMap, partitionColumns = normalizedParCols.getOrElse(Nil)) - df.sqlContext.continuousQueryManager.startQuery( + df.sqlContext.sessionState.continuousQueryManager.startQuery( extraOptions.getOrElse("queryName", StreamExecution.nextName), df, dataSource.createSink()) } @@ -255,7 +255,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.4.0 */ def insertInto(tableName: String): Unit = { - insertInto(df.sqlContext.sqlParser.parseTableIdentifier(tableName)) + insertInto(df.sqlContext.sessionState.sqlParser.parseTableIdentifier(tableName)) } private def insertInto(tableIdent: TableIdentifier): Unit = { @@ -354,7 +354,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.4.0 */ def saveAsTable(tableName: String): Unit = { - saveAsTable(df.sqlContext.sqlParser.parseTableIdentifier(tableName)) + saveAsTable(df.sqlContext.sessionState.sqlParser.parseTableIdentifier(tableName)) } private def saveAsTable(tableIdent: TableIdentifier): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index b5079cf2763ff..ef239a1e2f324 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -818,7 +818,7 @@ class Dataset[T] private[sql]( @scala.annotation.varargs def selectExpr(exprs: String*): DataFrame = { select(exprs.map { expr => - Column(sqlContext.sqlParser.parseExpression(expr)) + Column(sqlContext.sessionState.sqlParser.parseExpression(expr)) }: _*) } @@ -919,7 +919,7 @@ class Dataset[T] private[sql]( * @since 1.3.0 */ def filter(conditionExpr: String): Dataset[T] = { - filter(Column(sqlContext.sqlParser.parseExpression(conditionExpr))) + filter(Column(sqlContext.sessionState.sqlParser.parseExpression(conditionExpr))) } /** @@ -943,7 +943,7 @@ class Dataset[T] private[sql]( * @since 1.5.0 */ def where(conditionExpr: String): Dataset[T] = { - filter(Column(sqlContext.sqlParser.parseExpression(conditionExpr))) + filter(Column(sqlContext.sessionState.sqlParser.parseExpression(conditionExpr))) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala index deed45d273c33..d7cd84fd246c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.rules.Rule * @since 1.3.0 */ @Experimental -class ExperimentalMethods protected[sql](sqlContext: SQLContext) { +class ExperimentalMethods private[sql]() { /** * Allows extra strategies to be injected into the query planner at runtime. Note this API diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 36fe57f78be1d..0f5d1c8cab519 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -121,14 +121,7 @@ class SQLContext private[sql]( protected[sql] lazy val sessionState: SessionState = new SessionState(self) protected[sql] def conf: SQLConf = sessionState.conf protected[sql] def catalog: Catalog = sessionState.catalog - protected[sql] def functionRegistry: FunctionRegistry = sessionState.functionRegistry protected[sql] def analyzer: Analyzer = sessionState.analyzer - protected[sql] def optimizer: Optimizer = sessionState.optimizer - protected[sql] def sqlParser: ParserInterface = sessionState.sqlParser - protected[sql] def planner: SparkPlanner = sessionState.planner - protected[sql] def continuousQueryManager = sessionState.continuousQueryManager - protected[sql] def prepareForExecution: RuleExecutor[SparkPlan] = - sessionState.prepareForExecution /** * An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s @@ -197,7 +190,7 @@ class SQLContext private[sql]( */ def getAllConfs: immutable.Map[String, String] = conf.getAllConfs - protected[sql] def parseSql(sql: String): LogicalPlan = sqlParser.parsePlan(sql) + protected[sql] def parseSql(sql: String): LogicalPlan = sessionState.sqlParser.parsePlan(sql) protected[sql] def executeSql(sql: String): QueryExecution = executePlan(parseSql(sql)) @@ -244,7 +237,7 @@ class SQLContext private[sql]( */ @Experimental @transient - val experimental: ExperimentalMethods = new ExperimentalMethods(this) + def experimental: ExperimentalMethods = sessionState.experimentalMethods /** * :: Experimental :: @@ -641,7 +634,7 @@ class SQLContext private[sql]( tableName: String, source: String, options: Map[String, String]): DataFrame = { - val tableIdent = sqlParser.parseTableIdentifier(tableName) + val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName) val cmd = CreateTableUsing( tableIdent, @@ -687,7 +680,7 @@ class SQLContext private[sql]( source: String, schema: StructType, options: Map[String, String]): DataFrame = { - val tableIdent = sqlParser.parseTableIdentifier(tableName) + val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName) val cmd = CreateTableUsing( tableIdent, @@ -706,7 +699,7 @@ class SQLContext private[sql]( * only during the lifetime of this instance of SQLContext. */ private[sql] def registerDataFrameAsTable(df: DataFrame, tableName: String): Unit = { - catalog.registerTable(sqlParser.parseTableIdentifier(tableName), df.logicalPlan) + catalog.registerTable(sessionState.sqlParser.parseTableIdentifier(tableName), df.logicalPlan) } /** @@ -800,7 +793,7 @@ class SQLContext private[sql]( * @since 1.3.0 */ def table(tableName: String): DataFrame = { - table(sqlParser.parseTableIdentifier(tableName)) + table(sessionState.sqlParser.parseTableIdentifier(tableName)) } private def table(tableIdent: TableIdentifier): DataFrame = { @@ -837,9 +830,7 @@ class SQLContext private[sql]( * * @since 2.0.0 */ - def streams: ContinuousQueryManager = { - continuousQueryManager - } + def streams: ContinuousQueryManager = sessionState.continuousQueryManager /** * Returns the names of tables in the current database as an array. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 9e60c1cd6141c..5b4254f741ab1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -45,16 +45,16 @@ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) { sqlContext.cacheManager.useCachedData(analyzed) } - lazy val optimizedPlan: LogicalPlan = sqlContext.optimizer.execute(withCachedData) + lazy val optimizedPlan: LogicalPlan = sqlContext.sessionState.optimizer.execute(withCachedData) lazy val sparkPlan: SparkPlan = { SQLContext.setActive(sqlContext) - sqlContext.planner.plan(ReturnAnswer(optimizedPlan)).next() + sqlContext.sessionState.planner.plan(ReturnAnswer(optimizedPlan)).next() } // executedPlan should not be used to initialize any SparkPlan. It should be // only used for execution. - lazy val executedPlan: SparkPlan = sqlContext.prepareForExecution.execute(sparkPlan) + lazy val executedPlan: SparkPlan = sqlContext.sessionState.prepareForExecution.execute(sparkPlan) /** Internal version of the RDD. Avoids copies and has no schema */ lazy val toRdd: RDD[InternalRow] = executedPlan.execute() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index edaf3b36aa52e..cbde777d98415 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -17,11 +17,10 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.optimizer._ +import org.apache.spark.sql.ExperimentalMethods +import org.apache.spark.sql.catalyst.optimizer.Optimizer -class SparkOptimizer(val sqlContext: SQLContext) - extends Optimizer { - override def batches: Seq[Batch] = super.batches :+ Batch( - "User Provided Optimizers", FixedPoint(100), sqlContext.experimental.extraOptimizations: _*) +class SparkOptimizer(experimentalMethods: ExperimentalMethods) extends Optimizer { + override def batches: Seq[Batch] = super.batches :+ Batch( + "User Provided Optimizers", FixedPoint(100), experimentalMethods.extraOptimizations: _*) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index 292d366e727d3..9da2c74c62fc6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -21,14 +21,18 @@ import org.apache.spark.SparkContext import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, FileSourceStrategy} +import org.apache.spark.sql.internal.SQLConf -class SparkPlanner(val sqlContext: SQLContext) extends SparkStrategies { - val sparkContext: SparkContext = sqlContext.sparkContext +class SparkPlanner( + val sparkContext: SparkContext, + val conf: SQLConf, + val experimentalMethods: ExperimentalMethods) + extends SparkStrategies { - def numPartitions: Int = sqlContext.conf.numShufflePartitions + def numPartitions: Int = conf.numShufflePartitions def strategies: Seq[Strategy] = - sqlContext.experimental.extraStrategies ++ ( + experimentalMethods.extraStrategies ++ ( FileSourceStrategy :: DataSourceStrategy :: DDLStrategy :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 6352c48c76ea5..113cf9ae2f222 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -80,8 +80,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { */ object CanBroadcast { def unapply(plan: LogicalPlan): Option[LogicalPlan] = { - if (sqlContext.conf.autoBroadcastJoinThreshold > 0 && - plan.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold) { + if (conf.autoBroadcastJoinThreshold > 0 && + plan.statistics.sizeInBytes <= conf.autoBroadcastJoinThreshold) { Some(plan) } else { None diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 8fb4705581a38..81676d3ebb346 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution import org.apache.spark.broadcast import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -29,6 +28,7 @@ import org.apache.spark.sql.catalyst.util.toCommentSafeString import org.apache.spark.sql.execution.aggregate.TungstenAggregate import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, SortMergeJoin} import org.apache.spark.sql.execution.metric.LongSQLMetricValue +import org.apache.spark.sql.internal.SQLConf /** * An interface for those physical operators that support codegen. @@ -427,7 +427,7 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup /** * Find the chained plans that support codegen, collapse them together as WholeStageCodegen. */ -private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Rule[SparkPlan] { +case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] { private def supportCodegen(e: Expression): Boolean = e match { case e: LeafExpression => true @@ -472,7 +472,7 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru } def apply(plan: SparkPlan): SparkPlan = { - if (sqlContext.conf.wholeStageEnabled) { + if (conf.wholeStageEnabled) { insertWholeStageCodegen(plan) } else { plan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index 6e36a15a6d033..e711797c1b51a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -358,13 +358,14 @@ case class ShowFunctions(db: Option[String], pattern: Option[String]) extends Ru case Some(p) => try { val regex = java.util.regex.Pattern.compile(p) - sqlContext.functionRegistry.listFunction().filter(regex.matcher(_).matches()).map(Row(_)) + sqlContext.sessionState.functionRegistry.listFunction() + .filter(regex.matcher(_).matches()).map(Row(_)) } catch { // probably will failed in the regex that user provided, then returns empty row. case _: Throwable => Seq.empty[Row] } case None => - sqlContext.functionRegistry.listFunction().map(Row(_)) + sqlContext.sessionState.functionRegistry.listFunction().map(Row(_)) } } @@ -395,7 +396,7 @@ case class DescribeFunction( } override def run(sqlContext: SQLContext): Seq[Row] = { - sqlContext.functionRegistry.lookupFunction(functionName) match { + sqlContext.sessionState.functionRegistry.lookupFunction(functionName) match { case Some(info) => val result = Row(s"Function: ${info.getName}") :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 709a4246365dd..4864db7f2ac9b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.execution.exchange -import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ +import org.apache.spark.sql.internal.SQLConf /** * Ensures that the [[org.apache.spark.sql.catalyst.plans.physical.Partitioning Partitioning]] @@ -30,15 +30,15 @@ import org.apache.spark.sql.execution._ * each operator by inserting [[ShuffleExchange]] Operators where required. Also ensure that the * input partition ordering requirements are met. */ -private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[SparkPlan] { - private def defaultNumPreShufflePartitions: Int = sqlContext.conf.numShufflePartitions +case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { + private def defaultNumPreShufflePartitions: Int = conf.numShufflePartitions - private def targetPostShuffleInputSize: Long = sqlContext.conf.targetPostShuffleInputSize + private def targetPostShuffleInputSize: Long = conf.targetPostShuffleInputSize - private def adaptiveExecutionEnabled: Boolean = sqlContext.conf.adaptiveExecutionEnabled + private def adaptiveExecutionEnabled: Boolean = conf.adaptiveExecutionEnabled private def minNumPostShufflePartitions: Option[Int] = { - val minNumPostShufflePartitions = sqlContext.conf.minNumPostShufflePartitions + val minNumPostShufflePartitions = conf.minNumPostShufflePartitions if (minNumPostShufflePartitions > 0) Some(minNumPostShufflePartitions) else None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala index 12513e9106707..9eaadea1b11ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala @@ -22,11 +22,11 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.broadcast import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType /** @@ -64,10 +64,10 @@ case class ReusedExchange(override val output: Seq[Attribute], child: Exchange) * Find out duplicated exchanges in the spark plan, then use the same exchange for all the * references. */ -private[sql] case class ReuseExchange(sqlContext: SQLContext) extends Rule[SparkPlan] { +case class ReuseExchange(conf: SQLConf) extends Rule[SparkPlan] { def apply(plan: SparkPlan): SparkPlan = { - if (!sqlContext.conf.exchangeReuseEnabled) { + if (!conf.exchangeReuseEnabled) { return plan } // Build a hash map using schema of exchanges to avoid O(N*N) sameResult calls. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index e6d7480b0422c..0d580703f5547 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -17,12 +17,12 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.{expressions, InternalRow} import org.apache.spark.sql.catalyst.expressions.{ExprId, Literal, SubqueryExpression} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SessionState import org.apache.spark.sql.types.DataType /** @@ -62,12 +62,12 @@ case class ScalarSubquery( /** * Convert the subquery from logical plan into executed plan. */ -case class PlanSubqueries(sqlContext: SQLContext) extends Rule[SparkPlan] { +case class PlanSubqueries(sessionState: SessionState) extends Rule[SparkPlan] { def apply(plan: SparkPlan): SparkPlan = { plan.transformAllExpressions { case subquery: expressions.ScalarSubquery => - val sparkPlan = sqlContext.planner.plan(ReturnAnswer(subquery.query)).next() - val executedPlan = sqlContext.prepareForExecution.execute(sparkPlan) + val sparkPlan = sessionState.planner.plan(ReturnAnswer(subquery.query)).next() + val executedPlan = sessionState.prepareForExecution.execute(sparkPlan) ScalarSubquery(executedPlan, subquery.exprId) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 326c1e5a7cc03..dd4aa9e93ae4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1161,7 +1161,7 @@ object functions { * @group normal_funcs */ def expr(expr: String): Column = { - val parser = SQLContext.getActive().map(_.sqlParser).getOrElse(new CatalystQl()) + val parser = SQLContext.getActive().map(_.sessionState.sqlParser).getOrElse(new CatalystQl()) Column(parser.parseExpression(expr)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index 98ada4d58af7e..e6be0ab3bc420 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.internal -import org.apache.spark.sql.{ContinuousQueryManager, SQLContext, UDFRegistration} +import org.apache.spark.sql.{ContinuousQueryManager, ExperimentalMethods, SQLContext, UDFRegistration} import org.apache.spark.sql.catalyst.analysis.{Analyzer, Catalog, FunctionRegistry, SimpleCatalog} import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.parser.ParserInterface @@ -40,6 +40,8 @@ private[sql] class SessionState(ctx: SQLContext) { */ lazy val conf = new SQLConf + lazy val experimentalMethods = new ExperimentalMethods + /** * Internal catalog for managing table and database states. */ @@ -73,7 +75,7 @@ private[sql] class SessionState(ctx: SQLContext) { /** * Logical query plan optimizer. */ - lazy val optimizer: Optimizer = new SparkOptimizer(ctx) + lazy val optimizer: Optimizer = new SparkOptimizer(experimentalMethods) /** * Parser that extracts expressions, plans, table identifiers etc. from SQL texts. @@ -83,7 +85,7 @@ private[sql] class SessionState(ctx: SQLContext) { /** * Planner that converts optimized logical plans to physical plans. */ - lazy val planner: SparkPlanner = new SparkPlanner(ctx) + lazy val planner: SparkPlanner = new SparkPlanner(ctx.sparkContext, conf, experimentalMethods) /** * Prepares a planned [[SparkPlan]] for execution by inserting shuffle operations and internal @@ -91,10 +93,10 @@ private[sql] class SessionState(ctx: SQLContext) { */ lazy val prepareForExecution = new RuleExecutor[SparkPlan] { override val batches: Seq[Batch] = Seq( - Batch("Subquery", Once, PlanSubqueries(ctx)), - Batch("Add exchange", Once, EnsureRequirements(ctx)), - Batch("Whole stage codegen", Once, CollapseCodegenStages(ctx)), - Batch("Reuse duplicated exchanges", Once, ReuseExchange(ctx)) + Batch("Subquery", Once, PlanSubqueries(SessionState.this)), + Batch("Add exchange", Once, EnsureRequirements(conf)), + Batch("Whole stage codegen", Once, CollapseCodegenStages(conf)), + Batch("Reuse duplicated exchanges", Once, ReuseExchange(conf)) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 2bd29ef19b649..50647c28402eb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -37,7 +37,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { val x = testData2.as("x") val y = testData2.as("y") val join = x.join(y, $"x.a" === $"y.a", "inner").queryExecution.optimizedPlan - val planned = sqlContext.planner.EquiJoinSelection(join) + val planned = sqlContext.sessionState.planner.EquiJoinSelection(join) assert(planned.size === 1) } @@ -139,7 +139,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { val x = testData2.as("x") val y = testData2.as("y") val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === $"y.b")).queryExecution.optimizedPlan - val planned = sqlContext.planner.EquiJoinSelection(join) + val planned = sqlContext.sessionState.planner.EquiJoinSelection(join) assert(planned.size === 1) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala index ec19d97d8cec2..2ad92b52c4ff0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala @@ -76,6 +76,6 @@ class SQLContextSuite extends SparkFunSuite with SharedSparkContext{ test("Catalyst optimization passes are modifiable at runtime") { val sqlContext = SQLContext.getOrCreate(sc) sqlContext.experimental.extraOptimizations = Seq(DummyRule) - assert(sqlContext.optimizer.batches.flatMap(_.rules).contains(DummyRule)) + assert(sqlContext.sessionState.optimizer.batches.flatMap(_.rules).contains(DummyRule)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 98d0008489f4d..836fb1ce853c8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -54,7 +54,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("show functions") { def getFunctions(pattern: String): Seq[Row] = { val regex = java.util.regex.Pattern.compile(pattern) - sqlContext.functionRegistry.listFunction().filter(regex.matcher(_).matches()).map(Row(_)) + sqlContext.sessionState.functionRegistry.listFunction() + .filter(regex.matcher(_).matches()).map(Row(_)) } checkAnswer(sql("SHOW functions"), getFunctions(".*")) Seq("^c.*", ".*e$", "log.*", ".*date.*").foreach { pattern => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index ab0a7ff628962..88fbcda296cac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -37,7 +37,7 @@ class PlannerSuite extends SharedSQLContext { setupTestData() private def testPartialAggregationPlan(query: LogicalPlan): Unit = { - val planner = sqlContext.planner + val planner = sqlContext.sessionState.planner import planner._ val plannedOption = Aggregation(query).headOption val planned = @@ -294,7 +294,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildDistribution = Seq(distribution, distribution), requiredChildOrdering = Seq(Seq.empty, Seq.empty) ) - val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchange => true }.isEmpty) { fail(s"Exchange should have been added:\n$outputPlan") @@ -314,7 +314,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildDistribution = Seq(distribution, distribution), requiredChildOrdering = Seq(Seq.empty, Seq.empty) ) - val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) } @@ -332,7 +332,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildDistribution = Seq(distribution, distribution), requiredChildOrdering = Seq(Seq.empty, Seq.empty) ) - val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchange => true }.isEmpty) { fail(s"Exchange should have been added:\n$outputPlan") @@ -352,7 +352,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildDistribution = Seq(distribution, distribution), requiredChildOrdering = Seq(Seq.empty, Seq.empty) ) - val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchange => true }.nonEmpty) { fail(s"Exchange should not have been added:\n$outputPlan") @@ -375,7 +375,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildDistribution = Seq(distribution, distribution), requiredChildOrdering = Seq(outputOrdering, outputOrdering) ) - val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchange => true }.nonEmpty) { fail(s"No Exchanges should have been added:\n$outputPlan") @@ -391,7 +391,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildOrdering = Seq(Seq(orderingB)), requiredChildDistribution = Seq(UnspecifiedDistribution) ) - val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case s: Sort => true }.isEmpty) { fail(s"Sort should have been added:\n$outputPlan") @@ -407,7 +407,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildOrdering = Seq(Seq(orderingA)), requiredChildDistribution = Seq(UnspecifiedDistribution) ) - val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case s: Sort => true }.nonEmpty) { fail(s"No sorts should have been added:\n$outputPlan") @@ -424,7 +424,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildOrdering = Seq(Seq(orderingA, orderingB)), requiredChildDistribution = Seq(UnspecifiedDistribution) ) - val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case s: Sort => true }.isEmpty) { fail(s"Sort should have been added:\n$outputPlan") @@ -443,7 +443,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildOrdering = Seq(Seq.empty)), None) - val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchange => true }.size == 2) { fail(s"Topmost Exchange should have been eliminated:\n$outputPlan") @@ -463,7 +463,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildOrdering = Seq(Seq.empty)), None) - val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchange => true }.size == 1) { fail(s"Topmost Exchange should not have been eliminated:\n$outputPlan") @@ -491,7 +491,7 @@ class PlannerSuite extends SharedSQLContext { shuffle, shuffle) - val outputPlan = ReuseExchange(sqlContext).apply(inputPlan) + val outputPlan = ReuseExchange(sqlContext.sessionState.conf).apply(inputPlan) if (outputPlan.collect { case e: ReusedExchange => true }.size != 1) { fail(s"Should re-use the shuffle:\n$outputPlan") } @@ -507,7 +507,7 @@ class PlannerSuite extends SharedSQLContext { ShuffleExchange(finalPartitioning, inputPlan), ShuffleExchange(finalPartitioning, inputPlan)) - val outputPlan2 = ReuseExchange(sqlContext).apply(inputPlan2) + val outputPlan2 = ReuseExchange(sqlContext.sessionState.conf).apply(inputPlan2) if (outputPlan2.collect { case e: ReusedExchange => true }.size != 2) { fail(s"Should re-use the two shuffles:\n$outputPlan2") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index aa928cfc8096f..ed0d3f56e5ca9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -233,7 +233,7 @@ object SparkPlanTest { private def executePlan(outputPlan: SparkPlan, sqlContext: SQLContext): Seq[Row] = { // A very simple resolver to make writing tests easier. In contrast to the real resolver // this is always case sensitive and does not try to handle scoping or complex type resolution. - val resolvedPlan = sqlContext.prepareForExecution.execute( + val resolvedPlan = sqlContext.sessionState.prepareForExecution.execute( outputPlan transform { case plan: SparkPlan => val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index a256ee95a153c..6d5b777733f41 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -63,7 +63,8 @@ class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll { // Comparison at the end is for broadcast left semi join val joinExpression = df1("key") === df2("key") && df1("value") > df2("value") val df3 = df1.join(broadcast(df2), joinExpression, joinType) - val plan = EnsureRequirements(sqlContext).apply(df3.queryExecution.sparkPlan) + val plan = + EnsureRequirements(sqlContext.sessionState.conf).apply(df3.queryExecution.sparkPlan) assert(plan.collect { case p: T => p }.size === 1) plan.executeCollect() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index 7eb15249ebbd6..eeb44404e9e47 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -98,7 +98,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { boundCondition, leftPlan, rightPlan) - EnsureRequirements(sqlContext).apply(broadcastJoin) + EnsureRequirements(sqlContext.sessionState.conf).apply(broadcastJoin) } def makeSortMergeJoin( @@ -109,7 +109,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { rightPlan: SparkPlan) = { val sortMergeJoin = joins.SortMergeJoin(leftKeys, rightKeys, boundCondition, leftPlan, rightPlan) - EnsureRequirements(sqlContext).apply(sortMergeJoin) + EnsureRequirements(sqlContext.sessionState.conf).apply(sortMergeJoin) } test(s"$testName using BroadcastHashJoin (build=left)") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 0d1c29fe574a6..45254864309eb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -98,7 +98,7 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - EnsureRequirements(sqlContext).apply( + EnsureRequirements(sqlContext.sessionState.conf).apply( SortMergeOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)), expectedAnswer.map(Row.fromTuple), sortAnswers = true) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala index bc341db5571be..d8c9564f1e4fb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala @@ -76,7 +76,7 @@ class SemiJoinSuite extends SparkPlanTest with SharedSQLContext { extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - EnsureRequirements(left.sqlContext).apply( + EnsureRequirements(left.sqlContext.sessionState.conf).apply( LeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition)), expectedAnswer.map(Row.fromTuple), sortAnswers = true) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 8244dd4230102..a78b7b0cc4961 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -348,12 +348,12 @@ class HiveContext private[hive]( * @since 1.3.0 */ def refreshTable(tableName: String): Unit = { - val tableIdent = sqlParser.parseTableIdentifier(tableName) + val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName) catalog.refreshTable(tableIdent) } protected[hive] def invalidateTable(tableName: String): Unit = { - val tableIdent = sqlParser.parseTableIdentifier(tableName) + val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName) catalog.invalidateTable(tableIdent) } @@ -367,7 +367,7 @@ class HiveContext private[hive]( * @since 1.2.0 */ def analyze(tableName: String) { - val tableIdent = sqlParser.parseTableIdentifier(tableName) + val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName) val relation = EliminateSubqueryAliases(catalog.lookupRelation(tableIdent)) relation match { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala index cbb6333336383..d9cd96d66f493 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala @@ -74,11 +74,11 @@ private[hive] class HiveSessionState(ctx: HiveContext) extends SessionState(ctx) * Planner that takes into account Hive-specific strategies. */ override lazy val planner: SparkPlanner = { - new SparkPlanner(ctx) with HiveStrategies { + new SparkPlanner(ctx.sparkContext, conf, experimentalMethods) with HiveStrategies { override val hiveContext = ctx override def strategies: Seq[Strategy] = { - ctx.experimental.extraStrategies ++ Seq( + experimentalMethods.extraStrategies ++ Seq( FileSourceStrategy, DataSourceStrategy, HiveCommandStrategy(ctx), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala index d77c88fa4b384..33c1bb059e2fe 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala @@ -69,7 +69,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle private val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k") def tableDir: File = { - val identifier = hiveContext.sqlParser.parseTableIdentifier("bucketed_table") + val identifier = hiveContext.sessionState.sqlParser.parseTableIdentifier("bucketed_table") new File(URI.create(hiveContext.catalog.hiveDefaultTableFilePath(identifier))) }