From 315e972a1d36a316f40a7f3e073ee5ead7b3db54 Mon Sep 17 00:00:00 2001 From: tribbloid Date: Tue, 15 Sep 2015 19:02:57 -0400 Subject: [PATCH] test case demonstrating SPARK-10625: Spark SQL JDBC read/write is unable to handle JDBC Drivers that adds unserializable objects into connection properties add one more unit test fix JDBCRelation & DataFrameWriter to pass all tests revise scala style put driver replacement code into a shared function fix styling upgrade to master and resolve all related issues --- .../datasources/jdbc/DriverRegistry.scala | 1 - .../execution/datasources/jdbc/JDBCRDD.scala | 2 +- .../datasources/jdbc/JDBCRelation.scala | 22 ++++++- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 7 +++ .../spark/sql/jdbc/JDBCWriteSuite.scala | 8 +++ .../sql/jdbc/UnserializableDriverHelper.scala | 58 +++++++++++++++++++ 6 files changed, 93 insertions(+), 5 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/jdbc/UnserializableDriverHelper.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala index 7ccd61ed469e9..72d155cad846f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala @@ -57,4 +57,3 @@ object DriverRegistry extends Logging { case driver => driver.getClass.getCanonicalName } } - diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index f9b72597dd2a9..b1fe66d9a73d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -185,7 +185,7 @@ private[sql] object JDBCRDD extends Logging { case e: ClassNotFoundException => logWarning(s"Couldn't find class $driver", e) } - DriverManager.getConnection(url, properties) + DriverManager.getConnection(url, JDBCRelation.getEffectiveProperties(properties)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index f9300dc2cb529..195a5e5a97557 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -75,6 +75,18 @@ private[sql] object JDBCRelation { } ans.toArray } + + def getEffectiveProperties( + connectionProperties: Properties, + extraOptions: scala.collection.Map[String, String] = Map()): Properties = { + val props = new Properties() + extraOptions.foreach { case (key, value) => + props.put(key, value) + } + // connectionProperties should override settings in extraOptions + props.putAll(connectionProperties) + props + } } private[sql] case class JDBCRelation( @@ -88,7 +100,11 @@ private[sql] case class JDBCRelation( override val needConversion: Boolean = false - override val schema: StructType = JDBCRDD.resolveTable(url, table, properties) + override val schema: StructType = JDBCRDD.resolveTable( + url, + table, + JDBCRelation.getEffectiveProperties(properties) + ) override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { val driver: String = DriverRegistry.getDriverClassName(url) @@ -98,7 +114,7 @@ private[sql] case class JDBCRelation( schema, driver, url, - properties, + JDBCRelation.getEffectiveProperties(properties), table, requiredColumns, filters, @@ -108,6 +124,6 @@ private[sql] case class JDBCRelation( override def insert(data: DataFrame, overwrite: Boolean): Unit = { data.write .mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append) - .jdbc(url, table, properties) + .jdbc(url, table, JDBCRelation.getEffectiveProperties(properties)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index d530b1a469ce2..2a5fc9bc46228 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -484,4 +484,11 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext assert(h2.getTableExistsQuery(table) == defaultQuery) assert(derby.getTableExistsQuery(table) == defaultQuery) } + + test("Basic API with Unserializable Driver Properties") { + UnserializableDriverHelper.replaceDriverDuring { + assert(sqlContext.read.jdbc( + urlWithUserAndPass, "TEST.PEOPLE", new Properties).collect().length === 3) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index e23ee6693133b..3564e1823cb28 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -151,4 +151,12 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).count) assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) } + + test("INSERT to JDBC Datasource with Unserializable Driver Properties") { + UnserializableDriverHelper.replaceDriverDuring { + sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") + assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).count) + assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/UnserializableDriverHelper.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/UnserializableDriverHelper.scala new file mode 100644 index 0000000000000..737ba8366c605 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/UnserializableDriverHelper.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.sql.{DriverManager, Connection} +import java.util.Properties +import java.util.logging.Logger + +object UnserializableDriverHelper { + + def replaceDriverDuring[T](f: => T): T = { + import scala.collection.JavaConverters._ + + object UnserializableH2Driver extends org.h2.Driver { + + override def connect(url: String, info: Properties): Connection = { + + val result = super.connect(url, info) + info.put("unserializableDriver", this) + result + } + + override def getParentLogger: Logger = null + } + + val oldDrivers = DriverManager.getDrivers.asScala.filter(_.acceptsURL("jdbc:h2:")).toSeq + oldDrivers.foreach{ + DriverManager.deregisterDriver + } + DriverManager.registerDriver(UnserializableH2Driver) + + val result = try { + f + } + finally { + DriverManager.deregisterDriver(UnserializableH2Driver) + oldDrivers.foreach{ + DriverManager.registerDriver + } + } + result + } +}