Skip to content

Commit

Permalink
[SPARK-21957][SQL] Support current_user function
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Currently, we do not have a suitable definition of the `user` concept in Spark. We only have a `sparkUser` app widely but do not support identify or retrieve the user information from a session in STS or a runtime query execution.

`current_user()` is very popular and supported by plenty of other modern or old school databases, and also ANSI compliant.

This PR add `current_user()`  as a SQL function. And, they are the same.  In this PR, we add these functions w/o ambiguity.
1. For a normal single-threaded Spark application, clearly the `sparkUser` is always equivalent to `current_user()` .
2. For a multi-threaded Spark application, e.g. Spark thrift server, we use a `ThreadLocal` variable to store the client-side user(after authenticated) before running the query and retrieve it in the parser.

### Why are the changes needed?

`current_user()` is very popular and supported by plenty of other modern or old school databases, and also ANSI compliant.

### Does this PR introduce _any_ user-facing change?

yes, added  `current_user()`  as a SQL function
### How was this patch tested?

new tests in thrift server and sql/catalyst

Closes #32718 from yaooqinn/SPARK-21957.

Authored-by: Kent Yao <yao@apache.org>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
yaooqinn authored and cloud-fan committed Jun 2, 2021
1 parent daf9d19 commit 345d35e
Show file tree
Hide file tree
Showing 11 changed files with 81 additions and 11 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst

object CurrentUserContext {
val CURRENT_USER: InheritableThreadLocal[String] = new InheritableThreadLocal[String] {
override protected def initialValue(): String = null
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,7 @@ object FunctionRegistry {
expression[MonotonicallyIncreasingID]("monotonically_increasing_id"),
expression[CurrentDatabase]("current_database"),
expression[CurrentCatalog]("current_catalog"),
expression[CurrentUser]("current_user"),
expression[CallMethodViaReflection]("reflect"),
expression[CallMethodViaReflection]("java_method", true),
expression[SparkVersion]("version"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,3 +283,21 @@ case class TypeOf(child: Expression) extends UnaryExpression {

override protected def withNewChildInternal(newChild: Expression): TypeOf = copy(child = newChild)
}

// scalastyle:off line.size.limit
@ExpressionDescription(
usage = """_FUNC_() - user name of current execution context.""",
examples = """
Examples:
> SELECT _FUNC_();
mockingjay
""",
since = "3.2.0",
group = "misc_funcs")
// scalastyle:on line.size.limit
case class CurrentUser() extends LeafExpression with Unevaluable {
override def nullable: Boolean = false
override def dataType: DataType = StringType
override def prettyName: String = "current_user"
final override val nodePatterns: Seq[TreePattern] = Seq(CURRENT_LIKE)
}
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
RewriteNonCorrelatedExists,
PullOutGroupingExpressions,
ComputeCurrentTime,
GetCurrentDatabaseAndCatalog(catalogManager)) ::
ReplaceCurrentLike(catalogManager)) ::
//////////////////////////////////////////////////////////////////////////////////////////
// Optimizer rules start here
//////////////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -256,7 +256,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
EliminateView.ruleName ::
ReplaceExpressions.ruleName ::
ComputeCurrentTime.ruleName ::
GetCurrentDatabaseAndCatalog(catalogManager).ruleName ::
ReplaceCurrentLike(catalogManager).ruleName ::
RewriteDistinctAggregates.ruleName ::
ReplaceDeduplicateWithAggregate.ruleName ::
ReplaceIntersectWithSemiJoin.ruleName ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer

import scala.collection.mutable

import org.apache.spark.sql.catalyst.CurrentUserContext.CURRENT_USER
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical._
Expand All @@ -27,6 +28,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils


/**
Expand Down Expand Up @@ -98,17 +100,20 @@ object ComputeCurrentTime extends Rule[LogicalPlan] {
* Replaces the expression of CurrentDatabase with the current database name.
* Replaces the expression of CurrentCatalog with the current catalog name.
*/
case class GetCurrentDatabaseAndCatalog(catalogManager: CatalogManager) extends Rule[LogicalPlan] {
case class ReplaceCurrentLike(catalogManager: CatalogManager) extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = {
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
val currentNamespace = catalogManager.currentNamespace.quoted
val currentCatalog = catalogManager.currentCatalog.name()
val currentUser = Option(CURRENT_USER.get()).getOrElse(Utils.getCurrentUserName())

plan.transformAllExpressionsWithPruning(_.containsPattern(CURRENT_LIKE)) {
case CurrentDatabase() =>
Literal.create(currentNamespace, StringType)
case CurrentCatalog() =>
Literal.create(currentCatalog, StringType)
case CurrentUser() =>
Literal.create(currentUser, StringType)
}
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<!-- Automatically generated by ExpressionsSchemaSuite -->
## Summary
- Number of queries: 352
- Number of queries: 353
- Number of expressions that missing example: 13
- Expressions missing examples: bigint,binary,boolean,date,decimal,double,float,int,smallint,string,timestamp,tinyint,window
## Schema of Built-in Functions
Expand Down Expand Up @@ -89,6 +89,7 @@
| org.apache.spark.sql.catalyst.expressions.CurrentDate | current_date | SELECT current_date() | struct<current_date():date> |
| org.apache.spark.sql.catalyst.expressions.CurrentTimeZone | current_timezone | SELECT current_timezone() | struct<current_timezone():string> |
| org.apache.spark.sql.catalyst.expressions.CurrentTimestamp | current_timestamp | SELECT current_timestamp() | struct<current_timestamp():timestamp> |
| org.apache.spark.sql.catalyst.expressions.CurrentUser | current_user | SELECT current_user() | struct<current_user():string> |
| org.apache.spark.sql.catalyst.expressions.DateAdd | date_add | SELECT date_add('2016-07-30', 1) | struct<date_add(2016-07-30, 1):date> |
| org.apache.spark.sql.catalyst.expressions.DateDiff | datediff | SELECT datediff('2009-07-31', '2009-07-30') | struct<datediff(2009-07-31, 2009-07-30):int> |
| org.apache.spark.sql.catalyst.expressions.DateFormatClass | date_format | SELECT date_format('2016-04-08', 'y') | struct<date_format(2016-04-08, y):string> |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ class MiscFunctionsSuite extends QueryTest with SharedSparkSession {
Row(SPARK_VERSION_SHORT + " " + SPARK_REVISION))
assert(df.schema.fieldNames === Seq("version()"))
}

test("SPARK-21957: get current_user in normal spark apps") {
val df = sql("select current_user()")
checkAnswer(df, Row(spark.sparkContext.sparkUser))
}
}

object ReflectClass {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,8 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession {
"org.apache.spark.sql.catalyst.expressions.CallMethodViaReflection",
"org.apache.spark.sql.catalyst.expressions.SparkVersion",
// Throws an error
"org.apache.spark.sql.catalyst.expressions.RaiseError")
"org.apache.spark.sql.catalyst.expressions.RaiseError",
classOf[CurrentUser].getName)

val parFuncs = new ParVector(spark.sessionState.functionRegistry.listFunction().toVector)
parFuncs.foreach { funcId =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.hive.service.cli.operation.Operation
import org.apache.spark.SparkContext
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{SparkSession, SQLContext}
import org.apache.spark.sql.catalyst.CurrentUserContext.CURRENT_USER
import org.apache.spark.sql.catalyst.catalog.CatalogTableType
import org.apache.spark.sql.catalyst.catalog.CatalogTableType.{EXTERNAL, MANAGED, VIEW}
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -73,10 +74,11 @@ private[hive] trait SparkOperation extends Operation with Logging {
sqlContext.sparkContext.setLocalProperty(SparkContext.SPARK_SCHEDULER_POOL, pool)
case None =>
}

CURRENT_USER.set(getParentSession.getUserName)
// run the body
f
} finally {
CURRENT_USER.remove()
// reset local properties, will also reset SPARK_SCHEDULER_POOL
sqlContext.sparkContext.setLocalProperties(originalProps)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,15 +99,16 @@ trait SharedThriftServer extends SharedSparkSession {
}
}

protected def withCLIServiceClient(f: ThriftCLIServiceClient => Unit): Unit = {
protected def withCLIServiceClient(username: String = user)
(f: ThriftCLIServiceClient => Unit): Unit = {
require(serverPort != 0, "Failed to bind an actual port for HiveThriftServer2")
val transport = mode match {
case ServerMode.binary =>
val rawTransport = new TSocket("localhost", serverPort)
PlainSaslHelper.getPlainTransport(user, "anonymous", rawTransport)
PlainSaslHelper.getPlainTransport(username, "anonymous", rawTransport)
case ServerMode.http =>
val interceptor = new HttpBasicAuthInterceptor(
user,
username,
"anonymous",
null, null, true, new util.HashMap[String, String]())
new THttpClient(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ trait ThriftServerWithSparkContextSuite extends SharedThriftServer {

test("Full stack traces as error message for jdbc or thrift client") {
val sql = "select date_sub(date'2011-11-11', '1.2')"
withCLIServiceClient { client =>
withCLIServiceClient() { client =>
val sessionHandle = client.openSession(user, "")

val confOverlay = new java.util.HashMap[java.lang.String, java.lang.String]
Expand Down Expand Up @@ -117,8 +117,20 @@ trait ThriftServerWithSparkContextSuite extends SharedThriftServer {
}
}
}
}

test("SPARK-21957: get current_user through thrift server") {
val clientUser = "storm_earth_fire_heed_my_call"
val sql = "select current_user()"

withCLIServiceClient(clientUser) { client =>
val sessionHandle = client.openSession(clientUser, "")
val confOverlay = new java.util.HashMap[java.lang.String, java.lang.String]
val opHandle = client.executeStatement(sessionHandle, sql, confOverlay)
val rowSet = client.fetchResults(opHandle)
assert(rowSet.toTRowSet.getColumns.get(0).getStringVal.getValues.get(0) === clientUser)
}
}
}

class ThriftServerWithSparkContextInBinarySuite extends ThriftServerWithSparkContextSuite {
override def mode: ServerMode.Value = ServerMode.binary
Expand Down

0 comments on commit 345d35e

Please sign in to comment.