Skip to content

Commit

Permalink
[SPARK-15075][SPARK-15345][SQL] Clean up SparkSession builder and pro…
Browse files Browse the repository at this point in the history
…pagate config options to existing sessions if specified

## What changes were proposed in this pull request?
Currently SparkSession.Builder use SQLContext.getOrCreate. It should probably the the other way around, i.e. all the core logic goes in SparkSession, and SQLContext just calls that. This patch does that.

This patch also makes sure config options specified in the builder are propagated to the existing (and of course the new) SparkSession.

## How was this patch tested?
Updated tests to reflect the change, and also introduced a new SparkSessionBuilderSuite that should cover all the branches.

Author: Reynold Xin <rxin@databricks.com>

Closes #13200 from rxin/SPARK-15075.

(cherry picked from commit f2ee0ed)
Signed-off-by: Reynold Xin <rxin@databricks.com>
  • Loading branch information
rxin committed May 20, 2016
1 parent e6810e9 commit 52b967f
Show file tree
Hide file tree
Showing 43 changed files with 367 additions and 357 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public void testDefaultReadWrite() throws IOException {
} catch (IOException e) {
// expected
}
instance.write().context(spark.wrapped()).overwrite().save(outputPath);
instance.write().context(spark.sqlContext()).overwrite().save(outputPath);
MyParams newInstance = MyParams.load(outputPath);
Assert.assertEquals("UID should match.", instance.uid(), newInstance.uid());
Assert.assertEquals("Params should be preserved.",
Expand Down
5 changes: 4 additions & 1 deletion python/pyspark/sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@


class SQLContext(object):
"""Wrapper around :class:`SparkSession`, the main entry point to Spark SQL functionality.
"""The entry point for working with structured data (rows and columns) in Spark, in Spark 1.x.
As of Spark 2.0, this is replaced by :class:`SparkSession`. However, we are keeping the class
here for backward compatibility.
A SQLContext can be used create :class:`DataFrame`, register :class:`DataFrame` as
tables, execute SQL over tables, cache tables, and read parquet files.
Expand Down
17 changes: 14 additions & 3 deletions python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ def master(self, master):
def appName(self, name):
"""Sets a name for the application, which will be shown in the Spark web UI.
If no application name is set, a randomly generated name will be used.
:param name: an application name
"""
return self.config("spark.app.name", name)
Expand All @@ -133,8 +135,17 @@ def enableHiveSupport(self):

@since(2.0)
def getOrCreate(self):
"""Gets an existing :class:`SparkSession` or, if there is no existing one, creates a new
one based on the options set in this builder.
"""Gets an existing :class:`SparkSession` or, if there is no existing one, creates a
new one based on the options set in this builder.
This method first checks whether there is a valid thread-local SparkSession,
and if yes, return that one. It then checks whether there is a valid global
default SparkSession, and if yes, return that one. If no valid global default
SparkSession exists, the method creates a new SparkSession and assigns the
newly created SparkSession as the global default.
In case an existing SparkSession is returned, the config options specified
in this builder will be applied to the existing SparkSession.
"""
with self._lock:
from pyspark.conf import SparkConf
Expand Down Expand Up @@ -175,7 +186,7 @@ def __init__(self, sparkContext, jsparkSession=None):
if jsparkSession is None:
jsparkSession = self._jvm.SparkSession(self._jsc.sc())
self._jsparkSession = jsparkSession
self._jwrapped = self._jsparkSession.wrapped()
self._jwrapped = self._jsparkSession.sqlContext()
self._wrapped = SQLContext(self._sc, self, self._jwrapped)
_monkey_patch_RDD(self)
install_exception_handler()
Expand Down
2 changes: 1 addition & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ class Dataset[T] private[sql](
private implicit def classTag = unresolvedTEncoder.clsTag

// sqlContext must be val because a stable identifier is expected when you import implicits
@transient lazy val sqlContext: SQLContext = sparkSession.wrapped
@transient lazy val sqlContext: SQLContext = sparkSession.sqlContext

protected[sql] def resolve(colName: String): NamedExpression = {
queryExecution.analyzed.resolveQuoted(colName, sparkSession.sessionState.analyzer.resolver)
Expand Down
124 changes: 12 additions & 112 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,22 @@ package org.apache.spark.sql

import java.beans.BeanInfo
import java.util.Properties
import java.util.concurrent.atomic.AtomicReference

import scala.collection.immutable
import scala.reflect.runtime.universe.TypeTag

import org.apache.spark.{SparkConf, SparkContext, SparkException}
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.ConfigEntry
import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd}
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.command.ShowTablesCommand
import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab}
import org.apache.spark.sql.internal.{SessionState, SharedState, SQLConf}
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types._
Expand All @@ -46,8 +43,8 @@ import org.apache.spark.sql.util.ExecutionListenerManager
/**
* The entry point for working with structured data (rows and columns) in Spark, in Spark 1.x.
*
* As of Spark 2.0, this is replaced by [[SparkSession]]. However, we are keeping the class here
* for backward compatibility.
* As of Spark 2.0, this is replaced by [[SparkSession]]. However, we are keeping the class
* here for backward compatibility.
*
* @groupname basic Basic Operations
* @groupname ddl_ops Persistent Catalog DDL
Expand Down Expand Up @@ -76,42 +73,21 @@ class SQLContext private[sql](
this(sparkSession, true)
}

@deprecated("Use SparkSession.builder instead", "2.0.0")
def this(sc: SparkContext) = {
this(new SparkSession(sc))
}

@deprecated("Use SparkSession.builder instead", "2.0.0")
def this(sparkContext: JavaSparkContext) = this(sparkContext.sc)

// TODO: move this logic into SparkSession

// If spark.sql.allowMultipleContexts is true, we will throw an exception if a user
// wants to create a new root SQLContext (a SQLContext that is not created by newSession).
private val allowMultipleContexts =
sparkContext.conf.getBoolean(
SQLConf.ALLOW_MULTIPLE_CONTEXTS.key,
SQLConf.ALLOW_MULTIPLE_CONTEXTS.defaultValue.get)

// Assert no root SQLContext is running when allowMultipleContexts is false.
{
if (!allowMultipleContexts && isRootContext) {
SQLContext.getInstantiatedContextOption() match {
case Some(rootSQLContext) =>
val errMsg = "Only one SQLContext/HiveContext may be running in this JVM. " +
s"It is recommended to use SQLContext.getOrCreate to get the instantiated " +
s"SQLContext/HiveContext. To ignore this error, " +
s"set ${SQLConf.ALLOW_MULTIPLE_CONTEXTS.key} = true in SparkConf."
throw new SparkException(errMsg)
case None => // OK
}
}
}

protected[sql] def sessionState: SessionState = sparkSession.sessionState
protected[sql] def sharedState: SharedState = sparkSession.sharedState
protected[sql] def conf: SQLConf = sessionState.conf
protected[sql] def runtimeConf: RuntimeConfig = sparkSession.conf
protected[sql] def cacheManager: CacheManager = sparkSession.cacheManager
protected[sql] def listener: SQLListener = sparkSession.listener
protected[sql] def externalCatalog: ExternalCatalog = sparkSession.externalCatalog

def sparkContext: SparkContext = sparkSession.sparkContext
Expand All @@ -123,7 +99,7 @@ class SQLContext private[sql](
*
* @since 1.6.0
*/
def newSession(): SQLContext = sparkSession.newSession().wrapped
def newSession(): SQLContext = sparkSession.newSession().sqlContext

/**
* An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s
Expand Down Expand Up @@ -760,21 +736,6 @@ class SQLContext private[sql](
schema: StructType): DataFrame = {
sparkSession.applySchemaToPythonRDD(rdd, schema)
}

// TODO: move this logic into SparkSession

// Register a successfully instantiated context to the singleton. This should be at the end of
// the class definition so that the singleton is updated only if there is no exception in the
// construction of the instance.
sparkContext.addSparkListener(new SparkListener {
override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = {
SQLContext.clearInstantiatedContext()
SQLContext.clearSqlListener()
}
})

sparkSession.setWrappedContext(self)
SQLContext.setInstantiatedContext(self)
}

/**
Expand All @@ -787,19 +748,6 @@ class SQLContext private[sql](
*/
object SQLContext {

/**
* The active SQLContext for the current thread.
*/
private val activeContext: InheritableThreadLocal[SQLContext] =
new InheritableThreadLocal[SQLContext]

/**
* Reference to the created SQLContext.
*/
@transient private val instantiatedContext = new AtomicReference[SQLContext]()

@transient private val sqlListener = new AtomicReference[SQLListener]()

/**
* Get the singleton SQLContext if it exists or create a new one using the given SparkContext.
*
Expand All @@ -811,41 +759,9 @@ object SQLContext {
*
* @since 1.5.0
*/
@deprecated("Use SparkSession.builder instead", "2.0.0")
def getOrCreate(sparkContext: SparkContext): SQLContext = {
val ctx = activeContext.get()
if (ctx != null && !ctx.sparkContext.isStopped) {
return ctx
}

synchronized {
val ctx = instantiatedContext.get()
if (ctx == null || ctx.sparkContext.isStopped) {
new SQLContext(sparkContext)
} else {
ctx
}
}
}

private[sql] def clearInstantiatedContext(): Unit = {
instantiatedContext.set(null)
}

private[sql] def setInstantiatedContext(sqlContext: SQLContext): Unit = {
synchronized {
val ctx = instantiatedContext.get()
if (ctx == null || ctx.sparkContext.isStopped) {
instantiatedContext.set(sqlContext)
}
}
}

private[sql] def getInstantiatedContextOption(): Option[SQLContext] = {
Option(instantiatedContext.get())
}

private[sql] def clearSqlListener(): Unit = {
sqlListener.set(null)
SparkSession.builder().sparkContext(sparkContext).getOrCreate().sqlContext
}

/**
Expand All @@ -855,8 +771,9 @@ object SQLContext {
*
* @since 1.6.0
*/
@deprecated("Use SparkSession.setActiveSession instead", "2.0.0")
def setActive(sqlContext: SQLContext): Unit = {
activeContext.set(sqlContext)
SparkSession.setActiveSession(sqlContext.sparkSession)
}

/**
Expand All @@ -865,12 +782,9 @@ object SQLContext {
*
* @since 1.6.0
*/
@deprecated("Use SparkSession.clearActiveSession instead", "2.0.0")
def clearActive(): Unit = {
activeContext.remove()
}

private[sql] def getActive(): Option[SQLContext] = {
Option(activeContext.get())
SparkSession.clearActiveSession()
}

/**
Expand All @@ -894,20 +808,6 @@ object SQLContext {
}
}

/**
* Create a SQLListener then add it into SparkContext, and create an SQLTab if there is SparkUI.
*/
private[sql] def createListenerAndUI(sc: SparkContext): SQLListener = {
if (sqlListener.get() == null) {
val listener = new SQLListener(sc.conf)
if (sqlListener.compareAndSet(null, listener)) {
sc.addSparkListener(listener)
sc.ui.foreach(new SQLTab(listener, _))
}
}
sqlListener.get()
}

/**
* Extract `spark.sql.*` properties from the conf and return them as a [[Properties]].
*/
Expand Down

0 comments on commit 52b967f

Please sign in to comment.