Skip to content

Commit

Permalink
[SPARK-13890][SQL] Remove some internal classes' dependency on SQLCon…
Browse files Browse the repository at this point in the history
…text

## What changes were proposed in this pull request?
In general it is better for internal classes to not depend on the external class (in this case SQLContext) to reduce coupling between user-facing APIs and the internal implementations. This patch removes SQLContext dependency from some internal classes such as SparkPlanner, SparkOptimizer.

As part of this patch, I also removed the following internal methods from SQLContext:
```
protected[sql] def functionRegistry: FunctionRegistry
protected[sql] def optimizer: Optimizer
protected[sql] def sqlParser: ParserInterface
protected[sql] def planner: SparkPlanner
protected[sql] def continuousQueryManager
protected[sql] def prepareForExecution: RuleExecutor[SparkPlan]
```

## How was this patch tested?
Existing unit/integration tests.

Author: Reynold Xin <rxin@databricks.com>

Closes #11712 from rxin/sqlContext-planner.
  • Loading branch information
rxin committed Mar 15, 2016
1 parent a51f877 commit 276c2d5
Show file tree
Hide file tree
Showing 28 changed files with 95 additions and 95 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,8 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
*/
def table(tableName: String): DataFrame = {
Dataset.newDataFrame(sqlContext,
sqlContext.catalog.lookupRelation(sqlContext.sqlParser.parseTableIdentifier(tableName)))
sqlContext.catalog.lookupRelation(
sqlContext.sessionState.sqlParser.parseTableIdentifier(tableName)))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
options = extraOptions.toMap,
partitionColumns = normalizedParCols.getOrElse(Nil))

df.sqlContext.continuousQueryManager.startQuery(
df.sqlContext.sessionState.continuousQueryManager.startQuery(
extraOptions.getOrElse("queryName", StreamExecution.nextName), df, dataSource.createSink())
}

Expand All @@ -255,7 +255,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* @since 1.4.0
*/
def insertInto(tableName: String): Unit = {
insertInto(df.sqlContext.sqlParser.parseTableIdentifier(tableName))
insertInto(df.sqlContext.sessionState.sqlParser.parseTableIdentifier(tableName))
}

private def insertInto(tableIdent: TableIdentifier): Unit = {
Expand Down Expand Up @@ -354,7 +354,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* @since 1.4.0
*/
def saveAsTable(tableName: String): Unit = {
saveAsTable(df.sqlContext.sqlParser.parseTableIdentifier(tableName))
saveAsTable(df.sqlContext.sessionState.sqlParser.parseTableIdentifier(tableName))
}

private def saveAsTable(tableIdent: TableIdentifier): Unit = {
Expand Down
6 changes: 3 additions & 3 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -818,7 +818,7 @@ class Dataset[T] private[sql](
@scala.annotation.varargs
def selectExpr(exprs: String*): DataFrame = {
select(exprs.map { expr =>
Column(sqlContext.sqlParser.parseExpression(expr))
Column(sqlContext.sessionState.sqlParser.parseExpression(expr))
}: _*)
}

Expand Down Expand Up @@ -919,7 +919,7 @@ class Dataset[T] private[sql](
* @since 1.3.0
*/
def filter(conditionExpr: String): Dataset[T] = {
filter(Column(sqlContext.sqlParser.parseExpression(conditionExpr)))
filter(Column(sqlContext.sessionState.sqlParser.parseExpression(conditionExpr)))
}

/**
Expand All @@ -943,7 +943,7 @@ class Dataset[T] private[sql](
* @since 1.5.0
*/
def where(conditionExpr: String): Dataset[T] = {
filter(Column(sqlContext.sqlParser.parseExpression(conditionExpr)))
filter(Column(sqlContext.sessionState.sqlParser.parseExpression(conditionExpr)))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
* @since 1.3.0
*/
@Experimental
class ExperimentalMethods protected[sql](sqlContext: SQLContext) {
class ExperimentalMethods private[sql]() {

/**
* Allows extra strategies to be injected into the query planner at runtime. Note this API
Expand Down
23 changes: 7 additions & 16 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -121,14 +121,7 @@ class SQLContext private[sql](
protected[sql] lazy val sessionState: SessionState = new SessionState(self)
protected[sql] def conf: SQLConf = sessionState.conf
protected[sql] def catalog: Catalog = sessionState.catalog
protected[sql] def functionRegistry: FunctionRegistry = sessionState.functionRegistry
protected[sql] def analyzer: Analyzer = sessionState.analyzer
protected[sql] def optimizer: Optimizer = sessionState.optimizer
protected[sql] def sqlParser: ParserInterface = sessionState.sqlParser
protected[sql] def planner: SparkPlanner = sessionState.planner
protected[sql] def continuousQueryManager = sessionState.continuousQueryManager
protected[sql] def prepareForExecution: RuleExecutor[SparkPlan] =
sessionState.prepareForExecution

/**
* An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s
Expand Down Expand Up @@ -197,7 +190,7 @@ class SQLContext private[sql](
*/
def getAllConfs: immutable.Map[String, String] = conf.getAllConfs

protected[sql] def parseSql(sql: String): LogicalPlan = sqlParser.parsePlan(sql)
protected[sql] def parseSql(sql: String): LogicalPlan = sessionState.sqlParser.parsePlan(sql)

protected[sql] def executeSql(sql: String): QueryExecution = executePlan(parseSql(sql))

Expand Down Expand Up @@ -244,7 +237,7 @@ class SQLContext private[sql](
*/
@Experimental
@transient
val experimental: ExperimentalMethods = new ExperimentalMethods(this)
def experimental: ExperimentalMethods = sessionState.experimentalMethods

/**
* :: Experimental ::
Expand Down Expand Up @@ -641,7 +634,7 @@ class SQLContext private[sql](
tableName: String,
source: String,
options: Map[String, String]): DataFrame = {
val tableIdent = sqlParser.parseTableIdentifier(tableName)
val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName)
val cmd =
CreateTableUsing(
tableIdent,
Expand Down Expand Up @@ -687,7 +680,7 @@ class SQLContext private[sql](
source: String,
schema: StructType,
options: Map[String, String]): DataFrame = {
val tableIdent = sqlParser.parseTableIdentifier(tableName)
val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName)
val cmd =
CreateTableUsing(
tableIdent,
Expand All @@ -706,7 +699,7 @@ class SQLContext private[sql](
* only during the lifetime of this instance of SQLContext.
*/
private[sql] def registerDataFrameAsTable(df: DataFrame, tableName: String): Unit = {
catalog.registerTable(sqlParser.parseTableIdentifier(tableName), df.logicalPlan)
catalog.registerTable(sessionState.sqlParser.parseTableIdentifier(tableName), df.logicalPlan)
}

/**
Expand Down Expand Up @@ -800,7 +793,7 @@ class SQLContext private[sql](
* @since 1.3.0
*/
def table(tableName: String): DataFrame = {
table(sqlParser.parseTableIdentifier(tableName))
table(sessionState.sqlParser.parseTableIdentifier(tableName))
}

private def table(tableIdent: TableIdentifier): DataFrame = {
Expand Down Expand Up @@ -837,9 +830,7 @@ class SQLContext private[sql](
*
* @since 2.0.0
*/
def streams: ContinuousQueryManager = {
continuousQueryManager
}
def streams: ContinuousQueryManager = sessionState.continuousQueryManager

/**
* Returns the names of tables in the current database as an array.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,16 @@ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) {
sqlContext.cacheManager.useCachedData(analyzed)
}

lazy val optimizedPlan: LogicalPlan = sqlContext.optimizer.execute(withCachedData)
lazy val optimizedPlan: LogicalPlan = sqlContext.sessionState.optimizer.execute(withCachedData)

lazy val sparkPlan: SparkPlan = {
SQLContext.setActive(sqlContext)
sqlContext.planner.plan(ReturnAnswer(optimizedPlan)).next()
sqlContext.sessionState.planner.plan(ReturnAnswer(optimizedPlan)).next()
}

// executedPlan should not be used to initialize any SparkPlan. It should be
// only used for execution.
lazy val executedPlan: SparkPlan = sqlContext.prepareForExecution.execute(sparkPlan)
lazy val executedPlan: SparkPlan = sqlContext.sessionState.prepareForExecution.execute(sparkPlan)

/** Internal version of the RDD. Avoids copies and has no schema */
lazy val toRdd: RDD[InternalRow] = executedPlan.execute()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,10 @@

package org.apache.spark.sql.execution

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.optimizer._
import org.apache.spark.sql.ExperimentalMethods
import org.apache.spark.sql.catalyst.optimizer.Optimizer

class SparkOptimizer(val sqlContext: SQLContext)
extends Optimizer {
override def batches: Seq[Batch] = super.batches :+ Batch(
"User Provided Optimizers", FixedPoint(100), sqlContext.experimental.extraOptimizations: _*)
class SparkOptimizer(experimentalMethods: ExperimentalMethods) extends Optimizer {
override def batches: Seq[Batch] = super.batches :+ Batch(
"User Provided Optimizers", FixedPoint(100), experimentalMethods.extraOptimizations: _*)
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,18 @@ import org.apache.spark.SparkContext
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, FileSourceStrategy}
import org.apache.spark.sql.internal.SQLConf

class SparkPlanner(val sqlContext: SQLContext) extends SparkStrategies {
val sparkContext: SparkContext = sqlContext.sparkContext
class SparkPlanner(
val sparkContext: SparkContext,
val conf: SQLConf,
val experimentalMethods: ExperimentalMethods)
extends SparkStrategies {

def numPartitions: Int = sqlContext.conf.numShufflePartitions
def numPartitions: Int = conf.numShufflePartitions

def strategies: Seq[Strategy] =
sqlContext.experimental.extraStrategies ++ (
experimentalMethods.extraStrategies ++ (
FileSourceStrategy ::
DataSourceStrategy ::
DDLStrategy ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
*/
object CanBroadcast {
def unapply(plan: LogicalPlan): Option[LogicalPlan] = {
if (sqlContext.conf.autoBroadcastJoinThreshold > 0 &&
plan.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold) {
if (conf.autoBroadcastJoinThreshold > 0 &&
plan.statistics.sizeInBytes <= conf.autoBroadcastJoinThreshold) {
Some(plan)
} else {
None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution

import org.apache.spark.broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
Expand All @@ -29,6 +28,7 @@ import org.apache.spark.sql.catalyst.util.toCommentSafeString
import org.apache.spark.sql.execution.aggregate.TungstenAggregate
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, SortMergeJoin}
import org.apache.spark.sql.execution.metric.LongSQLMetricValue
import org.apache.spark.sql.internal.SQLConf

/**
* An interface for those physical operators that support codegen.
Expand Down Expand Up @@ -427,7 +427,7 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup
/**
* Find the chained plans that support codegen, collapse them together as WholeStageCodegen.
*/
private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Rule[SparkPlan] {
case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] {

private def supportCodegen(e: Expression): Boolean = e match {
case e: LeafExpression => true
Expand Down Expand Up @@ -472,7 +472,7 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru
}

def apply(plan: SparkPlan): SparkPlan = {
if (sqlContext.conf.wholeStageEnabled) {
if (conf.wholeStageEnabled) {
insertWholeStageCodegen(plan)
} else {
plan
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -358,13 +358,14 @@ case class ShowFunctions(db: Option[String], pattern: Option[String]) extends Ru
case Some(p) =>
try {
val regex = java.util.regex.Pattern.compile(p)
sqlContext.functionRegistry.listFunction().filter(regex.matcher(_).matches()).map(Row(_))
sqlContext.sessionState.functionRegistry.listFunction()
.filter(regex.matcher(_).matches()).map(Row(_))
} catch {
// probably will failed in the regex that user provided, then returns empty row.
case _: Throwable => Seq.empty[Row]
}
case None =>
sqlContext.functionRegistry.listFunction().map(Row(_))
sqlContext.sessionState.functionRegistry.listFunction().map(Row(_))
}
}

Expand Down Expand Up @@ -395,7 +396,7 @@ case class DescribeFunction(
}

override def run(sqlContext: SQLContext): Seq[Row] = {
sqlContext.functionRegistry.lookupFunction(functionName) match {
sqlContext.sessionState.functionRegistry.lookupFunction(functionName) match {
case Some(info) =>
val result =
Row(s"Function: ${info.getName}") ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@

package org.apache.spark.sql.execution.exchange

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.internal.SQLConf

/**
* Ensures that the [[org.apache.spark.sql.catalyst.plans.physical.Partitioning Partitioning]]
Expand All @@ -30,15 +30,15 @@ import org.apache.spark.sql.execution._
* each operator by inserting [[ShuffleExchange]] Operators where required. Also ensure that the
* input partition ordering requirements are met.
*/
private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[SparkPlan] {
private def defaultNumPreShufflePartitions: Int = sqlContext.conf.numShufflePartitions
case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
private def defaultNumPreShufflePartitions: Int = conf.numShufflePartitions

private def targetPostShuffleInputSize: Long = sqlContext.conf.targetPostShuffleInputSize
private def targetPostShuffleInputSize: Long = conf.targetPostShuffleInputSize

private def adaptiveExecutionEnabled: Boolean = sqlContext.conf.adaptiveExecutionEnabled
private def adaptiveExecutionEnabled: Boolean = conf.adaptiveExecutionEnabled

private def minNumPostShufflePartitions: Option[Int] = {
val minNumPostShufflePartitions = sqlContext.conf.minNumPostShufflePartitions
val minNumPostShufflePartitions = conf.minNumPostShufflePartitions
if (minNumPostShufflePartitions > 0) Some(minNumPostShufflePartitions) else None
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ import scala.collection.mutable.ArrayBuffer

import org.apache.spark.broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType

/**
Expand Down Expand Up @@ -64,10 +64,10 @@ case class ReusedExchange(override val output: Seq[Attribute], child: Exchange)
* Find out duplicated exchanges in the spark plan, then use the same exchange for all the
* references.
*/
private[sql] case class ReuseExchange(sqlContext: SQLContext) extends Rule[SparkPlan] {
case class ReuseExchange(conf: SQLConf) extends Rule[SparkPlan] {

def apply(plan: SparkPlan): SparkPlan = {
if (!sqlContext.conf.exchangeReuseEnabled) {
if (!conf.exchangeReuseEnabled) {
return plan
}
// Build a hash map using schema of exchanges to avoid O(N*N) sameResult calls.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@

package org.apache.spark.sql.execution

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.{expressions, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{ExprId, Literal, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SessionState
import org.apache.spark.sql.types.DataType

/**
Expand Down Expand Up @@ -62,12 +62,12 @@ case class ScalarSubquery(
/**
* Convert the subquery from logical plan into executed plan.
*/
case class PlanSubqueries(sqlContext: SQLContext) extends Rule[SparkPlan] {
case class PlanSubqueries(sessionState: SessionState) extends Rule[SparkPlan] {
def apply(plan: SparkPlan): SparkPlan = {
plan.transformAllExpressions {
case subquery: expressions.ScalarSubquery =>
val sparkPlan = sqlContext.planner.plan(ReturnAnswer(subquery.query)).next()
val executedPlan = sqlContext.prepareForExecution.execute(sparkPlan)
val sparkPlan = sessionState.planner.plan(ReturnAnswer(subquery.query)).next()
val executedPlan = sessionState.prepareForExecution.execute(sparkPlan)
ScalarSubquery(executedPlan, subquery.exprId)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1161,7 +1161,7 @@ object functions {
* @group normal_funcs
*/
def expr(expr: String): Column = {
val parser = SQLContext.getActive().map(_.sqlParser).getOrElse(new CatalystQl())
val parser = SQLContext.getActive().map(_.sessionState.sqlParser).getOrElse(new CatalystQl())
Column(parser.parseExpression(expr))
}

Expand Down
Loading

0 comments on commit 276c2d5

Please sign in to comment.