Skip to content

Commit

Permalink
Merge pull request #131 from yaooqinn/KYUUBI-128
Browse files Browse the repository at this point in the history
[KYUUBI-128]add ldap authenticate support
  • Loading branch information
yaooqinn committed Dec 12, 2018
2 parents 6886529 + 942762b commit e960540
Show file tree
Hide file tree
Showing 18 changed files with 379 additions and 79 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,5 @@ hs_err_pid*
spark-warehouse/
metastore_db
derby.log
ldap

21 changes: 20 additions & 1 deletion kyuubi-server/src/main/scala/org/apache/spark/KyuubiConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,8 @@ object KyuubiConf {
KyuubiConfigBuilder("spark.kyuubi.authentication")
.doc("Client authentication types." +
" NONE: no authentication check." +
" KERBEROS: Kerberos/GSSAPI authentication")
" KERBEROS: Kerberos/GSSAPI authentication." +
" LDAP: Lightweight Directory Access Protocol authentication.")
.stringConf
.createWithDefault("NONE")

Expand All @@ -314,6 +315,24 @@ object KyuubiConf {
.stringConf
.createWithDefault("auth")

val AUTHENTICATION_LDAP_URL: ConfigEntry[String] =
KyuubiConfigBuilder("spark.kyuubi.authentication.ldap.url")
.doc("SPACE character separated LDAP connection URL(s).")
.stringConf
.createWithDefault("")

val AUTHENTICATION_LDAP_BASEDN: ConfigEntry[String] =
KyuubiConfigBuilder("spark.kyuubi.authentication.ldap.baseDN")
.doc("LDAP base DN.")
.stringConf
.createWithDefault("")

val AUTHENTICATION_LDAP_DOMAIN: ConfigEntry[String] =
KyuubiConfigBuilder("spark.kyuubi.authentication.ldap.Domain")
.doc("LDAP base DN.")
.stringConf
.createWithDefault("")

/////////////////////////////////////////////////////////////////////////////////////////////////
// Authorization //
/////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,15 @@ package yaooqinn.kyuubi.auth

import javax.security.sasl.AuthenticationException

trait AuthMethods {
def authMethodStr: String = null
}
object AuthMethods extends Enumeration {

object AuthMethods {
type AuthMethods = Value

case object NONE extends AuthMethods {
override val authMethodStr = "NONE"
}
val NONE, LDAP = Value

def getValidAuthMethod(authMethodStr: String): AuthMethods = authMethodStr match {
case NONE.authMethodStr => NONE
case "NONE" => NONE
case "LDAP" => LDAP
case _ => throw new AuthenticationException("Not a valid authentication method")
}
}
27 changes: 7 additions & 20 deletions kyuubi-server/src/main/scala/yaooqinn/kyuubi/auth/AuthType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,30 +19,17 @@ package yaooqinn.kyuubi.auth

import yaooqinn.kyuubi.service.ServiceException

abstract class AuthType {
def name: String = ""
override def toString: String = name
object AuthType extends Enumeration {

}

object AuthType {

case object NOSASL extends AuthType {
override val name: String = "NOSASL"
}
type AuthType = Value

case object NONE extends AuthType {
override val name: String = "NONE"
}

case object KERBEROS extends AuthType {
override val name: String = "KERBEROS"
}
val NOSASL, NONE, LDAP, KERBEROS = Value

def toAuthType(authTypeStr: String): AuthType = authTypeStr.toUpperCase match {
case NOSASL.name => NOSASL
case NONE.name => NONE
case KERBEROS.name => KERBEROS
case "NOSASL" => NOSASL
case "NONE" => NONE
case "LDAP" => LDAP
case "KERBEROS" => KERBEROS
case other => throw new ServiceException("Unsupported authentication method: " + other)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,20 @@ package yaooqinn.kyuubi.auth

import javax.security.sasl.AuthenticationException

import org.apache.spark.SparkConf

import yaooqinn.kyuubi.auth.AuthMethods.AuthMethods

/**
* This class helps select a [[PasswdAuthenticationProvider]] for a given [[AuthMethods]]
*/
object AuthenticationProviderFactory {
@throws[AuthenticationException]
def getAuthenticationProvider(method: AuthMethods): PasswdAuthenticationProvider = method match {
def getAuthenticationProvider(
method: AuthMethods,
conf: SparkConf): PasswdAuthenticationProvider = method match {
case AuthMethods.NONE => new AnonymousAuthenticationProviderImpl
case AuthMethods.LDAP => new LdapAuthenticationProviderImpl(conf)
case _ => throw new AuthenticationException("Not a valid authentication method")

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class KyuubiAuthFactory(conf: SparkConf) extends Logging {
throw new TTransportException("Failed to start token manager", e)
}
Some(server)
case AuthType.NONE | AuthType.NOSASL => None
case AuthType.NONE | AuthType.NOSASL | AuthType.LDAP => None
case other => throw new ServiceException("Unsupported authentication method: " + other)
}

Expand All @@ -86,7 +86,7 @@ class KyuubiAuthFactory(conf: SparkConf) extends Logging {
}
case _ => authMethod match {
case AuthType.NOSASL => new TTransportFactory
case _ => PlainSaslHelper.getTransportFactory(authMethod.name)
case _ => PlainSaslHelper.getTransportFactory(authMethod.toString, conf)
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package yaooqinn.kyuubi.auth

import java.util.Hashtable
import javax.naming.{Context, NamingException}
import javax.naming.directory.InitialDirContext
import javax.security.sasl.AuthenticationException

import org.apache.commons.lang3.StringUtils
import org.apache.spark.{KyuubiConf, SparkConf}

import yaooqinn.kyuubi.service.ServiceUtils

class LdapAuthenticationProviderImpl(conf: SparkConf) extends PasswdAuthenticationProvider {

import KyuubiConf._

/**
* The authenticate method is called by the Kyuubi Server authentication layer
* to authenticate users for their requests.
* If a user is to be granted, return nothing/throw nothing.
* When a user is to be disallowed, throw an appropriate [[AuthenticationException]].
*
* @param user The username received over the connection request
* @param password The password received over the connection request
*
* @throws AuthenticationException When a user is found to be invalid by the implementation
*/
override def authenticate(user: String, password: String): Unit = {
if (StringUtils.isBlank(user)) {
throw new AuthenticationException(s"Error validating LDAP user, user is null" +
s" or contains blank space")
}

if (StringUtils.isBlank(password)) {
throw new AuthenticationException(s"Error validating LDAP user, password is null" +
s" or contains blank space")
}

val env = new Hashtable[String, Any]()
env.put(Context.INITIAL_CONTEXT_FACTORY, "com.sun.jndi.ldap.LdapCtxFactory")
env.put(Context.SECURITY_AUTHENTICATION, "simple")

conf.getOption(AUTHENTICATION_LDAP_URL).foreach(env.put(Context.PROVIDER_URL, _))

val domain = conf.get(AUTHENTICATION_LDAP_DOMAIN, "")
val u = if (!hasDomain(user) && StringUtils.isNotBlank(domain)) {
user + "@" + domain
} else {
user
}

val bindDn = conf.getOption(AUTHENTICATION_LDAP_BASEDN) match {
case Some(dn) => "uid=" + u + "," + dn
case _ => u
}

env.put(Context.SECURITY_PRINCIPAL, bindDn)
env.put(Context.SECURITY_CREDENTIALS, password)

try {
val ctx = new InitialDirContext(env)
ctx.close()
} catch {
case e: NamingException =>
throw new AuthenticationException(s"Error validating LDAP user: $bindDn", e)
}
}

private def hasDomain(userName: String): Boolean = ServiceUtils.indexOfDomainMatch(userName) > 0
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@ import javax.security.sasl.{AuthenticationException, AuthorizeCallback}
import scala.collection.JavaConverters._

import org.apache.hive.service.cli.thrift.TCLIService.Iface
import org.apache.spark.SparkConf
import org.apache.thrift.{TProcessor, TProcessorFactory}
import org.apache.thrift.transport.{TSaslServerTransport, TTransport, TTransportFactory}

import yaooqinn.kyuubi.auth.AuthMethods.AuthMethods
import yaooqinn.kyuubi.auth.PlainSaslServer.SaslPlainProvider

object PlainSaslHelper {
Expand All @@ -41,10 +43,10 @@ object PlainSaslHelper {
}

@throws[LoginException]
def getTransportFactory(authTypeStr: String): TTransportFactory = {
def getTransportFactory(authTypeStr: String, conf: SparkConf): TTransportFactory = {
val saslFactory = new TSaslServerTransport.Factory()
try {
val handler = new PlainServerCallbackHandler(authTypeStr)
val handler = new PlainServerCallbackHandler(authTypeStr, conf)
val props = Map.empty[String, String]
saslFactory.addServerDefinition("PLAIN", authTypeStr, null, props.asJava, handler)
} catch {
Expand All @@ -54,10 +56,11 @@ object PlainSaslHelper {
saslFactory
}

private class PlainServerCallbackHandler private(authMethod: AuthMethods)
private class PlainServerCallbackHandler private(authMethod: AuthMethods, conf: SparkConf)
extends CallbackHandler {
@throws[AuthenticationException]
def this(authMethodStr: String) = this(AuthMethods.getValidAuthMethod(authMethodStr))
def this(authMethodStr: String, conf: SparkConf) =
this(AuthMethods.getValidAuthMethod(authMethodStr), conf)

@throws[IOException]
@throws[UnsupportedCallbackException]
Expand All @@ -75,7 +78,7 @@ object PlainSaslHelper {
case _ => throw new UnsupportedCallbackException(callback)
}
}
val provider = AuthenticationProviderFactory.getAuthenticationProvider(authMethod)
val provider = AuthenticationProviderFactory.getAuthenticationProvider(authMethod, conf)
provider.authenticate(username, password)
if (ac != null) ac.setAuthorized(true)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import java.util.{ArrayDeque => JDeque, Map => JMap}
import javax.security.auth.callback._
import javax.security.sasl.{AuthorizeCallback, SaslException, SaslServer, SaslServerFactory}

import yaooqinn.kyuubi.auth.AuthMethods.AuthMethods

object PlainSaslServer {

val PLAIN_METHOD = "PLAIN"
Expand Down Expand Up @@ -70,7 +72,7 @@ class PlainSaslServer(handler: CallbackHandler, authMethod: AuthMethods) extends
case 0 =>
tokenList.addLast(messageToken.toString)
messageToken.setLength(0)
case b: Byte => messageToken.append(b)
case b: Byte => messageToken.append(b.toChar)
}
tokenList.addLast(messageToken.toString)
// validate response
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ private[kyuubi] class FrontendService private(name: String, beService: BackendSe
def getServerIPAddress: InetAddress = serverIPAddress

private def isKerberosAuthMode: Boolean = {
conf.get(KyuubiConf.AUTHENTICATION_METHOD).equalsIgnoreCase(AuthType.KERBEROS.name)
conf.get(KyuubiConf.AUTHENTICATION_METHOD).equalsIgnoreCase(AuthType.KERBEROS.toString)
}

private def getUserName(req: TOpenSessionReq): String = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,17 @@ import javax.security.sasl.AuthenticationException

import org.apache.spark.SparkFunSuite

import yaooqinn.kyuubi.auth.AuthMethods._

class AuthMethodsSuite extends SparkFunSuite {

test("testGetValidAuthMethod") {
assert(new AuthMethods{}.authMethodStr === null)
assert(AuthMethods.getValidAuthMethod("NONE") === AuthMethods.NONE)
assert(AuthMethods.getValidAuthMethod(AuthMethods.NONE.authMethodStr) === AuthMethods.NONE)
test("get valid auth method") {
assert(getValidAuthMethod("NONE") === NONE)
assert(getValidAuthMethod(NONE.toString) === NONE)
assert(getValidAuthMethod("LDAP") === LDAP)
assert(getValidAuthMethod(LDAP.toString) === LDAP)

assert(AuthMethods.NONE.authMethodStr === "NONE")
intercept[AuthenticationException](AuthMethods.getValidAuthMethod("ELSE"))
assert(NONE.toString === "NONE")
intercept[AuthenticationException](getValidAuthMethod("ELSE"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,24 @@ package yaooqinn.kyuubi.auth

import org.apache.spark.SparkFunSuite

import yaooqinn.kyuubi.auth.AuthType._
import yaooqinn.kyuubi.service.ServiceException

class AuthTypeSuite extends SparkFunSuite {

test("test name") {
assert(new AuthType {}.name === "")
assert(AuthType.NONE.name === "NONE")
assert(AuthType.KERBEROS.name === "KERBEROS")
assert(AuthType.NOSASL.name === "NOSASL")
}

test("to auth type") {
assert(AuthType.toAuthType("NONE") === AuthType.NONE)
assert(AuthType.toAuthType(AuthType.NONE.name) === AuthType.NONE)
assert(toAuthType("NOSASL") === NOSASL)
assert(toAuthType(NOSASL.toString) === NOSASL)

assert(toAuthType("NONE") === NONE)
assert(toAuthType(NONE.toString) === NONE)

assert(AuthType.toAuthType("KERBEROS") === AuthType.KERBEROS)
assert(AuthType.toAuthType(AuthType.KERBEROS.name) === AuthType.KERBEROS)
assert(toAuthType("LDAP") === LDAP)
assert(toAuthType(LDAP.toString) === LDAP)

assert(AuthType.toAuthType("NOSASL") === AuthType.NOSASL)
assert(AuthType.toAuthType(AuthType.NOSASL.name) === AuthType.NOSASL)
assert(toAuthType("KERBEROS") === KERBEROS)
assert(toAuthType(KERBEROS.toString) === KERBEROS)

intercept[ServiceException](AuthType.toAuthType("ELSE"))
intercept[ServiceException](toAuthType("ELSE"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,22 @@ package yaooqinn.kyuubi.auth

import javax.security.sasl.AuthenticationException

import org.apache.spark.SparkFunSuite
import org.apache.spark.{SparkConf, SparkFunSuite}

class AuthenticationProviderFactorySuite extends SparkFunSuite {

test("testGetAuthenticationProvider") {
AuthenticationProviderFactory
.getAuthenticationProvider(AuthMethods.NONE).authenticate("test", "test")
intercept[AuthenticationException](AuthenticationProviderFactory
.getAuthenticationProvider(new AuthMethods {
override def authMethodStr: String = "TEST"
}))
val conf = new SparkConf()
val anonymous = AuthenticationProviderFactory.getAuthenticationProvider(AuthMethods.NONE, conf)
anonymous.authenticate("test", "test")

val ldap = AuthenticationProviderFactory.getAuthenticationProvider(AuthMethods.LDAP, conf)
val exception = intercept[AuthenticationException](ldap.authenticate("test", "test"))
assert(exception.getMessage.contains("Error validating LDAP user:"))

val e2 = intercept[AuthenticationException](
AuthenticationProviderFactory.getAuthenticationProvider(null, conf))
assert(e2.getMessage === "Not a valid authentication method")
}

}

0 comments on commit e960540

Please sign in to comment.