Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
package org.apache.spark.sql.arangodb.commons

import com.arangodb.{ArangoDB, entity}
import com.arangodb.model.OverwriteMode
import com.arangodb.{ArangoDB, entity}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DropMalformedMode, FailFastMode, ParseMode, PermissiveMode}

import java.io.ByteArrayInputStream
import java.io.{ByteArrayInputStream, FileInputStream}
import java.security.KeyStore
import java.security.cert.CertificateFactory
import java.util
Expand Down Expand Up @@ -69,11 +69,23 @@ object ArangoDBConf {
.createWithDefault(false)

val SSL_VERIFY_HOST = "ssl.verifyHost"
val verifyHostConf: ConfigEntry[Boolean] = ConfigBuilder(SSL_VERIFY_HOST)
val sslVerifyHostConf: ConfigEntry[Boolean] = ConfigBuilder(SSL_VERIFY_HOST)
.doc("hostname verification")
.booleanConf
.createWithDefault(true)

val SSL_TRUST_STORE_PASSWORD = "ssl.trustStore.password"
val sslTrustStorePasswordConf: OptionalConfigEntry[String] = ConfigBuilder(SSL_TRUST_STORE_PASSWORD)
.doc("trustStore password")
.stringConf
.createOptional

val SSL_TRUST_STORE_PATH = "ssl.trustStore.path"
val sslTrustStorePathConf: OptionalConfigEntry[String] = ConfigBuilder(SSL_TRUST_STORE_PATH)
.doc("trustStore path")
.stringConf
.createOptional

val SSL_CERT_VALUE = "ssl.cert.value"
val sslCertValueConf: OptionalConfigEntry[String] = ConfigBuilder(SSL_CERT_VALUE)
.doc("base64 encoded certificate")
Expand All @@ -100,7 +112,7 @@ object ArangoDBConf {

val SSL_KEYSTORE_TYPE = "ssl.keystore.type"
val sslKeystoreTypeConf: ConfigEntry[String] = ConfigBuilder(SSL_KEYSTORE_TYPE)
.doc("keystore type")
.doc("keystore type, deprecated: use ssl.trustStore.type instead")
.stringConf
.createWithDefault("jks")

Expand Down Expand Up @@ -268,7 +280,9 @@ object ArangoDBConf {
CONTENT_TYPE -> contentTypeConf,
TIMEOUT -> timeoutConf,
SSL_ENABLED -> sslEnabledConf,
SSL_VERIFY_HOST -> verifyHostConf,
SSL_VERIFY_HOST -> sslVerifyHostConf,
SSL_TRUST_STORE_PASSWORD -> sslTrustStorePasswordConf,
SSL_TRUST_STORE_PATH -> sslTrustStorePathConf,
SSL_CERT_VALUE -> sslCertValueConf,
SSL_CERT_TYPE -> sslCertTypeConf,
SSL_CERT_ALIAS -> sslCertAliasConf,
Expand Down Expand Up @@ -479,7 +493,11 @@ class ArangoDBDriverConf(opts: Map[String, String]) extends ArangoDBConf(opts) {

val sslEnabled: Boolean = getConf(sslEnabledConf)

val verifyHost: Boolean = getConf(verifyHostConf)
val verifyHost: Boolean = getConf(sslVerifyHostConf)

val sslTrustStorePassword: Option[String] = getConf(sslTrustStorePasswordConf)

val sslTrustStorePath: Option[String] = getConf(sslTrustStorePathConf)

val sslCertValue: Option[String] = getConf(sslCertValueConf)

Expand All @@ -489,6 +507,7 @@ class ArangoDBDriverConf(opts: Map[String, String]) extends ArangoDBConf(opts) {

val sslAlgorithm: String = getConf(sslAlgorithmConf)

// FIXME: merge with sslTrustStoreType
val sslKeystoreType: String = getConf(sslKeystoreTypeConf)

val sslProtocol: String = getConf(sslProtocolConf)
Expand All @@ -514,21 +533,35 @@ class ArangoDBDriverConf(opts: Map[String, String]) extends ArangoDBConf(opts) {
builder
}

def getSslContext: SSLContext = sslCertValue match {
case Some(b64cert) =>
val is = new ByteArrayInputStream(Base64.getDecoder.decode(b64cert))
def getSslContext: SSLContext = {
if (sslCertValue.isDefined) {
val is = new ByteArrayInputStream(Base64.getDecoder.decode(sslCertValue.get))
val cert = CertificateFactory.getInstance(sslCertType).generateCertificate(is)
val ks = KeyStore.getInstance(sslKeystoreType)
ks.load(null) // scalastyle:ignore null
ks.setCertificateEntry(sslCertAlias, cert)
val tmf = TrustManagerFactory.getInstance(sslAlgorithm)
tmf.init(ks)
val sc = SSLContext.getInstance(sslProtocol)
sc.init(null, tmf.getTrustManagers, null) // scalastyle:ignore null
sc
case None => SSLContext.getDefault
createSslContext(ks)
} else if (sslTrustStorePath.isDefined) {
val ks = KeyStore.getInstance(sslKeystoreType)
val is = new FileInputStream(sslTrustStorePath.get)
try {
ks.load(is, sslTrustStorePassword.map(_.toCharArray).orNull)
} finally {
is.close()
}
createSslContext(ks)
} else {
SSLContext.getDefault
}
}

private def createSslContext(ks: KeyStore): SSLContext = {
val tmf = TrustManagerFactory.getInstance(sslAlgorithm)
tmf.init(ks)
val sc = SSLContext.getInstance(sslProtocol)
sc.init(null, tmf.getTrustManagers, null) // scalastyle:ignore null
sc
}

}

Expand Down
Binary file added integration-tests/src/test/resources/cert.p12
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,34 @@ class SslTest {

@ParameterizedTest
@ValueSource(strings = Array("vst", "http", "http2"))
def sslTest(protocol: String): Unit = {
def sslB64CertTest(protocol: String): Unit = {
assumeTrue(protocol != "vst" || isLessThanVersion(version.getVersion, 3, 12, 0))
val df = spark.read
.format(classOf[DefaultSource].getName)
.options(options ++ Map(
ArangoDBConf.PROTOCOL -> protocol,
ArangoDBConf.COLLECTION -> "sslTest"
ArangoDBConf.COLLECTION -> "sslTest",
ArangoDBConf.SSL_CERT_VALUE -> SslTest.b64cert,
))
.load()
df.show()
}

@ParameterizedTest
@ValueSource(strings = Array("vst", "http", "http2"))
def sslTrustStorePathTest(protocol: String): Unit = {
assumeTrue(protocol != "vst" || isLessThanVersion(version.getVersion, 3, 12, 0))
val df = spark.read
.format(classOf[DefaultSource].getName)
.options(options ++ Map(
ArangoDBConf.PROTOCOL -> protocol,
ArangoDBConf.COLLECTION -> "sslTest",
ArangoDBConf.SSL_TRUST_STORE_PATH -> "src/test/resources/cert.p12",
ArangoDBConf.SSL_TRUST_STORE_PASSWORD -> "12345678"
))
.load()
df.show()
}
}

object SslTest {
Expand Down Expand Up @@ -101,7 +117,6 @@ object SslTest {
private val options = Map(
ArangoDBConf.SSL_ENABLED -> "true",
ArangoDBConf.SSL_VERIFY_HOST -> "false",
ArangoDBConf.SSL_CERT_VALUE -> b64cert,
ArangoDBConf.DB -> database,
ArangoDBConf.USER -> user,
ArangoDBConf.PASSWORD -> password,
Expand Down