diff --git a/arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/commons/ArangoDBConf.scala b/arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/commons/ArangoDBConf.scala index ca36836..7f538c0 100644 --- a/arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/commons/ArangoDBConf.scala +++ b/arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/commons/ArangoDBConf.scala @@ -1,13 +1,13 @@ package org.apache.spark.sql.arangodb.commons -import com.arangodb.{ArangoDB, entity} import com.arangodb.model.OverwriteMode +import com.arangodb.{ArangoDB, entity} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DropMalformedMode, FailFastMode, ParseMode, PermissiveMode} -import java.io.ByteArrayInputStream +import java.io.{ByteArrayInputStream, FileInputStream} import java.security.KeyStore import java.security.cert.CertificateFactory import java.util @@ -69,11 +69,23 @@ object ArangoDBConf { .createWithDefault(false) val SSL_VERIFY_HOST = "ssl.verifyHost" - val verifyHostConf: ConfigEntry[Boolean] = ConfigBuilder(SSL_VERIFY_HOST) + val sslVerifyHostConf: ConfigEntry[Boolean] = ConfigBuilder(SSL_VERIFY_HOST) .doc("hostname verification") .booleanConf .createWithDefault(true) + val SSL_TRUST_STORE_PASSWORD = "ssl.trustStore.password" + val sslTrustStorePasswordConf: OptionalConfigEntry[String] = ConfigBuilder(SSL_TRUST_STORE_PASSWORD) + .doc("trustStore password") + .stringConf + .createOptional + + val SSL_TRUST_STORE_PATH = "ssl.trustStore.path" + val sslTrustStorePathConf: OptionalConfigEntry[String] = ConfigBuilder(SSL_TRUST_STORE_PATH) + .doc("trustStore path") + .stringConf + .createOptional + val SSL_CERT_VALUE = "ssl.cert.value" val sslCertValueConf: OptionalConfigEntry[String] = ConfigBuilder(SSL_CERT_VALUE) .doc("base64 encoded certificate") @@ -100,7 +112,7 @@ object ArangoDBConf { val SSL_KEYSTORE_TYPE = "ssl.keystore.type" val sslKeystoreTypeConf: ConfigEntry[String] = ConfigBuilder(SSL_KEYSTORE_TYPE) - .doc("keystore type") + .doc("keystore type, deprecated: use ssl.trustStore.type instead") .stringConf .createWithDefault("jks") @@ -268,7 +280,9 @@ object ArangoDBConf { CONTENT_TYPE -> contentTypeConf, TIMEOUT -> timeoutConf, SSL_ENABLED -> sslEnabledConf, - SSL_VERIFY_HOST -> verifyHostConf, + SSL_VERIFY_HOST -> sslVerifyHostConf, + SSL_TRUST_STORE_PASSWORD -> sslTrustStorePasswordConf, + SSL_TRUST_STORE_PATH -> sslTrustStorePathConf, SSL_CERT_VALUE -> sslCertValueConf, SSL_CERT_TYPE -> sslCertTypeConf, SSL_CERT_ALIAS -> sslCertAliasConf, @@ -479,7 +493,11 @@ class ArangoDBDriverConf(opts: Map[String, String]) extends ArangoDBConf(opts) { val sslEnabled: Boolean = getConf(sslEnabledConf) - val verifyHost: Boolean = getConf(verifyHostConf) + val verifyHost: Boolean = getConf(sslVerifyHostConf) + + val sslTrustStorePassword: Option[String] = getConf(sslTrustStorePasswordConf) + + val sslTrustStorePath: Option[String] = getConf(sslTrustStorePathConf) val sslCertValue: Option[String] = getConf(sslCertValueConf) @@ -489,6 +507,7 @@ class ArangoDBDriverConf(opts: Map[String, String]) extends ArangoDBConf(opts) { val sslAlgorithm: String = getConf(sslAlgorithmConf) + // FIXME: merge with sslTrustStoreType val sslKeystoreType: String = getConf(sslKeystoreTypeConf) val sslProtocol: String = getConf(sslProtocolConf) @@ -514,21 +533,35 @@ class ArangoDBDriverConf(opts: Map[String, String]) extends ArangoDBConf(opts) { builder } - def getSslContext: SSLContext = sslCertValue match { - case Some(b64cert) => - val is = new ByteArrayInputStream(Base64.getDecoder.decode(b64cert)) + def getSslContext: SSLContext = { + if (sslCertValue.isDefined) { + val is = new ByteArrayInputStream(Base64.getDecoder.decode(sslCertValue.get)) val cert = CertificateFactory.getInstance(sslCertType).generateCertificate(is) val ks = KeyStore.getInstance(sslKeystoreType) ks.load(null) // scalastyle:ignore null ks.setCertificateEntry(sslCertAlias, cert) - val tmf = TrustManagerFactory.getInstance(sslAlgorithm) - tmf.init(ks) - val sc = SSLContext.getInstance(sslProtocol) - sc.init(null, tmf.getTrustManagers, null) // scalastyle:ignore null - sc - case None => SSLContext.getDefault + createSslContext(ks) + } else if (sslTrustStorePath.isDefined) { + val ks = KeyStore.getInstance(sslKeystoreType) + val is = new FileInputStream(sslTrustStorePath.get) + try { + ks.load(is, sslTrustStorePassword.map(_.toCharArray).orNull) + } finally { + is.close() + } + createSslContext(ks) + } else { + SSLContext.getDefault + } } + private def createSslContext(ks: KeyStore): SSLContext = { + val tmf = TrustManagerFactory.getInstance(sslAlgorithm) + tmf.init(ks) + val sc = SSLContext.getInstance(sslProtocol) + sc.init(null, tmf.getTrustManagers, null) // scalastyle:ignore null + sc + } } diff --git a/integration-tests/src/test/resources/cert.p12 b/integration-tests/src/test/resources/cert.p12 new file mode 100644 index 0000000..1e9fa4f Binary files /dev/null and b/integration-tests/src/test/resources/cert.p12 differ diff --git a/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/SslTest.scala b/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/SslTest.scala index 60d9f4f..c8e7ee9 100644 --- a/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/SslTest.scala +++ b/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/SslTest.scala @@ -36,18 +36,34 @@ class SslTest { @ParameterizedTest @ValueSource(strings = Array("vst", "http", "http2")) - def sslTest(protocol: String): Unit = { + def sslB64CertTest(protocol: String): Unit = { assumeTrue(protocol != "vst" || isLessThanVersion(version.getVersion, 3, 12, 0)) val df = spark.read .format(classOf[DefaultSource].getName) .options(options ++ Map( ArangoDBConf.PROTOCOL -> protocol, - ArangoDBConf.COLLECTION -> "sslTest" + ArangoDBConf.COLLECTION -> "sslTest", + ArangoDBConf.SSL_CERT_VALUE -> SslTest.b64cert, )) .load() df.show() } + @ParameterizedTest + @ValueSource(strings = Array("vst", "http", "http2")) + def sslTrustStorePathTest(protocol: String): Unit = { + assumeTrue(protocol != "vst" || isLessThanVersion(version.getVersion, 3, 12, 0)) + val df = spark.read + .format(classOf[DefaultSource].getName) + .options(options ++ Map( + ArangoDBConf.PROTOCOL -> protocol, + ArangoDBConf.COLLECTION -> "sslTest", + ArangoDBConf.SSL_TRUST_STORE_PATH -> "src/test/resources/cert.p12", + ArangoDBConf.SSL_TRUST_STORE_PASSWORD -> "12345678" + )) + .load() + df.show() + } } object SslTest { @@ -101,7 +117,6 @@ object SslTest { private val options = Map( ArangoDBConf.SSL_ENABLED -> "true", ArangoDBConf.SSL_VERIFY_HOST -> "false", - ArangoDBConf.SSL_CERT_VALUE -> b64cert, ArangoDBConf.DB -> database, ArangoDBConf.USER -> user, ArangoDBConf.PASSWORD -> password,