From c8294aea77daa51090f3403a5ac5056435b7ecd4 Mon Sep 17 00:00:00 2001 From: Vyacheslav Baranov Date: Wed, 29 Apr 2015 17:46:35 +0300 Subject: [PATCH 1/3] [SPARK-6913] Fixed "No suitable driver found" when using using JDBC driver added with SparkContext.addJar --- .../org/apache/spark/sql/jdbc/JDBCRDD.scala | 2 +- .../apache/spark/sql/jdbc/JDBCRelation.scala | 4 +- .../org/apache/spark/sql/jdbc/jdbc.scala | 57 ++++++++++++++++++- 3 files changed, 59 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala index f3b5455574d1a..3be185686bb28 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala @@ -154,7 +154,7 @@ private[sql] object JDBCRDD extends Logging { def getConnector(driver: String, url: String, properties: Properties): () => Connection = { () => { try { - if (driver != null) Utils.getContextOrSparkClassLoader.loadClass(driver) + if (driver != null) DriverRegistry.register(driver) } catch { case e: ClassNotFoundException => { logWarning(s"Couldn't find class $driver", e); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala index 5f480083d5a49..d6b3fb3291a2e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala @@ -100,7 +100,7 @@ private[sql] class DefaultSource extends RelationProvider { val upperBound = parameters.getOrElse("upperBound", null) val numPartitions = parameters.getOrElse("numPartitions", null) - if (driver != null) Utils.getContextOrSparkClassLoader.loadClass(driver) + if (driver != null) DriverRegistry.register(driver) if (partitionColumn != null && (lowerBound == null || upperBound == null || numPartitions == null)) { @@ -136,7 +136,7 @@ private[sql] case class JDBCRelation( override val schema: StructType = JDBCRDD.resolveTable(url, table, properties) override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { - val driver: String = DriverManager.getDriver(url).getClass.getCanonicalName + val driver: String = DriverRegistry.getDriverClassName(url) JDBCRDD.scanTable( sqlContext.sparkContext, schema, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala index d4e0abc040bc6..ccbf652192d8e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala @@ -17,10 +17,14 @@ package org.apache.spark.sql -import java.sql.{Connection, DriverManager, PreparedStatement} +import java.sql.{Connection, Driver, DriverManager, DriverPropertyInfo, PreparedStatement} +import java.util.Properties + +import scala.collection.concurrent.TrieMap import org.apache.spark.Logging import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils package object jdbc { private[sql] object JDBCWriteDetails extends Logging { @@ -179,4 +183,55 @@ package object jdbc { } } + + private [sql] case class DriverWrapper(wrapped: Driver) extends Driver { + override def acceptsURL(url: String): Boolean = wrapped.acceptsURL(url) + + override def jdbcCompliant(): Boolean = wrapped.jdbcCompliant() + + override def getPropertyInfo(url: String, info: Properties): Array[DriverPropertyInfo] = { + wrapped.getPropertyInfo(url, info) + } + + override def getMinorVersion: Int = wrapped.getMinorVersion + + override def getParentLogger: java.util.logging.Logger = wrapped.getParentLogger + + override def connect(url: String, info: Properties): Connection = wrapped.connect(url, info) + + override def getMajorVersion: Int = wrapped.getMajorVersion + } + + /** + * java.sql.DriverManager is always loaded by bootstrap classloader, + * so it can't load JDBC drivers accessible by Spark ClassLoader. + * + * To solve the problem, drivers from user-supplied jars are wrapped + * into thin wrapper. + */ + private [sql] object DriverRegistry extends Logging { + + val wrapperMap: TrieMap[String, DriverWrapper] = TrieMap.empty + + def register(className: String): Unit = { + val cls = Utils.getContextOrSparkClassLoader.loadClass(className) + if (cls.getClassLoader == null) { + logTrace(s"$className has been loaded with bootstrap ClassLoader, wrapper is not required") + } else if (wrapperMap.get(className).isDefined) { + logTrace(s"Wrapper for $className already exists") + } else { + val wrapper = new DriverWrapper(cls.newInstance().asInstanceOf[Driver]) + if (wrapperMap.putIfAbsent(className, wrapper).isEmpty) { + DriverManager.registerDriver(wrapper) + logTrace(s"Wrapper for $className registered") + } + } + } + + def getDriverClassName(url: String): String = DriverManager.getDriver(url) match { + case DriverWrapper(wrapped) => wrapped.getClass.getCanonicalName + case driver => driver.getClass.getCanonicalName + } + } + } // package object jdbc From b2a727c7ed4412c9b9d42b75c5a52369bd9ac43e Mon Sep 17 00:00:00 2001 From: Vyacheslav Baranov Date: Wed, 29 Apr 2015 19:54:46 +0300 Subject: [PATCH 2/3] [SPARK-6913] Fixed thread race on driver registration --- .../org/apache/spark/sql/jdbc/jdbc.scala | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala index ccbf652192d8e..5ddc6806650f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala @@ -19,8 +19,9 @@ package org.apache.spark.sql import java.sql.{Connection, Driver, DriverManager, DriverPropertyInfo, PreparedStatement} import java.util.Properties +import java.util.concurrent.locks.ReentrantLock -import scala.collection.concurrent.TrieMap +import scala.collection.mutable import org.apache.spark.Logging import org.apache.spark.sql.types._ @@ -211,7 +212,9 @@ package object jdbc { */ private [sql] object DriverRegistry extends Logging { - val wrapperMap: TrieMap[String, DriverWrapper] = TrieMap.empty + val wrapperMap: mutable.Map[String, DriverWrapper] = mutable.Map.empty + + val lock = new ReentrantLock() def register(className: String): Unit = { val cls = Utils.getContextOrSparkClassLoader.loadClass(className) @@ -220,10 +223,16 @@ package object jdbc { } else if (wrapperMap.get(className).isDefined) { logTrace(s"Wrapper for $className already exists") } else { - val wrapper = new DriverWrapper(cls.newInstance().asInstanceOf[Driver]) - if (wrapperMap.putIfAbsent(className, wrapper).isEmpty) { - DriverManager.registerDriver(wrapper) - logTrace(s"Wrapper for $className registered") + lock.lock() + try { + if (wrapperMap.get(className).isEmpty) { + val wrapper = new DriverWrapper(cls.newInstance().asInstanceOf[Driver]) + DriverManager.registerDriver(wrapper) + wrapperMap(className) = wrapper + logTrace(s"Wrapper for $className registered") + } + } finally { + lock.unlock() } } } From 510c43f691f4b3901ca877585e25b28121a0cf36 Mon Sep 17 00:00:00 2001 From: Vyacheslav Baranov Date: Thu, 30 Apr 2015 10:01:07 +0300 Subject: [PATCH 3/3] [SPARK-6913] Fixed review comments --- .../scala/org/apache/spark/sql/jdbc/jdbc.scala | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala index 5ddc6806650f7..ae9af1eabe68e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql import java.sql.{Connection, Driver, DriverManager, DriverPropertyInfo, PreparedStatement} import java.util.Properties -import java.util.concurrent.locks.ReentrantLock import scala.collection.mutable @@ -185,7 +184,7 @@ package object jdbc { } - private [sql] case class DriverWrapper(wrapped: Driver) extends Driver { + private [sql] class DriverWrapper(val wrapped: Driver) extends Driver { override def acceptsURL(url: String): Boolean = wrapped.acceptsURL(url) override def jdbcCompliant(): Boolean = wrapped.jdbcCompliant() @@ -212,9 +211,7 @@ package object jdbc { */ private [sql] object DriverRegistry extends Logging { - val wrapperMap: mutable.Map[String, DriverWrapper] = mutable.Map.empty - - val lock = new ReentrantLock() + private val wrapperMap: mutable.Map[String, DriverWrapper] = mutable.Map.empty def register(className: String): Unit = { val cls = Utils.getContextOrSparkClassLoader.loadClass(className) @@ -223,22 +220,19 @@ package object jdbc { } else if (wrapperMap.get(className).isDefined) { logTrace(s"Wrapper for $className already exists") } else { - lock.lock() - try { + synchronized { if (wrapperMap.get(className).isEmpty) { val wrapper = new DriverWrapper(cls.newInstance().asInstanceOf[Driver]) DriverManager.registerDriver(wrapper) wrapperMap(className) = wrapper logTrace(s"Wrapper for $className registered") } - } finally { - lock.unlock() } } } def getDriverClassName(url: String): String = DriverManager.getDriver(url) match { - case DriverWrapper(wrapped) => wrapped.getClass.getCanonicalName + case wrapper: DriverWrapper => wrapper.wrapped.getClass.getCanonicalName case driver => driver.getClass.getCanonicalName } }