diff --git a/kyuubi-server/src/main/scala/org/apache/spark/sql/SparkSQLUtils.scala b/kyuubi-server/src/main/scala/org/apache/spark/sql/SparkSQLUtils.scala index 1b3ab9ed26f..8602bf6a834 100644 --- a/kyuubi-server/src/main/scala/org/apache/spark/sql/SparkSQLUtils.scala +++ b/kyuubi-server/src/main/scala/org/apache/spark/sql/SparkSQLUtils.scala @@ -38,4 +38,8 @@ object SparkSQLUtils { def toDataFrame(sparkSession: SparkSession, plan: LogicalPlan): DataFrame = { Dataset.ofRows(sparkSession, plan) } + + def initializeMetaStoreClient(sparkSession: SparkSession): Seq[String] = { + sparkSession.sessionState.catalog.listDatabases("default") + } } diff --git a/kyuubi-server/src/main/scala/yaooqinn/kyuubi/session/SessionManager.scala b/kyuubi-server/src/main/scala/yaooqinn/kyuubi/session/SessionManager.scala index 91f8bb2995c..eddc4f11cde 100644 --- a/kyuubi-server/src/main/scala/yaooqinn/kyuubi/session/SessionManager.scala +++ b/kyuubi-server/src/main/scala/yaooqinn/kyuubi/session/SessionManager.scala @@ -255,7 +255,7 @@ private[kyuubi] class SessionManager private( ipAddress: String, sessionConf: Map[String, String], withImpersonation: Boolean): SessionHandle = { - val kyuubiSession = new KyuubiSession( + val session = new KyuubiSession( protocol, username, password, @@ -265,24 +265,21 @@ private[kyuubi] class SessionManager private( this, operationManager) info(s"Opening session for $username") - kyuubiSession.open(sessionConf) + session.open(sessionConf) - kyuubiSession.setResourcesSessionDir(resourcesRootDir) + session.setResourcesSessionDir(resourcesRootDir) if (isOperationLogEnabled) { - kyuubiSession.setOperationLogSessionDir(operationLogRootDir) + session.setOperationLogSessionDir(operationLogRootDir) } - val sessionHandle = kyuubiSession.getSessionHandle - handleToSession.put(sessionHandle, kyuubiSession) - KyuubiServerMonitor.getListener(kyuubiSession.getUserName).foreach { - _.onSessionCreated( - kyuubiSession.getIpAddress, - sessionHandle.getSessionId.toString, - kyuubiSession.getUserName) + val handle = session.getSessionHandle + handleToSession.put(handle, session) + KyuubiServerMonitor.getListener(session.getUserName).foreach { + _.onSessionCreated(session.getIpAddress, handle.getSessionId.toString, session.getUserName) } - info(s"Session [$sessionHandle] opened, current opening sessions: $getOpenSessionCount") + info(s"$username's Session [$handle] opened, current opening sessions: $getOpenSessionCount") - sessionHandle + handle } @throws[KyuubiSQLException] @@ -300,12 +297,12 @@ private[kyuubi] class SessionManager private( if (session == null) { throw new KyuubiSQLException(s"Session $sessionHandle does not exist!") } - val sessionUser = session.getUserName - KyuubiServerMonitor.getListener(sessionUser).foreach { + val user = session.getUserName + KyuubiServerMonitor.getListener(user).foreach { _.onSessionClosed(sessionHandle.getSessionId.toString) } - cacheManager.decrease(sessionUser) - info(s"Session [$sessionHandle] closed, current opening sessions: $getOpenSessionCount") + cacheManager.decrease(user) + info(s"$user's Session [$sessionHandle] closed, current opening sessions: $getOpenSessionCount") try { session.close() } finally { diff --git a/kyuubi-server/src/main/scala/yaooqinn/kyuubi/spark/SparkSessionWithUGI.scala b/kyuubi-server/src/main/scala/yaooqinn/kyuubi/spark/SparkSessionWithUGI.scala index c005b7d009f..afd36381b5f 100644 --- a/kyuubi-server/src/main/scala/yaooqinn/kyuubi/spark/SparkSessionWithUGI.scala +++ b/kyuubi-server/src/main/scala/yaooqinn/kyuubi/spark/SparkSessionWithUGI.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.security.UserGroupInformation import org.apache.spark.{KyuubiSparkUtil, SparkConf, SparkContext} import org.apache.spark.KyuubiConf._ import org.apache.spark.KyuubiSparkUtil._ -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{SparkSession, SparkSQLUtils} import org.apache.spark.ui.KyuubiSessionTab import yaooqinn.kyuubi.{KyuubiSQLException, Logging} @@ -183,6 +183,9 @@ class SparkSessionWithUGI( getOrCreate(sessionConf) try { + doAs(user) { + SparkSQLUtils.initializeMetaStoreClient(_sparkSession) + } initialDatabase.foreach { db => doAs(user)(_sparkSession.sql(db)) } diff --git a/kyuubi-server/src/test/scala/org/apache/spark/sql/SparkSQLUtilsSuite.scala b/kyuubi-server/src/test/scala/org/apache/spark/sql/SparkSQLUtilsSuite.scala index 6109eff26f2..ea286e967c6 100644 --- a/kyuubi-server/src/test/scala/org/apache/spark/sql/SparkSQLUtilsSuite.scala +++ b/kyuubi-server/src/test/scala/org/apache/spark/sql/SparkSQLUtilsSuite.scala @@ -20,17 +20,31 @@ package org.apache.spark.sql import org.apache.spark.SparkFunSuite class SparkSQLUtilsSuite extends SparkFunSuite { + var sparkSession: SparkSession = _ - test("get user jar class loader") { - val sparkSession = SparkSession + override def beforeAll(): Unit = { + sparkSession = SparkSession .builder() .appName(classOf[SparkSQLUtilsSuite].getSimpleName) .master("local") .getOrCreate() + super.beforeAll() + } + + override def afterAll(): Unit = { + if (sparkSession != null) { + sparkSession.stop() + } + } + + test("initialize metastore client ahead") { + val dbs = SparkSQLUtils.initializeMetaStoreClient(sparkSession) + assert(dbs.contains("default")) + } + + test("get user jar class loader") { sparkSession.sql("add jar udf-test.jar") val loader = SparkSQLUtils.getUserJarClassLoader(sparkSession) assert(loader.getResource("udf-test.jar") !== null) - sparkSession.stop() } - }