Skip to content

Commit

Permalink
Merge pull request #113 from yaooqinn/KYUUBI-107
Browse files Browse the repository at this point in the history
[KYUUBI-107] kill yarn application fast when spark context fails to initialize
  • Loading branch information
yaooqinn committed Oct 30, 2018
2 parents 8bbc3d5 + a4cb010 commit 8a30b63
Show file tree
Hide file tree
Showing 16 changed files with 425 additions and 63 deletions.
5 changes: 5 additions & 0 deletions kyuubi-server/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,11 @@
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
</dependency>

<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-minicluster</artifactId>
</dependency>
</dependencies>

<build>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.util.HashMap
import java.util.concurrent.TimeUnit

import scala.collection.JavaConverters._
import scala.language.implicitConversions

import org.apache.spark.internal.config.{ConfigBuilder, ConfigEntry}

Expand Down Expand Up @@ -365,4 +366,9 @@ object KyuubiConf {
(kv.getKey, kv.getValue.defaultValueString)
}.toMap
}

implicit def convertBooleanConf(config: ConfigEntry[Boolean]): String = config.key
implicit def convertIntConf(config: ConfigEntry[Int]): String = config.key
implicit def convertLongConf(config: ConfigEntry[Long]): String = config.key
implicit def convertStringConf(config: ConfigEntry[String]): String = config.key
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class TSetIpAddressProcessor[I <: Iface](iface: Iface)
}
}

private def setUserName(in: TProtocol) = {
private def setUserName(in: TProtocol): Unit = {
val transport = in.getTransport
transport match {
case transport1: TSaslServerTransport =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class HandleIdentifier(val publicId: UUID, val secretId: UUID) {
Option(secret).map(id => new UUID(id.getLong(), id.getLong())).getOrElse(UUID.randomUUID()))

def this(tHandleId: THandleIdentifier) =
this(tHandleId.bufferForGuid(), tHandleId.bufferForSecret())
this(ByteBuffer.wrap(tHandleId.getGuid), ByteBuffer.wrap(tHandleId.getSecret))

def getPublicId: UUID = this.publicId
def getSecretId: UUID = this.secretId
Expand Down Expand Up @@ -80,5 +80,5 @@ class HandleIdentifier(val publicId: UUID, val secretId: UUID) {
true
}

override def toString: String = publicId.toString
override def toString: String = Option(publicId).map(_.toString).getOrElse("")
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,24 +49,24 @@ import yaooqinn.kyuubi.utils.NamedThreadFactory
private[kyuubi] class FrontendService private(name: String, beService: BackendService)
extends AbstractService(name) with TCLIService.Iface with Runnable with Logging {

private[this] var hadoopConf: Configuration = _
private[this] var authFactory: KyuubiAuthFactory = _
private var hadoopConf: Configuration = _
private var authFactory: KyuubiAuthFactory = _

private[this] val OK_STATUS = new TStatus(TStatusCode.SUCCESS_STATUS)
private val OK_STATUS = new TStatus(TStatusCode.SUCCESS_STATUS)

private[this] var serverEventHandler: TServerEventHandler = _
private[this] var currentServerContext: ThreadLocal[ServerContext] = _
private var serverEventHandler: TServerEventHandler = _
private var currentServerContext: ThreadLocal[ServerContext] = _

private[this] var server: Option[TServer] = None
private[this] var portNum = 0
private[this] var serverIPAddress: InetAddress = _
private[this] var serverSocket: ServerSocket = _
private var server: Option[TServer] = None
private var portNum = 0
private var serverIPAddress: InetAddress = _
private var serverSocket: ServerSocket = _

private[this] val threadPoolName = name + "-Handler-Pool"
private val threadPoolName = name + "-Handler-Pool"

private[this] var isStarted = false
private var isStarted = false

private[this] var realUser: String = _
private var realUser: String = _

def this(beService: BackendService) = {
this(classOf[FrontendService].getSimpleName, beService)
Expand Down Expand Up @@ -112,16 +112,15 @@ private[kyuubi] class FrontendService private(name: String, beService: BackendSe
}

override def init(conf: SparkConf): Unit = synchronized {
this.conf = conf
hadoopConf = SparkHadoopUtil.get.newConfiguration(conf)
val serverHost = conf.get(FRONTEND_BIND_HOST.key)
val serverHost = conf.get(FRONTEND_BIND_HOST)
try {
if (serverHost.nonEmpty) {
serverIPAddress = InetAddress.getByName(serverHost)
} else {
serverIPAddress = InetAddress.getLocalHost
}
portNum = conf.get(FRONTEND_BIND_PORT.key).toInt
portNum = conf.get(FRONTEND_BIND_PORT).toInt
serverSocket = new ServerSocket(portNum, 1, serverIPAddress)
} catch {
case e: Exception => throw new ServiceException(e.getMessage + ": " + portNum, e)
Expand Down Expand Up @@ -151,11 +150,11 @@ private[kyuubi] class FrontendService private(name: String, beService: BackendSe

def getServerIPAddress: InetAddress = serverIPAddress

private[this] def isKerberosAuthMode = {
conf.get(KyuubiConf.AUTHENTICATION_METHOD.key).equalsIgnoreCase(AuthType.KERBEROS.name)
private def isKerberosAuthMode: Boolean = {
conf.get(KyuubiConf.AUTHENTICATION_METHOD).equalsIgnoreCase(AuthType.KERBEROS.name)
}

private[this] def getUserName(req: TOpenSessionReq) = {
private def getUserName(req: TOpenSessionReq): String = {
// Kerberos
if (isKerberosAuthMode) {
realUser = authFactory.getRemoteUser.orNull
Expand All @@ -168,10 +167,15 @@ private[kyuubi] class FrontendService private(name: String, beService: BackendSe
realUser = req.getUsername
}
realUser = getShortName(realUser)
getProxyUser(req.getConfiguration.asScala.toMap, getIpAddress)

if (req.getConfiguration == null) {
realUser
} else {
getProxyUser(req.getConfiguration.asScala.toMap, getIpAddress)
}
}

private[this] def getShortName(userName: String): String = {
private def getShortName(userName: String): String = {
val indexOfDomainMatch = ServiceUtils.indexOfDomainMatch(userName)
if (indexOfDomainMatch <= 0) {
userName
Expand All @@ -181,10 +185,10 @@ private[kyuubi] class FrontendService private(name: String, beService: BackendSe
}

@throws[KyuubiSQLException]
private[this] def getProxyUser(sessionConf: Map[String, String], ipAddress: String): String = {
private def getProxyUser(sessionConf: Map[String, String], ipAddress: String): String = {
Option(sessionConf).flatMap(_.get(KyuubiAuthFactory.HS2_PROXY_USER)) match {
case None => realUser
case Some(_) if !conf.get(FRONTEND_ALLOW_USER_SUBSTITUTION.key).toBoolean =>
case Some(_) if !conf.get(FRONTEND_ALLOW_USER_SUBSTITUTION).toBoolean =>
throw new KyuubiSQLException("Proxy user substitution is not allowed")
case Some(p) if !isKerberosAuthMode => p
case Some(p) => // Verify proxy user privilege of the realUser for the proxyUser
Expand All @@ -193,15 +197,15 @@ private[kyuubi] class FrontendService private(name: String, beService: BackendSe
}
}

private[this] def getIpAddress: String = {
private def getIpAddress: String = {
if (isKerberosAuthMode) {
this.authFactory.getIpAddress.orNull
} else {
TSetIpAddressProcessor.getUserIpAddress
}
}

private[this] def getMinVersion(versions: TProtocolVersion*): TProtocolVersion = {
private def getMinVersion(versions: TProtocolVersion*): TProtocolVersion = {
val values = TProtocolVersion.values
var current = values(values.length - 1).getValue
for (version <- versions) {
Expand All @@ -218,17 +222,17 @@ private[kyuubi] class FrontendService private(name: String, beService: BackendSe
}

@throws[KyuubiSQLException]
private[this] def getSessionHandle(req: TOpenSessionReq, res: TOpenSessionResp) = {
private def getSessionHandle(req: TOpenSessionReq, res: TOpenSessionResp): SessionHandle = {
val userName = getUserName(req)
val ipAddress = getIpAddress
val protocol = getMinVersion(BackendService.SERVER_VERSION, req.getClient_protocol)
val sessionHandle =
if (conf.get(FRONTEND_ENABLE_DOAS.key).toBoolean && (userName != null)) {
val confMap = Option(req.getConfiguration).map(_.asScala.toMap).getOrElse(Map.empty)
val sessionHandle = if (conf.get(FRONTEND_ENABLE_DOAS).toBoolean && (userName != null)) {
beService.openSessionWithImpersonation(
protocol, userName, req.getPassword, ipAddress, req.getConfiguration.asScala.toMap, null)
protocol, userName, req.getPassword, ipAddress, confMap, null)
} else {
beService.openSession(
protocol, userName, req.getPassword, ipAddress, req.getConfiguration.asScala.toMap)
protocol, userName, req.getPassword, ipAddress, confMap)
}
res.setServerProtocolVersion(protocol)
sessionHandle
Expand Down Expand Up @@ -567,12 +571,12 @@ private[kyuubi] class FrontendService private(name: String, beService: BackendSe
override def run(): Unit = {
try {
// Server thread pool
val minThreads = conf.get(FRONTEND_MIN_WORKER_THREADS.key).toInt
val maxThreads = conf.get(FRONTEND_MAX_WORKER_THREADS.key).toInt
val minThreads = conf.get(FRONTEND_MIN_WORKER_THREADS).toInt
val maxThreads = conf.get(FRONTEND_MAX_WORKER_THREADS).toInt
val executorService = new ThreadPoolExecutor(
minThreads,
maxThreads,
conf.getTimeAsSeconds(FRONTEND_WORKER_KEEPALIVE_TIME.key),
conf.getTimeAsSeconds(FRONTEND_WORKER_KEEPALIVE_TIME),
TimeUnit.SECONDS,
new SynchronousQueue[Runnable],
new NamedThreadFactory(threadPoolName))
Expand All @@ -584,9 +588,9 @@ private[kyuubi] class FrontendService private(name: String, beService: BackendSe
val tSocket = new TServerSocket(serverSocket)

// Server args
val maxMessageSize = conf.get(FRONTEND_MAX_MESSAGE_SIZE.key).toInt
val requestTimeout = conf.getTimeAsSeconds(FRONTEND_LOGIN_TIMEOUT.key).toInt
val beBackoffSlotLength = conf.getTimeAsMs(FRONTEND_LOGIN_BEBACKOFF_SLOT_LENGTH.key).toInt
val maxMessageSize = conf.get(FRONTEND_MAX_MESSAGE_SIZE).toInt
val requestTimeout = conf.getTimeAsSeconds(FRONTEND_LOGIN_TIMEOUT).toInt
val beBackoffSlotLength = conf.getTimeAsMs(FRONTEND_LOGIN_BEBACKOFF_SLOT_LENGTH).toInt
val args = new TThreadPoolServer.Args(tSocket)
.processorFactory(processorFactory)
.transportFactory(transportFactory)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ private[kyuubi] class SessionManager private(

val sessionHandle = kyuubiSession.getSessionHandle
handleToSession.put(sessionHandle, kyuubiSession)
KyuubiServerMonitor.getListener(username).foreach {
KyuubiServerMonitor.getListener(kyuubiSession.getUserName).foreach {
_.onSessionCreated(
kyuubiSession.getIpAddress,
sessionHandle.getSessionId.toString,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package yaooqinn.kyuubi.spark

import java.security.PrivilegedExceptionAction
import java.util.concurrent.TimeUnit

import scala.collection.mutable.{HashSet => MHSet}
Expand All @@ -36,7 +35,7 @@ import org.apache.spark.ui.KyuubiServerTab
import yaooqinn.kyuubi.{KyuubiSQLException, Logging}
import yaooqinn.kyuubi.author.AuthzHelper
import yaooqinn.kyuubi.ui.{KyuubiServerListener, KyuubiServerMonitor}
import yaooqinn.kyuubi.utils.ReflectUtils
import yaooqinn.kyuubi.utils.{KyuubiHadoopUtil, ReflectUtils}

class SparkSessionWithUGI(
user: UserGroupInformation,
Expand Down Expand Up @@ -124,9 +123,9 @@ class SparkSessionWithUGI(
}

private[this] def getOrCreate(sessionConf: Map[String, String]): Unit = synchronized {
val totalRounds = math.max(conf.get(BACKEND_SESSION_WAIT_OTHER_TIMES.key).toInt, 15)
val totalRounds = math.max(conf.get(BACKEND_SESSION_WAIT_OTHER_TIMES).toInt, 15)
var checkRound = totalRounds
val interval = conf.getTimeAsMs(BACKEND_SESSION_WAIT_OTHER_INTERVAL.key)
val interval = conf.getTimeAsMs(BACKEND_SESSION_WAIT_OTHER_INTERVAL)
// if user's sc is being constructed by another
while (SparkSessionWithUGI.isPartiallyConstructed(userName)) {
wait(interval)
Expand All @@ -151,25 +150,30 @@ class SparkSessionWithUGI(

private[this] def create(sessionConf: Map[String, String]): Unit = {
info(s"--------- Create new SparkSession for $userName ----------")
val appName = s"KyuubiSession[$userName]@" + conf.get(FRONTEND_BIND_HOST.key)
// kyuubi|user name|canonical host name| port
val appName = Seq(
"kyuubi", userName, conf.get(FRONTEND_BIND_HOST), conf.get(FRONTEND_BIND_PORT)).mkString("|")
conf.setAppName(appName)
configureSparkConf(sessionConf)
val totalWaitTime: Long = conf.getTimeAsSeconds(BACKEND_SESSTION_INIT_TIMEOUT.key)
val totalWaitTime: Long = conf.getTimeAsSeconds(BACKEND_SESSTION_INIT_TIMEOUT)
try {
user.doAs(new PrivilegedExceptionAction[Unit] {
override def run(): Unit = {
newContext().start()
val context =
Await.result(promisedSparkContext.future, Duration(totalWaitTime, TimeUnit.SECONDS))
_sparkSession = ReflectUtils.newInstance(
classOf[SparkSession].getName,
Seq(classOf[SparkContext]),
Seq(context)).asInstanceOf[SparkSession]
}
})
KyuubiHadoopUtil.doAs(user) {
newContext().start()
val context =
Await.result(promisedSparkContext.future, Duration(totalWaitTime, TimeUnit.SECONDS))
_sparkSession = ReflectUtils.newInstance(
classOf[SparkSession].getName,
Seq(classOf[SparkContext]),
Seq(context)).asInstanceOf[SparkSession]
}
cache.set(userName, _sparkSession)
} catch {
case e: Exception =>
if (conf.getOption("spark.master").contains("yarn")) {
KyuubiHadoopUtil.doAs(user) {
KyuubiHadoopUtil.killYarnAppByName(appName)
}
}
stopContext()
val ke = new KyuubiSQLException(
s"Get SparkSession for [$userName] failed", "08S01", 1001, findCause(e))
Expand All @@ -193,9 +197,9 @@ class SparkSessionWithUGI(

try {
initialDatabase.foreach { db =>
user.doAs(new PrivilegedExceptionAction[Unit] {
override def run(): Unit = _sparkSession.sql(db)
})
KyuubiHadoopUtil.doAs(user) {
_sparkSession.sql(db)
}
}
} catch {
case e: Exception =>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package yaooqinn.kyuubi.utils

import java.security.PrivilegedExceptionAction

import scala.collection.JavaConverters._

import org.apache.hadoop.security.UserGroupInformation
import org.apache.hadoop.yarn.api.records.ApplicationReport
import org.apache.hadoop.yarn.api.records.YarnApplicationState._
import org.apache.hadoop.yarn.client.api.YarnClient
import org.apache.hadoop.yarn.conf.YarnConfiguration

import yaooqinn.kyuubi.Logging

private[kyuubi] object KyuubiHadoopUtil extends Logging {

// YarnClient is thread safe. Create once, share it across threads.
private lazy val yarnClient = {
val c = YarnClient.createYarnClient()
c.init(new YarnConfiguration())
c.start()
c
}

def killYarnApp(report: ApplicationReport): Unit = {
try {
yarnClient.killApplication(report.getApplicationId)
} catch {
case e: Exception => error("Failed to kill Application: " + report.getApplicationId, e)
}
}

def getApplications: Seq[ApplicationReport] = {
yarnClient.getApplications(Set("SPARK").asJava).asScala.filter { p =>
p.getYarnApplicationState match {
case ACCEPTED | NEW | NEW_SAVING | SUBMITTED | RUNNING => true
case _ => false
}
}
}

def killYarnAppByName(appName: String): Unit = {
getApplications.filter(app => app.getName.equals(appName)).foreach(killYarnApp)
}

def doAs[T](user: UserGroupInformation)(f: => T): T = {
user.doAs(new PrivilegedExceptionAction[T] {
override def run(): T = f
})
}
}
2 changes: 1 addition & 1 deletion kyuubi-server/src/test/resources/kyuubi-test.conf
Original file line number Diff line number Diff line change
@@ -1 +1 @@
spark.kyuubi.test 1
spark.kyuubi.test=1

0 comments on commit 8a30b63

Please sign in to comment.