Skip to content

Commit

Permalink
[SPARK-10810] [SPARK-10902] [SQL] Improve session management in SQL
Browse files Browse the repository at this point in the history
This PR improve the sessions management by replacing the thread-local based to one SQLContext per session approach, introduce separated temporary tables and UDFs/UDAFs for each session.

A new session of SQLContext could be created by:

1) create an new SQLContext
2) call newSession() on existing SQLContext

For HiveContext, in order to reduce the cost for each session, the classloader and Hive client are shared across multiple sessions (created by newSession).

CacheManager is also shared by multiple sessions, so cache a table multiple times in different sessions will not cause multiple copies of in-memory cache.

Added jars are still shared by all the sessions, because SparkContext does not support sessions.

cc marmbrus yhuai rxin

Author: Davies Liu <davies@databricks.com>

Closes #8909 from davies/sessions.
  • Loading branch information
Davies Liu authored and davies committed Oct 9, 2015
1 parent 84ea287 commit 3390b40
Show file tree
Hide file tree
Showing 22 changed files with 540 additions and 440 deletions.
22 changes: 21 additions & 1 deletion project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,27 @@ object MimaExcludes {
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.ml.regression.LeastSquaresAggregator.add"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.ml.regression.LeastSquaresCostFun.this")
"org.apache.spark.ml.regression.LeastSquaresCostFun.this"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.SQLContext.clearLastInstantiatedContext"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.SQLContext.setLastInstantiatedContext"),
ProblemFilters.exclude[MissingClassProblem](
"org.apache.spark.sql.SQLContext$SQLSession"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.SQLContext.detachSession"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.SQLContext.tlSession"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.SQLContext.defaultSession"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.SQLContext.currentSession"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.SQLContext.openSession"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.SQLContext.setSession"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.SQLContext.createSession")
)
case v if v.startsWith("1.5") =>
Seq(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,23 +51,37 @@ class SimpleFunctionRegistry extends FunctionRegistry {
private val functionBuilders =
StringKeyHashMap[(ExpressionInfo, FunctionBuilder)](caseSensitive = false)

override def registerFunction(name: String, info: ExpressionInfo, builder: FunctionBuilder)
: Unit = {
override def registerFunction(
name: String,
info: ExpressionInfo,
builder: FunctionBuilder): Unit = synchronized {
functionBuilders.put(name, (info, builder))
}

override def lookupFunction(name: String, children: Seq[Expression]): Expression = {
val func = functionBuilders.get(name).map(_._2).getOrElse {
throw new AnalysisException(s"undefined function $name")
val func = synchronized {
functionBuilders.get(name).map(_._2).getOrElse {
throw new AnalysisException(s"undefined function $name")
}
}
func(children)
}

override def listFunction(): Seq[String] = functionBuilders.iterator.map(_._1).toList.sorted
override def listFunction(): Seq[String] = synchronized {
functionBuilders.iterator.map(_._1).toList.sorted
}

override def lookupFunction(name: String): Option[ExpressionInfo] = {
override def lookupFunction(name: String): Option[ExpressionInfo] = synchronized {
functionBuilders.get(name).map(_._1)
}

def copy(): SimpleFunctionRegistry = synchronized {
val registry = new SimpleFunctionRegistry
functionBuilders.iterator.foreach { case (name, (info, builder)) =>
registry.registerFunction(name, info, builder)
}
registry
}
}

/**
Expand Down Expand Up @@ -257,7 +271,7 @@ object FunctionRegistry {
expression[InputFileName]("input_file_name")
)

val builtin: FunctionRegistry = {
val builtin: SimpleFunctionRegistry = {
val fr = new SimpleFunctionRegistry
expressions.foreach { case (name, (info, builder)) => fr.registerFunction(name, info, builder) }
fr
Expand Down
164 changes: 93 additions & 71 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.SparkContext
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd}
import org.apache.spark.sql.SQLConf.SQLConfEntry
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.errors.DialectException
Expand All @@ -38,15 +39,12 @@ import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, _}
import org.apache.spark.sql.execution.{Filter, _}
import org.apache.spark.sql.{execution => sparkexecution}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.sources._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab}
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types._
import org.apache.spark.sql.{execution => sparkexecution}
import org.apache.spark.util.Utils

/**
Expand All @@ -64,18 +62,30 @@ import org.apache.spark.util.Utils
*
* @since 1.0.0
*/
class SQLContext(@transient val sparkContext: SparkContext)
extends org.apache.spark.Logging
with Serializable {
class SQLContext private[sql](
@transient val sparkContext: SparkContext,
@transient protected[sql] val cacheManager: CacheManager)
extends org.apache.spark.Logging with Serializable {

self =>

def this(sparkContext: SparkContext) = this(sparkContext, new CacheManager)
def this(sparkContext: JavaSparkContext) = this(sparkContext.sc)

/**
* Returns a SQLContext as new session, with separated SQL configurations, temporary tables,
* registered functions, but sharing the same SparkContext and CacheManager.
*
* @since 1.6.0
*/
def newSession(): SQLContext = {
new SQLContext(sparkContext, cacheManager)
}

/**
* @return Spark SQL configuration
*/
protected[sql] def conf = currentSession().conf
protected[sql] lazy val conf = new SQLConf

// `listener` should be only used in the driver
@transient private[sql] val listener = new SQLListener(this)
Expand Down Expand Up @@ -142,13 +152,11 @@ class SQLContext(@transient val sparkContext: SparkContext)
*/
def getAllConfs: immutable.Map[String, String] = conf.getAllConfs

// TODO how to handle the temp table per user session?
@transient
protected[sql] lazy val catalog: Catalog = new SimpleCatalog(conf)

// TODO how to handle the temp function per user session?
@transient
protected[sql] lazy val functionRegistry: FunctionRegistry = FunctionRegistry.builtin
protected[sql] lazy val functionRegistry: FunctionRegistry = FunctionRegistry.builtin.copy()

@transient
protected[sql] lazy val analyzer: Analyzer =
Expand Down Expand Up @@ -198,20 +206,19 @@ class SQLContext(@transient val sparkContext: SparkContext)
protected[sql] def executePlan(plan: LogicalPlan) =
new sparkexecution.QueryExecution(this, plan)

@transient
protected[sql] val tlSession = new ThreadLocal[SQLSession]() {
override def initialValue: SQLSession = defaultSession
}

@transient
protected[sql] val defaultSession = createSession()

protected[sql] def dialectClassName = if (conf.dialect == "sql") {
classOf[DefaultParserDialect].getCanonicalName
} else {
conf.dialect
}

/**
* Add a jar to SQLContext
*/
protected[sql] def addJar(path: String): Unit = {
sparkContext.addJar(path)
}

{
// We extract spark sql settings from SparkContext's conf and put them to
// Spark SQL's conf.
Expand All @@ -236,9 +243,6 @@ class SQLContext(@transient val sparkContext: SparkContext)
}
}

@transient
protected[sql] val cacheManager = new CacheManager(this)

/**
* :: Experimental ::
* A collection of methods that are considered experimental, but can be used to hook into
Expand Down Expand Up @@ -300,21 +304,25 @@ class SQLContext(@transient val sparkContext: SparkContext)
* @group cachemgmt
* @since 1.3.0
*/
def isCached(tableName: String): Boolean = cacheManager.isCached(tableName)
def isCached(tableName: String): Boolean = {
cacheManager.lookupCachedData(table(tableName)).nonEmpty
}

/**
* Caches the specified table in-memory.
* @group cachemgmt
* @since 1.3.0
*/
def cacheTable(tableName: String): Unit = cacheManager.cacheTable(tableName)
def cacheTable(tableName: String): Unit = {
cacheManager.cacheQuery(table(tableName), Some(tableName))
}

/**
* Removes the specified table from the in-memory cache.
* @group cachemgmt
* @since 1.3.0
*/
def uncacheTable(tableName: String): Unit = cacheManager.uncacheTable(tableName)
def uncacheTable(tableName: String): Unit = cacheManager.uncacheQuery(table(tableName))

/**
* Removes all cached tables from the in-memory cache.
Expand Down Expand Up @@ -830,36 +838,6 @@ class SQLContext(@transient val sparkContext: SparkContext)
)
}

protected[sql] def openSession(): SQLSession = {
detachSession()
val session = createSession()
tlSession.set(session)

session
}

protected[sql] def currentSession(): SQLSession = {
tlSession.get()
}

protected[sql] def createSession(): SQLSession = {
new this.SQLSession()
}

protected[sql] def detachSession(): Unit = {
tlSession.remove()
}

protected[sql] def setSession(session: SQLSession): Unit = {
detachSession()
tlSession.set(session)
}

protected[sql] class SQLSession {
// Note that this is a lazy val so we can override the default value in subclasses.
protected[sql] lazy val conf: SQLConf = new SQLConf
}

@deprecated("use org.apache.spark.sql.QueryExecution", "1.6.0")
protected[sql] class QueryExecution(logical: LogicalPlan)
extends sparkexecution.QueryExecution(this, logical)
Expand Down Expand Up @@ -1196,46 +1174,90 @@ class SQLContext(@transient val sparkContext: SparkContext)
// Register a succesfully instantiatd context to the singleton. This should be at the end of
// the class definition so that the singleton is updated only if there is no exception in the
// construction of the instance.
SQLContext.setLastInstantiatedContext(self)
sparkContext.addSparkListener(new SparkListener {
override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = {
SQLContext.clearInstantiatedContext(self)
}
})

SQLContext.setInstantiatedContext(self)
}

/**
* This SQLContext object contains utility functions to create a singleton SQLContext instance,
* or to get the last created SQLContext instance.
* or to get the created SQLContext instance.
*
* It also provides utility functions to support preference for threads in multiple sessions
* scenario, setActive could set a SQLContext for current thread, which will be returned by
* getOrCreate instead of the global one.
*/
object SQLContext {

private val INSTANTIATION_LOCK = new Object()
/**
* The active SQLContext for the current thread.
*/
private val activeContext: InheritableThreadLocal[SQLContext] =
new InheritableThreadLocal[SQLContext]

/**
* Reference to the last created SQLContext.
* Reference to the created SQLContext.
*/
@transient private val lastInstantiatedContext = new AtomicReference[SQLContext]()
@transient private val instantiatedContext = new AtomicReference[SQLContext]()

/**
* Get the singleton SQLContext if it exists or create a new one using the given SparkContext.
*
* This function can be used to create a singleton SQLContext object that can be shared across
* the JVM.
*
* If there is an active SQLContext for current thread, it will be returned instead of the global
* one.
*
* @since 1.5.0
*/
def getOrCreate(sparkContext: SparkContext): SQLContext = {
INSTANTIATION_LOCK.synchronized {
if (lastInstantiatedContext.get() == null) {
val ctx = activeContext.get()
if (ctx != null) {
return ctx
}

synchronized {
val ctx = instantiatedContext.get()
if (ctx == null) {
new SQLContext(sparkContext)
} else {
ctx
}
}
lastInstantiatedContext.get()
}

private[sql] def clearLastInstantiatedContext(): Unit = {
INSTANTIATION_LOCK.synchronized {
lastInstantiatedContext.set(null)
}
private[sql] def clearInstantiatedContext(sqlContext: SQLContext): Unit = {
instantiatedContext.compareAndSet(sqlContext, null)
}

private[sql] def setLastInstantiatedContext(sqlContext: SQLContext): Unit = {
INSTANTIATION_LOCK.synchronized {
lastInstantiatedContext.set(sqlContext)
}
private[sql] def setInstantiatedContext(sqlContext: SQLContext): Unit = {
instantiatedContext.compareAndSet(null, sqlContext)
}

/**
* Changes the SQLContext that will be returned in this thread and its children when
* SQLContext.getOrCreate() is called. This can be used to ensure that a given thread receives
* a SQLContext with an isolated session, instead of the global (first created) context.
*
* @since 1.6.0
*/
def setActive(sqlContext: SQLContext): Unit = {
activeContext.set(sqlContext)
}

/**
* Clears the active SQLContext for current thread. Subsequent calls to getOrCreate will
* return the first created context instead of a thread-local override.
*
* @since 1.6.0
*/
def clearActive(): Unit = {
activeContext.remove()
}

/**
Expand Down

0 comments on commit 3390b40

Please sign in to comment.