Skip to content

Commit

Permalink
Refactor SessionState to remove passing of base SessionState, and ini…
Browse files Browse the repository at this point in the history
…tialize all fields directly instead. Same change for HiveSessionState.
  • Loading branch information
kunalkhamar committed Feb 16, 2017
1 parent 579d0b7 commit 2837e73
Show file tree
Hide file tree
Showing 7 changed files with 313 additions and 174 deletions.
17 changes: 9 additions & 8 deletions sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import java.io.Closeable
import java.util.concurrent.atomic.AtomicReference

import scala.collection.JavaConverters._
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag
import scala.util.control.NonFatal

Expand Down Expand Up @@ -114,7 +113,7 @@ class SparkSession private(
private[sql] lazy val sessionState: SessionState = {
existingSessionState
.map(_.copy(this))
.getOrElse(SparkSession.reflect[SessionState, SparkSession](
.getOrElse(SparkSession.instantiateSessionState(
SparkSession.sessionStateClassName(sparkContext.conf),
self))
}
Expand Down Expand Up @@ -994,16 +993,18 @@ object SparkSession {
}

/**
* Helper method to create an instance of [[T]] using a single-arg constructor that
* accepts an [[Arg]].
* Helper method to create an instance of `SessionState`
* The result is either `SessionState` or `HiveSessionState`
*/
private def reflect[T, Arg <: AnyRef](
private def instantiateSessionState(
className: String,
ctorArg: Arg)(implicit ctorArgTag: ClassTag[Arg]): T = {
sparkSession: SparkSession): SessionState = {

try {
// get `SessionState.apply(SparkSession)`
val clazz = Utils.classForName(className)
val ctor = clazz.getDeclaredConstructor(ctorArgTag.runtimeClass)
ctor.newInstance(ctorArg).asInstanceOf[T]
val method = clazz.getMethod("apply", sparkSession.getClass)
method.invoke(null, sparkSession).asInstanceOf[SessionState]
} catch {
case NonFatal(e) =>
throw new IllegalArgumentException(s"Error while instantiating '$className':", e)
Expand Down
245 changes: 128 additions & 117 deletions sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,105 +32,28 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.command.AnalyzeTableCommand
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryManager}
import org.apache.spark.sql.streaming.{StreamingQueryManager}
import org.apache.spark.sql.util.ExecutionListenerManager


/**
* A class that holds all session-specific state in a given [[SparkSession]].
* If an `existingSessionState` is supplied, then its members will be copied over.
*/
private[sql] class SessionState(
sparkSession: SparkSession,
parentSessionState: Option[SessionState]) {
val conf: SQLConf,
val experimentalMethods: ExperimentalMethods,
val functionRegistry: FunctionRegistry,
val catalog: SessionCatalog,
val sqlParser: ParserInterface) {

private[sql] def this(sparkSession: SparkSession) = {
this(sparkSession, None)
}

// Note: Many of these vals depend on each other (e.g. conf) and should be initialized
// with an early initializer if we want subclasses to override some of the fields.
// Otherwise, we would get a lot of NPEs.

/**
* SQL-specific key-value configurations.
*/
val conf: SQLConf = {
parentSessionState.map(_.conf.copy).getOrElse(new SQLConf)
}

def newHadoopConf(): Configuration = {
val hadoopConf = new Configuration(sparkSession.sparkContext.hadoopConfiguration)
conf.getAllConfs.foreach { case (k, v) => if (v ne null) hadoopConf.set(k, v) }
hadoopConf
}

def newHadoopConfWithOptions(options: Map[String, String]): Configuration = {
val hadoopConf = newHadoopConf()
options.foreach { case (k, v) =>
if ((v ne null) && k != "path" && k != "paths") {
hadoopConf.set(k, v)
}
}
hadoopConf
}

val experimentalMethods: ExperimentalMethods = {
parentSessionState
.map(_.experimentalMethods.copy)
.getOrElse(new ExperimentalMethods)
}

/**
* Internal catalog for managing functions registered by the user.
*/
val functionRegistry: FunctionRegistry = {
parentSessionState.map(_.functionRegistry.copy).getOrElse(FunctionRegistry.builtin.copy)
}

/**
* A class for loading resources specified by a function.
*/
val functionResourceLoader: FunctionResourceLoader = {
new FunctionResourceLoader {
override def loadResource(resource: FunctionResource): Unit = {
resource.resourceType match {
case JarResource => addJar(resource.uri)
case FileResource => sparkSession.sparkContext.addFile(resource.uri)
case ArchiveResource =>
throw new AnalysisException(
"Archive is not allowed to be loaded. If YARN mode is used, " +
"please use --archives options while calling spark-submit.")
}
}
}
}

/**
* Internal catalog for managing table and database states.
*/
val catalog: SessionCatalog = {
parentSessionState
.map(_.catalog.copy)
.getOrElse(new SessionCatalog(
sparkSession.sharedState.externalCatalog,
sparkSession.sharedState.globalTempViewManager,
functionResourceLoader,
functionRegistry,
conf,
newHadoopConf(),
sqlParser))
}

/**
/*
* Interface exposed to the user for registering user-defined functions.
* Note that the user-defined functions must be deterministic.
*/
val udf: UDFRegistration = new UDFRegistration(functionRegistry)

/**
* Logical query plan analyzer for resolving unresolved attributes and relations.
*/
// Logical query plan analyzer for resolving unresolved attributes and relations.
val analyzer: Analyzer = {
new Analyzer(catalog, conf) {
override val extendedResolutionRules =
Expand All @@ -147,49 +70,56 @@ private[sql] class SessionState(
}
}

/**
* Logical query plan optimizer.
*/
// Logical query plan optimizer.
val optimizer: Optimizer = new SparkOptimizer(catalog, conf, experimentalMethods)

/**
* Parser that extracts expressions, plans, table identifiers etc. from SQL texts.
*/
val sqlParser: ParserInterface = new SparkSqlParser(conf)

/**
* Planner that converts optimized logical plans to physical plans.
*/
def planner: SparkPlanner =
new SparkPlanner(sparkSession.sparkContext, conf, experimentalMethods.extraStrategies)

/**
/*
* An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s
* that listen for execution metrics.
*/
val listenerManager: ExecutionListenerManager = new ExecutionListenerManager

/**
* Interface to start and stop [[StreamingQuery]]s.
*/
val streamingQueryManager: StreamingQueryManager = {
new StreamingQueryManager(sparkSession)
}

private val jarClassLoader: NonClosableMutableURLClassLoader =
sparkSession.sharedState.jarClassLoader
// Interface to start and stop [[StreamingQuery]]s.
val streamingQueryManager: StreamingQueryManager = new StreamingQueryManager(sparkSession)

// Automatically extract all entries and put it in our SQLConf
// We need to call it after all of vals have been initialized.
sparkSession.sparkContext.getConf.getAll.foreach { case (k, v) =>
conf.setConfString(k, v)
}

/**
* Planner that converts optimized logical plans to physical plans.
*/
def planner: SparkPlanner =
new SparkPlanner(sparkSession.sparkContext, conf, experimentalMethods.extraStrategies)

def newHadoopConf(): Configuration = SessionState.newHadoopConf(sparkSession, conf)

def newHadoopConfWithOptions(options: Map[String, String]): Configuration = {
val hadoopConf = newHadoopConf()
options.foreach { case (k, v) =>
if ((v ne null) && k != "path" && k != "paths") {
hadoopConf.set(k, v)
}
}
hadoopConf
}

/**
* Get an identical copy of the `SessionState` and associate it with the given `SparkSession`
*/
def copy(associatedSparkSession: SparkSession): SessionState = {
new SessionState(associatedSparkSession, Some(this))
val sqlConf = conf.copy
val sqlParser: ParserInterface = new SparkSqlParser(sqlConf)

new SessionState(
sparkSession,
sqlConf,
experimentalMethods.copy,
functionRegistry.copy,
catalog.copy,
sqlParser)
}

// ------------------------------------------------------
Expand All @@ -202,7 +132,89 @@ private[sql] class SessionState(
catalog.refreshTable(sqlParser.parseTableIdentifier(tableName))
}

def addJar(path: String): Unit = {
private val jarClassLoader: NonClosableMutableURLClassLoader =
sparkSession.sharedState.jarClassLoader

def addJar(path: String): Unit = SessionState.addJar(sparkSession, path, jarClassLoader)

/**
* Analyzes the given table in the current database to generate statistics, which will be
* used in query optimizations.
*/
def analyze(tableIdent: TableIdentifier, noscan: Boolean = true): Unit = {
AnalyzeTableCommand(tableIdent, noscan).run(sparkSession)
}

}


object SessionState {

def apply(sparkSession: SparkSession): SessionState = {
apply(sparkSession, None)
}

def apply(
sparkSession: SparkSession,
conf: Option[SQLConf]): SessionState = {

// SQL-specific key-value configurations.
val sqlConf = conf.getOrElse(new SQLConf)

// Internal catalog for managing functions registered by the user.
val functionRegistry = FunctionRegistry.builtin.copy

val jarClassLoader: NonClosableMutableURLClassLoader = sparkSession.sharedState.jarClassLoader

// A class for loading resources specified by a function.
val functionResourceLoader: FunctionResourceLoader =
createFunctionResourceLoader(sparkSession, jarClassLoader)

// Parser that extracts expressions, plans, table identifiers etc. from SQL texts.
val sqlParser: ParserInterface = new SparkSqlParser(sqlConf)

// Internal catalog for managing table and database states.
val catalog = new SessionCatalog(
sparkSession.sharedState.externalCatalog,
sparkSession.sharedState.globalTempViewManager,
functionResourceLoader,
functionRegistry,
sqlConf,
newHadoopConf(sparkSession, sqlConf),
sqlParser)

new SessionState(
sparkSession,
sqlConf,
new ExperimentalMethods,
functionRegistry,
catalog,
sqlParser)
}

def createFunctionResourceLoader(
sparkSession: SparkSession,
jarClassLoader: NonClosableMutableURLClassLoader): FunctionResourceLoader = {

new FunctionResourceLoader {
override def loadResource(resource: FunctionResource): Unit = {
resource.resourceType match {
case JarResource => addJar(sparkSession, resource.uri, jarClassLoader)
case FileResource => sparkSession.sparkContext.addFile(resource.uri)
case ArchiveResource =>
throw new AnalysisException(
"Archive is not allowed to be loaded. If YARN mode is used, " +
"please use --archives options while calling spark-submit.")
}
}
}
}

def addJar(
sparkSession: SparkSession,
path: String,
jarClassLoader: NonClosableMutableURLClassLoader): Unit = {

sparkSession.sparkContext.addJar(path)

val uri = new Path(path).toUri
Expand All @@ -217,11 +229,10 @@ private[sql] class SessionState(
Thread.currentThread().setContextClassLoader(jarClassLoader)
}

/**
* Analyzes the given table in the current database to generate statistics, which will be
* used in query optimizations.
*/
def analyze(tableIdent: TableIdentifier, noscan: Boolean = true): Unit = {
AnalyzeTableCommand(tableIdent, noscan).run(sparkSession)
def newHadoopConf(sparkSession: SparkSession, conf: SQLConf): Configuration = {
val hadoopConf = new Configuration(sparkSession.sparkContext.hadoopConfiguration)
conf.getAllConfs.foreach { case (k, v) => if (v ne null) hadoopConf.set(k, v) }
hadoopConf
}

}
Loading

0 comments on commit 2837e73

Please sign in to comment.