Skip to content

Commit

Permalink
[SPARK-20048][SQL] Cloning SessionState does not clone query executio…
Browse files Browse the repository at this point in the history
…n listeners

## What changes were proposed in this pull request?

Bugfix from [SPARK-19540.](#16826)
Cloning SessionState does not clone query execution listeners, so cloned session is unable to listen to events on queries.

## How was this patch tested?

- Unit test

Author: Kunal Khamar <kkhamar@outlook.com>

Closes #17379 from kunalkhamar/clone-bugfix.
  • Loading branch information
kunalkhamar authored and hvanhovell committed Mar 29, 2017
1 parent d6ddfdf commit 142f6d1
Show file tree
Hide file tree
Showing 9 changed files with 111 additions and 56 deletions.
22 changes: 11 additions & 11 deletions sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Range}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.ui.SQLListener
import org.apache.spark.sql.internal.{CatalogImpl, SessionState, SharedState}
import org.apache.spark.sql.internal.{BaseSessionStateBuilder, CatalogImpl, SessionState, SessionStateBuilder, SharedState}
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.streaming._
Expand Down Expand Up @@ -194,7 +194,7 @@ class SparkSession private(
*
* @since 2.0.0
*/
def udf: UDFRegistration = sessionState.udf
def udf: UDFRegistration = sessionState.udfRegistration

/**
* :: Experimental ::
Expand Down Expand Up @@ -990,28 +990,28 @@ object SparkSession {
/** Reference to the root SparkSession. */
private val defaultSession = new AtomicReference[SparkSession]

private val HIVE_SESSION_STATE_CLASS_NAME = "org.apache.spark.sql.hive.HiveSessionState"
private val HIVE_SESSION_STATE_BUILDER_CLASS_NAME =
"org.apache.spark.sql.hive.HiveSessionStateBuilder"

private def sessionStateClassName(conf: SparkConf): String = {
conf.get(CATALOG_IMPLEMENTATION) match {
case "hive" => HIVE_SESSION_STATE_CLASS_NAME
case "in-memory" => classOf[SessionState].getCanonicalName
case "hive" => HIVE_SESSION_STATE_BUILDER_CLASS_NAME
case "in-memory" => classOf[SessionStateBuilder].getCanonicalName
}
}

/**
* Helper method to create an instance of `SessionState` based on `className` from conf.
* The result is either `SessionState` or `HiveSessionState`.
* The result is either `SessionState` or a Hive based `SessionState`.
*/
private def instantiateSessionState(
className: String,
sparkSession: SparkSession): SessionState = {

try {
// get `SessionState.apply(SparkSession)`
// invoke `new [Hive]SessionStateBuilder(SparkSession, Option[SessionState])`
val clazz = Utils.classForName(className)
val method = clazz.getMethod("apply", sparkSession.getClass)
method.invoke(null, sparkSession).asInstanceOf[SessionState]
val ctor = clazz.getConstructors.head
ctor.newInstance(sparkSession, None).asInstanceOf[BaseSessionStateBuilder].build()
} catch {
case NonFatal(e) =>
throw new IllegalArgumentException(s"Error while instantiating '$className':", e)
Expand All @@ -1023,7 +1023,7 @@ object SparkSession {
*/
private[spark] def hiveClassesArePresent: Boolean = {
try {
Utils.classForName(HIVE_SESSION_STATE_CLASS_NAME)
Utils.classForName(HIVE_SESSION_STATE_BUILDER_CLASS_NAME)
Utils.classForName("org.apache.hadoop.hive.conf.HiveConf")
true
} catch {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package org.apache.spark.sql.internal

import org.apache.spark.SparkConf
import org.apache.spark.annotation.{Experimental, InterfaceStability}
import org.apache.spark.sql.{ExperimentalMethods, SparkSession, Strategy}
import org.apache.spark.sql.{ExperimentalMethods, SparkSession, Strategy, UDFRegistration}
import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.SessionCatalog
import org.apache.spark.sql.catalyst.optimizer.Optimizer
Expand All @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{QueryExecution, SparkOptimizer, SparkPlanner, SparkSqlParser}
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.streaming.StreamingQueryManager
import org.apache.spark.sql.util.ExecutionListenerManager

/**
* Builder class that coordinates construction of a new [[SessionState]].
Expand Down Expand Up @@ -133,6 +134,14 @@ abstract class BaseSessionStateBuilder(
catalog
}

/**
* Interface exposed to the user for registering user-defined functions.
*
* Note 1: The user-defined functions must be deterministic.
* Note 2: This depends on the `functionRegistry` field.
*/
protected def udfRegistration: UDFRegistration = new UDFRegistration(functionRegistry)

/**
* Logical query plan analyzer for resolving unresolved attributes and relations.
*
Expand Down Expand Up @@ -232,6 +241,16 @@ abstract class BaseSessionStateBuilder(
*/
protected def streamingQueryManager: StreamingQueryManager = new StreamingQueryManager(session)

/**
* An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s
* that listen for execution metrics.
*
* This gets cloned from parent if available, otherwise is a new instance is created.
*/
protected def listenerManager: ExecutionListenerManager = {
parentState.map(_.listenerManager.clone()).getOrElse(new ExecutionListenerManager)
}

/**
* Function used to make clones of the session state.
*/
Expand All @@ -245,17 +264,18 @@ abstract class BaseSessionStateBuilder(
*/
def build(): SessionState = {
new SessionState(
session.sparkContext,
session.sharedState,
conf,
experimentalMethods,
functionRegistry,
udfRegistration,
catalog,
sqlParser,
analyzer,
optimizer,
planner,
streamingQueryManager,
listenerManager,
resourceLoader,
createQueryExecution,
createClone)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,43 +32,46 @@ import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution._
import org.apache.spark.sql.streaming.StreamingQueryManager
import org.apache.spark.sql.util.ExecutionListenerManager
import org.apache.spark.sql.util.{ExecutionListenerManager, QueryExecutionListener}

/**
* A class that holds all session-specific state in a given [[SparkSession]].
*
* @param sparkContext The [[SparkContext]].
* @param sharedState The shared state.
* @param sharedState The state shared across sessions, e.g. global view manager, external catalog.
* @param conf SQL-specific key-value configurations.
* @param experimentalMethods The experimental methods.
* @param experimentalMethods Interface to add custom planning strategies and optimizers.
* @param functionRegistry Internal catalog for managing functions registered by the user.
* @param udfRegistration Interface exposed to the user for registering user-defined functions.
* @param catalog Internal catalog for managing table and database states.
* @param sqlParser Parser that extracts expressions, plans, table identifiers etc. from SQL texts.
* @param analyzer Logical query plan analyzer for resolving unresolved attributes and relations.
* @param optimizer Logical query plan optimizer.
* @param planner Planner that converts optimized logical plans to physical plans
* @param planner Planner that converts optimized logical plans to physical plans.
* @param streamingQueryManager Interface to start and stop streaming queries.
* @param listenerManager Interface to register custom [[QueryExecutionListener]]s.
* @param resourceLoader Session shared resource loader to load JARs, files, etc.
* @param createQueryExecution Function used to create QueryExecution objects.
* @param createClone Function used to create clones of the session state.
*/
private[sql] class SessionState(
sparkContext: SparkContext,
sharedState: SharedState,
val conf: SQLConf,
val experimentalMethods: ExperimentalMethods,
val functionRegistry: FunctionRegistry,
val udfRegistration: UDFRegistration,
val catalog: SessionCatalog,
val sqlParser: ParserInterface,
val analyzer: Analyzer,
val optimizer: Optimizer,
val planner: SparkPlanner,
val streamingQueryManager: StreamingQueryManager,
val listenerManager: ExecutionListenerManager,
val resourceLoader: SessionResourceLoader,
createQueryExecution: LogicalPlan => QueryExecution,
createClone: (SparkSession, SessionState) => SessionState) {

def newHadoopConf(): Configuration = SessionState.newHadoopConf(
sparkContext.hadoopConfiguration,
sharedState.sparkContext.hadoopConfiguration,
conf)

def newHadoopConfWithOptions(options: Map[String, String]): Configuration = {
Expand All @@ -81,18 +84,6 @@ private[sql] class SessionState(
hadoopConf
}

/**
* Interface exposed to the user for registering user-defined functions.
* Note that the user-defined functions must be deterministic.
*/
val udf: UDFRegistration = new UDFRegistration(functionRegistry)

/**
* An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s
* that listen for execution metrics.
*/
val listenerManager: ExecutionListenerManager = new ExecutionListenerManager

/**
* Get an identical copy of the `SessionState` and associate it with the given `SparkSession`
*/
Expand All @@ -110,13 +101,6 @@ private[sql] class SessionState(
}

private[sql] object SessionState {
/**
* Create a new [[SessionState]] for the given session.
*/
def apply(session: SparkSession): SessionState = {
new SessionStateBuilder(session).build()
}

def newHadoopConf(hadoopConf: Configuration, sqlConf: SQLConf): Configuration = {
val newHadoopConf = new Configuration(hadoopConf)
sqlConf.getAllConfs.foreach { case (k, v) => if (v ne null) newHadoopConf.set(k, v) }
Expand Down Expand Up @@ -155,7 +139,7 @@ class SessionResourceLoader(session: SparkSession) extends FunctionResourceLoade
/**
* Add a jar path to [[SparkContext]] and the classloader.
*
* Note: this method seems not access any session state, but the subclass `HiveSessionState` needs
* Note: this method seems not access any session state, but a Hive based `SessionState` needs
* to add the jar to its hive client for the current session. Hence, it still needs to be in
* [[SessionState]].
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,16 @@ class ExecutionListenerManager private[sql] () extends Logging {
listeners.clear()
}

/**
* Get an identical copy of this listener manager.
*/
@DeveloperApi
override def clone(): ExecutionListenerManager = writeLock {
val newListenerManager = new ExecutionListenerManager
listeners.foreach(newListenerManager.register)
newListenerManager
}

private[sql] def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
readLock {
withErrorHandling { listener =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@ package org.apache.spark.sql

import org.scalatest.BeforeAndAfterAll
import org.scalatest.BeforeAndAfterEach
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.util.QueryExecutionListener

class SessionStateSuite extends SparkFunSuite
with BeforeAndAfterEach with BeforeAndAfterAll {
Expand Down Expand Up @@ -122,6 +125,56 @@ class SessionStateSuite extends SparkFunSuite
}
}

test("fork new session and inherit listener manager") {
class CommandCollector extends QueryExecutionListener {
val commands: ArrayBuffer[String] = ArrayBuffer.empty[String]
override def onFailure(funcName: String, qe: QueryExecution, ex: Exception) : Unit = {}
override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
commands += funcName
}
}
val collectorA = new CommandCollector
val collectorB = new CommandCollector
val collectorC = new CommandCollector

try {
def runCollectQueryOn(sparkSession: SparkSession): Unit = {
val tupleEncoder = Encoders.tuple(Encoders.scalaInt, Encoders.STRING)
val df = sparkSession.createDataset(Seq(1 -> "a"))(tupleEncoder).toDF("i", "j")
df.select("i").collect()
}

activeSession.listenerManager.register(collectorA)
val forkedSession = activeSession.cloneSession()

// inheritance
assert(forkedSession ne activeSession)
assert(forkedSession.listenerManager ne activeSession.listenerManager)
runCollectQueryOn(forkedSession)
assert(collectorA.commands.length == 1) // forked should callback to A
assert(collectorA.commands(0) == "collect")

// independence
// => changes to forked do not affect original
forkedSession.listenerManager.register(collectorB)
runCollectQueryOn(activeSession)
assert(collectorB.commands.isEmpty) // original should not callback to B
assert(collectorA.commands.length == 2) // original should still callback to A
assert(collectorA.commands(1) == "collect")
// <= changes to original do not affect forked
activeSession.listenerManager.register(collectorC)
runCollectQueryOn(forkedSession)
assert(collectorC.commands.isEmpty) // forked should not callback to C
assert(collectorA.commands.length == 3) // forked should still callback to A
assert(collectorB.commands.length == 1) // forked should still callback to B
assert(collectorA.commands(2) == "collect")
assert(collectorB.commands(0) == "collect")
} finally {
activeSession.listenerManager.unregister(collectorA)
activeSession.listenerManager.unregister(collectorC)
}
}

test("fork new sessions and run query on inherited table") {
def checkTableExists(sparkSession: SparkSession): Unit = {
QueryTest.checkAnswer(sparkSession.sql(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ import org.apache.thrift.transport.TSocket

import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.hive.{HiveSessionState, HiveUtils}
import org.apache.spark.sql.hive.HiveUtils
import org.apache.spark.util.ShutdownHookManager

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,7 @@ import org.apache.spark.sql.hive.client.HiveClient
import org.apache.spark.sql.internal.{BaseSessionStateBuilder, SessionResourceLoader, SessionState}

/**
* Entry object for creating a Hive aware [[SessionState]].
*/
private[hive] object HiveSessionState {
/**
* Create a new Hive aware [[SessionState]]. for the given session.
*/
def apply(session: SparkSession): SessionState = {
new HiveSessionStateBuilder(session).build()
}
}

/**
* Builder that produces a [[HiveSessionState]].
* Builder that produces a Hive aware [[SessionState]].
*/
@Experimental
@InterfaceStability.Unstable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.command.CacheTableCommand
import org.apache.spark.sql.hive._
import org.apache.spark.sql.hive.client.HiveClient
import org.apache.spark.sql.internal._
import org.apache.spark.sql.internal.{SessionState, SharedState, SQLConf, WithTestConf}
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
import org.apache.spark.util.{ShutdownHookManager, Utils}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql._
import org.apache.spark.sql.hive.test.TestHiveSingleton

/**
* Run all tests from `SessionStateSuite` with a `HiveSessionState`.
* Run all tests from `SessionStateSuite` with a Hive based `SessionState`.
*/
class HiveSessionStateSuite extends SessionStateSuite
with TestHiveSingleton with BeforeAndAfterEach {
Expand Down

0 comments on commit 142f6d1

Please sign in to comment.