From 1354d2d0de550af1f951a199332cd648644c8c48 Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Thu, 9 Apr 2020 09:20:02 -0700 Subject: [PATCH] [SPARK-31021][SQL] Support MariaDB Kerberos login in JDBC connector ### What changes were proposed in this pull request? When loading DataFrames from JDBC datasource with Kerberos authentication, remote executors (yarn-client/cluster etc. modes) fail to establish a connection due to lack of Kerberos ticket or ability to generate it. This is a real issue when trying to ingest data from kerberized data sources (SQL Server, Oracle) in enterprise environment where exposing simple authentication access is not an option due to IT policy issues. In this PR I've added MariaDB support (other supported databases will come in later PRs). What this PR contains: * Introduced `SecureConnectionProvider` and added basic secure functionalities * Added `MariaDBConnectionProvider` * Added `MariaDBConnectionProviderSuite` * Added `MariaDBKrbIntegrationSuite` docker integration test * Added some missing code documentation ### Why are the changes needed? Missing JDBC kerberos support. ### Does this PR introduce any user-facing change? Yes, now user is able to connect to MariaDB using kerberos. ### How was this patch tested? * Additional + existing unit tests * Additional + existing integration tests * Test on cluster manually Closes #28019 from gaborgsomogyi/SPARK-31021. Authored-by: Gabor Somogyi Signed-off-by: Marcelo Vanzin --- external/docker-integration-tests/pom.xml | 4 +- .../resources/mariadb_docker_entrypoint.sh | 24 ++++++ .../src/test/resources/mariadb_krb_setup.sh | 20 +++++ .../sql/jdbc/DockerJDBCIntegrationSuite.scala | 22 ++++-- .../jdbc/DockerKrbJDBCIntegrationSuite.scala | 75 +++++++++++++++++- .../sql/jdbc/MariaDBKrbIntegrationSuite.scala | 67 ++++++++++++++++ .../jdbc/MsSqlServerIntegrationSuite.scala | 2 - .../sql/jdbc/MySQLIntegrationSuite.scala | 1 - .../sql/jdbc/OracleIntegrationSuite.scala | 1 - .../sql/jdbc/PostgresIntegrationSuite.scala | 1 - .../jdbc/PostgresKrbIntegrationSuite.scala | 76 +------------------ pom.xml | 6 ++ sql/core/pom.xml | 4 +- .../jdbc/connection/ConnectionProvider.scala | 7 ++ .../MariaDBConnectionProvider.scala | 54 +++++++++++++ .../PostgresConnectionProvider.scala | 50 ++---------- .../connection/SecureConnectionProvider.scala | 75 ++++++++++++++++++ .../ConnectionProviderSuiteBase.scala | 69 +++++++++++++++++ .../MariaDBConnectionProviderSuite.scala | 27 +++++++ .../PostgresConnectionProviderSuite.scala | 61 +-------------- 20 files changed, 457 insertions(+), 189 deletions(-) create mode 100755 external/docker-integration-tests/src/test/resources/mariadb_docker_entrypoint.sh create mode 100755 external/docker-integration-tests/src/test/resources/mariadb_krb_setup.sh create mode 100644 external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MariaDBKrbIntegrationSuite.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/MariaDBConnectionProvider.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/SecureConnectionProvider.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProviderSuiteBase.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/MariaDBConnectionProviderSuite.scala diff --git a/external/docker-integration-tests/pom.xml b/external/docker-integration-tests/pom.xml index 8743d72b887e1..3b7bd2a71d2d2 100644 --- a/external/docker-integration-tests/pom.xml +++ b/external/docker-integration-tests/pom.xml @@ -121,8 +121,8 @@ test - mysql - mysql-connector-java + org.mariadb.jdbc + mariadb-java-client test diff --git a/external/docker-integration-tests/src/test/resources/mariadb_docker_entrypoint.sh b/external/docker-integration-tests/src/test/resources/mariadb_docker_entrypoint.sh new file mode 100755 index 0000000000000..00885a3b62327 --- /dev/null +++ b/external/docker-integration-tests/src/test/resources/mariadb_docker_entrypoint.sh @@ -0,0 +1,24 @@ +#!/usr/bin/env bash + +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +dpkg-divert --add /bin/systemctl && ln -sT /bin/true /bin/systemctl +apt update +apt install -y mariadb-plugin-gssapi-server +echo "gssapi_keytab_path=/docker-entrypoint-initdb.d/mariadb.keytab" >> /etc/mysql/mariadb.conf.d/auth_gssapi.cnf +echo "gssapi_principal_name=mariadb/__IP_ADDRESS_REPLACE_ME__@EXAMPLE.COM" >> /etc/mysql/mariadb.conf.d/auth_gssapi.cnf +docker-entrypoint.sh mysqld diff --git a/external/docker-integration-tests/src/test/resources/mariadb_krb_setup.sh b/external/docker-integration-tests/src/test/resources/mariadb_krb_setup.sh new file mode 100755 index 0000000000000..e97be805b4592 --- /dev/null +++ b/external/docker-integration-tests/src/test/resources/mariadb_krb_setup.sh @@ -0,0 +1,20 @@ +#!/usr/bin/env bash + +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +mysql -u root -p'rootpass' -e 'CREATE USER "mariadb/__IP_ADDRESS_REPLACE_ME__@EXAMPLE.COM" IDENTIFIED WITH gssapi;' +mysql -u root -p'rootpass' -D mysql -e 'GRANT ALL PRIVILEGES ON *.* TO "mariadb/__IP_ADDRESS_REPLACE_ME__@EXAMPLE.COM";' diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala index cd26fb3628151..376dd4646608c 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala @@ -58,10 +58,19 @@ abstract class DatabaseOnDocker { */ def getJdbcUrl(ip: String, port: Int): String + /** + * Optional entry point when container starts + * + * Startup process is a parameter of entry point. This may or may not be considered during + * startup. Prefer entry point to startup process when you need a command always to be executed or + * you want to change the initialization order. + */ + def getEntryPoint: Option[String] = None + /** * Optional process to run when container starts */ - def getStartupProcessName: Option[String] + def getStartupProcessName: Option[String] = None /** * Optional step before container starts @@ -77,6 +86,7 @@ abstract class DockerJDBCIntegrationSuite extends SharedSparkSession with Eventu val db: DatabaseOnDocker private var docker: DockerClient = _ + protected var externalPort: Int = _ private var containerId: String = _ protected var jdbcUrl: String = _ @@ -101,7 +111,7 @@ abstract class DockerJDBCIntegrationSuite extends SharedSparkSession with Eventu docker.pull(db.imageName) } // Configure networking (necessary for boot2docker / Docker Machine) - val externalPort: Int = { + externalPort = { val sock = new ServerSocket(0) val port = sock.getLocalPort sock.close() @@ -118,9 +128,11 @@ abstract class DockerJDBCIntegrationSuite extends SharedSparkSession with Eventu .networkDisabled(false) .env(db.env.map { case (k, v) => s"$k=$v" }.toSeq.asJava) .exposedPorts(s"${db.jdbcPort}/tcp") - if(db.getStartupProcessName.isDefined) { - containerConfigBuilder - .cmd(db.getStartupProcessName.get) + if (db.getEntryPoint.isDefined) { + containerConfigBuilder.entrypoint(db.getEntryPoint.get) + } + if (db.getStartupProcessName.isDefined) { + containerConfigBuilder.cmd(db.getStartupProcessName.get) } db.beforeContainerStart(hostConfigBuilder, containerConfigBuilder) containerConfigBuilder.hostConfig(hostConfigBuilder.build()) diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerKrbJDBCIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerKrbJDBCIntegrationSuite.scala index 583d8108c716c..009b4a2b1b32e 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerKrbJDBCIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerKrbJDBCIntegrationSuite.scala @@ -18,17 +18,22 @@ package org.apache.spark.sql.jdbc import java.io.{File, FileInputStream, FileOutputStream} +import java.sql.Connection +import java.util.Properties import javax.security.auth.login.Configuration import scala.io.Source import org.apache.hadoop.minikdc.MiniKdc +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.StringType import org.apache.spark.util.{SecurityUtils, Utils} abstract class DockerKrbJDBCIntegrationSuite extends DockerJDBCIntegrationSuite { private var kdc: MiniKdc = _ - protected var workDir: File = _ + protected var entryPointDir: File = _ + protected var initDbDir: File = _ protected val userName: String protected var principal: String = _ protected val keytabFileName: String @@ -46,8 +51,9 @@ abstract class DockerKrbJDBCIntegrationSuite extends DockerJDBCIntegrationSuite principal = s"$userName@${kdc.getRealm}" - workDir = Utils.createTempDir() - val keytabFile = new File(workDir, keytabFileName) + entryPointDir = Utils.createTempDir() + initDbDir = Utils.createTempDir() + val keytabFile = new File(initDbDir, keytabFileName) keytabFullPath = keytabFile.getAbsolutePath kdc.createPrincipal(keytabFile, userName) logInfo(s"Created keytab file: $keytabFullPath") @@ -62,6 +68,7 @@ abstract class DockerKrbJDBCIntegrationSuite extends DockerJDBCIntegrationSuite try { if (kdc != null) { kdc.stop() + kdc = null } Configuration.setConfiguration(null) SecurityUtils.setGlobalKrbDebug(false) @@ -71,7 +78,7 @@ abstract class DockerKrbJDBCIntegrationSuite extends DockerJDBCIntegrationSuite } protected def copyExecutableResource( - fileName: String, dir: File, processLine: String => String) = { + fileName: String, dir: File, processLine: String => String = identity) = { val newEntry = new File(dir.getAbsolutePath, fileName) newEntry.createNewFile() Utils.tryWithResource( @@ -91,4 +98,64 @@ abstract class DockerKrbJDBCIntegrationSuite extends DockerJDBCIntegrationSuite logInfo(s"Created executable resource file: ${newEntry.getAbsolutePath}") newEntry } + + override def dataPreparation(conn: Connection): Unit = { + conn.prepareStatement("CREATE TABLE bar (c0 text)").executeUpdate() + conn.prepareStatement("INSERT INTO bar VALUES ('hello')").executeUpdate() + } + + test("Basic read test in query option") { + // This makes sure Spark must do authentication + Configuration.setConfiguration(null) + + val expectedResult = Set("hello").map(Row(_)) + + val query = "SELECT c0 FROM bar" + // query option to pass on the query string. + val df = spark.read.format("jdbc") + .option("url", jdbcUrl) + .option("keytab", keytabFullPath) + .option("principal", principal) + .option("query", query) + .load() + assert(df.collect().toSet === expectedResult) + } + + test("Basic read test in create table path") { + // This makes sure Spark must do authentication + Configuration.setConfiguration(null) + + val expectedResult = Set("hello").map(Row(_)) + + val query = "SELECT c0 FROM bar" + // query option in the create table path. + sql( + s""" + |CREATE OR REPLACE TEMPORARY VIEW queryOption + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$jdbcUrl', query '$query', keytab '$keytabFullPath', principal '$principal') + """.stripMargin.replaceAll("\n", " ")) + assert(sql("select c0 from queryOption").collect().toSet === expectedResult) + } + + test("Basic write test") { + // This makes sure Spark must do authentication + Configuration.setConfiguration(null) + + val props = new Properties + props.setProperty("keytab", keytabFullPath) + props.setProperty("principal", principal) + + val tableName = "write_test" + sqlContext.createDataFrame(Seq(("foo", "bar"))) + .write.jdbc(jdbcUrl, tableName, props) + val df = sqlContext.read.jdbc(jdbcUrl, tableName, props) + + val schema = df.schema + assert(schema.map(_.dataType).toSeq === Seq(StringType, StringType)) + val rows = df.collect() + assert(rows.length === 1) + assert(rows(0).getString(0) === "foo") + assert(rows(0).getString(1) === "bar") + } } diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MariaDBKrbIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MariaDBKrbIntegrationSuite.scala new file mode 100644 index 0000000000000..7c1adc990bab3 --- /dev/null +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MariaDBKrbIntegrationSuite.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import javax.security.auth.login.Configuration + +import com.spotify.docker.client.messages.{ContainerConfig, HostConfig} + +import org.apache.spark.sql.execution.datasources.jdbc.connection.SecureConnectionProvider +import org.apache.spark.tags.DockerTest + +@DockerTest +class MariaDBKrbIntegrationSuite extends DockerKrbJDBCIntegrationSuite { + override protected val userName = s"mariadb/$dockerIp" + override protected val keytabFileName = "mariadb.keytab" + + override val db = new DatabaseOnDocker { + override val imageName = "mariadb:10.4" + override val env = Map( + "MYSQL_ROOT_PASSWORD" -> "rootpass" + ) + override val usesIpc = false + override val jdbcPort = 3306 + + override def getJdbcUrl(ip: String, port: Int): String = + s"jdbc:mysql://$ip:$port/mysql?user=$principal" + + override def getEntryPoint: Option[String] = + Some("/docker-entrypoint/mariadb_docker_entrypoint.sh") + + override def beforeContainerStart( + hostConfigBuilder: HostConfig.Builder, + containerConfigBuilder: ContainerConfig.Builder): Unit = { + def replaceIp(s: String): String = s.replace("__IP_ADDRESS_REPLACE_ME__", dockerIp) + copyExecutableResource("mariadb_docker_entrypoint.sh", entryPointDir, replaceIp) + copyExecutableResource("mariadb_krb_setup.sh", initDbDir, replaceIp) + + hostConfigBuilder.appendBinds( + HostConfig.Bind.from(entryPointDir.getAbsolutePath) + .to("/docker-entrypoint").readOnly(true).build(), + HostConfig.Bind.from(initDbDir.getAbsolutePath) + .to("/docker-entrypoint-initdb.d").readOnly(true).build() + ) + } + } + + override protected def setAuthentication(keytabFile: String, principal: String): Unit = { + val config = new SecureConnectionProvider.JDBCConfiguration( + Configuration.getConfiguration, "Krb5ConnectorContext", keytabFile, principal) + Configuration.setConfiguration(config) + } +} diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala index 5738307095933..42d64873c44d9 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala @@ -37,8 +37,6 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite { override def getJdbcUrl(ip: String, port: Int): String = s"jdbc:sqlserver://$ip:$port;user=sa;password=Sapass123;" - - override def getStartupProcessName: Option[String] = None } override def dataPreparation(conn: Connection): Unit = { diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala index bba1b5275269b..4cbcb59e02de1 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala @@ -35,7 +35,6 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite { override val jdbcPort: Int = 3306 override def getJdbcUrl(ip: String, port: Int): String = s"jdbc:mysql://$ip:$port/mysql?user=root&password=rootpass" - override def getStartupProcessName: Option[String] = None } override def dataPreparation(conn: Connection): Unit = { diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala index 6faa888cf18ed..24c3adb9c0153 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala @@ -66,7 +66,6 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSpark override val jdbcPort: Int = 1521 override def getJdbcUrl(ip: String, port: Int): String = s"jdbc:oracle:thin:system/oracle@//$ip:$port/xe" - override def getStartupProcessName: Option[String] = None } override def dataPreparation(conn: Connection): Unit = { diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala index 599f00def0750..6611bc2d19ed8 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -37,7 +37,6 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { override val jdbcPort = 5432 override def getJdbcUrl(ip: String, port: Int): String = s"jdbc:postgresql://$ip:$port/postgres?user=postgres&password=rootpass" - override def getStartupProcessName: Option[String] = None } override def dataPreparation(conn: Connection): Unit = { diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresKrbIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresKrbIntegrationSuite.scala index 721a4882b986a..adf30fbdc1e12 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresKrbIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresKrbIntegrationSuite.scala @@ -17,15 +17,11 @@ package org.apache.spark.sql.jdbc -import java.sql.Connection -import java.util.Properties import javax.security.auth.login.Configuration import com.spotify.docker.client.messages.{ContainerConfig, HostConfig} -import org.apache.spark.sql.Row -import org.apache.spark.sql.execution.datasources.jdbc.connection.PostgresConnectionProvider -import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.execution.datasources.jdbc.connection.SecureConnectionProvider import org.apache.spark.tags.DockerTest @DockerTest @@ -44,86 +40,22 @@ class PostgresKrbIntegrationSuite extends DockerKrbJDBCIntegrationSuite { override def getJdbcUrl(ip: String, port: Int): String = s"jdbc:postgresql://$ip:$port/postgres?user=$principal&gsslib=gssapi" - override def getStartupProcessName: Option[String] = None - override def beforeContainerStart( hostConfigBuilder: HostConfig.Builder, containerConfigBuilder: ContainerConfig.Builder): Unit = { def replaceIp(s: String): String = s.replace("__IP_ADDRESS_REPLACE_ME__", dockerIp) - copyExecutableResource("postgres_krb_setup.sh", workDir, replaceIp) + copyExecutableResource("postgres_krb_setup.sh", initDbDir, replaceIp) hostConfigBuilder.appendBinds( - HostConfig.Bind.from(workDir.getAbsolutePath) + HostConfig.Bind.from(initDbDir.getAbsolutePath) .to("/docker-entrypoint-initdb.d").readOnly(true).build() ) } } override protected def setAuthentication(keytabFile: String, principal: String): Unit = { - val config = new PostgresConnectionProvider.PGJDBCConfiguration( + val config = new SecureConnectionProvider.JDBCConfiguration( Configuration.getConfiguration, "pgjdbc", keytabFile, principal) Configuration.setConfiguration(config) } - - override def dataPreparation(conn: Connection): Unit = { - conn.prepareStatement("CREATE DATABASE foo").executeUpdate() - conn.setCatalog("foo") - conn.prepareStatement("CREATE TABLE bar (c0 text)").executeUpdate() - conn.prepareStatement("INSERT INTO bar VALUES ('hello')").executeUpdate() - } - - test("Basic read test in query option") { - // This makes sure Spark must do authentication - Configuration.setConfiguration(null) - - val expectedResult = Set("hello").map(Row(_)) - - val query = "SELECT c0 FROM bar" - // query option to pass on the query string. - val df = spark.read.format("jdbc") - .option("url", jdbcUrl) - .option("keytab", keytabFullPath) - .option("principal", principal) - .option("query", query) - .load() - assert(df.collect().toSet === expectedResult) - } - - test("Basic read test in create table path") { - // This makes sure Spark must do authentication - Configuration.setConfiguration(null) - - val expectedResult = Set("hello").map(Row(_)) - - val query = "SELECT c0 FROM bar" - // query option in the create table path. - sql( - s""" - |CREATE OR REPLACE TEMPORARY VIEW queryOption - |USING org.apache.spark.sql.jdbc - |OPTIONS (url '$jdbcUrl', query '$query', keytab '$keytabFullPath', principal '$principal') - """.stripMargin.replaceAll("\n", " ")) - assert(sql("select c0 from queryOption").collect().toSet === expectedResult) - } - - test("Basic write test") { - // This makes sure Spark must do authentication - Configuration.setConfiguration(null) - - val props = new Properties - props.setProperty("keytab", keytabFullPath) - props.setProperty("principal", principal) - - val tableName = "write_test" - sqlContext.createDataFrame(Seq(("foo", "bar"))) - .write.jdbc(jdbcUrl, tableName, props) - val df = sqlContext.read.jdbc(jdbcUrl, tableName, props) - - val schema = df.schema - assert(schema.map(_.dataType).toSeq === Seq(StringType, StringType)) - val rows = df.collect() - assert(rows.length === 1) - assert(rows(0).getString(0) === "foo") - assert(rows(0).getString(1) === "bar") - } } diff --git a/pom.xml b/pom.xml index cc48ee794ea04..cd85db6e03264 100644 --- a/pom.xml +++ b/pom.xml @@ -951,6 +951,12 @@ 5.1.38 test + + org.mariadb.jdbc + mariadb-java-client + 2.5.4 + test + org.postgresql postgresql diff --git a/sql/core/pom.xml b/sql/core/pom.xml index c95fe3ce1c120..e97c7fd3280be 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -131,8 +131,8 @@ test - mysql - mysql-connector-java + org.mariadb.jdbc + mariadb-java-client test diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProvider.scala index ccaff0d6ca7d4..c864f1f52fcce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProvider.scala @@ -28,6 +28,9 @@ import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions * the parameters. */ private[jdbc] trait ConnectionProvider { + /** + * Opens connection toward the database. + */ def getConnection(): Connection } @@ -43,6 +46,10 @@ private[jdbc] object ConnectionProvider extends Logging { logDebug("Postgres connection provider found") new PostgresConnectionProvider(driver, options) + case MariaDBConnectionProvider.driverClass => + logDebug("MariaDB connection provider found") + new MariaDBConnectionProvider(driver, options) + case _ => throw new IllegalArgumentException(s"Driver ${options.driverClass} does not support " + "Kerberos authentication") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/MariaDBConnectionProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/MariaDBConnectionProvider.scala new file mode 100644 index 0000000000000..eb2f0f78022ba --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/MariaDBConnectionProvider.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.jdbc.connection + +import java.sql.Driver +import javax.security.auth.login.Configuration + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions + +private[jdbc] class MariaDBConnectionProvider(driver: Driver, options: JDBCOptions) + extends SecureConnectionProvider(driver, options) { + override val appEntry: String = { + "Krb5ConnectorContext" + } + + override def setAuthenticationConfigIfNeeded(): Unit = { + val parent = Configuration.getConfiguration + val configEntry = parent.getAppConfigurationEntry(appEntry) + /** + * Couple of things to mention here: + * 1. MariaDB doesn't support JAAS application name configuration + * 2. MariaDB sets a default JAAS config if "java.security.auth.login.config" is not set + */ + val entryUsesKeytab = configEntry != null && + configEntry.exists(_.getOptions().get("useKeyTab") == "true") + if (configEntry == null || configEntry.isEmpty || !entryUsesKeytab) { + val config = new SecureConnectionProvider.JDBCConfiguration( + parent, appEntry, options.keytab, options.principal) + logDebug("Adding database specific security configuration") + Configuration.setConfiguration(config) + } + } +} + +private[sql] object MariaDBConnectionProvider { + val driverClass = "org.mariadb.jdbc.Driver" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/PostgresConnectionProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/PostgresConnectionProvider.scala index e793c4dfd780e..14911fc75ebc1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/PostgresConnectionProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/PostgresConnectionProvider.scala @@ -17,66 +17,32 @@ package org.apache.spark.sql.execution.datasources.jdbc.connection -import java.sql.{Connection, Driver} +import java.sql.Driver import java.util.Properties -import javax.security.auth.login.{AppConfigurationEntry, Configuration} - -import scala.collection.JavaConverters._ +import javax.security.auth.login.Configuration import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions -import org.apache.spark.sql.execution.datasources.jdbc.connection.PostgresConnectionProvider.PGJDBCConfiguration -import org.apache.spark.util.SecurityUtils private[jdbc] class PostgresConnectionProvider(driver: Driver, options: JDBCOptions) - extends BasicConnectionProvider(driver, options) { - val appEntry: String = { + extends SecureConnectionProvider(driver, options) { + override val appEntry: String = { val parseURL = driver.getClass.getMethod("parseURL", classOf[String], classOf[Properties]) val properties = parseURL.invoke(driver, options.url, null).asInstanceOf[Properties] properties.getProperty("jaasApplicationName", "pgjdbc") } - def setAuthenticationConfigIfNeeded(): Unit = { + override def setAuthenticationConfigIfNeeded(): Unit = { val parent = Configuration.getConfiguration val configEntry = parent.getAppConfigurationEntry(appEntry) if (configEntry == null || configEntry.isEmpty) { - val config = new PGJDBCConfiguration(parent, appEntry, options.keytab, options.principal) + val config = new SecureConnectionProvider.JDBCConfiguration( + parent, appEntry, options.keytab, options.principal) + logDebug("Adding database specific security configuration") Configuration.setConfiguration(config) } } - - override def getConnection(): Connection = { - setAuthenticationConfigIfNeeded() - super.getConnection() - } } private[sql] object PostgresConnectionProvider { - class PGJDBCConfiguration( - parent: Configuration, - appEntry: String, - keytab: String, - principal: String) extends Configuration { - private val entry = - new AppConfigurationEntry( - SecurityUtils.getKrb5LoginModuleName(), - AppConfigurationEntry.LoginModuleControlFlag.REQUIRED, - Map[String, Object]( - "useTicketCache" -> "false", - "useKeyTab" -> "true", - "keyTab" -> keytab, - "principal" -> principal, - "debug" -> "true" - ).asJava - ) - - override def getAppConfigurationEntry(name: String): Array[AppConfigurationEntry] = { - if (name.equals(appEntry)) { - Array(entry) - } else { - parent.getAppConfigurationEntry(name) - } - } - } - val driverClass = "org.postgresql.Driver" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/SecureConnectionProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/SecureConnectionProvider.scala new file mode 100644 index 0000000000000..ff192d71e6f33 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/SecureConnectionProvider.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.jdbc.connection + +import java.sql.{Connection, Driver} +import javax.security.auth.login.{AppConfigurationEntry, Configuration} + +import scala.collection.JavaConverters._ + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions +import org.apache.spark.util.SecurityUtils + +private[jdbc] abstract class SecureConnectionProvider(driver: Driver, options: JDBCOptions) + extends BasicConnectionProvider(driver, options) with Logging { + override def getConnection(): Connection = { + setAuthenticationConfigIfNeeded() + super.getConnection() + } + + /** + * Returns JAAS application name. This is sometimes configurable on the JDBC driver level. + */ + val appEntry: String + + /** + * Sets database specific authentication configuration when needed. If configuration already set + * then later calls must be no op. + */ + def setAuthenticationConfigIfNeeded(): Unit +} + +object SecureConnectionProvider { + class JDBCConfiguration( + parent: Configuration, + appEntry: String, + keytab: String, + principal: String) extends Configuration { + val entry = + new AppConfigurationEntry( + SecurityUtils.getKrb5LoginModuleName(), + AppConfigurationEntry.LoginModuleControlFlag.REQUIRED, + Map[String, Object]( + "useTicketCache" -> "false", + "useKeyTab" -> "true", + "keyTab" -> keytab, + "principal" -> principal, + "debug" -> "true" + ).asJava + ) + + override def getAppConfigurationEntry(name: String): Array[AppConfigurationEntry] = { + if (name.equals(appEntry)) { + Array(entry) + } else { + parent.getAppConfigurationEntry(name) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProviderSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProviderSuiteBase.scala new file mode 100644 index 0000000000000..d18a3088c4f2f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProviderSuiteBase.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.jdbc.connection + +import java.sql.{Driver, DriverManager} +import javax.security.auth.login.Configuration + +import scala.collection.JavaConverters._ + +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.execution.datasources.jdbc.{DriverRegistry, JDBCOptions} + +abstract class ConnectionProviderSuiteBase extends SparkFunSuite with BeforeAndAfterEach { + protected def registerDriver(driverClass: String): Driver = { + DriverRegistry.register(driverClass) + DriverManager.getDrivers.asScala.collectFirst { + case d if d.getClass.getCanonicalName == driverClass => d + }.get + } + + protected def options(url: String) = new JDBCOptions(Map[String, String]( + JDBCOptions.JDBC_URL -> url, + JDBCOptions.JDBC_TABLE_NAME -> "table", + JDBCOptions.JDBC_KEYTAB -> "/path/to/keytab", + JDBCOptions.JDBC_PRINCIPAL -> "principal" + )) + + override def afterEach(): Unit = { + try { + Configuration.setConfiguration(null) + } finally { + super.afterEach() + } + } + + protected def testSecureConnectionProvider(provider: SecureConnectionProvider): Unit = { + // Make sure no authentication for the database is set + assert(Configuration.getConfiguration.getAppConfigurationEntry(provider.appEntry) == null) + + // Make sure the first call sets authentication properly + val savedConfig = Configuration.getConfiguration + provider.setAuthenticationConfigIfNeeded() + val config = Configuration.getConfiguration + assert(savedConfig != config) + val appEntry = config.getAppConfigurationEntry(provider.appEntry) + assert(appEntry != null) + + // Make sure a second call is not modifying the existing authentication + provider.setAuthenticationConfigIfNeeded() + assert(config.getAppConfigurationEntry(provider.appEntry) === appEntry) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/MariaDBConnectionProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/MariaDBConnectionProviderSuite.scala new file mode 100644 index 0000000000000..70cad2097eb43 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/MariaDBConnectionProviderSuite.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.jdbc.connection + +class MariaDBConnectionProviderSuite extends ConnectionProviderSuiteBase { + test("setAuthenticationConfigIfNeeded must set authentication if not set") { + val driver = registerDriver(MariaDBConnectionProvider.driverClass) + val provider = new MariaDBConnectionProvider(driver, options("jdbc:mysql://localhost/mysql")) + + testSecureConnectionProvider(provider) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/PostgresConnectionProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/PostgresConnectionProviderSuite.scala index 59ff1c79bd064..8cef7652f9c54 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/PostgresConnectionProviderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/PostgresConnectionProviderSuite.scala @@ -17,69 +17,16 @@ package org.apache.spark.sql.execution.datasources.jdbc.connection -import java.sql.{Driver, DriverManager} -import javax.security.auth.login.Configuration - -import scala.collection.JavaConverters._ - -import org.scalatest.BeforeAndAfterEach - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.execution.datasources.jdbc.{DriverRegistry, JDBCOptions} - -class PostgresConnectionProviderSuite extends SparkFunSuite with BeforeAndAfterEach { - private def options(url: String) = new JDBCOptions(Map[String, String]( - JDBCOptions.JDBC_URL -> url, - JDBCOptions.JDBC_TABLE_NAME -> "table", - JDBCOptions.JDBC_KEYTAB -> "/path/to/keytab", - JDBCOptions.JDBC_PRINCIPAL -> "principal" - )) - - override def afterEach(): Unit = { - try { - Configuration.setConfiguration(null) - } finally { - super.afterEach() - } - } - +class PostgresConnectionProviderSuite extends ConnectionProviderSuiteBase { test("setAuthenticationConfigIfNeeded must set authentication if not set") { - DriverRegistry.register(PostgresConnectionProvider.driverClass) - val driver = DriverManager.getDrivers.asScala.collectFirst { - case d if d.getClass.getCanonicalName == PostgresConnectionProvider.driverClass => d - }.get + val driver = registerDriver(PostgresConnectionProvider.driverClass) val defaultProvider = new PostgresConnectionProvider( driver, options("jdbc:postgresql://localhost/postgres")) val customProvider = new PostgresConnectionProvider( driver, options(s"jdbc:postgresql://localhost/postgres?jaasApplicationName=custompgjdbc")) assert(defaultProvider.appEntry !== customProvider.appEntry) - - // Make sure no authentication for postgres is set - assert(Configuration.getConfiguration.getAppConfigurationEntry( - defaultProvider.appEntry) == null) - assert(Configuration.getConfiguration.getAppConfigurationEntry( - customProvider.appEntry) == null) - - // Make sure the first call sets authentication properly - val savedConfig = Configuration.getConfiguration - defaultProvider.setAuthenticationConfigIfNeeded() - val defaultConfig = Configuration.getConfiguration - assert(savedConfig != defaultConfig) - val defaultAppEntry = defaultConfig.getAppConfigurationEntry(defaultProvider.appEntry) - assert(defaultAppEntry != null) - customProvider.setAuthenticationConfigIfNeeded() - val customConfig = Configuration.getConfiguration - assert(savedConfig != customConfig) - assert(defaultConfig != customConfig) - val customAppEntry = customConfig.getAppConfigurationEntry(customProvider.appEntry) - assert(customAppEntry != null) - - // Make sure a second call is not modifying the existing authentication - defaultProvider.setAuthenticationConfigIfNeeded() - customProvider.setAuthenticationConfigIfNeeded() - assert(customConfig == Configuration.getConfiguration) - assert(defaultConfig.getAppConfigurationEntry(defaultProvider.appEntry) === defaultAppEntry) - assert(customConfig.getAppConfigurationEntry(customProvider.appEntry) === customAppEntry) + testSecureConnectionProvider(defaultProvider) + testSecureConnectionProvider(customProvider) } }