Skip to content

Commit

Permalink
Incorporate feedback. Fix association of incorrect SparkSession while…
Browse files Browse the repository at this point in the history
… cloning SessionState.
  • Loading branch information
kunalkhamar committed Feb 10, 2017
1 parent a343d8a commit 4210079
Show file tree
Hide file tree
Showing 8 changed files with 47 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ trait FunctionRegistry {
/** Clear all registered functions. */
def clear(): Unit

/* Create a copy of this registry with identical functions as this registry */
def copy(): FunctionRegistry
}

class SimpleFunctionRegistry extends FunctionRegistry {
Expand Down Expand Up @@ -107,7 +109,7 @@ class SimpleFunctionRegistry extends FunctionRegistry {
functionBuilders.clear()
}

def copy(): SimpleFunctionRegistry = synchronized {
override def copy(): FunctionRegistry = synchronized {
val registry = new SimpleFunctionRegistry
functionBuilders.iterator.foreach { case (name, (info, builder)) =>
registry.registerFunction(name, info, builder)
Expand Down Expand Up @@ -150,6 +152,9 @@ object EmptyFunctionRegistry extends FunctionRegistry {
throw new UnsupportedOperationException
}

override def copy(): FunctionRegistry = {
throw new UnsupportedOperationException
}
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,21 +50,15 @@ class ExperimentalMethods private[sql]() {

/**
* Get an identical copy of this `ExperimentalMethods` instance.
* @note This is used when forking a `SparkSession`.
* `clone` is provided here instead of implementing equivalent functionality
*
* @note `clone` is provided here instead of implementing equivalent functionality
* in `SparkSession.clone` since it needs to be updated
* as the class `ExperimentalMethods` is extended or modified.
*/
override def clone: ExperimentalMethods = {
def cloneSeq[T](seq: Seq[T]): Seq[T] = {
val newSeq = new ListBuffer[T]
newSeq ++= seq
newSeq
}

val result = new ExperimentalMethods
result.extraStrategies = cloneSeq(extraStrategies)
result.extraOptimizations = cloneSeq(extraOptimizations)
result.extraStrategies = extraStrategies
result.extraOptimizations = extraOptimizations
result
}
}
23 changes: 13 additions & 10 deletions sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ class SparkSession private(
*/
@transient
private[sql] lazy val sessionState: SessionState = {
existingSessionState.getOrElse(SparkSession.reflect[SessionState, SparkSession](
existingSessionState.map(_.clone(this)).getOrElse(SparkSession.reflect[SessionState, SparkSession](
SparkSession.sessionStateClassName(sparkContext.conf),
self))
}
Expand Down Expand Up @@ -219,21 +219,24 @@ class SparkSession private(
}

/**
* Start a new session, sharing the underlying `SparkContext` and cached data.
* If inheritSessionState is enabled, then SQL configurations, temporary tables,
* registered functions are copied over from parent `SparkSession`.
* :: Experimental ::
* Create an identical copy of this `SparkSession`, sharing the underlying `SparkContext`
* and cached data. SessionState (SQL configurations, temporary tables, registered functions)
* is also copied over.
* Changes to base session are not propagated to cloned session, cloned is independent
* after creation.
*
* @note Other than the `SparkContext`, all shared state is initialized lazily.
* This method will force the initialization of the shared state to ensure that parent
* and child sessions are set up with the same shared state. If the underlying catalog
* implementation is Hive, this will initialize the metastore, which may take some time.
*
* @since 2.1.1
*/
def newSession(inheritSessionState: Boolean): SparkSession = {
if (inheritSessionState) {
new SparkSession(sparkContext, Some(sharedState), Some(sessionState.clone))
} else {
newSession()
}
@Experimental
@InterfaceStability.Evolving
def cloneSession(): SparkSession = {
new SparkSession(sparkContext, Some(sharedState), Some(sessionState))
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -980,6 +980,14 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging {
def clear(): Unit = {
settings.clear()
}

override def clone: SQLConf = {
val result = new SQLConf
getAllConfs.foreach {
case(k, v) => if (v ne null) result.setConfString(k, v)
}
result
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ import org.apache.spark.sql.util.ExecutionListenerManager
*/
private[sql] class SessionState(
sparkSession: SparkSession,
existingSessionState: Option[SessionState]) {
parentSessionState: Option[SessionState]) {

private[sql] def this(sparkSession: SparkSession) = {
this(sparkSession, None)
Expand All @@ -55,13 +55,7 @@ private[sql] class SessionState(
* SQL-specific key-value configurations.
*/
lazy val conf: SQLConf = {
val result = new SQLConf
if (existingSessionState.nonEmpty) {
existingSessionState.get.conf.getAllConfs.foreach {
case (k, v) => if (v ne null) result.setConfString(k, v)
}
}
result
parentSessionState.map(_.conf.clone).getOrElse(new SQLConf)
}

def newHadoopConf(): Configuration = {
Expand All @@ -81,7 +75,7 @@ private[sql] class SessionState(
}

lazy val experimentalMethods: ExperimentalMethods = {
existingSessionState
parentSessionState
.map(_.experimentalMethods.clone)
.getOrElse(new ExperimentalMethods)
}
Expand All @@ -90,18 +84,7 @@ private[sql] class SessionState(
* Internal catalog for managing functions registered by the user.
*/
lazy val functionRegistry: FunctionRegistry = {
val registry = FunctionRegistry.builtin.copy()

if (existingSessionState.nonEmpty) {
val sourceRegistry = existingSessionState.get.functionRegistry
sourceRegistry
.listFunction()
.foreach(name => registry.registerFunction(
name,
sourceRegistry.lookupFunction(name).get,
sourceRegistry.lookupFunctionBuilder(name).get))
}
registry
parentSessionState.map(_.functionRegistry.copy()).getOrElse(FunctionRegistry.builtin.copy())
}

/**
Expand All @@ -126,7 +109,7 @@ private[sql] class SessionState(
* Internal catalog for managing table and database states.
*/
lazy val catalog: SessionCatalog = {
existingSessionState
parentSessionState
.map(_.catalog.clone)
.getOrElse(new SessionCatalog(
sparkSession.sharedState.externalCatalog,
Expand Down Expand Up @@ -202,10 +185,10 @@ private[sql] class SessionState(
}

/**
* Get an identical copy of the `SessionState`.
* Get an identical copy of the `SessionState` and associate it with the given `SparkSession`
*/
override def clone: SessionState = {
new SessionState(sparkSession, Some(this))
def clone(sc: SparkSession): SessionState = {
new SessionState(sc, Some(this))
}

// ------------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class SparkSessionBuilderSuite extends SparkFunSuite {

test("fork new session and inherit a copy of the session state") {
val activeSession = SparkSession.builder().master("local").getOrCreate()
val forkedSession = activeSession.newSession(inheritSessionState = true)
val forkedSession = activeSession.cloneSession()

assert(forkedSession ne activeSession)
assert(forkedSession.sessionState ne activeSession.sessionState)
Expand All @@ -141,9 +141,9 @@ class SparkSessionBuilderSuite extends SparkFunSuite {
val activeSession = SparkSession
.builder()
.master("local")
.config("spark-configb", "b")
.getOrCreate()
val forkedSession = activeSession.newSession(inheritSessionState = true)
activeSession.conf.set("spark-configb", "b")
val forkedSession = activeSession.cloneSession()

assert(forkedSession ne activeSession)
assert(forkedSession.conf ne activeSession.conf)
Expand All @@ -156,7 +156,7 @@ class SparkSessionBuilderSuite extends SparkFunSuite {
test("fork new session and inherit function registry and udf") {
val activeSession = SparkSession.builder().master("local").getOrCreate()
activeSession.udf.register("strlenScala", (_: String).length + (_: Int))
val forkedSession = activeSession.newSession(inheritSessionState = true)
val forkedSession = activeSession.cloneSession()

assert(forkedSession ne activeSession)
assert(forkedSession.sessionState.functionRegistry ne
Expand All @@ -179,12 +179,9 @@ class SparkSessionBuilderSuite extends SparkFunSuite {
val activeSession = SparkSession.builder().master("local").getOrCreate()
activeSession.experimental.extraOptimizations = optimizations

val forkedSession = activeSession.newSession(inheritSessionState = true)
val forkedSession = activeSession.cloneSession()

assert(forkedSession ne activeSession)
assert(forkedSession.experimental ne activeSession.experimental)
assert(forkedSession.experimental.extraOptimizations ne
activeSession.experimental.extraOptimizations)
assert(forkedSession.experimental.extraOptimizations.toSet ==
activeSession.experimental.extraOptimizations.toSet)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ class CatalogSuite
createTempTable("my_temp_table")
assert(spark.catalog.listTables().collect().map(_.name).toSet == Set("my_temp_table"))

val forkedSession = spark.newSession(inheritSessionState = true)
val forkedSession = spark.cloneSession()
assert(forkedSession.catalog.listTables().collect().map(_.name).toSet == Set("my_temp_table"))

dropTable("my_temp_table")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.command.CacheTableCommand
import org.apache.spark.sql.hive._
import org.apache.spark.sql.internal.{SessionState, SharedState, SQLConf}
import org.apache.spark.sql.internal.{SharedState, SQLConf}
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
import org.apache.spark.util.{ShutdownHookManager, Utils}

Expand Down Expand Up @@ -113,22 +113,16 @@ class TestHiveContext(
private[hive] class TestHiveSparkSession(
@transient private val sc: SparkContext,
@transient private val existingSharedState: Option[SharedState],
existingSessionState: Option[SessionState],
private val loadTestTables: Boolean)
extends SparkSession(sc) with Logging { self =>

def this(sc: SparkContext, loadTestTables: Boolean) {
this(
sc,
existingSharedState = None,
existingSessionState = None,
loadTestTables)
}

def this(sc: SparkContext, existingSharedState: Option[SharedState], loadTestTables: Boolean) {
this(sc, existingSharedState, existingSessionState = None, loadTestTables)
}

{ // set the metastore temporary configuration
val metastoreTempConf = HiveUtils.newTemporaryConfiguration(useInMemoryDerby = false) ++ Map(
ConfVars.METASTORE_INTEGER_JDO_PUSHDOWN.varname -> "true",
Expand All @@ -155,7 +149,7 @@ private[hive] class TestHiveSparkSession(
new TestHiveSessionState(self)

override def newSession(): TestHiveSparkSession = {
new TestHiveSparkSession(sc, Some(sharedState), None, loadTestTables)
new TestHiveSparkSession(sc, Some(sharedState), loadTestTables)
}

private var cacheTables: Boolean = false
Expand Down

0 comments on commit 4210079

Please sign in to comment.