Skip to content

Commit

Permalink
TOREE-390: Lazily initialize spark sessions
Browse files Browse the repository at this point in the history
This updates the kernel so that it lazily initializes
Spark sessions. This allows the user to configure and
start a session, or to wait until spark is first
referenced to start a session. This shortens the time
to the first interaction with Toree.

It also  adds a notification that Toree
is waiting for a Spark session to start when the
getOrCreate method takes more than 100ms.

Closes #109.
  • Loading branch information
rdblue authored and lresende committed Jun 7, 2017
1 parent 6997e4d commit 5cfbc83
Show file tree
Hide file tree
Showing 10 changed files with 100 additions and 218 deletions.
Expand Up @@ -17,22 +17,18 @@

package org.apache.toree.kernel.api

import java.io.{InputStream, OutputStream, PrintStream}

import java.io.{InputStream, PrintStream}
import java.net.URI
import com.typesafe.config.Config
import org.apache.spark.api.java.JavaSparkContext
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.{SQLContext, SparkSession}
import org.apache.spark.{SparkContext, SparkConf}
import org.apache.spark.sql.SparkSession

/**
* Interface for the kernel API. This does not include exposed variables.
*/
trait KernelLike {

def createSparkContext(conf: SparkConf): SparkContext

def createSparkContext(master: String): SparkContext

/**
* Executes a block of code represented as a string and returns the result.
*
Expand Down Expand Up @@ -105,6 +101,8 @@ trait KernelLike {

def config: Config

def addJars(uris: URI*)

def sparkContext: SparkContext

def sparkConf: SparkConf
Expand Down

This file was deleted.

This file was deleted.

Expand Up @@ -45,7 +45,6 @@ trait ComponentInitialization {
* Initializes and registers all components (not needed by bare init).
*
* @param config The config used for initialization
* @param appName The name of the "application" for Spark
* @param actorLoader The actor loader to use for some initialization
*/
def initializeComponents(
Expand Down Expand Up @@ -83,8 +82,6 @@ trait StandardComponentInitialization extends ComponentInitialization {

initializePlugins(config, pluginManager)

initializeSparkContext(config, kernel)

interpreterManager.initializeInterpreters(kernel)

pluginManager.fireEvent(AllInterpretersReady)
Expand All @@ -97,13 +94,6 @@ trait StandardComponentInitialization extends ComponentInitialization {

}


def initializeSparkContext(config:Config, kernel:Kernel) = {
if(!config.getBoolean("nosparkcontext")) {
kernel.createSparkContext(config.getString("spark.master"))
}
}

private def initializeCommObjects(actorLoader: ActorLoader) = {
logger.debug("Constructing Comm storage")
val commStorage = new CommStorage()
Expand Down
84 changes: 54 additions & 30 deletions kernel/src/main/scala/org/apache/toree/kernel/api/Kernel.scala
Expand Up @@ -18,7 +18,8 @@
package org.apache.toree.kernel.api

import java.io.{InputStream, PrintStream}
import java.util.concurrent.ConcurrentHashMap
import java.net.URI
import java.util.concurrent.{ConcurrentHashMap, TimeUnit, TimeoutException}
import scala.collection.mutable
import com.typesafe.config.Config
import org.apache.spark.api.java.JavaSparkContext
Expand All @@ -35,14 +36,15 @@ import org.apache.toree.kernel.protocol.v5
import org.apache.toree.kernel.protocol.v5.kernel.ActorLoader
import org.apache.toree.kernel.protocol.v5.magic.MagicParser
import org.apache.toree.kernel.protocol.v5.stream.KernelOutputStream
import org.apache.toree.kernel.protocol.v5.{KMBuilder, KernelMessage}
import org.apache.toree.kernel.protocol.v5.{KMBuilder, KernelMessage, MIMEType}
import org.apache.toree.magic.MagicManager
import org.apache.toree.plugins.PluginManager
import org.apache.toree.utils.{KeyValuePairUtils, LogLike}
import org.apache.toree.utils.LogLike
import scala.language.dynamics
import scala.reflect.runtime.universe._
import scala.util.{DynamicVariable, Try}
import org.apache.toree.plugins.SparkReady
import scala.util.DynamicVariable
import scala.concurrent.duration.Duration
import scala.concurrent.{Future, Await}

/**
* Represents the main kernel API to be used for interaction.
Expand All @@ -61,6 +63,23 @@ class Kernel (
val pluginManager: PluginManager
) extends KernelLike with LogLike {

/**
* Jars that have been added to the kernel
*/
private val jars = new mutable.ArrayBuffer[URI]()

override def addJars(uris: URI*): Unit = {
uris.foreach { uri =>
if (uri.getScheme != "file") {
throw new RuntimeException("Cannot add non-local jar: " + uri)
}
}

jars ++= uris
interpreter.addJars(uris.map(_.toURL):_*)
uris.foreach(uri => sparkContext.addJar(uri.getPath))
}

/**
* Represents the current input stream used by the kernel for the specific
* thread.
Expand Down Expand Up @@ -339,30 +358,6 @@ class Kernel (
someKernelMessage.get
}

override def createSparkContext(conf: SparkConf): SparkContext = {
val sconf = createSparkConf(conf)
val _sparkSession = SparkSession.builder.config(sconf).getOrCreate()

val sparkMaster = sconf.getOption("spark.master").getOrElse("not_set")
logger.info( s"Connecting to spark.master $sparkMaster")

// TODO: Convert to events
pluginManager.dependencyManager.add(_sparkSession.sparkContext.getConf)
pluginManager.dependencyManager.add(_sparkSession)
pluginManager.dependencyManager.add(_sparkSession.sparkContext)
pluginManager.dependencyManager.add(javaSparkContext(_sparkSession))

pluginManager.fireEvent(SparkReady)

_sparkSession.sparkContext
}

override def createSparkContext(
master: String
): SparkContext = {
createSparkContext(new SparkConf().setMaster(master))
}

// TODO: Think of a better way to test without exposing this
protected[toree] def createSparkConf(conf: SparkConf) = {

Expand Down Expand Up @@ -401,7 +396,36 @@ class Kernel (
interpreterManager.interpreters.get(name)
}

override def sparkSession: SparkSession = SparkSession.builder.getOrCreate
private lazy val defaultSparkConf: SparkConf = createSparkConf(new SparkConf())

override def sparkSession: SparkSession = {
defaultSparkConf.getOption("spark.master") match {
case Some(master) if !master.contains("local") =>
// when connecting to a remote cluster, the first call to getOrCreate
// may create a session and take a long time, so this starts a future
// to get the session. if it take longer than 100 ms, then print a
// message to the user that Spark is starting.
import scala.concurrent.ExecutionContext.Implicits.global
val sessionFuture = Future {
SparkSession.builder.config(defaultSparkConf).getOrCreate
}

try {
Await.result(sessionFuture, Duration(100, TimeUnit.MILLISECONDS))
} catch {
case timeout: TimeoutException =>
// getting the session is taking a long time, so assume that Spark
// is starting and print a message
display.content(
MIMEType.PlainText, "Waiting for a Spark session to start...")
Await.result(sessionFuture, Duration.Inf)
}

case _ =>
SparkSession.builder.config(defaultSparkConf).getOrCreate
}
}

override def sparkContext: SparkContext = sparkSession.sparkContext
override def sparkConf: SparkConf = sparkSession.sparkContext.getConf
override def javaSparkContext: JavaSparkContext = javaSparkContext(sparkSession)
Expand Down
Expand Up @@ -30,7 +30,7 @@ import org.apache.toree.plugins.annotations.Event


class AddDeps extends LineMagic with IncludeInterpreter
with IncludeOutputStream with IncludeSparkContext with ArgumentParsingSupport
with IncludeOutputStream with ArgumentParsingSupport
with IncludeDependencyDownloader with IncludeKernel
{

Expand Down Expand Up @@ -78,7 +78,7 @@ class AddDeps extends LineMagic with IncludeInterpreter

if (nonOptionArgs.size == 3) {
// get the jars and hold onto the paths at which they reside
val urls = dependencyDownloader.retrieve(
val uris = dependencyDownloader.retrieve(
groupId = nonOptionArgs.head,
artifactId = nonOptionArgs(1),
version = nonOptionArgs(2),
Expand All @@ -87,11 +87,10 @@ class AddDeps extends LineMagic with IncludeInterpreter
extraRepositories = repositoriesWithCreds,
verbose = _verbose,
trace = _trace
).map(_.toURL)
)

// add the jars to the interpreter and spark context
interpreter.addJars(urls:_*)
urls.foreach(url => sparkContext.addJar(url.getPath))
// pass the new Jars to the kernel
kernel.addJars(uris:_*)
} else {
printHelp(printStream, """%AddDeps my.company artifact-id version""")
}
Expand Down
Expand Up @@ -47,7 +47,7 @@ object AddJar {
}

class AddJar
extends LineMagic with IncludeInterpreter with IncludeSparkContext
extends LineMagic with IncludeInterpreter
with IncludeOutputStream with DownloadSupport with ArgumentParsingSupport
with IncludeKernel with IncludePluginManager with IncludeConfig with LogLike
{
Expand Down Expand Up @@ -137,8 +137,7 @@ class AddJar
val plugins = pluginManager.loadPlugins(fileDownloadLocation)
pluginManager.initializePlugins(plugins)
} else {
interpreter.addJars(fileDownloadLocation.toURI.toURL)
sparkContext.addJar(fileDownloadLocation.getCanonicalPath)
kernel.addJars(fileDownloadLocation.toURI)
}
}
}

0 comments on commit 5cfbc83

Please sign in to comment.